diff --git a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/combiner/CompetingRiskFunctionCombiner.java b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/combiner/CompetingRiskFunctionCombiner.java index f6a4aea..ee8292a 100644 --- a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/combiner/CompetingRiskFunctionCombiner.java +++ b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/combiner/CompetingRiskFunctionCombiner.java @@ -19,6 +19,7 @@ package ca.joeltherrien.randomforest.responses.competingrisk.combiner; import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskFunctions; import ca.joeltherrien.randomforest.tree.ResponseCombiner; import ca.joeltherrien.randomforest.utils.RightContinuousStepFunction; +import ca.joeltherrien.randomforest.utils.Utils; import lombok.RequiredArgsConstructor; import java.util.ArrayList; @@ -57,40 +58,63 @@ public class CompetingRiskFunctionCombiner implements ResponseCombiner functions.getSurvivalCurve().evaluate(time) / n) - .sum(); - } + /* + We're going to try to efficiently put our predictions together - + Assumptions - for each event on a response, the hazard and CIF functions share the same x points - final RightContinuousStepFunction survivalFunction = new RightContinuousStepFunction(timesToUse, survivalY, 1.0); + Plan - go through the time on each response and make use of that so that when we search for a time index + to evaluate the function at, we don't need to re-search the earlier times. - final List causeSpecificCumulativeHazardFunctionList = new ArrayList<>(events.length); - final List cumulativeIncidenceFunctionList = new ArrayList<>(events.length); + */ - for(final int event : events){ - final double[] cumulativeHazardFunctionY = new double[timesToUse.length]; - final double[] cumulativeIncidenceFunctionY = new double[timesToUse.length]; + for(final CompetingRiskFunctions currentFunctions : responses){ + final double[] survivalXPoints = currentFunctions.getSurvivalCurve().getX(); + final double[][] eventSpecificXPoints = new double[events.length][]; + + for(final int event : events){ + eventSpecificXPoints[event-1] = currentFunctions.getCumulativeIncidenceFunction(event) + .getX(); + } + + int previousSurvivalIndex = 0; + final int[] previousEventSpecificIndex = new int[events.length]; // relying on 0 being default value for(int i=0; i functions.getCauseSpecificHazardFunction(event).evaluate(time) / n) - .sum(); + // Survival curve + final int survivalTimeIndex = Utils.binarySearchLessThan(previousSurvivalIndex, survivalXPoints.length, survivalXPoints, time); + survivalY[i] = survivalY[i] + currentFunctions.getSurvivalCurve().evaluateByIndex(survivalTimeIndex) / n; + previousSurvivalIndex = Math.max(survivalTimeIndex, 0); // if our current time is less than the smallest time in xPoints then binarySearchLessThan returned a -1. + // -1's not an issue for evaluateByIndex, but it is an issue for the next time binarySearchLessThan is called. - cumulativeIncidenceFunctionY[i] = responses.stream() - .mapToDouble(functions -> functions.getCumulativeIncidenceFunction(event).evaluate(time) / n) - .sum(); + // CHFs and CIFs + for(final int event : events){ + final double[] xPoints = eventSpecificXPoints[event-1]; + final int eventTimeIndex = Utils.binarySearchLessThan(previousEventSpecificIndex[event-1], xPoints.length, + xPoints, time); + csCHFY[event-1][i] = csCHFY[event-1][i] + currentFunctions.getCauseSpecificHazardFunction(event) + .evaluateByIndex(eventTimeIndex) / n; + cifY[event-1][i] = cifY[event-1][i] + currentFunctions.getCumulativeIncidenceFunction(event) + .evaluateByIndex(eventTimeIndex) / n; + previousEventSpecificIndex[event-1] = Math.max(eventTimeIndex, 0); + } } - causeSpecificCumulativeHazardFunctionList.add(event-1, new RightContinuousStepFunction(timesToUse, cumulativeHazardFunctionY, 0)); - cumulativeIncidenceFunctionList.add(event-1, new RightContinuousStepFunction(timesToUse, cumulativeIncidenceFunctionY, 0)); + } + final RightContinuousStepFunction survivalFunction = new RightContinuousStepFunction(timesToUse, survivalY, 1.0); + final List causeSpecificCumulativeHazardFunctionList = new ArrayList<>(events.length); + final List cumulativeIncidenceFunctionList = new ArrayList<>(events.length); + + for(final int event : events){ + causeSpecificCumulativeHazardFunctionList.add(event-1, new RightContinuousStepFunction(timesToUse, csCHFY[event-1], 0)); + cumulativeIncidenceFunctionList.add(event-1, new RightContinuousStepFunction(timesToUse, cifY[event-1], 0)); } return CompetingRiskFunctions.builder()