WIP - Add a CompetingRiskErrorRateCalculator.
Note that tests fail.
This commit is contained in:
parent
d4853f5232
commit
7a77851f94
6 changed files with 434 additions and 14 deletions
|
@ -0,0 +1,10 @@
|
||||||
|
package ca.joeltherrien.randomforest;
|
||||||
|
|
||||||
|
import java.lang.annotation.Documented;
|
||||||
|
import java.lang.annotation.ElementType;
|
||||||
|
import java.lang.annotation.Target;
|
||||||
|
|
||||||
|
@Documented
|
||||||
|
@Target(ElementType.METHOD)
|
||||||
|
public @interface VisibleForTesting {
|
||||||
|
}
|
|
@ -0,0 +1,287 @@
|
||||||
|
package ca.joeltherrien.randomforest.responses.competingrisk;
|
||||||
|
|
||||||
|
import ca.joeltherrien.randomforest.Row;
|
||||||
|
import ca.joeltherrien.randomforest.VisibleForTesting;
|
||||||
|
import ca.joeltherrien.randomforest.tree.Forest;
|
||||||
|
import ca.joeltherrien.randomforest.tree.Tree;
|
||||||
|
import lombok.RequiredArgsConstructor;
|
||||||
|
|
||||||
|
import java.util.Collection;
|
||||||
|
import java.util.Comparator;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Based on the naive version in Section 3.2 of "Concordance for Prognastic Models with Competing Risks" by Wolbers et al.
|
||||||
|
*
|
||||||
|
* Note that this is the same version implemented in randomForestSRC. The downsides of this approach is that we can expect the errors to be biased, possibly severely.
|
||||||
|
* Therefore I suggest that this measure only be used in comparing models, but not as a final output.
|
||||||
|
*/
|
||||||
|
@RequiredArgsConstructor
|
||||||
|
public class CompetingRiskErrorRateCalculator {
|
||||||
|
|
||||||
|
private final CompetingRiskFunctionCombiner combiner;
|
||||||
|
private final int[] events;
|
||||||
|
|
||||||
|
public CompetingRiskErrorRateCalculator(final int[] events, final double[] times){
|
||||||
|
this.events = events;
|
||||||
|
this.combiner = new CompetingRiskFunctionCombiner(events, times);
|
||||||
|
}
|
||||||
|
|
||||||
|
public double[] calculateAll(final List<Row<CompetingRiskResponse>> rows, final Forest<CompetingRiskFunctions, CompetingRiskFunctions> forest){
|
||||||
|
|
||||||
|
final Collection<Tree<CompetingRiskFunctions>> trees = forest.getTrees();
|
||||||
|
|
||||||
|
// This predicts for rows based on their OOB trees.
|
||||||
|
final List<CompetingRiskFunctions> riskFunctions = rows.stream()
|
||||||
|
.map(row -> {
|
||||||
|
return trees.stream().filter(tree -> !tree.idInBootstrapSample(row.getId())).map(tree -> tree.evaluate(row)).collect(Collectors.toList());
|
||||||
|
})
|
||||||
|
.map(list -> combiner.combine(list))
|
||||||
|
.collect(Collectors.toList());
|
||||||
|
|
||||||
|
final double[] errorRates = new double[events.length];
|
||||||
|
|
||||||
|
final List<CompetingRiskResponse> responses = rows.stream().map(row -> row.getResponse()).collect(Collectors.toList());
|
||||||
|
|
||||||
|
// Let \tau be the max time.
|
||||||
|
|
||||||
|
for(int e=0; e<events.length; e++){
|
||||||
|
final int event = events[e];
|
||||||
|
|
||||||
|
final double[] mortalityList = riskFunctions.stream()
|
||||||
|
.map(riskFunction -> riskFunction.getCumulativeIncidenceFunction(event))
|
||||||
|
.mapToDouble(cif -> functionToMortality(cif))
|
||||||
|
.toArray();
|
||||||
|
|
||||||
|
final double concordance = calculate(responses, mortalityList, event);
|
||||||
|
errorRates[e] = 1.0 - concordance;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
return errorRates;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@VisibleForTesting
|
||||||
|
public double calculate(final List<CompetingRiskResponse> responseList, final double[] mortalityArray, final int event){
|
||||||
|
|
||||||
|
// Let \tau be the max time.
|
||||||
|
|
||||||
|
int permissible = 0;
|
||||||
|
int numerator = 0;
|
||||||
|
|
||||||
|
for(int i = 0; i<mortalityArray.length; i++){
|
||||||
|
final CompetingRiskResponse responseI = responseList.get(i);
|
||||||
|
if(responseI.getDelta() != event){ // \tilde{N}_i^1(\tau) == 1 check
|
||||||
|
continue; // skip if it's 0
|
||||||
|
}
|
||||||
|
|
||||||
|
final double mortalityI = mortalityArray[i];
|
||||||
|
|
||||||
|
for(int j=0; j<mortalityArray.length; j++){
|
||||||
|
final CompetingRiskResponse responseJ = responseList.get(j);
|
||||||
|
// Check that Aij or Bij == 1
|
||||||
|
if(responseI.getU() < responseJ.getU() || (responseI.getU() >= responseJ.getU() && !responseJ.isCensored() && responseJ.getDelta() != event)){
|
||||||
|
permissible++;
|
||||||
|
|
||||||
|
final double mortalityJ = mortalityArray[j];
|
||||||
|
numerator += mortalityI > mortalityJ ? 1 : 0;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
return (double) numerator / (double) permissible;
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
public double[] calculateAll(final List<Row<CompetingRiskResponse>> rows, final Forest<CompetingRiskFunctions, CompetingRiskFunctions> forest){
|
||||||
|
rows.sort(Comparator.comparing(row -> row.getResponse().getU())); // optimization for later loop
|
||||||
|
|
||||||
|
final Collection<Tree<CompetingRiskFunctions>> trees = forest.getTrees();
|
||||||
|
|
||||||
|
final List<CompetingRiskFunctions> riskFunctions = rows.stream()
|
||||||
|
.map(row -> {
|
||||||
|
return trees.stream().filter(tree -> !tree.idInBootstrapSample(row.getId())).map(tree -> tree.evaluate(row)).collect(Collectors.toList());
|
||||||
|
})
|
||||||
|
.map(list -> combiner.combine(list))
|
||||||
|
.collect(Collectors.toList());
|
||||||
|
|
||||||
|
final double[] errorRates = new double[events.length];
|
||||||
|
|
||||||
|
|
||||||
|
for(int e=0; e<events.length; e++){
|
||||||
|
final int event = events[e];
|
||||||
|
|
||||||
|
final double[] mortalityList = riskFunctions.stream()
|
||||||
|
.map(riskFunction -> riskFunction.getCumulativeIncidenceFunction(event))
|
||||||
|
.mapToDouble(cif -> functionToMortality(cif))
|
||||||
|
.toArray();
|
||||||
|
|
||||||
|
int permissible = 0;
|
||||||
|
double c = 0.0;
|
||||||
|
|
||||||
|
outer_mortality:
|
||||||
|
for(int i = 0; i< mortalityList.length; i++){
|
||||||
|
final Row<CompetingRiskResponse> leftRow = rows.get(i);
|
||||||
|
final double mortalityLeft = mortalityList[i];
|
||||||
|
final CompetingRiskResponse leftResponse = leftRow.getResponse();
|
||||||
|
|
||||||
|
for(int j=i+1; j<mortalityList.length; j++){
|
||||||
|
|
||||||
|
final Row<CompetingRiskResponse> rightRow = rows.get(j);
|
||||||
|
final double mortalityRight = mortalityList[j];
|
||||||
|
final CompetingRiskResponse rightResponse = rightRow.getResponse();
|
||||||
|
|
||||||
|
if(leftResponse.getDelta() != event && rightResponse.getU() > leftResponse.getU()){
|
||||||
|
// because we've sorted the responses earlier we will never get a permissable result for greater j.
|
||||||
|
continue outer_mortality;
|
||||||
|
}
|
||||||
|
|
||||||
|
// check and see if pair is permissable
|
||||||
|
if(isPermissablePair(leftResponse, rightResponse, event)){
|
||||||
|
permissible++;
|
||||||
|
|
||||||
|
final double comparisonScore = compare(leftResponse, rightResponse, event);
|
||||||
|
if(comparisonScore < 0) { // left > right
|
||||||
|
// right has shorter time
|
||||||
|
if(mortalityRight > mortalityLeft){
|
||||||
|
c += 1.0;
|
||||||
|
}
|
||||||
|
else if(mortalityRight == mortalityLeft){
|
||||||
|
c += 0.5;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else if(comparisonScore > 0){ // left < right
|
||||||
|
// left has shorter term
|
||||||
|
if(mortalityRight < mortalityLeft){
|
||||||
|
c += 1.0;
|
||||||
|
}
|
||||||
|
else if(mortalityRight == mortalityLeft){
|
||||||
|
c += 0.5;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else{ // comparisonScore == 0
|
||||||
|
c += (mortalityLeft == mortalityRight) ? 1.0 : 0.5;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
else{
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
final double concordance = c / (double) permissible;
|
||||||
|
errorRates[e] = 1.0 - concordance;
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
return errorRates;
|
||||||
|
}
|
||||||
|
*/
|
||||||
|
|
||||||
|
/*
|
||||||
|
private boolean isPermissablePair(final CompetingRiskResponse left, final CompetingRiskResponse right){
|
||||||
|
if(left.isCensored() && right.isCensored()){
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if(left.getU() < right.getU() && left.isCensored()){
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if(left.getU() > right.getU() && right.isCensored()){
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
|
||||||
|
}
|
||||||
|
*/
|
||||||
|
|
||||||
|
/*
|
||||||
|
private boolean isPermissablePair(final CompetingRiskResponse left, final CompetingRiskResponse right, int event){
|
||||||
|
if(left.getDelta() != event && right.getDelta() != event){
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if(left.getU() < right.getU() && left.getDelta() != event){
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if(left.getU() > right.getU() && right.getDelta() != event){
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
|
||||||
|
}
|
||||||
|
*/
|
||||||
|
|
||||||
|
|
||||||
|
private double functionToMortality(final MathFunction cif){
|
||||||
|
double summation = 0.0;
|
||||||
|
double previousTime = 0.0;
|
||||||
|
|
||||||
|
for(final Point point : cif.getPoints()){
|
||||||
|
summation += point.getY() * (point.getTime() - previousTime);
|
||||||
|
previousTime = point.getTime();
|
||||||
|
}
|
||||||
|
|
||||||
|
return summation;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Compare two CompetingRiskResponses to see which is larger than the other (if it can be determined).
|
||||||
|
*
|
||||||
|
* @param left
|
||||||
|
* @param right
|
||||||
|
* @param event Event of interest. All other events are treated as censoring.
|
||||||
|
* @return -1 if left is strictly greater than right, -0.5 if left is greater than right, 0 if both are equal, 0.5 if right is greater than left, and 1 if right is strictly greater than left.
|
||||||
|
*//*
|
||||||
|
@VisibleForTesting
|
||||||
|
public double compare(final CompetingRiskResponse left, final CompetingRiskResponse right, int event){
|
||||||
|
|
||||||
|
|
||||||
|
if(left.getU() > right.getU() && right.getDelta()==event){
|
||||||
|
// left is greater
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
else if(right.getU() > left.getU() && left.getDelta()==event){
|
||||||
|
// right is greater
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
else if(left.getU() == right.getU() && left.getDelta()==event && right.getDelta()==event){
|
||||||
|
// they are equal
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
else if(left.getU() == right.getU() && left.getDelta()!=event && right.getDelta()==event){
|
||||||
|
// left is greater (note; could be unknown depending on definitions)
|
||||||
|
//return -0.5;
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
else if(left.getU() == right.getU() && left.getDelta()==event && right.getDelta()!=event){
|
||||||
|
// right is greater (note; could be unknown depending on definitions)
|
||||||
|
//return 0.5;
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
else{
|
||||||
|
throw new IllegalArgumentException("Invalid comparison of " + left + " and " + right + "; did you call isPermissablePair first?");
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
}*/
|
||||||
|
|
||||||
|
}
|
|
@ -1,6 +1,7 @@
|
||||||
package ca.joeltherrien.randomforest.responses.competingrisk;
|
package ca.joeltherrien.randomforest.responses.competingrisk;
|
||||||
|
|
||||||
import ca.joeltherrien.randomforest.tree.ResponseCombiner;
|
import ca.joeltherrien.randomforest.tree.ResponseCombiner;
|
||||||
|
import lombok.Getter;
|
||||||
import lombok.RequiredArgsConstructor;
|
import lombok.RequiredArgsConstructor;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
|
@ -14,6 +15,7 @@ public class CompetingRiskFunctionCombiner implements ResponseCombiner<Competing
|
||||||
private final int[] events;
|
private final int[] events;
|
||||||
private final double[] times; // We may restrict ourselves to specific times.
|
private final double[] times; // We may restrict ourselves to specific times.
|
||||||
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public CompetingRiskFunctions combine(List<CompetingRiskFunctions> responses) {
|
public CompetingRiskFunctions combine(List<CompetingRiskFunctions> responses) {
|
||||||
|
|
||||||
|
|
|
@ -43,12 +43,30 @@ output.one.tree$chf[,c(11,66,103),1]
|
||||||
#output$cif[,,1] # CIF for cause 1
|
#output$cif[,,1] # CIF for cause 1
|
||||||
#output$cif[,,2] # CIF for cause 2
|
#output$cif[,,2] # CIF for cause 2
|
||||||
|
|
||||||
many.trees <- rfsrc(Surv(time, status) ~ idu + black, wihs, nsplit = 5, ntree = 100, splitrule="logrank", cause=1, mtry=2, membership=TRUE)
|
many.trees <- rfsrc(Surv(time, status) ~ idu + black, wihs, nsplit = 5, ntree = 100, splitrule="logrank", cause=1, mtry=2, membership=TRUE); many.trees
|
||||||
|
|
||||||
|
err.rate.1 = c()
|
||||||
|
err.rate.2 = c()
|
||||||
|
for(j in 1:100){
|
||||||
|
many.trees <- rfsrc(Surv(time, status) ~ idu + black, wihs, nsplit = 5, ntree = 100, splitrule="logrank", cause=1, mtry=2, membership=TRUE)
|
||||||
|
err.rate.1 = c(err.rate.1, many.trees$err.rate[100,1])
|
||||||
|
err.rate.2 = c(err.rate.2, many.trees$err.rate[100,2])
|
||||||
|
}
|
||||||
|
quant.1 = quantile(err.rate.1, probs=c(0.025, 0.5, 0.975)) # 0.4727131 0.4792391 0.4862286
|
||||||
|
quant.2 = quantile(err.rate.2, probs=c(0.025, 0.5, 0.975)) # 0.4898299 0.4978300 0.5064539
|
||||||
|
|
||||||
|
(quant.1[3] + quant.1[1]) / 2
|
||||||
|
(quant.1[3] - quant.1[1]) / 2
|
||||||
|
|
||||||
|
(quant.2[3] + quant.2[1]) / 2
|
||||||
|
(quant.2[3] - quant.2[1]) / 2
|
||||||
|
|
||||||
|
|
||||||
output.many.trees = predict(many.trees, newData)
|
output.many.trees = predict(many.trees, newData)
|
||||||
output.many.trees$cif[,41,1]
|
output.many.trees$cif[,41,1]
|
||||||
output.many.trees$cif[,41,2]
|
output.many.trees$cif[,41,2]
|
||||||
|
|
||||||
many.trees.all <- rfsrc(Surv(time, status) ~ ageatfda + cd4nadir + idu + black, wihs, nsplit = 5, ntree = 100, splitrule="logrank", cause=1, mtry=2, membership=TRUE)
|
many.trees.all <- rfsrc(Surv(time, status) ~ ageatfda + cd4nadir + idu + black, wihs, nsplit = 5, ntree = 100, splitrule="logrank", cause=1, mtry=2, membership=TRUE); many.trees.all
|
||||||
output.many.trees.all = predict(many.trees.all, newData)
|
output.many.trees.all = predict(many.trees.all, newData)
|
||||||
output.many.trees.all$cif[,103,1]
|
output.many.trees.all$cif[,103,1]
|
||||||
output.many.trees.all$cif[,103,2]
|
output.many.trees.all$cif[,103,2]
|
||||||
|
@ -57,6 +75,7 @@ output.many.trees.all$cif[,103,2]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
end.numbers = c()
|
end.numbers = c()
|
||||||
end.times = c()
|
end.times = c()
|
||||||
lgths = c()
|
lgths = c()
|
||||||
|
|
|
@ -2,10 +2,7 @@ package ca.joeltherrien.randomforest.competingrisk;
|
||||||
|
|
||||||
import ca.joeltherrien.randomforest.*;
|
import ca.joeltherrien.randomforest.*;
|
||||||
import ca.joeltherrien.randomforest.covariates.*;
|
import ca.joeltherrien.randomforest.covariates.*;
|
||||||
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskFunctions;
|
import ca.joeltherrien.randomforest.responses.competingrisk.*;
|
||||||
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
|
|
||||||
import ca.joeltherrien.randomforest.responses.competingrisk.MathFunction;
|
|
||||||
import ca.joeltherrien.randomforest.responses.competingrisk.Point;
|
|
||||||
import ca.joeltherrien.randomforest.tree.Forest;
|
import ca.joeltherrien.randomforest.tree.Forest;
|
||||||
import ca.joeltherrien.randomforest.tree.ForestTrainer;
|
import ca.joeltherrien.randomforest.tree.ForestTrainer;
|
||||||
import ca.joeltherrien.randomforest.tree.Node;
|
import ca.joeltherrien.randomforest.tree.Node;
|
||||||
|
@ -16,11 +13,8 @@ import org.junit.jupiter.api.Test;
|
||||||
import static org.junit.jupiter.api.Assertions.*;
|
import static org.junit.jupiter.api.Assertions.*;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.Collections;
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.concurrent.ThreadLocalRandom;
|
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
public class TestCompetingRisk {
|
public class TestCompetingRisk {
|
||||||
|
@ -81,6 +75,10 @@ public class TestCompetingRisk {
|
||||||
.build();
|
.build();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public List<Covariate> getCovariates(Settings settings){
|
||||||
|
return settings.getCovariates().stream().map(covariateSettings -> covariateSettings.build()).collect(Collectors.toList());
|
||||||
|
}
|
||||||
|
|
||||||
public CovariateRow getPredictionRow(List<Covariate> covariates){
|
public CovariateRow getPredictionRow(List<Covariate> covariates){
|
||||||
return CovariateRow.createSimple(Map.of(
|
return CovariateRow.createSimple(Map.of(
|
||||||
"ageatfda", "35",
|
"ageatfda", "35",
|
||||||
|
@ -194,10 +192,6 @@ public class TestCompetingRisk {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public List<Covariate> getCovariates(Settings settings){
|
|
||||||
return settings.getCovariates().stream().map(covariateSettings -> covariateSettings.build()).collect(Collectors.toList());
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testLogRankSingleGroupDifferentiatorTwoBooleans() throws IOException {
|
public void testLogRankSingleGroupDifferentiatorTwoBooleans() throws IOException {
|
||||||
final Settings settings = getSettings();
|
final Settings settings = getSettings();
|
||||||
|
@ -232,6 +226,15 @@ public class TestCompetingRisk {
|
||||||
|
|
||||||
closeEnough(0.163, functions.getCumulativeIncidenceFunction(2).evaluate(4.0).getY(), 0.01);
|
closeEnough(0.163, functions.getCumulativeIncidenceFunction(2).evaluate(4.0).getY(), 0.01);
|
||||||
closeEnough(0.195, functions.getCumulativeIncidenceFunction(2).evaluate(10.8).getY(), 0.01);
|
closeEnough(0.195, functions.getCumulativeIncidenceFunction(2).evaluate(10.8).getY(), 0.01);
|
||||||
|
|
||||||
|
final CompetingRiskErrorRateCalculator errorRateCalculator = new CompetingRiskErrorRateCalculator(new int[]{1,2}, null);
|
||||||
|
final double[] errorRates = errorRateCalculator.calculateAll(dataset, forest);
|
||||||
|
|
||||||
|
// Error rates happen to be about the same
|
||||||
|
closeEnough(0.4795, errorRates[0], 0.007);
|
||||||
|
closeEnough(0.478, errorRates[1], 0.008);
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@ -294,8 +297,16 @@ public class TestCompetingRisk {
|
||||||
final List<Point> causeOneCIFPoints = functions.getCumulativeIncidenceFunction(1).getPoints();
|
final List<Point> causeOneCIFPoints = functions.getCumulativeIncidenceFunction(1).getPoints();
|
||||||
|
|
||||||
// We seem to consistently underestimate the results.
|
// We seem to consistently underestimate the results.
|
||||||
assertTrue(causeOneCIFPoints.get(causeOneCIFPoints.size()-1).getY() > 0.75, "Results should match randomForestSRC");
|
assertTrue(causeOneCIFPoints.get(causeOneCIFPoints.size()-1).getY() > 0.75, "Results should match randomForestSRC; had " + causeOneCIFPoints.get(causeOneCIFPoints.size()-1).getY()); // note; most observations from randomForestSRC hover around 0.78 but I've seen it as low as 0.72
|
||||||
|
|
||||||
|
final CompetingRiskErrorRateCalculator errorRate = new CompetingRiskErrorRateCalculator((CompetingRiskFunctionCombiner) settings.getTreeCombiner(), new int[]{1,2});
|
||||||
|
final double[] errorRates = errorRate.calculateAll(dataset, forest);
|
||||||
|
|
||||||
|
System.out.println(errorRates[0]);
|
||||||
|
System.out.println(errorRates[1]);
|
||||||
|
|
||||||
|
closeEnough(0.41, errorRates[0], 0.02);
|
||||||
|
closeEnough(0.38, errorRates[1], 0.02);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,91 @@
|
||||||
|
package ca.joeltherrien.randomforest.competingrisk;
|
||||||
|
|
||||||
|
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskErrorRateCalculator;
|
||||||
|
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
|
||||||
|
import org.junit.jupiter.api.Test;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
import static org.junit.jupiter.api.Assertions.*;
|
||||||
|
|
||||||
|
public class TestCompetingRiskErrorRateCalculator {
|
||||||
|
|
||||||
|
/*
|
||||||
|
@Test
|
||||||
|
public void testComparingResponses(){
|
||||||
|
|
||||||
|
// Large, uncensored
|
||||||
|
CompetingRiskResponse responseA = new CompetingRiskResponse(1, 10.0);
|
||||||
|
|
||||||
|
// Large, censored
|
||||||
|
CompetingRiskResponse responseB = new CompetingRiskResponse(0, 10.0);
|
||||||
|
|
||||||
|
// Large, other event
|
||||||
|
CompetingRiskResponse responseC = new CompetingRiskResponse(2, 10.0);
|
||||||
|
|
||||||
|
// Medium, uncensored
|
||||||
|
CompetingRiskResponse responseD = new CompetingRiskResponse(1, 5.0);
|
||||||
|
|
||||||
|
// Medium, censored
|
||||||
|
CompetingRiskResponse responseE = new CompetingRiskResponse(0, 5.0);
|
||||||
|
|
||||||
|
// Medium, other event
|
||||||
|
CompetingRiskResponse responseF = new CompetingRiskResponse(2, 5.0);
|
||||||
|
|
||||||
|
final int event = 1;
|
||||||
|
|
||||||
|
final CompetingRiskErrorRateCalculator errorRateCalculator = new CompetingRiskErrorRateCalculator(null, null);
|
||||||
|
|
||||||
|
assertThrows(IllegalArgumentException.class, () -> errorRateCalculator.compare(responseB, responseB, event));
|
||||||
|
assertThrows(IllegalArgumentException.class, () -> errorRateCalculator.compare(responseC, responseC, event));
|
||||||
|
|
||||||
|
assertEquals(0.5, errorRateCalculator.compare(responseA, responseB, event));
|
||||||
|
assertEquals(-0.5, errorRateCalculator.compare(responseB, responseA, event));
|
||||||
|
|
||||||
|
assertEquals(0.0, errorRateCalculator.compare(responseA, responseA, event));
|
||||||
|
|
||||||
|
assertThrows(IllegalArgumentException.class, () -> errorRateCalculator.compare(responseB, responseE, event));
|
||||||
|
assertThrows(IllegalArgumentException.class, () -> errorRateCalculator.compare(responseE, responseB, event));
|
||||||
|
assertThrows(IllegalArgumentException.class, () -> errorRateCalculator.compare(responseB, responseF, event));
|
||||||
|
assertThrows(IllegalArgumentException.class, () -> errorRateCalculator.compare(responseF, responseB, event));
|
||||||
|
|
||||||
|
assertEquals(-1.0, errorRateCalculator.compare(responseB, responseD, event));
|
||||||
|
assertEquals(1.0, errorRateCalculator.compare(responseD, responseB, event));
|
||||||
|
assertEquals(-1.0, errorRateCalculator.compare(responseC, responseD, event));
|
||||||
|
assertEquals(1.0, errorRateCalculator.compare(responseD, responseC, event));
|
||||||
|
|
||||||
|
assertEquals(-1.0, errorRateCalculator.compare(responseA, responseD, event));
|
||||||
|
assertEquals(1.0, errorRateCalculator.compare(responseD, responseA, event));
|
||||||
|
|
||||||
|
assertThrows(IllegalArgumentException.class, () -> errorRateCalculator.compare(responseA, responseE, event));
|
||||||
|
assertThrows(IllegalArgumentException.class, () -> errorRateCalculator.compare(responseE, responseA, event));
|
||||||
|
assertThrows(IllegalArgumentException.class, () -> errorRateCalculator.compare(responseA, responseF, event));
|
||||||
|
assertThrows(IllegalArgumentException.class, () -> errorRateCalculator.compare(responseF, responseA, event));
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
*/
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testConcordance(){
|
||||||
|
|
||||||
|
final CompetingRiskResponse response1 = new CompetingRiskResponse(1, 5.0);
|
||||||
|
final CompetingRiskResponse response2 = new CompetingRiskResponse(0, 6.0);
|
||||||
|
final CompetingRiskResponse response3 = new CompetingRiskResponse(2, 8.0);
|
||||||
|
final CompetingRiskResponse response4 = new CompetingRiskResponse(1, 3.0);
|
||||||
|
|
||||||
|
final double[] mortalityArray = new double[]{1, 4, 3, 9};
|
||||||
|
final List<CompetingRiskResponse> responseList = List.of(response1, response2, response3, response4);
|
||||||
|
|
||||||
|
final int event = 1;
|
||||||
|
|
||||||
|
final CompetingRiskErrorRateCalculator errorRateCalculator = new CompetingRiskErrorRateCalculator(new int[]{1,2}, null);
|
||||||
|
|
||||||
|
final double concordance = errorRateCalculator.calculate(responseList, mortalityArray, event);
|
||||||
|
|
||||||
|
// Expected value found through calculations by hand
|
||||||
|
assertEquals(3.0/5.0, concordance);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
Loading…
Reference in a new issue