Massive optimizations;

Refactored how MathFunctions are structured to use more primitives and
less objects.
Optimized competing risk group differentiators to run faster.
Removed alternative competing risk response combiners (may be added back
later)
This commit is contained in:
Joel Therrien 2018-10-25 10:34:27 -07:00
parent cce5ad1e0f
commit c68f67e47a
36 changed files with 1223 additions and 474 deletions

View file

@ -3,20 +3,17 @@ package ca.joeltherrien.randomforest;
import ca.joeltherrien.randomforest.covariates.*; import ca.joeltherrien.randomforest.covariates.*;
import ca.joeltherrien.randomforest.responses.competingrisk.*; import ca.joeltherrien.randomforest.responses.competingrisk.*;
import ca.joeltherrien.randomforest.responses.competingrisk.combiner.CompetingRiskFunctionCombiner; import ca.joeltherrien.randomforest.responses.competingrisk.combiner.CompetingRiskFunctionCombiner;
import ca.joeltherrien.randomforest.responses.competingrisk.combiner.alternative.CompetingRiskListCombiner;
import ca.joeltherrien.randomforest.tree.Forest; import ca.joeltherrien.randomforest.tree.Forest;
import ca.joeltherrien.randomforest.tree.ForestTrainer; import ca.joeltherrien.randomforest.tree.ForestTrainer;
import ca.joeltherrien.randomforest.tree.ResponseCombiner; import ca.joeltherrien.randomforest.tree.ResponseCombiner;
import ca.joeltherrien.randomforest.utils.MathFunction; import ca.joeltherrien.randomforest.utils.StepFunction;
import ca.joeltherrien.randomforest.utils.Utils; import ca.joeltherrien.randomforest.utils.Utils;
import com.fasterxml.jackson.databind.node.JsonNodeFactory; import com.fasterxml.jackson.databind.node.JsonNodeFactory;
import com.fasterxml.jackson.databind.node.ObjectNode; import com.fasterxml.jackson.databind.node.ObjectNode;
import com.fasterxml.jackson.databind.node.TextNode; import com.fasterxml.jackson.databind.node.TextNode;
import java.io.*; import java.io.*;
import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.stream.Collectors;
public class Main { public class Main {
@ -68,9 +65,6 @@ public class Main {
if(responseCombiner instanceof CompetingRiskFunctionCombiner){ if(responseCombiner instanceof CompetingRiskFunctionCombiner){
events = ((CompetingRiskFunctionCombiner) responseCombiner).getEvents(); events = ((CompetingRiskFunctionCombiner) responseCombiner).getEvents();
} }
else if(responseCombiner instanceof CompetingRiskListCombiner){
events = ((CompetingRiskListCombiner) responseCombiner).getOriginalCombiner().getEvents();
}
else{ else{
System.out.println("Unsupported tree combiner"); System.out.println("Unsupported tree combiner");
return; return;
@ -123,7 +117,7 @@ public class Main {
final double[] censorTimes = dataset.stream() final double[] censorTimes = dataset.stream()
.mapToDouble(row -> ((CompetingRiskResponseWithCensorTime) row.getResponse()).getC()) .mapToDouble(row -> ((CompetingRiskResponseWithCensorTime) row.getResponse()).getC())
.toArray(); .toArray();
final MathFunction censorDistribution = Utils.estimateOneMinusECDF(censorTimes); final StepFunction censorDistribution = Utils.estimateOneMinusECDF(censorTimes);
System.out.println("Finished generating censor distribution - running concordance"); System.out.println("Finished generating censor distribution - running concordance");

View file

@ -2,11 +2,10 @@ package ca.joeltherrien.randomforest;
import ca.joeltherrien.randomforest.covariates.Covariate; import ca.joeltherrien.randomforest.covariates.Covariate;
import ca.joeltherrien.randomforest.covariates.CovariateSettings; import ca.joeltherrien.randomforest.covariates.CovariateSettings;
import ca.joeltherrien.randomforest.responses.competingrisk.*; import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponseWithCensorTime;
import ca.joeltherrien.randomforest.responses.competingrisk.combiner.CompetingRiskFunctionCombiner; import ca.joeltherrien.randomforest.responses.competingrisk.combiner.CompetingRiskFunctionCombiner;
import ca.joeltherrien.randomforest.responses.competingrisk.combiner.CompetingRiskResponseCombiner; import ca.joeltherrien.randomforest.responses.competingrisk.combiner.CompetingRiskResponseCombiner;
import ca.joeltherrien.randomforest.responses.competingrisk.combiner.alternative.CompetingRiskListCombiner;
import ca.joeltherrien.randomforest.responses.competingrisk.combiner.alternative.CompetingRiskResponseCombinerToList;
import ca.joeltherrien.randomforest.responses.competingrisk.differentiator.GrayLogRankMultipleGroupDifferentiator; import ca.joeltherrien.randomforest.responses.competingrisk.differentiator.GrayLogRankMultipleGroupDifferentiator;
import ca.joeltherrien.randomforest.responses.competingrisk.differentiator.GrayLogRankSingleGroupDifferentiator; import ca.joeltherrien.randomforest.responses.competingrisk.differentiator.GrayLogRankSingleGroupDifferentiator;
import ca.joeltherrien.randomforest.responses.competingrisk.differentiator.LogRankMultipleGroupDifferentiator; import ca.joeltherrien.randomforest.responses.competingrisk.differentiator.LogRankMultipleGroupDifferentiator;
@ -22,7 +21,10 @@ import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.node.JsonNodeFactory; import com.fasterxml.jackson.databind.node.JsonNodeFactory;
import com.fasterxml.jackson.databind.node.ObjectNode; import com.fasterxml.jackson.databind.node.ObjectNode;
import com.fasterxml.jackson.dataformat.yaml.YAMLFactory; import com.fasterxml.jackson.dataformat.yaml.YAMLFactory;
import lombok.*; import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.EqualsAndHashCode;
import java.io.File; import java.io.File;
import java.io.IOException; import java.io.IOException;
@ -76,7 +78,13 @@ public class Settings {
(objectNode) -> { (objectNode) -> {
final int eventOfFocus = objectNode.get("eventOfFocus").asInt(); final int eventOfFocus = objectNode.get("eventOfFocus").asInt();
return new LogRankSingleGroupDifferentiator(eventOfFocus); final Iterator<JsonNode> elements = objectNode.get("events").elements();
final List<JsonNode> elementList = new ArrayList<>();
elements.forEachRemaining(node -> elementList.add(node));
final int[] eventArray = elementList.stream().mapToInt(node -> node.asInt()).toArray();
return new LogRankSingleGroupDifferentiator(eventOfFocus, eventArray);
} }
); );
registerGroupDifferentiatorConstructor("GrayLogRankMultipleGroupDifferentiator", registerGroupDifferentiatorConstructor("GrayLogRankMultipleGroupDifferentiator",
@ -105,7 +113,14 @@ public class Settings {
(objectNode) -> { (objectNode) -> {
final int eventOfFocus = objectNode.get("eventOfFocus").asInt(); final int eventOfFocus = objectNode.get("eventOfFocus").asInt();
return new GrayLogRankSingleGroupDifferentiator(eventOfFocus); final Iterator<JsonNode> elements = objectNode.get("events").elements();
final List<JsonNode> elementList = new ArrayList<>();
elements.forEachRemaining(node -> elementList.add(node));
final int[] eventArray = elementList.stream().mapToInt(node -> node.asInt()).toArray();
return new GrayLogRankSingleGroupDifferentiator(eventOfFocus, eventArray);
} }
); );
} }
@ -154,23 +169,6 @@ public class Settings {
} }
); );
registerResponseCombinerConstructor("CompetingRiskListCombiner",
(node) -> {
final List<Integer> eventList = new ArrayList<>();
node.get("events").elements().forEachRemaining(event -> eventList.add(event.asInt()));
final int[] events = eventList.stream().mapToInt(i -> i).toArray();
final CompetingRiskResponseCombiner responseCombiner = new CompetingRiskResponseCombiner(events);
return new CompetingRiskListCombiner(responseCombiner);
}
);
registerResponseCombinerConstructor("CompetingRiskResponseCombinerToList",
(node) -> new CompetingRiskResponseCombinerToList()
);
} }

View file

