Add functionality to analyze using validation sets
This commit is contained in:
parent
98cb97a1f1
commit
7008959999
9 changed files with 45 additions and 27 deletions
|
@ -39,7 +39,7 @@ public class Main {
|
|||
.map(cs -> cs.build()).collect(Collectors.toList());
|
||||
|
||||
if(args[1].equalsIgnoreCase("train")){
|
||||
final List<Row> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getDataFileLocation());
|
||||
final List<Row> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation());
|
||||
|
||||
final ForestTrainer forestTrainer = new ForestTrainer(settings, dataset, covariates);
|
||||
|
||||
|
@ -79,7 +79,7 @@ public class Main {
|
|||
return;
|
||||
}
|
||||
|
||||
final List<Row<CompetingRiskResponse>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getDataFileLocation());
|
||||
final List<Row<CompetingRiskResponse>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getValidationDataLocation());
|
||||
|
||||
// Let's reduce this down to n
|
||||
final int n = Integer.parseInt(args[2]);
|
||||
|
@ -88,9 +88,17 @@ public class Main {
|
|||
final File folder = new File(settings.getSaveTreeLocation());
|
||||
final Forest<?, CompetingRiskFunctions> forest = DataLoader.loadForest(folder, responseCombiner);
|
||||
|
||||
System.out.println("Finished loading trees + dataset; creating calculator and evaluating OOB predictions");
|
||||
final boolean useBootstrapPredictions = settings.getTrainingDataLocation().equals(settings.getValidationDataLocation());
|
||||
|
||||
final CompetingRiskErrorRateCalculator errorRateCalculator = new CompetingRiskErrorRateCalculator(dataset, forest);
|
||||
if(useBootstrapPredictions){
|
||||
System.out.println("Finished loading trees + dataset; creating calculator and evaluating OOB predictions");
|
||||
}
|
||||
else{
|
||||
System.out.println("Finished loading trees + dataset; creating calculator and evaluating predictions");
|
||||
}
|
||||
|
||||
|
||||
final CompetingRiskErrorRateCalculator errorRateCalculator = new CompetingRiskErrorRateCalculator(dataset, forest, useBootstrapPredictions);
|
||||
final PrintWriter printWriter = new PrintWriter(settings.getSaveTreeLocation() + "/errors.txt");
|
||||
|
||||
System.out.println("Running Naive Mortality");
|
||||
|
@ -166,7 +174,8 @@ public class Main {
|
|||
new FactorCovariateSettings("x3", Utils.easyList("cat", "mouse", "dog"))
|
||||
)
|
||||
)
|
||||
.dataFileLocation("data.csv")
|
||||
.trainingDataLocation("training_data.csv")
|
||||
.validationDataLocation("validation_data.csv")
|
||||
.responseCombinerSettings(responseCombinerSettings)
|
||||
.treeCombinerSettings(treeCombinerSettings)
|
||||
.groupDifferentiatorSettings(groupDifferentiatorSettings)
|
||||
|
|
|
@ -204,7 +204,8 @@ public class Settings {
|
|||
// number of trees to try
|
||||
private int ntree = 500;
|
||||
|
||||
private String dataFileLocation = "data.csv";
|
||||
private String trainingDataLocation = "data.csv";
|
||||
private String validationDataLocation = "data.csv";
|
||||
private String saveTreeLocation = "trees/";
|
||||
|
||||
private int numberOfThreads = 1;
|
||||
|
|
|
@ -18,11 +18,17 @@ public class CompetingRiskErrorRateCalculator {
|
|||
private final List<Row<CompetingRiskResponse>> dataset;
|
||||
private final List<CompetingRiskFunctions> riskFunctions;
|
||||
|
||||
public CompetingRiskErrorRateCalculator(final List<Row<CompetingRiskResponse>> dataset, final Forest<?, CompetingRiskFunctions> forest){
|
||||
public CompetingRiskErrorRateCalculator(final List<Row<CompetingRiskResponse>> dataset, final Forest<?, CompetingRiskFunctions> forest, boolean useBootstrapPredictions){
|
||||
this.dataset = dataset;
|
||||
this.riskFunctions = dataset.stream()
|
||||
.map(forest::evaluateOOB)
|
||||
.collect(Collectors.toList());
|
||||
if(useBootstrapPredictions){
|
||||
this.riskFunctions = dataset.stream()
|
||||
.map(forest::evaluateOOB)
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
else{
|
||||
this.riskFunctions = forest.evaluate(dataset);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -30,7 +30,7 @@ public class Forest<O, FO> { // O = output of trees, FO = forest output. In prac
|
|||
* @param rowList List of CovariateRows to evaluate
|
||||
* @return A List of predictions.
|
||||
*/
|
||||
public List<FO> evaluate(List<CovariateRow> rowList){
|
||||
public List<FO> evaluate(List<? extends CovariateRow> rowList){
|
||||
return rowList.parallelStream()
|
||||
.map(this::evaluate)
|
||||
.collect(Collectors.toList());
|
||||
|
|
|
@ -60,7 +60,8 @@ public class TestSavingLoading {
|
|||
new NumericCovariateSettings("cd4nadir")
|
||||
)
|
||||
)
|
||||
.dataFileLocation("src/test/resources/wihs.csv")
|
||||
.trainingDataLocation("src/test/resources/wihs.csv")
|
||||
.validationDataLocation("src/test/resources/wihs.csv")
|
||||
.responseCombinerSettings(responseCombinerSettings)
|
||||
.treeCombinerSettings(treeCombinerSettings)
|
||||
.groupDifferentiatorSettings(groupDifferentiatorSettings)
|
||||
|
@ -94,7 +95,7 @@ public class TestSavingLoading {
|
|||
public void testSavingLoading() throws IOException, ClassNotFoundException {
|
||||
final Settings settings = getSettings();
|
||||
final List<Covariate> covariates = getCovariates(settings);
|
||||
final List<Row<CompetingRiskResponse>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getDataFileLocation());
|
||||
final List<Row<CompetingRiskResponse>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation());
|
||||
|
||||
final File directory = new File(settings.getSaveTreeLocation());
|
||||
assertFalse(directory.exists());
|
||||
|
|
|
@ -62,7 +62,7 @@ public class TestCompetingRisk {
|
|||
new NumericCovariateSettings("cd4nadir")
|
||||
)
|
||||
)
|
||||
.dataFileLocation("src/test/resources/wihs.csv")
|
||||
.trainingDataLocation("src/test/resources/wihs.csv")
|
||||
.responseCombinerSettings(responseCombinerSettings)
|
||||
.treeCombinerSettings(treeCombinerSettings)
|
||||
.groupDifferentiatorSettings(groupDifferentiatorSettings)
|
||||
|
@ -95,7 +95,7 @@ public class TestCompetingRisk {
|
|||
@Test
|
||||
public void testSingleTree() throws IOException {
|
||||
final Settings settings = getSettings();
|
||||
settings.setDataFileLocation("src/test/resources/wihs.bootstrapped.csv");
|
||||
settings.setTrainingDataLocation("src/test/resources/wihs.bootstrapped.csv");
|
||||
settings.setCovariates(Utils.easyList(
|
||||
new BooleanCovariateSettings("idu"),
|
||||
new BooleanCovariateSettings("black")
|
||||
|
@ -103,7 +103,7 @@ public class TestCompetingRisk {
|
|||
|
||||
final List<Covariate> covariates = getCovariates(settings);
|
||||
|
||||
final List<Row<CompetingRiskResponse>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getDataFileLocation());
|
||||
final List<Row<CompetingRiskResponse>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation());
|
||||
|
||||
final TreeTrainer<CompetingRiskResponse, CompetingRiskFunctions> treeTrainer = new TreeTrainer<>(settings, covariates);
|
||||
final Node<CompetingRiskFunctions> node = treeTrainer.growTree(dataset);
|
||||
|
@ -152,11 +152,11 @@ public class TestCompetingRisk {
|
|||
final Settings settings = getSettings();
|
||||
settings.setMtry(4);
|
||||
settings.setNumberOfSplits(0);
|
||||
settings.setDataFileLocation("src/test/resources/wihs.bootstrapped2.csv");
|
||||
settings.setTrainingDataLocation("src/test/resources/wihs.bootstrapped2.csv");
|
||||
|
||||
final List<Covariate> covariates = getCovariates(settings);
|
||||
|
||||
final List<Row<CompetingRiskResponse>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getDataFileLocation());
|
||||
final List<Row<CompetingRiskResponse>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation());
|
||||
|
||||
final TreeTrainer<CompetingRiskResponse, CompetingRiskFunctions> treeTrainer = new TreeTrainer<>(settings, covariates);
|
||||
final Node<CompetingRiskFunctions> node = treeTrainer.growTree(dataset);
|
||||
|
@ -206,7 +206,7 @@ public class TestCompetingRisk {
|
|||
|
||||
final List<Covariate> covariates = getCovariates(settings);
|
||||
|
||||
final List<Row<CompetingRiskResponse>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getDataFileLocation());
|
||||
final List<Row<CompetingRiskResponse>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation());
|
||||
|
||||
final ForestTrainer<CompetingRiskResponse, CompetingRiskFunctions, CompetingRiskFunctions> forestTrainer = new ForestTrainer<>(settings, dataset, covariates);
|
||||
|
||||
|
@ -231,7 +231,7 @@ public class TestCompetingRisk {
|
|||
closeEnough(0.163, functions.getCumulativeIncidenceFunction(2).evaluate(4.0).getY(), 0.01);
|
||||
closeEnough(0.195, functions.getCumulativeIncidenceFunction(2).evaluate(10.8).getY(), 0.01);
|
||||
|
||||
final CompetingRiskErrorRateCalculator errorRateCalculator = new CompetingRiskErrorRateCalculator(dataset, forest);
|
||||
final CompetingRiskErrorRateCalculator errorRateCalculator = new CompetingRiskErrorRateCalculator(dataset, forest, true);
|
||||
final double[] errorRates = errorRateCalculator.calculateConcordance(new int[]{1,2});
|
||||
|
||||
// Error rates happen to be about the same
|
||||
|
@ -256,7 +256,7 @@ public class TestCompetingRisk {
|
|||
|
||||
final List<Covariate> covariates = getCovariates(settings);
|
||||
|
||||
final List<Row<CompetingRiskResponse>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getDataFileLocation());
|
||||
final List<Row<CompetingRiskResponse>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation());
|
||||
|
||||
// Let's count the events and make sure the data was correctly read.
|
||||
int countCensored = 0;
|
||||
|
@ -292,7 +292,7 @@ public class TestCompetingRisk {
|
|||
settings.setNtree(300); // results are too variable at 100
|
||||
|
||||
final List<Covariate> covariates = getCovariates(settings);
|
||||
final List<Row<CompetingRiskResponse>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getDataFileLocation());
|
||||
final List<Row<CompetingRiskResponse>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation());
|
||||
final ForestTrainer<CompetingRiskResponse, CompetingRiskFunctions, CompetingRiskFunctions> forestTrainer = new ForestTrainer<>(settings, dataset, covariates);
|
||||
final Forest<CompetingRiskFunctions, CompetingRiskFunctions> forest = forestTrainer.trainSerial();
|
||||
|
||||
|
@ -313,7 +313,7 @@ public class TestCompetingRisk {
|
|||
// We seem to consistently underestimate the results.
|
||||
assertTrue(causeOneCIFPoints.get(causeOneCIFPoints.size()-1).getY() > 0.74, "Results should match randomForestSRC; had " + causeOneCIFPoints.get(causeOneCIFPoints.size()-1).getY()); // note; most observations from randomForestSRC hover around 0.78 but I've seen it as low as 0.72
|
||||
|
||||
final CompetingRiskErrorRateCalculator errorRateCalculator = new CompetingRiskErrorRateCalculator(dataset, forest);
|
||||
final CompetingRiskErrorRateCalculator errorRateCalculator = new CompetingRiskErrorRateCalculator(dataset, forest, true);
|
||||
final double[] errorRates = errorRateCalculator.calculateConcordance(new int[]{1,2});
|
||||
|
||||
System.out.println(errorRates[0]);
|
||||
|
|
|
@ -93,7 +93,7 @@ public class TestCompetingRiskErrorRateCalculator {
|
|||
when(mockForest.evaluateOOB(dataset.get(3))).thenReturn(function4);
|
||||
|
||||
|
||||
final CompetingRiskErrorRateCalculator errorRateCalculator = new CompetingRiskErrorRateCalculator(dataset, mockForest);
|
||||
final CompetingRiskErrorRateCalculator errorRateCalculator = new CompetingRiskErrorRateCalculator(dataset, mockForest, true);
|
||||
|
||||
final double error = errorRateCalculator.calculateNaiveMortalityError(new int[]{1,2});
|
||||
|
||||
|
|
|
@ -37,7 +37,7 @@ public class TestLoadingCSV {
|
|||
yVarSettings.set("name", new TextNode("y"));
|
||||
|
||||
final Settings settings = Settings.builder()
|
||||
.dataFileLocation(filename)
|
||||
.trainingDataLocation(filename)
|
||||
.covariates(
|
||||
Utils.easyList(new NumericCovariateSettings("x1"),
|
||||
new FactorCovariateSettings("x2", Utils.easyList("dog", "cat", "mouse")),
|
||||
|
@ -52,7 +52,7 @@ public class TestLoadingCSV {
|
|||
|
||||
final DataLoader.ResponseLoader loader = settings.getResponseLoader();
|
||||
|
||||
return DataLoader.loadData(covariates, loader, settings.getDataFileLocation());
|
||||
return DataLoader.loadData(covariates, loader, settings.getTrainingDataLocation());
|
||||
}
|
||||
|
||||
@Test
|
||||
|
|
|
@ -37,7 +37,8 @@ public class TestPersistence {
|
|||
new FactorCovariateSettings("x3", Utils.easyList("cat", "mouse", "dog"))
|
||||
)
|
||||
)
|
||||
.dataFileLocation("data.csv")
|
||||
.trainingDataLocation("training_data.csv")
|
||||
.validationDataLocation("validation_data.csv")
|
||||
.responseCombinerSettings(responseCombinerSettings)
|
||||
.treeCombinerSettings(treeCombinerSettings)
|
||||
.groupDifferentiatorSettings(groupDifferentiatorSettings)
|
||||
|
|
Loading…
Reference in a new issue