Add capability to load trees back into memory.
This commit is contained in:
parent
fffdfe85bf
commit
05f9122b58
9 changed files with 252 additions and 26 deletions
|
@ -1,19 +1,17 @@
|
||||||
package ca.joeltherrien.randomforest;
|
package ca.joeltherrien.randomforest;
|
||||||
|
|
||||||
import ca.joeltherrien.randomforest.covariates.Covariate;
|
import ca.joeltherrien.randomforest.covariates.Covariate;
|
||||||
|
import ca.joeltherrien.randomforest.tree.Forest;
|
||||||
|
import ca.joeltherrien.randomforest.tree.ResponseCombiner;
|
||||||
|
import ca.joeltherrien.randomforest.tree.Tree;
|
||||||
import com.fasterxml.jackson.databind.node.ObjectNode;
|
import com.fasterxml.jackson.databind.node.ObjectNode;
|
||||||
import lombok.RequiredArgsConstructor;
|
import lombok.RequiredArgsConstructor;
|
||||||
import org.apache.commons.csv.CSVFormat;
|
import org.apache.commons.csv.CSVFormat;
|
||||||
import org.apache.commons.csv.CSVParser;
|
import org.apache.commons.csv.CSVParser;
|
||||||
import org.apache.commons.csv.CSVRecord;
|
import org.apache.commons.csv.CSVRecord;
|
||||||
|
|
||||||
import java.io.FileReader;
|
import java.io.*;
|
||||||
import java.io.IOException;
|
import java.util.*;
|
||||||
import java.io.Reader;
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.HashMap;
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.Map;
|
|
||||||
|
|
||||||
public class DataLoader {
|
public class DataLoader {
|
||||||
|
|
||||||
|
@ -43,6 +41,36 @@ public class DataLoader {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public static <O, FO> Forest<O, FO> loadForest(File folder, ResponseCombiner<O, FO> treeResponseCombiner) throws IOException, ClassNotFoundException {
|
||||||
|
if(!folder.isDirectory()){
|
||||||
|
throw new IllegalArgumentException("Tree directory must be a directory!");
|
||||||
|
}
|
||||||
|
|
||||||
|
final File[] treeFiles = folder.listFiles(((file, s) -> s.endsWith(".tree")));
|
||||||
|
final List<File> treeFileList = Arrays.asList(treeFiles);
|
||||||
|
|
||||||
|
Collections.sort(treeFileList, Comparator.comparing(File::getName));
|
||||||
|
|
||||||
|
final List<Tree<O>> treeList = new ArrayList<>(treeFileList.size());
|
||||||
|
|
||||||
|
for(final File treeFile : treeFileList){
|
||||||
|
final ObjectInputStream inputStream = new ObjectInputStream(new FileInputStream(treeFile));
|
||||||
|
|
||||||
|
final Tree<O> tree = (Tree) inputStream.readObject();
|
||||||
|
|
||||||
|
treeList.add(tree);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
final Forest forest = Forest.<O, FO>builder()
|
||||||
|
.trees(treeList)
|
||||||
|
.treeResponseCombiner(treeResponseCombiner)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
return forest;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
@FunctionalInterface
|
@FunctionalInterface
|
||||||
public interface ResponseLoader<Y>{
|
public interface ResponseLoader<Y>{
|
||||||
Y parse(CSVRecord record);
|
Y parse(CSVRecord record);
|
||||||
|
|
|
@ -4,10 +4,11 @@ import lombok.Builder;
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
import lombok.RequiredArgsConstructor;
|
import lombok.RequiredArgsConstructor;
|
||||||
|
|
||||||
|
import java.io.Serializable;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
|
||||||
@Builder
|
@Builder
|
||||||
public class CompetingRiskFunctions {
|
public class CompetingRiskFunctions implements Serializable {
|
||||||
|
|
||||||
private final Map<Integer, MathFunction> causeSpecificHazardFunctionMap;
|
private final Map<Integer, MathFunction> causeSpecificHazardFunctionMap;
|
||||||
private final Map<Integer, MathFunction> cumulativeIncidenceFunctionMap;
|
private final Map<Integer, MathFunction> cumulativeIncidenceFunctionMap;
|
||||||
|
|
|
@ -2,6 +2,7 @@ package ca.joeltherrien.randomforest.responses.competingrisk;
|
||||||
|
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
|
|
||||||
|
import java.io.Serializable;
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
import java.util.Comparator;
|
import java.util.Comparator;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
@ -12,7 +13,7 @@ import java.util.Optional;
|
||||||
* constant at the value of the previous encountered point.
|
* constant at the value of the previous encountered point.
|
||||||
*
|
*
|
||||||
*/
|
*/
|
||||||
public class MathFunction {
|
public class MathFunction implements Serializable {
|
||||||
|
|
||||||
@Getter
|
@Getter
|
||||||
private final List<Point> points;
|
private final List<Point> points;
|
||||||
|
|
|
@ -2,12 +2,14 @@ package ca.joeltherrien.randomforest.responses.competingrisk;
|
||||||
|
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
|
||||||
|
import java.io.Serializable;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Represents a point in our estimate of either the cumulative hazard function or the cumulative incidence function.
|
* Represents a point in our estimate of either the cumulative hazard function or the cumulative incidence function.
|
||||||
*
|
*
|
||||||
*/
|
*/
|
||||||
@Data
|
@Data
|
||||||
public class Point {
|
public class Point implements Serializable {
|
||||||
private final Double time;
|
private final Double time;
|
||||||
private final Double y;
|
private final Double y;
|
||||||
}
|
}
|
||||||
|
|
|
@ -4,12 +4,13 @@ import ca.joeltherrien.randomforest.CovariateRow;
|
||||||
import lombok.Builder;
|
import lombok.Builder;
|
||||||
|
|
||||||
import java.util.Collection;
|
import java.util.Collection;
|
||||||
|
import java.util.Collections;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
@Builder
|
@Builder
|
||||||
public class Forest<O, FO> { // O = output of trees, FO = forest output. In practice O == FO, even in competing risk & survival settings
|
public class Forest<O, FO> { // O = output of trees, FO = forest output. In practice O == FO, even in competing risk & survival settings
|
||||||
|
|
||||||
private final Collection<Node<O>> trees;
|
private final Collection<Tree<O>> trees;
|
||||||
private final ResponseCombiner<O, FO> treeResponseCombiner;
|
private final ResponseCombiner<O, FO> treeResponseCombiner;
|
||||||
|
|
||||||
public FO evaluate(CovariateRow row){
|
public FO evaluate(CovariateRow row){
|
||||||
|
@ -22,4 +23,8 @@ public class Forest<O, FO> { // O = output of trees, FO = forest output. In prac
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public Collection<Tree<O>> getTrees(){
|
||||||
|
return Collections.unmodifiableCollection(trees);
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,9 +1,9 @@
|
||||||
package ca.joeltherrien.randomforest.tree;
|
package ca.joeltherrien.randomforest.tree;
|
||||||
|
|
||||||
import ca.joeltherrien.randomforest.Bootstrapper;
|
import ca.joeltherrien.randomforest.Bootstrapper;
|
||||||
|
import ca.joeltherrien.randomforest.Row;
|
||||||
import ca.joeltherrien.randomforest.Settings;
|
import ca.joeltherrien.randomforest.Settings;
|
||||||
import ca.joeltherrien.randomforest.covariates.Covariate;
|
import ca.joeltherrien.randomforest.covariates.Covariate;
|
||||||
import ca.joeltherrien.randomforest.Row;
|
|
||||||
import lombok.AccessLevel;
|
import lombok.AccessLevel;
|
||||||
import lombok.AllArgsConstructor;
|
import lombok.AllArgsConstructor;
|
||||||
import lombok.Builder;
|
import lombok.Builder;
|
||||||
|
@ -12,11 +12,9 @@ import java.io.FileOutputStream;
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.io.ObjectOutputStream;
|
import java.io.ObjectOutputStream;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.Collections;
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.concurrent.ExecutorService;
|
import java.util.concurrent.ExecutorService;
|
||||||
import java.util.concurrent.Executors;
|
import java.util.concurrent.Executors;
|
||||||
import java.util.concurrent.ThreadLocalRandom;
|
|
||||||
import java.util.concurrent.atomic.AtomicInteger;
|
import java.util.concurrent.atomic.AtomicInteger;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
import java.util.stream.Stream;
|
import java.util.stream.Stream;
|
||||||
|
@ -54,7 +52,7 @@ public class ForestTrainer<Y, TO, FO> {
|
||||||
|
|
||||||
public Forest<TO, FO> trainSerial(){
|
public Forest<TO, FO> trainSerial(){
|
||||||
|
|
||||||
final List<Node<TO>> trees = new ArrayList<>(ntree);
|
final List<Tree<TO>> trees = new ArrayList<>(ntree);
|
||||||
final Bootstrapper<Row<Y>> bootstrapper = new Bootstrapper<>(data);
|
final Bootstrapper<Row<Y>> bootstrapper = new Bootstrapper<>(data);
|
||||||
|
|
||||||
for(int j=0; j<ntree; j++){
|
for(int j=0; j<ntree; j++){
|
||||||
|
@ -83,7 +81,7 @@ public class ForestTrainer<Y, TO, FO> {
|
||||||
|
|
||||||
// create a list that is prespecified in size (I can call the .set method at any index < ntree without
|
// create a list that is prespecified in size (I can call the .set method at any index < ntree without
|
||||||
// the earlier indexes being filled.
|
// the earlier indexes being filled.
|
||||||
final List<Node<TO>> trees = Stream.<Node<TO>>generate(() -> null).limit(ntree).collect(Collectors.toList());
|
final List<Tree<TO>> trees = Stream.<Tree<TO>>generate(() -> null).limit(ntree).collect(Collectors.toList());
|
||||||
|
|
||||||
final ExecutorService executorService = Executors.newFixedThreadPool(threads);
|
final ExecutorService executorService = Executors.newFixedThreadPool(threads);
|
||||||
|
|
||||||
|
@ -103,7 +101,7 @@ public class ForestTrainer<Y, TO, FO> {
|
||||||
|
|
||||||
if(displayProgress) {
|
if(displayProgress) {
|
||||||
int numberTreesSet = 0;
|
int numberTreesSet = 0;
|
||||||
for (final Node<TO> tree : trees) {
|
for (final Tree<TO> tree : trees) {
|
||||||
if (tree != null) {
|
if (tree != null) {
|
||||||
numberTreesSet++;
|
numberTreesSet++;
|
||||||
}
|
}
|
||||||
|
@ -131,7 +129,7 @@ public class ForestTrainer<Y, TO, FO> {
|
||||||
final AtomicInteger treeCount = new AtomicInteger(0); // tracks how many trees are finished
|
final AtomicInteger treeCount = new AtomicInteger(0); // tracks how many trees are finished
|
||||||
|
|
||||||
for(int j=0; j<ntree; j++){
|
for(int j=0; j<ntree; j++){
|
||||||
final Runnable worker = new TreeSavedWorker(data, "tree-" + (j+1), treeCount);
|
final Runnable worker = new TreeSavedWorker(data, "tree-" + (j+1) + ".tree", treeCount);
|
||||||
executorService.execute(worker);
|
executorService.execute(worker);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -157,12 +155,12 @@ public class ForestTrainer<Y, TO, FO> {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private Node<TO> trainTree(final Bootstrapper<Row<Y>> bootstrapper){
|
private Tree<TO> trainTree(final Bootstrapper<Row<Y>> bootstrapper){
|
||||||
final List<Row<Y>> bootstrappedData = bootstrapper.bootstrap();
|
final List<Row<Y>> bootstrappedData = bootstrapper.bootstrap();
|
||||||
return treeTrainer.growTree(bootstrappedData);
|
return treeTrainer.growTree(bootstrappedData);
|
||||||
}
|
}
|
||||||
|
|
||||||
public void saveTree(final Node<TO> tree, String name) throws IOException {
|
public void saveTree(final Tree<TO> tree, String name) throws IOException {
|
||||||
final String filename = saveTreeLocation + "/" + name;
|
final String filename = saveTreeLocation + "/" + name;
|
||||||
|
|
||||||
final ObjectOutputStream outputStream = new ObjectOutputStream(new FileOutputStream(filename));
|
final ObjectOutputStream outputStream = new ObjectOutputStream(new FileOutputStream(filename));
|
||||||
|
@ -177,9 +175,9 @@ public class ForestTrainer<Y, TO, FO> {
|
||||||
|
|
||||||
private final Bootstrapper<Row<Y>> bootstrapper;
|
private final Bootstrapper<Row<Y>> bootstrapper;
|
||||||
private final int treeIndex;
|
private final int treeIndex;
|
||||||
private final List<Node<TO>> treeList;
|
private final List<Tree<TO>> treeList;
|
||||||
|
|
||||||
public TreeInMemoryWorker(final List<Row<Y>> data, final int treeIndex, final List<Node<TO>> treeList) {
|
public TreeInMemoryWorker(final List<Row<Y>> data, final int treeIndex, final List<Tree<TO>> treeList) {
|
||||||
this.bootstrapper = new Bootstrapper<>(data);
|
this.bootstrapper = new Bootstrapper<>(data);
|
||||||
this.treeIndex = treeIndex;
|
this.treeIndex = treeIndex;
|
||||||
this.treeList = treeList;
|
this.treeList = treeList;
|
||||||
|
@ -188,7 +186,7 @@ public class ForestTrainer<Y, TO, FO> {
|
||||||
@Override
|
@Override
|
||||||
public void run() {
|
public void run() {
|
||||||
|
|
||||||
final Node<TO> tree = trainTree(bootstrapper);
|
final Tree<TO> tree = trainTree(bootstrapper);
|
||||||
|
|
||||||
// should be okay as the list structure isn't changing
|
// should be okay as the list structure isn't changing
|
||||||
treeList.set(treeIndex, tree);
|
treeList.set(treeIndex, tree);
|
||||||
|
@ -211,7 +209,7 @@ public class ForestTrainer<Y, TO, FO> {
|
||||||
@Override
|
@Override
|
||||||
public void run() {
|
public void run() {
|
||||||
|
|
||||||
final Node<TO> tree = trainTree(bootstrapper);
|
final Tree<TO> tree = trainTree(bootstrapper);
|
||||||
|
|
||||||
try {
|
try {
|
||||||
saveTree(tree, filename);
|
saveTree(tree, filename);
|
||||||
|
|
42
src/main/java/ca/joeltherrien/randomforest/tree/Tree.java
Normal file
42
src/main/java/ca/joeltherrien/randomforest/tree/Tree.java
Normal file
|
@ -0,0 +1,42 @@
|
||||||
|
package ca.joeltherrien.randomforest.tree;
|
||||||
|
|
||||||
|
import ca.joeltherrien.randomforest.CovariateRow;
|
||||||
|
import lombok.RequiredArgsConstructor;
|
||||||
|
|
||||||
|
import java.util.Arrays;
|
||||||
|
|
||||||
|
@RequiredArgsConstructor
|
||||||
|
public class Tree<Y> implements Node<Y> {
|
||||||
|
|
||||||
|
private final Node<Y> rootNode;
|
||||||
|
private final int[] bootstrapRowIds;
|
||||||
|
private boolean bootStrapRowIdsSorted = false;
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Y evaluate(CovariateRow row) {
|
||||||
|
return rootNode.evaluate(row);
|
||||||
|
}
|
||||||
|
|
||||||
|
public int[] getBootstrapRowIds(){
|
||||||
|
return bootstrapRowIds.clone();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Sort bootstrapRowIds. This is not done automatically for efficiency purposes, as in many cases we may not be interested in using bootstrapRowIds();
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
public void sortBootstrapRowIds(){
|
||||||
|
if(!bootStrapRowIdsSorted){
|
||||||
|
Arrays.sort(bootstrapRowIds);
|
||||||
|
bootStrapRowIdsSorted = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
public boolean idInBootstrapSample(int id){
|
||||||
|
this.sortBootstrapRowIds();
|
||||||
|
|
||||||
|
return Arrays.binarySearch(this.bootstrapRowIds, id) >= 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -39,8 +39,13 @@ public class TreeTrainer<Y, O> {
|
||||||
this.covariates = covariates;
|
this.covariates = covariates;
|
||||||
}
|
}
|
||||||
|
|
||||||
public Node<O> growTree(List<Row<Y>> data){
|
public Tree<O> growTree(List<Row<Y>> data){
|
||||||
return growNode(data, 0);
|
|
||||||
|
final Node<O> rootNode = growNode(data, 0);
|
||||||
|
final Tree<O> tree = new Tree<>(rootNode, data.stream().mapToInt(Row::getId).toArray());
|
||||||
|
|
||||||
|
return tree;
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private Node<O> growNode(List<Row<Y>> data, int depth){
|
private Node<O> growNode(List<Row<Y>> data, int depth){
|
||||||
|
|
|
@ -0,0 +1,144 @@
|
||||||
|
package ca.joeltherrien.randomforest;
|
||||||
|
|
||||||
|
import ca.joeltherrien.randomforest.covariates.BooleanCovariateSettings;
|
||||||
|
import ca.joeltherrien.randomforest.covariates.Covariate;
|
||||||
|
import ca.joeltherrien.randomforest.covariates.NumericCovariateSettings;
|
||||||
|
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskFunctionCombiner;
|
||||||
|
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskFunctions;
|
||||||
|
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
|
||||||
|
import ca.joeltherrien.randomforest.tree.Forest;
|
||||||
|
import ca.joeltherrien.randomforest.tree.ForestTrainer;
|
||||||
|
import com.fasterxml.jackson.databind.node.*;
|
||||||
|
import org.junit.jupiter.api.Test;
|
||||||
|
import static org.junit.jupiter.api.Assertions.*;
|
||||||
|
|
||||||
|
import java.io.File;
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
|
public class TestSavingLoading {
|
||||||
|
|
||||||
|
private final int NTREE = 10;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* By default uses single log-rank test.
|
||||||
|
*
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
public Settings getSettings(){
|
||||||
|
final ObjectNode groupDifferentiatorSettings = new ObjectNode(JsonNodeFactory.instance);
|
||||||
|
groupDifferentiatorSettings.set("type", new TextNode("LogRankSingleGroupDifferentiator"));
|
||||||
|
groupDifferentiatorSettings.set("eventOfFocus", new IntNode(1));
|
||||||
|
|
||||||
|
final ObjectNode responseCombinerSettings = new ObjectNode(JsonNodeFactory.instance);
|
||||||
|
responseCombinerSettings.set("type", new TextNode("CompetingRiskResponseCombiner"));
|
||||||
|
responseCombinerSettings.set("events",
|
||||||
|
new ArrayNode(JsonNodeFactory.instance, List.of(new IntNode(1), new IntNode(2)))
|
||||||
|
);
|
||||||
|
// not setting times
|
||||||
|
|
||||||
|
|
||||||
|
final ObjectNode treeCombinerSettings = new ObjectNode(JsonNodeFactory.instance);
|
||||||
|
treeCombinerSettings.set("type", new TextNode("CompetingRiskFunctionCombiner"));
|
||||||
|
treeCombinerSettings.set("events",
|
||||||
|
new ArrayNode(JsonNodeFactory.instance, List.of(new IntNode(1), new IntNode(2)))
|
||||||
|
);
|
||||||
|
// not setting times
|
||||||
|
|
||||||
|
final ObjectNode yVarSettings = new ObjectNode(JsonNodeFactory.instance);
|
||||||
|
yVarSettings.set("type", new TextNode("CompetingRiskResponse"));
|
||||||
|
yVarSettings.set("u", new TextNode("time"));
|
||||||
|
yVarSettings.set("delta", new TextNode("status"));
|
||||||
|
|
||||||
|
return Settings.builder()
|
||||||
|
.covariates(List.of(
|
||||||
|
new NumericCovariateSettings("ageatfda"),
|
||||||
|
new BooleanCovariateSettings("idu"),
|
||||||
|
new BooleanCovariateSettings("black"),
|
||||||
|
new NumericCovariateSettings("cd4nadir")
|
||||||
|
)
|
||||||
|
)
|
||||||
|
.dataFileLocation("src/test/resources/wihs.csv")
|
||||||
|
.responseCombinerSettings(responseCombinerSettings)
|
||||||
|
.treeCombinerSettings(treeCombinerSettings)
|
||||||
|
.groupDifferentiatorSettings(groupDifferentiatorSettings)
|
||||||
|
.yVarSettings(yVarSettings)
|
||||||
|
.maxNodeDepth(100000)
|
||||||
|
// TODO fill in these settings
|
||||||
|
.mtry(2)
|
||||||
|
.nodeSize(6)
|
||||||
|
.ntree(NTREE)
|
||||||
|
.numberOfSplits(5)
|
||||||
|
.numberOfThreads(3)
|
||||||
|
.saveProgress(true)
|
||||||
|
.saveTreeLocation("src/test/resources/trees/")
|
||||||
|
.build();
|
||||||
|
}
|
||||||
|
|
||||||
|
public CovariateRow getPredictionRow(List<Covariate> covariates){
|
||||||
|
return CovariateRow.createSimple(Map.of(
|
||||||
|
"ageatfda", "35",
|
||||||
|
"idu", "false",
|
||||||
|
"black", "false",
|
||||||
|
"cd4nadir", "0.81")
|
||||||
|
, covariates, 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
public List<Covariate> getCovariates(Settings settings){
|
||||||
|
return settings.getCovariates().stream().map(covariateSettings -> covariateSettings.build()).collect(Collectors.toList());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testSavingLoading() throws IOException, ClassNotFoundException {
|
||||||
|
final Settings settings = getSettings();
|
||||||
|
final List<Covariate> covariates = getCovariates(settings);
|
||||||
|
final List<Row<CompetingRiskResponse>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getDataFileLocation());
|
||||||
|
|
||||||
|
final File directory = new File(settings.getSaveTreeLocation());
|
||||||
|
assertFalse(directory.exists());
|
||||||
|
directory.mkdir();
|
||||||
|
|
||||||
|
final ForestTrainer<CompetingRiskResponse, CompetingRiskFunctions, CompetingRiskFunctions> forestTrainer = new ForestTrainer<>(settings, dataset, covariates);
|
||||||
|
|
||||||
|
forestTrainer.trainParallelOnDisk(1);
|
||||||
|
|
||||||
|
assertTrue(directory.exists());
|
||||||
|
assertTrue(directory.isDirectory());
|
||||||
|
assertEquals(NTREE, directory.listFiles().length);
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
final Forest<CompetingRiskFunctions, CompetingRiskFunctions> forest = DataLoader.loadForest(directory, new CompetingRiskFunctionCombiner(new int[]{1,2}, null));
|
||||||
|
|
||||||
|
final CovariateRow predictionRow = getPredictionRow(covariates);
|
||||||
|
|
||||||
|
final CompetingRiskFunctions functions = forest.evaluate(predictionRow);
|
||||||
|
assertNotNull(functions);
|
||||||
|
assertTrue(functions.getCumulativeIncidenceFunction(1).getPoints().size() > 2);
|
||||||
|
|
||||||
|
|
||||||
|
assertEquals(NTREE, forest.getTrees().size());
|
||||||
|
|
||||||
|
cleanup(directory);
|
||||||
|
|
||||||
|
assertFalse(directory.exists());
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
private void cleanup(File file){
|
||||||
|
if(file.isFile()){
|
||||||
|
file.delete();
|
||||||
|
}
|
||||||
|
else{
|
||||||
|
for(final File inner : file.listFiles()){
|
||||||
|
cleanup(inner);
|
||||||
|
}
|
||||||
|
file.delete();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
Loading…
Reference in a new issue