From cce5ad1e0f83897fc611881d385f03f459471fff Mon Sep 17 00:00:00 2001 From: Joel Therrien Date: Mon, 15 Oct 2018 11:03:35 -0700 Subject: [PATCH] Add parameter to decide on whether to check for node purity or not --- .../joeltherrien/randomforest/Settings.java | 1 + .../randomforest/tree/TreeTrainer.java | 25 +++++++++++++++++-- 2 files changed, 24 insertions(+), 2 deletions(-) diff --git a/src/main/java/ca/joeltherrien/randomforest/Settings.java b/src/main/java/ca/joeltherrien/randomforest/Settings.java index 7d737bb..e9f8b7f 100644 --- a/src/main/java/ca/joeltherrien/randomforest/Settings.java +++ b/src/main/java/ca/joeltherrien/randomforest/Settings.java @@ -177,6 +177,7 @@ public class Settings { private int numberOfSplits = 5; private int nodeSize = 5; private int maxNodeDepth = 1000000; // basically no maxNodeDepth + private boolean checkNodePurity = false; private ObjectNode responseCombinerSettings = new ObjectNode(JsonNodeFactory.instance); private ObjectNode groupDifferentiatorSettings = new ObjectNode(JsonNodeFactory.instance); diff --git a/src/main/java/ca/joeltherrien/randomforest/tree/TreeTrainer.java b/src/main/java/ca/joeltherrien/randomforest/tree/TreeTrainer.java index 98805b5..0e359b4 100644 --- a/src/main/java/ca/joeltherrien/randomforest/tree/TreeTrainer.java +++ b/src/main/java/ca/joeltherrien/randomforest/tree/TreeTrainer.java @@ -27,6 +27,12 @@ public class TreeTrainer { private final int maxNodeDepth; private final int mtry; + /** + * Whether to check if a node is pure or not when deciding to split. Splitting on a pure node won't change predictive accuracy, + * but (depending on conditions) may hurt performance. + */ + private final boolean checkNodePurity; + private final List covariates; public TreeTrainer(final Settings settings, final List covariates){ @@ -34,6 +40,7 @@ public class TreeTrainer { this.nodeSize = settings.getNodeSize(); this.maxNodeDepth = settings.getMaxNodeDepth(); this.mtry = settings.getMtry(); + this.checkNodePurity = settings.isCheckNodePurity(); this.responseCombiner = settings.getResponseCombiner(); this.groupDifferentiator = settings.getGroupDifferentiator(); @@ -48,7 +55,7 @@ public class TreeTrainer { } private Node growNode(List> data, int depth){ - // TODO; what is minimum per tree? + // See https://kogalur.github.io/randomForestSRC/theory.html#section3.1 (near bottom) if(data.size() >= 2*nodeSize && depth < maxNodeDepth && !nodeIsPure(data)){ final List covariatesToTry = selectCovariates(this.mtry); final SplitRuleAndSplit bestSplitRuleAndSplit = findBestSplitRule(data, covariatesToTry); @@ -161,8 +168,22 @@ public class TreeTrainer { } private boolean nodeIsPure(List> data){ + if(!checkNodePurity){ + return false; + } + + if(data.size() <= 1){ + return true; + } + final Y first = data.get(0).getResponse(); - return data.stream().allMatch(row -> row.getResponse().equals(first)); + for(int i = 1; i< data.size(); i++){ + if(!data.get(i).getResponse().equals(first)){ + return false; + } + } + + return true; } private class SplitRuleAndSplit{