Add alternative way where functions are computed only at final step.

This commit is contained in:
Joel Therrien 2018-08-07 15:49:55 -07:00
parent d85f4eb099
commit 74151b94db
8 changed files with 111 additions and 20 deletions

View file

@ -6,8 +6,10 @@ import ca.joeltherrien.randomforest.covariates.FactorCovariateSettings;
import ca.joeltherrien.randomforest.covariates.NumericCovariateSettings;
import ca.joeltherrien.randomforest.responses.competingrisk.*;
import ca.joeltherrien.randomforest.responses.competingrisk.combiner.CompetingRiskFunctionCombiner;
import ca.joeltherrien.randomforest.responses.competingrisk.combiner.alternative.CompetingRiskListCombiner;
import ca.joeltherrien.randomforest.tree.Forest;
import ca.joeltherrien.randomforest.tree.ForestTrainer;
import ca.joeltherrien.randomforest.tree.ResponseCombiner;
import ca.joeltherrien.randomforest.utils.MathFunction;
import ca.joeltherrien.randomforest.utils.Utils;
import com.fasterxml.jackson.databind.node.JsonNodeFactory;
@ -54,16 +56,29 @@ public class Main {
if(args.length < 3){
System.out.println("Specify error sample size");
return;
}
final String yVarType = settings.getYVarSettings().get("type").asText();
if(!yVarType.equalsIgnoreCase("CompetingRiskResponse") && !yVarType.equalsIgnoreCase("CompetingRiskResponseWithCensorTime")){
System.out.println("Analyze currently only works on competing risk data");
return;
}
final CompetingRiskFunctionCombiner responseCombiner = (CompetingRiskFunctionCombiner) settings.getTreeCombiner();
final int[] events = responseCombiner.getEvents();
final ResponseCombiner<?, CompetingRiskFunctions> responseCombiner = settings.getTreeCombiner();
final int[] events;
if(responseCombiner instanceof CompetingRiskFunctionCombiner){
events = ((CompetingRiskFunctionCombiner) responseCombiner).getEvents();
}
else if(responseCombiner instanceof CompetingRiskListCombiner){
events = ((CompetingRiskListCombiner) responseCombiner).getOriginalCombiner().getEvents();
}
else{
System.out.println("Unsupported tree combiner");
return;
}
final List<Row<CompetingRiskResponse>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getDataFileLocation());
@ -72,7 +87,7 @@ public class Main {
Utils.reduceListToSize(dataset, n);
final File folder = new File(settings.getSaveTreeLocation());
final Forest<CompetingRiskFunctions, CompetingRiskFunctions> forest = DataLoader.loadForest(folder, responseCombiner);
final Forest<?, CompetingRiskFunctions> forest = DataLoader.loadForest(folder, responseCombiner);
System.out.println("Finished loading trees + dataset; creating calculator and evaluating OOB predictions");

View file

@ -4,6 +4,8 @@ import ca.joeltherrien.randomforest.covariates.CovariateSettings;
import ca.joeltherrien.randomforest.responses.competingrisk.*;
import ca.joeltherrien.randomforest.responses.competingrisk.combiner.CompetingRiskFunctionCombiner;
import ca.joeltherrien.randomforest.responses.competingrisk.combiner.CompetingRiskResponseCombiner;
import ca.joeltherrien.randomforest.responses.competingrisk.combiner.alternative.CompetingRiskListCombiner;
import ca.joeltherrien.randomforest.responses.competingrisk.combiner.alternative.CompetingRiskResponseCombinerToList;
import ca.joeltherrien.randomforest.responses.competingrisk.differentiator.GrayLogRankMultipleGroupDifferentiator;
import ca.joeltherrien.randomforest.responses.competingrisk.differentiator.GrayLogRankSingleGroupDifferentiator;
import ca.joeltherrien.randomforest.responses.competingrisk.differentiator.LogRankMultipleGroupDifferentiator;
@ -158,6 +160,30 @@ public class Settings {
}
);
registerResponseCombinerConstructor("CompetingRiskListCombiner",
(node) -> {
final List<Integer> eventList = new ArrayList<>();
node.get("events").elements().forEachRemaining(event -> eventList.add(event.asInt()));
final int[] events = eventList.stream().mapToInt(i -> i).toArray();
double[] times = null;
// note that times may be null
if(node.hasNonNull("times")){
final List<Double> timeList = new ArrayList<>();
node.get("times").elements().forEachRemaining(event -> timeList.add(event.asDouble()));
times = eventList.stream().mapToDouble(db -> db).toArray();
}
final CompetingRiskResponseCombiner responseCombiner = new CompetingRiskResponseCombiner(events, times);
return new CompetingRiskListCombiner(responseCombiner);
}
);
registerResponseCombinerConstructor("CompetingRiskResponseCombinerToList",
(node) -> new CompetingRiskResponseCombinerToList()
);
}

View file

@ -19,7 +19,7 @@ public class CompetingRiskErrorRateCalculator {
private final List<Row<CompetingRiskResponse>> dataset;
private final List<CompetingRiskFunctions> riskFunctions;
public CompetingRiskErrorRateCalculator(final List<Row<CompetingRiskResponse>> dataset, final Forest<CompetingRiskFunctions, CompetingRiskFunctions> forest){
public CompetingRiskErrorRateCalculator(final List<Row<CompetingRiskResponse>> dataset, final Forest<?, CompetingRiskFunctions> forest){
this.dataset = dataset;
this.riskFunctions = dataset.stream()
.map(forest::evaluateOOB)

View file

@ -6,8 +6,10 @@ import lombok.Data;
import lombok.RequiredArgsConstructor;
import org.apache.commons.csv.CSVRecord;
import java.io.Serializable;
@Data
public class CompetingRiskResponse {
public class CompetingRiskResponse implements Serializable {
private final int delta;
private final double u;

View file

@ -0,0 +1,26 @@
package ca.joeltherrien.randomforest.responses.competingrisk.combiner.alternative;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskFunctions;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
import ca.joeltherrien.randomforest.responses.competingrisk.combiner.CompetingRiskResponseCombiner;
import ca.joeltherrien.randomforest.tree.ResponseCombiner;
import lombok.Getter;
import lombok.RequiredArgsConstructor;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;
@RequiredArgsConstructor
public class CompetingRiskListCombiner implements ResponseCombiner<CompetingRiskResponse[], CompetingRiskFunctions> {
@Getter
private final CompetingRiskResponseCombiner originalCombiner;
@Override
public CompetingRiskFunctions combine(List<CompetingRiskResponse[]> responses) {
final List<CompetingRiskResponse> completeList = responses.stream().flatMap(Arrays::stream).collect(Collectors.toList());
return originalCombiner.combine(completeList);
}
}

View file

@ -0,0 +1,29 @@
package ca.joeltherrien.randomforest.responses.competingrisk.combiner.alternative;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
import ca.joeltherrien.randomforest.tree.ResponseCombiner;
import java.util.List;
/**
* This class takes all of the observations in a terminal node and 'combines' them into just a list of the observations.
*
* This is used in the alternative approach to only compute the functions at the final stage when combining trees.
*
*/
public class CompetingRiskResponseCombinerToList implements ResponseCombiner<CompetingRiskResponse, CompetingRiskResponse[]> {
@Override
public CompetingRiskResponse[] combine(List<CompetingRiskResponse> responses) {
final CompetingRiskResponse[] array = new CompetingRiskResponse[responses.size()];
for(int i=0; i<array.length; i++){
array[i] = responses.get(i);
}
return array;
}
}

View file

@ -27,7 +27,7 @@ public class Forest<O, FO> { // O = output of trees, FO = forest output. In prac
public FO evaluateOOB(CovariateRow row){
return treeResponseCombiner.combine(
trees.stream()
trees.parallelStream()
.filter(tree -> !tree.idInBootstrapSample(row.getId()))
.map(node -> node.evaluate(row))
.collect(Collectors.toList())

View file

@ -10,7 +10,13 @@ public class Tree<Y> implements Node<Y> {
private final Node<Y> rootNode;
private final int[] bootstrapRowIds;
private boolean bootStrapRowIdsSorted = false;
public Tree(Node<Y> rootNode, int[] bootstrapRowIds) {
this.rootNode = rootNode;
this.bootstrapRowIds = bootstrapRowIds;
Arrays.sort(bootstrapRowIds);
}
@Override
public Y evaluate(CovariateRow row) {
@ -21,21 +27,8 @@ public class Tree<Y> implements Node<Y> {
return bootstrapRowIds.clone();
}
/**
* Sort bootstrapRowIds. This is not done automatically for efficiency purposes, as in many cases we may not be interested in using bootstrapRowIds();
*
*/
public void sortBootstrapRowIds(){
if(!bootStrapRowIdsSorted){
Arrays.sort(bootstrapRowIds);
bootStrapRowIdsSorted = true;
}
}
public boolean idInBootstrapSample(int id){
this.sortBootstrapRowIds();
return Arrays.binarySearch(this.bootstrapRowIds, id) >= 0;
}