Add functionality to analyze using validation sets

This commit is contained in:
Joel Therrien 2018-09-13 12:09:20 -07:00
parent 98cb97a1f1
commit 7008959999
9 changed files with 45 additions and 27 deletions

View file

@ -39,7 +39,7 @@ public class Main {
.map(cs -> cs.build()).collect(Collectors.toList()); .map(cs -> cs.build()).collect(Collectors.toList());
if(args[1].equalsIgnoreCase("train")){ if(args[1].equalsIgnoreCase("train")){
final List<Row> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getDataFileLocation()); final List<Row> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation());
final ForestTrainer forestTrainer = new ForestTrainer(settings, dataset, covariates); final ForestTrainer forestTrainer = new ForestTrainer(settings, dataset, covariates);
@ -79,7 +79,7 @@ public class Main {
return; return;
} }
final List<Row<CompetingRiskResponse>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getDataFileLocation()); final List<Row<CompetingRiskResponse>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getValidationDataLocation());
// Let's reduce this down to n // Let's reduce this down to n
final int n = Integer.parseInt(args[2]); final int n = Integer.parseInt(args[2]);
@ -88,9 +88,17 @@ public class Main {
final File folder = new File(settings.getSaveTreeLocation()); final File folder = new File(settings.getSaveTreeLocation());
final Forest<?, CompetingRiskFunctions> forest = DataLoader.loadForest(folder, responseCombiner); final Forest<?, CompetingRiskFunctions> forest = DataLoader.loadForest(folder, responseCombiner);
System.out.println("Finished loading trees + dataset; creating calculator and evaluating OOB predictions"); final boolean useBootstrapPredictions = settings.getTrainingDataLocation().equals(settings.getValidationDataLocation());
final CompetingRiskErrorRateCalculator errorRateCalculator = new CompetingRiskErrorRateCalculator(dataset, forest); if(useBootstrapPredictions){
System.out.println("Finished loading trees + dataset; creating calculator and evaluating OOB predictions");
}
else{
System.out.println("Finished loading trees + dataset; creating calculator and evaluating predictions");
}
final CompetingRiskErrorRateCalculator errorRateCalculator = new CompetingRiskErrorRateCalculator(dataset, forest, useBootstrapPredictions);
final PrintWriter printWriter = new PrintWriter(settings.getSaveTreeLocation() + "/errors.txt"); final PrintWriter printWriter = new PrintWriter(settings.getSaveTreeLocation() + "/errors.txt");
System.out.println("Running Naive Mortality"); System.out.println("Running Naive Mortality");
@ -166,7 +174,8 @@ public class Main {
new FactorCovariateSettings("x3", Utils.easyList("cat", "mouse", "dog")) new FactorCovariateSettings("x3", Utils.easyList("cat", "mouse", "dog"))
) )
) )
.dataFileLocation("data.csv") .trainingDataLocation("training_data.csv")
.validationDataLocation("validation_data.csv")
.responseCombinerSettings(responseCombinerSettings) .responseCombinerSettings(responseCombinerSettings)
.treeCombinerSettings(treeCombinerSettings) .treeCombinerSettings(treeCombinerSettings)
.groupDifferentiatorSettings(groupDifferentiatorSettings) .groupDifferentiatorSettings(groupDifferentiatorSettings)

View file

