Add support for providing an initial forest to add trees to
This commit is contained in:
parent
7da3bd14a5
commit
22accdb263
7 changed files with 365 additions and 51 deletions
|
@ -39,6 +39,7 @@ import java.io.File;
|
|||
import java.io.IOException;
|
||||
import java.io.PrintWriter;
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
import java.util.Random;
|
||||
|
||||
public class Main {
|
||||
|
@ -73,16 +74,16 @@ public class Main {
|
|||
|
||||
if(settings.isSaveProgress()){
|
||||
if(settings.getNumberOfThreads() > 1){
|
||||
forestTrainer.trainParallelOnDisk(settings.getNumberOfThreads());
|
||||
forestTrainer.trainParallelOnDisk(Optional.empty(), settings.getNumberOfThreads());
|
||||
} else{
|
||||
forestTrainer.trainSerialOnDisk();
|
||||
forestTrainer.trainSerialOnDisk(Optional.empty());
|
||||
}
|
||||
}
|
||||
else{
|
||||
if(settings.getNumberOfThreads() > 1){
|
||||
forestTrainer.trainParallelInMemory(settings.getNumberOfThreads());
|
||||
forestTrainer.trainParallelInMemory(Optional.empty(), settings.getNumberOfThreads());
|
||||
} else{
|
||||
forestTrainer.trainSerialInMemory();
|
||||
forestTrainer.trainSerialInMemory(Optional.empty());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -21,16 +21,11 @@ import ca.joeltherrien.randomforest.Row;
|
|||
import ca.joeltherrien.randomforest.Settings;
|
||||
import ca.joeltherrien.randomforest.covariates.Covariate;
|
||||
import ca.joeltherrien.randomforest.utils.DataUtils;
|
||||
import lombok.AccessLevel;
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Builder;
|
||||
import lombok.*;
|
||||
|
||||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.Random;
|
||||
import java.util.*;
|
||||
import java.util.concurrent.ExecutorService;
|
||||
import java.util.concurrent.Executors;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
|
@ -48,7 +43,9 @@ public class ForestTrainer<Y, TO, FO> {
|
|||
private final List<Row<Y>> data;
|
||||
|
||||
// number of trees to try
|
||||
private final int ntree;
|
||||
@Getter
|
||||
@Setter
|
||||
private int ntree;
|
||||
|
||||
private final boolean displayProgress; // whether to print to standard output our progress; not always desirable
|
||||
private final String saveTreeLocation;
|
||||
|
@ -72,12 +69,21 @@ public class ForestTrainer<Y, TO, FO> {
|
|||
}
|
||||
}
|
||||
|
||||
public Forest<TO, FO> trainSerialInMemory(){
|
||||
/**
|
||||
* Train a forest in memory using a single core
|
||||
*
|
||||
* @param initialForest An Optional possibly containing a pre-trained forest,
|
||||
* in which case its trees are combined with the new one.
|
||||
* @return A trained forest.
|
||||
*/
|
||||
public Forest<TO, FO> trainSerialInMemory(Optional<Forest<TO, FO>> initialForest){
|
||||
|
||||
final List<Tree<TO>> trees = new ArrayList<>(ntree);
|
||||
initialForest.ifPresent(forest -> trees.addAll(forest.getTrees()));
|
||||
|
||||
final Bootstrapper<Row<Y>> bootstrapper = new Bootstrapper<>(data);
|
||||
|
||||
for(int j=0; j<ntree; j++){
|
||||
for(int j=trees.size(); j<ntree; j++){
|
||||
if(displayProgress){
|
||||
System.out.print("\rFinished tree " + j + "/" + ntree + " trees");
|
||||
}
|
||||
|
@ -99,7 +105,15 @@ public class ForestTrainer<Y, TO, FO> {
|
|||
|
||||
}
|
||||
|
||||
public void trainSerialOnDisk(){
|
||||
/**
|
||||
* Train a forest on the disk using a single core.
|
||||
*
|
||||
* @param initialForest An Optional possibly containing a pre-trained forest,
|
||||
* in which case its trees are combined with the new one.
|
||||
* There cannot be existing trees if the initial forest is
|
||||
* specified.
|
||||
*/
|
||||
public void trainSerialOnDisk(Optional<Forest<TO, FO>> initialForest){
|
||||
// First we need to see how many trees there currently are
|
||||
final File folder = new File(saveTreeLocation);
|
||||
if(!folder.exists()){
|
||||
|
@ -112,21 +126,42 @@ public class ForestTrainer<Y, TO, FO> {
|
|||
|
||||
final File[] treeFiles = folder.listFiles((file, s) -> s.endsWith(".tree"));
|
||||
final List<String> treeFileNames = Arrays.stream(treeFiles).map(file -> file.getName()).collect(Collectors.toList());
|
||||
final AtomicInteger treeCount = new AtomicInteger(treeFiles.length); // tracks how many trees are finished
|
||||
// Using an AtomicInteger is overkill for serial code, but this lets use reuse TreeSavedWorker
|
||||
|
||||
for(int j=0; j<ntree; j++){
|
||||
if(initialForest.isPresent() & treeFiles.length > 0){
|
||||
throw new IllegalArgumentException("An initial forest is present but trees are also present; not clear how to integrate initial forest into new forest");
|
||||
}
|
||||
|
||||
final AtomicInteger treeCount; // tracks how many trees are finished
|
||||
// Using an AtomicInteger is overkill for serial code, but this lets us reuse TreeSavedWorker
|
||||
if(initialForest.isPresent()){
|
||||
final List<Tree<TO>> initialTrees = initialForest.get().getTrees();
|
||||
|
||||
for(int j=0; j<initialTrees.size(); j++){
|
||||
final String filename = "tree-" + (j+1) + ".tree";
|
||||
final Tree<TO> tree = initialTrees.get(j);
|
||||
|
||||
saveTree(tree, filename);
|
||||
|
||||
}
|
||||
|
||||
treeCount = new AtomicInteger(initialTrees.size());
|
||||
} else{
|
||||
treeCount = new AtomicInteger(treeFiles.length);
|
||||
}
|
||||
|
||||
|
||||
while(treeCount.get() < ntree){
|
||||
if(displayProgress) {
|
||||
System.out.print("\rFinished " + treeCount.get() + "/" + ntree + " trees");
|
||||
}
|
||||
|
||||
final String treeFileName = "tree-" + (j+1) + ".tree";
|
||||
final String treeFileName = "tree-" + (treeCount.get() + 1) + ".tree";
|
||||
|
||||
if(treeFileNames.contains(treeFileName)){
|
||||
continue;
|
||||
}
|
||||
|
||||
final Random random = new Random(this.randomSeed + j);
|
||||
final Random random = new Random(this.randomSeed + treeCount.get());
|
||||
final Runnable worker = new TreeSavedWorker(data, treeFileName, treeCount, random);
|
||||
worker.run();
|
||||
|
||||
|
@ -139,15 +174,34 @@ public class ForestTrainer<Y, TO, FO> {
|
|||
|
||||
}
|
||||
|
||||
public Forest<TO, FO> trainParallelInMemory(int threads){
|
||||
/**
|
||||
* Train a forest in memory using the specified number of threads.
|
||||
*
|
||||
* @param initialForest An Optional possibly containing a pre-trained forest,
|
||||
* in which case its trees are combined with the new one.
|
||||
* @param threads The number of trees to train at once.
|
||||
*/
|
||||
public Forest<TO, FO> trainParallelInMemory(Optional<Forest<TO, FO>> initialForest, int threads){
|
||||
|
||||
// create a list that is prespecified in size (I can call the .set method at any index < ntree without
|
||||
// create a list that is pre-specified in size (I can call the .set method at any index < ntree without
|
||||
// the earlier indexes being filled.
|
||||
final List<Tree<TO>> trees = Stream.<Tree<TO>>generate(() -> null).limit(ntree).collect(Collectors.toList());
|
||||
|
||||
final int startingCount;
|
||||
if(initialForest.isPresent()){
|
||||
final List<Tree<TO>> initialTrees = initialForest.get().getTrees();
|
||||
for(int j=0; j<initialTrees.size(); j++) {
|
||||
trees.set(j, initialTrees.get(j));
|
||||
}
|
||||
startingCount = initialTrees.size();
|
||||
}
|
||||
else{
|
||||
startingCount = 0;
|
||||
}
|
||||
|
||||
final ExecutorService executorService = Executors.newFixedThreadPool(threads);
|
||||
|
||||
for(int j=0; j<ntree; j++){
|
||||
for(int j=startingCount; j<ntree; j++){
|
||||
final Random random = new Random(this.randomSeed + j);
|
||||
final Runnable worker = new TreeInMemoryWorker(data, j, trees, random);
|
||||
executorService.execute(worker);
|
||||
|
@ -191,8 +245,16 @@ public class ForestTrainer<Y, TO, FO> {
|
|||
|
||||
}
|
||||
|
||||
|
||||
public void trainParallelOnDisk(int threads){
|
||||
/**
|
||||
* Train a forest on the disk using a specified number of threads.
|
||||
*
|
||||
* @param initialForest An Optional possibly containing a pre-trained forest,
|
||||
* in which case its trees are combined with the new one.
|
||||
* There cannot be existing trees if the initial forest is
|
||||
* specified.
|
||||
* @param threads The number of trees to train at once.
|
||||
*/
|
||||
public void trainParallelOnDisk(Optional<Forest<TO, FO>> initialForest, int threads){
|
||||
// First we need to see how many trees there currently are
|
||||
final File folder = new File(saveTreeLocation);
|
||||
if(!folder.exists()){
|
||||
|
@ -205,11 +267,31 @@ public class ForestTrainer<Y, TO, FO> {
|
|||
|
||||
final File[] treeFiles = folder.listFiles((file, s) -> s.endsWith(".tree"));
|
||||
final List<String> treeFileNames = Arrays.stream(treeFiles).map(file -> file.getName()).collect(Collectors.toList());
|
||||
final AtomicInteger treeCount = new AtomicInteger(treeFiles.length); // tracks how many trees are finished
|
||||
|
||||
if(initialForest.isPresent() & treeFiles.length > 0){
|
||||
throw new IllegalArgumentException("An initial forest is present but trees are also present; not clear how to integrate initial forest into new forest");
|
||||
}
|
||||
|
||||
final AtomicInteger treeCount; // tracks how many trees are finished
|
||||
if(initialForest.isPresent()){
|
||||
final List<Tree<TO>> initialTrees = initialForest.get().getTrees();
|
||||
|
||||
for(int j=0; j<initialTrees.size(); j++){
|
||||
final String filename = "tree-" + (j+1) + ".tree";
|
||||
final Tree<TO> tree = initialTrees.get(j);
|
||||
|
||||
saveTree(tree, filename);
|
||||
|
||||
}
|
||||
|
||||
treeCount = new AtomicInteger(initialTrees.size());
|
||||
} else{
|
||||
treeCount = new AtomicInteger(treeFiles.length);
|
||||
}
|
||||
|
||||
final ExecutorService executorService = Executors.newFixedThreadPool(threads);
|
||||
|
||||
for(int j=0; j<ntree; j++){
|
||||
for(int j=treeCount.get(); j<ntree; j++){
|
||||
final String treeFileName = "tree-" + (j+1) + ".tree";
|
||||
if(treeFileNames.contains(treeFileName)){
|
||||
continue;
|
||||
|
@ -253,6 +335,17 @@ public class ForestTrainer<Y, TO, FO> {
|
|||
return treeTrainer.growTree(bootstrappedData, random);
|
||||
}
|
||||
|
||||
private void saveTree(Tree<TO> tree, String filename){
|
||||
try {
|
||||
DataUtils.saveObject(tree, saveTreeLocation + "/" + filename);
|
||||
} catch (IOException e) {
|
||||
System.err.println("IOException while saving " + filename);
|
||||
e.printStackTrace();
|
||||
System.err.println("Quitting program");
|
||||
System.exit(1);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
private class TreeInMemoryWorker implements Runnable {
|
||||
|
||||
|
@ -297,14 +390,7 @@ public class ForestTrainer<Y, TO, FO> {
|
|||
public void run() {
|
||||
final Tree<TO> tree = trainTree(bootstrapper, random);
|
||||
|
||||
try {
|
||||
DataUtils.saveObject(tree, saveTreeLocation + "/" + filename);
|
||||
} catch (IOException e) {
|
||||
System.err.println("IOException while saving " + filename);
|
||||
e.printStackTrace();
|
||||
System.err.println("Quitting program");
|
||||
System.exit(1);
|
||||
}
|
||||
saveTree(tree, filename);
|
||||
|
||||
treeCount.incrementAndGet();
|
||||
|
||||
|
|
|
@ -33,6 +33,7 @@ import java.io.File;
|
|||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
import java.util.Random;
|
||||
|
||||
public class TestDeterministicForests {
|
||||
|
@ -144,7 +145,7 @@ public class TestDeterministicForests {
|
|||
.build();
|
||||
|
||||
// By training the referenceForest through one method we also verify that all the methods produce the same forests for a given seed.
|
||||
final Forest<Double, Double> referenceForest = forestTrainer.trainSerialInMemory();
|
||||
final Forest<Double, Double> referenceForest = forestTrainer.trainSerialInMemory(Optional.empty());
|
||||
|
||||
verifySerialInMemoryTraining(referenceForest, forestTrainer, testData);
|
||||
verifyParallelInMemoryTraining(referenceForest, forestTrainer, testData);
|
||||
|
@ -206,20 +207,20 @@ public class TestDeterministicForests {
|
|||
.build();
|
||||
|
||||
// By training the referenceForest through one method we also verify that all the methods produce the same forests for a given seed.
|
||||
final Forest<Double, Double> referenceForest = forestTrainer10Trees.trainSerialInMemory();
|
||||
final Forest<Double, Double> referenceForest = forestTrainer10Trees.trainSerialInMemory(Optional.empty());
|
||||
|
||||
final File saveTreeFile = new File(saveTreeLocation);
|
||||
|
||||
|
||||
forestTrainer5Trees.trainSerialOnDisk();
|
||||
forestTrainer10Trees.trainSerialOnDisk();
|
||||
forestTrainer5Trees.trainSerialOnDisk(Optional.empty());
|
||||
forestTrainer10Trees.trainSerialOnDisk(Optional.empty());
|
||||
final Forest<Double, Double> forestSerial = DataUtils.loadForest(saveTreeFile, new MeanResponseCombiner());
|
||||
TestUtils.removeFolder(saveTreeFile);
|
||||
verifyTwoForestsEqual(testData, referenceForest, forestSerial);
|
||||
|
||||
|
||||
forestTrainer5Trees.trainParallelOnDisk(4);
|
||||
forestTrainer10Trees.trainParallelOnDisk(4);
|
||||
forestTrainer5Trees.trainParallelOnDisk(Optional.empty(), 4);
|
||||
forestTrainer10Trees.trainParallelOnDisk(Optional.empty(), 4);
|
||||
final Forest<Double, Double> forestParallel = DataUtils.loadForest(saveTreeFile, new MeanResponseCombiner());
|
||||
TestUtils.removeFolder(saveTreeFile);
|
||||
verifyTwoForestsEqual(testData, referenceForest, forestParallel);
|
||||
|
@ -232,7 +233,7 @@ public class TestDeterministicForests {
|
|||
List<Row<Double>> testData){
|
||||
|
||||
for(int k=0; k<3; k++){
|
||||
final Forest<Double, Double> replicantForest = forestTrainer.trainSerialInMemory();
|
||||
final Forest<Double, Double> replicantForest = forestTrainer.trainSerialInMemory(Optional.empty());
|
||||
verifyTwoForestsEqual(testData, referenceForest, replicantForest);
|
||||
}
|
||||
}
|
||||
|
@ -243,7 +244,7 @@ public class TestDeterministicForests {
|
|||
List<Row<Double>> testData){
|
||||
|
||||
for(int k=0; k<3; k++){
|
||||
final Forest<Double, Double> replicantForest = forestTrainer.trainParallelInMemory(4);
|
||||
final Forest<Double, Double> replicantForest = forestTrainer.trainParallelInMemory(Optional.empty(), 4);
|
||||
verifyTwoForestsEqual(testData, referenceForest, replicantForest);
|
||||
}
|
||||
}
|
||||
|
@ -257,7 +258,7 @@ public class TestDeterministicForests {
|
|||
final File saveTreeFile = new File(saveTreeLocation);
|
||||
|
||||
for(int k=0; k<3; k++){
|
||||
forestTrainer.trainSerialOnDisk();
|
||||
forestTrainer.trainSerialOnDisk(Optional.empty());
|
||||
final Forest<Double, Double> replicantForest = DataUtils.loadForest(saveTreeFile, responseCombiner);
|
||||
TestUtils.removeFolder(saveTreeFile);
|
||||
verifyTwoForestsEqual(testData, referenceForest, replicantForest);
|
||||
|
@ -272,7 +273,7 @@ public class TestDeterministicForests {
|
|||
final File saveTreeFile = new File(saveTreeLocation);
|
||||
|
||||
for(int k=0; k<3; k++){
|
||||
forestTrainer.trainParallelOnDisk(4);
|
||||
forestTrainer.trainParallelOnDisk(Optional.empty(), 4);
|
||||
final Forest<Double, Double> replicantForest = DataUtils.loadForest(saveTreeFile, responseCombiner);
|
||||
TestUtils.removeFolder(saveTreeFile);
|
||||
verifyTwoForestsEqual(testData, referenceForest, replicantForest);
|
||||
|
|
|
@ -0,0 +1,224 @@
|
|||
/*
|
||||
* Copyright (c) 2019 Joel Therrien.
|
||||
* This program is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* This program is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License
|
||||
* along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package ca.joeltherrien.randomforest;
|
||||
|
||||
import ca.joeltherrien.randomforest.covariates.Covariate;
|
||||
import ca.joeltherrien.randomforest.covariates.numeric.NumericCovariate;
|
||||
import ca.joeltherrien.randomforest.covariates.settings.BooleanCovariateSettings;
|
||||
import ca.joeltherrien.randomforest.covariates.settings.NumericCovariateSettings;
|
||||
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskFunctions;
|
||||
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
|
||||
import ca.joeltherrien.randomforest.responses.competingrisk.combiner.CompetingRiskFunctionCombiner;
|
||||
import ca.joeltherrien.randomforest.responses.regression.MeanResponseCombiner;
|
||||
import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceSplitFinder;
|
||||
import ca.joeltherrien.randomforest.tree.Forest;
|
||||
import ca.joeltherrien.randomforest.tree.ForestTrainer;
|
||||
import ca.joeltherrien.randomforest.tree.Tree;
|
||||
import ca.joeltherrien.randomforest.tree.TreeTrainer;
|
||||
import ca.joeltherrien.randomforest.utils.DataUtils;
|
||||
import ca.joeltherrien.randomforest.utils.Utils;
|
||||
import com.fasterxml.jackson.databind.node.*;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.*;
|
||||
|
||||
public class TestProvidingInitialForest {
|
||||
|
||||
private Forest<Double, Double> initialForest;
|
||||
private List<Covariate> covariateList;
|
||||
private List<Row<Double>> data;
|
||||
|
||||
public TestProvidingInitialForest(){
|
||||
covariateList = Collections.singletonList(new NumericCovariate("x", 0));
|
||||
|
||||
data = Utils.easyList(
|
||||
Row.createSimple(Utils.easyMap("x", "1.0"), covariateList, 1, 1.0),
|
||||
Row.createSimple(Utils.easyMap("x", "1.0"), covariateList, 2, 1.5),
|
||||
Row.createSimple(Utils.easyMap("x", "2.0"), covariateList, 3, 5.0),
|
||||
Row.createSimple(Utils.easyMap("x", "2.0"), covariateList, 4, 6.0)
|
||||
);
|
||||
|
||||
final TreeTrainer<Double, Double> treeTrainer = TreeTrainer.<Double, Double>builder()
|
||||
.splitFinder(new WeightedVarianceSplitFinder())
|
||||
.responseCombiner(new MeanResponseCombiner())
|
||||
.checkNodePurity(false)
|
||||
.numberOfSplits(0)
|
||||
.nodeSize(1)
|
||||
.mtry(1)
|
||||
.maxNodeDepth(100000)
|
||||
.covariates(covariateList)
|
||||
.build();
|
||||
|
||||
final ForestTrainer<Double, Double, Double> forestTrainer = ForestTrainer.<Double, Double, Double>builder()
|
||||
.treeResponseCombiner(new MeanResponseCombiner())
|
||||
.ntree(10)
|
||||
.displayProgress(false)
|
||||
.data(data)
|
||||
.covariates(covariateList)
|
||||
.treeTrainer(treeTrainer)
|
||||
.build();
|
||||
|
||||
initialForest = forestTrainer.trainSerialInMemory(Optional.empty());
|
||||
}
|
||||
|
||||
private final int NTREE = 10;
|
||||
|
||||
private ForestTrainer<Double, Double, Double> getForestTrainer(String saveTreeLocation, int ntree){
|
||||
final TreeTrainer<Double, Double> treeTrainer = TreeTrainer.<Double, Double>builder()
|
||||
.splitFinder(new WeightedVarianceSplitFinder())
|
||||
.responseCombiner(new MeanResponseCombiner())
|
||||
.checkNodePurity(false)
|
||||
.numberOfSplits(0)
|
||||
.nodeSize(1)
|
||||
.mtry(1)
|
||||
.maxNodeDepth(100000)
|
||||
.covariates(covariateList)
|
||||
.build();
|
||||
|
||||
final ForestTrainer<Double, Double, Double> forestTrainer = ForestTrainer.<Double, Double, Double>builder()
|
||||
.treeResponseCombiner(new MeanResponseCombiner())
|
||||
.ntree(ntree)
|
||||
.displayProgress(false)
|
||||
.data(data)
|
||||
.covariates(covariateList)
|
||||
.treeTrainer(treeTrainer)
|
||||
.saveTreeLocation(saveTreeLocation)
|
||||
.build();
|
||||
|
||||
return forestTrainer;
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testSerialInMemory(){
|
||||
final ForestTrainer<Double, Double, Double> forestTrainer = getForestTrainer(null, 20);
|
||||
|
||||
final Forest<Double, Double> newForest = forestTrainer.trainSerialInMemory(Optional.of(initialForest));
|
||||
assertEquals(20, newForest.getTrees().size());
|
||||
|
||||
for(Tree<Double> initialTree : initialForest.getTrees()){
|
||||
assertTrue(newForest.getTrees().contains(initialTree));
|
||||
}
|
||||
for(int j=10; j<20; j++){
|
||||
final Tree<Double> newTree = newForest.getTrees().get(j);
|
||||
assertFalse(initialForest.getTrees().contains(newTree));
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testParallelInMemory(){
|
||||
final ForestTrainer<Double, Double, Double> forestTrainer = getForestTrainer(null, 20);
|
||||
|
||||
final Forest<Double, Double> newForest = forestTrainer.trainParallelInMemory(Optional.of(initialForest), 2);
|
||||
assertEquals(20, newForest.getTrees().size());
|
||||
|
||||
for(Tree<Double> initialTree : initialForest.getTrees()){
|
||||
assertTrue(newForest.getTrees().contains(initialTree));
|
||||
}
|
||||
for(int j=10; j<20; j++){
|
||||
final Tree<Double> newTree = newForest.getTrees().get(j);
|
||||
assertFalse(initialForest.getTrees().contains(newTree));
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testParallelOnDisk() throws IOException, ClassNotFoundException {
|
||||
final String filePath = "src/test/resources/trees/";
|
||||
final File directory = new File(filePath);
|
||||
if(directory.exists()){
|
||||
TestUtils.removeFolder(directory);
|
||||
}
|
||||
|
||||
final ForestTrainer<Double, Double, Double> forestTrainer = getForestTrainer(filePath, 20);
|
||||
|
||||
forestTrainer.trainParallelOnDisk(Optional.of(initialForest), 2);
|
||||
|
||||
assertEquals(20, directory.listFiles().length);
|
||||
final Forest<Double, Double> newForest = DataUtils.loadForest(directory, new MeanResponseCombiner());
|
||||
|
||||
|
||||
|
||||
assertEquals(20, newForest.getTrees().size());
|
||||
|
||||
final List<String> newForestTreesAsStrings = newForest.getTrees().stream()
|
||||
.map(tree -> tree.toString()).collect(Collectors.toList());
|
||||
|
||||
for(Tree<Double> initialTree : initialForest.getTrees()){
|
||||
assertTrue(newForestTreesAsStrings.contains(initialTree.toString()));
|
||||
}
|
||||
|
||||
TestUtils.removeFolder(directory);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testSerialOnDisk() throws IOException, ClassNotFoundException {
|
||||
final String filePath = "src/test/resources/trees/";
|
||||
final File directory = new File(filePath);
|
||||
if(directory.exists()){
|
||||
TestUtils.removeFolder(directory);
|
||||
}
|
||||
final ForestTrainer<Double, Double, Double> forestTrainer = getForestTrainer(filePath, 20);
|
||||
|
||||
|
||||
forestTrainer.trainSerialOnDisk(Optional.of(initialForest));
|
||||
|
||||
assertEquals(20, directory.listFiles().length);
|
||||
|
||||
final Forest<Double, Double> newForest = DataUtils.loadForest(directory, new MeanResponseCombiner());
|
||||
|
||||
assertEquals(20, newForest.getTrees().size());
|
||||
|
||||
final List<String> newForestTreesAsStrings = newForest.getTrees().stream()
|
||||
.map(tree -> tree.toString()).collect(Collectors.toList());
|
||||
|
||||
for(Tree<Double> initialTree : initialForest.getTrees()){
|
||||
assertTrue(newForestTreesAsStrings.contains(initialTree.toString()));
|
||||
}
|
||||
|
||||
TestUtils.removeFolder(directory);
|
||||
}
|
||||
|
||||
/*
|
||||
We throw IllegalArgumentExceptions when we try providing an initial forest when trees were already saved, because
|
||||
it's not clear if the forest being provided is the same one that trees were saved from.
|
||||
*/
|
||||
@Test
|
||||
public void verifyExceptions(){
|
||||
final String filePath = "src/test/resources/trees/";
|
||||
final File directory = new File(filePath);
|
||||
if(directory.exists()){
|
||||
TestUtils.removeFolder(directory);
|
||||
}
|
||||
final ForestTrainer<Double, Double, Double> forestTrainer = getForestTrainer(filePath, 10);
|
||||
forestTrainer.trainSerialOnDisk(Optional.empty());
|
||||
|
||||
forestTrainer.setNtree(20);
|
||||
assertThrows(IllegalArgumentException.class, () -> forestTrainer.trainSerialOnDisk(Optional.of(initialForest)));
|
||||
assertThrows(IllegalArgumentException.class, () -> forestTrainer.trainParallelOnDisk(Optional.of(initialForest), 2));
|
||||
|
||||
TestUtils.removeFolder(directory);
|
||||
}
|
||||
|
||||
}
|
|
@ -33,6 +33,7 @@ import static org.junit.jupiter.api.Assertions.*;
|
|||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
|
||||
public class TestSavingLoading {
|
||||
|
||||
|
@ -123,7 +124,7 @@ public class TestSavingLoading {
|
|||
|
||||
final ForestTrainer<CompetingRiskResponse, CompetingRiskFunctions, CompetingRiskFunctions> forestTrainer = new ForestTrainer<>(settings, dataset, covariates);
|
||||
|
||||
forestTrainer.trainSerialOnDisk();
|
||||
forestTrainer.trainSerialOnDisk(Optional.empty());
|
||||
|
||||
assertTrue(directory.exists());
|
||||
assertTrue(directory.isDirectory());
|
||||
|
@ -164,7 +165,7 @@ public class TestSavingLoading {
|
|||
|
||||
final ForestTrainer<CompetingRiskResponse, CompetingRiskFunctions, CompetingRiskFunctions> forestTrainer = new ForestTrainer<>(settings, dataset, covariates);
|
||||
|
||||
forestTrainer.trainParallelOnDisk(settings.getNumberOfThreads());
|
||||
forestTrainer.trainParallelOnDisk(Optional.empty(), settings.getNumberOfThreads());
|
||||
|
||||
assertTrue(directory.exists());
|
||||
assertTrue(directory.isDirectory());
|
||||
|
|
|
@ -37,6 +37,7 @@ import org.junit.jupiter.api.Test;
|
|||
|
||||
import java.io.IOException;
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
import java.util.Random;
|
||||
|
||||
import static ca.joeltherrien.randomforest.TestUtils.assertCumulativeFunction;
|
||||
|
@ -235,7 +236,7 @@ public class TestCompetingRisk {
|
|||
|
||||
final ForestTrainer<CompetingRiskResponse, CompetingRiskFunctions, CompetingRiskFunctions> forestTrainer = new ForestTrainer<>(settings, dataset, covariates);
|
||||
|
||||
final Forest<CompetingRiskFunctions, CompetingRiskFunctions> forest = forestTrainer.trainSerialInMemory();
|
||||
final Forest<CompetingRiskFunctions, CompetingRiskFunctions> forest = forestTrainer.trainSerialInMemory(Optional.empty());
|
||||
|
||||
// prediction row
|
||||
// time status ageatfda idu black cd4nadir
|
||||
|
@ -328,7 +329,7 @@ public class TestCompetingRisk {
|
|||
|
||||
final long startTime = System.currentTimeMillis();
|
||||
for(int i=0; i<50; i++){
|
||||
forestTrainer.trainSerialInMemory();
|
||||
forestTrainer.trainSerialInMemory(Optional.empty());
|
||||
}
|
||||
final long endTime = System.currentTimeMillis();
|
||||
|
||||
|
@ -346,7 +347,7 @@ public class TestCompetingRisk {
|
|||
final List<Row<CompetingRiskResponse>> dataset = DataUtils.loadData(covariates, settings.getResponseLoader(),
|
||||
settings.getTrainingDataLocation());
|
||||
final ForestTrainer<CompetingRiskResponse, CompetingRiskFunctions, CompetingRiskFunctions> forestTrainer = new ForestTrainer<>(settings, dataset, covariates);
|
||||
final Forest<CompetingRiskFunctions, CompetingRiskFunctions> forest = forestTrainer.trainSerialInMemory();
|
||||
final Forest<CompetingRiskFunctions, CompetingRiskFunctions> forest = forestTrainer.trainSerialInMemory(Optional.empty());
|
||||
|
||||
// prediction row
|
||||
// time status ageatfda idu black cd4nadir
|
||||
|
|
|
@ -90,7 +90,7 @@ public class TrainForest {
|
|||
|
||||
//final Forest<Double> forest = forestTrainer.trainSerialInMemory();
|
||||
//final Forest<Double> forest = forestTrainer.trainParallelInMemory(3);
|
||||
forestTrainer.trainParallelOnDisk(3);
|
||||
forestTrainer.trainParallelOnDisk(Optional.empty(), 3);
|
||||
|
||||
final long endTime = System.currentTimeMillis();
|
||||
|
||||
|
|
Loading…
Reference in a new issue