From e709c42da136a88063e8a3b6eac0d4fef4877482 Mon Sep 17 00:00:00 2001 From: Joel Therrien Date: Fri, 11 Jan 2019 22:56:41 -0800 Subject: [PATCH] Update the competing risk GroupDifferentiators to make efficient use of the SplitRuleUpdater updates Results in a speed improvement of over 1/3 according to a timing of the TestCompetingRisk#testLogRankSingleGroupDifferentiatorAllCovariates() test --- .../randomforest/covariates/Covariate.java | 1 + .../numeric/NumericSplitRuleUpdater.java | 9 +- .../CompetingRiskGraySetsImpl.java | 90 ++++-- .../competingrisk/CompetingRiskSets.java | 14 +- .../competingrisk/CompetingRiskSetsImpl.java | 69 +++-- .../competingrisk/CompetingRiskUtils.java | 266 +++++++++++------- .../CompetingRiskGroupDifferentiator.java | 127 +++++++-- ...rayLogRankMultipleGroupDifferentiator.java | 17 +- .../GrayLogRankSingleGroupDifferentiator.java | 18 +- .../LogRankMultipleGroupDifferentiator.java | 18 +- .../LogRankSingleGroupDifferentiator.java | 18 +- .../tree/GroupDifferentiator.java | 2 +- .../tree/SimpleGroupDifferentiator.java | 6 +- .../TestCalculatingCompetingRiskSets.java | 214 -------------- .../competingrisk/TestCompetingRisk.java | 48 +++- ...estLogRankMultipleGroupDifferentiator.java | 17 +- .../TestLogRankSingleGroupDifferentiator.java | 68 +++-- 17 files changed, 524 insertions(+), 478 deletions(-) delete mode 100644 src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCalculatingCompetingRiskSets.java diff --git a/src/main/java/ca/joeltherrien/randomforest/covariates/Covariate.java b/src/main/java/ca/joeltherrien/randomforest/covariates/Covariate.java index 46c7849..95c64df 100644 --- a/src/main/java/ca/joeltherrien/randomforest/covariates/Covariate.java +++ b/src/main/java/ca/joeltherrien/randomforest/covariates/Covariate.java @@ -40,6 +40,7 @@ public interface Covariate extends Serializable { interface SplitRuleUpdater extends Iterator>{ Split currentSplit(); + boolean currentSplitValid(); SplitUpdate nextUpdate(); } diff --git a/src/main/java/ca/joeltherrien/randomforest/covariates/numeric/NumericSplitRuleUpdater.java b/src/main/java/ca/joeltherrien/randomforest/covariates/numeric/NumericSplitRuleUpdater.java index d41b270..4f1586c 100644 --- a/src/main/java/ca/joeltherrien/randomforest/covariates/numeric/NumericSplitRuleUpdater.java +++ b/src/main/java/ca/joeltherrien/randomforest/covariates/numeric/NumericSplitRuleUpdater.java @@ -37,6 +37,11 @@ public class NumericSplitRuleUpdater implements Covariate.SplitRuleUpdater 0 && currentSplit.getRightHand().size() > 0; + } + @Override public NumericSplitUpdate nextUpdate() { if(hasNext()){ @@ -51,8 +56,8 @@ public class NumericSplitRuleUpdater implements Covariate.SplitRuleUpdater( splitRule, - orderedData.subList(0, newPosition), - orderedData.subList(newPosition, orderedData.size()), + Collections.unmodifiableList(orderedData.subList(0, newPosition)), + Collections.unmodifiableList(orderedData.subList(newPosition, orderedData.size())), Collections.emptyList()); diff --git a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskGraySetsImpl.java b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskGraySetsImpl.java index a3b6f31..2eced4c 100644 --- a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskGraySetsImpl.java +++ b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskGraySetsImpl.java @@ -1,37 +1,77 @@ package ca.joeltherrien.randomforest.responses.competingrisk; -import ca.joeltherrien.randomforest.utils.MathFunction; -import lombok.Builder; -import lombok.Getter; +import java.util.Arrays; -import java.util.List; -import java.util.Map; +public class CompetingRiskGraySetsImpl implements CompetingRiskSets { -/** - * Represents a response from CompetingRiskUtils#calculateGraySetsEfficiently - * - */ -@Builder -@Getter -public class CompetingRiskGraySetsImpl implements CompetingRiskSets{ + final double[] times; // length m array + int[][] riskSetLeft; // J x m array + final int[][] riskSetTotal; // J x m array + int[][] numberOfEventsLeft; // J+1 x m array + final int[][] numberOfEventsTotal; // J+1 x m array - private final List eventTimes; - private final MathFunction[] riskSet; - private final Map numberOfEvents; - - @Override - public MathFunction getRiskSet(int event){ - return riskSet[event-1]; + public CompetingRiskGraySetsImpl(double[] times, int[][] riskSetLeft, int[][] riskSetTotal, int[][] numberOfEventsLeft, int[][] numberOfEventsTotal) { + this.times = times; + this.riskSetLeft = riskSetLeft; + this.riskSetTotal = riskSetTotal; + this.numberOfEventsLeft = numberOfEventsLeft; + this.numberOfEventsTotal = numberOfEventsTotal; } @Override - public int getNumberOfEvents(Double time, int event){ - if(numberOfEvents.containsKey(time)){ - return numberOfEvents.get(time)[event]; + public double[] getDistinctTimes() { + return times; + } + + @Override + public int getRiskSetLeft(int timeIndex, int event) { + return riskSetLeft[event-1][timeIndex]; + } + + @Override + public int getRiskSetTotal(int timeIndex, int event) { + return riskSetTotal[event-1][timeIndex]; + } + + + @Override + public int getNumberOfEventsLeft(int timeIndex, int event) { + return numberOfEventsLeft[event][timeIndex]; + } + + @Override + public int getNumberOfEventsTotal(int timeIndex, int event) { + return numberOfEventsTotal[event][timeIndex]; + } + + @Override + public void update(CompetingRiskResponseWithCensorTime rowMovedToLeft) { + final double time = rowMovedToLeft.getU(); + final int k = Arrays.binarySearch(times, time); + final int delta_m_1 = rowMovedToLeft.getDelta() - 1; + final double censorTime = rowMovedToLeft.getC(); + + for(int j=0; j= t, in I(...) + for(int i=0; i<=k; i++){ + riskSetLeftJ[i]++; + } + + // second iteration; only if delta-1 != j + // corresponds to the second part, U_i < t & delta_i != j & C_i > t + if(delta_m_1 != j && !rowMovedToLeft.isCensored()){ + int i = k+1; + while(i < times.length && times[i] < censorTime){ + riskSetLeftJ[i]++; + i++; + } + } + } - return 0; + numberOfEventsLeft[rowMovedToLeft.getDelta()][k]++; } - - } diff --git a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskSets.java b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskSets.java index 9a53d3e..d9a72b5 100644 --- a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskSets.java +++ b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskSets.java @@ -1,13 +1,13 @@ package ca.joeltherrien.randomforest.responses.competingrisk; -import ca.joeltherrien.randomforest.utils.MathFunction; +public interface CompetingRiskSets { -import java.util.List; + double[] getDistinctTimes(); + int getRiskSetLeft(int timeIndex, int event); + int getRiskSetTotal(int timeIndex, int event); + int getNumberOfEventsLeft(int timeIndex, int event); + int getNumberOfEventsTotal(int timeIndex, int event); -public interface CompetingRiskSets { - - MathFunction getRiskSet(int event); - int getNumberOfEvents(Double time, int event); - List getEventTimes(); + void update(T rowMovedToLeft); } diff --git a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskSetsImpl.java b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskSetsImpl.java index 44a9dd8..7fd8ba3 100644 --- a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskSetsImpl.java +++ b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskSetsImpl.java @@ -1,36 +1,59 @@ package ca.joeltherrien.randomforest.responses.competingrisk; -import ca.joeltherrien.randomforest.utils.MathFunction; -import lombok.Builder; -import lombok.Getter; +import java.util.Arrays; -import java.util.List; -import java.util.Map; +public class CompetingRiskSetsImpl implements CompetingRiskSets { -/** - * Represents a response from CompetingRiskUtils#calculateSetsEfficiently - * - */ -@Builder -@Getter -public class CompetingRiskSetsImpl implements CompetingRiskSets{ + final double[] times; // length m array + int[] riskSetLeft; // length m array + final int[] riskSetTotal; // length m array + int[][] numberOfEventsLeft; // J+1 x m array + final int[][] numberOfEventsTotal; // J+1 x m array - private final List eventTimes; - private final MathFunction riskSet; - private final Map numberOfEvents; - @Override - public MathFunction getRiskSet(int event){ - return riskSet; + public CompetingRiskSetsImpl(double[] times, int[] riskSetLeft, int[] riskSetTotal, int[][] numberOfEventsLeft, int[][] numberOfEventsTotal) { + this.times = times; + this.riskSetLeft = riskSetLeft; + this.riskSetTotal = riskSetTotal; + this.numberOfEventsLeft = numberOfEventsLeft; + this.numberOfEventsTotal = numberOfEventsTotal; } @Override - public int getNumberOfEvents(Double time, int event){ - if(numberOfEvents.containsKey(time)){ - return numberOfEvents.get(time)[event]; + public double[] getDistinctTimes() { + return times; + } + + @Override + public int getRiskSetLeft(int timeIndex, int event) { + return riskSetLeft[timeIndex]; + } + + @Override + public int getRiskSetTotal(int timeIndex, int event) { + return riskSetTotal[timeIndex]; + } + + + @Override + public int getNumberOfEventsLeft(int timeIndex, int event) { + return numberOfEventsLeft[event][timeIndex]; + } + + @Override + public int getNumberOfEventsTotal(int timeIndex, int event) { + return numberOfEventsTotal[event][timeIndex]; + } + + @Override + public void update(CompetingRiskResponse rowMovedToLeft) { + final double time = rowMovedToLeft.getU(); + final int k = Arrays.binarySearch(times, time); + + for(int i=0; i<=k; i++){ + riskSetLeft[i]++; } - return 0; + numberOfEventsLeft[rowMovedToLeft.getDelta()][k]++; } - } diff --git a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskUtils.java b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskUtils.java index 3a612b1..94dc48e 100644 --- a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskUtils.java +++ b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskUtils.java @@ -1,11 +1,9 @@ package ca.joeltherrien.randomforest.responses.competingrisk; -import ca.joeltherrien.randomforest.utils.LeftContinuousStepFunction; import ca.joeltherrien.randomforest.utils.StepFunction; -import ca.joeltherrien.randomforest.utils.VeryDiscontinuousStepFunction; import java.util.*; -import java.util.stream.DoubleStream; +import java.util.stream.Stream; public class CompetingRiskUtils { @@ -102,18 +100,30 @@ public class CompetingRiskUtils { } - public static CompetingRiskSetsImpl calculateSetsEfficiently(final List responses, int[] eventsOfFocus){ - final int n = responses.size(); - int[] numberOfCurrentEvents = new int[eventsOfFocus.length+1]; - final Map numberOfEvents = new HashMap<>(); + public static CompetingRiskSetsImpl calculateSetsEfficiently(final List initialLeftHand, + final List initialRightHand, + int[] eventsOfFocus, + boolean calculateRiskSets){ - final List eventTimes = new ArrayList<>(n); - final List eventAndCensorTimes = new ArrayList<>(n); - final List riskSetNumberList = new ArrayList<>(n); + final double[] distinctEventTimes = Stream.concat( + initialLeftHand.stream(), + initialRightHand.stream()) + //.filter(y -> !y.isCensored()) + .map(CompetingRiskResponse::getU) + .mapToDouble(Double::doubleValue) + .sorted() + .distinct() + .toArray(); + + + final int m = distinctEventTimes.length; + final int[][] numberOfCurrentEventsTotal = new int[eventsOfFocus.length+1][m]; + + // Left Hand First // need to first sort responses - Collections.sort(responses, (y1, y2) -> { + Collections.sort(initialLeftHand, (y1, y2) -> { if(y1.getU() < y2.getU()){ return -1; } @@ -125,127 +135,191 @@ public class CompetingRiskUtils { } }); + final int nLeft = initialLeftHand.size(); + final int nRight = initialRightHand.size(); + + final int[][] numberOfCurrentEventsLeft = new int[eventsOfFocus.length+1][m]; + final int[] riskSetArrayLeft = new int[m]; + final int[] riskSetArrayTotal = new int[m]; - for(int i=0; i currentResponse.getU(); - numberOfCurrentEvents[currentResponse.getDelta()]++; + for(int k=0; k currentResponse.getU(); + + final int k = Arrays.binarySearch(distinctEventTimes, currentResponse.getU()); + + numberOfCurrentEventsLeft[currentResponse.getDelta()][k]++; + numberOfCurrentEventsTotal[currentResponse.getDelta()][k]++; if(lastOfTime){ int totalNumberOfCurrentEvents = 0; - for(int e = 1; e < numberOfCurrentEvents.length; e++){ // exclude censored events - totalNumberOfCurrentEvents += numberOfCurrentEvents[e]; + for(int e = 1; e < eventsOfFocus.length+1; e++){ // exclude censored events + totalNumberOfCurrentEvents += numberOfCurrentEventsLeft[e][k]; } - final double currentTime = currentResponse.getU(); + // Calculate risk set values + // Note that we only decrease values in the *future* + if(calculateRiskSets){ + final int decreaseBy = totalNumberOfCurrentEvents + numberOfCurrentEventsLeft[0][k]; + for(int j=k+1; j 0){ // add numberOfCurrentEvents - // Add point - eventTimes.add(currentTime); - numberOfEvents.put(currentTime, numberOfCurrentEvents); + } } - // Always do risk set - // remember that the LeftContinuousFunction takes into account that at this currentTime the risk value is the previous value - final int riskSet = n - (i+1); - riskSetNumberList.add(riskSet); - eventAndCensorTimes.add(currentTime); - - // reset counters - numberOfCurrentEvents = new int[eventsOfFocus.length+1]; } } - final double[] riskSetArray = new double[eventAndCensorTimes.size()]; - final double[] timesArray = new double[eventAndCensorTimes.size()]; - for(int i=0; i { + if(y1.getU() < y2.getU()){ + return -1; + } + else if(y1.getU() > y2.getU()){ + return 1; + } + else{ + return 0; + } + }); + + // Right Hand + int[] currentEventsRight = new int[eventsOfFocus.length+1]; + for(int i=0; i currentResponse.getU(); + + final int k = Arrays.binarySearch(distinctEventTimes, currentResponse.getU()); + + currentEventsRight[currentResponse.getDelta()]++; + numberOfCurrentEventsTotal[currentResponse.getDelta()][k]++; + + if(lastOfTime){ + int totalNumberOfCurrentEvents = 0; + for(int e = 1; e < eventsOfFocus.length+1; e++){ // exclude censored events + totalNumberOfCurrentEvents += currentEventsRight[e]; + } + + // Calculate risk set values + // Note that we only decrease values in the *future* + if(calculateRiskSets){ + final int decreaseBy = totalNumberOfCurrentEvents + currentEventsRight[0]; + for(int j=k+1; j responses, int[] eventsOfFocus){ - final List sillyList = responses; // annoying Java generic work-around - final CompetingRiskSetsImpl originalSets = calculateSetsEfficiently(sillyList, eventsOfFocus); - final double[] allTimes = DoubleStream.concat( - responses.stream() - .mapToDouble(CompetingRiskResponseWithCensorTime::getC), - responses.stream() - .mapToDouble(CompetingRiskResponseWithCensorTime::getU) - ).sorted().distinct().toArray(); + public static CompetingRiskGraySetsImpl calculateGraySetsEfficiently(final List initialLeftHand, + final List initialRightHand, + int[] eventsOfFocus){ + final List leftHandGenericsSuck = initialLeftHand; + final List rightHandGenericsSuck = initialRightHand; + final CompetingRiskSetsImpl normalSets = calculateSetsEfficiently( + leftHandGenericsSuck, + rightHandGenericsSuck, + eventsOfFocus, false); - final VeryDiscontinuousStepFunction[] riskSets = new VeryDiscontinuousStepFunction[eventsOfFocus.length]; + final double[] times = normalSets.times; + final int[][] numberOfEventsLeft = normalSets.numberOfEventsLeft; + final int[][] numberOfEventsTotal = normalSets.numberOfEventsTotal; - for(final int event : eventsOfFocus){ - final double[] yAt = new double[allTimes.length]; - final double[] yRight = new double[allTimes.length]; + // FYI; initialLeftHand and initialRightHand have both now been sorted + // Time to calculate the Gray modified risk sets + final int[][] riskSetsLeft = new int[eventsOfFocus.length][times.length]; + final int[][] riskSetsTotal = new int[eventsOfFocus.length][times.length]; - for(final CompetingRiskResponseWithCensorTime response : responses){ - if(response.getDelta() == event){ - // traditional case only; increment on time t when I(t <= Ui) - final double time = response.getU(); - final int index = Arrays.binarySearch(allTimes, time); + // Left hand first + for(final CompetingRiskResponseWithCensorTime response : initialLeftHand){ + final double time = response.getU(); + final int k = Arrays.binarySearch(times, time); + final int delta_m_1 = response.getDelta() - 1; + final double censorTime = response.getC(); - if(index < 0){ // TODO remove once code is stable - throw new IllegalStateException("Index shouldn't be negative!"); - } + for(int j=0; j= t, in I(...) + for(int i=0; i<=k; i++){ + riskSetLeftJ[i]++; + riskSetTotalJ[i]++; + } + + // second iteration; only if delta-1 != j + // corresponds to the second part, U_i < t & delta_i != j & C_i > t + if(delta_m_1 != j && !response.isCensored()){ + int i = k+1; + while(i < times.length && times[i] < censorTime){ + riskSetLeftJ[i]++; + riskSetTotalJ[i]++; + i++; } } - else{ - // need to increment on time t on following conditions; I(t <= Ui | t < Ci) - // Fact: Ci >= Ui. - // increment yAt up to Ci. If Ui==Ci, increment yAt at Ci. - final double time = response.getC(); - final int index = Arrays.binarySearch(allTimes, time); - - if(index < 0){ // TODO remove once code is stable - throw new IllegalStateException("Index shouldn't be negative!"); - } - - for(int i=0; i= t, in I(...) + for(int i=0; i<=k; i++){ + riskSetTotalJ[i]++; + } + + // second iteration; only if delta-1 != j + // corresponds to the second part, U_i < t & delta_i != j & C_i > t + if(delta_m_1 != j && !response.isCensored()){ + int i = k+1; + while(i < times.length && times[i] < censorTime){ + riskSetTotalJ[i]++; + i++; + } + } + + } + + } + + return new CompetingRiskGraySetsImpl(times, riskSetsLeft, riskSetsTotal, numberOfEventsLeft, numberOfEventsTotal); } diff --git a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/differentiator/CompetingRiskGroupDifferentiator.java b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/differentiator/CompetingRiskGroupDifferentiator.java index d925094..703ebfc 100644 --- a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/differentiator/CompetingRiskGroupDifferentiator.java +++ b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/differentiator/CompetingRiskGroupDifferentiator.java @@ -1,55 +1,132 @@ package ca.joeltherrien.randomforest.responses.competingrisk.differentiator; +import ca.joeltherrien.randomforest.Row; +import ca.joeltherrien.randomforest.covariates.Covariate; import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse; import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskSets; -import ca.joeltherrien.randomforest.tree.SimpleGroupDifferentiator; +import ca.joeltherrien.randomforest.tree.GroupDifferentiator; +import ca.joeltherrien.randomforest.tree.Split; +import ca.joeltherrien.randomforest.tree.SplitAndScore; import lombok.AllArgsConstructor; import lombok.Data; -import java.util.stream.Stream; +import java.util.Iterator; +import java.util.List; +import java.util.stream.Collectors; /** * See page 761 of Random survival forests for competing risks by Ishwaran et al. The class is abstract as Gray's test * modifies the abstract method. * */ -public abstract class CompetingRiskGroupDifferentiator extends SimpleGroupDifferentiator { +public abstract class CompetingRiskGroupDifferentiator implements GroupDifferentiator { + + abstract protected CompetingRiskSets createCompetingRiskSets(List leftHand, List rightHand); + + abstract protected Double getScore(final CompetingRiskSets competingRiskSets); + + @Override + public SplitAndScore differentiate(Iterator> splitIterator) { + + if(splitIterator instanceof Covariate.SplitRuleUpdater){ + return differentiateWithSplitUpdater((Covariate.SplitRuleUpdater) splitIterator); + } + else{ + return differentiateWithBasicIterator(splitIterator); + } + } + + private SplitAndScore differentiateWithBasicIterator(Iterator> splitIterator){ + Double bestScore = null; + Split bestSplit = null; + + while(splitIterator.hasNext()){ + final Split candidateSplit = splitIterator.next(); + + final List leftHand = candidateSplit.getLeftHand().stream().map(Row::getResponse).collect(Collectors.toList()); + final List rightHand = candidateSplit.getRightHand().stream().map(Row::getResponse).collect(Collectors.toList()); + + if(leftHand.isEmpty() || rightHand.isEmpty()){ + continue; + } + + final CompetingRiskSets competingRiskSets = createCompetingRiskSets(leftHand, rightHand); + + final Double score = getScore(competingRiskSets); + + if(Double.isFinite(score) && (bestScore == null || score > bestScore)){ + bestScore = score; + bestSplit = candidateSplit; + } + } + + if(bestSplit == null){ + return null; + } + + return new SplitAndScore<>(bestSplit, bestScore); + } + + private SplitAndScore differentiateWithSplitUpdater(Covariate.SplitRuleUpdater splitRuleUpdater) { + + final List leftInitialSplit = splitRuleUpdater.currentSplit().getLeftHand() + .stream().map(Row::getResponse).collect(Collectors.toList()); + final List rightInitialSplit = splitRuleUpdater.currentSplit().getRightHand() + .stream().map(Row::getResponse).collect(Collectors.toList()); + + final CompetingRiskSets competingRiskSets = createCompetingRiskSets(leftInitialSplit, rightInitialSplit); + + Double bestScore = null; + Split bestSplit = null; + + while(splitRuleUpdater.hasNext()){ + for(Row rowMoved : splitRuleUpdater.nextUpdate().rowsMovedToLeftHand()){ + competingRiskSets.update(rowMoved.getResponse()); + } + + final Double score = getScore(competingRiskSets); + + if(Double.isFinite(score) && (bestScore == null || score > bestScore)){ + bestScore = score; + bestSplit = splitRuleUpdater.currentSplit(); + } + } + + if(bestSplit == null){ + return null; + } + + return new SplitAndScore<>(bestSplit, bestScore); + + } /** * Calculates the log rank value (or the Gray's test value) for a *specific* event cause. * * @param eventOfFocus - * @param competingRiskSetsLeft A summary of the different sets used in the calculation for the left side - * @param competingRiskSetsRight A summary of the different sets used in the calculation for the right side + * @param competingRiskSets A summary of the different sets used in the calculation * @return */ - LogRankValue specificLogRankValue(final int eventOfFocus, final CompetingRiskSets competingRiskSetsLeft, final CompetingRiskSets competingRiskSetsRight){ - - final double[] distinctEventTimes = Stream.concat( - competingRiskSetsLeft.getEventTimes().stream(), - competingRiskSetsRight.getEventTimes().stream()) - .mapToDouble(Double::doubleValue) - .sorted() - .distinct() - .toArray(); + LogRankValue specificLogRankValue(final int eventOfFocus, final CompetingRiskSets competingRiskSets){ double summation = 0.0; double variance = 0.0; - for(final double time_k : distinctEventTimes){ + final double[] distinctTimes = competingRiskSets.getDistinctTimes(); + + for(int k = 0; k leftHand, List rightHand) { - if(leftHand.size() == 0 || rightHand.size() == 0){ - return null; - } - - final CompetingRiskGraySetsImpl competingRiskSetsLeft = CompetingRiskUtils.calculateGraySetsEfficiently(leftHand, events); - final CompetingRiskGraySetsImpl competingRiskSetsRight = CompetingRiskUtils.calculateGraySetsEfficiently(rightHand, events); + protected CompetingRiskSets createCompetingRiskSets(List leftHand, List rightHand){ + return CompetingRiskUtils.calculateGraySetsEfficiently(leftHand, rightHand, events); + } + @Override + protected Double getScore(final CompetingRiskSets competingRiskSets){ double numerator = 0.0; double denominatorSquared = 0.0; for(final int eventOfFocus : events){ - final LogRankValue valueOfInterest = specificLogRankValue(eventOfFocus, competingRiskSetsLeft, competingRiskSetsRight); + final LogRankValue valueOfInterest = specificLogRankValue(eventOfFocus, competingRiskSets); numerator += valueOfInterest.getNumerator()*valueOfInterest.getVarianceSqrt(); denominatorSquared += valueOfInterest.getVariance(); @@ -37,7 +35,6 @@ public class GrayLogRankMultipleGroupDifferentiator extends CompetingRiskGroupDi } return Math.abs(numerator / Math.sqrt(denominatorSquared)); - } diff --git a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/differentiator/GrayLogRankSingleGroupDifferentiator.java b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/differentiator/GrayLogRankSingleGroupDifferentiator.java index 1287f03..afe66db 100644 --- a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/differentiator/GrayLogRankSingleGroupDifferentiator.java +++ b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/differentiator/GrayLogRankSingleGroupDifferentiator.java @@ -1,7 +1,7 @@ package ca.joeltherrien.randomforest.responses.competingrisk.differentiator; -import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskGraySetsImpl; import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponseWithCensorTime; +import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskSets; import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskUtils; import lombok.RequiredArgsConstructor; @@ -18,18 +18,14 @@ public class GrayLogRankSingleGroupDifferentiator extends CompetingRiskGroupDiff private final int[] events; @Override - public Double getScore(List leftHand, List rightHand) { - if(leftHand.size() == 0 || rightHand.size() == 0){ - return null; - } - - final CompetingRiskGraySetsImpl competingRiskSetsLeft = CompetingRiskUtils.calculateGraySetsEfficiently(leftHand, events); - final CompetingRiskGraySetsImpl competingRiskSetsRight = CompetingRiskUtils.calculateGraySetsEfficiently(rightHand, events); - - final LogRankValue valueOfInterest = specificLogRankValue(eventOfFocus, competingRiskSetsLeft, competingRiskSetsRight); + protected CompetingRiskSets createCompetingRiskSets(List leftHand, List rightHand){ + return CompetingRiskUtils.calculateGraySetsEfficiently(leftHand, rightHand, events); + } + @Override + protected Double getScore(final CompetingRiskSets competingRiskSets){ + final LogRankValue valueOfInterest = specificLogRankValue(eventOfFocus, competingRiskSets); return Math.abs(valueOfInterest.getNumerator() / valueOfInterest.getVarianceSqrt()); - } } diff --git a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/differentiator/LogRankMultipleGroupDifferentiator.java b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/differentiator/LogRankMultipleGroupDifferentiator.java index 2478c8b..6465b44 100644 --- a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/differentiator/LogRankMultipleGroupDifferentiator.java +++ b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/differentiator/LogRankMultipleGroupDifferentiator.java @@ -1,7 +1,7 @@ package ca.joeltherrien.randomforest.responses.competingrisk.differentiator; import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse; -import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskSetsImpl; +import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskSets; import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskUtils; import lombok.RequiredArgsConstructor; @@ -17,19 +17,17 @@ public class LogRankMultipleGroupDifferentiator extends CompetingRiskGroupDiffer private final int[] events; @Override - public Double getScore(List leftHand, List rightHand) { - if(leftHand.size() == 0 || rightHand.size() == 0){ - return null; - } - - final CompetingRiskSetsImpl competingRiskSetsLeft = CompetingRiskUtils.calculateSetsEfficiently(leftHand, events); - final CompetingRiskSetsImpl competingRiskSetsRight = CompetingRiskUtils.calculateSetsEfficiently(rightHand, events); + protected CompetingRiskSets createCompetingRiskSets(List leftHand, List rightHand){ + return CompetingRiskUtils.calculateSetsEfficiently(leftHand, rightHand, events, true); + } + @Override + protected Double getScore(final CompetingRiskSets competingRiskSets){ double numerator = 0.0; double denominatorSquared = 0.0; for(final int eventOfFocus : events){ - final LogRankValue valueOfInterest = specificLogRankValue(eventOfFocus, competingRiskSetsLeft, competingRiskSetsRight); + final LogRankValue valueOfInterest = specificLogRankValue(eventOfFocus, competingRiskSets); numerator += valueOfInterest.getNumerator()*valueOfInterest.getVarianceSqrt(); denominatorSquared += valueOfInterest.getVariance(); @@ -37,7 +35,7 @@ public class LogRankMultipleGroupDifferentiator extends CompetingRiskGroupDiffer } return Math.abs(numerator / Math.sqrt(denominatorSquared)); - } + } diff --git a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/differentiator/LogRankSingleGroupDifferentiator.java b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/differentiator/LogRankSingleGroupDifferentiator.java index 817b4eb..7c633b1 100644 --- a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/differentiator/LogRankSingleGroupDifferentiator.java +++ b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/differentiator/LogRankSingleGroupDifferentiator.java @@ -1,7 +1,7 @@ package ca.joeltherrien.randomforest.responses.competingrisk.differentiator; import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse; -import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskSetsImpl; +import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskSets; import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskUtils; import lombok.RequiredArgsConstructor; @@ -18,18 +18,14 @@ public class LogRankSingleGroupDifferentiator extends CompetingRiskGroupDifferen private final int[] events; @Override - public Double getScore(List leftHand, List rightHand) { - if(leftHand.size() == 0 || rightHand.size() == 0){ - return null; - } - - final CompetingRiskSetsImpl competingRiskSetsLeft = CompetingRiskUtils.calculateSetsEfficiently(leftHand, events); - final CompetingRiskSetsImpl competingRiskSetsRight = CompetingRiskUtils.calculateSetsEfficiently(rightHand, events); - - final LogRankValue valueOfInterest = specificLogRankValue(eventOfFocus, competingRiskSetsLeft, competingRiskSetsRight); + protected CompetingRiskSets createCompetingRiskSets(List leftHand, List rightHand){ + return CompetingRiskUtils.calculateSetsEfficiently(leftHand, rightHand, events, true); + } + @Override + protected Double getScore(final CompetingRiskSets competingRiskSets){ + final LogRankValue valueOfInterest = specificLogRankValue(eventOfFocus, competingRiskSets); return Math.abs(valueOfInterest.getNumerator() / valueOfInterest.getVarianceSqrt()); - } } diff --git a/src/main/java/ca/joeltherrien/randomforest/tree/GroupDifferentiator.java b/src/main/java/ca/joeltherrien/randomforest/tree/GroupDifferentiator.java index b040fe0..66e8cbc 100644 --- a/src/main/java/ca/joeltherrien/randomforest/tree/GroupDifferentiator.java +++ b/src/main/java/ca/joeltherrien/randomforest/tree/GroupDifferentiator.java @@ -12,6 +12,6 @@ import java.util.Iterator; */ public interface GroupDifferentiator { - SplitAndScore differentiate(Iterator> splitIterator); + SplitAndScore differentiate(Iterator> splitIterator); } diff --git a/src/main/java/ca/joeltherrien/randomforest/tree/SimpleGroupDifferentiator.java b/src/main/java/ca/joeltherrien/randomforest/tree/SimpleGroupDifferentiator.java index 9a06afe..596f81e 100644 --- a/src/main/java/ca/joeltherrien/randomforest/tree/SimpleGroupDifferentiator.java +++ b/src/main/java/ca/joeltherrien/randomforest/tree/SimpleGroupDifferentiator.java @@ -9,12 +9,12 @@ import java.util.stream.Collectors; public abstract class SimpleGroupDifferentiator implements GroupDifferentiator { @Override - public SplitAndScore differentiate(Iterator> splitIterator) { + public SplitAndScore differentiate(Iterator> splitIterator) { Double bestScore = null; - Split bestSplit = null; + Split bestSplit = null; while(splitIterator.hasNext()){ - final Split candidateSplit = splitIterator.next(); + final Split candidateSplit = splitIterator.next(); final List leftHand = candidateSplit.getLeftHand().stream().map(Row::getResponse).collect(Collectors.toList()); final List rightHand = candidateSplit.getRightHand().stream().map(Row::getResponse).collect(Collectors.toList()); diff --git a/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCalculatingCompetingRiskSets.java b/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCalculatingCompetingRiskSets.java deleted file mode 100644 index e6fec68..0000000 --- a/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCalculatingCompetingRiskSets.java +++ /dev/null @@ -1,214 +0,0 @@ -package ca.joeltherrien.randomforest.competingrisk; - -import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskGraySetsImpl; -import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponseWithCensorTime; -import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskSetsImpl; -import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskUtils; -import org.junit.jupiter.api.Test; - -import java.util.ArrayList; -import java.util.List; - -import static org.junit.jupiter.api.Assertions.assertEquals; - -public class TestCalculatingCompetingRiskSets { - - public List generateData(){ - final List data = new ArrayList<>(); - - data.add(new CompetingRiskResponseWithCensorTime(1, 1, 3)); - data.add(new CompetingRiskResponseWithCensorTime(1, 1, 3)); - data.add(new CompetingRiskResponseWithCensorTime(0, 1, 1)); - data.add(new CompetingRiskResponseWithCensorTime(1, 2, 2.5)); - data.add(new CompetingRiskResponseWithCensorTime(2, 3, 4)); - data.add(new CompetingRiskResponseWithCensorTime(0, 3, 3)); - data.add(new CompetingRiskResponseWithCensorTime(1, 4, 4)); - data.add(new CompetingRiskResponseWithCensorTime(0, 5, 5)); - data.add(new CompetingRiskResponseWithCensorTime(2, 6, 7)); - - return data; - } - - @Test - public void testCalculatingSets(){ - final List data = generateData(); - - final CompetingRiskSetsImpl sets = CompetingRiskUtils.calculateSetsEfficiently(data, new int[]{1,2}); - - final List times = sets.getEventTimes(); - assertEquals(5, times.size()); - - // Times - assertEquals(1.0, times.get(0).doubleValue()); - assertEquals(2.0, times.get(1).doubleValue()); - assertEquals(3.0, times.get(2).doubleValue()); - assertEquals(4.0, times.get(3).doubleValue()); - assertEquals(6.0, times.get(4).doubleValue()); - - // Number of Events - assertEquals(2, sets.getNumberOfEvents(1.0, 1)); - assertEquals(0, sets.getNumberOfEvents(1.0, 2)); - - assertEquals(1, sets.getNumberOfEvents(2.0, 1)); - assertEquals(0, sets.getNumberOfEvents(2.0, 2)); - - assertEquals(0, sets.getNumberOfEvents(3.0, 1)); - assertEquals(1, sets.getNumberOfEvents(3.0, 2)); - - assertEquals(1, sets.getNumberOfEvents(4.0, 1)); - assertEquals(0, sets.getNumberOfEvents(4.0, 2)); - - assertEquals(0, sets.getNumberOfEvents(6.0, 1)); - assertEquals(1, sets.getNumberOfEvents(6.0, 2)); - - // Make sure it doesn't break for other times - assertEquals(0, sets.getNumberOfEvents(5.5, 1)); - assertEquals(0, sets.getNumberOfEvents(5.5, 2)); - - - // Risk set - assertEquals(9, sets.getRiskSet(1).evaluate(0.5)); - assertEquals(9, sets.getRiskSet(2).evaluate(0.5)); - - assertEquals(9, sets.getRiskSet(1).evaluate(1.0)); - assertEquals(9, sets.getRiskSet(2).evaluate(1.0)); - - assertEquals(6, sets.getRiskSet(1).evaluate(1.5)); - assertEquals(6, sets.getRiskSet(2).evaluate(1.5)); - - assertEquals(6, sets.getRiskSet(1).evaluate(2.0)); - assertEquals(6, sets.getRiskSet(2).evaluate(2.0)); - - assertEquals(5, sets.getRiskSet(1).evaluate(2.3)); - assertEquals(5, sets.getRiskSet(2).evaluate(2.3)); - - assertEquals(5, sets.getRiskSet(1).evaluate(2.5)); - assertEquals(5, sets.getRiskSet(2).evaluate(2.5)); - - assertEquals(5, sets.getRiskSet(1).evaluate(2.7)); - assertEquals(5, sets.getRiskSet(2).evaluate(2.7)); - - assertEquals(5, sets.getRiskSet(1).evaluate(3.0)); - assertEquals(5, sets.getRiskSet(2).evaluate(3.0)); - - assertEquals(3, sets.getRiskSet(1).evaluate(3.5)); - assertEquals(3, sets.getRiskSet(2).evaluate(3.5)); - - assertEquals(3, sets.getRiskSet(1).evaluate(4.0)); - assertEquals(3, sets.getRiskSet(2).evaluate(4.0)); - - assertEquals(2, sets.getRiskSet(1).evaluate(4.5)); - assertEquals(2, sets.getRiskSet(2).evaluate(4.5)); - - assertEquals(2, sets.getRiskSet(1).evaluate(5.0)); - assertEquals(2, sets.getRiskSet(2).evaluate(5.0)); - - assertEquals(1, sets.getRiskSet(1).evaluate(5.5)); - assertEquals(1, sets.getRiskSet(2).evaluate(5.5)); - - assertEquals(1, sets.getRiskSet(1).evaluate(6.0)); - assertEquals(1, sets.getRiskSet(2).evaluate(6.0)); - - assertEquals(0, sets.getRiskSet(1).evaluate(6.5)); - assertEquals(0, sets.getRiskSet(2).evaluate(6.5)); - - assertEquals(0, sets.getRiskSet(1).evaluate(7.0)); - assertEquals(0, sets.getRiskSet(2).evaluate(7.0)); - - assertEquals(0, sets.getRiskSet(1).evaluate(7.5)); - assertEquals(0, sets.getRiskSet(2).evaluate(7.5)); - - } - - @Test - public void testCalculatingGraySets(){ - final List data = generateData(); - - final CompetingRiskGraySetsImpl sets = CompetingRiskUtils.calculateGraySetsEfficiently(data, new int[]{1,2}); - - final List times = sets.getEventTimes(); - assertEquals(5, times.size()); - - // Times - assertEquals(1.0, times.get(0).doubleValue()); - assertEquals(2.0, times.get(1).doubleValue()); - assertEquals(3.0, times.get(2).doubleValue()); - assertEquals(4.0, times.get(3).doubleValue()); - assertEquals(6.0, times.get(4).doubleValue()); - - // Number of Events - assertEquals(2, sets.getNumberOfEvents(1.0, 1)); - assertEquals(0, sets.getNumberOfEvents(1.0, 2)); - - assertEquals(1, sets.getNumberOfEvents(2.0, 1)); - assertEquals(0, sets.getNumberOfEvents(2.0, 2)); - - assertEquals(0, sets.getNumberOfEvents(3.0, 1)); - assertEquals(1, sets.getNumberOfEvents(3.0, 2)); - - assertEquals(1, sets.getNumberOfEvents(4.0, 1)); - assertEquals(0, sets.getNumberOfEvents(4.0, 2)); - - assertEquals(0, sets.getNumberOfEvents(6.0, 1)); - assertEquals(1, sets.getNumberOfEvents(6.0, 2)); - - // Make sure it doesn't break for other times - assertEquals(0, sets.getNumberOfEvents(5.5, 1)); - assertEquals(0, sets.getNumberOfEvents(5.5, 2)); - - - // Risk set - assertEquals(9, sets.getRiskSet(1).evaluate(0.5)); - assertEquals(9, sets.getRiskSet(2).evaluate(0.5)); - - assertEquals(9, sets.getRiskSet(1).evaluate(1.0)); - assertEquals(9, sets.getRiskSet(2).evaluate(1.0)); - - assertEquals(6, sets.getRiskSet(1).evaluate(1.5)); - assertEquals(8, sets.getRiskSet(2).evaluate(1.5)); - - assertEquals(6, sets.getRiskSet(1).evaluate(2.0)); - assertEquals(8, sets.getRiskSet(2).evaluate(2.0)); - - assertEquals(5, sets.getRiskSet(1).evaluate(2.3)); - assertEquals(8, sets.getRiskSet(2).evaluate(2.3)); - - assertEquals(5, sets.getRiskSet(1).evaluate(2.5)); - assertEquals(7, sets.getRiskSet(2).evaluate(2.5)); - - assertEquals(5, sets.getRiskSet(1).evaluate(2.7)); - assertEquals(7, sets.getRiskSet(2).evaluate(2.7)); - - assertEquals(5, sets.getRiskSet(1).evaluate(3.0)); - assertEquals(5, sets.getRiskSet(2).evaluate(3.0)); - - assertEquals(4, sets.getRiskSet(1).evaluate(3.5)); - assertEquals(3, sets.getRiskSet(2).evaluate(3.5)); - - assertEquals(3, sets.getRiskSet(1).evaluate(4.0)); - assertEquals(3, sets.getRiskSet(2).evaluate(4.0)); - - assertEquals(2, sets.getRiskSet(1).evaluate(4.5)); - assertEquals(2, sets.getRiskSet(2).evaluate(4.5)); - - assertEquals(2, sets.getRiskSet(1).evaluate(5.0)); - assertEquals(2, sets.getRiskSet(2).evaluate(5.0)); - - assertEquals(1, sets.getRiskSet(1).evaluate(5.5)); - assertEquals(1, sets.getRiskSet(2).evaluate(5.5)); - - assertEquals(1, sets.getRiskSet(1).evaluate(6.0)); - assertEquals(1, sets.getRiskSet(2).evaluate(6.0)); - - assertEquals(1, sets.getRiskSet(1).evaluate(6.5)); - assertEquals(0, sets.getRiskSet(2).evaluate(6.5)); - - assertEquals(0, sets.getRiskSet(1).evaluate(7.0)); - assertEquals(0, sets.getRiskSet(2).evaluate(7.0)); - - assertEquals(0, sets.getRiskSet(1).evaluate(7.5)); - assertEquals(0, sets.getRiskSet(2).evaluate(7.5)); - - } - -} diff --git a/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCompetingRisk.java b/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCompetingRisk.java index 2d4365c..a48d38c 100644 --- a/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCompetingRisk.java +++ b/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCompetingRisk.java @@ -1,10 +1,15 @@ package ca.joeltherrien.randomforest.competingrisk; -import ca.joeltherrien.randomforest.*; -import ca.joeltherrien.randomforest.covariates.*; +import ca.joeltherrien.randomforest.CovariateRow; +import ca.joeltherrien.randomforest.DataLoader; +import ca.joeltherrien.randomforest.Row; +import ca.joeltherrien.randomforest.Settings; +import ca.joeltherrien.randomforest.covariates.Covariate; import ca.joeltherrien.randomforest.covariates.settings.BooleanCovariateSettings; import ca.joeltherrien.randomforest.covariates.settings.NumericCovariateSettings; -import ca.joeltherrien.randomforest.responses.competingrisk.*; +import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskErrorRateCalculator; +import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskFunctions; +import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse; import ca.joeltherrien.randomforest.tree.Forest; import ca.joeltherrien.randomforest.tree.ForestTrainer; import ca.joeltherrien.randomforest.tree.Node; @@ -14,14 +19,15 @@ import ca.joeltherrien.randomforest.utils.Utils; import com.fasterxml.jackson.databind.node.*; import org.junit.jupiter.api.Test; -import static ca.joeltherrien.randomforest.TestUtils.assertCumulativeFunction; -import static ca.joeltherrien.randomforest.TestUtils.closeEnough; -import static org.junit.jupiter.api.Assertions.*; - import java.io.IOException; import java.util.List; import java.util.Random; +import static ca.joeltherrien.randomforest.TestUtils.assertCumulativeFunction; +import static ca.joeltherrien.randomforest.TestUtils.closeEnough; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + public class TestCompetingRisk { @@ -284,6 +290,34 @@ public class TestCompetingRisk { assertEquals(359, countEventTwo); } + /** + * Used to time how long the algorithm takes + * + * @param args Not used. + * @throws IOException + */ + public static void main(String[] args) throws IOException { + // timing + final TestCompetingRisk tcr = new TestCompetingRisk(); + + final Settings settings = tcr.getSettings(); + settings.setNtree(300); // results are too variable at 100 + + final List covariates = settings.getCovariates(); + final List> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), + settings.getTrainingDataLocation()); + final ForestTrainer forestTrainer = new ForestTrainer<>(settings, dataset, covariates); + + final long startTime = System.currentTimeMillis(); + for(int i=0; i<50; i++){ + forestTrainer.trainSerial(); + } + final long endTime = System.currentTimeMillis(); + + final double diffTime = endTime - startTime; + System.out.println(diffTime / 1000.0 / 50.0); + } + @Test public void testLogRankSingleGroupDifferentiatorAllCovariates() throws IOException { diff --git a/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestLogRankMultipleGroupDifferentiator.java b/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestLogRankMultipleGroupDifferentiator.java index 22c975f..9836ee3 100644 --- a/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestLogRankMultipleGroupDifferentiator.java +++ b/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestLogRankMultipleGroupDifferentiator.java @@ -7,6 +7,8 @@ import ca.joeltherrien.randomforest.covariates.Covariate; import ca.joeltherrien.randomforest.covariates.settings.NumericCovariateSettings; import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse; import ca.joeltherrien.randomforest.responses.competingrisk.differentiator.LogRankMultipleGroupDifferentiator; +import ca.joeltherrien.randomforest.tree.Split; +import ca.joeltherrien.randomforest.utils.SingletonIterator; import ca.joeltherrien.randomforest.utils.Utils; import com.fasterxml.jackson.databind.node.JsonNodeFactory; import com.fasterxml.jackson.databind.node.ObjectNode; @@ -15,6 +17,8 @@ import lombok.AllArgsConstructor; import org.junit.jupiter.api.Test; import java.io.IOException; +import java.util.Collections; +import java.util.Iterator; import java.util.List; import java.util.stream.Collectors; @@ -22,6 +26,11 @@ import static org.junit.jupiter.api.Assertions.assertTrue; public class TestLogRankMultipleGroupDifferentiator { + private Iterator> turnIntoSplitIterator(List> leftList, + List> rightList){ + return new SingletonIterator>(new Split(null, leftList, rightList, Collections.emptyList())); + } + public static Data loadData(String filename) throws IOException { final ObjectNode yVarSettings = new ObjectNode(JsonNodeFactory.instance); yVarSettings.set("type", new TextNode("CompetingRiskResponse")); @@ -53,16 +62,12 @@ public class TestLogRankMultipleGroupDifferentiator { final List> group1Bad = data.subList(0, 196); final List> group2Bad = data.subList(196, data.size()); - final double scoreBad = groupDifferentiator.getScore( - group1Bad.stream().map(Row::getResponse).collect(Collectors.toList()), - group2Bad.stream().map(Row::getResponse).collect(Collectors.toList())); + final double scoreBad = groupDifferentiator.differentiate(turnIntoSplitIterator(group1Bad, group2Bad)).getScore(); final List> group1Good = data.subList(0, 199); final List> group2Good= data.subList(199, data.size()); - final double scoreGood = groupDifferentiator.getScore( - group1Good.stream().map(Row::getResponse).collect(Collectors.toList()), - group2Good.stream().map(Row::getResponse).collect(Collectors.toList())); + final double scoreGood = groupDifferentiator.differentiate(turnIntoSplitIterator(group1Good, group2Good)).getScore(); // expected results calculated manually using survival::survdiff in R; see issue #10 in Gitea closeEnough(71.41135, scoreBad, 0.00001); diff --git a/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestLogRankSingleGroupDifferentiator.java b/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestLogRankSingleGroupDifferentiator.java index e48c888..03856b0 100644 --- a/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestLogRankSingleGroupDifferentiator.java +++ b/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestLogRankSingleGroupDifferentiator.java @@ -3,10 +3,15 @@ package ca.joeltherrien.randomforest.competingrisk; import ca.joeltherrien.randomforest.Row; import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse; import ca.joeltherrien.randomforest.responses.competingrisk.differentiator.LogRankSingleGroupDifferentiator; +import ca.joeltherrien.randomforest.tree.GroupDifferentiator; +import ca.joeltherrien.randomforest.tree.Split; +import ca.joeltherrien.randomforest.utils.SingletonIterator; import org.junit.jupiter.api.Test; import java.io.IOException; import java.util.ArrayList; +import java.util.Collections; +import java.util.Iterator; import java.util.List; import java.util.stream.Collectors; @@ -15,42 +20,55 @@ import static org.junit.jupiter.api.Assertions.assertTrue; public class TestLogRankSingleGroupDifferentiator { - private List generateData1(){ - final List data = new ArrayList<>(); + private double getScore(final GroupDifferentiator groupDifferentiator, List> left, List> right){ + final Iterator> iterator = new SingletonIterator<>( + new Split<>(null, left, right, Collections.emptyList())); - data.add(new CompetingRiskResponse(1, 1.0)); - data.add(new CompetingRiskResponse(1, 1.0)); - data.add(new CompetingRiskResponse(1, 2.0)); - data.add(new CompetingRiskResponse(1, 1.5)); - data.add(new CompetingRiskResponse(0, 2.0)); - data.add(new CompetingRiskResponse(0, 1.5)); - data.add(new CompetingRiskResponse(0, 2.5)); + return groupDifferentiator.differentiate(iterator).getScore(); + + } + + int count = 1; + private Row createRow(Y response){ + return new Row<>(null, count++, response); + } + + private List> generateData1(){ + final List> data = new ArrayList<>(); + + data.add(createRow(new CompetingRiskResponse(1, 1.0))); + data.add(createRow(new CompetingRiskResponse(1, 1.0))); + data.add(createRow(new CompetingRiskResponse(1, 2.0))); + data.add(createRow(new CompetingRiskResponse(1, 1.5))); + data.add(createRow(new CompetingRiskResponse(0, 2.0))); + data.add(createRow(new CompetingRiskResponse(0, 1.5))); + data.add(createRow(new CompetingRiskResponse(0, 2.5))); return data; } - private List generateData2(){ - final List data = new ArrayList<>(); + private List> generateData2(){ + final List> data = new ArrayList<>(); - data.add(new CompetingRiskResponse(1, 2.0)); - data.add(new CompetingRiskResponse(1, 2.0)); - data.add(new CompetingRiskResponse(1, 4.0)); - data.add(new CompetingRiskResponse(1, 3.0)); - data.add(new CompetingRiskResponse(0, 4.0)); - data.add(new CompetingRiskResponse(0, 3.0)); - data.add(new CompetingRiskResponse(0, 5.0)); + data.add(createRow(new CompetingRiskResponse(1, 2.0))); + data.add(createRow(new CompetingRiskResponse(1, 2.0))); + data.add(createRow(new CompetingRiskResponse(1, 4.0))); + data.add(createRow(new CompetingRiskResponse(1, 3.0))); + data.add(createRow(new CompetingRiskResponse(0, 4.0))); + data.add(createRow(new CompetingRiskResponse(0, 3.0))); + data.add(createRow(new CompetingRiskResponse(0, 5.0))); return data; } @Test public void testCompetingRiskResponseCombiner(){ - final List data1 = generateData1(); - final List data2 = generateData2(); + final List> data1 = generateData1(); + final List> data2 = generateData2(); final LogRankSingleGroupDifferentiator differentiator = new LogRankSingleGroupDifferentiator(1, new int[]{1}); - final double score = differentiator.getScore(data1, data2); + final double score = getScore(differentiator, data1, data2); final double margin = 0.000001; // Tested using 855 method @@ -70,16 +88,12 @@ public class TestLogRankSingleGroupDifferentiator { final List> group1Good = data.subList(0, 221); final List> group2Good = data.subList(221, data.size()); - final double scoreGood = groupDifferentiator.getScore( - group1Good.stream().map(Row::getResponse).collect(Collectors.toList()), - group2Good.stream().map(Row::getResponse).collect(Collectors.toList())); + final double scoreGood = getScore(groupDifferentiator, group1Good, group2Good); final List> group1Bad = data.subList(0, 222); final List> group2Bad = data.subList(222, data.size()); - final double scoreBad = groupDifferentiator.getScore( - group1Bad.stream().map(Row::getResponse).collect(Collectors.toList()), - group2Bad.stream().map(Row::getResponse).collect(Collectors.toList())); + final double scoreBad = getScore(groupDifferentiator, group1Bad, group2Bad); // Apparently not all groups are unique when splitting assertEquals(scoreGood, scoreBad);