diff --git a/src/main/java/ca/joeltherrien/randomforest/Main.java b/src/main/java/ca/joeltherrien/randomforest/Main.java index aa8173a..0328d87 100644 --- a/src/main/java/ca/joeltherrien/randomforest/Main.java +++ b/src/main/java/ca/joeltherrien/randomforest/Main.java @@ -3,20 +3,17 @@ package ca.joeltherrien.randomforest; import ca.joeltherrien.randomforest.covariates.*; import ca.joeltherrien.randomforest.responses.competingrisk.*; import ca.joeltherrien.randomforest.responses.competingrisk.combiner.CompetingRiskFunctionCombiner; -import ca.joeltherrien.randomforest.responses.competingrisk.combiner.alternative.CompetingRiskListCombiner; import ca.joeltherrien.randomforest.tree.Forest; import ca.joeltherrien.randomforest.tree.ForestTrainer; import ca.joeltherrien.randomforest.tree.ResponseCombiner; -import ca.joeltherrien.randomforest.utils.MathFunction; +import ca.joeltherrien.randomforest.utils.StepFunction; import ca.joeltherrien.randomforest.utils.Utils; import com.fasterxml.jackson.databind.node.JsonNodeFactory; import com.fasterxml.jackson.databind.node.ObjectNode; import com.fasterxml.jackson.databind.node.TextNode; import java.io.*; -import java.util.ArrayList; import java.util.List; -import java.util.stream.Collectors; public class Main { @@ -68,9 +65,6 @@ public class Main { if(responseCombiner instanceof CompetingRiskFunctionCombiner){ events = ((CompetingRiskFunctionCombiner) responseCombiner).getEvents(); } - else if(responseCombiner instanceof CompetingRiskListCombiner){ - events = ((CompetingRiskListCombiner) responseCombiner).getOriginalCombiner().getEvents(); - } else{ System.out.println("Unsupported tree combiner"); return; @@ -123,7 +117,7 @@ public class Main { final double[] censorTimes = dataset.stream() .mapToDouble(row -> ((CompetingRiskResponseWithCensorTime) row.getResponse()).getC()) .toArray(); - final MathFunction censorDistribution = Utils.estimateOneMinusECDF(censorTimes); + final StepFunction censorDistribution = Utils.estimateOneMinusECDF(censorTimes); System.out.println("Finished generating censor distribution - running concordance"); diff --git a/src/main/java/ca/joeltherrien/randomforest/Settings.java b/src/main/java/ca/joeltherrien/randomforest/Settings.java index e9f8b7f..ee40a35 100644 --- a/src/main/java/ca/joeltherrien/randomforest/Settings.java +++ b/src/main/java/ca/joeltherrien/randomforest/Settings.java @@ -2,11 +2,10 @@ package ca.joeltherrien.randomforest; import ca.joeltherrien.randomforest.covariates.Covariate; import ca.joeltherrien.randomforest.covariates.CovariateSettings; -import ca.joeltherrien.randomforest.responses.competingrisk.*; +import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse; +import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponseWithCensorTime; import ca.joeltherrien.randomforest.responses.competingrisk.combiner.CompetingRiskFunctionCombiner; import ca.joeltherrien.randomforest.responses.competingrisk.combiner.CompetingRiskResponseCombiner; -import ca.joeltherrien.randomforest.responses.competingrisk.combiner.alternative.CompetingRiskListCombiner; -import ca.joeltherrien.randomforest.responses.competingrisk.combiner.alternative.CompetingRiskResponseCombinerToList; import ca.joeltherrien.randomforest.responses.competingrisk.differentiator.GrayLogRankMultipleGroupDifferentiator; import ca.joeltherrien.randomforest.responses.competingrisk.differentiator.GrayLogRankSingleGroupDifferentiator; import ca.joeltherrien.randomforest.responses.competingrisk.differentiator.LogRankMultipleGroupDifferentiator; @@ -22,7 +21,10 @@ import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.node.JsonNodeFactory; import com.fasterxml.jackson.databind.node.ObjectNode; import com.fasterxml.jackson.dataformat.yaml.YAMLFactory; -import lombok.*; +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Data; +import lombok.EqualsAndHashCode; import java.io.File; import java.io.IOException; @@ -76,7 +78,13 @@ public class Settings { (objectNode) -> { final int eventOfFocus = objectNode.get("eventOfFocus").asInt(); - return new LogRankSingleGroupDifferentiator(eventOfFocus); + final Iterator elements = objectNode.get("events").elements(); + final List elementList = new ArrayList<>(); + elements.forEachRemaining(node -> elementList.add(node)); + + final int[] eventArray = elementList.stream().mapToInt(node -> node.asInt()).toArray(); + + return new LogRankSingleGroupDifferentiator(eventOfFocus, eventArray); } ); registerGroupDifferentiatorConstructor("GrayLogRankMultipleGroupDifferentiator", @@ -105,7 +113,14 @@ public class Settings { (objectNode) -> { final int eventOfFocus = objectNode.get("eventOfFocus").asInt(); - return new GrayLogRankSingleGroupDifferentiator(eventOfFocus); + final Iterator elements = objectNode.get("events").elements(); + final List elementList = new ArrayList<>(); + elements.forEachRemaining(node -> elementList.add(node)); + + final int[] eventArray = elementList.stream().mapToInt(node -> node.asInt()).toArray(); + + + return new GrayLogRankSingleGroupDifferentiator(eventOfFocus, eventArray); } ); } @@ -154,23 +169,6 @@ public class Settings { } ); - registerResponseCombinerConstructor("CompetingRiskListCombiner", - (node) -> { - final List eventList = new ArrayList<>(); - node.get("events").elements().forEachRemaining(event -> eventList.add(event.asInt())); - final int[] events = eventList.stream().mapToInt(i -> i).toArray(); - - - final CompetingRiskResponseCombiner responseCombiner = new CompetingRiskResponseCombiner(events); - return new CompetingRiskListCombiner(responseCombiner); - - } - ); - - registerResponseCombinerConstructor("CompetingRiskResponseCombinerToList", - (node) -> new CompetingRiskResponseCombinerToList() - ); - } diff --git a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskErrorRateCalculator.java b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskErrorRateCalculator.java index ac2e386..a837f0e 100644 --- a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskErrorRateCalculator.java +++ b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskErrorRateCalculator.java @@ -2,7 +2,7 @@ package ca.joeltherrien.randomforest.responses.competingrisk; import ca.joeltherrien.randomforest.Row; import ca.joeltherrien.randomforest.tree.Forest; -import ca.joeltherrien.randomforest.utils.MathFunction; +import ca.joeltherrien.randomforest.utils.StepFunction; import java.util.List; import java.util.stream.Collectors; @@ -110,13 +110,13 @@ public class CompetingRiskErrorRateCalculator { } - public double[] calculateIPCWConcordance(final int[] events, final MathFunction censoringDistribution){ + public double[] calculateIPCWConcordance(final int[] events, final StepFunction censoringDistribution){ final double tau = dataset.stream().mapToDouble(row -> row.getResponse().getU()).max().orElse(0.0); return calculateIPCWConcordance(events, censoringDistribution, tau); } - private double[] calculateIPCWConcordance(final int[] events, final MathFunction censoringDistribution, final double tau){ + private double[] calculateIPCWConcordance(final int[] events, final StepFunction censoringDistribution, final double tau){ final double[] errorRates = new double[events.length]; diff --git a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskFunctions.java b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskFunctions.java index 869b7f4..56601c9 100644 --- a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskFunctions.java +++ b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskFunctions.java @@ -1,7 +1,6 @@ package ca.joeltherrien.randomforest.responses.competingrisk; -import ca.joeltherrien.randomforest.utils.MathFunction; -import ca.joeltherrien.randomforest.utils.Point; +import ca.joeltherrien.randomforest.utils.StepFunction; import lombok.Builder; import lombok.Getter; @@ -11,44 +10,48 @@ import java.util.List; @Builder public class CompetingRiskFunctions implements Serializable { - private final List causeSpecificHazards; - private final List cumulativeIncidenceCurves; + private final List causeSpecificHazards; + private final List cumulativeIncidenceCurves; @Getter - private final MathFunction survivalCurve; + private final StepFunction survivalCurve; - public MathFunction getCauseSpecificHazardFunction(int cause){ + public StepFunction getCauseSpecificHazardFunction(int cause){ return causeSpecificHazards.get(cause-1); } - public MathFunction getCumulativeIncidenceFunction(int cause) { + public StepFunction getCumulativeIncidenceFunction(int cause) { return cumulativeIncidenceCurves.get(cause-1); } public double calculateEventSpecificMortality(final int event, final double tau){ - final MathFunction cif = getCumulativeIncidenceFunction(event); + final StepFunction cif = getCumulativeIncidenceFunction(event); double summation = 0.0; - Point previousPoint = null; - for(final Point point : cif.getPoints()){ - if(point.getTime() > tau){ + Double previousTime = null; + Double previousY = null; + + final double[] cifTimes = cif.getX(); + for(int i=0; i tau){ break; } - if(previousPoint != null){ - summation += previousPoint.getY() * (point.getTime() - previousPoint.getTime()); + if(previousTime != null){ + summation += previousY * (time - previousTime); } - previousPoint = point; - + previousTime = time; + previousY = cif.evaluateByIndex(i); } // this is to ensure that we integrate over the proper range - if(previousPoint != null){ - summation += cif.evaluate(tau).getY() * (tau - previousPoint.getTime()); + if(previousTime != null){ + summation += cif.evaluate(tau) * (tau - previousTime); } - return summation; } diff --git a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskGraySetsImpl.java b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskGraySetsImpl.java new file mode 100644 index 0000000..6650dcb --- /dev/null +++ b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskGraySetsImpl.java @@ -0,0 +1,37 @@ +package ca.joeltherrien.randomforest.responses.competingrisk; + +import ca.joeltherrien.randomforest.utils.MathFunction; +import lombok.Builder; +import lombok.Getter; + +import java.util.List; +import java.util.Map; + +/** + * Represents a response from CompetingRiskUtils#calculateGraySetsEfficiently + * + */ +@Builder +@Getter +public class CompetingRiskGraySetsImpl implements CompetingRiskSets{ + + private final List eventTimes; + private final MathFunction[] riskSet; + private final Map numberOfEvents; + + @Override + public MathFunction getRiskSet(int event){ + return(riskSet[event-1]); + } + + @Override + public int getNumberOfEvents(Double time, int event){ + if(numberOfEvents.containsKey(time)){ + return numberOfEvents.get(time)[event]; + } + + return 0; + } + + +} diff --git a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskSets.java b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskSets.java new file mode 100644 index 0000000..9a53d3e --- /dev/null +++ b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskSets.java @@ -0,0 +1,13 @@ +package ca.joeltherrien.randomforest.responses.competingrisk; + +import ca.joeltherrien.randomforest.utils.MathFunction; + +import java.util.List; + +public interface CompetingRiskSets { + + MathFunction getRiskSet(int event); + int getNumberOfEvents(Double time, int event); + List getEventTimes(); + +} diff --git a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskSetsImpl.java b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskSetsImpl.java new file mode 100644 index 0000000..44a9dd8 --- /dev/null +++ b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskSetsImpl.java @@ -0,0 +1,36 @@ +package ca.joeltherrien.randomforest.responses.competingrisk; + +import ca.joeltherrien.randomforest.utils.MathFunction; +import lombok.Builder; +import lombok.Getter; + +import java.util.List; +import java.util.Map; + +/** + * Represents a response from CompetingRiskUtils#calculateSetsEfficiently + * + */ +@Builder +@Getter +public class CompetingRiskSetsImpl implements CompetingRiskSets{ + + private final List eventTimes; + private final MathFunction riskSet; + private final Map numberOfEvents; + + @Override + public MathFunction getRiskSet(int event){ + return riskSet; + } + + @Override + public int getNumberOfEvents(Double time, int event){ + if(numberOfEvents.containsKey(time)){ + return numberOfEvents.get(time)[event]; + } + + return 0; + } + +} diff --git a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskUtils.java b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskUtils.java index 44c078a..3a612b1 100644 --- a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskUtils.java +++ b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskUtils.java @@ -1,8 +1,11 @@ package ca.joeltherrien.randomforest.responses.competingrisk; -import ca.joeltherrien.randomforest.utils.MathFunction; +import ca.joeltherrien.randomforest.utils.LeftContinuousStepFunction; +import ca.joeltherrien.randomforest.utils.StepFunction; +import ca.joeltherrien.randomforest.utils.VeryDiscontinuousStepFunction; -import java.util.List; +import java.util.*; +import java.util.stream.DoubleStream; public class CompetingRiskUtils { @@ -46,7 +49,9 @@ public class CompetingRiskUtils { } - public static double calculateIPCWConcordance(final List responseList, double[] mortalityArray, final int event, final MathFunction censoringDistribution){ + public static double calculateIPCWConcordance(final List responseList, + double[] mortalityArray, final int event, + final StepFunction censoringDistribution){ // Let \tau be the max time. @@ -61,8 +66,8 @@ public class CompetingRiskUtils { final double mortalityI = mortalityArray[i]; final double Ti = responseI.getU(); - final double G_Ti_minus = censoringDistribution.evaluatePrevious(Ti).getY(); - final double AijWeight = 1.0 / (censoringDistribution.evaluate(Ti).getY() * G_Ti_minus); + final double G_Ti_minus = censoringDistribution.evaluatePrevious(Ti); + final double AijWeight = 1.0 / (censoringDistribution.evaluate(Ti) * G_Ti_minus); for(int j=0; j= responseJ.getU() && !responseJ.isCensored() && responseJ.getDelta() != event){ // Bij == 1 - AijWeightPlusBijWeight = 1.0 / (G_Ti_minus * censoringDistribution.evaluatePrevious(responseJ.getU()).getY()); + AijWeightPlusBijWeight = 1.0 / (G_Ti_minus * censoringDistribution.evaluatePrevious(responseJ.getU())); } else{ continue; @@ -97,5 +102,152 @@ public class CompetingRiskUtils { } + public static CompetingRiskSetsImpl calculateSetsEfficiently(final List responses, int[] eventsOfFocus){ + final int n = responses.size(); + int[] numberOfCurrentEvents = new int[eventsOfFocus.length+1]; + + final Map numberOfEvents = new HashMap<>(); + + final List eventTimes = new ArrayList<>(n); + final List eventAndCensorTimes = new ArrayList<>(n); + final List riskSetNumberList = new ArrayList<>(n); + + // need to first sort responses + Collections.sort(responses, (y1, y2) -> { + if(y1.getU() < y2.getU()){ + return -1; + } + else if(y1.getU() > y2.getU()){ + return 1; + } + else{ + return 0; + } + }); + + + + for(int i=0; i currentResponse.getU(); + + numberOfCurrentEvents[currentResponse.getDelta()]++; + + if(lastOfTime){ + int totalNumberOfCurrentEvents = 0; + for(int e = 1; e < numberOfCurrentEvents.length; e++){ // exclude censored events + totalNumberOfCurrentEvents += numberOfCurrentEvents[e]; + } + + final double currentTime = currentResponse.getU(); + + if(totalNumberOfCurrentEvents > 0){ // add numberOfCurrentEvents + // Add point + eventTimes.add(currentTime); + numberOfEvents.put(currentTime, numberOfCurrentEvents); + } + + // Always do risk set + // remember that the LeftContinuousFunction takes into account that at this currentTime the risk value is the previous value + final int riskSet = n - (i+1); + riskSetNumberList.add(riskSet); + eventAndCensorTimes.add(currentTime); + + // reset counters + numberOfCurrentEvents = new int[eventsOfFocus.length+1]; + + } + + } + + final double[] riskSetArray = new double[eventAndCensorTimes.size()]; + final double[] timesArray = new double[eventAndCensorTimes.size()]; + for(int i=0; i responses, int[] eventsOfFocus){ + final List sillyList = responses; // annoying Java generic work-around + final CompetingRiskSetsImpl originalSets = calculateSetsEfficiently(sillyList, eventsOfFocus); + + final double[] allTimes = DoubleStream.concat( + responses.stream() + .mapToDouble(CompetingRiskResponseWithCensorTime::getC), + responses.stream() + .mapToDouble(CompetingRiskResponseWithCensorTime::getU) + ).sorted().distinct().toArray(); + + + + final VeryDiscontinuousStepFunction[] riskSets = new VeryDiscontinuousStepFunction[eventsOfFocus.length]; + + for(final int event : eventsOfFocus){ + final double[] yAt = new double[allTimes.length]; + final double[] yRight = new double[allTimes.length]; + + for(final CompetingRiskResponseWithCensorTime response : responses){ + if(response.getDelta() == event){ + // traditional case only; increment on time t when I(t <= Ui) + final double time = response.getU(); + final int index = Arrays.binarySearch(allTimes, time); + + if(index < 0){ // TODO remove once code is stable + throw new IllegalStateException("Index shouldn't be negative!"); + } + + // All yAts up to and including index are incremented; + // All yRights up to index are incremented + yAt[index]++; + for(int i=0; i= Ui. + + // increment yAt up to Ci. If Ui==Ci, increment yAt at Ci. + final double time = response.getC(); + final int index = Arrays.binarySearch(allTimes, time); + + if(index < 0){ // TODO remove once code is stable + throw new IllegalStateException("Index shouldn't be negative!"); + } + + for(int i=0; i functions.getSurvivalCurve()) .flatMapToDouble( - function -> function.getPoints().stream() - .mapToDouble(point -> point.getTime()) + function -> Arrays.stream(function.getX()) ).sorted().distinct().toArray(); } final double n = responses.size(); - final List survivalPoints = new ArrayList<>(timesToUse.length); - for(final double time : timesToUse){ + final double[] survivalY = new double[timesToUse.length]; - final double survivalY = responses.stream() - .mapToDouble(functions -> functions.getSurvivalCurve().evaluate(time).getY() / n) + for(int i=0; i functions.getSurvivalCurve().evaluate(time) / n) .sum(); - - survivalPoints.add(new Point(time, survivalY)); - } - final MathFunction survivalFunction = new MathFunction(survivalPoints, new Point(0.0, 1.0)); + final StepFunction survivalFunction = new RightContinuousStepFunction(timesToUse, survivalY, 1.0); - final List causeSpecificCumulativeHazardFunctionList = new ArrayList<>(events.length); - final List cumulativeIncidenceFunctionList = new ArrayList<>(events.length); + final List causeSpecificCumulativeHazardFunctionList = new ArrayList<>(events.length); + final List cumulativeIncidenceFunctionList = new ArrayList<>(events.length); for(final int event : events){ - final List cumulativeHazardFunctionPoints = new ArrayList<>(timesToUse.length); - final List cumulativeIncidenceFunctionPoints = new ArrayList<>(timesToUse.length); + final double[] cumulativeHazardFunctionY = new double[timesToUse.length]; + final double[] cumulativeIncidenceFunctionY = new double[timesToUse.length]; - for(final double time : timesToUse){ + for(int i=0; i functions.getCauseSpecificHazardFunction(event).evaluate(time).getY() / n) + cumulativeHazardFunctionY[i] = responses.stream() + .mapToDouble(functions -> functions.getCauseSpecificHazardFunction(event).evaluate(time) / n) .sum(); - final double incidenceY = responses.stream() - .mapToDouble(functions -> functions.getCumulativeIncidenceFunction(event).evaluate(time).getY() / n) + cumulativeIncidenceFunctionY[i] = responses.stream() + .mapToDouble(functions -> functions.getCumulativeIncidenceFunction(event).evaluate(time) / n) .sum(); - cumulativeHazardFunctionPoints.add(new Point(time, hazardY)); - cumulativeIncidenceFunctionPoints.add(new Point(time, incidenceY)); - } - causeSpecificCumulativeHazardFunctionList.add(event-1, new MathFunction(cumulativeHazardFunctionPoints)); - cumulativeIncidenceFunctionList.add(event-1, new MathFunction(cumulativeIncidenceFunctionPoints)); + causeSpecificCumulativeHazardFunctionList.add(event-1, new RightContinuousStepFunction(timesToUse, cumulativeHazardFunctionY, 0)); + cumulativeIncidenceFunctionList.add(event-1, new RightContinuousStepFunction(timesToUse, cumulativeIncidenceFunctionY, 0)); } diff --git a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/combiner/CompetingRiskResponseCombiner.java b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/combiner/CompetingRiskResponseCombiner.java index edd4591..538245a 100644 --- a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/combiner/CompetingRiskResponseCombiner.java +++ b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/combiner/CompetingRiskResponseCombiner.java @@ -3,9 +3,9 @@ package ca.joeltherrien.randomforest.responses.competingrisk.combiner; import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskFunctions; import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse; import ca.joeltherrien.randomforest.tree.ResponseCombiner; -import ca.joeltherrien.randomforest.utils.MathFunction; import ca.joeltherrien.randomforest.utils.Point; -import lombok.RequiredArgsConstructor; +import ca.joeltherrien.randomforest.utils.RightContinuousStepFunction; +import ca.joeltherrien.randomforest.utils.StepFunction; import java.util.*; @@ -38,8 +38,8 @@ public class CompetingRiskResponseCombiner implements ResponseCombiner responses) { - final List causeSpecificCumulativeHazardFunctionList = new ArrayList<>(events.length); - final List cumulativeIncidenceFunctionList = new ArrayList<>(events.length); + final List causeSpecificCumulativeHazardFunctionList = new ArrayList<>(events.length); + final List cumulativeIncidenceFunctionList = new ArrayList<>(events.length); Collections.sort(responses, (y1, y2) -> { if(y1.getU() < y2.getU()){ @@ -97,7 +97,7 @@ public class CompetingRiskResponseCombiner implements ResponseCombiner 0 ? survivalCurve.evaluate(timesToUse[i-1]).getY() : survivalCurve.evaluate(0.0).getY(); - final double previousSurvivalEvaluation = i > 0 ? survivalCurve.evaluate(timesToUseList.get(i-1)).getY() : 1.0; + final double previousSurvivalEvaluation = i > 0 ? survivalCurve.evaluateByIndex(i-1) : 1.0; final double cifDeltaY = previousSurvivalEvaluation * (numberEventsAtTime / individualsAtRisk); final Point newCIFPoint = new Point(time_k, previousCIFPoint.getY() + cifDeltaY); @@ -138,10 +138,10 @@ public class CompetingRiskResponseCombiner implements ResponseCombiner { - - @Getter - private final CompetingRiskResponseCombiner originalCombiner; - - @Override - public CompetingRiskFunctions combine(List responses) { - final List completeList = responses.stream().flatMap(Arrays::stream).collect(Collectors.toList()); - - return originalCombiner.combine(completeList); - } -} diff --git a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/combiner/alternative/CompetingRiskResponseCombinerToList.java b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/combiner/alternative/CompetingRiskResponseCombinerToList.java deleted file mode 100644 index 5526d0a..0000000 --- a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/combiner/alternative/CompetingRiskResponseCombinerToList.java +++ /dev/null @@ -1,29 +0,0 @@ -package ca.joeltherrien.randomforest.responses.competingrisk.combiner.alternative; - -import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse; -import ca.joeltherrien.randomforest.tree.ResponseCombiner; - -import java.util.List; - -/** - * This class takes all of the observations in a terminal node and 'combines' them into just a list of the observations. - * - * This is used in the alternative approach to only compute the functions at the final stage when combining trees. - * - */ -public class CompetingRiskResponseCombinerToList implements ResponseCombiner { - - @Override - public CompetingRiskResponse[] combine(List responses) { - final CompetingRiskResponse[] array = new CompetingRiskResponse[responses.size()]; - - for(int i=0; i leftHand, List rightHand); - abstract double riskSet(final List eventList, double time, int eventOfFocus); - - private double numberOfEventsAtTime(int eventOfFocus, List eventList, double time){ - return (double) eventList.stream() - .filter(event -> event.getDelta() == eventOfFocus) - .filter(event -> event.getU() == time) // since delta != 0 we know censoring didn't occur prior to this - .count(); - - } /** * Calculates the log rank value (or the Gray's test value) for a *specific* event cause. * * @param eventOfFocus - * @param leftHand A non-empty list of CompetingRiskResponse - * @param rightHand A non-empty list of CompetingRiskResponse + * @param competingRiskSetsLeft A summary of the different sets used in the calculation for the left side + * @param competingRiskSetsRight A summary of the different sets used in the calculation for the right side * @return */ - LogRankValue specificLogRankValue(final int eventOfFocus, List leftHand, List rightHand){ + LogRankValue specificLogRankValue(final int eventOfFocus, final CompetingRiskSets competingRiskSetsLeft, final CompetingRiskSets competingRiskSetsRight){ final double[] distinctEventTimes = Stream.concat( - leftHand.stream(), rightHand.stream() - ) - .filter(event -> !event.isCensored()) - .mapToDouble(event -> event.getU()) + competingRiskSetsLeft.getEventTimes().stream(), + competingRiskSetsRight.getEventTimes().stream()) + .mapToDouble(Double::doubleValue) + .sorted() .distinct() .toArray(); @@ -51,12 +43,12 @@ public abstract class CompetingRiskGroupDifferentiator eventList, double time, int eventOfFocus) { - return eventList.stream() - .filter(event -> event.getU() >= time || - (event.getU() < time && event.getDelta() != eventOfFocus && event.getC() > time) - ) - .count(); - } } diff --git a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/differentiator/GrayLogRankSingleGroupDifferentiator.java b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/differentiator/GrayLogRankSingleGroupDifferentiator.java index 7a8d7f9..48e3b6f 100644 --- a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/differentiator/GrayLogRankSingleGroupDifferentiator.java +++ b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/differentiator/GrayLogRankSingleGroupDifferentiator.java @@ -1,6 +1,8 @@ package ca.joeltherrien.randomforest.responses.competingrisk.differentiator; +import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskGraySetsImpl; import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponseWithCensorTime; +import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskUtils; import lombok.RequiredArgsConstructor; import java.util.List; @@ -13,6 +15,7 @@ import java.util.List; public class GrayLogRankSingleGroupDifferentiator extends CompetingRiskGroupDifferentiator { private final int eventOfFocus; + private final int[] events; @Override public Double differentiate(List leftHand, List rightHand) { @@ -20,19 +23,13 @@ public class GrayLogRankSingleGroupDifferentiator extends CompetingRiskGroupDiff return null; } - final LogRankValue valueOfInterest = specificLogRankValue(eventOfFocus, leftHand, rightHand); + final CompetingRiskGraySetsImpl competingRiskSetsLeft = CompetingRiskUtils.calculateGraySetsEfficiently(leftHand, events); + final CompetingRiskGraySetsImpl competingRiskSetsRight = CompetingRiskUtils.calculateGraySetsEfficiently(rightHand, events); + + final LogRankValue valueOfInterest = specificLogRankValue(eventOfFocus, competingRiskSetsLeft, competingRiskSetsRight); return Math.abs(valueOfInterest.getNumerator() / valueOfInterest.getVarianceSqrt()); } - @Override - double riskSet(List eventList, double time, int eventOfFocus) { - return eventList.stream() - .filter(event -> event.getU() >= time || - (event.getU() < time && event.getDelta() != eventOfFocus && event.getC() > time) - ) - .count(); - } - } diff --git a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/differentiator/LogRankMultipleGroupDifferentiator.java b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/differentiator/LogRankMultipleGroupDifferentiator.java index 7a62e86..2ad2424 100644 --- a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/differentiator/LogRankMultipleGroupDifferentiator.java +++ b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/differentiator/LogRankMultipleGroupDifferentiator.java @@ -1,6 +1,8 @@ package ca.joeltherrien.randomforest.responses.competingrisk.differentiator; import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse; +import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskSetsImpl; +import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskUtils; import lombok.RequiredArgsConstructor; import java.util.List; @@ -20,11 +22,14 @@ public class LogRankMultipleGroupDifferentiator extends CompetingRiskGroupDiffer return null; } + final CompetingRiskSetsImpl competingRiskSetsLeft = CompetingRiskUtils.calculateSetsEfficiently(leftHand, events); + final CompetingRiskSetsImpl competingRiskSetsRight = CompetingRiskUtils.calculateSetsEfficiently(rightHand, events); + double numerator = 0.0; double denominatorSquared = 0.0; for(final int eventOfFocus : events){ - final LogRankValue valueOfInterest = specificLogRankValue(eventOfFocus, leftHand, rightHand); + final LogRankValue valueOfInterest = specificLogRankValue(eventOfFocus, competingRiskSetsLeft, competingRiskSetsRight); numerator += valueOfInterest.getNumerator()*valueOfInterest.getVarianceSqrt(); denominatorSquared += valueOfInterest.getVariance(); @@ -35,11 +40,4 @@ public class LogRankMultipleGroupDifferentiator extends CompetingRiskGroupDiffer } - @Override - double riskSet(List eventList, double time, int eventOfFocus) { - return eventList.stream() - .filter(event -> event.getU() >= time) - .count(); - } - } diff --git a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/differentiator/LogRankSingleGroupDifferentiator.java b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/differentiator/LogRankSingleGroupDifferentiator.java index 06cb6bd..107964e 100644 --- a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/differentiator/LogRankSingleGroupDifferentiator.java +++ b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/differentiator/LogRankSingleGroupDifferentiator.java @@ -1,6 +1,8 @@ package ca.joeltherrien.randomforest.responses.competingrisk.differentiator; import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse; +import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskSetsImpl; +import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskUtils; import lombok.RequiredArgsConstructor; import java.util.List; @@ -13,6 +15,7 @@ import java.util.List; public class LogRankSingleGroupDifferentiator extends CompetingRiskGroupDifferentiator { private final int eventOfFocus; + private final int[] events; @Override public Double differentiate(List leftHand, List rightHand) { @@ -20,17 +23,13 @@ public class LogRankSingleGroupDifferentiator extends CompetingRiskGroupDifferen return null; } - final LogRankValue valueOfInterest = specificLogRankValue(eventOfFocus, leftHand, rightHand); + final CompetingRiskSetsImpl competingRiskSetsLeft = CompetingRiskUtils.calculateSetsEfficiently(leftHand, events); + final CompetingRiskSetsImpl competingRiskSetsRight = CompetingRiskUtils.calculateSetsEfficiently(rightHand, events); + + final LogRankValue valueOfInterest = specificLogRankValue(eventOfFocus, competingRiskSetsLeft, competingRiskSetsRight); return Math.abs(valueOfInterest.getNumerator() / valueOfInterest.getVarianceSqrt()); } - @Override - double riskSet(List eventList, double time, int eventOfFocus) { - return eventList.stream() - .filter(event -> event.getU() >= time) - .count(); - } - } diff --git a/src/main/java/ca/joeltherrien/randomforest/utils/DiscontinuousStepFunction.java b/src/main/java/ca/joeltherrien/randomforest/utils/DiscontinuousStepFunction.java new file mode 100644 index 0000000..3e71977 --- /dev/null +++ b/src/main/java/ca/joeltherrien/randomforest/utils/DiscontinuousStepFunction.java @@ -0,0 +1,97 @@ +package ca.joeltherrien.randomforest.utils; + +/** + * Represents a function represented by discrete points. However, the function may be right-continuous or left-continuous + * at a given point, with no consistency. This function tracks that. + */ +public final class DiscontinuousStepFunction extends StepFunction { + + private final double[] y; + private final boolean[] isLeftContinuous; + + /** + * Represents the value that should be returned by evaluate if there are points prior to the time the function is being evaluated at. + * + * Map be null. + */ + private final double defaultY; + + public DiscontinuousStepFunction(double[] x, double[] y, boolean[] isLeftContinuous, double defaultY) { + super(x); + this.y = y; + this.isLeftContinuous = isLeftContinuous; + this.defaultY = defaultY; + } + + @Override + public double evaluate(double time){ + int index = Utils.binarySearchLessThan(0, x.length, x, time); + if(index < 0){ + return defaultY; + } + else{ + if(x[index] == time){ + return evaluateByIndex(index); + } + else{ + return y[index]; + } + } + } + + + @Override + public double evaluatePrevious(double time){ + int index = Utils.binarySearchLessThan(0, x.length, x, time) - 1; + if(index < 0){ + return defaultY; + } + else{ + if(x[index] == time){ + return evaluateByIndex(index); + } + else{ + return y[index]; + } + } + } + + @Override + public double evaluateByIndex(int i) { + if(isLeftContinuous[i]){ + i -= 1; + } + + if(i < 0){ + return defaultY; + } + + return y[i]; + } + + @Override + public String toString(){ + final StringBuilder builder = new StringBuilder(); + builder.append("Default point: "); + builder.append(defaultY); + builder.append("\n"); + + for(int i=0; i pointList, final double defaultY){ + + final double[] x = new double[pointList.size()]; + final double[] y = new double[pointList.size()]; + + final ListIterator pointIterator = pointList.listIterator(); + while(pointIterator.hasNext()){ + final int index = pointIterator.nextIndex(); + final Point currentPoint = pointIterator.next(); + + x[index] = currentPoint.getTime(); + y[index] = currentPoint.getY(); + } + + return new LeftContinuousStepFunction(x, y, defaultY); + + } + + @Override + public double evaluate(double time){ + int index = Utils.binarySearchLessThan(0, x.length, x, time); + if(index < 0){ + return defaultY; + } + else{ + if(x[index] == time){ + return evaluateByIndex(index-1); + } + else{ + return y[index]; + } + } + } + + @Override + public double evaluatePrevious(double time){ + int index = Utils.binarySearchLessThan(0, x.length, x, time) - 1; + if(index < 0){ + return defaultY; + } + else{ + if(x[index] == time){ + return evaluateByIndex(index-1); + } + else{ + return y[index]; + } + } + } + + @Override + public double evaluateByIndex(int i) { + if(i < 0){ + return defaultY; + } + + return y[i]; + } + + + @Override + public String toString(){ + final StringBuilder builder = new StringBuilder(); + builder.append("Default point: "); + builder.append(defaultY); + builder.append("\n"); + + for(int i=0; i points; - - /** - * Represents the value that should be returned by evaluate if there are points prior to the time the function is being evaluated at. - * - * Map be null. - */ - private final Point defaultValue; - - public MathFunction(final List points){ - this(points, new Point(0.0, 0.0)); - } - - public MathFunction(final List points, final Point defaultValue){ - this.points = Collections.unmodifiableList(points); - this.defaultValue = defaultValue; - } - - public Point evaluate(double time){ - int index = binarySearch(points, time); - if(index < 0){ - return defaultValue; - } - else{ - return points.get(index); - } - } - - public Point evaluatePrevious(double time){ - - int index = binarySearch(points, time) - 1; - if(index < 0){ - return defaultValue; - } - else{ - return points.get(index); - } - - - } - - /** - * Returns the index of the largest (in terms of time) Point that is <= the provided time value. - * - * @param points - * @param time - * @return The index of the largest Point who's time is <= the time parameter. - */ - private static int binarySearch(List points, double time){ - final int pointSize = points.size(); - - if(pointSize == 0 || points.get(pointSize-1).getTime() <= time){ - // we're already too far - return pointSize - 1; - } - - if(pointSize < 200){ - for(int i = 0; i < pointSize; i++){ - if(points.get(i).getTime() > time){ - return i - 1; - } - } - } - - // else - - - final int middle = pointSize / 2; - final double middleTime = points.get(middle).getTime(); - if(middleTime < time){ - // go right - return binarySearch(points.subList(middle, pointSize), time) + middle; - } - else if(middleTime > time){ - // go left - return binarySearch(points.subList(0, middle), time); - } - else{ // middleTime == time - return middle; - } - } - - @Override - public String toString(){ - final StringBuilder builder = new StringBuilder(); - builder.append("Default point: "); - builder.append(defaultValue); - builder.append("\n"); - - for(final Point point : points){ - builder.append(point); - builder.append("\n"); - } - - return builder.toString(); - } + double evaluate(double time); } diff --git a/src/main/java/ca/joeltherrien/randomforest/utils/RUtils.java b/src/main/java/ca/joeltherrien/randomforest/utils/RUtils.java index 2ee3cb4..392425f 100644 --- a/src/main/java/ca/joeltherrien/randomforest/utils/RUtils.java +++ b/src/main/java/ca/joeltherrien/randomforest/utils/RUtils.java @@ -11,26 +11,12 @@ import java.util.zip.GZIPOutputStream; */ public final class RUtils { - public static double[] extractTimes(final MathFunction function){ - final List pointList = function.getPoints(); - final double[] times = new double[pointList.size()]; - - for(int i=0; i pointList = function.getPoints(); - final double[] times = new double[pointList.size()]; - - for(int i=0; i pointList, final double defaultY){ + + final double[] x = new double[pointList.size()]; + final double[] y = new double[pointList.size()]; + + final ListIterator pointIterator = pointList.listIterator(); + while(pointIterator.hasNext()){ + final int index = pointIterator.nextIndex(); + final Point currentPoint = pointIterator.next(); + + x[index] = currentPoint.getTime(); + y[index] = currentPoint.getY(); + } + + return new RightContinuousStepFunction(x, y, defaultY); + + } + + public double[] getY(){ + return y.clone(); + } + + @Override + public double evaluate(double time){ + int index = Utils.binarySearchLessThan(0, x.length, x, time); + if(index < 0){ + return defaultY; + } + else{ + return y[index]; + } + } + + @Override + public double evaluatePrevious(double time){ + int index = Utils.binarySearchLessThan(0, x.length, x, time) - 1; + if(index < 0){ + return defaultY; + } + else{ + return y[index]; + } + } + + @Override + public double evaluateByIndex(int i) { + if(i < 0){ + return defaultY; + } + + return y[i]; + } + + + @Override + public String toString(){ + final StringBuilder builder = new StringBuilder(); + builder.append("Default point: "); + builder.append(defaultY); + builder.append("\n"); + + for(int i=0; i timeCounterMap = new HashMap<>(); @@ -33,7 +32,7 @@ public final class Utils { pointList.add(new Point(entry.getKey(), (double) newCount / n)); } - return new MathFunction(pointList, defaultPoint); + return RightContinuousStepFunction.constructFromPoints(pointList, 1.0); } @@ -64,6 +63,48 @@ public final class Utils { } } + /** + * Returns the index of the largest (in terms of time) Point that is <= the provided time value. + * + * @param startIndex Only search from startIndex (inclusive) + * @param endIndex Only search up to endIndex (exclusive) + * @param time + * @return The index of the largest Point who's time is <= the time parameter. + */ + public static int binarySearchLessThan(int startIndex, int endIndex, double[] x, double time){ + final int range = endIndex - startIndex; + + if(range == 0 || x[endIndex-1] <= time){ + // we're already too far + return endIndex - 1; + } + + if(range < 200){ + for(int i = startIndex; i < endIndex; i++){ + if(x[i] > time){ + return i - 1; + } + } + } + + // else + + + final int middle = range / 2; + final double middleTime = x[middle]; + if(middleTime < time){ + // go right + return binarySearchLessThan(middle, endIndex, x, time); + } + else if(middleTime > time){ + // go left + return binarySearchLessThan(0, middle, x, time); + } + else{ // middleTime == time + return middle; + } + } + /** * Replacement for Java 9's List.of * diff --git a/src/main/java/ca/joeltherrien/randomforest/utils/VeryDiscontinuousStepFunction.java b/src/main/java/ca/joeltherrien/randomforest/utils/VeryDiscontinuousStepFunction.java new file mode 100644 index 0000000..31ca83f --- /dev/null +++ b/src/main/java/ca/joeltherrien/randomforest/utils/VeryDiscontinuousStepFunction.java @@ -0,0 +1,64 @@ +package ca.joeltherrien.randomforest.utils; + +/** + * Represents a step function represented by discrete points. However, there may be individual time values that has + * a y value that doesn't belong to a particular 'step'. + */ +public final class VeryDiscontinuousStepFunction implements MathFunction { + + private final double[] x; + private final double[] yAt; + private final double[] yRight; + + /** + * Represents the value that should be returned by evaluate if there are points prior to the time the function is being evaluated at. + * + * Map be null. + */ + private final double defaultY; + + public VeryDiscontinuousStepFunction(double[] x, double[] yAt, double[] yRight, double defaultY) { + this.x = x; + this.yAt = yAt; + this.yRight = yRight; + this.defaultY = defaultY; + } + + @Override + public double evaluate(double time){ + int index = Utils.binarySearchLessThan(0, x.length, x, time); + if(index < 0){ + return defaultY; + } + else{ + if(x[index] == time){ + return yAt[index]; + } + else{ // time > x[index] + return yRight[index]; + } + } + } + + + @Override + public String toString(){ + final StringBuilder builder = new StringBuilder(); + builder.append("Default point: "); + builder.append(defaultY); + builder.append("\n"); + + for(int i=0; i 2); + assertTrue(functions.getCumulativeIncidenceFunction(1).getX().length > 2); assertEquals(NTREE, forest.getTrees().size()); diff --git a/src/test/java/ca/joeltherrien/randomforest/TestUtils.java b/src/test/java/ca/joeltherrien/randomforest/TestUtils.java index 60e7a07..ab4c83a 100644 --- a/src/test/java/ca/joeltherrien/randomforest/TestUtils.java +++ b/src/test/java/ca/joeltherrien/randomforest/TestUtils.java @@ -1,7 +1,6 @@ package ca.joeltherrien.randomforest; -import ca.joeltherrien.randomforest.utils.MathFunction; -import ca.joeltherrien.randomforest.utils.Point; +import ca.joeltherrien.randomforest.utils.StepFunction; import ca.joeltherrien.randomforest.utils.Utils; import org.junit.jupiter.api.Test; @@ -23,53 +22,67 @@ public class TestUtils { * * @param function */ - public static void assertCumulativeFunction(MathFunction function){ - Point previousPoint = null; - for(final Point point : function.getPoints()){ + public static void assertCumulativeFunction(StepFunction function){ + Double previousTime = null; + Double previousY = null; - if(previousPoint != null){ - assertTrue(previousPoint.getTime() < point.getTime(), "Points should be ordered and strictly different"); - assertTrue(previousPoint.getY() <= point.getY(), "Cumulative incidence functions are monotone"); + final double[] times = function.getX(); + + for(int i=0; i= point.getY(), "Survival functions are monotone"); + final double[] times = function.getX(); + + for(int i=0; i= y, "Survival functions are monotone"); } + previousTime = time; + previousY = y; - previousPoint = point; } } @Test public void testOneMinusECDF(){ final double[] times = new double[]{1.0, 1.0, 2.0, 3.0, 3.0, 50.0}; - final MathFunction survivalCurve = Utils.estimateOneMinusECDF(times); + final StepFunction survivalCurve = Utils.estimateOneMinusECDF(times); final double margin = 0.000001; - closeEnough(1.0, survivalCurve.evaluate(0.0).getY(), margin); + closeEnough(1.0, survivalCurve.evaluate(0.0), margin); - closeEnough(1.0, survivalCurve.evaluatePrevious(1.0).getY(), margin); - closeEnough(4.0/6.0, survivalCurve.evaluate(1.0).getY(), margin); + closeEnough(1.0, survivalCurve.evaluatePrevious(1.0), margin); + closeEnough(4.0/6.0, survivalCurve.evaluate(1.0), margin); - closeEnough(4.0/6.0, survivalCurve.evaluatePrevious(2.0).getY(), margin); - closeEnough(3.0/6.0, survivalCurve.evaluate(2.0).getY(), margin); + closeEnough(4.0/6.0, survivalCurve.evaluatePrevious(2.0), margin); + closeEnough(3.0/6.0, survivalCurve.evaluate(2.0), margin); - closeEnough(3.0/6.0, survivalCurve.evaluatePrevious(3.0).getY(), margin); - closeEnough(1.0/6.0, survivalCurve.evaluate(3.0).getY(), margin); + closeEnough(3.0/6.0, survivalCurve.evaluatePrevious(3.0), margin); + closeEnough(1.0/6.0, survivalCurve.evaluate(3.0), margin); - closeEnough(1.0/6.0, survivalCurve.evaluatePrevious(50.0).getY(), margin); - closeEnough(0.0, survivalCurve.evaluate(50.0).getY(), margin); + closeEnough(1.0/6.0, survivalCurve.evaluatePrevious(50.0), margin); + closeEnough(0.0, survivalCurve.evaluate(50.0), margin); assertSurvivalCurve(survivalCurve); diff --git a/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCalculatingCompetingRiskSets.java b/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCalculatingCompetingRiskSets.java new file mode 100644 index 0000000..e6fec68 --- /dev/null +++ b/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCalculatingCompetingRiskSets.java @@ -0,0 +1,214 @@ +package ca.joeltherrien.randomforest.competingrisk; + +import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskGraySetsImpl; +import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponseWithCensorTime; +import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskSetsImpl; +import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskUtils; +import org.junit.jupiter.api.Test; + +import java.util.ArrayList; +import java.util.List; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class TestCalculatingCompetingRiskSets { + + public List generateData(){ + final List data = new ArrayList<>(); + + data.add(new CompetingRiskResponseWithCensorTime(1, 1, 3)); + data.add(new CompetingRiskResponseWithCensorTime(1, 1, 3)); + data.add(new CompetingRiskResponseWithCensorTime(0, 1, 1)); + data.add(new CompetingRiskResponseWithCensorTime(1, 2, 2.5)); + data.add(new CompetingRiskResponseWithCensorTime(2, 3, 4)); + data.add(new CompetingRiskResponseWithCensorTime(0, 3, 3)); + data.add(new CompetingRiskResponseWithCensorTime(1, 4, 4)); + data.add(new CompetingRiskResponseWithCensorTime(0, 5, 5)); + data.add(new CompetingRiskResponseWithCensorTime(2, 6, 7)); + + return data; + } + + @Test + public void testCalculatingSets(){ + final List data = generateData(); + + final CompetingRiskSetsImpl sets = CompetingRiskUtils.calculateSetsEfficiently(data, new int[]{1,2}); + + final List times = sets.getEventTimes(); + assertEquals(5, times.size()); + + // Times + assertEquals(1.0, times.get(0).doubleValue()); + assertEquals(2.0, times.get(1).doubleValue()); + assertEquals(3.0, times.get(2).doubleValue()); + assertEquals(4.0, times.get(3).doubleValue()); + assertEquals(6.0, times.get(4).doubleValue()); + + // Number of Events + assertEquals(2, sets.getNumberOfEvents(1.0, 1)); + assertEquals(0, sets.getNumberOfEvents(1.0, 2)); + + assertEquals(1, sets.getNumberOfEvents(2.0, 1)); + assertEquals(0, sets.getNumberOfEvents(2.0, 2)); + + assertEquals(0, sets.getNumberOfEvents(3.0, 1)); + assertEquals(1, sets.getNumberOfEvents(3.0, 2)); + + assertEquals(1, sets.getNumberOfEvents(4.0, 1)); + assertEquals(0, sets.getNumberOfEvents(4.0, 2)); + + assertEquals(0, sets.getNumberOfEvents(6.0, 1)); + assertEquals(1, sets.getNumberOfEvents(6.0, 2)); + + // Make sure it doesn't break for other times + assertEquals(0, sets.getNumberOfEvents(5.5, 1)); + assertEquals(0, sets.getNumberOfEvents(5.5, 2)); + + + // Risk set + assertEquals(9, sets.getRiskSet(1).evaluate(0.5)); + assertEquals(9, sets.getRiskSet(2).evaluate(0.5)); + + assertEquals(9, sets.getRiskSet(1).evaluate(1.0)); + assertEquals(9, sets.getRiskSet(2).evaluate(1.0)); + + assertEquals(6, sets.getRiskSet(1).evaluate(1.5)); + assertEquals(6, sets.getRiskSet(2).evaluate(1.5)); + + assertEquals(6, sets.getRiskSet(1).evaluate(2.0)); + assertEquals(6, sets.getRiskSet(2).evaluate(2.0)); + + assertEquals(5, sets.getRiskSet(1).evaluate(2.3)); + assertEquals(5, sets.getRiskSet(2).evaluate(2.3)); + + assertEquals(5, sets.getRiskSet(1).evaluate(2.5)); + assertEquals(5, sets.getRiskSet(2).evaluate(2.5)); + + assertEquals(5, sets.getRiskSet(1).evaluate(2.7)); + assertEquals(5, sets.getRiskSet(2).evaluate(2.7)); + + assertEquals(5, sets.getRiskSet(1).evaluate(3.0)); + assertEquals(5, sets.getRiskSet(2).evaluate(3.0)); + + assertEquals(3, sets.getRiskSet(1).evaluate(3.5)); + assertEquals(3, sets.getRiskSet(2).evaluate(3.5)); + + assertEquals(3, sets.getRiskSet(1).evaluate(4.0)); + assertEquals(3, sets.getRiskSet(2).evaluate(4.0)); + + assertEquals(2, sets.getRiskSet(1).evaluate(4.5)); + assertEquals(2, sets.getRiskSet(2).evaluate(4.5)); + + assertEquals(2, sets.getRiskSet(1).evaluate(5.0)); + assertEquals(2, sets.getRiskSet(2).evaluate(5.0)); + + assertEquals(1, sets.getRiskSet(1).evaluate(5.5)); + assertEquals(1, sets.getRiskSet(2).evaluate(5.5)); + + assertEquals(1, sets.getRiskSet(1).evaluate(6.0)); + assertEquals(1, sets.getRiskSet(2).evaluate(6.0)); + + assertEquals(0, sets.getRiskSet(1).evaluate(6.5)); + assertEquals(0, sets.getRiskSet(2).evaluate(6.5)); + + assertEquals(0, sets.getRiskSet(1).evaluate(7.0)); + assertEquals(0, sets.getRiskSet(2).evaluate(7.0)); + + assertEquals(0, sets.getRiskSet(1).evaluate(7.5)); + assertEquals(0, sets.getRiskSet(2).evaluate(7.5)); + + } + + @Test + public void testCalculatingGraySets(){ + final List data = generateData(); + + final CompetingRiskGraySetsImpl sets = CompetingRiskUtils.calculateGraySetsEfficiently(data, new int[]{1,2}); + + final List times = sets.getEventTimes(); + assertEquals(5, times.size()); + + // Times + assertEquals(1.0, times.get(0).doubleValue()); + assertEquals(2.0, times.get(1).doubleValue()); + assertEquals(3.0, times.get(2).doubleValue()); + assertEquals(4.0, times.get(3).doubleValue()); + assertEquals(6.0, times.get(4).doubleValue()); + + // Number of Events + assertEquals(2, sets.getNumberOfEvents(1.0, 1)); + assertEquals(0, sets.getNumberOfEvents(1.0, 2)); + + assertEquals(1, sets.getNumberOfEvents(2.0, 1)); + assertEquals(0, sets.getNumberOfEvents(2.0, 2)); + + assertEquals(0, sets.getNumberOfEvents(3.0, 1)); + assertEquals(1, sets.getNumberOfEvents(3.0, 2)); + + assertEquals(1, sets.getNumberOfEvents(4.0, 1)); + assertEquals(0, sets.getNumberOfEvents(4.0, 2)); + + assertEquals(0, sets.getNumberOfEvents(6.0, 1)); + assertEquals(1, sets.getNumberOfEvents(6.0, 2)); + + // Make sure it doesn't break for other times + assertEquals(0, sets.getNumberOfEvents(5.5, 1)); + assertEquals(0, sets.getNumberOfEvents(5.5, 2)); + + + // Risk set + assertEquals(9, sets.getRiskSet(1).evaluate(0.5)); + assertEquals(9, sets.getRiskSet(2).evaluate(0.5)); + + assertEquals(9, sets.getRiskSet(1).evaluate(1.0)); + assertEquals(9, sets.getRiskSet(2).evaluate(1.0)); + + assertEquals(6, sets.getRiskSet(1).evaluate(1.5)); + assertEquals(8, sets.getRiskSet(2).evaluate(1.5)); + + assertEquals(6, sets.getRiskSet(1).evaluate(2.0)); + assertEquals(8, sets.getRiskSet(2).evaluate(2.0)); + + assertEquals(5, sets.getRiskSet(1).evaluate(2.3)); + assertEquals(8, sets.getRiskSet(2).evaluate(2.3)); + + assertEquals(5, sets.getRiskSet(1).evaluate(2.5)); + assertEquals(7, sets.getRiskSet(2).evaluate(2.5)); + + assertEquals(5, sets.getRiskSet(1).evaluate(2.7)); + assertEquals(7, sets.getRiskSet(2).evaluate(2.7)); + + assertEquals(5, sets.getRiskSet(1).evaluate(3.0)); + assertEquals(5, sets.getRiskSet(2).evaluate(3.0)); + + assertEquals(4, sets.getRiskSet(1).evaluate(3.5)); + assertEquals(3, sets.getRiskSet(2).evaluate(3.5)); + + assertEquals(3, sets.getRiskSet(1).evaluate(4.0)); + assertEquals(3, sets.getRiskSet(2).evaluate(4.0)); + + assertEquals(2, sets.getRiskSet(1).evaluate(4.5)); + assertEquals(2, sets.getRiskSet(2).evaluate(4.5)); + + assertEquals(2, sets.getRiskSet(1).evaluate(5.0)); + assertEquals(2, sets.getRiskSet(2).evaluate(5.0)); + + assertEquals(1, sets.getRiskSet(1).evaluate(5.5)); + assertEquals(1, sets.getRiskSet(2).evaluate(5.5)); + + assertEquals(1, sets.getRiskSet(1).evaluate(6.0)); + assertEquals(1, sets.getRiskSet(2).evaluate(6.0)); + + assertEquals(1, sets.getRiskSet(1).evaluate(6.5)); + assertEquals(0, sets.getRiskSet(2).evaluate(6.5)); + + assertEquals(0, sets.getRiskSet(1).evaluate(7.0)); + assertEquals(0, sets.getRiskSet(2).evaluate(7.0)); + + assertEquals(0, sets.getRiskSet(1).evaluate(7.5)); + assertEquals(0, sets.getRiskSet(2).evaluate(7.5)); + + } + +} diff --git a/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCompetingRisk.java b/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCompetingRisk.java index c761401..e38ce57 100644 --- a/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCompetingRisk.java +++ b/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCompetingRisk.java @@ -7,8 +7,7 @@ import ca.joeltherrien.randomforest.tree.Forest; import ca.joeltherrien.randomforest.tree.ForestTrainer; import ca.joeltherrien.randomforest.tree.Node; import ca.joeltherrien.randomforest.tree.TreeTrainer; -import ca.joeltherrien.randomforest.utils.MathFunction; -import ca.joeltherrien.randomforest.utils.Point; +import ca.joeltherrien.randomforest.utils.StepFunction; import ca.joeltherrien.randomforest.utils.Utils; import com.fasterxml.jackson.databind.node.*; import org.junit.jupiter.api.Test; @@ -19,7 +18,6 @@ import static org.junit.jupiter.api.Assertions.*; import java.io.IOException; import java.util.List; -import java.util.stream.Collectors; public class TestCompetingRisk { @@ -33,6 +31,9 @@ public class TestCompetingRisk { final ObjectNode groupDifferentiatorSettings = new ObjectNode(JsonNodeFactory.instance); groupDifferentiatorSettings.set("type", new TextNode("LogRankSingleGroupDifferentiator")); groupDifferentiatorSettings.set("eventOfFocus", new IntNode(1)); + groupDifferentiatorSettings.set("events", + new ArrayNode(JsonNodeFactory.instance, Utils.easyList(new IntNode(1), new IntNode(2))) + ); final ObjectNode responseCombinerSettings = new ObjectNode(JsonNodeFactory.instance); responseCombinerSettings.set("type", new TextNode("CompetingRiskResponseCombiner")); @@ -109,32 +110,32 @@ public class TestCompetingRisk { final CompetingRiskFunctions functions = node.evaluate(newRow); - final MathFunction causeOneCIFFunction = functions.getCumulativeIncidenceFunction(1); - final MathFunction causeTwoCIFFunction = functions.getCumulativeIncidenceFunction(2); - final MathFunction cumHazOneFunction = functions.getCauseSpecificHazardFunction(1); - final MathFunction cumHazTwoFunction = functions.getCauseSpecificHazardFunction(2); + final StepFunction causeOneCIFFunction = functions.getCumulativeIncidenceFunction(1); + final StepFunction causeTwoCIFFunction = functions.getCumulativeIncidenceFunction(2); + final StepFunction cumHazOneFunction = functions.getCauseSpecificHazardFunction(1); + final StepFunction cumHazTwoFunction = functions.getCauseSpecificHazardFunction(2); final double margin = 0.0000001; - closeEnough(0.003003003, causeOneCIFFunction.evaluate(0.02).getY(), margin); - closeEnough(0.166183852, causeOneCIFFunction.evaluate(1.00).getY(), margin); - closeEnough(0.715625487, causeOneCIFFunction.evaluate(6.50).getY(), margin); - closeEnough(0.794796334, causeOneCIFFunction.evaluate(10.60).getY(), margin); - closeEnough(0.794796334, causeOneCIFFunction.evaluate(10.80).getY(), margin); + closeEnough(0.003003003, causeOneCIFFunction.evaluate(0.02), margin); + closeEnough(0.166183852, causeOneCIFFunction.evaluate(1.00), margin); + closeEnough(0.715625487, causeOneCIFFunction.evaluate(6.50), margin); + closeEnough(0.794796334, causeOneCIFFunction.evaluate(10.60), margin); + closeEnough(0.794796334, causeOneCIFFunction.evaluate(10.80), margin); - closeEnough(0.08149211, causeTwoCIFFunction.evaluate(1.00).getY(), margin); - closeEnough(0.14926318, causeTwoCIFFunction.evaluate(6.50).getY(), margin); - closeEnough(0.15332850, causeTwoCIFFunction.evaluate(10.80).getY(), margin); + closeEnough(0.08149211, causeTwoCIFFunction.evaluate(1.00), margin); + closeEnough(0.14926318, causeTwoCIFFunction.evaluate(6.50), margin); + closeEnough(0.15332850, causeTwoCIFFunction.evaluate(10.80), margin); - closeEnough(0.1888601, cumHazOneFunction.evaluate(1.00).getY(), margin); - closeEnough(1.6189759, cumHazOneFunction.evaluate(6.50).getY(), margin); - closeEnough(2.4878342, cumHazOneFunction.evaluate(10.80).getY(), margin); + closeEnough(0.1888601, cumHazOneFunction.evaluate(1.00), margin); + closeEnough(1.6189759, cumHazOneFunction.evaluate(6.50), margin); + closeEnough(2.4878342, cumHazOneFunction.evaluate(10.80), margin); - closeEnough(0.08946513, cumHazTwoFunction.evaluate(1.00).getY(), margin); - closeEnough(0.32801830, cumHazTwoFunction.evaluate(6.50).getY(), margin); - closeEnough(0.36505534, cumHazTwoFunction.evaluate(10.80).getY(), margin); + closeEnough(0.08946513, cumHazTwoFunction.evaluate(1.00), margin); + closeEnough(0.32801830, cumHazTwoFunction.evaluate(6.50), margin); + closeEnough(0.36505534, cumHazTwoFunction.evaluate(10.80), margin); } @@ -162,18 +163,18 @@ public class TestCompetingRisk { final CompetingRiskFunctions functions = node.evaluate(newRow); - final MathFunction causeOneCIFFunction = functions.getCumulativeIncidenceFunction(1); - final MathFunction causeTwoCIFFunction = functions.getCumulativeIncidenceFunction(2); - final MathFunction cumHazOneFunction = functions.getCauseSpecificHazardFunction(1); - final MathFunction cumHazTwoFunction = functions.getCauseSpecificHazardFunction(2); + final StepFunction causeOneCIFFunction = functions.getCumulativeIncidenceFunction(1); + final StepFunction causeTwoCIFFunction = functions.getCumulativeIncidenceFunction(2); + final StepFunction cumHazOneFunction = functions.getCauseSpecificHazardFunction(1); + final StepFunction cumHazTwoFunction = functions.getCauseSpecificHazardFunction(2); final double margin = 0.0000001; - closeEnough(0, causeOneCIFFunction.evaluate(0.02).getY(), margin); - closeEnough(0.555555555, causeOneCIFFunction.evaluate(0.4).getY(), margin); - closeEnough(0.66666666666, causeOneCIFFunction.evaluate(0.8).getY(), margin); - closeEnough(0.88888888888, causeOneCIFFunction.evaluate(0.9).getY(), margin); - closeEnough(1.0, causeOneCIFFunction.evaluate(1.0).getY(), margin); + closeEnough(0, causeOneCIFFunction.evaluate(0.02), margin); + closeEnough(0.555555555, causeOneCIFFunction.evaluate(0.4), margin); + closeEnough(0.66666666666, causeOneCIFFunction.evaluate(0.8), margin); + closeEnough(0.88888888888, causeOneCIFFunction.evaluate(0.9), margin); + closeEnough(1.0, causeOneCIFFunction.evaluate(1.0), margin); /* closeEnough(0.08149211, causeTwoCIFFunction.evaluate(1.00).getY(), margin); @@ -222,11 +223,11 @@ public class TestCompetingRisk { assertCumulativeFunction(functions.getCumulativeIncidenceFunction(2)); - closeEnough(0.63, functions.getCumulativeIncidenceFunction(1).evaluate(4.0).getY(), 0.01); - closeEnough(0.765, functions.getCumulativeIncidenceFunction(1).evaluate(10.8).getY(), 0.01); + closeEnough(0.63, functions.getCumulativeIncidenceFunction(1).evaluate(4.0), 0.01); + closeEnough(0.765, functions.getCumulativeIncidenceFunction(1).evaluate(10.8), 0.01); - closeEnough(0.163, functions.getCumulativeIncidenceFunction(2).evaluate(4.0).getY(), 0.01); - closeEnough(0.195, functions.getCumulativeIncidenceFunction(2).evaluate(10.8).getY(), 0.01); + closeEnough(0.163, functions.getCumulativeIncidenceFunction(2).evaluate(4.0), 0.01); + closeEnough(0.195, functions.getCumulativeIncidenceFunction(2).evaluate(10.8), 0.01); final CompetingRiskErrorRateCalculator errorRateCalculator = new CompetingRiskErrorRateCalculator(dataset, forest, true); final double[] errorRates = errorRateCalculator.calculateConcordance(new int[]{1,2}); @@ -289,7 +290,8 @@ public class TestCompetingRisk { settings.setNtree(300); // results are too variable at 100 final List covariates = settings.getCovariates(); - final List> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation()); + final List> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), + settings.getTrainingDataLocation()); final ForestTrainer forestTrainer = new ForestTrainer<>(settings, dataset, covariates); final Forest forest = forestTrainer.trainSerial(); @@ -305,10 +307,9 @@ public class TestCompetingRisk { assertCumulativeFunction(functions.getCumulativeIncidenceFunction(1)); assertCumulativeFunction(functions.getCumulativeIncidenceFunction(2)); - final List causeOneCIFPoints = functions.getCumulativeIncidenceFunction(1).getPoints(); - // We seem to consistently underestimate the results. - assertTrue(causeOneCIFPoints.get(causeOneCIFPoints.size()-1).getY() > 0.74, "Results should match randomForestSRC; had " + causeOneCIFPoints.get(causeOneCIFPoints.size()-1).getY()); // note; most observations from randomForestSRC hover around 0.78 but I've seen it as low as 0.72 + final double endProbability = functions.getCumulativeIncidenceFunction(1).evaluate(10000000); + assertTrue(endProbability > 0.74, "Results should match randomForestSRC; had " + endProbability); // note; most observations from randomForestSRC hover around 0.78 but I've seen it as low as 0.72 final CompetingRiskErrorRateCalculator errorRateCalculator = new CompetingRiskErrorRateCalculator(dataset, forest, true); final double[] errorRates = errorRateCalculator.calculateConcordance(new int[]{1,2}); diff --git a/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCompetingRiskErrorRateCalculator.java b/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCompetingRiskErrorRateCalculator.java index 6da0d38..45d0928 100644 --- a/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCompetingRiskErrorRateCalculator.java +++ b/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCompetingRiskErrorRateCalculator.java @@ -4,8 +4,8 @@ import ca.joeltherrien.randomforest.Row; import ca.joeltherrien.randomforest.covariates.Covariate; import ca.joeltherrien.randomforest.responses.competingrisk.*; import ca.joeltherrien.randomforest.tree.Forest; -import ca.joeltherrien.randomforest.utils.MathFunction; -import ca.joeltherrien.randomforest.utils.Point; +import ca.joeltherrien.randomforest.utils.RightContinuousStepFunction; +import ca.joeltherrien.randomforest.utils.StepFunction; import ca.joeltherrien.randomforest.utils.Utils; import org.junit.jupiter.api.Test; @@ -36,7 +36,7 @@ public class TestCompetingRiskErrorRateCalculator { final double naiveConcordance = CompetingRiskUtils.calculateConcordance(responseList, mortalityArray, event); - final MathFunction fakeCensorDistribution = new MathFunction(Collections.emptyList(), new Point(0.0, 1.0)); + final StepFunction fakeCensorDistribution = RightContinuousStepFunction.constructFromPoints(Collections.emptyList(), 1.0); // This distribution will make the IPCW weights == 1, giving identical results to the naive concordance. final double ipcwConcordance = CompetingRiskUtils.calculateIPCWConcordance(responseList, mortalityArray, event, fakeCensorDistribution); diff --git a/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCompetingRiskFunctions.java b/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCompetingRiskFunctions.java index 18acf89..7147f4c 100644 --- a/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCompetingRiskFunctions.java +++ b/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCompetingRiskFunctions.java @@ -2,8 +2,9 @@ package ca.joeltherrien.randomforest.competingrisk; import ca.joeltherrien.randomforest.TestUtils; import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskFunctions; -import ca.joeltherrien.randomforest.utils.MathFunction; import ca.joeltherrien.randomforest.utils.Point; +import ca.joeltherrien.randomforest.utils.RightContinuousStepFunction; +import ca.joeltherrien.randomforest.utils.StepFunction; import ca.joeltherrien.randomforest.utils.Utils; import org.junit.jupiter.api.Test; @@ -13,19 +14,19 @@ public class TestCompetingRiskFunctions { @Test public void testCalculateEventSpecificMortality(){ - final MathFunction cif1 = new MathFunction( + final StepFunction cif1 = RightContinuousStepFunction.constructFromPoints( Utils.easyList( new Point(1.0, 0.3), new Point(1.5, 0.7), new Point(2.0, 0.8) - ), new Point(0.0 ,0.0) + ), 0.0 ); // not being used - final MathFunction chf1 = new MathFunction(Collections.emptyList()); + final StepFunction chf1 = RightContinuousStepFunction.constructFromPoints(Collections.emptyList(), 0.0); // not being used - final MathFunction km = new MathFunction(Collections.emptyList()); + final StepFunction km = RightContinuousStepFunction.constructFromPoints(Collections.emptyList(), 0.0); final CompetingRiskFunctions functions = CompetingRiskFunctions.builder() .causeSpecificHazards(Collections.singletonList(chf1)) diff --git a/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCompetingRiskResponseCombiner.java b/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCompetingRiskResponseCombiner.java index 4cfea67..44ee657 100644 --- a/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCompetingRiskResponseCombiner.java +++ b/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCompetingRiskResponseCombiner.java @@ -3,7 +3,7 @@ package ca.joeltherrien.randomforest.competingrisk; import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskFunctions; import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse; import ca.joeltherrien.randomforest.responses.competingrisk.combiner.CompetingRiskResponseCombiner; -import ca.joeltherrien.randomforest.utils.MathFunction; +import ca.joeltherrien.randomforest.utils.StepFunction; import org.junit.jupiter.api.Test; import static ca.joeltherrien.randomforest.TestUtils.closeEnough; @@ -33,17 +33,17 @@ public class TestCompetingRiskResponseCombiner { public void testCompetingRiskResponseCombiner(){ final CompetingRiskFunctions functions = generateFunctions(); - final MathFunction survivalCurve = functions.getSurvivalCurve(); + final StepFunction survivalCurve = functions.getSurvivalCurve(); // time = 1.0 1.5 2.0 2.5 // surv = 0.7142857 0.5714286 0.1904762 0.1904762 final double margin = 0.0000001; - closeEnough(0.7142857, survivalCurve.evaluate(1.0).getY(), margin); - closeEnough(0.5714286, survivalCurve.evaluate(1.5).getY(), margin); - closeEnough(0.1904762, survivalCurve.evaluate(2.0).getY(), margin); - closeEnough(0.1904762, survivalCurve.evaluate(2.5).getY(), margin); + closeEnough(0.7142857, survivalCurve.evaluate(1.0), margin); + closeEnough(0.5714286, survivalCurve.evaluate(1.5), margin); + closeEnough(0.1904762, survivalCurve.evaluate(2.0), margin); + closeEnough(0.1904762, survivalCurve.evaluate(2.5), margin); // Time = 1.0 1.5 2.0 2.5 @@ -53,17 +53,17 @@ public class TestCompetingRiskResponseCombiner { [2,] 0.0000000 0.2000000 0.5333333 0.5333333 */ - final MathFunction cumHaz1 = functions.getCauseSpecificHazardFunction(1); - closeEnough(0.2857143, cumHaz1.evaluate(1.0).getY(), margin); - closeEnough(0.2857143, cumHaz1.evaluate(1.5).getY(), margin); - closeEnough(0.6190476, cumHaz1.evaluate(2.0).getY(), margin); - closeEnough(0.6190476, cumHaz1.evaluate(2.5).getY(), margin); + final StepFunction cumHaz1 = functions.getCauseSpecificHazardFunction(1); + closeEnough(0.2857143, cumHaz1.evaluate(1.0), margin); + closeEnough(0.2857143, cumHaz1.evaluate(1.5), margin); + closeEnough(0.6190476, cumHaz1.evaluate(2.0), margin); + closeEnough(0.6190476, cumHaz1.evaluate(2.5), margin); - final MathFunction cumHaz2 = functions.getCauseSpecificHazardFunction(2); - closeEnough(0.0, cumHaz2.evaluate(1.0).getY(), margin); - closeEnough(0.2, cumHaz2.evaluate(1.5).getY(), margin); - closeEnough(0.5333333, cumHaz2.evaluate(2.0).getY(), margin); - closeEnough(0.5333333, cumHaz2.evaluate(2.5).getY(), margin); + final StepFunction cumHaz2 = functions.getCauseSpecificHazardFunction(2); + closeEnough(0.0, cumHaz2.evaluate(1.0), margin); + closeEnough(0.2, cumHaz2.evaluate(1.5), margin); + closeEnough(0.5333333, cumHaz2.evaluate(2.0), margin); + closeEnough(0.5333333, cumHaz2.evaluate(2.5), margin); /* Time = 1.0 1.5 2.0 2.5 Cumulative Incidence Curve. Each row for one event. @@ -72,17 +72,17 @@ public class TestCompetingRiskResponseCombiner { [2,] 0.0000000 0.1428571 0.3333333 0.3333333 */ - final MathFunction cic1 = functions.getCumulativeIncidenceFunction(1); - closeEnough(0.2857143, cic1.evaluate(1.0).getY(), margin); - closeEnough(0.2857143, cic1.evaluate(1.5).getY(), margin); - closeEnough(0.4761905, cic1.evaluate(2.0).getY(), margin); - closeEnough(0.4761905, cic1.evaluate(2.5).getY(), margin); + final StepFunction cic1 = functions.getCumulativeIncidenceFunction(1); + closeEnough(0.2857143, cic1.evaluate(1.0), margin); + closeEnough(0.2857143, cic1.evaluate(1.5), margin); + closeEnough(0.4761905, cic1.evaluate(2.0), margin); + closeEnough(0.4761905, cic1.evaluate(2.5), margin); - final MathFunction cic2 = functions.getCumulativeIncidenceFunction(2); - closeEnough(0.0, cic2.evaluate(1.0).getY(), margin); - closeEnough(0.1428571, cic2.evaluate(1.5).getY(), margin); - closeEnough(0.3333333, cic2.evaluate(2.0).getY(), margin); - closeEnough(0.3333333, cic2.evaluate(2.5).getY(), margin); + final StepFunction cic2 = functions.getCumulativeIncidenceFunction(2); + closeEnough(0.0, cic2.evaluate(1.0), margin); + closeEnough(0.1428571, cic2.evaluate(1.5), margin); + closeEnough(0.3333333, cic2.evaluate(2.0), margin); + closeEnough(0.3333333, cic2.evaluate(2.5), margin); } diff --git a/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestLogRankSingleGroupDifferentiator.java b/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestLogRankSingleGroupDifferentiator.java index 56e668b..ef79b65 100644 --- a/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestLogRankSingleGroupDifferentiator.java +++ b/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestLogRankSingleGroupDifferentiator.java @@ -44,7 +44,7 @@ public class TestLogRankSingleGroupDifferentiator { final List data1 = generateData1(); final List data2 = generateData2(); - final LogRankSingleGroupDifferentiator differentiator = new LogRankSingleGroupDifferentiator(1); + final LogRankSingleGroupDifferentiator differentiator = new LogRankSingleGroupDifferentiator(1, new int[]{1}); final double score = differentiator.differentiate(data1, data2); final double margin = 0.000001; diff --git a/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestMathFunction.java b/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestMathFunction.java deleted file mode 100644 index 591ca9a..0000000 --- a/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestMathFunction.java +++ /dev/null @@ -1,43 +0,0 @@ -package ca.joeltherrien.randomforest.competingrisk; - -import ca.joeltherrien.randomforest.utils.MathFunction; -import ca.joeltherrien.randomforest.utils.Point; -import org.junit.jupiter.api.Test; -import static org.junit.jupiter.api.Assertions.*; - -import java.util.ArrayList; -import java.util.List; - -public class TestMathFunction { - - private MathFunction generateMathFunction(){ - final double[] time = new double[]{1.0, 2.0, 3.0}; - final double[] y = new double[]{-1.0, 1.0, 0.5}; - - final List pointList = new ArrayList<>(); - for(int i=0; i