diff --git a/library/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskUtils.java b/library/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskUtils.java index 24c3f69..ebec041 100644 --- a/library/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskUtils.java +++ b/library/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskUtils.java @@ -16,9 +16,11 @@ package ca.joeltherrien.randomforest.responses.competingrisk; +import ca.joeltherrien.randomforest.utils.RightContinuousStepFunction; import ca.joeltherrien.randomforest.utils.StepFunction; import java.util.*; +import java.util.stream.IntStream; import java.util.stream.Stream; public class CompetingRiskUtils { @@ -116,6 +118,44 @@ public class CompetingRiskUtils { } + /** + * Calculate the Integrated Brier Score error on a list of responses and predictions. + * + * @param responses A List of responses + * @param predictions The corresponding List of predictions. + * @param censoringDistribution The censoring distribution. + * @param eventOfFocus The event we are calculating the error for. + * @param integrationUpperBound The upper bound to integrate to. + * @param isParallel Whether we should use parallel streams or not (provided because of bugs on a particular system). + * @return + */ + public static double[] calculateIBSError(final List responses, + List predictions, + Optional censoringDistribution, + int eventOfFocus, + double integrationUpperBound, + boolean isParallel){ + + if(responses.size() != predictions.size()){ + throw new IllegalArgumentException("Length of responses and predictions must be equal."); + } + + final IBSCalculator calculator = new IBSCalculator(censoringDistribution); + + IntStream stream = IntStream.range(0, responses.size()); + + if(isParallel){ + stream = stream.parallel(); + } + + return stream.mapToDouble(i -> { + CompetingRiskResponse response = responses.get(i); + RightContinuousStepFunction cif = predictions.get(i).getCumulativeIncidenceFunction(eventOfFocus); + + return calculator.calculateError(response, cif, eventOfFocus, integrationUpperBound); + }).toArray(); + } + public static CompetingRiskSetsImpl calculateSetsEfficiently(final List initialLeftHand, final List initialRightHand, diff --git a/library/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/IBSCalculator.java b/library/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/IBSCalculator.java new file mode 100644 index 0000000..eda6ce8 --- /dev/null +++ b/library/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/IBSCalculator.java @@ -0,0 +1,82 @@ +package ca.joeltherrien.randomforest.responses.competingrisk; + +import ca.joeltherrien.randomforest.utils.RightContinuousStepFunction; + +import java.util.Optional; + +/** + * Used to calculate the Integrated Brier Score. See Section 4.2 of "Random survival forests for competing risks" by Ishwaran. + * + */ +public class IBSCalculator { + + private final Optional censoringDistribution; + + public IBSCalculator(RightContinuousStepFunction censoringDistribution){ + this.censoringDistribution = Optional.of(censoringDistribution); + } + + public IBSCalculator(){ + this.censoringDistribution = Optional.empty(); + } + + public IBSCalculator(Optional censoringDistribution){ + this.censoringDistribution = censoringDistribution; + } + + public double calculateError(CompetingRiskResponse response, RightContinuousStepFunction cif, int eventOfInterest, double integrationUpperBound){ + + // return integral of weights*(I(response.getU() <= times & response.getDelta() == eventOfInterest) - cif(times))^2 + // Note that if we don't have weights, just treat them all as one (i.e. don't bother multiplying) + + RightContinuousStepFunction functionToIntegrate = cif; + + if(response.getDelta() == eventOfInterest){ + final RightContinuousStepFunction observedFunction = new RightContinuousStepFunction(new double[]{response.getU()}, new double[]{1.0}, 0.0); + functionToIntegrate = RightContinuousStepFunction.biOperation(observedFunction, functionToIntegrate, (a, b) -> (a - b) * (a - b)); + } else{ + functionToIntegrate = functionToIntegrate.unaryOperation(a -> a*a); + } + + if(censoringDistribution.isPresent()){ + final RightContinuousStepFunction weights = calculateWeights(response, censoringDistribution.get()); + functionToIntegrate = RightContinuousStepFunction.biOperation(weights, functionToIntegrate, (a, b) -> a*b); + + // the censoring weights go to 0 after the response is censored, so we can speed up results by only integrating + // prior to the censor times + if(response.isCensored()){ + integrationUpperBound = Math.min(integrationUpperBound, response.getU()); + } + + } + + return functionToIntegrate.integrate(0.0, integrationUpperBound); + } + + private RightContinuousStepFunction calculateWeights(CompetingRiskResponse response, RightContinuousStepFunction censoringDistribution){ + final double recordedTime = response.getU(); + + // Function(t) = firstPart(t) + secondPart(t)/thirdPart(t) where: + // firstPart(t) = I(recordedTime <= t & !response.isCensored()) / censoringDistribution.evaluate(recordedTime); + // secondPart(t) = I(recordedTime > t) = 1 - I(recordedTime <= t) + // thirdPart(t) = censoringDistribution.evaluate(t) + + final RightContinuousStepFunction secondPart = new RightContinuousStepFunction(new double[]{recordedTime}, new double[]{0.0}, 1.0); + RightContinuousStepFunction result = RightContinuousStepFunction.biOperation(secondPart, censoringDistribution, + (second, third) -> second / third); + + if(!response.isCensored()){ + final RightContinuousStepFunction firstPart = new RightContinuousStepFunction( + new double[]{recordedTime}, + new double[]{1.0 / censoringDistribution.evaluate(recordedTime)}, + 0.0); + + result = RightContinuousStepFunction.biOperation(firstPart, result, Double::sum); + } + + return result; + + } + + +} diff --git a/library/src/main/java/ca/joeltherrien/randomforest/utils/Utils.java b/library/src/main/java/ca/joeltherrien/randomforest/utils/Utils.java index 70e3e1f..2661b2d 100644 --- a/library/src/main/java/ca/joeltherrien/randomforest/utils/Utils.java +++ b/library/src/main/java/ca/joeltherrien/randomforest/utils/Utils.java @@ -20,7 +20,7 @@ import java.util.*; public final class Utils { - public static StepFunction estimateOneMinusECDF(final double[] times){ + public static RightContinuousStepFunction estimateOneMinusECDF(final double[] times){ Arrays.sort(times); final Map timeCounterMap = new HashMap<>(); diff --git a/library/src/test/java/ca/joeltherrien/randomforest/competingrisk/IBSCalculatorTest.java b/library/src/test/java/ca/joeltherrien/randomforest/competingrisk/IBSCalculatorTest.java new file mode 100644 index 0000000..dd3a2af --- /dev/null +++ b/library/src/test/java/ca/joeltherrien/randomforest/competingrisk/IBSCalculatorTest.java @@ -0,0 +1,164 @@ +package ca.joeltherrien.randomforest.competingrisk; + +import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskFunctions; +import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse; +import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskUtils; +import ca.joeltherrien.randomforest.responses.competingrisk.IBSCalculator; +import ca.joeltherrien.randomforest.utils.Point; +import ca.joeltherrien.randomforest.utils.RightContinuousStepFunction; +import ca.joeltherrien.randomforest.utils.Utils; +import org.junit.jupiter.api.Test; + +import java.util.List; +import java.util.Optional; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class IBSCalculatorTest { + + private final RightContinuousStepFunction cif; + + public IBSCalculatorTest(){ + this.cif = RightContinuousStepFunction.constructFromPoints( + Utils.easyList( + new Point(1.0, 0.1), + new Point(2.0, 0.2), + new Point(3.0, 0.3), + new Point(4.0, 0.8) + ), 0.0 + ); + } + + /* + R code to get these results: + + predicted_cif <- stepfun(1:4, c(0, 0.1, 0.2, 0.3, 0.8)) + weights <- 1 + recorded_time <- 2.0 + recorded_status <- 1.0 + event_of_interest <- 2 + times <- 0:4 + + errors <- weights * ( as.integer(recorded_time <= times & recorded_status == event_of_interest) - predicted_cif(times))^2 + sum(errors) + + + and run again with event_of_interest <- 1 + + + Note that in the R code I only evaluate up to 4, while in the Java code I integrate up to 5 + This is because the evaluation at 4 is giving the area of the rectangle from 4 to 5. + + */ + + @Test + public void resultsWithoutCensoringDistribution(){ + final IBSCalculator calculator = new IBSCalculator(); + + final double errorDifferentEvent = calculator.calculateError( + new CompetingRiskResponse(1, 2.0), + this.cif, + 2, + 5.0); + + assertEquals(0.78, errorDifferentEvent, 0.000001); + + final double errorSameEvent = calculator.calculateError( + new CompetingRiskResponse(1, 2.0), + this.cif, + 1, + 5.0); + + assertEquals(1.18, errorSameEvent, 0.000001); + + } + + @Test + public void resultsWithCensoringDistribution(){ + final RightContinuousStepFunction censorSurvivalFunction = RightContinuousStepFunction.constructFromPoints( + Utils.easyList( + new Point(0.0, 0.75), + new Point(1.0, 0.5), + new Point(3.0, 0.25), + new Point(5.0, 0) + ), 1.0 + ); + + final IBSCalculator calculator = new IBSCalculator(censorSurvivalFunction); + + final double errorDifferentEvent = calculator.calculateError( + new CompetingRiskResponse(1, 2.0), + this.cif, + 2, + 5.0); + + assertEquals(1.56, errorDifferentEvent, 0.000001); + + final double errorSameEvent = calculator.calculateError( + new CompetingRiskResponse(1, 2.0), + this.cif, + 1, + 5.0); + + assertEquals(2.36, errorSameEvent, 0.000001); + + } + + @Test + public void testStaticFunction(){ + final RightContinuousStepFunction censorSurvivalFunction = RightContinuousStepFunction.constructFromPoints( + Utils.easyList( + new Point(0.0, 0.75), + new Point(1.0, 0.5), + new Point(3.0, 0.25), + new Point(5.0, 0) + ), 1.0 + ); + + final List responseList = Utils.easyList( + new CompetingRiskResponse(1, 2.0), + new CompetingRiskResponse(1, 2.0)); + + // for predictions; we'll construct an improper CompetingRisksFunctions + final RightContinuousStepFunction trivialFunction = RightContinuousStepFunction.constructFromPoints( + Utils.easyList(new Point(1.0, 0.0)), + 1.0); + + final CompetingRiskFunctions prediction = CompetingRiskFunctions.builder() + .survivalCurve(trivialFunction) + .causeSpecificHazards(Utils.easyList(trivialFunction, trivialFunction)) + .cumulativeIncidenceCurves(Utils.easyList(this.cif, trivialFunction)) + .build(); + + final List predictionList = Utils.easyList(prediction, prediction); + + double[] errorParallel = CompetingRiskUtils.calculateIBSError( + responseList, + predictionList, + Optional.of(censorSurvivalFunction), + 1, + 5.0, + true); + + double[] errorSerial = CompetingRiskUtils.calculateIBSError( + responseList, + predictionList, + Optional.of(censorSurvivalFunction), + 1, + 5.0, + false); + + assertEquals(responseList.size(), errorParallel.length); + assertEquals(responseList.size(), errorSerial.length); + + assertEquals(2.36, errorParallel[0], 0.000001); + assertEquals(2.36, errorParallel[1], 0.000001); + + assertEquals(2.36, errorSerial[0], 0.000001); + assertEquals(2.36, errorSerial[1], 0.000001); + + } + + + +}