Improve performance of CompetingRiskFunctionCombiner
Estimate of time improvement is at least 10x faster
This commit is contained in:
parent
c8269ae285
commit
bf168bc2a5
1 changed files with 44 additions and 20 deletions
|
@ -19,6 +19,7 @@ package ca.joeltherrien.randomforest.responses.competingrisk.combiner;
|
||||||
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskFunctions;
|
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskFunctions;
|
||||||
import ca.joeltherrien.randomforest.tree.ResponseCombiner;
|
import ca.joeltherrien.randomforest.tree.ResponseCombiner;
|
||||||
import ca.joeltherrien.randomforest.utils.RightContinuousStepFunction;
|
import ca.joeltherrien.randomforest.utils.RightContinuousStepFunction;
|
||||||
|
import ca.joeltherrien.randomforest.utils.Utils;
|
||||||
import lombok.RequiredArgsConstructor;
|
import lombok.RequiredArgsConstructor;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
|
@ -57,40 +58,63 @@ public class CompetingRiskFunctionCombiner implements ResponseCombiner<Competing
|
||||||
final double n = responses.size();
|
final double n = responses.size();
|
||||||
|
|
||||||
final double[] survivalY = new double[timesToUse.length];
|
final double[] survivalY = new double[timesToUse.length];
|
||||||
|
final double[][] csCHFY = new double[events.length][timesToUse.length];
|
||||||
|
final double[][] cifY = new double[events.length][timesToUse.length];
|
||||||
|
|
||||||
|
/*
|
||||||
|
We're going to try to efficiently put our predictions together -
|
||||||
|
Assumptions - for each event on a response, the hazard and CIF functions share the same x points
|
||||||
|
|
||||||
|
Plan - go through the time on each response and make use of that so that when we search for a time index
|
||||||
|
to evaluate the function at, we don't need to re-search the earlier times.
|
||||||
|
|
||||||
|
*/
|
||||||
|
|
||||||
|
|
||||||
|
for(final CompetingRiskFunctions currentFunctions : responses){
|
||||||
|
final double[] survivalXPoints = currentFunctions.getSurvivalCurve().getX();
|
||||||
|
final double[][] eventSpecificXPoints = new double[events.length][];
|
||||||
|
|
||||||
|
for(final int event : events){
|
||||||
|
eventSpecificXPoints[event-1] = currentFunctions.getCumulativeIncidenceFunction(event)
|
||||||
|
.getX();
|
||||||
|
}
|
||||||
|
|
||||||
|
int previousSurvivalIndex = 0;
|
||||||
|
final int[] previousEventSpecificIndex = new int[events.length]; // relying on 0 being default value
|
||||||
|
|
||||||
for(int i=0; i<timesToUse.length; i++){
|
for(int i=0; i<timesToUse.length; i++){
|
||||||
final double time = timesToUse[i];
|
final double time = timesToUse[i];
|
||||||
survivalY[i] = responses.stream()
|
|
||||||
.mapToDouble(functions -> functions.getSurvivalCurve().evaluate(time) / n)
|
// Survival curve
|
||||||
.sum();
|
final int survivalTimeIndex = Utils.binarySearchLessThan(previousSurvivalIndex, survivalXPoints.length, survivalXPoints, time);
|
||||||
|
survivalY[i] = survivalY[i] + currentFunctions.getSurvivalCurve().evaluateByIndex(survivalTimeIndex) / n;
|
||||||
|
previousSurvivalIndex = Math.max(survivalTimeIndex, 0); // if our current time is less than the smallest time in xPoints then binarySearchLessThan returned a -1.
|
||||||
|
// -1's not an issue for evaluateByIndex, but it is an issue for the next time binarySearchLessThan is called.
|
||||||
|
|
||||||
|
// CHFs and CIFs
|
||||||
|
for(final int event : events){
|
||||||
|
final double[] xPoints = eventSpecificXPoints[event-1];
|
||||||
|
final int eventTimeIndex = Utils.binarySearchLessThan(previousEventSpecificIndex[event-1], xPoints.length,
|
||||||
|
xPoints, time);
|
||||||
|
csCHFY[event-1][i] = csCHFY[event-1][i] + currentFunctions.getCauseSpecificHazardFunction(event)
|
||||||
|
.evaluateByIndex(eventTimeIndex) / n;
|
||||||
|
cifY[event-1][i] = cifY[event-1][i] + currentFunctions.getCumulativeIncidenceFunction(event)
|
||||||
|
.evaluateByIndex(eventTimeIndex) / n;
|
||||||
|
|
||||||
|
previousEventSpecificIndex[event-1] = Math.max(eventTimeIndex, 0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
final RightContinuousStepFunction survivalFunction = new RightContinuousStepFunction(timesToUse, survivalY, 1.0);
|
final RightContinuousStepFunction survivalFunction = new RightContinuousStepFunction(timesToUse, survivalY, 1.0);
|
||||||
|
|
||||||
final List<RightContinuousStepFunction> causeSpecificCumulativeHazardFunctionList = new ArrayList<>(events.length);
|
final List<RightContinuousStepFunction> causeSpecificCumulativeHazardFunctionList = new ArrayList<>(events.length);
|
||||||
final List<RightContinuousStepFunction> cumulativeIncidenceFunctionList = new ArrayList<>(events.length);
|
final List<RightContinuousStepFunction> cumulativeIncidenceFunctionList = new ArrayList<>(events.length);
|
||||||
|
|
||||||
for(final int event : events){
|
for(final int event : events){
|
||||||
|
causeSpecificCumulativeHazardFunctionList.add(event-1, new RightContinuousStepFunction(timesToUse, csCHFY[event-1], 0));
|
||||||
final double[] cumulativeHazardFunctionY = new double[timesToUse.length];
|
cumulativeIncidenceFunctionList.add(event-1, new RightContinuousStepFunction(timesToUse, cifY[event-1], 0));
|
||||||
final double[] cumulativeIncidenceFunctionY = new double[timesToUse.length];
|
|
||||||
|
|
||||||
for(int i=0; i<timesToUse.length; i++){
|
|
||||||
final double time = timesToUse[i];
|
|
||||||
|
|
||||||
cumulativeHazardFunctionY[i] = responses.stream()
|
|
||||||
.mapToDouble(functions -> functions.getCauseSpecificHazardFunction(event).evaluate(time) / n)
|
|
||||||
.sum();
|
|
||||||
|
|
||||||
cumulativeIncidenceFunctionY[i] = responses.stream()
|
|
||||||
.mapToDouble(functions -> functions.getCumulativeIncidenceFunction(event).evaluate(time) / n)
|
|
||||||
.sum();
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
causeSpecificCumulativeHazardFunctionList.add(event-1, new RightContinuousStepFunction(timesToUse, cumulativeHazardFunctionY, 0));
|
|
||||||
cumulativeIncidenceFunctionList.add(event-1, new RightContinuousStepFunction(timesToUse, cumulativeIncidenceFunctionY, 0));
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return CompetingRiskFunctions.builder()
|
return CompetingRiskFunctions.builder()
|
||||||
|
|
Loading…
Reference in a new issue