Compare commits

..

No commits in common. "master" and "library" have entirely different histories.

68 changed files with 625 additions and 2718 deletions

7
.gitignore vendored
View file

@ -2,10 +2,7 @@
.settings
.project
target/
library/target/
executable/target/
*.iml
.idea
library/dependency-reduced-pom.xml
executable/dependency-reduced-pom.xml
executable/template.yaml
template.yaml
dependency-reduced-pom.xml

View file

@ -1,20 +1,14 @@
# README
This repository contains the largeRCRF Java library, containing all of the logic used for training the random forests. This provides the Jar file used in the R package [largeRCRF](https://github.com/jatherrien/largeRCRF).
This Java software package contains the backend classes used in the R package [largeRCRF](https://github.com/jatherrien/largeRCRF).
Most users interested in training random competing risks forests should use the [R package component](https://github.com/jatherrien/largeRCRF); the content in this repository is only useful for advanced users.
On its own it's not useful, but you're free to integrate it into your own projects (as long as you follow the terms of the GPL-3 license), or extend it. More documentation will be added later on how to extend it, but for now if you want an idea I suggest you take a look at the `MeanResponseCombiner` and `WeightedVarianceSplitFinder` classes, which is a small example of a regression random forest implementation.
## License
If you've made an extension or modification to the package and would like to integrate it into the R package component, build the project in Maven with `mvn clean package` and copy the `largeRCRF-1.0-SNAPSHOT.jar` file now found in the `target/` directory into the `inst/java/` directory for the R package (delete the previous jar file). Then just build the R package, possibly with your modifications in the R code, with `R> devtools::build()`.
You're free to use / modify / redistribute the project, as long as you follow the terms of the GPL-3 license.
If you have any questions on how to integrate this code with your own, how to integrate it with the R project, or anything else related to this project, please feel free to either [email me](mailto:joelt@sfu.ca) or create an Issue.
## Extending
Documentation on how to extend the library to add support for other types of random forests will eventually be added, but for now if you're interested in that I suggest you take a look at the `MeanResponseCombiner` and `WeightedVarianceSplitFinder` classes to see how some basic regression random forests were introduced.
If you've made a modification to the package and would like to integrate it into the R package component, build the project in Maven with `mvn clean package`, then just copy `target/largeRCRF-library-1.0-SNAPSHOT.jar` into the `inst/java/` directory for the R package, replacing the previous Jar file there. Then build the R package, possibly with your modifications to the code there, with `R> devtools::build()`.
Please don't take the current lack of documentation as a sign that I oppose others extending or modifying the project; if you have any questions on running, extending, integrating with R, or anything else related to this project, please don't hesitate to either [email me](mailto:joelt@sfu.ca) or create an Issue. Most likely my answers to your questions will end up forming the basis for any documentation written.
A small project allowing this code to be called directly outside of R will be released soon.
## System Requirements
@ -23,3 +17,5 @@ You need:
* A Java runtime version 1.8 or greater
* Maven to build the project

11
pom.xml
View file

@ -4,8 +4,8 @@
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<groupId>ca.joeltherrien.ca</groupId>
<artifactId>largeRCRF-library</artifactId>
<groupId>ca.joeltherrien</groupId>
<artifactId>largeRCRF</artifactId>
<version>1.0-SNAPSHOT</version>
<properties>
@ -60,8 +60,7 @@
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-shade-plugin</artifactId>
<version>3.2.1</version>
<artifactId>maven-shade-plugin</artifactId>
<executions>
<execution>
<phase>package</phase>
@ -86,7 +85,7 @@
<configuration>
<rulesets>
<!-- Custom local file system rule set -->
<ruleset>pmd-rules.xml</ruleset>
<ruleset>${project.basedir}/pmd-rules.xml</ruleset>
</rulesets>
</configuration>
</plugin>
@ -95,4 +94,4 @@
</project>
</project>

View file

@ -21,13 +21,12 @@ import lombok.Getter;
import lombok.RequiredArgsConstructor;
import java.io.Serializable;
import java.util.*;
import java.util.stream.Collectors;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
@RequiredArgsConstructor
public class CovariateRow implements Serializable, Cloneable {
private static final long serialVersionUID = 1L;
public class CovariateRow implements Serializable {
private final Covariate.Value[] valueArray;
@ -47,14 +46,6 @@ public class CovariateRow implements Serializable, Cloneable {
return "CovariateRow " + this.id;
}
@Override
public CovariateRow clone() {
// shallow clone, which is fine. I want a new array, but the values don't need to be copied
final Covariate.Value[] copyValueArray = this.valueArray.clone();
return new CovariateRow(copyValueArray, this.id);
}
public static CovariateRow createSimple(Map<String, String> simpleMap, List<Covariate> covariateList, int id){
final Covariate.Value[] valueArray = new Covariate.Value[covariateList.size()];
final Map<String, Covariate> covariateMap = new HashMap<>();
@ -73,27 +64,4 @@ public class CovariateRow implements Serializable, Cloneable {
return new CovariateRow(valueArray, id);
}
/**
* Used for variable importance; takes a List of CovariateRows and permute one of the Covariates.
*
* @param covariateRows The List of CovariateRows to scramble. Note that the originals won't be modified.
* @param covariateToScramble The Covariate to scramble on.
* @param random The source of randomness to use. If not present, one will be created.
* @return A List of CovariateRows where the specified covariate was scrambled. These are different objects from the ones provided.
*/
public static List<CovariateRow> scrambleCovariateValues(List<? extends CovariateRow> covariateRows, Covariate covariateToScramble, Optional<Random> random){
final List<CovariateRow> permutedCovariateRowList = new ArrayList<>(covariateRows);
Collections.shuffle(permutedCovariateRowList, random.orElse(new Random())); // without replacement
final List<CovariateRow> clonedRowList = covariateRows.stream().map(CovariateRow::clone).collect(Collectors.toList());
final int covariateToScrambleIndex = covariateToScramble.getIndex();
for(int i=0; i < covariateRows.size(); i++){
clonedRowList.get(i).valueArray[covariateToScrambleIndex] = permutedCovariateRowList.get(i).valueArray[covariateToScrambleIndex];
}
return clonedRowList;
}
}

View file

@ -32,6 +32,9 @@ public class Row<Y> extends CovariateRow {
this.response = response;
}
public Y getResponse() {
return this.response;
}

View file

@ -49,8 +49,6 @@ public interface Covariate<V> extends Serializable, Comparable<Covariate> {
return getIndex() - other.getIndex();
}
boolean haveNASplitPenalty();
interface Value<V> extends Serializable{
Covariate<V> getParent();

View file

@ -25,12 +25,9 @@ import lombok.Getter;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import java.util.stream.Collectors;
public final class BooleanCovariate implements Covariate<Boolean> {
private static final long serialVersionUID = 1L;
@Getter
private final String name;
@ -41,26 +38,14 @@ public final class BooleanCovariate implements Covariate<Boolean> {
private final BooleanSplitRule splitRule; // there's only one possible rule for BooleanCovariates.
private final boolean haveNASplitPenalty;
@Override
public boolean haveNASplitPenalty(){
// penalty would add worthless computational time if there are no NAs
return hasNAs && haveNASplitPenalty;
}
public BooleanCovariate(String name, int index, boolean haveNASplitPenalty){
public BooleanCovariate(String name, int index){
this.name = name;
this.index = index;
this.splitRule = new BooleanSplitRule(this);
this.haveNASplitPenalty = haveNASplitPenalty;
splitRule = new BooleanSplitRule(this);
}
@Override
public <Y> Iterator<Split<Y, Boolean>> generateSplitRuleUpdater(List<Row<Y>> data, int number, Random random) {
if(hasNAs){
data = data.stream().filter(row -> !row.getValueByIndex(index).isNA()).collect(Collectors.toList());
}
return new SingletonIterator<>(this.splitRule.applyRule(data));
}
@ -99,8 +84,6 @@ public final class BooleanCovariate implements Covariate<Boolean> {
public class BooleanValue implements Value<Boolean>{
private static final long serialVersionUID = 1L;
private final Boolean value;
private BooleanValue(final Boolean value){

View file

@ -21,8 +21,6 @@ import ca.joeltherrien.randomforest.covariates.SplitRule;
public class BooleanSplitRule implements SplitRule<Boolean> {
private static final long serialVersionUID = 1L;
private final int parentCovariateIndex;
public BooleanSplitRule(BooleanCovariate parent){

View file

@ -23,12 +23,9 @@ import lombok.EqualsAndHashCode;
import lombok.Getter;
import java.util.*;
import java.util.stream.Collectors;
public final class FactorCovariate implements Covariate<String> {
private static final long serialVersionUID = 1L;
@Getter
private final String name;
@ -41,15 +38,8 @@ public final class FactorCovariate implements Covariate<String> {
private boolean hasNAs;
private final boolean haveNASplitPenalty;
@Override
public boolean haveNASplitPenalty(){
// penalty would add worthless computational time if there are no NAs
return hasNAs && haveNASplitPenalty;
}
public FactorCovariate(final String name, final int index, List<String> levels, final boolean haveNASplitPenalty){
public FactorCovariate(final String name, final int index, List<String> levels){
this.name = name;
this.index = index;
this.factorLevels = new HashMap<>();
@ -71,22 +61,12 @@ public final class FactorCovariate implements Covariate<String> {
this.numberOfPossiblePairings = numberOfPossiblePairingsTemp-1;
this.naValue = new FactorValue(null);
this.haveNASplitPenalty = haveNASplitPenalty;
}
@Override
public <Y> Iterator<Split<Y, String>> generateSplitRuleUpdater(List<Row<Y>> data, int number, Random random) {
if(hasNAs()){
data = data.stream().filter(row -> !row.getCovariateValue(this).isNA()).collect(Collectors.toList());
}
if(number == 0){ // nsplit = 0 => try every possibility, although we limit it to the number of observations.
number = data.size();
}
final Set<Split<Y, String>> splits = new HashSet<>();
// This is to ensure we don't get stuck in an infinite loop for small factors
@ -142,8 +122,6 @@ public final class FactorCovariate implements Covariate<String> {
@EqualsAndHashCode
public final class FactorValue implements Covariate.Value<String>{
private static final long serialVersionUID = 1L;
private final String value;
private FactorValue(final String value){

View file

@ -25,8 +25,6 @@ import java.util.Set;
@EqualsAndHashCode
public final class FactorSplitRule implements SplitRule<String> {
private static final long serialVersionUID = 1L;
private final int parentCovariateIndex;
private final Set<String> leftSideValues;

View file

@ -37,8 +37,6 @@ import java.util.stream.Stream;
@ToString
public final class NumericCovariate implements Covariate<Double> {
private static final long serialVersionUID = 1L;
@Getter
private final String name;
@ -47,13 +45,6 @@ public final class NumericCovariate implements Covariate<Double> {
private boolean hasNAs = false;
private final boolean haveNASplitPenalty;
@Override
public boolean haveNASplitPenalty(){
// penalty would add worthless computational time if there are no NAs
return hasNAs && haveNASplitPenalty;
}
@Override
public <Y> NumericSplitRuleUpdater<Y> generateSplitRuleUpdater(List<Row<Y>> data, int number, Random random) {
Stream<Row<Y>> stream = data.stream();
@ -131,8 +122,6 @@ public final class NumericCovariate implements Covariate<Double> {
@EqualsAndHashCode
public class NumericValue implements Covariate.Value<Double>{
private static final long serialVersionUID = 1L;
private final Double value; // may be null
private NumericValue(final Double value){

View file

@ -23,12 +23,10 @@ import lombok.EqualsAndHashCode;
@EqualsAndHashCode
public class NumericSplitRule implements SplitRule<Double> {
private static final long serialVersionUID = 1L;
private final int parentCovariateIndex;
private final double threshold;
public NumericSplitRule(NumericCovariate parent, final double threshold){
NumericSplitRule(NumericCovariate parent, final double threshold){
this.parentCovariateIndex = parent.getIndex();
this.threshold = threshold;
}

View file

@ -26,8 +26,6 @@ import java.util.List;
@Builder
public class CompetingRiskFunctions implements Serializable {
private static final long serialVersionUID = 1L;
private final List<RightContinuousStepFunction> causeSpecificHazards;
private final List<RightContinuousStepFunction> cumulativeIncidenceCurves;

View file

@ -23,8 +23,6 @@ import java.io.Serializable;
@Data
public class CompetingRiskResponse implements Serializable {
private static final long serialVersionUID = 1L;
private final int delta;
private final double u;

View file

@ -26,9 +26,6 @@ import lombok.EqualsAndHashCode;
@EqualsAndHashCode(callSuper = true)
@Data
public final class CompetingRiskResponseWithCensorTime extends CompetingRiskResponse {
private static final long serialVersionUID = 1L;
private final double c;
public CompetingRiskResponseWithCensorTime(int delta, double u, double c) {

View file

@ -16,11 +16,9 @@
package ca.joeltherrien.randomforest.responses.competingrisk;
import ca.joeltherrien.randomforest.utils.RightContinuousStepFunction;
import ca.joeltherrien.randomforest.utils.StepFunction;
import java.util.*;
import java.util.stream.IntStream;
import java.util.stream.Stream;
public class CompetingRiskUtils {
@ -118,44 +116,6 @@ public class CompetingRiskUtils {
}
/**
* Calculate the Integrated Brier Score error on a list of responses and predictions.
*
* @param responses A List of responses
* @param predictions The corresponding List of predictions.
* @param censoringDistribution The censoring distribution.
* @param eventOfFocus The event we are calculating the error for.
* @param integrationUpperBound The upper bound to integrate to.
* @param isParallel Whether we should use parallel streams or not (provided because of bugs on a particular system).
* @return
*/
public static double[] calculateIBSError(final List<CompetingRiskResponse> responses,
List<CompetingRiskFunctions> predictions,
Optional<RightContinuousStepFunction> censoringDistribution,
int eventOfFocus,
double integrationUpperBound,
boolean isParallel){
if(responses.size() != predictions.size()){
throw new IllegalArgumentException("Length of responses and predictions must be equal.");
}
final IBSCalculator calculator = new IBSCalculator(censoringDistribution);
IntStream stream = IntStream.range(0, responses.size());
if(isParallel){
stream = stream.parallel();
}
return stream.mapToDouble(i -> {
CompetingRiskResponse response = responses.get(i);
RightContinuousStepFunction cif = predictions.get(i).getCumulativeIncidenceFunction(eventOfFocus);
return calculator.calculateError(response, cif, eventOfFocus, integrationUpperBound);
}).toArray();
}
public static CompetingRiskSetsImpl calculateSetsEfficiently(final List<CompetingRiskResponse> initialLeftHand,
final List<CompetingRiskResponse> initialRightHand,

View file

@ -1,82 +0,0 @@
package ca.joeltherrien.randomforest.responses.competingrisk;
import ca.joeltherrien.randomforest.utils.RightContinuousStepFunction;
import java.util.Optional;
/**
* Used to calculate the Integrated Brier Score. See Section 4.2 of "Random survival forests for competing risks" by Ishwaran.
*
*/
public class IBSCalculator {
private final Optional<RightContinuousStepFunction> censoringDistribution;
public IBSCalculator(RightContinuousStepFunction censoringDistribution){
this.censoringDistribution = Optional.of(censoringDistribution);
}
public IBSCalculator(){
this.censoringDistribution = Optional.empty();
}
public IBSCalculator(Optional<RightContinuousStepFunction> censoringDistribution){
this.censoringDistribution = censoringDistribution;
}
public double calculateError(CompetingRiskResponse response, RightContinuousStepFunction cif, int eventOfInterest, double integrationUpperBound){
// return integral of weights*(I(response.getU() <= times & response.getDelta() == eventOfInterest) - cif(times))^2
// Note that if we don't have weights, just treat them all as one (i.e. don't bother multiplying)
RightContinuousStepFunction functionToIntegrate = cif;
if(response.getDelta() == eventOfInterest){
final RightContinuousStepFunction observedFunction = new RightContinuousStepFunction(new double[]{response.getU()}, new double[]{1.0}, 0.0);
functionToIntegrate = RightContinuousStepFunction.biOperation(observedFunction, functionToIntegrate, (a, b) -> (a - b) * (a - b));
} else{
functionToIntegrate = functionToIntegrate.unaryOperation(a -> a*a);
}
if(censoringDistribution.isPresent()){
final RightContinuousStepFunction weights = calculateWeights(response, censoringDistribution.get());
functionToIntegrate = RightContinuousStepFunction.biOperation(weights, functionToIntegrate, (a, b) -> a*b);
// the censoring weights go to 0 after the response is censored, so we can speed up results by only integrating
// prior to the censor times
if(response.isCensored()){
integrationUpperBound = Math.min(integrationUpperBound, response.getU());
}
}
return functionToIntegrate.integrate(0.0, integrationUpperBound);
}
private RightContinuousStepFunction calculateWeights(CompetingRiskResponse response, RightContinuousStepFunction censoringDistribution){
final double recordedTime = response.getU();
// Function(t) = firstPart(t) + secondPart(t)/thirdPart(t) where:
// firstPart(t) = I(recordedTime <= t & !response.isCensored()) / censoringDistribution.evaluate(recordedTime);
// secondPart(t) = I(recordedTime > t) = 1 - I(recordedTime <= t)
// thirdPart(t) = censoringDistribution.evaluate(t)
final RightContinuousStepFunction secondPart = new RightContinuousStepFunction(new double[]{recordedTime}, new double[]{0.0}, 1.0);
RightContinuousStepFunction result = RightContinuousStepFunction.biOperation(secondPart, censoringDistribution,
(second, third) -> second / third);
if(!response.isCensored()){
final RightContinuousStepFunction firstPart = new RightContinuousStepFunction(
new double[]{recordedTime},
new double[]{1.0 / censoringDistribution.evaluate(recordedTime)},
0.0);
result = RightContinuousStepFunction.biOperation(firstPart, result, Double::sum);
}
return result;
}
}

View file

@ -17,17 +17,17 @@
package ca.joeltherrien.randomforest.responses.competingrisk.combiner;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskFunctions;
import ca.joeltherrien.randomforest.tree.ForestResponseCombiner;
import ca.joeltherrien.randomforest.tree.IntermediateCombinedResponse;
import ca.joeltherrien.randomforest.tree.ResponseCombiner;
import ca.joeltherrien.randomforest.utils.RightContinuousStepFunction;
import ca.joeltherrien.randomforest.utils.Utils;
import lombok.RequiredArgsConstructor;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
@RequiredArgsConstructor
public class CompetingRiskFunctionCombiner implements ForestResponseCombiner<CompetingRiskFunctions, CompetingRiskFunctions> {
private static final long serialVersionUID = 1L;
public class CompetingRiskFunctionCombiner implements ResponseCombiner<CompetingRiskFunctions, CompetingRiskFunctions> {
private final int[] events;
private final double[] times; // We may restrict ourselves to specific times.
@ -55,22 +55,72 @@ public class CompetingRiskFunctionCombiner implements ForestResponseCombiner<Com
).sorted().distinct().toArray();
}
final IntermediateCompetingRisksFunctionsTimesKnown intermediateResult = new IntermediateCompetingRisksFunctionsTimesKnown(responses.size(), this.events, timesToUse);
final double n = responses.size();
final double[] survivalY = new double[timesToUse.length];
final double[][] csCHFY = new double[events.length][timesToUse.length];
final double[][] cifY = new double[events.length][timesToUse.length];
/*
We're going to try to efficiently put our predictions together -
Assumptions - for each event on a response, the hazard and CIF functions share the same x points
Plan - go through the time on each response and make use of that so that when we search for a time index
to evaluate the function at, we don't need to re-search the earlier times.
*/
for(final CompetingRiskFunctions currentFunctions : responses){
final double[] survivalXPoints = currentFunctions.getSurvivalCurve().getX();
final double[][] eventSpecificXPoints = new double[events.length][];
for(final int event : events){
eventSpecificXPoints[event-1] = currentFunctions.getCumulativeIncidenceFunction(event)
.getX();
}
int previousSurvivalIndex = 0;
final int[] previousEventSpecificIndex = new int[events.length]; // relying on 0 being default value
for(int i=0; i<timesToUse.length; i++){
final double time = timesToUse[i];
// Survival curve
final int survivalTimeIndex = Utils.binarySearchLessThan(previousSurvivalIndex, survivalXPoints.length, survivalXPoints, time);
survivalY[i] = survivalY[i] + currentFunctions.getSurvivalCurve().evaluateByIndex(survivalTimeIndex) / n;
previousSurvivalIndex = Math.max(survivalTimeIndex, 0); // if our current time is less than the smallest time in xPoints then binarySearchLessThan returned a -1.
// -1's not an issue for evaluateByIndex, but it is an issue for the next time binarySearchLessThan is called.
// CHFs and CIFs
for(final int event : events){
final double[] xPoints = eventSpecificXPoints[event-1];
final int eventTimeIndex = Utils.binarySearchLessThan(previousEventSpecificIndex[event-1], xPoints.length,
xPoints, time);
csCHFY[event-1][i] = csCHFY[event-1][i] + currentFunctions.getCauseSpecificHazardFunction(event)
.evaluateByIndex(eventTimeIndex) / n;
cifY[event-1][i] = cifY[event-1][i] + currentFunctions.getCumulativeIncidenceFunction(event)
.evaluateByIndex(eventTimeIndex) / n;
previousEventSpecificIndex[event-1] = Math.max(eventTimeIndex, 0);
}
}
for(CompetingRiskFunctions input : responses){
intermediateResult.processNewInput(input);
}
return intermediateResult.transformToOutput();
}
final RightContinuousStepFunction survivalFunction = new RightContinuousStepFunction(timesToUse, survivalY, 1.0);
final List<RightContinuousStepFunction> causeSpecificCumulativeHazardFunctionList = new ArrayList<>(events.length);
final List<RightContinuousStepFunction> cumulativeIncidenceFunctionList = new ArrayList<>(events.length);
@Override
public IntermediateCombinedResponse<CompetingRiskFunctions, CompetingRiskFunctions> startIntermediateCombinedResponse(int countInputs) {
if(this.times != null){
return new IntermediateCompetingRisksFunctionsTimesKnown(countInputs, this.events, this.times);
for(final int event : events){
causeSpecificCumulativeHazardFunctionList.add(event-1, new RightContinuousStepFunction(timesToUse, csCHFY[event-1], 0));
cumulativeIncidenceFunctionList.add(event-1, new RightContinuousStepFunction(timesToUse, cifY[event-1], 0));
}
// TODO - implement
throw new RuntimeException("startIntermediateCombinedResponse when times is unknown is not yet implemented");
return CompetingRiskFunctions.builder()
.causeSpecificHazards(causeSpecificCumulativeHazardFunctionList)
.cumulativeIncidenceCurves(cumulativeIncidenceFunctionList)
.survivalCurve(survivalFunction)
.build();
}
}

View file

@ -35,8 +35,6 @@ import java.util.List;
*/
public class CompetingRiskResponseCombiner implements ResponseCombiner<CompetingRiskResponse, CompetingRiskFunctions> {
private static final long serialVersionUID = 1L;
private final int[] events;
public CompetingRiskResponseCombiner(final int[] events){

View file

@ -1,118 +0,0 @@
package ca.joeltherrien.randomforest.responses.competingrisk.combiner;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskFunctions;
import ca.joeltherrien.randomforest.tree.IntermediateCombinedResponse;
import ca.joeltherrien.randomforest.utils.RightContinuousStepFunction;
import ca.joeltherrien.randomforest.utils.Utils;
import java.util.ArrayList;
import java.util.List;
public class IntermediateCompetingRisksFunctionsTimesKnown implements IntermediateCombinedResponse<CompetingRiskFunctions, CompetingRiskFunctions> {
private double expectedN;
private final int[] events;
private final double[] timesToUse;
private int actualN;
private final double[] survivalY;
private final double[][] csCHFY;
private final double[][] cifY;
public IntermediateCompetingRisksFunctionsTimesKnown(int n, int[] events, double[] timesToUse){
this.expectedN = n;
this.events = events;
this.timesToUse = timesToUse;
this.actualN = 0;
this.survivalY = new double[timesToUse.length];
this.csCHFY = new double[events.length][timesToUse.length];
this.cifY = new double[events.length][timesToUse.length];
}
@Override
public void processNewInput(CompetingRiskFunctions input) {
/*
We're going to try to efficiently put our predictions together -
Assumptions - for each event on a response, the hazard and CIF functions share the same x points
Plan - go through the time on each response and make use of that so that when we search for a time index
to evaluate the function at, we don't need to re-search the earlier times.
*/
this.actualN++;
final double[] survivalXPoints = input.getSurvivalCurve().getX();
final double[][] eventSpecificXPoints = new double[events.length][];
for(final int event : events){
eventSpecificXPoints[event-1] = input.getCumulativeIncidenceFunction(event)
.getX();
}
int previousSurvivalIndex = 0;
final int[] previousEventSpecificIndex = new int[events.length]; // relying on 0 being default value
for(int i=0; i<timesToUse.length; i++){
final double time = timesToUse[i];
// Survival curve
final int survivalTimeIndex = Utils.binarySearchLessThan(previousSurvivalIndex, survivalXPoints.length, survivalXPoints, time);
survivalY[i] = survivalY[i] + input.getSurvivalCurve().evaluateByIndex(survivalTimeIndex) / expectedN;
previousSurvivalIndex = Math.max(survivalTimeIndex, 0); // if our current time is less than the smallest time in xPoints then binarySearchLessThan returned a -1.
// -1's not an issue for evaluateByIndex, but it is an issue for the next time binarySearchLessThan is called.
// CHFs and CIFs
for(final int event : events){
final double[] xPoints = eventSpecificXPoints[event-1];
final int eventTimeIndex = Utils.binarySearchLessThan(previousEventSpecificIndex[event-1], xPoints.length,
xPoints, time);
csCHFY[event-1][i] = csCHFY[event-1][i] + input.getCauseSpecificHazardFunction(event)
.evaluateByIndex(eventTimeIndex) / expectedN;
cifY[event-1][i] = cifY[event-1][i] + input.getCumulativeIncidenceFunction(event)
.evaluateByIndex(eventTimeIndex) / expectedN;
previousEventSpecificIndex[event-1] = Math.max(eventTimeIndex, 0);
}
}
}
@Override
public CompetingRiskFunctions transformToOutput() {
rescaleOutput();
final RightContinuousStepFunction survivalFunction = new RightContinuousStepFunction(timesToUse, survivalY, 1.0);
final List<RightContinuousStepFunction> causeSpecificCumulativeHazardFunctionList = new ArrayList<>(events.length);
final List<RightContinuousStepFunction> cumulativeIncidenceFunctionList = new ArrayList<>(events.length);
for(final int event : events){
causeSpecificCumulativeHazardFunctionList.add(event-1, new RightContinuousStepFunction(timesToUse, csCHFY[event-1], 0));
cumulativeIncidenceFunctionList.add(event-1, new RightContinuousStepFunction(timesToUse, cifY[event-1], 0));
}
return CompetingRiskFunctions.builder()
.causeSpecificHazards(causeSpecificCumulativeHazardFunctionList)
.cumulativeIncidenceCurves(cumulativeIncidenceFunctionList)
.survivalCurve(survivalFunction)
.build();
}
private void rescaleOutput() {
rescaleArray(actualN, this.survivalY);
for(int event : events){
rescaleArray(actualN, this.cifY[event - 1]);
rescaleArray(actualN, this.csCHFY[event - 1]);
}
this.expectedN = actualN;
}
private void rescaleArray(double newN, double[] array){
for(int i=0; i<array.length; i++){
array[i] = array[i] * (this.expectedN / newN);
}
}
}

View file

@ -28,7 +28,6 @@ import java.util.List;
*
*/
public class GrayLogRankSplitFinder extends CompetingRiskSplitFinder<CompetingRiskResponseWithCensorTime> {
private static final long serialVersionUID = 1L;
private final int[] eventsOfFocus;
private final int[] events;

View file

@ -28,7 +28,6 @@ import java.util.List;
*
*/
public class LogRankSplitFinder extends CompetingRiskSplitFinder<CompetingRiskResponse> {
private static final long serialVersionUID = 1L;
private final int[] eventsOfFocus;
private final int[] events;

View file

@ -16,8 +16,7 @@
package ca.joeltherrien.randomforest.responses.regression;
import ca.joeltherrien.randomforest.tree.ForestResponseCombiner;
import ca.joeltherrien.randomforest.tree.IntermediateCombinedResponse;
import ca.joeltherrien.randomforest.tree.ResponseCombiner;
import java.util.List;
@ -25,8 +24,7 @@ import java.util.List;
* Returns the Mean value of a group of Doubles.
*
*/
public class MeanResponseCombiner implements ForestResponseCombiner<Double, Double> {
private static final long serialVersionUID = 1L;
public class MeanResponseCombiner implements ResponseCombiner<Double, Double> {
@Override
public Double combine(List<Double> responses) {
@ -36,39 +34,5 @@ public class MeanResponseCombiner implements ForestResponseCombiner<Double, Doub
}
@Override
public IntermediateCombinedResponse<Double, Double> startIntermediateCombinedResponse(int countInputs) {
return new MeanIntermediateCombinedResponse(countInputs);
}
public static class MeanIntermediateCombinedResponse implements IntermediateCombinedResponse<Double, Double>{
private double expectedN;
private int actualN;
private double currentMean;
public MeanIntermediateCombinedResponse(int n){
this.expectedN = n;
this.actualN = 0;
this.currentMean = 0.0;
}
@Override
public void processNewInput(Double input) {
this.currentMean = this.currentMean + input / expectedN;
this.actualN ++;
}
@Override
public Double transformToOutput() {
// rescale if necessary
this.currentMean = this.currentMean * (this.expectedN / (double) actualN);
this.expectedN = actualN;
return currentMean;
}
}
}

View file

@ -27,7 +27,6 @@ import java.util.List;
import java.util.stream.Collectors;
public class WeightedVarianceSplitFinder implements SplitFinder<Double> {
private static final long serialVersionUID = 1L;
private Double getScore(Set leftHand, Set rightHand) {

View file

@ -17,18 +17,31 @@
package ca.joeltherrien.randomforest.tree;
import ca.joeltherrien.randomforest.CovariateRow;
import ca.joeltherrien.randomforest.covariates.Covariate;
import lombok.Builder;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.TreeMap;
import java.util.stream.Collectors;
public abstract class Forest<O, FO> {
@Builder
public class Forest<O, FO> { // O = output of trees, FO = forest output. In practice O == FO, even in competing risk & survival settings
public abstract FO evaluate(CovariateRow row);
public abstract FO evaluateOOB(CovariateRow row);
public abstract Iterable<Tree<O>> getTrees();
public abstract int getNumberOfTrees();
private final List<Tree<O>> trees;
private final ResponseCombiner<O, FO> treeResponseCombiner;
private final List<Covariate> covariateList;
public FO evaluate(CovariateRow row){
return treeResponseCombiner.combine(
trees.stream()
.map(node -> node.evaluate(row))
.collect(Collectors.toList())
);
}
/**
* Used primarily in the R package interface to avoid R loops; and for easier parallelization.
@ -80,6 +93,21 @@ public abstract class Forest<O, FO> {
.collect(Collectors.toList());
}
public FO evaluateOOB(CovariateRow row){
return treeResponseCombiner.combine(
trees.stream()
.filter(tree -> !tree.idInBootstrapSample(row.getId()))
.map(node -> node.evaluate(row))
.collect(Collectors.toList())
);
}
public List<Tree<O>> getTrees(){
return Collections.unmodifiableList(trees);
}
public Map<Integer, Integer> findSplitsByCovariate(){
final Map<Integer, Integer> countMap = new TreeMap<>();
@ -130,5 +158,4 @@ public abstract class Forest<O, FO> {
return countTerminalNodes;
}
}

View file

@ -1,23 +0,0 @@
/*
* Copyright (c) 2019 Joel Therrien.
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
package ca.joeltherrien.randomforest.tree;
public interface ForestResponseCombiner<I, O> extends ResponseCombiner<I, O>{
IntermediateCombinedResponse<I, O> startIntermediateCombinedResponse(int countInputs);
}

View file

@ -38,7 +38,7 @@ public class ForestTrainer<Y, TO, FO> {
private final TreeTrainer<Y, TO> treeTrainer;
private final List<Covariate> covariates;
private final ForestResponseCombiner<TO, FO> treeResponseCombiner;
private final ResponseCombiner<TO, FO> treeResponseCombiner;
private final List<Row<Y>> data;
// number of trees to try
@ -57,10 +57,10 @@ public class ForestTrainer<Y, TO, FO> {
* in which case its trees are combined with the new one.
* @return A trained forest.
*/
public OnlineForest<TO, FO> trainSerialInMemory(Optional<Forest<TO, FO>> initialForest){
public Forest<TO, FO> trainSerialInMemory(Optional<Forest<TO, FO>> initialForest){
final List<Tree<TO>> trees = new ArrayList<>(ntree);
initialForest.ifPresent(forest -> forest.getTrees().forEach(trees::add));
initialForest.ifPresent(forest -> trees.addAll(forest.getTrees()));
final Bootstrapper<Row<Y>> bootstrapper = new Bootstrapper<>(data);
@ -77,9 +77,11 @@ public class ForestTrainer<Y, TO, FO> {
System.out.println("Finished");
}
return OnlineForest.<TO, FO>builder()
return Forest.<TO, FO>builder()
.treeResponseCombiner(treeResponseCombiner)
.trees(trees)
.covariateList(covariates)
.build();
}
@ -92,7 +94,7 @@ public class ForestTrainer<Y, TO, FO> {
* There cannot be existing trees if the initial forest is
* specified.
*/
public OfflineForest<TO, FO> trainSerialOnDisk(Optional<Forest<TO, FO>> initialForest){
public void trainSerialOnDisk(Optional<Forest<TO, FO>> initialForest){
// First we need to see how many trees there currently are
final File folder = new File(saveTreeLocation);
if(!folder.exists()){
@ -113,14 +115,17 @@ public class ForestTrainer<Y, TO, FO> {
final AtomicInteger treeCount; // tracks how many trees are finished
// Using an AtomicInteger is overkill for serial code, but this lets us reuse TreeSavedWorker
if(initialForest.isPresent()){
int j=0;
for(final Tree<TO> tree : initialForest.get().getTrees()){
final List<Tree<TO>> initialTrees = initialForest.get().getTrees();
for(int j=0; j<initialTrees.size(); j++){
final String filename = "tree-" + (j+1) + ".tree";
final Tree<TO> tree = initialTrees.get(j);
saveTree(tree, filename);
j++;
}
treeCount = new AtomicInteger(j);
treeCount = new AtomicInteger(initialTrees.size());
} else{
treeCount = new AtomicInteger(treeFiles.length);
}
@ -148,8 +153,6 @@ public class ForestTrainer<Y, TO, FO> {
System.out.println("Finished");
}
return new OfflineForest<>(folder, treeResponseCombiner);
}
/**
@ -159,7 +162,7 @@ public class ForestTrainer<Y, TO, FO> {
* in which case its trees are combined with the new one.
* @param threads The number of trees to train at once.
*/
public OnlineForest<TO, FO> trainParallelInMemory(Optional<Forest<TO, FO>> initialForest, int threads){
public Forest<TO, FO> trainParallelInMemory(Optional<Forest<TO, FO>> initialForest, int threads){
// create a list that is pre-specified in size (I can call the .set method at any index < ntree without
// the earlier indexes being filled.
@ -167,12 +170,11 @@ public class ForestTrainer<Y, TO, FO> {
final int startingCount;
if(initialForest.isPresent()){
int j = 0;
for(final Tree<TO> tree : initialForest.get().getTrees()){
trees.set(j, tree);
j++;
final List<Tree<TO>> initialTrees = initialForest.get().getTrees();
for(int j=0; j<initialTrees.size(); j++) {
trees.set(j, initialTrees.get(j));
}
startingCount = initialForest.get().getNumberOfTrees();
startingCount = initialTrees.size();
}
else{
startingCount = 0;
@ -217,7 +219,7 @@ public class ForestTrainer<Y, TO, FO> {
System.out.println("\nFinished");
}
return OnlineForest.<TO, FO>builder()
return Forest.<TO, FO>builder()
.treeResponseCombiner(treeResponseCombiner)
.trees(trees)
.build();
@ -233,7 +235,7 @@ public class ForestTrainer<Y, TO, FO> {
* specified.
* @param threads The number of trees to train at once.
*/
public OfflineForest<TO, FO> trainParallelOnDisk(Optional<Forest<TO, FO>> initialForest, int threads){
public void trainParallelOnDisk(Optional<Forest<TO, FO>> initialForest, int threads){
// First we need to see how many trees there currently are
final File folder = new File(saveTreeLocation);
if(!folder.exists()){
@ -253,14 +255,17 @@ public class ForestTrainer<Y, TO, FO> {
final AtomicInteger treeCount; // tracks how many trees are finished
if(initialForest.isPresent()){
int j=0;
for(final Tree<TO> tree : initialForest.get().getTrees()){
final List<Tree<TO>> initialTrees = initialForest.get().getTrees();
for(int j=0; j<initialTrees.size(); j++){
final String filename = "tree-" + (j+1) + ".tree";
final Tree<TO> tree = initialTrees.get(j);
saveTree(tree, filename);
j++;
}
treeCount = new AtomicInteger(j);
treeCount = new AtomicInteger(initialTrees.size());
} else{
treeCount = new AtomicInteger(treeFiles.length);
}
@ -304,8 +309,6 @@ public class ForestTrainer<Y, TO, FO> {
System.out.println("\nFinished");
}
return new OfflineForest<>(folder, treeResponseCombiner);
}
private Tree<TO> trainTree(final Bootstrapper<Row<Y>> bootstrapper, Random random){

View file

@ -1,30 +0,0 @@
/*
* Copyright (c) 2019 Joel Therrien.
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
package ca.joeltherrien.randomforest.tree;
/**
* Similar to ResponseCombiner, but an IntermediateCombinedResponse represents the intermediate state of a single output in the process of being combined.
* This class is only used in OfflineForests where we can only load one Tree in memory at a time.
*
*/
public interface IntermediateCombinedResponse<I, O> {
void processNewInput(I input);
O transformToOutput();
}

View file

@ -1,198 +0,0 @@
/*
* Copyright (c) 2019 Joel Therrien.
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
package ca.joeltherrien.randomforest.tree;
import ca.joeltherrien.randomforest.CovariateRow;
import ca.joeltherrien.randomforest.utils.IterableOfflineTree;
import lombok.AllArgsConstructor;
import java.io.File;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
@AllArgsConstructor
public class OfflineForest<O, FO> extends Forest<O, FO> {
private final File[] treeFiles;
private final ForestResponseCombiner<O, FO> treeResponseCombiner;
public OfflineForest(File treeDirectoryPath, ForestResponseCombiner<O, FO> treeResponseCombiner){
this.treeResponseCombiner = treeResponseCombiner;
if(!treeDirectoryPath.isDirectory()){
throw new IllegalArgumentException("treeDirectoryPath must point to a directory!");
}
this.treeFiles = treeDirectoryPath.listFiles((file, s) -> s.endsWith(".tree"));
}
@Override
public FO evaluate(CovariateRow row) {
final List<O> predictedOutputs = new ArrayList<>(treeFiles.length);
for(final Tree<O> tree : getTrees()){
final O prediction = tree.evaluate(row);
predictedOutputs.add(prediction);
}
return treeResponseCombiner.combine(predictedOutputs);
}
@Override
public FO evaluateOOB(CovariateRow row) {
final List<O> predictedOutputs = new ArrayList<>(treeFiles.length);
for(final Tree<O> tree : getTrees()){
if(!tree.idInBootstrapSample(row.getId())){
final O prediction = tree.evaluate(row);
predictedOutputs.add(prediction);
}
}
return treeResponseCombiner.combine(predictedOutputs);
}
@Override
public List<FO> evaluate(List<? extends CovariateRow> rowList){
final List<IntermediateCombinedResponse<O, FO>> intermediatePredictions =
IntStream.range(0, rowList.size())
.mapToObj(i -> treeResponseCombiner.startIntermediateCombinedResponse(treeFiles.length))
.collect(Collectors.toList());
final Iterator<Tree<O>> treeIterator = getTrees().iterator();
for(int treeId = 0; treeId < treeFiles.length; treeId++){
final Tree<O> currentTree = treeIterator.next();
IntStream.range(0, rowList.size()).parallel().forEach(
rowId -> {
final CovariateRow row = rowList.get(rowId);
final O prediction = currentTree.evaluate(row);
intermediatePredictions.get(rowId).processNewInput(prediction);
}
);
}
return intermediatePredictions.stream().parallel()
.map(intPred -> intPred.transformToOutput())
.collect(Collectors.toList());
}
@Override
public List<FO> evaluateSerial(List<? extends CovariateRow> rowList){
final List<IntermediateCombinedResponse<O, FO>> intermediatePredictions =
IntStream.range(0, rowList.size())
.mapToObj(i -> treeResponseCombiner.startIntermediateCombinedResponse(treeFiles.length))
.collect(Collectors.toList());
final Iterator<Tree<O>> treeIterator = getTrees().iterator();
for(int treeId = 0; treeId < treeFiles.length; treeId++){
final Tree<O> currentTree = treeIterator.next();
IntStream.range(0, rowList.size()).sequential().forEach(
rowId -> {
final CovariateRow row = rowList.get(rowId);
final O prediction = currentTree.evaluate(row);
intermediatePredictions.get(rowId).processNewInput(prediction);
}
);
}
return intermediatePredictions.stream().sequential()
.map(intPred -> intPred.transformToOutput())
.collect(Collectors.toList());
}
@Override
public List<FO> evaluateOOB(List<? extends CovariateRow> rowList){
final List<IntermediateCombinedResponse<O, FO>> intermediatePredictions =
IntStream.range(0, rowList.size())
.mapToObj(i -> treeResponseCombiner.startIntermediateCombinedResponse(treeFiles.length))
.collect(Collectors.toList());
final Iterator<Tree<O>> treeIterator = getTrees().iterator();
for(int treeId = 0; treeId < treeFiles.length; treeId++){
final Tree<O> currentTree = treeIterator.next();
IntStream.range(0, rowList.size()).parallel().forEach(
rowId -> {
final CovariateRow row = rowList.get(rowId);
if(!currentTree.idInBootstrapSample(row.getId())){
final O prediction = currentTree.evaluate(row);
intermediatePredictions.get(rowId).processNewInput(prediction);
}
// else do nothing; when we get the final output it will get scaled for the smaller N
}
);
}
return intermediatePredictions.stream().parallel()
.map(intPred -> intPred.transformToOutput())
.collect(Collectors.toList());
}
@Override
public List<FO> evaluateSerialOOB(List<? extends CovariateRow> rowList){
final List<IntermediateCombinedResponse<O, FO>> intermediatePredictions =
IntStream.range(0, rowList.size())
.mapToObj(i -> treeResponseCombiner.startIntermediateCombinedResponse(treeFiles.length))
.collect(Collectors.toList());
final Iterator<Tree<O>> treeIterator = getTrees().iterator();
for(int treeId = 0; treeId < treeFiles.length; treeId++){
final Tree<O> currentTree = treeIterator.next();
IntStream.range(0, rowList.size()).sequential().forEach(
rowId -> {
final CovariateRow row = rowList.get(rowId);
if(!currentTree.idInBootstrapSample(row.getId())){
final O prediction = currentTree.evaluate(row);
intermediatePredictions.get(rowId).processNewInput(prediction);
}
// else do nothing; when we get the final output it will get scaled for the smaller N
}
);
}
return intermediatePredictions.stream().sequential()
.map(intPred -> intPred.transformToOutput())
.collect(Collectors.toList());
}
@Override
public Iterable<Tree<O>> getTrees() {
return new IterableOfflineTree<>(treeFiles);
}
@Override
public int getNumberOfTrees() {
return treeFiles.length;
}
public OnlineForest<O, FO> createOnlineCopy(){
final List<Tree<O>> allTrees = new ArrayList<>(getNumberOfTrees());
getTrees().forEach(allTrees::add);
return OnlineForest.<O, FO>builder()
.trees(allTrees)
.treeResponseCombiner(treeResponseCombiner)
.build();
}
}

View file

@ -1,66 +0,0 @@
/*
* Copyright (c) 2019 Joel Therrien.
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
package ca.joeltherrien.randomforest.tree;
import ca.joeltherrien.randomforest.CovariateRow;
import lombok.Builder;
import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;
@Builder
public class OnlineForest<O, FO> extends Forest<O, FO> { // O = output of trees, FO = forest output. In practice O == FO, even in competing risk & survival settings
private final List<Tree<O>> trees;
private final ResponseCombiner<O, FO> treeResponseCombiner;
@Override
public FO evaluate(CovariateRow row){
return treeResponseCombiner.combine(
trees.stream()
.map(node -> node.evaluate(row))
.collect(Collectors.toList())
);
}
@Override
public FO evaluateOOB(CovariateRow row){
return treeResponseCombiner.combine(
trees.stream()
.filter(tree -> !tree.idInBootstrapSample(row.getId()))
.map(node -> node.evaluate(row))
.collect(Collectors.toList())
);
}
@Override
public List<Tree<O>> getTrees(){
return Collections.unmodifiableList(trees);
}
@Override
public int getNumberOfTrees() {
return trees.size();
}
}

View file

@ -17,13 +17,15 @@
package ca.joeltherrien.randomforest.tree;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.Getter;
@AllArgsConstructor
@Data
public class SplitAndScore<Y, V> {
private Split<Y, V> split;
private Double score;
@Getter
private final Split<Y, V> split;
@Getter
private final Double score;
}

View file

@ -29,7 +29,6 @@ import java.util.List;
@ToString
@Getter
public class SplitNode<Y> implements Node<Y> {
private static final long serialVersionUID = 1L;
private final Node<Y> leftHand;
private final Node<Y> rightHand;

View file

@ -27,7 +27,6 @@ import java.util.List;
@RequiredArgsConstructor
@ToString
public class TerminalNode<Y> implements Node<Y> {
private static final long serialVersionUID = 1L;
private final Y responseValue;

View file

@ -23,7 +23,6 @@ import java.util.Arrays;
import java.util.List;
public class Tree<Y> implements Node<Y> {
private static final long serialVersionUID = 1L;
@Getter
private final Node<Y> rootNode;

View file

@ -17,9 +17,7 @@
package ca.joeltherrien.randomforest.tree;
import ca.joeltherrien.randomforest.Row;
import ca.joeltherrien.randomforest.VisibleForTesting;
import ca.joeltherrien.randomforest.covariates.Covariate;
import ca.joeltherrien.randomforest.utils.SingletonIterator;
import lombok.AccessLevel;
import lombok.AllArgsConstructor;
import lombok.Builder;
@ -74,12 +72,31 @@ public class TreeTrainer<Y, O> {
}
// Now that we have the best split; we need to handle any NAs that were dropped off
final double probabilityLeftHand = (double) bestSplit.leftHand.size() /
(double) (bestSplit.leftHand.size() + bestSplit.rightHand.size());
// Assign missing values to the split if necessary
bestSplit = randomlyAssignNAs(data, bestSplit, random);
if(covariates.get(bestSplit.getSplitRule().getParentCovariateIndex()).hasNAs()){
bestSplit = bestSplit.modifiableClone(); // the lists in bestSplit are otherwise usually unmodifiable lists
for(Row<Y> row : data) {
final int covariateIndex = bestSplit.getSplitRule().getParentCovariateIndex();
if(row.getValueByIndex(covariateIndex).isNA()) {
final boolean randomDecision = random.nextDouble() <= probabilityLeftHand;
if(randomDecision){
bestSplit.getLeftHand().add(row);
}
else{
bestSplit.getRightHand().add(row);
}
}
}
}
final Node<O> leftNode;
final Node<O> rightNode;
@ -127,8 +144,7 @@ public class TreeTrainer<Y, O> {
return splitCovariates;
}
@VisibleForTesting
public Split<Y, ?> findBestSplitRule(List<Row<Y>> data, List<Covariate> covariatesToTry, Random random){
private Split<Y, ?> findBestSplitRule(List<Row<Y>> data, List<Covariate> covariatesToTry, Random random){
SplitAndScore<Y, ?> bestSplitAndScore = null;
final SplitFinder noGenericSplitFinder = splitFinder; // cause Java generics are sometimes too frustrating
@ -141,32 +157,10 @@ public class TreeTrainer<Y, O> {
continue;
}
SplitAndScore<Y, ?> candidateSplitAndScore = noGenericSplitFinder.findBestSplit(iterator);
final SplitAndScore<Y, ?> candidateSplitAndScore = noGenericSplitFinder.findBestSplit(iterator);
if(candidateSplitAndScore == null){
continue;
}
// This score was based on splitting only non-NA values. However, there might be a similar covariate we are also considering
// that is just as good at splitting but has less NAs; we should thus penalize the split score for variables with NAs
// We do this by randomly assigning the NAs and then recalculating the split score on the best split we already have.
//
// We only have to penalize the score though if we know it's possible that this might be the best split. If it's not,
// then we can skip the computations.
final boolean mayBeGoodSplit = bestSplitAndScore == null ||
candidateSplitAndScore.getScore() > bestSplitAndScore.getScore();
if(mayBeGoodSplit && covariate.haveNASplitPenalty()){
Split<Y, ?> candiateSplitWithNAs = randomlyAssignNAs(data, candidateSplitAndScore.getSplit(), random);
final Iterator<Split<Y, ?>> newSplitWithRandomNAs = new SingletonIterator<>(candiateSplitWithNAs);
final double newScore = splitFinder.findBestSplit(newSplitWithRandomNAs).getScore();
// There's a chance that NAs might add noise to *improve* the score; but we want to ensure we penalize it.
// Thus we only change the score if its worse.
candidateSplitAndScore.setScore(Math.min(newScore, candidateSplitAndScore.getScore()));
}
if(bestSplitAndScore == null || candidateSplitAndScore.getScore() > bestSplitAndScore.getScore()) {
if(candidateSplitAndScore != null && (bestSplitAndScore == null ||
candidateSplitAndScore.getScore() > bestSplitAndScore.getScore())) {
bestSplitAndScore = candidateSplitAndScore;
}
@ -180,38 +174,6 @@ public class TreeTrainer<Y, O> {
}
private <V> Split<Y, V> randomlyAssignNAs(List<Row<Y>> data, Split<Y, V> existingSplit, Random random){
// Now that we have the best split; we need to handle any NAs that were dropped off
final double probabilityLeftHand = (double) existingSplit.leftHand.size() /
(double) (existingSplit.leftHand.size() + existingSplit.rightHand.size());
final int covariateIndex = existingSplit.getSplitRule().getParentCovariateIndex();
// Assign missing values to the split if necessary
if(covariates.get(existingSplit.getSplitRule().getParentCovariateIndex()).hasNAs()){
existingSplit = existingSplit.modifiableClone(); // the lists in bestSplit are otherwise usually unmodifiable lists
for(Row<Y> row : data) {
if(row.getValueByIndex(covariateIndex).isNA()) {
final boolean randomDecision = random.nextDouble() <= probabilityLeftHand;
if(randomDecision){
existingSplit.getLeftHand().add(row);
}
else{
existingSplit.getRightHand().add(row);
}
}
}
}
return existingSplit;
}
private boolean nodeIsPure(List<Row<Y>> data){
if(!checkNodePurity){
return false;

View file

@ -1,24 +0,0 @@
package ca.joeltherrien.randomforest.tree.vimp;
import java.util.List;
/**
* Simple interface for VariableImportanceCalculator; takes in a List of observed responses and a List of predictions
* and produces an average error measure.
*
* @param <Y> The class of the responses.
* @param <P> The class of the predictions.
*/
public interface ErrorCalculator<Y, P>{
/**
* Compares the observed responses with the predictions to produce an average error measure.
* Lower errors should indicate a better model fit.
*
* @param responses
* @param predictions
* @return
*/
double averageError(List<Y> responses, List<P> predictions);
}

View file

@ -1,63 +0,0 @@
package ca.joeltherrien.randomforest.tree.vimp;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskFunctions;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
import ca.joeltherrien.randomforest.responses.competingrisk.IBSCalculator;
import ca.joeltherrien.randomforest.utils.RightContinuousStepFunction;
import java.util.Arrays;
import java.util.List;
/**
* Implements ErrorCalculator; essentially just wraps around IBSCalculator to fit into VariableImportanceCalculator.
*
*/
public class IBSErrorCalculatorWrapper implements ErrorCalculator<CompetingRiskResponse, CompetingRiskFunctions> {
private final IBSCalculator calculator;
private final int[] events;
private final double integrationUpperBound;
private final double[] eventWeights;
public IBSErrorCalculatorWrapper(IBSCalculator calculator, int[] events, double integrationUpperBound, double[] eventWeights) {
this.calculator = calculator;
this.events = events;
this.integrationUpperBound = integrationUpperBound;
this.eventWeights = eventWeights;
}
public IBSErrorCalculatorWrapper(IBSCalculator calculator, int[] events, double integrationUpperBound) {
this.calculator = calculator;
this.events = events;
this.integrationUpperBound = integrationUpperBound;
this.eventWeights = new double[events.length];
Arrays.fill(this.eventWeights, 1.0); // default is to just sum all errors together
}
@Override
public double averageError(List<CompetingRiskResponse> responses, List<CompetingRiskFunctions> predictions) {
final double[] errors = new double[events.length];
final double n = responses.size();
for(int i=0; i < responses.size(); i++){
final CompetingRiskResponse response = responses.get(i);
final CompetingRiskFunctions prediction = predictions.get(i);
for(int k=0; k < this.events.length; k++){
final int event = this.events[k];
final RightContinuousStepFunction cif = prediction.getCumulativeIncidenceFunction(event);
errors[k] += calculator.calculateError(response, cif, event, integrationUpperBound) / n;
}
}
double totalError = 0.0;
for(int k=0; k < this.events.length; k++){
totalError += this.eventWeights[k] * errors[k];
}
return totalError;
}
}

View file

@ -1,23 +0,0 @@
package ca.joeltherrien.randomforest.tree.vimp;
import java.util.List;
public class RegressionErrorCalculator implements ErrorCalculator<Double, Double>{
@Override
public double averageError(List<Double> responses, List<Double> predictions) {
double mean = 0.0;
final double n = responses.size();
for(int i=0; i<responses.size(); i++){
final double response = responses.get(i);
final double prediction = predictions.get(i);
final double difference = response - prediction;
mean += difference * difference / n;
}
return mean;
}
}

View file

@ -1,117 +0,0 @@
package ca.joeltherrien.randomforest.tree.vimp;
import ca.joeltherrien.randomforest.CovariateRow;
import ca.joeltherrien.randomforest.Row;
import ca.joeltherrien.randomforest.covariates.Covariate;
import ca.joeltherrien.randomforest.tree.Tree;
import java.util.List;
import java.util.Optional;
import java.util.Random;
import java.util.stream.Collectors;
public class VariableImportanceCalculator<Y, P> {
private final ErrorCalculator<Y, P> errorCalculator;
private final List<Tree<P>> trees;
private final List<Row<Y>> observations;
private final boolean isTrainingSet; // If true, then we use out-of-bag predictions
private final double[] baselineErrors;
public VariableImportanceCalculator(
ErrorCalculator<Y, P> errorCalculator,
List<Tree<P>> trees,
List<Row<Y>> observations,
boolean isTrainingSet
){
this.errorCalculator = errorCalculator;
this.trees = trees;
this.observations = observations;
this.isTrainingSet = isTrainingSet;
try {
this.baselineErrors = new double[trees.size()];
for (int i = 0; i < baselineErrors.length; i++) {
final Tree<P> tree = trees.get(i);
final List<Row<Y>> oobSubset = getAppropriateSubset(observations, tree); // may not actually be OOB depending on isTrainingSet
final List<Y> responses = oobSubset.stream().map(Row::getResponse).collect(Collectors.toList());
this.baselineErrors[i] = errorCalculator.averageError(responses, makePredictions(oobSubset, tree));
}
} catch(Exception e){
e.printStackTrace();
throw e;
}
}
/**
* Returns an array of importance values for every Tree for the given Covariate.
*
* @param covariate The Covariate to scramble.
* @param random
* @return
*/
public double[] calculateVariableImportanceRaw(Covariate covariate, Optional<Random> random){
final double[] vimp = new double[trees.size()];
for(int i = 0; i < vimp.length; i++){
final Tree<P> tree = trees.get(i);
final List<Row<Y>> oobSubset = getAppropriateSubset(observations, tree); // may not actually be OOB depending on isTrainingSet
final List<Y> responses = oobSubset.stream().map(Row::getResponse).collect(Collectors.toList());
final List<CovariateRow> scrambledValues = CovariateRow.scrambleCovariateValues(oobSubset, covariate, random);
final double error = errorCalculator.averageError(responses, makePredictions(scrambledValues, tree));
vimp[i] = error - this.baselineErrors[i];
}
return vimp;
}
public double calculateVariableImportanceZScore(Covariate covariate, Optional<Random> random){
final double[] vimpArray = calculateVariableImportanceRaw(covariate, random);
double mean = 0.0;
double variance = 0.0;
final double numTrees = vimpArray.length;
for(double vimp : vimpArray){
mean += vimp / numTrees;
}
for(double vimp : vimpArray){
variance += (vimp - mean)*(vimp - mean) / (numTrees - 1.0);
}
final double standardError = Math.sqrt(variance / numTrees);
return mean / standardError;
}
// Assume rowList has already been filtered for OOB
private List<P> makePredictions(List<? extends CovariateRow> rowList, Tree<P> tree){
return rowList.stream()
.map(tree::evaluate)
.collect(Collectors.toList());
}
private List<Row<Y>> getAppropriateSubset(List<Row<Y>> initialList, Tree<P> tree){
if(!isTrainingSet){
return initialList; // no need to make any subsets
}
return initialList.stream()
.filter(row -> !tree.idInBootstrapSample(row.getId()))
.collect(Collectors.toList());
}
}

View file

@ -16,7 +16,9 @@
package ca.joeltherrien.randomforest.utils;
import ca.joeltherrien.randomforest.tree.*;
import ca.joeltherrien.randomforest.tree.Forest;
import ca.joeltherrien.randomforest.tree.ResponseCombiner;
import ca.joeltherrien.randomforest.tree.Tree;
import java.io.*;
import java.util.*;
@ -25,17 +27,12 @@ import java.util.zip.GZIPOutputStream;
public class DataUtils {
public static <O, FO> OnlineForest<O, FO> loadOnlineForest(File folder, ResponseCombiner<O, FO> treeResponseCombiner) throws IOException, ClassNotFoundException {
public static <O, FO> Forest<O, FO> loadForest(File folder, ResponseCombiner<O, FO> treeResponseCombiner) throws IOException, ClassNotFoundException {
if(!folder.isDirectory()){
throw new IllegalArgumentException("Tree directory must be a directory!");
}
final File[] treeFiles = folder.listFiles((file, s) -> s.endsWith(".tree"));
return loadOnlineForest(treeFiles, treeResponseCombiner);
}
public static <O, FO> OnlineForest<O, FO> loadOnlineForest(File[] treeFiles, ResponseCombiner<O, FO> treeResponseCombiner) throws IOException, ClassNotFoundException {
final List<File> treeFileList = Arrays.asList(treeFiles);
Collections.sort(treeFileList, Comparator.comparing(File::getName));
@ -51,16 +48,16 @@ public class DataUtils {
}
return OnlineForest.<O, FO>builder()
return Forest.<O, FO>builder()
.trees(treeList)
.treeResponseCombiner(treeResponseCombiner)
.build();
}
public static <O, FO> OnlineForest<O, FO> loadOnlineForest(String folder, ResponseCombiner<O, FO> treeResponseCombiner) throws IOException, ClassNotFoundException {
public static <O, FO> Forest<O, FO> loadForest(String folder, ResponseCombiner<O, FO> treeResponseCombiner) throws IOException, ClassNotFoundException {
final File directory = new File(folder);
return loadOnlineForest(directory, treeResponseCombiner);
return loadForest(directory, treeResponseCombiner);
}
public static void saveObject(Serializable object, String filename) throws IOException {

View file

@ -0,0 +1,113 @@
/*
* Copyright (c) 2019 Joel Therrien.
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
package ca.joeltherrien.randomforest.utils;
/**
* Represents a function represented by discrete points. However, the function may be right-continuous or left-continuous
* at a given point, with no consistency. This function tracks that.
*/
public final class DiscontinuousStepFunction extends StepFunction {
private final double[] y;
private final boolean[] isLeftContinuous;
/**
* Represents the value that should be returned by evaluate if there are points prior to the time the function is being evaluated at.
*
* Map be null.
*/
private final double defaultY;
public DiscontinuousStepFunction(double[] x, double[] y, boolean[] isLeftContinuous, double defaultY) {
super(x);
this.y = y;
this.isLeftContinuous = isLeftContinuous;
this.defaultY = defaultY;
}
@Override
public double evaluate(double time){
int index = Utils.binarySearchLessThan(0, x.length, x, time);
if(index < 0){
return defaultY;
}
else{
if(x[index] == time){
return evaluateByIndex(index);
}
else{
return y[index];
}
}
}
@Override
public double evaluatePrevious(double time){
int index = Utils.binarySearchLessThan(0, x.length, x, time) - 1;
if(index < 0){
return defaultY;
}
else{
if(x[index] == time){
return evaluateByIndex(index);
}
else{
return y[index];
}
}
}
@Override
public double evaluateByIndex(int i) {
if(isLeftContinuous[i]){
i -= 1;
}
if(i < 0){
return defaultY;
}
return y[i];
}
@Override
public String toString(){
final StringBuilder builder = new StringBuilder();
builder.append("Default point: ");
builder.append(defaultY);
builder.append("\n");
for(int i=0; i<x.length; i++){
builder.append("x:");
builder.append(x[i]);
builder.append('\t');
if(isLeftContinuous[i]){
builder.append("*y:");
}
else{
builder.append("y*:");
}
builder.append(y[i]);
builder.append("\n");
}
return builder.toString();
}
}

View file

@ -1,68 +0,0 @@
/*
* Copyright (c) 2019 Joel Therrien.
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
package ca.joeltherrien.randomforest.utils;
import ca.joeltherrien.randomforest.tree.Tree;
import lombok.RequiredArgsConstructor;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.util.Iterator;
import java.util.zip.GZIPInputStream;
@RequiredArgsConstructor
public class IterableOfflineTree<Y> implements Iterable<Tree<Y>> {
private final File[] treeFiles;
@Override
public Iterator<Tree<Y>> iterator() {
return new OfflineTreeIterator<>(treeFiles);
}
@RequiredArgsConstructor
public static class OfflineTreeIterator<Y> implements Iterator<Tree<Y>>{
private final File[] treeFiles;
private int position = 0;
@Override
public boolean hasNext() {
return position < treeFiles.length;
}
@Override
public Tree<Y> next() {
final File treeFile = treeFiles[position];
position++;
try {
final ObjectInputStream inputStream= new ObjectInputStream(new GZIPInputStream(new FileInputStream(treeFile)));
final Tree<Y> tree = (Tree) inputStream.readObject();
return tree;
} catch (IOException | ClassNotFoundException e) {
e.printStackTrace();
throw new RuntimeException("Failed to load tree for " + treeFile.toString());
}
}
}
}

View file

@ -0,0 +1,129 @@
/*
* Copyright (c) 2019 Joel Therrien.
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
package ca.joeltherrien.randomforest.utils;
import java.util.List;
import java.util.ListIterator;
/**
* Represents a function represented by discrete points. We assume that the function is a stepwise left-continuous
* function, constant at the value of the previous encountered point.
*
*/
public final class LeftContinuousStepFunction extends StepFunction {
private final double[] y;
/**
* Represents the value that should be returned by evaluate if there are points prior to the time the function is being evaluated at.
*
* Map be null.
*/
private final double defaultY;
public LeftContinuousStepFunction(double[] x, double[] y, double defaultY) {
super(x);
this.y = y;
this.defaultY = defaultY;
}
/**
* This isn't a formal constructor because of limitations with abstract classes.
*
* @param pointList
* @param defaultY
* @return
*/
public static LeftContinuousStepFunction constructFromPoints(final List<Point> pointList, final double defaultY){
final double[] x = new double[pointList.size()];
final double[] y = new double[pointList.size()];
final ListIterator<Point> pointIterator = pointList.listIterator();
while(pointIterator.hasNext()){
final int index = pointIterator.nextIndex();
final Point currentPoint = pointIterator.next();
x[index] = currentPoint.getTime();
y[index] = currentPoint.getY();
}
return new LeftContinuousStepFunction(x, y, defaultY);
}
@Override
public double evaluate(double time){
int index = Utils.binarySearchLessThan(0, x.length, x, time);
if(index < 0){
return defaultY;
}
else{
if(x[index] == time){
return evaluateByIndex(index-1);
}
else{
return y[index];
}
}
}
@Override
public double evaluatePrevious(double time){
int index = Utils.binarySearchLessThan(0, x.length, x, time) - 1;
if(index < 0){
return defaultY;
}
else{
if(x[index] == time){
return evaluateByIndex(index-1);
}
else{
return y[index];
}
}
}
@Override
public double evaluateByIndex(int i) {
if(i < 0){
return defaultY;
}
return y[i];
}
@Override
public String toString(){
final StringBuilder builder = new StringBuilder();
builder.append("Default point: ");
builder.append(defaultY);
builder.append("\n");
for(int i=0; i<x.length; i++){
builder.append("x:");
builder.append(x[i]);
builder.append("\ty:");
builder.append(y[i]);
builder.append("\n");
}
return builder.toString();
}
}

View file

@ -26,8 +26,6 @@ import java.io.Serializable;
*/
@Data
public class Point implements Serializable {
private static final long serialVersionUID = 1L;
private final double time;
private final double y;
}

View file

@ -188,24 +188,4 @@ public final class RUtils {
return responses;
}
public static List<Object> produceSublist(List<Object> initialList, int[] indices){
final List<Object> newList = new ArrayList<>(indices.length);
for(int i : indices){
newList.add(initialList.get(i));
}
return newList;
}
public static File[] getTreeFileArray(String folderPath, int endingId){
final File[] fileArray = new File[endingId];
for(int i = 1; i <= endingId; i++){
fileArray[i-1] = new File(folderPath + "/tree-" + i + ".tree");
}
return fileArray;
}
}

View file

@ -16,14 +16,8 @@
package ca.joeltherrien.randomforest.utils;
import lombok.Getter;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.ListIterator;
import java.util.function.DoubleBinaryOperator;
import java.util.function.DoubleUnaryOperator;
/**
* Represents a function represented by discrete points. We assume that the function is a stepwise right-continuous
@ -32,16 +26,13 @@ import java.util.function.DoubleUnaryOperator;
*/
public final class RightContinuousStepFunction extends StepFunction {
private static final long serialVersionUID = 1L;
private final double[] y;
/**
* Represents the value that should be returned by evaluate if there are no points prior to the time the function is being evaluated at.
* Represents the value that should be returned by evaluate if there are points prior to the time the function is being evaluated at.
*
* May not be null.
* Map be null.
*/
@Getter
private final double defaultY;
public RightContinuousStepFunction(double[] x, double[] y, double defaultY) {
@ -136,12 +127,7 @@ public final class RightContinuousStepFunction extends StepFunction {
}
if(to < from){
return -integrate(to, from);
}
// Edge case - no points; just defaultY
if(this.x.length == 0){
return (to - from) * this.defaultY;
return integrate(to, from);
}
double summation = 0.0;
@ -184,7 +170,7 @@ public final class RightContinuousStepFunction extends StepFunction {
final double currentTime = xPoints[i];
final double currentHeight = evaluateByIndex(i);
if(i == xPoints.length-1 || xPoints[i+1] >= to){
if(i == xPoints.length-1 || xPoints[i+1] > to){
summation += currentHeight * (to - currentTime);
return summation;
}
@ -200,76 +186,5 @@ public final class RightContinuousStepFunction extends StepFunction {
}
public RightContinuousStepFunction unaryOperation(DoubleUnaryOperator operator){
final double newDefaultY = operator.applyAsDouble(this.defaultY);
final double[] newY = Arrays.stream(this.getY()).map(operator).toArray();
return new RightContinuousStepFunction(this.getX(), newY, newDefaultY);
}
public static RightContinuousStepFunction biOperation(RightContinuousStepFunction funLeft,
RightContinuousStepFunction funRight,
DoubleBinaryOperator operator){
final double newDefaultY = operator.applyAsDouble(funLeft.defaultY, funRight.defaultY);
final double[] leftX = funLeft.x;
final double[] rightX = funRight.x;
final List<Point> combinedPoints = new ArrayList<>(leftX.length + rightX.length);
// These indexes represent the times that have *already* been processed.
// They start at -1 because we already processed the defaultY values.
int indexLeft = -1;
int indexRight = -1;
// This while-loop will keep going until one of the functions reaches the ends of its points
while(indexLeft < leftX.length-1 && indexRight < rightX.length-1){
final double time;
if(leftX[indexLeft+1] < rightX[indexRight+1]){
indexLeft += 1;
time = leftX[indexLeft];
combinedPoints.add(new Point(time, operator.applyAsDouble(funLeft.evaluateByIndex(indexLeft), funRight.evaluateByIndex(indexRight))));
}
else if(leftX[indexLeft+1] > rightX[indexRight+1]){
indexRight += 1;
time = rightX[indexRight];
combinedPoints.add(new Point(time, operator.applyAsDouble(funLeft.evaluateByIndex(indexLeft), funRight.evaluateByIndex(indexRight))));
}
else{ // equal times
indexLeft += 1;
indexRight += 1;
time = leftX[indexLeft];
combinedPoints.add(new Point(time, operator.applyAsDouble(funLeft.evaluateByIndex(indexLeft), funRight.evaluateByIndex(indexRight))));
}
}
// At this point, at least one of function left or function right has reached the end of its points
// This while-loop occurring implies that functionRight can not move forward anymore
while(indexLeft < leftX.length-1){
indexLeft += 1;
final double time = leftX[indexLeft];
combinedPoints.add(new Point(time, operator.applyAsDouble(funLeft.evaluateByIndex(indexLeft), funRight.evaluateByIndex(indexRight))));
}
// This while-loop occurring implies that functionLeft can not move forward anymore
while(indexRight < rightX.length-1){
indexRight += 1;
final double time = rightX[indexRight];
combinedPoints.add(new Point(time, operator.applyAsDouble(funLeft.evaluateByIndex(indexLeft), funRight.evaluateByIndex(indexRight))));
}
return RightContinuousStepFunction.constructFromPoints(combinedPoints, newDefaultY);
}
}

View file

@ -20,7 +20,7 @@ import java.util.*;
public final class Utils {
public static RightContinuousStepFunction estimateOneMinusECDF(final double[] times){
public static StepFunction estimateOneMinusECDF(final double[] times){
Arrays.sort(times);
final Map<Double, Integer> timeCounterMap = new HashMap<>();

View file

@ -0,0 +1,80 @@
/*
* Copyright (c) 2019 Joel Therrien.
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
package ca.joeltherrien.randomforest.utils;
/**
* Represents a step function represented by discrete points. However, there may be individual time values that has
* a y value that doesn't belong to a particular 'step'.
*/
public final class VeryDiscontinuousStepFunction implements MathFunction {
private final double[] x;
private final double[] yAt;
private final double[] yRight;
/**
* Represents the value that should be returned by evaluate if there are points prior to the time the function is being evaluated at.
*
* Map be null.
*/
private final double defaultY;
public VeryDiscontinuousStepFunction(double[] x, double[] yAt, double[] yRight, double defaultY) {
this.x = x;
this.yAt = yAt;
this.yRight = yRight;
this.defaultY = defaultY;
}
@Override
public double evaluate(double time){
int index = Utils.binarySearchLessThan(0, x.length, x, time);
if(index < 0){
return defaultY;
}
else{
if(x[index] == time){
return yAt[index];
}
else{ // time > x[index]
return yRight[index];
}
}
}
@Override
public String toString(){
final StringBuilder builder = new StringBuilder();
builder.append("Default point: ");
builder.append(defaultY);
builder.append("\n");
for(int i=0; i<x.length; i++){
builder.append("x:");
builder.append(x[i]);
builder.append("\tyAt:");
builder.append(yAt[i]);
builder.append("\tyRight:");
builder.append(yRight[i]);
builder.append("\n");
}
return builder.toString();
}
}

View file

@ -45,20 +45,20 @@ public class TestDeterministicForests {
int index = 0;
for(int j=0; j<5; j++){
final NumericCovariate numericCovariate = new NumericCovariate("numeric"+j, index, false);
final NumericCovariate numericCovariate = new NumericCovariate("numeric"+j, index);
covariateList.add(numericCovariate);
index++;
}
for(int j=0; j<5; j++){
final BooleanCovariate booleanCovariate = new BooleanCovariate("boolean"+j, index, false);
final BooleanCovariate booleanCovariate = new BooleanCovariate("boolean"+j, index);
covariateList.add(booleanCovariate);
index++;
}
final List<String> levels = Utils.easyList("cat", "dog", "mouse");
for(int j=0; j<5; j++){
final FactorCovariate factorCovariate = new FactorCovariate("factor"+j, index, levels, false);
final FactorCovariate factorCovariate = new FactorCovariate("factor"+j, index, levels);
covariateList.add(factorCovariate);
index++;
}
@ -214,14 +214,14 @@ public class TestDeterministicForests {
forestTrainer5Trees.trainSerialOnDisk(Optional.empty());
forestTrainer10Trees.trainSerialOnDisk(Optional.empty());
final Forest<Double, Double> forestSerial = DataUtils.loadOnlineForest(saveTreeFile, new MeanResponseCombiner());
final Forest<Double, Double> forestSerial = DataUtils.loadForest(saveTreeFile, new MeanResponseCombiner());
TestUtils.removeFolder(saveTreeFile);
verifyTwoForestsEqual(testData, referenceForest, forestSerial);
forestTrainer5Trees.trainParallelOnDisk(Optional.empty(), 4);
forestTrainer10Trees.trainParallelOnDisk(Optional.empty(), 4);
final Forest<Double, Double> forestParallel = DataUtils.loadOnlineForest(saveTreeFile, new MeanResponseCombiner());
final Forest<Double, Double> forestParallel = DataUtils.loadForest(saveTreeFile, new MeanResponseCombiner());
TestUtils.removeFolder(saveTreeFile);
verifyTwoForestsEqual(testData, referenceForest, forestParallel);
@ -259,7 +259,7 @@ public class TestDeterministicForests {
for(int k=0; k<3; k++){
forestTrainer.trainSerialOnDisk(Optional.empty());
final Forest<Double, Double> replicantForest = DataUtils.loadOnlineForest(saveTreeFile, responseCombiner);
final Forest<Double, Double> replicantForest = DataUtils.loadForest(saveTreeFile, responseCombiner);
TestUtils.removeFolder(saveTreeFile);
verifyTwoForestsEqual(testData, referenceForest, replicantForest);
}
@ -274,7 +274,7 @@ public class TestDeterministicForests {
for(int k=0; k<3; k++){
forestTrainer.trainParallelOnDisk(Optional.empty(), 4);
final Forest<Double, Double> replicantForest = DataUtils.loadOnlineForest(saveTreeFile, responseCombiner);
final Forest<Double, Double> replicantForest = DataUtils.loadForest(saveTreeFile, responseCombiner);
TestUtils.removeFolder(saveTreeFile);
verifyTwoForestsEqual(testData, referenceForest, replicantForest);
}

View file

@ -20,8 +20,8 @@ import ca.joeltherrien.randomforest.covariates.Covariate;
import ca.joeltherrien.randomforest.covariates.numeric.NumericCovariate;
import ca.joeltherrien.randomforest.responses.regression.MeanResponseCombiner;
import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceSplitFinder;
import ca.joeltherrien.randomforest.tree.Forest;
import ca.joeltherrien.randomforest.tree.ForestTrainer;
import ca.joeltherrien.randomforest.tree.OnlineForest;
import ca.joeltherrien.randomforest.tree.Tree;
import ca.joeltherrien.randomforest.tree.TreeTrainer;
import ca.joeltherrien.randomforest.utils.DataUtils;
@ -39,12 +39,12 @@ import static org.junit.jupiter.api.Assertions.*;
public class TestProvidingInitialForest {
private OnlineForest<Double, Double> initialForest;
private Forest<Double, Double> initialForest;
private List<Covariate> covariateList;
private List<Row<Double>> data;
public TestProvidingInitialForest(){
covariateList = Collections.singletonList(new NumericCovariate("x", 0, false));
covariateList = Collections.singletonList(new NumericCovariate("x", 0));
data = Utils.easyList(
Row.createSimple(Utils.easyMap("x", "1.0"), covariateList, 1, 1.0),
@ -107,8 +107,8 @@ public class TestProvidingInitialForest {
public void testSerialInMemory(){
final ForestTrainer<Double, Double, Double> forestTrainer = getForestTrainer(null, 20);
final OnlineForest<Double, Double> newForest = forestTrainer.trainSerialInMemory(Optional.of(initialForest));
assertEquals(20, newForest.getNumberOfTrees());
final Forest<Double, Double> newForest = forestTrainer.trainSerialInMemory(Optional.of(initialForest));
assertEquals(20, newForest.getTrees().size());
for(Tree<Double> initialTree : initialForest.getTrees()){
assertTrue(newForest.getTrees().contains(initialTree));
@ -124,8 +124,8 @@ public class TestProvidingInitialForest {
public void testParallelInMemory(){
final ForestTrainer<Double, Double, Double> forestTrainer = getForestTrainer(null, 20);
final OnlineForest<Double, Double> newForest = forestTrainer.trainParallelInMemory(Optional.of(initialForest), 2);
assertEquals(20, newForest.getNumberOfTrees());
final Forest<Double, Double> newForest = forestTrainer.trainParallelInMemory(Optional.of(initialForest), 2);
assertEquals(20, newForest.getTrees().size());
for(Tree<Double> initialTree : initialForest.getTrees()){
assertTrue(newForest.getTrees().contains(initialTree));
@ -149,11 +149,11 @@ public class TestProvidingInitialForest {
forestTrainer.trainParallelOnDisk(Optional.of(initialForest), 2);
assertEquals(20, directory.listFiles().length);
final OnlineForest<Double, Double> newForest = DataUtils.loadOnlineForest(directory, new MeanResponseCombiner());
final Forest<Double, Double> newForest = DataUtils.loadForest(directory, new MeanResponseCombiner());
assertEquals(20, newForest.getNumberOfTrees());
assertEquals(20, newForest.getTrees().size());
final List<String> newForestTreesAsStrings = newForest.getTrees().stream()
.map(tree -> tree.toString()).collect(Collectors.toList());
@ -179,9 +179,9 @@ public class TestProvidingInitialForest {
assertEquals(20, directory.listFiles().length);
final OnlineForest<Double, Double> newForest = DataUtils.loadOnlineForest(directory, new MeanResponseCombiner());
final Forest<Double, Double> newForest = DataUtils.loadForest(directory, new MeanResponseCombiner());
assertEquals(20, newForest.getNumberOfTrees());
assertEquals(20, newForest.getTrees().size());
final List<String> newForestTreesAsStrings = newForest.getTrees().stream()
.map(tree -> tree.toString()).collect(Collectors.toList());
@ -198,7 +198,7 @@ public class TestProvidingInitialForest {
it's not clear if the forest being provided is the same one that trees were saved from.
*/
@Test
public void testExceptions(){
public void verifyExceptions(){
final String filePath = "src/test/resources/trees/";
final File directory = new File(filePath);
if(directory.exists()){

View file

@ -24,10 +24,11 @@ import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskRespons
import ca.joeltherrien.randomforest.responses.competingrisk.combiner.CompetingRiskFunctionCombiner;
import ca.joeltherrien.randomforest.responses.competingrisk.combiner.CompetingRiskResponseCombiner;
import ca.joeltherrien.randomforest.responses.competingrisk.splitfinder.LogRankSplitFinder;
import ca.joeltherrien.randomforest.tree.*;
import ca.joeltherrien.randomforest.tree.Forest;
import ca.joeltherrien.randomforest.tree.ForestTrainer;
import ca.joeltherrien.randomforest.tree.TreeTrainer;
import ca.joeltherrien.randomforest.utils.DataUtils;
import ca.joeltherrien.randomforest.utils.ResponseLoader;
import ca.joeltherrien.randomforest.utils.RightContinuousStepFunction;
import ca.joeltherrien.randomforest.utils.Utils;
import org.junit.jupiter.api.Test;
@ -46,10 +47,10 @@ public class TestSavingLoading {
public List<Covariate> getCovariates(){
return Utils.easyList(
new NumericCovariate("ageatfda", 0, false),
new BooleanCovariate("idu", 1, false),
new BooleanCovariate("black", 2, false),
new NumericCovariate("cd4nadir", 3, false)
new NumericCovariate("ageatfda", 0),
new BooleanCovariate("idu", 1),
new BooleanCovariate("black", 2),
new NumericCovariate("cd4nadir", 3)
);
}
@ -118,21 +119,16 @@ public class TestSavingLoading {
assertTrue(directory.isDirectory());
assertEquals(NTREE, directory.listFiles().length);
final CompetingRiskFunctionCombiner treeResponseCombiner = new CompetingRiskFunctionCombiner(new int[]{1,2}, null);
final OnlineForest<CompetingRiskFunctions, CompetingRiskFunctions> onlineForest = DataUtils.loadOnlineForest(directory, treeResponseCombiner);
final OfflineForest<CompetingRiskFunctions, CompetingRiskFunctions> offlineForest = new OfflineForest<>(directory, treeResponseCombiner);
final Forest<CompetingRiskFunctions, CompetingRiskFunctions> forest = DataUtils.loadForest(directory, new CompetingRiskFunctionCombiner(new int[]{1,2}, null));
final CovariateRow predictionRow = getPredictionRow(covariates);
final CompetingRiskFunctions functionsOnline = onlineForest.evaluate(predictionRow);
assertNotNull(functionsOnline);
assertTrue(functionsOnline.getCumulativeIncidenceFunction(1).getX().length > 2);
final CompetingRiskFunctions functionsOffline = offlineForest.evaluate(predictionRow);
assertTrue(competingFunctionsEqual(functionsOffline, functionsOnline));
final CompetingRiskFunctions functions = forest.evaluate(predictionRow);
assertNotNull(functions);
assertTrue(functions.getCumulativeIncidenceFunction(1).getX().length > 2);
assertEquals(NTREE, onlineForest.getTrees().size());
assertEquals(NTREE, forest.getTrees().size());
TestUtils.removeFolder(directory);
@ -163,22 +159,17 @@ public class TestSavingLoading {
assertEquals(NTREE, directory.listFiles().length);
final CompetingRiskFunctionCombiner treeResponseCombiner = new CompetingRiskFunctionCombiner(new int[]{1,2}, null);
final OnlineForest<CompetingRiskFunctions, CompetingRiskFunctions> onlineForest = DataUtils.loadOnlineForest(directory, treeResponseCombiner);
final OfflineForest<CompetingRiskFunctions, CompetingRiskFunctions> offlineForest = new OfflineForest<>(directory, treeResponseCombiner);
final Forest<CompetingRiskFunctions, CompetingRiskFunctions> forest = DataUtils.loadForest(directory, new CompetingRiskFunctionCombiner(new int[]{1,2}, null));
final CovariateRow predictionRow = getPredictionRow(covariates);
final CompetingRiskFunctions functionsOnline = onlineForest.evaluate(predictionRow);
assertNotNull(functionsOnline);
assertTrue(functionsOnline.getCumulativeIncidenceFunction(1).getX().length > 2);
final CompetingRiskFunctions functions = forest.evaluate(predictionRow);
assertNotNull(functions);
assertTrue(functions.getCumulativeIncidenceFunction(1).getX().length > 2);
final CompetingRiskFunctions functionsOffline = offlineForest.evaluate(predictionRow);
assertTrue(competingFunctionsEqual(functionsOffline, functionsOnline));
assertEquals(NTREE, onlineForest.getTrees().size());
assertEquals(NTREE, forest.getTrees().size());
TestUtils.removeFolder(directory);
@ -186,64 +177,6 @@ public class TestSavingLoading {
}
/*
We don't implement equals() methods on the below mentioned classes because then we'd need to implement an
appropriate hashCode() method that's consistent with the equals(), and we only need plain equals() for
these tests.
*/
private boolean competingFunctionsEqual(CompetingRiskFunctions f1 ,CompetingRiskFunctions f2){
if(!functionsEqual(f1.getSurvivalCurve(), f2.getSurvivalCurve())){
return false;
}
for(int i=1; i<=2; i++){
if(!functionsEqual(f1.getCauseSpecificHazardFunction(i), f2.getCauseSpecificHazardFunction(i))){
return false;
}
if(!functionsEqual(f1.getCumulativeIncidenceFunction(i), f2.getCumulativeIncidenceFunction(i))){
return false;
}
}
return true;
}
private boolean functionsEqual(RightContinuousStepFunction f1, RightContinuousStepFunction f2){
final double[] f1X = f1.getX();
final double[] f2X = f2.getX();
final double[] f1Y = f1.getY();
final double[] f2Y = f2.getY();
// first compare array lengths
if(f1X.length != f2X.length){
return false;
}
if(f1Y.length != f2Y.length){
return false;
}
// TODO - better comparisons of doubles. I don't really care too much though as this equals method is only being used in tests
final double delta = 0.000001;
if(Math.abs(f1.getDefaultY() - f2.getDefaultY()) > delta){
return false;
}
for(int i=0; i < f1X.length; i++){
if(Math.abs(f1X[i] - f2X[i]) > delta){
return false;
}
if(Math.abs(f1Y[i] - f2Y[i]) > delta){
return false;
}
}
return true;
}
}

View file

@ -156,7 +156,7 @@ public class TestUtils {
}
@Test
public void testReduceListToSize(){
public void reduceListToSize(){
final List<Integer> testList = Utils.easyList(1, 2, 3, 4, 5, 6, 7, 8, 9, 10);
final Random random = new Random();
for(int i=0; i<100; i++) { // want to test many times to ensure it doesn't work just due to randomness

View file

@ -1,164 +0,0 @@
package ca.joeltherrien.randomforest.competingrisk;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskFunctions;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskUtils;
import ca.joeltherrien.randomforest.responses.competingrisk.IBSCalculator;
import ca.joeltherrien.randomforest.utils.Point;
import ca.joeltherrien.randomforest.utils.RightContinuousStepFunction;
import ca.joeltherrien.randomforest.utils.Utils;
import org.junit.jupiter.api.Test;
import java.util.List;
import java.util.Optional;
import static org.junit.jupiter.api.Assertions.assertEquals;
public class IBSCalculatorTest {
private final RightContinuousStepFunction cif;
public IBSCalculatorTest(){
this.cif = RightContinuousStepFunction.constructFromPoints(
Utils.easyList(
new Point(1.0, 0.1),
new Point(2.0, 0.2),
new Point(3.0, 0.3),
new Point(4.0, 0.8)
), 0.0
);
}
/*
R code to get these results:
predicted_cif <- stepfun(1:4, c(0, 0.1, 0.2, 0.3, 0.8))
weights <- 1
recorded_time <- 2.0
recorded_status <- 1.0
event_of_interest <- 2
times <- 0:4
errors <- weights * ( as.integer(recorded_time <= times & recorded_status == event_of_interest) - predicted_cif(times))^2
sum(errors)
and run again with event_of_interest <- 1
Note that in the R code I only evaluate up to 4, while in the Java code I integrate up to 5
This is because the evaluation at 4 is giving the area of the rectangle from 4 to 5.
*/
@Test
public void testResultsWithoutCensoringDistribution(){
final IBSCalculator calculator = new IBSCalculator();
final double errorDifferentEvent = calculator.calculateError(
new CompetingRiskResponse(1, 2.0),
this.cif,
2,
5.0);
assertEquals(0.78, errorDifferentEvent, 0.000001);
final double errorSameEvent = calculator.calculateError(
new CompetingRiskResponse(1, 2.0),
this.cif,
1,
5.0);
assertEquals(1.18, errorSameEvent, 0.000001);
}
@Test
public void testResultsWithCensoringDistribution(){
final RightContinuousStepFunction censorSurvivalFunction = RightContinuousStepFunction.constructFromPoints(
Utils.easyList(
new Point(0.0, 0.75),
new Point(1.0, 0.5),
new Point(3.0, 0.25),
new Point(5.0, 0)
), 1.0
);
final IBSCalculator calculator = new IBSCalculator(censorSurvivalFunction);
final double errorDifferentEvent = calculator.calculateError(
new CompetingRiskResponse(1, 2.0),
this.cif,
2,
5.0);
assertEquals(1.56, errorDifferentEvent, 0.000001);
final double errorSameEvent = calculator.calculateError(
new CompetingRiskResponse(1, 2.0),
this.cif,
1,
5.0);
assertEquals(2.36, errorSameEvent, 0.000001);
}
@Test
public void testStaticFunction(){
final RightContinuousStepFunction censorSurvivalFunction = RightContinuousStepFunction.constructFromPoints(
Utils.easyList(
new Point(0.0, 0.75),
new Point(1.0, 0.5),
new Point(3.0, 0.25),
new Point(5.0, 0)
), 1.0
);
final List<CompetingRiskResponse> responseList = Utils.easyList(
new CompetingRiskResponse(1, 2.0),
new CompetingRiskResponse(1, 2.0));
// for predictions; we'll construct an improper CompetingRisksFunctions
final RightContinuousStepFunction trivialFunction = RightContinuousStepFunction.constructFromPoints(
Utils.easyList(new Point(1.0, 0.0)),
1.0);
final CompetingRiskFunctions prediction = CompetingRiskFunctions.builder()
.survivalCurve(trivialFunction)
.causeSpecificHazards(Utils.easyList(trivialFunction, trivialFunction))
.cumulativeIncidenceCurves(Utils.easyList(this.cif, trivialFunction))
.build();
final List<CompetingRiskFunctions> predictionList = Utils.easyList(prediction, prediction);
double[] errorParallel = CompetingRiskUtils.calculateIBSError(
responseList,
predictionList,
Optional.of(censorSurvivalFunction),
1,
5.0,
true);
double[] errorSerial = CompetingRiskUtils.calculateIBSError(
responseList,
predictionList,
Optional.of(censorSurvivalFunction),
1,
5.0,
false);
assertEquals(responseList.size(), errorParallel.length);
assertEquals(responseList.size(), errorSerial.length);
assertEquals(2.36, errorParallel[0], 0.000001);
assertEquals(2.36, errorParallel[1], 0.000001);
assertEquals(2.36, errorSerial[0], 0.000001);
assertEquals(2.36, errorSerial[1], 0.000001);
}
}

View file

@ -53,10 +53,10 @@ public class TestCompetingRisk {
public List<Covariate> getCovariates(){
return Utils.easyList(
new NumericCovariate("ageatfda", 0, false),
new BooleanCovariate("idu", 1, false),
new BooleanCovariate("black", 2, false),
new NumericCovariate("cd4nadir", 3, false)
new NumericCovariate("ageatfda", 0),
new BooleanCovariate("idu", 1),
new BooleanCovariate("black", 2),
new NumericCovariate("cd4nadir", 3)
);
}
@ -109,8 +109,8 @@ public class TestCompetingRisk {
// by only using BooleanCovariates (only one split rule) we can guarantee identical results with randomForestSRC on one tree.
final List<Covariate> covariates = Utils.easyList(
new BooleanCovariate("idu", 0, false),
new BooleanCovariate("black", 1, false)
new BooleanCovariate("idu", 0),
new BooleanCovariate("black", 1)
);
final List<Row<CompetingRiskResponse>> dataset = getData(covariates, "src/test/resources/wihs.bootstrapped.csv");
@ -210,8 +210,8 @@ public class TestCompetingRisk {
public void testLogRankSplitFinderTwoBooleans() throws IOException {
// by only using BooleanCovariates (only one split rule) we can guarantee identical results with randomForestSRC on one tree.
final List<Covariate> covariates = Utils.easyList(
new BooleanCovariate("idu", 0, false),
new BooleanCovariate("black", 1, false)
new BooleanCovariate("idu", 0),
new BooleanCovariate("black", 1)
);
@ -259,7 +259,7 @@ public class TestCompetingRisk {
}
@Test
public void testDataset() throws IOException {
public void verifyDataset() throws IOException {
final List<Covariate> covariates = getCovariates();
final List<Row<CompetingRiskResponse>> dataset = getData(covariates, DEFAULT_FILEPATH);

View file

@ -16,11 +16,10 @@
package ca.joeltherrien.randomforest.competingrisk;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskFunctions;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskUtils;
import ca.joeltherrien.randomforest.Row;
import ca.joeltherrien.randomforest.covariates.Covariate;
import ca.joeltherrien.randomforest.responses.competingrisk.*;
import ca.joeltherrien.randomforest.tree.Forest;
import ca.joeltherrien.randomforest.tree.OnlineForest;
import ca.joeltherrien.randomforest.utils.RightContinuousStepFunction;
import ca.joeltherrien.randomforest.utils.StepFunction;
import ca.joeltherrien.randomforest.utils.Utils;
@ -31,6 +30,8 @@ import java.util.List;
import static ca.joeltherrien.randomforest.TestUtils.closeEnough;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
public class TestCompetingRiskErrorRateCalculator {
@ -47,7 +48,7 @@ public class TestCompetingRiskErrorRateCalculator {
final int event = 1;
final Forest<CompetingRiskFunctions, CompetingRiskFunctions> fakeForest = OnlineForest.<CompetingRiskFunctions, CompetingRiskFunctions>builder().build();
final Forest<CompetingRiskFunctions, CompetingRiskFunctions> fakeForest = Forest.<CompetingRiskFunctions, CompetingRiskFunctions>builder().build();
final double naiveConcordance = CompetingRiskUtils.calculateConcordance(responseList, mortalityArray, event);

View file

@ -46,7 +46,7 @@ public class TestLogRankSplitFinder {
public static Data<CompetingRiskResponse> loadData(String filename) throws IOException {
final List<Covariate> covariates = Utils.easyList(
new NumericCovariate("x2", 0, false)
new NumericCovariate("x2", 0)
);
final List<Row<CompetingRiskResponse>> rows = TestUtils.loadData(covariates, new ResponseLoader.CompetingRisksResponseLoader("delta", "u"), filename);

View file

@ -16,6 +16,7 @@
package ca.joeltherrien.randomforest.competingrisk;
import ca.joeltherrien.randomforest.utils.LeftContinuousStepFunction;
import ca.joeltherrien.randomforest.utils.RightContinuousStepFunction;
import org.junit.jupiter.api.Test;
@ -30,6 +31,13 @@ public class TestMathFunctions {
return new RightContinuousStepFunction(time, y, 0.1);
}
private LeftContinuousStepFunction generateLeftContinuousStepFunction(){
final double[] time = new double[]{1.0, 2.0, 3.0};
final double[] y = new double[]{-1.0, 1.0, 0.5};
return new LeftContinuousStepFunction(time, y, 0.1);
}
@Test
public void testRightContinuousStepFunction(){
final RightContinuousStepFunction function = generateRightContinuousStepFunction();
@ -48,5 +56,21 @@ public class TestMathFunctions {
}
@Test
public void testLeftContinuousStepFunction(){
final LeftContinuousStepFunction function = generateLeftContinuousStepFunction();
assertEquals(0.1, function.evaluate(0.5));
assertEquals(0.1, function.evaluate(1.0));
assertEquals(-1.0, function.evaluate(2.0));
assertEquals(1.0, function.evaluate(3.0));
assertEquals(0.1, function.evaluate(0.6));
assertEquals(-1.0, function.evaluate(1.1));
assertEquals(1.0, function.evaluate(2.1));
assertEquals(0.5, function.evaluate(3.1));
}
}

View file

@ -17,15 +17,12 @@
package ca.joeltherrien.randomforest.covariates;
import ca.joeltherrien.randomforest.Row;
import ca.joeltherrien.randomforest.covariates.factor.FactorCovariate;
import ca.joeltherrien.randomforest.tree.Split;
import ca.joeltherrien.randomforest.utils.Utils;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.function.Executable;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Random;
@ -34,7 +31,7 @@ import static org.junit.jupiter.api.Assertions.*;
public class FactorCovariateTest {
@Test
public void testEqualLevels() {
void verifyEqualLevels() {
final FactorCovariate petCovariate = createTestCovariate();
final FactorCovariate.FactorValue dog1 = petCovariate.createValue("DOG");
@ -56,7 +53,7 @@ public class FactorCovariateTest {
}
@Test
public void testBadLevelException(){
void verifyBadLevelException(){
final FactorCovariate petCovariate = createTestCovariate();
final Executable badCode = () -> petCovariate.createValue("vulcan");
@ -64,169 +61,25 @@ public class FactorCovariateTest {
}
@Test
public void testAllSubsets(){
final int n = 2*3; // ensure that n is a multiple of 3 for the test
void testAllSubsets(){
final FactorCovariate petCovariate = createTestCovariate();
final List<Row<Double>> data = generateSampleData(petCovariate, n);
final List<Split<Double, String>> splits = new ArrayList<>();
final List<SplitRule<String>> splitRules = new ArrayList<>();
petCovariate.generateSplitRuleUpdater(data, 100, new Random())
.forEachRemaining(split -> splits.add(split));
petCovariate.generateSplitRuleUpdater(null, 100, new Random())
.forEachRemaining(split -> splitRules.add(split.getSplitRule()));
assertEquals(splits.size(), 3);
assertEquals(splitRules.size(), 3);
// These are the 3 possibilities
boolean dog_catmouse = false;
boolean cat_dogmouse = false;
boolean mouse_dogcat = false;
// TODO verify the contents of the split rules
for(Split<Double, String> split : splits){
List<Row<Double>> smallerHand;
List<Row<Double>> largerHand;
if(split.getLeftHand().size() < split.getRightHand().size()){
smallerHand = split.getLeftHand();
largerHand = split.getRightHand();
} else{
smallerHand = split.getRightHand();
largerHand = split.getLeftHand();
}
// There should be exactly one distinct value in the smaller list
assertEquals(n/3, smallerHand.size());
assertEquals(1,
smallerHand.stream()
.map(row -> row.getCovariateValue(petCovariate).getValue())
.distinct()
.count()
);
// There should be exactly two distinct values in the smaller list
assertEquals(2*n/3, largerHand.size());
assertEquals(2,
largerHand.stream()
.map(row -> row.getCovariateValue(petCovariate).getValue())
.distinct()
.count()
);
switch(smallerHand.get(0).getCovariateValue(petCovariate).getValue()){
case "DOG":
dog_catmouse = true;
case "CAT":
cat_dogmouse = true;
case "MOUSE":
mouse_dogcat = true;
}
}
assertTrue(dog_catmouse);
assertTrue(cat_dogmouse);
assertTrue(mouse_dogcat);
}
/*
* There was a bug where if number==0 in generateSplitRuleUpdater, then the result was empty.
*/
@Test
public void testNumber0Subsets(){
final int n = 2*3; // ensure that n is a multiple of 3 for the test
final FactorCovariate petCovariate = createTestCovariate();
final List<Row<Double>> data = generateSampleData(petCovariate, n);
final List<Split<Double, String>> splits = new ArrayList<>();
petCovariate.generateSplitRuleUpdater(data, 0, new Random())
.forEachRemaining(split -> splits.add(split));
assertEquals(splits.size(), 3);
// These are the 3 possibilities
boolean dog_catmouse = false;
boolean cat_dogmouse = false;
boolean mouse_dogcat = false;
for(Split<Double, String> split : splits){
List<Row<Double>> smallerHand;
List<Row<Double>> largerHand;
if(split.getLeftHand().size() < split.getRightHand().size()){
smallerHand = split.getLeftHand();
largerHand = split.getRightHand();
} else{
smallerHand = split.getRightHand();
largerHand = split.getLeftHand();
}
// There should be exactly one distinct value in the smaller list
assertEquals(n/3, smallerHand.size());
assertEquals(1,
smallerHand.stream()
.map(row -> row.getCovariateValue(petCovariate).getValue())
.distinct()
.count()
);
// There should be exactly two distinct values in the smaller list
assertEquals(2*n/3, largerHand.size());
assertEquals(2,
largerHand.stream()
.map(row -> row.getCovariateValue(petCovariate).getValue())
.distinct()
.count()
);
switch(smallerHand.get(0).getCovariateValue(petCovariate).getValue()){
case "DOG":
dog_catmouse = true;
case "CAT":
cat_dogmouse = true;
case "MOUSE":
mouse_dogcat = true;
}
}
assertTrue(dog_catmouse);
assertTrue(cat_dogmouse);
assertTrue(mouse_dogcat);
}
@Test
public void testSpitRuleUpdaterWithNAs(){
// When some NAs were present calling generateSplitRuleUpdater caused an exception.
final FactorCovariate covariate = createTestCovariate();
final List<Row<Double>> sampleData = generateSampleData(covariate, 10);
sampleData.add(Row.createSimple(Utils.easyMap("pet", "NA"), Collections.singletonList(covariate), 11, 5.0));
covariate.generateSplitRuleUpdater(sampleData, 0, new Random());
// Test passes if no exception has occurred.
}
private FactorCovariate createTestCovariate(){
final List<String> levels = Utils.easyList("DOG", "CAT", "MOUSE");
return new FactorCovariate("pet", 0, levels, false);
}
private List<Row<Double>> generateSampleData(Covariate covariate, int n){
final List<Covariate> covariateList = Collections.singletonList(covariate);
final List<Row<Double>> dataList = new ArrayList<>(n);
final String[] levels = new String[]{"DOG", "CAT", "MOUSE"};
for(int i=0; i<n; i++){
dataList.add(Row.createSimple(Utils.easyMap("pet", levels[i % 3]), covariateList, 1, 1.0));
}
return dataList;
return new FactorCovariate("pet", 0, levels);
}

View file

@ -70,7 +70,7 @@ public class NumericCovariateTest {
@Test
public void testNumericCovariateDeterministic(){
final NumericCovariate covariate = new NumericCovariate("x", 0, false);
final NumericCovariate covariate = new NumericCovariate("x", 0);
final List<Row<Double>> dataset = createTestDataset(covariate);
@ -158,7 +158,7 @@ public class NumericCovariateTest {
@Test
public void testNumericSplitRuleUpdaterWithIndexes(){
final NumericCovariate covariate = new NumericCovariate("x", 0, false);
final NumericCovariate covariate = new NumericCovariate("x", 0);
final List<Row<Double>> dataset = createTestDataset(covariate);
@ -223,7 +223,7 @@ public class NumericCovariateTest {
*/
@Test
public void testNumericSplitRuleUpdaterWithIndexesAllMissingData(){
final NumericCovariate covariate = new NumericCovariate("x", 0, false);
final NumericCovariate covariate = new NumericCovariate("x", 0);
final List<Row<Double>> dataset = createTestDatasetMissingValues(covariate);
final NumericSplitRuleUpdater<Double> updater = covariate.generateSplitRuleUpdater(dataset, 5, new Random());

View file

@ -18,34 +18,31 @@ package ca.joeltherrien.randomforest.nas;
import ca.joeltherrien.randomforest.Row;
import ca.joeltherrien.randomforest.covariates.Covariate;
import ca.joeltherrien.randomforest.covariates.bool.BooleanCovariate;
import ca.joeltherrien.randomforest.covariates.factor.FactorCovariate;
import ca.joeltherrien.randomforest.covariates.numeric.NumericCovariate;
import ca.joeltherrien.randomforest.responses.regression.MeanResponseCombiner;
import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceSplitFinder;
import ca.joeltherrien.randomforest.tree.Split;
import ca.joeltherrien.randomforest.tree.TreeTrainer;
import ca.joeltherrien.randomforest.utils.Utils;
import org.junit.jupiter.api.Test;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Random;
import static org.junit.jupiter.api.Assertions.*;
public class TestNAs {
private List<Row<Double>> generateData1(List<Covariate> covariates){
private List<Row<Double>> generateData(List<Covariate> covariates){
final List<Row<Double>> dataList = new ArrayList<>();
// We must include an NA for one of the values
dataList.add(Row.createSimple(Utils.easyMap("x", "NA", "y", "true", "z", "green"), covariates, 1, 5.0));
dataList.add(Row.createSimple(Utils.easyMap("x", "1", "y", "NA", "z", "blue"), covariates, 2, 6.0));
dataList.add(Row.createSimple(Utils.easyMap("x", "2", "y", "true", "z", "NA"), covariates, 3, 5.5));
dataList.add(Row.createSimple(Utils.easyMap("x", "7", "y", "false", "z", "green"), covariates, 4, 0.0));
dataList.add(Row.createSimple(Utils.easyMap("x", "8", "y", "true", "z", "blue"), covariates, 5, 1.0));
dataList.add(Row.createSimple(Utils.easyMap("x", "8.4", "y", "false", "z", "yellow"), covariates, 6, 1.0));
dataList.add(Row.createSimple(Utils.easyMap("x", "NA"), covariates, 1, 5.0));
dataList.add(Row.createSimple(Utils.easyMap("x", "1"), covariates, 1, 6.0));
dataList.add(Row.createSimple(Utils.easyMap("x", "2"), covariates, 1, 5.5));
dataList.add(Row.createSimple(Utils.easyMap("x", "7"), covariates, 1, 0.0));
dataList.add(Row.createSimple(Utils.easyMap("x", "8"), covariates, 1, 1.0));
dataList.add(Row.createSimple(Utils.easyMap("x", "8.4"), covariates, 1, 1.0));
return dataList;
@ -57,19 +54,14 @@ public class TestNAs {
// but NumericSplitRuleUpdater had unmodifiable lists when creating the split.
// This bug verifies that this no longer causes a crash
final List<Covariate> covariates = Utils.easyList(
new NumericCovariate("x", 0, false),
new BooleanCovariate("y", 1, true),
new FactorCovariate("z", 2, Utils.easyList("green", "blue", "yellow"), true)
);
final List<Row<Double>> dataset = generateData1(covariates);
final List<Covariate> covariates = Collections.singletonList(new NumericCovariate("x", 0));
final List<Row<Double>> dataset = generateData(covariates);
final TreeTrainer<Double, Double> treeTrainer = TreeTrainer.<Double, Double>builder()
.checkNodePurity(false)
.covariates(covariates)
.numberOfSplits(0)
.nodeSize(1)
.mtry(3)
.maxNodeDepth(1000)
.splitFinder(new WeightedVarianceSplitFinder())
.responseCombiner(new MeanResponseCombiner())
@ -78,87 +70,6 @@ public class TestNAs {
treeTrainer.growTree(dataset, new Random(123));
// As long as no exception occurs, we passed
}
private List<Row<Double>> generateData2(List<Covariate> covariates){
final List<Row<Double>> dataList = new ArrayList<>();
// Idea - when ignoring NAs, BadVar gives a perfect split.
// GoodVar is slightly worse than BadVar when NAs are excluded.
// However, BadVar has a ton of NAs that will degrade its performance.
dataList.add(Row.createSimple(
Utils.easyMap("BadVar", "-1.0", "GoodVar", "true") // GoodVars one error
, covariates, 1, 5.0)
);
dataList.add(Row.createSimple(
Utils.easyMap("BadVar", "NA", "GoodVar", "false")
, covariates, 2, 5.0)
);
dataList.add(Row.createSimple(
Utils.easyMap("BadVar", "NA", "GoodVar", "false")
, covariates, 3, 5.0)
);
dataList.add(Row.createSimple(
Utils.easyMap("BadVar", "0.5", "GoodVar", "true")
, covariates, 4, 10.0)
);
dataList.add(Row.createSimple(
Utils.easyMap("BadVar", "NA", "GoodVar", "true")
, covariates, 5, 10.0)
);
dataList.add(Row.createSimple(
Utils.easyMap("BadVar", "NA", "GoodVar", "true")
, covariates, 6, 10.0)
);
return dataList;
}
@Test
// Test that the NA penalty works when selecting a best split.
public void testNAPenalty(){
final List<Covariate> covariates1 = Utils.easyList(
new NumericCovariate("BadVar", 0, true),
new BooleanCovariate("GoodVar", 1, false)
);
final List<Row<Double>> dataList1 = generateData2(covariates1);
final TreeTrainer<Double, Double> treeTrainer1 = TreeTrainer.<Double, Double>builder()
.checkNodePurity(false)
.covariates(covariates1)
.numberOfSplits(0)
.nodeSize(1)
.mtry(2)
.maxNodeDepth(1000)
.splitFinder(new WeightedVarianceSplitFinder())
.responseCombiner(new MeanResponseCombiner())
.build();
final Split<Double, ?> bestSplit1 = treeTrainer1.findBestSplitRule(dataList1, covariates1, new Random(123));
assertEquals(1, bestSplit1.getSplitRule().getParentCovariateIndex()); // 1 corresponds to GoodVar
// Run again without the penalty; verify that we get different results
final List<Covariate> covariates2 = Utils.easyList(
new NumericCovariate("BadVar", 0, false),
new BooleanCovariate("GoodVar", 1, false)
);
final List<Row<Double>> dataList2 = generateData2(covariates2);
final TreeTrainer<Double, Double> treeTrainer2 = TreeTrainer.<Double, Double>builder()
.checkNodePurity(false)
.covariates(covariates2)
.numberOfSplits(0)
.nodeSize(1)
.mtry(2)
.maxNodeDepth(1000)
.splitFinder(new WeightedVarianceSplitFinder())
.responseCombiner(new MeanResponseCombiner())
.build();
final Split<Double, ?> bestSplit2 = treeTrainer2.findBestSplitRule(dataList2, covariates2, new Random(123));
assertEquals(0, bestSplit2.getSplitRule().getParentCovariateIndex()); // 1 corresponds to GoodVar
}

View file

@ -1,143 +0,0 @@
package ca.joeltherrien.randomforest.tree.vimp;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskFunctions;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
import ca.joeltherrien.randomforest.responses.competingrisk.IBSCalculator;
import ca.joeltherrien.randomforest.utils.Point;
import ca.joeltherrien.randomforest.utils.RightContinuousStepFunction;
import ca.joeltherrien.randomforest.utils.Utils;
import org.junit.jupiter.api.Test;
import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
public class IBSErrorCalculatorWrapperTest {
/*
We already have tests for the IBSCalculator, so these tests are concerned with making sure we correctly average
the errors together, not that we fully test the production of each error under different scenarios (like
providing / not providing a censoring distribution).
*/
private final double integrationUpperBound = 5.0;
private final List<CompetingRiskResponse> responses;
private final List<CompetingRiskFunctions> functions;
private final double[][] errors;
public IBSErrorCalculatorWrapperTest(){
this.responses = Utils.easyList(
new CompetingRiskResponse(0, 2.0),
new CompetingRiskResponse(0, 3.0),
new CompetingRiskResponse(1, 1.0),
new CompetingRiskResponse(1, 1.5),
new CompetingRiskResponse(2, 3.0),
new CompetingRiskResponse(2, 4.0)
);
final RightContinuousStepFunction cif1 = RightContinuousStepFunction.constructFromPoints(Utils.easyList(
new Point(1.0, 0.25),
new Point(1.5, 0.45)
), 0.0);
final RightContinuousStepFunction cif2 = RightContinuousStepFunction.constructFromPoints(Utils.easyList(
new Point(3.0, 0.25),
new Point(4.0, 0.45)
), 0.0);
// This function is for the unused CHFs and survival curve
// If we see infinities or NaNs popping up in our output we should look here.
final RightContinuousStepFunction emptyFun = RightContinuousStepFunction.constructFromPoints(Utils.easyList(
new Point(0.0, Double.NaN)
), Double.NEGATIVE_INFINITY
);
final CompetingRiskFunctions function = CompetingRiskFunctions.builder()
.cumulativeIncidenceCurves(Utils.easyList(cif1, cif2))
.causeSpecificHazards(Utils.easyList(emptyFun, emptyFun))
.survivalCurve(emptyFun)
.build();
// Same prediction for every response.
this.functions = Utils.easyList(function, function, function, function, function, function);
final IBSCalculator calculator = new IBSCalculator();
this.errors = new double[2][6];
for(int event : new int[]{1, 2}){
for(int i=0; i<6; i++){
this.errors[event-1][i] = calculator.calculateError(
responses.get(i), function.getCumulativeIncidenceFunction(event),
event, integrationUpperBound
);
}
}
}
@Test
public void testOneEventOne(){
final IBSErrorCalculatorWrapper wrapper = new IBSErrorCalculatorWrapper(new IBSCalculator(), new int[]{1},
this.integrationUpperBound);
final double error = wrapper.averageError(this.responses, this.functions);
double expectedError = 0.0;
for(int i=0; i<6; i++){
expectedError += errors[0][i] / 6.0;
}
assertEquals(expectedError, error, 0.00000001);
}
@Test
public void testOneEventTwo(){
final IBSErrorCalculatorWrapper wrapper = new IBSErrorCalculatorWrapper(new IBSCalculator(), new int[]{2},
this.integrationUpperBound);
final double error = wrapper.averageError(this.responses, this.functions);
double expectedError = 0.0;
for(int i=0; i<6; i++){
expectedError += errors[1][i] / 6.0;
}
assertEquals(expectedError, error, 0.00000001);
}
@Test
public void testTwoEventsNoWeights(){
final IBSErrorCalculatorWrapper wrapper = new IBSErrorCalculatorWrapper(new IBSCalculator(), new int[]{1, 2},
this.integrationUpperBound);
final double error = wrapper.averageError(this.responses, this.functions);
double expectedError1 = 0.0;
double expectedError2 = 0.0;
for(int i=0; i<6; i++){
expectedError1 += errors[0][i] / 6.0;
expectedError2 += errors[1][i] / 6.0;
}
assertEquals(expectedError1 + expectedError2, error, 0.00000001);
}
@Test
public void testTwoEventsWithWeights(){
final IBSErrorCalculatorWrapper wrapper = new IBSErrorCalculatorWrapper(new IBSCalculator(), new int[]{1, 2},
this.integrationUpperBound, new double[]{1.0, 2.0});
final double error = wrapper.averageError(this.responses, this.functions);
double expectedError1 = 0.0;
double expectedError2 = 0.0;
for(int i=0; i<6; i++){
expectedError1 += errors[0][i] / 6.0;
expectedError2 += errors[1][i] / 6.0;
}
assertEquals(1.0 * expectedError1 + 2.0 * expectedError2, error, 0.00000001);
}
}

View file

@ -1,26 +0,0 @@
package ca.joeltherrien.randomforest.tree.vimp;
import ca.joeltherrien.randomforest.utils.Utils;
import org.junit.jupiter.api.Test;
import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
public class RegressionErrorCalculatorTest {
private final RegressionErrorCalculator calculator = new RegressionErrorCalculator();
@Test
public void testRegressionErrorCalculator(){
final List<Double> responses = Utils.easyList(1.0, 1.5, 0.0, 3.0);
final List<Double> predictions = Utils.easyList(1.5, 1.7, 0.1, 2.9);
// Differences are 0.5, 0.2, -0.1, 0.1
// Squared: 0.25, 0.04, 0.01, 0.01
assertEquals((0.25 + 0.04 + 0.01 + 0.01)/4.0, calculator.averageError(responses, predictions), 0.000000001);
}
}

View file

@ -1,484 +0,0 @@
package ca.joeltherrien.randomforest.tree.vimp;
import ca.joeltherrien.randomforest.Row;
import ca.joeltherrien.randomforest.covariates.Covariate;
import ca.joeltherrien.randomforest.covariates.bool.BooleanCovariate;
import ca.joeltherrien.randomforest.covariates.bool.BooleanSplitRule;
import ca.joeltherrien.randomforest.covariates.factor.FactorCovariate;
import ca.joeltherrien.randomforest.covariates.numeric.NumericCovariate;
import ca.joeltherrien.randomforest.covariates.numeric.NumericSplitRule;
import ca.joeltherrien.randomforest.responses.regression.MeanResponseCombiner;
import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceSplitFinder;
import ca.joeltherrien.randomforest.tree.*;
import ca.joeltherrien.randomforest.utils.Utils;
import org.junit.jupiter.api.Test;
import java.util.List;
import java.util.Optional;
import java.util.Random;
import java.util.stream.Collectors;
import static org.junit.jupiter.api.Assertions.assertEquals;
public class TestVariableImportanceCalculator {
/*
Since the logic for VariableImportanceCalculator is generic, it will be much easier to test under a regression
setting.
*/
// We'l have a very simple Forest of two trees
private final OnlineForest<Double, Double> forest;
private final List<Covariate> covariates;
private final List<Row<Double>> rowList;
/*
Long setup process; forest is manually constructed so that we can be exactly sure on our variable importance.
*/
public TestVariableImportanceCalculator(){
final BooleanCovariate booleanCovariate = new BooleanCovariate("x", 0, false);
final NumericCovariate numericCovariate = new NumericCovariate("y", 1, false);
final FactorCovariate factorCovariate = new FactorCovariate("z", 2,
Utils.easyList("red", "blue", "green"), false);
this.covariates = Utils.easyList(booleanCovariate, numericCovariate, factorCovariate);
final TreeTrainer<Double, Double> treeTrainer = TreeTrainer.<Double, Double>builder()
.responseCombiner(new MeanResponseCombiner())
.splitFinder(new WeightedVarianceSplitFinder())
.numberOfSplits(0)
.nodeSize(1)
.maxNodeDepth(100)
.mtry(3)
.checkNodePurity(false)
.covariates(this.covariates)
.build();
/*
Plan for data - BooleanCovariate is split on first and has the largest impact.
NumericCovariate is at second level and has more minimal impact.
FactorCovariate is useless and never used.
Our tree (we'll duplicate it for testing OOB errors) will have a depth of 1. (0 based).
*/
final Tree<Double> tree1 = makeTree(covariates, 0.0, new int[]{1,2,3,4});
final Tree<Double> tree2 = makeTree(covariates, 2.0, new int[]{5,6,7,8});
this.forest = OnlineForest.<Double, Double>builder()
.trees(Utils.easyList(tree1, tree2))
.treeResponseCombiner(new MeanResponseCombiner())
.build();
// formula; boolean high adds 100; high numeric adds 10
// This row list should have a baseline error of 0.0
this.rowList = Utils.easyList(
Row.createSimple(Utils.easyMap(
"x", "false",
"y", "0.0",
"z", "red"),
covariates, 1, 0.0
),
Row.createSimple(Utils.easyMap(
"x", "false",
"y", "10.0",
"z", "blue"),
covariates, 2, 10.0
),
Row.createSimple(Utils.easyMap(
"x", "true",
"y", "0.0",
"z", "red"),
covariates, 3, 100.0
),
Row.createSimple(Utils.easyMap(
"x", "true",
"y", "10.0",
"z", "green"),
covariates, 4, 110.0
),
Row.createSimple(Utils.easyMap(
"x", "false",
"y", "0.0",
"z", "red"),
covariates, 5, 0.0
),
Row.createSimple(Utils.easyMap(
"x", "false",
"y", "10.0",
"z", "blue"),
covariates, 6, 10.0
),
Row.createSimple(Utils.easyMap(
"x", "true",
"y", "0.0",
"z", "red"),
covariates, 7, 100.0
),
Row.createSimple(Utils.easyMap(
"x", "true",
"y", "10.0",
"z", "green"),
covariates, 8, 110.0
)
);
}
private Tree<Double> makeTree(List<Covariate> covariates, double offset, int[] indices){
// Naming convention - xyTerminal where x and y are low/high denotes whether BooleanCovariate(x) is low/high and
// whether NumericCovariate(y) is low/high.
final TerminalNode<Double> lowLowTerminal = new TerminalNode<>(0.0 + offset, 5);
final TerminalNode<Double> lowHighTerminal = new TerminalNode<>(10.0 + offset, 5);
final TerminalNode<Double> highLowTerminal = new TerminalNode<>(100.0 + offset, 5);
final TerminalNode<Double> highHighTerminal = new TerminalNode<>(110.0 + offset, 5);
final SplitNode<Double> lowSplitNode = SplitNode.<Double>builder()
.leftHand(lowLowTerminal)
.rightHand(lowHighTerminal)
.probabilityNaLeftHand(0.5)
.splitRule(new NumericSplitRule((NumericCovariate) covariates.get(1), 5.0))
.build();
final SplitNode<Double> highSplitNode = SplitNode.<Double>builder()
.leftHand(highLowTerminal)
.rightHand(highHighTerminal)
.probabilityNaLeftHand(0.5)
.splitRule(new NumericSplitRule((NumericCovariate) covariates.get(1), 5.0))
.build();
final SplitNode<Double> rootSplitNode = SplitNode.<Double>builder()
.leftHand(lowSplitNode)
.rightHand(highSplitNode)
.probabilityNaLeftHand(0.5)
.splitRule(new BooleanSplitRule((BooleanCovariate) covariates.get(0)))
.build();
return new Tree<>(rootSplitNode, indices);
}
// Experiment with random seeds to first examine what a split does so we know what to expect
/*
public static void main(String[] args){
// Behaviour for OOB
final List<Integer> ints1 = IntStream.range(5, 9).boxed().collect(Collectors.toList());
final List<Integer> ints2 = IntStream.range(1, 5).boxed().collect(Collectors.toList());
final Random random = new Random(123);
Collections.shuffle(ints1, random);
Collections.shuffle(ints2, random);
System.out.println(ints1);
System.out.println(ints2);
// [5, 6, 8, 7]
// [3, 4, 1, 2]
// Behaviour for no-OOB
final List<Integer> fullInts1 = IntStream.range(1, 9).boxed().collect(Collectors.toList());
final List<Integer> fullInts2 = IntStream.range(1, 9).boxed().collect(Collectors.toList());
final Random fullIntsRandom = new Random(123);
Collections.shuffle(fullInts1, fullIntsRandom);
Collections.shuffle(fullInts2, fullIntsRandom);
System.out.println(fullInts1);
System.out.println(fullInts2);
// [1, 4, 8, 2, 5, 3, 7, 6]
// [6, 1, 4, 7, 5, 2, 8, 3]
}
*/
private double[] difference(double[] a, double[] b){
final double[] results = new double[a.length];
for(int i = 0; i < a.length; i++){
results[i] = a[i] - b[i];
}
return results;
}
private void assertDoubleEquals(double[] expected, double[] actual){
assertEquals(expected.length, actual.length, "Lengths of arrays should be equal");
for(int i=0; i < expected.length; i++){
assertEquals(expected[i], actual[i], 0.0000001, "Difference at " + i);
}
}
@Test
public void testVariableImportanceOnXNoOOB(){
// x is the BooleanCovariate
final VariableImportanceCalculator<Double, Double> calculator = new VariableImportanceCalculator<>(
new RegressionErrorCalculator(),
this.forest.getTrees(),
this.rowList,
false
);
final Covariate covariate = this.covariates.get(0);
double importance[] = calculator.calculateVariableImportanceRaw(covariate, Optional.of(new Random(123)));
final double[] expectedBaselineError = {0.0, 4.0}; // first tree is accurate, second tree is not
// [1, 4, 8, 2, 5, 3, 7, 6]
final List<Double> permutedPredictionsTree1 = Utils.easyList(
0., 110., 100., 10., 0., 110., 100., 10.
);
// [6, 1, 4, 7, 5, 2, 8, 3]
// Actual: [F, F, T, T, F, F, T, T]
// Seen: [F, F, T, T, F, F, T, T]
// Difference: 0 all around; random chance
final List<Double> permutedPredictionsTree2 = Utils.easyList(
2., 12., 102., 112., 2., 12., 102., 112.
);
final List<Double> observedValues = this.rowList.stream().map(Row::getResponse).collect(Collectors.toList());
final double[] expectedError = new double[2];
expectedError[0] = new RegressionErrorCalculator().averageError(observedValues, permutedPredictionsTree1);
expectedError[1] = new RegressionErrorCalculator().averageError(observedValues, permutedPredictionsTree2);
final double[] expectedVimp = difference(expectedError, expectedBaselineError);
assertDoubleEquals(expectedVimp, importance);
final double expectedVimpMean = (expectedVimp[0] + expectedVimp[1]) / 2.0;
final double expectedVimpVar = (Math.pow(expectedVimp[0] - expectedVimpMean, 2) + Math.pow(expectedVimp[1] - expectedVimpMean, 2)) / 1.0;
final double expectedVimpStandardError = Math.sqrt(expectedVimpVar / 2.0);
final double expectedZScore = expectedVimpMean / expectedVimpStandardError;
final double actualZScore = calculator.calculateVariableImportanceZScore(covariate, Optional.of(new Random(123)));
assertEquals(expectedZScore, actualZScore, 0.000001, "Z scores must match");
}
@Test
public void testVariableImportanceOnXOOB(){
// x is the BooleanCovariate
final VariableImportanceCalculator<Double, Double> calculator = new VariableImportanceCalculator<>(
new RegressionErrorCalculator(),
this.forest.getTrees(),
this.rowList,
true
);
final Covariate covariate = this.covariates.get(0);
double importance[] = calculator.calculateVariableImportanceRaw(covariate, Optional.of(new Random(123)));
final double[] expectedBaselineError = {0.0, 4.0}; // first tree is accurate, second tree is not
// [5, 6, 8, 7]
// Actual: [F, F, T, T]
// Seen: [F, F, T, T]
// Difference: No differences
final List<Double> permutedPredictionsTree1 = Utils.easyList(
0., 10., 100., 110.
);
// [3, 4, 1, 2]
// Actual: [F, F, T, T]
// Seen: [T, T, F, F]
// Difference: +100, +100, -100, -100
final List<Double> permutedPredictionsTree2 = Utils.easyList(
102., 112., 2., 12.
);
final List<Double> observedValues = this.rowList.stream().map(Row::getResponse).collect(Collectors.toList());
final List<Double> tree1OOBValues = observedValues.subList(4, 8);
final List<Double> tree2OOBValues = observedValues.subList(0, 4);
final double[] expectedError = new double[2];
expectedError[0] = new RegressionErrorCalculator().averageError(tree1OOBValues, permutedPredictionsTree1);
expectedError[1] = new RegressionErrorCalculator().averageError(tree2OOBValues, permutedPredictionsTree2);
final double[] expectedVimp = difference(expectedError, expectedBaselineError);
assertDoubleEquals(expectedVimp, importance);
final double expectedVimpMean = (expectedVimp[0] + expectedVimp[1]) / 2.0;
final double expectedVimpVar = (Math.pow(expectedVimp[0] - expectedVimpMean, 2) + Math.pow(expectedVimp[1] - expectedVimpMean, 2)) / 1.0;
final double expectedVimpStandardError = Math.sqrt(expectedVimpVar / 2.0);
final double expectedZScore = expectedVimpMean / expectedVimpStandardError;
final double actualZScore = calculator.calculateVariableImportanceZScore(covariate, Optional.of(new Random(123)));
assertEquals(expectedZScore, actualZScore, 0.000001, "Z scores must match");
}
@Test
public void testVariableImportanceOnYNoOOB(){
// y is the NumericCovariate
final VariableImportanceCalculator<Double, Double> calculator = new VariableImportanceCalculator<>(
new RegressionErrorCalculator(),
this.forest.getTrees(),
this.rowList,
false
);
final Covariate covariate = this.covariates.get(1);
double importance[] = calculator.calculateVariableImportanceRaw(covariate, Optional.of(new Random(123)));
final double[] expectedBaselineError = {0.0, 4.0}; // first tree is accurate, second tree is not
// [1, 4, 8, 2, 5, 3, 7, 6]
// Actual: [F, T, F, T, F, T, F, T]
// Seen: [F, T, T, T, F, F, F, T]
// Difference: [=, =, +, =, =, -, =, =]x10
final List<Double> permutedPredictionsTree1 = Utils.easyList(
0., 10., 110., 110., 0., 0., 100., 110.
);
// [6, 1, 4, 7, 5, 2, 8, 3]
// Actual: [F, T, F, T, F, T, F, T]
// Seen: [T, F, T, F, F, T, T, F]
// Difference: [+, -, +, -, =, =, +, -]
final List<Double> permutedPredictionsTree2 = Utils.easyList(
12., 2., 112., 102., 2., 12., 112., 102.
);
final List<Double> observedValues = this.rowList.stream().map(Row::getResponse).collect(Collectors.toList());
final double[] expectedError = new double[2];
expectedError[0] = new RegressionErrorCalculator().averageError(observedValues, permutedPredictionsTree1);
expectedError[1] = new RegressionErrorCalculator().averageError(observedValues, permutedPredictionsTree2);
final double[] expectedVimp = difference(expectedError, expectedBaselineError);
assertDoubleEquals(expectedVimp, importance);
final double expectedVimpMean = (expectedVimp[0] + expectedVimp[1]) / 2.0;
final double expectedVimpVar = (Math.pow(expectedVimp[0] - expectedVimpMean, 2) + Math.pow(expectedVimp[1] - expectedVimpMean, 2)) / 1.0;
final double expectedVimpStandardError = Math.sqrt(expectedVimpVar / 2.0);
final double expectedZScore = expectedVimpMean / expectedVimpStandardError;
final double actualZScore = calculator.calculateVariableImportanceZScore(covariate, Optional.of(new Random(123)));
assertEquals(expectedZScore, actualZScore, 0.000001, "Z scores must match");
}
@Test
public void testVariableImportanceOnYOOB(){
// y is the NumericCovariate
final VariableImportanceCalculator<Double, Double> calculator = new VariableImportanceCalculator<>(
new RegressionErrorCalculator(),
this.forest.getTrees(),
this.rowList,
true
);
final Covariate covariate = this.covariates.get(1);
double importance[] = calculator.calculateVariableImportanceRaw(covariate, Optional.of(new Random(123)));
final double[] expectedBaselineError = {0.0, 4.0}; // first tree is accurate, second tree is not
// [5, 6, 8, 7]
// Actual: [F, T, F, T]
// Seen: [F, T, T, F]
// Difference: [=, =, +, -]x10
final List<Double> permutedPredictionsTree1 = Utils.easyList(
0., 10., 110., 100.
);
// [3, 4, 1, 2]
// Actual: [F, T, F, T]
// Seen: [F, T, F, T]
// Difference: [=, =, =, =]x10 no change
final List<Double> permutedPredictionsTree2 = Utils.easyList(
2., 12., 102., 112.
);
final List<Double> observedValues = this.rowList.stream().map(Row::getResponse).collect(Collectors.toList());
final List<Double> tree1OOBValues = observedValues.subList(4, 8);
final List<Double> tree2OOBValues = observedValues.subList(0, 4);
final double[] expectedError = new double[2];
expectedError[0] = new RegressionErrorCalculator().averageError(tree1OOBValues, permutedPredictionsTree1);
expectedError[1] = new RegressionErrorCalculator().averageError(tree2OOBValues, permutedPredictionsTree2);
final double[] expectedVimp = difference(expectedError, expectedBaselineError);
assertDoubleEquals(expectedVimp, importance);
final double expectedVimpMean = (expectedVimp[0] + expectedVimp[1]) / 2.0;
final double expectedVimpVar = (Math.pow(expectedVimp[0] - expectedVimpMean, 2) + Math.pow(expectedVimp[1] - expectedVimpMean, 2)) / 1.0;
final double expectedVimpStandardError = Math.sqrt(expectedVimpVar / 2.0);
final double expectedZScore = expectedVimpMean / expectedVimpStandardError;
final double actualZScore = calculator.calculateVariableImportanceZScore(covariate, Optional.of(new Random(123)));
assertEquals(expectedZScore, actualZScore, 0.000001, "Z scores must match");
}
@Test
public void testVariableImportanceOnZNoOOB(){
// z is the useless FactorCovariate
final VariableImportanceCalculator<Double, Double> calculator = new VariableImportanceCalculator<>(
new RegressionErrorCalculator(),
this.forest.getTrees(),
this.rowList,
false
);
final double[] importance = calculator.calculateVariableImportanceRaw(this.covariates.get(2), Optional.of(new Random(123)));
final double[] expectedImportance = {0.0, 0.0};
// FactorImportance did nothing; so permuting it will make no difference to baseline error
assertDoubleEquals(expectedImportance, importance);
}
@Test
public void testVariableImportanceOnZOOB(){
// z is the useless FactorCovariate
final VariableImportanceCalculator<Double, Double> calculator = new VariableImportanceCalculator<>(
new RegressionErrorCalculator(),
this.forest.getTrees(),
this.rowList,
true
);
final double[] importance = calculator.calculateVariableImportanceRaw(this.covariates.get(2), Optional.of(new Random(123)));
final double[] expectedImportance = {0.0, 0.0};
// FactorImportance did nothing; so permuting it will make no difference to baseline error
assertDoubleEquals(expectedImportance, importance);
}
}

View file

@ -19,9 +19,6 @@ package ca.joeltherrien.randomforest.utils;
import ca.joeltherrien.randomforest.TestUtils;
import org.junit.jupiter.api.Test;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
public class RightContinuousStepFunctionIntegrationTest {
private RightContinuousStepFunction createTestFunction(){
@ -78,78 +75,5 @@ public class RightContinuousStepFunctionIntegrationTest {
}
@Test
public void testInvertedFromTo(){
final RightContinuousStepFunction function = createTestFunction();
final double area1 = function.integrate(0, 3.0);
final double area2 = function.integrate(3.0, 0.0);
assertEquals(area1, -area2, 0.0000001);
}
@Test
public void testIntegratingUpToNan(){
// Idea here - you have a function that is valid up to point x where it becomes NaN
// You should be able to integrate *up to* point x and not get an NaN
final RightContinuousStepFunction function1 = new RightContinuousStepFunction(
new double[]{1.0, 2.0, 3.0, 4.0},
new double[]{1.0, 1.0, 1.0, Double.NaN},
0.0);
final double area1 = function1.integrate(0.0, 4.0);
assertEquals(3.0, area1, 0.000000001);
final double nanArea1 = function1.integrate(0.0, 4.0001);
assertTrue(Double.isNaN(nanArea1));
// This tests integrating over the defaultY up to the NaN point
final RightContinuousStepFunction function2 = new RightContinuousStepFunction(
new double[]{1.0, 2.0, 3.0, 4.0},
new double[]{Double.NaN, 1.0, 1.0, Double.NaN},
1.0);
final double area2 = function2.integrate(0.0, 1.0);
assertEquals(1.0, area2, 0.000000001);
final double nanArea2 = function2.integrate(0.0, 4.0);
assertTrue(Double.isNaN(nanArea2));
// This tests integrating between two NaN points. Note that of course for RightContinuousValues carry the previous
// value until the next x point, so this is just making sure the code works if the x-value we pass over is NaN
final RightContinuousStepFunction function3 = new RightContinuousStepFunction(
new double[]{1.0, 2.0, 3.0, 4.0},
new double[]{Double.NaN, 1.0, 1.0, Double.NaN},
0.0);
final double area3 = function3.integrate(2.0, 4.0);
assertEquals(2.0, area3, 0.000000001);
final double nanArea3 = function3.integrate(0.0, 4.0001);
assertTrue(Double.isNaN(nanArea3));
}
@Test
public void testIntegratingEmptyFunction(){
// A function might have no points, but we'll still need to integrate it.
final RightContinuousStepFunction function = new RightContinuousStepFunction(
new double[]{}, new double[]{}, 1.0
);
final double area = function.integrate(1.0 ,3.0);
assertEquals(2.0, area, 0.000001);
}
}

View file

@ -1,180 +0,0 @@
package ca.joeltherrien.randomforest.utils;
import org.junit.jupiter.api.Test;
import java.util.function.DoubleBinaryOperator;
import java.util.function.DoubleUnaryOperator;
import static org.junit.jupiter.api.Assertions.assertEquals;
public class RightContinuousStepFunctionOperatorTests {
// Idea - small and middle slightly overlap; middle and large slightly overlap.
// small and large never overlap (i.e. small's x values always occur before large's)
private final RightContinuousStepFunction smallNumbers;
private final RightContinuousStepFunction middleNumbers;
private final RightContinuousStepFunction largeNumbers;
private final double delta = 0.0000000001;
public RightContinuousStepFunctionOperatorTests(){
smallNumbers = RightContinuousStepFunction.constructFromPoints(Utils.easyList(
new Point(1.0, 1.0),
new Point(2.0, 3.0),
new Point(3.0, 2.0),
new Point(4.0, 1.0)
), 0.0);
middleNumbers = RightContinuousStepFunction.constructFromPoints(Utils.easyList(
new Point(3.5, 4.0),
new Point(4.0, 3.0),
new Point(5.0, 2.0),
new Point(6.0, 1.0)
), 5.0);
largeNumbers = RightContinuousStepFunction.constructFromPoints(Utils.easyList(
new Point(5.0, 5.0),
new Point(6.0, 6.0),
new Point(7.0, 3.0),
new Point(8.0, 2.0)
), 3.0);
}
@Test
public void testDifferenceNoOverlapLargeMinusSmall(){
DoubleBinaryOperator operator = (a, b) -> a - b;
final RightContinuousStepFunction largeSmallDifference = RightContinuousStepFunction.biOperation(
largeNumbers,
smallNumbers,
operator);
assertEquals(8, largeSmallDifference.getX().length);
assertEquals(8, largeSmallDifference.getY().length);
final double[] offsetTimes = {-0.1, 0.0, 0.1};
for(int time = 1; time <= 9; time++){
for(double offsetTime : offsetTimes){
final double timeToEvaluateAt = (double) time + offsetTime;
final double largeFunEvaluation = largeNumbers.evaluate(timeToEvaluateAt);
final double smallFunEvaluation = smallNumbers.evaluate(timeToEvaluateAt);
final double expectedDifference = operator.applyAsDouble(largeFunEvaluation, smallFunEvaluation);
final double actualEvaluation = largeSmallDifference.evaluate(timeToEvaluateAt);
assertEquals(expectedDifference, actualEvaluation, delta);
}
}
}
@Test
public void testDifferenceNoOverlapSmallMinusLarge(){
DoubleBinaryOperator operator = (a, b) -> a - b;
final RightContinuousStepFunction smallLargeDifference = RightContinuousStepFunction.biOperation(
smallNumbers,
largeNumbers,
operator);
assertEquals(8, smallLargeDifference.getX().length);
assertEquals(8, smallLargeDifference.getY().length);
final double[] offsetTimes = {-0.1, 0.0, 0.1};
for(int time = 1; time <= 9; time++){
for(double offsetTime : offsetTimes){
final double timeToEvaluateAt = (double) time + offsetTime;
final double smallFunEvaluation = smallNumbers.evaluate(timeToEvaluateAt);
final double largeFunEvaluation = largeNumbers.evaluate(timeToEvaluateAt);
final double expectedDifference = operator.applyAsDouble(smallFunEvaluation, largeFunEvaluation);
final double actualEvaluation = smallLargeDifference.evaluate(timeToEvaluateAt);
assertEquals(expectedDifference, actualEvaluation, delta);
}
}
}
@Test
public void testDifferenceSomeOverlapLargeMinusMiddle(){
DoubleBinaryOperator operator = (a, b) -> a - b;
final RightContinuousStepFunction combinedFunction = RightContinuousStepFunction.biOperation(
largeNumbers,
middleNumbers,
operator);
assertEquals(6, combinedFunction.getX().length);
assertEquals(6, combinedFunction.getY().length);
final double[] offsetTimes = {-0.1, 0.0, 0.1};
for(int time = 1; time <= 9; time++){
for(double offsetTime : offsetTimes){
final double timeToEvaluateAt = (double) time + offsetTime;
final double middleFunEvaluation = middleNumbers.evaluate(timeToEvaluateAt);
final double largeFunEvaluation = largeNumbers.evaluate(timeToEvaluateAt);
final double expectedDifference = operator.applyAsDouble(largeFunEvaluation, middleFunEvaluation);
final double actualEvaluation = combinedFunction.evaluate(timeToEvaluateAt);
assertEquals(expectedDifference, actualEvaluation, delta);
}
}
}
@Test
public void testDifferenceCompleteOverlap(){
DoubleBinaryOperator operator = (a, b) -> a - b;
final RightContinuousStepFunction combinedFunction = RightContinuousStepFunction.biOperation(
middleNumbers,
middleNumbers,
operator);
assertEquals(4, combinedFunction.getX().length);
assertEquals(4, combinedFunction.getY().length);
final double[] offsetTimes = {-0.1, 0.0, 0.1};
for(int time = 1; time <= 9; time++){
for(double offsetTime : offsetTimes){
final double timeToEvaluateAt = (double) time + offsetTime;
final double actualEvaluation = combinedFunction.evaluate(timeToEvaluateAt);
assertEquals(0.0, actualEvaluation, delta);
}
}
}
@Test
public void testPowerFunction(){
final DoubleUnaryOperator operator = d -> d*d;
final RightContinuousStepFunction squaredFunction = smallNumbers.unaryOperation(operator);
assertEquals(4, squaredFunction.getX().length);
assertEquals(4, squaredFunction.getY().length);
final double[] offsetTimes = {-0.1, 0.0, 0.1};
for(int time = 1; time <= 9; time++){
for(double offsetTime : offsetTimes){
final double timeToEvaluateAt = (double) time + offsetTime;
final double expectedEvaluation = operator.applyAsDouble(smallNumbers.evaluate(timeToEvaluateAt));
final double actualEvaluation = squaredFunction.evaluate(timeToEvaluateAt);
assertEquals(expectedEvaluation, actualEvaluation, delta);
}
}
}
}

View file

@ -43,7 +43,7 @@ public class TrainForest {
final List<Covariate> covariateList = new ArrayList<>(p);
for(int j =0; j < p; j++){
final NumericCovariate covariate = new NumericCovariate("x"+j, j, false);
final NumericCovariate covariate = new NumericCovariate("x"+j, j);
covariateList.add(covariate);
}

View file

@ -39,8 +39,8 @@ public class TrainSingleTree {
final int n = 1000;
final List<Row<Double>> trainingSet = new ArrayList<>(n);
final Covariate<Double> x1Covariate = new NumericCovariate("x1", 0, false);
final Covariate<Double> x2Covariate = new NumericCovariate("x2", 1, false);
final Covariate<Double> x1Covariate = new NumericCovariate("x1", 0);
final Covariate<Double> x2Covariate = new NumericCovariate("x2", 1);
final List<Covariate.Value<Double>> x1List = DoubleStream
.generate(() -> random.nextDouble()*10.0)

View file

@ -41,9 +41,9 @@ public class TrainSingleTreeFactor {
final int n = 10000;
final List<Row<Double>> trainingSet = new ArrayList<>(n);
final Covariate<Double> x1Covariate = new NumericCovariate("x1", 0, false);
final Covariate<Double> x2Covariate = new NumericCovariate("x2", 1, false);
final FactorCovariate x3Covariate = new FactorCovariate("x3", 2, Utils.easyList("cat", "dog", "mouse"), false);
final Covariate<Double> x1Covariate = new NumericCovariate("x1", 0);
final Covariate<Double> x2Covariate = new NumericCovariate("x2", 1);
final FactorCovariate x3Covariate = new FactorCovariate("x3", 2, Utils.easyList("cat", "dog", "mouse"));
final List<Covariate.Value<Double>> x1List = DoubleStream
.generate(() -> random.nextDouble()*10.0)