Add support for saving trees as forest is being trained.

Support for loading the trees back is not yet written.
This commit is contained in:
Joel Therrien 2018-07-03 12:31:08 -07:00
parent df35a2007a
commit 254727e594
4 changed files with 93 additions and 14 deletions

View file

@ -1,9 +1,10 @@
package ca.joeltherrien.randomforest; package ca.joeltherrien.randomforest;
import java.io.Serializable;
import java.util.LinkedList; import java.util.LinkedList;
import java.util.List; import java.util.List;
public abstract class SplitRule { public abstract class SplitRule implements Serializable {
/** /**
* Applies the SplitRule to a list of rows and returns a Split object, which contains two lists for both sides. * Applies the SplitRule to a list of rows and returns a Split object, which contains two lists for both sides.

View file

@ -4,13 +4,16 @@ import ca.joeltherrien.randomforest.Bootstrapper;
import ca.joeltherrien.randomforest.ResponseCombiner; import ca.joeltherrien.randomforest.ResponseCombiner;
import ca.joeltherrien.randomforest.Row; import ca.joeltherrien.randomforest.Row;
import lombok.Builder; import lombok.Builder;
import lombok.RequiredArgsConstructor;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.ObjectOutputStream;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.concurrent.ExecutorService; import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors; import java.util.concurrent.Executors;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import java.util.stream.Stream; import java.util.stream.Stream;
@ -29,6 +32,7 @@ public class ForestTrainer<Y> {
private final int ntree; private final int ntree;
private final boolean displayProgress; private final boolean displayProgress;
private final String saveTreeLocation;
public Forest<Y> trainSerial(){ public Forest<Y> trainSerial(){
@ -57,7 +61,7 @@ public class ForestTrainer<Y> {
} }
public Forest<Y> trainParallel(int threads){ public Forest<Y> trainParallelInMemory(int threads){
// create a list that is prespecified in size (I can call the .set method at any index < ntree without // create a list that is prespecified in size (I can call the .set method at any index < ntree without
// the earlier indexes being filled. // the earlier indexes being filled.
@ -66,7 +70,7 @@ public class ForestTrainer<Y> {
final ExecutorService executorService = Executors.newFixedThreadPool(threads); final ExecutorService executorService = Executors.newFixedThreadPool(threads);
for(int j=0; j<ntree; j++){ for(int j=0; j<ntree; j++){
final Runnable worker = new Worker(data, j, trees); final Runnable worker = new TreeInMemoryWorker(data, j, trees);
executorService.execute(worker); executorService.execute(worker);
} }
@ -103,6 +107,38 @@ public class ForestTrainer<Y> {
} }
public void trainParallelOnDisk(int threads){
final ExecutorService executorService = Executors.newFixedThreadPool(threads);
final AtomicInteger treeCount = new AtomicInteger(0); // tracks how many trees are finished
for(int j=0; j<ntree; j++){
final Runnable worker = new TreeSavedWorker(data, "tree-" + (j+1), treeCount);
executorService.execute(worker);
}
executorService.shutdown();
while(!executorService.isTerminated()){
try{
Thread.sleep(100);
} catch (InterruptedException e) {
// do nothing; who cares?
}
if(displayProgress) {
System.out.print("\rFinished " + treeCount.get() + "/" + ntree + " trees");
}
}
if(displayProgress){
System.out.println("\nFinished");
}
}
private Node<Y> trainTree(final Bootstrapper<Row<Y>> bootstrapper){ private Node<Y> trainTree(final Bootstrapper<Row<Y>> bootstrapper){
final List<String> treeCovariates = new ArrayList<>(covariatesToTry); final List<String> treeCovariates = new ArrayList<>(covariatesToTry);
Collections.shuffle(treeCovariates); Collections.shuffle(treeCovariates);
@ -116,13 +152,24 @@ public class ForestTrainer<Y> {
return treeTrainer.growTree(bootstrappedData, treeCovariates); return treeTrainer.growTree(bootstrappedData, treeCovariates);
} }
private class Worker implements Runnable { public void saveTree(final Node<Y> tree, String name) throws IOException {
final String filename = saveTreeLocation + "/" + name;
final ObjectOutputStream outputStream = new ObjectOutputStream(new FileOutputStream(filename));
outputStream.writeObject(tree);
outputStream.close();
}
private class TreeInMemoryWorker implements Runnable {
private final Bootstrapper<Row<Y>> bootstrapper; private final Bootstrapper<Row<Y>> bootstrapper;
private final int treeIndex; private final int treeIndex;
private final List<Node<Y>> treeList; private final List<Node<Y>> treeList;
public Worker(final List<Row<Y>> data, final int treeIndex, final List<Node<Y>> treeList) { public TreeInMemoryWorker(final List<Row<Y>> data, final int treeIndex, final List<Node<Y>> treeList) {
this.bootstrapper = new Bootstrapper<>(data); this.bootstrapper = new Bootstrapper<>(data);
this.treeIndex = treeIndex; this.treeIndex = treeIndex;
this.treeList = treeList; this.treeList = treeList;
@ -136,10 +183,36 @@ public class ForestTrainer<Y> {
// should be okay as the list structure isn't changing // should be okay as the list structure isn't changing
treeList.set(treeIndex, tree); treeList.set(treeIndex, tree);
//if(displayProgress){ }
// System.out.println("Finished tree " + (treeIndex+1)); }
//}
private class TreeSavedWorker implements Runnable {
private final Bootstrapper<Row<Y>> bootstrapper;
private final String filename;
private final AtomicInteger treeCount;
public TreeSavedWorker(final List<Row<Y>> data, final String filename, final AtomicInteger treeCount) {
this.bootstrapper = new Bootstrapper<>(data);
this.filename = filename;
this.treeCount = treeCount;
}
@Override
public void run() {
final Node<Y> tree = trainTree(bootstrapper);
try {
saveTree(tree, filename);
} catch (IOException e) {
System.err.println("IOException while saving " + filename);
e.printStackTrace();
System.err.println("Quitting program");
System.exit(1);
}
treeCount.incrementAndGet();
} }
} }

View file

@ -2,7 +2,9 @@ package ca.joeltherrien.randomforest.tree;
import ca.joeltherrien.randomforest.CovariateRow; import ca.joeltherrien.randomforest.CovariateRow;
public interface Node<Y> { import java.io.Serializable;
public interface Node<Y> extends Serializable {
Y evaluate(CovariateRow row); Y evaluate(CovariateRow row);

View file

@ -16,7 +16,7 @@ public class TrainForest {
public static void main(String[] args){ public static void main(String[] args){
// test creating a regression tree on a problem and see if the results are sensible. // test creating a regression tree on a problem and see if the results are sensible.
final int n = 1000000; final int n = 10000;
final int p = 5; final int p = 5;
final Random random = new Random(); final Random random = new Random();
@ -49,7 +49,7 @@ public class TrainForest {
TreeTrainer<Double> treeTrainer = TreeTrainer.<Double>builder() TreeTrainer<Double> treeTrainer = TreeTrainer.<Double>builder()
.numberOfSplits(5) .numberOfSplits(5)
.nodeSize(3) .nodeSize(5)
.maxNodeDepth(100000000) .maxNodeDepth(100000000)
.groupDifferentiator(new WeightedVarianceGroupDifferentiator()) .groupDifferentiator(new WeightedVarianceGroupDifferentiator())
.responseCombiner(new MeanResponseCombiner()) .responseCombiner(new MeanResponseCombiner())
@ -63,18 +63,21 @@ public class TrainForest {
.ntree(100) .ntree(100)
.treeResponseCombiner(new MeanResponseCombiner()) .treeResponseCombiner(new MeanResponseCombiner())
.displayProgress(true) .displayProgress(true)
.saveTreeLocation("/home/joel/test")
.build(); .build();
final long startTime = System.currentTimeMillis(); final long startTime = System.currentTimeMillis();
final Forest<Double> forest = forestTrainer.trainSerial(); //final Forest<Double> forest = forestTrainer.trainSerial();
//final Forest<Double> forest = forestTrainer.trainParallel(8); //final Forest<Double> forest = forestTrainer.trainParallel(8);
forestTrainer.trainParallelOnDisk(3);
final long endTime = System.currentTimeMillis(); final long endTime = System.currentTimeMillis();
System.out.println("Took " + (double)(endTime - startTime)/1000.0 + " seconds."); System.out.println("Took " + (double)(endTime - startTime)/1000.0 + " seconds.");
/*
final Value zeroValue = new NumericValue(0.1); final Value zeroValue = new NumericValue(0.1);
final Value point5Value = new NumericValue(0.5); final Value point5Value = new NumericValue(0.5);
@ -87,7 +90,7 @@ public class TrainForest {
System.out.println(forest.evaluate(testRow2)); System.out.println(forest.evaluate(testRow2));
System.out.println("MinY = " + minY); System.out.println("MinY = " + minY);
*/
} }
} }