Add parallel support & fix fatal bug in TreeTrainer#findBestSplitRule.

This commit is contained in:
Joel Therrien 2018-07-02 23:16:20 -07:00
parent df7835869a
commit 5f280d09a1
5 changed files with 121 additions and 30 deletions

View file

@ -16,4 +16,9 @@ public class NumericValue implements Value<Double> {
public SplitRule generateSplitRule(final String covariateName) { public SplitRule generateSplitRule(final String covariateName) {
return new NumericSplitRule(covariateName, value); return new NumericSplitRule(covariateName, value);
} }
@Override
public String toString(){
return "" + value;
}
} }

View file

@ -4,12 +4,13 @@ import ca.joeltherrien.randomforest.CovariateRow;
import ca.joeltherrien.randomforest.ResponseCombiner; import ca.joeltherrien.randomforest.ResponseCombiner;
import lombok.Builder; import lombok.Builder;
import java.util.Collection;
import java.util.List; import java.util.List;
@Builder @Builder
public class Forest<Y> { public class Forest<Y> {
private final List<Node<Y>> trees; private final Collection<Node<Y>> trees;
private final ResponseCombiner<Y, ?> treeResponseCombiner; private final ResponseCombiner<Y, ?> treeResponseCombiner;
public Y evaluate(CovariateRow row){ public Y evaluate(CovariateRow row){

View file

@ -4,18 +4,23 @@ import ca.joeltherrien.randomforest.Bootstrapper;
import ca.joeltherrien.randomforest.ResponseCombiner; import ca.joeltherrien.randomforest.ResponseCombiner;
import ca.joeltherrien.randomforest.Row; import ca.joeltherrien.randomforest.Row;
import lombok.Builder; import lombok.Builder;
import lombok.RequiredArgsConstructor;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.stream.Collectors;
import java.util.stream.Stream;
@Builder @Builder
public class ForestTrainer<Y> { public class ForestTrainer<Y> {
private final TreeTrainer<Y> treeTrainer; private final TreeTrainer<Y> treeTrainer;
private final Bootstrapper<Row<Y>> bootstrapper;
private final List<String> covariatesToTry; private final List<String> covariatesToTry;
private final ResponseCombiner<Y, ?> treeResponseCombiner; private final ResponseCombiner<Y, ?> treeResponseCombiner;
private final List<Row<Y>> data;
// number of covariates to randomly try // number of covariates to randomly try
private final int mtry; private final int mtry;
@ -28,18 +33,11 @@ public class ForestTrainer<Y> {
public Forest<Y> trainSerial(){ public Forest<Y> trainSerial(){
final List<Node<Y>> trees = new ArrayList<>(ntree); final List<Node<Y>> trees = new ArrayList<>(ntree);
final Bootstrapper<Row<Y>> bootstrapper = new Bootstrapper<>(data);
for(int j=0; j<ntree; j++){ for(int j=0; j<ntree; j++){
final List<String> treeCovariates = new ArrayList<>(covariatesToTry);
Collections.shuffle(treeCovariates);
for(int treeIndex = covariatesToTry.size()-1; treeIndex >= mtry; treeIndex--){ trees.add(trainTree(bootstrapper));
treeCovariates.remove(treeIndex);
}
final List<Row<Y>> bootstrappedData = bootstrapper.bootstrap();
trees.add(treeTrainer.growTree(bootstrappedData, treeCovariates));
if(displayProgress){ if(displayProgress){
if(j==0) { if(j==0) {
@ -59,4 +57,91 @@ public class ForestTrainer<Y> {
} }
public Forest<Y> trainParallel(int threads){
// create a list that is prespecified in size (I can call the .set method at any index < ntree without
// the earlier indexes being filled.
final List<Node<Y>> trees = Stream.<Node<Y>>generate(() -> null).limit(ntree).collect(Collectors.toList());
final ExecutorService executorService = Executors.newFixedThreadPool(threads);
for(int j=0; j<ntree; j++){
final Runnable worker = new Worker(data, j, trees);
executorService.execute(worker);
}
executorService.shutdown();
while(!executorService.isTerminated()){
try{
Thread.sleep(100);
} catch (InterruptedException e) {
// do nothing; who cares?
}
if(displayProgress) {
int numberTreesSet = 0;
for (final Node<Y> tree : trees) {
if (tree != null) {
numberTreesSet++;
}
}
System.out.print("\rFinished " + numberTreesSet + "/" + ntree + " trees");
}
}
if(displayProgress){
System.out.println("\nFinished");
}
return Forest.<Y>builder()
.treeResponseCombiner(treeResponseCombiner)
.trees(trees)
.build();
}
private Node<Y> trainTree(final Bootstrapper<Row<Y>> bootstrapper){
final List<String> treeCovariates = new ArrayList<>(covariatesToTry);
Collections.shuffle(treeCovariates);
for(int treeIndex = covariatesToTry.size()-1; treeIndex >= mtry; treeIndex--){
treeCovariates.remove(treeIndex);
}
final List<Row<Y>> bootstrappedData = bootstrapper.bootstrap();
return treeTrainer.growTree(bootstrappedData, treeCovariates);
}
private class Worker implements Runnable {
private final Bootstrapper<Row<Y>> bootstrapper;
private final int treeIndex;
private final List<Node<Y>> treeList;
public Worker(final List<Row<Y>> data, final int treeIndex, final List<Node<Y>> treeList) {
this.bootstrapper = new Bootstrapper<>(data);
this.treeIndex = treeIndex;
this.treeList = treeList;
}
@Override
public void run() {
final Node<Y> tree = trainTree(bootstrapper);
// should be okay as the list structure isn't changing
treeList.set(treeIndex, tree);
//if(displayProgress){
// System.out.println("Finished tree " + (treeIndex+1));
//}
}
}
} }

View file

@ -1,9 +1,8 @@
package ca.joeltherrien.randomforest.tree; package ca.joeltherrien.randomforest.tree;
import ca.joeltherrien.randomforest.ResponseCombiner; import ca.joeltherrien.randomforest.*;
import ca.joeltherrien.randomforest.Row; import ca.joeltherrien.randomforest.regression.MeanResponseCombiner;
import ca.joeltherrien.randomforest.Split; import ca.joeltherrien.randomforest.regression.WeightedVarianceGroupDifferentiator;
import ca.joeltherrien.randomforest.SplitRule;
import lombok.Builder; import lombok.Builder;
import java.util.*; import java.util.*;
@ -57,9 +56,12 @@ public class TreeTrainer<Y> {
private SplitRule findBestSplitRule(List<Row<Y>> data, List<String> covariatesToTry){ private SplitRule findBestSplitRule(List<Row<Y>> data, List<String> covariatesToTry){
SplitRule bestSplitRule = null; SplitRule bestSplitRule = null;
double bestSplitScore = 0; Double bestSplitScore = 0.0; // may be null
boolean first = true; boolean first = true;
// temporary
final List<SplitRule> previousRules = new ArrayList<>();
for(final String covariate : covariatesToTry){ for(final String covariate : covariatesToTry){
final List<Row<Y>> shuffledData; final List<Row<Y>> shuffledData;
@ -83,24 +85,19 @@ public class TreeTrainer<Y> {
int tries = 0; int tries = 0;
while(tries < shuffledData.size()){ while(tries < shuffledData.size()){
final SplitRule possibleRule = shuffledData.get(tries).getCovariate(covariate).generateSplitRule(covariate); final SplitRule possibleRule = shuffledData.get(tries).getCovariate(covariate).generateSplitRule(covariate);
final Split<Y> possibleSplit = possibleRule.applyRule(data); final Split<Y> possibleSplit = possibleRule.applyRule(data);
previousRules.add(possibleRule);
final Double score = groupDifferentiator.differentiate( final Double score = groupDifferentiator.differentiate(
possibleSplit.leftHand.stream().map(row -> row.getResponse()).collect(Collectors.toList()), possibleSplit.leftHand.stream().map(row -> row.getResponse()).collect(Collectors.toList()),
possibleSplit.rightHand.stream().map(row -> row.getResponse()).collect(Collectors.toList()) possibleSplit.rightHand.stream().map(row -> row.getResponse()).collect(Collectors.toList())
); );
/* if( first || (score != null && (bestSplitScore == null || score > bestSplitScore))){
if( (groupDifferentiator.shouldMaximize() && score > bestSplitScore) || (!groupDifferentiator.shouldMaximize() && score < bestSplitScore) || first){
bestSplitRule = possibleRule;
bestSplitScore = score;
first = false;
}
*/
if( score != null && (score > bestSplitScore || first)){
bestSplitRule = possibleRule; bestSplitRule = possibleRule;
bestSplitScore = score; bestSplitScore = score;
first = false; first = false;

View file

@ -16,8 +16,8 @@ public class TrainForest {
public static void main(String[] args){ public static void main(String[] args){
// test creating a regression tree on a problem and see if the results are sensible. // test creating a regression tree on a problem and see if the results are sensible.
final int n = 10000; final int n = 1000000;
final int p =5; final int p = 5;
final Random random = new Random(); final Random random = new Random();
@ -48,8 +48,8 @@ public class TrainForest {
TreeTrainer<Double> treeTrainer = TreeTrainer.<Double>builder() TreeTrainer<Double> treeTrainer = TreeTrainer.<Double>builder()
.numberOfSplits(10) .numberOfSplits(5)
.nodeSize(5) .nodeSize(3)
.maxNodeDepth(100000000) .maxNodeDepth(100000000)
.groupDifferentiator(new WeightedVarianceGroupDifferentiator()) .groupDifferentiator(new WeightedVarianceGroupDifferentiator())
.responseCombiner(new MeanResponseCombiner()) .responseCombiner(new MeanResponseCombiner())
@ -57,7 +57,7 @@ public class TrainForest {
final ForestTrainer<Double> forestTrainer = ForestTrainer.<Double>builder() final ForestTrainer<Double> forestTrainer = ForestTrainer.<Double>builder()
.treeTrainer(treeTrainer) .treeTrainer(treeTrainer)
.bootstrapper(new Bootstrapper<>(data)) .data(data)
.covariatesToTry(covariateNames) .covariatesToTry(covariateNames)
.mtry(4) .mtry(4)
.ntree(100) .ntree(100)
@ -66,7 +66,10 @@ public class TrainForest {
.build(); .build();
final long startTime = System.currentTimeMillis(); final long startTime = System.currentTimeMillis();
final Forest<Double> forest = forestTrainer.trainSerial(); final Forest<Double> forest = forestTrainer.trainSerial();
//final Forest<Double> forest = forestTrainer.trainParallel(8);
final long endTime = System.currentTimeMillis(); final long endTime = System.currentTimeMillis();
System.out.println("Took " + (double)(endTime - startTime)/1000.0 + " seconds."); System.out.println("Took " + (double)(endTime - startTime)/1000.0 + " seconds.");