Add functionality to analyze using validation sets
This commit is contained in:
parent
98cb97a1f1
commit
7008959999
9 changed files with 45 additions and 27 deletions
|
@ -39,7 +39,7 @@ public class Main {
|
||||||
.map(cs -> cs.build()).collect(Collectors.toList());
|
.map(cs -> cs.build()).collect(Collectors.toList());
|
||||||
|
|
||||||
if(args[1].equalsIgnoreCase("train")){
|
if(args[1].equalsIgnoreCase("train")){
|
||||||
final List<Row> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getDataFileLocation());
|
final List<Row> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation());
|
||||||
|
|
||||||
final ForestTrainer forestTrainer = new ForestTrainer(settings, dataset, covariates);
|
final ForestTrainer forestTrainer = new ForestTrainer(settings, dataset, covariates);
|
||||||
|
|
||||||
|
@ -79,7 +79,7 @@ public class Main {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
final List<Row<CompetingRiskResponse>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getDataFileLocation());
|
final List<Row<CompetingRiskResponse>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getValidationDataLocation());
|
||||||
|
|
||||||
// Let's reduce this down to n
|
// Let's reduce this down to n
|
||||||
final int n = Integer.parseInt(args[2]);
|
final int n = Integer.parseInt(args[2]);
|
||||||
|
@ -88,9 +88,17 @@ public class Main {
|
||||||
final File folder = new File(settings.getSaveTreeLocation());
|
final File folder = new File(settings.getSaveTreeLocation());
|
||||||
final Forest<?, CompetingRiskFunctions> forest = DataLoader.loadForest(folder, responseCombiner);
|
final Forest<?, CompetingRiskFunctions> forest = DataLoader.loadForest(folder, responseCombiner);
|
||||||
|
|
||||||
System.out.println("Finished loading trees + dataset; creating calculator and evaluating OOB predictions");
|
final boolean useBootstrapPredictions = settings.getTrainingDataLocation().equals(settings.getValidationDataLocation());
|
||||||
|
|
||||||
final CompetingRiskErrorRateCalculator errorRateCalculator = new CompetingRiskErrorRateCalculator(dataset, forest);
|
if(useBootstrapPredictions){
|
||||||
|
System.out.println("Finished loading trees + dataset; creating calculator and evaluating OOB predictions");
|
||||||
|
}
|
||||||
|
else{
|
||||||
|
System.out.println("Finished loading trees + dataset; creating calculator and evaluating predictions");
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
final CompetingRiskErrorRateCalculator errorRateCalculator = new CompetingRiskErrorRateCalculator(dataset, forest, useBootstrapPredictions);
|
||||||
final PrintWriter printWriter = new PrintWriter(settings.getSaveTreeLocation() + "/errors.txt");
|
final PrintWriter printWriter = new PrintWriter(settings.getSaveTreeLocation() + "/errors.txt");
|
||||||
|
|
||||||
System.out.println("Running Naive Mortality");
|
System.out.println("Running Naive Mortality");
|
||||||
|
@ -166,7 +174,8 @@ public class Main {
|
||||||
new FactorCovariateSettings("x3", Utils.easyList("cat", "mouse", "dog"))
|
new FactorCovariateSettings("x3", Utils.easyList("cat", "mouse", "dog"))
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
.dataFileLocation("data.csv")
|
.trainingDataLocation("training_data.csv")
|
||||||
|
.validationDataLocation("validation_data.csv")
|
||||||
.responseCombinerSettings(responseCombinerSettings)
|
.responseCombinerSettings(responseCombinerSettings)
|
||||||
.treeCombinerSettings(treeCombinerSettings)
|
.treeCombinerSettings(treeCombinerSettings)
|
||||||
.groupDifferentiatorSettings(groupDifferentiatorSettings)
|
.groupDifferentiatorSettings(groupDifferentiatorSettings)
|
||||||
|
|
|
@ -204,7 +204,8 @@ public class Settings {
|
||||||
// number of trees to try
|
// number of trees to try
|
||||||
private int ntree = 500;
|
private int ntree = 500;
|
||||||
|
|
||||||
private String dataFileLocation = "data.csv";
|
private String trainingDataLocation = "data.csv";
|
||||||
|
private String validationDataLocation = "data.csv";
|
||||||
private String saveTreeLocation = "trees/";
|
private String saveTreeLocation = "trees/";
|
||||||
|
|
||||||
private int numberOfThreads = 1;
|
private int numberOfThreads = 1;
|
||||||
|
|
|
@ -18,11 +18,17 @@ public class CompetingRiskErrorRateCalculator {
|
||||||
private final List<Row<CompetingRiskResponse>> dataset;
|
private final List<Row<CompetingRiskResponse>> dataset;
|
||||||
private final List<CompetingRiskFunctions> riskFunctions;
|
private final List<CompetingRiskFunctions> riskFunctions;
|
||||||
|
|
||||||
public CompetingRiskErrorRateCalculator(final List<Row<CompetingRiskResponse>> dataset, final Forest<?, CompetingRiskFunctions> forest){
|
public CompetingRiskErrorRateCalculator(final List<Row<CompetingRiskResponse>> dataset, final Forest<?, CompetingRiskFunctions> forest, boolean useBootstrapPredictions){
|
||||||
this.dataset = dataset;
|
this.dataset = dataset;
|
||||||
this.riskFunctions = dataset.stream()
|
if(useBootstrapPredictions){
|
||||||
.map(forest::evaluateOOB)
|
this.riskFunctions = dataset.stream()
|
||||||
.collect(Collectors.toList());
|
.map(forest::evaluateOOB)
|
||||||
|
.collect(Collectors.toList());
|
||||||
|
}
|
||||||
|
else{
|
||||||
|
this.riskFunctions = forest.evaluate(dataset);
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -30,7 +30,7 @@ public class Forest<O, FO> { // O = output of trees, FO = forest output. In prac
|
||||||
* @param rowList List of CovariateRows to evaluate
|
* @param rowList List of CovariateRows to evaluate
|
||||||
* @return A List of predictions.
|
* @return A List of predictions.
|
||||||
*/
|
*/
|
||||||
public List<FO> evaluate(List<CovariateRow> rowList){
|
public List<FO> evaluate(List<? extends CovariateRow> rowList){
|
||||||
return rowList.parallelStream()
|
return rowList.parallelStream()
|
||||||
.map(this::evaluate)
|
.map(this::evaluate)
|
||||||
.collect(Collectors.toList());
|
.collect(Collectors.toList());
|
||||||
|
|
|
@ -60,7 +60,8 @@ public class TestSavingLoading {
|
||||||
new NumericCovariateSettings("cd4nadir")
|
new NumericCovariateSettings("cd4nadir")
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
.dataFileLocation("src/test/resources/wihs.csv")
|
.trainingDataLocation("src/test/resources/wihs.csv")
|
||||||
|
.validationDataLocation("src/test/resources/wihs.csv")
|
||||||
.responseCombinerSettings(responseCombinerSettings)
|
.responseCombinerSettings(responseCombinerSettings)
|
||||||
.treeCombinerSettings(treeCombinerSettings)
|
.treeCombinerSettings(treeCombinerSettings)
|
||||||
.groupDifferentiatorSettings(groupDifferentiatorSettings)
|
.groupDifferentiatorSettings(groupDifferentiatorSettings)
|
||||||
|
@ -94,7 +95,7 @@ public class TestSavingLoading {
|
||||||
public void testSavingLoading() throws IOException, ClassNotFoundException {
|
public void testSavingLoading() throws IOException, ClassNotFoundException {
|
||||||
final Settings settings = getSettings();
|
final Settings settings = getSettings();
|
||||||
final List<Covariate> covariates = getCovariates(settings);
|
final List<Covariate> covariates = getCovariates(settings);
|
||||||
final List<Row<CompetingRiskResponse>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getDataFileLocation());
|
final List<Row<CompetingRiskResponse>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation());
|
||||||
|
|
||||||
final File directory = new File(settings.getSaveTreeLocation());
|
final File directory = new File(settings.getSaveTreeLocation());
|
||||||
assertFalse(directory.exists());
|
assertFalse(directory.exists());
|
||||||
|
|
|
@ -62,7 +62,7 @@ public class TestCompetingRisk {
|
||||||
new NumericCovariateSettings("cd4nadir")
|
new NumericCovariateSettings("cd4nadir")
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
.dataFileLocation("src/test/resources/wihs.csv")
|
.trainingDataLocation("src/test/resources/wihs.csv")
|
||||||
.responseCombinerSettings(responseCombinerSettings)
|
.responseCombinerSettings(responseCombinerSettings)
|
||||||
.treeCombinerSettings(treeCombinerSettings)
|
.treeCombinerSettings(treeCombinerSettings)
|
||||||
.groupDifferentiatorSettings(groupDifferentiatorSettings)
|
.groupDifferentiatorSettings(groupDifferentiatorSettings)
|
||||||
|
@ -95,7 +95,7 @@ public class TestCompetingRisk {
|
||||||
@Test
|
@Test
|
||||||
public void testSingleTree() throws IOException {
|
public void testSingleTree() throws IOException {
|
||||||
final Settings settings = getSettings();
|
final Settings settings = getSettings();
|
||||||
settings.setDataFileLocation("src/test/resources/wihs.bootstrapped.csv");
|
settings.setTrainingDataLocation("src/test/resources/wihs.bootstrapped.csv");
|
||||||
settings.setCovariates(Utils.easyList(
|
settings.setCovariates(Utils.easyList(
|
||||||
new BooleanCovariateSettings("idu"),
|
new BooleanCovariateSettings("idu"),
|
||||||
new BooleanCovariateSettings("black")
|
new BooleanCovariateSettings("black")
|
||||||
|
@ -103,7 +103,7 @@ public class TestCompetingRisk {
|
||||||
|
|
||||||
final List<Covariate> covariates = getCovariates(settings);
|
final List<Covariate> covariates = getCovariates(settings);
|
||||||
|
|
||||||
final List<Row<CompetingRiskResponse>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getDataFileLocation());
|
final List<Row<CompetingRiskResponse>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation());
|
||||||
|
|
||||||
final TreeTrainer<CompetingRiskResponse, CompetingRiskFunctions> treeTrainer = new TreeTrainer<>(settings, covariates);
|
final TreeTrainer<CompetingRiskResponse, CompetingRiskFunctions> treeTrainer = new TreeTrainer<>(settings, covariates);
|
||||||
final Node<CompetingRiskFunctions> node = treeTrainer.growTree(dataset);
|
final Node<CompetingRiskFunctions> node = treeTrainer.growTree(dataset);
|
||||||
|
@ -152,11 +152,11 @@ public class TestCompetingRisk {
|
||||||
final Settings settings = getSettings();
|
final Settings settings = getSettings();
|
||||||
settings.setMtry(4);
|
settings.setMtry(4);
|
||||||
settings.setNumberOfSplits(0);
|
settings.setNumberOfSplits(0);
|
||||||
settings.setDataFileLocation("src/test/resources/wihs.bootstrapped2.csv");
|
settings.setTrainingDataLocation("src/test/resources/wihs.bootstrapped2.csv");
|
||||||
|
|
||||||
final List<Covariate> covariates = getCovariates(settings);
|
final List<Covariate> covariates = getCovariates(settings);
|
||||||
|
|
||||||
final List<Row<CompetingRiskResponse>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getDataFileLocation());
|
final List<Row<CompetingRiskResponse>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation());
|
||||||
|
|
||||||
final TreeTrainer<CompetingRiskResponse, CompetingRiskFunctions> treeTrainer = new TreeTrainer<>(settings, covariates);
|
final TreeTrainer<CompetingRiskResponse, CompetingRiskFunctions> treeTrainer = new TreeTrainer<>(settings, covariates);
|
||||||
final Node<CompetingRiskFunctions> node = treeTrainer.growTree(dataset);
|
final Node<CompetingRiskFunctions> node = treeTrainer.growTree(dataset);
|
||||||
|
@ -206,7 +206,7 @@ public class TestCompetingRisk {
|
||||||
|
|
||||||
final List<Covariate> covariates = getCovariates(settings);
|
final List<Covariate> covariates = getCovariates(settings);
|
||||||
|
|
||||||
final List<Row<CompetingRiskResponse>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getDataFileLocation());
|
final List<Row<CompetingRiskResponse>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation());
|
||||||
|
|
||||||
final ForestTrainer<CompetingRiskResponse, CompetingRiskFunctions, CompetingRiskFunctions> forestTrainer = new ForestTrainer<>(settings, dataset, covariates);
|
final ForestTrainer<CompetingRiskResponse, CompetingRiskFunctions, CompetingRiskFunctions> forestTrainer = new ForestTrainer<>(settings, dataset, covariates);
|
||||||
|
|
||||||
|
@ -231,7 +231,7 @@ public class TestCompetingRisk {
|
||||||
closeEnough(0.163, functions.getCumulativeIncidenceFunction(2).evaluate(4.0).getY(), 0.01);
|
closeEnough(0.163, functions.getCumulativeIncidenceFunction(2).evaluate(4.0).getY(), 0.01);
|
||||||
closeEnough(0.195, functions.getCumulativeIncidenceFunction(2).evaluate(10.8).getY(), 0.01);
|
closeEnough(0.195, functions.getCumulativeIncidenceFunction(2).evaluate(10.8).getY(), 0.01);
|
||||||
|
|
||||||
final CompetingRiskErrorRateCalculator errorRateCalculator = new CompetingRiskErrorRateCalculator(dataset, forest);
|
final CompetingRiskErrorRateCalculator errorRateCalculator = new CompetingRiskErrorRateCalculator(dataset, forest, true);
|
||||||
final double[] errorRates = errorRateCalculator.calculateConcordance(new int[]{1,2});
|
final double[] errorRates = errorRateCalculator.calculateConcordance(new int[]{1,2});
|
||||||
|
|
||||||
// Error rates happen to be about the same
|
// Error rates happen to be about the same
|
||||||
|
@ -256,7 +256,7 @@ public class TestCompetingRisk {
|
||||||
|
|
||||||
final List<Covariate> covariates = getCovariates(settings);
|
final List<Covariate> covariates = getCovariates(settings);
|
||||||
|
|
||||||
final List<Row<CompetingRiskResponse>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getDataFileLocation());
|
final List<Row<CompetingRiskResponse>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation());
|
||||||
|
|
||||||
// Let's count the events and make sure the data was correctly read.
|
// Let's count the events and make sure the data was correctly read.
|
||||||
int countCensored = 0;
|
int countCensored = 0;
|
||||||
|
@ -292,7 +292,7 @@ public class TestCompetingRisk {
|
||||||
settings.setNtree(300); // results are too variable at 100
|
settings.setNtree(300); // results are too variable at 100
|
||||||
|
|
||||||
final List<Covariate> covariates = getCovariates(settings);
|
final List<Covariate> covariates = getCovariates(settings);
|
||||||
final List<Row<CompetingRiskResponse>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getDataFileLocation());
|
final List<Row<CompetingRiskResponse>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation());
|
||||||
final ForestTrainer<CompetingRiskResponse, CompetingRiskFunctions, CompetingRiskFunctions> forestTrainer = new ForestTrainer<>(settings, dataset, covariates);
|
final ForestTrainer<CompetingRiskResponse, CompetingRiskFunctions, CompetingRiskFunctions> forestTrainer = new ForestTrainer<>(settings, dataset, covariates);
|
||||||
final Forest<CompetingRiskFunctions, CompetingRiskFunctions> forest = forestTrainer.trainSerial();
|
final Forest<CompetingRiskFunctions, CompetingRiskFunctions> forest = forestTrainer.trainSerial();
|
||||||
|
|
||||||
|
@ -313,7 +313,7 @@ public class TestCompetingRisk {
|
||||||
// We seem to consistently underestimate the results.
|
// We seem to consistently underestimate the results.
|
||||||
assertTrue(causeOneCIFPoints.get(causeOneCIFPoints.size()-1).getY() > 0.74, "Results should match randomForestSRC; had " + causeOneCIFPoints.get(causeOneCIFPoints.size()-1).getY()); // note; most observations from randomForestSRC hover around 0.78 but I've seen it as low as 0.72
|
assertTrue(causeOneCIFPoints.get(causeOneCIFPoints.size()-1).getY() > 0.74, "Results should match randomForestSRC; had " + causeOneCIFPoints.get(causeOneCIFPoints.size()-1).getY()); // note; most observations from randomForestSRC hover around 0.78 but I've seen it as low as 0.72
|
||||||
|
|
||||||
final CompetingRiskErrorRateCalculator errorRateCalculator = new CompetingRiskErrorRateCalculator(dataset, forest);
|
final CompetingRiskErrorRateCalculator errorRateCalculator = new CompetingRiskErrorRateCalculator(dataset, forest, true);
|
||||||
final double[] errorRates = errorRateCalculator.calculateConcordance(new int[]{1,2});
|
final double[] errorRates = errorRateCalculator.calculateConcordance(new int[]{1,2});
|
||||||
|
|
||||||
System.out.println(errorRates[0]);
|
System.out.println(errorRates[0]);
|
||||||
|
|
|
@ -93,7 +93,7 @@ public class TestCompetingRiskErrorRateCalculator {
|
||||||
when(mockForest.evaluateOOB(dataset.get(3))).thenReturn(function4);
|
when(mockForest.evaluateOOB(dataset.get(3))).thenReturn(function4);
|
||||||
|
|
||||||
|
|
||||||
final CompetingRiskErrorRateCalculator errorRateCalculator = new CompetingRiskErrorRateCalculator(dataset, mockForest);
|
final CompetingRiskErrorRateCalculator errorRateCalculator = new CompetingRiskErrorRateCalculator(dataset, mockForest, true);
|
||||||
|
|
||||||
final double error = errorRateCalculator.calculateNaiveMortalityError(new int[]{1,2});
|
final double error = errorRateCalculator.calculateNaiveMortalityError(new int[]{1,2});
|
||||||
|
|
||||||
|
|
|
@ -37,7 +37,7 @@ public class TestLoadingCSV {
|
||||||
yVarSettings.set("name", new TextNode("y"));
|
yVarSettings.set("name", new TextNode("y"));
|
||||||
|
|
||||||
final Settings settings = Settings.builder()
|
final Settings settings = Settings.builder()
|
||||||
.dataFileLocation(filename)
|
.trainingDataLocation(filename)
|
||||||
.covariates(
|
.covariates(
|
||||||
Utils.easyList(new NumericCovariateSettings("x1"),
|
Utils.easyList(new NumericCovariateSettings("x1"),
|
||||||
new FactorCovariateSettings("x2", Utils.easyList("dog", "cat", "mouse")),
|
new FactorCovariateSettings("x2", Utils.easyList("dog", "cat", "mouse")),
|
||||||
|
@ -52,7 +52,7 @@ public class TestLoadingCSV {
|
||||||
|
|
||||||
final DataLoader.ResponseLoader loader = settings.getResponseLoader();
|
final DataLoader.ResponseLoader loader = settings.getResponseLoader();
|
||||||
|
|
||||||
return DataLoader.loadData(covariates, loader, settings.getDataFileLocation());
|
return DataLoader.loadData(covariates, loader, settings.getTrainingDataLocation());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
|
|
@ -37,7 +37,8 @@ public class TestPersistence {
|
||||||
new FactorCovariateSettings("x3", Utils.easyList("cat", "mouse", "dog"))
|
new FactorCovariateSettings("x3", Utils.easyList("cat", "mouse", "dog"))
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
.dataFileLocation("data.csv")
|
.trainingDataLocation("training_data.csv")
|
||||||
|
.validationDataLocation("validation_data.csv")
|
||||||
.responseCombinerSettings(responseCombinerSettings)
|
.responseCombinerSettings(responseCombinerSettings)
|
||||||
.treeCombinerSettings(treeCombinerSettings)
|
.treeCombinerSettings(treeCombinerSettings)
|
||||||
.groupDifferentiatorSettings(groupDifferentiatorSettings)
|
.groupDifferentiatorSettings(groupDifferentiatorSettings)
|
||||||
|
|
Loading…
Reference in a new issue