Implement naive mortality error measure

This commit is contained in:
Joel Therrien 2018-07-25 15:29:09 -07:00
parent 650579a430
commit e1caef6d56
6 changed files with 174 additions and 59 deletions

View file

@ -56,6 +56,13 @@
<scope>test</scope> <scope>test</scope>
</dependency> </dependency>
<dependency>
<groupId>org.mockito</groupId>
<artifactId>mockito-core</artifactId>
<version>2.20.0</version>
<scope>test</scope>
</dependency>
</dependencies> </dependencies>
<build> <build>

View file

@ -3,10 +3,8 @@ package ca.joeltherrien.randomforest.responses.competingrisk;
import ca.joeltherrien.randomforest.Row; import ca.joeltherrien.randomforest.Row;
import ca.joeltherrien.randomforest.VisibleForTesting; import ca.joeltherrien.randomforest.VisibleForTesting;
import ca.joeltherrien.randomforest.tree.Forest; import ca.joeltherrien.randomforest.tree.Forest;
import ca.joeltherrien.randomforest.tree.Tree;
import lombok.RequiredArgsConstructor;
import java.util.*; import java.util.List;
import java.util.stream.Collectors; import java.util.stream.Collectors;
/** /**
@ -15,42 +13,75 @@ import java.util.stream.Collectors;
* Note that this is the same version implemented in randomForestSRC. The downsides of this approach is that we can expect the errors to be biased, possibly severely. * Note that this is the same version implemented in randomForestSRC. The downsides of this approach is that we can expect the errors to be biased, possibly severely.
* Therefore I suggest that this measure only be used in comparing models, but not as a final output. * Therefore I suggest that this measure only be used in comparing models, but not as a final output.
*/ */
@RequiredArgsConstructor
public class CompetingRiskErrorRateCalculator { public class CompetingRiskErrorRateCalculator {
private final CompetingRiskFunctionCombiner combiner; private final List<Row<CompetingRiskResponse>> dataset;
private final int[] events; private final List<CompetingRiskFunctions> riskFunctions;
public CompetingRiskErrorRateCalculator(final int[] events, final double[] times){ public CompetingRiskErrorRateCalculator(final List<Row<CompetingRiskResponse>> dataset, final Forest<CompetingRiskFunctions, CompetingRiskFunctions> forest){
this.events = events; this.dataset = dataset;
this.combiner = new CompetingRiskFunctionCombiner(events, times); this.riskFunctions = dataset.stream()
} .map(forest::evaluateOOB)
public double[] calculateConcordance(final List<Row<CompetingRiskResponse>> rows, final Forest<CompetingRiskFunctions, CompetingRiskFunctions> forest){
final double tau = rows.stream().mapToDouble(row -> row.getResponse().getU()).max().orElse(0.0);
return calculateConcordance(rows, forest, tau);
}
private double[] calculateConcordance(final List<Row<CompetingRiskResponse>> rows, final Forest<CompetingRiskFunctions, CompetingRiskFunctions> forest, final double tau){
final Collection<Tree<CompetingRiskFunctions>> trees = forest.getTrees();
// This predicts for rows based on their OOB trees.
final List<CompetingRiskFunctions> riskFunctions = rows.stream()
.map(row -> {
return trees.stream().filter(tree -> !tree.idInBootstrapSample(row.getId())).map(tree -> tree.evaluate(row)).collect(Collectors.toList());
})
.map(combiner::combine)
.collect(Collectors.toList()); .collect(Collectors.toList());
}
public double[] calculateConcordance(final int[] events){
final double tau = dataset.stream().mapToDouble(row -> row.getResponse().getU()).max().orElse(0.0);
return calculateConcordance(events, tau);
}
/**
* Idea for this error rate; go through every observation I have and calculate its mortality for the different events. If the event with the highest mortality is not the one that happened,
* then we add one to the error scale.
*
* Ignore censored observations.
*
* Possible extensions might involve counting how many other events had higher mortality, instead of just a single PASS / FAIL.
*
* @return
*/
public double calculateNaiveMortalityError(final int[] events){
int failures = 0;
int attempts = 0;
response_loop:
for(int i=0; i<dataset.size(); i++){
final CompetingRiskResponse response = dataset.get(i).getResponse();
if(response.isCensored()){
continue;
}
attempts++;
final CompetingRiskFunctions functions = riskFunctions.get(i);
final int delta = response.getDelta();
final double time = response.getU();
final double shouldBeHighestMortality = functions.calculateEventSpecificMortality(delta, time);
for(final int event : events){
if(event != delta){
final double otherEventMortality = functions.calculateEventSpecificMortality(event, time);
if(shouldBeHighestMortality < otherEventMortality){
failures++;
continue response_loop;
}
}
}
}
return (double) failures / (double) attempts;
//final List<CompetingRiskFunctions> riskFunctions = rows.stream().map(row -> forest.evaluate(row)).collect(Collectors.toList()); }
private double[] calculateConcordance(final int[] events, final double tau){
final double[] errorRates = new double[events.length]; final double[] errorRates = new double[events.length];
final List<CompetingRiskResponse> responses = rows.stream().map(Row::getResponse).collect(Collectors.toList()); final List<CompetingRiskResponse> responses = dataset.stream().map(Row::getResponse).collect(Collectors.toList());
// Let \tau be the max time. // Let \tau be the max time.
@ -58,8 +89,7 @@ public class CompetingRiskErrorRateCalculator {
final int event = events[e]; final int event = events[e];
final double[] mortalityList = riskFunctions.stream() final double[] mortalityList = riskFunctions.stream()
.map(riskFunction -> riskFunction.getCumulativeIncidenceFunction(event)) .mapToDouble(riskFunction -> riskFunction.calculateEventSpecificMortality(event, tau))
.mapToDouble(cif -> functionToMortality(cif, tau))
.toArray(); .toArray();
final double concordance = calculateConcordance(responses, mortalityList, event); final double concordance = calculateConcordance(responses, mortalityList, event);
@ -111,25 +141,4 @@ public class CompetingRiskErrorRateCalculator {
} }
private double functionToMortality(final MathFunction cif, final double tau){
double summation = 0.0;
Point previousPoint = null;
for(final Point point : cif.getPoints()){
if(previousPoint != null){
summation += previousPoint.getY() * (point.getTime() - previousPoint.getTime());
}
previousPoint = point;
}
// this is to ensure that we integrate over the same range for every function and get comparable results.
// Don't need to assert whether previousPoint is null or not; if it is null then the MathFunction was incorrectly made as there will always be at least one point for a response
summation += previousPoint.getY() * (tau - previousPoint.getTime());
return summation;
}
} }

View file

@ -23,4 +23,26 @@ public class CompetingRiskFunctions implements Serializable {
public MathFunction getCumulativeIncidenceFunction(int cause) { public MathFunction getCumulativeIncidenceFunction(int cause) {
return cumulativeIncidenceFunctionMap.get(cause); return cumulativeIncidenceFunctionMap.get(cause);
} }
public double calculateEventSpecificMortality(final int event, final double tau){
final MathFunction cif = getCauseSpecificHazardFunction(event);
double summation = 0.0;
Point previousPoint = null;
for(final Point point : cif.getPoints()){
if(previousPoint != null){
summation += previousPoint.getY() * (point.getTime() - previousPoint.getTime());
}
previousPoint = point;
}
// this is to ensure that we integrate over the same range for every function and get comparable results.
// Don't need to assert whether previousPoint is null or not; if it is null then the MathFunction was incorrectly made as there will always be at least one point for a response
summation += previousPoint.getY() * (tau - previousPoint.getTime());
return summation;
}
} }

View file

@ -2,6 +2,7 @@ package ca.joeltherrien.randomforest.tree;
import ca.joeltherrien.randomforest.CovariateRow; import ca.joeltherrien.randomforest.CovariateRow;
import lombok.Builder; import lombok.Builder;
import lombok.RequiredArgsConstructor;
import java.util.Collection; import java.util.Collection;
import java.util.Collections; import java.util.Collections;
@ -23,6 +24,17 @@ public class Forest<O, FO> { // O = output of trees, FO = forest output. In prac
} }
public FO evaluateOOB(CovariateRow row){
return treeResponseCombiner.combine(
trees.stream()
.filter(tree -> !tree.idInBootstrapSample(row.getId()))
.map(node -> node.evaluate(row))
.collect(Collectors.toList())
);
}
public Collection<Tree<O>> getTrees(){ public Collection<Tree<O>> getTrees(){
return Collections.unmodifiableCollection(trees); return Collections.unmodifiableCollection(trees);
} }

View file

@ -227,8 +227,8 @@ public class TestCompetingRisk {
closeEnough(0.163, functions.getCumulativeIncidenceFunction(2).evaluate(4.0).getY(), 0.01); closeEnough(0.163, functions.getCumulativeIncidenceFunction(2).evaluate(4.0).getY(), 0.01);
closeEnough(0.195, functions.getCumulativeIncidenceFunction(2).evaluate(10.8).getY(), 0.01); closeEnough(0.195, functions.getCumulativeIncidenceFunction(2).evaluate(10.8).getY(), 0.01);
final CompetingRiskErrorRateCalculator errorRateCalculator = new CompetingRiskErrorRateCalculator(new int[]{1,2}, null); final CompetingRiskErrorRateCalculator errorRateCalculator = new CompetingRiskErrorRateCalculator(dataset, forest);
final double[] errorRates = errorRateCalculator.calculateConcordance(dataset, forest); final double[] errorRates = errorRateCalculator.calculateConcordance(new int[]{1,2});
// Error rates happen to be about the same // Error rates happen to be about the same
/* randomForestSRC results; ignored for now /* randomForestSRC results; ignored for now
@ -308,8 +308,8 @@ public class TestCompetingRisk {
// We seem to consistently underestimate the results. // We seem to consistently underestimate the results.
assertTrue(causeOneCIFPoints.get(causeOneCIFPoints.size()-1).getY() > 0.75, "Results should match randomForestSRC; had " + causeOneCIFPoints.get(causeOneCIFPoints.size()-1).getY()); // note; most observations from randomForestSRC hover around 0.78 but I've seen it as low as 0.72 assertTrue(causeOneCIFPoints.get(causeOneCIFPoints.size()-1).getY() > 0.75, "Results should match randomForestSRC; had " + causeOneCIFPoints.get(causeOneCIFPoints.size()-1).getY()); // note; most observations from randomForestSRC hover around 0.78 but I've seen it as low as 0.72
final CompetingRiskErrorRateCalculator errorRate = new CompetingRiskErrorRateCalculator(new int[]{1,2}, null); final CompetingRiskErrorRateCalculator errorRateCalculator = new CompetingRiskErrorRateCalculator(dataset, forest);
final double[] errorRates = errorRate.calculateConcordance(dataset, forest); final double[] errorRates = errorRateCalculator.calculateConcordance(new int[]{1,2});
System.out.println(errorRates[0]); System.out.println(errorRates[0]);
System.out.println(errorRates[1]); System.out.println(errorRates[1]);

View file

@ -1,12 +1,19 @@
package ca.joeltherrien.randomforest.competingrisk; package ca.joeltherrien.randomforest.competingrisk;
import ca.joeltherrien.randomforest.Row;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskErrorRateCalculator; import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskErrorRateCalculator;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskFunctions;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse; import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
import ca.joeltherrien.randomforest.tree.Forest;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import java.util.Collections;
import java.util.List; import java.util.List;
import static org.junit.jupiter.api.Assertions.*; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
public class TestCompetingRiskErrorRateCalculator { public class TestCompetingRiskErrorRateCalculator {
@ -23,7 +30,8 @@ public class TestCompetingRiskErrorRateCalculator {
final int event = 1; final int event = 1;
final CompetingRiskErrorRateCalculator errorRateCalculator = new CompetingRiskErrorRateCalculator(new int[]{1,2}, null); final Forest<CompetingRiskFunctions, CompetingRiskFunctions> fakeForest = Forest.<CompetingRiskFunctions, CompetingRiskFunctions>builder().build();
final CompetingRiskErrorRateCalculator errorRateCalculator = new CompetingRiskErrorRateCalculator(Collections.emptyList(), fakeForest);
final double concordance = errorRateCalculator.calculateConcordance(responseList, mortalityArray, event); final double concordance = errorRateCalculator.calculateConcordance(responseList, mortalityArray, event);
@ -32,4 +40,61 @@ public class TestCompetingRiskErrorRateCalculator {
} }
@Test
public void testNaiveMortality(){
final CompetingRiskResponse response1 = new CompetingRiskResponse(1, 5.0);
final CompetingRiskResponse response2 = new CompetingRiskResponse(0, 6.0);
final CompetingRiskResponse response3 = new CompetingRiskResponse(2, 8.0);
final CompetingRiskResponse response4 = new CompetingRiskResponse(1, 3.0);
final List<Row<CompetingRiskResponse>> dataset = List.of(
new Row<>(Collections.emptyMap(), 1, response1),
new Row<>(Collections.emptyMap(), 2, response2),
new Row<>(Collections.emptyMap(), 3, response3),
new Row<>(Collections.emptyMap(), 4, response4)
);
final double[] mortalityOneArray = new double[]{1, 4, 3, 9};
final double[] mortalityTwoArray = new double[]{2, 3, 4, 7};
// response1 was predicted incorrectly
// response2 doesn't matter; censored
// response3 was correctly predicted
// response4 was correctly predicted
// Expect 1/3 for my error
final CompetingRiskFunctions function1 = mock(CompetingRiskFunctions.class);
when(function1.calculateEventSpecificMortality(1, response1.getU())).thenReturn(mortalityOneArray[0]);
when(function1.calculateEventSpecificMortality(2, response1.getU())).thenReturn(mortalityTwoArray[0]);
final CompetingRiskFunctions function2 = mock(CompetingRiskFunctions.class);
when(function2.calculateEventSpecificMortality(1, response2.getU())).thenReturn(mortalityOneArray[1]);
when(function2.calculateEventSpecificMortality(2, response2.getU())).thenReturn(mortalityTwoArray[1]);
final CompetingRiskFunctions function3 = mock(CompetingRiskFunctions.class);
when(function3.calculateEventSpecificMortality(1, response3.getU())).thenReturn(mortalityOneArray[2]);
when(function3.calculateEventSpecificMortality(2, response3.getU())).thenReturn(mortalityTwoArray[2]);
final CompetingRiskFunctions function4 = mock(CompetingRiskFunctions.class);
when(function4.calculateEventSpecificMortality(1, response4.getU())).thenReturn(mortalityOneArray[3]);
when(function4.calculateEventSpecificMortality(2, response4.getU())).thenReturn(mortalityTwoArray[3]);
final Forest mockForest = mock(Forest.class);
when(mockForest.evaluateOOB(dataset.get(0))).thenReturn(function1);
when(mockForest.evaluateOOB(dataset.get(1))).thenReturn(function2);
when(mockForest.evaluateOOB(dataset.get(2))).thenReturn(function3);
when(mockForest.evaluateOOB(dataset.get(3))).thenReturn(function4);
final CompetingRiskErrorRateCalculator errorRateCalculator = new CompetingRiskErrorRateCalculator(dataset, mockForest);
final double error = errorRateCalculator.calculateNaiveMortalityError(new int[]{1,2});
assertEquals(1.0/3.0, error);
}
} }