Fix a bug in naive mortality error measure; implement IPCW concordance

measure if you can provide the censoring distribution.
This commit is contained in:
Joel Therrien 2018-07-26 12:45:12 -07:00
parent e1caef6d56
commit d3994212b6
12 changed files with 255 additions and 51 deletions

View file

@ -3,6 +3,7 @@ package ca.joeltherrien.randomforest.responses.competingrisk;
import ca.joeltherrien.randomforest.Row;
import ca.joeltherrien.randomforest.VisibleForTesting;
import ca.joeltherrien.randomforest.tree.Forest;
import ca.joeltherrien.randomforest.utils.MathFunction;
import java.util.List;
import java.util.stream.Collectors;
@ -25,12 +26,6 @@ public class CompetingRiskErrorRateCalculator {
.collect(Collectors.toList());
}
public double[] calculateConcordance(final int[] events){
final double tau = dataset.stream().mapToDouble(row -> row.getResponse().getU()).max().orElse(0.0);
return calculateConcordance(events, tau);
}
/**
* Idea for this error rate; go through every observation I have and calculate its mortality for the different events. If the event with the highest mortality is not the one that happened,
* then we add one to the error scale.
@ -39,6 +34,8 @@ public class CompetingRiskErrorRateCalculator {
*
* Possible extensions might involve counting how many other events had higher mortality, instead of just a single PASS / FAIL.
*
* My observation is that this error rate isn't very useful...
*
* @return
*/
public double calculateNaiveMortalityError(final int[] events){
@ -77,6 +74,13 @@ public class CompetingRiskErrorRateCalculator {
}
public double[] calculateConcordance(final int[] events){
final double tau = dataset.stream().mapToDouble(row -> row.getResponse().getU()).max().orElse(0.0);
return calculateConcordance(events, tau);
}
private double[] calculateConcordance(final int[] events, final double tau){
final double[] errorRates = new double[events.length];
@ -101,6 +105,36 @@ public class CompetingRiskErrorRateCalculator {
}
public double[] calculateIPCWConcordance(final int[] events, final MathFunction censoringDistribution){
final double tau = dataset.stream().mapToDouble(row -> row.getResponse().getU()).max().orElse(0.0);
return calculateIPCWConcordance(events, censoringDistribution, tau);
}
private double[] calculateIPCWConcordance(final int[] events, final MathFunction censoringDistribution, final double tau){
final double[] errorRates = new double[events.length];
final List<CompetingRiskResponse> responses = dataset.stream().map(Row::getResponse).collect(Collectors.toList());
// Let \tau be the max time.
for(int e=0; e<events.length; e++){
final int event = events[e];
final double[] mortalityList = riskFunctions.stream()
.mapToDouble(riskFunction -> riskFunction.calculateEventSpecificMortality(event, tau))
.toArray();
final double concordance = calculateIPCWConcordance(responses, mortalityList, event, censoringDistribution);
errorRates[e] = 1.0 - concordance;
}
return errorRates;
}
@VisibleForTesting
public double calculateConcordance(final List<CompetingRiskResponse> responseList, double[] mortalityArray, final int event){
@ -141,4 +175,55 @@ public class CompetingRiskErrorRateCalculator {
}
@VisibleForTesting
public double calculateIPCWConcordance(final List<CompetingRiskResponse> responseList, double[] mortalityArray, final int event, final MathFunction censoringDistribution){
// Let \tau be the max time.
double denominator = 0.0;
double numerator = 0.0;
for(int i = 0; i<mortalityArray.length; i++){
final CompetingRiskResponse responseI = responseList.get(i);
if(responseI.getDelta() != event){ // \tilde{N}_i^1(\tau) == 1 check
continue; // skip if it's 0
}
final double mortalityI = mortalityArray[i];
for(int j=0; j<mortalityArray.length; j++){
final CompetingRiskResponse responseJ = responseList.get(j);
final double AijWeightPlusBijWeight;
if(responseI.getU() < responseJ.getU()){ // Aij == 1
final double Ti = responseI.getU();
AijWeightPlusBijWeight = 1.0 / (censoringDistribution.evaluate(Ti).getY() * censoringDistribution.evaluatePrevious(Ti).getY());
}
else if(responseI.getU() >= responseJ.getU() && !responseJ.isCensored() && responseJ.getDelta() != event){ // Bij == 1
AijWeightPlusBijWeight = 1.0 / (censoringDistribution.evaluatePrevious(responseI.getU()).getY() * censoringDistribution.evaluatePrevious(responseJ.getU()).getY());
}
else{
continue;
}
denominator += AijWeightPlusBijWeight;
final double mortalityJ = mortalityArray[j];
if(mortalityI > mortalityJ){
numerator += AijWeightPlusBijWeight*1.0;
}
else if(mortalityI == mortalityJ){
numerator += AijWeightPlusBijWeight*0.5; // Edge case that can happen in trees with only a few BooleanCovariates, when you're looking at training error
}
}
}
return numerator / denominator;
}
}

View file

@ -1,7 +1,8 @@
package ca.joeltherrien.randomforest.responses.competingrisk;
import ca.joeltherrien.randomforest.tree.ResponseCombiner;
import lombok.Getter;
import ca.joeltherrien.randomforest.utils.MathFunction;
import ca.joeltherrien.randomforest.utils.Point;
import lombok.RequiredArgsConstructor;
import java.util.ArrayList;

View file

@ -1,8 +1,9 @@
package ca.joeltherrien.randomforest.responses.competingrisk;
import ca.joeltherrien.randomforest.utils.MathFunction;
import ca.joeltherrien.randomforest.utils.Point;
import lombok.Builder;
import lombok.Getter;
import lombok.RequiredArgsConstructor;
import java.io.Serializable;
import java.util.Map;
@ -31,6 +32,10 @@ public class CompetingRiskFunctions implements Serializable {
Point previousPoint = null;
for(final Point point : cif.getPoints()){
if(point.getTime() > tau){
break;
}
if(previousPoint != null){
summation += previousPoint.getY() * (point.getTime() - previousPoint.getTime());
}
@ -38,9 +43,11 @@ public class CompetingRiskFunctions implements Serializable {
}
// this is to ensure that we integrate over the same range for every function and get comparable results.
// Don't need to assert whether previousPoint is null or not; if it is null then the MathFunction was incorrectly made as there will always be at least one point for a response
summation += previousPoint.getY() * (tau - previousPoint.getTime());
// this is to ensure that we integrate over the proper range
if(previousPoint != null){
summation += cif.evaluate(tau).getY() * (tau - previousPoint.getTime());
}
return summation;

View file

@ -1,6 +1,8 @@
package ca.joeltherrien.randomforest.responses.competingrisk;
import ca.joeltherrien.randomforest.tree.ResponseCombiner;
import ca.joeltherrien.randomforest.utils.MathFunction;
import ca.joeltherrien.randomforest.utils.Point;
import lombok.RequiredArgsConstructor;
import java.util.*;

View file

@ -1,4 +1,4 @@
package ca.joeltherrien.randomforest.responses.competingrisk;
package ca.joeltherrien.randomforest.utils;
import lombok.Getter;
@ -43,6 +43,14 @@ public class MathFunction implements Serializable {
}
public Point evaluatePrevious(double time){
final Optional<Point> pointOptional = points.stream()
.filter(point -> point.getTime() < time)
.max(Comparator.comparingDouble(Point::getTime));
return pointOptional.orElse(defaultValue);
}
@Override
public String toString(){
final StringBuilder builder = new StringBuilder();

View file

@ -1,4 +1,4 @@
package ca.joeltherrien.randomforest.responses.competingrisk;
package ca.joeltherrien.randomforest.utils;
import lombok.Data;

View file

@ -0,0 +1,40 @@
package ca.joeltherrien.randomforest.utils;
import java.util.*;
public class Utils {
public static MathFunction estimateOneMinusECDF(final double[] times){
final Point defaultPoint = new Point(0.0, 1.0);
Arrays.sort(times);
final Map<Double, Integer> timeCounterMap = new HashMap<>();
for(final double time : times){
Integer existingCount = timeCounterMap.get(time);
existingCount = existingCount != null ? existingCount : 0;
timeCounterMap.put(time, existingCount+1);
}
final List<Map.Entry<Double, Integer>> timeCounterList = new ArrayList<>(timeCounterMap.entrySet());
Collections.sort(timeCounterList, Comparator.comparingDouble(Map.Entry::getKey));
final List<Point> pointList = new ArrayList<>(timeCounterList.size());
int previousCount = times.length;
final double n = times.length;
for(final Map.Entry<Double, Integer> entry : timeCounterList){
final int newCount = previousCount - entry.getValue();
previousCount = newCount;
pointList.add(new Point(entry.getKey(), (double) newCount / n));
}
return new MathFunction(pointList, defaultPoint);
}
}

View file

@ -0,0 +1,73 @@
package ca.joeltherrien.randomforest;
import ca.joeltherrien.randomforest.utils.MathFunction;
import ca.joeltherrien.randomforest.utils.Point;
import ca.joeltherrien.randomforest.utils.Utils;
import org.junit.jupiter.api.Test;
import static org.junit.jupiter.api.Assertions.assertTrue;
public class TestUtils {
public static void closeEnough(double expected, double actual, double margin){
assertTrue(Math.abs(expected - actual) < margin, "Expected " + expected + " but saw " + actual);
}
/**
* We know the function is cumulative; make sure it is ordered correctly and that that function is monotone.
*
* @param function
*/
public static void assertCumulativeFunction(MathFunction function){
Point previousPoint = null;
for(final Point point : function.getPoints()){
if(previousPoint != null){
assertTrue(previousPoint.getTime() < point.getTime(), "Points should be ordered and strictly different");
assertTrue(previousPoint.getY() <= point.getY(), "Cumulative incidence functions are monotone");
}
previousPoint = point;
}
}
public static void assertSurvivalCurve(MathFunction function){
Point previousPoint = null;
for(final Point point : function.getPoints()){
if(previousPoint != null){
assertTrue(previousPoint.getTime() < point.getTime(), "Points should be ordered and strictly different");
assertTrue(previousPoint.getY() >= point.getY(), "Survival functions are monotone");
}
previousPoint = point;
}
}
@Test
public void testOneMinusECDF(){
final double[] times = new double[]{1.0, 1.0, 2.0, 3.0, 3.0, 50.0};
final MathFunction survivalCurve = Utils.estimateOneMinusECDF(times);
final double margin = 0.000001;
closeEnough(1.0, survivalCurve.evaluate(0.0).getY(), margin);
closeEnough(1.0, survivalCurve.evaluatePrevious(1.0).getY(), margin);
closeEnough(4.0/6.0, survivalCurve.evaluate(1.0).getY(), margin);
closeEnough(4.0/6.0, survivalCurve.evaluatePrevious(2.0).getY(), margin);
closeEnough(3.0/6.0, survivalCurve.evaluate(2.0).getY(), margin);
closeEnough(3.0/6.0, survivalCurve.evaluatePrevious(3.0).getY(), margin);
closeEnough(1.0/6.0, survivalCurve.evaluate(3.0).getY(), margin);
closeEnough(1.0/6.0, survivalCurve.evaluatePrevious(50.0).getY(), margin);
closeEnough(0.0, survivalCurve.evaluate(50.0).getY(), margin);
assertSurvivalCurve(survivalCurve);
}
}

View file

@ -7,9 +7,13 @@ import ca.joeltherrien.randomforest.tree.Forest;
import ca.joeltherrien.randomforest.tree.ForestTrainer;
import ca.joeltherrien.randomforest.tree.Node;
import ca.joeltherrien.randomforest.tree.TreeTrainer;
import ca.joeltherrien.randomforest.utils.MathFunction;
import ca.joeltherrien.randomforest.utils.Point;
import com.fasterxml.jackson.databind.node.*;
import org.junit.jupiter.api.Test;
import static ca.joeltherrien.randomforest.TestUtils.assertCumulativeFunction;
import static ca.joeltherrien.randomforest.TestUtils.closeEnough;
import static org.junit.jupiter.api.Assertions.*;
import java.io.IOException;
@ -240,10 +244,10 @@ public class TestCompetingRisk {
System.out.println(errorRates[1]);
closeEnough(0.452, errorRates[0], 0.01);
closeEnough(0.446, errorRates[1], 0.01);
closeEnough(0.452, errorRates[0], 0.02);
closeEnough(0.446, errorRates[1], 0.02);
System.out.println(errorRateCalculator.calculateNaiveMortalityError(new int[]{1,2}));
}
@Test
@ -306,7 +310,7 @@ public class TestCompetingRisk {
final List<Point> causeOneCIFPoints = functions.getCumulativeIncidenceFunction(1).getPoints();
// We seem to consistently underestimate the results.
assertTrue(causeOneCIFPoints.get(causeOneCIFPoints.size()-1).getY() > 0.75, "Results should match randomForestSRC; had " + causeOneCIFPoints.get(causeOneCIFPoints.size()-1).getY()); // note; most observations from randomForestSRC hover around 0.78 but I've seen it as low as 0.72
assertTrue(causeOneCIFPoints.get(causeOneCIFPoints.size()-1).getY() > 0.74, "Results should match randomForestSRC; had " + causeOneCIFPoints.get(causeOneCIFPoints.size()-1).getY()); // note; most observations from randomForestSRC hover around 0.78 but I've seen it as low as 0.72
final CompetingRiskErrorRateCalculator errorRateCalculator = new CompetingRiskErrorRateCalculator(dataset, forest);
final double[] errorRates = errorRateCalculator.calculateConcordance(new int[]{1,2});
@ -322,29 +326,9 @@ public class TestCompetingRisk {
// Consistency results
closeEnough(0.395, errorRates[0], 0.01);
closeEnough(0.345, errorRates[1], 0.01);
}
/**
* We know the function is cumulative; make sure it is ordered correctly and that that function is monotone.
*
* @param function
*/
private void assertCumulativeFunction(MathFunction function){
Point previousPoint = null;
for(final Point point : function.getPoints()){
System.out.println(errorRateCalculator.calculateNaiveMortalityError(new int[]{1,2}));
if(previousPoint != null){
assertTrue(previousPoint.getTime() < point.getTime(), "Points should be ordered and strictly different");
assertTrue(previousPoint.getY() <= point.getY(), "Cumulative incidence functions are monotone");
}
previousPoint = point;
}
}
private void closeEnough(double expected, double actual, double margin){
assertTrue(Math.abs(expected - actual) < margin, "Expected " + expected + " but saw " + actual);
}
}

View file

@ -1,16 +1,16 @@
package ca.joeltherrien.randomforest.competingrisk;
import ca.joeltherrien.randomforest.Row;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskErrorRateCalculator;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskFunctions;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
import ca.joeltherrien.randomforest.responses.competingrisk.*;
import ca.joeltherrien.randomforest.tree.Forest;
import ca.joeltherrien.randomforest.utils.MathFunction;
import ca.joeltherrien.randomforest.utils.Point;
import org.junit.jupiter.api.Test;
import java.util.Collections;
import java.util.List;
import static ca.joeltherrien.randomforest.TestUtils.closeEnough;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
@ -33,10 +33,16 @@ public class TestCompetingRiskErrorRateCalculator {
final Forest<CompetingRiskFunctions, CompetingRiskFunctions> fakeForest = Forest.<CompetingRiskFunctions, CompetingRiskFunctions>builder().build();
final CompetingRiskErrorRateCalculator errorRateCalculator = new CompetingRiskErrorRateCalculator(Collections.emptyList(), fakeForest);
final double concordance = errorRateCalculator.calculateConcordance(responseList, mortalityArray, event);
final double naiveConcordance = errorRateCalculator.calculateConcordance(responseList, mortalityArray, event);
final MathFunction fakeCensorDistribution = new MathFunction(Collections.emptyList(), new Point(0.0, 1.0));
// This distribution will make the IPCW weights == 1, giving identical results to the naive concordance.
final double ipcwConcordance = errorRateCalculator.calculateIPCWConcordance(responseList, mortalityArray, event, fakeCensorDistribution);
closeEnough(naiveConcordance, ipcwConcordance, 0.0001);
// Expected value found through calculations by hand
assertEquals(3.0/5.0, concordance);
assertEquals(3.0/5.0, naiveConcordance);
}

View file

@ -3,8 +3,10 @@ package ca.joeltherrien.randomforest.competingrisk;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskFunctions;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponseCombiner;
import ca.joeltherrien.randomforest.responses.competingrisk.MathFunction;
import ca.joeltherrien.randomforest.utils.MathFunction;
import org.junit.jupiter.api.Test;
import static ca.joeltherrien.randomforest.TestUtils.closeEnough;
import static org.junit.jupiter.api.Assertions.*;
import java.util.ArrayList;
@ -87,8 +89,4 @@ public class TestCompetingRiskResponseCombiner {
}
private void closeEnough(double expected, double actual, double margin){
assertTrue(Math.abs(expected - actual) < margin, "Expected " + expected + " but saw " + actual);
}
}

View file

@ -1,7 +1,7 @@
package ca.joeltherrien.randomforest.competingrisk;
import ca.joeltherrien.randomforest.responses.competingrisk.MathFunction;
import ca.joeltherrien.randomforest.responses.competingrisk.Point;
import ca.joeltherrien.randomforest.utils.MathFunction;
import ca.joeltherrien.randomforest.utils.Point;
import org.junit.jupiter.api.Test;
import static org.junit.jupiter.api.Assertions.*;