Attempting memory optimizations
This commit is contained in:
parent
8014bd4629
commit
cfa3a6f432
4 changed files with 55 additions and 6 deletions
|
@ -71,10 +71,18 @@ public class Main {
|
|||
final ForestTrainer forestTrainer = new ForestTrainer(settings, dataset, covariates);
|
||||
|
||||
if(settings.isSaveProgress()){
|
||||
forestTrainer.trainParallelOnDisk(settings.getNumberOfThreads());
|
||||
if(settings.getNumberOfThreads() > 1){
|
||||
forestTrainer.trainParallelOnDisk(settings.getNumberOfThreads());
|
||||
} else{
|
||||
forestTrainer.trainSerialOnDisk();
|
||||
}
|
||||
}
|
||||
else{
|
||||
forestTrainer.trainParallelInMemory(settings.getNumberOfThreads());
|
||||
if(settings.getNumberOfThreads() > 1){
|
||||
forestTrainer.trainParallelInMemory(settings.getNumberOfThreads());
|
||||
} else{
|
||||
forestTrainer.trainSerial();
|
||||
}
|
||||
}
|
||||
}
|
||||
else if(args[1].equalsIgnoreCase("analyze")){
|
||||
|
|
|
@ -82,8 +82,8 @@ public interface Covariate<V> extends Serializable, Comparable<Covariate> {
|
|||
* @return
|
||||
*/
|
||||
default <Y> Split<Y, V> applyRule(List<Row<Y>> rows) {
|
||||
final List<Row<Y>> leftHand = new LinkedList<>();
|
||||
final List<Row<Y>> rightHand = new LinkedList<>();
|
||||
final List<Row<Y>> leftHand = new ArrayList<>(rows.size()*3/4);
|
||||
final List<Row<Y>> rightHand = new ArrayList<>(rows.size()*3/4);
|
||||
|
||||
final List<Row<Y>> missingValueRows = new ArrayList<>();
|
||||
|
||||
|
|
|
@ -93,6 +93,35 @@ public class ForestTrainer<Y, TO, FO> {
|
|||
|
||||
}
|
||||
|
||||
public void trainSerialOnDisk(){
|
||||
// First we need to see how many trees there currently are
|
||||
final File folder = new File(saveTreeLocation);
|
||||
if(!folder.isDirectory()){
|
||||
throw new IllegalArgumentException("Tree directory must be a directory!");
|
||||
}
|
||||
|
||||
|
||||
final File[] treeFiles = folder.listFiles((file, s) -> s.endsWith(".tree"));
|
||||
|
||||
final AtomicInteger treeCount = new AtomicInteger(treeFiles.length); // tracks how many trees are finished
|
||||
// Using an AtomicInteger is overkill for serial code, but this lets use reuse TreeSavedWorker
|
||||
|
||||
for(int j=treeCount.get(); j<ntree; j++){
|
||||
if(displayProgress) {
|
||||
System.out.print("\rFinished " + treeCount.get() + "/" + ntree + " trees");
|
||||
}
|
||||
|
||||
final Runnable worker = new TreeSavedWorker(data, "tree-" + (j+1) + ".tree", treeCount);
|
||||
worker.run();
|
||||
|
||||
}
|
||||
|
||||
if(displayProgress){
|
||||
System.out.println("\nFinished");
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
public Forest<TO, FO> trainParallelInMemory(int threads){
|
||||
|
||||
// create a list that is prespecified in size (I can call the .set method at any index < ntree without
|
||||
|
|
|
@ -111,10 +111,22 @@ public class TreeTrainer<Y, O> {
|
|||
}
|
||||
}
|
||||
|
||||
final Node<O> leftNode;
|
||||
final Node<O> rightNode;
|
||||
|
||||
// let's train the smaller hand first; I've seen some behaviour where a split takes only a very narrow slice
|
||||
// off of the main body, and this repeats over and over again. I'd prefer to train those small nodes first so that
|
||||
// we can get terminal nodes and save some memory in the heap
|
||||
if(bestSplit.leftHand.size() < bestSplit.rightHand.size()){
|
||||
leftNode = growNode(bestSplit.leftHand, depth+1, random);
|
||||
rightNode = growNode(bestSplit.rightHand, depth+1, random);
|
||||
}
|
||||
else{
|
||||
rightNode = growNode(bestSplit.rightHand, depth+1, random);
|
||||
leftNode = growNode(bestSplit.leftHand, depth+1, random);
|
||||
}
|
||||
|
||||
|
||||
final Node<O> leftNode = growNode(bestSplit.leftHand, depth+1, random);
|
||||
final Node<O> rightNode = growNode(bestSplit.rightHand, depth+1, random);
|
||||
|
||||
return new SplitNode<>(leftNode, rightNode, bestSplit.getSplitRule(), probabilityLeftHand);
|
||||
|
||||
|
|
Loading…
Reference in a new issue