Fix how variable importance works to be tree based and not forest based
This commit is contained in:
parent
f1c5b292ed
commit
51696e2546
4 changed files with 310 additions and 113 deletions
|
@ -3,7 +3,7 @@ package ca.joeltherrien.randomforest.tree.vimp;
|
||||||
import ca.joeltherrien.randomforest.CovariateRow;
|
import ca.joeltherrien.randomforest.CovariateRow;
|
||||||
import ca.joeltherrien.randomforest.Row;
|
import ca.joeltherrien.randomforest.Row;
|
||||||
import ca.joeltherrien.randomforest.covariates.Covariate;
|
import ca.joeltherrien.randomforest.covariates.Covariate;
|
||||||
import ca.joeltherrien.randomforest.tree.Forest;
|
import ca.joeltherrien.randomforest.tree.Tree;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Optional;
|
import java.util.Optional;
|
||||||
|
@ -14,52 +14,104 @@ import java.util.stream.Collectors;
|
||||||
public class VariableImportanceCalculator<Y, P> {
|
public class VariableImportanceCalculator<Y, P> {
|
||||||
|
|
||||||
private final ErrorCalculator<Y, P> errorCalculator;
|
private final ErrorCalculator<Y, P> errorCalculator;
|
||||||
private final Forest<Y, P> forest;
|
private final List<Tree<P>> trees;
|
||||||
private final List<Row<Y>> observations;
|
private final List<Row<Y>> observations;
|
||||||
private final List<Y> observedResponses;
|
|
||||||
|
|
||||||
private final boolean isTrainingSet; // If true, then we use out-of-bag predictions
|
private final boolean isTrainingSet; // If true, then we use out-of-bag predictions
|
||||||
private final double baselineError;
|
private final double[] baselineErrors;
|
||||||
|
|
||||||
public VariableImportanceCalculator(
|
public VariableImportanceCalculator(
|
||||||
ErrorCalculator<Y, P> errorCalculator,
|
ErrorCalculator<Y, P> errorCalculator,
|
||||||
Forest<Y, P> forest,
|
List<Tree<P>> trees,
|
||||||
List<Row<Y>> observations,
|
List<Row<Y>> observations,
|
||||||
boolean isTrainingSet
|
boolean isTrainingSet
|
||||||
){
|
){
|
||||||
this.errorCalculator = errorCalculator;
|
this.errorCalculator = errorCalculator;
|
||||||
this.forest = forest;
|
this.trees = trees;
|
||||||
this.observations = observations;
|
this.observations = observations;
|
||||||
this.isTrainingSet = isTrainingSet;
|
this.isTrainingSet = isTrainingSet;
|
||||||
|
|
||||||
this.observedResponses = observations.stream()
|
|
||||||
.map(row -> row.getResponse()).collect(Collectors.toList());
|
|
||||||
|
|
||||||
final List<P> baselinePredictions = makePredictions(observations);
|
try {
|
||||||
this.baselineError = errorCalculator.averageError(observedResponses, baselinePredictions);
|
|
||||||
|
|
||||||
|
this.baselineErrors = new double[trees.size()];
|
||||||
|
for (int i = 0; i < baselineErrors.length; i++) {
|
||||||
|
final Tree<P> tree = trees.get(i);
|
||||||
|
final List<Row<Y>> oobSubset = getAppropriateSubset(observations, tree); // may not actually be OOB depending on isTrainingSet
|
||||||
|
final List<Y> responses = oobSubset.stream().map(Row::getResponse).collect(Collectors.toList());
|
||||||
|
|
||||||
|
this.baselineErrors[i] = errorCalculator.averageError(responses, makePredictions(oobSubset, tree));
|
||||||
}
|
}
|
||||||
|
|
||||||
public double calculateVariableImportance(Covariate covariate, Optional<Random> random){
|
} catch(Exception e){
|
||||||
final List<CovariateRow> scrambledValues = CovariateRow.scrambleCovariateValues(this.observations, covariate, random);
|
e.printStackTrace();
|
||||||
final List<P> alternatePredictions = makePredictions(scrambledValues);
|
throw e;
|
||||||
final double newError = errorCalculator.averageError(this.observedResponses, alternatePredictions);
|
|
||||||
|
|
||||||
return newError - this.baselineError;
|
|
||||||
}
|
|
||||||
|
|
||||||
public double[] calculateVariableImportance(List<Covariate> covariates, Optional<Random> random){
|
|
||||||
return covariates.stream()
|
|
||||||
.mapToDouble(covariate -> calculateVariableImportance(covariate, random))
|
|
||||||
.toArray();
|
|
||||||
}
|
|
||||||
|
|
||||||
private List<P> makePredictions(List<? extends CovariateRow> rowList){
|
|
||||||
if(isTrainingSet){
|
|
||||||
return forest.evaluateOOB(rowList);
|
|
||||||
} else{
|
|
||||||
return forest.evaluate(rowList);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns an array of importance values for every Tree for the given Covariate.
|
||||||
|
*
|
||||||
|
* @param covariate The Covariate to scramble.
|
||||||
|
* @param random
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
public double[] calculateVariableImportanceRaw(Covariate covariate, Optional<Random> random){
|
||||||
|
|
||||||
|
final double[] vimp = new double[trees.size()];
|
||||||
|
for(int i = 0; i < vimp.length; i++){
|
||||||
|
final Tree<P> tree = trees.get(i);
|
||||||
|
final List<Row<Y>> oobSubset = getAppropriateSubset(observations, tree); // may not actually be OOB depending on isTrainingSet
|
||||||
|
final List<Y> responses = oobSubset.stream().map(Row::getResponse).collect(Collectors.toList());
|
||||||
|
final List<CovariateRow> scrambledValues = CovariateRow.scrambleCovariateValues(oobSubset, covariate, random);
|
||||||
|
|
||||||
|
final double error = errorCalculator.averageError(responses, makePredictions(scrambledValues, tree));
|
||||||
|
|
||||||
|
vimp[i] = error - this.baselineErrors[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
return vimp;
|
||||||
|
}
|
||||||
|
|
||||||
|
public double calculateVariableImportanceZScore(Covariate covariate, Optional<Random> random){
|
||||||
|
final double[] vimpArray = calculateVariableImportanceRaw(covariate, random);
|
||||||
|
|
||||||
|
double mean = 0.0;
|
||||||
|
double variance = 0.0;
|
||||||
|
final double numTrees = vimpArray.length;
|
||||||
|
|
||||||
|
for(double vimp : vimpArray){
|
||||||
|
mean += vimp / numTrees;
|
||||||
|
}
|
||||||
|
for(double vimp : vimpArray){
|
||||||
|
variance += (vimp - mean)*(vimp - mean) / (numTrees - 1.0);
|
||||||
|
}
|
||||||
|
|
||||||
|
final double standardError = Math.sqrt(variance / numTrees);
|
||||||
|
|
||||||
|
return mean / standardError;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
// Assume rowList has already been filtered for OOB
|
||||||
|
private List<P> makePredictions(List<? extends CovariateRow> rowList, Tree<P> tree){
|
||||||
|
return rowList.stream()
|
||||||
|
.map(tree::evaluate)
|
||||||
|
.collect(Collectors.toList());
|
||||||
|
}
|
||||||
|
|
||||||
|
private List<Row<Y>> getAppropriateSubset(List<Row<Y>> initialList, Tree<P> tree){
|
||||||
|
if(!isTrainingSet){
|
||||||
|
return initialList; // no need to make any subsets
|
||||||
|
}
|
||||||
|
|
||||||
|
return initialList.stream()
|
||||||
|
.filter(row -> !tree.idInBootstrapSample(row.getId()))
|
||||||
|
.collect(Collectors.toList());
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
|
@ -136,6 +136,11 @@ public final class RightContinuousStepFunction extends StepFunction {
|
||||||
return -integrate(to, from);
|
return -integrate(to, from);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Edge case - no points; just defaultY
|
||||||
|
if(this.x.length == 0){
|
||||||
|
return (to - from) * this.defaultY;
|
||||||
|
}
|
||||||
|
|
||||||
double summation = 0.0;
|
double summation = 0.0;
|
||||||
final double[] xPoints = getX();
|
final double[] xPoints = getX();
|
||||||
final int startingIndex;
|
final int startingIndex;
|
||||||
|
|
|
@ -13,9 +13,10 @@ import ca.joeltherrien.randomforest.tree.*;
|
||||||
import ca.joeltherrien.randomforest.utils.Utils;
|
import ca.joeltherrien.randomforest.utils.Utils;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
|
|
||||||
import java.util.*;
|
import java.util.List;
|
||||||
|
import java.util.Optional;
|
||||||
|
import java.util.Random;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
import java.util.stream.IntStream;
|
|
||||||
|
|
||||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
|
||||||
|
@ -163,197 +164,322 @@ public class VariableImportanceCalculatorTest {
|
||||||
// Experiment with random seeds to first examine what a split does so we know what to expect
|
// Experiment with random seeds to first examine what a split does so we know what to expect
|
||||||
/*
|
/*
|
||||||
public static void main(String[] args){
|
public static void main(String[] args){
|
||||||
final List<Integer> ints1 = IntStream.range(1, 9).boxed().collect(Collectors.toList());
|
|
||||||
final List<Integer> ints2 = IntStream.range(1, 9).boxed().collect(Collectors.toList());
|
// Behaviour for OOB
|
||||||
|
final List<Integer> ints1 = IntStream.range(5, 9).boxed().collect(Collectors.toList());
|
||||||
|
final List<Integer> ints2 = IntStream.range(1, 5).boxed().collect(Collectors.toList());
|
||||||
|
|
||||||
final Random random = new Random(123);
|
final Random random = new Random(123);
|
||||||
Collections.shuffle(ints1, random);
|
Collections.shuffle(ints1, random);
|
||||||
Collections.shuffle(ints2, random);
|
Collections.shuffle(ints2, random);
|
||||||
|
|
||||||
System.out.println(ints1);
|
System.out.println(ints1);
|
||||||
// [1, 4, 8, 2, 5, 3, 7, 6]
|
|
||||||
|
|
||||||
System.out.println(ints2);
|
System.out.println(ints2);
|
||||||
[6, 1, 4, 7, 5, 2, 8, 3]
|
// [5, 6, 8, 7]
|
||||||
|
// [3, 4, 1, 2]
|
||||||
|
|
||||||
|
|
||||||
|
// Behaviour for no-OOB
|
||||||
|
final List<Integer> fullInts1 = IntStream.range(1, 9).boxed().collect(Collectors.toList());
|
||||||
|
final List<Integer> fullInts2 = IntStream.range(1, 9).boxed().collect(Collectors.toList());
|
||||||
|
final Random fullIntsRandom = new Random(123);
|
||||||
|
|
||||||
|
Collections.shuffle(fullInts1, fullIntsRandom);
|
||||||
|
Collections.shuffle(fullInts2, fullIntsRandom);
|
||||||
|
System.out.println(fullInts1);
|
||||||
|
System.out.println(fullInts2);
|
||||||
|
// [1, 4, 8, 2, 5, 3, 7, 6]
|
||||||
|
// [6, 1, 4, 7, 5, 2, 8, 3]
|
||||||
|
|
||||||
}
|
}
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
|
||||||
|
private double[] difference(double[] a, double[] b){
|
||||||
|
final double[] results = new double[a.length];
|
||||||
|
|
||||||
|
for(int i = 0; i < a.length; i++){
|
||||||
|
results[i] = a[i] - b[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
return results;
|
||||||
|
}
|
||||||
|
|
||||||
|
private void assertDoubleEquals(double[] expected, double[] actual){
|
||||||
|
assertEquals(expected.length, actual.length, "Lengths of arrays should be equal");
|
||||||
|
|
||||||
|
for(int i=0; i < expected.length; i++){
|
||||||
|
assertEquals(expected[i], actual[i], 0.0000001, "Difference at " + i);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testVariableImportanceOnXNoOOB(){
|
public void testVariableImportanceOnXNoOOB(){
|
||||||
// x is the BooleanCovariate
|
// x is the BooleanCovariate
|
||||||
|
|
||||||
Random random = new Random(123);
|
|
||||||
final VariableImportanceCalculator<Double, Double> calculator = new VariableImportanceCalculator<>(
|
final VariableImportanceCalculator<Double, Double> calculator = new VariableImportanceCalculator<>(
|
||||||
new RegressionErrorCalculator(),
|
new RegressionErrorCalculator(),
|
||||||
this.forest,
|
this.forest.getTrees(),
|
||||||
this.rowList,
|
this.rowList,
|
||||||
false
|
false
|
||||||
);
|
);
|
||||||
|
|
||||||
double importance = calculator.calculateVariableImportance(this.covariates.get(0), Optional.of(random));
|
final Covariate covariate = this.covariates.get(0);
|
||||||
|
|
||||||
final double expectedBaselineError = 1.0; // Everything is off by 1, so average is 1.0
|
double importance[] = calculator.calculateVariableImportanceRaw(covariate, Optional.of(new Random(123)));
|
||||||
|
|
||||||
final List<Double> permutedPredictions = Utils.easyList(
|
final double[] expectedBaselineError = {0.0, 4.0}; // first tree is accurate, second tree is not
|
||||||
1., 111., 101., 11., 1., 111., 101., 11.
|
|
||||||
|
// [1, 4, 8, 2, 5, 3, 7, 6]
|
||||||
|
final List<Double> permutedPredictionsTree1 = Utils.easyList(
|
||||||
|
0., 110., 100., 10., 0., 110., 100., 10.
|
||||||
);
|
);
|
||||||
|
|
||||||
|
// [6, 1, 4, 7, 5, 2, 8, 3]
|
||||||
|
// Actual: [F, F, T, T, F, F, T, T]
|
||||||
|
// Seen: [F, F, T, T, F, F, T, T]
|
||||||
|
// Difference: 0 all around; random chance
|
||||||
|
final List<Double> permutedPredictionsTree2 = Utils.easyList(
|
||||||
|
2., 12., 102., 112., 2., 12., 102., 112.
|
||||||
|
);
|
||||||
|
|
||||||
final List<Double> observedValues = this.rowList.stream().map(Row::getResponse).collect(Collectors.toList());
|
final List<Double> observedValues = this.rowList.stream().map(Row::getResponse).collect(Collectors.toList());
|
||||||
|
|
||||||
final double expectedError = new RegressionErrorCalculator().averageError(observedValues, permutedPredictions);
|
final double[] expectedError = new double[2];
|
||||||
|
|
||||||
assertEquals(expectedError - expectedBaselineError, importance, 0.0000001);
|
expectedError[0] = new RegressionErrorCalculator().averageError(observedValues, permutedPredictionsTree1);
|
||||||
|
expectedError[1] = new RegressionErrorCalculator().averageError(observedValues, permutedPredictionsTree2);
|
||||||
|
|
||||||
|
final double[] expectedVimp = difference(expectedError, expectedBaselineError);
|
||||||
|
|
||||||
|
assertDoubleEquals(expectedVimp, importance);
|
||||||
|
|
||||||
|
final double expectedVimpMean = (expectedVimp[0] + expectedVimp[1]) / 2.0;
|
||||||
|
final double expectedVimpVar = (Math.pow(expectedVimp[0] - expectedVimpMean, 2) + Math.pow(expectedVimp[1] - expectedVimpMean, 2)) / 1.0;
|
||||||
|
final double expectedVimpStandardError = Math.sqrt(expectedVimpVar / 2.0);
|
||||||
|
final double expectedZScore = expectedVimpMean / expectedVimpStandardError;
|
||||||
|
|
||||||
|
final double actualZScore = calculator.calculateVariableImportanceZScore(covariate, Optional.of(new Random(123)));
|
||||||
|
|
||||||
|
assertEquals(expectedZScore, actualZScore, 0.000001, "Z scores must match");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testVariableImportanceOnXOOB(){
|
public void testVariableImportanceOnXOOB(){
|
||||||
// x is the BooleanCovariate
|
// x is the BooleanCovariate
|
||||||
|
|
||||||
Random random = new Random(123);
|
|
||||||
final VariableImportanceCalculator<Double, Double> calculator = new VariableImportanceCalculator<>(
|
final VariableImportanceCalculator<Double, Double> calculator = new VariableImportanceCalculator<>(
|
||||||
new RegressionErrorCalculator(),
|
new RegressionErrorCalculator(),
|
||||||
this.forest,
|
this.forest.getTrees(),
|
||||||
this.rowList,
|
this.rowList,
|
||||||
true
|
true
|
||||||
);
|
);
|
||||||
|
|
||||||
double importance = calculator.calculateVariableImportance(this.covariates.get(0), Optional.of(random));
|
final Covariate covariate = this.covariates.get(0);
|
||||||
|
|
||||||
// First 4 observations are off by 2, last 4 are off by 0
|
double importance[] = calculator.calculateVariableImportanceRaw(covariate, Optional.of(new Random(123)));
|
||||||
final double expectedBaselineError = 2.0*2.0 * 4.0 / 8.0;
|
|
||||||
|
|
||||||
// Remember we are working with OOB predictions
|
final double[] expectedBaselineError = {0.0, 4.0}; // first tree is accurate, second tree is not
|
||||||
final List<Double> permutedPredictions = Utils.easyList(
|
|
||||||
2., 112., 102., 12., 0., 110., 100., 10.
|
// [5, 6, 8, 7]
|
||||||
|
// Actual: [F, F, T, T]
|
||||||
|
// Seen: [F, F, T, T]
|
||||||
|
// Difference: No differences
|
||||||
|
final List<Double> permutedPredictionsTree1 = Utils.easyList(
|
||||||
|
0., 10., 100., 110.
|
||||||
);
|
);
|
||||||
|
|
||||||
|
// [3, 4, 1, 2]
|
||||||
|
// Actual: [F, F, T, T]
|
||||||
|
// Seen: [T, T, F, F]
|
||||||
|
// Difference: +100, +100, -100, -100
|
||||||
|
final List<Double> permutedPredictionsTree2 = Utils.easyList(
|
||||||
|
102., 112., 2., 12.
|
||||||
|
);
|
||||||
|
|
||||||
final List<Double> observedValues = this.rowList.stream().map(Row::getResponse).collect(Collectors.toList());
|
final List<Double> observedValues = this.rowList.stream().map(Row::getResponse).collect(Collectors.toList());
|
||||||
|
final List<Double> tree1OOBValues = observedValues.subList(4, 8);
|
||||||
|
final List<Double> tree2OOBValues = observedValues.subList(0, 4);
|
||||||
|
|
||||||
final double expectedError = new RegressionErrorCalculator().averageError(observedValues, permutedPredictions);
|
final double[] expectedError = new double[2];
|
||||||
|
|
||||||
assertEquals(expectedError - expectedBaselineError, importance, 0.0000001);
|
expectedError[0] = new RegressionErrorCalculator().averageError(tree1OOBValues, permutedPredictionsTree1);
|
||||||
|
expectedError[1] = new RegressionErrorCalculator().averageError(tree2OOBValues, permutedPredictionsTree2);
|
||||||
|
|
||||||
|
final double[] expectedVimp = difference(expectedError, expectedBaselineError);
|
||||||
|
|
||||||
|
assertDoubleEquals(expectedVimp, importance);
|
||||||
|
|
||||||
|
final double expectedVimpMean = (expectedVimp[0] + expectedVimp[1]) / 2.0;
|
||||||
|
final double expectedVimpVar = (Math.pow(expectedVimp[0] - expectedVimpMean, 2) + Math.pow(expectedVimp[1] - expectedVimpMean, 2)) / 1.0;
|
||||||
|
final double expectedVimpStandardError = Math.sqrt(expectedVimpVar / 2.0);
|
||||||
|
final double expectedZScore = expectedVimpMean / expectedVimpStandardError;
|
||||||
|
|
||||||
|
final double actualZScore = calculator.calculateVariableImportanceZScore(covariate, Optional.of(new Random(123)));
|
||||||
|
|
||||||
|
assertEquals(expectedZScore, actualZScore, 0.000001, "Z scores must match");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testVariableImportanceOnYNoOOB(){
|
public void testVariableImportanceOnYNoOOB(){
|
||||||
// y is the NumericCovariate
|
// y is the NumericCovariate
|
||||||
|
|
||||||
Random random = new Random(123);
|
|
||||||
final VariableImportanceCalculator<Double, Double> calculator = new VariableImportanceCalculator<>(
|
final VariableImportanceCalculator<Double, Double> calculator = new VariableImportanceCalculator<>(
|
||||||
new RegressionErrorCalculator(),
|
new RegressionErrorCalculator(),
|
||||||
this.forest,
|
this.forest.getTrees(),
|
||||||
this.rowList,
|
this.rowList,
|
||||||
false
|
false
|
||||||
);
|
);
|
||||||
|
|
||||||
double importance = calculator.calculateVariableImportance(this.covariates.get(1), Optional.of(random));
|
final Covariate covariate = this.covariates.get(1);
|
||||||
|
|
||||||
final double expectedBaselineError = 1.0; // Everything is off by 1, so average is 1.0
|
double importance[] = calculator.calculateVariableImportanceRaw(covariate, Optional.of(new Random(123)));
|
||||||
|
|
||||||
final List<Double> permutedPredictions = Utils.easyList(
|
final double[] expectedBaselineError = {0.0, 4.0}; // first tree is accurate, second tree is not
|
||||||
1., 11., 111., 111., 1., 1., 101., 111.
|
|
||||||
|
// [1, 4, 8, 2, 5, 3, 7, 6]
|
||||||
|
// Actual: [F, T, F, T, F, T, F, T]
|
||||||
|
// Seen: [F, T, T, T, F, F, F, T]
|
||||||
|
// Difference: [=, =, +, =, =, -, =, =]x10
|
||||||
|
final List<Double> permutedPredictionsTree1 = Utils.easyList(
|
||||||
|
0., 10., 110., 110., 0., 0., 100., 110.
|
||||||
);
|
);
|
||||||
|
|
||||||
|
// [6, 1, 4, 7, 5, 2, 8, 3]
|
||||||
|
// Actual: [F, T, F, T, F, T, F, T]
|
||||||
|
// Seen: [T, F, T, F, F, T, T, F]
|
||||||
|
// Difference: [+, -, +, -, =, =, +, -]
|
||||||
|
final List<Double> permutedPredictionsTree2 = Utils.easyList(
|
||||||
|
12., 2., 112., 102., 2., 12., 112., 102.
|
||||||
|
);
|
||||||
|
|
||||||
final List<Double> observedValues = this.rowList.stream().map(Row::getResponse).collect(Collectors.toList());
|
final List<Double> observedValues = this.rowList.stream().map(Row::getResponse).collect(Collectors.toList());
|
||||||
|
|
||||||
final double expectedError = new RegressionErrorCalculator().averageError(observedValues, permutedPredictions);
|
final double[] expectedError = new double[2];
|
||||||
|
|
||||||
assertEquals(expectedError - expectedBaselineError, importance, 0.0000001);
|
expectedError[0] = new RegressionErrorCalculator().averageError(observedValues, permutedPredictionsTree1);
|
||||||
|
expectedError[1] = new RegressionErrorCalculator().averageError(observedValues, permutedPredictionsTree2);
|
||||||
|
|
||||||
|
final double[] expectedVimp = difference(expectedError, expectedBaselineError);
|
||||||
|
|
||||||
|
assertDoubleEquals(expectedVimp, importance);
|
||||||
|
|
||||||
|
final double expectedVimpMean = (expectedVimp[0] + expectedVimp[1]) / 2.0;
|
||||||
|
final double expectedVimpVar = (Math.pow(expectedVimp[0] - expectedVimpMean, 2) + Math.pow(expectedVimp[1] - expectedVimpMean, 2)) / 1.0;
|
||||||
|
final double expectedVimpStandardError = Math.sqrt(expectedVimpVar / 2.0);
|
||||||
|
final double expectedZScore = expectedVimpMean / expectedVimpStandardError;
|
||||||
|
|
||||||
|
final double actualZScore = calculator.calculateVariableImportanceZScore(covariate, Optional.of(new Random(123)));
|
||||||
|
|
||||||
|
assertEquals(expectedZScore, actualZScore, 0.000001, "Z scores must match");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testVariableImportanceOnYOOB(){
|
public void testVariableImportanceOnYOOB(){
|
||||||
// y is the NumericCovariate
|
// y is the NumericCovariate
|
||||||
|
|
||||||
Random random = new Random(123);
|
|
||||||
final VariableImportanceCalculator<Double, Double> calculator = new VariableImportanceCalculator<>(
|
final VariableImportanceCalculator<Double, Double> calculator = new VariableImportanceCalculator<>(
|
||||||
new RegressionErrorCalculator(),
|
new RegressionErrorCalculator(),
|
||||||
this.forest,
|
this.forest.getTrees(),
|
||||||
this.rowList,
|
this.rowList,
|
||||||
true
|
true
|
||||||
);
|
);
|
||||||
|
|
||||||
double importance = calculator.calculateVariableImportance(this.covariates.get(1), Optional.of(random));
|
final Covariate covariate = this.covariates.get(1);
|
||||||
|
|
||||||
// First 4 observations are off by 2, last 4 are off by 0
|
double importance[] = calculator.calculateVariableImportanceRaw(covariate, Optional.of(new Random(123)));
|
||||||
final double expectedBaselineError = 2.0*2.0 * 4.0 / 8.0;
|
|
||||||
|
|
||||||
// Remember we are working with OOB predictions
|
final double[] expectedBaselineError = {0.0, 4.0}; // first tree is accurate, second tree is not
|
||||||
final List<Double> permutedPredictions = Utils.easyList(
|
|
||||||
2., 12., 112., 112., 0., 0., 100., 110.
|
// [5, 6, 8, 7]
|
||||||
|
// Actual: [F, T, F, T]
|
||||||
|
// Seen: [F, T, T, F]
|
||||||
|
// Difference: [=, =, +, -]x10
|
||||||
|
final List<Double> permutedPredictionsTree1 = Utils.easyList(
|
||||||
|
0., 10., 110., 100.
|
||||||
);
|
);
|
||||||
|
|
||||||
|
// [3, 4, 1, 2]
|
||||||
|
// Actual: [F, T, F, T]
|
||||||
|
// Seen: [F, T, F, T]
|
||||||
|
// Difference: [=, =, =, =]x10 no change
|
||||||
|
final List<Double> permutedPredictionsTree2 = Utils.easyList(
|
||||||
|
2., 12., 102., 112.
|
||||||
|
);
|
||||||
|
|
||||||
final List<Double> observedValues = this.rowList.stream().map(Row::getResponse).collect(Collectors.toList());
|
final List<Double> observedValues = this.rowList.stream().map(Row::getResponse).collect(Collectors.toList());
|
||||||
|
final List<Double> tree1OOBValues = observedValues.subList(4, 8);
|
||||||
|
final List<Double> tree2OOBValues = observedValues.subList(0, 4);
|
||||||
|
|
||||||
|
final double[] expectedError = new double[2];
|
||||||
|
|
||||||
|
expectedError[0] = new RegressionErrorCalculator().averageError(tree1OOBValues, permutedPredictionsTree1);
|
||||||
|
expectedError[1] = new RegressionErrorCalculator().averageError(tree2OOBValues, permutedPredictionsTree2);
|
||||||
|
|
||||||
|
final double[] expectedVimp = difference(expectedError, expectedBaselineError);
|
||||||
|
|
||||||
|
assertDoubleEquals(expectedVimp, importance);
|
||||||
|
|
||||||
|
final double expectedVimpMean = (expectedVimp[0] + expectedVimp[1]) / 2.0;
|
||||||
|
final double expectedVimpVar = (Math.pow(expectedVimp[0] - expectedVimpMean, 2) + Math.pow(expectedVimp[1] - expectedVimpMean, 2)) / 1.0;
|
||||||
|
final double expectedVimpStandardError = Math.sqrt(expectedVimpVar / 2.0);
|
||||||
|
final double expectedZScore = expectedVimpMean / expectedVimpStandardError;
|
||||||
|
|
||||||
|
final double actualZScore = calculator.calculateVariableImportanceZScore(covariate, Optional.of(new Random(123)));
|
||||||
|
|
||||||
|
assertEquals(expectedZScore, actualZScore, 0.000001, "Z scores must match");
|
||||||
|
|
||||||
final double expectedError = new RegressionErrorCalculator().averageError(observedValues, permutedPredictions);
|
|
||||||
|
|
||||||
assertEquals(expectedError - expectedBaselineError, importance, 0.0000001);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testVariableImportanceOnZNoOOB(){
|
public void testVariableImportanceOnZNoOOB(){
|
||||||
// z is the useless FactorCovariate
|
// z is the useless FactorCovariate
|
||||||
|
|
||||||
Random random = new Random(123);
|
|
||||||
final VariableImportanceCalculator<Double, Double> calculator = new VariableImportanceCalculator<>(
|
final VariableImportanceCalculator<Double, Double> calculator = new VariableImportanceCalculator<>(
|
||||||
new RegressionErrorCalculator(),
|
new RegressionErrorCalculator(),
|
||||||
this.forest,
|
this.forest.getTrees(),
|
||||||
this.rowList,
|
this.rowList,
|
||||||
false
|
false
|
||||||
);
|
);
|
||||||
|
|
||||||
double importance = calculator.calculateVariableImportance(this.covariates.get(2), Optional.of(random));
|
final double[] importance = calculator.calculateVariableImportanceRaw(this.covariates.get(2), Optional.of(new Random(123)));
|
||||||
|
final double[] expectedImportance = {0.0, 0.0};
|
||||||
|
|
||||||
|
|
||||||
// FactorImportance did nothing; so permuting it will make no difference to baseline error
|
// FactorImportance did nothing; so permuting it will make no difference to baseline error
|
||||||
assertEquals(0, importance, 0.0000001);
|
assertDoubleEquals(expectedImportance, importance);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testVariableImportanceOnZOOB(){
|
public void testVariableImportanceOnZOOB(){
|
||||||
// z is the useless FactorCovariate
|
// z is the useless FactorCovariate
|
||||||
|
|
||||||
Random random = new Random(123);
|
|
||||||
final VariableImportanceCalculator<Double, Double> calculator = new VariableImportanceCalculator<>(
|
final VariableImportanceCalculator<Double, Double> calculator = new VariableImportanceCalculator<>(
|
||||||
new RegressionErrorCalculator(),
|
new RegressionErrorCalculator(),
|
||||||
this.forest,
|
this.forest.getTrees(),
|
||||||
this.rowList,
|
this.rowList,
|
||||||
true
|
true
|
||||||
);
|
);
|
||||||
|
|
||||||
double importance = calculator.calculateVariableImportance(this.covariates.get(2), Optional.of(random));
|
final double[] importance = calculator.calculateVariableImportanceRaw(this.covariates.get(2), Optional.of(new Random(123)));
|
||||||
|
final double[] expectedImportance = {0.0, 0.0};
|
||||||
|
|
||||||
|
|
||||||
// FactorImportance did nothing; so permuting it will make no difference to baseline error
|
// FactorImportance did nothing; so permuting it will make no difference to baseline error
|
||||||
assertEquals(0, importance, 0.0000001);
|
assertDoubleEquals(expectedImportance, importance);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testVariableImportanceMultiple(){
|
|
||||||
Random random = new Random(123);
|
|
||||||
final VariableImportanceCalculator<Double, Double> calculator = new VariableImportanceCalculator<>(
|
|
||||||
new RegressionErrorCalculator(),
|
|
||||||
this.forest,
|
|
||||||
this.rowList,
|
|
||||||
false
|
|
||||||
);
|
|
||||||
|
|
||||||
double importance[] = calculator.calculateVariableImportance(covariates, Optional.of(random));
|
|
||||||
|
|
||||||
final double expectedBaselineError = 1.0; // Everything is off by 1, so average is 1.0
|
|
||||||
|
|
||||||
final List<Double> observedValues = this.rowList.stream().map(Row::getResponse).collect(Collectors.toList());
|
|
||||||
|
|
||||||
final List<Double> permutedPredictionsX = Utils.easyList(
|
|
||||||
1., 111., 101., 11., 1., 111., 101., 11.
|
|
||||||
);
|
|
||||||
|
|
||||||
// [6, 1, 4, 7, 5, 2, 8, 3]
|
|
||||||
final List<Double> permutedPredictionsY = Utils.easyList(
|
|
||||||
11., 1., 111., 101., 1., 11., 111., 101.
|
|
||||||
);
|
|
||||||
|
|
||||||
final double expectedErrorX = new RegressionErrorCalculator().averageError(observedValues, permutedPredictionsX);
|
|
||||||
final double expectedErrorY = new RegressionErrorCalculator().averageError(observedValues, permutedPredictionsY);
|
|
||||||
|
|
||||||
assertEquals(expectedErrorX - expectedBaselineError, importance[0], 0.0000001);
|
|
||||||
assertEquals(expectedErrorY - expectedBaselineError, importance[1], 0.0000001);
|
|
||||||
assertEquals(0, importance[2], 0.0000001);
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -138,4 +138,18 @@ public class RightContinuousStepFunctionIntegrationTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testIntegratingEmptyFunction(){
|
||||||
|
// A function might have no points, but we'll still need to integrate it.
|
||||||
|
|
||||||
|
final RightContinuousStepFunction function = new RightContinuousStepFunction(
|
||||||
|
new double[]{}, new double[]{}, 1.0
|
||||||
|
);
|
||||||
|
|
||||||
|
final double area = function.integrate(1.0 ,3.0);
|
||||||
|
assertEquals(2.0, area, 0.000001);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue