diff --git a/pom.xml b/pom.xml
index cdb9617..cf2a8fd 100644
--- a/pom.xml
+++ b/pom.xml
@@ -56,6 +56,13 @@
test
+
+ org.mockito
+ mockito-core
+ 2.20.0
+ test
+
+
diff --git a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskErrorRateCalculator.java b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskErrorRateCalculator.java
index 619e85d..4157e17 100644
--- a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskErrorRateCalculator.java
+++ b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskErrorRateCalculator.java
@@ -3,10 +3,8 @@ package ca.joeltherrien.randomforest.responses.competingrisk;
import ca.joeltherrien.randomforest.Row;
import ca.joeltherrien.randomforest.VisibleForTesting;
import ca.joeltherrien.randomforest.tree.Forest;
-import ca.joeltherrien.randomforest.tree.Tree;
-import lombok.RequiredArgsConstructor;
-import java.util.*;
+import java.util.List;
import java.util.stream.Collectors;
/**
@@ -15,42 +13,75 @@ import java.util.stream.Collectors;
* Note that this is the same version implemented in randomForestSRC. The downsides of this approach is that we can expect the errors to be biased, possibly severely.
* Therefore I suggest that this measure only be used in comparing models, but not as a final output.
*/
-@RequiredArgsConstructor
public class CompetingRiskErrorRateCalculator {
- private final CompetingRiskFunctionCombiner combiner;
- private final int[] events;
+ private final List> dataset;
+ private final List riskFunctions;
- public CompetingRiskErrorRateCalculator(final int[] events, final double[] times){
- this.events = events;
- this.combiner = new CompetingRiskFunctionCombiner(events, times);
- }
-
- public double[] calculateConcordance(final List> rows, final Forest forest){
- final double tau = rows.stream().mapToDouble(row -> row.getResponse().getU()).max().orElse(0.0);
-
- return calculateConcordance(rows, forest, tau);
- }
-
- private double[] calculateConcordance(final List> rows, final Forest forest, final double tau){
-
- final Collection> trees = forest.getTrees();
-
- // This predicts for rows based on their OOB trees.
-
- final List riskFunctions = rows.stream()
- .map(row -> {
- return trees.stream().filter(tree -> !tree.idInBootstrapSample(row.getId())).map(tree -> tree.evaluate(row)).collect(Collectors.toList());
- })
- .map(combiner::combine)
+ public CompetingRiskErrorRateCalculator(final List> dataset, final Forest forest){
+ this.dataset = dataset;
+ this.riskFunctions = dataset.stream()
+ .map(forest::evaluateOOB)
.collect(Collectors.toList());
+ }
+
+ public double[] calculateConcordance(final int[] events){
+ final double tau = dataset.stream().mapToDouble(row -> row.getResponse().getU()).max().orElse(0.0);
+
+ return calculateConcordance(events, tau);
+ }
+
+ /**
+ * Idea for this error rate; go through every observation I have and calculate its mortality for the different events. If the event with the highest mortality is not the one that happened,
+ * then we add one to the error scale.
+ *
+ * Ignore censored observations.
+ *
+ * Possible extensions might involve counting how many other events had higher mortality, instead of just a single PASS / FAIL.
+ *
+ * @return
+ */
+ public double calculateNaiveMortalityError(final int[] events){
+ int failures = 0;
+ int attempts = 0;
+
+ response_loop:
+ for(int i=0; i riskFunctions = rows.stream().map(row -> forest.evaluate(row)).collect(Collectors.toList());
+ }
+
+ private double[] calculateConcordance(final int[] events, final double tau){
final double[] errorRates = new double[events.length];
- final List responses = rows.stream().map(Row::getResponse).collect(Collectors.toList());
+ final List responses = dataset.stream().map(Row::getResponse).collect(Collectors.toList());
// Let \tau be the max time.
@@ -58,8 +89,7 @@ public class CompetingRiskErrorRateCalculator {
final int event = events[e];
final double[] mortalityList = riskFunctions.stream()
- .map(riskFunction -> riskFunction.getCumulativeIncidenceFunction(event))
- .mapToDouble(cif -> functionToMortality(cif, tau))
+ .mapToDouble(riskFunction -> riskFunction.calculateEventSpecificMortality(event, tau))
.toArray();
final double concordance = calculateConcordance(responses, mortalityList, event);
@@ -111,25 +141,4 @@ public class CompetingRiskErrorRateCalculator {
}
- private double functionToMortality(final MathFunction cif, final double tau){
- double summation = 0.0;
- Point previousPoint = null;
-
- for(final Point point : cif.getPoints()){
- if(previousPoint != null){
- summation += previousPoint.getY() * (point.getTime() - previousPoint.getTime());
- }
- previousPoint = point;
-
- }
-
- // this is to ensure that we integrate over the same range for every function and get comparable results.
- // Don't need to assert whether previousPoint is null or not; if it is null then the MathFunction was incorrectly made as there will always be at least one point for a response
- summation += previousPoint.getY() * (tau - previousPoint.getTime());
-
- return summation;
-
- }
-
-
}
diff --git a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskFunctions.java b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskFunctions.java
index 743e507..b9e36b7 100644
--- a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskFunctions.java
+++ b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskFunctions.java
@@ -23,4 +23,26 @@ public class CompetingRiskFunctions implements Serializable {
public MathFunction getCumulativeIncidenceFunction(int cause) {
return cumulativeIncidenceFunctionMap.get(cause);
}
+
+ public double calculateEventSpecificMortality(final int event, final double tau){
+ final MathFunction cif = getCauseSpecificHazardFunction(event);
+
+ double summation = 0.0;
+ Point previousPoint = null;
+
+ for(final Point point : cif.getPoints()){
+ if(previousPoint != null){
+ summation += previousPoint.getY() * (point.getTime() - previousPoint.getTime());
+ }
+ previousPoint = point;
+
+ }
+
+ // this is to ensure that we integrate over the same range for every function and get comparable results.
+ // Don't need to assert whether previousPoint is null or not; if it is null then the MathFunction was incorrectly made as there will always be at least one point for a response
+ summation += previousPoint.getY() * (tau - previousPoint.getTime());
+
+ return summation;
+
+ }
}
diff --git a/src/main/java/ca/joeltherrien/randomforest/tree/Forest.java b/src/main/java/ca/joeltherrien/randomforest/tree/Forest.java
index 449cfcd..f9e8a34 100644
--- a/src/main/java/ca/joeltherrien/randomforest/tree/Forest.java
+++ b/src/main/java/ca/joeltherrien/randomforest/tree/Forest.java
@@ -2,6 +2,7 @@ package ca.joeltherrien.randomforest.tree;
import ca.joeltherrien.randomforest.CovariateRow;
import lombok.Builder;
+import lombok.RequiredArgsConstructor;
import java.util.Collection;
import java.util.Collections;
@@ -23,6 +24,17 @@ public class Forest { // O = output of trees, FO = forest output. In prac
}
+ public FO evaluateOOB(CovariateRow row){
+
+ return treeResponseCombiner.combine(
+ trees.stream()
+ .filter(tree -> !tree.idInBootstrapSample(row.getId()))
+ .map(node -> node.evaluate(row))
+ .collect(Collectors.toList())
+ );
+
+ }
+
public Collection> getTrees(){
return Collections.unmodifiableCollection(trees);
}
diff --git a/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCompetingRisk.java b/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCompetingRisk.java
index 33268db..8b7bd63 100644
--- a/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCompetingRisk.java
+++ b/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCompetingRisk.java
@@ -227,8 +227,8 @@ public class TestCompetingRisk {
closeEnough(0.163, functions.getCumulativeIncidenceFunction(2).evaluate(4.0).getY(), 0.01);
closeEnough(0.195, functions.getCumulativeIncidenceFunction(2).evaluate(10.8).getY(), 0.01);
- final CompetingRiskErrorRateCalculator errorRateCalculator = new CompetingRiskErrorRateCalculator(new int[]{1,2}, null);
- final double[] errorRates = errorRateCalculator.calculateConcordance(dataset, forest);
+ final CompetingRiskErrorRateCalculator errorRateCalculator = new CompetingRiskErrorRateCalculator(dataset, forest);
+ final double[] errorRates = errorRateCalculator.calculateConcordance(new int[]{1,2});
// Error rates happen to be about the same
/* randomForestSRC results; ignored for now
@@ -308,8 +308,8 @@ public class TestCompetingRisk {
// We seem to consistently underestimate the results.
assertTrue(causeOneCIFPoints.get(causeOneCIFPoints.size()-1).getY() > 0.75, "Results should match randomForestSRC; had " + causeOneCIFPoints.get(causeOneCIFPoints.size()-1).getY()); // note; most observations from randomForestSRC hover around 0.78 but I've seen it as low as 0.72
- final CompetingRiskErrorRateCalculator errorRate = new CompetingRiskErrorRateCalculator(new int[]{1,2}, null);
- final double[] errorRates = errorRate.calculateConcordance(dataset, forest);
+ final CompetingRiskErrorRateCalculator errorRateCalculator = new CompetingRiskErrorRateCalculator(dataset, forest);
+ final double[] errorRates = errorRateCalculator.calculateConcordance(new int[]{1,2});
System.out.println(errorRates[0]);
System.out.println(errorRates[1]);
diff --git a/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCompetingRiskErrorRateCalculator.java b/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCompetingRiskErrorRateCalculator.java
index d72f075..db11ad7 100644
--- a/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCompetingRiskErrorRateCalculator.java
+++ b/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCompetingRiskErrorRateCalculator.java
@@ -1,12 +1,19 @@
package ca.joeltherrien.randomforest.competingrisk;
+
+import ca.joeltherrien.randomforest.Row;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskErrorRateCalculator;
+import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskFunctions;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
+import ca.joeltherrien.randomforest.tree.Forest;
import org.junit.jupiter.api.Test;
+import java.util.Collections;
import java.util.List;
-import static org.junit.jupiter.api.Assertions.*;
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
public class TestCompetingRiskErrorRateCalculator {
@@ -23,7 +30,8 @@ public class TestCompetingRiskErrorRateCalculator {
final int event = 1;
- final CompetingRiskErrorRateCalculator errorRateCalculator = new CompetingRiskErrorRateCalculator(new int[]{1,2}, null);
+ final Forest fakeForest = Forest.builder().build();
+ final CompetingRiskErrorRateCalculator errorRateCalculator = new CompetingRiskErrorRateCalculator(Collections.emptyList(), fakeForest);
final double concordance = errorRateCalculator.calculateConcordance(responseList, mortalityArray, event);
@@ -32,4 +40,61 @@ public class TestCompetingRiskErrorRateCalculator {
}
+ @Test
+ public void testNaiveMortality(){
+ final CompetingRiskResponse response1 = new CompetingRiskResponse(1, 5.0);
+ final CompetingRiskResponse response2 = new CompetingRiskResponse(0, 6.0);
+ final CompetingRiskResponse response3 = new CompetingRiskResponse(2, 8.0);
+ final CompetingRiskResponse response4 = new CompetingRiskResponse(1, 3.0);
+
+ final List> dataset = List.of(
+ new Row<>(Collections.emptyMap(), 1, response1),
+ new Row<>(Collections.emptyMap(), 2, response2),
+ new Row<>(Collections.emptyMap(), 3, response3),
+ new Row<>(Collections.emptyMap(), 4, response4)
+ );
+
+ final double[] mortalityOneArray = new double[]{1, 4, 3, 9};
+ final double[] mortalityTwoArray = new double[]{2, 3, 4, 7};
+
+ // response1 was predicted incorrectly
+ // response2 doesn't matter; censored
+ // response3 was correctly predicted
+ // response4 was correctly predicted
+
+ // Expect 1/3 for my error
+
+ final CompetingRiskFunctions function1 = mock(CompetingRiskFunctions.class);
+ when(function1.calculateEventSpecificMortality(1, response1.getU())).thenReturn(mortalityOneArray[0]);
+ when(function1.calculateEventSpecificMortality(2, response1.getU())).thenReturn(mortalityTwoArray[0]);
+
+ final CompetingRiskFunctions function2 = mock(CompetingRiskFunctions.class);
+ when(function2.calculateEventSpecificMortality(1, response2.getU())).thenReturn(mortalityOneArray[1]);
+ when(function2.calculateEventSpecificMortality(2, response2.getU())).thenReturn(mortalityTwoArray[1]);
+
+ final CompetingRiskFunctions function3 = mock(CompetingRiskFunctions.class);
+ when(function3.calculateEventSpecificMortality(1, response3.getU())).thenReturn(mortalityOneArray[2]);
+ when(function3.calculateEventSpecificMortality(2, response3.getU())).thenReturn(mortalityTwoArray[2]);
+
+ final CompetingRiskFunctions function4 = mock(CompetingRiskFunctions.class);
+ when(function4.calculateEventSpecificMortality(1, response4.getU())).thenReturn(mortalityOneArray[3]);
+ when(function4.calculateEventSpecificMortality(2, response4.getU())).thenReturn(mortalityTwoArray[3]);
+
+ final Forest mockForest = mock(Forest.class);
+ when(mockForest.evaluateOOB(dataset.get(0))).thenReturn(function1);
+ when(mockForest.evaluateOOB(dataset.get(1))).thenReturn(function2);
+ when(mockForest.evaluateOOB(dataset.get(2))).thenReturn(function3);
+ when(mockForest.evaluateOOB(dataset.get(3))).thenReturn(function4);
+
+
+ final CompetingRiskErrorRateCalculator errorRateCalculator = new CompetingRiskErrorRateCalculator(dataset, mockForest);
+
+ final double error = errorRateCalculator.calculateNaiveMortalityError(new int[]{1,2});
+
+ assertEquals(1.0/3.0, error);
+
+ }
+
+
+
}