@ -2,7 +2,7 @@ package ca.joeltherrien.randomforest.responses.competingrisk;
import ca.joeltherrien.randomforest.Row; import ca.joeltherrien.randomforest.Row;
import ca.joeltherrien.randomforest.tree.Forest; import ca.joeltherrien.randomforest.tree.Forest;
import ca.joeltherrien.randomforest.utils.MathFunction; import ca.joeltherrien.randomforest.utils.StepFunction;
import java.util.List; import java.util.List;
import java.util.stream.Collectors; import java.util.stream.Collectors;
@ -110,13 +110,13 @@ public class CompetingRiskErrorRateCalculator {
} }
public double[] calculateIPCWConcordance(final int[] events, final MathFunction censoringDistribution){ public double[] calculateIPCWConcordance(final int[] events, final StepFunction censoringDistribution){
final double tau = dataset.stream().mapToDouble(row -> row.getResponse().getU()).max().orElse(0.0); final double tau = dataset.stream().mapToDouble(row -> row.getResponse().getU()).max().orElse(0.0);
return calculateIPCWConcordance(events, censoringDistribution, tau); return calculateIPCWConcordance(events, censoringDistribution, tau);
} }
private double[] calculateIPCWConcordance(final int[] events, final MathFunction censoringDistribution, final double tau){ private double[] calculateIPCWConcordance(final int[] events, final StepFunction censoringDistribution, final double tau){
final double[] errorRates = new double[events.length]; final double[] errorRates = new double[events.length];

View file

@ -1,7 +1,6 @@
package ca.joeltherrien.randomforest.responses.competingrisk; package ca.joeltherrien.randomforest.responses.competingrisk;
import ca.joeltherrien.randomforest.utils.MathFunction; import ca.joeltherrien.randomforest.utils.StepFunction;
import ca.joeltherrien.randomforest.utils.Point;
import lombok.Builder; import lombok.Builder;
import lombok.Getter; import lombok.Getter;
@ -11,44 +10,48 @@ import java.util.List;
@Builder @Builder
public class CompetingRiskFunctions implements Serializable { public class CompetingRiskFunctions implements Serializable {
private final List<MathFunction> causeSpecificHazards; private final List<StepFunction> causeSpecificHazards;
private final List<MathFunction> cumulativeIncidenceCurves; private final List<StepFunction> cumulativeIncidenceCurves;
@Getter @Getter
private final MathFunction survivalCurve; private final StepFunction survivalCurve;
public MathFunction getCauseSpecificHazardFunction(int cause){ public StepFunction getCauseSpecificHazardFunction(int cause){
return causeSpecificHazards.get(cause-1); return causeSpecificHazards.get(cause-1);
} }
public MathFunction getCumulativeIncidenceFunction(int cause) { public StepFunction getCumulativeIncidenceFunction(int cause) {
return cumulativeIncidenceCurves.get(cause-1); return cumulativeIncidenceCurves.get(cause-1);
} }
public double calculateEventSpecificMortality(final int event, final double tau){ public double calculateEventSpecificMortality(final int event, final double tau){
final MathFunction cif = getCumulativeIncidenceFunction(event); final StepFunction cif = getCumulativeIncidenceFunction(event);
double summation = 0.0; double summation = 0.0;
Point previousPoint = null;
for(final Point point : cif.getPoints()){ Double previousTime = null;
if(point.getTime() > tau){ Double previousY = null;
final double[] cifTimes = cif.getX();
for(int i=0; i<cifTimes.length; i++){
final double time = cifTimes[i];
if(time > tau){
break; break;
} }
if(previousPoint != null){ if(previousTime != null){
summation += previousPoint.getY() * (point.getTime() - previousPoint.getTime()); summation += previousY * (time - previousTime);
} }
previousPoint = point; previousTime = time;
previousY = cif.evaluateByIndex(i);
} }
// this is to ensure that we integrate over the proper range // this is to ensure that we integrate over the proper range
if(previousPoint != null){ if(previousTime != null){
summation += cif.evaluate(tau).getY() * (tau - previousPoint.getTime()); summation += cif.evaluate(tau) * (tau - previousTime);
} }
return summation; return summation;
} }

View file

@ -0,0 +1,37 @@
package ca.joeltherrien.randomforest.responses.competingrisk;
import ca.joeltherrien.randomforest.utils.MathFunction;
import lombok.Builder;
import lombok.Getter;
import java.util.List;
import java.util.Map;
/**
* Represents a response from CompetingRiskUtils#calculateGraySetsEfficiently
*
*/
@Builder
@Getter
public class CompetingRiskGraySetsImpl implements CompetingRiskSets{
private final List<Double> eventTimes;
private final MathFunction[] riskSet;
private final Map<Double, int[]> numberOfEvents;
@Override
public MathFunction getRiskSet(int event){
return(riskSet[event-1]);
}
@Override
public int getNumberOfEvents(Double time, int event){
if(numberOfEvents.containsKey(time)){
return numberOfEvents.get(time)[event];
}
return 0;
}
}

View file

@ -0,0 +1,13 @@
package ca.joeltherrien.randomforest.responses.competingrisk;
import ca.joeltherrien.randomforest.utils.MathFunction;
import java.util.List;
public interface CompetingRiskSets {
MathFunction getRiskSet(int event);
int getNumberOfEvents(Double time, int event);
List<Double> getEventTimes();
}

View file

@ -0,0 +1,36 @@
package ca.joeltherrien.randomforest.responses.competingrisk;
import ca.joeltherrien.randomforest.utils.MathFunction;
import lombok.Builder;
import lombok.Getter;
import java.util.List;
import java.util.Map;
/**
* Represents a response from CompetingRiskUtils#calculateSetsEfficiently
*
*/
@Builder
@Getter
public class CompetingRiskSetsImpl implements CompetingRiskSets{
private final List<Double> eventTimes;
private final MathFunction riskSet;
private final Map<Double, int[]> numberOfEvents;
@Override
public MathFunction getRiskSet(int event){
return riskSet;
}
@Override
public int getNumberOfEvents(Double time, int event){
if(numberOfEvents.containsKey(time)){
return numberOfEvents.get(time)[event];
}
return 0;
}
}

View file

@ -1,8 +1,11 @@
package ca.joeltherrien.randomforest.responses.competingrisk; package ca.joeltherrien.randomforest.responses.competingrisk;
import ca.joeltherrien.randomforest.utils.MathFunction; import ca.joeltherrien.randomforest.utils.LeftContinuousStepFunction;
import ca.joeltherrien.randomforest.utils.StepFunction;
import ca.joeltherrien.randomforest.utils.VeryDiscontinuousStepFunction;
import java.util.List; import java.util.*;
import java.util.stream.DoubleStream;
public class CompetingRiskUtils { public class CompetingRiskUtils {
@ -46,7 +49,9 @@ public class CompetingRiskUtils {
} }
public static double calculateIPCWConcordance(final List<CompetingRiskResponse> responseList, double[] mortalityArray, final int event, final MathFunction censoringDistribution){ public static double calculateIPCWConcordance(final List<CompetingRiskResponse> responseList,
double[] mortalityArray, final int event,
final StepFunction censoringDistribution){
// Let \tau be the max time. // Let \tau be the max time.
@ -61,8 +66,8 @@ public class CompetingRiskUtils {
final double mortalityI = mortalityArray[i]; final double mortalityI = mortalityArray[i];
final double Ti = responseI.getU(); final double Ti = responseI.getU();
final double G_Ti_minus = censoringDistribution.evaluatePrevious(Ti).getY(); final double G_Ti_minus = censoringDistribution.evaluatePrevious(Ti);
final double AijWeight = 1.0 / (censoringDistribution.evaluate(Ti).getY() * G_Ti_minus); final double AijWeight = 1.0 / (censoringDistribution.evaluate(Ti) * G_Ti_minus);
for(int j=0; j<mortalityArray.length; j++){ for(int j=0; j<mortalityArray.length; j++){
final CompetingRiskResponse responseJ = responseList.get(j); final CompetingRiskResponse responseJ = responseList.get(j);
@ -73,7 +78,7 @@ public class CompetingRiskUtils {
AijWeightPlusBijWeight = AijWeight; AijWeightPlusBijWeight = AijWeight;
} }
else if(responseI.getU() >= responseJ.getU() && !responseJ.isCensored() && responseJ.getDelta() != event){ // Bij == 1 else if(responseI.getU() >= responseJ.getU() && !responseJ.isCensored() && responseJ.getDelta() != event){ // Bij == 1
AijWeightPlusBijWeight = 1.0 / (G_Ti_minus * censoringDistribution.evaluatePrevious(responseJ.getU()).getY()); AijWeightPlusBijWeight = 1.0 / (G_Ti_minus * censoringDistribution.evaluatePrevious(responseJ.getU()));
} }
else{ else{
continue; continue;
@ -97,5 +102,152 @@ public class CompetingRiskUtils {
} }
public static CompetingRiskSetsImpl calculateSetsEfficiently(final List<CompetingRiskResponse> responses, int[] eventsOfFocus){
final int n = responses.size();
int[] numberOfCurrentEvents = new int[eventsOfFocus.length+1];
final Map<Double, int[]> numberOfEvents = new HashMap<>();
final List<Double> eventTimes = new ArrayList<>(n);
final List<Double> eventAndCensorTimes = new ArrayList<>(n);
final List<Integer> riskSetNumberList = new ArrayList<>(n);
// need to first sort responses
Collections.sort(responses, (y1, y2) -> {
if(y1.getU() < y2.getU()){
return -1;
}
else if(y1.getU() > y2.getU()){
return 1;
}
else{
return 0;
}
});
for(int i=0; i<n; i++){
final CompetingRiskResponse currentResponse = responses.get(i);
final boolean lastOfTime = (i+1)==n || responses.get(i+1).getU() > currentResponse.getU();
numberOfCurrentEvents[currentResponse.getDelta()]++;
if(lastOfTime){
int totalNumberOfCurrentEvents = 0;
for(int e = 1; e < numberOfCurrentEvents.length; e++){ // exclude censored events
totalNumberOfCurrentEvents += numberOfCurrentEvents[e];
}
final double currentTime = currentResponse.getU();
if(totalNumberOfCurrentEvents > 0){ // add numberOfCurrentEvents
// Add point
eventTimes.add(currentTime);
numberOfEvents.put(currentTime, numberOfCurrentEvents);
}
// Always do risk set
// remember that the LeftContinuousFunction takes into account that at this currentTime the risk value is the previous value
final int riskSet = n - (i+1);
riskSetNumberList.add(riskSet);
eventAndCensorTimes.add(currentTime);
// reset counters
numberOfCurrentEvents = new int[eventsOfFocus.length+1];
}
}
final double[] riskSetArray = new double[eventAndCensorTimes.size()];
final double[] timesArray = new double[eventAndCensorTimes.size()];
for(int i=0; i<riskSetArray.length; i++){
timesArray[i] = eventAndCensorTimes.get(i);
riskSetArray[i] = riskSetNumberList.get(i);
}
final LeftContinuousStepFunction riskSetFunction = new LeftContinuousStepFunction(timesArray, riskSetArray, n);
return CompetingRiskSetsImpl.builder()
.numberOfEvents(numberOfEvents)
.riskSet(riskSetFunction)
.eventTimes(eventTimes)
.build();
}
public static CompetingRiskGraySetsImpl calculateGraySetsEfficiently(final List<CompetingRiskResponseWithCensorTime> responses, int[] eventsOfFocus){
final List sillyList = responses; // annoying Java generic work-around
final CompetingRiskSetsImpl originalSets = calculateSetsEfficiently(sillyList, eventsOfFocus);
final double[] allTimes = DoubleStream.concat(
responses.stream()
.mapToDouble(CompetingRiskResponseWithCensorTime::getC),
responses.stream()
.mapToDouble(CompetingRiskResponseWithCensorTime::getU)
).sorted().distinct().toArray();
final VeryDiscontinuousStepFunction[] riskSets = new VeryDiscontinuousStepFunction[eventsOfFocus.length];
for(final int event : eventsOfFocus){
final double[] yAt = new double[allTimes.length];
final double[] yRight = new double[allTimes.length];
for(final CompetingRiskResponseWithCensorTime response : responses){
if(response.getDelta() == event){
// traditional case only; increment on time t when I(t <= Ui)
final double time = response.getU();
final int index = Arrays.binarySearch(allTimes, time);
if(index < 0){ // TODO remove once code is stable
throw new IllegalStateException("Index shouldn't be negative!");
}
// All yAts up to and including index are incremented;
// All yRights up to index are incremented
yAt[index]++;
for(int i=0; i<index; i++){
yAt[i]++;
yRight[i]++;
}
}
else{
// need to increment on time t on following conditions; I(t <= Ui | t < Ci)
// Fact: Ci >= Ui.
// increment yAt up to Ci. If Ui==Ci, increment yAt at Ci.
final double time = response.getC();
final int index = Arrays.binarySearch(allTimes, time);
if(index < 0){ // TODO remove once code is stable
throw new IllegalStateException("Index shouldn't be negative!");
}
for(int i=0; i<index; i++){
yAt[i]++;
yRight[i]++;
}
if(response.getU() == response.getC()){
yAt[index]++;
}
}
}
riskSets[event-1] = new VeryDiscontinuousStepFunction(allTimes, yAt, yRight, responses.size());
}
return CompetingRiskGraySetsImpl.builder()
.numberOfEvents(originalSets.getNumberOfEvents())
.eventTimes(originalSets.getEventTimes())
.riskSet(riskSets)
.build();
}
} }

View file

@ -2,11 +2,12 @@ package ca.joeltherrien.randomforest.responses.competingrisk.combiner;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskFunctions; import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskFunctions;
import ca.joeltherrien.randomforest.tree.ResponseCombiner; import ca.joeltherrien.randomforest.tree.ResponseCombiner;
import ca.joeltherrien.randomforest.utils.MathFunction; import ca.joeltherrien.randomforest.utils.RightContinuousStepFunction;
import ca.joeltherrien.randomforest.utils.Point; import ca.joeltherrien.randomforest.utils.StepFunction;
import lombok.RequiredArgsConstructor; import lombok.RequiredArgsConstructor;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays;
import java.util.List; import java.util.List;
@RequiredArgsConstructor @RequiredArgsConstructor
@ -34,51 +35,46 @@ public class CompetingRiskFunctionCombiner implements ResponseCombiner<Competing
timesToUse = responses.stream() timesToUse = responses.stream()
.map(functions -> functions.getSurvivalCurve()) .map(functions -> functions.getSurvivalCurve())
.flatMapToDouble( .flatMapToDouble(
function -> function.getPoints().stream() function -> Arrays.stream(function.getX())
.mapToDouble(point -> point.getTime())
).sorted().distinct().toArray(); ).sorted().distinct().toArray();
} }
final double n = responses.size(); final double n = responses.size();
final List<Point> survivalPoints = new ArrayList<>(timesToUse.length); final double[] survivalY = new double[timesToUse.length];
for(final double time : timesToUse){
final double survivalY = responses.stream() for(int i=0; i<timesToUse.length; i++){
.mapToDouble(functions -> functions.getSurvivalCurve().evaluate(time).getY() / n) final double time = timesToUse[i];
survivalY[i] = responses.stream()
.mapToDouble(functions -> functions.getSurvivalCurve().evaluate(time) / n)
.sum(); .sum();
survivalPoints.add(new Point(time, survivalY));
} }
final MathFunction survivalFunction = new MathFunction(survivalPoints, new Point(0.0, 1.0)); final StepFunction survivalFunction = new RightContinuousStepFunction(timesToUse, survivalY, 1.0);
final List<MathFunction> causeSpecificCumulativeHazardFunctionList = new ArrayList<>(events.length); final List<StepFunction> causeSpecificCumulativeHazardFunctionList = new ArrayList<>(events.length);
final List<MathFunction> cumulativeIncidenceFunctionList = new ArrayList<>(events.length); final List<StepFunction> cumulativeIncidenceFunctionList = new ArrayList<>(events.length);
for(final int event : events){ for(final int event : events){
final List<Point> cumulativeHazardFunctionPoints = new ArrayList<>(timesToUse.length); final double[] cumulativeHazardFunctionY = new double[timesToUse.length];
final List<Point> cumulativeIncidenceFunctionPoints = new ArrayList<>(timesToUse.length); final double[] cumulativeIncidenceFunctionY = new double[timesToUse.length];
for(final double time : timesToUse){ for(int i=0; i<timesToUse.length; i++){
final double time = timesToUse[i];
final double hazardY = responses.stream() cumulativeHazardFunctionY[i] = responses.stream()
.mapToDouble(functions -> functions.getCauseSpecificHazardFunction(event).evaluate(time).getY() / n) .mapToDouble(functions -> functions.getCauseSpecificHazardFunction(event).evaluate(time) / n)
.sum(); .sum();
final double incidenceY = responses.stream() cumulativeIncidenceFunctionY[i] = responses.stream()
.mapToDouble(functions -> functions.getCumulativeIncidenceFunction(event).evaluate(time).getY() / n) .mapToDouble(functions -> functions.getCumulativeIncidenceFunction(event).evaluate(time) / n)
.sum(); .sum();
cumulativeHazardFunctionPoints.add(new Point(time, hazardY));
cumulativeIncidenceFunctionPoints.add(new Point(time, incidenceY));
} }
causeSpecificCumulativeHazardFunctionList.add(event-1, new MathFunction(cumulativeHazardFunctionPoints)); causeSpecificCumulativeHazardFunctionList.add(event-1, new RightContinuousStepFunction(timesToUse, cumulativeHazardFunctionY, 0));
cumulativeIncidenceFunctionList.add(event-1, new MathFunction(cumulativeIncidenceFunctionPoints)); cumulativeIncidenceFunctionList.add(event-1, new RightContinuousStepFunction(timesToUse, cumulativeIncidenceFunctionY, 0));
} }

View file

@ -3,9 +3,9 @@ package ca.joeltherrien.randomforest.responses.competingrisk.combiner;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskFunctions; import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskFunctions;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse; import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
import ca.joeltherrien.randomforest.tree.ResponseCombiner; import ca.joeltherrien.randomforest.tree.ResponseCombiner;
import ca.joeltherrien.randomforest.utils.MathFunction;
import ca.joeltherrien.randomforest.utils.Point; import ca.joeltherrien.randomforest.utils.Point;
import lombok.RequiredArgsConstructor; import ca.joeltherrien.randomforest.utils.RightContinuousStepFunction;
import ca.joeltherrien.randomforest.utils.StepFunction;
import java.util.*; import java.util.*;
@ -38,8 +38,8 @@ public class CompetingRiskResponseCombiner implements ResponseCombiner<Competing
@Override @Override
public CompetingRiskFunctions combine(List<CompetingRiskResponse> responses) { public CompetingRiskFunctions combine(List<CompetingRiskResponse> responses) {
final List<MathFunction> causeSpecificCumulativeHazardFunctionList = new ArrayList<>(events.length); final List<StepFunction> causeSpecificCumulativeHazardFunctionList = new ArrayList<>(events.length);
final List<MathFunction> cumulativeIncidenceFunctionList = new ArrayList<>(events.length); final List<StepFunction> cumulativeIncidenceFunctionList = new ArrayList<>(events.length);
Collections.sort(responses, (y1, y2) -> { Collections.sort(responses, (y1, y2) -> {
if(y1.getU() < y2.getU()){ if(y1.getU() < y2.getU()){
@ -97,7 +97,7 @@ public class CompetingRiskResponseCombiner implements ResponseCombiner<Competing
} }
} }
final MathFunction survivalCurve = new MathFunction(survivalPoints, new Point(0.0, 1.0)); final StepFunction survivalCurve = RightContinuousStepFunction.constructFromPoints(survivalPoints, 1.0);
for(final int event : events){ for(final int event : events){
@ -129,7 +129,7 @@ public class CompetingRiskResponseCombiner implements ResponseCombiner<Competing
// Cumulative incidence function // Cumulative incidence function
// TODO - confirm this behaviour // TODO - confirm this behaviour
//final double previousSurvivalEvaluation = i > 0 ? survivalCurve.evaluate(timesToUse[i-1]).getY() : survivalCurve.evaluate(0.0).getY(); //final double previousSurvivalEvaluation = i > 0 ? survivalCurve.evaluate(timesToUse[i-1]).getY() : survivalCurve.evaluate(0.0).getY();
final double previousSurvivalEvaluation = i > 0 ? survivalCurve.evaluate(timesToUseList.get(i-1)).getY() : 1.0; final double previousSurvivalEvaluation = i > 0 ? survivalCurve.evaluateByIndex(i-1) : 1.0;
final double cifDeltaY = previousSurvivalEvaluation * (numberEventsAtTime / individualsAtRisk); final double cifDeltaY = previousSurvivalEvaluation * (numberEventsAtTime / individualsAtRisk);
final Point newCIFPoint = new Point(time_k, previousCIFPoint.getY() + cifDeltaY); final Point newCIFPoint = new Point(time_k, previousCIFPoint.getY() + cifDeltaY);
@ -138,10 +138,10 @@ public class CompetingRiskResponseCombiner implements ResponseCombiner<Competing
} }
final MathFunction causeSpecificCumulativeHazardFunction = new MathFunction(hazardFunctionPoints); final StepFunction causeSpecificCumulativeHazardFunction = RightContinuousStepFunction.constructFromPoints(hazardFunctionPoints, 0.0);
causeSpecificCumulativeHazardFunctionList.add(event-1, causeSpecificCumulativeHazardFunction); causeSpecificCumulativeHazardFunctionList.add(event-1, causeSpecificCumulativeHazardFunction);
final MathFunction cifFunction = new MathFunction(cifPoints); final StepFunction cifFunction = RightContinuousStepFunction.constructFromPoints(cifPoints, 0.0);
cumulativeIncidenceFunctionList.add(event-1, cifFunction); cumulativeIncidenceFunctionList.add(event-1, cifFunction);
} }

View file

@ -1,26 +0,0 @@
package ca.joeltherrien.randomforest.responses.competingrisk.combiner.alternative;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskFunctions;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
import ca.joeltherrien.randomforest.responses.competingrisk.combiner.CompetingRiskResponseCombiner;
import ca.joeltherrien.randomforest.tree.ResponseCombiner;
import lombok.Getter;
import lombok.RequiredArgsConstructor;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;
@RequiredArgsConstructor
public class CompetingRiskListCombiner implements ResponseCombiner<CompetingRiskResponse[], CompetingRiskFunctions> {
@Getter
private final CompetingRiskResponseCombiner originalCombiner;
@Override
public CompetingRiskFunctions combine(List<CompetingRiskResponse[]> responses) {
final List<CompetingRiskResponse> completeList = responses.stream().flatMap(Arrays::stream).collect(Collectors.toList());
return originalCombiner.combine(completeList);
}
}

View file

@ -1,29 +0,0 @@
package ca.joeltherrien.randomforest.responses.competingrisk.combiner.alternative;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
import ca.joeltherrien.randomforest.tree.ResponseCombiner;
import java.util.List;
/**
* This class takes all of the observations in a terminal node and 'combines' them into just a list of the observations.
*
* This is used in the alternative approach to only compute the functions at the final stage when combining trees.
*
*/
public class CompetingRiskResponseCombinerToList implements ResponseCombiner<CompetingRiskResponse, CompetingRiskResponse[]> {
@Override
public CompetingRiskResponse[] combine(List<CompetingRiskResponse> responses) {
final CompetingRiskResponse[] array = new CompetingRiskResponse[responses.size()];
for(int i=0; i<array.length; i++){
array[i] = responses.get(i);
}
return array;
}
}

View file

@ -1,6 +1,7 @@
package ca.joeltherrien.randomforest.responses.competingrisk.differentiator; package ca.joeltherrien.randomforest.responses.competingrisk.differentiator;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse; import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskSets;
import ca.joeltherrien.randomforest.tree.GroupDifferentiator; import ca.joeltherrien.randomforest.tree.GroupDifferentiator;
import lombok.AllArgsConstructor; import lombok.AllArgsConstructor;
import lombok.Data; import lombok.Data;
@ -18,31 +19,22 @@ public abstract class CompetingRiskGroupDifferentiator<Y extends CompetingRiskRe
@Override @Override
public abstract Double differentiate(List<Y> leftHand, List<Y> rightHand); public abstract Double differentiate(List<Y> leftHand, List<Y> rightHand);
abstract double riskSet(final List<Y> eventList, double time, int eventOfFocus);
private double numberOfEventsAtTime(int eventOfFocus, List<Y> eventList, double time){
return (double) eventList.stream()
.filter(event -> event.getDelta() == eventOfFocus)
.filter(event -> event.getU() == time) // since delta != 0 we know censoring didn't occur prior to this
.count();
}
/** /**
* Calculates the log rank value (or the Gray's test value) for a *specific* event cause. * Calculates the log rank value (or the Gray's test value) for a *specific* event cause.
* *
* @param eventOfFocus * @param eventOfFocus
* @param leftHand A non-empty list of CompetingRiskResponse * @param competingRiskSetsLeft A summary of the different sets used in the calculation for the left side
* @param rightHand A non-empty list of CompetingRiskResponse * @param competingRiskSetsRight A summary of the different sets used in the calculation for the right side
* @return * @return
*/ */
LogRankValue specificLogRankValue(final int eventOfFocus, List<Y> leftHand, List<Y> rightHand){ LogRankValue specificLogRankValue(final int eventOfFocus, final CompetingRiskSets competingRiskSetsLeft, final CompetingRiskSets competingRiskSetsRight){
final double[] distinctEventTimes = Stream.concat( final double[] distinctEventTimes = Stream.concat(
leftHand.stream(), rightHand.stream() competingRiskSetsLeft.getEventTimes().stream(),
) competingRiskSetsRight.getEventTimes().stream())
.filter(event -> !event.isCensored()) .mapToDouble(Double::doubleValue)
.mapToDouble(event -> event.getU()) .sorted()
.distinct() .distinct()
.toArray(); .toArray();
@ -51,12 +43,12 @@ public abstract class CompetingRiskGroupDifferentiator<Y extends CompetingRiskRe
for(final double time_k : distinctEventTimes){ for(final double time_k : distinctEventTimes){
final double weight = weight(time_k); // W_j(t_k) final double weight = weight(time_k); // W_j(t_k)
final double numberEventsAtTimeDaughterLeft = numberOfEventsAtTime(eventOfFocus, leftHand, time_k); // d_{j,l}(t_k) final double numberEventsAtTimeDaughterLeft = competingRiskSetsLeft.getNumberOfEvents(time_k, eventOfFocus); // // d_{j,l}(t_k)
final double numberEventsAtTimeDaughterRight = numberOfEventsAtTime(eventOfFocus, rightHand, time_k); // d_{j,r}(t_k) final double numberEventsAtTimeDaughterRight = competingRiskSetsRight.getNumberOfEvents(time_k, eventOfFocus); // d_{j,r}(t_k)
final double numberOfEventsAtTime = numberEventsAtTimeDaughterLeft + numberEventsAtTimeDaughterRight; // d_j(t_k) final double numberOfEventsAtTime = numberEventsAtTimeDaughterLeft + numberEventsAtTimeDaughterRight; // d_j(t_k)
final double individualsAtRiskDaughterLeft = riskSet(leftHand, time_k, eventOfFocus); // Y_l(t_k) final double individualsAtRiskDaughterLeft = competingRiskSetsLeft.getRiskSet(eventOfFocus).evaluate(time_k); // Y_l(t_k)
final double individualsAtRiskDaughterRight = riskSet(rightHand, time_k, eventOfFocus); // Y_r(t_k) final double individualsAtRiskDaughterRight = competingRiskSetsRight.getRiskSet(eventOfFocus).evaluate(time_k); // Y_r(t_k)
final double individualsAtRisk = individualsAtRiskDaughterLeft + individualsAtRiskDaughterRight; // Y(t_k) final double individualsAtRisk = individualsAtRiskDaughterLeft + individualsAtRiskDaughterRight; // Y(t_k)
final double deltaSummation = weight*(numberEventsAtTimeDaughterLeft - numberOfEventsAtTime*individualsAtRiskDaughterLeft/individualsAtRisk); final double deltaSummation = weight*(numberEventsAtTimeDaughterLeft - numberOfEventsAtTime*individualsAtRiskDaughterLeft/individualsAtRisk);

View file

@ -1,6 +1,8 @@
package ca.joeltherrien.randomforest.responses.competingrisk.differentiator; package ca.joeltherrien.randomforest.responses.competingrisk.differentiator;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskGraySetsImpl;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponseWithCensorTime; import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponseWithCensorTime;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskUtils;
import lombok.RequiredArgsConstructor; import lombok.RequiredArgsConstructor;
import java.util.List; import java.util.List;
@ -20,11 +22,14 @@ public class GrayLogRankMultipleGroupDifferentiator extends CompetingRiskGroupDi
return null; return null;
} }
final CompetingRiskGraySetsImpl competingRiskSetsLeft = CompetingRiskUtils.calculateGraySetsEfficiently(leftHand, events);
final CompetingRiskGraySetsImpl competingRiskSetsRight = CompetingRiskUtils.calculateGraySetsEfficiently(rightHand, events);
double numerator = 0.0; double numerator = 0.0;
double denominatorSquared = 0.0; double denominatorSquared = 0.0;
for(final int eventOfFocus : events){ for(final int eventOfFocus : events){
final LogRankValue valueOfInterest = specificLogRankValue(eventOfFocus, leftHand, rightHand); final LogRankValue valueOfInterest = specificLogRankValue(eventOfFocus, competingRiskSetsLeft, competingRiskSetsRight);
numerator += valueOfInterest.getNumerator()*valueOfInterest.getVarianceSqrt(); numerator += valueOfInterest.getNumerator()*valueOfInterest.getVarianceSqrt();
denominatorSquared += valueOfInterest.getVariance(); denominatorSquared += valueOfInterest.getVariance();
@ -35,13 +40,5 @@ public class GrayLogRankMultipleGroupDifferentiator extends CompetingRiskGroupDi
} }
@Override
double riskSet(List<CompetingRiskResponseWithCensorTime> eventList, double time, int eventOfFocus) {
return eventList.stream()
.filter(event -> event.getU() >= time ||
(event.getU() < time && event.getDelta() != eventOfFocus && event.getC() > time)
)
.count();
}
} }

View file

@ -1,6 +1,8 @@
package ca.joeltherrien.randomforest.responses.competingrisk.differentiator; package ca.joeltherrien.randomforest.responses.competingrisk.differentiator;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskGraySetsImpl;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponseWithCensorTime; import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponseWithCensorTime;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskUtils;
import lombok.RequiredArgsConstructor; import lombok.RequiredArgsConstructor;
import java.util.List; import java.util.List;
@ -13,6 +15,7 @@ import java.util.List;
public class GrayLogRankSingleGroupDifferentiator extends CompetingRiskGroupDifferentiator<CompetingRiskResponseWithCensorTime> { public class GrayLogRankSingleGroupDifferentiator extends CompetingRiskGroupDifferentiator<CompetingRiskResponseWithCensorTime> {
private final int eventOfFocus; private final int eventOfFocus;
private final int[] events;
@Override @Override
public Double differentiate(List<CompetingRiskResponseWithCensorTime> leftHand, List<CompetingRiskResponseWithCensorTime> rightHand) { public Double differentiate(List<CompetingRiskResponseWithCensorTime> leftHand, List<CompetingRiskResponseWithCensorTime> rightHand) {
@ -20,19 +23,13 @@ public class GrayLogRankSingleGroupDifferentiator extends CompetingRiskGroupDiff
return null; return null;
} }
final LogRankValue valueOfInterest = specificLogRankValue(eventOfFocus, leftHand, rightHand); final CompetingRiskGraySetsImpl competingRiskSetsLeft = CompetingRiskUtils.calculateGraySetsEfficiently(leftHand, events);
final CompetingRiskGraySetsImpl competingRiskSetsRight = CompetingRiskUtils.calculateGraySetsEfficiently(rightHand, events);
final LogRankValue valueOfInterest = specificLogRankValue(eventOfFocus, competingRiskSetsLeft, competingRiskSetsRight);
return Math.abs(valueOfInterest.getNumerator() / valueOfInterest.getVarianceSqrt()); return Math.abs(valueOfInterest.getNumerator() / valueOfInterest.getVarianceSqrt());
} }
@Override
double riskSet(List<CompetingRiskResponseWithCensorTime> eventList, double time, int eventOfFocus) {
return eventList.stream()
.filter(event -> event.getU() >= time ||
(event.getU() < time && event.getDelta() != eventOfFocus && event.getC() > time)
)
.count();
}
} }

View file

@ -1,6 +1,8 @@
package ca.joeltherrien.randomforest.responses.competingrisk.differentiator; package ca.joeltherrien.randomforest.responses.competingrisk.differentiator;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse; import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskSetsImpl;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskUtils;
import lombok.RequiredArgsConstructor; import lombok.RequiredArgsConstructor;
import java.util.List; import java.util.List;
@ -20,11 +22,14 @@ public class LogRankMultipleGroupDifferentiator extends CompetingRiskGroupDiffer
return null; return null;
} }
final CompetingRiskSetsImpl competingRiskSetsLeft = CompetingRiskUtils.calculateSetsEfficiently(leftHand, events);
final CompetingRiskSetsImpl competingRiskSetsRight = CompetingRiskUtils.calculateSetsEfficiently(rightHand, events);
double numerator = 0.0; double numerator = 0.0;
double denominatorSquared = 0.0; double denominatorSquared = 0.0;
for(final int eventOfFocus : events){ for(final int eventOfFocus : events){
final LogRankValue valueOfInterest = specificLogRankValue(eventOfFocus, leftHand, rightHand); final LogRankValue valueOfInterest = specificLogRankValue(eventOfFocus, competingRiskSetsLeft, competingRiskSetsRight);
numerator += valueOfInterest.getNumerator()*valueOfInterest.getVarianceSqrt(); numerator += valueOfInterest.getNumerator()*valueOfInterest.getVarianceSqrt();
denominatorSquared += valueOfInterest.getVariance(); denominatorSquared += valueOfInterest.getVariance();
@ -35,11 +40,4 @@ public class LogRankMultipleGroupDifferentiator extends CompetingRiskGroupDiffer
} }
@Override
double riskSet(List<CompetingRiskResponse> eventList, double time, int eventOfFocus) {
return eventList.stream()
.filter(event -> event.getU() >= time)
.count();
}
} }

View file

@ -1,6 +1,8 @@
package ca.joeltherrien.randomforest.responses.competingrisk.differentiator; package ca.joeltherrien.randomforest.responses.competingrisk.differentiator;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse; import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskSetsImpl;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskUtils;
import lombok.RequiredArgsConstructor; import lombok.RequiredArgsConstructor;
import java.util.List; import java.util.List;
@ -13,6 +15,7 @@ import java.util.List;
public class LogRankSingleGroupDifferentiator extends CompetingRiskGroupDifferentiator<CompetingRiskResponse> { public class LogRankSingleGroupDifferentiator extends CompetingRiskGroupDifferentiator<CompetingRiskResponse> {
private final int eventOfFocus; private final int eventOfFocus;
private final int[] events;
@Override @Override
public Double differentiate(List<CompetingRiskResponse> leftHand, List<CompetingRiskResponse> rightHand) { public Double differentiate(List<CompetingRiskResponse> leftHand, List<CompetingRiskResponse> rightHand) {
@ -20,17 +23,13 @@ public class LogRankSingleGroupDifferentiator extends CompetingRiskGroupDifferen
return null; return null;
} }
final LogRankValue valueOfInterest = specificLogRankValue(eventOfFocus, leftHand, rightHand); final CompetingRiskSetsImpl competingRiskSetsLeft = CompetingRiskUtils.calculateSetsEfficiently(leftHand, events);
final CompetingRiskSetsImpl competingRiskSetsRight = CompetingRiskUtils.calculateSetsEfficiently(rightHand, events);
final LogRankValue valueOfInterest = specificLogRankValue(eventOfFocus, competingRiskSetsLeft, competingRiskSetsRight);
return Math.abs(valueOfInterest.getNumerator() / valueOfInterest.getVarianceSqrt()); return Math.abs(valueOfInterest.getNumerator() / valueOfInterest.getVarianceSqrt());
} }
@Override
double riskSet(List<CompetingRiskResponse> eventList, double time, int eventOfFocus) {
return eventList.stream()
.filter(event -> event.getU() >= time)
.count();
}
} }

