From 05f9122b58aeeca1e3e907b78805cca46c13e37a Mon Sep 17 00:00:00 2001 From: Joel Therrien Date: Tue, 17 Jul 2018 13:54:59 -0700 Subject: [PATCH] Add capability to load trees back into memory. --- .../joeltherrien/randomforest/DataLoader.java | 42 ++++- .../competingrisk/CompetingRiskFunctions.java | 3 +- .../responses/competingrisk/MathFunction.java | 3 +- .../responses/competingrisk/Point.java | 4 +- .../randomforest/tree/Forest.java | 7 +- .../randomforest/tree/ForestTrainer.java | 24 ++- .../joeltherrien/randomforest/tree/Tree.java | 42 +++++ .../randomforest/tree/TreeTrainer.java | 9 +- .../randomforest/TestSavingLoading.java | 144 ++++++++++++++++++ 9 files changed, 252 insertions(+), 26 deletions(-) create mode 100644 src/main/java/ca/joeltherrien/randomforest/tree/Tree.java create mode 100644 src/test/java/ca/joeltherrien/randomforest/TestSavingLoading.java diff --git a/src/main/java/ca/joeltherrien/randomforest/DataLoader.java b/src/main/java/ca/joeltherrien/randomforest/DataLoader.java index 01f0992..97ebbfc 100644 --- a/src/main/java/ca/joeltherrien/randomforest/DataLoader.java +++ b/src/main/java/ca/joeltherrien/randomforest/DataLoader.java @@ -1,19 +1,17 @@ package ca.joeltherrien.randomforest; import ca.joeltherrien.randomforest.covariates.Covariate; +import ca.joeltherrien.randomforest.tree.Forest; +import ca.joeltherrien.randomforest.tree.ResponseCombiner; +import ca.joeltherrien.randomforest.tree.Tree; import com.fasterxml.jackson.databind.node.ObjectNode; import lombok.RequiredArgsConstructor; import org.apache.commons.csv.CSVFormat; import org.apache.commons.csv.CSVParser; import org.apache.commons.csv.CSVRecord; -import java.io.FileReader; -import java.io.IOException; -import java.io.Reader; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; +import java.io.*; +import java.util.*; public class DataLoader { @@ -43,6 +41,36 @@ public class DataLoader { } + public static Forest loadForest(File folder, ResponseCombiner treeResponseCombiner) throws IOException, ClassNotFoundException { + if(!folder.isDirectory()){ + throw new IllegalArgumentException("Tree directory must be a directory!"); + } + + final File[] treeFiles = folder.listFiles(((file, s) -> s.endsWith(".tree"))); + final List treeFileList = Arrays.asList(treeFiles); + + Collections.sort(treeFileList, Comparator.comparing(File::getName)); + + final List> treeList = new ArrayList<>(treeFileList.size()); + + for(final File treeFile : treeFileList){ + final ObjectInputStream inputStream = new ObjectInputStream(new FileInputStream(treeFile)); + + final Tree tree = (Tree) inputStream.readObject(); + + treeList.add(tree); + + } + + final Forest forest = Forest.builder() + .trees(treeList) + .treeResponseCombiner(treeResponseCombiner) + .build(); + + return forest; + + } + @FunctionalInterface public interface ResponseLoader{ Y parse(CSVRecord record); diff --git a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskFunctions.java b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskFunctions.java index 6fef9f8..743e507 100644 --- a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskFunctions.java +++ b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskFunctions.java @@ -4,10 +4,11 @@ import lombok.Builder; import lombok.Getter; import lombok.RequiredArgsConstructor; +import java.io.Serializable; import java.util.Map; @Builder -public class CompetingRiskFunctions { +public class CompetingRiskFunctions implements Serializable { private final Map causeSpecificHazardFunctionMap; private final Map cumulativeIncidenceFunctionMap; diff --git a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/MathFunction.java b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/MathFunction.java index fcb8cb9..fde343d 100644 --- a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/MathFunction.java +++ b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/MathFunction.java @@ -2,6 +2,7 @@ package ca.joeltherrien.randomforest.responses.competingrisk; import lombok.Getter; +import java.io.Serializable; import java.util.Collections; import java.util.Comparator; import java.util.List; @@ -12,7 +13,7 @@ import java.util.Optional; * constant at the value of the previous encountered point. * */ -public class MathFunction { +public class MathFunction implements Serializable { @Getter private final List points; diff --git a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/Point.java b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/Point.java index f417af9..56a2831 100644 --- a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/Point.java +++ b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/Point.java @@ -2,12 +2,14 @@ package ca.joeltherrien.randomforest.responses.competingrisk; import lombok.Data; +import java.io.Serializable; + /** * Represents a point in our estimate of either the cumulative hazard function or the cumulative incidence function. * */ @Data -public class Point { +public class Point implements Serializable { private final Double time; private final Double y; } diff --git a/src/main/java/ca/joeltherrien/randomforest/tree/Forest.java b/src/main/java/ca/joeltherrien/randomforest/tree/Forest.java index c25db8c..449cfcd 100644 --- a/src/main/java/ca/joeltherrien/randomforest/tree/Forest.java +++ b/src/main/java/ca/joeltherrien/randomforest/tree/Forest.java @@ -4,12 +4,13 @@ import ca.joeltherrien.randomforest.CovariateRow; import lombok.Builder; import java.util.Collection; +import java.util.Collections; import java.util.stream.Collectors; @Builder public class Forest { // O = output of trees, FO = forest output. In practice O == FO, even in competing risk & survival settings - private final Collection> trees; + private final Collection> trees; private final ResponseCombiner treeResponseCombiner; public FO evaluate(CovariateRow row){ @@ -22,4 +23,8 @@ public class Forest { // O = output of trees, FO = forest output. In prac } + public Collection> getTrees(){ + return Collections.unmodifiableCollection(trees); + } + } diff --git a/src/main/java/ca/joeltherrien/randomforest/tree/ForestTrainer.java b/src/main/java/ca/joeltherrien/randomforest/tree/ForestTrainer.java index 1617b67..fcdd111 100644 --- a/src/main/java/ca/joeltherrien/randomforest/tree/ForestTrainer.java +++ b/src/main/java/ca/joeltherrien/randomforest/tree/ForestTrainer.java @@ -1,9 +1,9 @@ package ca.joeltherrien.randomforest.tree; import ca.joeltherrien.randomforest.Bootstrapper; +import ca.joeltherrien.randomforest.Row; import ca.joeltherrien.randomforest.Settings; import ca.joeltherrien.randomforest.covariates.Covariate; -import ca.joeltherrien.randomforest.Row; import lombok.AccessLevel; import lombok.AllArgsConstructor; import lombok.Builder; @@ -12,11 +12,9 @@ import java.io.FileOutputStream; import java.io.IOException; import java.io.ObjectOutputStream; import java.util.ArrayList; -import java.util.Collections; import java.util.List; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; -import java.util.concurrent.ThreadLocalRandom; import java.util.concurrent.atomic.AtomicInteger; import java.util.stream.Collectors; import java.util.stream.Stream; @@ -54,7 +52,7 @@ public class ForestTrainer { public Forest trainSerial(){ - final List> trees = new ArrayList<>(ntree); + final List> trees = new ArrayList<>(ntree); final Bootstrapper> bootstrapper = new Bootstrapper<>(data); for(int j=0; j { // create a list that is prespecified in size (I can call the .set method at any index < ntree without // the earlier indexes being filled. - final List> trees = Stream.>generate(() -> null).limit(ntree).collect(Collectors.toList()); + final List> trees = Stream.>generate(() -> null).limit(ntree).collect(Collectors.toList()); final ExecutorService executorService = Executors.newFixedThreadPool(threads); @@ -103,7 +101,7 @@ public class ForestTrainer { if(displayProgress) { int numberTreesSet = 0; - for (final Node tree : trees) { + for (final Tree tree : trees) { if (tree != null) { numberTreesSet++; } @@ -131,7 +129,7 @@ public class ForestTrainer { final AtomicInteger treeCount = new AtomicInteger(0); // tracks how many trees are finished for(int j=0; j { } - private Node trainTree(final Bootstrapper> bootstrapper){ + private Tree trainTree(final Bootstrapper> bootstrapper){ final List> bootstrappedData = bootstrapper.bootstrap(); return treeTrainer.growTree(bootstrappedData); } - public void saveTree(final Node tree, String name) throws IOException { + public void saveTree(final Tree tree, String name) throws IOException { final String filename = saveTreeLocation + "/" + name; final ObjectOutputStream outputStream = new ObjectOutputStream(new FileOutputStream(filename)); @@ -177,9 +175,9 @@ public class ForestTrainer { private final Bootstrapper> bootstrapper; private final int treeIndex; - private final List> treeList; + private final List> treeList; - public TreeInMemoryWorker(final List> data, final int treeIndex, final List> treeList) { + public TreeInMemoryWorker(final List> data, final int treeIndex, final List> treeList) { this.bootstrapper = new Bootstrapper<>(data); this.treeIndex = treeIndex; this.treeList = treeList; @@ -188,7 +186,7 @@ public class ForestTrainer { @Override public void run() { - final Node tree = trainTree(bootstrapper); + final Tree tree = trainTree(bootstrapper); // should be okay as the list structure isn't changing treeList.set(treeIndex, tree); @@ -211,7 +209,7 @@ public class ForestTrainer { @Override public void run() { - final Node tree = trainTree(bootstrapper); + final Tree tree = trainTree(bootstrapper); try { saveTree(tree, filename); diff --git a/src/main/java/ca/joeltherrien/randomforest/tree/Tree.java b/src/main/java/ca/joeltherrien/randomforest/tree/Tree.java new file mode 100644 index 0000000..a6d4252 --- /dev/null +++ b/src/main/java/ca/joeltherrien/randomforest/tree/Tree.java @@ -0,0 +1,42 @@ +package ca.joeltherrien.randomforest.tree; + +import ca.joeltherrien.randomforest.CovariateRow; +import lombok.RequiredArgsConstructor; + +import java.util.Arrays; + +@RequiredArgsConstructor +public class Tree implements Node { + + private final Node rootNode; + private final int[] bootstrapRowIds; + private boolean bootStrapRowIdsSorted = false; + + @Override + public Y evaluate(CovariateRow row) { + return rootNode.evaluate(row); + } + + public int[] getBootstrapRowIds(){ + return bootstrapRowIds.clone(); + } + + /** + * Sort bootstrapRowIds. This is not done automatically for efficiency purposes, as in many cases we may not be interested in using bootstrapRowIds(); + * + */ + public void sortBootstrapRowIds(){ + if(!bootStrapRowIdsSorted){ + Arrays.sort(bootstrapRowIds); + bootStrapRowIdsSorted = true; + } + + } + + public boolean idInBootstrapSample(int id){ + this.sortBootstrapRowIds(); + + return Arrays.binarySearch(this.bootstrapRowIds, id) >= 0; + } + +} diff --git a/src/main/java/ca/joeltherrien/randomforest/tree/TreeTrainer.java b/src/main/java/ca/joeltherrien/randomforest/tree/TreeTrainer.java index 5bcf510..fc466f6 100644 --- a/src/main/java/ca/joeltherrien/randomforest/tree/TreeTrainer.java +++ b/src/main/java/ca/joeltherrien/randomforest/tree/TreeTrainer.java @@ -39,8 +39,13 @@ public class TreeTrainer { this.covariates = covariates; } - public Node growTree(List> data){ - return growNode(data, 0); + public Tree growTree(List> data){ + + final Node rootNode = growNode(data, 0); + final Tree tree = new Tree<>(rootNode, data.stream().mapToInt(Row::getId).toArray()); + + return tree; + } private Node growNode(List> data, int depth){ diff --git a/src/test/java/ca/joeltherrien/randomforest/TestSavingLoading.java b/src/test/java/ca/joeltherrien/randomforest/TestSavingLoading.java new file mode 100644 index 0000000..71a6ff8 --- /dev/null +++ b/src/test/java/ca/joeltherrien/randomforest/TestSavingLoading.java @@ -0,0 +1,144 @@ +package ca.joeltherrien.randomforest; + +import ca.joeltherrien.randomforest.covariates.BooleanCovariateSettings; +import ca.joeltherrien.randomforest.covariates.Covariate; +import ca.joeltherrien.randomforest.covariates.NumericCovariateSettings; +import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskFunctionCombiner; +import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskFunctions; +import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse; +import ca.joeltherrien.randomforest.tree.Forest; +import ca.joeltherrien.randomforest.tree.ForestTrainer; +import com.fasterxml.jackson.databind.node.*; +import org.junit.jupiter.api.Test; +import static org.junit.jupiter.api.Assertions.*; + +import java.io.File; +import java.io.IOException; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +public class TestSavingLoading { + + private final int NTREE = 10; + + /** + * By default uses single log-rank test. + * + * @return + */ + public Settings getSettings(){ + final ObjectNode groupDifferentiatorSettings = new ObjectNode(JsonNodeFactory.instance); + groupDifferentiatorSettings.set("type", new TextNode("LogRankSingleGroupDifferentiator")); + groupDifferentiatorSettings.set("eventOfFocus", new IntNode(1)); + + final ObjectNode responseCombinerSettings = new ObjectNode(JsonNodeFactory.instance); + responseCombinerSettings.set("type", new TextNode("CompetingRiskResponseCombiner")); + responseCombinerSettings.set("events", + new ArrayNode(JsonNodeFactory.instance, List.of(new IntNode(1), new IntNode(2))) + ); + // not setting times + + + final ObjectNode treeCombinerSettings = new ObjectNode(JsonNodeFactory.instance); + treeCombinerSettings.set("type", new TextNode("CompetingRiskFunctionCombiner")); + treeCombinerSettings.set("events", + new ArrayNode(JsonNodeFactory.instance, List.of(new IntNode(1), new IntNode(2))) + ); + // not setting times + + final ObjectNode yVarSettings = new ObjectNode(JsonNodeFactory.instance); + yVarSettings.set("type", new TextNode("CompetingRiskResponse")); + yVarSettings.set("u", new TextNode("time")); + yVarSettings.set("delta", new TextNode("status")); + + return Settings.builder() + .covariates(List.of( + new NumericCovariateSettings("ageatfda"), + new BooleanCovariateSettings("idu"), + new BooleanCovariateSettings("black"), + new NumericCovariateSettings("cd4nadir") + ) + ) + .dataFileLocation("src/test/resources/wihs.csv") + .responseCombinerSettings(responseCombinerSettings) + .treeCombinerSettings(treeCombinerSettings) + .groupDifferentiatorSettings(groupDifferentiatorSettings) + .yVarSettings(yVarSettings) + .maxNodeDepth(100000) + // TODO fill in these settings + .mtry(2) + .nodeSize(6) + .ntree(NTREE) + .numberOfSplits(5) + .numberOfThreads(3) + .saveProgress(true) + .saveTreeLocation("src/test/resources/trees/") + .build(); + } + + public CovariateRow getPredictionRow(List covariates){ + return CovariateRow.createSimple(Map.of( + "ageatfda", "35", + "idu", "false", + "black", "false", + "cd4nadir", "0.81") + , covariates, 1); + } + + public List getCovariates(Settings settings){ + return settings.getCovariates().stream().map(covariateSettings -> covariateSettings.build()).collect(Collectors.toList()); + } + + @Test + public void testSavingLoading() throws IOException, ClassNotFoundException { + final Settings settings = getSettings(); + final List covariates = getCovariates(settings); + final List> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getDataFileLocation()); + + final File directory = new File(settings.getSaveTreeLocation()); + assertFalse(directory.exists()); + directory.mkdir(); + + final ForestTrainer forestTrainer = new ForestTrainer<>(settings, dataset, covariates); + + forestTrainer.trainParallelOnDisk(1); + + assertTrue(directory.exists()); + assertTrue(directory.isDirectory()); + assertEquals(NTREE, directory.listFiles().length); + + + + final Forest forest = DataLoader.loadForest(directory, new CompetingRiskFunctionCombiner(new int[]{1,2}, null)); + + final CovariateRow predictionRow = getPredictionRow(covariates); + + final CompetingRiskFunctions functions = forest.evaluate(predictionRow); + assertNotNull(functions); + assertTrue(functions.getCumulativeIncidenceFunction(1).getPoints().size() > 2); + + + assertEquals(NTREE, forest.getTrees().size()); + + cleanup(directory); + + assertFalse(directory.exists()); + + } + + private void cleanup(File file){ + if(file.isFile()){ + file.delete(); + } + else{ + for(final File inner : file.listFiles()){ + cleanup(inner); + } + file.delete(); + } + + + } + +}