Fixed a bug where Splits could be generated that had an empty daughter

node
This commit is contained in:
Joel Therrien 2018-07-03 15:15:09 -07:00
parent 254727e594
commit e7af65e8fd

View file

@ -31,9 +31,18 @@ public class TreeTrainer<Y> {
private Node<Y> growNode(List<Row<Y>> data, List<String> covariatesToTry, int depth){ private Node<Y> growNode(List<Row<Y>> data, List<String> covariatesToTry, int depth){
// TODO; what is minimum per tree? // TODO; what is minimum per tree?
if(data.size() >= 2*nodeSize && depth < maxNodeDepth && !nodeIsPure(data, covariatesToTry)){ if(data.size() >= 2*nodeSize && depth < maxNodeDepth && !nodeIsPure(data)){
final SplitRule bestSplitRule = findBestSplitRule(data, covariatesToTry); final SplitRule bestSplitRule = findBestSplitRule(data, covariatesToTry);
if(bestSplitRule == null){
return new TerminalNode<>(
data.stream()
.map(row -> row.getResponse())
.collect(responseCombiner)
);
}
final Split<Y> split = bestSplitRule.applyRule(data); // TODO optimize this as we're duplicating work done in findBestSplitRule final Split<Y> split = bestSplitRule.applyRule(data); // TODO optimize this as we're duplicating work done in findBestSplitRule
final Node<Y> leftNode = growNode(split.leftHand, covariatesToTry, depth+1); final Node<Y> leftNode = growNode(split.leftHand, covariatesToTry, depth+1);
@ -56,7 +65,7 @@ public class TreeTrainer<Y> {
private SplitRule findBestSplitRule(List<Row<Y>> data, List<String> covariatesToTry){ private SplitRule findBestSplitRule(List<Row<Y>> data, List<String> covariatesToTry){
SplitRule bestSplitRule = null; SplitRule bestSplitRule = null;
Double bestSplitScore = 0.0; // may be null double bestSplitScore = 0.0;
boolean first = true; boolean first = true;
for(final String covariate : covariatesToTry){ for(final String covariate : covariatesToTry){
@ -92,7 +101,7 @@ public class TreeTrainer<Y> {
possibleSplit.rightHand.stream().map(row -> row.getResponse()).collect(Collectors.toList()) possibleSplit.rightHand.stream().map(row -> row.getResponse()).collect(Collectors.toList())
); );
if( first || (score != null && (bestSplitScore == null || score > bestSplitScore))){ if(score != null && (score > bestSplitScore || first)){
bestSplitRule = possibleRule; bestSplitRule = possibleRule;
bestSplitScore = score; bestSplitScore = score;
first = false; first = false;
@ -107,9 +116,7 @@ public class TreeTrainer<Y> {
} }
private boolean nodeIsPure(List<Row<Y>> data, List<String> covariatesToTry){ private boolean nodeIsPure(List<Row<Y>> data){
// TODO how is this done?
final Y first = data.get(0).getResponse(); final Y first = data.get(0).getResponse();
return data.stream().allMatch(row -> row.getResponse().equals(first)); return data.stream().allMatch(row -> row.getResponse().equals(first));
} }