Add support for seeds to control randomness when training forests

This commit is contained in:
Joel Therrien 2019-05-10 16:02:33 -07:00
parent 17ae3a9f5a
commit 6f318db79e
10 changed files with 408 additions and 52 deletions

View file

@ -82,7 +82,7 @@ public class Main {
if(settings.getNumberOfThreads() > 1){ if(settings.getNumberOfThreads() > 1){
forestTrainer.trainParallelInMemory(settings.getNumberOfThreads()); forestTrainer.trainParallelInMemory(settings.getNumberOfThreads());
} else{ } else{
forestTrainer.trainSerial(); forestTrainer.trainSerialInMemory();
} }
} }
} }

View file

@ -151,6 +151,7 @@ public class Settings {
private int nodeSize = 5; private int nodeSize = 5;
private int maxNodeDepth = 1000000; // basically no maxNodeDepth private int maxNodeDepth = 1000000; // basically no maxNodeDepth
private boolean checkNodePurity = false; private boolean checkNodePurity = false;
private Long randomSeed;
private ObjectNode responseCombinerSettings = new ObjectNode(JsonNodeFactory.instance); private ObjectNode responseCombinerSettings = new ObjectNode(JsonNodeFactory.instance);
private ObjectNode splitFinderSettings = new ObjectNode(JsonNodeFactory.instance); private ObjectNode splitFinderSettings = new ObjectNode(JsonNodeFactory.instance);

View file

@ -28,12 +28,11 @@ import lombok.Builder;
import java.io.File; import java.io.File;
import java.io.IOException; import java.io.IOException;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays;
import java.util.List; import java.util.List;
import java.util.Random; import java.util.Random;
import java.util.UUID;
import java.util.concurrent.ExecutorService; import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors; import java.util.concurrent.Executors;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors; import java.util.stream.Collectors;
@ -51,8 +50,9 @@ public class ForestTrainer<Y, TO, FO> {
// number of trees to try // number of trees to try
private final int ntree; private final int ntree;
private final boolean displayProgress; private final boolean displayProgress; // whether to print to standard output our progress; not always desirable
private final String saveTreeLocation; private final String saveTreeLocation;
private final long randomSeed;
public ForestTrainer(final Settings settings, final List<Row<Y>> data, final List<Covariate> covariates){ public ForestTrainer(final Settings settings, final List<Row<Y>> data, final List<Covariate> covariates){
this.ntree = settings.getNtree(); this.ntree = settings.getNtree();
@ -63,19 +63,25 @@ public class ForestTrainer<Y, TO, FO> {
this.covariates = covariates; this.covariates = covariates;
this.treeResponseCombiner = settings.getTreeCombiner(); this.treeResponseCombiner = settings.getTreeCombiner();
this.treeTrainer = new TreeTrainer<>(settings, covariates); this.treeTrainer = new TreeTrainer<>(settings, covariates);
if(settings.getRandomSeed() != null){
this.randomSeed = settings.getRandomSeed();
}
else{
this.randomSeed = System.nanoTime();
}
} }
public Forest<TO, FO> trainSerial(){ public Forest<TO, FO> trainSerialInMemory(){
final List<Tree<TO>> trees = new ArrayList<>(ntree); final List<Tree<TO>> trees = new ArrayList<>(ntree);
final Bootstrapper<Row<Y>> bootstrapper = new Bootstrapper<>(data); final Bootstrapper<Row<Y>> bootstrapper = new Bootstrapper<>(data);
final Random random = new Random();
for(int j=0; j<ntree; j++){ for(int j=0; j<ntree; j++){
if(displayProgress){ if(displayProgress){
System.out.print("\rFinished tree " + j + "/" + ntree + " trees"); System.out.print("\rFinished tree " + j + "/" + ntree + " trees");
} }
final Random random = new Random(this.randomSeed + j);
trees.add(trainTree(bootstrapper, random)); trees.add(trainTree(bootstrapper, random));
} }
@ -96,22 +102,32 @@ public class ForestTrainer<Y, TO, FO> {
public void trainSerialOnDisk(){ public void trainSerialOnDisk(){
// First we need to see how many trees there currently are // First we need to see how many trees there currently are
final File folder = new File(saveTreeLocation); final File folder = new File(saveTreeLocation);
if(!folder.exists()){
folder.mkdir();
}
if(!folder.isDirectory()){ if(!folder.isDirectory()){
throw new IllegalArgumentException("Tree directory must be a directory!"); throw new IllegalArgumentException("Tree directory must be a directory!");
} }
final File[] treeFiles = folder.listFiles((file, s) -> s.endsWith(".tree")); final File[] treeFiles = folder.listFiles((file, s) -> s.endsWith(".tree"));
final List<String> treeFileNames = Arrays.stream(treeFiles).map(file -> file.getName()).collect(Collectors.toList());
final AtomicInteger treeCount = new AtomicInteger(treeFiles.length); // tracks how many trees are finished final AtomicInteger treeCount = new AtomicInteger(treeFiles.length); // tracks how many trees are finished
// Using an AtomicInteger is overkill for serial code, but this lets use reuse TreeSavedWorker // Using an AtomicInteger is overkill for serial code, but this lets use reuse TreeSavedWorker
for(int j=treeCount.get(); j<ntree; j++){ for(int j=0; j<ntree; j++){
if(displayProgress) { if(displayProgress) {
System.out.print("\rFinished " + treeCount.get() + "/" + ntree + " trees"); System.out.print("\rFinished " + treeCount.get() + "/" + ntree + " trees");
} }
final Runnable worker = new TreeSavedWorker(data, "tree-" + UUID.randomUUID() + ".tree", treeCount); final String treeFileName = "tree-" + (j+1) + ".tree";
if(treeFileNames.contains(treeFileName)){
continue;
}
final Random random = new Random(this.randomSeed + j);
final Runnable worker = new TreeSavedWorker(data, treeFileName, treeCount, random);
worker.run(); worker.run();
} }
@ -132,7 +148,8 @@ public class ForestTrainer<Y, TO, FO> {
final ExecutorService executorService = Executors.newFixedThreadPool(threads); final ExecutorService executorService = Executors.newFixedThreadPool(threads);
for(int j=0; j<ntree; j++){ for(int j=0; j<ntree; j++){
final Runnable worker = new TreeInMemoryWorker(data, j, trees); final Random random = new Random(this.randomSeed + j);
final Runnable worker = new TreeInMemoryWorker(data, j, trees, random);
executorService.execute(worker); executorService.execute(worker);
} }
@ -182,18 +199,28 @@ public class ForestTrainer<Y, TO, FO> {
public void trainParallelOnDisk(int threads){ public void trainParallelOnDisk(int threads){
// First we need to see how many trees there currently are // First we need to see how many trees there currently are
final File folder = new File(saveTreeLocation); final File folder = new File(saveTreeLocation);
if(!folder.exists()){
folder.mkdir();
}
if(!folder.isDirectory()){ if(!folder.isDirectory()){
throw new IllegalArgumentException("Tree directory must be a directory!"); throw new IllegalArgumentException("Tree directory must be a directory!");
} }
final File[] treeFiles = folder.listFiles((file, s) -> s.endsWith(".tree")); final File[] treeFiles = folder.listFiles((file, s) -> s.endsWith(".tree"));
final List<String> treeFileNames = Arrays.stream(treeFiles).map(file -> file.getName()).collect(Collectors.toList());
final ExecutorService executorService = Executors.newFixedThreadPool(threads);
final AtomicInteger treeCount = new AtomicInteger(treeFiles.length); // tracks how many trees are finished final AtomicInteger treeCount = new AtomicInteger(treeFiles.length); // tracks how many trees are finished
for(int j=treeCount.get(); j<ntree; j++){ final ExecutorService executorService = Executors.newFixedThreadPool(threads);
final Runnable worker = new TreeSavedWorker(data, "tree-" + UUID.randomUUID() + ".tree", treeCount);
for(int j=0; j<ntree; j++){
final String treeFileName = "tree-" + (j+1) + ".tree";
if(treeFileNames.contains(treeFileName)){
continue;
}
final Random random = new Random(this.randomSeed + j);
final Runnable worker = new TreeSavedWorker(data, treeFileName, treeCount, random);
executorService.execute(worker); executorService.execute(worker);
} }
@ -240,18 +267,18 @@ public class ForestTrainer<Y, TO, FO> {
private final Bootstrapper<Row<Y>> bootstrapper; private final Bootstrapper<Row<Y>> bootstrapper;
private final int treeIndex; private final int treeIndex;
private final List<Tree<TO>> treeList; private final List<Tree<TO>> treeList;
private final Random random;
TreeInMemoryWorker(final List<Row<Y>> data, final int treeIndex, final List<Tree<TO>> treeList) { TreeInMemoryWorker(final List<Row<Y>> data, final int treeIndex, final List<Tree<TO>> treeList, final Random random) {
this.bootstrapper = new Bootstrapper<>(data); this.bootstrapper = new Bootstrapper<>(data);
this.treeIndex = treeIndex; this.treeIndex = treeIndex;
this.treeList = treeList; this.treeList = treeList;
this.random = random;
} }
@Override @Override
public void run() { public void run() {
final Tree<TO> tree = trainTree(bootstrapper, random);
// ThreadLocalRandom should make sure we don't duplicate seeds
final Tree<TO> tree = trainTree(bootstrapper, ThreadLocalRandom.current());
// should be okay as the list structure isn't changing // should be okay as the list structure isn't changing
treeList.set(treeIndex, tree); treeList.set(treeIndex, tree);
@ -265,18 +292,18 @@ public class ForestTrainer<Y, TO, FO> {
private final Bootstrapper<Row<Y>> bootstrapper; private final Bootstrapper<Row<Y>> bootstrapper;
private final String filename; private final String filename;
private final AtomicInteger treeCount; private final AtomicInteger treeCount;
private final Random random;
public TreeSavedWorker(final List<Row<Y>> data, final String filename, final AtomicInteger treeCount) { public TreeSavedWorker(final List<Row<Y>> data, final String filename, final AtomicInteger treeCount, final Random random) {
this.bootstrapper = new Bootstrapper<>(data); this.bootstrapper = new Bootstrapper<>(data);
this.filename = filename; this.filename = filename;
this.treeCount = treeCount; this.treeCount = treeCount;
this.random = random;
} }
@Override @Override
public void run() { public void run() {
final Tree<TO> tree = trainTree(bootstrapper, random);
// ThreadLocalRandom should make sure we don't duplicate seeds
final Tree<TO> tree = trainTree(bootstrapper, ThreadLocalRandom.current());
try { try {
DataUtils.saveObject(tree, saveTreeLocation + "/" + filename); DataUtils.saveObject(tree, saveTreeLocation + "/" + filename);

View file

@ -0,0 +1,298 @@
/*
* Copyright (c) 2019 Joel Therrien.
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
package ca.joeltherrien.randomforest;
import ca.joeltherrien.randomforest.covariates.Covariate;
import ca.joeltherrien.randomforest.covariates.bool.BooleanCovariate;
import ca.joeltherrien.randomforest.covariates.factor.FactorCovariate;
import ca.joeltherrien.randomforest.covariates.numeric.NumericCovariate;
import ca.joeltherrien.randomforest.responses.regression.MeanResponseCombiner;
import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceSplitFinder;
import ca.joeltherrien.randomforest.tree.Forest;
import ca.joeltherrien.randomforest.tree.ForestTrainer;
import ca.joeltherrien.randomforest.tree.TreeTrainer;
import ca.joeltherrien.randomforest.utils.DataUtils;
import ca.joeltherrien.randomforest.utils.Utils;
import org.junit.jupiter.api.Test;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
public class TestDeterministicForests {
private final String saveTreeLocation = "src/test/resources/trees/";
private List<Covariate> generateCovariates(){
final List<Covariate> covariateList = new ArrayList<>();
int index = 0;
for(int j=0; j<5; j++){
final NumericCovariate numericCovariate = new NumericCovariate("numeric"+j, index);
covariateList.add(numericCovariate);
index++;
}
for(int j=0; j<5; j++){
final BooleanCovariate booleanCovariate = new BooleanCovariate("boolean"+j, index);
covariateList.add(booleanCovariate);
index++;
}
final List<String> levels = Utils.easyList("cat", "dog", "mouse");
for(int j=0; j<5; j++){
final FactorCovariate factorCovariate = new FactorCovariate("factor"+j, index, levels);
covariateList.add(factorCovariate);
index++;
}
return covariateList;
}
private Covariate.Value generateRandomValue(Covariate covariate, Random random){
if(covariate instanceof NumericCovariate){
return covariate.createValue(random.nextGaussian());
}
if(covariate instanceof BooleanCovariate){
return covariate.createValue(random.nextBoolean());
}
if(covariate instanceof FactorCovariate){
final double itemSelection = random.nextDouble();
final String item;
if(itemSelection < 1.0/3.0){
item = "cat";
}
else if(itemSelection < 2.0/3.0){
item = "dog";
}
else{
item = "mouse";
}
return covariate.createValue(item);
}
else{
throw new IllegalArgumentException("Unknown covariate type of class " + covariate.getClass().getName());
}
}
private List<Row<Double>> generateTestData(List<Covariate> covariateList, int n, Random random){
final List<Row<Double>> rowList = new ArrayList<>();
for(int i=0; i<n; i++){
final double response = random.nextGaussian();
final Covariate.Value[] valueArray = new Covariate.Value[covariateList.size()];
for(int j=0; j<covariateList.size(); j++){
valueArray[j] = generateRandomValue(covariateList.get(j), random);
}
rowList.add(new Row(valueArray, i, response));
}
return rowList;
}
@Test
public void testResultsAlwaysSame() throws IOException, ClassNotFoundException {
final List<Covariate> covariateList = generateCovariates();
final Random dataGeneratingRandom = new Random();
final List<Row<Double>> trainingData = generateTestData(covariateList, 100, dataGeneratingRandom);
final List<Row<Double>> testData = generateTestData(covariateList, 10, dataGeneratingRandom);
// pick a new seed at random
final long trainingSeed = dataGeneratingRandom.nextLong();
final TreeTrainer<Double, Double> treeTrainer = TreeTrainer.<Double, Double>builder()
.checkNodePurity(false)
.covariates(covariateList)
.maxNodeDepth(100)
.mtry(1)
.nodeSize(10)
.numberOfSplits(1) // want results to be dominated by randomness
.responseCombiner(new MeanResponseCombiner())
.splitFinder(new WeightedVarianceSplitFinder())
.build();
final ForestTrainer<Double, Double, Double> forestTrainer = ForestTrainer.<Double, Double, Double>builder()
.treeTrainer(treeTrainer)
.covariates(covariateList)
.data(trainingData)
.displayProgress(false)
.ntree(10)
.randomSeed(trainingSeed)
.treeResponseCombiner(new MeanResponseCombiner())
.saveTreeLocation(saveTreeLocation)
.build();
// By training the referenceForest through one method we also verify that all the methods produce the same forests for a given seed.
final Forest<Double, Double> referenceForest = forestTrainer.trainSerialInMemory();
verifySerialInMemoryTraining(referenceForest, forestTrainer, testData);
verifyParallelInMemoryTraining(referenceForest, forestTrainer, testData);
verifySerialOnDiskTraining(referenceForest, forestTrainer, testData);
verifyParallelOnDiskTraining(referenceForest, forestTrainer, testData);
}
/**
* Tests that if we train a forest under a specified seed for 10 trees, that it is equal to training a forest
* for 5 trees only, and then starting from that point to train the last 5.
*
* @throws IOException
* @throws ClassNotFoundException
*/
@Test
public void testInterupptedTrainingProducesSameResults() throws IOException, ClassNotFoundException {
final List<Covariate> covariateList = generateCovariates();
final Random dataGeneratingRandom = new Random();
final List<Row<Double>> trainingData = generateTestData(covariateList, 100, dataGeneratingRandom);
final List<Row<Double>> testData = generateTestData(covariateList, 10, dataGeneratingRandom);
// pick a new seed at random
final long trainingSeed = dataGeneratingRandom.nextLong();
final TreeTrainer<Double, Double> treeTrainer = TreeTrainer.<Double, Double>builder()
.checkNodePurity(false)
.covariates(covariateList)
.maxNodeDepth(100)
.mtry(1)
.nodeSize(10)
.numberOfSplits(1) // want results to be dominated by randomness
.responseCombiner(new MeanResponseCombiner())
.splitFinder(new WeightedVarianceSplitFinder())
.build();
final ForestTrainer<Double, Double, Double> forestTrainer5Trees = ForestTrainer.<Double, Double, Double>builder()
.treeTrainer(treeTrainer)
.covariates(covariateList)
.data(trainingData)
.displayProgress(false)
.ntree(5)
.randomSeed(trainingSeed)
.treeResponseCombiner(new MeanResponseCombiner())
.saveTreeLocation(saveTreeLocation)
.build();
final ForestTrainer<Double, Double, Double> forestTrainer10Trees = ForestTrainer.<Double, Double, Double>builder()
.treeTrainer(treeTrainer)
.covariates(covariateList)
.data(trainingData)
.displayProgress(false)
.ntree(10)
.randomSeed(trainingSeed)
.treeResponseCombiner(new MeanResponseCombiner())
.saveTreeLocation(saveTreeLocation)
.build();
// By training the referenceForest through one method we also verify that all the methods produce the same forests for a given seed.
final Forest<Double, Double> referenceForest = forestTrainer10Trees.trainSerialInMemory();
final File saveTreeFile = new File(saveTreeLocation);
forestTrainer5Trees.trainSerialOnDisk();
forestTrainer10Trees.trainSerialOnDisk();
final Forest<Double, Double> forestSerial = DataUtils.loadForest(saveTreeFile, new MeanResponseCombiner());
TestUtils.removeFolder(saveTreeFile);
verifyTwoForestsEqual(testData, referenceForest, forestSerial);
forestTrainer5Trees.trainParallelOnDisk(4);
forestTrainer10Trees.trainParallelOnDisk(4);
final Forest<Double, Double> forestParallel = DataUtils.loadForest(saveTreeFile, new MeanResponseCombiner());
TestUtils.removeFolder(saveTreeFile);
verifyTwoForestsEqual(testData, referenceForest, forestParallel);
}
private void verifySerialInMemoryTraining(
final Forest<Double, Double> referenceForest,
ForestTrainer<Double, Double, Double> forestTrainer,
List<Row<Double>> testData){
for(int k=0; k<3; k++){
final Forest<Double, Double> replicantForest = forestTrainer.trainSerialInMemory();
verifyTwoForestsEqual(testData, referenceForest, replicantForest);
}
}
private void verifyParallelInMemoryTraining(
Forest<Double, Double> referenceForest,
ForestTrainer<Double, Double, Double> forestTrainer,
List<Row<Double>> testData){
for(int k=0; k<3; k++){
final Forest<Double, Double> replicantForest = forestTrainer.trainParallelInMemory(4);
verifyTwoForestsEqual(testData, referenceForest, replicantForest);
}
}
private void verifySerialOnDiskTraining(
Forest<Double, Double> referenceForest,
ForestTrainer<Double, Double, Double> forestTrainer,
List<Row<Double>> testData) throws IOException, ClassNotFoundException {
final MeanResponseCombiner responseCombiner = new MeanResponseCombiner();
final File saveTreeFile = new File(saveTreeLocation);
for(int k=0; k<3; k++){
forestTrainer.trainSerialOnDisk();
final Forest<Double, Double> replicantForest = DataUtils.loadForest(saveTreeFile, responseCombiner);
TestUtils.removeFolder(saveTreeFile);
verifyTwoForestsEqual(testData, referenceForest, replicantForest);
}
}
private void verifyParallelOnDiskTraining(
final Forest<Double, Double> referenceForest, ForestTrainer<Double, Double, Double> forestTrainer,
List<Row<Double>> testData) throws IOException, ClassNotFoundException {
final MeanResponseCombiner responseCombiner = new MeanResponseCombiner();
final File saveTreeFile = new File(saveTreeLocation);
for(int k=0; k<3; k++){
forestTrainer.trainParallelOnDisk(4);
final Forest<Double, Double> replicantForest = DataUtils.loadForest(saveTreeFile, responseCombiner);
TestUtils.removeFolder(saveTreeFile);
verifyTwoForestsEqual(testData, referenceForest, replicantForest);
}
}
// Technically verifies the two forests give equal predictions on a given test dataset
private void verifyTwoForestsEqual(final List<Row<Double>> testData,
final Forest<Double, Double> forest1,
final Forest<Double, Double> forest2){
for(Row row : testData){
final Double prediction1 = forest1.evaluate(row);
final Double prediction2 = forest2.evaluate(row);
// I've noticed that results aren't necessarily always *identical*
TestUtils.closeEnough(prediction1, prediction2, 0.0000000001);
}
}
}

View file

@ -116,7 +116,7 @@ public class TestSavingLoading {
final File directory = new File(settings.getSaveTreeLocation()); final File directory = new File(settings.getSaveTreeLocation());
if(directory.exists()){ if(directory.exists()){
cleanup(directory); TestUtils.removeFolder(directory);
} }
assertFalse(directory.exists()); assertFalse(directory.exists());
directory.mkdir(); directory.mkdir();
@ -142,7 +142,7 @@ public class TestSavingLoading {
assertEquals(NTREE, forest.getTrees().size()); assertEquals(NTREE, forest.getTrees().size());
cleanup(directory); TestUtils.removeFolder(directory);
assertFalse(directory.exists()); assertFalse(directory.exists());
@ -157,7 +157,7 @@ public class TestSavingLoading {
final File directory = new File(settings.getSaveTreeLocation()); final File directory = new File(settings.getSaveTreeLocation());
if(directory.exists()){ if(directory.exists()){
cleanup(directory); TestUtils.removeFolder(directory);
} }
assertFalse(directory.exists()); assertFalse(directory.exists());
directory.mkdir(); directory.mkdir();
@ -183,24 +183,12 @@ public class TestSavingLoading {
assertEquals(NTREE, forest.getTrees().size()); assertEquals(NTREE, forest.getTrees().size());
cleanup(directory); TestUtils.removeFolder(directory);
assertFalse(directory.exists()); assertFalse(directory.exists());
} }
private void cleanup(File file){
if(file.isFile()){
file.delete();
}
else{
for(final File inner : file.listFiles()){
cleanup(inner);
}
file.delete();
}
}
} }

View file

@ -20,6 +20,7 @@ import ca.joeltherrien.randomforest.utils.StepFunction;
import ca.joeltherrien.randomforest.utils.Utils; import ca.joeltherrien.randomforest.utils.Utils;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import java.io.File;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.HashSet; import java.util.HashSet;
import java.util.List; import java.util.List;
@ -170,4 +171,16 @@ public class TestUtils {
} }
} }
public static void removeFolder(File file){
if(file.isFile()){
file.delete();
}
else{
for(final File inner : file.listFiles()){
removeFolder(inner);
}
file.delete();
}
}
} }

View file

@ -235,7 +235,7 @@ public class TestCompetingRisk {
final ForestTrainer<CompetingRiskResponse, CompetingRiskFunctions, CompetingRiskFunctions> forestTrainer = new ForestTrainer<>(settings, dataset, covariates); final ForestTrainer<CompetingRiskResponse, CompetingRiskFunctions, CompetingRiskFunctions> forestTrainer = new ForestTrainer<>(settings, dataset, covariates);
final Forest<CompetingRiskFunctions, CompetingRiskFunctions> forest = forestTrainer.trainSerial(); final Forest<CompetingRiskFunctions, CompetingRiskFunctions> forest = forestTrainer.trainSerialInMemory();
// prediction row // prediction row
// time status ageatfda idu black cd4nadir // time status ageatfda idu black cd4nadir
@ -328,7 +328,7 @@ public class TestCompetingRisk {
final long startTime = System.currentTimeMillis(); final long startTime = System.currentTimeMillis();
for(int i=0; i<50; i++){ for(int i=0; i<50; i++){
forestTrainer.trainSerial(); forestTrainer.trainSerialInMemory();
} }
final long endTime = System.currentTimeMillis(); final long endTime = System.currentTimeMillis();
@ -346,7 +346,7 @@ public class TestCompetingRisk {
final List<Row<CompetingRiskResponse>> dataset = DataUtils.loadData(covariates, settings.getResponseLoader(), final List<Row<CompetingRiskResponse>> dataset = DataUtils.loadData(covariates, settings.getResponseLoader(),
settings.getTrainingDataLocation()); settings.getTrainingDataLocation());
final ForestTrainer<CompetingRiskResponse, CompetingRiskFunctions, CompetingRiskFunctions> forestTrainer = new ForestTrainer<>(settings, dataset, covariates); final ForestTrainer<CompetingRiskResponse, CompetingRiskFunctions, CompetingRiskFunctions> forestTrainer = new ForestTrainer<>(settings, dataset, covariates);
final Forest<CompetingRiskFunctions, CompetingRiskFunctions> forest = forestTrainer.trainSerial(); final Forest<CompetingRiskFunctions, CompetingRiskFunctions> forest = forestTrainer.trainSerialInMemory();
// prediction row // prediction row
// time status ageatfda idu black cd4nadir // time status ageatfda idu black cd4nadir

View file

@ -23,6 +23,7 @@ import ca.joeltherrien.randomforest.covariates.settings.NumericCovariateSettings
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse; import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
import ca.joeltherrien.randomforest.responses.competingrisk.splitfinder.LogRankSplitFinder; import ca.joeltherrien.randomforest.responses.competingrisk.splitfinder.LogRankSplitFinder;
import ca.joeltherrien.randomforest.tree.Split; import ca.joeltherrien.randomforest.tree.Split;
import ca.joeltherrien.randomforest.utils.Data;
import ca.joeltherrien.randomforest.utils.DataUtils; import ca.joeltherrien.randomforest.utils.DataUtils;
import ca.joeltherrien.randomforest.utils.SingletonIterator; import ca.joeltherrien.randomforest.utils.SingletonIterator;
import ca.joeltherrien.randomforest.utils.Utils; import ca.joeltherrien.randomforest.utils.Utils;
@ -89,11 +90,4 @@ public class TestLogRankSplitFinder {
assertTrue(Math.abs(expected - actual) < margin, "Expected " + expected + " but saw " + actual); assertTrue(Math.abs(expected - actual) < margin, "Expected " + expected + " but saw " + actual);
} }
@lombok.Data
@AllArgsConstructor
public static class Data<Y> {
private List<Row<Y>> rows;
private List<Covariate> covariateList;
}
} }

View file

@ -0,0 +1,35 @@
/*
* Copyright (c) 2019 Joel Therrien.
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
package ca.joeltherrien.randomforest.utils;
import ca.joeltherrien.randomforest.Row;
import ca.joeltherrien.randomforest.covariates.Covariate;
import lombok.AllArgsConstructor;
import java.util.List;
/**
* Convenience class for unit tests
*
* @param <Y> The type of response.
*/
@lombok.Data
@AllArgsConstructor
public class Data<Y> {
private List<Row<Y>> rows;
private List<Covariate> covariateList;
}

View file

@ -88,7 +88,7 @@ public class TrainForest {
final long startTime = System.currentTimeMillis(); final long startTime = System.currentTimeMillis();
//final Forest<Double> forest = forestTrainer.trainSerial(); //final Forest<Double> forest = forestTrainer.trainSerialInMemory();
//final Forest<Double> forest = forestTrainer.trainParallelInMemory(3); //final Forest<Double> forest = forestTrainer.trainParallelInMemory(3);
forestTrainer.trainParallelOnDisk(3); forestTrainer.trainParallelOnDisk(3);