From 7da3bd14a5234a071c1f67696ba4da2261a20115 Mon Sep 17 00:00:00 2001 From: Joel Therrien Date: Wed, 5 Jun 2019 15:11:32 -0700 Subject: [PATCH] Add test verifying that CIFs are averaged together in the same way as randomForestSRC --- .../TestCompetingRiskFunctionCombiner.java | 123 ++++++++++++++++++ 1 file changed, 123 insertions(+) create mode 100644 src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCompetingRiskFunctionCombiner.java diff --git a/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCompetingRiskFunctionCombiner.java b/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCompetingRiskFunctionCombiner.java new file mode 100644 index 0000000..b6eba0f --- /dev/null +++ b/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCompetingRiskFunctionCombiner.java @@ -0,0 +1,123 @@ +/* + * Copyright (c) 2019 Joel Therrien. + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package ca.joeltherrien.randomforest.competingrisk; + +import ca.joeltherrien.randomforest.TestUtils; +import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskFunctions; +import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse; +import ca.joeltherrien.randomforest.responses.competingrisk.combiner.CompetingRiskFunctionCombiner; +import ca.joeltherrien.randomforest.responses.competingrisk.combiner.CompetingRiskResponseCombiner; +import ca.joeltherrien.randomforest.utils.RightContinuousStepFunction; +import ca.joeltherrien.randomforest.utils.Utils; +import org.junit.jupiter.api.Test; + +import java.util.List; + +public class TestCompetingRiskFunctionCombiner { + private final int[] events = new int[]{1,2}; + + private CompetingRiskFunctions createFunction(List responses){ + final CompetingRiskResponseCombiner responseCombiner = new CompetingRiskResponseCombiner(events); + + return responseCombiner.combine(responses); + } + + /* Data used in R code to compare with randomForestSRC + + data <- data.frame(u =c(1,1,2,3,3, 2,2,3,4,4), + delta =c(2,1,1,1,0, 2,1,1,1,0)) + */ + + @Test + public void testFuncionCombiner(){ + final List set1 = Utils.easyList( + new CompetingRiskResponse(2, 1.0), + new CompetingRiskResponse(1, 1.0), + new CompetingRiskResponse(1, 2.0), + new CompetingRiskResponse(1, 3.0), + new CompetingRiskResponse(0, 3.0) + ); + + final List set2 = Utils.easyList( + new CompetingRiskResponse(2, 2.0), + new CompetingRiskResponse(1, 2.0), + new CompetingRiskResponse(1, 3.0), + new CompetingRiskResponse(1, 4.0), + new CompetingRiskResponse(0, 4.0) + ); + + final CompetingRiskFunctions fun1 = createFunction(set1); + final CompetingRiskFunctions fun2 = createFunction(set2); + + final CompetingRiskFunctionCombiner combiner = new CompetingRiskFunctionCombiner(new int[]{1,2}, null); + + final CompetingRiskFunctions combinedFunction = combiner.combine(Utils.easyList(fun1, fun2)); + + final RightContinuousStepFunction cif_1 = combinedFunction.getCumulativeIncidenceFunction(1); + final RightContinuousStepFunction cif_2 = combinedFunction.getCumulativeIncidenceFunction(2); + + /* Result from randomForestSRC + , , CIF.1 + + [,1] [,2] [,3] [,4] + [1,] 0.1 0.3 0.5 0.6 + + , , CIF.2 + + [,1] [,2] [,3] [,4] + [1,] 0.1 0.2 0.2 0.2 + */ + + TestUtils.closeEnough(0.1, cif_1.evaluate(1.0), 0.01); + TestUtils.closeEnough(0.3, cif_1.evaluate(2.0), 0.01); + TestUtils.closeEnough(0.5, cif_1.evaluate(3.0), 0.01); + TestUtils.closeEnough(0.6, cif_1.evaluate(4.0), 0.01); + + TestUtils.closeEnough(0.1, cif_2.evaluate(1.0), 0.01); + TestUtils.closeEnough(0.2, cif_2.evaluate(2.0), 0.01); + TestUtils.closeEnough(0.2, cif_2.evaluate(3.0), 0.01); + TestUtils.closeEnough(0.2, cif_2.evaluate(4.0), 0.01); + + + } + + +} + +/* Code to get randomForestSRC results; last tested on version 2.9.0 + +library(randomForestSRC) +data <- data.frame(u =c(1,1,2,3,3, 2,2,3,4,4), + delta =c(2,1,1,1,0, 2,1,1,1,0), + x =c(0,0,0,0,0, 1,1,1,1,1)) + +bootstrap.matrix <- matrix(0, nrow=nrow(data), ncol=2) +bootstrap.matrix[1:5,1] <- 1 +bootstrap.matrix[6:10,2] <- 1 + + +model.rfsrc <- rfsrc(Surv(u, delta) ~ x, data, + nodedepth = 0, splitrule="logrank", + bootstrap="by.user", samp=bootstrap.matrix, + ntree=2 + ) + +new.data <- data.frame(x=c(1)) + +prediction <- predict(model.rfsrc, new.data) +prediction$cif + */