Add parameter to decide on whether to check for node purity or not
This commit is contained in:
parent
7fba964af9
commit
cce5ad1e0f
2 changed files with 24 additions and 2 deletions
|
@ -177,6 +177,7 @@ public class Settings {
|
|||
private int numberOfSplits = 5;
|
||||
private int nodeSize = 5;
|
||||
private int maxNodeDepth = 1000000; // basically no maxNodeDepth
|
||||
private boolean checkNodePurity = false;
|
||||
|
||||
private ObjectNode responseCombinerSettings = new ObjectNode(JsonNodeFactory.instance);
|
||||
private ObjectNode groupDifferentiatorSettings = new ObjectNode(JsonNodeFactory.instance);
|
||||
|
|
|
@ -27,6 +27,12 @@ public class TreeTrainer<Y, O> {
|
|||
private final int maxNodeDepth;
|
||||
private final int mtry;
|
||||
|
||||
/**
|
||||
* Whether to check if a node is pure or not when deciding to split. Splitting on a pure node won't change predictive accuracy,
|
||||
* but (depending on conditions) may hurt performance.
|
||||
*/
|
||||
private final boolean checkNodePurity;
|
||||
|
||||
private final List<Covariate> covariates;
|
||||
|
||||
public TreeTrainer(final Settings settings, final List<Covariate> covariates){
|
||||
|
@ -34,6 +40,7 @@ public class TreeTrainer<Y, O> {
|
|||
this.nodeSize = settings.getNodeSize();
|
||||
this.maxNodeDepth = settings.getMaxNodeDepth();
|
||||
this.mtry = settings.getMtry();
|
||||
this.checkNodePurity = settings.isCheckNodePurity();
|
||||
|
||||
this.responseCombiner = settings.getResponseCombiner();
|
||||
this.groupDifferentiator = settings.getGroupDifferentiator();
|
||||
|
@ -48,7 +55,7 @@ public class TreeTrainer<Y, O> {
|
|||
}
|
||||
|
||||
private Node<O> growNode(List<Row<Y>> data, int depth){
|
||||
// TODO; what is minimum per tree?
|
||||
// See https://kogalur.github.io/randomForestSRC/theory.html#section3.1 (near bottom)
|
||||
if(data.size() >= 2*nodeSize && depth < maxNodeDepth && !nodeIsPure(data)){
|
||||
final List<Covariate> covariatesToTry = selectCovariates(this.mtry);
|
||||
final SplitRuleAndSplit bestSplitRuleAndSplit = findBestSplitRule(data, covariatesToTry);
|
||||
|
@ -161,8 +168,22 @@ public class TreeTrainer<Y, O> {
|
|||
}
|
||||
|
||||
private boolean nodeIsPure(List<Row<Y>> data){
|
||||
if(!checkNodePurity){
|
||||
return false;
|
||||
}
|
||||
|
||||
if(data.size() <= 1){
|
||||
return true;
|
||||
}
|
||||
|
||||
final Y first = data.get(0).getResponse();
|
||||
return data.stream().allMatch(row -> row.getResponse().equals(first));
|
||||
for(int i = 1; i< data.size(); i++){
|
||||
if(!data.get(i).getResponse().equals(first)){
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
private class SplitRuleAndSplit{
|
||||
|
|
Loading…
Reference in a new issue