Update the competing risk GroupDifferentiators to make efficient use of the SplitRuleUpdater updates

Results in a speed improvement of over 1/3 according to a timing of the TestCompetingRisk#testLogRankSingleGroupDifferentiatorAllCovariates() test
This commit is contained in:
Joel Therrien 2019-01-11 22:56:41 -08:00
parent 86122fd90d
commit e709c42da1
17 changed files with 524 additions and 478 deletions

View file

@ -40,6 +40,7 @@ public interface Covariate<V> extends Serializable {
interface SplitRuleUpdater<Y, V> extends Iterator<Split<Y, V>>{
Split<Y, V> currentSplit();
boolean currentSplitValid();
SplitUpdate<Y, V> nextUpdate();
}

View file

@ -37,6 +37,11 @@ public class NumericSplitRuleUpdater<Y> implements Covariate.SplitRuleUpdater<Y,
return this.currentSplit;
}
@Override
public boolean currentSplitValid() {
return currentSplit.getLeftHand().size() > 0 && currentSplit.getRightHand().size() > 0;
}
@Override
public NumericSplitUpdate<Y> nextUpdate() {
if(hasNext()){
@ -51,8 +56,8 @@ public class NumericSplitRuleUpdater<Y> implements Covariate.SplitRuleUpdater<Y,
// Update current split
this.currentSplit = new Split<>(
splitRule,
orderedData.subList(0, newPosition),
orderedData.subList(newPosition, orderedData.size()),
Collections.unmodifiableList(orderedData.subList(0, newPosition)),
Collections.unmodifiableList(orderedData.subList(newPosition, orderedData.size())),
Collections.emptyList());

View file

@ -1,37 +1,77 @@
package ca.joeltherrien.randomforest.responses.competingrisk;
import ca.joeltherrien.randomforest.utils.MathFunction;
import lombok.Builder;
import lombok.Getter;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
public class CompetingRiskGraySetsImpl implements CompetingRiskSets<CompetingRiskResponseWithCensorTime> {
/**
* Represents a response from CompetingRiskUtils#calculateGraySetsEfficiently
*
*/
@Builder
@Getter
public class CompetingRiskGraySetsImpl implements CompetingRiskSets{
final double[] times; // length m array
int[][] riskSetLeft; // J x m array
final int[][] riskSetTotal; // J x m array
int[][] numberOfEventsLeft; // J+1 x m array
final int[][] numberOfEventsTotal; // J+1 x m array
private final List<Double> eventTimes;
private final MathFunction[] riskSet;
private final Map<Double, int[]> numberOfEvents;
@Override
public MathFunction getRiskSet(int event){
return riskSet[event-1];
public CompetingRiskGraySetsImpl(double[] times, int[][] riskSetLeft, int[][] riskSetTotal, int[][] numberOfEventsLeft, int[][] numberOfEventsTotal) {
this.times = times;
this.riskSetLeft = riskSetLeft;
this.riskSetTotal = riskSetTotal;
this.numberOfEventsLeft = numberOfEventsLeft;
this.numberOfEventsTotal = numberOfEventsTotal;
}
@Override
public int getNumberOfEvents(Double time, int event){
if(numberOfEvents.containsKey(time)){
return numberOfEvents.get(time)[event];
public double[] getDistinctTimes() {
return times;
}
@Override
public int getRiskSetLeft(int timeIndex, int event) {
return riskSetLeft[event-1][timeIndex];
}
@Override
public int getRiskSetTotal(int timeIndex, int event) {
return riskSetTotal[event-1][timeIndex];
}
@Override
public int getNumberOfEventsLeft(int timeIndex, int event) {
return numberOfEventsLeft[event][timeIndex];
}
@Override
public int getNumberOfEventsTotal(int timeIndex, int event) {
return numberOfEventsTotal[event][timeIndex];
}
@Override
public void update(CompetingRiskResponseWithCensorTime rowMovedToLeft) {
final double time = rowMovedToLeft.getU();
final int k = Arrays.binarySearch(times, time);
final int delta_m_1 = rowMovedToLeft.getDelta() - 1;
final double censorTime = rowMovedToLeft.getC();
for(int j=0; j<riskSetLeft.length; j++){
final int[] riskSetLeftJ = riskSetLeft[j];
// first iteration; perform normal increment as if Y is normal
// corresponds to the first part, U_i >= t, in I(...)
for(int i=0; i<=k; i++){
riskSetLeftJ[i]++;
}
// second iteration; only if delta-1 != j
// corresponds to the second part, U_i < t & delta_i != j & C_i > t
if(delta_m_1 != j && !rowMovedToLeft.isCensored()){
int i = k+1;
while(i < times.length && times[i] < censorTime){
riskSetLeftJ[i]++;
i++;
}
}
}
return 0;
numberOfEventsLeft[rowMovedToLeft.getDelta()][k]++;
}
}

View file

@ -1,13 +1,13 @@
package ca.joeltherrien.randomforest.responses.competingrisk;
import ca.joeltherrien.randomforest.utils.MathFunction;
public interface CompetingRiskSets<T extends CompetingRiskResponse> {
import java.util.List;
double[] getDistinctTimes();
int getRiskSetLeft(int timeIndex, int event);
int getRiskSetTotal(int timeIndex, int event);
int getNumberOfEventsLeft(int timeIndex, int event);
int getNumberOfEventsTotal(int timeIndex, int event);
public interface CompetingRiskSets {
MathFunction getRiskSet(int event);
int getNumberOfEvents(Double time, int event);
List<Double> getEventTimes();
void update(T rowMovedToLeft);
}

View file

@ -1,36 +1,59 @@
package ca.joeltherrien.randomforest.responses.competingrisk;
import ca.joeltherrien.randomforest.utils.MathFunction;
import lombok.Builder;
import lombok.Getter;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
public class CompetingRiskSetsImpl implements CompetingRiskSets<CompetingRiskResponse> {
/**
* Represents a response from CompetingRiskUtils#calculateSetsEfficiently
*
*/
@Builder
@Getter
public class CompetingRiskSetsImpl implements CompetingRiskSets{
final double[] times; // length m array
int[] riskSetLeft; // length m array
final int[] riskSetTotal; // length m array
int[][] numberOfEventsLeft; // J+1 x m array
final int[][] numberOfEventsTotal; // J+1 x m array
private final List<Double> eventTimes;
private final MathFunction riskSet;
private final Map<Double, int[]> numberOfEvents;
@Override
public MathFunction getRiskSet(int event){
return riskSet;
public CompetingRiskSetsImpl(double[] times, int[] riskSetLeft, int[] riskSetTotal, int[][] numberOfEventsLeft, int[][] numberOfEventsTotal) {
this.times = times;
this.riskSetLeft = riskSetLeft;
this.riskSetTotal = riskSetTotal;
this.numberOfEventsLeft = numberOfEventsLeft;
this.numberOfEventsTotal = numberOfEventsTotal;
}
@Override
public int getNumberOfEvents(Double time, int event){
if(numberOfEvents.containsKey(time)){
return numberOfEvents.get(time)[event];
public double[] getDistinctTimes() {
return times;
}
@Override
public int getRiskSetLeft(int timeIndex, int event) {
return riskSetLeft[timeIndex];
}
@Override
public int getRiskSetTotal(int timeIndex, int event) {
return riskSetTotal[timeIndex];
}
@Override
public int getNumberOfEventsLeft(int timeIndex, int event) {
return numberOfEventsLeft[event][timeIndex];
}
@Override
public int getNumberOfEventsTotal(int timeIndex, int event) {
return numberOfEventsTotal[event][timeIndex];
}
@Override
public void update(CompetingRiskResponse rowMovedToLeft) {
final double time = rowMovedToLeft.getU();
final int k = Arrays.binarySearch(times, time);
for(int i=0; i<=k; i++){
riskSetLeft[i]++;
}
return 0;
numberOfEventsLeft[rowMovedToLeft.getDelta()][k]++;
}
}

View file

@ -1,11 +1,9 @@
package ca.joeltherrien.randomforest.responses.competingrisk;
import ca.joeltherrien.randomforest.utils.LeftContinuousStepFunction;
import ca.joeltherrien.randomforest.utils.StepFunction;
import ca.joeltherrien.randomforest.utils.VeryDiscontinuousStepFunction;
import java.util.*;
import java.util.stream.DoubleStream;
import java.util.stream.Stream;
public class CompetingRiskUtils {
@ -102,18 +100,30 @@ public class CompetingRiskUtils {
}
public static CompetingRiskSetsImpl calculateSetsEfficiently(final List<CompetingRiskResponse> responses, int[] eventsOfFocus){
final int n = responses.size();
int[] numberOfCurrentEvents = new int[eventsOfFocus.length+1];
final Map<Double, int[]> numberOfEvents = new HashMap<>();
public static CompetingRiskSetsImpl calculateSetsEfficiently(final List<CompetingRiskResponse> initialLeftHand,
final List<CompetingRiskResponse> initialRightHand,
int[] eventsOfFocus,
boolean calculateRiskSets){
final List<Double> eventTimes = new ArrayList<>(n);
final List<Double> eventAndCensorTimes = new ArrayList<>(n);
final List<Integer> riskSetNumberList = new ArrayList<>(n);
final double[] distinctEventTimes = Stream.concat(
initialLeftHand.stream(),
initialRightHand.stream())
//.filter(y -> !y.isCensored())
.map(CompetingRiskResponse::getU)
.mapToDouble(Double::doubleValue)
.sorted()
.distinct()
.toArray();
final int m = distinctEventTimes.length;
final int[][] numberOfCurrentEventsTotal = new int[eventsOfFocus.length+1][m];
// Left Hand First
// need to first sort responses
Collections.sort(responses, (y1, y2) -> {
Collections.sort(initialLeftHand, (y1, y2) -> {
if(y1.getU() < y2.getU()){
return -1;
}
@ -125,127 +135,191 @@ public class CompetingRiskUtils {
}
});
final int nLeft = initialLeftHand.size();
final int nRight = initialRightHand.size();
final int[][] numberOfCurrentEventsLeft = new int[eventsOfFocus.length+1][m];
final int[] riskSetArrayLeft = new int[m];
final int[] riskSetArrayTotal = new int[m];
for(int i=0; i<n; i++){
final CompetingRiskResponse currentResponse = responses.get(i);
final boolean lastOfTime = (i+1)==n || responses.get(i+1).getU() > currentResponse.getU();
numberOfCurrentEvents[currentResponse.getDelta()]++;
for(int k=0; k<m; k++){
riskSetArrayLeft[k] = nLeft;
riskSetArrayTotal[k] = nLeft + nRight;
}
// Left Hand
for(int i=0; i<nLeft; i++){
final CompetingRiskResponse currentResponse = initialLeftHand.get(i);
final boolean lastOfTime = (i+1)==nLeft || initialLeftHand.get(i+1).getU() > currentResponse.getU();
final int k = Arrays.binarySearch(distinctEventTimes, currentResponse.getU());
numberOfCurrentEventsLeft[currentResponse.getDelta()][k]++;
numberOfCurrentEventsTotal[currentResponse.getDelta()][k]++;
if(lastOfTime){
int totalNumberOfCurrentEvents = 0;
for(int e = 1; e < numberOfCurrentEvents.length; e++){ // exclude censored events
totalNumberOfCurrentEvents += numberOfCurrentEvents[e];
for(int e = 1; e < eventsOfFocus.length+1; e++){ // exclude censored events
totalNumberOfCurrentEvents += numberOfCurrentEventsLeft[e][k];
}
final double currentTime = currentResponse.getU();
// Calculate risk set values
// Note that we only decrease values in the *future*
if(calculateRiskSets){
final int decreaseBy = totalNumberOfCurrentEvents + numberOfCurrentEventsLeft[0][k];
for(int j=k+1; j<m; j++){
riskSetArrayLeft[j] = riskSetArrayLeft[j] - decreaseBy;
riskSetArrayTotal[j] = riskSetArrayTotal[j] - decreaseBy;
if(totalNumberOfCurrentEvents > 0){ // add numberOfCurrentEvents
// Add point
eventTimes.add(currentTime);
numberOfEvents.put(currentTime, numberOfCurrentEvents);
}
}
// Always do risk set
// remember that the LeftContinuousFunction takes into account that at this currentTime the risk value is the previous value
final int riskSet = n - (i+1);
riskSetNumberList.add(riskSet);
eventAndCensorTimes.add(currentTime);
// reset counters
numberOfCurrentEvents = new int[eventsOfFocus.length+1];
}
}
final double[] riskSetArray = new double[eventAndCensorTimes.size()];
final double[] timesArray = new double[eventAndCensorTimes.size()];
for(int i=0; i<riskSetArray.length; i++){
timesArray[i] = eventAndCensorTimes.get(i);
riskSetArray[i] = riskSetNumberList.get(i);
// Right Hand Next. Note that we only need to keep track of the Left Hand and the Total
// need to first sort responses
Collections.sort(initialRightHand, (y1, y2) -> {
if(y1.getU() < y2.getU()){
return -1;
}
else if(y1.getU() > y2.getU()){
return 1;
}
else{
return 0;
}
});
// Right Hand
int[] currentEventsRight = new int[eventsOfFocus.length+1];
for(int i=0; i<nRight; i++){
final CompetingRiskResponse currentResponse = initialRightHand.get(i);
final boolean lastOfTime = (i+1)==nRight || initialRightHand.get(i+1).getU() > currentResponse.getU();
final int k = Arrays.binarySearch(distinctEventTimes, currentResponse.getU());
currentEventsRight[currentResponse.getDelta()]++;
numberOfCurrentEventsTotal[currentResponse.getDelta()][k]++;
if(lastOfTime){
int totalNumberOfCurrentEvents = 0;
for(int e = 1; e < eventsOfFocus.length+1; e++){ // exclude censored events
totalNumberOfCurrentEvents += currentEventsRight[e];
}
// Calculate risk set values
// Note that we only decrease values in the *future*
if(calculateRiskSets){
final int decreaseBy = totalNumberOfCurrentEvents + currentEventsRight[0];
for(int j=k+1; j<m; j++){
riskSetArrayTotal[j] = riskSetArrayTotal[j] - decreaseBy;
}
}
// Reset
currentEventsRight = new int[eventsOfFocus.length+1];
}
}
final LeftContinuousStepFunction riskSetFunction = new LeftContinuousStepFunction(timesArray, riskSetArray, n);
return CompetingRiskSetsImpl.builder()
.numberOfEvents(numberOfEvents)
.riskSet(riskSetFunction)
.eventTimes(eventTimes)
.build();
return new CompetingRiskSetsImpl(distinctEventTimes, riskSetArrayLeft, riskSetArrayTotal, numberOfCurrentEventsLeft, numberOfCurrentEventsTotal);
}
public static CompetingRiskGraySetsImpl calculateGraySetsEfficiently(final List<CompetingRiskResponseWithCensorTime> responses, int[] eventsOfFocus){
final List sillyList = responses; // annoying Java generic work-around
final CompetingRiskSetsImpl originalSets = calculateSetsEfficiently(sillyList, eventsOfFocus);
final double[] allTimes = DoubleStream.concat(
responses.stream()
.mapToDouble(CompetingRiskResponseWithCensorTime::getC),
responses.stream()
.mapToDouble(CompetingRiskResponseWithCensorTime::getU)
).sorted().distinct().toArray();
public static CompetingRiskGraySetsImpl calculateGraySetsEfficiently(final List<CompetingRiskResponseWithCensorTime> initialLeftHand,
final List<CompetingRiskResponseWithCensorTime> initialRightHand,
int[] eventsOfFocus){
final List leftHandGenericsSuck = initialLeftHand;
final List rightHandGenericsSuck = initialRightHand;
final CompetingRiskSetsImpl normalSets = calculateSetsEfficiently(
leftHandGenericsSuck,
rightHandGenericsSuck,
eventsOfFocus, false);
final VeryDiscontinuousStepFunction[] riskSets = new VeryDiscontinuousStepFunction[eventsOfFocus.length];
final double[] times = normalSets.times;
final int[][] numberOfEventsLeft = normalSets.numberOfEventsLeft;
final int[][] numberOfEventsTotal = normalSets.numberOfEventsTotal;
for(final int event : eventsOfFocus){
final double[] yAt = new double[allTimes.length];
final double[] yRight = new double[allTimes.length];
// FYI; initialLeftHand and initialRightHand have both now been sorted
// Time to calculate the Gray modified risk sets
final int[][] riskSetsLeft = new int[eventsOfFocus.length][times.length];
final int[][] riskSetsTotal = new int[eventsOfFocus.length][times.length];
for(final CompetingRiskResponseWithCensorTime response : responses){
if(response.getDelta() == event){
// traditional case only; increment on time t when I(t <= Ui)
final double time = response.getU();
final int index = Arrays.binarySearch(allTimes, time);
// Left hand first
for(final CompetingRiskResponseWithCensorTime response : initialLeftHand){
final double time = response.getU();
final int k = Arrays.binarySearch(times, time);
final int delta_m_1 = response.getDelta() - 1;
final double censorTime = response.getC();
if(index < 0){ // TODO remove once code is stable
throw new IllegalStateException("Index shouldn't be negative!");
}
for(int j=0; j<eventsOfFocus.length; j++){
final int[] riskSetLeftJ = riskSetsLeft[j];
final int[] riskSetTotalJ = riskSetsTotal[j];
// All yAts up to and including index are incremented;
// All yRights up to index are incremented
yAt[index]++;
for(int i=0; i<index; i++){
yAt[i]++;
yRight[i]++;
// first iteration; perform normal increment as if Y is normal
// corresponds to the first part, U_i >= t, in I(...)
for(int i=0; i<=k; i++){
riskSetLeftJ[i]++;
riskSetTotalJ[i]++;
}
// second iteration; only if delta-1 != j
// corresponds to the second part, U_i < t & delta_i != j & C_i > t
if(delta_m_1 != j && !response.isCensored()){
int i = k+1;
while(i < times.length && times[i] < censorTime){
riskSetLeftJ[i]++;
riskSetTotalJ[i]++;
i++;
}
}
else{
// need to increment on time t on following conditions; I(t <= Ui | t < Ci)
// Fact: Ci >= Ui.
// increment yAt up to Ci. If Ui==Ci, increment yAt at Ci.
final double time = response.getC();
final int index = Arrays.binarySearch(allTimes, time);
if(index < 0){ // TODO remove once code is stable
throw new IllegalStateException("Index shouldn't be negative!");
}
for(int i=0; i<index; i++){
yAt[i]++;
yRight[i]++;
}
if(response.getU() == response.getC()){
yAt[index]++;
}
}
}
riskSets[event-1] = new VeryDiscontinuousStepFunction(allTimes, yAt, yRight, responses.size());
}
return CompetingRiskGraySetsImpl.builder()
.numberOfEvents(originalSets.getNumberOfEvents())
.eventTimes(originalSets.getEventTimes())
.riskSet(riskSets)
.build();
// Repeat for right hand
for(final CompetingRiskResponseWithCensorTime response : initialRightHand){
final double time = response.getU();
final int k = Arrays.binarySearch(times, time);
final int delta_m_1 = response.getDelta() - 1;
final double censorTime = response.getC();
for(int j=0; j<eventsOfFocus.length; j++){
final int[] riskSetTotalJ = riskSetsTotal[j];
// first iteration; perform normal increment as if Y is normal
// corresponds to the first part, U_i >= t, in I(...)
for(int i=0; i<=k; i++){
riskSetTotalJ[i]++;
}
// second iteration; only if delta-1 != j
// corresponds to the second part, U_i < t & delta_i != j & C_i > t
if(delta_m_1 != j && !response.isCensored()){
int i = k+1;
while(i < times.length && times[i] < censorTime){
riskSetTotalJ[i]++;
i++;
}
}
}
}
return new CompetingRiskGraySetsImpl(times, riskSetsLeft, riskSetsTotal, numberOfEventsLeft, numberOfEventsTotal);
}

View file

@ -1,55 +1,132 @@
package ca.joeltherrien.randomforest.responses.competingrisk.differentiator;
import ca.joeltherrien.randomforest.Row;
import ca.joeltherrien.randomforest.covariates.Covariate;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskSets;
import ca.joeltherrien.randomforest.tree.SimpleGroupDifferentiator;
import ca.joeltherrien.randomforest.tree.GroupDifferentiator;
import ca.joeltherrien.randomforest.tree.Split;
import ca.joeltherrien.randomforest.tree.SplitAndScore;
import lombok.AllArgsConstructor;
import lombok.Data;
import java.util.stream.Stream;
import java.util.Iterator;
import java.util.List;
import java.util.stream.Collectors;
/**
* See page 761 of Random survival forests for competing risks by Ishwaran et al. The class is abstract as Gray's test
* modifies the abstract method.
*
*/
public abstract class CompetingRiskGroupDifferentiator<Y extends CompetingRiskResponse> extends SimpleGroupDifferentiator<Y> {
public abstract class CompetingRiskGroupDifferentiator<Y extends CompetingRiskResponse> implements GroupDifferentiator<Y> {
abstract protected CompetingRiskSets<Y> createCompetingRiskSets(List<Y> leftHand, List<Y> rightHand);
abstract protected Double getScore(final CompetingRiskSets<Y> competingRiskSets);
@Override
public SplitAndScore<Y, ?> differentiate(Iterator<Split<Y, ?>> splitIterator) {
if(splitIterator instanceof Covariate.SplitRuleUpdater){
return differentiateWithSplitUpdater((Covariate.SplitRuleUpdater) splitIterator);
}
else{
return differentiateWithBasicIterator(splitIterator);
}
}
private SplitAndScore<Y, ?> differentiateWithBasicIterator(Iterator<Split<Y, ?>> splitIterator){
Double bestScore = null;
Split<Y, ?> bestSplit = null;
while(splitIterator.hasNext()){
final Split<Y, ?> candidateSplit = splitIterator.next();
final List<Y> leftHand = candidateSplit.getLeftHand().stream().map(Row::getResponse).collect(Collectors.toList());
final List<Y> rightHand = candidateSplit.getRightHand().stream().map(Row::getResponse).collect(Collectors.toList());
if(leftHand.isEmpty() || rightHand.isEmpty()){
continue;
}
final CompetingRiskSets<Y> competingRiskSets = createCompetingRiskSets(leftHand, rightHand);
final Double score = getScore(competingRiskSets);
if(Double.isFinite(score) && (bestScore == null || score > bestScore)){
bestScore = score;
bestSplit = candidateSplit;
}
}
if(bestSplit == null){
return null;
}
return new SplitAndScore<>(bestSplit, bestScore);
}
private SplitAndScore<Y, ?> differentiateWithSplitUpdater(Covariate.SplitRuleUpdater<Y, ?> splitRuleUpdater) {
final List<Y> leftInitialSplit = splitRuleUpdater.currentSplit().getLeftHand()
.stream().map(Row::getResponse).collect(Collectors.toList());
final List<Y> rightInitialSplit = splitRuleUpdater.currentSplit().getRightHand()
.stream().map(Row::getResponse).collect(Collectors.toList());
final CompetingRiskSets<Y> competingRiskSets = createCompetingRiskSets(leftInitialSplit, rightInitialSplit);
Double bestScore = null;
Split<Y, ?> bestSplit = null;
while(splitRuleUpdater.hasNext()){
for(Row<Y> rowMoved : splitRuleUpdater.nextUpdate().rowsMovedToLeftHand()){
competingRiskSets.update(rowMoved.getResponse());
}
final Double score = getScore(competingRiskSets);
if(Double.isFinite(score) && (bestScore == null || score > bestScore)){
bestScore = score;
bestSplit = splitRuleUpdater.currentSplit();
}
}
if(bestSplit == null){
return null;
}
return new SplitAndScore<>(bestSplit, bestScore);
}
/**
* Calculates the log rank value (or the Gray's test value) for a *specific* event cause.
*
* @param eventOfFocus
* @param competingRiskSetsLeft A summary of the different sets used in the calculation for the left side
* @param competingRiskSetsRight A summary of the different sets used in the calculation for the right side
* @param competingRiskSets A summary of the different sets used in the calculation
* @return
*/
LogRankValue specificLogRankValue(final int eventOfFocus, final CompetingRiskSets competingRiskSetsLeft, final CompetingRiskSets competingRiskSetsRight){
final double[] distinctEventTimes = Stream.concat(
competingRiskSetsLeft.getEventTimes().stream(),
competingRiskSetsRight.getEventTimes().stream())
.mapToDouble(Double::doubleValue)
.sorted()
.distinct()
.toArray();
LogRankValue specificLogRankValue(final int eventOfFocus, final CompetingRiskSets<Y> competingRiskSets){
double summation = 0.0;
double variance = 0.0;
for(final double time_k : distinctEventTimes){
final double[] distinctTimes = competingRiskSets.getDistinctTimes();
for(int k = 0; k<distinctTimes.length; k++){
final double time_k = distinctTimes[k];
final double weight = weight(time_k); // W_j(t_k)
final double numberEventsAtTimeDaughterLeft = competingRiskSetsLeft.getNumberOfEvents(time_k, eventOfFocus); // // d_{j,l}(t_k)
final double numberEventsAtTimeDaughterRight = competingRiskSetsRight.getNumberOfEvents(time_k, eventOfFocus); // d_{j,r}(t_k)
final double numberOfEventsAtTime = numberEventsAtTimeDaughterLeft + numberEventsAtTimeDaughterRight; // d_j(t_k)
final double numberEventsAtTimeDaughterLeft = competingRiskSets.getNumberOfEventsLeft(k, eventOfFocus); // // d_{j,l}(t_k)
final double numberEventsAtTimeDaughterTotal = competingRiskSets.getNumberOfEventsTotal(k, eventOfFocus); // d_j(t_k)
final double individualsAtRiskDaughterLeft = competingRiskSetsLeft.getRiskSet(eventOfFocus).evaluate(time_k); // Y_l(t_k)
final double individualsAtRiskDaughterRight = competingRiskSetsRight.getRiskSet(eventOfFocus).evaluate(time_k); // Y_r(t_k)
final double individualsAtRisk = individualsAtRiskDaughterLeft + individualsAtRiskDaughterRight; // Y(t_k)
final double individualsAtRiskDaughterLeft = competingRiskSets.getRiskSetLeft(k, eventOfFocus); // Y_l(t_k)
final double individualsAtRiskDaughterTotal = competingRiskSets.getRiskSetTotal(k, eventOfFocus); // Y(t_k)
final double deltaSummation = weight*(numberEventsAtTimeDaughterLeft - numberOfEventsAtTime*individualsAtRiskDaughterLeft/individualsAtRisk);
final double deltaVariance = weight*weight*numberOfEventsAtTime*individualsAtRiskDaughterLeft/individualsAtRisk
* (1.0 - individualsAtRiskDaughterLeft / individualsAtRisk)
* ((individualsAtRisk - numberOfEventsAtTime) / (individualsAtRisk - 1.0));
final double deltaSummation = weight*(numberEventsAtTimeDaughterLeft - numberEventsAtTimeDaughterTotal*individualsAtRiskDaughterLeft/individualsAtRiskDaughterTotal);
final double deltaVariance = weight*weight*numberEventsAtTimeDaughterTotal*individualsAtRiskDaughterLeft/individualsAtRiskDaughterTotal
* (1.0 - individualsAtRiskDaughterLeft / individualsAtRiskDaughterTotal)
* ((individualsAtRiskDaughterTotal - numberEventsAtTimeDaughterTotal) / (individualsAtRiskDaughterTotal - 1.0));
// Note - notation differs slightly with what is found in STAT 855 notes, but they are equivalent.
// Note - if individualsAtRisk == 1 then variance will be NaN.

View file

@ -1,7 +1,7 @@
package ca.joeltherrien.randomforest.responses.competingrisk.differentiator;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskGraySetsImpl;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponseWithCensorTime;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskSets;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskUtils;
import lombok.RequiredArgsConstructor;
@ -17,19 +17,17 @@ public class GrayLogRankMultipleGroupDifferentiator extends CompetingRiskGroupDi
private final int[] events;
@Override
public Double getScore(List<CompetingRiskResponseWithCensorTime> leftHand, List<CompetingRiskResponseWithCensorTime> rightHand) {
if(leftHand.size() == 0 || rightHand.size() == 0){
return null;
}
final CompetingRiskGraySetsImpl competingRiskSetsLeft = CompetingRiskUtils.calculateGraySetsEfficiently(leftHand, events);
final CompetingRiskGraySetsImpl competingRiskSetsRight = CompetingRiskUtils.calculateGraySetsEfficiently(rightHand, events);
protected CompetingRiskSets<CompetingRiskResponseWithCensorTime> createCompetingRiskSets(List<CompetingRiskResponseWithCensorTime> leftHand, List<CompetingRiskResponseWithCensorTime> rightHand){
return CompetingRiskUtils.calculateGraySetsEfficiently(leftHand, rightHand, events);
}
@Override
protected Double getScore(final CompetingRiskSets<CompetingRiskResponseWithCensorTime> competingRiskSets){
double numerator = 0.0;
double denominatorSquared = 0.0;
for(final int eventOfFocus : events){
final LogRankValue valueOfInterest = specificLogRankValue(eventOfFocus, competingRiskSetsLeft, competingRiskSetsRight);
final LogRankValue valueOfInterest = specificLogRankValue(eventOfFocus, competingRiskSets);
numerator += valueOfInterest.getNumerator()*valueOfInterest.getVarianceSqrt();
denominatorSquared += valueOfInterest.getVariance();
@ -37,7 +35,6 @@ public class GrayLogRankMultipleGroupDifferentiator extends CompetingRiskGroupDi
}
return Math.abs(numerator / Math.sqrt(denominatorSquared));
}

View file

@ -1,7 +1,7 @@
package ca.joeltherrien.randomforest.responses.competingrisk.differentiator;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskGraySetsImpl;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponseWithCensorTime;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskSets;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskUtils;
import lombok.RequiredArgsConstructor;
@ -18,18 +18,14 @@ public class GrayLogRankSingleGroupDifferentiator extends CompetingRiskGroupDiff
private final int[] events;
@Override
public Double getScore(List<CompetingRiskResponseWithCensorTime> leftHand, List<CompetingRiskResponseWithCensorTime> rightHand) {
if(leftHand.size() == 0 || rightHand.size() == 0){
return null;
}
final CompetingRiskGraySetsImpl competingRiskSetsLeft = CompetingRiskUtils.calculateGraySetsEfficiently(leftHand, events);
final CompetingRiskGraySetsImpl competingRiskSetsRight = CompetingRiskUtils.calculateGraySetsEfficiently(rightHand, events);
final LogRankValue valueOfInterest = specificLogRankValue(eventOfFocus, competingRiskSetsLeft, competingRiskSetsRight);
protected CompetingRiskSets<CompetingRiskResponseWithCensorTime> createCompetingRiskSets(List<CompetingRiskResponseWithCensorTime> leftHand, List<CompetingRiskResponseWithCensorTime> rightHand){
return CompetingRiskUtils.calculateGraySetsEfficiently(leftHand, rightHand, events);
}
@Override
protected Double getScore(final CompetingRiskSets<CompetingRiskResponseWithCensorTime> competingRiskSets){
final LogRankValue valueOfInterest = specificLogRankValue(eventOfFocus, competingRiskSets);
return Math.abs(valueOfInterest.getNumerator() / valueOfInterest.getVarianceSqrt());
}
}

View file

@ -1,7 +1,7 @@
package ca.joeltherrien.randomforest.responses.competingrisk.differentiator;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskSetsImpl;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskSets;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskUtils;
import lombok.RequiredArgsConstructor;
@ -17,19 +17,17 @@ public class LogRankMultipleGroupDifferentiator extends CompetingRiskGroupDiffer
private final int[] events;
@Override
public Double getScore(List<CompetingRiskResponse> leftHand, List<CompetingRiskResponse> rightHand) {
if(leftHand.size() == 0 || rightHand.size() == 0){
return null;
}
final CompetingRiskSetsImpl competingRiskSetsLeft = CompetingRiskUtils.calculateSetsEfficiently(leftHand, events);
final CompetingRiskSetsImpl competingRiskSetsRight = CompetingRiskUtils.calculateSetsEfficiently(rightHand, events);
protected CompetingRiskSets<CompetingRiskResponse> createCompetingRiskSets(List<CompetingRiskResponse> leftHand, List<CompetingRiskResponse> rightHand){
return CompetingRiskUtils.calculateSetsEfficiently(leftHand, rightHand, events, true);
}
@Override
protected Double getScore(final CompetingRiskSets<CompetingRiskResponse> competingRiskSets){
double numerator = 0.0;
double denominatorSquared = 0.0;
for(final int eventOfFocus : events){
final LogRankValue valueOfInterest = specificLogRankValue(eventOfFocus, competingRiskSetsLeft, competingRiskSetsRight);
final LogRankValue valueOfInterest = specificLogRankValue(eventOfFocus, competingRiskSets);
numerator += valueOfInterest.getNumerator()*valueOfInterest.getVarianceSqrt();
denominatorSquared += valueOfInterest.getVariance();
@ -37,7 +35,7 @@ public class LogRankMultipleGroupDifferentiator extends CompetingRiskGroupDiffer
}
return Math.abs(numerator / Math.sqrt(denominatorSquared));
}
}

View file

@ -1,7 +1,7 @@
package ca.joeltherrien.randomforest.responses.competingrisk.differentiator;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskSetsImpl;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskSets;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskUtils;
import lombok.RequiredArgsConstructor;
@ -18,18 +18,14 @@ public class LogRankSingleGroupDifferentiator extends CompetingRiskGroupDifferen
private final int[] events;
@Override
public Double getScore(List<CompetingRiskResponse> leftHand, List<CompetingRiskResponse> rightHand) {
if(leftHand.size() == 0 || rightHand.size() == 0){
return null;
}
final CompetingRiskSetsImpl competingRiskSetsLeft = CompetingRiskUtils.calculateSetsEfficiently(leftHand, events);
final CompetingRiskSetsImpl competingRiskSetsRight = CompetingRiskUtils.calculateSetsEfficiently(rightHand, events);
final LogRankValue valueOfInterest = specificLogRankValue(eventOfFocus, competingRiskSetsLeft, competingRiskSetsRight);
protected CompetingRiskSets<CompetingRiskResponse> createCompetingRiskSets(List<CompetingRiskResponse> leftHand, List<CompetingRiskResponse> rightHand){
return CompetingRiskUtils.calculateSetsEfficiently(leftHand, rightHand, events, true);
}
@Override
protected Double getScore(final CompetingRiskSets<CompetingRiskResponse> competingRiskSets){
final LogRankValue valueOfInterest = specificLogRankValue(eventOfFocus, competingRiskSets);
return Math.abs(valueOfInterest.getNumerator() / valueOfInterest.getVarianceSqrt());
}
}

View file

@ -12,6 +12,6 @@ import java.util.Iterator;
*/
public interface GroupDifferentiator<Y> {
<V> SplitAndScore<Y, V> differentiate(Iterator<Split<Y, V>> splitIterator);
SplitAndScore<Y, ?> differentiate(Iterator<Split<Y, ?>> splitIterator);
}

View file

@ -9,12 +9,12 @@ import java.util.stream.Collectors;
public abstract class SimpleGroupDifferentiator<Y> implements GroupDifferentiator<Y> {
@Override
public <V> SplitAndScore<Y, V> differentiate(Iterator<Split<Y, V>> splitIterator) {
public SplitAndScore<Y, ?> differentiate(Iterator<Split<Y, ?>> splitIterator) {
Double bestScore = null;
Split<Y, V> bestSplit = null;
Split<Y, ?> bestSplit = null;
while(splitIterator.hasNext()){
final Split<Y, V> candidateSplit = splitIterator.next();
final Split<Y, ?> candidateSplit = splitIterator.next();
final List<Y> leftHand = candidateSplit.getLeftHand().stream().map(Row::getResponse).collect(Collectors.toList());
final List<Y> rightHand = candidateSplit.getRightHand().stream().map(Row::getResponse).collect(Collectors.toList());

View file

@ -1,214 +0,0 @@
package ca.joeltherrien.randomforest.competingrisk;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskGraySetsImpl;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponseWithCensorTime;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskSetsImpl;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskUtils;
import org.junit.jupiter.api.Test;
import java.util.ArrayList;
import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
public class TestCalculatingCompetingRiskSets {
public List<CompetingRiskResponseWithCensorTime> generateData(){
final List<CompetingRiskResponseWithCensorTime> data = new ArrayList<>();
data.add(new CompetingRiskResponseWithCensorTime(1, 1, 3));
data.add(new CompetingRiskResponseWithCensorTime(1, 1, 3));
data.add(new CompetingRiskResponseWithCensorTime(0, 1, 1));
data.add(new CompetingRiskResponseWithCensorTime(1, 2, 2.5));
data.add(new CompetingRiskResponseWithCensorTime(2, 3, 4));
data.add(new CompetingRiskResponseWithCensorTime(0, 3, 3));
data.add(new CompetingRiskResponseWithCensorTime(1, 4, 4));
data.add(new CompetingRiskResponseWithCensorTime(0, 5, 5));
data.add(new CompetingRiskResponseWithCensorTime(2, 6, 7));
return data;
}
@Test
public void testCalculatingSets(){
final List data = generateData();
final CompetingRiskSetsImpl sets = CompetingRiskUtils.calculateSetsEfficiently(data, new int[]{1,2});
final List<Double> times = sets.getEventTimes();
assertEquals(5, times.size());
// Times
assertEquals(1.0, times.get(0).doubleValue());
assertEquals(2.0, times.get(1).doubleValue());
assertEquals(3.0, times.get(2).doubleValue());
assertEquals(4.0, times.get(3).doubleValue());
assertEquals(6.0, times.get(4).doubleValue());
// Number of Events
assertEquals(2, sets.getNumberOfEvents(1.0, 1));
assertEquals(0, sets.getNumberOfEvents(1.0, 2));
assertEquals(1, sets.getNumberOfEvents(2.0, 1));
assertEquals(0, sets.getNumberOfEvents(2.0, 2));
assertEquals(0, sets.getNumberOfEvents(3.0, 1));
assertEquals(1, sets.getNumberOfEvents(3.0, 2));
assertEquals(1, sets.getNumberOfEvents(4.0, 1));
assertEquals(0, sets.getNumberOfEvents(4.0, 2));
assertEquals(0, sets.getNumberOfEvents(6.0, 1));
assertEquals(1, sets.getNumberOfEvents(6.0, 2));
// Make sure it doesn't break for other times
assertEquals(0, sets.getNumberOfEvents(5.5, 1));
assertEquals(0, sets.getNumberOfEvents(5.5, 2));
// Risk set
assertEquals(9, sets.getRiskSet(1).evaluate(0.5));
assertEquals(9, sets.getRiskSet(2).evaluate(0.5));
assertEquals(9, sets.getRiskSet(1).evaluate(1.0));
assertEquals(9, sets.getRiskSet(2).evaluate(1.0));
assertEquals(6, sets.getRiskSet(1).evaluate(1.5));
assertEquals(6, sets.getRiskSet(2).evaluate(1.5));
assertEquals(6, sets.getRiskSet(1).evaluate(2.0));
assertEquals(6, sets.getRiskSet(2).evaluate(2.0));
assertEquals(5, sets.getRiskSet(1).evaluate(2.3));
assertEquals(5, sets.getRiskSet(2).evaluate(2.3));
assertEquals(5, sets.getRiskSet(1).evaluate(2.5));
assertEquals(5, sets.getRiskSet(2).evaluate(2.5));
assertEquals(5, sets.getRiskSet(1).evaluate(2.7));
assertEquals(5, sets.getRiskSet(2).evaluate(2.7));
assertEquals(5, sets.getRiskSet(1).evaluate(3.0));
assertEquals(5, sets.getRiskSet(2).evaluate(3.0));
assertEquals(3, sets.getRiskSet(1).evaluate(3.5));
assertEquals(3, sets.getRiskSet(2).evaluate(3.5));
assertEquals(3, sets.getRiskSet(1).evaluate(4.0));
assertEquals(3, sets.getRiskSet(2).evaluate(4.0));
assertEquals(2, sets.getRiskSet(1).evaluate(4.5));
assertEquals(2, sets.getRiskSet(2).evaluate(4.5));
assertEquals(2, sets.getRiskSet(1).evaluate(5.0));
assertEquals(2, sets.getRiskSet(2).evaluate(5.0));
assertEquals(1, sets.getRiskSet(1).evaluate(5.5));
assertEquals(1, sets.getRiskSet(2).evaluate(5.5));
assertEquals(1, sets.getRiskSet(1).evaluate(6.0));
assertEquals(1, sets.getRiskSet(2).evaluate(6.0));
assertEquals(0, sets.getRiskSet(1).evaluate(6.5));
assertEquals(0, sets.getRiskSet(2).evaluate(6.5));
assertEquals(0, sets.getRiskSet(1).evaluate(7.0));
assertEquals(0, sets.getRiskSet(2).evaluate(7.0));
assertEquals(0, sets.getRiskSet(1).evaluate(7.5));
assertEquals(0, sets.getRiskSet(2).evaluate(7.5));
}
@Test
public void testCalculatingGraySets(){
final List<CompetingRiskResponseWithCensorTime> data = generateData();
final CompetingRiskGraySetsImpl sets = CompetingRiskUtils.calculateGraySetsEfficiently(data, new int[]{1,2});
final List<Double> times = sets.getEventTimes();
assertEquals(5, times.size());
// Times
assertEquals(1.0, times.get(0).doubleValue());
assertEquals(2.0, times.get(1).doubleValue());
assertEquals(3.0, times.get(2).doubleValue());
assertEquals(4.0, times.get(3).doubleValue());
assertEquals(6.0, times.get(4).doubleValue());
// Number of Events
assertEquals(2, sets.getNumberOfEvents(1.0, 1));
assertEquals(0, sets.getNumberOfEvents(1.0, 2));
assertEquals(1, sets.getNumberOfEvents(2.0, 1));
assertEquals(0, sets.getNumberOfEvents(2.0, 2));
assertEquals(0, sets.getNumberOfEvents(3.0, 1));
assertEquals(1, sets.getNumberOfEvents(3.0, 2));
assertEquals(1, sets.getNumberOfEvents(4.0, 1));
assertEquals(0, sets.getNumberOfEvents(4.0, 2));
assertEquals(0, sets.getNumberOfEvents(6.0, 1));
assertEquals(1, sets.getNumberOfEvents(6.0, 2));
// Make sure it doesn't break for other times
assertEquals(0, sets.getNumberOfEvents(5.5, 1));
assertEquals(0, sets.getNumberOfEvents(5.5, 2));
// Risk set
assertEquals(9, sets.getRiskSet(1).evaluate(0.5));
assertEquals(9, sets.getRiskSet(2).evaluate(0.5));
assertEquals(9, sets.getRiskSet(1).evaluate(1.0));
assertEquals(9, sets.getRiskSet(2).evaluate(1.0));
assertEquals(6, sets.getRiskSet(1).evaluate(1.5));
assertEquals(8, sets.getRiskSet(2).evaluate(1.5));
assertEquals(6, sets.getRiskSet(1).evaluate(2.0));
assertEquals(8, sets.getRiskSet(2).evaluate(2.0));
assertEquals(5, sets.getRiskSet(1).evaluate(2.3));
assertEquals(8, sets.getRiskSet(2).evaluate(2.3));
assertEquals(5, sets.getRiskSet(1).evaluate(2.5));
assertEquals(7, sets.getRiskSet(2).evaluate(2.5));
assertEquals(5, sets.getRiskSet(1).evaluate(2.7));
assertEquals(7, sets.getRiskSet(2).evaluate(2.7));
assertEquals(5, sets.getRiskSet(1).evaluate(3.0));
assertEquals(5, sets.getRiskSet(2).evaluate(3.0));
assertEquals(4, sets.getRiskSet(1).evaluate(3.5));
assertEquals(3, sets.getRiskSet(2).evaluate(3.5));
assertEquals(3, sets.getRiskSet(1).evaluate(4.0));
assertEquals(3, sets.getRiskSet(2).evaluate(4.0));
assertEquals(2, sets.getRiskSet(1).evaluate(4.5));
assertEquals(2, sets.getRiskSet(2).evaluate(4.5));
assertEquals(2, sets.getRiskSet(1).evaluate(5.0));
assertEquals(2, sets.getRiskSet(2).evaluate(5.0));
assertEquals(1, sets.getRiskSet(1).evaluate(5.5));
assertEquals(1, sets.getRiskSet(2).evaluate(5.5));
assertEquals(1, sets.getRiskSet(1).evaluate(6.0));
assertEquals(1, sets.getRiskSet(2).evaluate(6.0));
assertEquals(1, sets.getRiskSet(1).evaluate(6.5));
assertEquals(0, sets.getRiskSet(2).evaluate(6.5));
assertEquals(0, sets.getRiskSet(1).evaluate(7.0));
assertEquals(0, sets.getRiskSet(2).evaluate(7.0));
assertEquals(0, sets.getRiskSet(1).evaluate(7.5));
assertEquals(0, sets.getRiskSet(2).evaluate(7.5));
}
}

View file

@ -1,10 +1,15 @@
package ca.joeltherrien.randomforest.competingrisk;
import ca.joeltherrien.randomforest.*;
import ca.joeltherrien.randomforest.covariates.*;
import ca.joeltherrien.randomforest.CovariateRow;
import ca.joeltherrien.randomforest.DataLoader;
import ca.joeltherrien.randomforest.Row;
import ca.joeltherrien.randomforest.Settings;
import ca.joeltherrien.randomforest.covariates.Covariate;
import ca.joeltherrien.randomforest.covariates.settings.BooleanCovariateSettings;
import ca.joeltherrien.randomforest.covariates.settings.NumericCovariateSettings;
import ca.joeltherrien.randomforest.responses.competingrisk.*;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskErrorRateCalculator;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskFunctions;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
import ca.joeltherrien.randomforest.tree.Forest;
import ca.joeltherrien.randomforest.tree.ForestTrainer;
import ca.joeltherrien.randomforest.tree.Node;
@ -14,14 +19,15 @@ import ca.joeltherrien.randomforest.utils.Utils;
import com.fasterxml.jackson.databind.node.*;
import org.junit.jupiter.api.Test;
import static ca.joeltherrien.randomforest.TestUtils.assertCumulativeFunction;
import static ca.joeltherrien.randomforest.TestUtils.closeEnough;
import static org.junit.jupiter.api.Assertions.*;
import java.io.IOException;
import java.util.List;
import java.util.Random;
import static ca.joeltherrien.randomforest.TestUtils.assertCumulativeFunction;
import static ca.joeltherrien.randomforest.TestUtils.closeEnough;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
public class TestCompetingRisk {
@ -284,6 +290,34 @@ public class TestCompetingRisk {
assertEquals(359, countEventTwo);
}
/**
* Used to time how long the algorithm takes
*
* @param args Not used.
* @throws IOException
*/
public static void main(String[] args) throws IOException {
// timing
final TestCompetingRisk tcr = new TestCompetingRisk();
final Settings settings = tcr.getSettings();
settings.setNtree(300); // results are too variable at 100
final List<Covariate> covariates = settings.getCovariates();
final List<Row<CompetingRiskResponse>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(),
settings.getTrainingDataLocation());
final ForestTrainer<CompetingRiskResponse, CompetingRiskFunctions, CompetingRiskFunctions> forestTrainer = new ForestTrainer<>(settings, dataset, covariates);
final long startTime = System.currentTimeMillis();
for(int i=0; i<50; i++){
forestTrainer.trainSerial();
}
final long endTime = System.currentTimeMillis();
final double diffTime = endTime - startTime;
System.out.println(diffTime / 1000.0 / 50.0);
}
@Test
public void testLogRankSingleGroupDifferentiatorAllCovariates() throws IOException {

View file

@ -7,6 +7,8 @@ import ca.joeltherrien.randomforest.covariates.Covariate;
import ca.joeltherrien.randomforest.covariates.settings.NumericCovariateSettings;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
import ca.joeltherrien.randomforest.responses.competingrisk.differentiator.LogRankMultipleGroupDifferentiator;
import ca.joeltherrien.randomforest.tree.Split;
import ca.joeltherrien.randomforest.utils.SingletonIterator;
import ca.joeltherrien.randomforest.utils.Utils;
import com.fasterxml.jackson.databind.node.JsonNodeFactory;
import com.fasterxml.jackson.databind.node.ObjectNode;
@ -15,6 +17,8 @@ import lombok.AllArgsConstructor;
import org.junit.jupiter.api.Test;
import java.io.IOException;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.stream.Collectors;
@ -22,6 +26,11 @@ import static org.junit.jupiter.api.Assertions.assertTrue;
public class TestLogRankMultipleGroupDifferentiator {
private Iterator<Split<CompetingRiskResponse, ?>> turnIntoSplitIterator(List<Row<CompetingRiskResponse>> leftList,
List<Row<CompetingRiskResponse>> rightList){
return new SingletonIterator<Split<CompetingRiskResponse, ?>>(new Split(null, leftList, rightList, Collections.emptyList()));
}
public static Data<CompetingRiskResponse> loadData(String filename) throws IOException {
final ObjectNode yVarSettings = new ObjectNode(JsonNodeFactory.instance);
yVarSettings.set("type", new TextNode("CompetingRiskResponse"));
@ -53,16 +62,12 @@ public class TestLogRankMultipleGroupDifferentiator {
final List<Row<CompetingRiskResponse>> group1Bad = data.subList(0, 196);
final List<Row<CompetingRiskResponse>> group2Bad = data.subList(196, data.size());
final double scoreBad = groupDifferentiator.getScore(
group1Bad.stream().map(Row::getResponse).collect(Collectors.toList()),
group2Bad.stream().map(Row::getResponse).collect(Collectors.toList()));
final double scoreBad = groupDifferentiator.differentiate(turnIntoSplitIterator(group1Bad, group2Bad)).getScore();
final List<Row<CompetingRiskResponse>> group1Good = data.subList(0, 199);
final List<Row<CompetingRiskResponse>> group2Good= data.subList(199, data.size());
final double scoreGood = groupDifferentiator.getScore(
group1Good.stream().map(Row::getResponse).collect(Collectors.toList()),
group2Good.stream().map(Row::getResponse).collect(Collectors.toList()));
final double scoreGood = groupDifferentiator.differentiate(turnIntoSplitIterator(group1Good, group2Good)).getScore();
// expected results calculated manually using survival::survdiff in R; see issue #10 in Gitea
closeEnough(71.41135, scoreBad, 0.00001);

View file

@ -3,10 +3,15 @@ package ca.joeltherrien.randomforest.competingrisk;
import ca.joeltherrien.randomforest.Row;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
import ca.joeltherrien.randomforest.responses.competingrisk.differentiator.LogRankSingleGroupDifferentiator;
import ca.joeltherrien.randomforest.tree.GroupDifferentiator;
import ca.joeltherrien.randomforest.tree.Split;
import ca.joeltherrien.randomforest.utils.SingletonIterator;
import org.junit.jupiter.api.Test;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.stream.Collectors;
@ -15,42 +20,55 @@ import static org.junit.jupiter.api.Assertions.assertTrue;
public class TestLogRankSingleGroupDifferentiator {
private List<CompetingRiskResponse> generateData1(){
final List<CompetingRiskResponse> data = new ArrayList<>();
private double getScore(final GroupDifferentiator<CompetingRiskResponse> groupDifferentiator, List<Row<CompetingRiskResponse>> left, List<Row<CompetingRiskResponse>> right){
final Iterator<Split<CompetingRiskResponse, ?>> iterator = new SingletonIterator<>(
new Split<>(null, left, right, Collections.emptyList()));
data.add(new CompetingRiskResponse(1, 1.0));
data.add(new CompetingRiskResponse(1, 1.0));
data.add(new CompetingRiskResponse(1, 2.0));
data.add(new CompetingRiskResponse(1, 1.5));
data.add(new CompetingRiskResponse(0, 2.0));
data.add(new CompetingRiskResponse(0, 1.5));
data.add(new CompetingRiskResponse(0, 2.5));
return groupDifferentiator.differentiate(iterator).getScore();
}
int count = 1;
private <Y> Row<Y> createRow(Y response){
return new Row<>(null, count++, response);
}
private List<Row<CompetingRiskResponse>> generateData1(){
final List<Row<CompetingRiskResponse>> data = new ArrayList<>();
data.add(createRow(new CompetingRiskResponse(1, 1.0)));
data.add(createRow(new CompetingRiskResponse(1, 1.0)));
data.add(createRow(new CompetingRiskResponse(1, 2.0)));
data.add(createRow(new CompetingRiskResponse(1, 1.5)));
data.add(createRow(new CompetingRiskResponse(0, 2.0)));
data.add(createRow(new CompetingRiskResponse(0, 1.5)));
data.add(createRow(new CompetingRiskResponse(0, 2.5)));
return data;
}
private List<CompetingRiskResponse> generateData2(){
final List<CompetingRiskResponse> data = new ArrayList<>();
private List<Row<CompetingRiskResponse>> generateData2(){
final List<Row<CompetingRiskResponse>> data = new ArrayList<>();
data.add(new CompetingRiskResponse(1, 2.0));
data.add(new CompetingRiskResponse(1, 2.0));
data.add(new CompetingRiskResponse(1, 4.0));
data.add(new CompetingRiskResponse(1, 3.0));
data.add(new CompetingRiskResponse(0, 4.0));
data.add(new CompetingRiskResponse(0, 3.0));
data.add(new CompetingRiskResponse(0, 5.0));
data.add(createRow(new CompetingRiskResponse(1, 2.0)));
data.add(createRow(new CompetingRiskResponse(1, 2.0)));
data.add(createRow(new CompetingRiskResponse(1, 4.0)));
data.add(createRow(new CompetingRiskResponse(1, 3.0)));
data.add(createRow(new CompetingRiskResponse(0, 4.0)));
data.add(createRow(new CompetingRiskResponse(0, 3.0)));
data.add(createRow(new CompetingRiskResponse(0, 5.0)));
return data;
}
@Test
public void testCompetingRiskResponseCombiner(){
final List<CompetingRiskResponse> data1 = generateData1();
final List<CompetingRiskResponse> data2 = generateData2();
final List<Row<CompetingRiskResponse>> data1 = generateData1();
final List<Row<CompetingRiskResponse>> data2 = generateData2();
final LogRankSingleGroupDifferentiator differentiator = new LogRankSingleGroupDifferentiator(1, new int[]{1});
final double score = differentiator.getScore(data1, data2);
final double score = getScore(differentiator, data1, data2);
final double margin = 0.000001;
// Tested using 855 method
@ -70,16 +88,12 @@ public class TestLogRankSingleGroupDifferentiator {
final List<Row<CompetingRiskResponse>> group1Good = data.subList(0, 221);
final List<Row<CompetingRiskResponse>> group2Good = data.subList(221, data.size());
final double scoreGood = groupDifferentiator.getScore(
group1Good.stream().map(Row::getResponse).collect(Collectors.toList()),
group2Good.stream().map(Row::getResponse).collect(Collectors.toList()));
final double scoreGood = getScore(groupDifferentiator, group1Good, group2Good);
final List<Row<CompetingRiskResponse>> group1Bad = data.subList(0, 222);
final List<Row<CompetingRiskResponse>> group2Bad = data.subList(222, data.size());
final double scoreBad = groupDifferentiator.getScore(
group1Bad.stream().map(Row::getResponse).collect(Collectors.toList()),
group2Bad.stream().map(Row::getResponse).collect(Collectors.toList()));
final double scoreBad = getScore(groupDifferentiator, group1Bad, group2Bad);
// Apparently not all groups are unique when splitting
assertEquals(scoreGood, scoreBad);