Implement Response & GroupDifferentiators for CompetingRisk problems.
Also adjusted how settings are done to allow for specifying differentiators & responses that may require arguments. Note that CompetingRisk code is untested at this point.
This commit is contained in:
parent
4bbb0e0948
commit
462b0d9c35
24 changed files with 609 additions and 78 deletions
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -4,3 +4,4 @@
|
||||||
target/
|
target/
|
||||||
*.iml
|
*.iml
|
||||||
.idea
|
.idea
|
||||||
|
template.yaml
|
||||||
|
|
71
src/main/java/ca/joeltherrien/randomforest/DataLoader.java
Normal file
71
src/main/java/ca/joeltherrien/randomforest/DataLoader.java
Normal file
|
@ -0,0 +1,71 @@
|
||||||
|
package ca.joeltherrien.randomforest;
|
||||||
|
|
||||||
|
import ca.joeltherrien.randomforest.covariates.Covariate;
|
||||||
|
import ca.joeltherrien.randomforest.tree.GroupDifferentiator;
|
||||||
|
import com.fasterxml.jackson.databind.node.ObjectNode;
|
||||||
|
import lombok.RequiredArgsConstructor;
|
||||||
|
import org.apache.commons.csv.CSVFormat;
|
||||||
|
import org.apache.commons.csv.CSVParser;
|
||||||
|
import org.apache.commons.csv.CSVRecord;
|
||||||
|
|
||||||
|
import java.io.FileReader;
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.io.Reader;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
public class DataLoader {
|
||||||
|
|
||||||
|
public static <Y> List<Row<Y>> loadData(final List<Covariate> covariates, final ResponseLoader<Y> responseLoader, String filename) throws IOException {
|
||||||
|
|
||||||
|
final List<Row<Y>> dataset = new ArrayList<>();
|
||||||
|
|
||||||
|
final Reader input = new FileReader(filename);
|
||||||
|
final CSVParser parser = CSVFormat.RFC4180.withFirstRecordAsHeader().parse(input);
|
||||||
|
|
||||||
|
|
||||||
|
int id = 1;
|
||||||
|
for(final CSVRecord record : parser){
|
||||||
|
final Map<String, Covariate.Value> covariateValueMap = new HashMap<>();
|
||||||
|
|
||||||
|
for(final Covariate<?> covariate : covariates){
|
||||||
|
covariateValueMap.put(covariate.getName(), covariate.createValue(record.get(covariate.getName())));
|
||||||
|
}
|
||||||
|
|
||||||
|
final Y y = responseLoader.parse(record);
|
||||||
|
|
||||||
|
dataset.add(new Row<>(covariateValueMap, id++, y));
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
return dataset;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@FunctionalInterface
|
||||||
|
public interface ResponseLoader<Y>{
|
||||||
|
Y parse(CSVRecord record);
|
||||||
|
}
|
||||||
|
|
||||||
|
@FunctionalInterface
|
||||||
|
public interface ResponseLoaderConstructor<Y>{
|
||||||
|
ResponseLoader<Y> construct(ObjectNode node);
|
||||||
|
}
|
||||||
|
|
||||||
|
@RequiredArgsConstructor
|
||||||
|
public static class DoubleLoader implements ResponseLoader<Double> {
|
||||||
|
|
||||||
|
private final String yName;
|
||||||
|
|
||||||
|
public DoubleLoader(final ObjectNode node){
|
||||||
|
this.yName = node.get("name").asText();
|
||||||
|
}
|
||||||
|
@Override
|
||||||
|
public Double parse(CSVRecord record) {
|
||||||
|
return Double.parseDouble(record.get(yName));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -5,6 +5,9 @@ import ca.joeltherrien.randomforest.covariates.Covariate;
|
||||||
import ca.joeltherrien.randomforest.covariates.FactorCovariateSettings;
|
import ca.joeltherrien.randomforest.covariates.FactorCovariateSettings;
|
||||||
import ca.joeltherrien.randomforest.covariates.NumericCovariateSettings;
|
import ca.joeltherrien.randomforest.covariates.NumericCovariateSettings;
|
||||||
import ca.joeltherrien.randomforest.tree.ForestTrainer;
|
import ca.joeltherrien.randomforest.tree.ForestTrainer;
|
||||||
|
import com.fasterxml.jackson.databind.node.JsonNodeFactory;
|
||||||
|
import com.fasterxml.jackson.databind.node.ObjectNode;
|
||||||
|
import com.fasterxml.jackson.databind.node.TextNode;
|
||||||
import org.apache.commons.csv.CSVFormat;
|
import org.apache.commons.csv.CSVFormat;
|
||||||
import org.apache.commons.csv.CSVParser;
|
import org.apache.commons.csv.CSVParser;
|
||||||
import org.apache.commons.csv.CSVRecord;
|
import org.apache.commons.csv.CSVRecord;
|
||||||
|
@ -21,6 +24,7 @@ import java.util.stream.Collectors;
|
||||||
|
|
||||||
public class Main {
|
public class Main {
|
||||||
|
|
||||||
|
|
||||||
public static void main(String[] args) throws IOException {
|
public static void main(String[] args) throws IOException {
|
||||||
if(args.length != 1){
|
if(args.length != 1){
|
||||||
System.out.println("Must provide one argument - the path to the settings.yaml file.");
|
System.out.println("Must provide one argument - the path to the settings.yaml file.");
|
||||||
|
@ -36,7 +40,9 @@ public class Main {
|
||||||
final List<Covariate> covariates = settings.getCovariates().stream()
|
final List<Covariate> covariates = settings.getCovariates().stream()
|
||||||
.map(cs -> cs.build()).collect(Collectors.toList());
|
.map(cs -> cs.build()).collect(Collectors.toList());
|
||||||
|
|
||||||
final List<Row<Double>> dataset = loadData(covariates, settings);
|
|
||||||
|
|
||||||
|
final List<Row<Double>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getDataFileLocation());
|
||||||
|
|
||||||
final ForestTrainer<Double> forestTrainer = new ForestTrainer<>(settings, dataset, covariates);
|
final ForestTrainer<Double> forestTrainer = new ForestTrainer<>(settings, dataset, covariates);
|
||||||
|
|
||||||
|
@ -51,46 +57,28 @@ public class Main {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
public static List<Row<Double>> loadData(final List<Covariate> covariates, final Settings settings) throws IOException {
|
|
||||||
|
|
||||||
final List<Row<Double>> dataset = new ArrayList<>();
|
|
||||||
|
|
||||||
final Reader input = new FileReader(settings.getDataFileLocation());
|
|
||||||
final CSVParser parser = CSVFormat.RFC4180.withFirstRecordAsHeader().parse(input);
|
|
||||||
|
|
||||||
|
|
||||||
int id = 1;
|
|
||||||
for(final CSVRecord record : parser){
|
|
||||||
final Map<String, Covariate.Value> covariateValueMap = new HashMap<>();
|
|
||||||
|
|
||||||
for(final Covariate<?> covariate : covariates){
|
|
||||||
covariateValueMap.put(covariate.getName(), covariate.createValue(record.get(covariate.getName())));
|
|
||||||
}
|
|
||||||
|
|
||||||
final String yStr = record.get(settings.getYVar());
|
|
||||||
final Double yNum = Double.parseDouble(yStr);
|
|
||||||
|
|
||||||
dataset.add(new Row<>(covariateValueMap, id++, yNum));
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
return dataset;
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
private static Settings defaultTemplate(){
|
private static Settings defaultTemplate(){
|
||||||
return Settings.builder()
|
|
||||||
|
final ObjectNode groupDifferentiatorSettings = new ObjectNode(JsonNodeFactory.instance);
|
||||||
|
groupDifferentiatorSettings.set("type", new TextNode("WeightedVarianceGroupDifferentiator"));
|
||||||
|
|
||||||
|
final ObjectNode yVarSettings = new ObjectNode(JsonNodeFactory.instance);
|
||||||
|
yVarSettings.set("type", new TextNode("y"));
|
||||||
|
yVarSettings.set("name", new TextNode("y"));
|
||||||
|
|
||||||
|
final Settings settings = Settings.builder()
|
||||||
.covariates(List.of(
|
.covariates(List.of(
|
||||||
new NumericCovariateSettings("x1"),
|
new NumericCovariateSettings("x1"),
|
||||||
new BooleanCovariateSettings("x2"),
|
new BooleanCovariateSettings("x2"),
|
||||||
new FactorCovariateSettings("x3", List.of("cat", "mouse", "dog"))
|
new FactorCovariateSettings("x3", List.of("cat", "mouse", "dog"))
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
.yVar("y")
|
|
||||||
.dataFileLocation("data.csv")
|
.dataFileLocation("data.csv")
|
||||||
.groupDifferentiator("WeightedVarianceGroupDifferentiator")
|
|
||||||
.responseCombiner("MeanResponseCombiner")
|
.responseCombiner("MeanResponseCombiner")
|
||||||
.treeResponseCombiner("MeanResponseCombiner")
|
.treeResponseCombiner("MeanResponseCombiner")
|
||||||
|
.groupDifferentiatorSettings(groupDifferentiatorSettings)
|
||||||
|
.yVarSettings(yVarSettings)
|
||||||
.maxNodeDepth(100000)
|
.maxNodeDepth(100000)
|
||||||
.mtry(2)
|
.mtry(2)
|
||||||
.nodeSize(5)
|
.nodeSize(5)
|
||||||
|
@ -100,6 +88,9 @@ public class Main {
|
||||||
.saveProgress(true)
|
.saveProgress(true)
|
||||||
.saveTreeLocation("trees/")
|
.saveTreeLocation("trees/")
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
|
|
||||||
|
return settings;
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,20 +1,21 @@
|
||||||
package ca.joeltherrien.randomforest;
|
package ca.joeltherrien.randomforest;
|
||||||
|
|
||||||
import ca.joeltherrien.randomforest.covariates.Covariate;
|
|
||||||
import ca.joeltherrien.randomforest.covariates.CovariateSettings;
|
import ca.joeltherrien.randomforest.covariates.CovariateSettings;
|
||||||
import ca.joeltherrien.randomforest.regression.MeanResponseCombiner;
|
import ca.joeltherrien.randomforest.responses.competingrisk.*;
|
||||||
|
import ca.joeltherrien.randomforest.responses.regression.MeanGroupDifferentiator;
|
||||||
|
import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceGroupDifferentiator;
|
||||||
import ca.joeltherrien.randomforest.tree.GroupDifferentiator;
|
import ca.joeltherrien.randomforest.tree.GroupDifferentiator;
|
||||||
import ca.joeltherrien.randomforest.tree.ResponseCombiner;
|
import com.fasterxml.jackson.annotation.JsonIgnore;
|
||||||
|
import com.fasterxml.jackson.databind.JsonNode;
|
||||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||||
|
import com.fasterxml.jackson.databind.node.JsonNodeFactory;
|
||||||
|
import com.fasterxml.jackson.databind.node.ObjectNode;
|
||||||
import com.fasterxml.jackson.dataformat.yaml.YAMLFactory;
|
import com.fasterxml.jackson.dataformat.yaml.YAMLFactory;
|
||||||
import lombok.*;
|
import lombok.*;
|
||||||
|
|
||||||
import java.io.File;
|
import java.io.File;
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.util.ArrayList;
|
import java.util.*;
|
||||||
import java.util.HashMap;
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.Map;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This class is saved & loaded using a saved configuration file. It contains all relevant settings when training a forest.
|
* This class is saved & loaded using a saved configuration file. It contains all relevant settings when training a forest.
|
||||||
|
@ -25,16 +26,88 @@ import java.util.Map;
|
||||||
@EqualsAndHashCode
|
@EqualsAndHashCode
|
||||||
public class Settings {
|
public class Settings {
|
||||||
|
|
||||||
|
private static Map<String, DataLoader.ResponseLoaderConstructor> RESPONSE_LOADER_MAP = new HashMap<>();
|
||||||
|
public static DataLoader.ResponseLoaderConstructor getResponseLoaderConstructor(final String name){
|
||||||
|
return RESPONSE_LOADER_MAP.get(name.toLowerCase());
|
||||||
|
}
|
||||||
|
public static void registerResponseLoaderConstructor(final String name, final DataLoader.ResponseLoaderConstructor responseLoaderConstructor){
|
||||||
|
RESPONSE_LOADER_MAP.put(name.toLowerCase(), responseLoaderConstructor);
|
||||||
|
}
|
||||||
|
|
||||||
|
static{
|
||||||
|
registerResponseLoaderConstructor("double",
|
||||||
|
node -> new DataLoader.DoubleLoader(node)
|
||||||
|
);
|
||||||
|
registerResponseLoaderConstructor("CompetingResponse",
|
||||||
|
node -> new CompetingResponse.CompetingResponseLoader(node)
|
||||||
|
);
|
||||||
|
registerResponseLoaderConstructor("CompetingResponseWithCensorTime",
|
||||||
|
node -> new CompetingResponseWithCensorTime.CompetingResponseWithCensorTimeLoader(node)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
private static Map<String, GroupDifferentiator.GroupDifferentiatorConstructor> GROUP_DIFFERENTIATOR_MAP = new HashMap<>();
|
||||||
|
public static GroupDifferentiator.GroupDifferentiatorConstructor getGroupDifferentiatorConstructor(final String name){
|
||||||
|
return GROUP_DIFFERENTIATOR_MAP.get(name.toLowerCase());
|
||||||
|
}
|
||||||
|
public static void registerGroupDifferentiatorConstructor(final String name, final GroupDifferentiator.GroupDifferentiatorConstructor groupDifferentiatorConstructor){
|
||||||
|
GROUP_DIFFERENTIATOR_MAP.put(name.toLowerCase(), groupDifferentiatorConstructor);
|
||||||
|
}
|
||||||
|
static{
|
||||||
|
registerGroupDifferentiatorConstructor("MeanGroupDifferentiator",
|
||||||
|
(node) -> new MeanGroupDifferentiator()
|
||||||
|
);
|
||||||
|
registerGroupDifferentiatorConstructor("WeightedVarianceGroupDifferentiator",
|
||||||
|
(node) -> new WeightedVarianceGroupDifferentiator()
|
||||||
|
);
|
||||||
|
registerGroupDifferentiatorConstructor("LogRankSingleGroupDifferentiator",
|
||||||
|
(objectNode) -> {
|
||||||
|
final int eventOfFocus = objectNode.get("eventOfFocus").asInt();
|
||||||
|
|
||||||
|
return new LogRankSingleGroupDifferentiator(eventOfFocus);
|
||||||
|
}
|
||||||
|
);
|
||||||
|
registerGroupDifferentiatorConstructor("GrayLogRankMultipleGroupDifferentiator",
|
||||||
|
(objectNode) -> {
|
||||||
|
final Iterator<JsonNode> elements = objectNode.get("events").elements();
|
||||||
|
final List<JsonNode> elementList = new ArrayList<>();
|
||||||
|
elements.forEachRemaining(node -> elementList.add(node));
|
||||||
|
|
||||||
|
final int[] eventArray = elementList.stream().mapToInt(node -> node.asInt()).toArray();
|
||||||
|
|
||||||
|
return new GrayLogRankMultipleGroupDifferentiator(eventArray);
|
||||||
|
}
|
||||||
|
);
|
||||||
|
registerGroupDifferentiatorConstructor("LogRankMultipleGroupDifferentiator",
|
||||||
|
(objectNode) -> {
|
||||||
|
final Iterator<JsonNode> elements = objectNode.get("events").elements();
|
||||||
|
final List<JsonNode> elementList = new ArrayList<>();
|
||||||
|
elements.forEachRemaining(node -> elementList.add(node));
|
||||||
|
|
||||||
|
final int[] eventArray = elementList.stream().mapToInt(node -> node.asInt()).toArray();
|
||||||
|
|
||||||
|
return new LogRankMultipleGroupDifferentiator(eventArray);
|
||||||
|
}
|
||||||
|
);
|
||||||
|
registerGroupDifferentiatorConstructor("GrayLogRankSingleGroupDifferentiator",
|
||||||
|
(objectNode) -> {
|
||||||
|
final int eventOfFocus = objectNode.get("eventOfFocus").asInt();
|
||||||
|
|
||||||
|
return new GrayLogRankSingleGroupDifferentiator(eventOfFocus);
|
||||||
|
}
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
private int numberOfSplits = 5;
|
private int numberOfSplits = 5;
|
||||||
private int nodeSize = 5;
|
private int nodeSize = 5;
|
||||||
private int maxNodeDepth = 1000000; // basically no maxNodeDepth
|
private int maxNodeDepth = 1000000; // basically no maxNodeDepth
|
||||||
|
|
||||||
private String responseCombiner;
|
private String responseCombiner;
|
||||||
private String groupDifferentiator;
|
private ObjectNode groupDifferentiatorSettings = new ObjectNode(JsonNodeFactory.instance);
|
||||||
private String treeResponseCombiner;
|
private String treeResponseCombiner;
|
||||||
|
|
||||||
private List<CovariateSettings> covariates = new ArrayList<>();
|
private List<CovariateSettings> covariates = new ArrayList<>();
|
||||||
private String yVar = "y";
|
private ObjectNode yVarSettings = new ObjectNode(JsonNodeFactory.instance);
|
||||||
|
|
||||||
// number of covariates to randomly try
|
// number of covariates to randomly try
|
||||||
private int mtry = 0;
|
private int mtry = 0;
|
||||||
|
@ -48,7 +121,8 @@ public class Settings {
|
||||||
private int numberOfThreads = 1;
|
private int numberOfThreads = 1;
|
||||||
private boolean saveProgress = false;
|
private boolean saveProgress = false;
|
||||||
|
|
||||||
public Settings(){} // required for Jackson
|
public Settings(){
|
||||||
|
} // required for Jackson
|
||||||
|
|
||||||
public static Settings load(File file) throws IOException {
|
public static Settings load(File file) throws IOException {
|
||||||
final ObjectMapper mapper = new ObjectMapper(new YAMLFactory());
|
final ObjectMapper mapper = new ObjectMapper(new YAMLFactory());
|
||||||
|
@ -70,4 +144,18 @@ public class Settings {
|
||||||
mapper.writeValue(file, this);
|
mapper.writeValue(file, this);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@JsonIgnore
|
||||||
|
public GroupDifferentiator getGroupDifferentiator(){
|
||||||
|
final String type = groupDifferentiatorSettings.get("type").asText();
|
||||||
|
|
||||||
|
return getGroupDifferentiatorConstructor(type).construct(groupDifferentiatorSettings);
|
||||||
|
}
|
||||||
|
|
||||||
|
@JsonIgnore
|
||||||
|
public DataLoader.ResponseLoader getResponseLoader(){
|
||||||
|
final String type = yVarSettings.get("type").asText();
|
||||||
|
|
||||||
|
return getResponseLoaderConstructor(type).construct(yVarSettings);
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,36 @@
|
||||||
|
package ca.joeltherrien.randomforest.responses.competingrisk;
|
||||||
|
|
||||||
|
import ca.joeltherrien.randomforest.DataLoader;
|
||||||
|
import com.fasterxml.jackson.databind.node.ObjectNode;
|
||||||
|
import lombok.Data;
|
||||||
|
import lombok.RequiredArgsConstructor;
|
||||||
|
import org.apache.commons.csv.CSVRecord;
|
||||||
|
|
||||||
|
@Data
|
||||||
|
public class CompetingResponse {
|
||||||
|
|
||||||
|
private final int delta;
|
||||||
|
private final double u;
|
||||||
|
|
||||||
|
|
||||||
|
@RequiredArgsConstructor
|
||||||
|
public static class CompetingResponseLoader implements DataLoader.ResponseLoader<CompetingResponse>{
|
||||||
|
|
||||||
|
private final String deltaName;
|
||||||
|
private final String uName;
|
||||||
|
|
||||||
|
public CompetingResponseLoader(ObjectNode node){
|
||||||
|
this.deltaName = node.get("delta").asText();
|
||||||
|
this.uName = node.get("u").asText();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public CompetingResponse parse(CSVRecord record) {
|
||||||
|
final int delta = Integer.parseInt(record.get(deltaName));
|
||||||
|
final double u = Double.parseDouble(record.get(uName));
|
||||||
|
|
||||||
|
return new CompetingResponse(delta, u);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,44 @@
|
||||||
|
package ca.joeltherrien.randomforest.responses.competingrisk;
|
||||||
|
|
||||||
|
import ca.joeltherrien.randomforest.DataLoader;
|
||||||
|
import com.fasterxml.jackson.databind.node.ObjectNode;
|
||||||
|
import lombok.Data;
|
||||||
|
import lombok.RequiredArgsConstructor;
|
||||||
|
import org.apache.commons.csv.CSVRecord;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* See Ishwaran paper on splitting rule modelled after Gray's test. This requires that we know the censor times.
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
@Data
|
||||||
|
public class CompetingResponseWithCensorTime extends CompetingResponse{
|
||||||
|
private final double c;
|
||||||
|
|
||||||
|
public CompetingResponseWithCensorTime(int delta, double u, double c) {
|
||||||
|
super(delta, u);
|
||||||
|
this.c = c;
|
||||||
|
}
|
||||||
|
|
||||||
|
@RequiredArgsConstructor
|
||||||
|
public static class CompetingResponseWithCensorTimeLoader implements DataLoader.ResponseLoader<CompetingResponseWithCensorTime>{
|
||||||
|
|
||||||
|
private final String deltaName;
|
||||||
|
private final String uName;
|
||||||
|
private final String cName;
|
||||||
|
|
||||||
|
public CompetingResponseWithCensorTimeLoader(ObjectNode node){
|
||||||
|
this.deltaName = node.get("delta").asText();
|
||||||
|
this.uName = node.get("u").asText();
|
||||||
|
this.cName = node.get("c").asText();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public CompetingResponseWithCensorTime parse(CSVRecord record) {
|
||||||
|
final int delta = Integer.parseInt(record.get(deltaName));
|
||||||
|
final double u = Double.parseDouble(record.get(uName));
|
||||||
|
final double c = Double.parseDouble(record.get(cName));
|
||||||
|
|
||||||
|
return new CompetingResponseWithCensorTime(delta, u, c);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,11 @@
|
||||||
|
package ca.joeltherrien.randomforest.responses.competingrisk;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
public class CompetingRiskFunction {
|
||||||
|
|
||||||
|
private List<Point> pointList;
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,91 @@
|
||||||
|
package ca.joeltherrien.randomforest.responses.competingrisk;
|
||||||
|
|
||||||
|
import ca.joeltherrien.randomforest.tree.GroupDifferentiator;
|
||||||
|
import lombok.AllArgsConstructor;
|
||||||
|
import lombok.Data;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.stream.Stream;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* See page 761 of Random survival forests for competing risks by Ishwaran et al. The class is abstract as Gray's test
|
||||||
|
* modifies the abstract method.
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
public abstract class CompetingRiskGroupDifferentiator<Y extends CompetingResponse> implements GroupDifferentiator<Y>{
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public abstract Double differentiate(List<Y> leftHand, List<Y> rightHand);
|
||||||
|
|
||||||
|
abstract double riskSet(final List<Y> eventList, double time, int eventOfFocus);
|
||||||
|
|
||||||
|
private double numberOFEventsAtTime(int eventOfFocus, List<Y> eventList, double time){
|
||||||
|
return (double) eventList.stream()
|
||||||
|
.filter(event -> event.getDelta() == eventOfFocus)
|
||||||
|
.filter(event -> event.getU() == time) // since delta != 0 we know censoring didn't occur prior to this
|
||||||
|
.count();
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Calculates the log rank value (or the Gray's test value) for a *specific* event cause.
|
||||||
|
*
|
||||||
|
* @param eventOfFocus
|
||||||
|
* @param leftHand A non-empty list of CompetingResponse
|
||||||
|
* @param rightHand A non-empty list of CompetingResponse
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
LogRankValue specificLogRankValue(final int eventOfFocus, List<Y> leftHand, List<Y> rightHand){
|
||||||
|
|
||||||
|
final double[] distinctEventTimes = Stream.concat(
|
||||||
|
leftHand.stream(), rightHand.stream()
|
||||||
|
)
|
||||||
|
.filter(event -> event.getDelta() != 0) // remove censored events
|
||||||
|
.mapToDouble(event -> event.getU())
|
||||||
|
.distinct()
|
||||||
|
.toArray();
|
||||||
|
|
||||||
|
double summation = 0.0;
|
||||||
|
double varianceSquared = 0.0;
|
||||||
|
|
||||||
|
for(final double time_k : distinctEventTimes){
|
||||||
|
final double weight = weight(time_k); // W_j(t_k)
|
||||||
|
final double numberEventsAtTimeDaughterLeft = numberOFEventsAtTime(eventOfFocus, leftHand, time_k); // d_{j,l}(t_k)
|
||||||
|
final double numberEventsAtTimeDaughterRight = numberOFEventsAtTime(eventOfFocus, rightHand, time_k); // d_{j,r}(t_k)
|
||||||
|
final double numberOfEventsAtTime = numberEventsAtTimeDaughterLeft + numberEventsAtTimeDaughterRight; // d_j(t_k)
|
||||||
|
|
||||||
|
final double individualsAtRiskDaughterLeft = riskSet(leftHand, time_k, eventOfFocus); // Y_l(t_k)
|
||||||
|
final double individualsAtRiskDaughterRight = riskSet(rightHand, time_k, eventOfFocus); // Y_r(t_k)
|
||||||
|
final double individualsAtRisk = individualsAtRiskDaughterLeft + individualsAtRiskDaughterRight; // Y(t_k)
|
||||||
|
|
||||||
|
summation = summation + weight*(numberEventsAtTimeDaughterLeft - numberOfEventsAtTime*individualsAtRiskDaughterLeft/individualsAtRisk);
|
||||||
|
|
||||||
|
varianceSquared = varianceSquared + weight*weight*numberOfEventsAtTime*individualsAtRiskDaughterLeft/individualsAtRisk
|
||||||
|
* (1.0 - individualsAtRiskDaughterLeft / individualsAtRisk)
|
||||||
|
* ((individualsAtRisk - numberOfEventsAtTime) / (individualsAtRisk - 1.0));
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
return new LogRankValue(summation, varianceSquared);
|
||||||
|
}
|
||||||
|
|
||||||
|
double weight(double time){
|
||||||
|
return 1.0; // TODO - make configurable
|
||||||
|
// A value of 1 "corresponds to the standard log-rank test which has optimal power for detecting alternatives where the cause-specific hazards are proportional"
|
||||||
|
//TODO - look into what weights might be more appropriate.
|
||||||
|
}
|
||||||
|
|
||||||
|
@Data
|
||||||
|
@AllArgsConstructor
|
||||||
|
static class LogRankValue{
|
||||||
|
private final double numerator;
|
||||||
|
private final double varianceSquared;
|
||||||
|
|
||||||
|
public double getVariance(){
|
||||||
|
return Math.sqrt(varianceSquared);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,51 @@
|
||||||
|
package ca.joeltherrien.randomforest.responses.competingrisk;
|
||||||
|
|
||||||
|
import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceGroupDifferentiator;
|
||||||
|
import ca.joeltherrien.randomforest.tree.GroupDifferentiator;
|
||||||
|
import com.fasterxml.jackson.databind.JsonNode;
|
||||||
|
import lombok.RequiredArgsConstructor;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.Iterator;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* See page 761 of Random survival forests for competing risks by Ishwaran et al.
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
@RequiredArgsConstructor
|
||||||
|
public class GrayLogRankMultipleGroupDifferentiator extends CompetingRiskGroupDifferentiator<CompetingResponseWithCensorTime> {
|
||||||
|
|
||||||
|
private final int[] events;
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Double differentiate(List<CompetingResponseWithCensorTime> leftHand, List<CompetingResponseWithCensorTime> rightHand) {
|
||||||
|
if(leftHand.size() == 0 || rightHand.size() == 0){
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
double numerator = 0.0;
|
||||||
|
double denominatorSquared = 0.0;
|
||||||
|
|
||||||
|
for(final int eventOfFocus : events){
|
||||||
|
final LogRankValue valueOfInterest = specificLogRankValue(eventOfFocus, leftHand, rightHand);
|
||||||
|
|
||||||
|
numerator += valueOfInterest.getNumerator()*valueOfInterest.getVariance();
|
||||||
|
denominatorSquared += valueOfInterest.getVarianceSquared();
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
return Math.abs(numerator / Math.sqrt(denominatorSquared));
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
double riskSet(List<CompetingResponseWithCensorTime> eventList, double time, int eventOfFocus) {
|
||||||
|
return eventList.stream()
|
||||||
|
.filter(event -> event.getU() >= time ||
|
||||||
|
(event.getU() < time && event.getDelta() != eventOfFocus && event.getC() > time)
|
||||||
|
)
|
||||||
|
.count();
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,41 @@
|
||||||
|
package ca.joeltherrien.randomforest.responses.competingrisk;
|
||||||
|
|
||||||
|
import ca.joeltherrien.randomforest.tree.GroupDifferentiator;
|
||||||
|
import com.fasterxml.jackson.databind.JsonNode;
|
||||||
|
import lombok.RequiredArgsConstructor;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.Iterator;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* See page 761 of Random survival forests for competing risks by Ishwaran et al.
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
@RequiredArgsConstructor
|
||||||
|
public class GrayLogRankSingleGroupDifferentiator extends CompetingRiskGroupDifferentiator<CompetingResponseWithCensorTime> {
|
||||||
|
|
||||||
|
private final int eventOfFocus;
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Double differentiate(List<CompetingResponseWithCensorTime> leftHand, List<CompetingResponseWithCensorTime> rightHand) {
|
||||||
|
if(leftHand.size() == 0 || rightHand.size() == 0){
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
final LogRankValue valueOfInterest = specificLogRankValue(eventOfFocus, leftHand, rightHand);
|
||||||
|
|
||||||
|
return Math.abs(valueOfInterest.getNumerator() / valueOfInterest.getVariance());
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
double riskSet(List<CompetingResponseWithCensorTime> eventList, double time, int eventOfFocus) {
|
||||||
|
return eventList.stream()
|
||||||
|
.filter(event -> event.getU() >= time ||
|
||||||
|
(event.getU() < time && event.getDelta() != eventOfFocus && event.getC() > time)
|
||||||
|
)
|
||||||
|
.count();
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,48 @@
|
||||||
|
package ca.joeltherrien.randomforest.responses.competingrisk;
|
||||||
|
|
||||||
|
import ca.joeltherrien.randomforest.tree.GroupDifferentiator;
|
||||||
|
import com.fasterxml.jackson.databind.JsonNode;
|
||||||
|
import lombok.RequiredArgsConstructor;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.Iterator;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* See page 761 of Random survival forests for competing risks by Ishwaran et al.
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
@RequiredArgsConstructor
|
||||||
|
public class LogRankMultipleGroupDifferentiator extends CompetingRiskGroupDifferentiator<CompetingResponse> {
|
||||||
|
|
||||||
|
private final int[] events;
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Double differentiate(List<CompetingResponse> leftHand, List<CompetingResponse> rightHand) {
|
||||||
|
if(leftHand.size() == 0 || rightHand.size() == 0){
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
double numerator = 0.0;
|
||||||
|
double denominatorSquared = 0.0;
|
||||||
|
|
||||||
|
for(final int eventOfFocus : events){
|
||||||
|
final LogRankValue valueOfInterest = specificLogRankValue(eventOfFocus, leftHand, rightHand);
|
||||||
|
|
||||||
|
numerator += valueOfInterest.getNumerator()*valueOfInterest.getVariance();
|
||||||
|
denominatorSquared += valueOfInterest.getVarianceSquared();
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
return Math.abs(numerator / Math.sqrt(denominatorSquared));
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
double riskSet(List<CompetingResponse> eventList, double time, int eventOfFocus) {
|
||||||
|
return eventList.stream()
|
||||||
|
.filter(event -> event.getU() >= time)
|
||||||
|
.count();
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,36 @@
|
||||||
|
package ca.joeltherrien.randomforest.responses.competingrisk;
|
||||||
|
|
||||||
|
import ca.joeltherrien.randomforest.tree.GroupDifferentiator;
|
||||||
|
import lombok.RequiredArgsConstructor;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* See page 761 of Random survival forests for competing risks by Ishwaran et al.
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
@RequiredArgsConstructor
|
||||||
|
public class LogRankSingleGroupDifferentiator extends CompetingRiskGroupDifferentiator<CompetingResponse> {
|
||||||
|
|
||||||
|
private final int eventOfFocus;
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Double differentiate(List<CompetingResponse> leftHand, List<CompetingResponse> rightHand) {
|
||||||
|
if(leftHand.size() == 0 || rightHand.size() == 0){
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
final LogRankValue valueOfInterest = specificLogRankValue(eventOfFocus, leftHand, rightHand);
|
||||||
|
|
||||||
|
return Math.abs(valueOfInterest.getNumerator() / valueOfInterest.getVariance());
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
double riskSet(List<CompetingResponse> eventList, double time, int eventOfFocus) {
|
||||||
|
return eventList.stream()
|
||||||
|
.filter(event -> event.getU() >= time)
|
||||||
|
.count();
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,15 @@
|
||||||
|
package ca.joeltherrien.randomforest.responses.competingrisk;
|
||||||
|
|
||||||
|
import lombok.Data;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Represents a point in our estimate of either the cumulative hazard function or the cumulative incidence function.
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
@Data
|
||||||
|
public class Point {
|
||||||
|
|
||||||
|
private final Double time;
|
||||||
|
private final Double y;
|
||||||
|
|
||||||
|
}
|
|
@ -1,16 +1,11 @@
|
||||||
package ca.joeltherrien.randomforest.regression;
|
package ca.joeltherrien.randomforest.responses.regression;
|
||||||
|
|
||||||
import ca.joeltherrien.randomforest.tree.GroupDifferentiator;
|
import ca.joeltherrien.randomforest.tree.GroupDifferentiator;
|
||||||
import ca.joeltherrien.randomforest.tree.ResponseCombiner;
|
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
public class MeanGroupDifferentiator implements GroupDifferentiator<Double> {
|
public class MeanGroupDifferentiator implements GroupDifferentiator<Double> {
|
||||||
|
|
||||||
static{
|
|
||||||
GroupDifferentiator.registerGroupDifferentiator("MeanGroupDifferentiator", new MeanGroupDifferentiator());
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Double differentiate(List<Double> leftHand, List<Double> rightHand) {
|
public Double differentiate(List<Double> leftHand, List<Double> rightHand) {
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
package ca.joeltherrien.randomforest.regression;
|
package ca.joeltherrien.randomforest.responses.regression;
|
||||||
|
|
||||||
import ca.joeltherrien.randomforest.tree.ResponseCombiner;
|
import ca.joeltherrien.randomforest.tree.ResponseCombiner;
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
package ca.joeltherrien.randomforest.regression;
|
package ca.joeltherrien.randomforest.responses.regression;
|
||||||
|
|
||||||
import ca.joeltherrien.randomforest.tree.GroupDifferentiator;
|
import ca.joeltherrien.randomforest.tree.GroupDifferentiator;
|
||||||
|
|
||||||
|
@ -6,10 +6,6 @@ import java.util.List;
|
||||||
|
|
||||||
public class WeightedVarianceGroupDifferentiator implements GroupDifferentiator<Double> {
|
public class WeightedVarianceGroupDifferentiator implements GroupDifferentiator<Double> {
|
||||||
|
|
||||||
static{
|
|
||||||
GroupDifferentiator.registerGroupDifferentiator("WeightedVarianceGroupDifferentiator", new WeightedVarianceGroupDifferentiator());
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Double differentiate(List<Double> leftHand, List<Double> rightHand) {
|
public Double differentiate(List<Double> leftHand, List<Double> rightHand) {
|
||||||
|
|
|
@ -1,5 +1,7 @@
|
||||||
package ca.joeltherrien.randomforest.tree;
|
package ca.joeltherrien.randomforest.tree;
|
||||||
|
|
||||||
|
import com.fasterxml.jackson.databind.node.ObjectNode;
|
||||||
|
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
@ -14,12 +16,11 @@ public interface GroupDifferentiator<Y> {
|
||||||
|
|
||||||
Double differentiate(List<Y> leftHand, List<Y> rightHand);
|
Double differentiate(List<Y> leftHand, List<Y> rightHand);
|
||||||
|
|
||||||
Map<String, GroupDifferentiator> GROUP_DIFFERENTIATOR_MAP = new HashMap<>();
|
@FunctionalInterface
|
||||||
static GroupDifferentiator loadGroupDifferentiatorByName(final String name){
|
interface GroupDifferentiatorConstructor<Y>{
|
||||||
return GROUP_DIFFERENTIATOR_MAP.get(name);
|
|
||||||
}
|
GroupDifferentiator<Y> construct(ObjectNode node);
|
||||||
static void registerGroupDifferentiator(final String name, final GroupDifferentiator groupDifferentiator){
|
|
||||||
GROUP_DIFFERENTIATOR_MAP.put(name, groupDifferentiator);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,7 +1,5 @@
|
||||||
package ca.joeltherrien.randomforest.tree;
|
package ca.joeltherrien.randomforest.tree;
|
||||||
|
|
||||||
import ca.joeltherrien.randomforest.regression.MeanResponseCombiner;
|
|
||||||
|
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
|
|
@ -30,7 +30,7 @@ public class TreeTrainer<Y> {
|
||||||
this.maxNodeDepth = settings.getMaxNodeDepth();
|
this.maxNodeDepth = settings.getMaxNodeDepth();
|
||||||
|
|
||||||
this.responseCombiner = ResponseCombiner.loadResponseCombinerByName(settings.getResponseCombiner());
|
this.responseCombiner = ResponseCombiner.loadResponseCombinerByName(settings.getResponseCombiner());
|
||||||
this.groupDifferentiator = GroupDifferentiator.loadGroupDifferentiatorByName(settings.getGroupDifferentiator());
|
this.groupDifferentiator = settings.getGroupDifferentiator();
|
||||||
}
|
}
|
||||||
|
|
||||||
public Node<Y> growTree(List<Row<Y>> data, List<Covariate> covariatesToTry){
|
public Node<Y> growTree(List<Row<Y>> data, List<Covariate> covariatesToTry){
|
||||||
|
|
|
@ -1,12 +1,15 @@
|
||||||
package ca.joeltherrien.randomforest.csv;
|
package ca.joeltherrien.randomforest.csv;
|
||||||
|
|
||||||
import ca.joeltherrien.randomforest.Main;
|
import ca.joeltherrien.randomforest.DataLoader;
|
||||||
import ca.joeltherrien.randomforest.Row;
|
import ca.joeltherrien.randomforest.Row;
|
||||||
import ca.joeltherrien.randomforest.Settings;
|
import ca.joeltherrien.randomforest.Settings;
|
||||||
import ca.joeltherrien.randomforest.covariates.BooleanCovariateSettings;
|
import ca.joeltherrien.randomforest.covariates.BooleanCovariateSettings;
|
||||||
import ca.joeltherrien.randomforest.covariates.Covariate;
|
import ca.joeltherrien.randomforest.covariates.Covariate;
|
||||||
import ca.joeltherrien.randomforest.covariates.FactorCovariateSettings;
|
import ca.joeltherrien.randomforest.covariates.FactorCovariateSettings;
|
||||||
import ca.joeltherrien.randomforest.covariates.NumericCovariateSettings;
|
import ca.joeltherrien.randomforest.covariates.NumericCovariateSettings;
|
||||||
|
import com.fasterxml.jackson.databind.node.JsonNodeFactory;
|
||||||
|
import com.fasterxml.jackson.databind.node.ObjectNode;
|
||||||
|
import com.fasterxml.jackson.databind.node.TextNode;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
|
@ -26,7 +29,11 @@ public class TestLoadingCSV {
|
||||||
*/
|
*/
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void verifyLoading() throws IOException {
|
public void verifyLoading() throws IOException, ClassNotFoundException {
|
||||||
|
final ObjectNode yVarSettings = new ObjectNode(JsonNodeFactory.instance);
|
||||||
|
yVarSettings.set("type", new TextNode("Double"));
|
||||||
|
yVarSettings.set("name", new TextNode("y"));
|
||||||
|
|
||||||
final Settings settings = Settings.builder()
|
final Settings settings = Settings.builder()
|
||||||
.dataFileLocation("src/test/resources/testCSV.csv")
|
.dataFileLocation("src/test/resources/testCSV.csv")
|
||||||
.covariates(
|
.covariates(
|
||||||
|
@ -34,13 +41,16 @@ public class TestLoadingCSV {
|
||||||
new FactorCovariateSettings("x2", List.of("dog", "cat", "mouse")),
|
new FactorCovariateSettings("x2", List.of("dog", "cat", "mouse")),
|
||||||
new BooleanCovariateSettings("x3"))
|
new BooleanCovariateSettings("x3"))
|
||||||
)
|
)
|
||||||
.yVar("y")
|
.yVarSettings(yVarSettings)
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
final List<Covariate> covariates = settings.getCovariates().stream()
|
final List<Covariate> covariates = settings.getCovariates().stream()
|
||||||
.map(cs -> cs.build()).collect(Collectors.toList());
|
.map(cs -> cs.build()).collect(Collectors.toList());
|
||||||
|
|
||||||
final List<Row<Double>> data = Main.loadData(covariates, settings);
|
|
||||||
|
final DataLoader.ResponseLoader loader = settings.getResponseLoader();
|
||||||
|
|
||||||
|
final List<Row<Double>> data = DataLoader.loadData(covariates, loader, settings.getDataFileLocation());
|
||||||
|
|
||||||
assertEquals(4, data.size());
|
assertEquals(4, data.size());
|
||||||
|
|
||||||
|
|
|
@ -4,19 +4,26 @@ import ca.joeltherrien.randomforest.Settings;
|
||||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
|
||||||
import ca.joeltherrien.randomforest.covariates.*;
|
import ca.joeltherrien.randomforest.covariates.*;
|
||||||
import ca.joeltherrien.randomforest.regression.MeanResponseCombiner;
|
import com.fasterxml.jackson.databind.node.JsonNodeFactory;
|
||||||
import ca.joeltherrien.randomforest.regression.WeightedVarianceGroupDifferentiator;
|
import com.fasterxml.jackson.databind.node.ObjectNode;
|
||||||
|
import com.fasterxml.jackson.databind.node.TextNode;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
|
|
||||||
import java.io.File;
|
import java.io.File;
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
public class TestPersistence {
|
public class TestPersistence {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testSaving() throws IOException {
|
public void testSaving() throws IOException {
|
||||||
|
final ObjectNode groupDifferentiatorSettings = new ObjectNode(JsonNodeFactory.instance);
|
||||||
|
groupDifferentiatorSettings.set("type", new TextNode("WeightedVarianceGroupDifferentiator"));
|
||||||
|
|
||||||
|
final ObjectNode yVarSettings = new ObjectNode(JsonNodeFactory.instance);
|
||||||
|
yVarSettings.set("type", new TextNode("Double"));
|
||||||
|
yVarSettings.set("name", new TextNode("y"));
|
||||||
|
|
||||||
final Settings settingsOriginal = Settings.builder()
|
final Settings settingsOriginal = Settings.builder()
|
||||||
.covariates(List.of(
|
.covariates(List.of(
|
||||||
new NumericCovariateSettings("x1"),
|
new NumericCovariateSettings("x1"),
|
||||||
|
@ -24,11 +31,11 @@ public class TestPersistence {
|
||||||
new FactorCovariateSettings("x3", List.of("cat", "mouse", "dog"))
|
new FactorCovariateSettings("x3", List.of("cat", "mouse", "dog"))
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
.yVar("y")
|
|
||||||
.dataFileLocation("data.csv")
|
.dataFileLocation("data.csv")
|
||||||
.groupDifferentiator("WeightedVarianceGroupDifferentiator")
|
|
||||||
.responseCombiner("MeanResponseCombiner")
|
.responseCombiner("MeanResponseCombiner")
|
||||||
.treeResponseCombiner("MeanResponseCombiner")
|
.treeResponseCombiner("MeanResponseCombiner")
|
||||||
|
.groupDifferentiatorSettings(groupDifferentiatorSettings)
|
||||||
|
.yVarSettings(yVarSettings)
|
||||||
.maxNodeDepth(100000)
|
.maxNodeDepth(100000)
|
||||||
.mtry(2)
|
.mtry(2)
|
||||||
.nodeSize(5)
|
.nodeSize(5)
|
||||||
|
@ -46,7 +53,7 @@ public class TestPersistence {
|
||||||
|
|
||||||
assertEquals(settingsOriginal, reloadedSettings);
|
assertEquals(settingsOriginal, reloadedSettings);
|
||||||
|
|
||||||
templateFile.delete();
|
//templateFile.delete();
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -3,8 +3,8 @@ package ca.joeltherrien.randomforest.workshop;
|
||||||
import ca.joeltherrien.randomforest.*;
|
import ca.joeltherrien.randomforest.*;
|
||||||
import ca.joeltherrien.randomforest.covariates.Covariate;
|
import ca.joeltherrien.randomforest.covariates.Covariate;
|
||||||
import ca.joeltherrien.randomforest.covariates.NumericCovariate;
|
import ca.joeltherrien.randomforest.covariates.NumericCovariate;
|
||||||
import ca.joeltherrien.randomforest.regression.MeanResponseCombiner;
|
import ca.joeltherrien.randomforest.responses.regression.MeanResponseCombiner;
|
||||||
import ca.joeltherrien.randomforest.regression.WeightedVarianceGroupDifferentiator;
|
import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceGroupDifferentiator;
|
||||||
import ca.joeltherrien.randomforest.tree.ForestTrainer;
|
import ca.joeltherrien.randomforest.tree.ForestTrainer;
|
||||||
import ca.joeltherrien.randomforest.tree.TreeTrainer;
|
import ca.joeltherrien.randomforest.tree.TreeTrainer;
|
||||||
|
|
||||||
|
|
|
@ -5,8 +5,8 @@ import ca.joeltherrien.randomforest.covariates.Covariate;
|
||||||
import ca.joeltherrien.randomforest.CovariateRow;
|
import ca.joeltherrien.randomforest.CovariateRow;
|
||||||
import ca.joeltherrien.randomforest.covariates.NumericCovariate;
|
import ca.joeltherrien.randomforest.covariates.NumericCovariate;
|
||||||
import ca.joeltherrien.randomforest.Row;
|
import ca.joeltherrien.randomforest.Row;
|
||||||
import ca.joeltherrien.randomforest.regression.MeanResponseCombiner;
|
import ca.joeltherrien.randomforest.responses.regression.MeanResponseCombiner;
|
||||||
import ca.joeltherrien.randomforest.regression.WeightedVarianceGroupDifferentiator;
|
import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceGroupDifferentiator;
|
||||||
import ca.joeltherrien.randomforest.tree.Node;
|
import ca.joeltherrien.randomforest.tree.Node;
|
||||||
import ca.joeltherrien.randomforest.tree.TreeTrainer;
|
import ca.joeltherrien.randomforest.tree.TreeTrainer;
|
||||||
|
|
||||||
|
|
|
@ -5,8 +5,8 @@ import ca.joeltherrien.randomforest.*;
|
||||||
import ca.joeltherrien.randomforest.covariates.Covariate;
|
import ca.joeltherrien.randomforest.covariates.Covariate;
|
||||||
import ca.joeltherrien.randomforest.covariates.FactorCovariate;
|
import ca.joeltherrien.randomforest.covariates.FactorCovariate;
|
||||||
import ca.joeltherrien.randomforest.covariates.NumericCovariate;
|
import ca.joeltherrien.randomforest.covariates.NumericCovariate;
|
||||||
import ca.joeltherrien.randomforest.regression.MeanResponseCombiner;
|
import ca.joeltherrien.randomforest.responses.regression.MeanResponseCombiner;
|
||||||
import ca.joeltherrien.randomforest.regression.WeightedVarianceGroupDifferentiator;
|
import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceGroupDifferentiator;
|
||||||
import ca.joeltherrien.randomforest.tree.Node;
|
import ca.joeltherrien.randomforest.tree.Node;
|
||||||
import ca.joeltherrien.randomforest.tree.TreeTrainer;
|
import ca.joeltherrien.randomforest.tree.TreeTrainer;
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue