Add alternative way where functions are computed only at final step.
This commit is contained in:
parent
d85f4eb099
commit
74151b94db
8 changed files with 111 additions and 20 deletions
|
@ -6,8 +6,10 @@ import ca.joeltherrien.randomforest.covariates.FactorCovariateSettings;
|
||||||
import ca.joeltherrien.randomforest.covariates.NumericCovariateSettings;
|
import ca.joeltherrien.randomforest.covariates.NumericCovariateSettings;
|
||||||
import ca.joeltherrien.randomforest.responses.competingrisk.*;
|
import ca.joeltherrien.randomforest.responses.competingrisk.*;
|
||||||
import ca.joeltherrien.randomforest.responses.competingrisk.combiner.CompetingRiskFunctionCombiner;
|
import ca.joeltherrien.randomforest.responses.competingrisk.combiner.CompetingRiskFunctionCombiner;
|
||||||
|
import ca.joeltherrien.randomforest.responses.competingrisk.combiner.alternative.CompetingRiskListCombiner;
|
||||||
import ca.joeltherrien.randomforest.tree.Forest;
|
import ca.joeltherrien.randomforest.tree.Forest;
|
||||||
import ca.joeltherrien.randomforest.tree.ForestTrainer;
|
import ca.joeltherrien.randomforest.tree.ForestTrainer;
|
||||||
|
import ca.joeltherrien.randomforest.tree.ResponseCombiner;
|
||||||
import ca.joeltherrien.randomforest.utils.MathFunction;
|
import ca.joeltherrien.randomforest.utils.MathFunction;
|
||||||
import ca.joeltherrien.randomforest.utils.Utils;
|
import ca.joeltherrien.randomforest.utils.Utils;
|
||||||
import com.fasterxml.jackson.databind.node.JsonNodeFactory;
|
import com.fasterxml.jackson.databind.node.JsonNodeFactory;
|
||||||
|
@ -54,16 +56,29 @@ public class Main {
|
||||||
|
|
||||||
if(args.length < 3){
|
if(args.length < 3){
|
||||||
System.out.println("Specify error sample size");
|
System.out.println("Specify error sample size");
|
||||||
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
final String yVarType = settings.getYVarSettings().get("type").asText();
|
final String yVarType = settings.getYVarSettings().get("type").asText();
|
||||||
|
|
||||||
if(!yVarType.equalsIgnoreCase("CompetingRiskResponse") && !yVarType.equalsIgnoreCase("CompetingRiskResponseWithCensorTime")){
|
if(!yVarType.equalsIgnoreCase("CompetingRiskResponse") && !yVarType.equalsIgnoreCase("CompetingRiskResponseWithCensorTime")){
|
||||||
System.out.println("Analyze currently only works on competing risk data");
|
System.out.println("Analyze currently only works on competing risk data");
|
||||||
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
final CompetingRiskFunctionCombiner responseCombiner = (CompetingRiskFunctionCombiner) settings.getTreeCombiner();
|
final ResponseCombiner<?, CompetingRiskFunctions> responseCombiner = settings.getTreeCombiner();
|
||||||
final int[] events = responseCombiner.getEvents();
|
final int[] events;
|
||||||
|
|
||||||
|
if(responseCombiner instanceof CompetingRiskFunctionCombiner){
|
||||||
|
events = ((CompetingRiskFunctionCombiner) responseCombiner).getEvents();
|
||||||
|
}
|
||||||
|
else if(responseCombiner instanceof CompetingRiskListCombiner){
|
||||||
|
events = ((CompetingRiskListCombiner) responseCombiner).getOriginalCombiner().getEvents();
|
||||||
|
}
|
||||||
|
else{
|
||||||
|
System.out.println("Unsupported tree combiner");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
final List<Row<CompetingRiskResponse>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getDataFileLocation());
|
final List<Row<CompetingRiskResponse>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getDataFileLocation());
|
||||||
|
|
||||||
|
@ -72,7 +87,7 @@ public class Main {
|
||||||
Utils.reduceListToSize(dataset, n);
|
Utils.reduceListToSize(dataset, n);
|
||||||
|
|
||||||
final File folder = new File(settings.getSaveTreeLocation());
|
final File folder = new File(settings.getSaveTreeLocation());
|
||||||
final Forest<CompetingRiskFunctions, CompetingRiskFunctions> forest = DataLoader.loadForest(folder, responseCombiner);
|
final Forest<?, CompetingRiskFunctions> forest = DataLoader.loadForest(folder, responseCombiner);
|
||||||
|
|
||||||
System.out.println("Finished loading trees + dataset; creating calculator and evaluating OOB predictions");
|
System.out.println("Finished loading trees + dataset; creating calculator and evaluating OOB predictions");
|
||||||
|
|
||||||
|
|
|
@ -4,6 +4,8 @@ import ca.joeltherrien.randomforest.covariates.CovariateSettings;
|
||||||
import ca.joeltherrien.randomforest.responses.competingrisk.*;
|
import ca.joeltherrien.randomforest.responses.competingrisk.*;
|
||||||
import ca.joeltherrien.randomforest.responses.competingrisk.combiner.CompetingRiskFunctionCombiner;
|
import ca.joeltherrien.randomforest.responses.competingrisk.combiner.CompetingRiskFunctionCombiner;
|
||||||
import ca.joeltherrien.randomforest.responses.competingrisk.combiner.CompetingRiskResponseCombiner;
|
import ca.joeltherrien.randomforest.responses.competingrisk.combiner.CompetingRiskResponseCombiner;
|
||||||
|
import ca.joeltherrien.randomforest.responses.competingrisk.combiner.alternative.CompetingRiskListCombiner;
|
||||||
|
import ca.joeltherrien.randomforest.responses.competingrisk.combiner.alternative.CompetingRiskResponseCombinerToList;
|
||||||
import ca.joeltherrien.randomforest.responses.competingrisk.differentiator.GrayLogRankMultipleGroupDifferentiator;
|
import ca.joeltherrien.randomforest.responses.competingrisk.differentiator.GrayLogRankMultipleGroupDifferentiator;
|
||||||
import ca.joeltherrien.randomforest.responses.competingrisk.differentiator.GrayLogRankSingleGroupDifferentiator;
|
import ca.joeltherrien.randomforest.responses.competingrisk.differentiator.GrayLogRankSingleGroupDifferentiator;
|
||||||
import ca.joeltherrien.randomforest.responses.competingrisk.differentiator.LogRankMultipleGroupDifferentiator;
|
import ca.joeltherrien.randomforest.responses.competingrisk.differentiator.LogRankMultipleGroupDifferentiator;
|
||||||
|
@ -158,6 +160,30 @@ public class Settings {
|
||||||
}
|
}
|
||||||
);
|
);
|
||||||
|
|
||||||
|
registerResponseCombinerConstructor("CompetingRiskListCombiner",
|
||||||
|
(node) -> {
|
||||||
|
final List<Integer> eventList = new ArrayList<>();
|
||||||
|
node.get("events").elements().forEachRemaining(event -> eventList.add(event.asInt()));
|
||||||
|
final int[] events = eventList.stream().mapToInt(i -> i).toArray();
|
||||||
|
|
||||||
|
double[] times = null;
|
||||||
|
// note that times may be null
|
||||||
|
if(node.hasNonNull("times")){
|
||||||
|
final List<Double> timeList = new ArrayList<>();
|
||||||
|
node.get("times").elements().forEachRemaining(event -> timeList.add(event.asDouble()));
|
||||||
|
times = eventList.stream().mapToDouble(db -> db).toArray();
|
||||||
|
}
|
||||||
|
|
||||||
|
final CompetingRiskResponseCombiner responseCombiner = new CompetingRiskResponseCombiner(events, times);
|
||||||
|
return new CompetingRiskListCombiner(responseCombiner);
|
||||||
|
|
||||||
|
}
|
||||||
|
);
|
||||||
|
|
||||||
|
registerResponseCombinerConstructor("CompetingRiskResponseCombinerToList",
|
||||||
|
(node) -> new CompetingRiskResponseCombinerToList()
|
||||||
|
);
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -19,7 +19,7 @@ public class CompetingRiskErrorRateCalculator {
|
||||||
private final List<Row<CompetingRiskResponse>> dataset;
|
private final List<Row<CompetingRiskResponse>> dataset;
|
||||||
private final List<CompetingRiskFunctions> riskFunctions;
|
private final List<CompetingRiskFunctions> riskFunctions;
|
||||||
|
|
||||||
public CompetingRiskErrorRateCalculator(final List<Row<CompetingRiskResponse>> dataset, final Forest<CompetingRiskFunctions, CompetingRiskFunctions> forest){
|
public CompetingRiskErrorRateCalculator(final List<Row<CompetingRiskResponse>> dataset, final Forest<?, CompetingRiskFunctions> forest){
|
||||||
this.dataset = dataset;
|
this.dataset = dataset;
|
||||||
this.riskFunctions = dataset.stream()
|
this.riskFunctions = dataset.stream()
|
||||||
.map(forest::evaluateOOB)
|
.map(forest::evaluateOOB)
|
||||||
|
|
|
@ -6,8 +6,10 @@ import lombok.Data;
|
||||||
import lombok.RequiredArgsConstructor;
|
import lombok.RequiredArgsConstructor;
|
||||||
import org.apache.commons.csv.CSVRecord;
|
import org.apache.commons.csv.CSVRecord;
|
||||||
|
|
||||||
|
import java.io.Serializable;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
public class CompetingRiskResponse {
|
public class CompetingRiskResponse implements Serializable {
|
||||||
|
|
||||||
private final int delta;
|
private final int delta;
|
||||||
private final double u;
|
private final double u;
|
||||||
|
|
|
@ -0,0 +1,26 @@
|
||||||
|
package ca.joeltherrien.randomforest.responses.competingrisk.combiner.alternative;
|
||||||
|
|
||||||
|
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskFunctions;
|
||||||
|
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
|
||||||
|
import ca.joeltherrien.randomforest.responses.competingrisk.combiner.CompetingRiskResponseCombiner;
|
||||||
|
import ca.joeltherrien.randomforest.tree.ResponseCombiner;
|
||||||
|
import lombok.Getter;
|
||||||
|
import lombok.RequiredArgsConstructor;
|
||||||
|
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
|
@RequiredArgsConstructor
|
||||||
|
public class CompetingRiskListCombiner implements ResponseCombiner<CompetingRiskResponse[], CompetingRiskFunctions> {
|
||||||
|
|
||||||
|
@Getter
|
||||||
|
private final CompetingRiskResponseCombiner originalCombiner;
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public CompetingRiskFunctions combine(List<CompetingRiskResponse[]> responses) {
|
||||||
|
final List<CompetingRiskResponse> completeList = responses.stream().flatMap(Arrays::stream).collect(Collectors.toList());
|
||||||
|
|
||||||
|
return originalCombiner.combine(completeList);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,29 @@
|
||||||
|
package ca.joeltherrien.randomforest.responses.competingrisk.combiner.alternative;
|
||||||
|
|
||||||
|
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
|
||||||
|
import ca.joeltherrien.randomforest.tree.ResponseCombiner;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This class takes all of the observations in a terminal node and 'combines' them into just a list of the observations.
|
||||||
|
*
|
||||||
|
* This is used in the alternative approach to only compute the functions at the final stage when combining trees.
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
public class CompetingRiskResponseCombinerToList implements ResponseCombiner<CompetingRiskResponse, CompetingRiskResponse[]> {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public CompetingRiskResponse[] combine(List<CompetingRiskResponse> responses) {
|
||||||
|
final CompetingRiskResponse[] array = new CompetingRiskResponse[responses.size()];
|
||||||
|
|
||||||
|
for(int i=0; i<array.length; i++){
|
||||||
|
array[i] = responses.get(i);
|
||||||
|
}
|
||||||
|
|
||||||
|
return array;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
}
|
|
@ -27,7 +27,7 @@ public class Forest<O, FO> { // O = output of trees, FO = forest output. In prac
|
||||||
public FO evaluateOOB(CovariateRow row){
|
public FO evaluateOOB(CovariateRow row){
|
||||||
|
|
||||||
return treeResponseCombiner.combine(
|
return treeResponseCombiner.combine(
|
||||||
trees.stream()
|
trees.parallelStream()
|
||||||
.filter(tree -> !tree.idInBootstrapSample(row.getId()))
|
.filter(tree -> !tree.idInBootstrapSample(row.getId()))
|
||||||
.map(node -> node.evaluate(row))
|
.map(node -> node.evaluate(row))
|
||||||
.collect(Collectors.toList())
|
.collect(Collectors.toList())
|
||||||
|
|
|
@ -10,7 +10,13 @@ public class Tree<Y> implements Node<Y> {
|
||||||
|
|
||||||
private final Node<Y> rootNode;
|
private final Node<Y> rootNode;
|
||||||
private final int[] bootstrapRowIds;
|
private final int[] bootstrapRowIds;
|
||||||
private boolean bootStrapRowIdsSorted = false;
|
|
||||||
|
|
||||||
|
public Tree(Node<Y> rootNode, int[] bootstrapRowIds) {
|
||||||
|
this.rootNode = rootNode;
|
||||||
|
this.bootstrapRowIds = bootstrapRowIds;
|
||||||
|
Arrays.sort(bootstrapRowIds);
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Y evaluate(CovariateRow row) {
|
public Y evaluate(CovariateRow row) {
|
||||||
|
@ -21,21 +27,8 @@ public class Tree<Y> implements Node<Y> {
|
||||||
return bootstrapRowIds.clone();
|
return bootstrapRowIds.clone();
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Sort bootstrapRowIds. This is not done automatically for efficiency purposes, as in many cases we may not be interested in using bootstrapRowIds();
|
|
||||||
*
|
|
||||||
*/
|
|
||||||
public void sortBootstrapRowIds(){
|
|
||||||
if(!bootStrapRowIdsSorted){
|
|
||||||
Arrays.sort(bootstrapRowIds);
|
|
||||||
bootStrapRowIdsSorted = true;
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
public boolean idInBootstrapSample(int id){
|
public boolean idInBootstrapSample(int id){
|
||||||
this.sortBootstrapRowIds();
|
|
||||||
|
|
||||||
return Arrays.binarySearch(this.bootstrapRowIds, id) >= 0;
|
return Arrays.binarySearch(this.bootstrapRowIds, id) >= 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue