Implement Response & GroupDifferentiators for CompetingRisk problems.

Also adjusted how settings are done to allow for specifying
differentiators & responses that may require arguments.

Note that CompetingRisk code is untested at this point.
This commit is contained in:
Joel Therrien 2018-07-10 14:43:51 -07:00
parent 4bbb0e0948
commit 462b0d9c35
24 changed files with 609 additions and 78 deletions

1
.gitignore vendored
View file

@ -4,3 +4,4 @@
target/
*.iml
.idea
template.yaml

View file

@ -0,0 +1,71 @@
package ca.joeltherrien.randomforest;
import ca.joeltherrien.randomforest.covariates.Covariate;
import ca.joeltherrien.randomforest.tree.GroupDifferentiator;
import com.fasterxml.jackson.databind.node.ObjectNode;
import lombok.RequiredArgsConstructor;
import org.apache.commons.csv.CSVFormat;
import org.apache.commons.csv.CSVParser;
import org.apache.commons.csv.CSVRecord;
import java.io.FileReader;
import java.io.IOException;
import java.io.Reader;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
public class DataLoader {
public static <Y> List<Row<Y>> loadData(final List<Covariate> covariates, final ResponseLoader<Y> responseLoader, String filename) throws IOException {
final List<Row<Y>> dataset = new ArrayList<>();
final Reader input = new FileReader(filename);
final CSVParser parser = CSVFormat.RFC4180.withFirstRecordAsHeader().parse(input);
int id = 1;
for(final CSVRecord record : parser){
final Map<String, Covariate.Value> covariateValueMap = new HashMap<>();
for(final Covariate<?> covariate : covariates){
covariateValueMap.put(covariate.getName(), covariate.createValue(record.get(covariate.getName())));
}
final Y y = responseLoader.parse(record);
dataset.add(new Row<>(covariateValueMap, id++, y));
}
return dataset;
}
@FunctionalInterface
public interface ResponseLoader<Y>{
Y parse(CSVRecord record);
}
@FunctionalInterface
public interface ResponseLoaderConstructor<Y>{
ResponseLoader<Y> construct(ObjectNode node);
}
@RequiredArgsConstructor
public static class DoubleLoader implements ResponseLoader<Double> {
private final String yName;
public DoubleLoader(final ObjectNode node){
this.yName = node.get("name").asText();
}
@Override
public Double parse(CSVRecord record) {
return Double.parseDouble(record.get(yName));
}
}
}

View file

