Refactor - rename GroupDifferentiators into SplitFinders

SplitRule would have made more sense but it was already taken.
This commit is contained in:
Joel Therrien 2019-05-08 16:09:09 -07:00
parent c5c74ad7e9
commit 17ae3a9f5a
18 changed files with 93 additions and 93 deletions

View file

@ -180,8 +180,8 @@ public class Main {
private static Settings defaultTemplate(){
final ObjectNode groupDifferentiatorSettings = new ObjectNode(JsonNodeFactory.instance);
groupDifferentiatorSettings.set("type", new TextNode("WeightedVarianceGroupDifferentiator"));
final ObjectNode splitFinderSettings = new ObjectNode(JsonNodeFactory.instance);
splitFinderSettings.set("type", new TextNode("WeightedVarianceSplitFinder"));
final ObjectNode responseCombinerSettings = new ObjectNode(JsonNodeFactory.instance);
responseCombinerSettings.set("type", new TextNode("MeanResponseCombiner"));
@ -204,7 +204,7 @@ public class Main {
.validationDataLocation("validation_data.csv")
.responseCombinerSettings(responseCombinerSettings)
.treeCombinerSettings(treeCombinerSettings)
.groupDifferentiatorSettings(groupDifferentiatorSettings)
.splitFinderSettings(splitFinderSettings)
.yVarSettings(yVarSettings)
.maxNodeDepth(100000)
.mtry(2)

View file

@ -22,11 +22,11 @@ import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskRespons
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponseWithCensorTime;
import ca.joeltherrien.randomforest.responses.competingrisk.combiner.CompetingRiskFunctionCombiner;
import ca.joeltherrien.randomforest.responses.competingrisk.combiner.CompetingRiskResponseCombiner;
import ca.joeltherrien.randomforest.responses.competingrisk.differentiator.GrayLogRankDifferentiator;
import ca.joeltherrien.randomforest.responses.competingrisk.differentiator.LogRankDifferentiator;
import ca.joeltherrien.randomforest.responses.competingrisk.splitfinder.GrayLogRankSplitFinder;
import ca.joeltherrien.randomforest.responses.competingrisk.splitfinder.LogRankSplitFinder;
import ca.joeltherrien.randomforest.responses.regression.MeanResponseCombiner;
import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceGroupDifferentiator;
import ca.joeltherrien.randomforest.tree.GroupDifferentiator;
import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceSplitFinder;
import ca.joeltherrien.randomforest.tree.SplitFinder;
import ca.joeltherrien.randomforest.tree.ResponseCombiner;
import ca.joeltherrien.randomforest.utils.DataUtils;
import ca.joeltherrien.randomforest.utils.Utils;
@ -77,31 +77,31 @@ public class Settings {
);
}
private static Map<String, Function<ObjectNode, GroupDifferentiator>> GROUP_DIFFERENTIATOR_MAP = new HashMap<>();
public static Function<ObjectNode, GroupDifferentiator> getGroupDifferentiatorConstructor(final String name){
return GROUP_DIFFERENTIATOR_MAP.get(name.toLowerCase());
private static Map<String, Function<ObjectNode, SplitFinder>> SPLIT_FINDER_MAP = new HashMap<>();
public static Function<ObjectNode, SplitFinder> getSplitFinderConstructor(final String name){
return SPLIT_FINDER_MAP.get(name.toLowerCase());
}
public static void registerGroupDifferentiatorConstructor(final String name, final Function<ObjectNode, GroupDifferentiator> groupDifferentiatorConstructor){
GROUP_DIFFERENTIATOR_MAP.put(name.toLowerCase(), groupDifferentiatorConstructor);
public static void registerSplitFinderConstructor(final String name, final Function<ObjectNode, SplitFinder> splitFinderConstructor){
SPLIT_FINDER_MAP.put(name.toLowerCase(), splitFinderConstructor);
}
static{
registerGroupDifferentiatorConstructor("WeightedVarianceGroupDifferentiator",
(node) -> new WeightedVarianceGroupDifferentiator()
registerSplitFinderConstructor("WeightedVarianceSplitFinder",
(node) -> new WeightedVarianceSplitFinder()
);
registerGroupDifferentiatorConstructor("GrayLogRankDifferentiator",
registerSplitFinderConstructor("GrayLogRankSplitFinder",
(objectNode) -> {
final int[] eventsOfFocusArray = Utils.jsonToIntArray(objectNode.get("eventsOfFocus"));
final int[] eventArray = Utils.jsonToIntArray(objectNode.get("events"));
return new GrayLogRankDifferentiator(eventsOfFocusArray, eventArray);
return new GrayLogRankSplitFinder(eventsOfFocusArray, eventArray);
}
);
registerGroupDifferentiatorConstructor("LogRankDifferentiator",
registerSplitFinderConstructor("LogRankSplitFinder",
(objectNode) -> {
final int[] eventsOfFocusArray = Utils.jsonToIntArray(objectNode.get("eventsOfFocus"));
final int[] eventArray = Utils.jsonToIntArray(objectNode.get("events"));
return new LogRankDifferentiator(eventsOfFocusArray, eventArray);
return new LogRankSplitFinder(eventsOfFocusArray, eventArray);
}
);
}
@ -153,7 +153,7 @@ public class Settings {
private boolean checkNodePurity = false;
private ObjectNode responseCombinerSettings = new ObjectNode(JsonNodeFactory.instance);
private ObjectNode groupDifferentiatorSettings = new ObjectNode(JsonNodeFactory.instance);
private ObjectNode splitFinderSettings = new ObjectNode(JsonNodeFactory.instance);
private ObjectNode treeCombinerSettings = new ObjectNode(JsonNodeFactory.instance);
private List<CovariateSettings> covariateSettings = new ArrayList<>();
@ -194,10 +194,10 @@ public class Settings {
}
@JsonIgnore
public GroupDifferentiator getGroupDifferentiator(){
final String type = groupDifferentiatorSettings.get("type").asText();
public SplitFinder getSplitFinder(){
final String type = splitFinderSettings.get("type").asText();
return getGroupDifferentiatorConstructor(type).apply(groupDifferentiatorSettings);
return getSplitFinderConstructor(type).apply(splitFinderSettings);
}
@JsonIgnore

View file

@ -14,13 +14,13 @@
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
package ca.joeltherrien.randomforest.responses.competingrisk.differentiator;
package ca.joeltherrien.randomforest.responses.competingrisk.splitfinder;
import ca.joeltherrien.randomforest.Row;
import ca.joeltherrien.randomforest.covariates.Covariate;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskSets;
import ca.joeltherrien.randomforest.tree.GroupDifferentiator;
import ca.joeltherrien.randomforest.tree.SplitFinder;
import ca.joeltherrien.randomforest.tree.Split;
import ca.joeltherrien.randomforest.tree.SplitAndScore;
import lombok.AllArgsConstructor;
@ -35,24 +35,24 @@ import java.util.stream.Collectors;
* modifies the abstract method.
*
*/
public abstract class CompetingRiskGroupDifferentiator<Y extends CompetingRiskResponse> implements GroupDifferentiator<Y> {
public abstract class CompetingRiskSplitFinder<Y extends CompetingRiskResponse> implements SplitFinder<Y> {
abstract protected CompetingRiskSets<Y> createCompetingRiskSets(List<Y> leftHand, List<Y> rightHand);
abstract protected Double getScore(final CompetingRiskSets<Y> competingRiskSets);
@Override
public SplitAndScore<Y, ?> differentiate(Iterator<Split<Y, ?>> splitIterator) {
public SplitAndScore<Y, ?> findBestSplit(Iterator<Split<Y, ?>> splitIterator) {
if(splitIterator instanceof Covariate.SplitRuleUpdater){
return differentiateWithSplitUpdater((Covariate.SplitRuleUpdater) splitIterator);
return findBestSplitWithSplitUpdater((Covariate.SplitRuleUpdater) splitIterator);
}
else{
return differentiateWithBasicIterator(splitIterator);
return findBestSplitWithBasicIterator(splitIterator);
}
}
private SplitAndScore<Y, ?> differentiateWithBasicIterator(Iterator<Split<Y, ?>> splitIterator){
private SplitAndScore<Y, ?> findBestSplitWithBasicIterator(Iterator<Split<Y, ?>> splitIterator){
Double bestScore = null;
Split<Y, ?> bestSplit = null;
@ -83,7 +83,7 @@ public abstract class CompetingRiskGroupDifferentiator<Y extends CompetingRiskRe
return new SplitAndScore<>(bestSplit, bestScore);
}
private SplitAndScore<Y, ?> differentiateWithSplitUpdater(Covariate.SplitRuleUpdater<Y, ?> splitRuleUpdater) {
private SplitAndScore<Y, ?> findBestSplitWithSplitUpdater(Covariate.SplitRuleUpdater<Y, ?> splitRuleUpdater) {
final List<Y> leftInitialSplit = splitRuleUpdater.currentSplit().getLeftHand()
.stream().map(Row::getResponse).collect(Collectors.toList());

View file

@ -14,7 +14,7 @@
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
package ca.joeltherrien.randomforest.responses.competingrisk.differentiator;
package ca.joeltherrien.randomforest.responses.competingrisk.splitfinder;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponseWithCensorTime;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskSets;
@ -27,12 +27,12 @@ import java.util.List;
* See page 761 of Random survival forests for competing risks by Ishwaran et al.
*
*/
public class GrayLogRankDifferentiator extends CompetingRiskGroupDifferentiator<CompetingRiskResponseWithCensorTime> {
public class GrayLogRankSplitFinder extends CompetingRiskSplitFinder<CompetingRiskResponseWithCensorTime> {
private final int[] eventsOfFocus;
private final int[] events;
public GrayLogRankDifferentiator(int[] eventsOfFocus, int[] events){
public GrayLogRankSplitFinder(int[] eventsOfFocus, int[] events){
this.eventsOfFocus = eventsOfFocus;
this.events = events;

View file

@ -14,7 +14,7 @@
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
package ca.joeltherrien.randomforest.responses.competingrisk.differentiator;
package ca.joeltherrien.randomforest.responses.competingrisk.splitfinder;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskSets;
@ -27,12 +27,12 @@ import java.util.List;
* See page 761 of Random survival forests for competing risks by Ishwaran et al.
*
*/
public class LogRankDifferentiator extends CompetingRiskGroupDifferentiator<CompetingRiskResponse> {
public class LogRankSplitFinder extends CompetingRiskSplitFinder<CompetingRiskResponse> {
private final int[] eventsOfFocus;
private final int[] events;
public LogRankDifferentiator(int[] eventsOfFocus, int[] events){
public LogRankSplitFinder(int[] eventsOfFocus, int[] events){
this.eventsOfFocus = eventsOfFocus;
this.events = events;

View file

@ -18,7 +18,7 @@ package ca.joeltherrien.randomforest.responses.regression;
import ca.joeltherrien.randomforest.Row;
import ca.joeltherrien.randomforest.covariates.Covariate;
import ca.joeltherrien.randomforest.tree.GroupDifferentiator;
import ca.joeltherrien.randomforest.tree.SplitFinder;
import ca.joeltherrien.randomforest.tree.Split;
import ca.joeltherrien.randomforest.tree.SplitAndScore;
@ -26,7 +26,7 @@ import java.util.Iterator;
import java.util.List;
import java.util.stream.Collectors;
public class WeightedVarianceGroupDifferentiator implements GroupDifferentiator<Double> {
public class WeightedVarianceSplitFinder implements SplitFinder<Double> {
private Double getScore(Set leftHand, Set rightHand) {
@ -44,17 +44,17 @@ public class WeightedVarianceGroupDifferentiator implements GroupDifferentiator<
}
@Override
public SplitAndScore<Double, ?> differentiate(Iterator<Split<Double, ?>> splitIterator) {
public SplitAndScore<Double, ?> findBestSplit(Iterator<Split<Double, ?>> splitIterator) {
if(splitIterator instanceof Covariate.SplitRuleUpdater){
return differentiateWithSplitUpdater((Covariate.SplitRuleUpdater) splitIterator);
return findBestSplitWithSplitUpdater((Covariate.SplitRuleUpdater) splitIterator);
}
else{
return differentiateWithBasicIterator(splitIterator);
return findBestSplitWithBasicIterator(splitIterator);
}
}
private SplitAndScore<Double, ?> differentiateWithBasicIterator(Iterator<Split<Double, ?>> splitIterator){
private SplitAndScore<Double, ?> findBestSplitWithBasicIterator(Iterator<Split<Double, ?>> splitIterator){
Double bestScore = null;
Split<Double, ?> bestSplit = null;
@ -86,7 +86,7 @@ public class WeightedVarianceGroupDifferentiator implements GroupDifferentiator<
return new SplitAndScore<>(bestSplit, bestScore);
}
private SplitAndScore<Double, ?> differentiateWithSplitUpdater(Covariate.SplitRuleUpdater<Double, ?> splitRuleUpdater) {
private SplitAndScore<Double, ?> findBestSplitWithSplitUpdater(Covariate.SplitRuleUpdater<Double, ?> splitRuleUpdater) {
final List<Double> leftInitialSplit = splitRuleUpdater.currentSplit().getLeftHand()
.stream().map(Row::getResponse).collect(Collectors.toList());

View file

@ -22,10 +22,10 @@ import java.util.Iterator;
import java.util.List;
import java.util.stream.Collectors;
public abstract class SimpleGroupDifferentiator<Y> implements GroupDifferentiator<Y> {
public abstract class SimpleSplitFinder<Y> implements SplitFinder<Y> {
@Override
public SplitAndScore<Y, ?> differentiate(Iterator<Split<Y, ?>> splitIterator) {
public SplitAndScore<Y, ?> findBestSplit(Iterator<Split<Y, ?>> splitIterator) {
Double bestScore = null;
Split<Y, ?> bestSplit = null;

View file

@ -21,14 +21,14 @@ import java.util.Iterator;
/**
* When choosing an optimal node to split on, we choose the split that maximizes the difference between the two groups.
* The GroupDifferentiator has one method that cycles through an iterator of Splits (FYI; check if the iterator is an
* The SplitFinder has one method that cycles through an iterator of Splits (FYI; check if the iterator is an
* instance of Covariate.SplitRuleUpdater; in which case you get access to the rows that change between splits)
*
* If you want to implement a very trivial GroupDifferentiator that just takes two Lists as arguments, try extending
* SimpleGroupDifferentiator.
* If you want to implement a very trivial SplitFinder that just takes two Lists as arguments, try extending
* SimpleSplitFinder.
*/
public interface GroupDifferentiator<Y> extends Serializable {
public interface SplitFinder<Y> extends Serializable {
SplitAndScore<Y, ?> differentiate(Iterator<Split<Y, ?>> splitIterator);
SplitAndScore<Y, ?> findBestSplit(Iterator<Split<Y, ?>> splitIterator);
}

View file

@ -31,7 +31,7 @@ import java.util.stream.Collectors;
public class TreeTrainer<Y, O> {
private final ResponseCombiner<Y, O> responseCombiner;
private final GroupDifferentiator<Y> groupDifferentiator;
private final SplitFinder<Y> splitFinder;
/**
* The number of splits to perform on each covariate. A value of 0 means all possible splits are tried.
@ -58,7 +58,7 @@ public class TreeTrainer<Y, O> {
this.checkNodePurity = settings.isCheckNodePurity();
this.responseCombiner = settings.getResponseCombiner();
this.groupDifferentiator = settings.getGroupDifferentiator();
this.splitFinder = settings.getSplitFinder();
this.covariates = covariates;
}
@ -160,7 +160,7 @@ public class TreeTrainer<Y, O> {
private Split<Y, ?> findBestSplitRule(List<Row<Y>> data, List<Covariate> covariatesToTry, Random random){
SplitAndScore<Y, ?> bestSplitAndScore = null;
final GroupDifferentiator noGenericDifferentiator = groupDifferentiator; // cause Java generics suck
final SplitFinder noGenericSplitFinder = splitFinder; // cause Java generics are sometimes too frustrating
for(final Covariate covariate : covariatesToTry) {
final Iterator<Split> iterator = covariate.generateSplitRuleUpdater(data, this.numberOfSplits, random);
@ -170,7 +170,7 @@ public class TreeTrainer<Y, O> {
continue;
}
final SplitAndScore<Y, ?> candidateSplitAndScore = noGenericDifferentiator.differentiate(iterator);
final SplitAndScore<Y, ?> candidateSplitAndScore = noGenericSplitFinder.findBestSplit(iterator);
if(candidateSplitAndScore != null && (bestSplitAndScore == null ||
candidateSplitAndScore.getScore() > bestSplitAndScore.getScore())) {

View file

@ -44,12 +44,12 @@ public class TestSavingLoading {
* @return
*/
public Settings getSettings(){
final ObjectNode groupDifferentiatorSettings = new ObjectNode(JsonNodeFactory.instance);
groupDifferentiatorSettings.set("type", new TextNode("LogRankDifferentiator"));
groupDifferentiatorSettings.set("eventsOfFocus",
final ObjectNode splitRuleSettings = new ObjectNode(JsonNodeFactory.instance);
splitRuleSettings.set("type", new TextNode("LogRankSplitFinder"));
splitRuleSettings.set("eventsOfFocus",
new ArrayNode(JsonNodeFactory.instance, Utils.easyList(new IntNode(1)))
);
groupDifferentiatorSettings.set("events",
splitRuleSettings.set("events",
new ArrayNode(JsonNodeFactory.instance, Utils.easyList(new IntNode(1), new IntNode(2)))
);
@ -85,7 +85,7 @@ public class TestSavingLoading {
.validationDataLocation("src/test/resources/wihs.csv")
.responseCombinerSettings(responseCombinerSettings)
.treeCombinerSettings(treeCombinerSettings)
.groupDifferentiatorSettings(groupDifferentiatorSettings)
.splitFinderSettings(splitRuleSettings)
.yVarSettings(yVarSettings)
.maxNodeDepth(100000)
// TODO fill in these settings

View file

@ -53,12 +53,12 @@ public class TestCompetingRisk {
* @return
*/
public Settings getSettings(){
final ObjectNode groupDifferentiatorSettings = new ObjectNode(JsonNodeFactory.instance);
groupDifferentiatorSettings.set("type", new TextNode("LogRankDifferentiator"));
groupDifferentiatorSettings.set("eventsOfFocus",
final ObjectNode splitFinderSettings = new ObjectNode(JsonNodeFactory.instance);
splitFinderSettings.set("type", new TextNode("LogRankSplitFinder"));
splitFinderSettings.set("eventsOfFocus",
new ArrayNode(JsonNodeFactory.instance, Utils.easyList(new IntNode(1)))
);
groupDifferentiatorSettings.set("events",
splitFinderSettings.set("events",
new ArrayNode(JsonNodeFactory.instance, Utils.easyList(new IntNode(1), new IntNode(2)))
);
@ -93,7 +93,7 @@ public class TestCompetingRisk {
.trainingDataLocation("src/test/resources/wihs.csv")
.responseCombinerSettings(responseCombinerSettings)
.treeCombinerSettings(treeCombinerSettings)
.groupDifferentiatorSettings(groupDifferentiatorSettings)
.splitFinderSettings(splitFinderSettings)
.yVarSettings(yVarSettings)
.maxNodeDepth(100000)
// TODO fill in these settings
@ -222,7 +222,7 @@ public class TestCompetingRisk {
}
@Test
public void testLogRankSingleGroupDifferentiatorTwoBooleans() throws IOException {
public void testLogRankSplitFinderTwoBooleans() throws IOException {
final Settings settings = getSettings();
settings.setCovariateSettings(Utils.easyList(
new BooleanCovariateSettings("idu"),
@ -337,7 +337,7 @@ public class TestCompetingRisk {
}
@Test
public void testLogRankSingleGroupDifferentiatorAllCovariates() throws IOException {
public void testLogRankSplitFinderAllCovariates() throws IOException {
final Settings settings = getSettings();
settings.setNtree(300); // results are too variable at 100

View file

@ -21,7 +21,7 @@ import ca.joeltherrien.randomforest.Settings;
import ca.joeltherrien.randomforest.covariates.Covariate;
import ca.joeltherrien.randomforest.covariates.settings.NumericCovariateSettings;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
import ca.joeltherrien.randomforest.responses.competingrisk.differentiator.LogRankDifferentiator;
import ca.joeltherrien.randomforest.responses.competingrisk.splitfinder.LogRankSplitFinder;
import ca.joeltherrien.randomforest.tree.Split;
import ca.joeltherrien.randomforest.utils.DataUtils;
import ca.joeltherrien.randomforest.utils.SingletonIterator;
@ -39,7 +39,7 @@ import java.util.List;
import static org.junit.jupiter.api.Assertions.assertTrue;
public class TestLogRankDifferentiator {
public class TestLogRankSplitFinder {
private Iterator<Split<CompetingRiskResponse, ?>> turnIntoSplitIterator(List<Row<CompetingRiskResponse>> leftList,
List<Row<CompetingRiskResponse>> rightList){
@ -70,14 +70,14 @@ public class TestLogRankDifferentiator {
@Test
public void testSplitRule() throws IOException {
final LogRankDifferentiator groupDifferentiator = new LogRankDifferentiator(new int[]{1,2}, new int[]{1,2});
final LogRankSplitFinder splitFinder = new LogRankSplitFinder(new int[]{1,2}, new int[]{1,2});
final List<Row<CompetingRiskResponse>> data = loadData("src/test/resources/test_split_data.csv").getRows();
final List<Row<CompetingRiskResponse>> group1Bad = data.subList(0, 196);
final List<Row<CompetingRiskResponse>> group2Bad = data.subList(196, data.size());
final double scoreBad = groupDifferentiator.differentiate(turnIntoSplitIterator(group1Bad, group2Bad)).getScore();
final double scoreBad = splitFinder.findBestSplit(turnIntoSplitIterator(group1Bad, group2Bad)).getScore();
// expected results calculated manually using survival::survdiff in R; see issue #10 in Gitea
closeEnough(9.413002, scoreBad, 0.00001);

View file

@ -18,8 +18,8 @@ package ca.joeltherrien.randomforest.competingrisk;
import ca.joeltherrien.randomforest.Row;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
import ca.joeltherrien.randomforest.responses.competingrisk.differentiator.LogRankDifferentiator;
import ca.joeltherrien.randomforest.tree.GroupDifferentiator;
import ca.joeltherrien.randomforest.responses.competingrisk.splitfinder.LogRankSplitFinder;
import ca.joeltherrien.randomforest.tree.SplitFinder;
import ca.joeltherrien.randomforest.tree.Split;
import ca.joeltherrien.randomforest.utils.SingletonIterator;
import org.junit.jupiter.api.Test;
@ -33,13 +33,13 @@ import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
public class TestLogRankSingleGroupDifferentiator {
public class TestLogRankSplitFinderSingleEvent {
private double getScore(final GroupDifferentiator<CompetingRiskResponse> groupDifferentiator, List<Row<CompetingRiskResponse>> left, List<Row<CompetingRiskResponse>> right){
private double getScore(final SplitFinder<CompetingRiskResponse> splitFinder, List<Row<CompetingRiskResponse>> left, List<Row<CompetingRiskResponse>> right){
final Iterator<Split<CompetingRiskResponse, ?>> iterator = new SingletonIterator<>(
new Split<>(null, left, right, Collections.emptyList()));
return groupDifferentiator.differentiate(iterator).getScore();
return splitFinder.findBestSplit(iterator).getScore();
}
@ -81,9 +81,9 @@ public class TestLogRankSingleGroupDifferentiator {
final List<Row<CompetingRiskResponse>> data1 = generateData1();
final List<Row<CompetingRiskResponse>> data2 = generateData2();
final LogRankDifferentiator differentiator = new LogRankDifferentiator(new int[]{1}, new int[]{1});
final LogRankSplitFinder splitFinder = new LogRankSplitFinder(new int[]{1}, new int[]{1});
final double score = getScore(differentiator, data1, data2);
final double score = getScore(splitFinder, data1, data2);
final double margin = 0.000001;
// Tested using 855 method
@ -94,21 +94,21 @@ public class TestLogRankSingleGroupDifferentiator {
@Test
public void testCorrectSplit() throws IOException {
final LogRankDifferentiator groupDifferentiator =
new LogRankDifferentiator(new int[]{1}, new int[]{1,2});
final LogRankSplitFinder splitFinder =
new LogRankSplitFinder(new int[]{1}, new int[]{1,2});
final List<Row<CompetingRiskResponse>> data = TestLogRankDifferentiator.
final List<Row<CompetingRiskResponse>> data = TestLogRankSplitFinder.
loadData("src/test/resources/test_single_split.csv").getRows();
final List<Row<CompetingRiskResponse>> group1Good = data.subList(0, 221);
final List<Row<CompetingRiskResponse>> group2Good = data.subList(221, data.size());
final double scoreGood = getScore(groupDifferentiator, group1Good, group2Good);
final double scoreGood = getScore(splitFinder, group1Good, group2Good);
final List<Row<CompetingRiskResponse>> group1Bad = data.subList(0, 222);
final List<Row<CompetingRiskResponse>> group2Bad = data.subList(222, data.size());
final double scoreBad = getScore(groupDifferentiator, group1Bad, group2Bad);
final double scoreBad = getScore(splitFinder, group1Bad, group2Bad);
// Apparently not all groups are unique when splitting
assertEquals(scoreGood, scoreBad);

View file

@ -4,7 +4,7 @@ import ca.joeltherrien.randomforest.Row;
import ca.joeltherrien.randomforest.covariates.Covariate;
import ca.joeltherrien.randomforest.covariates.numeric.NumericCovariate;
import ca.joeltherrien.randomforest.responses.regression.MeanResponseCombiner;
import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceGroupDifferentiator;
import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceSplitFinder;
import ca.joeltherrien.randomforest.tree.TreeTrainer;
import ca.joeltherrien.randomforest.utils.Utils;
import org.junit.jupiter.api.Test;
@ -47,7 +47,7 @@ public class TestNAs {
.numberOfSplits(0)
.nodeSize(1)
.maxNodeDepth(1000)
.groupDifferentiator(new WeightedVarianceGroupDifferentiator())
.splitFinder(new WeightedVarianceSplitFinder())
.responseCombiner(new MeanResponseCombiner())
.build();

View file

@ -35,8 +35,8 @@ public class TestPersistence {
@Test
public void testSaving() throws IOException {
final ObjectNode groupDifferentiatorSettings = new ObjectNode(JsonNodeFactory.instance);
groupDifferentiatorSettings.set("type", new TextNode("WeightedVarianceGroupDifferentiator"));
final ObjectNode splitFinderSettings = new ObjectNode(JsonNodeFactory.instance);
splitFinderSettings.set("type", new TextNode("WeightedVarianceSplitFinder"));
final ObjectNode responseCombinerSettings = new ObjectNode(JsonNodeFactory.instance);
responseCombinerSettings.set("type", new TextNode("MeanResponseCombiner"));
@ -59,7 +59,7 @@ public class TestPersistence {
.validationDataLocation("validation_data.csv")
.responseCombinerSettings(responseCombinerSettings)
.treeCombinerSettings(treeCombinerSettings)
.groupDifferentiatorSettings(groupDifferentiatorSettings)
.splitFinderSettings(splitFinderSettings)
.yVarSettings(yVarSettings)
.maxNodeDepth(100000)
.mtry(2)

View file

@ -20,7 +20,7 @@ import ca.joeltherrien.randomforest.*;
import ca.joeltherrien.randomforest.covariates.Covariate;
import ca.joeltherrien.randomforest.covariates.numeric.NumericCovariate;
import ca.joeltherrien.randomforest.responses.regression.MeanResponseCombiner;
import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceGroupDifferentiator;
import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceSplitFinder;
import ca.joeltherrien.randomforest.tree.ForestTrainer;
import ca.joeltherrien.randomforest.tree.TreeTrainer;
@ -72,7 +72,7 @@ public class TrainForest {
.nodeSize(5)
.mtry(4)
.maxNodeDepth(100000000)
.groupDifferentiator(new WeightedVarianceGroupDifferentiator())
.splitFinder(new WeightedVarianceSplitFinder())
.responseCombiner(new MeanResponseCombiner())
.build();

View file

@ -22,7 +22,7 @@ import ca.joeltherrien.randomforest.CovariateRow;
import ca.joeltherrien.randomforest.covariates.numeric.NumericCovariate;
import ca.joeltherrien.randomforest.Row;
import ca.joeltherrien.randomforest.responses.regression.MeanResponseCombiner;
import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceGroupDifferentiator;
import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceSplitFinder;
import ca.joeltherrien.randomforest.tree.Node;
import ca.joeltherrien.randomforest.tree.TreeTrainer;
import ca.joeltherrien.randomforest.utils.Utils;
@ -66,7 +66,7 @@ public class TrainSingleTree {
final List<Covariate> covariateNames = Utils.easyList(x1Covariate, x2Covariate);
final TreeTrainer<Double, Double> treeTrainer = TreeTrainer.<Double, Double>builder()
.groupDifferentiator(new WeightedVarianceGroupDifferentiator())
.splitFinder(new WeightedVarianceSplitFinder())
.covariates(covariateNames)
.responseCombiner(new MeanResponseCombiner())
.maxNodeDepth(30)

View file

@ -22,7 +22,7 @@ import ca.joeltherrien.randomforest.covariates.Covariate;
import ca.joeltherrien.randomforest.covariates.factor.FactorCovariate;
import ca.joeltherrien.randomforest.covariates.numeric.NumericCovariate;
import ca.joeltherrien.randomforest.responses.regression.MeanResponseCombiner;
import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceGroupDifferentiator;
import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceSplitFinder;
import ca.joeltherrien.randomforest.tree.Node;
import ca.joeltherrien.randomforest.tree.TreeTrainer;
import ca.joeltherrien.randomforest.utils.Utils;
@ -86,7 +86,7 @@ public class TrainSingleTreeFactor {
final List<Covariate> covariateNames = Utils.easyList(x1Covariate, x2Covariate);
final TreeTrainer<Double, Double> treeTrainer = TreeTrainer.<Double, Double>builder()
.groupDifferentiator(new WeightedVarianceGroupDifferentiator())
.splitFinder(new WeightedVarianceSplitFinder())
.responseCombiner(new MeanResponseCombiner())
.covariates(covariateNames)
.maxNodeDepth(30)