diff --git a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskErrorRateCalculator.java b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskErrorRateCalculator.java index 134673c..966d063 100644 --- a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskErrorRateCalculator.java +++ b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskErrorRateCalculator.java @@ -1,7 +1,6 @@ package ca.joeltherrien.randomforest.responses.competingrisk; import ca.joeltherrien.randomforest.Row; -import ca.joeltherrien.randomforest.VisibleForTesting; import ca.joeltherrien.randomforest.tree.Forest; import ca.joeltherrien.randomforest.utils.MathFunction; @@ -96,7 +95,7 @@ public class CompetingRiskErrorRateCalculator { .mapToDouble(riskFunction -> riskFunction.calculateEventSpecificMortality(event, tau)) .toArray(); - final double concordance = calculateConcordance(responses, mortalityList, event); + final double concordance = CompetingRiskUtils.calculateConcordance(responses, mortalityList, event); errorRates[e] = 1.0 - concordance; } @@ -126,7 +125,7 @@ public class CompetingRiskErrorRateCalculator { .mapToDouble(riskFunction -> riskFunction.calculateEventSpecificMortality(event, tau)) .toArray(); - final double concordance = calculateIPCWConcordance(responses, mortalityList, event, censoringDistribution); + final double concordance = CompetingRiskUtils.calculateIPCWConcordance(responses, mortalityList, event, censoringDistribution); errorRates[e] = 1.0 - concordance; } @@ -135,97 +134,6 @@ public class CompetingRiskErrorRateCalculator { } - @VisibleForTesting - public double calculateConcordance(final List responseList, double[] mortalityArray, final int event){ - // Let \tau be the max time. - - int permissible = 0; - double numerator = 0; - - for(int i = 0; i= responseJ.getU() && !responseJ.isCensored() && responseJ.getDelta() != event)){ - permissible++; - - final double mortalityJ = mortalityArray[j]; - if(mortalityI > mortalityJ){ - numerator += 1.0; - } - else if(mortalityI == mortalityJ){ - numerator += 0.5; // Edge case that can happen in trees with only a few BooleanCovariates, when you're looking at training error - } - - } - - } - - } - - return numerator / (double) permissible; - - } - - - @VisibleForTesting - public double calculateIPCWConcordance(final List responseList, double[] mortalityArray, final int event, final MathFunction censoringDistribution){ - - // Let \tau be the max time. - - double denominator = 0.0; - double numerator = 0.0; - - for(int i = 0; i= responseJ.getU() && !responseJ.isCensored() && responseJ.getDelta() != event){ // Bij == 1 - AijWeightPlusBijWeight = 1.0 / (G_Ti_minus * censoringDistribution.evaluatePrevious(responseJ.getU()).getY()); - } - else{ - continue; - } - - denominator += AijWeightPlusBijWeight; - - final double mortalityJ = mortalityArray[j]; - if(mortalityI > mortalityJ){ - numerator += AijWeightPlusBijWeight*1.0; - } - else if(mortalityI == mortalityJ){ - numerator += AijWeightPlusBijWeight*0.5; // Edge case that can happen in trees with only a few BooleanCovariates, when you're looking at training error - } - - } - - } - - return numerator / denominator; - - } } diff --git a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskUtils.java b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskUtils.java new file mode 100644 index 0000000..44c078a --- /dev/null +++ b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskUtils.java @@ -0,0 +1,101 @@ +package ca.joeltherrien.randomforest.responses.competingrisk; + +import ca.joeltherrien.randomforest.utils.MathFunction; + +import java.util.List; + +public class CompetingRiskUtils { + + public static double calculateConcordance(final List responseList, double[] mortalityArray, final int event){ + + // Let \tau be the max time. + + int permissible = 0; + double numerator = 0; + + for(int i = 0; i= responseJ.getU() && !responseJ.isCensored() && responseJ.getDelta() != event)){ + permissible++; + + final double mortalityJ = mortalityArray[j]; + if(mortalityI > mortalityJ){ + numerator += 1.0; + } + else if(mortalityI == mortalityJ){ + numerator += 0.5; // Edge case that can happen in trees with only a few BooleanCovariates, when you're looking at training error + } + + } + + } + + } + + return numerator / (double) permissible; + + } + + + public static double calculateIPCWConcordance(final List responseList, double[] mortalityArray, final int event, final MathFunction censoringDistribution){ + + // Let \tau be the max time. + + double denominator = 0.0; + double numerator = 0.0; + + for(int i = 0; i= responseJ.getU() && !responseJ.isCensored() && responseJ.getDelta() != event){ // Bij == 1 + AijWeightPlusBijWeight = 1.0 / (G_Ti_minus * censoringDistribution.evaluatePrevious(responseJ.getU()).getY()); + } + else{ + continue; + } + + denominator += AijWeightPlusBijWeight; + + final double mortalityJ = mortalityArray[j]; + if(mortalityI > mortalityJ){ + numerator += AijWeightPlusBijWeight*1.0; + } + else if(mortalityI == mortalityJ){ + numerator += AijWeightPlusBijWeight*0.5; // Edge case that can happen in trees with only a few BooleanCovariates, when you're looking at training error + } + + } + + } + + return numerator / denominator; + + } + + +} diff --git a/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCompetingRiskErrorRateCalculator.java b/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCompetingRiskErrorRateCalculator.java index c0c54ec..7a5484e 100644 --- a/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCompetingRiskErrorRateCalculator.java +++ b/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCompetingRiskErrorRateCalculator.java @@ -31,13 +31,12 @@ public class TestCompetingRiskErrorRateCalculator { final int event = 1; final Forest fakeForest = Forest.builder().build(); - final CompetingRiskErrorRateCalculator errorRateCalculator = new CompetingRiskErrorRateCalculator(Collections.emptyList(), fakeForest); - final double naiveConcordance = errorRateCalculator.calculateConcordance(responseList, mortalityArray, event); + final double naiveConcordance = CompetingRiskUtils.calculateConcordance(responseList, mortalityArray, event); final MathFunction fakeCensorDistribution = new MathFunction(Collections.emptyList(), new Point(0.0, 1.0)); // This distribution will make the IPCW weights == 1, giving identical results to the naive concordance. - final double ipcwConcordance = errorRateCalculator.calculateIPCWConcordance(responseList, mortalityArray, event, fakeCensorDistribution); + final double ipcwConcordance = CompetingRiskUtils.calculateIPCWConcordance(responseList, mortalityArray, event, fakeCensorDistribution); closeEnough(naiveConcordance, ipcwConcordance, 0.0001);