Optimize tree training so that the best split is not applied twice
This commit is contained in:
parent
74151b94db
commit
9d9dc9ef8d
2 changed files with 20 additions and 30 deletions
|
@ -1,11 +1,9 @@
|
|||
package ca.joeltherrien.randomforest.tree;
|
||||
|
||||
import ca.joeltherrien.randomforest.CovariateRow;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
|
||||
import java.util.Arrays;
|
||||
|
||||
@RequiredArgsConstructor
|
||||
public class Tree<Y> implements Node<Y> {
|
||||
|
||||
private final Node<Y> rootNode;
|
||||
|
@ -27,7 +25,6 @@ public class Tree<Y> implements Node<Y> {
|
|||
return bootstrapRowIds.clone();
|
||||
}
|
||||
|
||||
|
||||
public boolean idInBootstrapSample(int id){
|
||||
return Arrays.binarySearch(this.bootstrapRowIds, id) >= 0;
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package ca.joeltherrien.randomforest.tree;
|
||||
|
||||
import ca.joeltherrien.randomforest.*;
|
||||
import ca.joeltherrien.randomforest.Row;
|
||||
import ca.joeltherrien.randomforest.Settings;
|
||||
import ca.joeltherrien.randomforest.covariates.Covariate;
|
||||
import lombok.AccessLevel;
|
||||
import lombok.AllArgsConstructor;
|
||||
|
@ -52,9 +53,9 @@ public class TreeTrainer<Y, O> {
|
|||
// TODO; what is minimum per tree?
|
||||
if(data.size() >= 2*nodeSize && depth < maxNodeDepth && !nodeIsPure(data)){
|
||||
final List<Covariate> covariatesToTry = selectCovariates(this.mtry);
|
||||
final Covariate.SplitRule bestSplitRule = findBestSplitRule(data, covariatesToTry);
|
||||
final SplitRuleAndSplit bestSplitRuleAndSplit = findBestSplitRule(data, covariatesToTry);
|
||||
|
||||
if(bestSplitRule == null){
|
||||
if(bestSplitRuleAndSplit.splitRule == null){
|
||||
|
||||
return new TerminalNode<>(
|
||||
responseCombiner.combine(
|
||||
|
@ -65,31 +66,14 @@ public class TreeTrainer<Y, O> {
|
|||
|
||||
}
|
||||
|
||||
final Split<Y> split = bestSplitRule.applyRule(data); // TODO optimize this as we're duplicating work done in findBestSplitRule
|
||||
|
||||
// We have to handle any NAs
|
||||
if(split.leftHand.size() == 0 && split.rightHand.size() == 0 && split.naHand.size() > 0){
|
||||
throw new IllegalArgumentException("Can't apply " + this + " when there are rows with missing data and no non-missing value rows");
|
||||
}
|
||||
|
||||
final double probabilityLeftHand = (double) split.leftHand.size() / (double) (split.leftHand.size() + split.rightHand.size());
|
||||
|
||||
final Random random = ThreadLocalRandom.current();
|
||||
for(final Row<Y> missingValueRow : split.naHand){
|
||||
final boolean randomDecision = random.nextDouble() <= probabilityLeftHand;
|
||||
if(randomDecision){
|
||||
split.leftHand.add(missingValueRow);
|
||||
}
|
||||
else{
|
||||
split.rightHand.add(missingValueRow);
|
||||
}
|
||||
}
|
||||
final Split<Y> split = bestSplitRuleAndSplit.split;
|
||||
// Note that NAs have already been handled
|
||||
|
||||
|
||||
final Node<O> leftNode = growNode(split.leftHand, depth+1);
|
||||
final Node<O> rightNode = growNode(split.rightHand, depth+1);
|
||||
|
||||
return new SplitNode<>(leftNode, rightNode, bestSplitRule, probabilityLeftHand);
|
||||
return new SplitNode<>(leftNode, rightNode, bestSplitRuleAndSplit.splitRule, bestSplitRuleAndSplit.probabilityLeftHand);
|
||||
|
||||
}
|
||||
else{
|
||||
|
@ -118,8 +102,8 @@ public class TreeTrainer<Y, O> {
|
|||
return splitCovariates;
|
||||
}
|
||||
|
||||
private Covariate.SplitRule findBestSplitRule(List<Row<Y>> data, List<Covariate> covariatesToTry){
|
||||
Covariate.SplitRule bestSplitRule = null;
|
||||
private SplitRuleAndSplit findBestSplitRule(List<Row<Y>> data, List<Covariate> covariatesToTry){
|
||||
SplitRuleAndSplit bestSplitRuleAndSplit = new SplitRuleAndSplit();
|
||||
double bestSplitScore = 0.0;
|
||||
boolean first = true;
|
||||
|
||||
|
@ -163,7 +147,10 @@ public class TreeTrainer<Y, O> {
|
|||
|
||||
|
||||
if(score != null && !Double.isNaN(score) && (score > bestSplitScore || first)){
|
||||
bestSplitRule = possibleRule;
|
||||
bestSplitRuleAndSplit.splitRule = possibleRule;
|
||||
bestSplitRuleAndSplit.split = possibleSplit;
|
||||
bestSplitRuleAndSplit.probabilityLeftHand = probabilityLeftHand;
|
||||
|
||||
bestSplitScore = score;
|
||||
first = false;
|
||||
}
|
||||
|
@ -171,7 +158,7 @@ public class TreeTrainer<Y, O> {
|
|||
|
||||
}
|
||||
|
||||
return bestSplitRule;
|
||||
return bestSplitRuleAndSplit;
|
||||
|
||||
}
|
||||
|
||||
|
@ -180,4 +167,10 @@ public class TreeTrainer<Y, O> {
|
|||
return data.stream().allMatch(row -> row.getResponse().equals(first));
|
||||
}
|
||||
|
||||
private class SplitRuleAndSplit{
|
||||
private Covariate.SplitRule splitRule = null;
|
||||
private Split<Y> split = null;
|
||||
private double probabilityLeftHand;
|
||||
}
|
||||
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue