diff --git a/executable/src/main/java/ca/joeltherrien/randomforest/Main.java b/executable/src/main/java/ca/joeltherrien/randomforest/Main.java index aeeddc6..569a51b 100644 --- a/executable/src/main/java/ca/joeltherrien/randomforest/Main.java +++ b/executable/src/main/java/ca/joeltherrien/randomforest/Main.java @@ -122,7 +122,7 @@ public class Main { Utils.reduceListToSize(dataset, n, new Random()); final File folder = new File(settings.getSaveTreeLocation()); - final Forest forest = DataUtils.loadForest(folder, responseCombiner); + final Forest forest = DataUtils.loadOnlineForest(folder, responseCombiner); final boolean useBootstrapPredictions = settings.getTrainingDataLocation().equals(settings.getValidationDataLocation()); diff --git a/library/src/main/java/ca/joeltherrien/randomforest/tree/Forest.java b/library/src/main/java/ca/joeltherrien/randomforest/tree/Forest.java index 86fbbbe..703aed4 100644 --- a/library/src/main/java/ca/joeltherrien/randomforest/tree/Forest.java +++ b/library/src/main/java/ca/joeltherrien/randomforest/tree/Forest.java @@ -17,31 +17,18 @@ package ca.joeltherrien.randomforest.tree; import ca.joeltherrien.randomforest.CovariateRow; -import ca.joeltherrien.randomforest.covariates.Covariate; -import lombok.Builder; -import java.util.Collections; import java.util.List; import java.util.Map; import java.util.TreeMap; import java.util.stream.Collectors; -@Builder -public class Forest { // O = output of trees, FO = forest output. In practice O == FO, even in competing risk & survival settings +public abstract class Forest { - private final List> trees; - private final ResponseCombiner treeResponseCombiner; - private final List covariateList; - - public FO evaluate(CovariateRow row){ - - return treeResponseCombiner.combine( - trees.stream() - .map(node -> node.evaluate(row)) - .collect(Collectors.toList()) - ); - - } + public abstract FO evaluate(CovariateRow row); + public abstract FO evaluateOOB(CovariateRow row); + public abstract Iterable> getTrees(); + public abstract int getNumberOfTrees(); /** * Used primarily in the R package interface to avoid R loops; and for easier parallelization. @@ -93,21 +80,6 @@ public class Forest { // O = output of trees, FO = forest output. In prac .collect(Collectors.toList()); } - public FO evaluateOOB(CovariateRow row){ - - return treeResponseCombiner.combine( - trees.stream() - .filter(tree -> !tree.idInBootstrapSample(row.getId())) - .map(node -> node.evaluate(row)) - .collect(Collectors.toList()) - ); - - } - - public List> getTrees(){ - return Collections.unmodifiableList(trees); - } - public Map findSplitsByCovariate(){ final Map countMap = new TreeMap<>(); @@ -158,4 +130,5 @@ public class Forest { // O = output of trees, FO = forest output. In prac return countTerminalNodes; } + } diff --git a/library/src/main/java/ca/joeltherrien/randomforest/tree/ForestTrainer.java b/library/src/main/java/ca/joeltherrien/randomforest/tree/ForestTrainer.java index 1a61929..aa2b3db 100644 --- a/library/src/main/java/ca/joeltherrien/randomforest/tree/ForestTrainer.java +++ b/library/src/main/java/ca/joeltherrien/randomforest/tree/ForestTrainer.java @@ -57,10 +57,10 @@ public class ForestTrainer { * in which case its trees are combined with the new one. * @return A trained forest. */ - public Forest trainSerialInMemory(Optional> initialForest){ + public OnlineForest trainSerialInMemory(Optional> initialForest){ final List> trees = new ArrayList<>(ntree); - initialForest.ifPresent(forest -> trees.addAll(forest.getTrees())); + initialForest.ifPresent(forest -> forest.getTrees().forEach(trees::add)); final Bootstrapper> bootstrapper = new Bootstrapper<>(data); @@ -77,11 +77,9 @@ public class ForestTrainer { System.out.println("Finished"); } - - return Forest.builder() + return OnlineForest.builder() .treeResponseCombiner(treeResponseCombiner) .trees(trees) - .covariateList(covariates) .build(); } @@ -94,7 +92,7 @@ public class ForestTrainer { * There cannot be existing trees if the initial forest is * specified. */ - public void trainSerialOnDisk(Optional> initialForest){ + public OfflineForest trainSerialOnDisk(Optional> initialForest){ // First we need to see how many trees there currently are final File folder = new File(saveTreeLocation); if(!folder.exists()){ @@ -115,17 +113,14 @@ public class ForestTrainer { final AtomicInteger treeCount; // tracks how many trees are finished // Using an AtomicInteger is overkill for serial code, but this lets us reuse TreeSavedWorker if(initialForest.isPresent()){ - final List> initialTrees = initialForest.get().getTrees(); - - for(int j=0; j tree : initialForest.get().getTrees()){ final String filename = "tree-" + (j+1) + ".tree"; - final Tree tree = initialTrees.get(j); - saveTree(tree, filename); - + j++; } - treeCount = new AtomicInteger(initialTrees.size()); + treeCount = new AtomicInteger(j); } else{ treeCount = new AtomicInteger(treeFiles.length); } @@ -153,6 +148,8 @@ public class ForestTrainer { System.out.println("Finished"); } + return new OfflineForest<>(folder, treeResponseCombiner); + } /** @@ -162,7 +159,7 @@ public class ForestTrainer { * in which case its trees are combined with the new one. * @param threads The number of trees to train at once. */ - public Forest trainParallelInMemory(Optional> initialForest, int threads){ + public OnlineForest trainParallelInMemory(Optional> initialForest, int threads){ // create a list that is pre-specified in size (I can call the .set method at any index < ntree without // the earlier indexes being filled. @@ -170,11 +167,12 @@ public class ForestTrainer { final int startingCount; if(initialForest.isPresent()){ - final List> initialTrees = initialForest.get().getTrees(); - for(int j=0; j tree : initialForest.get().getTrees()){ + trees.set(j, tree); + j++; } - startingCount = initialTrees.size(); + startingCount = initialForest.get().getNumberOfTrees(); } else{ startingCount = 0; @@ -219,7 +217,7 @@ public class ForestTrainer { System.out.println("\nFinished"); } - return Forest.builder() + return OnlineForest.builder() .treeResponseCombiner(treeResponseCombiner) .trees(trees) .build(); @@ -235,7 +233,7 @@ public class ForestTrainer { * specified. * @param threads The number of trees to train at once. */ - public void trainParallelOnDisk(Optional> initialForest, int threads){ + public OfflineForest trainParallelOnDisk(Optional> initialForest, int threads){ // First we need to see how many trees there currently are final File folder = new File(saveTreeLocation); if(!folder.exists()){ @@ -255,17 +253,14 @@ public class ForestTrainer { final AtomicInteger treeCount; // tracks how many trees are finished if(initialForest.isPresent()){ - final List> initialTrees = initialForest.get().getTrees(); - - for(int j=0; j tree : initialForest.get().getTrees()){ final String filename = "tree-" + (j+1) + ".tree"; - final Tree tree = initialTrees.get(j); - saveTree(tree, filename); - + j++; } - treeCount = new AtomicInteger(initialTrees.size()); + treeCount = new AtomicInteger(j); } else{ treeCount = new AtomicInteger(treeFiles.length); } @@ -309,6 +304,8 @@ public class ForestTrainer { System.out.println("\nFinished"); } + return new OfflineForest<>(folder, treeResponseCombiner); + } private Tree trainTree(final Bootstrapper> bootstrapper, Random random){ diff --git a/library/src/main/java/ca/joeltherrien/randomforest/tree/OfflineForest.java b/library/src/main/java/ca/joeltherrien/randomforest/tree/OfflineForest.java new file mode 100644 index 0000000..1c874a1 --- /dev/null +++ b/library/src/main/java/ca/joeltherrien/randomforest/tree/OfflineForest.java @@ -0,0 +1,198 @@ +/* + * Copyright (c) 2019 Joel Therrien. + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package ca.joeltherrien.randomforest.tree; + +import ca.joeltherrien.randomforest.CovariateRow; +import ca.joeltherrien.randomforest.utils.IterableOfflineTree; +import lombok.AllArgsConstructor; + +import java.io.File; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Iterator; +import java.util.List; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +@AllArgsConstructor +public class OfflineForest extends Forest { + + private final File[] treeFiles; + private final ResponseCombiner treeResponseCombiner; + + public OfflineForest(File treeDirectoryPath, ResponseCombiner treeResponseCombiner){ + this.treeResponseCombiner = treeResponseCombiner; + + if(!treeDirectoryPath.isDirectory()){ + throw new IllegalArgumentException("treeDirectoryPath must point to a directory!"); + } + + this.treeFiles = treeDirectoryPath.listFiles((file, s) -> s.endsWith(".tree")); + + } + + @Override + public FO evaluate(CovariateRow row) { + final List predictedOutputs = new ArrayList<>(treeFiles.length); + for(final Tree tree : getTrees()){ + final O prediction = tree.evaluate(row); + predictedOutputs.add(prediction); + } + + return treeResponseCombiner.combine(predictedOutputs); + } + + @Override + public FO evaluateOOB(CovariateRow row) { + final List predictedOutputs = new ArrayList<>(treeFiles.length); + for(final Tree tree : getTrees()){ + if(!tree.idInBootstrapSample(row.getId())){ + final O prediction = tree.evaluate(row); + predictedOutputs.add(prediction); + } + } + + return treeResponseCombiner.combine(predictedOutputs); + } + + + @Override + public List evaluate(List rowList){ + final O[][] predictions = (O[][])new Object[rowList.size()][treeFiles.length]; + final Iterator> treeIterator = getTrees().iterator(); + + for(int treeId = 0; treeId < treeFiles.length; treeId++){ + final Tree currentTree = treeIterator.next(); + + final int tempTreeId = treeId; // Java workaround + IntStream.range(0, rowList.size()).parallel().forEach( + rowId -> { + final CovariateRow row = rowList.get(rowId); + final O prediction = currentTree.evaluate(row); + predictions[rowId][tempTreeId] = prediction; + } + ); + } + + return Arrays.stream(predictions).parallel() + .map(predArray -> treeResponseCombiner.combine(Arrays.asList(predArray))) + .collect(Collectors.toList()); + } + + @Override + public List evaluateSerial(List rowList){ + final O[][] predictions = (O[][])new Object[rowList.size()][treeFiles.length]; + final Iterator> treeIterator = getTrees().iterator(); + + for(int treeId = 0; treeId < treeFiles.length; treeId++){ + final Tree currentTree = treeIterator.next(); + + final int tempTreeId = treeId; // Java workaround + IntStream.range(0, rowList.size()).sequential().forEach( + rowId -> { + final CovariateRow row = rowList.get(rowId); + final O prediction = currentTree.evaluate(row); + predictions[rowId][tempTreeId] = prediction; + } + ); + } + + return Arrays.stream(predictions).sequential() + .map(predArray -> treeResponseCombiner.combine(Arrays.asList(predArray))) + .collect(Collectors.toList()); + } + + + @Override + public List evaluateOOB(List rowList){ + final O[][] predictions = (O[][])new Object[rowList.size()][treeFiles.length]; + final Iterator> treeIterator = getTrees().iterator(); + + for(int treeId = 0; treeId < treeFiles.length; treeId++){ + final Tree currentTree = treeIterator.next(); + + final int tempTreeId = treeId; // Java workaround + IntStream.range(0, rowList.size()).parallel().forEach( + rowId -> { + final CovariateRow row = rowList.get(rowId); + if(!currentTree.idInBootstrapSample(row.getId())){ + final O prediction = currentTree.evaluate(row); + predictions[rowId][tempTreeId] = prediction; + } else{ + predictions[rowId][tempTreeId] = null; + } + + } + ); + } + + return Arrays.stream(predictions).parallel() + .map(predArray -> { + final List predList = Arrays.stream(predArray).parallel() + .filter(pred -> pred != null).collect(Collectors.toList()); + + return treeResponseCombiner.combine(predList); + + }) + .collect(Collectors.toList()); + } + + @Override + public List evaluateSerialOOB(List rowList){ + final O[][] predictions = (O[][])new Object[rowList.size()][treeFiles.length]; + final Iterator> treeIterator = getTrees().iterator(); + + for(int treeId = 0; treeId < treeFiles.length; treeId++){ + final Tree currentTree = treeIterator.next(); + + final int tempTreeId = treeId; // Java workaround + IntStream.range(0, rowList.size()).sequential().forEach( + rowId -> { + final CovariateRow row = rowList.get(rowId); + if(!currentTree.idInBootstrapSample(row.getId())){ + final O prediction = currentTree.evaluate(row); + predictions[rowId][tempTreeId] = prediction; + } else{ + predictions[rowId][tempTreeId] = null; + } + + } + ); + } + + return Arrays.stream(predictions).sequential() + .map(predArray -> { + final List predList = Arrays.stream(predArray).sequential() + .filter(pred -> pred != null).collect(Collectors.toList()); + + return treeResponseCombiner.combine(predList); + + }) + .collect(Collectors.toList()); + } + + @Override + public Iterable> getTrees() { + return new IterableOfflineTree<>(treeFiles); + } + + @Override + public int getNumberOfTrees() { + return treeFiles.length; + } +} + diff --git a/library/src/main/java/ca/joeltherrien/randomforest/tree/OnlineForest.java b/library/src/main/java/ca/joeltherrien/randomforest/tree/OnlineForest.java new file mode 100644 index 0000000..8bd5c20 --- /dev/null +++ b/library/src/main/java/ca/joeltherrien/randomforest/tree/OnlineForest.java @@ -0,0 +1,66 @@ +/* + * Copyright (c) 2019 Joel Therrien. + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package ca.joeltherrien.randomforest.tree; + +import ca.joeltherrien.randomforest.CovariateRow; +import lombok.Builder; + +import java.util.Collections; +import java.util.List; +import java.util.stream.Collectors; + +@Builder +public class OnlineForest extends Forest { // O = output of trees, FO = forest output. In practice O == FO, even in competing risk & survival settings + + private final List> trees; + private final ResponseCombiner treeResponseCombiner; + + @Override + public FO evaluate(CovariateRow row){ + + return treeResponseCombiner.combine( + trees.stream() + .map(node -> node.evaluate(row)) + .collect(Collectors.toList()) + ); + + } + + @Override + public FO evaluateOOB(CovariateRow row){ + + return treeResponseCombiner.combine( + trees.stream() + .filter(tree -> !tree.idInBootstrapSample(row.getId())) + .map(node -> node.evaluate(row)) + .collect(Collectors.toList()) + ); + + } + + @Override + public List> getTrees(){ + return Collections.unmodifiableList(trees); + } + + @Override + public int getNumberOfTrees() { + return trees.size(); + } + + +} diff --git a/library/src/main/java/ca/joeltherrien/randomforest/utils/DataUtils.java b/library/src/main/java/ca/joeltherrien/randomforest/utils/DataUtils.java index 5d33b98..899c903 100644 --- a/library/src/main/java/ca/joeltherrien/randomforest/utils/DataUtils.java +++ b/library/src/main/java/ca/joeltherrien/randomforest/utils/DataUtils.java @@ -16,9 +16,7 @@ package ca.joeltherrien.randomforest.utils; -import ca.joeltherrien.randomforest.tree.Forest; -import ca.joeltherrien.randomforest.tree.ResponseCombiner; -import ca.joeltherrien.randomforest.tree.Tree; +import ca.joeltherrien.randomforest.tree.*; import java.io.*; import java.util.*; @@ -27,7 +25,7 @@ import java.util.zip.GZIPOutputStream; public class DataUtils { - public static Forest loadForest(File folder, ResponseCombiner treeResponseCombiner) throws IOException, ClassNotFoundException { + public static OnlineForest loadOnlineForest(File folder, ResponseCombiner treeResponseCombiner) throws IOException, ClassNotFoundException { if(!folder.isDirectory()){ throw new IllegalArgumentException("Tree directory must be a directory!"); } @@ -48,16 +46,16 @@ public class DataUtils { } - return Forest.builder() + return OnlineForest.builder() .trees(treeList) .treeResponseCombiner(treeResponseCombiner) .build(); } - public static Forest loadForest(String folder, ResponseCombiner treeResponseCombiner) throws IOException, ClassNotFoundException { + public static OnlineForest loadOnlineForest(String folder, ResponseCombiner treeResponseCombiner) throws IOException, ClassNotFoundException { final File directory = new File(folder); - return loadForest(directory, treeResponseCombiner); + return loadOnlineForest(directory, treeResponseCombiner); } public static void saveObject(Serializable object, String filename) throws IOException { diff --git a/library/src/main/java/ca/joeltherrien/randomforest/utils/IterableOfflineTree.java b/library/src/main/java/ca/joeltherrien/randomforest/utils/IterableOfflineTree.java new file mode 100644 index 0000000..5f3f69b --- /dev/null +++ b/library/src/main/java/ca/joeltherrien/randomforest/utils/IterableOfflineTree.java @@ -0,0 +1,68 @@ +/* + * Copyright (c) 2019 Joel Therrien. + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package ca.joeltherrien.randomforest.utils; + +import ca.joeltherrien.randomforest.tree.Tree; +import lombok.RequiredArgsConstructor; + +import java.io.File; +import java.io.FileInputStream; +import java.io.IOException; +import java.io.ObjectInputStream; +import java.util.Iterator; +import java.util.zip.GZIPInputStream; + +@RequiredArgsConstructor +public class IterableOfflineTree implements Iterable> { + + private final File[] treeFiles; + + @Override + public Iterator> iterator() { + return new OfflineTreeIterator<>(treeFiles); + } + + @RequiredArgsConstructor + public static class OfflineTreeIterator implements Iterator>{ + private final File[] treeFiles; + private int position = 0; + + @Override + public boolean hasNext() { + return position < treeFiles.length; + } + + @Override + public Tree next() { + final File treeFile = treeFiles[position]; + position++; + + + try { + final ObjectInputStream inputStream= new ObjectInputStream(new GZIPInputStream(new FileInputStream(treeFile))); + final Tree tree = (Tree) inputStream.readObject(); + return tree; + } catch (IOException | ClassNotFoundException e) { + e.printStackTrace(); + throw new RuntimeException("Failed to load tree for " + treeFile.toString()); + } + + } + } + + +} diff --git a/library/src/main/java/ca/joeltherrien/randomforest/utils/RUtils.java b/library/src/main/java/ca/joeltherrien/randomforest/utils/RUtils.java index d1f4f6f..8edcadb 100644 --- a/library/src/main/java/ca/joeltherrien/randomforest/utils/RUtils.java +++ b/library/src/main/java/ca/joeltherrien/randomforest/utils/RUtils.java @@ -25,6 +25,7 @@ import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskRespons import java.io.*; import java.util.ArrayList; import java.util.List; +import java.util.stream.IntStream; import java.util.zip.GZIPInputStream; import java.util.zip.GZIPOutputStream; @@ -198,4 +199,12 @@ public final class RUtils { return newList; } + public static File[] getTreeFileArray(String folderPath, int endingId){ + return (File[]) IntStream.rangeClosed(1, endingId).sequential() + .mapToObj(i -> folderPath + "/tree-" + i + ".tree") + .map(File::new) + .toArray(); + + } + } diff --git a/library/src/main/java/ca/joeltherrien/randomforest/utils/RightContinuousStepFunction.java b/library/src/main/java/ca/joeltherrien/randomforest/utils/RightContinuousStepFunction.java index c3ca9d2..5cf5bb2 100644 --- a/library/src/main/java/ca/joeltherrien/randomforest/utils/RightContinuousStepFunction.java +++ b/library/src/main/java/ca/joeltherrien/randomforest/utils/RightContinuousStepFunction.java @@ -16,6 +16,8 @@ package ca.joeltherrien.randomforest.utils; +import lombok.Getter; + import java.util.ArrayList; import java.util.Arrays; import java.util.List; @@ -39,6 +41,7 @@ public final class RightContinuousStepFunction extends StepFunction { * * May not be null. */ + @Getter private final double defaultY; public RightContinuousStepFunction(double[] x, double[] y, double defaultY) { diff --git a/library/src/test/java/ca/joeltherrien/randomforest/TestDeterministicForests.java b/library/src/test/java/ca/joeltherrien/randomforest/TestDeterministicForests.java index feb9190..f99604e 100644 --- a/library/src/test/java/ca/joeltherrien/randomforest/TestDeterministicForests.java +++ b/library/src/test/java/ca/joeltherrien/randomforest/TestDeterministicForests.java @@ -214,14 +214,14 @@ public class TestDeterministicForests { forestTrainer5Trees.trainSerialOnDisk(Optional.empty()); forestTrainer10Trees.trainSerialOnDisk(Optional.empty()); - final Forest forestSerial = DataUtils.loadForest(saveTreeFile, new MeanResponseCombiner()); + final Forest forestSerial = DataUtils.loadOnlineForest(saveTreeFile, new MeanResponseCombiner()); TestUtils.removeFolder(saveTreeFile); verifyTwoForestsEqual(testData, referenceForest, forestSerial); forestTrainer5Trees.trainParallelOnDisk(Optional.empty(), 4); forestTrainer10Trees.trainParallelOnDisk(Optional.empty(), 4); - final Forest forestParallel = DataUtils.loadForest(saveTreeFile, new MeanResponseCombiner()); + final Forest forestParallel = DataUtils.loadOnlineForest(saveTreeFile, new MeanResponseCombiner()); TestUtils.removeFolder(saveTreeFile); verifyTwoForestsEqual(testData, referenceForest, forestParallel); @@ -259,7 +259,7 @@ public class TestDeterministicForests { for(int k=0; k<3; k++){ forestTrainer.trainSerialOnDisk(Optional.empty()); - final Forest replicantForest = DataUtils.loadForest(saveTreeFile, responseCombiner); + final Forest replicantForest = DataUtils.loadOnlineForest(saveTreeFile, responseCombiner); TestUtils.removeFolder(saveTreeFile); verifyTwoForestsEqual(testData, referenceForest, replicantForest); } @@ -274,7 +274,7 @@ public class TestDeterministicForests { for(int k=0; k<3; k++){ forestTrainer.trainParallelOnDisk(Optional.empty(), 4); - final Forest replicantForest = DataUtils.loadForest(saveTreeFile, responseCombiner); + final Forest replicantForest = DataUtils.loadOnlineForest(saveTreeFile, responseCombiner); TestUtils.removeFolder(saveTreeFile); verifyTwoForestsEqual(testData, referenceForest, replicantForest); } diff --git a/library/src/test/java/ca/joeltherrien/randomforest/TestProvidingInitialForest.java b/library/src/test/java/ca/joeltherrien/randomforest/TestProvidingInitialForest.java index 3ff2876..3b50b46 100644 --- a/library/src/test/java/ca/joeltherrien/randomforest/TestProvidingInitialForest.java +++ b/library/src/test/java/ca/joeltherrien/randomforest/TestProvidingInitialForest.java @@ -20,8 +20,8 @@ import ca.joeltherrien.randomforest.covariates.Covariate; import ca.joeltherrien.randomforest.covariates.numeric.NumericCovariate; import ca.joeltherrien.randomforest.responses.regression.MeanResponseCombiner; import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceSplitFinder; -import ca.joeltherrien.randomforest.tree.Forest; import ca.joeltherrien.randomforest.tree.ForestTrainer; +import ca.joeltherrien.randomforest.tree.OnlineForest; import ca.joeltherrien.randomforest.tree.Tree; import ca.joeltherrien.randomforest.tree.TreeTrainer; import ca.joeltherrien.randomforest.utils.DataUtils; @@ -39,7 +39,7 @@ import static org.junit.jupiter.api.Assertions.*; public class TestProvidingInitialForest { - private Forest initialForest; + private OnlineForest initialForest; private List covariateList; private List> data; @@ -107,8 +107,8 @@ public class TestProvidingInitialForest { public void testSerialInMemory(){ final ForestTrainer forestTrainer = getForestTrainer(null, 20); - final Forest newForest = forestTrainer.trainSerialInMemory(Optional.of(initialForest)); - assertEquals(20, newForest.getTrees().size()); + final OnlineForest newForest = forestTrainer.trainSerialInMemory(Optional.of(initialForest)); + assertEquals(20, newForest.getNumberOfTrees()); for(Tree initialTree : initialForest.getTrees()){ assertTrue(newForest.getTrees().contains(initialTree)); @@ -124,8 +124,8 @@ public class TestProvidingInitialForest { public void testParallelInMemory(){ final ForestTrainer forestTrainer = getForestTrainer(null, 20); - final Forest newForest = forestTrainer.trainParallelInMemory(Optional.of(initialForest), 2); - assertEquals(20, newForest.getTrees().size()); + final OnlineForest newForest = forestTrainer.trainParallelInMemory(Optional.of(initialForest), 2); + assertEquals(20, newForest.getNumberOfTrees()); for(Tree initialTree : initialForest.getTrees()){ assertTrue(newForest.getTrees().contains(initialTree)); @@ -149,11 +149,11 @@ public class TestProvidingInitialForest { forestTrainer.trainParallelOnDisk(Optional.of(initialForest), 2); assertEquals(20, directory.listFiles().length); - final Forest newForest = DataUtils.loadForest(directory, new MeanResponseCombiner()); + final OnlineForest newForest = DataUtils.loadOnlineForest(directory, new MeanResponseCombiner()); - assertEquals(20, newForest.getTrees().size()); + assertEquals(20, newForest.getNumberOfTrees()); final List newForestTreesAsStrings = newForest.getTrees().stream() .map(tree -> tree.toString()).collect(Collectors.toList()); @@ -179,9 +179,9 @@ public class TestProvidingInitialForest { assertEquals(20, directory.listFiles().length); - final Forest newForest = DataUtils.loadForest(directory, new MeanResponseCombiner()); + final OnlineForest newForest = DataUtils.loadOnlineForest(directory, new MeanResponseCombiner()); - assertEquals(20, newForest.getTrees().size()); + assertEquals(20, newForest.getNumberOfTrees()); final List newForestTreesAsStrings = newForest.getTrees().stream() .map(tree -> tree.toString()).collect(Collectors.toList()); diff --git a/library/src/test/java/ca/joeltherrien/randomforest/TestSavingLoading.java b/library/src/test/java/ca/joeltherrien/randomforest/TestSavingLoading.java index 41c70d4..e49d926 100644 --- a/library/src/test/java/ca/joeltherrien/randomforest/TestSavingLoading.java +++ b/library/src/test/java/ca/joeltherrien/randomforest/TestSavingLoading.java @@ -24,11 +24,10 @@ import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskRespons import ca.joeltherrien.randomforest.responses.competingrisk.combiner.CompetingRiskFunctionCombiner; import ca.joeltherrien.randomforest.responses.competingrisk.combiner.CompetingRiskResponseCombiner; import ca.joeltherrien.randomforest.responses.competingrisk.splitfinder.LogRankSplitFinder; -import ca.joeltherrien.randomforest.tree.Forest; -import ca.joeltherrien.randomforest.tree.ForestTrainer; -import ca.joeltherrien.randomforest.tree.TreeTrainer; +import ca.joeltherrien.randomforest.tree.*; import ca.joeltherrien.randomforest.utils.DataUtils; import ca.joeltherrien.randomforest.utils.ResponseLoader; +import ca.joeltherrien.randomforest.utils.RightContinuousStepFunction; import ca.joeltherrien.randomforest.utils.Utils; import org.junit.jupiter.api.Test; @@ -119,16 +118,21 @@ public class TestSavingLoading { assertTrue(directory.isDirectory()); assertEquals(NTREE, directory.listFiles().length); - final Forest forest = DataUtils.loadForest(directory, new CompetingRiskFunctionCombiner(new int[]{1,2}, null)); + final CompetingRiskFunctionCombiner treeResponseCombiner = new CompetingRiskFunctionCombiner(new int[]{1,2}, null); + final OnlineForest onlineForest = DataUtils.loadOnlineForest(directory, treeResponseCombiner); + final OfflineForest offlineForest = new OfflineForest<>(directory, treeResponseCombiner); final CovariateRow predictionRow = getPredictionRow(covariates); - final CompetingRiskFunctions functions = forest.evaluate(predictionRow); - assertNotNull(functions); - assertTrue(functions.getCumulativeIncidenceFunction(1).getX().length > 2); + final CompetingRiskFunctions functionsOnline = onlineForest.evaluate(predictionRow); + assertNotNull(functionsOnline); + assertTrue(functionsOnline.getCumulativeIncidenceFunction(1).getX().length > 2); + + final CompetingRiskFunctions functionsOffline = offlineForest.evaluate(predictionRow); + assertTrue(competingFunctionsEqual(functionsOffline, functionsOnline)); - assertEquals(NTREE, forest.getTrees().size()); + assertEquals(NTREE, onlineForest.getTrees().size()); TestUtils.removeFolder(directory); @@ -159,17 +163,22 @@ public class TestSavingLoading { assertEquals(NTREE, directory.listFiles().length); - - final Forest forest = DataUtils.loadForest(directory, new CompetingRiskFunctionCombiner(new int[]{1,2}, null)); + final CompetingRiskFunctionCombiner treeResponseCombiner = new CompetingRiskFunctionCombiner(new int[]{1,2}, null); + final OnlineForest onlineForest = DataUtils.loadOnlineForest(directory, treeResponseCombiner); + final OfflineForest offlineForest = new OfflineForest<>(directory, treeResponseCombiner); final CovariateRow predictionRow = getPredictionRow(covariates); - final CompetingRiskFunctions functions = forest.evaluate(predictionRow); - assertNotNull(functions); - assertTrue(functions.getCumulativeIncidenceFunction(1).getX().length > 2); + final CompetingRiskFunctions functionsOnline = onlineForest.evaluate(predictionRow); + assertNotNull(functionsOnline); + assertTrue(functionsOnline.getCumulativeIncidenceFunction(1).getX().length > 2); - assertEquals(NTREE, forest.getTrees().size()); + final CompetingRiskFunctions functionsOffline = offlineForest.evaluate(predictionRow); + assertTrue(competingFunctionsEqual(functionsOffline, functionsOnline)); + + + assertEquals(NTREE, onlineForest.getTrees().size()); TestUtils.removeFolder(directory); @@ -177,6 +186,64 @@ public class TestSavingLoading { } + /* + We don't implement equals() methods on the below mentioned classes because then we'd need to implement an + appropriate hashCode() method that's consistent with the equals(), and we only need plain equals() for + these tests. + */ + + private boolean competingFunctionsEqual(CompetingRiskFunctions f1 ,CompetingRiskFunctions f2){ + if(!functionsEqual(f1.getSurvivalCurve(), f2.getSurvivalCurve())){ + return false; + } + + for(int i=1; i<=2; i++){ + if(!functionsEqual(f1.getCauseSpecificHazardFunction(i), f2.getCauseSpecificHazardFunction(i))){ + return false; + } + if(!functionsEqual(f1.getCumulativeIncidenceFunction(i), f2.getCumulativeIncidenceFunction(i))){ + return false; + } + } + + return true; + } + + private boolean functionsEqual(RightContinuousStepFunction f1, RightContinuousStepFunction f2){ + + final double[] f1X = f1.getX(); + final double[] f2X = f2.getX(); + + final double[] f1Y = f1.getY(); + final double[] f2Y = f2.getY(); + + // first compare array lengths + if(f1X.length != f2X.length){ + return false; + } + if(f1Y.length != f2Y.length){ + return false; + } + + // TODO - better comparisons of doubles. I don't really care too much though as this equals method is only being used in tests + final double delta = 0.000001; + + if(Math.abs(f1.getDefaultY() - f2.getDefaultY()) > delta){ + return false; + } + + for(int i=0; i < f1X.length; i++){ + if(Math.abs(f1X[i] - f2X[i]) > delta){ + return false; + } + if(Math.abs(f1Y[i] - f2Y[i]) > delta){ + return false; + } + } + + return true; + + } } diff --git a/library/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCompetingRiskErrorRateCalculator.java b/library/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCompetingRiskErrorRateCalculator.java index 5e0c6a9..0d8f0f2 100644 --- a/library/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCompetingRiskErrorRateCalculator.java +++ b/library/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCompetingRiskErrorRateCalculator.java @@ -16,10 +16,11 @@ package ca.joeltherrien.randomforest.competingrisk; -import ca.joeltherrien.randomforest.Row; -import ca.joeltherrien.randomforest.covariates.Covariate; -import ca.joeltherrien.randomforest.responses.competingrisk.*; +import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskFunctions; +import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse; +import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskUtils; import ca.joeltherrien.randomforest.tree.Forest; +import ca.joeltherrien.randomforest.tree.OnlineForest; import ca.joeltherrien.randomforest.utils.RightContinuousStepFunction; import ca.joeltherrien.randomforest.utils.StepFunction; import ca.joeltherrien.randomforest.utils.Utils; @@ -30,8 +31,6 @@ import java.util.List; import static ca.joeltherrien.randomforest.TestUtils.closeEnough; import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; public class TestCompetingRiskErrorRateCalculator { @@ -48,7 +47,7 @@ public class TestCompetingRiskErrorRateCalculator { final int event = 1; - final Forest fakeForest = Forest.builder().build(); + final Forest fakeForest = OnlineForest.builder().build(); final double naiveConcordance = CompetingRiskUtils.calculateConcordance(responseList, mortalityArray, event); diff --git a/library/src/test/java/ca/joeltherrien/randomforest/tree/vimp/VariableImportanceCalculatorTest.java b/library/src/test/java/ca/joeltherrien/randomforest/tree/vimp/TestVariableImportanceCalculator.java similarity index 98% rename from library/src/test/java/ca/joeltherrien/randomforest/tree/vimp/VariableImportanceCalculatorTest.java rename to library/src/test/java/ca/joeltherrien/randomforest/tree/vimp/TestVariableImportanceCalculator.java index e8a83a9..273b177 100644 --- a/library/src/test/java/ca/joeltherrien/randomforest/tree/vimp/VariableImportanceCalculatorTest.java +++ b/library/src/test/java/ca/joeltherrien/randomforest/tree/vimp/TestVariableImportanceCalculator.java @@ -20,7 +20,7 @@ import java.util.stream.Collectors; import static org.junit.jupiter.api.Assertions.assertEquals; -public class VariableImportanceCalculatorTest { +public class TestVariableImportanceCalculator { /* Since the logic for VariableImportanceCalculator is generic, it will be much easier to test under a regression @@ -28,7 +28,7 @@ public class VariableImportanceCalculatorTest { */ // We'l have a very simple Forest of two trees - private final Forest forest; + private final OnlineForest forest; private final List covariates; @@ -38,7 +38,7 @@ public class VariableImportanceCalculatorTest { Long setup process; forest is manually constructed so that we can be exactly sure on our variable importance. */ - public VariableImportanceCalculatorTest(){ + public TestVariableImportanceCalculator(){ final BooleanCovariate booleanCovariate = new BooleanCovariate("x", 0, false); final NumericCovariate numericCovariate = new NumericCovariate("y", 1, false); final FactorCovariate factorCovariate = new FactorCovariate("z", 2, @@ -67,10 +67,9 @@ public class VariableImportanceCalculatorTest { final Tree tree1 = makeTree(covariates, 0.0, new int[]{1,2,3,4}); final Tree tree2 = makeTree(covariates, 2.0, new int[]{5,6,7,8}); - this.forest = Forest.builder() + this.forest = OnlineForest.builder() .trees(Utils.easyList(tree1, tree2)) .treeResponseCombiner(new MeanResponseCombiner()) - .covariateList(this.covariates) .build(); // formula; boolean high adds 100; high numeric adds 10