Add support for providing an initial forest to add trees to

This commit is contained in:
Joel Therrien 2019-06-07 19:55:44 -07:00
parent 7da3bd14a5
commit 22accdb263
7 changed files with 365 additions and 51 deletions

View file

@ -39,6 +39,7 @@ import java.io.File;
import java.io.IOException; import java.io.IOException;
import java.io.PrintWriter; import java.io.PrintWriter;
import java.util.List; import java.util.List;
import java.util.Optional;
import java.util.Random; import java.util.Random;
public class Main { public class Main {
@ -73,16 +74,16 @@ public class Main {
if(settings.isSaveProgress()){ if(settings.isSaveProgress()){
if(settings.getNumberOfThreads() > 1){ if(settings.getNumberOfThreads() > 1){
forestTrainer.trainParallelOnDisk(settings.getNumberOfThreads()); forestTrainer.trainParallelOnDisk(Optional.empty(), settings.getNumberOfThreads());
} else{ } else{
forestTrainer.trainSerialOnDisk(); forestTrainer.trainSerialOnDisk(Optional.empty());
} }
} }
else{ else{
if(settings.getNumberOfThreads() > 1){ if(settings.getNumberOfThreads() > 1){
forestTrainer.trainParallelInMemory(settings.getNumberOfThreads()); forestTrainer.trainParallelInMemory(Optional.empty(), settings.getNumberOfThreads());
} else{ } else{
forestTrainer.trainSerialInMemory(); forestTrainer.trainSerialInMemory(Optional.empty());
} }
} }
} }

View file

@ -21,16 +21,11 @@ import ca.joeltherrien.randomforest.Row;
import ca.joeltherrien.randomforest.Settings; import ca.joeltherrien.randomforest.Settings;
import ca.joeltherrien.randomforest.covariates.Covariate; import ca.joeltherrien.randomforest.covariates.Covariate;
import ca.joeltherrien.randomforest.utils.DataUtils; import ca.joeltherrien.randomforest.utils.DataUtils;
import lombok.AccessLevel; import lombok.*;
import lombok.AllArgsConstructor;
import lombok.Builder;
import java.io.File; import java.io.File;
import java.io.IOException; import java.io.IOException;
import java.util.ArrayList; import java.util.*;
import java.util.Arrays;
import java.util.List;
import java.util.Random;
import java.util.concurrent.ExecutorService; import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors; import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
@ -48,7 +43,9 @@ public class ForestTrainer<Y, TO, FO> {
private final List<Row<Y>> data; private final List<Row<Y>> data;
// number of trees to try // number of trees to try
private final int ntree; @Getter
@Setter
private int ntree;
private final boolean displayProgress; // whether to print to standard output our progress; not always desirable private final boolean displayProgress; // whether to print to standard output our progress; not always desirable
private final String saveTreeLocation; private final String saveTreeLocation;
@ -72,12 +69,21 @@ public class ForestTrainer<Y, TO, FO> {
} }
} }
public Forest<TO, FO> trainSerialInMemory(){ /**
* Train a forest in memory using a single core
*
* @param initialForest An Optional possibly containing a pre-trained forest,
* in which case its trees are combined with the new one.
* @return A trained forest.
*/
public Forest<TO, FO> trainSerialInMemory(Optional<Forest<TO, FO>> initialForest){
final List<Tree<TO>> trees = new ArrayList<>(ntree); final List<Tree<TO>> trees = new ArrayList<>(ntree);
initialForest.ifPresent(forest -> trees.addAll(forest.getTrees()));
final Bootstrapper<Row<Y>> bootstrapper = new Bootstrapper<>(data); final Bootstrapper<Row<Y>> bootstrapper = new Bootstrapper<>(data);
for(int j=0; j<ntree; j++){ for(int j=trees.size(); j<ntree; j++){
if(displayProgress){ if(displayProgress){
System.out.print("\rFinished tree " + j + "/" + ntree + " trees"); System.out.print("\rFinished tree " + j + "/" + ntree + " trees");
} }
@ -99,7 +105,15 @@ public class ForestTrainer<Y, TO, FO> {
} }
public void trainSerialOnDisk(){ /**
* Train a forest on the disk using a single core.
*
* @param initialForest An Optional possibly containing a pre-trained forest,
* in which case its trees are combined with the new one.
* There cannot be existing trees if the initial forest is
* specified.
*/
public void trainSerialOnDisk(Optional<Forest<TO, FO>> initialForest){
// First we need to see how many trees there currently are // First we need to see how many trees there currently are
final File folder = new File(saveTreeLocation); final File folder = new File(saveTreeLocation);
if(!folder.exists()){ if(!folder.exists()){
@ -112,21 +126,42 @@ public class ForestTrainer<Y, TO, FO> {
final File[] treeFiles = folder.listFiles((file, s) -> s.endsWith(".tree")); final File[] treeFiles = folder.listFiles((file, s) -> s.endsWith(".tree"));
final List<String> treeFileNames = Arrays.stream(treeFiles).map(file -> file.getName()).collect(Collectors.toList()); final List<String> treeFileNames = Arrays.stream(treeFiles).map(file -> file.getName()).collect(Collectors.toList());
final AtomicInteger treeCount = new AtomicInteger(treeFiles.length); // tracks how many trees are finished
// Using an AtomicInteger is overkill for serial code, but this lets use reuse TreeSavedWorker
for(int j=0; j<ntree; j++){ if(initialForest.isPresent() & treeFiles.length > 0){
throw new IllegalArgumentException("An initial forest is present but trees are also present; not clear how to integrate initial forest into new forest");
}
final AtomicInteger treeCount; // tracks how many trees are finished
// Using an AtomicInteger is overkill for serial code, but this lets us reuse TreeSavedWorker
if(initialForest.isPresent()){
final List<Tree<TO>> initialTrees = initialForest.get().getTrees();
for(int j=0; j<initialTrees.size(); j++){
final String filename = "tree-" + (j+1) + ".tree";
final Tree<TO> tree = initialTrees.get(j);
saveTree(tree, filename);
}
treeCount = new AtomicInteger(initialTrees.size());
} else{
treeCount = new AtomicInteger(treeFiles.length);
}
while(treeCount.get() < ntree){
if(displayProgress) { if(displayProgress) {
System.out.print("\rFinished " + treeCount.get() + "/" + ntree + " trees"); System.out.print("\rFinished " + treeCount.get() + "/" + ntree + " trees");
} }
final String treeFileName = "tree-" + (j+1) + ".tree"; final String treeFileName = "tree-" + (treeCount.get() + 1) + ".tree";
if(treeFileNames.contains(treeFileName)){ if(treeFileNames.contains(treeFileName)){
continue; continue;
} }
final Random random = new Random(this.randomSeed + j); final Random random = new Random(this.randomSeed + treeCount.get());
final Runnable worker = new TreeSavedWorker(data, treeFileName, treeCount, random); final Runnable worker = new TreeSavedWorker(data, treeFileName, treeCount, random);
worker.run(); worker.run();
@ -139,15 +174,34 @@ public class ForestTrainer<Y, TO, FO> {
} }
public Forest<TO, FO> trainParallelInMemory(int threads){ /**
* Train a forest in memory using the specified number of threads.
*
* @param initialForest An Optional possibly containing a pre-trained forest,
* in which case its trees are combined with the new one.
* @param threads The number of trees to train at once.
*/
public Forest<TO, FO> trainParallelInMemory(Optional<Forest<TO, FO>> initialForest, int threads){
// create a list that is prespecified in size (I can call the .set method at any index < ntree without // create a list that is pre-specified in size (I can call the .set method at any index < ntree without
// the earlier indexes being filled. // the earlier indexes being filled.
final List<Tree<TO>> trees = Stream.<Tree<TO>>generate(() -> null).limit(ntree).collect(Collectors.toList()); final List<Tree<TO>> trees = Stream.<Tree<TO>>generate(() -> null).limit(ntree).collect(Collectors.toList());
final int startingCount;
if(initialForest.isPresent()){
final List<Tree<TO>> initialTrees = initialForest.get().getTrees();
for(int j=0; j<initialTrees.size(); j++) {
trees.set(j, initialTrees.get(j));
}
startingCount = initialTrees.size();
}
else{
startingCount = 0;
}
final ExecutorService executorService = Executors.newFixedThreadPool(threads); final ExecutorService executorService = Executors.newFixedThreadPool(threads);
for(int j=0; j<ntree; j++){ for(int j=startingCount; j<ntree; j++){
final Random random = new Random(this.randomSeed + j); final Random random = new Random(this.randomSeed + j);
final Runnable worker = new TreeInMemoryWorker(data, j, trees, random); final Runnable worker = new TreeInMemoryWorker(data, j, trees, random);
executorService.execute(worker); executorService.execute(worker);
@ -191,8 +245,16 @@ public class ForestTrainer<Y, TO, FO> {
} }
/**
public void trainParallelOnDisk(int threads){ * Train a forest on the disk using a specified number of threads.
*
* @param initialForest An Optional possibly containing a pre-trained forest,
* in which case its trees are combined with the new one.
* There cannot be existing trees if the initial forest is
* specified.
* @param threads The number of trees to train at once.
*/
public void trainParallelOnDisk(Optional<Forest<TO, FO>> initialForest, int threads){
// First we need to see how many trees there currently are // First we need to see how many trees there currently are
final File folder = new File(saveTreeLocation); final File folder = new File(saveTreeLocation);
if(!folder.exists()){ if(!folder.exists()){
@ -205,11 +267,31 @@ public class ForestTrainer<Y, TO, FO> {
final File[] treeFiles = folder.listFiles((file, s) -> s.endsWith(".tree")); final File[] treeFiles = folder.listFiles((file, s) -> s.endsWith(".tree"));
final List<String> treeFileNames = Arrays.stream(treeFiles).map(file -> file.getName()).collect(Collectors.toList()); final List<String> treeFileNames = Arrays.stream(treeFiles).map(file -> file.getName()).collect(Collectors.toList());
final AtomicInteger treeCount = new AtomicInteger(treeFiles.length); // tracks how many trees are finished
if(initialForest.isPresent() & treeFiles.length > 0){
throw new IllegalArgumentException("An initial forest is present but trees are also present; not clear how to integrate initial forest into new forest");
}
final AtomicInteger treeCount; // tracks how many trees are finished
if(initialForest.isPresent()){
final List<Tree<TO>> initialTrees = initialForest.get().getTrees();
for(int j=0; j<initialTrees.size(); j++){
final String filename = "tree-" + (j+1) + ".tree";
final Tree<TO> tree = initialTrees.get(j);
saveTree(tree, filename);
}
treeCount = new AtomicInteger(initialTrees.size());
} else{
treeCount = new AtomicInteger(treeFiles.length);
}
final ExecutorService executorService = Executors.newFixedThreadPool(threads); final ExecutorService executorService = Executors.newFixedThreadPool(threads);
for(int j=0; j<ntree; j++){ for(int j=treeCount.get(); j<ntree; j++){
final String treeFileName = "tree-" + (j+1) + ".tree"; final String treeFileName = "tree-" + (j+1) + ".tree";
if(treeFileNames.contains(treeFileName)){ if(treeFileNames.contains(treeFileName)){
continue; continue;
@ -253,6 +335,17 @@ public class ForestTrainer<Y, TO, FO> {
return treeTrainer.growTree(bootstrappedData, random); return treeTrainer.growTree(bootstrappedData, random);
} }
private void saveTree(Tree<TO> tree, String filename){
try {
DataUtils.saveObject(tree, saveTreeLocation + "/" + filename);
} catch (IOException e) {
System.err.println("IOException while saving " + filename);
e.printStackTrace();
System.err.println("Quitting program");
System.exit(1);
}
}
private class TreeInMemoryWorker implements Runnable { private class TreeInMemoryWorker implements Runnable {
@ -297,14 +390,7 @@ public class ForestTrainer<Y, TO, FO> {
public void run() { public void run() {
final Tree<TO> tree = trainTree(bootstrapper, random); final Tree<TO> tree = trainTree(bootstrapper, random);
try { saveTree(tree, filename);
DataUtils.saveObject(tree, saveTreeLocation + "/" + filename);
} catch (IOException e) {
System.err.println("IOException while saving " + filename);
e.printStackTrace();
System.err.println("Quitting program");
System.exit(1);
}
treeCount.incrementAndGet(); treeCount.incrementAndGet();

View file

@ -33,6 +33,7 @@ import java.io.File;
import java.io.IOException; import java.io.IOException;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.Optional;
import java.util.Random; import java.util.Random;
public class TestDeterministicForests { public class TestDeterministicForests {
@ -144,7 +145,7 @@ public class TestDeterministicForests {
.build(); .build();
// By training the referenceForest through one method we also verify that all the methods produce the same forests for a given seed. // By training the referenceForest through one method we also verify that all the methods produce the same forests for a given seed.
final Forest<Double, Double> referenceForest = forestTrainer.trainSerialInMemory(); final Forest<Double, Double> referenceForest = forestTrainer.trainSerialInMemory(Optional.empty());
verifySerialInMemoryTraining(referenceForest, forestTrainer, testData); verifySerialInMemoryTraining(referenceForest, forestTrainer, testData);
verifyParallelInMemoryTraining(referenceForest, forestTrainer, testData); verifyParallelInMemoryTraining(referenceForest, forestTrainer, testData);
@ -206,20 +207,20 @@ public class TestDeterministicForests {
.build(); .build();
// By training the referenceForest through one method we also verify that all the methods produce the same forests for a given seed. // By training the referenceForest through one method we also verify that all the methods produce the same forests for a given seed.
final Forest<Double, Double> referenceForest = forestTrainer10Trees.trainSerialInMemory(); final Forest<Double, Double> referenceForest = forestTrainer10Trees.trainSerialInMemory(Optional.empty());
final File saveTreeFile = new File(saveTreeLocation); final File saveTreeFile = new File(saveTreeLocation);
forestTrainer5Trees.trainSerialOnDisk(); forestTrainer5Trees.trainSerialOnDisk(Optional.empty());
forestTrainer10Trees.trainSerialOnDisk(); forestTrainer10Trees.trainSerialOnDisk(Optional.empty());
final Forest<Double, Double> forestSerial = DataUtils.loadForest(saveTreeFile, new MeanResponseCombiner()); final Forest<Double, Double> forestSerial = DataUtils.loadForest(saveTreeFile, new MeanResponseCombiner());
TestUtils.removeFolder(saveTreeFile); TestUtils.removeFolder(saveTreeFile);
verifyTwoForestsEqual(testData, referenceForest, forestSerial); verifyTwoForestsEqual(testData, referenceForest, forestSerial);
forestTrainer5Trees.trainParallelOnDisk(4); forestTrainer5Trees.trainParallelOnDisk(Optional.empty(), 4);
forestTrainer10Trees.trainParallelOnDisk(4); forestTrainer10Trees.trainParallelOnDisk(Optional.empty(), 4);
final Forest<Double, Double> forestParallel = DataUtils.loadForest(saveTreeFile, new MeanResponseCombiner()); final Forest<Double, Double> forestParallel = DataUtils.loadForest(saveTreeFile, new MeanResponseCombiner());
TestUtils.removeFolder(saveTreeFile); TestUtils.removeFolder(saveTreeFile);
verifyTwoForestsEqual(testData, referenceForest, forestParallel); verifyTwoForestsEqual(testData, referenceForest, forestParallel);
@ -232,7 +233,7 @@ public class TestDeterministicForests {
List<Row<Double>> testData){ List<Row<Double>> testData){
for(int k=0; k<3; k++){ for(int k=0; k<3; k++){
final Forest<Double, Double> replicantForest = forestTrainer.trainSerialInMemory(); final Forest<Double, Double> replicantForest = forestTrainer.trainSerialInMemory(Optional.empty());
verifyTwoForestsEqual(testData, referenceForest, replicantForest); verifyTwoForestsEqual(testData, referenceForest, replicantForest);
} }
} }
@ -243,7 +244,7 @@ public class TestDeterministicForests {
List<Row<Double>> testData){ List<Row<Double>> testData){
for(int k=0; k<3; k++){ for(int k=0; k<3; k++){
final Forest<Double, Double> replicantForest = forestTrainer.trainParallelInMemory(4); final Forest<Double, Double> replicantForest = forestTrainer.trainParallelInMemory(Optional.empty(), 4);
verifyTwoForestsEqual(testData, referenceForest, replicantForest); verifyTwoForestsEqual(testData, referenceForest, replicantForest);
} }
} }
@ -257,7 +258,7 @@ public class TestDeterministicForests {
final File saveTreeFile = new File(saveTreeLocation); final File saveTreeFile = new File(saveTreeLocation);
for(int k=0; k<3; k++){ for(int k=0; k<3; k++){
forestTrainer.trainSerialOnDisk(); forestTrainer.trainSerialOnDisk(Optional.empty());
final Forest<Double, Double> replicantForest = DataUtils.loadForest(saveTreeFile, responseCombiner); final Forest<Double, Double> replicantForest = DataUtils.loadForest(saveTreeFile, responseCombiner);
TestUtils.removeFolder(saveTreeFile); TestUtils.removeFolder(saveTreeFile);
verifyTwoForestsEqual(testData, referenceForest, replicantForest); verifyTwoForestsEqual(testData, referenceForest, replicantForest);
@ -272,7 +273,7 @@ public class TestDeterministicForests {
final File saveTreeFile = new File(saveTreeLocation); final File saveTreeFile = new File(saveTreeLocation);
for(int k=0; k<3; k++){ for(int k=0; k<3; k++){
forestTrainer.trainParallelOnDisk(4); forestTrainer.trainParallelOnDisk(Optional.empty(), 4);
final Forest<Double, Double> replicantForest = DataUtils.loadForest(saveTreeFile, responseCombiner); final Forest<Double, Double> replicantForest = DataUtils.loadForest(saveTreeFile, responseCombiner);
TestUtils.removeFolder(saveTreeFile); TestUtils.removeFolder(saveTreeFile);
verifyTwoForestsEqual(testData, referenceForest, replicantForest); verifyTwoForestsEqual(testData, referenceForest, replicantForest);

View file

@ -0,0 +1,224 @@
/*
* Copyright (c) 2019 Joel Therrien.
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
package ca.joeltherrien.randomforest;
import ca.joeltherrien.randomforest.covariates.Covariate;
import ca.joeltherrien.randomforest.covariates.numeric.NumericCovariate;
import ca.joeltherrien.randomforest.covariates.settings.BooleanCovariateSettings;
import ca.joeltherrien.randomforest.covariates.settings.NumericCovariateSettings;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskFunctions;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
import ca.joeltherrien.randomforest.responses.competingrisk.combiner.CompetingRiskFunctionCombiner;
import ca.joeltherrien.randomforest.responses.regression.MeanResponseCombiner;
import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceSplitFinder;
import ca.joeltherrien.randomforest.tree.Forest;
import ca.joeltherrien.randomforest.tree.ForestTrainer;
import ca.joeltherrien.randomforest.tree.Tree;
import ca.joeltherrien.randomforest.tree.TreeTrainer;
import ca.joeltherrien.randomforest.utils.DataUtils;
import ca.joeltherrien.randomforest.utils.Utils;
import com.fasterxml.jackson.databind.node.*;
import org.junit.jupiter.api.Test;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.stream.Collectors;
import static org.junit.jupiter.api.Assertions.*;
public class TestProvidingInitialForest {
private Forest<Double, Double> initialForest;
private List<Covariate> covariateList;
private List<Row<Double>> data;
public TestProvidingInitialForest(){
covariateList = Collections.singletonList(new NumericCovariate("x", 0));
data = Utils.easyList(
Row.createSimple(Utils.easyMap("x", "1.0"), covariateList, 1, 1.0),
Row.createSimple(Utils.easyMap("x", "1.0"), covariateList, 2, 1.5),
Row.createSimple(Utils.easyMap("x", "2.0"), covariateList, 3, 5.0),
Row.createSimple(Utils.easyMap("x", "2.0"), covariateList, 4, 6.0)
);
final TreeTrainer<Double, Double> treeTrainer = TreeTrainer.<Double, Double>builder()
.splitFinder(new WeightedVarianceSplitFinder())
.responseCombiner(new MeanResponseCombiner())
.checkNodePurity(false)
.numberOfSplits(0)
.nodeSize(1)
.mtry(1)
.maxNodeDepth(100000)
.covariates(covariateList)
.build();
final ForestTrainer<Double, Double, Double> forestTrainer = ForestTrainer.<Double, Double, Double>builder()
.treeResponseCombiner(new MeanResponseCombiner())
.ntree(10)
.displayProgress(false)
.data(data)
.covariates(covariateList)
.treeTrainer(treeTrainer)
.build();
initialForest = forestTrainer.trainSerialInMemory(Optional.empty());
}
private final int NTREE = 10;
private ForestTrainer<Double, Double, Double> getForestTrainer(String saveTreeLocation, int ntree){
final TreeTrainer<Double, Double> treeTrainer = TreeTrainer.<Double, Double>builder()
.splitFinder(new WeightedVarianceSplitFinder())
.responseCombiner(new MeanResponseCombiner())
.checkNodePurity(false)
.numberOfSplits(0)
.nodeSize(1)
.mtry(1)
.maxNodeDepth(100000)
.covariates(covariateList)
.build();
final ForestTrainer<Double, Double, Double> forestTrainer = ForestTrainer.<Double, Double, Double>builder()
.treeResponseCombiner(new MeanResponseCombiner())
.ntree(ntree)
.displayProgress(false)
.data(data)
.covariates(covariateList)
.treeTrainer(treeTrainer)
.saveTreeLocation(saveTreeLocation)
.build();
return forestTrainer;
}
@Test
public void testSerialInMemory(){
final ForestTrainer<Double, Double, Double> forestTrainer = getForestTrainer(null, 20);
final Forest<Double, Double> newForest = forestTrainer.trainSerialInMemory(Optional.of(initialForest));
assertEquals(20, newForest.getTrees().size());
for(Tree<Double> initialTree : initialForest.getTrees()){
assertTrue(newForest.getTrees().contains(initialTree));
}
for(int j=10; j<20; j++){
final Tree<Double> newTree = newForest.getTrees().get(j);
assertFalse(initialForest.getTrees().contains(newTree));
}
}
@Test
public void testParallelInMemory(){
final ForestTrainer<Double, Double, Double> forestTrainer = getForestTrainer(null, 20);
final Forest<Double, Double> newForest = forestTrainer.trainParallelInMemory(Optional.of(initialForest), 2);
assertEquals(20, newForest.getTrees().size());
for(Tree<Double> initialTree : initialForest.getTrees()){
assertTrue(newForest.getTrees().contains(initialTree));
}
for(int j=10; j<20; j++){
final Tree<Double> newTree = newForest.getTrees().get(j);
assertFalse(initialForest.getTrees().contains(newTree));
}
}
@Test
public void testParallelOnDisk() throws IOException, ClassNotFoundException {
final String filePath = "src/test/resources/trees/";
final File directory = new File(filePath);
if(directory.exists()){
TestUtils.removeFolder(directory);
}
final ForestTrainer<Double, Double, Double> forestTrainer = getForestTrainer(filePath, 20);
forestTrainer.trainParallelOnDisk(Optional.of(initialForest), 2);
assertEquals(20, directory.listFiles().length);
final Forest<Double, Double> newForest = DataUtils.loadForest(directory, new MeanResponseCombiner());
assertEquals(20, newForest.getTrees().size());
final List<String> newForestTreesAsStrings = newForest.getTrees().stream()
.map(tree -> tree.toString()).collect(Collectors.toList());
for(Tree<Double> initialTree : initialForest.getTrees()){
assertTrue(newForestTreesAsStrings.contains(initialTree.toString()));
}
TestUtils.removeFolder(directory);
}
@Test
public void testSerialOnDisk() throws IOException, ClassNotFoundException {
final String filePath = "src/test/resources/trees/";
final File directory = new File(filePath);
if(directory.exists()){
TestUtils.removeFolder(directory);
}
final ForestTrainer<Double, Double, Double> forestTrainer = getForestTrainer(filePath, 20);
forestTrainer.trainSerialOnDisk(Optional.of(initialForest));
assertEquals(20, directory.listFiles().length);
final Forest<Double, Double> newForest = DataUtils.loadForest(directory, new MeanResponseCombiner());
assertEquals(20, newForest.getTrees().size());
final List<String> newForestTreesAsStrings = newForest.getTrees().stream()
.map(tree -> tree.toString()).collect(Collectors.toList());
for(Tree<Double> initialTree : initialForest.getTrees()){
assertTrue(newForestTreesAsStrings.contains(initialTree.toString()));
}
TestUtils.removeFolder(directory);
}
/*
We throw IllegalArgumentExceptions when we try providing an initial forest when trees were already saved, because
it's not clear if the forest being provided is the same one that trees were saved from.
*/
@Test
public void verifyExceptions(){
final String filePath = "src/test/resources/trees/";
final File directory = new File(filePath);
if(directory.exists()){
TestUtils.removeFolder(directory);
}
final ForestTrainer<Double, Double, Double> forestTrainer = getForestTrainer(filePath, 10);
forestTrainer.trainSerialOnDisk(Optional.empty());
forestTrainer.setNtree(20);
assertThrows(IllegalArgumentException.class, () -> forestTrainer.trainSerialOnDisk(Optional.of(initialForest)));
assertThrows(IllegalArgumentException.class, () -> forestTrainer.trainParallelOnDisk(Optional.of(initialForest), 2));
TestUtils.removeFolder(directory);
}
}

View file

@ -33,6 +33,7 @@ import static org.junit.jupiter.api.Assertions.*;
import java.io.File; import java.io.File;
import java.io.IOException; import java.io.IOException;
import java.util.List; import java.util.List;
import java.util.Optional;
public class TestSavingLoading { public class TestSavingLoading {
@ -123,7 +124,7 @@ public class TestSavingLoading {
final ForestTrainer<CompetingRiskResponse, CompetingRiskFunctions, CompetingRiskFunctions> forestTrainer = new ForestTrainer<>(settings, dataset, covariates); final ForestTrainer<CompetingRiskResponse, CompetingRiskFunctions, CompetingRiskFunctions> forestTrainer = new ForestTrainer<>(settings, dataset, covariates);
forestTrainer.trainSerialOnDisk(); forestTrainer.trainSerialOnDisk(Optional.empty());
assertTrue(directory.exists()); assertTrue(directory.exists());
assertTrue(directory.isDirectory()); assertTrue(directory.isDirectory());
@ -164,7 +165,7 @@ public class TestSavingLoading {
final ForestTrainer<CompetingRiskResponse, CompetingRiskFunctions, CompetingRiskFunctions> forestTrainer = new ForestTrainer<>(settings, dataset, covariates); final ForestTrainer<CompetingRiskResponse, CompetingRiskFunctions, CompetingRiskFunctions> forestTrainer = new ForestTrainer<>(settings, dataset, covariates);
forestTrainer.trainParallelOnDisk(settings.getNumberOfThreads()); forestTrainer.trainParallelOnDisk(Optional.empty(), settings.getNumberOfThreads());
assertTrue(directory.exists()); assertTrue(directory.exists());
assertTrue(directory.isDirectory()); assertTrue(directory.isDirectory());

View file

@ -37,6 +37,7 @@ import org.junit.jupiter.api.Test;
import java.io.IOException; import java.io.IOException;
import java.util.List; import java.util.List;
import java.util.Optional;
import java.util.Random; import java.util.Random;
import static ca.joeltherrien.randomforest.TestUtils.assertCumulativeFunction; import static ca.joeltherrien.randomforest.TestUtils.assertCumulativeFunction;
@ -235,7 +236,7 @@ public class TestCompetingRisk {
final ForestTrainer<CompetingRiskResponse, CompetingRiskFunctions, CompetingRiskFunctions> forestTrainer = new ForestTrainer<>(settings, dataset, covariates); final ForestTrainer<CompetingRiskResponse, CompetingRiskFunctions, CompetingRiskFunctions> forestTrainer = new ForestTrainer<>(settings, dataset, covariates);
final Forest<CompetingRiskFunctions, CompetingRiskFunctions> forest = forestTrainer.trainSerialInMemory(); final Forest<CompetingRiskFunctions, CompetingRiskFunctions> forest = forestTrainer.trainSerialInMemory(Optional.empty());
// prediction row // prediction row
// time status ageatfda idu black cd4nadir // time status ageatfda idu black cd4nadir
@ -328,7 +329,7 @@ public class TestCompetingRisk {
final long startTime = System.currentTimeMillis(); final long startTime = System.currentTimeMillis();
for(int i=0; i<50; i++){ for(int i=0; i<50; i++){
forestTrainer.trainSerialInMemory(); forestTrainer.trainSerialInMemory(Optional.empty());
} }
final long endTime = System.currentTimeMillis(); final long endTime = System.currentTimeMillis();
@ -346,7 +347,7 @@ public class TestCompetingRisk {
final List<Row<CompetingRiskResponse>> dataset = DataUtils.loadData(covariates, settings.getResponseLoader(), final List<Row<CompetingRiskResponse>> dataset = DataUtils.loadData(covariates, settings.getResponseLoader(),
settings.getTrainingDataLocation()); settings.getTrainingDataLocation());
final ForestTrainer<CompetingRiskResponse, CompetingRiskFunctions, CompetingRiskFunctions> forestTrainer = new ForestTrainer<>(settings, dataset, covariates); final ForestTrainer<CompetingRiskResponse, CompetingRiskFunctions, CompetingRiskFunctions> forestTrainer = new ForestTrainer<>(settings, dataset, covariates);
final Forest<CompetingRiskFunctions, CompetingRiskFunctions> forest = forestTrainer.trainSerialInMemory(); final Forest<CompetingRiskFunctions, CompetingRiskFunctions> forest = forestTrainer.trainSerialInMemory(Optional.empty());
// prediction row // prediction row
// time status ageatfda idu black cd4nadir // time status ageatfda idu black cd4nadir

View file

@ -90,7 +90,7 @@ public class TrainForest {
//final Forest<Double> forest = forestTrainer.trainSerialInMemory(); //final Forest<Double> forest = forestTrainer.trainSerialInMemory();
//final Forest<Double> forest = forestTrainer.trainParallelInMemory(3); //final Forest<Double> forest = forestTrainer.trainParallelInMemory(3);
forestTrainer.trainParallelOnDisk(3); forestTrainer.trainParallelOnDisk(Optional.empty(), 3);
final long endTime = System.currentTimeMillis(); final long endTime = System.currentTimeMillis();