Finish competing risk implementation. Fix a bug in tree training
algorithm.
This commit is contained in:
parent
462b0d9c35
commit
fffdfe85bf
41 changed files with 4768 additions and 241 deletions
|
@ -5,12 +5,12 @@ import lombok.RequiredArgsConstructor;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Random;
|
import java.util.Random;
|
||||||
|
import java.util.concurrent.ThreadLocalRandom;
|
||||||
|
|
||||||
@RequiredArgsConstructor
|
@RequiredArgsConstructor
|
||||||
public class Bootstrapper<T> {
|
public class Bootstrapper<T> {
|
||||||
|
|
||||||
final private List<T> originalData;
|
final private List<T> originalData;
|
||||||
final private Random random = new Random();
|
|
||||||
|
|
||||||
public List<T> bootstrap(){
|
public List<T> bootstrap(){
|
||||||
final int n = originalData.size();
|
final int n = originalData.size();
|
||||||
|
@ -18,7 +18,7 @@ public class Bootstrapper<T> {
|
||||||
final List<T> newList = new ArrayList<>(n);
|
final List<T> newList = new ArrayList<>(n);
|
||||||
|
|
||||||
for(int i=0; i<n; i++){
|
for(int i=0; i<n; i++){
|
||||||
final int index = random.nextInt(n);
|
final int index = ThreadLocalRandom.current().nextInt(n);
|
||||||
|
|
||||||
newList.add(originalData.get(index));
|
newList.add(originalData.get(index));
|
||||||
}
|
}
|
||||||
|
|
|
@ -4,6 +4,8 @@ import ca.joeltherrien.randomforest.covariates.Covariate;
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
import lombok.RequiredArgsConstructor;
|
import lombok.RequiredArgsConstructor;
|
||||||
|
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
|
||||||
@RequiredArgsConstructor
|
@RequiredArgsConstructor
|
||||||
|
@ -24,4 +26,19 @@ public class CovariateRow {
|
||||||
return "CovariateRow " + this.id;
|
return "CovariateRow " + this.id;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public static CovariateRow createSimple(Map<String, String> simpleMap, List<Covariate> covariateList, int id){
|
||||||
|
final Map<String, Covariate.Value> valueMap = new HashMap<>();
|
||||||
|
final Map<String, Covariate> covariateMap = new HashMap<>();
|
||||||
|
|
||||||
|
covariateList.forEach(covariate -> covariateMap.put(covariate.getName(), covariate));
|
||||||
|
|
||||||
|
simpleMap.forEach((name, valueStr) -> {
|
||||||
|
if(covariateMap.containsKey(name)){
|
||||||
|
valueMap.put(name, covariateMap.get(name).createValue(valueStr));
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
return new CovariateRow(valueMap, id);
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
package ca.joeltherrien.randomforest;
|
package ca.joeltherrien.randomforest;
|
||||||
|
|
||||||
import ca.joeltherrien.randomforest.covariates.Covariate;
|
import ca.joeltherrien.randomforest.covariates.Covariate;
|
||||||
import ca.joeltherrien.randomforest.tree.GroupDifferentiator;
|
|
||||||
import com.fasterxml.jackson.databind.node.ObjectNode;
|
import com.fasterxml.jackson.databind.node.ObjectNode;
|
||||||
import lombok.RequiredArgsConstructor;
|
import lombok.RequiredArgsConstructor;
|
||||||
import org.apache.commons.csv.CSVFormat;
|
import org.apache.commons.csv.CSVFormat;
|
||||||
|
@ -49,11 +48,6 @@ public class DataLoader {
|
||||||
Y parse(CSVRecord record);
|
Y parse(CSVRecord record);
|
||||||
}
|
}
|
||||||
|
|
||||||
@FunctionalInterface
|
|
||||||
public interface ResponseLoaderConstructor<Y>{
|
|
||||||
ResponseLoader<Y> construct(ObjectNode node);
|
|
||||||
}
|
|
||||||
|
|
||||||
@RequiredArgsConstructor
|
@RequiredArgsConstructor
|
||||||
public static class DoubleLoader implements ResponseLoader<Double> {
|
public static class DoubleLoader implements ResponseLoader<Double> {
|
||||||
|
|
||||||
|
|
|
@ -44,7 +44,7 @@ public class Main {
|
||||||
|
|
||||||
final List<Row<Double>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getDataFileLocation());
|
final List<Row<Double>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getDataFileLocation());
|
||||||
|
|
||||||
final ForestTrainer<Double> forestTrainer = new ForestTrainer<>(settings, dataset, covariates);
|
final ForestTrainer<Double, Double, Double> forestTrainer = new ForestTrainer<>(settings, dataset, covariates);
|
||||||
|
|
||||||
if(settings.isSaveProgress()){
|
if(settings.isSaveProgress()){
|
||||||
forestTrainer.trainParallelOnDisk(settings.getNumberOfThreads());
|
forestTrainer.trainParallelOnDisk(settings.getNumberOfThreads());
|
||||||
|
@ -63,8 +63,14 @@ public class Main {
|
||||||
final ObjectNode groupDifferentiatorSettings = new ObjectNode(JsonNodeFactory.instance);
|
final ObjectNode groupDifferentiatorSettings = new ObjectNode(JsonNodeFactory.instance);
|
||||||
groupDifferentiatorSettings.set("type", new TextNode("WeightedVarianceGroupDifferentiator"));
|
groupDifferentiatorSettings.set("type", new TextNode("WeightedVarianceGroupDifferentiator"));
|
||||||
|
|
||||||
|
final ObjectNode responseCombinerSettings = new ObjectNode(JsonNodeFactory.instance);
|
||||||
|
responseCombinerSettings.set("type", new TextNode("MeanResponseCombiner"));
|
||||||
|
|
||||||
|
final ObjectNode treeCombinerSettings = new ObjectNode(JsonNodeFactory.instance);
|
||||||
|
treeCombinerSettings.set("type", new TextNode("MeanResponseCombiner"));
|
||||||
|
|
||||||
final ObjectNode yVarSettings = new ObjectNode(JsonNodeFactory.instance);
|
final ObjectNode yVarSettings = new ObjectNode(JsonNodeFactory.instance);
|
||||||
yVarSettings.set("type", new TextNode("y"));
|
yVarSettings.set("type", new TextNode("Double"));
|
||||||
yVarSettings.set("name", new TextNode("y"));
|
yVarSettings.set("name", new TextNode("y"));
|
||||||
|
|
||||||
final Settings settings = Settings.builder()
|
final Settings settings = Settings.builder()
|
||||||
|
@ -75,8 +81,8 @@ public class Main {
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
.dataFileLocation("data.csv")
|
.dataFileLocation("data.csv")
|
||||||
.responseCombiner("MeanResponseCombiner")
|
.responseCombinerSettings(responseCombinerSettings)
|
||||||
.treeResponseCombiner("MeanResponseCombiner")
|
.treeCombinerSettings(treeCombinerSettings)
|
||||||
.groupDifferentiatorSettings(groupDifferentiatorSettings)
|
.groupDifferentiatorSettings(groupDifferentiatorSettings)
|
||||||
.yVarSettings(yVarSettings)
|
.yVarSettings(yVarSettings)
|
||||||
.maxNodeDepth(100000)
|
.maxNodeDepth(100000)
|
||||||
|
|
|
@ -3,8 +3,10 @@ package ca.joeltherrien.randomforest;
|
||||||
import ca.joeltherrien.randomforest.covariates.CovariateSettings;
|
import ca.joeltherrien.randomforest.covariates.CovariateSettings;
|
||||||
import ca.joeltherrien.randomforest.responses.competingrisk.*;
|
import ca.joeltherrien.randomforest.responses.competingrisk.*;
|
||||||
import ca.joeltherrien.randomforest.responses.regression.MeanGroupDifferentiator;
|
import ca.joeltherrien.randomforest.responses.regression.MeanGroupDifferentiator;
|
||||||
|
import ca.joeltherrien.randomforest.responses.regression.MeanResponseCombiner;
|
||||||
import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceGroupDifferentiator;
|
import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceGroupDifferentiator;
|
||||||
import ca.joeltherrien.randomforest.tree.GroupDifferentiator;
|
import ca.joeltherrien.randomforest.tree.GroupDifferentiator;
|
||||||
|
import ca.joeltherrien.randomforest.tree.ResponseCombiner;
|
||||||
import com.fasterxml.jackson.annotation.JsonIgnore;
|
import com.fasterxml.jackson.annotation.JsonIgnore;
|
||||||
import com.fasterxml.jackson.databind.JsonNode;
|
import com.fasterxml.jackson.databind.JsonNode;
|
||||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||||
|
@ -16,6 +18,7 @@ import lombok.*;
|
||||||
import java.io.File;
|
import java.io.File;
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
|
import java.util.function.Function;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This class is saved & loaded using a saved configuration file. It contains all relevant settings when training a forest.
|
* This class is saved & loaded using a saved configuration file. It contains all relevant settings when training a forest.
|
||||||
|
@ -26,11 +29,11 @@ import java.util.*;
|
||||||
@EqualsAndHashCode
|
@EqualsAndHashCode
|
||||||
public class Settings {
|
public class Settings {
|
||||||
|
|
||||||
private static Map<String, DataLoader.ResponseLoaderConstructor> RESPONSE_LOADER_MAP = new HashMap<>();
|
private static Map<String, Function<ObjectNode, DataLoader.ResponseLoader>> RESPONSE_LOADER_MAP = new HashMap<>();
|
||||||
public static DataLoader.ResponseLoaderConstructor getResponseLoaderConstructor(final String name){
|
public static Function<ObjectNode, DataLoader.ResponseLoader> getResponseLoaderConstructor(final String name){
|
||||||
return RESPONSE_LOADER_MAP.get(name.toLowerCase());
|
return RESPONSE_LOADER_MAP.get(name.toLowerCase());
|
||||||
}
|
}
|
||||||
public static void registerResponseLoaderConstructor(final String name, final DataLoader.ResponseLoaderConstructor responseLoaderConstructor){
|
public static void registerResponseLoaderConstructor(final String name, final Function<ObjectNode, DataLoader.ResponseLoader> responseLoaderConstructor){
|
||||||
RESPONSE_LOADER_MAP.put(name.toLowerCase(), responseLoaderConstructor);
|
RESPONSE_LOADER_MAP.put(name.toLowerCase(), responseLoaderConstructor);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -38,19 +41,19 @@ public class Settings {
|
||||||
registerResponseLoaderConstructor("double",
|
registerResponseLoaderConstructor("double",
|
||||||
node -> new DataLoader.DoubleLoader(node)
|
node -> new DataLoader.DoubleLoader(node)
|
||||||
);
|
);
|
||||||
registerResponseLoaderConstructor("CompetingResponse",
|
registerResponseLoaderConstructor("CompetingRiskResponse",
|
||||||
node -> new CompetingResponse.CompetingResponseLoader(node)
|
node -> new CompetingRiskResponse.CompetingResponseLoader(node)
|
||||||
);
|
);
|
||||||
registerResponseLoaderConstructor("CompetingResponseWithCensorTime",
|
registerResponseLoaderConstructor("CompetingRiskResponseWithCensorTime",
|
||||||
node -> new CompetingResponseWithCensorTime.CompetingResponseWithCensorTimeLoader(node)
|
node -> new CompetingRiskResponseWithCensorTime.CompetingResponseWithCensorTimeLoader(node)
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
private static Map<String, GroupDifferentiator.GroupDifferentiatorConstructor> GROUP_DIFFERENTIATOR_MAP = new HashMap<>();
|
private static Map<String, Function<ObjectNode, GroupDifferentiator>> GROUP_DIFFERENTIATOR_MAP = new HashMap<>();
|
||||||
public static GroupDifferentiator.GroupDifferentiatorConstructor getGroupDifferentiatorConstructor(final String name){
|
public static Function<ObjectNode, GroupDifferentiator> getGroupDifferentiatorConstructor(final String name){
|
||||||
return GROUP_DIFFERENTIATOR_MAP.get(name.toLowerCase());
|
return GROUP_DIFFERENTIATOR_MAP.get(name.toLowerCase());
|
||||||
}
|
}
|
||||||
public static void registerGroupDifferentiatorConstructor(final String name, final GroupDifferentiator.GroupDifferentiatorConstructor groupDifferentiatorConstructor){
|
public static void registerGroupDifferentiatorConstructor(final String name, final Function<ObjectNode, GroupDifferentiator> groupDifferentiatorConstructor){
|
||||||
GROUP_DIFFERENTIATOR_MAP.put(name.toLowerCase(), groupDifferentiatorConstructor);
|
GROUP_DIFFERENTIATOR_MAP.put(name.toLowerCase(), groupDifferentiatorConstructor);
|
||||||
}
|
}
|
||||||
static{
|
static{
|
||||||
|
@ -98,13 +101,67 @@ public class Settings {
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private static Map<String, Function<ObjectNode, ResponseCombiner>> RESPONSE_COMBINER_MAP = new HashMap<>();
|
||||||
|
public static Function<ObjectNode, ResponseCombiner> getResponseCombinerConstructor(final String name){
|
||||||
|
return RESPONSE_COMBINER_MAP.get(name.toLowerCase());
|
||||||
|
}
|
||||||
|
public static void registerResponseCombinerConstructor(final String name, final Function<ObjectNode, ResponseCombiner> responseCombinerConstructor){
|
||||||
|
RESPONSE_COMBINER_MAP.put(name.toLowerCase(), responseCombinerConstructor);
|
||||||
|
}
|
||||||
|
|
||||||
|
static{
|
||||||
|
|
||||||
|
registerResponseCombinerConstructor("MeanResponseCombiner",
|
||||||
|
(node) -> new MeanResponseCombiner()
|
||||||
|
);
|
||||||
|
registerResponseCombinerConstructor("CompetingRiskResponseCombiner",
|
||||||
|
(node) -> {
|
||||||
|
final List<Integer> eventList = new ArrayList<>();
|
||||||
|
node.get("events").elements().forEachRemaining(event -> eventList.add(event.asInt()));
|
||||||
|
final int[] events = eventList.stream().mapToInt(i -> i).toArray();
|
||||||
|
|
||||||
|
double[] times = null;
|
||||||
|
// note that times may be null
|
||||||
|
if(node.hasNonNull("times")){
|
||||||
|
final List<Double> timeList = new ArrayList<>();
|
||||||
|
node.get("times").elements().forEachRemaining(event -> timeList.add(event.asDouble()));
|
||||||
|
times = eventList.stream().mapToDouble(db -> db).toArray();
|
||||||
|
}
|
||||||
|
|
||||||
|
return new CompetingRiskResponseCombiner(events, times);
|
||||||
|
|
||||||
|
}
|
||||||
|
);
|
||||||
|
|
||||||
|
registerResponseCombinerConstructor("CompetingRiskFunctionCombiner",
|
||||||
|
(node) -> {
|
||||||
|
final List<Integer> eventList = new ArrayList<>();
|
||||||
|
node.get("events").elements().forEachRemaining(event -> eventList.add(event.asInt()));
|
||||||
|
final int[] events = eventList.stream().mapToInt(i -> i).toArray();
|
||||||
|
|
||||||
|
double[] times = null;
|
||||||
|
// note that times may be null
|
||||||
|
if(node.hasNonNull("times")){
|
||||||
|
final List<Double> timeList = new ArrayList<>();
|
||||||
|
node.get("times").elements().forEachRemaining(event -> timeList.add(event.asDouble()));
|
||||||
|
times = eventList.stream().mapToDouble(db -> db).toArray();
|
||||||
|
}
|
||||||
|
|
||||||
|
return new CompetingRiskFunctionCombiner(events, times);
|
||||||
|
|
||||||
|
}
|
||||||
|
);
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
private int numberOfSplits = 5;
|
private int numberOfSplits = 5;
|
||||||
private int nodeSize = 5;
|
private int nodeSize = 5;
|
||||||
private int maxNodeDepth = 1000000; // basically no maxNodeDepth
|
private int maxNodeDepth = 1000000; // basically no maxNodeDepth
|
||||||
|
|
||||||
private String responseCombiner;
|
private ObjectNode responseCombinerSettings = new ObjectNode(JsonNodeFactory.instance);
|
||||||
private ObjectNode groupDifferentiatorSettings = new ObjectNode(JsonNodeFactory.instance);
|
private ObjectNode groupDifferentiatorSettings = new ObjectNode(JsonNodeFactory.instance);
|
||||||
private String treeResponseCombiner;
|
private ObjectNode treeCombinerSettings = new ObjectNode(JsonNodeFactory.instance);
|
||||||
|
|
||||||
private List<CovariateSettings> covariates = new ArrayList<>();
|
private List<CovariateSettings> covariates = new ArrayList<>();
|
||||||
private ObjectNode yVarSettings = new ObjectNode(JsonNodeFactory.instance);
|
private ObjectNode yVarSettings = new ObjectNode(JsonNodeFactory.instance);
|
||||||
|
@ -148,14 +205,28 @@ public class Settings {
|
||||||
public GroupDifferentiator getGroupDifferentiator(){
|
public GroupDifferentiator getGroupDifferentiator(){
|
||||||
final String type = groupDifferentiatorSettings.get("type").asText();
|
final String type = groupDifferentiatorSettings.get("type").asText();
|
||||||
|
|
||||||
return getGroupDifferentiatorConstructor(type).construct(groupDifferentiatorSettings);
|
return getGroupDifferentiatorConstructor(type).apply(groupDifferentiatorSettings);
|
||||||
}
|
}
|
||||||
|
|
||||||
@JsonIgnore
|
@JsonIgnore
|
||||||
public DataLoader.ResponseLoader getResponseLoader(){
|
public DataLoader.ResponseLoader getResponseLoader(){
|
||||||
final String type = yVarSettings.get("type").asText();
|
final String type = yVarSettings.get("type").asText();
|
||||||
|
|
||||||
return getResponseLoaderConstructor(type).construct(yVarSettings);
|
return getResponseLoaderConstructor(type).apply(yVarSettings);
|
||||||
|
}
|
||||||
|
|
||||||
|
@JsonIgnore
|
||||||
|
public ResponseCombiner getResponseCombiner(){
|
||||||
|
final String type = responseCombinerSettings.get("type").asText();
|
||||||
|
|
||||||
|
return getResponseCombinerConstructor(type).apply(responseCombinerSettings);
|
||||||
|
}
|
||||||
|
|
||||||
|
@JsonIgnore
|
||||||
|
public ResponseCombiner getTreeCombiner(){
|
||||||
|
final String type = treeCombinerSettings.get("type").asText();
|
||||||
|
|
||||||
|
return getResponseCombinerConstructor(type).apply(treeCombinerSettings);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -42,6 +42,11 @@ public final class BooleanCovariate implements Covariate<Boolean>{
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String toString(){
|
||||||
|
return "BooleanCovariate(name=" + name + ")";
|
||||||
|
}
|
||||||
|
|
||||||
public class BooleanValue implements Value<Boolean>{
|
public class BooleanValue implements Value<Boolean>{
|
||||||
|
|
||||||
private final Boolean value;
|
private final Boolean value;
|
||||||
|
|
|
@ -82,6 +82,11 @@ public final class FactorCovariate implements Covariate<String>{
|
||||||
return factorValue;
|
return factorValue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String toString(){
|
||||||
|
return "FactorCovariate(name=" + name + ")";
|
||||||
|
}
|
||||||
|
|
||||||
@EqualsAndHashCode
|
@EqualsAndHashCode
|
||||||
public final class FactorValue implements Covariate.Value<String>{
|
public final class FactorValue implements Covariate.Value<String>{
|
||||||
|
|
||||||
|
|
|
@ -1,13 +1,16 @@
|
||||||
package ca.joeltherrien.randomforest.covariates;
|
package ca.joeltherrien.randomforest.covariates;
|
||||||
|
|
||||||
|
import lombok.EqualsAndHashCode;
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
import lombok.RequiredArgsConstructor;
|
import lombok.RequiredArgsConstructor;
|
||||||
|
import lombok.ToString;
|
||||||
|
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
import java.util.concurrent.ThreadLocalRandom;
|
import java.util.concurrent.ThreadLocalRandom;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
@RequiredArgsConstructor
|
@RequiredArgsConstructor
|
||||||
|
@ToString
|
||||||
public final class NumericCovariate implements Covariate<Double>{
|
public final class NumericCovariate implements Covariate<Double>{
|
||||||
|
|
||||||
@Getter
|
@Getter
|
||||||
|
@ -20,16 +23,18 @@ public final class NumericCovariate implements Covariate<Double>{
|
||||||
|
|
||||||
// only work with non-NA values
|
// only work with non-NA values
|
||||||
data = data.stream().filter(value -> !value.isNA()).collect(Collectors.toList());
|
data = data.stream().filter(value -> !value.isNA()).collect(Collectors.toList());
|
||||||
|
//data = data.stream().filter(value -> !value.isNA()).distinct().collect(Collectors.toList()); // TODO which to use?
|
||||||
|
|
||||||
// for this implementation we need to shuffle the data
|
// for this implementation we need to shuffle the data
|
||||||
final List<Value<Double>> shuffledData;
|
final List<Value<Double>> shuffledData;
|
||||||
if(number > data.size()){
|
if(number >= data.size()){
|
||||||
shuffledData = new ArrayList<>(data);
|
shuffledData = new ArrayList<>(data);
|
||||||
Collections.shuffle(shuffledData, random);
|
Collections.shuffle(shuffledData, random);
|
||||||
}
|
}
|
||||||
else{ // only need the top number entries
|
else{ // only need the top number entries
|
||||||
shuffledData = new ArrayList<>(number);
|
shuffledData = new ArrayList<>(number);
|
||||||
final Set<Integer> indexesToUse = new HashSet<>();
|
final Set<Integer> indexesToUse = new HashSet<>();
|
||||||
|
//final List<Integer> indexesToUse = new ArrayList<>(); // TODO which to use?
|
||||||
|
|
||||||
while(indexesToUse.size() < number){
|
while(indexesToUse.size() < number){
|
||||||
final int index = random.nextInt(data.size());
|
final int index = random.nextInt(data.size());
|
||||||
|
@ -56,7 +61,7 @@ public final class NumericCovariate implements Covariate<Double>{
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Value<Double> createValue(String value) {
|
public NumericValue createValue(String value) {
|
||||||
if(value == null || value.equalsIgnoreCase("na")){
|
if(value == null || value.equalsIgnoreCase("na")){
|
||||||
return createValue((Double) null);
|
return createValue((Double) null);
|
||||||
}
|
}
|
||||||
|
@ -64,6 +69,7 @@ public final class NumericCovariate implements Covariate<Double>{
|
||||||
return createValue(Double.parseDouble(value));
|
return createValue(Double.parseDouble(value));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@EqualsAndHashCode
|
||||||
public class NumericValue implements Covariate.Value<Double>{
|
public class NumericValue implements Covariate.Value<Double>{
|
||||||
|
|
||||||
private final Double value; // may be null
|
private final Double value; // may be null
|
||||||
|
@ -88,6 +94,7 @@ public final class NumericCovariate implements Covariate<Double>{
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@EqualsAndHashCode
|
||||||
public class NumericSplitRule implements Covariate.SplitRule<Double>{
|
public class NumericSplitRule implements Covariate.SplitRule<Double>{
|
||||||
|
|
||||||
private final double threshold;
|
private final double threshold;
|
||||||
|
|
|
@ -1,11 +0,0 @@
|
||||||
package ca.joeltherrien.randomforest.responses.competingrisk;
|
|
||||||
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
public class CompetingRiskFunction {
|
|
||||||
|
|
||||||
private List<Point> pointList;
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
|
@ -0,0 +1,81 @@
|
||||||
|
package ca.joeltherrien.randomforest.responses.competingrisk;
|
||||||
|
|
||||||
|
import ca.joeltherrien.randomforest.tree.ResponseCombiner;
|
||||||
|
import lombok.RequiredArgsConstructor;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
@RequiredArgsConstructor
|
||||||
|
public class CompetingRiskFunctionCombiner implements ResponseCombiner<CompetingRiskFunctions, CompetingRiskFunctions> {
|
||||||
|
|
||||||
|
private final int[] events;
|
||||||
|
private final double[] times; // We may restrict ourselves to specific times.
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public CompetingRiskFunctions combine(List<CompetingRiskFunctions> responses) {
|
||||||
|
|
||||||
|
final double[] timesToUse;
|
||||||
|
if(times != null){
|
||||||
|
timesToUse = times;
|
||||||
|
}
|
||||||
|
else{
|
||||||
|
timesToUse = responses.stream()
|
||||||
|
.map(functions -> functions.getSurvivalCurve())
|
||||||
|
.flatMapToDouble(
|
||||||
|
function -> function.getPoints().stream()
|
||||||
|
.mapToDouble(point -> point.getTime())
|
||||||
|
).sorted().distinct().toArray();
|
||||||
|
}
|
||||||
|
|
||||||
|
final double n = responses.size();
|
||||||
|
|
||||||
|
final List<Point> survivalPoints = new ArrayList<>(timesToUse.length);
|
||||||
|
for(final double time : timesToUse){
|
||||||
|
|
||||||
|
final double survivalY = responses.stream()
|
||||||
|
.mapToDouble(functions -> functions.getSurvivalCurve().evaluate(time).getY() / n)
|
||||||
|
.sum();
|
||||||
|
|
||||||
|
survivalPoints.add(new Point(time, survivalY));
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
final MathFunction survivalFunction = new MathFunction(survivalPoints, new Point(0.0, 1.0));
|
||||||
|
final Map<Integer, MathFunction> causeSpecificCumulativeHazardFunctionMap = new HashMap<>();
|
||||||
|
final Map<Integer, MathFunction> cumulativeIncidenceFunctionMap = new HashMap<>();
|
||||||
|
|
||||||
|
for(final int event : events){
|
||||||
|
|
||||||
|
final List<Point> cumulativeHazardFunctionPoints = new ArrayList<>(timesToUse.length);
|
||||||
|
final List<Point> cumulativeIncidenceFunctionPoints = new ArrayList<>(timesToUse.length);
|
||||||
|
|
||||||
|
for(final double time : timesToUse){
|
||||||
|
|
||||||
|
final double hazardY = responses.stream()
|
||||||
|
.mapToDouble(functions -> functions.getCauseSpecificHazardFunction(event).evaluate(time).getY() / n)
|
||||||
|
.sum();
|
||||||
|
|
||||||
|
final double incidenceY = responses.stream()
|
||||||
|
.mapToDouble(functions -> functions.getCumulativeIncidenceFunction(event).evaluate(time).getY() / n)
|
||||||
|
.sum();
|
||||||
|
|
||||||
|
cumulativeHazardFunctionPoints.add(new Point(time, hazardY));
|
||||||
|
cumulativeIncidenceFunctionPoints.add(new Point(time, incidenceY));
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
causeSpecificCumulativeHazardFunctionMap.put(event, new MathFunction(cumulativeHazardFunctionPoints));
|
||||||
|
cumulativeIncidenceFunctionMap.put(event, new MathFunction(cumulativeIncidenceFunctionPoints));
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
return CompetingRiskFunctions.builder()
|
||||||
|
.causeSpecificHazardFunctionMap(causeSpecificCumulativeHazardFunctionMap)
|
||||||
|
.cumulativeIncidenceFunctionMap(cumulativeIncidenceFunctionMap)
|
||||||
|
.survivalCurve(survivalFunction)
|
||||||
|
.build();
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,25 @@
|
||||||
|
package ca.joeltherrien.randomforest.responses.competingrisk;
|
||||||
|
|
||||||
|
import lombok.Builder;
|
||||||
|
import lombok.Getter;
|
||||||
|
import lombok.RequiredArgsConstructor;
|
||||||
|
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
@Builder
|
||||||
|
public class CompetingRiskFunctions {
|
||||||
|
|
||||||
|
private final Map<Integer, MathFunction> causeSpecificHazardFunctionMap;
|
||||||
|
private final Map<Integer, MathFunction> cumulativeIncidenceFunctionMap;
|
||||||
|
|
||||||
|
@Getter
|
||||||
|
private final MathFunction survivalCurve;
|
||||||
|
|
||||||
|
public MathFunction getCauseSpecificHazardFunction(int cause){
|
||||||
|
return causeSpecificHazardFunctionMap.get(cause);
|
||||||
|
}
|
||||||
|
|
||||||
|
public MathFunction getCumulativeIncidenceFunction(int cause) {
|
||||||
|
return cumulativeIncidenceFunctionMap.get(cause);
|
||||||
|
}
|
||||||
|
}
|
|
@ -12,14 +12,14 @@ import java.util.stream.Stream;
|
||||||
* modifies the abstract method.
|
* modifies the abstract method.
|
||||||
*
|
*
|
||||||
*/
|
*/
|
||||||
public abstract class CompetingRiskGroupDifferentiator<Y extends CompetingResponse> implements GroupDifferentiator<Y>{
|
public abstract class CompetingRiskGroupDifferentiator<Y extends CompetingRiskResponse> implements GroupDifferentiator<Y>{
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public abstract Double differentiate(List<Y> leftHand, List<Y> rightHand);
|
public abstract Double differentiate(List<Y> leftHand, List<Y> rightHand);
|
||||||
|
|
||||||
abstract double riskSet(final List<Y> eventList, double time, int eventOfFocus);
|
abstract double riskSet(final List<Y> eventList, double time, int eventOfFocus);
|
||||||
|
|
||||||
private double numberOFEventsAtTime(int eventOfFocus, List<Y> eventList, double time){
|
private double numberOfEventsAtTime(int eventOfFocus, List<Y> eventList, double time){
|
||||||
return (double) eventList.stream()
|
return (double) eventList.stream()
|
||||||
.filter(event -> event.getDelta() == eventOfFocus)
|
.filter(event -> event.getDelta() == eventOfFocus)
|
||||||
.filter(event -> event.getU() == time) // since delta != 0 we know censoring didn't occur prior to this
|
.filter(event -> event.getU() == time) // since delta != 0 we know censoring didn't occur prior to this
|
||||||
|
@ -31,8 +31,8 @@ public abstract class CompetingRiskGroupDifferentiator<Y extends CompetingRespon
|
||||||
* Calculates the log rank value (or the Gray's test value) for a *specific* event cause.
|
* Calculates the log rank value (or the Gray's test value) for a *specific* event cause.
|
||||||
*
|
*
|
||||||
* @param eventOfFocus
|
* @param eventOfFocus
|
||||||
* @param leftHand A non-empty list of CompetingResponse
|
* @param leftHand A non-empty list of CompetingRiskResponse
|
||||||
* @param rightHand A non-empty list of CompetingResponse
|
* @param rightHand A non-empty list of CompetingRiskResponse
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
LogRankValue specificLogRankValue(final int eventOfFocus, List<Y> leftHand, List<Y> rightHand){
|
LogRankValue specificLogRankValue(final int eventOfFocus, List<Y> leftHand, List<Y> rightHand){
|
||||||
|
@ -40,34 +40,42 @@ public abstract class CompetingRiskGroupDifferentiator<Y extends CompetingRespon
|
||||||
final double[] distinctEventTimes = Stream.concat(
|
final double[] distinctEventTimes = Stream.concat(
|
||||||
leftHand.stream(), rightHand.stream()
|
leftHand.stream(), rightHand.stream()
|
||||||
)
|
)
|
||||||
.filter(event -> event.getDelta() != 0) // remove censored events
|
.filter(event -> !event.isCensored())
|
||||||
.mapToDouble(event -> event.getU())
|
.mapToDouble(event -> event.getU())
|
||||||
.distinct()
|
.distinct()
|
||||||
.toArray();
|
.toArray();
|
||||||
|
|
||||||
double summation = 0.0;
|
double summation = 0.0;
|
||||||
double varianceSquared = 0.0;
|
double variance = 0.0;
|
||||||
|
|
||||||
for(final double time_k : distinctEventTimes){
|
for(final double time_k : distinctEventTimes){
|
||||||
final double weight = weight(time_k); // W_j(t_k)
|
final double weight = weight(time_k); // W_j(t_k)
|
||||||
final double numberEventsAtTimeDaughterLeft = numberOFEventsAtTime(eventOfFocus, leftHand, time_k); // d_{j,l}(t_k)
|
final double numberEventsAtTimeDaughterLeft = numberOfEventsAtTime(eventOfFocus, leftHand, time_k); // d_{j,l}(t_k)
|
||||||
final double numberEventsAtTimeDaughterRight = numberOFEventsAtTime(eventOfFocus, rightHand, time_k); // d_{j,r}(t_k)
|
final double numberEventsAtTimeDaughterRight = numberOfEventsAtTime(eventOfFocus, rightHand, time_k); // d_{j,r}(t_k)
|
||||||
final double numberOfEventsAtTime = numberEventsAtTimeDaughterLeft + numberEventsAtTimeDaughterRight; // d_j(t_k)
|
final double numberOfEventsAtTime = numberEventsAtTimeDaughterLeft + numberEventsAtTimeDaughterRight; // d_j(t_k)
|
||||||
|
|
||||||
final double individualsAtRiskDaughterLeft = riskSet(leftHand, time_k, eventOfFocus); // Y_l(t_k)
|
final double individualsAtRiskDaughterLeft = riskSet(leftHand, time_k, eventOfFocus); // Y_l(t_k)
|
||||||
final double individualsAtRiskDaughterRight = riskSet(rightHand, time_k, eventOfFocus); // Y_r(t_k)
|
final double individualsAtRiskDaughterRight = riskSet(rightHand, time_k, eventOfFocus); // Y_r(t_k)
|
||||||
final double individualsAtRisk = individualsAtRiskDaughterLeft + individualsAtRiskDaughterRight; // Y(t_k)
|
final double individualsAtRisk = individualsAtRiskDaughterLeft + individualsAtRiskDaughterRight; // Y(t_k)
|
||||||
|
|
||||||
summation = summation + weight*(numberEventsAtTimeDaughterLeft - numberOfEventsAtTime*individualsAtRiskDaughterLeft/individualsAtRisk);
|
final double deltaSummation = weight*(numberEventsAtTimeDaughterLeft - numberOfEventsAtTime*individualsAtRiskDaughterLeft/individualsAtRisk);
|
||||||
|
final double deltaVariance = weight*weight*numberOfEventsAtTime*individualsAtRiskDaughterLeft/individualsAtRisk
|
||||||
varianceSquared = varianceSquared + weight*weight*numberOfEventsAtTime*individualsAtRiskDaughterLeft/individualsAtRisk
|
|
||||||
* (1.0 - individualsAtRiskDaughterLeft / individualsAtRisk)
|
* (1.0 - individualsAtRiskDaughterLeft / individualsAtRisk)
|
||||||
* ((individualsAtRisk - numberOfEventsAtTime) / (individualsAtRisk - 1.0));
|
* ((individualsAtRisk - numberOfEventsAtTime) / (individualsAtRisk - 1.0));
|
||||||
|
|
||||||
|
// Note - notation differs slightly with what is found in STAT 855 notes, but they are equivalent.
|
||||||
|
// Note - if individualsAtRisk == 1 then variance will be NaN.
|
||||||
|
if(!Double.isNaN(deltaVariance)){
|
||||||
|
summation += deltaSummation;
|
||||||
|
variance += deltaVariance;
|
||||||
|
}
|
||||||
|
else{
|
||||||
|
// Do nothing; else statement left for breakpoints.
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return new LogRankValue(summation, varianceSquared);
|
return new LogRankValue(summation, variance);
|
||||||
}
|
}
|
||||||
|
|
||||||
double weight(double time){
|
double weight(double time){
|
||||||
|
@ -80,10 +88,10 @@ public abstract class CompetingRiskGroupDifferentiator<Y extends CompetingRespon
|
||||||
@AllArgsConstructor
|
@AllArgsConstructor
|
||||||
static class LogRankValue{
|
static class LogRankValue{
|
||||||
private final double numerator;
|
private final double numerator;
|
||||||
private final double varianceSquared;
|
private final double variance;
|
||||||
|
|
||||||
public double getVariance(){
|
public double getVarianceSqrt(){
|
||||||
return Math.sqrt(varianceSquared);
|
return Math.sqrt(variance);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -7,14 +7,18 @@ import lombok.RequiredArgsConstructor;
|
||||||
import org.apache.commons.csv.CSVRecord;
|
import org.apache.commons.csv.CSVRecord;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
public class CompetingResponse {
|
public class CompetingRiskResponse {
|
||||||
|
|
||||||
private final int delta;
|
private final int delta;
|
||||||
private final double u;
|
private final double u;
|
||||||
|
|
||||||
|
public boolean isCensored(){
|
||||||
|
return delta == 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@RequiredArgsConstructor
|
@RequiredArgsConstructor
|
||||||
public static class CompetingResponseLoader implements DataLoader.ResponseLoader<CompetingResponse>{
|
public static class CompetingResponseLoader implements DataLoader.ResponseLoader<CompetingRiskResponse>{
|
||||||
|
|
||||||
private final String deltaName;
|
private final String deltaName;
|
||||||
private final String uName;
|
private final String uName;
|
||||||
|
@ -25,11 +29,11 @@ public class CompetingResponse {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public CompetingResponse parse(CSVRecord record) {
|
public CompetingRiskResponse parse(CSVRecord record) {
|
||||||
final int delta = Integer.parseInt(record.get(deltaName));
|
final int delta = Integer.parseInt(record.get(deltaName));
|
||||||
final double u = Double.parseDouble(record.get(uName));
|
final double u = Double.parseDouble(record.get(uName));
|
||||||
|
|
||||||
return new CompetingResponse(delta, u);
|
return new CompetingRiskResponse(delta, u);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,123 @@
|
||||||
|
package ca.joeltherrien.randomforest.responses.competingrisk;
|
||||||
|
|
||||||
|
import ca.joeltherrien.randomforest.tree.ResponseCombiner;
|
||||||
|
import lombok.RequiredArgsConstructor;
|
||||||
|
|
||||||
|
import java.util.*;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This class takes all of the observations in a terminal node and combines them to produce estimates of the cause-specific hazard function
|
||||||
|
* and the cumulative incidence curve.
|
||||||
|
*
|
||||||
|
* See https://kogalur.github.io/randomForestSRC/theory.html for details.
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
@RequiredArgsConstructor
|
||||||
|
public class CompetingRiskResponseCombiner implements ResponseCombiner<CompetingRiskResponse, CompetingRiskFunctions> {
|
||||||
|
|
||||||
|
private final int[] events;
|
||||||
|
private final double[] times; // We may restrict ourselves to specific times.
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public CompetingRiskFunctions combine(List<CompetingRiskResponse> responses) {
|
||||||
|
|
||||||
|
final Map<Integer, MathFunction> causeSpecificCumulativeHazardFunctionMap = new HashMap<>();
|
||||||
|
final Map<Integer, MathFunction> cumulativeIncidenceFunctionMap = new HashMap<>();
|
||||||
|
|
||||||
|
final double[] timesToUse;
|
||||||
|
if(times != null){
|
||||||
|
timesToUse = this.times;
|
||||||
|
}
|
||||||
|
else{
|
||||||
|
timesToUse = responses.stream()
|
||||||
|
.filter(response -> !response.isCensored())
|
||||||
|
.mapToDouble(response -> response.getU())
|
||||||
|
.sorted().distinct()
|
||||||
|
.toArray();
|
||||||
|
}
|
||||||
|
|
||||||
|
final double[] individualsAtRiskArray = Arrays.stream(timesToUse).map(time -> riskSet(responses, time)).toArray();
|
||||||
|
|
||||||
|
// First we need to develop the overall survival curve!
|
||||||
|
final List<Point> survivalPoints = new ArrayList<>(timesToUse.length);
|
||||||
|
double previousSurvivalValue = 1.0;
|
||||||
|
for(int i=0; i<timesToUse.length; i++){
|
||||||
|
final double time_k = timesToUse[i];
|
||||||
|
final double individualsAtRisk = individualsAtRiskArray[i]; // Y(t_k)
|
||||||
|
final double numberOfEventsAtTime = (double) responses.stream()
|
||||||
|
.filter(event -> !event.isCensored())
|
||||||
|
.filter(event -> event.getU() == time_k) // since delta != 0 we know censoring didn't occur prior to this
|
||||||
|
.count();
|
||||||
|
|
||||||
|
final double newValue = previousSurvivalValue * (1.0 - numberOfEventsAtTime / individualsAtRisk);
|
||||||
|
survivalPoints.add(new Point(time_k, newValue));
|
||||||
|
previousSurvivalValue = newValue;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
final MathFunction survivalCurve = new MathFunction(survivalPoints, new Point(0.0, 1.0));
|
||||||
|
|
||||||
|
|
||||||
|
for(final int event : events){
|
||||||
|
|
||||||
|
final List<Point> hazardFunctionPoints = new ArrayList<>(timesToUse.length);
|
||||||
|
Point previousHazardFunctionPoint = new Point(0.0, 0.0);
|
||||||
|
|
||||||
|
final List<Point> cifPoints = new ArrayList<>(timesToUse.length);
|
||||||
|
Point previousCIFPoint = new Point(0.0, 0.0);
|
||||||
|
|
||||||
|
for(int i=0; i<timesToUse.length; i++){
|
||||||
|
final double time_k = timesToUse[i];
|
||||||
|
final double individualsAtRisk = individualsAtRiskArray[i]; // Y(t_k)
|
||||||
|
final double numberEventsAtTime = numberOfEventsAtTime(event, responses, time_k); // d_j(t_k)
|
||||||
|
|
||||||
|
// Cause-specific cumulative hazard function
|
||||||
|
final double hazardDeltaY = numberEventsAtTime / individualsAtRisk;
|
||||||
|
final Point newHazardPoint = new Point(time_k, previousHazardFunctionPoint.getY() + hazardDeltaY);
|
||||||
|
hazardFunctionPoints.add(newHazardPoint);
|
||||||
|
previousHazardFunctionPoint = newHazardPoint;
|
||||||
|
|
||||||
|
|
||||||
|
// Cumulative incidence function
|
||||||
|
// TODO - confirm this behaviour
|
||||||
|
//final double previousSurvivalEvaluation = i > 0 ? survivalCurve.evaluate(timesToUse[i-1]).getY() : survivalCurve.evaluate(0.0).getY();
|
||||||
|
final double previousSurvivalEvaluation = i > 0 ? survivalCurve.evaluate(timesToUse[i-1]).getY() : 1.0;
|
||||||
|
|
||||||
|
final double cifDeltaY = previousSurvivalEvaluation * (numberEventsAtTime / individualsAtRisk);
|
||||||
|
final Point newCIFPoint = new Point(time_k, previousCIFPoint.getY() + cifDeltaY);
|
||||||
|
cifPoints.add(newCIFPoint);
|
||||||
|
previousCIFPoint = newCIFPoint;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
final MathFunction causeSpecificCumulativeHazardFunction = new MathFunction(hazardFunctionPoints);
|
||||||
|
causeSpecificCumulativeHazardFunctionMap.put(event, causeSpecificCumulativeHazardFunction);
|
||||||
|
|
||||||
|
final MathFunction cifFunction = new MathFunction(cifPoints);
|
||||||
|
cumulativeIncidenceFunctionMap.put(event, cifFunction);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
return CompetingRiskFunctions.builder()
|
||||||
|
.causeSpecificHazardFunctionMap(causeSpecificCumulativeHazardFunctionMap)
|
||||||
|
.cumulativeIncidenceFunctionMap(cumulativeIncidenceFunctionMap)
|
||||||
|
.survivalCurve(survivalCurve)
|
||||||
|
.build();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
private double riskSet(List<CompetingRiskResponse> eventList, double time) {
|
||||||
|
return eventList.stream()
|
||||||
|
.filter(event -> event.getU() >= time)
|
||||||
|
.count();
|
||||||
|
}
|
||||||
|
|
||||||
|
private double numberOfEventsAtTime(int eventOfFocus, List<CompetingRiskResponse> eventList, double time){
|
||||||
|
return (double) eventList.stream()
|
||||||
|
.filter(event -> event.getDelta() == eventOfFocus)
|
||||||
|
.filter(event -> event.getU() == time) // since delta != 0 we know censoring didn't occur prior to this
|
||||||
|
.count();
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -11,16 +11,16 @@ import org.apache.commons.csv.CSVRecord;
|
||||||
*
|
*
|
||||||
*/
|
*/
|
||||||
@Data
|
@Data
|
||||||
public class CompetingResponseWithCensorTime extends CompetingResponse{
|
public class CompetingRiskResponseWithCensorTime extends CompetingRiskResponse {
|
||||||
private final double c;
|
private final double c;
|
||||||
|
|
||||||
public CompetingResponseWithCensorTime(int delta, double u, double c) {
|
public CompetingRiskResponseWithCensorTime(int delta, double u, double c) {
|
||||||
super(delta, u);
|
super(delta, u);
|
||||||
this.c = c;
|
this.c = c;
|
||||||
}
|
}
|
||||||
|
|
||||||
@RequiredArgsConstructor
|
@RequiredArgsConstructor
|
||||||
public static class CompetingResponseWithCensorTimeLoader implements DataLoader.ResponseLoader<CompetingResponseWithCensorTime>{
|
public static class CompetingResponseWithCensorTimeLoader implements DataLoader.ResponseLoader<CompetingRiskResponseWithCensorTime>{
|
||||||
|
|
||||||
private final String deltaName;
|
private final String deltaName;
|
||||||
private final String uName;
|
private final String uName;
|
||||||
|
@ -33,12 +33,12 @@ public class CompetingResponseWithCensorTime extends CompetingResponse{
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public CompetingResponseWithCensorTime parse(CSVRecord record) {
|
public CompetingRiskResponseWithCensorTime parse(CSVRecord record) {
|
||||||
final int delta = Integer.parseInt(record.get(deltaName));
|
final int delta = Integer.parseInt(record.get(deltaName));
|
||||||
final double u = Double.parseDouble(record.get(uName));
|
final double u = Double.parseDouble(record.get(uName));
|
||||||
final double c = Double.parseDouble(record.get(cName));
|
final double c = Double.parseDouble(record.get(cName));
|
||||||
|
|
||||||
return new CompetingResponseWithCensorTime(delta, u, c);
|
return new CompetingRiskResponseWithCensorTime(delta, u, c);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
|
@ -1,12 +1,7 @@
|
||||||
package ca.joeltherrien.randomforest.responses.competingrisk;
|
package ca.joeltherrien.randomforest.responses.competingrisk;
|
||||||
|
|
||||||
import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceGroupDifferentiator;
|
|
||||||
import ca.joeltherrien.randomforest.tree.GroupDifferentiator;
|
|
||||||
import com.fasterxml.jackson.databind.JsonNode;
|
|
||||||
import lombok.RequiredArgsConstructor;
|
import lombok.RequiredArgsConstructor;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.Iterator;
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -14,12 +9,12 @@ import java.util.List;
|
||||||
*
|
*
|
||||||
*/
|
*/
|
||||||
@RequiredArgsConstructor
|
@RequiredArgsConstructor
|
||||||
public class GrayLogRankMultipleGroupDifferentiator extends CompetingRiskGroupDifferentiator<CompetingResponseWithCensorTime> {
|
public class GrayLogRankMultipleGroupDifferentiator extends CompetingRiskGroupDifferentiator<CompetingRiskResponseWithCensorTime> {
|
||||||
|
|
||||||
private final int[] events;
|
private final int[] events;
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Double differentiate(List<CompetingResponseWithCensorTime> leftHand, List<CompetingResponseWithCensorTime> rightHand) {
|
public Double differentiate(List<CompetingRiskResponseWithCensorTime> leftHand, List<CompetingRiskResponseWithCensorTime> rightHand) {
|
||||||
if(leftHand.size() == 0 || rightHand.size() == 0){
|
if(leftHand.size() == 0 || rightHand.size() == 0){
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
@ -30,8 +25,8 @@ public class GrayLogRankMultipleGroupDifferentiator extends CompetingRiskGroupDi
|
||||||
for(final int eventOfFocus : events){
|
for(final int eventOfFocus : events){
|
||||||
final LogRankValue valueOfInterest = specificLogRankValue(eventOfFocus, leftHand, rightHand);
|
final LogRankValue valueOfInterest = specificLogRankValue(eventOfFocus, leftHand, rightHand);
|
||||||
|
|
||||||
numerator += valueOfInterest.getNumerator()*valueOfInterest.getVariance();
|
numerator += valueOfInterest.getNumerator()*valueOfInterest.getVarianceSqrt();
|
||||||
denominatorSquared += valueOfInterest.getVarianceSquared();
|
denominatorSquared += valueOfInterest.getVariance();
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -40,7 +35,7 @@ public class GrayLogRankMultipleGroupDifferentiator extends CompetingRiskGroupDi
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
double riskSet(List<CompetingResponseWithCensorTime> eventList, double time, int eventOfFocus) {
|
double riskSet(List<CompetingRiskResponseWithCensorTime> eventList, double time, int eventOfFocus) {
|
||||||
return eventList.stream()
|
return eventList.stream()
|
||||||
.filter(event -> event.getU() >= time ||
|
.filter(event -> event.getU() >= time ||
|
||||||
(event.getU() < time && event.getDelta() != eventOfFocus && event.getC() > time)
|
(event.getU() < time && event.getDelta() != eventOfFocus && event.getC() > time)
|
||||||
|
|
|
@ -1,11 +1,7 @@
|
||||||
package ca.joeltherrien.randomforest.responses.competingrisk;
|
package ca.joeltherrien.randomforest.responses.competingrisk;
|
||||||
|
|
||||||
import ca.joeltherrien.randomforest.tree.GroupDifferentiator;
|
|
||||||
import com.fasterxml.jackson.databind.JsonNode;
|
|
||||||
import lombok.RequiredArgsConstructor;
|
import lombok.RequiredArgsConstructor;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.Iterator;
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -13,24 +9,24 @@ import java.util.List;
|
||||||
*
|
*
|
||||||
*/
|
*/
|
||||||
@RequiredArgsConstructor
|
@RequiredArgsConstructor
|
||||||
public class GrayLogRankSingleGroupDifferentiator extends CompetingRiskGroupDifferentiator<CompetingResponseWithCensorTime> {
|
public class GrayLogRankSingleGroupDifferentiator extends CompetingRiskGroupDifferentiator<CompetingRiskResponseWithCensorTime> {
|
||||||
|
|
||||||
private final int eventOfFocus;
|
private final int eventOfFocus;
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Double differentiate(List<CompetingResponseWithCensorTime> leftHand, List<CompetingResponseWithCensorTime> rightHand) {
|
public Double differentiate(List<CompetingRiskResponseWithCensorTime> leftHand, List<CompetingRiskResponseWithCensorTime> rightHand) {
|
||||||
if(leftHand.size() == 0 || rightHand.size() == 0){
|
if(leftHand.size() == 0 || rightHand.size() == 0){
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
final LogRankValue valueOfInterest = specificLogRankValue(eventOfFocus, leftHand, rightHand);
|
final LogRankValue valueOfInterest = specificLogRankValue(eventOfFocus, leftHand, rightHand);
|
||||||
|
|
||||||
return Math.abs(valueOfInterest.getNumerator() / valueOfInterest.getVariance());
|
return Math.abs(valueOfInterest.getNumerator() / valueOfInterest.getVarianceSqrt());
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
double riskSet(List<CompetingResponseWithCensorTime> eventList, double time, int eventOfFocus) {
|
double riskSet(List<CompetingRiskResponseWithCensorTime> eventList, double time, int eventOfFocus) {
|
||||||
return eventList.stream()
|
return eventList.stream()
|
||||||
.filter(event -> event.getU() >= time ||
|
.filter(event -> event.getU() >= time ||
|
||||||
(event.getU() < time && event.getDelta() != eventOfFocus && event.getC() > time)
|
(event.getU() < time && event.getDelta() != eventOfFocus && event.getC() > time)
|
||||||
|
|
|
@ -1,11 +1,7 @@
|
||||||
package ca.joeltherrien.randomforest.responses.competingrisk;
|
package ca.joeltherrien.randomforest.responses.competingrisk;
|
||||||
|
|
||||||
import ca.joeltherrien.randomforest.tree.GroupDifferentiator;
|
|
||||||
import com.fasterxml.jackson.databind.JsonNode;
|
|
||||||
import lombok.RequiredArgsConstructor;
|
import lombok.RequiredArgsConstructor;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.Iterator;
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -13,12 +9,12 @@ import java.util.List;
|
||||||
*
|
*
|
||||||
*/
|
*/
|
||||||
@RequiredArgsConstructor
|
@RequiredArgsConstructor
|
||||||
public class LogRankMultipleGroupDifferentiator extends CompetingRiskGroupDifferentiator<CompetingResponse> {
|
public class LogRankMultipleGroupDifferentiator extends CompetingRiskGroupDifferentiator<CompetingRiskResponse> {
|
||||||
|
|
||||||
private final int[] events;
|
private final int[] events;
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Double differentiate(List<CompetingResponse> leftHand, List<CompetingResponse> rightHand) {
|
public Double differentiate(List<CompetingRiskResponse> leftHand, List<CompetingRiskResponse> rightHand) {
|
||||||
if(leftHand.size() == 0 || rightHand.size() == 0){
|
if(leftHand.size() == 0 || rightHand.size() == 0){
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
@ -29,8 +25,8 @@ public class LogRankMultipleGroupDifferentiator extends CompetingRiskGroupDiffer
|
||||||
for(final int eventOfFocus : events){
|
for(final int eventOfFocus : events){
|
||||||
final LogRankValue valueOfInterest = specificLogRankValue(eventOfFocus, leftHand, rightHand);
|
final LogRankValue valueOfInterest = specificLogRankValue(eventOfFocus, leftHand, rightHand);
|
||||||
|
|
||||||
numerator += valueOfInterest.getNumerator()*valueOfInterest.getVariance();
|
numerator += valueOfInterest.getNumerator()*valueOfInterest.getVarianceSqrt();
|
||||||
denominatorSquared += valueOfInterest.getVarianceSquared();
|
denominatorSquared += valueOfInterest.getVariance();
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -39,7 +35,7 @@ public class LogRankMultipleGroupDifferentiator extends CompetingRiskGroupDiffer
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
double riskSet(List<CompetingResponse> eventList, double time, int eventOfFocus) {
|
double riskSet(List<CompetingRiskResponse> eventList, double time, int eventOfFocus) {
|
||||||
return eventList.stream()
|
return eventList.stream()
|
||||||
.filter(event -> event.getU() >= time)
|
.filter(event -> event.getU() >= time)
|
||||||
.count();
|
.count();
|
||||||
|
|
|
@ -1,6 +1,5 @@
|
||||||
package ca.joeltherrien.randomforest.responses.competingrisk;
|
package ca.joeltherrien.randomforest.responses.competingrisk;
|
||||||
|
|
||||||
import ca.joeltherrien.randomforest.tree.GroupDifferentiator;
|
|
||||||
import lombok.RequiredArgsConstructor;
|
import lombok.RequiredArgsConstructor;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
@ -10,24 +9,24 @@ import java.util.List;
|
||||||
*
|
*
|
||||||
*/
|
*/
|
||||||
@RequiredArgsConstructor
|
@RequiredArgsConstructor
|
||||||
public class LogRankSingleGroupDifferentiator extends CompetingRiskGroupDifferentiator<CompetingResponse> {
|
public class LogRankSingleGroupDifferentiator extends CompetingRiskGroupDifferentiator<CompetingRiskResponse> {
|
||||||
|
|
||||||
private final int eventOfFocus;
|
private final int eventOfFocus;
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Double differentiate(List<CompetingResponse> leftHand, List<CompetingResponse> rightHand) {
|
public Double differentiate(List<CompetingRiskResponse> leftHand, List<CompetingRiskResponse> rightHand) {
|
||||||
if(leftHand.size() == 0 || rightHand.size() == 0){
|
if(leftHand.size() == 0 || rightHand.size() == 0){
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
final LogRankValue valueOfInterest = specificLogRankValue(eventOfFocus, leftHand, rightHand);
|
final LogRankValue valueOfInterest = specificLogRankValue(eventOfFocus, leftHand, rightHand);
|
||||||
|
|
||||||
return Math.abs(valueOfInterest.getNumerator() / valueOfInterest.getVariance());
|
return Math.abs(valueOfInterest.getNumerator() / valueOfInterest.getVarianceSqrt());
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
double riskSet(List<CompetingResponse> eventList, double time, int eventOfFocus) {
|
double riskSet(List<CompetingRiskResponse> eventList, double time, int eventOfFocus) {
|
||||||
return eventList.stream()
|
return eventList.stream()
|
||||||
.filter(event -> event.getU() >= time)
|
.filter(event -> event.getU() >= time)
|
||||||
.count();
|
.count();
|
||||||
|
|
|
@ -0,0 +1,60 @@
|
||||||
|
package ca.joeltherrien.randomforest.responses.competingrisk;
|
||||||
|
|
||||||
|
import lombok.Getter;
|
||||||
|
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.Comparator;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Optional;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Represents a function represented by discrete points. We assume that the function is a stepwise continuous function,
|
||||||
|
* constant at the value of the previous encountered point.
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
public class MathFunction {
|
||||||
|
|
||||||
|
@Getter
|
||||||
|
private final List<Point> points;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Represents the value that should be returned by evaluate if there are points prior to the time the function is being evaluated at.
|
||||||
|
*
|
||||||
|
* Map be null.
|
||||||
|
*/
|
||||||
|
private final Point defaultValue;
|
||||||
|
|
||||||
|
public MathFunction(final List<Point> points){
|
||||||
|
this(points, new Point(0.0, 0.0));
|
||||||
|
}
|
||||||
|
|
||||||
|
public MathFunction(final List<Point> points, final Point defaultValue){
|
||||||
|
this.points = Collections.unmodifiableList(points);
|
||||||
|
this.defaultValue = defaultValue;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Point evaluate(double time){
|
||||||
|
final Optional<Point> pointOptional = points.stream()
|
||||||
|
.filter(point -> point.getTime() <= time)
|
||||||
|
.max(Comparator.comparingDouble(Point::getTime));
|
||||||
|
|
||||||
|
return pointOptional.orElse(defaultValue);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String toString(){
|
||||||
|
final StringBuilder builder = new StringBuilder();
|
||||||
|
builder.append("Default point: ");
|
||||||
|
builder.append(defaultValue);
|
||||||
|
builder.append("\n");
|
||||||
|
|
||||||
|
for(final Point point : points){
|
||||||
|
builder.append(point);
|
||||||
|
builder.append("\n");
|
||||||
|
}
|
||||||
|
|
||||||
|
return builder.toString();
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -8,8 +8,6 @@ import lombok.Data;
|
||||||
*/
|
*/
|
||||||
@Data
|
@Data
|
||||||
public class Point {
|
public class Point {
|
||||||
|
|
||||||
private final Double time;
|
private final Double time;
|
||||||
private final Double y;
|
private final Double y;
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -3,11 +3,6 @@ package ca.joeltherrien.randomforest.responses.regression;
|
||||||
import ca.joeltherrien.randomforest.tree.ResponseCombiner;
|
import ca.joeltherrien.randomforest.tree.ResponseCombiner;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Set;
|
|
||||||
import java.util.function.BiConsumer;
|
|
||||||
import java.util.function.BinaryOperator;
|
|
||||||
import java.util.function.Function;
|
|
||||||
import java.util.function.Supplier;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This implementation of the collector isn't great... but good enough given that I'm not planning to fully support regression trees.
|
* This implementation of the collector isn't great... but good enough given that I'm not planning to fully support regression trees.
|
||||||
|
@ -15,11 +10,7 @@ import java.util.function.Supplier;
|
||||||
* (It's not great because you'll lose accuracy as you sum up the doubles, since dividing by n is the very last step.)
|
* (It's not great because you'll lose accuracy as you sum up the doubles, since dividing by n is the very last step.)
|
||||||
*
|
*
|
||||||
*/
|
*/
|
||||||
public class MeanResponseCombiner implements ResponseCombiner<Double, MeanResponseCombiner.Container> {
|
public class MeanResponseCombiner implements ResponseCombiner<Double, Double> {
|
||||||
|
|
||||||
static{
|
|
||||||
ResponseCombiner.registerResponseCombiner("MeanResponseCombiner", new MeanResponseCombiner());
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Double combine(List<Double> responses) {
|
public Double combine(List<Double> responses) {
|
||||||
|
@ -29,51 +20,5 @@ public class MeanResponseCombiner implements ResponseCombiner<Double, MeanRespon
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public Supplier<Container> supplier() {
|
|
||||||
return () -> new Container(0 ,0);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public BiConsumer<Container, Double> accumulator() {
|
|
||||||
return (container, number) -> {
|
|
||||||
container.number+=number;
|
|
||||||
container.n++;
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public BinaryOperator<Container> combiner() {
|
|
||||||
return (c1, c2) -> {
|
|
||||||
c1.number += c2.number;
|
|
||||||
c1.n += c2.n;
|
|
||||||
|
|
||||||
return c1;
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Function<Container, Double> finisher() {
|
|
||||||
return (container) -> container.number/(double)container.n;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Set<Characteristics> characteristics() {
|
|
||||||
return Set.of(Characteristics.UNORDERED);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
public static class Container{
|
|
||||||
|
|
||||||
Container(double number, int n){
|
|
||||||
this.number = number;
|
|
||||||
this.n = n;
|
|
||||||
}
|
|
||||||
|
|
||||||
public Double number;
|
|
||||||
public int n;
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -4,17 +4,22 @@ import ca.joeltherrien.randomforest.CovariateRow;
|
||||||
import lombok.Builder;
|
import lombok.Builder;
|
||||||
|
|
||||||
import java.util.Collection;
|
import java.util.Collection;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
@Builder
|
@Builder
|
||||||
public class Forest<Y> {
|
public class Forest<O, FO> { // O = output of trees, FO = forest output. In practice O == FO, even in competing risk & survival settings
|
||||||
|
|
||||||
private final Collection<Node<Y>> trees;
|
private final Collection<Node<O>> trees;
|
||||||
private final ResponseCombiner<Y, ?> treeResponseCombiner;
|
private final ResponseCombiner<O, FO> treeResponseCombiner;
|
||||||
|
|
||||||
public Y evaluate(CovariateRow row){
|
public FO evaluate(CovariateRow row){
|
||||||
return trees.parallelStream()
|
|
||||||
|
return treeResponseCombiner.combine(
|
||||||
|
trees.parallelStream()
|
||||||
.map(node -> node.evaluate(row))
|
.map(node -> node.evaluate(row))
|
||||||
.collect(treeResponseCombiner);
|
.collect(Collectors.toList())
|
||||||
|
);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -16,17 +16,18 @@ import java.util.Collections;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.concurrent.ExecutorService;
|
import java.util.concurrent.ExecutorService;
|
||||||
import java.util.concurrent.Executors;
|
import java.util.concurrent.Executors;
|
||||||
|
import java.util.concurrent.ThreadLocalRandom;
|
||||||
import java.util.concurrent.atomic.AtomicInteger;
|
import java.util.concurrent.atomic.AtomicInteger;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
import java.util.stream.Stream;
|
import java.util.stream.Stream;
|
||||||
|
|
||||||
@Builder
|
@Builder
|
||||||
@AllArgsConstructor(access=AccessLevel.PRIVATE)
|
@AllArgsConstructor(access=AccessLevel.PRIVATE)
|
||||||
public class ForestTrainer<Y> {
|
public class ForestTrainer<Y, TO, FO> {
|
||||||
|
|
||||||
private final TreeTrainer<Y> treeTrainer;
|
private final TreeTrainer<Y, TO> treeTrainer;
|
||||||
private final List<Covariate> covariatesToTry;
|
private final List<Covariate> covariates;
|
||||||
private final ResponseCombiner<Y, ?> treeResponseCombiner;
|
private final ResponseCombiner<TO, FO> treeResponseCombiner;
|
||||||
private final List<Row<Y>> data;
|
private final List<Row<Y>> data;
|
||||||
|
|
||||||
// number of covariates to randomly try
|
// number of covariates to randomly try
|
||||||
|
@ -45,15 +46,15 @@ public class ForestTrainer<Y> {
|
||||||
this.displayProgress = true;
|
this.displayProgress = true;
|
||||||
this.saveTreeLocation = settings.getSaveTreeLocation();
|
this.saveTreeLocation = settings.getSaveTreeLocation();
|
||||||
|
|
||||||
this.covariatesToTry = covariates;
|
this.covariates = covariates;
|
||||||
this.treeResponseCombiner = ResponseCombiner.loadResponseCombinerByName(settings.getTreeResponseCombiner());
|
this.treeResponseCombiner = settings.getTreeCombiner();
|
||||||
this.treeTrainer = new TreeTrainer<>(settings);
|
this.treeTrainer = new TreeTrainer<>(settings, covariates);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public Forest<Y> trainSerial(){
|
public Forest<TO, FO> trainSerial(){
|
||||||
|
|
||||||
final List<Node<Y>> trees = new ArrayList<>(ntree);
|
final List<Node<TO>> trees = new ArrayList<>(ntree);
|
||||||
final Bootstrapper<Row<Y>> bootstrapper = new Bootstrapper<>(data);
|
final Bootstrapper<Row<Y>> bootstrapper = new Bootstrapper<>(data);
|
||||||
|
|
||||||
for(int j=0; j<ntree; j++){
|
for(int j=0; j<ntree; j++){
|
||||||
|
@ -71,18 +72,18 @@ public class ForestTrainer<Y> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return Forest.<Y>builder()
|
return Forest.<TO, FO>builder()
|
||||||
.treeResponseCombiner(treeResponseCombiner)
|
.treeResponseCombiner(treeResponseCombiner)
|
||||||
.trees(trees)
|
.trees(trees)
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public Forest<Y> trainParallelInMemory(int threads){
|
public Forest<TO, FO> trainParallelInMemory(int threads){
|
||||||
|
|
||||||
// create a list that is prespecified in size (I can call the .set method at any index < ntree without
|
// create a list that is prespecified in size (I can call the .set method at any index < ntree without
|
||||||
// the earlier indexes being filled.
|
// the earlier indexes being filled.
|
||||||
final List<Node<Y>> trees = Stream.<Node<Y>>generate(() -> null).limit(ntree).collect(Collectors.toList());
|
final List<Node<TO>> trees = Stream.<Node<TO>>generate(() -> null).limit(ntree).collect(Collectors.toList());
|
||||||
|
|
||||||
final ExecutorService executorService = Executors.newFixedThreadPool(threads);
|
final ExecutorService executorService = Executors.newFixedThreadPool(threads);
|
||||||
|
|
||||||
|
@ -102,7 +103,7 @@ public class ForestTrainer<Y> {
|
||||||
|
|
||||||
if(displayProgress) {
|
if(displayProgress) {
|
||||||
int numberTreesSet = 0;
|
int numberTreesSet = 0;
|
||||||
for (final Node<Y> tree : trees) {
|
for (final Node<TO> tree : trees) {
|
||||||
if (tree != null) {
|
if (tree != null) {
|
||||||
numberTreesSet++;
|
numberTreesSet++;
|
||||||
}
|
}
|
||||||
|
@ -117,7 +118,7 @@ public class ForestTrainer<Y> {
|
||||||
System.out.println("\nFinished");
|
System.out.println("\nFinished");
|
||||||
}
|
}
|
||||||
|
|
||||||
return Forest.<Y>builder()
|
return Forest.<TO, FO>builder()
|
||||||
.treeResponseCombiner(treeResponseCombiner)
|
.treeResponseCombiner(treeResponseCombiner)
|
||||||
.trees(trees)
|
.trees(trees)
|
||||||
.build();
|
.build();
|
||||||
|
@ -156,20 +157,12 @@ public class ForestTrainer<Y> {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private Node<Y> trainTree(final Bootstrapper<Row<Y>> bootstrapper){
|
private Node<TO> trainTree(final Bootstrapper<Row<Y>> bootstrapper){
|
||||||
final List<Covariate> treeCovariates = new ArrayList<>(covariatesToTry);
|
|
||||||
Collections.shuffle(treeCovariates);
|
|
||||||
|
|
||||||
for(int treeIndex = covariatesToTry.size()-1; treeIndex >= mtry; treeIndex--){
|
|
||||||
treeCovariates.remove(treeIndex);
|
|
||||||
}
|
|
||||||
|
|
||||||
final List<Row<Y>> bootstrappedData = bootstrapper.bootstrap();
|
final List<Row<Y>> bootstrappedData = bootstrapper.bootstrap();
|
||||||
|
return treeTrainer.growTree(bootstrappedData);
|
||||||
return treeTrainer.growTree(bootstrappedData, treeCovariates);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public void saveTree(final Node<Y> tree, String name) throws IOException {
|
public void saveTree(final Node<TO> tree, String name) throws IOException {
|
||||||
final String filename = saveTreeLocation + "/" + name;
|
final String filename = saveTreeLocation + "/" + name;
|
||||||
|
|
||||||
final ObjectOutputStream outputStream = new ObjectOutputStream(new FileOutputStream(filename));
|
final ObjectOutputStream outputStream = new ObjectOutputStream(new FileOutputStream(filename));
|
||||||
|
@ -184,9 +177,9 @@ public class ForestTrainer<Y> {
|
||||||
|
|
||||||
private final Bootstrapper<Row<Y>> bootstrapper;
|
private final Bootstrapper<Row<Y>> bootstrapper;
|
||||||
private final int treeIndex;
|
private final int treeIndex;
|
||||||
private final List<Node<Y>> treeList;
|
private final List<Node<TO>> treeList;
|
||||||
|
|
||||||
public TreeInMemoryWorker(final List<Row<Y>> data, final int treeIndex, final List<Node<Y>> treeList) {
|
public TreeInMemoryWorker(final List<Row<Y>> data, final int treeIndex, final List<Node<TO>> treeList) {
|
||||||
this.bootstrapper = new Bootstrapper<>(data);
|
this.bootstrapper = new Bootstrapper<>(data);
|
||||||
this.treeIndex = treeIndex;
|
this.treeIndex = treeIndex;
|
||||||
this.treeList = treeList;
|
this.treeList = treeList;
|
||||||
|
@ -195,7 +188,7 @@ public class ForestTrainer<Y> {
|
||||||
@Override
|
@Override
|
||||||
public void run() {
|
public void run() {
|
||||||
|
|
||||||
final Node<Y> tree = trainTree(bootstrapper);
|
final Node<TO> tree = trainTree(bootstrapper);
|
||||||
|
|
||||||
// should be okay as the list structure isn't changing
|
// should be okay as the list structure isn't changing
|
||||||
treeList.set(treeIndex, tree);
|
treeList.set(treeIndex, tree);
|
||||||
|
@ -218,7 +211,7 @@ public class ForestTrainer<Y> {
|
||||||
@Override
|
@Override
|
||||||
public void run() {
|
public void run() {
|
||||||
|
|
||||||
final Node<Y> tree = trainTree(bootstrapper);
|
final Node<TO> tree = trainTree(bootstrapper);
|
||||||
|
|
||||||
try {
|
try {
|
||||||
saveTree(tree, filename);
|
saveTree(tree, filename);
|
||||||
|
|
|
@ -1,10 +1,6 @@
|
||||||
package ca.joeltherrien.randomforest.tree;
|
package ca.joeltherrien.randomforest.tree;
|
||||||
|
|
||||||
import com.fasterxml.jackson.databind.node.ObjectNode;
|
|
||||||
|
|
||||||
import java.util.HashMap;
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* When choosing an optimal node to split on, we choose the split that maximizes the difference between the two groups.
|
* When choosing an optimal node to split on, we choose the split that maximizes the difference between the two groups.
|
||||||
|
@ -16,11 +12,4 @@ public interface GroupDifferentiator<Y> {
|
||||||
|
|
||||||
Double differentiate(List<Y> leftHand, List<Y> rightHand);
|
Double differentiate(List<Y> leftHand, List<Y> rightHand);
|
||||||
|
|
||||||
@FunctionalInterface
|
|
||||||
interface GroupDifferentiatorConstructor<Y>{
|
|
||||||
|
|
||||||
GroupDifferentiator<Y> construct(ObjectNode node);
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,20 +1,9 @@
|
||||||
package ca.joeltherrien.randomforest.tree;
|
package ca.joeltherrien.randomforest.tree;
|
||||||
|
|
||||||
import java.util.HashMap;
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
|
||||||
import java.util.stream.Collector;
|
|
||||||
|
|
||||||
public interface ResponseCombiner<Y, K> extends Collector<Y, K, Y> {
|
public interface ResponseCombiner<I, O> {
|
||||||
|
|
||||||
Y combine(List<Y> responses);
|
O combine(List<I> responses);
|
||||||
|
|
||||||
final static Map<String, ResponseCombiner> RESPONSE_COMBINER_MAP = new HashMap<>();
|
|
||||||
static ResponseCombiner loadResponseCombinerByName(final String name){
|
|
||||||
return RESPONSE_COMBINER_MAP.get(name);
|
|
||||||
}
|
|
||||||
static void registerResponseCombiner(final String name, final ResponseCombiner responseCombiner){
|
|
||||||
RESPONSE_COMBINER_MAP.put(name, responseCombiner);
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -7,13 +7,14 @@ import lombok.AllArgsConstructor;
|
||||||
import lombok.Builder;
|
import lombok.Builder;
|
||||||
|
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
|
import java.util.concurrent.ThreadLocalRandom;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
@Builder
|
@Builder
|
||||||
@AllArgsConstructor(access = AccessLevel.PRIVATE)
|
@AllArgsConstructor(access = AccessLevel.PRIVATE)
|
||||||
public class TreeTrainer<Y> {
|
public class TreeTrainer<Y, O> {
|
||||||
|
|
||||||
private final ResponseCombiner<Y, ?> responseCombiner;
|
private final ResponseCombiner<Y, O> responseCombiner;
|
||||||
private final GroupDifferentiator<Y> groupDifferentiator;
|
private final GroupDifferentiator<Y> groupDifferentiator;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -23,54 +24,76 @@ public class TreeTrainer<Y> {
|
||||||
private final int numberOfSplits;
|
private final int numberOfSplits;
|
||||||
private final int nodeSize;
|
private final int nodeSize;
|
||||||
private final int maxNodeDepth;
|
private final int maxNodeDepth;
|
||||||
|
private final int mtry;
|
||||||
|
|
||||||
public TreeTrainer(final Settings settings){
|
private final List<Covariate> covariates;
|
||||||
|
|
||||||
|
public TreeTrainer(final Settings settings, final List<Covariate> covariates){
|
||||||
this.numberOfSplits = settings.getNumberOfSplits();
|
this.numberOfSplits = settings.getNumberOfSplits();
|
||||||
this.nodeSize = settings.getNodeSize();
|
this.nodeSize = settings.getNodeSize();
|
||||||
this.maxNodeDepth = settings.getMaxNodeDepth();
|
this.maxNodeDepth = settings.getMaxNodeDepth();
|
||||||
|
this.mtry = settings.getMtry();
|
||||||
|
|
||||||
this.responseCombiner = ResponseCombiner.loadResponseCombinerByName(settings.getResponseCombiner());
|
this.responseCombiner = settings.getResponseCombiner();
|
||||||
this.groupDifferentiator = settings.getGroupDifferentiator();
|
this.groupDifferentiator = settings.getGroupDifferentiator();
|
||||||
|
this.covariates = covariates;
|
||||||
}
|
}
|
||||||
|
|
||||||
public Node<Y> growTree(List<Row<Y>> data, List<Covariate> covariatesToTry){
|
public Node<O> growTree(List<Row<Y>> data){
|
||||||
return growNode(data, covariatesToTry, 0);
|
return growNode(data, 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
private Node<Y> growNode(List<Row<Y>> data, List<Covariate> covariatesToTry, int depth){
|
private Node<O> growNode(List<Row<Y>> data, int depth){
|
||||||
// TODO; what is minimum per tree?
|
// TODO; what is minimum per tree?
|
||||||
if(data.size() >= 2*nodeSize && depth < maxNodeDepth && !nodeIsPure(data)){
|
if(data.size() >= 2*nodeSize && depth < maxNodeDepth && !nodeIsPure(data)){
|
||||||
|
final List<Covariate> covariatesToTry = selectCovariates(this.mtry);
|
||||||
final Covariate.SplitRule bestSplitRule = findBestSplitRule(data, covariatesToTry);
|
final Covariate.SplitRule bestSplitRule = findBestSplitRule(data, covariatesToTry);
|
||||||
|
|
||||||
if(bestSplitRule == null){
|
if(bestSplitRule == null){
|
||||||
return new TerminalNode<>(
|
|
||||||
data.stream()
|
|
||||||
.map(row -> row.getResponse())
|
|
||||||
.collect(responseCombiner)
|
|
||||||
|
|
||||||
|
return new TerminalNode<>(
|
||||||
|
responseCombiner.combine(
|
||||||
|
data.stream().map(row -> row.getResponse()).collect(Collectors.toList())
|
||||||
|
)
|
||||||
);
|
);
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
final Split<Y> split = bestSplitRule.applyRule(data); // TODO optimize this as we're duplicating work done in findBestSplitRule
|
final Split<Y> split = bestSplitRule.applyRule(data); // TODO optimize this as we're duplicating work done in findBestSplitRule
|
||||||
|
|
||||||
final Node<Y> leftNode = growNode(split.leftHand, covariatesToTry, depth+1);
|
final Node<O> leftNode = growNode(split.leftHand, depth+1);
|
||||||
final Node<Y> rightNode = growNode(split.rightHand, covariatesToTry, depth+1);
|
final Node<O> rightNode = growNode(split.rightHand, depth+1);
|
||||||
|
|
||||||
return new SplitNode<>(leftNode, rightNode, bestSplitRule);
|
return new SplitNode<>(leftNode, rightNode, bestSplitRule);
|
||||||
|
|
||||||
}
|
}
|
||||||
else{
|
else{
|
||||||
return new TerminalNode<>(
|
return new TerminalNode<>(
|
||||||
data.stream()
|
responseCombiner.combine(
|
||||||
.map(row -> row.getResponse())
|
data.stream().map(row -> row.getResponse()).collect(Collectors.toList())
|
||||||
.collect(responseCombiner)
|
)
|
||||||
|
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private List<Covariate> selectCovariates(int mtry){
|
||||||
|
if(mtry >= covariates.size()){
|
||||||
|
return covariates;
|
||||||
|
}
|
||||||
|
|
||||||
|
final List<Covariate> splitCovariates = new ArrayList<>(covariates);
|
||||||
|
Collections.shuffle(splitCovariates, ThreadLocalRandom.current());
|
||||||
|
|
||||||
|
for(int treeIndex = splitCovariates.size()-1; treeIndex >= mtry; treeIndex--){
|
||||||
|
splitCovariates.remove(treeIndex);
|
||||||
|
}
|
||||||
|
|
||||||
|
return splitCovariates;
|
||||||
|
}
|
||||||
|
|
||||||
private Covariate.SplitRule findBestSplitRule(List<Row<Y>> data, List<Covariate> covariatesToTry){
|
private Covariate.SplitRule findBestSplitRule(List<Row<Y>> data, List<Covariate> covariatesToTry){
|
||||||
Covariate.SplitRule bestSplitRule = null;
|
Covariate.SplitRule bestSplitRule = null;
|
||||||
double bestSplitScore = 0.0;
|
double bestSplitScore = 0.0;
|
||||||
|
@ -96,7 +119,9 @@ public class TreeTrainer<Y> {
|
||||||
possibleSplit.rightHand.stream().map(row -> row.getResponse()).collect(Collectors.toList())
|
possibleSplit.rightHand.stream().map(row -> row.getResponse()).collect(Collectors.toList())
|
||||||
);
|
);
|
||||||
|
|
||||||
if(score != null && (score > bestSplitScore || first)){
|
|
||||||
|
|
||||||
|
if(score != null && !Double.isNaN(score) && (score > bestSplitScore || first)){
|
||||||
bestSplitRule = possibleRule;
|
bestSplitRule = possibleRule;
|
||||||
bestSplitScore = score;
|
bestSplitScore = score;
|
||||||
first = false;
|
first = false;
|
||||||
|
|
|
@ -0,0 +1,99 @@
|
||||||
|
# Comparing output between my two packages
|
||||||
|
|
||||||
|
require(survival)
|
||||||
|
require(randomForestSRC)
|
||||||
|
|
||||||
|
data(wihs, package = "randomForestSRC")
|
||||||
|
|
||||||
|
wihs$idu = as.logical(wihs$idu)
|
||||||
|
wihs$black = as.logical(wihs$black)
|
||||||
|
|
||||||
|
set.seed(100)
|
||||||
|
|
||||||
|
wihs = wihs[sample(1:nrow(wihs), replace = FALSE),]
|
||||||
|
|
||||||
|
# example row
|
||||||
|
# time status ageatfda idu black cd4nadir
|
||||||
|
#409 1.3 1 35 FALSE FALSE 0.81
|
||||||
|
newData = data.frame(ageatfda=35, idu=FALSE, black=FALSE, cd4nadir=0.81)
|
||||||
|
|
||||||
|
one.tree <- rfsrc(Surv(time, status) ~ idu + black, wihs, nsplit = 5, ntree = 1, splitrule="logrank", cause=1, mtry=2, seed=-5, membership=TRUE)
|
||||||
|
# next get membership
|
||||||
|
membership=one.tree$inbag[,1]
|
||||||
|
|
||||||
|
if(FALSE){
|
||||||
|
bootstrappedData = wihs[c(),]
|
||||||
|
for(i in 1:length(membership)){
|
||||||
|
times = membership[i]
|
||||||
|
|
||||||
|
if(times > 0){
|
||||||
|
for(j in 1:times){
|
||||||
|
bootstrappedData = rbind(bootstrappedData, wihs[i,])
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
output.one.tree = predict(one.tree, newData)
|
||||||
|
output.one.tree$cif[,,1]
|
||||||
|
output.one.tree$chf[,c(11,66,103),1]
|
||||||
|
|
||||||
|
# Note that "... ~ ." means 'use all explanatory variables"
|
||||||
|
#output = predict(wihs.obj, newData)
|
||||||
|
#output$cif[,,1] # CIF for cause 1
|
||||||
|
#output$cif[,,2] # CIF for cause 2
|
||||||
|
|
||||||
|
many.trees <- rfsrc(Surv(time, status) ~ idu + black, wihs, nsplit = 5, ntree = 100, splitrule="logrank", cause=1, mtry=2, membership=TRUE)
|
||||||
|
output.many.trees = predict(many.trees, newData)
|
||||||
|
output.many.trees$cif[,41,1]
|
||||||
|
output.many.trees$cif[,41,2]
|
||||||
|
|
||||||
|
many.trees.all <- rfsrc(Surv(time, status) ~ ageatfda + cd4nadir + idu + black, wihs, nsplit = 5, ntree = 100, splitrule="logrank", cause=1, mtry=2, membership=TRUE)
|
||||||
|
output.many.trees.all = predict(many.trees.all, newData)
|
||||||
|
output.many.trees.all$cif[,103,1]
|
||||||
|
output.many.trees.all$cif[,103,2]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
end.numbers = c()
|
||||||
|
end.times = c()
|
||||||
|
lgths = c()
|
||||||
|
trees = list()
|
||||||
|
for(i in 1:100){
|
||||||
|
one.tree = rfsrc(Surv(time, status) ~ ageatfda + cd4nadir + idu + black, wihs, nsplit = 0, ntree = 1, splitrule="logrank", cause=1, mtry=4, membership=TRUE, statistics = TRUE)
|
||||||
|
trees[[i]] = one.tree
|
||||||
|
prediction = predict(one.tree, newData)
|
||||||
|
lgth = length(prediction$cif[,,1])
|
||||||
|
lgths = c(lgths, lgth)
|
||||||
|
end.numbers = c(end.numbers, prediction$cif[,lgth,1])
|
||||||
|
end.times = c(end.times, max(prediction$time.interest))
|
||||||
|
}
|
||||||
|
|
||||||
|
special.tree = trees[[100]]
|
||||||
|
|
||||||
|
|
||||||
|
prediction = predict(special.tree, newData)
|
||||||
|
prediction$cif[,,1]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
membership = special.tree$inbag[,1]
|
||||||
|
bootstrappedData = wihs[c(),]
|
||||||
|
for(i in 1:length(membership)){
|
||||||
|
times = membership[i]
|
||||||
|
|
||||||
|
if(times > 0){
|
||||||
|
for(j in 1:times){
|
||||||
|
bootstrappedData = rbind(bootstrappedData, wihs[i,])
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
write.csv(bootstrappedData, "RandomSurvivalForests/src/test/resources/wihs.bootstrapped2.csv", row.names=FALSE)
|
||||||
|
prediction$cif[,,1]
|
||||||
|
|
|
@ -0,0 +1,325 @@
|
||||||
|
package ca.joeltherrien.randomforest.competingrisk;
|
||||||
|
|
||||||
|
import ca.joeltherrien.randomforest.*;
|
||||||
|
import ca.joeltherrien.randomforest.covariates.*;
|
||||||
|
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskFunctions;
|
||||||
|
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
|
||||||
|
import ca.joeltherrien.randomforest.responses.competingrisk.MathFunction;
|
||||||
|
import ca.joeltherrien.randomforest.responses.competingrisk.Point;
|
||||||
|
import ca.joeltherrien.randomforest.tree.Forest;
|
||||||
|
import ca.joeltherrien.randomforest.tree.ForestTrainer;
|
||||||
|
import ca.joeltherrien.randomforest.tree.Node;
|
||||||
|
import ca.joeltherrien.randomforest.tree.TreeTrainer;
|
||||||
|
import com.fasterxml.jackson.databind.node.*;
|
||||||
|
import org.junit.jupiter.api.Test;
|
||||||
|
|
||||||
|
import static org.junit.jupiter.api.Assertions.*;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.concurrent.ThreadLocalRandom;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
|
public class TestCompetingRisk {
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* By default uses single log-rank test.
|
||||||
|
*
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
public Settings getSettings(){
|
||||||
|
final ObjectNode groupDifferentiatorSettings = new ObjectNode(JsonNodeFactory.instance);
|
||||||
|
groupDifferentiatorSettings.set("type", new TextNode("LogRankSingleGroupDifferentiator"));
|
||||||
|
groupDifferentiatorSettings.set("eventOfFocus", new IntNode(1));
|
||||||
|
|
||||||
|
final ObjectNode responseCombinerSettings = new ObjectNode(JsonNodeFactory.instance);
|
||||||
|
responseCombinerSettings.set("type", new TextNode("CompetingRiskResponseCombiner"));
|
||||||
|
responseCombinerSettings.set("events",
|
||||||
|
new ArrayNode(JsonNodeFactory.instance, List.of(new IntNode(1), new IntNode(2)))
|
||||||
|
);
|
||||||
|
// not setting times
|
||||||
|
|
||||||
|
|
||||||
|
final ObjectNode treeCombinerSettings = new ObjectNode(JsonNodeFactory.instance);
|
||||||
|
treeCombinerSettings.set("type", new TextNode("CompetingRiskFunctionCombiner"));
|
||||||
|
treeCombinerSettings.set("events",
|
||||||
|
new ArrayNode(JsonNodeFactory.instance, List.of(new IntNode(1), new IntNode(2)))
|
||||||
|
);
|
||||||
|
// not setting times
|
||||||
|
|
||||||
|
final ObjectNode yVarSettings = new ObjectNode(JsonNodeFactory.instance);
|
||||||
|
yVarSettings.set("type", new TextNode("CompetingRiskResponse"));
|
||||||
|
yVarSettings.set("u", new TextNode("time"));
|
||||||
|
yVarSettings.set("delta", new TextNode("status"));
|
||||||
|
|
||||||
|
return Settings.builder()
|
||||||
|
.covariates(List.of(
|
||||||
|
new NumericCovariateSettings("ageatfda"),
|
||||||
|
new BooleanCovariateSettings("idu"),
|
||||||
|
new BooleanCovariateSettings("black"),
|
||||||
|
new NumericCovariateSettings("cd4nadir")
|
||||||
|
)
|
||||||
|
)
|
||||||
|
.dataFileLocation("src/test/resources/wihs.csv")
|
||||||
|
.responseCombinerSettings(responseCombinerSettings)
|
||||||
|
.treeCombinerSettings(treeCombinerSettings)
|
||||||
|
.groupDifferentiatorSettings(groupDifferentiatorSettings)
|
||||||
|
.yVarSettings(yVarSettings)
|
||||||
|
.maxNodeDepth(100000)
|
||||||
|
// TODO fill in these settings
|
||||||
|
.mtry(2)
|
||||||
|
.nodeSize(6)
|
||||||
|
.ntree(100)
|
||||||
|
.numberOfSplits(5)
|
||||||
|
.numberOfThreads(3)
|
||||||
|
.saveProgress(true)
|
||||||
|
.saveTreeLocation("trees/")
|
||||||
|
.build();
|
||||||
|
}
|
||||||
|
|
||||||
|
public CovariateRow getPredictionRow(List<Covariate> covariates){
|
||||||
|
return CovariateRow.createSimple(Map.of(
|
||||||
|
"ageatfda", "35",
|
||||||
|
"idu", "false",
|
||||||
|
"black", "false",
|
||||||
|
"cd4nadir", "0.81")
|
||||||
|
, covariates, 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testSingleTree() throws IOException {
|
||||||
|
final Settings settings = getSettings();
|
||||||
|
settings.setDataFileLocation("src/test/resources/wihs.bootstrapped.csv");
|
||||||
|
settings.setCovariates(List.of(
|
||||||
|
new BooleanCovariateSettings("idu"),
|
||||||
|
new BooleanCovariateSettings("black")
|
||||||
|
)); // by only using BooleanCovariates (only one split rule) we can guarantee identical results with randomForestSRC on one tree.
|
||||||
|
|
||||||
|
final List<Covariate> covariates = getCovariates(settings);
|
||||||
|
|
||||||
|
final List<Row<CompetingRiskResponse>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getDataFileLocation());
|
||||||
|
|
||||||
|
final TreeTrainer<CompetingRiskResponse, CompetingRiskFunctions> treeTrainer = new TreeTrainer<>(settings, covariates);
|
||||||
|
final Node<CompetingRiskFunctions> node = treeTrainer.growTree(dataset);
|
||||||
|
|
||||||
|
final CovariateRow newRow = getPredictionRow(covariates);
|
||||||
|
|
||||||
|
final CompetingRiskFunctions functions = node.evaluate(newRow);
|
||||||
|
|
||||||
|
final MathFunction causeOneCIFFunction = functions.getCumulativeIncidenceFunction(1);
|
||||||
|
final MathFunction causeTwoCIFFunction = functions.getCumulativeIncidenceFunction(2);
|
||||||
|
final MathFunction cumHazOneFunction = functions.getCauseSpecificHazardFunction(1);
|
||||||
|
final MathFunction cumHazTwoFunction = functions.getCauseSpecificHazardFunction(2);
|
||||||
|
|
||||||
|
final double margin = 0.0000001;
|
||||||
|
closeEnough(0.003003003, causeOneCIFFunction.evaluate(0.02).getY(), margin);
|
||||||
|
closeEnough(0.166183852, causeOneCIFFunction.evaluate(1.00).getY(), margin);
|
||||||
|
closeEnough(0.715625487, causeOneCIFFunction.evaluate(6.50).getY(), margin);
|
||||||
|
closeEnough(0.794796334, causeOneCIFFunction.evaluate(10.60).getY(), margin);
|
||||||
|
closeEnough(0.794796334, causeOneCIFFunction.evaluate(10.80).getY(), margin);
|
||||||
|
|
||||||
|
|
||||||
|
closeEnough(0.08149211, causeTwoCIFFunction.evaluate(1.00).getY(), margin);
|
||||||
|
closeEnough(0.14926318, causeTwoCIFFunction.evaluate(6.50).getY(), margin);
|
||||||
|
closeEnough(0.15332850, causeTwoCIFFunction.evaluate(10.80).getY(), margin);
|
||||||
|
|
||||||
|
|
||||||
|
closeEnough(0.1888601, cumHazOneFunction.evaluate(1.00).getY(), margin);
|
||||||
|
closeEnough(1.6189759, cumHazOneFunction.evaluate(6.50).getY(), margin);
|
||||||
|
closeEnough(2.4878342, cumHazOneFunction.evaluate(10.80).getY(), margin);
|
||||||
|
|
||||||
|
|
||||||
|
closeEnough(0.08946513, cumHazTwoFunction.evaluate(1.00).getY(), margin);
|
||||||
|
closeEnough(0.32801830, cumHazTwoFunction.evaluate(6.50).getY(), margin);
|
||||||
|
closeEnough(0.36505534, cumHazTwoFunction.evaluate(10.80).getY(), margin);
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Note - this test triggers a situation where the variance calculation in the log-rank test experiences an NaN.
|
||||||
|
*
|
||||||
|
* @throws IOException
|
||||||
|
*/
|
||||||
|
@Test
|
||||||
|
public void testSingleTree2() throws IOException {
|
||||||
|
final Settings settings = getSettings();
|
||||||
|
settings.setMtry(4);
|
||||||
|
settings.setNumberOfSplits(0);
|
||||||
|
settings.setDataFileLocation("src/test/resources/wihs.bootstrapped2.csv");
|
||||||
|
|
||||||
|
final List<Covariate> covariates = getCovariates(settings);
|
||||||
|
|
||||||
|
final List<Row<CompetingRiskResponse>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getDataFileLocation());
|
||||||
|
|
||||||
|
final TreeTrainer<CompetingRiskResponse, CompetingRiskFunctions> treeTrainer = new TreeTrainer<>(settings, covariates);
|
||||||
|
final Node<CompetingRiskFunctions> node = treeTrainer.growTree(dataset);
|
||||||
|
|
||||||
|
final CovariateRow newRow = getPredictionRow(covariates);
|
||||||
|
|
||||||
|
final CompetingRiskFunctions functions = node.evaluate(newRow);
|
||||||
|
|
||||||
|
final MathFunction causeOneCIFFunction = functions.getCumulativeIncidenceFunction(1);
|
||||||
|
final MathFunction causeTwoCIFFunction = functions.getCumulativeIncidenceFunction(2);
|
||||||
|
final MathFunction cumHazOneFunction = functions.getCauseSpecificHazardFunction(1);
|
||||||
|
final MathFunction cumHazTwoFunction = functions.getCauseSpecificHazardFunction(2);
|
||||||
|
|
||||||
|
|
||||||
|
final double margin = 0.0000001;
|
||||||
|
closeEnough(0, causeOneCIFFunction.evaluate(0.02).getY(), margin);
|
||||||
|
closeEnough(0.555555555, causeOneCIFFunction.evaluate(0.4).getY(), margin);
|
||||||
|
closeEnough(0.66666666666, causeOneCIFFunction.evaluate(0.8).getY(), margin);
|
||||||
|
closeEnough(0.88888888888, causeOneCIFFunction.evaluate(0.9).getY(), margin);
|
||||||
|
closeEnough(1.0, causeOneCIFFunction.evaluate(1.0).getY(), margin);
|
||||||
|
|
||||||
|
/*
|
||||||
|
closeEnough(0.08149211, causeTwoCIFFunction.evaluate(1.00).getY(), margin);
|
||||||
|
closeEnough(0.14926318, causeTwoCIFFunction.evaluate(6.50).getY(), margin);
|
||||||
|
closeEnough(0.15332850, causeTwoCIFFunction.evaluate(10.80).getY(), margin);
|
||||||
|
|
||||||
|
|
||||||
|
closeEnough(0.1888601, cumHazOneFunction.evaluate(1.00).getY(), margin);
|
||||||
|
closeEnough(1.6189759, cumHazOneFunction.evaluate(6.50).getY(), margin);
|
||||||
|
closeEnough(2.4878342, cumHazOneFunction.evaluate(10.80).getY(), margin);
|
||||||
|
|
||||||
|
|
||||||
|
closeEnough(0.08946513, cumHazTwoFunction.evaluate(1.00).getY(), margin);
|
||||||
|
closeEnough(0.32801830, cumHazTwoFunction.evaluate(6.50).getY(), margin);
|
||||||
|
closeEnough(0.36505534, cumHazTwoFunction.evaluate(10.80).getY(), margin);
|
||||||
|
*/
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
public List<Covariate> getCovariates(Settings settings){
|
||||||
|
return settings.getCovariates().stream().map(covariateSettings -> covariateSettings.build()).collect(Collectors.toList());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testLogRankSingleGroupDifferentiatorTwoBooleans() throws IOException {
|
||||||
|
final Settings settings = getSettings();
|
||||||
|
settings.setCovariates(List.of(
|
||||||
|
new BooleanCovariateSettings("idu"),
|
||||||
|
new BooleanCovariateSettings("black")
|
||||||
|
));
|
||||||
|
|
||||||
|
final List<Covariate> covariates = getCovariates(settings);
|
||||||
|
|
||||||
|
final List<Row<CompetingRiskResponse>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getDataFileLocation());
|
||||||
|
|
||||||
|
final ForestTrainer<CompetingRiskResponse, CompetingRiskFunctions, CompetingRiskFunctions> forestTrainer = new ForestTrainer<>(settings, dataset, covariates);
|
||||||
|
|
||||||
|
final Forest<CompetingRiskFunctions, CompetingRiskFunctions> forest = forestTrainer.trainSerial();
|
||||||
|
|
||||||
|
// prediction row
|
||||||
|
// time status ageatfda idu black cd4nadir
|
||||||
|
//409 1.3 1 35 FALSE FALSE 0.81
|
||||||
|
final CovariateRow newRow = getPredictionRow(covariates);
|
||||||
|
|
||||||
|
final CompetingRiskFunctions functions = forest.evaluate(newRow);
|
||||||
|
|
||||||
|
assertCumulativeFunction(functions.getCauseSpecificHazardFunction(1));
|
||||||
|
assertCumulativeFunction(functions.getCauseSpecificHazardFunction(2));
|
||||||
|
assertCumulativeFunction(functions.getCumulativeIncidenceFunction(1));
|
||||||
|
assertCumulativeFunction(functions.getCumulativeIncidenceFunction(2));
|
||||||
|
|
||||||
|
|
||||||
|
closeEnough(0.63, functions.getCumulativeIncidenceFunction(1).evaluate(4.0).getY(), 0.01);
|
||||||
|
closeEnough(0.765, functions.getCumulativeIncidenceFunction(1).evaluate(10.8).getY(), 0.01);
|
||||||
|
|
||||||
|
closeEnough(0.163, functions.getCumulativeIncidenceFunction(2).evaluate(4.0).getY(), 0.01);
|
||||||
|
closeEnough(0.195, functions.getCumulativeIncidenceFunction(2).evaluate(10.8).getY(), 0.01);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void verifyDataset() throws IOException {
|
||||||
|
final Settings settings = getSettings();
|
||||||
|
|
||||||
|
final List<Covariate> covariates = getCovariates(settings);
|
||||||
|
|
||||||
|
final List<Row<CompetingRiskResponse>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getDataFileLocation());
|
||||||
|
|
||||||
|
// Let's count the events and make sure the data was correctly read.
|
||||||
|
int countCensored = 0;
|
||||||
|
int countEventOne = 0;
|
||||||
|
int countEventTwo = 0;
|
||||||
|
for(final Row<CompetingRiskResponse> row : dataset){
|
||||||
|
final CompetingRiskResponse response = row.getResponse();
|
||||||
|
|
||||||
|
if(response.getDelta() == 0){
|
||||||
|
countCensored++;
|
||||||
|
}
|
||||||
|
else if(response.getDelta() == 1){
|
||||||
|
countEventOne++;
|
||||||
|
}
|
||||||
|
else if(response.getDelta() == 2){
|
||||||
|
countEventTwo++;
|
||||||
|
}
|
||||||
|
else{
|
||||||
|
throw new RuntimeException("There's an event of type " + response.getDelta());
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
assertEquals(126, countCensored);
|
||||||
|
assertEquals(679, countEventOne);
|
||||||
|
assertEquals(359, countEventTwo);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testLogRankSingleGroupDifferentiatorAllCovariates() throws IOException {
|
||||||
|
|
||||||
|
final Settings settings = getSettings();
|
||||||
|
|
||||||
|
final List<Covariate> covariates = getCovariates(settings);
|
||||||
|
final List<Row<CompetingRiskResponse>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getDataFileLocation());
|
||||||
|
final ForestTrainer<CompetingRiskResponse, CompetingRiskFunctions, CompetingRiskFunctions> forestTrainer = new ForestTrainer<>(settings, dataset, covariates);
|
||||||
|
final Forest<CompetingRiskFunctions, CompetingRiskFunctions> forest = forestTrainer.trainSerial();
|
||||||
|
|
||||||
|
// prediction row
|
||||||
|
// time status ageatfda idu black cd4nadir
|
||||||
|
//409 1.3 1 35 FALSE FALSE 0.81
|
||||||
|
final CovariateRow newRow = getPredictionRow(covariates);
|
||||||
|
|
||||||
|
final CompetingRiskFunctions functions = forest.evaluate(newRow);
|
||||||
|
|
||||||
|
assertCumulativeFunction(functions.getCauseSpecificHazardFunction(1));
|
||||||
|
assertCumulativeFunction(functions.getCauseSpecificHazardFunction(2));
|
||||||
|
assertCumulativeFunction(functions.getCumulativeIncidenceFunction(1));
|
||||||
|
assertCumulativeFunction(functions.getCumulativeIncidenceFunction(2));
|
||||||
|
|
||||||
|
final List<Point> causeOneCIFPoints = functions.getCumulativeIncidenceFunction(1).getPoints();
|
||||||
|
|
||||||
|
// We seem to consistently underestimate the results.
|
||||||
|
assertTrue(causeOneCIFPoints.get(causeOneCIFPoints.size()-1).getY() > 0.75, "Results should match randomForestSRC");
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* We know the function is cumulative; make sure it is ordered correctly and that that function is monotone.
|
||||||
|
*
|
||||||
|
* @param function
|
||||||
|
*/
|
||||||
|
private void assertCumulativeFunction(MathFunction function){
|
||||||
|
Point previousPoint = null;
|
||||||
|
for(final Point point : function.getPoints()){
|
||||||
|
|
||||||
|
if(previousPoint != null){
|
||||||
|
assertTrue(previousPoint.getTime() < point.getTime(), "Points should be ordered and strictly different");
|
||||||
|
assertTrue(previousPoint.getY() <= point.getY(), "Cumulative incidence functions are monotone");
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
previousPoint = point;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private void closeEnough(double expected, double actual, double margin){
|
||||||
|
assertTrue(Math.abs(expected - actual) < margin, "Expected " + expected + " but saw " + actual);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,12 @@
|
||||||
|
testData = data.frame(delta=c(1,1,1,2,2,0,0), u=c(1,1,2,1.5,2,1.5,2.5))
|
||||||
|
testData$survDelta = ifelse(testData$delta==0, 0, 1) # for KM curves on any events
|
||||||
|
|
||||||
|
require(survival)
|
||||||
|
|
||||||
|
kmCurve = survfit(Surv(u, survDelta, type="right") ~ 1, data=testData)
|
||||||
|
kmCurve$surv
|
||||||
|
|
||||||
|
curve = survfit(Surv(u, event=delta, type="mstate") ~ 1, data=testData)
|
||||||
|
curve$cumhaz[3,1:2,]
|
||||||
|
|
||||||
|
print(t(curve$pstate[,1:2]))
|
|
@ -0,0 +1,94 @@
|
||||||
|
package ca.joeltherrien.randomforest.competingrisk;
|
||||||
|
|
||||||
|
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskFunctions;
|
||||||
|
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
|
||||||
|
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponseCombiner;
|
||||||
|
import ca.joeltherrien.randomforest.responses.competingrisk.MathFunction;
|
||||||
|
import org.junit.jupiter.api.Test;
|
||||||
|
import static org.junit.jupiter.api.Assertions.*;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
public class TestCompetingRiskResponseCombiner {
|
||||||
|
|
||||||
|
private CompetingRiskFunctions generateFunctions(){
|
||||||
|
final List<CompetingRiskResponse> data = new ArrayList<>();
|
||||||
|
|
||||||
|
data.add(new CompetingRiskResponse(1, 1.0));
|
||||||
|
data.add(new CompetingRiskResponse(1, 1.0));
|
||||||
|
data.add(new CompetingRiskResponse(1, 2.0));
|
||||||
|
data.add(new CompetingRiskResponse(2, 1.5));
|
||||||
|
data.add(new CompetingRiskResponse(2, 2.0));
|
||||||
|
data.add(new CompetingRiskResponse(0, 1.5));
|
||||||
|
data.add(new CompetingRiskResponse(0, 2.5));
|
||||||
|
|
||||||
|
final CompetingRiskResponseCombiner combiner = new CompetingRiskResponseCombiner(new int[]{1,2}, null);
|
||||||
|
|
||||||
|
final CompetingRiskFunctions functions = combiner.combine(data);
|
||||||
|
|
||||||
|
return functions;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testCompetingRiskResponseCombiner(){
|
||||||
|
final CompetingRiskFunctions functions = generateFunctions();
|
||||||
|
|
||||||
|
final MathFunction survivalCurve = functions.getSurvivalCurve();
|
||||||
|
|
||||||
|
// time = 1.0 1.5 2.0 2.5
|
||||||
|
// surv = 0.7142857 0.5714286 0.1904762 0.1904762
|
||||||
|
|
||||||
|
final double margin = 0.0000001;
|
||||||
|
|
||||||
|
closeEnough(0.7142857, survivalCurve.evaluate(1.0).getY(), margin);
|
||||||
|
closeEnough(0.5714286, survivalCurve.evaluate(1.5).getY(), margin);
|
||||||
|
closeEnough(0.1904762, survivalCurve.evaluate(2.0).getY(), margin);
|
||||||
|
closeEnough(0.1904762, survivalCurve.evaluate(2.5).getY(), margin);
|
||||||
|
|
||||||
|
|
||||||
|
// Time = 1.0 1.5 2.0 2.5
|
||||||
|
/* Cumulative hazard function. Each row for one event.
|
||||||
|
[,1] [,2] [,3] [,4]
|
||||||
|
[1,] 0.2857143 0.2857143 0.6190476 0.6190476
|
||||||
|
[2,] 0.0000000 0.2000000 0.5333333 0.5333333
|
||||||
|
*/
|
||||||
|
|
||||||
|
final MathFunction cumHaz1 = functions.getCauseSpecificHazardFunction(1);
|
||||||
|
closeEnough(0.2857143, cumHaz1.evaluate(1.0).getY(), margin);
|
||||||
|
closeEnough(0.2857143, cumHaz1.evaluate(1.5).getY(), margin);
|
||||||
|
closeEnough(0.6190476, cumHaz1.evaluate(2.0).getY(), margin);
|
||||||
|
closeEnough(0.6190476, cumHaz1.evaluate(2.5).getY(), margin);
|
||||||
|
|
||||||
|
final MathFunction cumHaz2 = functions.getCauseSpecificHazardFunction(2);
|
||||||
|
closeEnough(0.0, cumHaz2.evaluate(1.0).getY(), margin);
|
||||||
|
closeEnough(0.2, cumHaz2.evaluate(1.5).getY(), margin);
|
||||||
|
closeEnough(0.5333333, cumHaz2.evaluate(2.0).getY(), margin);
|
||||||
|
closeEnough(0.5333333, cumHaz2.evaluate(2.5).getY(), margin);
|
||||||
|
|
||||||
|
/* Time = 1.0 1.5 2.0 2.5
|
||||||
|
Cumulative Incidence Curve. Each row for one event.
|
||||||
|
[,1] [,2] [,3] [,4]
|
||||||
|
[1,] 0.2857143 0.2857143 0.4761905 0.4761905
|
||||||
|
[2,] 0.0000000 0.1428571 0.3333333 0.3333333
|
||||||
|
*/
|
||||||
|
|
||||||
|
final MathFunction cic1 = functions.getCumulativeIncidenceFunction(1);
|
||||||
|
closeEnough(0.2857143, cic1.evaluate(1.0).getY(), margin);
|
||||||
|
closeEnough(0.2857143, cic1.evaluate(1.5).getY(), margin);
|
||||||
|
closeEnough(0.4761905, cic1.evaluate(2.0).getY(), margin);
|
||||||
|
closeEnough(0.4761905, cic1.evaluate(2.5).getY(), margin);
|
||||||
|
|
||||||
|
final MathFunction cic2 = functions.getCumulativeIncidenceFunction(2);
|
||||||
|
closeEnough(0.0, cic2.evaluate(1.0).getY(), margin);
|
||||||
|
closeEnough(0.1428571, cic2.evaluate(1.5).getY(), margin);
|
||||||
|
closeEnough(0.3333333, cic2.evaluate(2.0).getY(), margin);
|
||||||
|
closeEnough(0.3333333, cic2.evaluate(2.5).getY(), margin);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
private void closeEnough(double expected, double actual, double margin){
|
||||||
|
assertTrue(Math.abs(expected - actual) < margin, "Expected " + expected + " but saw " + actual);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,62 @@
|
||||||
|
testData1 = data.frame(delta=c(1,1,1,1,0,0,0), u=c(1,1,2,1.5,2,1.5,2.5), group=TRUE)
|
||||||
|
testData2 = data.frame(delta=c(1,1,1,1,0,0,0), u=c(2,2,4,3,4,3,5), group=FALSE) # just doubled everything
|
||||||
|
|
||||||
|
testData = rbind(testData1, testData2)
|
||||||
|
require(survival)
|
||||||
|
|
||||||
|
results = survdiff(Surv(u, delta, type="right") ~ group, data=testData)
|
||||||
|
|
||||||
|
mantelTest = function(times, observed, group0, group1=!group0){
|
||||||
|
U0 = times[group0]
|
||||||
|
observed0 = observed[group0]
|
||||||
|
U1 = times[group1]
|
||||||
|
observed1 = observed[group1]
|
||||||
|
|
||||||
|
Vs = sort(unique(c(U0[observed0], U1[observed1])))
|
||||||
|
|
||||||
|
atRisk = function(v, u){
|
||||||
|
u = subset(u, u >= v)
|
||||||
|
return(length(u))
|
||||||
|
}
|
||||||
|
|
||||||
|
Os = c()
|
||||||
|
Es = c()
|
||||||
|
varOs = c()
|
||||||
|
|
||||||
|
# we're going to treat group 1 as treatment
|
||||||
|
for(v in Vs){
|
||||||
|
placeboAtRisk = atRisk(v, U0)
|
||||||
|
treatmentAtRisk = atRisk(v, U1)
|
||||||
|
totalAtRisk = placeboAtRisk + treatmentAtRisk
|
||||||
|
|
||||||
|
numTreatmentFailures = length(subset(U1, observed1 & U1 == v))
|
||||||
|
numPlaceboFailures = length(subset(U0, observed0 & U0 == v))
|
||||||
|
totalFailures = numTreatmentFailures + numPlaceboFailures
|
||||||
|
|
||||||
|
Os = c(Os, numTreatmentFailures)
|
||||||
|
Es = c(Es, (totalFailures)*treatmentAtRisk/totalAtRisk)
|
||||||
|
|
||||||
|
varOfO = (totalAtRisk - treatmentAtRisk)/(totalAtRisk - 1) *
|
||||||
|
treatmentAtRisk * (totalFailures / totalAtRisk) *
|
||||||
|
(1 - totalFailures / totalAtRisk)
|
||||||
|
|
||||||
|
if(totalAtRisk == 1){
|
||||||
|
varOfO = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
varOs = c(varOs, varOfO)
|
||||||
|
}
|
||||||
|
|
||||||
|
numerator = sum(Os - Es)
|
||||||
|
variance = sum(varOs)
|
||||||
|
|
||||||
|
Z = numerator/sqrt(variance)
|
||||||
|
return(list(
|
||||||
|
statistic = Z,
|
||||||
|
pvalue = 2*pnorm(abs(Z), lower.tail=FALSE),
|
||||||
|
numerator = numerator,
|
||||||
|
variance = variance
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
|
myTest = mantelTest(testData$u, testData$delta == 1, group0=testData$group==1)
|
|
@ -0,0 +1,61 @@
|
||||||
|
package ca.joeltherrien.randomforest.competingrisk;
|
||||||
|
|
||||||
|
import ca.joeltherrien.randomforest.responses.competingrisk.*;
|
||||||
|
import org.junit.jupiter.api.Test;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||||
|
|
||||||
|
public class TestLogRankSingleGroupDifferentiator {
|
||||||
|
|
||||||
|
private List<CompetingRiskResponse> generateData1(){
|
||||||
|
final List<CompetingRiskResponse> data = new ArrayList<>();
|
||||||
|
|
||||||
|
data.add(new CompetingRiskResponse(1, 1.0));
|
||||||
|
data.add(new CompetingRiskResponse(1, 1.0));
|
||||||
|
data.add(new CompetingRiskResponse(1, 2.0));
|
||||||
|
data.add(new CompetingRiskResponse(1, 1.5));
|
||||||
|
data.add(new CompetingRiskResponse(0, 2.0));
|
||||||
|
data.add(new CompetingRiskResponse(0, 1.5));
|
||||||
|
data.add(new CompetingRiskResponse(0, 2.5));
|
||||||
|
|
||||||
|
return data;
|
||||||
|
}
|
||||||
|
|
||||||
|
private List<CompetingRiskResponse> generateData2(){
|
||||||
|
final List<CompetingRiskResponse> data = new ArrayList<>();
|
||||||
|
|
||||||
|
data.add(new CompetingRiskResponse(1, 2.0));
|
||||||
|
data.add(new CompetingRiskResponse(1, 2.0));
|
||||||
|
data.add(new CompetingRiskResponse(1, 4.0));
|
||||||
|
data.add(new CompetingRiskResponse(1, 3.0));
|
||||||
|
data.add(new CompetingRiskResponse(0, 4.0));
|
||||||
|
data.add(new CompetingRiskResponse(0, 3.0));
|
||||||
|
data.add(new CompetingRiskResponse(0, 5.0));
|
||||||
|
|
||||||
|
return data;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testCompetingRiskResponseCombiner(){
|
||||||
|
final List<CompetingRiskResponse> data1 = generateData1();
|
||||||
|
final List<CompetingRiskResponse> data2 = generateData2();
|
||||||
|
|
||||||
|
final LogRankSingleGroupDifferentiator differentiator = new LogRankSingleGroupDifferentiator(1);
|
||||||
|
|
||||||
|
final double score = differentiator.differentiate(data1, data2);
|
||||||
|
final double margin = 0.000001;
|
||||||
|
|
||||||
|
// Tested using 855 method
|
||||||
|
closeEnough(1.540139, score, margin);
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
private void closeEnough(double expected, double actual, double margin){
|
||||||
|
assertTrue(Math.abs(expected - actual) < margin, "Expected " + expected + " but saw " + actual);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,43 @@
|
||||||
|
package ca.joeltherrien.randomforest.competingrisk;
|
||||||
|
|
||||||
|
import ca.joeltherrien.randomforest.responses.competingrisk.MathFunction;
|
||||||
|
import ca.joeltherrien.randomforest.responses.competingrisk.Point;
|
||||||
|
import org.junit.jupiter.api.Test;
|
||||||
|
import static org.junit.jupiter.api.Assertions.*;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
public class TestMathFunction {
|
||||||
|
|
||||||
|
private MathFunction generateMathFunction(){
|
||||||
|
final double[] time = new double[]{1.0, 2.0, 3.0};
|
||||||
|
final double[] y = new double[]{-1.0, 1.0, 0.5};
|
||||||
|
|
||||||
|
final List<Point> pointList = new ArrayList<>();
|
||||||
|
for(int i=0; i<time.length; i++){
|
||||||
|
pointList.add(new Point(time[i], y[i]));
|
||||||
|
}
|
||||||
|
|
||||||
|
return new MathFunction(pointList, new Point(0.0, 0.1));
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void test(){
|
||||||
|
final MathFunction function = generateMathFunction();
|
||||||
|
|
||||||
|
assertEquals(new Point(1.0, -1.0), function.evaluate(1.0));
|
||||||
|
assertEquals(new Point(2.0, 1.0), function.evaluate(2.0));
|
||||||
|
assertEquals(new Point(3.0, 0.5), function.evaluate(3.0));
|
||||||
|
assertEquals(new Point(0.0, 0.1), function.evaluate(0.5));
|
||||||
|
|
||||||
|
assertEquals(new Point(1.0, -1.0), function.evaluate(1.1));
|
||||||
|
assertEquals(new Point(2.0, 1.0), function.evaluate(2.1));
|
||||||
|
assertEquals(new Point(3.0, 0.5), function.evaluate(3.1));
|
||||||
|
assertEquals(new Point(0.0, 0.1), function.evaluate(0.6));
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
}
|
|
@ -20,6 +20,12 @@ public class TestPersistence {
|
||||||
final ObjectNode groupDifferentiatorSettings = new ObjectNode(JsonNodeFactory.instance);
|
final ObjectNode groupDifferentiatorSettings = new ObjectNode(JsonNodeFactory.instance);
|
||||||
groupDifferentiatorSettings.set("type", new TextNode("WeightedVarianceGroupDifferentiator"));
|
groupDifferentiatorSettings.set("type", new TextNode("WeightedVarianceGroupDifferentiator"));
|
||||||
|
|
||||||
|
final ObjectNode responseCombinerSettings = new ObjectNode(JsonNodeFactory.instance);
|
||||||
|
responseCombinerSettings.set("type", new TextNode("MeanResponseCombiner"));
|
||||||
|
|
||||||
|
final ObjectNode treeCombinerSettings = new ObjectNode(JsonNodeFactory.instance);
|
||||||
|
treeCombinerSettings.set("type", new TextNode("MeanResponseCombiner"));
|
||||||
|
|
||||||
final ObjectNode yVarSettings = new ObjectNode(JsonNodeFactory.instance);
|
final ObjectNode yVarSettings = new ObjectNode(JsonNodeFactory.instance);
|
||||||
yVarSettings.set("type", new TextNode("Double"));
|
yVarSettings.set("type", new TextNode("Double"));
|
||||||
yVarSettings.set("name", new TextNode("y"));
|
yVarSettings.set("name", new TextNode("y"));
|
||||||
|
@ -32,8 +38,8 @@ public class TestPersistence {
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
.dataFileLocation("data.csv")
|
.dataFileLocation("data.csv")
|
||||||
.responseCombiner("MeanResponseCombiner")
|
.responseCombinerSettings(responseCombinerSettings)
|
||||||
.treeResponseCombiner("MeanResponseCombiner")
|
.treeCombinerSettings(treeCombinerSettings)
|
||||||
.groupDifferentiatorSettings(groupDifferentiatorSettings)
|
.groupDifferentiatorSettings(groupDifferentiatorSettings)
|
||||||
.yVarSettings(yVarSettings)
|
.yVarSettings(yVarSettings)
|
||||||
.maxNodeDepth(100000)
|
.maxNodeDepth(100000)
|
||||||
|
|
|
@ -51,7 +51,7 @@ public class TrainForest {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
final TreeTrainer<Double> treeTrainer = TreeTrainer.<Double>builder()
|
final TreeTrainer<Double, Double> treeTrainer = TreeTrainer.<Double, Double>builder()
|
||||||
.numberOfSplits(5)
|
.numberOfSplits(5)
|
||||||
.nodeSize(5)
|
.nodeSize(5)
|
||||||
.maxNodeDepth(100000000)
|
.maxNodeDepth(100000000)
|
||||||
|
@ -59,10 +59,10 @@ public class TrainForest {
|
||||||
.responseCombiner(new MeanResponseCombiner())
|
.responseCombiner(new MeanResponseCombiner())
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
final ForestTrainer<Double> forestTrainer = ForestTrainer.<Double>builder()
|
final ForestTrainer<Double, Double, Double> forestTrainer = ForestTrainer.<Double, Double, Double>builder()
|
||||||
.treeTrainer(treeTrainer)
|
.treeTrainer(treeTrainer)
|
||||||
.data(data)
|
.data(data)
|
||||||
.covariatesToTry(covariateList)
|
.covariates(covariateList)
|
||||||
.mtry(4)
|
.mtry(4)
|
||||||
.ntree(100)
|
.ntree(100)
|
||||||
.treeResponseCombiner(new MeanResponseCombiner())
|
.treeResponseCombiner(new MeanResponseCombiner())
|
||||||
|
|
|
@ -48,18 +48,20 @@ public class TrainSingleTree {
|
||||||
trainingSet.add(generateRow(x1, x2, i));
|
trainingSet.add(generateRow(x1, x2, i));
|
||||||
}
|
}
|
||||||
|
|
||||||
final TreeTrainer<Double> treeTrainer = TreeTrainer.<Double>builder()
|
final List<Covariate> covariateNames = List.of(x1Covariate, x2Covariate);
|
||||||
|
|
||||||
|
final TreeTrainer<Double, Double> treeTrainer = TreeTrainer.<Double, Double>builder()
|
||||||
.groupDifferentiator(new WeightedVarianceGroupDifferentiator())
|
.groupDifferentiator(new WeightedVarianceGroupDifferentiator())
|
||||||
|
.covariates(covariateNames)
|
||||||
.responseCombiner(new MeanResponseCombiner())
|
.responseCombiner(new MeanResponseCombiner())
|
||||||
.maxNodeDepth(30)
|
.maxNodeDepth(30)
|
||||||
.nodeSize(5)
|
.nodeSize(5)
|
||||||
.numberOfSplits(0)
|
.numberOfSplits(0)
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
final List<Covariate> covariateNames = List.of(x1Covariate, x2Covariate);
|
|
||||||
|
|
||||||
final long startTime = System.currentTimeMillis();
|
final long startTime = System.currentTimeMillis();
|
||||||
final Node<Double> baseNode = treeTrainer.growTree(trainingSet, covariateNames);
|
final Node<Double> baseNode = treeTrainer.growTree(trainingSet);
|
||||||
final long endTime = System.currentTimeMillis();
|
final long endTime = System.currentTimeMillis();
|
||||||
|
|
||||||
System.out.println(((double)(endTime - startTime))/1000.0);
|
System.out.println(((double)(endTime - startTime))/1000.0);
|
||||||
|
|
|
@ -69,18 +69,21 @@ public class TrainSingleTreeFactor {
|
||||||
trainingSet.add(generateRow(x1, x2, x3, i));
|
trainingSet.add(generateRow(x1, x2, x3, i));
|
||||||
}
|
}
|
||||||
|
|
||||||
final TreeTrainer<Double> treeTrainer = TreeTrainer.<Double>builder()
|
final List<Covariate> covariateNames = List.of(x1Covariate, x2Covariate);
|
||||||
|
|
||||||
|
final TreeTrainer<Double, Double> treeTrainer = TreeTrainer.<Double, Double>builder()
|
||||||
.groupDifferentiator(new WeightedVarianceGroupDifferentiator())
|
.groupDifferentiator(new WeightedVarianceGroupDifferentiator())
|
||||||
.responseCombiner(new MeanResponseCombiner())
|
.responseCombiner(new MeanResponseCombiner())
|
||||||
|
.covariates(covariateNames)
|
||||||
.maxNodeDepth(30)
|
.maxNodeDepth(30)
|
||||||
.nodeSize(5)
|
.nodeSize(5)
|
||||||
.numberOfSplits(5)
|
.numberOfSplits(5)
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
final List<Covariate> covariateNames = List.of(x1Covariate, x2Covariate);
|
|
||||||
|
|
||||||
final long startTime = System.currentTimeMillis();
|
final long startTime = System.currentTimeMillis();
|
||||||
final Node<Double> baseNode = treeTrainer.growTree(trainingSet, covariateNames);
|
final Node<Double> baseNode = treeTrainer.growTree(trainingSet);
|
||||||
final long endTime = System.currentTimeMillis();
|
final long endTime = System.currentTimeMillis();
|
||||||
|
|
||||||
System.out.println(((double)(endTime - startTime))/1000.0);
|
System.out.println(((double)(endTime - startTime))/1000.0);
|
||||||
|
|
1165
src/test/resources/wihs.bootstrapped.csv
Normal file
1165
src/test/resources/wihs.bootstrapped.csv
Normal file
File diff suppressed because it is too large
Load diff
1165
src/test/resources/wihs.bootstrapped2.csv
Normal file
1165
src/test/resources/wihs.bootstrapped2.csv
Normal file
File diff suppressed because it is too large
Load diff
1165
src/test/resources/wihs.csv
Normal file
1165
src/test/resources/wihs.csv
Normal file
File diff suppressed because it is too large
Load diff
Loading…
Reference in a new issue