Add support for saving trees as forest is being trained.
Support for loading the trees back is not yet written.
This commit is contained in:
parent
df35a2007a
commit
254727e594
4 changed files with 93 additions and 14 deletions
|
@ -1,9 +1,10 @@
|
||||||
package ca.joeltherrien.randomforest;
|
package ca.joeltherrien.randomforest;
|
||||||
|
|
||||||
|
import java.io.Serializable;
|
||||||
import java.util.LinkedList;
|
import java.util.LinkedList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
public abstract class SplitRule {
|
public abstract class SplitRule implements Serializable {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Applies the SplitRule to a list of rows and returns a Split object, which contains two lists for both sides.
|
* Applies the SplitRule to a list of rows and returns a Split object, which contains two lists for both sides.
|
||||||
|
|
|
@ -4,13 +4,16 @@ import ca.joeltherrien.randomforest.Bootstrapper;
|
||||||
import ca.joeltherrien.randomforest.ResponseCombiner;
|
import ca.joeltherrien.randomforest.ResponseCombiner;
|
||||||
import ca.joeltherrien.randomforest.Row;
|
import ca.joeltherrien.randomforest.Row;
|
||||||
import lombok.Builder;
|
import lombok.Builder;
|
||||||
import lombok.RequiredArgsConstructor;
|
|
||||||
|
|
||||||
|
import java.io.FileOutputStream;
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.io.ObjectOutputStream;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.concurrent.ExecutorService;
|
import java.util.concurrent.ExecutorService;
|
||||||
import java.util.concurrent.Executors;
|
import java.util.concurrent.Executors;
|
||||||
|
import java.util.concurrent.atomic.AtomicInteger;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
import java.util.stream.Stream;
|
import java.util.stream.Stream;
|
||||||
|
|
||||||
|
@ -29,6 +32,7 @@ public class ForestTrainer<Y> {
|
||||||
private final int ntree;
|
private final int ntree;
|
||||||
|
|
||||||
private final boolean displayProgress;
|
private final boolean displayProgress;
|
||||||
|
private final String saveTreeLocation;
|
||||||
|
|
||||||
public Forest<Y> trainSerial(){
|
public Forest<Y> trainSerial(){
|
||||||
|
|
||||||
|
@ -57,7 +61,7 @@ public class ForestTrainer<Y> {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public Forest<Y> trainParallel(int threads){
|
public Forest<Y> trainParallelInMemory(int threads){
|
||||||
|
|
||||||
// create a list that is prespecified in size (I can call the .set method at any index < ntree without
|
// create a list that is prespecified in size (I can call the .set method at any index < ntree without
|
||||||
// the earlier indexes being filled.
|
// the earlier indexes being filled.
|
||||||
|
@ -66,7 +70,7 @@ public class ForestTrainer<Y> {
|
||||||
final ExecutorService executorService = Executors.newFixedThreadPool(threads);
|
final ExecutorService executorService = Executors.newFixedThreadPool(threads);
|
||||||
|
|
||||||
for(int j=0; j<ntree; j++){
|
for(int j=0; j<ntree; j++){
|
||||||
final Runnable worker = new Worker(data, j, trees);
|
final Runnable worker = new TreeInMemoryWorker(data, j, trees);
|
||||||
executorService.execute(worker);
|
executorService.execute(worker);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -103,6 +107,38 @@ public class ForestTrainer<Y> {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
public void trainParallelOnDisk(int threads){
|
||||||
|
final ExecutorService executorService = Executors.newFixedThreadPool(threads);
|
||||||
|
final AtomicInteger treeCount = new AtomicInteger(0); // tracks how many trees are finished
|
||||||
|
|
||||||
|
for(int j=0; j<ntree; j++){
|
||||||
|
final Runnable worker = new TreeSavedWorker(data, "tree-" + (j+1), treeCount);
|
||||||
|
executorService.execute(worker);
|
||||||
|
}
|
||||||
|
|
||||||
|
executorService.shutdown();
|
||||||
|
while(!executorService.isTerminated()){
|
||||||
|
|
||||||
|
try{
|
||||||
|
Thread.sleep(100);
|
||||||
|
} catch (InterruptedException e) {
|
||||||
|
// do nothing; who cares?
|
||||||
|
}
|
||||||
|
|
||||||
|
if(displayProgress) {
|
||||||
|
|
||||||
|
System.out.print("\rFinished " + treeCount.get() + "/" + ntree + " trees");
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
if(displayProgress){
|
||||||
|
System.out.println("\nFinished");
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
private Node<Y> trainTree(final Bootstrapper<Row<Y>> bootstrapper){
|
private Node<Y> trainTree(final Bootstrapper<Row<Y>> bootstrapper){
|
||||||
final List<String> treeCovariates = new ArrayList<>(covariatesToTry);
|
final List<String> treeCovariates = new ArrayList<>(covariatesToTry);
|
||||||
Collections.shuffle(treeCovariates);
|
Collections.shuffle(treeCovariates);
|
||||||
|
@ -116,13 +152,24 @@ public class ForestTrainer<Y> {
|
||||||
return treeTrainer.growTree(bootstrappedData, treeCovariates);
|
return treeTrainer.growTree(bootstrappedData, treeCovariates);
|
||||||
}
|
}
|
||||||
|
|
||||||
private class Worker implements Runnable {
|
public void saveTree(final Node<Y> tree, String name) throws IOException {
|
||||||
|
final String filename = saveTreeLocation + "/" + name;
|
||||||
|
|
||||||
|
final ObjectOutputStream outputStream = new ObjectOutputStream(new FileOutputStream(filename));
|
||||||
|
|
||||||
|
outputStream.writeObject(tree);
|
||||||
|
|
||||||
|
outputStream.close();
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
private class TreeInMemoryWorker implements Runnable {
|
||||||
|
|
||||||
private final Bootstrapper<Row<Y>> bootstrapper;
|
private final Bootstrapper<Row<Y>> bootstrapper;
|
||||||
private final int treeIndex;
|
private final int treeIndex;
|
||||||
private final List<Node<Y>> treeList;
|
private final List<Node<Y>> treeList;
|
||||||
|
|
||||||
public Worker(final List<Row<Y>> data, final int treeIndex, final List<Node<Y>> treeList) {
|
public TreeInMemoryWorker(final List<Row<Y>> data, final int treeIndex, final List<Node<Y>> treeList) {
|
||||||
this.bootstrapper = new Bootstrapper<>(data);
|
this.bootstrapper = new Bootstrapper<>(data);
|
||||||
this.treeIndex = treeIndex;
|
this.treeIndex = treeIndex;
|
||||||
this.treeList = treeList;
|
this.treeList = treeList;
|
||||||
|
@ -136,10 +183,36 @@ public class ForestTrainer<Y> {
|
||||||
// should be okay as the list structure isn't changing
|
// should be okay as the list structure isn't changing
|
||||||
treeList.set(treeIndex, tree);
|
treeList.set(treeIndex, tree);
|
||||||
|
|
||||||
//if(displayProgress){
|
}
|
||||||
// System.out.println("Finished tree " + (treeIndex+1));
|
}
|
||||||
//}
|
|
||||||
|
|
||||||
|
private class TreeSavedWorker implements Runnable {
|
||||||
|
|
||||||
|
private final Bootstrapper<Row<Y>> bootstrapper;
|
||||||
|
private final String filename;
|
||||||
|
private final AtomicInteger treeCount;
|
||||||
|
|
||||||
|
public TreeSavedWorker(final List<Row<Y>> data, final String filename, final AtomicInteger treeCount) {
|
||||||
|
this.bootstrapper = new Bootstrapper<>(data);
|
||||||
|
this.filename = filename;
|
||||||
|
this.treeCount = treeCount;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void run() {
|
||||||
|
|
||||||
|
final Node<Y> tree = trainTree(bootstrapper);
|
||||||
|
|
||||||
|
try {
|
||||||
|
saveTree(tree, filename);
|
||||||
|
} catch (IOException e) {
|
||||||
|
System.err.println("IOException while saving " + filename);
|
||||||
|
e.printStackTrace();
|
||||||
|
System.err.println("Quitting program");
|
||||||
|
System.exit(1);
|
||||||
|
}
|
||||||
|
|
||||||
|
treeCount.incrementAndGet();
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,7 +2,9 @@ package ca.joeltherrien.randomforest.tree;
|
||||||
|
|
||||||
import ca.joeltherrien.randomforest.CovariateRow;
|
import ca.joeltherrien.randomforest.CovariateRow;
|
||||||
|
|
||||||
public interface Node<Y> {
|
import java.io.Serializable;
|
||||||
|
|
||||||
|
public interface Node<Y> extends Serializable {
|
||||||
|
|
||||||
Y evaluate(CovariateRow row);
|
Y evaluate(CovariateRow row);
|
||||||
|
|
||||||
|
|
|
@ -16,7 +16,7 @@ public class TrainForest {
|
||||||
public static void main(String[] args){
|
public static void main(String[] args){
|
||||||
// test creating a regression tree on a problem and see if the results are sensible.
|
// test creating a regression tree on a problem and see if the results are sensible.
|
||||||
|
|
||||||
final int n = 1000000;
|
final int n = 10000;
|
||||||
final int p = 5;
|
final int p = 5;
|
||||||
|
|
||||||
final Random random = new Random();
|
final Random random = new Random();
|
||||||
|
@ -49,7 +49,7 @@ public class TrainForest {
|
||||||
|
|
||||||
TreeTrainer<Double> treeTrainer = TreeTrainer.<Double>builder()
|
TreeTrainer<Double> treeTrainer = TreeTrainer.<Double>builder()
|
||||||
.numberOfSplits(5)
|
.numberOfSplits(5)
|
||||||
.nodeSize(3)
|
.nodeSize(5)
|
||||||
.maxNodeDepth(100000000)
|
.maxNodeDepth(100000000)
|
||||||
.groupDifferentiator(new WeightedVarianceGroupDifferentiator())
|
.groupDifferentiator(new WeightedVarianceGroupDifferentiator())
|
||||||
.responseCombiner(new MeanResponseCombiner())
|
.responseCombiner(new MeanResponseCombiner())
|
||||||
|
@ -63,18 +63,21 @@ public class TrainForest {
|
||||||
.ntree(100)
|
.ntree(100)
|
||||||
.treeResponseCombiner(new MeanResponseCombiner())
|
.treeResponseCombiner(new MeanResponseCombiner())
|
||||||
.displayProgress(true)
|
.displayProgress(true)
|
||||||
|
.saveTreeLocation("/home/joel/test")
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
final long startTime = System.currentTimeMillis();
|
final long startTime = System.currentTimeMillis();
|
||||||
|
|
||||||
final Forest<Double> forest = forestTrainer.trainSerial();
|
//final Forest<Double> forest = forestTrainer.trainSerial();
|
||||||
//final Forest<Double> forest = forestTrainer.trainParallel(8);
|
//final Forest<Double> forest = forestTrainer.trainParallel(8);
|
||||||
|
forestTrainer.trainParallelOnDisk(3);
|
||||||
|
|
||||||
final long endTime = System.currentTimeMillis();
|
final long endTime = System.currentTimeMillis();
|
||||||
|
|
||||||
System.out.println("Took " + (double)(endTime - startTime)/1000.0 + " seconds.");
|
System.out.println("Took " + (double)(endTime - startTime)/1000.0 + " seconds.");
|
||||||
|
|
||||||
|
|
||||||
|
/*
|
||||||
final Value zeroValue = new NumericValue(0.1);
|
final Value zeroValue = new NumericValue(0.1);
|
||||||
final Value point5Value = new NumericValue(0.5);
|
final Value point5Value = new NumericValue(0.5);
|
||||||
|
|
||||||
|
@ -87,7 +90,7 @@ public class TrainForest {
|
||||||
System.out.println(forest.evaluate(testRow2));
|
System.out.println(forest.evaluate(testRow2));
|
||||||
|
|
||||||
System.out.println("MinY = " + minY);
|
System.out.println("MinY = " + minY);
|
||||||
|
*/
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue