Add support for seeds to control randomness when training forests
This commit is contained in:
parent
17ae3a9f5a
commit
6f318db79e
10 changed files with 408 additions and 52 deletions
|
@ -82,7 +82,7 @@ public class Main {
|
|||
if(settings.getNumberOfThreads() > 1){
|
||||
forestTrainer.trainParallelInMemory(settings.getNumberOfThreads());
|
||||
} else{
|
||||
forestTrainer.trainSerial();
|
||||
forestTrainer.trainSerialInMemory();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -151,6 +151,7 @@ public class Settings {
|
|||
private int nodeSize = 5;
|
||||
private int maxNodeDepth = 1000000; // basically no maxNodeDepth
|
||||
private boolean checkNodePurity = false;
|
||||
private Long randomSeed;
|
||||
|
||||
private ObjectNode responseCombinerSettings = new ObjectNode(JsonNodeFactory.instance);
|
||||
private ObjectNode splitFinderSettings = new ObjectNode(JsonNodeFactory.instance);
|
||||
|
|
|
@ -28,12 +28,11 @@ import lombok.Builder;
|
|||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.Random;
|
||||
import java.util.UUID;
|
||||
import java.util.concurrent.ExecutorService;
|
||||
import java.util.concurrent.Executors;
|
||||
import java.util.concurrent.ThreadLocalRandom;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import java.util.concurrent.atomic.AtomicInteger;
|
||||
import java.util.stream.Collectors;
|
||||
|
@ -51,8 +50,9 @@ public class ForestTrainer<Y, TO, FO> {
|
|||
// number of trees to try
|
||||
private final int ntree;
|
||||
|
||||
private final boolean displayProgress;
|
||||
private final boolean displayProgress; // whether to print to standard output our progress; not always desirable
|
||||
private final String saveTreeLocation;
|
||||
private final long randomSeed;
|
||||
|
||||
public ForestTrainer(final Settings settings, final List<Row<Y>> data, final List<Covariate> covariates){
|
||||
this.ntree = settings.getNtree();
|
||||
|
@ -63,19 +63,25 @@ public class ForestTrainer<Y, TO, FO> {
|
|||
this.covariates = covariates;
|
||||
this.treeResponseCombiner = settings.getTreeCombiner();
|
||||
this.treeTrainer = new TreeTrainer<>(settings, covariates);
|
||||
|
||||
if(settings.getRandomSeed() != null){
|
||||
this.randomSeed = settings.getRandomSeed();
|
||||
}
|
||||
else{
|
||||
this.randomSeed = System.nanoTime();
|
||||
}
|
||||
}
|
||||
|
||||
public Forest<TO, FO> trainSerial(){
|
||||
public Forest<TO, FO> trainSerialInMemory(){
|
||||
|
||||
final List<Tree<TO>> trees = new ArrayList<>(ntree);
|
||||
final Bootstrapper<Row<Y>> bootstrapper = new Bootstrapper<>(data);
|
||||
final Random random = new Random();
|
||||
|
||||
for(int j=0; j<ntree; j++){
|
||||
if(displayProgress){
|
||||
System.out.print("\rFinished tree " + j + "/" + ntree + " trees");
|
||||
}
|
||||
|
||||
final Random random = new Random(this.randomSeed + j);
|
||||
trees.add(trainTree(bootstrapper, random));
|
||||
}
|
||||
|
||||
|
@ -96,22 +102,32 @@ public class ForestTrainer<Y, TO, FO> {
|
|||
public void trainSerialOnDisk(){
|
||||
// First we need to see how many trees there currently are
|
||||
final File folder = new File(saveTreeLocation);
|
||||
if(!folder.exists()){
|
||||
folder.mkdir();
|
||||
}
|
||||
|
||||
if(!folder.isDirectory()){
|
||||
throw new IllegalArgumentException("Tree directory must be a directory!");
|
||||
}
|
||||
|
||||
|
||||
final File[] treeFiles = folder.listFiles((file, s) -> s.endsWith(".tree"));
|
||||
|
||||
final List<String> treeFileNames = Arrays.stream(treeFiles).map(file -> file.getName()).collect(Collectors.toList());
|
||||
final AtomicInteger treeCount = new AtomicInteger(treeFiles.length); // tracks how many trees are finished
|
||||
// Using an AtomicInteger is overkill for serial code, but this lets use reuse TreeSavedWorker
|
||||
|
||||
for(int j=treeCount.get(); j<ntree; j++){
|
||||
for(int j=0; j<ntree; j++){
|
||||
if(displayProgress) {
|
||||
System.out.print("\rFinished " + treeCount.get() + "/" + ntree + " trees");
|
||||
}
|
||||
|
||||
final Runnable worker = new TreeSavedWorker(data, "tree-" + UUID.randomUUID() + ".tree", treeCount);
|
||||
final String treeFileName = "tree-" + (j+1) + ".tree";
|
||||
|
||||
if(treeFileNames.contains(treeFileName)){
|
||||
continue;
|
||||
}
|
||||
|
||||
final Random random = new Random(this.randomSeed + j);
|
||||
final Runnable worker = new TreeSavedWorker(data, treeFileName, treeCount, random);
|
||||
worker.run();
|
||||
|
||||
}
|
||||
|
@ -132,7 +148,8 @@ public class ForestTrainer<Y, TO, FO> {
|
|||
final ExecutorService executorService = Executors.newFixedThreadPool(threads);
|
||||
|
||||
for(int j=0; j<ntree; j++){
|
||||
final Runnable worker = new TreeInMemoryWorker(data, j, trees);
|
||||
final Random random = new Random(this.randomSeed + j);
|
||||
final Runnable worker = new TreeInMemoryWorker(data, j, trees, random);
|
||||
executorService.execute(worker);
|
||||
}
|
||||
|
||||
|
@ -182,18 +199,28 @@ public class ForestTrainer<Y, TO, FO> {
|
|||
public void trainParallelOnDisk(int threads){
|
||||
// First we need to see how many trees there currently are
|
||||
final File folder = new File(saveTreeLocation);
|
||||
if(!folder.exists()){
|
||||
folder.mkdir();
|
||||
}
|
||||
|
||||
if(!folder.isDirectory()){
|
||||
throw new IllegalArgumentException("Tree directory must be a directory!");
|
||||
}
|
||||
|
||||
|
||||
final File[] treeFiles = folder.listFiles((file, s) -> s.endsWith(".tree"));
|
||||
|
||||
final ExecutorService executorService = Executors.newFixedThreadPool(threads);
|
||||
final List<String> treeFileNames = Arrays.stream(treeFiles).map(file -> file.getName()).collect(Collectors.toList());
|
||||
final AtomicInteger treeCount = new AtomicInteger(treeFiles.length); // tracks how many trees are finished
|
||||
|
||||
for(int j=treeCount.get(); j<ntree; j++){
|
||||
final Runnable worker = new TreeSavedWorker(data, "tree-" + UUID.randomUUID() + ".tree", treeCount);
|
||||
final ExecutorService executorService = Executors.newFixedThreadPool(threads);
|
||||
|
||||
for(int j=0; j<ntree; j++){
|
||||
final String treeFileName = "tree-" + (j+1) + ".tree";
|
||||
if(treeFileNames.contains(treeFileName)){
|
||||
continue;
|
||||
}
|
||||
|
||||
final Random random = new Random(this.randomSeed + j);
|
||||
final Runnable worker = new TreeSavedWorker(data, treeFileName, treeCount, random);
|
||||
executorService.execute(worker);
|
||||
}
|
||||
|
||||
|
@ -240,18 +267,18 @@ public class ForestTrainer<Y, TO, FO> {
|
|||
private final Bootstrapper<Row<Y>> bootstrapper;
|
||||
private final int treeIndex;
|
||||
private final List<Tree<TO>> treeList;
|
||||
private final Random random;
|
||||
|
||||
TreeInMemoryWorker(final List<Row<Y>> data, final int treeIndex, final List<Tree<TO>> treeList) {
|
||||
TreeInMemoryWorker(final List<Row<Y>> data, final int treeIndex, final List<Tree<TO>> treeList, final Random random) {
|
||||
this.bootstrapper = new Bootstrapper<>(data);
|
||||
this.treeIndex = treeIndex;
|
||||
this.treeList = treeList;
|
||||
this.random = random;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void run() {
|
||||
|
||||
// ThreadLocalRandom should make sure we don't duplicate seeds
|
||||
final Tree<TO> tree = trainTree(bootstrapper, ThreadLocalRandom.current());
|
||||
final Tree<TO> tree = trainTree(bootstrapper, random);
|
||||
|
||||
// should be okay as the list structure isn't changing
|
||||
treeList.set(treeIndex, tree);
|
||||
|
@ -265,18 +292,18 @@ public class ForestTrainer<Y, TO, FO> {
|
|||
private final Bootstrapper<Row<Y>> bootstrapper;
|
||||
private final String filename;
|
||||
private final AtomicInteger treeCount;
|
||||
private final Random random;
|
||||
|
||||
public TreeSavedWorker(final List<Row<Y>> data, final String filename, final AtomicInteger treeCount) {
|
||||
public TreeSavedWorker(final List<Row<Y>> data, final String filename, final AtomicInteger treeCount, final Random random) {
|
||||
this.bootstrapper = new Bootstrapper<>(data);
|
||||
this.filename = filename;
|
||||
this.treeCount = treeCount;
|
||||
this.random = random;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void run() {
|
||||
|
||||
// ThreadLocalRandom should make sure we don't duplicate seeds
|
||||
final Tree<TO> tree = trainTree(bootstrapper, ThreadLocalRandom.current());
|
||||
final Tree<TO> tree = trainTree(bootstrapper, random);
|
||||
|
||||
try {
|
||||
DataUtils.saveObject(tree, saveTreeLocation + "/" + filename);
|
||||
|
|
|
@ -0,0 +1,298 @@
|
|||
/*
|
||||
* Copyright (c) 2019 Joel Therrien.
|
||||
* This program is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* This program is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License
|
||||
* along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package ca.joeltherrien.randomforest;
|
||||
|
||||
import ca.joeltherrien.randomforest.covariates.Covariate;
|
||||
import ca.joeltherrien.randomforest.covariates.bool.BooleanCovariate;
|
||||
import ca.joeltherrien.randomforest.covariates.factor.FactorCovariate;
|
||||
import ca.joeltherrien.randomforest.covariates.numeric.NumericCovariate;
|
||||
import ca.joeltherrien.randomforest.responses.regression.MeanResponseCombiner;
|
||||
import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceSplitFinder;
|
||||
import ca.joeltherrien.randomforest.tree.Forest;
|
||||
import ca.joeltherrien.randomforest.tree.ForestTrainer;
|
||||
import ca.joeltherrien.randomforest.tree.TreeTrainer;
|
||||
import ca.joeltherrien.randomforest.utils.DataUtils;
|
||||
import ca.joeltherrien.randomforest.utils.Utils;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Random;
|
||||
|
||||
public class TestDeterministicForests {
|
||||
|
||||
private final String saveTreeLocation = "src/test/resources/trees/";
|
||||
|
||||
private List<Covariate> generateCovariates(){
|
||||
final List<Covariate> covariateList = new ArrayList<>();
|
||||
|
||||
int index = 0;
|
||||
for(int j=0; j<5; j++){
|
||||
final NumericCovariate numericCovariate = new NumericCovariate("numeric"+j, index);
|
||||
covariateList.add(numericCovariate);
|
||||
index++;
|
||||
}
|
||||
|
||||
for(int j=0; j<5; j++){
|
||||
final BooleanCovariate booleanCovariate = new BooleanCovariate("boolean"+j, index);
|
||||
covariateList.add(booleanCovariate);
|
||||
index++;
|
||||
}
|
||||
|
||||
final List<String> levels = Utils.easyList("cat", "dog", "mouse");
|
||||
for(int j=0; j<5; j++){
|
||||
final FactorCovariate factorCovariate = new FactorCovariate("factor"+j, index, levels);
|
||||
covariateList.add(factorCovariate);
|
||||
index++;
|
||||
}
|
||||
|
||||
return covariateList;
|
||||
}
|
||||
|
||||
private Covariate.Value generateRandomValue(Covariate covariate, Random random){
|
||||
if(covariate instanceof NumericCovariate){
|
||||
return covariate.createValue(random.nextGaussian());
|
||||
}
|
||||
if(covariate instanceof BooleanCovariate){
|
||||
return covariate.createValue(random.nextBoolean());
|
||||
}
|
||||
if(covariate instanceof FactorCovariate){
|
||||
final double itemSelection = random.nextDouble();
|
||||
final String item;
|
||||
if(itemSelection < 1.0/3.0){
|
||||
item = "cat";
|
||||
}
|
||||
else if(itemSelection < 2.0/3.0){
|
||||
item = "dog";
|
||||
}
|
||||
else{
|
||||
item = "mouse";
|
||||
}
|
||||
|
||||
return covariate.createValue(item);
|
||||
}
|
||||
else{
|
||||
throw new IllegalArgumentException("Unknown covariate type of class " + covariate.getClass().getName());
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
private List<Row<Double>> generateTestData(List<Covariate> covariateList, int n, Random random){
|
||||
final List<Row<Double>> rowList = new ArrayList<>();
|
||||
for(int i=0; i<n; i++){
|
||||
final double response = random.nextGaussian();
|
||||
final Covariate.Value[] valueArray = new Covariate.Value[covariateList.size()];
|
||||
|
||||
for(int j=0; j<covariateList.size(); j++){
|
||||
valueArray[j] = generateRandomValue(covariateList.get(j), random);
|
||||
}
|
||||
|
||||
rowList.add(new Row(valueArray, i, response));
|
||||
}
|
||||
|
||||
return rowList;
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testResultsAlwaysSame() throws IOException, ClassNotFoundException {
|
||||
final List<Covariate> covariateList = generateCovariates();
|
||||
|
||||
final Random dataGeneratingRandom = new Random();
|
||||
|
||||
final List<Row<Double>> trainingData = generateTestData(covariateList, 100, dataGeneratingRandom);
|
||||
final List<Row<Double>> testData = generateTestData(covariateList, 10, dataGeneratingRandom);
|
||||
|
||||
// pick a new seed at random
|
||||
final long trainingSeed = dataGeneratingRandom.nextLong();
|
||||
|
||||
final TreeTrainer<Double, Double> treeTrainer = TreeTrainer.<Double, Double>builder()
|
||||
.checkNodePurity(false)
|
||||
.covariates(covariateList)
|
||||
.maxNodeDepth(100)
|
||||
.mtry(1)
|
||||
.nodeSize(10)
|
||||
.numberOfSplits(1) // want results to be dominated by randomness
|
||||
.responseCombiner(new MeanResponseCombiner())
|
||||
.splitFinder(new WeightedVarianceSplitFinder())
|
||||
.build();
|
||||
|
||||
final ForestTrainer<Double, Double, Double> forestTrainer = ForestTrainer.<Double, Double, Double>builder()
|
||||
.treeTrainer(treeTrainer)
|
||||
.covariates(covariateList)
|
||||
.data(trainingData)
|
||||
.displayProgress(false)
|
||||
.ntree(10)
|
||||
.randomSeed(trainingSeed)
|
||||
.treeResponseCombiner(new MeanResponseCombiner())
|
||||
.saveTreeLocation(saveTreeLocation)
|
||||
.build();
|
||||
|
||||
// By training the referenceForest through one method we also verify that all the methods produce the same forests for a given seed.
|
||||
final Forest<Double, Double> referenceForest = forestTrainer.trainSerialInMemory();
|
||||
|
||||
verifySerialInMemoryTraining(referenceForest, forestTrainer, testData);
|
||||
verifyParallelInMemoryTraining(referenceForest, forestTrainer, testData);
|
||||
verifySerialOnDiskTraining(referenceForest, forestTrainer, testData);
|
||||
verifyParallelOnDiskTraining(referenceForest, forestTrainer, testData);
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
* Tests that if we train a forest under a specified seed for 10 trees, that it is equal to training a forest
|
||||
* for 5 trees only, and then starting from that point to train the last 5.
|
||||
*
|
||||
* @throws IOException
|
||||
* @throws ClassNotFoundException
|
||||
*/
|
||||
@Test
|
||||
public void testInterupptedTrainingProducesSameResults() throws IOException, ClassNotFoundException {
|
||||
final List<Covariate> covariateList = generateCovariates();
|
||||
|
||||
final Random dataGeneratingRandom = new Random();
|
||||
|
||||
final List<Row<Double>> trainingData = generateTestData(covariateList, 100, dataGeneratingRandom);
|
||||
final List<Row<Double>> testData = generateTestData(covariateList, 10, dataGeneratingRandom);
|
||||
|
||||
// pick a new seed at random
|
||||
final long trainingSeed = dataGeneratingRandom.nextLong();
|
||||
|
||||
final TreeTrainer<Double, Double> treeTrainer = TreeTrainer.<Double, Double>builder()
|
||||
.checkNodePurity(false)
|
||||
.covariates(covariateList)
|
||||
.maxNodeDepth(100)
|
||||
.mtry(1)
|
||||
.nodeSize(10)
|
||||
.numberOfSplits(1) // want results to be dominated by randomness
|
||||
.responseCombiner(new MeanResponseCombiner())
|
||||
.splitFinder(new WeightedVarianceSplitFinder())
|
||||
.build();
|
||||
|
||||
final ForestTrainer<Double, Double, Double> forestTrainer5Trees = ForestTrainer.<Double, Double, Double>builder()
|
||||
.treeTrainer(treeTrainer)
|
||||
.covariates(covariateList)
|
||||
.data(trainingData)
|
||||
.displayProgress(false)
|
||||
.ntree(5)
|
||||
.randomSeed(trainingSeed)
|
||||
.treeResponseCombiner(new MeanResponseCombiner())
|
||||
.saveTreeLocation(saveTreeLocation)
|
||||
.build();
|
||||
|
||||
final ForestTrainer<Double, Double, Double> forestTrainer10Trees = ForestTrainer.<Double, Double, Double>builder()
|
||||
.treeTrainer(treeTrainer)
|
||||
.covariates(covariateList)
|
||||
.data(trainingData)
|
||||
.displayProgress(false)
|
||||
.ntree(10)
|
||||
.randomSeed(trainingSeed)
|
||||
.treeResponseCombiner(new MeanResponseCombiner())
|
||||
.saveTreeLocation(saveTreeLocation)
|
||||
.build();
|
||||
|
||||
// By training the referenceForest through one method we also verify that all the methods produce the same forests for a given seed.
|
||||
final Forest<Double, Double> referenceForest = forestTrainer10Trees.trainSerialInMemory();
|
||||
|
||||
final File saveTreeFile = new File(saveTreeLocation);
|
||||
|
||||
|
||||
forestTrainer5Trees.trainSerialOnDisk();
|
||||
forestTrainer10Trees.trainSerialOnDisk();
|
||||
final Forest<Double, Double> forestSerial = DataUtils.loadForest(saveTreeFile, new MeanResponseCombiner());
|
||||
TestUtils.removeFolder(saveTreeFile);
|
||||
verifyTwoForestsEqual(testData, referenceForest, forestSerial);
|
||||
|
||||
|
||||
forestTrainer5Trees.trainParallelOnDisk(4);
|
||||
forestTrainer10Trees.trainParallelOnDisk(4);
|
||||
final Forest<Double, Double> forestParallel = DataUtils.loadForest(saveTreeFile, new MeanResponseCombiner());
|
||||
TestUtils.removeFolder(saveTreeFile);
|
||||
verifyTwoForestsEqual(testData, referenceForest, forestParallel);
|
||||
|
||||
}
|
||||
|
||||
private void verifySerialInMemoryTraining(
|
||||
final Forest<Double, Double> referenceForest,
|
||||
ForestTrainer<Double, Double, Double> forestTrainer,
|
||||
List<Row<Double>> testData){
|
||||
|
||||
for(int k=0; k<3; k++){
|
||||
final Forest<Double, Double> replicantForest = forestTrainer.trainSerialInMemory();
|
||||
verifyTwoForestsEqual(testData, referenceForest, replicantForest);
|
||||
}
|
||||
}
|
||||
|
||||
private void verifyParallelInMemoryTraining(
|
||||
Forest<Double, Double> referenceForest,
|
||||
ForestTrainer<Double, Double, Double> forestTrainer,
|
||||
List<Row<Double>> testData){
|
||||
|
||||
for(int k=0; k<3; k++){
|
||||
final Forest<Double, Double> replicantForest = forestTrainer.trainParallelInMemory(4);
|
||||
verifyTwoForestsEqual(testData, referenceForest, replicantForest);
|
||||
}
|
||||
}
|
||||
|
||||
private void verifySerialOnDiskTraining(
|
||||
Forest<Double, Double> referenceForest,
|
||||
ForestTrainer<Double, Double, Double> forestTrainer,
|
||||
List<Row<Double>> testData) throws IOException, ClassNotFoundException {
|
||||
|
||||
final MeanResponseCombiner responseCombiner = new MeanResponseCombiner();
|
||||
final File saveTreeFile = new File(saveTreeLocation);
|
||||
|
||||
for(int k=0; k<3; k++){
|
||||
forestTrainer.trainSerialOnDisk();
|
||||
final Forest<Double, Double> replicantForest = DataUtils.loadForest(saveTreeFile, responseCombiner);
|
||||
TestUtils.removeFolder(saveTreeFile);
|
||||
verifyTwoForestsEqual(testData, referenceForest, replicantForest);
|
||||
}
|
||||
}
|
||||
|
||||
private void verifyParallelOnDiskTraining(
|
||||
final Forest<Double, Double> referenceForest, ForestTrainer<Double, Double, Double> forestTrainer,
|
||||
List<Row<Double>> testData) throws IOException, ClassNotFoundException {
|
||||
|
||||
final MeanResponseCombiner responseCombiner = new MeanResponseCombiner();
|
||||
final File saveTreeFile = new File(saveTreeLocation);
|
||||
|
||||
for(int k=0; k<3; k++){
|
||||
forestTrainer.trainParallelOnDisk(4);
|
||||
final Forest<Double, Double> replicantForest = DataUtils.loadForest(saveTreeFile, responseCombiner);
|
||||
TestUtils.removeFolder(saveTreeFile);
|
||||
verifyTwoForestsEqual(testData, referenceForest, replicantForest);
|
||||
}
|
||||
}
|
||||
|
||||
// Technically verifies the two forests give equal predictions on a given test dataset
|
||||
private void verifyTwoForestsEqual(final List<Row<Double>> testData,
|
||||
final Forest<Double, Double> forest1,
|
||||
final Forest<Double, Double> forest2){
|
||||
|
||||
for(Row row : testData){
|
||||
final Double prediction1 = forest1.evaluate(row);
|
||||
final Double prediction2 = forest2.evaluate(row);
|
||||
|
||||
// I've noticed that results aren't necessarily always *identical*
|
||||
TestUtils.closeEnough(prediction1, prediction2, 0.0000000001);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
}
|
|
@ -116,7 +116,7 @@ public class TestSavingLoading {
|
|||
|
||||
final File directory = new File(settings.getSaveTreeLocation());
|
||||
if(directory.exists()){
|
||||
cleanup(directory);
|
||||
TestUtils.removeFolder(directory);
|
||||
}
|
||||
assertFalse(directory.exists());
|
||||
directory.mkdir();
|
||||
|
@ -142,7 +142,7 @@ public class TestSavingLoading {
|
|||
|
||||
assertEquals(NTREE, forest.getTrees().size());
|
||||
|
||||
cleanup(directory);
|
||||
TestUtils.removeFolder(directory);
|
||||
|
||||
assertFalse(directory.exists());
|
||||
|
||||
|
@ -157,7 +157,7 @@ public class TestSavingLoading {
|
|||
|
||||
final File directory = new File(settings.getSaveTreeLocation());
|
||||
if(directory.exists()){
|
||||
cleanup(directory);
|
||||
TestUtils.removeFolder(directory);
|
||||
}
|
||||
assertFalse(directory.exists());
|
||||
directory.mkdir();
|
||||
|
@ -183,24 +183,12 @@ public class TestSavingLoading {
|
|||
|
||||
assertEquals(NTREE, forest.getTrees().size());
|
||||
|
||||
cleanup(directory);
|
||||
TestUtils.removeFolder(directory);
|
||||
|
||||
assertFalse(directory.exists());
|
||||
|
||||
}
|
||||
|
||||
private void cleanup(File file){
|
||||
if(file.isFile()){
|
||||
file.delete();
|
||||
}
|
||||
else{
|
||||
for(final File inner : file.listFiles()){
|
||||
cleanup(inner);
|
||||
}
|
||||
file.delete();
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -20,6 +20,7 @@ import ca.joeltherrien.randomforest.utils.StepFunction;
|
|||
import ca.joeltherrien.randomforest.utils.Utils;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import java.io.File;
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
|
@ -170,4 +171,16 @@ public class TestUtils {
|
|||
}
|
||||
}
|
||||
|
||||
public static void removeFolder(File file){
|
||||
if(file.isFile()){
|
||||
file.delete();
|
||||
}
|
||||
else{
|
||||
for(final File inner : file.listFiles()){
|
||||
removeFolder(inner);
|
||||
}
|
||||
file.delete();
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -235,7 +235,7 @@ public class TestCompetingRisk {
|
|||
|
||||
final ForestTrainer<CompetingRiskResponse, CompetingRiskFunctions, CompetingRiskFunctions> forestTrainer = new ForestTrainer<>(settings, dataset, covariates);
|
||||
|
||||
final Forest<CompetingRiskFunctions, CompetingRiskFunctions> forest = forestTrainer.trainSerial();
|
||||
final Forest<CompetingRiskFunctions, CompetingRiskFunctions> forest = forestTrainer.trainSerialInMemory();
|
||||
|
||||
// prediction row
|
||||
// time status ageatfda idu black cd4nadir
|
||||
|
@ -328,7 +328,7 @@ public class TestCompetingRisk {
|
|||
|
||||
final long startTime = System.currentTimeMillis();
|
||||
for(int i=0; i<50; i++){
|
||||
forestTrainer.trainSerial();
|
||||
forestTrainer.trainSerialInMemory();
|
||||
}
|
||||
final long endTime = System.currentTimeMillis();
|
||||
|
||||
|
@ -346,7 +346,7 @@ public class TestCompetingRisk {
|
|||
final List<Row<CompetingRiskResponse>> dataset = DataUtils.loadData(covariates, settings.getResponseLoader(),
|
||||
settings.getTrainingDataLocation());
|
||||
final ForestTrainer<CompetingRiskResponse, CompetingRiskFunctions, CompetingRiskFunctions> forestTrainer = new ForestTrainer<>(settings, dataset, covariates);
|
||||
final Forest<CompetingRiskFunctions, CompetingRiskFunctions> forest = forestTrainer.trainSerial();
|
||||
final Forest<CompetingRiskFunctions, CompetingRiskFunctions> forest = forestTrainer.trainSerialInMemory();
|
||||
|
||||
// prediction row
|
||||
// time status ageatfda idu black cd4nadir
|
||||
|
|
|
@ -23,6 +23,7 @@ import ca.joeltherrien.randomforest.covariates.settings.NumericCovariateSettings
|
|||
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
|
||||
import ca.joeltherrien.randomforest.responses.competingrisk.splitfinder.LogRankSplitFinder;
|
||||
import ca.joeltherrien.randomforest.tree.Split;
|
||||
import ca.joeltherrien.randomforest.utils.Data;
|
||||
import ca.joeltherrien.randomforest.utils.DataUtils;
|
||||
import ca.joeltherrien.randomforest.utils.SingletonIterator;
|
||||
import ca.joeltherrien.randomforest.utils.Utils;
|
||||
|
@ -89,11 +90,4 @@ public class TestLogRankSplitFinder {
|
|||
assertTrue(Math.abs(expected - actual) < margin, "Expected " + expected + " but saw " + actual);
|
||||
}
|
||||
|
||||
@lombok.Data
|
||||
@AllArgsConstructor
|
||||
public static class Data<Y> {
|
||||
private List<Row<Y>> rows;
|
||||
private List<Covariate> covariateList;
|
||||
}
|
||||
|
||||
}
|
||||
|
|
35
src/test/java/ca/joeltherrien/randomforest/utils/Data.java
Normal file
35
src/test/java/ca/joeltherrien/randomforest/utils/Data.java
Normal file
|
@ -0,0 +1,35 @@
|
|||
/*
|
||||
* Copyright (c) 2019 Joel Therrien.
|
||||
* This program is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* This program is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License
|
||||
* along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package ca.joeltherrien.randomforest.utils;
|
||||
|
||||
import ca.joeltherrien.randomforest.Row;
|
||||
import ca.joeltherrien.randomforest.covariates.Covariate;
|
||||
import lombok.AllArgsConstructor;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* Convenience class for unit tests
|
||||
*
|
||||
* @param <Y> The type of response.
|
||||
*/
|
||||
@lombok.Data
|
||||
@AllArgsConstructor
|
||||
public class Data<Y> {
|
||||
private List<Row<Y>> rows;
|
||||
private List<Covariate> covariateList;
|
||||
}
|
|
@ -88,7 +88,7 @@ public class TrainForest {
|
|||
|
||||
final long startTime = System.currentTimeMillis();
|
||||
|
||||
//final Forest<Double> forest = forestTrainer.trainSerial();
|
||||
//final Forest<Double> forest = forestTrainer.trainSerialInMemory();
|
||||
//final Forest<Double> forest = forestTrainer.trainParallelInMemory(3);
|
||||
forestTrainer.trainParallelOnDisk(3);
|
||||
|
||||
|
|
Loading…
Reference in a new issue