Add parameter to decide on whether to check for node purity or not

This commit is contained in:
Joel Therrien 2018-10-15 11:03:35 -07:00
parent 7fba964af9
commit cce5ad1e0f
2 changed files with 24 additions and 2 deletions

View file

@ -177,6 +177,7 @@ public class Settings {
private int numberOfSplits = 5;
private int nodeSize = 5;
private int maxNodeDepth = 1000000; // basically no maxNodeDepth
private boolean checkNodePurity = false;
private ObjectNode responseCombinerSettings = new ObjectNode(JsonNodeFactory.instance);
private ObjectNode groupDifferentiatorSettings = new ObjectNode(JsonNodeFactory.instance);

View file

@ -27,6 +27,12 @@ public class TreeTrainer<Y, O> {
private final int maxNodeDepth;
private final int mtry;
/**
* Whether to check if a node is pure or not when deciding to split. Splitting on a pure node won't change predictive accuracy,
* but (depending on conditions) may hurt performance.
*/
private final boolean checkNodePurity;
private final List<Covariate> covariates;
public TreeTrainer(final Settings settings, final List<Covariate> covariates){
@ -34,6 +40,7 @@ public class TreeTrainer<Y, O> {
this.nodeSize = settings.getNodeSize();
this.maxNodeDepth = settings.getMaxNodeDepth();
this.mtry = settings.getMtry();
this.checkNodePurity = settings.isCheckNodePurity();
this.responseCombiner = settings.getResponseCombiner();
this.groupDifferentiator = settings.getGroupDifferentiator();
@ -48,7 +55,7 @@ public class TreeTrainer<Y, O> {
}
private Node<O> growNode(List<Row<Y>> data, int depth){
// TODO; what is minimum per tree?
// See https://kogalur.github.io/randomForestSRC/theory.html#section3.1 (near bottom)
if(data.size() >= 2*nodeSize && depth < maxNodeDepth && !nodeIsPure(data)){
final List<Covariate> covariatesToTry = selectCovariates(this.mtry);
final SplitRuleAndSplit bestSplitRuleAndSplit = findBestSplitRule(data, covariatesToTry);
@ -161,8 +168,22 @@ public class TreeTrainer<Y, O> {
}
private boolean nodeIsPure(List<Row<Y>> data){
if(!checkNodePurity){
return false;
}
if(data.size() <= 1){
return true;
}
final Y first = data.get(0).getResponse();
return data.stream().allMatch(row -> row.getResponse().equals(first));
for(int i = 1; i< data.size(); i++){
if(!data.get(i).getResponse().equals(first)){
return false;
}
}
return true;
}
private class SplitRuleAndSplit{