Refactor - rename GroupDifferentiators into SplitFinders
SplitRule would have made more sense but it was already taken.
This commit is contained in:
parent
c5c74ad7e9
commit
17ae3a9f5a
18 changed files with 93 additions and 93 deletions
|
@ -180,8 +180,8 @@ public class Main {
|
||||||
|
|
||||||
private static Settings defaultTemplate(){
|
private static Settings defaultTemplate(){
|
||||||
|
|
||||||
final ObjectNode groupDifferentiatorSettings = new ObjectNode(JsonNodeFactory.instance);
|
final ObjectNode splitFinderSettings = new ObjectNode(JsonNodeFactory.instance);
|
||||||
groupDifferentiatorSettings.set("type", new TextNode("WeightedVarianceGroupDifferentiator"));
|
splitFinderSettings.set("type", new TextNode("WeightedVarianceSplitFinder"));
|
||||||
|
|
||||||
final ObjectNode responseCombinerSettings = new ObjectNode(JsonNodeFactory.instance);
|
final ObjectNode responseCombinerSettings = new ObjectNode(JsonNodeFactory.instance);
|
||||||
responseCombinerSettings.set("type", new TextNode("MeanResponseCombiner"));
|
responseCombinerSettings.set("type", new TextNode("MeanResponseCombiner"));
|
||||||
|
@ -204,7 +204,7 @@ public class Main {
|
||||||
.validationDataLocation("validation_data.csv")
|
.validationDataLocation("validation_data.csv")
|
||||||
.responseCombinerSettings(responseCombinerSettings)
|
.responseCombinerSettings(responseCombinerSettings)
|
||||||
.treeCombinerSettings(treeCombinerSettings)
|
.treeCombinerSettings(treeCombinerSettings)
|
||||||
.groupDifferentiatorSettings(groupDifferentiatorSettings)
|
.splitFinderSettings(splitFinderSettings)
|
||||||
.yVarSettings(yVarSettings)
|
.yVarSettings(yVarSettings)
|
||||||
.maxNodeDepth(100000)
|
.maxNodeDepth(100000)
|
||||||
.mtry(2)
|
.mtry(2)
|
||||||
|
|
|
@ -22,11 +22,11 @@ import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskRespons
|
||||||
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponseWithCensorTime;
|
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponseWithCensorTime;
|
||||||
import ca.joeltherrien.randomforest.responses.competingrisk.combiner.CompetingRiskFunctionCombiner;
|
import ca.joeltherrien.randomforest.responses.competingrisk.combiner.CompetingRiskFunctionCombiner;
|
||||||
import ca.joeltherrien.randomforest.responses.competingrisk.combiner.CompetingRiskResponseCombiner;
|
import ca.joeltherrien.randomforest.responses.competingrisk.combiner.CompetingRiskResponseCombiner;
|
||||||
import ca.joeltherrien.randomforest.responses.competingrisk.differentiator.GrayLogRankDifferentiator;
|
import ca.joeltherrien.randomforest.responses.competingrisk.splitfinder.GrayLogRankSplitFinder;
|
||||||
import ca.joeltherrien.randomforest.responses.competingrisk.differentiator.LogRankDifferentiator;
|
import ca.joeltherrien.randomforest.responses.competingrisk.splitfinder.LogRankSplitFinder;
|
||||||
import ca.joeltherrien.randomforest.responses.regression.MeanResponseCombiner;
|
import ca.joeltherrien.randomforest.responses.regression.MeanResponseCombiner;
|
||||||
import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceGroupDifferentiator;
|
import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceSplitFinder;
|
||||||
import ca.joeltherrien.randomforest.tree.GroupDifferentiator;
|
import ca.joeltherrien.randomforest.tree.SplitFinder;
|
||||||
import ca.joeltherrien.randomforest.tree.ResponseCombiner;
|
import ca.joeltherrien.randomforest.tree.ResponseCombiner;
|
||||||
import ca.joeltherrien.randomforest.utils.DataUtils;
|
import ca.joeltherrien.randomforest.utils.DataUtils;
|
||||||
import ca.joeltherrien.randomforest.utils.Utils;
|
import ca.joeltherrien.randomforest.utils.Utils;
|
||||||
|
@ -77,31 +77,31 @@ public class Settings {
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
private static Map<String, Function<ObjectNode, GroupDifferentiator>> GROUP_DIFFERENTIATOR_MAP = new HashMap<>();
|
private static Map<String, Function<ObjectNode, SplitFinder>> SPLIT_FINDER_MAP = new HashMap<>();
|
||||||
public static Function<ObjectNode, GroupDifferentiator> getGroupDifferentiatorConstructor(final String name){
|
public static Function<ObjectNode, SplitFinder> getSplitFinderConstructor(final String name){
|
||||||
return GROUP_DIFFERENTIATOR_MAP.get(name.toLowerCase());
|
return SPLIT_FINDER_MAP.get(name.toLowerCase());
|
||||||
}
|
}
|
||||||
public static void registerGroupDifferentiatorConstructor(final String name, final Function<ObjectNode, GroupDifferentiator> groupDifferentiatorConstructor){
|
public static void registerSplitFinderConstructor(final String name, final Function<ObjectNode, SplitFinder> splitFinderConstructor){
|
||||||
GROUP_DIFFERENTIATOR_MAP.put(name.toLowerCase(), groupDifferentiatorConstructor);
|
SPLIT_FINDER_MAP.put(name.toLowerCase(), splitFinderConstructor);
|
||||||
}
|
}
|
||||||
static{
|
static{
|
||||||
registerGroupDifferentiatorConstructor("WeightedVarianceGroupDifferentiator",
|
registerSplitFinderConstructor("WeightedVarianceSplitFinder",
|
||||||
(node) -> new WeightedVarianceGroupDifferentiator()
|
(node) -> new WeightedVarianceSplitFinder()
|
||||||
);
|
);
|
||||||
registerGroupDifferentiatorConstructor("GrayLogRankDifferentiator",
|
registerSplitFinderConstructor("GrayLogRankSplitFinder",
|
||||||
(objectNode) -> {
|
(objectNode) -> {
|
||||||
final int[] eventsOfFocusArray = Utils.jsonToIntArray(objectNode.get("eventsOfFocus"));
|
final int[] eventsOfFocusArray = Utils.jsonToIntArray(objectNode.get("eventsOfFocus"));
|
||||||
final int[] eventArray = Utils.jsonToIntArray(objectNode.get("events"));
|
final int[] eventArray = Utils.jsonToIntArray(objectNode.get("events"));
|
||||||
|
|
||||||
return new GrayLogRankDifferentiator(eventsOfFocusArray, eventArray);
|
return new GrayLogRankSplitFinder(eventsOfFocusArray, eventArray);
|
||||||
}
|
}
|
||||||
);
|
);
|
||||||
registerGroupDifferentiatorConstructor("LogRankDifferentiator",
|
registerSplitFinderConstructor("LogRankSplitFinder",
|
||||||
(objectNode) -> {
|
(objectNode) -> {
|
||||||
final int[] eventsOfFocusArray = Utils.jsonToIntArray(objectNode.get("eventsOfFocus"));
|
final int[] eventsOfFocusArray = Utils.jsonToIntArray(objectNode.get("eventsOfFocus"));
|
||||||
final int[] eventArray = Utils.jsonToIntArray(objectNode.get("events"));
|
final int[] eventArray = Utils.jsonToIntArray(objectNode.get("events"));
|
||||||
|
|
||||||
return new LogRankDifferentiator(eventsOfFocusArray, eventArray);
|
return new LogRankSplitFinder(eventsOfFocusArray, eventArray);
|
||||||
}
|
}
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
@ -153,7 +153,7 @@ public class Settings {
|
||||||
private boolean checkNodePurity = false;
|
private boolean checkNodePurity = false;
|
||||||
|
|
||||||
private ObjectNode responseCombinerSettings = new ObjectNode(JsonNodeFactory.instance);
|
private ObjectNode responseCombinerSettings = new ObjectNode(JsonNodeFactory.instance);
|
||||||
private ObjectNode groupDifferentiatorSettings = new ObjectNode(JsonNodeFactory.instance);
|
private ObjectNode splitFinderSettings = new ObjectNode(JsonNodeFactory.instance);
|
||||||
private ObjectNode treeCombinerSettings = new ObjectNode(JsonNodeFactory.instance);
|
private ObjectNode treeCombinerSettings = new ObjectNode(JsonNodeFactory.instance);
|
||||||
|
|
||||||
private List<CovariateSettings> covariateSettings = new ArrayList<>();
|
private List<CovariateSettings> covariateSettings = new ArrayList<>();
|
||||||
|
@ -194,10 +194,10 @@ public class Settings {
|
||||||
}
|
}
|
||||||
|
|
||||||
@JsonIgnore
|
@JsonIgnore
|
||||||
public GroupDifferentiator getGroupDifferentiator(){
|
public SplitFinder getSplitFinder(){
|
||||||
final String type = groupDifferentiatorSettings.get("type").asText();
|
final String type = splitFinderSettings.get("type").asText();
|
||||||
|
|
||||||
return getGroupDifferentiatorConstructor(type).apply(groupDifferentiatorSettings);
|
return getSplitFinderConstructor(type).apply(splitFinderSettings);
|
||||||
}
|
}
|
||||||
|
|
||||||
@JsonIgnore
|
@JsonIgnore
|
||||||
|
|
|
@ -14,13 +14,13 @@
|
||||||
* along with this program. If not, see <https://www.gnu.org/licenses/>.
|
* along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package ca.joeltherrien.randomforest.responses.competingrisk.differentiator;
|
package ca.joeltherrien.randomforest.responses.competingrisk.splitfinder;
|
||||||
|
|
||||||
import ca.joeltherrien.randomforest.Row;
|
import ca.joeltherrien.randomforest.Row;
|
||||||
import ca.joeltherrien.randomforest.covariates.Covariate;
|
import ca.joeltherrien.randomforest.covariates.Covariate;
|
||||||
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
|
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
|
||||||
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskSets;
|
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskSets;
|
||||||
import ca.joeltherrien.randomforest.tree.GroupDifferentiator;
|
import ca.joeltherrien.randomforest.tree.SplitFinder;
|
||||||
import ca.joeltherrien.randomforest.tree.Split;
|
import ca.joeltherrien.randomforest.tree.Split;
|
||||||
import ca.joeltherrien.randomforest.tree.SplitAndScore;
|
import ca.joeltherrien.randomforest.tree.SplitAndScore;
|
||||||
import lombok.AllArgsConstructor;
|
import lombok.AllArgsConstructor;
|
||||||
|
@ -35,24 +35,24 @@ import java.util.stream.Collectors;
|
||||||
* modifies the abstract method.
|
* modifies the abstract method.
|
||||||
*
|
*
|
||||||
*/
|
*/
|
||||||
public abstract class CompetingRiskGroupDifferentiator<Y extends CompetingRiskResponse> implements GroupDifferentiator<Y> {
|
public abstract class CompetingRiskSplitFinder<Y extends CompetingRiskResponse> implements SplitFinder<Y> {
|
||||||
|
|
||||||
abstract protected CompetingRiskSets<Y> createCompetingRiskSets(List<Y> leftHand, List<Y> rightHand);
|
abstract protected CompetingRiskSets<Y> createCompetingRiskSets(List<Y> leftHand, List<Y> rightHand);
|
||||||
|
|
||||||
abstract protected Double getScore(final CompetingRiskSets<Y> competingRiskSets);
|
abstract protected Double getScore(final CompetingRiskSets<Y> competingRiskSets);
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public SplitAndScore<Y, ?> differentiate(Iterator<Split<Y, ?>> splitIterator) {
|
public SplitAndScore<Y, ?> findBestSplit(Iterator<Split<Y, ?>> splitIterator) {
|
||||||
|
|
||||||
if(splitIterator instanceof Covariate.SplitRuleUpdater){
|
if(splitIterator instanceof Covariate.SplitRuleUpdater){
|
||||||
return differentiateWithSplitUpdater((Covariate.SplitRuleUpdater) splitIterator);
|
return findBestSplitWithSplitUpdater((Covariate.SplitRuleUpdater) splitIterator);
|
||||||
}
|
}
|
||||||
else{
|
else{
|
||||||
return differentiateWithBasicIterator(splitIterator);
|
return findBestSplitWithBasicIterator(splitIterator);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private SplitAndScore<Y, ?> differentiateWithBasicIterator(Iterator<Split<Y, ?>> splitIterator){
|
private SplitAndScore<Y, ?> findBestSplitWithBasicIterator(Iterator<Split<Y, ?>> splitIterator){
|
||||||
Double bestScore = null;
|
Double bestScore = null;
|
||||||
Split<Y, ?> bestSplit = null;
|
Split<Y, ?> bestSplit = null;
|
||||||
|
|
||||||
|
@ -83,7 +83,7 @@ public abstract class CompetingRiskGroupDifferentiator<Y extends CompetingRiskRe
|
||||||
return new SplitAndScore<>(bestSplit, bestScore);
|
return new SplitAndScore<>(bestSplit, bestScore);
|
||||||
}
|
}
|
||||||
|
|
||||||
private SplitAndScore<Y, ?> differentiateWithSplitUpdater(Covariate.SplitRuleUpdater<Y, ?> splitRuleUpdater) {
|
private SplitAndScore<Y, ?> findBestSplitWithSplitUpdater(Covariate.SplitRuleUpdater<Y, ?> splitRuleUpdater) {
|
||||||
|
|
||||||
final List<Y> leftInitialSplit = splitRuleUpdater.currentSplit().getLeftHand()
|
final List<Y> leftInitialSplit = splitRuleUpdater.currentSplit().getLeftHand()
|
||||||
.stream().map(Row::getResponse).collect(Collectors.toList());
|
.stream().map(Row::getResponse).collect(Collectors.toList());
|
|
@ -14,7 +14,7 @@
|
||||||
* along with this program. If not, see <https://www.gnu.org/licenses/>.
|
* along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package ca.joeltherrien.randomforest.responses.competingrisk.differentiator;
|
package ca.joeltherrien.randomforest.responses.competingrisk.splitfinder;
|
||||||
|
|
||||||
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponseWithCensorTime;
|
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponseWithCensorTime;
|
||||||
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskSets;
|
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskSets;
|
||||||
|
@ -27,12 +27,12 @@ import java.util.List;
|
||||||
* See page 761 of Random survival forests for competing risks by Ishwaran et al.
|
* See page 761 of Random survival forests for competing risks by Ishwaran et al.
|
||||||
*
|
*
|
||||||
*/
|
*/
|
||||||
public class GrayLogRankDifferentiator extends CompetingRiskGroupDifferentiator<CompetingRiskResponseWithCensorTime> {
|
public class GrayLogRankSplitFinder extends CompetingRiskSplitFinder<CompetingRiskResponseWithCensorTime> {
|
||||||
|
|
||||||
private final int[] eventsOfFocus;
|
private final int[] eventsOfFocus;
|
||||||
private final int[] events;
|
private final int[] events;
|
||||||
|
|
||||||
public GrayLogRankDifferentiator(int[] eventsOfFocus, int[] events){
|
public GrayLogRankSplitFinder(int[] eventsOfFocus, int[] events){
|
||||||
this.eventsOfFocus = eventsOfFocus;
|
this.eventsOfFocus = eventsOfFocus;
|
||||||
this.events = events;
|
this.events = events;
|
||||||
|
|
|
@ -14,7 +14,7 @@
|
||||||
* along with this program. If not, see <https://www.gnu.org/licenses/>.
|
* along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package ca.joeltherrien.randomforest.responses.competingrisk.differentiator;
|
package ca.joeltherrien.randomforest.responses.competingrisk.splitfinder;
|
||||||
|
|
||||||
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
|
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
|
||||||
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskSets;
|
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskSets;
|
||||||
|
@ -27,12 +27,12 @@ import java.util.List;
|
||||||
* See page 761 of Random survival forests for competing risks by Ishwaran et al.
|
* See page 761 of Random survival forests for competing risks by Ishwaran et al.
|
||||||
*
|
*
|
||||||
*/
|
*/
|
||||||
public class LogRankDifferentiator extends CompetingRiskGroupDifferentiator<CompetingRiskResponse> {
|
public class LogRankSplitFinder extends CompetingRiskSplitFinder<CompetingRiskResponse> {
|
||||||
|
|
||||||
private final int[] eventsOfFocus;
|
private final int[] eventsOfFocus;
|
||||||
private final int[] events;
|
private final int[] events;
|
||||||
|
|
||||||
public LogRankDifferentiator(int[] eventsOfFocus, int[] events){
|
public LogRankSplitFinder(int[] eventsOfFocus, int[] events){
|
||||||
this.eventsOfFocus = eventsOfFocus;
|
this.eventsOfFocus = eventsOfFocus;
|
||||||
this.events = events;
|
this.events = events;
|
||||||
|
|
|
@ -18,7 +18,7 @@ package ca.joeltherrien.randomforest.responses.regression;
|
||||||
|
|
||||||
import ca.joeltherrien.randomforest.Row;
|
import ca.joeltherrien.randomforest.Row;
|
||||||
import ca.joeltherrien.randomforest.covariates.Covariate;
|
import ca.joeltherrien.randomforest.covariates.Covariate;
|
||||||
import ca.joeltherrien.randomforest.tree.GroupDifferentiator;
|
import ca.joeltherrien.randomforest.tree.SplitFinder;
|
||||||
import ca.joeltherrien.randomforest.tree.Split;
|
import ca.joeltherrien.randomforest.tree.Split;
|
||||||
import ca.joeltherrien.randomforest.tree.SplitAndScore;
|
import ca.joeltherrien.randomforest.tree.SplitAndScore;
|
||||||
|
|
||||||
|
@ -26,7 +26,7 @@ import java.util.Iterator;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
public class WeightedVarianceGroupDifferentiator implements GroupDifferentiator<Double> {
|
public class WeightedVarianceSplitFinder implements SplitFinder<Double> {
|
||||||
|
|
||||||
private Double getScore(Set leftHand, Set rightHand) {
|
private Double getScore(Set leftHand, Set rightHand) {
|
||||||
|
|
||||||
|
@ -44,17 +44,17 @@ public class WeightedVarianceGroupDifferentiator implements GroupDifferentiator<
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public SplitAndScore<Double, ?> differentiate(Iterator<Split<Double, ?>> splitIterator) {
|
public SplitAndScore<Double, ?> findBestSplit(Iterator<Split<Double, ?>> splitIterator) {
|
||||||
|
|
||||||
if(splitIterator instanceof Covariate.SplitRuleUpdater){
|
if(splitIterator instanceof Covariate.SplitRuleUpdater){
|
||||||
return differentiateWithSplitUpdater((Covariate.SplitRuleUpdater) splitIterator);
|
return findBestSplitWithSplitUpdater((Covariate.SplitRuleUpdater) splitIterator);
|
||||||
}
|
}
|
||||||
else{
|
else{
|
||||||
return differentiateWithBasicIterator(splitIterator);
|
return findBestSplitWithBasicIterator(splitIterator);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private SplitAndScore<Double, ?> differentiateWithBasicIterator(Iterator<Split<Double, ?>> splitIterator){
|
private SplitAndScore<Double, ?> findBestSplitWithBasicIterator(Iterator<Split<Double, ?>> splitIterator){
|
||||||
Double bestScore = null;
|
Double bestScore = null;
|
||||||
Split<Double, ?> bestSplit = null;
|
Split<Double, ?> bestSplit = null;
|
||||||
|
|
||||||
|
@ -86,7 +86,7 @@ public class WeightedVarianceGroupDifferentiator implements GroupDifferentiator<
|
||||||
return new SplitAndScore<>(bestSplit, bestScore);
|
return new SplitAndScore<>(bestSplit, bestScore);
|
||||||
}
|
}
|
||||||
|
|
||||||
private SplitAndScore<Double, ?> differentiateWithSplitUpdater(Covariate.SplitRuleUpdater<Double, ?> splitRuleUpdater) {
|
private SplitAndScore<Double, ?> findBestSplitWithSplitUpdater(Covariate.SplitRuleUpdater<Double, ?> splitRuleUpdater) {
|
||||||
|
|
||||||
final List<Double> leftInitialSplit = splitRuleUpdater.currentSplit().getLeftHand()
|
final List<Double> leftInitialSplit = splitRuleUpdater.currentSplit().getLeftHand()
|
||||||
.stream().map(Row::getResponse).collect(Collectors.toList());
|
.stream().map(Row::getResponse).collect(Collectors.toList());
|
|
@ -22,10 +22,10 @@ import java.util.Iterator;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
public abstract class SimpleGroupDifferentiator<Y> implements GroupDifferentiator<Y> {
|
public abstract class SimpleSplitFinder<Y> implements SplitFinder<Y> {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public SplitAndScore<Y, ?> differentiate(Iterator<Split<Y, ?>> splitIterator) {
|
public SplitAndScore<Y, ?> findBestSplit(Iterator<Split<Y, ?>> splitIterator) {
|
||||||
Double bestScore = null;
|
Double bestScore = null;
|
||||||
Split<Y, ?> bestSplit = null;
|
Split<Y, ?> bestSplit = null;
|
||||||
|
|
|
@ -21,14 +21,14 @@ import java.util.Iterator;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* When choosing an optimal node to split on, we choose the split that maximizes the difference between the two groups.
|
* When choosing an optimal node to split on, we choose the split that maximizes the difference between the two groups.
|
||||||
* The GroupDifferentiator has one method that cycles through an iterator of Splits (FYI; check if the iterator is an
|
* The SplitFinder has one method that cycles through an iterator of Splits (FYI; check if the iterator is an
|
||||||
* instance of Covariate.SplitRuleUpdater; in which case you get access to the rows that change between splits)
|
* instance of Covariate.SplitRuleUpdater; in which case you get access to the rows that change between splits)
|
||||||
*
|
*
|
||||||
* If you want to implement a very trivial GroupDifferentiator that just takes two Lists as arguments, try extending
|
* If you want to implement a very trivial SplitFinder that just takes two Lists as arguments, try extending
|
||||||
* SimpleGroupDifferentiator.
|
* SimpleSplitFinder.
|
||||||
*/
|
*/
|
||||||
public interface GroupDifferentiator<Y> extends Serializable {
|
public interface SplitFinder<Y> extends Serializable {
|
||||||
|
|
||||||
SplitAndScore<Y, ?> differentiate(Iterator<Split<Y, ?>> splitIterator);
|
SplitAndScore<Y, ?> findBestSplit(Iterator<Split<Y, ?>> splitIterator);
|
||||||
|
|
||||||
}
|
}
|
|
@ -31,7 +31,7 @@ import java.util.stream.Collectors;
|
||||||
public class TreeTrainer<Y, O> {
|
public class TreeTrainer<Y, O> {
|
||||||
|
|
||||||
private final ResponseCombiner<Y, O> responseCombiner;
|
private final ResponseCombiner<Y, O> responseCombiner;
|
||||||
private final GroupDifferentiator<Y> groupDifferentiator;
|
private final SplitFinder<Y> splitFinder;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The number of splits to perform on each covariate. A value of 0 means all possible splits are tried.
|
* The number of splits to perform on each covariate. A value of 0 means all possible splits are tried.
|
||||||
|
@ -58,7 +58,7 @@ public class TreeTrainer<Y, O> {
|
||||||
this.checkNodePurity = settings.isCheckNodePurity();
|
this.checkNodePurity = settings.isCheckNodePurity();
|
||||||
|
|
||||||
this.responseCombiner = settings.getResponseCombiner();
|
this.responseCombiner = settings.getResponseCombiner();
|
||||||
this.groupDifferentiator = settings.getGroupDifferentiator();
|
this.splitFinder = settings.getSplitFinder();
|
||||||
this.covariates = covariates;
|
this.covariates = covariates;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -160,7 +160,7 @@ public class TreeTrainer<Y, O> {
|
||||||
private Split<Y, ?> findBestSplitRule(List<Row<Y>> data, List<Covariate> covariatesToTry, Random random){
|
private Split<Y, ?> findBestSplitRule(List<Row<Y>> data, List<Covariate> covariatesToTry, Random random){
|
||||||
|
|
||||||
SplitAndScore<Y, ?> bestSplitAndScore = null;
|
SplitAndScore<Y, ?> bestSplitAndScore = null;
|
||||||
final GroupDifferentiator noGenericDifferentiator = groupDifferentiator; // cause Java generics suck
|
final SplitFinder noGenericSplitFinder = splitFinder; // cause Java generics are sometimes too frustrating
|
||||||
|
|
||||||
for(final Covariate covariate : covariatesToTry) {
|
for(final Covariate covariate : covariatesToTry) {
|
||||||
final Iterator<Split> iterator = covariate.generateSplitRuleUpdater(data, this.numberOfSplits, random);
|
final Iterator<Split> iterator = covariate.generateSplitRuleUpdater(data, this.numberOfSplits, random);
|
||||||
|
@ -170,7 +170,7 @@ public class TreeTrainer<Y, O> {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
final SplitAndScore<Y, ?> candidateSplitAndScore = noGenericDifferentiator.differentiate(iterator);
|
final SplitAndScore<Y, ?> candidateSplitAndScore = noGenericSplitFinder.findBestSplit(iterator);
|
||||||
|
|
||||||
if(candidateSplitAndScore != null && (bestSplitAndScore == null ||
|
if(candidateSplitAndScore != null && (bestSplitAndScore == null ||
|
||||||
candidateSplitAndScore.getScore() > bestSplitAndScore.getScore())) {
|
candidateSplitAndScore.getScore() > bestSplitAndScore.getScore())) {
|
||||||
|
|
|
@ -44,12 +44,12 @@ public class TestSavingLoading {
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
public Settings getSettings(){
|
public Settings getSettings(){
|
||||||
final ObjectNode groupDifferentiatorSettings = new ObjectNode(JsonNodeFactory.instance);
|
final ObjectNode splitRuleSettings = new ObjectNode(JsonNodeFactory.instance);
|
||||||
groupDifferentiatorSettings.set("type", new TextNode("LogRankDifferentiator"));
|
splitRuleSettings.set("type", new TextNode("LogRankSplitFinder"));
|
||||||
groupDifferentiatorSettings.set("eventsOfFocus",
|
splitRuleSettings.set("eventsOfFocus",
|
||||||
new ArrayNode(JsonNodeFactory.instance, Utils.easyList(new IntNode(1)))
|
new ArrayNode(JsonNodeFactory.instance, Utils.easyList(new IntNode(1)))
|
||||||
);
|
);
|
||||||
groupDifferentiatorSettings.set("events",
|
splitRuleSettings.set("events",
|
||||||
new ArrayNode(JsonNodeFactory.instance, Utils.easyList(new IntNode(1), new IntNode(2)))
|
new ArrayNode(JsonNodeFactory.instance, Utils.easyList(new IntNode(1), new IntNode(2)))
|
||||||
);
|
);
|
||||||
|
|
||||||
|
@ -85,7 +85,7 @@ public class TestSavingLoading {
|
||||||
.validationDataLocation("src/test/resources/wihs.csv")
|
.validationDataLocation("src/test/resources/wihs.csv")
|
||||||
.responseCombinerSettings(responseCombinerSettings)
|
.responseCombinerSettings(responseCombinerSettings)
|
||||||
.treeCombinerSettings(treeCombinerSettings)
|
.treeCombinerSettings(treeCombinerSettings)
|
||||||
.groupDifferentiatorSettings(groupDifferentiatorSettings)
|
.splitFinderSettings(splitRuleSettings)
|
||||||
.yVarSettings(yVarSettings)
|
.yVarSettings(yVarSettings)
|
||||||
.maxNodeDepth(100000)
|
.maxNodeDepth(100000)
|
||||||
// TODO fill in these settings
|
// TODO fill in these settings
|
||||||
|
|
|
@ -53,12 +53,12 @@ public class TestCompetingRisk {
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
public Settings getSettings(){
|
public Settings getSettings(){
|
||||||
final ObjectNode groupDifferentiatorSettings = new ObjectNode(JsonNodeFactory.instance);
|
final ObjectNode splitFinderSettings = new ObjectNode(JsonNodeFactory.instance);
|
||||||
groupDifferentiatorSettings.set("type", new TextNode("LogRankDifferentiator"));
|
splitFinderSettings.set("type", new TextNode("LogRankSplitFinder"));
|
||||||
groupDifferentiatorSettings.set("eventsOfFocus",
|
splitFinderSettings.set("eventsOfFocus",
|
||||||
new ArrayNode(JsonNodeFactory.instance, Utils.easyList(new IntNode(1)))
|
new ArrayNode(JsonNodeFactory.instance, Utils.easyList(new IntNode(1)))
|
||||||
);
|
);
|
||||||
groupDifferentiatorSettings.set("events",
|
splitFinderSettings.set("events",
|
||||||
new ArrayNode(JsonNodeFactory.instance, Utils.easyList(new IntNode(1), new IntNode(2)))
|
new ArrayNode(JsonNodeFactory.instance, Utils.easyList(new IntNode(1), new IntNode(2)))
|
||||||
);
|
);
|
||||||
|
|
||||||
|
@ -93,7 +93,7 @@ public class TestCompetingRisk {
|
||||||
.trainingDataLocation("src/test/resources/wihs.csv")
|
.trainingDataLocation("src/test/resources/wihs.csv")
|
||||||
.responseCombinerSettings(responseCombinerSettings)
|
.responseCombinerSettings(responseCombinerSettings)
|
||||||
.treeCombinerSettings(treeCombinerSettings)
|
.treeCombinerSettings(treeCombinerSettings)
|
||||||
.groupDifferentiatorSettings(groupDifferentiatorSettings)
|
.splitFinderSettings(splitFinderSettings)
|
||||||
.yVarSettings(yVarSettings)
|
.yVarSettings(yVarSettings)
|
||||||
.maxNodeDepth(100000)
|
.maxNodeDepth(100000)
|
||||||
// TODO fill in these settings
|
// TODO fill in these settings
|
||||||
|
@ -222,7 +222,7 @@ public class TestCompetingRisk {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testLogRankSingleGroupDifferentiatorTwoBooleans() throws IOException {
|
public void testLogRankSplitFinderTwoBooleans() throws IOException {
|
||||||
final Settings settings = getSettings();
|
final Settings settings = getSettings();
|
||||||
settings.setCovariateSettings(Utils.easyList(
|
settings.setCovariateSettings(Utils.easyList(
|
||||||
new BooleanCovariateSettings("idu"),
|
new BooleanCovariateSettings("idu"),
|
||||||
|
@ -337,7 +337,7 @@ public class TestCompetingRisk {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testLogRankSingleGroupDifferentiatorAllCovariates() throws IOException {
|
public void testLogRankSplitFinderAllCovariates() throws IOException {
|
||||||
|
|
||||||
final Settings settings = getSettings();
|
final Settings settings = getSettings();
|
||||||
settings.setNtree(300); // results are too variable at 100
|
settings.setNtree(300); // results are too variable at 100
|
||||||
|
|
|
@ -21,7 +21,7 @@ import ca.joeltherrien.randomforest.Settings;
|
||||||
import ca.joeltherrien.randomforest.covariates.Covariate;
|
import ca.joeltherrien.randomforest.covariates.Covariate;
|
||||||
import ca.joeltherrien.randomforest.covariates.settings.NumericCovariateSettings;
|
import ca.joeltherrien.randomforest.covariates.settings.NumericCovariateSettings;
|
||||||
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
|
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
|
||||||
import ca.joeltherrien.randomforest.responses.competingrisk.differentiator.LogRankDifferentiator;
|
import ca.joeltherrien.randomforest.responses.competingrisk.splitfinder.LogRankSplitFinder;
|
||||||
import ca.joeltherrien.randomforest.tree.Split;
|
import ca.joeltherrien.randomforest.tree.Split;
|
||||||
import ca.joeltherrien.randomforest.utils.DataUtils;
|
import ca.joeltherrien.randomforest.utils.DataUtils;
|
||||||
import ca.joeltherrien.randomforest.utils.SingletonIterator;
|
import ca.joeltherrien.randomforest.utils.SingletonIterator;
|
||||||
|
@ -39,7 +39,7 @@ import java.util.List;
|
||||||
|
|
||||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||||
|
|
||||||
public class TestLogRankDifferentiator {
|
public class TestLogRankSplitFinder {
|
||||||
|
|
||||||
private Iterator<Split<CompetingRiskResponse, ?>> turnIntoSplitIterator(List<Row<CompetingRiskResponse>> leftList,
|
private Iterator<Split<CompetingRiskResponse, ?>> turnIntoSplitIterator(List<Row<CompetingRiskResponse>> leftList,
|
||||||
List<Row<CompetingRiskResponse>> rightList){
|
List<Row<CompetingRiskResponse>> rightList){
|
||||||
|
@ -70,14 +70,14 @@ public class TestLogRankDifferentiator {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testSplitRule() throws IOException {
|
public void testSplitRule() throws IOException {
|
||||||
final LogRankDifferentiator groupDifferentiator = new LogRankDifferentiator(new int[]{1,2}, new int[]{1,2});
|
final LogRankSplitFinder splitFinder = new LogRankSplitFinder(new int[]{1,2}, new int[]{1,2});
|
||||||
|
|
||||||
final List<Row<CompetingRiskResponse>> data = loadData("src/test/resources/test_split_data.csv").getRows();
|
final List<Row<CompetingRiskResponse>> data = loadData("src/test/resources/test_split_data.csv").getRows();
|
||||||
|
|
||||||
final List<Row<CompetingRiskResponse>> group1Bad = data.subList(0, 196);
|
final List<Row<CompetingRiskResponse>> group1Bad = data.subList(0, 196);
|
||||||
final List<Row<CompetingRiskResponse>> group2Bad = data.subList(196, data.size());
|
final List<Row<CompetingRiskResponse>> group2Bad = data.subList(196, data.size());
|
||||||
|
|
||||||
final double scoreBad = groupDifferentiator.differentiate(turnIntoSplitIterator(group1Bad, group2Bad)).getScore();
|
final double scoreBad = splitFinder.findBestSplit(turnIntoSplitIterator(group1Bad, group2Bad)).getScore();
|
||||||
|
|
||||||
// expected results calculated manually using survival::survdiff in R; see issue #10 in Gitea
|
// expected results calculated manually using survival::survdiff in R; see issue #10 in Gitea
|
||||||
closeEnough(9.413002, scoreBad, 0.00001);
|
closeEnough(9.413002, scoreBad, 0.00001);
|
|
@ -18,8 +18,8 @@ package ca.joeltherrien.randomforest.competingrisk;
|
||||||
|
|
||||||
import ca.joeltherrien.randomforest.Row;
|
import ca.joeltherrien.randomforest.Row;
|
||||||
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
|
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
|
||||||
import ca.joeltherrien.randomforest.responses.competingrisk.differentiator.LogRankDifferentiator;
|
import ca.joeltherrien.randomforest.responses.competingrisk.splitfinder.LogRankSplitFinder;
|
||||||
import ca.joeltherrien.randomforest.tree.GroupDifferentiator;
|
import ca.joeltherrien.randomforest.tree.SplitFinder;
|
||||||
import ca.joeltherrien.randomforest.tree.Split;
|
import ca.joeltherrien.randomforest.tree.Split;
|
||||||
import ca.joeltherrien.randomforest.utils.SingletonIterator;
|
import ca.joeltherrien.randomforest.utils.SingletonIterator;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
|
@ -33,13 +33,13 @@ import java.util.List;
|
||||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||||
|
|
||||||
public class TestLogRankSingleGroupDifferentiator {
|
public class TestLogRankSplitFinderSingleEvent {
|
||||||
|
|
||||||
private double getScore(final GroupDifferentiator<CompetingRiskResponse> groupDifferentiator, List<Row<CompetingRiskResponse>> left, List<Row<CompetingRiskResponse>> right){
|
private double getScore(final SplitFinder<CompetingRiskResponse> splitFinder, List<Row<CompetingRiskResponse>> left, List<Row<CompetingRiskResponse>> right){
|
||||||
final Iterator<Split<CompetingRiskResponse, ?>> iterator = new SingletonIterator<>(
|
final Iterator<Split<CompetingRiskResponse, ?>> iterator = new SingletonIterator<>(
|
||||||
new Split<>(null, left, right, Collections.emptyList()));
|
new Split<>(null, left, right, Collections.emptyList()));
|
||||||
|
|
||||||
return groupDifferentiator.differentiate(iterator).getScore();
|
return splitFinder.findBestSplit(iterator).getScore();
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -81,9 +81,9 @@ public class TestLogRankSingleGroupDifferentiator {
|
||||||
final List<Row<CompetingRiskResponse>> data1 = generateData1();
|
final List<Row<CompetingRiskResponse>> data1 = generateData1();
|
||||||
final List<Row<CompetingRiskResponse>> data2 = generateData2();
|
final List<Row<CompetingRiskResponse>> data2 = generateData2();
|
||||||
|
|
||||||
final LogRankDifferentiator differentiator = new LogRankDifferentiator(new int[]{1}, new int[]{1});
|
final LogRankSplitFinder splitFinder = new LogRankSplitFinder(new int[]{1}, new int[]{1});
|
||||||
|
|
||||||
final double score = getScore(differentiator, data1, data2);
|
final double score = getScore(splitFinder, data1, data2);
|
||||||
final double margin = 0.000001;
|
final double margin = 0.000001;
|
||||||
|
|
||||||
// Tested using 855 method
|
// Tested using 855 method
|
||||||
|
@ -94,21 +94,21 @@ public class TestLogRankSingleGroupDifferentiator {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testCorrectSplit() throws IOException {
|
public void testCorrectSplit() throws IOException {
|
||||||
final LogRankDifferentiator groupDifferentiator =
|
final LogRankSplitFinder splitFinder =
|
||||||
new LogRankDifferentiator(new int[]{1}, new int[]{1,2});
|
new LogRankSplitFinder(new int[]{1}, new int[]{1,2});
|
||||||
|
|
||||||
final List<Row<CompetingRiskResponse>> data = TestLogRankDifferentiator.
|
final List<Row<CompetingRiskResponse>> data = TestLogRankSplitFinder.
|
||||||
loadData("src/test/resources/test_single_split.csv").getRows();
|
loadData("src/test/resources/test_single_split.csv").getRows();
|
||||||
|
|
||||||
final List<Row<CompetingRiskResponse>> group1Good = data.subList(0, 221);
|
final List<Row<CompetingRiskResponse>> group1Good = data.subList(0, 221);
|
||||||
final List<Row<CompetingRiskResponse>> group2Good = data.subList(221, data.size());
|
final List<Row<CompetingRiskResponse>> group2Good = data.subList(221, data.size());
|
||||||
|
|
||||||
final double scoreGood = getScore(groupDifferentiator, group1Good, group2Good);
|
final double scoreGood = getScore(splitFinder, group1Good, group2Good);
|
||||||
|
|
||||||
final List<Row<CompetingRiskResponse>> group1Bad = data.subList(0, 222);
|
final List<Row<CompetingRiskResponse>> group1Bad = data.subList(0, 222);
|
||||||
final List<Row<CompetingRiskResponse>> group2Bad = data.subList(222, data.size());
|
final List<Row<CompetingRiskResponse>> group2Bad = data.subList(222, data.size());
|
||||||
|
|
||||||
final double scoreBad = getScore(groupDifferentiator, group1Bad, group2Bad);
|
final double scoreBad = getScore(splitFinder, group1Bad, group2Bad);
|
||||||
|
|
||||||
// Apparently not all groups are unique when splitting
|
// Apparently not all groups are unique when splitting
|
||||||
assertEquals(scoreGood, scoreBad);
|
assertEquals(scoreGood, scoreBad);
|
|
@ -4,7 +4,7 @@ import ca.joeltherrien.randomforest.Row;
|
||||||
import ca.joeltherrien.randomforest.covariates.Covariate;
|
import ca.joeltherrien.randomforest.covariates.Covariate;
|
||||||
import ca.joeltherrien.randomforest.covariates.numeric.NumericCovariate;
|
import ca.joeltherrien.randomforest.covariates.numeric.NumericCovariate;
|
||||||
import ca.joeltherrien.randomforest.responses.regression.MeanResponseCombiner;
|
import ca.joeltherrien.randomforest.responses.regression.MeanResponseCombiner;
|
||||||
import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceGroupDifferentiator;
|
import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceSplitFinder;
|
||||||
import ca.joeltherrien.randomforest.tree.TreeTrainer;
|
import ca.joeltherrien.randomforest.tree.TreeTrainer;
|
||||||
import ca.joeltherrien.randomforest.utils.Utils;
|
import ca.joeltherrien.randomforest.utils.Utils;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
|
@ -47,7 +47,7 @@ public class TestNAs {
|
||||||
.numberOfSplits(0)
|
.numberOfSplits(0)
|
||||||
.nodeSize(1)
|
.nodeSize(1)
|
||||||
.maxNodeDepth(1000)
|
.maxNodeDepth(1000)
|
||||||
.groupDifferentiator(new WeightedVarianceGroupDifferentiator())
|
.splitFinder(new WeightedVarianceSplitFinder())
|
||||||
.responseCombiner(new MeanResponseCombiner())
|
.responseCombiner(new MeanResponseCombiner())
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
|
|
|
@ -35,8 +35,8 @@ public class TestPersistence {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testSaving() throws IOException {
|
public void testSaving() throws IOException {
|
||||||
final ObjectNode groupDifferentiatorSettings = new ObjectNode(JsonNodeFactory.instance);
|
final ObjectNode splitFinderSettings = new ObjectNode(JsonNodeFactory.instance);
|
||||||
groupDifferentiatorSettings.set("type", new TextNode("WeightedVarianceGroupDifferentiator"));
|
splitFinderSettings.set("type", new TextNode("WeightedVarianceSplitFinder"));
|
||||||
|
|
||||||
final ObjectNode responseCombinerSettings = new ObjectNode(JsonNodeFactory.instance);
|
final ObjectNode responseCombinerSettings = new ObjectNode(JsonNodeFactory.instance);
|
||||||
responseCombinerSettings.set("type", new TextNode("MeanResponseCombiner"));
|
responseCombinerSettings.set("type", new TextNode("MeanResponseCombiner"));
|
||||||
|
@ -59,7 +59,7 @@ public class TestPersistence {
|
||||||
.validationDataLocation("validation_data.csv")
|
.validationDataLocation("validation_data.csv")
|
||||||
.responseCombinerSettings(responseCombinerSettings)
|
.responseCombinerSettings(responseCombinerSettings)
|
||||||
.treeCombinerSettings(treeCombinerSettings)
|
.treeCombinerSettings(treeCombinerSettings)
|
||||||
.groupDifferentiatorSettings(groupDifferentiatorSettings)
|
.splitFinderSettings(splitFinderSettings)
|
||||||
.yVarSettings(yVarSettings)
|
.yVarSettings(yVarSettings)
|
||||||
.maxNodeDepth(100000)
|
.maxNodeDepth(100000)
|
||||||
.mtry(2)
|
.mtry(2)
|
||||||
|
|
|
@ -20,7 +20,7 @@ import ca.joeltherrien.randomforest.*;
|
||||||
import ca.joeltherrien.randomforest.covariates.Covariate;
|
import ca.joeltherrien.randomforest.covariates.Covariate;
|
||||||
import ca.joeltherrien.randomforest.covariates.numeric.NumericCovariate;
|
import ca.joeltherrien.randomforest.covariates.numeric.NumericCovariate;
|
||||||
import ca.joeltherrien.randomforest.responses.regression.MeanResponseCombiner;
|
import ca.joeltherrien.randomforest.responses.regression.MeanResponseCombiner;
|
||||||
import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceGroupDifferentiator;
|
import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceSplitFinder;
|
||||||
import ca.joeltherrien.randomforest.tree.ForestTrainer;
|
import ca.joeltherrien.randomforest.tree.ForestTrainer;
|
||||||
import ca.joeltherrien.randomforest.tree.TreeTrainer;
|
import ca.joeltherrien.randomforest.tree.TreeTrainer;
|
||||||
|
|
||||||
|
@ -72,7 +72,7 @@ public class TrainForest {
|
||||||
.nodeSize(5)
|
.nodeSize(5)
|
||||||
.mtry(4)
|
.mtry(4)
|
||||||
.maxNodeDepth(100000000)
|
.maxNodeDepth(100000000)
|
||||||
.groupDifferentiator(new WeightedVarianceGroupDifferentiator())
|
.splitFinder(new WeightedVarianceSplitFinder())
|
||||||
.responseCombiner(new MeanResponseCombiner())
|
.responseCombiner(new MeanResponseCombiner())
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
|
|
|
@ -22,7 +22,7 @@ import ca.joeltherrien.randomforest.CovariateRow;
|
||||||
import ca.joeltherrien.randomforest.covariates.numeric.NumericCovariate;
|
import ca.joeltherrien.randomforest.covariates.numeric.NumericCovariate;
|
||||||
import ca.joeltherrien.randomforest.Row;
|
import ca.joeltherrien.randomforest.Row;
|
||||||
import ca.joeltherrien.randomforest.responses.regression.MeanResponseCombiner;
|
import ca.joeltherrien.randomforest.responses.regression.MeanResponseCombiner;
|
||||||
import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceGroupDifferentiator;
|
import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceSplitFinder;
|
||||||
import ca.joeltherrien.randomforest.tree.Node;
|
import ca.joeltherrien.randomforest.tree.Node;
|
||||||
import ca.joeltherrien.randomforest.tree.TreeTrainer;
|
import ca.joeltherrien.randomforest.tree.TreeTrainer;
|
||||||
import ca.joeltherrien.randomforest.utils.Utils;
|
import ca.joeltherrien.randomforest.utils.Utils;
|
||||||
|
@ -66,7 +66,7 @@ public class TrainSingleTree {
|
||||||
final List<Covariate> covariateNames = Utils.easyList(x1Covariate, x2Covariate);
|
final List<Covariate> covariateNames = Utils.easyList(x1Covariate, x2Covariate);
|
||||||
|
|
||||||
final TreeTrainer<Double, Double> treeTrainer = TreeTrainer.<Double, Double>builder()
|
final TreeTrainer<Double, Double> treeTrainer = TreeTrainer.<Double, Double>builder()
|
||||||
.groupDifferentiator(new WeightedVarianceGroupDifferentiator())
|
.splitFinder(new WeightedVarianceSplitFinder())
|
||||||
.covariates(covariateNames)
|
.covariates(covariateNames)
|
||||||
.responseCombiner(new MeanResponseCombiner())
|
.responseCombiner(new MeanResponseCombiner())
|
||||||
.maxNodeDepth(30)
|
.maxNodeDepth(30)
|
||||||
|
|
|
@ -22,7 +22,7 @@ import ca.joeltherrien.randomforest.covariates.Covariate;
|
||||||
import ca.joeltherrien.randomforest.covariates.factor.FactorCovariate;
|
import ca.joeltherrien.randomforest.covariates.factor.FactorCovariate;
|
||||||
import ca.joeltherrien.randomforest.covariates.numeric.NumericCovariate;
|
import ca.joeltherrien.randomforest.covariates.numeric.NumericCovariate;
|
||||||
import ca.joeltherrien.randomforest.responses.regression.MeanResponseCombiner;
|
import ca.joeltherrien.randomforest.responses.regression.MeanResponseCombiner;
|
||||||
import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceGroupDifferentiator;
|
import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceSplitFinder;
|
||||||
import ca.joeltherrien.randomforest.tree.Node;
|
import ca.joeltherrien.randomforest.tree.Node;
|
||||||
import ca.joeltherrien.randomforest.tree.TreeTrainer;
|
import ca.joeltherrien.randomforest.tree.TreeTrainer;
|
||||||
import ca.joeltherrien.randomforest.utils.Utils;
|
import ca.joeltherrien.randomforest.utils.Utils;
|
||||||
|
@ -86,7 +86,7 @@ public class TrainSingleTreeFactor {
|
||||||
final List<Covariate> covariateNames = Utils.easyList(x1Covariate, x2Covariate);
|
final List<Covariate> covariateNames = Utils.easyList(x1Covariate, x2Covariate);
|
||||||
|
|
||||||
final TreeTrainer<Double, Double> treeTrainer = TreeTrainer.<Double, Double>builder()
|
final TreeTrainer<Double, Double> treeTrainer = TreeTrainer.<Double, Double>builder()
|
||||||
.groupDifferentiator(new WeightedVarianceGroupDifferentiator())
|
.splitFinder(new WeightedVarianceSplitFinder())
|
||||||
.responseCombiner(new MeanResponseCombiner())
|
.responseCombiner(new MeanResponseCombiner())
|
||||||
.covariates(covariateNames)
|
.covariates(covariateNames)
|
||||||
.maxNodeDepth(30)
|
.maxNodeDepth(30)
|
||||||
|
|
Loading…
Reference in a new issue