diff --git a/src/main/java/ca/joeltherrien/randomforest/VisibleForTesting.java b/src/main/java/ca/joeltherrien/randomforest/VisibleForTesting.java new file mode 100644 index 0000000..d91a4d5 --- /dev/null +++ b/src/main/java/ca/joeltherrien/randomforest/VisibleForTesting.java @@ -0,0 +1,10 @@ +package ca.joeltherrien.randomforest; + +import java.lang.annotation.Documented; +import java.lang.annotation.ElementType; +import java.lang.annotation.Target; + +@Documented +@Target(ElementType.METHOD) +public @interface VisibleForTesting { +} diff --git a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskErrorRateCalculator.java b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskErrorRateCalculator.java new file mode 100644 index 0000000..bb47e7c --- /dev/null +++ b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskErrorRateCalculator.java @@ -0,0 +1,287 @@ +package ca.joeltherrien.randomforest.responses.competingrisk; + +import ca.joeltherrien.randomforest.Row; +import ca.joeltherrien.randomforest.VisibleForTesting; +import ca.joeltherrien.randomforest.tree.Forest; +import ca.joeltherrien.randomforest.tree.Tree; +import lombok.RequiredArgsConstructor; + +import java.util.Collection; +import java.util.Comparator; +import java.util.List; +import java.util.stream.Collectors; + +/** + * Based on the naive version in Section 3.2 of "Concordance for Prognastic Models with Competing Risks" by Wolbers et al. + * + * Note that this is the same version implemented in randomForestSRC. The downsides of this approach is that we can expect the errors to be biased, possibly severely. + * Therefore I suggest that this measure only be used in comparing models, but not as a final output. + */ +@RequiredArgsConstructor +public class CompetingRiskErrorRateCalculator { + + private final CompetingRiskFunctionCombiner combiner; + private final int[] events; + + public CompetingRiskErrorRateCalculator(final int[] events, final double[] times){ + this.events = events; + this.combiner = new CompetingRiskFunctionCombiner(events, times); + } + + public double[] calculateAll(final List> rows, final Forest forest){ + + final Collection> trees = forest.getTrees(); + + // This predicts for rows based on their OOB trees. + final List riskFunctions = rows.stream() + .map(row -> { + return trees.stream().filter(tree -> !tree.idInBootstrapSample(row.getId())).map(tree -> tree.evaluate(row)).collect(Collectors.toList()); + }) + .map(list -> combiner.combine(list)) + .collect(Collectors.toList()); + + final double[] errorRates = new double[events.length]; + + final List responses = rows.stream().map(row -> row.getResponse()).collect(Collectors.toList()); + + // Let \tau be the max time. + + for(int e=0; e riskFunction.getCumulativeIncidenceFunction(event)) + .mapToDouble(cif -> functionToMortality(cif)) + .toArray(); + + final double concordance = calculate(responses, mortalityList, event); + errorRates[e] = 1.0 - concordance; + + } + + return errorRates; + + } + + @VisibleForTesting + public double calculate(final List responseList, final double[] mortalityArray, final int event){ + + // Let \tau be the max time. + + int permissible = 0; + int numerator = 0; + + for(int i = 0; i= responseJ.getU() && !responseJ.isCensored() && responseJ.getDelta() != event)){ + permissible++; + + final double mortalityJ = mortalityArray[j]; + numerator += mortalityI > mortalityJ ? 1 : 0; + + } + + } + + } + + return (double) numerator / (double) permissible; + + + } + + /* + public double[] calculateAll(final List> rows, final Forest forest){ + rows.sort(Comparator.comparing(row -> row.getResponse().getU())); // optimization for later loop + + final Collection> trees = forest.getTrees(); + + final List riskFunctions = rows.stream() + .map(row -> { + return trees.stream().filter(tree -> !tree.idInBootstrapSample(row.getId())).map(tree -> tree.evaluate(row)).collect(Collectors.toList()); + }) + .map(list -> combiner.combine(list)) + .collect(Collectors.toList()); + + final double[] errorRates = new double[events.length]; + + + for(int e=0; e riskFunction.getCumulativeIncidenceFunction(event)) + .mapToDouble(cif -> functionToMortality(cif)) + .toArray(); + + int permissible = 0; + double c = 0.0; + + outer_mortality: + for(int i = 0; i< mortalityList.length; i++){ + final Row leftRow = rows.get(i); + final double mortalityLeft = mortalityList[i]; + final CompetingRiskResponse leftResponse = leftRow.getResponse(); + + for(int j=i+1; j rightRow = rows.get(j); + final double mortalityRight = mortalityList[j]; + final CompetingRiskResponse rightResponse = rightRow.getResponse(); + + if(leftResponse.getDelta() != event && rightResponse.getU() > leftResponse.getU()){ + // because we've sorted the responses earlier we will never get a permissable result for greater j. + continue outer_mortality; + } + + // check and see if pair is permissable + if(isPermissablePair(leftResponse, rightResponse, event)){ + permissible++; + + final double comparisonScore = compare(leftResponse, rightResponse, event); + if(comparisonScore < 0) { // left > right + // right has shorter time + if(mortalityRight > mortalityLeft){ + c += 1.0; + } + else if(mortalityRight == mortalityLeft){ + c += 0.5; + } + } + else if(comparisonScore > 0){ // left < right + // left has shorter term + if(mortalityRight < mortalityLeft){ + c += 1.0; + } + else if(mortalityRight == mortalityLeft){ + c += 0.5; + } + } + else{ // comparisonScore == 0 + c += (mortalityLeft == mortalityRight) ? 1.0 : 0.5; + } + + } + else{ + continue; + } + + } + + } + + final double concordance = c / (double) permissible; + errorRates[e] = 1.0 - concordance; + + + } + + + return errorRates; + } +*/ + + /* + private boolean isPermissablePair(final CompetingRiskResponse left, final CompetingRiskResponse right){ + if(left.isCensored() && right.isCensored()){ + return false; + } + + if(left.getU() < right.getU() && left.isCensored()){ + return false; + } + + if(left.getU() > right.getU() && right.isCensored()){ + return false; + } + + return true; + + } + */ + + /* + private boolean isPermissablePair(final CompetingRiskResponse left, final CompetingRiskResponse right, int event){ + if(left.getDelta() != event && right.getDelta() != event){ + return false; + } + + if(left.getU() < right.getU() && left.getDelta() != event){ + return false; + } + + if(left.getU() > right.getU() && right.getDelta() != event){ + return false; + } + + return true; + + } + */ + + + private double functionToMortality(final MathFunction cif){ + double summation = 0.0; + double previousTime = 0.0; + + for(final Point point : cif.getPoints()){ + summation += point.getY() * (point.getTime() - previousTime); + previousTime = point.getTime(); + } + + return summation; + + } + + + /** + * Compare two CompetingRiskResponses to see which is larger than the other (if it can be determined). + * + * @param left + * @param right + * @param event Event of interest. All other events are treated as censoring. + * @return -1 if left is strictly greater than right, -0.5 if left is greater than right, 0 if both are equal, 0.5 if right is greater than left, and 1 if right is strictly greater than left. + *//* + @VisibleForTesting + public double compare(final CompetingRiskResponse left, final CompetingRiskResponse right, int event){ + + + if(left.getU() > right.getU() && right.getDelta()==event){ + // left is greater + return -1; + } + else if(right.getU() > left.getU() && left.getDelta()==event){ + // right is greater + return 1; + } + else if(left.getU() == right.getU() && left.getDelta()==event && right.getDelta()==event){ + // they are equal + return 0; + } + else if(left.getU() == right.getU() && left.getDelta()!=event && right.getDelta()==event){ + // left is greater (note; could be unknown depending on definitions) + //return -0.5; + return 0; + } + else if(left.getU() == right.getU() && left.getDelta()==event && right.getDelta()!=event){ + // right is greater (note; could be unknown depending on definitions) + //return 0.5; + return 0; + } + else{ + throw new IllegalArgumentException("Invalid comparison of " + left + " and " + right + "; did you call isPermissablePair first?"); + } + + + }*/ + +} diff --git a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskFunctionCombiner.java b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskFunctionCombiner.java index 3e9733e..93c3e54 100644 --- a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskFunctionCombiner.java +++ b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskFunctionCombiner.java @@ -1,6 +1,7 @@ package ca.joeltherrien.randomforest.responses.competingrisk; import ca.joeltherrien.randomforest.tree.ResponseCombiner; +import lombok.Getter; import lombok.RequiredArgsConstructor; import java.util.ArrayList; @@ -14,6 +15,7 @@ public class CompetingRiskFunctionCombiner implements ResponseCombiner responses) { diff --git a/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCompetingRisk.R b/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCompetingRisk.R index 02dbe81..4e80601 100644 --- a/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCompetingRisk.R +++ b/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCompetingRisk.R @@ -43,12 +43,30 @@ output.one.tree$chf[,c(11,66,103),1] #output$cif[,,1] # CIF for cause 1 #output$cif[,,2] # CIF for cause 2 -many.trees <- rfsrc(Surv(time, status) ~ idu + black, wihs, nsplit = 5, ntree = 100, splitrule="logrank", cause=1, mtry=2, membership=TRUE) +many.trees <- rfsrc(Surv(time, status) ~ idu + black, wihs, nsplit = 5, ntree = 100, splitrule="logrank", cause=1, mtry=2, membership=TRUE); many.trees + +err.rate.1 = c() +err.rate.2 = c() +for(j in 1:100){ + many.trees <- rfsrc(Surv(time, status) ~ idu + black, wihs, nsplit = 5, ntree = 100, splitrule="logrank", cause=1, mtry=2, membership=TRUE) + err.rate.1 = c(err.rate.1, many.trees$err.rate[100,1]) + err.rate.2 = c(err.rate.2, many.trees$err.rate[100,2]) +} +quant.1 = quantile(err.rate.1, probs=c(0.025, 0.5, 0.975)) # 0.4727131 0.4792391 0.4862286 +quant.2 = quantile(err.rate.2, probs=c(0.025, 0.5, 0.975)) # 0.4898299 0.4978300 0.5064539 + +(quant.1[3] + quant.1[1]) / 2 +(quant.1[3] - quant.1[1]) / 2 + +(quant.2[3] + quant.2[1]) / 2 +(quant.2[3] - quant.2[1]) / 2 + + output.many.trees = predict(many.trees, newData) output.many.trees$cif[,41,1] output.many.trees$cif[,41,2] -many.trees.all <- rfsrc(Surv(time, status) ~ ageatfda + cd4nadir + idu + black, wihs, nsplit = 5, ntree = 100, splitrule="logrank", cause=1, mtry=2, membership=TRUE) +many.trees.all <- rfsrc(Surv(time, status) ~ ageatfda + cd4nadir + idu + black, wihs, nsplit = 5, ntree = 100, splitrule="logrank", cause=1, mtry=2, membership=TRUE); many.trees.all output.many.trees.all = predict(many.trees.all, newData) output.many.trees.all$cif[,103,1] output.many.trees.all$cif[,103,2] @@ -57,6 +75,7 @@ output.many.trees.all$cif[,103,2] + end.numbers = c() end.times = c() lgths = c() diff --git a/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCompetingRisk.java b/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCompetingRisk.java index 320c077..4a6ea4c 100644 --- a/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCompetingRisk.java +++ b/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCompetingRisk.java @@ -2,10 +2,7 @@ package ca.joeltherrien.randomforest.competingrisk; import ca.joeltherrien.randomforest.*; import ca.joeltherrien.randomforest.covariates.*; -import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskFunctions; -import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse; -import ca.joeltherrien.randomforest.responses.competingrisk.MathFunction; -import ca.joeltherrien.randomforest.responses.competingrisk.Point; +import ca.joeltherrien.randomforest.responses.competingrisk.*; import ca.joeltherrien.randomforest.tree.Forest; import ca.joeltherrien.randomforest.tree.ForestTrainer; import ca.joeltherrien.randomforest.tree.Node; @@ -16,11 +13,8 @@ import org.junit.jupiter.api.Test; import static org.junit.jupiter.api.Assertions.*; import java.io.IOException; -import java.util.ArrayList; -import java.util.Collections; import java.util.List; import java.util.Map; -import java.util.concurrent.ThreadLocalRandom; import java.util.stream.Collectors; public class TestCompetingRisk { @@ -81,6 +75,10 @@ public class TestCompetingRisk { .build(); } + public List getCovariates(Settings settings){ + return settings.getCovariates().stream().map(covariateSettings -> covariateSettings.build()).collect(Collectors.toList()); + } + public CovariateRow getPredictionRow(List covariates){ return CovariateRow.createSimple(Map.of( "ageatfda", "35", @@ -194,10 +192,6 @@ public class TestCompetingRisk { } - public List getCovariates(Settings settings){ - return settings.getCovariates().stream().map(covariateSettings -> covariateSettings.build()).collect(Collectors.toList()); - } - @Test public void testLogRankSingleGroupDifferentiatorTwoBooleans() throws IOException { final Settings settings = getSettings(); @@ -232,6 +226,15 @@ public class TestCompetingRisk { closeEnough(0.163, functions.getCumulativeIncidenceFunction(2).evaluate(4.0).getY(), 0.01); closeEnough(0.195, functions.getCumulativeIncidenceFunction(2).evaluate(10.8).getY(), 0.01); + + final CompetingRiskErrorRateCalculator errorRateCalculator = new CompetingRiskErrorRateCalculator(new int[]{1,2}, null); + final double[] errorRates = errorRateCalculator.calculateAll(dataset, forest); + + // Error rates happen to be about the same + closeEnough(0.4795, errorRates[0], 0.007); + closeEnough(0.478, errorRates[1], 0.008); + + } @Test @@ -294,8 +297,16 @@ public class TestCompetingRisk { final List causeOneCIFPoints = functions.getCumulativeIncidenceFunction(1).getPoints(); // We seem to consistently underestimate the results. - assertTrue(causeOneCIFPoints.get(causeOneCIFPoints.size()-1).getY() > 0.75, "Results should match randomForestSRC"); + assertTrue(causeOneCIFPoints.get(causeOneCIFPoints.size()-1).getY() > 0.75, "Results should match randomForestSRC; had " + causeOneCIFPoints.get(causeOneCIFPoints.size()-1).getY()); // note; most observations from randomForestSRC hover around 0.78 but I've seen it as low as 0.72 + final CompetingRiskErrorRateCalculator errorRate = new CompetingRiskErrorRateCalculator((CompetingRiskFunctionCombiner) settings.getTreeCombiner(), new int[]{1,2}); + final double[] errorRates = errorRate.calculateAll(dataset, forest); + + System.out.println(errorRates[0]); + System.out.println(errorRates[1]); + + closeEnough(0.41, errorRates[0], 0.02); + closeEnough(0.38, errorRates[1], 0.02); } diff --git a/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCompetingRiskErrorRateCalculator.java b/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCompetingRiskErrorRateCalculator.java new file mode 100644 index 0000000..a6317b7 --- /dev/null +++ b/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCompetingRiskErrorRateCalculator.java @@ -0,0 +1,91 @@ +package ca.joeltherrien.randomforest.competingrisk; + +import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskErrorRateCalculator; +import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse; +import org.junit.jupiter.api.Test; + +import java.util.List; + +import static org.junit.jupiter.api.Assertions.*; + +public class TestCompetingRiskErrorRateCalculator { + + /* + @Test + public void testComparingResponses(){ + + // Large, uncensored + CompetingRiskResponse responseA = new CompetingRiskResponse(1, 10.0); + + // Large, censored + CompetingRiskResponse responseB = new CompetingRiskResponse(0, 10.0); + + // Large, other event + CompetingRiskResponse responseC = new CompetingRiskResponse(2, 10.0); + + // Medium, uncensored + CompetingRiskResponse responseD = new CompetingRiskResponse(1, 5.0); + + // Medium, censored + CompetingRiskResponse responseE = new CompetingRiskResponse(0, 5.0); + + // Medium, other event + CompetingRiskResponse responseF = new CompetingRiskResponse(2, 5.0); + + final int event = 1; + + final CompetingRiskErrorRateCalculator errorRateCalculator = new CompetingRiskErrorRateCalculator(null, null); + + assertThrows(IllegalArgumentException.class, () -> errorRateCalculator.compare(responseB, responseB, event)); + assertThrows(IllegalArgumentException.class, () -> errorRateCalculator.compare(responseC, responseC, event)); + + assertEquals(0.5, errorRateCalculator.compare(responseA, responseB, event)); + assertEquals(-0.5, errorRateCalculator.compare(responseB, responseA, event)); + + assertEquals(0.0, errorRateCalculator.compare(responseA, responseA, event)); + + assertThrows(IllegalArgumentException.class, () -> errorRateCalculator.compare(responseB, responseE, event)); + assertThrows(IllegalArgumentException.class, () -> errorRateCalculator.compare(responseE, responseB, event)); + assertThrows(IllegalArgumentException.class, () -> errorRateCalculator.compare(responseB, responseF, event)); + assertThrows(IllegalArgumentException.class, () -> errorRateCalculator.compare(responseF, responseB, event)); + + assertEquals(-1.0, errorRateCalculator.compare(responseB, responseD, event)); + assertEquals(1.0, errorRateCalculator.compare(responseD, responseB, event)); + assertEquals(-1.0, errorRateCalculator.compare(responseC, responseD, event)); + assertEquals(1.0, errorRateCalculator.compare(responseD, responseC, event)); + + assertEquals(-1.0, errorRateCalculator.compare(responseA, responseD, event)); + assertEquals(1.0, errorRateCalculator.compare(responseD, responseA, event)); + + assertThrows(IllegalArgumentException.class, () -> errorRateCalculator.compare(responseA, responseE, event)); + assertThrows(IllegalArgumentException.class, () -> errorRateCalculator.compare(responseE, responseA, event)); + assertThrows(IllegalArgumentException.class, () -> errorRateCalculator.compare(responseA, responseF, event)); + assertThrows(IllegalArgumentException.class, () -> errorRateCalculator.compare(responseF, responseA, event)); + + + } + */ + + @Test + public void testConcordance(){ + + final CompetingRiskResponse response1 = new CompetingRiskResponse(1, 5.0); + final CompetingRiskResponse response2 = new CompetingRiskResponse(0, 6.0); + final CompetingRiskResponse response3 = new CompetingRiskResponse(2, 8.0); + final CompetingRiskResponse response4 = new CompetingRiskResponse(1, 3.0); + + final double[] mortalityArray = new double[]{1, 4, 3, 9}; + final List responseList = List.of(response1, response2, response3, response4); + + final int event = 1; + + final CompetingRiskErrorRateCalculator errorRateCalculator = new CompetingRiskErrorRateCalculator(new int[]{1,2}, null); + + final double concordance = errorRateCalculator.calculate(responseList, mortalityArray, event); + + // Expected value found through calculations by hand + assertEquals(3.0/5.0, concordance); + + } + +}