diff --git a/src/main/java/ca/joeltherrien/randomforest/SplitRule.java b/src/main/java/ca/joeltherrien/randomforest/SplitRule.java index a4b9c7a..95d7545 100644 --- a/src/main/java/ca/joeltherrien/randomforest/SplitRule.java +++ b/src/main/java/ca/joeltherrien/randomforest/SplitRule.java @@ -1,9 +1,10 @@ package ca.joeltherrien.randomforest; +import java.io.Serializable; import java.util.LinkedList; import java.util.List; -public abstract class SplitRule { +public abstract class SplitRule implements Serializable { /** * Applies the SplitRule to a list of rows and returns a Split object, which contains two lists for both sides. diff --git a/src/main/java/ca/joeltherrien/randomforest/tree/ForestTrainer.java b/src/main/java/ca/joeltherrien/randomforest/tree/ForestTrainer.java index 69c1d44..277b368 100644 --- a/src/main/java/ca/joeltherrien/randomforest/tree/ForestTrainer.java +++ b/src/main/java/ca/joeltherrien/randomforest/tree/ForestTrainer.java @@ -4,13 +4,16 @@ import ca.joeltherrien.randomforest.Bootstrapper; import ca.joeltherrien.randomforest.ResponseCombiner; import ca.joeltherrien.randomforest.Row; import lombok.Builder; -import lombok.RequiredArgsConstructor; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.ObjectOutputStream; import java.util.ArrayList; import java.util.Collections; import java.util.List; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; +import java.util.concurrent.atomic.AtomicInteger; import java.util.stream.Collectors; import java.util.stream.Stream; @@ -29,6 +32,7 @@ public class ForestTrainer { private final int ntree; private final boolean displayProgress; + private final String saveTreeLocation; public Forest trainSerial(){ @@ -57,7 +61,7 @@ public class ForestTrainer { } - public Forest trainParallel(int threads){ + public Forest trainParallelInMemory(int threads){ // create a list that is prespecified in size (I can call the .set method at any index < ntree without // the earlier indexes being filled. @@ -66,7 +70,7 @@ public class ForestTrainer { final ExecutorService executorService = Executors.newFixedThreadPool(threads); for(int j=0; j { } + + public void trainParallelOnDisk(int threads){ + final ExecutorService executorService = Executors.newFixedThreadPool(threads); + final AtomicInteger treeCount = new AtomicInteger(0); // tracks how many trees are finished + + for(int j=0; j trainTree(final Bootstrapper> bootstrapper){ final List treeCovariates = new ArrayList<>(covariatesToTry); Collections.shuffle(treeCovariates); @@ -116,13 +152,24 @@ public class ForestTrainer { return treeTrainer.growTree(bootstrappedData, treeCovariates); } - private class Worker implements Runnable { + public void saveTree(final Node tree, String name) throws IOException { + final String filename = saveTreeLocation + "/" + name; + + final ObjectOutputStream outputStream = new ObjectOutputStream(new FileOutputStream(filename)); + + outputStream.writeObject(tree); + + outputStream.close(); + + } + + private class TreeInMemoryWorker implements Runnable { private final Bootstrapper> bootstrapper; private final int treeIndex; private final List> treeList; - public Worker(final List> data, final int treeIndex, final List> treeList) { + public TreeInMemoryWorker(final List> data, final int treeIndex, final List> treeList) { this.bootstrapper = new Bootstrapper<>(data); this.treeIndex = treeIndex; this.treeList = treeList; @@ -136,10 +183,36 @@ public class ForestTrainer { // should be okay as the list structure isn't changing treeList.set(treeIndex, tree); - //if(displayProgress){ - // System.out.println("Finished tree " + (treeIndex+1)); - //} + } + } + private class TreeSavedWorker implements Runnable { + + private final Bootstrapper> bootstrapper; + private final String filename; + private final AtomicInteger treeCount; + + public TreeSavedWorker(final List> data, final String filename, final AtomicInteger treeCount) { + this.bootstrapper = new Bootstrapper<>(data); + this.filename = filename; + this.treeCount = treeCount; + } + + @Override + public void run() { + + final Node tree = trainTree(bootstrapper); + + try { + saveTree(tree, filename); + } catch (IOException e) { + System.err.println("IOException while saving " + filename); + e.printStackTrace(); + System.err.println("Quitting program"); + System.exit(1); + } + + treeCount.incrementAndGet(); } } diff --git a/src/main/java/ca/joeltherrien/randomforest/tree/Node.java b/src/main/java/ca/joeltherrien/randomforest/tree/Node.java index 25547e0..c0e8214 100644 --- a/src/main/java/ca/joeltherrien/randomforest/tree/Node.java +++ b/src/main/java/ca/joeltherrien/randomforest/tree/Node.java @@ -2,7 +2,9 @@ package ca.joeltherrien.randomforest.tree; import ca.joeltherrien.randomforest.CovariateRow; -public interface Node { +import java.io.Serializable; + +public interface Node extends Serializable { Y evaluate(CovariateRow row); diff --git a/src/main/java/ca/joeltherrien/randomforest/workshop/TrainForest.java b/src/main/java/ca/joeltherrien/randomforest/workshop/TrainForest.java index 32779a2..c645a35 100644 --- a/src/main/java/ca/joeltherrien/randomforest/workshop/TrainForest.java +++ b/src/main/java/ca/joeltherrien/randomforest/workshop/TrainForest.java @@ -16,7 +16,7 @@ public class TrainForest { public static void main(String[] args){ // test creating a regression tree on a problem and see if the results are sensible. - final int n = 1000000; + final int n = 10000; final int p = 5; final Random random = new Random(); @@ -49,7 +49,7 @@ public class TrainForest { TreeTrainer treeTrainer = TreeTrainer.builder() .numberOfSplits(5) - .nodeSize(3) + .nodeSize(5) .maxNodeDepth(100000000) .groupDifferentiator(new WeightedVarianceGroupDifferentiator()) .responseCombiner(new MeanResponseCombiner()) @@ -63,18 +63,21 @@ public class TrainForest { .ntree(100) .treeResponseCombiner(new MeanResponseCombiner()) .displayProgress(true) + .saveTreeLocation("/home/joel/test") .build(); final long startTime = System.currentTimeMillis(); - final Forest forest = forestTrainer.trainSerial(); + //final Forest forest = forestTrainer.trainSerial(); //final Forest forest = forestTrainer.trainParallel(8); + forestTrainer.trainParallelOnDisk(3); final long endTime = System.currentTimeMillis(); System.out.println("Took " + (double)(endTime - startTime)/1000.0 + " seconds."); + /* final Value zeroValue = new NumericValue(0.1); final Value point5Value = new NumericValue(0.5); @@ -87,7 +90,7 @@ public class TrainForest { System.out.println(forest.evaluate(testRow2)); System.out.println("MinY = " + minY); - + */ } }