From 700895999957e32e42a8b71a915480e8f50588e3 Mon Sep 17 00:00:00 2001 From: Joel Therrien Date: Thu, 13 Sep 2018 12:09:20 -0700 Subject: [PATCH] Add functionality to analyze using validation sets --- .../ca/joeltherrien/randomforest/Main.java | 19 +++++++++++++----- .../joeltherrien/randomforest/Settings.java | 3 ++- .../CompetingRiskErrorRateCalculator.java | 14 +++++++++---- .../randomforest/tree/Forest.java | 2 +- .../randomforest/TestSavingLoading.java | 5 +++-- .../competingrisk/TestCompetingRisk.java | 20 +++++++++---------- .../TestCompetingRiskErrorRateCalculator.java | 2 +- .../randomforest/csv/TestLoadingCSV.java | 4 ++-- .../settings/TestPersistence.java | 3 ++- 9 files changed, 45 insertions(+), 27 deletions(-) diff --git a/src/main/java/ca/joeltherrien/randomforest/Main.java b/src/main/java/ca/joeltherrien/randomforest/Main.java index 5fd76a7..d90328d 100644 --- a/src/main/java/ca/joeltherrien/randomforest/Main.java +++ b/src/main/java/ca/joeltherrien/randomforest/Main.java @@ -39,7 +39,7 @@ public class Main { .map(cs -> cs.build()).collect(Collectors.toList()); if(args[1].equalsIgnoreCase("train")){ - final List dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getDataFileLocation()); + final List dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation()); final ForestTrainer forestTrainer = new ForestTrainer(settings, dataset, covariates); @@ -79,7 +79,7 @@ public class Main { return; } - final List> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getDataFileLocation()); + final List> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getValidationDataLocation()); // Let's reduce this down to n final int n = Integer.parseInt(args[2]); @@ -88,9 +88,17 @@ public class Main { final File folder = new File(settings.getSaveTreeLocation()); final Forest forest = DataLoader.loadForest(folder, responseCombiner); - System.out.println("Finished loading trees + dataset; creating calculator and evaluating OOB predictions"); + final boolean useBootstrapPredictions = settings.getTrainingDataLocation().equals(settings.getValidationDataLocation()); - final CompetingRiskErrorRateCalculator errorRateCalculator = new CompetingRiskErrorRateCalculator(dataset, forest); + if(useBootstrapPredictions){ + System.out.println("Finished loading trees + dataset; creating calculator and evaluating OOB predictions"); + } + else{ + System.out.println("Finished loading trees + dataset; creating calculator and evaluating predictions"); + } + + + final CompetingRiskErrorRateCalculator errorRateCalculator = new CompetingRiskErrorRateCalculator(dataset, forest, useBootstrapPredictions); final PrintWriter printWriter = new PrintWriter(settings.getSaveTreeLocation() + "/errors.txt"); System.out.println("Running Naive Mortality"); @@ -166,7 +174,8 @@ public class Main { new FactorCovariateSettings("x3", Utils.easyList("cat", "mouse", "dog")) ) ) - .dataFileLocation("data.csv") + .trainingDataLocation("training_data.csv") + .validationDataLocation("validation_data.csv") .responseCombinerSettings(responseCombinerSettings) .treeCombinerSettings(treeCombinerSettings) .groupDifferentiatorSettings(groupDifferentiatorSettings) diff --git a/src/main/java/ca/joeltherrien/randomforest/Settings.java b/src/main/java/ca/joeltherrien/randomforest/Settings.java index 28b0784..e53374b 100644 --- a/src/main/java/ca/joeltherrien/randomforest/Settings.java +++ b/src/main/java/ca/joeltherrien/randomforest/Settings.java @@ -204,7 +204,8 @@ public class Settings { // number of trees to try private int ntree = 500; - private String dataFileLocation = "data.csv"; + private String trainingDataLocation = "data.csv"; + private String validationDataLocation = "data.csv"; private String saveTreeLocation = "trees/"; private int numberOfThreads = 1; diff --git a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskErrorRateCalculator.java b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskErrorRateCalculator.java index 966d063..ac2e386 100644 --- a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskErrorRateCalculator.java +++ b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskErrorRateCalculator.java @@ -18,11 +18,17 @@ public class CompetingRiskErrorRateCalculator { private final List> dataset; private final List riskFunctions; - public CompetingRiskErrorRateCalculator(final List> dataset, final Forest forest){ + public CompetingRiskErrorRateCalculator(final List> dataset, final Forest forest, boolean useBootstrapPredictions){ this.dataset = dataset; - this.riskFunctions = dataset.stream() - .map(forest::evaluateOOB) - .collect(Collectors.toList()); + if(useBootstrapPredictions){ + this.riskFunctions = dataset.stream() + .map(forest::evaluateOOB) + .collect(Collectors.toList()); + } + else{ + this.riskFunctions = forest.evaluate(dataset); + } + } /** diff --git a/src/main/java/ca/joeltherrien/randomforest/tree/Forest.java b/src/main/java/ca/joeltherrien/randomforest/tree/Forest.java index 9d6913f..848eca8 100644 --- a/src/main/java/ca/joeltherrien/randomforest/tree/Forest.java +++ b/src/main/java/ca/joeltherrien/randomforest/tree/Forest.java @@ -30,7 +30,7 @@ public class Forest { // O = output of trees, FO = forest output. In prac * @param rowList List of CovariateRows to evaluate * @return A List of predictions. */ - public List evaluate(List rowList){ + public List evaluate(List rowList){ return rowList.parallelStream() .map(this::evaluate) .collect(Collectors.toList()); diff --git a/src/test/java/ca/joeltherrien/randomforest/TestSavingLoading.java b/src/test/java/ca/joeltherrien/randomforest/TestSavingLoading.java index 49572c9..f6f800b 100644 --- a/src/test/java/ca/joeltherrien/randomforest/TestSavingLoading.java +++ b/src/test/java/ca/joeltherrien/randomforest/TestSavingLoading.java @@ -60,7 +60,8 @@ public class TestSavingLoading { new NumericCovariateSettings("cd4nadir") ) ) - .dataFileLocation("src/test/resources/wihs.csv") + .trainingDataLocation("src/test/resources/wihs.csv") + .validationDataLocation("src/test/resources/wihs.csv") .responseCombinerSettings(responseCombinerSettings) .treeCombinerSettings(treeCombinerSettings) .groupDifferentiatorSettings(groupDifferentiatorSettings) @@ -94,7 +95,7 @@ public class TestSavingLoading { public void testSavingLoading() throws IOException, ClassNotFoundException { final Settings settings = getSettings(); final List covariates = getCovariates(settings); - final List> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getDataFileLocation()); + final List> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation()); final File directory = new File(settings.getSaveTreeLocation()); assertFalse(directory.exists()); diff --git a/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCompetingRisk.java b/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCompetingRisk.java index db6d0c4..da18ce4 100644 --- a/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCompetingRisk.java +++ b/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCompetingRisk.java @@ -62,7 +62,7 @@ public class TestCompetingRisk { new NumericCovariateSettings("cd4nadir") ) ) - .dataFileLocation("src/test/resources/wihs.csv") + .trainingDataLocation("src/test/resources/wihs.csv") .responseCombinerSettings(responseCombinerSettings) .treeCombinerSettings(treeCombinerSettings) .groupDifferentiatorSettings(groupDifferentiatorSettings) @@ -95,7 +95,7 @@ public class TestCompetingRisk { @Test public void testSingleTree() throws IOException { final Settings settings = getSettings(); - settings.setDataFileLocation("src/test/resources/wihs.bootstrapped.csv"); + settings.setTrainingDataLocation("src/test/resources/wihs.bootstrapped.csv"); settings.setCovariates(Utils.easyList( new BooleanCovariateSettings("idu"), new BooleanCovariateSettings("black") @@ -103,7 +103,7 @@ public class TestCompetingRisk { final List covariates = getCovariates(settings); - final List> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getDataFileLocation()); + final List> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation()); final TreeTrainer treeTrainer = new TreeTrainer<>(settings, covariates); final Node node = treeTrainer.growTree(dataset); @@ -152,11 +152,11 @@ public class TestCompetingRisk { final Settings settings = getSettings(); settings.setMtry(4); settings.setNumberOfSplits(0); - settings.setDataFileLocation("src/test/resources/wihs.bootstrapped2.csv"); + settings.setTrainingDataLocation("src/test/resources/wihs.bootstrapped2.csv"); final List covariates = getCovariates(settings); - final List> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getDataFileLocation()); + final List> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation()); final TreeTrainer treeTrainer = new TreeTrainer<>(settings, covariates); final Node node = treeTrainer.growTree(dataset); @@ -206,7 +206,7 @@ public class TestCompetingRisk { final List covariates = getCovariates(settings); - final List> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getDataFileLocation()); + final List> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation()); final ForestTrainer forestTrainer = new ForestTrainer<>(settings, dataset, covariates); @@ -231,7 +231,7 @@ public class TestCompetingRisk { closeEnough(0.163, functions.getCumulativeIncidenceFunction(2).evaluate(4.0).getY(), 0.01); closeEnough(0.195, functions.getCumulativeIncidenceFunction(2).evaluate(10.8).getY(), 0.01); - final CompetingRiskErrorRateCalculator errorRateCalculator = new CompetingRiskErrorRateCalculator(dataset, forest); + final CompetingRiskErrorRateCalculator errorRateCalculator = new CompetingRiskErrorRateCalculator(dataset, forest, true); final double[] errorRates = errorRateCalculator.calculateConcordance(new int[]{1,2}); // Error rates happen to be about the same @@ -256,7 +256,7 @@ public class TestCompetingRisk { final List covariates = getCovariates(settings); - final List> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getDataFileLocation()); + final List> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation()); // Let's count the events and make sure the data was correctly read. int countCensored = 0; @@ -292,7 +292,7 @@ public class TestCompetingRisk { settings.setNtree(300); // results are too variable at 100 final List covariates = getCovariates(settings); - final List> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getDataFileLocation()); + final List> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation()); final ForestTrainer forestTrainer = new ForestTrainer<>(settings, dataset, covariates); final Forest forest = forestTrainer.trainSerial(); @@ -313,7 +313,7 @@ public class TestCompetingRisk { // We seem to consistently underestimate the results. assertTrue(causeOneCIFPoints.get(causeOneCIFPoints.size()-1).getY() > 0.74, "Results should match randomForestSRC; had " + causeOneCIFPoints.get(causeOneCIFPoints.size()-1).getY()); // note; most observations from randomForestSRC hover around 0.78 but I've seen it as low as 0.72 - final CompetingRiskErrorRateCalculator errorRateCalculator = new CompetingRiskErrorRateCalculator(dataset, forest); + final CompetingRiskErrorRateCalculator errorRateCalculator = new CompetingRiskErrorRateCalculator(dataset, forest, true); final double[] errorRates = errorRateCalculator.calculateConcordance(new int[]{1,2}); System.out.println(errorRates[0]); diff --git a/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCompetingRiskErrorRateCalculator.java b/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCompetingRiskErrorRateCalculator.java index 3e1d367..6dee34f 100644 --- a/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCompetingRiskErrorRateCalculator.java +++ b/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCompetingRiskErrorRateCalculator.java @@ -93,7 +93,7 @@ public class TestCompetingRiskErrorRateCalculator { when(mockForest.evaluateOOB(dataset.get(3))).thenReturn(function4); - final CompetingRiskErrorRateCalculator errorRateCalculator = new CompetingRiskErrorRateCalculator(dataset, mockForest); + final CompetingRiskErrorRateCalculator errorRateCalculator = new CompetingRiskErrorRateCalculator(dataset, mockForest, true); final double error = errorRateCalculator.calculateNaiveMortalityError(new int[]{1,2}); diff --git a/src/test/java/ca/joeltherrien/randomforest/csv/TestLoadingCSV.java b/src/test/java/ca/joeltherrien/randomforest/csv/TestLoadingCSV.java index fcb2c1d..5decd46 100644 --- a/src/test/java/ca/joeltherrien/randomforest/csv/TestLoadingCSV.java +++ b/src/test/java/ca/joeltherrien/randomforest/csv/TestLoadingCSV.java @@ -37,7 +37,7 @@ public class TestLoadingCSV { yVarSettings.set("name", new TextNode("y")); final Settings settings = Settings.builder() - .dataFileLocation(filename) + .trainingDataLocation(filename) .covariates( Utils.easyList(new NumericCovariateSettings("x1"), new FactorCovariateSettings("x2", Utils.easyList("dog", "cat", "mouse")), @@ -52,7 +52,7 @@ public class TestLoadingCSV { final DataLoader.ResponseLoader loader = settings.getResponseLoader(); - return DataLoader.loadData(covariates, loader, settings.getDataFileLocation()); + return DataLoader.loadData(covariates, loader, settings.getTrainingDataLocation()); } @Test diff --git a/src/test/java/ca/joeltherrien/randomforest/settings/TestPersistence.java b/src/test/java/ca/joeltherrien/randomforest/settings/TestPersistence.java index 87f8449..a7f11a8 100644 --- a/src/test/java/ca/joeltherrien/randomforest/settings/TestPersistence.java +++ b/src/test/java/ca/joeltherrien/randomforest/settings/TestPersistence.java @@ -37,7 +37,8 @@ public class TestPersistence { new FactorCovariateSettings("x3", Utils.easyList("cat", "mouse", "dog")) ) ) - .dataFileLocation("data.csv") + .trainingDataLocation("training_data.csv") + .validationDataLocation("validation_data.csv") .responseCombinerSettings(responseCombinerSettings) .treeCombinerSettings(treeCombinerSettings) .groupDifferentiatorSettings(groupDifferentiatorSettings)