From 51696e254618e766b6c5abb8e57e3548e8b184cd Mon Sep 17 00:00:00 2001 From: Joel Therrien Date: Mon, 12 Aug 2019 14:28:31 -0700 Subject: [PATCH] Fix how variable importance works to be tree based and not forest based --- .../vimp/VariableImportanceCalculator.java | 110 +++++-- .../utils/RightContinuousStepFunction.java | 5 + .../VariableImportanceCalculatorTest.java | 294 +++++++++++++----- ...ContinuousStepFunctionIntegrationTest.java | 14 + 4 files changed, 310 insertions(+), 113 deletions(-) diff --git a/library/src/main/java/ca/joeltherrien/randomforest/tree/vimp/VariableImportanceCalculator.java b/library/src/main/java/ca/joeltherrien/randomforest/tree/vimp/VariableImportanceCalculator.java index 1986214..1248a9b 100644 --- a/library/src/main/java/ca/joeltherrien/randomforest/tree/vimp/VariableImportanceCalculator.java +++ b/library/src/main/java/ca/joeltherrien/randomforest/tree/vimp/VariableImportanceCalculator.java @@ -3,7 +3,7 @@ package ca.joeltherrien.randomforest.tree.vimp; import ca.joeltherrien.randomforest.CovariateRow; import ca.joeltherrien.randomforest.Row; import ca.joeltherrien.randomforest.covariates.Covariate; -import ca.joeltherrien.randomforest.tree.Forest; +import ca.joeltherrien.randomforest.tree.Tree; import java.util.List; import java.util.Optional; @@ -14,52 +14,104 @@ import java.util.stream.Collectors; public class VariableImportanceCalculator { private final ErrorCalculator errorCalculator; - private final Forest forest; + private final List> trees; private final List> observations; - private final List observedResponses; private final boolean isTrainingSet; // If true, then we use out-of-bag predictions - private final double baselineError; + private final double[] baselineErrors; public VariableImportanceCalculator( ErrorCalculator errorCalculator, - Forest forest, + List> trees, List> observations, boolean isTrainingSet ){ this.errorCalculator = errorCalculator; - this.forest = forest; + this.trees = trees; this.observations = observations; this.isTrainingSet = isTrainingSet; - this.observedResponses = observations.stream() - .map(row -> row.getResponse()).collect(Collectors.toList()); - final List

baselinePredictions = makePredictions(observations); - this.baselineError = errorCalculator.averageError(observedResponses, baselinePredictions); + try { - } + this.baselineErrors = new double[trees.size()]; + for (int i = 0; i < baselineErrors.length; i++) { + final Tree

tree = trees.get(i); + final List> oobSubset = getAppropriateSubset(observations, tree); // may not actually be OOB depending on isTrainingSet + final List responses = oobSubset.stream().map(Row::getResponse).collect(Collectors.toList()); - public double calculateVariableImportance(Covariate covariate, Optional random){ - final List scrambledValues = CovariateRow.scrambleCovariateValues(this.observations, covariate, random); - final List

alternatePredictions = makePredictions(scrambledValues); - final double newError = errorCalculator.averageError(this.observedResponses, alternatePredictions); + this.baselineErrors[i] = errorCalculator.averageError(responses, makePredictions(oobSubset, tree)); + } - return newError - this.baselineError; - } - - public double[] calculateVariableImportance(List covariates, Optional random){ - return covariates.stream() - .mapToDouble(covariate -> calculateVariableImportance(covariate, random)) - .toArray(); - } - - private List

makePredictions(List rowList){ - if(isTrainingSet){ - return forest.evaluateOOB(rowList); - } else{ - return forest.evaluate(rowList); + } catch(Exception e){ + e.printStackTrace(); + throw e; } + } + /** + * Returns an array of importance values for every Tree for the given Covariate. + * + * @param covariate The Covariate to scramble. + * @param random + * @return + */ + public double[] calculateVariableImportanceRaw(Covariate covariate, Optional random){ + + final double[] vimp = new double[trees.size()]; + for(int i = 0; i < vimp.length; i++){ + final Tree

tree = trees.get(i); + final List> oobSubset = getAppropriateSubset(observations, tree); // may not actually be OOB depending on isTrainingSet + final List responses = oobSubset.stream().map(Row::getResponse).collect(Collectors.toList()); + final List scrambledValues = CovariateRow.scrambleCovariateValues(oobSubset, covariate, random); + + final double error = errorCalculator.averageError(responses, makePredictions(scrambledValues, tree)); + + vimp[i] = error - this.baselineErrors[i]; + } + + return vimp; + } + + public double calculateVariableImportanceZScore(Covariate covariate, Optional random){ + final double[] vimpArray = calculateVariableImportanceRaw(covariate, random); + + double mean = 0.0; + double variance = 0.0; + final double numTrees = vimpArray.length; + + for(double vimp : vimpArray){ + mean += vimp / numTrees; + } + for(double vimp : vimpArray){ + variance += (vimp - mean)*(vimp - mean) / (numTrees - 1.0); + } + + final double standardError = Math.sqrt(variance / numTrees); + + return mean / standardError; + } + + + + // Assume rowList has already been filtered for OOB + private List

makePredictions(List rowList, Tree

tree){ + return rowList.stream() + .map(tree::evaluate) + .collect(Collectors.toList()); + } + + private List> getAppropriateSubset(List> initialList, Tree

tree){ + if(!isTrainingSet){ + return initialList; // no need to make any subsets + } + + return initialList.stream() + .filter(row -> !tree.idInBootstrapSample(row.getId())) + .collect(Collectors.toList()); + + } + + } diff --git a/library/src/main/java/ca/joeltherrien/randomforest/utils/RightContinuousStepFunction.java b/library/src/main/java/ca/joeltherrien/randomforest/utils/RightContinuousStepFunction.java index 2ae8056..c3ca9d2 100644 --- a/library/src/main/java/ca/joeltherrien/randomforest/utils/RightContinuousStepFunction.java +++ b/library/src/main/java/ca/joeltherrien/randomforest/utils/RightContinuousStepFunction.java @@ -136,6 +136,11 @@ public final class RightContinuousStepFunction extends StepFunction { return -integrate(to, from); } + // Edge case - no points; just defaultY + if(this.x.length == 0){ + return (to - from) * this.defaultY; + } + double summation = 0.0; final double[] xPoints = getX(); final int startingIndex; diff --git a/library/src/test/java/ca/joeltherrien/randomforest/tree/vimp/VariableImportanceCalculatorTest.java b/library/src/test/java/ca/joeltherrien/randomforest/tree/vimp/VariableImportanceCalculatorTest.java index 8d393a5..b656ce0 100644 --- a/library/src/test/java/ca/joeltherrien/randomforest/tree/vimp/VariableImportanceCalculatorTest.java +++ b/library/src/test/java/ca/joeltherrien/randomforest/tree/vimp/VariableImportanceCalculatorTest.java @@ -13,9 +13,10 @@ import ca.joeltherrien.randomforest.tree.*; import ca.joeltherrien.randomforest.utils.Utils; import org.junit.jupiter.api.Test; -import java.util.*; +import java.util.List; +import java.util.Optional; +import java.util.Random; import java.util.stream.Collectors; -import java.util.stream.IntStream; import static org.junit.jupiter.api.Assertions.assertEquals; @@ -163,197 +164,322 @@ public class VariableImportanceCalculatorTest { // Experiment with random seeds to first examine what a split does so we know what to expect /* public static void main(String[] args){ - final List ints1 = IntStream.range(1, 9).boxed().collect(Collectors.toList()); - final List ints2 = IntStream.range(1, 9).boxed().collect(Collectors.toList()); + + // Behaviour for OOB + final List ints1 = IntStream.range(5, 9).boxed().collect(Collectors.toList()); + final List ints2 = IntStream.range(1, 5).boxed().collect(Collectors.toList()); final Random random = new Random(123); Collections.shuffle(ints1, random); Collections.shuffle(ints2, random); System.out.println(ints1); - // [1, 4, 8, 2, 5, 3, 7, 6] - System.out.println(ints2); - [6, 1, 4, 7, 5, 2, 8, 3] + // [5, 6, 8, 7] + // [3, 4, 1, 2] + + + // Behaviour for no-OOB + final List fullInts1 = IntStream.range(1, 9).boxed().collect(Collectors.toList()); + final List fullInts2 = IntStream.range(1, 9).boxed().collect(Collectors.toList()); + final Random fullIntsRandom = new Random(123); + + Collections.shuffle(fullInts1, fullIntsRandom); + Collections.shuffle(fullInts2, fullIntsRandom); + System.out.println(fullInts1); + System.out.println(fullInts2); + // [1, 4, 8, 2, 5, 3, 7, 6] + // [6, 1, 4, 7, 5, 2, 8, 3] + } - */ + */ + + + private double[] difference(double[] a, double[] b){ + final double[] results = new double[a.length]; + + for(int i = 0; i < a.length; i++){ + results[i] = a[i] - b[i]; + } + + return results; + } + + private void assertDoubleEquals(double[] expected, double[] actual){ + assertEquals(expected.length, actual.length, "Lengths of arrays should be equal"); + + for(int i=0; i < expected.length; i++){ + assertEquals(expected[i], actual[i], 0.0000001, "Difference at " + i); + } + + } + + @Test public void testVariableImportanceOnXNoOOB(){ // x is the BooleanCovariate - Random random = new Random(123); final VariableImportanceCalculator calculator = new VariableImportanceCalculator<>( new RegressionErrorCalculator(), - this.forest, + this.forest.getTrees(), this.rowList, false ); - double importance = calculator.calculateVariableImportance(this.covariates.get(0), Optional.of(random)); + final Covariate covariate = this.covariates.get(0); - final double expectedBaselineError = 1.0; // Everything is off by 1, so average is 1.0 + double importance[] = calculator.calculateVariableImportanceRaw(covariate, Optional.of(new Random(123))); - final List permutedPredictions = Utils.easyList( - 1., 111., 101., 11., 1., 111., 101., 11. + final double[] expectedBaselineError = {0.0, 4.0}; // first tree is accurate, second tree is not + + // [1, 4, 8, 2, 5, 3, 7, 6] + final List permutedPredictionsTree1 = Utils.easyList( + 0., 110., 100., 10., 0., 110., 100., 10. ); + + // [6, 1, 4, 7, 5, 2, 8, 3] + // Actual: [F, F, T, T, F, F, T, T] + // Seen: [F, F, T, T, F, F, T, T] + // Difference: 0 all around; random chance + final List permutedPredictionsTree2 = Utils.easyList( + 2., 12., 102., 112., 2., 12., 102., 112. + ); + final List observedValues = this.rowList.stream().map(Row::getResponse).collect(Collectors.toList()); - final double expectedError = new RegressionErrorCalculator().averageError(observedValues, permutedPredictions); + final double[] expectedError = new double[2]; - assertEquals(expectedError - expectedBaselineError, importance, 0.0000001); + expectedError[0] = new RegressionErrorCalculator().averageError(observedValues, permutedPredictionsTree1); + expectedError[1] = new RegressionErrorCalculator().averageError(observedValues, permutedPredictionsTree2); + + final double[] expectedVimp = difference(expectedError, expectedBaselineError); + + assertDoubleEquals(expectedVimp, importance); + + final double expectedVimpMean = (expectedVimp[0] + expectedVimp[1]) / 2.0; + final double expectedVimpVar = (Math.pow(expectedVimp[0] - expectedVimpMean, 2) + Math.pow(expectedVimp[1] - expectedVimpMean, 2)) / 1.0; + final double expectedVimpStandardError = Math.sqrt(expectedVimpVar / 2.0); + final double expectedZScore = expectedVimpMean / expectedVimpStandardError; + + final double actualZScore = calculator.calculateVariableImportanceZScore(covariate, Optional.of(new Random(123))); + + assertEquals(expectedZScore, actualZScore, 0.000001, "Z scores must match"); } + + @Test public void testVariableImportanceOnXOOB(){ // x is the BooleanCovariate - Random random = new Random(123); final VariableImportanceCalculator calculator = new VariableImportanceCalculator<>( new RegressionErrorCalculator(), - this.forest, + this.forest.getTrees(), this.rowList, true ); - double importance = calculator.calculateVariableImportance(this.covariates.get(0), Optional.of(random)); + final Covariate covariate = this.covariates.get(0); - // First 4 observations are off by 2, last 4 are off by 0 - final double expectedBaselineError = 2.0*2.0 * 4.0 / 8.0; + double importance[] = calculator.calculateVariableImportanceRaw(covariate, Optional.of(new Random(123))); - // Remember we are working with OOB predictions - final List permutedPredictions = Utils.easyList( - 2., 112., 102., 12., 0., 110., 100., 10. + final double[] expectedBaselineError = {0.0, 4.0}; // first tree is accurate, second tree is not + + // [5, 6, 8, 7] + // Actual: [F, F, T, T] + // Seen: [F, F, T, T] + // Difference: No differences + final List permutedPredictionsTree1 = Utils.easyList( + 0., 10., 100., 110. ); + + // [3, 4, 1, 2] + // Actual: [F, F, T, T] + // Seen: [T, T, F, F] + // Difference: +100, +100, -100, -100 + final List permutedPredictionsTree2 = Utils.easyList( + 102., 112., 2., 12. + ); + final List observedValues = this.rowList.stream().map(Row::getResponse).collect(Collectors.toList()); + final List tree1OOBValues = observedValues.subList(4, 8); + final List tree2OOBValues = observedValues.subList(0, 4); - final double expectedError = new RegressionErrorCalculator().averageError(observedValues, permutedPredictions); + final double[] expectedError = new double[2]; - assertEquals(expectedError - expectedBaselineError, importance, 0.0000001); + expectedError[0] = new RegressionErrorCalculator().averageError(tree1OOBValues, permutedPredictionsTree1); + expectedError[1] = new RegressionErrorCalculator().averageError(tree2OOBValues, permutedPredictionsTree2); + + final double[] expectedVimp = difference(expectedError, expectedBaselineError); + + assertDoubleEquals(expectedVimp, importance); + + final double expectedVimpMean = (expectedVimp[0] + expectedVimp[1]) / 2.0; + final double expectedVimpVar = (Math.pow(expectedVimp[0] - expectedVimpMean, 2) + Math.pow(expectedVimp[1] - expectedVimpMean, 2)) / 1.0; + final double expectedVimpStandardError = Math.sqrt(expectedVimpVar / 2.0); + final double expectedZScore = expectedVimpMean / expectedVimpStandardError; + + final double actualZScore = calculator.calculateVariableImportanceZScore(covariate, Optional.of(new Random(123))); + + assertEquals(expectedZScore, actualZScore, 0.000001, "Z scores must match"); } + @Test public void testVariableImportanceOnYNoOOB(){ // y is the NumericCovariate - Random random = new Random(123); final VariableImportanceCalculator calculator = new VariableImportanceCalculator<>( new RegressionErrorCalculator(), - this.forest, + this.forest.getTrees(), this.rowList, false ); - double importance = calculator.calculateVariableImportance(this.covariates.get(1), Optional.of(random)); + final Covariate covariate = this.covariates.get(1); - final double expectedBaselineError = 1.0; // Everything is off by 1, so average is 1.0 + double importance[] = calculator.calculateVariableImportanceRaw(covariate, Optional.of(new Random(123))); - final List permutedPredictions = Utils.easyList( - 1., 11., 111., 111., 1., 1., 101., 111. + final double[] expectedBaselineError = {0.0, 4.0}; // first tree is accurate, second tree is not + + // [1, 4, 8, 2, 5, 3, 7, 6] + // Actual: [F, T, F, T, F, T, F, T] + // Seen: [F, T, T, T, F, F, F, T] + // Difference: [=, =, +, =, =, -, =, =]x10 + final List permutedPredictionsTree1 = Utils.easyList( + 0., 10., 110., 110., 0., 0., 100., 110. ); + + // [6, 1, 4, 7, 5, 2, 8, 3] + // Actual: [F, T, F, T, F, T, F, T] + // Seen: [T, F, T, F, F, T, T, F] + // Difference: [+, -, +, -, =, =, +, -] + final List permutedPredictionsTree2 = Utils.easyList( + 12., 2., 112., 102., 2., 12., 112., 102. + ); + final List observedValues = this.rowList.stream().map(Row::getResponse).collect(Collectors.toList()); - final double expectedError = new RegressionErrorCalculator().averageError(observedValues, permutedPredictions); + final double[] expectedError = new double[2]; - assertEquals(expectedError - expectedBaselineError, importance, 0.0000001); + expectedError[0] = new RegressionErrorCalculator().averageError(observedValues, permutedPredictionsTree1); + expectedError[1] = new RegressionErrorCalculator().averageError(observedValues, permutedPredictionsTree2); + + final double[] expectedVimp = difference(expectedError, expectedBaselineError); + + assertDoubleEquals(expectedVimp, importance); + + final double expectedVimpMean = (expectedVimp[0] + expectedVimp[1]) / 2.0; + final double expectedVimpVar = (Math.pow(expectedVimp[0] - expectedVimpMean, 2) + Math.pow(expectedVimp[1] - expectedVimpMean, 2)) / 1.0; + final double expectedVimpStandardError = Math.sqrt(expectedVimpVar / 2.0); + final double expectedZScore = expectedVimpMean / expectedVimpStandardError; + + final double actualZScore = calculator.calculateVariableImportanceZScore(covariate, Optional.of(new Random(123))); + + assertEquals(expectedZScore, actualZScore, 0.000001, "Z scores must match"); } + + @Test public void testVariableImportanceOnYOOB(){ // y is the NumericCovariate - Random random = new Random(123); final VariableImportanceCalculator calculator = new VariableImportanceCalculator<>( new RegressionErrorCalculator(), - this.forest, + this.forest.getTrees(), this.rowList, true ); - double importance = calculator.calculateVariableImportance(this.covariates.get(1), Optional.of(random)); + final Covariate covariate = this.covariates.get(1); - // First 4 observations are off by 2, last 4 are off by 0 - final double expectedBaselineError = 2.0*2.0 * 4.0 / 8.0; + double importance[] = calculator.calculateVariableImportanceRaw(covariate, Optional.of(new Random(123))); - // Remember we are working with OOB predictions - final List permutedPredictions = Utils.easyList( - 2., 12., 112., 112., 0., 0., 100., 110. + final double[] expectedBaselineError = {0.0, 4.0}; // first tree is accurate, second tree is not + + // [5, 6, 8, 7] + // Actual: [F, T, F, T] + // Seen: [F, T, T, F] + // Difference: [=, =, +, -]x10 + final List permutedPredictionsTree1 = Utils.easyList( + 0., 10., 110., 100. ); + + // [3, 4, 1, 2] + // Actual: [F, T, F, T] + // Seen: [F, T, F, T] + // Difference: [=, =, =, =]x10 no change + final List permutedPredictionsTree2 = Utils.easyList( + 2., 12., 102., 112. + ); + final List observedValues = this.rowList.stream().map(Row::getResponse).collect(Collectors.toList()); + final List tree1OOBValues = observedValues.subList(4, 8); + final List tree2OOBValues = observedValues.subList(0, 4); + + final double[] expectedError = new double[2]; + + expectedError[0] = new RegressionErrorCalculator().averageError(tree1OOBValues, permutedPredictionsTree1); + expectedError[1] = new RegressionErrorCalculator().averageError(tree2OOBValues, permutedPredictionsTree2); + + final double[] expectedVimp = difference(expectedError, expectedBaselineError); + + assertDoubleEquals(expectedVimp, importance); + + final double expectedVimpMean = (expectedVimp[0] + expectedVimp[1]) / 2.0; + final double expectedVimpVar = (Math.pow(expectedVimp[0] - expectedVimpMean, 2) + Math.pow(expectedVimp[1] - expectedVimpMean, 2)) / 1.0; + final double expectedVimpStandardError = Math.sqrt(expectedVimpVar / 2.0); + final double expectedZScore = expectedVimpMean / expectedVimpStandardError; + + final double actualZScore = calculator.calculateVariableImportanceZScore(covariate, Optional.of(new Random(123))); + + assertEquals(expectedZScore, actualZScore, 0.000001, "Z scores must match"); - final double expectedError = new RegressionErrorCalculator().averageError(observedValues, permutedPredictions); - assertEquals(expectedError - expectedBaselineError, importance, 0.0000001); } + + @Test public void testVariableImportanceOnZNoOOB(){ // z is the useless FactorCovariate - Random random = new Random(123); final VariableImportanceCalculator calculator = new VariableImportanceCalculator<>( new RegressionErrorCalculator(), - this.forest, + this.forest.getTrees(), this.rowList, false ); - double importance = calculator.calculateVariableImportance(this.covariates.get(2), Optional.of(random)); + final double[] importance = calculator.calculateVariableImportanceRaw(this.covariates.get(2), Optional.of(new Random(123))); + final double[] expectedImportance = {0.0, 0.0}; + // FactorImportance did nothing; so permuting it will make no difference to baseline error - assertEquals(0, importance, 0.0000001); + assertDoubleEquals(expectedImportance, importance); } @Test public void testVariableImportanceOnZOOB(){ // z is the useless FactorCovariate - Random random = new Random(123); final VariableImportanceCalculator calculator = new VariableImportanceCalculator<>( new RegressionErrorCalculator(), - this.forest, + this.forest.getTrees(), this.rowList, true ); - double importance = calculator.calculateVariableImportance(this.covariates.get(2), Optional.of(random)); + final double[] importance = calculator.calculateVariableImportanceRaw(this.covariates.get(2), Optional.of(new Random(123))); + final double[] expectedImportance = {0.0, 0.0}; + // FactorImportance did nothing; so permuting it will make no difference to baseline error - assertEquals(0, importance, 0.0000001); + assertDoubleEquals(expectedImportance, importance); } - @Test - public void testVariableImportanceMultiple(){ - Random random = new Random(123); - final VariableImportanceCalculator calculator = new VariableImportanceCalculator<>( - new RegressionErrorCalculator(), - this.forest, - this.rowList, - false - ); - double importance[] = calculator.calculateVariableImportance(covariates, Optional.of(random)); - - final double expectedBaselineError = 1.0; // Everything is off by 1, so average is 1.0 - - final List observedValues = this.rowList.stream().map(Row::getResponse).collect(Collectors.toList()); - - final List permutedPredictionsX = Utils.easyList( - 1., 111., 101., 11., 1., 111., 101., 11. - ); - - // [6, 1, 4, 7, 5, 2, 8, 3] - final List permutedPredictionsY = Utils.easyList( - 11., 1., 111., 101., 1., 11., 111., 101. - ); - - final double expectedErrorX = new RegressionErrorCalculator().averageError(observedValues, permutedPredictionsX); - final double expectedErrorY = new RegressionErrorCalculator().averageError(observedValues, permutedPredictionsY); - - assertEquals(expectedErrorX - expectedBaselineError, importance[0], 0.0000001); - assertEquals(expectedErrorY - expectedBaselineError, importance[1], 0.0000001); - assertEquals(0, importance[2], 0.0000001); - - } } diff --git a/library/src/test/java/ca/joeltherrien/randomforest/utils/RightContinuousStepFunctionIntegrationTest.java b/library/src/test/java/ca/joeltherrien/randomforest/utils/RightContinuousStepFunctionIntegrationTest.java index 4aeb13d..42e21c4 100644 --- a/library/src/test/java/ca/joeltherrien/randomforest/utils/RightContinuousStepFunctionIntegrationTest.java +++ b/library/src/test/java/ca/joeltherrien/randomforest/utils/RightContinuousStepFunctionIntegrationTest.java @@ -138,4 +138,18 @@ public class RightContinuousStepFunctionIntegrationTest { } + @Test + public void testIntegratingEmptyFunction(){ + // A function might have no points, but we'll still need to integrate it. + + final RightContinuousStepFunction function = new RightContinuousStepFunction( + new double[]{}, new double[]{}, 1.0 + ); + + final double area = function.integrate(1.0 ,3.0); + assertEquals(2.0, area, 0.000001); + + } + + }