From 6f318db79ee21ed59a10f5e565d8f2b019eb7382 Mon Sep 17 00:00:00 2001 From: Joel Therrien Date: Fri, 10 May 2019 16:02:33 -0700 Subject: [PATCH] Add support for seeds to control randomness when training forests --- .../ca/joeltherrien/randomforest/Main.java | 2 +- .../joeltherrien/randomforest/Settings.java | 1 + .../randomforest/tree/ForestTrainer.java | 75 +++-- .../TestDeterministicForests.java | 298 ++++++++++++++++++ .../randomforest/TestSavingLoading.java | 20 +- .../joeltherrien/randomforest/TestUtils.java | 13 + .../competingrisk/TestCompetingRisk.java | 6 +- .../competingrisk/TestLogRankSplitFinder.java | 8 +- .../joeltherrien/randomforest/utils/Data.java | 35 ++ .../randomforest/workshop/TrainForest.java | 2 +- 10 files changed, 408 insertions(+), 52 deletions(-) create mode 100644 src/test/java/ca/joeltherrien/randomforest/TestDeterministicForests.java create mode 100644 src/test/java/ca/joeltherrien/randomforest/utils/Data.java diff --git a/src/main/java/ca/joeltherrien/randomforest/Main.java b/src/main/java/ca/joeltherrien/randomforest/Main.java index 3aa29ab..4d3951f 100644 --- a/src/main/java/ca/joeltherrien/randomforest/Main.java +++ b/src/main/java/ca/joeltherrien/randomforest/Main.java @@ -82,7 +82,7 @@ public class Main { if(settings.getNumberOfThreads() > 1){ forestTrainer.trainParallelInMemory(settings.getNumberOfThreads()); } else{ - forestTrainer.trainSerial(); + forestTrainer.trainSerialInMemory(); } } } diff --git a/src/main/java/ca/joeltherrien/randomforest/Settings.java b/src/main/java/ca/joeltherrien/randomforest/Settings.java index 64bd8da..6ba214d 100644 --- a/src/main/java/ca/joeltherrien/randomforest/Settings.java +++ b/src/main/java/ca/joeltherrien/randomforest/Settings.java @@ -151,6 +151,7 @@ public class Settings { private int nodeSize = 5; private int maxNodeDepth = 1000000; // basically no maxNodeDepth private boolean checkNodePurity = false; + private Long randomSeed; private ObjectNode responseCombinerSettings = new ObjectNode(JsonNodeFactory.instance); private ObjectNode splitFinderSettings = new ObjectNode(JsonNodeFactory.instance); diff --git a/src/main/java/ca/joeltherrien/randomforest/tree/ForestTrainer.java b/src/main/java/ca/joeltherrien/randomforest/tree/ForestTrainer.java index 6c62202..76e91ea 100644 --- a/src/main/java/ca/joeltherrien/randomforest/tree/ForestTrainer.java +++ b/src/main/java/ca/joeltherrien/randomforest/tree/ForestTrainer.java @@ -28,12 +28,11 @@ import lombok.Builder; import java.io.File; import java.io.IOException; import java.util.ArrayList; +import java.util.Arrays; import java.util.List; import java.util.Random; -import java.util.UUID; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; -import java.util.concurrent.ThreadLocalRandom; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; import java.util.stream.Collectors; @@ -51,8 +50,9 @@ public class ForestTrainer { // number of trees to try private final int ntree; - private final boolean displayProgress; + private final boolean displayProgress; // whether to print to standard output our progress; not always desirable private final String saveTreeLocation; + private final long randomSeed; public ForestTrainer(final Settings settings, final List> data, final List covariates){ this.ntree = settings.getNtree(); @@ -63,19 +63,25 @@ public class ForestTrainer { this.covariates = covariates; this.treeResponseCombiner = settings.getTreeCombiner(); this.treeTrainer = new TreeTrainer<>(settings, covariates); + + if(settings.getRandomSeed() != null){ + this.randomSeed = settings.getRandomSeed(); + } + else{ + this.randomSeed = System.nanoTime(); + } } - public Forest trainSerial(){ + public Forest trainSerialInMemory(){ final List> trees = new ArrayList<>(ntree); final Bootstrapper> bootstrapper = new Bootstrapper<>(data); - final Random random = new Random(); for(int j=0; j { public void trainSerialOnDisk(){ // First we need to see how many trees there currently are final File folder = new File(saveTreeLocation); + if(!folder.exists()){ + folder.mkdir(); + } + if(!folder.isDirectory()){ throw new IllegalArgumentException("Tree directory must be a directory!"); } - final File[] treeFiles = folder.listFiles((file, s) -> s.endsWith(".tree")); - + final List treeFileNames = Arrays.stream(treeFiles).map(file -> file.getName()).collect(Collectors.toList()); final AtomicInteger treeCount = new AtomicInteger(treeFiles.length); // tracks how many trees are finished // Using an AtomicInteger is overkill for serial code, but this lets use reuse TreeSavedWorker - for(int j=treeCount.get(); j { final ExecutorService executorService = Executors.newFixedThreadPool(threads); for(int j=0; j { public void trainParallelOnDisk(int threads){ // First we need to see how many trees there currently are final File folder = new File(saveTreeLocation); + if(!folder.exists()){ + folder.mkdir(); + } + if(!folder.isDirectory()){ throw new IllegalArgumentException("Tree directory must be a directory!"); } - final File[] treeFiles = folder.listFiles((file, s) -> s.endsWith(".tree")); - - final ExecutorService executorService = Executors.newFixedThreadPool(threads); + final List treeFileNames = Arrays.stream(treeFiles).map(file -> file.getName()).collect(Collectors.toList()); final AtomicInteger treeCount = new AtomicInteger(treeFiles.length); // tracks how many trees are finished - for(int j=treeCount.get(); j { private final Bootstrapper> bootstrapper; private final int treeIndex; private final List> treeList; + private final Random random; - TreeInMemoryWorker(final List> data, final int treeIndex, final List> treeList) { + TreeInMemoryWorker(final List> data, final int treeIndex, final List> treeList, final Random random) { this.bootstrapper = new Bootstrapper<>(data); this.treeIndex = treeIndex; this.treeList = treeList; + this.random = random; } @Override public void run() { - - // ThreadLocalRandom should make sure we don't duplicate seeds - final Tree tree = trainTree(bootstrapper, ThreadLocalRandom.current()); + final Tree tree = trainTree(bootstrapper, random); // should be okay as the list structure isn't changing treeList.set(treeIndex, tree); @@ -265,18 +292,18 @@ public class ForestTrainer { private final Bootstrapper> bootstrapper; private final String filename; private final AtomicInteger treeCount; + private final Random random; - public TreeSavedWorker(final List> data, final String filename, final AtomicInteger treeCount) { + public TreeSavedWorker(final List> data, final String filename, final AtomicInteger treeCount, final Random random) { this.bootstrapper = new Bootstrapper<>(data); this.filename = filename; this.treeCount = treeCount; + this.random = random; } @Override public void run() { - - // ThreadLocalRandom should make sure we don't duplicate seeds - final Tree tree = trainTree(bootstrapper, ThreadLocalRandom.current()); + final Tree tree = trainTree(bootstrapper, random); try { DataUtils.saveObject(tree, saveTreeLocation + "/" + filename); diff --git a/src/test/java/ca/joeltherrien/randomforest/TestDeterministicForests.java b/src/test/java/ca/joeltherrien/randomforest/TestDeterministicForests.java new file mode 100644 index 0000000..223d947 --- /dev/null +++ b/src/test/java/ca/joeltherrien/randomforest/TestDeterministicForests.java @@ -0,0 +1,298 @@ +/* + * Copyright (c) 2019 Joel Therrien. + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package ca.joeltherrien.randomforest; + +import ca.joeltherrien.randomforest.covariates.Covariate; +import ca.joeltherrien.randomforest.covariates.bool.BooleanCovariate; +import ca.joeltherrien.randomforest.covariates.factor.FactorCovariate; +import ca.joeltherrien.randomforest.covariates.numeric.NumericCovariate; +import ca.joeltherrien.randomforest.responses.regression.MeanResponseCombiner; +import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceSplitFinder; +import ca.joeltherrien.randomforest.tree.Forest; +import ca.joeltherrien.randomforest.tree.ForestTrainer; +import ca.joeltherrien.randomforest.tree.TreeTrainer; +import ca.joeltherrien.randomforest.utils.DataUtils; +import ca.joeltherrien.randomforest.utils.Utils; +import org.junit.jupiter.api.Test; + +import java.io.File; +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Random; + +public class TestDeterministicForests { + + private final String saveTreeLocation = "src/test/resources/trees/"; + + private List generateCovariates(){ + final List covariateList = new ArrayList<>(); + + int index = 0; + for(int j=0; j<5; j++){ + final NumericCovariate numericCovariate = new NumericCovariate("numeric"+j, index); + covariateList.add(numericCovariate); + index++; + } + + for(int j=0; j<5; j++){ + final BooleanCovariate booleanCovariate = new BooleanCovariate("boolean"+j, index); + covariateList.add(booleanCovariate); + index++; + } + + final List levels = Utils.easyList("cat", "dog", "mouse"); + for(int j=0; j<5; j++){ + final FactorCovariate factorCovariate = new FactorCovariate("factor"+j, index, levels); + covariateList.add(factorCovariate); + index++; + } + + return covariateList; + } + + private Covariate.Value generateRandomValue(Covariate covariate, Random random){ + if(covariate instanceof NumericCovariate){ + return covariate.createValue(random.nextGaussian()); + } + if(covariate instanceof BooleanCovariate){ + return covariate.createValue(random.nextBoolean()); + } + if(covariate instanceof FactorCovariate){ + final double itemSelection = random.nextDouble(); + final String item; + if(itemSelection < 1.0/3.0){ + item = "cat"; + } + else if(itemSelection < 2.0/3.0){ + item = "dog"; + } + else{ + item = "mouse"; + } + + return covariate.createValue(item); + } + else{ + throw new IllegalArgumentException("Unknown covariate type of class " + covariate.getClass().getName()); + } + + } + + private List> generateTestData(List covariateList, int n, Random random){ + final List> rowList = new ArrayList<>(); + for(int i=0; i covariateList = generateCovariates(); + + final Random dataGeneratingRandom = new Random(); + + final List> trainingData = generateTestData(covariateList, 100, dataGeneratingRandom); + final List> testData = generateTestData(covariateList, 10, dataGeneratingRandom); + + // pick a new seed at random + final long trainingSeed = dataGeneratingRandom.nextLong(); + + final TreeTrainer treeTrainer = TreeTrainer.builder() + .checkNodePurity(false) + .covariates(covariateList) + .maxNodeDepth(100) + .mtry(1) + .nodeSize(10) + .numberOfSplits(1) // want results to be dominated by randomness + .responseCombiner(new MeanResponseCombiner()) + .splitFinder(new WeightedVarianceSplitFinder()) + .build(); + + final ForestTrainer forestTrainer = ForestTrainer.builder() + .treeTrainer(treeTrainer) + .covariates(covariateList) + .data(trainingData) + .displayProgress(false) + .ntree(10) + .randomSeed(trainingSeed) + .treeResponseCombiner(new MeanResponseCombiner()) + .saveTreeLocation(saveTreeLocation) + .build(); + + // By training the referenceForest through one method we also verify that all the methods produce the same forests for a given seed. + final Forest referenceForest = forestTrainer.trainSerialInMemory(); + + verifySerialInMemoryTraining(referenceForest, forestTrainer, testData); + verifyParallelInMemoryTraining(referenceForest, forestTrainer, testData); + verifySerialOnDiskTraining(referenceForest, forestTrainer, testData); + verifyParallelOnDiskTraining(referenceForest, forestTrainer, testData); + + } + + /** + * Tests that if we train a forest under a specified seed for 10 trees, that it is equal to training a forest + * for 5 trees only, and then starting from that point to train the last 5. + * + * @throws IOException + * @throws ClassNotFoundException + */ + @Test + public void testInterupptedTrainingProducesSameResults() throws IOException, ClassNotFoundException { + final List covariateList = generateCovariates(); + + final Random dataGeneratingRandom = new Random(); + + final List> trainingData = generateTestData(covariateList, 100, dataGeneratingRandom); + final List> testData = generateTestData(covariateList, 10, dataGeneratingRandom); + + // pick a new seed at random + final long trainingSeed = dataGeneratingRandom.nextLong(); + + final TreeTrainer treeTrainer = TreeTrainer.builder() + .checkNodePurity(false) + .covariates(covariateList) + .maxNodeDepth(100) + .mtry(1) + .nodeSize(10) + .numberOfSplits(1) // want results to be dominated by randomness + .responseCombiner(new MeanResponseCombiner()) + .splitFinder(new WeightedVarianceSplitFinder()) + .build(); + + final ForestTrainer forestTrainer5Trees = ForestTrainer.builder() + .treeTrainer(treeTrainer) + .covariates(covariateList) + .data(trainingData) + .displayProgress(false) + .ntree(5) + .randomSeed(trainingSeed) + .treeResponseCombiner(new MeanResponseCombiner()) + .saveTreeLocation(saveTreeLocation) + .build(); + + final ForestTrainer forestTrainer10Trees = ForestTrainer.builder() + .treeTrainer(treeTrainer) + .covariates(covariateList) + .data(trainingData) + .displayProgress(false) + .ntree(10) + .randomSeed(trainingSeed) + .treeResponseCombiner(new MeanResponseCombiner()) + .saveTreeLocation(saveTreeLocation) + .build(); + + // By training the referenceForest through one method we also verify that all the methods produce the same forests for a given seed. + final Forest referenceForest = forestTrainer10Trees.trainSerialInMemory(); + + final File saveTreeFile = new File(saveTreeLocation); + + + forestTrainer5Trees.trainSerialOnDisk(); + forestTrainer10Trees.trainSerialOnDisk(); + final Forest forestSerial = DataUtils.loadForest(saveTreeFile, new MeanResponseCombiner()); + TestUtils.removeFolder(saveTreeFile); + verifyTwoForestsEqual(testData, referenceForest, forestSerial); + + + forestTrainer5Trees.trainParallelOnDisk(4); + forestTrainer10Trees.trainParallelOnDisk(4); + final Forest forestParallel = DataUtils.loadForest(saveTreeFile, new MeanResponseCombiner()); + TestUtils.removeFolder(saveTreeFile); + verifyTwoForestsEqual(testData, referenceForest, forestParallel); + + } + + private void verifySerialInMemoryTraining( + final Forest referenceForest, + ForestTrainer forestTrainer, + List> testData){ + + for(int k=0; k<3; k++){ + final Forest replicantForest = forestTrainer.trainSerialInMemory(); + verifyTwoForestsEqual(testData, referenceForest, replicantForest); + } + } + + private void verifyParallelInMemoryTraining( + Forest referenceForest, + ForestTrainer forestTrainer, + List> testData){ + + for(int k=0; k<3; k++){ + final Forest replicantForest = forestTrainer.trainParallelInMemory(4); + verifyTwoForestsEqual(testData, referenceForest, replicantForest); + } + } + + private void verifySerialOnDiskTraining( + Forest referenceForest, + ForestTrainer forestTrainer, + List> testData) throws IOException, ClassNotFoundException { + + final MeanResponseCombiner responseCombiner = new MeanResponseCombiner(); + final File saveTreeFile = new File(saveTreeLocation); + + for(int k=0; k<3; k++){ + forestTrainer.trainSerialOnDisk(); + final Forest replicantForest = DataUtils.loadForest(saveTreeFile, responseCombiner); + TestUtils.removeFolder(saveTreeFile); + verifyTwoForestsEqual(testData, referenceForest, replicantForest); + } + } + + private void verifyParallelOnDiskTraining( + final Forest referenceForest, ForestTrainer forestTrainer, + List> testData) throws IOException, ClassNotFoundException { + + final MeanResponseCombiner responseCombiner = new MeanResponseCombiner(); + final File saveTreeFile = new File(saveTreeLocation); + + for(int k=0; k<3; k++){ + forestTrainer.trainParallelOnDisk(4); + final Forest replicantForest = DataUtils.loadForest(saveTreeFile, responseCombiner); + TestUtils.removeFolder(saveTreeFile); + verifyTwoForestsEqual(testData, referenceForest, replicantForest); + } + } + + // Technically verifies the two forests give equal predictions on a given test dataset + private void verifyTwoForestsEqual(final List> testData, + final Forest forest1, + final Forest forest2){ + + for(Row row : testData){ + final Double prediction1 = forest1.evaluate(row); + final Double prediction2 = forest2.evaluate(row); + + // I've noticed that results aren't necessarily always *identical* + TestUtils.closeEnough(prediction1, prediction2, 0.0000000001); + } + + } + + +} diff --git a/src/test/java/ca/joeltherrien/randomforest/TestSavingLoading.java b/src/test/java/ca/joeltherrien/randomforest/TestSavingLoading.java index 7ae2dea..fe71844 100644 --- a/src/test/java/ca/joeltherrien/randomforest/TestSavingLoading.java +++ b/src/test/java/ca/joeltherrien/randomforest/TestSavingLoading.java @@ -116,7 +116,7 @@ public class TestSavingLoading { final File directory = new File(settings.getSaveTreeLocation()); if(directory.exists()){ - cleanup(directory); + TestUtils.removeFolder(directory); } assertFalse(directory.exists()); directory.mkdir(); @@ -142,7 +142,7 @@ public class TestSavingLoading { assertEquals(NTREE, forest.getTrees().size()); - cleanup(directory); + TestUtils.removeFolder(directory); assertFalse(directory.exists()); @@ -157,7 +157,7 @@ public class TestSavingLoading { final File directory = new File(settings.getSaveTreeLocation()); if(directory.exists()){ - cleanup(directory); + TestUtils.removeFolder(directory); } assertFalse(directory.exists()); directory.mkdir(); @@ -183,24 +183,12 @@ public class TestSavingLoading { assertEquals(NTREE, forest.getTrees().size()); - cleanup(directory); + TestUtils.removeFolder(directory); assertFalse(directory.exists()); } - private void cleanup(File file){ - if(file.isFile()){ - file.delete(); - } - else{ - for(final File inner : file.listFiles()){ - cleanup(inner); - } - file.delete(); - } - } - } diff --git a/src/test/java/ca/joeltherrien/randomforest/TestUtils.java b/src/test/java/ca/joeltherrien/randomforest/TestUtils.java index 24faceb..f118d05 100644 --- a/src/test/java/ca/joeltherrien/randomforest/TestUtils.java +++ b/src/test/java/ca/joeltherrien/randomforest/TestUtils.java @@ -20,6 +20,7 @@ import ca.joeltherrien.randomforest.utils.StepFunction; import ca.joeltherrien.randomforest.utils.Utils; import org.junit.jupiter.api.Test; +import java.io.File; import java.util.ArrayList; import java.util.HashSet; import java.util.List; @@ -170,4 +171,16 @@ public class TestUtils { } } + public static void removeFolder(File file){ + if(file.isFile()){ + file.delete(); + } + else{ + for(final File inner : file.listFiles()){ + removeFolder(inner); + } + file.delete(); + } + } + } diff --git a/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCompetingRisk.java b/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCompetingRisk.java index 1f78096..47d6c46 100644 --- a/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCompetingRisk.java +++ b/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCompetingRisk.java @@ -235,7 +235,7 @@ public class TestCompetingRisk { final ForestTrainer forestTrainer = new ForestTrainer<>(settings, dataset, covariates); - final Forest forest = forestTrainer.trainSerial(); + final Forest forest = forestTrainer.trainSerialInMemory(); // prediction row // time status ageatfda idu black cd4nadir @@ -328,7 +328,7 @@ public class TestCompetingRisk { final long startTime = System.currentTimeMillis(); for(int i=0; i<50; i++){ - forestTrainer.trainSerial(); + forestTrainer.trainSerialInMemory(); } final long endTime = System.currentTimeMillis(); @@ -346,7 +346,7 @@ public class TestCompetingRisk { final List> dataset = DataUtils.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation()); final ForestTrainer forestTrainer = new ForestTrainer<>(settings, dataset, covariates); - final Forest forest = forestTrainer.trainSerial(); + final Forest forest = forestTrainer.trainSerialInMemory(); // prediction row // time status ageatfda idu black cd4nadir diff --git a/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestLogRankSplitFinder.java b/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestLogRankSplitFinder.java index 259410b..95c741e 100644 --- a/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestLogRankSplitFinder.java +++ b/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestLogRankSplitFinder.java @@ -23,6 +23,7 @@ import ca.joeltherrien.randomforest.covariates.settings.NumericCovariateSettings import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse; import ca.joeltherrien.randomforest.responses.competingrisk.splitfinder.LogRankSplitFinder; import ca.joeltherrien.randomforest.tree.Split; +import ca.joeltherrien.randomforest.utils.Data; import ca.joeltherrien.randomforest.utils.DataUtils; import ca.joeltherrien.randomforest.utils.SingletonIterator; import ca.joeltherrien.randomforest.utils.Utils; @@ -89,11 +90,4 @@ public class TestLogRankSplitFinder { assertTrue(Math.abs(expected - actual) < margin, "Expected " + expected + " but saw " + actual); } - @lombok.Data - @AllArgsConstructor - public static class Data { - private List> rows; - private List covariateList; - } - } diff --git a/src/test/java/ca/joeltherrien/randomforest/utils/Data.java b/src/test/java/ca/joeltherrien/randomforest/utils/Data.java new file mode 100644 index 0000000..22a2710 --- /dev/null +++ b/src/test/java/ca/joeltherrien/randomforest/utils/Data.java @@ -0,0 +1,35 @@ +/* + * Copyright (c) 2019 Joel Therrien. + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package ca.joeltherrien.randomforest.utils; + +import ca.joeltherrien.randomforest.Row; +import ca.joeltherrien.randomforest.covariates.Covariate; +import lombok.AllArgsConstructor; + +import java.util.List; + +/** + * Convenience class for unit tests + * + * @param The type of response. + */ +@lombok.Data +@AllArgsConstructor +public class Data { + private List> rows; + private List covariateList; +} diff --git a/src/test/java/ca/joeltherrien/randomforest/workshop/TrainForest.java b/src/test/java/ca/joeltherrien/randomforest/workshop/TrainForest.java index b27045e..6942bf5 100644 --- a/src/test/java/ca/joeltherrien/randomforest/workshop/TrainForest.java +++ b/src/test/java/ca/joeltherrien/randomforest/workshop/TrainForest.java @@ -88,7 +88,7 @@ public class TrainForest { final long startTime = System.currentTimeMillis(); - //final Forest forest = forestTrainer.trainSerial(); + //final Forest forest = forestTrainer.trainSerialInMemory(); //final Forest forest = forestTrainer.trainParallelInMemory(3); forestTrainer.trainParallelOnDisk(3);