diff --git a/src/main/java/ca/joeltherrien/randomforest/tree/ForestTrainer.java b/src/main/java/ca/joeltherrien/randomforest/tree/ForestTrainer.java index 26adfc0..69f7ab0 100644 --- a/src/main/java/ca/joeltherrien/randomforest/tree/ForestTrainer.java +++ b/src/main/java/ca/joeltherrien/randomforest/tree/ForestTrainer.java @@ -8,6 +8,7 @@ import lombok.AccessLevel; import lombok.AllArgsConstructor; import lombok.Builder; +import java.io.File; import java.io.FileOutputStream; import java.io.IOException; import java.io.ObjectOutputStream; @@ -122,10 +123,19 @@ public class ForestTrainer { public void trainParallelOnDisk(int threads){ - final ExecutorService executorService = Executors.newFixedThreadPool(threads); - final AtomicInteger treeCount = new AtomicInteger(0); // tracks how many trees are finished + // First we need to see how many trees there currently are + final File folder = new File(saveTreeLocation); + if(!folder.isDirectory()){ + throw new IllegalArgumentException("Tree directory must be a directory!"); + } - for(int j=0; j s.endsWith(".tree"))); + + final ExecutorService executorService = Executors.newFixedThreadPool(threads); + final AtomicInteger treeCount = new AtomicInteger(treeFiles.length); // tracks how many trees are finished + + for(int j=treeCount.get(); j