diff --git a/src/main/java/ca/joeltherrien/randomforest/tree/Tree.java b/src/main/java/ca/joeltherrien/randomforest/tree/Tree.java index 48b4663..12d3d03 100644 --- a/src/main/java/ca/joeltherrien/randomforest/tree/Tree.java +++ b/src/main/java/ca/joeltherrien/randomforest/tree/Tree.java @@ -1,11 +1,9 @@ package ca.joeltherrien.randomforest.tree; import ca.joeltherrien.randomforest.CovariateRow; -import lombok.RequiredArgsConstructor; import java.util.Arrays; -@RequiredArgsConstructor public class Tree implements Node { private final Node rootNode; @@ -27,7 +25,6 @@ public class Tree implements Node { return bootstrapRowIds.clone(); } - public boolean idInBootstrapSample(int id){ return Arrays.binarySearch(this.bootstrapRowIds, id) >= 0; } diff --git a/src/main/java/ca/joeltherrien/randomforest/tree/TreeTrainer.java b/src/main/java/ca/joeltherrien/randomforest/tree/TreeTrainer.java index e33be28..15f0cdb 100644 --- a/src/main/java/ca/joeltherrien/randomforest/tree/TreeTrainer.java +++ b/src/main/java/ca/joeltherrien/randomforest/tree/TreeTrainer.java @@ -1,6 +1,7 @@ package ca.joeltherrien.randomforest.tree; -import ca.joeltherrien.randomforest.*; +import ca.joeltherrien.randomforest.Row; +import ca.joeltherrien.randomforest.Settings; import ca.joeltherrien.randomforest.covariates.Covariate; import lombok.AccessLevel; import lombok.AllArgsConstructor; @@ -52,9 +53,9 @@ public class TreeTrainer { // TODO; what is minimum per tree? if(data.size() >= 2*nodeSize && depth < maxNodeDepth && !nodeIsPure(data)){ final List covariatesToTry = selectCovariates(this.mtry); - final Covariate.SplitRule bestSplitRule = findBestSplitRule(data, covariatesToTry); + final SplitRuleAndSplit bestSplitRuleAndSplit = findBestSplitRule(data, covariatesToTry); - if(bestSplitRule == null){ + if(bestSplitRuleAndSplit.splitRule == null){ return new TerminalNode<>( responseCombiner.combine( @@ -65,31 +66,14 @@ public class TreeTrainer { } - final Split split = bestSplitRule.applyRule(data); // TODO optimize this as we're duplicating work done in findBestSplitRule - - // We have to handle any NAs - if(split.leftHand.size() == 0 && split.rightHand.size() == 0 && split.naHand.size() > 0){ - throw new IllegalArgumentException("Can't apply " + this + " when there are rows with missing data and no non-missing value rows"); - } - - final double probabilityLeftHand = (double) split.leftHand.size() / (double) (split.leftHand.size() + split.rightHand.size()); - - final Random random = ThreadLocalRandom.current(); - for(final Row missingValueRow : split.naHand){ - final boolean randomDecision = random.nextDouble() <= probabilityLeftHand; - if(randomDecision){ - split.leftHand.add(missingValueRow); - } - else{ - split.rightHand.add(missingValueRow); - } - } + final Split split = bestSplitRuleAndSplit.split; + // Note that NAs have already been handled final Node leftNode = growNode(split.leftHand, depth+1); final Node rightNode = growNode(split.rightHand, depth+1); - return new SplitNode<>(leftNode, rightNode, bestSplitRule, probabilityLeftHand); + return new SplitNode<>(leftNode, rightNode, bestSplitRuleAndSplit.splitRule, bestSplitRuleAndSplit.probabilityLeftHand); } else{ @@ -118,8 +102,8 @@ public class TreeTrainer { return splitCovariates; } - private Covariate.SplitRule findBestSplitRule(List> data, List covariatesToTry){ - Covariate.SplitRule bestSplitRule = null; + private SplitRuleAndSplit findBestSplitRule(List> data, List covariatesToTry){ + SplitRuleAndSplit bestSplitRuleAndSplit = new SplitRuleAndSplit(); double bestSplitScore = 0.0; boolean first = true; @@ -163,7 +147,10 @@ public class TreeTrainer { if(score != null && !Double.isNaN(score) && (score > bestSplitScore || first)){ - bestSplitRule = possibleRule; + bestSplitRuleAndSplit.splitRule = possibleRule; + bestSplitRuleAndSplit.split = possibleSplit; + bestSplitRuleAndSplit.probabilityLeftHand = probabilityLeftHand; + bestSplitScore = score; first = false; } @@ -171,7 +158,7 @@ public class TreeTrainer { } - return bestSplitRule; + return bestSplitRuleAndSplit; } @@ -180,4 +167,10 @@ public class TreeTrainer { return data.stream().allMatch(row -> row.getResponse().equals(first)); } + private class SplitRuleAndSplit{ + private Covariate.SplitRule splitRule = null; + private Split split = null; + private double probabilityLeftHand; + } + }