diff --git a/src/main/java/ca/joeltherrien/randomforest/NumericValue.java b/src/main/java/ca/joeltherrien/randomforest/NumericValue.java index 9defa06..a3d79bc 100644 --- a/src/main/java/ca/joeltherrien/randomforest/NumericValue.java +++ b/src/main/java/ca/joeltherrien/randomforest/NumericValue.java @@ -16,4 +16,9 @@ public class NumericValue implements Value { public SplitRule generateSplitRule(final String covariateName) { return new NumericSplitRule(covariateName, value); } + + @Override + public String toString(){ + return "" + value; + } } diff --git a/src/main/java/ca/joeltherrien/randomforest/tree/Forest.java b/src/main/java/ca/joeltherrien/randomforest/tree/Forest.java index 5036d9e..30e222b 100644 --- a/src/main/java/ca/joeltherrien/randomforest/tree/Forest.java +++ b/src/main/java/ca/joeltherrien/randomforest/tree/Forest.java @@ -4,12 +4,13 @@ import ca.joeltherrien.randomforest.CovariateRow; import ca.joeltherrien.randomforest.ResponseCombiner; import lombok.Builder; +import java.util.Collection; import java.util.List; @Builder public class Forest { - private final List> trees; + private final Collection> trees; private final ResponseCombiner treeResponseCombiner; public Y evaluate(CovariateRow row){ diff --git a/src/main/java/ca/joeltherrien/randomforest/tree/ForestTrainer.java b/src/main/java/ca/joeltherrien/randomforest/tree/ForestTrainer.java index 6759056..69c1d44 100644 --- a/src/main/java/ca/joeltherrien/randomforest/tree/ForestTrainer.java +++ b/src/main/java/ca/joeltherrien/randomforest/tree/ForestTrainer.java @@ -4,18 +4,23 @@ import ca.joeltherrien.randomforest.Bootstrapper; import ca.joeltherrien.randomforest.ResponseCombiner; import ca.joeltherrien.randomforest.Row; import lombok.Builder; +import lombok.RequiredArgsConstructor; import java.util.ArrayList; import java.util.Collections; import java.util.List; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.stream.Collectors; +import java.util.stream.Stream; @Builder public class ForestTrainer { private final TreeTrainer treeTrainer; - private final Bootstrapper> bootstrapper; private final List covariatesToTry; private final ResponseCombiner treeResponseCombiner; + private final List> data; // number of covariates to randomly try private final int mtry; @@ -28,18 +33,11 @@ public class ForestTrainer { public Forest trainSerial(){ final List> trees = new ArrayList<>(ntree); + final Bootstrapper> bootstrapper = new Bootstrapper<>(data); for(int j=0; j treeCovariates = new ArrayList<>(covariatesToTry); - Collections.shuffle(treeCovariates); - for(int treeIndex = covariatesToTry.size()-1; treeIndex >= mtry; treeIndex--){ - treeCovariates.remove(treeIndex); - } - - final List> bootstrappedData = bootstrapper.bootstrap(); - - trees.add(treeTrainer.growTree(bootstrappedData, treeCovariates)); + trees.add(trainTree(bootstrapper)); if(displayProgress){ if(j==0) { @@ -59,4 +57,91 @@ public class ForestTrainer { } + public Forest trainParallel(int threads){ + + // create a list that is prespecified in size (I can call the .set method at any index < ntree without + // the earlier indexes being filled. + final List> trees = Stream.>generate(() -> null).limit(ntree).collect(Collectors.toList()); + + final ExecutorService executorService = Executors.newFixedThreadPool(threads); + + for(int j=0; j tree : trees) { + if (tree != null) { + numberTreesSet++; + } + } + + System.out.print("\rFinished " + numberTreesSet + "/" + ntree + " trees"); + } + + } + + if(displayProgress){ + System.out.println("\nFinished"); + } + + return Forest.builder() + .treeResponseCombiner(treeResponseCombiner) + .trees(trees) + .build(); + + } + + private Node trainTree(final Bootstrapper> bootstrapper){ + final List treeCovariates = new ArrayList<>(covariatesToTry); + Collections.shuffle(treeCovariates); + + for(int treeIndex = covariatesToTry.size()-1; treeIndex >= mtry; treeIndex--){ + treeCovariates.remove(treeIndex); + } + + final List> bootstrappedData = bootstrapper.bootstrap(); + + return treeTrainer.growTree(bootstrappedData, treeCovariates); + } + + private class Worker implements Runnable { + + private final Bootstrapper> bootstrapper; + private final int treeIndex; + private final List> treeList; + + public Worker(final List> data, final int treeIndex, final List> treeList) { + this.bootstrapper = new Bootstrapper<>(data); + this.treeIndex = treeIndex; + this.treeList = treeList; + } + + @Override + public void run() { + + final Node tree = trainTree(bootstrapper); + + // should be okay as the list structure isn't changing + treeList.set(treeIndex, tree); + + //if(displayProgress){ + // System.out.println("Finished tree " + (treeIndex+1)); + //} + + + } + } + } diff --git a/src/main/java/ca/joeltherrien/randomforest/tree/TreeTrainer.java b/src/main/java/ca/joeltherrien/randomforest/tree/TreeTrainer.java index 88152c1..c140164 100644 --- a/src/main/java/ca/joeltherrien/randomforest/tree/TreeTrainer.java +++ b/src/main/java/ca/joeltherrien/randomforest/tree/TreeTrainer.java @@ -1,9 +1,8 @@ package ca.joeltherrien.randomforest.tree; -import ca.joeltherrien.randomforest.ResponseCombiner; -import ca.joeltherrien.randomforest.Row; -import ca.joeltherrien.randomforest.Split; -import ca.joeltherrien.randomforest.SplitRule; +import ca.joeltherrien.randomforest.*; +import ca.joeltherrien.randomforest.regression.MeanResponseCombiner; +import ca.joeltherrien.randomforest.regression.WeightedVarianceGroupDifferentiator; import lombok.Builder; import java.util.*; @@ -57,9 +56,12 @@ public class TreeTrainer { private SplitRule findBestSplitRule(List> data, List covariatesToTry){ SplitRule bestSplitRule = null; - double bestSplitScore = 0; + Double bestSplitScore = 0.0; // may be null boolean first = true; + // temporary + final List previousRules = new ArrayList<>(); + for(final String covariate : covariatesToTry){ final List> shuffledData; @@ -83,24 +85,19 @@ public class TreeTrainer { int tries = 0; + while(tries < shuffledData.size()){ final SplitRule possibleRule = shuffledData.get(tries).getCovariate(covariate).generateSplitRule(covariate); final Split possibleSplit = possibleRule.applyRule(data); + previousRules.add(possibleRule); + final Double score = groupDifferentiator.differentiate( possibleSplit.leftHand.stream().map(row -> row.getResponse()).collect(Collectors.toList()), possibleSplit.rightHand.stream().map(row -> row.getResponse()).collect(Collectors.toList()) ); - /* - if( (groupDifferentiator.shouldMaximize() && score > bestSplitScore) || (!groupDifferentiator.shouldMaximize() && score < bestSplitScore) || first){ - bestSplitRule = possibleRule; - bestSplitScore = score; - first = false; - } - */ - - if( score != null && (score > bestSplitScore || first)){ + if( first || (score != null && (bestSplitScore == null || score > bestSplitScore))){ bestSplitRule = possibleRule; bestSplitScore = score; first = false; diff --git a/src/main/java/ca/joeltherrien/randomforest/workshop/TrainForest.java b/src/main/java/ca/joeltherrien/randomforest/workshop/TrainForest.java index afa528d..32779a2 100644 --- a/src/main/java/ca/joeltherrien/randomforest/workshop/TrainForest.java +++ b/src/main/java/ca/joeltherrien/randomforest/workshop/TrainForest.java @@ -16,8 +16,8 @@ public class TrainForest { public static void main(String[] args){ // test creating a regression tree on a problem and see if the results are sensible. - final int n = 10000; - final int p =5; + final int n = 1000000; + final int p = 5; final Random random = new Random(); @@ -48,8 +48,8 @@ public class TrainForest { TreeTrainer treeTrainer = TreeTrainer.builder() - .numberOfSplits(10) - .nodeSize(5) + .numberOfSplits(5) + .nodeSize(3) .maxNodeDepth(100000000) .groupDifferentiator(new WeightedVarianceGroupDifferentiator()) .responseCombiner(new MeanResponseCombiner()) @@ -57,7 +57,7 @@ public class TrainForest { final ForestTrainer forestTrainer = ForestTrainer.builder() .treeTrainer(treeTrainer) - .bootstrapper(new Bootstrapper<>(data)) + .data(data) .covariatesToTry(covariateNames) .mtry(4) .ntree(100) @@ -66,7 +66,10 @@ public class TrainForest { .build(); final long startTime = System.currentTimeMillis(); + final Forest forest = forestTrainer.trainSerial(); + //final Forest forest = forestTrainer.trainParallel(8); + final long endTime = System.currentTimeMillis(); System.out.println("Took " + (double)(endTime - startTime)/1000.0 + " seconds.");