diff --git a/NAMESPACE b/NAMESPACE index c60de41..402bb00 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -35,4 +35,5 @@ export(loadForest) export(naiveConcordance) export(saveForest) export(train) +export(vimp) import(rJava) diff --git a/R/cr_integratedBrierScore.R b/R/cr_integratedBrierScore.R index 9525238..a86645e 100644 --- a/R/cr_integratedBrierScore.R +++ b/R/cr_integratedBrierScore.R @@ -59,45 +59,14 @@ integratedBrierScore <- function(responses, predictions, event, time, censoringD java.censoringDistribution <- NULL if(!is.null(censoringDistribution)){ - if(is.numeric(censoringDistribution)){ - # estimate ECDF - censoringTimes <- .jarray(censoringDistribution, "D") - java.censoringDistribution <- .jcall(.class_Utils, makeResponse(.class_RightContinuousStepFunction), "estimateOneMinusECDF", censoringTimes) - - } else if(is.list(censoringDistribution)){ - # First check that censoringDistribution fits the correct format - if(is.null(censoringDistribution$x) | is.null(censoringDistribution$y)){ - stop("If the censoringDistribution is provided as a list, it must have an x and a y item that are numeric.") - } - - if(length(censoringDistribution$x) != length(censoringDistribution$y)){ - stop("x and y in censoringDistribution must have the same length.") - } - - if(!is.numeric(censoringDistribution$x) | !is.numeric(censoringDistribution$y)){ - stop("x and y in censoringDistribution must both be numeric.") - } - - java.censoringDistribution <- createRightContinuousStepFunction(censoringDistribution$x, censoringDistribution$y, defaultY = 1.0) - - } else if("stepfun" %in% class(censoringDistribution)){ - x <- stats::knots(censoringDistribution) - y <- censoringDistribution(x) - - java.censoringDistribution <- createRightContinuousStepFunction(x, y, defaultY = 1.0) - } - else{ - stop("Invalid censoringDistribution") - } - - # Make sure we wrap it in an Optional + java.censoringDistribution <- processCensoringDistribution(censoringDistribution) java.censoringDistribution <- .object_Optional(java.censoringDistribution) - } else{ java.censoringDistribution <- .object_Optional(NULL) } + predictions.java <- lapply(predictions, function(x){return(x$javaObject)}) predictions.java <- convertRListToJava(predictions.java) diff --git a/R/java_classes_directory.R b/R/java_classes_directory.R index 97598c3..8ffe241 100644 --- a/R/java_classes_directory.R +++ b/R/java_classes_directory.R @@ -10,6 +10,7 @@ .class_Collection <- "java/util/Collection" .class_Serializable <- "java/io/Serializable" .class_File <- "java/io/File" +.class_Random <- "java/util/Random" # Utility Classes .class_DataUtils <- "ca/joeltherrien/randomforest/utils/DataUtils" @@ -53,6 +54,14 @@ .class_LogRankSplitFinder <- "ca/joeltherrien/randomforest/responses/competingrisk/splitfinder/LogRankSplitFinder" .class_WeightedVarianceSplitFinder <- "ca/joeltherrien/randomforest/responses/regression/WeightedVarianceSplitFinder" +# VIMP classes +.class_IBSCalculator <- "ca/joeltherrien/randomforest/responses/competingrisk/IBSCalculator" +.class_ErrorCalculator <- "ca/joeltherrien/randomforest/tree/vimp/ErrorCalculator" +.class_RegressionErrorCalculator <- "ca/joeltherrien/randomforest/tree/vimp/RegressionErrorCalculator" +.class_IBSErrorCalculatorWrapper <- "ca/joeltherrien/randomforest/tree/vimp/IBSErrorCalculatorWrapper" +.class_VariableImportanceCalculator <- "ca/joeltherrien/randomforest/tree/vimp/VariableImportanceCalculator" + + .object_Optional <- function(object=NULL){ if(is.null(object)){ return(.jcall("java/util/Optional", "Ljava/util/Optional;", "empty")) diff --git a/R/loadData.R b/R/loadData.R index 9e870c2..7a1c0df 100644 --- a/R/loadData.R +++ b/R/loadData.R @@ -18,7 +18,7 @@ loadData <- function(data, xVarNames, responses, covariateList.java = NULL){ rowList <- .jcall(.class_RUtils, makeResponse(.class_List), "importDataWithResponses", responses$javaObject, covariateList.java, textData) - return(list(covariateList=covariateList.java, dataset=rowList)) + return(list(covariateList = covariateList.java, dataset = rowList, responses = responses)) } @@ -54,15 +54,7 @@ getCovariateList <- function(data, xvarNames){ loadPredictionData <- function(newData, covariateList.java){ - xVarNames <- character(.jcall(covariateList.java, "I", "size")) - for(j in 1:length(xVarNames)){ - covariate.java <- .jcast( - .jcall(covariateList.java, makeResponse(.class_Object), "get", as.integer(j-1)), - .class_Covariate - ) - - xVarNames[j] <- .jcall(covariate.java, makeResponse(.class_String), "getName") - } + xVarNames <- extractCovariateNamesFromJavaList(covariateList.java) if(any(!(xVarNames %in% names(newData)))){ varsMissing = xVarNames[!(xVarNames %in% names(newData))] @@ -84,3 +76,16 @@ loadPredictionData <- function(newData, covariateList.java){ return(rowList) } +extractCovariateNamesFromJavaList <- function(covariateList.java){ + xVarNames <- character(.jcall(covariateList.java, "I", "size")) + for(j in 1:length(xVarNames)){ + covariate.java <- .jcast( + .jcall(covariateList.java, makeResponse(.class_Object), "get", as.integer(j-1)), + .class_Covariate + ) + + xVarNames[j] <- .jcall(covariate.java, makeResponse(.class_String), "getName") + } + + return(xVarNames) +} diff --git a/R/misc.R b/R/misc.R index dc4f040..6bb24e0 100644 --- a/R/misc.R +++ b/R/misc.R @@ -14,6 +14,23 @@ convertRListToJava <- function(lst){ return(javaList) } +#Internal function +convertJavaListToR <- function(javaList, class = .class_Object){ + lst <- list() + + javaList.length <- .jcall(javaList, "I", "size") + + for(i in 0:(javaList.length - 1)){ + object <- .jcall(javaList, makeResponse(.class_Object), "get", as.integer(i)) + object <- .jcast(object, class) + + lst[[i+1]] <- object + } + + return(lst) + +} + #' @export print.SplitFinder = function(x, ...) print(x$call) diff --git a/R/processCensoringDistribution.R b/R/processCensoringDistribution.R new file mode 100644 index 0000000..7e57048 --- /dev/null +++ b/R/processCensoringDistribution.R @@ -0,0 +1,39 @@ + + +# Internal function. Takes a censoring distribution and turns it into a +# RightContinuousStepFunction Java object. +processCensoringDistribution <- function(censoringDistribution){ + + if(is.numeric(censoringDistribution)){ + # estimate ECDF + censoringTimes <- .jarray(censoringDistribution, "D") + java.censoringDistribution <- .jcall(.class_Utils, makeResponse(.class_RightContinuousStepFunction), "estimateOneMinusECDF", censoringTimes) + + } else if(is.list(censoringDistribution)){ + # First check that censoringDistribution fits the correct format + if(is.null(censoringDistribution$x) | is.null(censoringDistribution$y)){ + stop("If the censoringDistribution is provided as a list, it must have an x and a y item that are numeric.") + } + + if(length(censoringDistribution$x) != length(censoringDistribution$y)){ + stop("x and y in censoringDistribution must have the same length.") + } + + if(!is.numeric(censoringDistribution$x) | !is.numeric(censoringDistribution$y)){ + stop("x and y in censoringDistribution must both be numeric.") + } + + java.censoringDistribution <- createRightContinuousStepFunction(censoringDistribution$x, censoringDistribution$y, defaultY = 1.0) + + } else if("stepfun" %in% class(censoringDistribution)){ + x <- stats::knots(censoringDistribution) + y <- censoringDistribution(x) + + java.censoringDistribution <- createRightContinuousStepFunction(x, y, defaultY = 1.0) + } + else{ + stop("Invalid censoringDistribution") + } + + return(java.censoringDistribution) +} \ No newline at end of file diff --git a/R/processFormula.R b/R/processFormula.R new file mode 100644 index 0000000..26d1097 --- /dev/null +++ b/R/processFormula.R @@ -0,0 +1,59 @@ + +# Internal function that takes a formula and processes it for use in the Java +# code. existingCovariateList is optional; if not provided then a new one is +# created internally. +processFormula <- function(formula, data, covariateList.java = NULL){ + + # Having an R copy of the data loaded at the same time can be wasteful; we + # also allow users to provide an environment of the data which gets removed + # after being imported into Java + if(class(data) == "environment"){ + if(is.null(data$data)){ + stop("When providing an environment with the dataset, the environment must contain an item called 'data'") + } + + env <- data + data <- env$data + env$data <- NULL + gc() + } + + yVar <- formula[[2]] + + responses <- NULL + variablesToDrop <- character(0) + + # yVar is a call object; as.character(yVar) will be the different components, including the parameters. + # if the length of yVar is > 1 then it's a function call. If the length is 1, and it's not in data, + # then we also need to explicitly evaluate it + if(class(yVar) == "call" || !(as.character(yVar) %in% colnames(data))){ + # yVar is a function like CR_Response + responses <- eval(expr=yVar, envir=data) + + if(class(formula[[3]]) == "name" && as.character(formula[[3]])=="."){ + # do any of the variables match data in data? We need to track that so we can drop them later + variablesToDrop <- as.character(yVar)[as.character(yVar) %in% names(data)] + } + + formula[[2]] <- NULL + + } else if(class(yVar) == "name"){ # and implicitly yVar is contained in data + variablesToDrop <- as.character(yVar) + } + + # Includes responses which we may need to later cut out if `.` was used on the + # right-hand-side + filteredData <- stats::model.frame(formula=formula, data=data, na.action=stats::na.pass) + + if(is.null(responses)){ # If this if-statement runs then we have a simple (i.e. numeric) response + responses <- stats::model.response(filteredData) + } + + # remove any response variables on the right-hand-side + covariateData <- filteredData[, !(names(filteredData) %in% variablesToDrop), drop=FALSE] + + + dataset <- loadData(covariateData, colnames(covariateData), responses, covariateList.java = covariateList.java) + + return(dataset) +} \ No newline at end of file diff --git a/R/train.R b/R/train.R index 2cfa379..116bc09 100644 --- a/R/train.R +++ b/R/train.R @@ -14,7 +14,7 @@ getCores <- function(){ return(cores) } -train.internal <- function(responses, covariateData, splitFinder, +train.internal <- function(dataset, splitFinder, nodeResponseCombiner, forestResponseCombiner, ntree, numberOfSplits, mtry, nodeSize, maxNodeDepth, splitPureNodes, savePath, savePath.overwrite, @@ -52,15 +52,15 @@ train.internal <- function(responses, covariateData, splitFinder, } if(is.null(splitFinder)){ - splitFinder <- splitFinderDefault(responses) + splitFinder <- splitFinderDefault(dataset$responses) } if(is.null(nodeResponseCombiner)){ - nodeResponseCombiner <- nodeResponseCombinerDefault(responses) + nodeResponseCombiner <- nodeResponseCombinerDefault(dataset$responses) } if(is.null(forestResponseCombiner)){ - forestResponseCombiner <- forestResponseCombinerDefault(responses) + forestResponseCombiner <- forestResponseCombinerDefault(dataset$responses) } @@ -75,20 +75,6 @@ train.internal <- function(responses, covariateData, splitFinder, stop("forestResponseCombiner must be a ResponseCombiner") } - if(class(covariateData)=="environment"){ - if(is.null(covariateData$data)){ - stop("When providing an environment with the dataset, the environment must contain an item called 'data'") - } - dataset <- loadData(covariateData$data, colnames(covariateData$data), responses) - covariateData$data <- NULL # save memory, hopefully - gc() # explicitly try to save memory - } - else{ - dataset <- loadData(covariateData, colnames(covariateData), responses) - } - - - treeTrainer <- createTreeTrainer(responseCombiner=nodeResponseCombiner, splitFinder=splitFinder, covariateList=dataset$covariateList, @@ -188,7 +174,7 @@ train.internal <- function(responses, covariateData, splitFinder, #' #' @param formula You may specify the response and covariates as a formula #' instead; make sure the response in the formula is still properly -#' constructed; see \code{responses} +#' constructed. #' @param data A data.frame containing the columns of the predictors and #' responses. #' @param splitFinder A split finder that's used to score splits in the random @@ -314,74 +300,16 @@ train <- function(formula, data, splitFinder = NULL, nodeResponseCombiner = NULL savePath.overwrite=c("warn", "delete", "merge"), cores = getCores(), randomSeed = NULL, displayProgress = TRUE){ - # Having an R copy of the data loaded at the same time can be wasteful; we - # also allow users to provide an environment of the data which gets removed - # after being imported into Java - env <- NULL - if(class(data) == "environment"){ - if(is.null(data$data)){ - stop("When providing an environment with the dataset, the environment must contain an item called 'data'") - } - - env <- data - data <- env$data - } - - yVar <- formula[[2]] - - responses <- NULL - variablesToDrop <- character(0) + dataset <- processFormula(formula, data) - # yVar is a call object; as.character(yVar) will be the different components, including the parameters. - # if the length of yVar is > 1 then it's a function call. If the length is 1, and it's not in data, - # then we also need to explicitly evaluate it - if(class(yVar)=="call" || !(as.character(yVar) %in% colnames(data))){ - # yVar is a function like CompetingRiskResponses - responses <- eval(expr=yVar, envir=data) - - if(class(formula[[3]]) == "name" && as.character(formula[[3]])=="."){ - # do any of the variables match data in data? We need to track that so we can drop them later - variablesToDrop <- as.character(yVar)[as.character(yVar) %in% names(data)] - } - - formula[[2]] <- NULL - - } else if(class(yVar)=="name"){ # and implicitly yVar is contained in data - variablesToDrop <- as.character(yVar) - } - - # Includes responses which we may need to later cut out - mf <- stats::model.frame(formula=formula, data=data, na.action=stats::na.pass) - - if(is.null(responses)){ - responses <- stats::model.response(mf) - } - - # remove any response variables - mf <- mf[,!(names(mf) %in% variablesToDrop), drop=FALSE] - - # If environment was provided instead of data - if(!is.null(env)){ - env$data <- mf - rm(data) - forest <- train.internal(responses, env, splitFinder = splitFinder, - nodeResponseCombiner = nodeResponseCombiner, - forestResponseCombiner = forestResponseCombiner, - ntree = ntree, numberOfSplits = numberOfSplits, - mtry = mtry, nodeSize = nodeSize, maxNodeDepth = maxNodeDepth, - splitPureNodes = splitPureNodes, savePath = savePath, - savePath.overwrite = savePath.overwrite, cores = cores, - randomSeed = randomSeed, displayProgress = displayProgress) - } else{ - forest <- train.internal(responses, mf, splitFinder = splitFinder, - nodeResponseCombiner = nodeResponseCombiner, - forestResponseCombiner = forestResponseCombiner, - ntree = ntree, numberOfSplits = numberOfSplits, - mtry = mtry, nodeSize = nodeSize, maxNodeDepth = maxNodeDepth, - splitPureNodes = splitPureNodes, savePath = savePath, - savePath.overwrite = savePath.overwrite, cores = cores, - randomSeed = randomSeed, displayProgress = displayProgress) - } + forest <- train.internal(dataset, splitFinder = splitFinder, + nodeResponseCombiner = nodeResponseCombiner, + forestResponseCombiner = forestResponseCombiner, + ntree = ntree, numberOfSplits = numberOfSplits, + mtry = mtry, nodeSize = nodeSize, maxNodeDepth = maxNodeDepth, + splitPureNodes = splitPureNodes, savePath = savePath, + savePath.overwrite = savePath.overwrite, cores = cores, + randomSeed = randomSeed, displayProgress = displayProgress) forest$call <- match.call() forest$formula <- formula diff --git a/R/vimp.R b/R/vimp.R new file mode 100644 index 0000000..27211e1 --- /dev/null +++ b/R/vimp.R @@ -0,0 +1,175 @@ + + +#' Variable Importance +#' +#' Calculate variable importance by recording the increase in error when a given +#' predictor is randomly permuted. Regression forests uses mean squared error; +#' competing risks uses integrated Brier score. +#' +#' @param forest The forest that was trained. +#' @param newData A test set of the data if available. If not, then out of bag +#' errors will be attempted on the training set. +#' @param randomSeed The source of randomness used to permute the values. Can be +#' left blank. +#' @param events If using competing risks forest, the events that the error +#' measure used for VIMP should be calculated on. +#' @param time If using competing risks forest, the upper bound of the +#' integrated Brier score. +#' @param censoringDistribution (Optional) If using competing risks forest, the +#' censoring distribution. See \code{\link{integratedBrierScore} for details.} +#' @param eventWeights (Optional) If using competing risks forest, weights to be +#' applied to the error for each of the \code{events}. +#' +#' @return A named numeric vector of importance values. +#' @export +#' +#' @examples +#' data(wihs) +#' +#' forest <- train(CR_Response(status, time) ~ ., wihs, +#' ntree = 100, numberOfSplits = 0, mtry=3, nodeSize = 5) +#' +#' vimp(forest, events = 1:2, time = 8.0) +#' +vimp <- function( + forest, + newData = NULL, + randomSeed = NULL, + type = c("mean", "z", "raw"), + events = NULL, + time = NULL, + censoringDistribution = NULL, + eventWeights = NULL){ + + if(is.null(newData) & is.null(forest$dataset)){ + stop("forest doesn't have a copy of the training data loaded (this happens if you just loaded it); please manually specify newData and possibly out.of.bag") + } + + # Basically we check if type is either null, length 0, or one of the invalid values. + # We can't include the last statement in the same statement as length(tyoe) < 1, + # because R checks both cases and a different error would display if length(type) == 0 + typeError = is.null(type) | length(type) < 1 + if(!typeError){ + typeError = !(type[1] %in% c("mean", "z", "raw")) + } + if(typeError){ + stop("A valid response type must be provided.") + } + + if(is.null(newData)){ + data.java <- forest$dataset + out.of.bag <- TRUE + + } + else{ # newData is provided + data.java <- processFormula(forest$formula, newData, forest$covariateList)$dataset + out.of.bag <- FALSE + } + + predictionClass <- forest$params$forestResponseCombiner$outputClass + + if(predictionClass == "CompetingRiskFunctions"){ + if(is.null(time) | length(time) != 1){ + stop("time must be set at length 1") + } + + errorCalculator.java <- ibsCalculatorWrapper( + events = events, + time = time, + censoringDistribution = censoringDistribution, + eventWeights = eventWeights) + + } else if(predictionClass == "numeric"){ + errorCalculator.java <- .jnew(.class_RegressionErrorCalculator) + errorCalculator.java <- .jcast(errorCalculator.java, .class_ErrorCalculator) + + } else{ + stop(paste0("VIMP not yet supported for ", predictionClass, ". If you're just using a non-custom version of largeRCRF then this is a bug and should be reported.")) + + } + + forest.trees.java <- .jcall(forest$javaObject, makeResponse(.class_List), "getTrees") + + vimp.calculator <- .jnew(.class_VariableImportanceCalculator, + errorCalculator.java, + forest.trees.java, + data.java, + out.of.bag # isTrainingSet parameter + ) + + random.java <- NULL + if(!is.null(randomSeed)){ + random.java <- .jnew(.class_Random, .jlong(as.integer(randomSeed))) + } + random.java <- .object_Optional(random.java) + + covariateRList <- convertJavaListToR(forest$covariateList, class = .class_Covariate) + importanceValues <- matrix(nrow = forest$params$ntree, ncol = length(covariateRList)) + colnames(importanceValues) <- extractCovariateNamesFromJavaList(forest$covariateList) + + for(j in 1:length(covariateRList)){ + covariateJava <- covariateRList[[j]] + covariateJava <- + + importanceValues[, j] <- .jcall(vimp.calculator, "[D", "calculateVariableImportanceRaw", covariateJava, random.java) + } + + if(type[1] == "raw"){ + return(importanceValues) + } else if(type[1] == "mean"){ + meanImportanceValues <- apply(importanceValues, 2, mean) + return(meanImportanceValues) + } else if(type[1] == "z"){ + zImportanceValues <- apply(importanceValues, 2, function(x){ + meanValue <- mean(x) + standardError <- sd(x)/sqrt(length(x)) + return(meanValue / standardError) + }) + return(zImportanceValues) + + } else{ + stop("A valid response type must be provided.") + } + + + return(importance) + +} + +# Internal function +ibsCalculatorWrapper <- function(events, time, censoringDistribution = NULL, eventWeights = NULL){ + if(is.null(events)){ + stop("events must be specified if using vimp on competing risks data") + } + + if(is.null(time)){ + stop("time must be specified if using vimp on competing risks data") + } + + + java.censoringDistribution <- NULL + if(!is.null(censoringDistribution)){ + java.censoringDistribution <- processCensoringDistribution(censoringDistribution) + java.censoringDistribution <- .object_Optional(java.censoringDistribution) + } + else{ + java.censoringDistribution <- .object_Optional(NULL) + } + + ibsCalculator.java <- .jnew(.class_IBSCalculator, java.censoringDistribution) + + if(is.null(eventWeights)){ + eventWeights <- rep(1, times = length(events)) + } + + ibsCalculatorWrapper.java <- .jnew(.class_IBSErrorCalculatorWrapper, + ibsCalculator.java, + .jarray(as.integer(events)), + as.numeric(time), + .jarray(as.numeric(eventWeights))) + + ibsCalculatorWrapper.java <- .jcast(ibsCalculatorWrapper.java, .class_ErrorCalculator) + return(ibsCalculatorWrapper.java) + + +} \ No newline at end of file diff --git a/inst/java/largeRCRF-library-1.0-SNAPSHOT.jar b/inst/java/largeRCRF-library-1.0-SNAPSHOT.jar index f7ce8af..ad8f858 100644 Binary files a/inst/java/largeRCRF-library-1.0-SNAPSHOT.jar and b/inst/java/largeRCRF-library-1.0-SNAPSHOT.jar differ diff --git a/man/train.Rd b/man/train.Rd index faa2542..a07e47c 100644 --- a/man/train.Rd +++ b/man/train.Rd @@ -13,7 +13,7 @@ train(formula, data, splitFinder = NULL, nodeResponseCombiner = NULL, \arguments{ \item{formula}{You may specify the response and covariates as a formula instead; make sure the response in the formula is still properly -constructed; see \code{responses}} +constructed.} \item{data}{A data.frame containing the columns of the predictors and responses.} diff --git a/man/vimp.Rd b/man/vimp.Rd new file mode 100644 index 0000000..21ce486 --- /dev/null +++ b/man/vimp.Rd @@ -0,0 +1,48 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/vimp.R +\name{vimp} +\alias{vimp} +\title{Variable Importance} +\usage{ +vimp(forest, newData = NULL, randomSeed = NULL, type = c("mean", "z", + "raw"), events = NULL, time = NULL, censoringDistribution = NULL, + eventWeights = NULL) +} +\arguments{ +\item{forest}{The forest that was trained.} + +\item{newData}{A test set of the data if available. If not, then out of bag +errors will be attempted on the training set.} + +\item{randomSeed}{The source of randomness used to permute the values. Can be +left blank.} + +\item{events}{If using competing risks forest, the events that the error +measure used for VIMP should be calculated on.} + +\item{time}{If using competing risks forest, the upper bound of the +integrated Brier score.} + +\item{censoringDistribution}{(Optional) If using competing risks forest, the +censoring distribution. See \code{\link{integratedBrierScore} for details.}} + +\item{eventWeights}{(Optional) If using competing risks forest, weights to be +applied to the error for each of the \code{events}.} +} +\value{ +A named numeric vector of importance values. +} +\description{ +Calculate variable importance by recording the increase in error when a given +predictor is randomly permuted. Regression forests uses mean squared error; +competing risks uses integrated Brier score. +} +\examples{ +data(wihs) + +forest <- train(CR_Response(status, time) ~ ., wihs, + ntree = 100, numberOfSplits = 0, mtry=3, nodeSize = 5) + +vimp(forest, events = 1:2, time = 8.0) + +} diff --git a/tests/testthat/test_vimp.R b/tests/testthat/test_vimp.R new file mode 100644 index 0000000..b6115d5 --- /dev/null +++ b/tests/testthat/test_vimp.R @@ -0,0 +1,102 @@ +context("Use VIMP without error") + +test_that("VIMP doesn't crash; no test dataset", { + + data(wihs) + + forest <- train(CR_Response(status, time) ~ ., wihs, ntree=50, numberOfSplits=0, mtry=1, nodeSize=5, displayProgress=FALSE) + + # Run VIMP several times under different scenarios + importance <- vimp(forest, type="raw", events=1:2, time=5.0) + vimp(forest, type="raw", events=1, time=5.0) + vimp(forest, type="raw", events=1:2, time=5.0, eventWeights = c(0.2, 0.8)) + + # Not much of a test, but the Java code tests more for correctness. This just + # tests that the R code runs without error. + expect_equal(ncol(importance), 4) # 4 predictors + +}) + + +test_that("VIMP doesn't crash; test dataset", { + + data(wihs) + + trainingData <- wihs[1:1000,] + testData <- wihs[1001:nrow(wihs),] + + forest <- train(CR_Response(status, time) ~ ., trainingData, ntree=50, numberOfSplits=0, mtry=1, nodeSize=5, displayProgress=FALSE, cores=1) + + # Run VIMP several times under different scenarios + importance <- vimp(forest, newData=testData, type="raw", events=1:2, time=5.0) + vimp(forest, newData=testData, type="raw", events=1, time=5.0) + vimp(forest, newData=testData, type="raw", events=1:2, time=5.0, eventWeights = c(0.2, 0.8)) + + # Not much of a test, but the Java code tests more for correctness. This just + # tests that the R code runs without error. + expect_equal(ncol(importance), 4) # 4 predictors + +}) + + +test_that("VIMP doesn't crash; censoring distribution; all methods equal", { + + sampleData <- data.frame(x=rnorm(100)) + sampleData$T <- sample(0:4, size=100, replace = TRUE) # the censor distribution we provide needs to conform to the data or we can get NaNs + sampleData$delta <- sample(0:2, size = 100, replace = TRUE) + + testData <- sampleData[1:5,] + trainingData <- sampleData[6:100,] + + forest <- train(CR_Response(delta, T) ~ x, trainingData, ntree=50, numberOfSplits=0, mtry=1, nodeSize=5, cores=2, displayProgress=FALSE) + + importance1 <- vimp(forest, type="raw", events=1:2, time=4.0, randomSeed=50, + censoringDistribution = c(0,1,1,2,3,4)) + importance2 <- vimp(forest, type="raw", events=1:2, time=4.0, randomSeed=50, + censoringDistribution = list(x = 0:4, y = 1 - c(1/6, 3/6, 4/6, 5/6, 6/6))) + importance3 <- vimp(forest, type="raw", events=1:2, time=4.0, randomSeed=50, + censoringDistribution = stepfun(x=0:4, y=1 - c(0, 1/6, 3/6, 4/6, 5/6, 6/6))) + + expect_equal(importance1, importance2) + expect_equal(importance1, importance3) + +}) + +test_that("VIMP doesn't crash; regression dataset", { + + data <- data.frame(x1=rnorm(1000), x2=rnorm(1000), x3=rnorm(1000)) + data$y <- data$x1 + 3*data$x2 + 0.05*data$x3 + rnorm(1000) + + forest <- train(y ~ ., data, ntree=50, numberOfSplits=100, mtry=2, nodeSize=5, displayProgress=FALSE) + + importance <- vimp(forest, type="mean") + + expect_true(importance["x2"] > importance["x3"]) + + # Not much of a test, but the Java code tests more for correctness. This just + # tests that the R code runs without error. + expect_equal(length(importance), 3) # 3 predictors + +}) + +test_that("VIMP produces mean and z scores correctly", { + + data <- data.frame(x1=rnorm(1000), x2=rnorm(1000), x3=rnorm(1000)) + data$y <- data$x1 + 3*data$x2 + 0.05*data$x3 + rnorm(1000) + + forest <- train(y ~ ., data, ntree=50, numberOfSplits=100, mtry=2, nodeSize=5, displayProgress=FALSE) + + actual.importance.raw <- vimp(forest, type="raw", randomSeed=5) + actual.importance.mean <- vimp(forest, type="mean", randomSeed=5) + actual.importance.z <- vimp(forest, type="z", randomSeed=5) + + expected.importance.mean <- apply(actual.importance.raw, 2, mean) + expected.importance.z <- apply(actual.importance.raw, 2, function(x){ + mn <- mean(x) + return( mn / (sd(x) / sqrt(length(x))) ) + }) + + expect_equal(expected.importance.mean, actual.importance.mean) + expect_equal(expected.importance.z, actual.importance.z) + +}) \ No newline at end of file