@ -204,7 +204,8 @@ public class Settings {
// number of trees to try // number of trees to try
private int ntree = 500; private int ntree = 500;
private String dataFileLocation = "data.csv"; private String trainingDataLocation = "data.csv";
private String validationDataLocation = "data.csv";
private String saveTreeLocation = "trees/"; private String saveTreeLocation = "trees/";
private int numberOfThreads = 1; private int numberOfThreads = 1;

View file

@ -18,11 +18,17 @@ public class CompetingRiskErrorRateCalculator {
private final List<Row<CompetingRiskResponse>> dataset; private final List<Row<CompetingRiskResponse>> dataset;
private final List<CompetingRiskFunctions> riskFunctions; private final List<CompetingRiskFunctions> riskFunctions;
public CompetingRiskErrorRateCalculator(final List<Row<CompetingRiskResponse>> dataset, final Forest<?, CompetingRiskFunctions> forest){ public CompetingRiskErrorRateCalculator(final List<Row<CompetingRiskResponse>> dataset, final Forest<?, CompetingRiskFunctions> forest, boolean useBootstrapPredictions){
this.dataset = dataset; this.dataset = dataset;
this.riskFunctions = dataset.stream() if(useBootstrapPredictions){
.map(forest::evaluateOOB) this.riskFunctions = dataset.stream()
.collect(Collectors.toList()); .map(forest::evaluateOOB)
.collect(Collectors.toList());
}
else{
this.riskFunctions = forest.evaluate(dataset);
}
} }
/** /**

View file

@ -30,7 +30,7 @@ public class Forest<O, FO> { // O = output of trees, FO = forest output. In prac
* @param rowList List of CovariateRows to evaluate * @param rowList List of CovariateRows to evaluate
* @return A List of predictions. * @return A List of predictions.
*/ */
public List<FO> evaluate(List<CovariateRow> rowList){ public List<FO> evaluate(List<? extends CovariateRow> rowList){
return rowList.parallelStream() return rowList.parallelStream()
.map(this::evaluate) .map(this::evaluate)
.collect(Collectors.toList()); .collect(Collectors.toList());

View file

@ -60,7 +60,8 @@ public class TestSavingLoading {
new NumericCovariateSettings("cd4nadir") new NumericCovariateSettings("cd4nadir")
) )
) )
.dataFileLocation("src/test/resources/wihs.csv") .trainingDataLocation("src/test/resources/wihs.csv")
.validationDataLocation("src/test/resources/wihs.csv")
.responseCombinerSettings(responseCombinerSettings) .responseCombinerSettings(responseCombinerSettings)
.treeCombinerSettings(treeCombinerSettings) .treeCombinerSettings(treeCombinerSettings)
.groupDifferentiatorSettings(groupDifferentiatorSettings) .groupDifferentiatorSettings(groupDifferentiatorSettings)
@ -94,7 +95,7 @@ public class TestSavingLoading {
public void testSavingLoading() throws IOException, ClassNotFoundException { public void testSavingLoading() throws IOException, ClassNotFoundException {
final Settings settings = getSettings(); final Settings settings = getSettings();
final List<Covariate> covariates = getCovariates(settings); final List<Covariate> covariates = getCovariates(settings);
final List<Row<CompetingRiskResponse>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getDataFileLocation()); final List<Row<CompetingRiskResponse>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation());
final File directory = new File(settings.getSaveTreeLocation()); final File directory = new File(settings.getSaveTreeLocation());
assertFalse(directory.exists()); assertFalse(directory.exists());

View file

@ -62,7 +62,7 @@ public class TestCompetingRisk {
new NumericCovariateSettings("cd4nadir") new NumericCovariateSettings("cd4nadir")
) )
) )
.dataFileLocation("src/test/resources/wihs.csv") .trainingDataLocation("src/test/resources/wihs.csv")
.responseCombinerSettings(responseCombinerSettings) .responseCombinerSettings(responseCombinerSettings)
.treeCombinerSettings(treeCombinerSettings) .treeCombinerSettings(treeCombinerSettings)
.groupDifferentiatorSettings(groupDifferentiatorSettings) .groupDifferentiatorSettings(groupDifferentiatorSettings)
@ -95,7 +95,7 @@ public class TestCompetingRisk {
@Test @Test
public void testSingleTree() throws IOException { public void testSingleTree() throws IOException {
final Settings settings = getSettings(); final Settings settings = getSettings();
settings.setDataFileLocation("src/test/resources/wihs.bootstrapped.csv"); settings.setTrainingDataLocation("src/test/resources/wihs.bootstrapped.csv");
settings.setCovariates(Utils.easyList( settings.setCovariates(Utils.easyList(
new BooleanCovariateSettings("idu"), new BooleanCovariateSettings("idu"),
new BooleanCovariateSettings("black") new BooleanCovariateSettings("black")
@ -103,7 +103,7 @@ public class TestCompetingRisk {
final List<Covariate> covariates = getCovariates(settings); final List<Covariate> covariates = getCovariates(settings);
final List<Row<CompetingRiskResponse>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getDataFileLocation()); final List<Row<CompetingRiskResponse>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation());
final TreeTrainer<CompetingRiskResponse, CompetingRiskFunctions> treeTrainer = new TreeTrainer<>(settings, covariates); final TreeTrainer<CompetingRiskResponse, CompetingRiskFunctions> treeTrainer = new TreeTrainer<>(settings, covariates);
final Node<CompetingRiskFunctions> node = treeTrainer.growTree(dataset); final Node<CompetingRiskFunctions> node = treeTrainer.growTree(dataset);
@ -152,11 +152,11 @@ public class TestCompetingRisk {
final Settings settings = getSettings(); final Settings settings = getSettings();
settings.setMtry(4); settings.setMtry(4);
settings.setNumberOfSplits(0); settings.setNumberOfSplits(0);
settings.setDataFileLocation("src/test/resources/wihs.bootstrapped2.csv"); settings.setTrainingDataLocation("src/test/resources/wihs.bootstrapped2.csv");
final List<Covariate> covariates = getCovariates(settings); final List<Covariate> covariates = getCovariates(settings);
final List<Row<CompetingRiskResponse>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getDataFileLocation()); final List<Row<CompetingRiskResponse>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation());
final TreeTrainer<CompetingRiskResponse, CompetingRiskFunctions> treeTrainer = new TreeTrainer<>(settings, covariates); final TreeTrainer<CompetingRiskResponse, CompetingRiskFunctions> treeTrainer = new TreeTrainer<>(settings, covariates);
final Node<CompetingRiskFunctions> node = treeTrainer.growTree(dataset); final Node<CompetingRiskFunctions> node = treeTrainer.growTree(dataset);
@ -206,7 +206,7 @@ public class TestCompetingRisk {
final List<Covariate> covariates = getCovariates(settings); final List<Covariate> covariates = getCovariates(settings);
final List<Row<CompetingRiskResponse>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getDataFileLocation()); final List<Row<CompetingRiskResponse>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation());
final ForestTrainer<CompetingRiskResponse, CompetingRiskFunctions, CompetingRiskFunctions> forestTrainer = new ForestTrainer<>(settings, dataset, covariates); final ForestTrainer<CompetingRiskResponse, CompetingRiskFunctions, CompetingRiskFunctions> forestTrainer = new ForestTrainer<>(settings, dataset, covariates);
@ -231,7 +231,7 @@ public class TestCompetingRisk {
closeEnough(0.163, functions.getCumulativeIncidenceFunction(2).evaluate(4.0).getY(), 0.01); closeEnough(0.163, functions.getCumulativeIncidenceFunction(2).evaluate(4.0).getY(), 0.01);
closeEnough(0.195, functions.getCumulativeIncidenceFunction(2).evaluate(10.8).getY(), 0.01); closeEnough(0.195, functions.getCumulativeIncidenceFunction(2).evaluate(10.8).getY(), 0.01);
final CompetingRiskErrorRateCalculator errorRateCalculator = new CompetingRiskErrorRateCalculator(dataset, forest); final CompetingRiskErrorRateCalculator errorRateCalculator = new CompetingRiskErrorRateCalculator(dataset, forest, true);
final double[] errorRates = errorRateCalculator.calculateConcordance(new int[]{1,2}); final double[] errorRates = errorRateCalculator.calculateConcordance(new int[]{1,2});
// Error rates happen to be about the same // Error rates happen to be about the same
@ -256,7 +256,7 @@ public class TestCompetingRisk {
final List<Covariate> covariates = getCovariates(settings); final List<Covariate> covariates = getCovariates(settings);
final List<Row<CompetingRiskResponse>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getDataFileLocation()); final List<Row<CompetingRiskResponse>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation());
// Let's count the events and make sure the data was correctly read. // Let's count the events and make sure the data was correctly read.
int countCensored = 0; int countCensored = 0;
@ -292,7 +292,7 @@ public class TestCompetingRisk {
settings.setNtree(300); // results are too variable at 100 settings.setNtree(300); // results are too variable at 100
final List<Covariate> covariates = getCovariates(settings); final List<Covariate> covariates = getCovariates(settings);
final List<Row<CompetingRiskResponse>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getDataFileLocation()); final List<Row<CompetingRiskResponse>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation());
final ForestTrainer<CompetingRiskResponse, CompetingRiskFunctions, CompetingRiskFunctions> forestTrainer = new ForestTrainer<>(settings, dataset, covariates); final ForestTrainer<CompetingRiskResponse, CompetingRiskFunctions, CompetingRiskFunctions> forestTrainer = new ForestTrainer<>(settings, dataset, covariates);
final Forest<CompetingRiskFunctions, CompetingRiskFunctions> forest = forestTrainer.trainSerial(); final Forest<CompetingRiskFunctions, CompetingRiskFunctions> forest = forestTrainer.trainSerial();
@ -313,7 +313,7 @@ public class TestCompetingRisk {
// We seem to consistently underestimate the results. // We seem to consistently underestimate the results.
assertTrue(causeOneCIFPoints.get(causeOneCIFPoints.size()-1).getY() > 0.74, "Results should match randomForestSRC; had " + causeOneCIFPoints.get(causeOneCIFPoints.size()-1).getY()); // note; most observations from randomForestSRC hover around 0.78 but I've seen it as low as 0.72 assertTrue(causeOneCIFPoints.get(causeOneCIFPoints.size()-1).getY() > 0.74, "Results should match randomForestSRC; had " + causeOneCIFPoints.get(causeOneCIFPoints.size()-1).getY()); // note; most observations from randomForestSRC hover around 0.78 but I've seen it as low as 0.72
final CompetingRiskErrorRateCalculator errorRateCalculator = new CompetingRiskErrorRateCalculator(dataset, forest); final CompetingRiskErrorRateCalculator errorRateCalculator = new CompetingRiskErrorRateCalculator(dataset, forest, true);
final double[] errorRates = errorRateCalculator.calculateConcordance(new int[]{1,2}); final double[] errorRates = errorRateCalculator.calculateConcordance(new int[]{1,2});
System.out.println(errorRates[0]); System.out.println(errorRates[0]);

View file

@ -93,7 +93,7 @@ public class TestCompetingRiskErrorRateCalculator {
when(mockForest.evaluateOOB(dataset.get(3))).thenReturn(function4); when(mockForest.evaluateOOB(dataset.get(3))).thenReturn(function4);
final CompetingRiskErrorRateCalculator errorRateCalculator = new CompetingRiskErrorRateCalculator(dataset, mockForest); final CompetingRiskErrorRateCalculator errorRateCalculator = new CompetingRiskErrorRateCalculator(dataset, mockForest, true);
final double error = errorRateCalculator.calculateNaiveMortalityError(new int[]{1,2}); final double error = errorRateCalculator.calculateNaiveMortalityError(new int[]{1,2});

View file

@ -37,7 +37,7 @@ public class TestLoadingCSV {
yVarSettings.set("name", new TextNode("y")); yVarSettings.set("name", new TextNode("y"));
final Settings settings = Settings.builder() final Settings settings = Settings.builder()
.dataFileLocation(filename) .trainingDataLocation(filename)
.covariates( .covariates(
Utils.easyList(new NumericCovariateSettings("x1"), Utils.easyList(new NumericCovariateSettings("x1"),
new FactorCovariateSettings("x2", Utils.easyList("dog", "cat", "mouse")), new FactorCovariateSettings("x2", Utils.easyList("dog", "cat", "mouse")),
@ -52,7 +52,7 @@ public class TestLoadingCSV {
final DataLoader.ResponseLoader loader = settings.getResponseLoader(); final DataLoader.ResponseLoader loader = settings.getResponseLoader();
return DataLoader.loadData(covariates, loader, settings.getDataFileLocation()); return DataLoader.loadData(covariates, loader, settings.getTrainingDataLocation());
} }
@Test @Test

View file

@ -37,7 +37,8 @@ public class TestPersistence {
new FactorCovariateSettings("x3", Utils.easyList("cat", "mouse", "dog")) new FactorCovariateSettings("x3", Utils.easyList("cat", "mouse", "dog"))
) )
) )
.dataFileLocation("data.csv") .trainingDataLocation("training_data.csv")
.validationDataLocation("validation_data.csv")
.responseCombinerSettings(responseCombinerSettings) .responseCombinerSettings(responseCombinerSettings)
.treeCombinerSettings(treeCombinerSettings) .treeCombinerSettings(treeCombinerSettings)
.groupDifferentiatorSettings(groupDifferentiatorSettings) .groupDifferentiatorSettings(groupDifferentiatorSettings)