diff --git a/src/main/java/ca/joeltherrien/randomforest/Bootstrapper.java b/src/main/java/ca/joeltherrien/randomforest/Bootstrapper.java index b73db9d..a7577db 100644 --- a/src/main/java/ca/joeltherrien/randomforest/Bootstrapper.java +++ b/src/main/java/ca/joeltherrien/randomforest/Bootstrapper.java @@ -4,20 +4,20 @@ import lombok.RequiredArgsConstructor; import java.util.ArrayList; import java.util.List; -import java.util.concurrent.ThreadLocalRandom; +import java.util.Random; @RequiredArgsConstructor public class Bootstrapper { final private List originalData; - public List bootstrap(){ + public List bootstrap(Random random){ final int n = originalData.size(); final List newList = new ArrayList<>(n); for(int i=0; i getCovariateValue(Covariate covariate){ + public Covariate.Value getCovariateValue(Covariate covariate){ return valueArray[covariate.getIndex()]; } diff --git a/src/main/java/ca/joeltherrien/randomforest/Main.java b/src/main/java/ca/joeltherrien/randomforest/Main.java index f9b7194..f5a7731 100644 --- a/src/main/java/ca/joeltherrien/randomforest/Main.java +++ b/src/main/java/ca/joeltherrien/randomforest/Main.java @@ -1,9 +1,9 @@ package ca.joeltherrien.randomforest; -import ca.joeltherrien.randomforest.covariates.BooleanCovariateSettings; +import ca.joeltherrien.randomforest.covariates.settings.BooleanCovariateSettings; import ca.joeltherrien.randomforest.covariates.Covariate; -import ca.joeltherrien.randomforest.covariates.FactorCovariateSettings; -import ca.joeltherrien.randomforest.covariates.NumericCovariateSettings; +import ca.joeltherrien.randomforest.covariates.settings.FactorCovariateSettings; +import ca.joeltherrien.randomforest.covariates.settings.NumericCovariateSettings; import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskErrorRateCalculator; import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskFunctions; import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse; diff --git a/src/main/java/ca/joeltherrien/randomforest/Row.java b/src/main/java/ca/joeltherrien/randomforest/Row.java index 850f035..00c5078 100644 --- a/src/main/java/ca/joeltherrien/randomforest/Row.java +++ b/src/main/java/ca/joeltherrien/randomforest/Row.java @@ -17,6 +17,8 @@ public class Row extends CovariateRow { } + + public Y getResponse() { return this.response; } diff --git a/src/main/java/ca/joeltherrien/randomforest/Settings.java b/src/main/java/ca/joeltherrien/randomforest/Settings.java index ee40a35..c89681c 100644 --- a/src/main/java/ca/joeltherrien/randomforest/Settings.java +++ b/src/main/java/ca/joeltherrien/randomforest/Settings.java @@ -1,7 +1,7 @@ package ca.joeltherrien.randomforest; import ca.joeltherrien.randomforest.covariates.Covariate; -import ca.joeltherrien.randomforest.covariates.CovariateSettings; +import ca.joeltherrien.randomforest.covariates.settings.CovariateSettings; import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse; import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponseWithCensorTime; import ca.joeltherrien.randomforest.responses.competingrisk.combiner.CompetingRiskFunctionCombiner; @@ -10,7 +10,6 @@ import ca.joeltherrien.randomforest.responses.competingrisk.differentiator.GrayL import ca.joeltherrien.randomforest.responses.competingrisk.differentiator.GrayLogRankSingleGroupDifferentiator; import ca.joeltherrien.randomforest.responses.competingrisk.differentiator.LogRankMultipleGroupDifferentiator; import ca.joeltherrien.randomforest.responses.competingrisk.differentiator.LogRankSingleGroupDifferentiator; -import ca.joeltherrien.randomforest.responses.regression.MeanGroupDifferentiator; import ca.joeltherrien.randomforest.responses.regression.MeanResponseCombiner; import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceGroupDifferentiator; import ca.joeltherrien.randomforest.tree.GroupDifferentiator; @@ -68,9 +67,6 @@ public class Settings { GROUP_DIFFERENTIATOR_MAP.put(name.toLowerCase(), groupDifferentiatorConstructor); } static{ - registerGroupDifferentiatorConstructor("MeanGroupDifferentiator", - (node) -> new MeanGroupDifferentiator() - ); registerGroupDifferentiatorConstructor("WeightedVarianceGroupDifferentiator", (node) -> new WeightedVarianceGroupDifferentiator() ); diff --git a/src/main/java/ca/joeltherrien/randomforest/covariates/BooleanCovariate.java b/src/main/java/ca/joeltherrien/randomforest/covariates/BooleanCovariate.java index 8959bc4..85221a0 100644 --- a/src/main/java/ca/joeltherrien/randomforest/covariates/BooleanCovariate.java +++ b/src/main/java/ca/joeltherrien/randomforest/covariates/BooleanCovariate.java @@ -1,14 +1,15 @@ package ca.joeltherrien.randomforest.covariates; +import ca.joeltherrien.randomforest.Row; +import ca.joeltherrien.randomforest.tree.Split; +import ca.joeltherrien.randomforest.utils.SingletonIterator; import lombok.Getter; import lombok.RequiredArgsConstructor; -import java.util.Collection; -import java.util.Collections; -import java.util.List; +import java.util.*; @RequiredArgsConstructor -public final class BooleanCovariate implements Covariate{ +public final class BooleanCovariate implements Covariate { @Getter private final String name; @@ -19,8 +20,8 @@ public final class BooleanCovariate implements Covariate{ private final BooleanSplitRule splitRule = new BooleanSplitRule(); // there's only one possible rule for BooleanCovariates. @Override - public Collection generateSplitRules(List> data, int number) { - return Collections.singleton(splitRule); + public Iterator> generateSplitRuleUpdater(List> data, int number, Random random) { + return new SingletonIterator<>(BooleanCovariate.this.splitRule.applyRule(data)); } @Override @@ -74,6 +75,7 @@ public final class BooleanCovariate implements Covariate{ } } + public class BooleanSplitRule implements SplitRule{ @Override diff --git a/src/main/java/ca/joeltherrien/randomforest/covariates/Covariate.java b/src/main/java/ca/joeltherrien/randomforest/covariates/Covariate.java index 9fa77ef..6e9b94e 100644 --- a/src/main/java/ca/joeltherrien/randomforest/covariates/Covariate.java +++ b/src/main/java/ca/joeltherrien/randomforest/covariates/Covariate.java @@ -5,10 +5,7 @@ import ca.joeltherrien.randomforest.Row; import ca.joeltherrien.randomforest.tree.Split; import java.io.Serializable; -import java.util.ArrayList; -import java.util.Collection; -import java.util.LinkedList; -import java.util.List; +import java.util.*; import java.util.concurrent.ThreadLocalRandom; public interface Covariate extends Serializable { @@ -17,7 +14,7 @@ public interface Covariate extends Serializable { int getIndex(); - Collection> generateSplitRules(final List> data, final int number); + Iterator> generateSplitRuleUpdater(final List> data, final int number, final Random random); Value createValue(V value); @@ -39,6 +36,16 @@ public interface Covariate extends Serializable { } + interface SplitRuleUpdater extends Iterator>{ + Split currentSplit(); + SplitUpdate nextUpdate(); + } + + interface SplitUpdate { + SplitRule getSplitRule(); + Collection> rowsMovedToLeftHand(); + } + interface SplitRule extends Serializable{ Covariate getParent(); @@ -51,7 +58,7 @@ public interface Covariate extends Serializable { * @param * @return */ - default Split applyRule(List> rows) { + default Split applyRule(List> rows) { final List> leftHand = new LinkedList<>(); final List> rightHand = new LinkedList<>(); @@ -77,7 +84,7 @@ public interface Covariate extends Serializable { } - return new Split<>(leftHand, rightHand, missingValueRows); + return new Split<>(this, leftHand, rightHand, missingValueRows); } default boolean isLeftHand(CovariateRow row, final double probabilityNaLeftHand){ diff --git a/src/main/java/ca/joeltherrien/randomforest/covariates/FactorCovariate.java b/src/main/java/ca/joeltherrien/randomforest/covariates/FactorCovariate.java index 402cb7d..fd00c7e 100644 --- a/src/main/java/ca/joeltherrien/randomforest/covariates/FactorCovariate.java +++ b/src/main/java/ca/joeltherrien/randomforest/covariates/FactorCovariate.java @@ -1,5 +1,7 @@ package ca.joeltherrien.randomforest.covariates; +import ca.joeltherrien.randomforest.Row; +import ca.joeltherrien.randomforest.tree.Split; import lombok.EqualsAndHashCode; import lombok.Getter; @@ -42,17 +44,14 @@ public final class FactorCovariate implements Covariate{ @Override - public Set generateSplitRules(List> data, int number) { - final Set splitRules = new HashSet<>(); + public Iterator> generateSplitRuleUpdater(List> data, int number, Random random) { + final Set> splits = new HashSet<>(); // This is to ensure we don't get stuck in an infinite loop for small factors number = Math.min(number, numberOfPossiblePairings); - final Random random = ThreadLocalRandom.current(); final List levels = new ArrayList<>(factorLevels.values()); - - - while(splitRules.size() < number){ + while(splits.size() < number){ Collections.shuffle(levels, random); final Set leftSideValues = new HashSet<>(); leftSideValues.add(levels.get(0)); @@ -63,13 +62,14 @@ public final class FactorCovariate implements Covariate{ } } - splitRules.add(new FactorSplitRule(leftSideValues)); + splits.add(new FactorSplitRule(leftSideValues).applyRule(data)); } - return splitRules; + return splits.iterator(); } + @Override public FactorValue createValue(String value) { if(value == null || value.equalsIgnoreCase("na")){ diff --git a/src/main/java/ca/joeltherrien/randomforest/covariates/NumericCovariate.java b/src/main/java/ca/joeltherrien/randomforest/covariates/numeric/NumericCovariate.java similarity index 52% rename from src/main/java/ca/joeltherrien/randomforest/covariates/NumericCovariate.java rename to src/main/java/ca/joeltherrien/randomforest/covariates/numeric/NumericCovariate.java index 7687bad..8924626 100644 --- a/src/main/java/ca/joeltherrien/randomforest/covariates/NumericCovariate.java +++ b/src/main/java/ca/joeltherrien/randomforest/covariates/numeric/NumericCovariate.java @@ -1,17 +1,21 @@ -package ca.joeltherrien.randomforest.covariates; +package ca.joeltherrien.randomforest.covariates.numeric; +import ca.joeltherrien.randomforest.Row; +import ca.joeltherrien.randomforest.covariates.Covariate; +import ca.joeltherrien.randomforest.utils.IndexedIterator; +import ca.joeltherrien.randomforest.utils.UniqueSubsetValueIterator; +import ca.joeltherrien.randomforest.utils.UniqueValueIterator; import lombok.EqualsAndHashCode; import lombok.Getter; import lombok.RequiredArgsConstructor; import lombok.ToString; import java.util.*; -import java.util.concurrent.ThreadLocalRandom; import java.util.stream.Collectors; @RequiredArgsConstructor @ToString -public final class NumericCovariate implements Covariate{ +public final class NumericCovariate implements Covariate { @Getter private final String name; @@ -20,40 +24,44 @@ public final class NumericCovariate implements Covariate{ private final int index; @Override - public Collection generateSplitRules(List> data, int number) { + public NumericSplitRuleUpdater generateSplitRuleUpdater(List> data, int number, Random random) { + data = data.stream() + .filter(row -> !row.getCovariateValue(this).isNA()) + .sorted((r1, r2) -> { + Double d1 = r1.getCovariateValue(this).getValue(); + Double d2 = r2.getCovariateValue(this).getValue(); - final Random random = ThreadLocalRandom.current(); + return d1.compareTo(d2); + }) + .collect(Collectors.toList()); - // only work with non-NA values - data = data.stream().filter(value -> !value.isNA()).collect(Collectors.toList()); - //data = data.stream().filter(value -> !value.isNA()).distinct().collect(Collectors.toList()); // TODO which to use? + Iterator sortedDataIterator = data.stream() + .map(row -> row.getCovariateValue(this).getValue()) + .filter(v -> v != null) + .iterator(); - // for this implementation we need to shuffle the data - final List> shuffledData; - if(number >= data.size()){ - shuffledData = data; + + final IndexedIterator dataIterator; + if(number == 0){ + dataIterator = new UniqueValueIterator<>(sortedDataIterator); } - else{ // only need the top number entries - shuffledData = new ArrayList<>(number); - final Set indexesToUse = new HashSet<>(); - //final List indexesToUse = new ArrayList<>(); // TODO which to use? + else{ // random splitting; we will not weight based on how many times a Row appears in the bootstrap sample + final TreeSet indexSet = new TreeSet<>(); - while(indexesToUse.size() < number){ - final int index = random.nextInt(data.size()); + final int maxIndex = data.size(); - if(indexesToUse.add(index)){ - shuffledData.add(data.get(index)); - } + for(int i=0; i( + new UniqueValueIterator<>(sortedDataIterator), + indexSet.toArray(new Integer[indexSet.size()]) // TODO verify this is ordered + ); + } - return shuffledData.stream() - .mapToDouble(v -> v.getValue()) - .mapToObj(threshold -> new NumericSplitRule(threshold)) - .collect(Collectors.toSet()); - // by returning a set we'll make everything far more efficient as a lot of rules can repeat due to bootstrapping - + return new NumericSplitRuleUpdater<>(this, data, dataIterator); } @@ -101,7 +109,7 @@ public final class NumericCovariate implements Covariate{ private final double threshold; - private NumericSplitRule(final double threshold){ + NumericSplitRule(final double threshold){ this.threshold = threshold; } diff --git a/src/main/java/ca/joeltherrien/randomforest/covariates/numeric/NumericSplitRuleUpdater.java b/src/main/java/ca/joeltherrien/randomforest/covariates/numeric/NumericSplitRuleUpdater.java new file mode 100644 index 0000000..d41b270 --- /dev/null +++ b/src/main/java/ca/joeltherrien/randomforest/covariates/numeric/NumericSplitRuleUpdater.java @@ -0,0 +1,79 @@ +package ca.joeltherrien.randomforest.covariates.numeric; + +import ca.joeltherrien.randomforest.Row; +import ca.joeltherrien.randomforest.covariates.Covariate; +import ca.joeltherrien.randomforest.tree.Split; +import ca.joeltherrien.randomforest.utils.IndexedIterator; + +import java.util.Collections; +import java.util.List; + +public class NumericSplitRuleUpdater implements Covariate.SplitRuleUpdater { + + private final NumericCovariate covariate; + private final List> orderedData; + private final IndexedIterator dataIterator; + + private Split currentSplit; + + public NumericSplitRuleUpdater(final NumericCovariate covariate, final List> orderedData, final IndexedIterator iterator){ + this.covariate = covariate; + this.orderedData = orderedData; + this.dataIterator = iterator; + + final List> leftHandList = Collections.emptyList(); + final List> rightHandList = orderedData; + + this.currentSplit = new Split<>( + covariate.new NumericSplitRule(Double.MIN_VALUE), + leftHandList, + rightHandList, + Collections.emptyList()); + + } + + @Override + public Split currentSplit() { + return this.currentSplit; + } + + @Override + public NumericSplitUpdate nextUpdate() { + if(hasNext()){ + final int currentPosition = dataIterator.getIndex(); + final Double splitValue = dataIterator.next(); + final int newPosition = dataIterator.getIndex(); + + final List> rowsMoved = orderedData.subList(currentPosition, newPosition); + + final NumericCovariate.NumericSplitRule splitRule = covariate.new NumericSplitRule(splitValue); + + // Update current split + this.currentSplit = new Split<>( + splitRule, + orderedData.subList(0, newPosition), + orderedData.subList(newPosition, orderedData.size()), + Collections.emptyList()); + + + return new NumericSplitUpdate<>(splitRule, rowsMoved); + } + + return null; + } + + @Override + public boolean hasNext() { + return dataIterator.hasNext(); + } + + @Override + public Split next() { + if(hasNext()){ + nextUpdate(); + } + + return this.currentSplit(); + } + +} diff --git a/src/main/java/ca/joeltherrien/randomforest/covariates/numeric/NumericSplitUpdate.java b/src/main/java/ca/joeltherrien/randomforest/covariates/numeric/NumericSplitUpdate.java new file mode 100644 index 0000000..f2757c0 --- /dev/null +++ b/src/main/java/ca/joeltherrien/randomforest/covariates/numeric/NumericSplitUpdate.java @@ -0,0 +1,24 @@ +package ca.joeltherrien.randomforest.covariates.numeric; + +import ca.joeltherrien.randomforest.Row; +import ca.joeltherrien.randomforest.covariates.Covariate; +import lombok.AllArgsConstructor; + +import java.util.Collection; + +@AllArgsConstructor +public class NumericSplitUpdate implements Covariate.SplitUpdate { + + private final NumericCovariate.NumericSplitRule numericSplitRule; + private final Collection> rowsMoved; + + @Override + public NumericCovariate.NumericSplitRule getSplitRule() { + return numericSplitRule; + } + + @Override + public Collection> rowsMovedToLeftHand() { + return rowsMoved; + } +} diff --git a/src/main/java/ca/joeltherrien/randomforest/covariates/BooleanCovariateSettings.java b/src/main/java/ca/joeltherrien/randomforest/covariates/settings/BooleanCovariateSettings.java similarity index 75% rename from src/main/java/ca/joeltherrien/randomforest/covariates/BooleanCovariateSettings.java rename to src/main/java/ca/joeltherrien/randomforest/covariates/settings/BooleanCovariateSettings.java index 5f57c15..ba6366e 100644 --- a/src/main/java/ca/joeltherrien/randomforest/covariates/BooleanCovariateSettings.java +++ b/src/main/java/ca/joeltherrien/randomforest/covariates/settings/BooleanCovariateSettings.java @@ -1,5 +1,6 @@ -package ca.joeltherrien.randomforest.covariates; +package ca.joeltherrien.randomforest.covariates.settings; +import ca.joeltherrien.randomforest.covariates.BooleanCovariate; import lombok.Data; import lombok.NoArgsConstructor; diff --git a/src/main/java/ca/joeltherrien/randomforest/covariates/CovariateSettings.java b/src/main/java/ca/joeltherrien/randomforest/covariates/settings/CovariateSettings.java similarity index 87% rename from src/main/java/ca/joeltherrien/randomforest/covariates/CovariateSettings.java rename to src/main/java/ca/joeltherrien/randomforest/covariates/settings/CovariateSettings.java index 4d850ac..9b9f93c 100644 --- a/src/main/java/ca/joeltherrien/randomforest/covariates/CovariateSettings.java +++ b/src/main/java/ca/joeltherrien/randomforest/covariates/settings/CovariateSettings.java @@ -1,5 +1,6 @@ -package ca.joeltherrien.randomforest.covariates; +package ca.joeltherrien.randomforest.covariates.settings; +import ca.joeltherrien.randomforest.covariates.Covariate; import com.fasterxml.jackson.annotation.JsonSubTypes; import com.fasterxml.jackson.annotation.JsonTypeInfo; import lombok.Getter; diff --git a/src/main/java/ca/joeltherrien/randomforest/covariates/FactorCovariateSettings.java b/src/main/java/ca/joeltherrien/randomforest/covariates/settings/FactorCovariateSettings.java similarity index 82% rename from src/main/java/ca/joeltherrien/randomforest/covariates/FactorCovariateSettings.java rename to src/main/java/ca/joeltherrien/randomforest/covariates/settings/FactorCovariateSettings.java index 9d7ece5..f40213c 100644 --- a/src/main/java/ca/joeltherrien/randomforest/covariates/FactorCovariateSettings.java +++ b/src/main/java/ca/joeltherrien/randomforest/covariates/settings/FactorCovariateSettings.java @@ -1,5 +1,6 @@ -package ca.joeltherrien.randomforest.covariates; +package ca.joeltherrien.randomforest.covariates.settings; +import ca.joeltherrien.randomforest.covariates.FactorCovariate; import lombok.Data; import lombok.NoArgsConstructor; diff --git a/src/main/java/ca/joeltherrien/randomforest/covariates/NumericCovariateSettings.java b/src/main/java/ca/joeltherrien/randomforest/covariates/settings/NumericCovariateSettings.java similarity index 74% rename from src/main/java/ca/joeltherrien/randomforest/covariates/NumericCovariateSettings.java rename to src/main/java/ca/joeltherrien/randomforest/covariates/settings/NumericCovariateSettings.java index 9cdf898..84a69a7 100644 --- a/src/main/java/ca/joeltherrien/randomforest/covariates/NumericCovariateSettings.java +++ b/src/main/java/ca/joeltherrien/randomforest/covariates/settings/NumericCovariateSettings.java @@ -1,5 +1,6 @@ -package ca.joeltherrien.randomforest.covariates; +package ca.joeltherrien.randomforest.covariates.settings; +import ca.joeltherrien.randomforest.covariates.numeric.NumericCovariate; import lombok.Data; import lombok.NoArgsConstructor; diff --git a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/differentiator/CompetingRiskGroupDifferentiator.java b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/differentiator/CompetingRiskGroupDifferentiator.java index f6efdeb..b09026a 100644 --- a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/differentiator/CompetingRiskGroupDifferentiator.java +++ b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/differentiator/CompetingRiskGroupDifferentiator.java @@ -3,6 +3,7 @@ package ca.joeltherrien.randomforest.responses.competingrisk.differentiator; import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse; import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskSets; import ca.joeltherrien.randomforest.tree.GroupDifferentiator; +import ca.joeltherrien.randomforest.tree.SimpleGroupDifferentiator; import lombok.AllArgsConstructor; import lombok.Data; @@ -14,11 +15,7 @@ import java.util.stream.Stream; * modifies the abstract method. * */ -public abstract class CompetingRiskGroupDifferentiator implements GroupDifferentiator{ - - @Override - public abstract Double differentiate(List leftHand, List rightHand); - +public abstract class CompetingRiskGroupDifferentiator extends SimpleGroupDifferentiator { /** * Calculates the log rank value (or the Gray's test value) for a *specific* event cause. diff --git a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/differentiator/GrayLogRankMultipleGroupDifferentiator.java b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/differentiator/GrayLogRankMultipleGroupDifferentiator.java index 64eeeda..62275c2 100644 --- a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/differentiator/GrayLogRankMultipleGroupDifferentiator.java +++ b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/differentiator/GrayLogRankMultipleGroupDifferentiator.java @@ -17,7 +17,7 @@ public class GrayLogRankMultipleGroupDifferentiator extends CompetingRiskGroupDi private final int[] events; @Override - public Double differentiate(List leftHand, List rightHand) { + public Double getScore(List leftHand, List rightHand) { if(leftHand.size() == 0 || rightHand.size() == 0){ return null; } diff --git a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/differentiator/GrayLogRankSingleGroupDifferentiator.java b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/differentiator/GrayLogRankSingleGroupDifferentiator.java index 48e3b6f..1287f03 100644 --- a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/differentiator/GrayLogRankSingleGroupDifferentiator.java +++ b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/differentiator/GrayLogRankSingleGroupDifferentiator.java @@ -18,7 +18,7 @@ public class GrayLogRankSingleGroupDifferentiator extends CompetingRiskGroupDiff private final int[] events; @Override - public Double differentiate(List leftHand, List rightHand) { + public Double getScore(List leftHand, List rightHand) { if(leftHand.size() == 0 || rightHand.size() == 0){ return null; } diff --git a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/differentiator/LogRankMultipleGroupDifferentiator.java b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/differentiator/LogRankMultipleGroupDifferentiator.java index 2ad2424..2478c8b 100644 --- a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/differentiator/LogRankMultipleGroupDifferentiator.java +++ b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/differentiator/LogRankMultipleGroupDifferentiator.java @@ -17,7 +17,7 @@ public class LogRankMultipleGroupDifferentiator extends CompetingRiskGroupDiffer private final int[] events; @Override - public Double differentiate(List leftHand, List rightHand) { + public Double getScore(List leftHand, List rightHand) { if(leftHand.size() == 0 || rightHand.size() == 0){ return null; } diff --git a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/differentiator/LogRankSingleGroupDifferentiator.java b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/differentiator/LogRankSingleGroupDifferentiator.java index 107964e..817b4eb 100644 --- a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/differentiator/LogRankSingleGroupDifferentiator.java +++ b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/differentiator/LogRankSingleGroupDifferentiator.java @@ -18,7 +18,7 @@ public class LogRankSingleGroupDifferentiator extends CompetingRiskGroupDifferen private final int[] events; @Override - public Double differentiate(List leftHand, List rightHand) { + public Double getScore(List leftHand, List rightHand) { if(leftHand.size() == 0 || rightHand.size() == 0){ return null; } diff --git a/src/main/java/ca/joeltherrien/randomforest/responses/regression/MeanGroupDifferentiator.java b/src/main/java/ca/joeltherrien/randomforest/responses/regression/MeanGroupDifferentiator.java deleted file mode 100644 index 75bc129..0000000 --- a/src/main/java/ca/joeltherrien/randomforest/responses/regression/MeanGroupDifferentiator.java +++ /dev/null @@ -1,26 +0,0 @@ -package ca.joeltherrien.randomforest.responses.regression; - -import ca.joeltherrien.randomforest.tree.GroupDifferentiator; - -import java.util.List; - -public class MeanGroupDifferentiator implements GroupDifferentiator { - - @Override - public Double differentiate(List leftHand, List rightHand) { - - double leftHandSize = leftHand.size(); - double rightHandSize = rightHand.size(); - - if(leftHandSize == 0 || rightHandSize == 0){ - return null; - } - - double leftHandMean = leftHand.stream().mapToDouble(db -> db/leftHandSize).sum(); - double rightHandMean = rightHand.stream().mapToDouble(db -> db/rightHandSize).sum(); - - return Math.abs(leftHandMean - rightHandMean); - - } - -} diff --git a/src/main/java/ca/joeltherrien/randomforest/responses/regression/WeightedVarianceGroupDifferentiator.java b/src/main/java/ca/joeltherrien/randomforest/responses/regression/WeightedVarianceGroupDifferentiator.java index 25f7e6e..9ae4673 100644 --- a/src/main/java/ca/joeltherrien/randomforest/responses/regression/WeightedVarianceGroupDifferentiator.java +++ b/src/main/java/ca/joeltherrien/randomforest/responses/regression/WeightedVarianceGroupDifferentiator.java @@ -1,13 +1,13 @@ package ca.joeltherrien.randomforest.responses.regression; -import ca.joeltherrien.randomforest.tree.GroupDifferentiator; +import ca.joeltherrien.randomforest.tree.SimpleGroupDifferentiator; import java.util.List; -public class WeightedVarianceGroupDifferentiator implements GroupDifferentiator { +public class WeightedVarianceGroupDifferentiator extends SimpleGroupDifferentiator { @Override - public Double differentiate(List leftHand, List rightHand) { + public Double getScore(List leftHand, List rightHand) { final double leftHandSize = leftHand.size(); final double rightHandSize = rightHand.size(); diff --git a/src/main/java/ca/joeltherrien/randomforest/tree/ForestTrainer.java b/src/main/java/ca/joeltherrien/randomforest/tree/ForestTrainer.java index 37ee4ce..900fdfd 100644 --- a/src/main/java/ca/joeltherrien/randomforest/tree/ForestTrainer.java +++ b/src/main/java/ca/joeltherrien/randomforest/tree/ForestTrainer.java @@ -14,8 +14,10 @@ import java.io.IOException; import java.io.ObjectOutputStream; import java.util.ArrayList; import java.util.List; +import java.util.Random; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; +import java.util.concurrent.ThreadLocalRandom; import java.util.concurrent.atomic.AtomicInteger; import java.util.stream.Collectors; import java.util.stream.Stream; @@ -45,17 +47,17 @@ public class ForestTrainer { this.covariates = covariates; this.treeResponseCombiner = settings.getTreeCombiner(); this.treeTrainer = new TreeTrainer<>(settings, covariates); - } public Forest trainSerial(){ final List> trees = new ArrayList<>(ntree); final Bootstrapper> bootstrapper = new Bootstrapper<>(data); + final Random random = new Random(); for(int j=0; j { } - private Tree trainTree(final Bootstrapper> bootstrapper){ - final List> bootstrappedData = bootstrapper.bootstrap(); - return treeTrainer.growTree(bootstrappedData); + private Tree trainTree(final Bootstrapper> bootstrapper, Random random){ + final List> bootstrappedData = bootstrapper.bootstrap(random); + return treeTrainer.growTree(bootstrappedData, random); } public void saveTree(final Tree tree, String name) throws IOException { @@ -193,7 +195,8 @@ public class ForestTrainer { @Override public void run() { - final Tree tree = trainTree(bootstrapper); + // ThreadLocalRandom should make sure we don't duplicate seeds + final Tree tree = trainTree(bootstrapper, ThreadLocalRandom.current()); // should be okay as the list structure isn't changing treeList.set(treeIndex, tree); @@ -216,7 +219,8 @@ public class ForestTrainer { @Override public void run() { - final Tree tree = trainTree(bootstrapper); + // ThreadLocalRandom should make sure we don't duplicate seeds + final Tree tree = trainTree(bootstrapper, ThreadLocalRandom.current()); try { saveTree(tree, filename); diff --git a/src/main/java/ca/joeltherrien/randomforest/tree/GroupDifferentiator.java b/src/main/java/ca/joeltherrien/randomforest/tree/GroupDifferentiator.java index cbd1247..b040fe0 100644 --- a/src/main/java/ca/joeltherrien/randomforest/tree/GroupDifferentiator.java +++ b/src/main/java/ca/joeltherrien/randomforest/tree/GroupDifferentiator.java @@ -1,15 +1,17 @@ package ca.joeltherrien.randomforest.tree; -import java.util.List; +import java.util.Iterator; /** * When choosing an optimal node to split on, we choose the split that maximizes the difference between the two groups. - * The GroupDifferentiator has one method that outputs a score to show how different groups are. The larger the score, - * the greater the difference. + * The GroupDifferentiator has one method that cycles through an iterator of Splits (FYI; check if the iterator is an + * instance of Covariate.SplitRuleUpdater; in which case you get access to the rows that change between splits) * + * If you want to implement a very trivial GroupDifferentiator that just takes two Lists as arguments, try extending + * SimpleGroupDifferentiator. */ public interface GroupDifferentiator { - Double differentiate(List leftHand, List rightHand); + SplitAndScore differentiate(Iterator> splitIterator); } diff --git a/src/main/java/ca/joeltherrien/randomforest/tree/SimpleGroupDifferentiator.java b/src/main/java/ca/joeltherrien/randomforest/tree/SimpleGroupDifferentiator.java new file mode 100644 index 0000000..9a06afe --- /dev/null +++ b/src/main/java/ca/joeltherrien/randomforest/tree/SimpleGroupDifferentiator.java @@ -0,0 +1,50 @@ +package ca.joeltherrien.randomforest.tree; + +import ca.joeltherrien.randomforest.Row; + +import java.util.Iterator; +import java.util.List; +import java.util.stream.Collectors; + +public abstract class SimpleGroupDifferentiator implements GroupDifferentiator { + + @Override + public SplitAndScore differentiate(Iterator> splitIterator) { + Double bestScore = null; + Split bestSplit = null; + + while(splitIterator.hasNext()){ + final Split candidateSplit = splitIterator.next(); + + final List leftHand = candidateSplit.getLeftHand().stream().map(Row::getResponse).collect(Collectors.toList()); + final List rightHand = candidateSplit.getRightHand().stream().map(Row::getResponse).collect(Collectors.toList()); + + if(leftHand.isEmpty() || rightHand.isEmpty()){ + continue; + } + + final Double score = getScore(leftHand, rightHand); + + if(score != null && (bestScore == null || score > bestScore)){ + bestScore = score; + bestSplit = candidateSplit; + } + } + + if(bestSplit == null){ + return null; + } + + return new SplitAndScore<>(bestSplit, bestScore); + } + + /** + * Return a score; higher is better. + * + * @param leftHand + * @param rightHand + * @return + */ + public abstract Double getScore(List leftHand, List rightHand); + +} diff --git a/src/main/java/ca/joeltherrien/randomforest/tree/Split.java b/src/main/java/ca/joeltherrien/randomforest/tree/Split.java index e566e64..b55444c 100644 --- a/src/main/java/ca/joeltherrien/randomforest/tree/Split.java +++ b/src/main/java/ca/joeltherrien/randomforest/tree/Split.java @@ -1,19 +1,21 @@ package ca.joeltherrien.randomforest.tree; import ca.joeltherrien.randomforest.Row; +import ca.joeltherrien.randomforest.covariates.Covariate; import lombok.Data; import java.util.List; /** - * Very simple class that contains three lists; it's essentially a thruple. + * Very simple class that contains three lists and a SplitRule. * * @author joel * */ @Data -public class Split { +public final class Split { + public final Covariate.SplitRule splitRule; public final List> leftHand; public final List> rightHand; public final List> naHand; diff --git a/src/main/java/ca/joeltherrien/randomforest/tree/SplitAndScore.java b/src/main/java/ca/joeltherrien/randomforest/tree/SplitAndScore.java new file mode 100644 index 0000000..1160680 --- /dev/null +++ b/src/main/java/ca/joeltherrien/randomforest/tree/SplitAndScore.java @@ -0,0 +1,15 @@ +package ca.joeltherrien.randomforest.tree; + +import lombok.AllArgsConstructor; +import lombok.Getter; + +@AllArgsConstructor +public class SplitAndScore { + + @Getter + private final Split split; + + @Getter + private final Double score; + +} diff --git a/src/main/java/ca/joeltherrien/randomforest/tree/TreeTrainer.java b/src/main/java/ca/joeltherrien/randomforest/tree/TreeTrainer.java index 1464c93..0318cef 100644 --- a/src/main/java/ca/joeltherrien/randomforest/tree/TreeTrainer.java +++ b/src/main/java/ca/joeltherrien/randomforest/tree/TreeTrainer.java @@ -3,10 +3,11 @@ package ca.joeltherrien.randomforest.tree; import ca.joeltherrien.randomforest.Row; import ca.joeltherrien.randomforest.Settings; import ca.joeltherrien.randomforest.covariates.Covariate; -import lombok.*; +import lombok.AccessLevel; +import lombok.AllArgsConstructor; +import lombok.Builder; import java.util.*; -import java.util.concurrent.ThreadLocalRandom; import java.util.stream.Collectors; @Builder @@ -45,20 +46,21 @@ public class TreeTrainer { this.covariates = covariates; } - public Tree growTree(List> data){ + public Tree growTree(List> data, Random random){ - final Node rootNode = growNode(data, 0); + final Node rootNode = growNode(data, 0, random); return new Tree<>(rootNode, data.stream().mapToInt(Row::getId).toArray()); } - private Node growNode(List> data, int depth){ + private Node growNode(List> data, int depth, Random random){ // See https://kogalur.github.io/randomForestSRC/theory.html#section3.1 (near bottom) if(data.size() >= 2*nodeSize && depth < maxNodeDepth && !nodeIsPure(data)){ - final List covariatesToTry = selectCovariates(this.mtry); - final SplitRuleAndSplit bestSplitRuleAndSplit = findBestSplitRule(data, covariatesToTry); + final List covariatesToTry = selectCovariates(this.mtry, random); + final Split bestSplit = findBestSplitRule(data, covariatesToTry, random); - if(bestSplitRuleAndSplit.splitRule == null){ + + if(bestSplit == null){ return new TerminalNode<>( responseCombiner.combine( @@ -69,14 +71,31 @@ public class TreeTrainer { } - final Split split = bestSplitRuleAndSplit.split; - // Note that NAs have already been handled + + // Now that we have the best split; we need to handle any NAs that were dropped off + final double probabilityLeftHand = (double) bestSplit.leftHand.size() / + (double) (bestSplit.leftHand.size() + bestSplit.rightHand.size()); + + // Assign missing values to the split + for(Row row : data) { + if(row.getCovariateValue(bestSplit.getSplitRule().getParent()).isNA()) { + final boolean randomDecision = random.nextDouble() <= probabilityLeftHand; + + if(randomDecision){ + bestSplit.getLeftHand().add(row); + } + else{ + bestSplit.getRightHand().add(row); + } + + } + } - final Node leftNode = growNode(split.leftHand, depth+1); - final Node rightNode = growNode(split.rightHand, depth+1); + final Node leftNode = growNode(bestSplit.leftHand, depth+1, random); + final Node rightNode = growNode(bestSplit.rightHand, depth+1, random); - return new SplitNode<>(leftNode, rightNode, bestSplitRuleAndSplit.splitRule, bestSplitRuleAndSplit.probabilityLeftHand); + return new SplitNode<>(leftNode, rightNode, bestSplit.getSplitRule(), probabilityLeftHand); } else{ @@ -90,13 +109,13 @@ public class TreeTrainer { } - private List selectCovariates(int mtry){ + private List selectCovariates(int mtry, Random random){ if(mtry >= covariates.size()){ return covariates; } final List splitCovariates = new ArrayList<>(covariates); - Collections.shuffle(splitCovariates, ThreadLocalRandom.current()); + Collections.shuffle(splitCovariates, random); if (splitCovariates.size() > mtry) { splitCovariates.subList(mtry, splitCovariates.size()).clear(); @@ -105,63 +124,29 @@ public class TreeTrainer { return splitCovariates; } - private SplitRuleAndSplit findBestSplitRule(List> data, List covariatesToTry){ - final Random random = ThreadLocalRandom.current(); - SplitRuleAndSplit bestSplitRuleAndSplit = new SplitRuleAndSplit(); - double bestSplitScore = 0.0; - boolean first = true; + private Split findBestSplitRule(List> data, List covariatesToTry, Random random){ - for(final Covariate covariate : covariatesToTry){ + SplitAndScore bestSplitAndScore = null; + final GroupDifferentiator noGenericDifferentiator = groupDifferentiator; // cause Java generics suck - final int numberToTry = numberOfSplits==0 ? data.size() : numberOfSplits; + for(final Covariate covariate : covariatesToTry) { + final Iterator iterator = covariate.generateSplitRuleUpdater(data, this.numberOfSplits, random); - final Collection splitRulesToTry = covariate - .generateSplitRules( - data - .stream() - .map(row -> row.getCovariateValue(covariate)) - .collect(Collectors.toList()) - , numberToTry); + final SplitAndScore candidateSplitAndScore = noGenericDifferentiator.differentiate(iterator); - for(final Covariate.SplitRule possibleRule : splitRulesToTry){ - final Split possibleSplit = possibleRule.applyRule(data); - - // We have to handle any NAs - if(possibleSplit.leftHand.size() == 0 && possibleSplit.rightHand.size() == 0 && possibleSplit.naHand.size() > 0){ - throw new IllegalArgumentException("Can't apply " + this + " when there are rows with missing data and no non-missing value rows"); - } - - final double probabilityLeftHand = (double) possibleSplit.leftHand.size() / (double) (possibleSplit.leftHand.size() + possibleSplit.rightHand.size()); - - for(final Row missingValueRow : possibleSplit.naHand){ - final boolean randomDecision = random.nextDouble() <= probabilityLeftHand; - if(randomDecision){ - possibleSplit.leftHand.add(missingValueRow); - } - else{ - possibleSplit.rightHand.add(missingValueRow); - } - } - - final Double score = groupDifferentiator.differentiate( - possibleSplit.leftHand.stream().map(row -> row.getResponse()).collect(Collectors.toList()), - possibleSplit.rightHand.stream().map(row -> row.getResponse()).collect(Collectors.toList()) - ); - - - if(score != null && !Double.isNaN(score) && (score > bestSplitScore || first)){ - bestSplitRuleAndSplit.splitRule = possibleRule; - bestSplitRuleAndSplit.split = possibleSplit; - bestSplitRuleAndSplit.probabilityLeftHand = probabilityLeftHand; - - bestSplitScore = score; - first = false; + if(candidateSplitAndScore != null) { + if (bestSplitAndScore == null || candidateSplitAndScore.getScore() > bestSplitAndScore.getScore()) { + bestSplitAndScore = candidateSplitAndScore; } } } - return bestSplitRuleAndSplit; + if(bestSplitAndScore == null){ + return null; + } + + return bestSplitAndScore.getSplit(); } @@ -184,10 +169,4 @@ public class TreeTrainer { return true; } - private class SplitRuleAndSplit{ - private Covariate.SplitRule splitRule = null; - private Split split = null; - private double probabilityLeftHand; - } - } diff --git a/src/main/java/ca/joeltherrien/randomforest/utils/IndexedIterator.java b/src/main/java/ca/joeltherrien/randomforest/utils/IndexedIterator.java new file mode 100644 index 0000000..d82e0ec --- /dev/null +++ b/src/main/java/ca/joeltherrien/randomforest/utils/IndexedIterator.java @@ -0,0 +1,9 @@ +package ca.joeltherrien.randomforest.utils; + +import java.util.Iterator; + +public interface IndexedIterator extends Iterator { + + int getIndex(); + +} diff --git a/src/main/java/ca/joeltherrien/randomforest/utils/SingletonIterator.java b/src/main/java/ca/joeltherrien/randomforest/utils/SingletonIterator.java new file mode 100644 index 0000000..ca6df8b --- /dev/null +++ b/src/main/java/ca/joeltherrien/randomforest/utils/SingletonIterator.java @@ -0,0 +1,29 @@ +package ca.joeltherrien.randomforest.utils; + +import lombok.RequiredArgsConstructor; + +import java.util.Iterator; + +@RequiredArgsConstructor +public class SingletonIterator implements Iterator { + + private final E value; + + private boolean beenCalled = false; + + + @Override + public boolean hasNext() { + return !beenCalled; + } + + @Override + public E next() { + if(!beenCalled){ + beenCalled = true; + return value; + } + + return null; + } +} diff --git a/src/main/java/ca/joeltherrien/randomforest/utils/UniqueSubsetValueIterator.java b/src/main/java/ca/joeltherrien/randomforest/utils/UniqueSubsetValueIterator.java new file mode 100644 index 0000000..ebce035 --- /dev/null +++ b/src/main/java/ca/joeltherrien/randomforest/utils/UniqueSubsetValueIterator.java @@ -0,0 +1,65 @@ +package ca.joeltherrien.randomforest.utils; + + +import lombok.Getter; + +import java.util.Iterator; + +/** + * Iterator that wraps around a UniqueValueIterator. It continues to iterate until it gets to one of the prespecified indexes, + * and then proceeds just past that to the end of the existing values it's at. + * + * The wrapped iterator must be from a sorted collection of some sort such that equal values are clumped together. + * I.e. "b b c c c d d a a" is okay but "a b b c c a" is not as 'a' appears twice at different locations + * + * @param + */ +public class UniqueSubsetValueIterator implements IndexedIterator { + + private final UniqueValueIterator iterator; + private final Integer[] indexValues; + + private int currentIndexSpot = 0; + + public UniqueSubsetValueIterator(final UniqueValueIterator iterator, final Integer[] indexValues){ + this.iterator = iterator; + this.indexValues = indexValues; + } + + @Override + public boolean hasNext() { + return iterator.hasNext() && iterator.getIndex() <= indexValues[indexValues.length-1]; + } + + @Override + public E next() { + if(hasNext()){ + final int indexToStopBy = indexValues[currentIndexSpot]; + + while(iterator.getIndex() <= indexToStopBy){ + iterator.next(); + } + + for(int i = currentIndexSpot + 1; i < indexValues.length; i++){ + if(iterator.getIndex() <= indexValues[i]){ + currentIndexSpot = i; + break; + } + } + + + return iterator.getCurrentValue(); + + } + + return null; + + } + + @Override + public int getIndex(){ + return iterator.getIndex(); + } + + +} diff --git a/src/main/java/ca/joeltherrien/randomforest/utils/UniqueValueIterator.java b/src/main/java/ca/joeltherrien/randomforest/utils/UniqueValueIterator.java new file mode 100644 index 0000000..b585152 --- /dev/null +++ b/src/main/java/ca/joeltherrien/randomforest/utils/UniqueValueIterator.java @@ -0,0 +1,71 @@ +package ca.joeltherrien.randomforest.utils; + + +import lombok.Getter; + +import java.util.Iterator; + +/** + * Iterator that wraps around another iterator. It continues to iterate until it gets to the *end* of a sequence of identical values. + * It also tracks the current index in the original iterator. + * + * The wrapped iterator must be from a sorted collection of some sort such that equal values are clumped together. + * I.e. "b b c c c d d a a" is okay but "a b b c c a" is not as 'a' appears twice at different locations + * + * @param + */ +public class UniqueValueIterator implements IndexedIterator { + + private final Iterator wrappedIterator; + + @Getter private E currentValue = null; + @Getter private E nextValue; + + public UniqueValueIterator(final Iterator wrappedIterator){ + this.wrappedIterator = wrappedIterator; + this.nextValue = wrappedIterator.next(); + } + + // Count must return the index of the last value of the sequence returned by next() + @Getter + private int index = 0; + + @Override + public boolean hasNext() { + return nextValue != null; + } + + @Override + public E next() { + + int count = 1; + while(wrappedIterator.hasNext()){ + final E currentIteratorValue = wrappedIterator.next(); + + if(currentIteratorValue.equals(nextValue)){ + count++; + } + else{ + index +=count; + currentValue = nextValue; + nextValue = currentIteratorValue; + + return currentValue; + } + + } + + if(nextValue != null){ + index += count; + currentValue = nextValue; + nextValue = null; + + return currentValue; + } + else{ + return null; + } + + + } +} diff --git a/src/test/java/ca/joeltherrien/randomforest/TestSavingLoading.java b/src/test/java/ca/joeltherrien/randomforest/TestSavingLoading.java index 7e59727..c3849c8 100644 --- a/src/test/java/ca/joeltherrien/randomforest/TestSavingLoading.java +++ b/src/test/java/ca/joeltherrien/randomforest/TestSavingLoading.java @@ -1,8 +1,8 @@ package ca.joeltherrien.randomforest; -import ca.joeltherrien.randomforest.covariates.BooleanCovariateSettings; +import ca.joeltherrien.randomforest.covariates.settings.BooleanCovariateSettings; import ca.joeltherrien.randomforest.covariates.Covariate; -import ca.joeltherrien.randomforest.covariates.NumericCovariateSettings; +import ca.joeltherrien.randomforest.covariates.settings.NumericCovariateSettings; import ca.joeltherrien.randomforest.responses.competingrisk.combiner.CompetingRiskFunctionCombiner; import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskFunctions; import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse; diff --git a/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCompetingRisk.java b/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCompetingRisk.java index ea2d7dd..2d4365c 100644 --- a/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCompetingRisk.java +++ b/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCompetingRisk.java @@ -2,6 +2,8 @@ package ca.joeltherrien.randomforest.competingrisk; import ca.joeltherrien.randomforest.*; import ca.joeltherrien.randomforest.covariates.*; +import ca.joeltherrien.randomforest.covariates.settings.BooleanCovariateSettings; +import ca.joeltherrien.randomforest.covariates.settings.NumericCovariateSettings; import ca.joeltherrien.randomforest.responses.competingrisk.*; import ca.joeltherrien.randomforest.tree.Forest; import ca.joeltherrien.randomforest.tree.ForestTrainer; @@ -18,6 +20,7 @@ import static org.junit.jupiter.api.Assertions.*; import java.io.IOException; import java.util.List; +import java.util.Random; public class TestCompetingRisk { @@ -104,7 +107,7 @@ public class TestCompetingRisk { final List> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation()); final TreeTrainer treeTrainer = new TreeTrainer<>(settings, covariates); - final Node node = treeTrainer.growTree(dataset); + final Node node = treeTrainer.growTree(dataset, new Random()); final CovariateRow newRow = getPredictionRow(covariates); @@ -157,7 +160,7 @@ public class TestCompetingRisk { final List> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation()); final TreeTrainer treeTrainer = new TreeTrainer<>(settings, covariates); - final Node node = treeTrainer.growTree(dataset); + final Node node = treeTrainer.growTree(dataset, new Random()); final CovariateRow newRow = getPredictionRow(covariates); diff --git a/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestLogRankMultipleGroupDifferentiator.java b/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestLogRankMultipleGroupDifferentiator.java index 51d142f..22c975f 100644 --- a/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestLogRankMultipleGroupDifferentiator.java +++ b/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestLogRankMultipleGroupDifferentiator.java @@ -4,7 +4,7 @@ import ca.joeltherrien.randomforest.DataLoader; import ca.joeltherrien.randomforest.Row; import ca.joeltherrien.randomforest.Settings; import ca.joeltherrien.randomforest.covariates.Covariate; -import ca.joeltherrien.randomforest.covariates.NumericCovariateSettings; +import ca.joeltherrien.randomforest.covariates.settings.NumericCovariateSettings; import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse; import ca.joeltherrien.randomforest.responses.competingrisk.differentiator.LogRankMultipleGroupDifferentiator; import ca.joeltherrien.randomforest.utils.Utils; @@ -53,14 +53,14 @@ public class TestLogRankMultipleGroupDifferentiator { final List> group1Bad = data.subList(0, 196); final List> group2Bad = data.subList(196, data.size()); - final double scoreBad = groupDifferentiator.differentiate( + final double scoreBad = groupDifferentiator.getScore( group1Bad.stream().map(Row::getResponse).collect(Collectors.toList()), group2Bad.stream().map(Row::getResponse).collect(Collectors.toList())); final List> group1Good = data.subList(0, 199); final List> group2Good= data.subList(199, data.size()); - final double scoreGood = groupDifferentiator.differentiate( + final double scoreGood = groupDifferentiator.getScore( group1Good.stream().map(Row::getResponse).collect(Collectors.toList()), group2Good.stream().map(Row::getResponse).collect(Collectors.toList())); diff --git a/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestLogRankSingleGroupDifferentiator.java b/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestLogRankSingleGroupDifferentiator.java index ffdbe4b..e48c888 100644 --- a/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestLogRankSingleGroupDifferentiator.java +++ b/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestLogRankSingleGroupDifferentiator.java @@ -1,8 +1,7 @@ package ca.joeltherrien.randomforest.competingrisk; import ca.joeltherrien.randomforest.Row; -import ca.joeltherrien.randomforest.responses.competingrisk.*; -import ca.joeltherrien.randomforest.responses.competingrisk.differentiator.LogRankMultipleGroupDifferentiator; +import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse; import ca.joeltherrien.randomforest.responses.competingrisk.differentiator.LogRankSingleGroupDifferentiator; import org.junit.jupiter.api.Test; @@ -51,7 +50,7 @@ public class TestLogRankSingleGroupDifferentiator { final LogRankSingleGroupDifferentiator differentiator = new LogRankSingleGroupDifferentiator(1, new int[]{1}); - final double score = differentiator.differentiate(data1, data2); + final double score = differentiator.getScore(data1, data2); final double margin = 0.000001; // Tested using 855 method @@ -71,14 +70,14 @@ public class TestLogRankSingleGroupDifferentiator { final List> group1Good = data.subList(0, 221); final List> group2Good = data.subList(221, data.size()); - final double scoreGood = groupDifferentiator.differentiate( + final double scoreGood = groupDifferentiator.getScore( group1Good.stream().map(Row::getResponse).collect(Collectors.toList()), group2Good.stream().map(Row::getResponse).collect(Collectors.toList())); final List> group1Bad = data.subList(0, 222); final List> group2Bad = data.subList(222, data.size()); - final double scoreBad = groupDifferentiator.differentiate( + final double scoreBad = groupDifferentiator.getScore( group1Bad.stream().map(Row::getResponse).collect(Collectors.toList()), group2Bad.stream().map(Row::getResponse).collect(Collectors.toList())); diff --git a/src/test/java/ca/joeltherrien/randomforest/covariates/FactorCovariateTest.java b/src/test/java/ca/joeltherrien/randomforest/covariates/FactorCovariateTest.java index 7d80a8c..adbf3d2 100644 --- a/src/test/java/ca/joeltherrien/randomforest/covariates/FactorCovariateTest.java +++ b/src/test/java/ca/joeltherrien/randomforest/covariates/FactorCovariateTest.java @@ -5,8 +5,9 @@ import ca.joeltherrien.randomforest.utils.Utils; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.function.Executable; -import java.util.Collection; +import java.util.ArrayList; import java.util.List; +import java.util.Random; import static org.junit.jupiter.api.Assertions.*; @@ -46,7 +47,10 @@ public class FactorCovariateTest { void testAllSubsets(){ final FactorCovariate petCovariate = createTestCovariate(); - final Collection splitRules = petCovariate.generateSplitRules(null, 100); + final List> splitRules = new ArrayList<>(); + + petCovariate.generateSplitRuleUpdater(null, 100, new Random()) + .forEachRemaining(split -> splitRules.add(split.getSplitRule())); assertEquals(splitRules.size(), 3); diff --git a/src/test/java/ca/joeltherrien/randomforest/csv/TestLoadingCSV.java b/src/test/java/ca/joeltherrien/randomforest/csv/TestLoadingCSV.java index 82262df..1fbb40f 100644 --- a/src/test/java/ca/joeltherrien/randomforest/csv/TestLoadingCSV.java +++ b/src/test/java/ca/joeltherrien/randomforest/csv/TestLoadingCSV.java @@ -3,10 +3,10 @@ package ca.joeltherrien.randomforest.csv; import ca.joeltherrien.randomforest.DataLoader; import ca.joeltherrien.randomforest.Row; import ca.joeltherrien.randomforest.Settings; -import ca.joeltherrien.randomforest.covariates.BooleanCovariateSettings; +import ca.joeltherrien.randomforest.covariates.settings.BooleanCovariateSettings; import ca.joeltherrien.randomforest.covariates.Covariate; -import ca.joeltherrien.randomforest.covariates.FactorCovariateSettings; -import ca.joeltherrien.randomforest.covariates.NumericCovariateSettings; +import ca.joeltherrien.randomforest.covariates.settings.FactorCovariateSettings; +import ca.joeltherrien.randomforest.covariates.settings.NumericCovariateSettings; import ca.joeltherrien.randomforest.utils.Utils; import com.fasterxml.jackson.databind.node.JsonNodeFactory; import com.fasterxml.jackson.databind.node.ObjectNode; @@ -15,7 +15,6 @@ import org.junit.jupiter.api.Test; import java.io.IOException; import java.util.List; -import java.util.stream.Collectors; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertTrue; diff --git a/src/test/java/ca/joeltherrien/randomforest/settings/TestPersistence.java b/src/test/java/ca/joeltherrien/randomforest/settings/TestPersistence.java index 901cdb1..79a0e6b 100644 --- a/src/test/java/ca/joeltherrien/randomforest/settings/TestPersistence.java +++ b/src/test/java/ca/joeltherrien/randomforest/settings/TestPersistence.java @@ -3,7 +3,9 @@ package ca.joeltherrien.randomforest.settings; import ca.joeltherrien.randomforest.Settings; import static org.junit.jupiter.api.Assertions.assertEquals; -import ca.joeltherrien.randomforest.covariates.*; +import ca.joeltherrien.randomforest.covariates.settings.BooleanCovariateSettings; +import ca.joeltherrien.randomforest.covariates.settings.FactorCovariateSettings; +import ca.joeltherrien.randomforest.covariates.settings.NumericCovariateSettings; import ca.joeltherrien.randomforest.utils.Utils; import com.fasterxml.jackson.databind.node.JsonNodeFactory; import com.fasterxml.jackson.databind.node.ObjectNode; diff --git a/src/test/java/ca/joeltherrien/randomforest/utils/TestSingletonIterator.java b/src/test/java/ca/joeltherrien/randomforest/utils/TestSingletonIterator.java new file mode 100644 index 0000000..03805a7 --- /dev/null +++ b/src/test/java/ca/joeltherrien/randomforest/utils/TestSingletonIterator.java @@ -0,0 +1,25 @@ +package ca.joeltherrien.randomforest.utils; + +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.*; + +public class TestSingletonIterator { + + @Test + public void verifyBehaviour(){ + final Integer element = 5; + + final SingletonIterator iterator = new SingletonIterator<>(element); + + assertTrue(iterator.hasNext()); + assertTrue(iterator.hasNext()); + + assertEquals(Integer.valueOf(5), iterator.next()); + + assertFalse(iterator.hasNext()); + assertNull(iterator.next()); + + } + +} diff --git a/src/test/java/ca/joeltherrien/randomforest/utils/TestUniqueSubsetValueIterator.java b/src/test/java/ca/joeltherrien/randomforest/utils/TestUniqueSubsetValueIterator.java new file mode 100644 index 0000000..0eb0051 --- /dev/null +++ b/src/test/java/ca/joeltherrien/randomforest/utils/TestUniqueSubsetValueIterator.java @@ -0,0 +1,68 @@ +package ca.joeltherrien.randomforest.utils; + +import org.junit.jupiter.api.Test; + +import java.util.Arrays; +import java.util.List; + +import static org.junit.jupiter.api.Assertions.*; + +public class TestUniqueSubsetValueIterator { + + @Test + public void testIterator1(){ + final List testData = Arrays.asList( + 1,1,2,3,5,5,5,6,7,7 + ); + + final Integer[] indexes = new Integer[]{2,3,4,5}; + + final UniqueValueIterator uniqueValueIterator = new UniqueValueIterator<>(testData.iterator()); + final UniqueSubsetValueIterator iterator = new UniqueSubsetValueIterator<>(uniqueValueIterator, indexes); + + // we expect to get 2, 3, and 5 back. 5 should happen only once + + assertEquals(iterator.getIndex(), 0); + assertTrue(iterator.hasNext()); + + assertEquals(iterator.next().intValue(), 2); + assertEquals(iterator.getIndex(), 3); + assertTrue(iterator.hasNext()); + + assertEquals(iterator.next().intValue(), 3); + assertEquals(iterator.getIndex(), 4); + assertTrue(iterator.hasNext()); + + assertEquals(iterator.next().intValue(), 5); + assertEquals(iterator.getIndex(), 7); + assertFalse(iterator.hasNext()); + + + } + + @Test + public void testIterator2(){ + final List testData = Arrays.asList( + 1,1,2,3,5,5,5,6,7,7 + ); + + final Integer[] indexes = new Integer[]{1,8}; + + final UniqueValueIterator uniqueValueIterator = new UniqueValueIterator<>(testData.iterator()); + final UniqueSubsetValueIterator iterator = new UniqueSubsetValueIterator<>(uniqueValueIterator, indexes); + + assertEquals(iterator.getIndex(), 0); + assertTrue(iterator.hasNext()); + + assertEquals(iterator.next().intValue(), 1); + assertEquals(iterator.getIndex(), 2); + assertTrue(iterator.hasNext()); + + assertEquals(iterator.next().intValue(), 7); + assertEquals(iterator.getIndex(), 10); + assertFalse(iterator.hasNext()); + + + } + +} diff --git a/src/test/java/ca/joeltherrien/randomforest/utils/TestUniqueValueIterator.java b/src/test/java/ca/joeltherrien/randomforest/utils/TestUniqueValueIterator.java new file mode 100644 index 0000000..16d6cc0 --- /dev/null +++ b/src/test/java/ca/joeltherrien/randomforest/utils/TestUniqueValueIterator.java @@ -0,0 +1,112 @@ +package ca.joeltherrien.randomforest.utils; + +import org.junit.jupiter.api.Test; + +import java.util.Arrays; +import java.util.List; + +import static org.junit.jupiter.api.Assertions.*; + +public class TestUniqueValueIterator { + + @Test + public void testIterator1(){ + final List testData = Arrays.asList( + 1,1,2,3,5,5,5,6,7,7 + ); + + final UniqueValueIterator iterator = new UniqueValueIterator<>(testData.iterator()); + + assertEquals(iterator.getIndex(), 0); + assertTrue(iterator.hasNext()); + + assertEquals(iterator.next().intValue(), 1); + assertEquals(iterator.getIndex(), 2); + assertTrue(iterator.hasNext()); + + assertEquals(iterator.next().intValue(), 2); + assertEquals(iterator.getIndex(), 3); + assertTrue(iterator.hasNext()); + + assertEquals(iterator.next().intValue(), 3); + assertEquals(iterator.getIndex(), 4); + assertTrue(iterator.hasNext()); + + assertEquals(iterator.next().intValue(), 5); + assertEquals(iterator.getIndex(), 7); + assertTrue(iterator.hasNext()); + + assertEquals(iterator.next().intValue(), 6); + assertEquals(iterator.getIndex(), 8); + assertTrue(iterator.hasNext()); + + assertEquals(iterator.next().intValue(), 7); + assertEquals(iterator.getIndex(), 10); + assertTrue(!iterator.hasNext()); + + } + + @Test + public void testIterator2(){ + final List testData = Arrays.asList( + 1,2,3,5,5,5,6,7 // same numbers; but 1 and 7 only appear once each + ); + + final UniqueValueIterator iterator = new UniqueValueIterator<>(testData.iterator()); + + assertEquals(iterator.getIndex(), 0); + assertTrue(iterator.hasNext()); + + assertEquals(iterator.next().intValue(), 1); + assertEquals(iterator.getIndex(), 1); + assertTrue(iterator.hasNext()); + + assertEquals(iterator.next().intValue(), 2); + assertEquals(iterator.getIndex(), 2); + assertTrue(iterator.hasNext()); + + assertEquals(iterator.next().intValue(), 3); + assertEquals(iterator.getIndex(), 3); + assertTrue(iterator.hasNext()); + + assertEquals(iterator.next().intValue(), 5); + assertEquals(iterator.getIndex(), 6); + assertTrue(iterator.hasNext()); + + assertEquals(iterator.next().intValue(), 6); + assertEquals(iterator.getIndex(), 7); + assertTrue(iterator.hasNext()); + + assertEquals(iterator.next().intValue(), 7); + assertEquals(iterator.getIndex(), 8); + assertFalse(iterator.hasNext()); + + } + + @Test + public void testIterator3(){ + final List testData = Arrays.asList( + 1,1,1,1,1,1,1,2,2,2,2,2,3 + ); + + final UniqueValueIterator iterator = new UniqueValueIterator<>(testData.iterator()); + + assertEquals(iterator.getIndex(), 0); + assertTrue(iterator.hasNext()); + + assertEquals(iterator.next().intValue(), 1); + assertEquals(iterator.getIndex(), 7); + assertTrue(iterator.hasNext()); + + assertEquals(iterator.next().intValue(), 2); + assertEquals(iterator.getIndex(), 12); + assertTrue(iterator.hasNext()); + + assertEquals(iterator.next().intValue(), 3); + assertEquals(iterator.getIndex(), 13); + assertFalse(iterator.hasNext()); + + + } + +} diff --git a/src/test/java/ca/joeltherrien/randomforest/workshop/TrainForest.java b/src/test/java/ca/joeltherrien/randomforest/workshop/TrainForest.java index f453cd9..d87442b 100644 --- a/src/test/java/ca/joeltherrien/randomforest/workshop/TrainForest.java +++ b/src/test/java/ca/joeltherrien/randomforest/workshop/TrainForest.java @@ -2,7 +2,7 @@ package ca.joeltherrien.randomforest.workshop; import ca.joeltherrien.randomforest.*; import ca.joeltherrien.randomforest.covariates.Covariate; -import ca.joeltherrien.randomforest.covariates.NumericCovariate; +import ca.joeltherrien.randomforest.covariates.numeric.NumericCovariate; import ca.joeltherrien.randomforest.responses.regression.MeanResponseCombiner; import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceGroupDifferentiator; import ca.joeltherrien.randomforest.tree.ForestTrainer; diff --git a/src/test/java/ca/joeltherrien/randomforest/workshop/TrainSingleTree.java b/src/test/java/ca/joeltherrien/randomforest/workshop/TrainSingleTree.java index 88a855e..547eb09 100644 --- a/src/test/java/ca/joeltherrien/randomforest/workshop/TrainSingleTree.java +++ b/src/test/java/ca/joeltherrien/randomforest/workshop/TrainSingleTree.java @@ -3,7 +3,7 @@ package ca.joeltherrien.randomforest.workshop; import ca.joeltherrien.randomforest.covariates.Covariate; import ca.joeltherrien.randomforest.CovariateRow; -import ca.joeltherrien.randomforest.covariates.NumericCovariate; +import ca.joeltherrien.randomforest.covariates.numeric.NumericCovariate; import ca.joeltherrien.randomforest.Row; import ca.joeltherrien.randomforest.responses.regression.MeanResponseCombiner; import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceGroupDifferentiator; @@ -54,13 +54,14 @@ public class TrainSingleTree { .covariates(covariateNames) .responseCombiner(new MeanResponseCombiner()) .maxNodeDepth(30) + .mtry(2) .nodeSize(5) .numberOfSplits(0) .build(); final long startTime = System.currentTimeMillis(); - final Node baseNode = treeTrainer.growTree(trainingSet); + final Node baseNode = treeTrainer.growTree(trainingSet, new Random()); final long endTime = System.currentTimeMillis(); System.out.println(((double)(endTime - startTime))/1000.0); diff --git a/src/test/java/ca/joeltherrien/randomforest/workshop/TrainSingleTreeFactor.java b/src/test/java/ca/joeltherrien/randomforest/workshop/TrainSingleTreeFactor.java index 6c2d430..e126624 100644 --- a/src/test/java/ca/joeltherrien/randomforest/workshop/TrainSingleTreeFactor.java +++ b/src/test/java/ca/joeltherrien/randomforest/workshop/TrainSingleTreeFactor.java @@ -4,7 +4,7 @@ package ca.joeltherrien.randomforest.workshop; import ca.joeltherrien.randomforest.*; import ca.joeltherrien.randomforest.covariates.Covariate; import ca.joeltherrien.randomforest.covariates.FactorCovariate; -import ca.joeltherrien.randomforest.covariates.NumericCovariate; +import ca.joeltherrien.randomforest.covariates.numeric.NumericCovariate; import ca.joeltherrien.randomforest.responses.regression.MeanResponseCombiner; import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceGroupDifferentiator; import ca.joeltherrien.randomforest.tree.Node; @@ -13,7 +13,6 @@ import ca.joeltherrien.randomforest.utils.Utils; import java.util.ArrayList; import java.util.List; -import java.util.Map; import java.util.Random; import java.util.stream.Collectors; import java.util.stream.DoubleStream; @@ -21,8 +20,6 @@ import java.util.stream.DoubleStream; public class TrainSingleTreeFactor { public static void main(String[] args) { - System.out.println("Hello world!"); - final Random random = new Random(123); final int n = 10000; @@ -84,7 +81,7 @@ public class TrainSingleTreeFactor { final long startTime = System.currentTimeMillis(); - final Node baseNode = treeTrainer.growTree(trainingSet); + final Node baseNode = treeTrainer.growTree(trainingSet, random); final long endTime = System.currentTimeMillis(); System.out.println(((double)(endTime - startTime))/1000.0);