Fix a bug in naive mortality error measure; implement IPCW concordance
measure if you can provide the censoring distribution.
This commit is contained in:
parent
e1caef6d56
commit
d3994212b6
12 changed files with 255 additions and 51 deletions
|
@ -3,6 +3,7 @@ package ca.joeltherrien.randomforest.responses.competingrisk;
|
|||
import ca.joeltherrien.randomforest.Row;
|
||||
import ca.joeltherrien.randomforest.VisibleForTesting;
|
||||
import ca.joeltherrien.randomforest.tree.Forest;
|
||||
import ca.joeltherrien.randomforest.utils.MathFunction;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
|
@ -25,12 +26,6 @@ public class CompetingRiskErrorRateCalculator {
|
|||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
public double[] calculateConcordance(final int[] events){
|
||||
final double tau = dataset.stream().mapToDouble(row -> row.getResponse().getU()).max().orElse(0.0);
|
||||
|
||||
return calculateConcordance(events, tau);
|
||||
}
|
||||
|
||||
/**
|
||||
* Idea for this error rate; go through every observation I have and calculate its mortality for the different events. If the event with the highest mortality is not the one that happened,
|
||||
* then we add one to the error scale.
|
||||
|
@ -39,6 +34,8 @@ public class CompetingRiskErrorRateCalculator {
|
|||
*
|
||||
* Possible extensions might involve counting how many other events had higher mortality, instead of just a single PASS / FAIL.
|
||||
*
|
||||
* My observation is that this error rate isn't very useful...
|
||||
*
|
||||
* @return
|
||||
*/
|
||||
public double calculateNaiveMortalityError(final int[] events){
|
||||
|
@ -77,6 +74,13 @@ public class CompetingRiskErrorRateCalculator {
|
|||
|
||||
}
|
||||
|
||||
|
||||
public double[] calculateConcordance(final int[] events){
|
||||
final double tau = dataset.stream().mapToDouble(row -> row.getResponse().getU()).max().orElse(0.0);
|
||||
|
||||
return calculateConcordance(events, tau);
|
||||
}
|
||||
|
||||
private double[] calculateConcordance(final int[] events, final double tau){
|
||||
|
||||
final double[] errorRates = new double[events.length];
|
||||
|
@ -101,6 +105,36 @@ public class CompetingRiskErrorRateCalculator {
|
|||
|
||||
}
|
||||
|
||||
public double[] calculateIPCWConcordance(final int[] events, final MathFunction censoringDistribution){
|
||||
final double tau = dataset.stream().mapToDouble(row -> row.getResponse().getU()).max().orElse(0.0);
|
||||
|
||||
return calculateIPCWConcordance(events, censoringDistribution, tau);
|
||||
}
|
||||
|
||||
private double[] calculateIPCWConcordance(final int[] events, final MathFunction censoringDistribution, final double tau){
|
||||
|
||||
final double[] errorRates = new double[events.length];
|
||||
|
||||
final List<CompetingRiskResponse> responses = dataset.stream().map(Row::getResponse).collect(Collectors.toList());
|
||||
|
||||
// Let \tau be the max time.
|
||||
|
||||
for(int e=0; e<events.length; e++){
|
||||
final int event = events[e];
|
||||
|
||||
final double[] mortalityList = riskFunctions.stream()
|
||||
.mapToDouble(riskFunction -> riskFunction.calculateEventSpecificMortality(event, tau))
|
||||
.toArray();
|
||||
|
||||
final double concordance = calculateIPCWConcordance(responses, mortalityList, event, censoringDistribution);
|
||||
errorRates[e] = 1.0 - concordance;
|
||||
|
||||
}
|
||||
|
||||
return errorRates;
|
||||
|
||||
}
|
||||
|
||||
@VisibleForTesting
|
||||
public double calculateConcordance(final List<CompetingRiskResponse> responseList, double[] mortalityArray, final int event){
|
||||
|
||||
|
@ -141,4 +175,55 @@ public class CompetingRiskErrorRateCalculator {
|
|||
|
||||
}
|
||||
|
||||
|
||||
@VisibleForTesting
|
||||
public double calculateIPCWConcordance(final List<CompetingRiskResponse> responseList, double[] mortalityArray, final int event, final MathFunction censoringDistribution){
|
||||
|
||||
// Let \tau be the max time.
|
||||
|
||||
double denominator = 0.0;
|
||||
double numerator = 0.0;
|
||||
|
||||
for(int i = 0; i<mortalityArray.length; i++){
|
||||
final CompetingRiskResponse responseI = responseList.get(i);
|
||||
if(responseI.getDelta() != event){ // \tilde{N}_i^1(\tau) == 1 check
|
||||
continue; // skip if it's 0
|
||||
}
|
||||
|
||||
final double mortalityI = mortalityArray[i];
|
||||
|
||||
for(int j=0; j<mortalityArray.length; j++){
|
||||
final CompetingRiskResponse responseJ = responseList.get(j);
|
||||
|
||||
final double AijWeightPlusBijWeight;
|
||||
|
||||
if(responseI.getU() < responseJ.getU()){ // Aij == 1
|
||||
final double Ti = responseI.getU();
|
||||
AijWeightPlusBijWeight = 1.0 / (censoringDistribution.evaluate(Ti).getY() * censoringDistribution.evaluatePrevious(Ti).getY());
|
||||
}
|
||||
else if(responseI.getU() >= responseJ.getU() && !responseJ.isCensored() && responseJ.getDelta() != event){ // Bij == 1
|
||||
AijWeightPlusBijWeight = 1.0 / (censoringDistribution.evaluatePrevious(responseI.getU()).getY() * censoringDistribution.evaluatePrevious(responseJ.getU()).getY());
|
||||
}
|
||||
else{
|
||||
continue;
|
||||
}
|
||||
|
||||
denominator += AijWeightPlusBijWeight;
|
||||
|
||||
final double mortalityJ = mortalityArray[j];
|
||||
if(mortalityI > mortalityJ){
|
||||
numerator += AijWeightPlusBijWeight*1.0;
|
||||
}
|
||||
else if(mortalityI == mortalityJ){
|
||||
numerator += AijWeightPlusBijWeight*0.5; // Edge case that can happen in trees with only a few BooleanCovariates, when you're looking at training error
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
return numerator / denominator;
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -1,7 +1,8 @@
|
|||
package ca.joeltherrien.randomforest.responses.competingrisk;
|
||||
|
||||
import ca.joeltherrien.randomforest.tree.ResponseCombiner;
|
||||
import lombok.Getter;
|
||||
import ca.joeltherrien.randomforest.utils.MathFunction;
|
||||
import ca.joeltherrien.randomforest.utils.Point;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
|
||||
import java.util.ArrayList;
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
package ca.joeltherrien.randomforest.responses.competingrisk;
|
||||
|
||||
import ca.joeltherrien.randomforest.utils.MathFunction;
|
||||
import ca.joeltherrien.randomforest.utils.Point;
|
||||
import lombok.Builder;
|
||||
import lombok.Getter;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
|
||||
import java.io.Serializable;
|
||||
import java.util.Map;
|
||||
|
@ -31,6 +32,10 @@ public class CompetingRiskFunctions implements Serializable {
|
|||
Point previousPoint = null;
|
||||
|
||||
for(final Point point : cif.getPoints()){
|
||||
if(point.getTime() > tau){
|
||||
break;
|
||||
}
|
||||
|
||||
if(previousPoint != null){
|
||||
summation += previousPoint.getY() * (point.getTime() - previousPoint.getTime());
|
||||
}
|
||||
|
@ -38,9 +43,11 @@ public class CompetingRiskFunctions implements Serializable {
|
|||
|
||||
}
|
||||
|
||||
// this is to ensure that we integrate over the same range for every function and get comparable results.
|
||||
// Don't need to assert whether previousPoint is null or not; if it is null then the MathFunction was incorrectly made as there will always be at least one point for a response
|
||||
summation += previousPoint.getY() * (tau - previousPoint.getTime());
|
||||
// this is to ensure that we integrate over the proper range
|
||||
if(previousPoint != null){
|
||||
summation += cif.evaluate(tau).getY() * (tau - previousPoint.getTime());
|
||||
}
|
||||
|
||||
|
||||
return summation;
|
||||
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
package ca.joeltherrien.randomforest.responses.competingrisk;
|
||||
|
||||
import ca.joeltherrien.randomforest.tree.ResponseCombiner;
|
||||
import ca.joeltherrien.randomforest.utils.MathFunction;
|
||||
import ca.joeltherrien.randomforest.utils.Point;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
|
||||
import java.util.*;
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
package ca.joeltherrien.randomforest.responses.competingrisk;
|
||||
package ca.joeltherrien.randomforest.utils;
|
||||
|
||||
import lombok.Getter;
|
||||
|
||||
|
@ -43,6 +43,14 @@ public class MathFunction implements Serializable {
|
|||
|
||||
}
|
||||
|
||||
public Point evaluatePrevious(double time){
|
||||
final Optional<Point> pointOptional = points.stream()
|
||||
.filter(point -> point.getTime() < time)
|
||||
.max(Comparator.comparingDouble(Point::getTime));
|
||||
|
||||
return pointOptional.orElse(defaultValue);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString(){
|
||||
final StringBuilder builder = new StringBuilder();
|
|
@ -1,4 +1,4 @@
|
|||
package ca.joeltherrien.randomforest.responses.competingrisk;
|
||||
package ca.joeltherrien.randomforest.utils;
|
||||
|
||||
import lombok.Data;
|
||||
|
40
src/main/java/ca/joeltherrien/randomforest/utils/Utils.java
Normal file
40
src/main/java/ca/joeltherrien/randomforest/utils/Utils.java
Normal file
|
@ -0,0 +1,40 @@
|
|||
package ca.joeltherrien.randomforest.utils;
|
||||
|
||||
import java.util.*;
|
||||
|
||||
public class Utils {
|
||||
|
||||
public static MathFunction estimateOneMinusECDF(final double[] times){
|
||||
final Point defaultPoint = new Point(0.0, 1.0);
|
||||
Arrays.sort(times);
|
||||
|
||||
final Map<Double, Integer> timeCounterMap = new HashMap<>();
|
||||
|
||||
for(final double time : times){
|
||||
Integer existingCount = timeCounterMap.get(time);
|
||||
existingCount = existingCount != null ? existingCount : 0;
|
||||
|
||||
timeCounterMap.put(time, existingCount+1);
|
||||
}
|
||||
|
||||
final List<Map.Entry<Double, Integer>> timeCounterList = new ArrayList<>(timeCounterMap.entrySet());
|
||||
Collections.sort(timeCounterList, Comparator.comparingDouble(Map.Entry::getKey));
|
||||
|
||||
final List<Point> pointList = new ArrayList<>(timeCounterList.size());
|
||||
|
||||
int previousCount = times.length;
|
||||
final double n = times.length;
|
||||
|
||||
for(final Map.Entry<Double, Integer> entry : timeCounterList){
|
||||
final int newCount = previousCount - entry.getValue();
|
||||
previousCount = newCount;
|
||||
|
||||
pointList.add(new Point(entry.getKey(), (double) newCount / n));
|
||||
}
|
||||
|
||||
return new MathFunction(pointList, defaultPoint);
|
||||
|
||||
}
|
||||
|
||||
|
||||
}
|
73
src/test/java/ca/joeltherrien/randomforest/TestUtils.java
Normal file
73
src/test/java/ca/joeltherrien/randomforest/TestUtils.java
Normal file
|
@ -0,0 +1,73 @@
|
|||
package ca.joeltherrien.randomforest;
|
||||
|
||||
import ca.joeltherrien.randomforest.utils.MathFunction;
|
||||
import ca.joeltherrien.randomforest.utils.Point;
|
||||
import ca.joeltherrien.randomforest.utils.Utils;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||
|
||||
public class TestUtils {
|
||||
|
||||
public static void closeEnough(double expected, double actual, double margin){
|
||||
assertTrue(Math.abs(expected - actual) < margin, "Expected " + expected + " but saw " + actual);
|
||||
}
|
||||
|
||||
/**
|
||||
* We know the function is cumulative; make sure it is ordered correctly and that that function is monotone.
|
||||
*
|
||||
* @param function
|
||||
*/
|
||||
public static void assertCumulativeFunction(MathFunction function){
|
||||
Point previousPoint = null;
|
||||
for(final Point point : function.getPoints()){
|
||||
|
||||
if(previousPoint != null){
|
||||
assertTrue(previousPoint.getTime() < point.getTime(), "Points should be ordered and strictly different");
|
||||
assertTrue(previousPoint.getY() <= point.getY(), "Cumulative incidence functions are monotone");
|
||||
}
|
||||
|
||||
|
||||
previousPoint = point;
|
||||
}
|
||||
}
|
||||
|
||||
public static void assertSurvivalCurve(MathFunction function){
|
||||
Point previousPoint = null;
|
||||
for(final Point point : function.getPoints()){
|
||||
|
||||
if(previousPoint != null){
|
||||
assertTrue(previousPoint.getTime() < point.getTime(), "Points should be ordered and strictly different");
|
||||
assertTrue(previousPoint.getY() >= point.getY(), "Survival functions are monotone");
|
||||
}
|
||||
|
||||
|
||||
previousPoint = point;
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testOneMinusECDF(){
|
||||
final double[] times = new double[]{1.0, 1.0, 2.0, 3.0, 3.0, 50.0};
|
||||
final MathFunction survivalCurve = Utils.estimateOneMinusECDF(times);
|
||||
|
||||
final double margin = 0.000001;
|
||||
closeEnough(1.0, survivalCurve.evaluate(0.0).getY(), margin);
|
||||
|
||||
closeEnough(1.0, survivalCurve.evaluatePrevious(1.0).getY(), margin);
|
||||
closeEnough(4.0/6.0, survivalCurve.evaluate(1.0).getY(), margin);
|
||||
|
||||
closeEnough(4.0/6.0, survivalCurve.evaluatePrevious(2.0).getY(), margin);
|
||||
closeEnough(3.0/6.0, survivalCurve.evaluate(2.0).getY(), margin);
|
||||
|
||||
closeEnough(3.0/6.0, survivalCurve.evaluatePrevious(3.0).getY(), margin);
|
||||
closeEnough(1.0/6.0, survivalCurve.evaluate(3.0).getY(), margin);
|
||||
|
||||
closeEnough(1.0/6.0, survivalCurve.evaluatePrevious(50.0).getY(), margin);
|
||||
closeEnough(0.0, survivalCurve.evaluate(50.0).getY(), margin);
|
||||
|
||||
assertSurvivalCurve(survivalCurve);
|
||||
|
||||
}
|
||||
|
||||
}
|
|
@ -7,9 +7,13 @@ import ca.joeltherrien.randomforest.tree.Forest;
|
|||
import ca.joeltherrien.randomforest.tree.ForestTrainer;
|
||||
import ca.joeltherrien.randomforest.tree.Node;
|
||||
import ca.joeltherrien.randomforest.tree.TreeTrainer;
|
||||
import ca.joeltherrien.randomforest.utils.MathFunction;
|
||||
import ca.joeltherrien.randomforest.utils.Point;
|
||||
import com.fasterxml.jackson.databind.node.*;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import static ca.joeltherrien.randomforest.TestUtils.assertCumulativeFunction;
|
||||
import static ca.joeltherrien.randomforest.TestUtils.closeEnough;
|
||||
import static org.junit.jupiter.api.Assertions.*;
|
||||
|
||||
import java.io.IOException;
|
||||
|
@ -240,10 +244,10 @@ public class TestCompetingRisk {
|
|||
System.out.println(errorRates[1]);
|
||||
|
||||
|
||||
closeEnough(0.452, errorRates[0], 0.01);
|
||||
closeEnough(0.446, errorRates[1], 0.01);
|
||||
|
||||
closeEnough(0.452, errorRates[0], 0.02);
|
||||
closeEnough(0.446, errorRates[1], 0.02);
|
||||
|
||||
System.out.println(errorRateCalculator.calculateNaiveMortalityError(new int[]{1,2}));
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -306,7 +310,7 @@ public class TestCompetingRisk {
|
|||
final List<Point> causeOneCIFPoints = functions.getCumulativeIncidenceFunction(1).getPoints();
|
||||
|
||||
// We seem to consistently underestimate the results.
|
||||
assertTrue(causeOneCIFPoints.get(causeOneCIFPoints.size()-1).getY() > 0.75, "Results should match randomForestSRC; had " + causeOneCIFPoints.get(causeOneCIFPoints.size()-1).getY()); // note; most observations from randomForestSRC hover around 0.78 but I've seen it as low as 0.72
|
||||
assertTrue(causeOneCIFPoints.get(causeOneCIFPoints.size()-1).getY() > 0.74, "Results should match randomForestSRC; had " + causeOneCIFPoints.get(causeOneCIFPoints.size()-1).getY()); // note; most observations from randomForestSRC hover around 0.78 but I've seen it as low as 0.72
|
||||
|
||||
final CompetingRiskErrorRateCalculator errorRateCalculator = new CompetingRiskErrorRateCalculator(dataset, forest);
|
||||
final double[] errorRates = errorRateCalculator.calculateConcordance(new int[]{1,2});
|
||||
|
@ -322,29 +326,9 @@ public class TestCompetingRisk {
|
|||
// Consistency results
|
||||
closeEnough(0.395, errorRates[0], 0.01);
|
||||
closeEnough(0.345, errorRates[1], 0.01);
|
||||
}
|
||||
|
||||
/**
|
||||
* We know the function is cumulative; make sure it is ordered correctly and that that function is monotone.
|
||||
*
|
||||
* @param function
|
||||
*/
|
||||
private void assertCumulativeFunction(MathFunction function){
|
||||
Point previousPoint = null;
|
||||
for(final Point point : function.getPoints()){
|
||||
System.out.println(errorRateCalculator.calculateNaiveMortalityError(new int[]{1,2}));
|
||||
|
||||
if(previousPoint != null){
|
||||
assertTrue(previousPoint.getTime() < point.getTime(), "Points should be ordered and strictly different");
|
||||
assertTrue(previousPoint.getY() <= point.getY(), "Cumulative incidence functions are monotone");
|
||||
}
|
||||
|
||||
|
||||
previousPoint = point;
|
||||
}
|
||||
}
|
||||
|
||||
private void closeEnough(double expected, double actual, double margin){
|
||||
assertTrue(Math.abs(expected - actual) < margin, "Expected " + expected + " but saw " + actual);
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -1,16 +1,16 @@
|
|||
package ca.joeltherrien.randomforest.competingrisk;
|
||||
|
||||
|
||||
import ca.joeltherrien.randomforest.Row;
|
||||
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskErrorRateCalculator;
|
||||
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskFunctions;
|
||||
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
|
||||
import ca.joeltherrien.randomforest.responses.competingrisk.*;
|
||||
import ca.joeltherrien.randomforest.tree.Forest;
|
||||
import ca.joeltherrien.randomforest.utils.MathFunction;
|
||||
import ca.joeltherrien.randomforest.utils.Point;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
|
||||
import static ca.joeltherrien.randomforest.TestUtils.closeEnough;
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
@ -33,10 +33,16 @@ public class TestCompetingRiskErrorRateCalculator {
|
|||
final Forest<CompetingRiskFunctions, CompetingRiskFunctions> fakeForest = Forest.<CompetingRiskFunctions, CompetingRiskFunctions>builder().build();
|
||||
final CompetingRiskErrorRateCalculator errorRateCalculator = new CompetingRiskErrorRateCalculator(Collections.emptyList(), fakeForest);
|
||||
|
||||
final double concordance = errorRateCalculator.calculateConcordance(responseList, mortalityArray, event);
|
||||
final double naiveConcordance = errorRateCalculator.calculateConcordance(responseList, mortalityArray, event);
|
||||
|
||||
final MathFunction fakeCensorDistribution = new MathFunction(Collections.emptyList(), new Point(0.0, 1.0));
|
||||
// This distribution will make the IPCW weights == 1, giving identical results to the naive concordance.
|
||||
final double ipcwConcordance = errorRateCalculator.calculateIPCWConcordance(responseList, mortalityArray, event, fakeCensorDistribution);
|
||||
|
||||
closeEnough(naiveConcordance, ipcwConcordance, 0.0001);
|
||||
|
||||
// Expected value found through calculations by hand
|
||||
assertEquals(3.0/5.0, concordance);
|
||||
assertEquals(3.0/5.0, naiveConcordance);
|
||||
|
||||
}
|
||||
|
||||
|
|
|
@ -3,8 +3,10 @@ package ca.joeltherrien.randomforest.competingrisk;
|
|||
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskFunctions;
|
||||
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
|
||||
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponseCombiner;
|
||||
import ca.joeltherrien.randomforest.responses.competingrisk.MathFunction;
|
||||
import ca.joeltherrien.randomforest.utils.MathFunction;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import static ca.joeltherrien.randomforest.TestUtils.closeEnough;
|
||||
import static org.junit.jupiter.api.Assertions.*;
|
||||
|
||||
import java.util.ArrayList;
|
||||
|
@ -87,8 +89,4 @@ public class TestCompetingRiskResponseCombiner {
|
|||
|
||||
}
|
||||
|
||||
private void closeEnough(double expected, double actual, double margin){
|
||||
assertTrue(Math.abs(expected - actual) < margin, "Expected " + expected + " but saw " + actual);
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
package ca.joeltherrien.randomforest.competingrisk;
|
||||
|
||||
import ca.joeltherrien.randomforest.responses.competingrisk.MathFunction;
|
||||
import ca.joeltherrien.randomforest.responses.competingrisk.Point;
|
||||
import ca.joeltherrien.randomforest.utils.MathFunction;
|
||||
import ca.joeltherrien.randomforest.utils.Point;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import static org.junit.jupiter.api.Assertions.*;
|
||||
|
||||
|
|
Loading…
Reference in a new issue