Add variable importance

This commit is contained in:
Joel Therrien 2019-08-12 14:19:45 -07:00
parent fd8621a88d
commit 48859b0249
13 changed files with 482 additions and 130 deletions

View file

@ -35,4 +35,5 @@ export(loadForest)
export(naiveConcordance)
export(saveForest)
export(train)
export(vimp)
import(rJava)

View file

@ -59,45 +59,14 @@ integratedBrierScore <- function(responses, predictions, event, time, censoringD
java.censoringDistribution <- NULL
if(!is.null(censoringDistribution)){
if(is.numeric(censoringDistribution)){
# estimate ECDF
censoringTimes <- .jarray(censoringDistribution, "D")
java.censoringDistribution <- .jcall(.class_Utils, makeResponse(.class_RightContinuousStepFunction), "estimateOneMinusECDF", censoringTimes)
} else if(is.list(censoringDistribution)){
# First check that censoringDistribution fits the correct format
if(is.null(censoringDistribution$x) | is.null(censoringDistribution$y)){
stop("If the censoringDistribution is provided as a list, it must have an x and a y item that are numeric.")
}
if(length(censoringDistribution$x) != length(censoringDistribution$y)){
stop("x and y in censoringDistribution must have the same length.")
}
if(!is.numeric(censoringDistribution$x) | !is.numeric(censoringDistribution$y)){
stop("x and y in censoringDistribution must both be numeric.")
}
java.censoringDistribution <- createRightContinuousStepFunction(censoringDistribution$x, censoringDistribution$y, defaultY = 1.0)
} else if("stepfun" %in% class(censoringDistribution)){
x <- stats::knots(censoringDistribution)
y <- censoringDistribution(x)
java.censoringDistribution <- createRightContinuousStepFunction(x, y, defaultY = 1.0)
}
else{
stop("Invalid censoringDistribution")
}
# Make sure we wrap it in an Optional
java.censoringDistribution <- processCensoringDistribution(censoringDistribution)
java.censoringDistribution <- .object_Optional(java.censoringDistribution)
}
else{
java.censoringDistribution <- .object_Optional(NULL)
}
predictions.java <- lapply(predictions, function(x){return(x$javaObject)})
predictions.java <- convertRListToJava(predictions.java)

View file

@ -10,6 +10,7 @@
.class_Collection <- "java/util/Collection"
.class_Serializable <- "java/io/Serializable"
.class_File <- "java/io/File"
.class_Random <- "java/util/Random"
# Utility Classes
.class_DataUtils <- "ca/joeltherrien/randomforest/utils/DataUtils"
@ -53,6 +54,14 @@
.class_LogRankSplitFinder <- "ca/joeltherrien/randomforest/responses/competingrisk/splitfinder/LogRankSplitFinder"
.class_WeightedVarianceSplitFinder <- "ca/joeltherrien/randomforest/responses/regression/WeightedVarianceSplitFinder"
# VIMP classes
.class_IBSCalculator <- "ca/joeltherrien/randomforest/responses/competingrisk/IBSCalculator"
.class_ErrorCalculator <- "ca/joeltherrien/randomforest/tree/vimp/ErrorCalculator"
.class_RegressionErrorCalculator <- "ca/joeltherrien/randomforest/tree/vimp/RegressionErrorCalculator"
.class_IBSErrorCalculatorWrapper <- "ca/joeltherrien/randomforest/tree/vimp/IBSErrorCalculatorWrapper"
.class_VariableImportanceCalculator <- "ca/joeltherrien/randomforest/tree/vimp/VariableImportanceCalculator"
.object_Optional <- function(object=NULL){
if(is.null(object)){
return(.jcall("java/util/Optional", "Ljava/util/Optional;", "empty"))

View file

@ -18,7 +18,7 @@ loadData <- function(data, xVarNames, responses, covariateList.java = NULL){
rowList <- .jcall(.class_RUtils, makeResponse(.class_List), "importDataWithResponses",
responses$javaObject, covariateList.java, textData)
return(list(covariateList=covariateList.java, dataset=rowList))
return(list(covariateList = covariateList.java, dataset = rowList, responses = responses))
}
@ -54,15 +54,7 @@ getCovariateList <- function(data, xvarNames){
loadPredictionData <- function(newData, covariateList.java){
xVarNames <- character(.jcall(covariateList.java, "I", "size"))
for(j in 1:length(xVarNames)){
covariate.java <- .jcast(
.jcall(covariateList.java, makeResponse(.class_Object), "get", as.integer(j-1)),
.class_Covariate
)
xVarNames[j] <- .jcall(covariate.java, makeResponse(.class_String), "getName")
}
xVarNames <- extractCovariateNamesFromJavaList(covariateList.java)
if(any(!(xVarNames %in% names(newData)))){
varsMissing = xVarNames[!(xVarNames %in% names(newData))]
@ -84,3 +76,16 @@ loadPredictionData <- function(newData, covariateList.java){
return(rowList)
}
extractCovariateNamesFromJavaList <- function(covariateList.java){
xVarNames <- character(.jcall(covariateList.java, "I", "size"))
for(j in 1:length(xVarNames)){
covariate.java <- .jcast(
.jcall(covariateList.java, makeResponse(.class_Object), "get", as.integer(j-1)),
.class_Covariate
)
xVarNames[j] <- .jcall(covariate.java, makeResponse(.class_String), "getName")
}
return(xVarNames)
}

View file

@ -14,6 +14,23 @@ convertRListToJava <- function(lst){
return(javaList)
}
#Internal function
convertJavaListToR <- function(javaList, class = .class_Object){
lst <- list()
javaList.length <- .jcall(javaList, "I", "size")
for(i in 0:(javaList.length - 1)){
object <- .jcall(javaList, makeResponse(.class_Object), "get", as.integer(i))
object <- .jcast(object, class)
lst[[i+1]] <- object
}
return(lst)
}
#' @export
print.SplitFinder = function(x, ...) print(x$call)

View file

@ -0,0 +1,39 @@
# Internal function. Takes a censoring distribution and turns it into a
# RightContinuousStepFunction Java object.
processCensoringDistribution <- function(censoringDistribution){
if(is.numeric(censoringDistribution)){
# estimate ECDF
censoringTimes <- .jarray(censoringDistribution, "D")
java.censoringDistribution <- .jcall(.class_Utils, makeResponse(.class_RightContinuousStepFunction), "estimateOneMinusECDF", censoringTimes)
} else if(is.list(censoringDistribution)){
# First check that censoringDistribution fits the correct format
if(is.null(censoringDistribution$x) | is.null(censoringDistribution$y)){
stop("If the censoringDistribution is provided as a list, it must have an x and a y item that are numeric.")
}
if(length(censoringDistribution$x) != length(censoringDistribution$y)){
stop("x and y in censoringDistribution must have the same length.")
}
if(!is.numeric(censoringDistribution$x) | !is.numeric(censoringDistribution$y)){
stop("x and y in censoringDistribution must both be numeric.")
}
java.censoringDistribution <- createRightContinuousStepFunction(censoringDistribution$x, censoringDistribution$y, defaultY = 1.0)
} else if("stepfun" %in% class(censoringDistribution)){
x <- stats::knots(censoringDistribution)
y <- censoringDistribution(x)
java.censoringDistribution <- createRightContinuousStepFunction(x, y, defaultY = 1.0)
}
else{
stop("Invalid censoringDistribution")
}
return(java.censoringDistribution)
}

59
R/processFormula.R Normal file
View file

@ -0,0 +1,59 @@
# Internal function that takes a formula and processes it for use in the Java
# code. existingCovariateList is optional; if not provided then a new one is
# created internally.
processFormula <- function(formula, data, covariateList.java = NULL){
# Having an R copy of the data loaded at the same time can be wasteful; we
# also allow users to provide an environment of the data which gets removed
# after being imported into Java
if(class(data) == "environment"){
if(is.null(data$data)){
stop("When providing an environment with the dataset, the environment must contain an item called 'data'")
}
env <- data
data <- env$data
env$data <- NULL
gc()
}
yVar <- formula[[2]]
responses <- NULL
variablesToDrop <- character(0)
# yVar is a call object; as.character(yVar) will be the different components, including the parameters.
# if the length of yVar is > 1 then it's a function call. If the length is 1, and it's not in data,
# then we also need to explicitly evaluate it
if(class(yVar) == "call" || !(as.character(yVar) %in% colnames(data))){
# yVar is a function like CR_Response
responses <- eval(expr=yVar, envir=data)
if(class(formula[[3]]) == "name" && as.character(formula[[3]])=="."){
# do any of the variables match data in data? We need to track that so we can drop them later
variablesToDrop <- as.character(yVar)[as.character(yVar) %in% names(data)]
}
formula[[2]] <- NULL
} else if(class(yVar) == "name"){ # and implicitly yVar is contained in data
variablesToDrop <- as.character(yVar)
}
# Includes responses which we may need to later cut out if `.` was used on the
# right-hand-side
filteredData <- stats::model.frame(formula=formula, data=data, na.action=stats::na.pass)
if(is.null(responses)){ # If this if-statement runs then we have a simple (i.e. numeric) response
responses <- stats::model.response(filteredData)
}
# remove any response variables on the right-hand-side
covariateData <- filteredData[, !(names(filteredData) %in% variablesToDrop), drop=FALSE]
dataset <- loadData(covariateData, colnames(covariateData), responses, covariateList.java = covariateList.java)
return(dataset)
}

100
R/train.R
View file

@ -14,7 +14,7 @@ getCores <- function(){
return(cores)
}
train.internal <- function(responses, covariateData, splitFinder,
train.internal <- function(dataset, splitFinder,
nodeResponseCombiner, forestResponseCombiner, ntree,
numberOfSplits, mtry, nodeSize, maxNodeDepth,
splitPureNodes, savePath, savePath.overwrite,
@ -52,15 +52,15 @@ train.internal <- function(responses, covariateData, splitFinder,
}
if(is.null(splitFinder)){
splitFinder <- splitFinderDefault(responses)
splitFinder <- splitFinderDefault(dataset$responses)
}
if(is.null(nodeResponseCombiner)){
nodeResponseCombiner <- nodeResponseCombinerDefault(responses)
nodeResponseCombiner <- nodeResponseCombinerDefault(dataset$responses)
}
if(is.null(forestResponseCombiner)){
forestResponseCombiner <- forestResponseCombinerDefault(responses)
forestResponseCombiner <- forestResponseCombinerDefault(dataset$responses)
}
@ -75,20 +75,6 @@ train.internal <- function(responses, covariateData, splitFinder,
stop("forestResponseCombiner must be a ResponseCombiner")
}
if(class(covariateData)=="environment"){
if(is.null(covariateData$data)){
stop("When providing an environment with the dataset, the environment must contain an item called 'data'")
}
dataset <- loadData(covariateData$data, colnames(covariateData$data), responses)
covariateData$data <- NULL # save memory, hopefully
gc() # explicitly try to save memory
}
else{
dataset <- loadData(covariateData, colnames(covariateData), responses)
}
treeTrainer <- createTreeTrainer(responseCombiner=nodeResponseCombiner,
splitFinder=splitFinder,
covariateList=dataset$covariateList,
@ -188,7 +174,7 @@ train.internal <- function(responses, covariateData, splitFinder,
#'
#' @param formula You may specify the response and covariates as a formula
#' instead; make sure the response in the formula is still properly
#' constructed; see \code{responses}
#' constructed.
#' @param data A data.frame containing the columns of the predictors and
#' responses.
#' @param splitFinder A split finder that's used to score splits in the random
@ -314,74 +300,16 @@ train <- function(formula, data, splitFinder = NULL, nodeResponseCombiner = NULL
savePath.overwrite=c("warn", "delete", "merge"), cores = getCores(),
randomSeed = NULL, displayProgress = TRUE){
# Having an R copy of the data loaded at the same time can be wasteful; we
# also allow users to provide an environment of the data which gets removed
# after being imported into Java
env <- NULL
if(class(data) == "environment"){
if(is.null(data$data)){
stop("When providing an environment with the dataset, the environment must contain an item called 'data'")
}
dataset <- processFormula(formula, data)
env <- data
data <- env$data
}
yVar <- formula[[2]]
responses <- NULL
variablesToDrop <- character(0)
# yVar is a call object; as.character(yVar) will be the different components, including the parameters.
# if the length of yVar is > 1 then it's a function call. If the length is 1, and it's not in data,
# then we also need to explicitly evaluate it
if(class(yVar)=="call" || !(as.character(yVar) %in% colnames(data))){
# yVar is a function like CompetingRiskResponses
responses <- eval(expr=yVar, envir=data)
if(class(formula[[3]]) == "name" && as.character(formula[[3]])=="."){
# do any of the variables match data in data? We need to track that so we can drop them later
variablesToDrop <- as.character(yVar)[as.character(yVar) %in% names(data)]
}
formula[[2]] <- NULL
} else if(class(yVar)=="name"){ # and implicitly yVar is contained in data
variablesToDrop <- as.character(yVar)
}
# Includes responses which we may need to later cut out
mf <- stats::model.frame(formula=formula, data=data, na.action=stats::na.pass)
if(is.null(responses)){
responses <- stats::model.response(mf)
}
# remove any response variables
mf <- mf[,!(names(mf) %in% variablesToDrop), drop=FALSE]
# If environment was provided instead of data
if(!is.null(env)){
env$data <- mf
rm(data)
forest <- train.internal(responses, env, splitFinder = splitFinder,
nodeResponseCombiner = nodeResponseCombiner,
forestResponseCombiner = forestResponseCombiner,
ntree = ntree, numberOfSplits = numberOfSplits,
mtry = mtry, nodeSize = nodeSize, maxNodeDepth = maxNodeDepth,
splitPureNodes = splitPureNodes, savePath = savePath,
savePath.overwrite = savePath.overwrite, cores = cores,
randomSeed = randomSeed, displayProgress = displayProgress)
} else{
forest <- train.internal(responses, mf, splitFinder = splitFinder,
nodeResponseCombiner = nodeResponseCombiner,
forestResponseCombiner = forestResponseCombiner,
ntree = ntree, numberOfSplits = numberOfSplits,
mtry = mtry, nodeSize = nodeSize, maxNodeDepth = maxNodeDepth,
splitPureNodes = splitPureNodes, savePath = savePath,
savePath.overwrite = savePath.overwrite, cores = cores,
randomSeed = randomSeed, displayProgress = displayProgress)
}
forest <- train.internal(dataset, splitFinder = splitFinder,
nodeResponseCombiner = nodeResponseCombiner,
forestResponseCombiner = forestResponseCombiner,
ntree = ntree, numberOfSplits = numberOfSplits,
mtry = mtry, nodeSize = nodeSize, maxNodeDepth = maxNodeDepth,
splitPureNodes = splitPureNodes, savePath = savePath,
savePath.overwrite = savePath.overwrite, cores = cores,
randomSeed = randomSeed, displayProgress = displayProgress)
forest$call <- match.call()
forest$formula <- formula

175
R/vimp.R Normal file
View file

@ -0,0 +1,175 @@
#' Variable Importance
#'
#' Calculate variable importance by recording the increase in error when a given
#' predictor is randomly permuted. Regression forests uses mean squared error;
#' competing risks uses integrated Brier score.
#'
#' @param forest The forest that was trained.
#' @param newData A test set of the data if available. If not, then out of bag
#' errors will be attempted on the training set.
#' @param randomSeed The source of randomness used to permute the values. Can be
#' left blank.
#' @param events If using competing risks forest, the events that the error
#' measure used for VIMP should be calculated on.
#' @param time If using competing risks forest, the upper bound of the
#' integrated Brier score.
#' @param censoringDistribution (Optional) If using competing risks forest, the
#' censoring distribution. See \code{\link{integratedBrierScore} for details.}
#' @param eventWeights (Optional) If using competing risks forest, weights to be
#' applied to the error for each of the \code{events}.
#'
#' @return A named numeric vector of importance values.
#' @export
#'
#' @examples
#' data(wihs)
#'
#' forest <- train(CR_Response(status, time) ~ ., wihs,
#' ntree = 100, numberOfSplits = 0, mtry=3, nodeSize = 5)
#'
#' vimp(forest, events = 1:2, time = 8.0)
#'
vimp <- function(
forest,
newData = NULL,
randomSeed = NULL,
type = c("mean", "z", "raw"),
events = NULL,
time = NULL,
censoringDistribution = NULL,
eventWeights = NULL){
if(is.null(newData) & is.null(forest$dataset)){
stop("forest doesn't have a copy of the training data loaded (this happens if you just loaded it); please manually specify newData and possibly out.of.bag")
}
# Basically we check if type is either null, length 0, or one of the invalid values.
# We can't include the last statement in the same statement as length(tyoe) < 1,
# because R checks both cases and a different error would display if length(type) == 0
typeError = is.null(type) | length(type) < 1
if(!typeError){
typeError = !(type[1] %in% c("mean", "z", "raw"))
}
if(typeError){
stop("A valid response type must be provided.")
}
if(is.null(newData)){
data.java <- forest$dataset
out.of.bag <- TRUE
}
else{ # newData is provided
data.java <- processFormula(forest$formula, newData, forest$covariateList)$dataset
out.of.bag <- FALSE
}
predictionClass <- forest$params$forestResponseCombiner$outputClass
if(predictionClass == "CompetingRiskFunctions"){
if(is.null(time) | length(time) != 1){
stop("time must be set at length 1")
}
errorCalculator.java <- ibsCalculatorWrapper(
events = events,
time = time,
censoringDistribution = censoringDistribution,
eventWeights = eventWeights)
} else if(predictionClass == "numeric"){
errorCalculator.java <- .jnew(.class_RegressionErrorCalculator)
errorCalculator.java <- .jcast(errorCalculator.java, .class_ErrorCalculator)
} else{
stop(paste0("VIMP not yet supported for ", predictionClass, ". If you're just using a non-custom version of largeRCRF then this is a bug and should be reported."))
}
forest.trees.java <- .jcall(forest$javaObject, makeResponse(.class_List), "getTrees")
vimp.calculator <- .jnew(.class_VariableImportanceCalculator,
errorCalculator.java,
forest.trees.java,
data.java,
out.of.bag # isTrainingSet parameter
)
random.java <- NULL
if(!is.null(randomSeed)){
random.java <- .jnew(.class_Random, .jlong(as.integer(randomSeed)))
}
random.java <- .object_Optional(random.java)
covariateRList <- convertJavaListToR(forest$covariateList, class = .class_Covariate)
importanceValues <- matrix(nrow = forest$params$ntree, ncol = length(covariateRList))
colnames(importanceValues) <- extractCovariateNamesFromJavaList(forest$covariateList)
for(j in 1:length(covariateRList)){
covariateJava <- covariateRList[[j]]
covariateJava <-
importanceValues[, j] <- .jcall(vimp.calculator, "[D", "calculateVariableImportanceRaw", covariateJava, random.java)
}
if(type[1] == "raw"){
return(importanceValues)
} else if(type[1] == "mean"){
meanImportanceValues <- apply(importanceValues, 2, mean)
return(meanImportanceValues)
} else if(type[1] == "z"){
zImportanceValues <- apply(importanceValues, 2, function(x){
meanValue <- mean(x)
standardError <- sd(x)/sqrt(length(x))
return(meanValue / standardError)
})
return(zImportanceValues)
} else{
stop("A valid response type must be provided.")
}
return(importance)
}
# Internal function
ibsCalculatorWrapper <- function(events, time, censoringDistribution = NULL, eventWeights = NULL){
if(is.null(events)){
stop("events must be specified if using vimp on competing risks data")
}
if(is.null(time)){
stop("time must be specified if using vimp on competing risks data")
}
java.censoringDistribution <- NULL
if(!is.null(censoringDistribution)){
java.censoringDistribution <- processCensoringDistribution(censoringDistribution)
java.censoringDistribution <- .object_Optional(java.censoringDistribution)
}
else{
java.censoringDistribution <- .object_Optional(NULL)
}
ibsCalculator.java <- .jnew(.class_IBSCalculator, java.censoringDistribution)
if(is.null(eventWeights)){
eventWeights <- rep(1, times = length(events))
}
ibsCalculatorWrapper.java <- .jnew(.class_IBSErrorCalculatorWrapper,
ibsCalculator.java,
.jarray(as.integer(events)),
as.numeric(time),
.jarray(as.numeric(eventWeights)))
ibsCalculatorWrapper.java <- .jcast(ibsCalculatorWrapper.java, .class_ErrorCalculator)
return(ibsCalculatorWrapper.java)
}

View file

@ -13,7 +13,7 @@ train(formula, data, splitFinder = NULL, nodeResponseCombiner = NULL,
\arguments{
\item{formula}{You may specify the response and covariates as a formula
instead; make sure the response in the formula is still properly
constructed; see \code{responses}}
constructed.}
\item{data}{A data.frame containing the columns of the predictors and
responses.}

48
man/vimp.Rd Normal file
View file

@ -0,0 +1,48 @@
% Generated by roxygen2: do not edit by hand
% Please edit documentation in R/vimp.R
\name{vimp}
\alias{vimp}
\title{Variable Importance}
\usage{
vimp(forest, newData = NULL, randomSeed = NULL, type = c("mean", "z",
"raw"), events = NULL, time = NULL, censoringDistribution = NULL,
eventWeights = NULL)
}
\arguments{
\item{forest}{The forest that was trained.}
\item{newData}{A test set of the data if available. If not, then out of bag
errors will be attempted on the training set.}
\item{randomSeed}{The source of randomness used to permute the values. Can be
left blank.}
\item{events}{If using competing risks forest, the events that the error
measure used for VIMP should be calculated on.}
\item{time}{If using competing risks forest, the upper bound of the
integrated Brier score.}
\item{censoringDistribution}{(Optional) If using competing risks forest, the
censoring distribution. See \code{\link{integratedBrierScore} for details.}}
\item{eventWeights}{(Optional) If using competing risks forest, weights to be
applied to the error for each of the \code{events}.}
}
\value{
A named numeric vector of importance values.
}
\description{
Calculate variable importance by recording the increase in error when a given
predictor is randomly permuted. Regression forests uses mean squared error;
competing risks uses integrated Brier score.
}
\examples{
data(wihs)
forest <- train(CR_Response(status, time) ~ ., wihs,
ntree = 100, numberOfSplits = 0, mtry=3, nodeSize = 5)
vimp(forest, events = 1:2, time = 8.0)
}

102
tests/testthat/test_vimp.R Normal file
View file

@ -0,0 +1,102 @@
context("Use VIMP without error")
test_that("VIMP doesn't crash; no test dataset", {
data(wihs)
forest <- train(CR_Response(status, time) ~ ., wihs, ntree=50, numberOfSplits=0, mtry=1, nodeSize=5, displayProgress=FALSE)
# Run VIMP several times under different scenarios
importance <- vimp(forest, type="raw", events=1:2, time=5.0)
vimp(forest, type="raw", events=1, time=5.0)
vimp(forest, type="raw", events=1:2, time=5.0, eventWeights = c(0.2, 0.8))
# Not much of a test, but the Java code tests more for correctness. This just
# tests that the R code runs without error.
expect_equal(ncol(importance), 4) # 4 predictors
})
test_that("VIMP doesn't crash; test dataset", {
data(wihs)
trainingData <- wihs[1:1000,]
testData <- wihs[1001:nrow(wihs),]
forest <- train(CR_Response(status, time) ~ ., trainingData, ntree=50, numberOfSplits=0, mtry=1, nodeSize=5, displayProgress=FALSE, cores=1)
# Run VIMP several times under different scenarios
importance <- vimp(forest, newData=testData, type="raw", events=1:2, time=5.0)
vimp(forest, newData=testData, type="raw", events=1, time=5.0)
vimp(forest, newData=testData, type="raw", events=1:2, time=5.0, eventWeights = c(0.2, 0.8))
# Not much of a test, but the Java code tests more for correctness. This just
# tests that the R code runs without error.
expect_equal(ncol(importance), 4) # 4 predictors
})
test_that("VIMP doesn't crash; censoring distribution; all methods equal", {
sampleData <- data.frame(x=rnorm(100))
sampleData$T <- sample(0:4, size=100, replace = TRUE) # the censor distribution we provide needs to conform to the data or we can get NaNs
sampleData$delta <- sample(0:2, size = 100, replace = TRUE)
testData <- sampleData[1:5,]
trainingData <- sampleData[6:100,]
forest <- train(CR_Response(delta, T) ~ x, trainingData, ntree=50, numberOfSplits=0, mtry=1, nodeSize=5, cores=2, displayProgress=FALSE)
importance1 <- vimp(forest, type="raw", events=1:2, time=4.0, randomSeed=50,
censoringDistribution = c(0,1,1,2,3,4))
importance2 <- vimp(forest, type="raw", events=1:2, time=4.0, randomSeed=50,
censoringDistribution = list(x = 0:4, y = 1 - c(1/6, 3/6, 4/6, 5/6, 6/6)))
importance3 <- vimp(forest, type="raw", events=1:2, time=4.0, randomSeed=50,
censoringDistribution = stepfun(x=0:4, y=1 - c(0, 1/6, 3/6, 4/6, 5/6, 6/6)))
expect_equal(importance1, importance2)
expect_equal(importance1, importance3)
})
test_that("VIMP doesn't crash; regression dataset", {
data <- data.frame(x1=rnorm(1000), x2=rnorm(1000), x3=rnorm(1000))
data$y <- data$x1 + 3*data$x2 + 0.05*data$x3 + rnorm(1000)
forest <- train(y ~ ., data, ntree=50, numberOfSplits=100, mtry=2, nodeSize=5, displayProgress=FALSE)
importance <- vimp(forest, type="mean")
expect_true(importance["x2"] > importance["x3"])
# Not much of a test, but the Java code tests more for correctness. This just
# tests that the R code runs without error.
expect_equal(length(importance), 3) # 3 predictors
})
test_that("VIMP produces mean and z scores correctly", {
data <- data.frame(x1=rnorm(1000), x2=rnorm(1000), x3=rnorm(1000))
data$y <- data$x1 + 3*data$x2 + 0.05*data$x3 + rnorm(1000)
forest <- train(y ~ ., data, ntree=50, numberOfSplits=100, mtry=2, nodeSize=5, displayProgress=FALSE)
actual.importance.raw <- vimp(forest, type="raw", randomSeed=5)
actual.importance.mean <- vimp(forest, type="mean", randomSeed=5)
actual.importance.z <- vimp(forest, type="z", randomSeed=5)
expected.importance.mean <- apply(actual.importance.raw, 2, mean)
expected.importance.z <- apply(actual.importance.raw, 2, function(x){
mn <- mean(x)
return( mn / (sd(x) / sqrt(length(x))) )
})
expect_equal(expected.importance.mean, actual.importance.mean)
expect_equal(expected.importance.z, actual.importance.z)
})