From e7af65e8fdc1e416687a8aea2eaea9c9ccd1e15b Mon Sep 17 00:00:00 2001 From: Joel Therrien Date: Tue, 3 Jul 2018 15:15:09 -0700 Subject: [PATCH 1/7] Fixed a bug where Splits could be generated that had an empty daughter node --- .../randomforest/tree/TreeTrainer.java | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/src/main/java/ca/joeltherrien/randomforest/tree/TreeTrainer.java b/src/main/java/ca/joeltherrien/randomforest/tree/TreeTrainer.java index 907f8ef..e830f4d 100644 --- a/src/main/java/ca/joeltherrien/randomforest/tree/TreeTrainer.java +++ b/src/main/java/ca/joeltherrien/randomforest/tree/TreeTrainer.java @@ -31,9 +31,18 @@ public class TreeTrainer { private Node growNode(List> data, List covariatesToTry, int depth){ // TODO; what is minimum per tree? - if(data.size() >= 2*nodeSize && depth < maxNodeDepth && !nodeIsPure(data, covariatesToTry)){ + if(data.size() >= 2*nodeSize && depth < maxNodeDepth && !nodeIsPure(data)){ final SplitRule bestSplitRule = findBestSplitRule(data, covariatesToTry); + if(bestSplitRule == null){ + return new TerminalNode<>( + data.stream() + .map(row -> row.getResponse()) + .collect(responseCombiner) + + ); + } + final Split split = bestSplitRule.applyRule(data); // TODO optimize this as we're duplicating work done in findBestSplitRule final Node leftNode = growNode(split.leftHand, covariatesToTry, depth+1); @@ -56,7 +65,7 @@ public class TreeTrainer { private SplitRule findBestSplitRule(List> data, List covariatesToTry){ SplitRule bestSplitRule = null; - Double bestSplitScore = 0.0; // may be null + double bestSplitScore = 0.0; boolean first = true; for(final String covariate : covariatesToTry){ @@ -92,7 +101,7 @@ public class TreeTrainer { possibleSplit.rightHand.stream().map(row -> row.getResponse()).collect(Collectors.toList()) ); - if( first || (score != null && (bestSplitScore == null || score > bestSplitScore))){ + if(score != null && (score > bestSplitScore || first)){ bestSplitRule = possibleRule; bestSplitScore = score; first = false; @@ -107,9 +116,7 @@ public class TreeTrainer { } - private boolean nodeIsPure(List> data, List covariatesToTry){ - // TODO how is this done? - + private boolean nodeIsPure(List> data){ final Y first = data.get(0).getResponse(); return data.stream().allMatch(row -> row.getResponse().equals(first)); } From e96a578ac97d0a07d46fb11f70cb85e1108a9439 Mon Sep 17 00:00:00 2001 From: Joel Therrien Date: Tue, 3 Jul 2018 17:00:02 -0700 Subject: [PATCH 2/7] Refactored code to allow for a class of covariates to determine which SplitRules are tested. Most of the refactoring involved the creation of a Covariate class (one instance per column); with SplitRule and Value being folded in as inner classes. --- .../joeltherrien/randomforest/Covariate.java | 58 ++++++++++ .../randomforest/CovariateRow.java | 4 +- .../randomforest/NumericCovariate.java | 103 ++++++++++++++++++ .../randomforest/NumericSplitRule.java | 33 ------ .../randomforest/NumericValue.java | 24 ---- .../ca/joeltherrien/randomforest/Row.java | 2 +- .../joeltherrien/randomforest/SplitRule.java | 37 ------- .../ca/joeltherrien/randomforest/Value.java | 11 -- .../exceptions/MissingValueException.java | 5 +- .../randomforest/tree/ForestTrainer.java | 5 +- .../randomforest/tree/SplitNode.java | 5 +- .../randomforest/tree/TreeTrainer.java | 49 +++------ .../randomforest/workshop/TrainForest.java | 27 +++-- .../workshop/TrainSingleTree.java | 59 +++++----- 14 files changed, 233 insertions(+), 189 deletions(-) create mode 100644 src/main/java/ca/joeltherrien/randomforest/Covariate.java create mode 100644 src/main/java/ca/joeltherrien/randomforest/NumericCovariate.java delete mode 100644 src/main/java/ca/joeltherrien/randomforest/NumericSplitRule.java delete mode 100644 src/main/java/ca/joeltherrien/randomforest/NumericValue.java delete mode 100644 src/main/java/ca/joeltherrien/randomforest/SplitRule.java delete mode 100644 src/main/java/ca/joeltherrien/randomforest/Value.java diff --git a/src/main/java/ca/joeltherrien/randomforest/Covariate.java b/src/main/java/ca/joeltherrien/randomforest/Covariate.java new file mode 100644 index 0000000..86848d8 --- /dev/null +++ b/src/main/java/ca/joeltherrien/randomforest/Covariate.java @@ -0,0 +1,58 @@ +package ca.joeltherrien.randomforest; + +import java.io.Serializable; +import java.util.Collection; +import java.util.LinkedList; +import java.util.List; + +public interface Covariate extends Serializable { + + String getName(); + + Collection> generateSplitRules(final List> data, final int number); + + Value createValue(V value); + + interface Value extends Serializable{ + + Covariate getParent(); + + V getValue(); + + } + + interface SplitRule extends Serializable{ + + Covariate getParent(); + + /** + * Applies the SplitRule to a list of rows and returns a Split object, which contains two lists for both sides. + * This method is primarily used during the training of a tree when splits are being tested. + * + * @param rows + * @param + * @return + */ + default Split applyRule(List> rows) { + final List> leftHand = new LinkedList<>(); + final List> rightHand = new LinkedList<>(); + + for(final Row row : rows) { + + if(isLeftHand(row)){ + leftHand.add(row); + } + else{ + rightHand.add(row); + } + + } + + return new Split<>(leftHand, rightHand); + } + + boolean isLeftHand(CovariateRow row); + } + + +} diff --git a/src/main/java/ca/joeltherrien/randomforest/CovariateRow.java b/src/main/java/ca/joeltherrien/randomforest/CovariateRow.java index beccc3a..181d80b 100644 --- a/src/main/java/ca/joeltherrien/randomforest/CovariateRow.java +++ b/src/main/java/ca/joeltherrien/randomforest/CovariateRow.java @@ -8,12 +8,12 @@ import java.util.Map; @RequiredArgsConstructor public class CovariateRow { - private final Map valueMap; + private final Map valueMap; @Getter private final int id; - public Value getCovariate(String name){ + public Covariate.Value getCovariateValue(String name){ return valueMap.get(name); } diff --git a/src/main/java/ca/joeltherrien/randomforest/NumericCovariate.java b/src/main/java/ca/joeltherrien/randomforest/NumericCovariate.java new file mode 100644 index 0000000..76c83b3 --- /dev/null +++ b/src/main/java/ca/joeltherrien/randomforest/NumericCovariate.java @@ -0,0 +1,103 @@ +package ca.joeltherrien.randomforest; + +import ca.joeltherrien.randomforest.exceptions.MissingValueException; +import lombok.Getter; +import lombok.RequiredArgsConstructor; + +import java.util.*; +import java.util.concurrent.ThreadLocalRandom; +import java.util.stream.Collectors; + +@RequiredArgsConstructor +public class NumericCovariate implements Covariate{ + + @Getter + private final String name; + + @Override + public Collection generateSplitRules(List> data, int number) { + + // for this implementation we need to shuffle the data + final List> shuffledData; + if(number > data.size()){ + shuffledData = new ArrayList<>(data); + Collections.shuffle(shuffledData); + } + else{ // only need the top number entries + shuffledData = new ArrayList<>(number); + final Set indexesToUse = new HashSet<>(); + + while(indexesToUse.size() < number){ + final int index = ThreadLocalRandom.current().nextInt(data.size()); + + if(indexesToUse.add(index)){ + shuffledData.add(data.get(index)); + } + } + + } + + return shuffledData.stream() + .mapToDouble(v -> v.getValue()) + .mapToObj(threshold -> new NumericSplitRule(threshold)) + .collect(Collectors.toSet()); + // by returning a set we'll make everything far more efficient as a lot of rules can repeat due to bootstrapping + + + } + + @Override + public NumericValue createValue(Double value) { + return new NumericValue(value); + } + + public class NumericValue implements Covariate.Value{ + + private final double value; + + private NumericValue(final double value){ + this.value = value; + } + + @Override + public Covariate getParent() { + return NumericCovariate.this; + } + + @Override + public Double getValue() { + return value; + } + } + + public class NumericSplitRule implements Covariate.SplitRule{ + + private final double threshold; + + private NumericSplitRule(final double threshold){ + this.threshold = threshold; + } + + @Override + public final String toString() { + return "NumericSplitRule on " + getParent().getName() + " at " + threshold; + } + + @Override + public Covariate getParent() { + return NumericCovariate.this; + } + + @Override + public boolean isLeftHand(CovariateRow row) { + final Covariate.Value x = row.getCovariateValue(getParent().getName()); + if(x == null) { + throw new MissingValueException(row, this); + } + + final double xNum = (Double) x.getValue(); + + return xNum <= threshold; + } + } +} diff --git a/src/main/java/ca/joeltherrien/randomforest/NumericSplitRule.java b/src/main/java/ca/joeltherrien/randomforest/NumericSplitRule.java deleted file mode 100644 index 505a58c..0000000 --- a/src/main/java/ca/joeltherrien/randomforest/NumericSplitRule.java +++ /dev/null @@ -1,33 +0,0 @@ -package ca.joeltherrien.randomforest; - -import java.util.LinkedList; -import java.util.List; - -import ca.joeltherrien.randomforest.exceptions.MissingValueException; -import lombok.AllArgsConstructor; - -@AllArgsConstructor -public class NumericSplitRule extends SplitRule{ - - public final String covariateName; - public final double threshold; - - @Override - public final String toString() { - return "NumericSplitRule on " + covariateName + " at " + threshold; - } - - @Override - public boolean isLeftHand(CovariateRow row) { - final Value x = row.getCovariate(covariateName); - if(x == null) { - throw new MissingValueException(row, this); - } - - final double xNum = (Double) x.getValue(); - - return xNum <= threshold; - } - - -} diff --git a/src/main/java/ca/joeltherrien/randomforest/NumericValue.java b/src/main/java/ca/joeltherrien/randomforest/NumericValue.java deleted file mode 100644 index a3d79bc..0000000 --- a/src/main/java/ca/joeltherrien/randomforest/NumericValue.java +++ /dev/null @@ -1,24 +0,0 @@ -package ca.joeltherrien.randomforest; - -import lombok.RequiredArgsConstructor; - -@RequiredArgsConstructor -public class NumericValue implements Value { - - private final double value; - - @Override - public Double getValue() { - return value; - } - - @Override - public SplitRule generateSplitRule(final String covariateName) { - return new NumericSplitRule(covariateName, value); - } - - @Override - public String toString(){ - return "" + value; - } -} diff --git a/src/main/java/ca/joeltherrien/randomforest/Row.java b/src/main/java/ca/joeltherrien/randomforest/Row.java index 898229b..b762f15 100644 --- a/src/main/java/ca/joeltherrien/randomforest/Row.java +++ b/src/main/java/ca/joeltherrien/randomforest/Row.java @@ -7,7 +7,7 @@ public class Row extends CovariateRow { private final Y response; - public Row(Map valueMap, int id, Y response){ + public Row(Map valueMap, int id, Y response){ super(valueMap, id); this.response = response; } diff --git a/src/main/java/ca/joeltherrien/randomforest/SplitRule.java b/src/main/java/ca/joeltherrien/randomforest/SplitRule.java deleted file mode 100644 index 95d7545..0000000 --- a/src/main/java/ca/joeltherrien/randomforest/SplitRule.java +++ /dev/null @@ -1,37 +0,0 @@ -package ca.joeltherrien.randomforest; - -import java.io.Serializable; -import java.util.LinkedList; -import java.util.List; - -public abstract class SplitRule implements Serializable { - - /** - * Applies the SplitRule to a list of rows and returns a Split object, which contains two lists for both sides. - * This method is primarily used during the training of a tree when splits are being tested. - * - * @param rows - * @param - * @return - */ - public Split applyRule(List> rows) { - final List> leftHand = new LinkedList<>(); - final List> rightHand = new LinkedList<>(); - - for(final Row row : rows) { - - if(isLeftHand(row)){ - leftHand.add(row); - } - else{ - rightHand.add(row); - } - - } - - return new Split<>(leftHand, rightHand); - } - - public abstract boolean isLeftHand(CovariateRow row); - -} diff --git a/src/main/java/ca/joeltherrien/randomforest/Value.java b/src/main/java/ca/joeltherrien/randomforest/Value.java deleted file mode 100644 index fd16563..0000000 --- a/src/main/java/ca/joeltherrien/randomforest/Value.java +++ /dev/null @@ -1,11 +0,0 @@ -package ca.joeltherrien.randomforest; - - -public interface Value { - - V getValue(); - - SplitRule generateSplitRule(String covariateName); - - -} diff --git a/src/main/java/ca/joeltherrien/randomforest/exceptions/MissingValueException.java b/src/main/java/ca/joeltherrien/randomforest/exceptions/MissingValueException.java index 6e58b93..d340fdc 100644 --- a/src/main/java/ca/joeltherrien/randomforest/exceptions/MissingValueException.java +++ b/src/main/java/ca/joeltherrien/randomforest/exceptions/MissingValueException.java @@ -1,8 +1,7 @@ package ca.joeltherrien.randomforest.exceptions; +import ca.joeltherrien.randomforest.Covariate; import ca.joeltherrien.randomforest.CovariateRow; -import ca.joeltherrien.randomforest.Row; -import ca.joeltherrien.randomforest.SplitRule; public class MissingValueException extends RuntimeException{ @@ -11,7 +10,7 @@ public class MissingValueException extends RuntimeException{ */ private static final long serialVersionUID = 6808060079431207726L; - public MissingValueException(CovariateRow row, SplitRule rule) { + public MissingValueException(CovariateRow row, Covariate.SplitRule rule) { super("Missing value at CovariateRow " + row + rule); } diff --git a/src/main/java/ca/joeltherrien/randomforest/tree/ForestTrainer.java b/src/main/java/ca/joeltherrien/randomforest/tree/ForestTrainer.java index 277b368..bae9f9f 100644 --- a/src/main/java/ca/joeltherrien/randomforest/tree/ForestTrainer.java +++ b/src/main/java/ca/joeltherrien/randomforest/tree/ForestTrainer.java @@ -1,6 +1,7 @@ package ca.joeltherrien.randomforest.tree; import ca.joeltherrien.randomforest.Bootstrapper; +import ca.joeltherrien.randomforest.Covariate; import ca.joeltherrien.randomforest.ResponseCombiner; import ca.joeltherrien.randomforest.Row; import lombok.Builder; @@ -21,7 +22,7 @@ import java.util.stream.Stream; public class ForestTrainer { private final TreeTrainer treeTrainer; - private final List covariatesToTry; + private final List covariatesToTry; private final ResponseCombiner treeResponseCombiner; private final List> data; @@ -140,7 +141,7 @@ public class ForestTrainer { } private Node trainTree(final Bootstrapper> bootstrapper){ - final List treeCovariates = new ArrayList<>(covariatesToTry); + final List treeCovariates = new ArrayList<>(covariatesToTry); Collections.shuffle(treeCovariates); for(int treeIndex = covariatesToTry.size()-1; treeIndex >= mtry; treeIndex--){ diff --git a/src/main/java/ca/joeltherrien/randomforest/tree/SplitNode.java b/src/main/java/ca/joeltherrien/randomforest/tree/SplitNode.java index 1cedfb5..3f5dfdb 100644 --- a/src/main/java/ca/joeltherrien/randomforest/tree/SplitNode.java +++ b/src/main/java/ca/joeltherrien/randomforest/tree/SplitNode.java @@ -1,8 +1,7 @@ package ca.joeltherrien.randomforest.tree; +import ca.joeltherrien.randomforest.Covariate; import ca.joeltherrien.randomforest.CovariateRow; -import ca.joeltherrien.randomforest.Row; -import ca.joeltherrien.randomforest.SplitRule; import lombok.Builder; @Builder @@ -10,7 +9,7 @@ public class SplitNode implements Node { private final Node leftHand; private final Node rightHand; - private final SplitRule splitRule; + private final Covariate.SplitRule splitRule; @Override public Y evaluate(CovariateRow row) { diff --git a/src/main/java/ca/joeltherrien/randomforest/tree/TreeTrainer.java b/src/main/java/ca/joeltherrien/randomforest/tree/TreeTrainer.java index e830f4d..94a1dad 100644 --- a/src/main/java/ca/joeltherrien/randomforest/tree/TreeTrainer.java +++ b/src/main/java/ca/joeltherrien/randomforest/tree/TreeTrainer.java @@ -1,8 +1,6 @@ package ca.joeltherrien.randomforest.tree; import ca.joeltherrien.randomforest.*; -import ca.joeltherrien.randomforest.regression.MeanResponseCombiner; -import ca.joeltherrien.randomforest.regression.WeightedVarianceGroupDifferentiator; import lombok.Builder; import java.util.*; @@ -22,17 +20,15 @@ public class TreeTrainer { private final int nodeSize; private final int maxNodeDepth; - private final Random random = new Random(); - - public Node growTree(List> data, List covariatesToTry){ + public Node growTree(List> data, List covariatesToTry){ return growNode(data, covariatesToTry, 0); } - private Node growNode(List> data, List covariatesToTry, int depth){ + private Node growNode(List> data, List covariatesToTry, int depth){ // TODO; what is minimum per tree? if(data.size() >= 2*nodeSize && depth < maxNodeDepth && !nodeIsPure(data)){ - final SplitRule bestSplitRule = findBestSplitRule(data, covariatesToTry); + final Covariate.SplitRule bestSplitRule = findBestSplitRule(data, covariatesToTry); if(bestSplitRule == null){ return new TerminalNode<>( @@ -63,37 +59,24 @@ public class TreeTrainer { } - private SplitRule findBestSplitRule(List> data, List covariatesToTry){ - SplitRule bestSplitRule = null; + private Covariate.SplitRule findBestSplitRule(List> data, List covariatesToTry){ + Covariate.SplitRule bestSplitRule = null; double bestSplitScore = 0.0; boolean first = true; - for(final String covariate : covariatesToTry){ + for(final Covariate covariate : covariatesToTry){ - final List> shuffledData; - if(numberOfSplits == 0 || numberOfSplits > data.size()){ - shuffledData = new ArrayList<>(data); - Collections.shuffle(shuffledData); - } - else{ // only need the top numberOfSplits entries - shuffledData = new ArrayList<>(numberOfSplits); - final Set indexesToUse = new HashSet<>(); + final int numberToTry = numberOfSplits==0 ? data.size() : numberOfSplits; - while(indexesToUse.size() < numberOfSplits){ - final int index = random.nextInt(data.size()); + final Collection splitRulesToTry = covariate + .generateSplitRules( + data + .stream() + .map(row -> row.getCovariateValue(covariate.getName())) + .collect(Collectors.toList()) + , numberToTry); - if(indexesToUse.add(index)){ - shuffledData.add(data.get(index)); - } - } - - } - - - int tries = 0; - - while(tries < shuffledData.size()){ - final SplitRule possibleRule = shuffledData.get(tries).getCovariate(covariate).generateSplitRule(covariate); + for(final Covariate.SplitRule possibleRule : splitRulesToTry){ final Split possibleSplit = possibleRule.applyRule(data); final Double score = groupDifferentiator.differentiate( @@ -106,8 +89,6 @@ public class TreeTrainer { bestSplitScore = score; first = false; } - - tries++; } } diff --git a/src/main/java/ca/joeltherrien/randomforest/workshop/TrainForest.java b/src/main/java/ca/joeltherrien/randomforest/workshop/TrainForest.java index c645a35..003256c 100644 --- a/src/main/java/ca/joeltherrien/randomforest/workshop/TrainForest.java +++ b/src/main/java/ca/joeltherrien/randomforest/workshop/TrainForest.java @@ -19,21 +19,28 @@ public class TrainForest { final int n = 10000; final int p = 5; + final Random random = new Random(); final List> data = new ArrayList<>(n); double minY = 1000.0; + final List covariateList = new ArrayList<>(p); + for(int j =0; j < p; j++){ + final NumericCovariate covariate = new NumericCovariate("x"+j); + covariateList.add(covariate); + } + for(int i=0; i map = new HashMap<>(); + final Map map = new HashMap<>(); - for(int j=0; j(map, i, y)); @@ -44,10 +51,8 @@ public class TrainForest { } - final List covariateNames = IntStream.range(0, p).mapToObj(j -> "x"+j).collect(Collectors.toList()); - - TreeTrainer treeTrainer = TreeTrainer.builder() + final TreeTrainer treeTrainer = TreeTrainer.builder() .numberOfSplits(5) .nodeSize(5) .maxNodeDepth(100000000) @@ -58,7 +63,7 @@ public class TrainForest { final ForestTrainer forestTrainer = ForestTrainer.builder() .treeTrainer(treeTrainer) .data(data) - .covariatesToTry(covariateNames) + .covariatesToTry(covariateList) .mtry(4) .ntree(100) .treeResponseCombiner(new MeanResponseCombiner()) @@ -69,7 +74,7 @@ public class TrainForest { final long startTime = System.currentTimeMillis(); //final Forest forest = forestTrainer.trainSerial(); - //final Forest forest = forestTrainer.trainParallel(8); + //final Forest forest = forestTrainer.trainParallelInMemory(3); forestTrainer.trainParallelOnDisk(3); final long endTime = System.currentTimeMillis(); @@ -88,9 +93,9 @@ public class TrainForest { System.out.println(forest.evaluate(testRow1)); System.out.println(forest.evaluate(testRow2)); - - System.out.println("MinY = " + minY); */ + System.out.println("MinY = " + minY); + } } diff --git a/src/main/java/ca/joeltherrien/randomforest/workshop/TrainSingleTree.java b/src/main/java/ca/joeltherrien/randomforest/workshop/TrainSingleTree.java index d218687..72654a8 100644 --- a/src/main/java/ca/joeltherrien/randomforest/workshop/TrainSingleTree.java +++ b/src/main/java/ca/joeltherrien/randomforest/workshop/TrainSingleTree.java @@ -1,11 +1,10 @@ package ca.joeltherrien.randomforest.workshop; +import ca.joeltherrien.randomforest.Covariate; import ca.joeltherrien.randomforest.CovariateRow; -import ca.joeltherrien.randomforest.NumericValue; +import ca.joeltherrien.randomforest.NumericCovariate; import ca.joeltherrien.randomforest.Row; -import ca.joeltherrien.randomforest.Value; -import ca.joeltherrien.randomforest.regression.MeanGroupDifferentiator; import ca.joeltherrien.randomforest.regression.MeanResponseCombiner; import ca.joeltherrien.randomforest.regression.WeightedVarianceGroupDifferentiator; import ca.joeltherrien.randomforest.tree.Node; @@ -25,30 +24,30 @@ public class TrainSingleTree { final int n = 1000; final List> trainingSet = new ArrayList<>(n); - final List> x1List = DoubleStream + final Covariate x1Covariate = new NumericCovariate("x1"); + final Covariate x2Covariate = new NumericCovariate("x2"); + + final List> x1List = DoubleStream .generate(() -> random.nextDouble()*10.0) .limit(n) - .mapToObj(x1 -> new NumericValue(x1)) + .mapToObj(x1 -> x1Covariate.createValue(x1)) .collect(Collectors.toList()); - final List> x2List = DoubleStream + final List> x2List = DoubleStream .generate(() -> random.nextDouble()*10.0) .limit(n) - .mapToObj(x1 -> new NumericValue(x1)) + .mapToObj(x2 -> x1Covariate.createValue(x2)) .collect(Collectors.toList()); for(int i=0; i x1 = x1List.get(i); + final Covariate.Value x2 = x2List.get(i); trainingSet.add(generateRow(x1, x2, i)); } - - final long startTime = System.currentTimeMillis(); - final TreeTrainer treeTrainer = TreeTrainer.builder() .groupDifferentiator(new WeightedVarianceGroupDifferentiator()) .responseCombiner(new MeanResponseCombiner()) @@ -57,25 +56,29 @@ public class TrainSingleTree { .numberOfSplits(0) .build(); + final List covariateNames = List.of(x1Covariate, x2Covariate); + + final long startTime = System.currentTimeMillis(); + final Node baseNode = treeTrainer.growTree(trainingSet, covariateNames); final long endTime = System.currentTimeMillis(); System.out.println(((double)(endTime - startTime))/1000.0); - final List covariateNames = List.of("x1", "x2"); - final Node baseNode = treeTrainer.growTree(trainingSet, covariateNames); + + final List testSet = new ArrayList<>(); - testSet.add(generateCovariateRow(9, 2, 1)); // expect 1 - testSet.add(generateCovariateRow(5, 2, 5)); - testSet.add(generateCovariateRow(2, 2, 3)); - testSet.add(generateCovariateRow(9, 5, 0)); - testSet.add(generateCovariateRow(6, 5, 8)); - testSet.add(generateCovariateRow(3, 5, 10)); - testSet.add(generateCovariateRow(1, 5, 3)); - testSet.add(generateCovariateRow(7, 9, 2)); - testSet.add(generateCovariateRow(1, 9, 4)); + testSet.add(generateCovariateRow(x1Covariate.createValue(9.0), x2Covariate.createValue(2.0), 1)); // expect 1 + testSet.add(generateCovariateRow(x1Covariate.createValue(5.0), x2Covariate.createValue(2.0), 5)); + testSet.add(generateCovariateRow(x1Covariate.createValue(2.0), x2Covariate.createValue(2.0), 3)); + testSet.add(generateCovariateRow(x1Covariate.createValue(9.0), x2Covariate.createValue(5.0), 0)); + testSet.add(generateCovariateRow(x1Covariate.createValue(6.0), x2Covariate.createValue(5.0), 8)); + testSet.add(generateCovariateRow(x1Covariate.createValue(3.0), x2Covariate.createValue(5.0), 10)); + testSet.add(generateCovariateRow(x1Covariate.createValue(1.0), x2Covariate.createValue(5.0), 3)); + testSet.add(generateCovariateRow(x1Covariate.createValue(7.0), x2Covariate.createValue(9.0), 2)); + testSet.add(generateCovariateRow(x1Covariate.createValue(1.0), x2Covariate.createValue(9.0), 4)); for(final CovariateRow testCase : testSet){ System.out.println(testCase); @@ -91,18 +94,18 @@ public class TrainSingleTree { } - public static Row generateRow(double x1, double x2, int id){ - double y = generateResponse(x1, x2); + public static Row generateRow(Covariate.Value x1, Covariate.Value x2, int id){ + double y = generateResponse(x1.getValue(), x2.getValue()); - final Map map = Map.of("x1", new NumericValue(x1), "x2", new NumericValue(x2)); + final Map map = Map.of("x1", x1, "x2", x2); return new Row<>(map, id, y); } - public static CovariateRow generateCovariateRow(double x1, double x2, int id){ - final Map map = Map.of("x1", new NumericValue(x1), "x2", new NumericValue(x2)); + public static CovariateRow generateCovariateRow(Covariate.Value x1, Covariate.Value x2, int id){ + final Map map = Map.of("x1", x1, "x2", x2); return new CovariateRow(map, id); From 38e70dd3a1d02f3399f107a9019c16b964bcf97a Mon Sep 17 00:00:00 2001 From: Joel Therrien Date: Wed, 4 Jul 2018 10:54:07 -0700 Subject: [PATCH 3/7] Add BooleanCovariate --- .../randomforest/BooleanCovariate.java | 72 +++++++++++++++++++ 1 file changed, 72 insertions(+) create mode 100644 src/main/java/ca/joeltherrien/randomforest/BooleanCovariate.java diff --git a/src/main/java/ca/joeltherrien/randomforest/BooleanCovariate.java b/src/main/java/ca/joeltherrien/randomforest/BooleanCovariate.java new file mode 100644 index 0000000..8920273 --- /dev/null +++ b/src/main/java/ca/joeltherrien/randomforest/BooleanCovariate.java @@ -0,0 +1,72 @@ +package ca.joeltherrien.randomforest; + +import ca.joeltherrien.randomforest.exceptions.MissingValueException; +import lombok.Getter; +import lombok.RequiredArgsConstructor; + +import java.util.*; +import java.util.concurrent.ThreadLocalRandom; +import java.util.stream.Collectors; + +@RequiredArgsConstructor +public class BooleanCovariate implements Covariate{ + + @Getter + private final String name; + + private final BooleanSplitRule splitRule = new BooleanSplitRule(); // there's only one possible rule for BooleanCovariates. + + @Override + public Collection generateSplitRules(List> data, int number) { + return Collections.singleton(splitRule); + } + + @Override + public BooleanValue createValue(Boolean value) { + return new BooleanValue(value); + } + + public class BooleanValue implements Value{ + + private final boolean value; + + private BooleanValue(final boolean value){ + this.value = value; + } + + @Override + public BooleanCovariate getParent() { + return BooleanCovariate.this; + } + + @Override + public Boolean getValue() { + return value; + } + } + + public class BooleanSplitRule implements SplitRule{ + + @Override + public final String toString() { + return "BooleanSplitRule"; + } + + @Override + public BooleanCovariate getParent() { + return BooleanCovariate.this; + } + + @Override + public boolean isLeftHand(CovariateRow row) { + final Value x = row.getCovariateValue(getParent().getName()); + if(x == null) { + throw new MissingValueException(row, this); + } + + final boolean xBoolean = (Boolean) x.getValue(); + + return !xBoolean; + } + } +} From 2259528c221a6e70da6e32c1e0bf8b33fa4ff635 Mon Sep 17 00:00:00 2001 From: Joel Therrien Date: Wed, 4 Jul 2018 10:54:46 -0700 Subject: [PATCH 4/7] Small modificaton of NumericCovariate; child classes now gurantee they return NumericCovariate when getParent() is called. --- .../java/ca/joeltherrien/randomforest/NumericCovariate.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/main/java/ca/joeltherrien/randomforest/NumericCovariate.java b/src/main/java/ca/joeltherrien/randomforest/NumericCovariate.java index 76c83b3..5d1c8dc 100644 --- a/src/main/java/ca/joeltherrien/randomforest/NumericCovariate.java +++ b/src/main/java/ca/joeltherrien/randomforest/NumericCovariate.java @@ -60,7 +60,7 @@ public class NumericCovariate implements Covariate{ } @Override - public Covariate getParent() { + public NumericCovariate getParent() { return NumericCovariate.this; } @@ -84,7 +84,7 @@ public class NumericCovariate implements Covariate{ } @Override - public Covariate getParent() { + public NumericCovariate getParent() { return NumericCovariate.this; } From e0cfed632f81d33cbfd6cb9c8dbcbe067d0c1286 Mon Sep 17 00:00:00 2001 From: Joel Therrien Date: Wed, 4 Jul 2018 12:18:06 -0700 Subject: [PATCH 5/7] Add FactorCovariate; testing required. --- .../randomforest/FactorCovariate.java | 122 ++++++++++++++++++ 1 file changed, 122 insertions(+) create mode 100644 src/main/java/ca/joeltherrien/randomforest/FactorCovariate.java diff --git a/src/main/java/ca/joeltherrien/randomforest/FactorCovariate.java b/src/main/java/ca/joeltherrien/randomforest/FactorCovariate.java new file mode 100644 index 0000000..ba3ac81 --- /dev/null +++ b/src/main/java/ca/joeltherrien/randomforest/FactorCovariate.java @@ -0,0 +1,122 @@ +package ca.joeltherrien.randomforest; + +import lombok.EqualsAndHashCode; + +import java.util.*; +import java.util.concurrent.ThreadLocalRandom; + +public final class FactorCovariate implements Covariate{ + + private final String name; + private final Map factorLevels; + private final int numberOfPossiblePairings; + + + public FactorCovariate(final String name, List levels){ + this.name = name; + this.factorLevels = new HashMap<>(); + + for(final String level : levels){ + final FactorValue newValue = new FactorValue(level); + + factorLevels.put(level, newValue); + } + + int numberOfPossiblePairingsTemp = 1; + for(int i=0; i generateSplitRules(List> data, int number) { + final Set splitRules = new HashSet<>(); + + // This is to ensure we don't get stuck in an infinite loop for small factors + number = Math.min(number, numberOfPossiblePairings); + final Random random = ThreadLocalRandom.current(); + final List levels = new ArrayList<>(factorLevels.values()); + + + + while(splitRules.size() < number){ + Collections.shuffle(levels, random); + final Set leftSideValues = new HashSet<>(); + leftSideValues.add(levels.get(0)); + + for(int i=1; i{ + + private final String value; + + private FactorValue(final String value){ + this.value = value; + } + + @Override + public FactorCovariate getParent() { + return FactorCovariate.this; + } + + @Override + public String getValue() { + return value; + } + } + + @EqualsAndHashCode + public final class FactorSplitRule implements Covariate.SplitRule{ + + private final Set leftSideValues; + + private FactorSplitRule(final Set leftSideValues){ + this.leftSideValues = leftSideValues; + } + + @Override + public FactorCovariate getParent() { + return FactorCovariate.this; + } + + @Override + public boolean isLeftHand(CovariateRow row) { + final FactorValue value = (FactorValue) row.getCovariateValue(getName()).getValue(); + + return leftSideValues.contains(value); + + + } + } +} From c7298f7da6a42c047d0d78b19ed8bc1a428f7fd5 Mon Sep 17 00:00:00 2001 From: Joel Therrien Date: Wed, 4 Jul 2018 12:18:27 -0700 Subject: [PATCH 6/7] Fix incorrect use of non-concurrent Random object in NumericCovariate. --- .../java/ca/joeltherrien/randomforest/NumericCovariate.java | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/main/java/ca/joeltherrien/randomforest/NumericCovariate.java b/src/main/java/ca/joeltherrien/randomforest/NumericCovariate.java index 5d1c8dc..b9be60a 100644 --- a/src/main/java/ca/joeltherrien/randomforest/NumericCovariate.java +++ b/src/main/java/ca/joeltherrien/randomforest/NumericCovariate.java @@ -17,18 +17,20 @@ public class NumericCovariate implements Covariate{ @Override public Collection generateSplitRules(List> data, int number) { + final Random random = ThreadLocalRandom.current(); + // for this implementation we need to shuffle the data final List> shuffledData; if(number > data.size()){ shuffledData = new ArrayList<>(data); - Collections.shuffle(shuffledData); + Collections.shuffle(shuffledData, random); } else{ // only need the top number entries shuffledData = new ArrayList<>(number); final Set indexesToUse = new HashSet<>(); while(indexesToUse.size() < number){ - final int index = ThreadLocalRandom.current().nextInt(data.size()); + final int index = random.nextInt(data.size()); if(indexesToUse.add(index)){ shuffledData.add(data.get(index)); From 3b8952e13c3a8219c182b8451a8065cadda5ffd4 Mon Sep 17 00:00:00 2001 From: Joel Therrien Date: Wed, 4 Jul 2018 13:24:34 -0700 Subject: [PATCH 7/7] Added some tests for FactorCovariate. Moved workshop over to test codebase. --- pom.xml | 13 ++ .../randomforest/FactorCovariate.java | 4 +- .../covariates/FactorCovariateTest.java | 65 ++++++ .../randomforest/workshop/TrainForest.java | 0 .../workshop/TrainSingleTree.java | 0 .../workshop/TrainSingleTreeFactor.java | 190 ++++++++++++++++++ 6 files changed, 270 insertions(+), 2 deletions(-) create mode 100644 src/test/java/ca/joeltherrien/randomforest/covariates/FactorCovariateTest.java rename src/{main => test}/java/ca/joeltherrien/randomforest/workshop/TrainForest.java (100%) rename src/{main => test}/java/ca/joeltherrien/randomforest/workshop/TrainSingleTree.java (100%) create mode 100644 src/test/java/ca/joeltherrien/randomforest/workshop/TrainSingleTreeFactor.java diff --git a/pom.xml b/pom.xml index 637f869..01d9203 100644 --- a/pom.xml +++ b/pom.xml @@ -24,6 +24,19 @@ + + org.junit.jupiter + junit-jupiter-api + 5.2.0 + test + + + + org.junit.jupiter + junit-jupiter-engine + 5.2.0 + test + diff --git a/src/main/java/ca/joeltherrien/randomforest/FactorCovariate.java b/src/main/java/ca/joeltherrien/randomforest/FactorCovariate.java index ba3ac81..5200e02 100644 --- a/src/main/java/ca/joeltherrien/randomforest/FactorCovariate.java +++ b/src/main/java/ca/joeltherrien/randomforest/FactorCovariate.java @@ -26,7 +26,7 @@ public final class FactorCovariate implements Covariate{ for(int i=0; i{ } @Override - public Collection generateSplitRules(List> data, int number) { + public Set generateSplitRules(List> data, int number) { final Set splitRules = new HashSet<>(); // This is to ensure we don't get stuck in an infinite loop for small factors diff --git a/src/test/java/ca/joeltherrien/randomforest/covariates/FactorCovariateTest.java b/src/test/java/ca/joeltherrien/randomforest/covariates/FactorCovariateTest.java new file mode 100644 index 0000000..1a22ce9 --- /dev/null +++ b/src/test/java/ca/joeltherrien/randomforest/covariates/FactorCovariateTest.java @@ -0,0 +1,65 @@ +package ca.joeltherrien.randomforest.covariates; + + +import ca.joeltherrien.randomforest.FactorCovariate; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.function.Executable; + +import java.util.Collection; +import java.util.List; + +import static org.junit.jupiter.api.Assertions.*; + +public class FactorCovariateTest { + + @Test + void verifyEqualLevels() { + final FactorCovariate petCovariate = createTestCovariate(); + + final FactorCovariate.FactorValue dog1 = petCovariate.createValue("DOG"); + final FactorCovariate.FactorValue dog2 = petCovariate.createValue("DO" + "G"); + + assertSame(dog1, dog2); + + final FactorCovariate.FactorValue cat1 = petCovariate.createValue("CAT"); + final FactorCovariate.FactorValue cat2 = petCovariate.createValue("CA" + "T"); + + assertSame(cat1, cat2); + + final FactorCovariate.FactorValue mouse1 = petCovariate.createValue("MOUSE"); + final FactorCovariate.FactorValue mouse2 = petCovariate.createValue("MOUS" + "E"); + + assertSame(mouse1, mouse2); + + + } + + @Test + void verifyBadLevelException(){ + final FactorCovariate petCovariate = createTestCovariate(); + final Executable badCode = () -> petCovariate.createValue("vulcan"); + + assertThrows(IllegalArgumentException.class, badCode, "vulcan is not a level in FactorCovariate pet"); + } + + @Test + void testAllSubsets(){ + final FactorCovariate petCovariate = createTestCovariate(); + + final Collection splitRules = petCovariate.generateSplitRules(null, 100); + + assertEquals(splitRules.size(), 3); + + // TODO verify the contents of the split rules + + } + + + private FactorCovariate createTestCovariate(){ + final List levels = List.of("DOG", "CAT", "MOUSE"); + + return new FactorCovariate("pet", levels); + } + + +} diff --git a/src/main/java/ca/joeltherrien/randomforest/workshop/TrainForest.java b/src/test/java/ca/joeltherrien/randomforest/workshop/TrainForest.java similarity index 100% rename from src/main/java/ca/joeltherrien/randomforest/workshop/TrainForest.java rename to src/test/java/ca/joeltherrien/randomforest/workshop/TrainForest.java diff --git a/src/main/java/ca/joeltherrien/randomforest/workshop/TrainSingleTree.java b/src/test/java/ca/joeltherrien/randomforest/workshop/TrainSingleTree.java similarity index 100% rename from src/main/java/ca/joeltherrien/randomforest/workshop/TrainSingleTree.java rename to src/test/java/ca/joeltherrien/randomforest/workshop/TrainSingleTree.java diff --git a/src/test/java/ca/joeltherrien/randomforest/workshop/TrainSingleTreeFactor.java b/src/test/java/ca/joeltherrien/randomforest/workshop/TrainSingleTreeFactor.java new file mode 100644 index 0000000..e63f8cd --- /dev/null +++ b/src/test/java/ca/joeltherrien/randomforest/workshop/TrainSingleTreeFactor.java @@ -0,0 +1,190 @@ +package ca.joeltherrien.randomforest.workshop; + + +import ca.joeltherrien.randomforest.*; +import ca.joeltherrien.randomforest.regression.MeanResponseCombiner; +import ca.joeltherrien.randomforest.regression.WeightedVarianceGroupDifferentiator; +import ca.joeltherrien.randomforest.tree.Node; +import ca.joeltherrien.randomforest.tree.TreeTrainer; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.Random; +import java.util.stream.Collectors; +import java.util.stream.DoubleStream; +import java.util.stream.Stream; + +public class TrainSingleTreeFactor { + + public static void main(String[] args) { + System.out.println("Hello world!"); + + final Random random = new Random(123); + + final int n = 10000; + final List> trainingSet = new ArrayList<>(n); + + final Covariate x1Covariate = new NumericCovariate("x1"); + final Covariate x2Covariate = new NumericCovariate("x2"); + final FactorCovariate x3Covariate = new FactorCovariate("x3", List.of("cat", "dog", "mouse")); + + final List> x1List = DoubleStream + .generate(() -> random.nextDouble()*10.0) + .limit(n) + .mapToObj(x1 -> x1Covariate.createValue(x1)) + .collect(Collectors.toList()); + + final List> x2List = DoubleStream + .generate(() -> random.nextDouble()*10.0) + .limit(n) + .mapToObj(x2 -> x1Covariate.createValue(x2)) + .collect(Collectors.toList()); + + final List> x3List = DoubleStream + .generate(() -> random.nextDouble()) + .limit(n) + .mapToObj(db -> { + if(db < 0.333){ + return "cat"; + } + else if(db < 0.5){ + return "mouse"; + } + else{ + return "dog"; + } + }) + .map(str -> x3Covariate.createValue(str)) + .collect(Collectors.toList()); + + + for(int i=0; i x1 = x1List.get(i); + final Covariate.Value x2 = x2List.get(i); + final Covariate.Value x3 = x3List.get(i); + + trainingSet.add(generateRow(x1, x2, x3, i)); + } + + final TreeTrainer treeTrainer = TreeTrainer.builder() + .groupDifferentiator(new WeightedVarianceGroupDifferentiator()) + .responseCombiner(new MeanResponseCombiner()) + .maxNodeDepth(30) + .nodeSize(5) + .numberOfSplits(5) + .build(); + + final List covariateNames = List.of(x1Covariate, x2Covariate); + + final long startTime = System.currentTimeMillis(); + final Node baseNode = treeTrainer.growTree(trainingSet, covariateNames); + final long endTime = System.currentTimeMillis(); + + System.out.println(((double)(endTime - startTime))/1000.0); + + + + final Covariate.Value cat = x3Covariate.createValue("cat"); + final Covariate.Value dog = x3Covariate.createValue("dog"); + final Covariate.Value mouse = x3Covariate.createValue("mouse"); + + + final List testSet = new ArrayList<>(); + testSet.add(generateCovariateRow(x1Covariate.createValue(9.0), x2Covariate.createValue(2.0), cat, 1)); // expect 1 + testSet.add(generateCovariateRow(x1Covariate.createValue(5.0), x2Covariate.createValue(2.0), dog, 5)); + testSet.add(generateCovariateRow(x1Covariate.createValue(2.0), x2Covariate.createValue(2.0), cat, 3)); + testSet.add(generateCovariateRow(x1Covariate.createValue(9.0), x2Covariate.createValue(5.0), dog, 0)); + testSet.add(generateCovariateRow(x1Covariate.createValue(6.0), x2Covariate.createValue(5.0), cat, 8)); + testSet.add(generateCovariateRow(x1Covariate.createValue(3.0), x2Covariate.createValue(5.0), dog, 10)); + testSet.add(generateCovariateRow(x1Covariate.createValue(1.0), x2Covariate.createValue(5.0), cat, 3)); + testSet.add(generateCovariateRow(x1Covariate.createValue(7.0), x2Covariate.createValue(9.0), dog, 2)); + testSet.add(generateCovariateRow(x1Covariate.createValue(1.0), x2Covariate.createValue(9.0), cat, 4)); + + + testSet.add(generateCovariateRow(x1Covariate.createValue(3.0), x2Covariate.createValue(9.0), mouse, 0)); + testSet.add(generateCovariateRow(x1Covariate.createValue(7.0), x2Covariate.createValue(9.0), mouse, 5)); + + for(final CovariateRow testCase : testSet){ + System.out.println(testCase); + System.out.println(baseNode.evaluate(testCase)); + System.out.println(); + + + } + + + + + + } + + public static Row generateRow(Covariate.Value x1, Covariate.Value x2, Covariate.Value x3, int id){ + double y = generateResponse(x1.getValue(), x2.getValue(), x3.getValue()); + + final Map map = Map.of("x1", x1, "x2", x2); + + return new Row<>(map, id, y); + + } + + + public static CovariateRow generateCovariateRow(Covariate.Value x1, Covariate.Value x2, Covariate.Value x3, int id){ + final Map map = Map.of("x1", x1, "x2", x2, "x3", x3); + + return new CovariateRow(map, id); + + } + + + public static double generateResponse(double x1, double x2, String x3){ + + if(x3.equalsIgnoreCase("mouse")){ + if(x1 <= 5){ + return 0; + } + else{ + return 5; + } + } + + // cat & dog below + + if(x2 <= 3){ + if(x1 <= 3){ + return 3; + } + else if(x1 <= 7){ + return 5; + } + else{ + return 1; + } + } + else if(x1 >= 5){ + if(x2 > 6){ + return 2; + } + else if(x1 >= 8){ + return 0; + } + else{ + return 8; + } + } + else if(x1 <= 2){ + if(x2 >= 7){ + return 4; + } + else{ + return 3; + } + } + else{ + return 10; + } + + + } + +}