diff --git a/executable/src/main/java/ca/joeltherrien/randomforest/Main.java b/executable/src/main/java/ca/joeltherrien/randomforest/Main.java index 7d5dda2..aeeddc6 100644 --- a/executable/src/main/java/ca/joeltherrien/randomforest/Main.java +++ b/executable/src/main/java/ca/joeltherrien/randomforest/Main.java @@ -227,9 +227,9 @@ public class Main { return Settings.builder() .covariateSettings(Utils.easyList( - new NumericCovariateSettings("x1"), - new BooleanCovariateSettings("x2"), - new FactorCovariateSettings("x3", Utils.easyList("cat", "mouse", "dog")) + new NumericCovariateSettings("x1", true), + new BooleanCovariateSettings("x2", false), + new FactorCovariateSettings("x3", Utils.easyList("cat", "mouse", "dog"), true) ) ) .trainingDataLocation("training_data.csv") diff --git a/executable/src/main/java/ca/joeltherrien/randomforest/covariates/settings/BooleanCovariateSettings.java b/executable/src/main/java/ca/joeltherrien/randomforest/covariates/settings/BooleanCovariateSettings.java index 3ed9028..0c7ba46 100644 --- a/executable/src/main/java/ca/joeltherrien/randomforest/covariates/settings/BooleanCovariateSettings.java +++ b/executable/src/main/java/ca/joeltherrien/randomforest/covariates/settings/BooleanCovariateSettings.java @@ -24,12 +24,12 @@ import lombok.NoArgsConstructor; @Data public final class BooleanCovariateSettings extends CovariateSettings { - public BooleanCovariateSettings(String name){ - super(name); + public BooleanCovariateSettings(String name, boolean naSplitPenalty){ + super(name, naSplitPenalty); } @Override public BooleanCovariate build(int index) { - return new BooleanCovariate(name, index); + return new BooleanCovariate(name, index, naSplitPenalty); } } diff --git a/executable/src/main/java/ca/joeltherrien/randomforest/covariates/settings/CovariateSettings.java b/executable/src/main/java/ca/joeltherrien/randomforest/covariates/settings/CovariateSettings.java index 1b967ee..e658dc0 100644 --- a/executable/src/main/java/ca/joeltherrien/randomforest/covariates/settings/CovariateSettings.java +++ b/executable/src/main/java/ca/joeltherrien/randomforest/covariates/settings/CovariateSettings.java @@ -40,9 +40,11 @@ import lombok.NoArgsConstructor; public abstract class CovariateSettings { String name; + boolean naSplitPenalty; - CovariateSettings(String name){ + CovariateSettings(String name, boolean naSplitPenalty){ this.name = name; + this.naSplitPenalty = naSplitPenalty; } public abstract Covariate build(int index); diff --git a/executable/src/main/java/ca/joeltherrien/randomforest/covariates/settings/FactorCovariateSettings.java b/executable/src/main/java/ca/joeltherrien/randomforest/covariates/settings/FactorCovariateSettings.java index 03ccd3b..499bba1 100644 --- a/executable/src/main/java/ca/joeltherrien/randomforest/covariates/settings/FactorCovariateSettings.java +++ b/executable/src/main/java/ca/joeltherrien/randomforest/covariates/settings/FactorCovariateSettings.java @@ -29,13 +29,13 @@ public final class FactorCovariateSettings extends CovariateSettings { private List levels; - public FactorCovariateSettings(String name, List levels){ - super(name); + public FactorCovariateSettings(String name, List levels, boolean naSplitPenalty){ + super(name, naSplitPenalty); this.levels = new ArrayList<>(levels); // Jackson struggles with List.of(...) } @Override public FactorCovariate build(int index) { - return new FactorCovariate(name, index, levels); + return new FactorCovariate(name, index, levels, naSplitPenalty); } } diff --git a/executable/src/main/java/ca/joeltherrien/randomforest/covariates/settings/NumericCovariateSettings.java b/executable/src/main/java/ca/joeltherrien/randomforest/covariates/settings/NumericCovariateSettings.java index 8d15031..6820dde 100644 --- a/executable/src/main/java/ca/joeltherrien/randomforest/covariates/settings/NumericCovariateSettings.java +++ b/executable/src/main/java/ca/joeltherrien/randomforest/covariates/settings/NumericCovariateSettings.java @@ -24,12 +24,12 @@ import lombok.NoArgsConstructor; @Data public final class NumericCovariateSettings extends CovariateSettings { - public NumericCovariateSettings(String name){ - super(name); + public NumericCovariateSettings(String name, boolean naSplitPenalty){ + super(name, naSplitPenalty); } @Override public NumericCovariate build(int index) { - return new NumericCovariate(name, index); + return new NumericCovariate(name, index, naSplitPenalty); } } diff --git a/executable/src/test/java/ca/joeltherrien/randomforest/TestPersistence.java b/executable/src/test/java/ca/joeltherrien/randomforest/TestPersistence.java index b0449af..acc53c3 100644 --- a/executable/src/test/java/ca/joeltherrien/randomforest/TestPersistence.java +++ b/executable/src/test/java/ca/joeltherrien/randomforest/TestPersistence.java @@ -49,9 +49,9 @@ public class TestPersistence { final Settings settingsOriginal = Settings.builder() .covariateSettings(Utils.easyList( - new NumericCovariateSettings("x1"), - new BooleanCovariateSettings("x2"), - new FactorCovariateSettings("x3", Utils.easyList("cat", "mouse", "dog")) + new NumericCovariateSettings("x1", true), + new BooleanCovariateSettings("x2", false), + new FactorCovariateSettings("x3", Utils.easyList("cat", "mouse", "dog"), true) ) ) .trainingDataLocation("training_data.csv") diff --git a/executable/src/test/java/ca/joeltherrien/randomforest/csv/TestLoadingCSV.java b/executable/src/test/java/ca/joeltherrien/randomforest/csv/TestLoadingCSV.java index bf164b0..3c5991f 100644 --- a/executable/src/test/java/ca/joeltherrien/randomforest/csv/TestLoadingCSV.java +++ b/executable/src/test/java/ca/joeltherrien/randomforest/csv/TestLoadingCSV.java @@ -55,9 +55,9 @@ public class TestLoadingCSV { final Settings settings = Settings.builder() .trainingDataLocation(filename) .covariateSettings( - Utils.easyList(new NumericCovariateSettings("x1"), - new FactorCovariateSettings("x2", Utils.easyList("dog", "cat", "mouse")), - new BooleanCovariateSettings("x3")) + Utils.easyList(new NumericCovariateSettings("x1", true), + new FactorCovariateSettings("x2", Utils.easyList("dog", "cat", "mouse"), false), + new BooleanCovariateSettings("x3", true)) ) .yVarSettings(yVarSettings) .build(); @@ -71,14 +71,14 @@ public class TestLoadingCSV { } @Test - public void verifyLoadingNormal(final List covariates) throws IOException { + public void testLoadingNormal(final List covariates) throws IOException { final List> data = loadData("src/test/resources/testCSV.csv"); assertData(data, covariates); } @Test - public void verifyLoadingGz(final List covariates) throws IOException { + public void testLoadingGz(final List covariates) throws IOException { final List> data = loadData("src/test/resources/testCSV.csv.gz"); assertData(data, covariates); diff --git a/library/src/main/java/ca/joeltherrien/randomforest/covariates/Covariate.java b/library/src/main/java/ca/joeltherrien/randomforest/covariates/Covariate.java index f0a8cc1..172a944 100644 --- a/library/src/main/java/ca/joeltherrien/randomforest/covariates/Covariate.java +++ b/library/src/main/java/ca/joeltherrien/randomforest/covariates/Covariate.java @@ -49,6 +49,8 @@ public interface Covariate extends Serializable, Comparable { return getIndex() - other.getIndex(); } + boolean haveNASplitPenalty(); + interface Value extends Serializable{ Covariate getParent(); diff --git a/library/src/main/java/ca/joeltherrien/randomforest/covariates/bool/BooleanCovariate.java b/library/src/main/java/ca/joeltherrien/randomforest/covariates/bool/BooleanCovariate.java index 824c5eb..e87e0d0 100644 --- a/library/src/main/java/ca/joeltherrien/randomforest/covariates/bool/BooleanCovariate.java +++ b/library/src/main/java/ca/joeltherrien/randomforest/covariates/bool/BooleanCovariate.java @@ -25,6 +25,7 @@ import lombok.Getter; import java.util.Iterator; import java.util.List; import java.util.Random; +import java.util.stream.Collectors; public final class BooleanCovariate implements Covariate { @@ -40,14 +41,26 @@ public final class BooleanCovariate implements Covariate { private final BooleanSplitRule splitRule; // there's only one possible rule for BooleanCovariates. - public BooleanCovariate(String name, int index){ + private final boolean haveNASplitPenalty; + @Override + public boolean haveNASplitPenalty(){ + // penalty would add worthless computational time if there are no NAs + return hasNAs && haveNASplitPenalty; + } + + public BooleanCovariate(String name, int index, boolean haveNASplitPenalty){ this.name = name; this.index = index; - splitRule = new BooleanSplitRule(this); + this.splitRule = new BooleanSplitRule(this); + this.haveNASplitPenalty = haveNASplitPenalty; } @Override public Iterator> generateSplitRuleUpdater(List> data, int number, Random random) { + if(hasNAs){ + data = data.stream().filter(row -> !row.getValueByIndex(index).isNA()).collect(Collectors.toList()); + } + return new SingletonIterator<>(this.splitRule.applyRule(data)); } diff --git a/library/src/main/java/ca/joeltherrien/randomforest/covariates/factor/FactorCovariate.java b/library/src/main/java/ca/joeltherrien/randomforest/covariates/factor/FactorCovariate.java index 63db922..919f345 100644 --- a/library/src/main/java/ca/joeltherrien/randomforest/covariates/factor/FactorCovariate.java +++ b/library/src/main/java/ca/joeltherrien/randomforest/covariates/factor/FactorCovariate.java @@ -23,6 +23,7 @@ import lombok.EqualsAndHashCode; import lombok.Getter; import java.util.*; +import java.util.stream.Collectors; public final class FactorCovariate implements Covariate { @@ -40,8 +41,15 @@ public final class FactorCovariate implements Covariate { private boolean hasNAs; + private final boolean haveNASplitPenalty; + @Override + public boolean haveNASplitPenalty(){ + // penalty would add worthless computational time if there are no NAs + return hasNAs && haveNASplitPenalty; + } - public FactorCovariate(final String name, final int index, List levels){ + + public FactorCovariate(final String name, final int index, List levels, final boolean haveNASplitPenalty){ this.name = name; this.index = index; this.factorLevels = new HashMap<>(); @@ -63,12 +71,22 @@ public final class FactorCovariate implements Covariate { this.numberOfPossiblePairings = numberOfPossiblePairingsTemp-1; this.naValue = new FactorValue(null); + + this.haveNASplitPenalty = haveNASplitPenalty; } @Override public Iterator> generateSplitRuleUpdater(List> data, int number, Random random) { + if(hasNAs()){ + data = data.stream().filter(row -> !row.getCovariateValue(this).isNA()).collect(Collectors.toList()); + } + + if(number == 0){ // nsplit = 0 => try every possibility, although we limit it to the number of observations. + number = data.size(); + } + final Set> splits = new HashSet<>(); // This is to ensure we don't get stuck in an infinite loop for small factors diff --git a/library/src/main/java/ca/joeltherrien/randomforest/covariates/numeric/NumericCovariate.java b/library/src/main/java/ca/joeltherrien/randomforest/covariates/numeric/NumericCovariate.java index 9b65187..b08512b 100644 --- a/library/src/main/java/ca/joeltherrien/randomforest/covariates/numeric/NumericCovariate.java +++ b/library/src/main/java/ca/joeltherrien/randomforest/covariates/numeric/NumericCovariate.java @@ -47,6 +47,13 @@ public final class NumericCovariate implements Covariate { private boolean hasNAs = false; + private final boolean haveNASplitPenalty; + @Override + public boolean haveNASplitPenalty(){ + // penalty would add worthless computational time if there are no NAs + return hasNAs && haveNASplitPenalty; + } + @Override public NumericSplitRuleUpdater generateSplitRuleUpdater(List> data, int number, Random random) { Stream> stream = data.stream(); diff --git a/library/src/main/java/ca/joeltherrien/randomforest/tree/SplitAndScore.java b/library/src/main/java/ca/joeltherrien/randomforest/tree/SplitAndScore.java index 7b8d9d4..f232a2c 100644 --- a/library/src/main/java/ca/joeltherrien/randomforest/tree/SplitAndScore.java +++ b/library/src/main/java/ca/joeltherrien/randomforest/tree/SplitAndScore.java @@ -17,15 +17,13 @@ package ca.joeltherrien.randomforest.tree; import lombok.AllArgsConstructor; -import lombok.Getter; +import lombok.Data; @AllArgsConstructor +@Data public class SplitAndScore { - @Getter - private final Split split; - - @Getter - private final Double score; + private Split split; + private Double score; } diff --git a/library/src/main/java/ca/joeltherrien/randomforest/tree/TreeTrainer.java b/library/src/main/java/ca/joeltherrien/randomforest/tree/TreeTrainer.java index 653856a..b41213d 100644 --- a/library/src/main/java/ca/joeltherrien/randomforest/tree/TreeTrainer.java +++ b/library/src/main/java/ca/joeltherrien/randomforest/tree/TreeTrainer.java @@ -17,7 +17,9 @@ package ca.joeltherrien.randomforest.tree; import ca.joeltherrien.randomforest.Row; +import ca.joeltherrien.randomforest.VisibleForTesting; import ca.joeltherrien.randomforest.covariates.Covariate; +import ca.joeltherrien.randomforest.utils.SingletonIterator; import lombok.AccessLevel; import lombok.AllArgsConstructor; import lombok.Builder; @@ -72,31 +74,12 @@ public class TreeTrainer { } - // Now that we have the best split; we need to handle any NAs that were dropped off final double probabilityLeftHand = (double) bestSplit.leftHand.size() / (double) (bestSplit.leftHand.size() + bestSplit.rightHand.size()); // Assign missing values to the split if necessary - if(covariates.get(bestSplit.getSplitRule().getParentCovariateIndex()).hasNAs()){ - bestSplit = bestSplit.modifiableClone(); // the lists in bestSplit are otherwise usually unmodifiable lists - - for(Row row : data) { - final int covariateIndex = bestSplit.getSplitRule().getParentCovariateIndex(); - - if(row.getValueByIndex(covariateIndex).isNA()) { - final boolean randomDecision = random.nextDouble() <= probabilityLeftHand; - - if(randomDecision){ - bestSplit.getLeftHand().add(row); - } - else{ - bestSplit.getRightHand().add(row); - } - - } - } - } + bestSplit = randomlyAssignNAs(data, bestSplit, random); final Node leftNode; final Node rightNode; @@ -144,7 +127,8 @@ public class TreeTrainer { return splitCovariates; } - private Split findBestSplitRule(List> data, List covariatesToTry, Random random){ + @VisibleForTesting + public Split findBestSplitRule(List> data, List covariatesToTry, Random random){ SplitAndScore bestSplitAndScore = null; final SplitFinder noGenericSplitFinder = splitFinder; // cause Java generics are sometimes too frustrating @@ -157,10 +141,32 @@ public class TreeTrainer { continue; } - final SplitAndScore candidateSplitAndScore = noGenericSplitFinder.findBestSplit(iterator); + SplitAndScore candidateSplitAndScore = noGenericSplitFinder.findBestSplit(iterator); - if(candidateSplitAndScore != null && (bestSplitAndScore == null || - candidateSplitAndScore.getScore() > bestSplitAndScore.getScore())) { + + if(candidateSplitAndScore == null){ + continue; + } + + // This score was based on splitting only non-NA values. However, there might be a similar covariate we are also considering + // that is just as good at splitting but has less NAs; we should thus penalize the split score for variables with NAs + // We do this by randomly assigning the NAs and then recalculating the split score on the best split we already have. + // + // We only have to penalize the score though if we know it's possible that this might be the best split. If it's not, + // then we can skip the computations. + final boolean mayBeGoodSplit = bestSplitAndScore == null || + candidateSplitAndScore.getScore() > bestSplitAndScore.getScore(); + if(mayBeGoodSplit && covariate.haveNASplitPenalty()){ + Split candiateSplitWithNAs = randomlyAssignNAs(data, candidateSplitAndScore.getSplit(), random); + final Iterator> newSplitWithRandomNAs = new SingletonIterator<>(candiateSplitWithNAs); + final double newScore = splitFinder.findBestSplit(newSplitWithRandomNAs).getScore(); + + // There's a chance that NAs might add noise to *improve* the score; but we want to ensure we penalize it. + // Thus we only change the score if its worse. + candidateSplitAndScore.setScore(Math.min(newScore, candidateSplitAndScore.getScore())); + } + + if(bestSplitAndScore == null || candidateSplitAndScore.getScore() > bestSplitAndScore.getScore()) { bestSplitAndScore = candidateSplitAndScore; } @@ -174,6 +180,38 @@ public class TreeTrainer { } + private Split randomlyAssignNAs(List> data, Split existingSplit, Random random){ + + // Now that we have the best split; we need to handle any NAs that were dropped off + final double probabilityLeftHand = (double) existingSplit.leftHand.size() / + (double) (existingSplit.leftHand.size() + existingSplit.rightHand.size()); + + + final int covariateIndex = existingSplit.getSplitRule().getParentCovariateIndex(); + + // Assign missing values to the split if necessary + if(covariates.get(existingSplit.getSplitRule().getParentCovariateIndex()).hasNAs()){ + existingSplit = existingSplit.modifiableClone(); // the lists in bestSplit are otherwise usually unmodifiable lists + + for(Row row : data) { + if(row.getValueByIndex(covariateIndex).isNA()) { + final boolean randomDecision = random.nextDouble() <= probabilityLeftHand; + + if(randomDecision){ + existingSplit.getLeftHand().add(row); + } + else{ + existingSplit.getRightHand().add(row); + } + + } + } + } + + return existingSplit; + + } + private boolean nodeIsPure(List> data){ if(!checkNodePurity){ return false; diff --git a/library/src/test/java/ca/joeltherrien/randomforest/TestDeterministicForests.java b/library/src/test/java/ca/joeltherrien/randomforest/TestDeterministicForests.java index 506e939..feb9190 100644 --- a/library/src/test/java/ca/joeltherrien/randomforest/TestDeterministicForests.java +++ b/library/src/test/java/ca/joeltherrien/randomforest/TestDeterministicForests.java @@ -45,20 +45,20 @@ public class TestDeterministicForests { int index = 0; for(int j=0; j<5; j++){ - final NumericCovariate numericCovariate = new NumericCovariate("numeric"+j, index); + final NumericCovariate numericCovariate = new NumericCovariate("numeric"+j, index, false); covariateList.add(numericCovariate); index++; } for(int j=0; j<5; j++){ - final BooleanCovariate booleanCovariate = new BooleanCovariate("boolean"+j, index); + final BooleanCovariate booleanCovariate = new BooleanCovariate("boolean"+j, index, false); covariateList.add(booleanCovariate); index++; } final List levels = Utils.easyList("cat", "dog", "mouse"); for(int j=0; j<5; j++){ - final FactorCovariate factorCovariate = new FactorCovariate("factor"+j, index, levels); + final FactorCovariate factorCovariate = new FactorCovariate("factor"+j, index, levels, false); covariateList.add(factorCovariate); index++; } diff --git a/library/src/test/java/ca/joeltherrien/randomforest/TestProvidingInitialForest.java b/library/src/test/java/ca/joeltherrien/randomforest/TestProvidingInitialForest.java index 15835b1..3ff2876 100644 --- a/library/src/test/java/ca/joeltherrien/randomforest/TestProvidingInitialForest.java +++ b/library/src/test/java/ca/joeltherrien/randomforest/TestProvidingInitialForest.java @@ -44,7 +44,7 @@ public class TestProvidingInitialForest { private List> data; public TestProvidingInitialForest(){ - covariateList = Collections.singletonList(new NumericCovariate("x", 0)); + covariateList = Collections.singletonList(new NumericCovariate("x", 0, false)); data = Utils.easyList( Row.createSimple(Utils.easyMap("x", "1.0"), covariateList, 1, 1.0), @@ -198,7 +198,7 @@ public class TestProvidingInitialForest { it's not clear if the forest being provided is the same one that trees were saved from. */ @Test - public void verifyExceptions(){ + public void testExceptions(){ final String filePath = "src/test/resources/trees/"; final File directory = new File(filePath); if(directory.exists()){ diff --git a/library/src/test/java/ca/joeltherrien/randomforest/TestSavingLoading.java b/library/src/test/java/ca/joeltherrien/randomforest/TestSavingLoading.java index cff464f..41c70d4 100644 --- a/library/src/test/java/ca/joeltherrien/randomforest/TestSavingLoading.java +++ b/library/src/test/java/ca/joeltherrien/randomforest/TestSavingLoading.java @@ -47,10 +47,10 @@ public class TestSavingLoading { public List getCovariates(){ return Utils.easyList( - new NumericCovariate("ageatfda", 0), - new BooleanCovariate("idu", 1), - new BooleanCovariate("black", 2), - new NumericCovariate("cd4nadir", 3) + new NumericCovariate("ageatfda", 0, false), + new BooleanCovariate("idu", 1, false), + new BooleanCovariate("black", 2, false), + new NumericCovariate("cd4nadir", 3, false) ); } diff --git a/library/src/test/java/ca/joeltherrien/randomforest/TestUtils.java b/library/src/test/java/ca/joeltherrien/randomforest/TestUtils.java index 1b26a03..e290597 100644 --- a/library/src/test/java/ca/joeltherrien/randomforest/TestUtils.java +++ b/library/src/test/java/ca/joeltherrien/randomforest/TestUtils.java @@ -156,7 +156,7 @@ public class TestUtils { } @Test - public void reduceListToSize(){ + public void testReduceListToSize(){ final List testList = Utils.easyList(1, 2, 3, 4, 5, 6, 7, 8, 9, 10); final Random random = new Random(); for(int i=0; i<100; i++) { // want to test many times to ensure it doesn't work just due to randomness diff --git a/library/src/test/java/ca/joeltherrien/randomforest/competingrisk/IBSCalculatorTest.java b/library/src/test/java/ca/joeltherrien/randomforest/competingrisk/IBSCalculatorTest.java index dd3a2af..5598856 100644 --- a/library/src/test/java/ca/joeltherrien/randomforest/competingrisk/IBSCalculatorTest.java +++ b/library/src/test/java/ca/joeltherrien/randomforest/competingrisk/IBSCalculatorTest.java @@ -52,7 +52,7 @@ public class IBSCalculatorTest { */ @Test - public void resultsWithoutCensoringDistribution(){ + public void testResultsWithoutCensoringDistribution(){ final IBSCalculator calculator = new IBSCalculator(); final double errorDifferentEvent = calculator.calculateError( @@ -74,7 +74,7 @@ public class IBSCalculatorTest { } @Test - public void resultsWithCensoringDistribution(){ + public void testResultsWithCensoringDistribution(){ final RightContinuousStepFunction censorSurvivalFunction = RightContinuousStepFunction.constructFromPoints( Utils.easyList( new Point(0.0, 0.75), diff --git a/library/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCompetingRisk.java b/library/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCompetingRisk.java index afbd199..e72d7fd 100644 --- a/library/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCompetingRisk.java +++ b/library/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCompetingRisk.java @@ -53,10 +53,10 @@ public class TestCompetingRisk { public List getCovariates(){ return Utils.easyList( - new NumericCovariate("ageatfda", 0), - new BooleanCovariate("idu", 1), - new BooleanCovariate("black", 2), - new NumericCovariate("cd4nadir", 3) + new NumericCovariate("ageatfda", 0, false), + new BooleanCovariate("idu", 1, false), + new BooleanCovariate("black", 2, false), + new NumericCovariate("cd4nadir", 3, false) ); } @@ -109,8 +109,8 @@ public class TestCompetingRisk { // by only using BooleanCovariates (only one split rule) we can guarantee identical results with randomForestSRC on one tree. final List covariates = Utils.easyList( - new BooleanCovariate("idu", 0), - new BooleanCovariate("black", 1) + new BooleanCovariate("idu", 0, false), + new BooleanCovariate("black", 1, false) ); final List> dataset = getData(covariates, "src/test/resources/wihs.bootstrapped.csv"); @@ -210,8 +210,8 @@ public class TestCompetingRisk { public void testLogRankSplitFinderTwoBooleans() throws IOException { // by only using BooleanCovariates (only one split rule) we can guarantee identical results with randomForestSRC on one tree. final List covariates = Utils.easyList( - new BooleanCovariate("idu", 0), - new BooleanCovariate("black", 1) + new BooleanCovariate("idu", 0, false), + new BooleanCovariate("black", 1, false) ); @@ -259,7 +259,7 @@ public class TestCompetingRisk { } @Test - public void verifyDataset() throws IOException { + public void testDataset() throws IOException { final List covariates = getCovariates(); final List> dataset = getData(covariates, DEFAULT_FILEPATH); diff --git a/library/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestLogRankSplitFinder.java b/library/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestLogRankSplitFinder.java index 99a30e4..3513f5d 100644 --- a/library/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestLogRankSplitFinder.java +++ b/library/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestLogRankSplitFinder.java @@ -46,7 +46,7 @@ public class TestLogRankSplitFinder { public static Data loadData(String filename) throws IOException { final List covariates = Utils.easyList( - new NumericCovariate("x2", 0) + new NumericCovariate("x2", 0, false) ); final List> rows = TestUtils.loadData(covariates, new ResponseLoader.CompetingRisksResponseLoader("delta", "u"), filename); diff --git a/library/src/test/java/ca/joeltherrien/randomforest/covariates/FactorCovariateTest.java b/library/src/test/java/ca/joeltherrien/randomforest/covariates/FactorCovariateTest.java index 84c134d..c214c9f 100644 --- a/library/src/test/java/ca/joeltherrien/randomforest/covariates/FactorCovariateTest.java +++ b/library/src/test/java/ca/joeltherrien/randomforest/covariates/FactorCovariateTest.java @@ -17,12 +17,15 @@ package ca.joeltherrien.randomforest.covariates; +import ca.joeltherrien.randomforest.Row; import ca.joeltherrien.randomforest.covariates.factor.FactorCovariate; +import ca.joeltherrien.randomforest.tree.Split; import ca.joeltherrien.randomforest.utils.Utils; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.function.Executable; import java.util.ArrayList; +import java.util.Collections; import java.util.List; import java.util.Random; @@ -31,7 +34,7 @@ import static org.junit.jupiter.api.Assertions.*; public class FactorCovariateTest { @Test - void verifyEqualLevels() { + public void testEqualLevels() { final FactorCovariate petCovariate = createTestCovariate(); final FactorCovariate.FactorValue dog1 = petCovariate.createValue("DOG"); @@ -53,7 +56,7 @@ public class FactorCovariateTest { } @Test - void verifyBadLevelException(){ + public void testBadLevelException(){ final FactorCovariate petCovariate = createTestCovariate(); final Executable badCode = () -> petCovariate.createValue("vulcan"); @@ -61,25 +64,169 @@ public class FactorCovariateTest { } @Test - void testAllSubsets(){ + public void testAllSubsets(){ + final int n = 2*3; // ensure that n is a multiple of 3 for the test final FactorCovariate petCovariate = createTestCovariate(); + final List> data = generateSampleData(petCovariate, n); - final List> splitRules = new ArrayList<>(); + final List> splits = new ArrayList<>(); - petCovariate.generateSplitRuleUpdater(null, 100, new Random()) - .forEachRemaining(split -> splitRules.add(split.getSplitRule())); + petCovariate.generateSplitRuleUpdater(data, 100, new Random()) + .forEachRemaining(split -> splits.add(split)); - assertEquals(splitRules.size(), 3); + assertEquals(splits.size(), 3); - // TODO verify the contents of the split rules + // These are the 3 possibilities + boolean dog_catmouse = false; + boolean cat_dogmouse = false; + boolean mouse_dogcat = false; + for(Split split : splits){ + List> smallerHand; + List> largerHand; + + if(split.getLeftHand().size() < split.getRightHand().size()){ + smallerHand = split.getLeftHand(); + largerHand = split.getRightHand(); + } else{ + smallerHand = split.getRightHand(); + largerHand = split.getLeftHand(); + } + + // There should be exactly one distinct value in the smaller list + assertEquals(n/3, smallerHand.size()); + assertEquals(1, + smallerHand.stream() + .map(row -> row.getCovariateValue(petCovariate).getValue()) + .distinct() + .count() + ); + + // There should be exactly two distinct values in the smaller list + assertEquals(2*n/3, largerHand.size()); + assertEquals(2, + largerHand.stream() + .map(row -> row.getCovariateValue(petCovariate).getValue()) + .distinct() + .count() + ); + + switch(smallerHand.get(0).getCovariateValue(petCovariate).getValue()){ + case "DOG": + dog_catmouse = true; + case "CAT": + cat_dogmouse = true; + case "MOUSE": + mouse_dogcat = true; + } + + } + + assertTrue(dog_catmouse); + assertTrue(cat_dogmouse); + assertTrue(mouse_dogcat); + + } + + /* + * There was a bug where if number==0 in generateSplitRuleUpdater, then the result was empty. + */ + @Test + public void testNumber0Subsets(){ + final int n = 2*3; // ensure that n is a multiple of 3 for the test + final FactorCovariate petCovariate = createTestCovariate(); + final List> data = generateSampleData(petCovariate, n); + + final List> splits = new ArrayList<>(); + + petCovariate.generateSplitRuleUpdater(data, 0, new Random()) + .forEachRemaining(split -> splits.add(split)); + + assertEquals(splits.size(), 3); + + // These are the 3 possibilities + boolean dog_catmouse = false; + boolean cat_dogmouse = false; + boolean mouse_dogcat = false; + + for(Split split : splits){ + List> smallerHand; + List> largerHand; + + if(split.getLeftHand().size() < split.getRightHand().size()){ + smallerHand = split.getLeftHand(); + largerHand = split.getRightHand(); + } else{ + smallerHand = split.getRightHand(); + largerHand = split.getLeftHand(); + } + + // There should be exactly one distinct value in the smaller list + assertEquals(n/3, smallerHand.size()); + assertEquals(1, + smallerHand.stream() + .map(row -> row.getCovariateValue(petCovariate).getValue()) + .distinct() + .count() + ); + + // There should be exactly two distinct values in the smaller list + assertEquals(2*n/3, largerHand.size()); + assertEquals(2, + largerHand.stream() + .map(row -> row.getCovariateValue(petCovariate).getValue()) + .distinct() + .count() + ); + + switch(smallerHand.get(0).getCovariateValue(petCovariate).getValue()){ + case "DOG": + dog_catmouse = true; + case "CAT": + cat_dogmouse = true; + case "MOUSE": + mouse_dogcat = true; + } + + } + + assertTrue(dog_catmouse); + assertTrue(cat_dogmouse); + assertTrue(mouse_dogcat); + + } + + @Test + public void testSpitRuleUpdaterWithNAs(){ + // When some NAs were present calling generateSplitRuleUpdater caused an exception. + + final FactorCovariate covariate = createTestCovariate(); + final List> sampleData = generateSampleData(covariate, 10); + sampleData.add(Row.createSimple(Utils.easyMap("pet", "NA"), Collections.singletonList(covariate), 11, 5.0)); + + covariate.generateSplitRuleUpdater(sampleData, 0, new Random()); + + // Test passes if no exception has occurred. } private FactorCovariate createTestCovariate(){ final List levels = Utils.easyList("DOG", "CAT", "MOUSE"); - return new FactorCovariate("pet", 0, levels); + return new FactorCovariate("pet", 0, levels, false); + } + + private List> generateSampleData(Covariate covariate, int n){ + final List covariateList = Collections.singletonList(covariate); + final List> dataList = new ArrayList<>(n); + + final String[] levels = new String[]{"DOG", "CAT", "MOUSE"}; + + for(int i=0; i> dataset = createTestDataset(covariate); @@ -158,7 +158,7 @@ public class NumericCovariateTest { @Test public void testNumericSplitRuleUpdaterWithIndexes(){ - final NumericCovariate covariate = new NumericCovariate("x", 0); + final NumericCovariate covariate = new NumericCovariate("x", 0, false); final List> dataset = createTestDataset(covariate); @@ -223,7 +223,7 @@ public class NumericCovariateTest { */ @Test public void testNumericSplitRuleUpdaterWithIndexesAllMissingData(){ - final NumericCovariate covariate = new NumericCovariate("x", 0); + final NumericCovariate covariate = new NumericCovariate("x", 0, false); final List> dataset = createTestDatasetMissingValues(covariate); final NumericSplitRuleUpdater updater = covariate.generateSplitRuleUpdater(dataset, 5, new Random()); diff --git a/library/src/test/java/ca/joeltherrien/randomforest/nas/TestNAs.java b/library/src/test/java/ca/joeltherrien/randomforest/nas/TestNAs.java index d60b13d..4e63b1c 100644 --- a/library/src/test/java/ca/joeltherrien/randomforest/nas/TestNAs.java +++ b/library/src/test/java/ca/joeltherrien/randomforest/nas/TestNAs.java @@ -18,31 +18,34 @@ package ca.joeltherrien.randomforest.nas; import ca.joeltherrien.randomforest.Row; import ca.joeltherrien.randomforest.covariates.Covariate; +import ca.joeltherrien.randomforest.covariates.bool.BooleanCovariate; +import ca.joeltherrien.randomforest.covariates.factor.FactorCovariate; import ca.joeltherrien.randomforest.covariates.numeric.NumericCovariate; import ca.joeltherrien.randomforest.responses.regression.MeanResponseCombiner; import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceSplitFinder; +import ca.joeltherrien.randomforest.tree.Split; import ca.joeltherrien.randomforest.tree.TreeTrainer; import ca.joeltherrien.randomforest.utils.Utils; import org.junit.jupiter.api.Test; import java.util.ArrayList; -import java.util.Collections; import java.util.List; import java.util.Random; +import static org.junit.jupiter.api.Assertions.*; + public class TestNAs { - private List> generateData(List covariates){ + private List> generateData1(List covariates){ final List> dataList = new ArrayList<>(); - // We must include an NA for one of the values - dataList.add(Row.createSimple(Utils.easyMap("x", "NA"), covariates, 1, 5.0)); - dataList.add(Row.createSimple(Utils.easyMap("x", "1"), covariates, 1, 6.0)); - dataList.add(Row.createSimple(Utils.easyMap("x", "2"), covariates, 1, 5.5)); - dataList.add(Row.createSimple(Utils.easyMap("x", "7"), covariates, 1, 0.0)); - dataList.add(Row.createSimple(Utils.easyMap("x", "8"), covariates, 1, 1.0)); - dataList.add(Row.createSimple(Utils.easyMap("x", "8.4"), covariates, 1, 1.0)); + dataList.add(Row.createSimple(Utils.easyMap("x", "NA", "y", "true", "z", "green"), covariates, 1, 5.0)); + dataList.add(Row.createSimple(Utils.easyMap("x", "1", "y", "NA", "z", "blue"), covariates, 2, 6.0)); + dataList.add(Row.createSimple(Utils.easyMap("x", "2", "y", "true", "z", "NA"), covariates, 3, 5.5)); + dataList.add(Row.createSimple(Utils.easyMap("x", "7", "y", "false", "z", "green"), covariates, 4, 0.0)); + dataList.add(Row.createSimple(Utils.easyMap("x", "8", "y", "true", "z", "blue"), covariates, 5, 1.0)); + dataList.add(Row.createSimple(Utils.easyMap("x", "8.4", "y", "false", "z", "yellow"), covariates, 6, 1.0)); return dataList; @@ -54,14 +57,19 @@ public class TestNAs { // but NumericSplitRuleUpdater had unmodifiable lists when creating the split. // This bug verifies that this no longer causes a crash - final List covariates = Collections.singletonList(new NumericCovariate("x", 0)); - final List> dataset = generateData(covariates); + final List covariates = Utils.easyList( + new NumericCovariate("x", 0, false), + new BooleanCovariate("y", 1, true), + new FactorCovariate("z", 2, Utils.easyList("green", "blue", "yellow"), true) + ); + final List> dataset = generateData1(covariates); final TreeTrainer treeTrainer = TreeTrainer.builder() .checkNodePurity(false) .covariates(covariates) .numberOfSplits(0) .nodeSize(1) + .mtry(3) .maxNodeDepth(1000) .splitFinder(new WeightedVarianceSplitFinder()) .responseCombiner(new MeanResponseCombiner()) @@ -70,6 +78,87 @@ public class TestNAs { treeTrainer.growTree(dataset, new Random(123)); // As long as no exception occurs, we passed + } + + private List> generateData2(List covariates){ + final List> dataList = new ArrayList<>(); + // Idea - when ignoring NAs, BadVar gives a perfect split. + // GoodVar is slightly worse than BadVar when NAs are excluded. + // However, BadVar has a ton of NAs that will degrade its performance. + dataList.add(Row.createSimple( + Utils.easyMap("BadVar", "-1.0", "GoodVar", "true") // GoodVars one error + , covariates, 1, 5.0) + ); + dataList.add(Row.createSimple( + Utils.easyMap("BadVar", "NA", "GoodVar", "false") + , covariates, 2, 5.0) + ); + dataList.add(Row.createSimple( + Utils.easyMap("BadVar", "NA", "GoodVar", "false") + , covariates, 3, 5.0) + ); + dataList.add(Row.createSimple( + Utils.easyMap("BadVar", "0.5", "GoodVar", "true") + , covariates, 4, 10.0) + ); + dataList.add(Row.createSimple( + Utils.easyMap("BadVar", "NA", "GoodVar", "true") + , covariates, 5, 10.0) + ); + dataList.add(Row.createSimple( + Utils.easyMap("BadVar", "NA", "GoodVar", "true") + , covariates, 6, 10.0) + ); + + return dataList; + } + + @Test + // Test that the NA penalty works when selecting a best split. + public void testNAPenalty(){ + final List covariates1 = Utils.easyList( + new NumericCovariate("BadVar", 0, true), + new BooleanCovariate("GoodVar", 1, false) + ); + + final List> dataList1 = generateData2(covariates1); + + final TreeTrainer treeTrainer1 = TreeTrainer.builder() + .checkNodePurity(false) + .covariates(covariates1) + .numberOfSplits(0) + .nodeSize(1) + .mtry(2) + .maxNodeDepth(1000) + .splitFinder(new WeightedVarianceSplitFinder()) + .responseCombiner(new MeanResponseCombiner()) + .build(); + + final Split bestSplit1 = treeTrainer1.findBestSplitRule(dataList1, covariates1, new Random(123)); + assertEquals(1, bestSplit1.getSplitRule().getParentCovariateIndex()); // 1 corresponds to GoodVar + + // Run again without the penalty; verify that we get different results + + final List covariates2 = Utils.easyList( + new NumericCovariate("BadVar", 0, false), + new BooleanCovariate("GoodVar", 1, false) + ); + + final List> dataList2 = generateData2(covariates2); + + final TreeTrainer treeTrainer2 = TreeTrainer.builder() + .checkNodePurity(false) + .covariates(covariates2) + .numberOfSplits(0) + .nodeSize(1) + .mtry(2) + .maxNodeDepth(1000) + .splitFinder(new WeightedVarianceSplitFinder()) + .responseCombiner(new MeanResponseCombiner()) + .build(); + + final Split bestSplit2 = treeTrainer2.findBestSplitRule(dataList2, covariates2, new Random(123)); + assertEquals(0, bestSplit2.getSplitRule().getParentCovariateIndex()); // 1 corresponds to GoodVar } diff --git a/library/src/test/java/ca/joeltherrien/randomforest/tree/vimp/VariableImportanceCalculatorTest.java b/library/src/test/java/ca/joeltherrien/randomforest/tree/vimp/VariableImportanceCalculatorTest.java index b656ce0..e8a83a9 100644 --- a/library/src/test/java/ca/joeltherrien/randomforest/tree/vimp/VariableImportanceCalculatorTest.java +++ b/library/src/test/java/ca/joeltherrien/randomforest/tree/vimp/VariableImportanceCalculatorTest.java @@ -39,10 +39,10 @@ public class VariableImportanceCalculatorTest { */ public VariableImportanceCalculatorTest(){ - final BooleanCovariate booleanCovariate = new BooleanCovariate("x", 0); - final NumericCovariate numericCovariate = new NumericCovariate("y", 1); + final BooleanCovariate booleanCovariate = new BooleanCovariate("x", 0, false); + final NumericCovariate numericCovariate = new NumericCovariate("y", 1, false); final FactorCovariate factorCovariate = new FactorCovariate("z", 2, - Utils.easyList("red", "blue", "green")); + Utils.easyList("red", "blue", "green"), false); this.covariates = Utils.easyList(booleanCovariate, numericCovariate, factorCovariate); diff --git a/library/src/test/java/ca/joeltherrien/randomforest/workshop/TrainForest.java b/library/src/test/java/ca/joeltherrien/randomforest/workshop/TrainForest.java index c9a031a..679fef7 100644 --- a/library/src/test/java/ca/joeltherrien/randomforest/workshop/TrainForest.java +++ b/library/src/test/java/ca/joeltherrien/randomforest/workshop/TrainForest.java @@ -43,7 +43,7 @@ public class TrainForest { final List covariateList = new ArrayList<>(p); for(int j =0; j < p; j++){ - final NumericCovariate covariate = new NumericCovariate("x"+j, j); + final NumericCovariate covariate = new NumericCovariate("x"+j, j, false); covariateList.add(covariate); } diff --git a/library/src/test/java/ca/joeltherrien/randomforest/workshop/TrainSingleTree.java b/library/src/test/java/ca/joeltherrien/randomforest/workshop/TrainSingleTree.java index 8e1e361..6c9e0ed 100644 --- a/library/src/test/java/ca/joeltherrien/randomforest/workshop/TrainSingleTree.java +++ b/library/src/test/java/ca/joeltherrien/randomforest/workshop/TrainSingleTree.java @@ -39,8 +39,8 @@ public class TrainSingleTree { final int n = 1000; final List> trainingSet = new ArrayList<>(n); - final Covariate x1Covariate = new NumericCovariate("x1", 0); - final Covariate x2Covariate = new NumericCovariate("x2", 1); + final Covariate x1Covariate = new NumericCovariate("x1", 0, false); + final Covariate x2Covariate = new NumericCovariate("x2", 1, false); final List> x1List = DoubleStream .generate(() -> random.nextDouble()*10.0) diff --git a/library/src/test/java/ca/joeltherrien/randomforest/workshop/TrainSingleTreeFactor.java b/library/src/test/java/ca/joeltherrien/randomforest/workshop/TrainSingleTreeFactor.java index 1d2b404..bc0bd6e 100644 --- a/library/src/test/java/ca/joeltherrien/randomforest/workshop/TrainSingleTreeFactor.java +++ b/library/src/test/java/ca/joeltherrien/randomforest/workshop/TrainSingleTreeFactor.java @@ -41,9 +41,9 @@ public class TrainSingleTreeFactor { final int n = 10000; final List> trainingSet = new ArrayList<>(n); - final Covariate x1Covariate = new NumericCovariate("x1", 0); - final Covariate x2Covariate = new NumericCovariate("x2", 1); - final FactorCovariate x3Covariate = new FactorCovariate("x3", 2, Utils.easyList("cat", "dog", "mouse")); + final Covariate x1Covariate = new NumericCovariate("x1", 0, false); + final Covariate x2Covariate = new NumericCovariate("x2", 1, false); + final FactorCovariate x3Covariate = new FactorCovariate("x3", 2, Utils.easyList("cat", "dog", "mouse"), false); final List> x1List = DoubleStream .generate(() -> random.nextDouble()*10.0)