diff --git a/.gitignore b/.gitignore
index f61aa0d..73c3108 100644
--- a/.gitignore
+++ b/.gitignore
@@ -2,3 +2,5 @@
.settings
.project
target/
+*.iml
+.idea
diff --git a/pom.xml b/pom.xml
new file mode 100644
index 0000000..637f869
--- /dev/null
+++ b/pom.xml
@@ -0,0 +1,31 @@
+
+
+ 4.0.0
+
+ ca.joeltherrien
+ RandomSurvivalForests
+ 1.0-SNAPSHOT
+
+
+ 1.10
+ 1.10
+ 1.10
+
+
+
+
+
+ org.projectlombok
+ lombok
+ 1.18.0
+ provided
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/src/ca/joeltherrien/randomforest/Main.java b/src/ca/joeltherrien/randomforest/Main.java
deleted file mode 100644
index 1ae6537..0000000
--- a/src/ca/joeltherrien/randomforest/Main.java
+++ /dev/null
@@ -1,10 +0,0 @@
-package ca.joeltherrien.randomforest;
-
-public class Main {
-
- public static void main(String[] args) {
- System.out.println("Hello world!");
-
- }
-
-}
diff --git a/src/ca/joeltherrien/randomforest/Node.java b/src/ca/joeltherrien/randomforest/Node.java
deleted file mode 100644
index fa6efce..0000000
--- a/src/ca/joeltherrien/randomforest/Node.java
+++ /dev/null
@@ -1,5 +0,0 @@
-package ca.joeltherrien.randomforest;
-
-public class Node {
-
-}
diff --git a/src/ca/joeltherrien/randomforest/NumericSplitRule.java b/src/ca/joeltherrien/randomforest/NumericSplitRule.java
deleted file mode 100644
index d4e3fa9..0000000
--- a/src/ca/joeltherrien/randomforest/NumericSplitRule.java
+++ /dev/null
@@ -1,44 +0,0 @@
-package ca.joeltherrien.randomforest;
-
-import java.util.LinkedList;
-import java.util.List;
-
-import ca.joeltherrien.randomforest.exceptions.MissingValueException;
-
-public class NumericSplitRule implements SplitRule{
-
- public final String covariateName;
- public final double threshold;
-
- public NumericSplitRule(String covariateName, double threshold) {
- super();
- this.covariateName = covariateName;
- this.threshold = threshold;
- }
-
- @Override
- public final String toString() {
- return "NumericSplitRule on " + covariateName + " at " + threshold;
- }
-
- @Override
- public Split applyRule(List> rows) {
- final List> leftHand = new LinkedList<>();
- final List> rightHand = new LinkedList<>();
-
- for(final Row row : rows) {
- final Value x = row.getCovariate(covariateName);
- if(x == null) {
- throw new MissingValueException(row, this);
- }
-
- final NumericValue xNum = (NumericValue) x;
-
- }
-
- // TODO Auto-generated method stub
- return null;
- }
-
-
-}
diff --git a/src/ca/joeltherrien/randomforest/Row.java b/src/ca/joeltherrien/randomforest/Row.java
deleted file mode 100644
index ca379ae..0000000
--- a/src/ca/joeltherrien/randomforest/Row.java
+++ /dev/null
@@ -1,33 +0,0 @@
-package ca.joeltherrien.randomforest;
-
-import java.util.Map;
-
-public class Row {
-
- private final Map covariates;
- private final Y response;
- private final int id;
-
- public Row(Map covariates, Y response, int id) {
- super();
- this.covariates = covariates;
- this.response = response;
- this.id = id;
- }
-
- public Value getCovariate(String name) {
- return this.covariates.get(name);
- }
-
- public Y getResponse() {
- return this.response;
- }
-
- @Override
- public String toString() {
- return "Row " + this.id;
- }
-
-
-
-}
diff --git a/src/ca/joeltherrien/randomforest/SplitRule.java b/src/ca/joeltherrien/randomforest/SplitRule.java
deleted file mode 100644
index d7efca6..0000000
--- a/src/ca/joeltherrien/randomforest/SplitRule.java
+++ /dev/null
@@ -1,9 +0,0 @@
-package ca.joeltherrien.randomforest;
-
-import java.util.List;
-
-public interface SplitRule {
-
- Split applyRule(List> rows);
-
-}
diff --git a/src/ca/joeltherrien/randomforest/Value.java b/src/ca/joeltherrien/randomforest/Value.java
deleted file mode 100644
index f92affe..0000000
--- a/src/ca/joeltherrien/randomforest/Value.java
+++ /dev/null
@@ -1,7 +0,0 @@
-package ca.joeltherrien.randomforest;
-
-public interface Value {
-
- // TODO
-
-}
diff --git a/src/main/java/ca/joeltherrien/randomforest/CovariateRow.java b/src/main/java/ca/joeltherrien/randomforest/CovariateRow.java
new file mode 100644
index 0000000..beccc3a
--- /dev/null
+++ b/src/main/java/ca/joeltherrien/randomforest/CovariateRow.java
@@ -0,0 +1,26 @@
+package ca.joeltherrien.randomforest;
+
+import lombok.Getter;
+import lombok.RequiredArgsConstructor;
+
+import java.util.Map;
+
+@RequiredArgsConstructor
+public class CovariateRow {
+
+ private final Map valueMap;
+
+ @Getter
+ private final int id;
+
+ public Value> getCovariate(String name){
+ return valueMap.get(name);
+
+ }
+
+ @Override
+ public String toString(){
+ return "CovariateRow " + this.id;
+ }
+
+}
diff --git a/src/main/java/ca/joeltherrien/randomforest/Main.java b/src/main/java/ca/joeltherrien/randomforest/Main.java
new file mode 100644
index 0000000..b2a8edb
--- /dev/null
+++ b/src/main/java/ca/joeltherrien/randomforest/Main.java
@@ -0,0 +1,146 @@
+package ca.joeltherrien.randomforest;
+
+
+import ca.joeltherrien.randomforest.regression.MeanGroupDifferentiator;
+import ca.joeltherrien.randomforest.regression.MeanResponseCombiner;
+import ca.joeltherrien.randomforest.regression.WeightedVarianceGroupDifferentiator;
+import ca.joeltherrien.randomforest.tree.Node;
+import ca.joeltherrien.randomforest.tree.TreeTrainer;
+
+import java.util.*;
+import java.util.stream.Collectors;
+import java.util.stream.DoubleStream;
+
+public class Main {
+
+ public static void main(String[] args) {
+ System.out.println("Hello world!");
+
+ final Random random = new Random(123);
+
+ final int n = 1000;
+ final List> trainingSet = new ArrayList<>(n);
+
+ final List> x1List = DoubleStream
+ .generate(() -> random.nextDouble()*10.0)
+ .limit(n)
+ .mapToObj(x1 -> new NumericValue(x1))
+ .collect(Collectors.toList());
+
+ final List> x2List = DoubleStream
+ .generate(() -> random.nextDouble()*10.0)
+ .limit(n)
+ .mapToObj(x1 -> new NumericValue(x1))
+ .collect(Collectors.toList());
+
+
+
+ for(int i=0; i treeTrainer = TreeTrainer.builder()
+ .groupDifferentiator(new WeightedVarianceGroupDifferentiator())
+ .responseCombiner(new MeanResponseCombiner())
+ .maxNodeDepth(30)
+ .nodeSize(5)
+ .numberOfSplits(0)
+ .build();
+
+ final long endTime = System.currentTimeMillis();
+
+ System.out.println(((double)(endTime - startTime))/1000.0);
+
+ final List covariateNames = List.of("x1", "x2");
+
+ final Node baseNode = treeTrainer.growTree(trainingSet, covariateNames);
+
+
+ final List testSet = new ArrayList<>();
+ testSet.add(generateCovariateRow(9, 2, 1)); // expect 1
+ testSet.add(generateCovariateRow(5, 2, 5));
+ testSet.add(generateCovariateRow(2, 2, 3));
+ testSet.add(generateCovariateRow(9, 5, 0));
+ testSet.add(generateCovariateRow(6, 5, 8));
+ testSet.add(generateCovariateRow(3, 5, 10));
+ testSet.add(generateCovariateRow(1, 5, 3));
+ testSet.add(generateCovariateRow(7, 9, 2));
+ testSet.add(generateCovariateRow(1, 9, 4));
+
+ for(final CovariateRow testCase : testSet){
+ System.out.println(testCase);
+ System.out.println(baseNode.evaluate(testCase));
+ System.out.println();
+
+
+ }
+
+
+
+
+
+ }
+
+ public static Row generateRow(double x1, double x2, int id){
+ double y = generateResponse(x1, x2);
+
+ final Map map = Map.of("x1", new NumericValue(x1), "x2", new NumericValue(x2));
+
+ return new Row<>(map, id, y);
+
+ }
+
+
+ public static CovariateRow generateCovariateRow(double x1, double x2, int id){
+ final Map map = Map.of("x1", new NumericValue(x1), "x2", new NumericValue(x2));
+
+ return new CovariateRow(map, id);
+
+ }
+
+
+ public static double generateResponse(double x1, double x2){
+ if(x2 <= 3){
+ if(x1 <= 3){
+ return 3;
+ }
+ else if(x1 <= 7){
+ return 5;
+ }
+ else{
+ return 1;
+ }
+ }
+ else if(x1 >= 5){
+ if(x2 > 6){
+ return 2;
+ }
+ else if(x1 >= 8){
+ return 0;
+ }
+ else{
+ return 8;
+ }
+ }
+ else if(x1 <= 2){
+ if(x2 >= 7){
+ return 4;
+ }
+ else{
+ return 3;
+ }
+ }
+ else{
+ return 10;
+ }
+
+
+ }
+
+}
diff --git a/src/main/java/ca/joeltherrien/randomforest/NumericSplitRule.java b/src/main/java/ca/joeltherrien/randomforest/NumericSplitRule.java
new file mode 100644
index 0000000..505a58c
--- /dev/null
+++ b/src/main/java/ca/joeltherrien/randomforest/NumericSplitRule.java
@@ -0,0 +1,33 @@
+package ca.joeltherrien.randomforest;
+
+import java.util.LinkedList;
+import java.util.List;
+
+import ca.joeltherrien.randomforest.exceptions.MissingValueException;
+import lombok.AllArgsConstructor;
+
+@AllArgsConstructor
+public class NumericSplitRule extends SplitRule{
+
+ public final String covariateName;
+ public final double threshold;
+
+ @Override
+ public final String toString() {
+ return "NumericSplitRule on " + covariateName + " at " + threshold;
+ }
+
+ @Override
+ public boolean isLeftHand(CovariateRow row) {
+ final Value> x = row.getCovariate(covariateName);
+ if(x == null) {
+ throw new MissingValueException(row, this);
+ }
+
+ final double xNum = (Double) x.getValue();
+
+ return xNum <= threshold;
+ }
+
+
+}
diff --git a/src/main/java/ca/joeltherrien/randomforest/NumericValue.java b/src/main/java/ca/joeltherrien/randomforest/NumericValue.java
new file mode 100644
index 0000000..9defa06
--- /dev/null
+++ b/src/main/java/ca/joeltherrien/randomforest/NumericValue.java
@@ -0,0 +1,19 @@
+package ca.joeltherrien.randomforest;
+
+import lombok.RequiredArgsConstructor;
+
+@RequiredArgsConstructor
+public class NumericValue implements Value {
+
+ private final double value;
+
+ @Override
+ public Double getValue() {
+ return value;
+ }
+
+ @Override
+ public SplitRule generateSplitRule(final String covariateName) {
+ return new NumericSplitRule(covariateName, value);
+ }
+}
diff --git a/src/main/java/ca/joeltherrien/randomforest/ResponseCombiner.java b/src/main/java/ca/joeltherrien/randomforest/ResponseCombiner.java
new file mode 100644
index 0000000..7f24928
--- /dev/null
+++ b/src/main/java/ca/joeltherrien/randomforest/ResponseCombiner.java
@@ -0,0 +1,9 @@
+package ca.joeltherrien.randomforest;
+
+import java.util.List;
+
+public interface ResponseCombiner {
+
+ Y combine(List responses);
+
+}
diff --git a/src/main/java/ca/joeltherrien/randomforest/Row.java b/src/main/java/ca/joeltherrien/randomforest/Row.java
new file mode 100644
index 0000000..898229b
--- /dev/null
+++ b/src/main/java/ca/joeltherrien/randomforest/Row.java
@@ -0,0 +1,27 @@
+package ca.joeltherrien.randomforest;
+
+
+import java.util.Map;
+
+public class Row extends CovariateRow {
+
+ private final Y response;
+
+ public Row(Map valueMap, int id, Y response){
+ super(valueMap, id);
+ this.response = response;
+ }
+
+
+ public Y getResponse() {
+ return this.response;
+ }
+
+ @Override
+ public String toString() {
+ return "Row " + this.getId();
+ }
+
+
+
+}
diff --git a/src/ca/joeltherrien/randomforest/Split.java b/src/main/java/ca/joeltherrien/randomforest/Split.java
similarity index 68%
rename from src/ca/joeltherrien/randomforest/Split.java
rename to src/main/java/ca/joeltherrien/randomforest/Split.java
index 741b608..9804804 100644
--- a/src/ca/joeltherrien/randomforest/Split.java
+++ b/src/main/java/ca/joeltherrien/randomforest/Split.java
@@ -1,5 +1,7 @@
package ca.joeltherrien.randomforest;
+import lombok.Data;
+
import java.util.List;
/**
@@ -8,13 +10,10 @@ import java.util.List;
* @author joel
*
*/
+@Data
public class Split {
public final List> leftHand;
public final List> rightHand;
-
- public Split(List> leftHand, List> rightHand){
- this.leftHand = leftHand;
- this.rightHand = rightHand;
- }
+
}
diff --git a/src/main/java/ca/joeltherrien/randomforest/SplitRule.java b/src/main/java/ca/joeltherrien/randomforest/SplitRule.java
new file mode 100644
index 0000000..a4b9c7a
--- /dev/null
+++ b/src/main/java/ca/joeltherrien/randomforest/SplitRule.java
@@ -0,0 +1,36 @@
+package ca.joeltherrien.randomforest;
+
+import java.util.LinkedList;
+import java.util.List;
+
+public abstract class SplitRule {
+
+ /**
+ * Applies the SplitRule to a list of rows and returns a Split object, which contains two lists for both sides.
+ * This method is primarily used during the training of a tree when splits are being tested.
+ *
+ * @param rows
+ * @param
+ * @return
+ */
+ public Split applyRule(List> rows) {
+ final List> leftHand = new LinkedList<>();
+ final List> rightHand = new LinkedList<>();
+
+ for(final Row row : rows) {
+
+ if(isLeftHand(row)){
+ leftHand.add(row);
+ }
+ else{
+ rightHand.add(row);
+ }
+
+ }
+
+ return new Split<>(leftHand, rightHand);
+ }
+
+ public abstract boolean isLeftHand(CovariateRow row);
+
+}
diff --git a/src/main/java/ca/joeltherrien/randomforest/Value.java b/src/main/java/ca/joeltherrien/randomforest/Value.java
new file mode 100644
index 0000000..fd16563
--- /dev/null
+++ b/src/main/java/ca/joeltherrien/randomforest/Value.java
@@ -0,0 +1,11 @@
+package ca.joeltherrien.randomforest;
+
+
+public interface Value {
+
+ V getValue();
+
+ SplitRule generateSplitRule(String covariateName);
+
+
+}
diff --git a/src/ca/joeltherrien/randomforest/exceptions/MissingValueException.java b/src/main/java/ca/joeltherrien/randomforest/exceptions/MissingValueException.java
similarity index 62%
rename from src/ca/joeltherrien/randomforest/exceptions/MissingValueException.java
rename to src/main/java/ca/joeltherrien/randomforest/exceptions/MissingValueException.java
index 1de5e30..6e58b93 100644
--- a/src/ca/joeltherrien/randomforest/exceptions/MissingValueException.java
+++ b/src/main/java/ca/joeltherrien/randomforest/exceptions/MissingValueException.java
@@ -1,5 +1,6 @@
package ca.joeltherrien.randomforest.exceptions;
+import ca.joeltherrien.randomforest.CovariateRow;
import ca.joeltherrien.randomforest.Row;
import ca.joeltherrien.randomforest.SplitRule;
@@ -10,8 +11,8 @@ public class MissingValueException extends RuntimeException{
*/
private static final long serialVersionUID = 6808060079431207726L;
- public MissingValueException(Row> row, SplitRule rule) {
- super("Missing value at row " + row + rule);
+ public MissingValueException(CovariateRow row, SplitRule rule) {
+ super("Missing value at CovariateRow " + row + rule);
}
}
diff --git a/src/main/java/ca/joeltherrien/randomforest/regression/MeanGroupDifferentiator.java b/src/main/java/ca/joeltherrien/randomforest/regression/MeanGroupDifferentiator.java
new file mode 100644
index 0000000..7c196e6
--- /dev/null
+++ b/src/main/java/ca/joeltherrien/randomforest/regression/MeanGroupDifferentiator.java
@@ -0,0 +1,26 @@
+package ca.joeltherrien.randomforest.regression;
+
+import ca.joeltherrien.randomforest.tree.GroupDifferentiator;
+
+import java.util.List;
+
+public class MeanGroupDifferentiator implements GroupDifferentiator {
+
+ @Override
+ public Double differentiate(List leftHand, List rightHand) {
+
+ double leftHandSize = leftHand.size();
+ double rightHandSize = rightHand.size();
+
+ if(leftHandSize == 0 || rightHandSize == 0){
+ return null;
+ }
+
+ double leftHandMean = leftHand.stream().mapToDouble(db -> db/leftHandSize).sum();
+ double rightHandMean = rightHand.stream().mapToDouble(db -> db/rightHandSize).sum();
+
+ return Math.abs(leftHandMean - rightHandMean);
+
+ }
+
+}
diff --git a/src/main/java/ca/joeltherrien/randomforest/regression/MeanResponseCombiner.java b/src/main/java/ca/joeltherrien/randomforest/regression/MeanResponseCombiner.java
new file mode 100644
index 0000000..1e37f3c
--- /dev/null
+++ b/src/main/java/ca/joeltherrien/randomforest/regression/MeanResponseCombiner.java
@@ -0,0 +1,16 @@
+package ca.joeltherrien.randomforest.regression;
+
+import ca.joeltherrien.randomforest.ResponseCombiner;
+
+import java.util.List;
+
+public class MeanResponseCombiner implements ResponseCombiner {
+
+ @Override
+ public Double combine(List responses) {
+ double size = responses.size();
+
+ return responses.stream().mapToDouble(db -> db/size).sum();
+
+ }
+}
diff --git a/src/main/java/ca/joeltherrien/randomforest/regression/WeightedVarianceGroupDifferentiator.java b/src/main/java/ca/joeltherrien/randomforest/regression/WeightedVarianceGroupDifferentiator.java
new file mode 100644
index 0000000..2f40999
--- /dev/null
+++ b/src/main/java/ca/joeltherrien/randomforest/regression/WeightedVarianceGroupDifferentiator.java
@@ -0,0 +1,30 @@
+package ca.joeltherrien.randomforest.regression;
+
+import ca.joeltherrien.randomforest.tree.GroupDifferentiator;
+
+import java.util.List;
+
+public class WeightedVarianceGroupDifferentiator implements GroupDifferentiator {
+
+ @Override
+ public Double differentiate(List leftHand, List rightHand) {
+
+ final double leftHandSize = leftHand.size();
+ final double rightHandSize = rightHand.size();
+ final double n = leftHandSize + rightHandSize;
+
+ if(leftHandSize == 0 || rightHandSize == 0){
+ return null;
+ }
+
+ final double leftHandMean = leftHand.stream().mapToDouble(db -> db/leftHandSize).sum();
+ final double rightHandMean = rightHand.stream().mapToDouble(db -> db/rightHandSize).sum();
+
+ final double leftVariance = leftHand.stream().mapToDouble(db -> (db - leftHandMean)*(db - leftHandMean)).sum();
+ final double rightVariance = rightHand.stream().mapToDouble(db -> (db - rightHandMean)*(db - rightHandMean)).sum();
+
+ return -(leftVariance + rightVariance) / n;
+
+ }
+
+}
diff --git a/src/main/java/ca/joeltherrien/randomforest/tree/GroupDifferentiator.java b/src/main/java/ca/joeltherrien/randomforest/tree/GroupDifferentiator.java
new file mode 100644
index 0000000..cbd1247
--- /dev/null
+++ b/src/main/java/ca/joeltherrien/randomforest/tree/GroupDifferentiator.java
@@ -0,0 +1,15 @@
+package ca.joeltherrien.randomforest.tree;
+
+import java.util.List;
+
+/**
+ * When choosing an optimal node to split on, we choose the split that maximizes the difference between the two groups.
+ * The GroupDifferentiator has one method that outputs a score to show how different groups are. The larger the score,
+ * the greater the difference.
+ *
+ */
+public interface GroupDifferentiator {
+
+ Double differentiate(List leftHand, List rightHand);
+
+}
diff --git a/src/main/java/ca/joeltherrien/randomforest/tree/Node.java b/src/main/java/ca/joeltherrien/randomforest/tree/Node.java
new file mode 100644
index 0000000..25547e0
--- /dev/null
+++ b/src/main/java/ca/joeltherrien/randomforest/tree/Node.java
@@ -0,0 +1,9 @@
+package ca.joeltherrien.randomforest.tree;
+
+import ca.joeltherrien.randomforest.CovariateRow;
+
+public interface Node {
+
+ Y evaluate(CovariateRow row);
+
+}
diff --git a/src/main/java/ca/joeltherrien/randomforest/tree/SplitNode.java b/src/main/java/ca/joeltherrien/randomforest/tree/SplitNode.java
new file mode 100644
index 0000000..1cedfb5
--- /dev/null
+++ b/src/main/java/ca/joeltherrien/randomforest/tree/SplitNode.java
@@ -0,0 +1,26 @@
+package ca.joeltherrien.randomforest.tree;
+
+import ca.joeltherrien.randomforest.CovariateRow;
+import ca.joeltherrien.randomforest.Row;
+import ca.joeltherrien.randomforest.SplitRule;
+import lombok.Builder;
+
+@Builder
+public class SplitNode implements Node {
+
+ private final Node leftHand;
+ private final Node rightHand;
+ private final SplitRule splitRule;
+
+ @Override
+ public Y evaluate(CovariateRow row) {
+
+ if(splitRule.isLeftHand(row)){
+ return leftHand.evaluate(row);
+ }
+ else{
+ return rightHand.evaluate(row);
+ }
+
+ }
+}
diff --git a/src/main/java/ca/joeltherrien/randomforest/tree/TerminalNode.java b/src/main/java/ca/joeltherrien/randomforest/tree/TerminalNode.java
new file mode 100644
index 0000000..93bc61f
--- /dev/null
+++ b/src/main/java/ca/joeltherrien/randomforest/tree/TerminalNode.java
@@ -0,0 +1,20 @@
+package ca.joeltherrien.randomforest.tree;
+
+import ca.joeltherrien.randomforest.CovariateRow;
+
+import lombok.RequiredArgsConstructor;
+
+@RequiredArgsConstructor
+public class TerminalNode implements Node {
+
+ private final Y responseValue;
+
+ @Override
+ public Y evaluate(CovariateRow row){
+ return responseValue;
+ }
+
+
+
+
+}
diff --git a/src/main/java/ca/joeltherrien/randomforest/tree/TreeTrainer.java b/src/main/java/ca/joeltherrien/randomforest/tree/TreeTrainer.java
new file mode 100644
index 0000000..df2ae5c
--- /dev/null
+++ b/src/main/java/ca/joeltherrien/randomforest/tree/TreeTrainer.java
@@ -0,0 +1,105 @@
+package ca.joeltherrien.randomforest.tree;
+
+import ca.joeltherrien.randomforest.ResponseCombiner;
+import ca.joeltherrien.randomforest.Row;
+import ca.joeltherrien.randomforest.Split;
+import ca.joeltherrien.randomforest.SplitRule;
+import lombok.Builder;
+
+import java.util.Collections;
+import java.util.List;
+import java.util.stream.Collectors;
+
+@Builder
+public class TreeTrainer {
+
+ private final ResponseCombiner responseCombiner;
+ private final GroupDifferentiator groupDifferentiator;
+
+ /**
+ * The number of splits to perform on each covariate. A value of 0 means all possible splits are tried.
+ *
+ */
+ private final int numberOfSplits;
+ private final int nodeSize;
+ private final int maxNodeDepth;
+
+
+ public Node growTree(List> data, List covariatesToTry){
+ return growNode(data, covariatesToTry, 0);
+ }
+
+ private Node growNode(List> data, List covariatesToTry, int depth){
+ // TODO; what is minimum per tree?
+ if(data.size() >= 2*nodeSize && depth < maxNodeDepth && !nodeIsPure(data, covariatesToTry)){
+ final SplitRule bestSplitRule = findBestSplitRule(data, covariatesToTry);
+
+ final Split split = bestSplitRule.applyRule(data); // TODO optimize this as we're duplicating work done in findBestSplitRule
+
+ final Node leftNode = growNode(split.leftHand, covariatesToTry, depth+1);
+ final Node rightNode = growNode(split.rightHand, covariatesToTry, depth+1);
+
+ return new SplitNode<>(leftNode, rightNode, bestSplitRule);
+
+ }
+ else{
+ return new TerminalNode<>(responseCombiner.combine(
+ data.stream()
+ .map(row -> row.getResponse())
+ .collect(Collectors.toList()))
+
+ );
+ }
+
+
+ }
+
+ private SplitRule findBestSplitRule(List> data, List covariatesToTry){
+ SplitRule bestSplitRule = null;
+ double bestSplitScore = 0;
+ boolean first = true;
+
+ for(final String covariate : covariatesToTry){
+ Collections.shuffle(data);
+
+ int tries = 0;
+ while(tries <= numberOfSplits || (numberOfSplits == 0 && tries < data.size())){
+ final SplitRule possibleRule = data.get(tries).getCovariate(covariate).generateSplitRule(covariate);
+ final Split possibleSplit = possibleRule.applyRule(data);
+
+ final Double score = groupDifferentiator.differentiate(
+ possibleSplit.leftHand.stream().map(row -> row.getResponse()).collect(Collectors.toList()),
+ possibleSplit.rightHand.stream().map(row -> row.getResponse()).collect(Collectors.toList())
+ );
+
+ /*
+ if( (groupDifferentiator.shouldMaximize() && score > bestSplitScore) || (!groupDifferentiator.shouldMaximize() && score < bestSplitScore) || first){
+ bestSplitRule = possibleRule;
+ bestSplitScore = score;
+ first = false;
+ }
+ */
+
+ if( score != null && (score > bestSplitScore || first)){
+ bestSplitRule = possibleRule;
+ bestSplitScore = score;
+ first = false;
+ }
+
+ tries++;
+ }
+
+ }
+
+ return bestSplitRule;
+
+ }
+
+ private boolean nodeIsPure(List> data, List covariatesToTry){
+ // TODO how is this done?
+
+ final Y first = data.get(0).getResponse();
+ return data.stream().allMatch(row -> row.getResponse().equals(first));
+ }
+
+}