Add capability to load trees back into memory.

This commit is contained in:
Joel Therrien 2018-07-17 13:54:59 -07:00
parent fffdfe85bf
commit 05f9122b58
9 changed files with 252 additions and 26 deletions

View file

@ -1,19 +1,17 @@
package ca.joeltherrien.randomforest;
import ca.joeltherrien.randomforest.covariates.Covariate;
import ca.joeltherrien.randomforest.tree.Forest;
import ca.joeltherrien.randomforest.tree.ResponseCombiner;
import ca.joeltherrien.randomforest.tree.Tree;
import com.fasterxml.jackson.databind.node.ObjectNode;
import lombok.RequiredArgsConstructor;
import org.apache.commons.csv.CSVFormat;
import org.apache.commons.csv.CSVParser;
import org.apache.commons.csv.CSVRecord;
import java.io.FileReader;
import java.io.IOException;
import java.io.Reader;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.io.*;
import java.util.*;
public class DataLoader {
@ -43,6 +41,36 @@ public class DataLoader {
}
public static <O, FO> Forest<O, FO> loadForest(File folder, ResponseCombiner<O, FO> treeResponseCombiner) throws IOException, ClassNotFoundException {
if(!folder.isDirectory()){
throw new IllegalArgumentException("Tree directory must be a directory!");
}
final File[] treeFiles = folder.listFiles(((file, s) -> s.endsWith(".tree")));
final List<File> treeFileList = Arrays.asList(treeFiles);
Collections.sort(treeFileList, Comparator.comparing(File::getName));
final List<Tree<O>> treeList = new ArrayList<>(treeFileList.size());
for(final File treeFile : treeFileList){
final ObjectInputStream inputStream = new ObjectInputStream(new FileInputStream(treeFile));
final Tree<O> tree = (Tree) inputStream.readObject();
treeList.add(tree);
}
final Forest forest = Forest.<O, FO>builder()
.trees(treeList)
.treeResponseCombiner(treeResponseCombiner)
.build();
return forest;
}
@FunctionalInterface
public interface ResponseLoader<Y>{
Y parse(CSVRecord record);

View file

@ -4,10 +4,11 @@ import lombok.Builder;
import lombok.Getter;
import lombok.RequiredArgsConstructor;
import java.io.Serializable;
import java.util.Map;
@Builder
public class CompetingRiskFunctions {
public class CompetingRiskFunctions implements Serializable {
private final Map<Integer, MathFunction> causeSpecificHazardFunctionMap;
private final Map<Integer, MathFunction> cumulativeIncidenceFunctionMap;

View file

@ -2,6 +2,7 @@ package ca.joeltherrien.randomforest.responses.competingrisk;
import lombok.Getter;
import java.io.Serializable;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
@ -12,7 +13,7 @@ import java.util.Optional;
* constant at the value of the previous encountered point.
*
*/
public class MathFunction {
public class MathFunction implements Serializable {
@Getter
private final List<Point> points;

View file

@ -2,12 +2,14 @@ package ca.joeltherrien.randomforest.responses.competingrisk;
import lombok.Data;
import java.io.Serializable;
/**
* Represents a point in our estimate of either the cumulative hazard function or the cumulative incidence function.
*
*/
@Data
public class Point {
public class Point implements Serializable {
private final Double time;
private final Double y;
}

View file

@ -4,12 +4,13 @@ import ca.joeltherrien.randomforest.CovariateRow;
import lombok.Builder;
import java.util.Collection;
import java.util.Collections;
import java.util.stream.Collectors;
@Builder
public class Forest<O, FO> { // O = output of trees, FO = forest output. In practice O == FO, even in competing risk & survival settings
private final Collection<Node<O>> trees;
private final Collection<Tree<O>> trees;
private final ResponseCombiner<O, FO> treeResponseCombiner;
public FO evaluate(CovariateRow row){
@ -22,4 +23,8 @@ public class Forest<O, FO> { // O = output of trees, FO = forest output. In prac
}
public Collection<Tree<O>> getTrees(){
return Collections.unmodifiableCollection(trees);
}
}

View file

@ -1,9 +1,9 @@
package ca.joeltherrien.randomforest.tree;
import ca.joeltherrien.randomforest.Bootstrapper;
import ca.joeltherrien.randomforest.Row;
import ca.joeltherrien.randomforest.Settings;
import ca.joeltherrien.randomforest.covariates.Covariate;
import ca.joeltherrien.randomforest.Row;
import lombok.AccessLevel;
import lombok.AllArgsConstructor;
import lombok.Builder;
@ -12,11 +12,9 @@ import java.io.FileOutputStream;
import java.io.IOException;
import java.io.ObjectOutputStream;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;
import java.util.stream.Stream;
@ -54,7 +52,7 @@ public class ForestTrainer<Y, TO, FO> {
public Forest<TO, FO> trainSerial(){
final List<Node<TO>> trees = new ArrayList<>(ntree);
final List<Tree<TO>> trees = new ArrayList<>(ntree);
final Bootstrapper<Row<Y>> bootstrapper = new Bootstrapper<>(data);
for(int j=0; j<ntree; j++){
@ -83,7 +81,7 @@ public class ForestTrainer<Y, TO, FO> {
// create a list that is prespecified in size (I can call the .set method at any index < ntree without
// the earlier indexes being filled.
final List<Node<TO>> trees = Stream.<Node<TO>>generate(() -> null).limit(ntree).collect(Collectors.toList());
final List<Tree<TO>> trees = Stream.<Tree<TO>>generate(() -> null).limit(ntree).collect(Collectors.toList());
final ExecutorService executorService = Executors.newFixedThreadPool(threads);
@ -103,7 +101,7 @@ public class ForestTrainer<Y, TO, FO> {
if(displayProgress) {
int numberTreesSet = 0;
for (final Node<TO> tree : trees) {
for (final Tree<TO> tree : trees) {
if (tree != null) {
numberTreesSet++;
}
@ -131,7 +129,7 @@ public class ForestTrainer<Y, TO, FO> {
final AtomicInteger treeCount = new AtomicInteger(0); // tracks how many trees are finished
for(int j=0; j<ntree; j++){
final Runnable worker = new TreeSavedWorker(data, "tree-" + (j+1), treeCount);
final Runnable worker = new TreeSavedWorker(data, "tree-" + (j+1) + ".tree", treeCount);
executorService.execute(worker);
}
@ -157,12 +155,12 @@ public class ForestTrainer<Y, TO, FO> {
}
private Node<TO> trainTree(final Bootstrapper<Row<Y>> bootstrapper){
private Tree<TO> trainTree(final Bootstrapper<Row<Y>> bootstrapper){
final List<Row<Y>> bootstrappedData = bootstrapper.bootstrap();
return treeTrainer.growTree(bootstrappedData);
}
public void saveTree(final Node<TO> tree, String name) throws IOException {
public void saveTree(final Tree<TO> tree, String name) throws IOException {
final String filename = saveTreeLocation + "/" + name;
final ObjectOutputStream outputStream = new ObjectOutputStream(new FileOutputStream(filename));
@ -177,9 +175,9 @@ public class ForestTrainer<Y, TO, FO> {
private final Bootstrapper<Row<Y>> bootstrapper;
private final int treeIndex;
private final List<Node<TO>> treeList;
private final List<Tree<TO>> treeList;
public TreeInMemoryWorker(final List<Row<Y>> data, final int treeIndex, final List<Node<TO>> treeList) {
public TreeInMemoryWorker(final List<Row<Y>> data, final int treeIndex, final List<Tree<TO>> treeList) {
this.bootstrapper = new Bootstrapper<>(data);
this.treeIndex = treeIndex;
this.treeList = treeList;
@ -188,7 +186,7 @@ public class ForestTrainer<Y, TO, FO> {
@Override
public void run() {
final Node<TO> tree = trainTree(bootstrapper);
final Tree<TO> tree = trainTree(bootstrapper);
// should be okay as the list structure isn't changing
treeList.set(treeIndex, tree);
@ -211,7 +209,7 @@ public class ForestTrainer<Y, TO, FO> {
@Override
public void run() {
final Node<TO> tree = trainTree(bootstrapper);
final Tree<TO> tree = trainTree(bootstrapper);
try {
saveTree(tree, filename);

View file

@ -0,0 +1,42 @@
package ca.joeltherrien.randomforest.tree;
import ca.joeltherrien.randomforest.CovariateRow;
import lombok.RequiredArgsConstructor;
import java.util.Arrays;
@RequiredArgsConstructor
public class Tree<Y> implements Node<Y> {
private final Node<Y> rootNode;
private final int[] bootstrapRowIds;
private boolean bootStrapRowIdsSorted = false;
@Override
public Y evaluate(CovariateRow row) {
return rootNode.evaluate(row);
}
public int[] getBootstrapRowIds(){
return bootstrapRowIds.clone();
}
/**
* Sort bootstrapRowIds. This is not done automatically for efficiency purposes, as in many cases we may not be interested in using bootstrapRowIds();
*
*/
public void sortBootstrapRowIds(){
if(!bootStrapRowIdsSorted){
Arrays.sort(bootstrapRowIds);
bootStrapRowIdsSorted = true;
}
}
public boolean idInBootstrapSample(int id){
this.sortBootstrapRowIds();
return Arrays.binarySearch(this.bootstrapRowIds, id) >= 0;
}
}

View file

@ -39,8 +39,13 @@ public class TreeTrainer<Y, O> {
this.covariates = covariates;
}
public Node<O> growTree(List<Row<Y>> data){
return growNode(data, 0);
public Tree<O> growTree(List<Row<Y>> data){
final Node<O> rootNode = growNode(data, 0);
final Tree<O> tree = new Tree<>(rootNode, data.stream().mapToInt(Row::getId).toArray());
return tree;
}
private Node<O> growNode(List<Row<Y>> data, int depth){

View file

@ -0,0 +1,144 @@
package ca.joeltherrien.randomforest;
import ca.joeltherrien.randomforest.covariates.BooleanCovariateSettings;
import ca.joeltherrien.randomforest.covariates.Covariate;
import ca.joeltherrien.randomforest.covariates.NumericCovariateSettings;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskFunctionCombiner;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskFunctions;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
import ca.joeltherrien.randomforest.tree.Forest;
import ca.joeltherrien.randomforest.tree.ForestTrainer;
import com.fasterxml.jackson.databind.node.*;
import org.junit.jupiter.api.Test;
import static org.junit.jupiter.api.Assertions.*;
import java.io.File;
import java.io.IOException;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
public class TestSavingLoading {
private final int NTREE = 10;
/**
* By default uses single log-rank test.
*
* @return
*/
public Settings getSettings(){
final ObjectNode groupDifferentiatorSettings = new ObjectNode(JsonNodeFactory.instance);
groupDifferentiatorSettings.set("type", new TextNode("LogRankSingleGroupDifferentiator"));
groupDifferentiatorSettings.set("eventOfFocus", new IntNode(1));
final ObjectNode responseCombinerSettings = new ObjectNode(JsonNodeFactory.instance);
responseCombinerSettings.set("type", new TextNode("CompetingRiskResponseCombiner"));
responseCombinerSettings.set("events",
new ArrayNode(JsonNodeFactory.instance, List.of(new IntNode(1), new IntNode(2)))
);
// not setting times
final ObjectNode treeCombinerSettings = new ObjectNode(JsonNodeFactory.instance);
treeCombinerSettings.set("type", new TextNode("CompetingRiskFunctionCombiner"));
treeCombinerSettings.set("events",
new ArrayNode(JsonNodeFactory.instance, List.of(new IntNode(1), new IntNode(2)))
);
// not setting times
final ObjectNode yVarSettings = new ObjectNode(JsonNodeFactory.instance);
yVarSettings.set("type", new TextNode("CompetingRiskResponse"));
yVarSettings.set("u", new TextNode("time"));
yVarSettings.set("delta", new TextNode("status"));
return Settings.builder()
.covariates(List.of(
new NumericCovariateSettings("ageatfda"),
new BooleanCovariateSettings("idu"),
new BooleanCovariateSettings("black"),
new NumericCovariateSettings("cd4nadir")
)
)
.dataFileLocation("src/test/resources/wihs.csv")
.responseCombinerSettings(responseCombinerSettings)
.treeCombinerSettings(treeCombinerSettings)
.groupDifferentiatorSettings(groupDifferentiatorSettings)
.yVarSettings(yVarSettings)
.maxNodeDepth(100000)
// TODO fill in these settings
.mtry(2)
.nodeSize(6)
.ntree(NTREE)
.numberOfSplits(5)
.numberOfThreads(3)
.saveProgress(true)
.saveTreeLocation("src/test/resources/trees/")
.build();
}
public CovariateRow getPredictionRow(List<Covariate> covariates){
return CovariateRow.createSimple(Map.of(
"ageatfda", "35",
"idu", "false",
"black", "false",
"cd4nadir", "0.81")
, covariates, 1);
}
public List<Covariate> getCovariates(Settings settings){
return settings.getCovariates().stream().map(covariateSettings -> covariateSettings.build()).collect(Collectors.toList());
}
@Test
public void testSavingLoading() throws IOException, ClassNotFoundException {
final Settings settings = getSettings();
final List<Covariate> covariates = getCovariates(settings);
final List<Row<CompetingRiskResponse>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getDataFileLocation());
final File directory = new File(settings.getSaveTreeLocation());
assertFalse(directory.exists());
directory.mkdir();
final ForestTrainer<CompetingRiskResponse, CompetingRiskFunctions, CompetingRiskFunctions> forestTrainer = new ForestTrainer<>(settings, dataset, covariates);
forestTrainer.trainParallelOnDisk(1);
assertTrue(directory.exists());
assertTrue(directory.isDirectory());
assertEquals(NTREE, directory.listFiles().length);
final Forest<CompetingRiskFunctions, CompetingRiskFunctions> forest = DataLoader.loadForest(directory, new CompetingRiskFunctionCombiner(new int[]{1,2}, null));
final CovariateRow predictionRow = getPredictionRow(covariates);
final CompetingRiskFunctions functions = forest.evaluate(predictionRow);
assertNotNull(functions);
assertTrue(functions.getCumulativeIncidenceFunction(1).getPoints().size() > 2);
assertEquals(NTREE, forest.getTrees().size());
cleanup(directory);
assertFalse(directory.exists());
}
private void cleanup(File file){
if(file.isFile()){
file.delete();
}
else{
for(final File inner : file.listFiles()){
cleanup(inner);
}
file.delete();
}
}
}