diff --git a/src/main/java/ca/joeltherrien/randomforest/tree/ForestTrainer.java b/src/main/java/ca/joeltherrien/randomforest/tree/ForestTrainer.java index 6b8b285..6c62202 100644 --- a/src/main/java/ca/joeltherrien/randomforest/tree/ForestTrainer.java +++ b/src/main/java/ca/joeltherrien/randomforest/tree/ForestTrainer.java @@ -139,7 +139,7 @@ public class ForestTrainer { int prevNumberTreesSet = -1; while(true){ try { - if (executorService.awaitTermination(5, TimeUnit.SECONDS)) break; + if (executorService.awaitTermination(1, TimeUnit.SECONDS)) break; } catch (InterruptedException e) { System.err.println("There was an InterruptedException while waiting for the forest to finish training; this is unusual but on its own shouldn't be a problem."); System.err.println("Please send a bug report about it to joelt@sfu.ca"); @@ -147,21 +147,22 @@ public class ForestTrainer { // do nothing; this shouldn't be an issue } - if(displayProgress) { - int numberTreesSet = 0; - for (final Tree tree : trees) { - if (tree != null) { - numberTreesSet++; - } + int numberTreesSet = 0; + for (final Tree tree : trees) { + if (tree != null) { + numberTreesSet++; } + } + if(displayProgress && numberTreesSet != prevNumberTreesSet) { // Only output trees set on screen if there was a change // In some environments where standard output is streamed to a file this method below causes frequent writes to output - if(numberTreesSet != prevNumberTreesSet){ - System.out.print("\rFinished " + numberTreesSet + "/" + ntree + " trees"); - prevNumberTreesSet = numberTreesSet; - } + System.out.print("\rFinished " + numberTreesSet + "/" + ntree + " trees"); + prevNumberTreesSet = numberTreesSet; + } + if(numberTreesSet == ntree){ + executorService.shutdown(); } } @@ -199,26 +200,27 @@ public class ForestTrainer { int prevNumberTreesSet = -1; while(true){ try { - if (executorService.awaitTermination(5, TimeUnit.SECONDS)) break; + if (executorService.awaitTermination(1, TimeUnit.SECONDS)) break; } catch (InterruptedException e) { System.err.println("There was an InterruptedException while waiting for the forest to finish training; this is unusual but on its own shouldn't be a problem."); System.err.println("Please send a bug report about it to joelt@sfu.ca"); e.printStackTrace(); // do nothing; this shouldn't be an issue } + int numberTreesSet = treeCount.get(); - if(displayProgress) { - int numberTreesSet = treeCount.get(); - + if(displayProgress && numberTreesSet != prevNumberTreesSet) { // Only output trees set on screen if there was a change // In some environments where standard output is streamed to a file this method below causes frequent writes to output - if(numberTreesSet != prevNumberTreesSet){ - System.out.print("\rFinished " + numberTreesSet + "/" + ntree + " trees"); - prevNumberTreesSet = numberTreesSet; - } + System.out.print("\rFinished " + numberTreesSet + "/" + ntree + " trees"); + prevNumberTreesSet = numberTreesSet; } + if(numberTreesSet == ntree){ + executorService.shutdown(); + } + } if(displayProgress){ diff --git a/src/test/java/ca/joeltherrien/randomforest/TestSavingLoading.java b/src/test/java/ca/joeltherrien/randomforest/TestSavingLoading.java index 77a3db6..58dd3e7 100644 --- a/src/test/java/ca/joeltherrien/randomforest/TestSavingLoading.java +++ b/src/test/java/ca/joeltherrien/randomforest/TestSavingLoading.java @@ -109,7 +109,7 @@ public class TestSavingLoading { } @Test - public void testSavingLoading() throws IOException, ClassNotFoundException { + public void testSavingLoadingSerial() throws IOException, ClassNotFoundException { final Settings settings = getSettings(); final List covariates = settings.getCovariates(); final List> dataset = DataUtils.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation()); @@ -131,6 +131,47 @@ public class TestSavingLoading { + final Forest forest = DataUtils.loadForest(directory, new CompetingRiskFunctionCombiner(new int[]{1,2}, null)); + + final CovariateRow predictionRow = getPredictionRow(covariates); + + final CompetingRiskFunctions functions = forest.evaluate(predictionRow); + assertNotNull(functions); + assertTrue(functions.getCumulativeIncidenceFunction(1).getX().length > 2); + + + assertEquals(NTREE, forest.getTrees().size()); + + cleanup(directory); + + assertFalse(directory.exists()); + + } + + + @Test + public void testSavingLoadingParallel() throws IOException, ClassNotFoundException { + final Settings settings = getSettings(); + final List covariates = settings.getCovariates(); + final List> dataset = DataUtils.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation()); + + final File directory = new File(settings.getSaveTreeLocation()); + if(directory.exists()){ + cleanup(directory); + } + assertFalse(directory.exists()); + directory.mkdir(); + + final ForestTrainer forestTrainer = new ForestTrainer<>(settings, dataset, covariates); + + forestTrainer.trainParallelOnDisk(settings.getNumberOfThreads()); + + assertTrue(directory.exists()); + assertTrue(directory.isDirectory()); + assertEquals(NTREE, directory.listFiles().length); + + + final Forest forest = DataUtils.loadForest(directory, new CompetingRiskFunctionCombiner(new int[]{1,2}, null)); final CovariateRow predictionRow = getPredictionRow(covariates);