From 79a9522ba7f135c238fa88fdce5795401dbaef18 Mon Sep 17 00:00:00 2001 From: Joel Therrien Date: Wed, 28 Aug 2019 18:07:35 -0700 Subject: [PATCH] Several changes - Fixed some tests that weren't running. Fixed a bug where training crashed if FactorCovariates had any NA Fixed a bug where FactorCovariates were ignored in splitting if nsplit==0 Added a covariate specific option for whether splitting on an NA variable should have a penalty. This penalty is accomplished by first calculating the split score and best split for a covariate without NAs as done previously before. Then NAs are randomly assigned, and the split score is recalculated on that best split. The new score is the lower of the new score and the original. --- .../ca/joeltherrien/randomforest/Main.java | 6 +- .../settings/BooleanCovariateSettings.java | 6 +- .../settings/CovariateSettings.java | 4 +- .../settings/FactorCovariateSettings.java | 6 +- .../settings/NumericCovariateSettings.java | 6 +- .../randomforest/TestPersistence.java | 6 +- .../randomforest/csv/TestLoadingCSV.java | 10 +- .../randomforest/covariates/Covariate.java | 2 + .../covariates/bool/BooleanCovariate.java | 17 +- .../covariates/factor/FactorCovariate.java | 20 ++- .../covariates/numeric/NumericCovariate.java | 7 + .../randomforest/tree/SplitAndScore.java | 10 +- .../randomforest/tree/TreeTrainer.java | 86 ++++++--- .../TestDeterministicForests.java | 6 +- .../TestProvidingInitialForest.java | 4 +- .../randomforest/TestSavingLoading.java | 8 +- .../joeltherrien/randomforest/TestUtils.java | 2 +- .../competingrisk/IBSCalculatorTest.java | 4 +- .../competingrisk/TestCompetingRisk.java | 18 +- .../competingrisk/TestLogRankSplitFinder.java | 2 +- .../covariates/FactorCovariateTest.java | 165 +++++++++++++++++- .../covariates/NumericCovariateTest.java | 6 +- .../randomforest/nas/TestNAs.java | 111 ++++++++++-- .../VariableImportanceCalculatorTest.java | 6 +- .../randomforest/workshop/TrainForest.java | 2 +- .../workshop/TrainSingleTree.java | 4 +- .../workshop/TrainSingleTreeFactor.java | 6 +- 27 files changed, 422 insertions(+), 108 deletions(-) diff --git a/executable/src/main/java/ca/joeltherrien/randomforest/Main.java b/executable/src/main/java/ca/joeltherrien/randomforest/Main.java index 7d5dda2..aeeddc6 100644 --- a/executable/src/main/java/ca/joeltherrien/randomforest/Main.java +++ b/executable/src/main/java/ca/joeltherrien/randomforest/Main.java @@ -227,9 +227,9 @@ public class Main { return Settings.builder() .covariateSettings(Utils.easyList( - new NumericCovariateSettings("x1"), - new BooleanCovariateSettings("x2"), - new FactorCovariateSettings("x3", Utils.easyList("cat", "mouse", "dog")) + new NumericCovariateSettings("x1", true), + new BooleanCovariateSettings("x2", false), + new FactorCovariateSettings("x3", Utils.easyList("cat", "mouse", "dog"), true) ) ) .trainingDataLocation("training_data.csv") diff --git a/executable/src/main/java/ca/joeltherrien/randomforest/covariates/settings/BooleanCovariateSettings.java b/executable/src/main/java/ca/joeltherrien/randomforest/covariates/settings/BooleanCovariateSettings.java index 3ed9028..0c7ba46 100644 --- a/executable/src/main/java/ca/joeltherrien/randomforest/covariates/settings/BooleanCovariateSettings.java +++ b/executable/src/main/java/ca/joeltherrien/randomforest/covariates/settings/BooleanCovariateSettings.java @@ -24,12 +24,12 @@ import lombok.NoArgsConstructor; @Data public final class BooleanCovariateSettings extends CovariateSettings { - public BooleanCovariateSettings(String name){ - super(name); + public BooleanCovariateSettings(String name, boolean naSplitPenalty){ + super(name, naSplitPenalty); } @Override public BooleanCovariate build(int index) { - return new BooleanCovariate(name, index); + return new BooleanCovariate(name, index, naSplitPenalty); } } diff --git a/executable/src/main/java/ca/joeltherrien/randomforest/covariates/settings/CovariateSettings.java b/executable/src/main/java/ca/joeltherrien/randomforest/covariates/settings/CovariateSettings.java index 1b967ee..e658dc0 100644 --- a/executable/src/main/java/ca/joeltherrien/randomforest/covariates/settings/CovariateSettings.java +++ b/executable/src/main/java/ca/joeltherrien/randomforest/covariates/settings/CovariateSettings.java @@ -40,9 +40,11 @@ import lombok.NoArgsConstructor; public abstract class CovariateSettings { String name; + boolean naSplitPenalty; - CovariateSettings(String name){ + CovariateSettings(String name, boolean naSplitPenalty){ this.name = name; + this.naSplitPenalty = naSplitPenalty; } public abstract Covariate build(int index); diff --git a/executable/src/main/java/ca/joeltherrien/randomforest/covariates/settings/FactorCovariateSettings.java b/executable/src/main/java/ca/joeltherrien/randomforest/covariates/settings/FactorCovariateSettings.java index 03ccd3b..499bba1 100644 --- a/executable/src/main/java/ca/joeltherrien/randomforest/covariates/settings/FactorCovariateSettings.java +++ b/executable/src/main/java/ca/joeltherrien/randomforest/covariates/settings/FactorCovariateSettings.java @@ -29,13 +29,13 @@ public final class FactorCovariateSettings extends CovariateSettings { private List levels; - public FactorCovariateSettings(String name, List levels){ - super(name); + public FactorCovariateSettings(String name, List levels, boolean naSplitPenalty){ + super(name, naSplitPenalty); this.levels = new ArrayList<>(levels); // Jackson struggles with List.of(...) } @Override public FactorCovariate build(int index) { - return new FactorCovariate(name, index, levels); + return new FactorCovariate(name, index, levels, naSplitPenalty); } } diff --git a/executable/src/main/java/ca/joeltherrien/randomforest/covariates/settings/NumericCovariateSettings.java b/executable/src/main/java/ca/joeltherrien/randomforest/covariates/settings/NumericCovariateSettings.java index 8d15031..6820dde 100644 --- a/executable/src/main/java/ca/joeltherrien/randomforest/covariates/settings/NumericCovariateSettings.java +++ b/executable/src/main/java/ca/joeltherrien/randomforest/covariates/settings/NumericCovariateSettings.java @@ -24,12 +24,12 @@ import lombok.NoArgsConstructor; @Data public final class NumericCovariateSettings extends CovariateSettings { - public NumericCovariateSettings(String name){ - super(name); + public NumericCovariateSettings(String name, boolean naSplitPenalty){ + super(name, naSplitPenalty); } @Override public NumericCovariate build(int index) { - return new NumericCovariate(name, index); + return new NumericCovariate(name, index, naSplitPenalty); } } diff --git a/executable/src/test/java/ca/joeltherrien/randomforest/TestPersistence.java b/executable/src/test/java/ca/joeltherrien/randomforest/TestPersistence.java index b0449af..acc53c3 100644 --- a/executable/src/test/java/ca/joeltherrien/randomforest/TestPersistence.java +++ b/executable/src/test/java/ca/joeltherrien/randomforest/TestPersistence.java @@ -49,9 +49,9 @@ public class TestPersistence { final Settings settingsOriginal = Settings.builder() .covariateSettings(Utils.easyList( - new NumericCovariateSettings("x1"), - new BooleanCovariateSettings("x2"), - new FactorCovariateSettings("x3", Utils.easyList("cat", "mouse", "dog")) + new NumericCovariateSettings("x1", true), + new BooleanCovariateSettings("x2", false), + new FactorCovariateSettings("x3", Utils.easyList("cat", "mouse", "dog"), true) ) ) .trainingDataLocation("training_data.csv") diff --git a/executable/src/test/java/ca/joeltherrien/randomforest/csv/TestLoadingCSV.java b/executable/src/test/java/ca/joeltherrien/randomforest/csv/TestLoadingCSV.java index bf164b0..3c5991f 100644 --- a/executable/src/test/java/ca/joeltherrien/randomforest/csv/TestLoadingCSV.java +++ b/executable/src/test/java/ca/joeltherrien/randomforest/csv/TestLoadingCSV.java @@ -55,9 +55,9 @@ public class TestLoadingCSV { final Settings settings = Settings.builder() .trainingDataLocation(filename) .covariateSettings( - Utils.easyList(new NumericCovariateSettings("x1"), - new FactorCovariateSettings("x2", Utils.easyList("dog", "cat", "mouse")), - new BooleanCovariateSettings("x3")) + Utils.easyList(new NumericCovariateSettings("x1", true), + new FactorCovariateSettings("x2", Utils.easyList("dog", "cat", "mouse"), false), + new BooleanCovariateSettings("x3", true)) ) .yVarSettings(yVarSettings) .build(); @@ -71,14 +71,14 @@ public class TestLoadingCSV { } @Test - public void verifyLoadingNormal(final List covariates) throws IOException { + public void testLoadingNormal(final List covariates) throws IOException { final List> data = loadData("src/test/resources/testCSV.csv"); assertData(data, covariates); } @Test - public void verifyLoadingGz(final List covariates) throws IOException { + public void testLoadingGz(final List covariates) throws IOException { final List> data = loadData("src/test/resources/testCSV.csv.gz"); assertData(data, covariates); diff --git a/library/src/main/java/ca/joeltherrien/randomforest/covariates/Covariate.java b/library/src/main/java/ca/joeltherrien/randomforest/covariates/Covariate.java index f0a8cc1..172a944 100644 --- a/library/src/main/java/ca/joeltherrien/randomforest/covariates/Covariate.java +++ b/library/src/main/java/ca/joeltherrien/randomforest/covariates/Covariate.java @@ -49,6 +49,8 @@ public interface Covariate extends Serializable, Comparable { return getIndex() - other.getIndex(); } + boolean haveNASplitPenalty(); + interface Value extends Serializable{ Covariate getParent(); diff --git a/library/src/main/java/ca/joeltherrien/randomforest/covariates/bool/BooleanCovariate.java b/library/src/main/java/ca/joeltherrien/randomforest/covariates/bool/BooleanCovariate.java index 824c5eb..e87e0d0 100644 --- a/library/src/main/java/ca/joeltherrien/randomforest/covariates/bool/BooleanCovariate.java +++ b/library/src/main/java/ca/joeltherrien/randomforest/covariates/bool/BooleanCovariate.java @@ -25,6 +25,7 @@ import lombok.Getter; import java.util.Iterator; import java.util.List; import java.util.Random; +import java.util.stream.Collectors; public final class BooleanCovariate implements Covariate { @@ -40,14 +41,26 @@ public final class BooleanCovariate implements Covariate { private final BooleanSplitRule splitRule; // there's only one possible rule for BooleanCovariates. - public BooleanCovariate(String name, int index){ + private final boolean haveNASplitPenalty; + @Override + public boolean haveNASplitPenalty(){ + // penalty would add worthless computational time if there are no NAs + return hasNAs && haveNASplitPenalty; + } + + public BooleanCovariate(String name, int index, boolean haveNASplitPenalty){ this.name = name; this.index = index; - splitRule = new BooleanSplitRule(this); + this.splitRule = new BooleanSplitRule(this); + this.haveNASplitPenalty = haveNASplitPenalty; } @Override public Iterator> generateSplitRuleUpdater(List> data, int number, Random random) { + if(hasNAs){ + data = data.stream().filter(row -> !row.getValueByIndex(index).isNA()).collect(Collectors.toList()); + } + return new SingletonIterator<>(this.splitRule.applyRule(data)); } diff --git a/library/src/main/java/ca/joeltherrien/randomforest/covariates/factor/FactorCovariate.java b/library/src/main/java/ca/joeltherrien/randomforest/covariates/factor/FactorCovariate.java index 63db922..919f345 100644 --- a/library/src/main/java/ca/joeltherrien/randomforest/covariates/factor/FactorCovariate.java +++ b/library/src/main/java/ca/joeltherrien/randomforest/covariates/factor/FactorCovariate.java @@ -23,6 +23,7 @@ import lombok.EqualsAndHashCode; import lombok.Getter; import java.util.*; +import java.util.stream.Collectors; public final class FactorCovariate implements Covariate { @@ -40,8 +41,15 @@ public final class FactorCovariate implements Covariate { private boolean hasNAs; + private final boolean haveNASplitPenalty; + @Override + public boolean haveNASplitPenalty(){ + // penalty would add worthless computational time if there are no NAs + return hasNAs && haveNASplitPenalty; + } - public FactorCovariate(final String name, final int index, List levels){ + + public FactorCovariate(final String name, final int index, List levels, final boolean haveNASplitPenalty){ this.name = name; this.index = index; this.factorLevels = new HashMap<>(); @@ -63,12 +71,22 @@ public final class FactorCovariate implements Covariate { this.numberOfPossiblePairings = numberOfPossiblePairingsTemp-1; this.naValue = new FactorValue(null); + + this.haveNASplitPenalty = haveNASplitPenalty; } @Override public Iterator> generateSplitRuleUpdater(List> data, int number, Random random) { + if(hasNAs()){ + data = data.stream().filter(row -> !row.getCovariateValue(this).isNA()).collect(Collectors.toList()); + } + + if(number == 0){ // nsplit = 0 => try every possibility, although we limit it to the number of observations. + number = data.size(); + } + final Set> splits = new HashSet<>(); // This is to ensure we don't get stuck in an infinite loop for small factors diff --git a/library/src/main/java/ca/joeltherrien/randomforest/covariates/numeric/NumericCovariate.java b/library/src/main/java/ca/joeltherrien/randomforest/covariates/numeric/NumericCovariate.java index 9b65187..b08512b 100644 --- a/library/src/main/java/ca/joeltherrien/randomforest/covariates/numeric/NumericCovariate.java +++ b/library/src/main/java/ca/joeltherrien/randomforest/covariates/numeric/NumericCovariate.java @@ -47,6 +47,13 @@ public final class NumericCovariate implements Covariate { private boolean hasNAs = false; + private final boolean haveNASplitPenalty; + @Override + public boolean haveNASplitPenalty(){ + // penalty would add worthless computational time if there are no NAs + return hasNAs && haveNASplitPenalty; + } + @Override public NumericSplitRuleUpdater generateSplitRuleUpdater(List> data, int number, Random random) { Stream> stream = data.stream(); diff --git a/library/src/main/java/ca/joeltherrien/randomforest/tree/SplitAndScore.java b/library/src/main/java/ca/joeltherrien/randomforest/tree/SplitAndScore.java index 7b8d9d4..f232a2c 100644 --- a/library/src/main/java/ca/joeltherrien/randomforest/tree/SplitAndScore.java +++ b/library/src/main/java/ca/joeltherrien/randomforest/tree/SplitAndScore.java @@ -17,15 +17,13 @@ package ca.joeltherrien.randomforest.tree; import lombok.AllArgsConstructor; -import lombok.Getter; +import lombok.Data; @AllArgsConstructor +@Data public class SplitAndScore { - @Getter - private final Split split; - - @Getter - private final Double score; + private Split split; + private Double score; } diff --git a/library/src/main/java/ca/joeltherrien/randomforest/tree/TreeTrainer.java b/library/src/main/java/ca/joeltherrien/randomforest/tree/TreeTrainer.java index 653856a..b41213d 100644 --- a/library/src/main/java/ca/joeltherrien/randomforest/tree/TreeTrainer.java +++ b/library/src/main/java/ca/joeltherrien/randomforest/tree/TreeTrainer.java @@ -17,7 +17,9 @@ package ca.joeltherrien.randomforest.tree; import ca.joeltherrien.randomforest.Row; +import ca.joeltherrien.randomforest.VisibleForTesting; import ca.joeltherrien.randomforest.covariates.Covariate; +import ca.joeltherrien.randomforest.utils.SingletonIterator; import lombok.AccessLevel; import lombok.AllArgsConstructor; import lombok.Builder; @@ -72,31 +74,12 @@ public class TreeTrainer { } - // Now that we have the best split; we need to handle any NAs that were dropped off final double probabilityLeftHand = (double) bestSplit.leftHand.size() / (double) (bestSplit.leftHand.size() + bestSplit.rightHand.size()); // Assign missing values to the split if necessary - if(covariates.get(bestSplit.getSplitRule().getParentCovariateIndex()).hasNAs()){ - bestSplit = bestSplit.modifiableClone(); // the lists in bestSplit are otherwise usually unmodifiable lists - - for(Row row : data) { - final int covariateIndex = bestSplit.getSplitRule().getParentCovariateIndex(); - - if(row.getValueByIndex(covariateIndex).isNA()) { - final boolean randomDecision = random.nextDouble() <= probabilityLeftHand; - - if(randomDecision){ - bestSplit.getLeftHand().add(row); - } - else{ - bestSplit.getRightHand().add(row); - } - - } - } - } + bestSplit = randomlyAssignNAs(data, bestSplit, random); final Node leftNode; final Node rightNode; @@ -144,7 +127,8 @@ public class TreeTrainer { return splitCovariates; } - private Split findBestSplitRule(List> data, List covariatesToTry, Random random){ + @VisibleForTesting + public Split findBestSplitRule(List> data, List covariatesToTry, Random random){ SplitAndScore bestSplitAndScore = null; final SplitFinder noGenericSplitFinder = splitFinder; // cause Java generics are sometimes too frustrating @@ -157,10 +141,32 @@ public class TreeTrainer { continue; } - final SplitAndScore candidateSplitAndScore = noGenericSplitFinder.findBestSplit(iterator); + SplitAndScore candidateSplitAndScore = noGenericSplitFinder.findBestSplit(iterator); - if(candidateSplitAndScore != null && (bestSplitAndScore == null || - candidateSplitAndScore.getScore() > bestSplitAndScore.getScore())) { + + if(candidateSplitAndScore == null){ + continue; + } + + // This score was based on splitting only non-NA values. However, there might be a similar covariate we are also considering + // that is just as good at splitting but has less NAs; we should thus penalize the split score for variables with NAs + // We do this by randomly assigning the NAs and then recalculating the split score on the best split we already have. + // + // We only have to penalize the score though if we know it's possible that this might be the best split. If it's not, + // then we can skip the computations. + final boolean mayBeGoodSplit = bestSplitAndScore == null || + candidateSplitAndScore.getScore() > bestSplitAndScore.getScore(); + if(mayBeGoodSplit && covariate.haveNASplitPenalty()){ + Split candiateSplitWithNAs = randomlyAssignNAs(data, candidateSplitAndScore.getSplit(), random); + final Iterator> newSplitWithRandomNAs = new SingletonIterator<>(candiateSplitWithNAs); + final double newScore = splitFinder.findBestSplit(newSplitWithRandomNAs).getScore(); + + // There's a chance that NAs might add noise to *improve* the score; but we want to ensure we penalize it. + // Thus we only change the score if its worse. + candidateSplitAndScore.setScore(Math.min(newScore, candidateSplitAndScore.getScore())); + } + + if(bestSplitAndScore == null || candidateSplitAndScore.getScore() > bestSplitAndScore.getScore()) { bestSplitAndScore = candidateSplitAndScore; } @@ -174,6 +180,38 @@ public class TreeTrainer { } + private Split randomlyAssignNAs(List> data, Split existingSplit, Random random){ + + // Now that we have the best split; we need to handle any NAs that were dropped off + final double probabilityLeftHand = (double) existingSplit.leftHand.size() / + (double) (existingSplit.leftHand.size() + existingSplit.rightHand.size()); + + + final int covariateIndex = existingSplit.getSplitRule().getParentCovariateIndex(); + + // Assign missing values to the split if necessary + if(covariates.get(existingSplit.getSplitRule().getParentCovariateIndex()).hasNAs()){ + existingSplit = existingSplit.modifiableClone(); // the lists in bestSplit are otherwise usually unmodifiable lists + + for(Row row : data) { + if(row.getValueByIndex(covariateIndex).isNA()) { + final boolean randomDecision = random.nextDouble() <= probabilityLeftHand; + + if(randomDecision){ + existingSplit.getLeftHand().add(row); + } + else{ + existingSplit.getRightHand().add(row); + } + + } + } + } + + return existingSplit; + + } + private boolean nodeIsPure(List> data){ if(!checkNodePurity){ return false; diff --git a/library/src/test/java/ca/joeltherrien/randomforest/TestDeterministicForests.java b/library/src/test/java/ca/joeltherrien/randomforest/TestDeterministicForests.java index 506e939..feb9190 100644 --- a/library/src/test/java/ca/joeltherrien/randomforest/TestDeterministicForests.java +++ b/library/src/test/java/ca/joeltherrien/randomforest/TestDeterministicForests.java @@ -45,20 +45,20 @@ public class TestDeterministicForests { int index = 0; for(int j=0; j<5; j++){ - final NumericCovariate numericCovariate = new NumericCovariate("numeric"+j, index); + final NumericCovariate numericCovariate = new NumericCovariate("numeric"+j, index, false); covariateList.add(numericCovariate); index++; } for(int j=0; j<5; j++){ - final BooleanCovariate booleanCovariate = new BooleanCovariate("boolean"+j, index); + final BooleanCovariate booleanCovariate = new BooleanCovariate("boolean"+j, index, false); covariateList.add(booleanCovariate); index++; } final List levels = Utils.easyList("cat", "dog", "mouse"); for(int j=0; j<5; j++){ - final FactorCovariate factorCovariate = new FactorCovariate("factor"+j, index, levels); + final FactorCovariate factorCovariate = new FactorCovariate("factor"+j, index, levels, false); covariateList.add(factorCovariate); index++; } diff --git a/library/src/test/java/ca/joeltherrien/randomforest/TestProvidingInitialForest.java b/library/src/test/java/ca/joeltherrien/randomforest/TestProvidingInitialForest.java index 15835b1..3ff2876 100644 --- a/library/src/test/java/ca/joeltherrien/randomforest/TestProvidingInitialForest.java +++ b/library/src/test/java/ca/joeltherrien/randomforest/TestProvidingInitialForest.java @@ -44,7 +44,7 @@ public class TestProvidingInitialForest { private List> data; public TestProvidingInitialForest(){ - covariateList = Collections.singletonList(new NumericCovariate("x", 0)); + covariateList = Collections.singletonList(new NumericCovariate("x", 0, false)); data = Utils.easyList( Row.createSimple(Utils.easyMap("x", "1.0"), covariateList, 1, 1.0), @@ -198,7 +198,7 @@ public class TestProvidingInitialForest { it's not clear if the forest being provided is the same one that trees were saved from. */ @Test - public void verifyExceptions(){ + public void testExceptions(){ final String filePath = "src/test/resources/trees/"; final File directory = new File(filePath); if(directory.exists()){ diff --git a/library/src/test/java/ca/joeltherrien/randomforest/TestSavingLoading.java b/library/src/test/java/ca/joeltherrien/randomforest/TestSavingLoading.java index cff464f..41c70d4 100644 --- a/library/src/test/java/ca/joeltherrien/randomforest/TestSavingLoading.java +++ b/library/src/test/java/ca/joeltherrien/randomforest/TestSavingLoading.java @@ -47,10 +47,10 @@ public class TestSavingLoading { public List getCovariates(){ return Utils.easyList( - new NumericCovariate("ageatfda", 0), - new BooleanCovariate("idu", 1), - new BooleanCovariate("black", 2), - new NumericCovariate("cd4nadir", 3) + new NumericCovariate("ageatfda", 0, false), + new BooleanCovariate("idu", 1, false), + new BooleanCovariate("black", 2, false), + new NumericCovariate("cd4nadir", 3, false) ); } diff --git a/library/src/test/java/ca/joeltherrien/randomforest/TestUtils.java b/library/src/test/java/ca/joeltherrien/randomforest/TestUtils.java index 1b26a03..e290597 100644 --- a/library/src/test/java/ca/joeltherrien/randomforest/TestUtils.java +++ b/library/src/test/java/ca/joeltherrien/randomforest/TestUtils.java @@ -156,7 +156,7 @@ public class TestUtils { } @Test - public void reduceListToSize(){ + public void testReduceListToSize(){ final List testList = Utils.easyList(1, 2, 3, 4, 5, 6, 7, 8, 9, 10); final Random random = new Random(); for(int i=0; i<100; i++) { // want to test many times to ensure it doesn't work just due to randomness diff --git a/library/src/test/java/ca/joeltherrien/randomforest/competingrisk/IBSCalculatorTest.java b/library/src/test/java/ca/joeltherrien/randomforest/competingrisk/IBSCalculatorTest.java index dd3a2af..5598856 100644 --- a/library/src/test/java/ca/joeltherrien/randomforest/competingrisk/IBSCalculatorTest.java +++ b/library/src/test/java/ca/joeltherrien/randomforest/competingrisk/IBSCalculatorTest.java @@ -52,7 +52,7 @@ public class IBSCalculatorTest { */ @Test - public void resultsWithoutCensoringDistribution(){ + public void testResultsWithoutCensoringDistribution(){ final IBSCalculator calculator = new IBSCalculator(); final double errorDifferentEvent = calculator.calculateError( @@ -74,7 +74,7 @@ public class IBSCalculatorTest { } @Test - public void resultsWithCensoringDistribution(){ + public void testResultsWithCensoringDistribution(){ final RightContinuousStepFunction censorSurvivalFunction = RightContinuousStepFunction.constructFromPoints( Utils.easyList( new Point(0.0, 0.75), diff --git a/library/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCompetingRisk.java b/library/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCompetingRisk.java index afbd199..e72d7fd 100644 --- a/library/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCompetingRisk.java +++ b/library/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCompetingRisk.java @@ -53,10 +53,10 @@ public class TestCompetingRisk { public List getCovariates(){ return Utils.easyList( - new NumericCovariate("ageatfda", 0), - new BooleanCovariate("idu", 1), - new BooleanCovariate("black", 2), - new NumericCovariate("cd4nadir", 3) + new NumericCovariate("ageatfda", 0, false), + new BooleanCovariate("idu", 1, false), + new BooleanCovariate("black", 2, false), + new NumericCovariate("cd4nadir", 3, false) ); } @@ -109,8 +109,8 @@ public class TestCompetingRisk { // by only using BooleanCovariates (only one split rule) we can guarantee identical results with randomForestSRC on one tree. final List covariates = Utils.easyList( - new BooleanCovariate("idu", 0), - new BooleanCovariate("black", 1) + new BooleanCovariate("idu", 0, false), + new BooleanCovariate("black", 1, false) ); final List> dataset = getData(covariates, "src/test/resources/wihs.bootstrapped.csv"); @@ -210,8 +210,8 @@ public class TestCompetingRisk { public void testLogRankSplitFinderTwoBooleans() throws IOException { // by only using BooleanCovariates (only one split rule) we can guarantee identical results with randomForestSRC on one tree. final List covariates = Utils.easyList( - new BooleanCovariate("idu", 0), - new BooleanCovariate("black", 1) + new BooleanCovariate("idu", 0, false), + new BooleanCovariate("black", 1, false) ); @@ -259,7 +259,7 @@ public class TestCompetingRisk { } @Test - public void verifyDataset() throws IOException { + public void testDataset() throws IOException { final List covariates = getCovariates(); final List> dataset = getData(covariates, DEFAULT_FILEPATH); diff --git a/library/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestLogRankSplitFinder.java b/library/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestLogRankSplitFinder.java index 99a30e4..3513f5d 100644 --- a/library/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestLogRankSplitFinder.java +++ b/library/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestLogRankSplitFinder.java @@ -46,7 +46,7 @@ public class TestLogRankSplitFinder { public static Data loadData(String filename) throws IOException { final List covariates = Utils.easyList( - new NumericCovariate("x2", 0) + new NumericCovariate("x2", 0, false) ); final List> rows = TestUtils.loadData(covariates, new ResponseLoader.CompetingRisksResponseLoader("delta", "u"), filename); diff --git a/library/src/test/java/ca/joeltherrien/randomforest/covariates/FactorCovariateTest.java b/library/src/test/java/ca/joeltherrien/randomforest/covariates/FactorCovariateTest.java index 84c134d..c214c9f 100644 --- a/library/src/test/java/ca/joeltherrien/randomforest/covariates/FactorCovariateTest.java +++ b/library/src/test/java/ca/joeltherrien/randomforest/covariates/FactorCovariateTest.java @@ -17,12 +17,15 @@ package ca.joeltherrien.randomforest.covariates; +import ca.joeltherrien.randomforest.Row; import ca.joeltherrien.randomforest.covariates.factor.FactorCovariate; +import ca.joeltherrien.randomforest.tree.Split; import ca.joeltherrien.randomforest.utils.Utils; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.function.Executable; import java.util.ArrayList; +import java.util.Collections; import java.util.List; import java.util.Random; @@ -31,7 +34,7 @@ import static org.junit.jupiter.api.Assertions.*; public class FactorCovariateTest { @Test - void verifyEqualLevels() { + public void testEqualLevels() { final FactorCovariate petCovariate = createTestCovariate(); final FactorCovariate.FactorValue dog1 = petCovariate.createValue("DOG"); @@ -53,7 +56,7 @@ public class FactorCovariateTest { } @Test - void verifyBadLevelException(){ + public void testBadLevelException(){ final FactorCovariate petCovariate = createTestCovariate(); final Executable badCode = () -> petCovariate.createValue("vulcan"); @@ -61,25 +64,169 @@ public class FactorCovariateTest { } @Test - void testAllSubsets(){ + public void testAllSubsets(){ + final int n = 2*3; // ensure that n is a multiple of 3 for the test final FactorCovariate petCovariate = createTestCovariate(); + final List> data = generateSampleData(petCovariate, n); - final List> splitRules = new ArrayList<>(); + final List> splits = new ArrayList<>(); - petCovariate.generateSplitRuleUpdater(null, 100, new Random()) - .forEachRemaining(split -> splitRules.add(split.getSplitRule())); + petCovariate.generateSplitRuleUpdater(data, 100, new Random()) + .forEachRemaining(split -> splits.add(split)); - assertEquals(splitRules.size(), 3); + assertEquals(splits.size(), 3); - // TODO verify the contents of the split rules + // These are the 3 possibilities + boolean dog_catmouse = false; + boolean cat_dogmouse = false; + boolean mouse_dogcat = false; + for(Split split : splits){ + List> smallerHand; + List> largerHand; + + if(split.getLeftHand().size() < split.getRightHand().size()){ + smallerHand = split.getLeftHand(); + largerHand = split.getRightHand(); + } else{ + smallerHand = split.getRightHand(); + largerHand = split.getLeftHand(); + } + + // There should be exactly one distinct value in the smaller list + assertEquals(n/3, smallerHand.size()); + assertEquals(1, + smallerHand.stream() + .map(row -> row.getCovariateValue(petCovariate).getValue()) + .distinct() + .count() + ); + + // There should be exactly two distinct values in the smaller list + assertEquals(2*n/3, largerHand.size()); + assertEquals(2, + largerHand.stream() + .map(row -> row.getCovariateValue(petCovariate).getValue()) + .distinct() + .count() + ); + + switch(smallerHand.get(0).getCovariateValue(petCovariate).getValue()){ + case "DOG": + dog_catmouse = true; + case "CAT": + cat_dogmouse = true; + case "MOUSE": + mouse_dogcat = true; + } + + } + + assertTrue(dog_catmouse); + assertTrue(cat_dogmouse); + assertTrue(mouse_dogcat); + + } + + /* + * There was a bug where if number==0 in generateSplitRuleUpdater, then the result was empty. + */ + @Test + public void testNumber0Subsets(){ + final int n = 2*3; // ensure that n is a multiple of 3 for the test + final FactorCovariate petCovariate = createTestCovariate(); + final List> data = generateSampleData(petCovariate, n); + + final List> splits = new ArrayList<>(); + + petCovariate.generateSplitRuleUpdater(data, 0, new Random()) + .forEachRemaining(split -> splits.add(split)); + + assertEquals(splits.size(), 3); + + // These are the 3 possibilities + boolean dog_catmouse = false; + boolean cat_dogmouse = false; + boolean mouse_dogcat = false; + + for(Split split : splits){ + List> smallerHand; + List> largerHand; + + if(split.getLeftHand().size() < split.getRightHand().size()){ + smallerHand = split.getLeftHand(); + largerHand = split.getRightHand(); + } else{ + smallerHand = split.getRightHand(); + largerHand = split.getLeftHand(); + } + + // There should be exactly one distinct value in the smaller list + assertEquals(n/3, smallerHand.size()); + assertEquals(1, + smallerHand.stream() + .map(row -> row.getCovariateValue(petCovariate).getValue()) + .distinct() + .count() + ); + + // There should be exactly two distinct values in the smaller list + assertEquals(2*n/3, largerHand.size()); + assertEquals(2, + largerHand.stream() + .map(row -> row.getCovariateValue(petCovariate).getValue()) + .distinct() + .count() + ); + + switch(smallerHand.get(0).getCovariateValue(petCovariate).getValue()){ + case "DOG": + dog_catmouse = true; + case "CAT": + cat_dogmouse = true; + case "MOUSE": + mouse_dogcat = true; + } + + } + + assertTrue(dog_catmouse); + assertTrue(cat_dogmouse); + assertTrue(mouse_dogcat); + + } + + @Test + public void testSpitRuleUpdaterWithNAs(){ + // When some NAs were present calling generateSplitRuleUpdater caused an exception. + + final FactorCovariate covariate = createTestCovariate(); + final List> sampleData = generateSampleData(covariate, 10); + sampleData.add(Row.createSimple(Utils.easyMap("pet", "NA"), Collections.singletonList(covariate), 11, 5.0)); + + covariate.generateSplitRuleUpdater(sampleData, 0, new Random()); + + // Test passes if no exception has occurred. } private FactorCovariate createTestCovariate(){ final List levels = Utils.easyList("DOG", "CAT", "MOUSE"); - return new FactorCovariate("pet", 0, levels); + return new FactorCovariate("pet", 0, levels, false); + } + + private List> generateSampleData(Covariate covariate, int n){ + final List covariateList = Collections.singletonList(covariate); + final List> dataList = new ArrayList<>(n); + + final String[] levels = new String[]{"DOG", "CAT", "MOUSE"}; + + for(int i=0; i> dataset = createTestDataset(covariate); @@ -158,7 +158,7 @@ public class NumericCovariateTest { @Test public void testNumericSplitRuleUpdaterWithIndexes(){ - final NumericCovariate covariate = new NumericCovariate("x", 0); + final NumericCovariate covariate = new NumericCovariate("x", 0, false); final List> dataset = createTestDataset(covariate); @@ -223,7 +223,7 @@ public class NumericCovariateTest { */ @Test public void testNumericSplitRuleUpdaterWithIndexesAllMissingData(){ - final NumericCovariate covariate = new NumericCovariate("x", 0); + final NumericCovariate covariate = new NumericCovariate("x", 0, false); final List> dataset = createTestDatasetMissingValues(covariate); final NumericSplitRuleUpdater updater = covariate.generateSplitRuleUpdater(dataset, 5, new Random()); diff --git a/library/src/test/java/ca/joeltherrien/randomforest/nas/TestNAs.java b/library/src/test/java/ca/joeltherrien/randomforest/nas/TestNAs.java index d60b13d..4e63b1c 100644 --- a/library/src/test/java/ca/joeltherrien/randomforest/nas/TestNAs.java +++ b/library/src/test/java/ca/joeltherrien/randomforest/nas/TestNAs.java @@ -18,31 +18,34 @@ package ca.joeltherrien.randomforest.nas; import ca.joeltherrien.randomforest.Row; import ca.joeltherrien.randomforest.covariates.Covariate; +import ca.joeltherrien.randomforest.covariates.bool.BooleanCovariate; +import ca.joeltherrien.randomforest.covariates.factor.FactorCovariate; import ca.joeltherrien.randomforest.covariates.numeric.NumericCovariate; import ca.joeltherrien.randomforest.responses.regression.MeanResponseCombiner; import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceSplitFinder; +import ca.joeltherrien.randomforest.tree.Split; import ca.joeltherrien.randomforest.tree.TreeTrainer; import ca.joeltherrien.randomforest.utils.Utils; import org.junit.jupiter.api.Test; import java.util.ArrayList; -import java.util.Collections; import java.util.List; import java.util.Random; +import static org.junit.jupiter.api.Assertions.*; + public class TestNAs { - private List> generateData(List covariates){ + private List> generateData1(List covariates){ final List> dataList = new ArrayList<>(); - // We must include an NA for one of the values - dataList.add(Row.createSimple(Utils.easyMap("x", "NA"), covariates, 1, 5.0)); - dataList.add(Row.createSimple(Utils.easyMap("x", "1"), covariates, 1, 6.0)); - dataList.add(Row.createSimple(Utils.easyMap("x", "2"), covariates, 1, 5.5)); - dataList.add(Row.createSimple(Utils.easyMap("x", "7"), covariates, 1, 0.0)); - dataList.add(Row.createSimple(Utils.easyMap("x", "8"), covariates, 1, 1.0)); - dataList.add(Row.createSimple(Utils.easyMap("x", "8.4"), covariates, 1, 1.0)); + dataList.add(Row.createSimple(Utils.easyMap("x", "NA", "y", "true", "z", "green"), covariates, 1, 5.0)); + dataList.add(Row.createSimple(Utils.easyMap("x", "1", "y", "NA", "z", "blue"), covariates, 2, 6.0)); + dataList.add(Row.createSimple(Utils.easyMap("x", "2", "y", "true", "z", "NA"), covariates, 3, 5.5)); + dataList.add(Row.createSimple(Utils.easyMap("x", "7", "y", "false", "z", "green"), covariates, 4, 0.0)); + dataList.add(Row.createSimple(Utils.easyMap("x", "8", "y", "true", "z", "blue"), covariates, 5, 1.0)); + dataList.add(Row.createSimple(Utils.easyMap("x", "8.4", "y", "false", "z", "yellow"), covariates, 6, 1.0)); return dataList; @@ -54,14 +57,19 @@ public class TestNAs { // but NumericSplitRuleUpdater had unmodifiable lists when creating the split. // This bug verifies that this no longer causes a crash - final List covariates = Collections.singletonList(new NumericCovariate("x", 0)); - final List> dataset = generateData(covariates); + final List covariates = Utils.easyList( + new NumericCovariate("x", 0, false), + new BooleanCovariate("y", 1, true), + new FactorCovariate("z", 2, Utils.easyList("green", "blue", "yellow"), true) + ); + final List> dataset = generateData1(covariates); final TreeTrainer treeTrainer = TreeTrainer.builder() .checkNodePurity(false) .covariates(covariates) .numberOfSplits(0) .nodeSize(1) + .mtry(3) .maxNodeDepth(1000) .splitFinder(new WeightedVarianceSplitFinder()) .responseCombiner(new MeanResponseCombiner()) @@ -70,6 +78,87 @@ public class TestNAs { treeTrainer.growTree(dataset, new Random(123)); // As long as no exception occurs, we passed + } + + private List> generateData2(List covariates){ + final List> dataList = new ArrayList<>(); + // Idea - when ignoring NAs, BadVar gives a perfect split. + // GoodVar is slightly worse than BadVar when NAs are excluded. + // However, BadVar has a ton of NAs that will degrade its performance. + dataList.add(Row.createSimple( + Utils.easyMap("BadVar", "-1.0", "GoodVar", "true") // GoodVars one error + , covariates, 1, 5.0) + ); + dataList.add(Row.createSimple( + Utils.easyMap("BadVar", "NA", "GoodVar", "false") + , covariates, 2, 5.0) + ); + dataList.add(Row.createSimple( + Utils.easyMap("BadVar", "NA", "GoodVar", "false") + , covariates, 3, 5.0) + ); + dataList.add(Row.createSimple( + Utils.easyMap("BadVar", "0.5", "GoodVar", "true") + , covariates, 4, 10.0) + ); + dataList.add(Row.createSimple( + Utils.easyMap("BadVar", "NA", "GoodVar", "true") + , covariates, 5, 10.0) + ); + dataList.add(Row.createSimple( + Utils.easyMap("BadVar", "NA", "GoodVar", "true") + , covariates, 6, 10.0) + ); + + return dataList; + } + + @Test + // Test that the NA penalty works when selecting a best split. + public void testNAPenalty(){ + final List covariates1 = Utils.easyList( + new NumericCovariate("BadVar", 0, true), + new BooleanCovariate("GoodVar", 1, false) + ); + + final List> dataList1 = generateData2(covariates1); + + final TreeTrainer treeTrainer1 = TreeTrainer.builder() + .checkNodePurity(false) + .covariates(covariates1) + .numberOfSplits(0) + .nodeSize(1) + .mtry(2) + .maxNodeDepth(1000) + .splitFinder(new WeightedVarianceSplitFinder()) + .responseCombiner(new MeanResponseCombiner()) + .build(); + + final Split bestSplit1 = treeTrainer1.findBestSplitRule(dataList1, covariates1, new Random(123)); + assertEquals(1, bestSplit1.getSplitRule().getParentCovariateIndex()); // 1 corresponds to GoodVar + + // Run again without the penalty; verify that we get different results + + final List covariates2 = Utils.easyList( + new NumericCovariate("BadVar", 0, false), + new BooleanCovariate("GoodVar", 1, false) + ); + + final List> dataList2 = generateData2(covariates2); + + final TreeTrainer treeTrainer2 = TreeTrainer.builder() + .checkNodePurity(false) + .covariates(covariates2) + .numberOfSplits(0) + .nodeSize(1) + .mtry(2) + .maxNodeDepth(1000) + .splitFinder(new WeightedVarianceSplitFinder()) + .responseCombiner(new MeanResponseCombiner()) + .build(); + + final Split bestSplit2 = treeTrainer2.findBestSplitRule(dataList2, covariates2, new Random(123)); + assertEquals(0, bestSplit2.getSplitRule().getParentCovariateIndex()); // 1 corresponds to GoodVar } diff --git a/library/src/test/java/ca/joeltherrien/randomforest/tree/vimp/VariableImportanceCalculatorTest.java b/library/src/test/java/ca/joeltherrien/randomforest/tree/vimp/VariableImportanceCalculatorTest.java index b656ce0..e8a83a9 100644 --- a/library/src/test/java/ca/joeltherrien/randomforest/tree/vimp/VariableImportanceCalculatorTest.java +++ b/library/src/test/java/ca/joeltherrien/randomforest/tree/vimp/VariableImportanceCalculatorTest.java @@ -39,10 +39,10 @@ public class VariableImportanceCalculatorTest { */ public VariableImportanceCalculatorTest(){ - final BooleanCovariate booleanCovariate = new BooleanCovariate("x", 0); - final NumericCovariate numericCovariate = new NumericCovariate("y", 1); + final BooleanCovariate booleanCovariate = new BooleanCovariate("x", 0, false); + final NumericCovariate numericCovariate = new NumericCovariate("y", 1, false); final FactorCovariate factorCovariate = new FactorCovariate("z", 2, - Utils.easyList("red", "blue", "green")); + Utils.easyList("red", "blue", "green"), false); this.covariates = Utils.easyList(booleanCovariate, numericCovariate, factorCovariate); diff --git a/library/src/test/java/ca/joeltherrien/randomforest/workshop/TrainForest.java b/library/src/test/java/ca/joeltherrien/randomforest/workshop/TrainForest.java index c9a031a..679fef7 100644 --- a/library/src/test/java/ca/joeltherrien/randomforest/workshop/TrainForest.java +++ b/library/src/test/java/ca/joeltherrien/randomforest/workshop/TrainForest.java @@ -43,7 +43,7 @@ public class TrainForest { final List covariateList = new ArrayList<>(p); for(int j =0; j < p; j++){ - final NumericCovariate covariate = new NumericCovariate("x"+j, j); + final NumericCovariate covariate = new NumericCovariate("x"+j, j, false); covariateList.add(covariate); } diff --git a/library/src/test/java/ca/joeltherrien/randomforest/workshop/TrainSingleTree.java b/library/src/test/java/ca/joeltherrien/randomforest/workshop/TrainSingleTree.java index 8e1e361..6c9e0ed 100644 --- a/library/src/test/java/ca/joeltherrien/randomforest/workshop/TrainSingleTree.java +++ b/library/src/test/java/ca/joeltherrien/randomforest/workshop/TrainSingleTree.java @@ -39,8 +39,8 @@ public class TrainSingleTree { final int n = 1000; final List> trainingSet = new ArrayList<>(n); - final Covariate x1Covariate = new NumericCovariate("x1", 0); - final Covariate x2Covariate = new NumericCovariate("x2", 1); + final Covariate x1Covariate = new NumericCovariate("x1", 0, false); + final Covariate x2Covariate = new NumericCovariate("x2", 1, false); final List> x1List = DoubleStream .generate(() -> random.nextDouble()*10.0) diff --git a/library/src/test/java/ca/joeltherrien/randomforest/workshop/TrainSingleTreeFactor.java b/library/src/test/java/ca/joeltherrien/randomforest/workshop/TrainSingleTreeFactor.java index 1d2b404..bc0bd6e 100644 --- a/library/src/test/java/ca/joeltherrien/randomforest/workshop/TrainSingleTreeFactor.java +++ b/library/src/test/java/ca/joeltherrien/randomforest/workshop/TrainSingleTreeFactor.java @@ -41,9 +41,9 @@ public class TrainSingleTreeFactor { final int n = 10000; final List> trainingSet = new ArrayList<>(n); - final Covariate x1Covariate = new NumericCovariate("x1", 0); - final Covariate x2Covariate = new NumericCovariate("x2", 1); - final FactorCovariate x3Covariate = new FactorCovariate("x3", 2, Utils.easyList("cat", "dog", "mouse")); + final Covariate x1Covariate = new NumericCovariate("x1", 0, false); + final Covariate x2Covariate = new NumericCovariate("x2", 1, false); + final FactorCovariate x3Covariate = new FactorCovariate("x3", 2, Utils.easyList("cat", "dog", "mouse"), false); final List> x1List = DoubleStream .generate(() -> random.nextDouble()*10.0)