Add capability to load trees back into memory.
This commit is contained in:
parent
fffdfe85bf
commit
05f9122b58
9 changed files with 252 additions and 26 deletions
|
@ -1,19 +1,17 @@
|
|||
package ca.joeltherrien.randomforest;
|
||||
|
||||
import ca.joeltherrien.randomforest.covariates.Covariate;
|
||||
import ca.joeltherrien.randomforest.tree.Forest;
|
||||
import ca.joeltherrien.randomforest.tree.ResponseCombiner;
|
||||
import ca.joeltherrien.randomforest.tree.Tree;
|
||||
import com.fasterxml.jackson.databind.node.ObjectNode;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import org.apache.commons.csv.CSVFormat;
|
||||
import org.apache.commons.csv.CSVParser;
|
||||
import org.apache.commons.csv.CSVRecord;
|
||||
|
||||
import java.io.FileReader;
|
||||
import java.io.IOException;
|
||||
import java.io.Reader;
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.io.*;
|
||||
import java.util.*;
|
||||
|
||||
public class DataLoader {
|
||||
|
||||
|
@ -43,6 +41,36 @@ public class DataLoader {
|
|||
|
||||
}
|
||||
|
||||
public static <O, FO> Forest<O, FO> loadForest(File folder, ResponseCombiner<O, FO> treeResponseCombiner) throws IOException, ClassNotFoundException {
|
||||
if(!folder.isDirectory()){
|
||||
throw new IllegalArgumentException("Tree directory must be a directory!");
|
||||
}
|
||||
|
||||
final File[] treeFiles = folder.listFiles(((file, s) -> s.endsWith(".tree")));
|
||||
final List<File> treeFileList = Arrays.asList(treeFiles);
|
||||
|
||||
Collections.sort(treeFileList, Comparator.comparing(File::getName));
|
||||
|
||||
final List<Tree<O>> treeList = new ArrayList<>(treeFileList.size());
|
||||
|
||||
for(final File treeFile : treeFileList){
|
||||
final ObjectInputStream inputStream = new ObjectInputStream(new FileInputStream(treeFile));
|
||||
|
||||
final Tree<O> tree = (Tree) inputStream.readObject();
|
||||
|
||||
treeList.add(tree);
|
||||
|
||||
}
|
||||
|
||||
final Forest forest = Forest.<O, FO>builder()
|
||||
.trees(treeList)
|
||||
.treeResponseCombiner(treeResponseCombiner)
|
||||
.build();
|
||||
|
||||
return forest;
|
||||
|
||||
}
|
||||
|
||||
@FunctionalInterface
|
||||
public interface ResponseLoader<Y>{
|
||||
Y parse(CSVRecord record);
|
||||
|
|
|
@ -4,10 +4,11 @@ import lombok.Builder;
|
|||
import lombok.Getter;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
|
||||
import java.io.Serializable;
|
||||
import java.util.Map;
|
||||
|
||||
@Builder
|
||||
public class CompetingRiskFunctions {
|
||||
public class CompetingRiskFunctions implements Serializable {
|
||||
|
||||
private final Map<Integer, MathFunction> causeSpecificHazardFunctionMap;
|
||||
private final Map<Integer, MathFunction> cumulativeIncidenceFunctionMap;
|
||||
|
|
|
@ -2,6 +2,7 @@ package ca.joeltherrien.randomforest.responses.competingrisk;
|
|||
|
||||
import lombok.Getter;
|
||||
|
||||
import java.io.Serializable;
|
||||
import java.util.Collections;
|
||||
import java.util.Comparator;
|
||||
import java.util.List;
|
||||
|
@ -12,7 +13,7 @@ import java.util.Optional;
|
|||
* constant at the value of the previous encountered point.
|
||||
*
|
||||
*/
|
||||
public class MathFunction {
|
||||
public class MathFunction implements Serializable {
|
||||
|
||||
@Getter
|
||||
private final List<Point> points;
|
||||
|
|
|
@ -2,12 +2,14 @@ package ca.joeltherrien.randomforest.responses.competingrisk;
|
|||
|
||||
import lombok.Data;
|
||||
|
||||
import java.io.Serializable;
|
||||
|
||||
/**
|
||||
* Represents a point in our estimate of either the cumulative hazard function or the cumulative incidence function.
|
||||
*
|
||||
*/
|
||||
@Data
|
||||
public class Point {
|
||||
public class Point implements Serializable {
|
||||
private final Double time;
|
||||
private final Double y;
|
||||
}
|
||||
|
|
|
@ -4,12 +4,13 @@ import ca.joeltherrien.randomforest.CovariateRow;
|
|||
import lombok.Builder;
|
||||
|
||||
import java.util.Collection;
|
||||
import java.util.Collections;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@Builder
|
||||
public class Forest<O, FO> { // O = output of trees, FO = forest output. In practice O == FO, even in competing risk & survival settings
|
||||
|
||||
private final Collection<Node<O>> trees;
|
||||
private final Collection<Tree<O>> trees;
|
||||
private final ResponseCombiner<O, FO> treeResponseCombiner;
|
||||
|
||||
public FO evaluate(CovariateRow row){
|
||||
|
@ -22,4 +23,8 @@ public class Forest<O, FO> { // O = output of trees, FO = forest output. In prac
|
|||
|
||||
}
|
||||
|
||||
public Collection<Tree<O>> getTrees(){
|
||||
return Collections.unmodifiableCollection(trees);
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
package ca.joeltherrien.randomforest.tree;
|
||||
|
||||
import ca.joeltherrien.randomforest.Bootstrapper;
|
||||
import ca.joeltherrien.randomforest.Row;
|
||||
import ca.joeltherrien.randomforest.Settings;
|
||||
import ca.joeltherrien.randomforest.covariates.Covariate;
|
||||
import ca.joeltherrien.randomforest.Row;
|
||||
import lombok.AccessLevel;
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Builder;
|
||||
|
@ -12,11 +12,9 @@ import java.io.FileOutputStream;
|
|||
import java.io.IOException;
|
||||
import java.io.ObjectOutputStream;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.ExecutorService;
|
||||
import java.util.concurrent.Executors;
|
||||
import java.util.concurrent.ThreadLocalRandom;
|
||||
import java.util.concurrent.atomic.AtomicInteger;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.Stream;
|
||||
|
@ -54,7 +52,7 @@ public class ForestTrainer<Y, TO, FO> {
|
|||
|
||||
public Forest<TO, FO> trainSerial(){
|
||||
|
||||
final List<Node<TO>> trees = new ArrayList<>(ntree);
|
||||
final List<Tree<TO>> trees = new ArrayList<>(ntree);
|
||||
final Bootstrapper<Row<Y>> bootstrapper = new Bootstrapper<>(data);
|
||||
|
||||
for(int j=0; j<ntree; j++){
|
||||
|
@ -83,7 +81,7 @@ public class ForestTrainer<Y, TO, FO> {
|
|||
|
||||
// create a list that is prespecified in size (I can call the .set method at any index < ntree without
|
||||
// the earlier indexes being filled.
|
||||
final List<Node<TO>> trees = Stream.<Node<TO>>generate(() -> null).limit(ntree).collect(Collectors.toList());
|
||||
final List<Tree<TO>> trees = Stream.<Tree<TO>>generate(() -> null).limit(ntree).collect(Collectors.toList());
|
||||
|
||||
final ExecutorService executorService = Executors.newFixedThreadPool(threads);
|
||||
|
||||
|
@ -103,7 +101,7 @@ public class ForestTrainer<Y, TO, FO> {
|
|||
|
||||
if(displayProgress) {
|
||||
int numberTreesSet = 0;
|
||||
for (final Node<TO> tree : trees) {
|
||||
for (final Tree<TO> tree : trees) {
|
||||
if (tree != null) {
|
||||
numberTreesSet++;
|
||||
}
|
||||
|
@ -131,7 +129,7 @@ public class ForestTrainer<Y, TO, FO> {
|
|||
final AtomicInteger treeCount = new AtomicInteger(0); // tracks how many trees are finished
|
||||
|
||||
for(int j=0; j<ntree; j++){
|
||||
final Runnable worker = new TreeSavedWorker(data, "tree-" + (j+1), treeCount);
|
||||
final Runnable worker = new TreeSavedWorker(data, "tree-" + (j+1) + ".tree", treeCount);
|
||||
executorService.execute(worker);
|
||||
}
|
||||
|
||||
|
@ -157,12 +155,12 @@ public class ForestTrainer<Y, TO, FO> {
|
|||
|
||||
}
|
||||
|
||||
private Node<TO> trainTree(final Bootstrapper<Row<Y>> bootstrapper){
|
||||
private Tree<TO> trainTree(final Bootstrapper<Row<Y>> bootstrapper){
|
||||
final List<Row<Y>> bootstrappedData = bootstrapper.bootstrap();
|
||||
return treeTrainer.growTree(bootstrappedData);
|
||||
}
|
||||
|
||||
public void saveTree(final Node<TO> tree, String name) throws IOException {
|
||||
public void saveTree(final Tree<TO> tree, String name) throws IOException {
|
||||
final String filename = saveTreeLocation + "/" + name;
|
||||
|
||||
final ObjectOutputStream outputStream = new ObjectOutputStream(new FileOutputStream(filename));
|
||||
|
@ -177,9 +175,9 @@ public class ForestTrainer<Y, TO, FO> {
|
|||
|
||||
private final Bootstrapper<Row<Y>> bootstrapper;
|
||||
private final int treeIndex;
|
||||
private final List<Node<TO>> treeList;
|
||||
private final List<Tree<TO>> treeList;
|
||||
|
||||
public TreeInMemoryWorker(final List<Row<Y>> data, final int treeIndex, final List<Node<TO>> treeList) {
|
||||
public TreeInMemoryWorker(final List<Row<Y>> data, final int treeIndex, final List<Tree<TO>> treeList) {
|
||||
this.bootstrapper = new Bootstrapper<>(data);
|
||||
this.treeIndex = treeIndex;
|
||||
this.treeList = treeList;
|
||||
|
@ -188,7 +186,7 @@ public class ForestTrainer<Y, TO, FO> {
|
|||
@Override
|
||||
public void run() {
|
||||
|
||||
final Node<TO> tree = trainTree(bootstrapper);
|
||||
final Tree<TO> tree = trainTree(bootstrapper);
|
||||
|
||||
// should be okay as the list structure isn't changing
|
||||
treeList.set(treeIndex, tree);
|
||||
|
@ -211,7 +209,7 @@ public class ForestTrainer<Y, TO, FO> {
|
|||
@Override
|
||||
public void run() {
|
||||
|
||||
final Node<TO> tree = trainTree(bootstrapper);
|
||||
final Tree<TO> tree = trainTree(bootstrapper);
|
||||
|
||||
try {
|
||||
saveTree(tree, filename);
|
||||
|
|
42
src/main/java/ca/joeltherrien/randomforest/tree/Tree.java
Normal file
42
src/main/java/ca/joeltherrien/randomforest/tree/Tree.java
Normal file
|
@ -0,0 +1,42 @@
|
|||
package ca.joeltherrien.randomforest.tree;
|
||||
|
||||
import ca.joeltherrien.randomforest.CovariateRow;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
|
||||
import java.util.Arrays;
|
||||
|
||||
@RequiredArgsConstructor
|
||||
public class Tree<Y> implements Node<Y> {
|
||||
|
||||
private final Node<Y> rootNode;
|
||||
private final int[] bootstrapRowIds;
|
||||
private boolean bootStrapRowIdsSorted = false;
|
||||
|
||||
@Override
|
||||
public Y evaluate(CovariateRow row) {
|
||||
return rootNode.evaluate(row);
|
||||
}
|
||||
|
||||
public int[] getBootstrapRowIds(){
|
||||
return bootstrapRowIds.clone();
|
||||
}
|
||||
|
||||
/**
|
||||
* Sort bootstrapRowIds. This is not done automatically for efficiency purposes, as in many cases we may not be interested in using bootstrapRowIds();
|
||||
*
|
||||
*/
|
||||
public void sortBootstrapRowIds(){
|
||||
if(!bootStrapRowIdsSorted){
|
||||
Arrays.sort(bootstrapRowIds);
|
||||
bootStrapRowIdsSorted = true;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
public boolean idInBootstrapSample(int id){
|
||||
this.sortBootstrapRowIds();
|
||||
|
||||
return Arrays.binarySearch(this.bootstrapRowIds, id) >= 0;
|
||||
}
|
||||
|
||||
}
|
|
@ -39,8 +39,13 @@ public class TreeTrainer<Y, O> {
|
|||
this.covariates = covariates;
|
||||
}
|
||||
|
||||
public Node<O> growTree(List<Row<Y>> data){
|
||||
return growNode(data, 0);
|
||||
public Tree<O> growTree(List<Row<Y>> data){
|
||||
|
||||
final Node<O> rootNode = growNode(data, 0);
|
||||
final Tree<O> tree = new Tree<>(rootNode, data.stream().mapToInt(Row::getId).toArray());
|
||||
|
||||
return tree;
|
||||
|
||||
}
|
||||
|
||||
private Node<O> growNode(List<Row<Y>> data, int depth){
|
||||
|
|
|
@ -0,0 +1,144 @@
|
|||
package ca.joeltherrien.randomforest;
|
||||
|
||||
import ca.joeltherrien.randomforest.covariates.BooleanCovariateSettings;
|
||||
import ca.joeltherrien.randomforest.covariates.Covariate;
|
||||
import ca.joeltherrien.randomforest.covariates.NumericCovariateSettings;
|
||||
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskFunctionCombiner;
|
||||
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskFunctions;
|
||||
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
|
||||
import ca.joeltherrien.randomforest.tree.Forest;
|
||||
import ca.joeltherrien.randomforest.tree.ForestTrainer;
|
||||
import com.fasterxml.jackson.databind.node.*;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import static org.junit.jupiter.api.Assertions.*;
|
||||
|
||||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
public class TestSavingLoading {
|
||||
|
||||
private final int NTREE = 10;
|
||||
|
||||
/**
|
||||
* By default uses single log-rank test.
|
||||
*
|
||||
* @return
|
||||
*/
|
||||
public Settings getSettings(){
|
||||
final ObjectNode groupDifferentiatorSettings = new ObjectNode(JsonNodeFactory.instance);
|
||||
groupDifferentiatorSettings.set("type", new TextNode("LogRankSingleGroupDifferentiator"));
|
||||
groupDifferentiatorSettings.set("eventOfFocus", new IntNode(1));
|
||||
|
||||
final ObjectNode responseCombinerSettings = new ObjectNode(JsonNodeFactory.instance);
|
||||
responseCombinerSettings.set("type", new TextNode("CompetingRiskResponseCombiner"));
|
||||
responseCombinerSettings.set("events",
|
||||
new ArrayNode(JsonNodeFactory.instance, List.of(new IntNode(1), new IntNode(2)))
|
||||
);
|
||||
// not setting times
|
||||
|
||||
|
||||
final ObjectNode treeCombinerSettings = new ObjectNode(JsonNodeFactory.instance);
|
||||
treeCombinerSettings.set("type", new TextNode("CompetingRiskFunctionCombiner"));
|
||||
treeCombinerSettings.set("events",
|
||||
new ArrayNode(JsonNodeFactory.instance, List.of(new IntNode(1), new IntNode(2)))
|
||||
);
|
||||
// not setting times
|
||||
|
||||
final ObjectNode yVarSettings = new ObjectNode(JsonNodeFactory.instance);
|
||||
yVarSettings.set("type", new TextNode("CompetingRiskResponse"));
|
||||
yVarSettings.set("u", new TextNode("time"));
|
||||
yVarSettings.set("delta", new TextNode("status"));
|
||||
|
||||
return Settings.builder()
|
||||
.covariates(List.of(
|
||||
new NumericCovariateSettings("ageatfda"),
|
||||
new BooleanCovariateSettings("idu"),
|
||||
new BooleanCovariateSettings("black"),
|
||||
new NumericCovariateSettings("cd4nadir")
|
||||
)
|
||||
)
|
||||
.dataFileLocation("src/test/resources/wihs.csv")
|
||||
.responseCombinerSettings(responseCombinerSettings)
|
||||
.treeCombinerSettings(treeCombinerSettings)
|
||||
.groupDifferentiatorSettings(groupDifferentiatorSettings)
|
||||
.yVarSettings(yVarSettings)
|
||||
.maxNodeDepth(100000)
|
||||
// TODO fill in these settings
|
||||
.mtry(2)
|
||||
.nodeSize(6)
|
||||
.ntree(NTREE)
|
||||
.numberOfSplits(5)
|
||||
.numberOfThreads(3)
|
||||
.saveProgress(true)
|
||||
.saveTreeLocation("src/test/resources/trees/")
|
||||
.build();
|
||||
}
|
||||
|
||||
public CovariateRow getPredictionRow(List<Covariate> covariates){
|
||||
return CovariateRow.createSimple(Map.of(
|
||||
"ageatfda", "35",
|
||||
"idu", "false",
|
||||
"black", "false",
|
||||
"cd4nadir", "0.81")
|
||||
, covariates, 1);
|
||||
}
|
||||
|
||||
public List<Covariate> getCovariates(Settings settings){
|
||||
return settings.getCovariates().stream().map(covariateSettings -> covariateSettings.build()).collect(Collectors.toList());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testSavingLoading() throws IOException, ClassNotFoundException {
|
||||
final Settings settings = getSettings();
|
||||
final List<Covariate> covariates = getCovariates(settings);
|
||||
final List<Row<CompetingRiskResponse>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getDataFileLocation());
|
||||
|
||||
final File directory = new File(settings.getSaveTreeLocation());
|
||||
assertFalse(directory.exists());
|
||||
directory.mkdir();
|
||||
|
||||
final ForestTrainer<CompetingRiskResponse, CompetingRiskFunctions, CompetingRiskFunctions> forestTrainer = new ForestTrainer<>(settings, dataset, covariates);
|
||||
|
||||
forestTrainer.trainParallelOnDisk(1);
|
||||
|
||||
assertTrue(directory.exists());
|
||||
assertTrue(directory.isDirectory());
|
||||
assertEquals(NTREE, directory.listFiles().length);
|
||||
|
||||
|
||||
|
||||
final Forest<CompetingRiskFunctions, CompetingRiskFunctions> forest = DataLoader.loadForest(directory, new CompetingRiskFunctionCombiner(new int[]{1,2}, null));
|
||||
|
||||
final CovariateRow predictionRow = getPredictionRow(covariates);
|
||||
|
||||
final CompetingRiskFunctions functions = forest.evaluate(predictionRow);
|
||||
assertNotNull(functions);
|
||||
assertTrue(functions.getCumulativeIncidenceFunction(1).getPoints().size() > 2);
|
||||
|
||||
|
||||
assertEquals(NTREE, forest.getTrees().size());
|
||||
|
||||
cleanup(directory);
|
||||
|
||||
assertFalse(directory.exists());
|
||||
|
||||
}
|
||||
|
||||
private void cleanup(File file){
|
||||
if(file.isFile()){
|
||||
file.delete();
|
||||
}
|
||||
else{
|
||||
for(final File inner : file.listFiles()){
|
||||
cleanup(inner);
|
||||
}
|
||||
file.delete();
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
||||
}
|
Loading…
Reference in a new issue