Refactor - rename GroupDifferentiators into SplitFinders

SplitRule would have made more sense but it was already taken.
This commit is contained in:
Joel Therrien 2019-05-08 16:09:09 -07:00
parent c5c74ad7e9
commit 17ae3a9f5a
18 changed files with 93 additions and 93 deletions

View file

@ -180,8 +180,8 @@ public class Main {
private static Settings defaultTemplate(){ private static Settings defaultTemplate(){
final ObjectNode groupDifferentiatorSettings = new ObjectNode(JsonNodeFactory.instance); final ObjectNode splitFinderSettings = new ObjectNode(JsonNodeFactory.instance);
groupDifferentiatorSettings.set("type", new TextNode("WeightedVarianceGroupDifferentiator")); splitFinderSettings.set("type", new TextNode("WeightedVarianceSplitFinder"));
final ObjectNode responseCombinerSettings = new ObjectNode(JsonNodeFactory.instance); final ObjectNode responseCombinerSettings = new ObjectNode(JsonNodeFactory.instance);
responseCombinerSettings.set("type", new TextNode("MeanResponseCombiner")); responseCombinerSettings.set("type", new TextNode("MeanResponseCombiner"));
@ -204,7 +204,7 @@ public class Main {
.validationDataLocation("validation_data.csv") .validationDataLocation("validation_data.csv")
.responseCombinerSettings(responseCombinerSettings) .responseCombinerSettings(responseCombinerSettings)
.treeCombinerSettings(treeCombinerSettings) .treeCombinerSettings(treeCombinerSettings)
.groupDifferentiatorSettings(groupDifferentiatorSettings) .splitFinderSettings(splitFinderSettings)
.yVarSettings(yVarSettings) .yVarSettings(yVarSettings)
.maxNodeDepth(100000) .maxNodeDepth(100000)
.mtry(2) .mtry(2)

View file

@ -22,11 +22,11 @@ import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskRespons
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponseWithCensorTime; import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponseWithCensorTime;
import ca.joeltherrien.randomforest.responses.competingrisk.combiner.CompetingRiskFunctionCombiner; import ca.joeltherrien.randomforest.responses.competingrisk.combiner.CompetingRiskFunctionCombiner;
import ca.joeltherrien.randomforest.responses.competingrisk.combiner.CompetingRiskResponseCombiner; import ca.joeltherrien.randomforest.responses.competingrisk.combiner.CompetingRiskResponseCombiner;
import ca.joeltherrien.randomforest.responses.competingrisk.differentiator.GrayLogRankDifferentiator; import ca.joeltherrien.randomforest.responses.competingrisk.splitfinder.GrayLogRankSplitFinder;
import ca.joeltherrien.randomforest.responses.competingrisk.differentiator.LogRankDifferentiator; import ca.joeltherrien.randomforest.responses.competingrisk.splitfinder.LogRankSplitFinder;
import ca.joeltherrien.randomforest.responses.regression.MeanResponseCombiner; import ca.joeltherrien.randomforest.responses.regression.MeanResponseCombiner;
import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceGroupDifferentiator; import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceSplitFinder;
import ca.joeltherrien.randomforest.tree.GroupDifferentiator; import ca.joeltherrien.randomforest.tree.SplitFinder;
import ca.joeltherrien.randomforest.tree.ResponseCombiner; import ca.joeltherrien.randomforest.tree.ResponseCombiner;
import ca.joeltherrien.randomforest.utils.DataUtils; import ca.joeltherrien.randomforest.utils.DataUtils;
import ca.joeltherrien.randomforest.utils.Utils; import ca.joeltherrien.randomforest.utils.Utils;
@ -77,31 +77,31 @@ public class Settings {
); );
} }
private static Map<String, Function<ObjectNode, GroupDifferentiator>> GROUP_DIFFERENTIATOR_MAP = new HashMap<>(); private static Map<String, Function<ObjectNode, SplitFinder>> SPLIT_FINDER_MAP = new HashMap<>();
public static Function<ObjectNode, GroupDifferentiator> getGroupDifferentiatorConstructor(final String name){ public static Function<ObjectNode, SplitFinder> getSplitFinderConstructor(final String name){
return GROUP_DIFFERENTIATOR_MAP.get(name.toLowerCase()); return SPLIT_FINDER_MAP.get(name.toLowerCase());
} }
public static void registerGroupDifferentiatorConstructor(final String name, final Function<ObjectNode, GroupDifferentiator> groupDifferentiatorConstructor){ public static void registerSplitFinderConstructor(final String name, final Function<ObjectNode, SplitFinder> splitFinderConstructor){
GROUP_DIFFERENTIATOR_MAP.put(name.toLowerCase(), groupDifferentiatorConstructor); SPLIT_FINDER_MAP.put(name.toLowerCase(), splitFinderConstructor);
} }
static{ static{
registerGroupDifferentiatorConstructor("WeightedVarianceGroupDifferentiator", registerSplitFinderConstructor("WeightedVarianceSplitFinder",
(node) -> new WeightedVarianceGroupDifferentiator() (node) -> new WeightedVarianceSplitFinder()
); );
registerGroupDifferentiatorConstructor("GrayLogRankDifferentiator", registerSplitFinderConstructor("GrayLogRankSplitFinder",
(objectNode) -> { (objectNode) -> {
final int[] eventsOfFocusArray = Utils.jsonToIntArray(objectNode.get("eventsOfFocus")); final int[] eventsOfFocusArray = Utils.jsonToIntArray(objectNode.get("eventsOfFocus"));
final int[] eventArray = Utils.jsonToIntArray(objectNode.get("events")); final int[] eventArray = Utils.jsonToIntArray(objectNode.get("events"));
return new GrayLogRankDifferentiator(eventsOfFocusArray, eventArray); return new GrayLogRankSplitFinder(eventsOfFocusArray, eventArray);
} }
); );
registerGroupDifferentiatorConstructor("LogRankDifferentiator", registerSplitFinderConstructor("LogRankSplitFinder",
(objectNode) -> { (objectNode) -> {
final int[] eventsOfFocusArray = Utils.jsonToIntArray(objectNode.get("eventsOfFocus")); final int[] eventsOfFocusArray = Utils.jsonToIntArray(objectNode.get("eventsOfFocus"));
final int[] eventArray = Utils.jsonToIntArray(objectNode.get("events")); final int[] eventArray = Utils.jsonToIntArray(objectNode.get("events"));
return new LogRankDifferentiator(eventsOfFocusArray, eventArray); return new LogRankSplitFinder(eventsOfFocusArray, eventArray);
} }
); );
} }
@ -153,7 +153,7 @@ public class Settings {
private boolean checkNodePurity = false; private boolean checkNodePurity = false;
private ObjectNode responseCombinerSettings = new ObjectNode(JsonNodeFactory.instance); private ObjectNode responseCombinerSettings = new ObjectNode(JsonNodeFactory.instance);
private ObjectNode groupDifferentiatorSettings = new ObjectNode(JsonNodeFactory.instance); private ObjectNode splitFinderSettings = new ObjectNode(JsonNodeFactory.instance);
private ObjectNode treeCombinerSettings = new ObjectNode(JsonNodeFactory.instance); private ObjectNode treeCombinerSettings = new ObjectNode(JsonNodeFactory.instance);
private List<CovariateSettings> covariateSettings = new ArrayList<>(); private List<CovariateSettings> covariateSettings = new ArrayList<>();
@ -194,10 +194,10 @@ public class Settings {
} }
@JsonIgnore @JsonIgnore
public GroupDifferentiator getGroupDifferentiator(){ public SplitFinder getSplitFinder(){
final String type = groupDifferentiatorSettings.get("type").asText(); final String type = splitFinderSettings.get("type").asText();
return getGroupDifferentiatorConstructor(type).apply(groupDifferentiatorSettings); return getSplitFinderConstructor(type).apply(splitFinderSettings);
} }
@JsonIgnore @JsonIgnore

View file

@ -14,13 +14,13 @@
* along with this program. If not, see <https://www.gnu.org/licenses/>. * along with this program. If not, see <https://www.gnu.org/licenses/>.
*/ */
package ca.joeltherrien.randomforest.responses.competingrisk.differentiator; package ca.joeltherrien.randomforest.responses.competingrisk.splitfinder;
import ca.joeltherrien.randomforest.Row; import ca.joeltherrien.randomforest.Row;
import ca.joeltherrien.randomforest.covariates.Covariate; import ca.joeltherrien.randomforest.covariates.Covariate;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse; import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskSets; import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskSets;
import ca.joeltherrien.randomforest.tree.GroupDifferentiator; import ca.joeltherrien.randomforest.tree.SplitFinder;
import ca.joeltherrien.randomforest.tree.Split; import ca.joeltherrien.randomforest.tree.Split;
import ca.joeltherrien.randomforest.tree.SplitAndScore; import ca.joeltherrien.randomforest.tree.SplitAndScore;
import lombok.AllArgsConstructor; import lombok.AllArgsConstructor;
@ -35,24 +35,24 @@ import java.util.stream.Collectors;
* modifies the abstract method. * modifies the abstract method.
* *
*/ */
public abstract class CompetingRiskGroupDifferentiator<Y extends CompetingRiskResponse> implements GroupDifferentiator<Y> { public abstract class CompetingRiskSplitFinder<Y extends CompetingRiskResponse> implements SplitFinder<Y> {
abstract protected CompetingRiskSets<Y> createCompetingRiskSets(List<Y> leftHand, List<Y> rightHand); abstract protected CompetingRiskSets<Y> createCompetingRiskSets(List<Y> leftHand, List<Y> rightHand);
abstract protected Double getScore(final CompetingRiskSets<Y> competingRiskSets); abstract protected Double getScore(final CompetingRiskSets<Y> competingRiskSets);
@Override @Override
public SplitAndScore<Y, ?> differentiate(Iterator<Split<Y, ?>> splitIterator) { public SplitAndScore<Y, ?> findBestSplit(Iterator<Split<Y, ?>> splitIterator) {
if(splitIterator instanceof Covariate.SplitRuleUpdater){ if(splitIterator instanceof Covariate.SplitRuleUpdater){
return differentiateWithSplitUpdater((Covariate.SplitRuleUpdater) splitIterator); return findBestSplitWithSplitUpdater((Covariate.SplitRuleUpdater) splitIterator);
} }
else{ else{
return differentiateWithBasicIterator(splitIterator); return findBestSplitWithBasicIterator(splitIterator);
} }
} }
private SplitAndScore<Y, ?> differentiateWithBasicIterator(Iterator<Split<Y, ?>> splitIterator){ private SplitAndScore<Y, ?> findBestSplitWithBasicIterator(Iterator<Split<Y, ?>> splitIterator){
Double bestScore = null; Double bestScore = null;
Split<Y, ?> bestSplit = null; Split<Y, ?> bestSplit = null;
@ -83,7 +83,7 @@ public abstract class CompetingRiskGroupDifferentiator<Y extends CompetingRiskRe
return new SplitAndScore<>(bestSplit, bestScore); return new SplitAndScore<>(bestSplit, bestScore);
} }
private SplitAndScore<Y, ?> differentiateWithSplitUpdater(Covariate.SplitRuleUpdater<Y, ?> splitRuleUpdater) { private SplitAndScore<Y, ?> findBestSplitWithSplitUpdater(Covariate.SplitRuleUpdater<Y, ?> splitRuleUpdater) {
final List<Y> leftInitialSplit = splitRuleUpdater.currentSplit().getLeftHand() final List<Y> leftInitialSplit = splitRuleUpdater.currentSplit().getLeftHand()
.stream().map(Row::getResponse).collect(Collectors.toList()); .stream().map(Row::getResponse).collect(Collectors.toList());

View file

@ -14,7 +14,7 @@
* along with this program. If not, see <https://www.gnu.org/licenses/>. * along with this program. If not, see <https://www.gnu.org/licenses/>.
*/ */
package ca.joeltherrien.randomforest.responses.competingrisk.differentiator; package ca.joeltherrien.randomforest.responses.competingrisk.splitfinder;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponseWithCensorTime; import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponseWithCensorTime;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskSets; import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskSets;
@ -27,12 +27,12 @@ import java.util.List;
* See page 761 of Random survival forests for competing risks by Ishwaran et al. * See page 761 of Random survival forests for competing risks by Ishwaran et al.
* *
*/ */
public class GrayLogRankDifferentiator extends CompetingRiskGroupDifferentiator<CompetingRiskResponseWithCensorTime> { public class GrayLogRankSplitFinder extends CompetingRiskSplitFinder<CompetingRiskResponseWithCensorTime> {
private final int[] eventsOfFocus; private final int[] eventsOfFocus;
private final int[] events; private final int[] events;
public GrayLogRankDifferentiator(int[] eventsOfFocus, int[] events){ public GrayLogRankSplitFinder(int[] eventsOfFocus, int[] events){
this.eventsOfFocus = eventsOfFocus; this.eventsOfFocus = eventsOfFocus;
this.events = events; this.events = events;

View file

@ -14,7 +14,7 @@
* along with this program. If not, see <https://www.gnu.org/licenses/>. * along with this program. If not, see <https://www.gnu.org/licenses/>.
*/ */
package ca.joeltherrien.randomforest.responses.competingrisk.differentiator; package ca.joeltherrien.randomforest.responses.competingrisk.splitfinder;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse; import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskSets; import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskSets;
@ -27,12 +27,12 @@ import java.util.List;
* See page 761 of Random survival forests for competing risks by Ishwaran et al. * See page 761 of Random survival forests for competing risks by Ishwaran et al.
* *
*/ */
public class LogRankDifferentiator extends CompetingRiskGroupDifferentiator<CompetingRiskResponse> { public class LogRankSplitFinder extends CompetingRiskSplitFinder<CompetingRiskResponse> {
private final int[] eventsOfFocus; private final int[] eventsOfFocus;
private final int[] events; private final int[] events;
public LogRankDifferentiator(int[] eventsOfFocus, int[] events){ public LogRankSplitFinder(int[] eventsOfFocus, int[] events){
this.eventsOfFocus = eventsOfFocus; this.eventsOfFocus = eventsOfFocus;
this.events = events; this.events = events;

View file

@ -18,7 +18,7 @@ package ca.joeltherrien.randomforest.responses.regression;
import ca.joeltherrien.randomforest.Row; import ca.joeltherrien.randomforest.Row;
import ca.joeltherrien.randomforest.covariates.Covariate; import ca.joeltherrien.randomforest.covariates.Covariate;
import ca.joeltherrien.randomforest.tree.GroupDifferentiator; import ca.joeltherrien.randomforest.tree.SplitFinder;
import ca.joeltherrien.randomforest.tree.Split; import ca.joeltherrien.randomforest.tree.Split;
import ca.joeltherrien.randomforest.tree.SplitAndScore; import ca.joeltherrien.randomforest.tree.SplitAndScore;
@ -26,7 +26,7 @@ import java.util.Iterator;
import java.util.List; import java.util.List;
import java.util.stream.Collectors; import java.util.stream.Collectors;
public class WeightedVarianceGroupDifferentiator implements GroupDifferentiator<Double> { public class WeightedVarianceSplitFinder implements SplitFinder<Double> {
private Double getScore(Set leftHand, Set rightHand) { private Double getScore(Set leftHand, Set rightHand) {
@ -44,17 +44,17 @@ public class WeightedVarianceGroupDifferentiator implements GroupDifferentiator<
} }
@Override @Override
public SplitAndScore<Double, ?> differentiate(Iterator<Split<Double, ?>> splitIterator) { public SplitAndScore<Double, ?> findBestSplit(Iterator<Split<Double, ?>> splitIterator) {
if(splitIterator instanceof Covariate.SplitRuleUpdater){ if(splitIterator instanceof Covariate.SplitRuleUpdater){
return differentiateWithSplitUpdater((Covariate.SplitRuleUpdater) splitIterator); return findBestSplitWithSplitUpdater((Covariate.SplitRuleUpdater) splitIterator);
} }
else{ else{
return differentiateWithBasicIterator(splitIterator); return findBestSplitWithBasicIterator(splitIterator);
} }
} }
private SplitAndScore<Double, ?> differentiateWithBasicIterator(Iterator<Split<Double, ?>> splitIterator){ private SplitAndScore<Double, ?> findBestSplitWithBasicIterator(Iterator<Split<Double, ?>> splitIterator){
Double bestScore = null; Double bestScore = null;
Split<Double, ?> bestSplit = null; Split<Double, ?> bestSplit = null;
@ -86,7 +86,7 @@ public class WeightedVarianceGroupDifferentiator implements GroupDifferentiator<
return new SplitAndScore<>(bestSplit, bestScore); return new SplitAndScore<>(bestSplit, bestScore);
} }
private SplitAndScore<Double, ?> differentiateWithSplitUpdater(Covariate.SplitRuleUpdater<Double, ?> splitRuleUpdater) { private SplitAndScore<Double, ?> findBestSplitWithSplitUpdater(Covariate.SplitRuleUpdater<Double, ?> splitRuleUpdater) {
final List<Double> leftInitialSplit = splitRuleUpdater.currentSplit().getLeftHand() final List<Double> leftInitialSplit = splitRuleUpdater.currentSplit().getLeftHand()
.stream().map(Row::getResponse).collect(Collectors.toList()); .stream().map(Row::getResponse).collect(Collectors.toList());

View file

@ -22,10 +22,10 @@ import java.util.Iterator;
import java.util.List; import java.util.List;
import java.util.stream.Collectors; import java.util.stream.Collectors;
public abstract class SimpleGroupDifferentiator<Y> implements GroupDifferentiator<Y> { public abstract class SimpleSplitFinder<Y> implements SplitFinder<Y> {
@Override @Override
public SplitAndScore<Y, ?> differentiate(Iterator<Split<Y, ?>> splitIterator) { public SplitAndScore<Y, ?> findBestSplit(Iterator<Split<Y, ?>> splitIterator) {
Double bestScore = null; Double bestScore = null;
Split<Y, ?> bestSplit = null; Split<Y, ?> bestSplit = null;

View file

@ -21,14 +21,14 @@ import java.util.Iterator;
/** /**
* When choosing an optimal node to split on, we choose the split that maximizes the difference between the two groups. * When choosing an optimal node to split on, we choose the split that maximizes the difference between the two groups.
* The GroupDifferentiator has one method that cycles through an iterator of Splits (FYI; check if the iterator is an * The SplitFinder has one method that cycles through an iterator of Splits (FYI; check if the iterator is an
* instance of Covariate.SplitRuleUpdater; in which case you get access to the rows that change between splits) * instance of Covariate.SplitRuleUpdater; in which case you get access to the rows that change between splits)
* *
* If you want to implement a very trivial GroupDifferentiator that just takes two Lists as arguments, try extending * If you want to implement a very trivial SplitFinder that just takes two Lists as arguments, try extending
* SimpleGroupDifferentiator. * SimpleSplitFinder.
*/ */
public interface GroupDifferentiator<Y> extends Serializable { public interface SplitFinder<Y> extends Serializable {
SplitAndScore<Y, ?> differentiate(Iterator<Split<Y, ?>> splitIterator); SplitAndScore<Y, ?> findBestSplit(Iterator<Split<Y, ?>> splitIterator);
} }

View file

@ -31,7 +31,7 @@ import java.util.stream.Collectors;
public class TreeTrainer<Y, O> { public class TreeTrainer<Y, O> {
private final ResponseCombiner<Y, O> responseCombiner; private final ResponseCombiner<Y, O> responseCombiner;
private final GroupDifferentiator<Y> groupDifferentiator; private final SplitFinder<Y> splitFinder;
/** /**
* The number of splits to perform on each covariate. A value of 0 means all possible splits are tried. * The number of splits to perform on each covariate. A value of 0 means all possible splits are tried.
@ -58,7 +58,7 @@ public class TreeTrainer<Y, O> {
this.checkNodePurity = settings.isCheckNodePurity(); this.checkNodePurity = settings.isCheckNodePurity();
this.responseCombiner = settings.getResponseCombiner(); this.responseCombiner = settings.getResponseCombiner();
this.groupDifferentiator = settings.getGroupDifferentiator(); this.splitFinder = settings.getSplitFinder();
this.covariates = covariates; this.covariates = covariates;
} }
@ -160,7 +160,7 @@ public class TreeTrainer<Y, O> {
private Split<Y, ?> findBestSplitRule(List<Row<Y>> data, List<Covariate> covariatesToTry, Random random){ private Split<Y, ?> findBestSplitRule(List<Row<Y>> data, List<Covariate> covariatesToTry, Random random){
SplitAndScore<Y, ?> bestSplitAndScore = null; SplitAndScore<Y, ?> bestSplitAndScore = null;
final GroupDifferentiator noGenericDifferentiator = groupDifferentiator; // cause Java generics suck final SplitFinder noGenericSplitFinder = splitFinder; // cause Java generics are sometimes too frustrating
for(final Covariate covariate : covariatesToTry) { for(final Covariate covariate : covariatesToTry) {
final Iterator<Split> iterator = covariate.generateSplitRuleUpdater(data, this.numberOfSplits, random); final Iterator<Split> iterator = covariate.generateSplitRuleUpdater(data, this.numberOfSplits, random);
@ -170,7 +170,7 @@ public class TreeTrainer<Y, O> {
continue; continue;
} }
final SplitAndScore<Y, ?> candidateSplitAndScore = noGenericDifferentiator.differentiate(iterator); final SplitAndScore<Y, ?> candidateSplitAndScore = noGenericSplitFinder.findBestSplit(iterator);
if(candidateSplitAndScore != null && (bestSplitAndScore == null || if(candidateSplitAndScore != null && (bestSplitAndScore == null ||
candidateSplitAndScore.getScore() > bestSplitAndScore.getScore())) { candidateSplitAndScore.getScore() > bestSplitAndScore.getScore())) {

View file

@ -44,12 +44,12 @@ public class TestSavingLoading {
* @return * @return
*/ */
public Settings getSettings(){ public Settings getSettings(){
final ObjectNode groupDifferentiatorSettings = new ObjectNode(JsonNodeFactory.instance); final ObjectNode splitRuleSettings = new ObjectNode(JsonNodeFactory.instance);
groupDifferentiatorSettings.set("type", new TextNode("LogRankDifferentiator")); splitRuleSettings.set("type", new TextNode("LogRankSplitFinder"));
groupDifferentiatorSettings.set("eventsOfFocus", splitRuleSettings.set("eventsOfFocus",
new ArrayNode(JsonNodeFactory.instance, Utils.easyList(new IntNode(1))) new ArrayNode(JsonNodeFactory.instance, Utils.easyList(new IntNode(1)))
); );
groupDifferentiatorSettings.set("events", splitRuleSettings.set("events",
new ArrayNode(JsonNodeFactory.instance, Utils.easyList(new IntNode(1), new IntNode(2))) new ArrayNode(JsonNodeFactory.instance, Utils.easyList(new IntNode(1), new IntNode(2)))
); );
@ -85,7 +85,7 @@ public class TestSavingLoading {
.validationDataLocation("src/test/resources/wihs.csv") .validationDataLocation("src/test/resources/wihs.csv")
.responseCombinerSettings(responseCombinerSettings) .responseCombinerSettings(responseCombinerSettings)
.treeCombinerSettings(treeCombinerSettings) .treeCombinerSettings(treeCombinerSettings)
.groupDifferentiatorSettings(groupDifferentiatorSettings) .splitFinderSettings(splitRuleSettings)
.yVarSettings(yVarSettings) .yVarSettings(yVarSettings)
.maxNodeDepth(100000) .maxNodeDepth(100000)
// TODO fill in these settings // TODO fill in these settings

View file

@ -53,12 +53,12 @@ public class TestCompetingRisk {
* @return * @return
*/ */
public Settings getSettings(){ public Settings getSettings(){
final ObjectNode groupDifferentiatorSettings = new ObjectNode(JsonNodeFactory.instance); final ObjectNode splitFinderSettings = new ObjectNode(JsonNodeFactory.instance);
groupDifferentiatorSettings.set("type", new TextNode("LogRankDifferentiator")); splitFinderSettings.set("type", new TextNode("LogRankSplitFinder"));
groupDifferentiatorSettings.set("eventsOfFocus", splitFinderSettings.set("eventsOfFocus",
new ArrayNode(JsonNodeFactory.instance, Utils.easyList(new IntNode(1))) new ArrayNode(JsonNodeFactory.instance, Utils.easyList(new IntNode(1)))
); );
groupDifferentiatorSettings.set("events", splitFinderSettings.set("events",
new ArrayNode(JsonNodeFactory.instance, Utils.easyList(new IntNode(1), new IntNode(2))) new ArrayNode(JsonNodeFactory.instance, Utils.easyList(new IntNode(1), new IntNode(2)))
); );
@ -93,7 +93,7 @@ public class TestCompetingRisk {
.trainingDataLocation("src/test/resources/wihs.csv") .trainingDataLocation("src/test/resources/wihs.csv")
.responseCombinerSettings(responseCombinerSettings) .responseCombinerSettings(responseCombinerSettings)
.treeCombinerSettings(treeCombinerSettings) .treeCombinerSettings(treeCombinerSettings)
.groupDifferentiatorSettings(groupDifferentiatorSettings) .splitFinderSettings(splitFinderSettings)
.yVarSettings(yVarSettings) .yVarSettings(yVarSettings)
.maxNodeDepth(100000) .maxNodeDepth(100000)
// TODO fill in these settings // TODO fill in these settings
@ -222,7 +222,7 @@ public class TestCompetingRisk {
} }
@Test @Test
public void testLogRankSingleGroupDifferentiatorTwoBooleans() throws IOException { public void testLogRankSplitFinderTwoBooleans() throws IOException {
final Settings settings = getSettings(); final Settings settings = getSettings();
settings.setCovariateSettings(Utils.easyList( settings.setCovariateSettings(Utils.easyList(
new BooleanCovariateSettings("idu"), new BooleanCovariateSettings("idu"),
@ -337,7 +337,7 @@ public class TestCompetingRisk {
} }
@Test @Test
public void testLogRankSingleGroupDifferentiatorAllCovariates() throws IOException { public void testLogRankSplitFinderAllCovariates() throws IOException {
final Settings settings = getSettings(); final Settings settings = getSettings();
settings.setNtree(300); // results are too variable at 100 settings.setNtree(300); // results are too variable at 100

View file

@ -21,7 +21,7 @@ import ca.joeltherrien.randomforest.Settings;
import ca.joeltherrien.randomforest.covariates.Covariate; import ca.joeltherrien.randomforest.covariates.Covariate;
import ca.joeltherrien.randomforest.covariates.settings.NumericCovariateSettings; import ca.joeltherrien.randomforest.covariates.settings.NumericCovariateSettings;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse; import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
import ca.joeltherrien.randomforest.responses.competingrisk.differentiator.LogRankDifferentiator; import ca.joeltherrien.randomforest.responses.competingrisk.splitfinder.LogRankSplitFinder;
import ca.joeltherrien.randomforest.tree.Split; import ca.joeltherrien.randomforest.tree.Split;
import ca.joeltherrien.randomforest.utils.DataUtils; import ca.joeltherrien.randomforest.utils.DataUtils;
import ca.joeltherrien.randomforest.utils.SingletonIterator; import ca.joeltherrien.randomforest.utils.SingletonIterator;
@ -39,7 +39,7 @@ import java.util.List;
import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.assertTrue;
public class TestLogRankDifferentiator { public class TestLogRankSplitFinder {
private Iterator<Split<CompetingRiskResponse, ?>> turnIntoSplitIterator(List<Row<CompetingRiskResponse>> leftList, private Iterator<Split<CompetingRiskResponse, ?>> turnIntoSplitIterator(List<Row<CompetingRiskResponse>> leftList,
List<Row<CompetingRiskResponse>> rightList){ List<Row<CompetingRiskResponse>> rightList){
@ -70,14 +70,14 @@ public class TestLogRankDifferentiator {
@Test @Test
public void testSplitRule() throws IOException { public void testSplitRule() throws IOException {
final LogRankDifferentiator groupDifferentiator = new LogRankDifferentiator(new int[]{1,2}, new int[]{1,2}); final LogRankSplitFinder splitFinder = new LogRankSplitFinder(new int[]{1,2}, new int[]{1,2});
final List<Row<CompetingRiskResponse>> data = loadData("src/test/resources/test_split_data.csv").getRows(); final List<Row<CompetingRiskResponse>> data = loadData("src/test/resources/test_split_data.csv").getRows();
final List<Row<CompetingRiskResponse>> group1Bad = data.subList(0, 196); final List<Row<CompetingRiskResponse>> group1Bad = data.subList(0, 196);
final List<Row<CompetingRiskResponse>> group2Bad = data.subList(196, data.size()); final List<Row<CompetingRiskResponse>> group2Bad = data.subList(196, data.size());
final double scoreBad = groupDifferentiator.differentiate(turnIntoSplitIterator(group1Bad, group2Bad)).getScore(); final double scoreBad = splitFinder.findBestSplit(turnIntoSplitIterator(group1Bad, group2Bad)).getScore();
// expected results calculated manually using survival::survdiff in R; see issue #10 in Gitea // expected results calculated manually using survival::survdiff in R; see issue #10 in Gitea
closeEnough(9.413002, scoreBad, 0.00001); closeEnough(9.413002, scoreBad, 0.00001);

View file

@ -18,8 +18,8 @@ package ca.joeltherrien.randomforest.competingrisk;
import ca.joeltherrien.randomforest.Row; import ca.joeltherrien.randomforest.Row;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse; import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
import ca.joeltherrien.randomforest.responses.competingrisk.differentiator.LogRankDifferentiator; import ca.joeltherrien.randomforest.responses.competingrisk.splitfinder.LogRankSplitFinder;
import ca.joeltherrien.randomforest.tree.GroupDifferentiator; import ca.joeltherrien.randomforest.tree.SplitFinder;
import ca.joeltherrien.randomforest.tree.Split; import ca.joeltherrien.randomforest.tree.Split;
import ca.joeltherrien.randomforest.utils.SingletonIterator; import ca.joeltherrien.randomforest.utils.SingletonIterator;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
@ -33,13 +33,13 @@ import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.assertTrue;
public class TestLogRankSingleGroupDifferentiator { public class TestLogRankSplitFinderSingleEvent {
private double getScore(final GroupDifferentiator<CompetingRiskResponse> groupDifferentiator, List<Row<CompetingRiskResponse>> left, List<Row<CompetingRiskResponse>> right){ private double getScore(final SplitFinder<CompetingRiskResponse> splitFinder, List<Row<CompetingRiskResponse>> left, List<Row<CompetingRiskResponse>> right){
final Iterator<Split<CompetingRiskResponse, ?>> iterator = new SingletonIterator<>( final Iterator<Split<CompetingRiskResponse, ?>> iterator = new SingletonIterator<>(
new Split<>(null, left, right, Collections.emptyList())); new Split<>(null, left, right, Collections.emptyList()));
return groupDifferentiator.differentiate(iterator).getScore(); return splitFinder.findBestSplit(iterator).getScore();
} }
@ -81,9 +81,9 @@ public class TestLogRankSingleGroupDifferentiator {
final List<Row<CompetingRiskResponse>> data1 = generateData1(); final List<Row<CompetingRiskResponse>> data1 = generateData1();
final List<Row<CompetingRiskResponse>> data2 = generateData2(); final List<Row<CompetingRiskResponse>> data2 = generateData2();
final LogRankDifferentiator differentiator = new LogRankDifferentiator(new int[]{1}, new int[]{1}); final LogRankSplitFinder splitFinder = new LogRankSplitFinder(new int[]{1}, new int[]{1});
final double score = getScore(differentiator, data1, data2); final double score = getScore(splitFinder, data1, data2);
final double margin = 0.000001; final double margin = 0.000001;
// Tested using 855 method // Tested using 855 method
@ -94,21 +94,21 @@ public class TestLogRankSingleGroupDifferentiator {
@Test @Test
public void testCorrectSplit() throws IOException { public void testCorrectSplit() throws IOException {
final LogRankDifferentiator groupDifferentiator = final LogRankSplitFinder splitFinder =
new LogRankDifferentiator(new int[]{1}, new int[]{1,2}); new LogRankSplitFinder(new int[]{1}, new int[]{1,2});
final List<Row<CompetingRiskResponse>> data = TestLogRankDifferentiator. final List<Row<CompetingRiskResponse>> data = TestLogRankSplitFinder.
loadData("src/test/resources/test_single_split.csv").getRows(); loadData("src/test/resources/test_single_split.csv").getRows();
final List<Row<CompetingRiskResponse>> group1Good = data.subList(0, 221); final List<Row<CompetingRiskResponse>> group1Good = data.subList(0, 221);
final List<Row<CompetingRiskResponse>> group2Good = data.subList(221, data.size()); final List<Row<CompetingRiskResponse>> group2Good = data.subList(221, data.size());
final double scoreGood = getScore(groupDifferentiator, group1Good, group2Good); final double scoreGood = getScore(splitFinder, group1Good, group2Good);
final List<Row<CompetingRiskResponse>> group1Bad = data.subList(0, 222); final List<Row<CompetingRiskResponse>> group1Bad = data.subList(0, 222);
final List<Row<CompetingRiskResponse>> group2Bad = data.subList(222, data.size()); final List<Row<CompetingRiskResponse>> group2Bad = data.subList(222, data.size());
final double scoreBad = getScore(groupDifferentiator, group1Bad, group2Bad); final double scoreBad = getScore(splitFinder, group1Bad, group2Bad);
// Apparently not all groups are unique when splitting // Apparently not all groups are unique when splitting
assertEquals(scoreGood, scoreBad); assertEquals(scoreGood, scoreBad);

View file

@ -4,7 +4,7 @@ import ca.joeltherrien.randomforest.Row;
import ca.joeltherrien.randomforest.covariates.Covariate; import ca.joeltherrien.randomforest.covariates.Covariate;
import ca.joeltherrien.randomforest.covariates.numeric.NumericCovariate; import ca.joeltherrien.randomforest.covariates.numeric.NumericCovariate;
import ca.joeltherrien.randomforest.responses.regression.MeanResponseCombiner; import ca.joeltherrien.randomforest.responses.regression.MeanResponseCombiner;
import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceGroupDifferentiator; import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceSplitFinder;
import ca.joeltherrien.randomforest.tree.TreeTrainer; import ca.joeltherrien.randomforest.tree.TreeTrainer;
import ca.joeltherrien.randomforest.utils.Utils; import ca.joeltherrien.randomforest.utils.Utils;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
@ -47,7 +47,7 @@ public class TestNAs {
.numberOfSplits(0) .numberOfSplits(0)
.nodeSize(1) .nodeSize(1)
.maxNodeDepth(1000) .maxNodeDepth(1000)
.groupDifferentiator(new WeightedVarianceGroupDifferentiator()) .splitFinder(new WeightedVarianceSplitFinder())
.responseCombiner(new MeanResponseCombiner()) .responseCombiner(new MeanResponseCombiner())
.build(); .build();

View file

@ -35,8 +35,8 @@ public class TestPersistence {
@Test @Test
public void testSaving() throws IOException { public void testSaving() throws IOException {
final ObjectNode groupDifferentiatorSettings = new ObjectNode(JsonNodeFactory.instance); final ObjectNode splitFinderSettings = new ObjectNode(JsonNodeFactory.instance);
groupDifferentiatorSettings.set("type", new TextNode("WeightedVarianceGroupDifferentiator")); splitFinderSettings.set("type", new TextNode("WeightedVarianceSplitFinder"));
final ObjectNode responseCombinerSettings = new ObjectNode(JsonNodeFactory.instance); final ObjectNode responseCombinerSettings = new ObjectNode(JsonNodeFactory.instance);
responseCombinerSettings.set("type", new TextNode("MeanResponseCombiner")); responseCombinerSettings.set("type", new TextNode("MeanResponseCombiner"));
@ -59,7 +59,7 @@ public class TestPersistence {
.validationDataLocation("validation_data.csv") .validationDataLocation("validation_data.csv")
.responseCombinerSettings(responseCombinerSettings) .responseCombinerSettings(responseCombinerSettings)
.treeCombinerSettings(treeCombinerSettings) .treeCombinerSettings(treeCombinerSettings)
.groupDifferentiatorSettings(groupDifferentiatorSettings) .splitFinderSettings(splitFinderSettings)
.yVarSettings(yVarSettings) .yVarSettings(yVarSettings)
.maxNodeDepth(100000) .maxNodeDepth(100000)
.mtry(2) .mtry(2)

View file

@ -20,7 +20,7 @@ import ca.joeltherrien.randomforest.*;
import ca.joeltherrien.randomforest.covariates.Covariate; import ca.joeltherrien.randomforest.covariates.Covariate;
import ca.joeltherrien.randomforest.covariates.numeric.NumericCovariate; import ca.joeltherrien.randomforest.covariates.numeric.NumericCovariate;
import ca.joeltherrien.randomforest.responses.regression.MeanResponseCombiner; import ca.joeltherrien.randomforest.responses.regression.MeanResponseCombiner;
import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceGroupDifferentiator; import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceSplitFinder;
import ca.joeltherrien.randomforest.tree.ForestTrainer; import ca.joeltherrien.randomforest.tree.ForestTrainer;
import ca.joeltherrien.randomforest.tree.TreeTrainer; import ca.joeltherrien.randomforest.tree.TreeTrainer;
@ -72,7 +72,7 @@ public class TrainForest {
.nodeSize(5) .nodeSize(5)
.mtry(4) .mtry(4)
.maxNodeDepth(100000000) .maxNodeDepth(100000000)
.groupDifferentiator(new WeightedVarianceGroupDifferentiator()) .splitFinder(new WeightedVarianceSplitFinder())
.responseCombiner(new MeanResponseCombiner()) .responseCombiner(new MeanResponseCombiner())
.build(); .build();

View file

@ -22,7 +22,7 @@ import ca.joeltherrien.randomforest.CovariateRow;
import ca.joeltherrien.randomforest.covariates.numeric.NumericCovariate; import ca.joeltherrien.randomforest.covariates.numeric.NumericCovariate;
import ca.joeltherrien.randomforest.Row; import ca.joeltherrien.randomforest.Row;
import ca.joeltherrien.randomforest.responses.regression.MeanResponseCombiner; import ca.joeltherrien.randomforest.responses.regression.MeanResponseCombiner;
import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceGroupDifferentiator; import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceSplitFinder;
import ca.joeltherrien.randomforest.tree.Node; import ca.joeltherrien.randomforest.tree.Node;
import ca.joeltherrien.randomforest.tree.TreeTrainer; import ca.joeltherrien.randomforest.tree.TreeTrainer;
import ca.joeltherrien.randomforest.utils.Utils; import ca.joeltherrien.randomforest.utils.Utils;
@ -66,7 +66,7 @@ public class TrainSingleTree {
final List<Covariate> covariateNames = Utils.easyList(x1Covariate, x2Covariate); final List<Covariate> covariateNames = Utils.easyList(x1Covariate, x2Covariate);
final TreeTrainer<Double, Double> treeTrainer = TreeTrainer.<Double, Double>builder() final TreeTrainer<Double, Double> treeTrainer = TreeTrainer.<Double, Double>builder()
.groupDifferentiator(new WeightedVarianceGroupDifferentiator()) .splitFinder(new WeightedVarianceSplitFinder())
.covariates(covariateNames) .covariates(covariateNames)
.responseCombiner(new MeanResponseCombiner()) .responseCombiner(new MeanResponseCombiner())
.maxNodeDepth(30) .maxNodeDepth(30)

View file

@ -22,7 +22,7 @@ import ca.joeltherrien.randomforest.covariates.Covariate;
import ca.joeltherrien.randomforest.covariates.factor.FactorCovariate; import ca.joeltherrien.randomforest.covariates.factor.FactorCovariate;
import ca.joeltherrien.randomforest.covariates.numeric.NumericCovariate; import ca.joeltherrien.randomforest.covariates.numeric.NumericCovariate;
import ca.joeltherrien.randomforest.responses.regression.MeanResponseCombiner; import ca.joeltherrien.randomforest.responses.regression.MeanResponseCombiner;
import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceGroupDifferentiator; import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceSplitFinder;
import ca.joeltherrien.randomforest.tree.Node; import ca.joeltherrien.randomforest.tree.Node;
import ca.joeltherrien.randomforest.tree.TreeTrainer; import ca.joeltherrien.randomforest.tree.TreeTrainer;
import ca.joeltherrien.randomforest.utils.Utils; import ca.joeltherrien.randomforest.utils.Utils;
@ -86,7 +86,7 @@ public class TrainSingleTreeFactor {
final List<Covariate> covariateNames = Utils.easyList(x1Covariate, x2Covariate); final List<Covariate> covariateNames = Utils.easyList(x1Covariate, x2Covariate);
final TreeTrainer<Double, Double> treeTrainer = TreeTrainer.<Double, Double>builder() final TreeTrainer<Double, Double> treeTrainer = TreeTrainer.<Double, Double>builder()
.groupDifferentiator(new WeightedVarianceGroupDifferentiator()) .splitFinder(new WeightedVarianceSplitFinder())
.responseCombiner(new MeanResponseCombiner()) .responseCombiner(new MeanResponseCombiner())
.covariates(covariateNames) .covariates(covariateNames)
.maxNodeDepth(30) .maxNodeDepth(30)