Add functionality to analyze using validation sets

This commit is contained in:
Joel Therrien 2018-09-13 12:09:20 -07:00
parent 98cb97a1f1
commit 7008959999
9 changed files with 45 additions and 27 deletions

View file

@ -39,7 +39,7 @@ public class Main {
.map(cs -> cs.build()).collect(Collectors.toList());
if(args[1].equalsIgnoreCase("train")){
final List<Row> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getDataFileLocation());
final List<Row> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation());
final ForestTrainer forestTrainer = new ForestTrainer(settings, dataset, covariates);
@ -79,7 +79,7 @@ public class Main {
return;
}
final List<Row<CompetingRiskResponse>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getDataFileLocation());
final List<Row<CompetingRiskResponse>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getValidationDataLocation());
// Let's reduce this down to n
final int n = Integer.parseInt(args[2]);
@ -88,9 +88,17 @@ public class Main {
final File folder = new File(settings.getSaveTreeLocation());
final Forest<?, CompetingRiskFunctions> forest = DataLoader.loadForest(folder, responseCombiner);
System.out.println("Finished loading trees + dataset; creating calculator and evaluating OOB predictions");
final boolean useBootstrapPredictions = settings.getTrainingDataLocation().equals(settings.getValidationDataLocation());
final CompetingRiskErrorRateCalculator errorRateCalculator = new CompetingRiskErrorRateCalculator(dataset, forest);
if(useBootstrapPredictions){
System.out.println("Finished loading trees + dataset; creating calculator and evaluating OOB predictions");
}
else{
System.out.println("Finished loading trees + dataset; creating calculator and evaluating predictions");
}
final CompetingRiskErrorRateCalculator errorRateCalculator = new CompetingRiskErrorRateCalculator(dataset, forest, useBootstrapPredictions);
final PrintWriter printWriter = new PrintWriter(settings.getSaveTreeLocation() + "/errors.txt");
System.out.println("Running Naive Mortality");
@ -166,7 +174,8 @@ public class Main {
new FactorCovariateSettings("x3", Utils.easyList("cat", "mouse", "dog"))
)
)
.dataFileLocation("data.csv")
.trainingDataLocation("training_data.csv")
.validationDataLocation("validation_data.csv")
.responseCombinerSettings(responseCombinerSettings)
.treeCombinerSettings(treeCombinerSettings)
.groupDifferentiatorSettings(groupDifferentiatorSettings)

View file

@ -204,7 +204,8 @@ public class Settings {
// number of trees to try
private int ntree = 500;
private String dataFileLocation = "data.csv";
private String trainingDataLocation = "data.csv";
private String validationDataLocation = "data.csv";
private String saveTreeLocation = "trees/";
private int numberOfThreads = 1;

View file

@ -18,11 +18,17 @@ public class CompetingRiskErrorRateCalculator {
private final List<Row<CompetingRiskResponse>> dataset;
private final List<CompetingRiskFunctions> riskFunctions;
public CompetingRiskErrorRateCalculator(final List<Row<CompetingRiskResponse>> dataset, final Forest<?, CompetingRiskFunctions> forest){
public CompetingRiskErrorRateCalculator(final List<Row<CompetingRiskResponse>> dataset, final Forest<?, CompetingRiskFunctions> forest, boolean useBootstrapPredictions){
this.dataset = dataset;
this.riskFunctions = dataset.stream()
.map(forest::evaluateOOB)
.collect(Collectors.toList());
if(useBootstrapPredictions){
this.riskFunctions = dataset.stream()
.map(forest::evaluateOOB)
.collect(Collectors.toList());
}
else{
this.riskFunctions = forest.evaluate(dataset);
}
}
/**

View file

@ -30,7 +30,7 @@ public class Forest<O, FO> { // O = output of trees, FO = forest output. In prac
* @param rowList List of CovariateRows to evaluate
* @return A List of predictions.
*/
public List<FO> evaluate(List<CovariateRow> rowList){
public List<FO> evaluate(List<? extends CovariateRow> rowList){
return rowList.parallelStream()
.map(this::evaluate)
.collect(Collectors.toList());

View file

@ -60,7 +60,8 @@ public class TestSavingLoading {
new NumericCovariateSettings("cd4nadir")
)
)
.dataFileLocation("src/test/resources/wihs.csv")
.trainingDataLocation("src/test/resources/wihs.csv")
.validationDataLocation("src/test/resources/wihs.csv")
.responseCombinerSettings(responseCombinerSettings)
.treeCombinerSettings(treeCombinerSettings)
.groupDifferentiatorSettings(groupDifferentiatorSettings)
@ -94,7 +95,7 @@ public class TestSavingLoading {
public void testSavingLoading() throws IOException, ClassNotFoundException {
final Settings settings = getSettings();
final List<Covariate> covariates = getCovariates(settings);
final List<Row<CompetingRiskResponse>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getDataFileLocation());
final List<Row<CompetingRiskResponse>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation());
final File directory = new File(settings.getSaveTreeLocation());
assertFalse(directory.exists());

View file

@ -62,7 +62,7 @@ public class TestCompetingRisk {
new NumericCovariateSettings("cd4nadir")
)
)
.dataFileLocation("src/test/resources/wihs.csv")
.trainingDataLocation("src/test/resources/wihs.csv")
.responseCombinerSettings(responseCombinerSettings)
.treeCombinerSettings(treeCombinerSettings)
.groupDifferentiatorSettings(groupDifferentiatorSettings)
@ -95,7 +95,7 @@ public class TestCompetingRisk {
@Test
public void testSingleTree() throws IOException {
final Settings settings = getSettings();
settings.setDataFileLocation("src/test/resources/wihs.bootstrapped.csv");
settings.setTrainingDataLocation("src/test/resources/wihs.bootstrapped.csv");
settings.setCovariates(Utils.easyList(
new BooleanCovariateSettings("idu"),
new BooleanCovariateSettings("black")
@ -103,7 +103,7 @@ public class TestCompetingRisk {
final List<Covariate> covariates = getCovariates(settings);
final List<Row<CompetingRiskResponse>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getDataFileLocation());
final List<Row<CompetingRiskResponse>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation());
final TreeTrainer<CompetingRiskResponse, CompetingRiskFunctions> treeTrainer = new TreeTrainer<>(settings, covariates);
final Node<CompetingRiskFunctions> node = treeTrainer.growTree(dataset);
@ -152,11 +152,11 @@ public class TestCompetingRisk {
final Settings settings = getSettings();
settings.setMtry(4);
settings.setNumberOfSplits(0);
settings.setDataFileLocation("src/test/resources/wihs.bootstrapped2.csv");
settings.setTrainingDataLocation("src/test/resources/wihs.bootstrapped2.csv");
final List<Covariate> covariates = getCovariates(settings);
final List<Row<CompetingRiskResponse>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getDataFileLocation());
final List<Row<CompetingRiskResponse>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation());
final TreeTrainer<CompetingRiskResponse, CompetingRiskFunctions> treeTrainer = new TreeTrainer<>(settings, covariates);
final Node<CompetingRiskFunctions> node = treeTrainer.growTree(dataset);
@ -206,7 +206,7 @@ public class TestCompetingRisk {
final List<Covariate> covariates = getCovariates(settings);
final List<Row<CompetingRiskResponse>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getDataFileLocation());
final List<Row<CompetingRiskResponse>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation());
final ForestTrainer<CompetingRiskResponse, CompetingRiskFunctions, CompetingRiskFunctions> forestTrainer = new ForestTrainer<>(settings, dataset, covariates);
@ -231,7 +231,7 @@ public class TestCompetingRisk {
closeEnough(0.163, functions.getCumulativeIncidenceFunction(2).evaluate(4.0).getY(), 0.01);
closeEnough(0.195, functions.getCumulativeIncidenceFunction(2).evaluate(10.8).getY(), 0.01);
final CompetingRiskErrorRateCalculator errorRateCalculator = new CompetingRiskErrorRateCalculator(dataset, forest);
final CompetingRiskErrorRateCalculator errorRateCalculator = new CompetingRiskErrorRateCalculator(dataset, forest, true);
final double[] errorRates = errorRateCalculator.calculateConcordance(new int[]{1,2});
// Error rates happen to be about the same
@ -256,7 +256,7 @@ public class TestCompetingRisk {
final List<Covariate> covariates = getCovariates(settings);
final List<Row<CompetingRiskResponse>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getDataFileLocation());
final List<Row<CompetingRiskResponse>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation());
// Let's count the events and make sure the data was correctly read.
int countCensored = 0;
@ -292,7 +292,7 @@ public class TestCompetingRisk {
settings.setNtree(300); // results are too variable at 100
final List<Covariate> covariates = getCovariates(settings);
final List<Row<CompetingRiskResponse>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getDataFileLocation());
final List<Row<CompetingRiskResponse>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation());
final ForestTrainer<CompetingRiskResponse, CompetingRiskFunctions, CompetingRiskFunctions> forestTrainer = new ForestTrainer<>(settings, dataset, covariates);
final Forest<CompetingRiskFunctions, CompetingRiskFunctions> forest = forestTrainer.trainSerial();
@ -313,7 +313,7 @@ public class TestCompetingRisk {
// We seem to consistently underestimate the results.
assertTrue(causeOneCIFPoints.get(causeOneCIFPoints.size()-1).getY() > 0.74, "Results should match randomForestSRC; had " + causeOneCIFPoints.get(causeOneCIFPoints.size()-1).getY()); // note; most observations from randomForestSRC hover around 0.78 but I've seen it as low as 0.72
final CompetingRiskErrorRateCalculator errorRateCalculator = new CompetingRiskErrorRateCalculator(dataset, forest);
final CompetingRiskErrorRateCalculator errorRateCalculator = new CompetingRiskErrorRateCalculator(dataset, forest, true);
final double[] errorRates = errorRateCalculator.calculateConcordance(new int[]{1,2});
System.out.println(errorRates[0]);

View file

@ -93,7 +93,7 @@ public class TestCompetingRiskErrorRateCalculator {
when(mockForest.evaluateOOB(dataset.get(3))).thenReturn(function4);
final CompetingRiskErrorRateCalculator errorRateCalculator = new CompetingRiskErrorRateCalculator(dataset, mockForest);
final CompetingRiskErrorRateCalculator errorRateCalculator = new CompetingRiskErrorRateCalculator(dataset, mockForest, true);
final double error = errorRateCalculator.calculateNaiveMortalityError(new int[]{1,2});

View file

@ -37,7 +37,7 @@ public class TestLoadingCSV {
yVarSettings.set("name", new TextNode("y"));
final Settings settings = Settings.builder()
.dataFileLocation(filename)
.trainingDataLocation(filename)
.covariates(
Utils.easyList(new NumericCovariateSettings("x1"),
new FactorCovariateSettings("x2", Utils.easyList("dog", "cat", "mouse")),
@ -52,7 +52,7 @@ public class TestLoadingCSV {
final DataLoader.ResponseLoader loader = settings.getResponseLoader();
return DataLoader.loadData(covariates, loader, settings.getDataFileLocation());
return DataLoader.loadData(covariates, loader, settings.getTrainingDataLocation());
}
@Test

View file

@ -37,7 +37,8 @@ public class TestPersistence {
new FactorCovariateSettings("x3", Utils.easyList("cat", "mouse", "dog"))
)
)
.dataFileLocation("data.csv")
.trainingDataLocation("training_data.csv")
.validationDataLocation("validation_data.csv")
.responseCombinerSettings(responseCombinerSettings)
.treeCombinerSettings(treeCombinerSettings)
.groupDifferentiatorSettings(groupDifferentiatorSettings)