Update the competing risk GroupDifferentiators to make efficient use of the SplitRuleUpdater updates
Results in a speed improvement of over 1/3 according to a timing of the TestCompetingRisk#testLogRankSingleGroupDifferentiatorAllCovariates() test
This commit is contained in:
parent
86122fd90d
commit
e709c42da1
17 changed files with 524 additions and 478 deletions
|
@ -40,6 +40,7 @@ public interface Covariate<V> extends Serializable {
|
||||||
|
|
||||||
interface SplitRuleUpdater<Y, V> extends Iterator<Split<Y, V>>{
|
interface SplitRuleUpdater<Y, V> extends Iterator<Split<Y, V>>{
|
||||||
Split<Y, V> currentSplit();
|
Split<Y, V> currentSplit();
|
||||||
|
boolean currentSplitValid();
|
||||||
SplitUpdate<Y, V> nextUpdate();
|
SplitUpdate<Y, V> nextUpdate();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -37,6 +37,11 @@ public class NumericSplitRuleUpdater<Y> implements Covariate.SplitRuleUpdater<Y,
|
||||||
return this.currentSplit;
|
return this.currentSplit;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean currentSplitValid() {
|
||||||
|
return currentSplit.getLeftHand().size() > 0 && currentSplit.getRightHand().size() > 0;
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public NumericSplitUpdate<Y> nextUpdate() {
|
public NumericSplitUpdate<Y> nextUpdate() {
|
||||||
if(hasNext()){
|
if(hasNext()){
|
||||||
|
@ -51,8 +56,8 @@ public class NumericSplitRuleUpdater<Y> implements Covariate.SplitRuleUpdater<Y,
|
||||||
// Update current split
|
// Update current split
|
||||||
this.currentSplit = new Split<>(
|
this.currentSplit = new Split<>(
|
||||||
splitRule,
|
splitRule,
|
||||||
orderedData.subList(0, newPosition),
|
Collections.unmodifiableList(orderedData.subList(0, newPosition)),
|
||||||
orderedData.subList(newPosition, orderedData.size()),
|
Collections.unmodifiableList(orderedData.subList(newPosition, orderedData.size())),
|
||||||
Collections.emptyList());
|
Collections.emptyList());
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,37 +1,77 @@
|
||||||
package ca.joeltherrien.randomforest.responses.competingrisk;
|
package ca.joeltherrien.randomforest.responses.competingrisk;
|
||||||
|
|
||||||
import ca.joeltherrien.randomforest.utils.MathFunction;
|
import java.util.Arrays;
|
||||||
import lombok.Builder;
|
|
||||||
import lombok.Getter;
|
|
||||||
|
|
||||||
import java.util.List;
|
public class CompetingRiskGraySetsImpl implements CompetingRiskSets<CompetingRiskResponseWithCensorTime> {
|
||||||
import java.util.Map;
|
|
||||||
|
|
||||||
/**
|
final double[] times; // length m array
|
||||||
* Represents a response from CompetingRiskUtils#calculateGraySetsEfficiently
|
int[][] riskSetLeft; // J x m array
|
||||||
*
|
final int[][] riskSetTotal; // J x m array
|
||||||
*/
|
int[][] numberOfEventsLeft; // J+1 x m array
|
||||||
@Builder
|
final int[][] numberOfEventsTotal; // J+1 x m array
|
||||||
@Getter
|
|
||||||
public class CompetingRiskGraySetsImpl implements CompetingRiskSets{
|
|
||||||
|
|
||||||
private final List<Double> eventTimes;
|
public CompetingRiskGraySetsImpl(double[] times, int[][] riskSetLeft, int[][] riskSetTotal, int[][] numberOfEventsLeft, int[][] numberOfEventsTotal) {
|
||||||
private final MathFunction[] riskSet;
|
this.times = times;
|
||||||
private final Map<Double, int[]> numberOfEvents;
|
this.riskSetLeft = riskSetLeft;
|
||||||
|
this.riskSetTotal = riskSetTotal;
|
||||||
@Override
|
this.numberOfEventsLeft = numberOfEventsLeft;
|
||||||
public MathFunction getRiskSet(int event){
|
this.numberOfEventsTotal = numberOfEventsTotal;
|
||||||
return riskSet[event-1];
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public int getNumberOfEvents(Double time, int event){
|
public double[] getDistinctTimes() {
|
||||||
if(numberOfEvents.containsKey(time)){
|
return times;
|
||||||
return numberOfEvents.get(time)[event];
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int getRiskSetLeft(int timeIndex, int event) {
|
||||||
|
return riskSetLeft[event-1][timeIndex];
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int getRiskSetTotal(int timeIndex, int event) {
|
||||||
|
return riskSetTotal[event-1][timeIndex];
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int getNumberOfEventsLeft(int timeIndex, int event) {
|
||||||
|
return numberOfEventsLeft[event][timeIndex];
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int getNumberOfEventsTotal(int timeIndex, int event) {
|
||||||
|
return numberOfEventsTotal[event][timeIndex];
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void update(CompetingRiskResponseWithCensorTime rowMovedToLeft) {
|
||||||
|
final double time = rowMovedToLeft.getU();
|
||||||
|
final int k = Arrays.binarySearch(times, time);
|
||||||
|
final int delta_m_1 = rowMovedToLeft.getDelta() - 1;
|
||||||
|
final double censorTime = rowMovedToLeft.getC();
|
||||||
|
|
||||||
|
for(int j=0; j<riskSetLeft.length; j++){
|
||||||
|
final int[] riskSetLeftJ = riskSetLeft[j];
|
||||||
|
|
||||||
|
// first iteration; perform normal increment as if Y is normal
|
||||||
|
// corresponds to the first part, U_i >= t, in I(...)
|
||||||
|
for(int i=0; i<=k; i++){
|
||||||
|
riskSetLeftJ[i]++;
|
||||||
|
}
|
||||||
|
|
||||||
|
// second iteration; only if delta-1 != j
|
||||||
|
// corresponds to the second part, U_i < t & delta_i != j & C_i > t
|
||||||
|
if(delta_m_1 != j && !rowMovedToLeft.isCensored()){
|
||||||
|
int i = k+1;
|
||||||
|
while(i < times.length && times[i] < censorTime){
|
||||||
|
riskSetLeftJ[i]++;
|
||||||
|
i++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return 0;
|
numberOfEventsLeft[rowMovedToLeft.getDelta()][k]++;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,13 +1,13 @@
|
||||||
package ca.joeltherrien.randomforest.responses.competingrisk;
|
package ca.joeltherrien.randomforest.responses.competingrisk;
|
||||||
|
|
||||||
import ca.joeltherrien.randomforest.utils.MathFunction;
|
public interface CompetingRiskSets<T extends CompetingRiskResponse> {
|
||||||
|
|
||||||
import java.util.List;
|
double[] getDistinctTimes();
|
||||||
|
int getRiskSetLeft(int timeIndex, int event);
|
||||||
|
int getRiskSetTotal(int timeIndex, int event);
|
||||||
|
int getNumberOfEventsLeft(int timeIndex, int event);
|
||||||
|
int getNumberOfEventsTotal(int timeIndex, int event);
|
||||||
|
|
||||||
public interface CompetingRiskSets {
|
void update(T rowMovedToLeft);
|
||||||
|
|
||||||
MathFunction getRiskSet(int event);
|
|
||||||
int getNumberOfEvents(Double time, int event);
|
|
||||||
List<Double> getEventTimes();
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,36 +1,59 @@
|
||||||
package ca.joeltherrien.randomforest.responses.competingrisk;
|
package ca.joeltherrien.randomforest.responses.competingrisk;
|
||||||
|
|
||||||
import ca.joeltherrien.randomforest.utils.MathFunction;
|
import java.util.Arrays;
|
||||||
import lombok.Builder;
|
|
||||||
import lombok.Getter;
|
|
||||||
|
|
||||||
import java.util.List;
|
public class CompetingRiskSetsImpl implements CompetingRiskSets<CompetingRiskResponse> {
|
||||||
import java.util.Map;
|
|
||||||
|
|
||||||
/**
|
final double[] times; // length m array
|
||||||
* Represents a response from CompetingRiskUtils#calculateSetsEfficiently
|
int[] riskSetLeft; // length m array
|
||||||
*
|
final int[] riskSetTotal; // length m array
|
||||||
*/
|
int[][] numberOfEventsLeft; // J+1 x m array
|
||||||
@Builder
|
final int[][] numberOfEventsTotal; // J+1 x m array
|
||||||
@Getter
|
|
||||||
public class CompetingRiskSetsImpl implements CompetingRiskSets{
|
|
||||||
|
|
||||||
private final List<Double> eventTimes;
|
|
||||||
private final MathFunction riskSet;
|
|
||||||
private final Map<Double, int[]> numberOfEvents;
|
|
||||||
|
|
||||||
@Override
|
public CompetingRiskSetsImpl(double[] times, int[] riskSetLeft, int[] riskSetTotal, int[][] numberOfEventsLeft, int[][] numberOfEventsTotal) {
|
||||||
public MathFunction getRiskSet(int event){
|
this.times = times;
|
||||||
return riskSet;
|
this.riskSetLeft = riskSetLeft;
|
||||||
|
this.riskSetTotal = riskSetTotal;
|
||||||
|
this.numberOfEventsLeft = numberOfEventsLeft;
|
||||||
|
this.numberOfEventsTotal = numberOfEventsTotal;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public int getNumberOfEvents(Double time, int event){
|
public double[] getDistinctTimes() {
|
||||||
if(numberOfEvents.containsKey(time)){
|
return times;
|
||||||
return numberOfEvents.get(time)[event];
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int getRiskSetLeft(int timeIndex, int event) {
|
||||||
|
return riskSetLeft[timeIndex];
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int getRiskSetTotal(int timeIndex, int event) {
|
||||||
|
return riskSetTotal[timeIndex];
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int getNumberOfEventsLeft(int timeIndex, int event) {
|
||||||
|
return numberOfEventsLeft[event][timeIndex];
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int getNumberOfEventsTotal(int timeIndex, int event) {
|
||||||
|
return numberOfEventsTotal[event][timeIndex];
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void update(CompetingRiskResponse rowMovedToLeft) {
|
||||||
|
final double time = rowMovedToLeft.getU();
|
||||||
|
final int k = Arrays.binarySearch(times, time);
|
||||||
|
|
||||||
|
for(int i=0; i<=k; i++){
|
||||||
|
riskSetLeft[i]++;
|
||||||
}
|
}
|
||||||
|
|
||||||
return 0;
|
numberOfEventsLeft[rowMovedToLeft.getDelta()][k]++;
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,11 +1,9 @@
|
||||||
package ca.joeltherrien.randomforest.responses.competingrisk;
|
package ca.joeltherrien.randomforest.responses.competingrisk;
|
||||||
|
|
||||||
import ca.joeltherrien.randomforest.utils.LeftContinuousStepFunction;
|
|
||||||
import ca.joeltherrien.randomforest.utils.StepFunction;
|
import ca.joeltherrien.randomforest.utils.StepFunction;
|
||||||
import ca.joeltherrien.randomforest.utils.VeryDiscontinuousStepFunction;
|
|
||||||
|
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
import java.util.stream.DoubleStream;
|
import java.util.stream.Stream;
|
||||||
|
|
||||||
public class CompetingRiskUtils {
|
public class CompetingRiskUtils {
|
||||||
|
|
||||||
|
@ -102,18 +100,30 @@ public class CompetingRiskUtils {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public static CompetingRiskSetsImpl calculateSetsEfficiently(final List<CompetingRiskResponse> responses, int[] eventsOfFocus){
|
|
||||||
final int n = responses.size();
|
|
||||||
int[] numberOfCurrentEvents = new int[eventsOfFocus.length+1];
|
|
||||||
|
|
||||||
final Map<Double, int[]> numberOfEvents = new HashMap<>();
|
public static CompetingRiskSetsImpl calculateSetsEfficiently(final List<CompetingRiskResponse> initialLeftHand,
|
||||||
|
final List<CompetingRiskResponse> initialRightHand,
|
||||||
|
int[] eventsOfFocus,
|
||||||
|
boolean calculateRiskSets){
|
||||||
|
|
||||||
final List<Double> eventTimes = new ArrayList<>(n);
|
final double[] distinctEventTimes = Stream.concat(
|
||||||
final List<Double> eventAndCensorTimes = new ArrayList<>(n);
|
initialLeftHand.stream(),
|
||||||
final List<Integer> riskSetNumberList = new ArrayList<>(n);
|
initialRightHand.stream())
|
||||||
|
//.filter(y -> !y.isCensored())
|
||||||
|
.map(CompetingRiskResponse::getU)
|
||||||
|
.mapToDouble(Double::doubleValue)
|
||||||
|
.sorted()
|
||||||
|
.distinct()
|
||||||
|
.toArray();
|
||||||
|
|
||||||
|
|
||||||
|
final int m = distinctEventTimes.length;
|
||||||
|
final int[][] numberOfCurrentEventsTotal = new int[eventsOfFocus.length+1][m];
|
||||||
|
|
||||||
|
// Left Hand First
|
||||||
|
|
||||||
// need to first sort responses
|
// need to first sort responses
|
||||||
Collections.sort(responses, (y1, y2) -> {
|
Collections.sort(initialLeftHand, (y1, y2) -> {
|
||||||
if(y1.getU() < y2.getU()){
|
if(y1.getU() < y2.getU()){
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
|
@ -125,127 +135,191 @@ public class CompetingRiskUtils {
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
|
final int nLeft = initialLeftHand.size();
|
||||||
|
final int nRight = initialRightHand.size();
|
||||||
|
|
||||||
|
final int[][] numberOfCurrentEventsLeft = new int[eventsOfFocus.length+1][m];
|
||||||
|
final int[] riskSetArrayLeft = new int[m];
|
||||||
|
final int[] riskSetArrayTotal = new int[m];
|
||||||
|
|
||||||
|
|
||||||
for(int i=0; i<n; i++){
|
|
||||||
final CompetingRiskResponse currentResponse = responses.get(i);
|
|
||||||
final boolean lastOfTime = (i+1)==n || responses.get(i+1).getU() > currentResponse.getU();
|
|
||||||
|
|
||||||
numberOfCurrentEvents[currentResponse.getDelta()]++;
|
for(int k=0; k<m; k++){
|
||||||
|
riskSetArrayLeft[k] = nLeft;
|
||||||
|
riskSetArrayTotal[k] = nLeft + nRight;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Left Hand
|
||||||
|
for(int i=0; i<nLeft; i++){
|
||||||
|
final CompetingRiskResponse currentResponse = initialLeftHand.get(i);
|
||||||
|
final boolean lastOfTime = (i+1)==nLeft || initialLeftHand.get(i+1).getU() > currentResponse.getU();
|
||||||
|
|
||||||
|
final int k = Arrays.binarySearch(distinctEventTimes, currentResponse.getU());
|
||||||
|
|
||||||
|
numberOfCurrentEventsLeft[currentResponse.getDelta()][k]++;
|
||||||
|
numberOfCurrentEventsTotal[currentResponse.getDelta()][k]++;
|
||||||
|
|
||||||
if(lastOfTime){
|
if(lastOfTime){
|
||||||
int totalNumberOfCurrentEvents = 0;
|
int totalNumberOfCurrentEvents = 0;
|
||||||
for(int e = 1; e < numberOfCurrentEvents.length; e++){ // exclude censored events
|
for(int e = 1; e < eventsOfFocus.length+1; e++){ // exclude censored events
|
||||||
totalNumberOfCurrentEvents += numberOfCurrentEvents[e];
|
totalNumberOfCurrentEvents += numberOfCurrentEventsLeft[e][k];
|
||||||
}
|
}
|
||||||
|
|
||||||
final double currentTime = currentResponse.getU();
|
// Calculate risk set values
|
||||||
|
// Note that we only decrease values in the *future*
|
||||||
|
if(calculateRiskSets){
|
||||||
|
final int decreaseBy = totalNumberOfCurrentEvents + numberOfCurrentEventsLeft[0][k];
|
||||||
|
for(int j=k+1; j<m; j++){
|
||||||
|
riskSetArrayLeft[j] = riskSetArrayLeft[j] - decreaseBy;
|
||||||
|
riskSetArrayTotal[j] = riskSetArrayTotal[j] - decreaseBy;
|
||||||
|
|
||||||
if(totalNumberOfCurrentEvents > 0){ // add numberOfCurrentEvents
|
}
|
||||||
// Add point
|
|
||||||
eventTimes.add(currentTime);
|
|
||||||
numberOfEvents.put(currentTime, numberOfCurrentEvents);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Always do risk set
|
|
||||||
// remember that the LeftContinuousFunction takes into account that at this currentTime the risk value is the previous value
|
|
||||||
final int riskSet = n - (i+1);
|
|
||||||
riskSetNumberList.add(riskSet);
|
|
||||||
eventAndCensorTimes.add(currentTime);
|
|
||||||
|
|
||||||
// reset counters
|
|
||||||
numberOfCurrentEvents = new int[eventsOfFocus.length+1];
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
final double[] riskSetArray = new double[eventAndCensorTimes.size()];
|
|
||||||
final double[] timesArray = new double[eventAndCensorTimes.size()];
|
// Right Hand Next. Note that we only need to keep track of the Left Hand and the Total
|
||||||
for(int i=0; i<riskSetArray.length; i++){
|
|
||||||
timesArray[i] = eventAndCensorTimes.get(i);
|
// need to first sort responses
|
||||||
riskSetArray[i] = riskSetNumberList.get(i);
|
Collections.sort(initialRightHand, (y1, y2) -> {
|
||||||
|
if(y1.getU() < y2.getU()){
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
else if(y1.getU() > y2.getU()){
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
else{
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
// Right Hand
|
||||||
|
int[] currentEventsRight = new int[eventsOfFocus.length+1];
|
||||||
|
for(int i=0; i<nRight; i++){
|
||||||
|
final CompetingRiskResponse currentResponse = initialRightHand.get(i);
|
||||||
|
final boolean lastOfTime = (i+1)==nRight || initialRightHand.get(i+1).getU() > currentResponse.getU();
|
||||||
|
|
||||||
|
final int k = Arrays.binarySearch(distinctEventTimes, currentResponse.getU());
|
||||||
|
|
||||||
|
currentEventsRight[currentResponse.getDelta()]++;
|
||||||
|
numberOfCurrentEventsTotal[currentResponse.getDelta()][k]++;
|
||||||
|
|
||||||
|
if(lastOfTime){
|
||||||
|
int totalNumberOfCurrentEvents = 0;
|
||||||
|
for(int e = 1; e < eventsOfFocus.length+1; e++){ // exclude censored events
|
||||||
|
totalNumberOfCurrentEvents += currentEventsRight[e];
|
||||||
|
}
|
||||||
|
|
||||||
|
// Calculate risk set values
|
||||||
|
// Note that we only decrease values in the *future*
|
||||||
|
if(calculateRiskSets){
|
||||||
|
final int decreaseBy = totalNumberOfCurrentEvents + currentEventsRight[0];
|
||||||
|
for(int j=k+1; j<m; j++){
|
||||||
|
riskSetArrayTotal[j] = riskSetArrayTotal[j] - decreaseBy;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reset
|
||||||
|
currentEventsRight = new int[eventsOfFocus.length+1];
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
final LeftContinuousStepFunction riskSetFunction = new LeftContinuousStepFunction(timesArray, riskSetArray, n);
|
return new CompetingRiskSetsImpl(distinctEventTimes, riskSetArrayLeft, riskSetArrayTotal, numberOfCurrentEventsLeft, numberOfCurrentEventsTotal);
|
||||||
|
|
||||||
return CompetingRiskSetsImpl.builder()
|
|
||||||
.numberOfEvents(numberOfEvents)
|
|
||||||
.riskSet(riskSetFunction)
|
|
||||||
.eventTimes(eventTimes)
|
|
||||||
.build();
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public static CompetingRiskGraySetsImpl calculateGraySetsEfficiently(final List<CompetingRiskResponseWithCensorTime> responses, int[] eventsOfFocus){
|
|
||||||
final List sillyList = responses; // annoying Java generic work-around
|
|
||||||
final CompetingRiskSetsImpl originalSets = calculateSetsEfficiently(sillyList, eventsOfFocus);
|
|
||||||
|
|
||||||
final double[] allTimes = DoubleStream.concat(
|
public static CompetingRiskGraySetsImpl calculateGraySetsEfficiently(final List<CompetingRiskResponseWithCensorTime> initialLeftHand,
|
||||||
responses.stream()
|
final List<CompetingRiskResponseWithCensorTime> initialRightHand,
|
||||||
.mapToDouble(CompetingRiskResponseWithCensorTime::getC),
|
int[] eventsOfFocus){
|
||||||
responses.stream()
|
|
||||||
.mapToDouble(CompetingRiskResponseWithCensorTime::getU)
|
|
||||||
).sorted().distinct().toArray();
|
|
||||||
|
|
||||||
|
final List leftHandGenericsSuck = initialLeftHand;
|
||||||
|
final List rightHandGenericsSuck = initialRightHand;
|
||||||
|
|
||||||
|
final CompetingRiskSetsImpl normalSets = calculateSetsEfficiently(
|
||||||
|
leftHandGenericsSuck,
|
||||||
|
rightHandGenericsSuck,
|
||||||
|
eventsOfFocus, false);
|
||||||
|
|
||||||
final VeryDiscontinuousStepFunction[] riskSets = new VeryDiscontinuousStepFunction[eventsOfFocus.length];
|
final double[] times = normalSets.times;
|
||||||
|
final int[][] numberOfEventsLeft = normalSets.numberOfEventsLeft;
|
||||||
|
final int[][] numberOfEventsTotal = normalSets.numberOfEventsTotal;
|
||||||
|
|
||||||
for(final int event : eventsOfFocus){
|
// FYI; initialLeftHand and initialRightHand have both now been sorted
|
||||||
final double[] yAt = new double[allTimes.length];
|
// Time to calculate the Gray modified risk sets
|
||||||
final double[] yRight = new double[allTimes.length];
|
final int[][] riskSetsLeft = new int[eventsOfFocus.length][times.length];
|
||||||
|
final int[][] riskSetsTotal = new int[eventsOfFocus.length][times.length];
|
||||||
|
|
||||||
for(final CompetingRiskResponseWithCensorTime response : responses){
|
// Left hand first
|
||||||
if(response.getDelta() == event){
|
for(final CompetingRiskResponseWithCensorTime response : initialLeftHand){
|
||||||
// traditional case only; increment on time t when I(t <= Ui)
|
final double time = response.getU();
|
||||||
final double time = response.getU();
|
final int k = Arrays.binarySearch(times, time);
|
||||||
final int index = Arrays.binarySearch(allTimes, time);
|
final int delta_m_1 = response.getDelta() - 1;
|
||||||
|
final double censorTime = response.getC();
|
||||||
|
|
||||||
if(index < 0){ // TODO remove once code is stable
|
for(int j=0; j<eventsOfFocus.length; j++){
|
||||||
throw new IllegalStateException("Index shouldn't be negative!");
|
final int[] riskSetLeftJ = riskSetsLeft[j];
|
||||||
}
|
final int[] riskSetTotalJ = riskSetsTotal[j];
|
||||||
|
|
||||||
// All yAts up to and including index are incremented;
|
// first iteration; perform normal increment as if Y is normal
|
||||||
// All yRights up to index are incremented
|
// corresponds to the first part, U_i >= t, in I(...)
|
||||||
yAt[index]++;
|
for(int i=0; i<=k; i++){
|
||||||
for(int i=0; i<index; i++){
|
riskSetLeftJ[i]++;
|
||||||
yAt[i]++;
|
riskSetTotalJ[i]++;
|
||||||
yRight[i]++;
|
}
|
||||||
|
|
||||||
|
// second iteration; only if delta-1 != j
|
||||||
|
// corresponds to the second part, U_i < t & delta_i != j & C_i > t
|
||||||
|
if(delta_m_1 != j && !response.isCensored()){
|
||||||
|
int i = k+1;
|
||||||
|
while(i < times.length && times[i] < censorTime){
|
||||||
|
riskSetLeftJ[i]++;
|
||||||
|
riskSetTotalJ[i]++;
|
||||||
|
i++;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
else{
|
|
||||||
// need to increment on time t on following conditions; I(t <= Ui | t < Ci)
|
|
||||||
// Fact: Ci >= Ui.
|
|
||||||
|
|
||||||
// increment yAt up to Ci. If Ui==Ci, increment yAt at Ci.
|
|
||||||
final double time = response.getC();
|
|
||||||
final int index = Arrays.binarySearch(allTimes, time);
|
|
||||||
|
|
||||||
if(index < 0){ // TODO remove once code is stable
|
|
||||||
throw new IllegalStateException("Index shouldn't be negative!");
|
|
||||||
}
|
|
||||||
|
|
||||||
for(int i=0; i<index; i++){
|
|
||||||
yAt[i]++;
|
|
||||||
yRight[i]++;
|
|
||||||
}
|
|
||||||
if(response.getU() == response.getC()){
|
|
||||||
yAt[index]++;
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
riskSets[event-1] = new VeryDiscontinuousStepFunction(allTimes, yAt, yRight, responses.size());
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return CompetingRiskGraySetsImpl.builder()
|
// Repeat for right hand
|
||||||
.numberOfEvents(originalSets.getNumberOfEvents())
|
for(final CompetingRiskResponseWithCensorTime response : initialRightHand){
|
||||||
.eventTimes(originalSets.getEventTimes())
|
final double time = response.getU();
|
||||||
.riskSet(riskSets)
|
final int k = Arrays.binarySearch(times, time);
|
||||||
.build();
|
final int delta_m_1 = response.getDelta() - 1;
|
||||||
|
final double censorTime = response.getC();
|
||||||
|
|
||||||
|
for(int j=0; j<eventsOfFocus.length; j++){
|
||||||
|
final int[] riskSetTotalJ = riskSetsTotal[j];
|
||||||
|
|
||||||
|
// first iteration; perform normal increment as if Y is normal
|
||||||
|
// corresponds to the first part, U_i >= t, in I(...)
|
||||||
|
for(int i=0; i<=k; i++){
|
||||||
|
riskSetTotalJ[i]++;
|
||||||
|
}
|
||||||
|
|
||||||
|
// second iteration; only if delta-1 != j
|
||||||
|
// corresponds to the second part, U_i < t & delta_i != j & C_i > t
|
||||||
|
if(delta_m_1 != j && !response.isCensored()){
|
||||||
|
int i = k+1;
|
||||||
|
while(i < times.length && times[i] < censorTime){
|
||||||
|
riskSetTotalJ[i]++;
|
||||||
|
i++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
return new CompetingRiskGraySetsImpl(times, riskSetsLeft, riskSetsTotal, numberOfEventsLeft, numberOfEventsTotal);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,55 +1,132 @@
|
||||||
package ca.joeltherrien.randomforest.responses.competingrisk.differentiator;
|
package ca.joeltherrien.randomforest.responses.competingrisk.differentiator;
|
||||||
|
|
||||||
|
import ca.joeltherrien.randomforest.Row;
|
||||||
|
import ca.joeltherrien.randomforest.covariates.Covariate;
|
||||||
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
|
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
|
||||||
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskSets;
|
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskSets;
|
||||||
import ca.joeltherrien.randomforest.tree.SimpleGroupDifferentiator;
|
import ca.joeltherrien.randomforest.tree.GroupDifferentiator;
|
||||||
|
import ca.joeltherrien.randomforest.tree.Split;
|
||||||
|
import ca.joeltherrien.randomforest.tree.SplitAndScore;
|
||||||
import lombok.AllArgsConstructor;
|
import lombok.AllArgsConstructor;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
|
||||||
import java.util.stream.Stream;
|
import java.util.Iterator;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* See page 761 of Random survival forests for competing risks by Ishwaran et al. The class is abstract as Gray's test
|
* See page 761 of Random survival forests for competing risks by Ishwaran et al. The class is abstract as Gray's test
|
||||||
* modifies the abstract method.
|
* modifies the abstract method.
|
||||||
*
|
*
|
||||||
*/
|
*/
|
||||||
public abstract class CompetingRiskGroupDifferentiator<Y extends CompetingRiskResponse> extends SimpleGroupDifferentiator<Y> {
|
public abstract class CompetingRiskGroupDifferentiator<Y extends CompetingRiskResponse> implements GroupDifferentiator<Y> {
|
||||||
|
|
||||||
|
abstract protected CompetingRiskSets<Y> createCompetingRiskSets(List<Y> leftHand, List<Y> rightHand);
|
||||||
|
|
||||||
|
abstract protected Double getScore(final CompetingRiskSets<Y> competingRiskSets);
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public SplitAndScore<Y, ?> differentiate(Iterator<Split<Y, ?>> splitIterator) {
|
||||||
|
|
||||||
|
if(splitIterator instanceof Covariate.SplitRuleUpdater){
|
||||||
|
return differentiateWithSplitUpdater((Covariate.SplitRuleUpdater) splitIterator);
|
||||||
|
}
|
||||||
|
else{
|
||||||
|
return differentiateWithBasicIterator(splitIterator);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private SplitAndScore<Y, ?> differentiateWithBasicIterator(Iterator<Split<Y, ?>> splitIterator){
|
||||||
|
Double bestScore = null;
|
||||||
|
Split<Y, ?> bestSplit = null;
|
||||||
|
|
||||||
|
while(splitIterator.hasNext()){
|
||||||
|
final Split<Y, ?> candidateSplit = splitIterator.next();
|
||||||
|
|
||||||
|
final List<Y> leftHand = candidateSplit.getLeftHand().stream().map(Row::getResponse).collect(Collectors.toList());
|
||||||
|
final List<Y> rightHand = candidateSplit.getRightHand().stream().map(Row::getResponse).collect(Collectors.toList());
|
||||||
|
|
||||||
|
if(leftHand.isEmpty() || rightHand.isEmpty()){
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
final CompetingRiskSets<Y> competingRiskSets = createCompetingRiskSets(leftHand, rightHand);
|
||||||
|
|
||||||
|
final Double score = getScore(competingRiskSets);
|
||||||
|
|
||||||
|
if(Double.isFinite(score) && (bestScore == null || score > bestScore)){
|
||||||
|
bestScore = score;
|
||||||
|
bestSplit = candidateSplit;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if(bestSplit == null){
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
return new SplitAndScore<>(bestSplit, bestScore);
|
||||||
|
}
|
||||||
|
|
||||||
|
private SplitAndScore<Y, ?> differentiateWithSplitUpdater(Covariate.SplitRuleUpdater<Y, ?> splitRuleUpdater) {
|
||||||
|
|
||||||
|
final List<Y> leftInitialSplit = splitRuleUpdater.currentSplit().getLeftHand()
|
||||||
|
.stream().map(Row::getResponse).collect(Collectors.toList());
|
||||||
|
final List<Y> rightInitialSplit = splitRuleUpdater.currentSplit().getRightHand()
|
||||||
|
.stream().map(Row::getResponse).collect(Collectors.toList());
|
||||||
|
|
||||||
|
final CompetingRiskSets<Y> competingRiskSets = createCompetingRiskSets(leftInitialSplit, rightInitialSplit);
|
||||||
|
|
||||||
|
Double bestScore = null;
|
||||||
|
Split<Y, ?> bestSplit = null;
|
||||||
|
|
||||||
|
while(splitRuleUpdater.hasNext()){
|
||||||
|
for(Row<Y> rowMoved : splitRuleUpdater.nextUpdate().rowsMovedToLeftHand()){
|
||||||
|
competingRiskSets.update(rowMoved.getResponse());
|
||||||
|
}
|
||||||
|
|
||||||
|
final Double score = getScore(competingRiskSets);
|
||||||
|
|
||||||
|
if(Double.isFinite(score) && (bestScore == null || score > bestScore)){
|
||||||
|
bestScore = score;
|
||||||
|
bestSplit = splitRuleUpdater.currentSplit();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if(bestSplit == null){
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
return new SplitAndScore<>(bestSplit, bestScore);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Calculates the log rank value (or the Gray's test value) for a *specific* event cause.
|
* Calculates the log rank value (or the Gray's test value) for a *specific* event cause.
|
||||||
*
|
*
|
||||||
* @param eventOfFocus
|
* @param eventOfFocus
|
||||||
* @param competingRiskSetsLeft A summary of the different sets used in the calculation for the left side
|
* @param competingRiskSets A summary of the different sets used in the calculation
|
||||||
* @param competingRiskSetsRight A summary of the different sets used in the calculation for the right side
|
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
LogRankValue specificLogRankValue(final int eventOfFocus, final CompetingRiskSets competingRiskSetsLeft, final CompetingRiskSets competingRiskSetsRight){
|
LogRankValue specificLogRankValue(final int eventOfFocus, final CompetingRiskSets<Y> competingRiskSets){
|
||||||
|
|
||||||
final double[] distinctEventTimes = Stream.concat(
|
|
||||||
competingRiskSetsLeft.getEventTimes().stream(),
|
|
||||||
competingRiskSetsRight.getEventTimes().stream())
|
|
||||||
.mapToDouble(Double::doubleValue)
|
|
||||||
.sorted()
|
|
||||||
.distinct()
|
|
||||||
.toArray();
|
|
||||||
|
|
||||||
double summation = 0.0;
|
double summation = 0.0;
|
||||||
double variance = 0.0;
|
double variance = 0.0;
|
||||||
|
|
||||||
for(final double time_k : distinctEventTimes){
|
final double[] distinctTimes = competingRiskSets.getDistinctTimes();
|
||||||
|
|
||||||
|
for(int k = 0; k<distinctTimes.length; k++){
|
||||||
|
final double time_k = distinctTimes[k];
|
||||||
final double weight = weight(time_k); // W_j(t_k)
|
final double weight = weight(time_k); // W_j(t_k)
|
||||||
final double numberEventsAtTimeDaughterLeft = competingRiskSetsLeft.getNumberOfEvents(time_k, eventOfFocus); // // d_{j,l}(t_k)
|
final double numberEventsAtTimeDaughterLeft = competingRiskSets.getNumberOfEventsLeft(k, eventOfFocus); // // d_{j,l}(t_k)
|
||||||
final double numberEventsAtTimeDaughterRight = competingRiskSetsRight.getNumberOfEvents(time_k, eventOfFocus); // d_{j,r}(t_k)
|
final double numberEventsAtTimeDaughterTotal = competingRiskSets.getNumberOfEventsTotal(k, eventOfFocus); // d_j(t_k)
|
||||||
final double numberOfEventsAtTime = numberEventsAtTimeDaughterLeft + numberEventsAtTimeDaughterRight; // d_j(t_k)
|
|
||||||
|
|
||||||
final double individualsAtRiskDaughterLeft = competingRiskSetsLeft.getRiskSet(eventOfFocus).evaluate(time_k); // Y_l(t_k)
|
final double individualsAtRiskDaughterLeft = competingRiskSets.getRiskSetLeft(k, eventOfFocus); // Y_l(t_k)
|
||||||
final double individualsAtRiskDaughterRight = competingRiskSetsRight.getRiskSet(eventOfFocus).evaluate(time_k); // Y_r(t_k)
|
final double individualsAtRiskDaughterTotal = competingRiskSets.getRiskSetTotal(k, eventOfFocus); // Y(t_k)
|
||||||
final double individualsAtRisk = individualsAtRiskDaughterLeft + individualsAtRiskDaughterRight; // Y(t_k)
|
|
||||||
|
|
||||||
final double deltaSummation = weight*(numberEventsAtTimeDaughterLeft - numberOfEventsAtTime*individualsAtRiskDaughterLeft/individualsAtRisk);
|
final double deltaSummation = weight*(numberEventsAtTimeDaughterLeft - numberEventsAtTimeDaughterTotal*individualsAtRiskDaughterLeft/individualsAtRiskDaughterTotal);
|
||||||
final double deltaVariance = weight*weight*numberOfEventsAtTime*individualsAtRiskDaughterLeft/individualsAtRisk
|
final double deltaVariance = weight*weight*numberEventsAtTimeDaughterTotal*individualsAtRiskDaughterLeft/individualsAtRiskDaughterTotal
|
||||||
* (1.0 - individualsAtRiskDaughterLeft / individualsAtRisk)
|
* (1.0 - individualsAtRiskDaughterLeft / individualsAtRiskDaughterTotal)
|
||||||
* ((individualsAtRisk - numberOfEventsAtTime) / (individualsAtRisk - 1.0));
|
* ((individualsAtRiskDaughterTotal - numberEventsAtTimeDaughterTotal) / (individualsAtRiskDaughterTotal - 1.0));
|
||||||
|
|
||||||
// Note - notation differs slightly with what is found in STAT 855 notes, but they are equivalent.
|
// Note - notation differs slightly with what is found in STAT 855 notes, but they are equivalent.
|
||||||
// Note - if individualsAtRisk == 1 then variance will be NaN.
|
// Note - if individualsAtRisk == 1 then variance will be NaN.
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
package ca.joeltherrien.randomforest.responses.competingrisk.differentiator;
|
package ca.joeltherrien.randomforest.responses.competingrisk.differentiator;
|
||||||
|
|
||||||
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskGraySetsImpl;
|
|
||||||
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponseWithCensorTime;
|
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponseWithCensorTime;
|
||||||
|
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskSets;
|
||||||
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskUtils;
|
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskUtils;
|
||||||
import lombok.RequiredArgsConstructor;
|
import lombok.RequiredArgsConstructor;
|
||||||
|
|
||||||
|
@ -17,19 +17,17 @@ public class GrayLogRankMultipleGroupDifferentiator extends CompetingRiskGroupDi
|
||||||
private final int[] events;
|
private final int[] events;
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Double getScore(List<CompetingRiskResponseWithCensorTime> leftHand, List<CompetingRiskResponseWithCensorTime> rightHand) {
|
protected CompetingRiskSets<CompetingRiskResponseWithCensorTime> createCompetingRiskSets(List<CompetingRiskResponseWithCensorTime> leftHand, List<CompetingRiskResponseWithCensorTime> rightHand){
|
||||||
if(leftHand.size() == 0 || rightHand.size() == 0){
|
return CompetingRiskUtils.calculateGraySetsEfficiently(leftHand, rightHand, events);
|
||||||
return null;
|
}
|
||||||
}
|
|
||||||
|
|
||||||
final CompetingRiskGraySetsImpl competingRiskSetsLeft = CompetingRiskUtils.calculateGraySetsEfficiently(leftHand, events);
|
|
||||||
final CompetingRiskGraySetsImpl competingRiskSetsRight = CompetingRiskUtils.calculateGraySetsEfficiently(rightHand, events);
|
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected Double getScore(final CompetingRiskSets<CompetingRiskResponseWithCensorTime> competingRiskSets){
|
||||||
double numerator = 0.0;
|
double numerator = 0.0;
|
||||||
double denominatorSquared = 0.0;
|
double denominatorSquared = 0.0;
|
||||||
|
|
||||||
for(final int eventOfFocus : events){
|
for(final int eventOfFocus : events){
|
||||||
final LogRankValue valueOfInterest = specificLogRankValue(eventOfFocus, competingRiskSetsLeft, competingRiskSetsRight);
|
final LogRankValue valueOfInterest = specificLogRankValue(eventOfFocus, competingRiskSets);
|
||||||
|
|
||||||
numerator += valueOfInterest.getNumerator()*valueOfInterest.getVarianceSqrt();
|
numerator += valueOfInterest.getNumerator()*valueOfInterest.getVarianceSqrt();
|
||||||
denominatorSquared += valueOfInterest.getVariance();
|
denominatorSquared += valueOfInterest.getVariance();
|
||||||
|
@ -37,7 +35,6 @@ public class GrayLogRankMultipleGroupDifferentiator extends CompetingRiskGroupDi
|
||||||
}
|
}
|
||||||
|
|
||||||
return Math.abs(numerator / Math.sqrt(denominatorSquared));
|
return Math.abs(numerator / Math.sqrt(denominatorSquared));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
package ca.joeltherrien.randomforest.responses.competingrisk.differentiator;
|
package ca.joeltherrien.randomforest.responses.competingrisk.differentiator;
|
||||||
|
|
||||||
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskGraySetsImpl;
|
|
||||||
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponseWithCensorTime;
|
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponseWithCensorTime;
|
||||||
|
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskSets;
|
||||||
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskUtils;
|
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskUtils;
|
||||||
import lombok.RequiredArgsConstructor;
|
import lombok.RequiredArgsConstructor;
|
||||||
|
|
||||||
|
@ -18,18 +18,14 @@ public class GrayLogRankSingleGroupDifferentiator extends CompetingRiskGroupDiff
|
||||||
private final int[] events;
|
private final int[] events;
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Double getScore(List<CompetingRiskResponseWithCensorTime> leftHand, List<CompetingRiskResponseWithCensorTime> rightHand) {
|
protected CompetingRiskSets<CompetingRiskResponseWithCensorTime> createCompetingRiskSets(List<CompetingRiskResponseWithCensorTime> leftHand, List<CompetingRiskResponseWithCensorTime> rightHand){
|
||||||
if(leftHand.size() == 0 || rightHand.size() == 0){
|
return CompetingRiskUtils.calculateGraySetsEfficiently(leftHand, rightHand, events);
|
||||||
return null;
|
}
|
||||||
}
|
|
||||||
|
|
||||||
final CompetingRiskGraySetsImpl competingRiskSetsLeft = CompetingRiskUtils.calculateGraySetsEfficiently(leftHand, events);
|
|
||||||
final CompetingRiskGraySetsImpl competingRiskSetsRight = CompetingRiskUtils.calculateGraySetsEfficiently(rightHand, events);
|
|
||||||
|
|
||||||
final LogRankValue valueOfInterest = specificLogRankValue(eventOfFocus, competingRiskSetsLeft, competingRiskSetsRight);
|
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected Double getScore(final CompetingRiskSets<CompetingRiskResponseWithCensorTime> competingRiskSets){
|
||||||
|
final LogRankValue valueOfInterest = specificLogRankValue(eventOfFocus, competingRiskSets);
|
||||||
return Math.abs(valueOfInterest.getNumerator() / valueOfInterest.getVarianceSqrt());
|
return Math.abs(valueOfInterest.getNumerator() / valueOfInterest.getVarianceSqrt());
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
package ca.joeltherrien.randomforest.responses.competingrisk.differentiator;
|
package ca.joeltherrien.randomforest.responses.competingrisk.differentiator;
|
||||||
|
|
||||||
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
|
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
|
||||||
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskSetsImpl;
|
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskSets;
|
||||||
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskUtils;
|
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskUtils;
|
||||||
import lombok.RequiredArgsConstructor;
|
import lombok.RequiredArgsConstructor;
|
||||||
|
|
||||||
|
@ -17,19 +17,17 @@ public class LogRankMultipleGroupDifferentiator extends CompetingRiskGroupDiffer
|
||||||
private final int[] events;
|
private final int[] events;
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Double getScore(List<CompetingRiskResponse> leftHand, List<CompetingRiskResponse> rightHand) {
|
protected CompetingRiskSets<CompetingRiskResponse> createCompetingRiskSets(List<CompetingRiskResponse> leftHand, List<CompetingRiskResponse> rightHand){
|
||||||
if(leftHand.size() == 0 || rightHand.size() == 0){
|
return CompetingRiskUtils.calculateSetsEfficiently(leftHand, rightHand, events, true);
|
||||||
return null;
|
}
|
||||||
}
|
|
||||||
|
|
||||||
final CompetingRiskSetsImpl competingRiskSetsLeft = CompetingRiskUtils.calculateSetsEfficiently(leftHand, events);
|
|
||||||
final CompetingRiskSetsImpl competingRiskSetsRight = CompetingRiskUtils.calculateSetsEfficiently(rightHand, events);
|
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected Double getScore(final CompetingRiskSets<CompetingRiskResponse> competingRiskSets){
|
||||||
double numerator = 0.0;
|
double numerator = 0.0;
|
||||||
double denominatorSquared = 0.0;
|
double denominatorSquared = 0.0;
|
||||||
|
|
||||||
for(final int eventOfFocus : events){
|
for(final int eventOfFocus : events){
|
||||||
final LogRankValue valueOfInterest = specificLogRankValue(eventOfFocus, competingRiskSetsLeft, competingRiskSetsRight);
|
final LogRankValue valueOfInterest = specificLogRankValue(eventOfFocus, competingRiskSets);
|
||||||
|
|
||||||
numerator += valueOfInterest.getNumerator()*valueOfInterest.getVarianceSqrt();
|
numerator += valueOfInterest.getNumerator()*valueOfInterest.getVarianceSqrt();
|
||||||
denominatorSquared += valueOfInterest.getVariance();
|
denominatorSquared += valueOfInterest.getVariance();
|
||||||
|
@ -37,7 +35,7 @@ public class LogRankMultipleGroupDifferentiator extends CompetingRiskGroupDiffer
|
||||||
}
|
}
|
||||||
|
|
||||||
return Math.abs(numerator / Math.sqrt(denominatorSquared));
|
return Math.abs(numerator / Math.sqrt(denominatorSquared));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
package ca.joeltherrien.randomforest.responses.competingrisk.differentiator;
|
package ca.joeltherrien.randomforest.responses.competingrisk.differentiator;
|
||||||
|
|
||||||
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
|
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
|
||||||
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskSetsImpl;
|
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskSets;
|
||||||
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskUtils;
|
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskUtils;
|
||||||
import lombok.RequiredArgsConstructor;
|
import lombok.RequiredArgsConstructor;
|
||||||
|
|
||||||
|
@ -18,18 +18,14 @@ public class LogRankSingleGroupDifferentiator extends CompetingRiskGroupDifferen
|
||||||
private final int[] events;
|
private final int[] events;
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Double getScore(List<CompetingRiskResponse> leftHand, List<CompetingRiskResponse> rightHand) {
|
protected CompetingRiskSets<CompetingRiskResponse> createCompetingRiskSets(List<CompetingRiskResponse> leftHand, List<CompetingRiskResponse> rightHand){
|
||||||
if(leftHand.size() == 0 || rightHand.size() == 0){
|
return CompetingRiskUtils.calculateSetsEfficiently(leftHand, rightHand, events, true);
|
||||||
return null;
|
}
|
||||||
}
|
|
||||||
|
|
||||||
final CompetingRiskSetsImpl competingRiskSetsLeft = CompetingRiskUtils.calculateSetsEfficiently(leftHand, events);
|
|
||||||
final CompetingRiskSetsImpl competingRiskSetsRight = CompetingRiskUtils.calculateSetsEfficiently(rightHand, events);
|
|
||||||
|
|
||||||
final LogRankValue valueOfInterest = specificLogRankValue(eventOfFocus, competingRiskSetsLeft, competingRiskSetsRight);
|
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected Double getScore(final CompetingRiskSets<CompetingRiskResponse> competingRiskSets){
|
||||||
|
final LogRankValue valueOfInterest = specificLogRankValue(eventOfFocus, competingRiskSets);
|
||||||
return Math.abs(valueOfInterest.getNumerator() / valueOfInterest.getVarianceSqrt());
|
return Math.abs(valueOfInterest.getNumerator() / valueOfInterest.getVarianceSqrt());
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -12,6 +12,6 @@ import java.util.Iterator;
|
||||||
*/
|
*/
|
||||||
public interface GroupDifferentiator<Y> {
|
public interface GroupDifferentiator<Y> {
|
||||||
|
|
||||||
<V> SplitAndScore<Y, V> differentiate(Iterator<Split<Y, V>> splitIterator);
|
SplitAndScore<Y, ?> differentiate(Iterator<Split<Y, ?>> splitIterator);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -9,12 +9,12 @@ import java.util.stream.Collectors;
|
||||||
public abstract class SimpleGroupDifferentiator<Y> implements GroupDifferentiator<Y> {
|
public abstract class SimpleGroupDifferentiator<Y> implements GroupDifferentiator<Y> {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public <V> SplitAndScore<Y, V> differentiate(Iterator<Split<Y, V>> splitIterator) {
|
public SplitAndScore<Y, ?> differentiate(Iterator<Split<Y, ?>> splitIterator) {
|
||||||
Double bestScore = null;
|
Double bestScore = null;
|
||||||
Split<Y, V> bestSplit = null;
|
Split<Y, ?> bestSplit = null;
|
||||||
|
|
||||||
while(splitIterator.hasNext()){
|
while(splitIterator.hasNext()){
|
||||||
final Split<Y, V> candidateSplit = splitIterator.next();
|
final Split<Y, ?> candidateSplit = splitIterator.next();
|
||||||
|
|
||||||
final List<Y> leftHand = candidateSplit.getLeftHand().stream().map(Row::getResponse).collect(Collectors.toList());
|
final List<Y> leftHand = candidateSplit.getLeftHand().stream().map(Row::getResponse).collect(Collectors.toList());
|
||||||
final List<Y> rightHand = candidateSplit.getRightHand().stream().map(Row::getResponse).collect(Collectors.toList());
|
final List<Y> rightHand = candidateSplit.getRightHand().stream().map(Row::getResponse).collect(Collectors.toList());
|
||||||
|
|
|
@ -1,214 +0,0 @@
|
||||||
package ca.joeltherrien.randomforest.competingrisk;
|
|
||||||
|
|
||||||
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskGraySetsImpl;
|
|
||||||
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponseWithCensorTime;
|
|
||||||
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskSetsImpl;
|
|
||||||
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskUtils;
|
|
||||||
import org.junit.jupiter.api.Test;
|
|
||||||
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
|
||||||
|
|
||||||
public class TestCalculatingCompetingRiskSets {
|
|
||||||
|
|
||||||
public List<CompetingRiskResponseWithCensorTime> generateData(){
|
|
||||||
final List<CompetingRiskResponseWithCensorTime> data = new ArrayList<>();
|
|
||||||
|
|
||||||
data.add(new CompetingRiskResponseWithCensorTime(1, 1, 3));
|
|
||||||
data.add(new CompetingRiskResponseWithCensorTime(1, 1, 3));
|
|
||||||
data.add(new CompetingRiskResponseWithCensorTime(0, 1, 1));
|
|
||||||
data.add(new CompetingRiskResponseWithCensorTime(1, 2, 2.5));
|
|
||||||
data.add(new CompetingRiskResponseWithCensorTime(2, 3, 4));
|
|
||||||
data.add(new CompetingRiskResponseWithCensorTime(0, 3, 3));
|
|
||||||
data.add(new CompetingRiskResponseWithCensorTime(1, 4, 4));
|
|
||||||
data.add(new CompetingRiskResponseWithCensorTime(0, 5, 5));
|
|
||||||
data.add(new CompetingRiskResponseWithCensorTime(2, 6, 7));
|
|
||||||
|
|
||||||
return data;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testCalculatingSets(){
|
|
||||||
final List data = generateData();
|
|
||||||
|
|
||||||
final CompetingRiskSetsImpl sets = CompetingRiskUtils.calculateSetsEfficiently(data, new int[]{1,2});
|
|
||||||
|
|
||||||
final List<Double> times = sets.getEventTimes();
|
|
||||||
assertEquals(5, times.size());
|
|
||||||
|
|
||||||
// Times
|
|
||||||
assertEquals(1.0, times.get(0).doubleValue());
|
|
||||||
assertEquals(2.0, times.get(1).doubleValue());
|
|
||||||
assertEquals(3.0, times.get(2).doubleValue());
|
|
||||||
assertEquals(4.0, times.get(3).doubleValue());
|
|
||||||
assertEquals(6.0, times.get(4).doubleValue());
|
|
||||||
|
|
||||||
// Number of Events
|
|
||||||
assertEquals(2, sets.getNumberOfEvents(1.0, 1));
|
|
||||||
assertEquals(0, sets.getNumberOfEvents(1.0, 2));
|
|
||||||
|
|
||||||
assertEquals(1, sets.getNumberOfEvents(2.0, 1));
|
|
||||||
assertEquals(0, sets.getNumberOfEvents(2.0, 2));
|
|
||||||
|
|
||||||
assertEquals(0, sets.getNumberOfEvents(3.0, 1));
|
|
||||||
assertEquals(1, sets.getNumberOfEvents(3.0, 2));
|
|
||||||
|
|
||||||
assertEquals(1, sets.getNumberOfEvents(4.0, 1));
|
|
||||||
assertEquals(0, sets.getNumberOfEvents(4.0, 2));
|
|
||||||
|
|
||||||
assertEquals(0, sets.getNumberOfEvents(6.0, 1));
|
|
||||||
assertEquals(1, sets.getNumberOfEvents(6.0, 2));
|
|
||||||
|
|
||||||
// Make sure it doesn't break for other times
|
|
||||||
assertEquals(0, sets.getNumberOfEvents(5.5, 1));
|
|
||||||
assertEquals(0, sets.getNumberOfEvents(5.5, 2));
|
|
||||||
|
|
||||||
|
|
||||||
// Risk set
|
|
||||||
assertEquals(9, sets.getRiskSet(1).evaluate(0.5));
|
|
||||||
assertEquals(9, sets.getRiskSet(2).evaluate(0.5));
|
|
||||||
|
|
||||||
assertEquals(9, sets.getRiskSet(1).evaluate(1.0));
|
|
||||||
assertEquals(9, sets.getRiskSet(2).evaluate(1.0));
|
|
||||||
|
|
||||||
assertEquals(6, sets.getRiskSet(1).evaluate(1.5));
|
|
||||||
assertEquals(6, sets.getRiskSet(2).evaluate(1.5));
|
|
||||||
|
|
||||||
assertEquals(6, sets.getRiskSet(1).evaluate(2.0));
|
|
||||||
assertEquals(6, sets.getRiskSet(2).evaluate(2.0));
|
|
||||||
|
|
||||||
assertEquals(5, sets.getRiskSet(1).evaluate(2.3));
|
|
||||||
assertEquals(5, sets.getRiskSet(2).evaluate(2.3));
|
|
||||||
|
|
||||||
assertEquals(5, sets.getRiskSet(1).evaluate(2.5));
|
|
||||||
assertEquals(5, sets.getRiskSet(2).evaluate(2.5));
|
|
||||||
|
|
||||||
assertEquals(5, sets.getRiskSet(1).evaluate(2.7));
|
|
||||||
assertEquals(5, sets.getRiskSet(2).evaluate(2.7));
|
|
||||||
|
|
||||||
assertEquals(5, sets.getRiskSet(1).evaluate(3.0));
|
|
||||||
assertEquals(5, sets.getRiskSet(2).evaluate(3.0));
|
|
||||||
|
|
||||||
assertEquals(3, sets.getRiskSet(1).evaluate(3.5));
|
|
||||||
assertEquals(3, sets.getRiskSet(2).evaluate(3.5));
|
|
||||||
|
|
||||||
assertEquals(3, sets.getRiskSet(1).evaluate(4.0));
|
|
||||||
assertEquals(3, sets.getRiskSet(2).evaluate(4.0));
|
|
||||||
|
|
||||||
assertEquals(2, sets.getRiskSet(1).evaluate(4.5));
|
|
||||||
assertEquals(2, sets.getRiskSet(2).evaluate(4.5));
|
|
||||||
|
|
||||||
assertEquals(2, sets.getRiskSet(1).evaluate(5.0));
|
|
||||||
assertEquals(2, sets.getRiskSet(2).evaluate(5.0));
|
|
||||||
|
|
||||||
assertEquals(1, sets.getRiskSet(1).evaluate(5.5));
|
|
||||||
assertEquals(1, sets.getRiskSet(2).evaluate(5.5));
|
|
||||||
|
|
||||||
assertEquals(1, sets.getRiskSet(1).evaluate(6.0));
|
|
||||||
assertEquals(1, sets.getRiskSet(2).evaluate(6.0));
|
|
||||||
|
|
||||||
assertEquals(0, sets.getRiskSet(1).evaluate(6.5));
|
|
||||||
assertEquals(0, sets.getRiskSet(2).evaluate(6.5));
|
|
||||||
|
|
||||||
assertEquals(0, sets.getRiskSet(1).evaluate(7.0));
|
|
||||||
assertEquals(0, sets.getRiskSet(2).evaluate(7.0));
|
|
||||||
|
|
||||||
assertEquals(0, sets.getRiskSet(1).evaluate(7.5));
|
|
||||||
assertEquals(0, sets.getRiskSet(2).evaluate(7.5));
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testCalculatingGraySets(){
|
|
||||||
final List<CompetingRiskResponseWithCensorTime> data = generateData();
|
|
||||||
|
|
||||||
final CompetingRiskGraySetsImpl sets = CompetingRiskUtils.calculateGraySetsEfficiently(data, new int[]{1,2});
|
|
||||||
|
|
||||||
final List<Double> times = sets.getEventTimes();
|
|
||||||
assertEquals(5, times.size());
|
|
||||||
|
|
||||||
// Times
|
|
||||||
assertEquals(1.0, times.get(0).doubleValue());
|
|
||||||
assertEquals(2.0, times.get(1).doubleValue());
|
|
||||||
assertEquals(3.0, times.get(2).doubleValue());
|
|
||||||
assertEquals(4.0, times.get(3).doubleValue());
|
|
||||||
assertEquals(6.0, times.get(4).doubleValue());
|
|
||||||
|
|
||||||
// Number of Events
|
|
||||||
assertEquals(2, sets.getNumberOfEvents(1.0, 1));
|
|
||||||
assertEquals(0, sets.getNumberOfEvents(1.0, 2));
|
|
||||||
|
|
||||||
assertEquals(1, sets.getNumberOfEvents(2.0, 1));
|
|
||||||
assertEquals(0, sets.getNumberOfEvents(2.0, 2));
|
|
||||||
|
|
||||||
assertEquals(0, sets.getNumberOfEvents(3.0, 1));
|
|
||||||
assertEquals(1, sets.getNumberOfEvents(3.0, 2));
|
|
||||||
|
|
||||||
assertEquals(1, sets.getNumberOfEvents(4.0, 1));
|
|
||||||
assertEquals(0, sets.getNumberOfEvents(4.0, 2));
|
|
||||||
|
|
||||||
assertEquals(0, sets.getNumberOfEvents(6.0, 1));
|
|
||||||
assertEquals(1, sets.getNumberOfEvents(6.0, 2));
|
|
||||||
|
|
||||||
// Make sure it doesn't break for other times
|
|
||||||
assertEquals(0, sets.getNumberOfEvents(5.5, 1));
|
|
||||||
assertEquals(0, sets.getNumberOfEvents(5.5, 2));
|
|
||||||
|
|
||||||
|
|
||||||
// Risk set
|
|
||||||
assertEquals(9, sets.getRiskSet(1).evaluate(0.5));
|
|
||||||
assertEquals(9, sets.getRiskSet(2).evaluate(0.5));
|
|
||||||
|
|
||||||
assertEquals(9, sets.getRiskSet(1).evaluate(1.0));
|
|
||||||
assertEquals(9, sets.getRiskSet(2).evaluate(1.0));
|
|
||||||
|
|
||||||
assertEquals(6, sets.getRiskSet(1).evaluate(1.5));
|
|
||||||
assertEquals(8, sets.getRiskSet(2).evaluate(1.5));
|
|
||||||
|
|
||||||
assertEquals(6, sets.getRiskSet(1).evaluate(2.0));
|
|
||||||
assertEquals(8, sets.getRiskSet(2).evaluate(2.0));
|
|
||||||
|
|
||||||
assertEquals(5, sets.getRiskSet(1).evaluate(2.3));
|
|
||||||
assertEquals(8, sets.getRiskSet(2).evaluate(2.3));
|
|
||||||
|
|
||||||
assertEquals(5, sets.getRiskSet(1).evaluate(2.5));
|
|
||||||
assertEquals(7, sets.getRiskSet(2).evaluate(2.5));
|
|
||||||
|
|
||||||
assertEquals(5, sets.getRiskSet(1).evaluate(2.7));
|
|
||||||
assertEquals(7, sets.getRiskSet(2).evaluate(2.7));
|
|
||||||
|
|
||||||
assertEquals(5, sets.getRiskSet(1).evaluate(3.0));
|
|
||||||
assertEquals(5, sets.getRiskSet(2).evaluate(3.0));
|
|
||||||
|
|
||||||
assertEquals(4, sets.getRiskSet(1).evaluate(3.5));
|
|
||||||
assertEquals(3, sets.getRiskSet(2).evaluate(3.5));
|
|
||||||
|
|
||||||
assertEquals(3, sets.getRiskSet(1).evaluate(4.0));
|
|
||||||
assertEquals(3, sets.getRiskSet(2).evaluate(4.0));
|
|
||||||
|
|
||||||
assertEquals(2, sets.getRiskSet(1).evaluate(4.5));
|
|
||||||
assertEquals(2, sets.getRiskSet(2).evaluate(4.5));
|
|
||||||
|
|
||||||
assertEquals(2, sets.getRiskSet(1).evaluate(5.0));
|
|
||||||
assertEquals(2, sets.getRiskSet(2).evaluate(5.0));
|
|
||||||
|
|
||||||
assertEquals(1, sets.getRiskSet(1).evaluate(5.5));
|
|
||||||
assertEquals(1, sets.getRiskSet(2).evaluate(5.5));
|
|
||||||
|
|
||||||
assertEquals(1, sets.getRiskSet(1).evaluate(6.0));
|
|
||||||
assertEquals(1, sets.getRiskSet(2).evaluate(6.0));
|
|
||||||
|
|
||||||
assertEquals(1, sets.getRiskSet(1).evaluate(6.5));
|
|
||||||
assertEquals(0, sets.getRiskSet(2).evaluate(6.5));
|
|
||||||
|
|
||||||
assertEquals(0, sets.getRiskSet(1).evaluate(7.0));
|
|
||||||
assertEquals(0, sets.getRiskSet(2).evaluate(7.0));
|
|
||||||
|
|
||||||
assertEquals(0, sets.getRiskSet(1).evaluate(7.5));
|
|
||||||
assertEquals(0, sets.getRiskSet(2).evaluate(7.5));
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,10 +1,15 @@
|
||||||
package ca.joeltherrien.randomforest.competingrisk;
|
package ca.joeltherrien.randomforest.competingrisk;
|
||||||
|
|
||||||
import ca.joeltherrien.randomforest.*;
|
import ca.joeltherrien.randomforest.CovariateRow;
|
||||||
import ca.joeltherrien.randomforest.covariates.*;
|
import ca.joeltherrien.randomforest.DataLoader;
|
||||||
|
import ca.joeltherrien.randomforest.Row;
|
||||||
|
import ca.joeltherrien.randomforest.Settings;
|
||||||
|
import ca.joeltherrien.randomforest.covariates.Covariate;
|
||||||
import ca.joeltherrien.randomforest.covariates.settings.BooleanCovariateSettings;
|
import ca.joeltherrien.randomforest.covariates.settings.BooleanCovariateSettings;
|
||||||
import ca.joeltherrien.randomforest.covariates.settings.NumericCovariateSettings;
|
import ca.joeltherrien.randomforest.covariates.settings.NumericCovariateSettings;
|
||||||
import ca.joeltherrien.randomforest.responses.competingrisk.*;
|
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskErrorRateCalculator;
|
||||||
|
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskFunctions;
|
||||||
|
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
|
||||||
import ca.joeltherrien.randomforest.tree.Forest;
|
import ca.joeltherrien.randomforest.tree.Forest;
|
||||||
import ca.joeltherrien.randomforest.tree.ForestTrainer;
|
import ca.joeltherrien.randomforest.tree.ForestTrainer;
|
||||||
import ca.joeltherrien.randomforest.tree.Node;
|
import ca.joeltherrien.randomforest.tree.Node;
|
||||||
|
@ -14,14 +19,15 @@ import ca.joeltherrien.randomforest.utils.Utils;
|
||||||
import com.fasterxml.jackson.databind.node.*;
|
import com.fasterxml.jackson.databind.node.*;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
|
|
||||||
import static ca.joeltherrien.randomforest.TestUtils.assertCumulativeFunction;
|
|
||||||
import static ca.joeltherrien.randomforest.TestUtils.closeEnough;
|
|
||||||
import static org.junit.jupiter.api.Assertions.*;
|
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Random;
|
import java.util.Random;
|
||||||
|
|
||||||
|
import static ca.joeltherrien.randomforest.TestUtils.assertCumulativeFunction;
|
||||||
|
import static ca.joeltherrien.randomforest.TestUtils.closeEnough;
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||||
|
|
||||||
public class TestCompetingRisk {
|
public class TestCompetingRisk {
|
||||||
|
|
||||||
|
|
||||||
|
@ -284,6 +290,34 @@ public class TestCompetingRisk {
|
||||||
assertEquals(359, countEventTwo);
|
assertEquals(359, countEventTwo);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Used to time how long the algorithm takes
|
||||||
|
*
|
||||||
|
* @param args Not used.
|
||||||
|
* @throws IOException
|
||||||
|
*/
|
||||||
|
public static void main(String[] args) throws IOException {
|
||||||
|
// timing
|
||||||
|
final TestCompetingRisk tcr = new TestCompetingRisk();
|
||||||
|
|
||||||
|
final Settings settings = tcr.getSettings();
|
||||||
|
settings.setNtree(300); // results are too variable at 100
|
||||||
|
|
||||||
|
final List<Covariate> covariates = settings.getCovariates();
|
||||||
|
final List<Row<CompetingRiskResponse>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(),
|
||||||
|
settings.getTrainingDataLocation());
|
||||||
|
final ForestTrainer<CompetingRiskResponse, CompetingRiskFunctions, CompetingRiskFunctions> forestTrainer = new ForestTrainer<>(settings, dataset, covariates);
|
||||||
|
|
||||||
|
final long startTime = System.currentTimeMillis();
|
||||||
|
for(int i=0; i<50; i++){
|
||||||
|
forestTrainer.trainSerial();
|
||||||
|
}
|
||||||
|
final long endTime = System.currentTimeMillis();
|
||||||
|
|
||||||
|
final double diffTime = endTime - startTime;
|
||||||
|
System.out.println(diffTime / 1000.0 / 50.0);
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testLogRankSingleGroupDifferentiatorAllCovariates() throws IOException {
|
public void testLogRankSingleGroupDifferentiatorAllCovariates() throws IOException {
|
||||||
|
|
||||||
|
|
|
@ -7,6 +7,8 @@ import ca.joeltherrien.randomforest.covariates.Covariate;
|
||||||
import ca.joeltherrien.randomforest.covariates.settings.NumericCovariateSettings;
|
import ca.joeltherrien.randomforest.covariates.settings.NumericCovariateSettings;
|
||||||
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
|
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
|
||||||
import ca.joeltherrien.randomforest.responses.competingrisk.differentiator.LogRankMultipleGroupDifferentiator;
|
import ca.joeltherrien.randomforest.responses.competingrisk.differentiator.LogRankMultipleGroupDifferentiator;
|
||||||
|
import ca.joeltherrien.randomforest.tree.Split;
|
||||||
|
import ca.joeltherrien.randomforest.utils.SingletonIterator;
|
||||||
import ca.joeltherrien.randomforest.utils.Utils;
|
import ca.joeltherrien.randomforest.utils.Utils;
|
||||||
import com.fasterxml.jackson.databind.node.JsonNodeFactory;
|
import com.fasterxml.jackson.databind.node.JsonNodeFactory;
|
||||||
import com.fasterxml.jackson.databind.node.ObjectNode;
|
import com.fasterxml.jackson.databind.node.ObjectNode;
|
||||||
|
@ -15,6 +17,8 @@ import lombok.AllArgsConstructor;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.Iterator;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
|
@ -22,6 +26,11 @@ import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||||
|
|
||||||
public class TestLogRankMultipleGroupDifferentiator {
|
public class TestLogRankMultipleGroupDifferentiator {
|
||||||
|
|
||||||
|
private Iterator<Split<CompetingRiskResponse, ?>> turnIntoSplitIterator(List<Row<CompetingRiskResponse>> leftList,
|
||||||
|
List<Row<CompetingRiskResponse>> rightList){
|
||||||
|
return new SingletonIterator<Split<CompetingRiskResponse, ?>>(new Split(null, leftList, rightList, Collections.emptyList()));
|
||||||
|
}
|
||||||
|
|
||||||
public static Data<CompetingRiskResponse> loadData(String filename) throws IOException {
|
public static Data<CompetingRiskResponse> loadData(String filename) throws IOException {
|
||||||
final ObjectNode yVarSettings = new ObjectNode(JsonNodeFactory.instance);
|
final ObjectNode yVarSettings = new ObjectNode(JsonNodeFactory.instance);
|
||||||
yVarSettings.set("type", new TextNode("CompetingRiskResponse"));
|
yVarSettings.set("type", new TextNode("CompetingRiskResponse"));
|
||||||
|
@ -53,16 +62,12 @@ public class TestLogRankMultipleGroupDifferentiator {
|
||||||
final List<Row<CompetingRiskResponse>> group1Bad = data.subList(0, 196);
|
final List<Row<CompetingRiskResponse>> group1Bad = data.subList(0, 196);
|
||||||
final List<Row<CompetingRiskResponse>> group2Bad = data.subList(196, data.size());
|
final List<Row<CompetingRiskResponse>> group2Bad = data.subList(196, data.size());
|
||||||
|
|
||||||
final double scoreBad = groupDifferentiator.getScore(
|
final double scoreBad = groupDifferentiator.differentiate(turnIntoSplitIterator(group1Bad, group2Bad)).getScore();
|
||||||
group1Bad.stream().map(Row::getResponse).collect(Collectors.toList()),
|
|
||||||
group2Bad.stream().map(Row::getResponse).collect(Collectors.toList()));
|
|
||||||
|
|
||||||
final List<Row<CompetingRiskResponse>> group1Good = data.subList(0, 199);
|
final List<Row<CompetingRiskResponse>> group1Good = data.subList(0, 199);
|
||||||
final List<Row<CompetingRiskResponse>> group2Good= data.subList(199, data.size());
|
final List<Row<CompetingRiskResponse>> group2Good= data.subList(199, data.size());
|
||||||
|
|
||||||
final double scoreGood = groupDifferentiator.getScore(
|
final double scoreGood = groupDifferentiator.differentiate(turnIntoSplitIterator(group1Good, group2Good)).getScore();
|
||||||
group1Good.stream().map(Row::getResponse).collect(Collectors.toList()),
|
|
||||||
group2Good.stream().map(Row::getResponse).collect(Collectors.toList()));
|
|
||||||
|
|
||||||
// expected results calculated manually using survival::survdiff in R; see issue #10 in Gitea
|
// expected results calculated manually using survival::survdiff in R; see issue #10 in Gitea
|
||||||
closeEnough(71.41135, scoreBad, 0.00001);
|
closeEnough(71.41135, scoreBad, 0.00001);
|
||||||
|
|
|
@ -3,10 +3,15 @@ package ca.joeltherrien.randomforest.competingrisk;
|
||||||
import ca.joeltherrien.randomforest.Row;
|
import ca.joeltherrien.randomforest.Row;
|
||||||
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
|
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
|
||||||
import ca.joeltherrien.randomforest.responses.competingrisk.differentiator.LogRankSingleGroupDifferentiator;
|
import ca.joeltherrien.randomforest.responses.competingrisk.differentiator.LogRankSingleGroupDifferentiator;
|
||||||
|
import ca.joeltherrien.randomforest.tree.GroupDifferentiator;
|
||||||
|
import ca.joeltherrien.randomforest.tree.Split;
|
||||||
|
import ca.joeltherrien.randomforest.utils.SingletonIterator;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.Iterator;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
|
@ -15,42 +20,55 @@ import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||||
|
|
||||||
public class TestLogRankSingleGroupDifferentiator {
|
public class TestLogRankSingleGroupDifferentiator {
|
||||||
|
|
||||||
private List<CompetingRiskResponse> generateData1(){
|
private double getScore(final GroupDifferentiator<CompetingRiskResponse> groupDifferentiator, List<Row<CompetingRiskResponse>> left, List<Row<CompetingRiskResponse>> right){
|
||||||
final List<CompetingRiskResponse> data = new ArrayList<>();
|
final Iterator<Split<CompetingRiskResponse, ?>> iterator = new SingletonIterator<>(
|
||||||
|
new Split<>(null, left, right, Collections.emptyList()));
|
||||||
|
|
||||||
data.add(new CompetingRiskResponse(1, 1.0));
|
return groupDifferentiator.differentiate(iterator).getScore();
|
||||||
data.add(new CompetingRiskResponse(1, 1.0));
|
|
||||||
data.add(new CompetingRiskResponse(1, 2.0));
|
}
|
||||||
data.add(new CompetingRiskResponse(1, 1.5));
|
|
||||||
data.add(new CompetingRiskResponse(0, 2.0));
|
int count = 1;
|
||||||
data.add(new CompetingRiskResponse(0, 1.5));
|
private <Y> Row<Y> createRow(Y response){
|
||||||
data.add(new CompetingRiskResponse(0, 2.5));
|
return new Row<>(null, count++, response);
|
||||||
|
}
|
||||||
|
|
||||||
|
private List<Row<CompetingRiskResponse>> generateData1(){
|
||||||
|
final List<Row<CompetingRiskResponse>> data = new ArrayList<>();
|
||||||
|
|
||||||
|
data.add(createRow(new CompetingRiskResponse(1, 1.0)));
|
||||||
|
data.add(createRow(new CompetingRiskResponse(1, 1.0)));
|
||||||
|
data.add(createRow(new CompetingRiskResponse(1, 2.0)));
|
||||||
|
data.add(createRow(new CompetingRiskResponse(1, 1.5)));
|
||||||
|
data.add(createRow(new CompetingRiskResponse(0, 2.0)));
|
||||||
|
data.add(createRow(new CompetingRiskResponse(0, 1.5)));
|
||||||
|
data.add(createRow(new CompetingRiskResponse(0, 2.5)));
|
||||||
|
|
||||||
return data;
|
return data;
|
||||||
}
|
}
|
||||||
|
|
||||||
private List<CompetingRiskResponse> generateData2(){
|
private List<Row<CompetingRiskResponse>> generateData2(){
|
||||||
final List<CompetingRiskResponse> data = new ArrayList<>();
|
final List<Row<CompetingRiskResponse>> data = new ArrayList<>();
|
||||||
|
|
||||||
data.add(new CompetingRiskResponse(1, 2.0));
|
data.add(createRow(new CompetingRiskResponse(1, 2.0)));
|
||||||
data.add(new CompetingRiskResponse(1, 2.0));
|
data.add(createRow(new CompetingRiskResponse(1, 2.0)));
|
||||||
data.add(new CompetingRiskResponse(1, 4.0));
|
data.add(createRow(new CompetingRiskResponse(1, 4.0)));
|
||||||
data.add(new CompetingRiskResponse(1, 3.0));
|
data.add(createRow(new CompetingRiskResponse(1, 3.0)));
|
||||||
data.add(new CompetingRiskResponse(0, 4.0));
|
data.add(createRow(new CompetingRiskResponse(0, 4.0)));
|
||||||
data.add(new CompetingRiskResponse(0, 3.0));
|
data.add(createRow(new CompetingRiskResponse(0, 3.0)));
|
||||||
data.add(new CompetingRiskResponse(0, 5.0));
|
data.add(createRow(new CompetingRiskResponse(0, 5.0)));
|
||||||
|
|
||||||
return data;
|
return data;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testCompetingRiskResponseCombiner(){
|
public void testCompetingRiskResponseCombiner(){
|
||||||
final List<CompetingRiskResponse> data1 = generateData1();
|
final List<Row<CompetingRiskResponse>> data1 = generateData1();
|
||||||
final List<CompetingRiskResponse> data2 = generateData2();
|
final List<Row<CompetingRiskResponse>> data2 = generateData2();
|
||||||
|
|
||||||
final LogRankSingleGroupDifferentiator differentiator = new LogRankSingleGroupDifferentiator(1, new int[]{1});
|
final LogRankSingleGroupDifferentiator differentiator = new LogRankSingleGroupDifferentiator(1, new int[]{1});
|
||||||
|
|
||||||
final double score = differentiator.getScore(data1, data2);
|
final double score = getScore(differentiator, data1, data2);
|
||||||
final double margin = 0.000001;
|
final double margin = 0.000001;
|
||||||
|
|
||||||
// Tested using 855 method
|
// Tested using 855 method
|
||||||
|
@ -70,16 +88,12 @@ public class TestLogRankSingleGroupDifferentiator {
|
||||||
final List<Row<CompetingRiskResponse>> group1Good = data.subList(0, 221);
|
final List<Row<CompetingRiskResponse>> group1Good = data.subList(0, 221);
|
||||||
final List<Row<CompetingRiskResponse>> group2Good = data.subList(221, data.size());
|
final List<Row<CompetingRiskResponse>> group2Good = data.subList(221, data.size());
|
||||||
|
|
||||||
final double scoreGood = groupDifferentiator.getScore(
|
final double scoreGood = getScore(groupDifferentiator, group1Good, group2Good);
|
||||||
group1Good.stream().map(Row::getResponse).collect(Collectors.toList()),
|
|
||||||
group2Good.stream().map(Row::getResponse).collect(Collectors.toList()));
|
|
||||||
|
|
||||||
final List<Row<CompetingRiskResponse>> group1Bad = data.subList(0, 222);
|
final List<Row<CompetingRiskResponse>> group1Bad = data.subList(0, 222);
|
||||||
final List<Row<CompetingRiskResponse>> group2Bad = data.subList(222, data.size());
|
final List<Row<CompetingRiskResponse>> group2Bad = data.subList(222, data.size());
|
||||||
|
|
||||||
final double scoreBad = groupDifferentiator.getScore(
|
final double scoreBad = getScore(groupDifferentiator, group1Bad, group2Bad);
|
||||||
group1Bad.stream().map(Row::getResponse).collect(Collectors.toList()),
|
|
||||||
group2Bad.stream().map(Row::getResponse).collect(Collectors.toList()));
|
|
||||||
|
|
||||||
// Apparently not all groups are unique when splitting
|
// Apparently not all groups are unique when splitting
|
||||||
assertEquals(scoreGood, scoreBad);
|
assertEquals(scoreGood, scoreBad);
|
||||||
|
|
Loading…
Reference in a new issue