Add variable importance
This commit is contained in:
parent
fd8621a88d
commit
48859b0249
13 changed files with 482 additions and 130 deletions
|
@ -35,4 +35,5 @@ export(loadForest)
|
||||||
export(naiveConcordance)
|
export(naiveConcordance)
|
||||||
export(saveForest)
|
export(saveForest)
|
||||||
export(train)
|
export(train)
|
||||||
|
export(vimp)
|
||||||
import(rJava)
|
import(rJava)
|
||||||
|
|
|
@ -59,45 +59,14 @@ integratedBrierScore <- function(responses, predictions, event, time, censoringD
|
||||||
|
|
||||||
java.censoringDistribution <- NULL
|
java.censoringDistribution <- NULL
|
||||||
if(!is.null(censoringDistribution)){
|
if(!is.null(censoringDistribution)){
|
||||||
if(is.numeric(censoringDistribution)){
|
java.censoringDistribution <- processCensoringDistribution(censoringDistribution)
|
||||||
# estimate ECDF
|
|
||||||
censoringTimes <- .jarray(censoringDistribution, "D")
|
|
||||||
java.censoringDistribution <- .jcall(.class_Utils, makeResponse(.class_RightContinuousStepFunction), "estimateOneMinusECDF", censoringTimes)
|
|
||||||
|
|
||||||
} else if(is.list(censoringDistribution)){
|
|
||||||
# First check that censoringDistribution fits the correct format
|
|
||||||
if(is.null(censoringDistribution$x) | is.null(censoringDistribution$y)){
|
|
||||||
stop("If the censoringDistribution is provided as a list, it must have an x and a y item that are numeric.")
|
|
||||||
}
|
|
||||||
|
|
||||||
if(length(censoringDistribution$x) != length(censoringDistribution$y)){
|
|
||||||
stop("x and y in censoringDistribution must have the same length.")
|
|
||||||
}
|
|
||||||
|
|
||||||
if(!is.numeric(censoringDistribution$x) | !is.numeric(censoringDistribution$y)){
|
|
||||||
stop("x and y in censoringDistribution must both be numeric.")
|
|
||||||
}
|
|
||||||
|
|
||||||
java.censoringDistribution <- createRightContinuousStepFunction(censoringDistribution$x, censoringDistribution$y, defaultY = 1.0)
|
|
||||||
|
|
||||||
} else if("stepfun" %in% class(censoringDistribution)){
|
|
||||||
x <- stats::knots(censoringDistribution)
|
|
||||||
y <- censoringDistribution(x)
|
|
||||||
|
|
||||||
java.censoringDistribution <- createRightContinuousStepFunction(x, y, defaultY = 1.0)
|
|
||||||
}
|
|
||||||
else{
|
|
||||||
stop("Invalid censoringDistribution")
|
|
||||||
}
|
|
||||||
|
|
||||||
# Make sure we wrap it in an Optional
|
|
||||||
java.censoringDistribution <- .object_Optional(java.censoringDistribution)
|
java.censoringDistribution <- .object_Optional(java.censoringDistribution)
|
||||||
|
|
||||||
}
|
}
|
||||||
else{
|
else{
|
||||||
java.censoringDistribution <- .object_Optional(NULL)
|
java.censoringDistribution <- .object_Optional(NULL)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
predictions.java <- lapply(predictions, function(x){return(x$javaObject)})
|
predictions.java <- lapply(predictions, function(x){return(x$javaObject)})
|
||||||
predictions.java <- convertRListToJava(predictions.java)
|
predictions.java <- convertRListToJava(predictions.java)
|
||||||
|
|
||||||
|
|
|
@ -10,6 +10,7 @@
|
||||||
.class_Collection <- "java/util/Collection"
|
.class_Collection <- "java/util/Collection"
|
||||||
.class_Serializable <- "java/io/Serializable"
|
.class_Serializable <- "java/io/Serializable"
|
||||||
.class_File <- "java/io/File"
|
.class_File <- "java/io/File"
|
||||||
|
.class_Random <- "java/util/Random"
|
||||||
|
|
||||||
# Utility Classes
|
# Utility Classes
|
||||||
.class_DataUtils <- "ca/joeltherrien/randomforest/utils/DataUtils"
|
.class_DataUtils <- "ca/joeltherrien/randomforest/utils/DataUtils"
|
||||||
|
@ -53,6 +54,14 @@
|
||||||
.class_LogRankSplitFinder <- "ca/joeltherrien/randomforest/responses/competingrisk/splitfinder/LogRankSplitFinder"
|
.class_LogRankSplitFinder <- "ca/joeltherrien/randomforest/responses/competingrisk/splitfinder/LogRankSplitFinder"
|
||||||
.class_WeightedVarianceSplitFinder <- "ca/joeltherrien/randomforest/responses/regression/WeightedVarianceSplitFinder"
|
.class_WeightedVarianceSplitFinder <- "ca/joeltherrien/randomforest/responses/regression/WeightedVarianceSplitFinder"
|
||||||
|
|
||||||
|
# VIMP classes
|
||||||
|
.class_IBSCalculator <- "ca/joeltherrien/randomforest/responses/competingrisk/IBSCalculator"
|
||||||
|
.class_ErrorCalculator <- "ca/joeltherrien/randomforest/tree/vimp/ErrorCalculator"
|
||||||
|
.class_RegressionErrorCalculator <- "ca/joeltherrien/randomforest/tree/vimp/RegressionErrorCalculator"
|
||||||
|
.class_IBSErrorCalculatorWrapper <- "ca/joeltherrien/randomforest/tree/vimp/IBSErrorCalculatorWrapper"
|
||||||
|
.class_VariableImportanceCalculator <- "ca/joeltherrien/randomforest/tree/vimp/VariableImportanceCalculator"
|
||||||
|
|
||||||
|
|
||||||
.object_Optional <- function(object=NULL){
|
.object_Optional <- function(object=NULL){
|
||||||
if(is.null(object)){
|
if(is.null(object)){
|
||||||
return(.jcall("java/util/Optional", "Ljava/util/Optional;", "empty"))
|
return(.jcall("java/util/Optional", "Ljava/util/Optional;", "empty"))
|
||||||
|
|
25
R/loadData.R
25
R/loadData.R
|
@ -18,7 +18,7 @@ loadData <- function(data, xVarNames, responses, covariateList.java = NULL){
|
||||||
rowList <- .jcall(.class_RUtils, makeResponse(.class_List), "importDataWithResponses",
|
rowList <- .jcall(.class_RUtils, makeResponse(.class_List), "importDataWithResponses",
|
||||||
responses$javaObject, covariateList.java, textData)
|
responses$javaObject, covariateList.java, textData)
|
||||||
|
|
||||||
return(list(covariateList=covariateList.java, dataset=rowList))
|
return(list(covariateList = covariateList.java, dataset = rowList, responses = responses))
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -54,15 +54,7 @@ getCovariateList <- function(data, xvarNames){
|
||||||
|
|
||||||
loadPredictionData <- function(newData, covariateList.java){
|
loadPredictionData <- function(newData, covariateList.java){
|
||||||
|
|
||||||
xVarNames <- character(.jcall(covariateList.java, "I", "size"))
|
xVarNames <- extractCovariateNamesFromJavaList(covariateList.java)
|
||||||
for(j in 1:length(xVarNames)){
|
|
||||||
covariate.java <- .jcast(
|
|
||||||
.jcall(covariateList.java, makeResponse(.class_Object), "get", as.integer(j-1)),
|
|
||||||
.class_Covariate
|
|
||||||
)
|
|
||||||
|
|
||||||
xVarNames[j] <- .jcall(covariate.java, makeResponse(.class_String), "getName")
|
|
||||||
}
|
|
||||||
|
|
||||||
if(any(!(xVarNames %in% names(newData)))){
|
if(any(!(xVarNames %in% names(newData)))){
|
||||||
varsMissing = xVarNames[!(xVarNames %in% names(newData))]
|
varsMissing = xVarNames[!(xVarNames %in% names(newData))]
|
||||||
|
@ -84,3 +76,16 @@ loadPredictionData <- function(newData, covariateList.java){
|
||||||
return(rowList)
|
return(rowList)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
extractCovariateNamesFromJavaList <- function(covariateList.java){
|
||||||
|
xVarNames <- character(.jcall(covariateList.java, "I", "size"))
|
||||||
|
for(j in 1:length(xVarNames)){
|
||||||
|
covariate.java <- .jcast(
|
||||||
|
.jcall(covariateList.java, makeResponse(.class_Object), "get", as.integer(j-1)),
|
||||||
|
.class_Covariate
|
||||||
|
)
|
||||||
|
|
||||||
|
xVarNames[j] <- .jcall(covariate.java, makeResponse(.class_String), "getName")
|
||||||
|
}
|
||||||
|
|
||||||
|
return(xVarNames)
|
||||||
|
}
|
||||||
|
|
17
R/misc.R
17
R/misc.R
|
@ -14,6 +14,23 @@ convertRListToJava <- function(lst){
|
||||||
return(javaList)
|
return(javaList)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#Internal function
|
||||||
|
convertJavaListToR <- function(javaList, class = .class_Object){
|
||||||
|
lst <- list()
|
||||||
|
|
||||||
|
javaList.length <- .jcall(javaList, "I", "size")
|
||||||
|
|
||||||
|
for(i in 0:(javaList.length - 1)){
|
||||||
|
object <- .jcall(javaList, makeResponse(.class_Object), "get", as.integer(i))
|
||||||
|
object <- .jcast(object, class)
|
||||||
|
|
||||||
|
lst[[i+1]] <- object
|
||||||
|
}
|
||||||
|
|
||||||
|
return(lst)
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
#' @export
|
#' @export
|
||||||
print.SplitFinder = function(x, ...) print(x$call)
|
print.SplitFinder = function(x, ...) print(x$call)
|
||||||
|
|
||||||
|
|
39
R/processCensoringDistribution.R
Normal file
39
R/processCensoringDistribution.R
Normal file
|
@ -0,0 +1,39 @@
|
||||||
|
|
||||||
|
|
||||||
|
# Internal function. Takes a censoring distribution and turns it into a
|
||||||
|
# RightContinuousStepFunction Java object.
|
||||||
|
processCensoringDistribution <- function(censoringDistribution){
|
||||||
|
|
||||||
|
if(is.numeric(censoringDistribution)){
|
||||||
|
# estimate ECDF
|
||||||
|
censoringTimes <- .jarray(censoringDistribution, "D")
|
||||||
|
java.censoringDistribution <- .jcall(.class_Utils, makeResponse(.class_RightContinuousStepFunction), "estimateOneMinusECDF", censoringTimes)
|
||||||
|
|
||||||
|
} else if(is.list(censoringDistribution)){
|
||||||
|
# First check that censoringDistribution fits the correct format
|
||||||
|
if(is.null(censoringDistribution$x) | is.null(censoringDistribution$y)){
|
||||||
|
stop("If the censoringDistribution is provided as a list, it must have an x and a y item that are numeric.")
|
||||||
|
}
|
||||||
|
|
||||||
|
if(length(censoringDistribution$x) != length(censoringDistribution$y)){
|
||||||
|
stop("x and y in censoringDistribution must have the same length.")
|
||||||
|
}
|
||||||
|
|
||||||
|
if(!is.numeric(censoringDistribution$x) | !is.numeric(censoringDistribution$y)){
|
||||||
|
stop("x and y in censoringDistribution must both be numeric.")
|
||||||
|
}
|
||||||
|
|
||||||
|
java.censoringDistribution <- createRightContinuousStepFunction(censoringDistribution$x, censoringDistribution$y, defaultY = 1.0)
|
||||||
|
|
||||||
|
} else if("stepfun" %in% class(censoringDistribution)){
|
||||||
|
x <- stats::knots(censoringDistribution)
|
||||||
|
y <- censoringDistribution(x)
|
||||||
|
|
||||||
|
java.censoringDistribution <- createRightContinuousStepFunction(x, y, defaultY = 1.0)
|
||||||
|
}
|
||||||
|
else{
|
||||||
|
stop("Invalid censoringDistribution")
|
||||||
|
}
|
||||||
|
|
||||||
|
return(java.censoringDistribution)
|
||||||
|
}
|
59
R/processFormula.R
Normal file
59
R/processFormula.R
Normal file
|
@ -0,0 +1,59 @@
|
||||||
|
|
||||||
|
# Internal function that takes a formula and processes it for use in the Java
|
||||||
|
# code. existingCovariateList is optional; if not provided then a new one is
|
||||||
|
# created internally.
|
||||||
|
processFormula <- function(formula, data, covariateList.java = NULL){
|
||||||
|
|
||||||
|
# Having an R copy of the data loaded at the same time can be wasteful; we
|
||||||
|
# also allow users to provide an environment of the data which gets removed
|
||||||
|
# after being imported into Java
|
||||||
|
if(class(data) == "environment"){
|
||||||
|
if(is.null(data$data)){
|
||||||
|
stop("When providing an environment with the dataset, the environment must contain an item called 'data'")
|
||||||
|
}
|
||||||
|
|
||||||
|
env <- data
|
||||||
|
data <- env$data
|
||||||
|
env$data <- NULL
|
||||||
|
gc()
|
||||||
|
}
|
||||||
|
|
||||||
|
yVar <- formula[[2]]
|
||||||
|
|
||||||
|
responses <- NULL
|
||||||
|
variablesToDrop <- character(0)
|
||||||
|
|
||||||
|
# yVar is a call object; as.character(yVar) will be the different components, including the parameters.
|
||||||
|
# if the length of yVar is > 1 then it's a function call. If the length is 1, and it's not in data,
|
||||||
|
# then we also need to explicitly evaluate it
|
||||||
|
if(class(yVar) == "call" || !(as.character(yVar) %in% colnames(data))){
|
||||||
|
# yVar is a function like CR_Response
|
||||||
|
responses <- eval(expr=yVar, envir=data)
|
||||||
|
|
||||||
|
if(class(formula[[3]]) == "name" && as.character(formula[[3]])=="."){
|
||||||
|
# do any of the variables match data in data? We need to track that so we can drop them later
|
||||||
|
variablesToDrop <- as.character(yVar)[as.character(yVar) %in% names(data)]
|
||||||
|
}
|
||||||
|
|
||||||
|
formula[[2]] <- NULL
|
||||||
|
|
||||||
|
} else if(class(yVar) == "name"){ # and implicitly yVar is contained in data
|
||||||
|
variablesToDrop <- as.character(yVar)
|
||||||
|
}
|
||||||
|
|
||||||
|
# Includes responses which we may need to later cut out if `.` was used on the
|
||||||
|
# right-hand-side
|
||||||
|
filteredData <- stats::model.frame(formula=formula, data=data, na.action=stats::na.pass)
|
||||||
|
|
||||||
|
if(is.null(responses)){ # If this if-statement runs then we have a simple (i.e. numeric) response
|
||||||
|
responses <- stats::model.response(filteredData)
|
||||||
|
}
|
||||||
|
|
||||||
|
# remove any response variables on the right-hand-side
|
||||||
|
covariateData <- filteredData[, !(names(filteredData) %in% variablesToDrop), drop=FALSE]
|
||||||
|
|
||||||
|
|
||||||
|
dataset <- loadData(covariateData, colnames(covariateData), responses, covariateList.java = covariateList.java)
|
||||||
|
|
||||||
|
return(dataset)
|
||||||
|
}
|
100
R/train.R
100
R/train.R
|
@ -14,7 +14,7 @@ getCores <- function(){
|
||||||
return(cores)
|
return(cores)
|
||||||
}
|
}
|
||||||
|
|
||||||
train.internal <- function(responses, covariateData, splitFinder,
|
train.internal <- function(dataset, splitFinder,
|
||||||
nodeResponseCombiner, forestResponseCombiner, ntree,
|
nodeResponseCombiner, forestResponseCombiner, ntree,
|
||||||
numberOfSplits, mtry, nodeSize, maxNodeDepth,
|
numberOfSplits, mtry, nodeSize, maxNodeDepth,
|
||||||
splitPureNodes, savePath, savePath.overwrite,
|
splitPureNodes, savePath, savePath.overwrite,
|
||||||
|
@ -52,15 +52,15 @@ train.internal <- function(responses, covariateData, splitFinder,
|
||||||
}
|
}
|
||||||
|
|
||||||
if(is.null(splitFinder)){
|
if(is.null(splitFinder)){
|
||||||
splitFinder <- splitFinderDefault(responses)
|
splitFinder <- splitFinderDefault(dataset$responses)
|
||||||
}
|
}
|
||||||
|
|
||||||
if(is.null(nodeResponseCombiner)){
|
if(is.null(nodeResponseCombiner)){
|
||||||
nodeResponseCombiner <- nodeResponseCombinerDefault(responses)
|
nodeResponseCombiner <- nodeResponseCombinerDefault(dataset$responses)
|
||||||
}
|
}
|
||||||
|
|
||||||
if(is.null(forestResponseCombiner)){
|
if(is.null(forestResponseCombiner)){
|
||||||
forestResponseCombiner <- forestResponseCombinerDefault(responses)
|
forestResponseCombiner <- forestResponseCombinerDefault(dataset$responses)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -75,20 +75,6 @@ train.internal <- function(responses, covariateData, splitFinder,
|
||||||
stop("forestResponseCombiner must be a ResponseCombiner")
|
stop("forestResponseCombiner must be a ResponseCombiner")
|
||||||
}
|
}
|
||||||
|
|
||||||
if(class(covariateData)=="environment"){
|
|
||||||
if(is.null(covariateData$data)){
|
|
||||||
stop("When providing an environment with the dataset, the environment must contain an item called 'data'")
|
|
||||||
}
|
|
||||||
dataset <- loadData(covariateData$data, colnames(covariateData$data), responses)
|
|
||||||
covariateData$data <- NULL # save memory, hopefully
|
|
||||||
gc() # explicitly try to save memory
|
|
||||||
}
|
|
||||||
else{
|
|
||||||
dataset <- loadData(covariateData, colnames(covariateData), responses)
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
treeTrainer <- createTreeTrainer(responseCombiner=nodeResponseCombiner,
|
treeTrainer <- createTreeTrainer(responseCombiner=nodeResponseCombiner,
|
||||||
splitFinder=splitFinder,
|
splitFinder=splitFinder,
|
||||||
covariateList=dataset$covariateList,
|
covariateList=dataset$covariateList,
|
||||||
|
@ -188,7 +174,7 @@ train.internal <- function(responses, covariateData, splitFinder,
|
||||||
#'
|
#'
|
||||||
#' @param formula You may specify the response and covariates as a formula
|
#' @param formula You may specify the response and covariates as a formula
|
||||||
#' instead; make sure the response in the formula is still properly
|
#' instead; make sure the response in the formula is still properly
|
||||||
#' constructed; see \code{responses}
|
#' constructed.
|
||||||
#' @param data A data.frame containing the columns of the predictors and
|
#' @param data A data.frame containing the columns of the predictors and
|
||||||
#' responses.
|
#' responses.
|
||||||
#' @param splitFinder A split finder that's used to score splits in the random
|
#' @param splitFinder A split finder that's used to score splits in the random
|
||||||
|
@ -314,74 +300,16 @@ train <- function(formula, data, splitFinder = NULL, nodeResponseCombiner = NULL
|
||||||
savePath.overwrite=c("warn", "delete", "merge"), cores = getCores(),
|
savePath.overwrite=c("warn", "delete", "merge"), cores = getCores(),
|
||||||
randomSeed = NULL, displayProgress = TRUE){
|
randomSeed = NULL, displayProgress = TRUE){
|
||||||
|
|
||||||
# Having an R copy of the data loaded at the same time can be wasteful; we
|
dataset <- processFormula(formula, data)
|
||||||
# also allow users to provide an environment of the data which gets removed
|
|
||||||
# after being imported into Java
|
|
||||||
env <- NULL
|
|
||||||
if(class(data) == "environment"){
|
|
||||||
if(is.null(data$data)){
|
|
||||||
stop("When providing an environment with the dataset, the environment must contain an item called 'data'")
|
|
||||||
}
|
|
||||||
|
|
||||||
env <- data
|
|
||||||
data <- env$data
|
|
||||||
}
|
|
||||||
|
|
||||||
yVar <- formula[[2]]
|
|
||||||
|
|
||||||
responses <- NULL
|
|
||||||
variablesToDrop <- character(0)
|
|
||||||
|
|
||||||
# yVar is a call object; as.character(yVar) will be the different components, including the parameters.
|
forest <- train.internal(dataset, splitFinder = splitFinder,
|
||||||
# if the length of yVar is > 1 then it's a function call. If the length is 1, and it's not in data,
|
nodeResponseCombiner = nodeResponseCombiner,
|
||||||
# then we also need to explicitly evaluate it
|
forestResponseCombiner = forestResponseCombiner,
|
||||||
if(class(yVar)=="call" || !(as.character(yVar) %in% colnames(data))){
|
ntree = ntree, numberOfSplits = numberOfSplits,
|
||||||
# yVar is a function like CompetingRiskResponses
|
mtry = mtry, nodeSize = nodeSize, maxNodeDepth = maxNodeDepth,
|
||||||
responses <- eval(expr=yVar, envir=data)
|
splitPureNodes = splitPureNodes, savePath = savePath,
|
||||||
|
savePath.overwrite = savePath.overwrite, cores = cores,
|
||||||
if(class(formula[[3]]) == "name" && as.character(formula[[3]])=="."){
|
randomSeed = randomSeed, displayProgress = displayProgress)
|
||||||
# do any of the variables match data in data? We need to track that so we can drop them later
|
|
||||||
variablesToDrop <- as.character(yVar)[as.character(yVar) %in% names(data)]
|
|
||||||
}
|
|
||||||
|
|
||||||
formula[[2]] <- NULL
|
|
||||||
|
|
||||||
} else if(class(yVar)=="name"){ # and implicitly yVar is contained in data
|
|
||||||
variablesToDrop <- as.character(yVar)
|
|
||||||
}
|
|
||||||
|
|
||||||
# Includes responses which we may need to later cut out
|
|
||||||
mf <- stats::model.frame(formula=formula, data=data, na.action=stats::na.pass)
|
|
||||||
|
|
||||||
if(is.null(responses)){
|
|
||||||
responses <- stats::model.response(mf)
|
|
||||||
}
|
|
||||||
|
|
||||||
# remove any response variables
|
|
||||||
mf <- mf[,!(names(mf) %in% variablesToDrop), drop=FALSE]
|
|
||||||
|
|
||||||
# If environment was provided instead of data
|
|
||||||
if(!is.null(env)){
|
|
||||||
env$data <- mf
|
|
||||||
rm(data)
|
|
||||||
forest <- train.internal(responses, env, splitFinder = splitFinder,
|
|
||||||
nodeResponseCombiner = nodeResponseCombiner,
|
|
||||||
forestResponseCombiner = forestResponseCombiner,
|
|
||||||
ntree = ntree, numberOfSplits = numberOfSplits,
|
|
||||||
mtry = mtry, nodeSize = nodeSize, maxNodeDepth = maxNodeDepth,
|
|
||||||
splitPureNodes = splitPureNodes, savePath = savePath,
|
|
||||||
savePath.overwrite = savePath.overwrite, cores = cores,
|
|
||||||
randomSeed = randomSeed, displayProgress = displayProgress)
|
|
||||||
} else{
|
|
||||||
forest <- train.internal(responses, mf, splitFinder = splitFinder,
|
|
||||||
nodeResponseCombiner = nodeResponseCombiner,
|
|
||||||
forestResponseCombiner = forestResponseCombiner,
|
|
||||||
ntree = ntree, numberOfSplits = numberOfSplits,
|
|
||||||
mtry = mtry, nodeSize = nodeSize, maxNodeDepth = maxNodeDepth,
|
|
||||||
splitPureNodes = splitPureNodes, savePath = savePath,
|
|
||||||
savePath.overwrite = savePath.overwrite, cores = cores,
|
|
||||||
randomSeed = randomSeed, displayProgress = displayProgress)
|
|
||||||
}
|
|
||||||
|
|
||||||
forest$call <- match.call()
|
forest$call <- match.call()
|
||||||
forest$formula <- formula
|
forest$formula <- formula
|
||||||
|
|
175
R/vimp.R
Normal file
175
R/vimp.R
Normal file
|
@ -0,0 +1,175 @@
|
||||||
|
|
||||||
|
|
||||||
|
#' Variable Importance
|
||||||
|
#'
|
||||||
|
#' Calculate variable importance by recording the increase in error when a given
|
||||||
|
#' predictor is randomly permuted. Regression forests uses mean squared error;
|
||||||
|
#' competing risks uses integrated Brier score.
|
||||||
|
#'
|
||||||
|
#' @param forest The forest that was trained.
|
||||||
|
#' @param newData A test set of the data if available. If not, then out of bag
|
||||||
|
#' errors will be attempted on the training set.
|
||||||
|
#' @param randomSeed The source of randomness used to permute the values. Can be
|
||||||
|
#' left blank.
|
||||||
|
#' @param events If using competing risks forest, the events that the error
|
||||||
|
#' measure used for VIMP should be calculated on.
|
||||||
|
#' @param time If using competing risks forest, the upper bound of the
|
||||||
|
#' integrated Brier score.
|
||||||
|
#' @param censoringDistribution (Optional) If using competing risks forest, the
|
||||||
|
#' censoring distribution. See \code{\link{integratedBrierScore} for details.}
|
||||||
|
#' @param eventWeights (Optional) If using competing risks forest, weights to be
|
||||||
|
#' applied to the error for each of the \code{events}.
|
||||||
|
#'
|
||||||
|
#' @return A named numeric vector of importance values.
|
||||||
|
#' @export
|
||||||
|
#'
|
||||||
|
#' @examples
|
||||||
|
#' data(wihs)
|
||||||
|
#'
|
||||||
|
#' forest <- train(CR_Response(status, time) ~ ., wihs,
|
||||||
|
#' ntree = 100, numberOfSplits = 0, mtry=3, nodeSize = 5)
|
||||||
|
#'
|
||||||
|
#' vimp(forest, events = 1:2, time = 8.0)
|
||||||
|
#'
|
||||||
|
vimp <- function(
|
||||||
|
forest,
|
||||||
|
newData = NULL,
|
||||||
|
randomSeed = NULL,
|
||||||
|
type = c("mean", "z", "raw"),
|
||||||
|
events = NULL,
|
||||||
|
time = NULL,
|
||||||
|
censoringDistribution = NULL,
|
||||||
|
eventWeights = NULL){
|
||||||
|
|
||||||
|
if(is.null(newData) & is.null(forest$dataset)){
|
||||||
|
stop("forest doesn't have a copy of the training data loaded (this happens if you just loaded it); please manually specify newData and possibly out.of.bag")
|
||||||
|
}
|
||||||
|
|
||||||
|
# Basically we check if type is either null, length 0, or one of the invalid values.
|
||||||
|
# We can't include the last statement in the same statement as length(tyoe) < 1,
|
||||||
|
# because R checks both cases and a different error would display if length(type) == 0
|
||||||
|
typeError = is.null(type) | length(type) < 1
|
||||||
|
if(!typeError){
|
||||||
|
typeError = !(type[1] %in% c("mean", "z", "raw"))
|
||||||
|
}
|
||||||
|
if(typeError){
|
||||||
|
stop("A valid response type must be provided.")
|
||||||
|
}
|
||||||
|
|
||||||
|
if(is.null(newData)){
|
||||||
|
data.java <- forest$dataset
|
||||||
|
out.of.bag <- TRUE
|
||||||
|
|
||||||
|
}
|
||||||
|
else{ # newData is provided
|
||||||
|
data.java <- processFormula(forest$formula, newData, forest$covariateList)$dataset
|
||||||
|
out.of.bag <- FALSE
|
||||||
|
}
|
||||||
|
|
||||||
|
predictionClass <- forest$params$forestResponseCombiner$outputClass
|
||||||
|
|
||||||
|
if(predictionClass == "CompetingRiskFunctions"){
|
||||||
|
if(is.null(time) | length(time) != 1){
|
||||||
|
stop("time must be set at length 1")
|
||||||
|
}
|
||||||
|
|
||||||
|
errorCalculator.java <- ibsCalculatorWrapper(
|
||||||
|
events = events,
|
||||||
|
time = time,
|
||||||
|
censoringDistribution = censoringDistribution,
|
||||||
|
eventWeights = eventWeights)
|
||||||
|
|
||||||
|
} else if(predictionClass == "numeric"){
|
||||||
|
errorCalculator.java <- .jnew(.class_RegressionErrorCalculator)
|
||||||
|
errorCalculator.java <- .jcast(errorCalculator.java, .class_ErrorCalculator)
|
||||||
|
|
||||||
|
} else{
|
||||||
|
stop(paste0("VIMP not yet supported for ", predictionClass, ". If you're just using a non-custom version of largeRCRF then this is a bug and should be reported."))
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
forest.trees.java <- .jcall(forest$javaObject, makeResponse(.class_List), "getTrees")
|
||||||
|
|
||||||
|
vimp.calculator <- .jnew(.class_VariableImportanceCalculator,
|
||||||
|
errorCalculator.java,
|
||||||
|
forest.trees.java,
|
||||||
|
data.java,
|
||||||
|
out.of.bag # isTrainingSet parameter
|
||||||
|
)
|
||||||
|
|
||||||
|
random.java <- NULL
|
||||||
|
if(!is.null(randomSeed)){
|
||||||
|
random.java <- .jnew(.class_Random, .jlong(as.integer(randomSeed)))
|
||||||
|
}
|
||||||
|
random.java <- .object_Optional(random.java)
|
||||||
|
|
||||||
|
covariateRList <- convertJavaListToR(forest$covariateList, class = .class_Covariate)
|
||||||
|
importanceValues <- matrix(nrow = forest$params$ntree, ncol = length(covariateRList))
|
||||||
|
colnames(importanceValues) <- extractCovariateNamesFromJavaList(forest$covariateList)
|
||||||
|
|
||||||
|
for(j in 1:length(covariateRList)){
|
||||||
|
covariateJava <- covariateRList[[j]]
|
||||||
|
covariateJava <-
|
||||||
|
|
||||||
|
importanceValues[, j] <- .jcall(vimp.calculator, "[D", "calculateVariableImportanceRaw", covariateJava, random.java)
|
||||||
|
}
|
||||||
|
|
||||||
|
if(type[1] == "raw"){
|
||||||
|
return(importanceValues)
|
||||||
|
} else if(type[1] == "mean"){
|
||||||
|
meanImportanceValues <- apply(importanceValues, 2, mean)
|
||||||
|
return(meanImportanceValues)
|
||||||
|
} else if(type[1] == "z"){
|
||||||
|
zImportanceValues <- apply(importanceValues, 2, function(x){
|
||||||
|
meanValue <- mean(x)
|
||||||
|
standardError <- sd(x)/sqrt(length(x))
|
||||||
|
return(meanValue / standardError)
|
||||||
|
})
|
||||||
|
return(zImportanceValues)
|
||||||
|
|
||||||
|
} else{
|
||||||
|
stop("A valid response type must be provided.")
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
return(importance)
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
# Internal function
|
||||||
|
ibsCalculatorWrapper <- function(events, time, censoringDistribution = NULL, eventWeights = NULL){
|
||||||
|
if(is.null(events)){
|
||||||
|
stop("events must be specified if using vimp on competing risks data")
|
||||||
|
}
|
||||||
|
|
||||||
|
if(is.null(time)){
|
||||||
|
stop("time must be specified if using vimp on competing risks data")
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
java.censoringDistribution <- NULL
|
||||||
|
if(!is.null(censoringDistribution)){
|
||||||
|
java.censoringDistribution <- processCensoringDistribution(censoringDistribution)
|
||||||
|
java.censoringDistribution <- .object_Optional(java.censoringDistribution)
|
||||||
|
}
|
||||||
|
else{
|
||||||
|
java.censoringDistribution <- .object_Optional(NULL)
|
||||||
|
}
|
||||||
|
|
||||||
|
ibsCalculator.java <- .jnew(.class_IBSCalculator, java.censoringDistribution)
|
||||||
|
|
||||||
|
if(is.null(eventWeights)){
|
||||||
|
eventWeights <- rep(1, times = length(events))
|
||||||
|
}
|
||||||
|
|
||||||
|
ibsCalculatorWrapper.java <- .jnew(.class_IBSErrorCalculatorWrapper,
|
||||||
|
ibsCalculator.java,
|
||||||
|
.jarray(as.integer(events)),
|
||||||
|
as.numeric(time),
|
||||||
|
.jarray(as.numeric(eventWeights)))
|
||||||
|
|
||||||
|
ibsCalculatorWrapper.java <- .jcast(ibsCalculatorWrapper.java, .class_ErrorCalculator)
|
||||||
|
return(ibsCalculatorWrapper.java)
|
||||||
|
|
||||||
|
|
||||||
|
}
|
Binary file not shown.
|
@ -13,7 +13,7 @@ train(formula, data, splitFinder = NULL, nodeResponseCombiner = NULL,
|
||||||
\arguments{
|
\arguments{
|
||||||
\item{formula}{You may specify the response and covariates as a formula
|
\item{formula}{You may specify the response and covariates as a formula
|
||||||
instead; make sure the response in the formula is still properly
|
instead; make sure the response in the formula is still properly
|
||||||
constructed; see \code{responses}}
|
constructed.}
|
||||||
|
|
||||||
\item{data}{A data.frame containing the columns of the predictors and
|
\item{data}{A data.frame containing the columns of the predictors and
|
||||||
responses.}
|
responses.}
|
||||||
|
|
48
man/vimp.Rd
Normal file
48
man/vimp.Rd
Normal file
|
@ -0,0 +1,48 @@
|
||||||
|
% Generated by roxygen2: do not edit by hand
|
||||||
|
% Please edit documentation in R/vimp.R
|
||||||
|
\name{vimp}
|
||||||
|
\alias{vimp}
|
||||||
|
\title{Variable Importance}
|
||||||
|
\usage{
|
||||||
|
vimp(forest, newData = NULL, randomSeed = NULL, type = c("mean", "z",
|
||||||
|
"raw"), events = NULL, time = NULL, censoringDistribution = NULL,
|
||||||
|
eventWeights = NULL)
|
||||||
|
}
|
||||||
|
\arguments{
|
||||||
|
\item{forest}{The forest that was trained.}
|
||||||
|
|
||||||
|
\item{newData}{A test set of the data if available. If not, then out of bag
|
||||||
|
errors will be attempted on the training set.}
|
||||||
|
|
||||||
|
\item{randomSeed}{The source of randomness used to permute the values. Can be
|
||||||
|
left blank.}
|
||||||
|
|
||||||
|
\item{events}{If using competing risks forest, the events that the error
|
||||||
|
measure used for VIMP should be calculated on.}
|
||||||
|
|
||||||
|
\item{time}{If using competing risks forest, the upper bound of the
|
||||||
|
integrated Brier score.}
|
||||||
|
|
||||||
|
\item{censoringDistribution}{(Optional) If using competing risks forest, the
|
||||||
|
censoring distribution. See \code{\link{integratedBrierScore} for details.}}
|
||||||
|
|
||||||
|
\item{eventWeights}{(Optional) If using competing risks forest, weights to be
|
||||||
|
applied to the error for each of the \code{events}.}
|
||||||
|
}
|
||||||
|
\value{
|
||||||
|
A named numeric vector of importance values.
|
||||||
|
}
|
||||||
|
\description{
|
||||||
|
Calculate variable importance by recording the increase in error when a given
|
||||||
|
predictor is randomly permuted. Regression forests uses mean squared error;
|
||||||
|
competing risks uses integrated Brier score.
|
||||||
|
}
|
||||||
|
\examples{
|
||||||
|
data(wihs)
|
||||||
|
|
||||||
|
forest <- train(CR_Response(status, time) ~ ., wihs,
|
||||||
|
ntree = 100, numberOfSplits = 0, mtry=3, nodeSize = 5)
|
||||||
|
|
||||||
|
vimp(forest, events = 1:2, time = 8.0)
|
||||||
|
|
||||||
|
}
|
102
tests/testthat/test_vimp.R
Normal file
102
tests/testthat/test_vimp.R
Normal file
|
@ -0,0 +1,102 @@
|
||||||
|
context("Use VIMP without error")
|
||||||
|
|
||||||
|
test_that("VIMP doesn't crash; no test dataset", {
|
||||||
|
|
||||||
|
data(wihs)
|
||||||
|
|
||||||
|
forest <- train(CR_Response(status, time) ~ ., wihs, ntree=50, numberOfSplits=0, mtry=1, nodeSize=5, displayProgress=FALSE)
|
||||||
|
|
||||||
|
# Run VIMP several times under different scenarios
|
||||||
|
importance <- vimp(forest, type="raw", events=1:2, time=5.0)
|
||||||
|
vimp(forest, type="raw", events=1, time=5.0)
|
||||||
|
vimp(forest, type="raw", events=1:2, time=5.0, eventWeights = c(0.2, 0.8))
|
||||||
|
|
||||||
|
# Not much of a test, but the Java code tests more for correctness. This just
|
||||||
|
# tests that the R code runs without error.
|
||||||
|
expect_equal(ncol(importance), 4) # 4 predictors
|
||||||
|
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
test_that("VIMP doesn't crash; test dataset", {
|
||||||
|
|
||||||
|
data(wihs)
|
||||||
|
|
||||||
|
trainingData <- wihs[1:1000,]
|
||||||
|
testData <- wihs[1001:nrow(wihs),]
|
||||||
|
|
||||||
|
forest <- train(CR_Response(status, time) ~ ., trainingData, ntree=50, numberOfSplits=0, mtry=1, nodeSize=5, displayProgress=FALSE, cores=1)
|
||||||
|
|
||||||
|
# Run VIMP several times under different scenarios
|
||||||
|
importance <- vimp(forest, newData=testData, type="raw", events=1:2, time=5.0)
|
||||||
|
vimp(forest, newData=testData, type="raw", events=1, time=5.0)
|
||||||
|
vimp(forest, newData=testData, type="raw", events=1:2, time=5.0, eventWeights = c(0.2, 0.8))
|
||||||
|
|
||||||
|
# Not much of a test, but the Java code tests more for correctness. This just
|
||||||
|
# tests that the R code runs without error.
|
||||||
|
expect_equal(ncol(importance), 4) # 4 predictors
|
||||||
|
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
test_that("VIMP doesn't crash; censoring distribution; all methods equal", {
|
||||||
|
|
||||||
|
sampleData <- data.frame(x=rnorm(100))
|
||||||
|
sampleData$T <- sample(0:4, size=100, replace = TRUE) # the censor distribution we provide needs to conform to the data or we can get NaNs
|
||||||
|
sampleData$delta <- sample(0:2, size = 100, replace = TRUE)
|
||||||
|
|
||||||
|
testData <- sampleData[1:5,]
|
||||||
|
trainingData <- sampleData[6:100,]
|
||||||
|
|
||||||
|
forest <- train(CR_Response(delta, T) ~ x, trainingData, ntree=50, numberOfSplits=0, mtry=1, nodeSize=5, cores=2, displayProgress=FALSE)
|
||||||
|
|
||||||
|
importance1 <- vimp(forest, type="raw", events=1:2, time=4.0, randomSeed=50,
|
||||||
|
censoringDistribution = c(0,1,1,2,3,4))
|
||||||
|
importance2 <- vimp(forest, type="raw", events=1:2, time=4.0, randomSeed=50,
|
||||||
|
censoringDistribution = list(x = 0:4, y = 1 - c(1/6, 3/6, 4/6, 5/6, 6/6)))
|
||||||
|
importance3 <- vimp(forest, type="raw", events=1:2, time=4.0, randomSeed=50,
|
||||||
|
censoringDistribution = stepfun(x=0:4, y=1 - c(0, 1/6, 3/6, 4/6, 5/6, 6/6)))
|
||||||
|
|
||||||
|
expect_equal(importance1, importance2)
|
||||||
|
expect_equal(importance1, importance3)
|
||||||
|
|
||||||
|
})
|
||||||
|
|
||||||
|
test_that("VIMP doesn't crash; regression dataset", {
|
||||||
|
|
||||||
|
data <- data.frame(x1=rnorm(1000), x2=rnorm(1000), x3=rnorm(1000))
|
||||||
|
data$y <- data$x1 + 3*data$x2 + 0.05*data$x3 + rnorm(1000)
|
||||||
|
|
||||||
|
forest <- train(y ~ ., data, ntree=50, numberOfSplits=100, mtry=2, nodeSize=5, displayProgress=FALSE)
|
||||||
|
|
||||||
|
importance <- vimp(forest, type="mean")
|
||||||
|
|
||||||
|
expect_true(importance["x2"] > importance["x3"])
|
||||||
|
|
||||||
|
# Not much of a test, but the Java code tests more for correctness. This just
|
||||||
|
# tests that the R code runs without error.
|
||||||
|
expect_equal(length(importance), 3) # 3 predictors
|
||||||
|
|
||||||
|
})
|
||||||
|
|
||||||
|
test_that("VIMP produces mean and z scores correctly", {
|
||||||
|
|
||||||
|
data <- data.frame(x1=rnorm(1000), x2=rnorm(1000), x3=rnorm(1000))
|
||||||
|
data$y <- data$x1 + 3*data$x2 + 0.05*data$x3 + rnorm(1000)
|
||||||
|
|
||||||
|
forest <- train(y ~ ., data, ntree=50, numberOfSplits=100, mtry=2, nodeSize=5, displayProgress=FALSE)
|
||||||
|
|
||||||
|
actual.importance.raw <- vimp(forest, type="raw", randomSeed=5)
|
||||||
|
actual.importance.mean <- vimp(forest, type="mean", randomSeed=5)
|
||||||
|
actual.importance.z <- vimp(forest, type="z", randomSeed=5)
|
||||||
|
|
||||||
|
expected.importance.mean <- apply(actual.importance.raw, 2, mean)
|
||||||
|
expected.importance.z <- apply(actual.importance.raw, 2, function(x){
|
||||||
|
mn <- mean(x)
|
||||||
|
return( mn / (sd(x) / sqrt(length(x))) )
|
||||||
|
})
|
||||||
|
|
||||||
|
expect_equal(expected.importance.mean, actual.importance.mean)
|
||||||
|
expect_equal(expected.importance.z, actual.importance.z)
|
||||||
|
|
||||||
|
})
|
Loading…
Reference in a new issue