diff --git a/src/main/java/ca/joeltherrien/randomforest/Covariate.java b/src/main/java/ca/joeltherrien/randomforest/Covariate.java new file mode 100644 index 0000000..86848d8 --- /dev/null +++ b/src/main/java/ca/joeltherrien/randomforest/Covariate.java @@ -0,0 +1,58 @@ +package ca.joeltherrien.randomforest; + +import java.io.Serializable; +import java.util.Collection; +import java.util.LinkedList; +import java.util.List; + +public interface Covariate extends Serializable { + + String getName(); + + Collection> generateSplitRules(final List> data, final int number); + + Value createValue(V value); + + interface Value extends Serializable{ + + Covariate getParent(); + + V getValue(); + + } + + interface SplitRule extends Serializable{ + + Covariate getParent(); + + /** + * Applies the SplitRule to a list of rows and returns a Split object, which contains two lists for both sides. + * This method is primarily used during the training of a tree when splits are being tested. + * + * @param rows + * @param + * @return + */ + default Split applyRule(List> rows) { + final List> leftHand = new LinkedList<>(); + final List> rightHand = new LinkedList<>(); + + for(final Row row : rows) { + + if(isLeftHand(row)){ + leftHand.add(row); + } + else{ + rightHand.add(row); + } + + } + + return new Split<>(leftHand, rightHand); + } + + boolean isLeftHand(CovariateRow row); + } + + +} diff --git a/src/main/java/ca/joeltherrien/randomforest/CovariateRow.java b/src/main/java/ca/joeltherrien/randomforest/CovariateRow.java index beccc3a..181d80b 100644 --- a/src/main/java/ca/joeltherrien/randomforest/CovariateRow.java +++ b/src/main/java/ca/joeltherrien/randomforest/CovariateRow.java @@ -8,12 +8,12 @@ import java.util.Map; @RequiredArgsConstructor public class CovariateRow { - private final Map valueMap; + private final Map valueMap; @Getter private final int id; - public Value getCovariate(String name){ + public Covariate.Value getCovariateValue(String name){ return valueMap.get(name); } diff --git a/src/main/java/ca/joeltherrien/randomforest/NumericCovariate.java b/src/main/java/ca/joeltherrien/randomforest/NumericCovariate.java new file mode 100644 index 0000000..76c83b3 --- /dev/null +++ b/src/main/java/ca/joeltherrien/randomforest/NumericCovariate.java @@ -0,0 +1,103 @@ +package ca.joeltherrien.randomforest; + +import ca.joeltherrien.randomforest.exceptions.MissingValueException; +import lombok.Getter; +import lombok.RequiredArgsConstructor; + +import java.util.*; +import java.util.concurrent.ThreadLocalRandom; +import java.util.stream.Collectors; + +@RequiredArgsConstructor +public class NumericCovariate implements Covariate{ + + @Getter + private final String name; + + @Override + public Collection generateSplitRules(List> data, int number) { + + // for this implementation we need to shuffle the data + final List> shuffledData; + if(number > data.size()){ + shuffledData = new ArrayList<>(data); + Collections.shuffle(shuffledData); + } + else{ // only need the top number entries + shuffledData = new ArrayList<>(number); + final Set indexesToUse = new HashSet<>(); + + while(indexesToUse.size() < number){ + final int index = ThreadLocalRandom.current().nextInt(data.size()); + + if(indexesToUse.add(index)){ + shuffledData.add(data.get(index)); + } + } + + } + + return shuffledData.stream() + .mapToDouble(v -> v.getValue()) + .mapToObj(threshold -> new NumericSplitRule(threshold)) + .collect(Collectors.toSet()); + // by returning a set we'll make everything far more efficient as a lot of rules can repeat due to bootstrapping + + + } + + @Override + public NumericValue createValue(Double value) { + return new NumericValue(value); + } + + public class NumericValue implements Covariate.Value{ + + private final double value; + + private NumericValue(final double value){ + this.value = value; + } + + @Override + public Covariate getParent() { + return NumericCovariate.this; + } + + @Override + public Double getValue() { + return value; + } + } + + public class NumericSplitRule implements Covariate.SplitRule{ + + private final double threshold; + + private NumericSplitRule(final double threshold){ + this.threshold = threshold; + } + + @Override + public final String toString() { + return "NumericSplitRule on " + getParent().getName() + " at " + threshold; + } + + @Override + public Covariate getParent() { + return NumericCovariate.this; + } + + @Override + public boolean isLeftHand(CovariateRow row) { + final Covariate.Value x = row.getCovariateValue(getParent().getName()); + if(x == null) { + throw new MissingValueException(row, this); + } + + final double xNum = (Double) x.getValue(); + + return xNum <= threshold; + } + } +} diff --git a/src/main/java/ca/joeltherrien/randomforest/NumericSplitRule.java b/src/main/java/ca/joeltherrien/randomforest/NumericSplitRule.java deleted file mode 100644 index 505a58c..0000000 --- a/src/main/java/ca/joeltherrien/randomforest/NumericSplitRule.java +++ /dev/null @@ -1,33 +0,0 @@ -package ca.joeltherrien.randomforest; - -import java.util.LinkedList; -import java.util.List; - -import ca.joeltherrien.randomforest.exceptions.MissingValueException; -import lombok.AllArgsConstructor; - -@AllArgsConstructor -public class NumericSplitRule extends SplitRule{ - - public final String covariateName; - public final double threshold; - - @Override - public final String toString() { - return "NumericSplitRule on " + covariateName + " at " + threshold; - } - - @Override - public boolean isLeftHand(CovariateRow row) { - final Value x = row.getCovariate(covariateName); - if(x == null) { - throw new MissingValueException(row, this); - } - - final double xNum = (Double) x.getValue(); - - return xNum <= threshold; - } - - -} diff --git a/src/main/java/ca/joeltherrien/randomforest/NumericValue.java b/src/main/java/ca/joeltherrien/randomforest/NumericValue.java deleted file mode 100644 index a3d79bc..0000000 --- a/src/main/java/ca/joeltherrien/randomforest/NumericValue.java +++ /dev/null @@ -1,24 +0,0 @@ -package ca.joeltherrien.randomforest; - -import lombok.RequiredArgsConstructor; - -@RequiredArgsConstructor -public class NumericValue implements Value { - - private final double value; - - @Override - public Double getValue() { - return value; - } - - @Override - public SplitRule generateSplitRule(final String covariateName) { - return new NumericSplitRule(covariateName, value); - } - - @Override - public String toString(){ - return "" + value; - } -} diff --git a/src/main/java/ca/joeltherrien/randomforest/Row.java b/src/main/java/ca/joeltherrien/randomforest/Row.java index 898229b..b762f15 100644 --- a/src/main/java/ca/joeltherrien/randomforest/Row.java +++ b/src/main/java/ca/joeltherrien/randomforest/Row.java @@ -7,7 +7,7 @@ public class Row extends CovariateRow { private final Y response; - public Row(Map valueMap, int id, Y response){ + public Row(Map valueMap, int id, Y response){ super(valueMap, id); this.response = response; } diff --git a/src/main/java/ca/joeltherrien/randomforest/SplitRule.java b/src/main/java/ca/joeltherrien/randomforest/SplitRule.java deleted file mode 100644 index 95d7545..0000000 --- a/src/main/java/ca/joeltherrien/randomforest/SplitRule.java +++ /dev/null @@ -1,37 +0,0 @@ -package ca.joeltherrien.randomforest; - -import java.io.Serializable; -import java.util.LinkedList; -import java.util.List; - -public abstract class SplitRule implements Serializable { - - /** - * Applies the SplitRule to a list of rows and returns a Split object, which contains two lists for both sides. - * This method is primarily used during the training of a tree when splits are being tested. - * - * @param rows - * @param - * @return - */ - public Split applyRule(List> rows) { - final List> leftHand = new LinkedList<>(); - final List> rightHand = new LinkedList<>(); - - for(final Row row : rows) { - - if(isLeftHand(row)){ - leftHand.add(row); - } - else{ - rightHand.add(row); - } - - } - - return new Split<>(leftHand, rightHand); - } - - public abstract boolean isLeftHand(CovariateRow row); - -} diff --git a/src/main/java/ca/joeltherrien/randomforest/Value.java b/src/main/java/ca/joeltherrien/randomforest/Value.java deleted file mode 100644 index fd16563..0000000 --- a/src/main/java/ca/joeltherrien/randomforest/Value.java +++ /dev/null @@ -1,11 +0,0 @@ -package ca.joeltherrien.randomforest; - - -public interface Value { - - V getValue(); - - SplitRule generateSplitRule(String covariateName); - - -} diff --git a/src/main/java/ca/joeltherrien/randomforest/exceptions/MissingValueException.java b/src/main/java/ca/joeltherrien/randomforest/exceptions/MissingValueException.java index 6e58b93..d340fdc 100644 --- a/src/main/java/ca/joeltherrien/randomforest/exceptions/MissingValueException.java +++ b/src/main/java/ca/joeltherrien/randomforest/exceptions/MissingValueException.java @@ -1,8 +1,7 @@ package ca.joeltherrien.randomforest.exceptions; +import ca.joeltherrien.randomforest.Covariate; import ca.joeltherrien.randomforest.CovariateRow; -import ca.joeltherrien.randomforest.Row; -import ca.joeltherrien.randomforest.SplitRule; public class MissingValueException extends RuntimeException{ @@ -11,7 +10,7 @@ public class MissingValueException extends RuntimeException{ */ private static final long serialVersionUID = 6808060079431207726L; - public MissingValueException(CovariateRow row, SplitRule rule) { + public MissingValueException(CovariateRow row, Covariate.SplitRule rule) { super("Missing value at CovariateRow " + row + rule); } diff --git a/src/main/java/ca/joeltherrien/randomforest/tree/ForestTrainer.java b/src/main/java/ca/joeltherrien/randomforest/tree/ForestTrainer.java index 277b368..bae9f9f 100644 --- a/src/main/java/ca/joeltherrien/randomforest/tree/ForestTrainer.java +++ b/src/main/java/ca/joeltherrien/randomforest/tree/ForestTrainer.java @@ -1,6 +1,7 @@ package ca.joeltherrien.randomforest.tree; import ca.joeltherrien.randomforest.Bootstrapper; +import ca.joeltherrien.randomforest.Covariate; import ca.joeltherrien.randomforest.ResponseCombiner; import ca.joeltherrien.randomforest.Row; import lombok.Builder; @@ -21,7 +22,7 @@ import java.util.stream.Stream; public class ForestTrainer { private final TreeTrainer treeTrainer; - private final List covariatesToTry; + private final List covariatesToTry; private final ResponseCombiner treeResponseCombiner; private final List> data; @@ -140,7 +141,7 @@ public class ForestTrainer { } private Node trainTree(final Bootstrapper> bootstrapper){ - final List treeCovariates = new ArrayList<>(covariatesToTry); + final List treeCovariates = new ArrayList<>(covariatesToTry); Collections.shuffle(treeCovariates); for(int treeIndex = covariatesToTry.size()-1; treeIndex >= mtry; treeIndex--){ diff --git a/src/main/java/ca/joeltherrien/randomforest/tree/SplitNode.java b/src/main/java/ca/joeltherrien/randomforest/tree/SplitNode.java index 1cedfb5..3f5dfdb 100644 --- a/src/main/java/ca/joeltherrien/randomforest/tree/SplitNode.java +++ b/src/main/java/ca/joeltherrien/randomforest/tree/SplitNode.java @@ -1,8 +1,7 @@ package ca.joeltherrien.randomforest.tree; +import ca.joeltherrien.randomforest.Covariate; import ca.joeltherrien.randomforest.CovariateRow; -import ca.joeltherrien.randomforest.Row; -import ca.joeltherrien.randomforest.SplitRule; import lombok.Builder; @Builder @@ -10,7 +9,7 @@ public class SplitNode implements Node { private final Node leftHand; private final Node rightHand; - private final SplitRule splitRule; + private final Covariate.SplitRule splitRule; @Override public Y evaluate(CovariateRow row) { diff --git a/src/main/java/ca/joeltherrien/randomforest/tree/TreeTrainer.java b/src/main/java/ca/joeltherrien/randomforest/tree/TreeTrainer.java index e830f4d..94a1dad 100644 --- a/src/main/java/ca/joeltherrien/randomforest/tree/TreeTrainer.java +++ b/src/main/java/ca/joeltherrien/randomforest/tree/TreeTrainer.java @@ -1,8 +1,6 @@ package ca.joeltherrien.randomforest.tree; import ca.joeltherrien.randomforest.*; -import ca.joeltherrien.randomforest.regression.MeanResponseCombiner; -import ca.joeltherrien.randomforest.regression.WeightedVarianceGroupDifferentiator; import lombok.Builder; import java.util.*; @@ -22,17 +20,15 @@ public class TreeTrainer { private final int nodeSize; private final int maxNodeDepth; - private final Random random = new Random(); - - public Node growTree(List> data, List covariatesToTry){ + public Node growTree(List> data, List covariatesToTry){ return growNode(data, covariatesToTry, 0); } - private Node growNode(List> data, List covariatesToTry, int depth){ + private Node growNode(List> data, List covariatesToTry, int depth){ // TODO; what is minimum per tree? if(data.size() >= 2*nodeSize && depth < maxNodeDepth && !nodeIsPure(data)){ - final SplitRule bestSplitRule = findBestSplitRule(data, covariatesToTry); + final Covariate.SplitRule bestSplitRule = findBestSplitRule(data, covariatesToTry); if(bestSplitRule == null){ return new TerminalNode<>( @@ -63,37 +59,24 @@ public class TreeTrainer { } - private SplitRule findBestSplitRule(List> data, List covariatesToTry){ - SplitRule bestSplitRule = null; + private Covariate.SplitRule findBestSplitRule(List> data, List covariatesToTry){ + Covariate.SplitRule bestSplitRule = null; double bestSplitScore = 0.0; boolean first = true; - for(final String covariate : covariatesToTry){ + for(final Covariate covariate : covariatesToTry){ - final List> shuffledData; - if(numberOfSplits == 0 || numberOfSplits > data.size()){ - shuffledData = new ArrayList<>(data); - Collections.shuffle(shuffledData); - } - else{ // only need the top numberOfSplits entries - shuffledData = new ArrayList<>(numberOfSplits); - final Set indexesToUse = new HashSet<>(); + final int numberToTry = numberOfSplits==0 ? data.size() : numberOfSplits; - while(indexesToUse.size() < numberOfSplits){ - final int index = random.nextInt(data.size()); + final Collection splitRulesToTry = covariate + .generateSplitRules( + data + .stream() + .map(row -> row.getCovariateValue(covariate.getName())) + .collect(Collectors.toList()) + , numberToTry); - if(indexesToUse.add(index)){ - shuffledData.add(data.get(index)); - } - } - - } - - - int tries = 0; - - while(tries < shuffledData.size()){ - final SplitRule possibleRule = shuffledData.get(tries).getCovariate(covariate).generateSplitRule(covariate); + for(final Covariate.SplitRule possibleRule : splitRulesToTry){ final Split possibleSplit = possibleRule.applyRule(data); final Double score = groupDifferentiator.differentiate( @@ -106,8 +89,6 @@ public class TreeTrainer { bestSplitScore = score; first = false; } - - tries++; } } diff --git a/src/main/java/ca/joeltherrien/randomforest/workshop/TrainForest.java b/src/main/java/ca/joeltherrien/randomforest/workshop/TrainForest.java index c645a35..003256c 100644 --- a/src/main/java/ca/joeltherrien/randomforest/workshop/TrainForest.java +++ b/src/main/java/ca/joeltherrien/randomforest/workshop/TrainForest.java @@ -19,21 +19,28 @@ public class TrainForest { final int n = 10000; final int p = 5; + final Random random = new Random(); final List> data = new ArrayList<>(n); double minY = 1000.0; + final List covariateList = new ArrayList<>(p); + for(int j =0; j < p; j++){ + final NumericCovariate covariate = new NumericCovariate("x"+j); + covariateList.add(covariate); + } + for(int i=0; i map = new HashMap<>(); + final Map map = new HashMap<>(); - for(int j=0; j(map, i, y)); @@ -44,10 +51,8 @@ public class TrainForest { } - final List covariateNames = IntStream.range(0, p).mapToObj(j -> "x"+j).collect(Collectors.toList()); - - TreeTrainer treeTrainer = TreeTrainer.builder() + final TreeTrainer treeTrainer = TreeTrainer.builder() .numberOfSplits(5) .nodeSize(5) .maxNodeDepth(100000000) @@ -58,7 +63,7 @@ public class TrainForest { final ForestTrainer forestTrainer = ForestTrainer.builder() .treeTrainer(treeTrainer) .data(data) - .covariatesToTry(covariateNames) + .covariatesToTry(covariateList) .mtry(4) .ntree(100) .treeResponseCombiner(new MeanResponseCombiner()) @@ -69,7 +74,7 @@ public class TrainForest { final long startTime = System.currentTimeMillis(); //final Forest forest = forestTrainer.trainSerial(); - //final Forest forest = forestTrainer.trainParallel(8); + //final Forest forest = forestTrainer.trainParallelInMemory(3); forestTrainer.trainParallelOnDisk(3); final long endTime = System.currentTimeMillis(); @@ -88,9 +93,9 @@ public class TrainForest { System.out.println(forest.evaluate(testRow1)); System.out.println(forest.evaluate(testRow2)); - - System.out.println("MinY = " + minY); */ + System.out.println("MinY = " + minY); + } } diff --git a/src/main/java/ca/joeltherrien/randomforest/workshop/TrainSingleTree.java b/src/main/java/ca/joeltherrien/randomforest/workshop/TrainSingleTree.java index d218687..72654a8 100644 --- a/src/main/java/ca/joeltherrien/randomforest/workshop/TrainSingleTree.java +++ b/src/main/java/ca/joeltherrien/randomforest/workshop/TrainSingleTree.java @@ -1,11 +1,10 @@ package ca.joeltherrien.randomforest.workshop; +import ca.joeltherrien.randomforest.Covariate; import ca.joeltherrien.randomforest.CovariateRow; -import ca.joeltherrien.randomforest.NumericValue; +import ca.joeltherrien.randomforest.NumericCovariate; import ca.joeltherrien.randomforest.Row; -import ca.joeltherrien.randomforest.Value; -import ca.joeltherrien.randomforest.regression.MeanGroupDifferentiator; import ca.joeltherrien.randomforest.regression.MeanResponseCombiner; import ca.joeltherrien.randomforest.regression.WeightedVarianceGroupDifferentiator; import ca.joeltherrien.randomforest.tree.Node; @@ -25,30 +24,30 @@ public class TrainSingleTree { final int n = 1000; final List> trainingSet = new ArrayList<>(n); - final List> x1List = DoubleStream + final Covariate x1Covariate = new NumericCovariate("x1"); + final Covariate x2Covariate = new NumericCovariate("x2"); + + final List> x1List = DoubleStream .generate(() -> random.nextDouble()*10.0) .limit(n) - .mapToObj(x1 -> new NumericValue(x1)) + .mapToObj(x1 -> x1Covariate.createValue(x1)) .collect(Collectors.toList()); - final List> x2List = DoubleStream + final List> x2List = DoubleStream .generate(() -> random.nextDouble()*10.0) .limit(n) - .mapToObj(x1 -> new NumericValue(x1)) + .mapToObj(x2 -> x1Covariate.createValue(x2)) .collect(Collectors.toList()); for(int i=0; i x1 = x1List.get(i); + final Covariate.Value x2 = x2List.get(i); trainingSet.add(generateRow(x1, x2, i)); } - - final long startTime = System.currentTimeMillis(); - final TreeTrainer treeTrainer = TreeTrainer.builder() .groupDifferentiator(new WeightedVarianceGroupDifferentiator()) .responseCombiner(new MeanResponseCombiner()) @@ -57,25 +56,29 @@ public class TrainSingleTree { .numberOfSplits(0) .build(); + final List covariateNames = List.of(x1Covariate, x2Covariate); + + final long startTime = System.currentTimeMillis(); + final Node baseNode = treeTrainer.growTree(trainingSet, covariateNames); final long endTime = System.currentTimeMillis(); System.out.println(((double)(endTime - startTime))/1000.0); - final List covariateNames = List.of("x1", "x2"); - final Node baseNode = treeTrainer.growTree(trainingSet, covariateNames); + + final List testSet = new ArrayList<>(); - testSet.add(generateCovariateRow(9, 2, 1)); // expect 1 - testSet.add(generateCovariateRow(5, 2, 5)); - testSet.add(generateCovariateRow(2, 2, 3)); - testSet.add(generateCovariateRow(9, 5, 0)); - testSet.add(generateCovariateRow(6, 5, 8)); - testSet.add(generateCovariateRow(3, 5, 10)); - testSet.add(generateCovariateRow(1, 5, 3)); - testSet.add(generateCovariateRow(7, 9, 2)); - testSet.add(generateCovariateRow(1, 9, 4)); + testSet.add(generateCovariateRow(x1Covariate.createValue(9.0), x2Covariate.createValue(2.0), 1)); // expect 1 + testSet.add(generateCovariateRow(x1Covariate.createValue(5.0), x2Covariate.createValue(2.0), 5)); + testSet.add(generateCovariateRow(x1Covariate.createValue(2.0), x2Covariate.createValue(2.0), 3)); + testSet.add(generateCovariateRow(x1Covariate.createValue(9.0), x2Covariate.createValue(5.0), 0)); + testSet.add(generateCovariateRow(x1Covariate.createValue(6.0), x2Covariate.createValue(5.0), 8)); + testSet.add(generateCovariateRow(x1Covariate.createValue(3.0), x2Covariate.createValue(5.0), 10)); + testSet.add(generateCovariateRow(x1Covariate.createValue(1.0), x2Covariate.createValue(5.0), 3)); + testSet.add(generateCovariateRow(x1Covariate.createValue(7.0), x2Covariate.createValue(9.0), 2)); + testSet.add(generateCovariateRow(x1Covariate.createValue(1.0), x2Covariate.createValue(9.0), 4)); for(final CovariateRow testCase : testSet){ System.out.println(testCase); @@ -91,18 +94,18 @@ public class TrainSingleTree { } - public static Row generateRow(double x1, double x2, int id){ - double y = generateResponse(x1, x2); + public static Row generateRow(Covariate.Value x1, Covariate.Value x2, int id){ + double y = generateResponse(x1.getValue(), x2.getValue()); - final Map map = Map.of("x1", new NumericValue(x1), "x2", new NumericValue(x2)); + final Map map = Map.of("x1", x1, "x2", x2); return new Row<>(map, id, y); } - public static CovariateRow generateCovariateRow(double x1, double x2, int id){ - final Map map = Map.of("x1", new NumericValue(x1), "x2", new NumericValue(x2)); + public static CovariateRow generateCovariateRow(Covariate.Value x1, Covariate.Value x2, int id){ + final Map map = Map.of("x1", x1, "x2", x2); return new CovariateRow(map, id);