diff --git a/src/main/java/ca/joeltherrien/randomforest/Main.java b/src/main/java/ca/joeltherrien/randomforest/Main.java index e353c7f..3aa29ab 100644 --- a/src/main/java/ca/joeltherrien/randomforest/Main.java +++ b/src/main/java/ca/joeltherrien/randomforest/Main.java @@ -180,8 +180,8 @@ public class Main { private static Settings defaultTemplate(){ - final ObjectNode groupDifferentiatorSettings = new ObjectNode(JsonNodeFactory.instance); - groupDifferentiatorSettings.set("type", new TextNode("WeightedVarianceGroupDifferentiator")); + final ObjectNode splitFinderSettings = new ObjectNode(JsonNodeFactory.instance); + splitFinderSettings.set("type", new TextNode("WeightedVarianceSplitFinder")); final ObjectNode responseCombinerSettings = new ObjectNode(JsonNodeFactory.instance); responseCombinerSettings.set("type", new TextNode("MeanResponseCombiner")); @@ -204,7 +204,7 @@ public class Main { .validationDataLocation("validation_data.csv") .responseCombinerSettings(responseCombinerSettings) .treeCombinerSettings(treeCombinerSettings) - .groupDifferentiatorSettings(groupDifferentiatorSettings) + .splitFinderSettings(splitFinderSettings) .yVarSettings(yVarSettings) .maxNodeDepth(100000) .mtry(2) diff --git a/src/main/java/ca/joeltherrien/randomforest/Settings.java b/src/main/java/ca/joeltherrien/randomforest/Settings.java index da83ff4..64bd8da 100644 --- a/src/main/java/ca/joeltherrien/randomforest/Settings.java +++ b/src/main/java/ca/joeltherrien/randomforest/Settings.java @@ -22,11 +22,11 @@ import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskRespons import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponseWithCensorTime; import ca.joeltherrien.randomforest.responses.competingrisk.combiner.CompetingRiskFunctionCombiner; import ca.joeltherrien.randomforest.responses.competingrisk.combiner.CompetingRiskResponseCombiner; -import ca.joeltherrien.randomforest.responses.competingrisk.differentiator.GrayLogRankDifferentiator; -import ca.joeltherrien.randomforest.responses.competingrisk.differentiator.LogRankDifferentiator; +import ca.joeltherrien.randomforest.responses.competingrisk.splitfinder.GrayLogRankSplitFinder; +import ca.joeltherrien.randomforest.responses.competingrisk.splitfinder.LogRankSplitFinder; import ca.joeltherrien.randomforest.responses.regression.MeanResponseCombiner; -import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceGroupDifferentiator; -import ca.joeltherrien.randomforest.tree.GroupDifferentiator; +import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceSplitFinder; +import ca.joeltherrien.randomforest.tree.SplitFinder; import ca.joeltherrien.randomforest.tree.ResponseCombiner; import ca.joeltherrien.randomforest.utils.DataUtils; import ca.joeltherrien.randomforest.utils.Utils; @@ -77,31 +77,31 @@ public class Settings { ); } - private static Map<String, Function<ObjectNode, GroupDifferentiator>> GROUP_DIFFERENTIATOR_MAP = new HashMap<>(); - public static Function<ObjectNode, GroupDifferentiator> getGroupDifferentiatorConstructor(final String name){ - return GROUP_DIFFERENTIATOR_MAP.get(name.toLowerCase()); + private static Map<String, Function<ObjectNode, SplitFinder>> SPLIT_FINDER_MAP = new HashMap<>(); + public static Function<ObjectNode, SplitFinder> getSplitFinderConstructor(final String name){ + return SPLIT_FINDER_MAP.get(name.toLowerCase()); } - public static void registerGroupDifferentiatorConstructor(final String name, final Function<ObjectNode, GroupDifferentiator> groupDifferentiatorConstructor){ - GROUP_DIFFERENTIATOR_MAP.put(name.toLowerCase(), groupDifferentiatorConstructor); + public static void registerSplitFinderConstructor(final String name, final Function<ObjectNode, SplitFinder> splitFinderConstructor){ + SPLIT_FINDER_MAP.put(name.toLowerCase(), splitFinderConstructor); } static{ - registerGroupDifferentiatorConstructor("WeightedVarianceGroupDifferentiator", - (node) -> new WeightedVarianceGroupDifferentiator() + registerSplitFinderConstructor("WeightedVarianceSplitFinder", + (node) -> new WeightedVarianceSplitFinder() ); - registerGroupDifferentiatorConstructor("GrayLogRankDifferentiator", + registerSplitFinderConstructor("GrayLogRankSplitFinder", (objectNode) -> { final int[] eventsOfFocusArray = Utils.jsonToIntArray(objectNode.get("eventsOfFocus")); final int[] eventArray = Utils.jsonToIntArray(objectNode.get("events")); - return new GrayLogRankDifferentiator(eventsOfFocusArray, eventArray); + return new GrayLogRankSplitFinder(eventsOfFocusArray, eventArray); } ); - registerGroupDifferentiatorConstructor("LogRankDifferentiator", + registerSplitFinderConstructor("LogRankSplitFinder", (objectNode) -> { final int[] eventsOfFocusArray = Utils.jsonToIntArray(objectNode.get("eventsOfFocus")); final int[] eventArray = Utils.jsonToIntArray(objectNode.get("events")); - return new LogRankDifferentiator(eventsOfFocusArray, eventArray); + return new LogRankSplitFinder(eventsOfFocusArray, eventArray); } ); } @@ -153,7 +153,7 @@ public class Settings { private boolean checkNodePurity = false; private ObjectNode responseCombinerSettings = new ObjectNode(JsonNodeFactory.instance); - private ObjectNode groupDifferentiatorSettings = new ObjectNode(JsonNodeFactory.instance); + private ObjectNode splitFinderSettings = new ObjectNode(JsonNodeFactory.instance); private ObjectNode treeCombinerSettings = new ObjectNode(JsonNodeFactory.instance); private List<CovariateSettings> covariateSettings = new ArrayList<>(); @@ -194,10 +194,10 @@ public class Settings { } @JsonIgnore - public GroupDifferentiator getGroupDifferentiator(){ - final String type = groupDifferentiatorSettings.get("type").asText(); + public SplitFinder getSplitFinder(){ + final String type = splitFinderSettings.get("type").asText(); - return getGroupDifferentiatorConstructor(type).apply(groupDifferentiatorSettings); + return getSplitFinderConstructor(type).apply(splitFinderSettings); } @JsonIgnore diff --git a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/differentiator/CompetingRiskGroupDifferentiator.java b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/splitfinder/CompetingRiskSplitFinder.java similarity index 92% rename from src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/differentiator/CompetingRiskGroupDifferentiator.java rename to src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/splitfinder/CompetingRiskSplitFinder.java index aef9b4c..b2cddb8 100644 --- a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/differentiator/CompetingRiskGroupDifferentiator.java +++ b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/splitfinder/CompetingRiskSplitFinder.java @@ -14,13 +14,13 @@ * along with this program. If not, see <https://www.gnu.org/licenses/>. */ -package ca.joeltherrien.randomforest.responses.competingrisk.differentiator; +package ca.joeltherrien.randomforest.responses.competingrisk.splitfinder; import ca.joeltherrien.randomforest.Row; import ca.joeltherrien.randomforest.covariates.Covariate; import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse; import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskSets; -import ca.joeltherrien.randomforest.tree.GroupDifferentiator; +import ca.joeltherrien.randomforest.tree.SplitFinder; import ca.joeltherrien.randomforest.tree.Split; import ca.joeltherrien.randomforest.tree.SplitAndScore; import lombok.AllArgsConstructor; @@ -35,24 +35,24 @@ import java.util.stream.Collectors; * modifies the abstract method. * */ -public abstract class CompetingRiskGroupDifferentiator<Y extends CompetingRiskResponse> implements GroupDifferentiator<Y> { +public abstract class CompetingRiskSplitFinder<Y extends CompetingRiskResponse> implements SplitFinder<Y> { abstract protected CompetingRiskSets<Y> createCompetingRiskSets(List<Y> leftHand, List<Y> rightHand); abstract protected Double getScore(final CompetingRiskSets<Y> competingRiskSets); @Override - public SplitAndScore<Y, ?> differentiate(Iterator<Split<Y, ?>> splitIterator) { + public SplitAndScore<Y, ?> findBestSplit(Iterator<Split<Y, ?>> splitIterator) { if(splitIterator instanceof Covariate.SplitRuleUpdater){ - return differentiateWithSplitUpdater((Covariate.SplitRuleUpdater) splitIterator); + return findBestSplitWithSplitUpdater((Covariate.SplitRuleUpdater) splitIterator); } else{ - return differentiateWithBasicIterator(splitIterator); + return findBestSplitWithBasicIterator(splitIterator); } } - private SplitAndScore<Y, ?> differentiateWithBasicIterator(Iterator<Split<Y, ?>> splitIterator){ + private SplitAndScore<Y, ?> findBestSplitWithBasicIterator(Iterator<Split<Y, ?>> splitIterator){ Double bestScore = null; Split<Y, ?> bestSplit = null; @@ -83,7 +83,7 @@ public abstract class CompetingRiskGroupDifferentiator<Y extends CompetingRiskRe return new SplitAndScore<>(bestSplit, bestScore); } - private SplitAndScore<Y, ?> differentiateWithSplitUpdater(Covariate.SplitRuleUpdater<Y, ?> splitRuleUpdater) { + private SplitAndScore<Y, ?> findBestSplitWithSplitUpdater(Covariate.SplitRuleUpdater<Y, ?> splitRuleUpdater) { final List<Y> leftInitialSplit = splitRuleUpdater.currentSplit().getLeftHand() .stream().map(Row::getResponse).collect(Collectors.toList()); diff --git a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/differentiator/GrayLogRankDifferentiator.java b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/splitfinder/GrayLogRankSplitFinder.java similarity index 91% rename from src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/differentiator/GrayLogRankDifferentiator.java rename to src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/splitfinder/GrayLogRankSplitFinder.java index 420ad15..b6eed13 100644 --- a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/differentiator/GrayLogRankDifferentiator.java +++ b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/splitfinder/GrayLogRankSplitFinder.java @@ -14,7 +14,7 @@ * along with this program. If not, see <https://www.gnu.org/licenses/>. */ -package ca.joeltherrien.randomforest.responses.competingrisk.differentiator; +package ca.joeltherrien.randomforest.responses.competingrisk.splitfinder; import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponseWithCensorTime; import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskSets; @@ -27,12 +27,12 @@ import java.util.List; * See page 761 of Random survival forests for competing risks by Ishwaran et al. * */ -public class GrayLogRankDifferentiator extends CompetingRiskGroupDifferentiator<CompetingRiskResponseWithCensorTime> { +public class GrayLogRankSplitFinder extends CompetingRiskSplitFinder<CompetingRiskResponseWithCensorTime> { private final int[] eventsOfFocus; private final int[] events; - public GrayLogRankDifferentiator(int[] eventsOfFocus, int[] events){ + public GrayLogRankSplitFinder(int[] eventsOfFocus, int[] events){ this.eventsOfFocus = eventsOfFocus; this.events = events; diff --git a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/differentiator/LogRankDifferentiator.java b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/splitfinder/LogRankSplitFinder.java similarity index 91% rename from src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/differentiator/LogRankDifferentiator.java rename to src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/splitfinder/LogRankSplitFinder.java index d528de3..c31bb84 100644 --- a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/differentiator/LogRankDifferentiator.java +++ b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/splitfinder/LogRankSplitFinder.java @@ -14,7 +14,7 @@ * along with this program. If not, see <https://www.gnu.org/licenses/>. */ -package ca.joeltherrien.randomforest.responses.competingrisk.differentiator; +package ca.joeltherrien.randomforest.responses.competingrisk.splitfinder; import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse; import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskSets; @@ -27,12 +27,12 @@ import java.util.List; * See page 761 of Random survival forests for competing risks by Ishwaran et al. * */ -public class LogRankDifferentiator extends CompetingRiskGroupDifferentiator<CompetingRiskResponse> { +public class LogRankSplitFinder extends CompetingRiskSplitFinder<CompetingRiskResponse> { private final int[] eventsOfFocus; private final int[] events; - public LogRankDifferentiator(int[] eventsOfFocus, int[] events){ + public LogRankSplitFinder(int[] eventsOfFocus, int[] events){ this.eventsOfFocus = eventsOfFocus; this.events = events; diff --git a/src/main/java/ca/joeltherrien/randomforest/responses/regression/WeightedVarianceGroupDifferentiator.java b/src/main/java/ca/joeltherrien/randomforest/responses/regression/WeightedVarianceSplitFinder.java similarity index 91% rename from src/main/java/ca/joeltherrien/randomforest/responses/regression/WeightedVarianceGroupDifferentiator.java rename to src/main/java/ca/joeltherrien/randomforest/responses/regression/WeightedVarianceSplitFinder.java index a16f403..9fb4d43 100644 --- a/src/main/java/ca/joeltherrien/randomforest/responses/regression/WeightedVarianceGroupDifferentiator.java +++ b/src/main/java/ca/joeltherrien/randomforest/responses/regression/WeightedVarianceSplitFinder.java @@ -18,7 +18,7 @@ package ca.joeltherrien.randomforest.responses.regression; import ca.joeltherrien.randomforest.Row; import ca.joeltherrien.randomforest.covariates.Covariate; -import ca.joeltherrien.randomforest.tree.GroupDifferentiator; +import ca.joeltherrien.randomforest.tree.SplitFinder; import ca.joeltherrien.randomforest.tree.Split; import ca.joeltherrien.randomforest.tree.SplitAndScore; @@ -26,7 +26,7 @@ import java.util.Iterator; import java.util.List; import java.util.stream.Collectors; -public class WeightedVarianceGroupDifferentiator implements GroupDifferentiator<Double> { +public class WeightedVarianceSplitFinder implements SplitFinder<Double> { private Double getScore(Set leftHand, Set rightHand) { @@ -44,17 +44,17 @@ public class WeightedVarianceGroupDifferentiator implements GroupDifferentiator< } @Override - public SplitAndScore<Double, ?> differentiate(Iterator<Split<Double, ?>> splitIterator) { + public SplitAndScore<Double, ?> findBestSplit(Iterator<Split<Double, ?>> splitIterator) { if(splitIterator instanceof Covariate.SplitRuleUpdater){ - return differentiateWithSplitUpdater((Covariate.SplitRuleUpdater) splitIterator); + return findBestSplitWithSplitUpdater((Covariate.SplitRuleUpdater) splitIterator); } else{ - return differentiateWithBasicIterator(splitIterator); + return findBestSplitWithBasicIterator(splitIterator); } } - private SplitAndScore<Double, ?> differentiateWithBasicIterator(Iterator<Split<Double, ?>> splitIterator){ + private SplitAndScore<Double, ?> findBestSplitWithBasicIterator(Iterator<Split<Double, ?>> splitIterator){ Double bestScore = null; Split<Double, ?> bestSplit = null; @@ -86,7 +86,7 @@ public class WeightedVarianceGroupDifferentiator implements GroupDifferentiator< return new SplitAndScore<>(bestSplit, bestScore); } - private SplitAndScore<Double, ?> differentiateWithSplitUpdater(Covariate.SplitRuleUpdater<Double, ?> splitRuleUpdater) { + private SplitAndScore<Double, ?> findBestSplitWithSplitUpdater(Covariate.SplitRuleUpdater<Double, ?> splitRuleUpdater) { final List<Double> leftInitialSplit = splitRuleUpdater.currentSplit().getLeftHand() .stream().map(Row::getResponse).collect(Collectors.toList()); diff --git a/src/main/java/ca/joeltherrien/randomforest/tree/SimpleGroupDifferentiator.java b/src/main/java/ca/joeltherrien/randomforest/tree/SimpleSplitFinder.java similarity index 93% rename from src/main/java/ca/joeltherrien/randomforest/tree/SimpleGroupDifferentiator.java rename to src/main/java/ca/joeltherrien/randomforest/tree/SimpleSplitFinder.java index 2cacec7..7fabe4b 100644 --- a/src/main/java/ca/joeltherrien/randomforest/tree/SimpleGroupDifferentiator.java +++ b/src/main/java/ca/joeltherrien/randomforest/tree/SimpleSplitFinder.java @@ -22,10 +22,10 @@ import java.util.Iterator; import java.util.List; import java.util.stream.Collectors; -public abstract class SimpleGroupDifferentiator<Y> implements GroupDifferentiator<Y> { +public abstract class SimpleSplitFinder<Y> implements SplitFinder<Y> { @Override - public SplitAndScore<Y, ?> differentiate(Iterator<Split<Y, ?>> splitIterator) { + public SplitAndScore<Y, ?> findBestSplit(Iterator<Split<Y, ?>> splitIterator) { Double bestScore = null; Split<Y, ?> bestSplit = null; diff --git a/src/main/java/ca/joeltherrien/randomforest/tree/GroupDifferentiator.java b/src/main/java/ca/joeltherrien/randomforest/tree/SplitFinder.java similarity index 73% rename from src/main/java/ca/joeltherrien/randomforest/tree/GroupDifferentiator.java rename to src/main/java/ca/joeltherrien/randomforest/tree/SplitFinder.java index 8d72efb..3f09dd5 100644 --- a/src/main/java/ca/joeltherrien/randomforest/tree/GroupDifferentiator.java +++ b/src/main/java/ca/joeltherrien/randomforest/tree/SplitFinder.java @@ -21,14 +21,14 @@ import java.util.Iterator; /** * When choosing an optimal node to split on, we choose the split that maximizes the difference between the two groups. - * The GroupDifferentiator has one method that cycles through an iterator of Splits (FYI; check if the iterator is an + * The SplitFinder has one method that cycles through an iterator of Splits (FYI; check if the iterator is an * instance of Covariate.SplitRuleUpdater; in which case you get access to the rows that change between splits) * - * If you want to implement a very trivial GroupDifferentiator that just takes two Lists as arguments, try extending - * SimpleGroupDifferentiator. + * If you want to implement a very trivial SplitFinder that just takes two Lists as arguments, try extending + * SimpleSplitFinder. */ -public interface GroupDifferentiator<Y> extends Serializable { +public interface SplitFinder<Y> extends Serializable { - SplitAndScore<Y, ?> differentiate(Iterator<Split<Y, ?>> splitIterator); + SplitAndScore<Y, ?> findBestSplit(Iterator<Split<Y, ?>> splitIterator); } diff --git a/src/main/java/ca/joeltherrien/randomforest/tree/TreeTrainer.java b/src/main/java/ca/joeltherrien/randomforest/tree/TreeTrainer.java index a633664..6e016c8 100644 --- a/src/main/java/ca/joeltherrien/randomforest/tree/TreeTrainer.java +++ b/src/main/java/ca/joeltherrien/randomforest/tree/TreeTrainer.java @@ -31,7 +31,7 @@ import java.util.stream.Collectors; public class TreeTrainer<Y, O> { private final ResponseCombiner<Y, O> responseCombiner; - private final GroupDifferentiator<Y> groupDifferentiator; + private final SplitFinder<Y> splitFinder; /** * The number of splits to perform on each covariate. A value of 0 means all possible splits are tried. @@ -58,7 +58,7 @@ public class TreeTrainer<Y, O> { this.checkNodePurity = settings.isCheckNodePurity(); this.responseCombiner = settings.getResponseCombiner(); - this.groupDifferentiator = settings.getGroupDifferentiator(); + this.splitFinder = settings.getSplitFinder(); this.covariates = covariates; } @@ -160,7 +160,7 @@ public class TreeTrainer<Y, O> { private Split<Y, ?> findBestSplitRule(List<Row<Y>> data, List<Covariate> covariatesToTry, Random random){ SplitAndScore<Y, ?> bestSplitAndScore = null; - final GroupDifferentiator noGenericDifferentiator = groupDifferentiator; // cause Java generics suck + final SplitFinder noGenericSplitFinder = splitFinder; // cause Java generics are sometimes too frustrating for(final Covariate covariate : covariatesToTry) { final Iterator<Split> iterator = covariate.generateSplitRuleUpdater(data, this.numberOfSplits, random); @@ -170,7 +170,7 @@ public class TreeTrainer<Y, O> { continue; } - final SplitAndScore<Y, ?> candidateSplitAndScore = noGenericDifferentiator.differentiate(iterator); + final SplitAndScore<Y, ?> candidateSplitAndScore = noGenericSplitFinder.findBestSplit(iterator); if(candidateSplitAndScore != null && (bestSplitAndScore == null || candidateSplitAndScore.getScore() > bestSplitAndScore.getScore())) { diff --git a/src/test/java/ca/joeltherrien/randomforest/TestSavingLoading.java b/src/test/java/ca/joeltherrien/randomforest/TestSavingLoading.java index 58dd3e7..7ae2dea 100644 --- a/src/test/java/ca/joeltherrien/randomforest/TestSavingLoading.java +++ b/src/test/java/ca/joeltherrien/randomforest/TestSavingLoading.java @@ -44,12 +44,12 @@ public class TestSavingLoading { * @return */ public Settings getSettings(){ - final ObjectNode groupDifferentiatorSettings = new ObjectNode(JsonNodeFactory.instance); - groupDifferentiatorSettings.set("type", new TextNode("LogRankDifferentiator")); - groupDifferentiatorSettings.set("eventsOfFocus", + final ObjectNode splitRuleSettings = new ObjectNode(JsonNodeFactory.instance); + splitRuleSettings.set("type", new TextNode("LogRankSplitFinder")); + splitRuleSettings.set("eventsOfFocus", new ArrayNode(JsonNodeFactory.instance, Utils.easyList(new IntNode(1))) ); - groupDifferentiatorSettings.set("events", + splitRuleSettings.set("events", new ArrayNode(JsonNodeFactory.instance, Utils.easyList(new IntNode(1), new IntNode(2))) ); @@ -85,7 +85,7 @@ public class TestSavingLoading { .validationDataLocation("src/test/resources/wihs.csv") .responseCombinerSettings(responseCombinerSettings) .treeCombinerSettings(treeCombinerSettings) - .groupDifferentiatorSettings(groupDifferentiatorSettings) + .splitFinderSettings(splitRuleSettings) .yVarSettings(yVarSettings) .maxNodeDepth(100000) // TODO fill in these settings diff --git a/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCompetingRisk.java b/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCompetingRisk.java index 9ba04d7..1f78096 100644 --- a/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCompetingRisk.java +++ b/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCompetingRisk.java @@ -53,12 +53,12 @@ public class TestCompetingRisk { * @return */ public Settings getSettings(){ - final ObjectNode groupDifferentiatorSettings = new ObjectNode(JsonNodeFactory.instance); - groupDifferentiatorSettings.set("type", new TextNode("LogRankDifferentiator")); - groupDifferentiatorSettings.set("eventsOfFocus", + final ObjectNode splitFinderSettings = new ObjectNode(JsonNodeFactory.instance); + splitFinderSettings.set("type", new TextNode("LogRankSplitFinder")); + splitFinderSettings.set("eventsOfFocus", new ArrayNode(JsonNodeFactory.instance, Utils.easyList(new IntNode(1))) ); - groupDifferentiatorSettings.set("events", + splitFinderSettings.set("events", new ArrayNode(JsonNodeFactory.instance, Utils.easyList(new IntNode(1), new IntNode(2))) ); @@ -93,7 +93,7 @@ public class TestCompetingRisk { .trainingDataLocation("src/test/resources/wihs.csv") .responseCombinerSettings(responseCombinerSettings) .treeCombinerSettings(treeCombinerSettings) - .groupDifferentiatorSettings(groupDifferentiatorSettings) + .splitFinderSettings(splitFinderSettings) .yVarSettings(yVarSettings) .maxNodeDepth(100000) // TODO fill in these settings @@ -222,7 +222,7 @@ public class TestCompetingRisk { } @Test - public void testLogRankSingleGroupDifferentiatorTwoBooleans() throws IOException { + public void testLogRankSplitFinderTwoBooleans() throws IOException { final Settings settings = getSettings(); settings.setCovariateSettings(Utils.easyList( new BooleanCovariateSettings("idu"), @@ -337,7 +337,7 @@ public class TestCompetingRisk { } @Test - public void testLogRankSingleGroupDifferentiatorAllCovariates() throws IOException { + public void testLogRankSplitFinderAllCovariates() throws IOException { final Settings settings = getSettings(); settings.setNtree(300); // results are too variable at 100 diff --git a/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestLogRankDifferentiator.java b/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestLogRankSplitFinder.java similarity index 91% rename from src/test/java/ca/joeltherrien/randomforest/competingrisk/TestLogRankDifferentiator.java rename to src/test/java/ca/joeltherrien/randomforest/competingrisk/TestLogRankSplitFinder.java index 6a600dd..259410b 100644 --- a/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestLogRankDifferentiator.java +++ b/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestLogRankSplitFinder.java @@ -21,7 +21,7 @@ import ca.joeltherrien.randomforest.Settings; import ca.joeltherrien.randomforest.covariates.Covariate; import ca.joeltherrien.randomforest.covariates.settings.NumericCovariateSettings; import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse; -import ca.joeltherrien.randomforest.responses.competingrisk.differentiator.LogRankDifferentiator; +import ca.joeltherrien.randomforest.responses.competingrisk.splitfinder.LogRankSplitFinder; import ca.joeltherrien.randomforest.tree.Split; import ca.joeltherrien.randomforest.utils.DataUtils; import ca.joeltherrien.randomforest.utils.SingletonIterator; @@ -39,7 +39,7 @@ import java.util.List; import static org.junit.jupiter.api.Assertions.assertTrue; -public class TestLogRankDifferentiator { +public class TestLogRankSplitFinder { private Iterator<Split<CompetingRiskResponse, ?>> turnIntoSplitIterator(List<Row<CompetingRiskResponse>> leftList, List<Row<CompetingRiskResponse>> rightList){ @@ -70,14 +70,14 @@ public class TestLogRankDifferentiator { @Test public void testSplitRule() throws IOException { - final LogRankDifferentiator groupDifferentiator = new LogRankDifferentiator(new int[]{1,2}, new int[]{1,2}); + final LogRankSplitFinder splitFinder = new LogRankSplitFinder(new int[]{1,2}, new int[]{1,2}); final List<Row<CompetingRiskResponse>> data = loadData("src/test/resources/test_split_data.csv").getRows(); final List<Row<CompetingRiskResponse>> group1Bad = data.subList(0, 196); final List<Row<CompetingRiskResponse>> group2Bad = data.subList(196, data.size()); - final double scoreBad = groupDifferentiator.differentiate(turnIntoSplitIterator(group1Bad, group2Bad)).getScore(); + final double scoreBad = splitFinder.findBestSplit(turnIntoSplitIterator(group1Bad, group2Bad)).getScore(); // expected results calculated manually using survival::survdiff in R; see issue #10 in Gitea closeEnough(9.413002, scoreBad, 0.00001); diff --git a/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestLogRankSingleGroupDifferentiator.java b/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestLogRankSplitFinderSingleEvent.java similarity index 80% rename from src/test/java/ca/joeltherrien/randomforest/competingrisk/TestLogRankSingleGroupDifferentiator.java rename to src/test/java/ca/joeltherrien/randomforest/competingrisk/TestLogRankSplitFinderSingleEvent.java index 8b0693f..205500f 100644 --- a/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestLogRankSingleGroupDifferentiator.java +++ b/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestLogRankSplitFinderSingleEvent.java @@ -18,8 +18,8 @@ package ca.joeltherrien.randomforest.competingrisk; import ca.joeltherrien.randomforest.Row; import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse; -import ca.joeltherrien.randomforest.responses.competingrisk.differentiator.LogRankDifferentiator; -import ca.joeltherrien.randomforest.tree.GroupDifferentiator; +import ca.joeltherrien.randomforest.responses.competingrisk.splitfinder.LogRankSplitFinder; +import ca.joeltherrien.randomforest.tree.SplitFinder; import ca.joeltherrien.randomforest.tree.Split; import ca.joeltherrien.randomforest.utils.SingletonIterator; import org.junit.jupiter.api.Test; @@ -33,13 +33,13 @@ import java.util.List; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertTrue; -public class TestLogRankSingleGroupDifferentiator { +public class TestLogRankSplitFinderSingleEvent { - private double getScore(final GroupDifferentiator<CompetingRiskResponse> groupDifferentiator, List<Row<CompetingRiskResponse>> left, List<Row<CompetingRiskResponse>> right){ + private double getScore(final SplitFinder<CompetingRiskResponse> splitFinder, List<Row<CompetingRiskResponse>> left, List<Row<CompetingRiskResponse>> right){ final Iterator<Split<CompetingRiskResponse, ?>> iterator = new SingletonIterator<>( new Split<>(null, left, right, Collections.emptyList())); - return groupDifferentiator.differentiate(iterator).getScore(); + return splitFinder.findBestSplit(iterator).getScore(); } @@ -81,9 +81,9 @@ public class TestLogRankSingleGroupDifferentiator { final List<Row<CompetingRiskResponse>> data1 = generateData1(); final List<Row<CompetingRiskResponse>> data2 = generateData2(); - final LogRankDifferentiator differentiator = new LogRankDifferentiator(new int[]{1}, new int[]{1}); + final LogRankSplitFinder splitFinder = new LogRankSplitFinder(new int[]{1}, new int[]{1}); - final double score = getScore(differentiator, data1, data2); + final double score = getScore(splitFinder, data1, data2); final double margin = 0.000001; // Tested using 855 method @@ -94,21 +94,21 @@ public class TestLogRankSingleGroupDifferentiator { @Test public void testCorrectSplit() throws IOException { - final LogRankDifferentiator groupDifferentiator = - new LogRankDifferentiator(new int[]{1}, new int[]{1,2}); + final LogRankSplitFinder splitFinder = + new LogRankSplitFinder(new int[]{1}, new int[]{1,2}); - final List<Row<CompetingRiskResponse>> data = TestLogRankDifferentiator. + final List<Row<CompetingRiskResponse>> data = TestLogRankSplitFinder. loadData("src/test/resources/test_single_split.csv").getRows(); final List<Row<CompetingRiskResponse>> group1Good = data.subList(0, 221); final List<Row<CompetingRiskResponse>> group2Good = data.subList(221, data.size()); - final double scoreGood = getScore(groupDifferentiator, group1Good, group2Good); + final double scoreGood = getScore(splitFinder, group1Good, group2Good); final List<Row<CompetingRiskResponse>> group1Bad = data.subList(0, 222); final List<Row<CompetingRiskResponse>> group2Bad = data.subList(222, data.size()); - final double scoreBad = getScore(groupDifferentiator, group1Bad, group2Bad); + final double scoreBad = getScore(splitFinder, group1Bad, group2Bad); // Apparently not all groups are unique when splitting assertEquals(scoreGood, scoreBad); diff --git a/src/test/java/ca/joeltherrien/randomforest/nas/TestNAs.java b/src/test/java/ca/joeltherrien/randomforest/nas/TestNAs.java index d748f1e..cb9a116 100644 --- a/src/test/java/ca/joeltherrien/randomforest/nas/TestNAs.java +++ b/src/test/java/ca/joeltherrien/randomforest/nas/TestNAs.java @@ -4,7 +4,7 @@ import ca.joeltherrien.randomforest.Row; import ca.joeltherrien.randomforest.covariates.Covariate; import ca.joeltherrien.randomforest.covariates.numeric.NumericCovariate; import ca.joeltherrien.randomforest.responses.regression.MeanResponseCombiner; -import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceGroupDifferentiator; +import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceSplitFinder; import ca.joeltherrien.randomforest.tree.TreeTrainer; import ca.joeltherrien.randomforest.utils.Utils; import org.junit.jupiter.api.Test; @@ -47,7 +47,7 @@ public class TestNAs { .numberOfSplits(0) .nodeSize(1) .maxNodeDepth(1000) - .groupDifferentiator(new WeightedVarianceGroupDifferentiator()) + .splitFinder(new WeightedVarianceSplitFinder()) .responseCombiner(new MeanResponseCombiner()) .build(); diff --git a/src/test/java/ca/joeltherrien/randomforest/settings/TestPersistence.java b/src/test/java/ca/joeltherrien/randomforest/settings/TestPersistence.java index d1d0a41..acac42f 100644 --- a/src/test/java/ca/joeltherrien/randomforest/settings/TestPersistence.java +++ b/src/test/java/ca/joeltherrien/randomforest/settings/TestPersistence.java @@ -35,8 +35,8 @@ public class TestPersistence { @Test public void testSaving() throws IOException { - final ObjectNode groupDifferentiatorSettings = new ObjectNode(JsonNodeFactory.instance); - groupDifferentiatorSettings.set("type", new TextNode("WeightedVarianceGroupDifferentiator")); + final ObjectNode splitFinderSettings = new ObjectNode(JsonNodeFactory.instance); + splitFinderSettings.set("type", new TextNode("WeightedVarianceSplitFinder")); final ObjectNode responseCombinerSettings = new ObjectNode(JsonNodeFactory.instance); responseCombinerSettings.set("type", new TextNode("MeanResponseCombiner")); @@ -59,7 +59,7 @@ public class TestPersistence { .validationDataLocation("validation_data.csv") .responseCombinerSettings(responseCombinerSettings) .treeCombinerSettings(treeCombinerSettings) - .groupDifferentiatorSettings(groupDifferentiatorSettings) + .splitFinderSettings(splitFinderSettings) .yVarSettings(yVarSettings) .maxNodeDepth(100000) .mtry(2) diff --git a/src/test/java/ca/joeltherrien/randomforest/workshop/TrainForest.java b/src/test/java/ca/joeltherrien/randomforest/workshop/TrainForest.java index 8579e0e..b27045e 100644 --- a/src/test/java/ca/joeltherrien/randomforest/workshop/TrainForest.java +++ b/src/test/java/ca/joeltherrien/randomforest/workshop/TrainForest.java @@ -20,7 +20,7 @@ import ca.joeltherrien.randomforest.*; import ca.joeltherrien.randomforest.covariates.Covariate; import ca.joeltherrien.randomforest.covariates.numeric.NumericCovariate; import ca.joeltherrien.randomforest.responses.regression.MeanResponseCombiner; -import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceGroupDifferentiator; +import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceSplitFinder; import ca.joeltherrien.randomforest.tree.ForestTrainer; import ca.joeltherrien.randomforest.tree.TreeTrainer; @@ -72,7 +72,7 @@ public class TrainForest { .nodeSize(5) .mtry(4) .maxNodeDepth(100000000) - .groupDifferentiator(new WeightedVarianceGroupDifferentiator()) + .splitFinder(new WeightedVarianceSplitFinder()) .responseCombiner(new MeanResponseCombiner()) .build(); diff --git a/src/test/java/ca/joeltherrien/randomforest/workshop/TrainSingleTree.java b/src/test/java/ca/joeltherrien/randomforest/workshop/TrainSingleTree.java index 1527d4a..8e1e361 100644 --- a/src/test/java/ca/joeltherrien/randomforest/workshop/TrainSingleTree.java +++ b/src/test/java/ca/joeltherrien/randomforest/workshop/TrainSingleTree.java @@ -22,7 +22,7 @@ import ca.joeltherrien.randomforest.CovariateRow; import ca.joeltherrien.randomforest.covariates.numeric.NumericCovariate; import ca.joeltherrien.randomforest.Row; import ca.joeltherrien.randomforest.responses.regression.MeanResponseCombiner; -import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceGroupDifferentiator; +import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceSplitFinder; import ca.joeltherrien.randomforest.tree.Node; import ca.joeltherrien.randomforest.tree.TreeTrainer; import ca.joeltherrien.randomforest.utils.Utils; @@ -66,7 +66,7 @@ public class TrainSingleTree { final List<Covariate> covariateNames = Utils.easyList(x1Covariate, x2Covariate); final TreeTrainer<Double, Double> treeTrainer = TreeTrainer.<Double, Double>builder() - .groupDifferentiator(new WeightedVarianceGroupDifferentiator()) + .splitFinder(new WeightedVarianceSplitFinder()) .covariates(covariateNames) .responseCombiner(new MeanResponseCombiner()) .maxNodeDepth(30) diff --git a/src/test/java/ca/joeltherrien/randomforest/workshop/TrainSingleTreeFactor.java b/src/test/java/ca/joeltherrien/randomforest/workshop/TrainSingleTreeFactor.java index 7313d22..1d2b404 100644 --- a/src/test/java/ca/joeltherrien/randomforest/workshop/TrainSingleTreeFactor.java +++ b/src/test/java/ca/joeltherrien/randomforest/workshop/TrainSingleTreeFactor.java @@ -22,7 +22,7 @@ import ca.joeltherrien.randomforest.covariates.Covariate; import ca.joeltherrien.randomforest.covariates.factor.FactorCovariate; import ca.joeltherrien.randomforest.covariates.numeric.NumericCovariate; import ca.joeltherrien.randomforest.responses.regression.MeanResponseCombiner; -import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceGroupDifferentiator; +import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceSplitFinder; import ca.joeltherrien.randomforest.tree.Node; import ca.joeltherrien.randomforest.tree.TreeTrainer; import ca.joeltherrien.randomforest.utils.Utils; @@ -86,7 +86,7 @@ public class TrainSingleTreeFactor { final List<Covariate> covariateNames = Utils.easyList(x1Covariate, x2Covariate); final TreeTrainer<Double, Double> treeTrainer = TreeTrainer.<Double, Double>builder() - .groupDifferentiator(new WeightedVarianceGroupDifferentiator()) + .splitFinder(new WeightedVarianceSplitFinder()) .responseCombiner(new MeanResponseCombiner()) .covariates(covariateNames) .maxNodeDepth(30)