diff --git a/src/main/java/ca/joeltherrien/randomforest/Main.java b/src/main/java/ca/joeltherrien/randomforest/Main.java index bccd125..24d96c8 100644 --- a/src/main/java/ca/joeltherrien/randomforest/Main.java +++ b/src/main/java/ca/joeltherrien/randomforest/Main.java @@ -71,10 +71,18 @@ public class Main { final ForestTrainer forestTrainer = new ForestTrainer(settings, dataset, covariates); if(settings.isSaveProgress()){ - forestTrainer.trainParallelOnDisk(settings.getNumberOfThreads()); + if(settings.getNumberOfThreads() > 1){ + forestTrainer.trainParallelOnDisk(settings.getNumberOfThreads()); + } else{ + forestTrainer.trainSerialOnDisk(); + } } else{ - forestTrainer.trainParallelInMemory(settings.getNumberOfThreads()); + if(settings.getNumberOfThreads() > 1){ + forestTrainer.trainParallelInMemory(settings.getNumberOfThreads()); + } else{ + forestTrainer.trainSerial(); + } } } else if(args[1].equalsIgnoreCase("analyze")){ diff --git a/src/main/java/ca/joeltherrien/randomforest/covariates/Covariate.java b/src/main/java/ca/joeltherrien/randomforest/covariates/Covariate.java index 13b3efa..193889f 100644 --- a/src/main/java/ca/joeltherrien/randomforest/covariates/Covariate.java +++ b/src/main/java/ca/joeltherrien/randomforest/covariates/Covariate.java @@ -82,8 +82,8 @@ public interface Covariate extends Serializable, Comparable { * @return */ default Split applyRule(List> rows) { - final List> leftHand = new LinkedList<>(); - final List> rightHand = new LinkedList<>(); + final List> leftHand = new ArrayList<>(rows.size()*3/4); + final List> rightHand = new ArrayList<>(rows.size()*3/4); final List> missingValueRows = new ArrayList<>(); diff --git a/src/main/java/ca/joeltherrien/randomforest/tree/ForestTrainer.java b/src/main/java/ca/joeltherrien/randomforest/tree/ForestTrainer.java index 769046b..31d04e8 100644 --- a/src/main/java/ca/joeltherrien/randomforest/tree/ForestTrainer.java +++ b/src/main/java/ca/joeltherrien/randomforest/tree/ForestTrainer.java @@ -93,6 +93,35 @@ public class ForestTrainer { } + public void trainSerialOnDisk(){ + // First we need to see how many trees there currently are + final File folder = new File(saveTreeLocation); + if(!folder.isDirectory()){ + throw new IllegalArgumentException("Tree directory must be a directory!"); + } + + + final File[] treeFiles = folder.listFiles((file, s) -> s.endsWith(".tree")); + + final AtomicInteger treeCount = new AtomicInteger(treeFiles.length); // tracks how many trees are finished + // Using an AtomicInteger is overkill for serial code, but this lets use reuse TreeSavedWorker + + for(int j=treeCount.get(); j trainParallelInMemory(int threads){ // create a list that is prespecified in size (I can call the .set method at any index < ntree without diff --git a/src/main/java/ca/joeltherrien/randomforest/tree/TreeTrainer.java b/src/main/java/ca/joeltherrien/randomforest/tree/TreeTrainer.java index 9756c79..98292e3 100644 --- a/src/main/java/ca/joeltherrien/randomforest/tree/TreeTrainer.java +++ b/src/main/java/ca/joeltherrien/randomforest/tree/TreeTrainer.java @@ -111,10 +111,22 @@ public class TreeTrainer { } } + final Node leftNode; + final Node rightNode; + + // let's train the smaller hand first; I've seen some behaviour where a split takes only a very narrow slice + // off of the main body, and this repeats over and over again. I'd prefer to train those small nodes first so that + // we can get terminal nodes and save some memory in the heap + if(bestSplit.leftHand.size() < bestSplit.rightHand.size()){ + leftNode = growNode(bestSplit.leftHand, depth+1, random); + rightNode = growNode(bestSplit.rightHand, depth+1, random); + } + else{ + rightNode = growNode(bestSplit.rightHand, depth+1, random); + leftNode = growNode(bestSplit.leftHand, depth+1, random); + } - final Node leftNode = growNode(bestSplit.leftHand, depth+1, random); - final Node rightNode = growNode(bestSplit.rightHand, depth+1, random); return new SplitNode<>(leftNode, rightNode, bestSplit.getSplitRule(), probabilityLeftHand);