Add ability to compute different error rates.

This commit is contained in:
Joel Therrien 2018-08-07 10:52:52 -07:00
parent d3994212b6
commit bf56dfb59d
12 changed files with 225 additions and 46 deletions

View file

@ -4,7 +4,11 @@ import ca.joeltherrien.randomforest.covariates.BooleanCovariateSettings;
import ca.joeltherrien.randomforest.covariates.Covariate; import ca.joeltherrien.randomforest.covariates.Covariate;
import ca.joeltherrien.randomforest.covariates.FactorCovariateSettings; import ca.joeltherrien.randomforest.covariates.FactorCovariateSettings;
import ca.joeltherrien.randomforest.covariates.NumericCovariateSettings; import ca.joeltherrien.randomforest.covariates.NumericCovariateSettings;
import ca.joeltherrien.randomforest.responses.competingrisk.*;
import ca.joeltherrien.randomforest.tree.Forest;
import ca.joeltherrien.randomforest.tree.ForestTrainer; import ca.joeltherrien.randomforest.tree.ForestTrainer;
import ca.joeltherrien.randomforest.utils.MathFunction;
import ca.joeltherrien.randomforest.utils.Utils;
import com.fasterxml.jackson.databind.node.JsonNodeFactory; import com.fasterxml.jackson.databind.node.JsonNodeFactory;
import com.fasterxml.jackson.databind.node.ObjectNode; import com.fasterxml.jackson.databind.node.ObjectNode;
import com.fasterxml.jackson.databind.node.TextNode; import com.fasterxml.jackson.databind.node.TextNode;
@ -12,10 +16,7 @@ import org.apache.commons.csv.CSVFormat;
import org.apache.commons.csv.CSVParser; import org.apache.commons.csv.CSVParser;
import org.apache.commons.csv.CSVRecord; import org.apache.commons.csv.CSVRecord;
import java.io.File; import java.io.*;
import java.io.FileReader;
import java.io.IOException;
import java.io.Reader;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.HashMap; import java.util.HashMap;
import java.util.List; import java.util.List;
@ -25,9 +26,10 @@ import java.util.stream.Collectors;
public class Main { public class Main {
public static void main(String[] args) throws IOException { public static void main(String[] args) throws IOException, ClassNotFoundException {
if(args.length != 1){ if(args.length < 2){
System.out.println("Must provide one argument - the path to the settings.yaml file."); System.out.println("Must provide two arguments - the path to the settings.yaml file and instructions to either train or analyze.");
System.out.println("Note that analyzing only supports competing risk data, and that you must then specify a sample size for testing errors.");
if(args.length == 0){ if(args.length == 0){
System.out.println("Generating template file."); System.out.println("Generating template file.");
defaultTemplate().save(new File("template.yaml")); defaultTemplate().save(new File("template.yaml"));
@ -40,24 +42,99 @@ public class Main {
final List<Covariate> covariates = settings.getCovariates().stream() final List<Covariate> covariates = settings.getCovariates().stream()
.map(cs -> cs.build()).collect(Collectors.toList()); .map(cs -> cs.build()).collect(Collectors.toList());
if(args[1].equalsIgnoreCase("train")){
final List<Row> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getDataFileLocation());
final ForestTrainer forestTrainer = new ForestTrainer(settings, dataset, covariates);
if(settings.isSaveProgress()){
forestTrainer.trainParallelOnDisk(settings.getNumberOfThreads());
}
else{
forestTrainer.trainParallelInMemory(settings.getNumberOfThreads());
}
}
else if(args[1].equalsIgnoreCase("analyze")){
// Perform different prediction measures
if(args.length < 3){
System.out.println("Specify error sample size");
}
final String yVarType = settings.getYVarSettings().get("type").asText();
if(!yVarType.equalsIgnoreCase("CompetingRiskResponse") && !yVarType.equalsIgnoreCase("CompetingRiskResponseWithCensorTime")){
System.out.println("Analyze currently only works on competing risk data");
}
final CompetingRiskFunctionCombiner responseCombiner = (CompetingRiskFunctionCombiner) settings.getTreeCombiner();
final int[] events = responseCombiner.getEvents();
final List<Row<CompetingRiskResponse>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getDataFileLocation());
// Let's reduce this down to n
final int n = Integer.parseInt(args[2]);
Utils.reduceListToSize(dataset, n);
final File folder = new File(settings.getSaveTreeLocation());
final Forest<CompetingRiskFunctions, CompetingRiskFunctions> forest = DataLoader.loadForest(folder, responseCombiner);
System.out.println("Finished loading trees + dataset; creating calculator and evaluating OOB predictions");
final CompetingRiskErrorRateCalculator errorRateCalculator = new CompetingRiskErrorRateCalculator(dataset, forest);
final PrintWriter printWriter = new PrintWriter(settings.getSaveTreeLocation() + "/errors.txt");
System.out.println("Running Naive Mortality");
final double naiveMortality = errorRateCalculator.calculateNaiveMortalityError(events);
printWriter.write("Naive Mortality: ");
printWriter.write(Double.toString(naiveMortality));
printWriter.write('\n');
System.out.println("Running Naive Concordance");
final double[] naiveConcordance = errorRateCalculator.calculateConcordance(events);
printWriter.write("Naive concordance:\n");
for(int i=0; i<events.length; i++){
printWriter.write('\t');
printWriter.write(Integer.toString(events[i]));
printWriter.write(": ");
printWriter.write(Double.toString(naiveConcordance[i]));
printWriter.write('\n');
}
if(yVarType.equalsIgnoreCase("CompetingRiskResponseWithCensorTime")){
System.out.println("Running IPCW Concordance - creating censor distribution");
final double[] censorTimes = dataset.stream()
.mapToDouble(row -> ((CompetingRiskResponseWithCensorTime) row.getResponse()).getC())
.toArray();
final MathFunction censorDistribution = Utils.estimateOneMinusECDF(censorTimes);
System.out.println("Finished generating censor distribution - running concordance");
final double[] ipcwConcordance = errorRateCalculator.calculateIPCWConcordance(events, censorDistribution);
printWriter.write("IPCW concordance:\n");
for(int i=0; i<events.length; i++){
printWriter.write('\t');
printWriter.write(Integer.toString(events[i]));
printWriter.write(": ");
printWriter.write(Double.toString(ipcwConcordance[i]));
printWriter.write('\n');
}
}
final List<Row<Double>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getDataFileLocation()); printWriter.close();
final ForestTrainer<Double, Double, Double> forestTrainer = new ForestTrainer<>(settings, dataset, covariates);
if(settings.isSaveProgress()){
forestTrainer.trainParallelOnDisk(settings.getNumberOfThreads());
} }
else{ else{
forestTrainer.trainParallelInMemory(settings.getNumberOfThreads()); System.out.println("Invalid instruction; use either train or analyze.");
System.out.println("Note that analyzing only supports competing risk data.");
} }
} }
private static Settings defaultTemplate(){ private static Settings defaultTemplate(){
final ObjectNode groupDifferentiatorSettings = new ObjectNode(JsonNodeFactory.instance); final ObjectNode groupDifferentiatorSettings = new ObjectNode(JsonNodeFactory.instance);

View file

@ -50,7 +50,6 @@ public interface Covariate<V> extends Serializable {
final List<Row<Y>> leftHand = new LinkedList<>(); final List<Row<Y>> leftHand = new LinkedList<>();
final List<Row<Y>> rightHand = new LinkedList<>(); final List<Row<Y>> rightHand = new LinkedList<>();
final List<Boolean> nonMissingDecisions = new ArrayList<>();
final List<Row<Y>> missingValueRows = new ArrayList<>(); final List<Row<Y>> missingValueRows = new ArrayList<>();
@ -63,8 +62,6 @@ public interface Covariate<V> extends Serializable {
} }
final boolean isLeftHand = isLeftHand(value); final boolean isLeftHand = isLeftHand(value);
nonMissingDecisions.add(isLeftHand);
if(isLeftHand){ if(isLeftHand){
leftHand.add(row); leftHand.add(row);
} }
@ -74,27 +71,17 @@ public interface Covariate<V> extends Serializable {
} }
if(nonMissingDecisions.size() == 0 && missingValueRows.size() > 0){
throw new IllegalArgumentException("Can't apply " + this + " when there are rows with missing data and no non-missing value rows");
}
final Random random = ThreadLocalRandom.current(); return new Split<>(leftHand, rightHand, missingValueRows);
for(final Row<Y> missingValueRow : missingValueRows){
final boolean randomDecision = nonMissingDecisions.get(random.nextInt(nonMissingDecisions.size()));
if(randomDecision){
leftHand.add(missingValueRow);
}
else{
rightHand.add(missingValueRow);
}
}
return new Split<>(leftHand, rightHand);
} }
default boolean isLeftHand(CovariateRow row){ default boolean isLeftHand(CovariateRow row, final double probabilityNaLeftHand){
final Value<V> value = (Value<V>) row.getCovariateValue(getParent().getName()); final Value<V> value = (Value<V>) row.getCovariateValue(getParent().getName());
if(value.isNA()){
return ThreadLocalRandom.current().nextDouble() <= probabilityNaLeftHand;
}
return isLeftHand(value); return isLeftHand(value);
} }

View file

@ -191,6 +191,9 @@ public class CompetingRiskErrorRateCalculator {
} }
final double mortalityI = mortalityArray[i]; final double mortalityI = mortalityArray[i];
final double Ti = responseI.getU();
final double G_Ti_minus = censoringDistribution.evaluatePrevious(Ti).getY();
final double AijWeight = 1.0 / (censoringDistribution.evaluate(Ti).getY() * G_Ti_minus);
for(int j=0; j<mortalityArray.length; j++){ for(int j=0; j<mortalityArray.length; j++){
final CompetingRiskResponse responseJ = responseList.get(j); final CompetingRiskResponse responseJ = responseList.get(j);
@ -198,11 +201,10 @@ public class CompetingRiskErrorRateCalculator {
final double AijWeightPlusBijWeight; final double AijWeightPlusBijWeight;
if(responseI.getU() < responseJ.getU()){ // Aij == 1 if(responseI.getU() < responseJ.getU()){ // Aij == 1
final double Ti = responseI.getU(); AijWeightPlusBijWeight = AijWeight;
AijWeightPlusBijWeight = 1.0 / (censoringDistribution.evaluate(Ti).getY() * censoringDistribution.evaluatePrevious(Ti).getY());
} }
else if(responseI.getU() >= responseJ.getU() && !responseJ.isCensored() && responseJ.getDelta() != event){ // Bij == 1 else if(responseI.getU() >= responseJ.getU() && !responseJ.isCensored() && responseJ.getDelta() != event){ // Bij == 1
AijWeightPlusBijWeight = 1.0 / (censoringDistribution.evaluatePrevious(responseI.getU()).getY() * censoringDistribution.evaluatePrevious(responseJ.getU()).getY()); AijWeightPlusBijWeight = 1.0 / (G_Ti_minus * censoringDistribution.evaluatePrevious(responseJ.getU()).getY());
} }
else{ else{
continue; continue;

View file

@ -16,6 +16,13 @@ public class CompetingRiskFunctionCombiner implements ResponseCombiner<Competing
private final int[] events; private final int[] events;
private final double[] times; // We may restrict ourselves to specific times. private final double[] times; // We may restrict ourselves to specific times.
public int[] getEvents(){
return events.clone();
}
public double[] getTimes(){
return times.clone();
}
@Override @Override
public CompetingRiskFunctions combine(List<CompetingRiskFunctions> responses) { public CompetingRiskFunctions combine(List<CompetingRiskFunctions> responses) {

View file

@ -20,6 +20,14 @@ public class CompetingRiskResponseCombiner implements ResponseCombiner<Competing
private final int[] events; private final int[] events;
private final double[] times; // We may restrict ourselves to specific times. private final double[] times; // We may restrict ourselves to specific times.
public int[] getEvents(){
return events.clone();
}
public double[] getTimes(){
return times.clone();
}
@Override @Override
public CompetingRiskFunctions combine(List<CompetingRiskResponse> responses) { public CompetingRiskFunctions combine(List<CompetingRiskResponse> responses) {

View file

@ -11,7 +11,7 @@ import org.apache.commons.csv.CSVRecord;
* *
*/ */
@Data @Data
public class CompetingRiskResponseWithCensorTime extends CompetingRiskResponse { public final class CompetingRiskResponseWithCensorTime extends CompetingRiskResponse {
private final double c; private final double c;
public CompetingRiskResponseWithCensorTime(int delta, double u, double c) { public CompetingRiskResponseWithCensorTime(int delta, double u, double c) {

View file

@ -6,7 +6,7 @@ import lombok.Data;
import java.util.List; import java.util.List;
/** /**
* Very simple class that contains two lists; it's essentially a tuple. * Very simple class that contains three lists; it's essentially a thruple.
* *
* @author joel * @author joel
* *
@ -16,5 +16,6 @@ public class Split<Y> {
public final List<Row<Y>> leftHand; public final List<Row<Y>> leftHand;
public final List<Row<Y>> rightHand; public final List<Row<Y>> rightHand;
public final List<Row<Y>> naHand;
} }

View file

@ -10,11 +10,12 @@ public class SplitNode<Y> implements Node<Y> {
private final Node<Y> leftHand; private final Node<Y> leftHand;
private final Node<Y> rightHand; private final Node<Y> rightHand;
private final Covariate.SplitRule splitRule; private final Covariate.SplitRule splitRule;
private final double probabilityNaLeftHand; // used when assigning NA values
@Override @Override
public Y evaluate(CovariateRow row) { public Y evaluate(CovariateRow row) {
if(splitRule.isLeftHand(row)){ if(splitRule.isLeftHand(row, probabilityNaLeftHand)){
return leftHand.evaluate(row); return leftHand.evaluate(row);
} }
else{ else{

View file

@ -67,10 +67,29 @@ public class TreeTrainer<Y, O> {
final Split<Y> split = bestSplitRule.applyRule(data); // TODO optimize this as we're duplicating work done in findBestSplitRule final Split<Y> split = bestSplitRule.applyRule(data); // TODO optimize this as we're duplicating work done in findBestSplitRule
// We have to handle any NAs
if(split.leftHand.size() == 0 && split.rightHand.size() == 0 && split.naHand.size() > 0){
throw new IllegalArgumentException("Can't apply " + this + " when there are rows with missing data and no non-missing value rows");
}
final double probabilityLeftHand = (double) split.leftHand.size() / (double) (split.leftHand.size() + split.rightHand.size());
final Random random = ThreadLocalRandom.current();
for(final Row<Y> missingValueRow : split.naHand){
final boolean randomDecision = random.nextDouble() <= probabilityLeftHand;
if(randomDecision){
split.leftHand.add(missingValueRow);
}
else{
split.rightHand.add(missingValueRow);
}
}
final Node<O> leftNode = growNode(split.leftHand, depth+1); final Node<O> leftNode = growNode(split.leftHand, depth+1);
final Node<O> rightNode = growNode(split.rightHand, depth+1); final Node<O> rightNode = growNode(split.rightHand, depth+1);
return new SplitNode<>(leftNode, rightNode, bestSplitRule); return new SplitNode<>(leftNode, rightNode, bestSplitRule, probabilityLeftHand);
} }
else{ else{
@ -119,13 +138,30 @@ public class TreeTrainer<Y, O> {
for(final Covariate.SplitRule possibleRule : splitRulesToTry){ for(final Covariate.SplitRule possibleRule : splitRulesToTry){
final Split<Y> possibleSplit = possibleRule.applyRule(data); final Split<Y> possibleSplit = possibleRule.applyRule(data);
// We have to handle any NAs
if(possibleSplit.leftHand.size() == 0 && possibleSplit.rightHand.size() == 0 && possibleSplit.naHand.size() > 0){
throw new IllegalArgumentException("Can't apply " + this + " when there are rows with missing data and no non-missing value rows");
}
final double probabilityLeftHand = (double) possibleSplit.leftHand.size() / (double) (possibleSplit.leftHand.size() + possibleSplit.rightHand.size());
final Random random = ThreadLocalRandom.current();
for(final Row<Y> missingValueRow : possibleSplit.naHand){
final boolean randomDecision = random.nextDouble() <= probabilityLeftHand;
if(randomDecision){
possibleSplit.leftHand.add(missingValueRow);
}
else{
possibleSplit.rightHand.add(missingValueRow);
}
}
final Double score = groupDifferentiator.differentiate( final Double score = groupDifferentiator.differentiate(
possibleSplit.leftHand.stream().map(row -> row.getResponse()).collect(Collectors.toList()), possibleSplit.leftHand.stream().map(row -> row.getResponse()).collect(Collectors.toList()),
possibleSplit.rightHand.stream().map(row -> row.getResponse()).collect(Collectors.toList()) possibleSplit.rightHand.stream().map(row -> row.getResponse()).collect(Collectors.toList())
); );
if(score != null && !Double.isNaN(score) && (score > bestSplitScore || first)){ if(score != null && !Double.isNaN(score) && (score > bestSplitScore || first)){
bestSplitRule = possibleRule; bestSplitRule = possibleRule;
bestSplitScore = score; bestSplitScore = score;

View file

@ -10,6 +10,6 @@ import java.io.Serializable;
*/ */
@Data @Data
public class Point implements Serializable { public class Point implements Serializable {
private final Double time; private final double time;
private final Double y; private final double y;
} }

View file

@ -1,6 +1,7 @@
package ca.joeltherrien.randomforest.utils; package ca.joeltherrien.randomforest.utils;
import java.util.*; import java.util.*;
import java.util.concurrent.ThreadLocalRandom;
public class Utils { public class Utils {
@ -36,5 +37,32 @@ public class Utils {
} }
public static <T> void reduceListToSize(List<T> list, int n){
if(list.size() <= n){
return;
}
final Random random = ThreadLocalRandom.current();
if(n > list.size()/2){
// faster to randomly remove items
while(list.size() > n){
final int indexToRemove = random.nextInt(list.size());
list.remove(indexToRemove);
}
}
else{
// Faster to create a new list
final List<T> newList = new ArrayList<>(n);
while(newList.size() < n){
final int indexToAdd = random.nextInt(list.size());
newList.add(list.remove(indexToAdd));
}
list.clear();
list.addAll(newList);
}
}
} }

View file

@ -5,6 +5,11 @@ import ca.joeltherrien.randomforest.utils.Point;
import ca.joeltherrien.randomforest.utils.Utils; import ca.joeltherrien.randomforest.utils.Utils;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.assertTrue;
public class TestUtils { public class TestUtils {
@ -70,4 +75,31 @@ public class TestUtils {
} }
@Test
public void reduceListToSize(){
final List<Integer> testList = List.of(1, 2, 3, 4, 5, 6, 7, 8, 9, 10);
for(int i=0; i<100; i++) { // want to test many times to ensure it doesn't work just due to randomness
final List<Integer> testList1 = new ArrayList<>(testList);
// Test when removing elements
Utils.reduceListToSize(testList1, 7);
assertEquals(7, testList1.size()); // verify proper size
assertEquals(7, new HashSet<>(testList1).size()); // verify the items are unique
final List<Integer> testList2 = new ArrayList<>(testList);
// Test when adding elements
Utils.reduceListToSize(testList2, 3);
assertEquals(3, testList2.size()); // verify proper size
assertEquals(3, new HashSet<>(testList2).size()); // verify the items are unique
final List<Integer> testList3 = new ArrayList<>(testList);
// verify no change
Utils.reduceListToSize(testList3, 15);
assertEquals(10, testList3.size()); // verify proper size
assertEquals(10, new HashSet<>(testList3).size()); // verify the items are unique
}
}
} }