Broke two methods in CompetingRiskErrorRateCalculator into static

methods in a new class.
This commit is contained in:
Joel Therrien 2018-08-27 11:18:56 -07:00
parent e92abdab13
commit c85cebb59f
3 changed files with 105 additions and 97 deletions

View file

@ -1,7 +1,6 @@
package ca.joeltherrien.randomforest.responses.competingrisk; package ca.joeltherrien.randomforest.responses.competingrisk;
import ca.joeltherrien.randomforest.Row; import ca.joeltherrien.randomforest.Row;
import ca.joeltherrien.randomforest.VisibleForTesting;
import ca.joeltherrien.randomforest.tree.Forest; import ca.joeltherrien.randomforest.tree.Forest;
import ca.joeltherrien.randomforest.utils.MathFunction; import ca.joeltherrien.randomforest.utils.MathFunction;
@ -96,7 +95,7 @@ public class CompetingRiskErrorRateCalculator {
.mapToDouble(riskFunction -> riskFunction.calculateEventSpecificMortality(event, tau)) .mapToDouble(riskFunction -> riskFunction.calculateEventSpecificMortality(event, tau))
.toArray(); .toArray();
final double concordance = calculateConcordance(responses, mortalityList, event); final double concordance = CompetingRiskUtils.calculateConcordance(responses, mortalityList, event);
errorRates[e] = 1.0 - concordance; errorRates[e] = 1.0 - concordance;
} }
@ -126,7 +125,7 @@ public class CompetingRiskErrorRateCalculator {
.mapToDouble(riskFunction -> riskFunction.calculateEventSpecificMortality(event, tau)) .mapToDouble(riskFunction -> riskFunction.calculateEventSpecificMortality(event, tau))
.toArray(); .toArray();
final double concordance = calculateIPCWConcordance(responses, mortalityList, event, censoringDistribution); final double concordance = CompetingRiskUtils.calculateIPCWConcordance(responses, mortalityList, event, censoringDistribution);
errorRates[e] = 1.0 - concordance; errorRates[e] = 1.0 - concordance;
} }
@ -135,97 +134,6 @@ public class CompetingRiskErrorRateCalculator {
} }
@VisibleForTesting
public double calculateConcordance(final List<CompetingRiskResponse> responseList, double[] mortalityArray, final int event){
// Let \tau be the max time.
int permissible = 0;
double numerator = 0;
for(int i = 0; i<mortalityArray.length; i++){
final CompetingRiskResponse responseI = responseList.get(i);
if(responseI.getDelta() != event){ // \tilde{N}_i^1(\tau) == 1 check
continue; // skip if it's 0
}
final double mortalityI = mortalityArray[i];
for(int j=0; j<mortalityArray.length; j++){
final CompetingRiskResponse responseJ = responseList.get(j);
// Check that Aij or Bij == 1
if(responseI.getU() < responseJ.getU() || (responseI.getU() >= responseJ.getU() && !responseJ.isCensored() && responseJ.getDelta() != event)){
permissible++;
final double mortalityJ = mortalityArray[j];
if(mortalityI > mortalityJ){
numerator += 1.0;
}
else if(mortalityI == mortalityJ){
numerator += 0.5; // Edge case that can happen in trees with only a few BooleanCovariates, when you're looking at training error
}
}
}
}
return numerator / (double) permissible;
}
@VisibleForTesting
public double calculateIPCWConcordance(final List<CompetingRiskResponse> responseList, double[] mortalityArray, final int event, final MathFunction censoringDistribution){
// Let \tau be the max time.
double denominator = 0.0;
double numerator = 0.0;
for(int i = 0; i<mortalityArray.length; i++){
final CompetingRiskResponse responseI = responseList.get(i);
if(responseI.getDelta() != event){ // \tilde{N}_i^1(\tau) == 1 check
continue; // skip if it's 0
}
final double mortalityI = mortalityArray[i];
final double Ti = responseI.getU();
final double G_Ti_minus = censoringDistribution.evaluatePrevious(Ti).getY();
final double AijWeight = 1.0 / (censoringDistribution.evaluate(Ti).getY() * G_Ti_minus);
for(int j=0; j<mortalityArray.length; j++){
final CompetingRiskResponse responseJ = responseList.get(j);
final double AijWeightPlusBijWeight;
if(responseI.getU() < responseJ.getU()){ // Aij == 1
AijWeightPlusBijWeight = AijWeight;
}
else if(responseI.getU() >= responseJ.getU() && !responseJ.isCensored() && responseJ.getDelta() != event){ // Bij == 1
AijWeightPlusBijWeight = 1.0 / (G_Ti_minus * censoringDistribution.evaluatePrevious(responseJ.getU()).getY());
}
else{
continue;
}
denominator += AijWeightPlusBijWeight;
final double mortalityJ = mortalityArray[j];
if(mortalityI > mortalityJ){
numerator += AijWeightPlusBijWeight*1.0;
}
else if(mortalityI == mortalityJ){
numerator += AijWeightPlusBijWeight*0.5; // Edge case that can happen in trees with only a few BooleanCovariates, when you're looking at training error
}
}
}
return numerator / denominator;
}
} }

View file

@ -0,0 +1,101 @@
package ca.joeltherrien.randomforest.responses.competingrisk;
import ca.joeltherrien.randomforest.utils.MathFunction;
import java.util.List;
public class CompetingRiskUtils {
public static double calculateConcordance(final List<CompetingRiskResponse> responseList, double[] mortalityArray, final int event){
// Let \tau be the max time.
int permissible = 0;
double numerator = 0;
for(int i = 0; i<mortalityArray.length; i++){
final CompetingRiskResponse responseI = responseList.get(i);
if(responseI.getDelta() != event){ // \tilde{N}_i^1(\tau) == 1 check
continue; // skip if it's 0
}
final double mortalityI = mortalityArray[i];
for(int j=0; j<mortalityArray.length; j++){
final CompetingRiskResponse responseJ = responseList.get(j);
// Check that Aij or Bij == 1
if(responseI.getU() < responseJ.getU() || (responseI.getU() >= responseJ.getU() && !responseJ.isCensored() && responseJ.getDelta() != event)){
permissible++;
final double mortalityJ = mortalityArray[j];
if(mortalityI > mortalityJ){
numerator += 1.0;
}
else if(mortalityI == mortalityJ){
numerator += 0.5; // Edge case that can happen in trees with only a few BooleanCovariates, when you're looking at training error
}
}
}
}
return numerator / (double) permissible;
}
public static double calculateIPCWConcordance(final List<CompetingRiskResponse> responseList, double[] mortalityArray, final int event, final MathFunction censoringDistribution){
// Let \tau be the max time.
double denominator = 0.0;
double numerator = 0.0;
for(int i = 0; i<mortalityArray.length; i++){
final CompetingRiskResponse responseI = responseList.get(i);
if(responseI.getDelta() != event){ // \tilde{N}_i^1(\tau) == 1 check
continue; // skip if it's 0
}
final double mortalityI = mortalityArray[i];
final double Ti = responseI.getU();
final double G_Ti_minus = censoringDistribution.evaluatePrevious(Ti).getY();
final double AijWeight = 1.0 / (censoringDistribution.evaluate(Ti).getY() * G_Ti_minus);
for(int j=0; j<mortalityArray.length; j++){
final CompetingRiskResponse responseJ = responseList.get(j);
final double AijWeightPlusBijWeight;
if(responseI.getU() < responseJ.getU()){ // Aij == 1
AijWeightPlusBijWeight = AijWeight;
}
else if(responseI.getU() >= responseJ.getU() && !responseJ.isCensored() && responseJ.getDelta() != event){ // Bij == 1
AijWeightPlusBijWeight = 1.0 / (G_Ti_minus * censoringDistribution.evaluatePrevious(responseJ.getU()).getY());
}
else{
continue;
}
denominator += AijWeightPlusBijWeight;
final double mortalityJ = mortalityArray[j];
if(mortalityI > mortalityJ){
numerator += AijWeightPlusBijWeight*1.0;
}
else if(mortalityI == mortalityJ){
numerator += AijWeightPlusBijWeight*0.5; // Edge case that can happen in trees with only a few BooleanCovariates, when you're looking at training error
}
}
}
return numerator / denominator;
}
}

View file

@ -31,13 +31,12 @@ public class TestCompetingRiskErrorRateCalculator {
final int event = 1; final int event = 1;
final Forest<CompetingRiskFunctions, CompetingRiskFunctions> fakeForest = Forest.<CompetingRiskFunctions, CompetingRiskFunctions>builder().build(); final Forest<CompetingRiskFunctions, CompetingRiskFunctions> fakeForest = Forest.<CompetingRiskFunctions, CompetingRiskFunctions>builder().build();
final CompetingRiskErrorRateCalculator errorRateCalculator = new CompetingRiskErrorRateCalculator(Collections.emptyList(), fakeForest);
final double naiveConcordance = errorRateCalculator.calculateConcordance(responseList, mortalityArray, event); final double naiveConcordance = CompetingRiskUtils.calculateConcordance(responseList, mortalityArray, event);
final MathFunction fakeCensorDistribution = new MathFunction(Collections.emptyList(), new Point(0.0, 1.0)); final MathFunction fakeCensorDistribution = new MathFunction(Collections.emptyList(), new Point(0.0, 1.0));
// This distribution will make the IPCW weights == 1, giving identical results to the naive concordance. // This distribution will make the IPCW weights == 1, giving identical results to the naive concordance.
final double ipcwConcordance = errorRateCalculator.calculateIPCWConcordance(responseList, mortalityArray, event, fakeCensorDistribution); final double ipcwConcordance = CompetingRiskUtils.calculateIPCWConcordance(responseList, mortalityArray, event, fakeCensorDistribution);
closeEnough(naiveConcordance, ipcwConcordance, 0.0001); closeEnough(naiveConcordance, ipcwConcordance, 0.0001);