From a56ad4433dcb8374b5b44d6cefb56e999fcf8810 Mon Sep 17 00:00:00 2001 From: Joel Therrien Date: Fri, 26 Jul 2019 14:59:41 -0700 Subject: [PATCH] Add variable importance methods to library --- .../randomforest/CovariateRow.java | 38 +- .../ca/joeltherrien/randomforest/Row.java | 3 - .../covariates/numeric/NumericSplitRule.java | 2 +- .../tree/vimp/ErrorCalculator.java | 24 ++ .../tree/vimp/IBSErrorCalculatorWrapper.java | 63 +++ .../tree/vimp/RegressionErrorCalculator.java | 23 ++ .../vimp/VariableImportanceCalculator.java | 65 ++++ .../vimp/IBSErrorCalculatorWrapperTest.java | 143 +++++++ .../vimp/RegressionErrorCalculatorTest.java | 26 ++ .../VariableImportanceCalculatorTest.java | 359 ++++++++++++++++++ 10 files changed, 738 insertions(+), 8 deletions(-) create mode 100644 library/src/main/java/ca/joeltherrien/randomforest/tree/vimp/ErrorCalculator.java create mode 100644 library/src/main/java/ca/joeltherrien/randomforest/tree/vimp/IBSErrorCalculatorWrapper.java create mode 100644 library/src/main/java/ca/joeltherrien/randomforest/tree/vimp/RegressionErrorCalculator.java create mode 100644 library/src/main/java/ca/joeltherrien/randomforest/tree/vimp/VariableImportanceCalculator.java create mode 100644 library/src/test/java/ca/joeltherrien/randomforest/tree/vimp/IBSErrorCalculatorWrapperTest.java create mode 100644 library/src/test/java/ca/joeltherrien/randomforest/tree/vimp/RegressionErrorCalculatorTest.java create mode 100644 library/src/test/java/ca/joeltherrien/randomforest/tree/vimp/VariableImportanceCalculatorTest.java diff --git a/library/src/main/java/ca/joeltherrien/randomforest/CovariateRow.java b/library/src/main/java/ca/joeltherrien/randomforest/CovariateRow.java index 1f44191..3e053bf 100644 --- a/library/src/main/java/ca/joeltherrien/randomforest/CovariateRow.java +++ b/library/src/main/java/ca/joeltherrien/randomforest/CovariateRow.java @@ -21,12 +21,11 @@ import lombok.Getter; import lombok.RequiredArgsConstructor; import java.io.Serializable; -import java.util.HashMap; -import java.util.List; -import java.util.Map; +import java.util.*; +import java.util.stream.Collectors; @RequiredArgsConstructor -public class CovariateRow implements Serializable { +public class CovariateRow implements Serializable, Cloneable { private final Covariate.Value[] valueArray; @@ -46,6 +45,14 @@ public class CovariateRow implements Serializable { return "CovariateRow " + this.id; } + @Override + public CovariateRow clone() { + // shallow clone, which is fine. I want a new array, but the values don't need to be copied + final Covariate.Value[] copyValueArray = this.valueArray.clone(); + + return new CovariateRow(copyValueArray, this.id); + } + public static CovariateRow createSimple(Map simpleMap, List covariateList, int id){ final Covariate.Value[] valueArray = new Covariate.Value[covariateList.size()]; final Map covariateMap = new HashMap<>(); @@ -64,4 +71,27 @@ public class CovariateRow implements Serializable { return new CovariateRow(valueArray, id); } + /** + * Used for variable importance; takes a List of CovariateRows and permute one of the Covariates. + * + * @param covariateRows The List of CovariateRows to scramble. Note that the originals won't be modified. + * @param covariateToScramble The Covariate to scramble on. + * @param random The source of randomness to use. If not present, one will be created. + * @return A List of CovariateRows where the specified covariate was scrambled. These are different objects from the ones provided. + */ + public static List scrambleCovariateValues(List covariateRows, Covariate covariateToScramble, Optional random){ + final List permutedCovariateRowList = new ArrayList<>(covariateRows); + Collections.shuffle(permutedCovariateRowList, random.orElse(new Random())); // without replacement + + final List clonedRowList = covariateRows.stream().map(CovariateRow::clone).collect(Collectors.toList()); + + final int covariateToScrambleIndex = covariateToScramble.getIndex(); + for(int i=0; i < covariateRows.size(); i++){ + clonedRowList.get(i).valueArray[covariateToScrambleIndex] = permutedCovariateRowList.get(i).valueArray[covariateToScrambleIndex]; + } + + return clonedRowList; + + } + } diff --git a/library/src/main/java/ca/joeltherrien/randomforest/Row.java b/library/src/main/java/ca/joeltherrien/randomforest/Row.java index 43b26ed..77357e5 100644 --- a/library/src/main/java/ca/joeltherrien/randomforest/Row.java +++ b/library/src/main/java/ca/joeltherrien/randomforest/Row.java @@ -32,9 +32,6 @@ public class Row extends CovariateRow { this.response = response; } - - - public Y getResponse() { return this.response; } diff --git a/library/src/main/java/ca/joeltherrien/randomforest/covariates/numeric/NumericSplitRule.java b/library/src/main/java/ca/joeltherrien/randomforest/covariates/numeric/NumericSplitRule.java index c4e4150..43b0c95 100644 --- a/library/src/main/java/ca/joeltherrien/randomforest/covariates/numeric/NumericSplitRule.java +++ b/library/src/main/java/ca/joeltherrien/randomforest/covariates/numeric/NumericSplitRule.java @@ -26,7 +26,7 @@ public class NumericSplitRule implements SplitRule { private final int parentCovariateIndex; private final double threshold; - NumericSplitRule(NumericCovariate parent, final double threshold){ + public NumericSplitRule(NumericCovariate parent, final double threshold){ this.parentCovariateIndex = parent.getIndex(); this.threshold = threshold; } diff --git a/library/src/main/java/ca/joeltherrien/randomforest/tree/vimp/ErrorCalculator.java b/library/src/main/java/ca/joeltherrien/randomforest/tree/vimp/ErrorCalculator.java new file mode 100644 index 0000000..0fba4fd --- /dev/null +++ b/library/src/main/java/ca/joeltherrien/randomforest/tree/vimp/ErrorCalculator.java @@ -0,0 +1,24 @@ +package ca.joeltherrien.randomforest.tree.vimp; + +import java.util.List; + +/** + * Simple interface for VariableImportanceCalculator; takes in a List of observed responses and a List of predictions + * and produces an average error measure. + * + * @param The class of the responses. + * @param

The class of the predictions. + */ +public interface ErrorCalculator{ + + /** + * Compares the observed responses with the predictions to produce an average error measure. + * Lower errors should indicate a better model fit. + * + * @param responses + * @param predictions + * @return + */ + double averageError(List responses, List

predictions); + +} diff --git a/library/src/main/java/ca/joeltherrien/randomforest/tree/vimp/IBSErrorCalculatorWrapper.java b/library/src/main/java/ca/joeltherrien/randomforest/tree/vimp/IBSErrorCalculatorWrapper.java new file mode 100644 index 0000000..2ee4a00 --- /dev/null +++ b/library/src/main/java/ca/joeltherrien/randomforest/tree/vimp/IBSErrorCalculatorWrapper.java @@ -0,0 +1,63 @@ +package ca.joeltherrien.randomforest.tree.vimp; + +import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskFunctions; +import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse; +import ca.joeltherrien.randomforest.responses.competingrisk.IBSCalculator; +import ca.joeltherrien.randomforest.utils.RightContinuousStepFunction; + +import java.util.Arrays; +import java.util.List; + +/** + * Implements ErrorCalculator; essentially just wraps around IBSCalculator to fit into VariableImportanceCalculator. + * + */ +public class IBSErrorCalculatorWrapper implements ErrorCalculator { + + private final IBSCalculator calculator; + private final int[] events; + private final double integrationUpperBound; + private final double[] eventWeights; + + public IBSErrorCalculatorWrapper(IBSCalculator calculator, int[] events, double integrationUpperBound, double[] eventWeights) { + this.calculator = calculator; + this.events = events; + this.integrationUpperBound = integrationUpperBound; + this.eventWeights = eventWeights; + } + + public IBSErrorCalculatorWrapper(IBSCalculator calculator, int[] events, double integrationUpperBound) { + this.calculator = calculator; + this.events = events; + this.integrationUpperBound = integrationUpperBound; + this.eventWeights = new double[events.length]; + + Arrays.fill(this.eventWeights, 1.0); // default is to just sum all errors together + + } + + @Override + public double averageError(List responses, List predictions) { + final double[] errors = new double[events.length]; + final double n = responses.size(); + + for(int i=0; i < responses.size(); i++){ + final CompetingRiskResponse response = responses.get(i); + final CompetingRiskFunctions prediction = predictions.get(i); + + for(int k=0; k < this.events.length; k++){ + final int event = this.events[k]; + final RightContinuousStepFunction cif = prediction.getCumulativeIncidenceFunction(event); + errors[k] += calculator.calculateError(response, cif, event, integrationUpperBound) / n; + } + + } + + double totalError = 0.0; + for(int k=0; k < this.events.length; k++){ + totalError += this.eventWeights[k] * errors[k]; + } + + return totalError; + } +} diff --git a/library/src/main/java/ca/joeltherrien/randomforest/tree/vimp/RegressionErrorCalculator.java b/library/src/main/java/ca/joeltherrien/randomforest/tree/vimp/RegressionErrorCalculator.java new file mode 100644 index 0000000..09e839a --- /dev/null +++ b/library/src/main/java/ca/joeltherrien/randomforest/tree/vimp/RegressionErrorCalculator.java @@ -0,0 +1,23 @@ +package ca.joeltherrien.randomforest.tree.vimp; + +import java.util.List; + +public class RegressionErrorCalculator implements ErrorCalculator{ + + @Override + public double averageError(List responses, List predictions) { + double mean = 0.0; + final double n = responses.size(); + + for(int i=0; i { + + private final ErrorCalculator errorCalculator; + private final Forest forest; + private final List> observations; + private final List observedResponses; + + private final boolean isTrainingSet; // If true, then we use out-of-bag predictions + private final double baselineError; + + public VariableImportanceCalculator( + ErrorCalculator errorCalculator, + Forest forest, + List> observations, + boolean isTrainingSet + ){ + this.errorCalculator = errorCalculator; + this.forest = forest; + this.observations = observations; + this.isTrainingSet = isTrainingSet; + + this.observedResponses = observations.stream() + .map(row -> row.getResponse()).collect(Collectors.toList()); + + final List

baselinePredictions = makePredictions(observations); + this.baselineError = errorCalculator.averageError(observedResponses, baselinePredictions); + + } + + public double calculateVariableImportance(Covariate covariate, Optional random){ + final List scrambledValues = CovariateRow.scrambleCovariateValues(this.observations, covariate, random); + final List

alternatePredictions = makePredictions(scrambledValues); + final double newError = errorCalculator.averageError(this.observedResponses, alternatePredictions); + + return newError - this.baselineError; + } + + public double[] calculateVariableImportance(List covariates, Optional random){ + return covariates.stream() + .mapToDouble(covariate -> calculateVariableImportance(covariate, random)) + .toArray(); + } + + private List

makePredictions(List rowList){ + if(isTrainingSet){ + return forest.evaluateOOB(rowList); + } else{ + return forest.evaluate(rowList); + } + } + +} diff --git a/library/src/test/java/ca/joeltherrien/randomforest/tree/vimp/IBSErrorCalculatorWrapperTest.java b/library/src/test/java/ca/joeltherrien/randomforest/tree/vimp/IBSErrorCalculatorWrapperTest.java new file mode 100644 index 0000000..5b17d33 --- /dev/null +++ b/library/src/test/java/ca/joeltherrien/randomforest/tree/vimp/IBSErrorCalculatorWrapperTest.java @@ -0,0 +1,143 @@ +package ca.joeltherrien.randomforest.tree.vimp; + +import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskFunctions; +import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse; +import ca.joeltherrien.randomforest.responses.competingrisk.IBSCalculator; +import ca.joeltherrien.randomforest.utils.Point; +import ca.joeltherrien.randomforest.utils.RightContinuousStepFunction; +import ca.joeltherrien.randomforest.utils.Utils; +import org.junit.jupiter.api.Test; + +import java.util.List; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class IBSErrorCalculatorWrapperTest { + + /* + We already have tests for the IBSCalculator, so these tests are concerned with making sure we correctly average + the errors together, not that we fully test the production of each error under different scenarios (like + providing / not providing a censoring distribution). + */ + + private final double integrationUpperBound = 5.0; + + private final List responses; + private final List functions; + + + private final double[][] errors; + + public IBSErrorCalculatorWrapperTest(){ + this.responses = Utils.easyList( + new CompetingRiskResponse(0, 2.0), + new CompetingRiskResponse(0, 3.0), + new CompetingRiskResponse(1, 1.0), + new CompetingRiskResponse(1, 1.5), + new CompetingRiskResponse(2, 3.0), + new CompetingRiskResponse(2, 4.0) + ); + + final RightContinuousStepFunction cif1 = RightContinuousStepFunction.constructFromPoints(Utils.easyList( + new Point(1.0, 0.25), + new Point(1.5, 0.45) + ), 0.0); + + final RightContinuousStepFunction cif2 = RightContinuousStepFunction.constructFromPoints(Utils.easyList( + new Point(3.0, 0.25), + new Point(4.0, 0.45) + ), 0.0); + + // This function is for the unused CHFs and survival curve + // If we see infinities or NaNs popping up in our output we should look here. + final RightContinuousStepFunction emptyFun = RightContinuousStepFunction.constructFromPoints(Utils.easyList( + new Point(0.0, Double.NaN) + ), Double.NEGATIVE_INFINITY + ); + + final CompetingRiskFunctions function = CompetingRiskFunctions.builder() + .cumulativeIncidenceCurves(Utils.easyList(cif1, cif2)) + .causeSpecificHazards(Utils.easyList(emptyFun, emptyFun)) + .survivalCurve(emptyFun) + .build(); + + // Same prediction for every response. + this.functions = Utils.easyList(function, function, function, function, function, function); + + final IBSCalculator calculator = new IBSCalculator(); + this.errors = new double[2][6]; + + for(int event : new int[]{1, 2}){ + for(int i=0; i<6; i++){ + this.errors[event-1][i] = calculator.calculateError( + responses.get(i), function.getCumulativeIncidenceFunction(event), + event, integrationUpperBound + ); + } + } + + } + + @Test + public void testOneEventOne(){ + final IBSErrorCalculatorWrapper wrapper = new IBSErrorCalculatorWrapper(new IBSCalculator(), new int[]{1}, + this.integrationUpperBound); + + final double error = wrapper.averageError(this.responses, this.functions); + double expectedError = 0.0; + for(int i=0; i<6; i++){ + expectedError += errors[0][i] / 6.0; + } + + assertEquals(expectedError, error, 0.00000001); + } + + @Test + public void testOneEventTwo(){ + final IBSErrorCalculatorWrapper wrapper = new IBSErrorCalculatorWrapper(new IBSCalculator(), new int[]{2}, + this.integrationUpperBound); + + final double error = wrapper.averageError(this.responses, this.functions); + double expectedError = 0.0; + for(int i=0; i<6; i++){ + expectedError += errors[1][i] / 6.0; + } + + assertEquals(expectedError, error, 0.00000001); + } + + @Test + public void testTwoEventsNoWeights(){ + final IBSErrorCalculatorWrapper wrapper = new IBSErrorCalculatorWrapper(new IBSCalculator(), new int[]{1, 2}, + this.integrationUpperBound); + + final double error = wrapper.averageError(this.responses, this.functions); + double expectedError1 = 0.0; + double expectedError2 = 0.0; + + for(int i=0; i<6; i++){ + expectedError1 += errors[0][i] / 6.0; + expectedError2 += errors[1][i] / 6.0; + } + + assertEquals(expectedError1 + expectedError2, error, 0.00000001); + } + + @Test + public void testTwoEventsWithWeights(){ + final IBSErrorCalculatorWrapper wrapper = new IBSErrorCalculatorWrapper(new IBSCalculator(), new int[]{1, 2}, + this.integrationUpperBound, new double[]{1.0, 2.0}); + + final double error = wrapper.averageError(this.responses, this.functions); + double expectedError1 = 0.0; + double expectedError2 = 0.0; + + for(int i=0; i<6; i++){ + expectedError1 += errors[0][i] / 6.0; + expectedError2 += errors[1][i] / 6.0; + } + + assertEquals(1.0 * expectedError1 + 2.0 * expectedError2, error, 0.00000001); + } + +} diff --git a/library/src/test/java/ca/joeltherrien/randomforest/tree/vimp/RegressionErrorCalculatorTest.java b/library/src/test/java/ca/joeltherrien/randomforest/tree/vimp/RegressionErrorCalculatorTest.java new file mode 100644 index 0000000..df8e5d0 --- /dev/null +++ b/library/src/test/java/ca/joeltherrien/randomforest/tree/vimp/RegressionErrorCalculatorTest.java @@ -0,0 +1,26 @@ +package ca.joeltherrien.randomforest.tree.vimp; + +import ca.joeltherrien.randomforest.utils.Utils; +import org.junit.jupiter.api.Test; + +import java.util.List; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class RegressionErrorCalculatorTest { + + private final RegressionErrorCalculator calculator = new RegressionErrorCalculator(); + + @Test + public void testRegressionErrorCalculator(){ + final List responses = Utils.easyList(1.0, 1.5, 0.0, 3.0); + final List predictions = Utils.easyList(1.5, 1.7, 0.1, 2.9); + + // Differences are 0.5, 0.2, -0.1, 0.1 + // Squared: 0.25, 0.04, 0.01, 0.01 + + assertEquals((0.25 + 0.04 + 0.01 + 0.01)/4.0, calculator.averageError(responses, predictions), 0.000000001); + + } + +} diff --git a/library/src/test/java/ca/joeltherrien/randomforest/tree/vimp/VariableImportanceCalculatorTest.java b/library/src/test/java/ca/joeltherrien/randomforest/tree/vimp/VariableImportanceCalculatorTest.java new file mode 100644 index 0000000..8d393a5 --- /dev/null +++ b/library/src/test/java/ca/joeltherrien/randomforest/tree/vimp/VariableImportanceCalculatorTest.java @@ -0,0 +1,359 @@ +package ca.joeltherrien.randomforest.tree.vimp; + +import ca.joeltherrien.randomforest.Row; +import ca.joeltherrien.randomforest.covariates.Covariate; +import ca.joeltherrien.randomforest.covariates.bool.BooleanCovariate; +import ca.joeltherrien.randomforest.covariates.bool.BooleanSplitRule; +import ca.joeltherrien.randomforest.covariates.factor.FactorCovariate; +import ca.joeltherrien.randomforest.covariates.numeric.NumericCovariate; +import ca.joeltherrien.randomforest.covariates.numeric.NumericSplitRule; +import ca.joeltherrien.randomforest.responses.regression.MeanResponseCombiner; +import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceSplitFinder; +import ca.joeltherrien.randomforest.tree.*; +import ca.joeltherrien.randomforest.utils.Utils; +import org.junit.jupiter.api.Test; + +import java.util.*; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class VariableImportanceCalculatorTest { + + /* + Since the logic for VariableImportanceCalculator is generic, it will be much easier to test under a regression + setting. + */ + + // We'l have a very simple Forest of two trees + private final Forest forest; + + + private final List covariates; + private final List> rowList; + + /* + Long setup process; forest is manually constructed so that we can be exactly sure on our variable importance. + + */ + public VariableImportanceCalculatorTest(){ + final BooleanCovariate booleanCovariate = new BooleanCovariate("x", 0); + final NumericCovariate numericCovariate = new NumericCovariate("y", 1); + final FactorCovariate factorCovariate = new FactorCovariate("z", 2, + Utils.easyList("red", "blue", "green")); + + this.covariates = Utils.easyList(booleanCovariate, numericCovariate, factorCovariate); + + final TreeTrainer treeTrainer = TreeTrainer.builder() + .responseCombiner(new MeanResponseCombiner()) + .splitFinder(new WeightedVarianceSplitFinder()) + .numberOfSplits(0) + .nodeSize(1) + .maxNodeDepth(100) + .mtry(3) + .checkNodePurity(false) + .covariates(this.covariates) + .build(); + + /* + Plan for data - BooleanCovariate is split on first and has the largest impact. + NumericCovariate is at second level and has more minimal impact. + FactorCovariate is useless and never used. + Our tree (we'll duplicate it for testing OOB errors) will have a depth of 1. (0 based). + */ + + final Tree tree1 = makeTree(covariates, 0.0, new int[]{1,2,3,4}); + final Tree tree2 = makeTree(covariates, 2.0, new int[]{5,6,7,8}); + + this.forest = Forest.builder() + .trees(Utils.easyList(tree1, tree2)) + .treeResponseCombiner(new MeanResponseCombiner()) + .covariateList(this.covariates) + .build(); + + // formula; boolean high adds 100; high numeric adds 10 + // This row list should have a baseline error of 0.0 + this.rowList = Utils.easyList( + Row.createSimple(Utils.easyMap( + "x", "false", + "y", "0.0", + "z", "red"), + covariates, 1, 0.0 + ), + Row.createSimple(Utils.easyMap( + "x", "false", + "y", "10.0", + "z", "blue"), + covariates, 2, 10.0 + ), + Row.createSimple(Utils.easyMap( + "x", "true", + "y", "0.0", + "z", "red"), + covariates, 3, 100.0 + ), + Row.createSimple(Utils.easyMap( + "x", "true", + "y", "10.0", + "z", "green"), + covariates, 4, 110.0 + ), + + Row.createSimple(Utils.easyMap( + "x", "false", + "y", "0.0", + "z", "red"), + covariates, 5, 0.0 + ), + Row.createSimple(Utils.easyMap( + "x", "false", + "y", "10.0", + "z", "blue"), + covariates, 6, 10.0 + ), + Row.createSimple(Utils.easyMap( + "x", "true", + "y", "0.0", + "z", "red"), + covariates, 7, 100.0 + ), + Row.createSimple(Utils.easyMap( + "x", "true", + "y", "10.0", + "z", "green"), + covariates, 8, 110.0 + ) + ); + } + + private Tree makeTree(List covariates, double offset, int[] indices){ + // Naming convention - xyTerminal where x and y are low/high denotes whether BooleanCovariate(x) is low/high and + // whether NumericCovariate(y) is low/high. + final TerminalNode lowLowTerminal = new TerminalNode<>(0.0 + offset, 5); + final TerminalNode lowHighTerminal = new TerminalNode<>(10.0 + offset, 5); + final TerminalNode highLowTerminal = new TerminalNode<>(100.0 + offset, 5); + final TerminalNode highHighTerminal = new TerminalNode<>(110.0 + offset, 5); + + final SplitNode lowSplitNode = SplitNode.builder() + .leftHand(lowLowTerminal) + .rightHand(lowHighTerminal) + .probabilityNaLeftHand(0.5) + .splitRule(new NumericSplitRule((NumericCovariate) covariates.get(1), 5.0)) + .build(); + + final SplitNode highSplitNode = SplitNode.builder() + .leftHand(highLowTerminal) + .rightHand(highHighTerminal) + .probabilityNaLeftHand(0.5) + .splitRule(new NumericSplitRule((NumericCovariate) covariates.get(1), 5.0)) + .build(); + + final SplitNode rootSplitNode = SplitNode.builder() + .leftHand(lowSplitNode) + .rightHand(highSplitNode) + .probabilityNaLeftHand(0.5) + .splitRule(new BooleanSplitRule((BooleanCovariate) covariates.get(0))) + .build(); + + return new Tree<>(rootSplitNode, indices); + + } + + // Experiment with random seeds to first examine what a split does so we know what to expect + /* + public static void main(String[] args){ + final List ints1 = IntStream.range(1, 9).boxed().collect(Collectors.toList()); + final List ints2 = IntStream.range(1, 9).boxed().collect(Collectors.toList()); + + final Random random = new Random(123); + Collections.shuffle(ints1, random); + Collections.shuffle(ints2, random); + + System.out.println(ints1); + // [1, 4, 8, 2, 5, 3, 7, 6] + + System.out.println(ints2); + [6, 1, 4, 7, 5, 2, 8, 3] + } + */ + + @Test + public void testVariableImportanceOnXNoOOB(){ + // x is the BooleanCovariate + + Random random = new Random(123); + final VariableImportanceCalculator calculator = new VariableImportanceCalculator<>( + new RegressionErrorCalculator(), + this.forest, + this.rowList, + false + ); + + double importance = calculator.calculateVariableImportance(this.covariates.get(0), Optional.of(random)); + + final double expectedBaselineError = 1.0; // Everything is off by 1, so average is 1.0 + + final List permutedPredictions = Utils.easyList( + 1., 111., 101., 11., 1., 111., 101., 11. + ); + final List observedValues = this.rowList.stream().map(Row::getResponse).collect(Collectors.toList()); + + final double expectedError = new RegressionErrorCalculator().averageError(observedValues, permutedPredictions); + + assertEquals(expectedError - expectedBaselineError, importance, 0.0000001); + } + + @Test + public void testVariableImportanceOnXOOB(){ + // x is the BooleanCovariate + + Random random = new Random(123); + final VariableImportanceCalculator calculator = new VariableImportanceCalculator<>( + new RegressionErrorCalculator(), + this.forest, + this.rowList, + true + ); + + double importance = calculator.calculateVariableImportance(this.covariates.get(0), Optional.of(random)); + + // First 4 observations are off by 2, last 4 are off by 0 + final double expectedBaselineError = 2.0*2.0 * 4.0 / 8.0; + + // Remember we are working with OOB predictions + final List permutedPredictions = Utils.easyList( + 2., 112., 102., 12., 0., 110., 100., 10. + ); + final List observedValues = this.rowList.stream().map(Row::getResponse).collect(Collectors.toList()); + + final double expectedError = new RegressionErrorCalculator().averageError(observedValues, permutedPredictions); + + assertEquals(expectedError - expectedBaselineError, importance, 0.0000001); + } + + @Test + public void testVariableImportanceOnYNoOOB(){ + // y is the NumericCovariate + + Random random = new Random(123); + final VariableImportanceCalculator calculator = new VariableImportanceCalculator<>( + new RegressionErrorCalculator(), + this.forest, + this.rowList, + false + ); + + double importance = calculator.calculateVariableImportance(this.covariates.get(1), Optional.of(random)); + + final double expectedBaselineError = 1.0; // Everything is off by 1, so average is 1.0 + + final List permutedPredictions = Utils.easyList( + 1., 11., 111., 111., 1., 1., 101., 111. + ); + final List observedValues = this.rowList.stream().map(Row::getResponse).collect(Collectors.toList()); + + final double expectedError = new RegressionErrorCalculator().averageError(observedValues, permutedPredictions); + + assertEquals(expectedError - expectedBaselineError, importance, 0.0000001); + } + + @Test + public void testVariableImportanceOnYOOB(){ + // y is the NumericCovariate + + Random random = new Random(123); + final VariableImportanceCalculator calculator = new VariableImportanceCalculator<>( + new RegressionErrorCalculator(), + this.forest, + this.rowList, + true + ); + + double importance = calculator.calculateVariableImportance(this.covariates.get(1), Optional.of(random)); + + // First 4 observations are off by 2, last 4 are off by 0 + final double expectedBaselineError = 2.0*2.0 * 4.0 / 8.0; + + // Remember we are working with OOB predictions + final List permutedPredictions = Utils.easyList( + 2., 12., 112., 112., 0., 0., 100., 110. + ); + final List observedValues = this.rowList.stream().map(Row::getResponse).collect(Collectors.toList()); + + final double expectedError = new RegressionErrorCalculator().averageError(observedValues, permutedPredictions); + + assertEquals(expectedError - expectedBaselineError, importance, 0.0000001); + } + + @Test + public void testVariableImportanceOnZNoOOB(){ + // z is the useless FactorCovariate + + Random random = new Random(123); + final VariableImportanceCalculator calculator = new VariableImportanceCalculator<>( + new RegressionErrorCalculator(), + this.forest, + this.rowList, + false + ); + + double importance = calculator.calculateVariableImportance(this.covariates.get(2), Optional.of(random)); + + // FactorImportance did nothing; so permuting it will make no difference to baseline error + assertEquals(0, importance, 0.0000001); + } + + @Test + public void testVariableImportanceOnZOOB(){ + // z is the useless FactorCovariate + + Random random = new Random(123); + final VariableImportanceCalculator calculator = new VariableImportanceCalculator<>( + new RegressionErrorCalculator(), + this.forest, + this.rowList, + true + ); + + double importance = calculator.calculateVariableImportance(this.covariates.get(2), Optional.of(random)); + + // FactorImportance did nothing; so permuting it will make no difference to baseline error + assertEquals(0, importance, 0.0000001); + } + + @Test + public void testVariableImportanceMultiple(){ + Random random = new Random(123); + final VariableImportanceCalculator calculator = new VariableImportanceCalculator<>( + new RegressionErrorCalculator(), + this.forest, + this.rowList, + false + ); + + double importance[] = calculator.calculateVariableImportance(covariates, Optional.of(random)); + + final double expectedBaselineError = 1.0; // Everything is off by 1, so average is 1.0 + + final List observedValues = this.rowList.stream().map(Row::getResponse).collect(Collectors.toList()); + + final List permutedPredictionsX = Utils.easyList( + 1., 111., 101., 11., 1., 111., 101., 11. + ); + + // [6, 1, 4, 7, 5, 2, 8, 3] + final List permutedPredictionsY = Utils.easyList( + 11., 1., 111., 101., 1., 11., 111., 101. + ); + + final double expectedErrorX = new RegressionErrorCalculator().averageError(observedValues, permutedPredictionsX); + final double expectedErrorY = new RegressionErrorCalculator().averageError(observedValues, permutedPredictionsY); + + assertEquals(expectedErrorX - expectedBaselineError, importance[0], 0.0000001); + assertEquals(expectedErrorY - expectedBaselineError, importance[1], 0.0000001); + assertEquals(0, importance[2], 0.0000001); + + } + +}