diff --git a/.gitignore b/.gitignore index 73c3108..ecad204 100644 --- a/.gitignore +++ b/.gitignore @@ -4,3 +4,4 @@ target/ *.iml .idea +template.yaml diff --git a/src/main/java/ca/joeltherrien/randomforest/DataLoader.java b/src/main/java/ca/joeltherrien/randomforest/DataLoader.java new file mode 100644 index 0000000..56f6431 --- /dev/null +++ b/src/main/java/ca/joeltherrien/randomforest/DataLoader.java @@ -0,0 +1,71 @@ +package ca.joeltherrien.randomforest; + +import ca.joeltherrien.randomforest.covariates.Covariate; +import ca.joeltherrien.randomforest.tree.GroupDifferentiator; +import com.fasterxml.jackson.databind.node.ObjectNode; +import lombok.RequiredArgsConstructor; +import org.apache.commons.csv.CSVFormat; +import org.apache.commons.csv.CSVParser; +import org.apache.commons.csv.CSVRecord; + +import java.io.FileReader; +import java.io.IOException; +import java.io.Reader; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +public class DataLoader { + + public static List> loadData(final List covariates, final ResponseLoader responseLoader, String filename) throws IOException { + + final List> dataset = new ArrayList<>(); + + final Reader input = new FileReader(filename); + final CSVParser parser = CSVFormat.RFC4180.withFirstRecordAsHeader().parse(input); + + + int id = 1; + for(final CSVRecord record : parser){ + final Map covariateValueMap = new HashMap<>(); + + for(final Covariate covariate : covariates){ + covariateValueMap.put(covariate.getName(), covariate.createValue(record.get(covariate.getName()))); + } + + final Y y = responseLoader.parse(record); + + dataset.add(new Row<>(covariateValueMap, id++, y)); + + } + + return dataset; + + } + + @FunctionalInterface + public interface ResponseLoader{ + Y parse(CSVRecord record); + } + + @FunctionalInterface + public interface ResponseLoaderConstructor{ + ResponseLoader construct(ObjectNode node); + } + + @RequiredArgsConstructor + public static class DoubleLoader implements ResponseLoader { + + private final String yName; + + public DoubleLoader(final ObjectNode node){ + this.yName = node.get("name").asText(); + } + @Override + public Double parse(CSVRecord record) { + return Double.parseDouble(record.get(yName)); + } + } + +} diff --git a/src/main/java/ca/joeltherrien/randomforest/Main.java b/src/main/java/ca/joeltherrien/randomforest/Main.java index f557663..49db7a4 100644 --- a/src/main/java/ca/joeltherrien/randomforest/Main.java +++ b/src/main/java/ca/joeltherrien/randomforest/Main.java @@ -5,6 +5,9 @@ import ca.joeltherrien.randomforest.covariates.Covariate; import ca.joeltherrien.randomforest.covariates.FactorCovariateSettings; import ca.joeltherrien.randomforest.covariates.NumericCovariateSettings; import ca.joeltherrien.randomforest.tree.ForestTrainer; +import com.fasterxml.jackson.databind.node.JsonNodeFactory; +import com.fasterxml.jackson.databind.node.ObjectNode; +import com.fasterxml.jackson.databind.node.TextNode; import org.apache.commons.csv.CSVFormat; import org.apache.commons.csv.CSVParser; import org.apache.commons.csv.CSVRecord; @@ -21,6 +24,7 @@ import java.util.stream.Collectors; public class Main { + public static void main(String[] args) throws IOException { if(args.length != 1){ System.out.println("Must provide one argument - the path to the settings.yaml file."); @@ -36,7 +40,9 @@ public class Main { final List covariates = settings.getCovariates().stream() .map(cs -> cs.build()).collect(Collectors.toList()); - final List> dataset = loadData(covariates, settings); + + + final List> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getDataFileLocation()); final ForestTrainer forestTrainer = new ForestTrainer<>(settings, dataset, covariates); @@ -51,46 +57,28 @@ public class Main { } - public static List> loadData(final List covariates, final Settings settings) throws IOException { - - final List> dataset = new ArrayList<>(); - - final Reader input = new FileReader(settings.getDataFileLocation()); - final CSVParser parser = CSVFormat.RFC4180.withFirstRecordAsHeader().parse(input); - - - int id = 1; - for(final CSVRecord record : parser){ - final Map covariateValueMap = new HashMap<>(); - - for(final Covariate covariate : covariates){ - covariateValueMap.put(covariate.getName(), covariate.createValue(record.get(covariate.getName()))); - } - - final String yStr = record.get(settings.getYVar()); - final Double yNum = Double.parseDouble(yStr); - - dataset.add(new Row<>(covariateValueMap, id++, yNum)); - - } - - return dataset; - - } private static Settings defaultTemplate(){ - return Settings.builder() + + final ObjectNode groupDifferentiatorSettings = new ObjectNode(JsonNodeFactory.instance); + groupDifferentiatorSettings.set("type", new TextNode("WeightedVarianceGroupDifferentiator")); + + final ObjectNode yVarSettings = new ObjectNode(JsonNodeFactory.instance); + yVarSettings.set("type", new TextNode("y")); + yVarSettings.set("name", new TextNode("y")); + + final Settings settings = Settings.builder() .covariates(List.of( new NumericCovariateSettings("x1"), new BooleanCovariateSettings("x2"), new FactorCovariateSettings("x3", List.of("cat", "mouse", "dog")) ) ) - .yVar("y") .dataFileLocation("data.csv") - .groupDifferentiator("WeightedVarianceGroupDifferentiator") .responseCombiner("MeanResponseCombiner") .treeResponseCombiner("MeanResponseCombiner") + .groupDifferentiatorSettings(groupDifferentiatorSettings) + .yVarSettings(yVarSettings) .maxNodeDepth(100000) .mtry(2) .nodeSize(5) @@ -100,6 +88,9 @@ public class Main { .saveProgress(true) .saveTreeLocation("trees/") .build(); + + + return settings; } } diff --git a/src/main/java/ca/joeltherrien/randomforest/Settings.java b/src/main/java/ca/joeltherrien/randomforest/Settings.java index da24933..cdc7c70 100644 --- a/src/main/java/ca/joeltherrien/randomforest/Settings.java +++ b/src/main/java/ca/joeltherrien/randomforest/Settings.java @@ -1,20 +1,21 @@ package ca.joeltherrien.randomforest; -import ca.joeltherrien.randomforest.covariates.Covariate; import ca.joeltherrien.randomforest.covariates.CovariateSettings; -import ca.joeltherrien.randomforest.regression.MeanResponseCombiner; +import ca.joeltherrien.randomforest.responses.competingrisk.*; +import ca.joeltherrien.randomforest.responses.regression.MeanGroupDifferentiator; +import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceGroupDifferentiator; import ca.joeltherrien.randomforest.tree.GroupDifferentiator; -import ca.joeltherrien.randomforest.tree.ResponseCombiner; +import com.fasterxml.jackson.annotation.JsonIgnore; +import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.node.JsonNodeFactory; +import com.fasterxml.jackson.databind.node.ObjectNode; import com.fasterxml.jackson.dataformat.yaml.YAMLFactory; import lombok.*; import java.io.File; import java.io.IOException; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; +import java.util.*; /** * This class is saved & loaded using a saved configuration file. It contains all relevant settings when training a forest. @@ -25,16 +26,88 @@ import java.util.Map; @EqualsAndHashCode public class Settings { + private static Map RESPONSE_LOADER_MAP = new HashMap<>(); + public static DataLoader.ResponseLoaderConstructor getResponseLoaderConstructor(final String name){ + return RESPONSE_LOADER_MAP.get(name.toLowerCase()); + } + public static void registerResponseLoaderConstructor(final String name, final DataLoader.ResponseLoaderConstructor responseLoaderConstructor){ + RESPONSE_LOADER_MAP.put(name.toLowerCase(), responseLoaderConstructor); + } + + static{ + registerResponseLoaderConstructor("double", + node -> new DataLoader.DoubleLoader(node) + ); + registerResponseLoaderConstructor("CompetingResponse", + node -> new CompetingResponse.CompetingResponseLoader(node) + ); + registerResponseLoaderConstructor("CompetingResponseWithCensorTime", + node -> new CompetingResponseWithCensorTime.CompetingResponseWithCensorTimeLoader(node) + ); + } + + private static Map GROUP_DIFFERENTIATOR_MAP = new HashMap<>(); + public static GroupDifferentiator.GroupDifferentiatorConstructor getGroupDifferentiatorConstructor(final String name){ + return GROUP_DIFFERENTIATOR_MAP.get(name.toLowerCase()); + } + public static void registerGroupDifferentiatorConstructor(final String name, final GroupDifferentiator.GroupDifferentiatorConstructor groupDifferentiatorConstructor){ + GROUP_DIFFERENTIATOR_MAP.put(name.toLowerCase(), groupDifferentiatorConstructor); + } + static{ + registerGroupDifferentiatorConstructor("MeanGroupDifferentiator", + (node) -> new MeanGroupDifferentiator() + ); + registerGroupDifferentiatorConstructor("WeightedVarianceGroupDifferentiator", + (node) -> new WeightedVarianceGroupDifferentiator() + ); + registerGroupDifferentiatorConstructor("LogRankSingleGroupDifferentiator", + (objectNode) -> { + final int eventOfFocus = objectNode.get("eventOfFocus").asInt(); + + return new LogRankSingleGroupDifferentiator(eventOfFocus); + } + ); + registerGroupDifferentiatorConstructor("GrayLogRankMultipleGroupDifferentiator", + (objectNode) -> { + final Iterator elements = objectNode.get("events").elements(); + final List elementList = new ArrayList<>(); + elements.forEachRemaining(node -> elementList.add(node)); + + final int[] eventArray = elementList.stream().mapToInt(node -> node.asInt()).toArray(); + + return new GrayLogRankMultipleGroupDifferentiator(eventArray); + } + ); + registerGroupDifferentiatorConstructor("LogRankMultipleGroupDifferentiator", + (objectNode) -> { + final Iterator elements = objectNode.get("events").elements(); + final List elementList = new ArrayList<>(); + elements.forEachRemaining(node -> elementList.add(node)); + + final int[] eventArray = elementList.stream().mapToInt(node -> node.asInt()).toArray(); + + return new LogRankMultipleGroupDifferentiator(eventArray); + } + ); + registerGroupDifferentiatorConstructor("GrayLogRankSingleGroupDifferentiator", + (objectNode) -> { + final int eventOfFocus = objectNode.get("eventOfFocus").asInt(); + + return new GrayLogRankSingleGroupDifferentiator(eventOfFocus); + } + ); + } + private int numberOfSplits = 5; private int nodeSize = 5; private int maxNodeDepth = 1000000; // basically no maxNodeDepth private String responseCombiner; - private String groupDifferentiator; + private ObjectNode groupDifferentiatorSettings = new ObjectNode(JsonNodeFactory.instance); private String treeResponseCombiner; private List covariates = new ArrayList<>(); - private String yVar = "y"; + private ObjectNode yVarSettings = new ObjectNode(JsonNodeFactory.instance); // number of covariates to randomly try private int mtry = 0; @@ -48,7 +121,8 @@ public class Settings { private int numberOfThreads = 1; private boolean saveProgress = false; - public Settings(){} // required for Jackson + public Settings(){ + } // required for Jackson public static Settings load(File file) throws IOException { final ObjectMapper mapper = new ObjectMapper(new YAMLFactory()); @@ -70,4 +144,18 @@ public class Settings { mapper.writeValue(file, this); } + @JsonIgnore + public GroupDifferentiator getGroupDifferentiator(){ + final String type = groupDifferentiatorSettings.get("type").asText(); + + return getGroupDifferentiatorConstructor(type).construct(groupDifferentiatorSettings); + } + + @JsonIgnore + public DataLoader.ResponseLoader getResponseLoader(){ + final String type = yVarSettings.get("type").asText(); + + return getResponseLoaderConstructor(type).construct(yVarSettings); + } + } diff --git a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingResponse.java b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingResponse.java new file mode 100644 index 0000000..a910f91 --- /dev/null +++ b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingResponse.java @@ -0,0 +1,36 @@ +package ca.joeltherrien.randomforest.responses.competingrisk; + +import ca.joeltherrien.randomforest.DataLoader; +import com.fasterxml.jackson.databind.node.ObjectNode; +import lombok.Data; +import lombok.RequiredArgsConstructor; +import org.apache.commons.csv.CSVRecord; + +@Data +public class CompetingResponse { + + private final int delta; + private final double u; + + + @RequiredArgsConstructor + public static class CompetingResponseLoader implements DataLoader.ResponseLoader{ + + private final String deltaName; + private final String uName; + + public CompetingResponseLoader(ObjectNode node){ + this.deltaName = node.get("delta").asText(); + this.uName = node.get("u").asText(); + } + + @Override + public CompetingResponse parse(CSVRecord record) { + final int delta = Integer.parseInt(record.get(deltaName)); + final double u = Double.parseDouble(record.get(uName)); + + return new CompetingResponse(delta, u); + } + } + +} diff --git a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingResponseWithCensorTime.java b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingResponseWithCensorTime.java new file mode 100644 index 0000000..adb56f1 --- /dev/null +++ b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingResponseWithCensorTime.java @@ -0,0 +1,44 @@ +package ca.joeltherrien.randomforest.responses.competingrisk; + +import ca.joeltherrien.randomforest.DataLoader; +import com.fasterxml.jackson.databind.node.ObjectNode; +import lombok.Data; +import lombok.RequiredArgsConstructor; +import org.apache.commons.csv.CSVRecord; + +/** + * See Ishwaran paper on splitting rule modelled after Gray's test. This requires that we know the censor times. + * + */ +@Data +public class CompetingResponseWithCensorTime extends CompetingResponse{ + private final double c; + + public CompetingResponseWithCensorTime(int delta, double u, double c) { + super(delta, u); + this.c = c; + } + + @RequiredArgsConstructor + public static class CompetingResponseWithCensorTimeLoader implements DataLoader.ResponseLoader{ + + private final String deltaName; + private final String uName; + private final String cName; + + public CompetingResponseWithCensorTimeLoader(ObjectNode node){ + this.deltaName = node.get("delta").asText(); + this.uName = node.get("u").asText(); + this.cName = node.get("c").asText(); + } + + @Override + public CompetingResponseWithCensorTime parse(CSVRecord record) { + final int delta = Integer.parseInt(record.get(deltaName)); + final double u = Double.parseDouble(record.get(uName)); + final double c = Double.parseDouble(record.get(cName)); + + return new CompetingResponseWithCensorTime(delta, u, c); + } + } +} diff --git a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskFunction.java b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskFunction.java new file mode 100644 index 0000000..0cdd742 --- /dev/null +++ b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskFunction.java @@ -0,0 +1,11 @@ +package ca.joeltherrien.randomforest.responses.competingrisk; + +import java.util.List; + +public class CompetingRiskFunction { + + private List pointList; + + + +} diff --git a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskGroupDifferentiator.java b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskGroupDifferentiator.java new file mode 100644 index 0000000..3b15926 --- /dev/null +++ b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskGroupDifferentiator.java @@ -0,0 +1,91 @@ +package ca.joeltherrien.randomforest.responses.competingrisk; + +import ca.joeltherrien.randomforest.tree.GroupDifferentiator; +import lombok.AllArgsConstructor; +import lombok.Data; + +import java.util.List; +import java.util.stream.Stream; + +/** + * See page 761 of Random survival forests for competing risks by Ishwaran et al. The class is abstract as Gray's test + * modifies the abstract method. + * + */ +public abstract class CompetingRiskGroupDifferentiator implements GroupDifferentiator{ + + @Override + public abstract Double differentiate(List leftHand, List rightHand); + + abstract double riskSet(final List eventList, double time, int eventOfFocus); + + private double numberOFEventsAtTime(int eventOfFocus, List eventList, double time){ + return (double) eventList.stream() + .filter(event -> event.getDelta() == eventOfFocus) + .filter(event -> event.getU() == time) // since delta != 0 we know censoring didn't occur prior to this + .count(); + + } + + /** + * Calculates the log rank value (or the Gray's test value) for a *specific* event cause. + * + * @param eventOfFocus + * @param leftHand A non-empty list of CompetingResponse + * @param rightHand A non-empty list of CompetingResponse + * @return + */ + LogRankValue specificLogRankValue(final int eventOfFocus, List leftHand, List rightHand){ + + final double[] distinctEventTimes = Stream.concat( + leftHand.stream(), rightHand.stream() + ) + .filter(event -> event.getDelta() != 0) // remove censored events + .mapToDouble(event -> event.getU()) + .distinct() + .toArray(); + + double summation = 0.0; + double varianceSquared = 0.0; + + for(final double time_k : distinctEventTimes){ + final double weight = weight(time_k); // W_j(t_k) + final double numberEventsAtTimeDaughterLeft = numberOFEventsAtTime(eventOfFocus, leftHand, time_k); // d_{j,l}(t_k) + final double numberEventsAtTimeDaughterRight = numberOFEventsAtTime(eventOfFocus, rightHand, time_k); // d_{j,r}(t_k) + final double numberOfEventsAtTime = numberEventsAtTimeDaughterLeft + numberEventsAtTimeDaughterRight; // d_j(t_k) + + final double individualsAtRiskDaughterLeft = riskSet(leftHand, time_k, eventOfFocus); // Y_l(t_k) + final double individualsAtRiskDaughterRight = riskSet(rightHand, time_k, eventOfFocus); // Y_r(t_k) + final double individualsAtRisk = individualsAtRiskDaughterLeft + individualsAtRiskDaughterRight; // Y(t_k) + + summation = summation + weight*(numberEventsAtTimeDaughterLeft - numberOfEventsAtTime*individualsAtRiskDaughterLeft/individualsAtRisk); + + varianceSquared = varianceSquared + weight*weight*numberOfEventsAtTime*individualsAtRiskDaughterLeft/individualsAtRisk + * (1.0 - individualsAtRiskDaughterLeft / individualsAtRisk) + * ((individualsAtRisk - numberOfEventsAtTime) / (individualsAtRisk - 1.0)); + + + } + + return new LogRankValue(summation, varianceSquared); + } + + double weight(double time){ + return 1.0; // TODO - make configurable + // A value of 1 "corresponds to the standard log-rank test which has optimal power for detecting alternatives where the cause-specific hazards are proportional" + //TODO - look into what weights might be more appropriate. + } + + @Data + @AllArgsConstructor + static class LogRankValue{ + private final double numerator; + private final double varianceSquared; + + public double getVariance(){ + return Math.sqrt(varianceSquared); + } + } + + +} diff --git a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/GrayLogRankMultipleGroupDifferentiator.java b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/GrayLogRankMultipleGroupDifferentiator.java new file mode 100644 index 0000000..10694eb --- /dev/null +++ b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/GrayLogRankMultipleGroupDifferentiator.java @@ -0,0 +1,51 @@ +package ca.joeltherrien.randomforest.responses.competingrisk; + +import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceGroupDifferentiator; +import ca.joeltherrien.randomforest.tree.GroupDifferentiator; +import com.fasterxml.jackson.databind.JsonNode; +import lombok.RequiredArgsConstructor; + +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; + +/** + * See page 761 of Random survival forests for competing risks by Ishwaran et al. + * + */ +@RequiredArgsConstructor +public class GrayLogRankMultipleGroupDifferentiator extends CompetingRiskGroupDifferentiator { + + private final int[] events; + + @Override + public Double differentiate(List leftHand, List rightHand) { + if(leftHand.size() == 0 || rightHand.size() == 0){ + return null; + } + + double numerator = 0.0; + double denominatorSquared = 0.0; + + for(final int eventOfFocus : events){ + final LogRankValue valueOfInterest = specificLogRankValue(eventOfFocus, leftHand, rightHand); + + numerator += valueOfInterest.getNumerator()*valueOfInterest.getVariance(); + denominatorSquared += valueOfInterest.getVarianceSquared(); + + } + + return Math.abs(numerator / Math.sqrt(denominatorSquared)); + + } + + @Override + double riskSet(List eventList, double time, int eventOfFocus) { + return eventList.stream() + .filter(event -> event.getU() >= time || + (event.getU() < time && event.getDelta() != eventOfFocus && event.getC() > time) + ) + .count(); + } + +} diff --git a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/GrayLogRankSingleGroupDifferentiator.java b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/GrayLogRankSingleGroupDifferentiator.java new file mode 100644 index 0000000..4982144 --- /dev/null +++ b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/GrayLogRankSingleGroupDifferentiator.java @@ -0,0 +1,41 @@ +package ca.joeltherrien.randomforest.responses.competingrisk; + +import ca.joeltherrien.randomforest.tree.GroupDifferentiator; +import com.fasterxml.jackson.databind.JsonNode; +import lombok.RequiredArgsConstructor; + +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; + +/** + * See page 761 of Random survival forests for competing risks by Ishwaran et al. + * + */ +@RequiredArgsConstructor +public class GrayLogRankSingleGroupDifferentiator extends CompetingRiskGroupDifferentiator { + + private final int eventOfFocus; + + @Override + public Double differentiate(List leftHand, List rightHand) { + if(leftHand.size() == 0 || rightHand.size() == 0){ + return null; + } + + final LogRankValue valueOfInterest = specificLogRankValue(eventOfFocus, leftHand, rightHand); + + return Math.abs(valueOfInterest.getNumerator() / valueOfInterest.getVariance()); + + } + + @Override + double riskSet(List eventList, double time, int eventOfFocus) { + return eventList.stream() + .filter(event -> event.getU() >= time || + (event.getU() < time && event.getDelta() != eventOfFocus && event.getC() > time) + ) + .count(); + } + +} diff --git a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/LogRankMultipleGroupDifferentiator.java b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/LogRankMultipleGroupDifferentiator.java new file mode 100644 index 0000000..8926b1a --- /dev/null +++ b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/LogRankMultipleGroupDifferentiator.java @@ -0,0 +1,48 @@ +package ca.joeltherrien.randomforest.responses.competingrisk; + +import ca.joeltherrien.randomforest.tree.GroupDifferentiator; +import com.fasterxml.jackson.databind.JsonNode; +import lombok.RequiredArgsConstructor; + +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; + +/** + * See page 761 of Random survival forests for competing risks by Ishwaran et al. + * + */ +@RequiredArgsConstructor +public class LogRankMultipleGroupDifferentiator extends CompetingRiskGroupDifferentiator { + + private final int[] events; + + @Override + public Double differentiate(List leftHand, List rightHand) { + if(leftHand.size() == 0 || rightHand.size() == 0){ + return null; + } + + double numerator = 0.0; + double denominatorSquared = 0.0; + + for(final int eventOfFocus : events){ + final LogRankValue valueOfInterest = specificLogRankValue(eventOfFocus, leftHand, rightHand); + + numerator += valueOfInterest.getNumerator()*valueOfInterest.getVariance(); + denominatorSquared += valueOfInterest.getVarianceSquared(); + + } + + return Math.abs(numerator / Math.sqrt(denominatorSquared)); + + } + + @Override + double riskSet(List eventList, double time, int eventOfFocus) { + return eventList.stream() + .filter(event -> event.getU() >= time) + .count(); + } + +} diff --git a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/LogRankSingleGroupDifferentiator.java b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/LogRankSingleGroupDifferentiator.java new file mode 100644 index 0000000..14688df --- /dev/null +++ b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/LogRankSingleGroupDifferentiator.java @@ -0,0 +1,36 @@ +package ca.joeltherrien.randomforest.responses.competingrisk; + +import ca.joeltherrien.randomforest.tree.GroupDifferentiator; +import lombok.RequiredArgsConstructor; + +import java.util.List; + +/** + * See page 761 of Random survival forests for competing risks by Ishwaran et al. + * + */ +@RequiredArgsConstructor +public class LogRankSingleGroupDifferentiator extends CompetingRiskGroupDifferentiator { + + private final int eventOfFocus; + + @Override + public Double differentiate(List leftHand, List rightHand) { + if(leftHand.size() == 0 || rightHand.size() == 0){ + return null; + } + + final LogRankValue valueOfInterest = specificLogRankValue(eventOfFocus, leftHand, rightHand); + + return Math.abs(valueOfInterest.getNumerator() / valueOfInterest.getVariance()); + + } + + @Override + double riskSet(List eventList, double time, int eventOfFocus) { + return eventList.stream() + .filter(event -> event.getU() >= time) + .count(); + } + +} diff --git a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/Point.java b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/Point.java new file mode 100644 index 0000000..1bef210 --- /dev/null +++ b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/Point.java @@ -0,0 +1,15 @@ +package ca.joeltherrien.randomforest.responses.competingrisk; + +import lombok.Data; + +/** + * Represents a point in our estimate of either the cumulative hazard function or the cumulative incidence function. + * + */ +@Data +public class Point { + + private final Double time; + private final Double y; + +} diff --git a/src/main/java/ca/joeltherrien/randomforest/regression/MeanGroupDifferentiator.java b/src/main/java/ca/joeltherrien/randomforest/responses/regression/MeanGroupDifferentiator.java similarity index 74% rename from src/main/java/ca/joeltherrien/randomforest/regression/MeanGroupDifferentiator.java rename to src/main/java/ca/joeltherrien/randomforest/responses/regression/MeanGroupDifferentiator.java index da3b0b1..75bc129 100644 --- a/src/main/java/ca/joeltherrien/randomforest/regression/MeanGroupDifferentiator.java +++ b/src/main/java/ca/joeltherrien/randomforest/responses/regression/MeanGroupDifferentiator.java @@ -1,16 +1,11 @@ -package ca.joeltherrien.randomforest.regression; +package ca.joeltherrien.randomforest.responses.regression; import ca.joeltherrien.randomforest.tree.GroupDifferentiator; -import ca.joeltherrien.randomforest.tree.ResponseCombiner; import java.util.List; public class MeanGroupDifferentiator implements GroupDifferentiator { - static{ - GroupDifferentiator.registerGroupDifferentiator("MeanGroupDifferentiator", new MeanGroupDifferentiator()); - } - @Override public Double differentiate(List leftHand, List rightHand) { diff --git a/src/main/java/ca/joeltherrien/randomforest/regression/MeanResponseCombiner.java b/src/main/java/ca/joeltherrien/randomforest/responses/regression/MeanResponseCombiner.java similarity index 97% rename from src/main/java/ca/joeltherrien/randomforest/regression/MeanResponseCombiner.java rename to src/main/java/ca/joeltherrien/randomforest/responses/regression/MeanResponseCombiner.java index 3ff43f3..36e5701 100644 --- a/src/main/java/ca/joeltherrien/randomforest/regression/MeanResponseCombiner.java +++ b/src/main/java/ca/joeltherrien/randomforest/responses/regression/MeanResponseCombiner.java @@ -1,4 +1,4 @@ -package ca.joeltherrien.randomforest.regression; +package ca.joeltherrien.randomforest.responses.regression; import ca.joeltherrien.randomforest.tree.ResponseCombiner; diff --git a/src/main/java/ca/joeltherrien/randomforest/regression/WeightedVarianceGroupDifferentiator.java b/src/main/java/ca/joeltherrien/randomforest/responses/regression/WeightedVarianceGroupDifferentiator.java similarity index 83% rename from src/main/java/ca/joeltherrien/randomforest/regression/WeightedVarianceGroupDifferentiator.java rename to src/main/java/ca/joeltherrien/randomforest/responses/regression/WeightedVarianceGroupDifferentiator.java index c27363f..25f7e6e 100644 --- a/src/main/java/ca/joeltherrien/randomforest/regression/WeightedVarianceGroupDifferentiator.java +++ b/src/main/java/ca/joeltherrien/randomforest/responses/regression/WeightedVarianceGroupDifferentiator.java @@ -1,4 +1,4 @@ -package ca.joeltherrien.randomforest.regression; +package ca.joeltherrien.randomforest.responses.regression; import ca.joeltherrien.randomforest.tree.GroupDifferentiator; @@ -6,10 +6,6 @@ import java.util.List; public class WeightedVarianceGroupDifferentiator implements GroupDifferentiator { - static{ - GroupDifferentiator.registerGroupDifferentiator("WeightedVarianceGroupDifferentiator", new WeightedVarianceGroupDifferentiator()); - } - @Override public Double differentiate(List leftHand, List rightHand) { diff --git a/src/main/java/ca/joeltherrien/randomforest/tree/GroupDifferentiator.java b/src/main/java/ca/joeltherrien/randomforest/tree/GroupDifferentiator.java index 8950aa8..bb67a95 100644 --- a/src/main/java/ca/joeltherrien/randomforest/tree/GroupDifferentiator.java +++ b/src/main/java/ca/joeltherrien/randomforest/tree/GroupDifferentiator.java @@ -1,5 +1,7 @@ package ca.joeltherrien.randomforest.tree; +import com.fasterxml.jackson.databind.node.ObjectNode; + import java.util.HashMap; import java.util.List; import java.util.Map; @@ -14,12 +16,11 @@ public interface GroupDifferentiator { Double differentiate(List leftHand, List rightHand); - Map GROUP_DIFFERENTIATOR_MAP = new HashMap<>(); - static GroupDifferentiator loadGroupDifferentiatorByName(final String name){ - return GROUP_DIFFERENTIATOR_MAP.get(name); - } - static void registerGroupDifferentiator(final String name, final GroupDifferentiator groupDifferentiator){ - GROUP_DIFFERENTIATOR_MAP.put(name, groupDifferentiator); + @FunctionalInterface + interface GroupDifferentiatorConstructor{ + + GroupDifferentiator construct(ObjectNode node); + } } diff --git a/src/main/java/ca/joeltherrien/randomforest/tree/ResponseCombiner.java b/src/main/java/ca/joeltherrien/randomforest/tree/ResponseCombiner.java index 6430aa0..8c3eab6 100644 --- a/src/main/java/ca/joeltherrien/randomforest/tree/ResponseCombiner.java +++ b/src/main/java/ca/joeltherrien/randomforest/tree/ResponseCombiner.java @@ -1,7 +1,5 @@ package ca.joeltherrien.randomforest.tree; -import ca.joeltherrien.randomforest.regression.MeanResponseCombiner; - import java.util.HashMap; import java.util.List; import java.util.Map; diff --git a/src/main/java/ca/joeltherrien/randomforest/tree/TreeTrainer.java b/src/main/java/ca/joeltherrien/randomforest/tree/TreeTrainer.java index d447952..9fb0eaf 100644 --- a/src/main/java/ca/joeltherrien/randomforest/tree/TreeTrainer.java +++ b/src/main/java/ca/joeltherrien/randomforest/tree/TreeTrainer.java @@ -30,7 +30,7 @@ public class TreeTrainer { this.maxNodeDepth = settings.getMaxNodeDepth(); this.responseCombiner = ResponseCombiner.loadResponseCombinerByName(settings.getResponseCombiner()); - this.groupDifferentiator = GroupDifferentiator.loadGroupDifferentiatorByName(settings.getGroupDifferentiator()); + this.groupDifferentiator = settings.getGroupDifferentiator(); } public Node growTree(List> data, List covariatesToTry){ diff --git a/src/test/java/ca/joeltherrien/randomforest/csv/TestLoadingCSV.java b/src/test/java/ca/joeltherrien/randomforest/csv/TestLoadingCSV.java index 8032185..96a2a48 100644 --- a/src/test/java/ca/joeltherrien/randomforest/csv/TestLoadingCSV.java +++ b/src/test/java/ca/joeltherrien/randomforest/csv/TestLoadingCSV.java @@ -1,12 +1,15 @@ package ca.joeltherrien.randomforest.csv; -import ca.joeltherrien.randomforest.Main; +import ca.joeltherrien.randomforest.DataLoader; import ca.joeltherrien.randomforest.Row; import ca.joeltherrien.randomforest.Settings; import ca.joeltherrien.randomforest.covariates.BooleanCovariateSettings; import ca.joeltherrien.randomforest.covariates.Covariate; import ca.joeltherrien.randomforest.covariates.FactorCovariateSettings; import ca.joeltherrien.randomforest.covariates.NumericCovariateSettings; +import com.fasterxml.jackson.databind.node.JsonNodeFactory; +import com.fasterxml.jackson.databind.node.ObjectNode; +import com.fasterxml.jackson.databind.node.TextNode; import org.junit.jupiter.api.Test; import java.io.IOException; @@ -26,7 +29,11 @@ public class TestLoadingCSV { */ @Test - public void verifyLoading() throws IOException { + public void verifyLoading() throws IOException, ClassNotFoundException { + final ObjectNode yVarSettings = new ObjectNode(JsonNodeFactory.instance); + yVarSettings.set("type", new TextNode("Double")); + yVarSettings.set("name", new TextNode("y")); + final Settings settings = Settings.builder() .dataFileLocation("src/test/resources/testCSV.csv") .covariates( @@ -34,13 +41,16 @@ public class TestLoadingCSV { new FactorCovariateSettings("x2", List.of("dog", "cat", "mouse")), new BooleanCovariateSettings("x3")) ) - .yVar("y") + .yVarSettings(yVarSettings) .build(); final List covariates = settings.getCovariates().stream() .map(cs -> cs.build()).collect(Collectors.toList()); - final List> data = Main.loadData(covariates, settings); + + final DataLoader.ResponseLoader loader = settings.getResponseLoader(); + + final List> data = DataLoader.loadData(covariates, loader, settings.getDataFileLocation()); assertEquals(4, data.size()); diff --git a/src/test/java/ca/joeltherrien/randomforest/settings/TestPersistence.java b/src/test/java/ca/joeltherrien/randomforest/settings/TestPersistence.java index 421ca65..f48212f 100644 --- a/src/test/java/ca/joeltherrien/randomforest/settings/TestPersistence.java +++ b/src/test/java/ca/joeltherrien/randomforest/settings/TestPersistence.java @@ -4,31 +4,38 @@ import ca.joeltherrien.randomforest.Settings; import static org.junit.jupiter.api.Assertions.assertEquals; import ca.joeltherrien.randomforest.covariates.*; -import ca.joeltherrien.randomforest.regression.MeanResponseCombiner; -import ca.joeltherrien.randomforest.regression.WeightedVarianceGroupDifferentiator; +import com.fasterxml.jackson.databind.node.JsonNodeFactory; +import com.fasterxml.jackson.databind.node.ObjectNode; +import com.fasterxml.jackson.databind.node.TextNode; import org.junit.jupiter.api.Test; import java.io.File; import java.io.IOException; -import java.util.ArrayList; import java.util.List; public class TestPersistence { @Test public void testSaving() throws IOException { - final Settings settingsOriginal = Settings.builder() + final ObjectNode groupDifferentiatorSettings = new ObjectNode(JsonNodeFactory.instance); + groupDifferentiatorSettings.set("type", new TextNode("WeightedVarianceGroupDifferentiator")); + + final ObjectNode yVarSettings = new ObjectNode(JsonNodeFactory.instance); + yVarSettings.set("type", new TextNode("Double")); + yVarSettings.set("name", new TextNode("y")); + + final Settings settingsOriginal = Settings.builder() .covariates(List.of( new NumericCovariateSettings("x1"), new BooleanCovariateSettings("x2"), new FactorCovariateSettings("x3", List.of("cat", "mouse", "dog")) ) ) - .yVar("y") .dataFileLocation("data.csv") - .groupDifferentiator("WeightedVarianceGroupDifferentiator") .responseCombiner("MeanResponseCombiner") .treeResponseCombiner("MeanResponseCombiner") + .groupDifferentiatorSettings(groupDifferentiatorSettings) + .yVarSettings(yVarSettings) .maxNodeDepth(100000) .mtry(2) .nodeSize(5) @@ -46,7 +53,7 @@ public class TestPersistence { assertEquals(settingsOriginal, reloadedSettings); - templateFile.delete(); + //templateFile.delete(); } diff --git a/src/test/java/ca/joeltherrien/randomforest/workshop/TrainForest.java b/src/test/java/ca/joeltherrien/randomforest/workshop/TrainForest.java index 00ba886..cbce172 100644 --- a/src/test/java/ca/joeltherrien/randomforest/workshop/TrainForest.java +++ b/src/test/java/ca/joeltherrien/randomforest/workshop/TrainForest.java @@ -3,8 +3,8 @@ package ca.joeltherrien.randomforest.workshop; import ca.joeltherrien.randomforest.*; import ca.joeltherrien.randomforest.covariates.Covariate; import ca.joeltherrien.randomforest.covariates.NumericCovariate; -import ca.joeltherrien.randomforest.regression.MeanResponseCombiner; -import ca.joeltherrien.randomforest.regression.WeightedVarianceGroupDifferentiator; +import ca.joeltherrien.randomforest.responses.regression.MeanResponseCombiner; +import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceGroupDifferentiator; import ca.joeltherrien.randomforest.tree.ForestTrainer; import ca.joeltherrien.randomforest.tree.TreeTrainer; diff --git a/src/test/java/ca/joeltherrien/randomforest/workshop/TrainSingleTree.java b/src/test/java/ca/joeltherrien/randomforest/workshop/TrainSingleTree.java index 6e5dd55..dec417e 100644 --- a/src/test/java/ca/joeltherrien/randomforest/workshop/TrainSingleTree.java +++ b/src/test/java/ca/joeltherrien/randomforest/workshop/TrainSingleTree.java @@ -5,8 +5,8 @@ import ca.joeltherrien.randomforest.covariates.Covariate; import ca.joeltherrien.randomforest.CovariateRow; import ca.joeltherrien.randomforest.covariates.NumericCovariate; import ca.joeltherrien.randomforest.Row; -import ca.joeltherrien.randomforest.regression.MeanResponseCombiner; -import ca.joeltherrien.randomforest.regression.WeightedVarianceGroupDifferentiator; +import ca.joeltherrien.randomforest.responses.regression.MeanResponseCombiner; +import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceGroupDifferentiator; import ca.joeltherrien.randomforest.tree.Node; import ca.joeltherrien.randomforest.tree.TreeTrainer; diff --git a/src/test/java/ca/joeltherrien/randomforest/workshop/TrainSingleTreeFactor.java b/src/test/java/ca/joeltherrien/randomforest/workshop/TrainSingleTreeFactor.java index 665675c..9b30229 100644 --- a/src/test/java/ca/joeltherrien/randomforest/workshop/TrainSingleTreeFactor.java +++ b/src/test/java/ca/joeltherrien/randomforest/workshop/TrainSingleTreeFactor.java @@ -5,8 +5,8 @@ import ca.joeltherrien.randomforest.*; import ca.joeltherrien.randomforest.covariates.Covariate; import ca.joeltherrien.randomforest.covariates.FactorCovariate; import ca.joeltherrien.randomforest.covariates.NumericCovariate; -import ca.joeltherrien.randomforest.regression.MeanResponseCombiner; -import ca.joeltherrien.randomforest.regression.WeightedVarianceGroupDifferentiator; +import ca.joeltherrien.randomforest.responses.regression.MeanResponseCombiner; +import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceGroupDifferentiator; import ca.joeltherrien.randomforest.tree.Node; import ca.joeltherrien.randomforest.tree.TreeTrainer;