Fix a bug in naive mortality error measure; implement IPCW concordance
measure if you can provide the censoring distribution.
This commit is contained in:
parent
e1caef6d56
commit
d3994212b6
12 changed files with 255 additions and 51 deletions
|
@ -3,6 +3,7 @@ package ca.joeltherrien.randomforest.responses.competingrisk;
|
||||||
import ca.joeltherrien.randomforest.Row;
|
import ca.joeltherrien.randomforest.Row;
|
||||||
import ca.joeltherrien.randomforest.VisibleForTesting;
|
import ca.joeltherrien.randomforest.VisibleForTesting;
|
||||||
import ca.joeltherrien.randomforest.tree.Forest;
|
import ca.joeltherrien.randomforest.tree.Forest;
|
||||||
|
import ca.joeltherrien.randomforest.utils.MathFunction;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
|
@ -25,12 +26,6 @@ public class CompetingRiskErrorRateCalculator {
|
||||||
.collect(Collectors.toList());
|
.collect(Collectors.toList());
|
||||||
}
|
}
|
||||||
|
|
||||||
public double[] calculateConcordance(final int[] events){
|
|
||||||
final double tau = dataset.stream().mapToDouble(row -> row.getResponse().getU()).max().orElse(0.0);
|
|
||||||
|
|
||||||
return calculateConcordance(events, tau);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Idea for this error rate; go through every observation I have and calculate its mortality for the different events. If the event with the highest mortality is not the one that happened,
|
* Idea for this error rate; go through every observation I have and calculate its mortality for the different events. If the event with the highest mortality is not the one that happened,
|
||||||
* then we add one to the error scale.
|
* then we add one to the error scale.
|
||||||
|
@ -39,6 +34,8 @@ public class CompetingRiskErrorRateCalculator {
|
||||||
*
|
*
|
||||||
* Possible extensions might involve counting how many other events had higher mortality, instead of just a single PASS / FAIL.
|
* Possible extensions might involve counting how many other events had higher mortality, instead of just a single PASS / FAIL.
|
||||||
*
|
*
|
||||||
|
* My observation is that this error rate isn't very useful...
|
||||||
|
*
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
public double calculateNaiveMortalityError(final int[] events){
|
public double calculateNaiveMortalityError(final int[] events){
|
||||||
|
@ -77,6 +74,13 @@ public class CompetingRiskErrorRateCalculator {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
public double[] calculateConcordance(final int[] events){
|
||||||
|
final double tau = dataset.stream().mapToDouble(row -> row.getResponse().getU()).max().orElse(0.0);
|
||||||
|
|
||||||
|
return calculateConcordance(events, tau);
|
||||||
|
}
|
||||||
|
|
||||||
private double[] calculateConcordance(final int[] events, final double tau){
|
private double[] calculateConcordance(final int[] events, final double tau){
|
||||||
|
|
||||||
final double[] errorRates = new double[events.length];
|
final double[] errorRates = new double[events.length];
|
||||||
|
@ -101,6 +105,36 @@ public class CompetingRiskErrorRateCalculator {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public double[] calculateIPCWConcordance(final int[] events, final MathFunction censoringDistribution){
|
||||||
|
final double tau = dataset.stream().mapToDouble(row -> row.getResponse().getU()).max().orElse(0.0);
|
||||||
|
|
||||||
|
return calculateIPCWConcordance(events, censoringDistribution, tau);
|
||||||
|
}
|
||||||
|
|
||||||
|
private double[] calculateIPCWConcordance(final int[] events, final MathFunction censoringDistribution, final double tau){
|
||||||
|
|
||||||
|
final double[] errorRates = new double[events.length];
|
||||||
|
|
||||||
|
final List<CompetingRiskResponse> responses = dataset.stream().map(Row::getResponse).collect(Collectors.toList());
|
||||||
|
|
||||||
|
// Let \tau be the max time.
|
||||||
|
|
||||||
|
for(int e=0; e<events.length; e++){
|
||||||
|
final int event = events[e];
|
||||||
|
|
||||||
|
final double[] mortalityList = riskFunctions.stream()
|
||||||
|
.mapToDouble(riskFunction -> riskFunction.calculateEventSpecificMortality(event, tau))
|
||||||
|
.toArray();
|
||||||
|
|
||||||
|
final double concordance = calculateIPCWConcordance(responses, mortalityList, event, censoringDistribution);
|
||||||
|
errorRates[e] = 1.0 - concordance;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
return errorRates;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
@VisibleForTesting
|
@VisibleForTesting
|
||||||
public double calculateConcordance(final List<CompetingRiskResponse> responseList, double[] mortalityArray, final int event){
|
public double calculateConcordance(final List<CompetingRiskResponse> responseList, double[] mortalityArray, final int event){
|
||||||
|
|
||||||
|
@ -141,4 +175,55 @@ public class CompetingRiskErrorRateCalculator {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@VisibleForTesting
|
||||||
|
public double calculateIPCWConcordance(final List<CompetingRiskResponse> responseList, double[] mortalityArray, final int event, final MathFunction censoringDistribution){
|
||||||
|
|
||||||
|
// Let \tau be the max time.
|
||||||
|
|
||||||
|
double denominator = 0.0;
|
||||||
|
double numerator = 0.0;
|
||||||
|
|
||||||
|
for(int i = 0; i<mortalityArray.length; i++){
|
||||||
|
final CompetingRiskResponse responseI = responseList.get(i);
|
||||||
|
if(responseI.getDelta() != event){ // \tilde{N}_i^1(\tau) == 1 check
|
||||||
|
continue; // skip if it's 0
|
||||||
|
}
|
||||||
|
|
||||||
|
final double mortalityI = mortalityArray[i];
|
||||||
|
|
||||||
|
for(int j=0; j<mortalityArray.length; j++){
|
||||||
|
final CompetingRiskResponse responseJ = responseList.get(j);
|
||||||
|
|
||||||
|
final double AijWeightPlusBijWeight;
|
||||||
|
|
||||||
|
if(responseI.getU() < responseJ.getU()){ // Aij == 1
|
||||||
|
final double Ti = responseI.getU();
|
||||||
|
AijWeightPlusBijWeight = 1.0 / (censoringDistribution.evaluate(Ti).getY() * censoringDistribution.evaluatePrevious(Ti).getY());
|
||||||
|
}
|
||||||
|
else if(responseI.getU() >= responseJ.getU() && !responseJ.isCensored() && responseJ.getDelta() != event){ // Bij == 1
|
||||||
|
AijWeightPlusBijWeight = 1.0 / (censoringDistribution.evaluatePrevious(responseI.getU()).getY() * censoringDistribution.evaluatePrevious(responseJ.getU()).getY());
|
||||||
|
}
|
||||||
|
else{
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
denominator += AijWeightPlusBijWeight;
|
||||||
|
|
||||||
|
final double mortalityJ = mortalityArray[j];
|
||||||
|
if(mortalityI > mortalityJ){
|
||||||
|
numerator += AijWeightPlusBijWeight*1.0;
|
||||||
|
}
|
||||||
|
else if(mortalityI == mortalityJ){
|
||||||
|
numerator += AijWeightPlusBijWeight*0.5; // Edge case that can happen in trees with only a few BooleanCovariates, when you're looking at training error
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
return numerator / denominator;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,7 +1,8 @@
|
||||||
package ca.joeltherrien.randomforest.responses.competingrisk;
|
package ca.joeltherrien.randomforest.responses.competingrisk;
|
||||||
|
|
||||||
import ca.joeltherrien.randomforest.tree.ResponseCombiner;
|
import ca.joeltherrien.randomforest.tree.ResponseCombiner;
|
||||||
import lombok.Getter;
|
import ca.joeltherrien.randomforest.utils.MathFunction;
|
||||||
|
import ca.joeltherrien.randomforest.utils.Point;
|
||||||
import lombok.RequiredArgsConstructor;
|
import lombok.RequiredArgsConstructor;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
|
|
|
@ -1,8 +1,9 @@
|
||||||
package ca.joeltherrien.randomforest.responses.competingrisk;
|
package ca.joeltherrien.randomforest.responses.competingrisk;
|
||||||
|
|
||||||
|
import ca.joeltherrien.randomforest.utils.MathFunction;
|
||||||
|
import ca.joeltherrien.randomforest.utils.Point;
|
||||||
import lombok.Builder;
|
import lombok.Builder;
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
import lombok.RequiredArgsConstructor;
|
|
||||||
|
|
||||||
import java.io.Serializable;
|
import java.io.Serializable;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
@ -31,6 +32,10 @@ public class CompetingRiskFunctions implements Serializable {
|
||||||
Point previousPoint = null;
|
Point previousPoint = null;
|
||||||
|
|
||||||
for(final Point point : cif.getPoints()){
|
for(final Point point : cif.getPoints()){
|
||||||
|
if(point.getTime() > tau){
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
if(previousPoint != null){
|
if(previousPoint != null){
|
||||||
summation += previousPoint.getY() * (point.getTime() - previousPoint.getTime());
|
summation += previousPoint.getY() * (point.getTime() - previousPoint.getTime());
|
||||||
}
|
}
|
||||||
|
@ -38,9 +43,11 @@ public class CompetingRiskFunctions implements Serializable {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// this is to ensure that we integrate over the same range for every function and get comparable results.
|
// this is to ensure that we integrate over the proper range
|
||||||
// Don't need to assert whether previousPoint is null or not; if it is null then the MathFunction was incorrectly made as there will always be at least one point for a response
|
if(previousPoint != null){
|
||||||
summation += previousPoint.getY() * (tau - previousPoint.getTime());
|
summation += cif.evaluate(tau).getY() * (tau - previousPoint.getTime());
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
return summation;
|
return summation;
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,8 @@
|
||||||
package ca.joeltherrien.randomforest.responses.competingrisk;
|
package ca.joeltherrien.randomforest.responses.competingrisk;
|
||||||
|
|
||||||
import ca.joeltherrien.randomforest.tree.ResponseCombiner;
|
import ca.joeltherrien.randomforest.tree.ResponseCombiner;
|
||||||
|
import ca.joeltherrien.randomforest.utils.MathFunction;
|
||||||
|
import ca.joeltherrien.randomforest.utils.Point;
|
||||||
import lombok.RequiredArgsConstructor;
|
import lombok.RequiredArgsConstructor;
|
||||||
|
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
package ca.joeltherrien.randomforest.responses.competingrisk;
|
package ca.joeltherrien.randomforest.utils;
|
||||||
|
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
|
|
||||||
|
@ -43,6 +43,14 @@ public class MathFunction implements Serializable {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public Point evaluatePrevious(double time){
|
||||||
|
final Optional<Point> pointOptional = points.stream()
|
||||||
|
.filter(point -> point.getTime() < time)
|
||||||
|
.max(Comparator.comparingDouble(Point::getTime));
|
||||||
|
|
||||||
|
return pointOptional.orElse(defaultValue);
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String toString(){
|
public String toString(){
|
||||||
final StringBuilder builder = new StringBuilder();
|
final StringBuilder builder = new StringBuilder();
|
|
@ -1,4 +1,4 @@
|
||||||
package ca.joeltherrien.randomforest.responses.competingrisk;
|
package ca.joeltherrien.randomforest.utils;
|
||||||
|
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
|
40
src/main/java/ca/joeltherrien/randomforest/utils/Utils.java
Normal file
40
src/main/java/ca/joeltherrien/randomforest/utils/Utils.java
Normal file
|
@ -0,0 +1,40 @@
|
||||||
|
package ca.joeltherrien.randomforest.utils;
|
||||||
|
|
||||||
|
import java.util.*;
|
||||||
|
|
||||||
|
public class Utils {
|
||||||
|
|
||||||
|
public static MathFunction estimateOneMinusECDF(final double[] times){
|
||||||
|
final Point defaultPoint = new Point(0.0, 1.0);
|
||||||
|
Arrays.sort(times);
|
||||||
|
|
||||||
|
final Map<Double, Integer> timeCounterMap = new HashMap<>();
|
||||||
|
|
||||||
|
for(final double time : times){
|
||||||
|
Integer existingCount = timeCounterMap.get(time);
|
||||||
|
existingCount = existingCount != null ? existingCount : 0;
|
||||||
|
|
||||||
|
timeCounterMap.put(time, existingCount+1);
|
||||||
|
}
|
||||||
|
|
||||||
|
final List<Map.Entry<Double, Integer>> timeCounterList = new ArrayList<>(timeCounterMap.entrySet());
|
||||||
|
Collections.sort(timeCounterList, Comparator.comparingDouble(Map.Entry::getKey));
|
||||||
|
|
||||||
|
final List<Point> pointList = new ArrayList<>(timeCounterList.size());
|
||||||
|
|
||||||
|
int previousCount = times.length;
|
||||||
|
final double n = times.length;
|
||||||
|
|
||||||
|
for(final Map.Entry<Double, Integer> entry : timeCounterList){
|
||||||
|
final int newCount = previousCount - entry.getValue();
|
||||||
|
previousCount = newCount;
|
||||||
|
|
||||||
|
pointList.add(new Point(entry.getKey(), (double) newCount / n));
|
||||||
|
}
|
||||||
|
|
||||||
|
return new MathFunction(pointList, defaultPoint);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
}
|
73
src/test/java/ca/joeltherrien/randomforest/TestUtils.java
Normal file
73
src/test/java/ca/joeltherrien/randomforest/TestUtils.java
Normal file
|
@ -0,0 +1,73 @@
|
||||||
|
package ca.joeltherrien.randomforest;
|
||||||
|
|
||||||
|
import ca.joeltherrien.randomforest.utils.MathFunction;
|
||||||
|
import ca.joeltherrien.randomforest.utils.Point;
|
||||||
|
import ca.joeltherrien.randomforest.utils.Utils;
|
||||||
|
import org.junit.jupiter.api.Test;
|
||||||
|
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||||
|
|
||||||
|
public class TestUtils {
|
||||||
|
|
||||||
|
public static void closeEnough(double expected, double actual, double margin){
|
||||||
|
assertTrue(Math.abs(expected - actual) < margin, "Expected " + expected + " but saw " + actual);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* We know the function is cumulative; make sure it is ordered correctly and that that function is monotone.
|
||||||
|
*
|
||||||
|
* @param function
|
||||||
|
*/
|
||||||
|
public static void assertCumulativeFunction(MathFunction function){
|
||||||
|
Point previousPoint = null;
|
||||||
|
for(final Point point : function.getPoints()){
|
||||||
|
|
||||||
|
if(previousPoint != null){
|
||||||
|
assertTrue(previousPoint.getTime() < point.getTime(), "Points should be ordered and strictly different");
|
||||||
|
assertTrue(previousPoint.getY() <= point.getY(), "Cumulative incidence functions are monotone");
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
previousPoint = point;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public static void assertSurvivalCurve(MathFunction function){
|
||||||
|
Point previousPoint = null;
|
||||||
|
for(final Point point : function.getPoints()){
|
||||||
|
|
||||||
|
if(previousPoint != null){
|
||||||
|
assertTrue(previousPoint.getTime() < point.getTime(), "Points should be ordered and strictly different");
|
||||||
|
assertTrue(previousPoint.getY() >= point.getY(), "Survival functions are monotone");
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
previousPoint = point;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testOneMinusECDF(){
|
||||||
|
final double[] times = new double[]{1.0, 1.0, 2.0, 3.0, 3.0, 50.0};
|
||||||
|
final MathFunction survivalCurve = Utils.estimateOneMinusECDF(times);
|
||||||
|
|
||||||
|
final double margin = 0.000001;
|
||||||
|
closeEnough(1.0, survivalCurve.evaluate(0.0).getY(), margin);
|
||||||
|
|
||||||
|
closeEnough(1.0, survivalCurve.evaluatePrevious(1.0).getY(), margin);
|
||||||
|
closeEnough(4.0/6.0, survivalCurve.evaluate(1.0).getY(), margin);
|
||||||
|
|
||||||
|
closeEnough(4.0/6.0, survivalCurve.evaluatePrevious(2.0).getY(), margin);
|
||||||
|
closeEnough(3.0/6.0, survivalCurve.evaluate(2.0).getY(), margin);
|
||||||
|
|
||||||
|
closeEnough(3.0/6.0, survivalCurve.evaluatePrevious(3.0).getY(), margin);
|
||||||
|
closeEnough(1.0/6.0, survivalCurve.evaluate(3.0).getY(), margin);
|
||||||
|
|
||||||
|
closeEnough(1.0/6.0, survivalCurve.evaluatePrevious(50.0).getY(), margin);
|
||||||
|
closeEnough(0.0, survivalCurve.evaluate(50.0).getY(), margin);
|
||||||
|
|
||||||
|
assertSurvivalCurve(survivalCurve);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -7,9 +7,13 @@ import ca.joeltherrien.randomforest.tree.Forest;
|
||||||
import ca.joeltherrien.randomforest.tree.ForestTrainer;
|
import ca.joeltherrien.randomforest.tree.ForestTrainer;
|
||||||
import ca.joeltherrien.randomforest.tree.Node;
|
import ca.joeltherrien.randomforest.tree.Node;
|
||||||
import ca.joeltherrien.randomforest.tree.TreeTrainer;
|
import ca.joeltherrien.randomforest.tree.TreeTrainer;
|
||||||
|
import ca.joeltherrien.randomforest.utils.MathFunction;
|
||||||
|
import ca.joeltherrien.randomforest.utils.Point;
|
||||||
import com.fasterxml.jackson.databind.node.*;
|
import com.fasterxml.jackson.databind.node.*;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
|
|
||||||
|
import static ca.joeltherrien.randomforest.TestUtils.assertCumulativeFunction;
|
||||||
|
import static ca.joeltherrien.randomforest.TestUtils.closeEnough;
|
||||||
import static org.junit.jupiter.api.Assertions.*;
|
import static org.junit.jupiter.api.Assertions.*;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
|
@ -240,10 +244,10 @@ public class TestCompetingRisk {
|
||||||
System.out.println(errorRates[1]);
|
System.out.println(errorRates[1]);
|
||||||
|
|
||||||
|
|
||||||
closeEnough(0.452, errorRates[0], 0.01);
|
closeEnough(0.452, errorRates[0], 0.02);
|
||||||
closeEnough(0.446, errorRates[1], 0.01);
|
closeEnough(0.446, errorRates[1], 0.02);
|
||||||
|
|
||||||
|
|
||||||
|
System.out.println(errorRateCalculator.calculateNaiveMortalityError(new int[]{1,2}));
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@ -306,7 +310,7 @@ public class TestCompetingRisk {
|
||||||
final List<Point> causeOneCIFPoints = functions.getCumulativeIncidenceFunction(1).getPoints();
|
final List<Point> causeOneCIFPoints = functions.getCumulativeIncidenceFunction(1).getPoints();
|
||||||
|
|
||||||
// We seem to consistently underestimate the results.
|
// We seem to consistently underestimate the results.
|
||||||
assertTrue(causeOneCIFPoints.get(causeOneCIFPoints.size()-1).getY() > 0.75, "Results should match randomForestSRC; had " + causeOneCIFPoints.get(causeOneCIFPoints.size()-1).getY()); // note; most observations from randomForestSRC hover around 0.78 but I've seen it as low as 0.72
|
assertTrue(causeOneCIFPoints.get(causeOneCIFPoints.size()-1).getY() > 0.74, "Results should match randomForestSRC; had " + causeOneCIFPoints.get(causeOneCIFPoints.size()-1).getY()); // note; most observations from randomForestSRC hover around 0.78 but I've seen it as low as 0.72
|
||||||
|
|
||||||
final CompetingRiskErrorRateCalculator errorRateCalculator = new CompetingRiskErrorRateCalculator(dataset, forest);
|
final CompetingRiskErrorRateCalculator errorRateCalculator = new CompetingRiskErrorRateCalculator(dataset, forest);
|
||||||
final double[] errorRates = errorRateCalculator.calculateConcordance(new int[]{1,2});
|
final double[] errorRates = errorRateCalculator.calculateConcordance(new int[]{1,2});
|
||||||
|
@ -322,29 +326,9 @@ public class TestCompetingRisk {
|
||||||
// Consistency results
|
// Consistency results
|
||||||
closeEnough(0.395, errorRates[0], 0.01);
|
closeEnough(0.395, errorRates[0], 0.01);
|
||||||
closeEnough(0.345, errorRates[1], 0.01);
|
closeEnough(0.345, errorRates[1], 0.01);
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
System.out.println(errorRateCalculator.calculateNaiveMortalityError(new int[]{1,2}));
|
||||||
* We know the function is cumulative; make sure it is ordered correctly and that that function is monotone.
|
|
||||||
*
|
|
||||||
* @param function
|
|
||||||
*/
|
|
||||||
private void assertCumulativeFunction(MathFunction function){
|
|
||||||
Point previousPoint = null;
|
|
||||||
for(final Point point : function.getPoints()){
|
|
||||||
|
|
||||||
if(previousPoint != null){
|
|
||||||
assertTrue(previousPoint.getTime() < point.getTime(), "Points should be ordered and strictly different");
|
|
||||||
assertTrue(previousPoint.getY() <= point.getY(), "Cumulative incidence functions are monotone");
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
previousPoint = point;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private void closeEnough(double expected, double actual, double margin){
|
|
||||||
assertTrue(Math.abs(expected - actual) < margin, "Expected " + expected + " but saw " + actual);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,16 +1,16 @@
|
||||||
package ca.joeltherrien.randomforest.competingrisk;
|
package ca.joeltherrien.randomforest.competingrisk;
|
||||||
|
|
||||||
|
|
||||||
import ca.joeltherrien.randomforest.Row;
|
import ca.joeltherrien.randomforest.Row;
|
||||||
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskErrorRateCalculator;
|
import ca.joeltherrien.randomforest.responses.competingrisk.*;
|
||||||
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskFunctions;
|
|
||||||
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
|
|
||||||
import ca.joeltherrien.randomforest.tree.Forest;
|
import ca.joeltherrien.randomforest.tree.Forest;
|
||||||
|
import ca.joeltherrien.randomforest.utils.MathFunction;
|
||||||
|
import ca.joeltherrien.randomforest.utils.Point;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
|
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
|
import static ca.joeltherrien.randomforest.TestUtils.closeEnough;
|
||||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
import static org.mockito.Mockito.mock;
|
import static org.mockito.Mockito.mock;
|
||||||
import static org.mockito.Mockito.when;
|
import static org.mockito.Mockito.when;
|
||||||
|
@ -33,10 +33,16 @@ public class TestCompetingRiskErrorRateCalculator {
|
||||||
final Forest<CompetingRiskFunctions, CompetingRiskFunctions> fakeForest = Forest.<CompetingRiskFunctions, CompetingRiskFunctions>builder().build();
|
final Forest<CompetingRiskFunctions, CompetingRiskFunctions> fakeForest = Forest.<CompetingRiskFunctions, CompetingRiskFunctions>builder().build();
|
||||||
final CompetingRiskErrorRateCalculator errorRateCalculator = new CompetingRiskErrorRateCalculator(Collections.emptyList(), fakeForest);
|
final CompetingRiskErrorRateCalculator errorRateCalculator = new CompetingRiskErrorRateCalculator(Collections.emptyList(), fakeForest);
|
||||||
|
|
||||||
final double concordance = errorRateCalculator.calculateConcordance(responseList, mortalityArray, event);
|
final double naiveConcordance = errorRateCalculator.calculateConcordance(responseList, mortalityArray, event);
|
||||||
|
|
||||||
|
final MathFunction fakeCensorDistribution = new MathFunction(Collections.emptyList(), new Point(0.0, 1.0));
|
||||||
|
// This distribution will make the IPCW weights == 1, giving identical results to the naive concordance.
|
||||||
|
final double ipcwConcordance = errorRateCalculator.calculateIPCWConcordance(responseList, mortalityArray, event, fakeCensorDistribution);
|
||||||
|
|
||||||
|
closeEnough(naiveConcordance, ipcwConcordance, 0.0001);
|
||||||
|
|
||||||
// Expected value found through calculations by hand
|
// Expected value found through calculations by hand
|
||||||
assertEquals(3.0/5.0, concordance);
|
assertEquals(3.0/5.0, naiveConcordance);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -3,8 +3,10 @@ package ca.joeltherrien.randomforest.competingrisk;
|
||||||
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskFunctions;
|
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskFunctions;
|
||||||
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
|
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
|
||||||
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponseCombiner;
|
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponseCombiner;
|
||||||
import ca.joeltherrien.randomforest.responses.competingrisk.MathFunction;
|
import ca.joeltherrien.randomforest.utils.MathFunction;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
|
|
||||||
|
import static ca.joeltherrien.randomforest.TestUtils.closeEnough;
|
||||||
import static org.junit.jupiter.api.Assertions.*;
|
import static org.junit.jupiter.api.Assertions.*;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
|
@ -87,8 +89,4 @@ public class TestCompetingRiskResponseCombiner {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private void closeEnough(double expected, double actual, double margin){
|
|
||||||
assertTrue(Math.abs(expected - actual) < margin, "Expected " + expected + " but saw " + actual);
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
package ca.joeltherrien.randomforest.competingrisk;
|
package ca.joeltherrien.randomforest.competingrisk;
|
||||||
|
|
||||||
import ca.joeltherrien.randomforest.responses.competingrisk.MathFunction;
|
import ca.joeltherrien.randomforest.utils.MathFunction;
|
||||||
import ca.joeltherrien.randomforest.responses.competingrisk.Point;
|
import ca.joeltherrien.randomforest.utils.Point;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import static org.junit.jupiter.api.Assertions.*;
|
import static org.junit.jupiter.api.Assertions.*;
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue