diff --git a/src/main/java/ca/joeltherrien/randomforest/tree/TreeTrainer.java b/src/main/java/ca/joeltherrien/randomforest/tree/TreeTrainer.java index 907f8ef..e830f4d 100644 --- a/src/main/java/ca/joeltherrien/randomforest/tree/TreeTrainer.java +++ b/src/main/java/ca/joeltherrien/randomforest/tree/TreeTrainer.java @@ -31,9 +31,18 @@ public class TreeTrainer { private Node growNode(List> data, List covariatesToTry, int depth){ // TODO; what is minimum per tree? - if(data.size() >= 2*nodeSize && depth < maxNodeDepth && !nodeIsPure(data, covariatesToTry)){ + if(data.size() >= 2*nodeSize && depth < maxNodeDepth && !nodeIsPure(data)){ final SplitRule bestSplitRule = findBestSplitRule(data, covariatesToTry); + if(bestSplitRule == null){ + return new TerminalNode<>( + data.stream() + .map(row -> row.getResponse()) + .collect(responseCombiner) + + ); + } + final Split split = bestSplitRule.applyRule(data); // TODO optimize this as we're duplicating work done in findBestSplitRule final Node leftNode = growNode(split.leftHand, covariatesToTry, depth+1); @@ -56,7 +65,7 @@ public class TreeTrainer { private SplitRule findBestSplitRule(List> data, List covariatesToTry){ SplitRule bestSplitRule = null; - Double bestSplitScore = 0.0; // may be null + double bestSplitScore = 0.0; boolean first = true; for(final String covariate : covariatesToTry){ @@ -92,7 +101,7 @@ public class TreeTrainer { possibleSplit.rightHand.stream().map(row -> row.getResponse()).collect(Collectors.toList()) ); - if( first || (score != null && (bestSplitScore == null || score > bestSplitScore))){ + if(score != null && (score > bestSplitScore || first)){ bestSplitRule = possibleRule; bestSplitScore = score; first = false; @@ -107,9 +116,7 @@ public class TreeTrainer { } - private boolean nodeIsPure(List> data, List covariatesToTry){ - // TODO how is this done? - + private boolean nodeIsPure(List> data){ final Y first = data.get(0).getResponse(); return data.stream().allMatch(row -> row.getResponse().equals(first)); }