From e1caef6d5636b7ff5e8d69df183afbd02e5a71e0 Mon Sep 17 00:00:00 2001 From: Joel Therrien Date: Wed, 25 Jul 2018 15:29:09 -0700 Subject: [PATCH] Implement naive mortality error measure --- pom.xml | 7 ++ .../CompetingRiskErrorRateCalculator.java | 115 ++++++++++-------- .../competingrisk/CompetingRiskFunctions.java | 22 ++++ .../randomforest/tree/Forest.java | 12 ++ .../competingrisk/TestCompetingRisk.java | 8 +- .../TestCompetingRiskErrorRateCalculator.java | 69 ++++++++++- 6 files changed, 174 insertions(+), 59 deletions(-) diff --git a/pom.xml b/pom.xml index cdb9617..cf2a8fd 100644 --- a/pom.xml +++ b/pom.xml @@ -56,6 +56,13 @@ test + + org.mockito + mockito-core + 2.20.0 + test + + diff --git a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskErrorRateCalculator.java b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskErrorRateCalculator.java index 619e85d..4157e17 100644 --- a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskErrorRateCalculator.java +++ b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskErrorRateCalculator.java @@ -3,10 +3,8 @@ package ca.joeltherrien.randomforest.responses.competingrisk; import ca.joeltherrien.randomforest.Row; import ca.joeltherrien.randomforest.VisibleForTesting; import ca.joeltherrien.randomforest.tree.Forest; -import ca.joeltherrien.randomforest.tree.Tree; -import lombok.RequiredArgsConstructor; -import java.util.*; +import java.util.List; import java.util.stream.Collectors; /** @@ -15,42 +13,75 @@ import java.util.stream.Collectors; * Note that this is the same version implemented in randomForestSRC. The downsides of this approach is that we can expect the errors to be biased, possibly severely. * Therefore I suggest that this measure only be used in comparing models, but not as a final output. */ -@RequiredArgsConstructor public class CompetingRiskErrorRateCalculator { - private final CompetingRiskFunctionCombiner combiner; - private final int[] events; + private final List> dataset; + private final List riskFunctions; - public CompetingRiskErrorRateCalculator(final int[] events, final double[] times){ - this.events = events; - this.combiner = new CompetingRiskFunctionCombiner(events, times); - } - - public double[] calculateConcordance(final List> rows, final Forest forest){ - final double tau = rows.stream().mapToDouble(row -> row.getResponse().getU()).max().orElse(0.0); - - return calculateConcordance(rows, forest, tau); - } - - private double[] calculateConcordance(final List> rows, final Forest forest, final double tau){ - - final Collection> trees = forest.getTrees(); - - // This predicts for rows based on their OOB trees. - - final List riskFunctions = rows.stream() - .map(row -> { - return trees.stream().filter(tree -> !tree.idInBootstrapSample(row.getId())).map(tree -> tree.evaluate(row)).collect(Collectors.toList()); - }) - .map(combiner::combine) + public CompetingRiskErrorRateCalculator(final List> dataset, final Forest forest){ + this.dataset = dataset; + this.riskFunctions = dataset.stream() + .map(forest::evaluateOOB) .collect(Collectors.toList()); + } + + public double[] calculateConcordance(final int[] events){ + final double tau = dataset.stream().mapToDouble(row -> row.getResponse().getU()).max().orElse(0.0); + + return calculateConcordance(events, tau); + } + + /** + * Idea for this error rate; go through every observation I have and calculate its mortality for the different events. If the event with the highest mortality is not the one that happened, + * then we add one to the error scale. + * + * Ignore censored observations. + * + * Possible extensions might involve counting how many other events had higher mortality, instead of just a single PASS / FAIL. + * + * @return + */ + public double calculateNaiveMortalityError(final int[] events){ + int failures = 0; + int attempts = 0; + + response_loop: + for(int i=0; i riskFunctions = rows.stream().map(row -> forest.evaluate(row)).collect(Collectors.toList()); + } + + private double[] calculateConcordance(final int[] events, final double tau){ final double[] errorRates = new double[events.length]; - final List responses = rows.stream().map(Row::getResponse).collect(Collectors.toList()); + final List responses = dataset.stream().map(Row::getResponse).collect(Collectors.toList()); // Let \tau be the max time. @@ -58,8 +89,7 @@ public class CompetingRiskErrorRateCalculator { final int event = events[e]; final double[] mortalityList = riskFunctions.stream() - .map(riskFunction -> riskFunction.getCumulativeIncidenceFunction(event)) - .mapToDouble(cif -> functionToMortality(cif, tau)) + .mapToDouble(riskFunction -> riskFunction.calculateEventSpecificMortality(event, tau)) .toArray(); final double concordance = calculateConcordance(responses, mortalityList, event); @@ -111,25 +141,4 @@ public class CompetingRiskErrorRateCalculator { } - private double functionToMortality(final MathFunction cif, final double tau){ - double summation = 0.0; - Point previousPoint = null; - - for(final Point point : cif.getPoints()){ - if(previousPoint != null){ - summation += previousPoint.getY() * (point.getTime() - previousPoint.getTime()); - } - previousPoint = point; - - } - - // this is to ensure that we integrate over the same range for every function and get comparable results. - // Don't need to assert whether previousPoint is null or not; if it is null then the MathFunction was incorrectly made as there will always be at least one point for a response - summation += previousPoint.getY() * (tau - previousPoint.getTime()); - - return summation; - - } - - } diff --git a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskFunctions.java b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskFunctions.java index 743e507..b9e36b7 100644 --- a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskFunctions.java +++ b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskFunctions.java @@ -23,4 +23,26 @@ public class CompetingRiskFunctions implements Serializable { public MathFunction getCumulativeIncidenceFunction(int cause) { return cumulativeIncidenceFunctionMap.get(cause); } + + public double calculateEventSpecificMortality(final int event, final double tau){ + final MathFunction cif = getCauseSpecificHazardFunction(event); + + double summation = 0.0; + Point previousPoint = null; + + for(final Point point : cif.getPoints()){ + if(previousPoint != null){ + summation += previousPoint.getY() * (point.getTime() - previousPoint.getTime()); + } + previousPoint = point; + + } + + // this is to ensure that we integrate over the same range for every function and get comparable results. + // Don't need to assert whether previousPoint is null or not; if it is null then the MathFunction was incorrectly made as there will always be at least one point for a response + summation += previousPoint.getY() * (tau - previousPoint.getTime()); + + return summation; + + } } diff --git a/src/main/java/ca/joeltherrien/randomforest/tree/Forest.java b/src/main/java/ca/joeltherrien/randomforest/tree/Forest.java index 449cfcd..f9e8a34 100644 --- a/src/main/java/ca/joeltherrien/randomforest/tree/Forest.java +++ b/src/main/java/ca/joeltherrien/randomforest/tree/Forest.java @@ -2,6 +2,7 @@ package ca.joeltherrien.randomforest.tree; import ca.joeltherrien.randomforest.CovariateRow; import lombok.Builder; +import lombok.RequiredArgsConstructor; import java.util.Collection; import java.util.Collections; @@ -23,6 +24,17 @@ public class Forest { // O = output of trees, FO = forest output. In prac } + public FO evaluateOOB(CovariateRow row){ + + return treeResponseCombiner.combine( + trees.stream() + .filter(tree -> !tree.idInBootstrapSample(row.getId())) + .map(node -> node.evaluate(row)) + .collect(Collectors.toList()) + ); + + } + public Collection> getTrees(){ return Collections.unmodifiableCollection(trees); } diff --git a/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCompetingRisk.java b/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCompetingRisk.java index 33268db..8b7bd63 100644 --- a/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCompetingRisk.java +++ b/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCompetingRisk.java @@ -227,8 +227,8 @@ public class TestCompetingRisk { closeEnough(0.163, functions.getCumulativeIncidenceFunction(2).evaluate(4.0).getY(), 0.01); closeEnough(0.195, functions.getCumulativeIncidenceFunction(2).evaluate(10.8).getY(), 0.01); - final CompetingRiskErrorRateCalculator errorRateCalculator = new CompetingRiskErrorRateCalculator(new int[]{1,2}, null); - final double[] errorRates = errorRateCalculator.calculateConcordance(dataset, forest); + final CompetingRiskErrorRateCalculator errorRateCalculator = new CompetingRiskErrorRateCalculator(dataset, forest); + final double[] errorRates = errorRateCalculator.calculateConcordance(new int[]{1,2}); // Error rates happen to be about the same /* randomForestSRC results; ignored for now @@ -308,8 +308,8 @@ public class TestCompetingRisk { // We seem to consistently underestimate the results. assertTrue(causeOneCIFPoints.get(causeOneCIFPoints.size()-1).getY() > 0.75, "Results should match randomForestSRC; had " + causeOneCIFPoints.get(causeOneCIFPoints.size()-1).getY()); // note; most observations from randomForestSRC hover around 0.78 but I've seen it as low as 0.72 - final CompetingRiskErrorRateCalculator errorRate = new CompetingRiskErrorRateCalculator(new int[]{1,2}, null); - final double[] errorRates = errorRate.calculateConcordance(dataset, forest); + final CompetingRiskErrorRateCalculator errorRateCalculator = new CompetingRiskErrorRateCalculator(dataset, forest); + final double[] errorRates = errorRateCalculator.calculateConcordance(new int[]{1,2}); System.out.println(errorRates[0]); System.out.println(errorRates[1]); diff --git a/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCompetingRiskErrorRateCalculator.java b/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCompetingRiskErrorRateCalculator.java index d72f075..db11ad7 100644 --- a/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCompetingRiskErrorRateCalculator.java +++ b/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCompetingRiskErrorRateCalculator.java @@ -1,12 +1,19 @@ package ca.joeltherrien.randomforest.competingrisk; + +import ca.joeltherrien.randomforest.Row; import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskErrorRateCalculator; +import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskFunctions; import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse; +import ca.joeltherrien.randomforest.tree.Forest; import org.junit.jupiter.api.Test; +import java.util.Collections; import java.util.List; -import static org.junit.jupiter.api.Assertions.*; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; public class TestCompetingRiskErrorRateCalculator { @@ -23,7 +30,8 @@ public class TestCompetingRiskErrorRateCalculator { final int event = 1; - final CompetingRiskErrorRateCalculator errorRateCalculator = new CompetingRiskErrorRateCalculator(new int[]{1,2}, null); + final Forest fakeForest = Forest.builder().build(); + final CompetingRiskErrorRateCalculator errorRateCalculator = new CompetingRiskErrorRateCalculator(Collections.emptyList(), fakeForest); final double concordance = errorRateCalculator.calculateConcordance(responseList, mortalityArray, event); @@ -32,4 +40,61 @@ public class TestCompetingRiskErrorRateCalculator { } + @Test + public void testNaiveMortality(){ + final CompetingRiskResponse response1 = new CompetingRiskResponse(1, 5.0); + final CompetingRiskResponse response2 = new CompetingRiskResponse(0, 6.0); + final CompetingRiskResponse response3 = new CompetingRiskResponse(2, 8.0); + final CompetingRiskResponse response4 = new CompetingRiskResponse(1, 3.0); + + final List> dataset = List.of( + new Row<>(Collections.emptyMap(), 1, response1), + new Row<>(Collections.emptyMap(), 2, response2), + new Row<>(Collections.emptyMap(), 3, response3), + new Row<>(Collections.emptyMap(), 4, response4) + ); + + final double[] mortalityOneArray = new double[]{1, 4, 3, 9}; + final double[] mortalityTwoArray = new double[]{2, 3, 4, 7}; + + // response1 was predicted incorrectly + // response2 doesn't matter; censored + // response3 was correctly predicted + // response4 was correctly predicted + + // Expect 1/3 for my error + + final CompetingRiskFunctions function1 = mock(CompetingRiskFunctions.class); + when(function1.calculateEventSpecificMortality(1, response1.getU())).thenReturn(mortalityOneArray[0]); + when(function1.calculateEventSpecificMortality(2, response1.getU())).thenReturn(mortalityTwoArray[0]); + + final CompetingRiskFunctions function2 = mock(CompetingRiskFunctions.class); + when(function2.calculateEventSpecificMortality(1, response2.getU())).thenReturn(mortalityOneArray[1]); + when(function2.calculateEventSpecificMortality(2, response2.getU())).thenReturn(mortalityTwoArray[1]); + + final CompetingRiskFunctions function3 = mock(CompetingRiskFunctions.class); + when(function3.calculateEventSpecificMortality(1, response3.getU())).thenReturn(mortalityOneArray[2]); + when(function3.calculateEventSpecificMortality(2, response3.getU())).thenReturn(mortalityTwoArray[2]); + + final CompetingRiskFunctions function4 = mock(CompetingRiskFunctions.class); + when(function4.calculateEventSpecificMortality(1, response4.getU())).thenReturn(mortalityOneArray[3]); + when(function4.calculateEventSpecificMortality(2, response4.getU())).thenReturn(mortalityTwoArray[3]); + + final Forest mockForest = mock(Forest.class); + when(mockForest.evaluateOOB(dataset.get(0))).thenReturn(function1); + when(mockForest.evaluateOOB(dataset.get(1))).thenReturn(function2); + when(mockForest.evaluateOOB(dataset.get(2))).thenReturn(function3); + when(mockForest.evaluateOOB(dataset.get(3))).thenReturn(function4); + + + final CompetingRiskErrorRateCalculator errorRateCalculator = new CompetingRiskErrorRateCalculator(dataset, mockForest); + + final double error = errorRateCalculator.calculateNaiveMortalityError(new int[]{1,2}); + + assertEquals(1.0/3.0, error); + + } + + + }