Optimize CompetingRiskResponseCombiner
This commit is contained in:
parent
aa733d5eba
commit
7fba964af9
3 changed files with 72 additions and 138 deletions
|
@ -129,15 +129,8 @@ public class Settings {
|
||||||
node.get("events").elements().forEachRemaining(event -> eventList.add(event.asInt()));
|
node.get("events").elements().forEachRemaining(event -> eventList.add(event.asInt()));
|
||||||
final int[] events = eventList.stream().mapToInt(i -> i).toArray();
|
final int[] events = eventList.stream().mapToInt(i -> i).toArray();
|
||||||
|
|
||||||
double[] times = null;
|
|
||||||
// note that times may be null
|
|
||||||
if(node.hasNonNull("times")){
|
|
||||||
final List<Double> timeList = new ArrayList<>();
|
|
||||||
node.get("times").elements().forEachRemaining(time -> timeList.add(time.asDouble()));
|
|
||||||
times = timeList.stream().mapToDouble(db -> db).toArray();
|
|
||||||
}
|
|
||||||
|
|
||||||
return new CompetingRiskResponseCombiner(events, times);
|
return new CompetingRiskResponseCombiner(events);
|
||||||
|
|
||||||
}
|
}
|
||||||
);
|
);
|
||||||
|
@ -167,15 +160,8 @@ public class Settings {
|
||||||
node.get("events").elements().forEachRemaining(event -> eventList.add(event.asInt()));
|
node.get("events").elements().forEachRemaining(event -> eventList.add(event.asInt()));
|
||||||
final int[] events = eventList.stream().mapToInt(i -> i).toArray();
|
final int[] events = eventList.stream().mapToInt(i -> i).toArray();
|
||||||
|
|
||||||
double[] times = null;
|
|
||||||
// note that times may be null
|
|
||||||
if(node.hasNonNull("times")){
|
|
||||||
final List<Double> timeList = new ArrayList<>();
|
|
||||||
node.get("times").elements().forEachRemaining(time -> timeList.add(time.asDouble()));
|
|
||||||
times = timeList.stream().mapToDouble(db -> db).toArray();
|
|
||||||
}
|
|
||||||
|
|
||||||
final CompetingRiskResponseCombiner responseCombiner = new CompetingRiskResponseCombiner(events, times);
|
final CompetingRiskResponseCombiner responseCombiner = new CompetingRiskResponseCombiner(events);
|
||||||
return new CompetingRiskListCombiner(responseCombiner);
|
return new CompetingRiskListCombiner(responseCombiner);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -16,84 +16,108 @@ import java.util.*;
|
||||||
* See https://kogalur.github.io/randomForestSRC/theory.html for details.
|
* See https://kogalur.github.io/randomForestSRC/theory.html for details.
|
||||||
*
|
*
|
||||||
*/
|
*/
|
||||||
@RequiredArgsConstructor
|
|
||||||
public class CompetingRiskResponseCombiner implements ResponseCombiner<CompetingRiskResponse, CompetingRiskFunctions> {
|
public class CompetingRiskResponseCombiner implements ResponseCombiner<CompetingRiskResponse, CompetingRiskFunctions> {
|
||||||
|
|
||||||
private final int[] events;
|
private final int[] events;
|
||||||
private final double[] times; // We may restrict ourselves to specific times.
|
|
||||||
|
public CompetingRiskResponseCombiner(final int[] events){
|
||||||
|
this.events = events.clone();
|
||||||
|
|
||||||
|
// Check to make sure that events go from 1 to the right order
|
||||||
|
for(int i=0; i<events.length; i++){
|
||||||
|
if(events[i] != (i+1)){
|
||||||
|
throw new IllegalArgumentException("The events parameter must be in the form 1,2,3,...J with no gaps");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
public int[] getEvents(){
|
public int[] getEvents(){
|
||||||
return events.clone();
|
return events.clone();
|
||||||
}
|
}
|
||||||
|
|
||||||
public double[] getTimes(){
|
|
||||||
return times.clone();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public CompetingRiskFunctions combine(List<CompetingRiskResponse> responses) {
|
public CompetingRiskFunctions combine(List<CompetingRiskResponse> responses) {
|
||||||
|
|
||||||
final List<MathFunction> causeSpecificCumulativeHazardFunctionList = new ArrayList<>(events.length);
|
final List<MathFunction> causeSpecificCumulativeHazardFunctionList = new ArrayList<>(events.length);
|
||||||
final List<MathFunction> cumulativeIncidenceFunctionList = new ArrayList<>(events.length);
|
final List<MathFunction> cumulativeIncidenceFunctionList = new ArrayList<>(events.length);
|
||||||
|
|
||||||
final double[] timesToUse;
|
Collections.sort(responses, (y1, y2) -> {
|
||||||
if(times != null){
|
if(y1.getU() < y2.getU()){
|
||||||
timesToUse = this.times;
|
return -1;
|
||||||
|
}
|
||||||
|
else if(y1.getU() > y2.getU()){
|
||||||
|
return 1;
|
||||||
}
|
}
|
||||||
else{
|
else{
|
||||||
timesToUse = responses.stream()
|
return 0;
|
||||||
.filter(response -> !response.isCensored())
|
|
||||||
.mapToDouble(response -> response.getU())
|
|
||||||
.sorted().distinct()
|
|
||||||
.toArray();
|
|
||||||
}
|
}
|
||||||
|
});
|
||||||
|
|
||||||
final double[] individualsAtRiskArray = Arrays.stream(timesToUse).map(time -> riskSet(responses, time)).toArray();
|
final int n = responses.size();
|
||||||
|
|
||||||
|
int[] numberOfCurrentEvents = new int[events.length+1];
|
||||||
|
|
||||||
// First we need to develop the overall survival curve!
|
|
||||||
final List<Point> survivalPoints = new ArrayList<>(timesToUse.length);
|
|
||||||
double previousSurvivalValue = 1.0;
|
double previousSurvivalValue = 1.0;
|
||||||
for(int i=0; i<timesToUse.length; i++){
|
final List<Point> survivalPoints = new ArrayList<>(n); // better to be too large than too small
|
||||||
final double time_k = timesToUse[i];
|
|
||||||
final double individualsAtRisk = individualsAtRiskArray[i]; // Y(t_k)
|
|
||||||
|
|
||||||
if(individualsAtRisk == 0){
|
// Also track riskSet variables and numberOfEvents, and timesToUse
|
||||||
// if we continue we'll get NaN
|
final List<Double> timesToUseList = new ArrayList<>(n);
|
||||||
break;
|
final List<Integer> riskSetList = new ArrayList<>(n);
|
||||||
|
final List<int[]> numberOfEvents = new ArrayList<>(n);
|
||||||
|
|
||||||
|
|
||||||
|
for(int i=0; i<n; i++){
|
||||||
|
final CompetingRiskResponse currentResponse = responses.get(i);
|
||||||
|
final boolean lastOfTime = (i+1)==n || responses.get(i+1).getU() > currentResponse.getU();
|
||||||
|
|
||||||
|
numberOfCurrentEvents[currentResponse.getDelta()]++;
|
||||||
|
|
||||||
|
if(lastOfTime){
|
||||||
|
int totalNumberOfCurrentEvents = 0;
|
||||||
|
for(int e = 1; e < numberOfCurrentEvents.length; e++){ // exclude censored events
|
||||||
|
totalNumberOfCurrentEvents += numberOfCurrentEvents[e];
|
||||||
}
|
}
|
||||||
|
|
||||||
final double numberOfEventsAtTime = (double) responses.stream()
|
if(totalNumberOfCurrentEvents > 0){
|
||||||
.filter(event -> !event.isCensored())
|
// Add point
|
||||||
.filter(event -> event.getU() == time_k) // since delta != 0 we know censoring didn't occur prior to this
|
final double currentTime = currentResponse.getU();
|
||||||
.count();
|
final int riskSet = n - (i+1) + totalNumberOfCurrentEvents + numberOfCurrentEvents[0];
|
||||||
|
final double newValue = previousSurvivalValue * (1.0 - (double) totalNumberOfCurrentEvents / (double) riskSet);
|
||||||
final double newValue = previousSurvivalValue * (1.0 - numberOfEventsAtTime / individualsAtRisk);
|
survivalPoints.add(new Point(currentTime, newValue));
|
||||||
survivalPoints.add(new Point(time_k, newValue));
|
|
||||||
previousSurvivalValue = newValue;
|
previousSurvivalValue = newValue;
|
||||||
|
|
||||||
|
timesToUseList.add(currentTime);
|
||||||
|
riskSetList.add(riskSet);
|
||||||
|
numberOfEvents.add(numberOfCurrentEvents);
|
||||||
|
|
||||||
|
}
|
||||||
|
// reset counters
|
||||||
|
numberOfCurrentEvents = new int[events.length+1];
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
}
|
||||||
final MathFunction survivalCurve = new MathFunction(survivalPoints, new Point(0.0, 1.0));
|
final MathFunction survivalCurve = new MathFunction(survivalPoints, new Point(0.0, 1.0));
|
||||||
|
|
||||||
|
|
||||||
for(final int event : events){
|
for(final int event : events){
|
||||||
|
|
||||||
final List<Point> hazardFunctionPoints = new ArrayList<>(timesToUse.length);
|
final List<Point> hazardFunctionPoints = new ArrayList<>(timesToUseList.size());
|
||||||
Point previousHazardFunctionPoint = new Point(0.0, 0.0);
|
Point previousHazardFunctionPoint = new Point(0.0, 0.0);
|
||||||
|
|
||||||
final List<Point> cifPoints = new ArrayList<>(timesToUse.length);
|
final List<Point> cifPoints = new ArrayList<>(timesToUseList.size());
|
||||||
Point previousCIFPoint = new Point(0.0, 0.0);
|
Point previousCIFPoint = new Point(0.0, 0.0);
|
||||||
|
|
||||||
for(int i=0; i<timesToUse.length; i++){
|
for(int i=0; i<timesToUseList.size(); i++){
|
||||||
final double time_k = timesToUse[i];
|
final double time_k = timesToUseList.get(i);
|
||||||
final double individualsAtRisk = individualsAtRiskArray[i]; // Y(t_k)
|
final double individualsAtRisk = riskSetList.get(i); // Y(t_k)
|
||||||
|
|
||||||
if(individualsAtRisk == 0){
|
if(individualsAtRisk == 0){
|
||||||
// if we continue we'll get NaN
|
// if we continue we'll get NaN
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
final double numberEventsAtTime = numberOfEventsAtTime(event, responses, time_k); // d_j(t_k)
|
final double numberEventsAtTime = numberOfEvents.get(i)[event]; // d_j(t_k)
|
||||||
|
|
||||||
// Cause-specific cumulative hazard function
|
// Cause-specific cumulative hazard function
|
||||||
final double hazardDeltaY = numberEventsAtTime / individualsAtRisk;
|
final double hazardDeltaY = numberEventsAtTime / individualsAtRisk;
|
||||||
|
@ -105,7 +129,7 @@ public class CompetingRiskResponseCombiner implements ResponseCombiner<Competing
|
||||||
// Cumulative incidence function
|
// Cumulative incidence function
|
||||||
// TODO - confirm this behaviour
|
// TODO - confirm this behaviour
|
||||||
//final double previousSurvivalEvaluation = i > 0 ? survivalCurve.evaluate(timesToUse[i-1]).getY() : survivalCurve.evaluate(0.0).getY();
|
//final double previousSurvivalEvaluation = i > 0 ? survivalCurve.evaluate(timesToUse[i-1]).getY() : survivalCurve.evaluate(0.0).getY();
|
||||||
final double previousSurvivalEvaluation = i > 0 ? survivalCurve.evaluate(timesToUse[i-1]).getY() : 1.0;
|
final double previousSurvivalEvaluation = i > 0 ? survivalCurve.evaluate(timesToUseList.get(i-1)).getY() : 1.0;
|
||||||
|
|
||||||
final double cifDeltaY = previousSurvivalEvaluation * (numberEventsAtTime / individualsAtRisk);
|
final double cifDeltaY = previousSurvivalEvaluation * (numberEventsAtTime / individualsAtRisk);
|
||||||
final Point newCIFPoint = new Point(time_k, previousCIFPoint.getY() + cifDeltaY);
|
final Point newCIFPoint = new Point(time_k, previousCIFPoint.getY() + cifDeltaY);
|
||||||
|
@ -130,18 +154,5 @@ public class CompetingRiskResponseCombiner implements ResponseCombiner<Competing
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
private double riskSet(List<CompetingRiskResponse> eventList, double time) {
|
|
||||||
return eventList.stream()
|
|
||||||
.filter(event -> event.getU() >= time)
|
|
||||||
.count();
|
|
||||||
}
|
|
||||||
|
|
||||||
private double numberOfEventsAtTime(int eventOfFocus, List<CompetingRiskResponse> eventList, double time){
|
|
||||||
return (double) eventList.stream()
|
|
||||||
.filter(event -> event.getDelta() == eventOfFocus)
|
|
||||||
.filter(event -> event.getU() == time) // since delta != 0 we know censoring didn't occur prior to this
|
|
||||||
.count();
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -13,7 +13,7 @@ import java.util.List;
|
||||||
|
|
||||||
public class TestCompetingRiskResponseCombiner {
|
public class TestCompetingRiskResponseCombiner {
|
||||||
|
|
||||||
private CompetingRiskFunctions generateFunctions(double[] times){
|
private CompetingRiskFunctions generateFunctions(){
|
||||||
final List<CompetingRiskResponse> data = new ArrayList<>();
|
final List<CompetingRiskResponse> data = new ArrayList<>();
|
||||||
|
|
||||||
data.add(new CompetingRiskResponse(1, 1.0));
|
data.add(new CompetingRiskResponse(1, 1.0));
|
||||||
|
@ -24,14 +24,14 @@ public class TestCompetingRiskResponseCombiner {
|
||||||
data.add(new CompetingRiskResponse(0, 1.5));
|
data.add(new CompetingRiskResponse(0, 1.5));
|
||||||
data.add(new CompetingRiskResponse(0, 2.5));
|
data.add(new CompetingRiskResponse(0, 2.5));
|
||||||
|
|
||||||
final CompetingRiskResponseCombiner combiner = new CompetingRiskResponseCombiner(new int[]{1,2}, times);
|
final CompetingRiskResponseCombiner combiner = new CompetingRiskResponseCombiner(new int[]{1,2});
|
||||||
|
|
||||||
return combiner.combine(data);
|
return combiner.combine(data);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testCompetingRiskResponseCombiner(){
|
public void testCompetingRiskResponseCombiner(){
|
||||||
final CompetingRiskFunctions functions = generateFunctions(null);
|
final CompetingRiskFunctions functions = generateFunctions();
|
||||||
|
|
||||||
final MathFunction survivalCurve = functions.getSurvivalCurve();
|
final MathFunction survivalCurve = functions.getSurvivalCurve();
|
||||||
|
|
||||||
|
@ -86,68 +86,5 @@ public class TestCompetingRiskResponseCombiner {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testCompetingRiskResponseCombinerWithSetTimes(){
|
|
||||||
// By including time 3.0 (which extends past the data),
|
|
||||||
// we verify that we don't get NaNs past 3.0, which was a previous bug.
|
|
||||||
final CompetingRiskFunctions functions = generateFunctions(new double[]{1.0, 1.5, 2.0, 2.5, 3.0});
|
|
||||||
|
|
||||||
final MathFunction survivalCurve = functions.getSurvivalCurve();
|
|
||||||
|
|
||||||
// time = 1.0 1.5 2.0 2.5
|
|
||||||
// surv = 0.7142857 0.5714286 0.1904762 0.1904762
|
|
||||||
|
|
||||||
final double margin = 0.0000001;
|
|
||||||
|
|
||||||
closeEnough(0.7142857, survivalCurve.evaluate(1.0).getY(), margin);
|
|
||||||
closeEnough(0.5714286, survivalCurve.evaluate(1.5).getY(), margin);
|
|
||||||
closeEnough(0.1904762, survivalCurve.evaluate(2.0).getY(), margin);
|
|
||||||
closeEnough(0.1904762, survivalCurve.evaluate(2.5).getY(), margin);
|
|
||||||
closeEnough(0.1904762, survivalCurve.evaluate(3.0).getY(), margin);
|
|
||||||
|
|
||||||
|
|
||||||
// Time = 1.0 1.5 2.0 2.5
|
|
||||||
/* Cumulative hazard function. Each row for one event.
|
|
||||||
[,1] [,2] [,3] [,4]
|
|
||||||
[1,] 0.2857143 0.2857143 0.6190476 0.6190476
|
|
||||||
[2,] 0.0000000 0.2000000 0.5333333 0.5333333
|
|
||||||
*/
|
|
||||||
|
|
||||||
final MathFunction cumHaz1 = functions.getCauseSpecificHazardFunction(1);
|
|
||||||
closeEnough(0.2857143, cumHaz1.evaluate(1.0).getY(), margin);
|
|
||||||
closeEnough(0.2857143, cumHaz1.evaluate(1.5).getY(), margin);
|
|
||||||
closeEnough(0.6190476, cumHaz1.evaluate(2.0).getY(), margin);
|
|
||||||
closeEnough(0.6190476, cumHaz1.evaluate(2.5).getY(), margin);
|
|
||||||
closeEnough(0.6190476, cumHaz1.evaluate(3.0).getY(), margin);
|
|
||||||
|
|
||||||
final MathFunction cumHaz2 = functions.getCauseSpecificHazardFunction(2);
|
|
||||||
closeEnough(0.0, cumHaz2.evaluate(1.0).getY(), margin);
|
|
||||||
closeEnough(0.2, cumHaz2.evaluate(1.5).getY(), margin);
|
|
||||||
closeEnough(0.5333333, cumHaz2.evaluate(2.0).getY(), margin);
|
|
||||||
closeEnough(0.5333333, cumHaz2.evaluate(2.5).getY(), margin);
|
|
||||||
closeEnough(0.5333333, cumHaz2.evaluate(3.0).getY(), margin);
|
|
||||||
|
|
||||||
/* Time = 1.0 1.5 2.0 2.5
|
|
||||||
Cumulative Incidence Curve. Each row for one event.
|
|
||||||
[,1] [,2] [,3] [,4]
|
|
||||||
[1,] 0.2857143 0.2857143 0.4761905 0.4761905
|
|
||||||
[2,] 0.0000000 0.1428571 0.3333333 0.3333333
|
|
||||||
*/
|
|
||||||
|
|
||||||
final MathFunction cic1 = functions.getCumulativeIncidenceFunction(1);
|
|
||||||
closeEnough(0.2857143, cic1.evaluate(1.0).getY(), margin);
|
|
||||||
closeEnough(0.2857143, cic1.evaluate(1.5).getY(), margin);
|
|
||||||
closeEnough(0.4761905, cic1.evaluate(2.0).getY(), margin);
|
|
||||||
closeEnough(0.4761905, cic1.evaluate(2.5).getY(), margin);
|
|
||||||
closeEnough(0.4761905, cic1.evaluate(3.0).getY(), margin);
|
|
||||||
|
|
||||||
final MathFunction cic2 = functions.getCumulativeIncidenceFunction(2);
|
|
||||||
closeEnough(0.0, cic2.evaluate(1.0).getY(), margin);
|
|
||||||
closeEnough(0.1428571, cic2.evaluate(1.5).getY(), margin);
|
|
||||||
closeEnough(0.3333333, cic2.evaluate(2.0).getY(), margin);
|
|
||||||
closeEnough(0.3333333, cic2.evaluate(2.5).getY(), margin);
|
|
||||||
closeEnough(0.3333333, cic2.evaluate(3.0).getY(), margin);
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue