Implement naive mortality error measure
This commit is contained in:
parent
650579a430
commit
e1caef6d56
6 changed files with 174 additions and 59 deletions
7
pom.xml
7
pom.xml
|
@ -56,6 +56,13 @@
|
||||||
<scope>test</scope>
|
<scope>test</scope>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.mockito</groupId>
|
||||||
|
<artifactId>mockito-core</artifactId>
|
||||||
|
<version>2.20.0</version>
|
||||||
|
<scope>test</scope>
|
||||||
|
</dependency>
|
||||||
|
|
||||||
</dependencies>
|
</dependencies>
|
||||||
|
|
||||||
<build>
|
<build>
|
||||||
|
|
|
@ -3,10 +3,8 @@ package ca.joeltherrien.randomforest.responses.competingrisk;
|
||||||
import ca.joeltherrien.randomforest.Row;
|
import ca.joeltherrien.randomforest.Row;
|
||||||
import ca.joeltherrien.randomforest.VisibleForTesting;
|
import ca.joeltherrien.randomforest.VisibleForTesting;
|
||||||
import ca.joeltherrien.randomforest.tree.Forest;
|
import ca.joeltherrien.randomforest.tree.Forest;
|
||||||
import ca.joeltherrien.randomforest.tree.Tree;
|
|
||||||
import lombok.RequiredArgsConstructor;
|
|
||||||
|
|
||||||
import java.util.*;
|
import java.util.List;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -15,42 +13,75 @@ import java.util.stream.Collectors;
|
||||||
* Note that this is the same version implemented in randomForestSRC. The downsides of this approach is that we can expect the errors to be biased, possibly severely.
|
* Note that this is the same version implemented in randomForestSRC. The downsides of this approach is that we can expect the errors to be biased, possibly severely.
|
||||||
* Therefore I suggest that this measure only be used in comparing models, but not as a final output.
|
* Therefore I suggest that this measure only be used in comparing models, but not as a final output.
|
||||||
*/
|
*/
|
||||||
@RequiredArgsConstructor
|
|
||||||
public class CompetingRiskErrorRateCalculator {
|
public class CompetingRiskErrorRateCalculator {
|
||||||
|
|
||||||
private final CompetingRiskFunctionCombiner combiner;
|
private final List<Row<CompetingRiskResponse>> dataset;
|
||||||
private final int[] events;
|
private final List<CompetingRiskFunctions> riskFunctions;
|
||||||
|
|
||||||
public CompetingRiskErrorRateCalculator(final int[] events, final double[] times){
|
public CompetingRiskErrorRateCalculator(final List<Row<CompetingRiskResponse>> dataset, final Forest<CompetingRiskFunctions, CompetingRiskFunctions> forest){
|
||||||
this.events = events;
|
this.dataset = dataset;
|
||||||
this.combiner = new CompetingRiskFunctionCombiner(events, times);
|
this.riskFunctions = dataset.stream()
|
||||||
}
|
.map(forest::evaluateOOB)
|
||||||
|
|
||||||
public double[] calculateConcordance(final List<Row<CompetingRiskResponse>> rows, final Forest<CompetingRiskFunctions, CompetingRiskFunctions> forest){
|
|
||||||
final double tau = rows.stream().mapToDouble(row -> row.getResponse().getU()).max().orElse(0.0);
|
|
||||||
|
|
||||||
return calculateConcordance(rows, forest, tau);
|
|
||||||
}
|
|
||||||
|
|
||||||
private double[] calculateConcordance(final List<Row<CompetingRiskResponse>> rows, final Forest<CompetingRiskFunctions, CompetingRiskFunctions> forest, final double tau){
|
|
||||||
|
|
||||||
final Collection<Tree<CompetingRiskFunctions>> trees = forest.getTrees();
|
|
||||||
|
|
||||||
// This predicts for rows based on their OOB trees.
|
|
||||||
|
|
||||||
final List<CompetingRiskFunctions> riskFunctions = rows.stream()
|
|
||||||
.map(row -> {
|
|
||||||
return trees.stream().filter(tree -> !tree.idInBootstrapSample(row.getId())).map(tree -> tree.evaluate(row)).collect(Collectors.toList());
|
|
||||||
})
|
|
||||||
.map(combiner::combine)
|
|
||||||
.collect(Collectors.toList());
|
.collect(Collectors.toList());
|
||||||
|
}
|
||||||
|
|
||||||
|
public double[] calculateConcordance(final int[] events){
|
||||||
|
final double tau = dataset.stream().mapToDouble(row -> row.getResponse().getU()).max().orElse(0.0);
|
||||||
|
|
||||||
|
return calculateConcordance(events, tau);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Idea for this error rate; go through every observation I have and calculate its mortality for the different events. If the event with the highest mortality is not the one that happened,
|
||||||
|
* then we add one to the error scale.
|
||||||
|
*
|
||||||
|
* Ignore censored observations.
|
||||||
|
*
|
||||||
|
* Possible extensions might involve counting how many other events had higher mortality, instead of just a single PASS / FAIL.
|
||||||
|
*
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
public double calculateNaiveMortalityError(final int[] events){
|
||||||
|
int failures = 0;
|
||||||
|
int attempts = 0;
|
||||||
|
|
||||||
|
response_loop:
|
||||||
|
for(int i=0; i<dataset.size(); i++){
|
||||||
|
final CompetingRiskResponse response = dataset.get(i).getResponse();
|
||||||
|
|
||||||
|
if(response.isCensored()){
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
attempts++;
|
||||||
|
|
||||||
|
final CompetingRiskFunctions functions = riskFunctions.get(i);
|
||||||
|
final int delta = response.getDelta();
|
||||||
|
final double time = response.getU();
|
||||||
|
final double shouldBeHighestMortality = functions.calculateEventSpecificMortality(delta, time);
|
||||||
|
|
||||||
|
for(final int event : events){
|
||||||
|
if(event != delta){
|
||||||
|
final double otherEventMortality = functions.calculateEventSpecificMortality(event, time);
|
||||||
|
|
||||||
|
if(shouldBeHighestMortality < otherEventMortality){
|
||||||
|
failures++;
|
||||||
|
continue response_loop;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return (double) failures / (double) attempts;
|
||||||
|
|
||||||
|
|
||||||
//final List<CompetingRiskFunctions> riskFunctions = rows.stream().map(row -> forest.evaluate(row)).collect(Collectors.toList());
|
}
|
||||||
|
|
||||||
|
private double[] calculateConcordance(final int[] events, final double tau){
|
||||||
|
|
||||||
final double[] errorRates = new double[events.length];
|
final double[] errorRates = new double[events.length];
|
||||||
|
|
||||||
final List<CompetingRiskResponse> responses = rows.stream().map(Row::getResponse).collect(Collectors.toList());
|
final List<CompetingRiskResponse> responses = dataset.stream().map(Row::getResponse).collect(Collectors.toList());
|
||||||
|
|
||||||
// Let \tau be the max time.
|
// Let \tau be the max time.
|
||||||
|
|
||||||
|
@ -58,8 +89,7 @@ public class CompetingRiskErrorRateCalculator {
|
||||||
final int event = events[e];
|
final int event = events[e];
|
||||||
|
|
||||||
final double[] mortalityList = riskFunctions.stream()
|
final double[] mortalityList = riskFunctions.stream()
|
||||||
.map(riskFunction -> riskFunction.getCumulativeIncidenceFunction(event))
|
.mapToDouble(riskFunction -> riskFunction.calculateEventSpecificMortality(event, tau))
|
||||||
.mapToDouble(cif -> functionToMortality(cif, tau))
|
|
||||||
.toArray();
|
.toArray();
|
||||||
|
|
||||||
final double concordance = calculateConcordance(responses, mortalityList, event);
|
final double concordance = calculateConcordance(responses, mortalityList, event);
|
||||||
|
@ -111,25 +141,4 @@ public class CompetingRiskErrorRateCalculator {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private double functionToMortality(final MathFunction cif, final double tau){
|
|
||||||
double summation = 0.0;
|
|
||||||
Point previousPoint = null;
|
|
||||||
|
|
||||||
for(final Point point : cif.getPoints()){
|
|
||||||
if(previousPoint != null){
|
|
||||||
summation += previousPoint.getY() * (point.getTime() - previousPoint.getTime());
|
|
||||||
}
|
|
||||||
previousPoint = point;
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
// this is to ensure that we integrate over the same range for every function and get comparable results.
|
|
||||||
// Don't need to assert whether previousPoint is null or not; if it is null then the MathFunction was incorrectly made as there will always be at least one point for a response
|
|
||||||
summation += previousPoint.getY() * (tau - previousPoint.getTime());
|
|
||||||
|
|
||||||
return summation;
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -23,4 +23,26 @@ public class CompetingRiskFunctions implements Serializable {
|
||||||
public MathFunction getCumulativeIncidenceFunction(int cause) {
|
public MathFunction getCumulativeIncidenceFunction(int cause) {
|
||||||
return cumulativeIncidenceFunctionMap.get(cause);
|
return cumulativeIncidenceFunctionMap.get(cause);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public double calculateEventSpecificMortality(final int event, final double tau){
|
||||||
|
final MathFunction cif = getCauseSpecificHazardFunction(event);
|
||||||
|
|
||||||
|
double summation = 0.0;
|
||||||
|
Point previousPoint = null;
|
||||||
|
|
||||||
|
for(final Point point : cif.getPoints()){
|
||||||
|
if(previousPoint != null){
|
||||||
|
summation += previousPoint.getY() * (point.getTime() - previousPoint.getTime());
|
||||||
|
}
|
||||||
|
previousPoint = point;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
// this is to ensure that we integrate over the same range for every function and get comparable results.
|
||||||
|
// Don't need to assert whether previousPoint is null or not; if it is null then the MathFunction was incorrectly made as there will always be at least one point for a response
|
||||||
|
summation += previousPoint.getY() * (tau - previousPoint.getTime());
|
||||||
|
|
||||||
|
return summation;
|
||||||
|
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,6 +2,7 @@ package ca.joeltherrien.randomforest.tree;
|
||||||
|
|
||||||
import ca.joeltherrien.randomforest.CovariateRow;
|
import ca.joeltherrien.randomforest.CovariateRow;
|
||||||
import lombok.Builder;
|
import lombok.Builder;
|
||||||
|
import lombok.RequiredArgsConstructor;
|
||||||
|
|
||||||
import java.util.Collection;
|
import java.util.Collection;
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
|
@ -23,6 +24,17 @@ public class Forest<O, FO> { // O = output of trees, FO = forest output. In prac
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public FO evaluateOOB(CovariateRow row){
|
||||||
|
|
||||||
|
return treeResponseCombiner.combine(
|
||||||
|
trees.stream()
|
||||||
|
.filter(tree -> !tree.idInBootstrapSample(row.getId()))
|
||||||
|
.map(node -> node.evaluate(row))
|
||||||
|
.collect(Collectors.toList())
|
||||||
|
);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
public Collection<Tree<O>> getTrees(){
|
public Collection<Tree<O>> getTrees(){
|
||||||
return Collections.unmodifiableCollection(trees);
|
return Collections.unmodifiableCollection(trees);
|
||||||
}
|
}
|
||||||
|
|
|
@ -227,8 +227,8 @@ public class TestCompetingRisk {
|
||||||
closeEnough(0.163, functions.getCumulativeIncidenceFunction(2).evaluate(4.0).getY(), 0.01);
|
closeEnough(0.163, functions.getCumulativeIncidenceFunction(2).evaluate(4.0).getY(), 0.01);
|
||||||
closeEnough(0.195, functions.getCumulativeIncidenceFunction(2).evaluate(10.8).getY(), 0.01);
|
closeEnough(0.195, functions.getCumulativeIncidenceFunction(2).evaluate(10.8).getY(), 0.01);
|
||||||
|
|
||||||
final CompetingRiskErrorRateCalculator errorRateCalculator = new CompetingRiskErrorRateCalculator(new int[]{1,2}, null);
|
final CompetingRiskErrorRateCalculator errorRateCalculator = new CompetingRiskErrorRateCalculator(dataset, forest);
|
||||||
final double[] errorRates = errorRateCalculator.calculateConcordance(dataset, forest);
|
final double[] errorRates = errorRateCalculator.calculateConcordance(new int[]{1,2});
|
||||||
|
|
||||||
// Error rates happen to be about the same
|
// Error rates happen to be about the same
|
||||||
/* randomForestSRC results; ignored for now
|
/* randomForestSRC results; ignored for now
|
||||||
|
@ -308,8 +308,8 @@ public class TestCompetingRisk {
|
||||||
// We seem to consistently underestimate the results.
|
// We seem to consistently underestimate the results.
|
||||||
assertTrue(causeOneCIFPoints.get(causeOneCIFPoints.size()-1).getY() > 0.75, "Results should match randomForestSRC; had " + causeOneCIFPoints.get(causeOneCIFPoints.size()-1).getY()); // note; most observations from randomForestSRC hover around 0.78 but I've seen it as low as 0.72
|
assertTrue(causeOneCIFPoints.get(causeOneCIFPoints.size()-1).getY() > 0.75, "Results should match randomForestSRC; had " + causeOneCIFPoints.get(causeOneCIFPoints.size()-1).getY()); // note; most observations from randomForestSRC hover around 0.78 but I've seen it as low as 0.72
|
||||||
|
|
||||||
final CompetingRiskErrorRateCalculator errorRate = new CompetingRiskErrorRateCalculator(new int[]{1,2}, null);
|
final CompetingRiskErrorRateCalculator errorRateCalculator = new CompetingRiskErrorRateCalculator(dataset, forest);
|
||||||
final double[] errorRates = errorRate.calculateConcordance(dataset, forest);
|
final double[] errorRates = errorRateCalculator.calculateConcordance(new int[]{1,2});
|
||||||
|
|
||||||
System.out.println(errorRates[0]);
|
System.out.println(errorRates[0]);
|
||||||
System.out.println(errorRates[1]);
|
System.out.println(errorRates[1]);
|
||||||
|
|
|
@ -1,12 +1,19 @@
|
||||||
package ca.joeltherrien.randomforest.competingrisk;
|
package ca.joeltherrien.randomforest.competingrisk;
|
||||||
|
|
||||||
|
|
||||||
|
import ca.joeltherrien.randomforest.Row;
|
||||||
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskErrorRateCalculator;
|
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskErrorRateCalculator;
|
||||||
|
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskFunctions;
|
||||||
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
|
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
|
||||||
|
import ca.joeltherrien.randomforest.tree.Forest;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
|
|
||||||
|
import java.util.Collections;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
import static org.junit.jupiter.api.Assertions.*;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
import static org.mockito.Mockito.mock;
|
||||||
|
import static org.mockito.Mockito.when;
|
||||||
|
|
||||||
public class TestCompetingRiskErrorRateCalculator {
|
public class TestCompetingRiskErrorRateCalculator {
|
||||||
|
|
||||||
|
@ -23,7 +30,8 @@ public class TestCompetingRiskErrorRateCalculator {
|
||||||
|
|
||||||
final int event = 1;
|
final int event = 1;
|
||||||
|
|
||||||
final CompetingRiskErrorRateCalculator errorRateCalculator = new CompetingRiskErrorRateCalculator(new int[]{1,2}, null);
|
final Forest<CompetingRiskFunctions, CompetingRiskFunctions> fakeForest = Forest.<CompetingRiskFunctions, CompetingRiskFunctions>builder().build();
|
||||||
|
final CompetingRiskErrorRateCalculator errorRateCalculator = new CompetingRiskErrorRateCalculator(Collections.emptyList(), fakeForest);
|
||||||
|
|
||||||
final double concordance = errorRateCalculator.calculateConcordance(responseList, mortalityArray, event);
|
final double concordance = errorRateCalculator.calculateConcordance(responseList, mortalityArray, event);
|
||||||
|
|
||||||
|
@ -32,4 +40,61 @@ public class TestCompetingRiskErrorRateCalculator {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testNaiveMortality(){
|
||||||
|
final CompetingRiskResponse response1 = new CompetingRiskResponse(1, 5.0);
|
||||||
|
final CompetingRiskResponse response2 = new CompetingRiskResponse(0, 6.0);
|
||||||
|
final CompetingRiskResponse response3 = new CompetingRiskResponse(2, 8.0);
|
||||||
|
final CompetingRiskResponse response4 = new CompetingRiskResponse(1, 3.0);
|
||||||
|
|
||||||
|
final List<Row<CompetingRiskResponse>> dataset = List.of(
|
||||||
|
new Row<>(Collections.emptyMap(), 1, response1),
|
||||||
|
new Row<>(Collections.emptyMap(), 2, response2),
|
||||||
|
new Row<>(Collections.emptyMap(), 3, response3),
|
||||||
|
new Row<>(Collections.emptyMap(), 4, response4)
|
||||||
|
);
|
||||||
|
|
||||||
|
final double[] mortalityOneArray = new double[]{1, 4, 3, 9};
|
||||||
|
final double[] mortalityTwoArray = new double[]{2, 3, 4, 7};
|
||||||
|
|
||||||
|
// response1 was predicted incorrectly
|
||||||
|
// response2 doesn't matter; censored
|
||||||
|
// response3 was correctly predicted
|
||||||
|
// response4 was correctly predicted
|
||||||
|
|
||||||
|
// Expect 1/3 for my error
|
||||||
|
|
||||||
|
final CompetingRiskFunctions function1 = mock(CompetingRiskFunctions.class);
|
||||||
|
when(function1.calculateEventSpecificMortality(1, response1.getU())).thenReturn(mortalityOneArray[0]);
|
||||||
|
when(function1.calculateEventSpecificMortality(2, response1.getU())).thenReturn(mortalityTwoArray[0]);
|
||||||
|
|
||||||
|
final CompetingRiskFunctions function2 = mock(CompetingRiskFunctions.class);
|
||||||
|
when(function2.calculateEventSpecificMortality(1, response2.getU())).thenReturn(mortalityOneArray[1]);
|
||||||
|
when(function2.calculateEventSpecificMortality(2, response2.getU())).thenReturn(mortalityTwoArray[1]);
|
||||||
|
|
||||||
|
final CompetingRiskFunctions function3 = mock(CompetingRiskFunctions.class);
|
||||||
|
when(function3.calculateEventSpecificMortality(1, response3.getU())).thenReturn(mortalityOneArray[2]);
|
||||||
|
when(function3.calculateEventSpecificMortality(2, response3.getU())).thenReturn(mortalityTwoArray[2]);
|
||||||
|
|
||||||
|
final CompetingRiskFunctions function4 = mock(CompetingRiskFunctions.class);
|
||||||
|
when(function4.calculateEventSpecificMortality(1, response4.getU())).thenReturn(mortalityOneArray[3]);
|
||||||
|
when(function4.calculateEventSpecificMortality(2, response4.getU())).thenReturn(mortalityTwoArray[3]);
|
||||||
|
|
||||||
|
final Forest mockForest = mock(Forest.class);
|
||||||
|
when(mockForest.evaluateOOB(dataset.get(0))).thenReturn(function1);
|
||||||
|
when(mockForest.evaluateOOB(dataset.get(1))).thenReturn(function2);
|
||||||
|
when(mockForest.evaluateOOB(dataset.get(2))).thenReturn(function3);
|
||||||
|
when(mockForest.evaluateOOB(dataset.get(3))).thenReturn(function4);
|
||||||
|
|
||||||
|
|
||||||
|
final CompetingRiskErrorRateCalculator errorRateCalculator = new CompetingRiskErrorRateCalculator(dataset, mockForest);
|
||||||
|
|
||||||
|
final double error = errorRateCalculator.calculateNaiveMortalityError(new int[]{1,2});
|
||||||
|
|
||||||
|
assertEquals(1.0/3.0, error);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue