Compare commits
No commits in common. "master" and "ibs" have entirely different histories.
125 changed files with 2266 additions and 2020 deletions
34
README.md
34
README.md
|
@ -1,18 +1,21 @@
|
||||||
# README
|
# README
|
||||||
|
|
||||||
This repository contains the largeRCRF Java library, containing all of the logic used for training the random forests. This provides the Jar file used in the R package [largeRCRF](https://github.com/jatherrien/largeRCRF).
|
This repository contains two Java projects;
|
||||||
|
|
||||||
Most users interested in training random competing risks forests should use the [R package component](https://github.com/jatherrien/largeRCRF); the content in this repository is only useful for advanced users.
|
* The first is the largeRCRF library (`library/`) containing all of the logic used for training the random forests. This part provides the Jar file used in the R package [largeRCRF](https://github.com/jatherrien/largeRCRF).
|
||||||
|
* The second is a small executable (`executable/`) Java project that uses the library and can be run directly outside of R. It's still in its early stages and isn't polished, nor is it yet well documented; but you can take a look if you want.
|
||||||
|
|
||||||
|
Most users interested in training random competing risks forests should use the [R package component](https://github.com/jatherrien/largeRCRF); the content in this repository will only be useful to advanced users.
|
||||||
|
|
||||||
## License
|
## License
|
||||||
|
|
||||||
You're free to use / modify / redistribute the project, as long as you follow the terms of the GPL-3 license.
|
You're free to use / modify / redistribute either of the two projects above, as long as you follow the terms of the GPL-3 license.
|
||||||
|
|
||||||
## Extending
|
## Extending
|
||||||
|
|
||||||
Documentation on how to extend the library to add support for other types of random forests will eventually be added, but for now if you're interested in that I suggest you take a look at the `MeanResponseCombiner` and `WeightedVarianceSplitFinder` classes to see how some basic regression random forests were introduced.
|
Documentation on how to extend the library to add support for other types of random forests will eventually be added, but for now if you're interested in that I suggest you take a look at the `MeanResponseCombiner` and `WeightedVarianceSplitFinder` classes to see how some basic regression random forests were introduced.
|
||||||
|
|
||||||
If you've made a modification to the package and would like to integrate it into the R package component, build the project in Maven with `mvn clean package`, then just copy `target/largeRCRF-library-1.0-SNAPSHOT.jar` into the `inst/java/` directory for the R package, replacing the previous Jar file there. Then build the R package, possibly with your modifications to the code there, with `R> devtools::build()`.
|
If you've made a modification to the package and would like to integrate it into the R package component, build the project in Maven with `mvn clean package` (in the same directory as this `README` file), then just copy `library/target/largeRCRF-library-1.0-SNAPSHOT.jar` into the `inst/java/` directory for the R package, replacing the previous Jar file there. Then build the R package, possibly with your modifications to the code there, with `R> devtools::build()`.
|
||||||
|
|
||||||
Please don't take the current lack of documentation as a sign that I oppose others extending or modifying the project; if you have any questions on running, extending, integrating with R, or anything else related to this project, please don't hesitate to either [email me](mailto:joelt@sfu.ca) or create an Issue. Most likely my answers to your questions will end up forming the basis for any documentation written.
|
Please don't take the current lack of documentation as a sign that I oppose others extending or modifying the project; if you have any questions on running, extending, integrating with R, or anything else related to this project, please don't hesitate to either [email me](mailto:joelt@sfu.ca) or create an Issue. Most likely my answers to your questions will end up forming the basis for any documentation written.
|
||||||
|
|
||||||
|
@ -23,3 +26,26 @@ You need:
|
||||||
* A Java runtime version 1.8 or greater
|
* A Java runtime version 1.8 or greater
|
||||||
* Maven to build the project
|
* Maven to build the project
|
||||||
|
|
||||||
|
## Troubleshooting (Running `executable`)
|
||||||
|
|
||||||
|
Some of these Troubleshooting items can also apply if you are integrating the library classes into your own Java project.
|
||||||
|
|
||||||
|
### I get an `OutOfMemoryException` error but I have plenty of RAM
|
||||||
|
|
||||||
|
By default the Java virtual machine only uses a quarter of the available system memory. When launching the Jar file you can manually specify the memory available like below:
|
||||||
|
```
|
||||||
|
java -jar -Xmx15G -Xms15G largeRCRF-executable-1.0-SNAPSHOT.jar settings.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
with `15G` replaced with a little less than your available system memory.
|
||||||
|
|
||||||
|
### I get an `OutOfMemoryException` error and I'm short on RAM
|
||||||
|
|
||||||
|
Try reducing the number of trees being trained simultaneously by reducing the number of threads in the settings file.
|
||||||
|
|
||||||
|
### Training stalls immediately at 0 trees and the CPU is idle
|
||||||
|
|
||||||
|
This issue has been observed before on one particular system (and only on that system) but it's not clear what causes it. If you encounter this, please open an Issue and describe what operating system you're running on, what cloud system (if relevant) you're running on, and the entire output of `java --version`.
|
||||||
|
|
||||||
|
From my observation this issues occurs randomly but only at 0 trees; so as an imperfect workaround you can cancel the training and try again. Another imperfect workaround is to set the number of threads to 1; this causes the code to not use Java's parallel code capabilities which will bypass the problem (at the cost of slower training).
|
||||||
|
|
||||||
|
|
126
executable/pom.xml
Normal file
126
executable/pom.xml
Normal file
|
@ -0,0 +1,126 @@
|
||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<project xmlns="http://maven.apache.org/POM/4.0.0"
|
||||||
|
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
||||||
|
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
||||||
|
<parent>
|
||||||
|
<artifactId>largeRCRF</artifactId>
|
||||||
|
<groupId>ca.joeltherrien</groupId>
|
||||||
|
<version>1.0-SNAPSHOT</version>
|
||||||
|
</parent>
|
||||||
|
<modelVersion>4.0.0</modelVersion>
|
||||||
|
|
||||||
|
<artifactId>largeRCRF-executable</artifactId>
|
||||||
|
|
||||||
|
<properties>
|
||||||
|
<java.version>1.8</java.version>
|
||||||
|
<maven.compiler.target>1.8</maven.compiler.target>
|
||||||
|
<maven.compiler.source>1.8</maven.compiler.source>
|
||||||
|
<jackson.version>2.9.9</jackson.version>
|
||||||
|
|
||||||
|
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
|
||||||
|
<project.reporting.outputEncoding>UTF-8</project.reporting.outputEncoding>
|
||||||
|
</properties>
|
||||||
|
|
||||||
|
<dependencies>
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.projectlombok</groupId>
|
||||||
|
<artifactId>lombok</artifactId>
|
||||||
|
<version>1.18.0</version>
|
||||||
|
<scope>provided</scope>
|
||||||
|
</dependency>
|
||||||
|
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.apache.commons</groupId>
|
||||||
|
<artifactId>commons-csv</artifactId>
|
||||||
|
<version>1.5</version>
|
||||||
|
</dependency>
|
||||||
|
|
||||||
|
<dependency>
|
||||||
|
<groupId>com.fasterxml.jackson.core</groupId>
|
||||||
|
<artifactId>jackson-databind</artifactId>
|
||||||
|
<!--<version>${jackson.version}</version> 2.9.9.1 fixes some vulnerability; other Jackson dependencies are still on 2.9.9-->
|
||||||
|
<version>2.9.9.1</version>
|
||||||
|
</dependency>
|
||||||
|
|
||||||
|
<dependency>
|
||||||
|
<groupId>com.fasterxml.jackson.dataformat</groupId>
|
||||||
|
<artifactId>jackson-dataformat-yaml</artifactId>
|
||||||
|
<version>${jackson.version}</version>
|
||||||
|
</dependency>
|
||||||
|
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.junit.jupiter</groupId>
|
||||||
|
<artifactId>junit-jupiter-api</artifactId>
|
||||||
|
<version>5.2.0</version>
|
||||||
|
<scope>test</scope>
|
||||||
|
</dependency>
|
||||||
|
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.junit.jupiter</groupId>
|
||||||
|
<artifactId>junit-jupiter-engine</artifactId>
|
||||||
|
<version>5.2.0</version>
|
||||||
|
<scope>test</scope>
|
||||||
|
</dependency>
|
||||||
|
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.mockito</groupId>
|
||||||
|
<artifactId>mockito-core</artifactId>
|
||||||
|
<version>2.20.0</version>
|
||||||
|
<scope>test</scope>
|
||||||
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>ca.joeltherrien</groupId>
|
||||||
|
<artifactId>largeRCRF-library</artifactId>
|
||||||
|
<version>1.0-SNAPSHOT</version>
|
||||||
|
</dependency>
|
||||||
|
|
||||||
|
</dependencies>
|
||||||
|
|
||||||
|
<build>
|
||||||
|
<plugins>
|
||||||
|
<plugin>
|
||||||
|
<groupId>org.apache.maven.plugins</groupId>
|
||||||
|
<artifactId>maven-shade-plugin</artifactId>
|
||||||
|
<executions>
|
||||||
|
<execution>
|
||||||
|
<phase>package</phase>
|
||||||
|
<goals>
|
||||||
|
<goal>shade</goal>
|
||||||
|
</goals>
|
||||||
|
<configuration>
|
||||||
|
<transformers>
|
||||||
|
<transformer implementation="org.apache.maven.plugins.shade.resource.ManifestResourceTransformer">
|
||||||
|
<mainClass>
|
||||||
|
ca.joeltherrien.randomforest.Main
|
||||||
|
</mainClass>
|
||||||
|
</transformer>
|
||||||
|
</transformers>
|
||||||
|
<minimizeJar>true</minimizeJar>
|
||||||
|
</configuration>
|
||||||
|
</execution>
|
||||||
|
</executions>
|
||||||
|
</plugin>
|
||||||
|
<plugin>
|
||||||
|
<groupId>org.apache.maven.plugins</groupId>
|
||||||
|
<artifactId>maven-pmd-plugin</artifactId>
|
||||||
|
<version>3.11.0</version>
|
||||||
|
<executions>
|
||||||
|
<execution>
|
||||||
|
<phase>package</phase> <!-- bind to the packaging phase -->
|
||||||
|
<goals>
|
||||||
|
<goal>check</goal>
|
||||||
|
</goals>
|
||||||
|
</execution>
|
||||||
|
</executions>
|
||||||
|
<configuration>
|
||||||
|
<rulesets>
|
||||||
|
<!-- Custom local file system rule set -->
|
||||||
|
<ruleset>${project.basedir}/../pmd-rules.xml</ruleset>
|
||||||
|
</rulesets>
|
||||||
|
</configuration>
|
||||||
|
</plugin>
|
||||||
|
</plugins>
|
||||||
|
</build>
|
||||||
|
|
||||||
|
|
||||||
|
</project>
|
252
executable/src/main/java/ca/joeltherrien/randomforest/Main.java
Normal file
252
executable/src/main/java/ca/joeltherrien/randomforest/Main.java
Normal file
|
@ -0,0 +1,252 @@
|
||||||
|
/*
|
||||||
|
* Copyright (c) 2019 Joel Therrien.
|
||||||
|
* This program is free software: you can redistribute it and/or modify
|
||||||
|
* it under the terms of the GNU General Public License as published by
|
||||||
|
* the Free Software Foundation, either version 3 of the License, or
|
||||||
|
* (at your option) any later version.
|
||||||
|
*
|
||||||
|
* This program is distributed in the hope that it will be useful,
|
||||||
|
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
* GNU General Public License for more details.
|
||||||
|
*
|
||||||
|
* You should have received a copy of the GNU General Public License
|
||||||
|
* along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package ca.joeltherrien.randomforest;
|
||||||
|
|
||||||
|
import ca.joeltherrien.randomforest.covariates.Covariate;
|
||||||
|
import ca.joeltherrien.randomforest.covariates.settings.BooleanCovariateSettings;
|
||||||
|
import ca.joeltherrien.randomforest.covariates.settings.FactorCovariateSettings;
|
||||||
|
import ca.joeltherrien.randomforest.covariates.settings.NumericCovariateSettings;
|
||||||
|
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskErrorRateCalculator;
|
||||||
|
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskFunctions;
|
||||||
|
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
|
||||||
|
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponseWithCensorTime;
|
||||||
|
import ca.joeltherrien.randomforest.responses.competingrisk.combiner.CompetingRiskFunctionCombiner;
|
||||||
|
import ca.joeltherrien.randomforest.tree.Forest;
|
||||||
|
import ca.joeltherrien.randomforest.tree.ForestTrainer;
|
||||||
|
import ca.joeltherrien.randomforest.tree.ResponseCombiner;
|
||||||
|
import ca.joeltherrien.randomforest.tree.TreeTrainer;
|
||||||
|
import ca.joeltherrien.randomforest.utils.CSVUtils;
|
||||||
|
import ca.joeltherrien.randomforest.utils.DataUtils;
|
||||||
|
import ca.joeltherrien.randomforest.utils.StepFunction;
|
||||||
|
import ca.joeltherrien.randomforest.utils.Utils;
|
||||||
|
import com.fasterxml.jackson.databind.node.JsonNodeFactory;
|
||||||
|
import com.fasterxml.jackson.databind.node.ObjectNode;
|
||||||
|
import com.fasterxml.jackson.databind.node.TextNode;
|
||||||
|
|
||||||
|
import java.io.File;
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.io.PrintWriter;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Optional;
|
||||||
|
import java.util.Random;
|
||||||
|
|
||||||
|
public class Main {
|
||||||
|
|
||||||
|
public static void main(String[] args) throws IOException, ClassNotFoundException {
|
||||||
|
if(args.length < 2){
|
||||||
|
System.out.println("Must provide two arguments - the path to the settings.yaml file and instructions to either train or analyze.");
|
||||||
|
System.out.println("Note that analyzing only supports competing risk data, and that you must then specify a sample size for testing errors.");
|
||||||
|
if(args.length == 0){
|
||||||
|
final File templateFile = new File("template.yaml");
|
||||||
|
if(templateFile.exists()){
|
||||||
|
System.out.println("Template file exists; not creating a new one");
|
||||||
|
}
|
||||||
|
else{
|
||||||
|
System.out.println("Generating template file.");
|
||||||
|
defaultTemplate().save(templateFile);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
final File settingsFile = new File(args[0]);
|
||||||
|
final Settings settings = Settings.load(settingsFile);
|
||||||
|
|
||||||
|
final List<Covariate> covariates = settings.getCovariates();
|
||||||
|
|
||||||
|
if(args[1].equalsIgnoreCase("train")){
|
||||||
|
final List<Row<Object>> dataset = CSVUtils.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation());
|
||||||
|
|
||||||
|
final ForestTrainer forestTrainer = constructForestTrainer(settings, dataset, covariates);
|
||||||
|
|
||||||
|
if(settings.isSaveProgress()){
|
||||||
|
if(settings.getNumberOfThreads() > 1){
|
||||||
|
forestTrainer.trainParallelOnDisk(Optional.empty(), settings.getNumberOfThreads());
|
||||||
|
} else{
|
||||||
|
forestTrainer.trainSerialOnDisk(Optional.empty());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else{
|
||||||
|
if(settings.getNumberOfThreads() > 1){
|
||||||
|
forestTrainer.trainParallelInMemory(Optional.empty(), settings.getNumberOfThreads());
|
||||||
|
} else{
|
||||||
|
forestTrainer.trainSerialInMemory(Optional.empty());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else if(args[1].equalsIgnoreCase("analyze")){
|
||||||
|
// Perform different prediction measures
|
||||||
|
|
||||||
|
if(args.length < 3){
|
||||||
|
System.out.println("Specify error sample size");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
final String yVarType = settings.getYVarSettings().get("type").asText();
|
||||||
|
|
||||||
|
if(!yVarType.equalsIgnoreCase("CompetingRiskResponse") && !yVarType.equalsIgnoreCase("CompetingRiskResponseWithCensorTime")){
|
||||||
|
System.out.println("Analyze currently only works on competing risk data");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
final ResponseCombiner<?, CompetingRiskFunctions> responseCombiner = settings.getTreeCombiner();
|
||||||
|
final int[] events;
|
||||||
|
|
||||||
|
if(responseCombiner instanceof CompetingRiskFunctionCombiner){
|
||||||
|
events = ((CompetingRiskFunctionCombiner) responseCombiner).getEvents();
|
||||||
|
}
|
||||||
|
else{
|
||||||
|
System.out.println("Unsupported tree combiner");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
final List<Row<CompetingRiskResponse>> dataset = CSVUtils.loadData(covariates, settings.getResponseLoader(), settings.getValidationDataLocation());
|
||||||
|
|
||||||
|
// Let's reduce this down to n
|
||||||
|
final int n = Integer.parseInt(args[2]);
|
||||||
|
Utils.reduceListToSize(dataset, n, new Random());
|
||||||
|
|
||||||
|
final File folder = new File(settings.getSaveTreeLocation());
|
||||||
|
final Forest<?, CompetingRiskFunctions> forest = DataUtils.loadForest(folder, responseCombiner);
|
||||||
|
|
||||||
|
final boolean useBootstrapPredictions = settings.getTrainingDataLocation().equals(settings.getValidationDataLocation());
|
||||||
|
|
||||||
|
if(useBootstrapPredictions){
|
||||||
|
System.out.println("Finished loading trees + dataset; creating calculator and evaluating OOB predictions");
|
||||||
|
}
|
||||||
|
else{
|
||||||
|
System.out.println("Finished loading trees + dataset; creating calculator and evaluating predictions");
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
final CompetingRiskErrorRateCalculator errorRateCalculator = new CompetingRiskErrorRateCalculator(dataset, forest, useBootstrapPredictions);
|
||||||
|
final PrintWriter printWriter = new PrintWriter(settings.getSaveTreeLocation() + "/errors.txt");
|
||||||
|
|
||||||
|
System.out.println("Running Naive Concordance");
|
||||||
|
|
||||||
|
final double[] naiveConcordance = errorRateCalculator.calculateConcordance(events);
|
||||||
|
printWriter.write("Naive concordance:\n");
|
||||||
|
for(int i=0; i<events.length; i++){
|
||||||
|
printWriter.write('\t');
|
||||||
|
printWriter.write(Integer.toString(events[i]));
|
||||||
|
printWriter.write(": ");
|
||||||
|
printWriter.write(Double.toString(naiveConcordance[i]));
|
||||||
|
printWriter.write('\n');
|
||||||
|
}
|
||||||
|
|
||||||
|
if(yVarType.equalsIgnoreCase("CompetingRiskResponseWithCensorTime")){
|
||||||
|
System.out.println("Running IPCW Concordance - creating censor distribution");
|
||||||
|
|
||||||
|
final double[] censorTimes = dataset.stream()
|
||||||
|
.mapToDouble(row -> ((CompetingRiskResponseWithCensorTime) row.getResponse()).getC())
|
||||||
|
.toArray();
|
||||||
|
final StepFunction censorDistribution = Utils.estimateOneMinusECDF(censorTimes);
|
||||||
|
|
||||||
|
System.out.println("Finished generating censor distribution - running concordance");
|
||||||
|
|
||||||
|
final double[] ipcwConcordance = errorRateCalculator.calculateIPCWConcordance(events, censorDistribution);
|
||||||
|
printWriter.write("IPCW concordance:\n");
|
||||||
|
for(int i=0; i<events.length; i++){
|
||||||
|
printWriter.write('\t');
|
||||||
|
printWriter.write(Integer.toString(events[i]));
|
||||||
|
printWriter.write(": ");
|
||||||
|
printWriter.write(Double.toString(ipcwConcordance[i]));
|
||||||
|
printWriter.write('\n');
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
printWriter.close();
|
||||||
|
}
|
||||||
|
else{
|
||||||
|
System.out.println("Invalid instruction; use either train or analyze.");
|
||||||
|
System.out.println("Note that analyzing only supports competing risk data.");
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
private static ForestTrainer constructForestTrainer(final Settings settings, final List<Row<Object>> data, final List<Covariate> covariates){
|
||||||
|
|
||||||
|
return ForestTrainer.builder()
|
||||||
|
.ntree(settings.getNtree())
|
||||||
|
.data(data)
|
||||||
|
.displayProgress(true)
|
||||||
|
.saveTreeLocation(settings.getSaveTreeLocation())
|
||||||
|
.covariates(covariates)
|
||||||
|
.treeResponseCombiner(settings.getTreeCombiner())
|
||||||
|
.treeTrainer(constructTreeTrainer(settings, covariates))
|
||||||
|
.randomSeed(settings.getRandomSeed() != null ? settings.getRandomSeed() : System.nanoTime())
|
||||||
|
.build();
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
private static TreeTrainer constructTreeTrainer(final Settings settings, final List<Covariate> covariates) {
|
||||||
|
|
||||||
|
return TreeTrainer.builder()
|
||||||
|
.numberOfSplits(settings.getNumberOfSplits())
|
||||||
|
.nodeSize(settings.getNodeSize())
|
||||||
|
.maxNodeDepth(settings.getMaxNodeDepth())
|
||||||
|
.mtry(settings.getMtry())
|
||||||
|
.checkNodePurity(settings.isCheckNodePurity())
|
||||||
|
.responseCombiner(settings.getResponseCombiner())
|
||||||
|
.splitFinder(settings.getSplitFinder())
|
||||||
|
.covariates(covariates)
|
||||||
|
.build();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
private static Settings defaultTemplate(){
|
||||||
|
|
||||||
|
final ObjectNode splitFinderSettings = new ObjectNode(JsonNodeFactory.instance);
|
||||||
|
splitFinderSettings.set("type", new TextNode("WeightedVarianceSplitFinder"));
|
||||||
|
|
||||||
|
final ObjectNode responseCombinerSettings = new ObjectNode(JsonNodeFactory.instance);
|
||||||
|
responseCombinerSettings.set("type", new TextNode("MeanResponseCombiner"));
|
||||||
|
|
||||||
|
final ObjectNode treeCombinerSettings = new ObjectNode(JsonNodeFactory.instance);
|
||||||
|
treeCombinerSettings.set("type", new TextNode("MeanResponseCombiner"));
|
||||||
|
|
||||||
|
final ObjectNode yVarSettings = new ObjectNode(JsonNodeFactory.instance);
|
||||||
|
yVarSettings.set("type", new TextNode("Double"));
|
||||||
|
yVarSettings.set("name", new TextNode("y"));
|
||||||
|
|
||||||
|
return Settings.builder()
|
||||||
|
.covariateSettings(Utils.easyList(
|
||||||
|
new NumericCovariateSettings("x1"),
|
||||||
|
new BooleanCovariateSettings("x2"),
|
||||||
|
new FactorCovariateSettings("x3", Utils.easyList("cat", "mouse", "dog"))
|
||||||
|
)
|
||||||
|
)
|
||||||
|
.trainingDataLocation("training_data.csv")
|
||||||
|
.validationDataLocation("validation_data.csv")
|
||||||
|
.responseCombinerSettings(responseCombinerSettings)
|
||||||
|
.treeCombinerSettings(treeCombinerSettings)
|
||||||
|
.splitFinderSettings(splitFinderSettings)
|
||||||
|
.yVarSettings(yVarSettings)
|
||||||
|
.maxNodeDepth(100000)
|
||||||
|
.mtry(2)
|
||||||
|
.nodeSize(5)
|
||||||
|
.ntree(500)
|
||||||
|
.numberOfSplits(5)
|
||||||
|
.numberOfThreads(1)
|
||||||
|
.saveProgress(true)
|
||||||
|
.saveTreeLocation("trees/")
|
||||||
|
.build();
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,236 @@
|
||||||
|
/*
|
||||||
|
* Copyright (c) 2019 Joel Therrien.
|
||||||
|
* This program is free software: you can redistribute it and/or modify
|
||||||
|
* it under the terms of the GNU General Public License as published by
|
||||||
|
* the Free Software Foundation, either version 3 of the License, or
|
||||||
|
* (at your option) any later version.
|
||||||
|
*
|
||||||
|
* This program is distributed in the hope that it will be useful,
|
||||||
|
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
* GNU General Public License for more details.
|
||||||
|
*
|
||||||
|
* You should have received a copy of the GNU General Public License
|
||||||
|
* along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package ca.joeltherrien.randomforest;
|
||||||
|
|
||||||
|
import ca.joeltherrien.randomforest.covariates.Covariate;
|
||||||
|
import ca.joeltherrien.randomforest.covariates.settings.CovariateSettings;
|
||||||
|
import ca.joeltherrien.randomforest.loaders.CompetingResponseLoader;
|
||||||
|
import ca.joeltherrien.randomforest.loaders.CompetingResponseWithCensorTimeLoader;
|
||||||
|
import ca.joeltherrien.randomforest.loaders.DoubleLoader;
|
||||||
|
import ca.joeltherrien.randomforest.loaders.ResponseLoader;
|
||||||
|
import ca.joeltherrien.randomforest.responses.competingrisk.combiner.CompetingRiskFunctionCombiner;
|
||||||
|
import ca.joeltherrien.randomforest.responses.competingrisk.combiner.CompetingRiskResponseCombiner;
|
||||||
|
import ca.joeltherrien.randomforest.responses.competingrisk.splitfinder.GrayLogRankSplitFinder;
|
||||||
|
import ca.joeltherrien.randomforest.responses.competingrisk.splitfinder.LogRankSplitFinder;
|
||||||
|
import ca.joeltherrien.randomforest.responses.regression.MeanResponseCombiner;
|
||||||
|
import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceSplitFinder;
|
||||||
|
import ca.joeltherrien.randomforest.tree.ResponseCombiner;
|
||||||
|
import ca.joeltherrien.randomforest.tree.SplitFinder;
|
||||||
|
import ca.joeltherrien.randomforest.utils.JsonUtils;
|
||||||
|
import com.fasterxml.jackson.annotation.JsonIgnore;
|
||||||
|
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||||
|
import com.fasterxml.jackson.databind.node.JsonNodeFactory;
|
||||||
|
import com.fasterxml.jackson.databind.node.ObjectNode;
|
||||||
|
import com.fasterxml.jackson.dataformat.yaml.YAMLFactory;
|
||||||
|
import lombok.AllArgsConstructor;
|
||||||
|
import lombok.Builder;
|
||||||
|
import lombok.Data;
|
||||||
|
import lombok.EqualsAndHashCode;
|
||||||
|
|
||||||
|
import java.io.File;
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.function.Function;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This class is saved & loaded using a saved configuration file. It contains all relevant settings when training a forest.
|
||||||
|
*/
|
||||||
|
@Data
|
||||||
|
@Builder
|
||||||
|
@AllArgsConstructor
|
||||||
|
@EqualsAndHashCode
|
||||||
|
public class Settings {
|
||||||
|
|
||||||
|
private static Map<String, Function<ObjectNode, ResponseLoader>> RESPONSE_LOADER_MAP = new HashMap<>();
|
||||||
|
public static Function<ObjectNode, ResponseLoader> getResponseLoaderConstructor(final String name){
|
||||||
|
return RESPONSE_LOADER_MAP.get(name.toLowerCase());
|
||||||
|
}
|
||||||
|
public static void registerResponseLoaderConstructor(final String name, final Function<ObjectNode, ResponseLoader> responseLoaderConstructor){
|
||||||
|
RESPONSE_LOADER_MAP.put(name.toLowerCase(), responseLoaderConstructor);
|
||||||
|
}
|
||||||
|
|
||||||
|
static{
|
||||||
|
registerResponseLoaderConstructor("double",
|
||||||
|
node -> new DoubleLoader(node)
|
||||||
|
);
|
||||||
|
registerResponseLoaderConstructor("CompetingRiskResponse",
|
||||||
|
node -> new CompetingResponseLoader(node)
|
||||||
|
);
|
||||||
|
registerResponseLoaderConstructor("CompetingRiskResponseWithCensorTime",
|
||||||
|
node -> new CompetingResponseWithCensorTimeLoader(node)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
private static Map<String, Function<ObjectNode, SplitFinder>> SPLIT_FINDER_MAP = new HashMap<>();
|
||||||
|
public static Function<ObjectNode, SplitFinder> getSplitFinderConstructor(final String name){
|
||||||
|
return SPLIT_FINDER_MAP.get(name.toLowerCase());
|
||||||
|
}
|
||||||
|
public static void registerSplitFinderConstructor(final String name, final Function<ObjectNode, SplitFinder> splitFinderConstructor){
|
||||||
|
SPLIT_FINDER_MAP.put(name.toLowerCase(), splitFinderConstructor);
|
||||||
|
}
|
||||||
|
static{
|
||||||
|
registerSplitFinderConstructor("WeightedVarianceSplitFinder",
|
||||||
|
(node) -> new WeightedVarianceSplitFinder()
|
||||||
|
);
|
||||||
|
registerSplitFinderConstructor("GrayLogRankSplitFinder",
|
||||||
|
(objectNode) -> {
|
||||||
|
final int[] eventsOfFocusArray = JsonUtils.jsonToIntArray(objectNode.get("eventsOfFocus"));
|
||||||
|
final int[] eventArray = JsonUtils.jsonToIntArray(objectNode.get("events"));
|
||||||
|
|
||||||
|
return new GrayLogRankSplitFinder(eventsOfFocusArray, eventArray);
|
||||||
|
}
|
||||||
|
);
|
||||||
|
registerSplitFinderConstructor("LogRankSplitFinder",
|
||||||
|
(objectNode) -> {
|
||||||
|
final int[] eventsOfFocusArray = JsonUtils.jsonToIntArray(objectNode.get("eventsOfFocus"));
|
||||||
|
final int[] eventArray = JsonUtils.jsonToIntArray(objectNode.get("events"));
|
||||||
|
|
||||||
|
return new LogRankSplitFinder(eventsOfFocusArray, eventArray);
|
||||||
|
}
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
private static Map<String, Function<ObjectNode, ResponseCombiner>> RESPONSE_COMBINER_MAP = new HashMap<>();
|
||||||
|
public static Function<ObjectNode, ResponseCombiner> getResponseCombinerConstructor(final String name){
|
||||||
|
return RESPONSE_COMBINER_MAP.get(name.toLowerCase());
|
||||||
|
}
|
||||||
|
public static void registerResponseCombinerConstructor(final String name, final Function<ObjectNode, ResponseCombiner> responseCombinerConstructor){
|
||||||
|
RESPONSE_COMBINER_MAP.put(name.toLowerCase(), responseCombinerConstructor);
|
||||||
|
}
|
||||||
|
|
||||||
|
static{
|
||||||
|
|
||||||
|
registerResponseCombinerConstructor("MeanResponseCombiner",
|
||||||
|
(node) -> new MeanResponseCombiner()
|
||||||
|
);
|
||||||
|
registerResponseCombinerConstructor("CompetingRiskResponseCombiner",
|
||||||
|
(node) -> {
|
||||||
|
final int[] events = JsonUtils.jsonToIntArray(node.get("events"));
|
||||||
|
|
||||||
|
return new CompetingRiskResponseCombiner(events);
|
||||||
|
|
||||||
|
}
|
||||||
|
);
|
||||||
|
|
||||||
|
registerResponseCombinerConstructor("CompetingRiskFunctionCombiner",
|
||||||
|
(node) -> {
|
||||||
|
final int[] events = JsonUtils.jsonToIntArray(node.get("events"));
|
||||||
|
|
||||||
|
double[] times = null;
|
||||||
|
if(node.hasNonNull("times")){
|
||||||
|
times = JsonUtils.jsonToDoubleArray(node.get("times"));
|
||||||
|
}
|
||||||
|
|
||||||
|
return new CompetingRiskFunctionCombiner(events, times);
|
||||||
|
|
||||||
|
}
|
||||||
|
);
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
private int numberOfSplits = 5;
|
||||||
|
private int nodeSize = 5;
|
||||||
|
private int maxNodeDepth = 1000000; // basically no maxNodeDepth
|
||||||
|
private boolean checkNodePurity = false;
|
||||||
|
private Long randomSeed;
|
||||||
|
|
||||||
|
private ObjectNode responseCombinerSettings = new ObjectNode(JsonNodeFactory.instance);
|
||||||
|
private ObjectNode splitFinderSettings = new ObjectNode(JsonNodeFactory.instance);
|
||||||
|
private ObjectNode treeCombinerSettings = new ObjectNode(JsonNodeFactory.instance);
|
||||||
|
|
||||||
|
private List<CovariateSettings> covariateSettings = new ArrayList<>();
|
||||||
|
private ObjectNode yVarSettings = new ObjectNode(JsonNodeFactory.instance);
|
||||||
|
|
||||||
|
// number of covariates to randomly try
|
||||||
|
private int mtry = 0;
|
||||||
|
|
||||||
|
// number of trees to try
|
||||||
|
private int ntree = 500;
|
||||||
|
|
||||||
|
private String trainingDataLocation = "data.csv";
|
||||||
|
private String validationDataLocation = "data.csv";
|
||||||
|
private String saveTreeLocation = "trees/";
|
||||||
|
|
||||||
|
private int numberOfThreads = 1;
|
||||||
|
private boolean saveProgress = false;
|
||||||
|
|
||||||
|
public Settings(){
|
||||||
|
} // required for Jackson
|
||||||
|
|
||||||
|
public static Settings load(File file) throws IOException {
|
||||||
|
final ObjectMapper mapper = new ObjectMapper(new YAMLFactory());
|
||||||
|
//mapper.enableDefaultTyping();
|
||||||
|
|
||||||
|
return mapper.readValue(file, Settings.class);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
public void save(File file) throws IOException {
|
||||||
|
final ObjectMapper mapper = new ObjectMapper(new YAMLFactory());
|
||||||
|
//mapper.enableDefaultTyping();
|
||||||
|
|
||||||
|
// Jackson can struggle with some types of Lists, such as that returned by the useful List.of(...)
|
||||||
|
this.covariateSettings = new ArrayList<>(this.covariateSettings);
|
||||||
|
|
||||||
|
mapper.writeValue(file, this);
|
||||||
|
}
|
||||||
|
|
||||||
|
@JsonIgnore
|
||||||
|
public SplitFinder getSplitFinder(){
|
||||||
|
final String type = splitFinderSettings.get("type").asText();
|
||||||
|
|
||||||
|
return getSplitFinderConstructor(type).apply(splitFinderSettings);
|
||||||
|
}
|
||||||
|
|
||||||
|
@JsonIgnore
|
||||||
|
public ResponseLoader getResponseLoader(){
|
||||||
|
final String type = yVarSettings.get("type").asText();
|
||||||
|
|
||||||
|
return getResponseLoaderConstructor(type).apply(yVarSettings);
|
||||||
|
}
|
||||||
|
|
||||||
|
@JsonIgnore
|
||||||
|
public ResponseCombiner getResponseCombiner(){
|
||||||
|
final String type = responseCombinerSettings.get("type").asText();
|
||||||
|
|
||||||
|
return getResponseCombinerConstructor(type).apply(responseCombinerSettings);
|
||||||
|
}
|
||||||
|
|
||||||
|
@JsonIgnore
|
||||||
|
public ResponseCombiner getTreeCombiner(){
|
||||||
|
final String type = treeCombinerSettings.get("type").asText();
|
||||||
|
|
||||||
|
return getResponseCombinerConstructor(type).apply(treeCombinerSettings);
|
||||||
|
}
|
||||||
|
|
||||||
|
@JsonIgnore
|
||||||
|
public List<Covariate> getCovariates(){
|
||||||
|
final List<CovariateSettings> covariateSettingsList = this.getCovariateSettings();
|
||||||
|
final List<Covariate> covariates = new ArrayList<>(covariateSettingsList.size());
|
||||||
|
for(int i = 0; i < covariateSettingsList.size(); i++){
|
||||||
|
covariates.add(covariateSettingsList.get(i).build(i));
|
||||||
|
}
|
||||||
|
return covariates;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -14,17 +14,22 @@
|
||||||
* along with this program. If not, see <https://www.gnu.org/licenses/>.
|
* along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package ca.joeltherrien.randomforest.tree;
|
package ca.joeltherrien.randomforest.covariates.settings;
|
||||||
|
|
||||||
/**
|
import ca.joeltherrien.randomforest.covariates.bool.BooleanCovariate;
|
||||||
* Similar to ResponseCombiner, but an IntermediateCombinedResponse represents the intermediate state of a single output in the process of being combined.
|
import lombok.Data;
|
||||||
* This class is only used in OfflineForests where we can only load one Tree in memory at a time.
|
import lombok.NoArgsConstructor;
|
||||||
*
|
|
||||||
*/
|
|
||||||
public interface IntermediateCombinedResponse<I, O> {
|
|
||||||
|
|
||||||
void processNewInput(I input);
|
@NoArgsConstructor // required by Jackson
|
||||||
|
@Data
|
||||||
|
public final class BooleanCovariateSettings extends CovariateSettings<Boolean> {
|
||||||
|
|
||||||
O transformToOutput();
|
public BooleanCovariateSettings(String name){
|
||||||
|
super(name);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public BooleanCovariate build(int index) {
|
||||||
|
return new BooleanCovariate(name, index);
|
||||||
|
}
|
||||||
}
|
}
|
|
@ -0,0 +1,49 @@
|
||||||
|
/*
|
||||||
|
* Copyright (c) 2019 Joel Therrien.
|
||||||
|
* This program is free software: you can redistribute it and/or modify
|
||||||
|
* it under the terms of the GNU General Public License as published by
|
||||||
|
* the Free Software Foundation, either version 3 of the License, or
|
||||||
|
* (at your option) any later version.
|
||||||
|
*
|
||||||
|
* This program is distributed in the hope that it will be useful,
|
||||||
|
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
* GNU General Public License for more details.
|
||||||
|
*
|
||||||
|
* You should have received a copy of the GNU General Public License
|
||||||
|
* along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package ca.joeltherrien.randomforest.covariates.settings;
|
||||||
|
|
||||||
|
import ca.joeltherrien.randomforest.covariates.Covariate;
|
||||||
|
import com.fasterxml.jackson.annotation.JsonSubTypes;
|
||||||
|
import com.fasterxml.jackson.annotation.JsonTypeInfo;
|
||||||
|
import lombok.Getter;
|
||||||
|
import lombok.NoArgsConstructor;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Nuisance class to work with Jackson for persisting settings.
|
||||||
|
*
|
||||||
|
* @param <V>
|
||||||
|
*/
|
||||||
|
@NoArgsConstructor // required for Jackson
|
||||||
|
@Getter
|
||||||
|
@JsonTypeInfo(
|
||||||
|
use = JsonTypeInfo.Id.NAME,
|
||||||
|
property = "type")
|
||||||
|
@JsonSubTypes({
|
||||||
|
@JsonSubTypes.Type(value = BooleanCovariateSettings.class, name = "boolean"),
|
||||||
|
@JsonSubTypes.Type(value = NumericCovariateSettings.class, name = "numeric"),
|
||||||
|
@JsonSubTypes.Type(value = FactorCovariateSettings.class, name = "factor")
|
||||||
|
})
|
||||||
|
public abstract class CovariateSettings<V> {
|
||||||
|
|
||||||
|
String name;
|
||||||
|
|
||||||
|
CovariateSettings(String name){
|
||||||
|
this.name = name;
|
||||||
|
}
|
||||||
|
|
||||||
|
public abstract Covariate<V> build(int index);
|
||||||
|
}
|
|
@ -0,0 +1,41 @@
|
||||||
|
/*
|
||||||
|
* Copyright (c) 2019 Joel Therrien.
|
||||||
|
* This program is free software: you can redistribute it and/or modify
|
||||||
|
* it under the terms of the GNU General Public License as published by
|
||||||
|
* the Free Software Foundation, either version 3 of the License, or
|
||||||
|
* (at your option) any later version.
|
||||||
|
*
|
||||||
|
* This program is distributed in the hope that it will be useful,
|
||||||
|
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
* GNU General Public License for more details.
|
||||||
|
*
|
||||||
|
* You should have received a copy of the GNU General Public License
|
||||||
|
* along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package ca.joeltherrien.randomforest.covariates.settings;
|
||||||
|
|
||||||
|
import ca.joeltherrien.randomforest.covariates.factor.FactorCovariate;
|
||||||
|
import lombok.Data;
|
||||||
|
import lombok.NoArgsConstructor;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
@NoArgsConstructor // required by Jackson
|
||||||
|
@Data
|
||||||
|
public final class FactorCovariateSettings extends CovariateSettings<String> {
|
||||||
|
|
||||||
|
private List<String> levels;
|
||||||
|
|
||||||
|
public FactorCovariateSettings(String name, List<String> levels){
|
||||||
|
super(name);
|
||||||
|
this.levels = new ArrayList<>(levels); // Jackson struggles with List.of(...)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public FactorCovariate build(int index) {
|
||||||
|
return new FactorCovariate(name, index, levels);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,35 @@
|
||||||
|
/*
|
||||||
|
* Copyright (c) 2019 Joel Therrien.
|
||||||
|
* This program is free software: you can redistribute it and/or modify
|
||||||
|
* it under the terms of the GNU General Public License as published by
|
||||||
|
* the Free Software Foundation, either version 3 of the License, or
|
||||||
|
* (at your option) any later version.
|
||||||
|
*
|
||||||
|
* This program is distributed in the hope that it will be useful,
|
||||||
|
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
* GNU General Public License for more details.
|
||||||
|
*
|
||||||
|
* You should have received a copy of the GNU General Public License
|
||||||
|
* along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package ca.joeltherrien.randomforest.covariates.settings;
|
||||||
|
|
||||||
|
import ca.joeltherrien.randomforest.covariates.numeric.NumericCovariate;
|
||||||
|
import lombok.Data;
|
||||||
|
import lombok.NoArgsConstructor;
|
||||||
|
|
||||||
|
@NoArgsConstructor // required by Jackson
|
||||||
|
@Data
|
||||||
|
public final class NumericCovariateSettings extends CovariateSettings<Double> {
|
||||||
|
|
||||||
|
public NumericCovariateSettings(String name){
|
||||||
|
super(name);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public NumericCovariate build(int index) {
|
||||||
|
return new NumericCovariate(name, index);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,42 @@
|
||||||
|
/*
|
||||||
|
* Copyright (c) 2019 Joel Therrien.
|
||||||
|
* This program is free software: you can redistribute it and/or modify
|
||||||
|
* it under the terms of the GNU General Public License as published by
|
||||||
|
* the Free Software Foundation, either version 3 of the License, or
|
||||||
|
* (at your option) any later version.
|
||||||
|
*
|
||||||
|
* This program is distributed in the hope that it will be useful,
|
||||||
|
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
* GNU General Public License for more details.
|
||||||
|
*
|
||||||
|
* You should have received a copy of the GNU General Public License
|
||||||
|
* along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package ca.joeltherrien.randomforest.loaders;
|
||||||
|
|
||||||
|
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
|
||||||
|
import com.fasterxml.jackson.databind.node.ObjectNode;
|
||||||
|
import lombok.RequiredArgsConstructor;
|
||||||
|
import org.apache.commons.csv.CSVRecord;
|
||||||
|
|
||||||
|
@RequiredArgsConstructor
|
||||||
|
public class CompetingResponseLoader implements ResponseLoader<CompetingRiskResponse> {
|
||||||
|
|
||||||
|
private final String deltaName;
|
||||||
|
private final String uName;
|
||||||
|
|
||||||
|
public CompetingResponseLoader(ObjectNode node){
|
||||||
|
this.deltaName = node.get("delta").asText();
|
||||||
|
this.uName = node.get("u").asText();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public CompetingRiskResponse parse(CSVRecord record) {
|
||||||
|
final int delta = Integer.parseInt(record.get(deltaName));
|
||||||
|
final double u = Double.parseDouble(record.get(uName));
|
||||||
|
|
||||||
|
return new CompetingRiskResponse(delta, u);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,45 @@
|
||||||
|
/*
|
||||||
|
* Copyright (c) 2019 Joel Therrien.
|
||||||
|
* This program is free software: you can redistribute it and/or modify
|
||||||
|
* it under the terms of the GNU General Public License as published by
|
||||||
|
* the Free Software Foundation, either version 3 of the License, or
|
||||||
|
* (at your option) any later version.
|
||||||
|
*
|
||||||
|
* This program is distributed in the hope that it will be useful,
|
||||||
|
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
* GNU General Public License for more details.
|
||||||
|
*
|
||||||
|
* You should have received a copy of the GNU General Public License
|
||||||
|
* along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package ca.joeltherrien.randomforest.loaders;
|
||||||
|
|
||||||
|
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponseWithCensorTime;
|
||||||
|
import com.fasterxml.jackson.databind.node.ObjectNode;
|
||||||
|
import lombok.RequiredArgsConstructor;
|
||||||
|
import org.apache.commons.csv.CSVRecord;
|
||||||
|
|
||||||
|
@RequiredArgsConstructor
|
||||||
|
public class CompetingResponseWithCensorTimeLoader implements ResponseLoader<CompetingRiskResponseWithCensorTime>{
|
||||||
|
|
||||||
|
private final String deltaName;
|
||||||
|
private final String uName;
|
||||||
|
private final String cName;
|
||||||
|
|
||||||
|
public CompetingResponseWithCensorTimeLoader(ObjectNode node){
|
||||||
|
this.deltaName = node.get("delta").asText();
|
||||||
|
this.uName = node.get("u").asText();
|
||||||
|
this.cName = node.get("c").asText();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public CompetingRiskResponseWithCensorTime parse(CSVRecord record) {
|
||||||
|
final int delta = Integer.parseInt(record.get(deltaName));
|
||||||
|
final double u = Double.parseDouble(record.get(uName));
|
||||||
|
final double c = Double.parseDouble(record.get(cName));
|
||||||
|
|
||||||
|
return new CompetingRiskResponseWithCensorTime(delta, u, c);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,36 @@
|
||||||
|
/*
|
||||||
|
* Copyright (c) 2019 Joel Therrien.
|
||||||
|
* This program is free software: you can redistribute it and/or modify
|
||||||
|
* it under the terms of the GNU General Public License as published by
|
||||||
|
* the Free Software Foundation, either version 3 of the License, or
|
||||||
|
* (at your option) any later version.
|
||||||
|
*
|
||||||
|
* This program is distributed in the hope that it will be useful,
|
||||||
|
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
* GNU General Public License for more details.
|
||||||
|
*
|
||||||
|
* You should have received a copy of the GNU General Public License
|
||||||
|
* along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package ca.joeltherrien.randomforest.loaders;
|
||||||
|
|
||||||
|
import com.fasterxml.jackson.databind.node.ObjectNode;
|
||||||
|
import lombok.RequiredArgsConstructor;
|
||||||
|
import org.apache.commons.csv.CSVRecord;
|
||||||
|
|
||||||
|
@RequiredArgsConstructor
|
||||||
|
public class DoubleLoader implements ResponseLoader<Double> {
|
||||||
|
|
||||||
|
private final String yName;
|
||||||
|
|
||||||
|
public DoubleLoader(final ObjectNode node){
|
||||||
|
this.yName = node.get("name").asText();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Double parse(CSVRecord record) {
|
||||||
|
return Double.parseDouble(record.get(yName));
|
||||||
|
}
|
||||||
|
}
|
|
@ -14,10 +14,13 @@
|
||||||
* along with this program. If not, see <https://www.gnu.org/licenses/>.
|
* along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package ca.joeltherrien.randomforest.tree;
|
package ca.joeltherrien.randomforest.loaders;
|
||||||
|
|
||||||
public interface ForestResponseCombiner<I, O> extends ResponseCombiner<I, O>{
|
import org.apache.commons.csv.CSVRecord;
|
||||||
|
|
||||||
IntermediateCombinedResponse<I, O> startIntermediateCombinedResponse(int countInputs);
|
@FunctionalInterface
|
||||||
|
public interface ResponseLoader<Y> {
|
||||||
|
|
||||||
|
Y parse(CSVRecord record);
|
||||||
|
|
||||||
}
|
}
|
|
@ -0,0 +1,70 @@
|
||||||
|
/*
|
||||||
|
* Copyright (c) 2019 Joel Therrien.
|
||||||
|
* This program is free software: you can redistribute it and/or modify
|
||||||
|
* it under the terms of the GNU General Public License as published by
|
||||||
|
* the Free Software Foundation, either version 3 of the License, or
|
||||||
|
* (at your option) any later version.
|
||||||
|
*
|
||||||
|
* This program is distributed in the hope that it will be useful,
|
||||||
|
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
* GNU General Public License for more details.
|
||||||
|
*
|
||||||
|
* You should have received a copy of the GNU General Public License
|
||||||
|
* along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package ca.joeltherrien.randomforest.utils;
|
||||||
|
|
||||||
|
import ca.joeltherrien.randomforest.Row;
|
||||||
|
import ca.joeltherrien.randomforest.covariates.Covariate;
|
||||||
|
import ca.joeltherrien.randomforest.loaders.ResponseLoader;
|
||||||
|
import org.apache.commons.csv.CSVFormat;
|
||||||
|
import org.apache.commons.csv.CSVParser;
|
||||||
|
import org.apache.commons.csv.CSVRecord;
|
||||||
|
|
||||||
|
import java.io.*;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.zip.GZIPInputStream;
|
||||||
|
|
||||||
|
public class CSVUtils {
|
||||||
|
|
||||||
|
public static <Y> List<Row<Y>> loadData(final List<Covariate> covariates, final ResponseLoader<Y> responseLoader, String filename) throws IOException {
|
||||||
|
|
||||||
|
final List<Row<Y>> dataset = new ArrayList<>();
|
||||||
|
|
||||||
|
final Reader input;
|
||||||
|
if(filename.endsWith(".gz")){
|
||||||
|
final FileInputStream inputStream = new FileInputStream(filename);
|
||||||
|
final GZIPInputStream gzipInputStream = new GZIPInputStream(inputStream);
|
||||||
|
|
||||||
|
input = new InputStreamReader(gzipInputStream);
|
||||||
|
}
|
||||||
|
else{
|
||||||
|
input = new FileReader(filename);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
final CSVParser parser = CSVFormat.RFC4180.withFirstRecordAsHeader().parse(input);
|
||||||
|
|
||||||
|
|
||||||
|
int id = 1;
|
||||||
|
for(final CSVRecord record : parser){
|
||||||
|
final Covariate.Value[] valueArray = new Covariate.Value[covariates.size()];
|
||||||
|
|
||||||
|
for(final Covariate<?> covariate : covariates){
|
||||||
|
valueArray[covariate.getIndex()] = covariate.createValue(record.get(covariate.getName()));
|
||||||
|
}
|
||||||
|
|
||||||
|
final Y y = responseLoader.parse(record);
|
||||||
|
|
||||||
|
dataset.add(new Row<>(valueArray, id++, y));
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
return dataset;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,44 @@
|
||||||
|
/*
|
||||||
|
* Copyright (c) 2019 Joel Therrien.
|
||||||
|
* This program is free software: you can redistribute it and/or modify
|
||||||
|
* it under the terms of the GNU General Public License as published by
|
||||||
|
* the Free Software Foundation, either version 3 of the License, or
|
||||||
|
* (at your option) any later version.
|
||||||
|
*
|
||||||
|
* This program is distributed in the hope that it will be useful,
|
||||||
|
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
* GNU General Public License for more details.
|
||||||
|
*
|
||||||
|
* You should have received a copy of the GNU General Public License
|
||||||
|
* along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package ca.joeltherrien.randomforest.utils;
|
||||||
|
|
||||||
|
import com.fasterxml.jackson.databind.JsonNode;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.Iterator;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
public class JsonUtils {
|
||||||
|
|
||||||
|
public static int[] jsonToIntArray(final JsonNode node){
|
||||||
|
final Iterator<JsonNode> elements = node.elements();
|
||||||
|
final List<JsonNode> elementList = new ArrayList<>();
|
||||||
|
elements.forEachRemaining(n -> elementList.add(n));
|
||||||
|
|
||||||
|
final int[] array = elementList.stream().mapToInt(n -> n.asInt()).toArray();
|
||||||
|
return array;
|
||||||
|
}
|
||||||
|
|
||||||
|
public static double[] jsonToDoubleArray(final JsonNode node){
|
||||||
|
final Iterator<JsonNode> elements = node.elements();
|
||||||
|
final List<JsonNode> elementList = new ArrayList<>();
|
||||||
|
elements.forEachRemaining(n -> elementList.add(n));
|
||||||
|
|
||||||
|
final double[] array = elementList.stream().mapToDouble(n -> n.asDouble()).toArray();
|
||||||
|
return array;
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,84 @@
|
||||||
|
/*
|
||||||
|
* Copyright (c) 2019 Joel Therrien.
|
||||||
|
* This program is free software: you can redistribute it and/or modify
|
||||||
|
* it under the terms of the GNU General Public License as published by
|
||||||
|
* the Free Software Foundation, either version 3 of the License, or
|
||||||
|
* (at your option) any later version.
|
||||||
|
*
|
||||||
|
* This program is distributed in the hope that it will be useful,
|
||||||
|
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
* GNU General Public License for more details.
|
||||||
|
*
|
||||||
|
* You should have received a copy of the GNU General Public License
|
||||||
|
* along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package ca.joeltherrien.randomforest;
|
||||||
|
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
|
||||||
|
import ca.joeltherrien.randomforest.covariates.settings.BooleanCovariateSettings;
|
||||||
|
import ca.joeltherrien.randomforest.covariates.settings.FactorCovariateSettings;
|
||||||
|
import ca.joeltherrien.randomforest.covariates.settings.NumericCovariateSettings;
|
||||||
|
import ca.joeltherrien.randomforest.utils.Utils;
|
||||||
|
import com.fasterxml.jackson.databind.node.JsonNodeFactory;
|
||||||
|
import com.fasterxml.jackson.databind.node.ObjectNode;
|
||||||
|
import com.fasterxml.jackson.databind.node.TextNode;
|
||||||
|
import org.junit.jupiter.api.Test;
|
||||||
|
|
||||||
|
import java.io.File;
|
||||||
|
import java.io.IOException;
|
||||||
|
|
||||||
|
public class TestPersistence {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testSaving() throws IOException {
|
||||||
|
final ObjectNode splitFinderSettings = new ObjectNode(JsonNodeFactory.instance);
|
||||||
|
splitFinderSettings.set("type", new TextNode("WeightedVarianceSplitFinder"));
|
||||||
|
|
||||||
|
final ObjectNode responseCombinerSettings = new ObjectNode(JsonNodeFactory.instance);
|
||||||
|
responseCombinerSettings.set("type", new TextNode("MeanResponseCombiner"));
|
||||||
|
|
||||||
|
final ObjectNode treeCombinerSettings = new ObjectNode(JsonNodeFactory.instance);
|
||||||
|
treeCombinerSettings.set("type", new TextNode("MeanResponseCombiner"));
|
||||||
|
|
||||||
|
final ObjectNode yVarSettings = new ObjectNode(JsonNodeFactory.instance);
|
||||||
|
yVarSettings.set("type", new TextNode("Double"));
|
||||||
|
yVarSettings.set("name", new TextNode("y"));
|
||||||
|
|
||||||
|
final Settings settingsOriginal = Settings.builder()
|
||||||
|
.covariateSettings(Utils.easyList(
|
||||||
|
new NumericCovariateSettings("x1"),
|
||||||
|
new BooleanCovariateSettings("x2"),
|
||||||
|
new FactorCovariateSettings("x3", Utils.easyList("cat", "mouse", "dog"))
|
||||||
|
)
|
||||||
|
)
|
||||||
|
.trainingDataLocation("training_data.csv")
|
||||||
|
.validationDataLocation("validation_data.csv")
|
||||||
|
.responseCombinerSettings(responseCombinerSettings)
|
||||||
|
.treeCombinerSettings(treeCombinerSettings)
|
||||||
|
.splitFinderSettings(splitFinderSettings)
|
||||||
|
.yVarSettings(yVarSettings)
|
||||||
|
.maxNodeDepth(100000)
|
||||||
|
.mtry(2)
|
||||||
|
.nodeSize(5)
|
||||||
|
.ntree(500)
|
||||||
|
.numberOfSplits(5)
|
||||||
|
.numberOfThreads(1)
|
||||||
|
.saveProgress(true)
|
||||||
|
.saveTreeLocation("trees/")
|
||||||
|
.build();
|
||||||
|
|
||||||
|
final File templateFile = new File("template.yaml");
|
||||||
|
settingsOriginal.save(templateFile);
|
||||||
|
|
||||||
|
final Settings reloadedSettings = Settings.load(templateFile);
|
||||||
|
|
||||||
|
assertEquals(settingsOriginal, reloadedSettings);
|
||||||
|
|
||||||
|
//templateFile.delete();
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,120 @@
|
||||||
|
/*
|
||||||
|
* Copyright (c) 2019 Joel Therrien.
|
||||||
|
* This program is free software: you can redistribute it and/or modify
|
||||||
|
* it under the terms of the GNU General Public License as published by
|
||||||
|
* the Free Software Foundation, either version 3 of the License, or
|
||||||
|
* (at your option) any later version.
|
||||||
|
*
|
||||||
|
* This program is distributed in the hope that it will be useful,
|
||||||
|
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
* GNU General Public License for more details.
|
||||||
|
*
|
||||||
|
* You should have received a copy of the GNU General Public License
|
||||||
|
* along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package ca.joeltherrien.randomforest.csv;
|
||||||
|
|
||||||
|
import ca.joeltherrien.randomforest.Row;
|
||||||
|
import ca.joeltherrien.randomforest.Settings;
|
||||||
|
import ca.joeltherrien.randomforest.covariates.Covariate;
|
||||||
|
import ca.joeltherrien.randomforest.covariates.settings.BooleanCovariateSettings;
|
||||||
|
import ca.joeltherrien.randomforest.covariates.settings.FactorCovariateSettings;
|
||||||
|
import ca.joeltherrien.randomforest.covariates.settings.NumericCovariateSettings;
|
||||||
|
import ca.joeltherrien.randomforest.loaders.ResponseLoader;
|
||||||
|
import ca.joeltherrien.randomforest.utils.CSVUtils;
|
||||||
|
import ca.joeltherrien.randomforest.utils.Utils;
|
||||||
|
import com.fasterxml.jackson.databind.node.JsonNodeFactory;
|
||||||
|
import com.fasterxml.jackson.databind.node.ObjectNode;
|
||||||
|
import com.fasterxml.jackson.databind.node.TextNode;
|
||||||
|
import org.junit.jupiter.api.Test;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||||
|
|
||||||
|
public class TestLoadingCSV {
|
||||||
|
|
||||||
|
/*
|
||||||
|
y,x1,x2,x3
|
||||||
|
5,3.0,"mouse",true
|
||||||
|
2,1.0,"dog",false
|
||||||
|
9,1.5,"cat",true
|
||||||
|
-3,NA,NA,NA
|
||||||
|
*/
|
||||||
|
|
||||||
|
|
||||||
|
public List<Row<Double>> loadData(String filename) throws IOException {
|
||||||
|
final ObjectNode yVarSettings = new ObjectNode(JsonNodeFactory.instance);
|
||||||
|
yVarSettings.set("type", new TextNode("Double"));
|
||||||
|
yVarSettings.set("name", new TextNode("y"));
|
||||||
|
|
||||||
|
final Settings settings = Settings.builder()
|
||||||
|
.trainingDataLocation(filename)
|
||||||
|
.covariateSettings(
|
||||||
|
Utils.easyList(new NumericCovariateSettings("x1"),
|
||||||
|
new FactorCovariateSettings("x2", Utils.easyList("dog", "cat", "mouse")),
|
||||||
|
new BooleanCovariateSettings("x3"))
|
||||||
|
)
|
||||||
|
.yVarSettings(yVarSettings)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
final List<Covariate> covariates = settings.getCovariates();
|
||||||
|
|
||||||
|
|
||||||
|
final ResponseLoader loader = settings.getResponseLoader();
|
||||||
|
|
||||||
|
return CSVUtils.loadData(covariates, loader, settings.getTrainingDataLocation());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void verifyLoadingNormal(final List<Covariate> covariates) throws IOException {
|
||||||
|
final List<Row<Double>> data = loadData("src/test/resources/testCSV.csv");
|
||||||
|
|
||||||
|
assertData(data, covariates);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void verifyLoadingGz(final List<Covariate> covariates) throws IOException {
|
||||||
|
final List<Row<Double>> data = loadData("src/test/resources/testCSV.csv.gz");
|
||||||
|
|
||||||
|
assertData(data, covariates);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
private void assertData(final List<Row<Double>> data, final List<Covariate> covariates){
|
||||||
|
final Covariate x1 = covariates.get(0);
|
||||||
|
final Covariate x2 = covariates.get(0);
|
||||||
|
final Covariate x3 = covariates.get(0);
|
||||||
|
|
||||||
|
assertEquals(4, data.size());
|
||||||
|
|
||||||
|
Row<Double> row = data.get(0);
|
||||||
|
assertEquals(5.0, (double)row.getResponse());
|
||||||
|
assertEquals(3.0, row.getCovariateValue(x1).getValue());
|
||||||
|
assertEquals("mouse", row.getCovariateValue(x2).getValue());
|
||||||
|
assertEquals(true, row.getCovariateValue(x3).getValue());
|
||||||
|
|
||||||
|
row = data.get(1);
|
||||||
|
assertEquals(2.0, (double)row.getResponse());
|
||||||
|
assertEquals(1.0, row.getCovariateValue(x1).getValue());
|
||||||
|
assertEquals("dog", row.getCovariateValue(x2).getValue());
|
||||||
|
assertEquals(false, row.getCovariateValue(x3).getValue());
|
||||||
|
|
||||||
|
row = data.get(2);
|
||||||
|
assertEquals(9.0, (double)row.getResponse());
|
||||||
|
assertEquals(1.5, row.getCovariateValue(x1).getValue());
|
||||||
|
assertEquals("cat", row.getCovariateValue(x2).getValue());
|
||||||
|
assertEquals(true, row.getCovariateValue(x3).getValue());
|
||||||
|
|
||||||
|
row = data.get(3);
|
||||||
|
assertEquals(-3.0, (double)row.getResponse());
|
||||||
|
assertTrue(row.getCovariateValue(x1).isNA());
|
||||||
|
assertTrue(row.getCovariateValue(x2).isNA());
|
||||||
|
assertTrue(row.getCovariateValue(x3).isNA());
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
102
library/pom.xml
Normal file
102
library/pom.xml
Normal file
|
@ -0,0 +1,102 @@
|
||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<project xmlns="http://maven.apache.org/POM/4.0.0"
|
||||||
|
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
||||||
|
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
||||||
|
<modelVersion>4.0.0</modelVersion>
|
||||||
|
|
||||||
|
<artifactId>largeRCRF-library</artifactId>
|
||||||
|
|
||||||
|
<!-- parent pom -->
|
||||||
|
<parent>
|
||||||
|
<groupId>ca.joeltherrien</groupId>
|
||||||
|
<artifactId>largeRCRF</artifactId>
|
||||||
|
<version>1.0-SNAPSHOT</version>
|
||||||
|
</parent>
|
||||||
|
|
||||||
|
<properties>
|
||||||
|
<java.version>1.8</java.version>
|
||||||
|
<maven.compiler.target>1.8</maven.compiler.target>
|
||||||
|
<maven.compiler.source>1.8</maven.compiler.source>
|
||||||
|
|
||||||
|
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
|
||||||
|
<project.reporting.outputEncoding>UTF-8</project.reporting.outputEncoding>
|
||||||
|
</properties>
|
||||||
|
|
||||||
|
|
||||||
|
<dependencies>
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.projectlombok</groupId>
|
||||||
|
<artifactId>lombok</artifactId>
|
||||||
|
<version>1.18.0</version>
|
||||||
|
<scope>provided</scope>
|
||||||
|
</dependency>
|
||||||
|
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.apache.commons</groupId>
|
||||||
|
<artifactId>commons-csv</artifactId>
|
||||||
|
<version>1.5</version>
|
||||||
|
<scope>test</scope>
|
||||||
|
</dependency>
|
||||||
|
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.junit.jupiter</groupId>
|
||||||
|
<artifactId>junit-jupiter-api</artifactId>
|
||||||
|
<version>5.2.0</version>
|
||||||
|
<scope>test</scope>
|
||||||
|
</dependency>
|
||||||
|
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.junit.jupiter</groupId>
|
||||||
|
<artifactId>junit-jupiter-engine</artifactId>
|
||||||
|
<version>5.2.0</version>
|
||||||
|
<scope>test</scope>
|
||||||
|
</dependency>
|
||||||
|
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.mockito</groupId>
|
||||||
|
<artifactId>mockito-core</artifactId>
|
||||||
|
<version>2.20.0</version>
|
||||||
|
<scope>test</scope>
|
||||||
|
</dependency>
|
||||||
|
|
||||||
|
</dependencies>
|
||||||
|
|
||||||
|
<build>
|
||||||
|
<plugins>
|
||||||
|
<plugin>
|
||||||
|
<groupId>org.apache.maven.plugins</groupId>
|
||||||
|
<artifactId>maven-shade-plugin</artifactId>
|
||||||
|
<executions>
|
||||||
|
<execution>
|
||||||
|
<phase>package</phase>
|
||||||
|
<goals>
|
||||||
|
<goal>shade</goal>
|
||||||
|
</goals>
|
||||||
|
</execution>
|
||||||
|
</executions>
|
||||||
|
</plugin>
|
||||||
|
<plugin>
|
||||||
|
<groupId>org.apache.maven.plugins</groupId>
|
||||||
|
<artifactId>maven-pmd-plugin</artifactId>
|
||||||
|
<version>3.11.0</version>
|
||||||
|
<executions>
|
||||||
|
<execution>
|
||||||
|
<phase>package</phase> <!-- bind to the packaging phase -->
|
||||||
|
<goals>
|
||||||
|
<goal>check</goal>
|
||||||
|
</goals>
|
||||||
|
</execution>
|
||||||
|
</executions>
|
||||||
|
<configuration>
|
||||||
|
<rulesets>
|
||||||
|
<!-- Custom local file system rule set -->
|
||||||
|
<ruleset>${project.basedir}/../pmd-rules.xml</ruleset>
|
||||||
|
</rulesets>
|
||||||
|
</configuration>
|
||||||
|
</plugin>
|
||||||
|
</plugins>
|
||||||
|
</build>
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
</project>
|
|
@ -49,8 +49,6 @@ public interface Covariate<V> extends Serializable, Comparable<Covariate> {
|
||||||
return getIndex() - other.getIndex();
|
return getIndex() - other.getIndex();
|
||||||
}
|
}
|
||||||
|
|
||||||
boolean haveNASplitPenalty();
|
|
||||||
|
|
||||||
interface Value<V> extends Serializable{
|
interface Value<V> extends Serializable{
|
||||||
|
|
||||||
Covariate<V> getParent();
|
Covariate<V> getParent();
|
|
@ -25,7 +25,6 @@ import lombok.Getter;
|
||||||
import java.util.Iterator;
|
import java.util.Iterator;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Random;
|
import java.util.Random;
|
||||||
import java.util.stream.Collectors;
|
|
||||||
|
|
||||||
public final class BooleanCovariate implements Covariate<Boolean> {
|
public final class BooleanCovariate implements Covariate<Boolean> {
|
||||||
|
|
||||||
|
@ -41,26 +40,14 @@ public final class BooleanCovariate implements Covariate<Boolean> {
|
||||||
|
|
||||||
private final BooleanSplitRule splitRule; // there's only one possible rule for BooleanCovariates.
|
private final BooleanSplitRule splitRule; // there's only one possible rule for BooleanCovariates.
|
||||||
|
|
||||||
private final boolean haveNASplitPenalty;
|
public BooleanCovariate(String name, int index){
|
||||||
@Override
|
|
||||||
public boolean haveNASplitPenalty(){
|
|
||||||
// penalty would add worthless computational time if there are no NAs
|
|
||||||
return hasNAs && haveNASplitPenalty;
|
|
||||||
}
|
|
||||||
|
|
||||||
public BooleanCovariate(String name, int index, boolean haveNASplitPenalty){
|
|
||||||
this.name = name;
|
this.name = name;
|
||||||
this.index = index;
|
this.index = index;
|
||||||
this.splitRule = new BooleanSplitRule(this);
|
splitRule = new BooleanSplitRule(this);
|
||||||
this.haveNASplitPenalty = haveNASplitPenalty;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public <Y> Iterator<Split<Y, Boolean>> generateSplitRuleUpdater(List<Row<Y>> data, int number, Random random) {
|
public <Y> Iterator<Split<Y, Boolean>> generateSplitRuleUpdater(List<Row<Y>> data, int number, Random random) {
|
||||||
if(hasNAs){
|
|
||||||
data = data.stream().filter(row -> !row.getValueByIndex(index).isNA()).collect(Collectors.toList());
|
|
||||||
}
|
|
||||||
|
|
||||||
return new SingletonIterator<>(this.splitRule.applyRule(data));
|
return new SingletonIterator<>(this.splitRule.applyRule(data));
|
||||||
}
|
}
|
||||||
|
|
|
@ -23,7 +23,6 @@ import lombok.EqualsAndHashCode;
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
|
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
import java.util.stream.Collectors;
|
|
||||||
|
|
||||||
public final class FactorCovariate implements Covariate<String> {
|
public final class FactorCovariate implements Covariate<String> {
|
||||||
|
|
||||||
|
@ -41,15 +40,8 @@ public final class FactorCovariate implements Covariate<String> {
|
||||||
|
|
||||||
private boolean hasNAs;
|
private boolean hasNAs;
|
||||||
|
|
||||||
private final boolean haveNASplitPenalty;
|
|
||||||
@Override
|
|
||||||
public boolean haveNASplitPenalty(){
|
|
||||||
// penalty would add worthless computational time if there are no NAs
|
|
||||||
return hasNAs && haveNASplitPenalty;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
public FactorCovariate(final String name, final int index, List<String> levels){
|
||||||
public FactorCovariate(final String name, final int index, List<String> levels, final boolean haveNASplitPenalty){
|
|
||||||
this.name = name;
|
this.name = name;
|
||||||
this.index = index;
|
this.index = index;
|
||||||
this.factorLevels = new HashMap<>();
|
this.factorLevels = new HashMap<>();
|
||||||
|
@ -71,22 +63,12 @@ public final class FactorCovariate implements Covariate<String> {
|
||||||
this.numberOfPossiblePairings = numberOfPossiblePairingsTemp-1;
|
this.numberOfPossiblePairings = numberOfPossiblePairingsTemp-1;
|
||||||
|
|
||||||
this.naValue = new FactorValue(null);
|
this.naValue = new FactorValue(null);
|
||||||
|
|
||||||
this.haveNASplitPenalty = haveNASplitPenalty;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public <Y> Iterator<Split<Y, String>> generateSplitRuleUpdater(List<Row<Y>> data, int number, Random random) {
|
public <Y> Iterator<Split<Y, String>> generateSplitRuleUpdater(List<Row<Y>> data, int number, Random random) {
|
||||||
if(hasNAs()){
|
|
||||||
data = data.stream().filter(row -> !row.getCovariateValue(this).isNA()).collect(Collectors.toList());
|
|
||||||
}
|
|
||||||
|
|
||||||
if(number == 0){ // nsplit = 0 => try every possibility, although we limit it to the number of observations.
|
|
||||||
number = data.size();
|
|
||||||
}
|
|
||||||
|
|
||||||
final Set<Split<Y, String>> splits = new HashSet<>();
|
final Set<Split<Y, String>> splits = new HashSet<>();
|
||||||
|
|
||||||
// This is to ensure we don't get stuck in an infinite loop for small factors
|
// This is to ensure we don't get stuck in an infinite loop for small factors
|
|
@ -47,13 +47,6 @@ public final class NumericCovariate implements Covariate<Double> {
|
||||||
|
|
||||||
private boolean hasNAs = false;
|
private boolean hasNAs = false;
|
||||||
|
|
||||||
private final boolean haveNASplitPenalty;
|
|
||||||
@Override
|
|
||||||
public boolean haveNASplitPenalty(){
|
|
||||||
// penalty would add worthless computational time if there are no NAs
|
|
||||||
return hasNAs && haveNASplitPenalty;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public <Y> NumericSplitRuleUpdater<Y> generateSplitRuleUpdater(List<Row<Y>> data, int number, Random random) {
|
public <Y> NumericSplitRuleUpdater<Y> generateSplitRuleUpdater(List<Row<Y>> data, int number, Random random) {
|
||||||
Stream<Row<Y>> stream = data.stream();
|
Stream<Row<Y>> stream = data.stream();
|
|
@ -0,0 +1,128 @@
|
||||||
|
/*
|
||||||
|
* Copyright (c) 2019 Joel Therrien.
|
||||||
|
* This program is free software: you can redistribute it and/or modify
|
||||||
|
* it under the terms of the GNU General Public License as published by
|
||||||
|
* the Free Software Foundation, either version 3 of the License, or
|
||||||
|
* (at your option) any later version.
|
||||||
|
*
|
||||||
|
* This program is distributed in the hope that it will be useful,
|
||||||
|
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
* GNU General Public License for more details.
|
||||||
|
*
|
||||||
|
* You should have received a copy of the GNU General Public License
|
||||||
|
* along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package ca.joeltherrien.randomforest.responses.competingrisk.combiner;
|
||||||
|
|
||||||
|
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskFunctions;
|
||||||
|
import ca.joeltherrien.randomforest.tree.ResponseCombiner;
|
||||||
|
import ca.joeltherrien.randomforest.utils.RightContinuousStepFunction;
|
||||||
|
import ca.joeltherrien.randomforest.utils.Utils;
|
||||||
|
import lombok.RequiredArgsConstructor;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
@RequiredArgsConstructor
|
||||||
|
public class CompetingRiskFunctionCombiner implements ResponseCombiner<CompetingRiskFunctions, CompetingRiskFunctions> {
|
||||||
|
|
||||||
|
private static final long serialVersionUID = 1L;
|
||||||
|
|
||||||
|
private final int[] events;
|
||||||
|
private final double[] times; // We may restrict ourselves to specific times.
|
||||||
|
|
||||||
|
public int[] getEvents(){
|
||||||
|
return events.clone();
|
||||||
|
}
|
||||||
|
|
||||||
|
public double[] getTimes(){
|
||||||
|
return times.clone();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public CompetingRiskFunctions combine(List<CompetingRiskFunctions> responses) {
|
||||||
|
|
||||||
|
final double[] timesToUse;
|
||||||
|
if(times != null){
|
||||||
|
timesToUse = times;
|
||||||
|
}
|
||||||
|
else{
|
||||||
|
timesToUse = responses.stream()
|
||||||
|
.map(functions -> functions.getSurvivalCurve())
|
||||||
|
.flatMapToDouble(
|
||||||
|
function -> Arrays.stream(function.getX())
|
||||||
|
).sorted().distinct().toArray();
|
||||||
|
}
|
||||||
|
|
||||||
|
final double n = responses.size();
|
||||||
|
|
||||||
|
final double[] survivalY = new double[timesToUse.length];
|
||||||
|
final double[][] csCHFY = new double[events.length][timesToUse.length];
|
||||||
|
final double[][] cifY = new double[events.length][timesToUse.length];
|
||||||
|
|
||||||
|
/*
|
||||||
|
We're going to try to efficiently put our predictions together -
|
||||||
|
Assumptions - for each event on a response, the hazard and CIF functions share the same x points
|
||||||
|
|
||||||
|
Plan - go through the time on each response and make use of that so that when we search for a time index
|
||||||
|
to evaluate the function at, we don't need to re-search the earlier times.
|
||||||
|
|
||||||
|
*/
|
||||||
|
|
||||||
|
|
||||||
|
for(final CompetingRiskFunctions currentFunctions : responses){
|
||||||
|
final double[] survivalXPoints = currentFunctions.getSurvivalCurve().getX();
|
||||||
|
final double[][] eventSpecificXPoints = new double[events.length][];
|
||||||
|
|
||||||
|
for(final int event : events){
|
||||||
|
eventSpecificXPoints[event-1] = currentFunctions.getCumulativeIncidenceFunction(event)
|
||||||
|
.getX();
|
||||||
|
}
|
||||||
|
|
||||||
|
int previousSurvivalIndex = 0;
|
||||||
|
final int[] previousEventSpecificIndex = new int[events.length]; // relying on 0 being default value
|
||||||
|
|
||||||
|
for(int i=0; i<timesToUse.length; i++){
|
||||||
|
final double time = timesToUse[i];
|
||||||
|
|
||||||
|
// Survival curve
|
||||||
|
final int survivalTimeIndex = Utils.binarySearchLessThan(previousSurvivalIndex, survivalXPoints.length, survivalXPoints, time);
|
||||||
|
survivalY[i] = survivalY[i] + currentFunctions.getSurvivalCurve().evaluateByIndex(survivalTimeIndex) / n;
|
||||||
|
previousSurvivalIndex = Math.max(survivalTimeIndex, 0); // if our current time is less than the smallest time in xPoints then binarySearchLessThan returned a -1.
|
||||||
|
// -1's not an issue for evaluateByIndex, but it is an issue for the next time binarySearchLessThan is called.
|
||||||
|
|
||||||
|
// CHFs and CIFs
|
||||||
|
for(final int event : events){
|
||||||
|
final double[] xPoints = eventSpecificXPoints[event-1];
|
||||||
|
final int eventTimeIndex = Utils.binarySearchLessThan(previousEventSpecificIndex[event-1], xPoints.length,
|
||||||
|
xPoints, time);
|
||||||
|
csCHFY[event-1][i] = csCHFY[event-1][i] + currentFunctions.getCauseSpecificHazardFunction(event)
|
||||||
|
.evaluateByIndex(eventTimeIndex) / n;
|
||||||
|
cifY[event-1][i] = cifY[event-1][i] + currentFunctions.getCumulativeIncidenceFunction(event)
|
||||||
|
.evaluateByIndex(eventTimeIndex) / n;
|
||||||
|
|
||||||
|
previousEventSpecificIndex[event-1] = Math.max(eventTimeIndex, 0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
final RightContinuousStepFunction survivalFunction = new RightContinuousStepFunction(timesToUse, survivalY, 1.0);
|
||||||
|
final List<RightContinuousStepFunction> causeSpecificCumulativeHazardFunctionList = new ArrayList<>(events.length);
|
||||||
|
final List<RightContinuousStepFunction> cumulativeIncidenceFunctionList = new ArrayList<>(events.length);
|
||||||
|
|
||||||
|
for(final int event : events){
|
||||||
|
causeSpecificCumulativeHazardFunctionList.add(event-1, new RightContinuousStepFunction(timesToUse, csCHFY[event-1], 0));
|
||||||
|
cumulativeIncidenceFunctionList.add(event-1, new RightContinuousStepFunction(timesToUse, cifY[event-1], 0));
|
||||||
|
}
|
||||||
|
|
||||||
|
return CompetingRiskFunctions.builder()
|
||||||
|
.causeSpecificHazards(causeSpecificCumulativeHazardFunctionList)
|
||||||
|
.cumulativeIncidenceCurves(cumulativeIncidenceFunctionList)
|
||||||
|
.survivalCurve(survivalFunction)
|
||||||
|
.build();
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,39 @@
|
||||||
|
/*
|
||||||
|
* Copyright (c) 2019 Joel Therrien.
|
||||||
|
* This program is free software: you can redistribute it and/or modify
|
||||||
|
* it under the terms of the GNU General Public License as published by
|
||||||
|
* the Free Software Foundation, either version 3 of the License, or
|
||||||
|
* (at your option) any later version.
|
||||||
|
*
|
||||||
|
* This program is distributed in the hope that it will be useful,
|
||||||
|
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
* GNU General Public License for more details.
|
||||||
|
*
|
||||||
|
* You should have received a copy of the GNU General Public License
|
||||||
|
* along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package ca.joeltherrien.randomforest.responses.regression;
|
||||||
|
|
||||||
|
import ca.joeltherrien.randomforest.tree.ResponseCombiner;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns the Mean value of a group of Doubles.
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
public class MeanResponseCombiner implements ResponseCombiner<Double, Double> {
|
||||||
|
private static final long serialVersionUID = 1L;
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Double combine(List<Double> responses) {
|
||||||
|
final double size = responses.size();
|
||||||
|
|
||||||
|
return responses.stream().mapToDouble(db -> db/size).sum();
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
}
|
|
@ -17,18 +17,31 @@
|
||||||
package ca.joeltherrien.randomforest.tree;
|
package ca.joeltherrien.randomforest.tree;
|
||||||
|
|
||||||
import ca.joeltherrien.randomforest.CovariateRow;
|
import ca.joeltherrien.randomforest.CovariateRow;
|
||||||
|
import ca.joeltherrien.randomforest.covariates.Covariate;
|
||||||
|
import lombok.Builder;
|
||||||
|
|
||||||
|
import java.util.Collections;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.TreeMap;
|
import java.util.TreeMap;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
public abstract class Forest<O, FO> {
|
@Builder
|
||||||
|
public class Forest<O, FO> { // O = output of trees, FO = forest output. In practice O == FO, even in competing risk & survival settings
|
||||||
|
|
||||||
public abstract FO evaluate(CovariateRow row);
|
private final List<Tree<O>> trees;
|
||||||
public abstract FO evaluateOOB(CovariateRow row);
|
private final ResponseCombiner<O, FO> treeResponseCombiner;
|
||||||
public abstract Iterable<Tree<O>> getTrees();
|
private final List<Covariate> covariateList;
|
||||||
public abstract int getNumberOfTrees();
|
|
||||||
|
public FO evaluate(CovariateRow row){
|
||||||
|
|
||||||
|
return treeResponseCombiner.combine(
|
||||||
|
trees.stream()
|
||||||
|
.map(node -> node.evaluate(row))
|
||||||
|
.collect(Collectors.toList())
|
||||||
|
);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Used primarily in the R package interface to avoid R loops; and for easier parallelization.
|
* Used primarily in the R package interface to avoid R loops; and for easier parallelization.
|
||||||
|
@ -80,6 +93,21 @@ public abstract class Forest<O, FO> {
|
||||||
.collect(Collectors.toList());
|
.collect(Collectors.toList());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public FO evaluateOOB(CovariateRow row){
|
||||||
|
|
||||||
|
return treeResponseCombiner.combine(
|
||||||
|
trees.stream()
|
||||||
|
.filter(tree -> !tree.idInBootstrapSample(row.getId()))
|
||||||
|
.map(node -> node.evaluate(row))
|
||||||
|
.collect(Collectors.toList())
|
||||||
|
);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
public List<Tree<O>> getTrees(){
|
||||||
|
return Collections.unmodifiableList(trees);
|
||||||
|
}
|
||||||
|
|
||||||
public Map<Integer, Integer> findSplitsByCovariate(){
|
public Map<Integer, Integer> findSplitsByCovariate(){
|
||||||
final Map<Integer, Integer> countMap = new TreeMap<>();
|
final Map<Integer, Integer> countMap = new TreeMap<>();
|
||||||
|
|
||||||
|
@ -130,5 +158,4 @@ public abstract class Forest<O, FO> {
|
||||||
return countTerminalNodes;
|
return countTerminalNodes;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
|
@ -38,7 +38,7 @@ public class ForestTrainer<Y, TO, FO> {
|
||||||
|
|
||||||
private final TreeTrainer<Y, TO> treeTrainer;
|
private final TreeTrainer<Y, TO> treeTrainer;
|
||||||
private final List<Covariate> covariates;
|
private final List<Covariate> covariates;
|
||||||
private final ForestResponseCombiner<TO, FO> treeResponseCombiner;
|
private final ResponseCombiner<TO, FO> treeResponseCombiner;
|
||||||
private final List<Row<Y>> data;
|
private final List<Row<Y>> data;
|
||||||
|
|
||||||
// number of trees to try
|
// number of trees to try
|
||||||
|
@ -57,10 +57,10 @@ public class ForestTrainer<Y, TO, FO> {
|
||||||
* in which case its trees are combined with the new one.
|
* in which case its trees are combined with the new one.
|
||||||
* @return A trained forest.
|
* @return A trained forest.
|
||||||
*/
|
*/
|
||||||
public OnlineForest<TO, FO> trainSerialInMemory(Optional<Forest<TO, FO>> initialForest){
|
public Forest<TO, FO> trainSerialInMemory(Optional<Forest<TO, FO>> initialForest){
|
||||||
|
|
||||||
final List<Tree<TO>> trees = new ArrayList<>(ntree);
|
final List<Tree<TO>> trees = new ArrayList<>(ntree);
|
||||||
initialForest.ifPresent(forest -> forest.getTrees().forEach(trees::add));
|
initialForest.ifPresent(forest -> trees.addAll(forest.getTrees()));
|
||||||
|
|
||||||
final Bootstrapper<Row<Y>> bootstrapper = new Bootstrapper<>(data);
|
final Bootstrapper<Row<Y>> bootstrapper = new Bootstrapper<>(data);
|
||||||
|
|
||||||
|
@ -77,9 +77,11 @@ public class ForestTrainer<Y, TO, FO> {
|
||||||
System.out.println("Finished");
|
System.out.println("Finished");
|
||||||
}
|
}
|
||||||
|
|
||||||
return OnlineForest.<TO, FO>builder()
|
|
||||||
|
return Forest.<TO, FO>builder()
|
||||||
.treeResponseCombiner(treeResponseCombiner)
|
.treeResponseCombiner(treeResponseCombiner)
|
||||||
.trees(trees)
|
.trees(trees)
|
||||||
|
.covariateList(covariates)
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -92,7 +94,7 @@ public class ForestTrainer<Y, TO, FO> {
|
||||||
* There cannot be existing trees if the initial forest is
|
* There cannot be existing trees if the initial forest is
|
||||||
* specified.
|
* specified.
|
||||||
*/
|
*/
|
||||||
public OfflineForest<TO, FO> trainSerialOnDisk(Optional<Forest<TO, FO>> initialForest){
|
public void trainSerialOnDisk(Optional<Forest<TO, FO>> initialForest){
|
||||||
// First we need to see how many trees there currently are
|
// First we need to see how many trees there currently are
|
||||||
final File folder = new File(saveTreeLocation);
|
final File folder = new File(saveTreeLocation);
|
||||||
if(!folder.exists()){
|
if(!folder.exists()){
|
||||||
|
@ -113,14 +115,17 @@ public class ForestTrainer<Y, TO, FO> {
|
||||||
final AtomicInteger treeCount; // tracks how many trees are finished
|
final AtomicInteger treeCount; // tracks how many trees are finished
|
||||||
// Using an AtomicInteger is overkill for serial code, but this lets us reuse TreeSavedWorker
|
// Using an AtomicInteger is overkill for serial code, but this lets us reuse TreeSavedWorker
|
||||||
if(initialForest.isPresent()){
|
if(initialForest.isPresent()){
|
||||||
int j=0;
|
final List<Tree<TO>> initialTrees = initialForest.get().getTrees();
|
||||||
for(final Tree<TO> tree : initialForest.get().getTrees()){
|
|
||||||
|
for(int j=0; j<initialTrees.size(); j++){
|
||||||
final String filename = "tree-" + (j+1) + ".tree";
|
final String filename = "tree-" + (j+1) + ".tree";
|
||||||
|
final Tree<TO> tree = initialTrees.get(j);
|
||||||
|
|
||||||
saveTree(tree, filename);
|
saveTree(tree, filename);
|
||||||
j++;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
treeCount = new AtomicInteger(j);
|
treeCount = new AtomicInteger(initialTrees.size());
|
||||||
} else{
|
} else{
|
||||||
treeCount = new AtomicInteger(treeFiles.length);
|
treeCount = new AtomicInteger(treeFiles.length);
|
||||||
}
|
}
|
||||||
|
@ -148,8 +153,6 @@ public class ForestTrainer<Y, TO, FO> {
|
||||||
System.out.println("Finished");
|
System.out.println("Finished");
|
||||||
}
|
}
|
||||||
|
|
||||||
return new OfflineForest<>(folder, treeResponseCombiner);
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -159,7 +162,7 @@ public class ForestTrainer<Y, TO, FO> {
|
||||||
* in which case its trees are combined with the new one.
|
* in which case its trees are combined with the new one.
|
||||||
* @param threads The number of trees to train at once.
|
* @param threads The number of trees to train at once.
|
||||||
*/
|
*/
|
||||||
public OnlineForest<TO, FO> trainParallelInMemory(Optional<Forest<TO, FO>> initialForest, int threads){
|
public Forest<TO, FO> trainParallelInMemory(Optional<Forest<TO, FO>> initialForest, int threads){
|
||||||
|
|
||||||
// create a list that is pre-specified in size (I can call the .set method at any index < ntree without
|
// create a list that is pre-specified in size (I can call the .set method at any index < ntree without
|
||||||
// the earlier indexes being filled.
|
// the earlier indexes being filled.
|
||||||
|
@ -167,12 +170,11 @@ public class ForestTrainer<Y, TO, FO> {
|
||||||
|
|
||||||
final int startingCount;
|
final int startingCount;
|
||||||
if(initialForest.isPresent()){
|
if(initialForest.isPresent()){
|
||||||
int j = 0;
|
final List<Tree<TO>> initialTrees = initialForest.get().getTrees();
|
||||||
for(final Tree<TO> tree : initialForest.get().getTrees()){
|
for(int j=0; j<initialTrees.size(); j++) {
|
||||||
trees.set(j, tree);
|
trees.set(j, initialTrees.get(j));
|
||||||
j++;
|
|
||||||
}
|
}
|
||||||
startingCount = initialForest.get().getNumberOfTrees();
|
startingCount = initialTrees.size();
|
||||||
}
|
}
|
||||||
else{
|
else{
|
||||||
startingCount = 0;
|
startingCount = 0;
|
||||||
|
@ -217,7 +219,7 @@ public class ForestTrainer<Y, TO, FO> {
|
||||||
System.out.println("\nFinished");
|
System.out.println("\nFinished");
|
||||||
}
|
}
|
||||||
|
|
||||||
return OnlineForest.<TO, FO>builder()
|
return Forest.<TO, FO>builder()
|
||||||
.treeResponseCombiner(treeResponseCombiner)
|
.treeResponseCombiner(treeResponseCombiner)
|
||||||
.trees(trees)
|
.trees(trees)
|
||||||
.build();
|
.build();
|
||||||
|
@ -233,7 +235,7 @@ public class ForestTrainer<Y, TO, FO> {
|
||||||
* specified.
|
* specified.
|
||||||
* @param threads The number of trees to train at once.
|
* @param threads The number of trees to train at once.
|
||||||
*/
|
*/
|
||||||
public OfflineForest<TO, FO> trainParallelOnDisk(Optional<Forest<TO, FO>> initialForest, int threads){
|
public void trainParallelOnDisk(Optional<Forest<TO, FO>> initialForest, int threads){
|
||||||
// First we need to see how many trees there currently are
|
// First we need to see how many trees there currently are
|
||||||
final File folder = new File(saveTreeLocation);
|
final File folder = new File(saveTreeLocation);
|
||||||
if(!folder.exists()){
|
if(!folder.exists()){
|
||||||
|
@ -253,14 +255,17 @@ public class ForestTrainer<Y, TO, FO> {
|
||||||
|
|
||||||
final AtomicInteger treeCount; // tracks how many trees are finished
|
final AtomicInteger treeCount; // tracks how many trees are finished
|
||||||
if(initialForest.isPresent()){
|
if(initialForest.isPresent()){
|
||||||
int j=0;
|
final List<Tree<TO>> initialTrees = initialForest.get().getTrees();
|
||||||
for(final Tree<TO> tree : initialForest.get().getTrees()){
|
|
||||||
|
for(int j=0; j<initialTrees.size(); j++){
|
||||||
final String filename = "tree-" + (j+1) + ".tree";
|
final String filename = "tree-" + (j+1) + ".tree";
|
||||||
|
final Tree<TO> tree = initialTrees.get(j);
|
||||||
|
|
||||||
saveTree(tree, filename);
|
saveTree(tree, filename);
|
||||||
j++;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
treeCount = new AtomicInteger(j);
|
treeCount = new AtomicInteger(initialTrees.size());
|
||||||
} else{
|
} else{
|
||||||
treeCount = new AtomicInteger(treeFiles.length);
|
treeCount = new AtomicInteger(treeFiles.length);
|
||||||
}
|
}
|
||||||
|
@ -304,8 +309,6 @@ public class ForestTrainer<Y, TO, FO> {
|
||||||
System.out.println("\nFinished");
|
System.out.println("\nFinished");
|
||||||
}
|
}
|
||||||
|
|
||||||
return new OfflineForest<>(folder, treeResponseCombiner);
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private Tree<TO> trainTree(final Bootstrapper<Row<Y>> bootstrapper, Random random){
|
private Tree<TO> trainTree(final Bootstrapper<Row<Y>> bootstrapper, Random random){
|
|
@ -17,13 +17,15 @@
|
||||||
package ca.joeltherrien.randomforest.tree;
|
package ca.joeltherrien.randomforest.tree;
|
||||||
|
|
||||||
import lombok.AllArgsConstructor;
|
import lombok.AllArgsConstructor;
|
||||||
import lombok.Data;
|
import lombok.Getter;
|
||||||
|
|
||||||
@AllArgsConstructor
|
@AllArgsConstructor
|
||||||
@Data
|
|
||||||
public class SplitAndScore<Y, V> {
|
public class SplitAndScore<Y, V> {
|
||||||
|
|
||||||
private Split<Y, V> split;
|
@Getter
|
||||||
private Double score;
|
private final Split<Y, V> split;
|
||||||
|
|
||||||
|
@Getter
|
||||||
|
private final Double score;
|
||||||
|
|
||||||
}
|
}
|
|
@ -17,9 +17,7 @@
|
||||||
package ca.joeltherrien.randomforest.tree;
|
package ca.joeltherrien.randomforest.tree;
|
||||||
|
|
||||||
import ca.joeltherrien.randomforest.Row;
|
import ca.joeltherrien.randomforest.Row;
|
||||||
import ca.joeltherrien.randomforest.VisibleForTesting;
|
|
||||||
import ca.joeltherrien.randomforest.covariates.Covariate;
|
import ca.joeltherrien.randomforest.covariates.Covariate;
|
||||||
import ca.joeltherrien.randomforest.utils.SingletonIterator;
|
|
||||||
import lombok.AccessLevel;
|
import lombok.AccessLevel;
|
||||||
import lombok.AllArgsConstructor;
|
import lombok.AllArgsConstructor;
|
||||||
import lombok.Builder;
|
import lombok.Builder;
|
||||||
|
@ -74,12 +72,31 @@ public class TreeTrainer<Y, O> {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
// Now that we have the best split; we need to handle any NAs that were dropped off
|
// Now that we have the best split; we need to handle any NAs that were dropped off
|
||||||
final double probabilityLeftHand = (double) bestSplit.leftHand.size() /
|
final double probabilityLeftHand = (double) bestSplit.leftHand.size() /
|
||||||
(double) (bestSplit.leftHand.size() + bestSplit.rightHand.size());
|
(double) (bestSplit.leftHand.size() + bestSplit.rightHand.size());
|
||||||
|
|
||||||
// Assign missing values to the split if necessary
|
// Assign missing values to the split if necessary
|
||||||
bestSplit = randomlyAssignNAs(data, bestSplit, random);
|
if(covariates.get(bestSplit.getSplitRule().getParentCovariateIndex()).hasNAs()){
|
||||||
|
bestSplit = bestSplit.modifiableClone(); // the lists in bestSplit are otherwise usually unmodifiable lists
|
||||||
|
|
||||||
|
for(Row<Y> row : data) {
|
||||||
|
final int covariateIndex = bestSplit.getSplitRule().getParentCovariateIndex();
|
||||||
|
|
||||||
|
if(row.getValueByIndex(covariateIndex).isNA()) {
|
||||||
|
final boolean randomDecision = random.nextDouble() <= probabilityLeftHand;
|
||||||
|
|
||||||
|
if(randomDecision){
|
||||||
|
bestSplit.getLeftHand().add(row);
|
||||||
|
}
|
||||||
|
else{
|
||||||
|
bestSplit.getRightHand().add(row);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
final Node<O> leftNode;
|
final Node<O> leftNode;
|
||||||
final Node<O> rightNode;
|
final Node<O> rightNode;
|
||||||
|
@ -127,8 +144,7 @@ public class TreeTrainer<Y, O> {
|
||||||
return splitCovariates;
|
return splitCovariates;
|
||||||
}
|
}
|
||||||
|
|
||||||
@VisibleForTesting
|
private Split<Y, ?> findBestSplitRule(List<Row<Y>> data, List<Covariate> covariatesToTry, Random random){
|
||||||
public Split<Y, ?> findBestSplitRule(List<Row<Y>> data, List<Covariate> covariatesToTry, Random random){
|
|
||||||
|
|
||||||
SplitAndScore<Y, ?> bestSplitAndScore = null;
|
SplitAndScore<Y, ?> bestSplitAndScore = null;
|
||||||
final SplitFinder noGenericSplitFinder = splitFinder; // cause Java generics are sometimes too frustrating
|
final SplitFinder noGenericSplitFinder = splitFinder; // cause Java generics are sometimes too frustrating
|
||||||
|
@ -141,32 +157,10 @@ public class TreeTrainer<Y, O> {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
SplitAndScore<Y, ?> candidateSplitAndScore = noGenericSplitFinder.findBestSplit(iterator);
|
final SplitAndScore<Y, ?> candidateSplitAndScore = noGenericSplitFinder.findBestSplit(iterator);
|
||||||
|
|
||||||
|
if(candidateSplitAndScore != null && (bestSplitAndScore == null ||
|
||||||
if(candidateSplitAndScore == null){
|
candidateSplitAndScore.getScore() > bestSplitAndScore.getScore())) {
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
// This score was based on splitting only non-NA values. However, there might be a similar covariate we are also considering
|
|
||||||
// that is just as good at splitting but has less NAs; we should thus penalize the split score for variables with NAs
|
|
||||||
// We do this by randomly assigning the NAs and then recalculating the split score on the best split we already have.
|
|
||||||
//
|
|
||||||
// We only have to penalize the score though if we know it's possible that this might be the best split. If it's not,
|
|
||||||
// then we can skip the computations.
|
|
||||||
final boolean mayBeGoodSplit = bestSplitAndScore == null ||
|
|
||||||
candidateSplitAndScore.getScore() > bestSplitAndScore.getScore();
|
|
||||||
if(mayBeGoodSplit && covariate.haveNASplitPenalty()){
|
|
||||||
Split<Y, ?> candiateSplitWithNAs = randomlyAssignNAs(data, candidateSplitAndScore.getSplit(), random);
|
|
||||||
final Iterator<Split<Y, ?>> newSplitWithRandomNAs = new SingletonIterator<>(candiateSplitWithNAs);
|
|
||||||
final double newScore = splitFinder.findBestSplit(newSplitWithRandomNAs).getScore();
|
|
||||||
|
|
||||||
// There's a chance that NAs might add noise to *improve* the score; but we want to ensure we penalize it.
|
|
||||||
// Thus we only change the score if its worse.
|
|
||||||
candidateSplitAndScore.setScore(Math.min(newScore, candidateSplitAndScore.getScore()));
|
|
||||||
}
|
|
||||||
|
|
||||||
if(bestSplitAndScore == null || candidateSplitAndScore.getScore() > bestSplitAndScore.getScore()) {
|
|
||||||
bestSplitAndScore = candidateSplitAndScore;
|
bestSplitAndScore = candidateSplitAndScore;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -180,38 +174,6 @@ public class TreeTrainer<Y, O> {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private <V> Split<Y, V> randomlyAssignNAs(List<Row<Y>> data, Split<Y, V> existingSplit, Random random){
|
|
||||||
|
|
||||||
// Now that we have the best split; we need to handle any NAs that were dropped off
|
|
||||||
final double probabilityLeftHand = (double) existingSplit.leftHand.size() /
|
|
||||||
(double) (existingSplit.leftHand.size() + existingSplit.rightHand.size());
|
|
||||||
|
|
||||||
|
|
||||||
final int covariateIndex = existingSplit.getSplitRule().getParentCovariateIndex();
|
|
||||||
|
|
||||||
// Assign missing values to the split if necessary
|
|
||||||
if(covariates.get(existingSplit.getSplitRule().getParentCovariateIndex()).hasNAs()){
|
|
||||||
existingSplit = existingSplit.modifiableClone(); // the lists in bestSplit are otherwise usually unmodifiable lists
|
|
||||||
|
|
||||||
for(Row<Y> row : data) {
|
|
||||||
if(row.getValueByIndex(covariateIndex).isNA()) {
|
|
||||||
final boolean randomDecision = random.nextDouble() <= probabilityLeftHand;
|
|
||||||
|
|
||||||
if(randomDecision){
|
|
||||||
existingSplit.getLeftHand().add(row);
|
|
||||||
}
|
|
||||||
else{
|
|
||||||
existingSplit.getRightHand().add(row);
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return existingSplit;
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
private boolean nodeIsPure(List<Row<Y>> data){
|
private boolean nodeIsPure(List<Row<Y>> data){
|
||||||
if(!checkNodePurity){
|
if(!checkNodePurity){
|
||||||
return false;
|
return false;
|
|
@ -0,0 +1,65 @@
|
||||||
|
package ca.joeltherrien.randomforest.tree.vimp;
|
||||||
|
|
||||||
|
import ca.joeltherrien.randomforest.CovariateRow;
|
||||||
|
import ca.joeltherrien.randomforest.Row;
|
||||||
|
import ca.joeltherrien.randomforest.covariates.Covariate;
|
||||||
|
import ca.joeltherrien.randomforest.tree.Forest;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Optional;
|
||||||
|
import java.util.Random;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
|
|
||||||
|
public class VariableImportanceCalculator<Y, P> {
|
||||||
|
|
||||||
|
private final ErrorCalculator<Y, P> errorCalculator;
|
||||||
|
private final Forest<Y, P> forest;
|
||||||
|
private final List<Row<Y>> observations;
|
||||||
|
private final List<Y> observedResponses;
|
||||||
|
|
||||||
|
private final boolean isTrainingSet; // If true, then we use out-of-bag predictions
|
||||||
|
private final double baselineError;
|
||||||
|
|
||||||
|
public VariableImportanceCalculator(
|
||||||
|
ErrorCalculator<Y, P> errorCalculator,
|
||||||
|
Forest<Y, P> forest,
|
||||||
|
List<Row<Y>> observations,
|
||||||
|
boolean isTrainingSet
|
||||||
|
){
|
||||||
|
this.errorCalculator = errorCalculator;
|
||||||
|
this.forest = forest;
|
||||||
|
this.observations = observations;
|
||||||
|
this.isTrainingSet = isTrainingSet;
|
||||||
|
|
||||||
|
this.observedResponses = observations.stream()
|
||||||
|
.map(row -> row.getResponse()).collect(Collectors.toList());
|
||||||
|
|
||||||
|
final List<P> baselinePredictions = makePredictions(observations);
|
||||||
|
this.baselineError = errorCalculator.averageError(observedResponses, baselinePredictions);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
public double calculateVariableImportance(Covariate covariate, Optional<Random> random){
|
||||||
|
final List<CovariateRow> scrambledValues = CovariateRow.scrambleCovariateValues(this.observations, covariate, random);
|
||||||
|
final List<P> alternatePredictions = makePredictions(scrambledValues);
|
||||||
|
final double newError = errorCalculator.averageError(this.observedResponses, alternatePredictions);
|
||||||
|
|
||||||
|
return newError - this.baselineError;
|
||||||
|
}
|
||||||
|
|
||||||
|
public double[] calculateVariableImportance(List<Covariate> covariates, Optional<Random> random){
|
||||||
|
return covariates.stream()
|
||||||
|
.mapToDouble(covariate -> calculateVariableImportance(covariate, random))
|
||||||
|
.toArray();
|
||||||
|
}
|
||||||
|
|
||||||
|
private List<P> makePredictions(List<? extends CovariateRow> rowList){
|
||||||
|
if(isTrainingSet){
|
||||||
|
return forest.evaluateOOB(rowList);
|
||||||
|
} else{
|
||||||
|
return forest.evaluate(rowList);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -16,7 +16,9 @@
|
||||||
|
|
||||||
package ca.joeltherrien.randomforest.utils;
|
package ca.joeltherrien.randomforest.utils;
|
||||||
|
|
||||||
import ca.joeltherrien.randomforest.tree.*;
|
import ca.joeltherrien.randomforest.tree.Forest;
|
||||||
|
import ca.joeltherrien.randomforest.tree.ResponseCombiner;
|
||||||
|
import ca.joeltherrien.randomforest.tree.Tree;
|
||||||
|
|
||||||
import java.io.*;
|
import java.io.*;
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
|
@ -25,17 +27,12 @@ import java.util.zip.GZIPOutputStream;
|
||||||
|
|
||||||
public class DataUtils {
|
public class DataUtils {
|
||||||
|
|
||||||
public static <O, FO> OnlineForest<O, FO> loadOnlineForest(File folder, ResponseCombiner<O, FO> treeResponseCombiner) throws IOException, ClassNotFoundException {
|
public static <O, FO> Forest<O, FO> loadForest(File folder, ResponseCombiner<O, FO> treeResponseCombiner) throws IOException, ClassNotFoundException {
|
||||||
if(!folder.isDirectory()){
|
if(!folder.isDirectory()){
|
||||||
throw new IllegalArgumentException("Tree directory must be a directory!");
|
throw new IllegalArgumentException("Tree directory must be a directory!");
|
||||||
}
|
}
|
||||||
|
|
||||||
final File[] treeFiles = folder.listFiles((file, s) -> s.endsWith(".tree"));
|
final File[] treeFiles = folder.listFiles((file, s) -> s.endsWith(".tree"));
|
||||||
|
|
||||||
return loadOnlineForest(treeFiles, treeResponseCombiner);
|
|
||||||
}
|
|
||||||
|
|
||||||
public static <O, FO> OnlineForest<O, FO> loadOnlineForest(File[] treeFiles, ResponseCombiner<O, FO> treeResponseCombiner) throws IOException, ClassNotFoundException {
|
|
||||||
final List<File> treeFileList = Arrays.asList(treeFiles);
|
final List<File> treeFileList = Arrays.asList(treeFiles);
|
||||||
|
|
||||||
Collections.sort(treeFileList, Comparator.comparing(File::getName));
|
Collections.sort(treeFileList, Comparator.comparing(File::getName));
|
||||||
|
@ -51,16 +48,16 @@ public class DataUtils {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return OnlineForest.<O, FO>builder()
|
return Forest.<O, FO>builder()
|
||||||
.trees(treeList)
|
.trees(treeList)
|
||||||
.treeResponseCombiner(treeResponseCombiner)
|
.treeResponseCombiner(treeResponseCombiner)
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public static <O, FO> OnlineForest<O, FO> loadOnlineForest(String folder, ResponseCombiner<O, FO> treeResponseCombiner) throws IOException, ClassNotFoundException {
|
public static <O, FO> Forest<O, FO> loadForest(String folder, ResponseCombiner<O, FO> treeResponseCombiner) throws IOException, ClassNotFoundException {
|
||||||
final File directory = new File(folder);
|
final File directory = new File(folder);
|
||||||
return loadOnlineForest(directory, treeResponseCombiner);
|
return loadForest(directory, treeResponseCombiner);
|
||||||
}
|
}
|
||||||
|
|
||||||
public static void saveObject(Serializable object, String filename) throws IOException {
|
public static void saveObject(Serializable object, String filename) throws IOException {
|
|
@ -198,14 +198,4 @@ public final class RUtils {
|
||||||
return newList;
|
return newList;
|
||||||
}
|
}
|
||||||
|
|
||||||
public static File[] getTreeFileArray(String folderPath, int endingId){
|
|
||||||
final File[] fileArray = new File[endingId];
|
|
||||||
|
|
||||||
for(int i = 1; i <= endingId; i++){
|
|
||||||
fileArray[i-1] = new File(folderPath + "/tree-" + i + ".tree");
|
|
||||||
}
|
|
||||||
|
|
||||||
return fileArray;
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
}
|
|
@ -16,8 +16,6 @@
|
||||||
|
|
||||||
package ca.joeltherrien.randomforest.utils;
|
package ca.joeltherrien.randomforest.utils;
|
||||||
|
|
||||||
import lombok.Getter;
|
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
@ -41,7 +39,6 @@ public final class RightContinuousStepFunction extends StepFunction {
|
||||||
*
|
*
|
||||||
* May not be null.
|
* May not be null.
|
||||||
*/
|
*/
|
||||||
@Getter
|
|
||||||
private final double defaultY;
|
private final double defaultY;
|
||||||
|
|
||||||
public RightContinuousStepFunction(double[] x, double[] y, double defaultY) {
|
public RightContinuousStepFunction(double[] x, double[] y, double defaultY) {
|
||||||
|
@ -139,11 +136,6 @@ public final class RightContinuousStepFunction extends StepFunction {
|
||||||
return -integrate(to, from);
|
return -integrate(to, from);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Edge case - no points; just defaultY
|
|
||||||
if(this.x.length == 0){
|
|
||||||
return (to - from) * this.defaultY;
|
|
||||||
}
|
|
||||||
|
|
||||||
double summation = 0.0;
|
double summation = 0.0;
|
||||||
final double[] xPoints = getX();
|
final double[] xPoints = getX();
|
||||||
final int startingIndex;
|
final int startingIndex;
|
|
@ -45,20 +45,20 @@ public class TestDeterministicForests {
|
||||||
|
|
||||||
int index = 0;
|
int index = 0;
|
||||||
for(int j=0; j<5; j++){
|
for(int j=0; j<5; j++){
|
||||||
final NumericCovariate numericCovariate = new NumericCovariate("numeric"+j, index, false);
|
final NumericCovariate numericCovariate = new NumericCovariate("numeric"+j, index);
|
||||||
covariateList.add(numericCovariate);
|
covariateList.add(numericCovariate);
|
||||||
index++;
|
index++;
|
||||||
}
|
}
|
||||||
|
|
||||||
for(int j=0; j<5; j++){
|
for(int j=0; j<5; j++){
|
||||||
final BooleanCovariate booleanCovariate = new BooleanCovariate("boolean"+j, index, false);
|
final BooleanCovariate booleanCovariate = new BooleanCovariate("boolean"+j, index);
|
||||||
covariateList.add(booleanCovariate);
|
covariateList.add(booleanCovariate);
|
||||||
index++;
|
index++;
|
||||||
}
|
}
|
||||||
|
|
||||||
final List<String> levels = Utils.easyList("cat", "dog", "mouse");
|
final List<String> levels = Utils.easyList("cat", "dog", "mouse");
|
||||||
for(int j=0; j<5; j++){
|
for(int j=0; j<5; j++){
|
||||||
final FactorCovariate factorCovariate = new FactorCovariate("factor"+j, index, levels, false);
|
final FactorCovariate factorCovariate = new FactorCovariate("factor"+j, index, levels);
|
||||||
covariateList.add(factorCovariate);
|
covariateList.add(factorCovariate);
|
||||||
index++;
|
index++;
|
||||||
}
|
}
|
||||||
|
@ -214,14 +214,14 @@ public class TestDeterministicForests {
|
||||||
|
|
||||||
forestTrainer5Trees.trainSerialOnDisk(Optional.empty());
|
forestTrainer5Trees.trainSerialOnDisk(Optional.empty());
|
||||||
forestTrainer10Trees.trainSerialOnDisk(Optional.empty());
|
forestTrainer10Trees.trainSerialOnDisk(Optional.empty());
|
||||||
final Forest<Double, Double> forestSerial = DataUtils.loadOnlineForest(saveTreeFile, new MeanResponseCombiner());
|
final Forest<Double, Double> forestSerial = DataUtils.loadForest(saveTreeFile, new MeanResponseCombiner());
|
||||||
TestUtils.removeFolder(saveTreeFile);
|
TestUtils.removeFolder(saveTreeFile);
|
||||||
verifyTwoForestsEqual(testData, referenceForest, forestSerial);
|
verifyTwoForestsEqual(testData, referenceForest, forestSerial);
|
||||||
|
|
||||||
|
|
||||||
forestTrainer5Trees.trainParallelOnDisk(Optional.empty(), 4);
|
forestTrainer5Trees.trainParallelOnDisk(Optional.empty(), 4);
|
||||||
forestTrainer10Trees.trainParallelOnDisk(Optional.empty(), 4);
|
forestTrainer10Trees.trainParallelOnDisk(Optional.empty(), 4);
|
||||||
final Forest<Double, Double> forestParallel = DataUtils.loadOnlineForest(saveTreeFile, new MeanResponseCombiner());
|
final Forest<Double, Double> forestParallel = DataUtils.loadForest(saveTreeFile, new MeanResponseCombiner());
|
||||||
TestUtils.removeFolder(saveTreeFile);
|
TestUtils.removeFolder(saveTreeFile);
|
||||||
verifyTwoForestsEqual(testData, referenceForest, forestParallel);
|
verifyTwoForestsEqual(testData, referenceForest, forestParallel);
|
||||||
|
|
||||||
|
@ -259,7 +259,7 @@ public class TestDeterministicForests {
|
||||||
|
|
||||||
for(int k=0; k<3; k++){
|
for(int k=0; k<3; k++){
|
||||||
forestTrainer.trainSerialOnDisk(Optional.empty());
|
forestTrainer.trainSerialOnDisk(Optional.empty());
|
||||||
final Forest<Double, Double> replicantForest = DataUtils.loadOnlineForest(saveTreeFile, responseCombiner);
|
final Forest<Double, Double> replicantForest = DataUtils.loadForest(saveTreeFile, responseCombiner);
|
||||||
TestUtils.removeFolder(saveTreeFile);
|
TestUtils.removeFolder(saveTreeFile);
|
||||||
verifyTwoForestsEqual(testData, referenceForest, replicantForest);
|
verifyTwoForestsEqual(testData, referenceForest, replicantForest);
|
||||||
}
|
}
|
||||||
|
@ -274,7 +274,7 @@ public class TestDeterministicForests {
|
||||||
|
|
||||||
for(int k=0; k<3; k++){
|
for(int k=0; k<3; k++){
|
||||||
forestTrainer.trainParallelOnDisk(Optional.empty(), 4);
|
forestTrainer.trainParallelOnDisk(Optional.empty(), 4);
|
||||||
final Forest<Double, Double> replicantForest = DataUtils.loadOnlineForest(saveTreeFile, responseCombiner);
|
final Forest<Double, Double> replicantForest = DataUtils.loadForest(saveTreeFile, responseCombiner);
|
||||||
TestUtils.removeFolder(saveTreeFile);
|
TestUtils.removeFolder(saveTreeFile);
|
||||||
verifyTwoForestsEqual(testData, referenceForest, replicantForest);
|
verifyTwoForestsEqual(testData, referenceForest, replicantForest);
|
||||||
}
|
}
|
|
@ -20,8 +20,8 @@ import ca.joeltherrien.randomforest.covariates.Covariate;
|
||||||
import ca.joeltherrien.randomforest.covariates.numeric.NumericCovariate;
|
import ca.joeltherrien.randomforest.covariates.numeric.NumericCovariate;
|
||||||
import ca.joeltherrien.randomforest.responses.regression.MeanResponseCombiner;
|
import ca.joeltherrien.randomforest.responses.regression.MeanResponseCombiner;
|
||||||
import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceSplitFinder;
|
import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceSplitFinder;
|
||||||
|
import ca.joeltherrien.randomforest.tree.Forest;
|
||||||
import ca.joeltherrien.randomforest.tree.ForestTrainer;
|
import ca.joeltherrien.randomforest.tree.ForestTrainer;
|
||||||
import ca.joeltherrien.randomforest.tree.OnlineForest;
|
|
||||||
import ca.joeltherrien.randomforest.tree.Tree;
|
import ca.joeltherrien.randomforest.tree.Tree;
|
||||||
import ca.joeltherrien.randomforest.tree.TreeTrainer;
|
import ca.joeltherrien.randomforest.tree.TreeTrainer;
|
||||||
import ca.joeltherrien.randomforest.utils.DataUtils;
|
import ca.joeltherrien.randomforest.utils.DataUtils;
|
||||||
|
@ -39,12 +39,12 @@ import static org.junit.jupiter.api.Assertions.*;
|
||||||
|
|
||||||
public class TestProvidingInitialForest {
|
public class TestProvidingInitialForest {
|
||||||
|
|
||||||
private OnlineForest<Double, Double> initialForest;
|
private Forest<Double, Double> initialForest;
|
||||||
private List<Covariate> covariateList;
|
private List<Covariate> covariateList;
|
||||||
private List<Row<Double>> data;
|
private List<Row<Double>> data;
|
||||||
|
|
||||||
public TestProvidingInitialForest(){
|
public TestProvidingInitialForest(){
|
||||||
covariateList = Collections.singletonList(new NumericCovariate("x", 0, false));
|
covariateList = Collections.singletonList(new NumericCovariate("x", 0));
|
||||||
|
|
||||||
data = Utils.easyList(
|
data = Utils.easyList(
|
||||||
Row.createSimple(Utils.easyMap("x", "1.0"), covariateList, 1, 1.0),
|
Row.createSimple(Utils.easyMap("x", "1.0"), covariateList, 1, 1.0),
|
||||||
|
@ -107,8 +107,8 @@ public class TestProvidingInitialForest {
|
||||||
public void testSerialInMemory(){
|
public void testSerialInMemory(){
|
||||||
final ForestTrainer<Double, Double, Double> forestTrainer = getForestTrainer(null, 20);
|
final ForestTrainer<Double, Double, Double> forestTrainer = getForestTrainer(null, 20);
|
||||||
|
|
||||||
final OnlineForest<Double, Double> newForest = forestTrainer.trainSerialInMemory(Optional.of(initialForest));
|
final Forest<Double, Double> newForest = forestTrainer.trainSerialInMemory(Optional.of(initialForest));
|
||||||
assertEquals(20, newForest.getNumberOfTrees());
|
assertEquals(20, newForest.getTrees().size());
|
||||||
|
|
||||||
for(Tree<Double> initialTree : initialForest.getTrees()){
|
for(Tree<Double> initialTree : initialForest.getTrees()){
|
||||||
assertTrue(newForest.getTrees().contains(initialTree));
|
assertTrue(newForest.getTrees().contains(initialTree));
|
||||||
|
@ -124,8 +124,8 @@ public class TestProvidingInitialForest {
|
||||||
public void testParallelInMemory(){
|
public void testParallelInMemory(){
|
||||||
final ForestTrainer<Double, Double, Double> forestTrainer = getForestTrainer(null, 20);
|
final ForestTrainer<Double, Double, Double> forestTrainer = getForestTrainer(null, 20);
|
||||||
|
|
||||||
final OnlineForest<Double, Double> newForest = forestTrainer.trainParallelInMemory(Optional.of(initialForest), 2);
|
final Forest<Double, Double> newForest = forestTrainer.trainParallelInMemory(Optional.of(initialForest), 2);
|
||||||
assertEquals(20, newForest.getNumberOfTrees());
|
assertEquals(20, newForest.getTrees().size());
|
||||||
|
|
||||||
for(Tree<Double> initialTree : initialForest.getTrees()){
|
for(Tree<Double> initialTree : initialForest.getTrees()){
|
||||||
assertTrue(newForest.getTrees().contains(initialTree));
|
assertTrue(newForest.getTrees().contains(initialTree));
|
||||||
|
@ -149,11 +149,11 @@ public class TestProvidingInitialForest {
|
||||||
forestTrainer.trainParallelOnDisk(Optional.of(initialForest), 2);
|
forestTrainer.trainParallelOnDisk(Optional.of(initialForest), 2);
|
||||||
|
|
||||||
assertEquals(20, directory.listFiles().length);
|
assertEquals(20, directory.listFiles().length);
|
||||||
final OnlineForest<Double, Double> newForest = DataUtils.loadOnlineForest(directory, new MeanResponseCombiner());
|
final Forest<Double, Double> newForest = DataUtils.loadForest(directory, new MeanResponseCombiner());
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
assertEquals(20, newForest.getNumberOfTrees());
|
assertEquals(20, newForest.getTrees().size());
|
||||||
|
|
||||||
final List<String> newForestTreesAsStrings = newForest.getTrees().stream()
|
final List<String> newForestTreesAsStrings = newForest.getTrees().stream()
|
||||||
.map(tree -> tree.toString()).collect(Collectors.toList());
|
.map(tree -> tree.toString()).collect(Collectors.toList());
|
||||||
|
@ -179,9 +179,9 @@ public class TestProvidingInitialForest {
|
||||||
|
|
||||||
assertEquals(20, directory.listFiles().length);
|
assertEquals(20, directory.listFiles().length);
|
||||||
|
|
||||||
final OnlineForest<Double, Double> newForest = DataUtils.loadOnlineForest(directory, new MeanResponseCombiner());
|
final Forest<Double, Double> newForest = DataUtils.loadForest(directory, new MeanResponseCombiner());
|
||||||
|
|
||||||
assertEquals(20, newForest.getNumberOfTrees());
|
assertEquals(20, newForest.getTrees().size());
|
||||||
|
|
||||||
final List<String> newForestTreesAsStrings = newForest.getTrees().stream()
|
final List<String> newForestTreesAsStrings = newForest.getTrees().stream()
|
||||||
.map(tree -> tree.toString()).collect(Collectors.toList());
|
.map(tree -> tree.toString()).collect(Collectors.toList());
|
||||||
|
@ -198,7 +198,7 @@ public class TestProvidingInitialForest {
|
||||||
it's not clear if the forest being provided is the same one that trees were saved from.
|
it's not clear if the forest being provided is the same one that trees were saved from.
|
||||||
*/
|
*/
|
||||||
@Test
|
@Test
|
||||||
public void testExceptions(){
|
public void verifyExceptions(){
|
||||||
final String filePath = "src/test/resources/trees/";
|
final String filePath = "src/test/resources/trees/";
|
||||||
final File directory = new File(filePath);
|
final File directory = new File(filePath);
|
||||||
if(directory.exists()){
|
if(directory.exists()){
|
|
@ -24,10 +24,11 @@ import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskRespons
|
||||||
import ca.joeltherrien.randomforest.responses.competingrisk.combiner.CompetingRiskFunctionCombiner;
|
import ca.joeltherrien.randomforest.responses.competingrisk.combiner.CompetingRiskFunctionCombiner;
|
||||||
import ca.joeltherrien.randomforest.responses.competingrisk.combiner.CompetingRiskResponseCombiner;
|
import ca.joeltherrien.randomforest.responses.competingrisk.combiner.CompetingRiskResponseCombiner;
|
||||||
import ca.joeltherrien.randomforest.responses.competingrisk.splitfinder.LogRankSplitFinder;
|
import ca.joeltherrien.randomforest.responses.competingrisk.splitfinder.LogRankSplitFinder;
|
||||||
import ca.joeltherrien.randomforest.tree.*;
|
import ca.joeltherrien.randomforest.tree.Forest;
|
||||||
|
import ca.joeltherrien.randomforest.tree.ForestTrainer;
|
||||||
|
import ca.joeltherrien.randomforest.tree.TreeTrainer;
|
||||||
import ca.joeltherrien.randomforest.utils.DataUtils;
|
import ca.joeltherrien.randomforest.utils.DataUtils;
|
||||||
import ca.joeltherrien.randomforest.utils.ResponseLoader;
|
import ca.joeltherrien.randomforest.utils.ResponseLoader;
|
||||||
import ca.joeltherrien.randomforest.utils.RightContinuousStepFunction;
|
|
||||||
import ca.joeltherrien.randomforest.utils.Utils;
|
import ca.joeltherrien.randomforest.utils.Utils;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
|
|
||||||
|
@ -46,10 +47,10 @@ public class TestSavingLoading {
|
||||||
|
|
||||||
public List<Covariate> getCovariates(){
|
public List<Covariate> getCovariates(){
|
||||||
return Utils.easyList(
|
return Utils.easyList(
|
||||||
new NumericCovariate("ageatfda", 0, false),
|
new NumericCovariate("ageatfda", 0),
|
||||||
new BooleanCovariate("idu", 1, false),
|
new BooleanCovariate("idu", 1),
|
||||||
new BooleanCovariate("black", 2, false),
|
new BooleanCovariate("black", 2),
|
||||||
new NumericCovariate("cd4nadir", 3, false)
|
new NumericCovariate("cd4nadir", 3)
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -118,21 +119,16 @@ public class TestSavingLoading {
|
||||||
assertTrue(directory.isDirectory());
|
assertTrue(directory.isDirectory());
|
||||||
assertEquals(NTREE, directory.listFiles().length);
|
assertEquals(NTREE, directory.listFiles().length);
|
||||||
|
|
||||||
final CompetingRiskFunctionCombiner treeResponseCombiner = new CompetingRiskFunctionCombiner(new int[]{1,2}, null);
|
final Forest<CompetingRiskFunctions, CompetingRiskFunctions> forest = DataUtils.loadForest(directory, new CompetingRiskFunctionCombiner(new int[]{1,2}, null));
|
||||||
final OnlineForest<CompetingRiskFunctions, CompetingRiskFunctions> onlineForest = DataUtils.loadOnlineForest(directory, treeResponseCombiner);
|
|
||||||
final OfflineForest<CompetingRiskFunctions, CompetingRiskFunctions> offlineForest = new OfflineForest<>(directory, treeResponseCombiner);
|
|
||||||
|
|
||||||
final CovariateRow predictionRow = getPredictionRow(covariates);
|
final CovariateRow predictionRow = getPredictionRow(covariates);
|
||||||
|
|
||||||
final CompetingRiskFunctions functionsOnline = onlineForest.evaluate(predictionRow);
|
final CompetingRiskFunctions functions = forest.evaluate(predictionRow);
|
||||||
assertNotNull(functionsOnline);
|
assertNotNull(functions);
|
||||||
assertTrue(functionsOnline.getCumulativeIncidenceFunction(1).getX().length > 2);
|
assertTrue(functions.getCumulativeIncidenceFunction(1).getX().length > 2);
|
||||||
|
|
||||||
final CompetingRiskFunctions functionsOffline = offlineForest.evaluate(predictionRow);
|
|
||||||
assertTrue(competingFunctionsEqual(functionsOffline, functionsOnline));
|
|
||||||
|
|
||||||
|
|
||||||
assertEquals(NTREE, onlineForest.getTrees().size());
|
assertEquals(NTREE, forest.getTrees().size());
|
||||||
|
|
||||||
TestUtils.removeFolder(directory);
|
TestUtils.removeFolder(directory);
|
||||||
|
|
||||||
|
@ -163,22 +159,17 @@ public class TestSavingLoading {
|
||||||
assertEquals(NTREE, directory.listFiles().length);
|
assertEquals(NTREE, directory.listFiles().length);
|
||||||
|
|
||||||
|
|
||||||
final CompetingRiskFunctionCombiner treeResponseCombiner = new CompetingRiskFunctionCombiner(new int[]{1,2}, null);
|
|
||||||
final OnlineForest<CompetingRiskFunctions, CompetingRiskFunctions> onlineForest = DataUtils.loadOnlineForest(directory, treeResponseCombiner);
|
final Forest<CompetingRiskFunctions, CompetingRiskFunctions> forest = DataUtils.loadForest(directory, new CompetingRiskFunctionCombiner(new int[]{1,2}, null));
|
||||||
final OfflineForest<CompetingRiskFunctions, CompetingRiskFunctions> offlineForest = new OfflineForest<>(directory, treeResponseCombiner);
|
|
||||||
|
|
||||||
final CovariateRow predictionRow = getPredictionRow(covariates);
|
final CovariateRow predictionRow = getPredictionRow(covariates);
|
||||||
|
|
||||||
final CompetingRiskFunctions functionsOnline = onlineForest.evaluate(predictionRow);
|
final CompetingRiskFunctions functions = forest.evaluate(predictionRow);
|
||||||
assertNotNull(functionsOnline);
|
assertNotNull(functions);
|
||||||
assertTrue(functionsOnline.getCumulativeIncidenceFunction(1).getX().length > 2);
|
assertTrue(functions.getCumulativeIncidenceFunction(1).getX().length > 2);
|
||||||
|
|
||||||
|
|
||||||
final CompetingRiskFunctions functionsOffline = offlineForest.evaluate(predictionRow);
|
assertEquals(NTREE, forest.getTrees().size());
|
||||||
assertTrue(competingFunctionsEqual(functionsOffline, functionsOnline));
|
|
||||||
|
|
||||||
|
|
||||||
assertEquals(NTREE, onlineForest.getTrees().size());
|
|
||||||
|
|
||||||
TestUtils.removeFolder(directory);
|
TestUtils.removeFolder(directory);
|
||||||
|
|
||||||
|
@ -186,64 +177,6 @@ public class TestSavingLoading {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
|
||||||
We don't implement equals() methods on the below mentioned classes because then we'd need to implement an
|
|
||||||
appropriate hashCode() method that's consistent with the equals(), and we only need plain equals() for
|
|
||||||
these tests.
|
|
||||||
*/
|
|
||||||
|
|
||||||
private boolean competingFunctionsEqual(CompetingRiskFunctions f1 ,CompetingRiskFunctions f2){
|
|
||||||
if(!functionsEqual(f1.getSurvivalCurve(), f2.getSurvivalCurve())){
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
for(int i=1; i<=2; i++){
|
|
||||||
if(!functionsEqual(f1.getCauseSpecificHazardFunction(i), f2.getCauseSpecificHazardFunction(i))){
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
if(!functionsEqual(f1.getCumulativeIncidenceFunction(i), f2.getCumulativeIncidenceFunction(i))){
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
private boolean functionsEqual(RightContinuousStepFunction f1, RightContinuousStepFunction f2){
|
|
||||||
|
|
||||||
final double[] f1X = f1.getX();
|
|
||||||
final double[] f2X = f2.getX();
|
|
||||||
|
|
||||||
final double[] f1Y = f1.getY();
|
|
||||||
final double[] f2Y = f2.getY();
|
|
||||||
|
|
||||||
// first compare array lengths
|
|
||||||
if(f1X.length != f2X.length){
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
if(f1Y.length != f2Y.length){
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO - better comparisons of doubles. I don't really care too much though as this equals method is only being used in tests
|
|
||||||
final double delta = 0.000001;
|
|
||||||
|
|
||||||
if(Math.abs(f1.getDefaultY() - f2.getDefaultY()) > delta){
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
for(int i=0; i < f1X.length; i++){
|
|
||||||
if(Math.abs(f1X[i] - f2X[i]) > delta){
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
if(Math.abs(f1Y[i] - f2Y[i]) > delta){
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return true;
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
|
@ -156,7 +156,7 @@ public class TestUtils {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testReduceListToSize(){
|
public void reduceListToSize(){
|
||||||
final List<Integer> testList = Utils.easyList(1, 2, 3, 4, 5, 6, 7, 8, 9, 10);
|
final List<Integer> testList = Utils.easyList(1, 2, 3, 4, 5, 6, 7, 8, 9, 10);
|
||||||
final Random random = new Random();
|
final Random random = new Random();
|
||||||
for(int i=0; i<100; i++) { // want to test many times to ensure it doesn't work just due to randomness
|
for(int i=0; i<100; i++) { // want to test many times to ensure it doesn't work just due to randomness
|
|
@ -52,7 +52,7 @@ public class IBSCalculatorTest {
|
||||||
*/
|
*/
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testResultsWithoutCensoringDistribution(){
|
public void resultsWithoutCensoringDistribution(){
|
||||||
final IBSCalculator calculator = new IBSCalculator();
|
final IBSCalculator calculator = new IBSCalculator();
|
||||||
|
|
||||||
final double errorDifferentEvent = calculator.calculateError(
|
final double errorDifferentEvent = calculator.calculateError(
|
||||||
|
@ -74,7 +74,7 @@ public class IBSCalculatorTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testResultsWithCensoringDistribution(){
|
public void resultsWithCensoringDistribution(){
|
||||||
final RightContinuousStepFunction censorSurvivalFunction = RightContinuousStepFunction.constructFromPoints(
|
final RightContinuousStepFunction censorSurvivalFunction = RightContinuousStepFunction.constructFromPoints(
|
||||||
Utils.easyList(
|
Utils.easyList(
|
||||||
new Point(0.0, 0.75),
|
new Point(0.0, 0.75),
|
|
@ -53,10 +53,10 @@ public class TestCompetingRisk {
|
||||||
|
|
||||||
public List<Covariate> getCovariates(){
|
public List<Covariate> getCovariates(){
|
||||||
return Utils.easyList(
|
return Utils.easyList(
|
||||||
new NumericCovariate("ageatfda", 0, false),
|
new NumericCovariate("ageatfda", 0),
|
||||||
new BooleanCovariate("idu", 1, false),
|
new BooleanCovariate("idu", 1),
|
||||||
new BooleanCovariate("black", 2, false),
|
new BooleanCovariate("black", 2),
|
||||||
new NumericCovariate("cd4nadir", 3, false)
|
new NumericCovariate("cd4nadir", 3)
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -109,8 +109,8 @@ public class TestCompetingRisk {
|
||||||
|
|
||||||
// by only using BooleanCovariates (only one split rule) we can guarantee identical results with randomForestSRC on one tree.
|
// by only using BooleanCovariates (only one split rule) we can guarantee identical results with randomForestSRC on one tree.
|
||||||
final List<Covariate> covariates = Utils.easyList(
|
final List<Covariate> covariates = Utils.easyList(
|
||||||
new BooleanCovariate("idu", 0, false),
|
new BooleanCovariate("idu", 0),
|
||||||
new BooleanCovariate("black", 1, false)
|
new BooleanCovariate("black", 1)
|
||||||
);
|
);
|
||||||
|
|
||||||
final List<Row<CompetingRiskResponse>> dataset = getData(covariates, "src/test/resources/wihs.bootstrapped.csv");
|
final List<Row<CompetingRiskResponse>> dataset = getData(covariates, "src/test/resources/wihs.bootstrapped.csv");
|
||||||
|
@ -210,8 +210,8 @@ public class TestCompetingRisk {
|
||||||
public void testLogRankSplitFinderTwoBooleans() throws IOException {
|
public void testLogRankSplitFinderTwoBooleans() throws IOException {
|
||||||
// by only using BooleanCovariates (only one split rule) we can guarantee identical results with randomForestSRC on one tree.
|
// by only using BooleanCovariates (only one split rule) we can guarantee identical results with randomForestSRC on one tree.
|
||||||
final List<Covariate> covariates = Utils.easyList(
|
final List<Covariate> covariates = Utils.easyList(
|
||||||
new BooleanCovariate("idu", 0, false),
|
new BooleanCovariate("idu", 0),
|
||||||
new BooleanCovariate("black", 1, false)
|
new BooleanCovariate("black", 1)
|
||||||
);
|
);
|
||||||
|
|
||||||
|
|
||||||
|
@ -259,7 +259,7 @@ public class TestCompetingRisk {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testDataset() throws IOException {
|
public void verifyDataset() throws IOException {
|
||||||
final List<Covariate> covariates = getCovariates();
|
final List<Covariate> covariates = getCovariates();
|
||||||
|
|
||||||
final List<Row<CompetingRiskResponse>> dataset = getData(covariates, DEFAULT_FILEPATH);
|
final List<Row<CompetingRiskResponse>> dataset = getData(covariates, DEFAULT_FILEPATH);
|
|
@ -16,11 +16,10 @@
|
||||||
|
|
||||||
package ca.joeltherrien.randomforest.competingrisk;
|
package ca.joeltherrien.randomforest.competingrisk;
|
||||||
|
|
||||||
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskFunctions;
|
import ca.joeltherrien.randomforest.Row;
|
||||||
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
|
import ca.joeltherrien.randomforest.covariates.Covariate;
|
||||||
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskUtils;
|
import ca.joeltherrien.randomforest.responses.competingrisk.*;
|
||||||
import ca.joeltherrien.randomforest.tree.Forest;
|
import ca.joeltherrien.randomforest.tree.Forest;
|
||||||
import ca.joeltherrien.randomforest.tree.OnlineForest;
|
|
||||||
import ca.joeltherrien.randomforest.utils.RightContinuousStepFunction;
|
import ca.joeltherrien.randomforest.utils.RightContinuousStepFunction;
|
||||||
import ca.joeltherrien.randomforest.utils.StepFunction;
|
import ca.joeltherrien.randomforest.utils.StepFunction;
|
||||||
import ca.joeltherrien.randomforest.utils.Utils;
|
import ca.joeltherrien.randomforest.utils.Utils;
|
||||||
|
@ -31,6 +30,8 @@ import java.util.List;
|
||||||
|
|
||||||
import static ca.joeltherrien.randomforest.TestUtils.closeEnough;
|
import static ca.joeltherrien.randomforest.TestUtils.closeEnough;
|
||||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
import static org.mockito.Mockito.mock;
|
||||||
|
import static org.mockito.Mockito.when;
|
||||||
|
|
||||||
public class TestCompetingRiskErrorRateCalculator {
|
public class TestCompetingRiskErrorRateCalculator {
|
||||||
|
|
||||||
|
@ -47,7 +48,7 @@ public class TestCompetingRiskErrorRateCalculator {
|
||||||
|
|
||||||
final int event = 1;
|
final int event = 1;
|
||||||
|
|
||||||
final Forest<CompetingRiskFunctions, CompetingRiskFunctions> fakeForest = OnlineForest.<CompetingRiskFunctions, CompetingRiskFunctions>builder().build();
|
final Forest<CompetingRiskFunctions, CompetingRiskFunctions> fakeForest = Forest.<CompetingRiskFunctions, CompetingRiskFunctions>builder().build();
|
||||||
|
|
||||||
final double naiveConcordance = CompetingRiskUtils.calculateConcordance(responseList, mortalityArray, event);
|
final double naiveConcordance = CompetingRiskUtils.calculateConcordance(responseList, mortalityArray, event);
|
||||||
|
|
|
@ -46,7 +46,7 @@ public class TestLogRankSplitFinder {
|
||||||
public static Data<CompetingRiskResponse> loadData(String filename) throws IOException {
|
public static Data<CompetingRiskResponse> loadData(String filename) throws IOException {
|
||||||
|
|
||||||
final List<Covariate> covariates = Utils.easyList(
|
final List<Covariate> covariates = Utils.easyList(
|
||||||
new NumericCovariate("x2", 0, false)
|
new NumericCovariate("x2", 0)
|
||||||
);
|
);
|
||||||
|
|
||||||
final List<Row<CompetingRiskResponse>> rows = TestUtils.loadData(covariates, new ResponseLoader.CompetingRisksResponseLoader("delta", "u"), filename);
|
final List<Row<CompetingRiskResponse>> rows = TestUtils.loadData(covariates, new ResponseLoader.CompetingRisksResponseLoader("delta", "u"), filename);
|
|
@ -0,0 +1,86 @@
|
||||||
|
/*
|
||||||
|
* Copyright (c) 2019 Joel Therrien.
|
||||||
|
* This program is free software: you can redistribute it and/or modify
|
||||||
|
* it under the terms of the GNU General Public License as published by
|
||||||
|
* the Free Software Foundation, either version 3 of the License, or
|
||||||
|
* (at your option) any later version.
|
||||||
|
*
|
||||||
|
* This program is distributed in the hope that it will be useful,
|
||||||
|
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
* GNU General Public License for more details.
|
||||||
|
*
|
||||||
|
* You should have received a copy of the GNU General Public License
|
||||||
|
* along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package ca.joeltherrien.randomforest.covariates;
|
||||||
|
|
||||||
|
|
||||||
|
import ca.joeltherrien.randomforest.covariates.factor.FactorCovariate;
|
||||||
|
import ca.joeltherrien.randomforest.utils.Utils;
|
||||||
|
import org.junit.jupiter.api.Test;
|
||||||
|
import org.junit.jupiter.api.function.Executable;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Random;
|
||||||
|
|
||||||
|
import static org.junit.jupiter.api.Assertions.*;
|
||||||
|
|
||||||
|
public class FactorCovariateTest {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void verifyEqualLevels() {
|
||||||
|
final FactorCovariate petCovariate = createTestCovariate();
|
||||||
|
|
||||||
|
final FactorCovariate.FactorValue dog1 = petCovariate.createValue("DOG");
|
||||||
|
final FactorCovariate.FactorValue dog2 = petCovariate.createValue("DO" + "G");
|
||||||
|
|
||||||
|
assertSame(dog1, dog2);
|
||||||
|
|
||||||
|
final FactorCovariate.FactorValue cat1 = petCovariate.createValue("CAT");
|
||||||
|
final FactorCovariate.FactorValue cat2 = petCovariate.createValue("CA" + "T");
|
||||||
|
|
||||||
|
assertSame(cat1, cat2);
|
||||||
|
|
||||||
|
final FactorCovariate.FactorValue mouse1 = petCovariate.createValue("MOUSE");
|
||||||
|
final FactorCovariate.FactorValue mouse2 = petCovariate.createValue("MOUS" + "E");
|
||||||
|
|
||||||
|
assertSame(mouse1, mouse2);
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void verifyBadLevelException(){
|
||||||
|
final FactorCovariate petCovariate = createTestCovariate();
|
||||||
|
final Executable badCode = () -> petCovariate.createValue("vulcan");
|
||||||
|
|
||||||
|
assertThrows(IllegalArgumentException.class, badCode, "vulcan is not a level in FactorCovariate pet");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void testAllSubsets(){
|
||||||
|
final FactorCovariate petCovariate = createTestCovariate();
|
||||||
|
|
||||||
|
final List<SplitRule<String>> splitRules = new ArrayList<>();
|
||||||
|
|
||||||
|
petCovariate.generateSplitRuleUpdater(null, 100, new Random())
|
||||||
|
.forEachRemaining(split -> splitRules.add(split.getSplitRule()));
|
||||||
|
|
||||||
|
assertEquals(splitRules.size(), 3);
|
||||||
|
|
||||||
|
// TODO verify the contents of the split rules
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
private FactorCovariate createTestCovariate(){
|
||||||
|
final List<String> levels = Utils.easyList("DOG", "CAT", "MOUSE");
|
||||||
|
|
||||||
|
return new FactorCovariate("pet", 0, levels);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
}
|
|
@ -70,7 +70,7 @@ public class NumericCovariateTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testNumericCovariateDeterministic(){
|
public void testNumericCovariateDeterministic(){
|
||||||
final NumericCovariate covariate = new NumericCovariate("x", 0, false);
|
final NumericCovariate covariate = new NumericCovariate("x", 0);
|
||||||
|
|
||||||
final List<Row<Double>> dataset = createTestDataset(covariate);
|
final List<Row<Double>> dataset = createTestDataset(covariate);
|
||||||
|
|
||||||
|
@ -158,7 +158,7 @@ public class NumericCovariateTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testNumericSplitRuleUpdaterWithIndexes(){
|
public void testNumericSplitRuleUpdaterWithIndexes(){
|
||||||
final NumericCovariate covariate = new NumericCovariate("x", 0, false);
|
final NumericCovariate covariate = new NumericCovariate("x", 0);
|
||||||
|
|
||||||
final List<Row<Double>> dataset = createTestDataset(covariate);
|
final List<Row<Double>> dataset = createTestDataset(covariate);
|
||||||
|
|
||||||
|
@ -223,7 +223,7 @@ public class NumericCovariateTest {
|
||||||
*/
|
*/
|
||||||
@Test
|
@Test
|
||||||
public void testNumericSplitRuleUpdaterWithIndexesAllMissingData(){
|
public void testNumericSplitRuleUpdaterWithIndexesAllMissingData(){
|
||||||
final NumericCovariate covariate = new NumericCovariate("x", 0, false);
|
final NumericCovariate covariate = new NumericCovariate("x", 0);
|
||||||
final List<Row<Double>> dataset = createTestDatasetMissingValues(covariate);
|
final List<Row<Double>> dataset = createTestDatasetMissingValues(covariate);
|
||||||
final NumericSplitRuleUpdater<Double> updater = covariate.generateSplitRuleUpdater(dataset, 5, new Random());
|
final NumericSplitRuleUpdater<Double> updater = covariate.generateSplitRuleUpdater(dataset, 5, new Random());
|
||||||
|
|
|
@ -0,0 +1,77 @@
|
||||||
|
/*
|
||||||
|
* Copyright (c) 2019 Joel Therrien.
|
||||||
|
* This program is free software: you can redistribute it and/or modify
|
||||||
|
* it under the terms of the GNU General Public License as published by
|
||||||
|
* the Free Software Foundation, either version 3 of the License, or
|
||||||
|
* (at your option) any later version.
|
||||||
|
*
|
||||||
|
* This program is distributed in the hope that it will be useful,
|
||||||
|
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
* GNU General Public License for more details.
|
||||||
|
*
|
||||||
|
* You should have received a copy of the GNU General Public License
|
||||||
|
* along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package ca.joeltherrien.randomforest.nas;
|
||||||
|
|
||||||
|
import ca.joeltherrien.randomforest.Row;
|
||||||
|
import ca.joeltherrien.randomforest.covariates.Covariate;
|
||||||
|
import ca.joeltherrien.randomforest.covariates.numeric.NumericCovariate;
|
||||||
|
import ca.joeltherrien.randomforest.responses.regression.MeanResponseCombiner;
|
||||||
|
import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceSplitFinder;
|
||||||
|
import ca.joeltherrien.randomforest.tree.TreeTrainer;
|
||||||
|
import ca.joeltherrien.randomforest.utils.Utils;
|
||||||
|
import org.junit.jupiter.api.Test;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Random;
|
||||||
|
|
||||||
|
public class TestNAs {
|
||||||
|
|
||||||
|
private List<Row<Double>> generateData(List<Covariate> covariates){
|
||||||
|
final List<Row<Double>> dataList = new ArrayList<>();
|
||||||
|
|
||||||
|
|
||||||
|
// We must include an NA for one of the values
|
||||||
|
dataList.add(Row.createSimple(Utils.easyMap("x", "NA"), covariates, 1, 5.0));
|
||||||
|
dataList.add(Row.createSimple(Utils.easyMap("x", "1"), covariates, 1, 6.0));
|
||||||
|
dataList.add(Row.createSimple(Utils.easyMap("x", "2"), covariates, 1, 5.5));
|
||||||
|
dataList.add(Row.createSimple(Utils.easyMap("x", "7"), covariates, 1, 0.0));
|
||||||
|
dataList.add(Row.createSimple(Utils.easyMap("x", "8"), covariates, 1, 1.0));
|
||||||
|
dataList.add(Row.createSimple(Utils.easyMap("x", "8.4"), covariates, 1, 1.0));
|
||||||
|
|
||||||
|
|
||||||
|
return dataList;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testException(){
|
||||||
|
// There was a bug with NAs where when we tried to randomly assign NAs during a split to the best split produced by NumericSplitRuleUpdater,
|
||||||
|
// but NumericSplitRuleUpdater had unmodifiable lists when creating the split.
|
||||||
|
// This bug verifies that this no longer causes a crash
|
||||||
|
|
||||||
|
final List<Covariate> covariates = Collections.singletonList(new NumericCovariate("x", 0));
|
||||||
|
final List<Row<Double>> dataset = generateData(covariates);
|
||||||
|
|
||||||
|
final TreeTrainer<Double, Double> treeTrainer = TreeTrainer.<Double, Double>builder()
|
||||||
|
.checkNodePurity(false)
|
||||||
|
.covariates(covariates)
|
||||||
|
.numberOfSplits(0)
|
||||||
|
.nodeSize(1)
|
||||||
|
.maxNodeDepth(1000)
|
||||||
|
.splitFinder(new WeightedVarianceSplitFinder())
|
||||||
|
.responseCombiner(new MeanResponseCombiner())
|
||||||
|
.build();
|
||||||
|
|
||||||
|
treeTrainer.growTree(dataset, new Random(123));
|
||||||
|
|
||||||
|
// As long as no exception occurs, we passed
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,359 @@
|
||||||
|
package ca.joeltherrien.randomforest.tree.vimp;
|
||||||
|
|
||||||
|
import ca.joeltherrien.randomforest.Row;
|
||||||
|
import ca.joeltherrien.randomforest.covariates.Covariate;
|
||||||
|
import ca.joeltherrien.randomforest.covariates.bool.BooleanCovariate;
|
||||||
|
import ca.joeltherrien.randomforest.covariates.bool.BooleanSplitRule;
|
||||||
|
import ca.joeltherrien.randomforest.covariates.factor.FactorCovariate;
|
||||||
|
import ca.joeltherrien.randomforest.covariates.numeric.NumericCovariate;
|
||||||
|
import ca.joeltherrien.randomforest.covariates.numeric.NumericSplitRule;
|
||||||
|
import ca.joeltherrien.randomforest.responses.regression.MeanResponseCombiner;
|
||||||
|
import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceSplitFinder;
|
||||||
|
import ca.joeltherrien.randomforest.tree.*;
|
||||||
|
import ca.joeltherrien.randomforest.utils.Utils;
|
||||||
|
import org.junit.jupiter.api.Test;
|
||||||
|
|
||||||
|
import java.util.*;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
import java.util.stream.IntStream;
|
||||||
|
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
|
||||||
|
public class VariableImportanceCalculatorTest {
|
||||||
|
|
||||||
|
/*
|
||||||
|
Since the logic for VariableImportanceCalculator is generic, it will be much easier to test under a regression
|
||||||
|
setting.
|
||||||
|
*/
|
||||||
|
|
||||||
|
// We'l have a very simple Forest of two trees
|
||||||
|
private final Forest<Double, Double> forest;
|
||||||
|
|
||||||
|
|
||||||
|
private final List<Covariate> covariates;
|
||||||
|
private final List<Row<Double>> rowList;
|
||||||
|
|
||||||
|
/*
|
||||||
|
Long setup process; forest is manually constructed so that we can be exactly sure on our variable importance.
|
||||||
|
|
||||||
|
*/
|
||||||
|
public VariableImportanceCalculatorTest(){
|
||||||
|
final BooleanCovariate booleanCovariate = new BooleanCovariate("x", 0);
|
||||||
|
final NumericCovariate numericCovariate = new NumericCovariate("y", 1);
|
||||||
|
final FactorCovariate factorCovariate = new FactorCovariate("z", 2,
|
||||||
|
Utils.easyList("red", "blue", "green"));
|
||||||
|
|
||||||
|
this.covariates = Utils.easyList(booleanCovariate, numericCovariate, factorCovariate);
|
||||||
|
|
||||||
|
final TreeTrainer<Double, Double> treeTrainer = TreeTrainer.<Double, Double>builder()
|
||||||
|
.responseCombiner(new MeanResponseCombiner())
|
||||||
|
.splitFinder(new WeightedVarianceSplitFinder())
|
||||||
|
.numberOfSplits(0)
|
||||||
|
.nodeSize(1)
|
||||||
|
.maxNodeDepth(100)
|
||||||
|
.mtry(3)
|
||||||
|
.checkNodePurity(false)
|
||||||
|
.covariates(this.covariates)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
/*
|
||||||
|
Plan for data - BooleanCovariate is split on first and has the largest impact.
|
||||||
|
NumericCovariate is at second level and has more minimal impact.
|
||||||
|
FactorCovariate is useless and never used.
|
||||||
|
Our tree (we'll duplicate it for testing OOB errors) will have a depth of 1. (0 based).
|
||||||
|
*/
|
||||||
|
|
||||||
|
final Tree<Double> tree1 = makeTree(covariates, 0.0, new int[]{1,2,3,4});
|
||||||
|
final Tree<Double> tree2 = makeTree(covariates, 2.0, new int[]{5,6,7,8});
|
||||||
|
|
||||||
|
this.forest = Forest.<Double, Double>builder()
|
||||||
|
.trees(Utils.easyList(tree1, tree2))
|
||||||
|
.treeResponseCombiner(new MeanResponseCombiner())
|
||||||
|
.covariateList(this.covariates)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
// formula; boolean high adds 100; high numeric adds 10
|
||||||
|
// This row list should have a baseline error of 0.0
|
||||||
|
this.rowList = Utils.easyList(
|
||||||
|
Row.createSimple(Utils.easyMap(
|
||||||
|
"x", "false",
|
||||||
|
"y", "0.0",
|
||||||
|
"z", "red"),
|
||||||
|
covariates, 1, 0.0
|
||||||
|
),
|
||||||
|
Row.createSimple(Utils.easyMap(
|
||||||
|
"x", "false",
|
||||||
|
"y", "10.0",
|
||||||
|
"z", "blue"),
|
||||||
|
covariates, 2, 10.0
|
||||||
|
),
|
||||||
|
Row.createSimple(Utils.easyMap(
|
||||||
|
"x", "true",
|
||||||
|
"y", "0.0",
|
||||||
|
"z", "red"),
|
||||||
|
covariates, 3, 100.0
|
||||||
|
),
|
||||||
|
Row.createSimple(Utils.easyMap(
|
||||||
|
"x", "true",
|
||||||
|
"y", "10.0",
|
||||||
|
"z", "green"),
|
||||||
|
covariates, 4, 110.0
|
||||||
|
),
|
||||||
|
|
||||||
|
Row.createSimple(Utils.easyMap(
|
||||||
|
"x", "false",
|
||||||
|
"y", "0.0",
|
||||||
|
"z", "red"),
|
||||||
|
covariates, 5, 0.0
|
||||||
|
),
|
||||||
|
Row.createSimple(Utils.easyMap(
|
||||||
|
"x", "false",
|
||||||
|
"y", "10.0",
|
||||||
|
"z", "blue"),
|
||||||
|
covariates, 6, 10.0
|
||||||
|
),
|
||||||
|
Row.createSimple(Utils.easyMap(
|
||||||
|
"x", "true",
|
||||||
|
"y", "0.0",
|
||||||
|
"z", "red"),
|
||||||
|
covariates, 7, 100.0
|
||||||
|
),
|
||||||
|
Row.createSimple(Utils.easyMap(
|
||||||
|
"x", "true",
|
||||||
|
"y", "10.0",
|
||||||
|
"z", "green"),
|
||||||
|
covariates, 8, 110.0
|
||||||
|
)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
private Tree<Double> makeTree(List<Covariate> covariates, double offset, int[] indices){
|
||||||
|
// Naming convention - xyTerminal where x and y are low/high denotes whether BooleanCovariate(x) is low/high and
|
||||||
|
// whether NumericCovariate(y) is low/high.
|
||||||
|
final TerminalNode<Double> lowLowTerminal = new TerminalNode<>(0.0 + offset, 5);
|
||||||
|
final TerminalNode<Double> lowHighTerminal = new TerminalNode<>(10.0 + offset, 5);
|
||||||
|
final TerminalNode<Double> highLowTerminal = new TerminalNode<>(100.0 + offset, 5);
|
||||||
|
final TerminalNode<Double> highHighTerminal = new TerminalNode<>(110.0 + offset, 5);
|
||||||
|
|
||||||
|
final SplitNode<Double> lowSplitNode = SplitNode.<Double>builder()
|
||||||
|
.leftHand(lowLowTerminal)
|
||||||
|
.rightHand(lowHighTerminal)
|
||||||
|
.probabilityNaLeftHand(0.5)
|
||||||
|
.splitRule(new NumericSplitRule((NumericCovariate) covariates.get(1), 5.0))
|
||||||
|
.build();
|
||||||
|
|
||||||
|
final SplitNode<Double> highSplitNode = SplitNode.<Double>builder()
|
||||||
|
.leftHand(highLowTerminal)
|
||||||
|
.rightHand(highHighTerminal)
|
||||||
|
.probabilityNaLeftHand(0.5)
|
||||||
|
.splitRule(new NumericSplitRule((NumericCovariate) covariates.get(1), 5.0))
|
||||||
|
.build();
|
||||||
|
|
||||||
|
final SplitNode<Double> rootSplitNode = SplitNode.<Double>builder()
|
||||||
|
.leftHand(lowSplitNode)
|
||||||
|
.rightHand(highSplitNode)
|
||||||
|
.probabilityNaLeftHand(0.5)
|
||||||
|
.splitRule(new BooleanSplitRule((BooleanCovariate) covariates.get(0)))
|
||||||
|
.build();
|
||||||
|
|
||||||
|
return new Tree<>(rootSplitNode, indices);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
// Experiment with random seeds to first examine what a split does so we know what to expect
|
||||||
|
/*
|
||||||
|
public static void main(String[] args){
|
||||||
|
final List<Integer> ints1 = IntStream.range(1, 9).boxed().collect(Collectors.toList());
|
||||||
|
final List<Integer> ints2 = IntStream.range(1, 9).boxed().collect(Collectors.toList());
|
||||||
|
|
||||||
|
final Random random = new Random(123);
|
||||||
|
Collections.shuffle(ints1, random);
|
||||||
|
Collections.shuffle(ints2, random);
|
||||||
|
|
||||||
|
System.out.println(ints1);
|
||||||
|
// [1, 4, 8, 2, 5, 3, 7, 6]
|
||||||
|
|
||||||
|
System.out.println(ints2);
|
||||||
|
[6, 1, 4, 7, 5, 2, 8, 3]
|
||||||
|
}
|
||||||
|
*/
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testVariableImportanceOnXNoOOB(){
|
||||||
|
// x is the BooleanCovariate
|
||||||
|
|
||||||
|
Random random = new Random(123);
|
||||||
|
final VariableImportanceCalculator<Double, Double> calculator = new VariableImportanceCalculator<>(
|
||||||
|
new RegressionErrorCalculator(),
|
||||||
|
this.forest,
|
||||||
|
this.rowList,
|
||||||
|
false
|
||||||
|
);
|
||||||
|
|
||||||
|
double importance = calculator.calculateVariableImportance(this.covariates.get(0), Optional.of(random));
|
||||||
|
|
||||||
|
final double expectedBaselineError = 1.0; // Everything is off by 1, so average is 1.0
|
||||||
|
|
||||||
|
final List<Double> permutedPredictions = Utils.easyList(
|
||||||
|
1., 111., 101., 11., 1., 111., 101., 11.
|
||||||
|
);
|
||||||
|
final List<Double> observedValues = this.rowList.stream().map(Row::getResponse).collect(Collectors.toList());
|
||||||
|
|
||||||
|
final double expectedError = new RegressionErrorCalculator().averageError(observedValues, permutedPredictions);
|
||||||
|
|
||||||
|
assertEquals(expectedError - expectedBaselineError, importance, 0.0000001);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testVariableImportanceOnXOOB(){
|
||||||
|
// x is the BooleanCovariate
|
||||||
|
|
||||||
|
Random random = new Random(123);
|
||||||
|
final VariableImportanceCalculator<Double, Double> calculator = new VariableImportanceCalculator<>(
|
||||||
|
new RegressionErrorCalculator(),
|
||||||
|
this.forest,
|
||||||
|
this.rowList,
|
||||||
|
true
|
||||||
|
);
|
||||||
|
|
||||||
|
double importance = calculator.calculateVariableImportance(this.covariates.get(0), Optional.of(random));
|
||||||
|
|
||||||
|
// First 4 observations are off by 2, last 4 are off by 0
|
||||||
|
final double expectedBaselineError = 2.0*2.0 * 4.0 / 8.0;
|
||||||
|
|
||||||
|
// Remember we are working with OOB predictions
|
||||||
|
final List<Double> permutedPredictions = Utils.easyList(
|
||||||
|
2., 112., 102., 12., 0., 110., 100., 10.
|
||||||
|
);
|
||||||
|
final List<Double> observedValues = this.rowList.stream().map(Row::getResponse).collect(Collectors.toList());
|
||||||
|
|
||||||
|
final double expectedError = new RegressionErrorCalculator().averageError(observedValues, permutedPredictions);
|
||||||
|
|
||||||
|
assertEquals(expectedError - expectedBaselineError, importance, 0.0000001);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testVariableImportanceOnYNoOOB(){
|
||||||
|
// y is the NumericCovariate
|
||||||
|
|
||||||
|
Random random = new Random(123);
|
||||||
|
final VariableImportanceCalculator<Double, Double> calculator = new VariableImportanceCalculator<>(
|
||||||
|
new RegressionErrorCalculator(),
|
||||||
|
this.forest,
|
||||||
|
this.rowList,
|
||||||
|
false
|
||||||
|
);
|
||||||
|
|
||||||
|
double importance = calculator.calculateVariableImportance(this.covariates.get(1), Optional.of(random));
|
||||||
|
|
||||||
|
final double expectedBaselineError = 1.0; // Everything is off by 1, so average is 1.0
|
||||||
|
|
||||||
|
final List<Double> permutedPredictions = Utils.easyList(
|
||||||
|
1., 11., 111., 111., 1., 1., 101., 111.
|
||||||
|
);
|
||||||
|
final List<Double> observedValues = this.rowList.stream().map(Row::getResponse).collect(Collectors.toList());
|
||||||
|
|
||||||
|
final double expectedError = new RegressionErrorCalculator().averageError(observedValues, permutedPredictions);
|
||||||
|
|
||||||
|
assertEquals(expectedError - expectedBaselineError, importance, 0.0000001);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testVariableImportanceOnYOOB(){
|
||||||
|
// y is the NumericCovariate
|
||||||
|
|
||||||
|
Random random = new Random(123);
|
||||||
|
final VariableImportanceCalculator<Double, Double> calculator = new VariableImportanceCalculator<>(
|
||||||
|
new RegressionErrorCalculator(),
|
||||||
|
this.forest,
|
||||||
|
this.rowList,
|
||||||
|
true
|
||||||
|
);
|
||||||
|
|
||||||
|
double importance = calculator.calculateVariableImportance(this.covariates.get(1), Optional.of(random));
|
||||||
|
|
||||||
|
// First 4 observations are off by 2, last 4 are off by 0
|
||||||
|
final double expectedBaselineError = 2.0*2.0 * 4.0 / 8.0;
|
||||||
|
|
||||||
|
// Remember we are working with OOB predictions
|
||||||
|
final List<Double> permutedPredictions = Utils.easyList(
|
||||||
|
2., 12., 112., 112., 0., 0., 100., 110.
|
||||||
|
);
|
||||||
|
final List<Double> observedValues = this.rowList.stream().map(Row::getResponse).collect(Collectors.toList());
|
||||||
|
|
||||||
|
final double expectedError = new RegressionErrorCalculator().averageError(observedValues, permutedPredictions);
|
||||||
|
|
||||||
|
assertEquals(expectedError - expectedBaselineError, importance, 0.0000001);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testVariableImportanceOnZNoOOB(){
|
||||||
|
// z is the useless FactorCovariate
|
||||||
|
|
||||||
|
Random random = new Random(123);
|
||||||
|
final VariableImportanceCalculator<Double, Double> calculator = new VariableImportanceCalculator<>(
|
||||||
|
new RegressionErrorCalculator(),
|
||||||
|
this.forest,
|
||||||
|
this.rowList,
|
||||||
|
false
|
||||||
|
);
|
||||||
|
|
||||||
|
double importance = calculator.calculateVariableImportance(this.covariates.get(2), Optional.of(random));
|
||||||
|
|
||||||
|
// FactorImportance did nothing; so permuting it will make no difference to baseline error
|
||||||
|
assertEquals(0, importance, 0.0000001);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testVariableImportanceOnZOOB(){
|
||||||
|
// z is the useless FactorCovariate
|
||||||
|
|
||||||
|
Random random = new Random(123);
|
||||||
|
final VariableImportanceCalculator<Double, Double> calculator = new VariableImportanceCalculator<>(
|
||||||
|
new RegressionErrorCalculator(),
|
||||||
|
this.forest,
|
||||||
|
this.rowList,
|
||||||
|
true
|
||||||
|
);
|
||||||
|
|
||||||
|
double importance = calculator.calculateVariableImportance(this.covariates.get(2), Optional.of(random));
|
||||||
|
|
||||||
|
// FactorImportance did nothing; so permuting it will make no difference to baseline error
|
||||||
|
assertEquals(0, importance, 0.0000001);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testVariableImportanceMultiple(){
|
||||||
|
Random random = new Random(123);
|
||||||
|
final VariableImportanceCalculator<Double, Double> calculator = new VariableImportanceCalculator<>(
|
||||||
|
new RegressionErrorCalculator(),
|
||||||
|
this.forest,
|
||||||
|
this.rowList,
|
||||||
|
false
|
||||||
|
);
|
||||||
|
|
||||||
|
double importance[] = calculator.calculateVariableImportance(covariates, Optional.of(random));
|
||||||
|
|
||||||
|
final double expectedBaselineError = 1.0; // Everything is off by 1, so average is 1.0
|
||||||
|
|
||||||
|
final List<Double> observedValues = this.rowList.stream().map(Row::getResponse).collect(Collectors.toList());
|
||||||
|
|
||||||
|
final List<Double> permutedPredictionsX = Utils.easyList(
|
||||||
|
1., 111., 101., 11., 1., 111., 101., 11.
|
||||||
|
);
|
||||||
|
|
||||||
|
// [6, 1, 4, 7, 5, 2, 8, 3]
|
||||||
|
final List<Double> permutedPredictionsY = Utils.easyList(
|
||||||
|
11., 1., 111., 101., 1., 11., 111., 101.
|
||||||
|
);
|
||||||
|
|
||||||
|
final double expectedErrorX = new RegressionErrorCalculator().averageError(observedValues, permutedPredictionsX);
|
||||||
|
final double expectedErrorY = new RegressionErrorCalculator().averageError(observedValues, permutedPredictionsY);
|
||||||
|
|
||||||
|
assertEquals(expectedErrorX - expectedBaselineError, importance[0], 0.0000001);
|
||||||
|
assertEquals(expectedErrorY - expectedBaselineError, importance[1], 0.0000001);
|
||||||
|
assertEquals(0, importance[2], 0.0000001);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -138,18 +138,4 @@ public class RightContinuousStepFunctionIntegrationTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testIntegratingEmptyFunction(){
|
|
||||||
// A function might have no points, but we'll still need to integrate it.
|
|
||||||
|
|
||||||
final RightContinuousStepFunction function = new RightContinuousStepFunction(
|
|
||||||
new double[]{}, new double[]{}, 1.0
|
|
||||||
);
|
|
||||||
|
|
||||||
final double area = function.integrate(1.0 ,3.0);
|
|
||||||
assertEquals(2.0, area, 0.000001);
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Reference in a new issue