diff --git a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskErrorRateCalculator.java b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskErrorRateCalculator.java index 4157e17..3d95b84 100644 --- a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskErrorRateCalculator.java +++ b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskErrorRateCalculator.java @@ -3,6 +3,7 @@ package ca.joeltherrien.randomforest.responses.competingrisk; import ca.joeltherrien.randomforest.Row; import ca.joeltherrien.randomforest.VisibleForTesting; import ca.joeltherrien.randomforest.tree.Forest; +import ca.joeltherrien.randomforest.utils.MathFunction; import java.util.List; import java.util.stream.Collectors; @@ -25,12 +26,6 @@ public class CompetingRiskErrorRateCalculator { .collect(Collectors.toList()); } - public double[] calculateConcordance(final int[] events){ - final double tau = dataset.stream().mapToDouble(row -> row.getResponse().getU()).max().orElse(0.0); - - return calculateConcordance(events, tau); - } - /** * Idea for this error rate; go through every observation I have and calculate its mortality for the different events. If the event with the highest mortality is not the one that happened, * then we add one to the error scale. @@ -39,6 +34,8 @@ public class CompetingRiskErrorRateCalculator { * * Possible extensions might involve counting how many other events had higher mortality, instead of just a single PASS / FAIL. * + * My observation is that this error rate isn't very useful... + * * @return */ public double calculateNaiveMortalityError(final int[] events){ @@ -77,6 +74,13 @@ public class CompetingRiskErrorRateCalculator { } + + public double[] calculateConcordance(final int[] events){ + final double tau = dataset.stream().mapToDouble(row -> row.getResponse().getU()).max().orElse(0.0); + + return calculateConcordance(events, tau); + } + private double[] calculateConcordance(final int[] events, final double tau){ final double[] errorRates = new double[events.length]; @@ -101,6 +105,36 @@ public class CompetingRiskErrorRateCalculator { } + public double[] calculateIPCWConcordance(final int[] events, final MathFunction censoringDistribution){ + final double tau = dataset.stream().mapToDouble(row -> row.getResponse().getU()).max().orElse(0.0); + + return calculateIPCWConcordance(events, censoringDistribution, tau); + } + + private double[] calculateIPCWConcordance(final int[] events, final MathFunction censoringDistribution, final double tau){ + + final double[] errorRates = new double[events.length]; + + final List responses = dataset.stream().map(Row::getResponse).collect(Collectors.toList()); + + // Let \tau be the max time. + + for(int e=0; e riskFunction.calculateEventSpecificMortality(event, tau)) + .toArray(); + + final double concordance = calculateIPCWConcordance(responses, mortalityList, event, censoringDistribution); + errorRates[e] = 1.0 - concordance; + + } + + return errorRates; + + } + @VisibleForTesting public double calculateConcordance(final List responseList, double[] mortalityArray, final int event){ @@ -141,4 +175,55 @@ public class CompetingRiskErrorRateCalculator { } + + @VisibleForTesting + public double calculateIPCWConcordance(final List responseList, double[] mortalityArray, final int event, final MathFunction censoringDistribution){ + + // Let \tau be the max time. + + double denominator = 0.0; + double numerator = 0.0; + + for(int i = 0; i= responseJ.getU() && !responseJ.isCensored() && responseJ.getDelta() != event){ // Bij == 1 + AijWeightPlusBijWeight = 1.0 / (censoringDistribution.evaluatePrevious(responseI.getU()).getY() * censoringDistribution.evaluatePrevious(responseJ.getU()).getY()); + } + else{ + continue; + } + + denominator += AijWeightPlusBijWeight; + + final double mortalityJ = mortalityArray[j]; + if(mortalityI > mortalityJ){ + numerator += AijWeightPlusBijWeight*1.0; + } + else if(mortalityI == mortalityJ){ + numerator += AijWeightPlusBijWeight*0.5; // Edge case that can happen in trees with only a few BooleanCovariates, when you're looking at training error + } + + } + + } + + return numerator / denominator; + + } + } diff --git a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskFunctionCombiner.java b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskFunctionCombiner.java index 93c3e54..6dae8c0 100644 --- a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskFunctionCombiner.java +++ b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskFunctionCombiner.java @@ -1,7 +1,8 @@ package ca.joeltherrien.randomforest.responses.competingrisk; import ca.joeltherrien.randomforest.tree.ResponseCombiner; -import lombok.Getter; +import ca.joeltherrien.randomforest.utils.MathFunction; +import ca.joeltherrien.randomforest.utils.Point; import lombok.RequiredArgsConstructor; import java.util.ArrayList; diff --git a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskFunctions.java b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskFunctions.java index b9e36b7..e75ffd4 100644 --- a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskFunctions.java +++ b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskFunctions.java @@ -1,8 +1,9 @@ package ca.joeltherrien.randomforest.responses.competingrisk; +import ca.joeltherrien.randomforest.utils.MathFunction; +import ca.joeltherrien.randomforest.utils.Point; import lombok.Builder; import lombok.Getter; -import lombok.RequiredArgsConstructor; import java.io.Serializable; import java.util.Map; @@ -31,6 +32,10 @@ public class CompetingRiskFunctions implements Serializable { Point previousPoint = null; for(final Point point : cif.getPoints()){ + if(point.getTime() > tau){ + break; + } + if(previousPoint != null){ summation += previousPoint.getY() * (point.getTime() - previousPoint.getTime()); } @@ -38,9 +43,11 @@ public class CompetingRiskFunctions implements Serializable { } - // this is to ensure that we integrate over the same range for every function and get comparable results. - // Don't need to assert whether previousPoint is null or not; if it is null then the MathFunction was incorrectly made as there will always be at least one point for a response - summation += previousPoint.getY() * (tau - previousPoint.getTime()); + // this is to ensure that we integrate over the proper range + if(previousPoint != null){ + summation += cif.evaluate(tau).getY() * (tau - previousPoint.getTime()); + } + return summation; diff --git a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskResponseCombiner.java b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskResponseCombiner.java index d728879..c417413 100644 --- a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskResponseCombiner.java +++ b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskResponseCombiner.java @@ -1,6 +1,8 @@ package ca.joeltherrien.randomforest.responses.competingrisk; import ca.joeltherrien.randomforest.tree.ResponseCombiner; +import ca.joeltherrien.randomforest.utils.MathFunction; +import ca.joeltherrien.randomforest.utils.Point; import lombok.RequiredArgsConstructor; import java.util.*; diff --git a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/MathFunction.java b/src/main/java/ca/joeltherrien/randomforest/utils/MathFunction.java similarity index 82% rename from src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/MathFunction.java rename to src/main/java/ca/joeltherrien/randomforest/utils/MathFunction.java index fde343d..68d6e83 100644 --- a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/MathFunction.java +++ b/src/main/java/ca/joeltherrien/randomforest/utils/MathFunction.java @@ -1,4 +1,4 @@ -package ca.joeltherrien.randomforest.responses.competingrisk; +package ca.joeltherrien.randomforest.utils; import lombok.Getter; @@ -43,6 +43,14 @@ public class MathFunction implements Serializable { } + public Point evaluatePrevious(double time){ + final Optional pointOptional = points.stream() + .filter(point -> point.getTime() < time) + .max(Comparator.comparingDouble(Point::getTime)); + + return pointOptional.orElse(defaultValue); + } + @Override public String toString(){ final StringBuilder builder = new StringBuilder(); diff --git a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/Point.java b/src/main/java/ca/joeltherrien/randomforest/utils/Point.java similarity index 82% rename from src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/Point.java rename to src/main/java/ca/joeltherrien/randomforest/utils/Point.java index 56a2831..c97d194 100644 --- a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/Point.java +++ b/src/main/java/ca/joeltherrien/randomforest/utils/Point.java @@ -1,4 +1,4 @@ -package ca.joeltherrien.randomforest.responses.competingrisk; +package ca.joeltherrien.randomforest.utils; import lombok.Data; diff --git a/src/main/java/ca/joeltherrien/randomforest/utils/Utils.java b/src/main/java/ca/joeltherrien/randomforest/utils/Utils.java new file mode 100644 index 0000000..0b351d3 --- /dev/null +++ b/src/main/java/ca/joeltherrien/randomforest/utils/Utils.java @@ -0,0 +1,40 @@ +package ca.joeltherrien.randomforest.utils; + +import java.util.*; + +public class Utils { + + public static MathFunction estimateOneMinusECDF(final double[] times){ + final Point defaultPoint = new Point(0.0, 1.0); + Arrays.sort(times); + + final Map timeCounterMap = new HashMap<>(); + + for(final double time : times){ + Integer existingCount = timeCounterMap.get(time); + existingCount = existingCount != null ? existingCount : 0; + + timeCounterMap.put(time, existingCount+1); + } + + final List> timeCounterList = new ArrayList<>(timeCounterMap.entrySet()); + Collections.sort(timeCounterList, Comparator.comparingDouble(Map.Entry::getKey)); + + final List pointList = new ArrayList<>(timeCounterList.size()); + + int previousCount = times.length; + final double n = times.length; + + for(final Map.Entry entry : timeCounterList){ + final int newCount = previousCount - entry.getValue(); + previousCount = newCount; + + pointList.add(new Point(entry.getKey(), (double) newCount / n)); + } + + return new MathFunction(pointList, defaultPoint); + + } + + +} diff --git a/src/test/java/ca/joeltherrien/randomforest/TestUtils.java b/src/test/java/ca/joeltherrien/randomforest/TestUtils.java new file mode 100644 index 0000000..72bee51 --- /dev/null +++ b/src/test/java/ca/joeltherrien/randomforest/TestUtils.java @@ -0,0 +1,73 @@ +package ca.joeltherrien.randomforest; + +import ca.joeltherrien.randomforest.utils.MathFunction; +import ca.joeltherrien.randomforest.utils.Point; +import ca.joeltherrien.randomforest.utils.Utils; +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class TestUtils { + + public static void closeEnough(double expected, double actual, double margin){ + assertTrue(Math.abs(expected - actual) < margin, "Expected " + expected + " but saw " + actual); + } + + /** + * We know the function is cumulative; make sure it is ordered correctly and that that function is monotone. + * + * @param function + */ + public static void assertCumulativeFunction(MathFunction function){ + Point previousPoint = null; + for(final Point point : function.getPoints()){ + + if(previousPoint != null){ + assertTrue(previousPoint.getTime() < point.getTime(), "Points should be ordered and strictly different"); + assertTrue(previousPoint.getY() <= point.getY(), "Cumulative incidence functions are monotone"); + } + + + previousPoint = point; + } + } + + public static void assertSurvivalCurve(MathFunction function){ + Point previousPoint = null; + for(final Point point : function.getPoints()){ + + if(previousPoint != null){ + assertTrue(previousPoint.getTime() < point.getTime(), "Points should be ordered and strictly different"); + assertTrue(previousPoint.getY() >= point.getY(), "Survival functions are monotone"); + } + + + previousPoint = point; + } + } + + @Test + public void testOneMinusECDF(){ + final double[] times = new double[]{1.0, 1.0, 2.0, 3.0, 3.0, 50.0}; + final MathFunction survivalCurve = Utils.estimateOneMinusECDF(times); + + final double margin = 0.000001; + closeEnough(1.0, survivalCurve.evaluate(0.0).getY(), margin); + + closeEnough(1.0, survivalCurve.evaluatePrevious(1.0).getY(), margin); + closeEnough(4.0/6.0, survivalCurve.evaluate(1.0).getY(), margin); + + closeEnough(4.0/6.0, survivalCurve.evaluatePrevious(2.0).getY(), margin); + closeEnough(3.0/6.0, survivalCurve.evaluate(2.0).getY(), margin); + + closeEnough(3.0/6.0, survivalCurve.evaluatePrevious(3.0).getY(), margin); + closeEnough(1.0/6.0, survivalCurve.evaluate(3.0).getY(), margin); + + closeEnough(1.0/6.0, survivalCurve.evaluatePrevious(50.0).getY(), margin); + closeEnough(0.0, survivalCurve.evaluate(50.0).getY(), margin); + + assertSurvivalCurve(survivalCurve); + + } + +} diff --git a/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCompetingRisk.java b/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCompetingRisk.java index 8b7bd63..e0311ee 100644 --- a/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCompetingRisk.java +++ b/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCompetingRisk.java @@ -7,9 +7,13 @@ import ca.joeltherrien.randomforest.tree.Forest; import ca.joeltherrien.randomforest.tree.ForestTrainer; import ca.joeltherrien.randomforest.tree.Node; import ca.joeltherrien.randomforest.tree.TreeTrainer; +import ca.joeltherrien.randomforest.utils.MathFunction; +import ca.joeltherrien.randomforest.utils.Point; import com.fasterxml.jackson.databind.node.*; import org.junit.jupiter.api.Test; +import static ca.joeltherrien.randomforest.TestUtils.assertCumulativeFunction; +import static ca.joeltherrien.randomforest.TestUtils.closeEnough; import static org.junit.jupiter.api.Assertions.*; import java.io.IOException; @@ -240,10 +244,10 @@ public class TestCompetingRisk { System.out.println(errorRates[1]); - closeEnough(0.452, errorRates[0], 0.01); - closeEnough(0.446, errorRates[1], 0.01); - + closeEnough(0.452, errorRates[0], 0.02); + closeEnough(0.446, errorRates[1], 0.02); + System.out.println(errorRateCalculator.calculateNaiveMortalityError(new int[]{1,2})); } @Test @@ -306,7 +310,7 @@ public class TestCompetingRisk { final List causeOneCIFPoints = functions.getCumulativeIncidenceFunction(1).getPoints(); // We seem to consistently underestimate the results. - assertTrue(causeOneCIFPoints.get(causeOneCIFPoints.size()-1).getY() > 0.75, "Results should match randomForestSRC; had " + causeOneCIFPoints.get(causeOneCIFPoints.size()-1).getY()); // note; most observations from randomForestSRC hover around 0.78 but I've seen it as low as 0.72 + assertTrue(causeOneCIFPoints.get(causeOneCIFPoints.size()-1).getY() > 0.74, "Results should match randomForestSRC; had " + causeOneCIFPoints.get(causeOneCIFPoints.size()-1).getY()); // note; most observations from randomForestSRC hover around 0.78 but I've seen it as low as 0.72 final CompetingRiskErrorRateCalculator errorRateCalculator = new CompetingRiskErrorRateCalculator(dataset, forest); final double[] errorRates = errorRateCalculator.calculateConcordance(new int[]{1,2}); @@ -322,29 +326,9 @@ public class TestCompetingRisk { // Consistency results closeEnough(0.395, errorRates[0], 0.01); closeEnough(0.345, errorRates[1], 0.01); - } - /** - * We know the function is cumulative; make sure it is ordered correctly and that that function is monotone. - * - * @param function - */ - private void assertCumulativeFunction(MathFunction function){ - Point previousPoint = null; - for(final Point point : function.getPoints()){ + System.out.println(errorRateCalculator.calculateNaiveMortalityError(new int[]{1,2})); - if(previousPoint != null){ - assertTrue(previousPoint.getTime() < point.getTime(), "Points should be ordered and strictly different"); - assertTrue(previousPoint.getY() <= point.getY(), "Cumulative incidence functions are monotone"); - } - - - previousPoint = point; - } - } - - private void closeEnough(double expected, double actual, double margin){ - assertTrue(Math.abs(expected - actual) < margin, "Expected " + expected + " but saw " + actual); } } diff --git a/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCompetingRiskErrorRateCalculator.java b/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCompetingRiskErrorRateCalculator.java index db11ad7..c0c54ec 100644 --- a/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCompetingRiskErrorRateCalculator.java +++ b/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCompetingRiskErrorRateCalculator.java @@ -1,16 +1,16 @@ package ca.joeltherrien.randomforest.competingrisk; - import ca.joeltherrien.randomforest.Row; -import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskErrorRateCalculator; -import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskFunctions; -import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse; +import ca.joeltherrien.randomforest.responses.competingrisk.*; import ca.joeltherrien.randomforest.tree.Forest; +import ca.joeltherrien.randomforest.utils.MathFunction; +import ca.joeltherrien.randomforest.utils.Point; import org.junit.jupiter.api.Test; import java.util.Collections; import java.util.List; +import static ca.joeltherrien.randomforest.TestUtils.closeEnough; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -33,10 +33,16 @@ public class TestCompetingRiskErrorRateCalculator { final Forest fakeForest = Forest.builder().build(); final CompetingRiskErrorRateCalculator errorRateCalculator = new CompetingRiskErrorRateCalculator(Collections.emptyList(), fakeForest); - final double concordance = errorRateCalculator.calculateConcordance(responseList, mortalityArray, event); + final double naiveConcordance = errorRateCalculator.calculateConcordance(responseList, mortalityArray, event); + + final MathFunction fakeCensorDistribution = new MathFunction(Collections.emptyList(), new Point(0.0, 1.0)); + // This distribution will make the IPCW weights == 1, giving identical results to the naive concordance. + final double ipcwConcordance = errorRateCalculator.calculateIPCWConcordance(responseList, mortalityArray, event, fakeCensorDistribution); + + closeEnough(naiveConcordance, ipcwConcordance, 0.0001); // Expected value found through calculations by hand - assertEquals(3.0/5.0, concordance); + assertEquals(3.0/5.0, naiveConcordance); } diff --git a/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCompetingRiskResponseCombiner.java b/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCompetingRiskResponseCombiner.java index cdbd10b..ce18106 100644 --- a/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCompetingRiskResponseCombiner.java +++ b/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCompetingRiskResponseCombiner.java @@ -3,8 +3,10 @@ package ca.joeltherrien.randomforest.competingrisk; import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskFunctions; import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse; import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponseCombiner; -import ca.joeltherrien.randomforest.responses.competingrisk.MathFunction; +import ca.joeltherrien.randomforest.utils.MathFunction; import org.junit.jupiter.api.Test; + +import static ca.joeltherrien.randomforest.TestUtils.closeEnough; import static org.junit.jupiter.api.Assertions.*; import java.util.ArrayList; @@ -87,8 +89,4 @@ public class TestCompetingRiskResponseCombiner { } - private void closeEnough(double expected, double actual, double margin){ - assertTrue(Math.abs(expected - actual) < margin, "Expected " + expected + " but saw " + actual); - } - } diff --git a/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestMathFunction.java b/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestMathFunction.java index a602e73..591ca9a 100644 --- a/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestMathFunction.java +++ b/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestMathFunction.java @@ -1,7 +1,7 @@ package ca.joeltherrien.randomforest.competingrisk; -import ca.joeltherrien.randomforest.responses.competingrisk.MathFunction; -import ca.joeltherrien.randomforest.responses.competingrisk.Point; +import ca.joeltherrien.randomforest.utils.MathFunction; +import ca.joeltherrien.randomforest.utils.Point; import org.junit.jupiter.api.Test; import static org.junit.jupiter.api.Assertions.*;