diff --git a/src/main/java/ca/joeltherrien/randomforest/Main.java b/src/main/java/ca/joeltherrien/randomforest/Main.java index e353c7f..3aa29ab 100644 --- a/src/main/java/ca/joeltherrien/randomforest/Main.java +++ b/src/main/java/ca/joeltherrien/randomforest/Main.java @@ -180,8 +180,8 @@ public class Main { private static Settings defaultTemplate(){ - final ObjectNode groupDifferentiatorSettings = new ObjectNode(JsonNodeFactory.instance); - groupDifferentiatorSettings.set("type", new TextNode("WeightedVarianceGroupDifferentiator")); + final ObjectNode splitFinderSettings = new ObjectNode(JsonNodeFactory.instance); + splitFinderSettings.set("type", new TextNode("WeightedVarianceSplitFinder")); final ObjectNode responseCombinerSettings = new ObjectNode(JsonNodeFactory.instance); responseCombinerSettings.set("type", new TextNode("MeanResponseCombiner")); @@ -204,7 +204,7 @@ public class Main { .validationDataLocation("validation_data.csv") .responseCombinerSettings(responseCombinerSettings) .treeCombinerSettings(treeCombinerSettings) - .groupDifferentiatorSettings(groupDifferentiatorSettings) + .splitFinderSettings(splitFinderSettings) .yVarSettings(yVarSettings) .maxNodeDepth(100000) .mtry(2) diff --git a/src/main/java/ca/joeltherrien/randomforest/Settings.java b/src/main/java/ca/joeltherrien/randomforest/Settings.java index da83ff4..64bd8da 100644 --- a/src/main/java/ca/joeltherrien/randomforest/Settings.java +++ b/src/main/java/ca/joeltherrien/randomforest/Settings.java @@ -22,11 +22,11 @@ import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskRespons import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponseWithCensorTime; import ca.joeltherrien.randomforest.responses.competingrisk.combiner.CompetingRiskFunctionCombiner; import ca.joeltherrien.randomforest.responses.competingrisk.combiner.CompetingRiskResponseCombiner; -import ca.joeltherrien.randomforest.responses.competingrisk.differentiator.GrayLogRankDifferentiator; -import ca.joeltherrien.randomforest.responses.competingrisk.differentiator.LogRankDifferentiator; +import ca.joeltherrien.randomforest.responses.competingrisk.splitfinder.GrayLogRankSplitFinder; +import ca.joeltherrien.randomforest.responses.competingrisk.splitfinder.LogRankSplitFinder; import ca.joeltherrien.randomforest.responses.regression.MeanResponseCombiner; -import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceGroupDifferentiator; -import ca.joeltherrien.randomforest.tree.GroupDifferentiator; +import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceSplitFinder; +import ca.joeltherrien.randomforest.tree.SplitFinder; import ca.joeltherrien.randomforest.tree.ResponseCombiner; import ca.joeltherrien.randomforest.utils.DataUtils; import ca.joeltherrien.randomforest.utils.Utils; @@ -77,31 +77,31 @@ public class Settings { ); } - private static Map> GROUP_DIFFERENTIATOR_MAP = new HashMap<>(); - public static Function getGroupDifferentiatorConstructor(final String name){ - return GROUP_DIFFERENTIATOR_MAP.get(name.toLowerCase()); + private static Map> SPLIT_FINDER_MAP = new HashMap<>(); + public static Function getSplitFinderConstructor(final String name){ + return SPLIT_FINDER_MAP.get(name.toLowerCase()); } - public static void registerGroupDifferentiatorConstructor(final String name, final Function groupDifferentiatorConstructor){ - GROUP_DIFFERENTIATOR_MAP.put(name.toLowerCase(), groupDifferentiatorConstructor); + public static void registerSplitFinderConstructor(final String name, final Function splitFinderConstructor){ + SPLIT_FINDER_MAP.put(name.toLowerCase(), splitFinderConstructor); } static{ - registerGroupDifferentiatorConstructor("WeightedVarianceGroupDifferentiator", - (node) -> new WeightedVarianceGroupDifferentiator() + registerSplitFinderConstructor("WeightedVarianceSplitFinder", + (node) -> new WeightedVarianceSplitFinder() ); - registerGroupDifferentiatorConstructor("GrayLogRankDifferentiator", + registerSplitFinderConstructor("GrayLogRankSplitFinder", (objectNode) -> { final int[] eventsOfFocusArray = Utils.jsonToIntArray(objectNode.get("eventsOfFocus")); final int[] eventArray = Utils.jsonToIntArray(objectNode.get("events")); - return new GrayLogRankDifferentiator(eventsOfFocusArray, eventArray); + return new GrayLogRankSplitFinder(eventsOfFocusArray, eventArray); } ); - registerGroupDifferentiatorConstructor("LogRankDifferentiator", + registerSplitFinderConstructor("LogRankSplitFinder", (objectNode) -> { final int[] eventsOfFocusArray = Utils.jsonToIntArray(objectNode.get("eventsOfFocus")); final int[] eventArray = Utils.jsonToIntArray(objectNode.get("events")); - return new LogRankDifferentiator(eventsOfFocusArray, eventArray); + return new LogRankSplitFinder(eventsOfFocusArray, eventArray); } ); } @@ -153,7 +153,7 @@ public class Settings { private boolean checkNodePurity = false; private ObjectNode responseCombinerSettings = new ObjectNode(JsonNodeFactory.instance); - private ObjectNode groupDifferentiatorSettings = new ObjectNode(JsonNodeFactory.instance); + private ObjectNode splitFinderSettings = new ObjectNode(JsonNodeFactory.instance); private ObjectNode treeCombinerSettings = new ObjectNode(JsonNodeFactory.instance); private List covariateSettings = new ArrayList<>(); @@ -194,10 +194,10 @@ public class Settings { } @JsonIgnore - public GroupDifferentiator getGroupDifferentiator(){ - final String type = groupDifferentiatorSettings.get("type").asText(); + public SplitFinder getSplitFinder(){ + final String type = splitFinderSettings.get("type").asText(); - return getGroupDifferentiatorConstructor(type).apply(groupDifferentiatorSettings); + return getSplitFinderConstructor(type).apply(splitFinderSettings); } @JsonIgnore diff --git a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/differentiator/CompetingRiskGroupDifferentiator.java b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/splitfinder/CompetingRiskSplitFinder.java similarity index 92% rename from src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/differentiator/CompetingRiskGroupDifferentiator.java rename to src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/splitfinder/CompetingRiskSplitFinder.java index aef9b4c..b2cddb8 100644 --- a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/differentiator/CompetingRiskGroupDifferentiator.java +++ b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/splitfinder/CompetingRiskSplitFinder.java @@ -14,13 +14,13 @@ * along with this program. If not, see . */ -package ca.joeltherrien.randomforest.responses.competingrisk.differentiator; +package ca.joeltherrien.randomforest.responses.competingrisk.splitfinder; import ca.joeltherrien.randomforest.Row; import ca.joeltherrien.randomforest.covariates.Covariate; import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse; import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskSets; -import ca.joeltherrien.randomforest.tree.GroupDifferentiator; +import ca.joeltherrien.randomforest.tree.SplitFinder; import ca.joeltherrien.randomforest.tree.Split; import ca.joeltherrien.randomforest.tree.SplitAndScore; import lombok.AllArgsConstructor; @@ -35,24 +35,24 @@ import java.util.stream.Collectors; * modifies the abstract method. * */ -public abstract class CompetingRiskGroupDifferentiator implements GroupDifferentiator { +public abstract class CompetingRiskSplitFinder implements SplitFinder { abstract protected CompetingRiskSets createCompetingRiskSets(List leftHand, List rightHand); abstract protected Double getScore(final CompetingRiskSets competingRiskSets); @Override - public SplitAndScore differentiate(Iterator> splitIterator) { + public SplitAndScore findBestSplit(Iterator> splitIterator) { if(splitIterator instanceof Covariate.SplitRuleUpdater){ - return differentiateWithSplitUpdater((Covariate.SplitRuleUpdater) splitIterator); + return findBestSplitWithSplitUpdater((Covariate.SplitRuleUpdater) splitIterator); } else{ - return differentiateWithBasicIterator(splitIterator); + return findBestSplitWithBasicIterator(splitIterator); } } - private SplitAndScore differentiateWithBasicIterator(Iterator> splitIterator){ + private SplitAndScore findBestSplitWithBasicIterator(Iterator> splitIterator){ Double bestScore = null; Split bestSplit = null; @@ -83,7 +83,7 @@ public abstract class CompetingRiskGroupDifferentiator(bestSplit, bestScore); } - private SplitAndScore differentiateWithSplitUpdater(Covariate.SplitRuleUpdater splitRuleUpdater) { + private SplitAndScore findBestSplitWithSplitUpdater(Covariate.SplitRuleUpdater splitRuleUpdater) { final List leftInitialSplit = splitRuleUpdater.currentSplit().getLeftHand() .stream().map(Row::getResponse).collect(Collectors.toList()); diff --git a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/differentiator/GrayLogRankDifferentiator.java b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/splitfinder/GrayLogRankSplitFinder.java similarity index 91% rename from src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/differentiator/GrayLogRankDifferentiator.java rename to src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/splitfinder/GrayLogRankSplitFinder.java index 420ad15..b6eed13 100644 --- a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/differentiator/GrayLogRankDifferentiator.java +++ b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/splitfinder/GrayLogRankSplitFinder.java @@ -14,7 +14,7 @@ * along with this program. If not, see . */ -package ca.joeltherrien.randomforest.responses.competingrisk.differentiator; +package ca.joeltherrien.randomforest.responses.competingrisk.splitfinder; import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponseWithCensorTime; import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskSets; @@ -27,12 +27,12 @@ import java.util.List; * See page 761 of Random survival forests for competing risks by Ishwaran et al. * */ -public class GrayLogRankDifferentiator extends CompetingRiskGroupDifferentiator { +public class GrayLogRankSplitFinder extends CompetingRiskSplitFinder { private final int[] eventsOfFocus; private final int[] events; - public GrayLogRankDifferentiator(int[] eventsOfFocus, int[] events){ + public GrayLogRankSplitFinder(int[] eventsOfFocus, int[] events){ this.eventsOfFocus = eventsOfFocus; this.events = events; diff --git a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/differentiator/LogRankDifferentiator.java b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/splitfinder/LogRankSplitFinder.java similarity index 91% rename from src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/differentiator/LogRankDifferentiator.java rename to src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/splitfinder/LogRankSplitFinder.java index d528de3..c31bb84 100644 --- a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/differentiator/LogRankDifferentiator.java +++ b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/splitfinder/LogRankSplitFinder.java @@ -14,7 +14,7 @@ * along with this program. If not, see . */ -package ca.joeltherrien.randomforest.responses.competingrisk.differentiator; +package ca.joeltherrien.randomforest.responses.competingrisk.splitfinder; import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse; import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskSets; @@ -27,12 +27,12 @@ import java.util.List; * See page 761 of Random survival forests for competing risks by Ishwaran et al. * */ -public class LogRankDifferentiator extends CompetingRiskGroupDifferentiator { +public class LogRankSplitFinder extends CompetingRiskSplitFinder { private final int[] eventsOfFocus; private final int[] events; - public LogRankDifferentiator(int[] eventsOfFocus, int[] events){ + public LogRankSplitFinder(int[] eventsOfFocus, int[] events){ this.eventsOfFocus = eventsOfFocus; this.events = events; diff --git a/src/main/java/ca/joeltherrien/randomforest/responses/regression/WeightedVarianceGroupDifferentiator.java b/src/main/java/ca/joeltherrien/randomforest/responses/regression/WeightedVarianceSplitFinder.java similarity index 91% rename from src/main/java/ca/joeltherrien/randomforest/responses/regression/WeightedVarianceGroupDifferentiator.java rename to src/main/java/ca/joeltherrien/randomforest/responses/regression/WeightedVarianceSplitFinder.java index a16f403..9fb4d43 100644 --- a/src/main/java/ca/joeltherrien/randomforest/responses/regression/WeightedVarianceGroupDifferentiator.java +++ b/src/main/java/ca/joeltherrien/randomforest/responses/regression/WeightedVarianceSplitFinder.java @@ -18,7 +18,7 @@ package ca.joeltherrien.randomforest.responses.regression; import ca.joeltherrien.randomforest.Row; import ca.joeltherrien.randomforest.covariates.Covariate; -import ca.joeltherrien.randomforest.tree.GroupDifferentiator; +import ca.joeltherrien.randomforest.tree.SplitFinder; import ca.joeltherrien.randomforest.tree.Split; import ca.joeltherrien.randomforest.tree.SplitAndScore; @@ -26,7 +26,7 @@ import java.util.Iterator; import java.util.List; import java.util.stream.Collectors; -public class WeightedVarianceGroupDifferentiator implements GroupDifferentiator { +public class WeightedVarianceSplitFinder implements SplitFinder { private Double getScore(Set leftHand, Set rightHand) { @@ -44,17 +44,17 @@ public class WeightedVarianceGroupDifferentiator implements GroupDifferentiator< } @Override - public SplitAndScore differentiate(Iterator> splitIterator) { + public SplitAndScore findBestSplit(Iterator> splitIterator) { if(splitIterator instanceof Covariate.SplitRuleUpdater){ - return differentiateWithSplitUpdater((Covariate.SplitRuleUpdater) splitIterator); + return findBestSplitWithSplitUpdater((Covariate.SplitRuleUpdater) splitIterator); } else{ - return differentiateWithBasicIterator(splitIterator); + return findBestSplitWithBasicIterator(splitIterator); } } - private SplitAndScore differentiateWithBasicIterator(Iterator> splitIterator){ + private SplitAndScore findBestSplitWithBasicIterator(Iterator> splitIterator){ Double bestScore = null; Split bestSplit = null; @@ -86,7 +86,7 @@ public class WeightedVarianceGroupDifferentiator implements GroupDifferentiator< return new SplitAndScore<>(bestSplit, bestScore); } - private SplitAndScore differentiateWithSplitUpdater(Covariate.SplitRuleUpdater splitRuleUpdater) { + private SplitAndScore findBestSplitWithSplitUpdater(Covariate.SplitRuleUpdater splitRuleUpdater) { final List leftInitialSplit = splitRuleUpdater.currentSplit().getLeftHand() .stream().map(Row::getResponse).collect(Collectors.toList()); diff --git a/src/main/java/ca/joeltherrien/randomforest/tree/SimpleGroupDifferentiator.java b/src/main/java/ca/joeltherrien/randomforest/tree/SimpleSplitFinder.java similarity index 93% rename from src/main/java/ca/joeltherrien/randomforest/tree/SimpleGroupDifferentiator.java rename to src/main/java/ca/joeltherrien/randomforest/tree/SimpleSplitFinder.java index 2cacec7..7fabe4b 100644 --- a/src/main/java/ca/joeltherrien/randomforest/tree/SimpleGroupDifferentiator.java +++ b/src/main/java/ca/joeltherrien/randomforest/tree/SimpleSplitFinder.java @@ -22,10 +22,10 @@ import java.util.Iterator; import java.util.List; import java.util.stream.Collectors; -public abstract class SimpleGroupDifferentiator implements GroupDifferentiator { +public abstract class SimpleSplitFinder implements SplitFinder { @Override - public SplitAndScore differentiate(Iterator> splitIterator) { + public SplitAndScore findBestSplit(Iterator> splitIterator) { Double bestScore = null; Split bestSplit = null; diff --git a/src/main/java/ca/joeltherrien/randomforest/tree/GroupDifferentiator.java b/src/main/java/ca/joeltherrien/randomforest/tree/SplitFinder.java similarity index 73% rename from src/main/java/ca/joeltherrien/randomforest/tree/GroupDifferentiator.java rename to src/main/java/ca/joeltherrien/randomforest/tree/SplitFinder.java index 8d72efb..3f09dd5 100644 --- a/src/main/java/ca/joeltherrien/randomforest/tree/GroupDifferentiator.java +++ b/src/main/java/ca/joeltherrien/randomforest/tree/SplitFinder.java @@ -21,14 +21,14 @@ import java.util.Iterator; /** * When choosing an optimal node to split on, we choose the split that maximizes the difference between the two groups. - * The GroupDifferentiator has one method that cycles through an iterator of Splits (FYI; check if the iterator is an + * The SplitFinder has one method that cycles through an iterator of Splits (FYI; check if the iterator is an * instance of Covariate.SplitRuleUpdater; in which case you get access to the rows that change between splits) * - * If you want to implement a very trivial GroupDifferentiator that just takes two Lists as arguments, try extending - * SimpleGroupDifferentiator. + * If you want to implement a very trivial SplitFinder that just takes two Lists as arguments, try extending + * SimpleSplitFinder. */ -public interface GroupDifferentiator extends Serializable { +public interface SplitFinder extends Serializable { - SplitAndScore differentiate(Iterator> splitIterator); + SplitAndScore findBestSplit(Iterator> splitIterator); } diff --git a/src/main/java/ca/joeltherrien/randomforest/tree/TreeTrainer.java b/src/main/java/ca/joeltherrien/randomforest/tree/TreeTrainer.java index a633664..6e016c8 100644 --- a/src/main/java/ca/joeltherrien/randomforest/tree/TreeTrainer.java +++ b/src/main/java/ca/joeltherrien/randomforest/tree/TreeTrainer.java @@ -31,7 +31,7 @@ import java.util.stream.Collectors; public class TreeTrainer { private final ResponseCombiner responseCombiner; - private final GroupDifferentiator groupDifferentiator; + private final SplitFinder splitFinder; /** * The number of splits to perform on each covariate. A value of 0 means all possible splits are tried. @@ -58,7 +58,7 @@ public class TreeTrainer { this.checkNodePurity = settings.isCheckNodePurity(); this.responseCombiner = settings.getResponseCombiner(); - this.groupDifferentiator = settings.getGroupDifferentiator(); + this.splitFinder = settings.getSplitFinder(); this.covariates = covariates; } @@ -160,7 +160,7 @@ public class TreeTrainer { private Split findBestSplitRule(List> data, List covariatesToTry, Random random){ SplitAndScore bestSplitAndScore = null; - final GroupDifferentiator noGenericDifferentiator = groupDifferentiator; // cause Java generics suck + final SplitFinder noGenericSplitFinder = splitFinder; // cause Java generics are sometimes too frustrating for(final Covariate covariate : covariatesToTry) { final Iterator iterator = covariate.generateSplitRuleUpdater(data, this.numberOfSplits, random); @@ -170,7 +170,7 @@ public class TreeTrainer { continue; } - final SplitAndScore candidateSplitAndScore = noGenericDifferentiator.differentiate(iterator); + final SplitAndScore candidateSplitAndScore = noGenericSplitFinder.findBestSplit(iterator); if(candidateSplitAndScore != null && (bestSplitAndScore == null || candidateSplitAndScore.getScore() > bestSplitAndScore.getScore())) { diff --git a/src/test/java/ca/joeltherrien/randomforest/TestSavingLoading.java b/src/test/java/ca/joeltherrien/randomforest/TestSavingLoading.java index 58dd3e7..7ae2dea 100644 --- a/src/test/java/ca/joeltherrien/randomforest/TestSavingLoading.java +++ b/src/test/java/ca/joeltherrien/randomforest/TestSavingLoading.java @@ -44,12 +44,12 @@ public class TestSavingLoading { * @return */ public Settings getSettings(){ - final ObjectNode groupDifferentiatorSettings = new ObjectNode(JsonNodeFactory.instance); - groupDifferentiatorSettings.set("type", new TextNode("LogRankDifferentiator")); - groupDifferentiatorSettings.set("eventsOfFocus", + final ObjectNode splitRuleSettings = new ObjectNode(JsonNodeFactory.instance); + splitRuleSettings.set("type", new TextNode("LogRankSplitFinder")); + splitRuleSettings.set("eventsOfFocus", new ArrayNode(JsonNodeFactory.instance, Utils.easyList(new IntNode(1))) ); - groupDifferentiatorSettings.set("events", + splitRuleSettings.set("events", new ArrayNode(JsonNodeFactory.instance, Utils.easyList(new IntNode(1), new IntNode(2))) ); @@ -85,7 +85,7 @@ public class TestSavingLoading { .validationDataLocation("src/test/resources/wihs.csv") .responseCombinerSettings(responseCombinerSettings) .treeCombinerSettings(treeCombinerSettings) - .groupDifferentiatorSettings(groupDifferentiatorSettings) + .splitFinderSettings(splitRuleSettings) .yVarSettings(yVarSettings) .maxNodeDepth(100000) // TODO fill in these settings diff --git a/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCompetingRisk.java b/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCompetingRisk.java index 9ba04d7..1f78096 100644 --- a/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCompetingRisk.java +++ b/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCompetingRisk.java @@ -53,12 +53,12 @@ public class TestCompetingRisk { * @return */ public Settings getSettings(){ - final ObjectNode groupDifferentiatorSettings = new ObjectNode(JsonNodeFactory.instance); - groupDifferentiatorSettings.set("type", new TextNode("LogRankDifferentiator")); - groupDifferentiatorSettings.set("eventsOfFocus", + final ObjectNode splitFinderSettings = new ObjectNode(JsonNodeFactory.instance); + splitFinderSettings.set("type", new TextNode("LogRankSplitFinder")); + splitFinderSettings.set("eventsOfFocus", new ArrayNode(JsonNodeFactory.instance, Utils.easyList(new IntNode(1))) ); - groupDifferentiatorSettings.set("events", + splitFinderSettings.set("events", new ArrayNode(JsonNodeFactory.instance, Utils.easyList(new IntNode(1), new IntNode(2))) ); @@ -93,7 +93,7 @@ public class TestCompetingRisk { .trainingDataLocation("src/test/resources/wihs.csv") .responseCombinerSettings(responseCombinerSettings) .treeCombinerSettings(treeCombinerSettings) - .groupDifferentiatorSettings(groupDifferentiatorSettings) + .splitFinderSettings(splitFinderSettings) .yVarSettings(yVarSettings) .maxNodeDepth(100000) // TODO fill in these settings @@ -222,7 +222,7 @@ public class TestCompetingRisk { } @Test - public void testLogRankSingleGroupDifferentiatorTwoBooleans() throws IOException { + public void testLogRankSplitFinderTwoBooleans() throws IOException { final Settings settings = getSettings(); settings.setCovariateSettings(Utils.easyList( new BooleanCovariateSettings("idu"), @@ -337,7 +337,7 @@ public class TestCompetingRisk { } @Test - public void testLogRankSingleGroupDifferentiatorAllCovariates() throws IOException { + public void testLogRankSplitFinderAllCovariates() throws IOException { final Settings settings = getSettings(); settings.setNtree(300); // results are too variable at 100 diff --git a/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestLogRankDifferentiator.java b/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestLogRankSplitFinder.java similarity index 91% rename from src/test/java/ca/joeltherrien/randomforest/competingrisk/TestLogRankDifferentiator.java rename to src/test/java/ca/joeltherrien/randomforest/competingrisk/TestLogRankSplitFinder.java index 6a600dd..259410b 100644 --- a/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestLogRankDifferentiator.java +++ b/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestLogRankSplitFinder.java @@ -21,7 +21,7 @@ import ca.joeltherrien.randomforest.Settings; import ca.joeltherrien.randomforest.covariates.Covariate; import ca.joeltherrien.randomforest.covariates.settings.NumericCovariateSettings; import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse; -import ca.joeltherrien.randomforest.responses.competingrisk.differentiator.LogRankDifferentiator; +import ca.joeltherrien.randomforest.responses.competingrisk.splitfinder.LogRankSplitFinder; import ca.joeltherrien.randomforest.tree.Split; import ca.joeltherrien.randomforest.utils.DataUtils; import ca.joeltherrien.randomforest.utils.SingletonIterator; @@ -39,7 +39,7 @@ import java.util.List; import static org.junit.jupiter.api.Assertions.assertTrue; -public class TestLogRankDifferentiator { +public class TestLogRankSplitFinder { private Iterator> turnIntoSplitIterator(List> leftList, List> rightList){ @@ -70,14 +70,14 @@ public class TestLogRankDifferentiator { @Test public void testSplitRule() throws IOException { - final LogRankDifferentiator groupDifferentiator = new LogRankDifferentiator(new int[]{1,2}, new int[]{1,2}); + final LogRankSplitFinder splitFinder = new LogRankSplitFinder(new int[]{1,2}, new int[]{1,2}); final List> data = loadData("src/test/resources/test_split_data.csv").getRows(); final List> group1Bad = data.subList(0, 196); final List> group2Bad = data.subList(196, data.size()); - final double scoreBad = groupDifferentiator.differentiate(turnIntoSplitIterator(group1Bad, group2Bad)).getScore(); + final double scoreBad = splitFinder.findBestSplit(turnIntoSplitIterator(group1Bad, group2Bad)).getScore(); // expected results calculated manually using survival::survdiff in R; see issue #10 in Gitea closeEnough(9.413002, scoreBad, 0.00001); diff --git a/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestLogRankSingleGroupDifferentiator.java b/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestLogRankSplitFinderSingleEvent.java similarity index 80% rename from src/test/java/ca/joeltherrien/randomforest/competingrisk/TestLogRankSingleGroupDifferentiator.java rename to src/test/java/ca/joeltherrien/randomforest/competingrisk/TestLogRankSplitFinderSingleEvent.java index 8b0693f..205500f 100644 --- a/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestLogRankSingleGroupDifferentiator.java +++ b/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestLogRankSplitFinderSingleEvent.java @@ -18,8 +18,8 @@ package ca.joeltherrien.randomforest.competingrisk; import ca.joeltherrien.randomforest.Row; import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse; -import ca.joeltherrien.randomforest.responses.competingrisk.differentiator.LogRankDifferentiator; -import ca.joeltherrien.randomforest.tree.GroupDifferentiator; +import ca.joeltherrien.randomforest.responses.competingrisk.splitfinder.LogRankSplitFinder; +import ca.joeltherrien.randomforest.tree.SplitFinder; import ca.joeltherrien.randomforest.tree.Split; import ca.joeltherrien.randomforest.utils.SingletonIterator; import org.junit.jupiter.api.Test; @@ -33,13 +33,13 @@ import java.util.List; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertTrue; -public class TestLogRankSingleGroupDifferentiator { +public class TestLogRankSplitFinderSingleEvent { - private double getScore(final GroupDifferentiator groupDifferentiator, List> left, List> right){ + private double getScore(final SplitFinder splitFinder, List> left, List> right){ final Iterator> iterator = new SingletonIterator<>( new Split<>(null, left, right, Collections.emptyList())); - return groupDifferentiator.differentiate(iterator).getScore(); + return splitFinder.findBestSplit(iterator).getScore(); } @@ -81,9 +81,9 @@ public class TestLogRankSingleGroupDifferentiator { final List> data1 = generateData1(); final List> data2 = generateData2(); - final LogRankDifferentiator differentiator = new LogRankDifferentiator(new int[]{1}, new int[]{1}); + final LogRankSplitFinder splitFinder = new LogRankSplitFinder(new int[]{1}, new int[]{1}); - final double score = getScore(differentiator, data1, data2); + final double score = getScore(splitFinder, data1, data2); final double margin = 0.000001; // Tested using 855 method @@ -94,21 +94,21 @@ public class TestLogRankSingleGroupDifferentiator { @Test public void testCorrectSplit() throws IOException { - final LogRankDifferentiator groupDifferentiator = - new LogRankDifferentiator(new int[]{1}, new int[]{1,2}); + final LogRankSplitFinder splitFinder = + new LogRankSplitFinder(new int[]{1}, new int[]{1,2}); - final List> data = TestLogRankDifferentiator. + final List> data = TestLogRankSplitFinder. loadData("src/test/resources/test_single_split.csv").getRows(); final List> group1Good = data.subList(0, 221); final List> group2Good = data.subList(221, data.size()); - final double scoreGood = getScore(groupDifferentiator, group1Good, group2Good); + final double scoreGood = getScore(splitFinder, group1Good, group2Good); final List> group1Bad = data.subList(0, 222); final List> group2Bad = data.subList(222, data.size()); - final double scoreBad = getScore(groupDifferentiator, group1Bad, group2Bad); + final double scoreBad = getScore(splitFinder, group1Bad, group2Bad); // Apparently not all groups are unique when splitting assertEquals(scoreGood, scoreBad); diff --git a/src/test/java/ca/joeltherrien/randomforest/nas/TestNAs.java b/src/test/java/ca/joeltherrien/randomforest/nas/TestNAs.java index d748f1e..cb9a116 100644 --- a/src/test/java/ca/joeltherrien/randomforest/nas/TestNAs.java +++ b/src/test/java/ca/joeltherrien/randomforest/nas/TestNAs.java @@ -4,7 +4,7 @@ import ca.joeltherrien.randomforest.Row; import ca.joeltherrien.randomforest.covariates.Covariate; import ca.joeltherrien.randomforest.covariates.numeric.NumericCovariate; import ca.joeltherrien.randomforest.responses.regression.MeanResponseCombiner; -import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceGroupDifferentiator; +import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceSplitFinder; import ca.joeltherrien.randomforest.tree.TreeTrainer; import ca.joeltherrien.randomforest.utils.Utils; import org.junit.jupiter.api.Test; @@ -47,7 +47,7 @@ public class TestNAs { .numberOfSplits(0) .nodeSize(1) .maxNodeDepth(1000) - .groupDifferentiator(new WeightedVarianceGroupDifferentiator()) + .splitFinder(new WeightedVarianceSplitFinder()) .responseCombiner(new MeanResponseCombiner()) .build(); diff --git a/src/test/java/ca/joeltherrien/randomforest/settings/TestPersistence.java b/src/test/java/ca/joeltherrien/randomforest/settings/TestPersistence.java index d1d0a41..acac42f 100644 --- a/src/test/java/ca/joeltherrien/randomforest/settings/TestPersistence.java +++ b/src/test/java/ca/joeltherrien/randomforest/settings/TestPersistence.java @@ -35,8 +35,8 @@ public class TestPersistence { @Test public void testSaving() throws IOException { - final ObjectNode groupDifferentiatorSettings = new ObjectNode(JsonNodeFactory.instance); - groupDifferentiatorSettings.set("type", new TextNode("WeightedVarianceGroupDifferentiator")); + final ObjectNode splitFinderSettings = new ObjectNode(JsonNodeFactory.instance); + splitFinderSettings.set("type", new TextNode("WeightedVarianceSplitFinder")); final ObjectNode responseCombinerSettings = new ObjectNode(JsonNodeFactory.instance); responseCombinerSettings.set("type", new TextNode("MeanResponseCombiner")); @@ -59,7 +59,7 @@ public class TestPersistence { .validationDataLocation("validation_data.csv") .responseCombinerSettings(responseCombinerSettings) .treeCombinerSettings(treeCombinerSettings) - .groupDifferentiatorSettings(groupDifferentiatorSettings) + .splitFinderSettings(splitFinderSettings) .yVarSettings(yVarSettings) .maxNodeDepth(100000) .mtry(2) diff --git a/src/test/java/ca/joeltherrien/randomforest/workshop/TrainForest.java b/src/test/java/ca/joeltherrien/randomforest/workshop/TrainForest.java index 8579e0e..b27045e 100644 --- a/src/test/java/ca/joeltherrien/randomforest/workshop/TrainForest.java +++ b/src/test/java/ca/joeltherrien/randomforest/workshop/TrainForest.java @@ -20,7 +20,7 @@ import ca.joeltherrien.randomforest.*; import ca.joeltherrien.randomforest.covariates.Covariate; import ca.joeltherrien.randomforest.covariates.numeric.NumericCovariate; import ca.joeltherrien.randomforest.responses.regression.MeanResponseCombiner; -import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceGroupDifferentiator; +import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceSplitFinder; import ca.joeltherrien.randomforest.tree.ForestTrainer; import ca.joeltherrien.randomforest.tree.TreeTrainer; @@ -72,7 +72,7 @@ public class TrainForest { .nodeSize(5) .mtry(4) .maxNodeDepth(100000000) - .groupDifferentiator(new WeightedVarianceGroupDifferentiator()) + .splitFinder(new WeightedVarianceSplitFinder()) .responseCombiner(new MeanResponseCombiner()) .build(); diff --git a/src/test/java/ca/joeltherrien/randomforest/workshop/TrainSingleTree.java b/src/test/java/ca/joeltherrien/randomforest/workshop/TrainSingleTree.java index 1527d4a..8e1e361 100644 --- a/src/test/java/ca/joeltherrien/randomforest/workshop/TrainSingleTree.java +++ b/src/test/java/ca/joeltherrien/randomforest/workshop/TrainSingleTree.java @@ -22,7 +22,7 @@ import ca.joeltherrien.randomforest.CovariateRow; import ca.joeltherrien.randomforest.covariates.numeric.NumericCovariate; import ca.joeltherrien.randomforest.Row; import ca.joeltherrien.randomforest.responses.regression.MeanResponseCombiner; -import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceGroupDifferentiator; +import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceSplitFinder; import ca.joeltherrien.randomforest.tree.Node; import ca.joeltherrien.randomforest.tree.TreeTrainer; import ca.joeltherrien.randomforest.utils.Utils; @@ -66,7 +66,7 @@ public class TrainSingleTree { final List covariateNames = Utils.easyList(x1Covariate, x2Covariate); final TreeTrainer treeTrainer = TreeTrainer.builder() - .groupDifferentiator(new WeightedVarianceGroupDifferentiator()) + .splitFinder(new WeightedVarianceSplitFinder()) .covariates(covariateNames) .responseCombiner(new MeanResponseCombiner()) .maxNodeDepth(30) diff --git a/src/test/java/ca/joeltherrien/randomforest/workshop/TrainSingleTreeFactor.java b/src/test/java/ca/joeltherrien/randomforest/workshop/TrainSingleTreeFactor.java index 7313d22..1d2b404 100644 --- a/src/test/java/ca/joeltherrien/randomforest/workshop/TrainSingleTreeFactor.java +++ b/src/test/java/ca/joeltherrien/randomforest/workshop/TrainSingleTreeFactor.java @@ -22,7 +22,7 @@ import ca.joeltherrien.randomforest.covariates.Covariate; import ca.joeltherrien.randomforest.covariates.factor.FactorCovariate; import ca.joeltherrien.randomforest.covariates.numeric.NumericCovariate; import ca.joeltherrien.randomforest.responses.regression.MeanResponseCombiner; -import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceGroupDifferentiator; +import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceSplitFinder; import ca.joeltherrien.randomforest.tree.Node; import ca.joeltherrien.randomforest.tree.TreeTrainer; import ca.joeltherrien.randomforest.utils.Utils; @@ -86,7 +86,7 @@ public class TrainSingleTreeFactor { final List covariateNames = Utils.easyList(x1Covariate, x2Covariate); final TreeTrainer treeTrainer = TreeTrainer.builder() - .groupDifferentiator(new WeightedVarianceGroupDifferentiator()) + .splitFinder(new WeightedVarianceSplitFinder()) .responseCombiner(new MeanResponseCombiner()) .covariates(covariateNames) .maxNodeDepth(30)