Broke two methods in CompetingRiskErrorRateCalculator into static
methods in a new class.
This commit is contained in:
parent
e92abdab13
commit
c85cebb59f
3 changed files with 105 additions and 97 deletions
|
@ -1,7 +1,6 @@
|
||||||
package ca.joeltherrien.randomforest.responses.competingrisk;
|
package ca.joeltherrien.randomforest.responses.competingrisk;
|
||||||
|
|
||||||
import ca.joeltherrien.randomforest.Row;
|
import ca.joeltherrien.randomforest.Row;
|
||||||
import ca.joeltherrien.randomforest.VisibleForTesting;
|
|
||||||
import ca.joeltherrien.randomforest.tree.Forest;
|
import ca.joeltherrien.randomforest.tree.Forest;
|
||||||
import ca.joeltherrien.randomforest.utils.MathFunction;
|
import ca.joeltherrien.randomforest.utils.MathFunction;
|
||||||
|
|
||||||
|
@ -96,7 +95,7 @@ public class CompetingRiskErrorRateCalculator {
|
||||||
.mapToDouble(riskFunction -> riskFunction.calculateEventSpecificMortality(event, tau))
|
.mapToDouble(riskFunction -> riskFunction.calculateEventSpecificMortality(event, tau))
|
||||||
.toArray();
|
.toArray();
|
||||||
|
|
||||||
final double concordance = calculateConcordance(responses, mortalityList, event);
|
final double concordance = CompetingRiskUtils.calculateConcordance(responses, mortalityList, event);
|
||||||
errorRates[e] = 1.0 - concordance;
|
errorRates[e] = 1.0 - concordance;
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -126,7 +125,7 @@ public class CompetingRiskErrorRateCalculator {
|
||||||
.mapToDouble(riskFunction -> riskFunction.calculateEventSpecificMortality(event, tau))
|
.mapToDouble(riskFunction -> riskFunction.calculateEventSpecificMortality(event, tau))
|
||||||
.toArray();
|
.toArray();
|
||||||
|
|
||||||
final double concordance = calculateIPCWConcordance(responses, mortalityList, event, censoringDistribution);
|
final double concordance = CompetingRiskUtils.calculateIPCWConcordance(responses, mortalityList, event, censoringDistribution);
|
||||||
errorRates[e] = 1.0 - concordance;
|
errorRates[e] = 1.0 - concordance;
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -135,97 +134,6 @@ public class CompetingRiskErrorRateCalculator {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@VisibleForTesting
|
|
||||||
public double calculateConcordance(final List<CompetingRiskResponse> responseList, double[] mortalityArray, final int event){
|
|
||||||
|
|
||||||
// Let \tau be the max time.
|
|
||||||
|
|
||||||
int permissible = 0;
|
|
||||||
double numerator = 0;
|
|
||||||
|
|
||||||
for(int i = 0; i<mortalityArray.length; i++){
|
|
||||||
final CompetingRiskResponse responseI = responseList.get(i);
|
|
||||||
if(responseI.getDelta() != event){ // \tilde{N}_i^1(\tau) == 1 check
|
|
||||||
continue; // skip if it's 0
|
|
||||||
}
|
|
||||||
|
|
||||||
final double mortalityI = mortalityArray[i];
|
|
||||||
|
|
||||||
for(int j=0; j<mortalityArray.length; j++){
|
|
||||||
final CompetingRiskResponse responseJ = responseList.get(j);
|
|
||||||
// Check that Aij or Bij == 1
|
|
||||||
if(responseI.getU() < responseJ.getU() || (responseI.getU() >= responseJ.getU() && !responseJ.isCensored() && responseJ.getDelta() != event)){
|
|
||||||
permissible++;
|
|
||||||
|
|
||||||
final double mortalityJ = mortalityArray[j];
|
|
||||||
if(mortalityI > mortalityJ){
|
|
||||||
numerator += 1.0;
|
|
||||||
}
|
|
||||||
else if(mortalityI == mortalityJ){
|
|
||||||
numerator += 0.5; // Edge case that can happen in trees with only a few BooleanCovariates, when you're looking at training error
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
return numerator / (double) permissible;
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@VisibleForTesting
|
|
||||||
public double calculateIPCWConcordance(final List<CompetingRiskResponse> responseList, double[] mortalityArray, final int event, final MathFunction censoringDistribution){
|
|
||||||
|
|
||||||
// Let \tau be the max time.
|
|
||||||
|
|
||||||
double denominator = 0.0;
|
|
||||||
double numerator = 0.0;
|
|
||||||
|
|
||||||
for(int i = 0; i<mortalityArray.length; i++){
|
|
||||||
final CompetingRiskResponse responseI = responseList.get(i);
|
|
||||||
if(responseI.getDelta() != event){ // \tilde{N}_i^1(\tau) == 1 check
|
|
||||||
continue; // skip if it's 0
|
|
||||||
}
|
|
||||||
|
|
||||||
final double mortalityI = mortalityArray[i];
|
|
||||||
final double Ti = responseI.getU();
|
|
||||||
final double G_Ti_minus = censoringDistribution.evaluatePrevious(Ti).getY();
|
|
||||||
final double AijWeight = 1.0 / (censoringDistribution.evaluate(Ti).getY() * G_Ti_minus);
|
|
||||||
|
|
||||||
for(int j=0; j<mortalityArray.length; j++){
|
|
||||||
final CompetingRiskResponse responseJ = responseList.get(j);
|
|
||||||
|
|
||||||
final double AijWeightPlusBijWeight;
|
|
||||||
|
|
||||||
if(responseI.getU() < responseJ.getU()){ // Aij == 1
|
|
||||||
AijWeightPlusBijWeight = AijWeight;
|
|
||||||
}
|
|
||||||
else if(responseI.getU() >= responseJ.getU() && !responseJ.isCensored() && responseJ.getDelta() != event){ // Bij == 1
|
|
||||||
AijWeightPlusBijWeight = 1.0 / (G_Ti_minus * censoringDistribution.evaluatePrevious(responseJ.getU()).getY());
|
|
||||||
}
|
|
||||||
else{
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
denominator += AijWeightPlusBijWeight;
|
|
||||||
|
|
||||||
final double mortalityJ = mortalityArray[j];
|
|
||||||
if(mortalityI > mortalityJ){
|
|
||||||
numerator += AijWeightPlusBijWeight*1.0;
|
|
||||||
}
|
|
||||||
else if(mortalityI == mortalityJ){
|
|
||||||
numerator += AijWeightPlusBijWeight*0.5; // Edge case that can happen in trees with only a few BooleanCovariates, when you're looking at training error
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
return numerator / denominator;
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,101 @@
|
||||||
|
package ca.joeltherrien.randomforest.responses.competingrisk;
|
||||||
|
|
||||||
|
import ca.joeltherrien.randomforest.utils.MathFunction;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
public class CompetingRiskUtils {
|
||||||
|
|
||||||
|
public static double calculateConcordance(final List<CompetingRiskResponse> responseList, double[] mortalityArray, final int event){
|
||||||
|
|
||||||
|
// Let \tau be the max time.
|
||||||
|
|
||||||
|
int permissible = 0;
|
||||||
|
double numerator = 0;
|
||||||
|
|
||||||
|
for(int i = 0; i<mortalityArray.length; i++){
|
||||||
|
final CompetingRiskResponse responseI = responseList.get(i);
|
||||||
|
if(responseI.getDelta() != event){ // \tilde{N}_i^1(\tau) == 1 check
|
||||||
|
continue; // skip if it's 0
|
||||||
|
}
|
||||||
|
|
||||||
|
final double mortalityI = mortalityArray[i];
|
||||||
|
|
||||||
|
for(int j=0; j<mortalityArray.length; j++){
|
||||||
|
final CompetingRiskResponse responseJ = responseList.get(j);
|
||||||
|
// Check that Aij or Bij == 1
|
||||||
|
if(responseI.getU() < responseJ.getU() || (responseI.getU() >= responseJ.getU() && !responseJ.isCensored() && responseJ.getDelta() != event)){
|
||||||
|
permissible++;
|
||||||
|
|
||||||
|
final double mortalityJ = mortalityArray[j];
|
||||||
|
if(mortalityI > mortalityJ){
|
||||||
|
numerator += 1.0;
|
||||||
|
}
|
||||||
|
else if(mortalityI == mortalityJ){
|
||||||
|
numerator += 0.5; // Edge case that can happen in trees with only a few BooleanCovariates, when you're looking at training error
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
return numerator / (double) permissible;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
public static double calculateIPCWConcordance(final List<CompetingRiskResponse> responseList, double[] mortalityArray, final int event, final MathFunction censoringDistribution){
|
||||||
|
|
||||||
|
// Let \tau be the max time.
|
||||||
|
|
||||||
|
double denominator = 0.0;
|
||||||
|
double numerator = 0.0;
|
||||||
|
|
||||||
|
for(int i = 0; i<mortalityArray.length; i++){
|
||||||
|
final CompetingRiskResponse responseI = responseList.get(i);
|
||||||
|
if(responseI.getDelta() != event){ // \tilde{N}_i^1(\tau) == 1 check
|
||||||
|
continue; // skip if it's 0
|
||||||
|
}
|
||||||
|
|
||||||
|
final double mortalityI = mortalityArray[i];
|
||||||
|
final double Ti = responseI.getU();
|
||||||
|
final double G_Ti_minus = censoringDistribution.evaluatePrevious(Ti).getY();
|
||||||
|
final double AijWeight = 1.0 / (censoringDistribution.evaluate(Ti).getY() * G_Ti_minus);
|
||||||
|
|
||||||
|
for(int j=0; j<mortalityArray.length; j++){
|
||||||
|
final CompetingRiskResponse responseJ = responseList.get(j);
|
||||||
|
|
||||||
|
final double AijWeightPlusBijWeight;
|
||||||
|
|
||||||
|
if(responseI.getU() < responseJ.getU()){ // Aij == 1
|
||||||
|
AijWeightPlusBijWeight = AijWeight;
|
||||||
|
}
|
||||||
|
else if(responseI.getU() >= responseJ.getU() && !responseJ.isCensored() && responseJ.getDelta() != event){ // Bij == 1
|
||||||
|
AijWeightPlusBijWeight = 1.0 / (G_Ti_minus * censoringDistribution.evaluatePrevious(responseJ.getU()).getY());
|
||||||
|
}
|
||||||
|
else{
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
denominator += AijWeightPlusBijWeight;
|
||||||
|
|
||||||
|
final double mortalityJ = mortalityArray[j];
|
||||||
|
if(mortalityI > mortalityJ){
|
||||||
|
numerator += AijWeightPlusBijWeight*1.0;
|
||||||
|
}
|
||||||
|
else if(mortalityI == mortalityJ){
|
||||||
|
numerator += AijWeightPlusBijWeight*0.5; // Edge case that can happen in trees with only a few BooleanCovariates, when you're looking at training error
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
return numerator / denominator;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
}
|
|
@ -31,13 +31,12 @@ public class TestCompetingRiskErrorRateCalculator {
|
||||||
final int event = 1;
|
final int event = 1;
|
||||||
|
|
||||||
final Forest<CompetingRiskFunctions, CompetingRiskFunctions> fakeForest = Forest.<CompetingRiskFunctions, CompetingRiskFunctions>builder().build();
|
final Forest<CompetingRiskFunctions, CompetingRiskFunctions> fakeForest = Forest.<CompetingRiskFunctions, CompetingRiskFunctions>builder().build();
|
||||||
final CompetingRiskErrorRateCalculator errorRateCalculator = new CompetingRiskErrorRateCalculator(Collections.emptyList(), fakeForest);
|
|
||||||
|
|
||||||
final double naiveConcordance = errorRateCalculator.calculateConcordance(responseList, mortalityArray, event);
|
final double naiveConcordance = CompetingRiskUtils.calculateConcordance(responseList, mortalityArray, event);
|
||||||
|
|
||||||
final MathFunction fakeCensorDistribution = new MathFunction(Collections.emptyList(), new Point(0.0, 1.0));
|
final MathFunction fakeCensorDistribution = new MathFunction(Collections.emptyList(), new Point(0.0, 1.0));
|
||||||
// This distribution will make the IPCW weights == 1, giving identical results to the naive concordance.
|
// This distribution will make the IPCW weights == 1, giving identical results to the naive concordance.
|
||||||
final double ipcwConcordance = errorRateCalculator.calculateIPCWConcordance(responseList, mortalityArray, event, fakeCensorDistribution);
|
final double ipcwConcordance = CompetingRiskUtils.calculateIPCWConcordance(responseList, mortalityArray, event, fakeCensorDistribution);
|
||||||
|
|
||||||
closeEnough(naiveConcordance, ipcwConcordance, 0.0001);
|
closeEnough(naiveConcordance, ipcwConcordance, 0.0001);
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue