Optimize tree training so that the best split is not applied twice
This commit is contained in:
parent
74151b94db
commit
9d9dc9ef8d
2 changed files with 20 additions and 30 deletions
|
@ -1,11 +1,9 @@
|
||||||
package ca.joeltherrien.randomforest.tree;
|
package ca.joeltherrien.randomforest.tree;
|
||||||
|
|
||||||
import ca.joeltherrien.randomforest.CovariateRow;
|
import ca.joeltherrien.randomforest.CovariateRow;
|
||||||
import lombok.RequiredArgsConstructor;
|
|
||||||
|
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
|
|
||||||
@RequiredArgsConstructor
|
|
||||||
public class Tree<Y> implements Node<Y> {
|
public class Tree<Y> implements Node<Y> {
|
||||||
|
|
||||||
private final Node<Y> rootNode;
|
private final Node<Y> rootNode;
|
||||||
|
@ -27,7 +25,6 @@ public class Tree<Y> implements Node<Y> {
|
||||||
return bootstrapRowIds.clone();
|
return bootstrapRowIds.clone();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
public boolean idInBootstrapSample(int id){
|
public boolean idInBootstrapSample(int id){
|
||||||
return Arrays.binarySearch(this.bootstrapRowIds, id) >= 0;
|
return Arrays.binarySearch(this.bootstrapRowIds, id) >= 0;
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package ca.joeltherrien.randomforest.tree;
|
package ca.joeltherrien.randomforest.tree;
|
||||||
|
|
||||||
import ca.joeltherrien.randomforest.*;
|
import ca.joeltherrien.randomforest.Row;
|
||||||
|
import ca.joeltherrien.randomforest.Settings;
|
||||||
import ca.joeltherrien.randomforest.covariates.Covariate;
|
import ca.joeltherrien.randomforest.covariates.Covariate;
|
||||||
import lombok.AccessLevel;
|
import lombok.AccessLevel;
|
||||||
import lombok.AllArgsConstructor;
|
import lombok.AllArgsConstructor;
|
||||||
|
@ -52,9 +53,9 @@ public class TreeTrainer<Y, O> {
|
||||||
// TODO; what is minimum per tree?
|
// TODO; what is minimum per tree?
|
||||||
if(data.size() >= 2*nodeSize && depth < maxNodeDepth && !nodeIsPure(data)){
|
if(data.size() >= 2*nodeSize && depth < maxNodeDepth && !nodeIsPure(data)){
|
||||||
final List<Covariate> covariatesToTry = selectCovariates(this.mtry);
|
final List<Covariate> covariatesToTry = selectCovariates(this.mtry);
|
||||||
final Covariate.SplitRule bestSplitRule = findBestSplitRule(data, covariatesToTry);
|
final SplitRuleAndSplit bestSplitRuleAndSplit = findBestSplitRule(data, covariatesToTry);
|
||||||
|
|
||||||
if(bestSplitRule == null){
|
if(bestSplitRuleAndSplit.splitRule == null){
|
||||||
|
|
||||||
return new TerminalNode<>(
|
return new TerminalNode<>(
|
||||||
responseCombiner.combine(
|
responseCombiner.combine(
|
||||||
|
@ -65,31 +66,14 @@ public class TreeTrainer<Y, O> {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
final Split<Y> split = bestSplitRule.applyRule(data); // TODO optimize this as we're duplicating work done in findBestSplitRule
|
final Split<Y> split = bestSplitRuleAndSplit.split;
|
||||||
|
// Note that NAs have already been handled
|
||||||
// We have to handle any NAs
|
|
||||||
if(split.leftHand.size() == 0 && split.rightHand.size() == 0 && split.naHand.size() > 0){
|
|
||||||
throw new IllegalArgumentException("Can't apply " + this + " when there are rows with missing data and no non-missing value rows");
|
|
||||||
}
|
|
||||||
|
|
||||||
final double probabilityLeftHand = (double) split.leftHand.size() / (double) (split.leftHand.size() + split.rightHand.size());
|
|
||||||
|
|
||||||
final Random random = ThreadLocalRandom.current();
|
|
||||||
for(final Row<Y> missingValueRow : split.naHand){
|
|
||||||
final boolean randomDecision = random.nextDouble() <= probabilityLeftHand;
|
|
||||||
if(randomDecision){
|
|
||||||
split.leftHand.add(missingValueRow);
|
|
||||||
}
|
|
||||||
else{
|
|
||||||
split.rightHand.add(missingValueRow);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
final Node<O> leftNode = growNode(split.leftHand, depth+1);
|
final Node<O> leftNode = growNode(split.leftHand, depth+1);
|
||||||
final Node<O> rightNode = growNode(split.rightHand, depth+1);
|
final Node<O> rightNode = growNode(split.rightHand, depth+1);
|
||||||
|
|
||||||
return new SplitNode<>(leftNode, rightNode, bestSplitRule, probabilityLeftHand);
|
return new SplitNode<>(leftNode, rightNode, bestSplitRuleAndSplit.splitRule, bestSplitRuleAndSplit.probabilityLeftHand);
|
||||||
|
|
||||||
}
|
}
|
||||||
else{
|
else{
|
||||||
|
@ -118,8 +102,8 @@ public class TreeTrainer<Y, O> {
|
||||||
return splitCovariates;
|
return splitCovariates;
|
||||||
}
|
}
|
||||||
|
|
||||||
private Covariate.SplitRule findBestSplitRule(List<Row<Y>> data, List<Covariate> covariatesToTry){
|
private SplitRuleAndSplit findBestSplitRule(List<Row<Y>> data, List<Covariate> covariatesToTry){
|
||||||
Covariate.SplitRule bestSplitRule = null;
|
SplitRuleAndSplit bestSplitRuleAndSplit = new SplitRuleAndSplit();
|
||||||
double bestSplitScore = 0.0;
|
double bestSplitScore = 0.0;
|
||||||
boolean first = true;
|
boolean first = true;
|
||||||
|
|
||||||
|
@ -163,7 +147,10 @@ public class TreeTrainer<Y, O> {
|
||||||
|
|
||||||
|
|
||||||
if(score != null && !Double.isNaN(score) && (score > bestSplitScore || first)){
|
if(score != null && !Double.isNaN(score) && (score > bestSplitScore || first)){
|
||||||
bestSplitRule = possibleRule;
|
bestSplitRuleAndSplit.splitRule = possibleRule;
|
||||||
|
bestSplitRuleAndSplit.split = possibleSplit;
|
||||||
|
bestSplitRuleAndSplit.probabilityLeftHand = probabilityLeftHand;
|
||||||
|
|
||||||
bestSplitScore = score;
|
bestSplitScore = score;
|
||||||
first = false;
|
first = false;
|
||||||
}
|
}
|
||||||
|
@ -171,7 +158,7 @@ public class TreeTrainer<Y, O> {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return bestSplitRule;
|
return bestSplitRuleAndSplit;
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -180,4 +167,10 @@ public class TreeTrainer<Y, O> {
|
||||||
return data.stream().allMatch(row -> row.getResponse().equals(first));
|
return data.stream().allMatch(row -> row.getResponse().equals(first));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private class SplitRuleAndSplit{
|
||||||
|
private Covariate.SplitRule splitRule = null;
|
||||||
|
private Split<Y> split = null;
|
||||||
|
private double probabilityLeftHand;
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue