Broke two methods in CompetingRiskErrorRateCalculator into static
methods in a new class.
This commit is contained in:
parent
e92abdab13
commit
c85cebb59f
3 changed files with 105 additions and 97 deletions
|
@ -1,7 +1,6 @@
|
|||
package ca.joeltherrien.randomforest.responses.competingrisk;
|
||||
|
||||
import ca.joeltherrien.randomforest.Row;
|
||||
import ca.joeltherrien.randomforest.VisibleForTesting;
|
||||
import ca.joeltherrien.randomforest.tree.Forest;
|
||||
import ca.joeltherrien.randomforest.utils.MathFunction;
|
||||
|
||||
|
@ -96,7 +95,7 @@ public class CompetingRiskErrorRateCalculator {
|
|||
.mapToDouble(riskFunction -> riskFunction.calculateEventSpecificMortality(event, tau))
|
||||
.toArray();
|
||||
|
||||
final double concordance = calculateConcordance(responses, mortalityList, event);
|
||||
final double concordance = CompetingRiskUtils.calculateConcordance(responses, mortalityList, event);
|
||||
errorRates[e] = 1.0 - concordance;
|
||||
|
||||
}
|
||||
|
@ -126,7 +125,7 @@ public class CompetingRiskErrorRateCalculator {
|
|||
.mapToDouble(riskFunction -> riskFunction.calculateEventSpecificMortality(event, tau))
|
||||
.toArray();
|
||||
|
||||
final double concordance = calculateIPCWConcordance(responses, mortalityList, event, censoringDistribution);
|
||||
final double concordance = CompetingRiskUtils.calculateIPCWConcordance(responses, mortalityList, event, censoringDistribution);
|
||||
errorRates[e] = 1.0 - concordance;
|
||||
|
||||
}
|
||||
|
@ -135,97 +134,6 @@ public class CompetingRiskErrorRateCalculator {
|
|||
|
||||
}
|
||||
|
||||
@VisibleForTesting
|
||||
public double calculateConcordance(final List<CompetingRiskResponse> responseList, double[] mortalityArray, final int event){
|
||||
|
||||
// Let \tau be the max time.
|
||||
|
||||
int permissible = 0;
|
||||
double numerator = 0;
|
||||
|
||||
for(int i = 0; i<mortalityArray.length; i++){
|
||||
final CompetingRiskResponse responseI = responseList.get(i);
|
||||
if(responseI.getDelta() != event){ // \tilde{N}_i^1(\tau) == 1 check
|
||||
continue; // skip if it's 0
|
||||
}
|
||||
|
||||
final double mortalityI = mortalityArray[i];
|
||||
|
||||
for(int j=0; j<mortalityArray.length; j++){
|
||||
final CompetingRiskResponse responseJ = responseList.get(j);
|
||||
// Check that Aij or Bij == 1
|
||||
if(responseI.getU() < responseJ.getU() || (responseI.getU() >= responseJ.getU() && !responseJ.isCensored() && responseJ.getDelta() != event)){
|
||||
permissible++;
|
||||
|
||||
final double mortalityJ = mortalityArray[j];
|
||||
if(mortalityI > mortalityJ){
|
||||
numerator += 1.0;
|
||||
}
|
||||
else if(mortalityI == mortalityJ){
|
||||
numerator += 0.5; // Edge case that can happen in trees with only a few BooleanCovariates, when you're looking at training error
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
return numerator / (double) permissible;
|
||||
|
||||
}
|
||||
|
||||
|
||||
@VisibleForTesting
|
||||
public double calculateIPCWConcordance(final List<CompetingRiskResponse> responseList, double[] mortalityArray, final int event, final MathFunction censoringDistribution){
|
||||
|
||||
// Let \tau be the max time.
|
||||
|
||||
double denominator = 0.0;
|
||||
double numerator = 0.0;
|
||||
|
||||
for(int i = 0; i<mortalityArray.length; i++){
|
||||
final CompetingRiskResponse responseI = responseList.get(i);
|
||||
if(responseI.getDelta() != event){ // \tilde{N}_i^1(\tau) == 1 check
|
||||
continue; // skip if it's 0
|
||||
}
|
||||
|
||||
final double mortalityI = mortalityArray[i];
|
||||
final double Ti = responseI.getU();
|
||||
final double G_Ti_minus = censoringDistribution.evaluatePrevious(Ti).getY();
|
||||
final double AijWeight = 1.0 / (censoringDistribution.evaluate(Ti).getY() * G_Ti_minus);
|
||||
|
||||
for(int j=0; j<mortalityArray.length; j++){
|
||||
final CompetingRiskResponse responseJ = responseList.get(j);
|
||||
|
||||
final double AijWeightPlusBijWeight;
|
||||
|
||||
if(responseI.getU() < responseJ.getU()){ // Aij == 1
|
||||
AijWeightPlusBijWeight = AijWeight;
|
||||
}
|
||||
else if(responseI.getU() >= responseJ.getU() && !responseJ.isCensored() && responseJ.getDelta() != event){ // Bij == 1
|
||||
AijWeightPlusBijWeight = 1.0 / (G_Ti_minus * censoringDistribution.evaluatePrevious(responseJ.getU()).getY());
|
||||
}
|
||||
else{
|
||||
continue;
|
||||
}
|
||||
|
||||
denominator += AijWeightPlusBijWeight;
|
||||
|
||||
final double mortalityJ = mortalityArray[j];
|
||||
if(mortalityI > mortalityJ){
|
||||
numerator += AijWeightPlusBijWeight*1.0;
|
||||
}
|
||||
else if(mortalityI == mortalityJ){
|
||||
numerator += AijWeightPlusBijWeight*0.5; // Edge case that can happen in trees with only a few BooleanCovariates, when you're looking at training error
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
return numerator / denominator;
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -0,0 +1,101 @@
|
|||
package ca.joeltherrien.randomforest.responses.competingrisk;
|
||||
|
||||
import ca.joeltherrien.randomforest.utils.MathFunction;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
public class CompetingRiskUtils {
|
||||
|
||||
public static double calculateConcordance(final List<CompetingRiskResponse> responseList, double[] mortalityArray, final int event){
|
||||
|
||||
// Let \tau be the max time.
|
||||
|
||||
int permissible = 0;
|
||||
double numerator = 0;
|
||||
|
||||
for(int i = 0; i<mortalityArray.length; i++){
|
||||
final CompetingRiskResponse responseI = responseList.get(i);
|
||||
if(responseI.getDelta() != event){ // \tilde{N}_i^1(\tau) == 1 check
|
||||
continue; // skip if it's 0
|
||||
}
|
||||
|
||||
final double mortalityI = mortalityArray[i];
|
||||
|
||||
for(int j=0; j<mortalityArray.length; j++){
|
||||
final CompetingRiskResponse responseJ = responseList.get(j);
|
||||
// Check that Aij or Bij == 1
|
||||
if(responseI.getU() < responseJ.getU() || (responseI.getU() >= responseJ.getU() && !responseJ.isCensored() && responseJ.getDelta() != event)){
|
||||
permissible++;
|
||||
|
||||
final double mortalityJ = mortalityArray[j];
|
||||
if(mortalityI > mortalityJ){
|
||||
numerator += 1.0;
|
||||
}
|
||||
else if(mortalityI == mortalityJ){
|
||||
numerator += 0.5; // Edge case that can happen in trees with only a few BooleanCovariates, when you're looking at training error
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
return numerator / (double) permissible;
|
||||
|
||||
}
|
||||
|
||||
|
||||
public static double calculateIPCWConcordance(final List<CompetingRiskResponse> responseList, double[] mortalityArray, final int event, final MathFunction censoringDistribution){
|
||||
|
||||
// Let \tau be the max time.
|
||||
|
||||
double denominator = 0.0;
|
||||
double numerator = 0.0;
|
||||
|
||||
for(int i = 0; i<mortalityArray.length; i++){
|
||||
final CompetingRiskResponse responseI = responseList.get(i);
|
||||
if(responseI.getDelta() != event){ // \tilde{N}_i^1(\tau) == 1 check
|
||||
continue; // skip if it's 0
|
||||
}
|
||||
|
||||
final double mortalityI = mortalityArray[i];
|
||||
final double Ti = responseI.getU();
|
||||
final double G_Ti_minus = censoringDistribution.evaluatePrevious(Ti).getY();
|
||||
final double AijWeight = 1.0 / (censoringDistribution.evaluate(Ti).getY() * G_Ti_minus);
|
||||
|
||||
for(int j=0; j<mortalityArray.length; j++){
|
||||
final CompetingRiskResponse responseJ = responseList.get(j);
|
||||
|
||||
final double AijWeightPlusBijWeight;
|
||||
|
||||
if(responseI.getU() < responseJ.getU()){ // Aij == 1
|
||||
AijWeightPlusBijWeight = AijWeight;
|
||||
}
|
||||
else if(responseI.getU() >= responseJ.getU() && !responseJ.isCensored() && responseJ.getDelta() != event){ // Bij == 1
|
||||
AijWeightPlusBijWeight = 1.0 / (G_Ti_minus * censoringDistribution.evaluatePrevious(responseJ.getU()).getY());
|
||||
}
|
||||
else{
|
||||
continue;
|
||||
}
|
||||
|
||||
denominator += AijWeightPlusBijWeight;
|
||||
|
||||
final double mortalityJ = mortalityArray[j];
|
||||
if(mortalityI > mortalityJ){
|
||||
numerator += AijWeightPlusBijWeight*1.0;
|
||||
}
|
||||
else if(mortalityI == mortalityJ){
|
||||
numerator += AijWeightPlusBijWeight*0.5; // Edge case that can happen in trees with only a few BooleanCovariates, when you're looking at training error
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
return numerator / denominator;
|
||||
|
||||
}
|
||||
|
||||
|
||||
}
|
|
@ -31,13 +31,12 @@ public class TestCompetingRiskErrorRateCalculator {
|
|||
final int event = 1;
|
||||
|
||||
final Forest<CompetingRiskFunctions, CompetingRiskFunctions> fakeForest = Forest.<CompetingRiskFunctions, CompetingRiskFunctions>builder().build();
|
||||
final CompetingRiskErrorRateCalculator errorRateCalculator = new CompetingRiskErrorRateCalculator(Collections.emptyList(), fakeForest);
|
||||
|
||||
final double naiveConcordance = errorRateCalculator.calculateConcordance(responseList, mortalityArray, event);
|
||||
final double naiveConcordance = CompetingRiskUtils.calculateConcordance(responseList, mortalityArray, event);
|
||||
|
||||
final MathFunction fakeCensorDistribution = new MathFunction(Collections.emptyList(), new Point(0.0, 1.0));
|
||||
// This distribution will make the IPCW weights == 1, giving identical results to the naive concordance.
|
||||
final double ipcwConcordance = errorRateCalculator.calculateIPCWConcordance(responseList, mortalityArray, event, fakeCensorDistribution);
|
||||
final double ipcwConcordance = CompetingRiskUtils.calculateIPCWConcordance(responseList, mortalityArray, event, fakeCensorDistribution);
|
||||
|
||||
closeEnough(naiveConcordance, ipcwConcordance, 0.0001);
|
||||
|
||||
|
|
Loading…
Reference in a new issue