diff --git a/src/main/java/ca/joeltherrien/randomforest/Settings.java b/src/main/java/ca/joeltherrien/randomforest/Settings.java index 0c15b57..7d737bb 100644 --- a/src/main/java/ca/joeltherrien/randomforest/Settings.java +++ b/src/main/java/ca/joeltherrien/randomforest/Settings.java @@ -129,15 +129,8 @@ public class Settings { node.get("events").elements().forEachRemaining(event -> eventList.add(event.asInt())); final int[] events = eventList.stream().mapToInt(i -> i).toArray(); - double[] times = null; - // note that times may be null - if(node.hasNonNull("times")){ - final List timeList = new ArrayList<>(); - node.get("times").elements().forEachRemaining(time -> timeList.add(time.asDouble())); - times = timeList.stream().mapToDouble(db -> db).toArray(); - } - return new CompetingRiskResponseCombiner(events, times); + return new CompetingRiskResponseCombiner(events); } ); @@ -167,15 +160,8 @@ public class Settings { node.get("events").elements().forEachRemaining(event -> eventList.add(event.asInt())); final int[] events = eventList.stream().mapToInt(i -> i).toArray(); - double[] times = null; - // note that times may be null - if(node.hasNonNull("times")){ - final List timeList = new ArrayList<>(); - node.get("times").elements().forEachRemaining(time -> timeList.add(time.asDouble())); - times = timeList.stream().mapToDouble(db -> db).toArray(); - } - final CompetingRiskResponseCombiner responseCombiner = new CompetingRiskResponseCombiner(events, times); + final CompetingRiskResponseCombiner responseCombiner = new CompetingRiskResponseCombiner(events); return new CompetingRiskListCombiner(responseCombiner); } diff --git a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/combiner/CompetingRiskResponseCombiner.java b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/combiner/CompetingRiskResponseCombiner.java index 43de0e2..edd4591 100644 --- a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/combiner/CompetingRiskResponseCombiner.java +++ b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/combiner/CompetingRiskResponseCombiner.java @@ -16,84 +16,108 @@ import java.util.*; * See https://kogalur.github.io/randomForestSRC/theory.html for details. * */ -@RequiredArgsConstructor public class CompetingRiskResponseCombiner implements ResponseCombiner { private final int[] events; - private final double[] times; // We may restrict ourselves to specific times. + + public CompetingRiskResponseCombiner(final int[] events){ + this.events = events.clone(); + + // Check to make sure that events go from 1 to the right order + for(int i=0; i responses) { final List causeSpecificCumulativeHazardFunctionList = new ArrayList<>(events.length); final List cumulativeIncidenceFunctionList = new ArrayList<>(events.length); - final double[] timesToUse; - if(times != null){ - timesToUse = this.times; - } - else{ - timesToUse = responses.stream() - .filter(response -> !response.isCensored()) - .mapToDouble(response -> response.getU()) - .sorted().distinct() - .toArray(); - } + Collections.sort(responses, (y1, y2) -> { + if(y1.getU() < y2.getU()){ + return -1; + } + else if(y1.getU() > y2.getU()){ + return 1; + } + else{ + return 0; + } + }); - final double[] individualsAtRiskArray = Arrays.stream(timesToUse).map(time -> riskSet(responses, time)).toArray(); + final int n = responses.size(); + + int[] numberOfCurrentEvents = new int[events.length+1]; - // First we need to develop the overall survival curve! - final List survivalPoints = new ArrayList<>(timesToUse.length); double previousSurvivalValue = 1.0; - for(int i=0; i survivalPoints = new ArrayList<>(n); // better to be too large than too small + + // Also track riskSet variables and numberOfEvents, and timesToUse + final List timesToUseList = new ArrayList<>(n); + final List riskSetList = new ArrayList<>(n); + final List numberOfEvents = new ArrayList<>(n); + + + for(int i=0; i currentResponse.getU(); + + numberOfCurrentEvents[currentResponse.getDelta()]++; + + if(lastOfTime){ + int totalNumberOfCurrentEvents = 0; + for(int e = 1; e < numberOfCurrentEvents.length; e++){ // exclude censored events + totalNumberOfCurrentEvents += numberOfCurrentEvents[e]; + } + + if(totalNumberOfCurrentEvents > 0){ + // Add point + final double currentTime = currentResponse.getU(); + final int riskSet = n - (i+1) + totalNumberOfCurrentEvents + numberOfCurrentEvents[0]; + final double newValue = previousSurvivalValue * (1.0 - (double) totalNumberOfCurrentEvents / (double) riskSet); + survivalPoints.add(new Point(currentTime, newValue)); + previousSurvivalValue = newValue; + + timesToUseList.add(currentTime); + riskSetList.add(riskSet); + numberOfEvents.add(numberOfCurrentEvents); + + } + // reset counters + numberOfCurrentEvents = new int[events.length+1]; - if(individualsAtRisk == 0){ - // if we continue we'll get NaN - break; } - final double numberOfEventsAtTime = (double) responses.stream() - .filter(event -> !event.isCensored()) - .filter(event -> event.getU() == time_k) // since delta != 0 we know censoring didn't occur prior to this - .count(); - - final double newValue = previousSurvivalValue * (1.0 - numberOfEventsAtTime / individualsAtRisk); - survivalPoints.add(new Point(time_k, newValue)); - previousSurvivalValue = newValue; - } - final MathFunction survivalCurve = new MathFunction(survivalPoints, new Point(0.0, 1.0)); for(final int event : events){ - final List hazardFunctionPoints = new ArrayList<>(timesToUse.length); + final List hazardFunctionPoints = new ArrayList<>(timesToUseList.size()); Point previousHazardFunctionPoint = new Point(0.0, 0.0); - final List cifPoints = new ArrayList<>(timesToUse.length); + final List cifPoints = new ArrayList<>(timesToUseList.size()); Point previousCIFPoint = new Point(0.0, 0.0); - for(int i=0; i 0 ? survivalCurve.evaluate(timesToUse[i-1]).getY() : survivalCurve.evaluate(0.0).getY(); - final double previousSurvivalEvaluation = i > 0 ? survivalCurve.evaluate(timesToUse[i-1]).getY() : 1.0; + final double previousSurvivalEvaluation = i > 0 ? survivalCurve.evaluate(timesToUseList.get(i-1)).getY() : 1.0; final double cifDeltaY = previousSurvivalEvaluation * (numberEventsAtTime / individualsAtRisk); final Point newCIFPoint = new Point(time_k, previousCIFPoint.getY() + cifDeltaY); @@ -130,18 +154,5 @@ public class CompetingRiskResponseCombiner implements ResponseCombiner eventList, double time) { - return eventList.stream() - .filter(event -> event.getU() >= time) - .count(); - } - - private double numberOfEventsAtTime(int eventOfFocus, List eventList, double time){ - return (double) eventList.stream() - .filter(event -> event.getDelta() == eventOfFocus) - .filter(event -> event.getU() == time) // since delta != 0 we know censoring didn't occur prior to this - .count(); - - } } diff --git a/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCompetingRiskResponseCombiner.java b/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCompetingRiskResponseCombiner.java index 1dcfa93..4cfea67 100644 --- a/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCompetingRiskResponseCombiner.java +++ b/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCompetingRiskResponseCombiner.java @@ -13,7 +13,7 @@ import java.util.List; public class TestCompetingRiskResponseCombiner { - private CompetingRiskFunctions generateFunctions(double[] times){ + private CompetingRiskFunctions generateFunctions(){ final List data = new ArrayList<>(); data.add(new CompetingRiskResponse(1, 1.0)); @@ -24,14 +24,14 @@ public class TestCompetingRiskResponseCombiner { data.add(new CompetingRiskResponse(0, 1.5)); data.add(new CompetingRiskResponse(0, 2.5)); - final CompetingRiskResponseCombiner combiner = new CompetingRiskResponseCombiner(new int[]{1,2}, times); + final CompetingRiskResponseCombiner combiner = new CompetingRiskResponseCombiner(new int[]{1,2}); return combiner.combine(data); } @Test public void testCompetingRiskResponseCombiner(){ - final CompetingRiskFunctions functions = generateFunctions(null); + final CompetingRiskFunctions functions = generateFunctions(); final MathFunction survivalCurve = functions.getSurvivalCurve(); @@ -86,68 +86,5 @@ public class TestCompetingRiskResponseCombiner { } - @Test - public void testCompetingRiskResponseCombinerWithSetTimes(){ - // By including time 3.0 (which extends past the data), - // we verify that we don't get NaNs past 3.0, which was a previous bug. - final CompetingRiskFunctions functions = generateFunctions(new double[]{1.0, 1.5, 2.0, 2.5, 3.0}); - - final MathFunction survivalCurve = functions.getSurvivalCurve(); - - // time = 1.0 1.5 2.0 2.5 - // surv = 0.7142857 0.5714286 0.1904762 0.1904762 - - final double margin = 0.0000001; - - closeEnough(0.7142857, survivalCurve.evaluate(1.0).getY(), margin); - closeEnough(0.5714286, survivalCurve.evaluate(1.5).getY(), margin); - closeEnough(0.1904762, survivalCurve.evaluate(2.0).getY(), margin); - closeEnough(0.1904762, survivalCurve.evaluate(2.5).getY(), margin); - closeEnough(0.1904762, survivalCurve.evaluate(3.0).getY(), margin); - - - // Time = 1.0 1.5 2.0 2.5 - /* Cumulative hazard function. Each row for one event. - [,1] [,2] [,3] [,4] - [1,] 0.2857143 0.2857143 0.6190476 0.6190476 - [2,] 0.0000000 0.2000000 0.5333333 0.5333333 - */ - - final MathFunction cumHaz1 = functions.getCauseSpecificHazardFunction(1); - closeEnough(0.2857143, cumHaz1.evaluate(1.0).getY(), margin); - closeEnough(0.2857143, cumHaz1.evaluate(1.5).getY(), margin); - closeEnough(0.6190476, cumHaz1.evaluate(2.0).getY(), margin); - closeEnough(0.6190476, cumHaz1.evaluate(2.5).getY(), margin); - closeEnough(0.6190476, cumHaz1.evaluate(3.0).getY(), margin); - - final MathFunction cumHaz2 = functions.getCauseSpecificHazardFunction(2); - closeEnough(0.0, cumHaz2.evaluate(1.0).getY(), margin); - closeEnough(0.2, cumHaz2.evaluate(1.5).getY(), margin); - closeEnough(0.5333333, cumHaz2.evaluate(2.0).getY(), margin); - closeEnough(0.5333333, cumHaz2.evaluate(2.5).getY(), margin); - closeEnough(0.5333333, cumHaz2.evaluate(3.0).getY(), margin); - - /* Time = 1.0 1.5 2.0 2.5 - Cumulative Incidence Curve. Each row for one event. - [,1] [,2] [,3] [,4] - [1,] 0.2857143 0.2857143 0.4761905 0.4761905 - [2,] 0.0000000 0.1428571 0.3333333 0.3333333 - */ - - final MathFunction cic1 = functions.getCumulativeIncidenceFunction(1); - closeEnough(0.2857143, cic1.evaluate(1.0).getY(), margin); - closeEnough(0.2857143, cic1.evaluate(1.5).getY(), margin); - closeEnough(0.4761905, cic1.evaluate(2.0).getY(), margin); - closeEnough(0.4761905, cic1.evaluate(2.5).getY(), margin); - closeEnough(0.4761905, cic1.evaluate(3.0).getY(), margin); - - final MathFunction cic2 = functions.getCumulativeIncidenceFunction(2); - closeEnough(0.0, cic2.evaluate(1.0).getY(), margin); - closeEnough(0.1428571, cic2.evaluate(1.5).getY(), margin); - closeEnough(0.3333333, cic2.evaluate(2.0).getY(), margin); - closeEnough(0.3333333, cic2.evaluate(2.5).getY(), margin); - closeEnough(0.3333333, cic2.evaluate(3.0).getY(), margin); - - } }