Fairly significant refactoring;

Made Forest into an abstract class with OnlineForest (in memory; same as previous)
and OfflineForest (reads individual trees only as needed). Many methods were changed.
This commit is contained in:
Joel Therrien 2019-10-02 13:54:45 -07:00
parent 79a9522ba7
commit 54af805d4d
14 changed files with 484 additions and 107 deletions

View file

@ -122,7 +122,7 @@ public class Main {
Utils.reduceListToSize(dataset, n, new Random()); Utils.reduceListToSize(dataset, n, new Random());
final File folder = new File(settings.getSaveTreeLocation()); final File folder = new File(settings.getSaveTreeLocation());
final Forest<?, CompetingRiskFunctions> forest = DataUtils.loadForest(folder, responseCombiner); final Forest<?, CompetingRiskFunctions> forest = DataUtils.loadOnlineForest(folder, responseCombiner);
final boolean useBootstrapPredictions = settings.getTrainingDataLocation().equals(settings.getValidationDataLocation()); final boolean useBootstrapPredictions = settings.getTrainingDataLocation().equals(settings.getValidationDataLocation());

View file

@ -17,31 +17,18 @@
package ca.joeltherrien.randomforest.tree; package ca.joeltherrien.randomforest.tree;
import ca.joeltherrien.randomforest.CovariateRow; import ca.joeltherrien.randomforest.CovariateRow;
import ca.joeltherrien.randomforest.covariates.Covariate;
import lombok.Builder;
import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.TreeMap; import java.util.TreeMap;
import java.util.stream.Collectors; import java.util.stream.Collectors;
@Builder public abstract class Forest<O, FO> {
public class Forest<O, FO> { // O = output of trees, FO = forest output. In practice O == FO, even in competing risk & survival settings
private final List<Tree<O>> trees; public abstract FO evaluate(CovariateRow row);
private final ResponseCombiner<O, FO> treeResponseCombiner; public abstract FO evaluateOOB(CovariateRow row);
private final List<Covariate> covariateList; public abstract Iterable<Tree<O>> getTrees();
public abstract int getNumberOfTrees();
public FO evaluate(CovariateRow row){
return treeResponseCombiner.combine(
trees.stream()
.map(node -> node.evaluate(row))
.collect(Collectors.toList())
);
}
/** /**
* Used primarily in the R package interface to avoid R loops; and for easier parallelization. * Used primarily in the R package interface to avoid R loops; and for easier parallelization.
@ -93,21 +80,6 @@ public class Forest<O, FO> { // O = output of trees, FO = forest output. In prac
.collect(Collectors.toList()); .collect(Collectors.toList());
} }
public FO evaluateOOB(CovariateRow row){
return treeResponseCombiner.combine(
trees.stream()
.filter(tree -> !tree.idInBootstrapSample(row.getId()))
.map(node -> node.evaluate(row))
.collect(Collectors.toList())
);
}
public List<Tree<O>> getTrees(){
return Collections.unmodifiableList(trees);
}
public Map<Integer, Integer> findSplitsByCovariate(){ public Map<Integer, Integer> findSplitsByCovariate(){
final Map<Integer, Integer> countMap = new TreeMap<>(); final Map<Integer, Integer> countMap = new TreeMap<>();
@ -158,4 +130,5 @@ public class Forest<O, FO> { // O = output of trees, FO = forest output. In prac
return countTerminalNodes; return countTerminalNodes;
} }
} }

View file

@ -57,10 +57,10 @@ public class ForestTrainer<Y, TO, FO> {
* in which case its trees are combined with the new one. * in which case its trees are combined with the new one.
* @return A trained forest. * @return A trained forest.
*/ */
public Forest<TO, FO> trainSerialInMemory(Optional<Forest<TO, FO>> initialForest){ public OnlineForest<TO, FO> trainSerialInMemory(Optional<Forest<TO, FO>> initialForest){
final List<Tree<TO>> trees = new ArrayList<>(ntree); final List<Tree<TO>> trees = new ArrayList<>(ntree);
initialForest.ifPresent(forest -> trees.addAll(forest.getTrees())); initialForest.ifPresent(forest -> forest.getTrees().forEach(trees::add));
final Bootstrapper<Row<Y>> bootstrapper = new Bootstrapper<>(data); final Bootstrapper<Row<Y>> bootstrapper = new Bootstrapper<>(data);
@ -77,11 +77,9 @@ public class ForestTrainer<Y, TO, FO> {
System.out.println("Finished"); System.out.println("Finished");
} }
return OnlineForest.<TO, FO>builder()
return Forest.<TO, FO>builder()
.treeResponseCombiner(treeResponseCombiner) .treeResponseCombiner(treeResponseCombiner)
.trees(trees) .trees(trees)
.covariateList(covariates)
.build(); .build();
} }
@ -94,7 +92,7 @@ public class ForestTrainer<Y, TO, FO> {
* There cannot be existing trees if the initial forest is * There cannot be existing trees if the initial forest is
* specified. * specified.
*/ */
public void trainSerialOnDisk(Optional<Forest<TO, FO>> initialForest){ public OfflineForest<TO, FO> trainSerialOnDisk(Optional<Forest<TO, FO>> initialForest){
// First we need to see how many trees there currently are // First we need to see how many trees there currently are
final File folder = new File(saveTreeLocation); final File folder = new File(saveTreeLocation);
if(!folder.exists()){ if(!folder.exists()){
@ -115,17 +113,14 @@ public class ForestTrainer<Y, TO, FO> {
final AtomicInteger treeCount; // tracks how many trees are finished final AtomicInteger treeCount; // tracks how many trees are finished
// Using an AtomicInteger is overkill for serial code, but this lets us reuse TreeSavedWorker // Using an AtomicInteger is overkill for serial code, but this lets us reuse TreeSavedWorker
if(initialForest.isPresent()){ if(initialForest.isPresent()){
final List<Tree<TO>> initialTrees = initialForest.get().getTrees(); int j=0;
for(final Tree<TO> tree : initialForest.get().getTrees()){
for(int j=0; j<initialTrees.size(); j++){
final String filename = "tree-" + (j+1) + ".tree"; final String filename = "tree-" + (j+1) + ".tree";
final Tree<TO> tree = initialTrees.get(j);
saveTree(tree, filename); saveTree(tree, filename);
j++;
} }
treeCount = new AtomicInteger(initialTrees.size()); treeCount = new AtomicInteger(j);
} else{ } else{
treeCount = new AtomicInteger(treeFiles.length); treeCount = new AtomicInteger(treeFiles.length);
} }
@ -153,6 +148,8 @@ public class ForestTrainer<Y, TO, FO> {
System.out.println("Finished"); System.out.println("Finished");
} }
return new OfflineForest<>(folder, treeResponseCombiner);
} }
/** /**
@ -162,7 +159,7 @@ public class ForestTrainer<Y, TO, FO> {
* in which case its trees are combined with the new one. * in which case its trees are combined with the new one.
* @param threads The number of trees to train at once. * @param threads The number of trees to train at once.
*/ */
public Forest<TO, FO> trainParallelInMemory(Optional<Forest<TO, FO>> initialForest, int threads){ public OnlineForest<TO, FO> trainParallelInMemory(Optional<Forest<TO, FO>> initialForest, int threads){
// create a list that is pre-specified in size (I can call the .set method at any index < ntree without // create a list that is pre-specified in size (I can call the .set method at any index < ntree without
// the earlier indexes being filled. // the earlier indexes being filled.
@ -170,11 +167,12 @@ public class ForestTrainer<Y, TO, FO> {
final int startingCount; final int startingCount;
if(initialForest.isPresent()){ if(initialForest.isPresent()){
final List<Tree<TO>> initialTrees = initialForest.get().getTrees(); int j = 0;
for(int j=0; j<initialTrees.size(); j++) { for(final Tree<TO> tree : initialForest.get().getTrees()){
trees.set(j, initialTrees.get(j)); trees.set(j, tree);
j++;
} }
startingCount = initialTrees.size(); startingCount = initialForest.get().getNumberOfTrees();
} }
else{ else{
startingCount = 0; startingCount = 0;
@ -219,7 +217,7 @@ public class ForestTrainer<Y, TO, FO> {
System.out.println("\nFinished"); System.out.println("\nFinished");
} }
return Forest.<TO, FO>builder() return OnlineForest.<TO, FO>builder()
.treeResponseCombiner(treeResponseCombiner) .treeResponseCombiner(treeResponseCombiner)
.trees(trees) .trees(trees)
.build(); .build();
@ -235,7 +233,7 @@ public class ForestTrainer<Y, TO, FO> {
* specified. * specified.
* @param threads The number of trees to train at once. * @param threads The number of trees to train at once.
*/ */
public void trainParallelOnDisk(Optional<Forest<TO, FO>> initialForest, int threads){ public OfflineForest<TO, FO> trainParallelOnDisk(Optional<Forest<TO, FO>> initialForest, int threads){
// First we need to see how many trees there currently are // First we need to see how many trees there currently are
final File folder = new File(saveTreeLocation); final File folder = new File(saveTreeLocation);
if(!folder.exists()){ if(!folder.exists()){
@ -255,17 +253,14 @@ public class ForestTrainer<Y, TO, FO> {
final AtomicInteger treeCount; // tracks how many trees are finished final AtomicInteger treeCount; // tracks how many trees are finished
if(initialForest.isPresent()){ if(initialForest.isPresent()){
final List<Tree<TO>> initialTrees = initialForest.get().getTrees(); int j=0;
for(final Tree<TO> tree : initialForest.get().getTrees()){
for(int j=0; j<initialTrees.size(); j++){
final String filename = "tree-" + (j+1) + ".tree"; final String filename = "tree-" + (j+1) + ".tree";
final Tree<TO> tree = initialTrees.get(j);
saveTree(tree, filename); saveTree(tree, filename);
j++;
} }
treeCount = new AtomicInteger(initialTrees.size()); treeCount = new AtomicInteger(j);
} else{ } else{
treeCount = new AtomicInteger(treeFiles.length); treeCount = new AtomicInteger(treeFiles.length);
} }
@ -309,6 +304,8 @@ public class ForestTrainer<Y, TO, FO> {
System.out.println("\nFinished"); System.out.println("\nFinished");
} }
return new OfflineForest<>(folder, treeResponseCombiner);
} }
private Tree<TO> trainTree(final Bootstrapper<Row<Y>> bootstrapper, Random random){ private Tree<TO> trainTree(final Bootstrapper<Row<Y>> bootstrapper, Random random){

View file

@ -0,0 +1,198 @@
/*
* Copyright (c) 2019 Joel Therrien.
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
package ca.joeltherrien.randomforest.tree;
import ca.joeltherrien.randomforest.CovariateRow;
import ca.joeltherrien.randomforest.utils.IterableOfflineTree;
import lombok.AllArgsConstructor;
import java.io.File;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
@AllArgsConstructor
public class OfflineForest<O, FO> extends Forest<O, FO> {
private final File[] treeFiles;
private final ResponseCombiner<O, FO> treeResponseCombiner;
public OfflineForest(File treeDirectoryPath, ResponseCombiner<O, FO> treeResponseCombiner){
this.treeResponseCombiner = treeResponseCombiner;
if(!treeDirectoryPath.isDirectory()){
throw new IllegalArgumentException("treeDirectoryPath must point to a directory!");
}
this.treeFiles = treeDirectoryPath.listFiles((file, s) -> s.endsWith(".tree"));
}
@Override
public FO evaluate(CovariateRow row) {
final List<O> predictedOutputs = new ArrayList<>(treeFiles.length);
for(final Tree<O> tree : getTrees()){
final O prediction = tree.evaluate(row);
predictedOutputs.add(prediction);
}
return treeResponseCombiner.combine(predictedOutputs);
}
@Override
public FO evaluateOOB(CovariateRow row) {
final List<O> predictedOutputs = new ArrayList<>(treeFiles.length);
for(final Tree<O> tree : getTrees()){
if(!tree.idInBootstrapSample(row.getId())){
final O prediction = tree.evaluate(row);
predictedOutputs.add(prediction);
}
}
return treeResponseCombiner.combine(predictedOutputs);
}
@Override
public List<FO> evaluate(List<? extends CovariateRow> rowList){
final O[][] predictions = (O[][])new Object[rowList.size()][treeFiles.length];
final Iterator<Tree<O>> treeIterator = getTrees().iterator();
for(int treeId = 0; treeId < treeFiles.length; treeId++){
final Tree<O> currentTree = treeIterator.next();
final int tempTreeId = treeId; // Java workaround
IntStream.range(0, rowList.size()).parallel().forEach(
rowId -> {
final CovariateRow row = rowList.get(rowId);
final O prediction = currentTree.evaluate(row);
predictions[rowId][tempTreeId] = prediction;
}
);
}
return Arrays.stream(predictions).parallel()
.map(predArray -> treeResponseCombiner.combine(Arrays.asList(predArray)))
.collect(Collectors.toList());
}
@Override
public List<FO> evaluateSerial(List<? extends CovariateRow> rowList){
final O[][] predictions = (O[][])new Object[rowList.size()][treeFiles.length];
final Iterator<Tree<O>> treeIterator = getTrees().iterator();
for(int treeId = 0; treeId < treeFiles.length; treeId++){
final Tree<O> currentTree = treeIterator.next();
final int tempTreeId = treeId; // Java workaround
IntStream.range(0, rowList.size()).sequential().forEach(
rowId -> {
final CovariateRow row = rowList.get(rowId);
final O prediction = currentTree.evaluate(row);
predictions[rowId][tempTreeId] = prediction;
}
);
}
return Arrays.stream(predictions).sequential()
.map(predArray -> treeResponseCombiner.combine(Arrays.asList(predArray)))
.collect(Collectors.toList());
}
@Override
public List<FO> evaluateOOB(List<? extends CovariateRow> rowList){
final O[][] predictions = (O[][])new Object[rowList.size()][treeFiles.length];
final Iterator<Tree<O>> treeIterator = getTrees().iterator();
for(int treeId = 0; treeId < treeFiles.length; treeId++){
final Tree<O> currentTree = treeIterator.next();
final int tempTreeId = treeId; // Java workaround
IntStream.range(0, rowList.size()).parallel().forEach(
rowId -> {
final CovariateRow row = rowList.get(rowId);
if(!currentTree.idInBootstrapSample(row.getId())){
final O prediction = currentTree.evaluate(row);
predictions[rowId][tempTreeId] = prediction;
} else{
predictions[rowId][tempTreeId] = null;
}
}
);
}
return Arrays.stream(predictions).parallel()
.map(predArray -> {
final List<O> predList = Arrays.stream(predArray).parallel()
.filter(pred -> pred != null).collect(Collectors.toList());
return treeResponseCombiner.combine(predList);
})
.collect(Collectors.toList());
}
@Override
public List<FO> evaluateSerialOOB(List<? extends CovariateRow> rowList){
final O[][] predictions = (O[][])new Object[rowList.size()][treeFiles.length];
final Iterator<Tree<O>> treeIterator = getTrees().iterator();
for(int treeId = 0; treeId < treeFiles.length; treeId++){
final Tree<O> currentTree = treeIterator.next();
final int tempTreeId = treeId; // Java workaround
IntStream.range(0, rowList.size()).sequential().forEach(
rowId -> {
final CovariateRow row = rowList.get(rowId);
if(!currentTree.idInBootstrapSample(row.getId())){
final O prediction = currentTree.evaluate(row);
predictions[rowId][tempTreeId] = prediction;
} else{
predictions[rowId][tempTreeId] = null;
}
}
);
}
return Arrays.stream(predictions).sequential()
.map(predArray -> {
final List<O> predList = Arrays.stream(predArray).sequential()
.filter(pred -> pred != null).collect(Collectors.toList());
return treeResponseCombiner.combine(predList);
})
.collect(Collectors.toList());
}
@Override
public Iterable<Tree<O>> getTrees() {
return new IterableOfflineTree<>(treeFiles);
}
@Override
public int getNumberOfTrees() {
return treeFiles.length;
}
}

View file

@ -0,0 +1,66 @@
/*
* Copyright (c) 2019 Joel Therrien.
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
package ca.joeltherrien.randomforest.tree;
import ca.joeltherrien.randomforest.CovariateRow;
import lombok.Builder;
import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;
@Builder
public class OnlineForest<O, FO> extends Forest<O, FO> { // O = output of trees, FO = forest output. In practice O == FO, even in competing risk & survival settings
private final List<Tree<O>> trees;
private final ResponseCombiner<O, FO> treeResponseCombiner;
@Override
public FO evaluate(CovariateRow row){
return treeResponseCombiner.combine(
trees.stream()
.map(node -> node.evaluate(row))
.collect(Collectors.toList())
);
}
@Override
public FO evaluateOOB(CovariateRow row){
return treeResponseCombiner.combine(
trees.stream()
.filter(tree -> !tree.idInBootstrapSample(row.getId()))
.map(node -> node.evaluate(row))
.collect(Collectors.toList())
);
}
@Override
public List<Tree<O>> getTrees(){
return Collections.unmodifiableList(trees);
}
@Override
public int getNumberOfTrees() {
return trees.size();
}
}

View file

@ -16,9 +16,7 @@
package ca.joeltherrien.randomforest.utils; package ca.joeltherrien.randomforest.utils;
import ca.joeltherrien.randomforest.tree.Forest; import ca.joeltherrien.randomforest.tree.*;
import ca.joeltherrien.randomforest.tree.ResponseCombiner;
import ca.joeltherrien.randomforest.tree.Tree;
import java.io.*; import java.io.*;
import java.util.*; import java.util.*;
@ -27,7 +25,7 @@ import java.util.zip.GZIPOutputStream;
public class DataUtils { public class DataUtils {
public static <O, FO> Forest<O, FO> loadForest(File folder, ResponseCombiner<O, FO> treeResponseCombiner) throws IOException, ClassNotFoundException { public static <O, FO> OnlineForest<O, FO> loadOnlineForest(File folder, ResponseCombiner<O, FO> treeResponseCombiner) throws IOException, ClassNotFoundException {
if(!folder.isDirectory()){ if(!folder.isDirectory()){
throw new IllegalArgumentException("Tree directory must be a directory!"); throw new IllegalArgumentException("Tree directory must be a directory!");
} }
@ -48,16 +46,16 @@ public class DataUtils {
} }
return Forest.<O, FO>builder() return OnlineForest.<O, FO>builder()
.trees(treeList) .trees(treeList)
.treeResponseCombiner(treeResponseCombiner) .treeResponseCombiner(treeResponseCombiner)
.build(); .build();
} }
public static <O, FO> Forest<O, FO> loadForest(String folder, ResponseCombiner<O, FO> treeResponseCombiner) throws IOException, ClassNotFoundException { public static <O, FO> OnlineForest<O, FO> loadOnlineForest(String folder, ResponseCombiner<O, FO> treeResponseCombiner) throws IOException, ClassNotFoundException {
final File directory = new File(folder); final File directory = new File(folder);
return loadForest(directory, treeResponseCombiner); return loadOnlineForest(directory, treeResponseCombiner);
} }
public static void saveObject(Serializable object, String filename) throws IOException { public static void saveObject(Serializable object, String filename) throws IOException {

View file

@ -0,0 +1,68 @@
/*
* Copyright (c) 2019 Joel Therrien.
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
package ca.joeltherrien.randomforest.utils;
import ca.joeltherrien.randomforest.tree.Tree;
import lombok.RequiredArgsConstructor;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.util.Iterator;
import java.util.zip.GZIPInputStream;
@RequiredArgsConstructor
public class IterableOfflineTree<Y> implements Iterable<Tree<Y>> {
private final File[] treeFiles;
@Override
public Iterator<Tree<Y>> iterator() {
return new OfflineTreeIterator<>(treeFiles);
}
@RequiredArgsConstructor
public static class OfflineTreeIterator<Y> implements Iterator<Tree<Y>>{
private final File[] treeFiles;
private int position = 0;
@Override
public boolean hasNext() {
return position < treeFiles.length;
}
@Override
public Tree<Y> next() {
final File treeFile = treeFiles[position];
position++;
try {
final ObjectInputStream inputStream= new ObjectInputStream(new GZIPInputStream(new FileInputStream(treeFile)));
final Tree<Y> tree = (Tree) inputStream.readObject();
return tree;
} catch (IOException | ClassNotFoundException e) {
e.printStackTrace();
throw new RuntimeException("Failed to load tree for " + treeFile.toString());
}
}
}
}

View file

@ -25,6 +25,7 @@ import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskRespons
import java.io.*; import java.io.*;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.stream.IntStream;
import java.util.zip.GZIPInputStream; import java.util.zip.GZIPInputStream;
import java.util.zip.GZIPOutputStream; import java.util.zip.GZIPOutputStream;
@ -198,4 +199,12 @@ public final class RUtils {
return newList; return newList;
} }
public static File[] getTreeFileArray(String folderPath, int endingId){
return (File[]) IntStream.rangeClosed(1, endingId).sequential()
.mapToObj(i -> folderPath + "/tree-" + i + ".tree")
.map(File::new)
.toArray();
}
} }

View file

@ -16,6 +16,8 @@
package ca.joeltherrien.randomforest.utils; package ca.joeltherrien.randomforest.utils;
import lombok.Getter;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.List; import java.util.List;
@ -39,6 +41,7 @@ public final class RightContinuousStepFunction extends StepFunction {
* *
* May not be null. * May not be null.
*/ */
@Getter
private final double defaultY; private final double defaultY;
public RightContinuousStepFunction(double[] x, double[] y, double defaultY) { public RightContinuousStepFunction(double[] x, double[] y, double defaultY) {

View file

@ -214,14 +214,14 @@ public class TestDeterministicForests {
forestTrainer5Trees.trainSerialOnDisk(Optional.empty()); forestTrainer5Trees.trainSerialOnDisk(Optional.empty());
forestTrainer10Trees.trainSerialOnDisk(Optional.empty()); forestTrainer10Trees.trainSerialOnDisk(Optional.empty());
final Forest<Double, Double> forestSerial = DataUtils.loadForest(saveTreeFile, new MeanResponseCombiner()); final Forest<Double, Double> forestSerial = DataUtils.loadOnlineForest(saveTreeFile, new MeanResponseCombiner());
TestUtils.removeFolder(saveTreeFile); TestUtils.removeFolder(saveTreeFile);
verifyTwoForestsEqual(testData, referenceForest, forestSerial); verifyTwoForestsEqual(testData, referenceForest, forestSerial);
forestTrainer5Trees.trainParallelOnDisk(Optional.empty(), 4); forestTrainer5Trees.trainParallelOnDisk(Optional.empty(), 4);
forestTrainer10Trees.trainParallelOnDisk(Optional.empty(), 4); forestTrainer10Trees.trainParallelOnDisk(Optional.empty(), 4);
final Forest<Double, Double> forestParallel = DataUtils.loadForest(saveTreeFile, new MeanResponseCombiner()); final Forest<Double, Double> forestParallel = DataUtils.loadOnlineForest(saveTreeFile, new MeanResponseCombiner());
TestUtils.removeFolder(saveTreeFile); TestUtils.removeFolder(saveTreeFile);
verifyTwoForestsEqual(testData, referenceForest, forestParallel); verifyTwoForestsEqual(testData, referenceForest, forestParallel);
@ -259,7 +259,7 @@ public class TestDeterministicForests {
for(int k=0; k<3; k++){ for(int k=0; k<3; k++){
forestTrainer.trainSerialOnDisk(Optional.empty()); forestTrainer.trainSerialOnDisk(Optional.empty());
final Forest<Double, Double> replicantForest = DataUtils.loadForest(saveTreeFile, responseCombiner); final Forest<Double, Double> replicantForest = DataUtils.loadOnlineForest(saveTreeFile, responseCombiner);
TestUtils.removeFolder(saveTreeFile); TestUtils.removeFolder(saveTreeFile);
verifyTwoForestsEqual(testData, referenceForest, replicantForest); verifyTwoForestsEqual(testData, referenceForest, replicantForest);
} }
@ -274,7 +274,7 @@ public class TestDeterministicForests {
for(int k=0; k<3; k++){ for(int k=0; k<3; k++){
forestTrainer.trainParallelOnDisk(Optional.empty(), 4); forestTrainer.trainParallelOnDisk(Optional.empty(), 4);
final Forest<Double, Double> replicantForest = DataUtils.loadForest(saveTreeFile, responseCombiner); final Forest<Double, Double> replicantForest = DataUtils.loadOnlineForest(saveTreeFile, responseCombiner);
TestUtils.removeFolder(saveTreeFile); TestUtils.removeFolder(saveTreeFile);
verifyTwoForestsEqual(testData, referenceForest, replicantForest); verifyTwoForestsEqual(testData, referenceForest, replicantForest);
} }

View file

@ -20,8 +20,8 @@ import ca.joeltherrien.randomforest.covariates.Covariate;
import ca.joeltherrien.randomforest.covariates.numeric.NumericCovariate; import ca.joeltherrien.randomforest.covariates.numeric.NumericCovariate;
import ca.joeltherrien.randomforest.responses.regression.MeanResponseCombiner; import ca.joeltherrien.randomforest.responses.regression.MeanResponseCombiner;
import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceSplitFinder; import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceSplitFinder;
import ca.joeltherrien.randomforest.tree.Forest;
import ca.joeltherrien.randomforest.tree.ForestTrainer; import ca.joeltherrien.randomforest.tree.ForestTrainer;
import ca.joeltherrien.randomforest.tree.OnlineForest;
import ca.joeltherrien.randomforest.tree.Tree; import ca.joeltherrien.randomforest.tree.Tree;
import ca.joeltherrien.randomforest.tree.TreeTrainer; import ca.joeltherrien.randomforest.tree.TreeTrainer;
import ca.joeltherrien.randomforest.utils.DataUtils; import ca.joeltherrien.randomforest.utils.DataUtils;
@ -39,7 +39,7 @@ import static org.junit.jupiter.api.Assertions.*;
public class TestProvidingInitialForest { public class TestProvidingInitialForest {
private Forest<Double, Double> initialForest; private OnlineForest<Double, Double> initialForest;
private List<Covariate> covariateList; private List<Covariate> covariateList;
private List<Row<Double>> data; private List<Row<Double>> data;
@ -107,8 +107,8 @@ public class TestProvidingInitialForest {
public void testSerialInMemory(){ public void testSerialInMemory(){
final ForestTrainer<Double, Double, Double> forestTrainer = getForestTrainer(null, 20); final ForestTrainer<Double, Double, Double> forestTrainer = getForestTrainer(null, 20);
final Forest<Double, Double> newForest = forestTrainer.trainSerialInMemory(Optional.of(initialForest)); final OnlineForest<Double, Double> newForest = forestTrainer.trainSerialInMemory(Optional.of(initialForest));
assertEquals(20, newForest.getTrees().size()); assertEquals(20, newForest.getNumberOfTrees());
for(Tree<Double> initialTree : initialForest.getTrees()){ for(Tree<Double> initialTree : initialForest.getTrees()){
assertTrue(newForest.getTrees().contains(initialTree)); assertTrue(newForest.getTrees().contains(initialTree));
@ -124,8 +124,8 @@ public class TestProvidingInitialForest {
public void testParallelInMemory(){ public void testParallelInMemory(){
final ForestTrainer<Double, Double, Double> forestTrainer = getForestTrainer(null, 20); final ForestTrainer<Double, Double, Double> forestTrainer = getForestTrainer(null, 20);
final Forest<Double, Double> newForest = forestTrainer.trainParallelInMemory(Optional.of(initialForest), 2); final OnlineForest<Double, Double> newForest = forestTrainer.trainParallelInMemory(Optional.of(initialForest), 2);
assertEquals(20, newForest.getTrees().size()); assertEquals(20, newForest.getNumberOfTrees());
for(Tree<Double> initialTree : initialForest.getTrees()){ for(Tree<Double> initialTree : initialForest.getTrees()){
assertTrue(newForest.getTrees().contains(initialTree)); assertTrue(newForest.getTrees().contains(initialTree));
@ -149,11 +149,11 @@ public class TestProvidingInitialForest {
forestTrainer.trainParallelOnDisk(Optional.of(initialForest), 2); forestTrainer.trainParallelOnDisk(Optional.of(initialForest), 2);
assertEquals(20, directory.listFiles().length); assertEquals(20, directory.listFiles().length);
final Forest<Double, Double> newForest = DataUtils.loadForest(directory, new MeanResponseCombiner()); final OnlineForest<Double, Double> newForest = DataUtils.loadOnlineForest(directory, new MeanResponseCombiner());
assertEquals(20, newForest.getTrees().size()); assertEquals(20, newForest.getNumberOfTrees());
final List<String> newForestTreesAsStrings = newForest.getTrees().stream() final List<String> newForestTreesAsStrings = newForest.getTrees().stream()
.map(tree -> tree.toString()).collect(Collectors.toList()); .map(tree -> tree.toString()).collect(Collectors.toList());
@ -179,9 +179,9 @@ public class TestProvidingInitialForest {
assertEquals(20, directory.listFiles().length); assertEquals(20, directory.listFiles().length);
final Forest<Double, Double> newForest = DataUtils.loadForest(directory, new MeanResponseCombiner()); final OnlineForest<Double, Double> newForest = DataUtils.loadOnlineForest(directory, new MeanResponseCombiner());
assertEquals(20, newForest.getTrees().size()); assertEquals(20, newForest.getNumberOfTrees());
final List<String> newForestTreesAsStrings = newForest.getTrees().stream() final List<String> newForestTreesAsStrings = newForest.getTrees().stream()
.map(tree -> tree.toString()).collect(Collectors.toList()); .map(tree -> tree.toString()).collect(Collectors.toList());

View file

@ -24,11 +24,10 @@ import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskRespons
import ca.joeltherrien.randomforest.responses.competingrisk.combiner.CompetingRiskFunctionCombiner; import ca.joeltherrien.randomforest.responses.competingrisk.combiner.CompetingRiskFunctionCombiner;
import ca.joeltherrien.randomforest.responses.competingrisk.combiner.CompetingRiskResponseCombiner; import ca.joeltherrien.randomforest.responses.competingrisk.combiner.CompetingRiskResponseCombiner;
import ca.joeltherrien.randomforest.responses.competingrisk.splitfinder.LogRankSplitFinder; import ca.joeltherrien.randomforest.responses.competingrisk.splitfinder.LogRankSplitFinder;
import ca.joeltherrien.randomforest.tree.Forest; import ca.joeltherrien.randomforest.tree.*;
import ca.joeltherrien.randomforest.tree.ForestTrainer;
import ca.joeltherrien.randomforest.tree.TreeTrainer;
import ca.joeltherrien.randomforest.utils.DataUtils; import ca.joeltherrien.randomforest.utils.DataUtils;
import ca.joeltherrien.randomforest.utils.ResponseLoader; import ca.joeltherrien.randomforest.utils.ResponseLoader;
import ca.joeltherrien.randomforest.utils.RightContinuousStepFunction;
import ca.joeltherrien.randomforest.utils.Utils; import ca.joeltherrien.randomforest.utils.Utils;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
@ -119,16 +118,21 @@ public class TestSavingLoading {
assertTrue(directory.isDirectory()); assertTrue(directory.isDirectory());
assertEquals(NTREE, directory.listFiles().length); assertEquals(NTREE, directory.listFiles().length);
final Forest<CompetingRiskFunctions, CompetingRiskFunctions> forest = DataUtils.loadForest(directory, new CompetingRiskFunctionCombiner(new int[]{1,2}, null)); final CompetingRiskFunctionCombiner treeResponseCombiner = new CompetingRiskFunctionCombiner(new int[]{1,2}, null);
final OnlineForest<CompetingRiskFunctions, CompetingRiskFunctions> onlineForest = DataUtils.loadOnlineForest(directory, treeResponseCombiner);
final OfflineForest<CompetingRiskFunctions, CompetingRiskFunctions> offlineForest = new OfflineForest<>(directory, treeResponseCombiner);
final CovariateRow predictionRow = getPredictionRow(covariates); final CovariateRow predictionRow = getPredictionRow(covariates);
final CompetingRiskFunctions functions = forest.evaluate(predictionRow); final CompetingRiskFunctions functionsOnline = onlineForest.evaluate(predictionRow);
assertNotNull(functions); assertNotNull(functionsOnline);
assertTrue(functions.getCumulativeIncidenceFunction(1).getX().length > 2); assertTrue(functionsOnline.getCumulativeIncidenceFunction(1).getX().length > 2);
final CompetingRiskFunctions functionsOffline = offlineForest.evaluate(predictionRow);
assertTrue(competingFunctionsEqual(functionsOffline, functionsOnline));
assertEquals(NTREE, forest.getTrees().size()); assertEquals(NTREE, onlineForest.getTrees().size());
TestUtils.removeFolder(directory); TestUtils.removeFolder(directory);
@ -159,17 +163,22 @@ public class TestSavingLoading {
assertEquals(NTREE, directory.listFiles().length); assertEquals(NTREE, directory.listFiles().length);
final CompetingRiskFunctionCombiner treeResponseCombiner = new CompetingRiskFunctionCombiner(new int[]{1,2}, null);
final Forest<CompetingRiskFunctions, CompetingRiskFunctions> forest = DataUtils.loadForest(directory, new CompetingRiskFunctionCombiner(new int[]{1,2}, null)); final OnlineForest<CompetingRiskFunctions, CompetingRiskFunctions> onlineForest = DataUtils.loadOnlineForest(directory, treeResponseCombiner);
final OfflineForest<CompetingRiskFunctions, CompetingRiskFunctions> offlineForest = new OfflineForest<>(directory, treeResponseCombiner);
final CovariateRow predictionRow = getPredictionRow(covariates); final CovariateRow predictionRow = getPredictionRow(covariates);
final CompetingRiskFunctions functions = forest.evaluate(predictionRow); final CompetingRiskFunctions functionsOnline = onlineForest.evaluate(predictionRow);
assertNotNull(functions); assertNotNull(functionsOnline);
assertTrue(functions.getCumulativeIncidenceFunction(1).getX().length > 2); assertTrue(functionsOnline.getCumulativeIncidenceFunction(1).getX().length > 2);
assertEquals(NTREE, forest.getTrees().size()); final CompetingRiskFunctions functionsOffline = offlineForest.evaluate(predictionRow);
assertTrue(competingFunctionsEqual(functionsOffline, functionsOnline));
assertEquals(NTREE, onlineForest.getTrees().size());
TestUtils.removeFolder(directory); TestUtils.removeFolder(directory);
@ -177,6 +186,64 @@ public class TestSavingLoading {
} }
/*
We don't implement equals() methods on the below mentioned classes because then we'd need to implement an
appropriate hashCode() method that's consistent with the equals(), and we only need plain equals() for
these tests.
*/
private boolean competingFunctionsEqual(CompetingRiskFunctions f1 ,CompetingRiskFunctions f2){
if(!functionsEqual(f1.getSurvivalCurve(), f2.getSurvivalCurve())){
return false;
}
for(int i=1; i<=2; i++){
if(!functionsEqual(f1.getCauseSpecificHazardFunction(i), f2.getCauseSpecificHazardFunction(i))){
return false;
}
if(!functionsEqual(f1.getCumulativeIncidenceFunction(i), f2.getCumulativeIncidenceFunction(i))){
return false;
}
}
return true;
}
private boolean functionsEqual(RightContinuousStepFunction f1, RightContinuousStepFunction f2){
final double[] f1X = f1.getX();
final double[] f2X = f2.getX();
final double[] f1Y = f1.getY();
final double[] f2Y = f2.getY();
// first compare array lengths
if(f1X.length != f2X.length){
return false;
}
if(f1Y.length != f2Y.length){
return false;
}
// TODO - better comparisons of doubles. I don't really care too much though as this equals method is only being used in tests
final double delta = 0.000001;
if(Math.abs(f1.getDefaultY() - f2.getDefaultY()) > delta){
return false;
}
for(int i=0; i < f1X.length; i++){
if(Math.abs(f1X[i] - f2X[i]) > delta){
return false;
}
if(Math.abs(f1Y[i] - f2Y[i]) > delta){
return false;
}
}
return true;
}
} }

View file

@ -16,10 +16,11 @@
package ca.joeltherrien.randomforest.competingrisk; package ca.joeltherrien.randomforest.competingrisk;
import ca.joeltherrien.randomforest.Row; import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskFunctions;
import ca.joeltherrien.randomforest.covariates.Covariate; import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
import ca.joeltherrien.randomforest.responses.competingrisk.*; import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskUtils;
import ca.joeltherrien.randomforest.tree.Forest; import ca.joeltherrien.randomforest.tree.Forest;
import ca.joeltherrien.randomforest.tree.OnlineForest;
import ca.joeltherrien.randomforest.utils.RightContinuousStepFunction; import ca.joeltherrien.randomforest.utils.RightContinuousStepFunction;
import ca.joeltherrien.randomforest.utils.StepFunction; import ca.joeltherrien.randomforest.utils.StepFunction;
import ca.joeltherrien.randomforest.utils.Utils; import ca.joeltherrien.randomforest.utils.Utils;
@ -30,8 +31,6 @@ import java.util.List;
import static ca.joeltherrien.randomforest.TestUtils.closeEnough; import static ca.joeltherrien.randomforest.TestUtils.closeEnough;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
public class TestCompetingRiskErrorRateCalculator { public class TestCompetingRiskErrorRateCalculator {
@ -48,7 +47,7 @@ public class TestCompetingRiskErrorRateCalculator {
final int event = 1; final int event = 1;
final Forest<CompetingRiskFunctions, CompetingRiskFunctions> fakeForest = Forest.<CompetingRiskFunctions, CompetingRiskFunctions>builder().build(); final Forest<CompetingRiskFunctions, CompetingRiskFunctions> fakeForest = OnlineForest.<CompetingRiskFunctions, CompetingRiskFunctions>builder().build();
final double naiveConcordance = CompetingRiskUtils.calculateConcordance(responseList, mortalityArray, event); final double naiveConcordance = CompetingRiskUtils.calculateConcordance(responseList, mortalityArray, event);

View file

@ -20,7 +20,7 @@ import java.util.stream.Collectors;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
public class VariableImportanceCalculatorTest { public class TestVariableImportanceCalculator {
/* /*
Since the logic for VariableImportanceCalculator is generic, it will be much easier to test under a regression Since the logic for VariableImportanceCalculator is generic, it will be much easier to test under a regression
@ -28,7 +28,7 @@ public class VariableImportanceCalculatorTest {
*/ */
// We'l have a very simple Forest of two trees // We'l have a very simple Forest of two trees
private final Forest<Double, Double> forest; private final OnlineForest<Double, Double> forest;
private final List<Covariate> covariates; private final List<Covariate> covariates;
@ -38,7 +38,7 @@ public class VariableImportanceCalculatorTest {
Long setup process; forest is manually constructed so that we can be exactly sure on our variable importance. Long setup process; forest is manually constructed so that we can be exactly sure on our variable importance.
*/ */
public VariableImportanceCalculatorTest(){ public TestVariableImportanceCalculator(){
final BooleanCovariate booleanCovariate = new BooleanCovariate("x", 0, false); final BooleanCovariate booleanCovariate = new BooleanCovariate("x", 0, false);
final NumericCovariate numericCovariate = new NumericCovariate("y", 1, false); final NumericCovariate numericCovariate = new NumericCovariate("y", 1, false);
final FactorCovariate factorCovariate = new FactorCovariate("z", 2, final FactorCovariate factorCovariate = new FactorCovariate("z", 2,
@ -67,10 +67,9 @@ public class VariableImportanceCalculatorTest {
final Tree<Double> tree1 = makeTree(covariates, 0.0, new int[]{1,2,3,4}); final Tree<Double> tree1 = makeTree(covariates, 0.0, new int[]{1,2,3,4});
final Tree<Double> tree2 = makeTree(covariates, 2.0, new int[]{5,6,7,8}); final Tree<Double> tree2 = makeTree(covariates, 2.0, new int[]{5,6,7,8});
this.forest = Forest.<Double, Double>builder() this.forest = OnlineForest.<Double, Double>builder()
.trees(Utils.easyList(tree1, tree2)) .trees(Utils.easyList(tree1, tree2))
.treeResponseCombiner(new MeanResponseCombiner()) .treeResponseCombiner(new MeanResponseCombiner())
.covariateList(this.covariates)
.build(); .build();
// formula; boolean high adds 100; high numeric adds 10 // formula; boolean high adds 100; high numeric adds 10