Add test verifying that CIFs are averaged together in the same way as randomForestSRC
This commit is contained in:
parent
4aac73b868
commit
7da3bd14a5
1 changed files with 123 additions and 0 deletions
|
@ -0,0 +1,123 @@
|
||||||
|
/*
|
||||||
|
* Copyright (c) 2019 Joel Therrien.
|
||||||
|
* This program is free software: you can redistribute it and/or modify
|
||||||
|
* it under the terms of the GNU General Public License as published by
|
||||||
|
* the Free Software Foundation, either version 3 of the License, or
|
||||||
|
* (at your option) any later version.
|
||||||
|
*
|
||||||
|
* This program is distributed in the hope that it will be useful,
|
||||||
|
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
* GNU General Public License for more details.
|
||||||
|
*
|
||||||
|
* You should have received a copy of the GNU General Public License
|
||||||
|
* along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package ca.joeltherrien.randomforest.competingrisk;
|
||||||
|
|
||||||
|
import ca.joeltherrien.randomforest.TestUtils;
|
||||||
|
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskFunctions;
|
||||||
|
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
|
||||||
|
import ca.joeltherrien.randomforest.responses.competingrisk.combiner.CompetingRiskFunctionCombiner;
|
||||||
|
import ca.joeltherrien.randomforest.responses.competingrisk.combiner.CompetingRiskResponseCombiner;
|
||||||
|
import ca.joeltherrien.randomforest.utils.RightContinuousStepFunction;
|
||||||
|
import ca.joeltherrien.randomforest.utils.Utils;
|
||||||
|
import org.junit.jupiter.api.Test;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
public class TestCompetingRiskFunctionCombiner {
|
||||||
|
private final int[] events = new int[]{1,2};
|
||||||
|
|
||||||
|
private CompetingRiskFunctions createFunction(List<CompetingRiskResponse> responses){
|
||||||
|
final CompetingRiskResponseCombiner responseCombiner = new CompetingRiskResponseCombiner(events);
|
||||||
|
|
||||||
|
return responseCombiner.combine(responses);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Data used in R code to compare with randomForestSRC
|
||||||
|
|
||||||
|
data <- data.frame(u =c(1,1,2,3,3, 2,2,3,4,4),
|
||||||
|
delta =c(2,1,1,1,0, 2,1,1,1,0))
|
||||||
|
*/
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testFuncionCombiner(){
|
||||||
|
final List<CompetingRiskResponse> set1 = Utils.easyList(
|
||||||
|
new CompetingRiskResponse(2, 1.0),
|
||||||
|
new CompetingRiskResponse(1, 1.0),
|
||||||
|
new CompetingRiskResponse(1, 2.0),
|
||||||
|
new CompetingRiskResponse(1, 3.0),
|
||||||
|
new CompetingRiskResponse(0, 3.0)
|
||||||
|
);
|
||||||
|
|
||||||
|
final List<CompetingRiskResponse> set2 = Utils.easyList(
|
||||||
|
new CompetingRiskResponse(2, 2.0),
|
||||||
|
new CompetingRiskResponse(1, 2.0),
|
||||||
|
new CompetingRiskResponse(1, 3.0),
|
||||||
|
new CompetingRiskResponse(1, 4.0),
|
||||||
|
new CompetingRiskResponse(0, 4.0)
|
||||||
|
);
|
||||||
|
|
||||||
|
final CompetingRiskFunctions fun1 = createFunction(set1);
|
||||||
|
final CompetingRiskFunctions fun2 = createFunction(set2);
|
||||||
|
|
||||||
|
final CompetingRiskFunctionCombiner combiner = new CompetingRiskFunctionCombiner(new int[]{1,2}, null);
|
||||||
|
|
||||||
|
final CompetingRiskFunctions combinedFunction = combiner.combine(Utils.easyList(fun1, fun2));
|
||||||
|
|
||||||
|
final RightContinuousStepFunction cif_1 = combinedFunction.getCumulativeIncidenceFunction(1);
|
||||||
|
final RightContinuousStepFunction cif_2 = combinedFunction.getCumulativeIncidenceFunction(2);
|
||||||
|
|
||||||
|
/* Result from randomForestSRC
|
||||||
|
, , CIF.1
|
||||||
|
|
||||||
|
[,1] [,2] [,3] [,4]
|
||||||
|
[1,] 0.1 0.3 0.5 0.6
|
||||||
|
|
||||||
|
, , CIF.2
|
||||||
|
|
||||||
|
[,1] [,2] [,3] [,4]
|
||||||
|
[1,] 0.1 0.2 0.2 0.2
|
||||||
|
*/
|
||||||
|
|
||||||
|
TestUtils.closeEnough(0.1, cif_1.evaluate(1.0), 0.01);
|
||||||
|
TestUtils.closeEnough(0.3, cif_1.evaluate(2.0), 0.01);
|
||||||
|
TestUtils.closeEnough(0.5, cif_1.evaluate(3.0), 0.01);
|
||||||
|
TestUtils.closeEnough(0.6, cif_1.evaluate(4.0), 0.01);
|
||||||
|
|
||||||
|
TestUtils.closeEnough(0.1, cif_2.evaluate(1.0), 0.01);
|
||||||
|
TestUtils.closeEnough(0.2, cif_2.evaluate(2.0), 0.01);
|
||||||
|
TestUtils.closeEnough(0.2, cif_2.evaluate(3.0), 0.01);
|
||||||
|
TestUtils.closeEnough(0.2, cif_2.evaluate(4.0), 0.01);
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Code to get randomForestSRC results; last tested on version 2.9.0
|
||||||
|
|
||||||
|
library(randomForestSRC)
|
||||||
|
data <- data.frame(u =c(1,1,2,3,3, 2,2,3,4,4),
|
||||||
|
delta =c(2,1,1,1,0, 2,1,1,1,0),
|
||||||
|
x =c(0,0,0,0,0, 1,1,1,1,1))
|
||||||
|
|
||||||
|
bootstrap.matrix <- matrix(0, nrow=nrow(data), ncol=2)
|
||||||
|
bootstrap.matrix[1:5,1] <- 1
|
||||||
|
bootstrap.matrix[6:10,2] <- 1
|
||||||
|
|
||||||
|
|
||||||
|
model.rfsrc <- rfsrc(Surv(u, delta) ~ x, data,
|
||||||
|
nodedepth = 0, splitrule="logrank",
|
||||||
|
bootstrap="by.user", samp=bootstrap.matrix,
|
||||||
|
ntree=2
|
||||||
|
)
|
||||||
|
|
||||||
|
new.data <- data.frame(x=c(1))
|
||||||
|
|
||||||
|
prediction <- predict(model.rfsrc, new.data)
|
||||||
|
prediction$cif
|
||||||
|
*/
|
Loading…
Reference in a new issue