Add support for offline forests.
This commit is contained in:
parent
54af805d4d
commit
f3a4ef01ed
10 changed files with 296 additions and 125 deletions
|
@ -28,6 +28,7 @@ import ca.joeltherrien.randomforest.responses.competingrisk.splitfinder.GrayLogR
|
||||||
import ca.joeltherrien.randomforest.responses.competingrisk.splitfinder.LogRankSplitFinder;
|
import ca.joeltherrien.randomforest.responses.competingrisk.splitfinder.LogRankSplitFinder;
|
||||||
import ca.joeltherrien.randomforest.responses.regression.MeanResponseCombiner;
|
import ca.joeltherrien.randomforest.responses.regression.MeanResponseCombiner;
|
||||||
import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceSplitFinder;
|
import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceSplitFinder;
|
||||||
|
import ca.joeltherrien.randomforest.tree.ForestResponseCombiner;
|
||||||
import ca.joeltherrien.randomforest.tree.ResponseCombiner;
|
import ca.joeltherrien.randomforest.tree.ResponseCombiner;
|
||||||
import ca.joeltherrien.randomforest.tree.SplitFinder;
|
import ca.joeltherrien.randomforest.tree.SplitFinder;
|
||||||
import ca.joeltherrien.randomforest.utils.JsonUtils;
|
import ca.joeltherrien.randomforest.utils.JsonUtils;
|
||||||
|
@ -110,6 +111,8 @@ public class Settings {
|
||||||
|
|
||||||
|
|
||||||
private static Map<String, Function<ObjectNode, ResponseCombiner>> RESPONSE_COMBINER_MAP = new HashMap<>();
|
private static Map<String, Function<ObjectNode, ResponseCombiner>> RESPONSE_COMBINER_MAP = new HashMap<>();
|
||||||
|
private static Map<String, Function<ObjectNode, ForestResponseCombiner>> FOREST_RESPONSE_COMBINER_MAP = new HashMap<>();
|
||||||
|
|
||||||
public static Function<ObjectNode, ResponseCombiner> getResponseCombinerConstructor(final String name){
|
public static Function<ObjectNode, ResponseCombiner> getResponseCombinerConstructor(final String name){
|
||||||
return RESPONSE_COMBINER_MAP.get(name.toLowerCase());
|
return RESPONSE_COMBINER_MAP.get(name.toLowerCase());
|
||||||
}
|
}
|
||||||
|
@ -117,11 +120,21 @@ public class Settings {
|
||||||
RESPONSE_COMBINER_MAP.put(name.toLowerCase(), responseCombinerConstructor);
|
RESPONSE_COMBINER_MAP.put(name.toLowerCase(), responseCombinerConstructor);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public static Function<ObjectNode, ForestResponseCombiner> getForestResponseCombinerConstructor(final String name){
|
||||||
|
return FOREST_RESPONSE_COMBINER_MAP.get(name.toLowerCase());
|
||||||
|
}
|
||||||
|
public static void registerForestResponseCombinerConstructor(final String name, final Function<ObjectNode, ForestResponseCombiner> responseCombinerConstructor){
|
||||||
|
FOREST_RESPONSE_COMBINER_MAP.put(name.toLowerCase(), responseCombinerConstructor);
|
||||||
|
}
|
||||||
|
|
||||||
static{
|
static{
|
||||||
|
|
||||||
registerResponseCombinerConstructor("MeanResponseCombiner",
|
registerResponseCombinerConstructor("MeanResponseCombiner",
|
||||||
(node) -> new MeanResponseCombiner()
|
(node) -> new MeanResponseCombiner()
|
||||||
);
|
);
|
||||||
|
registerForestResponseCombinerConstructor("MeanResponseCombiner",
|
||||||
|
(node) -> new MeanResponseCombiner()
|
||||||
|
);
|
||||||
registerResponseCombinerConstructor("CompetingRiskResponseCombiner",
|
registerResponseCombinerConstructor("CompetingRiskResponseCombiner",
|
||||||
(node) -> {
|
(node) -> {
|
||||||
final int[] events = JsonUtils.jsonToIntArray(node.get("events"));
|
final int[] events = JsonUtils.jsonToIntArray(node.get("events"));
|
||||||
|
@ -131,7 +144,7 @@ public class Settings {
|
||||||
}
|
}
|
||||||
);
|
);
|
||||||
|
|
||||||
registerResponseCombinerConstructor("CompetingRiskFunctionCombiner",
|
registerForestResponseCombinerConstructor("CompetingRiskFunctionCombiner",
|
||||||
(node) -> {
|
(node) -> {
|
||||||
final int[] events = JsonUtils.jsonToIntArray(node.get("events"));
|
final int[] events = JsonUtils.jsonToIntArray(node.get("events"));
|
||||||
|
|
||||||
|
@ -144,8 +157,6 @@ public class Settings {
|
||||||
|
|
||||||
}
|
}
|
||||||
);
|
);
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private int numberOfSplits = 5;
|
private int numberOfSplits = 5;
|
||||||
|
@ -217,10 +228,10 @@ public class Settings {
|
||||||
}
|
}
|
||||||
|
|
||||||
@JsonIgnore
|
@JsonIgnore
|
||||||
public ResponseCombiner getTreeCombiner(){
|
public ForestResponseCombiner getTreeCombiner(){
|
||||||
final String type = treeCombinerSettings.get("type").asText();
|
final String type = treeCombinerSettings.get("type").asText();
|
||||||
|
|
||||||
return getResponseCombinerConstructor(type).apply(treeCombinerSettings);
|
return getForestResponseCombinerConstructor(type).apply(treeCombinerSettings);
|
||||||
}
|
}
|
||||||
|
|
||||||
@JsonIgnore
|
@JsonIgnore
|
||||||
|
|
|
@ -17,17 +17,15 @@
|
||||||
package ca.joeltherrien.randomforest.responses.competingrisk.combiner;
|
package ca.joeltherrien.randomforest.responses.competingrisk.combiner;
|
||||||
|
|
||||||
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskFunctions;
|
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskFunctions;
|
||||||
import ca.joeltherrien.randomforest.tree.ResponseCombiner;
|
import ca.joeltherrien.randomforest.tree.ForestResponseCombiner;
|
||||||
import ca.joeltherrien.randomforest.utils.RightContinuousStepFunction;
|
import ca.joeltherrien.randomforest.tree.IntermediateCombinedResponse;
|
||||||
import ca.joeltherrien.randomforest.utils.Utils;
|
|
||||||
import lombok.RequiredArgsConstructor;
|
import lombok.RequiredArgsConstructor;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
@RequiredArgsConstructor
|
@RequiredArgsConstructor
|
||||||
public class CompetingRiskFunctionCombiner implements ResponseCombiner<CompetingRiskFunctions, CompetingRiskFunctions> {
|
public class CompetingRiskFunctionCombiner implements ForestResponseCombiner<CompetingRiskFunctions, CompetingRiskFunctions> {
|
||||||
|
|
||||||
private static final long serialVersionUID = 1L;
|
private static final long serialVersionUID = 1L;
|
||||||
|
|
||||||
|
@ -57,72 +55,22 @@ public class CompetingRiskFunctionCombiner implements ResponseCombiner<Competing
|
||||||
).sorted().distinct().toArray();
|
).sorted().distinct().toArray();
|
||||||
}
|
}
|
||||||
|
|
||||||
final double n = responses.size();
|
final IntermediateCompetingRisksFunctionsTimesKnown intermediateResult = new IntermediateCompetingRisksFunctionsTimesKnown(responses.size(), this.events, timesToUse);
|
||||||
|
|
||||||
final double[] survivalY = new double[timesToUse.length];
|
|
||||||
final double[][] csCHFY = new double[events.length][timesToUse.length];
|
|
||||||
final double[][] cifY = new double[events.length][timesToUse.length];
|
|
||||||
|
|
||||||
/*
|
|
||||||
We're going to try to efficiently put our predictions together -
|
|
||||||
Assumptions - for each event on a response, the hazard and CIF functions share the same x points
|
|
||||||
|
|
||||||
Plan - go through the time on each response and make use of that so that when we search for a time index
|
|
||||||
to evaluate the function at, we don't need to re-search the earlier times.
|
|
||||||
|
|
||||||
*/
|
|
||||||
|
|
||||||
|
|
||||||
for(final CompetingRiskFunctions currentFunctions : responses){
|
|
||||||
final double[] survivalXPoints = currentFunctions.getSurvivalCurve().getX();
|
|
||||||
final double[][] eventSpecificXPoints = new double[events.length][];
|
|
||||||
|
|
||||||
for(final int event : events){
|
|
||||||
eventSpecificXPoints[event-1] = currentFunctions.getCumulativeIncidenceFunction(event)
|
|
||||||
.getX();
|
|
||||||
}
|
|
||||||
|
|
||||||
int previousSurvivalIndex = 0;
|
|
||||||
final int[] previousEventSpecificIndex = new int[events.length]; // relying on 0 being default value
|
|
||||||
|
|
||||||
for(int i=0; i<timesToUse.length; i++){
|
|
||||||
final double time = timesToUse[i];
|
|
||||||
|
|
||||||
// Survival curve
|
|
||||||
final int survivalTimeIndex = Utils.binarySearchLessThan(previousSurvivalIndex, survivalXPoints.length, survivalXPoints, time);
|
|
||||||
survivalY[i] = survivalY[i] + currentFunctions.getSurvivalCurve().evaluateByIndex(survivalTimeIndex) / n;
|
|
||||||
previousSurvivalIndex = Math.max(survivalTimeIndex, 0); // if our current time is less than the smallest time in xPoints then binarySearchLessThan returned a -1.
|
|
||||||
// -1's not an issue for evaluateByIndex, but it is an issue for the next time binarySearchLessThan is called.
|
|
||||||
|
|
||||||
// CHFs and CIFs
|
|
||||||
for(final int event : events){
|
|
||||||
final double[] xPoints = eventSpecificXPoints[event-1];
|
|
||||||
final int eventTimeIndex = Utils.binarySearchLessThan(previousEventSpecificIndex[event-1], xPoints.length,
|
|
||||||
xPoints, time);
|
|
||||||
csCHFY[event-1][i] = csCHFY[event-1][i] + currentFunctions.getCauseSpecificHazardFunction(event)
|
|
||||||
.evaluateByIndex(eventTimeIndex) / n;
|
|
||||||
cifY[event-1][i] = cifY[event-1][i] + currentFunctions.getCumulativeIncidenceFunction(event)
|
|
||||||
.evaluateByIndex(eventTimeIndex) / n;
|
|
||||||
|
|
||||||
previousEventSpecificIndex[event-1] = Math.max(eventTimeIndex, 0);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
for(CompetingRiskFunctions input : responses){
|
||||||
|
intermediateResult.processNewInput(input);
|
||||||
}
|
}
|
||||||
|
|
||||||
final RightContinuousStepFunction survivalFunction = new RightContinuousStepFunction(timesToUse, survivalY, 1.0);
|
return intermediateResult.transformToOutput();
|
||||||
final List<RightContinuousStepFunction> causeSpecificCumulativeHazardFunctionList = new ArrayList<>(events.length);
|
}
|
||||||
final List<RightContinuousStepFunction> cumulativeIncidenceFunctionList = new ArrayList<>(events.length);
|
|
||||||
|
|
||||||
for(final int event : events){
|
@Override
|
||||||
causeSpecificCumulativeHazardFunctionList.add(event-1, new RightContinuousStepFunction(timesToUse, csCHFY[event-1], 0));
|
public IntermediateCombinedResponse<CompetingRiskFunctions, CompetingRiskFunctions> startIntermediateCombinedResponse(int countInputs) {
|
||||||
cumulativeIncidenceFunctionList.add(event-1, new RightContinuousStepFunction(timesToUse, cifY[event-1], 0));
|
if(this.times != null){
|
||||||
|
return new IntermediateCompetingRisksFunctionsTimesKnown(countInputs, this.events, this.times);
|
||||||
}
|
}
|
||||||
|
|
||||||
return CompetingRiskFunctions.builder()
|
// TODO - implement
|
||||||
.causeSpecificHazards(causeSpecificCumulativeHazardFunctionList)
|
throw new RuntimeException("startIntermediateCombinedResponse when times is unknown is not yet implemented");
|
||||||
.cumulativeIncidenceCurves(cumulativeIncidenceFunctionList)
|
|
||||||
.survivalCurve(survivalFunction)
|
|
||||||
.build();
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,118 @@
|
||||||
|
package ca.joeltherrien.randomforest.responses.competingrisk.combiner;
|
||||||
|
|
||||||
|
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskFunctions;
|
||||||
|
import ca.joeltherrien.randomforest.tree.IntermediateCombinedResponse;
|
||||||
|
import ca.joeltherrien.randomforest.utils.RightContinuousStepFunction;
|
||||||
|
import ca.joeltherrien.randomforest.utils.Utils;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
public class IntermediateCompetingRisksFunctionsTimesKnown implements IntermediateCombinedResponse<CompetingRiskFunctions, CompetingRiskFunctions> {
|
||||||
|
|
||||||
|
private double expectedN;
|
||||||
|
private final int[] events;
|
||||||
|
private final double[] timesToUse;
|
||||||
|
private int actualN;
|
||||||
|
|
||||||
|
private final double[] survivalY;
|
||||||
|
private final double[][] csCHFY;
|
||||||
|
private final double[][] cifY;
|
||||||
|
|
||||||
|
public IntermediateCompetingRisksFunctionsTimesKnown(int n, int[] events, double[] timesToUse){
|
||||||
|
this.expectedN = n;
|
||||||
|
this.events = events;
|
||||||
|
this.timesToUse = timesToUse;
|
||||||
|
this.actualN = 0;
|
||||||
|
|
||||||
|
this.survivalY = new double[timesToUse.length];
|
||||||
|
this.csCHFY = new double[events.length][timesToUse.length];
|
||||||
|
this.cifY = new double[events.length][timesToUse.length];
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void processNewInput(CompetingRiskFunctions input) {
|
||||||
|
/*
|
||||||
|
We're going to try to efficiently put our predictions together -
|
||||||
|
Assumptions - for each event on a response, the hazard and CIF functions share the same x points
|
||||||
|
|
||||||
|
Plan - go through the time on each response and make use of that so that when we search for a time index
|
||||||
|
to evaluate the function at, we don't need to re-search the earlier times.
|
||||||
|
|
||||||
|
*/
|
||||||
|
|
||||||
|
this.actualN++;
|
||||||
|
|
||||||
|
final double[] survivalXPoints = input.getSurvivalCurve().getX();
|
||||||
|
final double[][] eventSpecificXPoints = new double[events.length][];
|
||||||
|
|
||||||
|
for(final int event : events){
|
||||||
|
eventSpecificXPoints[event-1] = input.getCumulativeIncidenceFunction(event)
|
||||||
|
.getX();
|
||||||
|
}
|
||||||
|
|
||||||
|
int previousSurvivalIndex = 0;
|
||||||
|
final int[] previousEventSpecificIndex = new int[events.length]; // relying on 0 being default value
|
||||||
|
|
||||||
|
for(int i=0; i<timesToUse.length; i++){
|
||||||
|
final double time = timesToUse[i];
|
||||||
|
|
||||||
|
// Survival curve
|
||||||
|
final int survivalTimeIndex = Utils.binarySearchLessThan(previousSurvivalIndex, survivalXPoints.length, survivalXPoints, time);
|
||||||
|
survivalY[i] = survivalY[i] + input.getSurvivalCurve().evaluateByIndex(survivalTimeIndex) / expectedN;
|
||||||
|
previousSurvivalIndex = Math.max(survivalTimeIndex, 0); // if our current time is less than the smallest time in xPoints then binarySearchLessThan returned a -1.
|
||||||
|
// -1's not an issue for evaluateByIndex, but it is an issue for the next time binarySearchLessThan is called.
|
||||||
|
|
||||||
|
// CHFs and CIFs
|
||||||
|
for(final int event : events){
|
||||||
|
final double[] xPoints = eventSpecificXPoints[event-1];
|
||||||
|
final int eventTimeIndex = Utils.binarySearchLessThan(previousEventSpecificIndex[event-1], xPoints.length,
|
||||||
|
xPoints, time);
|
||||||
|
csCHFY[event-1][i] = csCHFY[event-1][i] + input.getCauseSpecificHazardFunction(event)
|
||||||
|
.evaluateByIndex(eventTimeIndex) / expectedN;
|
||||||
|
cifY[event-1][i] = cifY[event-1][i] + input.getCumulativeIncidenceFunction(event)
|
||||||
|
.evaluateByIndex(eventTimeIndex) / expectedN;
|
||||||
|
|
||||||
|
previousEventSpecificIndex[event-1] = Math.max(eventTimeIndex, 0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public CompetingRiskFunctions transformToOutput() {
|
||||||
|
rescaleOutput();
|
||||||
|
|
||||||
|
final RightContinuousStepFunction survivalFunction = new RightContinuousStepFunction(timesToUse, survivalY, 1.0);
|
||||||
|
final List<RightContinuousStepFunction> causeSpecificCumulativeHazardFunctionList = new ArrayList<>(events.length);
|
||||||
|
final List<RightContinuousStepFunction> cumulativeIncidenceFunctionList = new ArrayList<>(events.length);
|
||||||
|
|
||||||
|
for(final int event : events){
|
||||||
|
causeSpecificCumulativeHazardFunctionList.add(event-1, new RightContinuousStepFunction(timesToUse, csCHFY[event-1], 0));
|
||||||
|
cumulativeIncidenceFunctionList.add(event-1, new RightContinuousStepFunction(timesToUse, cifY[event-1], 0));
|
||||||
|
}
|
||||||
|
|
||||||
|
return CompetingRiskFunctions.builder()
|
||||||
|
.causeSpecificHazards(causeSpecificCumulativeHazardFunctionList)
|
||||||
|
.cumulativeIncidenceCurves(cumulativeIncidenceFunctionList)
|
||||||
|
.survivalCurve(survivalFunction)
|
||||||
|
.build();
|
||||||
|
}
|
||||||
|
|
||||||
|
private void rescaleOutput() {
|
||||||
|
rescaleArray(actualN, this.survivalY);
|
||||||
|
|
||||||
|
for(int event : events){
|
||||||
|
rescaleArray(actualN, this.cifY[event - 1]);
|
||||||
|
rescaleArray(actualN, this.csCHFY[event - 1]);
|
||||||
|
}
|
||||||
|
|
||||||
|
this.expectedN = actualN;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
private void rescaleArray(double newN, double[] array){
|
||||||
|
for(int i=0; i<array.length; i++){
|
||||||
|
array[i] = array[i] * (this.expectedN / newN);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -16,7 +16,8 @@
|
||||||
|
|
||||||
package ca.joeltherrien.randomforest.responses.regression;
|
package ca.joeltherrien.randomforest.responses.regression;
|
||||||
|
|
||||||
import ca.joeltherrien.randomforest.tree.ResponseCombiner;
|
import ca.joeltherrien.randomforest.tree.ForestResponseCombiner;
|
||||||
|
import ca.joeltherrien.randomforest.tree.IntermediateCombinedResponse;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
|
@ -24,7 +25,7 @@ import java.util.List;
|
||||||
* Returns the Mean value of a group of Doubles.
|
* Returns the Mean value of a group of Doubles.
|
||||||
*
|
*
|
||||||
*/
|
*/
|
||||||
public class MeanResponseCombiner implements ResponseCombiner<Double, Double> {
|
public class MeanResponseCombiner implements ForestResponseCombiner<Double, Double> {
|
||||||
private static final long serialVersionUID = 1L;
|
private static final long serialVersionUID = 1L;
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -35,5 +36,39 @@ public class MeanResponseCombiner implements ResponseCombiner<Double, Double> {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public IntermediateCombinedResponse<Double, Double> startIntermediateCombinedResponse(int countInputs) {
|
||||||
|
return new MeanIntermediateCombinedResponse(countInputs);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static class MeanIntermediateCombinedResponse implements IntermediateCombinedResponse<Double, Double>{
|
||||||
|
|
||||||
|
private double expectedN;
|
||||||
|
private int actualN;
|
||||||
|
private double currentMean;
|
||||||
|
|
||||||
|
public MeanIntermediateCombinedResponse(int n){
|
||||||
|
this.expectedN = n;
|
||||||
|
this.actualN = 0;
|
||||||
|
this.currentMean = 0.0;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void processNewInput(Double input) {
|
||||||
|
this.currentMean = this.currentMean + input / expectedN;
|
||||||
|
this.actualN ++;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Double transformToOutput() {
|
||||||
|
// rescale if necessary
|
||||||
|
this.currentMean = this.currentMean * (this.expectedN / (double) actualN);
|
||||||
|
this.expectedN = actualN;
|
||||||
|
|
||||||
|
return currentMean;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,23 @@
|
||||||
|
/*
|
||||||
|
* Copyright (c) 2019 Joel Therrien.
|
||||||
|
* This program is free software: you can redistribute it and/or modify
|
||||||
|
* it under the terms of the GNU General Public License as published by
|
||||||
|
* the Free Software Foundation, either version 3 of the License, or
|
||||||
|
* (at your option) any later version.
|
||||||
|
*
|
||||||
|
* This program is distributed in the hope that it will be useful,
|
||||||
|
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
* GNU General Public License for more details.
|
||||||
|
*
|
||||||
|
* You should have received a copy of the GNU General Public License
|
||||||
|
* along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package ca.joeltherrien.randomforest.tree;
|
||||||
|
|
||||||
|
public interface ForestResponseCombiner<I, O> extends ResponseCombiner<I, O>{
|
||||||
|
|
||||||
|
IntermediateCombinedResponse<I, O> startIntermediateCombinedResponse(int countInputs);
|
||||||
|
|
||||||
|
}
|
|
@ -38,7 +38,7 @@ public class ForestTrainer<Y, TO, FO> {
|
||||||
|
|
||||||
private final TreeTrainer<Y, TO> treeTrainer;
|
private final TreeTrainer<Y, TO> treeTrainer;
|
||||||
private final List<Covariate> covariates;
|
private final List<Covariate> covariates;
|
||||||
private final ResponseCombiner<TO, FO> treeResponseCombiner;
|
private final ForestResponseCombiner<TO, FO> treeResponseCombiner;
|
||||||
private final List<Row<Y>> data;
|
private final List<Row<Y>> data;
|
||||||
|
|
||||||
// number of trees to try
|
// number of trees to try
|
||||||
|
|
|
@ -0,0 +1,30 @@
|
||||||
|
/*
|
||||||
|
* Copyright (c) 2019 Joel Therrien.
|
||||||
|
* This program is free software: you can redistribute it and/or modify
|
||||||
|
* it under the terms of the GNU General Public License as published by
|
||||||
|
* the Free Software Foundation, either version 3 of the License, or
|
||||||
|
* (at your option) any later version.
|
||||||
|
*
|
||||||
|
* This program is distributed in the hope that it will be useful,
|
||||||
|
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
* GNU General Public License for more details.
|
||||||
|
*
|
||||||
|
* You should have received a copy of the GNU General Public License
|
||||||
|
* along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package ca.joeltherrien.randomforest.tree;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Similar to ResponseCombiner, but an IntermediateCombinedResponse represents the intermediate state of a single output in the process of being combined.
|
||||||
|
* This class is only used in OfflineForests where we can only load one Tree in memory at a time.
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
public interface IntermediateCombinedResponse<I, O> {
|
||||||
|
|
||||||
|
void processNewInput(I input);
|
||||||
|
|
||||||
|
O transformToOutput();
|
||||||
|
|
||||||
|
}
|
|
@ -22,7 +22,6 @@ import lombok.AllArgsConstructor;
|
||||||
|
|
||||||
import java.io.File;
|
import java.io.File;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.Arrays;
|
|
||||||
import java.util.Iterator;
|
import java.util.Iterator;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
|
@ -32,9 +31,9 @@ import java.util.stream.IntStream;
|
||||||
public class OfflineForest<O, FO> extends Forest<O, FO> {
|
public class OfflineForest<O, FO> extends Forest<O, FO> {
|
||||||
|
|
||||||
private final File[] treeFiles;
|
private final File[] treeFiles;
|
||||||
private final ResponseCombiner<O, FO> treeResponseCombiner;
|
private final ForestResponseCombiner<O, FO> treeResponseCombiner;
|
||||||
|
|
||||||
public OfflineForest(File treeDirectoryPath, ResponseCombiner<O, FO> treeResponseCombiner){
|
public OfflineForest(File treeDirectoryPath, ForestResponseCombiner<O, FO> treeResponseCombiner){
|
||||||
this.treeResponseCombiner = treeResponseCombiner;
|
this.treeResponseCombiner = treeResponseCombiner;
|
||||||
|
|
||||||
if(!treeDirectoryPath.isDirectory()){
|
if(!treeDirectoryPath.isDirectory()){
|
||||||
|
@ -42,7 +41,6 @@ public class OfflineForest<O, FO> extends Forest<O, FO> {
|
||||||
}
|
}
|
||||||
|
|
||||||
this.treeFiles = treeDirectoryPath.listFiles((file, s) -> s.endsWith(".tree"));
|
this.treeFiles = treeDirectoryPath.listFiles((file, s) -> s.endsWith(".tree"));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -72,116 +70,108 @@ public class OfflineForest<O, FO> extends Forest<O, FO> {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<FO> evaluate(List<? extends CovariateRow> rowList){
|
public List<FO> evaluate(List<? extends CovariateRow> rowList){
|
||||||
final O[][] predictions = (O[][])new Object[rowList.size()][treeFiles.length];
|
final List<IntermediateCombinedResponse<O, FO>> intermediatePredictions =
|
||||||
final Iterator<Tree<O>> treeIterator = getTrees().iterator();
|
IntStream.range(0, rowList.size())
|
||||||
|
.mapToObj(i -> treeResponseCombiner.startIntermediateCombinedResponse(treeFiles.length))
|
||||||
|
.collect(Collectors.toList());
|
||||||
|
|
||||||
|
final Iterator<Tree<O>> treeIterator = getTrees().iterator();
|
||||||
for(int treeId = 0; treeId < treeFiles.length; treeId++){
|
for(int treeId = 0; treeId < treeFiles.length; treeId++){
|
||||||
final Tree<O> currentTree = treeIterator.next();
|
final Tree<O> currentTree = treeIterator.next();
|
||||||
|
|
||||||
final int tempTreeId = treeId; // Java workaround
|
|
||||||
IntStream.range(0, rowList.size()).parallel().forEach(
|
IntStream.range(0, rowList.size()).parallel().forEach(
|
||||||
rowId -> {
|
rowId -> {
|
||||||
final CovariateRow row = rowList.get(rowId);
|
final CovariateRow row = rowList.get(rowId);
|
||||||
final O prediction = currentTree.evaluate(row);
|
final O prediction = currentTree.evaluate(row);
|
||||||
predictions[rowId][tempTreeId] = prediction;
|
intermediatePredictions.get(rowId).processNewInput(prediction);
|
||||||
}
|
}
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
return Arrays.stream(predictions).parallel()
|
return intermediatePredictions.stream().parallel()
|
||||||
.map(predArray -> treeResponseCombiner.combine(Arrays.asList(predArray)))
|
.map(intPred -> intPred.transformToOutput())
|
||||||
.collect(Collectors.toList());
|
.collect(Collectors.toList());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<FO> evaluateSerial(List<? extends CovariateRow> rowList){
|
public List<FO> evaluateSerial(List<? extends CovariateRow> rowList){
|
||||||
final O[][] predictions = (O[][])new Object[rowList.size()][treeFiles.length];
|
final List<IntermediateCombinedResponse<O, FO>> intermediatePredictions =
|
||||||
final Iterator<Tree<O>> treeIterator = getTrees().iterator();
|
IntStream.range(0, rowList.size())
|
||||||
|
.mapToObj(i -> treeResponseCombiner.startIntermediateCombinedResponse(treeFiles.length))
|
||||||
|
.collect(Collectors.toList());
|
||||||
|
|
||||||
|
final Iterator<Tree<O>> treeIterator = getTrees().iterator();
|
||||||
for(int treeId = 0; treeId < treeFiles.length; treeId++){
|
for(int treeId = 0; treeId < treeFiles.length; treeId++){
|
||||||
final Tree<O> currentTree = treeIterator.next();
|
final Tree<O> currentTree = treeIterator.next();
|
||||||
|
|
||||||
final int tempTreeId = treeId; // Java workaround
|
|
||||||
IntStream.range(0, rowList.size()).sequential().forEach(
|
IntStream.range(0, rowList.size()).sequential().forEach(
|
||||||
rowId -> {
|
rowId -> {
|
||||||
final CovariateRow row = rowList.get(rowId);
|
final CovariateRow row = rowList.get(rowId);
|
||||||
final O prediction = currentTree.evaluate(row);
|
final O prediction = currentTree.evaluate(row);
|
||||||
predictions[rowId][tempTreeId] = prediction;
|
intermediatePredictions.get(rowId).processNewInput(prediction);
|
||||||
}
|
}
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
return Arrays.stream(predictions).sequential()
|
return intermediatePredictions.stream().sequential()
|
||||||
.map(predArray -> treeResponseCombiner.combine(Arrays.asList(predArray)))
|
.map(intPred -> intPred.transformToOutput())
|
||||||
.collect(Collectors.toList());
|
.collect(Collectors.toList());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<FO> evaluateOOB(List<? extends CovariateRow> rowList){
|
public List<FO> evaluateOOB(List<? extends CovariateRow> rowList){
|
||||||
final O[][] predictions = (O[][])new Object[rowList.size()][treeFiles.length];
|
final List<IntermediateCombinedResponse<O, FO>> intermediatePredictions =
|
||||||
final Iterator<Tree<O>> treeIterator = getTrees().iterator();
|
IntStream.range(0, rowList.size())
|
||||||
|
.mapToObj(i -> treeResponseCombiner.startIntermediateCombinedResponse(treeFiles.length))
|
||||||
|
.collect(Collectors.toList());
|
||||||
|
|
||||||
|
final Iterator<Tree<O>> treeIterator = getTrees().iterator();
|
||||||
for(int treeId = 0; treeId < treeFiles.length; treeId++){
|
for(int treeId = 0; treeId < treeFiles.length; treeId++){
|
||||||
final Tree<O> currentTree = treeIterator.next();
|
final Tree<O> currentTree = treeIterator.next();
|
||||||
|
|
||||||
final int tempTreeId = treeId; // Java workaround
|
|
||||||
IntStream.range(0, rowList.size()).parallel().forEach(
|
IntStream.range(0, rowList.size()).parallel().forEach(
|
||||||
rowId -> {
|
rowId -> {
|
||||||
final CovariateRow row = rowList.get(rowId);
|
final CovariateRow row = rowList.get(rowId);
|
||||||
if(!currentTree.idInBootstrapSample(row.getId())){
|
if(!currentTree.idInBootstrapSample(row.getId())){
|
||||||
final O prediction = currentTree.evaluate(row);
|
final O prediction = currentTree.evaluate(row);
|
||||||
predictions[rowId][tempTreeId] = prediction;
|
intermediatePredictions.get(rowId).processNewInput(prediction);
|
||||||
} else{
|
|
||||||
predictions[rowId][tempTreeId] = null;
|
|
||||||
}
|
}
|
||||||
|
// else do nothing; when we get the final output it will get scaled for the smaller N
|
||||||
}
|
}
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
return Arrays.stream(predictions).parallel()
|
return intermediatePredictions.stream().parallel()
|
||||||
.map(predArray -> {
|
.map(intPred -> intPred.transformToOutput())
|
||||||
final List<O> predList = Arrays.stream(predArray).parallel()
|
|
||||||
.filter(pred -> pred != null).collect(Collectors.toList());
|
|
||||||
|
|
||||||
return treeResponseCombiner.combine(predList);
|
|
||||||
|
|
||||||
})
|
|
||||||
.collect(Collectors.toList());
|
.collect(Collectors.toList());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<FO> evaluateSerialOOB(List<? extends CovariateRow> rowList){
|
public List<FO> evaluateSerialOOB(List<? extends CovariateRow> rowList){
|
||||||
final O[][] predictions = (O[][])new Object[rowList.size()][treeFiles.length];
|
final List<IntermediateCombinedResponse<O, FO>> intermediatePredictions =
|
||||||
final Iterator<Tree<O>> treeIterator = getTrees().iterator();
|
IntStream.range(0, rowList.size())
|
||||||
|
.mapToObj(i -> treeResponseCombiner.startIntermediateCombinedResponse(treeFiles.length))
|
||||||
|
.collect(Collectors.toList());
|
||||||
|
|
||||||
|
final Iterator<Tree<O>> treeIterator = getTrees().iterator();
|
||||||
for(int treeId = 0; treeId < treeFiles.length; treeId++){
|
for(int treeId = 0; treeId < treeFiles.length; treeId++){
|
||||||
final Tree<O> currentTree = treeIterator.next();
|
final Tree<O> currentTree = treeIterator.next();
|
||||||
|
|
||||||
final int tempTreeId = treeId; // Java workaround
|
|
||||||
IntStream.range(0, rowList.size()).sequential().forEach(
|
IntStream.range(0, rowList.size()).sequential().forEach(
|
||||||
rowId -> {
|
rowId -> {
|
||||||
final CovariateRow row = rowList.get(rowId);
|
final CovariateRow row = rowList.get(rowId);
|
||||||
if(!currentTree.idInBootstrapSample(row.getId())){
|
if(!currentTree.idInBootstrapSample(row.getId())){
|
||||||
final O prediction = currentTree.evaluate(row);
|
final O prediction = currentTree.evaluate(row);
|
||||||
predictions[rowId][tempTreeId] = prediction;
|
intermediatePredictions.get(rowId).processNewInput(prediction);
|
||||||
} else{
|
|
||||||
predictions[rowId][tempTreeId] = null;
|
|
||||||
}
|
}
|
||||||
|
// else do nothing; when we get the final output it will get scaled for the smaller N
|
||||||
}
|
}
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
return Arrays.stream(predictions).sequential()
|
return intermediatePredictions.stream().sequential()
|
||||||
.map(predArray -> {
|
.map(intPred -> intPred.transformToOutput())
|
||||||
final List<O> predList = Arrays.stream(predArray).sequential()
|
|
||||||
.filter(pred -> pred != null).collect(Collectors.toList());
|
|
||||||
|
|
||||||
return treeResponseCombiner.combine(predList);
|
|
||||||
|
|
||||||
})
|
|
||||||
.collect(Collectors.toList());
|
.collect(Collectors.toList());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -194,5 +184,15 @@ public class OfflineForest<O, FO> extends Forest<O, FO> {
|
||||||
public int getNumberOfTrees() {
|
public int getNumberOfTrees() {
|
||||||
return treeFiles.length;
|
return treeFiles.length;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public OnlineForest<O, FO> createOnlineCopy(){
|
||||||
|
final List<Tree<O>> allTrees = new ArrayList<>(getNumberOfTrees());
|
||||||
|
getTrees().forEach(allTrees::add);
|
||||||
|
|
||||||
|
return OnlineForest.<O, FO>builder()
|
||||||
|
.trees(allTrees)
|
||||||
|
.treeResponseCombiner(treeResponseCombiner)
|
||||||
|
.build();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -31,6 +31,11 @@ public class DataUtils {
|
||||||
}
|
}
|
||||||
|
|
||||||
final File[] treeFiles = folder.listFiles((file, s) -> s.endsWith(".tree"));
|
final File[] treeFiles = folder.listFiles((file, s) -> s.endsWith(".tree"));
|
||||||
|
|
||||||
|
return loadOnlineForest(treeFiles, treeResponseCombiner);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static <O, FO> OnlineForest<O, FO> loadOnlineForest(File[] treeFiles, ResponseCombiner<O, FO> treeResponseCombiner) throws IOException, ClassNotFoundException {
|
||||||
final List<File> treeFileList = Arrays.asList(treeFiles);
|
final List<File> treeFileList = Arrays.asList(treeFiles);
|
||||||
|
|
||||||
Collections.sort(treeFileList, Comparator.comparing(File::getName));
|
Collections.sort(treeFileList, Comparator.comparing(File::getName));
|
||||||
|
|
|
@ -25,7 +25,6 @@ import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskRespons
|
||||||
import java.io.*;
|
import java.io.*;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.stream.IntStream;
|
|
||||||
import java.util.zip.GZIPInputStream;
|
import java.util.zip.GZIPInputStream;
|
||||||
import java.util.zip.GZIPOutputStream;
|
import java.util.zip.GZIPOutputStream;
|
||||||
|
|
||||||
|
@ -200,11 +199,13 @@ public final class RUtils {
|
||||||
}
|
}
|
||||||
|
|
||||||
public static File[] getTreeFileArray(String folderPath, int endingId){
|
public static File[] getTreeFileArray(String folderPath, int endingId){
|
||||||
return (File[]) IntStream.rangeClosed(1, endingId).sequential()
|
final File[] fileArray = new File[endingId];
|
||||||
.mapToObj(i -> folderPath + "/tree-" + i + ".tree")
|
|
||||||
.map(File::new)
|
|
||||||
.toArray();
|
|
||||||
|
|
||||||
|
for(int i = 1; i <= endingId; i++){
|
||||||
|
fileArray[i-1] = new File(folderPath + "/tree-" + i + ".tree");
|
||||||
|
}
|
||||||
|
|
||||||
|
return fileArray;
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue