Optimize tree training so that the best split is not applied twice

This commit is contained in:
Joel Therrien 2018-08-08 11:34:02 -07:00
parent 74151b94db
commit 9d9dc9ef8d
2 changed files with 20 additions and 30 deletions

View file

@ -1,11 +1,9 @@
package ca.joeltherrien.randomforest.tree; package ca.joeltherrien.randomforest.tree;
import ca.joeltherrien.randomforest.CovariateRow; import ca.joeltherrien.randomforest.CovariateRow;
import lombok.RequiredArgsConstructor;
import java.util.Arrays; import java.util.Arrays;
@RequiredArgsConstructor
public class Tree<Y> implements Node<Y> { public class Tree<Y> implements Node<Y> {
private final Node<Y> rootNode; private final Node<Y> rootNode;
@ -27,7 +25,6 @@ public class Tree<Y> implements Node<Y> {
return bootstrapRowIds.clone(); return bootstrapRowIds.clone();
} }
public boolean idInBootstrapSample(int id){ public boolean idInBootstrapSample(int id){
return Arrays.binarySearch(this.bootstrapRowIds, id) >= 0; return Arrays.binarySearch(this.bootstrapRowIds, id) >= 0;
} }

View file

@ -1,6 +1,7 @@
package ca.joeltherrien.randomforest.tree; package ca.joeltherrien.randomforest.tree;
import ca.joeltherrien.randomforest.*; import ca.joeltherrien.randomforest.Row;
import ca.joeltherrien.randomforest.Settings;
import ca.joeltherrien.randomforest.covariates.Covariate; import ca.joeltherrien.randomforest.covariates.Covariate;
import lombok.AccessLevel; import lombok.AccessLevel;
import lombok.AllArgsConstructor; import lombok.AllArgsConstructor;
@ -52,9 +53,9 @@ public class TreeTrainer<Y, O> {
// TODO; what is minimum per tree? // TODO; what is minimum per tree?
if(data.size() >= 2*nodeSize && depth < maxNodeDepth && !nodeIsPure(data)){ if(data.size() >= 2*nodeSize && depth < maxNodeDepth && !nodeIsPure(data)){
final List<Covariate> covariatesToTry = selectCovariates(this.mtry); final List<Covariate> covariatesToTry = selectCovariates(this.mtry);
final Covariate.SplitRule bestSplitRule = findBestSplitRule(data, covariatesToTry); final SplitRuleAndSplit bestSplitRuleAndSplit = findBestSplitRule(data, covariatesToTry);
if(bestSplitRule == null){ if(bestSplitRuleAndSplit.splitRule == null){
return new TerminalNode<>( return new TerminalNode<>(
responseCombiner.combine( responseCombiner.combine(
@ -65,31 +66,14 @@ public class TreeTrainer<Y, O> {
} }
final Split<Y> split = bestSplitRule.applyRule(data); // TODO optimize this as we're duplicating work done in findBestSplitRule final Split<Y> split = bestSplitRuleAndSplit.split;
// Note that NAs have already been handled
// We have to handle any NAs
if(split.leftHand.size() == 0 && split.rightHand.size() == 0 && split.naHand.size() > 0){
throw new IllegalArgumentException("Can't apply " + this + " when there are rows with missing data and no non-missing value rows");
}
final double probabilityLeftHand = (double) split.leftHand.size() / (double) (split.leftHand.size() + split.rightHand.size());
final Random random = ThreadLocalRandom.current();
for(final Row<Y> missingValueRow : split.naHand){
final boolean randomDecision = random.nextDouble() <= probabilityLeftHand;
if(randomDecision){
split.leftHand.add(missingValueRow);
}
else{
split.rightHand.add(missingValueRow);
}
}
final Node<O> leftNode = growNode(split.leftHand, depth+1); final Node<O> leftNode = growNode(split.leftHand, depth+1);
final Node<O> rightNode = growNode(split.rightHand, depth+1); final Node<O> rightNode = growNode(split.rightHand, depth+1);
return new SplitNode<>(leftNode, rightNode, bestSplitRule, probabilityLeftHand); return new SplitNode<>(leftNode, rightNode, bestSplitRuleAndSplit.splitRule, bestSplitRuleAndSplit.probabilityLeftHand);
} }
else{ else{
@ -118,8 +102,8 @@ public class TreeTrainer<Y, O> {
return splitCovariates; return splitCovariates;
} }
private Covariate.SplitRule findBestSplitRule(List<Row<Y>> data, List<Covariate> covariatesToTry){ private SplitRuleAndSplit findBestSplitRule(List<Row<Y>> data, List<Covariate> covariatesToTry){
Covariate.SplitRule bestSplitRule = null; SplitRuleAndSplit bestSplitRuleAndSplit = new SplitRuleAndSplit();
double bestSplitScore = 0.0; double bestSplitScore = 0.0;
boolean first = true; boolean first = true;
@ -163,7 +147,10 @@ public class TreeTrainer<Y, O> {
if(score != null && !Double.isNaN(score) && (score > bestSplitScore || first)){ if(score != null && !Double.isNaN(score) && (score > bestSplitScore || first)){
bestSplitRule = possibleRule; bestSplitRuleAndSplit.splitRule = possibleRule;
bestSplitRuleAndSplit.split = possibleSplit;
bestSplitRuleAndSplit.probabilityLeftHand = probabilityLeftHand;
bestSplitScore = score; bestSplitScore = score;
first = false; first = false;
} }
@ -171,7 +158,7 @@ public class TreeTrainer<Y, O> {
} }
return bestSplitRule; return bestSplitRuleAndSplit;
} }
@ -180,4 +167,10 @@ public class TreeTrainer<Y, O> {
return data.stream().allMatch(row -> row.getResponse().equals(first)); return data.stream().allMatch(row -> row.getResponse().equals(first));
} }
private class SplitRuleAndSplit{
private Covariate.SplitRule splitRule = null;
private Split<Y> split = null;
private double probabilityLeftHand;
}
} }