View file

@ -0,0 +1,97 @@
package ca.joeltherrien.randomforest.utils;
/**
* Represents a function represented by discrete points. However, the function may be right-continuous or left-continuous
* at a given point, with no consistency. This function tracks that.
*/
public final class DiscontinuousStepFunction extends StepFunction {
private final double[] y;
private final boolean[] isLeftContinuous;
/**
* Represents the value that should be returned by evaluate if there are points prior to the time the function is being evaluated at.
*
* Map be null.
*/
private final double defaultY;
public DiscontinuousStepFunction(double[] x, double[] y, boolean[] isLeftContinuous, double defaultY) {
super(x);
this.y = y;
this.isLeftContinuous = isLeftContinuous;
this.defaultY = defaultY;
}
@Override
public double evaluate(double time){
int index = Utils.binarySearchLessThan(0, x.length, x, time);
if(index < 0){
return defaultY;
}
else{
if(x[index] == time){
return evaluateByIndex(index);
}
else{
return y[index];
}
}
}
@Override
public double evaluatePrevious(double time){
int index = Utils.binarySearchLessThan(0, x.length, x, time) - 1;
if(index < 0){
return defaultY;
}
else{
if(x[index] == time){
return evaluateByIndex(index);
}
else{
return y[index];
}
}
}
@Override
public double evaluateByIndex(int i) {
if(isLeftContinuous[i]){
i -= 1;
}
if(i < 0){
return defaultY;
}
return y[i];
}
@Override
public String toString(){
final StringBuilder builder = new StringBuilder();
builder.append("Default point: ");
builder.append(defaultY);
builder.append("\n");
for(int i=0; i<x.length; i++){
builder.append("x:");
builder.append(x[i]);
builder.append('\t');
if(isLeftContinuous[i]){
builder.append("*y:");
}
else{
builder.append("y*:");
}
builder.append(y[i]);
builder.append("\n");
}
return builder.toString();
}
}

