Massive optimizations;
Refactored how MathFunctions are structured to use more primitives and less objects. Optimized competing risk group differentiators to run faster. Removed alternative competing risk response combiners (may be added back later)
This commit is contained in:
parent
cce5ad1e0f
commit
c68f67e47a
36 changed files with 1223 additions and 474 deletions
|
@ -3,20 +3,17 @@ package ca.joeltherrien.randomforest;
|
|||
import ca.joeltherrien.randomforest.covariates.*;
|
||||
import ca.joeltherrien.randomforest.responses.competingrisk.*;
|
||||
import ca.joeltherrien.randomforest.responses.competingrisk.combiner.CompetingRiskFunctionCombiner;
|
||||
import ca.joeltherrien.randomforest.responses.competingrisk.combiner.alternative.CompetingRiskListCombiner;
|
||||
import ca.joeltherrien.randomforest.tree.Forest;
|
||||
import ca.joeltherrien.randomforest.tree.ForestTrainer;
|
||||
import ca.joeltherrien.randomforest.tree.ResponseCombiner;
|
||||
import ca.joeltherrien.randomforest.utils.MathFunction;
|
||||
import ca.joeltherrien.randomforest.utils.StepFunction;
|
||||
import ca.joeltherrien.randomforest.utils.Utils;
|
||||
import com.fasterxml.jackson.databind.node.JsonNodeFactory;
|
||||
import com.fasterxml.jackson.databind.node.ObjectNode;
|
||||
import com.fasterxml.jackson.databind.node.TextNode;
|
||||
|
||||
import java.io.*;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
public class Main {
|
||||
|
||||
|
@ -68,9 +65,6 @@ public class Main {
|
|||
if(responseCombiner instanceof CompetingRiskFunctionCombiner){
|
||||
events = ((CompetingRiskFunctionCombiner) responseCombiner).getEvents();
|
||||
}
|
||||
else if(responseCombiner instanceof CompetingRiskListCombiner){
|
||||
events = ((CompetingRiskListCombiner) responseCombiner).getOriginalCombiner().getEvents();
|
||||
}
|
||||
else{
|
||||
System.out.println("Unsupported tree combiner");
|
||||
return;
|
||||
|
@ -123,7 +117,7 @@ public class Main {
|
|||
final double[] censorTimes = dataset.stream()
|
||||
.mapToDouble(row -> ((CompetingRiskResponseWithCensorTime) row.getResponse()).getC())
|
||||
.toArray();
|
||||
final MathFunction censorDistribution = Utils.estimateOneMinusECDF(censorTimes);
|
||||
final StepFunction censorDistribution = Utils.estimateOneMinusECDF(censorTimes);
|
||||
|
||||
System.out.println("Finished generating censor distribution - running concordance");
|
||||
|
||||
|
|
|
@ -2,11 +2,10 @@ package ca.joeltherrien.randomforest;
|
|||
|
||||
import ca.joeltherrien.randomforest.covariates.Covariate;
|
||||
import ca.joeltherrien.randomforest.covariates.CovariateSettings;
|
||||
import ca.joeltherrien.randomforest.responses.competingrisk.*;
|
||||
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
|
||||
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponseWithCensorTime;
|
||||
import ca.joeltherrien.randomforest.responses.competingrisk.combiner.CompetingRiskFunctionCombiner;
|
||||
import ca.joeltherrien.randomforest.responses.competingrisk.combiner.CompetingRiskResponseCombiner;
|
||||
import ca.joeltherrien.randomforest.responses.competingrisk.combiner.alternative.CompetingRiskListCombiner;
|
||||
import ca.joeltherrien.randomforest.responses.competingrisk.combiner.alternative.CompetingRiskResponseCombinerToList;
|
||||
import ca.joeltherrien.randomforest.responses.competingrisk.differentiator.GrayLogRankMultipleGroupDifferentiator;
|
||||
import ca.joeltherrien.randomforest.responses.competingrisk.differentiator.GrayLogRankSingleGroupDifferentiator;
|
||||
import ca.joeltherrien.randomforest.responses.competingrisk.differentiator.LogRankMultipleGroupDifferentiator;
|
||||
|
@ -22,7 +21,10 @@ import com.fasterxml.jackson.databind.ObjectMapper;
|
|||
import com.fasterxml.jackson.databind.node.JsonNodeFactory;
|
||||
import com.fasterxml.jackson.databind.node.ObjectNode;
|
||||
import com.fasterxml.jackson.dataformat.yaml.YAMLFactory;
|
||||
import lombok.*;
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Builder;
|
||||
import lombok.Data;
|
||||
import lombok.EqualsAndHashCode;
|
||||
|
||||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
|
@ -76,7 +78,13 @@ public class Settings {
|
|||
(objectNode) -> {
|
||||
final int eventOfFocus = objectNode.get("eventOfFocus").asInt();
|
||||
|
||||
return new LogRankSingleGroupDifferentiator(eventOfFocus);
|
||||
final Iterator<JsonNode> elements = objectNode.get("events").elements();
|
||||
final List<JsonNode> elementList = new ArrayList<>();
|
||||
elements.forEachRemaining(node -> elementList.add(node));
|
||||
|
||||
final int[] eventArray = elementList.stream().mapToInt(node -> node.asInt()).toArray();
|
||||
|
||||
return new LogRankSingleGroupDifferentiator(eventOfFocus, eventArray);
|
||||
}
|
||||
);
|
||||
registerGroupDifferentiatorConstructor("GrayLogRankMultipleGroupDifferentiator",
|
||||
|
@ -105,7 +113,14 @@ public class Settings {
|
|||
(objectNode) -> {
|
||||
final int eventOfFocus = objectNode.get("eventOfFocus").asInt();
|
||||
|
||||
return new GrayLogRankSingleGroupDifferentiator(eventOfFocus);
|
||||
final Iterator<JsonNode> elements = objectNode.get("events").elements();
|
||||
final List<JsonNode> elementList = new ArrayList<>();
|
||||
elements.forEachRemaining(node -> elementList.add(node));
|
||||
|
||||
final int[] eventArray = elementList.stream().mapToInt(node -> node.asInt()).toArray();
|
||||
|
||||
|
||||
return new GrayLogRankSingleGroupDifferentiator(eventOfFocus, eventArray);
|
||||
}
|
||||
);
|
||||
}
|
||||
|
@ -154,23 +169,6 @@ public class Settings {
|
|||
}
|
||||
);
|
||||
|
||||
registerResponseCombinerConstructor("CompetingRiskListCombiner",
|
||||
(node) -> {
|
||||
final List<Integer> eventList = new ArrayList<>();
|
||||
node.get("events").elements().forEachRemaining(event -> eventList.add(event.asInt()));
|
||||
final int[] events = eventList.stream().mapToInt(i -> i).toArray();
|
||||
|
||||
|
||||
final CompetingRiskResponseCombiner responseCombiner = new CompetingRiskResponseCombiner(events);
|
||||
return new CompetingRiskListCombiner(responseCombiner);
|
||||
|
||||
}
|
||||
);
|
||||
|
||||
registerResponseCombinerConstructor("CompetingRiskResponseCombinerToList",
|
||||
(node) -> new CompetingRiskResponseCombinerToList()
|
||||
);
|
||||
|
||||
|
||||
}
|
||||
|
||||
|
|
|
@ -2,7 +2,7 @@ package ca.joeltherrien.randomforest.responses.competingrisk;
|
|||
|
||||
import ca.joeltherrien.randomforest.Row;
|
||||
import ca.joeltherrien.randomforest.tree.Forest;
|
||||
import ca.joeltherrien.randomforest.utils.MathFunction;
|
||||
import ca.joeltherrien.randomforest.utils.StepFunction;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
|
@ -110,13 +110,13 @@ public class CompetingRiskErrorRateCalculator {
|
|||
|
||||
}
|
||||
|
||||
public double[] calculateIPCWConcordance(final int[] events, final MathFunction censoringDistribution){
|
||||
public double[] calculateIPCWConcordance(final int[] events, final StepFunction censoringDistribution){
|
||||
final double tau = dataset.stream().mapToDouble(row -> row.getResponse().getU()).max().orElse(0.0);
|
||||
|
||||
return calculateIPCWConcordance(events, censoringDistribution, tau);
|
||||
}
|
||||
|
||||
private double[] calculateIPCWConcordance(final int[] events, final MathFunction censoringDistribution, final double tau){
|
||||
private double[] calculateIPCWConcordance(final int[] events, final StepFunction censoringDistribution, final double tau){
|
||||
|
||||
final double[] errorRates = new double[events.length];
|
||||
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
package ca.joeltherrien.randomforest.responses.competingrisk;
|
||||
|
||||
import ca.joeltherrien.randomforest.utils.MathFunction;
|
||||
import ca.joeltherrien.randomforest.utils.Point;
|
||||
import ca.joeltherrien.randomforest.utils.StepFunction;
|
||||
import lombok.Builder;
|
||||
import lombok.Getter;
|
||||
|
||||
|
@ -11,44 +10,48 @@ import java.util.List;
|
|||
@Builder
|
||||
public class CompetingRiskFunctions implements Serializable {
|
||||
|
||||
private final List<MathFunction> causeSpecificHazards;
|
||||
private final List<MathFunction> cumulativeIncidenceCurves;
|
||||
private final List<StepFunction> causeSpecificHazards;
|
||||
private final List<StepFunction> cumulativeIncidenceCurves;
|
||||
|
||||
@Getter
|
||||
private final MathFunction survivalCurve;
|
||||
private final StepFunction survivalCurve;
|
||||
|
||||
public MathFunction getCauseSpecificHazardFunction(int cause){
|
||||
public StepFunction getCauseSpecificHazardFunction(int cause){
|
||||
return causeSpecificHazards.get(cause-1);
|
||||
}
|
||||
|
||||
public MathFunction getCumulativeIncidenceFunction(int cause) {
|
||||
public StepFunction getCumulativeIncidenceFunction(int cause) {
|
||||
return cumulativeIncidenceCurves.get(cause-1);
|
||||
}
|
||||
|
||||
public double calculateEventSpecificMortality(final int event, final double tau){
|
||||
final MathFunction cif = getCumulativeIncidenceFunction(event);
|
||||
final StepFunction cif = getCumulativeIncidenceFunction(event);
|
||||
|
||||
double summation = 0.0;
|
||||
Point previousPoint = null;
|
||||
|
||||
for(final Point point : cif.getPoints()){
|
||||
if(point.getTime() > tau){
|
||||
Double previousTime = null;
|
||||
Double previousY = null;
|
||||
|
||||
final double[] cifTimes = cif.getX();
|
||||
for(int i=0; i<cifTimes.length; i++){
|
||||
final double time = cifTimes[i];
|
||||
|
||||
if(time > tau){
|
||||
break;
|
||||
}
|
||||
|
||||
if(previousPoint != null){
|
||||
summation += previousPoint.getY() * (point.getTime() - previousPoint.getTime());
|
||||
if(previousTime != null){
|
||||
summation += previousY * (time - previousTime);
|
||||
}
|
||||
previousPoint = point;
|
||||
|
||||
previousTime = time;
|
||||
previousY = cif.evaluateByIndex(i);
|
||||
}
|
||||
|
||||
// this is to ensure that we integrate over the proper range
|
||||
if(previousPoint != null){
|
||||
summation += cif.evaluate(tau).getY() * (tau - previousPoint.getTime());
|
||||
if(previousTime != null){
|
||||
summation += cif.evaluate(tau) * (tau - previousTime);
|
||||
}
|
||||
|
||||
|
||||
return summation;
|
||||
|
||||
}
|
||||
|
|
|
@ -0,0 +1,37 @@
|
|||
package ca.joeltherrien.randomforest.responses.competingrisk;
|
||||
|
||||
import ca.joeltherrien.randomforest.utils.MathFunction;
|
||||
import lombok.Builder;
|
||||
import lombok.Getter;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
/**
|
||||
* Represents a response from CompetingRiskUtils#calculateGraySetsEfficiently
|
||||
*
|
||||
*/
|
||||
@Builder
|
||||
@Getter
|
||||
public class CompetingRiskGraySetsImpl implements CompetingRiskSets{
|
||||
|
||||
private final List<Double> eventTimes;
|
||||
private final MathFunction[] riskSet;
|
||||
private final Map<Double, int[]> numberOfEvents;
|
||||
|
||||
@Override
|
||||
public MathFunction getRiskSet(int event){
|
||||
return(riskSet[event-1]);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getNumberOfEvents(Double time, int event){
|
||||
if(numberOfEvents.containsKey(time)){
|
||||
return numberOfEvents.get(time)[event];
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
}
|
|
@ -0,0 +1,13 @@
|
|||
package ca.joeltherrien.randomforest.responses.competingrisk;
|
||||
|
||||
import ca.joeltherrien.randomforest.utils.MathFunction;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
public interface CompetingRiskSets {
|
||||
|
||||
MathFunction getRiskSet(int event);
|
||||
int getNumberOfEvents(Double time, int event);
|
||||
List<Double> getEventTimes();
|
||||
|
||||
}
|
|
@ -0,0 +1,36 @@
|
|||
package ca.joeltherrien.randomforest.responses.competingrisk;
|
||||
|
||||
import ca.joeltherrien.randomforest.utils.MathFunction;
|
||||
import lombok.Builder;
|
||||
import lombok.Getter;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
/**
|
||||
* Represents a response from CompetingRiskUtils#calculateSetsEfficiently
|
||||
*
|
||||
*/
|
||||
@Builder
|
||||
@Getter
|
||||
public class CompetingRiskSetsImpl implements CompetingRiskSets{
|
||||
|
||||
private final List<Double> eventTimes;
|
||||
private final MathFunction riskSet;
|
||||
private final Map<Double, int[]> numberOfEvents;
|
||||
|
||||
@Override
|
||||
public MathFunction getRiskSet(int event){
|
||||
return riskSet;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getNumberOfEvents(Double time, int event){
|
||||
if(numberOfEvents.containsKey(time)){
|
||||
return numberOfEvents.get(time)[event];
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
}
|
|
@ -1,8 +1,11 @@
|
|||
package ca.joeltherrien.randomforest.responses.competingrisk;
|
||||
|
||||
import ca.joeltherrien.randomforest.utils.MathFunction;
|
||||
import ca.joeltherrien.randomforest.utils.LeftContinuousStepFunction;
|
||||
import ca.joeltherrien.randomforest.utils.StepFunction;
|
||||
import ca.joeltherrien.randomforest.utils.VeryDiscontinuousStepFunction;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.*;
|
||||
import java.util.stream.DoubleStream;
|
||||
|
||||
public class CompetingRiskUtils {
|
||||
|
||||
|
@ -46,7 +49,9 @@ public class CompetingRiskUtils {
|
|||
}
|
||||
|
||||
|
||||
public static double calculateIPCWConcordance(final List<CompetingRiskResponse> responseList, double[] mortalityArray, final int event, final MathFunction censoringDistribution){
|
||||
public static double calculateIPCWConcordance(final List<CompetingRiskResponse> responseList,
|
||||
double[] mortalityArray, final int event,
|
||||
final StepFunction censoringDistribution){
|
||||
|
||||
// Let \tau be the max time.
|
||||
|
||||
|
@ -61,8 +66,8 @@ public class CompetingRiskUtils {
|
|||
|
||||
final double mortalityI = mortalityArray[i];
|
||||
final double Ti = responseI.getU();
|
||||
final double G_Ti_minus = censoringDistribution.evaluatePrevious(Ti).getY();
|
||||
final double AijWeight = 1.0 / (censoringDistribution.evaluate(Ti).getY() * G_Ti_minus);
|
||||
final double G_Ti_minus = censoringDistribution.evaluatePrevious(Ti);
|
||||
final double AijWeight = 1.0 / (censoringDistribution.evaluate(Ti) * G_Ti_minus);
|
||||
|
||||
for(int j=0; j<mortalityArray.length; j++){
|
||||
final CompetingRiskResponse responseJ = responseList.get(j);
|
||||
|
@ -73,7 +78,7 @@ public class CompetingRiskUtils {
|
|||
AijWeightPlusBijWeight = AijWeight;
|
||||
}
|
||||
else if(responseI.getU() >= responseJ.getU() && !responseJ.isCensored() && responseJ.getDelta() != event){ // Bij == 1
|
||||
AijWeightPlusBijWeight = 1.0 / (G_Ti_minus * censoringDistribution.evaluatePrevious(responseJ.getU()).getY());
|
||||
AijWeightPlusBijWeight = 1.0 / (G_Ti_minus * censoringDistribution.evaluatePrevious(responseJ.getU()));
|
||||
}
|
||||
else{
|
||||
continue;
|
||||
|
@ -97,5 +102,152 @@ public class CompetingRiskUtils {
|
|||
|
||||
}
|
||||
|
||||
public static CompetingRiskSetsImpl calculateSetsEfficiently(final List<CompetingRiskResponse> responses, int[] eventsOfFocus){
|
||||
final int n = responses.size();
|
||||
int[] numberOfCurrentEvents = new int[eventsOfFocus.length+1];
|
||||
|
||||
final Map<Double, int[]> numberOfEvents = new HashMap<>();
|
||||
|
||||
final List<Double> eventTimes = new ArrayList<>(n);
|
||||
final List<Double> eventAndCensorTimes = new ArrayList<>(n);
|
||||
final List<Integer> riskSetNumberList = new ArrayList<>(n);
|
||||
|
||||
// need to first sort responses
|
||||
Collections.sort(responses, (y1, y2) -> {
|
||||
if(y1.getU() < y2.getU()){
|
||||
return -1;
|
||||
}
|
||||
else if(y1.getU() > y2.getU()){
|
||||
return 1;
|
||||
}
|
||||
else{
|
||||
return 0;
|
||||
}
|
||||
});
|
||||
|
||||
|
||||
|
||||
for(int i=0; i<n; i++){
|
||||
final CompetingRiskResponse currentResponse = responses.get(i);
|
||||
final boolean lastOfTime = (i+1)==n || responses.get(i+1).getU() > currentResponse.getU();
|
||||
|
||||
numberOfCurrentEvents[currentResponse.getDelta()]++;
|
||||
|
||||
if(lastOfTime){
|
||||
int totalNumberOfCurrentEvents = 0;
|
||||
for(int e = 1; e < numberOfCurrentEvents.length; e++){ // exclude censored events
|
||||
totalNumberOfCurrentEvents += numberOfCurrentEvents[e];
|
||||
}
|
||||
|
||||
final double currentTime = currentResponse.getU();
|
||||
|
||||
if(totalNumberOfCurrentEvents > 0){ // add numberOfCurrentEvents
|
||||
// Add point
|
||||
eventTimes.add(currentTime);
|
||||
numberOfEvents.put(currentTime, numberOfCurrentEvents);
|
||||
}
|
||||
|
||||
// Always do risk set
|
||||
// remember that the LeftContinuousFunction takes into account that at this currentTime the risk value is the previous value
|
||||
final int riskSet = n - (i+1);
|
||||
riskSetNumberList.add(riskSet);
|
||||
eventAndCensorTimes.add(currentTime);
|
||||
|
||||
// reset counters
|
||||
numberOfCurrentEvents = new int[eventsOfFocus.length+1];
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
final double[] riskSetArray = new double[eventAndCensorTimes.size()];
|
||||
final double[] timesArray = new double[eventAndCensorTimes.size()];
|
||||
for(int i=0; i<riskSetArray.length; i++){
|
||||
timesArray[i] = eventAndCensorTimes.get(i);
|
||||
riskSetArray[i] = riskSetNumberList.get(i);
|
||||
}
|
||||
|
||||
final LeftContinuousStepFunction riskSetFunction = new LeftContinuousStepFunction(timesArray, riskSetArray, n);
|
||||
|
||||
return CompetingRiskSetsImpl.builder()
|
||||
.numberOfEvents(numberOfEvents)
|
||||
.riskSet(riskSetFunction)
|
||||
.eventTimes(eventTimes)
|
||||
.build();
|
||||
|
||||
}
|
||||
|
||||
public static CompetingRiskGraySetsImpl calculateGraySetsEfficiently(final List<CompetingRiskResponseWithCensorTime> responses, int[] eventsOfFocus){
|
||||
final List sillyList = responses; // annoying Java generic work-around
|
||||
final CompetingRiskSetsImpl originalSets = calculateSetsEfficiently(sillyList, eventsOfFocus);
|
||||
|
||||
final double[] allTimes = DoubleStream.concat(
|
||||
responses.stream()
|
||||
.mapToDouble(CompetingRiskResponseWithCensorTime::getC),
|
||||
responses.stream()
|
||||
.mapToDouble(CompetingRiskResponseWithCensorTime::getU)
|
||||
).sorted().distinct().toArray();
|
||||
|
||||
|
||||
|
||||
final VeryDiscontinuousStepFunction[] riskSets = new VeryDiscontinuousStepFunction[eventsOfFocus.length];
|
||||
|
||||
for(final int event : eventsOfFocus){
|
||||
final double[] yAt = new double[allTimes.length];
|
||||
final double[] yRight = new double[allTimes.length];
|
||||
|
||||
for(final CompetingRiskResponseWithCensorTime response : responses){
|
||||
if(response.getDelta() == event){
|
||||
// traditional case only; increment on time t when I(t <= Ui)
|
||||
final double time = response.getU();
|
||||
final int index = Arrays.binarySearch(allTimes, time);
|
||||
|
||||
if(index < 0){ // TODO remove once code is stable
|
||||
throw new IllegalStateException("Index shouldn't be negative!");
|
||||
}
|
||||
|
||||
// All yAts up to and including index are incremented;
|
||||
// All yRights up to index are incremented
|
||||
yAt[index]++;
|
||||
for(int i=0; i<index; i++){
|
||||
yAt[i]++;
|
||||
yRight[i]++;
|
||||
}
|
||||
}
|
||||
else{
|
||||
// need to increment on time t on following conditions; I(t <= Ui | t < Ci)
|
||||
// Fact: Ci >= Ui.
|
||||
|
||||
// increment yAt up to Ci. If Ui==Ci, increment yAt at Ci.
|
||||
final double time = response.getC();
|
||||
final int index = Arrays.binarySearch(allTimes, time);
|
||||
|
||||
if(index < 0){ // TODO remove once code is stable
|
||||
throw new IllegalStateException("Index shouldn't be negative!");
|
||||
}
|
||||
|
||||
for(int i=0; i<index; i++){
|
||||
yAt[i]++;
|
||||
yRight[i]++;
|
||||
}
|
||||
if(response.getU() == response.getC()){
|
||||
yAt[index]++;
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
riskSets[event-1] = new VeryDiscontinuousStepFunction(allTimes, yAt, yRight, responses.size());
|
||||
|
||||
}
|
||||
|
||||
return CompetingRiskGraySetsImpl.builder()
|
||||
.numberOfEvents(originalSets.getNumberOfEvents())
|
||||
.eventTimes(originalSets.getEventTimes())
|
||||
.riskSet(riskSets)
|
||||
.build();
|
||||
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
|
|
@ -2,11 +2,12 @@ package ca.joeltherrien.randomforest.responses.competingrisk.combiner;
|
|||
|
||||
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskFunctions;
|
||||
import ca.joeltherrien.randomforest.tree.ResponseCombiner;
|
||||
import ca.joeltherrien.randomforest.utils.MathFunction;
|
||||
import ca.joeltherrien.randomforest.utils.Point;
|
||||
import ca.joeltherrien.randomforest.utils.RightContinuousStepFunction;
|
||||
import ca.joeltherrien.randomforest.utils.StepFunction;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
|
||||
@RequiredArgsConstructor
|
||||
|
@ -34,51 +35,46 @@ public class CompetingRiskFunctionCombiner implements ResponseCombiner<Competing
|
|||
timesToUse = responses.stream()
|
||||
.map(functions -> functions.getSurvivalCurve())
|
||||
.flatMapToDouble(
|
||||
function -> function.getPoints().stream()
|
||||
.mapToDouble(point -> point.getTime())
|
||||
function -> Arrays.stream(function.getX())
|
||||
).sorted().distinct().toArray();
|
||||
}
|
||||
|
||||
final double n = responses.size();
|
||||
|
||||
final List<Point> survivalPoints = new ArrayList<>(timesToUse.length);
|
||||
for(final double time : timesToUse){
|
||||
final double[] survivalY = new double[timesToUse.length];
|
||||
|
||||
final double survivalY = responses.stream()
|
||||
.mapToDouble(functions -> functions.getSurvivalCurve().evaluate(time).getY() / n)
|
||||
for(int i=0; i<timesToUse.length; i++){
|
||||
final double time = timesToUse[i];
|
||||
survivalY[i] = responses.stream()
|
||||
.mapToDouble(functions -> functions.getSurvivalCurve().evaluate(time) / n)
|
||||
.sum();
|
||||
|
||||
survivalPoints.add(new Point(time, survivalY));
|
||||
|
||||
}
|
||||
|
||||
final MathFunction survivalFunction = new MathFunction(survivalPoints, new Point(0.0, 1.0));
|
||||
final StepFunction survivalFunction = new RightContinuousStepFunction(timesToUse, survivalY, 1.0);
|
||||
|
||||
final List<MathFunction> causeSpecificCumulativeHazardFunctionList = new ArrayList<>(events.length);
|
||||
final List<MathFunction> cumulativeIncidenceFunctionList = new ArrayList<>(events.length);
|
||||
final List<StepFunction> causeSpecificCumulativeHazardFunctionList = new ArrayList<>(events.length);
|
||||
final List<StepFunction> cumulativeIncidenceFunctionList = new ArrayList<>(events.length);
|
||||
|
||||
for(final int event : events){
|
||||
|
||||
final List<Point> cumulativeHazardFunctionPoints = new ArrayList<>(timesToUse.length);
|
||||
final List<Point> cumulativeIncidenceFunctionPoints = new ArrayList<>(timesToUse.length);
|
||||
final double[] cumulativeHazardFunctionY = new double[timesToUse.length];
|
||||
final double[] cumulativeIncidenceFunctionY = new double[timesToUse.length];
|
||||
|
||||
for(final double time : timesToUse){
|
||||
for(int i=0; i<timesToUse.length; i++){
|
||||
final double time = timesToUse[i];
|
||||
|
||||
final double hazardY = responses.stream()
|
||||
.mapToDouble(functions -> functions.getCauseSpecificHazardFunction(event).evaluate(time).getY() / n)
|
||||
cumulativeHazardFunctionY[i] = responses.stream()
|
||||
.mapToDouble(functions -> functions.getCauseSpecificHazardFunction(event).evaluate(time) / n)
|
||||
.sum();
|
||||
|
||||
final double incidenceY = responses.stream()
|
||||
.mapToDouble(functions -> functions.getCumulativeIncidenceFunction(event).evaluate(time).getY() / n)
|
||||
cumulativeIncidenceFunctionY[i] = responses.stream()
|
||||
.mapToDouble(functions -> functions.getCumulativeIncidenceFunction(event).evaluate(time) / n)
|
||||
.sum();
|
||||
|
||||
cumulativeHazardFunctionPoints.add(new Point(time, hazardY));
|
||||
cumulativeIncidenceFunctionPoints.add(new Point(time, incidenceY));
|
||||
|
||||
}
|
||||
|
||||
causeSpecificCumulativeHazardFunctionList.add(event-1, new MathFunction(cumulativeHazardFunctionPoints));
|
||||
cumulativeIncidenceFunctionList.add(event-1, new MathFunction(cumulativeIncidenceFunctionPoints));
|
||||
causeSpecificCumulativeHazardFunctionList.add(event-1, new RightContinuousStepFunction(timesToUse, cumulativeHazardFunctionY, 0));
|
||||
cumulativeIncidenceFunctionList.add(event-1, new RightContinuousStepFunction(timesToUse, cumulativeIncidenceFunctionY, 0));
|
||||
|
||||
}
|
||||
|
||||
|
|
|
@ -3,9 +3,9 @@ package ca.joeltherrien.randomforest.responses.competingrisk.combiner;
|
|||
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskFunctions;
|
||||
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
|
||||
import ca.joeltherrien.randomforest.tree.ResponseCombiner;
|
||||
import ca.joeltherrien.randomforest.utils.MathFunction;
|
||||
import ca.joeltherrien.randomforest.utils.Point;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import ca.joeltherrien.randomforest.utils.RightContinuousStepFunction;
|
||||
import ca.joeltherrien.randomforest.utils.StepFunction;
|
||||
|
||||
import java.util.*;
|
||||
|
||||
|
@ -38,8 +38,8 @@ public class CompetingRiskResponseCombiner implements ResponseCombiner<Competing
|
|||
@Override
|
||||
public CompetingRiskFunctions combine(List<CompetingRiskResponse> responses) {
|
||||
|
||||
final List<MathFunction> causeSpecificCumulativeHazardFunctionList = new ArrayList<>(events.length);
|
||||
final List<MathFunction> cumulativeIncidenceFunctionList = new ArrayList<>(events.length);
|
||||
final List<StepFunction> causeSpecificCumulativeHazardFunctionList = new ArrayList<>(events.length);
|
||||
final List<StepFunction> cumulativeIncidenceFunctionList = new ArrayList<>(events.length);
|
||||
|
||||
Collections.sort(responses, (y1, y2) -> {
|
||||
if(y1.getU() < y2.getU()){
|
||||
|
@ -97,7 +97,7 @@ public class CompetingRiskResponseCombiner implements ResponseCombiner<Competing
|
|||
}
|
||||
|
||||
}
|
||||
final MathFunction survivalCurve = new MathFunction(survivalPoints, new Point(0.0, 1.0));
|
||||
final StepFunction survivalCurve = RightContinuousStepFunction.constructFromPoints(survivalPoints, 1.0);
|
||||
|
||||
|
||||
for(final int event : events){
|
||||
|
@ -129,7 +129,7 @@ public class CompetingRiskResponseCombiner implements ResponseCombiner<Competing
|
|||
// Cumulative incidence function
|
||||
// TODO - confirm this behaviour
|
||||
//final double previousSurvivalEvaluation = i > 0 ? survivalCurve.evaluate(timesToUse[i-1]).getY() : survivalCurve.evaluate(0.0).getY();
|
||||
final double previousSurvivalEvaluation = i > 0 ? survivalCurve.evaluate(timesToUseList.get(i-1)).getY() : 1.0;
|
||||
final double previousSurvivalEvaluation = i > 0 ? survivalCurve.evaluateByIndex(i-1) : 1.0;
|
||||
|
||||
final double cifDeltaY = previousSurvivalEvaluation * (numberEventsAtTime / individualsAtRisk);
|
||||
final Point newCIFPoint = new Point(time_k, previousCIFPoint.getY() + cifDeltaY);
|
||||
|
@ -138,10 +138,10 @@ public class CompetingRiskResponseCombiner implements ResponseCombiner<Competing
|
|||
|
||||
}
|
||||
|
||||
final MathFunction causeSpecificCumulativeHazardFunction = new MathFunction(hazardFunctionPoints);
|
||||
final StepFunction causeSpecificCumulativeHazardFunction = RightContinuousStepFunction.constructFromPoints(hazardFunctionPoints, 0.0);
|
||||
causeSpecificCumulativeHazardFunctionList.add(event-1, causeSpecificCumulativeHazardFunction);
|
||||
|
||||
final MathFunction cifFunction = new MathFunction(cifPoints);
|
||||
final StepFunction cifFunction = RightContinuousStepFunction.constructFromPoints(cifPoints, 0.0);
|
||||
cumulativeIncidenceFunctionList.add(event-1, cifFunction);
|
||||
}
|
||||
|
||||
|
|
|
@ -1,26 +0,0 @@
|
|||
package ca.joeltherrien.randomforest.responses.competingrisk.combiner.alternative;
|
||||
|
||||
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskFunctions;
|
||||
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
|
||||
import ca.joeltherrien.randomforest.responses.competingrisk.combiner.CompetingRiskResponseCombiner;
|
||||
import ca.joeltherrien.randomforest.tree.ResponseCombiner;
|
||||
import lombok.Getter;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@RequiredArgsConstructor
|
||||
public class CompetingRiskListCombiner implements ResponseCombiner<CompetingRiskResponse[], CompetingRiskFunctions> {
|
||||
|
||||
@Getter
|
||||
private final CompetingRiskResponseCombiner originalCombiner;
|
||||
|
||||
@Override
|
||||
public CompetingRiskFunctions combine(List<CompetingRiskResponse[]> responses) {
|
||||
final List<CompetingRiskResponse> completeList = responses.stream().flatMap(Arrays::stream).collect(Collectors.toList());
|
||||
|
||||
return originalCombiner.combine(completeList);
|
||||
}
|
||||
}
|
|
@ -1,29 +0,0 @@
|
|||
package ca.joeltherrien.randomforest.responses.competingrisk.combiner.alternative;
|
||||
|
||||
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
|
||||
import ca.joeltherrien.randomforest.tree.ResponseCombiner;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* This class takes all of the observations in a terminal node and 'combines' them into just a list of the observations.
|
||||
*
|
||||
* This is used in the alternative approach to only compute the functions at the final stage when combining trees.
|
||||
*
|
||||
*/
|
||||
public class CompetingRiskResponseCombinerToList implements ResponseCombiner<CompetingRiskResponse, CompetingRiskResponse[]> {
|
||||
|
||||
@Override
|
||||
public CompetingRiskResponse[] combine(List<CompetingRiskResponse> responses) {
|
||||
final CompetingRiskResponse[] array = new CompetingRiskResponse[responses.size()];
|
||||
|
||||
for(int i=0; i<array.length; i++){
|
||||
array[i] = responses.get(i);
|
||||
}
|
||||
|
||||
return array;
|
||||
}
|
||||
|
||||
|
||||
|
||||
}
|
|
@ -1,6 +1,7 @@
|
|||
package ca.joeltherrien.randomforest.responses.competingrisk.differentiator;
|
||||
|
||||
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
|
||||
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskSets;
|
||||
import ca.joeltherrien.randomforest.tree.GroupDifferentiator;
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Data;
|
||||
|
@ -18,31 +19,22 @@ public abstract class CompetingRiskGroupDifferentiator<Y extends CompetingRiskRe
|
|||
@Override
|
||||
public abstract Double differentiate(List<Y> leftHand, List<Y> rightHand);
|
||||
|
||||
abstract double riskSet(final List<Y> eventList, double time, int eventOfFocus);
|
||||
|
||||
private double numberOfEventsAtTime(int eventOfFocus, List<Y> eventList, double time){
|
||||
return (double) eventList.stream()
|
||||
.filter(event -> event.getDelta() == eventOfFocus)
|
||||
.filter(event -> event.getU() == time) // since delta != 0 we know censoring didn't occur prior to this
|
||||
.count();
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculates the log rank value (or the Gray's test value) for a *specific* event cause.
|
||||
*
|
||||
* @param eventOfFocus
|
||||
* @param leftHand A non-empty list of CompetingRiskResponse
|
||||
* @param rightHand A non-empty list of CompetingRiskResponse
|
||||
* @param competingRiskSetsLeft A summary of the different sets used in the calculation for the left side
|
||||
* @param competingRiskSetsRight A summary of the different sets used in the calculation for the right side
|
||||
* @return
|
||||
*/
|
||||
LogRankValue specificLogRankValue(final int eventOfFocus, List<Y> leftHand, List<Y> rightHand){
|
||||
LogRankValue specificLogRankValue(final int eventOfFocus, final CompetingRiskSets competingRiskSetsLeft, final CompetingRiskSets competingRiskSetsRight){
|
||||
|
||||
final double[] distinctEventTimes = Stream.concat(
|
||||
leftHand.stream(), rightHand.stream()
|
||||
)
|
||||
.filter(event -> !event.isCensored())
|
||||
.mapToDouble(event -> event.getU())
|
||||
competingRiskSetsLeft.getEventTimes().stream(),
|
||||
competingRiskSetsRight.getEventTimes().stream())
|
||||
.mapToDouble(Double::doubleValue)
|
||||
.sorted()
|
||||
.distinct()
|
||||
.toArray();
|
||||
|
||||
|
@ -51,12 +43,12 @@ public abstract class CompetingRiskGroupDifferentiator<Y extends CompetingRiskRe
|
|||
|
||||
for(final double time_k : distinctEventTimes){
|
||||
final double weight = weight(time_k); // W_j(t_k)
|
||||
final double numberEventsAtTimeDaughterLeft = numberOfEventsAtTime(eventOfFocus, leftHand, time_k); // d_{j,l}(t_k)
|
||||
final double numberEventsAtTimeDaughterRight = numberOfEventsAtTime(eventOfFocus, rightHand, time_k); // d_{j,r}(t_k)
|
||||
final double numberEventsAtTimeDaughterLeft = competingRiskSetsLeft.getNumberOfEvents(time_k, eventOfFocus); // // d_{j,l}(t_k)
|
||||
final double numberEventsAtTimeDaughterRight = competingRiskSetsRight.getNumberOfEvents(time_k, eventOfFocus); // d_{j,r}(t_k)
|
||||
final double numberOfEventsAtTime = numberEventsAtTimeDaughterLeft + numberEventsAtTimeDaughterRight; // d_j(t_k)
|
||||
|
||||
final double individualsAtRiskDaughterLeft = riskSet(leftHand, time_k, eventOfFocus); // Y_l(t_k)
|
||||
final double individualsAtRiskDaughterRight = riskSet(rightHand, time_k, eventOfFocus); // Y_r(t_k)
|
||||
final double individualsAtRiskDaughterLeft = competingRiskSetsLeft.getRiskSet(eventOfFocus).evaluate(time_k); // Y_l(t_k)
|
||||
final double individualsAtRiskDaughterRight = competingRiskSetsRight.getRiskSet(eventOfFocus).evaluate(time_k); // Y_r(t_k)
|
||||
final double individualsAtRisk = individualsAtRiskDaughterLeft + individualsAtRiskDaughterRight; // Y(t_k)
|
||||
|
||||
final double deltaSummation = weight*(numberEventsAtTimeDaughterLeft - numberOfEventsAtTime*individualsAtRiskDaughterLeft/individualsAtRisk);
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
package ca.joeltherrien.randomforest.responses.competingrisk.differentiator;
|
||||
|
||||
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskGraySetsImpl;
|
||||
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponseWithCensorTime;
|
||||
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskUtils;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
|
||||
import java.util.List;
|
||||
|
@ -20,11 +22,14 @@ public class GrayLogRankMultipleGroupDifferentiator extends CompetingRiskGroupDi
|
|||
return null;
|
||||
}
|
||||
|
||||
final CompetingRiskGraySetsImpl competingRiskSetsLeft = CompetingRiskUtils.calculateGraySetsEfficiently(leftHand, events);
|
||||
final CompetingRiskGraySetsImpl competingRiskSetsRight = CompetingRiskUtils.calculateGraySetsEfficiently(rightHand, events);
|
||||
|
||||
double numerator = 0.0;
|
||||
double denominatorSquared = 0.0;
|
||||
|
||||
for(final int eventOfFocus : events){
|
||||
final LogRankValue valueOfInterest = specificLogRankValue(eventOfFocus, leftHand, rightHand);
|
||||
final LogRankValue valueOfInterest = specificLogRankValue(eventOfFocus, competingRiskSetsLeft, competingRiskSetsRight);
|
||||
|
||||
numerator += valueOfInterest.getNumerator()*valueOfInterest.getVarianceSqrt();
|
||||
denominatorSquared += valueOfInterest.getVariance();
|
||||
|
@ -35,13 +40,5 @@ public class GrayLogRankMultipleGroupDifferentiator extends CompetingRiskGroupDi
|
|||
|
||||
}
|
||||
|
||||
@Override
|
||||
double riskSet(List<CompetingRiskResponseWithCensorTime> eventList, double time, int eventOfFocus) {
|
||||
return eventList.stream()
|
||||
.filter(event -> event.getU() >= time ||
|
||||
(event.getU() < time && event.getDelta() != eventOfFocus && event.getC() > time)
|
||||
)
|
||||
.count();
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
package ca.joeltherrien.randomforest.responses.competingrisk.differentiator;
|
||||
|
||||
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskGraySetsImpl;
|
||||
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponseWithCensorTime;
|
||||
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskUtils;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
|
||||
import java.util.List;
|
||||
|
@ -13,6 +15,7 @@ import java.util.List;
|
|||
public class GrayLogRankSingleGroupDifferentiator extends CompetingRiskGroupDifferentiator<CompetingRiskResponseWithCensorTime> {
|
||||
|
||||
private final int eventOfFocus;
|
||||
private final int[] events;
|
||||
|
||||
@Override
|
||||
public Double differentiate(List<CompetingRiskResponseWithCensorTime> leftHand, List<CompetingRiskResponseWithCensorTime> rightHand) {
|
||||
|
@ -20,19 +23,13 @@ public class GrayLogRankSingleGroupDifferentiator extends CompetingRiskGroupDiff
|
|||
return null;
|
||||
}
|
||||
|
||||
final LogRankValue valueOfInterest = specificLogRankValue(eventOfFocus, leftHand, rightHand);
|
||||
final CompetingRiskGraySetsImpl competingRiskSetsLeft = CompetingRiskUtils.calculateGraySetsEfficiently(leftHand, events);
|
||||
final CompetingRiskGraySetsImpl competingRiskSetsRight = CompetingRiskUtils.calculateGraySetsEfficiently(rightHand, events);
|
||||
|
||||
final LogRankValue valueOfInterest = specificLogRankValue(eventOfFocus, competingRiskSetsLeft, competingRiskSetsRight);
|
||||
|
||||
return Math.abs(valueOfInterest.getNumerator() / valueOfInterest.getVarianceSqrt());
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
double riskSet(List<CompetingRiskResponseWithCensorTime> eventList, double time, int eventOfFocus) {
|
||||
return eventList.stream()
|
||||
.filter(event -> event.getU() >= time ||
|
||||
(event.getU() < time && event.getDelta() != eventOfFocus && event.getC() > time)
|
||||
)
|
||||
.count();
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
package ca.joeltherrien.randomforest.responses.competingrisk.differentiator;
|
||||
|
||||
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
|
||||
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskSetsImpl;
|
||||
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskUtils;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
|
||||
import java.util.List;
|
||||
|
@ -20,11 +22,14 @@ public class LogRankMultipleGroupDifferentiator extends CompetingRiskGroupDiffer
|
|||
return null;
|
||||
}
|
||||
|
||||
final CompetingRiskSetsImpl competingRiskSetsLeft = CompetingRiskUtils.calculateSetsEfficiently(leftHand, events);
|
||||
final CompetingRiskSetsImpl competingRiskSetsRight = CompetingRiskUtils.calculateSetsEfficiently(rightHand, events);
|
||||
|
||||
double numerator = 0.0;
|
||||
double denominatorSquared = 0.0;
|
||||
|
||||
for(final int eventOfFocus : events){
|
||||
final LogRankValue valueOfInterest = specificLogRankValue(eventOfFocus, leftHand, rightHand);
|
||||
final LogRankValue valueOfInterest = specificLogRankValue(eventOfFocus, competingRiskSetsLeft, competingRiskSetsRight);
|
||||
|
||||
numerator += valueOfInterest.getNumerator()*valueOfInterest.getVarianceSqrt();
|
||||
denominatorSquared += valueOfInterest.getVariance();
|
||||
|
@ -35,11 +40,4 @@ public class LogRankMultipleGroupDifferentiator extends CompetingRiskGroupDiffer
|
|||
|
||||
}
|
||||
|
||||
@Override
|
||||
double riskSet(List<CompetingRiskResponse> eventList, double time, int eventOfFocus) {
|
||||
return eventList.stream()
|
||||
.filter(event -> event.getU() >= time)
|
||||
.count();
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
package ca.joeltherrien.randomforest.responses.competingrisk.differentiator;
|
||||
|
||||
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
|
||||
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskSetsImpl;
|
||||
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskUtils;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
|
||||
import java.util.List;
|
||||
|
@ -13,6 +15,7 @@ import java.util.List;
|
|||
public class LogRankSingleGroupDifferentiator extends CompetingRiskGroupDifferentiator<CompetingRiskResponse> {
|
||||
|
||||
private final int eventOfFocus;
|
||||
private final int[] events;
|
||||
|
||||
@Override
|
||||
public Double differentiate(List<CompetingRiskResponse> leftHand, List<CompetingRiskResponse> rightHand) {
|
||||
|
@ -20,17 +23,13 @@ public class LogRankSingleGroupDifferentiator extends CompetingRiskGroupDifferen
|
|||
return null;
|
||||
}
|
||||
|
||||
final LogRankValue valueOfInterest = specificLogRankValue(eventOfFocus, leftHand, rightHand);
|
||||
final CompetingRiskSetsImpl competingRiskSetsLeft = CompetingRiskUtils.calculateSetsEfficiently(leftHand, events);
|
||||
final CompetingRiskSetsImpl competingRiskSetsRight = CompetingRiskUtils.calculateSetsEfficiently(rightHand, events);
|
||||
|
||||
final LogRankValue valueOfInterest = specificLogRankValue(eventOfFocus, competingRiskSetsLeft, competingRiskSetsRight);
|
||||
|
||||
return Math.abs(valueOfInterest.getNumerator() / valueOfInterest.getVarianceSqrt());
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
double riskSet(List<CompetingRiskResponse> eventList, double time, int eventOfFocus) {
|
||||
return eventList.stream()
|
||||
.filter(event -> event.getU() >= time)
|
||||
.count();
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -0,0 +1,97 @@
|
|||
package ca.joeltherrien.randomforest.utils;
|
||||
|
||||
/**
|
||||
* Represents a function represented by discrete points. However, the function may be right-continuous or left-continuous
|
||||
* at a given point, with no consistency. This function tracks that.
|
||||
*/
|
||||
public final class DiscontinuousStepFunction extends StepFunction {
|
||||
|
||||
private final double[] y;
|
||||
private final boolean[] isLeftContinuous;
|
||||
|
||||
/**
|
||||
* Represents the value that should be returned by evaluate if there are points prior to the time the function is being evaluated at.
|
||||
*
|
||||
* Map be null.
|
||||
*/
|
||||
private final double defaultY;
|
||||
|
||||
public DiscontinuousStepFunction(double[] x, double[] y, boolean[] isLeftContinuous, double defaultY) {
|
||||
super(x);
|
||||
this.y = y;
|
||||
this.isLeftContinuous = isLeftContinuous;
|
||||
this.defaultY = defaultY;
|
||||
}
|
||||
|
||||
@Override
|
||||
public double evaluate(double time){
|
||||
int index = Utils.binarySearchLessThan(0, x.length, x, time);
|
||||
if(index < 0){
|
||||
return defaultY;
|
||||
}
|
||||
else{
|
||||
if(x[index] == time){
|
||||
return evaluateByIndex(index);
|
||||
}
|
||||
else{
|
||||
return y[index];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public double evaluatePrevious(double time){
|
||||
int index = Utils.binarySearchLessThan(0, x.length, x, time) - 1;
|
||||
if(index < 0){
|
||||
return defaultY;
|
||||
}
|
||||
else{
|
||||
if(x[index] == time){
|
||||
return evaluateByIndex(index);
|
||||
}
|
||||
else{
|
||||
return y[index];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public double evaluateByIndex(int i) {
|
||||
if(isLeftContinuous[i]){
|
||||
i -= 1;
|
||||
}
|
||||
|
||||
if(i < 0){
|
||||
return defaultY;
|
||||
}
|
||||
|
||||
return y[i];
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString(){
|
||||
final StringBuilder builder = new StringBuilder();
|
||||
builder.append("Default point: ");
|
||||
builder.append(defaultY);
|
||||
builder.append("\n");
|
||||
|
||||
for(int i=0; i<x.length; i++){
|
||||
builder.append("x:");
|
||||
builder.append(x[i]);
|
||||
builder.append('\t');
|
||||
|
||||
if(isLeftContinuous[i]){
|
||||
builder.append("*y:");
|
||||
}
|
||||
else{
|
||||
builder.append("y*:");
|
||||
}
|
||||
builder.append(y[i]);
|
||||
builder.append("\n");
|
||||
}
|
||||
|
||||
return builder.toString();
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,113 @@
|
|||
package ca.joeltherrien.randomforest.utils;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.ListIterator;
|
||||
|
||||
/**
|
||||
* Represents a function represented by discrete points. We assume that the function is a stepwise left-continuous
|
||||
* function, constant at the value of the previous encountered point.
|
||||
*
|
||||
*/
|
||||
public final class LeftContinuousStepFunction extends StepFunction {
|
||||
|
||||
private final double[] y;
|
||||
|
||||
/**
|
||||
* Represents the value that should be returned by evaluate if there are points prior to the time the function is being evaluated at.
|
||||
*
|
||||
* Map be null.
|
||||
*/
|
||||
private final double defaultY;
|
||||
|
||||
public LeftContinuousStepFunction(double[] x, double[] y, double defaultY) {
|
||||
super(x);
|
||||
this.y = y;
|
||||
this.defaultY = defaultY;
|
||||
}
|
||||
|
||||
/**
|
||||
* This isn't a formal constructor because of limitations with abstract classes.
|
||||
*
|
||||
* @param pointList
|
||||
* @param defaultY
|
||||
* @return
|
||||
*/
|
||||
public static LeftContinuousStepFunction constructFromPoints(final List<Point> pointList, final double defaultY){
|
||||
|
||||
final double[] x = new double[pointList.size()];
|
||||
final double[] y = new double[pointList.size()];
|
||||
|
||||
final ListIterator<Point> pointIterator = pointList.listIterator();
|
||||
while(pointIterator.hasNext()){
|
||||
final int index = pointIterator.nextIndex();
|
||||
final Point currentPoint = pointIterator.next();
|
||||
|
||||
x[index] = currentPoint.getTime();
|
||||
y[index] = currentPoint.getY();
|
||||
}
|
||||
|
||||
return new LeftContinuousStepFunction(x, y, defaultY);
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public double evaluate(double time){
|
||||
int index = Utils.binarySearchLessThan(0, x.length, x, time);
|
||||
if(index < 0){
|
||||
return defaultY;
|
||||
}
|
||||
else{
|
||||
if(x[index] == time){
|
||||
return evaluateByIndex(index-1);
|
||||
}
|
||||
else{
|
||||
return y[index];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public double evaluatePrevious(double time){
|
||||
int index = Utils.binarySearchLessThan(0, x.length, x, time) - 1;
|
||||
if(index < 0){
|
||||
return defaultY;
|
||||
}
|
||||
else{
|
||||
if(x[index] == time){
|
||||
return evaluateByIndex(index-1);
|
||||
}
|
||||
else{
|
||||
return y[index];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public double evaluateByIndex(int i) {
|
||||
if(i < 0){
|
||||
return defaultY;
|
||||
}
|
||||
|
||||
return y[i];
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public String toString(){
|
||||
final StringBuilder builder = new StringBuilder();
|
||||
builder.append("Default point: ");
|
||||
builder.append(defaultY);
|
||||
builder.append("\n");
|
||||
|
||||
for(int i=0; i<x.length; i++){
|
||||
builder.append("x:");
|
||||
builder.append(x[i]);
|
||||
builder.append("\ty:");
|
||||
builder.append(y[i]);
|
||||
builder.append("\n");
|
||||
}
|
||||
|
||||
return builder.toString();
|
||||
}
|
||||
|
||||
}
|
|
@ -1,114 +1,9 @@
|
|||
package ca.joeltherrien.randomforest.utils;
|
||||
|
||||
import lombok.Getter;
|
||||
|
||||
import java.io.Serializable;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* Represents a function represented by discrete points. We assume that the function is a stepwise continuous function,
|
||||
* constant at the value of the previous encountered point.
|
||||
*
|
||||
*/
|
||||
public class MathFunction implements Serializable {
|
||||
|
||||
@Getter
|
||||
private final List<Point> points;
|
||||
|
||||
/**
|
||||
* Represents the value that should be returned by evaluate if there are points prior to the time the function is being evaluated at.
|
||||
*
|
||||
* Map be null.
|
||||
*/
|
||||
private final Point defaultValue;
|
||||
|
||||
public MathFunction(final List<Point> points){
|
||||
this(points, new Point(0.0, 0.0));
|
||||
}
|
||||
|
||||
public MathFunction(final List<Point> points, final Point defaultValue){
|
||||
this.points = Collections.unmodifiableList(points);
|
||||
this.defaultValue = defaultValue;
|
||||
}
|
||||
|
||||
public Point evaluate(double time){
|
||||
int index = binarySearch(points, time);
|
||||
if(index < 0){
|
||||
return defaultValue;
|
||||
}
|
||||
else{
|
||||
return points.get(index);
|
||||
}
|
||||
}
|
||||
|
||||
public Point evaluatePrevious(double time){
|
||||
|
||||
int index = binarySearch(points, time) - 1;
|
||||
if(index < 0){
|
||||
return defaultValue;
|
||||
}
|
||||
else{
|
||||
return points.get(index);
|
||||
}
|
||||
public interface MathFunction extends Serializable {
|
||||
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the index of the largest (in terms of time) Point that is <= the provided time value.
|
||||
*
|
||||
* @param points
|
||||
* @param time
|
||||
* @return The index of the largest Point who's time is <= the time parameter.
|
||||
*/
|
||||
private static int binarySearch(List<Point> points, double time){
|
||||
final int pointSize = points.size();
|
||||
|
||||
if(pointSize == 0 || points.get(pointSize-1).getTime() <= time){
|
||||
// we're already too far
|
||||
return pointSize - 1;
|
||||
}
|
||||
|
||||
if(pointSize < 200){
|
||||
for(int i = 0; i < pointSize; i++){
|
||||
if(points.get(i).getTime() > time){
|
||||
return i - 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// else
|
||||
|
||||
|
||||
final int middle = pointSize / 2;
|
||||
final double middleTime = points.get(middle).getTime();
|
||||
if(middleTime < time){
|
||||
// go right
|
||||
return binarySearch(points.subList(middle, pointSize), time) + middle;
|
||||
}
|
||||
else if(middleTime > time){
|
||||
// go left
|
||||
return binarySearch(points.subList(0, middle), time);
|
||||
}
|
||||
else{ // middleTime == time
|
||||
return middle;
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString(){
|
||||
final StringBuilder builder = new StringBuilder();
|
||||
builder.append("Default point: ");
|
||||
builder.append(defaultValue);
|
||||
builder.append("\n");
|
||||
|
||||
for(final Point point : points){
|
||||
builder.append(point);
|
||||
builder.append("\n");
|
||||
}
|
||||
|
||||
return builder.toString();
|
||||
}
|
||||
double evaluate(double time);
|
||||
|
||||
}
|
||||
|
|
|
@ -11,26 +11,12 @@ import java.util.zip.GZIPOutputStream;
|
|||
*/
|
||||
public final class RUtils {
|
||||
|
||||
public static double[] extractTimes(final MathFunction function){
|
||||
final List<Point> pointList = function.getPoints();
|
||||
final double[] times = new double[pointList.size()];
|
||||
|
||||
for(int i=0; i<pointList.size(); i++){
|
||||
times[i] = pointList.get(i).getTime();
|
||||
public static double[] extractTimes(final RightContinuousStepFunction function){
|
||||
return function.getX();
|
||||
}
|
||||
|
||||
return times;
|
||||
}
|
||||
|
||||
public static double[] extractY(final MathFunction function){
|
||||
final List<Point> pointList = function.getPoints();
|
||||
final double[] times = new double[pointList.size()];
|
||||
|
||||
for(int i=0; i<pointList.size(); i++){
|
||||
times[i] = pointList.get(i).getY();
|
||||
}
|
||||
|
||||
return times;
|
||||
public static double[] extractY(final RightContinuousStepFunction function){
|
||||
return function.getY();
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -0,0 +1,107 @@
|
|||
package ca.joeltherrien.randomforest.utils;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.ListIterator;
|
||||
|
||||
/**
|
||||
* Represents a function represented by discrete points. We assume that the function is a stepwise right-continuous
|
||||
* function, constant at the value of the previous encountered point.
|
||||
*
|
||||
*/
|
||||
public final class RightContinuousStepFunction extends StepFunction {
|
||||
|
||||
private final double[] y;
|
||||
|
||||
/**
|
||||
* Represents the value that should be returned by evaluate if there are points prior to the time the function is being evaluated at.
|
||||
*
|
||||
* Map be null.
|
||||
*/
|
||||
private final double defaultY;
|
||||
|
||||
public RightContinuousStepFunction(double[] x, double[] y, double defaultY) {
|
||||
super(x);
|
||||
this.y = y;
|
||||
this.defaultY = defaultY;
|
||||
}
|
||||
|
||||
/**
|
||||
* This isn't a formal constructor because of limitations with abstract classes.
|
||||
*
|
||||
* @param pointList
|
||||
* @param defaultY
|
||||
* @return
|
||||
*/
|
||||
public static RightContinuousStepFunction constructFromPoints(final List<Point> pointList, final double defaultY){
|
||||
|
||||
final double[] x = new double[pointList.size()];
|
||||
final double[] y = new double[pointList.size()];
|
||||
|
||||
final ListIterator<Point> pointIterator = pointList.listIterator();
|
||||
while(pointIterator.hasNext()){
|
||||
final int index = pointIterator.nextIndex();
|
||||
final Point currentPoint = pointIterator.next();
|
||||
|
||||
x[index] = currentPoint.getTime();
|
||||
y[index] = currentPoint.getY();
|
||||
}
|
||||
|
||||
return new RightContinuousStepFunction(x, y, defaultY);
|
||||
|
||||
}
|
||||
|
||||
public double[] getY(){
|
||||
return y.clone();
|
||||
}
|
||||
|
||||
@Override
|
||||
public double evaluate(double time){
|
||||
int index = Utils.binarySearchLessThan(0, x.length, x, time);
|
||||
if(index < 0){
|
||||
return defaultY;
|
||||
}
|
||||
else{
|
||||
return y[index];
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public double evaluatePrevious(double time){
|
||||
int index = Utils.binarySearchLessThan(0, x.length, x, time) - 1;
|
||||
if(index < 0){
|
||||
return defaultY;
|
||||
}
|
||||
else{
|
||||
return y[index];
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public double evaluateByIndex(int i) {
|
||||
if(i < 0){
|
||||
return defaultY;
|
||||
}
|
||||
|
||||
return y[i];
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public String toString(){
|
||||
final StringBuilder builder = new StringBuilder();
|
||||
builder.append("Default point: ");
|
||||
builder.append(defaultY);
|
||||
builder.append("\n");
|
||||
|
||||
for(int i=0; i<x.length; i++){
|
||||
builder.append("x:");
|
||||
builder.append(x[i]);
|
||||
builder.append("\ty:");
|
||||
builder.append(y[i]);
|
||||
builder.append("\n");
|
||||
}
|
||||
|
||||
return builder.toString();
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,26 @@
|
|||
package ca.joeltherrien.randomforest.utils;
|
||||
|
||||
public abstract class StepFunction implements MathFunction{
|
||||
|
||||
protected final double[] x;
|
||||
|
||||
StepFunction(double[] x){
|
||||
this.x = x;
|
||||
}
|
||||
|
||||
public double[] getX() {
|
||||
return x.clone();
|
||||
}
|
||||
|
||||
public abstract double evaluateByIndex(int i);
|
||||
|
||||
/**
|
||||
* Evaluate the function at the time *point* that occurred previous to time. This is NOT time - some delta, but rather
|
||||
* time[i-1].
|
||||
*
|
||||
* @param time
|
||||
* @return
|
||||
*/
|
||||
public abstract double evaluatePrevious(double time);
|
||||
|
||||
}
|
|
@ -0,0 +1,15 @@
|
|||
package ca.joeltherrien.randomforest.utils;
|
||||
|
||||
import lombok.RequiredArgsConstructor;
|
||||
|
||||
@RequiredArgsConstructor
|
||||
public class SumFunction implements MathFunction{
|
||||
|
||||
private final MathFunction function1;
|
||||
private final MathFunction function2;
|
||||
|
||||
@Override
|
||||
public double evaluate(double time) {
|
||||
return function1.evaluate(time) + function2.evaluate(time);
|
||||
}
|
||||
}
|
|
@ -5,8 +5,7 @@ import java.util.concurrent.ThreadLocalRandom;
|
|||
|
||||
public final class Utils {
|
||||
|
||||
public static MathFunction estimateOneMinusECDF(final double[] times){
|
||||
final Point defaultPoint = new Point(0.0, 1.0);
|
||||
public static StepFunction estimateOneMinusECDF(final double[] times){
|
||||
Arrays.sort(times);
|
||||
|
||||
final Map<Double, Integer> timeCounterMap = new HashMap<>();
|
||||
|
@ -33,7 +32,7 @@ public final class Utils {
|
|||
pointList.add(new Point(entry.getKey(), (double) newCount / n));
|
||||
}
|
||||
|
||||
return new MathFunction(pointList, defaultPoint);
|
||||
return RightContinuousStepFunction.constructFromPoints(pointList, 1.0);
|
||||
|
||||
}
|
||||
|
||||
|
@ -64,6 +63,48 @@ public final class Utils {
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the index of the largest (in terms of time) Point that is <= the provided time value.
|
||||
*
|
||||
* @param startIndex Only search from startIndex (inclusive)
|
||||
* @param endIndex Only search up to endIndex (exclusive)
|
||||
* @param time
|
||||
* @return The index of the largest Point who's time is <= the time parameter.
|
||||
*/
|
||||
public static int binarySearchLessThan(int startIndex, int endIndex, double[] x, double time){
|
||||
final int range = endIndex - startIndex;
|
||||
|
||||
if(range == 0 || x[endIndex-1] <= time){
|
||||
// we're already too far
|
||||
return endIndex - 1;
|
||||
}
|
||||
|
||||
if(range < 200){
|
||||
for(int i = startIndex; i < endIndex; i++){
|
||||
if(x[i] > time){
|
||||
return i - 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// else
|
||||
|
||||
|
||||
final int middle = range / 2;
|
||||
final double middleTime = x[middle];
|
||||
if(middleTime < time){
|
||||
// go right
|
||||
return binarySearchLessThan(middle, endIndex, x, time);
|
||||
}
|
||||
else if(middleTime > time){
|
||||
// go left
|
||||
return binarySearchLessThan(0, middle, x, time);
|
||||
}
|
||||
else{ // middleTime == time
|
||||
return middle;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Replacement for Java 9's List.of
|
||||
*
|
||||
|
|
|
@ -0,0 +1,64 @@
|
|||
package ca.joeltherrien.randomforest.utils;
|
||||
|
||||
/**
|
||||
* Represents a step function represented by discrete points. However, there may be individual time values that has
|
||||
* a y value that doesn't belong to a particular 'step'.
|
||||
*/
|
||||
public final class VeryDiscontinuousStepFunction implements MathFunction {
|
||||
|
||||
private final double[] x;
|
||||
private final double[] yAt;
|
||||
private final double[] yRight;
|
||||
|
||||
/**
|
||||
* Represents the value that should be returned by evaluate if there are points prior to the time the function is being evaluated at.
|
||||
*
|
||||
* Map be null.
|
||||
*/
|
||||
private final double defaultY;
|
||||
|
||||
public VeryDiscontinuousStepFunction(double[] x, double[] yAt, double[] yRight, double defaultY) {
|
||||
this.x = x;
|
||||
this.yAt = yAt;
|
||||
this.yRight = yRight;
|
||||
this.defaultY = defaultY;
|
||||
}
|
||||
|
||||
@Override
|
||||
public double evaluate(double time){
|
||||
int index = Utils.binarySearchLessThan(0, x.length, x, time);
|
||||
if(index < 0){
|
||||
return defaultY;
|
||||
}
|
||||
else{
|
||||
if(x[index] == time){
|
||||
return yAt[index];
|
||||
}
|
||||
else{ // time > x[index]
|
||||
return yRight[index];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public String toString(){
|
||||
final StringBuilder builder = new StringBuilder();
|
||||
builder.append("Default point: ");
|
||||
builder.append(defaultY);
|
||||
builder.append("\n");
|
||||
|
||||
for(int i=0; i<x.length; i++){
|
||||
builder.append("x:");
|
||||
builder.append(x[i]);
|
||||
builder.append("\tyAt:");
|
||||
builder.append(yAt[i]);
|
||||
builder.append("\tyRight:");
|
||||
builder.append(yRight[i]);
|
||||
builder.append("\n");
|
||||
}
|
||||
|
||||
return builder.toString();
|
||||
}
|
||||
|
||||
}
|
|
@ -16,7 +16,6 @@ import static org.junit.jupiter.api.Assertions.*;
|
|||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
public class TestSavingLoading {
|
||||
|
||||
|
@ -31,6 +30,9 @@ public class TestSavingLoading {
|
|||
final ObjectNode groupDifferentiatorSettings = new ObjectNode(JsonNodeFactory.instance);
|
||||
groupDifferentiatorSettings.set("type", new TextNode("LogRankSingleGroupDifferentiator"));
|
||||
groupDifferentiatorSettings.set("eventOfFocus", new IntNode(1));
|
||||
groupDifferentiatorSettings.set("events",
|
||||
new ArrayNode(JsonNodeFactory.instance, Utils.easyList(new IntNode(1), new IntNode(2)))
|
||||
);
|
||||
|
||||
final ObjectNode responseCombinerSettings = new ObjectNode(JsonNodeFactory.instance);
|
||||
responseCombinerSettings.set("type", new TextNode("CompetingRiskResponseCombiner"));
|
||||
|
@ -113,7 +115,7 @@ public class TestSavingLoading {
|
|||
|
||||
final CompetingRiskFunctions functions = forest.evaluate(predictionRow);
|
||||
assertNotNull(functions);
|
||||
assertTrue(functions.getCumulativeIncidenceFunction(1).getPoints().size() > 2);
|
||||
assertTrue(functions.getCumulativeIncidenceFunction(1).getX().length > 2);
|
||||
|
||||
|
||||
assertEquals(NTREE, forest.getTrees().size());
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
package ca.joeltherrien.randomforest;
|
||||
|
||||
import ca.joeltherrien.randomforest.utils.MathFunction;
|
||||
import ca.joeltherrien.randomforest.utils.Point;
|
||||
import ca.joeltherrien.randomforest.utils.StepFunction;
|
||||
import ca.joeltherrien.randomforest.utils.Utils;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
|
@ -23,53 +22,67 @@ public class TestUtils {
|
|||
*
|
||||
* @param function
|
||||
*/
|
||||
public static void assertCumulativeFunction(MathFunction function){
|
||||
Point previousPoint = null;
|
||||
for(final Point point : function.getPoints()){
|
||||
public static void assertCumulativeFunction(StepFunction function){
|
||||
Double previousTime = null;
|
||||
Double previousY = null;
|
||||
|
||||
if(previousPoint != null){
|
||||
assertTrue(previousPoint.getTime() < point.getTime(), "Points should be ordered and strictly different");
|
||||
assertTrue(previousPoint.getY() <= point.getY(), "Cumulative incidence functions are monotone");
|
||||
final double[] times = function.getX();
|
||||
|
||||
for(int i=0; i<times.length; i++){
|
||||
final double time = times[i];
|
||||
final double y = function.evaluateByIndex(i);
|
||||
|
||||
if(previousTime != null){
|
||||
assertTrue(previousTime < time, "Points should be ordered and strictly different");
|
||||
assertTrue(previousY <= y, "Cumulative incidence functions are monotone");
|
||||
}
|
||||
|
||||
previousTime = time;
|
||||
previousY = y;
|
||||
|
||||
previousPoint = point;
|
||||
}
|
||||
}
|
||||
|
||||
public static void assertSurvivalCurve(MathFunction function){
|
||||
Point previousPoint = null;
|
||||
for(final Point point : function.getPoints()){
|
||||
public static void assertSurvivalCurve(StepFunction function){
|
||||
Double previousTime = null;
|
||||
Double previousY = null;
|
||||
|
||||
if(previousPoint != null){
|
||||
assertTrue(previousPoint.getTime() < point.getTime(), "Points should be ordered and strictly different");
|
||||
assertTrue(previousPoint.getY() >= point.getY(), "Survival functions are monotone");
|
||||
final double[] times = function.getX();
|
||||
|
||||
for(int i=0; i<times.length; i++){
|
||||
final double time = times[i];
|
||||
final double y = function.evaluateByIndex(i);
|
||||
|
||||
if(previousTime != null){
|
||||
assertTrue(previousTime < time, "Points should be ordered and strictly different");
|
||||
assertTrue(previousY >= y, "Survival functions are monotone");
|
||||
}
|
||||
|
||||
previousTime = time;
|
||||
previousY = y;
|
||||
|
||||
previousPoint = point;
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testOneMinusECDF(){
|
||||
final double[] times = new double[]{1.0, 1.0, 2.0, 3.0, 3.0, 50.0};
|
||||
final MathFunction survivalCurve = Utils.estimateOneMinusECDF(times);
|
||||
final StepFunction survivalCurve = Utils.estimateOneMinusECDF(times);
|
||||
|
||||
final double margin = 0.000001;
|
||||
closeEnough(1.0, survivalCurve.evaluate(0.0).getY(), margin);
|
||||
closeEnough(1.0, survivalCurve.evaluate(0.0), margin);
|
||||
|
||||
closeEnough(1.0, survivalCurve.evaluatePrevious(1.0).getY(), margin);
|
||||
closeEnough(4.0/6.0, survivalCurve.evaluate(1.0).getY(), margin);
|
||||
closeEnough(1.0, survivalCurve.evaluatePrevious(1.0), margin);
|
||||
closeEnough(4.0/6.0, survivalCurve.evaluate(1.0), margin);
|
||||
|
||||
closeEnough(4.0/6.0, survivalCurve.evaluatePrevious(2.0).getY(), margin);
|
||||
closeEnough(3.0/6.0, survivalCurve.evaluate(2.0).getY(), margin);
|
||||
closeEnough(4.0/6.0, survivalCurve.evaluatePrevious(2.0), margin);
|
||||
closeEnough(3.0/6.0, survivalCurve.evaluate(2.0), margin);
|
||||
|
||||
closeEnough(3.0/6.0, survivalCurve.evaluatePrevious(3.0).getY(), margin);
|
||||
closeEnough(1.0/6.0, survivalCurve.evaluate(3.0).getY(), margin);
|
||||
closeEnough(3.0/6.0, survivalCurve.evaluatePrevious(3.0), margin);
|
||||
closeEnough(1.0/6.0, survivalCurve.evaluate(3.0), margin);
|
||||
|
||||
closeEnough(1.0/6.0, survivalCurve.evaluatePrevious(50.0).getY(), margin);
|
||||
closeEnough(0.0, survivalCurve.evaluate(50.0).getY(), margin);
|
||||
closeEnough(1.0/6.0, survivalCurve.evaluatePrevious(50.0), margin);
|
||||
closeEnough(0.0, survivalCurve.evaluate(50.0), margin);
|
||||
|
||||
assertSurvivalCurve(survivalCurve);
|
||||
|
||||
|
|
|
@ -0,0 +1,214 @@
|
|||
package ca.joeltherrien.randomforest.competingrisk;
|
||||
|
||||
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskGraySetsImpl;
|
||||
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponseWithCensorTime;
|
||||
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskSetsImpl;
|
||||
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskUtils;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
|
||||
public class TestCalculatingCompetingRiskSets {
|
||||
|
||||
public List<CompetingRiskResponseWithCensorTime> generateData(){
|
||||
final List<CompetingRiskResponseWithCensorTime> data = new ArrayList<>();
|
||||
|
||||
data.add(new CompetingRiskResponseWithCensorTime(1, 1, 3));
|
||||
data.add(new CompetingRiskResponseWithCensorTime(1, 1, 3));
|
||||
data.add(new CompetingRiskResponseWithCensorTime(0, 1, 1));
|
||||
data.add(new CompetingRiskResponseWithCensorTime(1, 2, 2.5));
|
||||
data.add(new CompetingRiskResponseWithCensorTime(2, 3, 4));
|
||||
data.add(new CompetingRiskResponseWithCensorTime(0, 3, 3));
|
||||
data.add(new CompetingRiskResponseWithCensorTime(1, 4, 4));
|
||||
data.add(new CompetingRiskResponseWithCensorTime(0, 5, 5));
|
||||
data.add(new CompetingRiskResponseWithCensorTime(2, 6, 7));
|
||||
|
||||
return data;
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testCalculatingSets(){
|
||||
final List data = generateData();
|
||||
|
||||
final CompetingRiskSetsImpl sets = CompetingRiskUtils.calculateSetsEfficiently(data, new int[]{1,2});
|
||||
|
||||
final List<Double> times = sets.getEventTimes();
|
||||
assertEquals(5, times.size());
|
||||
|
||||
// Times
|
||||
assertEquals(1.0, times.get(0).doubleValue());
|
||||
assertEquals(2.0, times.get(1).doubleValue());
|
||||
assertEquals(3.0, times.get(2).doubleValue());
|
||||
assertEquals(4.0, times.get(3).doubleValue());
|
||||
assertEquals(6.0, times.get(4).doubleValue());
|
||||
|
||||
// Number of Events
|
||||
assertEquals(2, sets.getNumberOfEvents(1.0, 1));
|
||||
assertEquals(0, sets.getNumberOfEvents(1.0, 2));
|
||||
|
||||
assertEquals(1, sets.getNumberOfEvents(2.0, 1));
|
||||
assertEquals(0, sets.getNumberOfEvents(2.0, 2));
|
||||
|
||||
assertEquals(0, sets.getNumberOfEvents(3.0, 1));
|
||||
assertEquals(1, sets.getNumberOfEvents(3.0, 2));
|
||||
|
||||
assertEquals(1, sets.getNumberOfEvents(4.0, 1));
|
||||
assertEquals(0, sets.getNumberOfEvents(4.0, 2));
|
||||
|
||||
assertEquals(0, sets.getNumberOfEvents(6.0, 1));
|
||||
assertEquals(1, sets.getNumberOfEvents(6.0, 2));
|
||||
|
||||
// Make sure it doesn't break for other times
|
||||
assertEquals(0, sets.getNumberOfEvents(5.5, 1));
|
||||
assertEquals(0, sets.getNumberOfEvents(5.5, 2));
|
||||
|
||||
|
||||
// Risk set
|
||||
assertEquals(9, sets.getRiskSet(1).evaluate(0.5));
|
||||
assertEquals(9, sets.getRiskSet(2).evaluate(0.5));
|
||||
|
||||
assertEquals(9, sets.getRiskSet(1).evaluate(1.0));
|
||||
assertEquals(9, sets.getRiskSet(2).evaluate(1.0));
|
||||
|
||||
assertEquals(6, sets.getRiskSet(1).evaluate(1.5));
|
||||
assertEquals(6, sets.getRiskSet(2).evaluate(1.5));
|
||||
|
||||
assertEquals(6, sets.getRiskSet(1).evaluate(2.0));
|
||||
assertEquals(6, sets.getRiskSet(2).evaluate(2.0));
|
||||
|
||||
assertEquals(5, sets.getRiskSet(1).evaluate(2.3));
|
||||
assertEquals(5, sets.getRiskSet(2).evaluate(2.3));
|
||||
|
||||
assertEquals(5, sets.getRiskSet(1).evaluate(2.5));
|
||||
assertEquals(5, sets.getRiskSet(2).evaluate(2.5));
|
||||
|
||||
assertEquals(5, sets.getRiskSet(1).evaluate(2.7));
|
||||
assertEquals(5, sets.getRiskSet(2).evaluate(2.7));
|
||||
|
||||
assertEquals(5, sets.getRiskSet(1).evaluate(3.0));
|
||||
assertEquals(5, sets.getRiskSet(2).evaluate(3.0));
|
||||
|
||||
assertEquals(3, sets.getRiskSet(1).evaluate(3.5));
|
||||
assertEquals(3, sets.getRiskSet(2).evaluate(3.5));
|
||||
|
||||
assertEquals(3, sets.getRiskSet(1).evaluate(4.0));
|
||||
assertEquals(3, sets.getRiskSet(2).evaluate(4.0));
|
||||
|
||||
assertEquals(2, sets.getRiskSet(1).evaluate(4.5));
|
||||
assertEquals(2, sets.getRiskSet(2).evaluate(4.5));
|
||||
|
||||
assertEquals(2, sets.getRiskSet(1).evaluate(5.0));
|
||||
assertEquals(2, sets.getRiskSet(2).evaluate(5.0));
|
||||
|
||||
assertEquals(1, sets.getRiskSet(1).evaluate(5.5));
|
||||
assertEquals(1, sets.getRiskSet(2).evaluate(5.5));
|
||||
|
||||
assertEquals(1, sets.getRiskSet(1).evaluate(6.0));
|
||||
assertEquals(1, sets.getRiskSet(2).evaluate(6.0));
|
||||
|
||||
assertEquals(0, sets.getRiskSet(1).evaluate(6.5));
|
||||
assertEquals(0, sets.getRiskSet(2).evaluate(6.5));
|
||||
|
||||
assertEquals(0, sets.getRiskSet(1).evaluate(7.0));
|
||||
assertEquals(0, sets.getRiskSet(2).evaluate(7.0));
|
||||
|
||||
assertEquals(0, sets.getRiskSet(1).evaluate(7.5));
|
||||
assertEquals(0, sets.getRiskSet(2).evaluate(7.5));
|
||||
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testCalculatingGraySets(){
|
||||
final List<CompetingRiskResponseWithCensorTime> data = generateData();
|
||||
|
||||
final CompetingRiskGraySetsImpl sets = CompetingRiskUtils.calculateGraySetsEfficiently(data, new int[]{1,2});
|
||||
|
||||
final List<Double> times = sets.getEventTimes();
|
||||
assertEquals(5, times.size());
|
||||
|
||||
// Times
|
||||
assertEquals(1.0, times.get(0).doubleValue());
|
||||
assertEquals(2.0, times.get(1).doubleValue());
|
||||
assertEquals(3.0, times.get(2).doubleValue());
|
||||
assertEquals(4.0, times.get(3).doubleValue());
|
||||
assertEquals(6.0, times.get(4).doubleValue());
|
||||
|
||||
// Number of Events
|
||||
assertEquals(2, sets.getNumberOfEvents(1.0, 1));
|
||||
assertEquals(0, sets.getNumberOfEvents(1.0, 2));
|
||||
|
||||
assertEquals(1, sets.getNumberOfEvents(2.0, 1));
|
||||
assertEquals(0, sets.getNumberOfEvents(2.0, 2));
|
||||
|
||||
assertEquals(0, sets.getNumberOfEvents(3.0, 1));
|
||||
assertEquals(1, sets.getNumberOfEvents(3.0, 2));
|
||||
|
||||
assertEquals(1, sets.getNumberOfEvents(4.0, 1));
|
||||
assertEquals(0, sets.getNumberOfEvents(4.0, 2));
|
||||
|
||||
assertEquals(0, sets.getNumberOfEvents(6.0, 1));
|
||||
assertEquals(1, sets.getNumberOfEvents(6.0, 2));
|
||||
|
||||
// Make sure it doesn't break for other times
|
||||
assertEquals(0, sets.getNumberOfEvents(5.5, 1));
|
||||
assertEquals(0, sets.getNumberOfEvents(5.5, 2));
|
||||
|
||||
|
||||
// Risk set
|
||||
assertEquals(9, sets.getRiskSet(1).evaluate(0.5));
|
||||
assertEquals(9, sets.getRiskSet(2).evaluate(0.5));
|
||||
|
||||
assertEquals(9, sets.getRiskSet(1).evaluate(1.0));
|
||||
assertEquals(9, sets.getRiskSet(2).evaluate(1.0));
|
||||
|
||||
assertEquals(6, sets.getRiskSet(1).evaluate(1.5));
|
||||
assertEquals(8, sets.getRiskSet(2).evaluate(1.5));
|
||||
|
||||
assertEquals(6, sets.getRiskSet(1).evaluate(2.0));
|
||||
assertEquals(8, sets.getRiskSet(2).evaluate(2.0));
|
||||
|
||||
assertEquals(5, sets.getRiskSet(1).evaluate(2.3));
|
||||
assertEquals(8, sets.getRiskSet(2).evaluate(2.3));
|
||||
|
||||
assertEquals(5, sets.getRiskSet(1).evaluate(2.5));
|
||||
assertEquals(7, sets.getRiskSet(2).evaluate(2.5));
|
||||
|
||||
assertEquals(5, sets.getRiskSet(1).evaluate(2.7));
|
||||
assertEquals(7, sets.getRiskSet(2).evaluate(2.7));
|
||||
|
||||
assertEquals(5, sets.getRiskSet(1).evaluate(3.0));
|
||||
assertEquals(5, sets.getRiskSet(2).evaluate(3.0));
|
||||
|
||||
assertEquals(4, sets.getRiskSet(1).evaluate(3.5));
|
||||
assertEquals(3, sets.getRiskSet(2).evaluate(3.5));
|
||||
|
||||
assertEquals(3, sets.getRiskSet(1).evaluate(4.0));
|
||||
assertEquals(3, sets.getRiskSet(2).evaluate(4.0));
|
||||
|
||||
assertEquals(2, sets.getRiskSet(1).evaluate(4.5));
|
||||
assertEquals(2, sets.getRiskSet(2).evaluate(4.5));
|
||||
|
||||
assertEquals(2, sets.getRiskSet(1).evaluate(5.0));
|
||||
assertEquals(2, sets.getRiskSet(2).evaluate(5.0));
|
||||
|
||||
assertEquals(1, sets.getRiskSet(1).evaluate(5.5));
|
||||
assertEquals(1, sets.getRiskSet(2).evaluate(5.5));
|
||||
|
||||
assertEquals(1, sets.getRiskSet(1).evaluate(6.0));
|
||||
assertEquals(1, sets.getRiskSet(2).evaluate(6.0));
|
||||
|
||||
assertEquals(1, sets.getRiskSet(1).evaluate(6.5));
|
||||
assertEquals(0, sets.getRiskSet(2).evaluate(6.5));
|
||||
|
||||
assertEquals(0, sets.getRiskSet(1).evaluate(7.0));
|
||||
assertEquals(0, sets.getRiskSet(2).evaluate(7.0));
|
||||
|
||||
assertEquals(0, sets.getRiskSet(1).evaluate(7.5));
|
||||
assertEquals(0, sets.getRiskSet(2).evaluate(7.5));
|
||||
|
||||
}
|
||||
|
||||
}
|
|
@ -7,8 +7,7 @@ import ca.joeltherrien.randomforest.tree.Forest;
|
|||
import ca.joeltherrien.randomforest.tree.ForestTrainer;
|
||||
import ca.joeltherrien.randomforest.tree.Node;
|
||||
import ca.joeltherrien.randomforest.tree.TreeTrainer;
|
||||
import ca.joeltherrien.randomforest.utils.MathFunction;
|
||||
import ca.joeltherrien.randomforest.utils.Point;
|
||||
import ca.joeltherrien.randomforest.utils.StepFunction;
|
||||
import ca.joeltherrien.randomforest.utils.Utils;
|
||||
import com.fasterxml.jackson.databind.node.*;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
@ -19,7 +18,6 @@ import static org.junit.jupiter.api.Assertions.*;
|
|||
|
||||
import java.io.IOException;
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
public class TestCompetingRisk {
|
||||
|
||||
|
@ -33,6 +31,9 @@ public class TestCompetingRisk {
|
|||
final ObjectNode groupDifferentiatorSettings = new ObjectNode(JsonNodeFactory.instance);
|
||||
groupDifferentiatorSettings.set("type", new TextNode("LogRankSingleGroupDifferentiator"));
|
||||
groupDifferentiatorSettings.set("eventOfFocus", new IntNode(1));
|
||||
groupDifferentiatorSettings.set("events",
|
||||
new ArrayNode(JsonNodeFactory.instance, Utils.easyList(new IntNode(1), new IntNode(2)))
|
||||
);
|
||||
|
||||
final ObjectNode responseCombinerSettings = new ObjectNode(JsonNodeFactory.instance);
|
||||
responseCombinerSettings.set("type", new TextNode("CompetingRiskResponseCombiner"));
|
||||
|
@ -109,32 +110,32 @@ public class TestCompetingRisk {
|
|||
|
||||
final CompetingRiskFunctions functions = node.evaluate(newRow);
|
||||
|
||||
final MathFunction causeOneCIFFunction = functions.getCumulativeIncidenceFunction(1);
|
||||
final MathFunction causeTwoCIFFunction = functions.getCumulativeIncidenceFunction(2);
|
||||
final MathFunction cumHazOneFunction = functions.getCauseSpecificHazardFunction(1);
|
||||
final MathFunction cumHazTwoFunction = functions.getCauseSpecificHazardFunction(2);
|
||||
final StepFunction causeOneCIFFunction = functions.getCumulativeIncidenceFunction(1);
|
||||
final StepFunction causeTwoCIFFunction = functions.getCumulativeIncidenceFunction(2);
|
||||
final StepFunction cumHazOneFunction = functions.getCauseSpecificHazardFunction(1);
|
||||
final StepFunction cumHazTwoFunction = functions.getCauseSpecificHazardFunction(2);
|
||||
|
||||
final double margin = 0.0000001;
|
||||
closeEnough(0.003003003, causeOneCIFFunction.evaluate(0.02).getY(), margin);
|
||||
closeEnough(0.166183852, causeOneCIFFunction.evaluate(1.00).getY(), margin);
|
||||
closeEnough(0.715625487, causeOneCIFFunction.evaluate(6.50).getY(), margin);
|
||||
closeEnough(0.794796334, causeOneCIFFunction.evaluate(10.60).getY(), margin);
|
||||
closeEnough(0.794796334, causeOneCIFFunction.evaluate(10.80).getY(), margin);
|
||||
closeEnough(0.003003003, causeOneCIFFunction.evaluate(0.02), margin);
|
||||
closeEnough(0.166183852, causeOneCIFFunction.evaluate(1.00), margin);
|
||||
closeEnough(0.715625487, causeOneCIFFunction.evaluate(6.50), margin);
|
||||
closeEnough(0.794796334, causeOneCIFFunction.evaluate(10.60), margin);
|
||||
closeEnough(0.794796334, causeOneCIFFunction.evaluate(10.80), margin);
|
||||
|
||||
|
||||
closeEnough(0.08149211, causeTwoCIFFunction.evaluate(1.00).getY(), margin);
|
||||
closeEnough(0.14926318, causeTwoCIFFunction.evaluate(6.50).getY(), margin);
|
||||
closeEnough(0.15332850, causeTwoCIFFunction.evaluate(10.80).getY(), margin);
|
||||
closeEnough(0.08149211, causeTwoCIFFunction.evaluate(1.00), margin);
|
||||
closeEnough(0.14926318, causeTwoCIFFunction.evaluate(6.50), margin);
|
||||
closeEnough(0.15332850, causeTwoCIFFunction.evaluate(10.80), margin);
|
||||
|
||||
|
||||
closeEnough(0.1888601, cumHazOneFunction.evaluate(1.00).getY(), margin);
|
||||
closeEnough(1.6189759, cumHazOneFunction.evaluate(6.50).getY(), margin);
|
||||
closeEnough(2.4878342, cumHazOneFunction.evaluate(10.80).getY(), margin);
|
||||
closeEnough(0.1888601, cumHazOneFunction.evaluate(1.00), margin);
|
||||
closeEnough(1.6189759, cumHazOneFunction.evaluate(6.50), margin);
|
||||
closeEnough(2.4878342, cumHazOneFunction.evaluate(10.80), margin);
|
||||
|
||||
|
||||
closeEnough(0.08946513, cumHazTwoFunction.evaluate(1.00).getY(), margin);
|
||||
closeEnough(0.32801830, cumHazTwoFunction.evaluate(6.50).getY(), margin);
|
||||
closeEnough(0.36505534, cumHazTwoFunction.evaluate(10.80).getY(), margin);
|
||||
closeEnough(0.08946513, cumHazTwoFunction.evaluate(1.00), margin);
|
||||
closeEnough(0.32801830, cumHazTwoFunction.evaluate(6.50), margin);
|
||||
closeEnough(0.36505534, cumHazTwoFunction.evaluate(10.80), margin);
|
||||
|
||||
|
||||
}
|
||||
|
@ -162,18 +163,18 @@ public class TestCompetingRisk {
|
|||
|
||||
final CompetingRiskFunctions functions = node.evaluate(newRow);
|
||||
|
||||
final MathFunction causeOneCIFFunction = functions.getCumulativeIncidenceFunction(1);
|
||||
final MathFunction causeTwoCIFFunction = functions.getCumulativeIncidenceFunction(2);
|
||||
final MathFunction cumHazOneFunction = functions.getCauseSpecificHazardFunction(1);
|
||||
final MathFunction cumHazTwoFunction = functions.getCauseSpecificHazardFunction(2);
|
||||
final StepFunction causeOneCIFFunction = functions.getCumulativeIncidenceFunction(1);
|
||||
final StepFunction causeTwoCIFFunction = functions.getCumulativeIncidenceFunction(2);
|
||||
final StepFunction cumHazOneFunction = functions.getCauseSpecificHazardFunction(1);
|
||||
final StepFunction cumHazTwoFunction = functions.getCauseSpecificHazardFunction(2);
|
||||
|
||||
|
||||
final double margin = 0.0000001;
|
||||
closeEnough(0, causeOneCIFFunction.evaluate(0.02).getY(), margin);
|
||||
closeEnough(0.555555555, causeOneCIFFunction.evaluate(0.4).getY(), margin);
|
||||
closeEnough(0.66666666666, causeOneCIFFunction.evaluate(0.8).getY(), margin);
|
||||
closeEnough(0.88888888888, causeOneCIFFunction.evaluate(0.9).getY(), margin);
|
||||
closeEnough(1.0, causeOneCIFFunction.evaluate(1.0).getY(), margin);
|
||||
closeEnough(0, causeOneCIFFunction.evaluate(0.02), margin);
|
||||
closeEnough(0.555555555, causeOneCIFFunction.evaluate(0.4), margin);
|
||||
closeEnough(0.66666666666, causeOneCIFFunction.evaluate(0.8), margin);
|
||||
closeEnough(0.88888888888, causeOneCIFFunction.evaluate(0.9), margin);
|
||||
closeEnough(1.0, causeOneCIFFunction.evaluate(1.0), margin);
|
||||
|
||||
/*
|
||||
closeEnough(0.08149211, causeTwoCIFFunction.evaluate(1.00).getY(), margin);
|
||||
|
@ -222,11 +223,11 @@ public class TestCompetingRisk {
|
|||
assertCumulativeFunction(functions.getCumulativeIncidenceFunction(2));
|
||||
|
||||
|
||||
closeEnough(0.63, functions.getCumulativeIncidenceFunction(1).evaluate(4.0).getY(), 0.01);
|
||||
closeEnough(0.765, functions.getCumulativeIncidenceFunction(1).evaluate(10.8).getY(), 0.01);
|
||||
closeEnough(0.63, functions.getCumulativeIncidenceFunction(1).evaluate(4.0), 0.01);
|
||||
closeEnough(0.765, functions.getCumulativeIncidenceFunction(1).evaluate(10.8), 0.01);
|
||||
|
||||
closeEnough(0.163, functions.getCumulativeIncidenceFunction(2).evaluate(4.0).getY(), 0.01);
|
||||
closeEnough(0.195, functions.getCumulativeIncidenceFunction(2).evaluate(10.8).getY(), 0.01);
|
||||
closeEnough(0.163, functions.getCumulativeIncidenceFunction(2).evaluate(4.0), 0.01);
|
||||
closeEnough(0.195, functions.getCumulativeIncidenceFunction(2).evaluate(10.8), 0.01);
|
||||
|
||||
final CompetingRiskErrorRateCalculator errorRateCalculator = new CompetingRiskErrorRateCalculator(dataset, forest, true);
|
||||
final double[] errorRates = errorRateCalculator.calculateConcordance(new int[]{1,2});
|
||||
|
@ -289,7 +290,8 @@ public class TestCompetingRisk {
|
|||
settings.setNtree(300); // results are too variable at 100
|
||||
|
||||
final List<Covariate> covariates = settings.getCovariates();
|
||||
final List<Row<CompetingRiskResponse>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation());
|
||||
final List<Row<CompetingRiskResponse>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(),
|
||||
settings.getTrainingDataLocation());
|
||||
final ForestTrainer<CompetingRiskResponse, CompetingRiskFunctions, CompetingRiskFunctions> forestTrainer = new ForestTrainer<>(settings, dataset, covariates);
|
||||
final Forest<CompetingRiskFunctions, CompetingRiskFunctions> forest = forestTrainer.trainSerial();
|
||||
|
||||
|
@ -305,10 +307,9 @@ public class TestCompetingRisk {
|
|||
assertCumulativeFunction(functions.getCumulativeIncidenceFunction(1));
|
||||
assertCumulativeFunction(functions.getCumulativeIncidenceFunction(2));
|
||||
|
||||
final List<Point> causeOneCIFPoints = functions.getCumulativeIncidenceFunction(1).getPoints();
|
||||
|
||||
// We seem to consistently underestimate the results.
|
||||
assertTrue(causeOneCIFPoints.get(causeOneCIFPoints.size()-1).getY() > 0.74, "Results should match randomForestSRC; had " + causeOneCIFPoints.get(causeOneCIFPoints.size()-1).getY()); // note; most observations from randomForestSRC hover around 0.78 but I've seen it as low as 0.72
|
||||
final double endProbability = functions.getCumulativeIncidenceFunction(1).evaluate(10000000);
|
||||
assertTrue(endProbability > 0.74, "Results should match randomForestSRC; had " + endProbability); // note; most observations from randomForestSRC hover around 0.78 but I've seen it as low as 0.72
|
||||
|
||||
final CompetingRiskErrorRateCalculator errorRateCalculator = new CompetingRiskErrorRateCalculator(dataset, forest, true);
|
||||
final double[] errorRates = errorRateCalculator.calculateConcordance(new int[]{1,2});
|
||||
|
|
|
@ -4,8 +4,8 @@ import ca.joeltherrien.randomforest.Row;
|
|||
import ca.joeltherrien.randomforest.covariates.Covariate;
|
||||
import ca.joeltherrien.randomforest.responses.competingrisk.*;
|
||||
import ca.joeltherrien.randomforest.tree.Forest;
|
||||
import ca.joeltherrien.randomforest.utils.MathFunction;
|
||||
import ca.joeltherrien.randomforest.utils.Point;
|
||||
import ca.joeltherrien.randomforest.utils.RightContinuousStepFunction;
|
||||
import ca.joeltherrien.randomforest.utils.StepFunction;
|
||||
import ca.joeltherrien.randomforest.utils.Utils;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
|
@ -36,7 +36,7 @@ public class TestCompetingRiskErrorRateCalculator {
|
|||
|
||||
final double naiveConcordance = CompetingRiskUtils.calculateConcordance(responseList, mortalityArray, event);
|
||||
|
||||
final MathFunction fakeCensorDistribution = new MathFunction(Collections.emptyList(), new Point(0.0, 1.0));
|
||||
final StepFunction fakeCensorDistribution = RightContinuousStepFunction.constructFromPoints(Collections.emptyList(), 1.0);
|
||||
// This distribution will make the IPCW weights == 1, giving identical results to the naive concordance.
|
||||
final double ipcwConcordance = CompetingRiskUtils.calculateIPCWConcordance(responseList, mortalityArray, event, fakeCensorDistribution);
|
||||
|
||||
|
|
|
@ -2,8 +2,9 @@ package ca.joeltherrien.randomforest.competingrisk;
|
|||
|
||||
import ca.joeltherrien.randomforest.TestUtils;
|
||||
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskFunctions;
|
||||
import ca.joeltherrien.randomforest.utils.MathFunction;
|
||||
import ca.joeltherrien.randomforest.utils.Point;
|
||||
import ca.joeltherrien.randomforest.utils.RightContinuousStepFunction;
|
||||
import ca.joeltherrien.randomforest.utils.StepFunction;
|
||||
import ca.joeltherrien.randomforest.utils.Utils;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
|
@ -13,19 +14,19 @@ public class TestCompetingRiskFunctions {
|
|||
|
||||
@Test
|
||||
public void testCalculateEventSpecificMortality(){
|
||||
final MathFunction cif1 = new MathFunction(
|
||||
final StepFunction cif1 = RightContinuousStepFunction.constructFromPoints(
|
||||
Utils.easyList(
|
||||
new Point(1.0, 0.3),
|
||||
new Point(1.5, 0.7),
|
||||
new Point(2.0, 0.8)
|
||||
), new Point(0.0 ,0.0)
|
||||
), 0.0
|
||||
);
|
||||
|
||||
// not being used
|
||||
final MathFunction chf1 = new MathFunction(Collections.emptyList());
|
||||
final StepFunction chf1 = RightContinuousStepFunction.constructFromPoints(Collections.emptyList(), 0.0);
|
||||
|
||||
// not being used
|
||||
final MathFunction km = new MathFunction(Collections.emptyList());
|
||||
final StepFunction km = RightContinuousStepFunction.constructFromPoints(Collections.emptyList(), 0.0);
|
||||
|
||||
final CompetingRiskFunctions functions = CompetingRiskFunctions.builder()
|
||||
.causeSpecificHazards(Collections.singletonList(chf1))
|
||||
|
|
|
@ -3,7 +3,7 @@ package ca.joeltherrien.randomforest.competingrisk;
|
|||
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskFunctions;
|
||||
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
|
||||
import ca.joeltherrien.randomforest.responses.competingrisk.combiner.CompetingRiskResponseCombiner;
|
||||
import ca.joeltherrien.randomforest.utils.MathFunction;
|
||||
import ca.joeltherrien.randomforest.utils.StepFunction;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import static ca.joeltherrien.randomforest.TestUtils.closeEnough;
|
||||
|
@ -33,17 +33,17 @@ public class TestCompetingRiskResponseCombiner {
|
|||
public void testCompetingRiskResponseCombiner(){
|
||||
final CompetingRiskFunctions functions = generateFunctions();
|
||||
|
||||
final MathFunction survivalCurve = functions.getSurvivalCurve();
|
||||
final StepFunction survivalCurve = functions.getSurvivalCurve();
|
||||
|
||||
// time = 1.0 1.5 2.0 2.5
|
||||
// surv = 0.7142857 0.5714286 0.1904762 0.1904762
|
||||
|
||||
final double margin = 0.0000001;
|
||||
|
||||
closeEnough(0.7142857, survivalCurve.evaluate(1.0).getY(), margin);
|
||||
closeEnough(0.5714286, survivalCurve.evaluate(1.5).getY(), margin);
|
||||
closeEnough(0.1904762, survivalCurve.evaluate(2.0).getY(), margin);
|
||||
closeEnough(0.1904762, survivalCurve.evaluate(2.5).getY(), margin);
|
||||
closeEnough(0.7142857, survivalCurve.evaluate(1.0), margin);
|
||||
closeEnough(0.5714286, survivalCurve.evaluate(1.5), margin);
|
||||
closeEnough(0.1904762, survivalCurve.evaluate(2.0), margin);
|
||||
closeEnough(0.1904762, survivalCurve.evaluate(2.5), margin);
|
||||
|
||||
|
||||
// Time = 1.0 1.5 2.0 2.5
|
||||
|
@ -53,17 +53,17 @@ public class TestCompetingRiskResponseCombiner {
|
|||
[2,] 0.0000000 0.2000000 0.5333333 0.5333333
|
||||
*/
|
||||
|
||||
final MathFunction cumHaz1 = functions.getCauseSpecificHazardFunction(1);
|
||||
closeEnough(0.2857143, cumHaz1.evaluate(1.0).getY(), margin);
|
||||
closeEnough(0.2857143, cumHaz1.evaluate(1.5).getY(), margin);
|
||||
closeEnough(0.6190476, cumHaz1.evaluate(2.0).getY(), margin);
|
||||
closeEnough(0.6190476, cumHaz1.evaluate(2.5).getY(), margin);
|
||||
final StepFunction cumHaz1 = functions.getCauseSpecificHazardFunction(1);
|
||||
closeEnough(0.2857143, cumHaz1.evaluate(1.0), margin);
|
||||
closeEnough(0.2857143, cumHaz1.evaluate(1.5), margin);
|
||||
closeEnough(0.6190476, cumHaz1.evaluate(2.0), margin);
|
||||
closeEnough(0.6190476, cumHaz1.evaluate(2.5), margin);
|
||||
|
||||
final MathFunction cumHaz2 = functions.getCauseSpecificHazardFunction(2);
|
||||
closeEnough(0.0, cumHaz2.evaluate(1.0).getY(), margin);
|
||||
closeEnough(0.2, cumHaz2.evaluate(1.5).getY(), margin);
|
||||
closeEnough(0.5333333, cumHaz2.evaluate(2.0).getY(), margin);
|
||||
closeEnough(0.5333333, cumHaz2.evaluate(2.5).getY(), margin);
|
||||
final StepFunction cumHaz2 = functions.getCauseSpecificHazardFunction(2);
|
||||
closeEnough(0.0, cumHaz2.evaluate(1.0), margin);
|
||||
closeEnough(0.2, cumHaz2.evaluate(1.5), margin);
|
||||
closeEnough(0.5333333, cumHaz2.evaluate(2.0), margin);
|
||||
closeEnough(0.5333333, cumHaz2.evaluate(2.5), margin);
|
||||
|
||||
/* Time = 1.0 1.5 2.0 2.5
|
||||
Cumulative Incidence Curve. Each row for one event.
|
||||
|
@ -72,17 +72,17 @@ public class TestCompetingRiskResponseCombiner {
|
|||
[2,] 0.0000000 0.1428571 0.3333333 0.3333333
|
||||
*/
|
||||
|
||||
final MathFunction cic1 = functions.getCumulativeIncidenceFunction(1);
|
||||
closeEnough(0.2857143, cic1.evaluate(1.0).getY(), margin);
|
||||
closeEnough(0.2857143, cic1.evaluate(1.5).getY(), margin);
|
||||
closeEnough(0.4761905, cic1.evaluate(2.0).getY(), margin);
|
||||
closeEnough(0.4761905, cic1.evaluate(2.5).getY(), margin);
|
||||
final StepFunction cic1 = functions.getCumulativeIncidenceFunction(1);
|
||||
closeEnough(0.2857143, cic1.evaluate(1.0), margin);
|
||||
closeEnough(0.2857143, cic1.evaluate(1.5), margin);
|
||||
closeEnough(0.4761905, cic1.evaluate(2.0), margin);
|
||||
closeEnough(0.4761905, cic1.evaluate(2.5), margin);
|
||||
|
||||
final MathFunction cic2 = functions.getCumulativeIncidenceFunction(2);
|
||||
closeEnough(0.0, cic2.evaluate(1.0).getY(), margin);
|
||||
closeEnough(0.1428571, cic2.evaluate(1.5).getY(), margin);
|
||||
closeEnough(0.3333333, cic2.evaluate(2.0).getY(), margin);
|
||||
closeEnough(0.3333333, cic2.evaluate(2.5).getY(), margin);
|
||||
final StepFunction cic2 = functions.getCumulativeIncidenceFunction(2);
|
||||
closeEnough(0.0, cic2.evaluate(1.0), margin);
|
||||
closeEnough(0.1428571, cic2.evaluate(1.5), margin);
|
||||
closeEnough(0.3333333, cic2.evaluate(2.0), margin);
|
||||
closeEnough(0.3333333, cic2.evaluate(2.5), margin);
|
||||
|
||||
}
|
||||
|
||||
|
|
|
@ -44,7 +44,7 @@ public class TestLogRankSingleGroupDifferentiator {
|
|||
final List<CompetingRiskResponse> data1 = generateData1();
|
||||
final List<CompetingRiskResponse> data2 = generateData2();
|
||||
|
||||
final LogRankSingleGroupDifferentiator differentiator = new LogRankSingleGroupDifferentiator(1);
|
||||
final LogRankSingleGroupDifferentiator differentiator = new LogRankSingleGroupDifferentiator(1, new int[]{1});
|
||||
|
||||
final double score = differentiator.differentiate(data1, data2);
|
||||
final double margin = 0.000001;
|
||||
|
|
|
@ -1,43 +0,0 @@
|
|||
package ca.joeltherrien.randomforest.competingrisk;
|
||||
|
||||
import ca.joeltherrien.randomforest.utils.MathFunction;
|
||||
import ca.joeltherrien.randomforest.utils.Point;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import static org.junit.jupiter.api.Assertions.*;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
public class TestMathFunction {
|
||||
|
||||
private MathFunction generateMathFunction(){
|
||||
final double[] time = new double[]{1.0, 2.0, 3.0};
|
||||
final double[] y = new double[]{-1.0, 1.0, 0.5};
|
||||
|
||||
final List<Point> pointList = new ArrayList<>();
|
||||
for(int i=0; i<time.length; i++){
|
||||
pointList.add(new Point(time[i], y[i]));
|
||||
}
|
||||
|
||||
return new MathFunction(pointList, new Point(0.0, 0.1));
|
||||
|
||||
}
|
||||
|
||||
@Test
|
||||
public void test(){
|
||||
final MathFunction function = generateMathFunction();
|
||||
|
||||
assertEquals(new Point(1.0, -1.0), function.evaluate(1.0));
|
||||
assertEquals(new Point(2.0, 1.0), function.evaluate(2.0));
|
||||
assertEquals(new Point(3.0, 0.5), function.evaluate(3.0));
|
||||
assertEquals(new Point(0.0, 0.1), function.evaluate(0.5));
|
||||
|
||||
assertEquals(new Point(1.0, -1.0), function.evaluate(1.1));
|
||||
assertEquals(new Point(2.0, 1.0), function.evaluate(2.1));
|
||||
assertEquals(new Point(3.0, 0.5), function.evaluate(3.1));
|
||||
assertEquals(new Point(0.0, 0.1), function.evaluate(0.6));
|
||||
|
||||
}
|
||||
|
||||
|
||||
}
|
|
@ -0,0 +1,60 @@
|
|||
package ca.joeltherrien.randomforest.competingrisk;
|
||||
|
||||
import ca.joeltherrien.randomforest.utils.LeftContinuousStepFunction;
|
||||
import ca.joeltherrien.randomforest.utils.RightContinuousStepFunction;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
|
||||
public class TestMathFunctions {
|
||||
|
||||
private RightContinuousStepFunction generateRightContinuousStepFunction(){
|
||||
final double[] time = new double[]{1.0, 2.0, 3.0};
|
||||
final double[] y = new double[]{-1.0, 1.0, 0.5};
|
||||
|
||||
return new RightContinuousStepFunction(time, y, 0.1);
|
||||
}
|
||||
|
||||
private LeftContinuousStepFunction generateLeftContinuousStepFunction(){
|
||||
final double[] time = new double[]{1.0, 2.0, 3.0};
|
||||
final double[] y = new double[]{-1.0, 1.0, 0.5};
|
||||
|
||||
return new LeftContinuousStepFunction(time, y, 0.1);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testRightContinuousStepFunction(){
|
||||
final RightContinuousStepFunction function = generateRightContinuousStepFunction();
|
||||
|
||||
assertEquals(0.1, function.evaluate(0.5));
|
||||
assertEquals(-1.0, function.evaluate(1.0));
|
||||
assertEquals(1.0, function.evaluate(2.0));
|
||||
assertEquals(0.5, function.evaluate(3.0));
|
||||
|
||||
|
||||
assertEquals(0.1, function.evaluate(0.6));
|
||||
assertEquals(-1.0, function.evaluate(1.1));
|
||||
assertEquals(1.0, function.evaluate(2.1));
|
||||
assertEquals(0.5, function.evaluate(3.1));
|
||||
|
||||
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testLeftContinuousStepFunction(){
|
||||
final LeftContinuousStepFunction function = generateLeftContinuousStepFunction();
|
||||
|
||||
assertEquals(0.1, function.evaluate(0.5));
|
||||
assertEquals(0.1, function.evaluate(1.0));
|
||||
assertEquals(-1.0, function.evaluate(2.0));
|
||||
assertEquals(1.0, function.evaluate(3.0));
|
||||
|
||||
|
||||
assertEquals(0.1, function.evaluate(0.6));
|
||||
assertEquals(-1.0, function.evaluate(1.1));
|
||||
assertEquals(1.0, function.evaluate(2.1));
|
||||
assertEquals(0.5, function.evaluate(3.1));
|
||||
|
||||
}
|
||||
|
||||
}
|
Loading…
Reference in a new issue