Compare commits
17 commits
Author | SHA1 | Date | |
---|---|---|---|
|
044bf08e3d | ||
|
c4bab39245 | ||
f3a4ef01ed | |||
|
54af805d4d | ||
79a9522ba7 | |||
|
c24626ff61 | ||
|
51696e2546 | ||
|
f1c5b292ed | ||
|
a56ad4433d | ||
|
f23ee21ef3 | ||
|
186de413ed | ||
|
aa1f544ea2 | ||
|
86f6c195d7 | ||
|
9258f75e4e | ||
|
7371dab4f1 | ||
|
ae9a6b9a3f | ||
d7cdc9f6e7 |
68 changed files with 2718 additions and 625 deletions
7
.gitignore
vendored
7
.gitignore
vendored
|
@ -2,7 +2,10 @@
|
||||||
.settings
|
.settings
|
||||||
.project
|
.project
|
||||||
target/
|
target/
|
||||||
|
library/target/
|
||||||
|
executable/target/
|
||||||
*.iml
|
*.iml
|
||||||
.idea
|
.idea
|
||||||
template.yaml
|
library/dependency-reduced-pom.xml
|
||||||
dependency-reduced-pom.xml
|
executable/dependency-reduced-pom.xml
|
||||||
|
executable/template.yaml
|
||||||
|
|
18
README.md
18
README.md
|
@ -1,14 +1,20 @@
|
||||||
# README
|
# README
|
||||||
|
|
||||||
This Java software package contains the backend classes used in the R package [largeRCRF](https://github.com/jatherrien/largeRCRF).
|
This repository contains the largeRCRF Java library, containing all of the logic used for training the random forests. This provides the Jar file used in the R package [largeRCRF](https://github.com/jatherrien/largeRCRF).
|
||||||
|
|
||||||
On its own it's not useful, but you're free to integrate it into your own projects (as long as you follow the terms of the GPL-3 license), or extend it. More documentation will be added later on how to extend it, but for now if you want an idea I suggest you take a look at the `MeanResponseCombiner` and `WeightedVarianceSplitFinder` classes, which is a small example of a regression random forest implementation.
|
Most users interested in training random competing risks forests should use the [R package component](https://github.com/jatherrien/largeRCRF); the content in this repository is only useful for advanced users.
|
||||||
|
|
||||||
If you've made an extension or modification to the package and would like to integrate it into the R package component, build the project in Maven with `mvn clean package` and copy the `largeRCRF-1.0-SNAPSHOT.jar` file now found in the `target/` directory into the `inst/java/` directory for the R package (delete the previous jar file). Then just build the R package, possibly with your modifications in the R code, with `R> devtools::build()`.
|
## License
|
||||||
|
|
||||||
If you have any questions on how to integrate this code with your own, how to integrate it with the R project, or anything else related to this project, please feel free to either [email me](mailto:joelt@sfu.ca) or create an Issue.
|
You're free to use / modify / redistribute the project, as long as you follow the terms of the GPL-3 license.
|
||||||
|
|
||||||
A small project allowing this code to be called directly outside of R will be released soon.
|
## Extending
|
||||||
|
|
||||||
|
Documentation on how to extend the library to add support for other types of random forests will eventually be added, but for now if you're interested in that I suggest you take a look at the `MeanResponseCombiner` and `WeightedVarianceSplitFinder` classes to see how some basic regression random forests were introduced.
|
||||||
|
|
||||||
|
If you've made a modification to the package and would like to integrate it into the R package component, build the project in Maven with `mvn clean package`, then just copy `target/largeRCRF-library-1.0-SNAPSHOT.jar` into the `inst/java/` directory for the R package, replacing the previous Jar file there. Then build the R package, possibly with your modifications to the code there, with `R> devtools::build()`.
|
||||||
|
|
||||||
|
Please don't take the current lack of documentation as a sign that I oppose others extending or modifying the project; if you have any questions on running, extending, integrating with R, or anything else related to this project, please don't hesitate to either [email me](mailto:joelt@sfu.ca) or create an Issue. Most likely my answers to your questions will end up forming the basis for any documentation written.
|
||||||
|
|
||||||
## System Requirements
|
## System Requirements
|
||||||
|
|
||||||
|
@ -17,5 +23,3 @@ You need:
|
||||||
* A Java runtime version 1.8 or greater
|
* A Java runtime version 1.8 or greater
|
||||||
* Maven to build the project
|
* Maven to build the project
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
11
pom.xml
11
pom.xml
|
@ -4,8 +4,8 @@
|
||||||
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
||||||
<modelVersion>4.0.0</modelVersion>
|
<modelVersion>4.0.0</modelVersion>
|
||||||
|
|
||||||
<groupId>ca.joeltherrien</groupId>
|
<groupId>ca.joeltherrien.ca</groupId>
|
||||||
<artifactId>largeRCRF</artifactId>
|
<artifactId>largeRCRF-library</artifactId>
|
||||||
<version>1.0-SNAPSHOT</version>
|
<version>1.0-SNAPSHOT</version>
|
||||||
|
|
||||||
<properties>
|
<properties>
|
||||||
|
@ -60,7 +60,8 @@
|
||||||
<plugins>
|
<plugins>
|
||||||
<plugin>
|
<plugin>
|
||||||
<groupId>org.apache.maven.plugins</groupId>
|
<groupId>org.apache.maven.plugins</groupId>
|
||||||
<artifactId>maven-shade-plugin</artifactId>
|
<artifactId>maven-shade-plugin</artifactId>
|
||||||
|
<version>3.2.1</version>
|
||||||
<executions>
|
<executions>
|
||||||
<execution>
|
<execution>
|
||||||
<phase>package</phase>
|
<phase>package</phase>
|
||||||
|
@ -85,7 +86,7 @@
|
||||||
<configuration>
|
<configuration>
|
||||||
<rulesets>
|
<rulesets>
|
||||||
<!-- Custom local file system rule set -->
|
<!-- Custom local file system rule set -->
|
||||||
<ruleset>${project.basedir}/pmd-rules.xml</ruleset>
|
<ruleset>pmd-rules.xml</ruleset>
|
||||||
</rulesets>
|
</rulesets>
|
||||||
</configuration>
|
</configuration>
|
||||||
</plugin>
|
</plugin>
|
||||||
|
@ -94,4 +95,4 @@
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
</project>
|
</project>
|
||||||
|
|
|
@ -21,12 +21,13 @@ import lombok.Getter;
|
||||||
import lombok.RequiredArgsConstructor;
|
import lombok.RequiredArgsConstructor;
|
||||||
|
|
||||||
import java.io.Serializable;
|
import java.io.Serializable;
|
||||||
import java.util.HashMap;
|
import java.util.*;
|
||||||
import java.util.List;
|
import java.util.stream.Collectors;
|
||||||
import java.util.Map;
|
|
||||||
|
|
||||||
@RequiredArgsConstructor
|
@RequiredArgsConstructor
|
||||||
public class CovariateRow implements Serializable {
|
public class CovariateRow implements Serializable, Cloneable {
|
||||||
|
|
||||||
|
private static final long serialVersionUID = 1L;
|
||||||
|
|
||||||
private final Covariate.Value[] valueArray;
|
private final Covariate.Value[] valueArray;
|
||||||
|
|
||||||
|
@ -46,6 +47,14 @@ public class CovariateRow implements Serializable {
|
||||||
return "CovariateRow " + this.id;
|
return "CovariateRow " + this.id;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public CovariateRow clone() {
|
||||||
|
// shallow clone, which is fine. I want a new array, but the values don't need to be copied
|
||||||
|
final Covariate.Value[] copyValueArray = this.valueArray.clone();
|
||||||
|
|
||||||
|
return new CovariateRow(copyValueArray, this.id);
|
||||||
|
}
|
||||||
|
|
||||||
public static CovariateRow createSimple(Map<String, String> simpleMap, List<Covariate> covariateList, int id){
|
public static CovariateRow createSimple(Map<String, String> simpleMap, List<Covariate> covariateList, int id){
|
||||||
final Covariate.Value[] valueArray = new Covariate.Value[covariateList.size()];
|
final Covariate.Value[] valueArray = new Covariate.Value[covariateList.size()];
|
||||||
final Map<String, Covariate> covariateMap = new HashMap<>();
|
final Map<String, Covariate> covariateMap = new HashMap<>();
|
||||||
|
@ -64,4 +73,27 @@ public class CovariateRow implements Serializable {
|
||||||
return new CovariateRow(valueArray, id);
|
return new CovariateRow(valueArray, id);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Used for variable importance; takes a List of CovariateRows and permute one of the Covariates.
|
||||||
|
*
|
||||||
|
* @param covariateRows The List of CovariateRows to scramble. Note that the originals won't be modified.
|
||||||
|
* @param covariateToScramble The Covariate to scramble on.
|
||||||
|
* @param random The source of randomness to use. If not present, one will be created.
|
||||||
|
* @return A List of CovariateRows where the specified covariate was scrambled. These are different objects from the ones provided.
|
||||||
|
*/
|
||||||
|
public static List<CovariateRow> scrambleCovariateValues(List<? extends CovariateRow> covariateRows, Covariate covariateToScramble, Optional<Random> random){
|
||||||
|
final List<CovariateRow> permutedCovariateRowList = new ArrayList<>(covariateRows);
|
||||||
|
Collections.shuffle(permutedCovariateRowList, random.orElse(new Random())); // without replacement
|
||||||
|
|
||||||
|
final List<CovariateRow> clonedRowList = covariateRows.stream().map(CovariateRow::clone).collect(Collectors.toList());
|
||||||
|
|
||||||
|
final int covariateToScrambleIndex = covariateToScramble.getIndex();
|
||||||
|
for(int i=0; i < covariateRows.size(); i++){
|
||||||
|
clonedRowList.get(i).valueArray[covariateToScrambleIndex] = permutedCovariateRowList.get(i).valueArray[covariateToScrambleIndex];
|
||||||
|
}
|
||||||
|
|
||||||
|
return clonedRowList;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -32,9 +32,6 @@ public class Row<Y> extends CovariateRow {
|
||||||
this.response = response;
|
this.response = response;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
public Y getResponse() {
|
public Y getResponse() {
|
||||||
return this.response;
|
return this.response;
|
||||||
}
|
}
|
||||||
|
|
|
@ -49,6 +49,8 @@ public interface Covariate<V> extends Serializable, Comparable<Covariate> {
|
||||||
return getIndex() - other.getIndex();
|
return getIndex() - other.getIndex();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
boolean haveNASplitPenalty();
|
||||||
|
|
||||||
interface Value<V> extends Serializable{
|
interface Value<V> extends Serializable{
|
||||||
|
|
||||||
Covariate<V> getParent();
|
Covariate<V> getParent();
|
||||||
|
|
|
@ -25,9 +25,12 @@ import lombok.Getter;
|
||||||
import java.util.Iterator;
|
import java.util.Iterator;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Random;
|
import java.util.Random;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
public final class BooleanCovariate implements Covariate<Boolean> {
|
public final class BooleanCovariate implements Covariate<Boolean> {
|
||||||
|
|
||||||
|
private static final long serialVersionUID = 1L;
|
||||||
|
|
||||||
@Getter
|
@Getter
|
||||||
private final String name;
|
private final String name;
|
||||||
|
|
||||||
|
@ -38,14 +41,26 @@ public final class BooleanCovariate implements Covariate<Boolean> {
|
||||||
|
|
||||||
private final BooleanSplitRule splitRule; // there's only one possible rule for BooleanCovariates.
|
private final BooleanSplitRule splitRule; // there's only one possible rule for BooleanCovariates.
|
||||||
|
|
||||||
public BooleanCovariate(String name, int index){
|
private final boolean haveNASplitPenalty;
|
||||||
|
@Override
|
||||||
|
public boolean haveNASplitPenalty(){
|
||||||
|
// penalty would add worthless computational time if there are no NAs
|
||||||
|
return hasNAs && haveNASplitPenalty;
|
||||||
|
}
|
||||||
|
|
||||||
|
public BooleanCovariate(String name, int index, boolean haveNASplitPenalty){
|
||||||
this.name = name;
|
this.name = name;
|
||||||
this.index = index;
|
this.index = index;
|
||||||
splitRule = new BooleanSplitRule(this);
|
this.splitRule = new BooleanSplitRule(this);
|
||||||
|
this.haveNASplitPenalty = haveNASplitPenalty;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public <Y> Iterator<Split<Y, Boolean>> generateSplitRuleUpdater(List<Row<Y>> data, int number, Random random) {
|
public <Y> Iterator<Split<Y, Boolean>> generateSplitRuleUpdater(List<Row<Y>> data, int number, Random random) {
|
||||||
|
if(hasNAs){
|
||||||
|
data = data.stream().filter(row -> !row.getValueByIndex(index).isNA()).collect(Collectors.toList());
|
||||||
|
}
|
||||||
|
|
||||||
return new SingletonIterator<>(this.splitRule.applyRule(data));
|
return new SingletonIterator<>(this.splitRule.applyRule(data));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -84,6 +99,8 @@ public final class BooleanCovariate implements Covariate<Boolean> {
|
||||||
|
|
||||||
public class BooleanValue implements Value<Boolean>{
|
public class BooleanValue implements Value<Boolean>{
|
||||||
|
|
||||||
|
private static final long serialVersionUID = 1L;
|
||||||
|
|
||||||
private final Boolean value;
|
private final Boolean value;
|
||||||
|
|
||||||
private BooleanValue(final Boolean value){
|
private BooleanValue(final Boolean value){
|
||||||
|
|
|
@ -21,6 +21,8 @@ import ca.joeltherrien.randomforest.covariates.SplitRule;
|
||||||
|
|
||||||
public class BooleanSplitRule implements SplitRule<Boolean> {
|
public class BooleanSplitRule implements SplitRule<Boolean> {
|
||||||
|
|
||||||
|
private static final long serialVersionUID = 1L;
|
||||||
|
|
||||||
private final int parentCovariateIndex;
|
private final int parentCovariateIndex;
|
||||||
|
|
||||||
public BooleanSplitRule(BooleanCovariate parent){
|
public BooleanSplitRule(BooleanCovariate parent){
|
||||||
|
|
|
@ -23,9 +23,12 @@ import lombok.EqualsAndHashCode;
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
|
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
public final class FactorCovariate implements Covariate<String> {
|
public final class FactorCovariate implements Covariate<String> {
|
||||||
|
|
||||||
|
private static final long serialVersionUID = 1L;
|
||||||
|
|
||||||
@Getter
|
@Getter
|
||||||
private final String name;
|
private final String name;
|
||||||
|
|
||||||
|
@ -38,8 +41,15 @@ public final class FactorCovariate implements Covariate<String> {
|
||||||
|
|
||||||
private boolean hasNAs;
|
private boolean hasNAs;
|
||||||
|
|
||||||
|
private final boolean haveNASplitPenalty;
|
||||||
|
@Override
|
||||||
|
public boolean haveNASplitPenalty(){
|
||||||
|
// penalty would add worthless computational time if there are no NAs
|
||||||
|
return hasNAs && haveNASplitPenalty;
|
||||||
|
}
|
||||||
|
|
||||||
public FactorCovariate(final String name, final int index, List<String> levels){
|
|
||||||
|
public FactorCovariate(final String name, final int index, List<String> levels, final boolean haveNASplitPenalty){
|
||||||
this.name = name;
|
this.name = name;
|
||||||
this.index = index;
|
this.index = index;
|
||||||
this.factorLevels = new HashMap<>();
|
this.factorLevels = new HashMap<>();
|
||||||
|
@ -61,12 +71,22 @@ public final class FactorCovariate implements Covariate<String> {
|
||||||
this.numberOfPossiblePairings = numberOfPossiblePairingsTemp-1;
|
this.numberOfPossiblePairings = numberOfPossiblePairingsTemp-1;
|
||||||
|
|
||||||
this.naValue = new FactorValue(null);
|
this.naValue = new FactorValue(null);
|
||||||
|
|
||||||
|
this.haveNASplitPenalty = haveNASplitPenalty;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public <Y> Iterator<Split<Y, String>> generateSplitRuleUpdater(List<Row<Y>> data, int number, Random random) {
|
public <Y> Iterator<Split<Y, String>> generateSplitRuleUpdater(List<Row<Y>> data, int number, Random random) {
|
||||||
|
if(hasNAs()){
|
||||||
|
data = data.stream().filter(row -> !row.getCovariateValue(this).isNA()).collect(Collectors.toList());
|
||||||
|
}
|
||||||
|
|
||||||
|
if(number == 0){ // nsplit = 0 => try every possibility, although we limit it to the number of observations.
|
||||||
|
number = data.size();
|
||||||
|
}
|
||||||
|
|
||||||
final Set<Split<Y, String>> splits = new HashSet<>();
|
final Set<Split<Y, String>> splits = new HashSet<>();
|
||||||
|
|
||||||
// This is to ensure we don't get stuck in an infinite loop for small factors
|
// This is to ensure we don't get stuck in an infinite loop for small factors
|
||||||
|
@ -122,6 +142,8 @@ public final class FactorCovariate implements Covariate<String> {
|
||||||
@EqualsAndHashCode
|
@EqualsAndHashCode
|
||||||
public final class FactorValue implements Covariate.Value<String>{
|
public final class FactorValue implements Covariate.Value<String>{
|
||||||
|
|
||||||
|
private static final long serialVersionUID = 1L;
|
||||||
|
|
||||||
private final String value;
|
private final String value;
|
||||||
|
|
||||||
private FactorValue(final String value){
|
private FactorValue(final String value){
|
||||||
|
|
|
@ -25,6 +25,8 @@ import java.util.Set;
|
||||||
@EqualsAndHashCode
|
@EqualsAndHashCode
|
||||||
public final class FactorSplitRule implements SplitRule<String> {
|
public final class FactorSplitRule implements SplitRule<String> {
|
||||||
|
|
||||||
|
private static final long serialVersionUID = 1L;
|
||||||
|
|
||||||
private final int parentCovariateIndex;
|
private final int parentCovariateIndex;
|
||||||
private final Set<String> leftSideValues;
|
private final Set<String> leftSideValues;
|
||||||
|
|
||||||
|
|
|
@ -37,6 +37,8 @@ import java.util.stream.Stream;
|
||||||
@ToString
|
@ToString
|
||||||
public final class NumericCovariate implements Covariate<Double> {
|
public final class NumericCovariate implements Covariate<Double> {
|
||||||
|
|
||||||
|
private static final long serialVersionUID = 1L;
|
||||||
|
|
||||||
@Getter
|
@Getter
|
||||||
private final String name;
|
private final String name;
|
||||||
|
|
||||||
|
@ -45,6 +47,13 @@ public final class NumericCovariate implements Covariate<Double> {
|
||||||
|
|
||||||
private boolean hasNAs = false;
|
private boolean hasNAs = false;
|
||||||
|
|
||||||
|
private final boolean haveNASplitPenalty;
|
||||||
|
@Override
|
||||||
|
public boolean haveNASplitPenalty(){
|
||||||
|
// penalty would add worthless computational time if there are no NAs
|
||||||
|
return hasNAs && haveNASplitPenalty;
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public <Y> NumericSplitRuleUpdater<Y> generateSplitRuleUpdater(List<Row<Y>> data, int number, Random random) {
|
public <Y> NumericSplitRuleUpdater<Y> generateSplitRuleUpdater(List<Row<Y>> data, int number, Random random) {
|
||||||
Stream<Row<Y>> stream = data.stream();
|
Stream<Row<Y>> stream = data.stream();
|
||||||
|
@ -122,6 +131,8 @@ public final class NumericCovariate implements Covariate<Double> {
|
||||||
@EqualsAndHashCode
|
@EqualsAndHashCode
|
||||||
public class NumericValue implements Covariate.Value<Double>{
|
public class NumericValue implements Covariate.Value<Double>{
|
||||||
|
|
||||||
|
private static final long serialVersionUID = 1L;
|
||||||
|
|
||||||
private final Double value; // may be null
|
private final Double value; // may be null
|
||||||
|
|
||||||
private NumericValue(final Double value){
|
private NumericValue(final Double value){
|
||||||
|
|
|
@ -23,10 +23,12 @@ import lombok.EqualsAndHashCode;
|
||||||
@EqualsAndHashCode
|
@EqualsAndHashCode
|
||||||
public class NumericSplitRule implements SplitRule<Double> {
|
public class NumericSplitRule implements SplitRule<Double> {
|
||||||
|
|
||||||
|
private static final long serialVersionUID = 1L;
|
||||||
|
|
||||||
private final int parentCovariateIndex;
|
private final int parentCovariateIndex;
|
||||||
private final double threshold;
|
private final double threshold;
|
||||||
|
|
||||||
NumericSplitRule(NumericCovariate parent, final double threshold){
|
public NumericSplitRule(NumericCovariate parent, final double threshold){
|
||||||
this.parentCovariateIndex = parent.getIndex();
|
this.parentCovariateIndex = parent.getIndex();
|
||||||
this.threshold = threshold;
|
this.threshold = threshold;
|
||||||
}
|
}
|
||||||
|
|
|
@ -26,6 +26,8 @@ import java.util.List;
|
||||||
@Builder
|
@Builder
|
||||||
public class CompetingRiskFunctions implements Serializable {
|
public class CompetingRiskFunctions implements Serializable {
|
||||||
|
|
||||||
|
private static final long serialVersionUID = 1L;
|
||||||
|
|
||||||
private final List<RightContinuousStepFunction> causeSpecificHazards;
|
private final List<RightContinuousStepFunction> causeSpecificHazards;
|
||||||
private final List<RightContinuousStepFunction> cumulativeIncidenceCurves;
|
private final List<RightContinuousStepFunction> cumulativeIncidenceCurves;
|
||||||
|
|
||||||
|
|
|
@ -23,6 +23,8 @@ import java.io.Serializable;
|
||||||
@Data
|
@Data
|
||||||
public class CompetingRiskResponse implements Serializable {
|
public class CompetingRiskResponse implements Serializable {
|
||||||
|
|
||||||
|
private static final long serialVersionUID = 1L;
|
||||||
|
|
||||||
private final int delta;
|
private final int delta;
|
||||||
private final double u;
|
private final double u;
|
||||||
|
|
||||||
|
|
|
@ -26,6 +26,9 @@ import lombok.EqualsAndHashCode;
|
||||||
@EqualsAndHashCode(callSuper = true)
|
@EqualsAndHashCode(callSuper = true)
|
||||||
@Data
|
@Data
|
||||||
public final class CompetingRiskResponseWithCensorTime extends CompetingRiskResponse {
|
public final class CompetingRiskResponseWithCensorTime extends CompetingRiskResponse {
|
||||||
|
|
||||||
|
private static final long serialVersionUID = 1L;
|
||||||
|
|
||||||
private final double c;
|
private final double c;
|
||||||
|
|
||||||
public CompetingRiskResponseWithCensorTime(int delta, double u, double c) {
|
public CompetingRiskResponseWithCensorTime(int delta, double u, double c) {
|
||||||
|
|
|
@ -16,9 +16,11 @@
|
||||||
|
|
||||||
package ca.joeltherrien.randomforest.responses.competingrisk;
|
package ca.joeltherrien.randomforest.responses.competingrisk;
|
||||||
|
|
||||||
|
import ca.joeltherrien.randomforest.utils.RightContinuousStepFunction;
|
||||||
import ca.joeltherrien.randomforest.utils.StepFunction;
|
import ca.joeltherrien.randomforest.utils.StepFunction;
|
||||||
|
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
|
import java.util.stream.IntStream;
|
||||||
import java.util.stream.Stream;
|
import java.util.stream.Stream;
|
||||||
|
|
||||||
public class CompetingRiskUtils {
|
public class CompetingRiskUtils {
|
||||||
|
@ -116,6 +118,44 @@ public class CompetingRiskUtils {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Calculate the Integrated Brier Score error on a list of responses and predictions.
|
||||||
|
*
|
||||||
|
* @param responses A List of responses
|
||||||
|
* @param predictions The corresponding List of predictions.
|
||||||
|
* @param censoringDistribution The censoring distribution.
|
||||||
|
* @param eventOfFocus The event we are calculating the error for.
|
||||||
|
* @param integrationUpperBound The upper bound to integrate to.
|
||||||
|
* @param isParallel Whether we should use parallel streams or not (provided because of bugs on a particular system).
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
public static double[] calculateIBSError(final List<CompetingRiskResponse> responses,
|
||||||
|
List<CompetingRiskFunctions> predictions,
|
||||||
|
Optional<RightContinuousStepFunction> censoringDistribution,
|
||||||
|
int eventOfFocus,
|
||||||
|
double integrationUpperBound,
|
||||||
|
boolean isParallel){
|
||||||
|
|
||||||
|
if(responses.size() != predictions.size()){
|
||||||
|
throw new IllegalArgumentException("Length of responses and predictions must be equal.");
|
||||||
|
}
|
||||||
|
|
||||||
|
final IBSCalculator calculator = new IBSCalculator(censoringDistribution);
|
||||||
|
|
||||||
|
IntStream stream = IntStream.range(0, responses.size());
|
||||||
|
|
||||||
|
if(isParallel){
|
||||||
|
stream = stream.parallel();
|
||||||
|
}
|
||||||
|
|
||||||
|
return stream.mapToDouble(i -> {
|
||||||
|
CompetingRiskResponse response = responses.get(i);
|
||||||
|
RightContinuousStepFunction cif = predictions.get(i).getCumulativeIncidenceFunction(eventOfFocus);
|
||||||
|
|
||||||
|
return calculator.calculateError(response, cif, eventOfFocus, integrationUpperBound);
|
||||||
|
}).toArray();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
public static CompetingRiskSetsImpl calculateSetsEfficiently(final List<CompetingRiskResponse> initialLeftHand,
|
public static CompetingRiskSetsImpl calculateSetsEfficiently(final List<CompetingRiskResponse> initialLeftHand,
|
||||||
final List<CompetingRiskResponse> initialRightHand,
|
final List<CompetingRiskResponse> initialRightHand,
|
||||||
|
|
|
@ -0,0 +1,82 @@
|
||||||
|
package ca.joeltherrien.randomforest.responses.competingrisk;
|
||||||
|
|
||||||
|
import ca.joeltherrien.randomforest.utils.RightContinuousStepFunction;
|
||||||
|
|
||||||
|
import java.util.Optional;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Used to calculate the Integrated Brier Score. See Section 4.2 of "Random survival forests for competing risks" by Ishwaran.
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
public class IBSCalculator {
|
||||||
|
|
||||||
|
private final Optional<RightContinuousStepFunction> censoringDistribution;
|
||||||
|
|
||||||
|
public IBSCalculator(RightContinuousStepFunction censoringDistribution){
|
||||||
|
this.censoringDistribution = Optional.of(censoringDistribution);
|
||||||
|
}
|
||||||
|
|
||||||
|
public IBSCalculator(){
|
||||||
|
this.censoringDistribution = Optional.empty();
|
||||||
|
}
|
||||||
|
|
||||||
|
public IBSCalculator(Optional<RightContinuousStepFunction> censoringDistribution){
|
||||||
|
this.censoringDistribution = censoringDistribution;
|
||||||
|
}
|
||||||
|
|
||||||
|
public double calculateError(CompetingRiskResponse response, RightContinuousStepFunction cif, int eventOfInterest, double integrationUpperBound){
|
||||||
|
|
||||||
|
// return integral of weights*(I(response.getU() <= times & response.getDelta() == eventOfInterest) - cif(times))^2
|
||||||
|
// Note that if we don't have weights, just treat them all as one (i.e. don't bother multiplying)
|
||||||
|
|
||||||
|
RightContinuousStepFunction functionToIntegrate = cif;
|
||||||
|
|
||||||
|
if(response.getDelta() == eventOfInterest){
|
||||||
|
final RightContinuousStepFunction observedFunction = new RightContinuousStepFunction(new double[]{response.getU()}, new double[]{1.0}, 0.0);
|
||||||
|
functionToIntegrate = RightContinuousStepFunction.biOperation(observedFunction, functionToIntegrate, (a, b) -> (a - b) * (a - b));
|
||||||
|
} else{
|
||||||
|
functionToIntegrate = functionToIntegrate.unaryOperation(a -> a*a);
|
||||||
|
}
|
||||||
|
|
||||||
|
if(censoringDistribution.isPresent()){
|
||||||
|
final RightContinuousStepFunction weights = calculateWeights(response, censoringDistribution.get());
|
||||||
|
functionToIntegrate = RightContinuousStepFunction.biOperation(weights, functionToIntegrate, (a, b) -> a*b);
|
||||||
|
|
||||||
|
// the censoring weights go to 0 after the response is censored, so we can speed up results by only integrating
|
||||||
|
// prior to the censor times
|
||||||
|
if(response.isCensored()){
|
||||||
|
integrationUpperBound = Math.min(integrationUpperBound, response.getU());
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
return functionToIntegrate.integrate(0.0, integrationUpperBound);
|
||||||
|
}
|
||||||
|
|
||||||
|
private RightContinuousStepFunction calculateWeights(CompetingRiskResponse response, RightContinuousStepFunction censoringDistribution){
|
||||||
|
final double recordedTime = response.getU();
|
||||||
|
|
||||||
|
// Function(t) = firstPart(t) + secondPart(t)/thirdPart(t) where:
|
||||||
|
// firstPart(t) = I(recordedTime <= t & !response.isCensored()) / censoringDistribution.evaluate(recordedTime);
|
||||||
|
// secondPart(t) = I(recordedTime > t) = 1 - I(recordedTime <= t)
|
||||||
|
// thirdPart(t) = censoringDistribution.evaluate(t)
|
||||||
|
|
||||||
|
final RightContinuousStepFunction secondPart = new RightContinuousStepFunction(new double[]{recordedTime}, new double[]{0.0}, 1.0);
|
||||||
|
RightContinuousStepFunction result = RightContinuousStepFunction.biOperation(secondPart, censoringDistribution,
|
||||||
|
(second, third) -> second / third);
|
||||||
|
|
||||||
|
if(!response.isCensored()){
|
||||||
|
final RightContinuousStepFunction firstPart = new RightContinuousStepFunction(
|
||||||
|
new double[]{recordedTime},
|
||||||
|
new double[]{1.0 / censoringDistribution.evaluate(recordedTime)},
|
||||||
|
0.0);
|
||||||
|
|
||||||
|
result = RightContinuousStepFunction.biOperation(firstPart, result, Double::sum);
|
||||||
|
}
|
||||||
|
|
||||||
|
return result;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
}
|
|
@ -17,17 +17,17 @@
|
||||||
package ca.joeltherrien.randomforest.responses.competingrisk.combiner;
|
package ca.joeltherrien.randomforest.responses.competingrisk.combiner;
|
||||||
|
|
||||||
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskFunctions;
|
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskFunctions;
|
||||||
import ca.joeltherrien.randomforest.tree.ResponseCombiner;
|
import ca.joeltherrien.randomforest.tree.ForestResponseCombiner;
|
||||||
import ca.joeltherrien.randomforest.utils.RightContinuousStepFunction;
|
import ca.joeltherrien.randomforest.tree.IntermediateCombinedResponse;
|
||||||
import ca.joeltherrien.randomforest.utils.Utils;
|
|
||||||
import lombok.RequiredArgsConstructor;
|
import lombok.RequiredArgsConstructor;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
@RequiredArgsConstructor
|
@RequiredArgsConstructor
|
||||||
public class CompetingRiskFunctionCombiner implements ResponseCombiner<CompetingRiskFunctions, CompetingRiskFunctions> {
|
public class CompetingRiskFunctionCombiner implements ForestResponseCombiner<CompetingRiskFunctions, CompetingRiskFunctions> {
|
||||||
|
|
||||||
|
private static final long serialVersionUID = 1L;
|
||||||
|
|
||||||
private final int[] events;
|
private final int[] events;
|
||||||
private final double[] times; // We may restrict ourselves to specific times.
|
private final double[] times; // We may restrict ourselves to specific times.
|
||||||
|
@ -55,72 +55,22 @@ public class CompetingRiskFunctionCombiner implements ResponseCombiner<Competing
|
||||||
).sorted().distinct().toArray();
|
).sorted().distinct().toArray();
|
||||||
}
|
}
|
||||||
|
|
||||||
final double n = responses.size();
|
final IntermediateCompetingRisksFunctionsTimesKnown intermediateResult = new IntermediateCompetingRisksFunctionsTimesKnown(responses.size(), this.events, timesToUse);
|
||||||
|
|
||||||
final double[] survivalY = new double[timesToUse.length];
|
|
||||||
final double[][] csCHFY = new double[events.length][timesToUse.length];
|
|
||||||
final double[][] cifY = new double[events.length][timesToUse.length];
|
|
||||||
|
|
||||||
/*
|
|
||||||
We're going to try to efficiently put our predictions together -
|
|
||||||
Assumptions - for each event on a response, the hazard and CIF functions share the same x points
|
|
||||||
|
|
||||||
Plan - go through the time on each response and make use of that so that when we search for a time index
|
|
||||||
to evaluate the function at, we don't need to re-search the earlier times.
|
|
||||||
|
|
||||||
*/
|
|
||||||
|
|
||||||
|
|
||||||
for(final CompetingRiskFunctions currentFunctions : responses){
|
|
||||||
final double[] survivalXPoints = currentFunctions.getSurvivalCurve().getX();
|
|
||||||
final double[][] eventSpecificXPoints = new double[events.length][];
|
|
||||||
|
|
||||||
for(final int event : events){
|
|
||||||
eventSpecificXPoints[event-1] = currentFunctions.getCumulativeIncidenceFunction(event)
|
|
||||||
.getX();
|
|
||||||
}
|
|
||||||
|
|
||||||
int previousSurvivalIndex = 0;
|
|
||||||
final int[] previousEventSpecificIndex = new int[events.length]; // relying on 0 being default value
|
|
||||||
|
|
||||||
for(int i=0; i<timesToUse.length; i++){
|
|
||||||
final double time = timesToUse[i];
|
|
||||||
|
|
||||||
// Survival curve
|
|
||||||
final int survivalTimeIndex = Utils.binarySearchLessThan(previousSurvivalIndex, survivalXPoints.length, survivalXPoints, time);
|
|
||||||
survivalY[i] = survivalY[i] + currentFunctions.getSurvivalCurve().evaluateByIndex(survivalTimeIndex) / n;
|
|
||||||
previousSurvivalIndex = Math.max(survivalTimeIndex, 0); // if our current time is less than the smallest time in xPoints then binarySearchLessThan returned a -1.
|
|
||||||
// -1's not an issue for evaluateByIndex, but it is an issue for the next time binarySearchLessThan is called.
|
|
||||||
|
|
||||||
// CHFs and CIFs
|
|
||||||
for(final int event : events){
|
|
||||||
final double[] xPoints = eventSpecificXPoints[event-1];
|
|
||||||
final int eventTimeIndex = Utils.binarySearchLessThan(previousEventSpecificIndex[event-1], xPoints.length,
|
|
||||||
xPoints, time);
|
|
||||||
csCHFY[event-1][i] = csCHFY[event-1][i] + currentFunctions.getCauseSpecificHazardFunction(event)
|
|
||||||
.evaluateByIndex(eventTimeIndex) / n;
|
|
||||||
cifY[event-1][i] = cifY[event-1][i] + currentFunctions.getCumulativeIncidenceFunction(event)
|
|
||||||
.evaluateByIndex(eventTimeIndex) / n;
|
|
||||||
|
|
||||||
previousEventSpecificIndex[event-1] = Math.max(eventTimeIndex, 0);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
for(CompetingRiskFunctions input : responses){
|
||||||
|
intermediateResult.processNewInput(input);
|
||||||
}
|
}
|
||||||
|
|
||||||
final RightContinuousStepFunction survivalFunction = new RightContinuousStepFunction(timesToUse, survivalY, 1.0);
|
return intermediateResult.transformToOutput();
|
||||||
final List<RightContinuousStepFunction> causeSpecificCumulativeHazardFunctionList = new ArrayList<>(events.length);
|
}
|
||||||
final List<RightContinuousStepFunction> cumulativeIncidenceFunctionList = new ArrayList<>(events.length);
|
|
||||||
|
|
||||||
for(final int event : events){
|
@Override
|
||||||
causeSpecificCumulativeHazardFunctionList.add(event-1, new RightContinuousStepFunction(timesToUse, csCHFY[event-1], 0));
|
public IntermediateCombinedResponse<CompetingRiskFunctions, CompetingRiskFunctions> startIntermediateCombinedResponse(int countInputs) {
|
||||||
cumulativeIncidenceFunctionList.add(event-1, new RightContinuousStepFunction(timesToUse, cifY[event-1], 0));
|
if(this.times != null){
|
||||||
|
return new IntermediateCompetingRisksFunctionsTimesKnown(countInputs, this.events, this.times);
|
||||||
}
|
}
|
||||||
|
|
||||||
return CompetingRiskFunctions.builder()
|
// TODO - implement
|
||||||
.causeSpecificHazards(causeSpecificCumulativeHazardFunctionList)
|
throw new RuntimeException("startIntermediateCombinedResponse when times is unknown is not yet implemented");
|
||||||
.cumulativeIncidenceCurves(cumulativeIncidenceFunctionList)
|
|
||||||
.survivalCurve(survivalFunction)
|
|
||||||
.build();
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -35,6 +35,8 @@ import java.util.List;
|
||||||
*/
|
*/
|
||||||
public class CompetingRiskResponseCombiner implements ResponseCombiner<CompetingRiskResponse, CompetingRiskFunctions> {
|
public class CompetingRiskResponseCombiner implements ResponseCombiner<CompetingRiskResponse, CompetingRiskFunctions> {
|
||||||
|
|
||||||
|
private static final long serialVersionUID = 1L;
|
||||||
|
|
||||||
private final int[] events;
|
private final int[] events;
|
||||||
|
|
||||||
public CompetingRiskResponseCombiner(final int[] events){
|
public CompetingRiskResponseCombiner(final int[] events){
|
||||||
|
|
|
@ -0,0 +1,118 @@
|
||||||
|
package ca.joeltherrien.randomforest.responses.competingrisk.combiner;
|
||||||
|
|
||||||
|
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskFunctions;
|
||||||
|
import ca.joeltherrien.randomforest.tree.IntermediateCombinedResponse;
|
||||||
|
import ca.joeltherrien.randomforest.utils.RightContinuousStepFunction;
|
||||||
|
import ca.joeltherrien.randomforest.utils.Utils;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
public class IntermediateCompetingRisksFunctionsTimesKnown implements IntermediateCombinedResponse<CompetingRiskFunctions, CompetingRiskFunctions> {
|
||||||
|
|
||||||
|
private double expectedN;
|
||||||
|
private final int[] events;
|
||||||
|
private final double[] timesToUse;
|
||||||
|
private int actualN;
|
||||||
|
|
||||||
|
private final double[] survivalY;
|
||||||
|
private final double[][] csCHFY;
|
||||||
|
private final double[][] cifY;
|
||||||
|
|
||||||
|
public IntermediateCompetingRisksFunctionsTimesKnown(int n, int[] events, double[] timesToUse){
|
||||||
|
this.expectedN = n;
|
||||||
|
this.events = events;
|
||||||
|
this.timesToUse = timesToUse;
|
||||||
|
this.actualN = 0;
|
||||||
|
|
||||||
|
this.survivalY = new double[timesToUse.length];
|
||||||
|
this.csCHFY = new double[events.length][timesToUse.length];
|
||||||
|
this.cifY = new double[events.length][timesToUse.length];
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void processNewInput(CompetingRiskFunctions input) {
|
||||||
|
/*
|
||||||
|
We're going to try to efficiently put our predictions together -
|
||||||
|
Assumptions - for each event on a response, the hazard and CIF functions share the same x points
|
||||||
|
|
||||||
|
Plan - go through the time on each response and make use of that so that when we search for a time index
|
||||||
|
to evaluate the function at, we don't need to re-search the earlier times.
|
||||||
|
|
||||||
|
*/
|
||||||
|
|
||||||
|
this.actualN++;
|
||||||
|
|
||||||
|
final double[] survivalXPoints = input.getSurvivalCurve().getX();
|
||||||
|
final double[][] eventSpecificXPoints = new double[events.length][];
|
||||||
|
|
||||||
|
for(final int event : events){
|
||||||
|
eventSpecificXPoints[event-1] = input.getCumulativeIncidenceFunction(event)
|
||||||
|
.getX();
|
||||||
|
}
|
||||||
|
|
||||||
|
int previousSurvivalIndex = 0;
|
||||||
|
final int[] previousEventSpecificIndex = new int[events.length]; // relying on 0 being default value
|
||||||
|
|
||||||
|
for(int i=0; i<timesToUse.length; i++){
|
||||||
|
final double time = timesToUse[i];
|
||||||
|
|
||||||
|
// Survival curve
|
||||||
|
final int survivalTimeIndex = Utils.binarySearchLessThan(previousSurvivalIndex, survivalXPoints.length, survivalXPoints, time);
|
||||||
|
survivalY[i] = survivalY[i] + input.getSurvivalCurve().evaluateByIndex(survivalTimeIndex) / expectedN;
|
||||||
|
previousSurvivalIndex = Math.max(survivalTimeIndex, 0); // if our current time is less than the smallest time in xPoints then binarySearchLessThan returned a -1.
|
||||||
|
// -1's not an issue for evaluateByIndex, but it is an issue for the next time binarySearchLessThan is called.
|
||||||
|
|
||||||
|
// CHFs and CIFs
|
||||||
|
for(final int event : events){
|
||||||
|
final double[] xPoints = eventSpecificXPoints[event-1];
|
||||||
|
final int eventTimeIndex = Utils.binarySearchLessThan(previousEventSpecificIndex[event-1], xPoints.length,
|
||||||
|
xPoints, time);
|
||||||
|
csCHFY[event-1][i] = csCHFY[event-1][i] + input.getCauseSpecificHazardFunction(event)
|
||||||
|
.evaluateByIndex(eventTimeIndex) / expectedN;
|
||||||
|
cifY[event-1][i] = cifY[event-1][i] + input.getCumulativeIncidenceFunction(event)
|
||||||
|
.evaluateByIndex(eventTimeIndex) / expectedN;
|
||||||
|
|
||||||
|
previousEventSpecificIndex[event-1] = Math.max(eventTimeIndex, 0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public CompetingRiskFunctions transformToOutput() {
|
||||||
|
rescaleOutput();
|
||||||
|
|
||||||
|
final RightContinuousStepFunction survivalFunction = new RightContinuousStepFunction(timesToUse, survivalY, 1.0);
|
||||||
|
final List<RightContinuousStepFunction> causeSpecificCumulativeHazardFunctionList = new ArrayList<>(events.length);
|
||||||
|
final List<RightContinuousStepFunction> cumulativeIncidenceFunctionList = new ArrayList<>(events.length);
|
||||||
|
|
||||||
|
for(final int event : events){
|
||||||
|
causeSpecificCumulativeHazardFunctionList.add(event-1, new RightContinuousStepFunction(timesToUse, csCHFY[event-1], 0));
|
||||||
|
cumulativeIncidenceFunctionList.add(event-1, new RightContinuousStepFunction(timesToUse, cifY[event-1], 0));
|
||||||
|
}
|
||||||
|
|
||||||
|
return CompetingRiskFunctions.builder()
|
||||||
|
.causeSpecificHazards(causeSpecificCumulativeHazardFunctionList)
|
||||||
|
.cumulativeIncidenceCurves(cumulativeIncidenceFunctionList)
|
||||||
|
.survivalCurve(survivalFunction)
|
||||||
|
.build();
|
||||||
|
}
|
||||||
|
|
||||||
|
private void rescaleOutput() {
|
||||||
|
rescaleArray(actualN, this.survivalY);
|
||||||
|
|
||||||
|
for(int event : events){
|
||||||
|
rescaleArray(actualN, this.cifY[event - 1]);
|
||||||
|
rescaleArray(actualN, this.csCHFY[event - 1]);
|
||||||
|
}
|
||||||
|
|
||||||
|
this.expectedN = actualN;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
private void rescaleArray(double newN, double[] array){
|
||||||
|
for(int i=0; i<array.length; i++){
|
||||||
|
array[i] = array[i] * (this.expectedN / newN);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -28,6 +28,7 @@ import java.util.List;
|
||||||
*
|
*
|
||||||
*/
|
*/
|
||||||
public class GrayLogRankSplitFinder extends CompetingRiskSplitFinder<CompetingRiskResponseWithCensorTime> {
|
public class GrayLogRankSplitFinder extends CompetingRiskSplitFinder<CompetingRiskResponseWithCensorTime> {
|
||||||
|
private static final long serialVersionUID = 1L;
|
||||||
|
|
||||||
private final int[] eventsOfFocus;
|
private final int[] eventsOfFocus;
|
||||||
private final int[] events;
|
private final int[] events;
|
||||||
|
|
|
@ -28,6 +28,7 @@ import java.util.List;
|
||||||
*
|
*
|
||||||
*/
|
*/
|
||||||
public class LogRankSplitFinder extends CompetingRiskSplitFinder<CompetingRiskResponse> {
|
public class LogRankSplitFinder extends CompetingRiskSplitFinder<CompetingRiskResponse> {
|
||||||
|
private static final long serialVersionUID = 1L;
|
||||||
|
|
||||||
private final int[] eventsOfFocus;
|
private final int[] eventsOfFocus;
|
||||||
private final int[] events;
|
private final int[] events;
|
||||||
|
|
|
@ -16,7 +16,8 @@
|
||||||
|
|
||||||
package ca.joeltherrien.randomforest.responses.regression;
|
package ca.joeltherrien.randomforest.responses.regression;
|
||||||
|
|
||||||
import ca.joeltherrien.randomforest.tree.ResponseCombiner;
|
import ca.joeltherrien.randomforest.tree.ForestResponseCombiner;
|
||||||
|
import ca.joeltherrien.randomforest.tree.IntermediateCombinedResponse;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
|
@ -24,7 +25,8 @@ import java.util.List;
|
||||||
* Returns the Mean value of a group of Doubles.
|
* Returns the Mean value of a group of Doubles.
|
||||||
*
|
*
|
||||||
*/
|
*/
|
||||||
public class MeanResponseCombiner implements ResponseCombiner<Double, Double> {
|
public class MeanResponseCombiner implements ForestResponseCombiner<Double, Double> {
|
||||||
|
private static final long serialVersionUID = 1L;
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Double combine(List<Double> responses) {
|
public Double combine(List<Double> responses) {
|
||||||
|
@ -34,5 +36,39 @@ public class MeanResponseCombiner implements ResponseCombiner<Double, Double> {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public IntermediateCombinedResponse<Double, Double> startIntermediateCombinedResponse(int countInputs) {
|
||||||
|
return new MeanIntermediateCombinedResponse(countInputs);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static class MeanIntermediateCombinedResponse implements IntermediateCombinedResponse<Double, Double>{
|
||||||
|
|
||||||
|
private double expectedN;
|
||||||
|
private int actualN;
|
||||||
|
private double currentMean;
|
||||||
|
|
||||||
|
public MeanIntermediateCombinedResponse(int n){
|
||||||
|
this.expectedN = n;
|
||||||
|
this.actualN = 0;
|
||||||
|
this.currentMean = 0.0;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void processNewInput(Double input) {
|
||||||
|
this.currentMean = this.currentMean + input / expectedN;
|
||||||
|
this.actualN ++;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Double transformToOutput() {
|
||||||
|
// rescale if necessary
|
||||||
|
this.currentMean = this.currentMean * (this.expectedN / (double) actualN);
|
||||||
|
this.expectedN = actualN;
|
||||||
|
|
||||||
|
return currentMean;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -27,6 +27,7 @@ import java.util.List;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
public class WeightedVarianceSplitFinder implements SplitFinder<Double> {
|
public class WeightedVarianceSplitFinder implements SplitFinder<Double> {
|
||||||
|
private static final long serialVersionUID = 1L;
|
||||||
|
|
||||||
private Double getScore(Set leftHand, Set rightHand) {
|
private Double getScore(Set leftHand, Set rightHand) {
|
||||||
|
|
||||||
|
|
|
@ -17,31 +17,18 @@
|
||||||
package ca.joeltherrien.randomforest.tree;
|
package ca.joeltherrien.randomforest.tree;
|
||||||
|
|
||||||
import ca.joeltherrien.randomforest.CovariateRow;
|
import ca.joeltherrien.randomforest.CovariateRow;
|
||||||
import ca.joeltherrien.randomforest.covariates.Covariate;
|
|
||||||
import lombok.Builder;
|
|
||||||
|
|
||||||
import java.util.Collections;
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.TreeMap;
|
import java.util.TreeMap;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
@Builder
|
public abstract class Forest<O, FO> {
|
||||||
public class Forest<O, FO> { // O = output of trees, FO = forest output. In practice O == FO, even in competing risk & survival settings
|
|
||||||
|
|
||||||
private final List<Tree<O>> trees;
|
public abstract FO evaluate(CovariateRow row);
|
||||||
private final ResponseCombiner<O, FO> treeResponseCombiner;
|
public abstract FO evaluateOOB(CovariateRow row);
|
||||||
private final List<Covariate> covariateList;
|
public abstract Iterable<Tree<O>> getTrees();
|
||||||
|
public abstract int getNumberOfTrees();
|
||||||
public FO evaluate(CovariateRow row){
|
|
||||||
|
|
||||||
return treeResponseCombiner.combine(
|
|
||||||
trees.stream()
|
|
||||||
.map(node -> node.evaluate(row))
|
|
||||||
.collect(Collectors.toList())
|
|
||||||
);
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Used primarily in the R package interface to avoid R loops; and for easier parallelization.
|
* Used primarily in the R package interface to avoid R loops; and for easier parallelization.
|
||||||
|
@ -93,21 +80,6 @@ public class Forest<O, FO> { // O = output of trees, FO = forest output. In prac
|
||||||
.collect(Collectors.toList());
|
.collect(Collectors.toList());
|
||||||
}
|
}
|
||||||
|
|
||||||
public FO evaluateOOB(CovariateRow row){
|
|
||||||
|
|
||||||
return treeResponseCombiner.combine(
|
|
||||||
trees.stream()
|
|
||||||
.filter(tree -> !tree.idInBootstrapSample(row.getId()))
|
|
||||||
.map(node -> node.evaluate(row))
|
|
||||||
.collect(Collectors.toList())
|
|
||||||
);
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
public List<Tree<O>> getTrees(){
|
|
||||||
return Collections.unmodifiableList(trees);
|
|
||||||
}
|
|
||||||
|
|
||||||
public Map<Integer, Integer> findSplitsByCovariate(){
|
public Map<Integer, Integer> findSplitsByCovariate(){
|
||||||
final Map<Integer, Integer> countMap = new TreeMap<>();
|
final Map<Integer, Integer> countMap = new TreeMap<>();
|
||||||
|
|
||||||
|
@ -158,4 +130,5 @@ public class Forest<O, FO> { // O = output of trees, FO = forest output. In prac
|
||||||
return countTerminalNodes;
|
return countTerminalNodes;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,23 @@
|
||||||
|
/*
|
||||||
|
* Copyright (c) 2019 Joel Therrien.
|
||||||
|
* This program is free software: you can redistribute it and/or modify
|
||||||
|
* it under the terms of the GNU General Public License as published by
|
||||||
|
* the Free Software Foundation, either version 3 of the License, or
|
||||||
|
* (at your option) any later version.
|
||||||
|
*
|
||||||
|
* This program is distributed in the hope that it will be useful,
|
||||||
|
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
* GNU General Public License for more details.
|
||||||
|
*
|
||||||
|
* You should have received a copy of the GNU General Public License
|
||||||
|
* along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package ca.joeltherrien.randomforest.tree;
|
||||||
|
|
||||||
|
public interface ForestResponseCombiner<I, O> extends ResponseCombiner<I, O>{
|
||||||
|
|
||||||
|
IntermediateCombinedResponse<I, O> startIntermediateCombinedResponse(int countInputs);
|
||||||
|
|
||||||
|
}
|
|
@ -38,7 +38,7 @@ public class ForestTrainer<Y, TO, FO> {
|
||||||
|
|
||||||
private final TreeTrainer<Y, TO> treeTrainer;
|
private final TreeTrainer<Y, TO> treeTrainer;
|
||||||
private final List<Covariate> covariates;
|
private final List<Covariate> covariates;
|
||||||
private final ResponseCombiner<TO, FO> treeResponseCombiner;
|
private final ForestResponseCombiner<TO, FO> treeResponseCombiner;
|
||||||
private final List<Row<Y>> data;
|
private final List<Row<Y>> data;
|
||||||
|
|
||||||
// number of trees to try
|
// number of trees to try
|
||||||
|
@ -57,10 +57,10 @@ public class ForestTrainer<Y, TO, FO> {
|
||||||
* in which case its trees are combined with the new one.
|
* in which case its trees are combined with the new one.
|
||||||
* @return A trained forest.
|
* @return A trained forest.
|
||||||
*/
|
*/
|
||||||
public Forest<TO, FO> trainSerialInMemory(Optional<Forest<TO, FO>> initialForest){
|
public OnlineForest<TO, FO> trainSerialInMemory(Optional<Forest<TO, FO>> initialForest){
|
||||||
|
|
||||||
final List<Tree<TO>> trees = new ArrayList<>(ntree);
|
final List<Tree<TO>> trees = new ArrayList<>(ntree);
|
||||||
initialForest.ifPresent(forest -> trees.addAll(forest.getTrees()));
|
initialForest.ifPresent(forest -> forest.getTrees().forEach(trees::add));
|
||||||
|
|
||||||
final Bootstrapper<Row<Y>> bootstrapper = new Bootstrapper<>(data);
|
final Bootstrapper<Row<Y>> bootstrapper = new Bootstrapper<>(data);
|
||||||
|
|
||||||
|
@ -77,11 +77,9 @@ public class ForestTrainer<Y, TO, FO> {
|
||||||
System.out.println("Finished");
|
System.out.println("Finished");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return OnlineForest.<TO, FO>builder()
|
||||||
return Forest.<TO, FO>builder()
|
|
||||||
.treeResponseCombiner(treeResponseCombiner)
|
.treeResponseCombiner(treeResponseCombiner)
|
||||||
.trees(trees)
|
.trees(trees)
|
||||||
.covariateList(covariates)
|
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -94,7 +92,7 @@ public class ForestTrainer<Y, TO, FO> {
|
||||||
* There cannot be existing trees if the initial forest is
|
* There cannot be existing trees if the initial forest is
|
||||||
* specified.
|
* specified.
|
||||||
*/
|
*/
|
||||||
public void trainSerialOnDisk(Optional<Forest<TO, FO>> initialForest){
|
public OfflineForest<TO, FO> trainSerialOnDisk(Optional<Forest<TO, FO>> initialForest){
|
||||||
// First we need to see how many trees there currently are
|
// First we need to see how many trees there currently are
|
||||||
final File folder = new File(saveTreeLocation);
|
final File folder = new File(saveTreeLocation);
|
||||||
if(!folder.exists()){
|
if(!folder.exists()){
|
||||||
|
@ -115,17 +113,14 @@ public class ForestTrainer<Y, TO, FO> {
|
||||||
final AtomicInteger treeCount; // tracks how many trees are finished
|
final AtomicInteger treeCount; // tracks how many trees are finished
|
||||||
// Using an AtomicInteger is overkill for serial code, but this lets us reuse TreeSavedWorker
|
// Using an AtomicInteger is overkill for serial code, but this lets us reuse TreeSavedWorker
|
||||||
if(initialForest.isPresent()){
|
if(initialForest.isPresent()){
|
||||||
final List<Tree<TO>> initialTrees = initialForest.get().getTrees();
|
int j=0;
|
||||||
|
for(final Tree<TO> tree : initialForest.get().getTrees()){
|
||||||
for(int j=0; j<initialTrees.size(); j++){
|
|
||||||
final String filename = "tree-" + (j+1) + ".tree";
|
final String filename = "tree-" + (j+1) + ".tree";
|
||||||
final Tree<TO> tree = initialTrees.get(j);
|
|
||||||
|
|
||||||
saveTree(tree, filename);
|
saveTree(tree, filename);
|
||||||
|
j++;
|
||||||
}
|
}
|
||||||
|
|
||||||
treeCount = new AtomicInteger(initialTrees.size());
|
treeCount = new AtomicInteger(j);
|
||||||
} else{
|
} else{
|
||||||
treeCount = new AtomicInteger(treeFiles.length);
|
treeCount = new AtomicInteger(treeFiles.length);
|
||||||
}
|
}
|
||||||
|
@ -153,6 +148,8 @@ public class ForestTrainer<Y, TO, FO> {
|
||||||
System.out.println("Finished");
|
System.out.println("Finished");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return new OfflineForest<>(folder, treeResponseCombiner);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -162,7 +159,7 @@ public class ForestTrainer<Y, TO, FO> {
|
||||||
* in which case its trees are combined with the new one.
|
* in which case its trees are combined with the new one.
|
||||||
* @param threads The number of trees to train at once.
|
* @param threads The number of trees to train at once.
|
||||||
*/
|
*/
|
||||||
public Forest<TO, FO> trainParallelInMemory(Optional<Forest<TO, FO>> initialForest, int threads){
|
public OnlineForest<TO, FO> trainParallelInMemory(Optional<Forest<TO, FO>> initialForest, int threads){
|
||||||
|
|
||||||
// create a list that is pre-specified in size (I can call the .set method at any index < ntree without
|
// create a list that is pre-specified in size (I can call the .set method at any index < ntree without
|
||||||
// the earlier indexes being filled.
|
// the earlier indexes being filled.
|
||||||
|
@ -170,11 +167,12 @@ public class ForestTrainer<Y, TO, FO> {
|
||||||
|
|
||||||
final int startingCount;
|
final int startingCount;
|
||||||
if(initialForest.isPresent()){
|
if(initialForest.isPresent()){
|
||||||
final List<Tree<TO>> initialTrees = initialForest.get().getTrees();
|
int j = 0;
|
||||||
for(int j=0; j<initialTrees.size(); j++) {
|
for(final Tree<TO> tree : initialForest.get().getTrees()){
|
||||||
trees.set(j, initialTrees.get(j));
|
trees.set(j, tree);
|
||||||
|
j++;
|
||||||
}
|
}
|
||||||
startingCount = initialTrees.size();
|
startingCount = initialForest.get().getNumberOfTrees();
|
||||||
}
|
}
|
||||||
else{
|
else{
|
||||||
startingCount = 0;
|
startingCount = 0;
|
||||||
|
@ -219,7 +217,7 @@ public class ForestTrainer<Y, TO, FO> {
|
||||||
System.out.println("\nFinished");
|
System.out.println("\nFinished");
|
||||||
}
|
}
|
||||||
|
|
||||||
return Forest.<TO, FO>builder()
|
return OnlineForest.<TO, FO>builder()
|
||||||
.treeResponseCombiner(treeResponseCombiner)
|
.treeResponseCombiner(treeResponseCombiner)
|
||||||
.trees(trees)
|
.trees(trees)
|
||||||
.build();
|
.build();
|
||||||
|
@ -235,7 +233,7 @@ public class ForestTrainer<Y, TO, FO> {
|
||||||
* specified.
|
* specified.
|
||||||
* @param threads The number of trees to train at once.
|
* @param threads The number of trees to train at once.
|
||||||
*/
|
*/
|
||||||
public void trainParallelOnDisk(Optional<Forest<TO, FO>> initialForest, int threads){
|
public OfflineForest<TO, FO> trainParallelOnDisk(Optional<Forest<TO, FO>> initialForest, int threads){
|
||||||
// First we need to see how many trees there currently are
|
// First we need to see how many trees there currently are
|
||||||
final File folder = new File(saveTreeLocation);
|
final File folder = new File(saveTreeLocation);
|
||||||
if(!folder.exists()){
|
if(!folder.exists()){
|
||||||
|
@ -255,17 +253,14 @@ public class ForestTrainer<Y, TO, FO> {
|
||||||
|
|
||||||
final AtomicInteger treeCount; // tracks how many trees are finished
|
final AtomicInteger treeCount; // tracks how many trees are finished
|
||||||
if(initialForest.isPresent()){
|
if(initialForest.isPresent()){
|
||||||
final List<Tree<TO>> initialTrees = initialForest.get().getTrees();
|
int j=0;
|
||||||
|
for(final Tree<TO> tree : initialForest.get().getTrees()){
|
||||||
for(int j=0; j<initialTrees.size(); j++){
|
|
||||||
final String filename = "tree-" + (j+1) + ".tree";
|
final String filename = "tree-" + (j+1) + ".tree";
|
||||||
final Tree<TO> tree = initialTrees.get(j);
|
|
||||||
|
|
||||||
saveTree(tree, filename);
|
saveTree(tree, filename);
|
||||||
|
j++;
|
||||||
}
|
}
|
||||||
|
|
||||||
treeCount = new AtomicInteger(initialTrees.size());
|
treeCount = new AtomicInteger(j);
|
||||||
} else{
|
} else{
|
||||||
treeCount = new AtomicInteger(treeFiles.length);
|
treeCount = new AtomicInteger(treeFiles.length);
|
||||||
}
|
}
|
||||||
|
@ -309,6 +304,8 @@ public class ForestTrainer<Y, TO, FO> {
|
||||||
System.out.println("\nFinished");
|
System.out.println("\nFinished");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return new OfflineForest<>(folder, treeResponseCombiner);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private Tree<TO> trainTree(final Bootstrapper<Row<Y>> bootstrapper, Random random){
|
private Tree<TO> trainTree(final Bootstrapper<Row<Y>> bootstrapper, Random random){
|
||||||
|
|
|
@ -0,0 +1,30 @@
|
||||||
|
/*
|
||||||
|
* Copyright (c) 2019 Joel Therrien.
|
||||||
|
* This program is free software: you can redistribute it and/or modify
|
||||||
|
* it under the terms of the GNU General Public License as published by
|
||||||
|
* the Free Software Foundation, either version 3 of the License, or
|
||||||
|
* (at your option) any later version.
|
||||||
|
*
|
||||||
|
* This program is distributed in the hope that it will be useful,
|
||||||
|
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
* GNU General Public License for more details.
|
||||||
|
*
|
||||||
|
* You should have received a copy of the GNU General Public License
|
||||||
|
* along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package ca.joeltherrien.randomforest.tree;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Similar to ResponseCombiner, but an IntermediateCombinedResponse represents the intermediate state of a single output in the process of being combined.
|
||||||
|
* This class is only used in OfflineForests where we can only load one Tree in memory at a time.
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
public interface IntermediateCombinedResponse<I, O> {
|
||||||
|
|
||||||
|
void processNewInput(I input);
|
||||||
|
|
||||||
|
O transformToOutput();
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,198 @@
|
||||||
|
/*
|
||||||
|
* Copyright (c) 2019 Joel Therrien.
|
||||||
|
* This program is free software: you can redistribute it and/or modify
|
||||||
|
* it under the terms of the GNU General Public License as published by
|
||||||
|
* the Free Software Foundation, either version 3 of the License, or
|
||||||
|
* (at your option) any later version.
|
||||||
|
*
|
||||||
|
* This program is distributed in the hope that it will be useful,
|
||||||
|
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
* GNU General Public License for more details.
|
||||||
|
*
|
||||||
|
* You should have received a copy of the GNU General Public License
|
||||||
|
* along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package ca.joeltherrien.randomforest.tree;
|
||||||
|
|
||||||
|
import ca.joeltherrien.randomforest.CovariateRow;
|
||||||
|
import ca.joeltherrien.randomforest.utils.IterableOfflineTree;
|
||||||
|
import lombok.AllArgsConstructor;
|
||||||
|
|
||||||
|
import java.io.File;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.Iterator;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
import java.util.stream.IntStream;
|
||||||
|
|
||||||
|
@AllArgsConstructor
|
||||||
|
public class OfflineForest<O, FO> extends Forest<O, FO> {
|
||||||
|
|
||||||
|
private final File[] treeFiles;
|
||||||
|
private final ForestResponseCombiner<O, FO> treeResponseCombiner;
|
||||||
|
|
||||||
|
public OfflineForest(File treeDirectoryPath, ForestResponseCombiner<O, FO> treeResponseCombiner){
|
||||||
|
this.treeResponseCombiner = treeResponseCombiner;
|
||||||
|
|
||||||
|
if(!treeDirectoryPath.isDirectory()){
|
||||||
|
throw new IllegalArgumentException("treeDirectoryPath must point to a directory!");
|
||||||
|
}
|
||||||
|
|
||||||
|
this.treeFiles = treeDirectoryPath.listFiles((file, s) -> s.endsWith(".tree"));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public FO evaluate(CovariateRow row) {
|
||||||
|
final List<O> predictedOutputs = new ArrayList<>(treeFiles.length);
|
||||||
|
for(final Tree<O> tree : getTrees()){
|
||||||
|
final O prediction = tree.evaluate(row);
|
||||||
|
predictedOutputs.add(prediction);
|
||||||
|
}
|
||||||
|
|
||||||
|
return treeResponseCombiner.combine(predictedOutputs);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public FO evaluateOOB(CovariateRow row) {
|
||||||
|
final List<O> predictedOutputs = new ArrayList<>(treeFiles.length);
|
||||||
|
for(final Tree<O> tree : getTrees()){
|
||||||
|
if(!tree.idInBootstrapSample(row.getId())){
|
||||||
|
final O prediction = tree.evaluate(row);
|
||||||
|
predictedOutputs.add(prediction);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return treeResponseCombiner.combine(predictedOutputs);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<FO> evaluate(List<? extends CovariateRow> rowList){
|
||||||
|
final List<IntermediateCombinedResponse<O, FO>> intermediatePredictions =
|
||||||
|
IntStream.range(0, rowList.size())
|
||||||
|
.mapToObj(i -> treeResponseCombiner.startIntermediateCombinedResponse(treeFiles.length))
|
||||||
|
.collect(Collectors.toList());
|
||||||
|
|
||||||
|
final Iterator<Tree<O>> treeIterator = getTrees().iterator();
|
||||||
|
for(int treeId = 0; treeId < treeFiles.length; treeId++){
|
||||||
|
final Tree<O> currentTree = treeIterator.next();
|
||||||
|
|
||||||
|
IntStream.range(0, rowList.size()).parallel().forEach(
|
||||||
|
rowId -> {
|
||||||
|
final CovariateRow row = rowList.get(rowId);
|
||||||
|
final O prediction = currentTree.evaluate(row);
|
||||||
|
intermediatePredictions.get(rowId).processNewInput(prediction);
|
||||||
|
}
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
return intermediatePredictions.stream().parallel()
|
||||||
|
.map(intPred -> intPred.transformToOutput())
|
||||||
|
.collect(Collectors.toList());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<FO> evaluateSerial(List<? extends CovariateRow> rowList){
|
||||||
|
final List<IntermediateCombinedResponse<O, FO>> intermediatePredictions =
|
||||||
|
IntStream.range(0, rowList.size())
|
||||||
|
.mapToObj(i -> treeResponseCombiner.startIntermediateCombinedResponse(treeFiles.length))
|
||||||
|
.collect(Collectors.toList());
|
||||||
|
|
||||||
|
final Iterator<Tree<O>> treeIterator = getTrees().iterator();
|
||||||
|
for(int treeId = 0; treeId < treeFiles.length; treeId++){
|
||||||
|
final Tree<O> currentTree = treeIterator.next();
|
||||||
|
|
||||||
|
IntStream.range(0, rowList.size()).sequential().forEach(
|
||||||
|
rowId -> {
|
||||||
|
final CovariateRow row = rowList.get(rowId);
|
||||||
|
final O prediction = currentTree.evaluate(row);
|
||||||
|
intermediatePredictions.get(rowId).processNewInput(prediction);
|
||||||
|
}
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
return intermediatePredictions.stream().sequential()
|
||||||
|
.map(intPred -> intPred.transformToOutput())
|
||||||
|
.collect(Collectors.toList());
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<FO> evaluateOOB(List<? extends CovariateRow> rowList){
|
||||||
|
final List<IntermediateCombinedResponse<O, FO>> intermediatePredictions =
|
||||||
|
IntStream.range(0, rowList.size())
|
||||||
|
.mapToObj(i -> treeResponseCombiner.startIntermediateCombinedResponse(treeFiles.length))
|
||||||
|
.collect(Collectors.toList());
|
||||||
|
|
||||||
|
final Iterator<Tree<O>> treeIterator = getTrees().iterator();
|
||||||
|
for(int treeId = 0; treeId < treeFiles.length; treeId++){
|
||||||
|
final Tree<O> currentTree = treeIterator.next();
|
||||||
|
|
||||||
|
IntStream.range(0, rowList.size()).parallel().forEach(
|
||||||
|
rowId -> {
|
||||||
|
final CovariateRow row = rowList.get(rowId);
|
||||||
|
if(!currentTree.idInBootstrapSample(row.getId())){
|
||||||
|
final O prediction = currentTree.evaluate(row);
|
||||||
|
intermediatePredictions.get(rowId).processNewInput(prediction);
|
||||||
|
}
|
||||||
|
// else do nothing; when we get the final output it will get scaled for the smaller N
|
||||||
|
}
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
return intermediatePredictions.stream().parallel()
|
||||||
|
.map(intPred -> intPred.transformToOutput())
|
||||||
|
.collect(Collectors.toList());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<FO> evaluateSerialOOB(List<? extends CovariateRow> rowList){
|
||||||
|
final List<IntermediateCombinedResponse<O, FO>> intermediatePredictions =
|
||||||
|
IntStream.range(0, rowList.size())
|
||||||
|
.mapToObj(i -> treeResponseCombiner.startIntermediateCombinedResponse(treeFiles.length))
|
||||||
|
.collect(Collectors.toList());
|
||||||
|
|
||||||
|
final Iterator<Tree<O>> treeIterator = getTrees().iterator();
|
||||||
|
for(int treeId = 0; treeId < treeFiles.length; treeId++){
|
||||||
|
final Tree<O> currentTree = treeIterator.next();
|
||||||
|
|
||||||
|
IntStream.range(0, rowList.size()).sequential().forEach(
|
||||||
|
rowId -> {
|
||||||
|
final CovariateRow row = rowList.get(rowId);
|
||||||
|
if(!currentTree.idInBootstrapSample(row.getId())){
|
||||||
|
final O prediction = currentTree.evaluate(row);
|
||||||
|
intermediatePredictions.get(rowId).processNewInput(prediction);
|
||||||
|
}
|
||||||
|
// else do nothing; when we get the final output it will get scaled for the smaller N
|
||||||
|
}
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
return intermediatePredictions.stream().sequential()
|
||||||
|
.map(intPred -> intPred.transformToOutput())
|
||||||
|
.collect(Collectors.toList());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Iterable<Tree<O>> getTrees() {
|
||||||
|
return new IterableOfflineTree<>(treeFiles);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int getNumberOfTrees() {
|
||||||
|
return treeFiles.length;
|
||||||
|
}
|
||||||
|
|
||||||
|
public OnlineForest<O, FO> createOnlineCopy(){
|
||||||
|
final List<Tree<O>> allTrees = new ArrayList<>(getNumberOfTrees());
|
||||||
|
getTrees().forEach(allTrees::add);
|
||||||
|
|
||||||
|
return OnlineForest.<O, FO>builder()
|
||||||
|
.trees(allTrees)
|
||||||
|
.treeResponseCombiner(treeResponseCombiner)
|
||||||
|
.build();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -0,0 +1,66 @@
|
||||||
|
/*
|
||||||
|
* Copyright (c) 2019 Joel Therrien.
|
||||||
|
* This program is free software: you can redistribute it and/or modify
|
||||||
|
* it under the terms of the GNU General Public License as published by
|
||||||
|
* the Free Software Foundation, either version 3 of the License, or
|
||||||
|
* (at your option) any later version.
|
||||||
|
*
|
||||||
|
* This program is distributed in the hope that it will be useful,
|
||||||
|
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
* GNU General Public License for more details.
|
||||||
|
*
|
||||||
|
* You should have received a copy of the GNU General Public License
|
||||||
|
* along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package ca.joeltherrien.randomforest.tree;
|
||||||
|
|
||||||
|
import ca.joeltherrien.randomforest.CovariateRow;
|
||||||
|
import lombok.Builder;
|
||||||
|
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
|
@Builder
|
||||||
|
public class OnlineForest<O, FO> extends Forest<O, FO> { // O = output of trees, FO = forest output. In practice O == FO, even in competing risk & survival settings
|
||||||
|
|
||||||
|
private final List<Tree<O>> trees;
|
||||||
|
private final ResponseCombiner<O, FO> treeResponseCombiner;
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public FO evaluate(CovariateRow row){
|
||||||
|
|
||||||
|
return treeResponseCombiner.combine(
|
||||||
|
trees.stream()
|
||||||
|
.map(node -> node.evaluate(row))
|
||||||
|
.collect(Collectors.toList())
|
||||||
|
);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public FO evaluateOOB(CovariateRow row){
|
||||||
|
|
||||||
|
return treeResponseCombiner.combine(
|
||||||
|
trees.stream()
|
||||||
|
.filter(tree -> !tree.idInBootstrapSample(row.getId()))
|
||||||
|
.map(node -> node.evaluate(row))
|
||||||
|
.collect(Collectors.toList())
|
||||||
|
);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<Tree<O>> getTrees(){
|
||||||
|
return Collections.unmodifiableList(trees);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int getNumberOfTrees() {
|
||||||
|
return trees.size();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
}
|
|
@ -17,15 +17,13 @@
|
||||||
package ca.joeltherrien.randomforest.tree;
|
package ca.joeltherrien.randomforest.tree;
|
||||||
|
|
||||||
import lombok.AllArgsConstructor;
|
import lombok.AllArgsConstructor;
|
||||||
import lombok.Getter;
|
import lombok.Data;
|
||||||
|
|
||||||
@AllArgsConstructor
|
@AllArgsConstructor
|
||||||
|
@Data
|
||||||
public class SplitAndScore<Y, V> {
|
public class SplitAndScore<Y, V> {
|
||||||
|
|
||||||
@Getter
|
private Split<Y, V> split;
|
||||||
private final Split<Y, V> split;
|
private Double score;
|
||||||
|
|
||||||
@Getter
|
|
||||||
private final Double score;
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -29,6 +29,7 @@ import java.util.List;
|
||||||
@ToString
|
@ToString
|
||||||
@Getter
|
@Getter
|
||||||
public class SplitNode<Y> implements Node<Y> {
|
public class SplitNode<Y> implements Node<Y> {
|
||||||
|
private static final long serialVersionUID = 1L;
|
||||||
|
|
||||||
private final Node<Y> leftHand;
|
private final Node<Y> leftHand;
|
||||||
private final Node<Y> rightHand;
|
private final Node<Y> rightHand;
|
||||||
|
|
|
@ -27,6 +27,7 @@ import java.util.List;
|
||||||
@RequiredArgsConstructor
|
@RequiredArgsConstructor
|
||||||
@ToString
|
@ToString
|
||||||
public class TerminalNode<Y> implements Node<Y> {
|
public class TerminalNode<Y> implements Node<Y> {
|
||||||
|
private static final long serialVersionUID = 1L;
|
||||||
|
|
||||||
private final Y responseValue;
|
private final Y responseValue;
|
||||||
|
|
||||||
|
|
|
@ -23,6 +23,7 @@ import java.util.Arrays;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
public class Tree<Y> implements Node<Y> {
|
public class Tree<Y> implements Node<Y> {
|
||||||
|
private static final long serialVersionUID = 1L;
|
||||||
|
|
||||||
@Getter
|
@Getter
|
||||||
private final Node<Y> rootNode;
|
private final Node<Y> rootNode;
|
||||||
|
|
|
@ -17,7 +17,9 @@
|
||||||
package ca.joeltherrien.randomforest.tree;
|
package ca.joeltherrien.randomforest.tree;
|
||||||
|
|
||||||
import ca.joeltherrien.randomforest.Row;
|
import ca.joeltherrien.randomforest.Row;
|
||||||
|
import ca.joeltherrien.randomforest.VisibleForTesting;
|
||||||
import ca.joeltherrien.randomforest.covariates.Covariate;
|
import ca.joeltherrien.randomforest.covariates.Covariate;
|
||||||
|
import ca.joeltherrien.randomforest.utils.SingletonIterator;
|
||||||
import lombok.AccessLevel;
|
import lombok.AccessLevel;
|
||||||
import lombok.AllArgsConstructor;
|
import lombok.AllArgsConstructor;
|
||||||
import lombok.Builder;
|
import lombok.Builder;
|
||||||
|
@ -72,31 +74,12 @@ public class TreeTrainer<Y, O> {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
// Now that we have the best split; we need to handle any NAs that were dropped off
|
// Now that we have the best split; we need to handle any NAs that were dropped off
|
||||||
final double probabilityLeftHand = (double) bestSplit.leftHand.size() /
|
final double probabilityLeftHand = (double) bestSplit.leftHand.size() /
|
||||||
(double) (bestSplit.leftHand.size() + bestSplit.rightHand.size());
|
(double) (bestSplit.leftHand.size() + bestSplit.rightHand.size());
|
||||||
|
|
||||||
// Assign missing values to the split if necessary
|
// Assign missing values to the split if necessary
|
||||||
if(covariates.get(bestSplit.getSplitRule().getParentCovariateIndex()).hasNAs()){
|
bestSplit = randomlyAssignNAs(data, bestSplit, random);
|
||||||
bestSplit = bestSplit.modifiableClone(); // the lists in bestSplit are otherwise usually unmodifiable lists
|
|
||||||
|
|
||||||
for(Row<Y> row : data) {
|
|
||||||
final int covariateIndex = bestSplit.getSplitRule().getParentCovariateIndex();
|
|
||||||
|
|
||||||
if(row.getValueByIndex(covariateIndex).isNA()) {
|
|
||||||
final boolean randomDecision = random.nextDouble() <= probabilityLeftHand;
|
|
||||||
|
|
||||||
if(randomDecision){
|
|
||||||
bestSplit.getLeftHand().add(row);
|
|
||||||
}
|
|
||||||
else{
|
|
||||||
bestSplit.getRightHand().add(row);
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
final Node<O> leftNode;
|
final Node<O> leftNode;
|
||||||
final Node<O> rightNode;
|
final Node<O> rightNode;
|
||||||
|
@ -144,7 +127,8 @@ public class TreeTrainer<Y, O> {
|
||||||
return splitCovariates;
|
return splitCovariates;
|
||||||
}
|
}
|
||||||
|
|
||||||
private Split<Y, ?> findBestSplitRule(List<Row<Y>> data, List<Covariate> covariatesToTry, Random random){
|
@VisibleForTesting
|
||||||
|
public Split<Y, ?> findBestSplitRule(List<Row<Y>> data, List<Covariate> covariatesToTry, Random random){
|
||||||
|
|
||||||
SplitAndScore<Y, ?> bestSplitAndScore = null;
|
SplitAndScore<Y, ?> bestSplitAndScore = null;
|
||||||
final SplitFinder noGenericSplitFinder = splitFinder; // cause Java generics are sometimes too frustrating
|
final SplitFinder noGenericSplitFinder = splitFinder; // cause Java generics are sometimes too frustrating
|
||||||
|
@ -157,10 +141,32 @@ public class TreeTrainer<Y, O> {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
final SplitAndScore<Y, ?> candidateSplitAndScore = noGenericSplitFinder.findBestSplit(iterator);
|
SplitAndScore<Y, ?> candidateSplitAndScore = noGenericSplitFinder.findBestSplit(iterator);
|
||||||
|
|
||||||
if(candidateSplitAndScore != null && (bestSplitAndScore == null ||
|
|
||||||
candidateSplitAndScore.getScore() > bestSplitAndScore.getScore())) {
|
if(candidateSplitAndScore == null){
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// This score was based on splitting only non-NA values. However, there might be a similar covariate we are also considering
|
||||||
|
// that is just as good at splitting but has less NAs; we should thus penalize the split score for variables with NAs
|
||||||
|
// We do this by randomly assigning the NAs and then recalculating the split score on the best split we already have.
|
||||||
|
//
|
||||||
|
// We only have to penalize the score though if we know it's possible that this might be the best split. If it's not,
|
||||||
|
// then we can skip the computations.
|
||||||
|
final boolean mayBeGoodSplit = bestSplitAndScore == null ||
|
||||||
|
candidateSplitAndScore.getScore() > bestSplitAndScore.getScore();
|
||||||
|
if(mayBeGoodSplit && covariate.haveNASplitPenalty()){
|
||||||
|
Split<Y, ?> candiateSplitWithNAs = randomlyAssignNAs(data, candidateSplitAndScore.getSplit(), random);
|
||||||
|
final Iterator<Split<Y, ?>> newSplitWithRandomNAs = new SingletonIterator<>(candiateSplitWithNAs);
|
||||||
|
final double newScore = splitFinder.findBestSplit(newSplitWithRandomNAs).getScore();
|
||||||
|
|
||||||
|
// There's a chance that NAs might add noise to *improve* the score; but we want to ensure we penalize it.
|
||||||
|
// Thus we only change the score if its worse.
|
||||||
|
candidateSplitAndScore.setScore(Math.min(newScore, candidateSplitAndScore.getScore()));
|
||||||
|
}
|
||||||
|
|
||||||
|
if(bestSplitAndScore == null || candidateSplitAndScore.getScore() > bestSplitAndScore.getScore()) {
|
||||||
bestSplitAndScore = candidateSplitAndScore;
|
bestSplitAndScore = candidateSplitAndScore;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -174,6 +180,38 @@ public class TreeTrainer<Y, O> {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private <V> Split<Y, V> randomlyAssignNAs(List<Row<Y>> data, Split<Y, V> existingSplit, Random random){
|
||||||
|
|
||||||
|
// Now that we have the best split; we need to handle any NAs that were dropped off
|
||||||
|
final double probabilityLeftHand = (double) existingSplit.leftHand.size() /
|
||||||
|
(double) (existingSplit.leftHand.size() + existingSplit.rightHand.size());
|
||||||
|
|
||||||
|
|
||||||
|
final int covariateIndex = existingSplit.getSplitRule().getParentCovariateIndex();
|
||||||
|
|
||||||
|
// Assign missing values to the split if necessary
|
||||||
|
if(covariates.get(existingSplit.getSplitRule().getParentCovariateIndex()).hasNAs()){
|
||||||
|
existingSplit = existingSplit.modifiableClone(); // the lists in bestSplit are otherwise usually unmodifiable lists
|
||||||
|
|
||||||
|
for(Row<Y> row : data) {
|
||||||
|
if(row.getValueByIndex(covariateIndex).isNA()) {
|
||||||
|
final boolean randomDecision = random.nextDouble() <= probabilityLeftHand;
|
||||||
|
|
||||||
|
if(randomDecision){
|
||||||
|
existingSplit.getLeftHand().add(row);
|
||||||
|
}
|
||||||
|
else{
|
||||||
|
existingSplit.getRightHand().add(row);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return existingSplit;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
private boolean nodeIsPure(List<Row<Y>> data){
|
private boolean nodeIsPure(List<Row<Y>> data){
|
||||||
if(!checkNodePurity){
|
if(!checkNodePurity){
|
||||||
return false;
|
return false;
|
||||||
|
|
|
@ -0,0 +1,24 @@
|
||||||
|
package ca.joeltherrien.randomforest.tree.vimp;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Simple interface for VariableImportanceCalculator; takes in a List of observed responses and a List of predictions
|
||||||
|
* and produces an average error measure.
|
||||||
|
*
|
||||||
|
* @param <Y> The class of the responses.
|
||||||
|
* @param <P> The class of the predictions.
|
||||||
|
*/
|
||||||
|
public interface ErrorCalculator<Y, P>{
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Compares the observed responses with the predictions to produce an average error measure.
|
||||||
|
* Lower errors should indicate a better model fit.
|
||||||
|
*
|
||||||
|
* @param responses
|
||||||
|
* @param predictions
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
double averageError(List<Y> responses, List<P> predictions);
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,63 @@
|
||||||
|
package ca.joeltherrien.randomforest.tree.vimp;
|
||||||
|
|
||||||
|
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskFunctions;
|
||||||
|
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
|
||||||
|
import ca.joeltherrien.randomforest.responses.competingrisk.IBSCalculator;
|
||||||
|
import ca.joeltherrien.randomforest.utils.RightContinuousStepFunction;
|
||||||
|
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Implements ErrorCalculator; essentially just wraps around IBSCalculator to fit into VariableImportanceCalculator.
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
public class IBSErrorCalculatorWrapper implements ErrorCalculator<CompetingRiskResponse, CompetingRiskFunctions> {
|
||||||
|
|
||||||
|
private final IBSCalculator calculator;
|
||||||
|
private final int[] events;
|
||||||
|
private final double integrationUpperBound;
|
||||||
|
private final double[] eventWeights;
|
||||||
|
|
||||||
|
public IBSErrorCalculatorWrapper(IBSCalculator calculator, int[] events, double integrationUpperBound, double[] eventWeights) {
|
||||||
|
this.calculator = calculator;
|
||||||
|
this.events = events;
|
||||||
|
this.integrationUpperBound = integrationUpperBound;
|
||||||
|
this.eventWeights = eventWeights;
|
||||||
|
}
|
||||||
|
|
||||||
|
public IBSErrorCalculatorWrapper(IBSCalculator calculator, int[] events, double integrationUpperBound) {
|
||||||
|
this.calculator = calculator;
|
||||||
|
this.events = events;
|
||||||
|
this.integrationUpperBound = integrationUpperBound;
|
||||||
|
this.eventWeights = new double[events.length];
|
||||||
|
|
||||||
|
Arrays.fill(this.eventWeights, 1.0); // default is to just sum all errors together
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public double averageError(List<CompetingRiskResponse> responses, List<CompetingRiskFunctions> predictions) {
|
||||||
|
final double[] errors = new double[events.length];
|
||||||
|
final double n = responses.size();
|
||||||
|
|
||||||
|
for(int i=0; i < responses.size(); i++){
|
||||||
|
final CompetingRiskResponse response = responses.get(i);
|
||||||
|
final CompetingRiskFunctions prediction = predictions.get(i);
|
||||||
|
|
||||||
|
for(int k=0; k < this.events.length; k++){
|
||||||
|
final int event = this.events[k];
|
||||||
|
final RightContinuousStepFunction cif = prediction.getCumulativeIncidenceFunction(event);
|
||||||
|
errors[k] += calculator.calculateError(response, cif, event, integrationUpperBound) / n;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
double totalError = 0.0;
|
||||||
|
for(int k=0; k < this.events.length; k++){
|
||||||
|
totalError += this.eventWeights[k] * errors[k];
|
||||||
|
}
|
||||||
|
|
||||||
|
return totalError;
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,23 @@
|
||||||
|
package ca.joeltherrien.randomforest.tree.vimp;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
public class RegressionErrorCalculator implements ErrorCalculator<Double, Double>{
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public double averageError(List<Double> responses, List<Double> predictions) {
|
||||||
|
double mean = 0.0;
|
||||||
|
final double n = responses.size();
|
||||||
|
|
||||||
|
for(int i=0; i<responses.size(); i++){
|
||||||
|
final double response = responses.get(i);
|
||||||
|
final double prediction = predictions.get(i);
|
||||||
|
|
||||||
|
final double difference = response - prediction;
|
||||||
|
|
||||||
|
mean += difference * difference / n;
|
||||||
|
}
|
||||||
|
|
||||||
|
return mean;
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,117 @@
|
||||||
|
package ca.joeltherrien.randomforest.tree.vimp;
|
||||||
|
|
||||||
|
import ca.joeltherrien.randomforest.CovariateRow;
|
||||||
|
import ca.joeltherrien.randomforest.Row;
|
||||||
|
import ca.joeltherrien.randomforest.covariates.Covariate;
|
||||||
|
import ca.joeltherrien.randomforest.tree.Tree;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Optional;
|
||||||
|
import java.util.Random;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
|
|
||||||
|
public class VariableImportanceCalculator<Y, P> {
|
||||||
|
|
||||||
|
private final ErrorCalculator<Y, P> errorCalculator;
|
||||||
|
private final List<Tree<P>> trees;
|
||||||
|
private final List<Row<Y>> observations;
|
||||||
|
|
||||||
|
private final boolean isTrainingSet; // If true, then we use out-of-bag predictions
|
||||||
|
private final double[] baselineErrors;
|
||||||
|
|
||||||
|
public VariableImportanceCalculator(
|
||||||
|
ErrorCalculator<Y, P> errorCalculator,
|
||||||
|
List<Tree<P>> trees,
|
||||||
|
List<Row<Y>> observations,
|
||||||
|
boolean isTrainingSet
|
||||||
|
){
|
||||||
|
this.errorCalculator = errorCalculator;
|
||||||
|
this.trees = trees;
|
||||||
|
this.observations = observations;
|
||||||
|
this.isTrainingSet = isTrainingSet;
|
||||||
|
|
||||||
|
|
||||||
|
try {
|
||||||
|
|
||||||
|
this.baselineErrors = new double[trees.size()];
|
||||||
|
for (int i = 0; i < baselineErrors.length; i++) {
|
||||||
|
final Tree<P> tree = trees.get(i);
|
||||||
|
final List<Row<Y>> oobSubset = getAppropriateSubset(observations, tree); // may not actually be OOB depending on isTrainingSet
|
||||||
|
final List<Y> responses = oobSubset.stream().map(Row::getResponse).collect(Collectors.toList());
|
||||||
|
|
||||||
|
this.baselineErrors[i] = errorCalculator.averageError(responses, makePredictions(oobSubset, tree));
|
||||||
|
}
|
||||||
|
|
||||||
|
} catch(Exception e){
|
||||||
|
e.printStackTrace();
|
||||||
|
throw e;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns an array of importance values for every Tree for the given Covariate.
|
||||||
|
*
|
||||||
|
* @param covariate The Covariate to scramble.
|
||||||
|
* @param random
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
public double[] calculateVariableImportanceRaw(Covariate covariate, Optional<Random> random){
|
||||||
|
|
||||||
|
final double[] vimp = new double[trees.size()];
|
||||||
|
for(int i = 0; i < vimp.length; i++){
|
||||||
|
final Tree<P> tree = trees.get(i);
|
||||||
|
final List<Row<Y>> oobSubset = getAppropriateSubset(observations, tree); // may not actually be OOB depending on isTrainingSet
|
||||||
|
final List<Y> responses = oobSubset.stream().map(Row::getResponse).collect(Collectors.toList());
|
||||||
|
final List<CovariateRow> scrambledValues = CovariateRow.scrambleCovariateValues(oobSubset, covariate, random);
|
||||||
|
|
||||||
|
final double error = errorCalculator.averageError(responses, makePredictions(scrambledValues, tree));
|
||||||
|
|
||||||
|
vimp[i] = error - this.baselineErrors[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
return vimp;
|
||||||
|
}
|
||||||
|
|
||||||
|
public double calculateVariableImportanceZScore(Covariate covariate, Optional<Random> random){
|
||||||
|
final double[] vimpArray = calculateVariableImportanceRaw(covariate, random);
|
||||||
|
|
||||||
|
double mean = 0.0;
|
||||||
|
double variance = 0.0;
|
||||||
|
final double numTrees = vimpArray.length;
|
||||||
|
|
||||||
|
for(double vimp : vimpArray){
|
||||||
|
mean += vimp / numTrees;
|
||||||
|
}
|
||||||
|
for(double vimp : vimpArray){
|
||||||
|
variance += (vimp - mean)*(vimp - mean) / (numTrees - 1.0);
|
||||||
|
}
|
||||||
|
|
||||||
|
final double standardError = Math.sqrt(variance / numTrees);
|
||||||
|
|
||||||
|
return mean / standardError;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
// Assume rowList has already been filtered for OOB
|
||||||
|
private List<P> makePredictions(List<? extends CovariateRow> rowList, Tree<P> tree){
|
||||||
|
return rowList.stream()
|
||||||
|
.map(tree::evaluate)
|
||||||
|
.collect(Collectors.toList());
|
||||||
|
}
|
||||||
|
|
||||||
|
private List<Row<Y>> getAppropriateSubset(List<Row<Y>> initialList, Tree<P> tree){
|
||||||
|
if(!isTrainingSet){
|
||||||
|
return initialList; // no need to make any subsets
|
||||||
|
}
|
||||||
|
|
||||||
|
return initialList.stream()
|
||||||
|
.filter(row -> !tree.idInBootstrapSample(row.getId()))
|
||||||
|
.collect(Collectors.toList());
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
}
|
|
@ -16,9 +16,7 @@
|
||||||
|
|
||||||
package ca.joeltherrien.randomforest.utils;
|
package ca.joeltherrien.randomforest.utils;
|
||||||
|
|
||||||
import ca.joeltherrien.randomforest.tree.Forest;
|
import ca.joeltherrien.randomforest.tree.*;
|
||||||
import ca.joeltherrien.randomforest.tree.ResponseCombiner;
|
|
||||||
import ca.joeltherrien.randomforest.tree.Tree;
|
|
||||||
|
|
||||||
import java.io.*;
|
import java.io.*;
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
|
@ -27,12 +25,17 @@ import java.util.zip.GZIPOutputStream;
|
||||||
|
|
||||||
public class DataUtils {
|
public class DataUtils {
|
||||||
|
|
||||||
public static <O, FO> Forest<O, FO> loadForest(File folder, ResponseCombiner<O, FO> treeResponseCombiner) throws IOException, ClassNotFoundException {
|
public static <O, FO> OnlineForest<O, FO> loadOnlineForest(File folder, ResponseCombiner<O, FO> treeResponseCombiner) throws IOException, ClassNotFoundException {
|
||||||
if(!folder.isDirectory()){
|
if(!folder.isDirectory()){
|
||||||
throw new IllegalArgumentException("Tree directory must be a directory!");
|
throw new IllegalArgumentException("Tree directory must be a directory!");
|
||||||
}
|
}
|
||||||
|
|
||||||
final File[] treeFiles = folder.listFiles((file, s) -> s.endsWith(".tree"));
|
final File[] treeFiles = folder.listFiles((file, s) -> s.endsWith(".tree"));
|
||||||
|
|
||||||
|
return loadOnlineForest(treeFiles, treeResponseCombiner);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static <O, FO> OnlineForest<O, FO> loadOnlineForest(File[] treeFiles, ResponseCombiner<O, FO> treeResponseCombiner) throws IOException, ClassNotFoundException {
|
||||||
final List<File> treeFileList = Arrays.asList(treeFiles);
|
final List<File> treeFileList = Arrays.asList(treeFiles);
|
||||||
|
|
||||||
Collections.sort(treeFileList, Comparator.comparing(File::getName));
|
Collections.sort(treeFileList, Comparator.comparing(File::getName));
|
||||||
|
@ -48,16 +51,16 @@ public class DataUtils {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return Forest.<O, FO>builder()
|
return OnlineForest.<O, FO>builder()
|
||||||
.trees(treeList)
|
.trees(treeList)
|
||||||
.treeResponseCombiner(treeResponseCombiner)
|
.treeResponseCombiner(treeResponseCombiner)
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public static <O, FO> Forest<O, FO> loadForest(String folder, ResponseCombiner<O, FO> treeResponseCombiner) throws IOException, ClassNotFoundException {
|
public static <O, FO> OnlineForest<O, FO> loadOnlineForest(String folder, ResponseCombiner<O, FO> treeResponseCombiner) throws IOException, ClassNotFoundException {
|
||||||
final File directory = new File(folder);
|
final File directory = new File(folder);
|
||||||
return loadForest(directory, treeResponseCombiner);
|
return loadOnlineForest(directory, treeResponseCombiner);
|
||||||
}
|
}
|
||||||
|
|
||||||
public static void saveObject(Serializable object, String filename) throws IOException {
|
public static void saveObject(Serializable object, String filename) throws IOException {
|
||||||
|
|
|
@ -1,113 +0,0 @@
|
||||||
/*
|
|
||||||
* Copyright (c) 2019 Joel Therrien.
|
|
||||||
* This program is free software: you can redistribute it and/or modify
|
|
||||||
* it under the terms of the GNU General Public License as published by
|
|
||||||
* the Free Software Foundation, either version 3 of the License, or
|
|
||||||
* (at your option) any later version.
|
|
||||||
*
|
|
||||||
* This program is distributed in the hope that it will be useful,
|
|
||||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
||||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
||||||
* GNU General Public License for more details.
|
|
||||||
*
|
|
||||||
* You should have received a copy of the GNU General Public License
|
|
||||||
* along with this program. If not, see <https://www.gnu.org/licenses/>.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package ca.joeltherrien.randomforest.utils;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Represents a function represented by discrete points. However, the function may be right-continuous or left-continuous
|
|
||||||
* at a given point, with no consistency. This function tracks that.
|
|
||||||
*/
|
|
||||||
public final class DiscontinuousStepFunction extends StepFunction {
|
|
||||||
|
|
||||||
private final double[] y;
|
|
||||||
private final boolean[] isLeftContinuous;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Represents the value that should be returned by evaluate if there are points prior to the time the function is being evaluated at.
|
|
||||||
*
|
|
||||||
* Map be null.
|
|
||||||
*/
|
|
||||||
private final double defaultY;
|
|
||||||
|
|
||||||
public DiscontinuousStepFunction(double[] x, double[] y, boolean[] isLeftContinuous, double defaultY) {
|
|
||||||
super(x);
|
|
||||||
this.y = y;
|
|
||||||
this.isLeftContinuous = isLeftContinuous;
|
|
||||||
this.defaultY = defaultY;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public double evaluate(double time){
|
|
||||||
int index = Utils.binarySearchLessThan(0, x.length, x, time);
|
|
||||||
if(index < 0){
|
|
||||||
return defaultY;
|
|
||||||
}
|
|
||||||
else{
|
|
||||||
if(x[index] == time){
|
|
||||||
return evaluateByIndex(index);
|
|
||||||
}
|
|
||||||
else{
|
|
||||||
return y[index];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public double evaluatePrevious(double time){
|
|
||||||
int index = Utils.binarySearchLessThan(0, x.length, x, time) - 1;
|
|
||||||
if(index < 0){
|
|
||||||
return defaultY;
|
|
||||||
}
|
|
||||||
else{
|
|
||||||
if(x[index] == time){
|
|
||||||
return evaluateByIndex(index);
|
|
||||||
}
|
|
||||||
else{
|
|
||||||
return y[index];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public double evaluateByIndex(int i) {
|
|
||||||
if(isLeftContinuous[i]){
|
|
||||||
i -= 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
if(i < 0){
|
|
||||||
return defaultY;
|
|
||||||
}
|
|
||||||
|
|
||||||
return y[i];
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String toString(){
|
|
||||||
final StringBuilder builder = new StringBuilder();
|
|
||||||
builder.append("Default point: ");
|
|
||||||
builder.append(defaultY);
|
|
||||||
builder.append("\n");
|
|
||||||
|
|
||||||
for(int i=0; i<x.length; i++){
|
|
||||||
builder.append("x:");
|
|
||||||
builder.append(x[i]);
|
|
||||||
builder.append('\t');
|
|
||||||
|
|
||||||
if(isLeftContinuous[i]){
|
|
||||||
builder.append("*y:");
|
|
||||||
}
|
|
||||||
else{
|
|
||||||
builder.append("y*:");
|
|
||||||
}
|
|
||||||
builder.append(y[i]);
|
|
||||||
builder.append("\n");
|
|
||||||
}
|
|
||||||
|
|
||||||
return builder.toString();
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
|
@ -0,0 +1,68 @@
|
||||||
|
/*
|
||||||
|
* Copyright (c) 2019 Joel Therrien.
|
||||||
|
* This program is free software: you can redistribute it and/or modify
|
||||||
|
* it under the terms of the GNU General Public License as published by
|
||||||
|
* the Free Software Foundation, either version 3 of the License, or
|
||||||
|
* (at your option) any later version.
|
||||||
|
*
|
||||||
|
* This program is distributed in the hope that it will be useful,
|
||||||
|
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
* GNU General Public License for more details.
|
||||||
|
*
|
||||||
|
* You should have received a copy of the GNU General Public License
|
||||||
|
* along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package ca.joeltherrien.randomforest.utils;
|
||||||
|
|
||||||
|
import ca.joeltherrien.randomforest.tree.Tree;
|
||||||
|
import lombok.RequiredArgsConstructor;
|
||||||
|
|
||||||
|
import java.io.File;
|
||||||
|
import java.io.FileInputStream;
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.io.ObjectInputStream;
|
||||||
|
import java.util.Iterator;
|
||||||
|
import java.util.zip.GZIPInputStream;
|
||||||
|
|
||||||
|
@RequiredArgsConstructor
|
||||||
|
public class IterableOfflineTree<Y> implements Iterable<Tree<Y>> {
|
||||||
|
|
||||||
|
private final File[] treeFiles;
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Iterator<Tree<Y>> iterator() {
|
||||||
|
return new OfflineTreeIterator<>(treeFiles);
|
||||||
|
}
|
||||||
|
|
||||||
|
@RequiredArgsConstructor
|
||||||
|
public static class OfflineTreeIterator<Y> implements Iterator<Tree<Y>>{
|
||||||
|
private final File[] treeFiles;
|
||||||
|
private int position = 0;
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean hasNext() {
|
||||||
|
return position < treeFiles.length;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Tree<Y> next() {
|
||||||
|
final File treeFile = treeFiles[position];
|
||||||
|
position++;
|
||||||
|
|
||||||
|
|
||||||
|
try {
|
||||||
|
final ObjectInputStream inputStream= new ObjectInputStream(new GZIPInputStream(new FileInputStream(treeFile)));
|
||||||
|
final Tree<Y> tree = (Tree) inputStream.readObject();
|
||||||
|
return tree;
|
||||||
|
} catch (IOException | ClassNotFoundException e) {
|
||||||
|
e.printStackTrace();
|
||||||
|
throw new RuntimeException("Failed to load tree for " + treeFile.toString());
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
}
|
|
@ -1,129 +0,0 @@
|
||||||
/*
|
|
||||||
* Copyright (c) 2019 Joel Therrien.
|
|
||||||
* This program is free software: you can redistribute it and/or modify
|
|
||||||
* it under the terms of the GNU General Public License as published by
|
|
||||||
* the Free Software Foundation, either version 3 of the License, or
|
|
||||||
* (at your option) any later version.
|
|
||||||
*
|
|
||||||
* This program is distributed in the hope that it will be useful,
|
|
||||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
||||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
||||||
* GNU General Public License for more details.
|
|
||||||
*
|
|
||||||
* You should have received a copy of the GNU General Public License
|
|
||||||
* along with this program. If not, see <https://www.gnu.org/licenses/>.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package ca.joeltherrien.randomforest.utils;
|
|
||||||
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.ListIterator;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Represents a function represented by discrete points. We assume that the function is a stepwise left-continuous
|
|
||||||
* function, constant at the value of the previous encountered point.
|
|
||||||
*
|
|
||||||
*/
|
|
||||||
public final class LeftContinuousStepFunction extends StepFunction {
|
|
||||||
|
|
||||||
private final double[] y;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Represents the value that should be returned by evaluate if there are points prior to the time the function is being evaluated at.
|
|
||||||
*
|
|
||||||
* Map be null.
|
|
||||||
*/
|
|
||||||
private final double defaultY;
|
|
||||||
|
|
||||||
public LeftContinuousStepFunction(double[] x, double[] y, double defaultY) {
|
|
||||||
super(x);
|
|
||||||
this.y = y;
|
|
||||||
this.defaultY = defaultY;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* This isn't a formal constructor because of limitations with abstract classes.
|
|
||||||
*
|
|
||||||
* @param pointList
|
|
||||||
* @param defaultY
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
public static LeftContinuousStepFunction constructFromPoints(final List<Point> pointList, final double defaultY){
|
|
||||||
|
|
||||||
final double[] x = new double[pointList.size()];
|
|
||||||
final double[] y = new double[pointList.size()];
|
|
||||||
|
|
||||||
final ListIterator<Point> pointIterator = pointList.listIterator();
|
|
||||||
while(pointIterator.hasNext()){
|
|
||||||
final int index = pointIterator.nextIndex();
|
|
||||||
final Point currentPoint = pointIterator.next();
|
|
||||||
|
|
||||||
x[index] = currentPoint.getTime();
|
|
||||||
y[index] = currentPoint.getY();
|
|
||||||
}
|
|
||||||
|
|
||||||
return new LeftContinuousStepFunction(x, y, defaultY);
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public double evaluate(double time){
|
|
||||||
int index = Utils.binarySearchLessThan(0, x.length, x, time);
|
|
||||||
if(index < 0){
|
|
||||||
return defaultY;
|
|
||||||
}
|
|
||||||
else{
|
|
||||||
if(x[index] == time){
|
|
||||||
return evaluateByIndex(index-1);
|
|
||||||
}
|
|
||||||
else{
|
|
||||||
return y[index];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public double evaluatePrevious(double time){
|
|
||||||
int index = Utils.binarySearchLessThan(0, x.length, x, time) - 1;
|
|
||||||
if(index < 0){
|
|
||||||
return defaultY;
|
|
||||||
}
|
|
||||||
else{
|
|
||||||
if(x[index] == time){
|
|
||||||
return evaluateByIndex(index-1);
|
|
||||||
}
|
|
||||||
else{
|
|
||||||
return y[index];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public double evaluateByIndex(int i) {
|
|
||||||
if(i < 0){
|
|
||||||
return defaultY;
|
|
||||||
}
|
|
||||||
|
|
||||||
return y[i];
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String toString(){
|
|
||||||
final StringBuilder builder = new StringBuilder();
|
|
||||||
builder.append("Default point: ");
|
|
||||||
builder.append(defaultY);
|
|
||||||
builder.append("\n");
|
|
||||||
|
|
||||||
for(int i=0; i<x.length; i++){
|
|
||||||
builder.append("x:");
|
|
||||||
builder.append(x[i]);
|
|
||||||
builder.append("\ty:");
|
|
||||||
builder.append(y[i]);
|
|
||||||
builder.append("\n");
|
|
||||||
}
|
|
||||||
|
|
||||||
return builder.toString();
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
|
@ -26,6 +26,8 @@ import java.io.Serializable;
|
||||||
*/
|
*/
|
||||||
@Data
|
@Data
|
||||||
public class Point implements Serializable {
|
public class Point implements Serializable {
|
||||||
|
private static final long serialVersionUID = 1L;
|
||||||
|
|
||||||
private final double time;
|
private final double time;
|
||||||
private final double y;
|
private final double y;
|
||||||
}
|
}
|
||||||
|
|
|
@ -188,4 +188,24 @@ public final class RUtils {
|
||||||
return responses;
|
return responses;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public static List<Object> produceSublist(List<Object> initialList, int[] indices){
|
||||||
|
final List<Object> newList = new ArrayList<>(indices.length);
|
||||||
|
|
||||||
|
for(int i : indices){
|
||||||
|
newList.add(initialList.get(i));
|
||||||
|
}
|
||||||
|
|
||||||
|
return newList;
|
||||||
|
}
|
||||||
|
|
||||||
|
public static File[] getTreeFileArray(String folderPath, int endingId){
|
||||||
|
final File[] fileArray = new File[endingId];
|
||||||
|
|
||||||
|
for(int i = 1; i <= endingId; i++){
|
||||||
|
fileArray[i-1] = new File(folderPath + "/tree-" + i + ".tree");
|
||||||
|
}
|
||||||
|
|
||||||
|
return fileArray;
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -16,8 +16,14 @@
|
||||||
|
|
||||||
package ca.joeltherrien.randomforest.utils;
|
package ca.joeltherrien.randomforest.utils;
|
||||||
|
|
||||||
|
import lombok.Getter;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.Arrays;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.ListIterator;
|
import java.util.ListIterator;
|
||||||
|
import java.util.function.DoubleBinaryOperator;
|
||||||
|
import java.util.function.DoubleUnaryOperator;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Represents a function represented by discrete points. We assume that the function is a stepwise right-continuous
|
* Represents a function represented by discrete points. We assume that the function is a stepwise right-continuous
|
||||||
|
@ -26,13 +32,16 @@ import java.util.ListIterator;
|
||||||
*/
|
*/
|
||||||
public final class RightContinuousStepFunction extends StepFunction {
|
public final class RightContinuousStepFunction extends StepFunction {
|
||||||
|
|
||||||
|
private static final long serialVersionUID = 1L;
|
||||||
|
|
||||||
private final double[] y;
|
private final double[] y;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Represents the value that should be returned by evaluate if there are points prior to the time the function is being evaluated at.
|
* Represents the value that should be returned by evaluate if there are no points prior to the time the function is being evaluated at.
|
||||||
*
|
*
|
||||||
* Map be null.
|
* May not be null.
|
||||||
*/
|
*/
|
||||||
|
@Getter
|
||||||
private final double defaultY;
|
private final double defaultY;
|
||||||
|
|
||||||
public RightContinuousStepFunction(double[] x, double[] y, double defaultY) {
|
public RightContinuousStepFunction(double[] x, double[] y, double defaultY) {
|
||||||
|
@ -127,7 +136,12 @@ public final class RightContinuousStepFunction extends StepFunction {
|
||||||
}
|
}
|
||||||
|
|
||||||
if(to < from){
|
if(to < from){
|
||||||
return integrate(to, from);
|
return -integrate(to, from);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Edge case - no points; just defaultY
|
||||||
|
if(this.x.length == 0){
|
||||||
|
return (to - from) * this.defaultY;
|
||||||
}
|
}
|
||||||
|
|
||||||
double summation = 0.0;
|
double summation = 0.0;
|
||||||
|
@ -170,7 +184,7 @@ public final class RightContinuousStepFunction extends StepFunction {
|
||||||
final double currentTime = xPoints[i];
|
final double currentTime = xPoints[i];
|
||||||
final double currentHeight = evaluateByIndex(i);
|
final double currentHeight = evaluateByIndex(i);
|
||||||
|
|
||||||
if(i == xPoints.length-1 || xPoints[i+1] > to){
|
if(i == xPoints.length-1 || xPoints[i+1] >= to){
|
||||||
summation += currentHeight * (to - currentTime);
|
summation += currentHeight * (to - currentTime);
|
||||||
return summation;
|
return summation;
|
||||||
}
|
}
|
||||||
|
@ -186,5 +200,76 @@ public final class RightContinuousStepFunction extends StepFunction {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public RightContinuousStepFunction unaryOperation(DoubleUnaryOperator operator){
|
||||||
|
final double newDefaultY = operator.applyAsDouble(this.defaultY);
|
||||||
|
final double[] newY = Arrays.stream(this.getY()).map(operator).toArray();
|
||||||
|
|
||||||
|
return new RightContinuousStepFunction(this.getX(), newY, newDefaultY);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
public static RightContinuousStepFunction biOperation(RightContinuousStepFunction funLeft,
|
||||||
|
RightContinuousStepFunction funRight,
|
||||||
|
DoubleBinaryOperator operator){
|
||||||
|
|
||||||
|
final double newDefaultY = operator.applyAsDouble(funLeft.defaultY, funRight.defaultY);
|
||||||
|
final double[] leftX = funLeft.x;
|
||||||
|
final double[] rightX = funRight.x;
|
||||||
|
|
||||||
|
final List<Point> combinedPoints = new ArrayList<>(leftX.length + rightX.length);
|
||||||
|
|
||||||
|
// These indexes represent the times that have *already* been processed.
|
||||||
|
// They start at -1 because we already processed the defaultY values.
|
||||||
|
int indexLeft = -1;
|
||||||
|
int indexRight = -1;
|
||||||
|
|
||||||
|
// This while-loop will keep going until one of the functions reaches the ends of its points
|
||||||
|
while(indexLeft < leftX.length-1 && indexRight < rightX.length-1){
|
||||||
|
final double time;
|
||||||
|
if(leftX[indexLeft+1] < rightX[indexRight+1]){
|
||||||
|
indexLeft += 1;
|
||||||
|
|
||||||
|
time = leftX[indexLeft];
|
||||||
|
combinedPoints.add(new Point(time, operator.applyAsDouble(funLeft.evaluateByIndex(indexLeft), funRight.evaluateByIndex(indexRight))));
|
||||||
|
}
|
||||||
|
else if(leftX[indexLeft+1] > rightX[indexRight+1]){
|
||||||
|
indexRight += 1;
|
||||||
|
|
||||||
|
time = rightX[indexRight];
|
||||||
|
combinedPoints.add(new Point(time, operator.applyAsDouble(funLeft.evaluateByIndex(indexLeft), funRight.evaluateByIndex(indexRight))));
|
||||||
|
}
|
||||||
|
else{ // equal times
|
||||||
|
indexLeft += 1;
|
||||||
|
indexRight += 1;
|
||||||
|
|
||||||
|
time = leftX[indexLeft];
|
||||||
|
combinedPoints.add(new Point(time, operator.applyAsDouble(funLeft.evaluateByIndex(indexLeft), funRight.evaluateByIndex(indexRight))));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// At this point, at least one of function left or function right has reached the end of its points
|
||||||
|
|
||||||
|
// This while-loop occurring implies that functionRight can not move forward anymore
|
||||||
|
while(indexLeft < leftX.length-1){
|
||||||
|
indexLeft += 1;
|
||||||
|
|
||||||
|
final double time = leftX[indexLeft];
|
||||||
|
combinedPoints.add(new Point(time, operator.applyAsDouble(funLeft.evaluateByIndex(indexLeft), funRight.evaluateByIndex(indexRight))));
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
// This while-loop occurring implies that functionLeft can not move forward anymore
|
||||||
|
while(indexRight < rightX.length-1){
|
||||||
|
indexRight += 1;
|
||||||
|
|
||||||
|
final double time = rightX[indexRight];
|
||||||
|
combinedPoints.add(new Point(time, operator.applyAsDouble(funLeft.evaluateByIndex(indexLeft), funRight.evaluateByIndex(indexRight))));
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
return RightContinuousStepFunction.constructFromPoints(combinedPoints, newDefaultY);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -20,7 +20,7 @@ import java.util.*;
|
||||||
|
|
||||||
public final class Utils {
|
public final class Utils {
|
||||||
|
|
||||||
public static StepFunction estimateOneMinusECDF(final double[] times){
|
public static RightContinuousStepFunction estimateOneMinusECDF(final double[] times){
|
||||||
Arrays.sort(times);
|
Arrays.sort(times);
|
||||||
|
|
||||||
final Map<Double, Integer> timeCounterMap = new HashMap<>();
|
final Map<Double, Integer> timeCounterMap = new HashMap<>();
|
||||||
|
|
|
@ -1,80 +0,0 @@
|
||||||
/*
|
|
||||||
* Copyright (c) 2019 Joel Therrien.
|
|
||||||
* This program is free software: you can redistribute it and/or modify
|
|
||||||
* it under the terms of the GNU General Public License as published by
|
|
||||||
* the Free Software Foundation, either version 3 of the License, or
|
|
||||||
* (at your option) any later version.
|
|
||||||
*
|
|
||||||
* This program is distributed in the hope that it will be useful,
|
|
||||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
||||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
||||||
* GNU General Public License for more details.
|
|
||||||
*
|
|
||||||
* You should have received a copy of the GNU General Public License
|
|
||||||
* along with this program. If not, see <https://www.gnu.org/licenses/>.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package ca.joeltherrien.randomforest.utils;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Represents a step function represented by discrete points. However, there may be individual time values that has
|
|
||||||
* a y value that doesn't belong to a particular 'step'.
|
|
||||||
*/
|
|
||||||
public final class VeryDiscontinuousStepFunction implements MathFunction {
|
|
||||||
|
|
||||||
private final double[] x;
|
|
||||||
private final double[] yAt;
|
|
||||||
private final double[] yRight;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Represents the value that should be returned by evaluate if there are points prior to the time the function is being evaluated at.
|
|
||||||
*
|
|
||||||
* Map be null.
|
|
||||||
*/
|
|
||||||
private final double defaultY;
|
|
||||||
|
|
||||||
public VeryDiscontinuousStepFunction(double[] x, double[] yAt, double[] yRight, double defaultY) {
|
|
||||||
this.x = x;
|
|
||||||
this.yAt = yAt;
|
|
||||||
this.yRight = yRight;
|
|
||||||
this.defaultY = defaultY;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public double evaluate(double time){
|
|
||||||
int index = Utils.binarySearchLessThan(0, x.length, x, time);
|
|
||||||
if(index < 0){
|
|
||||||
return defaultY;
|
|
||||||
}
|
|
||||||
else{
|
|
||||||
if(x[index] == time){
|
|
||||||
return yAt[index];
|
|
||||||
}
|
|
||||||
else{ // time > x[index]
|
|
||||||
return yRight[index];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String toString(){
|
|
||||||
final StringBuilder builder = new StringBuilder();
|
|
||||||
builder.append("Default point: ");
|
|
||||||
builder.append(defaultY);
|
|
||||||
builder.append("\n");
|
|
||||||
|
|
||||||
for(int i=0; i<x.length; i++){
|
|
||||||
builder.append("x:");
|
|
||||||
builder.append(x[i]);
|
|
||||||
builder.append("\tyAt:");
|
|
||||||
builder.append(yAt[i]);
|
|
||||||
builder.append("\tyRight:");
|
|
||||||
builder.append(yRight[i]);
|
|
||||||
builder.append("\n");
|
|
||||||
}
|
|
||||||
|
|
||||||
return builder.toString();
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
|
@ -45,20 +45,20 @@ public class TestDeterministicForests {
|
||||||
|
|
||||||
int index = 0;
|
int index = 0;
|
||||||
for(int j=0; j<5; j++){
|
for(int j=0; j<5; j++){
|
||||||
final NumericCovariate numericCovariate = new NumericCovariate("numeric"+j, index);
|
final NumericCovariate numericCovariate = new NumericCovariate("numeric"+j, index, false);
|
||||||
covariateList.add(numericCovariate);
|
covariateList.add(numericCovariate);
|
||||||
index++;
|
index++;
|
||||||
}
|
}
|
||||||
|
|
||||||
for(int j=0; j<5; j++){
|
for(int j=0; j<5; j++){
|
||||||
final BooleanCovariate booleanCovariate = new BooleanCovariate("boolean"+j, index);
|
final BooleanCovariate booleanCovariate = new BooleanCovariate("boolean"+j, index, false);
|
||||||
covariateList.add(booleanCovariate);
|
covariateList.add(booleanCovariate);
|
||||||
index++;
|
index++;
|
||||||
}
|
}
|
||||||
|
|
||||||
final List<String> levels = Utils.easyList("cat", "dog", "mouse");
|
final List<String> levels = Utils.easyList("cat", "dog", "mouse");
|
||||||
for(int j=0; j<5; j++){
|
for(int j=0; j<5; j++){
|
||||||
final FactorCovariate factorCovariate = new FactorCovariate("factor"+j, index, levels);
|
final FactorCovariate factorCovariate = new FactorCovariate("factor"+j, index, levels, false);
|
||||||
covariateList.add(factorCovariate);
|
covariateList.add(factorCovariate);
|
||||||
index++;
|
index++;
|
||||||
}
|
}
|
||||||
|
@ -214,14 +214,14 @@ public class TestDeterministicForests {
|
||||||
|
|
||||||
forestTrainer5Trees.trainSerialOnDisk(Optional.empty());
|
forestTrainer5Trees.trainSerialOnDisk(Optional.empty());
|
||||||
forestTrainer10Trees.trainSerialOnDisk(Optional.empty());
|
forestTrainer10Trees.trainSerialOnDisk(Optional.empty());
|
||||||
final Forest<Double, Double> forestSerial = DataUtils.loadForest(saveTreeFile, new MeanResponseCombiner());
|
final Forest<Double, Double> forestSerial = DataUtils.loadOnlineForest(saveTreeFile, new MeanResponseCombiner());
|
||||||
TestUtils.removeFolder(saveTreeFile);
|
TestUtils.removeFolder(saveTreeFile);
|
||||||
verifyTwoForestsEqual(testData, referenceForest, forestSerial);
|
verifyTwoForestsEqual(testData, referenceForest, forestSerial);
|
||||||
|
|
||||||
|
|
||||||
forestTrainer5Trees.trainParallelOnDisk(Optional.empty(), 4);
|
forestTrainer5Trees.trainParallelOnDisk(Optional.empty(), 4);
|
||||||
forestTrainer10Trees.trainParallelOnDisk(Optional.empty(), 4);
|
forestTrainer10Trees.trainParallelOnDisk(Optional.empty(), 4);
|
||||||
final Forest<Double, Double> forestParallel = DataUtils.loadForest(saveTreeFile, new MeanResponseCombiner());
|
final Forest<Double, Double> forestParallel = DataUtils.loadOnlineForest(saveTreeFile, new MeanResponseCombiner());
|
||||||
TestUtils.removeFolder(saveTreeFile);
|
TestUtils.removeFolder(saveTreeFile);
|
||||||
verifyTwoForestsEqual(testData, referenceForest, forestParallel);
|
verifyTwoForestsEqual(testData, referenceForest, forestParallel);
|
||||||
|
|
||||||
|
@ -259,7 +259,7 @@ public class TestDeterministicForests {
|
||||||
|
|
||||||
for(int k=0; k<3; k++){
|
for(int k=0; k<3; k++){
|
||||||
forestTrainer.trainSerialOnDisk(Optional.empty());
|
forestTrainer.trainSerialOnDisk(Optional.empty());
|
||||||
final Forest<Double, Double> replicantForest = DataUtils.loadForest(saveTreeFile, responseCombiner);
|
final Forest<Double, Double> replicantForest = DataUtils.loadOnlineForest(saveTreeFile, responseCombiner);
|
||||||
TestUtils.removeFolder(saveTreeFile);
|
TestUtils.removeFolder(saveTreeFile);
|
||||||
verifyTwoForestsEqual(testData, referenceForest, replicantForest);
|
verifyTwoForestsEqual(testData, referenceForest, replicantForest);
|
||||||
}
|
}
|
||||||
|
@ -274,7 +274,7 @@ public class TestDeterministicForests {
|
||||||
|
|
||||||
for(int k=0; k<3; k++){
|
for(int k=0; k<3; k++){
|
||||||
forestTrainer.trainParallelOnDisk(Optional.empty(), 4);
|
forestTrainer.trainParallelOnDisk(Optional.empty(), 4);
|
||||||
final Forest<Double, Double> replicantForest = DataUtils.loadForest(saveTreeFile, responseCombiner);
|
final Forest<Double, Double> replicantForest = DataUtils.loadOnlineForest(saveTreeFile, responseCombiner);
|
||||||
TestUtils.removeFolder(saveTreeFile);
|
TestUtils.removeFolder(saveTreeFile);
|
||||||
verifyTwoForestsEqual(testData, referenceForest, replicantForest);
|
verifyTwoForestsEqual(testData, referenceForest, replicantForest);
|
||||||
}
|
}
|
||||||
|
|
|
@ -20,8 +20,8 @@ import ca.joeltherrien.randomforest.covariates.Covariate;
|
||||||
import ca.joeltherrien.randomforest.covariates.numeric.NumericCovariate;
|
import ca.joeltherrien.randomforest.covariates.numeric.NumericCovariate;
|
||||||
import ca.joeltherrien.randomforest.responses.regression.MeanResponseCombiner;
|
import ca.joeltherrien.randomforest.responses.regression.MeanResponseCombiner;
|
||||||
import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceSplitFinder;
|
import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceSplitFinder;
|
||||||
import ca.joeltherrien.randomforest.tree.Forest;
|
|
||||||
import ca.joeltherrien.randomforest.tree.ForestTrainer;
|
import ca.joeltherrien.randomforest.tree.ForestTrainer;
|
||||||
|
import ca.joeltherrien.randomforest.tree.OnlineForest;
|
||||||
import ca.joeltherrien.randomforest.tree.Tree;
|
import ca.joeltherrien.randomforest.tree.Tree;
|
||||||
import ca.joeltherrien.randomforest.tree.TreeTrainer;
|
import ca.joeltherrien.randomforest.tree.TreeTrainer;
|
||||||
import ca.joeltherrien.randomforest.utils.DataUtils;
|
import ca.joeltherrien.randomforest.utils.DataUtils;
|
||||||
|
@ -39,12 +39,12 @@ import static org.junit.jupiter.api.Assertions.*;
|
||||||
|
|
||||||
public class TestProvidingInitialForest {
|
public class TestProvidingInitialForest {
|
||||||
|
|
||||||
private Forest<Double, Double> initialForest;
|
private OnlineForest<Double, Double> initialForest;
|
||||||
private List<Covariate> covariateList;
|
private List<Covariate> covariateList;
|
||||||
private List<Row<Double>> data;
|
private List<Row<Double>> data;
|
||||||
|
|
||||||
public TestProvidingInitialForest(){
|
public TestProvidingInitialForest(){
|
||||||
covariateList = Collections.singletonList(new NumericCovariate("x", 0));
|
covariateList = Collections.singletonList(new NumericCovariate("x", 0, false));
|
||||||
|
|
||||||
data = Utils.easyList(
|
data = Utils.easyList(
|
||||||
Row.createSimple(Utils.easyMap("x", "1.0"), covariateList, 1, 1.0),
|
Row.createSimple(Utils.easyMap("x", "1.0"), covariateList, 1, 1.0),
|
||||||
|
@ -107,8 +107,8 @@ public class TestProvidingInitialForest {
|
||||||
public void testSerialInMemory(){
|
public void testSerialInMemory(){
|
||||||
final ForestTrainer<Double, Double, Double> forestTrainer = getForestTrainer(null, 20);
|
final ForestTrainer<Double, Double, Double> forestTrainer = getForestTrainer(null, 20);
|
||||||
|
|
||||||
final Forest<Double, Double> newForest = forestTrainer.trainSerialInMemory(Optional.of(initialForest));
|
final OnlineForest<Double, Double> newForest = forestTrainer.trainSerialInMemory(Optional.of(initialForest));
|
||||||
assertEquals(20, newForest.getTrees().size());
|
assertEquals(20, newForest.getNumberOfTrees());
|
||||||
|
|
||||||
for(Tree<Double> initialTree : initialForest.getTrees()){
|
for(Tree<Double> initialTree : initialForest.getTrees()){
|
||||||
assertTrue(newForest.getTrees().contains(initialTree));
|
assertTrue(newForest.getTrees().contains(initialTree));
|
||||||
|
@ -124,8 +124,8 @@ public class TestProvidingInitialForest {
|
||||||
public void testParallelInMemory(){
|
public void testParallelInMemory(){
|
||||||
final ForestTrainer<Double, Double, Double> forestTrainer = getForestTrainer(null, 20);
|
final ForestTrainer<Double, Double, Double> forestTrainer = getForestTrainer(null, 20);
|
||||||
|
|
||||||
final Forest<Double, Double> newForest = forestTrainer.trainParallelInMemory(Optional.of(initialForest), 2);
|
final OnlineForest<Double, Double> newForest = forestTrainer.trainParallelInMemory(Optional.of(initialForest), 2);
|
||||||
assertEquals(20, newForest.getTrees().size());
|
assertEquals(20, newForest.getNumberOfTrees());
|
||||||
|
|
||||||
for(Tree<Double> initialTree : initialForest.getTrees()){
|
for(Tree<Double> initialTree : initialForest.getTrees()){
|
||||||
assertTrue(newForest.getTrees().contains(initialTree));
|
assertTrue(newForest.getTrees().contains(initialTree));
|
||||||
|
@ -149,11 +149,11 @@ public class TestProvidingInitialForest {
|
||||||
forestTrainer.trainParallelOnDisk(Optional.of(initialForest), 2);
|
forestTrainer.trainParallelOnDisk(Optional.of(initialForest), 2);
|
||||||
|
|
||||||
assertEquals(20, directory.listFiles().length);
|
assertEquals(20, directory.listFiles().length);
|
||||||
final Forest<Double, Double> newForest = DataUtils.loadForest(directory, new MeanResponseCombiner());
|
final OnlineForest<Double, Double> newForest = DataUtils.loadOnlineForest(directory, new MeanResponseCombiner());
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
assertEquals(20, newForest.getTrees().size());
|
assertEquals(20, newForest.getNumberOfTrees());
|
||||||
|
|
||||||
final List<String> newForestTreesAsStrings = newForest.getTrees().stream()
|
final List<String> newForestTreesAsStrings = newForest.getTrees().stream()
|
||||||
.map(tree -> tree.toString()).collect(Collectors.toList());
|
.map(tree -> tree.toString()).collect(Collectors.toList());
|
||||||
|
@ -179,9 +179,9 @@ public class TestProvidingInitialForest {
|
||||||
|
|
||||||
assertEquals(20, directory.listFiles().length);
|
assertEquals(20, directory.listFiles().length);
|
||||||
|
|
||||||
final Forest<Double, Double> newForest = DataUtils.loadForest(directory, new MeanResponseCombiner());
|
final OnlineForest<Double, Double> newForest = DataUtils.loadOnlineForest(directory, new MeanResponseCombiner());
|
||||||
|
|
||||||
assertEquals(20, newForest.getTrees().size());
|
assertEquals(20, newForest.getNumberOfTrees());
|
||||||
|
|
||||||
final List<String> newForestTreesAsStrings = newForest.getTrees().stream()
|
final List<String> newForestTreesAsStrings = newForest.getTrees().stream()
|
||||||
.map(tree -> tree.toString()).collect(Collectors.toList());
|
.map(tree -> tree.toString()).collect(Collectors.toList());
|
||||||
|
@ -198,7 +198,7 @@ public class TestProvidingInitialForest {
|
||||||
it's not clear if the forest being provided is the same one that trees were saved from.
|
it's not clear if the forest being provided is the same one that trees were saved from.
|
||||||
*/
|
*/
|
||||||
@Test
|
@Test
|
||||||
public void verifyExceptions(){
|
public void testExceptions(){
|
||||||
final String filePath = "src/test/resources/trees/";
|
final String filePath = "src/test/resources/trees/";
|
||||||
final File directory = new File(filePath);
|
final File directory = new File(filePath);
|
||||||
if(directory.exists()){
|
if(directory.exists()){
|
||||||
|
|
|
@ -24,11 +24,10 @@ import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskRespons
|
||||||
import ca.joeltherrien.randomforest.responses.competingrisk.combiner.CompetingRiskFunctionCombiner;
|
import ca.joeltherrien.randomforest.responses.competingrisk.combiner.CompetingRiskFunctionCombiner;
|
||||||
import ca.joeltherrien.randomforest.responses.competingrisk.combiner.CompetingRiskResponseCombiner;
|
import ca.joeltherrien.randomforest.responses.competingrisk.combiner.CompetingRiskResponseCombiner;
|
||||||
import ca.joeltherrien.randomforest.responses.competingrisk.splitfinder.LogRankSplitFinder;
|
import ca.joeltherrien.randomforest.responses.competingrisk.splitfinder.LogRankSplitFinder;
|
||||||
import ca.joeltherrien.randomforest.tree.Forest;
|
import ca.joeltherrien.randomforest.tree.*;
|
||||||
import ca.joeltherrien.randomforest.tree.ForestTrainer;
|
|
||||||
import ca.joeltherrien.randomforest.tree.TreeTrainer;
|
|
||||||
import ca.joeltherrien.randomforest.utils.DataUtils;
|
import ca.joeltherrien.randomforest.utils.DataUtils;
|
||||||
import ca.joeltherrien.randomforest.utils.ResponseLoader;
|
import ca.joeltherrien.randomforest.utils.ResponseLoader;
|
||||||
|
import ca.joeltherrien.randomforest.utils.RightContinuousStepFunction;
|
||||||
import ca.joeltherrien.randomforest.utils.Utils;
|
import ca.joeltherrien.randomforest.utils.Utils;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
|
|
||||||
|
@ -47,10 +46,10 @@ public class TestSavingLoading {
|
||||||
|
|
||||||
public List<Covariate> getCovariates(){
|
public List<Covariate> getCovariates(){
|
||||||
return Utils.easyList(
|
return Utils.easyList(
|
||||||
new NumericCovariate("ageatfda", 0),
|
new NumericCovariate("ageatfda", 0, false),
|
||||||
new BooleanCovariate("idu", 1),
|
new BooleanCovariate("idu", 1, false),
|
||||||
new BooleanCovariate("black", 2),
|
new BooleanCovariate("black", 2, false),
|
||||||
new NumericCovariate("cd4nadir", 3)
|
new NumericCovariate("cd4nadir", 3, false)
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -119,16 +118,21 @@ public class TestSavingLoading {
|
||||||
assertTrue(directory.isDirectory());
|
assertTrue(directory.isDirectory());
|
||||||
assertEquals(NTREE, directory.listFiles().length);
|
assertEquals(NTREE, directory.listFiles().length);
|
||||||
|
|
||||||
final Forest<CompetingRiskFunctions, CompetingRiskFunctions> forest = DataUtils.loadForest(directory, new CompetingRiskFunctionCombiner(new int[]{1,2}, null));
|
final CompetingRiskFunctionCombiner treeResponseCombiner = new CompetingRiskFunctionCombiner(new int[]{1,2}, null);
|
||||||
|
final OnlineForest<CompetingRiskFunctions, CompetingRiskFunctions> onlineForest = DataUtils.loadOnlineForest(directory, treeResponseCombiner);
|
||||||
|
final OfflineForest<CompetingRiskFunctions, CompetingRiskFunctions> offlineForest = new OfflineForest<>(directory, treeResponseCombiner);
|
||||||
|
|
||||||
final CovariateRow predictionRow = getPredictionRow(covariates);
|
final CovariateRow predictionRow = getPredictionRow(covariates);
|
||||||
|
|
||||||
final CompetingRiskFunctions functions = forest.evaluate(predictionRow);
|
final CompetingRiskFunctions functionsOnline = onlineForest.evaluate(predictionRow);
|
||||||
assertNotNull(functions);
|
assertNotNull(functionsOnline);
|
||||||
assertTrue(functions.getCumulativeIncidenceFunction(1).getX().length > 2);
|
assertTrue(functionsOnline.getCumulativeIncidenceFunction(1).getX().length > 2);
|
||||||
|
|
||||||
|
final CompetingRiskFunctions functionsOffline = offlineForest.evaluate(predictionRow);
|
||||||
|
assertTrue(competingFunctionsEqual(functionsOffline, functionsOnline));
|
||||||
|
|
||||||
|
|
||||||
assertEquals(NTREE, forest.getTrees().size());
|
assertEquals(NTREE, onlineForest.getTrees().size());
|
||||||
|
|
||||||
TestUtils.removeFolder(directory);
|
TestUtils.removeFolder(directory);
|
||||||
|
|
||||||
|
@ -159,17 +163,22 @@ public class TestSavingLoading {
|
||||||
assertEquals(NTREE, directory.listFiles().length);
|
assertEquals(NTREE, directory.listFiles().length);
|
||||||
|
|
||||||
|
|
||||||
|
final CompetingRiskFunctionCombiner treeResponseCombiner = new CompetingRiskFunctionCombiner(new int[]{1,2}, null);
|
||||||
final Forest<CompetingRiskFunctions, CompetingRiskFunctions> forest = DataUtils.loadForest(directory, new CompetingRiskFunctionCombiner(new int[]{1,2}, null));
|
final OnlineForest<CompetingRiskFunctions, CompetingRiskFunctions> onlineForest = DataUtils.loadOnlineForest(directory, treeResponseCombiner);
|
||||||
|
final OfflineForest<CompetingRiskFunctions, CompetingRiskFunctions> offlineForest = new OfflineForest<>(directory, treeResponseCombiner);
|
||||||
|
|
||||||
final CovariateRow predictionRow = getPredictionRow(covariates);
|
final CovariateRow predictionRow = getPredictionRow(covariates);
|
||||||
|
|
||||||
final CompetingRiskFunctions functions = forest.evaluate(predictionRow);
|
final CompetingRiskFunctions functionsOnline = onlineForest.evaluate(predictionRow);
|
||||||
assertNotNull(functions);
|
assertNotNull(functionsOnline);
|
||||||
assertTrue(functions.getCumulativeIncidenceFunction(1).getX().length > 2);
|
assertTrue(functionsOnline.getCumulativeIncidenceFunction(1).getX().length > 2);
|
||||||
|
|
||||||
|
|
||||||
assertEquals(NTREE, forest.getTrees().size());
|
final CompetingRiskFunctions functionsOffline = offlineForest.evaluate(predictionRow);
|
||||||
|
assertTrue(competingFunctionsEqual(functionsOffline, functionsOnline));
|
||||||
|
|
||||||
|
|
||||||
|
assertEquals(NTREE, onlineForest.getTrees().size());
|
||||||
|
|
||||||
TestUtils.removeFolder(directory);
|
TestUtils.removeFolder(directory);
|
||||||
|
|
||||||
|
@ -177,6 +186,64 @@ public class TestSavingLoading {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
We don't implement equals() methods on the below mentioned classes because then we'd need to implement an
|
||||||
|
appropriate hashCode() method that's consistent with the equals(), and we only need plain equals() for
|
||||||
|
these tests.
|
||||||
|
*/
|
||||||
|
|
||||||
|
private boolean competingFunctionsEqual(CompetingRiskFunctions f1 ,CompetingRiskFunctions f2){
|
||||||
|
if(!functionsEqual(f1.getSurvivalCurve(), f2.getSurvivalCurve())){
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
for(int i=1; i<=2; i++){
|
||||||
|
if(!functionsEqual(f1.getCauseSpecificHazardFunction(i), f2.getCauseSpecificHazardFunction(i))){
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if(!functionsEqual(f1.getCumulativeIncidenceFunction(i), f2.getCumulativeIncidenceFunction(i))){
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
private boolean functionsEqual(RightContinuousStepFunction f1, RightContinuousStepFunction f2){
|
||||||
|
|
||||||
|
final double[] f1X = f1.getX();
|
||||||
|
final double[] f2X = f2.getX();
|
||||||
|
|
||||||
|
final double[] f1Y = f1.getY();
|
||||||
|
final double[] f2Y = f2.getY();
|
||||||
|
|
||||||
|
// first compare array lengths
|
||||||
|
if(f1X.length != f2X.length){
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if(f1Y.length != f2Y.length){
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO - better comparisons of doubles. I don't really care too much though as this equals method is only being used in tests
|
||||||
|
final double delta = 0.000001;
|
||||||
|
|
||||||
|
if(Math.abs(f1.getDefaultY() - f2.getDefaultY()) > delta){
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
for(int i=0; i < f1X.length; i++){
|
||||||
|
if(Math.abs(f1X[i] - f2X[i]) > delta){
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if(Math.abs(f1Y[i] - f2Y[i]) > delta){
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -156,7 +156,7 @@ public class TestUtils {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void reduceListToSize(){
|
public void testReduceListToSize(){
|
||||||
final List<Integer> testList = Utils.easyList(1, 2, 3, 4, 5, 6, 7, 8, 9, 10);
|
final List<Integer> testList = Utils.easyList(1, 2, 3, 4, 5, 6, 7, 8, 9, 10);
|
||||||
final Random random = new Random();
|
final Random random = new Random();
|
||||||
for(int i=0; i<100; i++) { // want to test many times to ensure it doesn't work just due to randomness
|
for(int i=0; i<100; i++) { // want to test many times to ensure it doesn't work just due to randomness
|
||||||
|
|
|
@ -0,0 +1,164 @@
|
||||||
|
package ca.joeltherrien.randomforest.competingrisk;
|
||||||
|
|
||||||
|
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskFunctions;
|
||||||
|
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
|
||||||
|
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskUtils;
|
||||||
|
import ca.joeltherrien.randomforest.responses.competingrisk.IBSCalculator;
|
||||||
|
import ca.joeltherrien.randomforest.utils.Point;
|
||||||
|
import ca.joeltherrien.randomforest.utils.RightContinuousStepFunction;
|
||||||
|
import ca.joeltherrien.randomforest.utils.Utils;
|
||||||
|
import org.junit.jupiter.api.Test;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Optional;
|
||||||
|
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
|
||||||
|
public class IBSCalculatorTest {
|
||||||
|
|
||||||
|
private final RightContinuousStepFunction cif;
|
||||||
|
|
||||||
|
public IBSCalculatorTest(){
|
||||||
|
this.cif = RightContinuousStepFunction.constructFromPoints(
|
||||||
|
Utils.easyList(
|
||||||
|
new Point(1.0, 0.1),
|
||||||
|
new Point(2.0, 0.2),
|
||||||
|
new Point(3.0, 0.3),
|
||||||
|
new Point(4.0, 0.8)
|
||||||
|
), 0.0
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
R code to get these results:
|
||||||
|
|
||||||
|
predicted_cif <- stepfun(1:4, c(0, 0.1, 0.2, 0.3, 0.8))
|
||||||
|
weights <- 1
|
||||||
|
recorded_time <- 2.0
|
||||||
|
recorded_status <- 1.0
|
||||||
|
event_of_interest <- 2
|
||||||
|
times <- 0:4
|
||||||
|
|
||||||
|
errors <- weights * ( as.integer(recorded_time <= times & recorded_status == event_of_interest) - predicted_cif(times))^2
|
||||||
|
sum(errors)
|
||||||
|
|
||||||
|
|
||||||
|
and run again with event_of_interest <- 1
|
||||||
|
|
||||||
|
|
||||||
|
Note that in the R code I only evaluate up to 4, while in the Java code I integrate up to 5
|
||||||
|
This is because the evaluation at 4 is giving the area of the rectangle from 4 to 5.
|
||||||
|
|
||||||
|
*/
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testResultsWithoutCensoringDistribution(){
|
||||||
|
final IBSCalculator calculator = new IBSCalculator();
|
||||||
|
|
||||||
|
final double errorDifferentEvent = calculator.calculateError(
|
||||||
|
new CompetingRiskResponse(1, 2.0),
|
||||||
|
this.cif,
|
||||||
|
2,
|
||||||
|
5.0);
|
||||||
|
|
||||||
|
assertEquals(0.78, errorDifferentEvent, 0.000001);
|
||||||
|
|
||||||
|
final double errorSameEvent = calculator.calculateError(
|
||||||
|
new CompetingRiskResponse(1, 2.0),
|
||||||
|
this.cif,
|
||||||
|
1,
|
||||||
|
5.0);
|
||||||
|
|
||||||
|
assertEquals(1.18, errorSameEvent, 0.000001);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testResultsWithCensoringDistribution(){
|
||||||
|
final RightContinuousStepFunction censorSurvivalFunction = RightContinuousStepFunction.constructFromPoints(
|
||||||
|
Utils.easyList(
|
||||||
|
new Point(0.0, 0.75),
|
||||||
|
new Point(1.0, 0.5),
|
||||||
|
new Point(3.0, 0.25),
|
||||||
|
new Point(5.0, 0)
|
||||||
|
), 1.0
|
||||||
|
);
|
||||||
|
|
||||||
|
final IBSCalculator calculator = new IBSCalculator(censorSurvivalFunction);
|
||||||
|
|
||||||
|
final double errorDifferentEvent = calculator.calculateError(
|
||||||
|
new CompetingRiskResponse(1, 2.0),
|
||||||
|
this.cif,
|
||||||
|
2,
|
||||||
|
5.0);
|
||||||
|
|
||||||
|
assertEquals(1.56, errorDifferentEvent, 0.000001);
|
||||||
|
|
||||||
|
final double errorSameEvent = calculator.calculateError(
|
||||||
|
new CompetingRiskResponse(1, 2.0),
|
||||||
|
this.cif,
|
||||||
|
1,
|
||||||
|
5.0);
|
||||||
|
|
||||||
|
assertEquals(2.36, errorSameEvent, 0.000001);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testStaticFunction(){
|
||||||
|
final RightContinuousStepFunction censorSurvivalFunction = RightContinuousStepFunction.constructFromPoints(
|
||||||
|
Utils.easyList(
|
||||||
|
new Point(0.0, 0.75),
|
||||||
|
new Point(1.0, 0.5),
|
||||||
|
new Point(3.0, 0.25),
|
||||||
|
new Point(5.0, 0)
|
||||||
|
), 1.0
|
||||||
|
);
|
||||||
|
|
||||||
|
final List<CompetingRiskResponse> responseList = Utils.easyList(
|
||||||
|
new CompetingRiskResponse(1, 2.0),
|
||||||
|
new CompetingRiskResponse(1, 2.0));
|
||||||
|
|
||||||
|
// for predictions; we'll construct an improper CompetingRisksFunctions
|
||||||
|
final RightContinuousStepFunction trivialFunction = RightContinuousStepFunction.constructFromPoints(
|
||||||
|
Utils.easyList(new Point(1.0, 0.0)),
|
||||||
|
1.0);
|
||||||
|
|
||||||
|
final CompetingRiskFunctions prediction = CompetingRiskFunctions.builder()
|
||||||
|
.survivalCurve(trivialFunction)
|
||||||
|
.causeSpecificHazards(Utils.easyList(trivialFunction, trivialFunction))
|
||||||
|
.cumulativeIncidenceCurves(Utils.easyList(this.cif, trivialFunction))
|
||||||
|
.build();
|
||||||
|
|
||||||
|
final List<CompetingRiskFunctions> predictionList = Utils.easyList(prediction, prediction);
|
||||||
|
|
||||||
|
double[] errorParallel = CompetingRiskUtils.calculateIBSError(
|
||||||
|
responseList,
|
||||||
|
predictionList,
|
||||||
|
Optional.of(censorSurvivalFunction),
|
||||||
|
1,
|
||||||
|
5.0,
|
||||||
|
true);
|
||||||
|
|
||||||
|
double[] errorSerial = CompetingRiskUtils.calculateIBSError(
|
||||||
|
responseList,
|
||||||
|
predictionList,
|
||||||
|
Optional.of(censorSurvivalFunction),
|
||||||
|
1,
|
||||||
|
5.0,
|
||||||
|
false);
|
||||||
|
|
||||||
|
assertEquals(responseList.size(), errorParallel.length);
|
||||||
|
assertEquals(responseList.size(), errorSerial.length);
|
||||||
|
|
||||||
|
assertEquals(2.36, errorParallel[0], 0.000001);
|
||||||
|
assertEquals(2.36, errorParallel[1], 0.000001);
|
||||||
|
|
||||||
|
assertEquals(2.36, errorSerial[0], 0.000001);
|
||||||
|
assertEquals(2.36, errorSerial[1], 0.000001);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
}
|
|
@ -53,10 +53,10 @@ public class TestCompetingRisk {
|
||||||
|
|
||||||
public List<Covariate> getCovariates(){
|
public List<Covariate> getCovariates(){
|
||||||
return Utils.easyList(
|
return Utils.easyList(
|
||||||
new NumericCovariate("ageatfda", 0),
|
new NumericCovariate("ageatfda", 0, false),
|
||||||
new BooleanCovariate("idu", 1),
|
new BooleanCovariate("idu", 1, false),
|
||||||
new BooleanCovariate("black", 2),
|
new BooleanCovariate("black", 2, false),
|
||||||
new NumericCovariate("cd4nadir", 3)
|
new NumericCovariate("cd4nadir", 3, false)
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -109,8 +109,8 @@ public class TestCompetingRisk {
|
||||||
|
|
||||||
// by only using BooleanCovariates (only one split rule) we can guarantee identical results with randomForestSRC on one tree.
|
// by only using BooleanCovariates (only one split rule) we can guarantee identical results with randomForestSRC on one tree.
|
||||||
final List<Covariate> covariates = Utils.easyList(
|
final List<Covariate> covariates = Utils.easyList(
|
||||||
new BooleanCovariate("idu", 0),
|
new BooleanCovariate("idu", 0, false),
|
||||||
new BooleanCovariate("black", 1)
|
new BooleanCovariate("black", 1, false)
|
||||||
);
|
);
|
||||||
|
|
||||||
final List<Row<CompetingRiskResponse>> dataset = getData(covariates, "src/test/resources/wihs.bootstrapped.csv");
|
final List<Row<CompetingRiskResponse>> dataset = getData(covariates, "src/test/resources/wihs.bootstrapped.csv");
|
||||||
|
@ -210,8 +210,8 @@ public class TestCompetingRisk {
|
||||||
public void testLogRankSplitFinderTwoBooleans() throws IOException {
|
public void testLogRankSplitFinderTwoBooleans() throws IOException {
|
||||||
// by only using BooleanCovariates (only one split rule) we can guarantee identical results with randomForestSRC on one tree.
|
// by only using BooleanCovariates (only one split rule) we can guarantee identical results with randomForestSRC on one tree.
|
||||||
final List<Covariate> covariates = Utils.easyList(
|
final List<Covariate> covariates = Utils.easyList(
|
||||||
new BooleanCovariate("idu", 0),
|
new BooleanCovariate("idu", 0, false),
|
||||||
new BooleanCovariate("black", 1)
|
new BooleanCovariate("black", 1, false)
|
||||||
);
|
);
|
||||||
|
|
||||||
|
|
||||||
|
@ -259,7 +259,7 @@ public class TestCompetingRisk {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void verifyDataset() throws IOException {
|
public void testDataset() throws IOException {
|
||||||
final List<Covariate> covariates = getCovariates();
|
final List<Covariate> covariates = getCovariates();
|
||||||
|
|
||||||
final List<Row<CompetingRiskResponse>> dataset = getData(covariates, DEFAULT_FILEPATH);
|
final List<Row<CompetingRiskResponse>> dataset = getData(covariates, DEFAULT_FILEPATH);
|
||||||
|
|
|
@ -16,10 +16,11 @@
|
||||||
|
|
||||||
package ca.joeltherrien.randomforest.competingrisk;
|
package ca.joeltherrien.randomforest.competingrisk;
|
||||||
|
|
||||||
import ca.joeltherrien.randomforest.Row;
|
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskFunctions;
|
||||||
import ca.joeltherrien.randomforest.covariates.Covariate;
|
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
|
||||||
import ca.joeltherrien.randomforest.responses.competingrisk.*;
|
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskUtils;
|
||||||
import ca.joeltherrien.randomforest.tree.Forest;
|
import ca.joeltherrien.randomforest.tree.Forest;
|
||||||
|
import ca.joeltherrien.randomforest.tree.OnlineForest;
|
||||||
import ca.joeltherrien.randomforest.utils.RightContinuousStepFunction;
|
import ca.joeltherrien.randomforest.utils.RightContinuousStepFunction;
|
||||||
import ca.joeltherrien.randomforest.utils.StepFunction;
|
import ca.joeltherrien.randomforest.utils.StepFunction;
|
||||||
import ca.joeltherrien.randomforest.utils.Utils;
|
import ca.joeltherrien.randomforest.utils.Utils;
|
||||||
|
@ -30,8 +31,6 @@ import java.util.List;
|
||||||
|
|
||||||
import static ca.joeltherrien.randomforest.TestUtils.closeEnough;
|
import static ca.joeltherrien.randomforest.TestUtils.closeEnough;
|
||||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
import static org.mockito.Mockito.mock;
|
|
||||||
import static org.mockito.Mockito.when;
|
|
||||||
|
|
||||||
public class TestCompetingRiskErrorRateCalculator {
|
public class TestCompetingRiskErrorRateCalculator {
|
||||||
|
|
||||||
|
@ -48,7 +47,7 @@ public class TestCompetingRiskErrorRateCalculator {
|
||||||
|
|
||||||
final int event = 1;
|
final int event = 1;
|
||||||
|
|
||||||
final Forest<CompetingRiskFunctions, CompetingRiskFunctions> fakeForest = Forest.<CompetingRiskFunctions, CompetingRiskFunctions>builder().build();
|
final Forest<CompetingRiskFunctions, CompetingRiskFunctions> fakeForest = OnlineForest.<CompetingRiskFunctions, CompetingRiskFunctions>builder().build();
|
||||||
|
|
||||||
final double naiveConcordance = CompetingRiskUtils.calculateConcordance(responseList, mortalityArray, event);
|
final double naiveConcordance = CompetingRiskUtils.calculateConcordance(responseList, mortalityArray, event);
|
||||||
|
|
||||||
|
|
|
@ -46,7 +46,7 @@ public class TestLogRankSplitFinder {
|
||||||
public static Data<CompetingRiskResponse> loadData(String filename) throws IOException {
|
public static Data<CompetingRiskResponse> loadData(String filename) throws IOException {
|
||||||
|
|
||||||
final List<Covariate> covariates = Utils.easyList(
|
final List<Covariate> covariates = Utils.easyList(
|
||||||
new NumericCovariate("x2", 0)
|
new NumericCovariate("x2", 0, false)
|
||||||
);
|
);
|
||||||
|
|
||||||
final List<Row<CompetingRiskResponse>> rows = TestUtils.loadData(covariates, new ResponseLoader.CompetingRisksResponseLoader("delta", "u"), filename);
|
final List<Row<CompetingRiskResponse>> rows = TestUtils.loadData(covariates, new ResponseLoader.CompetingRisksResponseLoader("delta", "u"), filename);
|
||||||
|
|
|
@ -16,7 +16,6 @@
|
||||||
|
|
||||||
package ca.joeltherrien.randomforest.competingrisk;
|
package ca.joeltherrien.randomforest.competingrisk;
|
||||||
|
|
||||||
import ca.joeltherrien.randomforest.utils.LeftContinuousStepFunction;
|
|
||||||
import ca.joeltherrien.randomforest.utils.RightContinuousStepFunction;
|
import ca.joeltherrien.randomforest.utils.RightContinuousStepFunction;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
|
|
||||||
|
@ -31,13 +30,6 @@ public class TestMathFunctions {
|
||||||
return new RightContinuousStepFunction(time, y, 0.1);
|
return new RightContinuousStepFunction(time, y, 0.1);
|
||||||
}
|
}
|
||||||
|
|
||||||
private LeftContinuousStepFunction generateLeftContinuousStepFunction(){
|
|
||||||
final double[] time = new double[]{1.0, 2.0, 3.0};
|
|
||||||
final double[] y = new double[]{-1.0, 1.0, 0.5};
|
|
||||||
|
|
||||||
return new LeftContinuousStepFunction(time, y, 0.1);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testRightContinuousStepFunction(){
|
public void testRightContinuousStepFunction(){
|
||||||
final RightContinuousStepFunction function = generateRightContinuousStepFunction();
|
final RightContinuousStepFunction function = generateRightContinuousStepFunction();
|
||||||
|
@ -56,21 +48,5 @@ public class TestMathFunctions {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testLeftContinuousStepFunction(){
|
|
||||||
final LeftContinuousStepFunction function = generateLeftContinuousStepFunction();
|
|
||||||
|
|
||||||
assertEquals(0.1, function.evaluate(0.5));
|
|
||||||
assertEquals(0.1, function.evaluate(1.0));
|
|
||||||
assertEquals(-1.0, function.evaluate(2.0));
|
|
||||||
assertEquals(1.0, function.evaluate(3.0));
|
|
||||||
|
|
||||||
|
|
||||||
assertEquals(0.1, function.evaluate(0.6));
|
|
||||||
assertEquals(-1.0, function.evaluate(1.1));
|
|
||||||
assertEquals(1.0, function.evaluate(2.1));
|
|
||||||
assertEquals(0.5, function.evaluate(3.1));
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -17,12 +17,15 @@
|
||||||
package ca.joeltherrien.randomforest.covariates;
|
package ca.joeltherrien.randomforest.covariates;
|
||||||
|
|
||||||
|
|
||||||
|
import ca.joeltherrien.randomforest.Row;
|
||||||
import ca.joeltherrien.randomforest.covariates.factor.FactorCovariate;
|
import ca.joeltherrien.randomforest.covariates.factor.FactorCovariate;
|
||||||
|
import ca.joeltherrien.randomforest.tree.Split;
|
||||||
import ca.joeltherrien.randomforest.utils.Utils;
|
import ca.joeltherrien.randomforest.utils.Utils;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.junit.jupiter.api.function.Executable;
|
import org.junit.jupiter.api.function.Executable;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
|
import java.util.Collections;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Random;
|
import java.util.Random;
|
||||||
|
|
||||||
|
@ -31,7 +34,7 @@ import static org.junit.jupiter.api.Assertions.*;
|
||||||
public class FactorCovariateTest {
|
public class FactorCovariateTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
void verifyEqualLevels() {
|
public void testEqualLevels() {
|
||||||
final FactorCovariate petCovariate = createTestCovariate();
|
final FactorCovariate petCovariate = createTestCovariate();
|
||||||
|
|
||||||
final FactorCovariate.FactorValue dog1 = petCovariate.createValue("DOG");
|
final FactorCovariate.FactorValue dog1 = petCovariate.createValue("DOG");
|
||||||
|
@ -53,7 +56,7 @@ public class FactorCovariateTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
void verifyBadLevelException(){
|
public void testBadLevelException(){
|
||||||
final FactorCovariate petCovariate = createTestCovariate();
|
final FactorCovariate petCovariate = createTestCovariate();
|
||||||
final Executable badCode = () -> petCovariate.createValue("vulcan");
|
final Executable badCode = () -> petCovariate.createValue("vulcan");
|
||||||
|
|
||||||
|
@ -61,25 +64,169 @@ public class FactorCovariateTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
void testAllSubsets(){
|
public void testAllSubsets(){
|
||||||
|
final int n = 2*3; // ensure that n is a multiple of 3 for the test
|
||||||
final FactorCovariate petCovariate = createTestCovariate();
|
final FactorCovariate petCovariate = createTestCovariate();
|
||||||
|
final List<Row<Double>> data = generateSampleData(petCovariate, n);
|
||||||
|
|
||||||
final List<SplitRule<String>> splitRules = new ArrayList<>();
|
final List<Split<Double, String>> splits = new ArrayList<>();
|
||||||
|
|
||||||
petCovariate.generateSplitRuleUpdater(null, 100, new Random())
|
petCovariate.generateSplitRuleUpdater(data, 100, new Random())
|
||||||
.forEachRemaining(split -> splitRules.add(split.getSplitRule()));
|
.forEachRemaining(split -> splits.add(split));
|
||||||
|
|
||||||
assertEquals(splitRules.size(), 3);
|
assertEquals(splits.size(), 3);
|
||||||
|
|
||||||
// TODO verify the contents of the split rules
|
// These are the 3 possibilities
|
||||||
|
boolean dog_catmouse = false;
|
||||||
|
boolean cat_dogmouse = false;
|
||||||
|
boolean mouse_dogcat = false;
|
||||||
|
|
||||||
|
for(Split<Double, String> split : splits){
|
||||||
|
List<Row<Double>> smallerHand;
|
||||||
|
List<Row<Double>> largerHand;
|
||||||
|
|
||||||
|
if(split.getLeftHand().size() < split.getRightHand().size()){
|
||||||
|
smallerHand = split.getLeftHand();
|
||||||
|
largerHand = split.getRightHand();
|
||||||
|
} else{
|
||||||
|
smallerHand = split.getRightHand();
|
||||||
|
largerHand = split.getLeftHand();
|
||||||
|
}
|
||||||
|
|
||||||
|
// There should be exactly one distinct value in the smaller list
|
||||||
|
assertEquals(n/3, smallerHand.size());
|
||||||
|
assertEquals(1,
|
||||||
|
smallerHand.stream()
|
||||||
|
.map(row -> row.getCovariateValue(petCovariate).getValue())
|
||||||
|
.distinct()
|
||||||
|
.count()
|
||||||
|
);
|
||||||
|
|
||||||
|
// There should be exactly two distinct values in the smaller list
|
||||||
|
assertEquals(2*n/3, largerHand.size());
|
||||||
|
assertEquals(2,
|
||||||
|
largerHand.stream()
|
||||||
|
.map(row -> row.getCovariateValue(petCovariate).getValue())
|
||||||
|
.distinct()
|
||||||
|
.count()
|
||||||
|
);
|
||||||
|
|
||||||
|
switch(smallerHand.get(0).getCovariateValue(petCovariate).getValue()){
|
||||||
|
case "DOG":
|
||||||
|
dog_catmouse = true;
|
||||||
|
case "CAT":
|
||||||
|
cat_dogmouse = true;
|
||||||
|
case "MOUSE":
|
||||||
|
mouse_dogcat = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
assertTrue(dog_catmouse);
|
||||||
|
assertTrue(cat_dogmouse);
|
||||||
|
assertTrue(mouse_dogcat);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
* There was a bug where if number==0 in generateSplitRuleUpdater, then the result was empty.
|
||||||
|
*/
|
||||||
|
@Test
|
||||||
|
public void testNumber0Subsets(){
|
||||||
|
final int n = 2*3; // ensure that n is a multiple of 3 for the test
|
||||||
|
final FactorCovariate petCovariate = createTestCovariate();
|
||||||
|
final List<Row<Double>> data = generateSampleData(petCovariate, n);
|
||||||
|
|
||||||
|
final List<Split<Double, String>> splits = new ArrayList<>();
|
||||||
|
|
||||||
|
petCovariate.generateSplitRuleUpdater(data, 0, new Random())
|
||||||
|
.forEachRemaining(split -> splits.add(split));
|
||||||
|
|
||||||
|
assertEquals(splits.size(), 3);
|
||||||
|
|
||||||
|
// These are the 3 possibilities
|
||||||
|
boolean dog_catmouse = false;
|
||||||
|
boolean cat_dogmouse = false;
|
||||||
|
boolean mouse_dogcat = false;
|
||||||
|
|
||||||
|
for(Split<Double, String> split : splits){
|
||||||
|
List<Row<Double>> smallerHand;
|
||||||
|
List<Row<Double>> largerHand;
|
||||||
|
|
||||||
|
if(split.getLeftHand().size() < split.getRightHand().size()){
|
||||||
|
smallerHand = split.getLeftHand();
|
||||||
|
largerHand = split.getRightHand();
|
||||||
|
} else{
|
||||||
|
smallerHand = split.getRightHand();
|
||||||
|
largerHand = split.getLeftHand();
|
||||||
|
}
|
||||||
|
|
||||||
|
// There should be exactly one distinct value in the smaller list
|
||||||
|
assertEquals(n/3, smallerHand.size());
|
||||||
|
assertEquals(1,
|
||||||
|
smallerHand.stream()
|
||||||
|
.map(row -> row.getCovariateValue(petCovariate).getValue())
|
||||||
|
.distinct()
|
||||||
|
.count()
|
||||||
|
);
|
||||||
|
|
||||||
|
// There should be exactly two distinct values in the smaller list
|
||||||
|
assertEquals(2*n/3, largerHand.size());
|
||||||
|
assertEquals(2,
|
||||||
|
largerHand.stream()
|
||||||
|
.map(row -> row.getCovariateValue(petCovariate).getValue())
|
||||||
|
.distinct()
|
||||||
|
.count()
|
||||||
|
);
|
||||||
|
|
||||||
|
switch(smallerHand.get(0).getCovariateValue(petCovariate).getValue()){
|
||||||
|
case "DOG":
|
||||||
|
dog_catmouse = true;
|
||||||
|
case "CAT":
|
||||||
|
cat_dogmouse = true;
|
||||||
|
case "MOUSE":
|
||||||
|
mouse_dogcat = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
assertTrue(dog_catmouse);
|
||||||
|
assertTrue(cat_dogmouse);
|
||||||
|
assertTrue(mouse_dogcat);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testSpitRuleUpdaterWithNAs(){
|
||||||
|
// When some NAs were present calling generateSplitRuleUpdater caused an exception.
|
||||||
|
|
||||||
|
final FactorCovariate covariate = createTestCovariate();
|
||||||
|
final List<Row<Double>> sampleData = generateSampleData(covariate, 10);
|
||||||
|
sampleData.add(Row.createSimple(Utils.easyMap("pet", "NA"), Collections.singletonList(covariate), 11, 5.0));
|
||||||
|
|
||||||
|
covariate.generateSplitRuleUpdater(sampleData, 0, new Random());
|
||||||
|
|
||||||
|
// Test passes if no exception has occurred.
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
private FactorCovariate createTestCovariate(){
|
private FactorCovariate createTestCovariate(){
|
||||||
final List<String> levels = Utils.easyList("DOG", "CAT", "MOUSE");
|
final List<String> levels = Utils.easyList("DOG", "CAT", "MOUSE");
|
||||||
|
|
||||||
return new FactorCovariate("pet", 0, levels);
|
return new FactorCovariate("pet", 0, levels, false);
|
||||||
|
}
|
||||||
|
|
||||||
|
private List<Row<Double>> generateSampleData(Covariate covariate, int n){
|
||||||
|
final List<Covariate> covariateList = Collections.singletonList(covariate);
|
||||||
|
final List<Row<Double>> dataList = new ArrayList<>(n);
|
||||||
|
|
||||||
|
final String[] levels = new String[]{"DOG", "CAT", "MOUSE"};
|
||||||
|
|
||||||
|
for(int i=0; i<n; i++){
|
||||||
|
dataList.add(Row.createSimple(Utils.easyMap("pet", levels[i % 3]), covariateList, 1, 1.0));
|
||||||
|
}
|
||||||
|
|
||||||
|
return dataList;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -70,7 +70,7 @@ public class NumericCovariateTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testNumericCovariateDeterministic(){
|
public void testNumericCovariateDeterministic(){
|
||||||
final NumericCovariate covariate = new NumericCovariate("x", 0);
|
final NumericCovariate covariate = new NumericCovariate("x", 0, false);
|
||||||
|
|
||||||
final List<Row<Double>> dataset = createTestDataset(covariate);
|
final List<Row<Double>> dataset = createTestDataset(covariate);
|
||||||
|
|
||||||
|
@ -158,7 +158,7 @@ public class NumericCovariateTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testNumericSplitRuleUpdaterWithIndexes(){
|
public void testNumericSplitRuleUpdaterWithIndexes(){
|
||||||
final NumericCovariate covariate = new NumericCovariate("x", 0);
|
final NumericCovariate covariate = new NumericCovariate("x", 0, false);
|
||||||
|
|
||||||
final List<Row<Double>> dataset = createTestDataset(covariate);
|
final List<Row<Double>> dataset = createTestDataset(covariate);
|
||||||
|
|
||||||
|
@ -223,7 +223,7 @@ public class NumericCovariateTest {
|
||||||
*/
|
*/
|
||||||
@Test
|
@Test
|
||||||
public void testNumericSplitRuleUpdaterWithIndexesAllMissingData(){
|
public void testNumericSplitRuleUpdaterWithIndexesAllMissingData(){
|
||||||
final NumericCovariate covariate = new NumericCovariate("x", 0);
|
final NumericCovariate covariate = new NumericCovariate("x", 0, false);
|
||||||
final List<Row<Double>> dataset = createTestDatasetMissingValues(covariate);
|
final List<Row<Double>> dataset = createTestDatasetMissingValues(covariate);
|
||||||
final NumericSplitRuleUpdater<Double> updater = covariate.generateSplitRuleUpdater(dataset, 5, new Random());
|
final NumericSplitRuleUpdater<Double> updater = covariate.generateSplitRuleUpdater(dataset, 5, new Random());
|
||||||
|
|
||||||
|
|
|
@ -18,31 +18,34 @@ package ca.joeltherrien.randomforest.nas;
|
||||||
|
|
||||||
import ca.joeltherrien.randomforest.Row;
|
import ca.joeltherrien.randomforest.Row;
|
||||||
import ca.joeltherrien.randomforest.covariates.Covariate;
|
import ca.joeltherrien.randomforest.covariates.Covariate;
|
||||||
|
import ca.joeltherrien.randomforest.covariates.bool.BooleanCovariate;
|
||||||
|
import ca.joeltherrien.randomforest.covariates.factor.FactorCovariate;
|
||||||
import ca.joeltherrien.randomforest.covariates.numeric.NumericCovariate;
|
import ca.joeltherrien.randomforest.covariates.numeric.NumericCovariate;
|
||||||
import ca.joeltherrien.randomforest.responses.regression.MeanResponseCombiner;
|
import ca.joeltherrien.randomforest.responses.regression.MeanResponseCombiner;
|
||||||
import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceSplitFinder;
|
import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceSplitFinder;
|
||||||
|
import ca.joeltherrien.randomforest.tree.Split;
|
||||||
import ca.joeltherrien.randomforest.tree.TreeTrainer;
|
import ca.joeltherrien.randomforest.tree.TreeTrainer;
|
||||||
import ca.joeltherrien.randomforest.utils.Utils;
|
import ca.joeltherrien.randomforest.utils.Utils;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.Collections;
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Random;
|
import java.util.Random;
|
||||||
|
|
||||||
|
import static org.junit.jupiter.api.Assertions.*;
|
||||||
|
|
||||||
public class TestNAs {
|
public class TestNAs {
|
||||||
|
|
||||||
private List<Row<Double>> generateData(List<Covariate> covariates){
|
private List<Row<Double>> generateData1(List<Covariate> covariates){
|
||||||
final List<Row<Double>> dataList = new ArrayList<>();
|
final List<Row<Double>> dataList = new ArrayList<>();
|
||||||
|
|
||||||
|
|
||||||
// We must include an NA for one of the values
|
// We must include an NA for one of the values
|
||||||
dataList.add(Row.createSimple(Utils.easyMap("x", "NA"), covariates, 1, 5.0));
|
dataList.add(Row.createSimple(Utils.easyMap("x", "NA", "y", "true", "z", "green"), covariates, 1, 5.0));
|
||||||
dataList.add(Row.createSimple(Utils.easyMap("x", "1"), covariates, 1, 6.0));
|
dataList.add(Row.createSimple(Utils.easyMap("x", "1", "y", "NA", "z", "blue"), covariates, 2, 6.0));
|
||||||
dataList.add(Row.createSimple(Utils.easyMap("x", "2"), covariates, 1, 5.5));
|
dataList.add(Row.createSimple(Utils.easyMap("x", "2", "y", "true", "z", "NA"), covariates, 3, 5.5));
|
||||||
dataList.add(Row.createSimple(Utils.easyMap("x", "7"), covariates, 1, 0.0));
|
dataList.add(Row.createSimple(Utils.easyMap("x", "7", "y", "false", "z", "green"), covariates, 4, 0.0));
|
||||||
dataList.add(Row.createSimple(Utils.easyMap("x", "8"), covariates, 1, 1.0));
|
dataList.add(Row.createSimple(Utils.easyMap("x", "8", "y", "true", "z", "blue"), covariates, 5, 1.0));
|
||||||
dataList.add(Row.createSimple(Utils.easyMap("x", "8.4"), covariates, 1, 1.0));
|
dataList.add(Row.createSimple(Utils.easyMap("x", "8.4", "y", "false", "z", "yellow"), covariates, 6, 1.0));
|
||||||
|
|
||||||
|
|
||||||
return dataList;
|
return dataList;
|
||||||
|
@ -54,14 +57,19 @@ public class TestNAs {
|
||||||
// but NumericSplitRuleUpdater had unmodifiable lists when creating the split.
|
// but NumericSplitRuleUpdater had unmodifiable lists when creating the split.
|
||||||
// This bug verifies that this no longer causes a crash
|
// This bug verifies that this no longer causes a crash
|
||||||
|
|
||||||
final List<Covariate> covariates = Collections.singletonList(new NumericCovariate("x", 0));
|
final List<Covariate> covariates = Utils.easyList(
|
||||||
final List<Row<Double>> dataset = generateData(covariates);
|
new NumericCovariate("x", 0, false),
|
||||||
|
new BooleanCovariate("y", 1, true),
|
||||||
|
new FactorCovariate("z", 2, Utils.easyList("green", "blue", "yellow"), true)
|
||||||
|
);
|
||||||
|
final List<Row<Double>> dataset = generateData1(covariates);
|
||||||
|
|
||||||
final TreeTrainer<Double, Double> treeTrainer = TreeTrainer.<Double, Double>builder()
|
final TreeTrainer<Double, Double> treeTrainer = TreeTrainer.<Double, Double>builder()
|
||||||
.checkNodePurity(false)
|
.checkNodePurity(false)
|
||||||
.covariates(covariates)
|
.covariates(covariates)
|
||||||
.numberOfSplits(0)
|
.numberOfSplits(0)
|
||||||
.nodeSize(1)
|
.nodeSize(1)
|
||||||
|
.mtry(3)
|
||||||
.maxNodeDepth(1000)
|
.maxNodeDepth(1000)
|
||||||
.splitFinder(new WeightedVarianceSplitFinder())
|
.splitFinder(new WeightedVarianceSplitFinder())
|
||||||
.responseCombiner(new MeanResponseCombiner())
|
.responseCombiner(new MeanResponseCombiner())
|
||||||
|
@ -70,6 +78,87 @@ public class TestNAs {
|
||||||
treeTrainer.growTree(dataset, new Random(123));
|
treeTrainer.growTree(dataset, new Random(123));
|
||||||
|
|
||||||
// As long as no exception occurs, we passed
|
// As long as no exception occurs, we passed
|
||||||
|
}
|
||||||
|
|
||||||
|
private List<Row<Double>> generateData2(List<Covariate> covariates){
|
||||||
|
final List<Row<Double>> dataList = new ArrayList<>();
|
||||||
|
// Idea - when ignoring NAs, BadVar gives a perfect split.
|
||||||
|
// GoodVar is slightly worse than BadVar when NAs are excluded.
|
||||||
|
// However, BadVar has a ton of NAs that will degrade its performance.
|
||||||
|
dataList.add(Row.createSimple(
|
||||||
|
Utils.easyMap("BadVar", "-1.0", "GoodVar", "true") // GoodVars one error
|
||||||
|
, covariates, 1, 5.0)
|
||||||
|
);
|
||||||
|
dataList.add(Row.createSimple(
|
||||||
|
Utils.easyMap("BadVar", "NA", "GoodVar", "false")
|
||||||
|
, covariates, 2, 5.0)
|
||||||
|
);
|
||||||
|
dataList.add(Row.createSimple(
|
||||||
|
Utils.easyMap("BadVar", "NA", "GoodVar", "false")
|
||||||
|
, covariates, 3, 5.0)
|
||||||
|
);
|
||||||
|
dataList.add(Row.createSimple(
|
||||||
|
Utils.easyMap("BadVar", "0.5", "GoodVar", "true")
|
||||||
|
, covariates, 4, 10.0)
|
||||||
|
);
|
||||||
|
dataList.add(Row.createSimple(
|
||||||
|
Utils.easyMap("BadVar", "NA", "GoodVar", "true")
|
||||||
|
, covariates, 5, 10.0)
|
||||||
|
);
|
||||||
|
dataList.add(Row.createSimple(
|
||||||
|
Utils.easyMap("BadVar", "NA", "GoodVar", "true")
|
||||||
|
, covariates, 6, 10.0)
|
||||||
|
);
|
||||||
|
|
||||||
|
return dataList;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
// Test that the NA penalty works when selecting a best split.
|
||||||
|
public void testNAPenalty(){
|
||||||
|
final List<Covariate> covariates1 = Utils.easyList(
|
||||||
|
new NumericCovariate("BadVar", 0, true),
|
||||||
|
new BooleanCovariate("GoodVar", 1, false)
|
||||||
|
);
|
||||||
|
|
||||||
|
final List<Row<Double>> dataList1 = generateData2(covariates1);
|
||||||
|
|
||||||
|
final TreeTrainer<Double, Double> treeTrainer1 = TreeTrainer.<Double, Double>builder()
|
||||||
|
.checkNodePurity(false)
|
||||||
|
.covariates(covariates1)
|
||||||
|
.numberOfSplits(0)
|
||||||
|
.nodeSize(1)
|
||||||
|
.mtry(2)
|
||||||
|
.maxNodeDepth(1000)
|
||||||
|
.splitFinder(new WeightedVarianceSplitFinder())
|
||||||
|
.responseCombiner(new MeanResponseCombiner())
|
||||||
|
.build();
|
||||||
|
|
||||||
|
final Split<Double, ?> bestSplit1 = treeTrainer1.findBestSplitRule(dataList1, covariates1, new Random(123));
|
||||||
|
assertEquals(1, bestSplit1.getSplitRule().getParentCovariateIndex()); // 1 corresponds to GoodVar
|
||||||
|
|
||||||
|
// Run again without the penalty; verify that we get different results
|
||||||
|
|
||||||
|
final List<Covariate> covariates2 = Utils.easyList(
|
||||||
|
new NumericCovariate("BadVar", 0, false),
|
||||||
|
new BooleanCovariate("GoodVar", 1, false)
|
||||||
|
);
|
||||||
|
|
||||||
|
final List<Row<Double>> dataList2 = generateData2(covariates2);
|
||||||
|
|
||||||
|
final TreeTrainer<Double, Double> treeTrainer2 = TreeTrainer.<Double, Double>builder()
|
||||||
|
.checkNodePurity(false)
|
||||||
|
.covariates(covariates2)
|
||||||
|
.numberOfSplits(0)
|
||||||
|
.nodeSize(1)
|
||||||
|
.mtry(2)
|
||||||
|
.maxNodeDepth(1000)
|
||||||
|
.splitFinder(new WeightedVarianceSplitFinder())
|
||||||
|
.responseCombiner(new MeanResponseCombiner())
|
||||||
|
.build();
|
||||||
|
|
||||||
|
final Split<Double, ?> bestSplit2 = treeTrainer2.findBestSplitRule(dataList2, covariates2, new Random(123));
|
||||||
|
assertEquals(0, bestSplit2.getSplitRule().getParentCovariateIndex()); // 1 corresponds to GoodVar
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,143 @@
|
||||||
|
package ca.joeltherrien.randomforest.tree.vimp;
|
||||||
|
|
||||||
|
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskFunctions;
|
||||||
|
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
|
||||||
|
import ca.joeltherrien.randomforest.responses.competingrisk.IBSCalculator;
|
||||||
|
import ca.joeltherrien.randomforest.utils.Point;
|
||||||
|
import ca.joeltherrien.randomforest.utils.RightContinuousStepFunction;
|
||||||
|
import ca.joeltherrien.randomforest.utils.Utils;
|
||||||
|
import org.junit.jupiter.api.Test;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
|
||||||
|
public class IBSErrorCalculatorWrapperTest {
|
||||||
|
|
||||||
|
/*
|
||||||
|
We already have tests for the IBSCalculator, so these tests are concerned with making sure we correctly average
|
||||||
|
the errors together, not that we fully test the production of each error under different scenarios (like
|
||||||
|
providing / not providing a censoring distribution).
|
||||||
|
*/
|
||||||
|
|
||||||
|
private final double integrationUpperBound = 5.0;
|
||||||
|
|
||||||
|
private final List<CompetingRiskResponse> responses;
|
||||||
|
private final List<CompetingRiskFunctions> functions;
|
||||||
|
|
||||||
|
|
||||||
|
private final double[][] errors;
|
||||||
|
|
||||||
|
public IBSErrorCalculatorWrapperTest(){
|
||||||
|
this.responses = Utils.easyList(
|
||||||
|
new CompetingRiskResponse(0, 2.0),
|
||||||
|
new CompetingRiskResponse(0, 3.0),
|
||||||
|
new CompetingRiskResponse(1, 1.0),
|
||||||
|
new CompetingRiskResponse(1, 1.5),
|
||||||
|
new CompetingRiskResponse(2, 3.0),
|
||||||
|
new CompetingRiskResponse(2, 4.0)
|
||||||
|
);
|
||||||
|
|
||||||
|
final RightContinuousStepFunction cif1 = RightContinuousStepFunction.constructFromPoints(Utils.easyList(
|
||||||
|
new Point(1.0, 0.25),
|
||||||
|
new Point(1.5, 0.45)
|
||||||
|
), 0.0);
|
||||||
|
|
||||||
|
final RightContinuousStepFunction cif2 = RightContinuousStepFunction.constructFromPoints(Utils.easyList(
|
||||||
|
new Point(3.0, 0.25),
|
||||||
|
new Point(4.0, 0.45)
|
||||||
|
), 0.0);
|
||||||
|
|
||||||
|
// This function is for the unused CHFs and survival curve
|
||||||
|
// If we see infinities or NaNs popping up in our output we should look here.
|
||||||
|
final RightContinuousStepFunction emptyFun = RightContinuousStepFunction.constructFromPoints(Utils.easyList(
|
||||||
|
new Point(0.0, Double.NaN)
|
||||||
|
), Double.NEGATIVE_INFINITY
|
||||||
|
);
|
||||||
|
|
||||||
|
final CompetingRiskFunctions function = CompetingRiskFunctions.builder()
|
||||||
|
.cumulativeIncidenceCurves(Utils.easyList(cif1, cif2))
|
||||||
|
.causeSpecificHazards(Utils.easyList(emptyFun, emptyFun))
|
||||||
|
.survivalCurve(emptyFun)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
// Same prediction for every response.
|
||||||
|
this.functions = Utils.easyList(function, function, function, function, function, function);
|
||||||
|
|
||||||
|
final IBSCalculator calculator = new IBSCalculator();
|
||||||
|
this.errors = new double[2][6];
|
||||||
|
|
||||||
|
for(int event : new int[]{1, 2}){
|
||||||
|
for(int i=0; i<6; i++){
|
||||||
|
this.errors[event-1][i] = calculator.calculateError(
|
||||||
|
responses.get(i), function.getCumulativeIncidenceFunction(event),
|
||||||
|
event, integrationUpperBound
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testOneEventOne(){
|
||||||
|
final IBSErrorCalculatorWrapper wrapper = new IBSErrorCalculatorWrapper(new IBSCalculator(), new int[]{1},
|
||||||
|
this.integrationUpperBound);
|
||||||
|
|
||||||
|
final double error = wrapper.averageError(this.responses, this.functions);
|
||||||
|
double expectedError = 0.0;
|
||||||
|
for(int i=0; i<6; i++){
|
||||||
|
expectedError += errors[0][i] / 6.0;
|
||||||
|
}
|
||||||
|
|
||||||
|
assertEquals(expectedError, error, 0.00000001);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testOneEventTwo(){
|
||||||
|
final IBSErrorCalculatorWrapper wrapper = new IBSErrorCalculatorWrapper(new IBSCalculator(), new int[]{2},
|
||||||
|
this.integrationUpperBound);
|
||||||
|
|
||||||
|
final double error = wrapper.averageError(this.responses, this.functions);
|
||||||
|
double expectedError = 0.0;
|
||||||
|
for(int i=0; i<6; i++){
|
||||||
|
expectedError += errors[1][i] / 6.0;
|
||||||
|
}
|
||||||
|
|
||||||
|
assertEquals(expectedError, error, 0.00000001);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testTwoEventsNoWeights(){
|
||||||
|
final IBSErrorCalculatorWrapper wrapper = new IBSErrorCalculatorWrapper(new IBSCalculator(), new int[]{1, 2},
|
||||||
|
this.integrationUpperBound);
|
||||||
|
|
||||||
|
final double error = wrapper.averageError(this.responses, this.functions);
|
||||||
|
double expectedError1 = 0.0;
|
||||||
|
double expectedError2 = 0.0;
|
||||||
|
|
||||||
|
for(int i=0; i<6; i++){
|
||||||
|
expectedError1 += errors[0][i] / 6.0;
|
||||||
|
expectedError2 += errors[1][i] / 6.0;
|
||||||
|
}
|
||||||
|
|
||||||
|
assertEquals(expectedError1 + expectedError2, error, 0.00000001);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testTwoEventsWithWeights(){
|
||||||
|
final IBSErrorCalculatorWrapper wrapper = new IBSErrorCalculatorWrapper(new IBSCalculator(), new int[]{1, 2},
|
||||||
|
this.integrationUpperBound, new double[]{1.0, 2.0});
|
||||||
|
|
||||||
|
final double error = wrapper.averageError(this.responses, this.functions);
|
||||||
|
double expectedError1 = 0.0;
|
||||||
|
double expectedError2 = 0.0;
|
||||||
|
|
||||||
|
for(int i=0; i<6; i++){
|
||||||
|
expectedError1 += errors[0][i] / 6.0;
|
||||||
|
expectedError2 += errors[1][i] / 6.0;
|
||||||
|
}
|
||||||
|
|
||||||
|
assertEquals(1.0 * expectedError1 + 2.0 * expectedError2, error, 0.00000001);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,26 @@
|
||||||
|
package ca.joeltherrien.randomforest.tree.vimp;
|
||||||
|
|
||||||
|
import ca.joeltherrien.randomforest.utils.Utils;
|
||||||
|
import org.junit.jupiter.api.Test;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
|
||||||
|
public class RegressionErrorCalculatorTest {
|
||||||
|
|
||||||
|
private final RegressionErrorCalculator calculator = new RegressionErrorCalculator();
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testRegressionErrorCalculator(){
|
||||||
|
final List<Double> responses = Utils.easyList(1.0, 1.5, 0.0, 3.0);
|
||||||
|
final List<Double> predictions = Utils.easyList(1.5, 1.7, 0.1, 2.9);
|
||||||
|
|
||||||
|
// Differences are 0.5, 0.2, -0.1, 0.1
|
||||||
|
// Squared: 0.25, 0.04, 0.01, 0.01
|
||||||
|
|
||||||
|
assertEquals((0.25 + 0.04 + 0.01 + 0.01)/4.0, calculator.averageError(responses, predictions), 0.000000001);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,484 @@
|
||||||
|
package ca.joeltherrien.randomforest.tree.vimp;
|
||||||
|
|
||||||
|
import ca.joeltherrien.randomforest.Row;
|
||||||
|
import ca.joeltherrien.randomforest.covariates.Covariate;
|
||||||
|
import ca.joeltherrien.randomforest.covariates.bool.BooleanCovariate;
|
||||||
|
import ca.joeltherrien.randomforest.covariates.bool.BooleanSplitRule;
|
||||||
|
import ca.joeltherrien.randomforest.covariates.factor.FactorCovariate;
|
||||||
|
import ca.joeltherrien.randomforest.covariates.numeric.NumericCovariate;
|
||||||
|
import ca.joeltherrien.randomforest.covariates.numeric.NumericSplitRule;
|
||||||
|
import ca.joeltherrien.randomforest.responses.regression.MeanResponseCombiner;
|
||||||
|
import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceSplitFinder;
|
||||||
|
import ca.joeltherrien.randomforest.tree.*;
|
||||||
|
import ca.joeltherrien.randomforest.utils.Utils;
|
||||||
|
import org.junit.jupiter.api.Test;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Optional;
|
||||||
|
import java.util.Random;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
|
||||||
|
public class TestVariableImportanceCalculator {
|
||||||
|
|
||||||
|
/*
|
||||||
|
Since the logic for VariableImportanceCalculator is generic, it will be much easier to test under a regression
|
||||||
|
setting.
|
||||||
|
*/
|
||||||
|
|
||||||
|
// We'l have a very simple Forest of two trees
|
||||||
|
private final OnlineForest<Double, Double> forest;
|
||||||
|
|
||||||
|
|
||||||
|
private final List<Covariate> covariates;
|
||||||
|
private final List<Row<Double>> rowList;
|
||||||
|
|
||||||
|
/*
|
||||||
|
Long setup process; forest is manually constructed so that we can be exactly sure on our variable importance.
|
||||||
|
|
||||||
|
*/
|
||||||
|
public TestVariableImportanceCalculator(){
|
||||||
|
final BooleanCovariate booleanCovariate = new BooleanCovariate("x", 0, false);
|
||||||
|
final NumericCovariate numericCovariate = new NumericCovariate("y", 1, false);
|
||||||
|
final FactorCovariate factorCovariate = new FactorCovariate("z", 2,
|
||||||
|
Utils.easyList("red", "blue", "green"), false);
|
||||||
|
|
||||||
|
this.covariates = Utils.easyList(booleanCovariate, numericCovariate, factorCovariate);
|
||||||
|
|
||||||
|
final TreeTrainer<Double, Double> treeTrainer = TreeTrainer.<Double, Double>builder()
|
||||||
|
.responseCombiner(new MeanResponseCombiner())
|
||||||
|
.splitFinder(new WeightedVarianceSplitFinder())
|
||||||
|
.numberOfSplits(0)
|
||||||
|
.nodeSize(1)
|
||||||
|
.maxNodeDepth(100)
|
||||||
|
.mtry(3)
|
||||||
|
.checkNodePurity(false)
|
||||||
|
.covariates(this.covariates)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
/*
|
||||||
|
Plan for data - BooleanCovariate is split on first and has the largest impact.
|
||||||
|
NumericCovariate is at second level and has more minimal impact.
|
||||||
|
FactorCovariate is useless and never used.
|
||||||
|
Our tree (we'll duplicate it for testing OOB errors) will have a depth of 1. (0 based).
|
||||||
|
*/
|
||||||
|
|
||||||
|
final Tree<Double> tree1 = makeTree(covariates, 0.0, new int[]{1,2,3,4});
|
||||||
|
final Tree<Double> tree2 = makeTree(covariates, 2.0, new int[]{5,6,7,8});
|
||||||
|
|
||||||
|
this.forest = OnlineForest.<Double, Double>builder()
|
||||||
|
.trees(Utils.easyList(tree1, tree2))
|
||||||
|
.treeResponseCombiner(new MeanResponseCombiner())
|
||||||
|
.build();
|
||||||
|
|
||||||
|
// formula; boolean high adds 100; high numeric adds 10
|
||||||
|
// This row list should have a baseline error of 0.0
|
||||||
|
this.rowList = Utils.easyList(
|
||||||
|
Row.createSimple(Utils.easyMap(
|
||||||
|
"x", "false",
|
||||||
|
"y", "0.0",
|
||||||
|
"z", "red"),
|
||||||
|
covariates, 1, 0.0
|
||||||
|
),
|
||||||
|
Row.createSimple(Utils.easyMap(
|
||||||
|
"x", "false",
|
||||||
|
"y", "10.0",
|
||||||
|
"z", "blue"),
|
||||||
|
covariates, 2, 10.0
|
||||||
|
),
|
||||||
|
Row.createSimple(Utils.easyMap(
|
||||||
|
"x", "true",
|
||||||
|
"y", "0.0",
|
||||||
|
"z", "red"),
|
||||||
|
covariates, 3, 100.0
|
||||||
|
),
|
||||||
|
Row.createSimple(Utils.easyMap(
|
||||||
|
"x", "true",
|
||||||
|
"y", "10.0",
|
||||||
|
"z", "green"),
|
||||||
|
covariates, 4, 110.0
|
||||||
|
),
|
||||||
|
|
||||||
|
Row.createSimple(Utils.easyMap(
|
||||||
|
"x", "false",
|
||||||
|
"y", "0.0",
|
||||||
|
"z", "red"),
|
||||||
|
covariates, 5, 0.0
|
||||||
|
),
|
||||||
|
Row.createSimple(Utils.easyMap(
|
||||||
|
"x", "false",
|
||||||
|
"y", "10.0",
|
||||||
|
"z", "blue"),
|
||||||
|
covariates, 6, 10.0
|
||||||
|
),
|
||||||
|
Row.createSimple(Utils.easyMap(
|
||||||
|
"x", "true",
|
||||||
|
"y", "0.0",
|
||||||
|
"z", "red"),
|
||||||
|
covariates, 7, 100.0
|
||||||
|
),
|
||||||
|
Row.createSimple(Utils.easyMap(
|
||||||
|
"x", "true",
|
||||||
|
"y", "10.0",
|
||||||
|
"z", "green"),
|
||||||
|
covariates, 8, 110.0
|
||||||
|
)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
private Tree<Double> makeTree(List<Covariate> covariates, double offset, int[] indices){
|
||||||
|
// Naming convention - xyTerminal where x and y are low/high denotes whether BooleanCovariate(x) is low/high and
|
||||||
|
// whether NumericCovariate(y) is low/high.
|
||||||
|
final TerminalNode<Double> lowLowTerminal = new TerminalNode<>(0.0 + offset, 5);
|
||||||
|
final TerminalNode<Double> lowHighTerminal = new TerminalNode<>(10.0 + offset, 5);
|
||||||
|
final TerminalNode<Double> highLowTerminal = new TerminalNode<>(100.0 + offset, 5);
|
||||||
|
final TerminalNode<Double> highHighTerminal = new TerminalNode<>(110.0 + offset, 5);
|
||||||
|
|
||||||
|
final SplitNode<Double> lowSplitNode = SplitNode.<Double>builder()
|
||||||
|
.leftHand(lowLowTerminal)
|
||||||
|
.rightHand(lowHighTerminal)
|
||||||
|
.probabilityNaLeftHand(0.5)
|
||||||
|
.splitRule(new NumericSplitRule((NumericCovariate) covariates.get(1), 5.0))
|
||||||
|
.build();
|
||||||
|
|
||||||
|
final SplitNode<Double> highSplitNode = SplitNode.<Double>builder()
|
||||||
|
.leftHand(highLowTerminal)
|
||||||
|
.rightHand(highHighTerminal)
|
||||||
|
.probabilityNaLeftHand(0.5)
|
||||||
|
.splitRule(new NumericSplitRule((NumericCovariate) covariates.get(1), 5.0))
|
||||||
|
.build();
|
||||||
|
|
||||||
|
final SplitNode<Double> rootSplitNode = SplitNode.<Double>builder()
|
||||||
|
.leftHand(lowSplitNode)
|
||||||
|
.rightHand(highSplitNode)
|
||||||
|
.probabilityNaLeftHand(0.5)
|
||||||
|
.splitRule(new BooleanSplitRule((BooleanCovariate) covariates.get(0)))
|
||||||
|
.build();
|
||||||
|
|
||||||
|
return new Tree<>(rootSplitNode, indices);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
// Experiment with random seeds to first examine what a split does so we know what to expect
|
||||||
|
/*
|
||||||
|
public static void main(String[] args){
|
||||||
|
|
||||||
|
// Behaviour for OOB
|
||||||
|
final List<Integer> ints1 = IntStream.range(5, 9).boxed().collect(Collectors.toList());
|
||||||
|
final List<Integer> ints2 = IntStream.range(1, 5).boxed().collect(Collectors.toList());
|
||||||
|
|
||||||
|
final Random random = new Random(123);
|
||||||
|
Collections.shuffle(ints1, random);
|
||||||
|
Collections.shuffle(ints2, random);
|
||||||
|
|
||||||
|
System.out.println(ints1);
|
||||||
|
System.out.println(ints2);
|
||||||
|
// [5, 6, 8, 7]
|
||||||
|
// [3, 4, 1, 2]
|
||||||
|
|
||||||
|
|
||||||
|
// Behaviour for no-OOB
|
||||||
|
final List<Integer> fullInts1 = IntStream.range(1, 9).boxed().collect(Collectors.toList());
|
||||||
|
final List<Integer> fullInts2 = IntStream.range(1, 9).boxed().collect(Collectors.toList());
|
||||||
|
final Random fullIntsRandom = new Random(123);
|
||||||
|
|
||||||
|
Collections.shuffle(fullInts1, fullIntsRandom);
|
||||||
|
Collections.shuffle(fullInts2, fullIntsRandom);
|
||||||
|
System.out.println(fullInts1);
|
||||||
|
System.out.println(fullInts2);
|
||||||
|
// [1, 4, 8, 2, 5, 3, 7, 6]
|
||||||
|
// [6, 1, 4, 7, 5, 2, 8, 3]
|
||||||
|
|
||||||
|
}
|
||||||
|
*/
|
||||||
|
|
||||||
|
|
||||||
|
private double[] difference(double[] a, double[] b){
|
||||||
|
final double[] results = new double[a.length];
|
||||||
|
|
||||||
|
for(int i = 0; i < a.length; i++){
|
||||||
|
results[i] = a[i] - b[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
return results;
|
||||||
|
}
|
||||||
|
|
||||||
|
private void assertDoubleEquals(double[] expected, double[] actual){
|
||||||
|
assertEquals(expected.length, actual.length, "Lengths of arrays should be equal");
|
||||||
|
|
||||||
|
for(int i=0; i < expected.length; i++){
|
||||||
|
assertEquals(expected[i], actual[i], 0.0000001, "Difference at " + i);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testVariableImportanceOnXNoOOB(){
|
||||||
|
// x is the BooleanCovariate
|
||||||
|
|
||||||
|
final VariableImportanceCalculator<Double, Double> calculator = new VariableImportanceCalculator<>(
|
||||||
|
new RegressionErrorCalculator(),
|
||||||
|
this.forest.getTrees(),
|
||||||
|
this.rowList,
|
||||||
|
false
|
||||||
|
);
|
||||||
|
|
||||||
|
final Covariate covariate = this.covariates.get(0);
|
||||||
|
|
||||||
|
double importance[] = calculator.calculateVariableImportanceRaw(covariate, Optional.of(new Random(123)));
|
||||||
|
|
||||||
|
final double[] expectedBaselineError = {0.0, 4.0}; // first tree is accurate, second tree is not
|
||||||
|
|
||||||
|
// [1, 4, 8, 2, 5, 3, 7, 6]
|
||||||
|
final List<Double> permutedPredictionsTree1 = Utils.easyList(
|
||||||
|
0., 110., 100., 10., 0., 110., 100., 10.
|
||||||
|
);
|
||||||
|
|
||||||
|
// [6, 1, 4, 7, 5, 2, 8, 3]
|
||||||
|
// Actual: [F, F, T, T, F, F, T, T]
|
||||||
|
// Seen: [F, F, T, T, F, F, T, T]
|
||||||
|
// Difference: 0 all around; random chance
|
||||||
|
final List<Double> permutedPredictionsTree2 = Utils.easyList(
|
||||||
|
2., 12., 102., 112., 2., 12., 102., 112.
|
||||||
|
);
|
||||||
|
|
||||||
|
final List<Double> observedValues = this.rowList.stream().map(Row::getResponse).collect(Collectors.toList());
|
||||||
|
|
||||||
|
final double[] expectedError = new double[2];
|
||||||
|
|
||||||
|
expectedError[0] = new RegressionErrorCalculator().averageError(observedValues, permutedPredictionsTree1);
|
||||||
|
expectedError[1] = new RegressionErrorCalculator().averageError(observedValues, permutedPredictionsTree2);
|
||||||
|
|
||||||
|
final double[] expectedVimp = difference(expectedError, expectedBaselineError);
|
||||||
|
|
||||||
|
assertDoubleEquals(expectedVimp, importance);
|
||||||
|
|
||||||
|
final double expectedVimpMean = (expectedVimp[0] + expectedVimp[1]) / 2.0;
|
||||||
|
final double expectedVimpVar = (Math.pow(expectedVimp[0] - expectedVimpMean, 2) + Math.pow(expectedVimp[1] - expectedVimpMean, 2)) / 1.0;
|
||||||
|
final double expectedVimpStandardError = Math.sqrt(expectedVimpVar / 2.0);
|
||||||
|
final double expectedZScore = expectedVimpMean / expectedVimpStandardError;
|
||||||
|
|
||||||
|
final double actualZScore = calculator.calculateVariableImportanceZScore(covariate, Optional.of(new Random(123)));
|
||||||
|
|
||||||
|
assertEquals(expectedZScore, actualZScore, 0.000001, "Z scores must match");
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testVariableImportanceOnXOOB(){
|
||||||
|
// x is the BooleanCovariate
|
||||||
|
|
||||||
|
final VariableImportanceCalculator<Double, Double> calculator = new VariableImportanceCalculator<>(
|
||||||
|
new RegressionErrorCalculator(),
|
||||||
|
this.forest.getTrees(),
|
||||||
|
this.rowList,
|
||||||
|
true
|
||||||
|
);
|
||||||
|
|
||||||
|
final Covariate covariate = this.covariates.get(0);
|
||||||
|
|
||||||
|
double importance[] = calculator.calculateVariableImportanceRaw(covariate, Optional.of(new Random(123)));
|
||||||
|
|
||||||
|
final double[] expectedBaselineError = {0.0, 4.0}; // first tree is accurate, second tree is not
|
||||||
|
|
||||||
|
// [5, 6, 8, 7]
|
||||||
|
// Actual: [F, F, T, T]
|
||||||
|
// Seen: [F, F, T, T]
|
||||||
|
// Difference: No differences
|
||||||
|
final List<Double> permutedPredictionsTree1 = Utils.easyList(
|
||||||
|
0., 10., 100., 110.
|
||||||
|
);
|
||||||
|
|
||||||
|
// [3, 4, 1, 2]
|
||||||
|
// Actual: [F, F, T, T]
|
||||||
|
// Seen: [T, T, F, F]
|
||||||
|
// Difference: +100, +100, -100, -100
|
||||||
|
final List<Double> permutedPredictionsTree2 = Utils.easyList(
|
||||||
|
102., 112., 2., 12.
|
||||||
|
);
|
||||||
|
|
||||||
|
final List<Double> observedValues = this.rowList.stream().map(Row::getResponse).collect(Collectors.toList());
|
||||||
|
final List<Double> tree1OOBValues = observedValues.subList(4, 8);
|
||||||
|
final List<Double> tree2OOBValues = observedValues.subList(0, 4);
|
||||||
|
|
||||||
|
final double[] expectedError = new double[2];
|
||||||
|
|
||||||
|
expectedError[0] = new RegressionErrorCalculator().averageError(tree1OOBValues, permutedPredictionsTree1);
|
||||||
|
expectedError[1] = new RegressionErrorCalculator().averageError(tree2OOBValues, permutedPredictionsTree2);
|
||||||
|
|
||||||
|
final double[] expectedVimp = difference(expectedError, expectedBaselineError);
|
||||||
|
|
||||||
|
assertDoubleEquals(expectedVimp, importance);
|
||||||
|
|
||||||
|
final double expectedVimpMean = (expectedVimp[0] + expectedVimp[1]) / 2.0;
|
||||||
|
final double expectedVimpVar = (Math.pow(expectedVimp[0] - expectedVimpMean, 2) + Math.pow(expectedVimp[1] - expectedVimpMean, 2)) / 1.0;
|
||||||
|
final double expectedVimpStandardError = Math.sqrt(expectedVimpVar / 2.0);
|
||||||
|
final double expectedZScore = expectedVimpMean / expectedVimpStandardError;
|
||||||
|
|
||||||
|
final double actualZScore = calculator.calculateVariableImportanceZScore(covariate, Optional.of(new Random(123)));
|
||||||
|
|
||||||
|
assertEquals(expectedZScore, actualZScore, 0.000001, "Z scores must match");
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testVariableImportanceOnYNoOOB(){
|
||||||
|
// y is the NumericCovariate
|
||||||
|
|
||||||
|
final VariableImportanceCalculator<Double, Double> calculator = new VariableImportanceCalculator<>(
|
||||||
|
new RegressionErrorCalculator(),
|
||||||
|
this.forest.getTrees(),
|
||||||
|
this.rowList,
|
||||||
|
false
|
||||||
|
);
|
||||||
|
|
||||||
|
final Covariate covariate = this.covariates.get(1);
|
||||||
|
|
||||||
|
double importance[] = calculator.calculateVariableImportanceRaw(covariate, Optional.of(new Random(123)));
|
||||||
|
|
||||||
|
final double[] expectedBaselineError = {0.0, 4.0}; // first tree is accurate, second tree is not
|
||||||
|
|
||||||
|
// [1, 4, 8, 2, 5, 3, 7, 6]
|
||||||
|
// Actual: [F, T, F, T, F, T, F, T]
|
||||||
|
// Seen: [F, T, T, T, F, F, F, T]
|
||||||
|
// Difference: [=, =, +, =, =, -, =, =]x10
|
||||||
|
final List<Double> permutedPredictionsTree1 = Utils.easyList(
|
||||||
|
0., 10., 110., 110., 0., 0., 100., 110.
|
||||||
|
);
|
||||||
|
|
||||||
|
// [6, 1, 4, 7, 5, 2, 8, 3]
|
||||||
|
// Actual: [F, T, F, T, F, T, F, T]
|
||||||
|
// Seen: [T, F, T, F, F, T, T, F]
|
||||||
|
// Difference: [+, -, +, -, =, =, +, -]
|
||||||
|
final List<Double> permutedPredictionsTree2 = Utils.easyList(
|
||||||
|
12., 2., 112., 102., 2., 12., 112., 102.
|
||||||
|
);
|
||||||
|
|
||||||
|
final List<Double> observedValues = this.rowList.stream().map(Row::getResponse).collect(Collectors.toList());
|
||||||
|
|
||||||
|
final double[] expectedError = new double[2];
|
||||||
|
|
||||||
|
expectedError[0] = new RegressionErrorCalculator().averageError(observedValues, permutedPredictionsTree1);
|
||||||
|
expectedError[1] = new RegressionErrorCalculator().averageError(observedValues, permutedPredictionsTree2);
|
||||||
|
|
||||||
|
final double[] expectedVimp = difference(expectedError, expectedBaselineError);
|
||||||
|
|
||||||
|
assertDoubleEquals(expectedVimp, importance);
|
||||||
|
|
||||||
|
final double expectedVimpMean = (expectedVimp[0] + expectedVimp[1]) / 2.0;
|
||||||
|
final double expectedVimpVar = (Math.pow(expectedVimp[0] - expectedVimpMean, 2) + Math.pow(expectedVimp[1] - expectedVimpMean, 2)) / 1.0;
|
||||||
|
final double expectedVimpStandardError = Math.sqrt(expectedVimpVar / 2.0);
|
||||||
|
final double expectedZScore = expectedVimpMean / expectedVimpStandardError;
|
||||||
|
|
||||||
|
final double actualZScore = calculator.calculateVariableImportanceZScore(covariate, Optional.of(new Random(123)));
|
||||||
|
|
||||||
|
assertEquals(expectedZScore, actualZScore, 0.000001, "Z scores must match");
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testVariableImportanceOnYOOB(){
|
||||||
|
// y is the NumericCovariate
|
||||||
|
|
||||||
|
final VariableImportanceCalculator<Double, Double> calculator = new VariableImportanceCalculator<>(
|
||||||
|
new RegressionErrorCalculator(),
|
||||||
|
this.forest.getTrees(),
|
||||||
|
this.rowList,
|
||||||
|
true
|
||||||
|
);
|
||||||
|
|
||||||
|
final Covariate covariate = this.covariates.get(1);
|
||||||
|
|
||||||
|
double importance[] = calculator.calculateVariableImportanceRaw(covariate, Optional.of(new Random(123)));
|
||||||
|
|
||||||
|
final double[] expectedBaselineError = {0.0, 4.0}; // first tree is accurate, second tree is not
|
||||||
|
|
||||||
|
// [5, 6, 8, 7]
|
||||||
|
// Actual: [F, T, F, T]
|
||||||
|
// Seen: [F, T, T, F]
|
||||||
|
// Difference: [=, =, +, -]x10
|
||||||
|
final List<Double> permutedPredictionsTree1 = Utils.easyList(
|
||||||
|
0., 10., 110., 100.
|
||||||
|
);
|
||||||
|
|
||||||
|
// [3, 4, 1, 2]
|
||||||
|
// Actual: [F, T, F, T]
|
||||||
|
// Seen: [F, T, F, T]
|
||||||
|
// Difference: [=, =, =, =]x10 no change
|
||||||
|
final List<Double> permutedPredictionsTree2 = Utils.easyList(
|
||||||
|
2., 12., 102., 112.
|
||||||
|
);
|
||||||
|
|
||||||
|
final List<Double> observedValues = this.rowList.stream().map(Row::getResponse).collect(Collectors.toList());
|
||||||
|
final List<Double> tree1OOBValues = observedValues.subList(4, 8);
|
||||||
|
final List<Double> tree2OOBValues = observedValues.subList(0, 4);
|
||||||
|
|
||||||
|
final double[] expectedError = new double[2];
|
||||||
|
|
||||||
|
expectedError[0] = new RegressionErrorCalculator().averageError(tree1OOBValues, permutedPredictionsTree1);
|
||||||
|
expectedError[1] = new RegressionErrorCalculator().averageError(tree2OOBValues, permutedPredictionsTree2);
|
||||||
|
|
||||||
|
final double[] expectedVimp = difference(expectedError, expectedBaselineError);
|
||||||
|
|
||||||
|
assertDoubleEquals(expectedVimp, importance);
|
||||||
|
|
||||||
|
final double expectedVimpMean = (expectedVimp[0] + expectedVimp[1]) / 2.0;
|
||||||
|
final double expectedVimpVar = (Math.pow(expectedVimp[0] - expectedVimpMean, 2) + Math.pow(expectedVimp[1] - expectedVimpMean, 2)) / 1.0;
|
||||||
|
final double expectedVimpStandardError = Math.sqrt(expectedVimpVar / 2.0);
|
||||||
|
final double expectedZScore = expectedVimpMean / expectedVimpStandardError;
|
||||||
|
|
||||||
|
final double actualZScore = calculator.calculateVariableImportanceZScore(covariate, Optional.of(new Random(123)));
|
||||||
|
|
||||||
|
assertEquals(expectedZScore, actualZScore, 0.000001, "Z scores must match");
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testVariableImportanceOnZNoOOB(){
|
||||||
|
// z is the useless FactorCovariate
|
||||||
|
|
||||||
|
final VariableImportanceCalculator<Double, Double> calculator = new VariableImportanceCalculator<>(
|
||||||
|
new RegressionErrorCalculator(),
|
||||||
|
this.forest.getTrees(),
|
||||||
|
this.rowList,
|
||||||
|
false
|
||||||
|
);
|
||||||
|
|
||||||
|
final double[] importance = calculator.calculateVariableImportanceRaw(this.covariates.get(2), Optional.of(new Random(123)));
|
||||||
|
final double[] expectedImportance = {0.0, 0.0};
|
||||||
|
|
||||||
|
|
||||||
|
// FactorImportance did nothing; so permuting it will make no difference to baseline error
|
||||||
|
assertDoubleEquals(expectedImportance, importance);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testVariableImportanceOnZOOB(){
|
||||||
|
// z is the useless FactorCovariate
|
||||||
|
|
||||||
|
final VariableImportanceCalculator<Double, Double> calculator = new VariableImportanceCalculator<>(
|
||||||
|
new RegressionErrorCalculator(),
|
||||||
|
this.forest.getTrees(),
|
||||||
|
this.rowList,
|
||||||
|
true
|
||||||
|
);
|
||||||
|
|
||||||
|
final double[] importance = calculator.calculateVariableImportanceRaw(this.covariates.get(2), Optional.of(new Random(123)));
|
||||||
|
final double[] expectedImportance = {0.0, 0.0};
|
||||||
|
|
||||||
|
|
||||||
|
// FactorImportance did nothing; so permuting it will make no difference to baseline error
|
||||||
|
assertDoubleEquals(expectedImportance, importance);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
}
|
|
@ -19,6 +19,9 @@ package ca.joeltherrien.randomforest.utils;
|
||||||
import ca.joeltherrien.randomforest.TestUtils;
|
import ca.joeltherrien.randomforest.TestUtils;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
|
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||||
|
|
||||||
public class RightContinuousStepFunctionIntegrationTest {
|
public class RightContinuousStepFunctionIntegrationTest {
|
||||||
|
|
||||||
private RightContinuousStepFunction createTestFunction(){
|
private RightContinuousStepFunction createTestFunction(){
|
||||||
|
@ -75,5 +78,78 @@ public class RightContinuousStepFunctionIntegrationTest {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testInvertedFromTo(){
|
||||||
|
final RightContinuousStepFunction function = createTestFunction();
|
||||||
|
|
||||||
|
final double area1 = function.integrate(0, 3.0);
|
||||||
|
final double area2 = function.integrate(3.0, 0.0);
|
||||||
|
|
||||||
|
assertEquals(area1, -area2, 0.0000001);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testIntegratingUpToNan(){
|
||||||
|
// Idea here - you have a function that is valid up to point x where it becomes NaN
|
||||||
|
// You should be able to integrate *up to* point x and not get an NaN
|
||||||
|
|
||||||
|
final RightContinuousStepFunction function1 = new RightContinuousStepFunction(
|
||||||
|
new double[]{1.0, 2.0, 3.0, 4.0},
|
||||||
|
new double[]{1.0, 1.0, 1.0, Double.NaN},
|
||||||
|
0.0);
|
||||||
|
|
||||||
|
|
||||||
|
final double area1 = function1.integrate(0.0, 4.0);
|
||||||
|
assertEquals(3.0, area1, 0.000000001);
|
||||||
|
|
||||||
|
final double nanArea1 = function1.integrate(0.0, 4.0001);
|
||||||
|
assertTrue(Double.isNaN(nanArea1));
|
||||||
|
|
||||||
|
|
||||||
|
// This tests integrating over the defaultY up to the NaN point
|
||||||
|
final RightContinuousStepFunction function2 = new RightContinuousStepFunction(
|
||||||
|
new double[]{1.0, 2.0, 3.0, 4.0},
|
||||||
|
new double[]{Double.NaN, 1.0, 1.0, Double.NaN},
|
||||||
|
1.0);
|
||||||
|
|
||||||
|
|
||||||
|
final double area2 = function2.integrate(0.0, 1.0);
|
||||||
|
assertEquals(1.0, area2, 0.000000001);
|
||||||
|
|
||||||
|
final double nanArea2 = function2.integrate(0.0, 4.0);
|
||||||
|
assertTrue(Double.isNaN(nanArea2));
|
||||||
|
|
||||||
|
|
||||||
|
// This tests integrating between two NaN points. Note that of course for RightContinuousValues carry the previous
|
||||||
|
// value until the next x point, so this is just making sure the code works if the x-value we pass over is NaN
|
||||||
|
final RightContinuousStepFunction function3 = new RightContinuousStepFunction(
|
||||||
|
new double[]{1.0, 2.0, 3.0, 4.0},
|
||||||
|
new double[]{Double.NaN, 1.0, 1.0, Double.NaN},
|
||||||
|
0.0);
|
||||||
|
|
||||||
|
|
||||||
|
final double area3 = function3.integrate(2.0, 4.0);
|
||||||
|
assertEquals(2.0, area3, 0.000000001);
|
||||||
|
|
||||||
|
final double nanArea3 = function3.integrate(0.0, 4.0001);
|
||||||
|
assertTrue(Double.isNaN(nanArea3));
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testIntegratingEmptyFunction(){
|
||||||
|
// A function might have no points, but we'll still need to integrate it.
|
||||||
|
|
||||||
|
final RightContinuousStepFunction function = new RightContinuousStepFunction(
|
||||||
|
new double[]{}, new double[]{}, 1.0
|
||||||
|
);
|
||||||
|
|
||||||
|
final double area = function.integrate(1.0 ,3.0);
|
||||||
|
assertEquals(2.0, area, 0.000001);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,180 @@
|
||||||
|
package ca.joeltherrien.randomforest.utils;
|
||||||
|
|
||||||
|
import org.junit.jupiter.api.Test;
|
||||||
|
|
||||||
|
import java.util.function.DoubleBinaryOperator;
|
||||||
|
import java.util.function.DoubleUnaryOperator;
|
||||||
|
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
|
||||||
|
public class RightContinuousStepFunctionOperatorTests {
|
||||||
|
|
||||||
|
|
||||||
|
// Idea - small and middle slightly overlap; middle and large slightly overlap.
|
||||||
|
// small and large never overlap (i.e. small's x values always occur before large's)
|
||||||
|
private final RightContinuousStepFunction smallNumbers;
|
||||||
|
private final RightContinuousStepFunction middleNumbers;
|
||||||
|
private final RightContinuousStepFunction largeNumbers;
|
||||||
|
|
||||||
|
private final double delta = 0.0000000001;
|
||||||
|
|
||||||
|
public RightContinuousStepFunctionOperatorTests(){
|
||||||
|
smallNumbers = RightContinuousStepFunction.constructFromPoints(Utils.easyList(
|
||||||
|
new Point(1.0, 1.0),
|
||||||
|
new Point(2.0, 3.0),
|
||||||
|
new Point(3.0, 2.0),
|
||||||
|
new Point(4.0, 1.0)
|
||||||
|
), 0.0);
|
||||||
|
|
||||||
|
middleNumbers = RightContinuousStepFunction.constructFromPoints(Utils.easyList(
|
||||||
|
new Point(3.5, 4.0),
|
||||||
|
new Point(4.0, 3.0),
|
||||||
|
new Point(5.0, 2.0),
|
||||||
|
new Point(6.0, 1.0)
|
||||||
|
), 5.0);
|
||||||
|
|
||||||
|
largeNumbers = RightContinuousStepFunction.constructFromPoints(Utils.easyList(
|
||||||
|
new Point(5.0, 5.0),
|
||||||
|
new Point(6.0, 6.0),
|
||||||
|
new Point(7.0, 3.0),
|
||||||
|
new Point(8.0, 2.0)
|
||||||
|
), 3.0);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testDifferenceNoOverlapLargeMinusSmall(){
|
||||||
|
DoubleBinaryOperator operator = (a, b) -> a - b;
|
||||||
|
|
||||||
|
final RightContinuousStepFunction largeSmallDifference = RightContinuousStepFunction.biOperation(
|
||||||
|
largeNumbers,
|
||||||
|
smallNumbers,
|
||||||
|
operator);
|
||||||
|
|
||||||
|
assertEquals(8, largeSmallDifference.getX().length);
|
||||||
|
assertEquals(8, largeSmallDifference.getY().length);
|
||||||
|
|
||||||
|
final double[] offsetTimes = {-0.1, 0.0, 0.1};
|
||||||
|
|
||||||
|
for(int time = 1; time <= 9; time++){
|
||||||
|
for(double offsetTime : offsetTimes){
|
||||||
|
final double timeToEvaluateAt = (double) time + offsetTime;
|
||||||
|
|
||||||
|
final double largeFunEvaluation = largeNumbers.evaluate(timeToEvaluateAt);
|
||||||
|
final double smallFunEvaluation = smallNumbers.evaluate(timeToEvaluateAt);
|
||||||
|
final double expectedDifference = operator.applyAsDouble(largeFunEvaluation, smallFunEvaluation);
|
||||||
|
|
||||||
|
final double actualEvaluation = largeSmallDifference.evaluate(timeToEvaluateAt);
|
||||||
|
|
||||||
|
assertEquals(expectedDifference, actualEvaluation, delta);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testDifferenceNoOverlapSmallMinusLarge(){
|
||||||
|
DoubleBinaryOperator operator = (a, b) -> a - b;
|
||||||
|
|
||||||
|
final RightContinuousStepFunction smallLargeDifference = RightContinuousStepFunction.biOperation(
|
||||||
|
smallNumbers,
|
||||||
|
largeNumbers,
|
||||||
|
operator);
|
||||||
|
|
||||||
|
assertEquals(8, smallLargeDifference.getX().length);
|
||||||
|
assertEquals(8, smallLargeDifference.getY().length);
|
||||||
|
|
||||||
|
final double[] offsetTimes = {-0.1, 0.0, 0.1};
|
||||||
|
|
||||||
|
for(int time = 1; time <= 9; time++){
|
||||||
|
for(double offsetTime : offsetTimes){
|
||||||
|
final double timeToEvaluateAt = (double) time + offsetTime;
|
||||||
|
|
||||||
|
final double smallFunEvaluation = smallNumbers.evaluate(timeToEvaluateAt);
|
||||||
|
final double largeFunEvaluation = largeNumbers.evaluate(timeToEvaluateAt);
|
||||||
|
final double expectedDifference = operator.applyAsDouble(smallFunEvaluation, largeFunEvaluation);
|
||||||
|
|
||||||
|
final double actualEvaluation = smallLargeDifference.evaluate(timeToEvaluateAt);
|
||||||
|
|
||||||
|
assertEquals(expectedDifference, actualEvaluation, delta);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testDifferenceSomeOverlapLargeMinusMiddle(){
|
||||||
|
DoubleBinaryOperator operator = (a, b) -> a - b;
|
||||||
|
|
||||||
|
final RightContinuousStepFunction combinedFunction = RightContinuousStepFunction.biOperation(
|
||||||
|
largeNumbers,
|
||||||
|
middleNumbers,
|
||||||
|
operator);
|
||||||
|
|
||||||
|
assertEquals(6, combinedFunction.getX().length);
|
||||||
|
assertEquals(6, combinedFunction.getY().length);
|
||||||
|
|
||||||
|
final double[] offsetTimes = {-0.1, 0.0, 0.1};
|
||||||
|
|
||||||
|
for(int time = 1; time <= 9; time++){
|
||||||
|
for(double offsetTime : offsetTimes){
|
||||||
|
final double timeToEvaluateAt = (double) time + offsetTime;
|
||||||
|
|
||||||
|
final double middleFunEvaluation = middleNumbers.evaluate(timeToEvaluateAt);
|
||||||
|
final double largeFunEvaluation = largeNumbers.evaluate(timeToEvaluateAt);
|
||||||
|
final double expectedDifference = operator.applyAsDouble(largeFunEvaluation, middleFunEvaluation);
|
||||||
|
|
||||||
|
final double actualEvaluation = combinedFunction.evaluate(timeToEvaluateAt);
|
||||||
|
|
||||||
|
assertEquals(expectedDifference, actualEvaluation, delta);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testDifferenceCompleteOverlap(){
|
||||||
|
DoubleBinaryOperator operator = (a, b) -> a - b;
|
||||||
|
|
||||||
|
final RightContinuousStepFunction combinedFunction = RightContinuousStepFunction.biOperation(
|
||||||
|
middleNumbers,
|
||||||
|
middleNumbers,
|
||||||
|
operator);
|
||||||
|
|
||||||
|
assertEquals(4, combinedFunction.getX().length);
|
||||||
|
assertEquals(4, combinedFunction.getY().length);
|
||||||
|
|
||||||
|
final double[] offsetTimes = {-0.1, 0.0, 0.1};
|
||||||
|
|
||||||
|
for(int time = 1; time <= 9; time++){
|
||||||
|
for(double offsetTime : offsetTimes){
|
||||||
|
final double timeToEvaluateAt = (double) time + offsetTime;
|
||||||
|
|
||||||
|
final double actualEvaluation = combinedFunction.evaluate(timeToEvaluateAt);
|
||||||
|
|
||||||
|
assertEquals(0.0, actualEvaluation, delta);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testPowerFunction(){
|
||||||
|
final DoubleUnaryOperator operator = d -> d*d;
|
||||||
|
|
||||||
|
final RightContinuousStepFunction squaredFunction = smallNumbers.unaryOperation(operator);
|
||||||
|
|
||||||
|
assertEquals(4, squaredFunction.getX().length);
|
||||||
|
assertEquals(4, squaredFunction.getY().length);
|
||||||
|
|
||||||
|
final double[] offsetTimes = {-0.1, 0.0, 0.1};
|
||||||
|
|
||||||
|
for(int time = 1; time <= 9; time++){
|
||||||
|
for(double offsetTime : offsetTimes){
|
||||||
|
final double timeToEvaluateAt = (double) time + offsetTime;
|
||||||
|
|
||||||
|
final double expectedEvaluation = operator.applyAsDouble(smallNumbers.evaluate(timeToEvaluateAt));
|
||||||
|
final double actualEvaluation = squaredFunction.evaluate(timeToEvaluateAt);
|
||||||
|
|
||||||
|
assertEquals(expectedEvaluation, actualEvaluation, delta);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -43,7 +43,7 @@ public class TrainForest {
|
||||||
|
|
||||||
final List<Covariate> covariateList = new ArrayList<>(p);
|
final List<Covariate> covariateList = new ArrayList<>(p);
|
||||||
for(int j =0; j < p; j++){
|
for(int j =0; j < p; j++){
|
||||||
final NumericCovariate covariate = new NumericCovariate("x"+j, j);
|
final NumericCovariate covariate = new NumericCovariate("x"+j, j, false);
|
||||||
covariateList.add(covariate);
|
covariateList.add(covariate);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -39,8 +39,8 @@ public class TrainSingleTree {
|
||||||
final int n = 1000;
|
final int n = 1000;
|
||||||
final List<Row<Double>> trainingSet = new ArrayList<>(n);
|
final List<Row<Double>> trainingSet = new ArrayList<>(n);
|
||||||
|
|
||||||
final Covariate<Double> x1Covariate = new NumericCovariate("x1", 0);
|
final Covariate<Double> x1Covariate = new NumericCovariate("x1", 0, false);
|
||||||
final Covariate<Double> x2Covariate = new NumericCovariate("x2", 1);
|
final Covariate<Double> x2Covariate = new NumericCovariate("x2", 1, false);
|
||||||
|
|
||||||
final List<Covariate.Value<Double>> x1List = DoubleStream
|
final List<Covariate.Value<Double>> x1List = DoubleStream
|
||||||
.generate(() -> random.nextDouble()*10.0)
|
.generate(() -> random.nextDouble()*10.0)
|
||||||
|
|
|
@ -41,9 +41,9 @@ public class TrainSingleTreeFactor {
|
||||||
final int n = 10000;
|
final int n = 10000;
|
||||||
final List<Row<Double>> trainingSet = new ArrayList<>(n);
|
final List<Row<Double>> trainingSet = new ArrayList<>(n);
|
||||||
|
|
||||||
final Covariate<Double> x1Covariate = new NumericCovariate("x1", 0);
|
final Covariate<Double> x1Covariate = new NumericCovariate("x1", 0, false);
|
||||||
final Covariate<Double> x2Covariate = new NumericCovariate("x2", 1);
|
final Covariate<Double> x2Covariate = new NumericCovariate("x2", 1, false);
|
||||||
final FactorCovariate x3Covariate = new FactorCovariate("x3", 2, Utils.easyList("cat", "dog", "mouse"));
|
final FactorCovariate x3Covariate = new FactorCovariate("x3", 2, Utils.easyList("cat", "dog", "mouse"), false);
|
||||||
|
|
||||||
final List<Covariate.Value<Double>> x1List = DoubleStream
|
final List<Covariate.Value<Double>> x1List = DoubleStream
|
||||||
.generate(() -> random.nextDouble()*10.0)
|
.generate(() -> random.nextDouble()*10.0)
|
||||||
|
|
Loading…
Reference in a new issue