@ -5,6 +5,9 @@ import ca.joeltherrien.randomforest.covariates.Covariate;
import ca.joeltherrien.randomforest.covariates.FactorCovariateSettings;
import ca.joeltherrien.randomforest.covariates.NumericCovariateSettings;
import ca.joeltherrien.randomforest.tree.ForestTrainer;
import com.fasterxml.jackson.databind.node.JsonNodeFactory;
import com.fasterxml.jackson.databind.node.ObjectNode;
import com.fasterxml.jackson.databind.node.TextNode;
import org.apache.commons.csv.CSVFormat;
import org.apache.commons.csv.CSVParser;
import org.apache.commons.csv.CSVRecord;
@ -21,6 +24,7 @@ import java.util.stream.Collectors;
public class Main {
public static void main(String[] args) throws IOException {
if(args.length != 1){
System.out.println("Must provide one argument - the path to the settings.yaml file.");
@ -36,7 +40,9 @@ public class Main {
final List<Covariate> covariates = settings.getCovariates().stream()
.map(cs -> cs.build()).collect(Collectors.toList());
final List<Row<Double>> dataset = loadData(covariates, settings);
final List<Row<Double>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getDataFileLocation());
final ForestTrainer<Double> forestTrainer = new ForestTrainer<>(settings, dataset, covariates);
@ -51,46 +57,28 @@ public class Main {
}
public static List<Row<Double>> loadData(final List<Covariate> covariates, final Settings settings) throws IOException {
final List<Row<Double>> dataset = new ArrayList<>();
final Reader input = new FileReader(settings.getDataFileLocation());
final CSVParser parser = CSVFormat.RFC4180.withFirstRecordAsHeader().parse(input);
int id = 1;
for(final CSVRecord record : parser){
final Map<String, Covariate.Value> covariateValueMap = new HashMap<>();
for(final Covariate<?> covariate : covariates){
covariateValueMap.put(covariate.getName(), covariate.createValue(record.get(covariate.getName())));
}
final String yStr = record.get(settings.getYVar());
final Double yNum = Double.parseDouble(yStr);
dataset.add(new Row<>(covariateValueMap, id++, yNum));
}
return dataset;
}
private static Settings defaultTemplate(){
return Settings.builder()
final ObjectNode groupDifferentiatorSettings = new ObjectNode(JsonNodeFactory.instance);
groupDifferentiatorSettings.set("type", new TextNode("WeightedVarianceGroupDifferentiator"));
final ObjectNode yVarSettings = new ObjectNode(JsonNodeFactory.instance);
yVarSettings.set("type", new TextNode("y"));
yVarSettings.set("name", new TextNode("y"));
final Settings settings = Settings.builder()
.covariates(List.of(
new NumericCovariateSettings("x1"),
new BooleanCovariateSettings("x2"),
new FactorCovariateSettings("x3", List.of("cat", "mouse", "dog"))
)
)
.yVar("y")
.dataFileLocation("data.csv")
.groupDifferentiator("WeightedVarianceGroupDifferentiator")
.responseCombiner("MeanResponseCombiner")
.treeResponseCombiner("MeanResponseCombiner")
.groupDifferentiatorSettings(groupDifferentiatorSettings)
.yVarSettings(yVarSettings)
.maxNodeDepth(100000)
.mtry(2)
.nodeSize(5)
@ -100,6 +88,9 @@ public class Main {
.saveProgress(true)
.saveTreeLocation("trees/")
.build();
return settings;
}
}

View file

@ -1,20 +1,21 @@
package ca.joeltherrien.randomforest;
import ca.joeltherrien.randomforest.covariates.Covariate;
import ca.joeltherrien.randomforest.covariates.CovariateSettings;
import ca.joeltherrien.randomforest.regression.MeanResponseCombiner;
import ca.joeltherrien.randomforest.responses.competingrisk.*;
import ca.joeltherrien.randomforest.responses.regression.MeanGroupDifferentiator;
import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceGroupDifferentiator;
import ca.joeltherrien.randomforest.tree.GroupDifferentiator;
import ca.joeltherrien.randomforest.tree.ResponseCombiner;
import com.fasterxml.jackson.annotation.JsonIgnore;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.node.JsonNodeFactory;
import com.fasterxml.jackson.databind.node.ObjectNode;
import com.fasterxml.jackson.dataformat.yaml.YAMLFactory;
import lombok.*;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.*;
/**
* This class is saved & loaded using a saved configuration file. It contains all relevant settings when training a forest.
@ -25,16 +26,88 @@ import java.util.Map;
@EqualsAndHashCode
public class Settings {
private static Map<String, DataLoader.ResponseLoaderConstructor> RESPONSE_LOADER_MAP = new HashMap<>();
public static DataLoader.ResponseLoaderConstructor getResponseLoaderConstructor(final String name){
return RESPONSE_LOADER_MAP.get(name.toLowerCase());
}
public static void registerResponseLoaderConstructor(final String name, final DataLoader.ResponseLoaderConstructor responseLoaderConstructor){
RESPONSE_LOADER_MAP.put(name.toLowerCase(), responseLoaderConstructor);
}
static{
registerResponseLoaderConstructor("double",
node -> new DataLoader.DoubleLoader(node)
);
registerResponseLoaderConstructor("CompetingResponse",
node -> new CompetingResponse.CompetingResponseLoader(node)
);
registerResponseLoaderConstructor("CompetingResponseWithCensorTime",
node -> new CompetingResponseWithCensorTime.CompetingResponseWithCensorTimeLoader(node)
);
}
private static Map<String, GroupDifferentiator.GroupDifferentiatorConstructor> GROUP_DIFFERENTIATOR_MAP = new HashMap<>();
public static GroupDifferentiator.GroupDifferentiatorConstructor getGroupDifferentiatorConstructor(final String name){
return GROUP_DIFFERENTIATOR_MAP.get(name.toLowerCase());
}
public static void registerGroupDifferentiatorConstructor(final String name, final GroupDifferentiator.GroupDifferentiatorConstructor groupDifferentiatorConstructor){
GROUP_DIFFERENTIATOR_MAP.put(name.toLowerCase(), groupDifferentiatorConstructor);
}
static{
registerGroupDifferentiatorConstructor("MeanGroupDifferentiator",
(node) -> new MeanGroupDifferentiator()
);
registerGroupDifferentiatorConstructor("WeightedVarianceGroupDifferentiator",
(node) -> new WeightedVarianceGroupDifferentiator()
);
registerGroupDifferentiatorConstructor("LogRankSingleGroupDifferentiator",
(objectNode) -> {
final int eventOfFocus = objectNode.get("eventOfFocus").asInt();
return new LogRankSingleGroupDifferentiator(eventOfFocus);
}
);
registerGroupDifferentiatorConstructor("GrayLogRankMultipleGroupDifferentiator",
(objectNode) -> {
final Iterator<JsonNode> elements = objectNode.get("events").elements();
final List<JsonNode> elementList = new ArrayList<>();
elements.forEachRemaining(node -> elementList.add(node));
final int[] eventArray = elementList.stream().mapToInt(node -> node.asInt()).toArray();
return new GrayLogRankMultipleGroupDifferentiator(eventArray);
}
);
registerGroupDifferentiatorConstructor("LogRankMultipleGroupDifferentiator",
(objectNode) -> {
final Iterator<JsonNode> elements = objectNode.get("events").elements();
final List<JsonNode> elementList = new ArrayList<>();
elements.forEachRemaining(node -> elementList.add(node));
final int[] eventArray = elementList.stream().mapToInt(node -> node.asInt()).toArray();
return new LogRankMultipleGroupDifferentiator(eventArray);
}
);
registerGroupDifferentiatorConstructor("GrayLogRankSingleGroupDifferentiator",
(objectNode) -> {
final int eventOfFocus = objectNode.get("eventOfFocus").asInt();
return new GrayLogRankSingleGroupDifferentiator(eventOfFocus);
}
);
}
private int numberOfSplits = 5;
private int nodeSize = 5;
private int maxNodeDepth = 1000000; // basically no maxNodeDepth
private String responseCombiner;
private String groupDifferentiator;
private ObjectNode groupDifferentiatorSettings = new ObjectNode(JsonNodeFactory.instance);
private String treeResponseCombiner;
private List<CovariateSettings> covariates = new ArrayList<>();
private String yVar = "y";
private ObjectNode yVarSettings = new ObjectNode(JsonNodeFactory.instance);
// number of covariates to randomly try
private int mtry = 0;
@ -48,7 +121,8 @@ public class Settings {
private int numberOfThreads = 1;
private boolean saveProgress = false;
public Settings(){} // required for Jackson
public Settings(){
} // required for Jackson
public static Settings load(File file) throws IOException {
final ObjectMapper mapper = new ObjectMapper(new YAMLFactory());
@ -70,4 +144,18 @@ public class Settings {
mapper.writeValue(file, this);
}
@JsonIgnore
public GroupDifferentiator getGroupDifferentiator(){
final String type = groupDifferentiatorSettings.get("type").asText();
return getGroupDifferentiatorConstructor(type).construct(groupDifferentiatorSettings);
}
@JsonIgnore
public DataLoader.ResponseLoader getResponseLoader(){
final String type = yVarSettings.get("type").asText();
return getResponseLoaderConstructor(type).construct(yVarSettings);
}
}

View file

@ -0,0 +1,36 @@
package ca.joeltherrien.randomforest.responses.competingrisk;
import ca.joeltherrien.randomforest.DataLoader;
import com.fasterxml.jackson.databind.node.ObjectNode;
import lombok.Data;
import lombok.RequiredArgsConstructor;
import org.apache.commons.csv.CSVRecord;
@Data
public class CompetingResponse {
private final int delta;
private final double u;
@RequiredArgsConstructor
public static class CompetingResponseLoader implements DataLoader.ResponseLoader<CompetingResponse>{
private final String deltaName;
private final String uName;
public CompetingResponseLoader(ObjectNode node){
this.deltaName = node.get("delta").asText();
this.uName = node.get("u").asText();
}
@Override
public CompetingResponse parse(CSVRecord record) {
final int delta = Integer.parseInt(record.get(deltaName));
final double u = Double.parseDouble(record.get(uName));
return new CompetingResponse(delta, u);
}
}
}

View file

@ -0,0 +1,44 @@
package ca.joeltherrien.randomforest.responses.competingrisk;
import ca.joeltherrien.randomforest.DataLoader;
import com.fasterxml.jackson.databind.node.ObjectNode;
import lombok.Data;
import lombok.RequiredArgsConstructor;
import org.apache.commons.csv.CSVRecord;
/**
* See Ishwaran paper on splitting rule modelled after Gray's test. This requires that we know the censor times.
*
*/
@Data
public class CompetingResponseWithCensorTime extends CompetingResponse{
private final double c;
public CompetingResponseWithCensorTime(int delta, double u, double c) {
super(delta, u);
this.c = c;
}
@RequiredArgsConstructor
public static class CompetingResponseWithCensorTimeLoader implements DataLoader.ResponseLoader<CompetingResponseWithCensorTime>{
private final String deltaName;
private final String uName;
private final String cName;
public CompetingResponseWithCensorTimeLoader(ObjectNode node){
this.deltaName = node.get("delta").asText();
this.uName = node.get("u").asText();
this.cName = node.get("c").asText();
}
@Override
public CompetingResponseWithCensorTime parse(CSVRecord record) {
final int delta = Integer.parseInt(record.get(deltaName));
final double u = Double.parseDouble(record.get(uName));
final double c = Double.parseDouble(record.get(cName));
return new CompetingResponseWithCensorTime(delta, u, c);
}
}
}

View file

@ -0,0 +1,11 @@
package ca.joeltherrien.randomforest.responses.competingrisk;
import java.util.List;
public class CompetingRiskFunction {
private List<Point> pointList;
}

View file

@ -0,0 +1,91 @@
package ca.joeltherrien.randomforest.responses.competingrisk;
import ca.joeltherrien.randomforest.tree.GroupDifferentiator;
import lombok.AllArgsConstructor;
import lombok.Data;
import java.util.List;
import java.util.stream.Stream;
/**
* See page 761 of Random survival forests for competing risks by Ishwaran et al. The class is abstract as Gray's test
* modifies the abstract method.
*
*/
public abstract class CompetingRiskGroupDifferentiator<Y extends CompetingResponse> implements GroupDifferentiator<Y>{
@Override
public abstract Double differentiate(List<Y> leftHand, List<Y> rightHand);
abstract double riskSet(final List<Y> eventList, double time, int eventOfFocus);
private double numberOFEventsAtTime(int eventOfFocus, List<Y> eventList, double time){
return (double) eventList.stream()
.filter(event -> event.getDelta() == eventOfFocus)
.filter(event -> event.getU() == time) // since delta != 0 we know censoring didn't occur prior to this
.count();
}
/**
* Calculates the log rank value (or the Gray's test value) for a *specific* event cause.
*
* @param eventOfFocus
* @param leftHand A non-empty list of CompetingResponse
* @param rightHand A non-empty list of CompetingResponse
* @return
*/
LogRankValue specificLogRankValue(final int eventOfFocus, List<Y> leftHand, List<Y> rightHand){
final double[] distinctEventTimes = Stream.concat(
leftHand.stream(), rightHand.stream()
)
.filter(event -> event.getDelta() != 0) // remove censored events
.mapToDouble(event -> event.getU())
.distinct()
.toArray();
double summation = 0.0;
double varianceSquared = 0.0;
for(final double time_k : distinctEventTimes){
final double weight = weight(time_k); // W_j(t_k)
final double numberEventsAtTimeDaughterLeft = numberOFEventsAtTime(eventOfFocus, leftHand, time_k); // d_{j,l}(t_k)
final double numberEventsAtTimeDaughterRight = numberOFEventsAtTime(eventOfFocus, rightHand, time_k); // d_{j,r}(t_k)
final double numberOfEventsAtTime = numberEventsAtTimeDaughterLeft + numberEventsAtTimeDaughterRight; // d_j(t_k)
final double individualsAtRiskDaughterLeft = riskSet(leftHand, time_k, eventOfFocus); // Y_l(t_k)
final double individualsAtRiskDaughterRight = riskSet(rightHand, time_k, eventOfFocus); // Y_r(t_k)
final double individualsAtRisk = individualsAtRiskDaughterLeft + individualsAtRiskDaughterRight; // Y(t_k)
summation = summation + weight*(numberEventsAtTimeDaughterLeft - numberOfEventsAtTime*individualsAtRiskDaughterLeft/individualsAtRisk);
varianceSquared = varianceSquared + weight*weight*numberOfEventsAtTime*individualsAtRiskDaughterLeft/individualsAtRisk
* (1.0 - individualsAtRiskDaughterLeft / individualsAtRisk)
* ((individualsAtRisk - numberOfEventsAtTime) / (individualsAtRisk - 1.0));
}
return new LogRankValue(summation, varianceSquared);
}
double weight(double time){
return 1.0; // TODO - make configurable
// A value of 1 "corresponds to the standard log-rank test which has optimal power for detecting alternatives where the cause-specific hazards are proportional"
//TODO - look into what weights might be more appropriate.
}
@Data
@AllArgsConstructor
static class LogRankValue{
private final double numerator;
private final double varianceSquared;
public double getVariance(){
return Math.sqrt(varianceSquared);
}
}
}

View file

@ -0,0 +1,51 @@
package ca.joeltherrien.randomforest.responses.competingrisk;
import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceGroupDifferentiator;
import ca.joeltherrien.randomforest.tree.GroupDifferentiator;
import com.fasterxml.jackson.databind.JsonNode;
import lombok.RequiredArgsConstructor;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
/**
* See page 761 of Random survival forests for competing risks by Ishwaran et al.
*
*/
@RequiredArgsConstructor
public class GrayLogRankMultipleGroupDifferentiator extends CompetingRiskGroupDifferentiator<CompetingResponseWithCensorTime> {
private final int[] events;
@Override
public Double differentiate(List<CompetingResponseWithCensorTime> leftHand, List<CompetingResponseWithCensorTime> rightHand) {
if(leftHand.size() == 0 || rightHand.size() == 0){
return null;
}
double numerator = 0.0;
double denominatorSquared = 0.0;
for(final int eventOfFocus : events){
final LogRankValue valueOfInterest = specificLogRankValue(eventOfFocus, leftHand, rightHand);
numerator += valueOfInterest.getNumerator()*valueOfInterest.getVariance();
denominatorSquared += valueOfInterest.getVarianceSquared();
}
return Math.abs(numerator / Math.sqrt(denominatorSquared));
}
@Override
double riskSet(List<CompetingResponseWithCensorTime> eventList, double time, int eventOfFocus) {
return eventList.stream()
.filter(event -> event.getU() >= time ||
(event.getU() < time && event.getDelta() != eventOfFocus && event.getC() > time)
)
.count();
}
}

View file

@ -0,0 +1,41 @@
package ca.joeltherrien.randomforest.responses.competingrisk;
import ca.joeltherrien.randomforest.tree.GroupDifferentiator;
import com.fasterxml.jackson.databind.JsonNode;
import lombok.RequiredArgsConstructor;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
/**
* See page 761 of Random survival forests for competing risks by Ishwaran et al.
*
*/
@RequiredArgsConstructor
public class GrayLogRankSingleGroupDifferentiator extends CompetingRiskGroupDifferentiator<CompetingResponseWithCensorTime> {
private final int eventOfFocus;
@Override
public Double differentiate(List<CompetingResponseWithCensorTime> leftHand, List<CompetingResponseWithCensorTime> rightHand) {
if(leftHand.size() == 0 || rightHand.size() == 0){
return null;
}
final LogRankValue valueOfInterest = specificLogRankValue(eventOfFocus, leftHand, rightHand);
return Math.abs(valueOfInterest.getNumerator() / valueOfInterest.getVariance());
}
@Override
double riskSet(List<CompetingResponseWithCensorTime> eventList, double time, int eventOfFocus) {
return eventList.stream()
.filter(event -> event.getU() >= time ||
(event.getU() < time && event.getDelta() != eventOfFocus && event.getC() > time)
)
.count();
}
}

View file

@ -0,0 +1,48 @@
package ca.joeltherrien.randomforest.responses.competingrisk;
import ca.joeltherrien.randomforest.tree.GroupDifferentiator;
import com.fasterxml.jackson.databind.JsonNode;
import lombok.RequiredArgsConstructor;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
/**
* See page 761 of Random survival forests for competing risks by Ishwaran et al.
*
*/
@RequiredArgsConstructor
public class LogRankMultipleGroupDifferentiator extends CompetingRiskGroupDifferentiator<CompetingResponse> {
private final int[] events;
@Override
public Double differentiate(List<CompetingResponse> leftHand, List<CompetingResponse> rightHand) {
if(leftHand.size() == 0 || rightHand.size() == 0){
return null;
}
double numerator = 0.0;
double denominatorSquared = 0.0;
for(final int eventOfFocus : events){
final LogRankValue valueOfInterest = specificLogRankValue(eventOfFocus, leftHand, rightHand);
numerator += valueOfInterest.getNumerator()*valueOfInterest.getVariance();
denominatorSquared += valueOfInterest.getVarianceSquared();
}
return Math.abs(numerator / Math.sqrt(denominatorSquared));
}
@Override
double riskSet(List<CompetingResponse> eventList, double time, int eventOfFocus) {
return eventList.stream()
.filter(event -> event.getU() >= time)
.count();
}
}

View file

@ -0,0 +1,36 @@
package ca.joeltherrien.randomforest.responses.competingrisk;
import ca.joeltherrien.randomforest.tree.GroupDifferentiator;
import lombok.RequiredArgsConstructor;
import java.util.List;
/**
* See page 761 of Random survival forests for competing risks by Ishwaran et al.
*
*/
@RequiredArgsConstructor
public class LogRankSingleGroupDifferentiator extends CompetingRiskGroupDifferentiator<CompetingResponse> {
private final int eventOfFocus;
@Override
public Double differentiate(List<CompetingResponse> leftHand, List<CompetingResponse> rightHand) {
if(leftHand.size() == 0 || rightHand.size() == 0){
return null;
}
final LogRankValue valueOfInterest = specificLogRankValue(eventOfFocus, leftHand, rightHand);
return Math.abs(valueOfInterest.getNumerator() / valueOfInterest.getVariance());
}
@Override
double riskSet(List<CompetingResponse> eventList, double time, int eventOfFocus) {
return eventList.stream()
.filter(event -> event.getU() >= time)
.count();
}
}

View file

@ -0,0 +1,15 @@
package ca.joeltherrien.randomforest.responses.competingrisk;
import lombok.Data;
/**
* Represents a point in our estimate of either the cumulative hazard function or the cumulative incidence function.
*
*/
@Data
public class Point {
private final Double time;
private final Double y;
}

View file

@ -1,16 +1,11 @@
package ca.joeltherrien.randomforest.regression;
package ca.joeltherrien.randomforest.responses.regression;
import ca.joeltherrien.randomforest.tree.GroupDifferentiator;
import ca.joeltherrien.randomforest.tree.ResponseCombiner;
import java.util.List;
public class MeanGroupDifferentiator implements GroupDifferentiator<Double> {
static{
GroupDifferentiator.registerGroupDifferentiator("MeanGroupDifferentiator", new MeanGroupDifferentiator());
}
@Override
public Double differentiate(List<Double> leftHand, List<Double> rightHand) {

View file

@ -1,4 +1,4 @@
package ca.joeltherrien.randomforest.regression;
package ca.joeltherrien.randomforest.responses.regression;
import ca.joeltherrien.randomforest.tree.ResponseCombiner;

View file

@ -1,4 +1,4 @@
package ca.joeltherrien.randomforest.regression;
package ca.joeltherrien.randomforest.responses.regression;
import ca.joeltherrien.randomforest.tree.GroupDifferentiator;
@ -6,10 +6,6 @@ import java.util.List;
public class WeightedVarianceGroupDifferentiator implements GroupDifferentiator<Double> {
static{
GroupDifferentiator.registerGroupDifferentiator("WeightedVarianceGroupDifferentiator", new WeightedVarianceGroupDifferentiator());
}
@Override
public Double differentiate(List<Double> leftHand, List<Double> rightHand) {

View file

@ -1,5 +1,7 @@
package ca.joeltherrien.randomforest.tree;
import com.fasterxml.jackson.databind.node.ObjectNode;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
@ -14,12 +16,11 @@ public interface GroupDifferentiator<Y> {
Double differentiate(List<Y> leftHand, List<Y> rightHand);
Map<String, GroupDifferentiator> GROUP_DIFFERENTIATOR_MAP = new HashMap<>();
static GroupDifferentiator loadGroupDifferentiatorByName(final String name){
return GROUP_DIFFERENTIATOR_MAP.get(name);
}
static void registerGroupDifferentiator(final String name, final GroupDifferentiator groupDifferentiator){
GROUP_DIFFERENTIATOR_MAP.put(name, groupDifferentiator);
@FunctionalInterface
interface GroupDifferentiatorConstructor<Y>{
GroupDifferentiator<Y> construct(ObjectNode node);
}
}

View file

@ -1,7 +1,5 @@
package ca.joeltherrien.randomforest.tree;
import ca.joeltherrien.randomforest.regression.MeanResponseCombiner;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

View file

@ -30,7 +30,7 @@ public class TreeTrainer<Y> {
this.maxNodeDepth = settings.getMaxNodeDepth();
this.responseCombiner = ResponseCombiner.loadResponseCombinerByName(settings.getResponseCombiner());
this.groupDifferentiator = GroupDifferentiator.loadGroupDifferentiatorByName(settings.getGroupDifferentiator());
this.groupDifferentiator = settings.getGroupDifferentiator();
}
public Node<Y> growTree(List<Row<Y>> data, List<Covariate> covariatesToTry){

View file

@ -1,12 +1,15 @@
package ca.joeltherrien.randomforest.csv;
import ca.joeltherrien.randomforest.Main;
import ca.joeltherrien.randomforest.DataLoader;
import ca.joeltherrien.randomforest.Row;
import ca.joeltherrien.randomforest.Settings;
import ca.joeltherrien.randomforest.covariates.BooleanCovariateSettings;
import ca.joeltherrien.randomforest.covariates.Covariate;
import ca.joeltherrien.randomforest.covariates.FactorCovariateSettings;
import ca.joeltherrien.randomforest.covariates.NumericCovariateSettings;
import com.fasterxml.jackson.databind.node.JsonNodeFactory;
import com.fasterxml.jackson.databind.node.ObjectNode;
import com.fasterxml.jackson.databind.node.TextNode;
import org.junit.jupiter.api.Test;
import java.io.IOException;
@ -26,7 +29,11 @@ public class TestLoadingCSV {
*/
@Test
public void verifyLoading() throws IOException {
public void verifyLoading() throws IOException, ClassNotFoundException {
final ObjectNode yVarSettings = new ObjectNode(JsonNodeFactory.instance);
yVarSettings.set("type", new TextNode("Double"));
yVarSettings.set("name", new TextNode("y"));
final Settings settings = Settings.builder()
.dataFileLocation("src/test/resources/testCSV.csv")
.covariates(
@ -34,13 +41,16 @@ public class TestLoadingCSV {
new FactorCovariateSettings("x2", List.of("dog", "cat", "mouse")),
new BooleanCovariateSettings("x3"))
)
.yVar("y")
.yVarSettings(yVarSettings)
.build();
final List<Covariate> covariates = settings.getCovariates().stream()
.map(cs -> cs.build()).collect(Collectors.toList());
final List<Row<Double>> data = Main.loadData(covariates, settings);
final DataLoader.ResponseLoader loader = settings.getResponseLoader();
final List<Row<Double>> data = DataLoader.loadData(covariates, loader, settings.getDataFileLocation());
assertEquals(4, data.size());

View file

@ -4,31 +4,38 @@ import ca.joeltherrien.randomforest.Settings;
import static org.junit.jupiter.api.Assertions.assertEquals;
import ca.joeltherrien.randomforest.covariates.*;
import ca.joeltherrien.randomforest.regression.MeanResponseCombiner;
import ca.joeltherrien.randomforest.regression.WeightedVarianceGroupDifferentiator;
import com.fasterxml.jackson.databind.node.JsonNodeFactory;
import com.fasterxml.jackson.databind.node.ObjectNode;
import com.fasterxml.jackson.databind.node.TextNode;
import org.junit.jupiter.api.Test;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
public class TestPersistence {
@Test
public void testSaving() throws IOException {
final Settings settingsOriginal = Settings.builder()
final ObjectNode groupDifferentiatorSettings = new ObjectNode(JsonNodeFactory.instance);
groupDifferentiatorSettings.set("type", new TextNode("WeightedVarianceGroupDifferentiator"));
final ObjectNode yVarSettings = new ObjectNode(JsonNodeFactory.instance);
yVarSettings.set("type", new TextNode("Double"));
yVarSettings.set("name", new TextNode("y"));
final Settings settingsOriginal = Settings.builder()
.covariates(List.of(
new NumericCovariateSettings("x1"),
new BooleanCovariateSettings("x2"),
new FactorCovariateSettings("x3", List.of("cat", "mouse", "dog"))
)
)
.yVar("y")
.dataFileLocation("data.csv")
.groupDifferentiator("WeightedVarianceGroupDifferentiator")
.responseCombiner("MeanResponseCombiner")
.treeResponseCombiner("MeanResponseCombiner")
.groupDifferentiatorSettings(groupDifferentiatorSettings)
.yVarSettings(yVarSettings)
.maxNodeDepth(100000)
.mtry(2)
.nodeSize(5)
@ -46,7 +53,7 @@ public class TestPersistence {
assertEquals(settingsOriginal, reloadedSettings);
templateFile.delete();
//templateFile.delete();
}

View file

@ -3,8 +3,8 @@ package ca.joeltherrien.randomforest.workshop;
import ca.joeltherrien.randomforest.*;
import ca.joeltherrien.randomforest.covariates.Covariate;
import ca.joeltherrien.randomforest.covariates.NumericCovariate;
import ca.joeltherrien.randomforest.regression.MeanResponseCombiner;
import ca.joeltherrien.randomforest.regression.WeightedVarianceGroupDifferentiator;
import ca.joeltherrien.randomforest.responses.regression.MeanResponseCombiner;
import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceGroupDifferentiator;
import ca.joeltherrien.randomforest.tree.ForestTrainer;
import ca.joeltherrien.randomforest.tree.TreeTrainer;

View file

@ -5,8 +5,8 @@ import ca.joeltherrien.randomforest.covariates.Covariate;
import ca.joeltherrien.randomforest.CovariateRow;
import ca.joeltherrien.randomforest.covariates.NumericCovariate;
import ca.joeltherrien.randomforest.Row;
import ca.joeltherrien.randomforest.regression.MeanResponseCombiner;
import ca.joeltherrien.randomforest.regression.WeightedVarianceGroupDifferentiator;
import ca.joeltherrien.randomforest.responses.regression.MeanResponseCombiner;
import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceGroupDifferentiator;
import ca.joeltherrien.randomforest.tree.Node;
import ca.joeltherrien.randomforest.tree.TreeTrainer;

View file

@ -5,8 +5,8 @@ import ca.joeltherrien.randomforest.*;
import ca.joeltherrien.randomforest.covariates.Covariate;
import ca.joeltherrien.randomforest.covariates.FactorCovariate;
import ca.joeltherrien.randomforest.covariates.NumericCovariate;
import ca.joeltherrien.randomforest.regression.MeanResponseCombiner;
import ca.joeltherrien.randomforest.regression.WeightedVarianceGroupDifferentiator;
import ca.joeltherrien.randomforest.responses.regression.MeanResponseCombiner;
import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceGroupDifferentiator;
import ca.joeltherrien.randomforest.tree.Node;
import ca.joeltherrien.randomforest.tree.TreeTrainer;