Attempting memory optimizations
This commit is contained in:
parent
8014bd4629
commit
cfa3a6f432
4 changed files with 55 additions and 6 deletions
|
@ -71,10 +71,18 @@ public class Main {
|
||||||
final ForestTrainer forestTrainer = new ForestTrainer(settings, dataset, covariates);
|
final ForestTrainer forestTrainer = new ForestTrainer(settings, dataset, covariates);
|
||||||
|
|
||||||
if(settings.isSaveProgress()){
|
if(settings.isSaveProgress()){
|
||||||
|
if(settings.getNumberOfThreads() > 1){
|
||||||
forestTrainer.trainParallelOnDisk(settings.getNumberOfThreads());
|
forestTrainer.trainParallelOnDisk(settings.getNumberOfThreads());
|
||||||
|
} else{
|
||||||
|
forestTrainer.trainSerialOnDisk();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
else{
|
else{
|
||||||
|
if(settings.getNumberOfThreads() > 1){
|
||||||
forestTrainer.trainParallelInMemory(settings.getNumberOfThreads());
|
forestTrainer.trainParallelInMemory(settings.getNumberOfThreads());
|
||||||
|
} else{
|
||||||
|
forestTrainer.trainSerial();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
else if(args[1].equalsIgnoreCase("analyze")){
|
else if(args[1].equalsIgnoreCase("analyze")){
|
||||||
|
|
|
@ -82,8 +82,8 @@ public interface Covariate<V> extends Serializable, Comparable<Covariate> {
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
default <Y> Split<Y, V> applyRule(List<Row<Y>> rows) {
|
default <Y> Split<Y, V> applyRule(List<Row<Y>> rows) {
|
||||||
final List<Row<Y>> leftHand = new LinkedList<>();
|
final List<Row<Y>> leftHand = new ArrayList<>(rows.size()*3/4);
|
||||||
final List<Row<Y>> rightHand = new LinkedList<>();
|
final List<Row<Y>> rightHand = new ArrayList<>(rows.size()*3/4);
|
||||||
|
|
||||||
final List<Row<Y>> missingValueRows = new ArrayList<>();
|
final List<Row<Y>> missingValueRows = new ArrayList<>();
|
||||||
|
|
||||||
|
|
|
@ -93,6 +93,35 @@ public class ForestTrainer<Y, TO, FO> {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public void trainSerialOnDisk(){
|
||||||
|
// First we need to see how many trees there currently are
|
||||||
|
final File folder = new File(saveTreeLocation);
|
||||||
|
if(!folder.isDirectory()){
|
||||||
|
throw new IllegalArgumentException("Tree directory must be a directory!");
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
final File[] treeFiles = folder.listFiles((file, s) -> s.endsWith(".tree"));
|
||||||
|
|
||||||
|
final AtomicInteger treeCount = new AtomicInteger(treeFiles.length); // tracks how many trees are finished
|
||||||
|
// Using an AtomicInteger is overkill for serial code, but this lets use reuse TreeSavedWorker
|
||||||
|
|
||||||
|
for(int j=treeCount.get(); j<ntree; j++){
|
||||||
|
if(displayProgress) {
|
||||||
|
System.out.print("\rFinished " + treeCount.get() + "/" + ntree + " trees");
|
||||||
|
}
|
||||||
|
|
||||||
|
final Runnable worker = new TreeSavedWorker(data, "tree-" + (j+1) + ".tree", treeCount);
|
||||||
|
worker.run();
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
if(displayProgress){
|
||||||
|
System.out.println("\nFinished");
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
public Forest<TO, FO> trainParallelInMemory(int threads){
|
public Forest<TO, FO> trainParallelInMemory(int threads){
|
||||||
|
|
||||||
// create a list that is prespecified in size (I can call the .set method at any index < ntree without
|
// create a list that is prespecified in size (I can call the .set method at any index < ntree without
|
||||||
|
|
|
@ -111,10 +111,22 @@ public class TreeTrainer<Y, O> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
final Node<O> leftNode;
|
||||||
|
final Node<O> rightNode;
|
||||||
|
|
||||||
|
// let's train the smaller hand first; I've seen some behaviour where a split takes only a very narrow slice
|
||||||
|
// off of the main body, and this repeats over and over again. I'd prefer to train those small nodes first so that
|
||||||
|
// we can get terminal nodes and save some memory in the heap
|
||||||
|
if(bestSplit.leftHand.size() < bestSplit.rightHand.size()){
|
||||||
|
leftNode = growNode(bestSplit.leftHand, depth+1, random);
|
||||||
|
rightNode = growNode(bestSplit.rightHand, depth+1, random);
|
||||||
|
}
|
||||||
|
else{
|
||||||
|
rightNode = growNode(bestSplit.rightHand, depth+1, random);
|
||||||
|
leftNode = growNode(bestSplit.leftHand, depth+1, random);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
final Node<O> leftNode = growNode(bestSplit.leftHand, depth+1, random);
|
|
||||||
final Node<O> rightNode = growNode(bestSplit.rightHand, depth+1, random);
|
|
||||||
|
|
||||||
return new SplitNode<>(leftNode, rightNode, bestSplit.getSplitRule(), probabilityLeftHand);
|
return new SplitNode<>(leftNode, rightNode, bestSplit.getSplitRule(), probabilityLeftHand);
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue