Add ability to compute different error rates.
This commit is contained in:
parent
d3994212b6
commit
bf56dfb59d
12 changed files with 225 additions and 46 deletions
|
@ -4,7 +4,11 @@ import ca.joeltherrien.randomforest.covariates.BooleanCovariateSettings;
|
|||
import ca.joeltherrien.randomforest.covariates.Covariate;
|
||||
import ca.joeltherrien.randomforest.covariates.FactorCovariateSettings;
|
||||
import ca.joeltherrien.randomforest.covariates.NumericCovariateSettings;
|
||||
import ca.joeltherrien.randomforest.responses.competingrisk.*;
|
||||
import ca.joeltherrien.randomforest.tree.Forest;
|
||||
import ca.joeltherrien.randomforest.tree.ForestTrainer;
|
||||
import ca.joeltherrien.randomforest.utils.MathFunction;
|
||||
import ca.joeltherrien.randomforest.utils.Utils;
|
||||
import com.fasterxml.jackson.databind.node.JsonNodeFactory;
|
||||
import com.fasterxml.jackson.databind.node.ObjectNode;
|
||||
import com.fasterxml.jackson.databind.node.TextNode;
|
||||
|
@ -12,10 +16,7 @@ import org.apache.commons.csv.CSVFormat;
|
|||
import org.apache.commons.csv.CSVParser;
|
||||
import org.apache.commons.csv.CSVRecord;
|
||||
|
||||
import java.io.File;
|
||||
import java.io.FileReader;
|
||||
import java.io.IOException;
|
||||
import java.io.Reader;
|
||||
import java.io.*;
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
|
@ -25,9 +26,10 @@ import java.util.stream.Collectors;
|
|||
public class Main {
|
||||
|
||||
|
||||
public static void main(String[] args) throws IOException {
|
||||
if(args.length != 1){
|
||||
System.out.println("Must provide one argument - the path to the settings.yaml file.");
|
||||
public static void main(String[] args) throws IOException, ClassNotFoundException {
|
||||
if(args.length < 2){
|
||||
System.out.println("Must provide two arguments - the path to the settings.yaml file and instructions to either train or analyze.");
|
||||
System.out.println("Note that analyzing only supports competing risk data, and that you must then specify a sample size for testing errors.");
|
||||
if(args.length == 0){
|
||||
System.out.println("Generating template file.");
|
||||
defaultTemplate().save(new File("template.yaml"));
|
||||
|
@ -40,24 +42,99 @@ public class Main {
|
|||
final List<Covariate> covariates = settings.getCovariates().stream()
|
||||
.map(cs -> cs.build()).collect(Collectors.toList());
|
||||
|
||||
if(args[1].equalsIgnoreCase("train")){
|
||||
final List<Row> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getDataFileLocation());
|
||||
|
||||
final ForestTrainer forestTrainer = new ForestTrainer(settings, dataset, covariates);
|
||||
|
||||
if(settings.isSaveProgress()){
|
||||
forestTrainer.trainParallelOnDisk(settings.getNumberOfThreads());
|
||||
}
|
||||
else{
|
||||
forestTrainer.trainParallelInMemory(settings.getNumberOfThreads());
|
||||
}
|
||||
}
|
||||
else if(args[1].equalsIgnoreCase("analyze")){
|
||||
// Perform different prediction measures
|
||||
|
||||
if(args.length < 3){
|
||||
System.out.println("Specify error sample size");
|
||||
}
|
||||
|
||||
final String yVarType = settings.getYVarSettings().get("type").asText();
|
||||
|
||||
if(!yVarType.equalsIgnoreCase("CompetingRiskResponse") && !yVarType.equalsIgnoreCase("CompetingRiskResponseWithCensorTime")){
|
||||
System.out.println("Analyze currently only works on competing risk data");
|
||||
}
|
||||
|
||||
final CompetingRiskFunctionCombiner responseCombiner = (CompetingRiskFunctionCombiner) settings.getTreeCombiner();
|
||||
final int[] events = responseCombiner.getEvents();
|
||||
|
||||
final List<Row<CompetingRiskResponse>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getDataFileLocation());
|
||||
|
||||
// Let's reduce this down to n
|
||||
final int n = Integer.parseInt(args[2]);
|
||||
Utils.reduceListToSize(dataset, n);
|
||||
|
||||
final File folder = new File(settings.getSaveTreeLocation());
|
||||
final Forest<CompetingRiskFunctions, CompetingRiskFunctions> forest = DataLoader.loadForest(folder, responseCombiner);
|
||||
|
||||
System.out.println("Finished loading trees + dataset; creating calculator and evaluating OOB predictions");
|
||||
|
||||
final CompetingRiskErrorRateCalculator errorRateCalculator = new CompetingRiskErrorRateCalculator(dataset, forest);
|
||||
final PrintWriter printWriter = new PrintWriter(settings.getSaveTreeLocation() + "/errors.txt");
|
||||
|
||||
System.out.println("Running Naive Mortality");
|
||||
|
||||
final double naiveMortality = errorRateCalculator.calculateNaiveMortalityError(events);
|
||||
printWriter.write("Naive Mortality: ");
|
||||
printWriter.write(Double.toString(naiveMortality));
|
||||
printWriter.write('\n');
|
||||
|
||||
System.out.println("Running Naive Concordance");
|
||||
|
||||
final double[] naiveConcordance = errorRateCalculator.calculateConcordance(events);
|
||||
printWriter.write("Naive concordance:\n");
|
||||
for(int i=0; i<events.length; i++){
|
||||
printWriter.write('\t');
|
||||
printWriter.write(Integer.toString(events[i]));
|
||||
printWriter.write(": ");
|
||||
printWriter.write(Double.toString(naiveConcordance[i]));
|
||||
printWriter.write('\n');
|
||||
}
|
||||
|
||||
if(yVarType.equalsIgnoreCase("CompetingRiskResponseWithCensorTime")){
|
||||
System.out.println("Running IPCW Concordance - creating censor distribution");
|
||||
|
||||
final double[] censorTimes = dataset.stream()
|
||||
.mapToDouble(row -> ((CompetingRiskResponseWithCensorTime) row.getResponse()).getC())
|
||||
.toArray();
|
||||
final MathFunction censorDistribution = Utils.estimateOneMinusECDF(censorTimes);
|
||||
|
||||
System.out.println("Finished generating censor distribution - running concordance");
|
||||
|
||||
final double[] ipcwConcordance = errorRateCalculator.calculateIPCWConcordance(events, censorDistribution);
|
||||
printWriter.write("IPCW concordance:\n");
|
||||
for(int i=0; i<events.length; i++){
|
||||
printWriter.write('\t');
|
||||
printWriter.write(Integer.toString(events[i]));
|
||||
printWriter.write(": ");
|
||||
printWriter.write(Double.toString(ipcwConcordance[i]));
|
||||
printWriter.write('\n');
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
final List<Row<Double>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getDataFileLocation());
|
||||
|
||||
final ForestTrainer<Double, Double, Double> forestTrainer = new ForestTrainer<>(settings, dataset, covariates);
|
||||
|
||||
if(settings.isSaveProgress()){
|
||||
forestTrainer.trainParallelOnDisk(settings.getNumberOfThreads());
|
||||
printWriter.close();
|
||||
}
|
||||
else{
|
||||
forestTrainer.trainParallelInMemory(settings.getNumberOfThreads());
|
||||
System.out.println("Invalid instruction; use either train or analyze.");
|
||||
System.out.println("Note that analyzing only supports competing risk data.");
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
||||
private static Settings defaultTemplate(){
|
||||
|
||||
final ObjectNode groupDifferentiatorSettings = new ObjectNode(JsonNodeFactory.instance);
|
||||
|
|
|
@ -50,7 +50,6 @@ public interface Covariate<V> extends Serializable {
|
|||
final List<Row<Y>> leftHand = new LinkedList<>();
|
||||
final List<Row<Y>> rightHand = new LinkedList<>();
|
||||
|
||||
final List<Boolean> nonMissingDecisions = new ArrayList<>();
|
||||
final List<Row<Y>> missingValueRows = new ArrayList<>();
|
||||
|
||||
|
||||
|
@ -63,8 +62,6 @@ public interface Covariate<V> extends Serializable {
|
|||
}
|
||||
|
||||
final boolean isLeftHand = isLeftHand(value);
|
||||
nonMissingDecisions.add(isLeftHand);
|
||||
|
||||
if(isLeftHand){
|
||||
leftHand.add(row);
|
||||
}
|
||||
|
@ -74,27 +71,17 @@ public interface Covariate<V> extends Serializable {
|
|||
|
||||
}
|
||||
|
||||
if(nonMissingDecisions.size() == 0 && missingValueRows.size() > 0){
|
||||
throw new IllegalArgumentException("Can't apply " + this + " when there are rows with missing data and no non-missing value rows");
|
||||
}
|
||||
|
||||
final Random random = ThreadLocalRandom.current();
|
||||
for(final Row<Y> missingValueRow : missingValueRows){
|
||||
final boolean randomDecision = nonMissingDecisions.get(random.nextInt(nonMissingDecisions.size()));
|
||||
|
||||
if(randomDecision){
|
||||
leftHand.add(missingValueRow);
|
||||
}
|
||||
else{
|
||||
rightHand.add(missingValueRow);
|
||||
}
|
||||
}
|
||||
|
||||
return new Split<>(leftHand, rightHand);
|
||||
return new Split<>(leftHand, rightHand, missingValueRows);
|
||||
}
|
||||
|
||||
default boolean isLeftHand(CovariateRow row){
|
||||
default boolean isLeftHand(CovariateRow row, final double probabilityNaLeftHand){
|
||||
final Value<V> value = (Value<V>) row.getCovariateValue(getParent().getName());
|
||||
|
||||
if(value.isNA()){
|
||||
return ThreadLocalRandom.current().nextDouble() <= probabilityNaLeftHand;
|
||||
}
|
||||
|
||||
return isLeftHand(value);
|
||||
}
|
||||
|
||||
|
|
|
@ -191,6 +191,9 @@ public class CompetingRiskErrorRateCalculator {
|
|||
}
|
||||
|
||||
final double mortalityI = mortalityArray[i];
|
||||
final double Ti = responseI.getU();
|
||||
final double G_Ti_minus = censoringDistribution.evaluatePrevious(Ti).getY();
|
||||
final double AijWeight = 1.0 / (censoringDistribution.evaluate(Ti).getY() * G_Ti_minus);
|
||||
|
||||
for(int j=0; j<mortalityArray.length; j++){
|
||||
final CompetingRiskResponse responseJ = responseList.get(j);
|
||||
|
@ -198,11 +201,10 @@ public class CompetingRiskErrorRateCalculator {
|
|||
final double AijWeightPlusBijWeight;
|
||||
|
||||
if(responseI.getU() < responseJ.getU()){ // Aij == 1
|
||||
final double Ti = responseI.getU();
|
||||
AijWeightPlusBijWeight = 1.0 / (censoringDistribution.evaluate(Ti).getY() * censoringDistribution.evaluatePrevious(Ti).getY());
|
||||
AijWeightPlusBijWeight = AijWeight;
|
||||
}
|
||||
else if(responseI.getU() >= responseJ.getU() && !responseJ.isCensored() && responseJ.getDelta() != event){ // Bij == 1
|
||||
AijWeightPlusBijWeight = 1.0 / (censoringDistribution.evaluatePrevious(responseI.getU()).getY() * censoringDistribution.evaluatePrevious(responseJ.getU()).getY());
|
||||
AijWeightPlusBijWeight = 1.0 / (G_Ti_minus * censoringDistribution.evaluatePrevious(responseJ.getU()).getY());
|
||||
}
|
||||
else{
|
||||
continue;
|
||||
|
|
|
@ -16,6 +16,13 @@ public class CompetingRiskFunctionCombiner implements ResponseCombiner<Competing
|
|||
private final int[] events;
|
||||
private final double[] times; // We may restrict ourselves to specific times.
|
||||
|
||||
public int[] getEvents(){
|
||||
return events.clone();
|
||||
}
|
||||
|
||||
public double[] getTimes(){
|
||||
return times.clone();
|
||||
}
|
||||
|
||||
@Override
|
||||
public CompetingRiskFunctions combine(List<CompetingRiskFunctions> responses) {
|
||||
|
|
|
@ -20,6 +20,14 @@ public class CompetingRiskResponseCombiner implements ResponseCombiner<Competing
|
|||
private final int[] events;
|
||||
private final double[] times; // We may restrict ourselves to specific times.
|
||||
|
||||
public int[] getEvents(){
|
||||
return events.clone();
|
||||
}
|
||||
|
||||
public double[] getTimes(){
|
||||
return times.clone();
|
||||
}
|
||||
|
||||
@Override
|
||||
public CompetingRiskFunctions combine(List<CompetingRiskResponse> responses) {
|
||||
|
||||
|
|
|
@ -11,7 +11,7 @@ import org.apache.commons.csv.CSVRecord;
|
|||
*
|
||||
*/
|
||||
@Data
|
||||
public class CompetingRiskResponseWithCensorTime extends CompetingRiskResponse {
|
||||
public final class CompetingRiskResponseWithCensorTime extends CompetingRiskResponse {
|
||||
private final double c;
|
||||
|
||||
public CompetingRiskResponseWithCensorTime(int delta, double u, double c) {
|
||||
|
|
|
@ -6,7 +6,7 @@ import lombok.Data;
|
|||
import java.util.List;
|
||||
|
||||
/**
|
||||
* Very simple class that contains two lists; it's essentially a tuple.
|
||||
* Very simple class that contains three lists; it's essentially a thruple.
|
||||
*
|
||||
* @author joel
|
||||
*
|
||||
|
@ -16,5 +16,6 @@ public class Split<Y> {
|
|||
|
||||
public final List<Row<Y>> leftHand;
|
||||
public final List<Row<Y>> rightHand;
|
||||
public final List<Row<Y>> naHand;
|
||||
|
||||
}
|
||||
|
|
|
@ -10,11 +10,12 @@ public class SplitNode<Y> implements Node<Y> {
|
|||
private final Node<Y> leftHand;
|
||||
private final Node<Y> rightHand;
|
||||
private final Covariate.SplitRule splitRule;
|
||||
private final double probabilityNaLeftHand; // used when assigning NA values
|
||||
|
||||
@Override
|
||||
public Y evaluate(CovariateRow row) {
|
||||
|
||||
if(splitRule.isLeftHand(row)){
|
||||
if(splitRule.isLeftHand(row, probabilityNaLeftHand)){
|
||||
return leftHand.evaluate(row);
|
||||
}
|
||||
else{
|
||||
|
|
|
@ -67,10 +67,29 @@ public class TreeTrainer<Y, O> {
|
|||
|
||||
final Split<Y> split = bestSplitRule.applyRule(data); // TODO optimize this as we're duplicating work done in findBestSplitRule
|
||||
|
||||
// We have to handle any NAs
|
||||
if(split.leftHand.size() == 0 && split.rightHand.size() == 0 && split.naHand.size() > 0){
|
||||
throw new IllegalArgumentException("Can't apply " + this + " when there are rows with missing data and no non-missing value rows");
|
||||
}
|
||||
|
||||
final double probabilityLeftHand = (double) split.leftHand.size() / (double) (split.leftHand.size() + split.rightHand.size());
|
||||
|
||||
final Random random = ThreadLocalRandom.current();
|
||||
for(final Row<Y> missingValueRow : split.naHand){
|
||||
final boolean randomDecision = random.nextDouble() <= probabilityLeftHand;
|
||||
if(randomDecision){
|
||||
split.leftHand.add(missingValueRow);
|
||||
}
|
||||
else{
|
||||
split.rightHand.add(missingValueRow);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
final Node<O> leftNode = growNode(split.leftHand, depth+1);
|
||||
final Node<O> rightNode = growNode(split.rightHand, depth+1);
|
||||
|
||||
return new SplitNode<>(leftNode, rightNode, bestSplitRule);
|
||||
return new SplitNode<>(leftNode, rightNode, bestSplitRule, probabilityLeftHand);
|
||||
|
||||
}
|
||||
else{
|
||||
|
@ -119,13 +138,30 @@ public class TreeTrainer<Y, O> {
|
|||
for(final Covariate.SplitRule possibleRule : splitRulesToTry){
|
||||
final Split<Y> possibleSplit = possibleRule.applyRule(data);
|
||||
|
||||
// We have to handle any NAs
|
||||
if(possibleSplit.leftHand.size() == 0 && possibleSplit.rightHand.size() == 0 && possibleSplit.naHand.size() > 0){
|
||||
throw new IllegalArgumentException("Can't apply " + this + " when there are rows with missing data and no non-missing value rows");
|
||||
}
|
||||
|
||||
final double probabilityLeftHand = (double) possibleSplit.leftHand.size() / (double) (possibleSplit.leftHand.size() + possibleSplit.rightHand.size());
|
||||
|
||||
final Random random = ThreadLocalRandom.current();
|
||||
for(final Row<Y> missingValueRow : possibleSplit.naHand){
|
||||
final boolean randomDecision = random.nextDouble() <= probabilityLeftHand;
|
||||
if(randomDecision){
|
||||
possibleSplit.leftHand.add(missingValueRow);
|
||||
}
|
||||
else{
|
||||
possibleSplit.rightHand.add(missingValueRow);
|
||||
}
|
||||
}
|
||||
|
||||
final Double score = groupDifferentiator.differentiate(
|
||||
possibleSplit.leftHand.stream().map(row -> row.getResponse()).collect(Collectors.toList()),
|
||||
possibleSplit.rightHand.stream().map(row -> row.getResponse()).collect(Collectors.toList())
|
||||
);
|
||||
|
||||
|
||||
|
||||
if(score != null && !Double.isNaN(score) && (score > bestSplitScore || first)){
|
||||
bestSplitRule = possibleRule;
|
||||
bestSplitScore = score;
|
||||
|
|
|
@ -10,6 +10,6 @@ import java.io.Serializable;
|
|||
*/
|
||||
@Data
|
||||
public class Point implements Serializable {
|
||||
private final Double time;
|
||||
private final Double y;
|
||||
private final double time;
|
||||
private final double y;
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package ca.joeltherrien.randomforest.utils;
|
||||
|
||||
import java.util.*;
|
||||
import java.util.concurrent.ThreadLocalRandom;
|
||||
|
||||
public class Utils {
|
||||
|
||||
|
@ -36,5 +37,32 @@ public class Utils {
|
|||
|
||||
}
|
||||
|
||||
public static <T> void reduceListToSize(List<T> list, int n){
|
||||
if(list.size() <= n){
|
||||
return;
|
||||
}
|
||||
|
||||
final Random random = ThreadLocalRandom.current();
|
||||
if(n > list.size()/2){
|
||||
// faster to randomly remove items
|
||||
while(list.size() > n){
|
||||
final int indexToRemove = random.nextInt(list.size());
|
||||
list.remove(indexToRemove);
|
||||
}
|
||||
}
|
||||
else{
|
||||
// Faster to create a new list
|
||||
final List<T> newList = new ArrayList<>(n);
|
||||
while(newList.size() < n){
|
||||
final int indexToAdd = random.nextInt(list.size());
|
||||
newList.add(list.remove(indexToAdd));
|
||||
}
|
||||
|
||||
list.clear();
|
||||
list.addAll(newList);
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
|
|
@ -5,6 +5,11 @@ import ca.joeltherrien.randomforest.utils.Point;
|
|||
import ca.joeltherrien.randomforest.utils.Utils;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||
|
||||
public class TestUtils {
|
||||
|
@ -70,4 +75,31 @@ public class TestUtils {
|
|||
|
||||
}
|
||||
|
||||
@Test
|
||||
public void reduceListToSize(){
|
||||
final List<Integer> testList = List.of(1, 2, 3, 4, 5, 6, 7, 8, 9, 10);
|
||||
|
||||
for(int i=0; i<100; i++) { // want to test many times to ensure it doesn't work just due to randomness
|
||||
final List<Integer> testList1 = new ArrayList<>(testList);
|
||||
// Test when removing elements
|
||||
Utils.reduceListToSize(testList1, 7);
|
||||
assertEquals(7, testList1.size()); // verify proper size
|
||||
assertEquals(7, new HashSet<>(testList1).size()); // verify the items are unique
|
||||
|
||||
|
||||
final List<Integer> testList2 = new ArrayList<>(testList);
|
||||
// Test when adding elements
|
||||
Utils.reduceListToSize(testList2, 3);
|
||||
assertEquals(3, testList2.size()); // verify proper size
|
||||
assertEquals(3, new HashSet<>(testList2).size()); // verify the items are unique
|
||||
|
||||
final List<Integer> testList3 = new ArrayList<>(testList);
|
||||
// verify no change
|
||||
Utils.reduceListToSize(testList3, 15);
|
||||
assertEquals(10, testList3.size()); // verify proper size
|
||||
assertEquals(10, new HashSet<>(testList3).size()); // verify the items are unique
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue