From df7835869a169a105ed8dca3bc2da7bff19278df Mon Sep 17 00:00:00 2001 From: Joel Therrien Date: Mon, 2 Jul 2018 17:58:53 -0700 Subject: [PATCH] Add functionality to train a random forest in serial. --- .../randomforest/Bootstrapper.java | 30 +++++++ .../randomforest/tree/Forest.java | 21 +++++ .../randomforest/tree/ForestTrainer.java | 62 +++++++++++++ .../randomforest/tree/TreeTrainer.java | 30 +++++-- .../randomforest/workshop/TrainForest.java | 90 +++++++++++++++++++ .../TrainSingleTree.java} | 8 +- 6 files changed, 234 insertions(+), 7 deletions(-) create mode 100644 src/main/java/ca/joeltherrien/randomforest/Bootstrapper.java create mode 100644 src/main/java/ca/joeltherrien/randomforest/tree/Forest.java create mode 100644 src/main/java/ca/joeltherrien/randomforest/tree/ForestTrainer.java create mode 100644 src/main/java/ca/joeltherrien/randomforest/workshop/TrainForest.java rename src/main/java/ca/joeltherrien/randomforest/{Main.java => workshop/TrainSingleTree.java} (93%) diff --git a/src/main/java/ca/joeltherrien/randomforest/Bootstrapper.java b/src/main/java/ca/joeltherrien/randomforest/Bootstrapper.java new file mode 100644 index 0000000..88b064b --- /dev/null +++ b/src/main/java/ca/joeltherrien/randomforest/Bootstrapper.java @@ -0,0 +1,30 @@ +package ca.joeltherrien.randomforest; + +import lombok.RequiredArgsConstructor; + +import java.util.ArrayList; +import java.util.List; +import java.util.Random; + +@RequiredArgsConstructor +public class Bootstrapper { + + final private List originalData; + final private Random random = new Random(); + + public List bootstrap(){ + final int n = originalData.size(); + + final List newList = new ArrayList<>(n); + + for(int i=0; i { + + private final List> trees; + private final ResponseCombiner treeResponseCombiner; + + public Y evaluate(CovariateRow row){ + return trees.parallelStream() + .map(node -> node.evaluate(row)) + .collect(treeResponseCombiner); + } + +} diff --git a/src/main/java/ca/joeltherrien/randomforest/tree/ForestTrainer.java b/src/main/java/ca/joeltherrien/randomforest/tree/ForestTrainer.java new file mode 100644 index 0000000..6759056 --- /dev/null +++ b/src/main/java/ca/joeltherrien/randomforest/tree/ForestTrainer.java @@ -0,0 +1,62 @@ +package ca.joeltherrien.randomforest.tree; + +import ca.joeltherrien.randomforest.Bootstrapper; +import ca.joeltherrien.randomforest.ResponseCombiner; +import ca.joeltherrien.randomforest.Row; +import lombok.Builder; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +@Builder +public class ForestTrainer { + + private final TreeTrainer treeTrainer; + private final Bootstrapper> bootstrapper; + private final List covariatesToTry; + private final ResponseCombiner treeResponseCombiner; + + // number of covariates to randomly try + private final int mtry; + + // number of trees to try + private final int ntree; + + private final boolean displayProgress; + + public Forest trainSerial(){ + + final List> trees = new ArrayList<>(ntree); + + for(int j=0; j treeCovariates = new ArrayList<>(covariatesToTry); + Collections.shuffle(treeCovariates); + + for(int treeIndex = covariatesToTry.size()-1; treeIndex >= mtry; treeIndex--){ + treeCovariates.remove(treeIndex); + } + + final List> bootstrappedData = bootstrapper.bootstrap(); + + trees.add(treeTrainer.growTree(bootstrappedData, treeCovariates)); + + if(displayProgress){ + if(j==0) { + System.out.println(); + } + System.out.print("\rFinished tree " + (j+1) + "/" + ntree); + if(j==ntree-1){ + System.out.println(); + } + } + } + + return Forest.builder() + .treeResponseCombiner(treeResponseCombiner) + .trees(trees) + .build(); + + } + +} diff --git a/src/main/java/ca/joeltherrien/randomforest/tree/TreeTrainer.java b/src/main/java/ca/joeltherrien/randomforest/tree/TreeTrainer.java index fed92bf..88152c1 100644 --- a/src/main/java/ca/joeltherrien/randomforest/tree/TreeTrainer.java +++ b/src/main/java/ca/joeltherrien/randomforest/tree/TreeTrainer.java @@ -6,8 +6,7 @@ import ca.joeltherrien.randomforest.Split; import ca.joeltherrien.randomforest.SplitRule; import lombok.Builder; -import java.util.Collections; -import java.util.List; +import java.util.*; import java.util.stream.Collectors; @Builder @@ -24,6 +23,8 @@ public class TreeTrainer { private final int nodeSize; private final int maxNodeDepth; + private final Random random = new Random(); + public Node growTree(List> data, List covariatesToTry){ return growNode(data, covariatesToTry, 0); @@ -60,11 +61,30 @@ public class TreeTrainer { boolean first = true; for(final String covariate : covariatesToTry){ - Collections.shuffle(data); + + final List> shuffledData; + if(numberOfSplits == 0 || numberOfSplits > data.size()){ + shuffledData = new ArrayList<>(data); + Collections.shuffle(shuffledData); + } + else{ // only need the top numberOfSplits entries + shuffledData = new ArrayList<>(numberOfSplits); + final Set indexesToUse = new HashSet<>(); + + while(indexesToUse.size() < numberOfSplits){ + final int index = random.nextInt(data.size()); + + if(indexesToUse.add(index)){ + shuffledData.add(data.get(index)); + } + } + + } + int tries = 0; - while(tries <= numberOfSplits || (numberOfSplits == 0 && tries < data.size())){ - final SplitRule possibleRule = data.get(tries).getCovariate(covariate).generateSplitRule(covariate); + while(tries < shuffledData.size()){ + final SplitRule possibleRule = shuffledData.get(tries).getCovariate(covariate).generateSplitRule(covariate); final Split possibleSplit = possibleRule.applyRule(data); final Double score = groupDifferentiator.differentiate( diff --git a/src/main/java/ca/joeltherrien/randomforest/workshop/TrainForest.java b/src/main/java/ca/joeltherrien/randomforest/workshop/TrainForest.java new file mode 100644 index 0000000..afa528d --- /dev/null +++ b/src/main/java/ca/joeltherrien/randomforest/workshop/TrainForest.java @@ -0,0 +1,90 @@ +package ca.joeltherrien.randomforest.workshop; + +import ca.joeltherrien.randomforest.*; +import ca.joeltherrien.randomforest.regression.MeanResponseCombiner; +import ca.joeltherrien.randomforest.regression.WeightedVarianceGroupDifferentiator; +import ca.joeltherrien.randomforest.tree.Forest; +import ca.joeltherrien.randomforest.tree.ForestTrainer; +import ca.joeltherrien.randomforest.tree.TreeTrainer; + +import java.util.*; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +public class TrainForest { + + public static void main(String[] args){ + // test creating a regression tree on a problem and see if the results are sensible. + + final int n = 10000; + final int p =5; + + final Random random = new Random(); + + final List> data = new ArrayList<>(n); + + double minY = 1000.0; + + for(int i=0; i map = new HashMap<>(); + + for(int j=0; j(map, i, y)); + + if(y < minY){ + minY = y; + } + + } + + final List covariateNames = IntStream.range(0, p).mapToObj(j -> "x"+j).collect(Collectors.toList()); + + + TreeTrainer treeTrainer = TreeTrainer.builder() + .numberOfSplits(10) + .nodeSize(5) + .maxNodeDepth(100000000) + .groupDifferentiator(new WeightedVarianceGroupDifferentiator()) + .responseCombiner(new MeanResponseCombiner()) + .build(); + + final ForestTrainer forestTrainer = ForestTrainer.builder() + .treeTrainer(treeTrainer) + .bootstrapper(new Bootstrapper<>(data)) + .covariatesToTry(covariateNames) + .mtry(4) + .ntree(100) + .treeResponseCombiner(new MeanResponseCombiner()) + .displayProgress(true) + .build(); + + final long startTime = System.currentTimeMillis(); + final Forest forest = forestTrainer.trainSerial(); + final long endTime = System.currentTimeMillis(); + + System.out.println("Took " + (double)(endTime - startTime)/1000.0 + " seconds."); + + + final Value zeroValue = new NumericValue(0.1); + final Value point5Value = new NumericValue(0.5); + + // test row + final CovariateRow testRow1 = new CovariateRow(Map.of("x0", zeroValue, "x1",zeroValue,"x2",zeroValue,"x3",zeroValue,"x4",zeroValue), 0); + final CovariateRow testRow2 = new CovariateRow(Map.of("x0", point5Value, "x1",point5Value,"x2",point5Value,"x3",point5Value,"x4",point5Value), 2); + + + System.out.println(forest.evaluate(testRow1)); + System.out.println(forest.evaluate(testRow2)); + + System.out.println("MinY = " + minY); + + } + +} diff --git a/src/main/java/ca/joeltherrien/randomforest/Main.java b/src/main/java/ca/joeltherrien/randomforest/workshop/TrainSingleTree.java similarity index 93% rename from src/main/java/ca/joeltherrien/randomforest/Main.java rename to src/main/java/ca/joeltherrien/randomforest/workshop/TrainSingleTree.java index b2a8edb..d218687 100644 --- a/src/main/java/ca/joeltherrien/randomforest/Main.java +++ b/src/main/java/ca/joeltherrien/randomforest/workshop/TrainSingleTree.java @@ -1,6 +1,10 @@ -package ca.joeltherrien.randomforest; +package ca.joeltherrien.randomforest.workshop; +import ca.joeltherrien.randomforest.CovariateRow; +import ca.joeltherrien.randomforest.NumericValue; +import ca.joeltherrien.randomforest.Row; +import ca.joeltherrien.randomforest.Value; import ca.joeltherrien.randomforest.regression.MeanGroupDifferentiator; import ca.joeltherrien.randomforest.regression.MeanResponseCombiner; import ca.joeltherrien.randomforest.regression.WeightedVarianceGroupDifferentiator; @@ -11,7 +15,7 @@ import java.util.*; import java.util.stream.Collectors; import java.util.stream.DoubleStream; -public class Main { +public class TrainSingleTree { public static void main(String[] args) { System.out.println("Hello world!");