diff --git a/src/main/java/ca/joeltherrien/randomforest/Main.java b/src/main/java/ca/joeltherrien/randomforest/Main.java index 3f211bc..ee36967 100644 --- a/src/main/java/ca/joeltherrien/randomforest/Main.java +++ b/src/main/java/ca/joeltherrien/randomforest/Main.java @@ -6,8 +6,10 @@ import ca.joeltherrien.randomforest.covariates.FactorCovariateSettings; import ca.joeltherrien.randomforest.covariates.NumericCovariateSettings; import ca.joeltherrien.randomforest.responses.competingrisk.*; import ca.joeltherrien.randomforest.responses.competingrisk.combiner.CompetingRiskFunctionCombiner; +import ca.joeltherrien.randomforest.responses.competingrisk.combiner.alternative.CompetingRiskListCombiner; import ca.joeltherrien.randomforest.tree.Forest; import ca.joeltherrien.randomforest.tree.ForestTrainer; +import ca.joeltherrien.randomforest.tree.ResponseCombiner; import ca.joeltherrien.randomforest.utils.MathFunction; import ca.joeltherrien.randomforest.utils.Utils; import com.fasterxml.jackson.databind.node.JsonNodeFactory; @@ -54,16 +56,29 @@ public class Main { if(args.length < 3){ System.out.println("Specify error sample size"); + return; } final String yVarType = settings.getYVarSettings().get("type").asText(); if(!yVarType.equalsIgnoreCase("CompetingRiskResponse") && !yVarType.equalsIgnoreCase("CompetingRiskResponseWithCensorTime")){ System.out.println("Analyze currently only works on competing risk data"); + return; } - final CompetingRiskFunctionCombiner responseCombiner = (CompetingRiskFunctionCombiner) settings.getTreeCombiner(); - final int[] events = responseCombiner.getEvents(); + final ResponseCombiner responseCombiner = settings.getTreeCombiner(); + final int[] events; + + if(responseCombiner instanceof CompetingRiskFunctionCombiner){ + events = ((CompetingRiskFunctionCombiner) responseCombiner).getEvents(); + } + else if(responseCombiner instanceof CompetingRiskListCombiner){ + events = ((CompetingRiskListCombiner) responseCombiner).getOriginalCombiner().getEvents(); + } + else{ + System.out.println("Unsupported tree combiner"); + return; + } final List> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getDataFileLocation()); @@ -72,7 +87,7 @@ public class Main { Utils.reduceListToSize(dataset, n); final File folder = new File(settings.getSaveTreeLocation()); - final Forest forest = DataLoader.loadForest(folder, responseCombiner); + final Forest forest = DataLoader.loadForest(folder, responseCombiner); System.out.println("Finished loading trees + dataset; creating calculator and evaluating OOB predictions"); diff --git a/src/main/java/ca/joeltherrien/randomforest/Settings.java b/src/main/java/ca/joeltherrien/randomforest/Settings.java index f8a24f1..bbbfb29 100644 --- a/src/main/java/ca/joeltherrien/randomforest/Settings.java +++ b/src/main/java/ca/joeltherrien/randomforest/Settings.java @@ -4,6 +4,8 @@ import ca.joeltherrien.randomforest.covariates.CovariateSettings; import ca.joeltherrien.randomforest.responses.competingrisk.*; import ca.joeltherrien.randomforest.responses.competingrisk.combiner.CompetingRiskFunctionCombiner; import ca.joeltherrien.randomforest.responses.competingrisk.combiner.CompetingRiskResponseCombiner; +import ca.joeltherrien.randomforest.responses.competingrisk.combiner.alternative.CompetingRiskListCombiner; +import ca.joeltherrien.randomforest.responses.competingrisk.combiner.alternative.CompetingRiskResponseCombinerToList; import ca.joeltherrien.randomforest.responses.competingrisk.differentiator.GrayLogRankMultipleGroupDifferentiator; import ca.joeltherrien.randomforest.responses.competingrisk.differentiator.GrayLogRankSingleGroupDifferentiator; import ca.joeltherrien.randomforest.responses.competingrisk.differentiator.LogRankMultipleGroupDifferentiator; @@ -158,6 +160,30 @@ public class Settings { } ); + registerResponseCombinerConstructor("CompetingRiskListCombiner", + (node) -> { + final List eventList = new ArrayList<>(); + node.get("events").elements().forEachRemaining(event -> eventList.add(event.asInt())); + final int[] events = eventList.stream().mapToInt(i -> i).toArray(); + + double[] times = null; + // note that times may be null + if(node.hasNonNull("times")){ + final List timeList = new ArrayList<>(); + node.get("times").elements().forEachRemaining(event -> timeList.add(event.asDouble())); + times = eventList.stream().mapToDouble(db -> db).toArray(); + } + + final CompetingRiskResponseCombiner responseCombiner = new CompetingRiskResponseCombiner(events, times); + return new CompetingRiskListCombiner(responseCombiner); + + } + ); + + registerResponseCombinerConstructor("CompetingRiskResponseCombinerToList", + (node) -> new CompetingRiskResponseCombinerToList() + ); + } diff --git a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskErrorRateCalculator.java b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskErrorRateCalculator.java index e3926ef..134673c 100644 --- a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskErrorRateCalculator.java +++ b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskErrorRateCalculator.java @@ -19,7 +19,7 @@ public class CompetingRiskErrorRateCalculator { private final List> dataset; private final List riskFunctions; - public CompetingRiskErrorRateCalculator(final List> dataset, final Forest forest){ + public CompetingRiskErrorRateCalculator(final List> dataset, final Forest forest){ this.dataset = dataset; this.riskFunctions = dataset.stream() .map(forest::evaluateOOB) diff --git a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskResponse.java b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskResponse.java index 4ef9ee5..7321a08 100644 --- a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskResponse.java +++ b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskResponse.java @@ -6,8 +6,10 @@ import lombok.Data; import lombok.RequiredArgsConstructor; import org.apache.commons.csv.CSVRecord; +import java.io.Serializable; + @Data -public class CompetingRiskResponse { +public class CompetingRiskResponse implements Serializable { private final int delta; private final double u; diff --git a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/combiner/alternative/CompetingRiskListCombiner.java b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/combiner/alternative/CompetingRiskListCombiner.java new file mode 100644 index 0000000..f2f5fa5 --- /dev/null +++ b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/combiner/alternative/CompetingRiskListCombiner.java @@ -0,0 +1,26 @@ +package ca.joeltherrien.randomforest.responses.competingrisk.combiner.alternative; + +import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskFunctions; +import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse; +import ca.joeltherrien.randomforest.responses.competingrisk.combiner.CompetingRiskResponseCombiner; +import ca.joeltherrien.randomforest.tree.ResponseCombiner; +import lombok.Getter; +import lombok.RequiredArgsConstructor; + +import java.util.Arrays; +import java.util.List; +import java.util.stream.Collectors; + +@RequiredArgsConstructor +public class CompetingRiskListCombiner implements ResponseCombiner { + + @Getter + private final CompetingRiskResponseCombiner originalCombiner; + + @Override + public CompetingRiskFunctions combine(List responses) { + final List completeList = responses.stream().flatMap(Arrays::stream).collect(Collectors.toList()); + + return originalCombiner.combine(completeList); + } +} diff --git a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/combiner/alternative/CompetingRiskResponseCombinerToList.java b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/combiner/alternative/CompetingRiskResponseCombinerToList.java new file mode 100644 index 0000000..5526d0a --- /dev/null +++ b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/combiner/alternative/CompetingRiskResponseCombinerToList.java @@ -0,0 +1,29 @@ +package ca.joeltherrien.randomforest.responses.competingrisk.combiner.alternative; + +import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse; +import ca.joeltherrien.randomforest.tree.ResponseCombiner; + +import java.util.List; + +/** + * This class takes all of the observations in a terminal node and 'combines' them into just a list of the observations. + * + * This is used in the alternative approach to only compute the functions at the final stage when combining trees. + * + */ +public class CompetingRiskResponseCombinerToList implements ResponseCombiner { + + @Override + public CompetingRiskResponse[] combine(List responses) { + final CompetingRiskResponse[] array = new CompetingRiskResponse[responses.size()]; + + for(int i=0; i { // O = output of trees, FO = forest output. In prac public FO evaluateOOB(CovariateRow row){ return treeResponseCombiner.combine( - trees.stream() + trees.parallelStream() .filter(tree -> !tree.idInBootstrapSample(row.getId())) .map(node -> node.evaluate(row)) .collect(Collectors.toList()) diff --git a/src/main/java/ca/joeltherrien/randomforest/tree/Tree.java b/src/main/java/ca/joeltherrien/randomforest/tree/Tree.java index a6d4252..48b4663 100644 --- a/src/main/java/ca/joeltherrien/randomforest/tree/Tree.java +++ b/src/main/java/ca/joeltherrien/randomforest/tree/Tree.java @@ -10,7 +10,13 @@ public class Tree implements Node { private final Node rootNode; private final int[] bootstrapRowIds; - private boolean bootStrapRowIdsSorted = false; + + + public Tree(Node rootNode, int[] bootstrapRowIds) { + this.rootNode = rootNode; + this.bootstrapRowIds = bootstrapRowIds; + Arrays.sort(bootstrapRowIds); + } @Override public Y evaluate(CovariateRow row) { @@ -21,21 +27,8 @@ public class Tree implements Node { return bootstrapRowIds.clone(); } - /** - * Sort bootstrapRowIds. This is not done automatically for efficiency purposes, as in many cases we may not be interested in using bootstrapRowIds(); - * - */ - public void sortBootstrapRowIds(){ - if(!bootStrapRowIdsSorted){ - Arrays.sort(bootstrapRowIds); - bootStrapRowIdsSorted = true; - } - - } public boolean idInBootstrapSample(int id){ - this.sortBootstrapRowIds(); - return Arrays.binarySearch(this.bootstrapRowIds, id) >= 0; }