Add integrated Brier score error measure
This commit is contained in:
parent
7371dab4f1
commit
9258f75e4e
4 changed files with 287 additions and 1 deletions
|
@ -16,9 +16,11 @@
|
||||||
|
|
||||||
package ca.joeltherrien.randomforest.responses.competingrisk;
|
package ca.joeltherrien.randomforest.responses.competingrisk;
|
||||||
|
|
||||||
|
import ca.joeltherrien.randomforest.utils.RightContinuousStepFunction;
|
||||||
import ca.joeltherrien.randomforest.utils.StepFunction;
|
import ca.joeltherrien.randomforest.utils.StepFunction;
|
||||||
|
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
|
import java.util.stream.IntStream;
|
||||||
import java.util.stream.Stream;
|
import java.util.stream.Stream;
|
||||||
|
|
||||||
public class CompetingRiskUtils {
|
public class CompetingRiskUtils {
|
||||||
|
@ -116,6 +118,44 @@ public class CompetingRiskUtils {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Calculate the Integrated Brier Score error on a list of responses and predictions.
|
||||||
|
*
|
||||||
|
* @param responses A List of responses
|
||||||
|
* @param predictions The corresponding List of predictions.
|
||||||
|
* @param censoringDistribution The censoring distribution.
|
||||||
|
* @param eventOfFocus The event we are calculating the error for.
|
||||||
|
* @param integrationUpperBound The upper bound to integrate to.
|
||||||
|
* @param isParallel Whether we should use parallel streams or not (provided because of bugs on a particular system).
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
public static double[] calculateIBSError(final List<CompetingRiskResponse> responses,
|
||||||
|
List<CompetingRiskFunctions> predictions,
|
||||||
|
Optional<RightContinuousStepFunction> censoringDistribution,
|
||||||
|
int eventOfFocus,
|
||||||
|
double integrationUpperBound,
|
||||||
|
boolean isParallel){
|
||||||
|
|
||||||
|
if(responses.size() != predictions.size()){
|
||||||
|
throw new IllegalArgumentException("Length of responses and predictions must be equal.");
|
||||||
|
}
|
||||||
|
|
||||||
|
final IBSCalculator calculator = new IBSCalculator(censoringDistribution);
|
||||||
|
|
||||||
|
IntStream stream = IntStream.range(0, responses.size());
|
||||||
|
|
||||||
|
if(isParallel){
|
||||||
|
stream = stream.parallel();
|
||||||
|
}
|
||||||
|
|
||||||
|
return stream.mapToDouble(i -> {
|
||||||
|
CompetingRiskResponse response = responses.get(i);
|
||||||
|
RightContinuousStepFunction cif = predictions.get(i).getCumulativeIncidenceFunction(eventOfFocus);
|
||||||
|
|
||||||
|
return calculator.calculateError(response, cif, eventOfFocus, integrationUpperBound);
|
||||||
|
}).toArray();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
public static CompetingRiskSetsImpl calculateSetsEfficiently(final List<CompetingRiskResponse> initialLeftHand,
|
public static CompetingRiskSetsImpl calculateSetsEfficiently(final List<CompetingRiskResponse> initialLeftHand,
|
||||||
final List<CompetingRiskResponse> initialRightHand,
|
final List<CompetingRiskResponse> initialRightHand,
|
||||||
|
|
|
@ -0,0 +1,82 @@
|
||||||
|
package ca.joeltherrien.randomforest.responses.competingrisk;
|
||||||
|
|
||||||
|
import ca.joeltherrien.randomforest.utils.RightContinuousStepFunction;
|
||||||
|
|
||||||
|
import java.util.Optional;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Used to calculate the Integrated Brier Score. See Section 4.2 of "Random survival forests for competing risks" by Ishwaran.
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
public class IBSCalculator {
|
||||||
|
|
||||||
|
private final Optional<RightContinuousStepFunction> censoringDistribution;
|
||||||
|
|
||||||
|
public IBSCalculator(RightContinuousStepFunction censoringDistribution){
|
||||||
|
this.censoringDistribution = Optional.of(censoringDistribution);
|
||||||
|
}
|
||||||
|
|
||||||
|
public IBSCalculator(){
|
||||||
|
this.censoringDistribution = Optional.empty();
|
||||||
|
}
|
||||||
|
|
||||||
|
public IBSCalculator(Optional<RightContinuousStepFunction> censoringDistribution){
|
||||||
|
this.censoringDistribution = censoringDistribution;
|
||||||
|
}
|
||||||
|
|
||||||
|
public double calculateError(CompetingRiskResponse response, RightContinuousStepFunction cif, int eventOfInterest, double integrationUpperBound){
|
||||||
|
|
||||||
|
// return integral of weights*(I(response.getU() <= times & response.getDelta() == eventOfInterest) - cif(times))^2
|
||||||
|
// Note that if we don't have weights, just treat them all as one (i.e. don't bother multiplying)
|
||||||
|
|
||||||
|
RightContinuousStepFunction functionToIntegrate = cif;
|
||||||
|
|
||||||
|
if(response.getDelta() == eventOfInterest){
|
||||||
|
final RightContinuousStepFunction observedFunction = new RightContinuousStepFunction(new double[]{response.getU()}, new double[]{1.0}, 0.0);
|
||||||
|
functionToIntegrate = RightContinuousStepFunction.biOperation(observedFunction, functionToIntegrate, (a, b) -> (a - b) * (a - b));
|
||||||
|
} else{
|
||||||
|
functionToIntegrate = functionToIntegrate.unaryOperation(a -> a*a);
|
||||||
|
}
|
||||||
|
|
||||||
|
if(censoringDistribution.isPresent()){
|
||||||
|
final RightContinuousStepFunction weights = calculateWeights(response, censoringDistribution.get());
|
||||||
|
functionToIntegrate = RightContinuousStepFunction.biOperation(weights, functionToIntegrate, (a, b) -> a*b);
|
||||||
|
|
||||||
|
// the censoring weights go to 0 after the response is censored, so we can speed up results by only integrating
|
||||||
|
// prior to the censor times
|
||||||
|
if(response.isCensored()){
|
||||||
|
integrationUpperBound = Math.min(integrationUpperBound, response.getU());
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
return functionToIntegrate.integrate(0.0, integrationUpperBound);
|
||||||
|
}
|
||||||
|
|
||||||
|
private RightContinuousStepFunction calculateWeights(CompetingRiskResponse response, RightContinuousStepFunction censoringDistribution){
|
||||||
|
final double recordedTime = response.getU();
|
||||||
|
|
||||||
|
// Function(t) = firstPart(t) + secondPart(t)/thirdPart(t) where:
|
||||||
|
// firstPart(t) = I(recordedTime <= t & !response.isCensored()) / censoringDistribution.evaluate(recordedTime);
|
||||||
|
// secondPart(t) = I(recordedTime > t) = 1 - I(recordedTime <= t)
|
||||||
|
// thirdPart(t) = censoringDistribution.evaluate(t)
|
||||||
|
|
||||||
|
final RightContinuousStepFunction secondPart = new RightContinuousStepFunction(new double[]{recordedTime}, new double[]{0.0}, 1.0);
|
||||||
|
RightContinuousStepFunction result = RightContinuousStepFunction.biOperation(secondPart, censoringDistribution,
|
||||||
|
(second, third) -> second / third);
|
||||||
|
|
||||||
|
if(!response.isCensored()){
|
||||||
|
final RightContinuousStepFunction firstPart = new RightContinuousStepFunction(
|
||||||
|
new double[]{recordedTime},
|
||||||
|
new double[]{1.0 / censoringDistribution.evaluate(recordedTime)},
|
||||||
|
0.0);
|
||||||
|
|
||||||
|
result = RightContinuousStepFunction.biOperation(firstPart, result, Double::sum);
|
||||||
|
}
|
||||||
|
|
||||||
|
return result;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
}
|
|
@ -20,7 +20,7 @@ import java.util.*;
|
||||||
|
|
||||||
public final class Utils {
|
public final class Utils {
|
||||||
|
|
||||||
public static StepFunction estimateOneMinusECDF(final double[] times){
|
public static RightContinuousStepFunction estimateOneMinusECDF(final double[] times){
|
||||||
Arrays.sort(times);
|
Arrays.sort(times);
|
||||||
|
|
||||||
final Map<Double, Integer> timeCounterMap = new HashMap<>();
|
final Map<Double, Integer> timeCounterMap = new HashMap<>();
|
||||||
|
|
|
@ -0,0 +1,164 @@
|
||||||
|
package ca.joeltherrien.randomforest.competingrisk;
|
||||||
|
|
||||||
|
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskFunctions;
|
||||||
|
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
|
||||||
|
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskUtils;
|
||||||
|
import ca.joeltherrien.randomforest.responses.competingrisk.IBSCalculator;
|
||||||
|
import ca.joeltherrien.randomforest.utils.Point;
|
||||||
|
import ca.joeltherrien.randomforest.utils.RightContinuousStepFunction;
|
||||||
|
import ca.joeltherrien.randomforest.utils.Utils;
|
||||||
|
import org.junit.jupiter.api.Test;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Optional;
|
||||||
|
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
|
||||||
|
public class IBSCalculatorTest {
|
||||||
|
|
||||||
|
private final RightContinuousStepFunction cif;
|
||||||
|
|
||||||
|
public IBSCalculatorTest(){
|
||||||
|
this.cif = RightContinuousStepFunction.constructFromPoints(
|
||||||
|
Utils.easyList(
|
||||||
|
new Point(1.0, 0.1),
|
||||||
|
new Point(2.0, 0.2),
|
||||||
|
new Point(3.0, 0.3),
|
||||||
|
new Point(4.0, 0.8)
|
||||||
|
), 0.0
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
R code to get these results:
|
||||||
|
|
||||||
|
predicted_cif <- stepfun(1:4, c(0, 0.1, 0.2, 0.3, 0.8))
|
||||||
|
weights <- 1
|
||||||
|
recorded_time <- 2.0
|
||||||
|
recorded_status <- 1.0
|
||||||
|
event_of_interest <- 2
|
||||||
|
times <- 0:4
|
||||||
|
|
||||||
|
errors <- weights * ( as.integer(recorded_time <= times & recorded_status == event_of_interest) - predicted_cif(times))^2
|
||||||
|
sum(errors)
|
||||||
|
|
||||||
|
|
||||||
|
and run again with event_of_interest <- 1
|
||||||
|
|
||||||
|
|
||||||
|
Note that in the R code I only evaluate up to 4, while in the Java code I integrate up to 5
|
||||||
|
This is because the evaluation at 4 is giving the area of the rectangle from 4 to 5.
|
||||||
|
|
||||||
|
*/
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void resultsWithoutCensoringDistribution(){
|
||||||
|
final IBSCalculator calculator = new IBSCalculator();
|
||||||
|
|
||||||
|
final double errorDifferentEvent = calculator.calculateError(
|
||||||
|
new CompetingRiskResponse(1, 2.0),
|
||||||
|
this.cif,
|
||||||
|
2,
|
||||||
|
5.0);
|
||||||
|
|
||||||
|
assertEquals(0.78, errorDifferentEvent, 0.000001);
|
||||||
|
|
||||||
|
final double errorSameEvent = calculator.calculateError(
|
||||||
|
new CompetingRiskResponse(1, 2.0),
|
||||||
|
this.cif,
|
||||||
|
1,
|
||||||
|
5.0);
|
||||||
|
|
||||||
|
assertEquals(1.18, errorSameEvent, 0.000001);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void resultsWithCensoringDistribution(){
|
||||||
|
final RightContinuousStepFunction censorSurvivalFunction = RightContinuousStepFunction.constructFromPoints(
|
||||||
|
Utils.easyList(
|
||||||
|
new Point(0.0, 0.75),
|
||||||
|
new Point(1.0, 0.5),
|
||||||
|
new Point(3.0, 0.25),
|
||||||
|
new Point(5.0, 0)
|
||||||
|
), 1.0
|
||||||
|
);
|
||||||
|
|
||||||
|
final IBSCalculator calculator = new IBSCalculator(censorSurvivalFunction);
|
||||||
|
|
||||||
|
final double errorDifferentEvent = calculator.calculateError(
|
||||||
|
new CompetingRiskResponse(1, 2.0),
|
||||||
|
this.cif,
|
||||||
|
2,
|
||||||
|
5.0);
|
||||||
|
|
||||||
|
assertEquals(1.56, errorDifferentEvent, 0.000001);
|
||||||
|
|
||||||
|
final double errorSameEvent = calculator.calculateError(
|
||||||
|
new CompetingRiskResponse(1, 2.0),
|
||||||
|
this.cif,
|
||||||
|
1,
|
||||||
|
5.0);
|
||||||
|
|
||||||
|
assertEquals(2.36, errorSameEvent, 0.000001);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testStaticFunction(){
|
||||||
|
final RightContinuousStepFunction censorSurvivalFunction = RightContinuousStepFunction.constructFromPoints(
|
||||||
|
Utils.easyList(
|
||||||
|
new Point(0.0, 0.75),
|
||||||
|
new Point(1.0, 0.5),
|
||||||
|
new Point(3.0, 0.25),
|
||||||
|
new Point(5.0, 0)
|
||||||
|
), 1.0
|
||||||
|
);
|
||||||
|
|
||||||
|
final List<CompetingRiskResponse> responseList = Utils.easyList(
|
||||||
|
new CompetingRiskResponse(1, 2.0),
|
||||||
|
new CompetingRiskResponse(1, 2.0));
|
||||||
|
|
||||||
|
// for predictions; we'll construct an improper CompetingRisksFunctions
|
||||||
|
final RightContinuousStepFunction trivialFunction = RightContinuousStepFunction.constructFromPoints(
|
||||||
|
Utils.easyList(new Point(1.0, 0.0)),
|
||||||
|
1.0);
|
||||||
|
|
||||||
|
final CompetingRiskFunctions prediction = CompetingRiskFunctions.builder()
|
||||||
|
.survivalCurve(trivialFunction)
|
||||||
|
.causeSpecificHazards(Utils.easyList(trivialFunction, trivialFunction))
|
||||||
|
.cumulativeIncidenceCurves(Utils.easyList(this.cif, trivialFunction))
|
||||||
|
.build();
|
||||||
|
|
||||||
|
final List<CompetingRiskFunctions> predictionList = Utils.easyList(prediction, prediction);
|
||||||
|
|
||||||
|
double[] errorParallel = CompetingRiskUtils.calculateIBSError(
|
||||||
|
responseList,
|
||||||
|
predictionList,
|
||||||
|
Optional.of(censorSurvivalFunction),
|
||||||
|
1,
|
||||||
|
5.0,
|
||||||
|
true);
|
||||||
|
|
||||||
|
double[] errorSerial = CompetingRiskUtils.calculateIBSError(
|
||||||
|
responseList,
|
||||||
|
predictionList,
|
||||||
|
Optional.of(censorSurvivalFunction),
|
||||||
|
1,
|
||||||
|
5.0,
|
||||||
|
false);
|
||||||
|
|
||||||
|
assertEquals(responseList.size(), errorParallel.length);
|
||||||
|
assertEquals(responseList.size(), errorSerial.length);
|
||||||
|
|
||||||
|
assertEquals(2.36, errorParallel[0], 0.000001);
|
||||||
|
assertEquals(2.36, errorParallel[1], 0.000001);
|
||||||
|
|
||||||
|
assertEquals(2.36, errorSerial[0], 0.000001);
|
||||||
|
assertEquals(2.36, errorSerial[1], 0.000001);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
}
|
Loading…
Reference in a new issue