Fix a bug where CompetingRiskFunctions returns NaNs when using set times
in response combiner
This commit is contained in:
parent
62198f998d
commit
b8024275a9
2 changed files with 79 additions and 3 deletions
|
@ -56,6 +56,12 @@ public class CompetingRiskResponseCombiner implements ResponseCombiner<Competing
|
||||||
for(int i=0; i<timesToUse.length; i++){
|
for(int i=0; i<timesToUse.length; i++){
|
||||||
final double time_k = timesToUse[i];
|
final double time_k = timesToUse[i];
|
||||||
final double individualsAtRisk = individualsAtRiskArray[i]; // Y(t_k)
|
final double individualsAtRisk = individualsAtRiskArray[i]; // Y(t_k)
|
||||||
|
|
||||||
|
if(individualsAtRisk == 0){
|
||||||
|
// if we continue we'll get NaN
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
final double numberOfEventsAtTime = (double) responses.stream()
|
final double numberOfEventsAtTime = (double) responses.stream()
|
||||||
.filter(event -> !event.isCensored())
|
.filter(event -> !event.isCensored())
|
||||||
.filter(event -> event.getU() == time_k) // since delta != 0 we know censoring didn't occur prior to this
|
.filter(event -> event.getU() == time_k) // since delta != 0 we know censoring didn't occur prior to this
|
||||||
|
@ -81,6 +87,12 @@ public class CompetingRiskResponseCombiner implements ResponseCombiner<Competing
|
||||||
for(int i=0; i<timesToUse.length; i++){
|
for(int i=0; i<timesToUse.length; i++){
|
||||||
final double time_k = timesToUse[i];
|
final double time_k = timesToUse[i];
|
||||||
final double individualsAtRisk = individualsAtRiskArray[i]; // Y(t_k)
|
final double individualsAtRisk = individualsAtRiskArray[i]; // Y(t_k)
|
||||||
|
|
||||||
|
if(individualsAtRisk == 0){
|
||||||
|
// if we continue we'll get NaN
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
final double numberEventsAtTime = numberOfEventsAtTime(event, responses, time_k); // d_j(t_k)
|
final double numberEventsAtTime = numberOfEventsAtTime(event, responses, time_k); // d_j(t_k)
|
||||||
|
|
||||||
// Cause-specific cumulative hazard function
|
// Cause-specific cumulative hazard function
|
||||||
|
|
|
@ -13,7 +13,7 @@ import java.util.List;
|
||||||
|
|
||||||
public class TestCompetingRiskResponseCombiner {
|
public class TestCompetingRiskResponseCombiner {
|
||||||
|
|
||||||
private CompetingRiskFunctions generateFunctions(){
|
private CompetingRiskFunctions generateFunctions(double[] times){
|
||||||
final List<CompetingRiskResponse> data = new ArrayList<>();
|
final List<CompetingRiskResponse> data = new ArrayList<>();
|
||||||
|
|
||||||
data.add(new CompetingRiskResponse(1, 1.0));
|
data.add(new CompetingRiskResponse(1, 1.0));
|
||||||
|
@ -24,14 +24,14 @@ public class TestCompetingRiskResponseCombiner {
|
||||||
data.add(new CompetingRiskResponse(0, 1.5));
|
data.add(new CompetingRiskResponse(0, 1.5));
|
||||||
data.add(new CompetingRiskResponse(0, 2.5));
|
data.add(new CompetingRiskResponse(0, 2.5));
|
||||||
|
|
||||||
final CompetingRiskResponseCombiner combiner = new CompetingRiskResponseCombiner(new int[]{1,2}, null);
|
final CompetingRiskResponseCombiner combiner = new CompetingRiskResponseCombiner(new int[]{1,2}, times);
|
||||||
|
|
||||||
return combiner.combine(data);
|
return combiner.combine(data);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testCompetingRiskResponseCombiner(){
|
public void testCompetingRiskResponseCombiner(){
|
||||||
final CompetingRiskFunctions functions = generateFunctions();
|
final CompetingRiskFunctions functions = generateFunctions(null);
|
||||||
|
|
||||||
final MathFunction survivalCurve = functions.getSurvivalCurve();
|
final MathFunction survivalCurve = functions.getSurvivalCurve();
|
||||||
|
|
||||||
|
@ -86,4 +86,68 @@ public class TestCompetingRiskResponseCombiner {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testCompetingRiskResponseCombinerWithSetTimes(){
|
||||||
|
// By including time 3.0 (which extends past the data),
|
||||||
|
// we verify that we don't get NaNs past 3.0, which was a previous bug.
|
||||||
|
final CompetingRiskFunctions functions = generateFunctions(new double[]{1.0, 1.5, 2.0, 2.5, 3.0});
|
||||||
|
|
||||||
|
final MathFunction survivalCurve = functions.getSurvivalCurve();
|
||||||
|
|
||||||
|
// time = 1.0 1.5 2.0 2.5
|
||||||
|
// surv = 0.7142857 0.5714286 0.1904762 0.1904762
|
||||||
|
|
||||||
|
final double margin = 0.0000001;
|
||||||
|
|
||||||
|
closeEnough(0.7142857, survivalCurve.evaluate(1.0).getY(), margin);
|
||||||
|
closeEnough(0.5714286, survivalCurve.evaluate(1.5).getY(), margin);
|
||||||
|
closeEnough(0.1904762, survivalCurve.evaluate(2.0).getY(), margin);
|
||||||
|
closeEnough(0.1904762, survivalCurve.evaluate(2.5).getY(), margin);
|
||||||
|
closeEnough(0.1904762, survivalCurve.evaluate(3.0).getY(), margin);
|
||||||
|
|
||||||
|
|
||||||
|
// Time = 1.0 1.5 2.0 2.5
|
||||||
|
/* Cumulative hazard function. Each row for one event.
|
||||||
|
[,1] [,2] [,3] [,4]
|
||||||
|
[1,] 0.2857143 0.2857143 0.6190476 0.6190476
|
||||||
|
[2,] 0.0000000 0.2000000 0.5333333 0.5333333
|
||||||
|
*/
|
||||||
|
|
||||||
|
final MathFunction cumHaz1 = functions.getCauseSpecificHazardFunction(1);
|
||||||
|
closeEnough(0.2857143, cumHaz1.evaluate(1.0).getY(), margin);
|
||||||
|
closeEnough(0.2857143, cumHaz1.evaluate(1.5).getY(), margin);
|
||||||
|
closeEnough(0.6190476, cumHaz1.evaluate(2.0).getY(), margin);
|
||||||
|
closeEnough(0.6190476, cumHaz1.evaluate(2.5).getY(), margin);
|
||||||
|
closeEnough(0.6190476, cumHaz1.evaluate(3.0).getY(), margin);
|
||||||
|
|
||||||
|
final MathFunction cumHaz2 = functions.getCauseSpecificHazardFunction(2);
|
||||||
|
closeEnough(0.0, cumHaz2.evaluate(1.0).getY(), margin);
|
||||||
|
closeEnough(0.2, cumHaz2.evaluate(1.5).getY(), margin);
|
||||||
|
closeEnough(0.5333333, cumHaz2.evaluate(2.0).getY(), margin);
|
||||||
|
closeEnough(0.5333333, cumHaz2.evaluate(2.5).getY(), margin);
|
||||||
|
closeEnough(0.5333333, cumHaz2.evaluate(3.0).getY(), margin);
|
||||||
|
|
||||||
|
/* Time = 1.0 1.5 2.0 2.5
|
||||||
|
Cumulative Incidence Curve. Each row for one event.
|
||||||
|
[,1] [,2] [,3] [,4]
|
||||||
|
[1,] 0.2857143 0.2857143 0.4761905 0.4761905
|
||||||
|
[2,] 0.0000000 0.1428571 0.3333333 0.3333333
|
||||||
|
*/
|
||||||
|
|
||||||
|
final MathFunction cic1 = functions.getCumulativeIncidenceFunction(1);
|
||||||
|
closeEnough(0.2857143, cic1.evaluate(1.0).getY(), margin);
|
||||||
|
closeEnough(0.2857143, cic1.evaluate(1.5).getY(), margin);
|
||||||
|
closeEnough(0.4761905, cic1.evaluate(2.0).getY(), margin);
|
||||||
|
closeEnough(0.4761905, cic1.evaluate(2.5).getY(), margin);
|
||||||
|
closeEnough(0.4761905, cic1.evaluate(3.0).getY(), margin);
|
||||||
|
|
||||||
|
final MathFunction cic2 = functions.getCumulativeIncidenceFunction(2);
|
||||||
|
closeEnough(0.0, cic2.evaluate(1.0).getY(), margin);
|
||||||
|
closeEnough(0.1428571, cic2.evaluate(1.5).getY(), margin);
|
||||||
|
closeEnough(0.3333333, cic2.evaluate(2.0).getY(), margin);
|
||||||
|
closeEnough(0.3333333, cic2.evaluate(2.5).getY(), margin);
|
||||||
|
closeEnough(0.3333333, cic2.evaluate(3.0).getY(), margin);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue