Add integrated Brier score error measure
This commit is contained in:
parent
7371dab4f1
commit
9258f75e4e
4 changed files with 287 additions and 1 deletions
|
@ -16,9 +16,11 @@
|
|||
|
||||
package ca.joeltherrien.randomforest.responses.competingrisk;
|
||||
|
||||
import ca.joeltherrien.randomforest.utils.RightContinuousStepFunction;
|
||||
import ca.joeltherrien.randomforest.utils.StepFunction;
|
||||
|
||||
import java.util.*;
|
||||
import java.util.stream.IntStream;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
public class CompetingRiskUtils {
|
||||
|
@ -116,6 +118,44 @@ public class CompetingRiskUtils {
|
|||
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculate the Integrated Brier Score error on a list of responses and predictions.
|
||||
*
|
||||
* @param responses A List of responses
|
||||
* @param predictions The corresponding List of predictions.
|
||||
* @param censoringDistribution The censoring distribution.
|
||||
* @param eventOfFocus The event we are calculating the error for.
|
||||
* @param integrationUpperBound The upper bound to integrate to.
|
||||
* @param isParallel Whether we should use parallel streams or not (provided because of bugs on a particular system).
|
||||
* @return
|
||||
*/
|
||||
public static double[] calculateIBSError(final List<CompetingRiskResponse> responses,
|
||||
List<CompetingRiskFunctions> predictions,
|
||||
Optional<RightContinuousStepFunction> censoringDistribution,
|
||||
int eventOfFocus,
|
||||
double integrationUpperBound,
|
||||
boolean isParallel){
|
||||
|
||||
if(responses.size() != predictions.size()){
|
||||
throw new IllegalArgumentException("Length of responses and predictions must be equal.");
|
||||
}
|
||||
|
||||
final IBSCalculator calculator = new IBSCalculator(censoringDistribution);
|
||||
|
||||
IntStream stream = IntStream.range(0, responses.size());
|
||||
|
||||
if(isParallel){
|
||||
stream = stream.parallel();
|
||||
}
|
||||
|
||||
return stream.mapToDouble(i -> {
|
||||
CompetingRiskResponse response = responses.get(i);
|
||||
RightContinuousStepFunction cif = predictions.get(i).getCumulativeIncidenceFunction(eventOfFocus);
|
||||
|
||||
return calculator.calculateError(response, cif, eventOfFocus, integrationUpperBound);
|
||||
}).toArray();
|
||||
}
|
||||
|
||||
|
||||
public static CompetingRiskSetsImpl calculateSetsEfficiently(final List<CompetingRiskResponse> initialLeftHand,
|
||||
final List<CompetingRiskResponse> initialRightHand,
|
||||
|
|
|
@ -0,0 +1,82 @@
|
|||
package ca.joeltherrien.randomforest.responses.competingrisk;
|
||||
|
||||
import ca.joeltherrien.randomforest.utils.RightContinuousStepFunction;
|
||||
|
||||
import java.util.Optional;
|
||||
|
||||
/**
|
||||
* Used to calculate the Integrated Brier Score. See Section 4.2 of "Random survival forests for competing risks" by Ishwaran.
|
||||
*
|
||||
*/
|
||||
public class IBSCalculator {
|
||||
|
||||
private final Optional<RightContinuousStepFunction> censoringDistribution;
|
||||
|
||||
public IBSCalculator(RightContinuousStepFunction censoringDistribution){
|
||||
this.censoringDistribution = Optional.of(censoringDistribution);
|
||||
}
|
||||
|
||||
public IBSCalculator(){
|
||||
this.censoringDistribution = Optional.empty();
|
||||
}
|
||||
|
||||
public IBSCalculator(Optional<RightContinuousStepFunction> censoringDistribution){
|
||||
this.censoringDistribution = censoringDistribution;
|
||||
}
|
||||
|
||||
public double calculateError(CompetingRiskResponse response, RightContinuousStepFunction cif, int eventOfInterest, double integrationUpperBound){
|
||||
|
||||
// return integral of weights*(I(response.getU() <= times & response.getDelta() == eventOfInterest) - cif(times))^2
|
||||
// Note that if we don't have weights, just treat them all as one (i.e. don't bother multiplying)
|
||||
|
||||
RightContinuousStepFunction functionToIntegrate = cif;
|
||||
|
||||
if(response.getDelta() == eventOfInterest){
|
||||
final RightContinuousStepFunction observedFunction = new RightContinuousStepFunction(new double[]{response.getU()}, new double[]{1.0}, 0.0);
|
||||
functionToIntegrate = RightContinuousStepFunction.biOperation(observedFunction, functionToIntegrate, (a, b) -> (a - b) * (a - b));
|
||||
} else{
|
||||
functionToIntegrate = functionToIntegrate.unaryOperation(a -> a*a);
|
||||
}
|
||||
|
||||
if(censoringDistribution.isPresent()){
|
||||
final RightContinuousStepFunction weights = calculateWeights(response, censoringDistribution.get());
|
||||
functionToIntegrate = RightContinuousStepFunction.biOperation(weights, functionToIntegrate, (a, b) -> a*b);
|
||||
|
||||
// the censoring weights go to 0 after the response is censored, so we can speed up results by only integrating
|
||||
// prior to the censor times
|
||||
if(response.isCensored()){
|
||||
integrationUpperBound = Math.min(integrationUpperBound, response.getU());
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
return functionToIntegrate.integrate(0.0, integrationUpperBound);
|
||||
}
|
||||
|
||||
private RightContinuousStepFunction calculateWeights(CompetingRiskResponse response, RightContinuousStepFunction censoringDistribution){
|
||||
final double recordedTime = response.getU();
|
||||
|
||||
// Function(t) = firstPart(t) + secondPart(t)/thirdPart(t) where:
|
||||
// firstPart(t) = I(recordedTime <= t & !response.isCensored()) / censoringDistribution.evaluate(recordedTime);
|
||||
// secondPart(t) = I(recordedTime > t) = 1 - I(recordedTime <= t)
|
||||
// thirdPart(t) = censoringDistribution.evaluate(t)
|
||||
|
||||
final RightContinuousStepFunction secondPart = new RightContinuousStepFunction(new double[]{recordedTime}, new double[]{0.0}, 1.0);
|
||||
RightContinuousStepFunction result = RightContinuousStepFunction.biOperation(secondPart, censoringDistribution,
|
||||
(second, third) -> second / third);
|
||||
|
||||
if(!response.isCensored()){
|
||||
final RightContinuousStepFunction firstPart = new RightContinuousStepFunction(
|
||||
new double[]{recordedTime},
|
||||
new double[]{1.0 / censoringDistribution.evaluate(recordedTime)},
|
||||
0.0);
|
||||
|
||||
result = RightContinuousStepFunction.biOperation(firstPart, result, Double::sum);
|
||||
}
|
||||
|
||||
return result;
|
||||
|
||||
}
|
||||
|
||||
|
||||
}
|
|
@ -20,7 +20,7 @@ import java.util.*;
|
|||
|
||||
public final class Utils {
|
||||
|
||||
public static StepFunction estimateOneMinusECDF(final double[] times){
|
||||
public static RightContinuousStepFunction estimateOneMinusECDF(final double[] times){
|
||||
Arrays.sort(times);
|
||||
|
||||
final Map<Double, Integer> timeCounterMap = new HashMap<>();
|
||||
|
|
|
@ -0,0 +1,164 @@
|
|||
package ca.joeltherrien.randomforest.competingrisk;
|
||||
|
||||
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskFunctions;
|
||||
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
|
||||
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskUtils;
|
||||
import ca.joeltherrien.randomforest.responses.competingrisk.IBSCalculator;
|
||||
import ca.joeltherrien.randomforest.utils.Point;
|
||||
import ca.joeltherrien.randomforest.utils.RightContinuousStepFunction;
|
||||
import ca.joeltherrien.randomforest.utils.Utils;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
|
||||
public class IBSCalculatorTest {
|
||||
|
||||
private final RightContinuousStepFunction cif;
|
||||
|
||||
public IBSCalculatorTest(){
|
||||
this.cif = RightContinuousStepFunction.constructFromPoints(
|
||||
Utils.easyList(
|
||||
new Point(1.0, 0.1),
|
||||
new Point(2.0, 0.2),
|
||||
new Point(3.0, 0.3),
|
||||
new Point(4.0, 0.8)
|
||||
), 0.0
|
||||
);
|
||||
}
|
||||
|
||||
/*
|
||||
R code to get these results:
|
||||
|
||||
predicted_cif <- stepfun(1:4, c(0, 0.1, 0.2, 0.3, 0.8))
|
||||
weights <- 1
|
||||
recorded_time <- 2.0
|
||||
recorded_status <- 1.0
|
||||
event_of_interest <- 2
|
||||
times <- 0:4
|
||||
|
||||
errors <- weights * ( as.integer(recorded_time <= times & recorded_status == event_of_interest) - predicted_cif(times))^2
|
||||
sum(errors)
|
||||
|
||||
|
||||
and run again with event_of_interest <- 1
|
||||
|
||||
|
||||
Note that in the R code I only evaluate up to 4, while in the Java code I integrate up to 5
|
||||
This is because the evaluation at 4 is giving the area of the rectangle from 4 to 5.
|
||||
|
||||
*/
|
||||
|
||||
@Test
|
||||
public void resultsWithoutCensoringDistribution(){
|
||||
final IBSCalculator calculator = new IBSCalculator();
|
||||
|
||||
final double errorDifferentEvent = calculator.calculateError(
|
||||
new CompetingRiskResponse(1, 2.0),
|
||||
this.cif,
|
||||
2,
|
||||
5.0);
|
||||
|
||||
assertEquals(0.78, errorDifferentEvent, 0.000001);
|
||||
|
||||
final double errorSameEvent = calculator.calculateError(
|
||||
new CompetingRiskResponse(1, 2.0),
|
||||
this.cif,
|
||||
1,
|
||||
5.0);
|
||||
|
||||
assertEquals(1.18, errorSameEvent, 0.000001);
|
||||
|
||||
}
|
||||
|
||||
@Test
|
||||
public void resultsWithCensoringDistribution(){
|
||||
final RightContinuousStepFunction censorSurvivalFunction = RightContinuousStepFunction.constructFromPoints(
|
||||
Utils.easyList(
|
||||
new Point(0.0, 0.75),
|
||||
new Point(1.0, 0.5),
|
||||
new Point(3.0, 0.25),
|
||||
new Point(5.0, 0)
|
||||
), 1.0
|
||||
);
|
||||
|
||||
final IBSCalculator calculator = new IBSCalculator(censorSurvivalFunction);
|
||||
|
||||
final double errorDifferentEvent = calculator.calculateError(
|
||||
new CompetingRiskResponse(1, 2.0),
|
||||
this.cif,
|
||||
2,
|
||||
5.0);
|
||||
|
||||
assertEquals(1.56, errorDifferentEvent, 0.000001);
|
||||
|
||||
final double errorSameEvent = calculator.calculateError(
|
||||
new CompetingRiskResponse(1, 2.0),
|
||||
this.cif,
|
||||
1,
|
||||
5.0);
|
||||
|
||||
assertEquals(2.36, errorSameEvent, 0.000001);
|
||||
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testStaticFunction(){
|
||||
final RightContinuousStepFunction censorSurvivalFunction = RightContinuousStepFunction.constructFromPoints(
|
||||
Utils.easyList(
|
||||
new Point(0.0, 0.75),
|
||||
new Point(1.0, 0.5),
|
||||
new Point(3.0, 0.25),
|
||||
new Point(5.0, 0)
|
||||
), 1.0
|
||||
);
|
||||
|
||||
final List<CompetingRiskResponse> responseList = Utils.easyList(
|
||||
new CompetingRiskResponse(1, 2.0),
|
||||
new CompetingRiskResponse(1, 2.0));
|
||||
|
||||
// for predictions; we'll construct an improper CompetingRisksFunctions
|
||||
final RightContinuousStepFunction trivialFunction = RightContinuousStepFunction.constructFromPoints(
|
||||
Utils.easyList(new Point(1.0, 0.0)),
|
||||
1.0);
|
||||
|
||||
final CompetingRiskFunctions prediction = CompetingRiskFunctions.builder()
|
||||
.survivalCurve(trivialFunction)
|
||||
.causeSpecificHazards(Utils.easyList(trivialFunction, trivialFunction))
|
||||
.cumulativeIncidenceCurves(Utils.easyList(this.cif, trivialFunction))
|
||||
.build();
|
||||
|
||||
final List<CompetingRiskFunctions> predictionList = Utils.easyList(prediction, prediction);
|
||||
|
||||
double[] errorParallel = CompetingRiskUtils.calculateIBSError(
|
||||
responseList,
|
||||
predictionList,
|
||||
Optional.of(censorSurvivalFunction),
|
||||
1,
|
||||
5.0,
|
||||
true);
|
||||
|
||||
double[] errorSerial = CompetingRiskUtils.calculateIBSError(
|
||||
responseList,
|
||||
predictionList,
|
||||
Optional.of(censorSurvivalFunction),
|
||||
1,
|
||||
5.0,
|
||||
false);
|
||||
|
||||
assertEquals(responseList.size(), errorParallel.length);
|
||||
assertEquals(responseList.size(), errorSerial.length);
|
||||
|
||||
assertEquals(2.36, errorParallel[0], 0.000001);
|
||||
assertEquals(2.36, errorParallel[1], 0.000001);
|
||||
|
||||
assertEquals(2.36, errorSerial[0], 0.000001);
|
||||
assertEquals(2.36, errorSerial[1], 0.000001);
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
||||
}
|
Loading…
Reference in a new issue