Removed naive mortality error measurement
Naive mortality error was an ad-hoc method I implemented earlier on. It didn't provide any useful performance, nor was it theoretically grounded. It's better to remove it before someone accidently uses it.
This commit is contained in:
parent
a887a3cc15
commit
ae40a2e664
4 changed files with 0 additions and 115 deletions
|
@ -100,13 +100,6 @@ public class Main {
|
|||
final CompetingRiskErrorRateCalculator errorRateCalculator = new CompetingRiskErrorRateCalculator(dataset, forest, useBootstrapPredictions);
|
||||
final PrintWriter printWriter = new PrintWriter(settings.getSaveTreeLocation() + "/errors.txt");
|
||||
|
||||
System.out.println("Running Naive Mortality");
|
||||
|
||||
final double naiveMortality = errorRateCalculator.calculateNaiveMortalityError(events);
|
||||
printWriter.write("Naive Mortality: ");
|
||||
printWriter.write(Double.toString(naiveMortality));
|
||||
printWriter.write('\n');
|
||||
|
||||
System.out.println("Running Naive Concordance");
|
||||
|
||||
final double[] naiveConcordance = errorRateCalculator.calculateConcordance(events);
|
||||
|
|
|
@ -31,54 +31,6 @@ public class CompetingRiskErrorRateCalculator {
|
|||
|
||||
}
|
||||
|
||||
/**
|
||||
* Idea for this error rate; go through every observation I have and calculate its mortality for the different events. If the event with the highest mortality is not the one that happened,
|
||||
* then we add one to the error scale.
|
||||
*
|
||||
* Ignore censored observations.
|
||||
*
|
||||
* Possible extensions might involve counting how many other events had higher mortality, instead of just a single PASS / FAIL.
|
||||
*
|
||||
* My observation is that this error rate isn't very useful...
|
||||
*
|
||||
* @return
|
||||
*/
|
||||
public double calculateNaiveMortalityError(final int[] events){
|
||||
int failures = 0;
|
||||
int attempts = 0;
|
||||
|
||||
response_loop:
|
||||
for(int i=0; i<dataset.size(); i++){
|
||||
final CompetingRiskResponse response = dataset.get(i).getResponse();
|
||||
|
||||
if(response.isCensored()){
|
||||
continue;
|
||||
}
|
||||
attempts++;
|
||||
|
||||
final CompetingRiskFunctions functions = riskFunctions.get(i);
|
||||
final int delta = response.getDelta();
|
||||
final double time = response.getU();
|
||||
final double shouldBeHighestMortality = functions.calculateEventSpecificMortality(delta, time);
|
||||
|
||||
for(final int event : events){
|
||||
if(event != delta){
|
||||
final double otherEventMortality = functions.calculateEventSpecificMortality(event, time);
|
||||
|
||||
if(shouldBeHighestMortality < otherEventMortality){
|
||||
failures++;
|
||||
continue response_loop;
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return (double) failures / (double) attempts;
|
||||
|
||||
|
||||
}
|
||||
|
||||
|
||||
public double[] calculateConcordance(final int[] events){
|
||||
final double tau = dataset.stream().mapToDouble(row -> row.getResponse().getU()).max().orElse(0.0);
|
||||
|
|
|
@ -244,8 +244,6 @@ public class TestCompetingRisk {
|
|||
|
||||
closeEnough(0.452, errorRates[0], 0.02);
|
||||
closeEnough(0.446, errorRates[1], 0.02);
|
||||
|
||||
System.out.println(errorRateCalculator.calculateNaiveMortalityError(new int[]{1,2}));
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -326,8 +324,6 @@ public class TestCompetingRisk {
|
|||
closeEnough(0.395, errorRates[0], 0.02);
|
||||
closeEnough(0.345, errorRates[1], 0.02);
|
||||
|
||||
System.out.println(errorRateCalculator.calculateNaiveMortalityError(new int[]{1,2}));
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -47,61 +47,5 @@ public class TestCompetingRiskErrorRateCalculator {
|
|||
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testNaiveMortality(){
|
||||
final CompetingRiskResponse response1 = new CompetingRiskResponse(1, 5.0);
|
||||
final CompetingRiskResponse response2 = new CompetingRiskResponse(0, 6.0);
|
||||
final CompetingRiskResponse response3 = new CompetingRiskResponse(2, 8.0);
|
||||
final CompetingRiskResponse response4 = new CompetingRiskResponse(1, 3.0);
|
||||
|
||||
final List<Row<CompetingRiskResponse>> dataset = Utils.easyList(
|
||||
new Row<>(new Covariate.Value[]{}, 1, response1),
|
||||
new Row<>(new Covariate.Value[]{}, 2, response2),
|
||||
new Row<>(new Covariate.Value[]{}, 3, response3),
|
||||
new Row<>(new Covariate.Value[]{}, 4, response4)
|
||||
);
|
||||
|
||||
final double[] mortalityOneArray = new double[]{1, 4, 3, 9};
|
||||
final double[] mortalityTwoArray = new double[]{2, 3, 4, 7};
|
||||
|
||||
// response1 was predicted incorrectly
|
||||
// response2 doesn't matter; censored
|
||||
// response3 was correctly predicted
|
||||
// response4 was correctly predicted
|
||||
|
||||
// Expect 1/3 for my error
|
||||
|
||||
final CompetingRiskFunctions function1 = mock(CompetingRiskFunctions.class);
|
||||
when(function1.calculateEventSpecificMortality(1, response1.getU())).thenReturn(mortalityOneArray[0]);
|
||||
when(function1.calculateEventSpecificMortality(2, response1.getU())).thenReturn(mortalityTwoArray[0]);
|
||||
|
||||
final CompetingRiskFunctions function2 = mock(CompetingRiskFunctions.class);
|
||||
when(function2.calculateEventSpecificMortality(1, response2.getU())).thenReturn(mortalityOneArray[1]);
|
||||
when(function2.calculateEventSpecificMortality(2, response2.getU())).thenReturn(mortalityTwoArray[1]);
|
||||
|
||||
final CompetingRiskFunctions function3 = mock(CompetingRiskFunctions.class);
|
||||
when(function3.calculateEventSpecificMortality(1, response3.getU())).thenReturn(mortalityOneArray[2]);
|
||||
when(function3.calculateEventSpecificMortality(2, response3.getU())).thenReturn(mortalityTwoArray[2]);
|
||||
|
||||
final CompetingRiskFunctions function4 = mock(CompetingRiskFunctions.class);
|
||||
when(function4.calculateEventSpecificMortality(1, response4.getU())).thenReturn(mortalityOneArray[3]);
|
||||
when(function4.calculateEventSpecificMortality(2, response4.getU())).thenReturn(mortalityTwoArray[3]);
|
||||
|
||||
final Forest mockForest = mock(Forest.class);
|
||||
when(mockForest.evaluateOOB(dataset.get(0))).thenReturn(function1);
|
||||
when(mockForest.evaluateOOB(dataset.get(1))).thenReturn(function2);
|
||||
when(mockForest.evaluateOOB(dataset.get(2))).thenReturn(function3);
|
||||
when(mockForest.evaluateOOB(dataset.get(3))).thenReturn(function4);
|
||||
|
||||
|
||||
final CompetingRiskErrorRateCalculator errorRateCalculator = new CompetingRiskErrorRateCalculator(dataset, mockForest, true);
|
||||
|
||||
final double error = errorRateCalculator.calculateNaiveMortalityError(new int[]{1,2});
|
||||
|
||||
assertEquals(1.0/3.0, error);
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue