From bf56dfb59d0f589b9570215d8604dfea34325f8e Mon Sep 17 00:00:00 2001 From: Joel Therrien Date: Tue, 7 Aug 2018 10:52:52 -0700 Subject: [PATCH] Add ability to compute different error rates. --- .../ca/joeltherrien/randomforest/Main.java | 109 +++++++++++++++--- .../randomforest/covariates/Covariate.java | 27 ++--- .../CompetingRiskErrorRateCalculator.java | 8 +- .../CompetingRiskFunctionCombiner.java | 7 ++ .../CompetingRiskResponseCombiner.java | 8 ++ .../CompetingRiskResponseWithCensorTime.java | 2 +- .../joeltherrien/randomforest/tree/Split.java | 3 +- .../randomforest/tree/SplitNode.java | 3 +- .../randomforest/tree/TreeTrainer.java | 40 ++++++- .../randomforest/utils/Point.java | 4 +- .../randomforest/utils/Utils.java | 28 +++++ .../joeltherrien/randomforest/TestUtils.java | 32 +++++ 12 files changed, 225 insertions(+), 46 deletions(-) diff --git a/src/main/java/ca/joeltherrien/randomforest/Main.java b/src/main/java/ca/joeltherrien/randomforest/Main.java index 38d9a6f..c9c8d44 100644 --- a/src/main/java/ca/joeltherrien/randomforest/Main.java +++ b/src/main/java/ca/joeltherrien/randomforest/Main.java @@ -4,7 +4,11 @@ import ca.joeltherrien.randomforest.covariates.BooleanCovariateSettings; import ca.joeltherrien.randomforest.covariates.Covariate; import ca.joeltherrien.randomforest.covariates.FactorCovariateSettings; import ca.joeltherrien.randomforest.covariates.NumericCovariateSettings; +import ca.joeltherrien.randomforest.responses.competingrisk.*; +import ca.joeltherrien.randomforest.tree.Forest; import ca.joeltherrien.randomforest.tree.ForestTrainer; +import ca.joeltherrien.randomforest.utils.MathFunction; +import ca.joeltherrien.randomforest.utils.Utils; import com.fasterxml.jackson.databind.node.JsonNodeFactory; import com.fasterxml.jackson.databind.node.ObjectNode; import com.fasterxml.jackson.databind.node.TextNode; @@ -12,10 +16,7 @@ import org.apache.commons.csv.CSVFormat; import org.apache.commons.csv.CSVParser; import org.apache.commons.csv.CSVRecord; -import java.io.File; -import java.io.FileReader; -import java.io.IOException; -import java.io.Reader; +import java.io.*; import java.util.ArrayList; import java.util.HashMap; import java.util.List; @@ -25,9 +26,10 @@ import java.util.stream.Collectors; public class Main { - public static void main(String[] args) throws IOException { - if(args.length != 1){ - System.out.println("Must provide one argument - the path to the settings.yaml file."); + public static void main(String[] args) throws IOException, ClassNotFoundException { + if(args.length < 2){ + System.out.println("Must provide two arguments - the path to the settings.yaml file and instructions to either train or analyze."); + System.out.println("Note that analyzing only supports competing risk data, and that you must then specify a sample size for testing errors."); if(args.length == 0){ System.out.println("Generating template file."); defaultTemplate().save(new File("template.yaml")); @@ -40,24 +42,99 @@ public class Main { final List covariates = settings.getCovariates().stream() .map(cs -> cs.build()).collect(Collectors.toList()); + if(args[1].equalsIgnoreCase("train")){ + final List dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getDataFileLocation()); + + final ForestTrainer forestTrainer = new ForestTrainer(settings, dataset, covariates); + + if(settings.isSaveProgress()){ + forestTrainer.trainParallelOnDisk(settings.getNumberOfThreads()); + } + else{ + forestTrainer.trainParallelInMemory(settings.getNumberOfThreads()); + } + } + else if(args[1].equalsIgnoreCase("analyze")){ + // Perform different prediction measures + + if(args.length < 3){ + System.out.println("Specify error sample size"); + } + + final String yVarType = settings.getYVarSettings().get("type").asText(); + + if(!yVarType.equalsIgnoreCase("CompetingRiskResponse") && !yVarType.equalsIgnoreCase("CompetingRiskResponseWithCensorTime")){ + System.out.println("Analyze currently only works on competing risk data"); + } + + final CompetingRiskFunctionCombiner responseCombiner = (CompetingRiskFunctionCombiner) settings.getTreeCombiner(); + final int[] events = responseCombiner.getEvents(); + + final List> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getDataFileLocation()); + + // Let's reduce this down to n + final int n = Integer.parseInt(args[2]); + Utils.reduceListToSize(dataset, n); + + final File folder = new File(settings.getSaveTreeLocation()); + final Forest forest = DataLoader.loadForest(folder, responseCombiner); + + System.out.println("Finished loading trees + dataset; creating calculator and evaluating OOB predictions"); + + final CompetingRiskErrorRateCalculator errorRateCalculator = new CompetingRiskErrorRateCalculator(dataset, forest); + final PrintWriter printWriter = new PrintWriter(settings.getSaveTreeLocation() + "/errors.txt"); + + System.out.println("Running Naive Mortality"); + + final double naiveMortality = errorRateCalculator.calculateNaiveMortalityError(events); + printWriter.write("Naive Mortality: "); + printWriter.write(Double.toString(naiveMortality)); + printWriter.write('\n'); + + System.out.println("Running Naive Concordance"); + + final double[] naiveConcordance = errorRateCalculator.calculateConcordance(events); + printWriter.write("Naive concordance:\n"); + for(int i=0; i ((CompetingRiskResponseWithCensorTime) row.getResponse()).getC()) + .toArray(); + final MathFunction censorDistribution = Utils.estimateOneMinusECDF(censorTimes); + + System.out.println("Finished generating censor distribution - running concordance"); + + final double[] ipcwConcordance = errorRateCalculator.calculateIPCWConcordance(events, censorDistribution); + printWriter.write("IPCW concordance:\n"); + for(int i=0; i> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getDataFileLocation()); - - final ForestTrainer forestTrainer = new ForestTrainer<>(settings, dataset, covariates); - - if(settings.isSaveProgress()){ - forestTrainer.trainParallelOnDisk(settings.getNumberOfThreads()); + printWriter.close(); } else{ - forestTrainer.trainParallelInMemory(settings.getNumberOfThreads()); + System.out.println("Invalid instruction; use either train or analyze."); + System.out.println("Note that analyzing only supports competing risk data."); } - } - private static Settings defaultTemplate(){ final ObjectNode groupDifferentiatorSettings = new ObjectNode(JsonNodeFactory.instance); diff --git a/src/main/java/ca/joeltherrien/randomforest/covariates/Covariate.java b/src/main/java/ca/joeltherrien/randomforest/covariates/Covariate.java index e7a5aef..83760a5 100644 --- a/src/main/java/ca/joeltherrien/randomforest/covariates/Covariate.java +++ b/src/main/java/ca/joeltherrien/randomforest/covariates/Covariate.java @@ -50,7 +50,6 @@ public interface Covariate extends Serializable { final List> leftHand = new LinkedList<>(); final List> rightHand = new LinkedList<>(); - final List nonMissingDecisions = new ArrayList<>(); final List> missingValueRows = new ArrayList<>(); @@ -63,8 +62,6 @@ public interface Covariate extends Serializable { } final boolean isLeftHand = isLeftHand(value); - nonMissingDecisions.add(isLeftHand); - if(isLeftHand){ leftHand.add(row); } @@ -74,27 +71,17 @@ public interface Covariate extends Serializable { } - if(nonMissingDecisions.size() == 0 && missingValueRows.size() > 0){ - throw new IllegalArgumentException("Can't apply " + this + " when there are rows with missing data and no non-missing value rows"); - } - final Random random = ThreadLocalRandom.current(); - for(final Row missingValueRow : missingValueRows){ - final boolean randomDecision = nonMissingDecisions.get(random.nextInt(nonMissingDecisions.size())); - - if(randomDecision){ - leftHand.add(missingValueRow); - } - else{ - rightHand.add(missingValueRow); - } - } - - return new Split<>(leftHand, rightHand); + return new Split<>(leftHand, rightHand, missingValueRows); } - default boolean isLeftHand(CovariateRow row){ + default boolean isLeftHand(CovariateRow row, final double probabilityNaLeftHand){ final Value value = (Value) row.getCovariateValue(getParent().getName()); + + if(value.isNA()){ + return ThreadLocalRandom.current().nextDouble() <= probabilityNaLeftHand; + } + return isLeftHand(value); } diff --git a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskErrorRateCalculator.java b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskErrorRateCalculator.java index 3d95b84..e3926ef 100644 --- a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskErrorRateCalculator.java +++ b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskErrorRateCalculator.java @@ -191,6 +191,9 @@ public class CompetingRiskErrorRateCalculator { } final double mortalityI = mortalityArray[i]; + final double Ti = responseI.getU(); + final double G_Ti_minus = censoringDistribution.evaluatePrevious(Ti).getY(); + final double AijWeight = 1.0 / (censoringDistribution.evaluate(Ti).getY() * G_Ti_minus); for(int j=0; j= responseJ.getU() && !responseJ.isCensored() && responseJ.getDelta() != event){ // Bij == 1 - AijWeightPlusBijWeight = 1.0 / (censoringDistribution.evaluatePrevious(responseI.getU()).getY() * censoringDistribution.evaluatePrevious(responseJ.getU()).getY()); + AijWeightPlusBijWeight = 1.0 / (G_Ti_minus * censoringDistribution.evaluatePrevious(responseJ.getU()).getY()); } else{ continue; diff --git a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskFunctionCombiner.java b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskFunctionCombiner.java index 6dae8c0..4a04c79 100644 --- a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskFunctionCombiner.java +++ b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskFunctionCombiner.java @@ -16,6 +16,13 @@ public class CompetingRiskFunctionCombiner implements ResponseCombiner responses) { diff --git a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskResponseCombiner.java b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskResponseCombiner.java index c417413..761cbae 100644 --- a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskResponseCombiner.java +++ b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskResponseCombiner.java @@ -20,6 +20,14 @@ public class CompetingRiskResponseCombiner implements ResponseCombiner responses) { diff --git a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskResponseWithCensorTime.java b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskResponseWithCensorTime.java index 220cfea..7aefb88 100644 --- a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskResponseWithCensorTime.java +++ b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskResponseWithCensorTime.java @@ -11,7 +11,7 @@ import org.apache.commons.csv.CSVRecord; * */ @Data -public class CompetingRiskResponseWithCensorTime extends CompetingRiskResponse { +public final class CompetingRiskResponseWithCensorTime extends CompetingRiskResponse { private final double c; public CompetingRiskResponseWithCensorTime(int delta, double u, double c) { diff --git a/src/main/java/ca/joeltherrien/randomforest/tree/Split.java b/src/main/java/ca/joeltherrien/randomforest/tree/Split.java index eb90273..e566e64 100644 --- a/src/main/java/ca/joeltherrien/randomforest/tree/Split.java +++ b/src/main/java/ca/joeltherrien/randomforest/tree/Split.java @@ -6,7 +6,7 @@ import lombok.Data; import java.util.List; /** - * Very simple class that contains two lists; it's essentially a tuple. + * Very simple class that contains three lists; it's essentially a thruple. * * @author joel * @@ -16,5 +16,6 @@ public class Split { public final List> leftHand; public final List> rightHand; + public final List> naHand; } diff --git a/src/main/java/ca/joeltherrien/randomforest/tree/SplitNode.java b/src/main/java/ca/joeltherrien/randomforest/tree/SplitNode.java index c00249a..90d87e5 100644 --- a/src/main/java/ca/joeltherrien/randomforest/tree/SplitNode.java +++ b/src/main/java/ca/joeltherrien/randomforest/tree/SplitNode.java @@ -10,11 +10,12 @@ public class SplitNode implements Node { private final Node leftHand; private final Node rightHand; private final Covariate.SplitRule splitRule; + private final double probabilityNaLeftHand; // used when assigning NA values @Override public Y evaluate(CovariateRow row) { - if(splitRule.isLeftHand(row)){ + if(splitRule.isLeftHand(row, probabilityNaLeftHand)){ return leftHand.evaluate(row); } else{ diff --git a/src/main/java/ca/joeltherrien/randomforest/tree/TreeTrainer.java b/src/main/java/ca/joeltherrien/randomforest/tree/TreeTrainer.java index fc466f6..e33be28 100644 --- a/src/main/java/ca/joeltherrien/randomforest/tree/TreeTrainer.java +++ b/src/main/java/ca/joeltherrien/randomforest/tree/TreeTrainer.java @@ -67,10 +67,29 @@ public class TreeTrainer { final Split split = bestSplitRule.applyRule(data); // TODO optimize this as we're duplicating work done in findBestSplitRule + // We have to handle any NAs + if(split.leftHand.size() == 0 && split.rightHand.size() == 0 && split.naHand.size() > 0){ + throw new IllegalArgumentException("Can't apply " + this + " when there are rows with missing data and no non-missing value rows"); + } + + final double probabilityLeftHand = (double) split.leftHand.size() / (double) (split.leftHand.size() + split.rightHand.size()); + + final Random random = ThreadLocalRandom.current(); + for(final Row missingValueRow : split.naHand){ + final boolean randomDecision = random.nextDouble() <= probabilityLeftHand; + if(randomDecision){ + split.leftHand.add(missingValueRow); + } + else{ + split.rightHand.add(missingValueRow); + } + } + + final Node leftNode = growNode(split.leftHand, depth+1); final Node rightNode = growNode(split.rightHand, depth+1); - return new SplitNode<>(leftNode, rightNode, bestSplitRule); + return new SplitNode<>(leftNode, rightNode, bestSplitRule, probabilityLeftHand); } else{ @@ -119,13 +138,30 @@ public class TreeTrainer { for(final Covariate.SplitRule possibleRule : splitRulesToTry){ final Split possibleSplit = possibleRule.applyRule(data); + // We have to handle any NAs + if(possibleSplit.leftHand.size() == 0 && possibleSplit.rightHand.size() == 0 && possibleSplit.naHand.size() > 0){ + throw new IllegalArgumentException("Can't apply " + this + " when there are rows with missing data and no non-missing value rows"); + } + + final double probabilityLeftHand = (double) possibleSplit.leftHand.size() / (double) (possibleSplit.leftHand.size() + possibleSplit.rightHand.size()); + + final Random random = ThreadLocalRandom.current(); + for(final Row missingValueRow : possibleSplit.naHand){ + final boolean randomDecision = random.nextDouble() <= probabilityLeftHand; + if(randomDecision){ + possibleSplit.leftHand.add(missingValueRow); + } + else{ + possibleSplit.rightHand.add(missingValueRow); + } + } + final Double score = groupDifferentiator.differentiate( possibleSplit.leftHand.stream().map(row -> row.getResponse()).collect(Collectors.toList()), possibleSplit.rightHand.stream().map(row -> row.getResponse()).collect(Collectors.toList()) ); - if(score != null && !Double.isNaN(score) && (score > bestSplitScore || first)){ bestSplitRule = possibleRule; bestSplitScore = score; diff --git a/src/main/java/ca/joeltherrien/randomforest/utils/Point.java b/src/main/java/ca/joeltherrien/randomforest/utils/Point.java index c97d194..5128a80 100644 --- a/src/main/java/ca/joeltherrien/randomforest/utils/Point.java +++ b/src/main/java/ca/joeltherrien/randomforest/utils/Point.java @@ -10,6 +10,6 @@ import java.io.Serializable; */ @Data public class Point implements Serializable { - private final Double time; - private final Double y; + private final double time; + private final double y; } diff --git a/src/main/java/ca/joeltherrien/randomforest/utils/Utils.java b/src/main/java/ca/joeltherrien/randomforest/utils/Utils.java index 0b351d3..5acbe8d 100644 --- a/src/main/java/ca/joeltherrien/randomforest/utils/Utils.java +++ b/src/main/java/ca/joeltherrien/randomforest/utils/Utils.java @@ -1,6 +1,7 @@ package ca.joeltherrien.randomforest.utils; import java.util.*; +import java.util.concurrent.ThreadLocalRandom; public class Utils { @@ -36,5 +37,32 @@ public class Utils { } + public static void reduceListToSize(List list, int n){ + if(list.size() <= n){ + return; + } + + final Random random = ThreadLocalRandom.current(); + if(n > list.size()/2){ + // faster to randomly remove items + while(list.size() > n){ + final int indexToRemove = random.nextInt(list.size()); + list.remove(indexToRemove); + } + } + else{ + // Faster to create a new list + final List newList = new ArrayList<>(n); + while(newList.size() < n){ + final int indexToAdd = random.nextInt(list.size()); + newList.add(list.remove(indexToAdd)); + } + + list.clear(); + list.addAll(newList); + + } + } + } diff --git a/src/test/java/ca/joeltherrien/randomforest/TestUtils.java b/src/test/java/ca/joeltherrien/randomforest/TestUtils.java index 72bee51..2c6ed7d 100644 --- a/src/test/java/ca/joeltherrien/randomforest/TestUtils.java +++ b/src/test/java/ca/joeltherrien/randomforest/TestUtils.java @@ -5,6 +5,11 @@ import ca.joeltherrien.randomforest.utils.Point; import ca.joeltherrien.randomforest.utils.Utils; import org.junit.jupiter.api.Test; +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; + +import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertTrue; public class TestUtils { @@ -70,4 +75,31 @@ public class TestUtils { } + @Test + public void reduceListToSize(){ + final List testList = List.of(1, 2, 3, 4, 5, 6, 7, 8, 9, 10); + + for(int i=0; i<100; i++) { // want to test many times to ensure it doesn't work just due to randomness + final List testList1 = new ArrayList<>(testList); + // Test when removing elements + Utils.reduceListToSize(testList1, 7); + assertEquals(7, testList1.size()); // verify proper size + assertEquals(7, new HashSet<>(testList1).size()); // verify the items are unique + + + final List testList2 = new ArrayList<>(testList); + // Test when adding elements + Utils.reduceListToSize(testList2, 3); + assertEquals(3, testList2.size()); // verify proper size + assertEquals(3, new HashSet<>(testList2).size()); // verify the items are unique + + final List testList3 = new ArrayList<>(testList); + // verify no change + Utils.reduceListToSize(testList3, 15); + assertEquals(10, testList3.size()); // verify proper size + assertEquals(10, new HashSet<>(testList3).size()); // verify the items are unique + + } + } + }