View file

@ -0,0 +1,113 @@
package ca.joeltherrien.randomforest.utils;
import java.util.List;
import java.util.ListIterator;
/**
* Represents a function represented by discrete points. We assume that the function is a stepwise left-continuous
* function, constant at the value of the previous encountered point.
*
*/
public final class LeftContinuousStepFunction extends StepFunction {
private final double[] y;
/**
* Represents the value that should be returned by evaluate if there are points prior to the time the function is being evaluated at.
*
* Map be null.
*/
private final double defaultY;
public LeftContinuousStepFunction(double[] x, double[] y, double defaultY) {
super(x);
this.y = y;
this.defaultY = defaultY;
}
/**
* This isn't a formal constructor because of limitations with abstract classes.
*
* @param pointList
* @param defaultY
* @return
*/
public static LeftContinuousStepFunction constructFromPoints(final List<Point> pointList, final double defaultY){
final double[] x = new double[pointList.size()];
final double[] y = new double[pointList.size()];
final ListIterator<Point> pointIterator = pointList.listIterator();
while(pointIterator.hasNext()){
final int index = pointIterator.nextIndex();
final Point currentPoint = pointIterator.next();
x[index] = currentPoint.getTime();
y[index] = currentPoint.getY();
}
return new LeftContinuousStepFunction(x, y, defaultY);
}
@Override
public double evaluate(double time){
int index = Utils.binarySearchLessThan(0, x.length, x, time);
if(index < 0){
return defaultY;
}
else{
if(x[index] == time){
return evaluateByIndex(index-1);
}
else{
return y[index];
}
}
}
@Override
public double evaluatePrevious(double time){
int index = Utils.binarySearchLessThan(0, x.length, x, time) - 1;
if(index < 0){
return defaultY;
}
else{
if(x[index] == time){
return evaluateByIndex(index-1);
}
else{
return y[index];
}
}
}
@Override
public double evaluateByIndex(int i) {
if(i < 0){
return defaultY;
}
return y[i];
}
@Override
public String toString(){
final StringBuilder builder = new StringBuilder();
builder.append("Default point: ");
builder.append(defaultY);
builder.append("\n");
for(int i=0; i<x.length; i++){
builder.append("x:");
builder.append(x[i]);
builder.append("\ty:");
builder.append(y[i]);
builder.append("\n");
}
return builder.toString();
}
}

View file

