Finish competing risk implementation. Fix a bug in tree training

algorithm.
This commit is contained in:
Joel Therrien 2018-07-16 16:58:11 -07:00
parent 462b0d9c35
commit fffdfe85bf
41 changed files with 4768 additions and 241 deletions

View file

@ -5,12 +5,12 @@ import lombok.RequiredArgsConstructor;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import java.util.concurrent.ThreadLocalRandom;
@RequiredArgsConstructor
public class Bootstrapper<T> {
final private List<T> originalData;
final private Random random = new Random();
public List<T> bootstrap(){
final int n = originalData.size();
@ -18,7 +18,7 @@ public class Bootstrapper<T> {
final List<T> newList = new ArrayList<>(n);
for(int i=0; i<n; i++){
final int index = random.nextInt(n);
final int index = ThreadLocalRandom.current().nextInt(n);
newList.add(originalData.get(index));
}

View file

@ -4,6 +4,8 @@ import ca.joeltherrien.randomforest.covariates.Covariate;
import lombok.Getter;
import lombok.RequiredArgsConstructor;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
@RequiredArgsConstructor
@ -24,4 +26,19 @@ public class CovariateRow {
return "CovariateRow " + this.id;
}
public static CovariateRow createSimple(Map<String, String> simpleMap, List<Covariate> covariateList, int id){
final Map<String, Covariate.Value> valueMap = new HashMap<>();
final Map<String, Covariate> covariateMap = new HashMap<>();
covariateList.forEach(covariate -> covariateMap.put(covariate.getName(), covariate));
simpleMap.forEach((name, valueStr) -> {
if(covariateMap.containsKey(name)){
valueMap.put(name, covariateMap.get(name).createValue(valueStr));
}
});
return new CovariateRow(valueMap, id);
}
}

View file

@ -1,7 +1,6 @@
package ca.joeltherrien.randomforest;
import ca.joeltherrien.randomforest.covariates.Covariate;
import ca.joeltherrien.randomforest.tree.GroupDifferentiator;
import com.fasterxml.jackson.databind.node.ObjectNode;
import lombok.RequiredArgsConstructor;
import org.apache.commons.csv.CSVFormat;
@ -49,11 +48,6 @@ public class DataLoader {
Y parse(CSVRecord record);
}
@FunctionalInterface
public interface ResponseLoaderConstructor<Y>{
ResponseLoader<Y> construct(ObjectNode node);
}
@RequiredArgsConstructor
public static class DoubleLoader implements ResponseLoader<Double> {

View file

@ -44,7 +44,7 @@ public class Main {
final List<Row<Double>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getDataFileLocation());
final ForestTrainer<Double> forestTrainer = new ForestTrainer<>(settings, dataset, covariates);
final ForestTrainer<Double, Double, Double> forestTrainer = new ForestTrainer<>(settings, dataset, covariates);
if(settings.isSaveProgress()){
forestTrainer.trainParallelOnDisk(settings.getNumberOfThreads());
@ -63,8 +63,14 @@ public class Main {
final ObjectNode groupDifferentiatorSettings = new ObjectNode(JsonNodeFactory.instance);
groupDifferentiatorSettings.set("type", new TextNode("WeightedVarianceGroupDifferentiator"));
final ObjectNode responseCombinerSettings = new ObjectNode(JsonNodeFactory.instance);
responseCombinerSettings.set("type", new TextNode("MeanResponseCombiner"));
final ObjectNode treeCombinerSettings = new ObjectNode(JsonNodeFactory.instance);
treeCombinerSettings.set("type", new TextNode("MeanResponseCombiner"));
final ObjectNode yVarSettings = new ObjectNode(JsonNodeFactory.instance);
yVarSettings.set("type", new TextNode("y"));
yVarSettings.set("type", new TextNode("Double"));
yVarSettings.set("name", new TextNode("y"));
final Settings settings = Settings.builder()
@ -75,8 +81,8 @@ public class Main {
)
)
.dataFileLocation("data.csv")
.responseCombiner("MeanResponseCombiner")
.treeResponseCombiner("MeanResponseCombiner")
.responseCombinerSettings(responseCombinerSettings)
.treeCombinerSettings(treeCombinerSettings)
.groupDifferentiatorSettings(groupDifferentiatorSettings)
.yVarSettings(yVarSettings)
.maxNodeDepth(100000)

View file

@ -3,8 +3,10 @@ package ca.joeltherrien.randomforest;
import ca.joeltherrien.randomforest.covariates.CovariateSettings;
import ca.joeltherrien.randomforest.responses.competingrisk.*;
import ca.joeltherrien.randomforest.responses.regression.MeanGroupDifferentiator;
import ca.joeltherrien.randomforest.responses.regression.MeanResponseCombiner;
import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceGroupDifferentiator;
import ca.joeltherrien.randomforest.tree.GroupDifferentiator;
import ca.joeltherrien.randomforest.tree.ResponseCombiner;
import com.fasterxml.jackson.annotation.JsonIgnore;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
@ -16,6 +18,7 @@ import lombok.*;
import java.io.File;
import java.io.IOException;
import java.util.*;
import java.util.function.Function;
/**
* This class is saved & loaded using a saved configuration file. It contains all relevant settings when training a forest.
@ -26,11 +29,11 @@ import java.util.*;
@EqualsAndHashCode
public class Settings {
private static Map<String, DataLoader.ResponseLoaderConstructor> RESPONSE_LOADER_MAP = new HashMap<>();
public static DataLoader.ResponseLoaderConstructor getResponseLoaderConstructor(final String name){
private static Map<String, Function<ObjectNode, DataLoader.ResponseLoader>> RESPONSE_LOADER_MAP = new HashMap<>();
public static Function<ObjectNode, DataLoader.ResponseLoader> getResponseLoaderConstructor(final String name){
return RESPONSE_LOADER_MAP.get(name.toLowerCase());
}
public static void registerResponseLoaderConstructor(final String name, final DataLoader.ResponseLoaderConstructor responseLoaderConstructor){
public static void registerResponseLoaderConstructor(final String name, final Function<ObjectNode, DataLoader.ResponseLoader> responseLoaderConstructor){
RESPONSE_LOADER_MAP.put(name.toLowerCase(), responseLoaderConstructor);
}
@ -38,19 +41,19 @@ public class Settings {
registerResponseLoaderConstructor("double",
node -> new DataLoader.DoubleLoader(node)
);
registerResponseLoaderConstructor("CompetingResponse",
node -> new CompetingResponse.CompetingResponseLoader(node)
registerResponseLoaderConstructor("CompetingRiskResponse",
node -> new CompetingRiskResponse.CompetingResponseLoader(node)
);
registerResponseLoaderConstructor("CompetingResponseWithCensorTime",
node -> new CompetingResponseWithCensorTime.CompetingResponseWithCensorTimeLoader(node)
registerResponseLoaderConstructor("CompetingRiskResponseWithCensorTime",
node -> new CompetingRiskResponseWithCensorTime.CompetingResponseWithCensorTimeLoader(node)
);
}
private static Map<String, GroupDifferentiator.GroupDifferentiatorConstructor> GROUP_DIFFERENTIATOR_MAP = new HashMap<>();
public static GroupDifferentiator.GroupDifferentiatorConstructor getGroupDifferentiatorConstructor(final String name){
private static Map<String, Function<ObjectNode, GroupDifferentiator>> GROUP_DIFFERENTIATOR_MAP = new HashMap<>();
public static Function<ObjectNode, GroupDifferentiator> getGroupDifferentiatorConstructor(final String name){
return GROUP_DIFFERENTIATOR_MAP.get(name.toLowerCase());
}
public static void registerGroupDifferentiatorConstructor(final String name, final GroupDifferentiator.GroupDifferentiatorConstructor groupDifferentiatorConstructor){
public static void registerGroupDifferentiatorConstructor(final String name, final Function<ObjectNode, GroupDifferentiator> groupDifferentiatorConstructor){
GROUP_DIFFERENTIATOR_MAP.put(name.toLowerCase(), groupDifferentiatorConstructor);
}
static{
@ -98,13 +101,67 @@ public class Settings {
);
}
private static Map<String, Function<ObjectNode, ResponseCombiner>> RESPONSE_COMBINER_MAP = new HashMap<>();
public static Function<ObjectNode, ResponseCombiner> getResponseCombinerConstructor(final String name){
return RESPONSE_COMBINER_MAP.get(name.toLowerCase());
}
public static void registerResponseCombinerConstructor(final String name, final Function<ObjectNode, ResponseCombiner> responseCombinerConstructor){
RESPONSE_COMBINER_MAP.put(name.toLowerCase(), responseCombinerConstructor);
}
static{
registerResponseCombinerConstructor("MeanResponseCombiner",
(node) -> new MeanResponseCombiner()
);
registerResponseCombinerConstructor("CompetingRiskResponseCombiner",
(node) -> {
final List<Integer> eventList = new ArrayList<>();
node.get("events").elements().forEachRemaining(event -> eventList.add(event.asInt()));
final int[] events = eventList.stream().mapToInt(i -> i).toArray();
double[] times = null;
// note that times may be null
if(node.hasNonNull("times")){
final List<Double> timeList = new ArrayList<>();
node.get("times").elements().forEachRemaining(event -> timeList.add(event.asDouble()));
times = eventList.stream().mapToDouble(db -> db).toArray();
}
return new CompetingRiskResponseCombiner(events, times);
}
);
registerResponseCombinerConstructor("CompetingRiskFunctionCombiner",
(node) -> {
final List<Integer> eventList = new ArrayList<>();
node.get("events").elements().forEachRemaining(event -> eventList.add(event.asInt()));
final int[] events = eventList.stream().mapToInt(i -> i).toArray();
double[] times = null;
// note that times may be null
if(node.hasNonNull("times")){
final List<Double> timeList = new ArrayList<>();
node.get("times").elements().forEachRemaining(event -> timeList.add(event.asDouble()));
times = eventList.stream().mapToDouble(db -> db).toArray();
}
return new CompetingRiskFunctionCombiner(events, times);
}
);
}
private int numberOfSplits = 5;
private int nodeSize = 5;
private int maxNodeDepth = 1000000; // basically no maxNodeDepth
private String responseCombiner;
private ObjectNode responseCombinerSettings = new ObjectNode(JsonNodeFactory.instance);
private ObjectNode groupDifferentiatorSettings = new ObjectNode(JsonNodeFactory.instance);
private String treeResponseCombiner;
private ObjectNode treeCombinerSettings = new ObjectNode(JsonNodeFactory.instance);
private List<CovariateSettings> covariates = new ArrayList<>();
private ObjectNode yVarSettings = new ObjectNode(JsonNodeFactory.instance);
@ -148,14 +205,28 @@ public class Settings {
public GroupDifferentiator getGroupDifferentiator(){
final String type = groupDifferentiatorSettings.get("type").asText();
return getGroupDifferentiatorConstructor(type).construct(groupDifferentiatorSettings);
return getGroupDifferentiatorConstructor(type).apply(groupDifferentiatorSettings);
}
@JsonIgnore
public DataLoader.ResponseLoader getResponseLoader(){
final String type = yVarSettings.get("type").asText();
return getResponseLoaderConstructor(type).construct(yVarSettings);
return getResponseLoaderConstructor(type).apply(yVarSettings);
}
@JsonIgnore
public ResponseCombiner getResponseCombiner(){
final String type = responseCombinerSettings.get("type").asText();
return getResponseCombinerConstructor(type).apply(responseCombinerSettings);
}
@JsonIgnore
public ResponseCombiner getTreeCombiner(){
final String type = treeCombinerSettings.get("type").asText();
return getResponseCombinerConstructor(type).apply(treeCombinerSettings);
}
}

View file

@ -42,6 +42,11 @@ public final class BooleanCovariate implements Covariate<Boolean>{
}
}
@Override
public String toString(){
return "BooleanCovariate(name=" + name + ")";
}
public class BooleanValue implements Value<Boolean>{
private final Boolean value;

View file

@ -82,6 +82,11 @@ public final class FactorCovariate implements Covariate<String>{
return factorValue;
}
@Override
public String toString(){
return "FactorCovariate(name=" + name + ")";
}
@EqualsAndHashCode
public final class FactorValue implements Covariate.Value<String>{

View file

@ -1,13 +1,16 @@
package ca.joeltherrien.randomforest.covariates;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.RequiredArgsConstructor;
import lombok.ToString;
import java.util.*;
import java.util.concurrent.ThreadLocalRandom;
import java.util.stream.Collectors;
@RequiredArgsConstructor
@ToString
public final class NumericCovariate implements Covariate<Double>{
@Getter
@ -20,16 +23,18 @@ public final class NumericCovariate implements Covariate<Double>{
// only work with non-NA values
data = data.stream().filter(value -> !value.isNA()).collect(Collectors.toList());
//data = data.stream().filter(value -> !value.isNA()).distinct().collect(Collectors.toList()); // TODO which to use?
// for this implementation we need to shuffle the data
final List<Value<Double>> shuffledData;
if(number > data.size()){
if(number >= data.size()){
shuffledData = new ArrayList<>(data);
Collections.shuffle(shuffledData, random);
}
else{ // only need the top number entries
shuffledData = new ArrayList<>(number);
final Set<Integer> indexesToUse = new HashSet<>();
//final List<Integer> indexesToUse = new ArrayList<>(); // TODO which to use?
while(indexesToUse.size() < number){
final int index = random.nextInt(data.size());
@ -56,7 +61,7 @@ public final class NumericCovariate implements Covariate<Double>{
}
@Override
public Value<Double> createValue(String value) {
public NumericValue createValue(String value) {
if(value == null || value.equalsIgnoreCase("na")){
return createValue((Double) null);
}
@ -64,6 +69,7 @@ public final class NumericCovariate implements Covariate<Double>{
return createValue(Double.parseDouble(value));
}
@EqualsAndHashCode
public class NumericValue implements Covariate.Value<Double>{
private final Double value; // may be null
@ -88,6 +94,7 @@ public final class NumericCovariate implements Covariate<Double>{
}
}
@EqualsAndHashCode
public class NumericSplitRule implements Covariate.SplitRule<Double>{
private final double threshold;

View file

@ -1,11 +0,0 @@
package ca.joeltherrien.randomforest.responses.competingrisk;
import java.util.List;
public class CompetingRiskFunction {
private List<Point> pointList;
}

View file

@ -0,0 +1,81 @@
package ca.joeltherrien.randomforest.responses.competingrisk;
import ca.joeltherrien.randomforest.tree.ResponseCombiner;
import lombok.RequiredArgsConstructor;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
@RequiredArgsConstructor
public class CompetingRiskFunctionCombiner implements ResponseCombiner<CompetingRiskFunctions, CompetingRiskFunctions> {
private final int[] events;
private final double[] times; // We may restrict ourselves to specific times.
@Override
public CompetingRiskFunctions combine(List<CompetingRiskFunctions> responses) {
final double[] timesToUse;
if(times != null){
timesToUse = times;
}
else{
timesToUse = responses.stream()
.map(functions -> functions.getSurvivalCurve())
.flatMapToDouble(
function -> function.getPoints().stream()
.mapToDouble(point -> point.getTime())
).sorted().distinct().toArray();
}
final double n = responses.size();
final List<Point> survivalPoints = new ArrayList<>(timesToUse.length);
for(final double time : timesToUse){
final double survivalY = responses.stream()
.mapToDouble(functions -> functions.getSurvivalCurve().evaluate(time).getY() / n)
.sum();
survivalPoints.add(new Point(time, survivalY));
}
final MathFunction survivalFunction = new MathFunction(survivalPoints, new Point(0.0, 1.0));
final Map<Integer, MathFunction> causeSpecificCumulativeHazardFunctionMap = new HashMap<>();
final Map<Integer, MathFunction> cumulativeIncidenceFunctionMap = new HashMap<>();
for(final int event : events){
final List<Point> cumulativeHazardFunctionPoints = new ArrayList<>(timesToUse.length);
final List<Point> cumulativeIncidenceFunctionPoints = new ArrayList<>(timesToUse.length);
for(final double time : timesToUse){
final double hazardY = responses.stream()
.mapToDouble(functions -> functions.getCauseSpecificHazardFunction(event).evaluate(time).getY() / n)
.sum();
final double incidenceY = responses.stream()
.mapToDouble(functions -> functions.getCumulativeIncidenceFunction(event).evaluate(time).getY() / n)
.sum();
cumulativeHazardFunctionPoints.add(new Point(time, hazardY));
cumulativeIncidenceFunctionPoints.add(new Point(time, incidenceY));
}
causeSpecificCumulativeHazardFunctionMap.put(event, new MathFunction(cumulativeHazardFunctionPoints));
cumulativeIncidenceFunctionMap.put(event, new MathFunction(cumulativeIncidenceFunctionPoints));
}
return CompetingRiskFunctions.builder()
.causeSpecificHazardFunctionMap(causeSpecificCumulativeHazardFunctionMap)
.cumulativeIncidenceFunctionMap(cumulativeIncidenceFunctionMap)
.survivalCurve(survivalFunction)
.build();
}
}

View file

@ -0,0 +1,25 @@
package ca.joeltherrien.randomforest.responses.competingrisk;
import lombok.Builder;
import lombok.Getter;
import lombok.RequiredArgsConstructor;
import java.util.Map;
@Builder
public class CompetingRiskFunctions {
private final Map<Integer, MathFunction> causeSpecificHazardFunctionMap;
private final Map<Integer, MathFunction> cumulativeIncidenceFunctionMap;
@Getter
private final MathFunction survivalCurve;
public MathFunction getCauseSpecificHazardFunction(int cause){
return causeSpecificHazardFunctionMap.get(cause);
}
public MathFunction getCumulativeIncidenceFunction(int cause) {
return cumulativeIncidenceFunctionMap.get(cause);
}
}

View file

@ -12,14 +12,14 @@ import java.util.stream.Stream;
* modifies the abstract method.
*
*/
public abstract class CompetingRiskGroupDifferentiator<Y extends CompetingResponse> implements GroupDifferentiator<Y>{
public abstract class CompetingRiskGroupDifferentiator<Y extends CompetingRiskResponse> implements GroupDifferentiator<Y>{
@Override
public abstract Double differentiate(List<Y> leftHand, List<Y> rightHand);
abstract double riskSet(final List<Y> eventList, double time, int eventOfFocus);
private double numberOFEventsAtTime(int eventOfFocus, List<Y> eventList, double time){
private double numberOfEventsAtTime(int eventOfFocus, List<Y> eventList, double time){
return (double) eventList.stream()
.filter(event -> event.getDelta() == eventOfFocus)
.filter(event -> event.getU() == time) // since delta != 0 we know censoring didn't occur prior to this
@ -31,8 +31,8 @@ public abstract class CompetingRiskGroupDifferentiator<Y extends CompetingRespon
* Calculates the log rank value (or the Gray's test value) for a *specific* event cause.
*
* @param eventOfFocus
* @param leftHand A non-empty list of CompetingResponse
* @param rightHand A non-empty list of CompetingResponse
* @param leftHand A non-empty list of CompetingRiskResponse
* @param rightHand A non-empty list of CompetingRiskResponse
* @return
*/
LogRankValue specificLogRankValue(final int eventOfFocus, List<Y> leftHand, List<Y> rightHand){
@ -40,34 +40,42 @@ public abstract class CompetingRiskGroupDifferentiator<Y extends CompetingRespon
final double[] distinctEventTimes = Stream.concat(
leftHand.stream(), rightHand.stream()
)
.filter(event -> event.getDelta() != 0) // remove censored events
.filter(event -> !event.isCensored())
.mapToDouble(event -> event.getU())
.distinct()
.toArray();
double summation = 0.0;
double varianceSquared = 0.0;
double variance = 0.0;
for(final double time_k : distinctEventTimes){
final double weight = weight(time_k); // W_j(t_k)
final double numberEventsAtTimeDaughterLeft = numberOFEventsAtTime(eventOfFocus, leftHand, time_k); // d_{j,l}(t_k)
final double numberEventsAtTimeDaughterRight = numberOFEventsAtTime(eventOfFocus, rightHand, time_k); // d_{j,r}(t_k)
final double numberEventsAtTimeDaughterLeft = numberOfEventsAtTime(eventOfFocus, leftHand, time_k); // d_{j,l}(t_k)
final double numberEventsAtTimeDaughterRight = numberOfEventsAtTime(eventOfFocus, rightHand, time_k); // d_{j,r}(t_k)
final double numberOfEventsAtTime = numberEventsAtTimeDaughterLeft + numberEventsAtTimeDaughterRight; // d_j(t_k)
final double individualsAtRiskDaughterLeft = riskSet(leftHand, time_k, eventOfFocus); // Y_l(t_k)
final double individualsAtRiskDaughterRight = riskSet(rightHand, time_k, eventOfFocus); // Y_r(t_k)
final double individualsAtRisk = individualsAtRiskDaughterLeft + individualsAtRiskDaughterRight; // Y(t_k)
summation = summation + weight*(numberEventsAtTimeDaughterLeft - numberOfEventsAtTime*individualsAtRiskDaughterLeft/individualsAtRisk);
varianceSquared = varianceSquared + weight*weight*numberOfEventsAtTime*individualsAtRiskDaughterLeft/individualsAtRisk
final double deltaSummation = weight*(numberEventsAtTimeDaughterLeft - numberOfEventsAtTime*individualsAtRiskDaughterLeft/individualsAtRisk);
final double deltaVariance = weight*weight*numberOfEventsAtTime*individualsAtRiskDaughterLeft/individualsAtRisk
* (1.0 - individualsAtRiskDaughterLeft / individualsAtRisk)
* ((individualsAtRisk - numberOfEventsAtTime) / (individualsAtRisk - 1.0));
// Note - notation differs slightly with what is found in STAT 855 notes, but they are equivalent.
// Note - if individualsAtRisk == 1 then variance will be NaN.
if(!Double.isNaN(deltaVariance)){
summation += deltaSummation;
variance += deltaVariance;
}
else{
// Do nothing; else statement left for breakpoints.
}
}
return new LogRankValue(summation, varianceSquared);
return new LogRankValue(summation, variance);
}
double weight(double time){
@ -80,10 +88,10 @@ public abstract class CompetingRiskGroupDifferentiator<Y extends CompetingRespon
@AllArgsConstructor
static class LogRankValue{
private final double numerator;
private final double varianceSquared;
private final double variance;
public double getVariance(){
return Math.sqrt(varianceSquared);
public double getVarianceSqrt(){
return Math.sqrt(variance);
}
}

View file

@ -7,14 +7,18 @@ import lombok.RequiredArgsConstructor;
import org.apache.commons.csv.CSVRecord;
@Data
public class CompetingResponse {
public class CompetingRiskResponse {
private final int delta;
private final double u;
public boolean isCensored(){
return delta == 0;
}
@RequiredArgsConstructor
public static class CompetingResponseLoader implements DataLoader.ResponseLoader<CompetingResponse>{
public static class CompetingResponseLoader implements DataLoader.ResponseLoader<CompetingRiskResponse>{
private final String deltaName;
private final String uName;
@ -25,11 +29,11 @@ public class CompetingResponse {
}
@Override
public CompetingResponse parse(CSVRecord record) {
public CompetingRiskResponse parse(CSVRecord record) {
final int delta = Integer.parseInt(record.get(deltaName));
final double u = Double.parseDouble(record.get(uName));
return new CompetingResponse(delta, u);
return new CompetingRiskResponse(delta, u);
}
}

View file

@ -0,0 +1,123 @@
package ca.joeltherrien.randomforest.responses.competingrisk;
import ca.joeltherrien.randomforest.tree.ResponseCombiner;
import lombok.RequiredArgsConstructor;
import java.util.*;
/**
* This class takes all of the observations in a terminal node and combines them to produce estimates of the cause-specific hazard function
* and the cumulative incidence curve.
*
* See https://kogalur.github.io/randomForestSRC/theory.html for details.
*
*/
@RequiredArgsConstructor
public class CompetingRiskResponseCombiner implements ResponseCombiner<CompetingRiskResponse, CompetingRiskFunctions> {
private final int[] events;
private final double[] times; // We may restrict ourselves to specific times.
@Override
public CompetingRiskFunctions combine(List<CompetingRiskResponse> responses) {
final Map<Integer, MathFunction> causeSpecificCumulativeHazardFunctionMap = new HashMap<>();
final Map<Integer, MathFunction> cumulativeIncidenceFunctionMap = new HashMap<>();
final double[] timesToUse;
if(times != null){
timesToUse = this.times;
}
else{
timesToUse = responses.stream()
.filter(response -> !response.isCensored())
.mapToDouble(response -> response.getU())
.sorted().distinct()
.toArray();
}
final double[] individualsAtRiskArray = Arrays.stream(timesToUse).map(time -> riskSet(responses, time)).toArray();
// First we need to develop the overall survival curve!
final List<Point> survivalPoints = new ArrayList<>(timesToUse.length);
double previousSurvivalValue = 1.0;
for(int i=0; i<timesToUse.length; i++){
final double time_k = timesToUse[i];
final double individualsAtRisk = individualsAtRiskArray[i]; // Y(t_k)
final double numberOfEventsAtTime = (double) responses.stream()
.filter(event -> !event.isCensored())
.filter(event -> event.getU() == time_k) // since delta != 0 we know censoring didn't occur prior to this
.count();
final double newValue = previousSurvivalValue * (1.0 - numberOfEventsAtTime / individualsAtRisk);
survivalPoints.add(new Point(time_k, newValue));
previousSurvivalValue = newValue;
}
final MathFunction survivalCurve = new MathFunction(survivalPoints, new Point(0.0, 1.0));
for(final int event : events){
final List<Point> hazardFunctionPoints = new ArrayList<>(timesToUse.length);
Point previousHazardFunctionPoint = new Point(0.0, 0.0);
final List<Point> cifPoints = new ArrayList<>(timesToUse.length);
Point previousCIFPoint = new Point(0.0, 0.0);
for(int i=0; i<timesToUse.length; i++){
final double time_k = timesToUse[i];
final double individualsAtRisk = individualsAtRiskArray[i]; // Y(t_k)
final double numberEventsAtTime = numberOfEventsAtTime(event, responses, time_k); // d_j(t_k)
// Cause-specific cumulative hazard function
final double hazardDeltaY = numberEventsAtTime / individualsAtRisk;
final Point newHazardPoint = new Point(time_k, previousHazardFunctionPoint.getY() + hazardDeltaY);
hazardFunctionPoints.add(newHazardPoint);
previousHazardFunctionPoint = newHazardPoint;
// Cumulative incidence function
// TODO - confirm this behaviour
//final double previousSurvivalEvaluation = i > 0 ? survivalCurve.evaluate(timesToUse[i-1]).getY() : survivalCurve.evaluate(0.0).getY();
final double previousSurvivalEvaluation = i > 0 ? survivalCurve.evaluate(timesToUse[i-1]).getY() : 1.0;
final double cifDeltaY = previousSurvivalEvaluation * (numberEventsAtTime / individualsAtRisk);
final Point newCIFPoint = new Point(time_k, previousCIFPoint.getY() + cifDeltaY);
cifPoints.add(newCIFPoint);
previousCIFPoint = newCIFPoint;
}
final MathFunction causeSpecificCumulativeHazardFunction = new MathFunction(hazardFunctionPoints);
causeSpecificCumulativeHazardFunctionMap.put(event, causeSpecificCumulativeHazardFunction);
final MathFunction cifFunction = new MathFunction(cifPoints);
cumulativeIncidenceFunctionMap.put(event, cifFunction);
}
return CompetingRiskFunctions.builder()
.causeSpecificHazardFunctionMap(causeSpecificCumulativeHazardFunctionMap)
.cumulativeIncidenceFunctionMap(cumulativeIncidenceFunctionMap)
.survivalCurve(survivalCurve)
.build();
}
private double riskSet(List<CompetingRiskResponse> eventList, double time) {
return eventList.stream()
.filter(event -> event.getU() >= time)
.count();
}
private double numberOfEventsAtTime(int eventOfFocus, List<CompetingRiskResponse> eventList, double time){
return (double) eventList.stream()
.filter(event -> event.getDelta() == eventOfFocus)
.filter(event -> event.getU() == time) // since delta != 0 we know censoring didn't occur prior to this
.count();
}
}

View file

@ -11,16 +11,16 @@ import org.apache.commons.csv.CSVRecord;
*
*/
@Data
public class CompetingResponseWithCensorTime extends CompetingResponse{
public class CompetingRiskResponseWithCensorTime extends CompetingRiskResponse {
private final double c;
public CompetingResponseWithCensorTime(int delta, double u, double c) {
public CompetingRiskResponseWithCensorTime(int delta, double u, double c) {
super(delta, u);
this.c = c;
}
@RequiredArgsConstructor
public static class CompetingResponseWithCensorTimeLoader implements DataLoader.ResponseLoader<CompetingResponseWithCensorTime>{
public static class CompetingResponseWithCensorTimeLoader implements DataLoader.ResponseLoader<CompetingRiskResponseWithCensorTime>{
private final String deltaName;
private final String uName;
@ -33,12 +33,12 @@ public class CompetingResponseWithCensorTime extends CompetingResponse{
}
@Override
public CompetingResponseWithCensorTime parse(CSVRecord record) {
public CompetingRiskResponseWithCensorTime parse(CSVRecord record) {
final int delta = Integer.parseInt(record.get(deltaName));
final double u = Double.parseDouble(record.get(uName));
final double c = Double.parseDouble(record.get(cName));
return new CompetingResponseWithCensorTime(delta, u, c);
return new CompetingRiskResponseWithCensorTime(delta, u, c);
}
}
}

View file

@ -1,12 +1,7 @@
package ca.joeltherrien.randomforest.responses.competingrisk;
import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceGroupDifferentiator;
import ca.joeltherrien.randomforest.tree.GroupDifferentiator;
import com.fasterxml.jackson.databind.JsonNode;
import lombok.RequiredArgsConstructor;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
/**
@ -14,12 +9,12 @@ import java.util.List;
*
*/
@RequiredArgsConstructor
public class GrayLogRankMultipleGroupDifferentiator extends CompetingRiskGroupDifferentiator<CompetingResponseWithCensorTime> {
public class GrayLogRankMultipleGroupDifferentiator extends CompetingRiskGroupDifferentiator<CompetingRiskResponseWithCensorTime> {
private final int[] events;
@Override
public Double differentiate(List<CompetingResponseWithCensorTime> leftHand, List<CompetingResponseWithCensorTime> rightHand) {
public Double differentiate(List<CompetingRiskResponseWithCensorTime> leftHand, List<CompetingRiskResponseWithCensorTime> rightHand) {
if(leftHand.size() == 0 || rightHand.size() == 0){
return null;
}
@ -30,8 +25,8 @@ public class GrayLogRankMultipleGroupDifferentiator extends CompetingRiskGroupDi
for(final int eventOfFocus : events){
final LogRankValue valueOfInterest = specificLogRankValue(eventOfFocus, leftHand, rightHand);
numerator += valueOfInterest.getNumerator()*valueOfInterest.getVariance();
denominatorSquared += valueOfInterest.getVarianceSquared();
numerator += valueOfInterest.getNumerator()*valueOfInterest.getVarianceSqrt();
denominatorSquared += valueOfInterest.getVariance();
}
@ -40,7 +35,7 @@ public class GrayLogRankMultipleGroupDifferentiator extends CompetingRiskGroupDi
}
@Override
double riskSet(List<CompetingResponseWithCensorTime> eventList, double time, int eventOfFocus) {
double riskSet(List<CompetingRiskResponseWithCensorTime> eventList, double time, int eventOfFocus) {
return eventList.stream()
.filter(event -> event.getU() >= time ||
(event.getU() < time && event.getDelta() != eventOfFocus && event.getC() > time)

View file

@ -1,11 +1,7 @@
package ca.joeltherrien.randomforest.responses.competingrisk;
import ca.joeltherrien.randomforest.tree.GroupDifferentiator;
import com.fasterxml.jackson.databind.JsonNode;
import lombok.RequiredArgsConstructor;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
/**
@ -13,24 +9,24 @@ import java.util.List;
*
*/
@RequiredArgsConstructor
public class GrayLogRankSingleGroupDifferentiator extends CompetingRiskGroupDifferentiator<CompetingResponseWithCensorTime> {
public class GrayLogRankSingleGroupDifferentiator extends CompetingRiskGroupDifferentiator<CompetingRiskResponseWithCensorTime> {
private final int eventOfFocus;
@Override
public Double differentiate(List<CompetingResponseWithCensorTime> leftHand, List<CompetingResponseWithCensorTime> rightHand) {
public Double differentiate(List<CompetingRiskResponseWithCensorTime> leftHand, List<CompetingRiskResponseWithCensorTime> rightHand) {
if(leftHand.size() == 0 || rightHand.size() == 0){
return null;
}
final LogRankValue valueOfInterest = specificLogRankValue(eventOfFocus, leftHand, rightHand);
return Math.abs(valueOfInterest.getNumerator() / valueOfInterest.getVariance());
return Math.abs(valueOfInterest.getNumerator() / valueOfInterest.getVarianceSqrt());
}
@Override
double riskSet(List<CompetingResponseWithCensorTime> eventList, double time, int eventOfFocus) {
double riskSet(List<CompetingRiskResponseWithCensorTime> eventList, double time, int eventOfFocus) {
return eventList.stream()
.filter(event -> event.getU() >= time ||
(event.getU() < time && event.getDelta() != eventOfFocus && event.getC() > time)

View file

@ -1,11 +1,7 @@
package ca.joeltherrien.randomforest.responses.competingrisk;
import ca.joeltherrien.randomforest.tree.GroupDifferentiator;
import com.fasterxml.jackson.databind.JsonNode;
import lombok.RequiredArgsConstructor;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
/**
@ -13,12 +9,12 @@ import java.util.List;
*
*/
@RequiredArgsConstructor
public class LogRankMultipleGroupDifferentiator extends CompetingRiskGroupDifferentiator<CompetingResponse> {
public class LogRankMultipleGroupDifferentiator extends CompetingRiskGroupDifferentiator<CompetingRiskResponse> {
private final int[] events;
@Override
public Double differentiate(List<CompetingResponse> leftHand, List<CompetingResponse> rightHand) {
public Double differentiate(List<CompetingRiskResponse> leftHand, List<CompetingRiskResponse> rightHand) {
if(leftHand.size() == 0 || rightHand.size() == 0){
return null;
}
@ -29,8 +25,8 @@ public class LogRankMultipleGroupDifferentiator extends CompetingRiskGroupDiffer
for(final int eventOfFocus : events){
final LogRankValue valueOfInterest = specificLogRankValue(eventOfFocus, leftHand, rightHand);
numerator += valueOfInterest.getNumerator()*valueOfInterest.getVariance();
denominatorSquared += valueOfInterest.getVarianceSquared();
numerator += valueOfInterest.getNumerator()*valueOfInterest.getVarianceSqrt();
denominatorSquared += valueOfInterest.getVariance();
}
@ -39,7 +35,7 @@ public class LogRankMultipleGroupDifferentiator extends CompetingRiskGroupDiffer
}
@Override
double riskSet(List<CompetingResponse> eventList, double time, int eventOfFocus) {
double riskSet(List<CompetingRiskResponse> eventList, double time, int eventOfFocus) {
return eventList.stream()
.filter(event -> event.getU() >= time)
.count();

View file

@ -1,6 +1,5 @@
package ca.joeltherrien.randomforest.responses.competingrisk;
import ca.joeltherrien.randomforest.tree.GroupDifferentiator;
import lombok.RequiredArgsConstructor;
import java.util.List;
@ -10,24 +9,24 @@ import java.util.List;
*
*/
@RequiredArgsConstructor
public class LogRankSingleGroupDifferentiator extends CompetingRiskGroupDifferentiator<CompetingResponse> {
public class LogRankSingleGroupDifferentiator extends CompetingRiskGroupDifferentiator<CompetingRiskResponse> {
private final int eventOfFocus;
@Override
public Double differentiate(List<CompetingResponse> leftHand, List<CompetingResponse> rightHand) {
public Double differentiate(List<CompetingRiskResponse> leftHand, List<CompetingRiskResponse> rightHand) {
if(leftHand.size() == 0 || rightHand.size() == 0){
return null;
}
final LogRankValue valueOfInterest = specificLogRankValue(eventOfFocus, leftHand, rightHand);
return Math.abs(valueOfInterest.getNumerator() / valueOfInterest.getVariance());
return Math.abs(valueOfInterest.getNumerator() / valueOfInterest.getVarianceSqrt());
}
@Override
double riskSet(List<CompetingResponse> eventList, double time, int eventOfFocus) {
double riskSet(List<CompetingRiskResponse> eventList, double time, int eventOfFocus) {
return eventList.stream()
.filter(event -> event.getU() >= time)
.count();

View file

@ -0,0 +1,60 @@
package ca.joeltherrien.randomforest.responses.competingrisk;
import lombok.Getter;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.Optional;
/**
* Represents a function represented by discrete points. We assume that the function is a stepwise continuous function,
* constant at the value of the previous encountered point.
*
*/
public class MathFunction {
@Getter
private final List<Point> points;
/**
* Represents the value that should be returned by evaluate if there are points prior to the time the function is being evaluated at.
*
* Map be null.
*/
private final Point defaultValue;
public MathFunction(final List<Point> points){
this(points, new Point(0.0, 0.0));
}
public MathFunction(final List<Point> points, final Point defaultValue){
this.points = Collections.unmodifiableList(points);
this.defaultValue = defaultValue;
}
public Point evaluate(double time){
final Optional<Point> pointOptional = points.stream()
.filter(point -> point.getTime() <= time)
.max(Comparator.comparingDouble(Point::getTime));
return pointOptional.orElse(defaultValue);
}
@Override
public String toString(){
final StringBuilder builder = new StringBuilder();
builder.append("Default point: ");
builder.append(defaultValue);
builder.append("\n");
for(final Point point : points){
builder.append(point);
builder.append("\n");
}
return builder.toString();
}
}

View file

@ -8,8 +8,6 @@ import lombok.Data;
*/
@Data
public class Point {
private final Double time;
private final Double y;
}

View file

@ -3,11 +3,6 @@ package ca.joeltherrien.randomforest.responses.regression;
import ca.joeltherrien.randomforest.tree.ResponseCombiner;
import java.util.List;
import java.util.Set;
import java.util.function.BiConsumer;
import java.util.function.BinaryOperator;
import java.util.function.Function;
import java.util.function.Supplier;
/**
* This implementation of the collector isn't great... but good enough given that I'm not planning to fully support regression trees.
@ -15,11 +10,7 @@ import java.util.function.Supplier;
* (It's not great because you'll lose accuracy as you sum up the doubles, since dividing by n is the very last step.)
*
*/
public class MeanResponseCombiner implements ResponseCombiner<Double, MeanResponseCombiner.Container> {
static{
ResponseCombiner.registerResponseCombiner("MeanResponseCombiner", new MeanResponseCombiner());
}
public class MeanResponseCombiner implements ResponseCombiner<Double, Double> {
@Override
public Double combine(List<Double> responses) {
@ -29,51 +20,5 @@ public class MeanResponseCombiner implements ResponseCombiner<Double, MeanRespon
}
@Override
public Supplier<Container> supplier() {
return () -> new Container(0 ,0);
}
@Override
public BiConsumer<Container, Double> accumulator() {
return (container, number) -> {
container.number+=number;
container.n++;
};
}
@Override
public BinaryOperator<Container> combiner() {
return (c1, c2) -> {
c1.number += c2.number;
c1.n += c2.n;
return c1;
};
}
@Override
public Function<Container, Double> finisher() {
return (container) -> container.number/(double)container.n;
}
@Override
public Set<Characteristics> characteristics() {
return Set.of(Characteristics.UNORDERED);
}
public static class Container{
Container(double number, int n){
this.number = number;
this.n = n;
}
public Double number;
public int n;
}
}

View file

@ -4,17 +4,22 @@ import ca.joeltherrien.randomforest.CovariateRow;
import lombok.Builder;
import java.util.Collection;
import java.util.stream.Collectors;
@Builder
public class Forest<Y> {
public class Forest<O, FO> { // O = output of trees, FO = forest output. In practice O == FO, even in competing risk & survival settings
private final Collection<Node<Y>> trees;
private final ResponseCombiner<Y, ?> treeResponseCombiner;
private final Collection<Node<O>> trees;
private final ResponseCombiner<O, FO> treeResponseCombiner;
public Y evaluate(CovariateRow row){
return trees.parallelStream()
public FO evaluate(CovariateRow row){
return treeResponseCombiner.combine(
trees.parallelStream()
.map(node -> node.evaluate(row))
.collect(treeResponseCombiner);
.collect(Collectors.toList())
);
}
}

View file

@ -16,17 +16,18 @@ import java.util.Collections;
import java.util.List;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;
import java.util.stream.Stream;
@Builder
@AllArgsConstructor(access=AccessLevel.PRIVATE)
public class ForestTrainer<Y> {
public class ForestTrainer<Y, TO, FO> {
private final TreeTrainer<Y> treeTrainer;
private final List<Covariate> covariatesToTry;
private final ResponseCombiner<Y, ?> treeResponseCombiner;
private final TreeTrainer<Y, TO> treeTrainer;
private final List<Covariate> covariates;
private final ResponseCombiner<TO, FO> treeResponseCombiner;
private final List<Row<Y>> data;
// number of covariates to randomly try
@ -45,15 +46,15 @@ public class ForestTrainer<Y> {
this.displayProgress = true;
this.saveTreeLocation = settings.getSaveTreeLocation();
this.covariatesToTry = covariates;
this.treeResponseCombiner = ResponseCombiner.loadResponseCombinerByName(settings.getTreeResponseCombiner());
this.treeTrainer = new TreeTrainer<>(settings);
this.covariates = covariates;
this.treeResponseCombiner = settings.getTreeCombiner();
this.treeTrainer = new TreeTrainer<>(settings, covariates);
}
public Forest<Y> trainSerial(){
public Forest<TO, FO> trainSerial(){
final List<Node<Y>> trees = new ArrayList<>(ntree);
final List<Node<TO>> trees = new ArrayList<>(ntree);
final Bootstrapper<Row<Y>> bootstrapper = new Bootstrapper<>(data);
for(int j=0; j<ntree; j++){
@ -71,18 +72,18 @@ public class ForestTrainer<Y> {
}
}
return Forest.<Y>builder()
return Forest.<TO, FO>builder()
.treeResponseCombiner(treeResponseCombiner)
.trees(trees)
.build();
}
public Forest<Y> trainParallelInMemory(int threads){
public Forest<TO, FO> trainParallelInMemory(int threads){
// create a list that is prespecified in size (I can call the .set method at any index < ntree without
// the earlier indexes being filled.
final List<Node<Y>> trees = Stream.<Node<Y>>generate(() -> null).limit(ntree).collect(Collectors.toList());
final List<Node<TO>> trees = Stream.<Node<TO>>generate(() -> null).limit(ntree).collect(Collectors.toList());
final ExecutorService executorService = Executors.newFixedThreadPool(threads);
@ -102,7 +103,7 @@ public class ForestTrainer<Y> {
if(displayProgress) {
int numberTreesSet = 0;
for (final Node<Y> tree : trees) {
for (final Node<TO> tree : trees) {
if (tree != null) {
numberTreesSet++;
}
@ -117,7 +118,7 @@ public class ForestTrainer<Y> {
System.out.println("\nFinished");
}
return Forest.<Y>builder()
return Forest.<TO, FO>builder()
.treeResponseCombiner(treeResponseCombiner)
.trees(trees)
.build();
@ -156,20 +157,12 @@ public class ForestTrainer<Y> {
}
private Node<Y> trainTree(final Bootstrapper<Row<Y>> bootstrapper){
final List<Covariate> treeCovariates = new ArrayList<>(covariatesToTry);
Collections.shuffle(treeCovariates);
for(int treeIndex = covariatesToTry.size()-1; treeIndex >= mtry; treeIndex--){
treeCovariates.remove(treeIndex);
}
private Node<TO> trainTree(final Bootstrapper<Row<Y>> bootstrapper){
final List<Row<Y>> bootstrappedData = bootstrapper.bootstrap();
return treeTrainer.growTree(bootstrappedData, treeCovariates);
return treeTrainer.growTree(bootstrappedData);
}
public void saveTree(final Node<Y> tree, String name) throws IOException {
public void saveTree(final Node<TO> tree, String name) throws IOException {
final String filename = saveTreeLocation + "/" + name;
final ObjectOutputStream outputStream = new ObjectOutputStream(new FileOutputStream(filename));
@ -184,9 +177,9 @@ public class ForestTrainer<Y> {
private final Bootstrapper<Row<Y>> bootstrapper;
private final int treeIndex;
private final List<Node<Y>> treeList;
private final List<Node<TO>> treeList;
public TreeInMemoryWorker(final List<Row<Y>> data, final int treeIndex, final List<Node<Y>> treeList) {
public TreeInMemoryWorker(final List<Row<Y>> data, final int treeIndex, final List<Node<TO>> treeList) {
this.bootstrapper = new Bootstrapper<>(data);
this.treeIndex = treeIndex;
this.treeList = treeList;
@ -195,7 +188,7 @@ public class ForestTrainer<Y> {
@Override
public void run() {
final Node<Y> tree = trainTree(bootstrapper);
final Node<TO> tree = trainTree(bootstrapper);
// should be okay as the list structure isn't changing
treeList.set(treeIndex, tree);
@ -218,7 +211,7 @@ public class ForestTrainer<Y> {
@Override
public void run() {
final Node<Y> tree = trainTree(bootstrapper);
final Node<TO> tree = trainTree(bootstrapper);
try {
saveTree(tree, filename);

View file

@ -1,10 +1,6 @@
package ca.joeltherrien.randomforest.tree;
import com.fasterxml.jackson.databind.node.ObjectNode;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
/**
* When choosing an optimal node to split on, we choose the split that maximizes the difference between the two groups.
@ -16,11 +12,4 @@ public interface GroupDifferentiator<Y> {
Double differentiate(List<Y> leftHand, List<Y> rightHand);
@FunctionalInterface
interface GroupDifferentiatorConstructor<Y>{
GroupDifferentiator<Y> construct(ObjectNode node);
}
}

View file

@ -1,20 +1,9 @@
package ca.joeltherrien.randomforest.tree;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collector;
public interface ResponseCombiner<Y, K> extends Collector<Y, K, Y> {
public interface ResponseCombiner<I, O> {
Y combine(List<Y> responses);
final static Map<String, ResponseCombiner> RESPONSE_COMBINER_MAP = new HashMap<>();
static ResponseCombiner loadResponseCombinerByName(final String name){
return RESPONSE_COMBINER_MAP.get(name);
}
static void registerResponseCombiner(final String name, final ResponseCombiner responseCombiner){
RESPONSE_COMBINER_MAP.put(name, responseCombiner);
}
O combine(List<I> responses);
}

View file

@ -7,13 +7,14 @@ import lombok.AllArgsConstructor;
import lombok.Builder;
import java.util.*;
import java.util.concurrent.ThreadLocalRandom;
import java.util.stream.Collectors;
@Builder
@AllArgsConstructor(access = AccessLevel.PRIVATE)
public class TreeTrainer<Y> {
public class TreeTrainer<Y, O> {
private final ResponseCombiner<Y, ?> responseCombiner;
private final ResponseCombiner<Y, O> responseCombiner;
private final GroupDifferentiator<Y> groupDifferentiator;
/**
@ -23,54 +24,76 @@ public class TreeTrainer<Y> {
private final int numberOfSplits;
private final int nodeSize;
private final int maxNodeDepth;
private final int mtry;
public TreeTrainer(final Settings settings){
private final List<Covariate> covariates;
public TreeTrainer(final Settings settings, final List<Covariate> covariates){
this.numberOfSplits = settings.getNumberOfSplits();
this.nodeSize = settings.getNodeSize();
this.maxNodeDepth = settings.getMaxNodeDepth();
this.mtry = settings.getMtry();
this.responseCombiner = ResponseCombiner.loadResponseCombinerByName(settings.getResponseCombiner());
this.responseCombiner = settings.getResponseCombiner();
this.groupDifferentiator = settings.getGroupDifferentiator();
this.covariates = covariates;
}
public Node<Y> growTree(List<Row<Y>> data, List<Covariate> covariatesToTry){
return growNode(data, covariatesToTry, 0);
public Node<O> growTree(List<Row<Y>> data){
return growNode(data, 0);
}
private Node<Y> growNode(List<Row<Y>> data, List<Covariate> covariatesToTry, int depth){
private Node<O> growNode(List<Row<Y>> data, int depth){
// TODO; what is minimum per tree?
if(data.size() >= 2*nodeSize && depth < maxNodeDepth && !nodeIsPure(data)){
final List<Covariate> covariatesToTry = selectCovariates(this.mtry);
final Covariate.SplitRule bestSplitRule = findBestSplitRule(data, covariatesToTry);
if(bestSplitRule == null){
return new TerminalNode<>(
data.stream()
.map(row -> row.getResponse())
.collect(responseCombiner)
return new TerminalNode<>(
responseCombiner.combine(
data.stream().map(row -> row.getResponse()).collect(Collectors.toList())
)
);
}
final Split<Y> split = bestSplitRule.applyRule(data); // TODO optimize this as we're duplicating work done in findBestSplitRule
final Node<Y> leftNode = growNode(split.leftHand, covariatesToTry, depth+1);
final Node<Y> rightNode = growNode(split.rightHand, covariatesToTry, depth+1);
final Node<O> leftNode = growNode(split.leftHand, depth+1);
final Node<O> rightNode = growNode(split.rightHand, depth+1);
return new SplitNode<>(leftNode, rightNode, bestSplitRule);
}
else{
return new TerminalNode<>(
data.stream()
.map(row -> row.getResponse())
.collect(responseCombiner)
responseCombiner.combine(
data.stream().map(row -> row.getResponse()).collect(Collectors.toList())
)
);
}
}
private List<Covariate> selectCovariates(int mtry){
if(mtry >= covariates.size()){
return covariates;
}
final List<Covariate> splitCovariates = new ArrayList<>(covariates);
Collections.shuffle(splitCovariates, ThreadLocalRandom.current());
for(int treeIndex = splitCovariates.size()-1; treeIndex >= mtry; treeIndex--){
splitCovariates.remove(treeIndex);
}
return splitCovariates;
}
private Covariate.SplitRule findBestSplitRule(List<Row<Y>> data, List<Covariate> covariatesToTry){
Covariate.SplitRule bestSplitRule = null;
double bestSplitScore = 0.0;
@ -96,7 +119,9 @@ public class TreeTrainer<Y> {
possibleSplit.rightHand.stream().map(row -> row.getResponse()).collect(Collectors.toList())
);
if(score != null && (score > bestSplitScore || first)){
if(score != null && !Double.isNaN(score) && (score > bestSplitScore || first)){
bestSplitRule = possibleRule;
bestSplitScore = score;
first = false;

View file

@ -0,0 +1,99 @@
# Comparing output between my two packages
require(survival)
require(randomForestSRC)
data(wihs, package = "randomForestSRC")
wihs$idu = as.logical(wihs$idu)
wihs$black = as.logical(wihs$black)
set.seed(100)
wihs = wihs[sample(1:nrow(wihs), replace = FALSE),]
# example row
# time status ageatfda idu black cd4nadir
#409 1.3 1 35 FALSE FALSE 0.81
newData = data.frame(ageatfda=35, idu=FALSE, black=FALSE, cd4nadir=0.81)
one.tree <- rfsrc(Surv(time, status) ~ idu + black, wihs, nsplit = 5, ntree = 1, splitrule="logrank", cause=1, mtry=2, seed=-5, membership=TRUE)
# next get membership
membership=one.tree$inbag[,1]
if(FALSE){
bootstrappedData = wihs[c(),]
for(i in 1:length(membership)){
times = membership[i]
if(times > 0){
for(j in 1:times){
bootstrappedData = rbind(bootstrappedData, wihs[i,])
}
}
}
}
output.one.tree = predict(one.tree, newData)
output.one.tree$cif[,,1]
output.one.tree$chf[,c(11,66,103),1]
# Note that "... ~ ." means 'use all explanatory variables"
#output = predict(wihs.obj, newData)
#output$cif[,,1] # CIF for cause 1
#output$cif[,,2] # CIF for cause 2
many.trees <- rfsrc(Surv(time, status) ~ idu + black, wihs, nsplit = 5, ntree = 100, splitrule="logrank", cause=1, mtry=2, membership=TRUE)
output.many.trees = predict(many.trees, newData)
output.many.trees$cif[,41,1]
output.many.trees$cif[,41,2]
many.trees.all <- rfsrc(Surv(time, status) ~ ageatfda + cd4nadir + idu + black, wihs, nsplit = 5, ntree = 100, splitrule="logrank", cause=1, mtry=2, membership=TRUE)
output.many.trees.all = predict(many.trees.all, newData)
output.many.trees.all$cif[,103,1]
output.many.trees.all$cif[,103,2]
end.numbers = c()
end.times = c()
lgths = c()
trees = list()
for(i in 1:100){
one.tree = rfsrc(Surv(time, status) ~ ageatfda + cd4nadir + idu + black, wihs, nsplit = 0, ntree = 1, splitrule="logrank", cause=1, mtry=4, membership=TRUE, statistics = TRUE)
trees[[i]] = one.tree
prediction = predict(one.tree, newData)
lgth = length(prediction$cif[,,1])
lgths = c(lgths, lgth)
end.numbers = c(end.numbers, prediction$cif[,lgth,1])
end.times = c(end.times, max(prediction$time.interest))
}
special.tree = trees[[100]]
prediction = predict(special.tree, newData)
prediction$cif[,,1]
membership = special.tree$inbag[,1]
bootstrappedData = wihs[c(),]
for(i in 1:length(membership)){
times = membership[i]
if(times > 0){
for(j in 1:times){
bootstrappedData = rbind(bootstrappedData, wihs[i,])
}
}
}
write.csv(bootstrappedData, "RandomSurvivalForests/src/test/resources/wihs.bootstrapped2.csv", row.names=FALSE)
prediction$cif[,,1]

View file

@ -0,0 +1,325 @@
package ca.joeltherrien.randomforest.competingrisk;
import ca.joeltherrien.randomforest.*;
import ca.joeltherrien.randomforest.covariates.*;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskFunctions;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
import ca.joeltherrien.randomforest.responses.competingrisk.MathFunction;
import ca.joeltherrien.randomforest.responses.competingrisk.Point;
import ca.joeltherrien.randomforest.tree.Forest;
import ca.joeltherrien.randomforest.tree.ForestTrainer;
import ca.joeltherrien.randomforest.tree.Node;
import ca.joeltherrien.randomforest.tree.TreeTrainer;
import com.fasterxml.jackson.databind.node.*;
import org.junit.jupiter.api.Test;
import static org.junit.jupiter.api.Assertions.*;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ThreadLocalRandom;
import java.util.stream.Collectors;
public class TestCompetingRisk {
/**
* By default uses single log-rank test.
*
* @return
*/
public Settings getSettings(){
final ObjectNode groupDifferentiatorSettings = new ObjectNode(JsonNodeFactory.instance);
groupDifferentiatorSettings.set("type", new TextNode("LogRankSingleGroupDifferentiator"));
groupDifferentiatorSettings.set("eventOfFocus", new IntNode(1));
final ObjectNode responseCombinerSettings = new ObjectNode(JsonNodeFactory.instance);
responseCombinerSettings.set("type", new TextNode("CompetingRiskResponseCombiner"));
responseCombinerSettings.set("events",
new ArrayNode(JsonNodeFactory.instance, List.of(new IntNode(1), new IntNode(2)))
);
// not setting times
final ObjectNode treeCombinerSettings = new ObjectNode(JsonNodeFactory.instance);
treeCombinerSettings.set("type", new TextNode("CompetingRiskFunctionCombiner"));
treeCombinerSettings.set("events",
new ArrayNode(JsonNodeFactory.instance, List.of(new IntNode(1), new IntNode(2)))
);
// not setting times
final ObjectNode yVarSettings = new ObjectNode(JsonNodeFactory.instance);
yVarSettings.set("type", new TextNode("CompetingRiskResponse"));
yVarSettings.set("u", new TextNode("time"));
yVarSettings.set("delta", new TextNode("status"));
return Settings.builder()
.covariates(List.of(
new NumericCovariateSettings("ageatfda"),
new BooleanCovariateSettings("idu"),
new BooleanCovariateSettings("black"),
new NumericCovariateSettings("cd4nadir")
)
)
.dataFileLocation("src/test/resources/wihs.csv")
.responseCombinerSettings(responseCombinerSettings)
.treeCombinerSettings(treeCombinerSettings)
.groupDifferentiatorSettings(groupDifferentiatorSettings)
.yVarSettings(yVarSettings)
.maxNodeDepth(100000)
// TODO fill in these settings
.mtry(2)
.nodeSize(6)
.ntree(100)
.numberOfSplits(5)
.numberOfThreads(3)
.saveProgress(true)
.saveTreeLocation("trees/")
.build();
}
public CovariateRow getPredictionRow(List<Covariate> covariates){
return CovariateRow.createSimple(Map.of(
"ageatfda", "35",
"idu", "false",
"black", "false",
"cd4nadir", "0.81")
, covariates, 1);
}
@Test
public void testSingleTree() throws IOException {
final Settings settings = getSettings();
settings.setDataFileLocation("src/test/resources/wihs.bootstrapped.csv");
settings.setCovariates(List.of(
new BooleanCovariateSettings("idu"),
new BooleanCovariateSettings("black")
)); // by only using BooleanCovariates (only one split rule) we can guarantee identical results with randomForestSRC on one tree.
final List<Covariate> covariates = getCovariates(settings);
final List<Row<CompetingRiskResponse>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getDataFileLocation());
final TreeTrainer<CompetingRiskResponse, CompetingRiskFunctions> treeTrainer = new TreeTrainer<>(settings, covariates);
final Node<CompetingRiskFunctions> node = treeTrainer.growTree(dataset);
final CovariateRow newRow = getPredictionRow(covariates);
final CompetingRiskFunctions functions = node.evaluate(newRow);
final MathFunction causeOneCIFFunction = functions.getCumulativeIncidenceFunction(1);
final MathFunction causeTwoCIFFunction = functions.getCumulativeIncidenceFunction(2);
final MathFunction cumHazOneFunction = functions.getCauseSpecificHazardFunction(1);
final MathFunction cumHazTwoFunction = functions.getCauseSpecificHazardFunction(2);
final double margin = 0.0000001;
closeEnough(0.003003003, causeOneCIFFunction.evaluate(0.02).getY(), margin);
closeEnough(0.166183852, causeOneCIFFunction.evaluate(1.00).getY(), margin);
closeEnough(0.715625487, causeOneCIFFunction.evaluate(6.50).getY(), margin);
closeEnough(0.794796334, causeOneCIFFunction.evaluate(10.60).getY(), margin);
closeEnough(0.794796334, causeOneCIFFunction.evaluate(10.80).getY(), margin);
closeEnough(0.08149211, causeTwoCIFFunction.evaluate(1.00).getY(), margin);
closeEnough(0.14926318, causeTwoCIFFunction.evaluate(6.50).getY(), margin);
closeEnough(0.15332850, causeTwoCIFFunction.evaluate(10.80).getY(), margin);
closeEnough(0.1888601, cumHazOneFunction.evaluate(1.00).getY(), margin);
closeEnough(1.6189759, cumHazOneFunction.evaluate(6.50).getY(), margin);
closeEnough(2.4878342, cumHazOneFunction.evaluate(10.80).getY(), margin);
closeEnough(0.08946513, cumHazTwoFunction.evaluate(1.00).getY(), margin);
closeEnough(0.32801830, cumHazTwoFunction.evaluate(6.50).getY(), margin);
closeEnough(0.36505534, cumHazTwoFunction.evaluate(10.80).getY(), margin);
}
/**
* Note - this test triggers a situation where the variance calculation in the log-rank test experiences an NaN.
*
* @throws IOException
*/
@Test
public void testSingleTree2() throws IOException {
final Settings settings = getSettings();
settings.setMtry(4);
settings.setNumberOfSplits(0);
settings.setDataFileLocation("src/test/resources/wihs.bootstrapped2.csv");
final List<Covariate> covariates = getCovariates(settings);
final List<Row<CompetingRiskResponse>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getDataFileLocation());
final TreeTrainer<CompetingRiskResponse, CompetingRiskFunctions> treeTrainer = new TreeTrainer<>(settings, covariates);
final Node<CompetingRiskFunctions> node = treeTrainer.growTree(dataset);
final CovariateRow newRow = getPredictionRow(covariates);
final CompetingRiskFunctions functions = node.evaluate(newRow);
final MathFunction causeOneCIFFunction = functions.getCumulativeIncidenceFunction(1);
final MathFunction causeTwoCIFFunction = functions.getCumulativeIncidenceFunction(2);
final MathFunction cumHazOneFunction = functions.getCauseSpecificHazardFunction(1);
final MathFunction cumHazTwoFunction = functions.getCauseSpecificHazardFunction(2);
final double margin = 0.0000001;
closeEnough(0, causeOneCIFFunction.evaluate(0.02).getY(), margin);
closeEnough(0.555555555, causeOneCIFFunction.evaluate(0.4).getY(), margin);
closeEnough(0.66666666666, causeOneCIFFunction.evaluate(0.8).getY(), margin);
closeEnough(0.88888888888, causeOneCIFFunction.evaluate(0.9).getY(), margin);
closeEnough(1.0, causeOneCIFFunction.evaluate(1.0).getY(), margin);
/*
closeEnough(0.08149211, causeTwoCIFFunction.evaluate(1.00).getY(), margin);
closeEnough(0.14926318, causeTwoCIFFunction.evaluate(6.50).getY(), margin);
closeEnough(0.15332850, causeTwoCIFFunction.evaluate(10.80).getY(), margin);
closeEnough(0.1888601, cumHazOneFunction.evaluate(1.00).getY(), margin);
closeEnough(1.6189759, cumHazOneFunction.evaluate(6.50).getY(), margin);
closeEnough(2.4878342, cumHazOneFunction.evaluate(10.80).getY(), margin);
closeEnough(0.08946513, cumHazTwoFunction.evaluate(1.00).getY(), margin);
closeEnough(0.32801830, cumHazTwoFunction.evaluate(6.50).getY(), margin);
closeEnough(0.36505534, cumHazTwoFunction.evaluate(10.80).getY(), margin);
*/
}
public List<Covariate> getCovariates(Settings settings){
return settings.getCovariates().stream().map(covariateSettings -> covariateSettings.build()).collect(Collectors.toList());
}
@Test
public void testLogRankSingleGroupDifferentiatorTwoBooleans() throws IOException {
final Settings settings = getSettings();
settings.setCovariates(List.of(
new BooleanCovariateSettings("idu"),
new BooleanCovariateSettings("black")
));
final List<Covariate> covariates = getCovariates(settings);
final List<Row<CompetingRiskResponse>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getDataFileLocation());
final ForestTrainer<CompetingRiskResponse, CompetingRiskFunctions, CompetingRiskFunctions> forestTrainer = new ForestTrainer<>(settings, dataset, covariates);
final Forest<CompetingRiskFunctions, CompetingRiskFunctions> forest = forestTrainer.trainSerial();
// prediction row
// time status ageatfda idu black cd4nadir
//409 1.3 1 35 FALSE FALSE 0.81
final CovariateRow newRow = getPredictionRow(covariates);
final CompetingRiskFunctions functions = forest.evaluate(newRow);
assertCumulativeFunction(functions.getCauseSpecificHazardFunction(1));
assertCumulativeFunction(functions.getCauseSpecificHazardFunction(2));
assertCumulativeFunction(functions.getCumulativeIncidenceFunction(1));
assertCumulativeFunction(functions.getCumulativeIncidenceFunction(2));
closeEnough(0.63, functions.getCumulativeIncidenceFunction(1).evaluate(4.0).getY(), 0.01);
closeEnough(0.765, functions.getCumulativeIncidenceFunction(1).evaluate(10.8).getY(), 0.01);
closeEnough(0.163, functions.getCumulativeIncidenceFunction(2).evaluate(4.0).getY(), 0.01);
closeEnough(0.195, functions.getCumulativeIncidenceFunction(2).evaluate(10.8).getY(), 0.01);
}
@Test
public void verifyDataset() throws IOException {
final Settings settings = getSettings();
final List<Covariate> covariates = getCovariates(settings);
final List<Row<CompetingRiskResponse>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getDataFileLocation());
// Let's count the events and make sure the data was correctly read.
int countCensored = 0;
int countEventOne = 0;
int countEventTwo = 0;
for(final Row<CompetingRiskResponse> row : dataset){
final CompetingRiskResponse response = row.getResponse();
if(response.getDelta() == 0){
countCensored++;
}
else if(response.getDelta() == 1){
countEventOne++;
}
else if(response.getDelta() == 2){
countEventTwo++;
}
else{
throw new RuntimeException("There's an event of type " + response.getDelta());
}
}
assertEquals(126, countCensored);
assertEquals(679, countEventOne);
assertEquals(359, countEventTwo);
}
@Test
public void testLogRankSingleGroupDifferentiatorAllCovariates() throws IOException {
final Settings settings = getSettings();
final List<Covariate> covariates = getCovariates(settings);
final List<Row<CompetingRiskResponse>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getDataFileLocation());
final ForestTrainer<CompetingRiskResponse, CompetingRiskFunctions, CompetingRiskFunctions> forestTrainer = new ForestTrainer<>(settings, dataset, covariates);
final Forest<CompetingRiskFunctions, CompetingRiskFunctions> forest = forestTrainer.trainSerial();
// prediction row
// time status ageatfda idu black cd4nadir
//409 1.3 1 35 FALSE FALSE 0.81
final CovariateRow newRow = getPredictionRow(covariates);
final CompetingRiskFunctions functions = forest.evaluate(newRow);
assertCumulativeFunction(functions.getCauseSpecificHazardFunction(1));
assertCumulativeFunction(functions.getCauseSpecificHazardFunction(2));
assertCumulativeFunction(functions.getCumulativeIncidenceFunction(1));
assertCumulativeFunction(functions.getCumulativeIncidenceFunction(2));
final List<Point> causeOneCIFPoints = functions.getCumulativeIncidenceFunction(1).getPoints();
// We seem to consistently underestimate the results.
assertTrue(causeOneCIFPoints.get(causeOneCIFPoints.size()-1).getY() > 0.75, "Results should match randomForestSRC");
}
/**
* We know the function is cumulative; make sure it is ordered correctly and that that function is monotone.
*
* @param function
*/
private void assertCumulativeFunction(MathFunction function){
Point previousPoint = null;
for(final Point point : function.getPoints()){
if(previousPoint != null){
assertTrue(previousPoint.getTime() < point.getTime(), "Points should be ordered and strictly different");
assertTrue(previousPoint.getY() <= point.getY(), "Cumulative incidence functions are monotone");
}
previousPoint = point;
}
}
private void closeEnough(double expected, double actual, double margin){
assertTrue(Math.abs(expected - actual) < margin, "Expected " + expected + " but saw " + actual);
}
}

View file

@ -0,0 +1,12 @@
testData = data.frame(delta=c(1,1,1,2,2,0,0), u=c(1,1,2,1.5,2,1.5,2.5))
testData$survDelta = ifelse(testData$delta==0, 0, 1) # for KM curves on any events
require(survival)
kmCurve = survfit(Surv(u, survDelta, type="right") ~ 1, data=testData)
kmCurve$surv
curve = survfit(Surv(u, event=delta, type="mstate") ~ 1, data=testData)
curve$cumhaz[3,1:2,]
print(t(curve$pstate[,1:2]))

View file

@ -0,0 +1,94 @@
package ca.joeltherrien.randomforest.competingrisk;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskFunctions;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponseCombiner;
import ca.joeltherrien.randomforest.responses.competingrisk.MathFunction;
import org.junit.jupiter.api.Test;
import static org.junit.jupiter.api.Assertions.*;
import java.util.ArrayList;
import java.util.List;
public class TestCompetingRiskResponseCombiner {
private CompetingRiskFunctions generateFunctions(){
final List<CompetingRiskResponse> data = new ArrayList<>();
data.add(new CompetingRiskResponse(1, 1.0));
data.add(new CompetingRiskResponse(1, 1.0));
data.add(new CompetingRiskResponse(1, 2.0));
data.add(new CompetingRiskResponse(2, 1.5));
data.add(new CompetingRiskResponse(2, 2.0));
data.add(new CompetingRiskResponse(0, 1.5));
data.add(new CompetingRiskResponse(0, 2.5));
final CompetingRiskResponseCombiner combiner = new CompetingRiskResponseCombiner(new int[]{1,2}, null);
final CompetingRiskFunctions functions = combiner.combine(data);
return functions;
}
@Test
public void testCompetingRiskResponseCombiner(){
final CompetingRiskFunctions functions = generateFunctions();
final MathFunction survivalCurve = functions.getSurvivalCurve();
// time = 1.0 1.5 2.0 2.5
// surv = 0.7142857 0.5714286 0.1904762 0.1904762
final double margin = 0.0000001;
closeEnough(0.7142857, survivalCurve.evaluate(1.0).getY(), margin);
closeEnough(0.5714286, survivalCurve.evaluate(1.5).getY(), margin);
closeEnough(0.1904762, survivalCurve.evaluate(2.0).getY(), margin);
closeEnough(0.1904762, survivalCurve.evaluate(2.5).getY(), margin);
// Time = 1.0 1.5 2.0 2.5
/* Cumulative hazard function. Each row for one event.
[,1] [,2] [,3] [,4]
[1,] 0.2857143 0.2857143 0.6190476 0.6190476
[2,] 0.0000000 0.2000000 0.5333333 0.5333333
*/
final MathFunction cumHaz1 = functions.getCauseSpecificHazardFunction(1);
closeEnough(0.2857143, cumHaz1.evaluate(1.0).getY(), margin);
closeEnough(0.2857143, cumHaz1.evaluate(1.5).getY(), margin);
closeEnough(0.6190476, cumHaz1.evaluate(2.0).getY(), margin);
closeEnough(0.6190476, cumHaz1.evaluate(2.5).getY(), margin);
final MathFunction cumHaz2 = functions.getCauseSpecificHazardFunction(2);
closeEnough(0.0, cumHaz2.evaluate(1.0).getY(), margin);
closeEnough(0.2, cumHaz2.evaluate(1.5).getY(), margin);
closeEnough(0.5333333, cumHaz2.evaluate(2.0).getY(), margin);
closeEnough(0.5333333, cumHaz2.evaluate(2.5).getY(), margin);
/* Time = 1.0 1.5 2.0 2.5
Cumulative Incidence Curve. Each row for one event.
[,1] [,2] [,3] [,4]
[1,] 0.2857143 0.2857143 0.4761905 0.4761905
[2,] 0.0000000 0.1428571 0.3333333 0.3333333
*/
final MathFunction cic1 = functions.getCumulativeIncidenceFunction(1);
closeEnough(0.2857143, cic1.evaluate(1.0).getY(), margin);
closeEnough(0.2857143, cic1.evaluate(1.5).getY(), margin);
closeEnough(0.4761905, cic1.evaluate(2.0).getY(), margin);
closeEnough(0.4761905, cic1.evaluate(2.5).getY(), margin);
final MathFunction cic2 = functions.getCumulativeIncidenceFunction(2);
closeEnough(0.0, cic2.evaluate(1.0).getY(), margin);
closeEnough(0.1428571, cic2.evaluate(1.5).getY(), margin);
closeEnough(0.3333333, cic2.evaluate(2.0).getY(), margin);
closeEnough(0.3333333, cic2.evaluate(2.5).getY(), margin);
}
private void closeEnough(double expected, double actual, double margin){
assertTrue(Math.abs(expected - actual) < margin, "Expected " + expected + " but saw " + actual);
}
}

View file

@ -0,0 +1,62 @@
testData1 = data.frame(delta=c(1,1,1,1,0,0,0), u=c(1,1,2,1.5,2,1.5,2.5), group=TRUE)
testData2 = data.frame(delta=c(1,1,1,1,0,0,0), u=c(2,2,4,3,4,3,5), group=FALSE) # just doubled everything
testData = rbind(testData1, testData2)
require(survival)
results = survdiff(Surv(u, delta, type="right") ~ group, data=testData)
mantelTest = function(times, observed, group0, group1=!group0){
U0 = times[group0]
observed0 = observed[group0]
U1 = times[group1]
observed1 = observed[group1]
Vs = sort(unique(c(U0[observed0], U1[observed1])))
atRisk = function(v, u){
u = subset(u, u >= v)
return(length(u))
}
Os = c()
Es = c()
varOs = c()
# we're going to treat group 1 as treatment
for(v in Vs){
placeboAtRisk = atRisk(v, U0)
treatmentAtRisk = atRisk(v, U1)
totalAtRisk = placeboAtRisk + treatmentAtRisk
numTreatmentFailures = length(subset(U1, observed1 & U1 == v))
numPlaceboFailures = length(subset(U0, observed0 & U0 == v))
totalFailures = numTreatmentFailures + numPlaceboFailures
Os = c(Os, numTreatmentFailures)
Es = c(Es, (totalFailures)*treatmentAtRisk/totalAtRisk)
varOfO = (totalAtRisk - treatmentAtRisk)/(totalAtRisk - 1) *
treatmentAtRisk * (totalFailures / totalAtRisk) *
(1 - totalFailures / totalAtRisk)
if(totalAtRisk == 1){
varOfO = 0
}
varOs = c(varOs, varOfO)
}
numerator = sum(Os - Es)
variance = sum(varOs)
Z = numerator/sqrt(variance)
return(list(
statistic = Z,
pvalue = 2*pnorm(abs(Z), lower.tail=FALSE),
numerator = numerator,
variance = variance
))
}
myTest = mantelTest(testData$u, testData$delta == 1, group0=testData$group==1)

View file

@ -0,0 +1,61 @@
package ca.joeltherrien.randomforest.competingrisk;
import ca.joeltherrien.randomforest.responses.competingrisk.*;
import org.junit.jupiter.api.Test;
import java.util.ArrayList;
import java.util.List;
import static org.junit.jupiter.api.Assertions.assertTrue;
public class TestLogRankSingleGroupDifferentiator {
private List<CompetingRiskResponse> generateData1(){
final List<CompetingRiskResponse> data = new ArrayList<>();
data.add(new CompetingRiskResponse(1, 1.0));
data.add(new CompetingRiskResponse(1, 1.0));
data.add(new CompetingRiskResponse(1, 2.0));
data.add(new CompetingRiskResponse(1, 1.5));
data.add(new CompetingRiskResponse(0, 2.0));
data.add(new CompetingRiskResponse(0, 1.5));
data.add(new CompetingRiskResponse(0, 2.5));
return data;
}
private List<CompetingRiskResponse> generateData2(){
final List<CompetingRiskResponse> data = new ArrayList<>();
data.add(new CompetingRiskResponse(1, 2.0));
data.add(new CompetingRiskResponse(1, 2.0));
data.add(new CompetingRiskResponse(1, 4.0));
data.add(new CompetingRiskResponse(1, 3.0));
data.add(new CompetingRiskResponse(0, 4.0));
data.add(new CompetingRiskResponse(0, 3.0));
data.add(new CompetingRiskResponse(0, 5.0));
return data;
}
@Test
public void testCompetingRiskResponseCombiner(){
final List<CompetingRiskResponse> data1 = generateData1();
final List<CompetingRiskResponse> data2 = generateData2();
final LogRankSingleGroupDifferentiator differentiator = new LogRankSingleGroupDifferentiator(1);
final double score = differentiator.differentiate(data1, data2);
final double margin = 0.000001;
// Tested using 855 method
closeEnough(1.540139, score, margin);
}
private void closeEnough(double expected, double actual, double margin){
assertTrue(Math.abs(expected - actual) < margin, "Expected " + expected + " but saw " + actual);
}
}

View file

@ -0,0 +1,43 @@
package ca.joeltherrien.randomforest.competingrisk;
import ca.joeltherrien.randomforest.responses.competingrisk.MathFunction;
import ca.joeltherrien.randomforest.responses.competingrisk.Point;
import org.junit.jupiter.api.Test;
import static org.junit.jupiter.api.Assertions.*;
import java.util.ArrayList;
import java.util.List;
public class TestMathFunction {
private MathFunction generateMathFunction(){
final double[] time = new double[]{1.0, 2.0, 3.0};
final double[] y = new double[]{-1.0, 1.0, 0.5};
final List<Point> pointList = new ArrayList<>();
for(int i=0; i<time.length; i++){
pointList.add(new Point(time[i], y[i]));
}
return new MathFunction(pointList, new Point(0.0, 0.1));
}
@Test
public void test(){
final MathFunction function = generateMathFunction();
assertEquals(new Point(1.0, -1.0), function.evaluate(1.0));
assertEquals(new Point(2.0, 1.0), function.evaluate(2.0));
assertEquals(new Point(3.0, 0.5), function.evaluate(3.0));
assertEquals(new Point(0.0, 0.1), function.evaluate(0.5));
assertEquals(new Point(1.0, -1.0), function.evaluate(1.1));
assertEquals(new Point(2.0, 1.0), function.evaluate(2.1));
assertEquals(new Point(3.0, 0.5), function.evaluate(3.1));
assertEquals(new Point(0.0, 0.1), function.evaluate(0.6));
}
}

View file

@ -20,6 +20,12 @@ public class TestPersistence {
final ObjectNode groupDifferentiatorSettings = new ObjectNode(JsonNodeFactory.instance);
groupDifferentiatorSettings.set("type", new TextNode("WeightedVarianceGroupDifferentiator"));
final ObjectNode responseCombinerSettings = new ObjectNode(JsonNodeFactory.instance);
responseCombinerSettings.set("type", new TextNode("MeanResponseCombiner"));
final ObjectNode treeCombinerSettings = new ObjectNode(JsonNodeFactory.instance);
treeCombinerSettings.set("type", new TextNode("MeanResponseCombiner"));
final ObjectNode yVarSettings = new ObjectNode(JsonNodeFactory.instance);
yVarSettings.set("type", new TextNode("Double"));
yVarSettings.set("name", new TextNode("y"));
@ -32,8 +38,8 @@ public class TestPersistence {
)
)
.dataFileLocation("data.csv")
.responseCombiner("MeanResponseCombiner")
.treeResponseCombiner("MeanResponseCombiner")
.responseCombinerSettings(responseCombinerSettings)
.treeCombinerSettings(treeCombinerSettings)
.groupDifferentiatorSettings(groupDifferentiatorSettings)
.yVarSettings(yVarSettings)
.maxNodeDepth(100000)

View file

@ -51,7 +51,7 @@ public class TrainForest {
}
final TreeTrainer<Double> treeTrainer = TreeTrainer.<Double>builder()
final TreeTrainer<Double, Double> treeTrainer = TreeTrainer.<Double, Double>builder()
.numberOfSplits(5)
.nodeSize(5)
.maxNodeDepth(100000000)
@ -59,10 +59,10 @@ public class TrainForest {
.responseCombiner(new MeanResponseCombiner())
.build();
final ForestTrainer<Double> forestTrainer = ForestTrainer.<Double>builder()
final ForestTrainer<Double, Double, Double> forestTrainer = ForestTrainer.<Double, Double, Double>builder()
.treeTrainer(treeTrainer)
.data(data)
.covariatesToTry(covariateList)
.covariates(covariateList)
.mtry(4)
.ntree(100)
.treeResponseCombiner(new MeanResponseCombiner())

View file

@ -48,18 +48,20 @@ public class TrainSingleTree {
trainingSet.add(generateRow(x1, x2, i));
}
final TreeTrainer<Double> treeTrainer = TreeTrainer.<Double>builder()
final List<Covariate> covariateNames = List.of(x1Covariate, x2Covariate);
final TreeTrainer<Double, Double> treeTrainer = TreeTrainer.<Double, Double>builder()
.groupDifferentiator(new WeightedVarianceGroupDifferentiator())
.covariates(covariateNames)
.responseCombiner(new MeanResponseCombiner())
.maxNodeDepth(30)
.nodeSize(5)
.numberOfSplits(0)
.build();
final List<Covariate> covariateNames = List.of(x1Covariate, x2Covariate);
final long startTime = System.currentTimeMillis();
final Node<Double> baseNode = treeTrainer.growTree(trainingSet, covariateNames);
final Node<Double> baseNode = treeTrainer.growTree(trainingSet);
final long endTime = System.currentTimeMillis();
System.out.println(((double)(endTime - startTime))/1000.0);

View file

@ -69,18 +69,21 @@ public class TrainSingleTreeFactor {
trainingSet.add(generateRow(x1, x2, x3, i));
}
final TreeTrainer<Double> treeTrainer = TreeTrainer.<Double>builder()
final List<Covariate> covariateNames = List.of(x1Covariate, x2Covariate);
final TreeTrainer<Double, Double> treeTrainer = TreeTrainer.<Double, Double>builder()
.groupDifferentiator(new WeightedVarianceGroupDifferentiator())
.responseCombiner(new MeanResponseCombiner())
.covariates(covariateNames)
.maxNodeDepth(30)
.nodeSize(5)
.numberOfSplits(5)
.build();
final List<Covariate> covariateNames = List.of(x1Covariate, x2Covariate);
final long startTime = System.currentTimeMillis();
final Node<Double> baseNode = treeTrainer.growTree(trainingSet, covariateNames);
final Node<Double> baseNode = treeTrainer.growTree(trainingSet);
final long endTime = System.currentTimeMillis();
System.out.println(((double)(endTime - startTime))/1000.0);

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

1165
src/test/resources/wihs.csv Normal file

File diff suppressed because it is too large Load diff