From c68f67e47ac350af1c653c17a2b9ee4409f72e58 Mon Sep 17 00:00:00 2001 From: Joel Therrien Date: Thu, 25 Oct 2018 10:34:27 -0700 Subject: [PATCH] Massive optimizations; Refactored how MathFunctions are structured to use more primitives and less objects. Optimized competing risk group differentiators to run faster. Removed alternative competing risk response combiners (may be added back later) --- .../ca/joeltherrien/randomforest/Main.java | 10 +- .../joeltherrien/randomforest/Settings.java | 44 ++-- .../CompetingRiskErrorRateCalculator.java | 6 +- .../competingrisk/CompetingRiskFunctions.java | 39 ++-- .../CompetingRiskGraySetsImpl.java | 37 +++ .../competingrisk/CompetingRiskSets.java | 13 ++ .../competingrisk/CompetingRiskSetsImpl.java | 36 +++ .../competingrisk/CompetingRiskUtils.java | 164 +++++++++++++- .../CompetingRiskFunctionCombiner.java | 48 ++-- .../CompetingRiskResponseCombiner.java | 16 +- .../CompetingRiskListCombiner.java | 26 --- .../CompetingRiskResponseCombinerToList.java | 29 --- .../CompetingRiskGroupDifferentiator.java | 32 +-- ...rayLogRankMultipleGroupDifferentiator.java | 15 +- .../GrayLogRankSingleGroupDifferentiator.java | 17 +- .../LogRankMultipleGroupDifferentiator.java | 14 +- .../LogRankSingleGroupDifferentiator.java | 15 +- .../utils/DiscontinuousStepFunction.java | 97 ++++++++ .../utils/LeftContinuousStepFunction.java | 113 +++++++++ .../randomforest/utils/MathFunction.java | 109 +-------- .../randomforest/utils/RUtils.java | 22 +- .../utils/RightContinuousStepFunction.java | 107 +++++++++ .../randomforest/utils/StepFunction.java | 26 +++ .../randomforest/utils/SumFunction.java | 15 ++ .../randomforest/utils/Utils.java | 47 +++- .../utils/VeryDiscontinuousStepFunction.java | 64 ++++++ .../randomforest/TestSavingLoading.java | 6 +- .../joeltherrien/randomforest/TestUtils.java | 65 +++--- .../TestCalculatingCompetingRiskSets.java | 214 ++++++++++++++++++ .../competingrisk/TestCompetingRisk.java | 77 +++---- .../TestCompetingRiskErrorRateCalculator.java | 6 +- .../TestCompetingRiskFunctions.java | 11 +- .../TestCompetingRiskResponseCombiner.java | 52 ++--- .../TestLogRankSingleGroupDifferentiator.java | 2 +- .../competingrisk/TestMathFunction.java | 43 ---- .../competingrisk/TestMathFunctions.java | 60 +++++ 36 files changed, 1223 insertions(+), 474 deletions(-) create mode 100644 src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskGraySetsImpl.java create mode 100644 src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskSets.java create mode 100644 src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskSetsImpl.java delete mode 100644 src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/combiner/alternative/CompetingRiskListCombiner.java delete mode 100644 src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/combiner/alternative/CompetingRiskResponseCombinerToList.java create mode 100644 src/main/java/ca/joeltherrien/randomforest/utils/DiscontinuousStepFunction.java create mode 100644 src/main/java/ca/joeltherrien/randomforest/utils/LeftContinuousStepFunction.java create mode 100644 src/main/java/ca/joeltherrien/randomforest/utils/RightContinuousStepFunction.java create mode 100644 src/main/java/ca/joeltherrien/randomforest/utils/StepFunction.java create mode 100644 src/main/java/ca/joeltherrien/randomforest/utils/SumFunction.java create mode 100644 src/main/java/ca/joeltherrien/randomforest/utils/VeryDiscontinuousStepFunction.java create mode 100644 src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCalculatingCompetingRiskSets.java delete mode 100644 src/test/java/ca/joeltherrien/randomforest/competingrisk/TestMathFunction.java create mode 100644 src/test/java/ca/joeltherrien/randomforest/competingrisk/TestMathFunctions.java diff --git a/src/main/java/ca/joeltherrien/randomforest/Main.java b/src/main/java/ca/joeltherrien/randomforest/Main.java index aa8173a..0328d87 100644 --- a/src/main/java/ca/joeltherrien/randomforest/Main.java +++ b/src/main/java/ca/joeltherrien/randomforest/Main.java @@ -3,20 +3,17 @@ package ca.joeltherrien.randomforest; import ca.joeltherrien.randomforest.covariates.*; import ca.joeltherrien.randomforest.responses.competingrisk.*; import ca.joeltherrien.randomforest.responses.competingrisk.combiner.CompetingRiskFunctionCombiner; -import ca.joeltherrien.randomforest.responses.competingrisk.combiner.alternative.CompetingRiskListCombiner; import ca.joeltherrien.randomforest.tree.Forest; import ca.joeltherrien.randomforest.tree.ForestTrainer; import ca.joeltherrien.randomforest.tree.ResponseCombiner; -import ca.joeltherrien.randomforest.utils.MathFunction; +import ca.joeltherrien.randomforest.utils.StepFunction; import ca.joeltherrien.randomforest.utils.Utils; import com.fasterxml.jackson.databind.node.JsonNodeFactory; import com.fasterxml.jackson.databind.node.ObjectNode; import com.fasterxml.jackson.databind.node.TextNode; import java.io.*; -import java.util.ArrayList; import java.util.List; -import java.util.stream.Collectors; public class Main { @@ -68,9 +65,6 @@ public class Main { if(responseCombiner instanceof CompetingRiskFunctionCombiner){ events = ((CompetingRiskFunctionCombiner) responseCombiner).getEvents(); } - else if(responseCombiner instanceof CompetingRiskListCombiner){ - events = ((CompetingRiskListCombiner) responseCombiner).getOriginalCombiner().getEvents(); - } else{ System.out.println("Unsupported tree combiner"); return; @@ -123,7 +117,7 @@ public class Main { final double[] censorTimes = dataset.stream() .mapToDouble(row -> ((CompetingRiskResponseWithCensorTime) row.getResponse()).getC()) .toArray(); - final MathFunction censorDistribution = Utils.estimateOneMinusECDF(censorTimes); + final StepFunction censorDistribution = Utils.estimateOneMinusECDF(censorTimes); System.out.println("Finished generating censor distribution - running concordance"); diff --git a/src/main/java/ca/joeltherrien/randomforest/Settings.java b/src/main/java/ca/joeltherrien/randomforest/Settings.java index e9f8b7f..ee40a35 100644 --- a/src/main/java/ca/joeltherrien/randomforest/Settings.java +++ b/src/main/java/ca/joeltherrien/randomforest/Settings.java @@ -2,11 +2,10 @@ package ca.joeltherrien.randomforest; import ca.joeltherrien.randomforest.covariates.Covariate; import ca.joeltherrien.randomforest.covariates.CovariateSettings; -import ca.joeltherrien.randomforest.responses.competingrisk.*; +import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse; +import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponseWithCensorTime; import ca.joeltherrien.randomforest.responses.competingrisk.combiner.CompetingRiskFunctionCombiner; import ca.joeltherrien.randomforest.responses.competingrisk.combiner.CompetingRiskResponseCombiner; -import ca.joeltherrien.randomforest.responses.competingrisk.combiner.alternative.CompetingRiskListCombiner; -import ca.joeltherrien.randomforest.responses.competingrisk.combiner.alternative.CompetingRiskResponseCombinerToList; import ca.joeltherrien.randomforest.responses.competingrisk.differentiator.GrayLogRankMultipleGroupDifferentiator; import ca.joeltherrien.randomforest.responses.competingrisk.differentiator.GrayLogRankSingleGroupDifferentiator; import ca.joeltherrien.randomforest.responses.competingrisk.differentiator.LogRankMultipleGroupDifferentiator; @@ -22,7 +21,10 @@ import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.node.JsonNodeFactory; import com.fasterxml.jackson.databind.node.ObjectNode; import com.fasterxml.jackson.dataformat.yaml.YAMLFactory; -import lombok.*; +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Data; +import lombok.EqualsAndHashCode; import java.io.File; import java.io.IOException; @@ -76,7 +78,13 @@ public class Settings { (objectNode) -> { final int eventOfFocus = objectNode.get("eventOfFocus").asInt(); - return new LogRankSingleGroupDifferentiator(eventOfFocus); + final Iterator elements = objectNode.get("events").elements(); + final List elementList = new ArrayList<>(); + elements.forEachRemaining(node -> elementList.add(node)); + + final int[] eventArray = elementList.stream().mapToInt(node -> node.asInt()).toArray(); + + return new LogRankSingleGroupDifferentiator(eventOfFocus, eventArray); } ); registerGroupDifferentiatorConstructor("GrayLogRankMultipleGroupDifferentiator", @@ -105,7 +113,14 @@ public class Settings { (objectNode) -> { final int eventOfFocus = objectNode.get("eventOfFocus").asInt(); - return new GrayLogRankSingleGroupDifferentiator(eventOfFocus); + final Iterator elements = objectNode.get("events").elements(); + final List elementList = new ArrayList<>(); + elements.forEachRemaining(node -> elementList.add(node)); + + final int[] eventArray = elementList.stream().mapToInt(node -> node.asInt()).toArray(); + + + return new GrayLogRankSingleGroupDifferentiator(eventOfFocus, eventArray); } ); } @@ -154,23 +169,6 @@ public class Settings { } ); - registerResponseCombinerConstructor("CompetingRiskListCombiner", - (node) -> { - final List eventList = new ArrayList<>(); - node.get("events").elements().forEachRemaining(event -> eventList.add(event.asInt())); - final int[] events = eventList.stream().mapToInt(i -> i).toArray(); - - - final CompetingRiskResponseCombiner responseCombiner = new CompetingRiskResponseCombiner(events); - return new CompetingRiskListCombiner(responseCombiner); - - } - ); - - registerResponseCombinerConstructor("CompetingRiskResponseCombinerToList", - (node) -> new CompetingRiskResponseCombinerToList() - ); - } diff --git a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskErrorRateCalculator.java b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskErrorRateCalculator.java index ac2e386..a837f0e 100644 --- a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskErrorRateCalculator.java +++ b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskErrorRateCalculator.java @@ -2,7 +2,7 @@ package ca.joeltherrien.randomforest.responses.competingrisk; import ca.joeltherrien.randomforest.Row; import ca.joeltherrien.randomforest.tree.Forest; -import ca.joeltherrien.randomforest.utils.MathFunction; +import ca.joeltherrien.randomforest.utils.StepFunction; import java.util.List; import java.util.stream.Collectors; @@ -110,13 +110,13 @@ public class CompetingRiskErrorRateCalculator { } - public double[] calculateIPCWConcordance(final int[] events, final MathFunction censoringDistribution){ + public double[] calculateIPCWConcordance(final int[] events, final StepFunction censoringDistribution){ final double tau = dataset.stream().mapToDouble(row -> row.getResponse().getU()).max().orElse(0.0); return calculateIPCWConcordance(events, censoringDistribution, tau); } - private double[] calculateIPCWConcordance(final int[] events, final MathFunction censoringDistribution, final double tau){ + private double[] calculateIPCWConcordance(final int[] events, final StepFunction censoringDistribution, final double tau){ final double[] errorRates = new double[events.length]; diff --git a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskFunctions.java b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskFunctions.java index 869b7f4..56601c9 100644 --- a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskFunctions.java +++ b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskFunctions.java @@ -1,7 +1,6 @@ package ca.joeltherrien.randomforest.responses.competingrisk; -import ca.joeltherrien.randomforest.utils.MathFunction; -import ca.joeltherrien.randomforest.utils.Point; +import ca.joeltherrien.randomforest.utils.StepFunction; import lombok.Builder; import lombok.Getter; @@ -11,44 +10,48 @@ import java.util.List; @Builder public class CompetingRiskFunctions implements Serializable { - private final List causeSpecificHazards; - private final List cumulativeIncidenceCurves; + private final List causeSpecificHazards; + private final List cumulativeIncidenceCurves; @Getter - private final MathFunction survivalCurve; + private final StepFunction survivalCurve; - public MathFunction getCauseSpecificHazardFunction(int cause){ + public StepFunction getCauseSpecificHazardFunction(int cause){ return causeSpecificHazards.get(cause-1); } - public MathFunction getCumulativeIncidenceFunction(int cause) { + public StepFunction getCumulativeIncidenceFunction(int cause) { return cumulativeIncidenceCurves.get(cause-1); } public double calculateEventSpecificMortality(final int event, final double tau){ - final MathFunction cif = getCumulativeIncidenceFunction(event); + final StepFunction cif = getCumulativeIncidenceFunction(event); double summation = 0.0; - Point previousPoint = null; - for(final Point point : cif.getPoints()){ - if(point.getTime() > tau){ + Double previousTime = null; + Double previousY = null; + + final double[] cifTimes = cif.getX(); + for(int i=0; i tau){ break; } - if(previousPoint != null){ - summation += previousPoint.getY() * (point.getTime() - previousPoint.getTime()); + if(previousTime != null){ + summation += previousY * (time - previousTime); } - previousPoint = point; - + previousTime = time; + previousY = cif.evaluateByIndex(i); } // this is to ensure that we integrate over the proper range - if(previousPoint != null){ - summation += cif.evaluate(tau).getY() * (tau - previousPoint.getTime()); + if(previousTime != null){ + summation += cif.evaluate(tau) * (tau - previousTime); } - return summation; } diff --git a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskGraySetsImpl.java b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskGraySetsImpl.java new file mode 100644 index 0000000..6650dcb --- /dev/null +++ b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskGraySetsImpl.java @@ -0,0 +1,37 @@ +package ca.joeltherrien.randomforest.responses.competingrisk; + +import ca.joeltherrien.randomforest.utils.MathFunction; +import lombok.Builder; +import lombok.Getter; + +import java.util.List; +import java.util.Map; + +/** + * Represents a response from CompetingRiskUtils#calculateGraySetsEfficiently + * + */ +@Builder +@Getter +public class CompetingRiskGraySetsImpl implements CompetingRiskSets{ + + private final List eventTimes; + private final MathFunction[] riskSet; + private final Map numberOfEvents; + + @Override + public MathFunction getRiskSet(int event){ + return(riskSet[event-1]); + } + + @Override + public int getNumberOfEvents(Double time, int event){ + if(numberOfEvents.containsKey(time)){ + return numberOfEvents.get(time)[event]; + } + + return 0; + } + + +} diff --git a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskSets.java b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskSets.java new file mode 100644 index 0000000..9a53d3e --- /dev/null +++ b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskSets.java @@ -0,0 +1,13 @@ +package ca.joeltherrien.randomforest.responses.competingrisk; + +import ca.joeltherrien.randomforest.utils.MathFunction; + +import java.util.List; + +public interface CompetingRiskSets { + + MathFunction getRiskSet(int event); + int getNumberOfEvents(Double time, int event); + List getEventTimes(); + +} diff --git a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskSetsImpl.java b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskSetsImpl.java new file mode 100644 index 0000000..44a9dd8 --- /dev/null +++ b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskSetsImpl.java @@ -0,0 +1,36 @@ +package ca.joeltherrien.randomforest.responses.competingrisk; + +import ca.joeltherrien.randomforest.utils.MathFunction; +import lombok.Builder; +import lombok.Getter; + +import java.util.List; +import java.util.Map; + +/** + * Represents a response from CompetingRiskUtils#calculateSetsEfficiently + * + */ +@Builder +@Getter +public class CompetingRiskSetsImpl implements CompetingRiskSets{ + + private final List eventTimes; + private final MathFunction riskSet; + private final Map numberOfEvents; + + @Override + public MathFunction getRiskSet(int event){ + return riskSet; + } + + @Override + public int getNumberOfEvents(Double time, int event){ + if(numberOfEvents.containsKey(time)){ + return numberOfEvents.get(time)[event]; + } + + return 0; + } + +} diff --git a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskUtils.java b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskUtils.java index 44c078a..3a612b1 100644 --- a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskUtils.java +++ b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskUtils.java @@ -1,8 +1,11 @@ package ca.joeltherrien.randomforest.responses.competingrisk; -import ca.joeltherrien.randomforest.utils.MathFunction; +import ca.joeltherrien.randomforest.utils.LeftContinuousStepFunction; +import ca.joeltherrien.randomforest.utils.StepFunction; +import ca.joeltherrien.randomforest.utils.VeryDiscontinuousStepFunction; -import java.util.List; +import java.util.*; +import java.util.stream.DoubleStream; public class CompetingRiskUtils { @@ -46,7 +49,9 @@ public class CompetingRiskUtils { } - public static double calculateIPCWConcordance(final List responseList, double[] mortalityArray, final int event, final MathFunction censoringDistribution){ + public static double calculateIPCWConcordance(final List responseList, + double[] mortalityArray, final int event, + final StepFunction censoringDistribution){ // Let \tau be the max time. @@ -61,8 +66,8 @@ public class CompetingRiskUtils { final double mortalityI = mortalityArray[i]; final double Ti = responseI.getU(); - final double G_Ti_minus = censoringDistribution.evaluatePrevious(Ti).getY(); - final double AijWeight = 1.0 / (censoringDistribution.evaluate(Ti).getY() * G_Ti_minus); + final double G_Ti_minus = censoringDistribution.evaluatePrevious(Ti); + final double AijWeight = 1.0 / (censoringDistribution.evaluate(Ti) * G_Ti_minus); for(int j=0; j= responseJ.getU() && !responseJ.isCensored() && responseJ.getDelta() != event){ // Bij == 1 - AijWeightPlusBijWeight = 1.0 / (G_Ti_minus * censoringDistribution.evaluatePrevious(responseJ.getU()).getY()); + AijWeightPlusBijWeight = 1.0 / (G_Ti_minus * censoringDistribution.evaluatePrevious(responseJ.getU())); } else{ continue; @@ -97,5 +102,152 @@ public class CompetingRiskUtils { } + public static CompetingRiskSetsImpl calculateSetsEfficiently(final List responses, int[] eventsOfFocus){ + final int n = responses.size(); + int[] numberOfCurrentEvents = new int[eventsOfFocus.length+1]; + + final Map numberOfEvents = new HashMap<>(); + + final List eventTimes = new ArrayList<>(n); + final List eventAndCensorTimes = new ArrayList<>(n); + final List riskSetNumberList = new ArrayList<>(n); + + // need to first sort responses + Collections.sort(responses, (y1, y2) -> { + if(y1.getU() < y2.getU()){ + return -1; + } + else if(y1.getU() > y2.getU()){ + return 1; + } + else{ + return 0; + } + }); + + + + for(int i=0; i currentResponse.getU(); + + numberOfCurrentEvents[currentResponse.getDelta()]++; + + if(lastOfTime){ + int totalNumberOfCurrentEvents = 0; + for(int e = 1; e < numberOfCurrentEvents.length; e++){ // exclude censored events + totalNumberOfCurrentEvents += numberOfCurrentEvents[e]; + } + + final double currentTime = currentResponse.getU(); + + if(totalNumberOfCurrentEvents > 0){ // add numberOfCurrentEvents + // Add point + eventTimes.add(currentTime); + numberOfEvents.put(currentTime, numberOfCurrentEvents); + } + + // Always do risk set + // remember that the LeftContinuousFunction takes into account that at this currentTime the risk value is the previous value + final int riskSet = n - (i+1); + riskSetNumberList.add(riskSet); + eventAndCensorTimes.add(currentTime); + + // reset counters + numberOfCurrentEvents = new int[eventsOfFocus.length+1]; + + } + + } + + final double[] riskSetArray = new double[eventAndCensorTimes.size()]; + final double[] timesArray = new double[eventAndCensorTimes.size()]; + for(int i=0; i responses, int[] eventsOfFocus){ + final List sillyList = responses; // annoying Java generic work-around + final CompetingRiskSetsImpl originalSets = calculateSetsEfficiently(sillyList, eventsOfFocus); + + final double[] allTimes = DoubleStream.concat( + responses.stream() + .mapToDouble(CompetingRiskResponseWithCensorTime::getC), + responses.stream() + .mapToDouble(CompetingRiskResponseWithCensorTime::getU) + ).sorted().distinct().toArray(); + + + + final VeryDiscontinuousStepFunction[] riskSets = new VeryDiscontinuousStepFunction[eventsOfFocus.length]; + + for(final int event : eventsOfFocus){ + final double[] yAt = new double[allTimes.length]; + final double[] yRight = new double[allTimes.length]; + + for(final CompetingRiskResponseWithCensorTime response : responses){ + if(response.getDelta() == event){ + // traditional case only; increment on time t when I(t <= Ui) + final double time = response.getU(); + final int index = Arrays.binarySearch(allTimes, time); + + if(index < 0){ // TODO remove once code is stable + throw new IllegalStateException("Index shouldn't be negative!"); + } + + // All yAts up to and including index are incremented; + // All yRights up to index are incremented + yAt[index]++; + for(int i=0; i= Ui. + + // increment yAt up to Ci. If Ui==Ci, increment yAt at Ci. + final double time = response.getC(); + final int index = Arrays.binarySearch(allTimes, time); + + if(index < 0){ // TODO remove once code is stable + throw new IllegalStateException("Index shouldn't be negative!"); + } + + for(int i=0; i functions.getSurvivalCurve()) .flatMapToDouble( - function -> function.getPoints().stream() - .mapToDouble(point -> point.getTime()) + function -> Arrays.stream(function.getX()) ).sorted().distinct().toArray(); } final double n = responses.size(); - final List survivalPoints = new ArrayList<>(timesToUse.length); - for(final double time : timesToUse){ + final double[] survivalY = new double[timesToUse.length]; - final double survivalY = responses.stream() - .mapToDouble(functions -> functions.getSurvivalCurve().evaluate(time).getY() / n) + for(int i=0; i functions.getSurvivalCurve().evaluate(time) / n) .sum(); - - survivalPoints.add(new Point(time, survivalY)); - } - final MathFunction survivalFunction = new MathFunction(survivalPoints, new Point(0.0, 1.0)); + final StepFunction survivalFunction = new RightContinuousStepFunction(timesToUse, survivalY, 1.0); - final List causeSpecificCumulativeHazardFunctionList = new ArrayList<>(events.length); - final List cumulativeIncidenceFunctionList = new ArrayList<>(events.length); + final List causeSpecificCumulativeHazardFunctionList = new ArrayList<>(events.length); + final List cumulativeIncidenceFunctionList = new ArrayList<>(events.length); for(final int event : events){ - final List cumulativeHazardFunctionPoints = new ArrayList<>(timesToUse.length); - final List cumulativeIncidenceFunctionPoints = new ArrayList<>(timesToUse.length); + final double[] cumulativeHazardFunctionY = new double[timesToUse.length]; + final double[] cumulativeIncidenceFunctionY = new double[timesToUse.length]; - for(final double time : timesToUse){ + for(int i=0; i functions.getCauseSpecificHazardFunction(event).evaluate(time).getY() / n) + cumulativeHazardFunctionY[i] = responses.stream() + .mapToDouble(functions -> functions.getCauseSpecificHazardFunction(event).evaluate(time) / n) .sum(); - final double incidenceY = responses.stream() - .mapToDouble(functions -> functions.getCumulativeIncidenceFunction(event).evaluate(time).getY() / n) + cumulativeIncidenceFunctionY[i] = responses.stream() + .mapToDouble(functions -> functions.getCumulativeIncidenceFunction(event).evaluate(time) / n) .sum(); - cumulativeHazardFunctionPoints.add(new Point(time, hazardY)); - cumulativeIncidenceFunctionPoints.add(new Point(time, incidenceY)); - } - causeSpecificCumulativeHazardFunctionList.add(event-1, new MathFunction(cumulativeHazardFunctionPoints)); - cumulativeIncidenceFunctionList.add(event-1, new MathFunction(cumulativeIncidenceFunctionPoints)); + causeSpecificCumulativeHazardFunctionList.add(event-1, new RightContinuousStepFunction(timesToUse, cumulativeHazardFunctionY, 0)); + cumulativeIncidenceFunctionList.add(event-1, new RightContinuousStepFunction(timesToUse, cumulativeIncidenceFunctionY, 0)); } diff --git a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/combiner/CompetingRiskResponseCombiner.java b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/combiner/CompetingRiskResponseCombiner.java index edd4591..538245a 100644 --- a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/combiner/CompetingRiskResponseCombiner.java +++ b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/combiner/CompetingRiskResponseCombiner.java @@ -3,9 +3,9 @@ package ca.joeltherrien.randomforest.responses.competingrisk.combiner; import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskFunctions; import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse; import ca.joeltherrien.randomforest.tree.ResponseCombiner; -import ca.joeltherrien.randomforest.utils.MathFunction; import ca.joeltherrien.randomforest.utils.Point; -import lombok.RequiredArgsConstructor; +import ca.joeltherrien.randomforest.utils.RightContinuousStepFunction; +import ca.joeltherrien.randomforest.utils.StepFunction; import java.util.*; @@ -38,8 +38,8 @@ public class CompetingRiskResponseCombiner implements ResponseCombiner responses) { - final List causeSpecificCumulativeHazardFunctionList = new ArrayList<>(events.length); - final List cumulativeIncidenceFunctionList = new ArrayList<>(events.length); + final List causeSpecificCumulativeHazardFunctionList = new ArrayList<>(events.length); + final List cumulativeIncidenceFunctionList = new ArrayList<>(events.length); Collections.sort(responses, (y1, y2) -> { if(y1.getU() < y2.getU()){ @@ -97,7 +97,7 @@ public class CompetingRiskResponseCombiner implements ResponseCombiner 0 ? survivalCurve.evaluate(timesToUse[i-1]).getY() : survivalCurve.evaluate(0.0).getY(); - final double previousSurvivalEvaluation = i > 0 ? survivalCurve.evaluate(timesToUseList.get(i-1)).getY() : 1.0; + final double previousSurvivalEvaluation = i > 0 ? survivalCurve.evaluateByIndex(i-1) : 1.0; final double cifDeltaY = previousSurvivalEvaluation * (numberEventsAtTime / individualsAtRisk); final Point newCIFPoint = new Point(time_k, previousCIFPoint.getY() + cifDeltaY); @@ -138,10 +138,10 @@ public class CompetingRiskResponseCombiner implements ResponseCombiner { - - @Getter - private final CompetingRiskResponseCombiner originalCombiner; - - @Override - public CompetingRiskFunctions combine(List responses) { - final List completeList = responses.stream().flatMap(Arrays::stream).collect(Collectors.toList()); - - return originalCombiner.combine(completeList); - } -} diff --git a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/combiner/alternative/CompetingRiskResponseCombinerToList.java b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/combiner/alternative/CompetingRiskResponseCombinerToList.java deleted file mode 100644 index 5526d0a..0000000 --- a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/combiner/alternative/CompetingRiskResponseCombinerToList.java +++ /dev/null @@ -1,29 +0,0 @@ -package ca.joeltherrien.randomforest.responses.competingrisk.combiner.alternative; - -import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse; -import ca.joeltherrien.randomforest.tree.ResponseCombiner; - -import java.util.List; - -/** - * This class takes all of the observations in a terminal node and 'combines' them into just a list of the observations. - * - * This is used in the alternative approach to only compute the functions at the final stage when combining trees. - * - */ -public class CompetingRiskResponseCombinerToList implements ResponseCombiner { - - @Override - public CompetingRiskResponse[] combine(List responses) { - final CompetingRiskResponse[] array = new CompetingRiskResponse[responses.size()]; - - for(int i=0; i leftHand, List rightHand); - abstract double riskSet(final List eventList, double time, int eventOfFocus); - - private double numberOfEventsAtTime(int eventOfFocus, List eventList, double time){ - return (double) eventList.stream() - .filter(event -> event.getDelta() == eventOfFocus) - .filter(event -> event.getU() == time) // since delta != 0 we know censoring didn't occur prior to this - .count(); - - } /** * Calculates the log rank value (or the Gray's test value) for a *specific* event cause. * * @param eventOfFocus - * @param leftHand A non-empty list of CompetingRiskResponse - * @param rightHand A non-empty list of CompetingRiskResponse + * @param competingRiskSetsLeft A summary of the different sets used in the calculation for the left side + * @param competingRiskSetsRight A summary of the different sets used in the calculation for the right side * @return */ - LogRankValue specificLogRankValue(final int eventOfFocus, List leftHand, List rightHand){ + LogRankValue specificLogRankValue(final int eventOfFocus, final CompetingRiskSets competingRiskSetsLeft, final CompetingRiskSets competingRiskSetsRight){ final double[] distinctEventTimes = Stream.concat( - leftHand.stream(), rightHand.stream() - ) - .filter(event -> !event.isCensored()) - .mapToDouble(event -> event.getU()) + competingRiskSetsLeft.getEventTimes().stream(), + competingRiskSetsRight.getEventTimes().stream()) + .mapToDouble(Double::doubleValue) + .sorted() .distinct() .toArray(); @@ -51,12 +43,12 @@ public abstract class CompetingRiskGroupDifferentiator eventList, double time, int eventOfFocus) { - return eventList.stream() - .filter(event -> event.getU() >= time || - (event.getU() < time && event.getDelta() != eventOfFocus && event.getC() > time) - ) - .count(); - } } diff --git a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/differentiator/GrayLogRankSingleGroupDifferentiator.java b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/differentiator/GrayLogRankSingleGroupDifferentiator.java index 7a8d7f9..48e3b6f 100644 --- a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/differentiator/GrayLogRankSingleGroupDifferentiator.java +++ b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/differentiator/GrayLogRankSingleGroupDifferentiator.java @@ -1,6 +1,8 @@ package ca.joeltherrien.randomforest.responses.competingrisk.differentiator; +import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskGraySetsImpl; import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponseWithCensorTime; +import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskUtils; import lombok.RequiredArgsConstructor; import java.util.List; @@ -13,6 +15,7 @@ import java.util.List; public class GrayLogRankSingleGroupDifferentiator extends CompetingRiskGroupDifferentiator { private final int eventOfFocus; + private final int[] events; @Override public Double differentiate(List leftHand, List rightHand) { @@ -20,19 +23,13 @@ public class GrayLogRankSingleGroupDifferentiator extends CompetingRiskGroupDiff return null; } - final LogRankValue valueOfInterest = specificLogRankValue(eventOfFocus, leftHand, rightHand); + final CompetingRiskGraySetsImpl competingRiskSetsLeft = CompetingRiskUtils.calculateGraySetsEfficiently(leftHand, events); + final CompetingRiskGraySetsImpl competingRiskSetsRight = CompetingRiskUtils.calculateGraySetsEfficiently(rightHand, events); + + final LogRankValue valueOfInterest = specificLogRankValue(eventOfFocus, competingRiskSetsLeft, competingRiskSetsRight); return Math.abs(valueOfInterest.getNumerator() / valueOfInterest.getVarianceSqrt()); } - @Override - double riskSet(List eventList, double time, int eventOfFocus) { - return eventList.stream() - .filter(event -> event.getU() >= time || - (event.getU() < time && event.getDelta() != eventOfFocus && event.getC() > time) - ) - .count(); - } - } diff --git a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/differentiator/LogRankMultipleGroupDifferentiator.java b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/differentiator/LogRankMultipleGroupDifferentiator.java index 7a62e86..2ad2424 100644 --- a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/differentiator/LogRankMultipleGroupDifferentiator.java +++ b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/differentiator/LogRankMultipleGroupDifferentiator.java @@ -1,6 +1,8 @@ package ca.joeltherrien.randomforest.responses.competingrisk.differentiator; import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse; +import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskSetsImpl; +import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskUtils; import lombok.RequiredArgsConstructor; import java.util.List; @@ -20,11 +22,14 @@ public class LogRankMultipleGroupDifferentiator extends CompetingRiskGroupDiffer return null; } + final CompetingRiskSetsImpl competingRiskSetsLeft = CompetingRiskUtils.calculateSetsEfficiently(leftHand, events); + final CompetingRiskSetsImpl competingRiskSetsRight = CompetingRiskUtils.calculateSetsEfficiently(rightHand, events); + double numerator = 0.0; double denominatorSquared = 0.0; for(final int eventOfFocus : events){ - final LogRankValue valueOfInterest = specificLogRankValue(eventOfFocus, leftHand, rightHand); + final LogRankValue valueOfInterest = specificLogRankValue(eventOfFocus, competingRiskSetsLeft, competingRiskSetsRight); numerator += valueOfInterest.getNumerator()*valueOfInterest.getVarianceSqrt(); denominatorSquared += valueOfInterest.getVariance(); @@ -35,11 +40,4 @@ public class LogRankMultipleGroupDifferentiator extends CompetingRiskGroupDiffer } - @Override - double riskSet(List eventList, double time, int eventOfFocus) { - return eventList.stream() - .filter(event -> event.getU() >= time) - .count(); - } - } diff --git a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/differentiator/LogRankSingleGroupDifferentiator.java b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/differentiator/LogRankSingleGroupDifferentiator.java index 06cb6bd..107964e 100644 --- a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/differentiator/LogRankSingleGroupDifferentiator.java +++ b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/differentiator/LogRankSingleGroupDifferentiator.java @@ -1,6 +1,8 @@ package ca.joeltherrien.randomforest.responses.competingrisk.differentiator; import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse; +import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskSetsImpl; +import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskUtils; import lombok.RequiredArgsConstructor; import java.util.List; @@ -13,6 +15,7 @@ import java.util.List; public class LogRankSingleGroupDifferentiator extends CompetingRiskGroupDifferentiator { private final int eventOfFocus; + private final int[] events; @Override public Double differentiate(List leftHand, List rightHand) { @@ -20,17 +23,13 @@ public class LogRankSingleGroupDifferentiator extends CompetingRiskGroupDifferen return null; } - final LogRankValue valueOfInterest = specificLogRankValue(eventOfFocus, leftHand, rightHand); + final CompetingRiskSetsImpl competingRiskSetsLeft = CompetingRiskUtils.calculateSetsEfficiently(leftHand, events); + final CompetingRiskSetsImpl competingRiskSetsRight = CompetingRiskUtils.calculateSetsEfficiently(rightHand, events); + + final LogRankValue valueOfInterest = specificLogRankValue(eventOfFocus, competingRiskSetsLeft, competingRiskSetsRight); return Math.abs(valueOfInterest.getNumerator() / valueOfInterest.getVarianceSqrt()); } - @Override - double riskSet(List eventList, double time, int eventOfFocus) { - return eventList.stream() - .filter(event -> event.getU() >= time) - .count(); - } - } diff --git a/src/main/java/ca/joeltherrien/randomforest/utils/DiscontinuousStepFunction.java b/src/main/java/ca/joeltherrien/randomforest/utils/DiscontinuousStepFunction.java new file mode 100644 index 0000000..3e71977 --- /dev/null +++ b/src/main/java/ca/joeltherrien/randomforest/utils/DiscontinuousStepFunction.java @@ -0,0 +1,97 @@ +package ca.joeltherrien.randomforest.utils; + +/** + * Represents a function represented by discrete points. However, the function may be right-continuous or left-continuous + * at a given point, with no consistency. This function tracks that. + */ +public final class DiscontinuousStepFunction extends StepFunction { + + private final double[] y; + private final boolean[] isLeftContinuous; + + /** + * Represents the value that should be returned by evaluate if there are points prior to the time the function is being evaluated at. + * + * Map be null. + */ + private final double defaultY; + + public DiscontinuousStepFunction(double[] x, double[] y, boolean[] isLeftContinuous, double defaultY) { + super(x); + this.y = y; + this.isLeftContinuous = isLeftContinuous; + this.defaultY = defaultY; + } + + @Override + public double evaluate(double time){ + int index = Utils.binarySearchLessThan(0, x.length, x, time); + if(index < 0){ + return defaultY; + } + else{ + if(x[index] == time){ + return evaluateByIndex(index); + } + else{ + return y[index]; + } + } + } + + + @Override + public double evaluatePrevious(double time){ + int index = Utils.binarySearchLessThan(0, x.length, x, time) - 1; + if(index < 0){ + return defaultY; + } + else{ + if(x[index] == time){ + return evaluateByIndex(index); + } + else{ + return y[index]; + } + } + } + + @Override + public double evaluateByIndex(int i) { + if(isLeftContinuous[i]){ + i -= 1; + } + + if(i < 0){ + return defaultY; + } + + return y[i]; + } + + @Override + public String toString(){ + final StringBuilder builder = new StringBuilder(); + builder.append("Default point: "); + builder.append(defaultY); + builder.append("\n"); + + for(int i=0; i pointList, final double defaultY){ + + final double[] x = new double[pointList.size()]; + final double[] y = new double[pointList.size()]; + + final ListIterator pointIterator = pointList.listIterator(); + while(pointIterator.hasNext()){ + final int index = pointIterator.nextIndex(); + final Point currentPoint = pointIterator.next(); + + x[index] = currentPoint.getTime(); + y[index] = currentPoint.getY(); + } + + return new LeftContinuousStepFunction(x, y, defaultY); + + } + + @Override + public double evaluate(double time){ + int index = Utils.binarySearchLessThan(0, x.length, x, time); + if(index < 0){ + return defaultY; + } + else{ + if(x[index] == time){ + return evaluateByIndex(index-1); + } + else{ + return y[index]; + } + } + } + + @Override + public double evaluatePrevious(double time){ + int index = Utils.binarySearchLessThan(0, x.length, x, time) - 1; + if(index < 0){ + return defaultY; + } + else{ + if(x[index] == time){ + return evaluateByIndex(index-1); + } + else{ + return y[index]; + } + } + } + + @Override + public double evaluateByIndex(int i) { + if(i < 0){ + return defaultY; + } + + return y[i]; + } + + + @Override + public String toString(){ + final StringBuilder builder = new StringBuilder(); + builder.append("Default point: "); + builder.append(defaultY); + builder.append("\n"); + + for(int i=0; i points; - - /** - * Represents the value that should be returned by evaluate if there are points prior to the time the function is being evaluated at. - * - * Map be null. - */ - private final Point defaultValue; - - public MathFunction(final List points){ - this(points, new Point(0.0, 0.0)); - } - - public MathFunction(final List points, final Point defaultValue){ - this.points = Collections.unmodifiableList(points); - this.defaultValue = defaultValue; - } - - public Point evaluate(double time){ - int index = binarySearch(points, time); - if(index < 0){ - return defaultValue; - } - else{ - return points.get(index); - } - } - - public Point evaluatePrevious(double time){ - - int index = binarySearch(points, time) - 1; - if(index < 0){ - return defaultValue; - } - else{ - return points.get(index); - } - - - } - - /** - * Returns the index of the largest (in terms of time) Point that is <= the provided time value. - * - * @param points - * @param time - * @return The index of the largest Point who's time is <= the time parameter. - */ - private static int binarySearch(List points, double time){ - final int pointSize = points.size(); - - if(pointSize == 0 || points.get(pointSize-1).getTime() <= time){ - // we're already too far - return pointSize - 1; - } - - if(pointSize < 200){ - for(int i = 0; i < pointSize; i++){ - if(points.get(i).getTime() > time){ - return i - 1; - } - } - } - - // else - - - final int middle = pointSize / 2; - final double middleTime = points.get(middle).getTime(); - if(middleTime < time){ - // go right - return binarySearch(points.subList(middle, pointSize), time) + middle; - } - else if(middleTime > time){ - // go left - return binarySearch(points.subList(0, middle), time); - } - else{ // middleTime == time - return middle; - } - } - - @Override - public String toString(){ - final StringBuilder builder = new StringBuilder(); - builder.append("Default point: "); - builder.append(defaultValue); - builder.append("\n"); - - for(final Point point : points){ - builder.append(point); - builder.append("\n"); - } - - return builder.toString(); - } + double evaluate(double time); } diff --git a/src/main/java/ca/joeltherrien/randomforest/utils/RUtils.java b/src/main/java/ca/joeltherrien/randomforest/utils/RUtils.java index 2ee3cb4..392425f 100644 --- a/src/main/java/ca/joeltherrien/randomforest/utils/RUtils.java +++ b/src/main/java/ca/joeltherrien/randomforest/utils/RUtils.java @@ -11,26 +11,12 @@ import java.util.zip.GZIPOutputStream; */ public final class RUtils { - public static double[] extractTimes(final MathFunction function){ - final List pointList = function.getPoints(); - final double[] times = new double[pointList.size()]; - - for(int i=0; i pointList = function.getPoints(); - final double[] times = new double[pointList.size()]; - - for(int i=0; i pointList, final double defaultY){ + + final double[] x = new double[pointList.size()]; + final double[] y = new double[pointList.size()]; + + final ListIterator pointIterator = pointList.listIterator(); + while(pointIterator.hasNext()){ + final int index = pointIterator.nextIndex(); + final Point currentPoint = pointIterator.next(); + + x[index] = currentPoint.getTime(); + y[index] = currentPoint.getY(); + } + + return new RightContinuousStepFunction(x, y, defaultY); + + } + + public double[] getY(){ + return y.clone(); + } + + @Override + public double evaluate(double time){ + int index = Utils.binarySearchLessThan(0, x.length, x, time); + if(index < 0){ + return defaultY; + } + else{ + return y[index]; + } + } + + @Override + public double evaluatePrevious(double time){ + int index = Utils.binarySearchLessThan(0, x.length, x, time) - 1; + if(index < 0){ + return defaultY; + } + else{ + return y[index]; + } + } + + @Override + public double evaluateByIndex(int i) { + if(i < 0){ + return defaultY; + } + + return y[i]; + } + + + @Override + public String toString(){ + final StringBuilder builder = new StringBuilder(); + builder.append("Default point: "); + builder.append(defaultY); + builder.append("\n"); + + for(int i=0; i timeCounterMap = new HashMap<>(); @@ -33,7 +32,7 @@ public final class Utils { pointList.add(new Point(entry.getKey(), (double) newCount / n)); } - return new MathFunction(pointList, defaultPoint); + return RightContinuousStepFunction.constructFromPoints(pointList, 1.0); } @@ -64,6 +63,48 @@ public final class Utils { } } + /** + * Returns the index of the largest (in terms of time) Point that is <= the provided time value. + * + * @param startIndex Only search from startIndex (inclusive) + * @param endIndex Only search up to endIndex (exclusive) + * @param time + * @return The index of the largest Point who's time is <= the time parameter. + */ + public static int binarySearchLessThan(int startIndex, int endIndex, double[] x, double time){ + final int range = endIndex - startIndex; + + if(range == 0 || x[endIndex-1] <= time){ + // we're already too far + return endIndex - 1; + } + + if(range < 200){ + for(int i = startIndex; i < endIndex; i++){ + if(x[i] > time){ + return i - 1; + } + } + } + + // else + + + final int middle = range / 2; + final double middleTime = x[middle]; + if(middleTime < time){ + // go right + return binarySearchLessThan(middle, endIndex, x, time); + } + else if(middleTime > time){ + // go left + return binarySearchLessThan(0, middle, x, time); + } + else{ // middleTime == time + return middle; + } + } + /** * Replacement for Java 9's List.of * diff --git a/src/main/java/ca/joeltherrien/randomforest/utils/VeryDiscontinuousStepFunction.java b/src/main/java/ca/joeltherrien/randomforest/utils/VeryDiscontinuousStepFunction.java new file mode 100644 index 0000000..31ca83f --- /dev/null +++ b/src/main/java/ca/joeltherrien/randomforest/utils/VeryDiscontinuousStepFunction.java @@ -0,0 +1,64 @@ +package ca.joeltherrien.randomforest.utils; + +/** + * Represents a step function represented by discrete points. However, there may be individual time values that has + * a y value that doesn't belong to a particular 'step'. + */ +public final class VeryDiscontinuousStepFunction implements MathFunction { + + private final double[] x; + private final double[] yAt; + private final double[] yRight; + + /** + * Represents the value that should be returned by evaluate if there are points prior to the time the function is being evaluated at. + * + * Map be null. + */ + private final double defaultY; + + public VeryDiscontinuousStepFunction(double[] x, double[] yAt, double[] yRight, double defaultY) { + this.x = x; + this.yAt = yAt; + this.yRight = yRight; + this.defaultY = defaultY; + } + + @Override + public double evaluate(double time){ + int index = Utils.binarySearchLessThan(0, x.length, x, time); + if(index < 0){ + return defaultY; + } + else{ + if(x[index] == time){ + return yAt[index]; + } + else{ // time > x[index] + return yRight[index]; + } + } + } + + + @Override + public String toString(){ + final StringBuilder builder = new StringBuilder(); + builder.append("Default point: "); + builder.append(defaultY); + builder.append("\n"); + + for(int i=0; i 2); + assertTrue(functions.getCumulativeIncidenceFunction(1).getX().length > 2); assertEquals(NTREE, forest.getTrees().size()); diff --git a/src/test/java/ca/joeltherrien/randomforest/TestUtils.java b/src/test/java/ca/joeltherrien/randomforest/TestUtils.java index 60e7a07..ab4c83a 100644 --- a/src/test/java/ca/joeltherrien/randomforest/TestUtils.java +++ b/src/test/java/ca/joeltherrien/randomforest/TestUtils.java @@ -1,7 +1,6 @@ package ca.joeltherrien.randomforest; -import ca.joeltherrien.randomforest.utils.MathFunction; -import ca.joeltherrien.randomforest.utils.Point; +import ca.joeltherrien.randomforest.utils.StepFunction; import ca.joeltherrien.randomforest.utils.Utils; import org.junit.jupiter.api.Test; @@ -23,53 +22,67 @@ public class TestUtils { * * @param function */ - public static void assertCumulativeFunction(MathFunction function){ - Point previousPoint = null; - for(final Point point : function.getPoints()){ + public static void assertCumulativeFunction(StepFunction function){ + Double previousTime = null; + Double previousY = null; - if(previousPoint != null){ - assertTrue(previousPoint.getTime() < point.getTime(), "Points should be ordered and strictly different"); - assertTrue(previousPoint.getY() <= point.getY(), "Cumulative incidence functions are monotone"); + final double[] times = function.getX(); + + for(int i=0; i= point.getY(), "Survival functions are monotone"); + final double[] times = function.getX(); + + for(int i=0; i= y, "Survival functions are monotone"); } + previousTime = time; + previousY = y; - previousPoint = point; } } @Test public void testOneMinusECDF(){ final double[] times = new double[]{1.0, 1.0, 2.0, 3.0, 3.0, 50.0}; - final MathFunction survivalCurve = Utils.estimateOneMinusECDF(times); + final StepFunction survivalCurve = Utils.estimateOneMinusECDF(times); final double margin = 0.000001; - closeEnough(1.0, survivalCurve.evaluate(0.0).getY(), margin); + closeEnough(1.0, survivalCurve.evaluate(0.0), margin); - closeEnough(1.0, survivalCurve.evaluatePrevious(1.0).getY(), margin); - closeEnough(4.0/6.0, survivalCurve.evaluate(1.0).getY(), margin); + closeEnough(1.0, survivalCurve.evaluatePrevious(1.0), margin); + closeEnough(4.0/6.0, survivalCurve.evaluate(1.0), margin); - closeEnough(4.0/6.0, survivalCurve.evaluatePrevious(2.0).getY(), margin); - closeEnough(3.0/6.0, survivalCurve.evaluate(2.0).getY(), margin); + closeEnough(4.0/6.0, survivalCurve.evaluatePrevious(2.0), margin); + closeEnough(3.0/6.0, survivalCurve.evaluate(2.0), margin); - closeEnough(3.0/6.0, survivalCurve.evaluatePrevious(3.0).getY(), margin); - closeEnough(1.0/6.0, survivalCurve.evaluate(3.0).getY(), margin); + closeEnough(3.0/6.0, survivalCurve.evaluatePrevious(3.0), margin); + closeEnough(1.0/6.0, survivalCurve.evaluate(3.0), margin); - closeEnough(1.0/6.0, survivalCurve.evaluatePrevious(50.0).getY(), margin); - closeEnough(0.0, survivalCurve.evaluate(50.0).getY(), margin); + closeEnough(1.0/6.0, survivalCurve.evaluatePrevious(50.0), margin); + closeEnough(0.0, survivalCurve.evaluate(50.0), margin); assertSurvivalCurve(survivalCurve); diff --git a/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCalculatingCompetingRiskSets.java b/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCalculatingCompetingRiskSets.java new file mode 100644 index 0000000..e6fec68 --- /dev/null +++ b/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCalculatingCompetingRiskSets.java @@ -0,0 +1,214 @@ +package ca.joeltherrien.randomforest.competingrisk; + +import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskGraySetsImpl; +import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponseWithCensorTime; +import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskSetsImpl; +import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskUtils; +import org.junit.jupiter.api.Test; + +import java.util.ArrayList; +import java.util.List; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class TestCalculatingCompetingRiskSets { + + public List generateData(){ + final List data = new ArrayList<>(); + + data.add(new CompetingRiskResponseWithCensorTime(1, 1, 3)); + data.add(new CompetingRiskResponseWithCensorTime(1, 1, 3)); + data.add(new CompetingRiskResponseWithCensorTime(0, 1, 1)); + data.add(new CompetingRiskResponseWithCensorTime(1, 2, 2.5)); + data.add(new CompetingRiskResponseWithCensorTime(2, 3, 4)); + data.add(new CompetingRiskResponseWithCensorTime(0, 3, 3)); + data.add(new CompetingRiskResponseWithCensorTime(1, 4, 4)); + data.add(new CompetingRiskResponseWithCensorTime(0, 5, 5)); + data.add(new CompetingRiskResponseWithCensorTime(2, 6, 7)); + + return data; + } + + @Test + public void testCalculatingSets(){ + final List data = generateData(); + + final CompetingRiskSetsImpl sets = CompetingRiskUtils.calculateSetsEfficiently(data, new int[]{1,2}); + + final List times = sets.getEventTimes(); + assertEquals(5, times.size()); + + // Times + assertEquals(1.0, times.get(0).doubleValue()); + assertEquals(2.0, times.get(1).doubleValue()); + assertEquals(3.0, times.get(2).doubleValue()); + assertEquals(4.0, times.get(3).doubleValue()); + assertEquals(6.0, times.get(4).doubleValue()); + + // Number of Events + assertEquals(2, sets.getNumberOfEvents(1.0, 1)); + assertEquals(0, sets.getNumberOfEvents(1.0, 2)); + + assertEquals(1, sets.getNumberOfEvents(2.0, 1)); + assertEquals(0, sets.getNumberOfEvents(2.0, 2)); + + assertEquals(0, sets.getNumberOfEvents(3.0, 1)); + assertEquals(1, sets.getNumberOfEvents(3.0, 2)); + + assertEquals(1, sets.getNumberOfEvents(4.0, 1)); + assertEquals(0, sets.getNumberOfEvents(4.0, 2)); + + assertEquals(0, sets.getNumberOfEvents(6.0, 1)); + assertEquals(1, sets.getNumberOfEvents(6.0, 2)); + + // Make sure it doesn't break for other times + assertEquals(0, sets.getNumberOfEvents(5.5, 1)); + assertEquals(0, sets.getNumberOfEvents(5.5, 2)); + + + // Risk set + assertEquals(9, sets.getRiskSet(1).evaluate(0.5)); + assertEquals(9, sets.getRiskSet(2).evaluate(0.5)); + + assertEquals(9, sets.getRiskSet(1).evaluate(1.0)); + assertEquals(9, sets.getRiskSet(2).evaluate(1.0)); + + assertEquals(6, sets.getRiskSet(1).evaluate(1.5)); + assertEquals(6, sets.getRiskSet(2).evaluate(1.5)); + + assertEquals(6, sets.getRiskSet(1).evaluate(2.0)); + assertEquals(6, sets.getRiskSet(2).evaluate(2.0)); + + assertEquals(5, sets.getRiskSet(1).evaluate(2.3)); + assertEquals(5, sets.getRiskSet(2).evaluate(2.3)); + + assertEquals(5, sets.getRiskSet(1).evaluate(2.5)); + assertEquals(5, sets.getRiskSet(2).evaluate(2.5)); + + assertEquals(5, sets.getRiskSet(1).evaluate(2.7)); + assertEquals(5, sets.getRiskSet(2).evaluate(2.7)); + + assertEquals(5, sets.getRiskSet(1).evaluate(3.0)); + assertEquals(5, sets.getRiskSet(2).evaluate(3.0)); + + assertEquals(3, sets.getRiskSet(1).evaluate(3.5)); + assertEquals(3, sets.getRiskSet(2).evaluate(3.5)); + + assertEquals(3, sets.getRiskSet(1).evaluate(4.0)); + assertEquals(3, sets.getRiskSet(2).evaluate(4.0)); + + assertEquals(2, sets.getRiskSet(1).evaluate(4.5)); + assertEquals(2, sets.getRiskSet(2).evaluate(4.5)); + + assertEquals(2, sets.getRiskSet(1).evaluate(5.0)); + assertEquals(2, sets.getRiskSet(2).evaluate(5.0)); + + assertEquals(1, sets.getRiskSet(1).evaluate(5.5)); + assertEquals(1, sets.getRiskSet(2).evaluate(5.5)); + + assertEquals(1, sets.getRiskSet(1).evaluate(6.0)); + assertEquals(1, sets.getRiskSet(2).evaluate(6.0)); + + assertEquals(0, sets.getRiskSet(1).evaluate(6.5)); + assertEquals(0, sets.getRiskSet(2).evaluate(6.5)); + + assertEquals(0, sets.getRiskSet(1).evaluate(7.0)); + assertEquals(0, sets.getRiskSet(2).evaluate(7.0)); + + assertEquals(0, sets.getRiskSet(1).evaluate(7.5)); + assertEquals(0, sets.getRiskSet(2).evaluate(7.5)); + + } + + @Test + public void testCalculatingGraySets(){ + final List data = generateData(); + + final CompetingRiskGraySetsImpl sets = CompetingRiskUtils.calculateGraySetsEfficiently(data, new int[]{1,2}); + + final List times = sets.getEventTimes(); + assertEquals(5, times.size()); + + // Times + assertEquals(1.0, times.get(0).doubleValue()); + assertEquals(2.0, times.get(1).doubleValue()); + assertEquals(3.0, times.get(2).doubleValue()); + assertEquals(4.0, times.get(3).doubleValue()); + assertEquals(6.0, times.get(4).doubleValue()); + + // Number of Events + assertEquals(2, sets.getNumberOfEvents(1.0, 1)); + assertEquals(0, sets.getNumberOfEvents(1.0, 2)); + + assertEquals(1, sets.getNumberOfEvents(2.0, 1)); + assertEquals(0, sets.getNumberOfEvents(2.0, 2)); + + assertEquals(0, sets.getNumberOfEvents(3.0, 1)); + assertEquals(1, sets.getNumberOfEvents(3.0, 2)); + + assertEquals(1, sets.getNumberOfEvents(4.0, 1)); + assertEquals(0, sets.getNumberOfEvents(4.0, 2)); + + assertEquals(0, sets.getNumberOfEvents(6.0, 1)); + assertEquals(1, sets.getNumberOfEvents(6.0, 2)); + + // Make sure it doesn't break for other times + assertEquals(0, sets.getNumberOfEvents(5.5, 1)); + assertEquals(0, sets.getNumberOfEvents(5.5, 2)); + + + // Risk set + assertEquals(9, sets.getRiskSet(1).evaluate(0.5)); + assertEquals(9, sets.getRiskSet(2).evaluate(0.5)); + + assertEquals(9, sets.getRiskSet(1).evaluate(1.0)); + assertEquals(9, sets.getRiskSet(2).evaluate(1.0)); + + assertEquals(6, sets.getRiskSet(1).evaluate(1.5)); + assertEquals(8, sets.getRiskSet(2).evaluate(1.5)); + + assertEquals(6, sets.getRiskSet(1).evaluate(2.0)); + assertEquals(8, sets.getRiskSet(2).evaluate(2.0)); + + assertEquals(5, sets.getRiskSet(1).evaluate(2.3)); + assertEquals(8, sets.getRiskSet(2).evaluate(2.3)); + + assertEquals(5, sets.getRiskSet(1).evaluate(2.5)); + assertEquals(7, sets.getRiskSet(2).evaluate(2.5)); + + assertEquals(5, sets.getRiskSet(1).evaluate(2.7)); + assertEquals(7, sets.getRiskSet(2).evaluate(2.7)); + + assertEquals(5, sets.getRiskSet(1).evaluate(3.0)); + assertEquals(5, sets.getRiskSet(2).evaluate(3.0)); + + assertEquals(4, sets.getRiskSet(1).evaluate(3.5)); + assertEquals(3, sets.getRiskSet(2).evaluate(3.5)); + + assertEquals(3, sets.getRiskSet(1).evaluate(4.0)); + assertEquals(3, sets.getRiskSet(2).evaluate(4.0)); + + assertEquals(2, sets.getRiskSet(1).evaluate(4.5)); + assertEquals(2, sets.getRiskSet(2).evaluate(4.5)); + + assertEquals(2, sets.getRiskSet(1).evaluate(5.0)); + assertEquals(2, sets.getRiskSet(2).evaluate(5.0)); + + assertEquals(1, sets.getRiskSet(1).evaluate(5.5)); + assertEquals(1, sets.getRiskSet(2).evaluate(5.5)); + + assertEquals(1, sets.getRiskSet(1).evaluate(6.0)); + assertEquals(1, sets.getRiskSet(2).evaluate(6.0)); + + assertEquals(1, sets.getRiskSet(1).evaluate(6.5)); + assertEquals(0, sets.getRiskSet(2).evaluate(6.5)); + + assertEquals(0, sets.getRiskSet(1).evaluate(7.0)); + assertEquals(0, sets.getRiskSet(2).evaluate(7.0)); + + assertEquals(0, sets.getRiskSet(1).evaluate(7.5)); + assertEquals(0, sets.getRiskSet(2).evaluate(7.5)); + + } + +} diff --git a/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCompetingRisk.java b/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCompetingRisk.java index c761401..e38ce57 100644 --- a/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCompetingRisk.java +++ b/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCompetingRisk.java @@ -7,8 +7,7 @@ import ca.joeltherrien.randomforest.tree.Forest; import ca.joeltherrien.randomforest.tree.ForestTrainer; import ca.joeltherrien.randomforest.tree.Node; import ca.joeltherrien.randomforest.tree.TreeTrainer; -import ca.joeltherrien.randomforest.utils.MathFunction; -import ca.joeltherrien.randomforest.utils.Point; +import ca.joeltherrien.randomforest.utils.StepFunction; import ca.joeltherrien.randomforest.utils.Utils; import com.fasterxml.jackson.databind.node.*; import org.junit.jupiter.api.Test; @@ -19,7 +18,6 @@ import static org.junit.jupiter.api.Assertions.*; import java.io.IOException; import java.util.List; -import java.util.stream.Collectors; public class TestCompetingRisk { @@ -33,6 +31,9 @@ public class TestCompetingRisk { final ObjectNode groupDifferentiatorSettings = new ObjectNode(JsonNodeFactory.instance); groupDifferentiatorSettings.set("type", new TextNode("LogRankSingleGroupDifferentiator")); groupDifferentiatorSettings.set("eventOfFocus", new IntNode(1)); + groupDifferentiatorSettings.set("events", + new ArrayNode(JsonNodeFactory.instance, Utils.easyList(new IntNode(1), new IntNode(2))) + ); final ObjectNode responseCombinerSettings = new ObjectNode(JsonNodeFactory.instance); responseCombinerSettings.set("type", new TextNode("CompetingRiskResponseCombiner")); @@ -109,32 +110,32 @@ public class TestCompetingRisk { final CompetingRiskFunctions functions = node.evaluate(newRow); - final MathFunction causeOneCIFFunction = functions.getCumulativeIncidenceFunction(1); - final MathFunction causeTwoCIFFunction = functions.getCumulativeIncidenceFunction(2); - final MathFunction cumHazOneFunction = functions.getCauseSpecificHazardFunction(1); - final MathFunction cumHazTwoFunction = functions.getCauseSpecificHazardFunction(2); + final StepFunction causeOneCIFFunction = functions.getCumulativeIncidenceFunction(1); + final StepFunction causeTwoCIFFunction = functions.getCumulativeIncidenceFunction(2); + final StepFunction cumHazOneFunction = functions.getCauseSpecificHazardFunction(1); + final StepFunction cumHazTwoFunction = functions.getCauseSpecificHazardFunction(2); final double margin = 0.0000001; - closeEnough(0.003003003, causeOneCIFFunction.evaluate(0.02).getY(), margin); - closeEnough(0.166183852, causeOneCIFFunction.evaluate(1.00).getY(), margin); - closeEnough(0.715625487, causeOneCIFFunction.evaluate(6.50).getY(), margin); - closeEnough(0.794796334, causeOneCIFFunction.evaluate(10.60).getY(), margin); - closeEnough(0.794796334, causeOneCIFFunction.evaluate(10.80).getY(), margin); + closeEnough(0.003003003, causeOneCIFFunction.evaluate(0.02), margin); + closeEnough(0.166183852, causeOneCIFFunction.evaluate(1.00), margin); + closeEnough(0.715625487, causeOneCIFFunction.evaluate(6.50), margin); + closeEnough(0.794796334, causeOneCIFFunction.evaluate(10.60), margin); + closeEnough(0.794796334, causeOneCIFFunction.evaluate(10.80), margin); - closeEnough(0.08149211, causeTwoCIFFunction.evaluate(1.00).getY(), margin); - closeEnough(0.14926318, causeTwoCIFFunction.evaluate(6.50).getY(), margin); - closeEnough(0.15332850, causeTwoCIFFunction.evaluate(10.80).getY(), margin); + closeEnough(0.08149211, causeTwoCIFFunction.evaluate(1.00), margin); + closeEnough(0.14926318, causeTwoCIFFunction.evaluate(6.50), margin); + closeEnough(0.15332850, causeTwoCIFFunction.evaluate(10.80), margin); - closeEnough(0.1888601, cumHazOneFunction.evaluate(1.00).getY(), margin); - closeEnough(1.6189759, cumHazOneFunction.evaluate(6.50).getY(), margin); - closeEnough(2.4878342, cumHazOneFunction.evaluate(10.80).getY(), margin); + closeEnough(0.1888601, cumHazOneFunction.evaluate(1.00), margin); + closeEnough(1.6189759, cumHazOneFunction.evaluate(6.50), margin); + closeEnough(2.4878342, cumHazOneFunction.evaluate(10.80), margin); - closeEnough(0.08946513, cumHazTwoFunction.evaluate(1.00).getY(), margin); - closeEnough(0.32801830, cumHazTwoFunction.evaluate(6.50).getY(), margin); - closeEnough(0.36505534, cumHazTwoFunction.evaluate(10.80).getY(), margin); + closeEnough(0.08946513, cumHazTwoFunction.evaluate(1.00), margin); + closeEnough(0.32801830, cumHazTwoFunction.evaluate(6.50), margin); + closeEnough(0.36505534, cumHazTwoFunction.evaluate(10.80), margin); } @@ -162,18 +163,18 @@ public class TestCompetingRisk { final CompetingRiskFunctions functions = node.evaluate(newRow); - final MathFunction causeOneCIFFunction = functions.getCumulativeIncidenceFunction(1); - final MathFunction causeTwoCIFFunction = functions.getCumulativeIncidenceFunction(2); - final MathFunction cumHazOneFunction = functions.getCauseSpecificHazardFunction(1); - final MathFunction cumHazTwoFunction = functions.getCauseSpecificHazardFunction(2); + final StepFunction causeOneCIFFunction = functions.getCumulativeIncidenceFunction(1); + final StepFunction causeTwoCIFFunction = functions.getCumulativeIncidenceFunction(2); + final StepFunction cumHazOneFunction = functions.getCauseSpecificHazardFunction(1); + final StepFunction cumHazTwoFunction = functions.getCauseSpecificHazardFunction(2); final double margin = 0.0000001; - closeEnough(0, causeOneCIFFunction.evaluate(0.02).getY(), margin); - closeEnough(0.555555555, causeOneCIFFunction.evaluate(0.4).getY(), margin); - closeEnough(0.66666666666, causeOneCIFFunction.evaluate(0.8).getY(), margin); - closeEnough(0.88888888888, causeOneCIFFunction.evaluate(0.9).getY(), margin); - closeEnough(1.0, causeOneCIFFunction.evaluate(1.0).getY(), margin); + closeEnough(0, causeOneCIFFunction.evaluate(0.02), margin); + closeEnough(0.555555555, causeOneCIFFunction.evaluate(0.4), margin); + closeEnough(0.66666666666, causeOneCIFFunction.evaluate(0.8), margin); + closeEnough(0.88888888888, causeOneCIFFunction.evaluate(0.9), margin); + closeEnough(1.0, causeOneCIFFunction.evaluate(1.0), margin); /* closeEnough(0.08149211, causeTwoCIFFunction.evaluate(1.00).getY(), margin); @@ -222,11 +223,11 @@ public class TestCompetingRisk { assertCumulativeFunction(functions.getCumulativeIncidenceFunction(2)); - closeEnough(0.63, functions.getCumulativeIncidenceFunction(1).evaluate(4.0).getY(), 0.01); - closeEnough(0.765, functions.getCumulativeIncidenceFunction(1).evaluate(10.8).getY(), 0.01); + closeEnough(0.63, functions.getCumulativeIncidenceFunction(1).evaluate(4.0), 0.01); + closeEnough(0.765, functions.getCumulativeIncidenceFunction(1).evaluate(10.8), 0.01); - closeEnough(0.163, functions.getCumulativeIncidenceFunction(2).evaluate(4.0).getY(), 0.01); - closeEnough(0.195, functions.getCumulativeIncidenceFunction(2).evaluate(10.8).getY(), 0.01); + closeEnough(0.163, functions.getCumulativeIncidenceFunction(2).evaluate(4.0), 0.01); + closeEnough(0.195, functions.getCumulativeIncidenceFunction(2).evaluate(10.8), 0.01); final CompetingRiskErrorRateCalculator errorRateCalculator = new CompetingRiskErrorRateCalculator(dataset, forest, true); final double[] errorRates = errorRateCalculator.calculateConcordance(new int[]{1,2}); @@ -289,7 +290,8 @@ public class TestCompetingRisk { settings.setNtree(300); // results are too variable at 100 final List covariates = settings.getCovariates(); - final List> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation()); + final List> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), + settings.getTrainingDataLocation()); final ForestTrainer forestTrainer = new ForestTrainer<>(settings, dataset, covariates); final Forest forest = forestTrainer.trainSerial(); @@ -305,10 +307,9 @@ public class TestCompetingRisk { assertCumulativeFunction(functions.getCumulativeIncidenceFunction(1)); assertCumulativeFunction(functions.getCumulativeIncidenceFunction(2)); - final List causeOneCIFPoints = functions.getCumulativeIncidenceFunction(1).getPoints(); - // We seem to consistently underestimate the results. - assertTrue(causeOneCIFPoints.get(causeOneCIFPoints.size()-1).getY() > 0.74, "Results should match randomForestSRC; had " + causeOneCIFPoints.get(causeOneCIFPoints.size()-1).getY()); // note; most observations from randomForestSRC hover around 0.78 but I've seen it as low as 0.72 + final double endProbability = functions.getCumulativeIncidenceFunction(1).evaluate(10000000); + assertTrue(endProbability > 0.74, "Results should match randomForestSRC; had " + endProbability); // note; most observations from randomForestSRC hover around 0.78 but I've seen it as low as 0.72 final CompetingRiskErrorRateCalculator errorRateCalculator = new CompetingRiskErrorRateCalculator(dataset, forest, true); final double[] errorRates = errorRateCalculator.calculateConcordance(new int[]{1,2}); diff --git a/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCompetingRiskErrorRateCalculator.java b/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCompetingRiskErrorRateCalculator.java index 6da0d38..45d0928 100644 --- a/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCompetingRiskErrorRateCalculator.java +++ b/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCompetingRiskErrorRateCalculator.java @@ -4,8 +4,8 @@ import ca.joeltherrien.randomforest.Row; import ca.joeltherrien.randomforest.covariates.Covariate; import ca.joeltherrien.randomforest.responses.competingrisk.*; import ca.joeltherrien.randomforest.tree.Forest; -import ca.joeltherrien.randomforest.utils.MathFunction; -import ca.joeltherrien.randomforest.utils.Point; +import ca.joeltherrien.randomforest.utils.RightContinuousStepFunction; +import ca.joeltherrien.randomforest.utils.StepFunction; import ca.joeltherrien.randomforest.utils.Utils; import org.junit.jupiter.api.Test; @@ -36,7 +36,7 @@ public class TestCompetingRiskErrorRateCalculator { final double naiveConcordance = CompetingRiskUtils.calculateConcordance(responseList, mortalityArray, event); - final MathFunction fakeCensorDistribution = new MathFunction(Collections.emptyList(), new Point(0.0, 1.0)); + final StepFunction fakeCensorDistribution = RightContinuousStepFunction.constructFromPoints(Collections.emptyList(), 1.0); // This distribution will make the IPCW weights == 1, giving identical results to the naive concordance. final double ipcwConcordance = CompetingRiskUtils.calculateIPCWConcordance(responseList, mortalityArray, event, fakeCensorDistribution); diff --git a/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCompetingRiskFunctions.java b/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCompetingRiskFunctions.java index 18acf89..7147f4c 100644 --- a/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCompetingRiskFunctions.java +++ b/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCompetingRiskFunctions.java @@ -2,8 +2,9 @@ package ca.joeltherrien.randomforest.competingrisk; import ca.joeltherrien.randomforest.TestUtils; import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskFunctions; -import ca.joeltherrien.randomforest.utils.MathFunction; import ca.joeltherrien.randomforest.utils.Point; +import ca.joeltherrien.randomforest.utils.RightContinuousStepFunction; +import ca.joeltherrien.randomforest.utils.StepFunction; import ca.joeltherrien.randomforest.utils.Utils; import org.junit.jupiter.api.Test; @@ -13,19 +14,19 @@ public class TestCompetingRiskFunctions { @Test public void testCalculateEventSpecificMortality(){ - final MathFunction cif1 = new MathFunction( + final StepFunction cif1 = RightContinuousStepFunction.constructFromPoints( Utils.easyList( new Point(1.0, 0.3), new Point(1.5, 0.7), new Point(2.0, 0.8) - ), new Point(0.0 ,0.0) + ), 0.0 ); // not being used - final MathFunction chf1 = new MathFunction(Collections.emptyList()); + final StepFunction chf1 = RightContinuousStepFunction.constructFromPoints(Collections.emptyList(), 0.0); // not being used - final MathFunction km = new MathFunction(Collections.emptyList()); + final StepFunction km = RightContinuousStepFunction.constructFromPoints(Collections.emptyList(), 0.0); final CompetingRiskFunctions functions = CompetingRiskFunctions.builder() .causeSpecificHazards(Collections.singletonList(chf1)) diff --git a/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCompetingRiskResponseCombiner.java b/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCompetingRiskResponseCombiner.java index 4cfea67..44ee657 100644 --- a/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCompetingRiskResponseCombiner.java +++ b/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCompetingRiskResponseCombiner.java @@ -3,7 +3,7 @@ package ca.joeltherrien.randomforest.competingrisk; import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskFunctions; import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse; import ca.joeltherrien.randomforest.responses.competingrisk.combiner.CompetingRiskResponseCombiner; -import ca.joeltherrien.randomforest.utils.MathFunction; +import ca.joeltherrien.randomforest.utils.StepFunction; import org.junit.jupiter.api.Test; import static ca.joeltherrien.randomforest.TestUtils.closeEnough; @@ -33,17 +33,17 @@ public class TestCompetingRiskResponseCombiner { public void testCompetingRiskResponseCombiner(){ final CompetingRiskFunctions functions = generateFunctions(); - final MathFunction survivalCurve = functions.getSurvivalCurve(); + final StepFunction survivalCurve = functions.getSurvivalCurve(); // time = 1.0 1.5 2.0 2.5 // surv = 0.7142857 0.5714286 0.1904762 0.1904762 final double margin = 0.0000001; - closeEnough(0.7142857, survivalCurve.evaluate(1.0).getY(), margin); - closeEnough(0.5714286, survivalCurve.evaluate(1.5).getY(), margin); - closeEnough(0.1904762, survivalCurve.evaluate(2.0).getY(), margin); - closeEnough(0.1904762, survivalCurve.evaluate(2.5).getY(), margin); + closeEnough(0.7142857, survivalCurve.evaluate(1.0), margin); + closeEnough(0.5714286, survivalCurve.evaluate(1.5), margin); + closeEnough(0.1904762, survivalCurve.evaluate(2.0), margin); + closeEnough(0.1904762, survivalCurve.evaluate(2.5), margin); // Time = 1.0 1.5 2.0 2.5 @@ -53,17 +53,17 @@ public class TestCompetingRiskResponseCombiner { [2,] 0.0000000 0.2000000 0.5333333 0.5333333 */ - final MathFunction cumHaz1 = functions.getCauseSpecificHazardFunction(1); - closeEnough(0.2857143, cumHaz1.evaluate(1.0).getY(), margin); - closeEnough(0.2857143, cumHaz1.evaluate(1.5).getY(), margin); - closeEnough(0.6190476, cumHaz1.evaluate(2.0).getY(), margin); - closeEnough(0.6190476, cumHaz1.evaluate(2.5).getY(), margin); + final StepFunction cumHaz1 = functions.getCauseSpecificHazardFunction(1); + closeEnough(0.2857143, cumHaz1.evaluate(1.0), margin); + closeEnough(0.2857143, cumHaz1.evaluate(1.5), margin); + closeEnough(0.6190476, cumHaz1.evaluate(2.0), margin); + closeEnough(0.6190476, cumHaz1.evaluate(2.5), margin); - final MathFunction cumHaz2 = functions.getCauseSpecificHazardFunction(2); - closeEnough(0.0, cumHaz2.evaluate(1.0).getY(), margin); - closeEnough(0.2, cumHaz2.evaluate(1.5).getY(), margin); - closeEnough(0.5333333, cumHaz2.evaluate(2.0).getY(), margin); - closeEnough(0.5333333, cumHaz2.evaluate(2.5).getY(), margin); + final StepFunction cumHaz2 = functions.getCauseSpecificHazardFunction(2); + closeEnough(0.0, cumHaz2.evaluate(1.0), margin); + closeEnough(0.2, cumHaz2.evaluate(1.5), margin); + closeEnough(0.5333333, cumHaz2.evaluate(2.0), margin); + closeEnough(0.5333333, cumHaz2.evaluate(2.5), margin); /* Time = 1.0 1.5 2.0 2.5 Cumulative Incidence Curve. Each row for one event. @@ -72,17 +72,17 @@ public class TestCompetingRiskResponseCombiner { [2,] 0.0000000 0.1428571 0.3333333 0.3333333 */ - final MathFunction cic1 = functions.getCumulativeIncidenceFunction(1); - closeEnough(0.2857143, cic1.evaluate(1.0).getY(), margin); - closeEnough(0.2857143, cic1.evaluate(1.5).getY(), margin); - closeEnough(0.4761905, cic1.evaluate(2.0).getY(), margin); - closeEnough(0.4761905, cic1.evaluate(2.5).getY(), margin); + final StepFunction cic1 = functions.getCumulativeIncidenceFunction(1); + closeEnough(0.2857143, cic1.evaluate(1.0), margin); + closeEnough(0.2857143, cic1.evaluate(1.5), margin); + closeEnough(0.4761905, cic1.evaluate(2.0), margin); + closeEnough(0.4761905, cic1.evaluate(2.5), margin); - final MathFunction cic2 = functions.getCumulativeIncidenceFunction(2); - closeEnough(0.0, cic2.evaluate(1.0).getY(), margin); - closeEnough(0.1428571, cic2.evaluate(1.5).getY(), margin); - closeEnough(0.3333333, cic2.evaluate(2.0).getY(), margin); - closeEnough(0.3333333, cic2.evaluate(2.5).getY(), margin); + final StepFunction cic2 = functions.getCumulativeIncidenceFunction(2); + closeEnough(0.0, cic2.evaluate(1.0), margin); + closeEnough(0.1428571, cic2.evaluate(1.5), margin); + closeEnough(0.3333333, cic2.evaluate(2.0), margin); + closeEnough(0.3333333, cic2.evaluate(2.5), margin); } diff --git a/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestLogRankSingleGroupDifferentiator.java b/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestLogRankSingleGroupDifferentiator.java index 56e668b..ef79b65 100644 --- a/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestLogRankSingleGroupDifferentiator.java +++ b/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestLogRankSingleGroupDifferentiator.java @@ -44,7 +44,7 @@ public class TestLogRankSingleGroupDifferentiator { final List data1 = generateData1(); final List data2 = generateData2(); - final LogRankSingleGroupDifferentiator differentiator = new LogRankSingleGroupDifferentiator(1); + final LogRankSingleGroupDifferentiator differentiator = new LogRankSingleGroupDifferentiator(1, new int[]{1}); final double score = differentiator.differentiate(data1, data2); final double margin = 0.000001; diff --git a/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestMathFunction.java b/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestMathFunction.java deleted file mode 100644 index 591ca9a..0000000 --- a/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestMathFunction.java +++ /dev/null @@ -1,43 +0,0 @@ -package ca.joeltherrien.randomforest.competingrisk; - -import ca.joeltherrien.randomforest.utils.MathFunction; -import ca.joeltherrien.randomforest.utils.Point; -import org.junit.jupiter.api.Test; -import static org.junit.jupiter.api.Assertions.*; - -import java.util.ArrayList; -import java.util.List; - -public class TestMathFunction { - - private MathFunction generateMathFunction(){ - final double[] time = new double[]{1.0, 2.0, 3.0}; - final double[] y = new double[]{-1.0, 1.0, 0.5}; - - final List pointList = new ArrayList<>(); - for(int i=0; i