Attempting memory optimizations

This commit is contained in:
Joel Therrien 2019-03-13 10:39:18 -07:00
parent 8014bd4629
commit cfa3a6f432
4 changed files with 55 additions and 6 deletions

View file

@ -71,10 +71,18 @@ public class Main {
final ForestTrainer forestTrainer = new ForestTrainer(settings, dataset, covariates); final ForestTrainer forestTrainer = new ForestTrainer(settings, dataset, covariates);
if(settings.isSaveProgress()){ if(settings.isSaveProgress()){
if(settings.getNumberOfThreads() > 1){
forestTrainer.trainParallelOnDisk(settings.getNumberOfThreads()); forestTrainer.trainParallelOnDisk(settings.getNumberOfThreads());
} else{
forestTrainer.trainSerialOnDisk();
}
} }
else{ else{
if(settings.getNumberOfThreads() > 1){
forestTrainer.trainParallelInMemory(settings.getNumberOfThreads()); forestTrainer.trainParallelInMemory(settings.getNumberOfThreads());
} else{
forestTrainer.trainSerial();
}
} }
} }
else if(args[1].equalsIgnoreCase("analyze")){ else if(args[1].equalsIgnoreCase("analyze")){

View file

@ -82,8 +82,8 @@ public interface Covariate<V> extends Serializable, Comparable<Covariate> {
* @return * @return
*/ */
default <Y> Split<Y, V> applyRule(List<Row<Y>> rows) { default <Y> Split<Y, V> applyRule(List<Row<Y>> rows) {
final List<Row<Y>> leftHand = new LinkedList<>(); final List<Row<Y>> leftHand = new ArrayList<>(rows.size()*3/4);
final List<Row<Y>> rightHand = new LinkedList<>(); final List<Row<Y>> rightHand = new ArrayList<>(rows.size()*3/4);
final List<Row<Y>> missingValueRows = new ArrayList<>(); final List<Row<Y>> missingValueRows = new ArrayList<>();

View file

@ -93,6 +93,35 @@ public class ForestTrainer<Y, TO, FO> {
} }
public void trainSerialOnDisk(){
// First we need to see how many trees there currently are
final File folder = new File(saveTreeLocation);
if(!folder.isDirectory()){
throw new IllegalArgumentException("Tree directory must be a directory!");
}
final File[] treeFiles = folder.listFiles((file, s) -> s.endsWith(".tree"));
final AtomicInteger treeCount = new AtomicInteger(treeFiles.length); // tracks how many trees are finished
// Using an AtomicInteger is overkill for serial code, but this lets use reuse TreeSavedWorker
for(int j=treeCount.get(); j<ntree; j++){
if(displayProgress) {
System.out.print("\rFinished " + treeCount.get() + "/" + ntree + " trees");
}
final Runnable worker = new TreeSavedWorker(data, "tree-" + (j+1) + ".tree", treeCount);
worker.run();
}
if(displayProgress){
System.out.println("\nFinished");
}
}
public Forest<TO, FO> trainParallelInMemory(int threads){ public Forest<TO, FO> trainParallelInMemory(int threads){
// create a list that is prespecified in size (I can call the .set method at any index < ntree without // create a list that is prespecified in size (I can call the .set method at any index < ntree without

View file

@ -111,10 +111,22 @@ public class TreeTrainer<Y, O> {
} }
} }
final Node<O> leftNode;
final Node<O> rightNode;
// let's train the smaller hand first; I've seen some behaviour where a split takes only a very narrow slice
// off of the main body, and this repeats over and over again. I'd prefer to train those small nodes first so that
// we can get terminal nodes and save some memory in the heap
if(bestSplit.leftHand.size() < bestSplit.rightHand.size()){
leftNode = growNode(bestSplit.leftHand, depth+1, random);
rightNode = growNode(bestSplit.rightHand, depth+1, random);
}
else{
rightNode = growNode(bestSplit.rightHand, depth+1, random);
leftNode = growNode(bestSplit.leftHand, depth+1, random);
}
final Node<O> leftNode = growNode(bestSplit.leftHand, depth+1, random);
final Node<O> rightNode = growNode(bestSplit.rightHand, depth+1, random);
return new SplitNode<>(leftNode, rightNode, bestSplit.getSplitRule(), probabilityLeftHand); return new SplitNode<>(leftNode, rightNode, bestSplit.getSplitRule(), probabilityLeftHand);