From f3a4ef01ed7b2fc69a75f3b59274ce678b4d983b Mon Sep 17 00:00:00 2001 From: Joel Therrien Date: Wed, 13 Nov 2019 17:08:31 -0800 Subject: [PATCH] Add support for offline forests. --- .../joeltherrien/randomforest/Settings.java | 21 +++- .../CompetingRiskFunctionCombiner.java | 80 +++--------- ...iateCompetingRisksFunctionsTimesKnown.java | 118 ++++++++++++++++++ .../regression/MeanResponseCombiner.java | 39 +++++- .../tree/ForestResponseCombiner.java | 23 ++++ .../randomforest/tree/ForestTrainer.java | 2 +- .../tree/IntermediateCombinedResponse.java | 30 +++++ .../randomforest/tree/OfflineForest.java | 92 +++++++------- .../randomforest/utils/DataUtils.java | 5 + .../randomforest/utils/RUtils.java | 11 +- 10 files changed, 296 insertions(+), 125 deletions(-) create mode 100644 library/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/combiner/IntermediateCompetingRisksFunctionsTimesKnown.java create mode 100644 library/src/main/java/ca/joeltherrien/randomforest/tree/ForestResponseCombiner.java create mode 100644 library/src/main/java/ca/joeltherrien/randomforest/tree/IntermediateCombinedResponse.java diff --git a/executable/src/main/java/ca/joeltherrien/randomforest/Settings.java b/executable/src/main/java/ca/joeltherrien/randomforest/Settings.java index 78a425a..a5ca343 100644 --- a/executable/src/main/java/ca/joeltherrien/randomforest/Settings.java +++ b/executable/src/main/java/ca/joeltherrien/randomforest/Settings.java @@ -28,6 +28,7 @@ import ca.joeltherrien.randomforest.responses.competingrisk.splitfinder.GrayLogR import ca.joeltherrien.randomforest.responses.competingrisk.splitfinder.LogRankSplitFinder; import ca.joeltherrien.randomforest.responses.regression.MeanResponseCombiner; import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceSplitFinder; +import ca.joeltherrien.randomforest.tree.ForestResponseCombiner; import ca.joeltherrien.randomforest.tree.ResponseCombiner; import ca.joeltherrien.randomforest.tree.SplitFinder; import ca.joeltherrien.randomforest.utils.JsonUtils; @@ -110,6 +111,8 @@ public class Settings { private static Map> RESPONSE_COMBINER_MAP = new HashMap<>(); + private static Map> FOREST_RESPONSE_COMBINER_MAP = new HashMap<>(); + public static Function getResponseCombinerConstructor(final String name){ return RESPONSE_COMBINER_MAP.get(name.toLowerCase()); } @@ -117,11 +120,21 @@ public class Settings { RESPONSE_COMBINER_MAP.put(name.toLowerCase(), responseCombinerConstructor); } + public static Function getForestResponseCombinerConstructor(final String name){ + return FOREST_RESPONSE_COMBINER_MAP.get(name.toLowerCase()); + } + public static void registerForestResponseCombinerConstructor(final String name, final Function responseCombinerConstructor){ + FOREST_RESPONSE_COMBINER_MAP.put(name.toLowerCase(), responseCombinerConstructor); + } + static{ registerResponseCombinerConstructor("MeanResponseCombiner", (node) -> new MeanResponseCombiner() ); + registerForestResponseCombinerConstructor("MeanResponseCombiner", + (node) -> new MeanResponseCombiner() + ); registerResponseCombinerConstructor("CompetingRiskResponseCombiner", (node) -> { final int[] events = JsonUtils.jsonToIntArray(node.get("events")); @@ -131,7 +144,7 @@ public class Settings { } ); - registerResponseCombinerConstructor("CompetingRiskFunctionCombiner", + registerForestResponseCombinerConstructor("CompetingRiskFunctionCombiner", (node) -> { final int[] events = JsonUtils.jsonToIntArray(node.get("events")); @@ -144,8 +157,6 @@ public class Settings { } ); - - } private int numberOfSplits = 5; @@ -217,10 +228,10 @@ public class Settings { } @JsonIgnore - public ResponseCombiner getTreeCombiner(){ + public ForestResponseCombiner getTreeCombiner(){ final String type = treeCombinerSettings.get("type").asText(); - return getResponseCombinerConstructor(type).apply(treeCombinerSettings); + return getForestResponseCombinerConstructor(type).apply(treeCombinerSettings); } @JsonIgnore diff --git a/library/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/combiner/CompetingRiskFunctionCombiner.java b/library/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/combiner/CompetingRiskFunctionCombiner.java index 89d55f5..97b08fc 100644 --- a/library/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/combiner/CompetingRiskFunctionCombiner.java +++ b/library/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/combiner/CompetingRiskFunctionCombiner.java @@ -17,17 +17,15 @@ package ca.joeltherrien.randomforest.responses.competingrisk.combiner; import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskFunctions; -import ca.joeltherrien.randomforest.tree.ResponseCombiner; -import ca.joeltherrien.randomforest.utils.RightContinuousStepFunction; -import ca.joeltherrien.randomforest.utils.Utils; +import ca.joeltherrien.randomforest.tree.ForestResponseCombiner; +import ca.joeltherrien.randomforest.tree.IntermediateCombinedResponse; import lombok.RequiredArgsConstructor; -import java.util.ArrayList; import java.util.Arrays; import java.util.List; @RequiredArgsConstructor -public class CompetingRiskFunctionCombiner implements ResponseCombiner { +public class CompetingRiskFunctionCombiner implements ForestResponseCombiner { private static final long serialVersionUID = 1L; @@ -57,72 +55,22 @@ public class CompetingRiskFunctionCombiner implements ResponseCombiner causeSpecificCumulativeHazardFunctionList = new ArrayList<>(events.length); - final List cumulativeIncidenceFunctionList = new ArrayList<>(events.length); + return intermediateResult.transformToOutput(); + } - for(final int event : events){ - causeSpecificCumulativeHazardFunctionList.add(event-1, new RightContinuousStepFunction(timesToUse, csCHFY[event-1], 0)); - cumulativeIncidenceFunctionList.add(event-1, new RightContinuousStepFunction(timesToUse, cifY[event-1], 0)); + @Override + public IntermediateCombinedResponse startIntermediateCombinedResponse(int countInputs) { + if(this.times != null){ + return new IntermediateCompetingRisksFunctionsTimesKnown(countInputs, this.events, this.times); } - return CompetingRiskFunctions.builder() - .causeSpecificHazards(causeSpecificCumulativeHazardFunctionList) - .cumulativeIncidenceCurves(cumulativeIncidenceFunctionList) - .survivalCurve(survivalFunction) - .build(); + // TODO - implement + throw new RuntimeException("startIntermediateCombinedResponse when times is unknown is not yet implemented"); } } diff --git a/library/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/combiner/IntermediateCompetingRisksFunctionsTimesKnown.java b/library/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/combiner/IntermediateCompetingRisksFunctionsTimesKnown.java new file mode 100644 index 0000000..536dc7c --- /dev/null +++ b/library/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/combiner/IntermediateCompetingRisksFunctionsTimesKnown.java @@ -0,0 +1,118 @@ +package ca.joeltherrien.randomforest.responses.competingrisk.combiner; + +import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskFunctions; +import ca.joeltherrien.randomforest.tree.IntermediateCombinedResponse; +import ca.joeltherrien.randomforest.utils.RightContinuousStepFunction; +import ca.joeltherrien.randomforest.utils.Utils; + +import java.util.ArrayList; +import java.util.List; + +public class IntermediateCompetingRisksFunctionsTimesKnown implements IntermediateCombinedResponse { + + private double expectedN; + private final int[] events; + private final double[] timesToUse; + private int actualN; + + private final double[] survivalY; + private final double[][] csCHFY; + private final double[][] cifY; + + public IntermediateCompetingRisksFunctionsTimesKnown(int n, int[] events, double[] timesToUse){ + this.expectedN = n; + this.events = events; + this.timesToUse = timesToUse; + this.actualN = 0; + + this.survivalY = new double[timesToUse.length]; + this.csCHFY = new double[events.length][timesToUse.length]; + this.cifY = new double[events.length][timesToUse.length]; + } + + @Override + public void processNewInput(CompetingRiskFunctions input) { + /* + We're going to try to efficiently put our predictions together - + Assumptions - for each event on a response, the hazard and CIF functions share the same x points + + Plan - go through the time on each response and make use of that so that when we search for a time index + to evaluate the function at, we don't need to re-search the earlier times. + + */ + + this.actualN++; + + final double[] survivalXPoints = input.getSurvivalCurve().getX(); + final double[][] eventSpecificXPoints = new double[events.length][]; + + for(final int event : events){ + eventSpecificXPoints[event-1] = input.getCumulativeIncidenceFunction(event) + .getX(); + } + + int previousSurvivalIndex = 0; + final int[] previousEventSpecificIndex = new int[events.length]; // relying on 0 being default value + + for(int i=0; i causeSpecificCumulativeHazardFunctionList = new ArrayList<>(events.length); + final List cumulativeIncidenceFunctionList = new ArrayList<>(events.length); + + for(final int event : events){ + causeSpecificCumulativeHazardFunctionList.add(event-1, new RightContinuousStepFunction(timesToUse, csCHFY[event-1], 0)); + cumulativeIncidenceFunctionList.add(event-1, new RightContinuousStepFunction(timesToUse, cifY[event-1], 0)); + } + + return CompetingRiskFunctions.builder() + .causeSpecificHazards(causeSpecificCumulativeHazardFunctionList) + .cumulativeIncidenceCurves(cumulativeIncidenceFunctionList) + .survivalCurve(survivalFunction) + .build(); + } + + private void rescaleOutput() { + rescaleArray(actualN, this.survivalY); + + for(int event : events){ + rescaleArray(actualN, this.cifY[event - 1]); + rescaleArray(actualN, this.csCHFY[event - 1]); + } + + this.expectedN = actualN; + + } + + private void rescaleArray(double newN, double[] array){ + for(int i=0; i { +public class MeanResponseCombiner implements ForestResponseCombiner { private static final long serialVersionUID = 1L; @Override @@ -35,5 +36,39 @@ public class MeanResponseCombiner implements ResponseCombiner { } + @Override + public IntermediateCombinedResponse startIntermediateCombinedResponse(int countInputs) { + return new MeanIntermediateCombinedResponse(countInputs); + } + + public static class MeanIntermediateCombinedResponse implements IntermediateCombinedResponse{ + + private double expectedN; + private int actualN; + private double currentMean; + + public MeanIntermediateCombinedResponse(int n){ + this.expectedN = n; + this.actualN = 0; + this.currentMean = 0.0; + } + + @Override + public void processNewInput(Double input) { + this.currentMean = this.currentMean + input / expectedN; + this.actualN ++; + } + + @Override + public Double transformToOutput() { + // rescale if necessary + this.currentMean = this.currentMean * (this.expectedN / (double) actualN); + this.expectedN = actualN; + + return currentMean; + } + + + } } diff --git a/library/src/main/java/ca/joeltherrien/randomforest/tree/ForestResponseCombiner.java b/library/src/main/java/ca/joeltherrien/randomforest/tree/ForestResponseCombiner.java new file mode 100644 index 0000000..e872ae2 --- /dev/null +++ b/library/src/main/java/ca/joeltherrien/randomforest/tree/ForestResponseCombiner.java @@ -0,0 +1,23 @@ +/* + * Copyright (c) 2019 Joel Therrien. + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package ca.joeltherrien.randomforest.tree; + +public interface ForestResponseCombiner extends ResponseCombiner{ + + IntermediateCombinedResponse startIntermediateCombinedResponse(int countInputs); + +} diff --git a/library/src/main/java/ca/joeltherrien/randomforest/tree/ForestTrainer.java b/library/src/main/java/ca/joeltherrien/randomforest/tree/ForestTrainer.java index aa2b3db..21d7f82 100644 --- a/library/src/main/java/ca/joeltherrien/randomforest/tree/ForestTrainer.java +++ b/library/src/main/java/ca/joeltherrien/randomforest/tree/ForestTrainer.java @@ -38,7 +38,7 @@ public class ForestTrainer { private final TreeTrainer treeTrainer; private final List covariates; - private final ResponseCombiner treeResponseCombiner; + private final ForestResponseCombiner treeResponseCombiner; private final List> data; // number of trees to try diff --git a/library/src/main/java/ca/joeltherrien/randomforest/tree/IntermediateCombinedResponse.java b/library/src/main/java/ca/joeltherrien/randomforest/tree/IntermediateCombinedResponse.java new file mode 100644 index 0000000..5ba0837 --- /dev/null +++ b/library/src/main/java/ca/joeltherrien/randomforest/tree/IntermediateCombinedResponse.java @@ -0,0 +1,30 @@ +/* + * Copyright (c) 2019 Joel Therrien. + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package ca.joeltherrien.randomforest.tree; + +/** + * Similar to ResponseCombiner, but an IntermediateCombinedResponse represents the intermediate state of a single output in the process of being combined. + * This class is only used in OfflineForests where we can only load one Tree in memory at a time. + * + */ +public interface IntermediateCombinedResponse { + + void processNewInput(I input); + + O transformToOutput(); + +} diff --git a/library/src/main/java/ca/joeltherrien/randomforest/tree/OfflineForest.java b/library/src/main/java/ca/joeltherrien/randomforest/tree/OfflineForest.java index 1c874a1..036130a 100644 --- a/library/src/main/java/ca/joeltherrien/randomforest/tree/OfflineForest.java +++ b/library/src/main/java/ca/joeltherrien/randomforest/tree/OfflineForest.java @@ -22,7 +22,6 @@ import lombok.AllArgsConstructor; import java.io.File; import java.util.ArrayList; -import java.util.Arrays; import java.util.Iterator; import java.util.List; import java.util.stream.Collectors; @@ -32,9 +31,9 @@ import java.util.stream.IntStream; public class OfflineForest extends Forest { private final File[] treeFiles; - private final ResponseCombiner treeResponseCombiner; + private final ForestResponseCombiner treeResponseCombiner; - public OfflineForest(File treeDirectoryPath, ResponseCombiner treeResponseCombiner){ + public OfflineForest(File treeDirectoryPath, ForestResponseCombiner treeResponseCombiner){ this.treeResponseCombiner = treeResponseCombiner; if(!treeDirectoryPath.isDirectory()){ @@ -42,7 +41,6 @@ public class OfflineForest extends Forest { } this.treeFiles = treeDirectoryPath.listFiles((file, s) -> s.endsWith(".tree")); - } @Override @@ -72,116 +70,108 @@ public class OfflineForest extends Forest { @Override public List evaluate(List rowList){ - final O[][] predictions = (O[][])new Object[rowList.size()][treeFiles.length]; - final Iterator> treeIterator = getTrees().iterator(); + final List> intermediatePredictions = + IntStream.range(0, rowList.size()) + .mapToObj(i -> treeResponseCombiner.startIntermediateCombinedResponse(treeFiles.length)) + .collect(Collectors.toList()); + final Iterator> treeIterator = getTrees().iterator(); for(int treeId = 0; treeId < treeFiles.length; treeId++){ final Tree currentTree = treeIterator.next(); - final int tempTreeId = treeId; // Java workaround IntStream.range(0, rowList.size()).parallel().forEach( rowId -> { final CovariateRow row = rowList.get(rowId); final O prediction = currentTree.evaluate(row); - predictions[rowId][tempTreeId] = prediction; + intermediatePredictions.get(rowId).processNewInput(prediction); } ); } - return Arrays.stream(predictions).parallel() - .map(predArray -> treeResponseCombiner.combine(Arrays.asList(predArray))) + return intermediatePredictions.stream().parallel() + .map(intPred -> intPred.transformToOutput()) .collect(Collectors.toList()); } @Override public List evaluateSerial(List rowList){ - final O[][] predictions = (O[][])new Object[rowList.size()][treeFiles.length]; - final Iterator> treeIterator = getTrees().iterator(); + final List> intermediatePredictions = + IntStream.range(0, rowList.size()) + .mapToObj(i -> treeResponseCombiner.startIntermediateCombinedResponse(treeFiles.length)) + .collect(Collectors.toList()); + final Iterator> treeIterator = getTrees().iterator(); for(int treeId = 0; treeId < treeFiles.length; treeId++){ final Tree currentTree = treeIterator.next(); - final int tempTreeId = treeId; // Java workaround IntStream.range(0, rowList.size()).sequential().forEach( rowId -> { final CovariateRow row = rowList.get(rowId); final O prediction = currentTree.evaluate(row); - predictions[rowId][tempTreeId] = prediction; + intermediatePredictions.get(rowId).processNewInput(prediction); } ); } - return Arrays.stream(predictions).sequential() - .map(predArray -> treeResponseCombiner.combine(Arrays.asList(predArray))) + return intermediatePredictions.stream().sequential() + .map(intPred -> intPred.transformToOutput()) .collect(Collectors.toList()); } @Override public List evaluateOOB(List rowList){ - final O[][] predictions = (O[][])new Object[rowList.size()][treeFiles.length]; - final Iterator> treeIterator = getTrees().iterator(); + final List> intermediatePredictions = + IntStream.range(0, rowList.size()) + .mapToObj(i -> treeResponseCombiner.startIntermediateCombinedResponse(treeFiles.length)) + .collect(Collectors.toList()); + final Iterator> treeIterator = getTrees().iterator(); for(int treeId = 0; treeId < treeFiles.length; treeId++){ final Tree currentTree = treeIterator.next(); - final int tempTreeId = treeId; // Java workaround IntStream.range(0, rowList.size()).parallel().forEach( rowId -> { final CovariateRow row = rowList.get(rowId); if(!currentTree.idInBootstrapSample(row.getId())){ final O prediction = currentTree.evaluate(row); - predictions[rowId][tempTreeId] = prediction; - } else{ - predictions[rowId][tempTreeId] = null; + intermediatePredictions.get(rowId).processNewInput(prediction); } - + // else do nothing; when we get the final output it will get scaled for the smaller N } ); } - return Arrays.stream(predictions).parallel() - .map(predArray -> { - final List predList = Arrays.stream(predArray).parallel() - .filter(pred -> pred != null).collect(Collectors.toList()); - - return treeResponseCombiner.combine(predList); - - }) + return intermediatePredictions.stream().parallel() + .map(intPred -> intPred.transformToOutput()) .collect(Collectors.toList()); } @Override public List evaluateSerialOOB(List rowList){ - final O[][] predictions = (O[][])new Object[rowList.size()][treeFiles.length]; - final Iterator> treeIterator = getTrees().iterator(); + final List> intermediatePredictions = + IntStream.range(0, rowList.size()) + .mapToObj(i -> treeResponseCombiner.startIntermediateCombinedResponse(treeFiles.length)) + .collect(Collectors.toList()); + final Iterator> treeIterator = getTrees().iterator(); for(int treeId = 0; treeId < treeFiles.length; treeId++){ final Tree currentTree = treeIterator.next(); - final int tempTreeId = treeId; // Java workaround IntStream.range(0, rowList.size()).sequential().forEach( rowId -> { final CovariateRow row = rowList.get(rowId); if(!currentTree.idInBootstrapSample(row.getId())){ final O prediction = currentTree.evaluate(row); - predictions[rowId][tempTreeId] = prediction; - } else{ - predictions[rowId][tempTreeId] = null; + intermediatePredictions.get(rowId).processNewInput(prediction); } - + // else do nothing; when we get the final output it will get scaled for the smaller N } ); } - return Arrays.stream(predictions).sequential() - .map(predArray -> { - final List predList = Arrays.stream(predArray).sequential() - .filter(pred -> pred != null).collect(Collectors.toList()); - - return treeResponseCombiner.combine(predList); - - }) + return intermediatePredictions.stream().sequential() + .map(intPred -> intPred.transformToOutput()) .collect(Collectors.toList()); } @@ -194,5 +184,15 @@ public class OfflineForest extends Forest { public int getNumberOfTrees() { return treeFiles.length; } + + public OnlineForest createOnlineCopy(){ + final List> allTrees = new ArrayList<>(getNumberOfTrees()); + getTrees().forEach(allTrees::add); + + return OnlineForest.builder() + .trees(allTrees) + .treeResponseCombiner(treeResponseCombiner) + .build(); + } } diff --git a/library/src/main/java/ca/joeltherrien/randomforest/utils/DataUtils.java b/library/src/main/java/ca/joeltherrien/randomforest/utils/DataUtils.java index 899c903..6f834bc 100644 --- a/library/src/main/java/ca/joeltherrien/randomforest/utils/DataUtils.java +++ b/library/src/main/java/ca/joeltherrien/randomforest/utils/DataUtils.java @@ -31,6 +31,11 @@ public class DataUtils { } final File[] treeFiles = folder.listFiles((file, s) -> s.endsWith(".tree")); + + return loadOnlineForest(treeFiles, treeResponseCombiner); + } + + public static OnlineForest loadOnlineForest(File[] treeFiles, ResponseCombiner treeResponseCombiner) throws IOException, ClassNotFoundException { final List treeFileList = Arrays.asList(treeFiles); Collections.sort(treeFileList, Comparator.comparing(File::getName)); diff --git a/library/src/main/java/ca/joeltherrien/randomforest/utils/RUtils.java b/library/src/main/java/ca/joeltherrien/randomforest/utils/RUtils.java index 8edcadb..7e19d58 100644 --- a/library/src/main/java/ca/joeltherrien/randomforest/utils/RUtils.java +++ b/library/src/main/java/ca/joeltherrien/randomforest/utils/RUtils.java @@ -25,7 +25,6 @@ import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskRespons import java.io.*; import java.util.ArrayList; import java.util.List; -import java.util.stream.IntStream; import java.util.zip.GZIPInputStream; import java.util.zip.GZIPOutputStream; @@ -200,11 +199,13 @@ public final class RUtils { } public static File[] getTreeFileArray(String folderPath, int endingId){ - return (File[]) IntStream.rangeClosed(1, endingId).sequential() - .mapToObj(i -> folderPath + "/tree-" + i + ".tree") - .map(File::new) - .toArray(); + final File[] fileArray = new File[endingId]; + for(int i = 1; i <= endingId; i++){ + fileArray[i-1] = new File(folderPath + "/tree-" + i + ".tree"); + } + + return fileArray; } }