Fairly significant refactoring;
Made Forest into an abstract class with OnlineForest (in memory; same as previous) and OfflineForest (reads individual trees only as needed). Many methods were changed.
This commit is contained in:
parent
79a9522ba7
commit
54af805d4d
14 changed files with 484 additions and 107 deletions
|
@ -122,7 +122,7 @@ public class Main {
|
||||||
Utils.reduceListToSize(dataset, n, new Random());
|
Utils.reduceListToSize(dataset, n, new Random());
|
||||||
|
|
||||||
final File folder = new File(settings.getSaveTreeLocation());
|
final File folder = new File(settings.getSaveTreeLocation());
|
||||||
final Forest<?, CompetingRiskFunctions> forest = DataUtils.loadForest(folder, responseCombiner);
|
final Forest<?, CompetingRiskFunctions> forest = DataUtils.loadOnlineForest(folder, responseCombiner);
|
||||||
|
|
||||||
final boolean useBootstrapPredictions = settings.getTrainingDataLocation().equals(settings.getValidationDataLocation());
|
final boolean useBootstrapPredictions = settings.getTrainingDataLocation().equals(settings.getValidationDataLocation());
|
||||||
|
|
||||||
|
|
|
@ -17,31 +17,18 @@
|
||||||
package ca.joeltherrien.randomforest.tree;
|
package ca.joeltherrien.randomforest.tree;
|
||||||
|
|
||||||
import ca.joeltherrien.randomforest.CovariateRow;
|
import ca.joeltherrien.randomforest.CovariateRow;
|
||||||
import ca.joeltherrien.randomforest.covariates.Covariate;
|
|
||||||
import lombok.Builder;
|
|
||||||
|
|
||||||
import java.util.Collections;
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.TreeMap;
|
import java.util.TreeMap;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
@Builder
|
public abstract class Forest<O, FO> {
|
||||||
public class Forest<O, FO> { // O = output of trees, FO = forest output. In practice O == FO, even in competing risk & survival settings
|
|
||||||
|
|
||||||
private final List<Tree<O>> trees;
|
public abstract FO evaluate(CovariateRow row);
|
||||||
private final ResponseCombiner<O, FO> treeResponseCombiner;
|
public abstract FO evaluateOOB(CovariateRow row);
|
||||||
private final List<Covariate> covariateList;
|
public abstract Iterable<Tree<O>> getTrees();
|
||||||
|
public abstract int getNumberOfTrees();
|
||||||
public FO evaluate(CovariateRow row){
|
|
||||||
|
|
||||||
return treeResponseCombiner.combine(
|
|
||||||
trees.stream()
|
|
||||||
.map(node -> node.evaluate(row))
|
|
||||||
.collect(Collectors.toList())
|
|
||||||
);
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Used primarily in the R package interface to avoid R loops; and for easier parallelization.
|
* Used primarily in the R package interface to avoid R loops; and for easier parallelization.
|
||||||
|
@ -93,21 +80,6 @@ public class Forest<O, FO> { // O = output of trees, FO = forest output. In prac
|
||||||
.collect(Collectors.toList());
|
.collect(Collectors.toList());
|
||||||
}
|
}
|
||||||
|
|
||||||
public FO evaluateOOB(CovariateRow row){
|
|
||||||
|
|
||||||
return treeResponseCombiner.combine(
|
|
||||||
trees.stream()
|
|
||||||
.filter(tree -> !tree.idInBootstrapSample(row.getId()))
|
|
||||||
.map(node -> node.evaluate(row))
|
|
||||||
.collect(Collectors.toList())
|
|
||||||
);
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
public List<Tree<O>> getTrees(){
|
|
||||||
return Collections.unmodifiableList(trees);
|
|
||||||
}
|
|
||||||
|
|
||||||
public Map<Integer, Integer> findSplitsByCovariate(){
|
public Map<Integer, Integer> findSplitsByCovariate(){
|
||||||
final Map<Integer, Integer> countMap = new TreeMap<>();
|
final Map<Integer, Integer> countMap = new TreeMap<>();
|
||||||
|
|
||||||
|
@ -158,4 +130,5 @@ public class Forest<O, FO> { // O = output of trees, FO = forest output. In prac
|
||||||
return countTerminalNodes;
|
return countTerminalNodes;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -57,10 +57,10 @@ public class ForestTrainer<Y, TO, FO> {
|
||||||
* in which case its trees are combined with the new one.
|
* in which case its trees are combined with the new one.
|
||||||
* @return A trained forest.
|
* @return A trained forest.
|
||||||
*/
|
*/
|
||||||
public Forest<TO, FO> trainSerialInMemory(Optional<Forest<TO, FO>> initialForest){
|
public OnlineForest<TO, FO> trainSerialInMemory(Optional<Forest<TO, FO>> initialForest){
|
||||||
|
|
||||||
final List<Tree<TO>> trees = new ArrayList<>(ntree);
|
final List<Tree<TO>> trees = new ArrayList<>(ntree);
|
||||||
initialForest.ifPresent(forest -> trees.addAll(forest.getTrees()));
|
initialForest.ifPresent(forest -> forest.getTrees().forEach(trees::add));
|
||||||
|
|
||||||
final Bootstrapper<Row<Y>> bootstrapper = new Bootstrapper<>(data);
|
final Bootstrapper<Row<Y>> bootstrapper = new Bootstrapper<>(data);
|
||||||
|
|
||||||
|
@ -77,11 +77,9 @@ public class ForestTrainer<Y, TO, FO> {
|
||||||
System.out.println("Finished");
|
System.out.println("Finished");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return OnlineForest.<TO, FO>builder()
|
||||||
return Forest.<TO, FO>builder()
|
|
||||||
.treeResponseCombiner(treeResponseCombiner)
|
.treeResponseCombiner(treeResponseCombiner)
|
||||||
.trees(trees)
|
.trees(trees)
|
||||||
.covariateList(covariates)
|
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -94,7 +92,7 @@ public class ForestTrainer<Y, TO, FO> {
|
||||||
* There cannot be existing trees if the initial forest is
|
* There cannot be existing trees if the initial forest is
|
||||||
* specified.
|
* specified.
|
||||||
*/
|
*/
|
||||||
public void trainSerialOnDisk(Optional<Forest<TO, FO>> initialForest){
|
public OfflineForest<TO, FO> trainSerialOnDisk(Optional<Forest<TO, FO>> initialForest){
|
||||||
// First we need to see how many trees there currently are
|
// First we need to see how many trees there currently are
|
||||||
final File folder = new File(saveTreeLocation);
|
final File folder = new File(saveTreeLocation);
|
||||||
if(!folder.exists()){
|
if(!folder.exists()){
|
||||||
|
@ -115,17 +113,14 @@ public class ForestTrainer<Y, TO, FO> {
|
||||||
final AtomicInteger treeCount; // tracks how many trees are finished
|
final AtomicInteger treeCount; // tracks how many trees are finished
|
||||||
// Using an AtomicInteger is overkill for serial code, but this lets us reuse TreeSavedWorker
|
// Using an AtomicInteger is overkill for serial code, but this lets us reuse TreeSavedWorker
|
||||||
if(initialForest.isPresent()){
|
if(initialForest.isPresent()){
|
||||||
final List<Tree<TO>> initialTrees = initialForest.get().getTrees();
|
int j=0;
|
||||||
|
for(final Tree<TO> tree : initialForest.get().getTrees()){
|
||||||
for(int j=0; j<initialTrees.size(); j++){
|
|
||||||
final String filename = "tree-" + (j+1) + ".tree";
|
final String filename = "tree-" + (j+1) + ".tree";
|
||||||
final Tree<TO> tree = initialTrees.get(j);
|
|
||||||
|
|
||||||
saveTree(tree, filename);
|
saveTree(tree, filename);
|
||||||
|
j++;
|
||||||
}
|
}
|
||||||
|
|
||||||
treeCount = new AtomicInteger(initialTrees.size());
|
treeCount = new AtomicInteger(j);
|
||||||
} else{
|
} else{
|
||||||
treeCount = new AtomicInteger(treeFiles.length);
|
treeCount = new AtomicInteger(treeFiles.length);
|
||||||
}
|
}
|
||||||
|
@ -153,6 +148,8 @@ public class ForestTrainer<Y, TO, FO> {
|
||||||
System.out.println("Finished");
|
System.out.println("Finished");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return new OfflineForest<>(folder, treeResponseCombiner);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -162,7 +159,7 @@ public class ForestTrainer<Y, TO, FO> {
|
||||||
* in which case its trees are combined with the new one.
|
* in which case its trees are combined with the new one.
|
||||||
* @param threads The number of trees to train at once.
|
* @param threads The number of trees to train at once.
|
||||||
*/
|
*/
|
||||||
public Forest<TO, FO> trainParallelInMemory(Optional<Forest<TO, FO>> initialForest, int threads){
|
public OnlineForest<TO, FO> trainParallelInMemory(Optional<Forest<TO, FO>> initialForest, int threads){
|
||||||
|
|
||||||
// create a list that is pre-specified in size (I can call the .set method at any index < ntree without
|
// create a list that is pre-specified in size (I can call the .set method at any index < ntree without
|
||||||
// the earlier indexes being filled.
|
// the earlier indexes being filled.
|
||||||
|
@ -170,11 +167,12 @@ public class ForestTrainer<Y, TO, FO> {
|
||||||
|
|
||||||
final int startingCount;
|
final int startingCount;
|
||||||
if(initialForest.isPresent()){
|
if(initialForest.isPresent()){
|
||||||
final List<Tree<TO>> initialTrees = initialForest.get().getTrees();
|
int j = 0;
|
||||||
for(int j=0; j<initialTrees.size(); j++) {
|
for(final Tree<TO> tree : initialForest.get().getTrees()){
|
||||||
trees.set(j, initialTrees.get(j));
|
trees.set(j, tree);
|
||||||
|
j++;
|
||||||
}
|
}
|
||||||
startingCount = initialTrees.size();
|
startingCount = initialForest.get().getNumberOfTrees();
|
||||||
}
|
}
|
||||||
else{
|
else{
|
||||||
startingCount = 0;
|
startingCount = 0;
|
||||||
|
@ -219,7 +217,7 @@ public class ForestTrainer<Y, TO, FO> {
|
||||||
System.out.println("\nFinished");
|
System.out.println("\nFinished");
|
||||||
}
|
}
|
||||||
|
|
||||||
return Forest.<TO, FO>builder()
|
return OnlineForest.<TO, FO>builder()
|
||||||
.treeResponseCombiner(treeResponseCombiner)
|
.treeResponseCombiner(treeResponseCombiner)
|
||||||
.trees(trees)
|
.trees(trees)
|
||||||
.build();
|
.build();
|
||||||
|
@ -235,7 +233,7 @@ public class ForestTrainer<Y, TO, FO> {
|
||||||
* specified.
|
* specified.
|
||||||
* @param threads The number of trees to train at once.
|
* @param threads The number of trees to train at once.
|
||||||
*/
|
*/
|
||||||
public void trainParallelOnDisk(Optional<Forest<TO, FO>> initialForest, int threads){
|
public OfflineForest<TO, FO> trainParallelOnDisk(Optional<Forest<TO, FO>> initialForest, int threads){
|
||||||
// First we need to see how many trees there currently are
|
// First we need to see how many trees there currently are
|
||||||
final File folder = new File(saveTreeLocation);
|
final File folder = new File(saveTreeLocation);
|
||||||
if(!folder.exists()){
|
if(!folder.exists()){
|
||||||
|
@ -255,17 +253,14 @@ public class ForestTrainer<Y, TO, FO> {
|
||||||
|
|
||||||
final AtomicInteger treeCount; // tracks how many trees are finished
|
final AtomicInteger treeCount; // tracks how many trees are finished
|
||||||
if(initialForest.isPresent()){
|
if(initialForest.isPresent()){
|
||||||
final List<Tree<TO>> initialTrees = initialForest.get().getTrees();
|
int j=0;
|
||||||
|
for(final Tree<TO> tree : initialForest.get().getTrees()){
|
||||||
for(int j=0; j<initialTrees.size(); j++){
|
|
||||||
final String filename = "tree-" + (j+1) + ".tree";
|
final String filename = "tree-" + (j+1) + ".tree";
|
||||||
final Tree<TO> tree = initialTrees.get(j);
|
|
||||||
|
|
||||||
saveTree(tree, filename);
|
saveTree(tree, filename);
|
||||||
|
j++;
|
||||||
}
|
}
|
||||||
|
|
||||||
treeCount = new AtomicInteger(initialTrees.size());
|
treeCount = new AtomicInteger(j);
|
||||||
} else{
|
} else{
|
||||||
treeCount = new AtomicInteger(treeFiles.length);
|
treeCount = new AtomicInteger(treeFiles.length);
|
||||||
}
|
}
|
||||||
|
@ -309,6 +304,8 @@ public class ForestTrainer<Y, TO, FO> {
|
||||||
System.out.println("\nFinished");
|
System.out.println("\nFinished");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return new OfflineForest<>(folder, treeResponseCombiner);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private Tree<TO> trainTree(final Bootstrapper<Row<Y>> bootstrapper, Random random){
|
private Tree<TO> trainTree(final Bootstrapper<Row<Y>> bootstrapper, Random random){
|
||||||
|
|
|
@ -0,0 +1,198 @@
|
||||||
|
/*
|
||||||
|
* Copyright (c) 2019 Joel Therrien.
|
||||||
|
* This program is free software: you can redistribute it and/or modify
|
||||||
|
* it under the terms of the GNU General Public License as published by
|
||||||
|
* the Free Software Foundation, either version 3 of the License, or
|
||||||
|
* (at your option) any later version.
|
||||||
|
*
|
||||||
|
* This program is distributed in the hope that it will be useful,
|
||||||
|
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
* GNU General Public License for more details.
|
||||||
|
*
|
||||||
|
* You should have received a copy of the GNU General Public License
|
||||||
|
* along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package ca.joeltherrien.randomforest.tree;
|
||||||
|
|
||||||
|
import ca.joeltherrien.randomforest.CovariateRow;
|
||||||
|
import ca.joeltherrien.randomforest.utils.IterableOfflineTree;
|
||||||
|
import lombok.AllArgsConstructor;
|
||||||
|
|
||||||
|
import java.io.File;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.Iterator;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
import java.util.stream.IntStream;
|
||||||
|
|
||||||
|
@AllArgsConstructor
|
||||||
|
public class OfflineForest<O, FO> extends Forest<O, FO> {
|
||||||
|
|
||||||
|
private final File[] treeFiles;
|
||||||
|
private final ResponseCombiner<O, FO> treeResponseCombiner;
|
||||||
|
|
||||||
|
public OfflineForest(File treeDirectoryPath, ResponseCombiner<O, FO> treeResponseCombiner){
|
||||||
|
this.treeResponseCombiner = treeResponseCombiner;
|
||||||
|
|
||||||
|
if(!treeDirectoryPath.isDirectory()){
|
||||||
|
throw new IllegalArgumentException("treeDirectoryPath must point to a directory!");
|
||||||
|
}
|
||||||
|
|
||||||
|
this.treeFiles = treeDirectoryPath.listFiles((file, s) -> s.endsWith(".tree"));
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public FO evaluate(CovariateRow row) {
|
||||||
|
final List<O> predictedOutputs = new ArrayList<>(treeFiles.length);
|
||||||
|
for(final Tree<O> tree : getTrees()){
|
||||||
|
final O prediction = tree.evaluate(row);
|
||||||
|
predictedOutputs.add(prediction);
|
||||||
|
}
|
||||||
|
|
||||||
|
return treeResponseCombiner.combine(predictedOutputs);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public FO evaluateOOB(CovariateRow row) {
|
||||||
|
final List<O> predictedOutputs = new ArrayList<>(treeFiles.length);
|
||||||
|
for(final Tree<O> tree : getTrees()){
|
||||||
|
if(!tree.idInBootstrapSample(row.getId())){
|
||||||
|
final O prediction = tree.evaluate(row);
|
||||||
|
predictedOutputs.add(prediction);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return treeResponseCombiner.combine(predictedOutputs);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<FO> evaluate(List<? extends CovariateRow> rowList){
|
||||||
|
final O[][] predictions = (O[][])new Object[rowList.size()][treeFiles.length];
|
||||||
|
final Iterator<Tree<O>> treeIterator = getTrees().iterator();
|
||||||
|
|
||||||
|
for(int treeId = 0; treeId < treeFiles.length; treeId++){
|
||||||
|
final Tree<O> currentTree = treeIterator.next();
|
||||||
|
|
||||||
|
final int tempTreeId = treeId; // Java workaround
|
||||||
|
IntStream.range(0, rowList.size()).parallel().forEach(
|
||||||
|
rowId -> {
|
||||||
|
final CovariateRow row = rowList.get(rowId);
|
||||||
|
final O prediction = currentTree.evaluate(row);
|
||||||
|
predictions[rowId][tempTreeId] = prediction;
|
||||||
|
}
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
return Arrays.stream(predictions).parallel()
|
||||||
|
.map(predArray -> treeResponseCombiner.combine(Arrays.asList(predArray)))
|
||||||
|
.collect(Collectors.toList());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<FO> evaluateSerial(List<? extends CovariateRow> rowList){
|
||||||
|
final O[][] predictions = (O[][])new Object[rowList.size()][treeFiles.length];
|
||||||
|
final Iterator<Tree<O>> treeIterator = getTrees().iterator();
|
||||||
|
|
||||||
|
for(int treeId = 0; treeId < treeFiles.length; treeId++){
|
||||||
|
final Tree<O> currentTree = treeIterator.next();
|
||||||
|
|
||||||
|
final int tempTreeId = treeId; // Java workaround
|
||||||
|
IntStream.range(0, rowList.size()).sequential().forEach(
|
||||||
|
rowId -> {
|
||||||
|
final CovariateRow row = rowList.get(rowId);
|
||||||
|
final O prediction = currentTree.evaluate(row);
|
||||||
|
predictions[rowId][tempTreeId] = prediction;
|
||||||
|
}
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
return Arrays.stream(predictions).sequential()
|
||||||
|
.map(predArray -> treeResponseCombiner.combine(Arrays.asList(predArray)))
|
||||||
|
.collect(Collectors.toList());
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<FO> evaluateOOB(List<? extends CovariateRow> rowList){
|
||||||
|
final O[][] predictions = (O[][])new Object[rowList.size()][treeFiles.length];
|
||||||
|
final Iterator<Tree<O>> treeIterator = getTrees().iterator();
|
||||||
|
|
||||||
|
for(int treeId = 0; treeId < treeFiles.length; treeId++){
|
||||||
|
final Tree<O> currentTree = treeIterator.next();
|
||||||
|
|
||||||
|
final int tempTreeId = treeId; // Java workaround
|
||||||
|
IntStream.range(0, rowList.size()).parallel().forEach(
|
||||||
|
rowId -> {
|
||||||
|
final CovariateRow row = rowList.get(rowId);
|
||||||
|
if(!currentTree.idInBootstrapSample(row.getId())){
|
||||||
|
final O prediction = currentTree.evaluate(row);
|
||||||
|
predictions[rowId][tempTreeId] = prediction;
|
||||||
|
} else{
|
||||||
|
predictions[rowId][tempTreeId] = null;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
return Arrays.stream(predictions).parallel()
|
||||||
|
.map(predArray -> {
|
||||||
|
final List<O> predList = Arrays.stream(predArray).parallel()
|
||||||
|
.filter(pred -> pred != null).collect(Collectors.toList());
|
||||||
|
|
||||||
|
return treeResponseCombiner.combine(predList);
|
||||||
|
|
||||||
|
})
|
||||||
|
.collect(Collectors.toList());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<FO> evaluateSerialOOB(List<? extends CovariateRow> rowList){
|
||||||
|
final O[][] predictions = (O[][])new Object[rowList.size()][treeFiles.length];
|
||||||
|
final Iterator<Tree<O>> treeIterator = getTrees().iterator();
|
||||||
|
|
||||||
|
for(int treeId = 0; treeId < treeFiles.length; treeId++){
|
||||||
|
final Tree<O> currentTree = treeIterator.next();
|
||||||
|
|
||||||
|
final int tempTreeId = treeId; // Java workaround
|
||||||
|
IntStream.range(0, rowList.size()).sequential().forEach(
|
||||||
|
rowId -> {
|
||||||
|
final CovariateRow row = rowList.get(rowId);
|
||||||
|
if(!currentTree.idInBootstrapSample(row.getId())){
|
||||||
|
final O prediction = currentTree.evaluate(row);
|
||||||
|
predictions[rowId][tempTreeId] = prediction;
|
||||||
|
} else{
|
||||||
|
predictions[rowId][tempTreeId] = null;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
return Arrays.stream(predictions).sequential()
|
||||||
|
.map(predArray -> {
|
||||||
|
final List<O> predList = Arrays.stream(predArray).sequential()
|
||||||
|
.filter(pred -> pred != null).collect(Collectors.toList());
|
||||||
|
|
||||||
|
return treeResponseCombiner.combine(predList);
|
||||||
|
|
||||||
|
})
|
||||||
|
.collect(Collectors.toList());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Iterable<Tree<O>> getTrees() {
|
||||||
|
return new IterableOfflineTree<>(treeFiles);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int getNumberOfTrees() {
|
||||||
|
return treeFiles.length;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -0,0 +1,66 @@
|
||||||
|
/*
|
||||||
|
* Copyright (c) 2019 Joel Therrien.
|
||||||
|
* This program is free software: you can redistribute it and/or modify
|
||||||
|
* it under the terms of the GNU General Public License as published by
|
||||||
|
* the Free Software Foundation, either version 3 of the License, or
|
||||||
|
* (at your option) any later version.
|
||||||
|
*
|
||||||
|
* This program is distributed in the hope that it will be useful,
|
||||||
|
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
* GNU General Public License for more details.
|
||||||
|
*
|
||||||
|
* You should have received a copy of the GNU General Public License
|
||||||
|
* along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package ca.joeltherrien.randomforest.tree;
|
||||||
|
|
||||||
|
import ca.joeltherrien.randomforest.CovariateRow;
|
||||||
|
import lombok.Builder;
|
||||||
|
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
|
@Builder
|
||||||
|
public class OnlineForest<O, FO> extends Forest<O, FO> { // O = output of trees, FO = forest output. In practice O == FO, even in competing risk & survival settings
|
||||||
|
|
||||||
|
private final List<Tree<O>> trees;
|
||||||
|
private final ResponseCombiner<O, FO> treeResponseCombiner;
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public FO evaluate(CovariateRow row){
|
||||||
|
|
||||||
|
return treeResponseCombiner.combine(
|
||||||
|
trees.stream()
|
||||||
|
.map(node -> node.evaluate(row))
|
||||||
|
.collect(Collectors.toList())
|
||||||
|
);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public FO evaluateOOB(CovariateRow row){
|
||||||
|
|
||||||
|
return treeResponseCombiner.combine(
|
||||||
|
trees.stream()
|
||||||
|
.filter(tree -> !tree.idInBootstrapSample(row.getId()))
|
||||||
|
.map(node -> node.evaluate(row))
|
||||||
|
.collect(Collectors.toList())
|
||||||
|
);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<Tree<O>> getTrees(){
|
||||||
|
return Collections.unmodifiableList(trees);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int getNumberOfTrees() {
|
||||||
|
return trees.size();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
}
|
|
@ -16,9 +16,7 @@
|
||||||
|
|
||||||
package ca.joeltherrien.randomforest.utils;
|
package ca.joeltherrien.randomforest.utils;
|
||||||
|
|
||||||
import ca.joeltherrien.randomforest.tree.Forest;
|
import ca.joeltherrien.randomforest.tree.*;
|
||||||
import ca.joeltherrien.randomforest.tree.ResponseCombiner;
|
|
||||||
import ca.joeltherrien.randomforest.tree.Tree;
|
|
||||||
|
|
||||||
import java.io.*;
|
import java.io.*;
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
|
@ -27,7 +25,7 @@ import java.util.zip.GZIPOutputStream;
|
||||||
|
|
||||||
public class DataUtils {
|
public class DataUtils {
|
||||||
|
|
||||||
public static <O, FO> Forest<O, FO> loadForest(File folder, ResponseCombiner<O, FO> treeResponseCombiner) throws IOException, ClassNotFoundException {
|
public static <O, FO> OnlineForest<O, FO> loadOnlineForest(File folder, ResponseCombiner<O, FO> treeResponseCombiner) throws IOException, ClassNotFoundException {
|
||||||
if(!folder.isDirectory()){
|
if(!folder.isDirectory()){
|
||||||
throw new IllegalArgumentException("Tree directory must be a directory!");
|
throw new IllegalArgumentException("Tree directory must be a directory!");
|
||||||
}
|
}
|
||||||
|
@ -48,16 +46,16 @@ public class DataUtils {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return Forest.<O, FO>builder()
|
return OnlineForest.<O, FO>builder()
|
||||||
.trees(treeList)
|
.trees(treeList)
|
||||||
.treeResponseCombiner(treeResponseCombiner)
|
.treeResponseCombiner(treeResponseCombiner)
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public static <O, FO> Forest<O, FO> loadForest(String folder, ResponseCombiner<O, FO> treeResponseCombiner) throws IOException, ClassNotFoundException {
|
public static <O, FO> OnlineForest<O, FO> loadOnlineForest(String folder, ResponseCombiner<O, FO> treeResponseCombiner) throws IOException, ClassNotFoundException {
|
||||||
final File directory = new File(folder);
|
final File directory = new File(folder);
|
||||||
return loadForest(directory, treeResponseCombiner);
|
return loadOnlineForest(directory, treeResponseCombiner);
|
||||||
}
|
}
|
||||||
|
|
||||||
public static void saveObject(Serializable object, String filename) throws IOException {
|
public static void saveObject(Serializable object, String filename) throws IOException {
|
||||||
|
|
|
@ -0,0 +1,68 @@
|
||||||
|
/*
|
||||||
|
* Copyright (c) 2019 Joel Therrien.
|
||||||
|
* This program is free software: you can redistribute it and/or modify
|
||||||
|
* it under the terms of the GNU General Public License as published by
|
||||||
|
* the Free Software Foundation, either version 3 of the License, or
|
||||||
|
* (at your option) any later version.
|
||||||
|
*
|
||||||
|
* This program is distributed in the hope that it will be useful,
|
||||||
|
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
* GNU General Public License for more details.
|
||||||
|
*
|
||||||
|
* You should have received a copy of the GNU General Public License
|
||||||
|
* along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package ca.joeltherrien.randomforest.utils;
|
||||||
|
|
||||||
|
import ca.joeltherrien.randomforest.tree.Tree;
|
||||||
|
import lombok.RequiredArgsConstructor;
|
||||||
|
|
||||||
|
import java.io.File;
|
||||||
|
import java.io.FileInputStream;
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.io.ObjectInputStream;
|
||||||
|
import java.util.Iterator;
|
||||||
|
import java.util.zip.GZIPInputStream;
|
||||||
|
|
||||||
|
@RequiredArgsConstructor
|
||||||
|
public class IterableOfflineTree<Y> implements Iterable<Tree<Y>> {
|
||||||
|
|
||||||
|
private final File[] treeFiles;
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Iterator<Tree<Y>> iterator() {
|
||||||
|
return new OfflineTreeIterator<>(treeFiles);
|
||||||
|
}
|
||||||
|
|
||||||
|
@RequiredArgsConstructor
|
||||||
|
public static class OfflineTreeIterator<Y> implements Iterator<Tree<Y>>{
|
||||||
|
private final File[] treeFiles;
|
||||||
|
private int position = 0;
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean hasNext() {
|
||||||
|
return position < treeFiles.length;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Tree<Y> next() {
|
||||||
|
final File treeFile = treeFiles[position];
|
||||||
|
position++;
|
||||||
|
|
||||||
|
|
||||||
|
try {
|
||||||
|
final ObjectInputStream inputStream= new ObjectInputStream(new GZIPInputStream(new FileInputStream(treeFile)));
|
||||||
|
final Tree<Y> tree = (Tree) inputStream.readObject();
|
||||||
|
return tree;
|
||||||
|
} catch (IOException | ClassNotFoundException e) {
|
||||||
|
e.printStackTrace();
|
||||||
|
throw new RuntimeException("Failed to load tree for " + treeFile.toString());
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
}
|
|
@ -25,6 +25,7 @@ import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskRespons
|
||||||
import java.io.*;
|
import java.io.*;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
import java.util.stream.IntStream;
|
||||||
import java.util.zip.GZIPInputStream;
|
import java.util.zip.GZIPInputStream;
|
||||||
import java.util.zip.GZIPOutputStream;
|
import java.util.zip.GZIPOutputStream;
|
||||||
|
|
||||||
|
@ -198,4 +199,12 @@ public final class RUtils {
|
||||||
return newList;
|
return newList;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public static File[] getTreeFileArray(String folderPath, int endingId){
|
||||||
|
return (File[]) IntStream.rangeClosed(1, endingId).sequential()
|
||||||
|
.mapToObj(i -> folderPath + "/tree-" + i + ".tree")
|
||||||
|
.map(File::new)
|
||||||
|
.toArray();
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -16,6 +16,8 @@
|
||||||
|
|
||||||
package ca.joeltherrien.randomforest.utils;
|
package ca.joeltherrien.randomforest.utils;
|
||||||
|
|
||||||
|
import lombok.Getter;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
@ -39,6 +41,7 @@ public final class RightContinuousStepFunction extends StepFunction {
|
||||||
*
|
*
|
||||||
* May not be null.
|
* May not be null.
|
||||||
*/
|
*/
|
||||||
|
@Getter
|
||||||
private final double defaultY;
|
private final double defaultY;
|
||||||
|
|
||||||
public RightContinuousStepFunction(double[] x, double[] y, double defaultY) {
|
public RightContinuousStepFunction(double[] x, double[] y, double defaultY) {
|
||||||
|
|
|
@ -214,14 +214,14 @@ public class TestDeterministicForests {
|
||||||
|
|
||||||
forestTrainer5Trees.trainSerialOnDisk(Optional.empty());
|
forestTrainer5Trees.trainSerialOnDisk(Optional.empty());
|
||||||
forestTrainer10Trees.trainSerialOnDisk(Optional.empty());
|
forestTrainer10Trees.trainSerialOnDisk(Optional.empty());
|
||||||
final Forest<Double, Double> forestSerial = DataUtils.loadForest(saveTreeFile, new MeanResponseCombiner());
|
final Forest<Double, Double> forestSerial = DataUtils.loadOnlineForest(saveTreeFile, new MeanResponseCombiner());
|
||||||
TestUtils.removeFolder(saveTreeFile);
|
TestUtils.removeFolder(saveTreeFile);
|
||||||
verifyTwoForestsEqual(testData, referenceForest, forestSerial);
|
verifyTwoForestsEqual(testData, referenceForest, forestSerial);
|
||||||
|
|
||||||
|
|
||||||
forestTrainer5Trees.trainParallelOnDisk(Optional.empty(), 4);
|
forestTrainer5Trees.trainParallelOnDisk(Optional.empty(), 4);
|
||||||
forestTrainer10Trees.trainParallelOnDisk(Optional.empty(), 4);
|
forestTrainer10Trees.trainParallelOnDisk(Optional.empty(), 4);
|
||||||
final Forest<Double, Double> forestParallel = DataUtils.loadForest(saveTreeFile, new MeanResponseCombiner());
|
final Forest<Double, Double> forestParallel = DataUtils.loadOnlineForest(saveTreeFile, new MeanResponseCombiner());
|
||||||
TestUtils.removeFolder(saveTreeFile);
|
TestUtils.removeFolder(saveTreeFile);
|
||||||
verifyTwoForestsEqual(testData, referenceForest, forestParallel);
|
verifyTwoForestsEqual(testData, referenceForest, forestParallel);
|
||||||
|
|
||||||
|
@ -259,7 +259,7 @@ public class TestDeterministicForests {
|
||||||
|
|
||||||
for(int k=0; k<3; k++){
|
for(int k=0; k<3; k++){
|
||||||
forestTrainer.trainSerialOnDisk(Optional.empty());
|
forestTrainer.trainSerialOnDisk(Optional.empty());
|
||||||
final Forest<Double, Double> replicantForest = DataUtils.loadForest(saveTreeFile, responseCombiner);
|
final Forest<Double, Double> replicantForest = DataUtils.loadOnlineForest(saveTreeFile, responseCombiner);
|
||||||
TestUtils.removeFolder(saveTreeFile);
|
TestUtils.removeFolder(saveTreeFile);
|
||||||
verifyTwoForestsEqual(testData, referenceForest, replicantForest);
|
verifyTwoForestsEqual(testData, referenceForest, replicantForest);
|
||||||
}
|
}
|
||||||
|
@ -274,7 +274,7 @@ public class TestDeterministicForests {
|
||||||
|
|
||||||
for(int k=0; k<3; k++){
|
for(int k=0; k<3; k++){
|
||||||
forestTrainer.trainParallelOnDisk(Optional.empty(), 4);
|
forestTrainer.trainParallelOnDisk(Optional.empty(), 4);
|
||||||
final Forest<Double, Double> replicantForest = DataUtils.loadForest(saveTreeFile, responseCombiner);
|
final Forest<Double, Double> replicantForest = DataUtils.loadOnlineForest(saveTreeFile, responseCombiner);
|
||||||
TestUtils.removeFolder(saveTreeFile);
|
TestUtils.removeFolder(saveTreeFile);
|
||||||
verifyTwoForestsEqual(testData, referenceForest, replicantForest);
|
verifyTwoForestsEqual(testData, referenceForest, replicantForest);
|
||||||
}
|
}
|
||||||
|
|
|
@ -20,8 +20,8 @@ import ca.joeltherrien.randomforest.covariates.Covariate;
|
||||||
import ca.joeltherrien.randomforest.covariates.numeric.NumericCovariate;
|
import ca.joeltherrien.randomforest.covariates.numeric.NumericCovariate;
|
||||||
import ca.joeltherrien.randomforest.responses.regression.MeanResponseCombiner;
|
import ca.joeltherrien.randomforest.responses.regression.MeanResponseCombiner;
|
||||||
import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceSplitFinder;
|
import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceSplitFinder;
|
||||||
import ca.joeltherrien.randomforest.tree.Forest;
|
|
||||||
import ca.joeltherrien.randomforest.tree.ForestTrainer;
|
import ca.joeltherrien.randomforest.tree.ForestTrainer;
|
||||||
|
import ca.joeltherrien.randomforest.tree.OnlineForest;
|
||||||
import ca.joeltherrien.randomforest.tree.Tree;
|
import ca.joeltherrien.randomforest.tree.Tree;
|
||||||
import ca.joeltherrien.randomforest.tree.TreeTrainer;
|
import ca.joeltherrien.randomforest.tree.TreeTrainer;
|
||||||
import ca.joeltherrien.randomforest.utils.DataUtils;
|
import ca.joeltherrien.randomforest.utils.DataUtils;
|
||||||
|
@ -39,7 +39,7 @@ import static org.junit.jupiter.api.Assertions.*;
|
||||||
|
|
||||||
public class TestProvidingInitialForest {
|
public class TestProvidingInitialForest {
|
||||||
|
|
||||||
private Forest<Double, Double> initialForest;
|
private OnlineForest<Double, Double> initialForest;
|
||||||
private List<Covariate> covariateList;
|
private List<Covariate> covariateList;
|
||||||
private List<Row<Double>> data;
|
private List<Row<Double>> data;
|
||||||
|
|
||||||
|
@ -107,8 +107,8 @@ public class TestProvidingInitialForest {
|
||||||
public void testSerialInMemory(){
|
public void testSerialInMemory(){
|
||||||
final ForestTrainer<Double, Double, Double> forestTrainer = getForestTrainer(null, 20);
|
final ForestTrainer<Double, Double, Double> forestTrainer = getForestTrainer(null, 20);
|
||||||
|
|
||||||
final Forest<Double, Double> newForest = forestTrainer.trainSerialInMemory(Optional.of(initialForest));
|
final OnlineForest<Double, Double> newForest = forestTrainer.trainSerialInMemory(Optional.of(initialForest));
|
||||||
assertEquals(20, newForest.getTrees().size());
|
assertEquals(20, newForest.getNumberOfTrees());
|
||||||
|
|
||||||
for(Tree<Double> initialTree : initialForest.getTrees()){
|
for(Tree<Double> initialTree : initialForest.getTrees()){
|
||||||
assertTrue(newForest.getTrees().contains(initialTree));
|
assertTrue(newForest.getTrees().contains(initialTree));
|
||||||
|
@ -124,8 +124,8 @@ public class TestProvidingInitialForest {
|
||||||
public void testParallelInMemory(){
|
public void testParallelInMemory(){
|
||||||
final ForestTrainer<Double, Double, Double> forestTrainer = getForestTrainer(null, 20);
|
final ForestTrainer<Double, Double, Double> forestTrainer = getForestTrainer(null, 20);
|
||||||
|
|
||||||
final Forest<Double, Double> newForest = forestTrainer.trainParallelInMemory(Optional.of(initialForest), 2);
|
final OnlineForest<Double, Double> newForest = forestTrainer.trainParallelInMemory(Optional.of(initialForest), 2);
|
||||||
assertEquals(20, newForest.getTrees().size());
|
assertEquals(20, newForest.getNumberOfTrees());
|
||||||
|
|
||||||
for(Tree<Double> initialTree : initialForest.getTrees()){
|
for(Tree<Double> initialTree : initialForest.getTrees()){
|
||||||
assertTrue(newForest.getTrees().contains(initialTree));
|
assertTrue(newForest.getTrees().contains(initialTree));
|
||||||
|
@ -149,11 +149,11 @@ public class TestProvidingInitialForest {
|
||||||
forestTrainer.trainParallelOnDisk(Optional.of(initialForest), 2);
|
forestTrainer.trainParallelOnDisk(Optional.of(initialForest), 2);
|
||||||
|
|
||||||
assertEquals(20, directory.listFiles().length);
|
assertEquals(20, directory.listFiles().length);
|
||||||
final Forest<Double, Double> newForest = DataUtils.loadForest(directory, new MeanResponseCombiner());
|
final OnlineForest<Double, Double> newForest = DataUtils.loadOnlineForest(directory, new MeanResponseCombiner());
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
assertEquals(20, newForest.getTrees().size());
|
assertEquals(20, newForest.getNumberOfTrees());
|
||||||
|
|
||||||
final List<String> newForestTreesAsStrings = newForest.getTrees().stream()
|
final List<String> newForestTreesAsStrings = newForest.getTrees().stream()
|
||||||
.map(tree -> tree.toString()).collect(Collectors.toList());
|
.map(tree -> tree.toString()).collect(Collectors.toList());
|
||||||
|
@ -179,9 +179,9 @@ public class TestProvidingInitialForest {
|
||||||
|
|
||||||
assertEquals(20, directory.listFiles().length);
|
assertEquals(20, directory.listFiles().length);
|
||||||
|
|
||||||
final Forest<Double, Double> newForest = DataUtils.loadForest(directory, new MeanResponseCombiner());
|
final OnlineForest<Double, Double> newForest = DataUtils.loadOnlineForest(directory, new MeanResponseCombiner());
|
||||||
|
|
||||||
assertEquals(20, newForest.getTrees().size());
|
assertEquals(20, newForest.getNumberOfTrees());
|
||||||
|
|
||||||
final List<String> newForestTreesAsStrings = newForest.getTrees().stream()
|
final List<String> newForestTreesAsStrings = newForest.getTrees().stream()
|
||||||
.map(tree -> tree.toString()).collect(Collectors.toList());
|
.map(tree -> tree.toString()).collect(Collectors.toList());
|
||||||
|
|
|
@ -24,11 +24,10 @@ import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskRespons
|
||||||
import ca.joeltherrien.randomforest.responses.competingrisk.combiner.CompetingRiskFunctionCombiner;
|
import ca.joeltherrien.randomforest.responses.competingrisk.combiner.CompetingRiskFunctionCombiner;
|
||||||
import ca.joeltherrien.randomforest.responses.competingrisk.combiner.CompetingRiskResponseCombiner;
|
import ca.joeltherrien.randomforest.responses.competingrisk.combiner.CompetingRiskResponseCombiner;
|
||||||
import ca.joeltherrien.randomforest.responses.competingrisk.splitfinder.LogRankSplitFinder;
|
import ca.joeltherrien.randomforest.responses.competingrisk.splitfinder.LogRankSplitFinder;
|
||||||
import ca.joeltherrien.randomforest.tree.Forest;
|
import ca.joeltherrien.randomforest.tree.*;
|
||||||
import ca.joeltherrien.randomforest.tree.ForestTrainer;
|
|
||||||
import ca.joeltherrien.randomforest.tree.TreeTrainer;
|
|
||||||
import ca.joeltherrien.randomforest.utils.DataUtils;
|
import ca.joeltherrien.randomforest.utils.DataUtils;
|
||||||
import ca.joeltherrien.randomforest.utils.ResponseLoader;
|
import ca.joeltherrien.randomforest.utils.ResponseLoader;
|
||||||
|
import ca.joeltherrien.randomforest.utils.RightContinuousStepFunction;
|
||||||
import ca.joeltherrien.randomforest.utils.Utils;
|
import ca.joeltherrien.randomforest.utils.Utils;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
|
|
||||||
|
@ -119,16 +118,21 @@ public class TestSavingLoading {
|
||||||
assertTrue(directory.isDirectory());
|
assertTrue(directory.isDirectory());
|
||||||
assertEquals(NTREE, directory.listFiles().length);
|
assertEquals(NTREE, directory.listFiles().length);
|
||||||
|
|
||||||
final Forest<CompetingRiskFunctions, CompetingRiskFunctions> forest = DataUtils.loadForest(directory, new CompetingRiskFunctionCombiner(new int[]{1,2}, null));
|
final CompetingRiskFunctionCombiner treeResponseCombiner = new CompetingRiskFunctionCombiner(new int[]{1,2}, null);
|
||||||
|
final OnlineForest<CompetingRiskFunctions, CompetingRiskFunctions> onlineForest = DataUtils.loadOnlineForest(directory, treeResponseCombiner);
|
||||||
|
final OfflineForest<CompetingRiskFunctions, CompetingRiskFunctions> offlineForest = new OfflineForest<>(directory, treeResponseCombiner);
|
||||||
|
|
||||||
final CovariateRow predictionRow = getPredictionRow(covariates);
|
final CovariateRow predictionRow = getPredictionRow(covariates);
|
||||||
|
|
||||||
final CompetingRiskFunctions functions = forest.evaluate(predictionRow);
|
final CompetingRiskFunctions functionsOnline = onlineForest.evaluate(predictionRow);
|
||||||
assertNotNull(functions);
|
assertNotNull(functionsOnline);
|
||||||
assertTrue(functions.getCumulativeIncidenceFunction(1).getX().length > 2);
|
assertTrue(functionsOnline.getCumulativeIncidenceFunction(1).getX().length > 2);
|
||||||
|
|
||||||
|
final CompetingRiskFunctions functionsOffline = offlineForest.evaluate(predictionRow);
|
||||||
|
assertTrue(competingFunctionsEqual(functionsOffline, functionsOnline));
|
||||||
|
|
||||||
|
|
||||||
assertEquals(NTREE, forest.getTrees().size());
|
assertEquals(NTREE, onlineForest.getTrees().size());
|
||||||
|
|
||||||
TestUtils.removeFolder(directory);
|
TestUtils.removeFolder(directory);
|
||||||
|
|
||||||
|
@ -159,17 +163,22 @@ public class TestSavingLoading {
|
||||||
assertEquals(NTREE, directory.listFiles().length);
|
assertEquals(NTREE, directory.listFiles().length);
|
||||||
|
|
||||||
|
|
||||||
|
final CompetingRiskFunctionCombiner treeResponseCombiner = new CompetingRiskFunctionCombiner(new int[]{1,2}, null);
|
||||||
final Forest<CompetingRiskFunctions, CompetingRiskFunctions> forest = DataUtils.loadForest(directory, new CompetingRiskFunctionCombiner(new int[]{1,2}, null));
|
final OnlineForest<CompetingRiskFunctions, CompetingRiskFunctions> onlineForest = DataUtils.loadOnlineForest(directory, treeResponseCombiner);
|
||||||
|
final OfflineForest<CompetingRiskFunctions, CompetingRiskFunctions> offlineForest = new OfflineForest<>(directory, treeResponseCombiner);
|
||||||
|
|
||||||
final CovariateRow predictionRow = getPredictionRow(covariates);
|
final CovariateRow predictionRow = getPredictionRow(covariates);
|
||||||
|
|
||||||
final CompetingRiskFunctions functions = forest.evaluate(predictionRow);
|
final CompetingRiskFunctions functionsOnline = onlineForest.evaluate(predictionRow);
|
||||||
assertNotNull(functions);
|
assertNotNull(functionsOnline);
|
||||||
assertTrue(functions.getCumulativeIncidenceFunction(1).getX().length > 2);
|
assertTrue(functionsOnline.getCumulativeIncidenceFunction(1).getX().length > 2);
|
||||||
|
|
||||||
|
|
||||||
assertEquals(NTREE, forest.getTrees().size());
|
final CompetingRiskFunctions functionsOffline = offlineForest.evaluate(predictionRow);
|
||||||
|
assertTrue(competingFunctionsEqual(functionsOffline, functionsOnline));
|
||||||
|
|
||||||
|
|
||||||
|
assertEquals(NTREE, onlineForest.getTrees().size());
|
||||||
|
|
||||||
TestUtils.removeFolder(directory);
|
TestUtils.removeFolder(directory);
|
||||||
|
|
||||||
|
@ -177,6 +186,64 @@ public class TestSavingLoading {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
We don't implement equals() methods on the below mentioned classes because then we'd need to implement an
|
||||||
|
appropriate hashCode() method that's consistent with the equals(), and we only need plain equals() for
|
||||||
|
these tests.
|
||||||
|
*/
|
||||||
|
|
||||||
|
private boolean competingFunctionsEqual(CompetingRiskFunctions f1 ,CompetingRiskFunctions f2){
|
||||||
|
if(!functionsEqual(f1.getSurvivalCurve(), f2.getSurvivalCurve())){
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
for(int i=1; i<=2; i++){
|
||||||
|
if(!functionsEqual(f1.getCauseSpecificHazardFunction(i), f2.getCauseSpecificHazardFunction(i))){
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if(!functionsEqual(f1.getCumulativeIncidenceFunction(i), f2.getCumulativeIncidenceFunction(i))){
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
private boolean functionsEqual(RightContinuousStepFunction f1, RightContinuousStepFunction f2){
|
||||||
|
|
||||||
|
final double[] f1X = f1.getX();
|
||||||
|
final double[] f2X = f2.getX();
|
||||||
|
|
||||||
|
final double[] f1Y = f1.getY();
|
||||||
|
final double[] f2Y = f2.getY();
|
||||||
|
|
||||||
|
// first compare array lengths
|
||||||
|
if(f1X.length != f2X.length){
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if(f1Y.length != f2Y.length){
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO - better comparisons of doubles. I don't really care too much though as this equals method is only being used in tests
|
||||||
|
final double delta = 0.000001;
|
||||||
|
|
||||||
|
if(Math.abs(f1.getDefaultY() - f2.getDefaultY()) > delta){
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
for(int i=0; i < f1X.length; i++){
|
||||||
|
if(Math.abs(f1X[i] - f2X[i]) > delta){
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if(Math.abs(f1Y[i] - f2Y[i]) > delta){
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -16,10 +16,11 @@
|
||||||
|
|
||||||
package ca.joeltherrien.randomforest.competingrisk;
|
package ca.joeltherrien.randomforest.competingrisk;
|
||||||
|
|
||||||
import ca.joeltherrien.randomforest.Row;
|
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskFunctions;
|
||||||
import ca.joeltherrien.randomforest.covariates.Covariate;
|
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
|
||||||
import ca.joeltherrien.randomforest.responses.competingrisk.*;
|
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskUtils;
|
||||||
import ca.joeltherrien.randomforest.tree.Forest;
|
import ca.joeltherrien.randomforest.tree.Forest;
|
||||||
|
import ca.joeltherrien.randomforest.tree.OnlineForest;
|
||||||
import ca.joeltherrien.randomforest.utils.RightContinuousStepFunction;
|
import ca.joeltherrien.randomforest.utils.RightContinuousStepFunction;
|
||||||
import ca.joeltherrien.randomforest.utils.StepFunction;
|
import ca.joeltherrien.randomforest.utils.StepFunction;
|
||||||
import ca.joeltherrien.randomforest.utils.Utils;
|
import ca.joeltherrien.randomforest.utils.Utils;
|
||||||
|
@ -30,8 +31,6 @@ import java.util.List;
|
||||||
|
|
||||||
import static ca.joeltherrien.randomforest.TestUtils.closeEnough;
|
import static ca.joeltherrien.randomforest.TestUtils.closeEnough;
|
||||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
import static org.mockito.Mockito.mock;
|
|
||||||
import static org.mockito.Mockito.when;
|
|
||||||
|
|
||||||
public class TestCompetingRiskErrorRateCalculator {
|
public class TestCompetingRiskErrorRateCalculator {
|
||||||
|
|
||||||
|
@ -48,7 +47,7 @@ public class TestCompetingRiskErrorRateCalculator {
|
||||||
|
|
||||||
final int event = 1;
|
final int event = 1;
|
||||||
|
|
||||||
final Forest<CompetingRiskFunctions, CompetingRiskFunctions> fakeForest = Forest.<CompetingRiskFunctions, CompetingRiskFunctions>builder().build();
|
final Forest<CompetingRiskFunctions, CompetingRiskFunctions> fakeForest = OnlineForest.<CompetingRiskFunctions, CompetingRiskFunctions>builder().build();
|
||||||
|
|
||||||
final double naiveConcordance = CompetingRiskUtils.calculateConcordance(responseList, mortalityArray, event);
|
final double naiveConcordance = CompetingRiskUtils.calculateConcordance(responseList, mortalityArray, event);
|
||||||
|
|
||||||
|
|
|
@ -20,7 +20,7 @@ import java.util.stream.Collectors;
|
||||||
|
|
||||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
|
||||||
public class VariableImportanceCalculatorTest {
|
public class TestVariableImportanceCalculator {
|
||||||
|
|
||||||
/*
|
/*
|
||||||
Since the logic for VariableImportanceCalculator is generic, it will be much easier to test under a regression
|
Since the logic for VariableImportanceCalculator is generic, it will be much easier to test under a regression
|
||||||
|
@ -28,7 +28,7 @@ public class VariableImportanceCalculatorTest {
|
||||||
*/
|
*/
|
||||||
|
|
||||||
// We'l have a very simple Forest of two trees
|
// We'l have a very simple Forest of two trees
|
||||||
private final Forest<Double, Double> forest;
|
private final OnlineForest<Double, Double> forest;
|
||||||
|
|
||||||
|
|
||||||
private final List<Covariate> covariates;
|
private final List<Covariate> covariates;
|
||||||
|
@ -38,7 +38,7 @@ public class VariableImportanceCalculatorTest {
|
||||||
Long setup process; forest is manually constructed so that we can be exactly sure on our variable importance.
|
Long setup process; forest is manually constructed so that we can be exactly sure on our variable importance.
|
||||||
|
|
||||||
*/
|
*/
|
||||||
public VariableImportanceCalculatorTest(){
|
public TestVariableImportanceCalculator(){
|
||||||
final BooleanCovariate booleanCovariate = new BooleanCovariate("x", 0, false);
|
final BooleanCovariate booleanCovariate = new BooleanCovariate("x", 0, false);
|
||||||
final NumericCovariate numericCovariate = new NumericCovariate("y", 1, false);
|
final NumericCovariate numericCovariate = new NumericCovariate("y", 1, false);
|
||||||
final FactorCovariate factorCovariate = new FactorCovariate("z", 2,
|
final FactorCovariate factorCovariate = new FactorCovariate("z", 2,
|
||||||
|
@ -67,10 +67,9 @@ public class VariableImportanceCalculatorTest {
|
||||||
final Tree<Double> tree1 = makeTree(covariates, 0.0, new int[]{1,2,3,4});
|
final Tree<Double> tree1 = makeTree(covariates, 0.0, new int[]{1,2,3,4});
|
||||||
final Tree<Double> tree2 = makeTree(covariates, 2.0, new int[]{5,6,7,8});
|
final Tree<Double> tree2 = makeTree(covariates, 2.0, new int[]{5,6,7,8});
|
||||||
|
|
||||||
this.forest = Forest.<Double, Double>builder()
|
this.forest = OnlineForest.<Double, Double>builder()
|
||||||
.trees(Utils.easyList(tree1, tree2))
|
.trees(Utils.easyList(tree1, tree2))
|
||||||
.treeResponseCombiner(new MeanResponseCombiner())
|
.treeResponseCombiner(new MeanResponseCombiner())
|
||||||
.covariateList(this.covariates)
|
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
// formula; boolean high adds 100; high numeric adds 10
|
// formula; boolean high adds 100; high numeric adds 10
|
Loading…
Reference in a new issue