@ -1,114 +1,9 @@
package ca.joeltherrien.randomforest.utils; package ca.joeltherrien.randomforest.utils;
import lombok.Getter;
import java.io.Serializable; import java.io.Serializable;
import java.util.Collections;
import java.util.List;
/** public interface MathFunction extends Serializable {
* Represents a function represented by discrete points. We assume that the function is a stepwise continuous function,
* constant at the value of the previous encountered point.
*
*/
public class MathFunction implements Serializable {
@Getter
private final List<Point> points;
/**
* Represents the value that should be returned by evaluate if there are points prior to the time the function is being evaluated at.
*
* Map be null.
*/
private final Point defaultValue;
public MathFunction(final List<Point> points){
this(points, new Point(0.0, 0.0));
}
public MathFunction(final List<Point> points, final Point defaultValue){
this.points = Collections.unmodifiableList(points);
this.defaultValue = defaultValue;
}
public Point evaluate(double time){
int index = binarySearch(points, time);
if(index < 0){
return defaultValue;
}
else{
return points.get(index);
}
}
public Point evaluatePrevious(double time){
int index = binarySearch(points, time) - 1;
if(index < 0){
return defaultValue;
}
else{
return points.get(index);
}
double evaluate(double time);
}
/**
* Returns the index of the largest (in terms of time) Point that is <= the provided time value.
*
* @param points
* @param time
* @return The index of the largest Point who's time is <= the time parameter.
*/
private static int binarySearch(List<Point> points, double time){
final int pointSize = points.size();
if(pointSize == 0 || points.get(pointSize-1).getTime() <= time){
// we're already too far
return pointSize - 1;
}
if(pointSize < 200){
for(int i = 0; i < pointSize; i++){
if(points.get(i).getTime() > time){
return i - 1;
}
}
}
// else
final int middle = pointSize / 2;
final double middleTime = points.get(middle).getTime();
if(middleTime < time){
// go right
return binarySearch(points.subList(middle, pointSize), time) + middle;
}
else if(middleTime > time){
// go left
return binarySearch(points.subList(0, middle), time);
}
else{ // middleTime == time
return middle;
}
}
@Override
public String toString(){
final StringBuilder builder = new StringBuilder();
builder.append("Default point: ");
builder.append(defaultValue);
builder.append("\n");
for(final Point point : points){
builder.append(point);
builder.append("\n");
}
return builder.toString();
}
} }

View file

@ -11,26 +11,12 @@ import java.util.zip.GZIPOutputStream;
*/ */
public final class RUtils { public final class RUtils {
public static double[] extractTimes(final MathFunction function){ public static double[] extractTimes(final RightContinuousStepFunction function){
final List<Point> pointList = function.getPoints(); return function.getX();
final double[] times = new double[pointList.size()];
for(int i=0; i<pointList.size(); i++){
times[i] = pointList.get(i).getTime();
} }
return times; public static double[] extractY(final RightContinuousStepFunction function){
} return function.getY();
public static double[] extractY(final MathFunction function){
final List<Point> pointList = function.getPoints();
final double[] times = new double[pointList.size()];
for(int i=0; i<pointList.size(); i++){
times[i] = pointList.get(i).getY();
}
return times;
} }
/** /**

View file

@ -0,0 +1,107 @@
package ca.joeltherrien.randomforest.utils;
import java.util.List;
import java.util.ListIterator;
/**
* Represents a function represented by discrete points. We assume that the function is a stepwise right-continuous
* function, constant at the value of the previous encountered point.
*
*/
public final class RightContinuousStepFunction extends StepFunction {
private final double[] y;
/**
* Represents the value that should be returned by evaluate if there are points prior to the time the function is being evaluated at.
*
* Map be null.
*/
private final double defaultY;
public RightContinuousStepFunction(double[] x, double[] y, double defaultY) {
super(x);
this.y = y;
this.defaultY = defaultY;
}
/**
* This isn't a formal constructor because of limitations with abstract classes.
*
* @param pointList
* @param defaultY
* @return
*/
public static RightContinuousStepFunction constructFromPoints(final List<Point> pointList, final double defaultY){
final double[] x = new double[pointList.size()];
final double[] y = new double[pointList.size()];
final ListIterator<Point> pointIterator = pointList.listIterator();
while(pointIterator.hasNext()){
final int index = pointIterator.nextIndex();
final Point currentPoint = pointIterator.next();
x[index] = currentPoint.getTime();
y[index] = currentPoint.getY();
}
return new RightContinuousStepFunction(x, y, defaultY);
}
public double[] getY(){
return y.clone();
}
@Override
public double evaluate(double time){
int index = Utils.binarySearchLessThan(0, x.length, x, time);
if(index < 0){
return defaultY;
}
else{
return y[index];
}
}
@Override
public double evaluatePrevious(double time){
int index = Utils.binarySearchLessThan(0, x.length, x, time) - 1;
if(index < 0){
return defaultY;
}
else{
return y[index];
}
}
@Override
public double evaluateByIndex(int i) {
if(i < 0){
return defaultY;
}
return y[i];
}
@Override
public String toString(){
final StringBuilder builder = new StringBuilder();
builder.append("Default point: ");
builder.append(defaultY);
builder.append("\n");
for(int i=0; i<x.length; i++){
builder.append("x:");
builder.append(x[i]);
builder.append("\ty:");
builder.append(y[i]);
builder.append("\n");
}
return builder.toString();
}
}

View file

@ -0,0 +1,26 @@
package ca.joeltherrien.randomforest.utils;
public abstract class StepFunction implements MathFunction{
protected final double[] x;
StepFunction(double[] x){
this.x = x;
}
public double[] getX() {
return x.clone();
}
public abstract double evaluateByIndex(int i);
/**
* Evaluate the function at the time *point* that occurred previous to time. This is NOT time - some delta, but rather
* time[i-1].
*
* @param time
* @return
*/
public abstract double evaluatePrevious(double time);
}

View file

@ -0,0 +1,15 @@
package ca.joeltherrien.randomforest.utils;
import lombok.RequiredArgsConstructor;
@RequiredArgsConstructor
public class SumFunction implements MathFunction{
private final MathFunction function1;
private final MathFunction function2;
@Override
public double evaluate(double time) {
return function1.evaluate(time) + function2.evaluate(time);
}
}

View file

@ -5,8 +5,7 @@ import java.util.concurrent.ThreadLocalRandom;
public final class Utils { public final class Utils {
public static MathFunction estimateOneMinusECDF(final double[] times){ public static StepFunction estimateOneMinusECDF(final double[] times){
final Point defaultPoint = new Point(0.0, 1.0);
Arrays.sort(times); Arrays.sort(times);
final Map<Double, Integer> timeCounterMap = new HashMap<>(); final Map<Double, Integer> timeCounterMap = new HashMap<>();
@ -33,7 +32,7 @@ public final class Utils {
pointList.add(new Point(entry.getKey(), (double) newCount / n)); pointList.add(new Point(entry.getKey(), (double) newCount / n));
} }
return new MathFunction(pointList, defaultPoint); return RightContinuousStepFunction.constructFromPoints(pointList, 1.0);
} }
@ -64,6 +63,48 @@ public final class Utils {
} }
} }
/**
* Returns the index of the largest (in terms of time) Point that is <= the provided time value.
*
* @param startIndex Only search from startIndex (inclusive)
* @param endIndex Only search up to endIndex (exclusive)
* @param time
* @return The index of the largest Point who's time is <= the time parameter.
*/
public static int binarySearchLessThan(int startIndex, int endIndex, double[] x, double time){
final int range = endIndex - startIndex;
if(range == 0 || x[endIndex-1] <= time){
// we're already too far
return endIndex - 1;
}
if(range < 200){
for(int i = startIndex; i < endIndex; i++){
if(x[i] > time){
return i - 1;
}
}
}
// else
final int middle = range / 2;
final double middleTime = x[middle];
if(middleTime < time){
// go right
return binarySearchLessThan(middle, endIndex, x, time);
}
else if(middleTime > time){
// go left
return binarySearchLessThan(0, middle, x, time);
}
else{ // middleTime == time
return middle;
}
}
/** /**
* Replacement for Java 9's List.of * Replacement for Java 9's List.of
* *

View file

@ -0,0 +1,64 @@
package ca.joeltherrien.randomforest.utils;
/**
* Represents a step function represented by discrete points. However, there may be individual time values that has
* a y value that doesn't belong to a particular 'step'.
*/
public final class VeryDiscontinuousStepFunction implements MathFunction {
private final double[] x;
private final double[] yAt;
private final double[] yRight;
/**
* Represents the value that should be returned by evaluate if there are points prior to the time the function is being evaluated at.
*
* Map be null.
*/
private final double defaultY;
public VeryDiscontinuousStepFunction(double[] x, double[] yAt, double[] yRight, double defaultY) {
this.x = x;
this.yAt = yAt;
this.yRight = yRight;
this.defaultY = defaultY;
}
@Override
public double evaluate(double time){
int index = Utils.binarySearchLessThan(0, x.length, x, time);
if(index < 0){
return defaultY;
}
else{
if(x[index] == time){
return yAt[index];
}
else{ // time > x[index]
return yRight[index];
}
}
}
@Override
public String toString(){
final StringBuilder builder = new StringBuilder();
builder.append("Default point: ");
builder.append(defaultY);
builder.append("\n");
for(int i=0; i<x.length; i++){
builder.append("x:");
builder.append(x[i]);
builder.append("\tyAt:");
builder.append(yAt[i]);
builder.append("\tyRight:");
builder.append(yRight[i]);
builder.append("\n");
}
return builder.toString();
}
}

View file

