diff --git a/src/main/java/ca/joeltherrien/randomforest/Main.java b/src/main/java/ca/joeltherrien/randomforest/Main.java index 4d3951f..0659ab6 100644 --- a/src/main/java/ca/joeltherrien/randomforest/Main.java +++ b/src/main/java/ca/joeltherrien/randomforest/Main.java @@ -39,6 +39,7 @@ import java.io.File; import java.io.IOException; import java.io.PrintWriter; import java.util.List; +import java.util.Optional; import java.util.Random; public class Main { @@ -73,16 +74,16 @@ public class Main { if(settings.isSaveProgress()){ if(settings.getNumberOfThreads() > 1){ - forestTrainer.trainParallelOnDisk(settings.getNumberOfThreads()); + forestTrainer.trainParallelOnDisk(Optional.empty(), settings.getNumberOfThreads()); } else{ - forestTrainer.trainSerialOnDisk(); + forestTrainer.trainSerialOnDisk(Optional.empty()); } } else{ if(settings.getNumberOfThreads() > 1){ - forestTrainer.trainParallelInMemory(settings.getNumberOfThreads()); + forestTrainer.trainParallelInMemory(Optional.empty(), settings.getNumberOfThreads()); } else{ - forestTrainer.trainSerialInMemory(); + forestTrainer.trainSerialInMemory(Optional.empty()); } } } diff --git a/src/main/java/ca/joeltherrien/randomforest/tree/ForestTrainer.java b/src/main/java/ca/joeltherrien/randomforest/tree/ForestTrainer.java index 04c53e5..f20e505 100644 --- a/src/main/java/ca/joeltherrien/randomforest/tree/ForestTrainer.java +++ b/src/main/java/ca/joeltherrien/randomforest/tree/ForestTrainer.java @@ -21,16 +21,11 @@ import ca.joeltherrien.randomforest.Row; import ca.joeltherrien.randomforest.Settings; import ca.joeltherrien.randomforest.covariates.Covariate; import ca.joeltherrien.randomforest.utils.DataUtils; -import lombok.AccessLevel; -import lombok.AllArgsConstructor; -import lombok.Builder; +import lombok.*; import java.io.File; import java.io.IOException; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; -import java.util.Random; +import java.util.*; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; @@ -48,7 +43,9 @@ public class ForestTrainer { private final List> data; // number of trees to try - private final int ntree; + @Getter + @Setter + private int ntree; private final boolean displayProgress; // whether to print to standard output our progress; not always desirable private final String saveTreeLocation; @@ -72,12 +69,21 @@ public class ForestTrainer { } } - public Forest trainSerialInMemory(){ + /** + * Train a forest in memory using a single core + * + * @param initialForest An Optional possibly containing a pre-trained forest, + * in which case its trees are combined with the new one. + * @return A trained forest. + */ + public Forest trainSerialInMemory(Optional> initialForest){ final List> trees = new ArrayList<>(ntree); + initialForest.ifPresent(forest -> trees.addAll(forest.getTrees())); + final Bootstrapper> bootstrapper = new Bootstrapper<>(data); - for(int j=0; j { } - public void trainSerialOnDisk(){ + /** + * Train a forest on the disk using a single core. + * + * @param initialForest An Optional possibly containing a pre-trained forest, + * in which case its trees are combined with the new one. + * There cannot be existing trees if the initial forest is + * specified. + */ + public void trainSerialOnDisk(Optional> initialForest){ // First we need to see how many trees there currently are final File folder = new File(saveTreeLocation); if(!folder.exists()){ @@ -112,21 +126,42 @@ public class ForestTrainer { final File[] treeFiles = folder.listFiles((file, s) -> s.endsWith(".tree")); final List treeFileNames = Arrays.stream(treeFiles).map(file -> file.getName()).collect(Collectors.toList()); - final AtomicInteger treeCount = new AtomicInteger(treeFiles.length); // tracks how many trees are finished - // Using an AtomicInteger is overkill for serial code, but this lets use reuse TreeSavedWorker - for(int j=0; j 0){ + throw new IllegalArgumentException("An initial forest is present but trees are also present; not clear how to integrate initial forest into new forest"); + } + + final AtomicInteger treeCount; // tracks how many trees are finished + // Using an AtomicInteger is overkill for serial code, but this lets us reuse TreeSavedWorker + if(initialForest.isPresent()){ + final List> initialTrees = initialForest.get().getTrees(); + + for(int j=0; j tree = initialTrees.get(j); + + saveTree(tree, filename); + + } + + treeCount = new AtomicInteger(initialTrees.size()); + } else{ + treeCount = new AtomicInteger(treeFiles.length); + } + + + while(treeCount.get() < ntree){ if(displayProgress) { System.out.print("\rFinished " + treeCount.get() + "/" + ntree + " trees"); } - final String treeFileName = "tree-" + (j+1) + ".tree"; + final String treeFileName = "tree-" + (treeCount.get() + 1) + ".tree"; if(treeFileNames.contains(treeFileName)){ continue; } - final Random random = new Random(this.randomSeed + j); + final Random random = new Random(this.randomSeed + treeCount.get()); final Runnable worker = new TreeSavedWorker(data, treeFileName, treeCount, random); worker.run(); @@ -139,15 +174,34 @@ public class ForestTrainer { } - public Forest trainParallelInMemory(int threads){ + /** + * Train a forest in memory using the specified number of threads. + * + * @param initialForest An Optional possibly containing a pre-trained forest, + * in which case its trees are combined with the new one. + * @param threads The number of trees to train at once. + */ + public Forest trainParallelInMemory(Optional> initialForest, int threads){ - // create a list that is prespecified in size (I can call the .set method at any index < ntree without + // create a list that is pre-specified in size (I can call the .set method at any index < ntree without // the earlier indexes being filled. final List> trees = Stream.>generate(() -> null).limit(ntree).collect(Collectors.toList()); + final int startingCount; + if(initialForest.isPresent()){ + final List> initialTrees = initialForest.get().getTrees(); + for(int j=0; j { } - - public void trainParallelOnDisk(int threads){ + /** + * Train a forest on the disk using a specified number of threads. + * + * @param initialForest An Optional possibly containing a pre-trained forest, + * in which case its trees are combined with the new one. + * There cannot be existing trees if the initial forest is + * specified. + * @param threads The number of trees to train at once. + */ + public void trainParallelOnDisk(Optional> initialForest, int threads){ // First we need to see how many trees there currently are final File folder = new File(saveTreeLocation); if(!folder.exists()){ @@ -205,11 +267,31 @@ public class ForestTrainer { final File[] treeFiles = folder.listFiles((file, s) -> s.endsWith(".tree")); final List treeFileNames = Arrays.stream(treeFiles).map(file -> file.getName()).collect(Collectors.toList()); - final AtomicInteger treeCount = new AtomicInteger(treeFiles.length); // tracks how many trees are finished + + if(initialForest.isPresent() & treeFiles.length > 0){ + throw new IllegalArgumentException("An initial forest is present but trees are also present; not clear how to integrate initial forest into new forest"); + } + + final AtomicInteger treeCount; // tracks how many trees are finished + if(initialForest.isPresent()){ + final List> initialTrees = initialForest.get().getTrees(); + + for(int j=0; j tree = initialTrees.get(j); + + saveTree(tree, filename); + + } + + treeCount = new AtomicInteger(initialTrees.size()); + } else{ + treeCount = new AtomicInteger(treeFiles.length); + } final ExecutorService executorService = Executors.newFixedThreadPool(threads); - for(int j=0; j { return treeTrainer.growTree(bootstrappedData, random); } + private void saveTree(Tree tree, String filename){ + try { + DataUtils.saveObject(tree, saveTreeLocation + "/" + filename); + } catch (IOException e) { + System.err.println("IOException while saving " + filename); + e.printStackTrace(); + System.err.println("Quitting program"); + System.exit(1); + } + } + private class TreeInMemoryWorker implements Runnable { @@ -297,14 +390,7 @@ public class ForestTrainer { public void run() { final Tree tree = trainTree(bootstrapper, random); - try { - DataUtils.saveObject(tree, saveTreeLocation + "/" + filename); - } catch (IOException e) { - System.err.println("IOException while saving " + filename); - e.printStackTrace(); - System.err.println("Quitting program"); - System.exit(1); - } + saveTree(tree, filename); treeCount.incrementAndGet(); diff --git a/src/test/java/ca/joeltherrien/randomforest/TestDeterministicForests.java b/src/test/java/ca/joeltherrien/randomforest/TestDeterministicForests.java index 223d947..506e939 100644 --- a/src/test/java/ca/joeltherrien/randomforest/TestDeterministicForests.java +++ b/src/test/java/ca/joeltherrien/randomforest/TestDeterministicForests.java @@ -33,6 +33,7 @@ import java.io.File; import java.io.IOException; import java.util.ArrayList; import java.util.List; +import java.util.Optional; import java.util.Random; public class TestDeterministicForests { @@ -144,7 +145,7 @@ public class TestDeterministicForests { .build(); // By training the referenceForest through one method we also verify that all the methods produce the same forests for a given seed. - final Forest referenceForest = forestTrainer.trainSerialInMemory(); + final Forest referenceForest = forestTrainer.trainSerialInMemory(Optional.empty()); verifySerialInMemoryTraining(referenceForest, forestTrainer, testData); verifyParallelInMemoryTraining(referenceForest, forestTrainer, testData); @@ -206,20 +207,20 @@ public class TestDeterministicForests { .build(); // By training the referenceForest through one method we also verify that all the methods produce the same forests for a given seed. - final Forest referenceForest = forestTrainer10Trees.trainSerialInMemory(); + final Forest referenceForest = forestTrainer10Trees.trainSerialInMemory(Optional.empty()); final File saveTreeFile = new File(saveTreeLocation); - forestTrainer5Trees.trainSerialOnDisk(); - forestTrainer10Trees.trainSerialOnDisk(); + forestTrainer5Trees.trainSerialOnDisk(Optional.empty()); + forestTrainer10Trees.trainSerialOnDisk(Optional.empty()); final Forest forestSerial = DataUtils.loadForest(saveTreeFile, new MeanResponseCombiner()); TestUtils.removeFolder(saveTreeFile); verifyTwoForestsEqual(testData, referenceForest, forestSerial); - forestTrainer5Trees.trainParallelOnDisk(4); - forestTrainer10Trees.trainParallelOnDisk(4); + forestTrainer5Trees.trainParallelOnDisk(Optional.empty(), 4); + forestTrainer10Trees.trainParallelOnDisk(Optional.empty(), 4); final Forest forestParallel = DataUtils.loadForest(saveTreeFile, new MeanResponseCombiner()); TestUtils.removeFolder(saveTreeFile); verifyTwoForestsEqual(testData, referenceForest, forestParallel); @@ -232,7 +233,7 @@ public class TestDeterministicForests { List> testData){ for(int k=0; k<3; k++){ - final Forest replicantForest = forestTrainer.trainSerialInMemory(); + final Forest replicantForest = forestTrainer.trainSerialInMemory(Optional.empty()); verifyTwoForestsEqual(testData, referenceForest, replicantForest); } } @@ -243,7 +244,7 @@ public class TestDeterministicForests { List> testData){ for(int k=0; k<3; k++){ - final Forest replicantForest = forestTrainer.trainParallelInMemory(4); + final Forest replicantForest = forestTrainer.trainParallelInMemory(Optional.empty(), 4); verifyTwoForestsEqual(testData, referenceForest, replicantForest); } } @@ -257,7 +258,7 @@ public class TestDeterministicForests { final File saveTreeFile = new File(saveTreeLocation); for(int k=0; k<3; k++){ - forestTrainer.trainSerialOnDisk(); + forestTrainer.trainSerialOnDisk(Optional.empty()); final Forest replicantForest = DataUtils.loadForest(saveTreeFile, responseCombiner); TestUtils.removeFolder(saveTreeFile); verifyTwoForestsEqual(testData, referenceForest, replicantForest); @@ -272,7 +273,7 @@ public class TestDeterministicForests { final File saveTreeFile = new File(saveTreeLocation); for(int k=0; k<3; k++){ - forestTrainer.trainParallelOnDisk(4); + forestTrainer.trainParallelOnDisk(Optional.empty(), 4); final Forest replicantForest = DataUtils.loadForest(saveTreeFile, responseCombiner); TestUtils.removeFolder(saveTreeFile); verifyTwoForestsEqual(testData, referenceForest, replicantForest); diff --git a/src/test/java/ca/joeltherrien/randomforest/TestProvidingInitialForest.java b/src/test/java/ca/joeltherrien/randomforest/TestProvidingInitialForest.java new file mode 100644 index 0000000..68143c4 --- /dev/null +++ b/src/test/java/ca/joeltherrien/randomforest/TestProvidingInitialForest.java @@ -0,0 +1,224 @@ +/* + * Copyright (c) 2019 Joel Therrien. + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package ca.joeltherrien.randomforest; + +import ca.joeltherrien.randomforest.covariates.Covariate; +import ca.joeltherrien.randomforest.covariates.numeric.NumericCovariate; +import ca.joeltherrien.randomforest.covariates.settings.BooleanCovariateSettings; +import ca.joeltherrien.randomforest.covariates.settings.NumericCovariateSettings; +import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskFunctions; +import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse; +import ca.joeltherrien.randomforest.responses.competingrisk.combiner.CompetingRiskFunctionCombiner; +import ca.joeltherrien.randomforest.responses.regression.MeanResponseCombiner; +import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceSplitFinder; +import ca.joeltherrien.randomforest.tree.Forest; +import ca.joeltherrien.randomforest.tree.ForestTrainer; +import ca.joeltherrien.randomforest.tree.Tree; +import ca.joeltherrien.randomforest.tree.TreeTrainer; +import ca.joeltherrien.randomforest.utils.DataUtils; +import ca.joeltherrien.randomforest.utils.Utils; +import com.fasterxml.jackson.databind.node.*; +import org.junit.jupiter.api.Test; + +import java.io.File; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Optional; +import java.util.stream.Collectors; + +import static org.junit.jupiter.api.Assertions.*; + +public class TestProvidingInitialForest { + + private Forest initialForest; + private List covariateList; + private List> data; + + public TestProvidingInitialForest(){ + covariateList = Collections.singletonList(new NumericCovariate("x", 0)); + + data = Utils.easyList( + Row.createSimple(Utils.easyMap("x", "1.0"), covariateList, 1, 1.0), + Row.createSimple(Utils.easyMap("x", "1.0"), covariateList, 2, 1.5), + Row.createSimple(Utils.easyMap("x", "2.0"), covariateList, 3, 5.0), + Row.createSimple(Utils.easyMap("x", "2.0"), covariateList, 4, 6.0) + ); + + final TreeTrainer treeTrainer = TreeTrainer.builder() + .splitFinder(new WeightedVarianceSplitFinder()) + .responseCombiner(new MeanResponseCombiner()) + .checkNodePurity(false) + .numberOfSplits(0) + .nodeSize(1) + .mtry(1) + .maxNodeDepth(100000) + .covariates(covariateList) + .build(); + + final ForestTrainer forestTrainer = ForestTrainer.builder() + .treeResponseCombiner(new MeanResponseCombiner()) + .ntree(10) + .displayProgress(false) + .data(data) + .covariates(covariateList) + .treeTrainer(treeTrainer) + .build(); + + initialForest = forestTrainer.trainSerialInMemory(Optional.empty()); + } + + private final int NTREE = 10; + + private ForestTrainer getForestTrainer(String saveTreeLocation, int ntree){ + final TreeTrainer treeTrainer = TreeTrainer.builder() + .splitFinder(new WeightedVarianceSplitFinder()) + .responseCombiner(new MeanResponseCombiner()) + .checkNodePurity(false) + .numberOfSplits(0) + .nodeSize(1) + .mtry(1) + .maxNodeDepth(100000) + .covariates(covariateList) + .build(); + + final ForestTrainer forestTrainer = ForestTrainer.builder() + .treeResponseCombiner(new MeanResponseCombiner()) + .ntree(ntree) + .displayProgress(false) + .data(data) + .covariates(covariateList) + .treeTrainer(treeTrainer) + .saveTreeLocation(saveTreeLocation) + .build(); + + return forestTrainer; + } + + @Test + public void testSerialInMemory(){ + final ForestTrainer forestTrainer = getForestTrainer(null, 20); + + final Forest newForest = forestTrainer.trainSerialInMemory(Optional.of(initialForest)); + assertEquals(20, newForest.getTrees().size()); + + for(Tree initialTree : initialForest.getTrees()){ + assertTrue(newForest.getTrees().contains(initialTree)); + } + for(int j=10; j<20; j++){ + final Tree newTree = newForest.getTrees().get(j); + assertFalse(initialForest.getTrees().contains(newTree)); + } + + } + + @Test + public void testParallelInMemory(){ + final ForestTrainer forestTrainer = getForestTrainer(null, 20); + + final Forest newForest = forestTrainer.trainParallelInMemory(Optional.of(initialForest), 2); + assertEquals(20, newForest.getTrees().size()); + + for(Tree initialTree : initialForest.getTrees()){ + assertTrue(newForest.getTrees().contains(initialTree)); + } + for(int j=10; j<20; j++){ + final Tree newTree = newForest.getTrees().get(j); + assertFalse(initialForest.getTrees().contains(newTree)); + } + } + + @Test + public void testParallelOnDisk() throws IOException, ClassNotFoundException { + final String filePath = "src/test/resources/trees/"; + final File directory = new File(filePath); + if(directory.exists()){ + TestUtils.removeFolder(directory); + } + + final ForestTrainer forestTrainer = getForestTrainer(filePath, 20); + + forestTrainer.trainParallelOnDisk(Optional.of(initialForest), 2); + + assertEquals(20, directory.listFiles().length); + final Forest newForest = DataUtils.loadForest(directory, new MeanResponseCombiner()); + + + + assertEquals(20, newForest.getTrees().size()); + + final List newForestTreesAsStrings = newForest.getTrees().stream() + .map(tree -> tree.toString()).collect(Collectors.toList()); + + for(Tree initialTree : initialForest.getTrees()){ + assertTrue(newForestTreesAsStrings.contains(initialTree.toString())); + } + + TestUtils.removeFolder(directory); + } + + @Test + public void testSerialOnDisk() throws IOException, ClassNotFoundException { + final String filePath = "src/test/resources/trees/"; + final File directory = new File(filePath); + if(directory.exists()){ + TestUtils.removeFolder(directory); + } + final ForestTrainer forestTrainer = getForestTrainer(filePath, 20); + + + forestTrainer.trainSerialOnDisk(Optional.of(initialForest)); + + assertEquals(20, directory.listFiles().length); + + final Forest newForest = DataUtils.loadForest(directory, new MeanResponseCombiner()); + + assertEquals(20, newForest.getTrees().size()); + + final List newForestTreesAsStrings = newForest.getTrees().stream() + .map(tree -> tree.toString()).collect(Collectors.toList()); + + for(Tree initialTree : initialForest.getTrees()){ + assertTrue(newForestTreesAsStrings.contains(initialTree.toString())); + } + + TestUtils.removeFolder(directory); + } + + /* + We throw IllegalArgumentExceptions when we try providing an initial forest when trees were already saved, because + it's not clear if the forest being provided is the same one that trees were saved from. + */ + @Test + public void verifyExceptions(){ + final String filePath = "src/test/resources/trees/"; + final File directory = new File(filePath); + if(directory.exists()){ + TestUtils.removeFolder(directory); + } + final ForestTrainer forestTrainer = getForestTrainer(filePath, 10); + forestTrainer.trainSerialOnDisk(Optional.empty()); + + forestTrainer.setNtree(20); + assertThrows(IllegalArgumentException.class, () -> forestTrainer.trainSerialOnDisk(Optional.of(initialForest))); + assertThrows(IllegalArgumentException.class, () -> forestTrainer.trainParallelOnDisk(Optional.of(initialForest), 2)); + + TestUtils.removeFolder(directory); + } + +} diff --git a/src/test/java/ca/joeltherrien/randomforest/TestSavingLoading.java b/src/test/java/ca/joeltherrien/randomforest/TestSavingLoading.java index fe71844..53f5bef 100644 --- a/src/test/java/ca/joeltherrien/randomforest/TestSavingLoading.java +++ b/src/test/java/ca/joeltherrien/randomforest/TestSavingLoading.java @@ -33,6 +33,7 @@ import static org.junit.jupiter.api.Assertions.*; import java.io.File; import java.io.IOException; import java.util.List; +import java.util.Optional; public class TestSavingLoading { @@ -123,7 +124,7 @@ public class TestSavingLoading { final ForestTrainer forestTrainer = new ForestTrainer<>(settings, dataset, covariates); - forestTrainer.trainSerialOnDisk(); + forestTrainer.trainSerialOnDisk(Optional.empty()); assertTrue(directory.exists()); assertTrue(directory.isDirectory()); @@ -164,7 +165,7 @@ public class TestSavingLoading { final ForestTrainer forestTrainer = new ForestTrainer<>(settings, dataset, covariates); - forestTrainer.trainParallelOnDisk(settings.getNumberOfThreads()); + forestTrainer.trainParallelOnDisk(Optional.empty(), settings.getNumberOfThreads()); assertTrue(directory.exists()); assertTrue(directory.isDirectory()); diff --git a/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCompetingRisk.java b/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCompetingRisk.java index 47d6c46..992c132 100644 --- a/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCompetingRisk.java +++ b/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCompetingRisk.java @@ -37,6 +37,7 @@ import org.junit.jupiter.api.Test; import java.io.IOException; import java.util.List; +import java.util.Optional; import java.util.Random; import static ca.joeltherrien.randomforest.TestUtils.assertCumulativeFunction; @@ -235,7 +236,7 @@ public class TestCompetingRisk { final ForestTrainer forestTrainer = new ForestTrainer<>(settings, dataset, covariates); - final Forest forest = forestTrainer.trainSerialInMemory(); + final Forest forest = forestTrainer.trainSerialInMemory(Optional.empty()); // prediction row // time status ageatfda idu black cd4nadir @@ -328,7 +329,7 @@ public class TestCompetingRisk { final long startTime = System.currentTimeMillis(); for(int i=0; i<50; i++){ - forestTrainer.trainSerialInMemory(); + forestTrainer.trainSerialInMemory(Optional.empty()); } final long endTime = System.currentTimeMillis(); @@ -346,7 +347,7 @@ public class TestCompetingRisk { final List> dataset = DataUtils.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation()); final ForestTrainer forestTrainer = new ForestTrainer<>(settings, dataset, covariates); - final Forest forest = forestTrainer.trainSerialInMemory(); + final Forest forest = forestTrainer.trainSerialInMemory(Optional.empty()); // prediction row // time status ageatfda idu black cd4nadir diff --git a/src/test/java/ca/joeltherrien/randomforest/workshop/TrainForest.java b/src/test/java/ca/joeltherrien/randomforest/workshop/TrainForest.java index 6942bf5..c9a031a 100644 --- a/src/test/java/ca/joeltherrien/randomforest/workshop/TrainForest.java +++ b/src/test/java/ca/joeltherrien/randomforest/workshop/TrainForest.java @@ -90,7 +90,7 @@ public class TrainForest { //final Forest forest = forestTrainer.trainSerialInMemory(); //final Forest forest = forestTrainer.trainParallelInMemory(3); - forestTrainer.trainParallelOnDisk(3); + forestTrainer.trainParallelOnDisk(Optional.empty(), 3); final long endTime = System.currentTimeMillis();