Add parallel support & fix fatal bug in TreeTrainer#findBestSplitRule.
This commit is contained in:
parent
df7835869a
commit
5f280d09a1
5 changed files with 121 additions and 30 deletions
|
@ -16,4 +16,9 @@ public class NumericValue implements Value<Double> {
|
||||||
public SplitRule generateSplitRule(final String covariateName) {
|
public SplitRule generateSplitRule(final String covariateName) {
|
||||||
return new NumericSplitRule(covariateName, value);
|
return new NumericSplitRule(covariateName, value);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String toString(){
|
||||||
|
return "" + value;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -4,12 +4,13 @@ import ca.joeltherrien.randomforest.CovariateRow;
|
||||||
import ca.joeltherrien.randomforest.ResponseCombiner;
|
import ca.joeltherrien.randomforest.ResponseCombiner;
|
||||||
import lombok.Builder;
|
import lombok.Builder;
|
||||||
|
|
||||||
|
import java.util.Collection;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
@Builder
|
@Builder
|
||||||
public class Forest<Y> {
|
public class Forest<Y> {
|
||||||
|
|
||||||
private final List<Node<Y>> trees;
|
private final Collection<Node<Y>> trees;
|
||||||
private final ResponseCombiner<Y, ?> treeResponseCombiner;
|
private final ResponseCombiner<Y, ?> treeResponseCombiner;
|
||||||
|
|
||||||
public Y evaluate(CovariateRow row){
|
public Y evaluate(CovariateRow row){
|
||||||
|
|
|
@ -4,18 +4,23 @@ import ca.joeltherrien.randomforest.Bootstrapper;
|
||||||
import ca.joeltherrien.randomforest.ResponseCombiner;
|
import ca.joeltherrien.randomforest.ResponseCombiner;
|
||||||
import ca.joeltherrien.randomforest.Row;
|
import ca.joeltherrien.randomforest.Row;
|
||||||
import lombok.Builder;
|
import lombok.Builder;
|
||||||
|
import lombok.RequiredArgsConstructor;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
import java.util.concurrent.ExecutorService;
|
||||||
|
import java.util.concurrent.Executors;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
import java.util.stream.Stream;
|
||||||
|
|
||||||
@Builder
|
@Builder
|
||||||
public class ForestTrainer<Y> {
|
public class ForestTrainer<Y> {
|
||||||
|
|
||||||
private final TreeTrainer<Y> treeTrainer;
|
private final TreeTrainer<Y> treeTrainer;
|
||||||
private final Bootstrapper<Row<Y>> bootstrapper;
|
|
||||||
private final List<String> covariatesToTry;
|
private final List<String> covariatesToTry;
|
||||||
private final ResponseCombiner<Y, ?> treeResponseCombiner;
|
private final ResponseCombiner<Y, ?> treeResponseCombiner;
|
||||||
|
private final List<Row<Y>> data;
|
||||||
|
|
||||||
// number of covariates to randomly try
|
// number of covariates to randomly try
|
||||||
private final int mtry;
|
private final int mtry;
|
||||||
|
@ -28,18 +33,11 @@ public class ForestTrainer<Y> {
|
||||||
public Forest<Y> trainSerial(){
|
public Forest<Y> trainSerial(){
|
||||||
|
|
||||||
final List<Node<Y>> trees = new ArrayList<>(ntree);
|
final List<Node<Y>> trees = new ArrayList<>(ntree);
|
||||||
|
final Bootstrapper<Row<Y>> bootstrapper = new Bootstrapper<>(data);
|
||||||
|
|
||||||
for(int j=0; j<ntree; j++){
|
for(int j=0; j<ntree; j++){
|
||||||
final List<String> treeCovariates = new ArrayList<>(covariatesToTry);
|
|
||||||
Collections.shuffle(treeCovariates);
|
|
||||||
|
|
||||||
for(int treeIndex = covariatesToTry.size()-1; treeIndex >= mtry; treeIndex--){
|
trees.add(trainTree(bootstrapper));
|
||||||
treeCovariates.remove(treeIndex);
|
|
||||||
}
|
|
||||||
|
|
||||||
final List<Row<Y>> bootstrappedData = bootstrapper.bootstrap();
|
|
||||||
|
|
||||||
trees.add(treeTrainer.growTree(bootstrappedData, treeCovariates));
|
|
||||||
|
|
||||||
if(displayProgress){
|
if(displayProgress){
|
||||||
if(j==0) {
|
if(j==0) {
|
||||||
|
@ -59,4 +57,91 @@ public class ForestTrainer<Y> {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public Forest<Y> trainParallel(int threads){
|
||||||
|
|
||||||
|
// create a list that is prespecified in size (I can call the .set method at any index < ntree without
|
||||||
|
// the earlier indexes being filled.
|
||||||
|
final List<Node<Y>> trees = Stream.<Node<Y>>generate(() -> null).limit(ntree).collect(Collectors.toList());
|
||||||
|
|
||||||
|
final ExecutorService executorService = Executors.newFixedThreadPool(threads);
|
||||||
|
|
||||||
|
for(int j=0; j<ntree; j++){
|
||||||
|
final Runnable worker = new Worker(data, j, trees);
|
||||||
|
executorService.execute(worker);
|
||||||
|
}
|
||||||
|
|
||||||
|
executorService.shutdown();
|
||||||
|
while(!executorService.isTerminated()){
|
||||||
|
|
||||||
|
try{
|
||||||
|
Thread.sleep(100);
|
||||||
|
} catch (InterruptedException e) {
|
||||||
|
// do nothing; who cares?
|
||||||
|
}
|
||||||
|
|
||||||
|
if(displayProgress) {
|
||||||
|
int numberTreesSet = 0;
|
||||||
|
for (final Node<Y> tree : trees) {
|
||||||
|
if (tree != null) {
|
||||||
|
numberTreesSet++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
System.out.print("\rFinished " + numberTreesSet + "/" + ntree + " trees");
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
if(displayProgress){
|
||||||
|
System.out.println("\nFinished");
|
||||||
|
}
|
||||||
|
|
||||||
|
return Forest.<Y>builder()
|
||||||
|
.treeResponseCombiner(treeResponseCombiner)
|
||||||
|
.trees(trees)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
private Node<Y> trainTree(final Bootstrapper<Row<Y>> bootstrapper){
|
||||||
|
final List<String> treeCovariates = new ArrayList<>(covariatesToTry);
|
||||||
|
Collections.shuffle(treeCovariates);
|
||||||
|
|
||||||
|
for(int treeIndex = covariatesToTry.size()-1; treeIndex >= mtry; treeIndex--){
|
||||||
|
treeCovariates.remove(treeIndex);
|
||||||
|
}
|
||||||
|
|
||||||
|
final List<Row<Y>> bootstrappedData = bootstrapper.bootstrap();
|
||||||
|
|
||||||
|
return treeTrainer.growTree(bootstrappedData, treeCovariates);
|
||||||
|
}
|
||||||
|
|
||||||
|
private class Worker implements Runnable {
|
||||||
|
|
||||||
|
private final Bootstrapper<Row<Y>> bootstrapper;
|
||||||
|
private final int treeIndex;
|
||||||
|
private final List<Node<Y>> treeList;
|
||||||
|
|
||||||
|
public Worker(final List<Row<Y>> data, final int treeIndex, final List<Node<Y>> treeList) {
|
||||||
|
this.bootstrapper = new Bootstrapper<>(data);
|
||||||
|
this.treeIndex = treeIndex;
|
||||||
|
this.treeList = treeList;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void run() {
|
||||||
|
|
||||||
|
final Node<Y> tree = trainTree(bootstrapper);
|
||||||
|
|
||||||
|
// should be okay as the list structure isn't changing
|
||||||
|
treeList.set(treeIndex, tree);
|
||||||
|
|
||||||
|
//if(displayProgress){
|
||||||
|
// System.out.println("Finished tree " + (treeIndex+1));
|
||||||
|
//}
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,9 +1,8 @@
|
||||||
package ca.joeltherrien.randomforest.tree;
|
package ca.joeltherrien.randomforest.tree;
|
||||||
|
|
||||||
import ca.joeltherrien.randomforest.ResponseCombiner;
|
import ca.joeltherrien.randomforest.*;
|
||||||
import ca.joeltherrien.randomforest.Row;
|
import ca.joeltherrien.randomforest.regression.MeanResponseCombiner;
|
||||||
import ca.joeltherrien.randomforest.Split;
|
import ca.joeltherrien.randomforest.regression.WeightedVarianceGroupDifferentiator;
|
||||||
import ca.joeltherrien.randomforest.SplitRule;
|
|
||||||
import lombok.Builder;
|
import lombok.Builder;
|
||||||
|
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
|
@ -57,9 +56,12 @@ public class TreeTrainer<Y> {
|
||||||
|
|
||||||
private SplitRule findBestSplitRule(List<Row<Y>> data, List<String> covariatesToTry){
|
private SplitRule findBestSplitRule(List<Row<Y>> data, List<String> covariatesToTry){
|
||||||
SplitRule bestSplitRule = null;
|
SplitRule bestSplitRule = null;
|
||||||
double bestSplitScore = 0;
|
Double bestSplitScore = 0.0; // may be null
|
||||||
boolean first = true;
|
boolean first = true;
|
||||||
|
|
||||||
|
// temporary
|
||||||
|
final List<SplitRule> previousRules = new ArrayList<>();
|
||||||
|
|
||||||
for(final String covariate : covariatesToTry){
|
for(final String covariate : covariatesToTry){
|
||||||
|
|
||||||
final List<Row<Y>> shuffledData;
|
final List<Row<Y>> shuffledData;
|
||||||
|
@ -83,24 +85,19 @@ public class TreeTrainer<Y> {
|
||||||
|
|
||||||
|
|
||||||
int tries = 0;
|
int tries = 0;
|
||||||
|
|
||||||
while(tries < shuffledData.size()){
|
while(tries < shuffledData.size()){
|
||||||
final SplitRule possibleRule = shuffledData.get(tries).getCovariate(covariate).generateSplitRule(covariate);
|
final SplitRule possibleRule = shuffledData.get(tries).getCovariate(covariate).generateSplitRule(covariate);
|
||||||
final Split<Y> possibleSplit = possibleRule.applyRule(data);
|
final Split<Y> possibleSplit = possibleRule.applyRule(data);
|
||||||
|
|
||||||
|
previousRules.add(possibleRule);
|
||||||
|
|
||||||
final Double score = groupDifferentiator.differentiate(
|
final Double score = groupDifferentiator.differentiate(
|
||||||
possibleSplit.leftHand.stream().map(row -> row.getResponse()).collect(Collectors.toList()),
|
possibleSplit.leftHand.stream().map(row -> row.getResponse()).collect(Collectors.toList()),
|
||||||
possibleSplit.rightHand.stream().map(row -> row.getResponse()).collect(Collectors.toList())
|
possibleSplit.rightHand.stream().map(row -> row.getResponse()).collect(Collectors.toList())
|
||||||
);
|
);
|
||||||
|
|
||||||
/*
|
if( first || (score != null && (bestSplitScore == null || score > bestSplitScore))){
|
||||||
if( (groupDifferentiator.shouldMaximize() && score > bestSplitScore) || (!groupDifferentiator.shouldMaximize() && score < bestSplitScore) || first){
|
|
||||||
bestSplitRule = possibleRule;
|
|
||||||
bestSplitScore = score;
|
|
||||||
first = false;
|
|
||||||
}
|
|
||||||
*/
|
|
||||||
|
|
||||||
if( score != null && (score > bestSplitScore || first)){
|
|
||||||
bestSplitRule = possibleRule;
|
bestSplitRule = possibleRule;
|
||||||
bestSplitScore = score;
|
bestSplitScore = score;
|
||||||
first = false;
|
first = false;
|
||||||
|
|
|
@ -16,8 +16,8 @@ public class TrainForest {
|
||||||
public static void main(String[] args){
|
public static void main(String[] args){
|
||||||
// test creating a regression tree on a problem and see if the results are sensible.
|
// test creating a regression tree on a problem and see if the results are sensible.
|
||||||
|
|
||||||
final int n = 10000;
|
final int n = 1000000;
|
||||||
final int p =5;
|
final int p = 5;
|
||||||
|
|
||||||
final Random random = new Random();
|
final Random random = new Random();
|
||||||
|
|
||||||
|
@ -48,8 +48,8 @@ public class TrainForest {
|
||||||
|
|
||||||
|
|
||||||
TreeTrainer<Double> treeTrainer = TreeTrainer.<Double>builder()
|
TreeTrainer<Double> treeTrainer = TreeTrainer.<Double>builder()
|
||||||
.numberOfSplits(10)
|
.numberOfSplits(5)
|
||||||
.nodeSize(5)
|
.nodeSize(3)
|
||||||
.maxNodeDepth(100000000)
|
.maxNodeDepth(100000000)
|
||||||
.groupDifferentiator(new WeightedVarianceGroupDifferentiator())
|
.groupDifferentiator(new WeightedVarianceGroupDifferentiator())
|
||||||
.responseCombiner(new MeanResponseCombiner())
|
.responseCombiner(new MeanResponseCombiner())
|
||||||
|
@ -57,7 +57,7 @@ public class TrainForest {
|
||||||
|
|
||||||
final ForestTrainer<Double> forestTrainer = ForestTrainer.<Double>builder()
|
final ForestTrainer<Double> forestTrainer = ForestTrainer.<Double>builder()
|
||||||
.treeTrainer(treeTrainer)
|
.treeTrainer(treeTrainer)
|
||||||
.bootstrapper(new Bootstrapper<>(data))
|
.data(data)
|
||||||
.covariatesToTry(covariateNames)
|
.covariatesToTry(covariateNames)
|
||||||
.mtry(4)
|
.mtry(4)
|
||||||
.ntree(100)
|
.ntree(100)
|
||||||
|
@ -66,7 +66,10 @@ public class TrainForest {
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
final long startTime = System.currentTimeMillis();
|
final long startTime = System.currentTimeMillis();
|
||||||
|
|
||||||
final Forest<Double> forest = forestTrainer.trainSerial();
|
final Forest<Double> forest = forestTrainer.trainSerial();
|
||||||
|
//final Forest<Double> forest = forestTrainer.trainParallel(8);
|
||||||
|
|
||||||
final long endTime = System.currentTimeMillis();
|
final long endTime = System.currentTimeMillis();
|
||||||
|
|
||||||
System.out.println("Took " + (double)(endTime - startTime)/1000.0 + " seconds.");
|
System.out.println("Took " + (double)(endTime - startTime)/1000.0 + " seconds.");
|
||||||
|
|
Loading…
Reference in a new issue