@ -16,7 +16,6 @@ import static org.junit.jupiter.api.Assertions.*;
import java.io.File; import java.io.File;
import java.io.IOException; import java.io.IOException;
import java.util.List; import java.util.List;
import java.util.stream.Collectors;
public class TestSavingLoading { public class TestSavingLoading {
@ -31,6 +30,9 @@ public class TestSavingLoading {
final ObjectNode groupDifferentiatorSettings = new ObjectNode(JsonNodeFactory.instance); final ObjectNode groupDifferentiatorSettings = new ObjectNode(JsonNodeFactory.instance);
groupDifferentiatorSettings.set("type", new TextNode("LogRankSingleGroupDifferentiator")); groupDifferentiatorSettings.set("type", new TextNode("LogRankSingleGroupDifferentiator"));
groupDifferentiatorSettings.set("eventOfFocus", new IntNode(1)); groupDifferentiatorSettings.set("eventOfFocus", new IntNode(1));
groupDifferentiatorSettings.set("events",
new ArrayNode(JsonNodeFactory.instance, Utils.easyList(new IntNode(1), new IntNode(2)))
);
final ObjectNode responseCombinerSettings = new ObjectNode(JsonNodeFactory.instance); final ObjectNode responseCombinerSettings = new ObjectNode(JsonNodeFactory.instance);
responseCombinerSettings.set("type", new TextNode("CompetingRiskResponseCombiner")); responseCombinerSettings.set("type", new TextNode("CompetingRiskResponseCombiner"));
@ -113,7 +115,7 @@ public class TestSavingLoading {
final CompetingRiskFunctions functions = forest.evaluate(predictionRow); final CompetingRiskFunctions functions = forest.evaluate(predictionRow);
assertNotNull(functions); assertNotNull(functions);
assertTrue(functions.getCumulativeIncidenceFunction(1).getPoints().size() > 2); assertTrue(functions.getCumulativeIncidenceFunction(1).getX().length > 2);
assertEquals(NTREE, forest.getTrees().size()); assertEquals(NTREE, forest.getTrees().size());

View file

@ -1,7 +1,6 @@
package ca.joeltherrien.randomforest; package ca.joeltherrien.randomforest;
import ca.joeltherrien.randomforest.utils.MathFunction; import ca.joeltherrien.randomforest.utils.StepFunction;
import ca.joeltherrien.randomforest.utils.Point;
import ca.joeltherrien.randomforest.utils.Utils; import ca.joeltherrien.randomforest.utils.Utils;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
@ -23,53 +22,67 @@ public class TestUtils {
* *
* @param function * @param function
*/ */
public static void assertCumulativeFunction(MathFunction function){ public static void assertCumulativeFunction(StepFunction function){
Point previousPoint = null; Double previousTime = null;
for(final Point point : function.getPoints()){ Double previousY = null;
if(previousPoint != null){ final double[] times = function.getX();
assertTrue(previousPoint.getTime() < point.getTime(), "Points should be ordered and strictly different");
assertTrue(previousPoint.getY() <= point.getY(), "Cumulative incidence functions are monotone"); for(int i=0; i<times.length; i++){
final double time = times[i];
final double y = function.evaluateByIndex(i);
if(previousTime != null){
assertTrue(previousTime < time, "Points should be ordered and strictly different");
assertTrue(previousY <= y, "Cumulative incidence functions are monotone");
} }
previousTime = time;
previousY = y;
previousPoint = point;
} }
} }
public static void assertSurvivalCurve(MathFunction function){ public static void assertSurvivalCurve(StepFunction function){
Point previousPoint = null; Double previousTime = null;
for(final Point point : function.getPoints()){ Double previousY = null;
if(previousPoint != null){ final double[] times = function.getX();
assertTrue(previousPoint.getTime() < point.getTime(), "Points should be ordered and strictly different");
assertTrue(previousPoint.getY() >= point.getY(), "Survival functions are monotone"); for(int i=0; i<times.length; i++){
final double time = times[i];
final double y = function.evaluateByIndex(i);
if(previousTime != null){
assertTrue(previousTime < time, "Points should be ordered and strictly different");
assertTrue(previousY >= y, "Survival functions are monotone");
} }
previousTime = time;
previousY = y;
previousPoint = point;
} }
} }
@Test @Test
public void testOneMinusECDF(){ public void testOneMinusECDF(){
final double[] times = new double[]{1.0, 1.0, 2.0, 3.0, 3.0, 50.0}; final double[] times = new double[]{1.0, 1.0, 2.0, 3.0, 3.0, 50.0};
final MathFunction survivalCurve = Utils.estimateOneMinusECDF(times); final StepFunction survivalCurve = Utils.estimateOneMinusECDF(times);
final double margin = 0.000001; final double margin = 0.000001;
closeEnough(1.0, survivalCurve.evaluate(0.0).getY(), margin); closeEnough(1.0, survivalCurve.evaluate(0.0), margin);
closeEnough(1.0, survivalCurve.evaluatePrevious(1.0).getY(), margin); closeEnough(1.0, survivalCurve.evaluatePrevious(1.0), margin);
closeEnough(4.0/6.0, survivalCurve.evaluate(1.0).getY(), margin); closeEnough(4.0/6.0, survivalCurve.evaluate(1.0), margin);
closeEnough(4.0/6.0, survivalCurve.evaluatePrevious(2.0).getY(), margin); closeEnough(4.0/6.0, survivalCurve.evaluatePrevious(2.0), margin);
closeEnough(3.0/6.0, survivalCurve.evaluate(2.0).getY(), margin); closeEnough(3.0/6.0, survivalCurve.evaluate(2.0), margin);
closeEnough(3.0/6.0, survivalCurve.evaluatePrevious(3.0).getY(), margin); closeEnough(3.0/6.0, survivalCurve.evaluatePrevious(3.0), margin);
closeEnough(1.0/6.0, survivalCurve.evaluate(3.0).getY(), margin); closeEnough(1.0/6.0, survivalCurve.evaluate(3.0), margin);
closeEnough(1.0/6.0, survivalCurve.evaluatePrevious(50.0).getY(), margin); closeEnough(1.0/6.0, survivalCurve.evaluatePrevious(50.0), margin);
closeEnough(0.0, survivalCurve.evaluate(50.0).getY(), margin); closeEnough(0.0, survivalCurve.evaluate(50.0), margin);
assertSurvivalCurve(survivalCurve); assertSurvivalCurve(survivalCurve);

View file

@ -0,0 +1,214 @@
package ca.joeltherrien.randomforest.competingrisk;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskGraySetsImpl;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponseWithCensorTime;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskSetsImpl;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskUtils;
import org.junit.jupiter.api.Test;
import java.util.ArrayList;
import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
public class TestCalculatingCompetingRiskSets {
public List<CompetingRiskResponseWithCensorTime> generateData(){
final List<CompetingRiskResponseWithCensorTime> data = new ArrayList<>();
data.add(new CompetingRiskResponseWithCensorTime(1, 1, 3));
data.add(new CompetingRiskResponseWithCensorTime(1, 1, 3));
data.add(new CompetingRiskResponseWithCensorTime(0, 1, 1));
data.add(new CompetingRiskResponseWithCensorTime(1, 2, 2.5));
data.add(new CompetingRiskResponseWithCensorTime(2, 3, 4));
data.add(new CompetingRiskResponseWithCensorTime(0, 3, 3));
data.add(new CompetingRiskResponseWithCensorTime(1, 4, 4));
data.add(new CompetingRiskResponseWithCensorTime(0, 5, 5));
data.add(new CompetingRiskResponseWithCensorTime(2, 6, 7));
return data;
}
@Test
public void testCalculatingSets(){
final List data = generateData();
final CompetingRiskSetsImpl sets = CompetingRiskUtils.calculateSetsEfficiently(data, new int[]{1,2});
final List<Double> times = sets.getEventTimes();
assertEquals(5, times.size());
// Times
assertEquals(1.0, times.get(0).doubleValue());
assertEquals(2.0, times.get(1).doubleValue());
assertEquals(3.0, times.get(2).doubleValue());
assertEquals(4.0, times.get(3).doubleValue());
assertEquals(6.0, times.get(4).doubleValue());
// Number of Events
assertEquals(2, sets.getNumberOfEvents(1.0, 1));
assertEquals(0, sets.getNumberOfEvents(1.0, 2));
assertEquals(1, sets.getNumberOfEvents(2.0, 1));
assertEquals(0, sets.getNumberOfEvents(2.0, 2));
assertEquals(0, sets.getNumberOfEvents(3.0, 1));
assertEquals(1, sets.getNumberOfEvents(3.0, 2));
assertEquals(1, sets.getNumberOfEvents(4.0, 1));
assertEquals(0, sets.getNumberOfEvents(4.0, 2));
assertEquals(0, sets.getNumberOfEvents(6.0, 1));
assertEquals(1, sets.getNumberOfEvents(6.0, 2));
// Make sure it doesn't break for other times
assertEquals(0, sets.getNumberOfEvents(5.5, 1));
assertEquals(0, sets.getNumberOfEvents(5.5, 2));
// Risk set
assertEquals(9, sets.getRiskSet(1).evaluate(0.5));
assertEquals(9, sets.getRiskSet(2).evaluate(0.5));
assertEquals(9, sets.getRiskSet(1).evaluate(1.0));
assertEquals(9, sets.getRiskSet(2).evaluate(1.0));
assertEquals(6, sets.getRiskSet(1).evaluate(1.5));
assertEquals(6, sets.getRiskSet(2).evaluate(1.5));
assertEquals(6, sets.getRiskSet(1).evaluate(2.0));
assertEquals(6, sets.getRiskSet(2).evaluate(2.0));
assertEquals(5, sets.getRiskSet(1).evaluate(2.3));
assertEquals(5, sets.getRiskSet(2).evaluate(2.3));
assertEquals(5, sets.getRiskSet(1).evaluate(2.5));
assertEquals(5, sets.getRiskSet(2).evaluate(2.5));
assertEquals(5, sets.getRiskSet(1).evaluate(2.7));
assertEquals(5, sets.getRiskSet(2).evaluate(2.7));
assertEquals(5, sets.getRiskSet(1).evaluate(3.0));
assertEquals(5, sets.getRiskSet(2).evaluate(3.0));
assertEquals(3, sets.getRiskSet(1).evaluate(3.5));
assertEquals(3, sets.getRiskSet(2).evaluate(3.5));
assertEquals(3, sets.getRiskSet(1).evaluate(4.0));
assertEquals(3, sets.getRiskSet(2).evaluate(4.0));
assertEquals(2, sets.getRiskSet(1).evaluate(4.5));
assertEquals(2, sets.getRiskSet(2).evaluate(4.5));
assertEquals(2, sets.getRiskSet(1).evaluate(5.0));
assertEquals(2, sets.getRiskSet(2).evaluate(5.0));
assertEquals(1, sets.getRiskSet(1).evaluate(5.5));
assertEquals(1, sets.getRiskSet(2).evaluate(5.5));
assertEquals(1, sets.getRiskSet(1).evaluate(6.0));
assertEquals(1, sets.getRiskSet(2).evaluate(6.0));
assertEquals(0, sets.getRiskSet(1).evaluate(6.5));
assertEquals(0, sets.getRiskSet(2).evaluate(6.5));
assertEquals(0, sets.getRiskSet(1).evaluate(7.0));
assertEquals(0, sets.getRiskSet(2).evaluate(7.0));
assertEquals(0, sets.getRiskSet(1).evaluate(7.5));
assertEquals(0, sets.getRiskSet(2).evaluate(7.5));
}
@Test
public void testCalculatingGraySets(){
final List<CompetingRiskResponseWithCensorTime> data = generateData();
final CompetingRiskGraySetsImpl sets = CompetingRiskUtils.calculateGraySetsEfficiently(data, new int[]{1,2});
final List<Double> times = sets.getEventTimes();
assertEquals(5, times.size());
// Times
assertEquals(1.0, times.get(0).doubleValue());
assertEquals(2.0, times.get(1).doubleValue());
assertEquals(3.0, times.get(2).doubleValue());
assertEquals(4.0, times.get(3).doubleValue());
assertEquals(6.0, times.get(4).doubleValue());
// Number of Events
assertEquals(2, sets.getNumberOfEvents(1.0, 1));
assertEquals(0, sets.getNumberOfEvents(1.0, 2));
assertEquals(1, sets.getNumberOfEvents(2.0, 1));
assertEquals(0, sets.getNumberOfEvents(2.0, 2));
assertEquals(0, sets.getNumberOfEvents(3.0, 1));
assertEquals(1, sets.getNumberOfEvents(3.0, 2));
assertEquals(1, sets.getNumberOfEvents(4.0, 1));
assertEquals(0, sets.getNumberOfEvents(4.0, 2));
assertEquals(0, sets.getNumberOfEvents(6.0, 1));
assertEquals(1, sets.getNumberOfEvents(6.0, 2));
// Make sure it doesn't break for other times
assertEquals(0, sets.getNumberOfEvents(5.5, 1));
assertEquals(0, sets.getNumberOfEvents(5.5, 2));
// Risk set
assertEquals(9, sets.getRiskSet(1).evaluate(0.5));
assertEquals(9, sets.getRiskSet(2).evaluate(0.5));
assertEquals(9, sets.getRiskSet(1).evaluate(1.0));
assertEquals(9, sets.getRiskSet(2).evaluate(1.0));
assertEquals(6, sets.getRiskSet(1).evaluate(1.5));
assertEquals(8, sets.getRiskSet(2).evaluate(1.5));
assertEquals(6, sets.getRiskSet(1).evaluate(2.0));
assertEquals(8, sets.getRiskSet(2).evaluate(2.0));
assertEquals(5, sets.getRiskSet(1).evaluate(2.3));
assertEquals(8, sets.getRiskSet(2).evaluate(2.3));
assertEquals(5, sets.getRiskSet(1).evaluate(2.5));
assertEquals(7, sets.getRiskSet(2).evaluate(2.5));
assertEquals(5, sets.getRiskSet(1).evaluate(2.7));
assertEquals(7, sets.getRiskSet(2).evaluate(2.7));
assertEquals(5, sets.getRiskSet(1).evaluate(3.0));
assertEquals(5, sets.getRiskSet(2).evaluate(3.0));
assertEquals(4, sets.getRiskSet(1).evaluate(3.5));
assertEquals(3, sets.getRiskSet(2).evaluate(3.5));
assertEquals(3, sets.getRiskSet(1).evaluate(4.0));
assertEquals(3, sets.getRiskSet(2).evaluate(4.0));
assertEquals(2, sets.getRiskSet(1).evaluate(4.5));
assertEquals(2, sets.getRiskSet(2).evaluate(4.5));
assertEquals(2, sets.getRiskSet(1).evaluate(5.0));
assertEquals(2, sets.getRiskSet(2).evaluate(5.0));
assertEquals(1, sets.getRiskSet(1).evaluate(5.5));
assertEquals(1, sets.getRiskSet(2).evaluate(5.5));
assertEquals(1, sets.getRiskSet(1).evaluate(6.0));
assertEquals(1, sets.getRiskSet(2).evaluate(6.0));
assertEquals(1, sets.getRiskSet(1).evaluate(6.5));
assertEquals(0, sets.getRiskSet(2).evaluate(6.5));
assertEquals(0, sets.getRiskSet(1).evaluate(7.0));
assertEquals(0, sets.getRiskSet(2).evaluate(7.0));
assertEquals(0, sets.getRiskSet(1).evaluate(7.5));
assertEquals(0, sets.getRiskSet(2).evaluate(7.5));
}
}

View file

@ -7,8 +7,7 @@ import ca.joeltherrien.randomforest.tree.Forest;
import ca.joeltherrien.randomforest.tree.ForestTrainer; import ca.joeltherrien.randomforest.tree.ForestTrainer;
import ca.joeltherrien.randomforest.tree.Node; import ca.joeltherrien.randomforest.tree.Node;
import ca.joeltherrien.randomforest.tree.TreeTrainer; import ca.joeltherrien.randomforest.tree.TreeTrainer;
import ca.joeltherrien.randomforest.utils.MathFunction; import ca.joeltherrien.randomforest.utils.StepFunction;
import ca.joeltherrien.randomforest.utils.Point;
import ca.joeltherrien.randomforest.utils.Utils; import ca.joeltherrien.randomforest.utils.Utils;
import com.fasterxml.jackson.databind.node.*; import com.fasterxml.jackson.databind.node.*;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
@ -19,7 +18,6 @@ import static org.junit.jupiter.api.Assertions.*;
import java.io.IOException; import java.io.IOException;
import java.util.List; import java.util.List;
import java.util.stream.Collectors;
public class TestCompetingRisk { public class TestCompetingRisk {
@ -33,6 +31,9 @@ public class TestCompetingRisk {
final ObjectNode groupDifferentiatorSettings = new ObjectNode(JsonNodeFactory.instance); final ObjectNode groupDifferentiatorSettings = new ObjectNode(JsonNodeFactory.instance);
groupDifferentiatorSettings.set("type", new TextNode("LogRankSingleGroupDifferentiator")); groupDifferentiatorSettings.set("type", new TextNode("LogRankSingleGroupDifferentiator"));
groupDifferentiatorSettings.set("eventOfFocus", new IntNode(1)); groupDifferentiatorSettings.set("eventOfFocus", new IntNode(1));
groupDifferentiatorSettings.set("events",
new ArrayNode(JsonNodeFactory.instance, Utils.easyList(new IntNode(1), new IntNode(2)))
);
final ObjectNode responseCombinerSettings = new ObjectNode(JsonNodeFactory.instance); final ObjectNode responseCombinerSettings = new ObjectNode(JsonNodeFactory.instance);
responseCombinerSettings.set("type", new TextNode("CompetingRiskResponseCombiner")); responseCombinerSettings.set("type", new TextNode("CompetingRiskResponseCombiner"));
@ -109,32 +110,32 @@ public class TestCompetingRisk {
final CompetingRiskFunctions functions = node.evaluate(newRow); final CompetingRiskFunctions functions = node.evaluate(newRow);
final MathFunction causeOneCIFFunction = functions.getCumulativeIncidenceFunction(1); final StepFunction causeOneCIFFunction = functions.getCumulativeIncidenceFunction(1);
final MathFunction causeTwoCIFFunction = functions.getCumulativeIncidenceFunction(2); final StepFunction causeTwoCIFFunction = functions.getCumulativeIncidenceFunction(2);
final MathFunction cumHazOneFunction = functions.getCauseSpecificHazardFunction(1); final StepFunction cumHazOneFunction = functions.getCauseSpecificHazardFunction(1);
final MathFunction cumHazTwoFunction = functions.getCauseSpecificHazardFunction(2); final StepFunction cumHazTwoFunction = functions.getCauseSpecificHazardFunction(2);
final double margin = 0.0000001; final double margin = 0.0000001;
closeEnough(0.003003003, causeOneCIFFunction.evaluate(0.02).getY(), margin); closeEnough(0.003003003, causeOneCIFFunction.evaluate(0.02), margin);
closeEnough(0.166183852, causeOneCIFFunction.evaluate(1.00).getY(), margin); closeEnough(0.166183852, causeOneCIFFunction.evaluate(1.00), margin);
closeEnough(0.715625487, causeOneCIFFunction.evaluate(6.50).getY(), margin); closeEnough(0.715625487, causeOneCIFFunction.evaluate(6.50), margin);
closeEnough(0.794796334, causeOneCIFFunction.evaluate(10.60).getY(), margin); closeEnough(0.794796334, causeOneCIFFunction.evaluate(10.60), margin);
closeEnough(0.794796334, causeOneCIFFunction.evaluate(10.80).getY(), margin); closeEnough(0.794796334, causeOneCIFFunction.evaluate(10.80), margin);
closeEnough(0.08149211, causeTwoCIFFunction.evaluate(1.00).getY(), margin); closeEnough(0.08149211, causeTwoCIFFunction.evaluate(1.00), margin);
closeEnough(0.14926318, causeTwoCIFFunction.evaluate(6.50).getY(), margin); closeEnough(0.14926318, causeTwoCIFFunction.evaluate(6.50), margin);
closeEnough(0.15332850, causeTwoCIFFunction.evaluate(10.80).getY(), margin); closeEnough(0.15332850, causeTwoCIFFunction.evaluate(10.80), margin);
closeEnough(0.1888601, cumHazOneFunction.evaluate(1.00).getY(), margin); closeEnough(0.1888601, cumHazOneFunction.evaluate(1.00), margin);
closeEnough(1.6189759, cumHazOneFunction.evaluate(6.50).getY(), margin); closeEnough(1.6189759, cumHazOneFunction.evaluate(6.50), margin);
closeEnough(2.4878342, cumHazOneFunction.evaluate(10.80).getY(), margin); closeEnough(2.4878342, cumHazOneFunction.evaluate(10.80), margin);
closeEnough(0.08946513, cumHazTwoFunction.evaluate(1.00).getY(), margin); closeEnough(0.08946513, cumHazTwoFunction.evaluate(1.00), margin);
closeEnough(0.32801830, cumHazTwoFunction.evaluate(6.50).getY(), margin); closeEnough(0.32801830, cumHazTwoFunction.evaluate(6.50), margin);
closeEnough(0.36505534, cumHazTwoFunction.evaluate(10.80).getY(), margin); closeEnough(0.36505534, cumHazTwoFunction.evaluate(10.80), margin);
} }
@ -162,18 +163,18 @@ public class TestCompetingRisk {
final CompetingRiskFunctions functions = node.evaluate(newRow); final CompetingRiskFunctions functions = node.evaluate(newRow);
final MathFunction causeOneCIFFunction = functions.getCumulativeIncidenceFunction(1); final StepFunction causeOneCIFFunction = functions.getCumulativeIncidenceFunction(1);
final MathFunction causeTwoCIFFunction = functions.getCumulativeIncidenceFunction(2); final StepFunction causeTwoCIFFunction = functions.getCumulativeIncidenceFunction(2);
final MathFunction cumHazOneFunction = functions.getCauseSpecificHazardFunction(1); final StepFunction cumHazOneFunction = functions.getCauseSpecificHazardFunction(1);
final MathFunction cumHazTwoFunction = functions.getCauseSpecificHazardFunction(2); final StepFunction cumHazTwoFunction = functions.getCauseSpecificHazardFunction(2);
final double margin = 0.0000001; final double margin = 0.0000001;
closeEnough(0, causeOneCIFFunction.evaluate(0.02).getY(), margin); closeEnough(0, causeOneCIFFunction.evaluate(0.02), margin);
closeEnough(0.555555555, causeOneCIFFunction.evaluate(0.4).getY(), margin); closeEnough(0.555555555, causeOneCIFFunction.evaluate(0.4), margin);
closeEnough(0.66666666666, causeOneCIFFunction.evaluate(0.8).getY(), margin); closeEnough(0.66666666666, causeOneCIFFunction.evaluate(0.8), margin);
closeEnough(0.88888888888, causeOneCIFFunction.evaluate(0.9).getY(), margin); closeEnough(0.88888888888, causeOneCIFFunction.evaluate(0.9), margin);
closeEnough(1.0, causeOneCIFFunction.evaluate(1.0).getY(), margin); closeEnough(1.0, causeOneCIFFunction.evaluate(1.0), margin);
/* /*
closeEnough(0.08149211, causeTwoCIFFunction.evaluate(1.00).getY(), margin); closeEnough(0.08149211, causeTwoCIFFunction.evaluate(1.00).getY(), margin);
@ -222,11 +223,11 @@ public class TestCompetingRisk {
assertCumulativeFunction(functions.getCumulativeIncidenceFunction(2)); assertCumulativeFunction(functions.getCumulativeIncidenceFunction(2));
closeEnough(0.63, functions.getCumulativeIncidenceFunction(1).evaluate(4.0).getY(), 0.01); closeEnough(0.63, functions.getCumulativeIncidenceFunction(1).evaluate(4.0), 0.01);
closeEnough(0.765, functions.getCumulativeIncidenceFunction(1).evaluate(10.8).getY(), 0.01); closeEnough(0.765, functions.getCumulativeIncidenceFunction(1).evaluate(10.8), 0.01);
closeEnough(0.163, functions.getCumulativeIncidenceFunction(2).evaluate(4.0).getY(), 0.01); closeEnough(0.163, functions.getCumulativeIncidenceFunction(2).evaluate(4.0), 0.01);
closeEnough(0.195, functions.getCumulativeIncidenceFunction(2).evaluate(10.8).getY(), 0.01); closeEnough(0.195, functions.getCumulativeIncidenceFunction(2).evaluate(10.8), 0.01);
final CompetingRiskErrorRateCalculator errorRateCalculator = new CompetingRiskErrorRateCalculator(dataset, forest, true); final CompetingRiskErrorRateCalculator errorRateCalculator = new CompetingRiskErrorRateCalculator(dataset, forest, true);
final double[] errorRates = errorRateCalculator.calculateConcordance(new int[]{1,2}); final double[] errorRates = errorRateCalculator.calculateConcordance(new int[]{1,2});
@ -289,7 +290,8 @@ public class TestCompetingRisk {
settings.setNtree(300); // results are too variable at 100 settings.setNtree(300); // results are too variable at 100
final List<Covariate> covariates = settings.getCovariates(); final List<Covariate> covariates = settings.getCovariates();
final List<Row<CompetingRiskResponse>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation()); final List<Row<CompetingRiskResponse>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(),
settings.getTrainingDataLocation());
final ForestTrainer<CompetingRiskResponse, CompetingRiskFunctions, CompetingRiskFunctions> forestTrainer = new ForestTrainer<>(settings, dataset, covariates); final ForestTrainer<CompetingRiskResponse, CompetingRiskFunctions, CompetingRiskFunctions> forestTrainer = new ForestTrainer<>(settings, dataset, covariates);
final Forest<CompetingRiskFunctions, CompetingRiskFunctions> forest = forestTrainer.trainSerial(); final Forest<CompetingRiskFunctions, CompetingRiskFunctions> forest = forestTrainer.trainSerial();
@ -305,10 +307,9 @@ public class TestCompetingRisk {
assertCumulativeFunction(functions.getCumulativeIncidenceFunction(1)); assertCumulativeFunction(functions.getCumulativeIncidenceFunction(1));
assertCumulativeFunction(functions.getCumulativeIncidenceFunction(2)); assertCumulativeFunction(functions.getCumulativeIncidenceFunction(2));
final List<Point> causeOneCIFPoints = functions.getCumulativeIncidenceFunction(1).getPoints();
// We seem to consistently underestimate the results. // We seem to consistently underestimate the results.
assertTrue(causeOneCIFPoints.get(causeOneCIFPoints.size()-1).getY() > 0.74, "Results should match randomForestSRC; had " + causeOneCIFPoints.get(causeOneCIFPoints.size()-1).getY()); // note; most observations from randomForestSRC hover around 0.78 but I've seen it as low as 0.72 final double endProbability = functions.getCumulativeIncidenceFunction(1).evaluate(10000000);
assertTrue(endProbability > 0.74, "Results should match randomForestSRC; had " + endProbability); // note; most observations from randomForestSRC hover around 0.78 but I've seen it as low as 0.72
final CompetingRiskErrorRateCalculator errorRateCalculator = new CompetingRiskErrorRateCalculator(dataset, forest, true); final CompetingRiskErrorRateCalculator errorRateCalculator = new CompetingRiskErrorRateCalculator(dataset, forest, true);
final double[] errorRates = errorRateCalculator.calculateConcordance(new int[]{1,2}); final double[] errorRates = errorRateCalculator.calculateConcordance(new int[]{1,2});

View file

@ -4,8 +4,8 @@ import ca.joeltherrien.randomforest.Row;
import ca.joeltherrien.randomforest.covariates.Covariate; import ca.joeltherrien.randomforest.covariates.Covariate;
import ca.joeltherrien.randomforest.responses.competingrisk.*; import ca.joeltherrien.randomforest.responses.competingrisk.*;
import ca.joeltherrien.randomforest.tree.Forest; import ca.joeltherrien.randomforest.tree.Forest;
import ca.joeltherrien.randomforest.utils.MathFunction; import ca.joeltherrien.randomforest.utils.RightContinuousStepFunction;
import ca.joeltherrien.randomforest.utils.Point; import ca.joeltherrien.randomforest.utils.StepFunction;
import ca.joeltherrien.randomforest.utils.Utils; import ca.joeltherrien.randomforest.utils.Utils;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
@ -36,7 +36,7 @@ public class TestCompetingRiskErrorRateCalculator {
final double naiveConcordance = CompetingRiskUtils.calculateConcordance(responseList, mortalityArray, event); final double naiveConcordance = CompetingRiskUtils.calculateConcordance(responseList, mortalityArray, event);
final MathFunction fakeCensorDistribution = new MathFunction(Collections.emptyList(), new Point(0.0, 1.0)); final StepFunction fakeCensorDistribution = RightContinuousStepFunction.constructFromPoints(Collections.emptyList(), 1.0);
// This distribution will make the IPCW weights == 1, giving identical results to the naive concordance. // This distribution will make the IPCW weights == 1, giving identical results to the naive concordance.
final double ipcwConcordance = CompetingRiskUtils.calculateIPCWConcordance(responseList, mortalityArray, event, fakeCensorDistribution); final double ipcwConcordance = CompetingRiskUtils.calculateIPCWConcordance(responseList, mortalityArray, event, fakeCensorDistribution);

View file

@ -2,8 +2,9 @@ package ca.joeltherrien.randomforest.competingrisk;
import ca.joeltherrien.randomforest.TestUtils; import ca.joeltherrien.randomforest.TestUtils;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskFunctions; import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskFunctions;
import ca.joeltherrien.randomforest.utils.MathFunction;
import ca.joeltherrien.randomforest.utils.Point; import ca.joeltherrien.randomforest.utils.Point;
import ca.joeltherrien.randomforest.utils.RightContinuousStepFunction;
import ca.joeltherrien.randomforest.utils.StepFunction;
import ca.joeltherrien.randomforest.utils.Utils; import ca.joeltherrien.randomforest.utils.Utils;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
@ -13,19 +14,19 @@ public class TestCompetingRiskFunctions {
@Test @Test
public void testCalculateEventSpecificMortality(){ public void testCalculateEventSpecificMortality(){
final MathFunction cif1 = new MathFunction( final StepFunction cif1 = RightContinuousStepFunction.constructFromPoints(
Utils.easyList( Utils.easyList(
new Point(1.0, 0.3), new Point(1.0, 0.3),
new Point(1.5, 0.7), new Point(1.5, 0.7),
new Point(2.0, 0.8) new Point(2.0, 0.8)
), new Point(0.0 ,0.0) ), 0.0
); );
// not being used // not being used
final MathFunction chf1 = new MathFunction(Collections.emptyList()); final StepFunction chf1 = RightContinuousStepFunction.constructFromPoints(Collections.emptyList(), 0.0);
// not being used // not being used
final MathFunction km = new MathFunction(Collections.emptyList()); final StepFunction km = RightContinuousStepFunction.constructFromPoints(Collections.emptyList(), 0.0);
final CompetingRiskFunctions functions = CompetingRiskFunctions.builder() final CompetingRiskFunctions functions = CompetingRiskFunctions.builder()
.causeSpecificHazards(Collections.singletonList(chf1)) .causeSpecificHazards(Collections.singletonList(chf1))

View file

@ -3,7 +3,7 @@ package ca.joeltherrien.randomforest.competingrisk;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskFunctions; import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskFunctions;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse; import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
import ca.joeltherrien.randomforest.responses.competingrisk.combiner.CompetingRiskResponseCombiner; import ca.joeltherrien.randomforest.responses.competingrisk.combiner.CompetingRiskResponseCombiner;
import ca.joeltherrien.randomforest.utils.MathFunction; import ca.joeltherrien.randomforest.utils.StepFunction;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import static ca.joeltherrien.randomforest.TestUtils.closeEnough; import static ca.joeltherrien.randomforest.TestUtils.closeEnough;
@ -33,17 +33,17 @@ public class TestCompetingRiskResponseCombiner {
public void testCompetingRiskResponseCombiner(){ public void testCompetingRiskResponseCombiner(){
final CompetingRiskFunctions functions = generateFunctions(); final CompetingRiskFunctions functions = generateFunctions();
final MathFunction survivalCurve = functions.getSurvivalCurve(); final StepFunction survivalCurve = functions.getSurvivalCurve();
// time = 1.0 1.5 2.0 2.5 // time = 1.0 1.5 2.0 2.5
// surv = 0.7142857 0.5714286 0.1904762 0.1904762 // surv = 0.7142857 0.5714286 0.1904762 0.1904762
final double margin = 0.0000001; final double margin = 0.0000001;
closeEnough(0.7142857, survivalCurve.evaluate(1.0).getY(), margin); closeEnough(0.7142857, survivalCurve.evaluate(1.0), margin);
closeEnough(0.5714286, survivalCurve.evaluate(1.5).getY(), margin); closeEnough(0.5714286, survivalCurve.evaluate(1.5), margin);
closeEnough(0.1904762, survivalCurve.evaluate(2.0).getY(), margin); closeEnough(0.1904762, survivalCurve.evaluate(2.0), margin);
closeEnough(0.1904762, survivalCurve.evaluate(2.5).getY(), margin); closeEnough(0.1904762, survivalCurve.evaluate(2.5), margin);
// Time = 1.0 1.5 2.0 2.5 // Time = 1.0 1.5 2.0 2.5
@ -53,17 +53,17 @@ public class TestCompetingRiskResponseCombiner {
[2,] 0.0000000 0.2000000 0.5333333 0.5333333 [2,] 0.0000000 0.2000000 0.5333333 0.5333333
*/ */
final MathFunction cumHaz1 = functions.getCauseSpecificHazardFunction(1); final StepFunction cumHaz1 = functions.getCauseSpecificHazardFunction(1);
closeEnough(0.2857143, cumHaz1.evaluate(1.0).getY(), margin); closeEnough(0.2857143, cumHaz1.evaluate(1.0), margin);
closeEnough(0.2857143, cumHaz1.evaluate(1.5).getY(), margin); closeEnough(0.2857143, cumHaz1.evaluate(1.5), margin);
closeEnough(0.6190476, cumHaz1.evaluate(2.0).getY(), margin); closeEnough(0.6190476, cumHaz1.evaluate(2.0), margin);
closeEnough(0.6190476, cumHaz1.evaluate(2.5).getY(), margin); closeEnough(0.6190476, cumHaz1.evaluate(2.5), margin);
final MathFunction cumHaz2 = functions.getCauseSpecificHazardFunction(2); final StepFunction cumHaz2 = functions.getCauseSpecificHazardFunction(2);
closeEnough(0.0, cumHaz2.evaluate(1.0).getY(), margin); closeEnough(0.0, cumHaz2.evaluate(1.0), margin);
closeEnough(0.2, cumHaz2.evaluate(1.5).getY(), margin); closeEnough(0.2, cumHaz2.evaluate(1.5), margin);
closeEnough(0.5333333, cumHaz2.evaluate(2.0).getY(), margin); closeEnough(0.5333333, cumHaz2.evaluate(2.0), margin);
closeEnough(0.5333333, cumHaz2.evaluate(2.5).getY(), margin); closeEnough(0.5333333, cumHaz2.evaluate(2.5), margin);
/* Time = 1.0 1.5 2.0 2.5 /* Time = 1.0 1.5 2.0 2.5
Cumulative Incidence Curve. Each row for one event. Cumulative Incidence Curve. Each row for one event.
@ -72,17 +72,17 @@ public class TestCompetingRiskResponseCombiner {
[2,] 0.0000000 0.1428571 0.3333333 0.3333333 [2,] 0.0000000 0.1428571 0.3333333 0.3333333
*/ */
final MathFunction cic1 = functions.getCumulativeIncidenceFunction(1); final StepFunction cic1 = functions.getCumulativeIncidenceFunction(1);
closeEnough(0.2857143, cic1.evaluate(1.0).getY(), margin); closeEnough(0.2857143, cic1.evaluate(1.0), margin);
closeEnough(0.2857143, cic1.evaluate(1.5).getY(), margin); closeEnough(0.2857143, cic1.evaluate(1.5), margin);
closeEnough(0.4761905, cic1.evaluate(2.0).getY(), margin); closeEnough(0.4761905, cic1.evaluate(2.0), margin);
closeEnough(0.4761905, cic1.evaluate(2.5).getY(), margin); closeEnough(0.4761905, cic1.evaluate(2.5), margin);
final MathFunction cic2 = functions.getCumulativeIncidenceFunction(2); final StepFunction cic2 = functions.getCumulativeIncidenceFunction(2);
closeEnough(0.0, cic2.evaluate(1.0).getY(), margin); closeEnough(0.0, cic2.evaluate(1.0), margin);
closeEnough(0.1428571, cic2.evaluate(1.5).getY(), margin); closeEnough(0.1428571, cic2.evaluate(1.5), margin);
closeEnough(0.3333333, cic2.evaluate(2.0).getY(), margin); closeEnough(0.3333333, cic2.evaluate(2.0), margin);
closeEnough(0.3333333, cic2.evaluate(2.5).getY(), margin); closeEnough(0.3333333, cic2.evaluate(2.5), margin);
} }

View file

@ -44,7 +44,7 @@ public class TestLogRankSingleGroupDifferentiator {
final List<CompetingRiskResponse> data1 = generateData1(); final List<CompetingRiskResponse> data1 = generateData1();
final List<CompetingRiskResponse> data2 = generateData2(); final List<CompetingRiskResponse> data2 = generateData2();
final LogRankSingleGroupDifferentiator differentiator = new LogRankSingleGroupDifferentiator(1); final LogRankSingleGroupDifferentiator differentiator = new LogRankSingleGroupDifferentiator(1, new int[]{1});
final double score = differentiator.differentiate(data1, data2); final double score = differentiator.differentiate(data1, data2);
final double margin = 0.000001; final double margin = 0.000001;

View file

@ -1,43 +0,0 @@
package ca.joeltherrien.randomforest.competingrisk;
import ca.joeltherrien.randomforest.utils.MathFunction;
import ca.joeltherrien.randomforest.utils.Point;
import org.junit.jupiter.api.Test;
import static org.junit.jupiter.api.Assertions.*;
import java.util.ArrayList;
import java.util.List;
public class TestMathFunction {
private MathFunction generateMathFunction(){
final double[] time = new double[]{1.0, 2.0, 3.0};
final double[] y = new double[]{-1.0, 1.0, 0.5};
final List<Point> pointList = new ArrayList<>();
for(int i=0; i<time.length; i++){
pointList.add(new Point(time[i], y[i]));
}
return new MathFunction(pointList, new Point(0.0, 0.1));
}
@Test
public void test(){
final MathFunction function = generateMathFunction();
assertEquals(new Point(1.0, -1.0), function.evaluate(1.0));
assertEquals(new Point(2.0, 1.0), function.evaluate(2.0));
assertEquals(new Point(3.0, 0.5), function.evaluate(3.0));
assertEquals(new Point(0.0, 0.1), function.evaluate(0.5));
assertEquals(new Point(1.0, -1.0), function.evaluate(1.1));
assertEquals(new Point(2.0, 1.0), function.evaluate(2.1));
assertEquals(new Point(3.0, 0.5), function.evaluate(3.1));
assertEquals(new Point(0.0, 0.1), function.evaluate(0.6));
}
}

View file

@ -0,0 +1,60 @@
package ca.joeltherrien.randomforest.competingrisk;
import ca.joeltherrien.randomforest.utils.LeftContinuousStepFunction;
import ca.joeltherrien.randomforest.utils.RightContinuousStepFunction;
import org.junit.jupiter.api.Test;
import static org.junit.jupiter.api.Assertions.assertEquals;
public class TestMathFunctions {
private RightContinuousStepFunction generateRightContinuousStepFunction(){
final double[] time = new double[]{1.0, 2.0, 3.0};
final double[] y = new double[]{-1.0, 1.0, 0.5};
return new RightContinuousStepFunction(time, y, 0.1);
}
private LeftContinuousStepFunction generateLeftContinuousStepFunction(){
final double[] time = new double[]{1.0, 2.0, 3.0};
final double[] y = new double[]{-1.0, 1.0, 0.5};
return new LeftContinuousStepFunction(time, y, 0.1);
}
@Test
public void testRightContinuousStepFunction(){
final RightContinuousStepFunction function = generateRightContinuousStepFunction();
assertEquals(0.1, function.evaluate(0.5));
assertEquals(-1.0, function.evaluate(1.0));
assertEquals(1.0, function.evaluate(2.0));
assertEquals(0.5, function.evaluate(3.0));
assertEquals(0.1, function.evaluate(0.6));
assertEquals(-1.0, function.evaluate(1.1));
assertEquals(1.0, function.evaluate(2.1));
assertEquals(0.5, function.evaluate(3.1));
}
@Test
public void testLeftContinuousStepFunction(){
final LeftContinuousStepFunction function = generateLeftContinuousStepFunction();
assertEquals(0.1, function.evaluate(0.5));
assertEquals(0.1, function.evaluate(1.0));
assertEquals(-1.0, function.evaluate(2.0));
assertEquals(1.0, function.evaluate(3.0));
assertEquals(0.1, function.evaluate(0.6));
assertEquals(-1.0, function.evaluate(1.1));
assertEquals(1.0, function.evaluate(2.1));
assertEquals(0.5, function.evaluate(3.1));
}
}