Add variable importance methods to library
This commit is contained in:
parent
f23ee21ef3
commit
a56ad4433d
10 changed files with 738 additions and 8 deletions
|
@ -21,12 +21,11 @@ import lombok.Getter;
|
||||||
import lombok.RequiredArgsConstructor;
|
import lombok.RequiredArgsConstructor;
|
||||||
|
|
||||||
import java.io.Serializable;
|
import java.io.Serializable;
|
||||||
import java.util.HashMap;
|
import java.util.*;
|
||||||
import java.util.List;
|
import java.util.stream.Collectors;
|
||||||
import java.util.Map;
|
|
||||||
|
|
||||||
@RequiredArgsConstructor
|
@RequiredArgsConstructor
|
||||||
public class CovariateRow implements Serializable {
|
public class CovariateRow implements Serializable, Cloneable {
|
||||||
|
|
||||||
private final Covariate.Value[] valueArray;
|
private final Covariate.Value[] valueArray;
|
||||||
|
|
||||||
|
@ -46,6 +45,14 @@ public class CovariateRow implements Serializable {
|
||||||
return "CovariateRow " + this.id;
|
return "CovariateRow " + this.id;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public CovariateRow clone() {
|
||||||
|
// shallow clone, which is fine. I want a new array, but the values don't need to be copied
|
||||||
|
final Covariate.Value[] copyValueArray = this.valueArray.clone();
|
||||||
|
|
||||||
|
return new CovariateRow(copyValueArray, this.id);
|
||||||
|
}
|
||||||
|
|
||||||
public static CovariateRow createSimple(Map<String, String> simpleMap, List<Covariate> covariateList, int id){
|
public static CovariateRow createSimple(Map<String, String> simpleMap, List<Covariate> covariateList, int id){
|
||||||
final Covariate.Value[] valueArray = new Covariate.Value[covariateList.size()];
|
final Covariate.Value[] valueArray = new Covariate.Value[covariateList.size()];
|
||||||
final Map<String, Covariate> covariateMap = new HashMap<>();
|
final Map<String, Covariate> covariateMap = new HashMap<>();
|
||||||
|
@ -64,4 +71,27 @@ public class CovariateRow implements Serializable {
|
||||||
return new CovariateRow(valueArray, id);
|
return new CovariateRow(valueArray, id);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Used for variable importance; takes a List of CovariateRows and permute one of the Covariates.
|
||||||
|
*
|
||||||
|
* @param covariateRows The List of CovariateRows to scramble. Note that the originals won't be modified.
|
||||||
|
* @param covariateToScramble The Covariate to scramble on.
|
||||||
|
* @param random The source of randomness to use. If not present, one will be created.
|
||||||
|
* @return A List of CovariateRows where the specified covariate was scrambled. These are different objects from the ones provided.
|
||||||
|
*/
|
||||||
|
public static List<CovariateRow> scrambleCovariateValues(List<? extends CovariateRow> covariateRows, Covariate covariateToScramble, Optional<Random> random){
|
||||||
|
final List<CovariateRow> permutedCovariateRowList = new ArrayList<>(covariateRows);
|
||||||
|
Collections.shuffle(permutedCovariateRowList, random.orElse(new Random())); // without replacement
|
||||||
|
|
||||||
|
final List<CovariateRow> clonedRowList = covariateRows.stream().map(CovariateRow::clone).collect(Collectors.toList());
|
||||||
|
|
||||||
|
final int covariateToScrambleIndex = covariateToScramble.getIndex();
|
||||||
|
for(int i=0; i < covariateRows.size(); i++){
|
||||||
|
clonedRowList.get(i).valueArray[covariateToScrambleIndex] = permutedCovariateRowList.get(i).valueArray[covariateToScrambleIndex];
|
||||||
|
}
|
||||||
|
|
||||||
|
return clonedRowList;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -32,9 +32,6 @@ public class Row<Y> extends CovariateRow {
|
||||||
this.response = response;
|
this.response = response;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
public Y getResponse() {
|
public Y getResponse() {
|
||||||
return this.response;
|
return this.response;
|
||||||
}
|
}
|
||||||
|
|
|
@ -26,7 +26,7 @@ public class NumericSplitRule implements SplitRule<Double> {
|
||||||
private final int parentCovariateIndex;
|
private final int parentCovariateIndex;
|
||||||
private final double threshold;
|
private final double threshold;
|
||||||
|
|
||||||
NumericSplitRule(NumericCovariate parent, final double threshold){
|
public NumericSplitRule(NumericCovariate parent, final double threshold){
|
||||||
this.parentCovariateIndex = parent.getIndex();
|
this.parentCovariateIndex = parent.getIndex();
|
||||||
this.threshold = threshold;
|
this.threshold = threshold;
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,24 @@
|
||||||
|
package ca.joeltherrien.randomforest.tree.vimp;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Simple interface for VariableImportanceCalculator; takes in a List of observed responses and a List of predictions
|
||||||
|
* and produces an average error measure.
|
||||||
|
*
|
||||||
|
* @param <Y> The class of the responses.
|
||||||
|
* @param <P> The class of the predictions.
|
||||||
|
*/
|
||||||
|
public interface ErrorCalculator<Y, P>{
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Compares the observed responses with the predictions to produce an average error measure.
|
||||||
|
* Lower errors should indicate a better model fit.
|
||||||
|
*
|
||||||
|
* @param responses
|
||||||
|
* @param predictions
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
double averageError(List<Y> responses, List<P> predictions);
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,63 @@
|
||||||
|
package ca.joeltherrien.randomforest.tree.vimp;
|
||||||
|
|
||||||
|
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskFunctions;
|
||||||
|
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
|
||||||
|
import ca.joeltherrien.randomforest.responses.competingrisk.IBSCalculator;
|
||||||
|
import ca.joeltherrien.randomforest.utils.RightContinuousStepFunction;
|
||||||
|
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Implements ErrorCalculator; essentially just wraps around IBSCalculator to fit into VariableImportanceCalculator.
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
public class IBSErrorCalculatorWrapper implements ErrorCalculator<CompetingRiskResponse, CompetingRiskFunctions> {
|
||||||
|
|
||||||
|
private final IBSCalculator calculator;
|
||||||
|
private final int[] events;
|
||||||
|
private final double integrationUpperBound;
|
||||||
|
private final double[] eventWeights;
|
||||||
|
|
||||||
|
public IBSErrorCalculatorWrapper(IBSCalculator calculator, int[] events, double integrationUpperBound, double[] eventWeights) {
|
||||||
|
this.calculator = calculator;
|
||||||
|
this.events = events;
|
||||||
|
this.integrationUpperBound = integrationUpperBound;
|
||||||
|
this.eventWeights = eventWeights;
|
||||||
|
}
|
||||||
|
|
||||||
|
public IBSErrorCalculatorWrapper(IBSCalculator calculator, int[] events, double integrationUpperBound) {
|
||||||
|
this.calculator = calculator;
|
||||||
|
this.events = events;
|
||||||
|
this.integrationUpperBound = integrationUpperBound;
|
||||||
|
this.eventWeights = new double[events.length];
|
||||||
|
|
||||||
|
Arrays.fill(this.eventWeights, 1.0); // default is to just sum all errors together
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public double averageError(List<CompetingRiskResponse> responses, List<CompetingRiskFunctions> predictions) {
|
||||||
|
final double[] errors = new double[events.length];
|
||||||
|
final double n = responses.size();
|
||||||
|
|
||||||
|
for(int i=0; i < responses.size(); i++){
|
||||||
|
final CompetingRiskResponse response = responses.get(i);
|
||||||
|
final CompetingRiskFunctions prediction = predictions.get(i);
|
||||||
|
|
||||||
|
for(int k=0; k < this.events.length; k++){
|
||||||
|
final int event = this.events[k];
|
||||||
|
final RightContinuousStepFunction cif = prediction.getCumulativeIncidenceFunction(event);
|
||||||
|
errors[k] += calculator.calculateError(response, cif, event, integrationUpperBound) / n;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
double totalError = 0.0;
|
||||||
|
for(int k=0; k < this.events.length; k++){
|
||||||
|
totalError += this.eventWeights[k] * errors[k];
|
||||||
|
}
|
||||||
|
|
||||||
|
return totalError;
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,23 @@
|
||||||
|
package ca.joeltherrien.randomforest.tree.vimp;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
public class RegressionErrorCalculator implements ErrorCalculator<Double, Double>{
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public double averageError(List<Double> responses, List<Double> predictions) {
|
||||||
|
double mean = 0.0;
|
||||||
|
final double n = responses.size();
|
||||||
|
|
||||||
|
for(int i=0; i<responses.size(); i++){
|
||||||
|
final double response = responses.get(i);
|
||||||
|
final double prediction = predictions.get(i);
|
||||||
|
|
||||||
|
final double difference = response - prediction;
|
||||||
|
|
||||||
|
mean += difference * difference / n;
|
||||||
|
}
|
||||||
|
|
||||||
|
return mean;
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,65 @@
|
||||||
|
package ca.joeltherrien.randomforest.tree.vimp;
|
||||||
|
|
||||||
|
import ca.joeltherrien.randomforest.CovariateRow;
|
||||||
|
import ca.joeltherrien.randomforest.Row;
|
||||||
|
import ca.joeltherrien.randomforest.covariates.Covariate;
|
||||||
|
import ca.joeltherrien.randomforest.tree.Forest;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Optional;
|
||||||
|
import java.util.Random;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
|
|
||||||
|
public class VariableImportanceCalculator<Y, P> {
|
||||||
|
|
||||||
|
private final ErrorCalculator<Y, P> errorCalculator;
|
||||||
|
private final Forest<Y, P> forest;
|
||||||
|
private final List<Row<Y>> observations;
|
||||||
|
private final List<Y> observedResponses;
|
||||||
|
|
||||||
|
private final boolean isTrainingSet; // If true, then we use out-of-bag predictions
|
||||||
|
private final double baselineError;
|
||||||
|
|
||||||
|
public VariableImportanceCalculator(
|
||||||
|
ErrorCalculator<Y, P> errorCalculator,
|
||||||
|
Forest<Y, P> forest,
|
||||||
|
List<Row<Y>> observations,
|
||||||
|
boolean isTrainingSet
|
||||||
|
){
|
||||||
|
this.errorCalculator = errorCalculator;
|
||||||
|
this.forest = forest;
|
||||||
|
this.observations = observations;
|
||||||
|
this.isTrainingSet = isTrainingSet;
|
||||||
|
|
||||||
|
this.observedResponses = observations.stream()
|
||||||
|
.map(row -> row.getResponse()).collect(Collectors.toList());
|
||||||
|
|
||||||
|
final List<P> baselinePredictions = makePredictions(observations);
|
||||||
|
this.baselineError = errorCalculator.averageError(observedResponses, baselinePredictions);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
public double calculateVariableImportance(Covariate covariate, Optional<Random> random){
|
||||||
|
final List<CovariateRow> scrambledValues = CovariateRow.scrambleCovariateValues(this.observations, covariate, random);
|
||||||
|
final List<P> alternatePredictions = makePredictions(scrambledValues);
|
||||||
|
final double newError = errorCalculator.averageError(this.observedResponses, alternatePredictions);
|
||||||
|
|
||||||
|
return newError - this.baselineError;
|
||||||
|
}
|
||||||
|
|
||||||
|
public double[] calculateVariableImportance(List<Covariate> covariates, Optional<Random> random){
|
||||||
|
return covariates.stream()
|
||||||
|
.mapToDouble(covariate -> calculateVariableImportance(covariate, random))
|
||||||
|
.toArray();
|
||||||
|
}
|
||||||
|
|
||||||
|
private List<P> makePredictions(List<? extends CovariateRow> rowList){
|
||||||
|
if(isTrainingSet){
|
||||||
|
return forest.evaluateOOB(rowList);
|
||||||
|
} else{
|
||||||
|
return forest.evaluate(rowList);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,143 @@
|
||||||
|
package ca.joeltherrien.randomforest.tree.vimp;
|
||||||
|
|
||||||
|
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskFunctions;
|
||||||
|
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
|
||||||
|
import ca.joeltherrien.randomforest.responses.competingrisk.IBSCalculator;
|
||||||
|
import ca.joeltherrien.randomforest.utils.Point;
|
||||||
|
import ca.joeltherrien.randomforest.utils.RightContinuousStepFunction;
|
||||||
|
import ca.joeltherrien.randomforest.utils.Utils;
|
||||||
|
import org.junit.jupiter.api.Test;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
|
||||||
|
public class IBSErrorCalculatorWrapperTest {
|
||||||
|
|
||||||
|
/*
|
||||||
|
We already have tests for the IBSCalculator, so these tests are concerned with making sure we correctly average
|
||||||
|
the errors together, not that we fully test the production of each error under different scenarios (like
|
||||||
|
providing / not providing a censoring distribution).
|
||||||
|
*/
|
||||||
|
|
||||||
|
private final double integrationUpperBound = 5.0;
|
||||||
|
|
||||||
|
private final List<CompetingRiskResponse> responses;
|
||||||
|
private final List<CompetingRiskFunctions> functions;
|
||||||
|
|
||||||
|
|
||||||
|
private final double[][] errors;
|
||||||
|
|
||||||
|
public IBSErrorCalculatorWrapperTest(){
|
||||||
|
this.responses = Utils.easyList(
|
||||||
|
new CompetingRiskResponse(0, 2.0),
|
||||||
|
new CompetingRiskResponse(0, 3.0),
|
||||||
|
new CompetingRiskResponse(1, 1.0),
|
||||||
|
new CompetingRiskResponse(1, 1.5),
|
||||||
|
new CompetingRiskResponse(2, 3.0),
|
||||||
|
new CompetingRiskResponse(2, 4.0)
|
||||||
|
);
|
||||||
|
|
||||||
|
final RightContinuousStepFunction cif1 = RightContinuousStepFunction.constructFromPoints(Utils.easyList(
|
||||||
|
new Point(1.0, 0.25),
|
||||||
|
new Point(1.5, 0.45)
|
||||||
|
), 0.0);
|
||||||
|
|
||||||
|
final RightContinuousStepFunction cif2 = RightContinuousStepFunction.constructFromPoints(Utils.easyList(
|
||||||
|
new Point(3.0, 0.25),
|
||||||
|
new Point(4.0, 0.45)
|
||||||
|
), 0.0);
|
||||||
|
|
||||||
|
// This function is for the unused CHFs and survival curve
|
||||||
|
// If we see infinities or NaNs popping up in our output we should look here.
|
||||||
|
final RightContinuousStepFunction emptyFun = RightContinuousStepFunction.constructFromPoints(Utils.easyList(
|
||||||
|
new Point(0.0, Double.NaN)
|
||||||
|
), Double.NEGATIVE_INFINITY
|
||||||
|
);
|
||||||
|
|
||||||
|
final CompetingRiskFunctions function = CompetingRiskFunctions.builder()
|
||||||
|
.cumulativeIncidenceCurves(Utils.easyList(cif1, cif2))
|
||||||
|
.causeSpecificHazards(Utils.easyList(emptyFun, emptyFun))
|
||||||
|
.survivalCurve(emptyFun)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
// Same prediction for every response.
|
||||||
|
this.functions = Utils.easyList(function, function, function, function, function, function);
|
||||||
|
|
||||||
|
final IBSCalculator calculator = new IBSCalculator();
|
||||||
|
this.errors = new double[2][6];
|
||||||
|
|
||||||
|
for(int event : new int[]{1, 2}){
|
||||||
|
for(int i=0; i<6; i++){
|
||||||
|
this.errors[event-1][i] = calculator.calculateError(
|
||||||
|
responses.get(i), function.getCumulativeIncidenceFunction(event),
|
||||||
|
event, integrationUpperBound
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testOneEventOne(){
|
||||||
|
final IBSErrorCalculatorWrapper wrapper = new IBSErrorCalculatorWrapper(new IBSCalculator(), new int[]{1},
|
||||||
|
this.integrationUpperBound);
|
||||||
|
|
||||||
|
final double error = wrapper.averageError(this.responses, this.functions);
|
||||||
|
double expectedError = 0.0;
|
||||||
|
for(int i=0; i<6; i++){
|
||||||
|
expectedError += errors[0][i] / 6.0;
|
||||||
|
}
|
||||||
|
|
||||||
|
assertEquals(expectedError, error, 0.00000001);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testOneEventTwo(){
|
||||||
|
final IBSErrorCalculatorWrapper wrapper = new IBSErrorCalculatorWrapper(new IBSCalculator(), new int[]{2},
|
||||||
|
this.integrationUpperBound);
|
||||||
|
|
||||||
|
final double error = wrapper.averageError(this.responses, this.functions);
|
||||||
|
double expectedError = 0.0;
|
||||||
|
for(int i=0; i<6; i++){
|
||||||
|
expectedError += errors[1][i] / 6.0;
|
||||||
|
}
|
||||||
|
|
||||||
|
assertEquals(expectedError, error, 0.00000001);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testTwoEventsNoWeights(){
|
||||||
|
final IBSErrorCalculatorWrapper wrapper = new IBSErrorCalculatorWrapper(new IBSCalculator(), new int[]{1, 2},
|
||||||
|
this.integrationUpperBound);
|
||||||
|
|
||||||
|
final double error = wrapper.averageError(this.responses, this.functions);
|
||||||
|
double expectedError1 = 0.0;
|
||||||
|
double expectedError2 = 0.0;
|
||||||
|
|
||||||
|
for(int i=0; i<6; i++){
|
||||||
|
expectedError1 += errors[0][i] / 6.0;
|
||||||
|
expectedError2 += errors[1][i] / 6.0;
|
||||||
|
}
|
||||||
|
|
||||||
|
assertEquals(expectedError1 + expectedError2, error, 0.00000001);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testTwoEventsWithWeights(){
|
||||||
|
final IBSErrorCalculatorWrapper wrapper = new IBSErrorCalculatorWrapper(new IBSCalculator(), new int[]{1, 2},
|
||||||
|
this.integrationUpperBound, new double[]{1.0, 2.0});
|
||||||
|
|
||||||
|
final double error = wrapper.averageError(this.responses, this.functions);
|
||||||
|
double expectedError1 = 0.0;
|
||||||
|
double expectedError2 = 0.0;
|
||||||
|
|
||||||
|
for(int i=0; i<6; i++){
|
||||||
|
expectedError1 += errors[0][i] / 6.0;
|
||||||
|
expectedError2 += errors[1][i] / 6.0;
|
||||||
|
}
|
||||||
|
|
||||||
|
assertEquals(1.0 * expectedError1 + 2.0 * expectedError2, error, 0.00000001);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,26 @@
|
||||||
|
package ca.joeltherrien.randomforest.tree.vimp;
|
||||||
|
|
||||||
|
import ca.joeltherrien.randomforest.utils.Utils;
|
||||||
|
import org.junit.jupiter.api.Test;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
|
||||||
|
public class RegressionErrorCalculatorTest {
|
||||||
|
|
||||||
|
private final RegressionErrorCalculator calculator = new RegressionErrorCalculator();
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testRegressionErrorCalculator(){
|
||||||
|
final List<Double> responses = Utils.easyList(1.0, 1.5, 0.0, 3.0);
|
||||||
|
final List<Double> predictions = Utils.easyList(1.5, 1.7, 0.1, 2.9);
|
||||||
|
|
||||||
|
// Differences are 0.5, 0.2, -0.1, 0.1
|
||||||
|
// Squared: 0.25, 0.04, 0.01, 0.01
|
||||||
|
|
||||||
|
assertEquals((0.25 + 0.04 + 0.01 + 0.01)/4.0, calculator.averageError(responses, predictions), 0.000000001);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,359 @@
|
||||||
|
package ca.joeltherrien.randomforest.tree.vimp;
|
||||||
|
|
||||||
|
import ca.joeltherrien.randomforest.Row;
|
||||||
|
import ca.joeltherrien.randomforest.covariates.Covariate;
|
||||||
|
import ca.joeltherrien.randomforest.covariates.bool.BooleanCovariate;
|
||||||
|
import ca.joeltherrien.randomforest.covariates.bool.BooleanSplitRule;
|
||||||
|
import ca.joeltherrien.randomforest.covariates.factor.FactorCovariate;
|
||||||
|
import ca.joeltherrien.randomforest.covariates.numeric.NumericCovariate;
|
||||||
|
import ca.joeltherrien.randomforest.covariates.numeric.NumericSplitRule;
|
||||||
|
import ca.joeltherrien.randomforest.responses.regression.MeanResponseCombiner;
|
||||||
|
import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceSplitFinder;
|
||||||
|
import ca.joeltherrien.randomforest.tree.*;
|
||||||
|
import ca.joeltherrien.randomforest.utils.Utils;
|
||||||
|
import org.junit.jupiter.api.Test;
|
||||||
|
|
||||||
|
import java.util.*;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
import java.util.stream.IntStream;
|
||||||
|
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
|
||||||
|
public class VariableImportanceCalculatorTest {
|
||||||
|
|
||||||
|
/*
|
||||||
|
Since the logic for VariableImportanceCalculator is generic, it will be much easier to test under a regression
|
||||||
|
setting.
|
||||||
|
*/
|
||||||
|
|
||||||
|
// We'l have a very simple Forest of two trees
|
||||||
|
private final Forest<Double, Double> forest;
|
||||||
|
|
||||||
|
|
||||||
|
private final List<Covariate> covariates;
|
||||||
|
private final List<Row<Double>> rowList;
|
||||||
|
|
||||||
|
/*
|
||||||
|
Long setup process; forest is manually constructed so that we can be exactly sure on our variable importance.
|
||||||
|
|
||||||
|
*/
|
||||||
|
public VariableImportanceCalculatorTest(){
|
||||||
|
final BooleanCovariate booleanCovariate = new BooleanCovariate("x", 0);
|
||||||
|
final NumericCovariate numericCovariate = new NumericCovariate("y", 1);
|
||||||
|
final FactorCovariate factorCovariate = new FactorCovariate("z", 2,
|
||||||
|
Utils.easyList("red", "blue", "green"));
|
||||||
|
|
||||||
|
this.covariates = Utils.easyList(booleanCovariate, numericCovariate, factorCovariate);
|
||||||
|
|
||||||
|
final TreeTrainer<Double, Double> treeTrainer = TreeTrainer.<Double, Double>builder()
|
||||||
|
.responseCombiner(new MeanResponseCombiner())
|
||||||
|
.splitFinder(new WeightedVarianceSplitFinder())
|
||||||
|
.numberOfSplits(0)
|
||||||
|
.nodeSize(1)
|
||||||
|
.maxNodeDepth(100)
|
||||||
|
.mtry(3)
|
||||||
|
.checkNodePurity(false)
|
||||||
|
.covariates(this.covariates)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
/*
|
||||||
|
Plan for data - BooleanCovariate is split on first and has the largest impact.
|
||||||
|
NumericCovariate is at second level and has more minimal impact.
|
||||||
|
FactorCovariate is useless and never used.
|
||||||
|
Our tree (we'll duplicate it for testing OOB errors) will have a depth of 1. (0 based).
|
||||||
|
*/
|
||||||
|
|
||||||
|
final Tree<Double> tree1 = makeTree(covariates, 0.0, new int[]{1,2,3,4});
|
||||||
|
final Tree<Double> tree2 = makeTree(covariates, 2.0, new int[]{5,6,7,8});
|
||||||
|
|
||||||
|
this.forest = Forest.<Double, Double>builder()
|
||||||
|
.trees(Utils.easyList(tree1, tree2))
|
||||||
|
.treeResponseCombiner(new MeanResponseCombiner())
|
||||||
|
.covariateList(this.covariates)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
// formula; boolean high adds 100; high numeric adds 10
|
||||||
|
// This row list should have a baseline error of 0.0
|
||||||
|
this.rowList = Utils.easyList(
|
||||||
|
Row.createSimple(Utils.easyMap(
|
||||||
|
"x", "false",
|
||||||
|
"y", "0.0",
|
||||||
|
"z", "red"),
|
||||||
|
covariates, 1, 0.0
|
||||||
|
),
|
||||||
|
Row.createSimple(Utils.easyMap(
|
||||||
|
"x", "false",
|
||||||
|
"y", "10.0",
|
||||||
|
"z", "blue"),
|
||||||
|
covariates, 2, 10.0
|
||||||
|
),
|
||||||
|
Row.createSimple(Utils.easyMap(
|
||||||
|
"x", "true",
|
||||||
|
"y", "0.0",
|
||||||
|
"z", "red"),
|
||||||
|
covariates, 3, 100.0
|
||||||
|
),
|
||||||
|
Row.createSimple(Utils.easyMap(
|
||||||
|
"x", "true",
|
||||||
|
"y", "10.0",
|
||||||
|
"z", "green"),
|
||||||
|
covariates, 4, 110.0
|
||||||
|
),
|
||||||
|
|
||||||
|
Row.createSimple(Utils.easyMap(
|
||||||
|
"x", "false",
|
||||||
|
"y", "0.0",
|
||||||
|
"z", "red"),
|
||||||
|
covariates, 5, 0.0
|
||||||
|
),
|
||||||
|
Row.createSimple(Utils.easyMap(
|
||||||
|
"x", "false",
|
||||||
|
"y", "10.0",
|
||||||
|
"z", "blue"),
|
||||||
|
covariates, 6, 10.0
|
||||||
|
),
|
||||||
|
Row.createSimple(Utils.easyMap(
|
||||||
|
"x", "true",
|
||||||
|
"y", "0.0",
|
||||||
|
"z", "red"),
|
||||||
|
covariates, 7, 100.0
|
||||||
|
),
|
||||||
|
Row.createSimple(Utils.easyMap(
|
||||||
|
"x", "true",
|
||||||
|
"y", "10.0",
|
||||||
|
"z", "green"),
|
||||||
|
covariates, 8, 110.0
|
||||||
|
)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
private Tree<Double> makeTree(List<Covariate> covariates, double offset, int[] indices){
|
||||||
|
// Naming convention - xyTerminal where x and y are low/high denotes whether BooleanCovariate(x) is low/high and
|
||||||
|
// whether NumericCovariate(y) is low/high.
|
||||||
|
final TerminalNode<Double> lowLowTerminal = new TerminalNode<>(0.0 + offset, 5);
|
||||||
|
final TerminalNode<Double> lowHighTerminal = new TerminalNode<>(10.0 + offset, 5);
|
||||||
|
final TerminalNode<Double> highLowTerminal = new TerminalNode<>(100.0 + offset, 5);
|
||||||
|
final TerminalNode<Double> highHighTerminal = new TerminalNode<>(110.0 + offset, 5);
|
||||||
|
|
||||||
|
final SplitNode<Double> lowSplitNode = SplitNode.<Double>builder()
|
||||||
|
.leftHand(lowLowTerminal)
|
||||||
|
.rightHand(lowHighTerminal)
|
||||||
|
.probabilityNaLeftHand(0.5)
|
||||||
|
.splitRule(new NumericSplitRule((NumericCovariate) covariates.get(1), 5.0))
|
||||||
|
.build();
|
||||||
|
|
||||||
|
final SplitNode<Double> highSplitNode = SplitNode.<Double>builder()
|
||||||
|
.leftHand(highLowTerminal)
|
||||||
|
.rightHand(highHighTerminal)
|
||||||
|
.probabilityNaLeftHand(0.5)
|
||||||
|
.splitRule(new NumericSplitRule((NumericCovariate) covariates.get(1), 5.0))
|
||||||
|
.build();
|
||||||
|
|
||||||
|
final SplitNode<Double> rootSplitNode = SplitNode.<Double>builder()
|
||||||
|
.leftHand(lowSplitNode)
|
||||||
|
.rightHand(highSplitNode)
|
||||||
|
.probabilityNaLeftHand(0.5)
|
||||||
|
.splitRule(new BooleanSplitRule((BooleanCovariate) covariates.get(0)))
|
||||||
|
.build();
|
||||||
|
|
||||||
|
return new Tree<>(rootSplitNode, indices);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
// Experiment with random seeds to first examine what a split does so we know what to expect
|
||||||
|
/*
|
||||||
|
public static void main(String[] args){
|
||||||
|
final List<Integer> ints1 = IntStream.range(1, 9).boxed().collect(Collectors.toList());
|
||||||
|
final List<Integer> ints2 = IntStream.range(1, 9).boxed().collect(Collectors.toList());
|
||||||
|
|
||||||
|
final Random random = new Random(123);
|
||||||
|
Collections.shuffle(ints1, random);
|
||||||
|
Collections.shuffle(ints2, random);
|
||||||
|
|
||||||
|
System.out.println(ints1);
|
||||||
|
// [1, 4, 8, 2, 5, 3, 7, 6]
|
||||||
|
|
||||||
|
System.out.println(ints2);
|
||||||
|
[6, 1, 4, 7, 5, 2, 8, 3]
|
||||||
|
}
|
||||||
|
*/
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testVariableImportanceOnXNoOOB(){
|
||||||
|
// x is the BooleanCovariate
|
||||||
|
|
||||||
|
Random random = new Random(123);
|
||||||
|
final VariableImportanceCalculator<Double, Double> calculator = new VariableImportanceCalculator<>(
|
||||||
|
new RegressionErrorCalculator(),
|
||||||
|
this.forest,
|
||||||
|
this.rowList,
|
||||||
|
false
|
||||||
|
);
|
||||||
|
|
||||||
|
double importance = calculator.calculateVariableImportance(this.covariates.get(0), Optional.of(random));
|
||||||
|
|
||||||
|
final double expectedBaselineError = 1.0; // Everything is off by 1, so average is 1.0
|
||||||
|
|
||||||
|
final List<Double> permutedPredictions = Utils.easyList(
|
||||||
|
1., 111., 101., 11., 1., 111., 101., 11.
|
||||||
|
);
|
||||||
|
final List<Double> observedValues = this.rowList.stream().map(Row::getResponse).collect(Collectors.toList());
|
||||||
|
|
||||||
|
final double expectedError = new RegressionErrorCalculator().averageError(observedValues, permutedPredictions);
|
||||||
|
|
||||||
|
assertEquals(expectedError - expectedBaselineError, importance, 0.0000001);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testVariableImportanceOnXOOB(){
|
||||||
|
// x is the BooleanCovariate
|
||||||
|
|
||||||
|
Random random = new Random(123);
|
||||||
|
final VariableImportanceCalculator<Double, Double> calculator = new VariableImportanceCalculator<>(
|
||||||
|
new RegressionErrorCalculator(),
|
||||||
|
this.forest,
|
||||||
|
this.rowList,
|
||||||
|
true
|
||||||
|
);
|
||||||
|
|
||||||
|
double importance = calculator.calculateVariableImportance(this.covariates.get(0), Optional.of(random));
|
||||||
|
|
||||||
|
// First 4 observations are off by 2, last 4 are off by 0
|
||||||
|
final double expectedBaselineError = 2.0*2.0 * 4.0 / 8.0;
|
||||||
|
|
||||||
|
// Remember we are working with OOB predictions
|
||||||
|
final List<Double> permutedPredictions = Utils.easyList(
|
||||||
|
2., 112., 102., 12., 0., 110., 100., 10.
|
||||||
|
);
|
||||||
|
final List<Double> observedValues = this.rowList.stream().map(Row::getResponse).collect(Collectors.toList());
|
||||||
|
|
||||||
|
final double expectedError = new RegressionErrorCalculator().averageError(observedValues, permutedPredictions);
|
||||||
|
|
||||||
|
assertEquals(expectedError - expectedBaselineError, importance, 0.0000001);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testVariableImportanceOnYNoOOB(){
|
||||||
|
// y is the NumericCovariate
|
||||||
|
|
||||||
|
Random random = new Random(123);
|
||||||
|
final VariableImportanceCalculator<Double, Double> calculator = new VariableImportanceCalculator<>(
|
||||||
|
new RegressionErrorCalculator(),
|
||||||
|
this.forest,
|
||||||
|
this.rowList,
|
||||||
|
false
|
||||||
|
);
|
||||||
|
|
||||||
|
double importance = calculator.calculateVariableImportance(this.covariates.get(1), Optional.of(random));
|
||||||
|
|
||||||
|
final double expectedBaselineError = 1.0; // Everything is off by 1, so average is 1.0
|
||||||
|
|
||||||
|
final List<Double> permutedPredictions = Utils.easyList(
|
||||||
|
1., 11., 111., 111., 1., 1., 101., 111.
|
||||||
|
);
|
||||||
|
final List<Double> observedValues = this.rowList.stream().map(Row::getResponse).collect(Collectors.toList());
|
||||||
|
|
||||||
|
final double expectedError = new RegressionErrorCalculator().averageError(observedValues, permutedPredictions);
|
||||||
|
|
||||||
|
assertEquals(expectedError - expectedBaselineError, importance, 0.0000001);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testVariableImportanceOnYOOB(){
|
||||||
|
// y is the NumericCovariate
|
||||||
|
|
||||||
|
Random random = new Random(123);
|
||||||
|
final VariableImportanceCalculator<Double, Double> calculator = new VariableImportanceCalculator<>(
|
||||||
|
new RegressionErrorCalculator(),
|
||||||
|
this.forest,
|
||||||
|
this.rowList,
|
||||||
|
true
|
||||||
|
);
|
||||||
|
|
||||||
|
double importance = calculator.calculateVariableImportance(this.covariates.get(1), Optional.of(random));
|
||||||
|
|
||||||
|
// First 4 observations are off by 2, last 4 are off by 0
|
||||||
|
final double expectedBaselineError = 2.0*2.0 * 4.0 / 8.0;
|
||||||
|
|
||||||
|
// Remember we are working with OOB predictions
|
||||||
|
final List<Double> permutedPredictions = Utils.easyList(
|
||||||
|
2., 12., 112., 112., 0., 0., 100., 110.
|
||||||
|
);
|
||||||
|
final List<Double> observedValues = this.rowList.stream().map(Row::getResponse).collect(Collectors.toList());
|
||||||
|
|
||||||
|
final double expectedError = new RegressionErrorCalculator().averageError(observedValues, permutedPredictions);
|
||||||
|
|
||||||
|
assertEquals(expectedError - expectedBaselineError, importance, 0.0000001);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testVariableImportanceOnZNoOOB(){
|
||||||
|
// z is the useless FactorCovariate
|
||||||
|
|
||||||
|
Random random = new Random(123);
|
||||||
|
final VariableImportanceCalculator<Double, Double> calculator = new VariableImportanceCalculator<>(
|
||||||
|
new RegressionErrorCalculator(),
|
||||||
|
this.forest,
|
||||||
|
this.rowList,
|
||||||
|
false
|
||||||
|
);
|
||||||
|
|
||||||
|
double importance = calculator.calculateVariableImportance(this.covariates.get(2), Optional.of(random));
|
||||||
|
|
||||||
|
// FactorImportance did nothing; so permuting it will make no difference to baseline error
|
||||||
|
assertEquals(0, importance, 0.0000001);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testVariableImportanceOnZOOB(){
|
||||||
|
// z is the useless FactorCovariate
|
||||||
|
|
||||||
|
Random random = new Random(123);
|
||||||
|
final VariableImportanceCalculator<Double, Double> calculator = new VariableImportanceCalculator<>(
|
||||||
|
new RegressionErrorCalculator(),
|
||||||
|
this.forest,
|
||||||
|
this.rowList,
|
||||||
|
true
|
||||||
|
);
|
||||||
|
|
||||||
|
double importance = calculator.calculateVariableImportance(this.covariates.get(2), Optional.of(random));
|
||||||
|
|
||||||
|
// FactorImportance did nothing; so permuting it will make no difference to baseline error
|
||||||
|
assertEquals(0, importance, 0.0000001);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testVariableImportanceMultiple(){
|
||||||
|
Random random = new Random(123);
|
||||||
|
final VariableImportanceCalculator<Double, Double> calculator = new VariableImportanceCalculator<>(
|
||||||
|
new RegressionErrorCalculator(),
|
||||||
|
this.forest,
|
||||||
|
this.rowList,
|
||||||
|
false
|
||||||
|
);
|
||||||
|
|
||||||
|
double importance[] = calculator.calculateVariableImportance(covariates, Optional.of(random));
|
||||||
|
|
||||||
|
final double expectedBaselineError = 1.0; // Everything is off by 1, so average is 1.0
|
||||||
|
|
||||||
|
final List<Double> observedValues = this.rowList.stream().map(Row::getResponse).collect(Collectors.toList());
|
||||||
|
|
||||||
|
final List<Double> permutedPredictionsX = Utils.easyList(
|
||||||
|
1., 111., 101., 11., 1., 111., 101., 11.
|
||||||
|
);
|
||||||
|
|
||||||
|
// [6, 1, 4, 7, 5, 2, 8, 3]
|
||||||
|
final List<Double> permutedPredictionsY = Utils.easyList(
|
||||||
|
11., 1., 111., 101., 1., 11., 111., 101.
|
||||||
|
);
|
||||||
|
|
||||||
|
final double expectedErrorX = new RegressionErrorCalculator().averageError(observedValues, permutedPredictionsX);
|
||||||
|
final double expectedErrorY = new RegressionErrorCalculator().averageError(observedValues, permutedPredictionsY);
|
||||||
|
|
||||||
|
assertEquals(expectedErrorX - expectedBaselineError, importance[0], 0.0000001);
|
||||||
|
assertEquals(expectedErrorY - expectedBaselineError, importance[1], 0.0000001);
|
||||||
|
assertEquals(0, importance[2], 0.0000001);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
Loading…
Reference in a new issue