Add variable importance methods to library

This commit is contained in:
Joel Therrien 2019-07-26 14:59:41 -07:00
parent f23ee21ef3
commit a56ad4433d
10 changed files with 738 additions and 8 deletions

View file

@ -21,12 +21,11 @@ import lombok.Getter;
import lombok.RequiredArgsConstructor; import lombok.RequiredArgsConstructor;
import java.io.Serializable; import java.io.Serializable;
import java.util.HashMap; import java.util.*;
import java.util.List; import java.util.stream.Collectors;
import java.util.Map;
@RequiredArgsConstructor @RequiredArgsConstructor
public class CovariateRow implements Serializable { public class CovariateRow implements Serializable, Cloneable {
private final Covariate.Value[] valueArray; private final Covariate.Value[] valueArray;
@ -46,6 +45,14 @@ public class CovariateRow implements Serializable {
return "CovariateRow " + this.id; return "CovariateRow " + this.id;
} }
@Override
public CovariateRow clone() {
// shallow clone, which is fine. I want a new array, but the values don't need to be copied
final Covariate.Value[] copyValueArray = this.valueArray.clone();
return new CovariateRow(copyValueArray, this.id);
}
public static CovariateRow createSimple(Map<String, String> simpleMap, List<Covariate> covariateList, int id){ public static CovariateRow createSimple(Map<String, String> simpleMap, List<Covariate> covariateList, int id){
final Covariate.Value[] valueArray = new Covariate.Value[covariateList.size()]; final Covariate.Value[] valueArray = new Covariate.Value[covariateList.size()];
final Map<String, Covariate> covariateMap = new HashMap<>(); final Map<String, Covariate> covariateMap = new HashMap<>();
@ -64,4 +71,27 @@ public class CovariateRow implements Serializable {
return new CovariateRow(valueArray, id); return new CovariateRow(valueArray, id);
} }
/**
* Used for variable importance; takes a List of CovariateRows and permute one of the Covariates.
*
* @param covariateRows The List of CovariateRows to scramble. Note that the originals won't be modified.
* @param covariateToScramble The Covariate to scramble on.
* @param random The source of randomness to use. If not present, one will be created.
* @return A List of CovariateRows where the specified covariate was scrambled. These are different objects from the ones provided.
*/
public static List<CovariateRow> scrambleCovariateValues(List<? extends CovariateRow> covariateRows, Covariate covariateToScramble, Optional<Random> random){
final List<CovariateRow> permutedCovariateRowList = new ArrayList<>(covariateRows);
Collections.shuffle(permutedCovariateRowList, random.orElse(new Random())); // without replacement
final List<CovariateRow> clonedRowList = covariateRows.stream().map(CovariateRow::clone).collect(Collectors.toList());
final int covariateToScrambleIndex = covariateToScramble.getIndex();
for(int i=0; i < covariateRows.size(); i++){
clonedRowList.get(i).valueArray[covariateToScrambleIndex] = permutedCovariateRowList.get(i).valueArray[covariateToScrambleIndex];
}
return clonedRowList;
}
} }

View file

@ -32,9 +32,6 @@ public class Row<Y> extends CovariateRow {
this.response = response; this.response = response;
} }
public Y getResponse() { public Y getResponse() {
return this.response; return this.response;
} }

View file

@ -26,7 +26,7 @@ public class NumericSplitRule implements SplitRule<Double> {
private final int parentCovariateIndex; private final int parentCovariateIndex;
private final double threshold; private final double threshold;
NumericSplitRule(NumericCovariate parent, final double threshold){ public NumericSplitRule(NumericCovariate parent, final double threshold){
this.parentCovariateIndex = parent.getIndex(); this.parentCovariateIndex = parent.getIndex();
this.threshold = threshold; this.threshold = threshold;
} }

View file

@ -0,0 +1,24 @@
package ca.joeltherrien.randomforest.tree.vimp;
import java.util.List;
/**
* Simple interface for VariableImportanceCalculator; takes in a List of observed responses and a List of predictions
* and produces an average error measure.
*
* @param <Y> The class of the responses.
* @param <P> The class of the predictions.
*/
public interface ErrorCalculator<Y, P>{
/**
* Compares the observed responses with the predictions to produce an average error measure.
* Lower errors should indicate a better model fit.
*
* @param responses
* @param predictions
* @return
*/
double averageError(List<Y> responses, List<P> predictions);
}

View file

@ -0,0 +1,63 @@
package ca.joeltherrien.randomforest.tree.vimp;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskFunctions;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
import ca.joeltherrien.randomforest.responses.competingrisk.IBSCalculator;
import ca.joeltherrien.randomforest.utils.RightContinuousStepFunction;
import java.util.Arrays;
import java.util.List;
/**
* Implements ErrorCalculator; essentially just wraps around IBSCalculator to fit into VariableImportanceCalculator.
*
*/
public class IBSErrorCalculatorWrapper implements ErrorCalculator<CompetingRiskResponse, CompetingRiskFunctions> {
private final IBSCalculator calculator;
private final int[] events;
private final double integrationUpperBound;
private final double[] eventWeights;
public IBSErrorCalculatorWrapper(IBSCalculator calculator, int[] events, double integrationUpperBound, double[] eventWeights) {
this.calculator = calculator;
this.events = events;
this.integrationUpperBound = integrationUpperBound;
this.eventWeights = eventWeights;
}
public IBSErrorCalculatorWrapper(IBSCalculator calculator, int[] events, double integrationUpperBound) {
this.calculator = calculator;
this.events = events;
this.integrationUpperBound = integrationUpperBound;
this.eventWeights = new double[events.length];
Arrays.fill(this.eventWeights, 1.0); // default is to just sum all errors together
}
@Override
public double averageError(List<CompetingRiskResponse> responses, List<CompetingRiskFunctions> predictions) {
final double[] errors = new double[events.length];
final double n = responses.size();
for(int i=0; i < responses.size(); i++){
final CompetingRiskResponse response = responses.get(i);
final CompetingRiskFunctions prediction = predictions.get(i);
for(int k=0; k < this.events.length; k++){
final int event = this.events[k];
final RightContinuousStepFunction cif = prediction.getCumulativeIncidenceFunction(event);
errors[k] += calculator.calculateError(response, cif, event, integrationUpperBound) / n;
}
}
double totalError = 0.0;
for(int k=0; k < this.events.length; k++){
totalError += this.eventWeights[k] * errors[k];
}
return totalError;
}
}

View file

@ -0,0 +1,23 @@
package ca.joeltherrien.randomforest.tree.vimp;
import java.util.List;
public class RegressionErrorCalculator implements ErrorCalculator<Double, Double>{
@Override
public double averageError(List<Double> responses, List<Double> predictions) {
double mean = 0.0;
final double n = responses.size();
for(int i=0; i<responses.size(); i++){
final double response = responses.get(i);
final double prediction = predictions.get(i);
final double difference = response - prediction;
mean += difference * difference / n;
}
return mean;
}
}

View file

@ -0,0 +1,65 @@
package ca.joeltherrien.randomforest.tree.vimp;
import ca.joeltherrien.randomforest.CovariateRow;
import ca.joeltherrien.randomforest.Row;
import ca.joeltherrien.randomforest.covariates.Covariate;
import ca.joeltherrien.randomforest.tree.Forest;
import java.util.List;
import java.util.Optional;
import java.util.Random;
import java.util.stream.Collectors;
public class VariableImportanceCalculator<Y, P> {
private final ErrorCalculator<Y, P> errorCalculator;
private final Forest<Y, P> forest;
private final List<Row<Y>> observations;
private final List<Y> observedResponses;
private final boolean isTrainingSet; // If true, then we use out-of-bag predictions
private final double baselineError;
public VariableImportanceCalculator(
ErrorCalculator<Y, P> errorCalculator,
Forest<Y, P> forest,
List<Row<Y>> observations,
boolean isTrainingSet
){
this.errorCalculator = errorCalculator;
this.forest = forest;
this.observations = observations;
this.isTrainingSet = isTrainingSet;
this.observedResponses = observations.stream()
.map(row -> row.getResponse()).collect(Collectors.toList());
final List<P> baselinePredictions = makePredictions(observations);
this.baselineError = errorCalculator.averageError(observedResponses, baselinePredictions);
}
public double calculateVariableImportance(Covariate covariate, Optional<Random> random){
final List<CovariateRow> scrambledValues = CovariateRow.scrambleCovariateValues(this.observations, covariate, random);
final List<P> alternatePredictions = makePredictions(scrambledValues);
final double newError = errorCalculator.averageError(this.observedResponses, alternatePredictions);
return newError - this.baselineError;
}
public double[] calculateVariableImportance(List<Covariate> covariates, Optional<Random> random){
return covariates.stream()
.mapToDouble(covariate -> calculateVariableImportance(covariate, random))
.toArray();
}
private List<P> makePredictions(List<? extends CovariateRow> rowList){
if(isTrainingSet){
return forest.evaluateOOB(rowList);
} else{
return forest.evaluate(rowList);
}
}
}

View file

@ -0,0 +1,143 @@
package ca.joeltherrien.randomforest.tree.vimp;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskFunctions;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
import ca.joeltherrien.randomforest.responses.competingrisk.IBSCalculator;
import ca.joeltherrien.randomforest.utils.Point;
import ca.joeltherrien.randomforest.utils.RightContinuousStepFunction;
import ca.joeltherrien.randomforest.utils.Utils;
import org.junit.jupiter.api.Test;
import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
public class IBSErrorCalculatorWrapperTest {
/*
We already have tests for the IBSCalculator, so these tests are concerned with making sure we correctly average
the errors together, not that we fully test the production of each error under different scenarios (like
providing / not providing a censoring distribution).
*/
private final double integrationUpperBound = 5.0;
private final List<CompetingRiskResponse> responses;
private final List<CompetingRiskFunctions> functions;
private final double[][] errors;
public IBSErrorCalculatorWrapperTest(){
this.responses = Utils.easyList(
new CompetingRiskResponse(0, 2.0),
new CompetingRiskResponse(0, 3.0),
new CompetingRiskResponse(1, 1.0),
new CompetingRiskResponse(1, 1.5),
new CompetingRiskResponse(2, 3.0),
new CompetingRiskResponse(2, 4.0)
);
final RightContinuousStepFunction cif1 = RightContinuousStepFunction.constructFromPoints(Utils.easyList(
new Point(1.0, 0.25),
new Point(1.5, 0.45)
), 0.0);
final RightContinuousStepFunction cif2 = RightContinuousStepFunction.constructFromPoints(Utils.easyList(
new Point(3.0, 0.25),
new Point(4.0, 0.45)
), 0.0);
// This function is for the unused CHFs and survival curve
// If we see infinities or NaNs popping up in our output we should look here.
final RightContinuousStepFunction emptyFun = RightContinuousStepFunction.constructFromPoints(Utils.easyList(
new Point(0.0, Double.NaN)
), Double.NEGATIVE_INFINITY
);
final CompetingRiskFunctions function = CompetingRiskFunctions.builder()
.cumulativeIncidenceCurves(Utils.easyList(cif1, cif2))
.causeSpecificHazards(Utils.easyList(emptyFun, emptyFun))
.survivalCurve(emptyFun)
.build();
// Same prediction for every response.
this.functions = Utils.easyList(function, function, function, function, function, function);
final IBSCalculator calculator = new IBSCalculator();
this.errors = new double[2][6];
for(int event : new int[]{1, 2}){
for(int i=0; i<6; i++){
this.errors[event-1][i] = calculator.calculateError(
responses.get(i), function.getCumulativeIncidenceFunction(event),
event, integrationUpperBound
);
}
}
}
@Test
public void testOneEventOne(){
final IBSErrorCalculatorWrapper wrapper = new IBSErrorCalculatorWrapper(new IBSCalculator(), new int[]{1},
this.integrationUpperBound);
final double error = wrapper.averageError(this.responses, this.functions);
double expectedError = 0.0;
for(int i=0; i<6; i++){
expectedError += errors[0][i] / 6.0;
}
assertEquals(expectedError, error, 0.00000001);
}
@Test
public void testOneEventTwo(){
final IBSErrorCalculatorWrapper wrapper = new IBSErrorCalculatorWrapper(new IBSCalculator(), new int[]{2},
this.integrationUpperBound);
final double error = wrapper.averageError(this.responses, this.functions);
double expectedError = 0.0;
for(int i=0; i<6; i++){
expectedError += errors[1][i] / 6.0;
}
assertEquals(expectedError, error, 0.00000001);
}
@Test
public void testTwoEventsNoWeights(){
final IBSErrorCalculatorWrapper wrapper = new IBSErrorCalculatorWrapper(new IBSCalculator(), new int[]{1, 2},
this.integrationUpperBound);
final double error = wrapper.averageError(this.responses, this.functions);
double expectedError1 = 0.0;
double expectedError2 = 0.0;
for(int i=0; i<6; i++){
expectedError1 += errors[0][i] / 6.0;
expectedError2 += errors[1][i] / 6.0;
}
assertEquals(expectedError1 + expectedError2, error, 0.00000001);
}
@Test
public void testTwoEventsWithWeights(){
final IBSErrorCalculatorWrapper wrapper = new IBSErrorCalculatorWrapper(new IBSCalculator(), new int[]{1, 2},
this.integrationUpperBound, new double[]{1.0, 2.0});
final double error = wrapper.averageError(this.responses, this.functions);
double expectedError1 = 0.0;
double expectedError2 = 0.0;
for(int i=0; i<6; i++){
expectedError1 += errors[0][i] / 6.0;
expectedError2 += errors[1][i] / 6.0;
}
assertEquals(1.0 * expectedError1 + 2.0 * expectedError2, error, 0.00000001);
}
}

View file

@ -0,0 +1,26 @@
package ca.joeltherrien.randomforest.tree.vimp;
import ca.joeltherrien.randomforest.utils.Utils;
import org.junit.jupiter.api.Test;
import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
public class RegressionErrorCalculatorTest {
private final RegressionErrorCalculator calculator = new RegressionErrorCalculator();
@Test
public void testRegressionErrorCalculator(){
final List<Double> responses = Utils.easyList(1.0, 1.5, 0.0, 3.0);
final List<Double> predictions = Utils.easyList(1.5, 1.7, 0.1, 2.9);
// Differences are 0.5, 0.2, -0.1, 0.1
// Squared: 0.25, 0.04, 0.01, 0.01
assertEquals((0.25 + 0.04 + 0.01 + 0.01)/4.0, calculator.averageError(responses, predictions), 0.000000001);
}
}

View file

@ -0,0 +1,359 @@
package ca.joeltherrien.randomforest.tree.vimp;
import ca.joeltherrien.randomforest.Row;
import ca.joeltherrien.randomforest.covariates.Covariate;
import ca.joeltherrien.randomforest.covariates.bool.BooleanCovariate;
import ca.joeltherrien.randomforest.covariates.bool.BooleanSplitRule;
import ca.joeltherrien.randomforest.covariates.factor.FactorCovariate;
import ca.joeltherrien.randomforest.covariates.numeric.NumericCovariate;
import ca.joeltherrien.randomforest.covariates.numeric.NumericSplitRule;
import ca.joeltherrien.randomforest.responses.regression.MeanResponseCombiner;
import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceSplitFinder;
import ca.joeltherrien.randomforest.tree.*;
import ca.joeltherrien.randomforest.utils.Utils;
import org.junit.jupiter.api.Test;
import java.util.*;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import static org.junit.jupiter.api.Assertions.assertEquals;
public class VariableImportanceCalculatorTest {
/*
Since the logic for VariableImportanceCalculator is generic, it will be much easier to test under a regression
setting.
*/
// We'l have a very simple Forest of two trees
private final Forest<Double, Double> forest;
private final List<Covariate> covariates;
private final List<Row<Double>> rowList;
/*
Long setup process; forest is manually constructed so that we can be exactly sure on our variable importance.
*/
public VariableImportanceCalculatorTest(){
final BooleanCovariate booleanCovariate = new BooleanCovariate("x", 0);
final NumericCovariate numericCovariate = new NumericCovariate("y", 1);
final FactorCovariate factorCovariate = new FactorCovariate("z", 2,
Utils.easyList("red", "blue", "green"));
this.covariates = Utils.easyList(booleanCovariate, numericCovariate, factorCovariate);
final TreeTrainer<Double, Double> treeTrainer = TreeTrainer.<Double, Double>builder()
.responseCombiner(new MeanResponseCombiner())
.splitFinder(new WeightedVarianceSplitFinder())
.numberOfSplits(0)
.nodeSize(1)
.maxNodeDepth(100)
.mtry(3)
.checkNodePurity(false)
.covariates(this.covariates)
.build();
/*
Plan for data - BooleanCovariate is split on first and has the largest impact.
NumericCovariate is at second level and has more minimal impact.
FactorCovariate is useless and never used.
Our tree (we'll duplicate it for testing OOB errors) will have a depth of 1. (0 based).
*/
final Tree<Double> tree1 = makeTree(covariates, 0.0, new int[]{1,2,3,4});
final Tree<Double> tree2 = makeTree(covariates, 2.0, new int[]{5,6,7,8});
this.forest = Forest.<Double, Double>builder()
.trees(Utils.easyList(tree1, tree2))
.treeResponseCombiner(new MeanResponseCombiner())
.covariateList(this.covariates)
.build();
// formula; boolean high adds 100; high numeric adds 10
// This row list should have a baseline error of 0.0
this.rowList = Utils.easyList(
Row.createSimple(Utils.easyMap(
"x", "false",
"y", "0.0",
"z", "red"),
covariates, 1, 0.0
),
Row.createSimple(Utils.easyMap(
"x", "false",
"y", "10.0",
"z", "blue"),
covariates, 2, 10.0
),
Row.createSimple(Utils.easyMap(
"x", "true",
"y", "0.0",
"z", "red"),
covariates, 3, 100.0
),
Row.createSimple(Utils.easyMap(
"x", "true",
"y", "10.0",
"z", "green"),
covariates, 4, 110.0
),
Row.createSimple(Utils.easyMap(
"x", "false",
"y", "0.0",
"z", "red"),
covariates, 5, 0.0
),
Row.createSimple(Utils.easyMap(
"x", "false",
"y", "10.0",
"z", "blue"),
covariates, 6, 10.0
),
Row.createSimple(Utils.easyMap(
"x", "true",
"y", "0.0",
"z", "red"),
covariates, 7, 100.0
),
Row.createSimple(Utils.easyMap(
"x", "true",
"y", "10.0",
"z", "green"),
covariates, 8, 110.0
)
);
}
private Tree<Double> makeTree(List<Covariate> covariates, double offset, int[] indices){
// Naming convention - xyTerminal where x and y are low/high denotes whether BooleanCovariate(x) is low/high and
// whether NumericCovariate(y) is low/high.
final TerminalNode<Double> lowLowTerminal = new TerminalNode<>(0.0 + offset, 5);
final TerminalNode<Double> lowHighTerminal = new TerminalNode<>(10.0 + offset, 5);
final TerminalNode<Double> highLowTerminal = new TerminalNode<>(100.0 + offset, 5);
final TerminalNode<Double> highHighTerminal = new TerminalNode<>(110.0 + offset, 5);
final SplitNode<Double> lowSplitNode = SplitNode.<Double>builder()
.leftHand(lowLowTerminal)
.rightHand(lowHighTerminal)
.probabilityNaLeftHand(0.5)
.splitRule(new NumericSplitRule((NumericCovariate) covariates.get(1), 5.0))
.build();
final SplitNode<Double> highSplitNode = SplitNode.<Double>builder()
.leftHand(highLowTerminal)
.rightHand(highHighTerminal)
.probabilityNaLeftHand(0.5)
.splitRule(new NumericSplitRule((NumericCovariate) covariates.get(1), 5.0))
.build();
final SplitNode<Double> rootSplitNode = SplitNode.<Double>builder()
.leftHand(lowSplitNode)
.rightHand(highSplitNode)
.probabilityNaLeftHand(0.5)
.splitRule(new BooleanSplitRule((BooleanCovariate) covariates.get(0)))
.build();
return new Tree<>(rootSplitNode, indices);
}
// Experiment with random seeds to first examine what a split does so we know what to expect
/*
public static void main(String[] args){
final List<Integer> ints1 = IntStream.range(1, 9).boxed().collect(Collectors.toList());
final List<Integer> ints2 = IntStream.range(1, 9).boxed().collect(Collectors.toList());
final Random random = new Random(123);
Collections.shuffle(ints1, random);
Collections.shuffle(ints2, random);
System.out.println(ints1);
// [1, 4, 8, 2, 5, 3, 7, 6]
System.out.println(ints2);
[6, 1, 4, 7, 5, 2, 8, 3]
}
*/
@Test
public void testVariableImportanceOnXNoOOB(){
// x is the BooleanCovariate
Random random = new Random(123);
final VariableImportanceCalculator<Double, Double> calculator = new VariableImportanceCalculator<>(
new RegressionErrorCalculator(),
this.forest,
this.rowList,
false
);
double importance = calculator.calculateVariableImportance(this.covariates.get(0), Optional.of(random));
final double expectedBaselineError = 1.0; // Everything is off by 1, so average is 1.0
final List<Double> permutedPredictions = Utils.easyList(
1., 111., 101., 11., 1., 111., 101., 11.
);
final List<Double> observedValues = this.rowList.stream().map(Row::getResponse).collect(Collectors.toList());
final double expectedError = new RegressionErrorCalculator().averageError(observedValues, permutedPredictions);
assertEquals(expectedError - expectedBaselineError, importance, 0.0000001);
}
@Test
public void testVariableImportanceOnXOOB(){
// x is the BooleanCovariate
Random random = new Random(123);
final VariableImportanceCalculator<Double, Double> calculator = new VariableImportanceCalculator<>(
new RegressionErrorCalculator(),
this.forest,
this.rowList,
true
);
double importance = calculator.calculateVariableImportance(this.covariates.get(0), Optional.of(random));
// First 4 observations are off by 2, last 4 are off by 0
final double expectedBaselineError = 2.0*2.0 * 4.0 / 8.0;
// Remember we are working with OOB predictions
final List<Double> permutedPredictions = Utils.easyList(
2., 112., 102., 12., 0., 110., 100., 10.
);
final List<Double> observedValues = this.rowList.stream().map(Row::getResponse).collect(Collectors.toList());
final double expectedError = new RegressionErrorCalculator().averageError(observedValues, permutedPredictions);
assertEquals(expectedError - expectedBaselineError, importance, 0.0000001);
}
@Test
public void testVariableImportanceOnYNoOOB(){
// y is the NumericCovariate
Random random = new Random(123);
final VariableImportanceCalculator<Double, Double> calculator = new VariableImportanceCalculator<>(
new RegressionErrorCalculator(),
this.forest,
this.rowList,
false
);
double importance = calculator.calculateVariableImportance(this.covariates.get(1), Optional.of(random));
final double expectedBaselineError = 1.0; // Everything is off by 1, so average is 1.0
final List<Double> permutedPredictions = Utils.easyList(
1., 11., 111., 111., 1., 1., 101., 111.
);
final List<Double> observedValues = this.rowList.stream().map(Row::getResponse).collect(Collectors.toList());
final double expectedError = new RegressionErrorCalculator().averageError(observedValues, permutedPredictions);
assertEquals(expectedError - expectedBaselineError, importance, 0.0000001);
}
@Test
public void testVariableImportanceOnYOOB(){
// y is the NumericCovariate
Random random = new Random(123);
final VariableImportanceCalculator<Double, Double> calculator = new VariableImportanceCalculator<>(
new RegressionErrorCalculator(),
this.forest,
this.rowList,
true
);
double importance = calculator.calculateVariableImportance(this.covariates.get(1), Optional.of(random));
// First 4 observations are off by 2, last 4 are off by 0
final double expectedBaselineError = 2.0*2.0 * 4.0 / 8.0;
// Remember we are working with OOB predictions
final List<Double> permutedPredictions = Utils.easyList(
2., 12., 112., 112., 0., 0., 100., 110.
);
final List<Double> observedValues = this.rowList.stream().map(Row::getResponse).collect(Collectors.toList());
final double expectedError = new RegressionErrorCalculator().averageError(observedValues, permutedPredictions);
assertEquals(expectedError - expectedBaselineError, importance, 0.0000001);
}
@Test
public void testVariableImportanceOnZNoOOB(){
// z is the useless FactorCovariate
Random random = new Random(123);
final VariableImportanceCalculator<Double, Double> calculator = new VariableImportanceCalculator<>(
new RegressionErrorCalculator(),
this.forest,
this.rowList,
false
);
double importance = calculator.calculateVariableImportance(this.covariates.get(2), Optional.of(random));
// FactorImportance did nothing; so permuting it will make no difference to baseline error
assertEquals(0, importance, 0.0000001);
}
@Test
public void testVariableImportanceOnZOOB(){
// z is the useless FactorCovariate
Random random = new Random(123);
final VariableImportanceCalculator<Double, Double> calculator = new VariableImportanceCalculator<>(
new RegressionErrorCalculator(),
this.forest,
this.rowList,
true
);
double importance = calculator.calculateVariableImportance(this.covariates.get(2), Optional.of(random));
// FactorImportance did nothing; so permuting it will make no difference to baseline error
assertEquals(0, importance, 0.0000001);
}
@Test
public void testVariableImportanceMultiple(){
Random random = new Random(123);
final VariableImportanceCalculator<Double, Double> calculator = new VariableImportanceCalculator<>(
new RegressionErrorCalculator(),
this.forest,
this.rowList,
false
);
double importance[] = calculator.calculateVariableImportance(covariates, Optional.of(random));
final double expectedBaselineError = 1.0; // Everything is off by 1, so average is 1.0
final List<Double> observedValues = this.rowList.stream().map(Row::getResponse).collect(Collectors.toList());
final List<Double> permutedPredictionsX = Utils.easyList(
1., 111., 101., 11., 1., 111., 101., 11.
);
// [6, 1, 4, 7, 5, 2, 8, 3]
final List<Double> permutedPredictionsY = Utils.easyList(
11., 1., 111., 101., 1., 11., 111., 101.
);
final double expectedErrorX = new RegressionErrorCalculator().averageError(observedValues, permutedPredictionsX);
final double expectedErrorY = new RegressionErrorCalculator().averageError(observedValues, permutedPredictionsY);
assertEquals(expectedErrorX - expectedBaselineError, importance[0], 0.0000001);
assertEquals(expectedErrorY - expectedBaselineError, importance[1], 0.0000001);
assertEquals(0, importance[2], 0.0000001);
}
}