Add integrated Brier score error measure

This commit is contained in:
Joel Therrien 2019-07-22 11:22:07 -07:00
parent 7371dab4f1
commit 9258f75e4e
4 changed files with 287 additions and 1 deletions

View file

@ -16,9 +16,11 @@
package ca.joeltherrien.randomforest.responses.competingrisk; package ca.joeltherrien.randomforest.responses.competingrisk;
import ca.joeltherrien.randomforest.utils.RightContinuousStepFunction;
import ca.joeltherrien.randomforest.utils.StepFunction; import ca.joeltherrien.randomforest.utils.StepFunction;
import java.util.*; import java.util.*;
import java.util.stream.IntStream;
import java.util.stream.Stream; import java.util.stream.Stream;
public class CompetingRiskUtils { public class CompetingRiskUtils {
@ -116,6 +118,44 @@ public class CompetingRiskUtils {
} }
/**
* Calculate the Integrated Brier Score error on a list of responses and predictions.
*
* @param responses A List of responses
* @param predictions The corresponding List of predictions.
* @param censoringDistribution The censoring distribution.
* @param eventOfFocus The event we are calculating the error for.
* @param integrationUpperBound The upper bound to integrate to.
* @param isParallel Whether we should use parallel streams or not (provided because of bugs on a particular system).
* @return
*/
public static double[] calculateIBSError(final List<CompetingRiskResponse> responses,
List<CompetingRiskFunctions> predictions,
Optional<RightContinuousStepFunction> censoringDistribution,
int eventOfFocus,
double integrationUpperBound,
boolean isParallel){
if(responses.size() != predictions.size()){
throw new IllegalArgumentException("Length of responses and predictions must be equal.");
}
final IBSCalculator calculator = new IBSCalculator(censoringDistribution);
IntStream stream = IntStream.range(0, responses.size());
if(isParallel){
stream = stream.parallel();
}
return stream.mapToDouble(i -> {
CompetingRiskResponse response = responses.get(i);
RightContinuousStepFunction cif = predictions.get(i).getCumulativeIncidenceFunction(eventOfFocus);
return calculator.calculateError(response, cif, eventOfFocus, integrationUpperBound);
}).toArray();
}
public static CompetingRiskSetsImpl calculateSetsEfficiently(final List<CompetingRiskResponse> initialLeftHand, public static CompetingRiskSetsImpl calculateSetsEfficiently(final List<CompetingRiskResponse> initialLeftHand,
final List<CompetingRiskResponse> initialRightHand, final List<CompetingRiskResponse> initialRightHand,

View file

@ -0,0 +1,82 @@
package ca.joeltherrien.randomforest.responses.competingrisk;
import ca.joeltherrien.randomforest.utils.RightContinuousStepFunction;
import java.util.Optional;
/**
* Used to calculate the Integrated Brier Score. See Section 4.2 of "Random survival forests for competing risks" by Ishwaran.
*
*/
public class IBSCalculator {
private final Optional<RightContinuousStepFunction> censoringDistribution;
public IBSCalculator(RightContinuousStepFunction censoringDistribution){
this.censoringDistribution = Optional.of(censoringDistribution);
}
public IBSCalculator(){
this.censoringDistribution = Optional.empty();
}
public IBSCalculator(Optional<RightContinuousStepFunction> censoringDistribution){
this.censoringDistribution = censoringDistribution;
}
public double calculateError(CompetingRiskResponse response, RightContinuousStepFunction cif, int eventOfInterest, double integrationUpperBound){
// return integral of weights*(I(response.getU() <= times & response.getDelta() == eventOfInterest) - cif(times))^2
// Note that if we don't have weights, just treat them all as one (i.e. don't bother multiplying)
RightContinuousStepFunction functionToIntegrate = cif;
if(response.getDelta() == eventOfInterest){
final RightContinuousStepFunction observedFunction = new RightContinuousStepFunction(new double[]{response.getU()}, new double[]{1.0}, 0.0);
functionToIntegrate = RightContinuousStepFunction.biOperation(observedFunction, functionToIntegrate, (a, b) -> (a - b) * (a - b));
} else{
functionToIntegrate = functionToIntegrate.unaryOperation(a -> a*a);
}
if(censoringDistribution.isPresent()){
final RightContinuousStepFunction weights = calculateWeights(response, censoringDistribution.get());
functionToIntegrate = RightContinuousStepFunction.biOperation(weights, functionToIntegrate, (a, b) -> a*b);
// the censoring weights go to 0 after the response is censored, so we can speed up results by only integrating
// prior to the censor times
if(response.isCensored()){
integrationUpperBound = Math.min(integrationUpperBound, response.getU());
}
}
return functionToIntegrate.integrate(0.0, integrationUpperBound);
}
private RightContinuousStepFunction calculateWeights(CompetingRiskResponse response, RightContinuousStepFunction censoringDistribution){
final double recordedTime = response.getU();
// Function(t) = firstPart(t) + secondPart(t)/thirdPart(t) where:
// firstPart(t) = I(recordedTime <= t & !response.isCensored()) / censoringDistribution.evaluate(recordedTime);
// secondPart(t) = I(recordedTime > t) = 1 - I(recordedTime <= t)
// thirdPart(t) = censoringDistribution.evaluate(t)
final RightContinuousStepFunction secondPart = new RightContinuousStepFunction(new double[]{recordedTime}, new double[]{0.0}, 1.0);
RightContinuousStepFunction result = RightContinuousStepFunction.biOperation(secondPart, censoringDistribution,
(second, third) -> second / third);
if(!response.isCensored()){
final RightContinuousStepFunction firstPart = new RightContinuousStepFunction(
new double[]{recordedTime},
new double[]{1.0 / censoringDistribution.evaluate(recordedTime)},
0.0);
result = RightContinuousStepFunction.biOperation(firstPart, result, Double::sum);
}
return result;
}
}

View file

@ -20,7 +20,7 @@ import java.util.*;
public final class Utils { public final class Utils {
public static StepFunction estimateOneMinusECDF(final double[] times){ public static RightContinuousStepFunction estimateOneMinusECDF(final double[] times){
Arrays.sort(times); Arrays.sort(times);
final Map<Double, Integer> timeCounterMap = new HashMap<>(); final Map<Double, Integer> timeCounterMap = new HashMap<>();

View file

@ -0,0 +1,164 @@
package ca.joeltherrien.randomforest.competingrisk;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskFunctions;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskUtils;
import ca.joeltherrien.randomforest.responses.competingrisk.IBSCalculator;
import ca.joeltherrien.randomforest.utils.Point;
import ca.joeltherrien.randomforest.utils.RightContinuousStepFunction;
import ca.joeltherrien.randomforest.utils.Utils;
import org.junit.jupiter.api.Test;
import java.util.List;
import java.util.Optional;
import static org.junit.jupiter.api.Assertions.assertEquals;
public class IBSCalculatorTest {
private final RightContinuousStepFunction cif;
public IBSCalculatorTest(){
this.cif = RightContinuousStepFunction.constructFromPoints(
Utils.easyList(
new Point(1.0, 0.1),
new Point(2.0, 0.2),
new Point(3.0, 0.3),
new Point(4.0, 0.8)
), 0.0
);
}
/*
R code to get these results:
predicted_cif <- stepfun(1:4, c(0, 0.1, 0.2, 0.3, 0.8))
weights <- 1
recorded_time <- 2.0
recorded_status <- 1.0
event_of_interest <- 2
times <- 0:4
errors <- weights * ( as.integer(recorded_time <= times & recorded_status == event_of_interest) - predicted_cif(times))^2
sum(errors)
and run again with event_of_interest <- 1
Note that in the R code I only evaluate up to 4, while in the Java code I integrate up to 5
This is because the evaluation at 4 is giving the area of the rectangle from 4 to 5.
*/
@Test
public void resultsWithoutCensoringDistribution(){
final IBSCalculator calculator = new IBSCalculator();
final double errorDifferentEvent = calculator.calculateError(
new CompetingRiskResponse(1, 2.0),
this.cif,
2,
5.0);
assertEquals(0.78, errorDifferentEvent, 0.000001);
final double errorSameEvent = calculator.calculateError(
new CompetingRiskResponse(1, 2.0),
this.cif,
1,
5.0);
assertEquals(1.18, errorSameEvent, 0.000001);
}
@Test
public void resultsWithCensoringDistribution(){
final RightContinuousStepFunction censorSurvivalFunction = RightContinuousStepFunction.constructFromPoints(
Utils.easyList(
new Point(0.0, 0.75),
new Point(1.0, 0.5),
new Point(3.0, 0.25),
new Point(5.0, 0)
), 1.0
);
final IBSCalculator calculator = new IBSCalculator(censorSurvivalFunction);
final double errorDifferentEvent = calculator.calculateError(
new CompetingRiskResponse(1, 2.0),
this.cif,
2,
5.0);
assertEquals(1.56, errorDifferentEvent, 0.000001);
final double errorSameEvent = calculator.calculateError(
new CompetingRiskResponse(1, 2.0),
this.cif,
1,
5.0);
assertEquals(2.36, errorSameEvent, 0.000001);
}
@Test
public void testStaticFunction(){
final RightContinuousStepFunction censorSurvivalFunction = RightContinuousStepFunction.constructFromPoints(
Utils.easyList(
new Point(0.0, 0.75),
new Point(1.0, 0.5),
new Point(3.0, 0.25),
new Point(5.0, 0)
), 1.0
);
final List<CompetingRiskResponse> responseList = Utils.easyList(
new CompetingRiskResponse(1, 2.0),
new CompetingRiskResponse(1, 2.0));
// for predictions; we'll construct an improper CompetingRisksFunctions
final RightContinuousStepFunction trivialFunction = RightContinuousStepFunction.constructFromPoints(
Utils.easyList(new Point(1.0, 0.0)),
1.0);
final CompetingRiskFunctions prediction = CompetingRiskFunctions.builder()
.survivalCurve(trivialFunction)
.causeSpecificHazards(Utils.easyList(trivialFunction, trivialFunction))
.cumulativeIncidenceCurves(Utils.easyList(this.cif, trivialFunction))
.build();
final List<CompetingRiskFunctions> predictionList = Utils.easyList(prediction, prediction);
double[] errorParallel = CompetingRiskUtils.calculateIBSError(
responseList,
predictionList,
Optional.of(censorSurvivalFunction),
1,
5.0,
true);
double[] errorSerial = CompetingRiskUtils.calculateIBSError(
responseList,
predictionList,
Optional.of(censorSurvivalFunction),
1,
5.0,
false);
assertEquals(responseList.size(), errorParallel.length);
assertEquals(responseList.size(), errorSerial.length);
assertEquals(2.36, errorParallel[0], 0.000001);
assertEquals(2.36, errorParallel[1], 0.000001);
assertEquals(2.36, errorSerial[0], 0.000001);
assertEquals(2.36, errorSerial[1], 0.000001);
}
}