Remove dependencies from project

This project is now purely a library only; the code for running directly from the command line will be
put into a new project. This was important because we were including large dependencies into the R code
that weren't needed and created some minor licensing inconveniences.
This commit is contained in:
Joel Therrien 2019-07-02 16:46:40 -07:00
parent bc2c240823
commit ee4b513298
21 changed files with 266 additions and 1237 deletions

View file

@ -1,11 +1,14 @@
# README # README
This Java software package contains the backend classes used in the R package [largeRCRF](https://github.com/jatherrien/largeRCRF). Most users won't directly use this project, but it can be directly run by configuring a yaml settings file specifying all of the attributes about the random forest and dataset that you can. You're also free to integrate it into your own projects (as long as you follow the terms of the GPL-3 license), or to extend it. More documentation will be added later on how to extend it, but for now if you want an idea I suggest you take a look at the `MeanResponseCombiner` and `WeightedVarianceSplitFinder` classes, which is a small example of a regression random forest implementation. This Java software package contains the backend classes used in the R package [largeRCRF](https://github.com/jatherrien/largeRCRF).
On its own it's not useful, but you're free to integrate it into your own projects (as long as you follow the terms of the GPL-3 license), or extend it. More documentation will be added later on how to extend it, but for now if you want an idea I suggest you take a look at the `MeanResponseCombiner` and `WeightedVarianceSplitFinder` classes, which is a small example of a regression random forest implementation.
If you've made an extension or modification to the package and would like to integrate it into the R package component, build the project in Maven with `mvn clean package`, extract the contents of `largeRCRF-1.0-SNAPSHOT.jar` now found in the `target/` directory into the `inst/java/` directory for the R package (delete all the files previously there). Delete the `META-INF/` directory that was also extracted as that's meta information for the jar file and isn't relevant. Then just build the R package, possibly with your modifications in the R code, with `R> devtools::build()`. If you've made an extension or modification to the package and would like to integrate it into the R package component, build the project in Maven with `mvn clean package` and copy the `largeRCRF-1.0-SNAPSHOT.jar` file now found in the `target/` directory into the `inst/java/` directory for the R package (delete the previous jar file). Then just build the R package, possibly with your modifications in the R code, with `R> devtools::build()`.
If you have any questions on how to run this project, how to extend it, how to integrate it with R, or anything else related to this project, please feel free to either [email me](mailto:joelt@sfu.ca) or create an Issue. If you have any questions on how to integrate this code with your own, how to integrate it with the R project, or anything else related to this project, please feel free to either [email me](mailto:joelt@sfu.ca) or create an Issue.
A small project allowing this code to be called directly outside of R will be released soon.
## System Requirements ## System Requirements
@ -14,25 +17,5 @@ You need:
* A Java runtime version 1.8 or greater * A Java runtime version 1.8 or greater
* Maven to build the project * Maven to build the project
## Troubleshooting (Running directly)
### I get an `OutOfMemoryException` error but I have plenty of RAM
By default the Java virtual machine only uses a quarter of the available system memory. When launching the jar file you can manually specify the memory available like below:
```
java -jar -Xmx15G -Xms15G largeRCRF-1.0-SNAPSHOT.jar settings.yaml
```
with `15G` replaced with a little less than your available system memory.
### I get an `OutOfMemoryException` error and I'm short on RAM
Try reducing the number of trees being trained simultaneously by reducing the number of threads in the settings file.
### Training stalls immediately at 0 trees and the CPU is idle
This issue has been observed before on one particular system (and only on that system) but it's not clear what causes it. If you encounter this, please open an Issue and describe what operating system you're running on, what cloud system (if relevant) you're running on, and the entire output of `java --version`.
From my observation this issues occurs randomly but only at 0 trees; so as an imperfect workaround you can cancel the training and try again. Another imperfect workaround is to set the number of threads to 1; this causes the code to not use Java's parallel code capabilities which will bypass the problem (at the cost of slower training).

24
pom.xml
View file

@ -12,7 +12,6 @@
<java.version>1.8</java.version> <java.version>1.8</java.version>
<maven.compiler.target>1.8</maven.compiler.target> <maven.compiler.target>1.8</maven.compiler.target>
<maven.compiler.source>1.8</maven.compiler.source> <maven.compiler.source>1.8</maven.compiler.source>
<jackson.version>2.9.9</jackson.version>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding> <project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<project.reporting.outputEncoding>UTF-8</project.reporting.outputEncoding> <project.reporting.outputEncoding>UTF-8</project.reporting.outputEncoding>
@ -31,18 +30,7 @@
<groupId>org.apache.commons</groupId> <groupId>org.apache.commons</groupId>
<artifactId>commons-csv</artifactId> <artifactId>commons-csv</artifactId>
<version>1.5</version> <version>1.5</version>
</dependency> <scope>test</scope>
<dependency>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-databind</artifactId>
<version>${jackson.version}</version>
</dependency>
<dependency>
<groupId>com.fasterxml.jackson.dataformat</groupId>
<artifactId>jackson-dataformat-yaml</artifactId>
<version>${jackson.version}</version>
</dependency> </dependency>
<dependency> <dependency>
@ -79,16 +67,6 @@
<goals> <goals>
<goal>shade</goal> <goal>shade</goal>
</goals> </goals>
<configuration>
<transformers>
<transformer implementation="org.apache.maven.plugins.shade.resource.ManifestResourceTransformer">
<mainClass>
ca.joeltherrien.randomforest.Main
</mainClass>
</transformer>
</transformers>
<minimizeJar>true</minimizeJar>
</configuration>
</execution> </execution>
</executions> </executions>
</plugin> </plugin>

View file

@ -1,221 +0,0 @@
/*
* Copyright (c) 2019 Joel Therrien.
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
package ca.joeltherrien.randomforest;
import ca.joeltherrien.randomforest.covariates.Covariate;
import ca.joeltherrien.randomforest.covariates.settings.BooleanCovariateSettings;
import ca.joeltherrien.randomforest.covariates.settings.FactorCovariateSettings;
import ca.joeltherrien.randomforest.covariates.settings.NumericCovariateSettings;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskErrorRateCalculator;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskFunctions;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponseWithCensorTime;
import ca.joeltherrien.randomforest.responses.competingrisk.combiner.CompetingRiskFunctionCombiner;
import ca.joeltherrien.randomforest.tree.Forest;
import ca.joeltherrien.randomforest.tree.ForestTrainer;
import ca.joeltherrien.randomforest.tree.ResponseCombiner;
import ca.joeltherrien.randomforest.utils.DataUtils;
import ca.joeltherrien.randomforest.utils.StepFunction;
import ca.joeltherrien.randomforest.utils.Utils;
import com.fasterxml.jackson.databind.node.JsonNodeFactory;
import com.fasterxml.jackson.databind.node.ObjectNode;
import com.fasterxml.jackson.databind.node.TextNode;
import java.io.File;
import java.io.IOException;
import java.io.PrintWriter;
import java.util.List;
import java.util.Optional;
import java.util.Random;
public class Main {
public static void main(String[] args) throws IOException, ClassNotFoundException {
if(args.length < 2){
System.out.println("Must provide two arguments - the path to the settings.yaml file and instructions to either train or analyze.");
System.out.println("Note that analyzing only supports competing risk data, and that you must then specify a sample size for testing errors.");
if(args.length == 0){
final File templateFile = new File("template.yaml");
if(templateFile.exists()){
System.out.println("Template file exists; not creating a new one");
}
else{
System.out.println("Generating template file.");
defaultTemplate().save(templateFile);
}
}
return;
}
final File settingsFile = new File(args[0]);
final Settings settings = Settings.load(settingsFile);
final List<Covariate> covariates = settings.getCovariates();
if(args[1].equalsIgnoreCase("train")){
final List<Row> dataset = DataUtils.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation());
final ForestTrainer forestTrainer = new ForestTrainer(settings, dataset, covariates);
if(settings.isSaveProgress()){
if(settings.getNumberOfThreads() > 1){
forestTrainer.trainParallelOnDisk(Optional.empty(), settings.getNumberOfThreads());
} else{
forestTrainer.trainSerialOnDisk(Optional.empty());
}
}
else{
if(settings.getNumberOfThreads() > 1){
forestTrainer.trainParallelInMemory(Optional.empty(), settings.getNumberOfThreads());
} else{
forestTrainer.trainSerialInMemory(Optional.empty());
}
}
}
else if(args[1].equalsIgnoreCase("analyze")){
// Perform different prediction measures
if(args.length < 3){
System.out.println("Specify error sample size");
return;
}
final String yVarType = settings.getYVarSettings().get("type").asText();
if(!yVarType.equalsIgnoreCase("CompetingRiskResponse") && !yVarType.equalsIgnoreCase("CompetingRiskResponseWithCensorTime")){
System.out.println("Analyze currently only works on competing risk data");
return;
}
final ResponseCombiner<?, CompetingRiskFunctions> responseCombiner = settings.getTreeCombiner();
final int[] events;
if(responseCombiner instanceof CompetingRiskFunctionCombiner){
events = ((CompetingRiskFunctionCombiner) responseCombiner).getEvents();
}
else{
System.out.println("Unsupported tree combiner");
return;
}
final List<Row<CompetingRiskResponse>> dataset = DataUtils.loadData(covariates, settings.getResponseLoader(), settings.getValidationDataLocation());
// Let's reduce this down to n
final int n = Integer.parseInt(args[2]);
Utils.reduceListToSize(dataset, n, new Random());
final File folder = new File(settings.getSaveTreeLocation());
final Forest<?, CompetingRiskFunctions> forest = DataUtils.loadForest(folder, responseCombiner);
final boolean useBootstrapPredictions = settings.getTrainingDataLocation().equals(settings.getValidationDataLocation());
if(useBootstrapPredictions){
System.out.println("Finished loading trees + dataset; creating calculator and evaluating OOB predictions");
}
else{
System.out.println("Finished loading trees + dataset; creating calculator and evaluating predictions");
}
final CompetingRiskErrorRateCalculator errorRateCalculator = new CompetingRiskErrorRateCalculator(dataset, forest, useBootstrapPredictions);
final PrintWriter printWriter = new PrintWriter(settings.getSaveTreeLocation() + "/errors.txt");
System.out.println("Running Naive Concordance");
final double[] naiveConcordance = errorRateCalculator.calculateConcordance(events);
printWriter.write("Naive concordance:\n");
for(int i=0; i<events.length; i++){
printWriter.write('\t');
printWriter.write(Integer.toString(events[i]));
printWriter.write(": ");
printWriter.write(Double.toString(naiveConcordance[i]));
printWriter.write('\n');
}
if(yVarType.equalsIgnoreCase("CompetingRiskResponseWithCensorTime")){
System.out.println("Running IPCW Concordance - creating censor distribution");
final double[] censorTimes = dataset.stream()
.mapToDouble(row -> ((CompetingRiskResponseWithCensorTime) row.getResponse()).getC())
.toArray();
final StepFunction censorDistribution = Utils.estimateOneMinusECDF(censorTimes);
System.out.println("Finished generating censor distribution - running concordance");
final double[] ipcwConcordance = errorRateCalculator.calculateIPCWConcordance(events, censorDistribution);
printWriter.write("IPCW concordance:\n");
for(int i=0; i<events.length; i++){
printWriter.write('\t');
printWriter.write(Integer.toString(events[i]));
printWriter.write(": ");
printWriter.write(Double.toString(ipcwConcordance[i]));
printWriter.write('\n');
}
}
printWriter.close();
}
else{
System.out.println("Invalid instruction; use either train or analyze.");
System.out.println("Note that analyzing only supports competing risk data.");
}
}
private static Settings defaultTemplate(){
final ObjectNode splitFinderSettings = new ObjectNode(JsonNodeFactory.instance);
splitFinderSettings.set("type", new TextNode("WeightedVarianceSplitFinder"));
final ObjectNode responseCombinerSettings = new ObjectNode(JsonNodeFactory.instance);
responseCombinerSettings.set("type", new TextNode("MeanResponseCombiner"));
final ObjectNode treeCombinerSettings = new ObjectNode(JsonNodeFactory.instance);
treeCombinerSettings.set("type", new TextNode("MeanResponseCombiner"));
final ObjectNode yVarSettings = new ObjectNode(JsonNodeFactory.instance);
yVarSettings.set("type", new TextNode("Double"));
yVarSettings.set("name", new TextNode("y"));
return Settings.builder()
.covariateSettings(Utils.easyList(
new NumericCovariateSettings("x1"),
new BooleanCovariateSettings("x2"),
new FactorCovariateSettings("x3", Utils.easyList("cat", "mouse", "dog"))
)
)
.trainingDataLocation("training_data.csv")
.validationDataLocation("validation_data.csv")
.responseCombinerSettings(responseCombinerSettings)
.treeCombinerSettings(treeCombinerSettings)
.splitFinderSettings(splitFinderSettings)
.yVarSettings(yVarSettings)
.maxNodeDepth(100000)
.mtry(2)
.nodeSize(5)
.ntree(500)
.numberOfSplits(5)
.numberOfThreads(1)
.saveProgress(true)
.saveTreeLocation("trees/")
.build();
}
}

View file

@ -1,235 +0,0 @@
/*
* Copyright (c) 2019 Joel Therrien.
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
package ca.joeltherrien.randomforest;
import ca.joeltherrien.randomforest.covariates.Covariate;
import ca.joeltherrien.randomforest.covariates.settings.CovariateSettings;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponseWithCensorTime;
import ca.joeltherrien.randomforest.responses.competingrisk.combiner.CompetingRiskFunctionCombiner;
import ca.joeltherrien.randomforest.responses.competingrisk.combiner.CompetingRiskResponseCombiner;
import ca.joeltherrien.randomforest.responses.competingrisk.splitfinder.GrayLogRankSplitFinder;
import ca.joeltherrien.randomforest.responses.competingrisk.splitfinder.LogRankSplitFinder;
import ca.joeltherrien.randomforest.responses.regression.MeanResponseCombiner;
import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceSplitFinder;
import ca.joeltherrien.randomforest.tree.SplitFinder;
import ca.joeltherrien.randomforest.tree.ResponseCombiner;
import ca.joeltherrien.randomforest.utils.DataUtils;
import ca.joeltherrien.randomforest.utils.Utils;
import com.fasterxml.jackson.annotation.JsonIgnore;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.node.JsonNodeFactory;
import com.fasterxml.jackson.databind.node.ObjectNode;
import com.fasterxml.jackson.dataformat.yaml.YAMLFactory;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.EqualsAndHashCode;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Function;
/**
* This class is saved & loaded using a saved configuration file. It contains all relevant settings when training a forest.
*/
@Data
@Builder
@AllArgsConstructor
@EqualsAndHashCode
public class Settings {
private static Map<String, Function<ObjectNode, DataUtils.ResponseLoader>> RESPONSE_LOADER_MAP = new HashMap<>();
public static Function<ObjectNode, DataUtils.ResponseLoader> getResponseLoaderConstructor(final String name){
return RESPONSE_LOADER_MAP.get(name.toLowerCase());
}
public static void registerResponseLoaderConstructor(final String name, final Function<ObjectNode, DataUtils.ResponseLoader> responseLoaderConstructor){
RESPONSE_LOADER_MAP.put(name.toLowerCase(), responseLoaderConstructor);
}
static{
registerResponseLoaderConstructor("double",
node -> new DataUtils.DoubleLoader(node)
);
registerResponseLoaderConstructor("CompetingRiskResponse",
node -> new CompetingRiskResponse.CompetingResponseLoader(node)
);
registerResponseLoaderConstructor("CompetingRiskResponseWithCensorTime",
node -> new CompetingRiskResponseWithCensorTime.CompetingResponseWithCensorTimeLoader(node)
);
}
private static Map<String, Function<ObjectNode, SplitFinder>> SPLIT_FINDER_MAP = new HashMap<>();
public static Function<ObjectNode, SplitFinder> getSplitFinderConstructor(final String name){
return SPLIT_FINDER_MAP.get(name.toLowerCase());
}
public static void registerSplitFinderConstructor(final String name, final Function<ObjectNode, SplitFinder> splitFinderConstructor){
SPLIT_FINDER_MAP.put(name.toLowerCase(), splitFinderConstructor);
}
static{
registerSplitFinderConstructor("WeightedVarianceSplitFinder",
(node) -> new WeightedVarianceSplitFinder()
);
registerSplitFinderConstructor("GrayLogRankSplitFinder",
(objectNode) -> {
final int[] eventsOfFocusArray = Utils.jsonToIntArray(objectNode.get("eventsOfFocus"));
final int[] eventArray = Utils.jsonToIntArray(objectNode.get("events"));
return new GrayLogRankSplitFinder(eventsOfFocusArray, eventArray);
}
);
registerSplitFinderConstructor("LogRankSplitFinder",
(objectNode) -> {
final int[] eventsOfFocusArray = Utils.jsonToIntArray(objectNode.get("eventsOfFocus"));
final int[] eventArray = Utils.jsonToIntArray(objectNode.get("events"));
return new LogRankSplitFinder(eventsOfFocusArray, eventArray);
}
);
}
private static Map<String, Function<ObjectNode, ResponseCombiner>> RESPONSE_COMBINER_MAP = new HashMap<>();
public static Function<ObjectNode, ResponseCombiner> getResponseCombinerConstructor(final String name){
return RESPONSE_COMBINER_MAP.get(name.toLowerCase());
}
public static void registerResponseCombinerConstructor(final String name, final Function<ObjectNode, ResponseCombiner> responseCombinerConstructor){
RESPONSE_COMBINER_MAP.put(name.toLowerCase(), responseCombinerConstructor);
}
static{
registerResponseCombinerConstructor("MeanResponseCombiner",
(node) -> new MeanResponseCombiner()
);
registerResponseCombinerConstructor("CompetingRiskResponseCombiner",
(node) -> {
final int[] events = Utils.jsonToIntArray(node.get("events"));
return new CompetingRiskResponseCombiner(events);
}
);
registerResponseCombinerConstructor("CompetingRiskFunctionCombiner",
(node) -> {
final int[] events = Utils.jsonToIntArray(node.get("events"));
double[] times = null;
if(node.hasNonNull("times")){
times = Utils.jsonToDoubleArray(node.get("times"));
}
return new CompetingRiskFunctionCombiner(events, times);
}
);
}
private int numberOfSplits = 5;
private int nodeSize = 5;
private int maxNodeDepth = 1000000; // basically no maxNodeDepth
private boolean checkNodePurity = false;
private Long randomSeed;
private ObjectNode responseCombinerSettings = new ObjectNode(JsonNodeFactory.instance);
private ObjectNode splitFinderSettings = new ObjectNode(JsonNodeFactory.instance);
private ObjectNode treeCombinerSettings = new ObjectNode(JsonNodeFactory.instance);
private List<CovariateSettings> covariateSettings = new ArrayList<>();
private ObjectNode yVarSettings = new ObjectNode(JsonNodeFactory.instance);
// number of covariates to randomly try
private int mtry = 0;
// number of trees to try
private int ntree = 500;
private String trainingDataLocation = "data.csv";
private String validationDataLocation = "data.csv";
private String saveTreeLocation = "trees/";
private int numberOfThreads = 1;
private boolean saveProgress = false;
public Settings(){
} // required for Jackson
public static Settings load(File file) throws IOException {
final ObjectMapper mapper = new ObjectMapper(new YAMLFactory());
//mapper.enableDefaultTyping();
return mapper.readValue(file, Settings.class);
}
public void save(File file) throws IOException {
final ObjectMapper mapper = new ObjectMapper(new YAMLFactory());
//mapper.enableDefaultTyping();
// Jackson can struggle with some types of Lists, such as that returned by the useful List.of(...)
this.covariateSettings = new ArrayList<>(this.covariateSettings);
mapper.writeValue(file, this);
}
@JsonIgnore
public SplitFinder getSplitFinder(){
final String type = splitFinderSettings.get("type").asText();
return getSplitFinderConstructor(type).apply(splitFinderSettings);
}
@JsonIgnore
public DataUtils.ResponseLoader getResponseLoader(){
final String type = yVarSettings.get("type").asText();
return getResponseLoaderConstructor(type).apply(yVarSettings);
}
@JsonIgnore
public ResponseCombiner getResponseCombiner(){
final String type = responseCombinerSettings.get("type").asText();
return getResponseCombinerConstructor(type).apply(responseCombinerSettings);
}
@JsonIgnore
public ResponseCombiner getTreeCombiner(){
final String type = treeCombinerSettings.get("type").asText();
return getResponseCombinerConstructor(type).apply(treeCombinerSettings);
}
@JsonIgnore
public List<Covariate> getCovariates(){
final List<CovariateSettings> covariateSettingsList = this.getCovariateSettings();
final List<Covariate> covariates = new ArrayList<>(covariateSettingsList.size());
for(int i = 0; i < covariateSettingsList.size(); i++){
covariates.add(covariateSettingsList.get(i).build(i));
}
return covariates;
}
}

View file

@ -1,35 +0,0 @@
/*
* Copyright (c) 2019 Joel Therrien.
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
package ca.joeltherrien.randomforest.covariates.settings;
import ca.joeltherrien.randomforest.covariates.bool.BooleanCovariate;
import lombok.Data;
import lombok.NoArgsConstructor;
@NoArgsConstructor // required by Jackson
@Data
public final class BooleanCovariateSettings extends CovariateSettings<Boolean> {
public BooleanCovariateSettings(String name){
super(name);
}
@Override
public BooleanCovariate build(int index) {
return new BooleanCovariate(name, index);
}
}

View file

@ -1,49 +0,0 @@
/*
* Copyright (c) 2019 Joel Therrien.
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
package ca.joeltherrien.randomforest.covariates.settings;
import ca.joeltherrien.randomforest.covariates.Covariate;
import com.fasterxml.jackson.annotation.JsonSubTypes;
import com.fasterxml.jackson.annotation.JsonTypeInfo;
import lombok.Getter;
import lombok.NoArgsConstructor;
/**
* Nuisance class to work with Jackson for persisting settings.
*
* @param <V>
*/
@NoArgsConstructor // required for Jackson
@Getter
@JsonTypeInfo(
use = JsonTypeInfo.Id.NAME,
property = "type")
@JsonSubTypes({
@JsonSubTypes.Type(value = BooleanCovariateSettings.class, name = "boolean"),
@JsonSubTypes.Type(value = NumericCovariateSettings.class, name = "numeric"),
@JsonSubTypes.Type(value = FactorCovariateSettings.class, name = "factor")
})
public abstract class CovariateSettings<V> {
String name;
CovariateSettings(String name){
this.name = name;
}
public abstract Covariate<V> build(int index);
}

View file

@ -1,41 +0,0 @@
/*
* Copyright (c) 2019 Joel Therrien.
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
package ca.joeltherrien.randomforest.covariates.settings;
import ca.joeltherrien.randomforest.covariates.factor.FactorCovariate;
import lombok.Data;
import lombok.NoArgsConstructor;
import java.util.ArrayList;
import java.util.List;
@NoArgsConstructor // required by Jackson
@Data
public final class FactorCovariateSettings extends CovariateSettings<String> {
private List<String> levels;
public FactorCovariateSettings(String name, List<String> levels){
super(name);
this.levels = new ArrayList<>(levels); // Jackson struggles with List.of(...)
}
@Override
public FactorCovariate build(int index) {
return new FactorCovariate(name, index, levels);
}
}

View file

@ -1,35 +0,0 @@
/*
* Copyright (c) 2019 Joel Therrien.
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
package ca.joeltherrien.randomforest.covariates.settings;
import ca.joeltherrien.randomforest.covariates.numeric.NumericCovariate;
import lombok.Data;
import lombok.NoArgsConstructor;
@NoArgsConstructor // required by Jackson
@Data
public final class NumericCovariateSettings extends CovariateSettings<Double> {
public NumericCovariateSettings(String name){
super(name);
}
@Override
public NumericCovariate build(int index) {
return new NumericCovariate(name, index);
}
}

View file

@ -16,11 +16,7 @@
package ca.joeltherrien.randomforest.responses.competingrisk; package ca.joeltherrien.randomforest.responses.competingrisk;
import ca.joeltherrien.randomforest.utils.DataUtils;
import com.fasterxml.jackson.databind.node.ObjectNode;
import lombok.Data; import lombok.Data;
import lombok.RequiredArgsConstructor;
import org.apache.commons.csv.CSVRecord;
import java.io.Serializable; import java.io.Serializable;
@ -35,24 +31,4 @@ public class CompetingRiskResponse implements Serializable {
} }
@RequiredArgsConstructor
public static class CompetingResponseLoader implements DataUtils.ResponseLoader<CompetingRiskResponse>{
private final String deltaName;
private final String uName;
public CompetingResponseLoader(ObjectNode node){
this.deltaName = node.get("delta").asText();
this.uName = node.get("u").asText();
}
@Override
public CompetingRiskResponse parse(CSVRecord record) {
final int delta = Integer.parseInt(record.get(deltaName));
final double u = Double.parseDouble(record.get(uName));
return new CompetingRiskResponse(delta, u);
}
}
} }

View file

@ -16,12 +16,8 @@
package ca.joeltherrien.randomforest.responses.competingrisk; package ca.joeltherrien.randomforest.responses.competingrisk;
import ca.joeltherrien.randomforest.utils.DataUtils;
import com.fasterxml.jackson.databind.node.ObjectNode;
import lombok.Data; import lombok.Data;
import lombok.EqualsAndHashCode; import lombok.EqualsAndHashCode;
import lombok.RequiredArgsConstructor;
import org.apache.commons.csv.CSVRecord;
/** /**
* See Ishwaran paper on splitting rule modelled after Gray's test. This requires that we know the censor times. * See Ishwaran paper on splitting rule modelled after Gray's test. This requires that we know the censor times.
@ -37,26 +33,4 @@ public final class CompetingRiskResponseWithCensorTime extends CompetingRiskResp
this.c = c; this.c = c;
} }
@RequiredArgsConstructor
public static class CompetingResponseWithCensorTimeLoader implements DataUtils.ResponseLoader<CompetingRiskResponseWithCensorTime>{
private final String deltaName;
private final String uName;
private final String cName;
public CompetingResponseWithCensorTimeLoader(ObjectNode node){
this.deltaName = node.get("delta").asText();
this.uName = node.get("u").asText();
this.cName = node.get("c").asText();
}
@Override
public CompetingRiskResponseWithCensorTime parse(CSVRecord record) {
final int delta = Integer.parseInt(record.get(deltaName));
final double u = Double.parseDouble(record.get(uName));
final double c = Double.parseDouble(record.get(cName));
return new CompetingRiskResponseWithCensorTime(delta, u, c);
}
}
} }

View file

@ -18,7 +18,6 @@ package ca.joeltherrien.randomforest.tree;
import ca.joeltherrien.randomforest.Bootstrapper; import ca.joeltherrien.randomforest.Bootstrapper;
import ca.joeltherrien.randomforest.Row; import ca.joeltherrien.randomforest.Row;
import ca.joeltherrien.randomforest.Settings;
import ca.joeltherrien.randomforest.covariates.Covariate; import ca.joeltherrien.randomforest.covariates.Covariate;
import ca.joeltherrien.randomforest.utils.DataUtils; import ca.joeltherrien.randomforest.utils.DataUtils;
import lombok.*; import lombok.*;
@ -51,24 +50,6 @@ public class ForestTrainer<Y, TO, FO> {
private final String saveTreeLocation; private final String saveTreeLocation;
private final long randomSeed; private final long randomSeed;
public ForestTrainer(final Settings settings, final List<Row<Y>> data, final List<Covariate> covariates){
this.ntree = settings.getNtree();
this.data = data;
this.displayProgress = true;
this.saveTreeLocation = settings.getSaveTreeLocation();
this.covariates = covariates;
this.treeResponseCombiner = settings.getTreeCombiner();
this.treeTrainer = new TreeTrainer<>(settings, covariates);
if(settings.getRandomSeed() != null){
this.randomSeed = settings.getRandomSeed();
}
else{
this.randomSeed = System.nanoTime();
}
}
/** /**
* Train a forest in memory using a single core * Train a forest in memory using a single core
* *

View file

@ -17,7 +17,6 @@
package ca.joeltherrien.randomforest.tree; package ca.joeltherrien.randomforest.tree;
import ca.joeltherrien.randomforest.Row; import ca.joeltherrien.randomforest.Row;
import ca.joeltherrien.randomforest.Settings;
import ca.joeltherrien.randomforest.covariates.Covariate; import ca.joeltherrien.randomforest.covariates.Covariate;
import lombok.AccessLevel; import lombok.AccessLevel;
import lombok.AllArgsConstructor; import lombok.AllArgsConstructor;
@ -50,18 +49,6 @@ public class TreeTrainer<Y, O> {
private final List<Covariate> covariates; private final List<Covariate> covariates;
public TreeTrainer(final Settings settings, final List<Covariate> covariates){
this.numberOfSplits = settings.getNumberOfSplits();
this.nodeSize = settings.getNodeSize();
this.maxNodeDepth = settings.getMaxNodeDepth();
this.mtry = settings.getMtry();
this.checkNodePurity = settings.isCheckNodePurity();
this.responseCombiner = settings.getResponseCombiner();
this.splitFinder = settings.getSplitFinder();
this.covariates = covariates;
}
public Tree<O> growTree(List<Row<Y>> data, Random random){ public Tree<O> growTree(List<Row<Y>> data, Random random){
final Node<O> rootNode = growNode(data, 0, random); final Node<O> rootNode = growNode(data, 0, random);
return new Tree<>(rootNode, data.stream().mapToInt(Row::getId).toArray()); return new Tree<>(rootNode, data.stream().mapToInt(Row::getId).toArray());

View file

@ -16,16 +16,9 @@
package ca.joeltherrien.randomforest.utils; package ca.joeltherrien.randomforest.utils;
import ca.joeltherrien.randomforest.Row;
import ca.joeltherrien.randomforest.covariates.Covariate;
import ca.joeltherrien.randomforest.tree.Forest; import ca.joeltherrien.randomforest.tree.Forest;
import ca.joeltherrien.randomforest.tree.ResponseCombiner; import ca.joeltherrien.randomforest.tree.ResponseCombiner;
import ca.joeltherrien.randomforest.tree.Tree; import ca.joeltherrien.randomforest.tree.Tree;
import com.fasterxml.jackson.databind.node.ObjectNode;
import lombok.RequiredArgsConstructor;
import org.apache.commons.csv.CSVFormat;
import org.apache.commons.csv.CSVParser;
import org.apache.commons.csv.CSVRecord;
import java.io.*; import java.io.*;
import java.util.*; import java.util.*;
@ -34,43 +27,6 @@ import java.util.zip.GZIPOutputStream;
public class DataUtils { public class DataUtils {
public static <Y> List<Row<Y>> loadData(final List<Covariate> covariates, final ResponseLoader<Y> responseLoader, String filename) throws IOException {
final List<Row<Y>> dataset = new ArrayList<>();
final Reader input;
if(filename.endsWith(".gz")){
final FileInputStream inputStream = new FileInputStream(filename);
final GZIPInputStream gzipInputStream = new GZIPInputStream(inputStream);
input = new InputStreamReader(gzipInputStream);
}
else{
input = new FileReader(filename);
}
final CSVParser parser = CSVFormat.RFC4180.withFirstRecordAsHeader().parse(input);
int id = 1;
for(final CSVRecord record : parser){
final Covariate.Value[] valueArray = new Covariate.Value[covariates.size()];
for(final Covariate<?> covariate : covariates){
valueArray[covariate.getIndex()] = covariate.createValue(record.get(covariate.getName()));
}
final Y y = responseLoader.parse(record);
dataset.add(new Row<>(valueArray, id++, y));
}
return dataset;
}
public static <O, FO> Forest<O, FO> loadForest(File folder, ResponseCombiner<O, FO> treeResponseCombiner) throws IOException, ClassNotFoundException { public static <O, FO> Forest<O, FO> loadForest(File folder, ResponseCombiner<O, FO> treeResponseCombiner) throws IOException, ClassNotFoundException {
if(!folder.isDirectory()){ if(!folder.isDirectory()){
throw new IllegalArgumentException("Tree directory must be a directory!"); throw new IllegalArgumentException("Tree directory must be a directory!");
@ -119,23 +75,5 @@ public class DataUtils {
} }
@FunctionalInterface
public interface ResponseLoader<Y>{
Y parse(CSVRecord record);
}
@RequiredArgsConstructor
public static class DoubleLoader implements ResponseLoader<Double> {
private final String yName;
public DoubleLoader(final ObjectNode node){
this.yName = node.get("name").asText();
}
@Override
public Double parse(CSVRecord record) {
return Double.parseDouble(record.get(yName));
}
}
} }

View file

@ -16,8 +16,6 @@
package ca.joeltherrien.randomforest.utils; package ca.joeltherrien.randomforest.utils;
import com.fasterxml.jackson.databind.JsonNode;
import java.util.*; import java.util.*;
public final class Utils { public final class Utils {
@ -210,22 +208,4 @@ public final class Utils {
return map; return map;
} }
public static int[] jsonToIntArray(final JsonNode node){
final Iterator<JsonNode> elements = node.elements();
final List<JsonNode> elementList = new ArrayList<>();
elements.forEachRemaining(n -> elementList.add(n));
final int[] array = elementList.stream().mapToInt(n -> n.asInt()).toArray();
return array;
}
public static double[] jsonToDoubleArray(final JsonNode node){
final Iterator<JsonNode> elements = node.elements();
final List<JsonNode> elementList = new ArrayList<>();
elements.forEachRemaining(n -> elementList.add(n));
final double[] array = elementList.stream().mapToDouble(n -> n.asDouble()).toArray();
return array;
}
} }

View file

@ -16,90 +16,78 @@
package ca.joeltherrien.randomforest; package ca.joeltherrien.randomforest;
import ca.joeltherrien.randomforest.covariates.settings.BooleanCovariateSettings;
import ca.joeltherrien.randomforest.covariates.Covariate; import ca.joeltherrien.randomforest.covariates.Covariate;
import ca.joeltherrien.randomforest.covariates.settings.NumericCovariateSettings; import ca.joeltherrien.randomforest.covariates.bool.BooleanCovariate;
import ca.joeltherrien.randomforest.responses.competingrisk.combiner.CompetingRiskFunctionCombiner; import ca.joeltherrien.randomforest.covariates.numeric.NumericCovariate;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskFunctions; import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskFunctions;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse; import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
import ca.joeltherrien.randomforest.responses.competingrisk.combiner.CompetingRiskFunctionCombiner;
import ca.joeltherrien.randomforest.responses.competingrisk.combiner.CompetingRiskResponseCombiner;
import ca.joeltherrien.randomforest.responses.competingrisk.splitfinder.LogRankSplitFinder;
import ca.joeltherrien.randomforest.tree.Forest; import ca.joeltherrien.randomforest.tree.Forest;
import ca.joeltherrien.randomforest.tree.ForestTrainer; import ca.joeltherrien.randomforest.tree.ForestTrainer;
import ca.joeltherrien.randomforest.tree.TreeTrainer;
import ca.joeltherrien.randomforest.utils.DataUtils; import ca.joeltherrien.randomforest.utils.DataUtils;
import ca.joeltherrien.randomforest.utils.ResponseLoader;
import ca.joeltherrien.randomforest.utils.Utils; import ca.joeltherrien.randomforest.utils.Utils;
import com.fasterxml.jackson.databind.node.*;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import static org.junit.jupiter.api.Assertions.*;
import java.io.File; import java.io.File;
import java.io.IOException; import java.io.IOException;
import java.util.List; import java.util.List;
import java.util.Optional; import java.util.Optional;
import static org.junit.jupiter.api.Assertions.*;
public class TestSavingLoading { public class TestSavingLoading {
private final int NTREE = 10; private static final int NTREE = 10;
private static final String DEFAULT_FILEPATH = "src/test/resources/wihs.csv";
private static final String SAVE_TREE_LOCATION = "src/test/resources/trees/";
/** public List<Covariate> getCovariates(){
* By default uses single log-rank test. return Utils.easyList(
* new NumericCovariate("ageatfda", 0),
* @return new BooleanCovariate("idu", 1),
*/ new BooleanCovariate("black", 2),
public Settings getSettings(){ new NumericCovariate("cd4nadir", 3)
final ObjectNode splitRuleSettings = new ObjectNode(JsonNodeFactory.instance);
splitRuleSettings.set("type", new TextNode("LogRankSplitFinder"));
splitRuleSettings.set("eventsOfFocus",
new ArrayNode(JsonNodeFactory.instance, Utils.easyList(new IntNode(1)))
);
splitRuleSettings.set("events",
new ArrayNode(JsonNodeFactory.instance, Utils.easyList(new IntNode(1), new IntNode(2)))
); );
}
final ObjectNode responseCombinerSettings = new ObjectNode(JsonNodeFactory.instance); public ForestTrainer.ForestTrainerBuilder<CompetingRiskResponse, CompetingRiskFunctions, CompetingRiskFunctions> getForestBuilder(
responseCombinerSettings.set("type", new TextNode("CompetingRiskResponseCombiner")); List<Covariate> covariates,
responseCombinerSettings.set("events", List<Row<CompetingRiskResponse>> data,
new ArrayNode(JsonNodeFactory.instance, Utils.easyList(new IntNode(1), new IntNode(2))) TreeTrainer<CompetingRiskResponse, CompetingRiskFunctions> treeTrainer) {
);
// not setting times
return ForestTrainer.<CompetingRiskResponse, CompetingRiskFunctions, CompetingRiskFunctions>builder()
.treeResponseCombiner(new CompetingRiskFunctionCombiner(new int[]{1,2}, null))
.ntree(NTREE)
.saveTreeLocation("src/test/resources/trees/")
.displayProgress(false)
.covariates(covariates)
.data(data)
.treeTrainer(treeTrainer);
final ObjectNode treeCombinerSettings = new ObjectNode(JsonNodeFactory.instance); }
treeCombinerSettings.set("type", new TextNode("CompetingRiskFunctionCombiner"));
treeCombinerSettings.set("events",
new ArrayNode(JsonNodeFactory.instance, Utils.easyList(new IntNode(1), new IntNode(2)))
);
// not setting times
final ObjectNode yVarSettings = new ObjectNode(JsonNodeFactory.instance); public List<Row<CompetingRiskResponse>> getData(List<Covariate> covariates, String filepath) throws IOException {
yVarSettings.set("type", new TextNode("CompetingRiskResponse")); return TestUtils.loadData(
yVarSettings.set("u", new TextNode("time")); covariates, new ResponseLoader.CompetingRisksResponseLoader("status", "time"),
yVarSettings.set("delta", new TextNode("status")); filepath);
}
return Settings.builder() public TreeTrainer.TreeTrainerBuilder<CompetingRiskResponse, CompetingRiskFunctions> getTreeTrainerBuilder(List<Covariate> covariates){
.covariateSettings(Utils.easyList( return TreeTrainer.<CompetingRiskResponse, CompetingRiskFunctions>builder()
new NumericCovariateSettings("ageatfda"), .covariates(covariates)
new BooleanCovariateSettings("idu"), .splitFinder(new LogRankSplitFinder(new int[]{1}, new int[]{1,2}))
new BooleanCovariateSettings("black"), .responseCombiner(new CompetingRiskResponseCombiner(new int[]{1,2}))
new NumericCovariateSettings("cd4nadir")
)
)
.trainingDataLocation("src/test/resources/wihs.csv")
.validationDataLocation("src/test/resources/wihs.csv")
.responseCombinerSettings(responseCombinerSettings)
.treeCombinerSettings(treeCombinerSettings)
.splitFinderSettings(splitRuleSettings)
.yVarSettings(yVarSettings)
.maxNodeDepth(100000) .maxNodeDepth(100000)
// TODO fill in these settings
.mtry(2) .mtry(2)
.nodeSize(6) .nodeSize(6)
.ntree(NTREE) .numberOfSplits(5);
.numberOfSplits(5)
.numberOfThreads(3)
.saveProgress(true)
.saveTreeLocation("src/test/resources/trees/")
.build();
} }
public CovariateRow getPredictionRow(List<Covariate> covariates){ public CovariateRow getPredictionRow(List<Covariate> covariates){
return CovariateRow.createSimple(Utils.easyMap( return CovariateRow.createSimple(Utils.easyMap(
"ageatfda", "35", "ageatfda", "35",
@ -111,18 +99,19 @@ public class TestSavingLoading {
@Test @Test
public void testSavingLoadingSerial() throws IOException, ClassNotFoundException { public void testSavingLoadingSerial() throws IOException, ClassNotFoundException {
final Settings settings = getSettings(); final List<Covariate> covariates = getCovariates();
final List<Covariate> covariates = settings.getCovariates(); final List<Row<CompetingRiskResponse>> dataset = getData(covariates, DEFAULT_FILEPATH);
final List<Row<CompetingRiskResponse>> dataset = DataUtils.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation());
final File directory = new File(settings.getSaveTreeLocation());
final File directory = new File(SAVE_TREE_LOCATION);
if(directory.exists()){ if(directory.exists()){
TestUtils.removeFolder(directory); TestUtils.removeFolder(directory);
} }
assertFalse(directory.exists()); assertFalse(directory.exists());
directory.mkdir(); directory.mkdir();
final ForestTrainer<CompetingRiskResponse, CompetingRiskFunctions, CompetingRiskFunctions> forestTrainer = new ForestTrainer<>(settings, dataset, covariates); final ForestTrainer<CompetingRiskResponse, CompetingRiskFunctions, CompetingRiskFunctions> forestTrainer =
getForestBuilder(covariates, dataset, getTreeTrainerBuilder(covariates).build()).build();
forestTrainer.trainSerialOnDisk(Optional.empty()); forestTrainer.trainSerialOnDisk(Optional.empty());
@ -130,8 +119,6 @@ public class TestSavingLoading {
assertTrue(directory.isDirectory()); assertTrue(directory.isDirectory());
assertEquals(NTREE, directory.listFiles().length); assertEquals(NTREE, directory.listFiles().length);
final Forest<CompetingRiskFunctions, CompetingRiskFunctions> forest = DataUtils.loadForest(directory, new CompetingRiskFunctionCombiner(new int[]{1,2}, null)); final Forest<CompetingRiskFunctions, CompetingRiskFunctions> forest = DataUtils.loadForest(directory, new CompetingRiskFunctionCombiner(new int[]{1,2}, null));
final CovariateRow predictionRow = getPredictionRow(covariates); final CovariateRow predictionRow = getPredictionRow(covariates);
@ -152,20 +139,20 @@ public class TestSavingLoading {
@Test @Test
public void testSavingLoadingParallel() throws IOException, ClassNotFoundException { public void testSavingLoadingParallel() throws IOException, ClassNotFoundException {
final Settings settings = getSettings(); final List<Covariate> covariates = getCovariates();
final List<Covariate> covariates = settings.getCovariates(); final List<Row<CompetingRiskResponse>> dataset = getData(covariates, DEFAULT_FILEPATH);
final List<Row<CompetingRiskResponse>> dataset = DataUtils.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation());
final File directory = new File(settings.getSaveTreeLocation()); final File directory = new File(SAVE_TREE_LOCATION);
if(directory.exists()){ if(directory.exists()){
TestUtils.removeFolder(directory); TestUtils.removeFolder(directory);
} }
assertFalse(directory.exists()); assertFalse(directory.exists());
directory.mkdir(); directory.mkdir();
final ForestTrainer<CompetingRiskResponse, CompetingRiskFunctions, CompetingRiskFunctions> forestTrainer = new ForestTrainer<>(settings, dataset, covariates); final ForestTrainer<CompetingRiskResponse, CompetingRiskFunctions, CompetingRiskFunctions> forestTrainer =
getForestBuilder(covariates, dataset, getTreeTrainerBuilder(covariates).build()).build();
forestTrainer.trainParallelOnDisk(Optional.empty(), settings.getNumberOfThreads()); forestTrainer.trainParallelOnDisk(Optional.empty(), 2);
assertTrue(directory.exists()); assertTrue(directory.exists());
assertTrue(directory.isDirectory()); assertTrue(directory.isDirectory());

View file

@ -16,23 +16,70 @@
package ca.joeltherrien.randomforest; package ca.joeltherrien.randomforest;
import ca.joeltherrien.randomforest.covariates.Covariate;
import ca.joeltherrien.randomforest.utils.ResponseLoader;
import ca.joeltherrien.randomforest.utils.StepFunction; import ca.joeltherrien.randomforest.utils.StepFunction;
import ca.joeltherrien.randomforest.utils.Utils; import ca.joeltherrien.randomforest.utils.Utils;
import org.apache.commons.csv.CSVFormat;
import org.apache.commons.csv.CSVParser;
import org.apache.commons.csv.CSVRecord;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import java.io.File; import java.io.*;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.HashSet; import java.util.HashSet;
import java.util.List; import java.util.List;
import java.util.Random; import java.util.Random;
import java.util.function.DoubleSupplier; import java.util.function.DoubleSupplier;
import java.util.stream.DoubleStream; import java.util.stream.DoubleStream;
import java.util.zip.GZIPInputStream;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.assertTrue;
public class TestUtils { public class TestUtils {
/*
This code is copied from the runnable part of largeRCRF-Java; it's not included in the non-test code of the library
because we want to avoid packaging dependencies into the library that the R code won't use.
*/
public static <Y> List<Row<Y>> loadData(final List<Covariate> covariates, final ResponseLoader<Y> responseLoader, String filename) throws IOException {
final List<Row<Y>> dataset = new ArrayList<>();
final Reader input;
if(filename.endsWith(".gz")){
final FileInputStream inputStream = new FileInputStream(filename);
final GZIPInputStream gzipInputStream = new GZIPInputStream(inputStream);
input = new InputStreamReader(gzipInputStream);
}
else{
input = new FileReader(filename);
}
final CSVParser parser = CSVFormat.RFC4180.withFirstRecordAsHeader().parse(input);
int id = 1;
for(final CSVRecord record : parser){
final Covariate.Value[] valueArray = new Covariate.Value[covariates.size()];
for(final Covariate<?> covariate : covariates){
valueArray[covariate.getIndex()] = covariate.createValue(record.get(covariate.getName()));
}
final Y y = responseLoader.parse(record);
dataset.add(new Row<>(valueArray, id++, y));
}
return dataset;
}
public static void closeEnough(double expected, double actual, double margin){ public static void closeEnough(double expected, double actual, double margin){
assertTrue(Math.abs(expected - actual) < margin, "Expected " + expected + " but saw " + actual); assertTrue(Math.abs(expected - actual) < margin, "Expected " + expected + " but saw " + actual);
} }

View file

@ -17,22 +17,24 @@
package ca.joeltherrien.randomforest.competingrisk; package ca.joeltherrien.randomforest.competingrisk;
import ca.joeltherrien.randomforest.CovariateRow; import ca.joeltherrien.randomforest.CovariateRow;
import ca.joeltherrien.randomforest.utils.DataUtils;
import ca.joeltherrien.randomforest.Row; import ca.joeltherrien.randomforest.Row;
import ca.joeltherrien.randomforest.Settings; import ca.joeltherrien.randomforest.TestUtils;
import ca.joeltherrien.randomforest.covariates.Covariate; import ca.joeltherrien.randomforest.covariates.Covariate;
import ca.joeltherrien.randomforest.covariates.settings.BooleanCovariateSettings; import ca.joeltherrien.randomforest.covariates.bool.BooleanCovariate;
import ca.joeltherrien.randomforest.covariates.settings.NumericCovariateSettings; import ca.joeltherrien.randomforest.covariates.numeric.NumericCovariate;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskErrorRateCalculator; import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskErrorRateCalculator;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskFunctions; import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskFunctions;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse; import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
import ca.joeltherrien.randomforest.responses.competingrisk.combiner.CompetingRiskFunctionCombiner;
import ca.joeltherrien.randomforest.responses.competingrisk.combiner.CompetingRiskResponseCombiner;
import ca.joeltherrien.randomforest.responses.competingrisk.splitfinder.LogRankSplitFinder;
import ca.joeltherrien.randomforest.tree.Forest; import ca.joeltherrien.randomforest.tree.Forest;
import ca.joeltherrien.randomforest.tree.ForestTrainer; import ca.joeltherrien.randomforest.tree.ForestTrainer;
import ca.joeltherrien.randomforest.tree.Node; import ca.joeltherrien.randomforest.tree.Node;
import ca.joeltherrien.randomforest.tree.TreeTrainer; import ca.joeltherrien.randomforest.tree.TreeTrainer;
import ca.joeltherrien.randomforest.utils.ResponseLoader;
import ca.joeltherrien.randomforest.utils.StepFunction; import ca.joeltherrien.randomforest.utils.StepFunction;
import ca.joeltherrien.randomforest.utils.Utils; import ca.joeltherrien.randomforest.utils.Utils;
import com.fasterxml.jackson.databind.node.*;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import java.io.IOException; import java.io.IOException;
@ -47,68 +49,52 @@ import static org.junit.jupiter.api.Assertions.assertTrue;
public class TestCompetingRisk { public class TestCompetingRisk {
private static final String DEFAULT_FILEPATH = "src/test/resources/wihs.csv";
/** public List<Covariate> getCovariates(){
* By default uses single log-rank test. return Utils.easyList(
* new NumericCovariate("ageatfda", 0),
* @return new BooleanCovariate("idu", 1),
*/ new BooleanCovariate("black", 2),
public Settings getSettings(){ new NumericCovariate("cd4nadir", 3)
final ObjectNode splitFinderSettings = new ObjectNode(JsonNodeFactory.instance);
splitFinderSettings.set("type", new TextNode("LogRankSplitFinder"));
splitFinderSettings.set("eventsOfFocus",
new ArrayNode(JsonNodeFactory.instance, Utils.easyList(new IntNode(1)))
);
splitFinderSettings.set("events",
new ArrayNode(JsonNodeFactory.instance, Utils.easyList(new IntNode(1), new IntNode(2)))
); );
}
final ObjectNode responseCombinerSettings = new ObjectNode(JsonNodeFactory.instance); public ForestTrainer.ForestTrainerBuilder<CompetingRiskResponse, CompetingRiskFunctions, CompetingRiskFunctions> getForestBuilder(
responseCombinerSettings.set("type", new TextNode("CompetingRiskResponseCombiner")); List<Covariate> covariates,
responseCombinerSettings.set("events", List<Row<CompetingRiskResponse>> data,
new ArrayNode(JsonNodeFactory.instance, Utils.easyList(new IntNode(1), new IntNode(2))) TreeTrainer<CompetingRiskResponse, CompetingRiskFunctions> treeTrainer) {
);
// not setting times
return ForestTrainer.<CompetingRiskResponse, CompetingRiskFunctions, CompetingRiskFunctions>builder()
.treeResponseCombiner(new CompetingRiskFunctionCombiner(new int[]{1,2}, null))
.ntree(100)
.saveTreeLocation("trees/")
.displayProgress(false)
.covariates(covariates)
.data(data)
.treeTrainer(treeTrainer);
final ObjectNode treeCombinerSettings = new ObjectNode(JsonNodeFactory.instance); }
treeCombinerSettings.set("type", new TextNode("CompetingRiskFunctionCombiner"));
treeCombinerSettings.set("events",
new ArrayNode(JsonNodeFactory.instance, Utils.easyList(new IntNode(1), new IntNode(2)))
);
// not setting times
final ObjectNode yVarSettings = new ObjectNode(JsonNodeFactory.instance); public List<Row<CompetingRiskResponse>> getData(List<Covariate> covariates, String filepath) throws IOException {
yVarSettings.set("type", new TextNode("CompetingRiskResponse")); return TestUtils.loadData(
yVarSettings.set("u", new TextNode("time")); covariates, new ResponseLoader.CompetingRisksResponseLoader("status", "time"),
yVarSettings.set("delta", new TextNode("status")); filepath);
}
return Settings.builder() public TreeTrainer.TreeTrainerBuilder<CompetingRiskResponse, CompetingRiskFunctions> getTreeTrainerBuilder(List<Covariate> covariates){
.covariateSettings(Utils.easyList( return TreeTrainer.<CompetingRiskResponse, CompetingRiskFunctions>builder()
new NumericCovariateSettings("ageatfda"), .covariates(covariates)
new BooleanCovariateSettings("idu"), .splitFinder(new LogRankSplitFinder(new int[]{1}, new int[]{1,2}))
new BooleanCovariateSettings("black"), .responseCombiner(new CompetingRiskResponseCombiner(new int[]{1,2}))
new NumericCovariateSettings("cd4nadir")
)
)
.trainingDataLocation("src/test/resources/wihs.csv")
.responseCombinerSettings(responseCombinerSettings)
.treeCombinerSettings(treeCombinerSettings)
.splitFinderSettings(splitFinderSettings)
.yVarSettings(yVarSettings)
.maxNodeDepth(100000) .maxNodeDepth(100000)
// TODO fill in these settings
.mtry(2) .mtry(2)
.nodeSize(6) .nodeSize(6)
.ntree(100) .numberOfSplits(5);
.numberOfSplits(5)
.numberOfThreads(3)
.saveProgress(true)
.saveTreeLocation("trees/")
.build();
} }
public CovariateRow getPredictionRow(List<Covariate> covariates){ public CovariateRow getPredictionRow(List<Covariate> covariates){
return CovariateRow.createSimple(Utils.easyMap( return CovariateRow.createSimple(Utils.easyMap(
"ageatfda", "35", "ageatfda", "35",
@ -120,18 +106,17 @@ public class TestCompetingRisk {
@Test @Test
public void testSingleTree() throws IOException { public void testSingleTree() throws IOException {
final Settings settings = getSettings();
settings.setTrainingDataLocation("src/test/resources/wihs.bootstrapped.csv");
settings.setCovariateSettings(Utils.easyList(
new BooleanCovariateSettings("idu"),
new BooleanCovariateSettings("black")
)); // by only using BooleanCovariates (only one split rule) we can guarantee identical results with randomForestSRC on one tree.
final List<Covariate> covariates = settings.getCovariates(); // by only using BooleanCovariates (only one split rule) we can guarantee identical results with randomForestSRC on one tree.
final List<Covariate> covariates = Utils.easyList(
new BooleanCovariate("idu", 0),
new BooleanCovariate("black", 1)
);
final List<Row<CompetingRiskResponse>> dataset = DataUtils.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation()); final List<Row<CompetingRiskResponse>> dataset = getData(covariates, "src/test/resources/wihs.bootstrapped.csv");
final TreeTrainer<CompetingRiskResponse, CompetingRiskFunctions> treeTrainer = getTreeTrainerBuilder(covariates).build();
final TreeTrainer<CompetingRiskResponse, CompetingRiskFunctions> treeTrainer = new TreeTrainer<>(settings, covariates);
final Node<CompetingRiskFunctions> node = treeTrainer.growTree(dataset, new Random()); final Node<CompetingRiskFunctions> node = treeTrainer.growTree(dataset, new Random());
final CovariateRow newRow = getPredictionRow(covariates); final CovariateRow newRow = getPredictionRow(covariates);
@ -175,16 +160,15 @@ public class TestCompetingRisk {
*/ */
@Test @Test
public void testSingleTree2() throws IOException { public void testSingleTree2() throws IOException {
final Settings settings = getSettings();
settings.setMtry(4);
settings.setNumberOfSplits(0);
settings.setTrainingDataLocation("src/test/resources/wihs.bootstrapped2.csv");
final List<Covariate> covariates = settings.getCovariates(); final List<Covariate> covariates = getCovariates();
final List<Row<CompetingRiskResponse>> dataset = getData(covariates, "src/test/resources/wihs.bootstrapped2.csv");
final List<Row<CompetingRiskResponse>> dataset = DataUtils.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation()); final TreeTrainer<CompetingRiskResponse, CompetingRiskFunctions> treeTrainer = getTreeTrainerBuilder(covariates)
.mtry(4)
.numberOfSplits(0)
.build();
final TreeTrainer<CompetingRiskResponse, CompetingRiskFunctions> treeTrainer = new TreeTrainer<>(settings, covariates);
final Node<CompetingRiskFunctions> node = treeTrainer.growTree(dataset, new Random()); final Node<CompetingRiskFunctions> node = treeTrainer.growTree(dataset, new Random());
final CovariateRow newRow = getPredictionRow(covariates); final CovariateRow newRow = getPredictionRow(covariates);
@ -224,17 +208,17 @@ public class TestCompetingRisk {
@Test @Test
public void testLogRankSplitFinderTwoBooleans() throws IOException { public void testLogRankSplitFinderTwoBooleans() throws IOException {
final Settings settings = getSettings(); // by only using BooleanCovariates (only one split rule) we can guarantee identical results with randomForestSRC on one tree.
settings.setCovariateSettings(Utils.easyList( final List<Covariate> covariates = Utils.easyList(
new BooleanCovariateSettings("idu"), new BooleanCovariate("idu", 0),
new BooleanCovariateSettings("black") new BooleanCovariate("black", 1)
)); );
final List<Covariate> covariates = settings.getCovariates();
final List<Row<CompetingRiskResponse>> dataset = DataUtils.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation()); final List<Row<CompetingRiskResponse>> dataset = getData(covariates, DEFAULT_FILEPATH);
final ForestTrainer<CompetingRiskResponse, CompetingRiskFunctions, CompetingRiskFunctions> forestTrainer = new ForestTrainer<>(settings, dataset, covariates); final ForestTrainer<CompetingRiskResponse, CompetingRiskFunctions, CompetingRiskFunctions> forestTrainer =
getForestBuilder(covariates, dataset, getTreeTrainerBuilder(covariates).build()).build();
final Forest<CompetingRiskFunctions, CompetingRiskFunctions> forest = forestTrainer.trainSerialInMemory(Optional.empty()); final Forest<CompetingRiskFunctions, CompetingRiskFunctions> forest = forestTrainer.trainSerialInMemory(Optional.empty());
@ -276,11 +260,9 @@ public class TestCompetingRisk {
@Test @Test
public void verifyDataset() throws IOException { public void verifyDataset() throws IOException {
final Settings settings = getSettings(); final List<Covariate> covariates = getCovariates();
final List<Covariate> covariates = settings.getCovariates(); final List<Row<CompetingRiskResponse>> dataset = getData(covariates, DEFAULT_FILEPATH);
final List<Row<CompetingRiskResponse>> dataset = DataUtils.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation());
// Let's count the events and make sure the data was correctly read. // Let's count the events and make sure the data was correctly read.
int countCensored = 0; int countCensored = 0;
@ -309,44 +291,18 @@ public class TestCompetingRisk {
assertEquals(359, countEventTwo); assertEquals(359, countEventTwo);
} }
/**
* Used to time how long the algorithm takes
*
* @param args Not used.
* @throws IOException
*/
public static void main(String[] args) throws IOException {
// timing
final TestCompetingRisk tcr = new TestCompetingRisk();
final Settings settings = tcr.getSettings();
settings.setNtree(300); // results are too variable at 100
final List<Covariate> covariates = settings.getCovariates();
final List<Row<CompetingRiskResponse>> dataset = DataUtils.loadData(covariates, settings.getResponseLoader(),
settings.getTrainingDataLocation());
final ForestTrainer<CompetingRiskResponse, CompetingRiskFunctions, CompetingRiskFunctions> forestTrainer = new ForestTrainer<>(settings, dataset, covariates);
final long startTime = System.currentTimeMillis();
for(int i=0; i<50; i++){
forestTrainer.trainSerialInMemory(Optional.empty());
}
final long endTime = System.currentTimeMillis();
final double diffTime = endTime - startTime;
System.out.println(diffTime / 1000.0 / 50.0);
}
@Test @Test
public void testLogRankSplitFinderAllCovariates() throws IOException { public void testLogRankSplitFinderAllCovariates() throws IOException {
final List<Covariate> covariates = getCovariates();
final List<Row<CompetingRiskResponse>> dataset = getData(covariates, DEFAULT_FILEPATH);
final Settings settings = getSettings();
settings.setNtree(300); // results are too variable at 100
final List<Covariate> covariates = settings.getCovariates(); final ForestTrainer<CompetingRiskResponse, CompetingRiskFunctions, CompetingRiskFunctions> forestTrainer =
final List<Row<CompetingRiskResponse>> dataset = DataUtils.loadData(covariates, settings.getResponseLoader(), getForestBuilder(covariates, dataset, getTreeTrainerBuilder(covariates).build())
settings.getTrainingDataLocation()); .ntree(300) // results are too variable at 100
final ForestTrainer<CompetingRiskResponse, CompetingRiskFunctions, CompetingRiskFunctions> forestTrainer = new ForestTrainer<>(settings, dataset, covariates); .build();
final Forest<CompetingRiskFunctions, CompetingRiskFunctions> forest = forestTrainer.trainSerialInMemory(Optional.empty()); final Forest<CompetingRiskFunctions, CompetingRiskFunctions> forest = forestTrainer.trainSerialInMemory(Optional.empty());
// prediction row // prediction row

View file

@ -17,20 +17,16 @@
package ca.joeltherrien.randomforest.competingrisk; package ca.joeltherrien.randomforest.competingrisk;
import ca.joeltherrien.randomforest.Row; import ca.joeltherrien.randomforest.Row;
import ca.joeltherrien.randomforest.Settings; import ca.joeltherrien.randomforest.TestUtils;
import ca.joeltherrien.randomforest.covariates.Covariate; import ca.joeltherrien.randomforest.covariates.Covariate;
import ca.joeltherrien.randomforest.covariates.settings.NumericCovariateSettings; import ca.joeltherrien.randomforest.covariates.numeric.NumericCovariate;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse; import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
import ca.joeltherrien.randomforest.responses.competingrisk.splitfinder.LogRankSplitFinder; import ca.joeltherrien.randomforest.responses.competingrisk.splitfinder.LogRankSplitFinder;
import ca.joeltherrien.randomforest.tree.Split; import ca.joeltherrien.randomforest.tree.Split;
import ca.joeltherrien.randomforest.utils.Data; import ca.joeltherrien.randomforest.utils.Data;
import ca.joeltherrien.randomforest.utils.DataUtils; import ca.joeltherrien.randomforest.utils.ResponseLoader;
import ca.joeltherrien.randomforest.utils.SingletonIterator; import ca.joeltherrien.randomforest.utils.SingletonIterator;
import ca.joeltherrien.randomforest.utils.Utils; import ca.joeltherrien.randomforest.utils.Utils;
import com.fasterxml.jackson.databind.node.JsonNodeFactory;
import com.fasterxml.jackson.databind.node.ObjectNode;
import com.fasterxml.jackson.databind.node.TextNode;
import lombok.AllArgsConstructor;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import java.io.IOException; import java.io.IOException;
@ -48,23 +44,12 @@ public class TestLogRankSplitFinder {
} }
public static Data<CompetingRiskResponse> loadData(String filename) throws IOException { public static Data<CompetingRiskResponse> loadData(String filename) throws IOException {
final ObjectNode yVarSettings = new ObjectNode(JsonNodeFactory.instance);
yVarSettings.set("type", new TextNode("CompetingRiskResponse"));
yVarSettings.set("delta", new TextNode("delta"));
yVarSettings.set("u", new TextNode("u"));
final Settings settings = Settings.builder() final List<Covariate> covariates = Utils.easyList(
.trainingDataLocation(filename) new NumericCovariate("x2", 0)
.covariateSettings( );
Utils.easyList(new NumericCovariateSettings("x2"))
)
.yVarSettings(yVarSettings)
.build();
final List<Covariate> covariates = settings.getCovariates(); final List<Row<CompetingRiskResponse>> rows = TestUtils.loadData(covariates, new ResponseLoader.CompetingRisksResponseLoader("delta", "u"), filename);
final DataUtils.ResponseLoader loader = settings.getResponseLoader();
final List<Row<CompetingRiskResponse>> rows = DataUtils.loadData(covariates, loader, settings.getTrainingDataLocation());
return new Data<>(rows, covariates); return new Data<>(rows, covariates);
} }

View file

@ -1,119 +0,0 @@
/*
* Copyright (c) 2019 Joel Therrien.
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
package ca.joeltherrien.randomforest.csv;
import ca.joeltherrien.randomforest.utils.DataUtils;
import ca.joeltherrien.randomforest.Row;
import ca.joeltherrien.randomforest.Settings;
import ca.joeltherrien.randomforest.covariates.settings.BooleanCovariateSettings;
import ca.joeltherrien.randomforest.covariates.Covariate;
import ca.joeltherrien.randomforest.covariates.settings.FactorCovariateSettings;
import ca.joeltherrien.randomforest.covariates.settings.NumericCovariateSettings;
import ca.joeltherrien.randomforest.utils.Utils;
import com.fasterxml.jackson.databind.node.JsonNodeFactory;
import com.fasterxml.jackson.databind.node.ObjectNode;
import com.fasterxml.jackson.databind.node.TextNode;
import org.junit.jupiter.api.Test;
import java.io.IOException;
import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
public class TestLoadingCSV {
/*
y,x1,x2,x3
5,3.0,"mouse",true
2,1.0,"dog",false
9,1.5,"cat",true
-3,NA,NA,NA
*/
public List<Row<Double>> loadData(String filename) throws IOException {
final ObjectNode yVarSettings = new ObjectNode(JsonNodeFactory.instance);
yVarSettings.set("type", new TextNode("Double"));
yVarSettings.set("name", new TextNode("y"));
final Settings settings = Settings.builder()
.trainingDataLocation(filename)
.covariateSettings(
Utils.easyList(new NumericCovariateSettings("x1"),
new FactorCovariateSettings("x2", Utils.easyList("dog", "cat", "mouse")),
new BooleanCovariateSettings("x3"))
)
.yVarSettings(yVarSettings)
.build();
final List<Covariate> covariates = settings.getCovariates();
final DataUtils.ResponseLoader loader = settings.getResponseLoader();
return DataUtils.loadData(covariates, loader, settings.getTrainingDataLocation());
}
@Test
public void verifyLoadingNormal(final List<Covariate> covariates) throws IOException {
final List<Row<Double>> data = loadData("src/test/resources/testCSV.csv");
assertData(data, covariates);
}
@Test
public void verifyLoadingGz(final List<Covariate> covariates) throws IOException {
final List<Row<Double>> data = loadData("src/test/resources/testCSV.csv.gz");
assertData(data, covariates);
}
private void assertData(final List<Row<Double>> data, final List<Covariate> covariates){
final Covariate x1 = covariates.get(0);
final Covariate x2 = covariates.get(0);
final Covariate x3 = covariates.get(0);
assertEquals(4, data.size());
Row<Double> row = data.get(0);
assertEquals(5.0, (double)row.getResponse());
assertEquals(3.0, row.getCovariateValue(x1).getValue());
assertEquals("mouse", row.getCovariateValue(x2).getValue());
assertEquals(true, row.getCovariateValue(x3).getValue());
row = data.get(1);
assertEquals(2.0, (double)row.getResponse());
assertEquals(1.0, row.getCovariateValue(x1).getValue());
assertEquals("dog", row.getCovariateValue(x2).getValue());
assertEquals(false, row.getCovariateValue(x3).getValue());
row = data.get(2);
assertEquals(9.0, (double)row.getResponse());
assertEquals(1.5, row.getCovariateValue(x1).getValue());
assertEquals("cat", row.getCovariateValue(x2).getValue());
assertEquals(true, row.getCovariateValue(x3).getValue());
row = data.get(3);
assertEquals(-3.0, (double)row.getResponse());
assertTrue(row.getCovariateValue(x1).isNA());
assertTrue(row.getCovariateValue(x2).isNA());
assertTrue(row.getCovariateValue(x3).isNA());
}
}

View file

@ -1,85 +0,0 @@
/*
* Copyright (c) 2019 Joel Therrien.
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
package ca.joeltherrien.randomforest.settings;
import ca.joeltherrien.randomforest.Settings;
import static org.junit.jupiter.api.Assertions.assertEquals;
import ca.joeltherrien.randomforest.covariates.settings.BooleanCovariateSettings;
import ca.joeltherrien.randomforest.covariates.settings.FactorCovariateSettings;
import ca.joeltherrien.randomforest.covariates.settings.NumericCovariateSettings;
import ca.joeltherrien.randomforest.utils.Utils;
import com.fasterxml.jackson.databind.node.JsonNodeFactory;
import com.fasterxml.jackson.databind.node.ObjectNode;
import com.fasterxml.jackson.databind.node.TextNode;
import org.junit.jupiter.api.Test;
import java.io.File;
import java.io.IOException;
public class TestPersistence {
@Test
public void testSaving() throws IOException {
final ObjectNode splitFinderSettings = new ObjectNode(JsonNodeFactory.instance);
splitFinderSettings.set("type", new TextNode("WeightedVarianceSplitFinder"));
final ObjectNode responseCombinerSettings = new ObjectNode(JsonNodeFactory.instance);
responseCombinerSettings.set("type", new TextNode("MeanResponseCombiner"));
final ObjectNode treeCombinerSettings = new ObjectNode(JsonNodeFactory.instance);
treeCombinerSettings.set("type", new TextNode("MeanResponseCombiner"));
final ObjectNode yVarSettings = new ObjectNode(JsonNodeFactory.instance);
yVarSettings.set("type", new TextNode("Double"));
yVarSettings.set("name", new TextNode("y"));
final Settings settingsOriginal = Settings.builder()
.covariateSettings(Utils.easyList(
new NumericCovariateSettings("x1"),
new BooleanCovariateSettings("x2"),
new FactorCovariateSettings("x3", Utils.easyList("cat", "mouse", "dog"))
)
)
.trainingDataLocation("training_data.csv")
.validationDataLocation("validation_data.csv")
.responseCombinerSettings(responseCombinerSettings)
.treeCombinerSettings(treeCombinerSettings)
.splitFinderSettings(splitFinderSettings)
.yVarSettings(yVarSettings)
.maxNodeDepth(100000)
.mtry(2)
.nodeSize(5)
.ntree(500)
.numberOfSplits(5)
.numberOfThreads(1)
.saveProgress(true)
.saveTreeLocation("trees/")
.build();
final File templateFile = new File("template.yaml");
settingsOriginal.save(templateFile);
final Settings reloadedSettings = Settings.load(templateFile);
assertEquals(settingsOriginal, reloadedSettings);
//templateFile.delete();
}
}

View file

@ -0,0 +1,77 @@
/*
* Copyright (c) 2019 Joel Therrien.
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
package ca.joeltherrien.randomforest.utils;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponseWithCensorTime;
import lombok.RequiredArgsConstructor;
import org.apache.commons.csv.CSVRecord;
/*
Note - this interface is copied from the runnable largeRCRF-Java project
and is used only for helping some tests load data.
*/
@FunctionalInterface
public interface ResponseLoader<Y> {
Y parse(CSVRecord record);
@RequiredArgsConstructor
class DoubleLoader implements ResponseLoader<Double> {
private final String yName;
@Override
public Double parse(CSVRecord record) {
return Double.parseDouble(record.get(yName));
}
}
@RequiredArgsConstructor
class CompetingRisksResponseLoader implements ResponseLoader<CompetingRiskResponse> {
private final String deltaName;
private final String uName;
@Override
public CompetingRiskResponse parse(CSVRecord record) {
final int delta = Integer.parseInt(record.get(deltaName));
final double u = Double.parseDouble(record.get(uName));
return new CompetingRiskResponse(delta, u);
}
}
@RequiredArgsConstructor
class CompetingRisksResponseWithCensorTimesLoader implements ResponseLoader<CompetingRiskResponseWithCensorTime> {
private final String deltaName;
private final String uName;
private final String cName;
@Override
public CompetingRiskResponseWithCensorTime parse(CSVRecord record) {
final int delta = Integer.parseInt(record.get(deltaName));
final double u = Double.parseDouble(record.get(uName));
final double c = Double.parseDouble(record.get(cName));
return new CompetingRiskResponseWithCensorTime(delta, u, c);
}
}
}