Fix how variable importance works to be tree based and not forest based

This commit is contained in:
Joel Therrien 2019-08-12 14:28:31 -07:00
parent f1c5b292ed
commit 51696e2546
4 changed files with 310 additions and 113 deletions

View file

@ -3,7 +3,7 @@ package ca.joeltherrien.randomforest.tree.vimp;
import ca.joeltherrien.randomforest.CovariateRow;
import ca.joeltherrien.randomforest.Row;
import ca.joeltherrien.randomforest.covariates.Covariate;
import ca.joeltherrien.randomforest.tree.Forest;
import ca.joeltherrien.randomforest.tree.Tree;
import java.util.List;
import java.util.Optional;
@ -14,52 +14,104 @@ import java.util.stream.Collectors;
public class VariableImportanceCalculator<Y, P> {
private final ErrorCalculator<Y, P> errorCalculator;
private final Forest<Y, P> forest;
private final List<Tree<P>> trees;
private final List<Row<Y>> observations;
private final List<Y> observedResponses;
private final boolean isTrainingSet; // If true, then we use out-of-bag predictions
private final double baselineError;
private final double[] baselineErrors;
public VariableImportanceCalculator(
ErrorCalculator<Y, P> errorCalculator,
Forest<Y, P> forest,
List<Tree<P>> trees,
List<Row<Y>> observations,
boolean isTrainingSet
){
this.errorCalculator = errorCalculator;
this.forest = forest;
this.trees = trees;
this.observations = observations;
this.isTrainingSet = isTrainingSet;
this.observedResponses = observations.stream()
.map(row -> row.getResponse()).collect(Collectors.toList());
final List<P> baselinePredictions = makePredictions(observations);
this.baselineError = errorCalculator.averageError(observedResponses, baselinePredictions);
try {
this.baselineErrors = new double[trees.size()];
for (int i = 0; i < baselineErrors.length; i++) {
final Tree<P> tree = trees.get(i);
final List<Row<Y>> oobSubset = getAppropriateSubset(observations, tree); // may not actually be OOB depending on isTrainingSet
final List<Y> responses = oobSubset.stream().map(Row::getResponse).collect(Collectors.toList());
this.baselineErrors[i] = errorCalculator.averageError(responses, makePredictions(oobSubset, tree));
}
} catch(Exception e){
e.printStackTrace();
throw e;
}
}
public double calculateVariableImportance(Covariate covariate, Optional<Random> random){
final List<CovariateRow> scrambledValues = CovariateRow.scrambleCovariateValues(this.observations, covariate, random);
final List<P> alternatePredictions = makePredictions(scrambledValues);
final double newError = errorCalculator.averageError(this.observedResponses, alternatePredictions);
/**
* Returns an array of importance values for every Tree for the given Covariate.
*
* @param covariate The Covariate to scramble.
* @param random
* @return
*/
public double[] calculateVariableImportanceRaw(Covariate covariate, Optional<Random> random){
return newError - this.baselineError;
final double[] vimp = new double[trees.size()];
for(int i = 0; i < vimp.length; i++){
final Tree<P> tree = trees.get(i);
final List<Row<Y>> oobSubset = getAppropriateSubset(observations, tree); // may not actually be OOB depending on isTrainingSet
final List<Y> responses = oobSubset.stream().map(Row::getResponse).collect(Collectors.toList());
final List<CovariateRow> scrambledValues = CovariateRow.scrambleCovariateValues(oobSubset, covariate, random);
final double error = errorCalculator.averageError(responses, makePredictions(scrambledValues, tree));
vimp[i] = error - this.baselineErrors[i];
}
public double[] calculateVariableImportance(List<Covariate> covariates, Optional<Random> random){
return covariates.stream()
.mapToDouble(covariate -> calculateVariableImportance(covariate, random))
.toArray();
return vimp;
}
private List<P> makePredictions(List<? extends CovariateRow> rowList){
if(isTrainingSet){
return forest.evaluateOOB(rowList);
} else{
return forest.evaluate(rowList);
public double calculateVariableImportanceZScore(Covariate covariate, Optional<Random> random){
final double[] vimpArray = calculateVariableImportanceRaw(covariate, random);
double mean = 0.0;
double variance = 0.0;
final double numTrees = vimpArray.length;
for(double vimp : vimpArray){
mean += vimp / numTrees;
}
for(double vimp : vimpArray){
variance += (vimp - mean)*(vimp - mean) / (numTrees - 1.0);
}
final double standardError = Math.sqrt(variance / numTrees);
return mean / standardError;
}
// Assume rowList has already been filtered for OOB
private List<P> makePredictions(List<? extends CovariateRow> rowList, Tree<P> tree){
return rowList.stream()
.map(tree::evaluate)
.collect(Collectors.toList());
}
private List<Row<Y>> getAppropriateSubset(List<Row<Y>> initialList, Tree<P> tree){
if(!isTrainingSet){
return initialList; // no need to make any subsets
}
return initialList.stream()
.filter(row -> !tree.idInBootstrapSample(row.getId()))
.collect(Collectors.toList());
}
}

View file

@ -136,6 +136,11 @@ public final class RightContinuousStepFunction extends StepFunction {
return -integrate(to, from);
}
// Edge case - no points; just defaultY
if(this.x.length == 0){
return (to - from) * this.defaultY;
}
double summation = 0.0;
final double[] xPoints = getX();
final int startingIndex;

View file

@ -13,9 +13,10 @@ import ca.joeltherrien.randomforest.tree.*;
import ca.joeltherrien.randomforest.utils.Utils;
import org.junit.jupiter.api.Test;
import java.util.*;
import java.util.List;
import java.util.Optional;
import java.util.Random;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import static org.junit.jupiter.api.Assertions.assertEquals;
@ -163,197 +164,322 @@ public class VariableImportanceCalculatorTest {
// Experiment with random seeds to first examine what a split does so we know what to expect
/*
public static void main(String[] args){
final List<Integer> ints1 = IntStream.range(1, 9).boxed().collect(Collectors.toList());
final List<Integer> ints2 = IntStream.range(1, 9).boxed().collect(Collectors.toList());
// Behaviour for OOB
final List<Integer> ints1 = IntStream.range(5, 9).boxed().collect(Collectors.toList());
final List<Integer> ints2 = IntStream.range(1, 5).boxed().collect(Collectors.toList());
final Random random = new Random(123);
Collections.shuffle(ints1, random);
Collections.shuffle(ints2, random);
System.out.println(ints1);
// [1, 4, 8, 2, 5, 3, 7, 6]
System.out.println(ints2);
[6, 1, 4, 7, 5, 2, 8, 3]
// [5, 6, 8, 7]
// [3, 4, 1, 2]
// Behaviour for no-OOB
final List<Integer> fullInts1 = IntStream.range(1, 9).boxed().collect(Collectors.toList());
final List<Integer> fullInts2 = IntStream.range(1, 9).boxed().collect(Collectors.toList());
final Random fullIntsRandom = new Random(123);
Collections.shuffle(fullInts1, fullIntsRandom);
Collections.shuffle(fullInts2, fullIntsRandom);
System.out.println(fullInts1);
System.out.println(fullInts2);
// [1, 4, 8, 2, 5, 3, 7, 6]
// [6, 1, 4, 7, 5, 2, 8, 3]
}
*/
private double[] difference(double[] a, double[] b){
final double[] results = new double[a.length];
for(int i = 0; i < a.length; i++){
results[i] = a[i] - b[i];
}
return results;
}
private void assertDoubleEquals(double[] expected, double[] actual){
assertEquals(expected.length, actual.length, "Lengths of arrays should be equal");
for(int i=0; i < expected.length; i++){
assertEquals(expected[i], actual[i], 0.0000001, "Difference at " + i);
}
}
@Test
public void testVariableImportanceOnXNoOOB(){
// x is the BooleanCovariate
Random random = new Random(123);
final VariableImportanceCalculator<Double, Double> calculator = new VariableImportanceCalculator<>(
new RegressionErrorCalculator(),
this.forest,
this.forest.getTrees(),
this.rowList,
false
);
double importance = calculator.calculateVariableImportance(this.covariates.get(0), Optional.of(random));
final Covariate covariate = this.covariates.get(0);
final double expectedBaselineError = 1.0; // Everything is off by 1, so average is 1.0
double importance[] = calculator.calculateVariableImportanceRaw(covariate, Optional.of(new Random(123)));
final List<Double> permutedPredictions = Utils.easyList(
1., 111., 101., 11., 1., 111., 101., 11.
final double[] expectedBaselineError = {0.0, 4.0}; // first tree is accurate, second tree is not
// [1, 4, 8, 2, 5, 3, 7, 6]
final List<Double> permutedPredictionsTree1 = Utils.easyList(
0., 110., 100., 10., 0., 110., 100., 10.
);
// [6, 1, 4, 7, 5, 2, 8, 3]
// Actual: [F, F, T, T, F, F, T, T]
// Seen: [F, F, T, T, F, F, T, T]
// Difference: 0 all around; random chance
final List<Double> permutedPredictionsTree2 = Utils.easyList(
2., 12., 102., 112., 2., 12., 102., 112.
);
final List<Double> observedValues = this.rowList.stream().map(Row::getResponse).collect(Collectors.toList());
final double expectedError = new RegressionErrorCalculator().averageError(observedValues, permutedPredictions);
final double[] expectedError = new double[2];
assertEquals(expectedError - expectedBaselineError, importance, 0.0000001);
expectedError[0] = new RegressionErrorCalculator().averageError(observedValues, permutedPredictionsTree1);
expectedError[1] = new RegressionErrorCalculator().averageError(observedValues, permutedPredictionsTree2);
final double[] expectedVimp = difference(expectedError, expectedBaselineError);
assertDoubleEquals(expectedVimp, importance);
final double expectedVimpMean = (expectedVimp[0] + expectedVimp[1]) / 2.0;
final double expectedVimpVar = (Math.pow(expectedVimp[0] - expectedVimpMean, 2) + Math.pow(expectedVimp[1] - expectedVimpMean, 2)) / 1.0;
final double expectedVimpStandardError = Math.sqrt(expectedVimpVar / 2.0);
final double expectedZScore = expectedVimpMean / expectedVimpStandardError;
final double actualZScore = calculator.calculateVariableImportanceZScore(covariate, Optional.of(new Random(123)));
assertEquals(expectedZScore, actualZScore, 0.000001, "Z scores must match");
}
@Test
public void testVariableImportanceOnXOOB(){
// x is the BooleanCovariate
Random random = new Random(123);
final VariableImportanceCalculator<Double, Double> calculator = new VariableImportanceCalculator<>(
new RegressionErrorCalculator(),
this.forest,
this.forest.getTrees(),
this.rowList,
true
);
double importance = calculator.calculateVariableImportance(this.covariates.get(0), Optional.of(random));
final Covariate covariate = this.covariates.get(0);
// First 4 observations are off by 2, last 4 are off by 0
final double expectedBaselineError = 2.0*2.0 * 4.0 / 8.0;
double importance[] = calculator.calculateVariableImportanceRaw(covariate, Optional.of(new Random(123)));
// Remember we are working with OOB predictions
final List<Double> permutedPredictions = Utils.easyList(
2., 112., 102., 12., 0., 110., 100., 10.
final double[] expectedBaselineError = {0.0, 4.0}; // first tree is accurate, second tree is not
// [5, 6, 8, 7]
// Actual: [F, F, T, T]
// Seen: [F, F, T, T]
// Difference: No differences
final List<Double> permutedPredictionsTree1 = Utils.easyList(
0., 10., 100., 110.
);
// [3, 4, 1, 2]
// Actual: [F, F, T, T]
// Seen: [T, T, F, F]
// Difference: +100, +100, -100, -100
final List<Double> permutedPredictionsTree2 = Utils.easyList(
102., 112., 2., 12.
);
final List<Double> observedValues = this.rowList.stream().map(Row::getResponse).collect(Collectors.toList());
final List<Double> tree1OOBValues = observedValues.subList(4, 8);
final List<Double> tree2OOBValues = observedValues.subList(0, 4);
final double expectedError = new RegressionErrorCalculator().averageError(observedValues, permutedPredictions);
final double[] expectedError = new double[2];
assertEquals(expectedError - expectedBaselineError, importance, 0.0000001);
expectedError[0] = new RegressionErrorCalculator().averageError(tree1OOBValues, permutedPredictionsTree1);
expectedError[1] = new RegressionErrorCalculator().averageError(tree2OOBValues, permutedPredictionsTree2);
final double[] expectedVimp = difference(expectedError, expectedBaselineError);
assertDoubleEquals(expectedVimp, importance);
final double expectedVimpMean = (expectedVimp[0] + expectedVimp[1]) / 2.0;
final double expectedVimpVar = (Math.pow(expectedVimp[0] - expectedVimpMean, 2) + Math.pow(expectedVimp[1] - expectedVimpMean, 2)) / 1.0;
final double expectedVimpStandardError = Math.sqrt(expectedVimpVar / 2.0);
final double expectedZScore = expectedVimpMean / expectedVimpStandardError;
final double actualZScore = calculator.calculateVariableImportanceZScore(covariate, Optional.of(new Random(123)));
assertEquals(expectedZScore, actualZScore, 0.000001, "Z scores must match");
}
@Test
public void testVariableImportanceOnYNoOOB(){
// y is the NumericCovariate
Random random = new Random(123);
final VariableImportanceCalculator<Double, Double> calculator = new VariableImportanceCalculator<>(
new RegressionErrorCalculator(),
this.forest,
this.forest.getTrees(),
this.rowList,
false
);
double importance = calculator.calculateVariableImportance(this.covariates.get(1), Optional.of(random));
final Covariate covariate = this.covariates.get(1);
final double expectedBaselineError = 1.0; // Everything is off by 1, so average is 1.0
double importance[] = calculator.calculateVariableImportanceRaw(covariate, Optional.of(new Random(123)));
final List<Double> permutedPredictions = Utils.easyList(
1., 11., 111., 111., 1., 1., 101., 111.
final double[] expectedBaselineError = {0.0, 4.0}; // first tree is accurate, second tree is not
// [1, 4, 8, 2, 5, 3, 7, 6]
// Actual: [F, T, F, T, F, T, F, T]
// Seen: [F, T, T, T, F, F, F, T]
// Difference: [=, =, +, =, =, -, =, =]x10
final List<Double> permutedPredictionsTree1 = Utils.easyList(
0., 10., 110., 110., 0., 0., 100., 110.
);
// [6, 1, 4, 7, 5, 2, 8, 3]
// Actual: [F, T, F, T, F, T, F, T]
// Seen: [T, F, T, F, F, T, T, F]
// Difference: [+, -, +, -, =, =, +, -]
final List<Double> permutedPredictionsTree2 = Utils.easyList(
12., 2., 112., 102., 2., 12., 112., 102.
);
final List<Double> observedValues = this.rowList.stream().map(Row::getResponse).collect(Collectors.toList());
final double expectedError = new RegressionErrorCalculator().averageError(observedValues, permutedPredictions);
final double[] expectedError = new double[2];
assertEquals(expectedError - expectedBaselineError, importance, 0.0000001);
expectedError[0] = new RegressionErrorCalculator().averageError(observedValues, permutedPredictionsTree1);
expectedError[1] = new RegressionErrorCalculator().averageError(observedValues, permutedPredictionsTree2);
final double[] expectedVimp = difference(expectedError, expectedBaselineError);
assertDoubleEquals(expectedVimp, importance);
final double expectedVimpMean = (expectedVimp[0] + expectedVimp[1]) / 2.0;
final double expectedVimpVar = (Math.pow(expectedVimp[0] - expectedVimpMean, 2) + Math.pow(expectedVimp[1] - expectedVimpMean, 2)) / 1.0;
final double expectedVimpStandardError = Math.sqrt(expectedVimpVar / 2.0);
final double expectedZScore = expectedVimpMean / expectedVimpStandardError;
final double actualZScore = calculator.calculateVariableImportanceZScore(covariate, Optional.of(new Random(123)));
assertEquals(expectedZScore, actualZScore, 0.000001, "Z scores must match");
}
@Test
public void testVariableImportanceOnYOOB(){
// y is the NumericCovariate
Random random = new Random(123);
final VariableImportanceCalculator<Double, Double> calculator = new VariableImportanceCalculator<>(
new RegressionErrorCalculator(),
this.forest,
this.forest.getTrees(),
this.rowList,
true
);
double importance = calculator.calculateVariableImportance(this.covariates.get(1), Optional.of(random));
final Covariate covariate = this.covariates.get(1);
// First 4 observations are off by 2, last 4 are off by 0
final double expectedBaselineError = 2.0*2.0 * 4.0 / 8.0;
double importance[] = calculator.calculateVariableImportanceRaw(covariate, Optional.of(new Random(123)));
// Remember we are working with OOB predictions
final List<Double> permutedPredictions = Utils.easyList(
2., 12., 112., 112., 0., 0., 100., 110.
final double[] expectedBaselineError = {0.0, 4.0}; // first tree is accurate, second tree is not
// [5, 6, 8, 7]
// Actual: [F, T, F, T]
// Seen: [F, T, T, F]
// Difference: [=, =, +, -]x10
final List<Double> permutedPredictionsTree1 = Utils.easyList(
0., 10., 110., 100.
);
// [3, 4, 1, 2]
// Actual: [F, T, F, T]
// Seen: [F, T, F, T]
// Difference: [=, =, =, =]x10 no change
final List<Double> permutedPredictionsTree2 = Utils.easyList(
2., 12., 102., 112.
);
final List<Double> observedValues = this.rowList.stream().map(Row::getResponse).collect(Collectors.toList());
final List<Double> tree1OOBValues = observedValues.subList(4, 8);
final List<Double> tree2OOBValues = observedValues.subList(0, 4);
final double[] expectedError = new double[2];
expectedError[0] = new RegressionErrorCalculator().averageError(tree1OOBValues, permutedPredictionsTree1);
expectedError[1] = new RegressionErrorCalculator().averageError(tree2OOBValues, permutedPredictionsTree2);
final double[] expectedVimp = difference(expectedError, expectedBaselineError);
assertDoubleEquals(expectedVimp, importance);
final double expectedVimpMean = (expectedVimp[0] + expectedVimp[1]) / 2.0;
final double expectedVimpVar = (Math.pow(expectedVimp[0] - expectedVimpMean, 2) + Math.pow(expectedVimp[1] - expectedVimpMean, 2)) / 1.0;
final double expectedVimpStandardError = Math.sqrt(expectedVimpVar / 2.0);
final double expectedZScore = expectedVimpMean / expectedVimpStandardError;
final double actualZScore = calculator.calculateVariableImportanceZScore(covariate, Optional.of(new Random(123)));
assertEquals(expectedZScore, actualZScore, 0.000001, "Z scores must match");
final double expectedError = new RegressionErrorCalculator().averageError(observedValues, permutedPredictions);
assertEquals(expectedError - expectedBaselineError, importance, 0.0000001);
}
@Test
public void testVariableImportanceOnZNoOOB(){
// z is the useless FactorCovariate
Random random = new Random(123);
final VariableImportanceCalculator<Double, Double> calculator = new VariableImportanceCalculator<>(
new RegressionErrorCalculator(),
this.forest,
this.forest.getTrees(),
this.rowList,
false
);
double importance = calculator.calculateVariableImportance(this.covariates.get(2), Optional.of(random));
final double[] importance = calculator.calculateVariableImportanceRaw(this.covariates.get(2), Optional.of(new Random(123)));
final double[] expectedImportance = {0.0, 0.0};
// FactorImportance did nothing; so permuting it will make no difference to baseline error
assertEquals(0, importance, 0.0000001);
assertDoubleEquals(expectedImportance, importance);
}
@Test
public void testVariableImportanceOnZOOB(){
// z is the useless FactorCovariate
Random random = new Random(123);
final VariableImportanceCalculator<Double, Double> calculator = new VariableImportanceCalculator<>(
new RegressionErrorCalculator(),
this.forest,
this.forest.getTrees(),
this.rowList,
true
);
double importance = calculator.calculateVariableImportance(this.covariates.get(2), Optional.of(random));
final double[] importance = calculator.calculateVariableImportanceRaw(this.covariates.get(2), Optional.of(new Random(123)));
final double[] expectedImportance = {0.0, 0.0};
// FactorImportance did nothing; so permuting it will make no difference to baseline error
assertEquals(0, importance, 0.0000001);
assertDoubleEquals(expectedImportance, importance);
}
@Test
public void testVariableImportanceMultiple(){
Random random = new Random(123);
final VariableImportanceCalculator<Double, Double> calculator = new VariableImportanceCalculator<>(
new RegressionErrorCalculator(),
this.forest,
this.rowList,
false
);
double importance[] = calculator.calculateVariableImportance(covariates, Optional.of(random));
final double expectedBaselineError = 1.0; // Everything is off by 1, so average is 1.0
final List<Double> observedValues = this.rowList.stream().map(Row::getResponse).collect(Collectors.toList());
final List<Double> permutedPredictionsX = Utils.easyList(
1., 111., 101., 11., 1., 111., 101., 11.
);
// [6, 1, 4, 7, 5, 2, 8, 3]
final List<Double> permutedPredictionsY = Utils.easyList(
11., 1., 111., 101., 1., 11., 111., 101.
);
final double expectedErrorX = new RegressionErrorCalculator().averageError(observedValues, permutedPredictionsX);
final double expectedErrorY = new RegressionErrorCalculator().averageError(observedValues, permutedPredictionsY);
assertEquals(expectedErrorX - expectedBaselineError, importance[0], 0.0000001);
assertEquals(expectedErrorY - expectedBaselineError, importance[1], 0.0000001);
assertEquals(0, importance[2], 0.0000001);
}
}

View file

@ -138,4 +138,18 @@ public class RightContinuousStepFunctionIntegrationTest {
}
@Test
public void testIntegratingEmptyFunction(){
// A function might have no points, but we'll still need to integrate it.
final RightContinuousStepFunction function = new RightContinuousStepFunction(
new double[]{}, new double[]{}, 1.0
);
final double area = function.integrate(1.0 ,3.0);
assertEquals(2.0, area, 0.000001);
}
}