From 3c9c78741fd83c7f9e514d505f291222140228b6 Mon Sep 17 00:00:00 2001 From: Joel Therrien Date: Sun, 1 Jul 2018 22:22:12 -0700 Subject: [PATCH] Basic functinality to train a single regression tree is implemented. --- .gitignore | 2 + pom.xml | 31 ++++ src/ca/joeltherrien/randomforest/Main.java | 10 -- src/ca/joeltherrien/randomforest/Node.java | 5 - .../randomforest/NumericSplitRule.java | 44 ------ src/ca/joeltherrien/randomforest/Row.java | 33 ---- .../joeltherrien/randomforest/SplitRule.java | 9 -- src/ca/joeltherrien/randomforest/Value.java | 7 - .../randomforest/CovariateRow.java | 26 ++++ .../ca/joeltherrien/randomforest/Main.java | 146 ++++++++++++++++++ .../randomforest/NumericSplitRule.java | 33 ++++ .../randomforest/NumericValue.java | 19 +++ .../randomforest/ResponseCombiner.java | 9 ++ .../ca/joeltherrien/randomforest/Row.java | 27 ++++ .../ca/joeltherrien/randomforest/Split.java | 9 +- .../joeltherrien/randomforest/SplitRule.java | 36 +++++ .../ca/joeltherrien/randomforest/Value.java | 11 ++ .../exceptions/MissingValueException.java | 5 +- .../regression/MeanGroupDifferentiator.java | 26 ++++ .../regression/MeanResponseCombiner.java | 16 ++ .../WeightedVarianceGroupDifferentiator.java | 30 ++++ .../tree/GroupDifferentiator.java | 15 ++ .../joeltherrien/randomforest/tree/Node.java | 9 ++ .../randomforest/tree/SplitNode.java | 26 ++++ .../randomforest/tree/TerminalNode.java | 20 +++ .../randomforest/tree/TreeTrainer.java | 105 +++++++++++++ 26 files changed, 594 insertions(+), 115 deletions(-) create mode 100644 pom.xml delete mode 100644 src/ca/joeltherrien/randomforest/Main.java delete mode 100644 src/ca/joeltherrien/randomforest/Node.java delete mode 100644 src/ca/joeltherrien/randomforest/NumericSplitRule.java delete mode 100644 src/ca/joeltherrien/randomforest/Row.java delete mode 100644 src/ca/joeltherrien/randomforest/SplitRule.java delete mode 100644 src/ca/joeltherrien/randomforest/Value.java create mode 100644 src/main/java/ca/joeltherrien/randomforest/CovariateRow.java create mode 100644 src/main/java/ca/joeltherrien/randomforest/Main.java create mode 100644 src/main/java/ca/joeltherrien/randomforest/NumericSplitRule.java create mode 100644 src/main/java/ca/joeltherrien/randomforest/NumericValue.java create mode 100644 src/main/java/ca/joeltherrien/randomforest/ResponseCombiner.java create mode 100644 src/main/java/ca/joeltherrien/randomforest/Row.java rename src/{ => main/java}/ca/joeltherrien/randomforest/Split.java (68%) create mode 100644 src/main/java/ca/joeltherrien/randomforest/SplitRule.java create mode 100644 src/main/java/ca/joeltherrien/randomforest/Value.java rename src/{ => main/java}/ca/joeltherrien/randomforest/exceptions/MissingValueException.java (62%) create mode 100644 src/main/java/ca/joeltherrien/randomforest/regression/MeanGroupDifferentiator.java create mode 100644 src/main/java/ca/joeltherrien/randomforest/regression/MeanResponseCombiner.java create mode 100644 src/main/java/ca/joeltherrien/randomforest/regression/WeightedVarianceGroupDifferentiator.java create mode 100644 src/main/java/ca/joeltherrien/randomforest/tree/GroupDifferentiator.java create mode 100644 src/main/java/ca/joeltherrien/randomforest/tree/Node.java create mode 100644 src/main/java/ca/joeltherrien/randomforest/tree/SplitNode.java create mode 100644 src/main/java/ca/joeltherrien/randomforest/tree/TerminalNode.java create mode 100644 src/main/java/ca/joeltherrien/randomforest/tree/TreeTrainer.java diff --git a/.gitignore b/.gitignore index f61aa0d..73c3108 100644 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,5 @@ .settings .project target/ +*.iml +.idea diff --git a/pom.xml b/pom.xml new file mode 100644 index 0000000..637f869 --- /dev/null +++ b/pom.xml @@ -0,0 +1,31 @@ + + + 4.0.0 + + ca.joeltherrien + RandomSurvivalForests + 1.0-SNAPSHOT + + + 1.10 + 1.10 + 1.10 + + + + + + org.projectlombok + lombok + 1.18.0 + provided + + + + + + + + \ No newline at end of file diff --git a/src/ca/joeltherrien/randomforest/Main.java b/src/ca/joeltherrien/randomforest/Main.java deleted file mode 100644 index 1ae6537..0000000 --- a/src/ca/joeltherrien/randomforest/Main.java +++ /dev/null @@ -1,10 +0,0 @@ -package ca.joeltherrien.randomforest; - -public class Main { - - public static void main(String[] args) { - System.out.println("Hello world!"); - - } - -} diff --git a/src/ca/joeltherrien/randomforest/Node.java b/src/ca/joeltherrien/randomforest/Node.java deleted file mode 100644 index fa6efce..0000000 --- a/src/ca/joeltherrien/randomforest/Node.java +++ /dev/null @@ -1,5 +0,0 @@ -package ca.joeltherrien.randomforest; - -public class Node { - -} diff --git a/src/ca/joeltherrien/randomforest/NumericSplitRule.java b/src/ca/joeltherrien/randomforest/NumericSplitRule.java deleted file mode 100644 index d4e3fa9..0000000 --- a/src/ca/joeltherrien/randomforest/NumericSplitRule.java +++ /dev/null @@ -1,44 +0,0 @@ -package ca.joeltherrien.randomforest; - -import java.util.LinkedList; -import java.util.List; - -import ca.joeltherrien.randomforest.exceptions.MissingValueException; - -public class NumericSplitRule implements SplitRule{ - - public final String covariateName; - public final double threshold; - - public NumericSplitRule(String covariateName, double threshold) { - super(); - this.covariateName = covariateName; - this.threshold = threshold; - } - - @Override - public final String toString() { - return "NumericSplitRule on " + covariateName + " at " + threshold; - } - - @Override - public Split applyRule(List> rows) { - final List> leftHand = new LinkedList<>(); - final List> rightHand = new LinkedList<>(); - - for(final Row row : rows) { - final Value x = row.getCovariate(covariateName); - if(x == null) { - throw new MissingValueException(row, this); - } - - final NumericValue xNum = (NumericValue) x; - - } - - // TODO Auto-generated method stub - return null; - } - - -} diff --git a/src/ca/joeltherrien/randomforest/Row.java b/src/ca/joeltherrien/randomforest/Row.java deleted file mode 100644 index ca379ae..0000000 --- a/src/ca/joeltherrien/randomforest/Row.java +++ /dev/null @@ -1,33 +0,0 @@ -package ca.joeltherrien.randomforest; - -import java.util.Map; - -public class Row { - - private final Map covariates; - private final Y response; - private final int id; - - public Row(Map covariates, Y response, int id) { - super(); - this.covariates = covariates; - this.response = response; - this.id = id; - } - - public Value getCovariate(String name) { - return this.covariates.get(name); - } - - public Y getResponse() { - return this.response; - } - - @Override - public String toString() { - return "Row " + this.id; - } - - - -} diff --git a/src/ca/joeltherrien/randomforest/SplitRule.java b/src/ca/joeltherrien/randomforest/SplitRule.java deleted file mode 100644 index d7efca6..0000000 --- a/src/ca/joeltherrien/randomforest/SplitRule.java +++ /dev/null @@ -1,9 +0,0 @@ -package ca.joeltherrien.randomforest; - -import java.util.List; - -public interface SplitRule { - - Split applyRule(List> rows); - -} diff --git a/src/ca/joeltherrien/randomforest/Value.java b/src/ca/joeltherrien/randomforest/Value.java deleted file mode 100644 index f92affe..0000000 --- a/src/ca/joeltherrien/randomforest/Value.java +++ /dev/null @@ -1,7 +0,0 @@ -package ca.joeltherrien.randomforest; - -public interface Value { - - // TODO - -} diff --git a/src/main/java/ca/joeltherrien/randomforest/CovariateRow.java b/src/main/java/ca/joeltherrien/randomforest/CovariateRow.java new file mode 100644 index 0000000..beccc3a --- /dev/null +++ b/src/main/java/ca/joeltherrien/randomforest/CovariateRow.java @@ -0,0 +1,26 @@ +package ca.joeltherrien.randomforest; + +import lombok.Getter; +import lombok.RequiredArgsConstructor; + +import java.util.Map; + +@RequiredArgsConstructor +public class CovariateRow { + + private final Map valueMap; + + @Getter + private final int id; + + public Value getCovariate(String name){ + return valueMap.get(name); + + } + + @Override + public String toString(){ + return "CovariateRow " + this.id; + } + +} diff --git a/src/main/java/ca/joeltherrien/randomforest/Main.java b/src/main/java/ca/joeltherrien/randomforest/Main.java new file mode 100644 index 0000000..b2a8edb --- /dev/null +++ b/src/main/java/ca/joeltherrien/randomforest/Main.java @@ -0,0 +1,146 @@ +package ca.joeltherrien.randomforest; + + +import ca.joeltherrien.randomforest.regression.MeanGroupDifferentiator; +import ca.joeltherrien.randomforest.regression.MeanResponseCombiner; +import ca.joeltherrien.randomforest.regression.WeightedVarianceGroupDifferentiator; +import ca.joeltherrien.randomforest.tree.Node; +import ca.joeltherrien.randomforest.tree.TreeTrainer; + +import java.util.*; +import java.util.stream.Collectors; +import java.util.stream.DoubleStream; + +public class Main { + + public static void main(String[] args) { + System.out.println("Hello world!"); + + final Random random = new Random(123); + + final int n = 1000; + final List> trainingSet = new ArrayList<>(n); + + final List> x1List = DoubleStream + .generate(() -> random.nextDouble()*10.0) + .limit(n) + .mapToObj(x1 -> new NumericValue(x1)) + .collect(Collectors.toList()); + + final List> x2List = DoubleStream + .generate(() -> random.nextDouble()*10.0) + .limit(n) + .mapToObj(x1 -> new NumericValue(x1)) + .collect(Collectors.toList()); + + + + for(int i=0; i treeTrainer = TreeTrainer.builder() + .groupDifferentiator(new WeightedVarianceGroupDifferentiator()) + .responseCombiner(new MeanResponseCombiner()) + .maxNodeDepth(30) + .nodeSize(5) + .numberOfSplits(0) + .build(); + + final long endTime = System.currentTimeMillis(); + + System.out.println(((double)(endTime - startTime))/1000.0); + + final List covariateNames = List.of("x1", "x2"); + + final Node baseNode = treeTrainer.growTree(trainingSet, covariateNames); + + + final List testSet = new ArrayList<>(); + testSet.add(generateCovariateRow(9, 2, 1)); // expect 1 + testSet.add(generateCovariateRow(5, 2, 5)); + testSet.add(generateCovariateRow(2, 2, 3)); + testSet.add(generateCovariateRow(9, 5, 0)); + testSet.add(generateCovariateRow(6, 5, 8)); + testSet.add(generateCovariateRow(3, 5, 10)); + testSet.add(generateCovariateRow(1, 5, 3)); + testSet.add(generateCovariateRow(7, 9, 2)); + testSet.add(generateCovariateRow(1, 9, 4)); + + for(final CovariateRow testCase : testSet){ + System.out.println(testCase); + System.out.println(baseNode.evaluate(testCase)); + System.out.println(); + + + } + + + + + + } + + public static Row generateRow(double x1, double x2, int id){ + double y = generateResponse(x1, x2); + + final Map map = Map.of("x1", new NumericValue(x1), "x2", new NumericValue(x2)); + + return new Row<>(map, id, y); + + } + + + public static CovariateRow generateCovariateRow(double x1, double x2, int id){ + final Map map = Map.of("x1", new NumericValue(x1), "x2", new NumericValue(x2)); + + return new CovariateRow(map, id); + + } + + + public static double generateResponse(double x1, double x2){ + if(x2 <= 3){ + if(x1 <= 3){ + return 3; + } + else if(x1 <= 7){ + return 5; + } + else{ + return 1; + } + } + else if(x1 >= 5){ + if(x2 > 6){ + return 2; + } + else if(x1 >= 8){ + return 0; + } + else{ + return 8; + } + } + else if(x1 <= 2){ + if(x2 >= 7){ + return 4; + } + else{ + return 3; + } + } + else{ + return 10; + } + + + } + +} diff --git a/src/main/java/ca/joeltherrien/randomforest/NumericSplitRule.java b/src/main/java/ca/joeltherrien/randomforest/NumericSplitRule.java new file mode 100644 index 0000000..505a58c --- /dev/null +++ b/src/main/java/ca/joeltherrien/randomforest/NumericSplitRule.java @@ -0,0 +1,33 @@ +package ca.joeltherrien.randomforest; + +import java.util.LinkedList; +import java.util.List; + +import ca.joeltherrien.randomforest.exceptions.MissingValueException; +import lombok.AllArgsConstructor; + +@AllArgsConstructor +public class NumericSplitRule extends SplitRule{ + + public final String covariateName; + public final double threshold; + + @Override + public final String toString() { + return "NumericSplitRule on " + covariateName + " at " + threshold; + } + + @Override + public boolean isLeftHand(CovariateRow row) { + final Value x = row.getCovariate(covariateName); + if(x == null) { + throw new MissingValueException(row, this); + } + + final double xNum = (Double) x.getValue(); + + return xNum <= threshold; + } + + +} diff --git a/src/main/java/ca/joeltherrien/randomforest/NumericValue.java b/src/main/java/ca/joeltherrien/randomforest/NumericValue.java new file mode 100644 index 0000000..9defa06 --- /dev/null +++ b/src/main/java/ca/joeltherrien/randomforest/NumericValue.java @@ -0,0 +1,19 @@ +package ca.joeltherrien.randomforest; + +import lombok.RequiredArgsConstructor; + +@RequiredArgsConstructor +public class NumericValue implements Value { + + private final double value; + + @Override + public Double getValue() { + return value; + } + + @Override + public SplitRule generateSplitRule(final String covariateName) { + return new NumericSplitRule(covariateName, value); + } +} diff --git a/src/main/java/ca/joeltherrien/randomforest/ResponseCombiner.java b/src/main/java/ca/joeltherrien/randomforest/ResponseCombiner.java new file mode 100644 index 0000000..7f24928 --- /dev/null +++ b/src/main/java/ca/joeltherrien/randomforest/ResponseCombiner.java @@ -0,0 +1,9 @@ +package ca.joeltherrien.randomforest; + +import java.util.List; + +public interface ResponseCombiner { + + Y combine(List responses); + +} diff --git a/src/main/java/ca/joeltherrien/randomforest/Row.java b/src/main/java/ca/joeltherrien/randomforest/Row.java new file mode 100644 index 0000000..898229b --- /dev/null +++ b/src/main/java/ca/joeltherrien/randomforest/Row.java @@ -0,0 +1,27 @@ +package ca.joeltherrien.randomforest; + + +import java.util.Map; + +public class Row extends CovariateRow { + + private final Y response; + + public Row(Map valueMap, int id, Y response){ + super(valueMap, id); + this.response = response; + } + + + public Y getResponse() { + return this.response; + } + + @Override + public String toString() { + return "Row " + this.getId(); + } + + + +} diff --git a/src/ca/joeltherrien/randomforest/Split.java b/src/main/java/ca/joeltherrien/randomforest/Split.java similarity index 68% rename from src/ca/joeltherrien/randomforest/Split.java rename to src/main/java/ca/joeltherrien/randomforest/Split.java index 741b608..9804804 100644 --- a/src/ca/joeltherrien/randomforest/Split.java +++ b/src/main/java/ca/joeltherrien/randomforest/Split.java @@ -1,5 +1,7 @@ package ca.joeltherrien.randomforest; +import lombok.Data; + import java.util.List; /** @@ -8,13 +10,10 @@ import java.util.List; * @author joel * */ +@Data public class Split { public final List> leftHand; public final List> rightHand; - - public Split(List> leftHand, List> rightHand){ - this.leftHand = leftHand; - this.rightHand = rightHand; - } + } diff --git a/src/main/java/ca/joeltherrien/randomforest/SplitRule.java b/src/main/java/ca/joeltherrien/randomforest/SplitRule.java new file mode 100644 index 0000000..a4b9c7a --- /dev/null +++ b/src/main/java/ca/joeltherrien/randomforest/SplitRule.java @@ -0,0 +1,36 @@ +package ca.joeltherrien.randomforest; + +import java.util.LinkedList; +import java.util.List; + +public abstract class SplitRule { + + /** + * Applies the SplitRule to a list of rows and returns a Split object, which contains two lists for both sides. + * This method is primarily used during the training of a tree when splits are being tested. + * + * @param rows + * @param + * @return + */ + public Split applyRule(List> rows) { + final List> leftHand = new LinkedList<>(); + final List> rightHand = new LinkedList<>(); + + for(final Row row : rows) { + + if(isLeftHand(row)){ + leftHand.add(row); + } + else{ + rightHand.add(row); + } + + } + + return new Split<>(leftHand, rightHand); + } + + public abstract boolean isLeftHand(CovariateRow row); + +} diff --git a/src/main/java/ca/joeltherrien/randomforest/Value.java b/src/main/java/ca/joeltherrien/randomforest/Value.java new file mode 100644 index 0000000..fd16563 --- /dev/null +++ b/src/main/java/ca/joeltherrien/randomforest/Value.java @@ -0,0 +1,11 @@ +package ca.joeltherrien.randomforest; + + +public interface Value { + + V getValue(); + + SplitRule generateSplitRule(String covariateName); + + +} diff --git a/src/ca/joeltherrien/randomforest/exceptions/MissingValueException.java b/src/main/java/ca/joeltherrien/randomforest/exceptions/MissingValueException.java similarity index 62% rename from src/ca/joeltherrien/randomforest/exceptions/MissingValueException.java rename to src/main/java/ca/joeltherrien/randomforest/exceptions/MissingValueException.java index 1de5e30..6e58b93 100644 --- a/src/ca/joeltherrien/randomforest/exceptions/MissingValueException.java +++ b/src/main/java/ca/joeltherrien/randomforest/exceptions/MissingValueException.java @@ -1,5 +1,6 @@ package ca.joeltherrien.randomforest.exceptions; +import ca.joeltherrien.randomforest.CovariateRow; import ca.joeltherrien.randomforest.Row; import ca.joeltherrien.randomforest.SplitRule; @@ -10,8 +11,8 @@ public class MissingValueException extends RuntimeException{ */ private static final long serialVersionUID = 6808060079431207726L; - public MissingValueException(Row row, SplitRule rule) { - super("Missing value at row " + row + rule); + public MissingValueException(CovariateRow row, SplitRule rule) { + super("Missing value at CovariateRow " + row + rule); } } diff --git a/src/main/java/ca/joeltherrien/randomforest/regression/MeanGroupDifferentiator.java b/src/main/java/ca/joeltherrien/randomforest/regression/MeanGroupDifferentiator.java new file mode 100644 index 0000000..7c196e6 --- /dev/null +++ b/src/main/java/ca/joeltherrien/randomforest/regression/MeanGroupDifferentiator.java @@ -0,0 +1,26 @@ +package ca.joeltherrien.randomforest.regression; + +import ca.joeltherrien.randomforest.tree.GroupDifferentiator; + +import java.util.List; + +public class MeanGroupDifferentiator implements GroupDifferentiator { + + @Override + public Double differentiate(List leftHand, List rightHand) { + + double leftHandSize = leftHand.size(); + double rightHandSize = rightHand.size(); + + if(leftHandSize == 0 || rightHandSize == 0){ + return null; + } + + double leftHandMean = leftHand.stream().mapToDouble(db -> db/leftHandSize).sum(); + double rightHandMean = rightHand.stream().mapToDouble(db -> db/rightHandSize).sum(); + + return Math.abs(leftHandMean - rightHandMean); + + } + +} diff --git a/src/main/java/ca/joeltherrien/randomforest/regression/MeanResponseCombiner.java b/src/main/java/ca/joeltherrien/randomforest/regression/MeanResponseCombiner.java new file mode 100644 index 0000000..1e37f3c --- /dev/null +++ b/src/main/java/ca/joeltherrien/randomforest/regression/MeanResponseCombiner.java @@ -0,0 +1,16 @@ +package ca.joeltherrien.randomforest.regression; + +import ca.joeltherrien.randomforest.ResponseCombiner; + +import java.util.List; + +public class MeanResponseCombiner implements ResponseCombiner { + + @Override + public Double combine(List responses) { + double size = responses.size(); + + return responses.stream().mapToDouble(db -> db/size).sum(); + + } +} diff --git a/src/main/java/ca/joeltherrien/randomforest/regression/WeightedVarianceGroupDifferentiator.java b/src/main/java/ca/joeltherrien/randomforest/regression/WeightedVarianceGroupDifferentiator.java new file mode 100644 index 0000000..2f40999 --- /dev/null +++ b/src/main/java/ca/joeltherrien/randomforest/regression/WeightedVarianceGroupDifferentiator.java @@ -0,0 +1,30 @@ +package ca.joeltherrien.randomforest.regression; + +import ca.joeltherrien.randomforest.tree.GroupDifferentiator; + +import java.util.List; + +public class WeightedVarianceGroupDifferentiator implements GroupDifferentiator { + + @Override + public Double differentiate(List leftHand, List rightHand) { + + final double leftHandSize = leftHand.size(); + final double rightHandSize = rightHand.size(); + final double n = leftHandSize + rightHandSize; + + if(leftHandSize == 0 || rightHandSize == 0){ + return null; + } + + final double leftHandMean = leftHand.stream().mapToDouble(db -> db/leftHandSize).sum(); + final double rightHandMean = rightHand.stream().mapToDouble(db -> db/rightHandSize).sum(); + + final double leftVariance = leftHand.stream().mapToDouble(db -> (db - leftHandMean)*(db - leftHandMean)).sum(); + final double rightVariance = rightHand.stream().mapToDouble(db -> (db - rightHandMean)*(db - rightHandMean)).sum(); + + return -(leftVariance + rightVariance) / n; + + } + +} diff --git a/src/main/java/ca/joeltherrien/randomforest/tree/GroupDifferentiator.java b/src/main/java/ca/joeltherrien/randomforest/tree/GroupDifferentiator.java new file mode 100644 index 0000000..cbd1247 --- /dev/null +++ b/src/main/java/ca/joeltherrien/randomforest/tree/GroupDifferentiator.java @@ -0,0 +1,15 @@ +package ca.joeltherrien.randomforest.tree; + +import java.util.List; + +/** + * When choosing an optimal node to split on, we choose the split that maximizes the difference between the two groups. + * The GroupDifferentiator has one method that outputs a score to show how different groups are. The larger the score, + * the greater the difference. + * + */ +public interface GroupDifferentiator { + + Double differentiate(List leftHand, List rightHand); + +} diff --git a/src/main/java/ca/joeltherrien/randomforest/tree/Node.java b/src/main/java/ca/joeltherrien/randomforest/tree/Node.java new file mode 100644 index 0000000..25547e0 --- /dev/null +++ b/src/main/java/ca/joeltherrien/randomforest/tree/Node.java @@ -0,0 +1,9 @@ +package ca.joeltherrien.randomforest.tree; + +import ca.joeltherrien.randomforest.CovariateRow; + +public interface Node { + + Y evaluate(CovariateRow row); + +} diff --git a/src/main/java/ca/joeltherrien/randomforest/tree/SplitNode.java b/src/main/java/ca/joeltherrien/randomforest/tree/SplitNode.java new file mode 100644 index 0000000..1cedfb5 --- /dev/null +++ b/src/main/java/ca/joeltherrien/randomforest/tree/SplitNode.java @@ -0,0 +1,26 @@ +package ca.joeltherrien.randomforest.tree; + +import ca.joeltherrien.randomforest.CovariateRow; +import ca.joeltherrien.randomforest.Row; +import ca.joeltherrien.randomforest.SplitRule; +import lombok.Builder; + +@Builder +public class SplitNode implements Node { + + private final Node leftHand; + private final Node rightHand; + private final SplitRule splitRule; + + @Override + public Y evaluate(CovariateRow row) { + + if(splitRule.isLeftHand(row)){ + return leftHand.evaluate(row); + } + else{ + return rightHand.evaluate(row); + } + + } +} diff --git a/src/main/java/ca/joeltherrien/randomforest/tree/TerminalNode.java b/src/main/java/ca/joeltherrien/randomforest/tree/TerminalNode.java new file mode 100644 index 0000000..93bc61f --- /dev/null +++ b/src/main/java/ca/joeltherrien/randomforest/tree/TerminalNode.java @@ -0,0 +1,20 @@ +package ca.joeltherrien.randomforest.tree; + +import ca.joeltherrien.randomforest.CovariateRow; + +import lombok.RequiredArgsConstructor; + +@RequiredArgsConstructor +public class TerminalNode implements Node { + + private final Y responseValue; + + @Override + public Y evaluate(CovariateRow row){ + return responseValue; + } + + + + +} diff --git a/src/main/java/ca/joeltherrien/randomforest/tree/TreeTrainer.java b/src/main/java/ca/joeltherrien/randomforest/tree/TreeTrainer.java new file mode 100644 index 0000000..df2ae5c --- /dev/null +++ b/src/main/java/ca/joeltherrien/randomforest/tree/TreeTrainer.java @@ -0,0 +1,105 @@ +package ca.joeltherrien.randomforest.tree; + +import ca.joeltherrien.randomforest.ResponseCombiner; +import ca.joeltherrien.randomforest.Row; +import ca.joeltherrien.randomforest.Split; +import ca.joeltherrien.randomforest.SplitRule; +import lombok.Builder; + +import java.util.Collections; +import java.util.List; +import java.util.stream.Collectors; + +@Builder +public class TreeTrainer { + + private final ResponseCombiner responseCombiner; + private final GroupDifferentiator groupDifferentiator; + + /** + * The number of splits to perform on each covariate. A value of 0 means all possible splits are tried. + * + */ + private final int numberOfSplits; + private final int nodeSize; + private final int maxNodeDepth; + + + public Node growTree(List> data, List covariatesToTry){ + return growNode(data, covariatesToTry, 0); + } + + private Node growNode(List> data, List covariatesToTry, int depth){ + // TODO; what is minimum per tree? + if(data.size() >= 2*nodeSize && depth < maxNodeDepth && !nodeIsPure(data, covariatesToTry)){ + final SplitRule bestSplitRule = findBestSplitRule(data, covariatesToTry); + + final Split split = bestSplitRule.applyRule(data); // TODO optimize this as we're duplicating work done in findBestSplitRule + + final Node leftNode = growNode(split.leftHand, covariatesToTry, depth+1); + final Node rightNode = growNode(split.rightHand, covariatesToTry, depth+1); + + return new SplitNode<>(leftNode, rightNode, bestSplitRule); + + } + else{ + return new TerminalNode<>(responseCombiner.combine( + data.stream() + .map(row -> row.getResponse()) + .collect(Collectors.toList())) + + ); + } + + + } + + private SplitRule findBestSplitRule(List> data, List covariatesToTry){ + SplitRule bestSplitRule = null; + double bestSplitScore = 0; + boolean first = true; + + for(final String covariate : covariatesToTry){ + Collections.shuffle(data); + + int tries = 0; + while(tries <= numberOfSplits || (numberOfSplits == 0 && tries < data.size())){ + final SplitRule possibleRule = data.get(tries).getCovariate(covariate).generateSplitRule(covariate); + final Split possibleSplit = possibleRule.applyRule(data); + + final Double score = groupDifferentiator.differentiate( + possibleSplit.leftHand.stream().map(row -> row.getResponse()).collect(Collectors.toList()), + possibleSplit.rightHand.stream().map(row -> row.getResponse()).collect(Collectors.toList()) + ); + + /* + if( (groupDifferentiator.shouldMaximize() && score > bestSplitScore) || (!groupDifferentiator.shouldMaximize() && score < bestSplitScore) || first){ + bestSplitRule = possibleRule; + bestSplitScore = score; + first = false; + } + */ + + if( score != null && (score > bestSplitScore || first)){ + bestSplitRule = possibleRule; + bestSplitScore = score; + first = false; + } + + tries++; + } + + } + + return bestSplitRule; + + } + + private boolean nodeIsPure(List> data, List covariatesToTry){ + // TODO how is this done? + + final Y first = data.get(0).getResponse(); + return data.stream().allMatch(row -> row.getResponse().equals(first)); + } + +}