Compare commits
17 commits
Author | SHA1 | Date | |
---|---|---|---|
|
044bf08e3d | ||
|
c4bab39245 | ||
f3a4ef01ed | |||
|
54af805d4d | ||
79a9522ba7 | |||
|
c24626ff61 | ||
|
51696e2546 | ||
|
f1c5b292ed | ||
|
a56ad4433d | ||
|
f23ee21ef3 | ||
|
186de413ed | ||
|
aa1f544ea2 | ||
|
86f6c195d7 | ||
|
9258f75e4e | ||
|
7371dab4f1 | ||
|
ae9a6b9a3f | ||
d7cdc9f6e7 |
68 changed files with 2718 additions and 625 deletions
7
.gitignore
vendored
7
.gitignore
vendored
|
@ -2,7 +2,10 @@
|
|||
.settings
|
||||
.project
|
||||
target/
|
||||
library/target/
|
||||
executable/target/
|
||||
*.iml
|
||||
.idea
|
||||
template.yaml
|
||||
dependency-reduced-pom.xml
|
||||
library/dependency-reduced-pom.xml
|
||||
executable/dependency-reduced-pom.xml
|
||||
executable/template.yaml
|
||||
|
|
18
README.md
18
README.md
|
@ -1,14 +1,20 @@
|
|||
# README
|
||||
|
||||
This Java software package contains the backend classes used in the R package [largeRCRF](https://github.com/jatherrien/largeRCRF).
|
||||
This repository contains the largeRCRF Java library, containing all of the logic used for training the random forests. This provides the Jar file used in the R package [largeRCRF](https://github.com/jatherrien/largeRCRF).
|
||||
|
||||
On its own it's not useful, but you're free to integrate it into your own projects (as long as you follow the terms of the GPL-3 license), or extend it. More documentation will be added later on how to extend it, but for now if you want an idea I suggest you take a look at the `MeanResponseCombiner` and `WeightedVarianceSplitFinder` classes, which is a small example of a regression random forest implementation.
|
||||
Most users interested in training random competing risks forests should use the [R package component](https://github.com/jatherrien/largeRCRF); the content in this repository is only useful for advanced users.
|
||||
|
||||
If you've made an extension or modification to the package and would like to integrate it into the R package component, build the project in Maven with `mvn clean package` and copy the `largeRCRF-1.0-SNAPSHOT.jar` file now found in the `target/` directory into the `inst/java/` directory for the R package (delete the previous jar file). Then just build the R package, possibly with your modifications in the R code, with `R> devtools::build()`.
|
||||
## License
|
||||
|
||||
If you have any questions on how to integrate this code with your own, how to integrate it with the R project, or anything else related to this project, please feel free to either [email me](mailto:joelt@sfu.ca) or create an Issue.
|
||||
You're free to use / modify / redistribute the project, as long as you follow the terms of the GPL-3 license.
|
||||
|
||||
A small project allowing this code to be called directly outside of R will be released soon.
|
||||
## Extending
|
||||
|
||||
Documentation on how to extend the library to add support for other types of random forests will eventually be added, but for now if you're interested in that I suggest you take a look at the `MeanResponseCombiner` and `WeightedVarianceSplitFinder` classes to see how some basic regression random forests were introduced.
|
||||
|
||||
If you've made a modification to the package and would like to integrate it into the R package component, build the project in Maven with `mvn clean package`, then just copy `target/largeRCRF-library-1.0-SNAPSHOT.jar` into the `inst/java/` directory for the R package, replacing the previous Jar file there. Then build the R package, possibly with your modifications to the code there, with `R> devtools::build()`.
|
||||
|
||||
Please don't take the current lack of documentation as a sign that I oppose others extending or modifying the project; if you have any questions on running, extending, integrating with R, or anything else related to this project, please don't hesitate to either [email me](mailto:joelt@sfu.ca) or create an Issue. Most likely my answers to your questions will end up forming the basis for any documentation written.
|
||||
|
||||
## System Requirements
|
||||
|
||||
|
@ -17,5 +23,3 @@ You need:
|
|||
* A Java runtime version 1.8 or greater
|
||||
* Maven to build the project
|
||||
|
||||
|
||||
|
||||
|
|
7
pom.xml
7
pom.xml
|
@ -4,8 +4,8 @@
|
|||
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
||||
<modelVersion>4.0.0</modelVersion>
|
||||
|
||||
<groupId>ca.joeltherrien</groupId>
|
||||
<artifactId>largeRCRF</artifactId>
|
||||
<groupId>ca.joeltherrien.ca</groupId>
|
||||
<artifactId>largeRCRF-library</artifactId>
|
||||
<version>1.0-SNAPSHOT</version>
|
||||
|
||||
<properties>
|
||||
|
@ -61,6 +61,7 @@
|
|||
<plugin>
|
||||
<groupId>org.apache.maven.plugins</groupId>
|
||||
<artifactId>maven-shade-plugin</artifactId>
|
||||
<version>3.2.1</version>
|
||||
<executions>
|
||||
<execution>
|
||||
<phase>package</phase>
|
||||
|
@ -85,7 +86,7 @@
|
|||
<configuration>
|
||||
<rulesets>
|
||||
<!-- Custom local file system rule set -->
|
||||
<ruleset>${project.basedir}/pmd-rules.xml</ruleset>
|
||||
<ruleset>pmd-rules.xml</ruleset>
|
||||
</rulesets>
|
||||
</configuration>
|
||||
</plugin>
|
||||
|
|
|
@ -21,12 +21,13 @@ import lombok.Getter;
|
|||
import lombok.RequiredArgsConstructor;
|
||||
|
||||
import java.io.Serializable;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.*;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@RequiredArgsConstructor
|
||||
public class CovariateRow implements Serializable {
|
||||
public class CovariateRow implements Serializable, Cloneable {
|
||||
|
||||
private static final long serialVersionUID = 1L;
|
||||
|
||||
private final Covariate.Value[] valueArray;
|
||||
|
||||
|
@ -46,6 +47,14 @@ public class CovariateRow implements Serializable {
|
|||
return "CovariateRow " + this.id;
|
||||
}
|
||||
|
||||
@Override
|
||||
public CovariateRow clone() {
|
||||
// shallow clone, which is fine. I want a new array, but the values don't need to be copied
|
||||
final Covariate.Value[] copyValueArray = this.valueArray.clone();
|
||||
|
||||
return new CovariateRow(copyValueArray, this.id);
|
||||
}
|
||||
|
||||
public static CovariateRow createSimple(Map<String, String> simpleMap, List<Covariate> covariateList, int id){
|
||||
final Covariate.Value[] valueArray = new Covariate.Value[covariateList.size()];
|
||||
final Map<String, Covariate> covariateMap = new HashMap<>();
|
||||
|
@ -64,4 +73,27 @@ public class CovariateRow implements Serializable {
|
|||
return new CovariateRow(valueArray, id);
|
||||
}
|
||||
|
||||
/**
|
||||
* Used for variable importance; takes a List of CovariateRows and permute one of the Covariates.
|
||||
*
|
||||
* @param covariateRows The List of CovariateRows to scramble. Note that the originals won't be modified.
|
||||
* @param covariateToScramble The Covariate to scramble on.
|
||||
* @param random The source of randomness to use. If not present, one will be created.
|
||||
* @return A List of CovariateRows where the specified covariate was scrambled. These are different objects from the ones provided.
|
||||
*/
|
||||
public static List<CovariateRow> scrambleCovariateValues(List<? extends CovariateRow> covariateRows, Covariate covariateToScramble, Optional<Random> random){
|
||||
final List<CovariateRow> permutedCovariateRowList = new ArrayList<>(covariateRows);
|
||||
Collections.shuffle(permutedCovariateRowList, random.orElse(new Random())); // without replacement
|
||||
|
||||
final List<CovariateRow> clonedRowList = covariateRows.stream().map(CovariateRow::clone).collect(Collectors.toList());
|
||||
|
||||
final int covariateToScrambleIndex = covariateToScramble.getIndex();
|
||||
for(int i=0; i < covariateRows.size(); i++){
|
||||
clonedRowList.get(i).valueArray[covariateToScrambleIndex] = permutedCovariateRowList.get(i).valueArray[covariateToScrambleIndex];
|
||||
}
|
||||
|
||||
return clonedRowList;
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -32,9 +32,6 @@ public class Row<Y> extends CovariateRow {
|
|||
this.response = response;
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
public Y getResponse() {
|
||||
return this.response;
|
||||
}
|
||||
|
|
|
@ -49,6 +49,8 @@ public interface Covariate<V> extends Serializable, Comparable<Covariate> {
|
|||
return getIndex() - other.getIndex();
|
||||
}
|
||||
|
||||
boolean haveNASplitPenalty();
|
||||
|
||||
interface Value<V> extends Serializable{
|
||||
|
||||
Covariate<V> getParent();
|
||||
|
|
|
@ -25,9 +25,12 @@ import lombok.Getter;
|
|||
import java.util.Iterator;
|
||||
import java.util.List;
|
||||
import java.util.Random;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
public final class BooleanCovariate implements Covariate<Boolean> {
|
||||
|
||||
private static final long serialVersionUID = 1L;
|
||||
|
||||
@Getter
|
||||
private final String name;
|
||||
|
||||
|
@ -38,14 +41,26 @@ public final class BooleanCovariate implements Covariate<Boolean> {
|
|||
|
||||
private final BooleanSplitRule splitRule; // there's only one possible rule for BooleanCovariates.
|
||||
|
||||
public BooleanCovariate(String name, int index){
|
||||
private final boolean haveNASplitPenalty;
|
||||
@Override
|
||||
public boolean haveNASplitPenalty(){
|
||||
// penalty would add worthless computational time if there are no NAs
|
||||
return hasNAs && haveNASplitPenalty;
|
||||
}
|
||||
|
||||
public BooleanCovariate(String name, int index, boolean haveNASplitPenalty){
|
||||
this.name = name;
|
||||
this.index = index;
|
||||
splitRule = new BooleanSplitRule(this);
|
||||
this.splitRule = new BooleanSplitRule(this);
|
||||
this.haveNASplitPenalty = haveNASplitPenalty;
|
||||
}
|
||||
|
||||
@Override
|
||||
public <Y> Iterator<Split<Y, Boolean>> generateSplitRuleUpdater(List<Row<Y>> data, int number, Random random) {
|
||||
if(hasNAs){
|
||||
data = data.stream().filter(row -> !row.getValueByIndex(index).isNA()).collect(Collectors.toList());
|
||||
}
|
||||
|
||||
return new SingletonIterator<>(this.splitRule.applyRule(data));
|
||||
}
|
||||
|
||||
|
@ -84,6 +99,8 @@ public final class BooleanCovariate implements Covariate<Boolean> {
|
|||
|
||||
public class BooleanValue implements Value<Boolean>{
|
||||
|
||||
private static final long serialVersionUID = 1L;
|
||||
|
||||
private final Boolean value;
|
||||
|
||||
private BooleanValue(final Boolean value){
|
||||
|
|
|
@ -21,6 +21,8 @@ import ca.joeltherrien.randomforest.covariates.SplitRule;
|
|||
|
||||
public class BooleanSplitRule implements SplitRule<Boolean> {
|
||||
|
||||
private static final long serialVersionUID = 1L;
|
||||
|
||||
private final int parentCovariateIndex;
|
||||
|
||||
public BooleanSplitRule(BooleanCovariate parent){
|
||||
|
|
|
@ -23,9 +23,12 @@ import lombok.EqualsAndHashCode;
|
|||
import lombok.Getter;
|
||||
|
||||
import java.util.*;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
public final class FactorCovariate implements Covariate<String> {
|
||||
|
||||
private static final long serialVersionUID = 1L;
|
||||
|
||||
@Getter
|
||||
private final String name;
|
||||
|
||||
|
@ -38,8 +41,15 @@ public final class FactorCovariate implements Covariate<String> {
|
|||
|
||||
private boolean hasNAs;
|
||||
|
||||
private final boolean haveNASplitPenalty;
|
||||
@Override
|
||||
public boolean haveNASplitPenalty(){
|
||||
// penalty would add worthless computational time if there are no NAs
|
||||
return hasNAs && haveNASplitPenalty;
|
||||
}
|
||||
|
||||
public FactorCovariate(final String name, final int index, List<String> levels){
|
||||
|
||||
public FactorCovariate(final String name, final int index, List<String> levels, final boolean haveNASplitPenalty){
|
||||
this.name = name;
|
||||
this.index = index;
|
||||
this.factorLevels = new HashMap<>();
|
||||
|
@ -61,12 +71,22 @@ public final class FactorCovariate implements Covariate<String> {
|
|||
this.numberOfPossiblePairings = numberOfPossiblePairingsTemp-1;
|
||||
|
||||
this.naValue = new FactorValue(null);
|
||||
|
||||
this.haveNASplitPenalty = haveNASplitPenalty;
|
||||
}
|
||||
|
||||
|
||||
|
||||
@Override
|
||||
public <Y> Iterator<Split<Y, String>> generateSplitRuleUpdater(List<Row<Y>> data, int number, Random random) {
|
||||
if(hasNAs()){
|
||||
data = data.stream().filter(row -> !row.getCovariateValue(this).isNA()).collect(Collectors.toList());
|
||||
}
|
||||
|
||||
if(number == 0){ // nsplit = 0 => try every possibility, although we limit it to the number of observations.
|
||||
number = data.size();
|
||||
}
|
||||
|
||||
final Set<Split<Y, String>> splits = new HashSet<>();
|
||||
|
||||
// This is to ensure we don't get stuck in an infinite loop for small factors
|
||||
|
@ -122,6 +142,8 @@ public final class FactorCovariate implements Covariate<String> {
|
|||
@EqualsAndHashCode
|
||||
public final class FactorValue implements Covariate.Value<String>{
|
||||
|
||||
private static final long serialVersionUID = 1L;
|
||||
|
||||
private final String value;
|
||||
|
||||
private FactorValue(final String value){
|
||||
|
|
|
@ -25,6 +25,8 @@ import java.util.Set;
|
|||
@EqualsAndHashCode
|
||||
public final class FactorSplitRule implements SplitRule<String> {
|
||||
|
||||
private static final long serialVersionUID = 1L;
|
||||
|
||||
private final int parentCovariateIndex;
|
||||
private final Set<String> leftSideValues;
|
||||
|
||||
|
|
|
@ -37,6 +37,8 @@ import java.util.stream.Stream;
|
|||
@ToString
|
||||
public final class NumericCovariate implements Covariate<Double> {
|
||||
|
||||
private static final long serialVersionUID = 1L;
|
||||
|
||||
@Getter
|
||||
private final String name;
|
||||
|
||||
|
@ -45,6 +47,13 @@ public final class NumericCovariate implements Covariate<Double> {
|
|||
|
||||
private boolean hasNAs = false;
|
||||
|
||||
private final boolean haveNASplitPenalty;
|
||||
@Override
|
||||
public boolean haveNASplitPenalty(){
|
||||
// penalty would add worthless computational time if there are no NAs
|
||||
return hasNAs && haveNASplitPenalty;
|
||||
}
|
||||
|
||||
@Override
|
||||
public <Y> NumericSplitRuleUpdater<Y> generateSplitRuleUpdater(List<Row<Y>> data, int number, Random random) {
|
||||
Stream<Row<Y>> stream = data.stream();
|
||||
|
@ -122,6 +131,8 @@ public final class NumericCovariate implements Covariate<Double> {
|
|||
@EqualsAndHashCode
|
||||
public class NumericValue implements Covariate.Value<Double>{
|
||||
|
||||
private static final long serialVersionUID = 1L;
|
||||
|
||||
private final Double value; // may be null
|
||||
|
||||
private NumericValue(final Double value){
|
||||
|
|
|
@ -23,10 +23,12 @@ import lombok.EqualsAndHashCode;
|
|||
@EqualsAndHashCode
|
||||
public class NumericSplitRule implements SplitRule<Double> {
|
||||
|
||||
private static final long serialVersionUID = 1L;
|
||||
|
||||
private final int parentCovariateIndex;
|
||||
private final double threshold;
|
||||
|
||||
NumericSplitRule(NumericCovariate parent, final double threshold){
|
||||
public NumericSplitRule(NumericCovariate parent, final double threshold){
|
||||
this.parentCovariateIndex = parent.getIndex();
|
||||
this.threshold = threshold;
|
||||
}
|
||||
|
|
|
@ -26,6 +26,8 @@ import java.util.List;
|
|||
@Builder
|
||||
public class CompetingRiskFunctions implements Serializable {
|
||||
|
||||
private static final long serialVersionUID = 1L;
|
||||
|
||||
private final List<RightContinuousStepFunction> causeSpecificHazards;
|
||||
private final List<RightContinuousStepFunction> cumulativeIncidenceCurves;
|
||||
|
||||
|
|
|
@ -23,6 +23,8 @@ import java.io.Serializable;
|
|||
@Data
|
||||
public class CompetingRiskResponse implements Serializable {
|
||||
|
||||
private static final long serialVersionUID = 1L;
|
||||
|
||||
private final int delta;
|
||||
private final double u;
|
||||
|
||||
|
|
|
@ -26,6 +26,9 @@ import lombok.EqualsAndHashCode;
|
|||
@EqualsAndHashCode(callSuper = true)
|
||||
@Data
|
||||
public final class CompetingRiskResponseWithCensorTime extends CompetingRiskResponse {
|
||||
|
||||
private static final long serialVersionUID = 1L;
|
||||
|
||||
private final double c;
|
||||
|
||||
public CompetingRiskResponseWithCensorTime(int delta, double u, double c) {
|
||||
|
|
|
@ -16,9 +16,11 @@
|
|||
|
||||
package ca.joeltherrien.randomforest.responses.competingrisk;
|
||||
|
||||
import ca.joeltherrien.randomforest.utils.RightContinuousStepFunction;
|
||||
import ca.joeltherrien.randomforest.utils.StepFunction;
|
||||
|
||||
import java.util.*;
|
||||
import java.util.stream.IntStream;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
public class CompetingRiskUtils {
|
||||
|
@ -116,6 +118,44 @@ public class CompetingRiskUtils {
|
|||
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculate the Integrated Brier Score error on a list of responses and predictions.
|
||||
*
|
||||
* @param responses A List of responses
|
||||
* @param predictions The corresponding List of predictions.
|
||||
* @param censoringDistribution The censoring distribution.
|
||||
* @param eventOfFocus The event we are calculating the error for.
|
||||
* @param integrationUpperBound The upper bound to integrate to.
|
||||
* @param isParallel Whether we should use parallel streams or not (provided because of bugs on a particular system).
|
||||
* @return
|
||||
*/
|
||||
public static double[] calculateIBSError(final List<CompetingRiskResponse> responses,
|
||||
List<CompetingRiskFunctions> predictions,
|
||||
Optional<RightContinuousStepFunction> censoringDistribution,
|
||||
int eventOfFocus,
|
||||
double integrationUpperBound,
|
||||
boolean isParallel){
|
||||
|
||||
if(responses.size() != predictions.size()){
|
||||
throw new IllegalArgumentException("Length of responses and predictions must be equal.");
|
||||
}
|
||||
|
||||
final IBSCalculator calculator = new IBSCalculator(censoringDistribution);
|
||||
|
||||
IntStream stream = IntStream.range(0, responses.size());
|
||||
|
||||
if(isParallel){
|
||||
stream = stream.parallel();
|
||||
}
|
||||
|
||||
return stream.mapToDouble(i -> {
|
||||
CompetingRiskResponse response = responses.get(i);
|
||||
RightContinuousStepFunction cif = predictions.get(i).getCumulativeIncidenceFunction(eventOfFocus);
|
||||
|
||||
return calculator.calculateError(response, cif, eventOfFocus, integrationUpperBound);
|
||||
}).toArray();
|
||||
}
|
||||
|
||||
|
||||
public static CompetingRiskSetsImpl calculateSetsEfficiently(final List<CompetingRiskResponse> initialLeftHand,
|
||||
final List<CompetingRiskResponse> initialRightHand,
|
||||
|
|
|
@ -0,0 +1,82 @@
|
|||
package ca.joeltherrien.randomforest.responses.competingrisk;
|
||||
|
||||
import ca.joeltherrien.randomforest.utils.RightContinuousStepFunction;
|
||||
|
||||
import java.util.Optional;
|
||||
|
||||
/**
|
||||
* Used to calculate the Integrated Brier Score. See Section 4.2 of "Random survival forests for competing risks" by Ishwaran.
|
||||
*
|
||||
*/
|
||||
public class IBSCalculator {
|
||||
|
||||
private final Optional<RightContinuousStepFunction> censoringDistribution;
|
||||
|
||||
public IBSCalculator(RightContinuousStepFunction censoringDistribution){
|
||||
this.censoringDistribution = Optional.of(censoringDistribution);
|
||||
}
|
||||
|
||||
public IBSCalculator(){
|
||||
this.censoringDistribution = Optional.empty();
|
||||
}
|
||||
|
||||
public IBSCalculator(Optional<RightContinuousStepFunction> censoringDistribution){
|
||||
this.censoringDistribution = censoringDistribution;
|
||||
}
|
||||
|
||||
public double calculateError(CompetingRiskResponse response, RightContinuousStepFunction cif, int eventOfInterest, double integrationUpperBound){
|
||||
|
||||
// return integral of weights*(I(response.getU() <= times & response.getDelta() == eventOfInterest) - cif(times))^2
|
||||
// Note that if we don't have weights, just treat them all as one (i.e. don't bother multiplying)
|
||||
|
||||
RightContinuousStepFunction functionToIntegrate = cif;
|
||||
|
||||
if(response.getDelta() == eventOfInterest){
|
||||
final RightContinuousStepFunction observedFunction = new RightContinuousStepFunction(new double[]{response.getU()}, new double[]{1.0}, 0.0);
|
||||
functionToIntegrate = RightContinuousStepFunction.biOperation(observedFunction, functionToIntegrate, (a, b) -> (a - b) * (a - b));
|
||||
} else{
|
||||
functionToIntegrate = functionToIntegrate.unaryOperation(a -> a*a);
|
||||
}
|
||||
|
||||
if(censoringDistribution.isPresent()){
|
||||
final RightContinuousStepFunction weights = calculateWeights(response, censoringDistribution.get());
|
||||
functionToIntegrate = RightContinuousStepFunction.biOperation(weights, functionToIntegrate, (a, b) -> a*b);
|
||||
|
||||
// the censoring weights go to 0 after the response is censored, so we can speed up results by only integrating
|
||||
// prior to the censor times
|
||||
if(response.isCensored()){
|
||||
integrationUpperBound = Math.min(integrationUpperBound, response.getU());
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
return functionToIntegrate.integrate(0.0, integrationUpperBound);
|
||||
}
|
||||
|
||||
private RightContinuousStepFunction calculateWeights(CompetingRiskResponse response, RightContinuousStepFunction censoringDistribution){
|
||||
final double recordedTime = response.getU();
|
||||
|
||||
// Function(t) = firstPart(t) + secondPart(t)/thirdPart(t) where:
|
||||
// firstPart(t) = I(recordedTime <= t & !response.isCensored()) / censoringDistribution.evaluate(recordedTime);
|
||||
// secondPart(t) = I(recordedTime > t) = 1 - I(recordedTime <= t)
|
||||
// thirdPart(t) = censoringDistribution.evaluate(t)
|
||||
|
||||
final RightContinuousStepFunction secondPart = new RightContinuousStepFunction(new double[]{recordedTime}, new double[]{0.0}, 1.0);
|
||||
RightContinuousStepFunction result = RightContinuousStepFunction.biOperation(secondPart, censoringDistribution,
|
||||
(second, third) -> second / third);
|
||||
|
||||
if(!response.isCensored()){
|
||||
final RightContinuousStepFunction firstPart = new RightContinuousStepFunction(
|
||||
new double[]{recordedTime},
|
||||
new double[]{1.0 / censoringDistribution.evaluate(recordedTime)},
|
||||
0.0);
|
||||
|
||||
result = RightContinuousStepFunction.biOperation(firstPart, result, Double::sum);
|
||||
}
|
||||
|
||||
return result;
|
||||
|
||||
}
|
||||
|
||||
|
||||
}
|
|
@ -17,17 +17,17 @@
|
|||
package ca.joeltherrien.randomforest.responses.competingrisk.combiner;
|
||||
|
||||
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskFunctions;
|
||||
import ca.joeltherrien.randomforest.tree.ResponseCombiner;
|
||||
import ca.joeltherrien.randomforest.utils.RightContinuousStepFunction;
|
||||
import ca.joeltherrien.randomforest.utils.Utils;
|
||||
import ca.joeltherrien.randomforest.tree.ForestResponseCombiner;
|
||||
import ca.joeltherrien.randomforest.tree.IntermediateCombinedResponse;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
|
||||
@RequiredArgsConstructor
|
||||
public class CompetingRiskFunctionCombiner implements ResponseCombiner<CompetingRiskFunctions, CompetingRiskFunctions> {
|
||||
public class CompetingRiskFunctionCombiner implements ForestResponseCombiner<CompetingRiskFunctions, CompetingRiskFunctions> {
|
||||
|
||||
private static final long serialVersionUID = 1L;
|
||||
|
||||
private final int[] events;
|
||||
private final double[] times; // We may restrict ourselves to specific times.
|
||||
|
@ -55,72 +55,22 @@ public class CompetingRiskFunctionCombiner implements ResponseCombiner<Competing
|
|||
).sorted().distinct().toArray();
|
||||
}
|
||||
|
||||
final double n = responses.size();
|
||||
final IntermediateCompetingRisksFunctionsTimesKnown intermediateResult = new IntermediateCompetingRisksFunctionsTimesKnown(responses.size(), this.events, timesToUse);
|
||||
|
||||
final double[] survivalY = new double[timesToUse.length];
|
||||
final double[][] csCHFY = new double[events.length][timesToUse.length];
|
||||
final double[][] cifY = new double[events.length][timesToUse.length];
|
||||
|
||||
/*
|
||||
We're going to try to efficiently put our predictions together -
|
||||
Assumptions - for each event on a response, the hazard and CIF functions share the same x points
|
||||
|
||||
Plan - go through the time on each response and make use of that so that when we search for a time index
|
||||
to evaluate the function at, we don't need to re-search the earlier times.
|
||||
|
||||
*/
|
||||
|
||||
|
||||
for(final CompetingRiskFunctions currentFunctions : responses){
|
||||
final double[] survivalXPoints = currentFunctions.getSurvivalCurve().getX();
|
||||
final double[][] eventSpecificXPoints = new double[events.length][];
|
||||
|
||||
for(final int event : events){
|
||||
eventSpecificXPoints[event-1] = currentFunctions.getCumulativeIncidenceFunction(event)
|
||||
.getX();
|
||||
for(CompetingRiskFunctions input : responses){
|
||||
intermediateResult.processNewInput(input);
|
||||
}
|
||||
|
||||
int previousSurvivalIndex = 0;
|
||||
final int[] previousEventSpecificIndex = new int[events.length]; // relying on 0 being default value
|
||||
|
||||
for(int i=0; i<timesToUse.length; i++){
|
||||
final double time = timesToUse[i];
|
||||
|
||||
// Survival curve
|
||||
final int survivalTimeIndex = Utils.binarySearchLessThan(previousSurvivalIndex, survivalXPoints.length, survivalXPoints, time);
|
||||
survivalY[i] = survivalY[i] + currentFunctions.getSurvivalCurve().evaluateByIndex(survivalTimeIndex) / n;
|
||||
previousSurvivalIndex = Math.max(survivalTimeIndex, 0); // if our current time is less than the smallest time in xPoints then binarySearchLessThan returned a -1.
|
||||
// -1's not an issue for evaluateByIndex, but it is an issue for the next time binarySearchLessThan is called.
|
||||
|
||||
// CHFs and CIFs
|
||||
for(final int event : events){
|
||||
final double[] xPoints = eventSpecificXPoints[event-1];
|
||||
final int eventTimeIndex = Utils.binarySearchLessThan(previousEventSpecificIndex[event-1], xPoints.length,
|
||||
xPoints, time);
|
||||
csCHFY[event-1][i] = csCHFY[event-1][i] + currentFunctions.getCauseSpecificHazardFunction(event)
|
||||
.evaluateByIndex(eventTimeIndex) / n;
|
||||
cifY[event-1][i] = cifY[event-1][i] + currentFunctions.getCumulativeIncidenceFunction(event)
|
||||
.evaluateByIndex(eventTimeIndex) / n;
|
||||
|
||||
previousEventSpecificIndex[event-1] = Math.max(eventTimeIndex, 0);
|
||||
}
|
||||
return intermediateResult.transformToOutput();
|
||||
}
|
||||
|
||||
@Override
|
||||
public IntermediateCombinedResponse<CompetingRiskFunctions, CompetingRiskFunctions> startIntermediateCombinedResponse(int countInputs) {
|
||||
if(this.times != null){
|
||||
return new IntermediateCompetingRisksFunctionsTimesKnown(countInputs, this.events, this.times);
|
||||
}
|
||||
|
||||
final RightContinuousStepFunction survivalFunction = new RightContinuousStepFunction(timesToUse, survivalY, 1.0);
|
||||
final List<RightContinuousStepFunction> causeSpecificCumulativeHazardFunctionList = new ArrayList<>(events.length);
|
||||
final List<RightContinuousStepFunction> cumulativeIncidenceFunctionList = new ArrayList<>(events.length);
|
||||
|
||||
for(final int event : events){
|
||||
causeSpecificCumulativeHazardFunctionList.add(event-1, new RightContinuousStepFunction(timesToUse, csCHFY[event-1], 0));
|
||||
cumulativeIncidenceFunctionList.add(event-1, new RightContinuousStepFunction(timesToUse, cifY[event-1], 0));
|
||||
}
|
||||
|
||||
return CompetingRiskFunctions.builder()
|
||||
.causeSpecificHazards(causeSpecificCumulativeHazardFunctionList)
|
||||
.cumulativeIncidenceCurves(cumulativeIncidenceFunctionList)
|
||||
.survivalCurve(survivalFunction)
|
||||
.build();
|
||||
// TODO - implement
|
||||
throw new RuntimeException("startIntermediateCombinedResponse when times is unknown is not yet implemented");
|
||||
}
|
||||
}
|
||||
|
|
|
@ -35,6 +35,8 @@ import java.util.List;
|
|||
*/
|
||||
public class CompetingRiskResponseCombiner implements ResponseCombiner<CompetingRiskResponse, CompetingRiskFunctions> {
|
||||
|
||||
private static final long serialVersionUID = 1L;
|
||||
|
||||
private final int[] events;
|
||||
|
||||
public CompetingRiskResponseCombiner(final int[] events){
|
||||
|
|
|
@ -0,0 +1,118 @@
|
|||
package ca.joeltherrien.randomforest.responses.competingrisk.combiner;
|
||||
|
||||
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskFunctions;
|
||||
import ca.joeltherrien.randomforest.tree.IntermediateCombinedResponse;
|
||||
import ca.joeltherrien.randomforest.utils.RightContinuousStepFunction;
|
||||
import ca.joeltherrien.randomforest.utils.Utils;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
public class IntermediateCompetingRisksFunctionsTimesKnown implements IntermediateCombinedResponse<CompetingRiskFunctions, CompetingRiskFunctions> {
|
||||
|
||||
private double expectedN;
|
||||
private final int[] events;
|
||||
private final double[] timesToUse;
|
||||
private int actualN;
|
||||
|
||||
private final double[] survivalY;
|
||||
private final double[][] csCHFY;
|
||||
private final double[][] cifY;
|
||||
|
||||
public IntermediateCompetingRisksFunctionsTimesKnown(int n, int[] events, double[] timesToUse){
|
||||
this.expectedN = n;
|
||||
this.events = events;
|
||||
this.timesToUse = timesToUse;
|
||||
this.actualN = 0;
|
||||
|
||||
this.survivalY = new double[timesToUse.length];
|
||||
this.csCHFY = new double[events.length][timesToUse.length];
|
||||
this.cifY = new double[events.length][timesToUse.length];
|
||||
}
|
||||
|
||||
@Override
|
||||
public void processNewInput(CompetingRiskFunctions input) {
|
||||
/*
|
||||
We're going to try to efficiently put our predictions together -
|
||||
Assumptions - for each event on a response, the hazard and CIF functions share the same x points
|
||||
|
||||
Plan - go through the time on each response and make use of that so that when we search for a time index
|
||||
to evaluate the function at, we don't need to re-search the earlier times.
|
||||
|
||||
*/
|
||||
|
||||
this.actualN++;
|
||||
|
||||
final double[] survivalXPoints = input.getSurvivalCurve().getX();
|
||||
final double[][] eventSpecificXPoints = new double[events.length][];
|
||||
|
||||
for(final int event : events){
|
||||
eventSpecificXPoints[event-1] = input.getCumulativeIncidenceFunction(event)
|
||||
.getX();
|
||||
}
|
||||
|
||||
int previousSurvivalIndex = 0;
|
||||
final int[] previousEventSpecificIndex = new int[events.length]; // relying on 0 being default value
|
||||
|
||||
for(int i=0; i<timesToUse.length; i++){
|
||||
final double time = timesToUse[i];
|
||||
|
||||
// Survival curve
|
||||
final int survivalTimeIndex = Utils.binarySearchLessThan(previousSurvivalIndex, survivalXPoints.length, survivalXPoints, time);
|
||||
survivalY[i] = survivalY[i] + input.getSurvivalCurve().evaluateByIndex(survivalTimeIndex) / expectedN;
|
||||
previousSurvivalIndex = Math.max(survivalTimeIndex, 0); // if our current time is less than the smallest time in xPoints then binarySearchLessThan returned a -1.
|
||||
// -1's not an issue for evaluateByIndex, but it is an issue for the next time binarySearchLessThan is called.
|
||||
|
||||
// CHFs and CIFs
|
||||
for(final int event : events){
|
||||
final double[] xPoints = eventSpecificXPoints[event-1];
|
||||
final int eventTimeIndex = Utils.binarySearchLessThan(previousEventSpecificIndex[event-1], xPoints.length,
|
||||
xPoints, time);
|
||||
csCHFY[event-1][i] = csCHFY[event-1][i] + input.getCauseSpecificHazardFunction(event)
|
||||
.evaluateByIndex(eventTimeIndex) / expectedN;
|
||||
cifY[event-1][i] = cifY[event-1][i] + input.getCumulativeIncidenceFunction(event)
|
||||
.evaluateByIndex(eventTimeIndex) / expectedN;
|
||||
|
||||
previousEventSpecificIndex[event-1] = Math.max(eventTimeIndex, 0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public CompetingRiskFunctions transformToOutput() {
|
||||
rescaleOutput();
|
||||
|
||||
final RightContinuousStepFunction survivalFunction = new RightContinuousStepFunction(timesToUse, survivalY, 1.0);
|
||||
final List<RightContinuousStepFunction> causeSpecificCumulativeHazardFunctionList = new ArrayList<>(events.length);
|
||||
final List<RightContinuousStepFunction> cumulativeIncidenceFunctionList = new ArrayList<>(events.length);
|
||||
|
||||
for(final int event : events){
|
||||
causeSpecificCumulativeHazardFunctionList.add(event-1, new RightContinuousStepFunction(timesToUse, csCHFY[event-1], 0));
|
||||
cumulativeIncidenceFunctionList.add(event-1, new RightContinuousStepFunction(timesToUse, cifY[event-1], 0));
|
||||
}
|
||||
|
||||
return CompetingRiskFunctions.builder()
|
||||
.causeSpecificHazards(causeSpecificCumulativeHazardFunctionList)
|
||||
.cumulativeIncidenceCurves(cumulativeIncidenceFunctionList)
|
||||
.survivalCurve(survivalFunction)
|
||||
.build();
|
||||
}
|
||||
|
||||
private void rescaleOutput() {
|
||||
rescaleArray(actualN, this.survivalY);
|
||||
|
||||
for(int event : events){
|
||||
rescaleArray(actualN, this.cifY[event - 1]);
|
||||
rescaleArray(actualN, this.csCHFY[event - 1]);
|
||||
}
|
||||
|
||||
this.expectedN = actualN;
|
||||
|
||||
}
|
||||
|
||||
private void rescaleArray(double newN, double[] array){
|
||||
for(int i=0; i<array.length; i++){
|
||||
array[i] = array[i] * (this.expectedN / newN);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -28,6 +28,7 @@ import java.util.List;
|
|||
*
|
||||
*/
|
||||
public class GrayLogRankSplitFinder extends CompetingRiskSplitFinder<CompetingRiskResponseWithCensorTime> {
|
||||
private static final long serialVersionUID = 1L;
|
||||
|
||||
private final int[] eventsOfFocus;
|
||||
private final int[] events;
|
||||
|
|
|
@ -28,6 +28,7 @@ import java.util.List;
|
|||
*
|
||||
*/
|
||||
public class LogRankSplitFinder extends CompetingRiskSplitFinder<CompetingRiskResponse> {
|
||||
private static final long serialVersionUID = 1L;
|
||||
|
||||
private final int[] eventsOfFocus;
|
||||
private final int[] events;
|
||||
|
|
|
@ -16,7 +16,8 @@
|
|||
|
||||
package ca.joeltherrien.randomforest.responses.regression;
|
||||
|
||||
import ca.joeltherrien.randomforest.tree.ResponseCombiner;
|
||||
import ca.joeltherrien.randomforest.tree.ForestResponseCombiner;
|
||||
import ca.joeltherrien.randomforest.tree.IntermediateCombinedResponse;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
|
@ -24,7 +25,8 @@ import java.util.List;
|
|||
* Returns the Mean value of a group of Doubles.
|
||||
*
|
||||
*/
|
||||
public class MeanResponseCombiner implements ResponseCombiner<Double, Double> {
|
||||
public class MeanResponseCombiner implements ForestResponseCombiner<Double, Double> {
|
||||
private static final long serialVersionUID = 1L;
|
||||
|
||||
@Override
|
||||
public Double combine(List<Double> responses) {
|
||||
|
@ -34,5 +36,39 @@ public class MeanResponseCombiner implements ResponseCombiner<Double, Double> {
|
|||
|
||||
}
|
||||
|
||||
@Override
|
||||
public IntermediateCombinedResponse<Double, Double> startIntermediateCombinedResponse(int countInputs) {
|
||||
return new MeanIntermediateCombinedResponse(countInputs);
|
||||
}
|
||||
|
||||
public static class MeanIntermediateCombinedResponse implements IntermediateCombinedResponse<Double, Double>{
|
||||
|
||||
private double expectedN;
|
||||
private int actualN;
|
||||
private double currentMean;
|
||||
|
||||
public MeanIntermediateCombinedResponse(int n){
|
||||
this.expectedN = n;
|
||||
this.actualN = 0;
|
||||
this.currentMean = 0.0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void processNewInput(Double input) {
|
||||
this.currentMean = this.currentMean + input / expectedN;
|
||||
this.actualN ++;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Double transformToOutput() {
|
||||
// rescale if necessary
|
||||
this.currentMean = this.currentMean * (this.expectedN / (double) actualN);
|
||||
this.expectedN = actualN;
|
||||
|
||||
return currentMean;
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -27,6 +27,7 @@ import java.util.List;
|
|||
import java.util.stream.Collectors;
|
||||
|
||||
public class WeightedVarianceSplitFinder implements SplitFinder<Double> {
|
||||
private static final long serialVersionUID = 1L;
|
||||
|
||||
private Double getScore(Set leftHand, Set rightHand) {
|
||||
|
||||
|
|
|
@ -17,31 +17,18 @@
|
|||
package ca.joeltherrien.randomforest.tree;
|
||||
|
||||
import ca.joeltherrien.randomforest.CovariateRow;
|
||||
import ca.joeltherrien.randomforest.covariates.Covariate;
|
||||
import lombok.Builder;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.TreeMap;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@Builder
|
||||
public class Forest<O, FO> { // O = output of trees, FO = forest output. In practice O == FO, even in competing risk & survival settings
|
||||
public abstract class Forest<O, FO> {
|
||||
|
||||
private final List<Tree<O>> trees;
|
||||
private final ResponseCombiner<O, FO> treeResponseCombiner;
|
||||
private final List<Covariate> covariateList;
|
||||
|
||||
public FO evaluate(CovariateRow row){
|
||||
|
||||
return treeResponseCombiner.combine(
|
||||
trees.stream()
|
||||
.map(node -> node.evaluate(row))
|
||||
.collect(Collectors.toList())
|
||||
);
|
||||
|
||||
}
|
||||
public abstract FO evaluate(CovariateRow row);
|
||||
public abstract FO evaluateOOB(CovariateRow row);
|
||||
public abstract Iterable<Tree<O>> getTrees();
|
||||
public abstract int getNumberOfTrees();
|
||||
|
||||
/**
|
||||
* Used primarily in the R package interface to avoid R loops; and for easier parallelization.
|
||||
|
@ -93,21 +80,6 @@ public class Forest<O, FO> { // O = output of trees, FO = forest output. In prac
|
|||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
public FO evaluateOOB(CovariateRow row){
|
||||
|
||||
return treeResponseCombiner.combine(
|
||||
trees.stream()
|
||||
.filter(tree -> !tree.idInBootstrapSample(row.getId()))
|
||||
.map(node -> node.evaluate(row))
|
||||
.collect(Collectors.toList())
|
||||
);
|
||||
|
||||
}
|
||||
|
||||
public List<Tree<O>> getTrees(){
|
||||
return Collections.unmodifiableList(trees);
|
||||
}
|
||||
|
||||
public Map<Integer, Integer> findSplitsByCovariate(){
|
||||
final Map<Integer, Integer> countMap = new TreeMap<>();
|
||||
|
||||
|
@ -158,4 +130,5 @@ public class Forest<O, FO> { // O = output of trees, FO = forest output. In prac
|
|||
return countTerminalNodes;
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
|
|
@ -0,0 +1,23 @@
|
|||
/*
|
||||
* Copyright (c) 2019 Joel Therrien.
|
||||
* This program is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* This program is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License
|
||||
* along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package ca.joeltherrien.randomforest.tree;
|
||||
|
||||
public interface ForestResponseCombiner<I, O> extends ResponseCombiner<I, O>{
|
||||
|
||||
IntermediateCombinedResponse<I, O> startIntermediateCombinedResponse(int countInputs);
|
||||
|
||||
}
|
|
@ -38,7 +38,7 @@ public class ForestTrainer<Y, TO, FO> {
|
|||
|
||||
private final TreeTrainer<Y, TO> treeTrainer;
|
||||
private final List<Covariate> covariates;
|
||||
private final ResponseCombiner<TO, FO> treeResponseCombiner;
|
||||
private final ForestResponseCombiner<TO, FO> treeResponseCombiner;
|
||||
private final List<Row<Y>> data;
|
||||
|
||||
// number of trees to try
|
||||
|
@ -57,10 +57,10 @@ public class ForestTrainer<Y, TO, FO> {
|
|||
* in which case its trees are combined with the new one.
|
||||
* @return A trained forest.
|
||||
*/
|
||||
public Forest<TO, FO> trainSerialInMemory(Optional<Forest<TO, FO>> initialForest){
|
||||
public OnlineForest<TO, FO> trainSerialInMemory(Optional<Forest<TO, FO>> initialForest){
|
||||
|
||||
final List<Tree<TO>> trees = new ArrayList<>(ntree);
|
||||
initialForest.ifPresent(forest -> trees.addAll(forest.getTrees()));
|
||||
initialForest.ifPresent(forest -> forest.getTrees().forEach(trees::add));
|
||||
|
||||
final Bootstrapper<Row<Y>> bootstrapper = new Bootstrapper<>(data);
|
||||
|
||||
|
@ -77,11 +77,9 @@ public class ForestTrainer<Y, TO, FO> {
|
|||
System.out.println("Finished");
|
||||
}
|
||||
|
||||
|
||||
return Forest.<TO, FO>builder()
|
||||
return OnlineForest.<TO, FO>builder()
|
||||
.treeResponseCombiner(treeResponseCombiner)
|
||||
.trees(trees)
|
||||
.covariateList(covariates)
|
||||
.build();
|
||||
|
||||
}
|
||||
|
@ -94,7 +92,7 @@ public class ForestTrainer<Y, TO, FO> {
|
|||
* There cannot be existing trees if the initial forest is
|
||||
* specified.
|
||||
*/
|
||||
public void trainSerialOnDisk(Optional<Forest<TO, FO>> initialForest){
|
||||
public OfflineForest<TO, FO> trainSerialOnDisk(Optional<Forest<TO, FO>> initialForest){
|
||||
// First we need to see how many trees there currently are
|
||||
final File folder = new File(saveTreeLocation);
|
||||
if(!folder.exists()){
|
||||
|
@ -115,17 +113,14 @@ public class ForestTrainer<Y, TO, FO> {
|
|||
final AtomicInteger treeCount; // tracks how many trees are finished
|
||||
// Using an AtomicInteger is overkill for serial code, but this lets us reuse TreeSavedWorker
|
||||
if(initialForest.isPresent()){
|
||||
final List<Tree<TO>> initialTrees = initialForest.get().getTrees();
|
||||
|
||||
for(int j=0; j<initialTrees.size(); j++){
|
||||
int j=0;
|
||||
for(final Tree<TO> tree : initialForest.get().getTrees()){
|
||||
final String filename = "tree-" + (j+1) + ".tree";
|
||||
final Tree<TO> tree = initialTrees.get(j);
|
||||
|
||||
saveTree(tree, filename);
|
||||
|
||||
j++;
|
||||
}
|
||||
|
||||
treeCount = new AtomicInteger(initialTrees.size());
|
||||
treeCount = new AtomicInteger(j);
|
||||
} else{
|
||||
treeCount = new AtomicInteger(treeFiles.length);
|
||||
}
|
||||
|
@ -153,6 +148,8 @@ public class ForestTrainer<Y, TO, FO> {
|
|||
System.out.println("Finished");
|
||||
}
|
||||
|
||||
return new OfflineForest<>(folder, treeResponseCombiner);
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -162,7 +159,7 @@ public class ForestTrainer<Y, TO, FO> {
|
|||
* in which case its trees are combined with the new one.
|
||||
* @param threads The number of trees to train at once.
|
||||
*/
|
||||
public Forest<TO, FO> trainParallelInMemory(Optional<Forest<TO, FO>> initialForest, int threads){
|
||||
public OnlineForest<TO, FO> trainParallelInMemory(Optional<Forest<TO, FO>> initialForest, int threads){
|
||||
|
||||
// create a list that is pre-specified in size (I can call the .set method at any index < ntree without
|
||||
// the earlier indexes being filled.
|
||||
|
@ -170,11 +167,12 @@ public class ForestTrainer<Y, TO, FO> {
|
|||
|
||||
final int startingCount;
|
||||
if(initialForest.isPresent()){
|
||||
final List<Tree<TO>> initialTrees = initialForest.get().getTrees();
|
||||
for(int j=0; j<initialTrees.size(); j++) {
|
||||
trees.set(j, initialTrees.get(j));
|
||||
int j = 0;
|
||||
for(final Tree<TO> tree : initialForest.get().getTrees()){
|
||||
trees.set(j, tree);
|
||||
j++;
|
||||
}
|
||||
startingCount = initialTrees.size();
|
||||
startingCount = initialForest.get().getNumberOfTrees();
|
||||
}
|
||||
else{
|
||||
startingCount = 0;
|
||||
|
@ -219,7 +217,7 @@ public class ForestTrainer<Y, TO, FO> {
|
|||
System.out.println("\nFinished");
|
||||
}
|
||||
|
||||
return Forest.<TO, FO>builder()
|
||||
return OnlineForest.<TO, FO>builder()
|
||||
.treeResponseCombiner(treeResponseCombiner)
|
||||
.trees(trees)
|
||||
.build();
|
||||
|
@ -235,7 +233,7 @@ public class ForestTrainer<Y, TO, FO> {
|
|||
* specified.
|
||||
* @param threads The number of trees to train at once.
|
||||
*/
|
||||
public void trainParallelOnDisk(Optional<Forest<TO, FO>> initialForest, int threads){
|
||||
public OfflineForest<TO, FO> trainParallelOnDisk(Optional<Forest<TO, FO>> initialForest, int threads){
|
||||
// First we need to see how many trees there currently are
|
||||
final File folder = new File(saveTreeLocation);
|
||||
if(!folder.exists()){
|
||||
|
@ -255,17 +253,14 @@ public class ForestTrainer<Y, TO, FO> {
|
|||
|
||||
final AtomicInteger treeCount; // tracks how many trees are finished
|
||||
if(initialForest.isPresent()){
|
||||
final List<Tree<TO>> initialTrees = initialForest.get().getTrees();
|
||||
|
||||
for(int j=0; j<initialTrees.size(); j++){
|
||||
int j=0;
|
||||
for(final Tree<TO> tree : initialForest.get().getTrees()){
|
||||
final String filename = "tree-" + (j+1) + ".tree";
|
||||
final Tree<TO> tree = initialTrees.get(j);
|
||||
|
||||
saveTree(tree, filename);
|
||||
|
||||
j++;
|
||||
}
|
||||
|
||||
treeCount = new AtomicInteger(initialTrees.size());
|
||||
treeCount = new AtomicInteger(j);
|
||||
} else{
|
||||
treeCount = new AtomicInteger(treeFiles.length);
|
||||
}
|
||||
|
@ -309,6 +304,8 @@ public class ForestTrainer<Y, TO, FO> {
|
|||
System.out.println("\nFinished");
|
||||
}
|
||||
|
||||
return new OfflineForest<>(folder, treeResponseCombiner);
|
||||
|
||||
}
|
||||
|
||||
private Tree<TO> trainTree(final Bootstrapper<Row<Y>> bootstrapper, Random random){
|
||||
|
|
|
@ -0,0 +1,30 @@
|
|||
/*
|
||||
* Copyright (c) 2019 Joel Therrien.
|
||||
* This program is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* This program is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License
|
||||
* along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package ca.joeltherrien.randomforest.tree;
|
||||
|
||||
/**
|
||||
* Similar to ResponseCombiner, but an IntermediateCombinedResponse represents the intermediate state of a single output in the process of being combined.
|
||||
* This class is only used in OfflineForests where we can only load one Tree in memory at a time.
|
||||
*
|
||||
*/
|
||||
public interface IntermediateCombinedResponse<I, O> {
|
||||
|
||||
void processNewInput(I input);
|
||||
|
||||
O transformToOutput();
|
||||
|
||||
}
|
|
@ -0,0 +1,198 @@
|
|||
/*
|
||||
* Copyright (c) 2019 Joel Therrien.
|
||||
* This program is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* This program is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License
|
||||
* along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package ca.joeltherrien.randomforest.tree;
|
||||
|
||||
import ca.joeltherrien.randomforest.CovariateRow;
|
||||
import ca.joeltherrien.randomforest.utils.IterableOfflineTree;
|
||||
import lombok.AllArgsConstructor;
|
||||
|
||||
import java.io.File;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Iterator;
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.IntStream;
|
||||
|
||||
@AllArgsConstructor
|
||||
public class OfflineForest<O, FO> extends Forest<O, FO> {
|
||||
|
||||
private final File[] treeFiles;
|
||||
private final ForestResponseCombiner<O, FO> treeResponseCombiner;
|
||||
|
||||
public OfflineForest(File treeDirectoryPath, ForestResponseCombiner<O, FO> treeResponseCombiner){
|
||||
this.treeResponseCombiner = treeResponseCombiner;
|
||||
|
||||
if(!treeDirectoryPath.isDirectory()){
|
||||
throw new IllegalArgumentException("treeDirectoryPath must point to a directory!");
|
||||
}
|
||||
|
||||
this.treeFiles = treeDirectoryPath.listFiles((file, s) -> s.endsWith(".tree"));
|
||||
}
|
||||
|
||||
@Override
|
||||
public FO evaluate(CovariateRow row) {
|
||||
final List<O> predictedOutputs = new ArrayList<>(treeFiles.length);
|
||||
for(final Tree<O> tree : getTrees()){
|
||||
final O prediction = tree.evaluate(row);
|
||||
predictedOutputs.add(prediction);
|
||||
}
|
||||
|
||||
return treeResponseCombiner.combine(predictedOutputs);
|
||||
}
|
||||
|
||||
@Override
|
||||
public FO evaluateOOB(CovariateRow row) {
|
||||
final List<O> predictedOutputs = new ArrayList<>(treeFiles.length);
|
||||
for(final Tree<O> tree : getTrees()){
|
||||
if(!tree.idInBootstrapSample(row.getId())){
|
||||
final O prediction = tree.evaluate(row);
|
||||
predictedOutputs.add(prediction);
|
||||
}
|
||||
}
|
||||
|
||||
return treeResponseCombiner.combine(predictedOutputs);
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public List<FO> evaluate(List<? extends CovariateRow> rowList){
|
||||
final List<IntermediateCombinedResponse<O, FO>> intermediatePredictions =
|
||||
IntStream.range(0, rowList.size())
|
||||
.mapToObj(i -> treeResponseCombiner.startIntermediateCombinedResponse(treeFiles.length))
|
||||
.collect(Collectors.toList());
|
||||
|
||||
final Iterator<Tree<O>> treeIterator = getTrees().iterator();
|
||||
for(int treeId = 0; treeId < treeFiles.length; treeId++){
|
||||
final Tree<O> currentTree = treeIterator.next();
|
||||
|
||||
IntStream.range(0, rowList.size()).parallel().forEach(
|
||||
rowId -> {
|
||||
final CovariateRow row = rowList.get(rowId);
|
||||
final O prediction = currentTree.evaluate(row);
|
||||
intermediatePredictions.get(rowId).processNewInput(prediction);
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
return intermediatePredictions.stream().parallel()
|
||||
.map(intPred -> intPred.transformToOutput())
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<FO> evaluateSerial(List<? extends CovariateRow> rowList){
|
||||
final List<IntermediateCombinedResponse<O, FO>> intermediatePredictions =
|
||||
IntStream.range(0, rowList.size())
|
||||
.mapToObj(i -> treeResponseCombiner.startIntermediateCombinedResponse(treeFiles.length))
|
||||
.collect(Collectors.toList());
|
||||
|
||||
final Iterator<Tree<O>> treeIterator = getTrees().iterator();
|
||||
for(int treeId = 0; treeId < treeFiles.length; treeId++){
|
||||
final Tree<O> currentTree = treeIterator.next();
|
||||
|
||||
IntStream.range(0, rowList.size()).sequential().forEach(
|
||||
rowId -> {
|
||||
final CovariateRow row = rowList.get(rowId);
|
||||
final O prediction = currentTree.evaluate(row);
|
||||
intermediatePredictions.get(rowId).processNewInput(prediction);
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
return intermediatePredictions.stream().sequential()
|
||||
.map(intPred -> intPred.transformToOutput())
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public List<FO> evaluateOOB(List<? extends CovariateRow> rowList){
|
||||
final List<IntermediateCombinedResponse<O, FO>> intermediatePredictions =
|
||||
IntStream.range(0, rowList.size())
|
||||
.mapToObj(i -> treeResponseCombiner.startIntermediateCombinedResponse(treeFiles.length))
|
||||
.collect(Collectors.toList());
|
||||
|
||||
final Iterator<Tree<O>> treeIterator = getTrees().iterator();
|
||||
for(int treeId = 0; treeId < treeFiles.length; treeId++){
|
||||
final Tree<O> currentTree = treeIterator.next();
|
||||
|
||||
IntStream.range(0, rowList.size()).parallel().forEach(
|
||||
rowId -> {
|
||||
final CovariateRow row = rowList.get(rowId);
|
||||
if(!currentTree.idInBootstrapSample(row.getId())){
|
||||
final O prediction = currentTree.evaluate(row);
|
||||
intermediatePredictions.get(rowId).processNewInput(prediction);
|
||||
}
|
||||
// else do nothing; when we get the final output it will get scaled for the smaller N
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
return intermediatePredictions.stream().parallel()
|
||||
.map(intPred -> intPred.transformToOutput())
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<FO> evaluateSerialOOB(List<? extends CovariateRow> rowList){
|
||||
final List<IntermediateCombinedResponse<O, FO>> intermediatePredictions =
|
||||
IntStream.range(0, rowList.size())
|
||||
.mapToObj(i -> treeResponseCombiner.startIntermediateCombinedResponse(treeFiles.length))
|
||||
.collect(Collectors.toList());
|
||||
|
||||
final Iterator<Tree<O>> treeIterator = getTrees().iterator();
|
||||
for(int treeId = 0; treeId < treeFiles.length; treeId++){
|
||||
final Tree<O> currentTree = treeIterator.next();
|
||||
|
||||
IntStream.range(0, rowList.size()).sequential().forEach(
|
||||
rowId -> {
|
||||
final CovariateRow row = rowList.get(rowId);
|
||||
if(!currentTree.idInBootstrapSample(row.getId())){
|
||||
final O prediction = currentTree.evaluate(row);
|
||||
intermediatePredictions.get(rowId).processNewInput(prediction);
|
||||
}
|
||||
// else do nothing; when we get the final output it will get scaled for the smaller N
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
return intermediatePredictions.stream().sequential()
|
||||
.map(intPred -> intPred.transformToOutput())
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
@Override
|
||||
public Iterable<Tree<O>> getTrees() {
|
||||
return new IterableOfflineTree<>(treeFiles);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getNumberOfTrees() {
|
||||
return treeFiles.length;
|
||||
}
|
||||
|
||||
public OnlineForest<O, FO> createOnlineCopy(){
|
||||
final List<Tree<O>> allTrees = new ArrayList<>(getNumberOfTrees());
|
||||
getTrees().forEach(allTrees::add);
|
||||
|
||||
return OnlineForest.<O, FO>builder()
|
||||
.trees(allTrees)
|
||||
.treeResponseCombiner(treeResponseCombiner)
|
||||
.build();
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,66 @@
|
|||
/*
|
||||
* Copyright (c) 2019 Joel Therrien.
|
||||
* This program is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* This program is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License
|
||||
* along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package ca.joeltherrien.randomforest.tree;
|
||||
|
||||
import ca.joeltherrien.randomforest.CovariateRow;
|
||||
import lombok.Builder;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@Builder
|
||||
public class OnlineForest<O, FO> extends Forest<O, FO> { // O = output of trees, FO = forest output. In practice O == FO, even in competing risk & survival settings
|
||||
|
||||
private final List<Tree<O>> trees;
|
||||
private final ResponseCombiner<O, FO> treeResponseCombiner;
|
||||
|
||||
@Override
|
||||
public FO evaluate(CovariateRow row){
|
||||
|
||||
return treeResponseCombiner.combine(
|
||||
trees.stream()
|
||||
.map(node -> node.evaluate(row))
|
||||
.collect(Collectors.toList())
|
||||
);
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public FO evaluateOOB(CovariateRow row){
|
||||
|
||||
return treeResponseCombiner.combine(
|
||||
trees.stream()
|
||||
.filter(tree -> !tree.idInBootstrapSample(row.getId()))
|
||||
.map(node -> node.evaluate(row))
|
||||
.collect(Collectors.toList())
|
||||
);
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<Tree<O>> getTrees(){
|
||||
return Collections.unmodifiableList(trees);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getNumberOfTrees() {
|
||||
return trees.size();
|
||||
}
|
||||
|
||||
|
||||
}
|
|
@ -17,15 +17,13 @@
|
|||
package ca.joeltherrien.randomforest.tree;
|
||||
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Getter;
|
||||
import lombok.Data;
|
||||
|
||||
@AllArgsConstructor
|
||||
@Data
|
||||
public class SplitAndScore<Y, V> {
|
||||
|
||||
@Getter
|
||||
private final Split<Y, V> split;
|
||||
|
||||
@Getter
|
||||
private final Double score;
|
||||
private Split<Y, V> split;
|
||||
private Double score;
|
||||
|
||||
}
|
||||
|
|
|
@ -29,6 +29,7 @@ import java.util.List;
|
|||
@ToString
|
||||
@Getter
|
||||
public class SplitNode<Y> implements Node<Y> {
|
||||
private static final long serialVersionUID = 1L;
|
||||
|
||||
private final Node<Y> leftHand;
|
||||
private final Node<Y> rightHand;
|
||||
|
|
|
@ -27,6 +27,7 @@ import java.util.List;
|
|||
@RequiredArgsConstructor
|
||||
@ToString
|
||||
public class TerminalNode<Y> implements Node<Y> {
|
||||
private static final long serialVersionUID = 1L;
|
||||
|
||||
private final Y responseValue;
|
||||
|
||||
|
|
|
@ -23,6 +23,7 @@ import java.util.Arrays;
|
|||
import java.util.List;
|
||||
|
||||
public class Tree<Y> implements Node<Y> {
|
||||
private static final long serialVersionUID = 1L;
|
||||
|
||||
@Getter
|
||||
private final Node<Y> rootNode;
|
||||
|
|
|
@ -17,7 +17,9 @@
|
|||
package ca.joeltherrien.randomforest.tree;
|
||||
|
||||
import ca.joeltherrien.randomforest.Row;
|
||||
import ca.joeltherrien.randomforest.VisibleForTesting;
|
||||
import ca.joeltherrien.randomforest.covariates.Covariate;
|
||||
import ca.joeltherrien.randomforest.utils.SingletonIterator;
|
||||
import lombok.AccessLevel;
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Builder;
|
||||
|
@ -72,31 +74,12 @@ public class TreeTrainer<Y, O> {
|
|||
|
||||
}
|
||||
|
||||
|
||||
// Now that we have the best split; we need to handle any NAs that were dropped off
|
||||
final double probabilityLeftHand = (double) bestSplit.leftHand.size() /
|
||||
(double) (bestSplit.leftHand.size() + bestSplit.rightHand.size());
|
||||
|
||||
// Assign missing values to the split if necessary
|
||||
if(covariates.get(bestSplit.getSplitRule().getParentCovariateIndex()).hasNAs()){
|
||||
bestSplit = bestSplit.modifiableClone(); // the lists in bestSplit are otherwise usually unmodifiable lists
|
||||
|
||||
for(Row<Y> row : data) {
|
||||
final int covariateIndex = bestSplit.getSplitRule().getParentCovariateIndex();
|
||||
|
||||
if(row.getValueByIndex(covariateIndex).isNA()) {
|
||||
final boolean randomDecision = random.nextDouble() <= probabilityLeftHand;
|
||||
|
||||
if(randomDecision){
|
||||
bestSplit.getLeftHand().add(row);
|
||||
}
|
||||
else{
|
||||
bestSplit.getRightHand().add(row);
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
bestSplit = randomlyAssignNAs(data, bestSplit, random);
|
||||
|
||||
final Node<O> leftNode;
|
||||
final Node<O> rightNode;
|
||||
|
@ -144,7 +127,8 @@ public class TreeTrainer<Y, O> {
|
|||
return splitCovariates;
|
||||
}
|
||||
|
||||
private Split<Y, ?> findBestSplitRule(List<Row<Y>> data, List<Covariate> covariatesToTry, Random random){
|
||||
@VisibleForTesting
|
||||
public Split<Y, ?> findBestSplitRule(List<Row<Y>> data, List<Covariate> covariatesToTry, Random random){
|
||||
|
||||
SplitAndScore<Y, ?> bestSplitAndScore = null;
|
||||
final SplitFinder noGenericSplitFinder = splitFinder; // cause Java generics are sometimes too frustrating
|
||||
|
@ -157,10 +141,32 @@ public class TreeTrainer<Y, O> {
|
|||
continue;
|
||||
}
|
||||
|
||||
final SplitAndScore<Y, ?> candidateSplitAndScore = noGenericSplitFinder.findBestSplit(iterator);
|
||||
SplitAndScore<Y, ?> candidateSplitAndScore = noGenericSplitFinder.findBestSplit(iterator);
|
||||
|
||||
if(candidateSplitAndScore != null && (bestSplitAndScore == null ||
|
||||
candidateSplitAndScore.getScore() > bestSplitAndScore.getScore())) {
|
||||
|
||||
if(candidateSplitAndScore == null){
|
||||
continue;
|
||||
}
|
||||
|
||||
// This score was based on splitting only non-NA values. However, there might be a similar covariate we are also considering
|
||||
// that is just as good at splitting but has less NAs; we should thus penalize the split score for variables with NAs
|
||||
// We do this by randomly assigning the NAs and then recalculating the split score on the best split we already have.
|
||||
//
|
||||
// We only have to penalize the score though if we know it's possible that this might be the best split. If it's not,
|
||||
// then we can skip the computations.
|
||||
final boolean mayBeGoodSplit = bestSplitAndScore == null ||
|
||||
candidateSplitAndScore.getScore() > bestSplitAndScore.getScore();
|
||||
if(mayBeGoodSplit && covariate.haveNASplitPenalty()){
|
||||
Split<Y, ?> candiateSplitWithNAs = randomlyAssignNAs(data, candidateSplitAndScore.getSplit(), random);
|
||||
final Iterator<Split<Y, ?>> newSplitWithRandomNAs = new SingletonIterator<>(candiateSplitWithNAs);
|
||||
final double newScore = splitFinder.findBestSplit(newSplitWithRandomNAs).getScore();
|
||||
|
||||
// There's a chance that NAs might add noise to *improve* the score; but we want to ensure we penalize it.
|
||||
// Thus we only change the score if its worse.
|
||||
candidateSplitAndScore.setScore(Math.min(newScore, candidateSplitAndScore.getScore()));
|
||||
}
|
||||
|
||||
if(bestSplitAndScore == null || candidateSplitAndScore.getScore() > bestSplitAndScore.getScore()) {
|
||||
bestSplitAndScore = candidateSplitAndScore;
|
||||
}
|
||||
|
||||
|
@ -174,6 +180,38 @@ public class TreeTrainer<Y, O> {
|
|||
|
||||
}
|
||||
|
||||
private <V> Split<Y, V> randomlyAssignNAs(List<Row<Y>> data, Split<Y, V> existingSplit, Random random){
|
||||
|
||||
// Now that we have the best split; we need to handle any NAs that were dropped off
|
||||
final double probabilityLeftHand = (double) existingSplit.leftHand.size() /
|
||||
(double) (existingSplit.leftHand.size() + existingSplit.rightHand.size());
|
||||
|
||||
|
||||
final int covariateIndex = existingSplit.getSplitRule().getParentCovariateIndex();
|
||||
|
||||
// Assign missing values to the split if necessary
|
||||
if(covariates.get(existingSplit.getSplitRule().getParentCovariateIndex()).hasNAs()){
|
||||
existingSplit = existingSplit.modifiableClone(); // the lists in bestSplit are otherwise usually unmodifiable lists
|
||||
|
||||
for(Row<Y> row : data) {
|
||||
if(row.getValueByIndex(covariateIndex).isNA()) {
|
||||
final boolean randomDecision = random.nextDouble() <= probabilityLeftHand;
|
||||
|
||||
if(randomDecision){
|
||||
existingSplit.getLeftHand().add(row);
|
||||
}
|
||||
else{
|
||||
existingSplit.getRightHand().add(row);
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return existingSplit;
|
||||
|
||||
}
|
||||
|
||||
private boolean nodeIsPure(List<Row<Y>> data){
|
||||
if(!checkNodePurity){
|
||||
return false;
|
||||
|
|
|
@ -0,0 +1,24 @@
|
|||
package ca.joeltherrien.randomforest.tree.vimp;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* Simple interface for VariableImportanceCalculator; takes in a List of observed responses and a List of predictions
|
||||
* and produces an average error measure.
|
||||
*
|
||||
* @param <Y> The class of the responses.
|
||||
* @param <P> The class of the predictions.
|
||||
*/
|
||||
public interface ErrorCalculator<Y, P>{
|
||||
|
||||
/**
|
||||
* Compares the observed responses with the predictions to produce an average error measure.
|
||||
* Lower errors should indicate a better model fit.
|
||||
*
|
||||
* @param responses
|
||||
* @param predictions
|
||||
* @return
|
||||
*/
|
||||
double averageError(List<Y> responses, List<P> predictions);
|
||||
|
||||
}
|
|
@ -0,0 +1,63 @@
|
|||
package ca.joeltherrien.randomforest.tree.vimp;
|
||||
|
||||
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskFunctions;
|
||||
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
|
||||
import ca.joeltherrien.randomforest.responses.competingrisk.IBSCalculator;
|
||||
import ca.joeltherrien.randomforest.utils.RightContinuousStepFunction;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* Implements ErrorCalculator; essentially just wraps around IBSCalculator to fit into VariableImportanceCalculator.
|
||||
*
|
||||
*/
|
||||
public class IBSErrorCalculatorWrapper implements ErrorCalculator<CompetingRiskResponse, CompetingRiskFunctions> {
|
||||
|
||||
private final IBSCalculator calculator;
|
||||
private final int[] events;
|
||||
private final double integrationUpperBound;
|
||||
private final double[] eventWeights;
|
||||
|
||||
public IBSErrorCalculatorWrapper(IBSCalculator calculator, int[] events, double integrationUpperBound, double[] eventWeights) {
|
||||
this.calculator = calculator;
|
||||
this.events = events;
|
||||
this.integrationUpperBound = integrationUpperBound;
|
||||
this.eventWeights = eventWeights;
|
||||
}
|
||||
|
||||
public IBSErrorCalculatorWrapper(IBSCalculator calculator, int[] events, double integrationUpperBound) {
|
||||
this.calculator = calculator;
|
||||
this.events = events;
|
||||
this.integrationUpperBound = integrationUpperBound;
|
||||
this.eventWeights = new double[events.length];
|
||||
|
||||
Arrays.fill(this.eventWeights, 1.0); // default is to just sum all errors together
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public double averageError(List<CompetingRiskResponse> responses, List<CompetingRiskFunctions> predictions) {
|
||||
final double[] errors = new double[events.length];
|
||||
final double n = responses.size();
|
||||
|
||||
for(int i=0; i < responses.size(); i++){
|
||||
final CompetingRiskResponse response = responses.get(i);
|
||||
final CompetingRiskFunctions prediction = predictions.get(i);
|
||||
|
||||
for(int k=0; k < this.events.length; k++){
|
||||
final int event = this.events[k];
|
||||
final RightContinuousStepFunction cif = prediction.getCumulativeIncidenceFunction(event);
|
||||
errors[k] += calculator.calculateError(response, cif, event, integrationUpperBound) / n;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
double totalError = 0.0;
|
||||
for(int k=0; k < this.events.length; k++){
|
||||
totalError += this.eventWeights[k] * errors[k];
|
||||
}
|
||||
|
||||
return totalError;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,23 @@
|
|||
package ca.joeltherrien.randomforest.tree.vimp;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
public class RegressionErrorCalculator implements ErrorCalculator<Double, Double>{
|
||||
|
||||
@Override
|
||||
public double averageError(List<Double> responses, List<Double> predictions) {
|
||||
double mean = 0.0;
|
||||
final double n = responses.size();
|
||||
|
||||
for(int i=0; i<responses.size(); i++){
|
||||
final double response = responses.get(i);
|
||||
final double prediction = predictions.get(i);
|
||||
|
||||
final double difference = response - prediction;
|
||||
|
||||
mean += difference * difference / n;
|
||||
}
|
||||
|
||||
return mean;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,117 @@
|
|||
package ca.joeltherrien.randomforest.tree.vimp;
|
||||
|
||||
import ca.joeltherrien.randomforest.CovariateRow;
|
||||
import ca.joeltherrien.randomforest.Row;
|
||||
import ca.joeltherrien.randomforest.covariates.Covariate;
|
||||
import ca.joeltherrien.randomforest.tree.Tree;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
import java.util.Random;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
|
||||
public class VariableImportanceCalculator<Y, P> {
|
||||
|
||||
private final ErrorCalculator<Y, P> errorCalculator;
|
||||
private final List<Tree<P>> trees;
|
||||
private final List<Row<Y>> observations;
|
||||
|
||||
private final boolean isTrainingSet; // If true, then we use out-of-bag predictions
|
||||
private final double[] baselineErrors;
|
||||
|
||||
public VariableImportanceCalculator(
|
||||
ErrorCalculator<Y, P> errorCalculator,
|
||||
List<Tree<P>> trees,
|
||||
List<Row<Y>> observations,
|
||||
boolean isTrainingSet
|
||||
){
|
||||
this.errorCalculator = errorCalculator;
|
||||
this.trees = trees;
|
||||
this.observations = observations;
|
||||
this.isTrainingSet = isTrainingSet;
|
||||
|
||||
|
||||
try {
|
||||
|
||||
this.baselineErrors = new double[trees.size()];
|
||||
for (int i = 0; i < baselineErrors.length; i++) {
|
||||
final Tree<P> tree = trees.get(i);
|
||||
final List<Row<Y>> oobSubset = getAppropriateSubset(observations, tree); // may not actually be OOB depending on isTrainingSet
|
||||
final List<Y> responses = oobSubset.stream().map(Row::getResponse).collect(Collectors.toList());
|
||||
|
||||
this.baselineErrors[i] = errorCalculator.averageError(responses, makePredictions(oobSubset, tree));
|
||||
}
|
||||
|
||||
} catch(Exception e){
|
||||
e.printStackTrace();
|
||||
throw e;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns an array of importance values for every Tree for the given Covariate.
|
||||
*
|
||||
* @param covariate The Covariate to scramble.
|
||||
* @param random
|
||||
* @return
|
||||
*/
|
||||
public double[] calculateVariableImportanceRaw(Covariate covariate, Optional<Random> random){
|
||||
|
||||
final double[] vimp = new double[trees.size()];
|
||||
for(int i = 0; i < vimp.length; i++){
|
||||
final Tree<P> tree = trees.get(i);
|
||||
final List<Row<Y>> oobSubset = getAppropriateSubset(observations, tree); // may not actually be OOB depending on isTrainingSet
|
||||
final List<Y> responses = oobSubset.stream().map(Row::getResponse).collect(Collectors.toList());
|
||||
final List<CovariateRow> scrambledValues = CovariateRow.scrambleCovariateValues(oobSubset, covariate, random);
|
||||
|
||||
final double error = errorCalculator.averageError(responses, makePredictions(scrambledValues, tree));
|
||||
|
||||
vimp[i] = error - this.baselineErrors[i];
|
||||
}
|
||||
|
||||
return vimp;
|
||||
}
|
||||
|
||||
public double calculateVariableImportanceZScore(Covariate covariate, Optional<Random> random){
|
||||
final double[] vimpArray = calculateVariableImportanceRaw(covariate, random);
|
||||
|
||||
double mean = 0.0;
|
||||
double variance = 0.0;
|
||||
final double numTrees = vimpArray.length;
|
||||
|
||||
for(double vimp : vimpArray){
|
||||
mean += vimp / numTrees;
|
||||
}
|
||||
for(double vimp : vimpArray){
|
||||
variance += (vimp - mean)*(vimp - mean) / (numTrees - 1.0);
|
||||
}
|
||||
|
||||
final double standardError = Math.sqrt(variance / numTrees);
|
||||
|
||||
return mean / standardError;
|
||||
}
|
||||
|
||||
|
||||
|
||||
// Assume rowList has already been filtered for OOB
|
||||
private List<P> makePredictions(List<? extends CovariateRow> rowList, Tree<P> tree){
|
||||
return rowList.stream()
|
||||
.map(tree::evaluate)
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
private List<Row<Y>> getAppropriateSubset(List<Row<Y>> initialList, Tree<P> tree){
|
||||
if(!isTrainingSet){
|
||||
return initialList; // no need to make any subsets
|
||||
}
|
||||
|
||||
return initialList.stream()
|
||||
.filter(row -> !tree.idInBootstrapSample(row.getId()))
|
||||
.collect(Collectors.toList());
|
||||
|
||||
}
|
||||
|
||||
|
||||
}
|
|
@ -16,9 +16,7 @@
|
|||
|
||||
package ca.joeltherrien.randomforest.utils;
|
||||
|
||||
import ca.joeltherrien.randomforest.tree.Forest;
|
||||
import ca.joeltherrien.randomforest.tree.ResponseCombiner;
|
||||
import ca.joeltherrien.randomforest.tree.Tree;
|
||||
import ca.joeltherrien.randomforest.tree.*;
|
||||
|
||||
import java.io.*;
|
||||
import java.util.*;
|
||||
|
@ -27,12 +25,17 @@ import java.util.zip.GZIPOutputStream;
|
|||
|
||||
public class DataUtils {
|
||||
|
||||
public static <O, FO> Forest<O, FO> loadForest(File folder, ResponseCombiner<O, FO> treeResponseCombiner) throws IOException, ClassNotFoundException {
|
||||
public static <O, FO> OnlineForest<O, FO> loadOnlineForest(File folder, ResponseCombiner<O, FO> treeResponseCombiner) throws IOException, ClassNotFoundException {
|
||||
if(!folder.isDirectory()){
|
||||
throw new IllegalArgumentException("Tree directory must be a directory!");
|
||||
}
|
||||
|
||||
final File[] treeFiles = folder.listFiles((file, s) -> s.endsWith(".tree"));
|
||||
|
||||
return loadOnlineForest(treeFiles, treeResponseCombiner);
|
||||
}
|
||||
|
||||
public static <O, FO> OnlineForest<O, FO> loadOnlineForest(File[] treeFiles, ResponseCombiner<O, FO> treeResponseCombiner) throws IOException, ClassNotFoundException {
|
||||
final List<File> treeFileList = Arrays.asList(treeFiles);
|
||||
|
||||
Collections.sort(treeFileList, Comparator.comparing(File::getName));
|
||||
|
@ -48,16 +51,16 @@ public class DataUtils {
|
|||
|
||||
}
|
||||
|
||||
return Forest.<O, FO>builder()
|
||||
return OnlineForest.<O, FO>builder()
|
||||
.trees(treeList)
|
||||
.treeResponseCombiner(treeResponseCombiner)
|
||||
.build();
|
||||
|
||||
}
|
||||
|
||||
public static <O, FO> Forest<O, FO> loadForest(String folder, ResponseCombiner<O, FO> treeResponseCombiner) throws IOException, ClassNotFoundException {
|
||||
public static <O, FO> OnlineForest<O, FO> loadOnlineForest(String folder, ResponseCombiner<O, FO> treeResponseCombiner) throws IOException, ClassNotFoundException {
|
||||
final File directory = new File(folder);
|
||||
return loadForest(directory, treeResponseCombiner);
|
||||
return loadOnlineForest(directory, treeResponseCombiner);
|
||||
}
|
||||
|
||||
public static void saveObject(Serializable object, String filename) throws IOException {
|
||||
|
|
|
@ -1,113 +0,0 @@
|
|||
/*
|
||||
* Copyright (c) 2019 Joel Therrien.
|
||||
* This program is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* This program is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License
|
||||
* along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package ca.joeltherrien.randomforest.utils;
|
||||
|
||||
/**
|
||||
* Represents a function represented by discrete points. However, the function may be right-continuous or left-continuous
|
||||
* at a given point, with no consistency. This function tracks that.
|
||||
*/
|
||||
public final class DiscontinuousStepFunction extends StepFunction {
|
||||
|
||||
private final double[] y;
|
||||
private final boolean[] isLeftContinuous;
|
||||
|
||||
/**
|
||||
* Represents the value that should be returned by evaluate if there are points prior to the time the function is being evaluated at.
|
||||
*
|
||||
* Map be null.
|
||||
*/
|
||||
private final double defaultY;
|
||||
|
||||
public DiscontinuousStepFunction(double[] x, double[] y, boolean[] isLeftContinuous, double defaultY) {
|
||||
super(x);
|
||||
this.y = y;
|
||||
this.isLeftContinuous = isLeftContinuous;
|
||||
this.defaultY = defaultY;
|
||||
}
|
||||
|
||||
@Override
|
||||
public double evaluate(double time){
|
||||
int index = Utils.binarySearchLessThan(0, x.length, x, time);
|
||||
if(index < 0){
|
||||
return defaultY;
|
||||
}
|
||||
else{
|
||||
if(x[index] == time){
|
||||
return evaluateByIndex(index);
|
||||
}
|
||||
else{
|
||||
return y[index];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public double evaluatePrevious(double time){
|
||||
int index = Utils.binarySearchLessThan(0, x.length, x, time) - 1;
|
||||
if(index < 0){
|
||||
return defaultY;
|
||||
}
|
||||
else{
|
||||
if(x[index] == time){
|
||||
return evaluateByIndex(index);
|
||||
}
|
||||
else{
|
||||
return y[index];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public double evaluateByIndex(int i) {
|
||||
if(isLeftContinuous[i]){
|
||||
i -= 1;
|
||||
}
|
||||
|
||||
if(i < 0){
|
||||
return defaultY;
|
||||
}
|
||||
|
||||
return y[i];
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString(){
|
||||
final StringBuilder builder = new StringBuilder();
|
||||
builder.append("Default point: ");
|
||||
builder.append(defaultY);
|
||||
builder.append("\n");
|
||||
|
||||
for(int i=0; i<x.length; i++){
|
||||
builder.append("x:");
|
||||
builder.append(x[i]);
|
||||
builder.append('\t');
|
||||
|
||||
if(isLeftContinuous[i]){
|
||||
builder.append("*y:");
|
||||
}
|
||||
else{
|
||||
builder.append("y*:");
|
||||
}
|
||||
builder.append(y[i]);
|
||||
builder.append("\n");
|
||||
}
|
||||
|
||||
return builder.toString();
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,68 @@
|
|||
/*
|
||||
* Copyright (c) 2019 Joel Therrien.
|
||||
* This program is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* This program is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License
|
||||
* along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package ca.joeltherrien.randomforest.utils;
|
||||
|
||||
import ca.joeltherrien.randomforest.tree.Tree;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
|
||||
import java.io.File;
|
||||
import java.io.FileInputStream;
|
||||
import java.io.IOException;
|
||||
import java.io.ObjectInputStream;
|
||||
import java.util.Iterator;
|
||||
import java.util.zip.GZIPInputStream;
|
||||
|
||||
@RequiredArgsConstructor
|
||||
public class IterableOfflineTree<Y> implements Iterable<Tree<Y>> {
|
||||
|
||||
private final File[] treeFiles;
|
||||
|
||||
@Override
|
||||
public Iterator<Tree<Y>> iterator() {
|
||||
return new OfflineTreeIterator<>(treeFiles);
|
||||
}
|
||||
|
||||
@RequiredArgsConstructor
|
||||
public static class OfflineTreeIterator<Y> implements Iterator<Tree<Y>>{
|
||||
private final File[] treeFiles;
|
||||
private int position = 0;
|
||||
|
||||
@Override
|
||||
public boolean hasNext() {
|
||||
return position < treeFiles.length;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Tree<Y> next() {
|
||||
final File treeFile = treeFiles[position];
|
||||
position++;
|
||||
|
||||
|
||||
try {
|
||||
final ObjectInputStream inputStream= new ObjectInputStream(new GZIPInputStream(new FileInputStream(treeFile)));
|
||||
final Tree<Y> tree = (Tree) inputStream.readObject();
|
||||
return tree;
|
||||
} catch (IOException | ClassNotFoundException e) {
|
||||
e.printStackTrace();
|
||||
throw new RuntimeException("Failed to load tree for " + treeFile.toString());
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
}
|
|
@ -1,129 +0,0 @@
|
|||
/*
|
||||
* Copyright (c) 2019 Joel Therrien.
|
||||
* This program is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* This program is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License
|
||||
* along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package ca.joeltherrien.randomforest.utils;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.ListIterator;
|
||||
|
||||
/**
|
||||
* Represents a function represented by discrete points. We assume that the function is a stepwise left-continuous
|
||||
* function, constant at the value of the previous encountered point.
|
||||
*
|
||||
*/
|
||||
public final class LeftContinuousStepFunction extends StepFunction {
|
||||
|
||||
private final double[] y;
|
||||
|
||||
/**
|
||||
* Represents the value that should be returned by evaluate if there are points prior to the time the function is being evaluated at.
|
||||
*
|
||||
* Map be null.
|
||||
*/
|
||||
private final double defaultY;
|
||||
|
||||
public LeftContinuousStepFunction(double[] x, double[] y, double defaultY) {
|
||||
super(x);
|
||||
this.y = y;
|
||||
this.defaultY = defaultY;
|
||||
}
|
||||
|
||||
/**
|
||||
* This isn't a formal constructor because of limitations with abstract classes.
|
||||
*
|
||||
* @param pointList
|
||||
* @param defaultY
|
||||
* @return
|
||||
*/
|
||||
public static LeftContinuousStepFunction constructFromPoints(final List<Point> pointList, final double defaultY){
|
||||
|
||||
final double[] x = new double[pointList.size()];
|
||||
final double[] y = new double[pointList.size()];
|
||||
|
||||
final ListIterator<Point> pointIterator = pointList.listIterator();
|
||||
while(pointIterator.hasNext()){
|
||||
final int index = pointIterator.nextIndex();
|
||||
final Point currentPoint = pointIterator.next();
|
||||
|
||||
x[index] = currentPoint.getTime();
|
||||
y[index] = currentPoint.getY();
|
||||
}
|
||||
|
||||
return new LeftContinuousStepFunction(x, y, defaultY);
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public double evaluate(double time){
|
||||
int index = Utils.binarySearchLessThan(0, x.length, x, time);
|
||||
if(index < 0){
|
||||
return defaultY;
|
||||
}
|
||||
else{
|
||||
if(x[index] == time){
|
||||
return evaluateByIndex(index-1);
|
||||
}
|
||||
else{
|
||||
return y[index];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public double evaluatePrevious(double time){
|
||||
int index = Utils.binarySearchLessThan(0, x.length, x, time) - 1;
|
||||
if(index < 0){
|
||||
return defaultY;
|
||||
}
|
||||
else{
|
||||
if(x[index] == time){
|
||||
return evaluateByIndex(index-1);
|
||||
}
|
||||
else{
|
||||
return y[index];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public double evaluateByIndex(int i) {
|
||||
if(i < 0){
|
||||
return defaultY;
|
||||
}
|
||||
|
||||
return y[i];
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public String toString(){
|
||||
final StringBuilder builder = new StringBuilder();
|
||||
builder.append("Default point: ");
|
||||
builder.append(defaultY);
|
||||
builder.append("\n");
|
||||
|
||||
for(int i=0; i<x.length; i++){
|
||||
builder.append("x:");
|
||||
builder.append(x[i]);
|
||||
builder.append("\ty:");
|
||||
builder.append(y[i]);
|
||||
builder.append("\n");
|
||||
}
|
||||
|
||||
return builder.toString();
|
||||
}
|
||||
|
||||
}
|
|
@ -26,6 +26,8 @@ import java.io.Serializable;
|
|||
*/
|
||||
@Data
|
||||
public class Point implements Serializable {
|
||||
private static final long serialVersionUID = 1L;
|
||||
|
||||
private final double time;
|
||||
private final double y;
|
||||
}
|
||||
|
|
|
@ -188,4 +188,24 @@ public final class RUtils {
|
|||
return responses;
|
||||
}
|
||||
|
||||
public static List<Object> produceSublist(List<Object> initialList, int[] indices){
|
||||
final List<Object> newList = new ArrayList<>(indices.length);
|
||||
|
||||
for(int i : indices){
|
||||
newList.add(initialList.get(i));
|
||||
}
|
||||
|
||||
return newList;
|
||||
}
|
||||
|
||||
public static File[] getTreeFileArray(String folderPath, int endingId){
|
||||
final File[] fileArray = new File[endingId];
|
||||
|
||||
for(int i = 1; i <= endingId; i++){
|
||||
fileArray[i-1] = new File(folderPath + "/tree-" + i + ".tree");
|
||||
}
|
||||
|
||||
return fileArray;
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -16,8 +16,14 @@
|
|||
|
||||
package ca.joeltherrien.randomforest.utils;
|
||||
|
||||
import lombok.Getter;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.ListIterator;
|
||||
import java.util.function.DoubleBinaryOperator;
|
||||
import java.util.function.DoubleUnaryOperator;
|
||||
|
||||
/**
|
||||
* Represents a function represented by discrete points. We assume that the function is a stepwise right-continuous
|
||||
|
@ -26,13 +32,16 @@ import java.util.ListIterator;
|
|||
*/
|
||||
public final class RightContinuousStepFunction extends StepFunction {
|
||||
|
||||
private static final long serialVersionUID = 1L;
|
||||
|
||||
private final double[] y;
|
||||
|
||||
/**
|
||||
* Represents the value that should be returned by evaluate if there are points prior to the time the function is being evaluated at.
|
||||
* Represents the value that should be returned by evaluate if there are no points prior to the time the function is being evaluated at.
|
||||
*
|
||||
* Map be null.
|
||||
* May not be null.
|
||||
*/
|
||||
@Getter
|
||||
private final double defaultY;
|
||||
|
||||
public RightContinuousStepFunction(double[] x, double[] y, double defaultY) {
|
||||
|
@ -127,7 +136,12 @@ public final class RightContinuousStepFunction extends StepFunction {
|
|||
}
|
||||
|
||||
if(to < from){
|
||||
return integrate(to, from);
|
||||
return -integrate(to, from);
|
||||
}
|
||||
|
||||
// Edge case - no points; just defaultY
|
||||
if(this.x.length == 0){
|
||||
return (to - from) * this.defaultY;
|
||||
}
|
||||
|
||||
double summation = 0.0;
|
||||
|
@ -170,7 +184,7 @@ public final class RightContinuousStepFunction extends StepFunction {
|
|||
final double currentTime = xPoints[i];
|
||||
final double currentHeight = evaluateByIndex(i);
|
||||
|
||||
if(i == xPoints.length-1 || xPoints[i+1] > to){
|
||||
if(i == xPoints.length-1 || xPoints[i+1] >= to){
|
||||
summation += currentHeight * (to - currentTime);
|
||||
return summation;
|
||||
}
|
||||
|
@ -186,5 +200,76 @@ public final class RightContinuousStepFunction extends StepFunction {
|
|||
|
||||
}
|
||||
|
||||
public RightContinuousStepFunction unaryOperation(DoubleUnaryOperator operator){
|
||||
final double newDefaultY = operator.applyAsDouble(this.defaultY);
|
||||
final double[] newY = Arrays.stream(this.getY()).map(operator).toArray();
|
||||
|
||||
return new RightContinuousStepFunction(this.getX(), newY, newDefaultY);
|
||||
}
|
||||
|
||||
|
||||
public static RightContinuousStepFunction biOperation(RightContinuousStepFunction funLeft,
|
||||
RightContinuousStepFunction funRight,
|
||||
DoubleBinaryOperator operator){
|
||||
|
||||
final double newDefaultY = operator.applyAsDouble(funLeft.defaultY, funRight.defaultY);
|
||||
final double[] leftX = funLeft.x;
|
||||
final double[] rightX = funRight.x;
|
||||
|
||||
final List<Point> combinedPoints = new ArrayList<>(leftX.length + rightX.length);
|
||||
|
||||
// These indexes represent the times that have *already* been processed.
|
||||
// They start at -1 because we already processed the defaultY values.
|
||||
int indexLeft = -1;
|
||||
int indexRight = -1;
|
||||
|
||||
// This while-loop will keep going until one of the functions reaches the ends of its points
|
||||
while(indexLeft < leftX.length-1 && indexRight < rightX.length-1){
|
||||
final double time;
|
||||
if(leftX[indexLeft+1] < rightX[indexRight+1]){
|
||||
indexLeft += 1;
|
||||
|
||||
time = leftX[indexLeft];
|
||||
combinedPoints.add(new Point(time, operator.applyAsDouble(funLeft.evaluateByIndex(indexLeft), funRight.evaluateByIndex(indexRight))));
|
||||
}
|
||||
else if(leftX[indexLeft+1] > rightX[indexRight+1]){
|
||||
indexRight += 1;
|
||||
|
||||
time = rightX[indexRight];
|
||||
combinedPoints.add(new Point(time, operator.applyAsDouble(funLeft.evaluateByIndex(indexLeft), funRight.evaluateByIndex(indexRight))));
|
||||
}
|
||||
else{ // equal times
|
||||
indexLeft += 1;
|
||||
indexRight += 1;
|
||||
|
||||
time = leftX[indexLeft];
|
||||
combinedPoints.add(new Point(time, operator.applyAsDouble(funLeft.evaluateByIndex(indexLeft), funRight.evaluateByIndex(indexRight))));
|
||||
}
|
||||
}
|
||||
|
||||
// At this point, at least one of function left or function right has reached the end of its points
|
||||
|
||||
// This while-loop occurring implies that functionRight can not move forward anymore
|
||||
while(indexLeft < leftX.length-1){
|
||||
indexLeft += 1;
|
||||
|
||||
final double time = leftX[indexLeft];
|
||||
combinedPoints.add(new Point(time, operator.applyAsDouble(funLeft.evaluateByIndex(indexLeft), funRight.evaluateByIndex(indexRight))));
|
||||
|
||||
}
|
||||
|
||||
// This while-loop occurring implies that functionLeft can not move forward anymore
|
||||
while(indexRight < rightX.length-1){
|
||||
indexRight += 1;
|
||||
|
||||
final double time = rightX[indexRight];
|
||||
combinedPoints.add(new Point(time, operator.applyAsDouble(funLeft.evaluateByIndex(indexLeft), funRight.evaluateByIndex(indexRight))));
|
||||
|
||||
}
|
||||
|
||||
|
||||
return RightContinuousStepFunction.constructFromPoints(combinedPoints, newDefaultY);
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -20,7 +20,7 @@ import java.util.*;
|
|||
|
||||
public final class Utils {
|
||||
|
||||
public static StepFunction estimateOneMinusECDF(final double[] times){
|
||||
public static RightContinuousStepFunction estimateOneMinusECDF(final double[] times){
|
||||
Arrays.sort(times);
|
||||
|
||||
final Map<Double, Integer> timeCounterMap = new HashMap<>();
|
||||
|
|
|
@ -1,80 +0,0 @@
|
|||
/*
|
||||
* Copyright (c) 2019 Joel Therrien.
|
||||
* This program is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* This program is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License
|
||||
* along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package ca.joeltherrien.randomforest.utils;
|
||||
|
||||
/**
|
||||
* Represents a step function represented by discrete points. However, there may be individual time values that has
|
||||
* a y value that doesn't belong to a particular 'step'.
|
||||
*/
|
||||
public final class VeryDiscontinuousStepFunction implements MathFunction {
|
||||
|
||||
private final double[] x;
|
||||
private final double[] yAt;
|
||||
private final double[] yRight;
|
||||
|
||||
/**
|
||||
* Represents the value that should be returned by evaluate if there are points prior to the time the function is being evaluated at.
|
||||
*
|
||||
* Map be null.
|
||||
*/
|
||||
private final double defaultY;
|
||||
|
||||
public VeryDiscontinuousStepFunction(double[] x, double[] yAt, double[] yRight, double defaultY) {
|
||||
this.x = x;
|
||||
this.yAt = yAt;
|
||||
this.yRight = yRight;
|
||||
this.defaultY = defaultY;
|
||||
}
|
||||
|
||||
@Override
|
||||
public double evaluate(double time){
|
||||
int index = Utils.binarySearchLessThan(0, x.length, x, time);
|
||||
if(index < 0){
|
||||
return defaultY;
|
||||
}
|
||||
else{
|
||||
if(x[index] == time){
|
||||
return yAt[index];
|
||||
}
|
||||
else{ // time > x[index]
|
||||
return yRight[index];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public String toString(){
|
||||
final StringBuilder builder = new StringBuilder();
|
||||
builder.append("Default point: ");
|
||||
builder.append(defaultY);
|
||||
builder.append("\n");
|
||||
|
||||
for(int i=0; i<x.length; i++){
|
||||
builder.append("x:");
|
||||
builder.append(x[i]);
|
||||
builder.append("\tyAt:");
|
||||
builder.append(yAt[i]);
|
||||
builder.append("\tyRight:");
|
||||
builder.append(yRight[i]);
|
||||
builder.append("\n");
|
||||
}
|
||||
|
||||
return builder.toString();
|
||||
}
|
||||
|
||||
}
|
|
@ -45,20 +45,20 @@ public class TestDeterministicForests {
|
|||
|
||||
int index = 0;
|
||||
for(int j=0; j<5; j++){
|
||||
final NumericCovariate numericCovariate = new NumericCovariate("numeric"+j, index);
|
||||
final NumericCovariate numericCovariate = new NumericCovariate("numeric"+j, index, false);
|
||||
covariateList.add(numericCovariate);
|
||||
index++;
|
||||
}
|
||||
|
||||
for(int j=0; j<5; j++){
|
||||
final BooleanCovariate booleanCovariate = new BooleanCovariate("boolean"+j, index);
|
||||
final BooleanCovariate booleanCovariate = new BooleanCovariate("boolean"+j, index, false);
|
||||
covariateList.add(booleanCovariate);
|
||||
index++;
|
||||
}
|
||||
|
||||
final List<String> levels = Utils.easyList("cat", "dog", "mouse");
|
||||
for(int j=0; j<5; j++){
|
||||
final FactorCovariate factorCovariate = new FactorCovariate("factor"+j, index, levels);
|
||||
final FactorCovariate factorCovariate = new FactorCovariate("factor"+j, index, levels, false);
|
||||
covariateList.add(factorCovariate);
|
||||
index++;
|
||||
}
|
||||
|
@ -214,14 +214,14 @@ public class TestDeterministicForests {
|
|||
|
||||
forestTrainer5Trees.trainSerialOnDisk(Optional.empty());
|
||||
forestTrainer10Trees.trainSerialOnDisk(Optional.empty());
|
||||
final Forest<Double, Double> forestSerial = DataUtils.loadForest(saveTreeFile, new MeanResponseCombiner());
|
||||
final Forest<Double, Double> forestSerial = DataUtils.loadOnlineForest(saveTreeFile, new MeanResponseCombiner());
|
||||
TestUtils.removeFolder(saveTreeFile);
|
||||
verifyTwoForestsEqual(testData, referenceForest, forestSerial);
|
||||
|
||||
|
||||
forestTrainer5Trees.trainParallelOnDisk(Optional.empty(), 4);
|
||||
forestTrainer10Trees.trainParallelOnDisk(Optional.empty(), 4);
|
||||
final Forest<Double, Double> forestParallel = DataUtils.loadForest(saveTreeFile, new MeanResponseCombiner());
|
||||
final Forest<Double, Double> forestParallel = DataUtils.loadOnlineForest(saveTreeFile, new MeanResponseCombiner());
|
||||
TestUtils.removeFolder(saveTreeFile);
|
||||
verifyTwoForestsEqual(testData, referenceForest, forestParallel);
|
||||
|
||||
|
@ -259,7 +259,7 @@ public class TestDeterministicForests {
|
|||
|
||||
for(int k=0; k<3; k++){
|
||||
forestTrainer.trainSerialOnDisk(Optional.empty());
|
||||
final Forest<Double, Double> replicantForest = DataUtils.loadForest(saveTreeFile, responseCombiner);
|
||||
final Forest<Double, Double> replicantForest = DataUtils.loadOnlineForest(saveTreeFile, responseCombiner);
|
||||
TestUtils.removeFolder(saveTreeFile);
|
||||
verifyTwoForestsEqual(testData, referenceForest, replicantForest);
|
||||
}
|
||||
|
@ -274,7 +274,7 @@ public class TestDeterministicForests {
|
|||
|
||||
for(int k=0; k<3; k++){
|
||||
forestTrainer.trainParallelOnDisk(Optional.empty(), 4);
|
||||
final Forest<Double, Double> replicantForest = DataUtils.loadForest(saveTreeFile, responseCombiner);
|
||||
final Forest<Double, Double> replicantForest = DataUtils.loadOnlineForest(saveTreeFile, responseCombiner);
|
||||
TestUtils.removeFolder(saveTreeFile);
|
||||
verifyTwoForestsEqual(testData, referenceForest, replicantForest);
|
||||
}
|
||||
|
|
|
@ -20,8 +20,8 @@ import ca.joeltherrien.randomforest.covariates.Covariate;
|
|||
import ca.joeltherrien.randomforest.covariates.numeric.NumericCovariate;
|
||||
import ca.joeltherrien.randomforest.responses.regression.MeanResponseCombiner;
|
||||
import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceSplitFinder;
|
||||
import ca.joeltherrien.randomforest.tree.Forest;
|
||||
import ca.joeltherrien.randomforest.tree.ForestTrainer;
|
||||
import ca.joeltherrien.randomforest.tree.OnlineForest;
|
||||
import ca.joeltherrien.randomforest.tree.Tree;
|
||||
import ca.joeltherrien.randomforest.tree.TreeTrainer;
|
||||
import ca.joeltherrien.randomforest.utils.DataUtils;
|
||||
|
@ -39,12 +39,12 @@ import static org.junit.jupiter.api.Assertions.*;
|
|||
|
||||
public class TestProvidingInitialForest {
|
||||
|
||||
private Forest<Double, Double> initialForest;
|
||||
private OnlineForest<Double, Double> initialForest;
|
||||
private List<Covariate> covariateList;
|
||||
private List<Row<Double>> data;
|
||||
|
||||
public TestProvidingInitialForest(){
|
||||
covariateList = Collections.singletonList(new NumericCovariate("x", 0));
|
||||
covariateList = Collections.singletonList(new NumericCovariate("x", 0, false));
|
||||
|
||||
data = Utils.easyList(
|
||||
Row.createSimple(Utils.easyMap("x", "1.0"), covariateList, 1, 1.0),
|
||||
|
@ -107,8 +107,8 @@ public class TestProvidingInitialForest {
|
|||
public void testSerialInMemory(){
|
||||
final ForestTrainer<Double, Double, Double> forestTrainer = getForestTrainer(null, 20);
|
||||
|
||||
final Forest<Double, Double> newForest = forestTrainer.trainSerialInMemory(Optional.of(initialForest));
|
||||
assertEquals(20, newForest.getTrees().size());
|
||||
final OnlineForest<Double, Double> newForest = forestTrainer.trainSerialInMemory(Optional.of(initialForest));
|
||||
assertEquals(20, newForest.getNumberOfTrees());
|
||||
|
||||
for(Tree<Double> initialTree : initialForest.getTrees()){
|
||||
assertTrue(newForest.getTrees().contains(initialTree));
|
||||
|
@ -124,8 +124,8 @@ public class TestProvidingInitialForest {
|
|||
public void testParallelInMemory(){
|
||||
final ForestTrainer<Double, Double, Double> forestTrainer = getForestTrainer(null, 20);
|
||||
|
||||
final Forest<Double, Double> newForest = forestTrainer.trainParallelInMemory(Optional.of(initialForest), 2);
|
||||
assertEquals(20, newForest.getTrees().size());
|
||||
final OnlineForest<Double, Double> newForest = forestTrainer.trainParallelInMemory(Optional.of(initialForest), 2);
|
||||
assertEquals(20, newForest.getNumberOfTrees());
|
||||
|
||||
for(Tree<Double> initialTree : initialForest.getTrees()){
|
||||
assertTrue(newForest.getTrees().contains(initialTree));
|
||||
|
@ -149,11 +149,11 @@ public class TestProvidingInitialForest {
|
|||
forestTrainer.trainParallelOnDisk(Optional.of(initialForest), 2);
|
||||
|
||||
assertEquals(20, directory.listFiles().length);
|
||||
final Forest<Double, Double> newForest = DataUtils.loadForest(directory, new MeanResponseCombiner());
|
||||
final OnlineForest<Double, Double> newForest = DataUtils.loadOnlineForest(directory, new MeanResponseCombiner());
|
||||
|
||||
|
||||
|
||||
assertEquals(20, newForest.getTrees().size());
|
||||
assertEquals(20, newForest.getNumberOfTrees());
|
||||
|
||||
final List<String> newForestTreesAsStrings = newForest.getTrees().stream()
|
||||
.map(tree -> tree.toString()).collect(Collectors.toList());
|
||||
|
@ -179,9 +179,9 @@ public class TestProvidingInitialForest {
|
|||
|
||||
assertEquals(20, directory.listFiles().length);
|
||||
|
||||
final Forest<Double, Double> newForest = DataUtils.loadForest(directory, new MeanResponseCombiner());
|
||||
final OnlineForest<Double, Double> newForest = DataUtils.loadOnlineForest(directory, new MeanResponseCombiner());
|
||||
|
||||
assertEquals(20, newForest.getTrees().size());
|
||||
assertEquals(20, newForest.getNumberOfTrees());
|
||||
|
||||
final List<String> newForestTreesAsStrings = newForest.getTrees().stream()
|
||||
.map(tree -> tree.toString()).collect(Collectors.toList());
|
||||
|
@ -198,7 +198,7 @@ public class TestProvidingInitialForest {
|
|||
it's not clear if the forest being provided is the same one that trees were saved from.
|
||||
*/
|
||||
@Test
|
||||
public void verifyExceptions(){
|
||||
public void testExceptions(){
|
||||
final String filePath = "src/test/resources/trees/";
|
||||
final File directory = new File(filePath);
|
||||
if(directory.exists()){
|
||||
|
|
|
@ -24,11 +24,10 @@ import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskRespons
|
|||
import ca.joeltherrien.randomforest.responses.competingrisk.combiner.CompetingRiskFunctionCombiner;
|
||||
import ca.joeltherrien.randomforest.responses.competingrisk.combiner.CompetingRiskResponseCombiner;
|
||||
import ca.joeltherrien.randomforest.responses.competingrisk.splitfinder.LogRankSplitFinder;
|
||||
import ca.joeltherrien.randomforest.tree.Forest;
|
||||
import ca.joeltherrien.randomforest.tree.ForestTrainer;
|
||||
import ca.joeltherrien.randomforest.tree.TreeTrainer;
|
||||
import ca.joeltherrien.randomforest.tree.*;
|
||||
import ca.joeltherrien.randomforest.utils.DataUtils;
|
||||
import ca.joeltherrien.randomforest.utils.ResponseLoader;
|
||||
import ca.joeltherrien.randomforest.utils.RightContinuousStepFunction;
|
||||
import ca.joeltherrien.randomforest.utils.Utils;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
|
@ -47,10 +46,10 @@ public class TestSavingLoading {
|
|||
|
||||
public List<Covariate> getCovariates(){
|
||||
return Utils.easyList(
|
||||
new NumericCovariate("ageatfda", 0),
|
||||
new BooleanCovariate("idu", 1),
|
||||
new BooleanCovariate("black", 2),
|
||||
new NumericCovariate("cd4nadir", 3)
|
||||
new NumericCovariate("ageatfda", 0, false),
|
||||
new BooleanCovariate("idu", 1, false),
|
||||
new BooleanCovariate("black", 2, false),
|
||||
new NumericCovariate("cd4nadir", 3, false)
|
||||
);
|
||||
}
|
||||
|
||||
|
@ -119,16 +118,21 @@ public class TestSavingLoading {
|
|||
assertTrue(directory.isDirectory());
|
||||
assertEquals(NTREE, directory.listFiles().length);
|
||||
|
||||
final Forest<CompetingRiskFunctions, CompetingRiskFunctions> forest = DataUtils.loadForest(directory, new CompetingRiskFunctionCombiner(new int[]{1,2}, null));
|
||||
final CompetingRiskFunctionCombiner treeResponseCombiner = new CompetingRiskFunctionCombiner(new int[]{1,2}, null);
|
||||
final OnlineForest<CompetingRiskFunctions, CompetingRiskFunctions> onlineForest = DataUtils.loadOnlineForest(directory, treeResponseCombiner);
|
||||
final OfflineForest<CompetingRiskFunctions, CompetingRiskFunctions> offlineForest = new OfflineForest<>(directory, treeResponseCombiner);
|
||||
|
||||
final CovariateRow predictionRow = getPredictionRow(covariates);
|
||||
|
||||
final CompetingRiskFunctions functions = forest.evaluate(predictionRow);
|
||||
assertNotNull(functions);
|
||||
assertTrue(functions.getCumulativeIncidenceFunction(1).getX().length > 2);
|
||||
final CompetingRiskFunctions functionsOnline = onlineForest.evaluate(predictionRow);
|
||||
assertNotNull(functionsOnline);
|
||||
assertTrue(functionsOnline.getCumulativeIncidenceFunction(1).getX().length > 2);
|
||||
|
||||
final CompetingRiskFunctions functionsOffline = offlineForest.evaluate(predictionRow);
|
||||
assertTrue(competingFunctionsEqual(functionsOffline, functionsOnline));
|
||||
|
||||
|
||||
assertEquals(NTREE, forest.getTrees().size());
|
||||
assertEquals(NTREE, onlineForest.getTrees().size());
|
||||
|
||||
TestUtils.removeFolder(directory);
|
||||
|
||||
|
@ -159,17 +163,22 @@ public class TestSavingLoading {
|
|||
assertEquals(NTREE, directory.listFiles().length);
|
||||
|
||||
|
||||
|
||||
final Forest<CompetingRiskFunctions, CompetingRiskFunctions> forest = DataUtils.loadForest(directory, new CompetingRiskFunctionCombiner(new int[]{1,2}, null));
|
||||
final CompetingRiskFunctionCombiner treeResponseCombiner = new CompetingRiskFunctionCombiner(new int[]{1,2}, null);
|
||||
final OnlineForest<CompetingRiskFunctions, CompetingRiskFunctions> onlineForest = DataUtils.loadOnlineForest(directory, treeResponseCombiner);
|
||||
final OfflineForest<CompetingRiskFunctions, CompetingRiskFunctions> offlineForest = new OfflineForest<>(directory, treeResponseCombiner);
|
||||
|
||||
final CovariateRow predictionRow = getPredictionRow(covariates);
|
||||
|
||||
final CompetingRiskFunctions functions = forest.evaluate(predictionRow);
|
||||
assertNotNull(functions);
|
||||
assertTrue(functions.getCumulativeIncidenceFunction(1).getX().length > 2);
|
||||
final CompetingRiskFunctions functionsOnline = onlineForest.evaluate(predictionRow);
|
||||
assertNotNull(functionsOnline);
|
||||
assertTrue(functionsOnline.getCumulativeIncidenceFunction(1).getX().length > 2);
|
||||
|
||||
|
||||
assertEquals(NTREE, forest.getTrees().size());
|
||||
final CompetingRiskFunctions functionsOffline = offlineForest.evaluate(predictionRow);
|
||||
assertTrue(competingFunctionsEqual(functionsOffline, functionsOnline));
|
||||
|
||||
|
||||
assertEquals(NTREE, onlineForest.getTrees().size());
|
||||
|
||||
TestUtils.removeFolder(directory);
|
||||
|
||||
|
@ -177,6 +186,64 @@ public class TestSavingLoading {
|
|||
|
||||
}
|
||||
|
||||
/*
|
||||
We don't implement equals() methods on the below mentioned classes because then we'd need to implement an
|
||||
appropriate hashCode() method that's consistent with the equals(), and we only need plain equals() for
|
||||
these tests.
|
||||
*/
|
||||
|
||||
private boolean competingFunctionsEqual(CompetingRiskFunctions f1 ,CompetingRiskFunctions f2){
|
||||
if(!functionsEqual(f1.getSurvivalCurve(), f2.getSurvivalCurve())){
|
||||
return false;
|
||||
}
|
||||
|
||||
for(int i=1; i<=2; i++){
|
||||
if(!functionsEqual(f1.getCauseSpecificHazardFunction(i), f2.getCauseSpecificHazardFunction(i))){
|
||||
return false;
|
||||
}
|
||||
if(!functionsEqual(f1.getCumulativeIncidenceFunction(i), f2.getCumulativeIncidenceFunction(i))){
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
private boolean functionsEqual(RightContinuousStepFunction f1, RightContinuousStepFunction f2){
|
||||
|
||||
final double[] f1X = f1.getX();
|
||||
final double[] f2X = f2.getX();
|
||||
|
||||
final double[] f1Y = f1.getY();
|
||||
final double[] f2Y = f2.getY();
|
||||
|
||||
// first compare array lengths
|
||||
if(f1X.length != f2X.length){
|
||||
return false;
|
||||
}
|
||||
if(f1Y.length != f2Y.length){
|
||||
return false;
|
||||
}
|
||||
|
||||
// TODO - better comparisons of doubles. I don't really care too much though as this equals method is only being used in tests
|
||||
final double delta = 0.000001;
|
||||
|
||||
if(Math.abs(f1.getDefaultY() - f2.getDefaultY()) > delta){
|
||||
return false;
|
||||
}
|
||||
|
||||
for(int i=0; i < f1X.length; i++){
|
||||
if(Math.abs(f1X[i] - f2X[i]) > delta){
|
||||
return false;
|
||||
}
|
||||
if(Math.abs(f1Y[i] - f2Y[i]) > delta){
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
|
|
@ -156,7 +156,7 @@ public class TestUtils {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void reduceListToSize(){
|
||||
public void testReduceListToSize(){
|
||||
final List<Integer> testList = Utils.easyList(1, 2, 3, 4, 5, 6, 7, 8, 9, 10);
|
||||
final Random random = new Random();
|
||||
for(int i=0; i<100; i++) { // want to test many times to ensure it doesn't work just due to randomness
|
||||
|
|
|
@ -0,0 +1,164 @@
|
|||
package ca.joeltherrien.randomforest.competingrisk;
|
||||
|
||||
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskFunctions;
|
||||
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
|
||||
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskUtils;
|
||||
import ca.joeltherrien.randomforest.responses.competingrisk.IBSCalculator;
|
||||
import ca.joeltherrien.randomforest.utils.Point;
|
||||
import ca.joeltherrien.randomforest.utils.RightContinuousStepFunction;
|
||||
import ca.joeltherrien.randomforest.utils.Utils;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
|
||||
public class IBSCalculatorTest {
|
||||
|
||||
private final RightContinuousStepFunction cif;
|
||||
|
||||
public IBSCalculatorTest(){
|
||||
this.cif = RightContinuousStepFunction.constructFromPoints(
|
||||
Utils.easyList(
|
||||
new Point(1.0, 0.1),
|
||||
new Point(2.0, 0.2),
|
||||
new Point(3.0, 0.3),
|
||||
new Point(4.0, 0.8)
|
||||
), 0.0
|
||||
);
|
||||
}
|
||||
|
||||
/*
|
||||
R code to get these results:
|
||||
|
||||
predicted_cif <- stepfun(1:4, c(0, 0.1, 0.2, 0.3, 0.8))
|
||||
weights <- 1
|
||||
recorded_time <- 2.0
|
||||
recorded_status <- 1.0
|
||||
event_of_interest <- 2
|
||||
times <- 0:4
|
||||
|
||||
errors <- weights * ( as.integer(recorded_time <= times & recorded_status == event_of_interest) - predicted_cif(times))^2
|
||||
sum(errors)
|
||||
|
||||
|
||||
and run again with event_of_interest <- 1
|
||||
|
||||
|
||||
Note that in the R code I only evaluate up to 4, while in the Java code I integrate up to 5
|
||||
This is because the evaluation at 4 is giving the area of the rectangle from 4 to 5.
|
||||
|
||||
*/
|
||||
|
||||
@Test
|
||||
public void testResultsWithoutCensoringDistribution(){
|
||||
final IBSCalculator calculator = new IBSCalculator();
|
||||
|
||||
final double errorDifferentEvent = calculator.calculateError(
|
||||
new CompetingRiskResponse(1, 2.0),
|
||||
this.cif,
|
||||
2,
|
||||
5.0);
|
||||
|
||||
assertEquals(0.78, errorDifferentEvent, 0.000001);
|
||||
|
||||
final double errorSameEvent = calculator.calculateError(
|
||||
new CompetingRiskResponse(1, 2.0),
|
||||
this.cif,
|
||||
1,
|
||||
5.0);
|
||||
|
||||
assertEquals(1.18, errorSameEvent, 0.000001);
|
||||
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testResultsWithCensoringDistribution(){
|
||||
final RightContinuousStepFunction censorSurvivalFunction = RightContinuousStepFunction.constructFromPoints(
|
||||
Utils.easyList(
|
||||
new Point(0.0, 0.75),
|
||||
new Point(1.0, 0.5),
|
||||
new Point(3.0, 0.25),
|
||||
new Point(5.0, 0)
|
||||
), 1.0
|
||||
);
|
||||
|
||||
final IBSCalculator calculator = new IBSCalculator(censorSurvivalFunction);
|
||||
|
||||
final double errorDifferentEvent = calculator.calculateError(
|
||||
new CompetingRiskResponse(1, 2.0),
|
||||
this.cif,
|
||||
2,
|
||||
5.0);
|
||||
|
||||
assertEquals(1.56, errorDifferentEvent, 0.000001);
|
||||
|
||||
final double errorSameEvent = calculator.calculateError(
|
||||
new CompetingRiskResponse(1, 2.0),
|
||||
this.cif,
|
||||
1,
|
||||
5.0);
|
||||
|
||||
assertEquals(2.36, errorSameEvent, 0.000001);
|
||||
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testStaticFunction(){
|
||||
final RightContinuousStepFunction censorSurvivalFunction = RightContinuousStepFunction.constructFromPoints(
|
||||
Utils.easyList(
|
||||
new Point(0.0, 0.75),
|
||||
new Point(1.0, 0.5),
|
||||
new Point(3.0, 0.25),
|
||||
new Point(5.0, 0)
|
||||
), 1.0
|
||||
);
|
||||
|
||||
final List<CompetingRiskResponse> responseList = Utils.easyList(
|
||||
new CompetingRiskResponse(1, 2.0),
|
||||
new CompetingRiskResponse(1, 2.0));
|
||||
|
||||
// for predictions; we'll construct an improper CompetingRisksFunctions
|
||||
final RightContinuousStepFunction trivialFunction = RightContinuousStepFunction.constructFromPoints(
|
||||
Utils.easyList(new Point(1.0, 0.0)),
|
||||
1.0);
|
||||
|
||||
final CompetingRiskFunctions prediction = CompetingRiskFunctions.builder()
|
||||
.survivalCurve(trivialFunction)
|
||||
.causeSpecificHazards(Utils.easyList(trivialFunction, trivialFunction))
|
||||
.cumulativeIncidenceCurves(Utils.easyList(this.cif, trivialFunction))
|
||||
.build();
|
||||
|
||||
final List<CompetingRiskFunctions> predictionList = Utils.easyList(prediction, prediction);
|
||||
|
||||
double[] errorParallel = CompetingRiskUtils.calculateIBSError(
|
||||
responseList,
|
||||
predictionList,
|
||||
Optional.of(censorSurvivalFunction),
|
||||
1,
|
||||
5.0,
|
||||
true);
|
||||
|
||||
double[] errorSerial = CompetingRiskUtils.calculateIBSError(
|
||||
responseList,
|
||||
predictionList,
|
||||
Optional.of(censorSurvivalFunction),
|
||||
1,
|
||||
5.0,
|
||||
false);
|
||||
|
||||
assertEquals(responseList.size(), errorParallel.length);
|
||||
assertEquals(responseList.size(), errorSerial.length);
|
||||
|
||||
assertEquals(2.36, errorParallel[0], 0.000001);
|
||||
assertEquals(2.36, errorParallel[1], 0.000001);
|
||||
|
||||
assertEquals(2.36, errorSerial[0], 0.000001);
|
||||
assertEquals(2.36, errorSerial[1], 0.000001);
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
||||
}
|
|
@ -53,10 +53,10 @@ public class TestCompetingRisk {
|
|||
|
||||
public List<Covariate> getCovariates(){
|
||||
return Utils.easyList(
|
||||
new NumericCovariate("ageatfda", 0),
|
||||
new BooleanCovariate("idu", 1),
|
||||
new BooleanCovariate("black", 2),
|
||||
new NumericCovariate("cd4nadir", 3)
|
||||
new NumericCovariate("ageatfda", 0, false),
|
||||
new BooleanCovariate("idu", 1, false),
|
||||
new BooleanCovariate("black", 2, false),
|
||||
new NumericCovariate("cd4nadir", 3, false)
|
||||
);
|
||||
}
|
||||
|
||||
|
@ -109,8 +109,8 @@ public class TestCompetingRisk {
|
|||
|
||||
// by only using BooleanCovariates (only one split rule) we can guarantee identical results with randomForestSRC on one tree.
|
||||
final List<Covariate> covariates = Utils.easyList(
|
||||
new BooleanCovariate("idu", 0),
|
||||
new BooleanCovariate("black", 1)
|
||||
new BooleanCovariate("idu", 0, false),
|
||||
new BooleanCovariate("black", 1, false)
|
||||
);
|
||||
|
||||
final List<Row<CompetingRiskResponse>> dataset = getData(covariates, "src/test/resources/wihs.bootstrapped.csv");
|
||||
|
@ -210,8 +210,8 @@ public class TestCompetingRisk {
|
|||
public void testLogRankSplitFinderTwoBooleans() throws IOException {
|
||||
// by only using BooleanCovariates (only one split rule) we can guarantee identical results with randomForestSRC on one tree.
|
||||
final List<Covariate> covariates = Utils.easyList(
|
||||
new BooleanCovariate("idu", 0),
|
||||
new BooleanCovariate("black", 1)
|
||||
new BooleanCovariate("idu", 0, false),
|
||||
new BooleanCovariate("black", 1, false)
|
||||
);
|
||||
|
||||
|
||||
|
@ -259,7 +259,7 @@ public class TestCompetingRisk {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void verifyDataset() throws IOException {
|
||||
public void testDataset() throws IOException {
|
||||
final List<Covariate> covariates = getCovariates();
|
||||
|
||||
final List<Row<CompetingRiskResponse>> dataset = getData(covariates, DEFAULT_FILEPATH);
|
||||
|
|
|
@ -16,10 +16,11 @@
|
|||
|
||||
package ca.joeltherrien.randomforest.competingrisk;
|
||||
|
||||
import ca.joeltherrien.randomforest.Row;
|
||||
import ca.joeltherrien.randomforest.covariates.Covariate;
|
||||
import ca.joeltherrien.randomforest.responses.competingrisk.*;
|
||||
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskFunctions;
|
||||
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
|
||||
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskUtils;
|
||||
import ca.joeltherrien.randomforest.tree.Forest;
|
||||
import ca.joeltherrien.randomforest.tree.OnlineForest;
|
||||
import ca.joeltherrien.randomforest.utils.RightContinuousStepFunction;
|
||||
import ca.joeltherrien.randomforest.utils.StepFunction;
|
||||
import ca.joeltherrien.randomforest.utils.Utils;
|
||||
|
@ -30,8 +31,6 @@ import java.util.List;
|
|||
|
||||
import static ca.joeltherrien.randomforest.TestUtils.closeEnough;
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
public class TestCompetingRiskErrorRateCalculator {
|
||||
|
||||
|
@ -48,7 +47,7 @@ public class TestCompetingRiskErrorRateCalculator {
|
|||
|
||||
final int event = 1;
|
||||
|
||||
final Forest<CompetingRiskFunctions, CompetingRiskFunctions> fakeForest = Forest.<CompetingRiskFunctions, CompetingRiskFunctions>builder().build();
|
||||
final Forest<CompetingRiskFunctions, CompetingRiskFunctions> fakeForest = OnlineForest.<CompetingRiskFunctions, CompetingRiskFunctions>builder().build();
|
||||
|
||||
final double naiveConcordance = CompetingRiskUtils.calculateConcordance(responseList, mortalityArray, event);
|
||||
|
||||
|
|
|
@ -46,7 +46,7 @@ public class TestLogRankSplitFinder {
|
|||
public static Data<CompetingRiskResponse> loadData(String filename) throws IOException {
|
||||
|
||||
final List<Covariate> covariates = Utils.easyList(
|
||||
new NumericCovariate("x2", 0)
|
||||
new NumericCovariate("x2", 0, false)
|
||||
);
|
||||
|
||||
final List<Row<CompetingRiskResponse>> rows = TestUtils.loadData(covariates, new ResponseLoader.CompetingRisksResponseLoader("delta", "u"), filename);
|
||||
|
|
|
@ -16,7 +16,6 @@
|
|||
|
||||
package ca.joeltherrien.randomforest.competingrisk;
|
||||
|
||||
import ca.joeltherrien.randomforest.utils.LeftContinuousStepFunction;
|
||||
import ca.joeltherrien.randomforest.utils.RightContinuousStepFunction;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
|
@ -31,13 +30,6 @@ public class TestMathFunctions {
|
|||
return new RightContinuousStepFunction(time, y, 0.1);
|
||||
}
|
||||
|
||||
private LeftContinuousStepFunction generateLeftContinuousStepFunction(){
|
||||
final double[] time = new double[]{1.0, 2.0, 3.0};
|
||||
final double[] y = new double[]{-1.0, 1.0, 0.5};
|
||||
|
||||
return new LeftContinuousStepFunction(time, y, 0.1);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testRightContinuousStepFunction(){
|
||||
final RightContinuousStepFunction function = generateRightContinuousStepFunction();
|
||||
|
@ -56,21 +48,5 @@ public class TestMathFunctions {
|
|||
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testLeftContinuousStepFunction(){
|
||||
final LeftContinuousStepFunction function = generateLeftContinuousStepFunction();
|
||||
|
||||
assertEquals(0.1, function.evaluate(0.5));
|
||||
assertEquals(0.1, function.evaluate(1.0));
|
||||
assertEquals(-1.0, function.evaluate(2.0));
|
||||
assertEquals(1.0, function.evaluate(3.0));
|
||||
|
||||
|
||||
assertEquals(0.1, function.evaluate(0.6));
|
||||
assertEquals(-1.0, function.evaluate(1.1));
|
||||
assertEquals(1.0, function.evaluate(2.1));
|
||||
assertEquals(0.5, function.evaluate(3.1));
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -17,12 +17,15 @@
|
|||
package ca.joeltherrien.randomforest.covariates;
|
||||
|
||||
|
||||
import ca.joeltherrien.randomforest.Row;
|
||||
import ca.joeltherrien.randomforest.covariates.factor.FactorCovariate;
|
||||
import ca.joeltherrien.randomforest.tree.Split;
|
||||
import ca.joeltherrien.randomforest.utils.Utils;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.function.Executable;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Random;
|
||||
|
||||
|
@ -31,7 +34,7 @@ import static org.junit.jupiter.api.Assertions.*;
|
|||
public class FactorCovariateTest {
|
||||
|
||||
@Test
|
||||
void verifyEqualLevels() {
|
||||
public void testEqualLevels() {
|
||||
final FactorCovariate petCovariate = createTestCovariate();
|
||||
|
||||
final FactorCovariate.FactorValue dog1 = petCovariate.createValue("DOG");
|
||||
|
@ -53,7 +56,7 @@ public class FactorCovariateTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
void verifyBadLevelException(){
|
||||
public void testBadLevelException(){
|
||||
final FactorCovariate petCovariate = createTestCovariate();
|
||||
final Executable badCode = () -> petCovariate.createValue("vulcan");
|
||||
|
||||
|
@ -61,25 +64,169 @@ public class FactorCovariateTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
void testAllSubsets(){
|
||||
public void testAllSubsets(){
|
||||
final int n = 2*3; // ensure that n is a multiple of 3 for the test
|
||||
final FactorCovariate petCovariate = createTestCovariate();
|
||||
final List<Row<Double>> data = generateSampleData(petCovariate, n);
|
||||
|
||||
final List<SplitRule<String>> splitRules = new ArrayList<>();
|
||||
final List<Split<Double, String>> splits = new ArrayList<>();
|
||||
|
||||
petCovariate.generateSplitRuleUpdater(null, 100, new Random())
|
||||
.forEachRemaining(split -> splitRules.add(split.getSplitRule()));
|
||||
petCovariate.generateSplitRuleUpdater(data, 100, new Random())
|
||||
.forEachRemaining(split -> splits.add(split));
|
||||
|
||||
assertEquals(splitRules.size(), 3);
|
||||
assertEquals(splits.size(), 3);
|
||||
|
||||
// TODO verify the contents of the split rules
|
||||
// These are the 3 possibilities
|
||||
boolean dog_catmouse = false;
|
||||
boolean cat_dogmouse = false;
|
||||
boolean mouse_dogcat = false;
|
||||
|
||||
for(Split<Double, String> split : splits){
|
||||
List<Row<Double>> smallerHand;
|
||||
List<Row<Double>> largerHand;
|
||||
|
||||
if(split.getLeftHand().size() < split.getRightHand().size()){
|
||||
smallerHand = split.getLeftHand();
|
||||
largerHand = split.getRightHand();
|
||||
} else{
|
||||
smallerHand = split.getRightHand();
|
||||
largerHand = split.getLeftHand();
|
||||
}
|
||||
|
||||
// There should be exactly one distinct value in the smaller list
|
||||
assertEquals(n/3, smallerHand.size());
|
||||
assertEquals(1,
|
||||
smallerHand.stream()
|
||||
.map(row -> row.getCovariateValue(petCovariate).getValue())
|
||||
.distinct()
|
||||
.count()
|
||||
);
|
||||
|
||||
// There should be exactly two distinct values in the smaller list
|
||||
assertEquals(2*n/3, largerHand.size());
|
||||
assertEquals(2,
|
||||
largerHand.stream()
|
||||
.map(row -> row.getCovariateValue(petCovariate).getValue())
|
||||
.distinct()
|
||||
.count()
|
||||
);
|
||||
|
||||
switch(smallerHand.get(0).getCovariateValue(petCovariate).getValue()){
|
||||
case "DOG":
|
||||
dog_catmouse = true;
|
||||
case "CAT":
|
||||
cat_dogmouse = true;
|
||||
case "MOUSE":
|
||||
mouse_dogcat = true;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
assertTrue(dog_catmouse);
|
||||
assertTrue(cat_dogmouse);
|
||||
assertTrue(mouse_dogcat);
|
||||
|
||||
}
|
||||
|
||||
/*
|
||||
* There was a bug where if number==0 in generateSplitRuleUpdater, then the result was empty.
|
||||
*/
|
||||
@Test
|
||||
public void testNumber0Subsets(){
|
||||
final int n = 2*3; // ensure that n is a multiple of 3 for the test
|
||||
final FactorCovariate petCovariate = createTestCovariate();
|
||||
final List<Row<Double>> data = generateSampleData(petCovariate, n);
|
||||
|
||||
final List<Split<Double, String>> splits = new ArrayList<>();
|
||||
|
||||
petCovariate.generateSplitRuleUpdater(data, 0, new Random())
|
||||
.forEachRemaining(split -> splits.add(split));
|
||||
|
||||
assertEquals(splits.size(), 3);
|
||||
|
||||
// These are the 3 possibilities
|
||||
boolean dog_catmouse = false;
|
||||
boolean cat_dogmouse = false;
|
||||
boolean mouse_dogcat = false;
|
||||
|
||||
for(Split<Double, String> split : splits){
|
||||
List<Row<Double>> smallerHand;
|
||||
List<Row<Double>> largerHand;
|
||||
|
||||
if(split.getLeftHand().size() < split.getRightHand().size()){
|
||||
smallerHand = split.getLeftHand();
|
||||
largerHand = split.getRightHand();
|
||||
} else{
|
||||
smallerHand = split.getRightHand();
|
||||
largerHand = split.getLeftHand();
|
||||
}
|
||||
|
||||
// There should be exactly one distinct value in the smaller list
|
||||
assertEquals(n/3, smallerHand.size());
|
||||
assertEquals(1,
|
||||
smallerHand.stream()
|
||||
.map(row -> row.getCovariateValue(petCovariate).getValue())
|
||||
.distinct()
|
||||
.count()
|
||||
);
|
||||
|
||||
// There should be exactly two distinct values in the smaller list
|
||||
assertEquals(2*n/3, largerHand.size());
|
||||
assertEquals(2,
|
||||
largerHand.stream()
|
||||
.map(row -> row.getCovariateValue(petCovariate).getValue())
|
||||
.distinct()
|
||||
.count()
|
||||
);
|
||||
|
||||
switch(smallerHand.get(0).getCovariateValue(petCovariate).getValue()){
|
||||
case "DOG":
|
||||
dog_catmouse = true;
|
||||
case "CAT":
|
||||
cat_dogmouse = true;
|
||||
case "MOUSE":
|
||||
mouse_dogcat = true;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
assertTrue(dog_catmouse);
|
||||
assertTrue(cat_dogmouse);
|
||||
assertTrue(mouse_dogcat);
|
||||
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testSpitRuleUpdaterWithNAs(){
|
||||
// When some NAs were present calling generateSplitRuleUpdater caused an exception.
|
||||
|
||||
final FactorCovariate covariate = createTestCovariate();
|
||||
final List<Row<Double>> sampleData = generateSampleData(covariate, 10);
|
||||
sampleData.add(Row.createSimple(Utils.easyMap("pet", "NA"), Collections.singletonList(covariate), 11, 5.0));
|
||||
|
||||
covariate.generateSplitRuleUpdater(sampleData, 0, new Random());
|
||||
|
||||
// Test passes if no exception has occurred.
|
||||
}
|
||||
|
||||
|
||||
private FactorCovariate createTestCovariate(){
|
||||
final List<String> levels = Utils.easyList("DOG", "CAT", "MOUSE");
|
||||
|
||||
return new FactorCovariate("pet", 0, levels);
|
||||
return new FactorCovariate("pet", 0, levels, false);
|
||||
}
|
||||
|
||||
private List<Row<Double>> generateSampleData(Covariate covariate, int n){
|
||||
final List<Covariate> covariateList = Collections.singletonList(covariate);
|
||||
final List<Row<Double>> dataList = new ArrayList<>(n);
|
||||
|
||||
final String[] levels = new String[]{"DOG", "CAT", "MOUSE"};
|
||||
|
||||
for(int i=0; i<n; i++){
|
||||
dataList.add(Row.createSimple(Utils.easyMap("pet", levels[i % 3]), covariateList, 1, 1.0));
|
||||
}
|
||||
|
||||
return dataList;
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -70,7 +70,7 @@ public class NumericCovariateTest {
|
|||
|
||||
@Test
|
||||
public void testNumericCovariateDeterministic(){
|
||||
final NumericCovariate covariate = new NumericCovariate("x", 0);
|
||||
final NumericCovariate covariate = new NumericCovariate("x", 0, false);
|
||||
|
||||
final List<Row<Double>> dataset = createTestDataset(covariate);
|
||||
|
||||
|
@ -158,7 +158,7 @@ public class NumericCovariateTest {
|
|||
|
||||
@Test
|
||||
public void testNumericSplitRuleUpdaterWithIndexes(){
|
||||
final NumericCovariate covariate = new NumericCovariate("x", 0);
|
||||
final NumericCovariate covariate = new NumericCovariate("x", 0, false);
|
||||
|
||||
final List<Row<Double>> dataset = createTestDataset(covariate);
|
||||
|
||||
|
@ -223,7 +223,7 @@ public class NumericCovariateTest {
|
|||
*/
|
||||
@Test
|
||||
public void testNumericSplitRuleUpdaterWithIndexesAllMissingData(){
|
||||
final NumericCovariate covariate = new NumericCovariate("x", 0);
|
||||
final NumericCovariate covariate = new NumericCovariate("x", 0, false);
|
||||
final List<Row<Double>> dataset = createTestDatasetMissingValues(covariate);
|
||||
final NumericSplitRuleUpdater<Double> updater = covariate.generateSplitRuleUpdater(dataset, 5, new Random());
|
||||
|
||||
|
|
|
@ -18,31 +18,34 @@ package ca.joeltherrien.randomforest.nas;
|
|||
|
||||
import ca.joeltherrien.randomforest.Row;
|
||||
import ca.joeltherrien.randomforest.covariates.Covariate;
|
||||
import ca.joeltherrien.randomforest.covariates.bool.BooleanCovariate;
|
||||
import ca.joeltherrien.randomforest.covariates.factor.FactorCovariate;
|
||||
import ca.joeltherrien.randomforest.covariates.numeric.NumericCovariate;
|
||||
import ca.joeltherrien.randomforest.responses.regression.MeanResponseCombiner;
|
||||
import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceSplitFinder;
|
||||
import ca.joeltherrien.randomforest.tree.Split;
|
||||
import ca.joeltherrien.randomforest.tree.TreeTrainer;
|
||||
import ca.joeltherrien.randomforest.utils.Utils;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Random;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.*;
|
||||
|
||||
public class TestNAs {
|
||||
|
||||
private List<Row<Double>> generateData(List<Covariate> covariates){
|
||||
private List<Row<Double>> generateData1(List<Covariate> covariates){
|
||||
final List<Row<Double>> dataList = new ArrayList<>();
|
||||
|
||||
|
||||
// We must include an NA for one of the values
|
||||
dataList.add(Row.createSimple(Utils.easyMap("x", "NA"), covariates, 1, 5.0));
|
||||
dataList.add(Row.createSimple(Utils.easyMap("x", "1"), covariates, 1, 6.0));
|
||||
dataList.add(Row.createSimple(Utils.easyMap("x", "2"), covariates, 1, 5.5));
|
||||
dataList.add(Row.createSimple(Utils.easyMap("x", "7"), covariates, 1, 0.0));
|
||||
dataList.add(Row.createSimple(Utils.easyMap("x", "8"), covariates, 1, 1.0));
|
||||
dataList.add(Row.createSimple(Utils.easyMap("x", "8.4"), covariates, 1, 1.0));
|
||||
dataList.add(Row.createSimple(Utils.easyMap("x", "NA", "y", "true", "z", "green"), covariates, 1, 5.0));
|
||||
dataList.add(Row.createSimple(Utils.easyMap("x", "1", "y", "NA", "z", "blue"), covariates, 2, 6.0));
|
||||
dataList.add(Row.createSimple(Utils.easyMap("x", "2", "y", "true", "z", "NA"), covariates, 3, 5.5));
|
||||
dataList.add(Row.createSimple(Utils.easyMap("x", "7", "y", "false", "z", "green"), covariates, 4, 0.0));
|
||||
dataList.add(Row.createSimple(Utils.easyMap("x", "8", "y", "true", "z", "blue"), covariates, 5, 1.0));
|
||||
dataList.add(Row.createSimple(Utils.easyMap("x", "8.4", "y", "false", "z", "yellow"), covariates, 6, 1.0));
|
||||
|
||||
|
||||
return dataList;
|
||||
|
@ -54,14 +57,19 @@ public class TestNAs {
|
|||
// but NumericSplitRuleUpdater had unmodifiable lists when creating the split.
|
||||
// This bug verifies that this no longer causes a crash
|
||||
|
||||
final List<Covariate> covariates = Collections.singletonList(new NumericCovariate("x", 0));
|
||||
final List<Row<Double>> dataset = generateData(covariates);
|
||||
final List<Covariate> covariates = Utils.easyList(
|
||||
new NumericCovariate("x", 0, false),
|
||||
new BooleanCovariate("y", 1, true),
|
||||
new FactorCovariate("z", 2, Utils.easyList("green", "blue", "yellow"), true)
|
||||
);
|
||||
final List<Row<Double>> dataset = generateData1(covariates);
|
||||
|
||||
final TreeTrainer<Double, Double> treeTrainer = TreeTrainer.<Double, Double>builder()
|
||||
.checkNodePurity(false)
|
||||
.covariates(covariates)
|
||||
.numberOfSplits(0)
|
||||
.nodeSize(1)
|
||||
.mtry(3)
|
||||
.maxNodeDepth(1000)
|
||||
.splitFinder(new WeightedVarianceSplitFinder())
|
||||
.responseCombiner(new MeanResponseCombiner())
|
||||
|
@ -70,6 +78,87 @@ public class TestNAs {
|
|||
treeTrainer.growTree(dataset, new Random(123));
|
||||
|
||||
// As long as no exception occurs, we passed
|
||||
}
|
||||
|
||||
private List<Row<Double>> generateData2(List<Covariate> covariates){
|
||||
final List<Row<Double>> dataList = new ArrayList<>();
|
||||
// Idea - when ignoring NAs, BadVar gives a perfect split.
|
||||
// GoodVar is slightly worse than BadVar when NAs are excluded.
|
||||
// However, BadVar has a ton of NAs that will degrade its performance.
|
||||
dataList.add(Row.createSimple(
|
||||
Utils.easyMap("BadVar", "-1.0", "GoodVar", "true") // GoodVars one error
|
||||
, covariates, 1, 5.0)
|
||||
);
|
||||
dataList.add(Row.createSimple(
|
||||
Utils.easyMap("BadVar", "NA", "GoodVar", "false")
|
||||
, covariates, 2, 5.0)
|
||||
);
|
||||
dataList.add(Row.createSimple(
|
||||
Utils.easyMap("BadVar", "NA", "GoodVar", "false")
|
||||
, covariates, 3, 5.0)
|
||||
);
|
||||
dataList.add(Row.createSimple(
|
||||
Utils.easyMap("BadVar", "0.5", "GoodVar", "true")
|
||||
, covariates, 4, 10.0)
|
||||
);
|
||||
dataList.add(Row.createSimple(
|
||||
Utils.easyMap("BadVar", "NA", "GoodVar", "true")
|
||||
, covariates, 5, 10.0)
|
||||
);
|
||||
dataList.add(Row.createSimple(
|
||||
Utils.easyMap("BadVar", "NA", "GoodVar", "true")
|
||||
, covariates, 6, 10.0)
|
||||
);
|
||||
|
||||
return dataList;
|
||||
}
|
||||
|
||||
@Test
|
||||
// Test that the NA penalty works when selecting a best split.
|
||||
public void testNAPenalty(){
|
||||
final List<Covariate> covariates1 = Utils.easyList(
|
||||
new NumericCovariate("BadVar", 0, true),
|
||||
new BooleanCovariate("GoodVar", 1, false)
|
||||
);
|
||||
|
||||
final List<Row<Double>> dataList1 = generateData2(covariates1);
|
||||
|
||||
final TreeTrainer<Double, Double> treeTrainer1 = TreeTrainer.<Double, Double>builder()
|
||||
.checkNodePurity(false)
|
||||
.covariates(covariates1)
|
||||
.numberOfSplits(0)
|
||||
.nodeSize(1)
|
||||
.mtry(2)
|
||||
.maxNodeDepth(1000)
|
||||
.splitFinder(new WeightedVarianceSplitFinder())
|
||||
.responseCombiner(new MeanResponseCombiner())
|
||||
.build();
|
||||
|
||||
final Split<Double, ?> bestSplit1 = treeTrainer1.findBestSplitRule(dataList1, covariates1, new Random(123));
|
||||
assertEquals(1, bestSplit1.getSplitRule().getParentCovariateIndex()); // 1 corresponds to GoodVar
|
||||
|
||||
// Run again without the penalty; verify that we get different results
|
||||
|
||||
final List<Covariate> covariates2 = Utils.easyList(
|
||||
new NumericCovariate("BadVar", 0, false),
|
||||
new BooleanCovariate("GoodVar", 1, false)
|
||||
);
|
||||
|
||||
final List<Row<Double>> dataList2 = generateData2(covariates2);
|
||||
|
||||
final TreeTrainer<Double, Double> treeTrainer2 = TreeTrainer.<Double, Double>builder()
|
||||
.checkNodePurity(false)
|
||||
.covariates(covariates2)
|
||||
.numberOfSplits(0)
|
||||
.nodeSize(1)
|
||||
.mtry(2)
|
||||
.maxNodeDepth(1000)
|
||||
.splitFinder(new WeightedVarianceSplitFinder())
|
||||
.responseCombiner(new MeanResponseCombiner())
|
||||
.build();
|
||||
|
||||
final Split<Double, ?> bestSplit2 = treeTrainer2.findBestSplitRule(dataList2, covariates2, new Random(123));
|
||||
assertEquals(0, bestSplit2.getSplitRule().getParentCovariateIndex()); // 1 corresponds to GoodVar
|
||||
|
||||
|
||||
}
|
||||
|
|
|
@ -0,0 +1,143 @@
|
|||
package ca.joeltherrien.randomforest.tree.vimp;
|
||||
|
||||
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskFunctions;
|
||||
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
|
||||
import ca.joeltherrien.randomforest.responses.competingrisk.IBSCalculator;
|
||||
import ca.joeltherrien.randomforest.utils.Point;
|
||||
import ca.joeltherrien.randomforest.utils.RightContinuousStepFunction;
|
||||
import ca.joeltherrien.randomforest.utils.Utils;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
|
||||
public class IBSErrorCalculatorWrapperTest {
|
||||
|
||||
/*
|
||||
We already have tests for the IBSCalculator, so these tests are concerned with making sure we correctly average
|
||||
the errors together, not that we fully test the production of each error under different scenarios (like
|
||||
providing / not providing a censoring distribution).
|
||||
*/
|
||||
|
||||
private final double integrationUpperBound = 5.0;
|
||||
|
||||
private final List<CompetingRiskResponse> responses;
|
||||
private final List<CompetingRiskFunctions> functions;
|
||||
|
||||
|
||||
private final double[][] errors;
|
||||
|
||||
public IBSErrorCalculatorWrapperTest(){
|
||||
this.responses = Utils.easyList(
|
||||
new CompetingRiskResponse(0, 2.0),
|
||||
new CompetingRiskResponse(0, 3.0),
|
||||
new CompetingRiskResponse(1, 1.0),
|
||||
new CompetingRiskResponse(1, 1.5),
|
||||
new CompetingRiskResponse(2, 3.0),
|
||||
new CompetingRiskResponse(2, 4.0)
|
||||
);
|
||||
|
||||
final RightContinuousStepFunction cif1 = RightContinuousStepFunction.constructFromPoints(Utils.easyList(
|
||||
new Point(1.0, 0.25),
|
||||
new Point(1.5, 0.45)
|
||||
), 0.0);
|
||||
|
||||
final RightContinuousStepFunction cif2 = RightContinuousStepFunction.constructFromPoints(Utils.easyList(
|
||||
new Point(3.0, 0.25),
|
||||
new Point(4.0, 0.45)
|
||||
), 0.0);
|
||||
|
||||
// This function is for the unused CHFs and survival curve
|
||||
// If we see infinities or NaNs popping up in our output we should look here.
|
||||
final RightContinuousStepFunction emptyFun = RightContinuousStepFunction.constructFromPoints(Utils.easyList(
|
||||
new Point(0.0, Double.NaN)
|
||||
), Double.NEGATIVE_INFINITY
|
||||
);
|
||||
|
||||
final CompetingRiskFunctions function = CompetingRiskFunctions.builder()
|
||||
.cumulativeIncidenceCurves(Utils.easyList(cif1, cif2))
|
||||
.causeSpecificHazards(Utils.easyList(emptyFun, emptyFun))
|
||||
.survivalCurve(emptyFun)
|
||||
.build();
|
||||
|
||||
// Same prediction for every response.
|
||||
this.functions = Utils.easyList(function, function, function, function, function, function);
|
||||
|
||||
final IBSCalculator calculator = new IBSCalculator();
|
||||
this.errors = new double[2][6];
|
||||
|
||||
for(int event : new int[]{1, 2}){
|
||||
for(int i=0; i<6; i++){
|
||||
this.errors[event-1][i] = calculator.calculateError(
|
||||
responses.get(i), function.getCumulativeIncidenceFunction(event),
|
||||
event, integrationUpperBound
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testOneEventOne(){
|
||||
final IBSErrorCalculatorWrapper wrapper = new IBSErrorCalculatorWrapper(new IBSCalculator(), new int[]{1},
|
||||
this.integrationUpperBound);
|
||||
|
||||
final double error = wrapper.averageError(this.responses, this.functions);
|
||||
double expectedError = 0.0;
|
||||
for(int i=0; i<6; i++){
|
||||
expectedError += errors[0][i] / 6.0;
|
||||
}
|
||||
|
||||
assertEquals(expectedError, error, 0.00000001);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testOneEventTwo(){
|
||||
final IBSErrorCalculatorWrapper wrapper = new IBSErrorCalculatorWrapper(new IBSCalculator(), new int[]{2},
|
||||
this.integrationUpperBound);
|
||||
|
||||
final double error = wrapper.averageError(this.responses, this.functions);
|
||||
double expectedError = 0.0;
|
||||
for(int i=0; i<6; i++){
|
||||
expectedError += errors[1][i] / 6.0;
|
||||
}
|
||||
|
||||
assertEquals(expectedError, error, 0.00000001);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testTwoEventsNoWeights(){
|
||||
final IBSErrorCalculatorWrapper wrapper = new IBSErrorCalculatorWrapper(new IBSCalculator(), new int[]{1, 2},
|
||||
this.integrationUpperBound);
|
||||
|
||||
final double error = wrapper.averageError(this.responses, this.functions);
|
||||
double expectedError1 = 0.0;
|
||||
double expectedError2 = 0.0;
|
||||
|
||||
for(int i=0; i<6; i++){
|
||||
expectedError1 += errors[0][i] / 6.0;
|
||||
expectedError2 += errors[1][i] / 6.0;
|
||||
}
|
||||
|
||||
assertEquals(expectedError1 + expectedError2, error, 0.00000001);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testTwoEventsWithWeights(){
|
||||
final IBSErrorCalculatorWrapper wrapper = new IBSErrorCalculatorWrapper(new IBSCalculator(), new int[]{1, 2},
|
||||
this.integrationUpperBound, new double[]{1.0, 2.0});
|
||||
|
||||
final double error = wrapper.averageError(this.responses, this.functions);
|
||||
double expectedError1 = 0.0;
|
||||
double expectedError2 = 0.0;
|
||||
|
||||
for(int i=0; i<6; i++){
|
||||
expectedError1 += errors[0][i] / 6.0;
|
||||
expectedError2 += errors[1][i] / 6.0;
|
||||
}
|
||||
|
||||
assertEquals(1.0 * expectedError1 + 2.0 * expectedError2, error, 0.00000001);
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,26 @@
|
|||
package ca.joeltherrien.randomforest.tree.vimp;
|
||||
|
||||
import ca.joeltherrien.randomforest.utils.Utils;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
|
||||
public class RegressionErrorCalculatorTest {
|
||||
|
||||
private final RegressionErrorCalculator calculator = new RegressionErrorCalculator();
|
||||
|
||||
@Test
|
||||
public void testRegressionErrorCalculator(){
|
||||
final List<Double> responses = Utils.easyList(1.0, 1.5, 0.0, 3.0);
|
||||
final List<Double> predictions = Utils.easyList(1.5, 1.7, 0.1, 2.9);
|
||||
|
||||
// Differences are 0.5, 0.2, -0.1, 0.1
|
||||
// Squared: 0.25, 0.04, 0.01, 0.01
|
||||
|
||||
assertEquals((0.25 + 0.04 + 0.01 + 0.01)/4.0, calculator.averageError(responses, predictions), 0.000000001);
|
||||
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,484 @@
|
|||
package ca.joeltherrien.randomforest.tree.vimp;
|
||||
|
||||
import ca.joeltherrien.randomforest.Row;
|
||||
import ca.joeltherrien.randomforest.covariates.Covariate;
|
||||
import ca.joeltherrien.randomforest.covariates.bool.BooleanCovariate;
|
||||
import ca.joeltherrien.randomforest.covariates.bool.BooleanSplitRule;
|
||||
import ca.joeltherrien.randomforest.covariates.factor.FactorCovariate;
|
||||
import ca.joeltherrien.randomforest.covariates.numeric.NumericCovariate;
|
||||
import ca.joeltherrien.randomforest.covariates.numeric.NumericSplitRule;
|
||||
import ca.joeltherrien.randomforest.responses.regression.MeanResponseCombiner;
|
||||
import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceSplitFinder;
|
||||
import ca.joeltherrien.randomforest.tree.*;
|
||||
import ca.joeltherrien.randomforest.utils.Utils;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
import java.util.Random;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
|
||||
public class TestVariableImportanceCalculator {
|
||||
|
||||
/*
|
||||
Since the logic for VariableImportanceCalculator is generic, it will be much easier to test under a regression
|
||||
setting.
|
||||
*/
|
||||
|
||||
// We'l have a very simple Forest of two trees
|
||||
private final OnlineForest<Double, Double> forest;
|
||||
|
||||
|
||||
private final List<Covariate> covariates;
|
||||
private final List<Row<Double>> rowList;
|
||||
|
||||
/*
|
||||
Long setup process; forest is manually constructed so that we can be exactly sure on our variable importance.
|
||||
|
||||
*/
|
||||
public TestVariableImportanceCalculator(){
|
||||
final BooleanCovariate booleanCovariate = new BooleanCovariate("x", 0, false);
|
||||
final NumericCovariate numericCovariate = new NumericCovariate("y", 1, false);
|
||||
final FactorCovariate factorCovariate = new FactorCovariate("z", 2,
|
||||
Utils.easyList("red", "blue", "green"), false);
|
||||
|
||||
this.covariates = Utils.easyList(booleanCovariate, numericCovariate, factorCovariate);
|
||||
|
||||
final TreeTrainer<Double, Double> treeTrainer = TreeTrainer.<Double, Double>builder()
|
||||
.responseCombiner(new MeanResponseCombiner())
|
||||
.splitFinder(new WeightedVarianceSplitFinder())
|
||||
.numberOfSplits(0)
|
||||
.nodeSize(1)
|
||||
.maxNodeDepth(100)
|
||||
.mtry(3)
|
||||
.checkNodePurity(false)
|
||||
.covariates(this.covariates)
|
||||
.build();
|
||||
|
||||
/*
|
||||
Plan for data - BooleanCovariate is split on first and has the largest impact.
|
||||
NumericCovariate is at second level and has more minimal impact.
|
||||
FactorCovariate is useless and never used.
|
||||
Our tree (we'll duplicate it for testing OOB errors) will have a depth of 1. (0 based).
|
||||
*/
|
||||
|
||||
final Tree<Double> tree1 = makeTree(covariates, 0.0, new int[]{1,2,3,4});
|
||||
final Tree<Double> tree2 = makeTree(covariates, 2.0, new int[]{5,6,7,8});
|
||||
|
||||
this.forest = OnlineForest.<Double, Double>builder()
|
||||
.trees(Utils.easyList(tree1, tree2))
|
||||
.treeResponseCombiner(new MeanResponseCombiner())
|
||||
.build();
|
||||
|
||||
// formula; boolean high adds 100; high numeric adds 10
|
||||
// This row list should have a baseline error of 0.0
|
||||
this.rowList = Utils.easyList(
|
||||
Row.createSimple(Utils.easyMap(
|
||||
"x", "false",
|
||||
"y", "0.0",
|
||||
"z", "red"),
|
||||
covariates, 1, 0.0
|
||||
),
|
||||
Row.createSimple(Utils.easyMap(
|
||||
"x", "false",
|
||||
"y", "10.0",
|
||||
"z", "blue"),
|
||||
covariates, 2, 10.0
|
||||
),
|
||||
Row.createSimple(Utils.easyMap(
|
||||
"x", "true",
|
||||
"y", "0.0",
|
||||
"z", "red"),
|
||||
covariates, 3, 100.0
|
||||
),
|
||||
Row.createSimple(Utils.easyMap(
|
||||
"x", "true",
|
||||
"y", "10.0",
|
||||
"z", "green"),
|
||||
covariates, 4, 110.0
|
||||
),
|
||||
|
||||
Row.createSimple(Utils.easyMap(
|
||||
"x", "false",
|
||||
"y", "0.0",
|
||||
"z", "red"),
|
||||
covariates, 5, 0.0
|
||||
),
|
||||
Row.createSimple(Utils.easyMap(
|
||||
"x", "false",
|
||||
"y", "10.0",
|
||||
"z", "blue"),
|
||||
covariates, 6, 10.0
|
||||
),
|
||||
Row.createSimple(Utils.easyMap(
|
||||
"x", "true",
|
||||
"y", "0.0",
|
||||
"z", "red"),
|
||||
covariates, 7, 100.0
|
||||
),
|
||||
Row.createSimple(Utils.easyMap(
|
||||
"x", "true",
|
||||
"y", "10.0",
|
||||
"z", "green"),
|
||||
covariates, 8, 110.0
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
private Tree<Double> makeTree(List<Covariate> covariates, double offset, int[] indices){
|
||||
// Naming convention - xyTerminal where x and y are low/high denotes whether BooleanCovariate(x) is low/high and
|
||||
// whether NumericCovariate(y) is low/high.
|
||||
final TerminalNode<Double> lowLowTerminal = new TerminalNode<>(0.0 + offset, 5);
|
||||
final TerminalNode<Double> lowHighTerminal = new TerminalNode<>(10.0 + offset, 5);
|
||||
final TerminalNode<Double> highLowTerminal = new TerminalNode<>(100.0 + offset, 5);
|
||||
final TerminalNode<Double> highHighTerminal = new TerminalNode<>(110.0 + offset, 5);
|
||||
|
||||
final SplitNode<Double> lowSplitNode = SplitNode.<Double>builder()
|
||||
.leftHand(lowLowTerminal)
|
||||
.rightHand(lowHighTerminal)
|
||||
.probabilityNaLeftHand(0.5)
|
||||
.splitRule(new NumericSplitRule((NumericCovariate) covariates.get(1), 5.0))
|
||||
.build();
|
||||
|
||||
final SplitNode<Double> highSplitNode = SplitNode.<Double>builder()
|
||||
.leftHand(highLowTerminal)
|
||||
.rightHand(highHighTerminal)
|
||||
.probabilityNaLeftHand(0.5)
|
||||
.splitRule(new NumericSplitRule((NumericCovariate) covariates.get(1), 5.0))
|
||||
.build();
|
||||
|
||||
final SplitNode<Double> rootSplitNode = SplitNode.<Double>builder()
|
||||
.leftHand(lowSplitNode)
|
||||
.rightHand(highSplitNode)
|
||||
.probabilityNaLeftHand(0.5)
|
||||
.splitRule(new BooleanSplitRule((BooleanCovariate) covariates.get(0)))
|
||||
.build();
|
||||
|
||||
return new Tree<>(rootSplitNode, indices);
|
||||
|
||||
}
|
||||
|
||||
// Experiment with random seeds to first examine what a split does so we know what to expect
|
||||
/*
|
||||
public static void main(String[] args){
|
||||
|
||||
// Behaviour for OOB
|
||||
final List<Integer> ints1 = IntStream.range(5, 9).boxed().collect(Collectors.toList());
|
||||
final List<Integer> ints2 = IntStream.range(1, 5).boxed().collect(Collectors.toList());
|
||||
|
||||
final Random random = new Random(123);
|
||||
Collections.shuffle(ints1, random);
|
||||
Collections.shuffle(ints2, random);
|
||||
|
||||
System.out.println(ints1);
|
||||
System.out.println(ints2);
|
||||
// [5, 6, 8, 7]
|
||||
// [3, 4, 1, 2]
|
||||
|
||||
|
||||
// Behaviour for no-OOB
|
||||
final List<Integer> fullInts1 = IntStream.range(1, 9).boxed().collect(Collectors.toList());
|
||||
final List<Integer> fullInts2 = IntStream.range(1, 9).boxed().collect(Collectors.toList());
|
||||
final Random fullIntsRandom = new Random(123);
|
||||
|
||||
Collections.shuffle(fullInts1, fullIntsRandom);
|
||||
Collections.shuffle(fullInts2, fullIntsRandom);
|
||||
System.out.println(fullInts1);
|
||||
System.out.println(fullInts2);
|
||||
// [1, 4, 8, 2, 5, 3, 7, 6]
|
||||
// [6, 1, 4, 7, 5, 2, 8, 3]
|
||||
|
||||
}
|
||||
*/
|
||||
|
||||
|
||||
private double[] difference(double[] a, double[] b){
|
||||
final double[] results = new double[a.length];
|
||||
|
||||
for(int i = 0; i < a.length; i++){
|
||||
results[i] = a[i] - b[i];
|
||||
}
|
||||
|
||||
return results;
|
||||
}
|
||||
|
||||
private void assertDoubleEquals(double[] expected, double[] actual){
|
||||
assertEquals(expected.length, actual.length, "Lengths of arrays should be equal");
|
||||
|
||||
for(int i=0; i < expected.length; i++){
|
||||
assertEquals(expected[i], actual[i], 0.0000001, "Difference at " + i);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
||||
@Test
|
||||
public void testVariableImportanceOnXNoOOB(){
|
||||
// x is the BooleanCovariate
|
||||
|
||||
final VariableImportanceCalculator<Double, Double> calculator = new VariableImportanceCalculator<>(
|
||||
new RegressionErrorCalculator(),
|
||||
this.forest.getTrees(),
|
||||
this.rowList,
|
||||
false
|
||||
);
|
||||
|
||||
final Covariate covariate = this.covariates.get(0);
|
||||
|
||||
double importance[] = calculator.calculateVariableImportanceRaw(covariate, Optional.of(new Random(123)));
|
||||
|
||||
final double[] expectedBaselineError = {0.0, 4.0}; // first tree is accurate, second tree is not
|
||||
|
||||
// [1, 4, 8, 2, 5, 3, 7, 6]
|
||||
final List<Double> permutedPredictionsTree1 = Utils.easyList(
|
||||
0., 110., 100., 10., 0., 110., 100., 10.
|
||||
);
|
||||
|
||||
// [6, 1, 4, 7, 5, 2, 8, 3]
|
||||
// Actual: [F, F, T, T, F, F, T, T]
|
||||
// Seen: [F, F, T, T, F, F, T, T]
|
||||
// Difference: 0 all around; random chance
|
||||
final List<Double> permutedPredictionsTree2 = Utils.easyList(
|
||||
2., 12., 102., 112., 2., 12., 102., 112.
|
||||
);
|
||||
|
||||
final List<Double> observedValues = this.rowList.stream().map(Row::getResponse).collect(Collectors.toList());
|
||||
|
||||
final double[] expectedError = new double[2];
|
||||
|
||||
expectedError[0] = new RegressionErrorCalculator().averageError(observedValues, permutedPredictionsTree1);
|
||||
expectedError[1] = new RegressionErrorCalculator().averageError(observedValues, permutedPredictionsTree2);
|
||||
|
||||
final double[] expectedVimp = difference(expectedError, expectedBaselineError);
|
||||
|
||||
assertDoubleEquals(expectedVimp, importance);
|
||||
|
||||
final double expectedVimpMean = (expectedVimp[0] + expectedVimp[1]) / 2.0;
|
||||
final double expectedVimpVar = (Math.pow(expectedVimp[0] - expectedVimpMean, 2) + Math.pow(expectedVimp[1] - expectedVimpMean, 2)) / 1.0;
|
||||
final double expectedVimpStandardError = Math.sqrt(expectedVimpVar / 2.0);
|
||||
final double expectedZScore = expectedVimpMean / expectedVimpStandardError;
|
||||
|
||||
final double actualZScore = calculator.calculateVariableImportanceZScore(covariate, Optional.of(new Random(123)));
|
||||
|
||||
assertEquals(expectedZScore, actualZScore, 0.000001, "Z scores must match");
|
||||
}
|
||||
|
||||
|
||||
|
||||
@Test
|
||||
public void testVariableImportanceOnXOOB(){
|
||||
// x is the BooleanCovariate
|
||||
|
||||
final VariableImportanceCalculator<Double, Double> calculator = new VariableImportanceCalculator<>(
|
||||
new RegressionErrorCalculator(),
|
||||
this.forest.getTrees(),
|
||||
this.rowList,
|
||||
true
|
||||
);
|
||||
|
||||
final Covariate covariate = this.covariates.get(0);
|
||||
|
||||
double importance[] = calculator.calculateVariableImportanceRaw(covariate, Optional.of(new Random(123)));
|
||||
|
||||
final double[] expectedBaselineError = {0.0, 4.0}; // first tree is accurate, second tree is not
|
||||
|
||||
// [5, 6, 8, 7]
|
||||
// Actual: [F, F, T, T]
|
||||
// Seen: [F, F, T, T]
|
||||
// Difference: No differences
|
||||
final List<Double> permutedPredictionsTree1 = Utils.easyList(
|
||||
0., 10., 100., 110.
|
||||
);
|
||||
|
||||
// [3, 4, 1, 2]
|
||||
// Actual: [F, F, T, T]
|
||||
// Seen: [T, T, F, F]
|
||||
// Difference: +100, +100, -100, -100
|
||||
final List<Double> permutedPredictionsTree2 = Utils.easyList(
|
||||
102., 112., 2., 12.
|
||||
);
|
||||
|
||||
final List<Double> observedValues = this.rowList.stream().map(Row::getResponse).collect(Collectors.toList());
|
||||
final List<Double> tree1OOBValues = observedValues.subList(4, 8);
|
||||
final List<Double> tree2OOBValues = observedValues.subList(0, 4);
|
||||
|
||||
final double[] expectedError = new double[2];
|
||||
|
||||
expectedError[0] = new RegressionErrorCalculator().averageError(tree1OOBValues, permutedPredictionsTree1);
|
||||
expectedError[1] = new RegressionErrorCalculator().averageError(tree2OOBValues, permutedPredictionsTree2);
|
||||
|
||||
final double[] expectedVimp = difference(expectedError, expectedBaselineError);
|
||||
|
||||
assertDoubleEquals(expectedVimp, importance);
|
||||
|
||||
final double expectedVimpMean = (expectedVimp[0] + expectedVimp[1]) / 2.0;
|
||||
final double expectedVimpVar = (Math.pow(expectedVimp[0] - expectedVimpMean, 2) + Math.pow(expectedVimp[1] - expectedVimpMean, 2)) / 1.0;
|
||||
final double expectedVimpStandardError = Math.sqrt(expectedVimpVar / 2.0);
|
||||
final double expectedZScore = expectedVimpMean / expectedVimpStandardError;
|
||||
|
||||
final double actualZScore = calculator.calculateVariableImportanceZScore(covariate, Optional.of(new Random(123)));
|
||||
|
||||
assertEquals(expectedZScore, actualZScore, 0.000001, "Z scores must match");
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void testVariableImportanceOnYNoOOB(){
|
||||
// y is the NumericCovariate
|
||||
|
||||
final VariableImportanceCalculator<Double, Double> calculator = new VariableImportanceCalculator<>(
|
||||
new RegressionErrorCalculator(),
|
||||
this.forest.getTrees(),
|
||||
this.rowList,
|
||||
false
|
||||
);
|
||||
|
||||
final Covariate covariate = this.covariates.get(1);
|
||||
|
||||
double importance[] = calculator.calculateVariableImportanceRaw(covariate, Optional.of(new Random(123)));
|
||||
|
||||
final double[] expectedBaselineError = {0.0, 4.0}; // first tree is accurate, second tree is not
|
||||
|
||||
// [1, 4, 8, 2, 5, 3, 7, 6]
|
||||
// Actual: [F, T, F, T, F, T, F, T]
|
||||
// Seen: [F, T, T, T, F, F, F, T]
|
||||
// Difference: [=, =, +, =, =, -, =, =]x10
|
||||
final List<Double> permutedPredictionsTree1 = Utils.easyList(
|
||||
0., 10., 110., 110., 0., 0., 100., 110.
|
||||
);
|
||||
|
||||
// [6, 1, 4, 7, 5, 2, 8, 3]
|
||||
// Actual: [F, T, F, T, F, T, F, T]
|
||||
// Seen: [T, F, T, F, F, T, T, F]
|
||||
// Difference: [+, -, +, -, =, =, +, -]
|
||||
final List<Double> permutedPredictionsTree2 = Utils.easyList(
|
||||
12., 2., 112., 102., 2., 12., 112., 102.
|
||||
);
|
||||
|
||||
final List<Double> observedValues = this.rowList.stream().map(Row::getResponse).collect(Collectors.toList());
|
||||
|
||||
final double[] expectedError = new double[2];
|
||||
|
||||
expectedError[0] = new RegressionErrorCalculator().averageError(observedValues, permutedPredictionsTree1);
|
||||
expectedError[1] = new RegressionErrorCalculator().averageError(observedValues, permutedPredictionsTree2);
|
||||
|
||||
final double[] expectedVimp = difference(expectedError, expectedBaselineError);
|
||||
|
||||
assertDoubleEquals(expectedVimp, importance);
|
||||
|
||||
final double expectedVimpMean = (expectedVimp[0] + expectedVimp[1]) / 2.0;
|
||||
final double expectedVimpVar = (Math.pow(expectedVimp[0] - expectedVimpMean, 2) + Math.pow(expectedVimp[1] - expectedVimpMean, 2)) / 1.0;
|
||||
final double expectedVimpStandardError = Math.sqrt(expectedVimpVar / 2.0);
|
||||
final double expectedZScore = expectedVimpMean / expectedVimpStandardError;
|
||||
|
||||
final double actualZScore = calculator.calculateVariableImportanceZScore(covariate, Optional.of(new Random(123)));
|
||||
|
||||
assertEquals(expectedZScore, actualZScore, 0.000001, "Z scores must match");
|
||||
}
|
||||
|
||||
|
||||
|
||||
@Test
|
||||
public void testVariableImportanceOnYOOB(){
|
||||
// y is the NumericCovariate
|
||||
|
||||
final VariableImportanceCalculator<Double, Double> calculator = new VariableImportanceCalculator<>(
|
||||
new RegressionErrorCalculator(),
|
||||
this.forest.getTrees(),
|
||||
this.rowList,
|
||||
true
|
||||
);
|
||||
|
||||
final Covariate covariate = this.covariates.get(1);
|
||||
|
||||
double importance[] = calculator.calculateVariableImportanceRaw(covariate, Optional.of(new Random(123)));
|
||||
|
||||
final double[] expectedBaselineError = {0.0, 4.0}; // first tree is accurate, second tree is not
|
||||
|
||||
// [5, 6, 8, 7]
|
||||
// Actual: [F, T, F, T]
|
||||
// Seen: [F, T, T, F]
|
||||
// Difference: [=, =, +, -]x10
|
||||
final List<Double> permutedPredictionsTree1 = Utils.easyList(
|
||||
0., 10., 110., 100.
|
||||
);
|
||||
|
||||
// [3, 4, 1, 2]
|
||||
// Actual: [F, T, F, T]
|
||||
// Seen: [F, T, F, T]
|
||||
// Difference: [=, =, =, =]x10 no change
|
||||
final List<Double> permutedPredictionsTree2 = Utils.easyList(
|
||||
2., 12., 102., 112.
|
||||
);
|
||||
|
||||
final List<Double> observedValues = this.rowList.stream().map(Row::getResponse).collect(Collectors.toList());
|
||||
final List<Double> tree1OOBValues = observedValues.subList(4, 8);
|
||||
final List<Double> tree2OOBValues = observedValues.subList(0, 4);
|
||||
|
||||
final double[] expectedError = new double[2];
|
||||
|
||||
expectedError[0] = new RegressionErrorCalculator().averageError(tree1OOBValues, permutedPredictionsTree1);
|
||||
expectedError[1] = new RegressionErrorCalculator().averageError(tree2OOBValues, permutedPredictionsTree2);
|
||||
|
||||
final double[] expectedVimp = difference(expectedError, expectedBaselineError);
|
||||
|
||||
assertDoubleEquals(expectedVimp, importance);
|
||||
|
||||
final double expectedVimpMean = (expectedVimp[0] + expectedVimp[1]) / 2.0;
|
||||
final double expectedVimpVar = (Math.pow(expectedVimp[0] - expectedVimpMean, 2) + Math.pow(expectedVimp[1] - expectedVimpMean, 2)) / 1.0;
|
||||
final double expectedVimpStandardError = Math.sqrt(expectedVimpVar / 2.0);
|
||||
final double expectedZScore = expectedVimpMean / expectedVimpStandardError;
|
||||
|
||||
final double actualZScore = calculator.calculateVariableImportanceZScore(covariate, Optional.of(new Random(123)));
|
||||
|
||||
assertEquals(expectedZScore, actualZScore, 0.000001, "Z scores must match");
|
||||
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
||||
@Test
|
||||
public void testVariableImportanceOnZNoOOB(){
|
||||
// z is the useless FactorCovariate
|
||||
|
||||
final VariableImportanceCalculator<Double, Double> calculator = new VariableImportanceCalculator<>(
|
||||
new RegressionErrorCalculator(),
|
||||
this.forest.getTrees(),
|
||||
this.rowList,
|
||||
false
|
||||
);
|
||||
|
||||
final double[] importance = calculator.calculateVariableImportanceRaw(this.covariates.get(2), Optional.of(new Random(123)));
|
||||
final double[] expectedImportance = {0.0, 0.0};
|
||||
|
||||
|
||||
// FactorImportance did nothing; so permuting it will make no difference to baseline error
|
||||
assertDoubleEquals(expectedImportance, importance);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testVariableImportanceOnZOOB(){
|
||||
// z is the useless FactorCovariate
|
||||
|
||||
final VariableImportanceCalculator<Double, Double> calculator = new VariableImportanceCalculator<>(
|
||||
new RegressionErrorCalculator(),
|
||||
this.forest.getTrees(),
|
||||
this.rowList,
|
||||
true
|
||||
);
|
||||
|
||||
final double[] importance = calculator.calculateVariableImportanceRaw(this.covariates.get(2), Optional.of(new Random(123)));
|
||||
final double[] expectedImportance = {0.0, 0.0};
|
||||
|
||||
|
||||
// FactorImportance did nothing; so permuting it will make no difference to baseline error
|
||||
assertDoubleEquals(expectedImportance, importance);
|
||||
}
|
||||
|
||||
|
||||
|
||||
}
|
|
@ -19,6 +19,9 @@ package ca.joeltherrien.randomforest.utils;
|
|||
import ca.joeltherrien.randomforest.TestUtils;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||
|
||||
public class RightContinuousStepFunctionIntegrationTest {
|
||||
|
||||
private RightContinuousStepFunction createTestFunction(){
|
||||
|
@ -75,5 +78,78 @@ public class RightContinuousStepFunctionIntegrationTest {
|
|||
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testInvertedFromTo(){
|
||||
final RightContinuousStepFunction function = createTestFunction();
|
||||
|
||||
final double area1 = function.integrate(0, 3.0);
|
||||
final double area2 = function.integrate(3.0, 0.0);
|
||||
|
||||
assertEquals(area1, -area2, 0.0000001);
|
||||
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testIntegratingUpToNan(){
|
||||
// Idea here - you have a function that is valid up to point x where it becomes NaN
|
||||
// You should be able to integrate *up to* point x and not get an NaN
|
||||
|
||||
final RightContinuousStepFunction function1 = new RightContinuousStepFunction(
|
||||
new double[]{1.0, 2.0, 3.0, 4.0},
|
||||
new double[]{1.0, 1.0, 1.0, Double.NaN},
|
||||
0.0);
|
||||
|
||||
|
||||
final double area1 = function1.integrate(0.0, 4.0);
|
||||
assertEquals(3.0, area1, 0.000000001);
|
||||
|
||||
final double nanArea1 = function1.integrate(0.0, 4.0001);
|
||||
assertTrue(Double.isNaN(nanArea1));
|
||||
|
||||
|
||||
// This tests integrating over the defaultY up to the NaN point
|
||||
final RightContinuousStepFunction function2 = new RightContinuousStepFunction(
|
||||
new double[]{1.0, 2.0, 3.0, 4.0},
|
||||
new double[]{Double.NaN, 1.0, 1.0, Double.NaN},
|
||||
1.0);
|
||||
|
||||
|
||||
final double area2 = function2.integrate(0.0, 1.0);
|
||||
assertEquals(1.0, area2, 0.000000001);
|
||||
|
||||
final double nanArea2 = function2.integrate(0.0, 4.0);
|
||||
assertTrue(Double.isNaN(nanArea2));
|
||||
|
||||
|
||||
// This tests integrating between two NaN points. Note that of course for RightContinuousValues carry the previous
|
||||
// value until the next x point, so this is just making sure the code works if the x-value we pass over is NaN
|
||||
final RightContinuousStepFunction function3 = new RightContinuousStepFunction(
|
||||
new double[]{1.0, 2.0, 3.0, 4.0},
|
||||
new double[]{Double.NaN, 1.0, 1.0, Double.NaN},
|
||||
0.0);
|
||||
|
||||
|
||||
final double area3 = function3.integrate(2.0, 4.0);
|
||||
assertEquals(2.0, area3, 0.000000001);
|
||||
|
||||
final double nanArea3 = function3.integrate(0.0, 4.0001);
|
||||
assertTrue(Double.isNaN(nanArea3));
|
||||
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void testIntegratingEmptyFunction(){
|
||||
// A function might have no points, but we'll still need to integrate it.
|
||||
|
||||
final RightContinuousStepFunction function = new RightContinuousStepFunction(
|
||||
new double[]{}, new double[]{}, 1.0
|
||||
);
|
||||
|
||||
final double area = function.integrate(1.0 ,3.0);
|
||||
assertEquals(2.0, area, 0.000001);
|
||||
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
|
|
@ -0,0 +1,180 @@
|
|||
package ca.joeltherrien.randomforest.utils;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import java.util.function.DoubleBinaryOperator;
|
||||
import java.util.function.DoubleUnaryOperator;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
|
||||
public class RightContinuousStepFunctionOperatorTests {
|
||||
|
||||
|
||||
// Idea - small and middle slightly overlap; middle and large slightly overlap.
|
||||
// small and large never overlap (i.e. small's x values always occur before large's)
|
||||
private final RightContinuousStepFunction smallNumbers;
|
||||
private final RightContinuousStepFunction middleNumbers;
|
||||
private final RightContinuousStepFunction largeNumbers;
|
||||
|
||||
private final double delta = 0.0000000001;
|
||||
|
||||
public RightContinuousStepFunctionOperatorTests(){
|
||||
smallNumbers = RightContinuousStepFunction.constructFromPoints(Utils.easyList(
|
||||
new Point(1.0, 1.0),
|
||||
new Point(2.0, 3.0),
|
||||
new Point(3.0, 2.0),
|
||||
new Point(4.0, 1.0)
|
||||
), 0.0);
|
||||
|
||||
middleNumbers = RightContinuousStepFunction.constructFromPoints(Utils.easyList(
|
||||
new Point(3.5, 4.0),
|
||||
new Point(4.0, 3.0),
|
||||
new Point(5.0, 2.0),
|
||||
new Point(6.0, 1.0)
|
||||
), 5.0);
|
||||
|
||||
largeNumbers = RightContinuousStepFunction.constructFromPoints(Utils.easyList(
|
||||
new Point(5.0, 5.0),
|
||||
new Point(6.0, 6.0),
|
||||
new Point(7.0, 3.0),
|
||||
new Point(8.0, 2.0)
|
||||
), 3.0);
|
||||
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testDifferenceNoOverlapLargeMinusSmall(){
|
||||
DoubleBinaryOperator operator = (a, b) -> a - b;
|
||||
|
||||
final RightContinuousStepFunction largeSmallDifference = RightContinuousStepFunction.biOperation(
|
||||
largeNumbers,
|
||||
smallNumbers,
|
||||
operator);
|
||||
|
||||
assertEquals(8, largeSmallDifference.getX().length);
|
||||
assertEquals(8, largeSmallDifference.getY().length);
|
||||
|
||||
final double[] offsetTimes = {-0.1, 0.0, 0.1};
|
||||
|
||||
for(int time = 1; time <= 9; time++){
|
||||
for(double offsetTime : offsetTimes){
|
||||
final double timeToEvaluateAt = (double) time + offsetTime;
|
||||
|
||||
final double largeFunEvaluation = largeNumbers.evaluate(timeToEvaluateAt);
|
||||
final double smallFunEvaluation = smallNumbers.evaluate(timeToEvaluateAt);
|
||||
final double expectedDifference = operator.applyAsDouble(largeFunEvaluation, smallFunEvaluation);
|
||||
|
||||
final double actualEvaluation = largeSmallDifference.evaluate(timeToEvaluateAt);
|
||||
|
||||
assertEquals(expectedDifference, actualEvaluation, delta);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testDifferenceNoOverlapSmallMinusLarge(){
|
||||
DoubleBinaryOperator operator = (a, b) -> a - b;
|
||||
|
||||
final RightContinuousStepFunction smallLargeDifference = RightContinuousStepFunction.biOperation(
|
||||
smallNumbers,
|
||||
largeNumbers,
|
||||
operator);
|
||||
|
||||
assertEquals(8, smallLargeDifference.getX().length);
|
||||
assertEquals(8, smallLargeDifference.getY().length);
|
||||
|
||||
final double[] offsetTimes = {-0.1, 0.0, 0.1};
|
||||
|
||||
for(int time = 1; time <= 9; time++){
|
||||
for(double offsetTime : offsetTimes){
|
||||
final double timeToEvaluateAt = (double) time + offsetTime;
|
||||
|
||||
final double smallFunEvaluation = smallNumbers.evaluate(timeToEvaluateAt);
|
||||
final double largeFunEvaluation = largeNumbers.evaluate(timeToEvaluateAt);
|
||||
final double expectedDifference = operator.applyAsDouble(smallFunEvaluation, largeFunEvaluation);
|
||||
|
||||
final double actualEvaluation = smallLargeDifference.evaluate(timeToEvaluateAt);
|
||||
|
||||
assertEquals(expectedDifference, actualEvaluation, delta);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testDifferenceSomeOverlapLargeMinusMiddle(){
|
||||
DoubleBinaryOperator operator = (a, b) -> a - b;
|
||||
|
||||
final RightContinuousStepFunction combinedFunction = RightContinuousStepFunction.biOperation(
|
||||
largeNumbers,
|
||||
middleNumbers,
|
||||
operator);
|
||||
|
||||
assertEquals(6, combinedFunction.getX().length);
|
||||
assertEquals(6, combinedFunction.getY().length);
|
||||
|
||||
final double[] offsetTimes = {-0.1, 0.0, 0.1};
|
||||
|
||||
for(int time = 1; time <= 9; time++){
|
||||
for(double offsetTime : offsetTimes){
|
||||
final double timeToEvaluateAt = (double) time + offsetTime;
|
||||
|
||||
final double middleFunEvaluation = middleNumbers.evaluate(timeToEvaluateAt);
|
||||
final double largeFunEvaluation = largeNumbers.evaluate(timeToEvaluateAt);
|
||||
final double expectedDifference = operator.applyAsDouble(largeFunEvaluation, middleFunEvaluation);
|
||||
|
||||
final double actualEvaluation = combinedFunction.evaluate(timeToEvaluateAt);
|
||||
|
||||
assertEquals(expectedDifference, actualEvaluation, delta);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testDifferenceCompleteOverlap(){
|
||||
DoubleBinaryOperator operator = (a, b) -> a - b;
|
||||
|
||||
final RightContinuousStepFunction combinedFunction = RightContinuousStepFunction.biOperation(
|
||||
middleNumbers,
|
||||
middleNumbers,
|
||||
operator);
|
||||
|
||||
assertEquals(4, combinedFunction.getX().length);
|
||||
assertEquals(4, combinedFunction.getY().length);
|
||||
|
||||
final double[] offsetTimes = {-0.1, 0.0, 0.1};
|
||||
|
||||
for(int time = 1; time <= 9; time++){
|
||||
for(double offsetTime : offsetTimes){
|
||||
final double timeToEvaluateAt = (double) time + offsetTime;
|
||||
|
||||
final double actualEvaluation = combinedFunction.evaluate(timeToEvaluateAt);
|
||||
|
||||
assertEquals(0.0, actualEvaluation, delta);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testPowerFunction(){
|
||||
final DoubleUnaryOperator operator = d -> d*d;
|
||||
|
||||
final RightContinuousStepFunction squaredFunction = smallNumbers.unaryOperation(operator);
|
||||
|
||||
assertEquals(4, squaredFunction.getX().length);
|
||||
assertEquals(4, squaredFunction.getY().length);
|
||||
|
||||
final double[] offsetTimes = {-0.1, 0.0, 0.1};
|
||||
|
||||
for(int time = 1; time <= 9; time++){
|
||||
for(double offsetTime : offsetTimes){
|
||||
final double timeToEvaluateAt = (double) time + offsetTime;
|
||||
|
||||
final double expectedEvaluation = operator.applyAsDouble(smallNumbers.evaluate(timeToEvaluateAt));
|
||||
final double actualEvaluation = squaredFunction.evaluate(timeToEvaluateAt);
|
||||
|
||||
assertEquals(expectedEvaluation, actualEvaluation, delta);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
|
@ -43,7 +43,7 @@ public class TrainForest {
|
|||
|
||||
final List<Covariate> covariateList = new ArrayList<>(p);
|
||||
for(int j =0; j < p; j++){
|
||||
final NumericCovariate covariate = new NumericCovariate("x"+j, j);
|
||||
final NumericCovariate covariate = new NumericCovariate("x"+j, j, false);
|
||||
covariateList.add(covariate);
|
||||
}
|
||||
|
||||
|
|
|
@ -39,8 +39,8 @@ public class TrainSingleTree {
|
|||
final int n = 1000;
|
||||
final List<Row<Double>> trainingSet = new ArrayList<>(n);
|
||||
|
||||
final Covariate<Double> x1Covariate = new NumericCovariate("x1", 0);
|
||||
final Covariate<Double> x2Covariate = new NumericCovariate("x2", 1);
|
||||
final Covariate<Double> x1Covariate = new NumericCovariate("x1", 0, false);
|
||||
final Covariate<Double> x2Covariate = new NumericCovariate("x2", 1, false);
|
||||
|
||||
final List<Covariate.Value<Double>> x1List = DoubleStream
|
||||
.generate(() -> random.nextDouble()*10.0)
|
||||
|
|
|
@ -41,9 +41,9 @@ public class TrainSingleTreeFactor {
|
|||
final int n = 10000;
|
||||
final List<Row<Double>> trainingSet = new ArrayList<>(n);
|
||||
|
||||
final Covariate<Double> x1Covariate = new NumericCovariate("x1", 0);
|
||||
final Covariate<Double> x2Covariate = new NumericCovariate("x2", 1);
|
||||
final FactorCovariate x3Covariate = new FactorCovariate("x3", 2, Utils.easyList("cat", "dog", "mouse"));
|
||||
final Covariate<Double> x1Covariate = new NumericCovariate("x1", 0, false);
|
||||
final Covariate<Double> x2Covariate = new NumericCovariate("x2", 1, false);
|
||||
final FactorCovariate x3Covariate = new FactorCovariate("x3", 2, Utils.easyList("cat", "dog", "mouse"), false);
|
||||
|
||||
final List<Covariate.Value<Double>> x1List = DoubleStream
|
||||
.generate(() -> random.nextDouble()*10.0)
|
||||
|
|
Loading…
Reference in a new issue