Compare commits
9 commits
1.0.3.1-de
...
master
Author | SHA1 | Date | |
---|---|---|---|
589df1a18b | |||
3fa2fef82a | |||
3f0f6c0878 | |||
|
48859b0249 | ||
|
fd8621a88d | ||
|
af0c8f2e96 | ||
|
4cd322ee22 | ||
|
360be1f80e | ||
|
cb4c9b73ae |
38 changed files with 1196 additions and 202 deletions
|
@ -1,7 +1,7 @@
|
|||
Package: largeRCRF
|
||||
Type: Package
|
||||
Title: Large Random Competing Risks Forests
|
||||
Version: 1.0.3.1
|
||||
Version: 1.0.5
|
||||
Authors@R: c(
|
||||
person("Joel", "Therrien", email = "joel_therrien@sfu.ca", role = c("aut", "cre", "cph")),
|
||||
person("Jiguo", "Cao", email = "jiguo_cao@sfu.ca", role = c("aut", "dgs"))
|
||||
|
@ -15,7 +15,8 @@ Copyright: All provided source code is copyrighted and owned by Joel Therrien.
|
|||
Encoding: UTF-8
|
||||
LazyData: true
|
||||
Imports:
|
||||
rJava (>= 0.9-9)
|
||||
rJava (>= 0.9-9),
|
||||
stats (>= 3.4.0)
|
||||
Suggests:
|
||||
parallel,
|
||||
testthat,
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
# Generated by roxygen2: do not edit by hand
|
||||
|
||||
S3method("[",CompetingRiskResponses)
|
||||
S3method("[",CompetingRiskResponsesWithCensorTimes)
|
||||
S3method(extractCHF,CompetingRiskFunctions)
|
||||
S3method(extractCHF,CompetingRiskFunctions.List)
|
||||
S3method(extractCIF,CompetingRiskFunctions)
|
||||
|
@ -24,12 +26,15 @@ export(Numeric)
|
|||
export(WeightedVarianceSplitFinder)
|
||||
export(addTrees)
|
||||
export(connectToData)
|
||||
export(convertToOnlineForest)
|
||||
export(extractCHF)
|
||||
export(extractCIF)
|
||||
export(extractMortalities)
|
||||
export(extractSurvivorCurve)
|
||||
export(integratedBrierScore)
|
||||
export(loadForest)
|
||||
export(naiveConcordance)
|
||||
export(saveForest)
|
||||
export(train)
|
||||
export(vimp)
|
||||
import(rJava)
|
||||
|
|
|
@ -38,6 +38,51 @@ CR_Response <- function(delta, u, C = NULL){
|
|||
}
|
||||
|
||||
|
||||
# This function is useful is we ever want to do something like CR_Response(c(1,1,2), c(0.1,0.2,0.3))[1]
|
||||
#' @export
|
||||
"[.CompetingRiskResponses" <- function(object, indices){
|
||||
newList <- list(
|
||||
eventIndicator = object$eventIndicator[indices],
|
||||
eventTime = object$eventTime[indices]
|
||||
)
|
||||
|
||||
previous.java.list <- object$javaObject
|
||||
|
||||
new.java.list <- .jcall(.class_RUtils,
|
||||
makeResponse(.class_List),
|
||||
"produceSublist",
|
||||
previous.java.list,
|
||||
.jarray(as.integer(indices - 1)))
|
||||
|
||||
newList$javaObject <- new.java.list
|
||||
|
||||
class(newList) <- "CompetingRiskResponses"
|
||||
return(newList)
|
||||
}
|
||||
|
||||
# This function is useful is we ever want to do something like CR_Response(c(1,1,2), c(0.1,0.2,0.3), c(2,3,4))[1]
|
||||
#' @export
|
||||
"[.CompetingRiskResponsesWithCensorTimes" <- function(object, indices){
|
||||
newList <- list(
|
||||
eventIndicator = object$eventIndicator[indices],
|
||||
eventTime = object$eventTime[indices],
|
||||
censorTime = object$censorTime[indices]
|
||||
)
|
||||
|
||||
previous.java.list <- object$javaObject
|
||||
|
||||
new.java.list <- .jcall(.class_RUtils,
|
||||
makeResponse(.class_List),
|
||||
"produceSublist",
|
||||
previous.java.list,
|
||||
.jarray(as.integer(indices - 1)))
|
||||
|
||||
newList$javaObject <- new.java.list
|
||||
|
||||
class(newList) <- "CompetingRiskResponsesWithCensorTimes"
|
||||
return(newList)
|
||||
}
|
||||
|
||||
# Internal function
|
||||
Java_CompetingRiskResponses <- function(delta, u){
|
||||
|
||||
|
|
29
R/addTrees.R
29
R/addTrees.R
|
@ -14,6 +14,10 @@
|
|||
#' @param savePath.overwrite If \code{savePath} is pointing to an existing
|
||||
#' directory, possibly containing another forest, this specifies what should
|
||||
#' be done.
|
||||
#' @param forest.output This parameter only applies if \code{savePath} has been
|
||||
#' set; set to 'online' (default) and the saved forest will be loaded into
|
||||
#' memory after being trained. Set to 'offline' and the forest is not saved
|
||||
#' into memory, but can still be used in a memory unintensive manner.
|
||||
#' @param cores The number of cores to be used for training the new trees.
|
||||
#' @param displayProgress A logical indicating whether the progress should be
|
||||
#' displayed to console; default is \code{TRUE}. Useful to set to FALSE in
|
||||
|
@ -22,7 +26,11 @@
|
|||
#' @return A new forest with the original and additional trees.
|
||||
#' @export
|
||||
#'
|
||||
addTrees <- function(forest, numTreesToAdd, savePath = NULL, savePath.overwrite = c("warn", "delete", "merge"), cores = getCores(), displayProgress = TRUE){
|
||||
addTrees <- function(forest, numTreesToAdd, savePath = NULL,
|
||||
savePath.overwrite = c("warn", "delete", "merge"),
|
||||
forest.output = c("online", "offline"),
|
||||
cores = getCores(), displayProgress = TRUE){
|
||||
|
||||
if(is.null(forest$dataset)){
|
||||
stop("Training dataset must be connected to forest before more trees can be added; this can be done manually by using connectToData")
|
||||
}
|
||||
|
@ -37,6 +45,10 @@ addTrees <- function(forest, numTreesToAdd, savePath = NULL, savePath.overwrite
|
|||
stop("savePath.overwrite must be one of c(\"warn\", \"delete\", \"merge\")")
|
||||
}
|
||||
|
||||
if(is.null(forest.output) | length(forest.output)==0 | !(forest.output[1] %in% c("online", "offline"))){
|
||||
stop("forest.output must be one of c(\"online\", \"offline\")")
|
||||
}
|
||||
|
||||
newTreeCount <- forest$params$ntree + as.integer(numTreesToAdd)
|
||||
|
||||
treeTrainer <- createTreeTrainer(responseCombiner=forest$params$nodeResponseCombiner,
|
||||
|
@ -98,22 +110,23 @@ addTrees <- function(forest, numTreesToAdd, savePath = NULL, savePath.overwrite
|
|||
params=params,
|
||||
forestCall=match.call())
|
||||
|
||||
forest.java <- NULL
|
||||
if(cores > 1){
|
||||
.jcall(forestTrainer, "V", "trainParallelOnDisk", initial.forest.optional, as.integer(cores))
|
||||
forest.java <- .jcall(forestTrainer, makeResponse(.class_OfflineForest), "trainParallelOnDisk", initial.forest.optional, as.integer(cores))
|
||||
} else {
|
||||
.jcall(forestTrainer, "V", "trainSerialOnDisk", initial.forest.optional)
|
||||
forest.java <- .jcall(forestTrainer, makeResponse(.class_OfflineForest), "trainSerialOnDisk", initial.forest.optional)
|
||||
}
|
||||
|
||||
# Need to now load forest trees back into memory
|
||||
forest.java <- .jcall(.class_DataUtils, makeResponse(.class_Forest), "loadForest", savePath, forest$params$forestResponseCombiner$javaObject)
|
||||
|
||||
if(forest.output[1] == "online"){
|
||||
forest.java <- convertToOnlineForest.Java(forest.java)
|
||||
}
|
||||
|
||||
}
|
||||
else{ # save directly into memory
|
||||
if(cores > 1){
|
||||
forest.java <- .jcall(forestTrainer, makeResponse(.class_Forest), "trainParallelInMemory", initial.forest.optional, as.integer(cores))
|
||||
forest.java <- .jcall(forestTrainer, makeResponse(.class_OnlineForest), "trainParallelInMemory", initial.forest.optional, as.integer(cores))
|
||||
} else {
|
||||
forest.java <- .jcall(forestTrainer, makeResponse(.class_Forest), "trainSerialInMemory", initial.forest.optional)
|
||||
forest.java <- .jcall(forestTrainer, makeResponse(.class_OnlineForest), "trainSerialInMemory", initial.forest.optional)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
38
R/convertToOnlineForest.R
Normal file
38
R/convertToOnlineForest.R
Normal file
|
@ -0,0 +1,38 @@
|
|||
|
||||
#' Convert to Online Forest
|
||||
#'
|
||||
#' Some forests are too large to store in memory and have been saved to disk.
|
||||
#' They can still be used, but their performance is much slower. If there's
|
||||
#' enough memory, they can easily be converted into an in-memory forest that is
|
||||
#' faster to use.
|
||||
#'
|
||||
#' @param forest The offline forest.
|
||||
#'
|
||||
#' @return An online, in memory forst.
|
||||
#' @export
|
||||
#'
|
||||
convertToOnlineForest <- function(forest){
|
||||
old.forest.object <- forest$javaObject
|
||||
|
||||
if(getJavaClass(old.forest.object) == "ca.joeltherrien.randomforest.tree.OnlineForest"){
|
||||
|
||||
warning("forest is already in-memory")
|
||||
return(forest)
|
||||
|
||||
} else if(getJavaClass(old.forest.object) == "ca.joeltherrien.randomforest.tree.OfflineForest"){
|
||||
|
||||
forest$javaObject <- convertToOnlineForest.Java(old.forest.object)
|
||||
return(forest)
|
||||
|
||||
} else{
|
||||
stop("'forest' is not an online or offline forest")
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
# Internal function
|
||||
convertToOnlineForest.Java <- function(forest.java){
|
||||
offline.forest <- .jcast(forest.java, .class_OfflineForest)
|
||||
online.forest <- .jcall(offline.forest, makeResponse(.class_OnlineForest), "createOnlineCopy")
|
||||
return(online.forest)
|
||||
}
|
|
@ -42,7 +42,7 @@ CR_FunctionCombiner <- function(events, times = NULL){
|
|||
}
|
||||
|
||||
javaObject <- .jnew(.class_CompetingRiskFunctionCombiner, eventArray, timeArray)
|
||||
javaObject <- .jcast(javaObject, .class_ResponseCombiner)
|
||||
javaObject <- .jcast(javaObject, .class_ForestResponseCombiner)
|
||||
|
||||
combiner <- list(javaObject=javaObject,
|
||||
call=match.call(),
|
||||
|
|
83
R/cr_integratedBrierScore.R
Normal file
83
R/cr_integratedBrierScore.R
Normal file
|
@ -0,0 +1,83 @@
|
|||
|
||||
|
||||
#' Integrated Brier Score
|
||||
#'
|
||||
#' Used to calculate the Integrated Brier Score, which for the competing risks
|
||||
#' setting is the integral of the squared difference between each observed
|
||||
#' cumulative incidence function (CIF) for each observation and the
|
||||
#' corresponding predicted CIF. If the survivor function (1 - CDF) of the
|
||||
#' censoring distribution is provided, weights can be calculated to account for
|
||||
#' the censoring.
|
||||
#'
|
||||
#' @return A numeric vector of the Integrated Brier Score for each prediction.
|
||||
#' @param responses A list of responses corresponding to the provided
|
||||
#' mortalities; use \code{\link{CR_Response}}.
|
||||
#' @param predictions The predictions to be tested against.
|
||||
#' @param event The event type for the error to be calculated on.
|
||||
#' @param time \code{time} specifies the upper bound of the integral.
|
||||
#' @param censoringDistribution Optional; if provided then weights are
|
||||
#' calculated on the errors. There are three ways to provide it - \itemize{
|
||||
#' \item{If you have all the censor times and just want to use a simple
|
||||
#' empirical estimate of the distribution, just provide a numeric vector of
|
||||
#' all of the censor times and it will be automatically calculated.} \item{You
|
||||
#' can directly specify the survivor function by providing a list with two
|
||||
#' numeric vectors called \code{x} and \code{y}. They should be of the same
|
||||
#' length and correspond to each point. It is assumed that previous to the
|
||||
#' first value in \code{y} the \code{y} value is 1.0; and that the function
|
||||
#' you provide is a right-continuous step function.} \item{You can provide a
|
||||
#' function from \code{\link[stats]{stepfun}}. Note that this only supports
|
||||
#' functions where \code{right = FALSE} (default), and that the first y value
|
||||
#' (corresponding to y before the first x value) will be to set to 1.0
|
||||
#' regardless of what is specified.}
|
||||
#'
|
||||
#' }
|
||||
#' @param parallel A logical indicating whether multiple cores should be
|
||||
#' utilized when calculating the error. Available as an option because it's
|
||||
#' been observed that using Java's \code{parallelStream} can be unstable on
|
||||
#' some systems. Default value is \code{TRUE}; only set to \code{FALSE} if you
|
||||
#' get strange errors while predicting.
|
||||
#'
|
||||
#' @export
|
||||
#' @references Section 4.2 of Ishwaran H, Gerds TA, Kogalur UB, Moore RD, Gange
|
||||
#' SJ, Lau BM (2014). “Random Survival Forests for Competing Risks.”
|
||||
#' Biostatistics, 15(4), 757–773. doi:10.1093/ biostatistics/kxu010.
|
||||
#'
|
||||
#' @examples
|
||||
#' data <- data.frame(delta=c(1,1,0,0,2,2), T=1:6, x=1:6)
|
||||
#'
|
||||
#' model <- train(CR_Response(delta, T) ~ x, data, ntree=100, numberOfSplits=0, mtry=1, nodeSize=1)
|
||||
#'
|
||||
#' newData <- data.frame(delta=c(1,0,2,1,0,2), T=1:6, x=1:6)
|
||||
#' predictions <- predict(model, newData)
|
||||
#'
|
||||
#' scores <- integratedBrierScore(CR_Response(data$delta, data$T), predictions, 1, 6.0)
|
||||
#'
|
||||
integratedBrierScore <- function(responses, predictions, event, time, censoringDistribution = NULL, parallel = TRUE){
|
||||
if(length(responses$eventTime) != length(predictions)){
|
||||
stop("Length of responses and predictions must be equal.")
|
||||
}
|
||||
|
||||
java.censoringDistribution <- NULL
|
||||
if(!is.null(censoringDistribution)){
|
||||
java.censoringDistribution <- processCensoringDistribution(censoringDistribution)
|
||||
java.censoringDistribution <- .object_Optional(java.censoringDistribution)
|
||||
}
|
||||
else{
|
||||
java.censoringDistribution <- .object_Optional(NULL)
|
||||
}
|
||||
|
||||
|
||||
predictions.java <- lapply(predictions, function(x){return(x$javaObject)})
|
||||
predictions.java <- convertRListToJava(predictions.java)
|
||||
|
||||
errors <- .jcall(.class_CompetingRiskUtils, "[D", "calculateIBSError",
|
||||
responses$javaObject,
|
||||
predictions.java,
|
||||
java.censoringDistribution,
|
||||
as.integer(event),
|
||||
time,
|
||||
parallel)
|
||||
|
||||
return(errors)
|
||||
|
||||
}
|
|
@ -16,6 +16,11 @@
|
|||
#' list should correspond to one of the events in the order of event 1 to J,
|
||||
#' and should be a vector of the same length as responses.
|
||||
#' @export
|
||||
#' @references Section 3.2 of Wolbers, Marcel, Paul Blanche, Michael T. Koller,
|
||||
#' Jacqueline C M Witteman, and Thomas A Gerds. 2014. “Concordance for
|
||||
#' Prognostic Models with Competing Risks.” Biostatistics 15 (3): 526–39.
|
||||
#' https://doi.org/10.1093/biostatistics/kxt059.
|
||||
#'
|
||||
#' @examples
|
||||
#' data <- data.frame(delta=c(1,1,0,0,2,2), T=1:6, x=1:6)
|
||||
#'
|
||||
|
|
11
R/createRightContinuousStepFunction.R
Normal file
11
R/createRightContinuousStepFunction.R
Normal file
|
@ -0,0 +1,11 @@
|
|||
# Internal function
|
||||
createRightContinuousStepFunction <- function(x, y, defaultY){
|
||||
x.java <- .jarray(as.numeric(x))
|
||||
y.java <- .jarray(as.numeric(y))
|
||||
|
||||
# as.numeric is explicitly required in case integers were accidently passed
|
||||
# in.
|
||||
newFun <- .jnew(.class_RightContinuousStepFunction, as.numeric(x), as.numeric(y), as.numeric(defaultY))
|
||||
return(newFun)
|
||||
|
||||
}
|
|
@ -29,8 +29,8 @@
|
|||
NULL
|
||||
|
||||
# @rdname covariates
|
||||
Java_BooleanCovariate <- function(name, index){
|
||||
covariate <- .jnew(.class_BooleanCovariate, name, as.integer(index))
|
||||
Java_BooleanCovariate <- function(name, index, na.penalty){
|
||||
covariate <- .jnew(.class_BooleanCovariate, name, as.integer(index), na.penalty)
|
||||
covariate <- .jcast(covariate, .class_Object) # needed for later adding it into Java Lists
|
||||
|
||||
return(covariate)
|
||||
|
@ -38,19 +38,19 @@ Java_BooleanCovariate <- function(name, index){
|
|||
|
||||
# @rdname covariates
|
||||
# @param levels The levels of the factor as a character vector
|
||||
Java_FactorCovariate <- function(name, index, levels){
|
||||
Java_FactorCovariate <- function(name, index, levels, na.penalty){
|
||||
levelsArray <- .jarray(levels, makeResponse(.class_String))
|
||||
levelsList <- .jcall("java/util/Arrays", "Ljava/util/List;", "asList", .jcast(levelsArray, "[Ljava/lang/Object;"))
|
||||
|
||||
covariate <- .jnew(.class_FactorCovariate, name, as.integer(index), levelsList)
|
||||
covariate <- .jnew(.class_FactorCovariate, name, as.integer(index), levelsList, na.penalty)
|
||||
covariate <- .jcast(covariate, .class_Object) # needed for later adding it into Java Lists
|
||||
|
||||
return(covariate)
|
||||
}
|
||||
|
||||
# @rdname covariates
|
||||
Java_NumericCovariate <- function(name, index){
|
||||
covariate <- .jnew(.class_NumericCovariate, name, as.integer(index))
|
||||
Java_NumericCovariate <- function(name, index, na.penalty){
|
||||
covariate <- .jnew(.class_NumericCovariate, name, as.integer(index), na.penalty)
|
||||
covariate <- .jcast(covariate, .class_Object) # needed for later adding it into Java Lists
|
||||
|
||||
return(covariate)
|
||||
|
|
|
@ -10,15 +10,20 @@
|
|||
.class_Collection <- "java/util/Collection"
|
||||
.class_Serializable <- "java/io/Serializable"
|
||||
.class_File <- "java/io/File"
|
||||
.class_Random <- "java/util/Random"
|
||||
.class_Class <- "java/lang/Class"
|
||||
|
||||
# Utility Classes
|
||||
.class_DataUtils <- "ca/joeltherrien/randomforest/utils/DataUtils"
|
||||
.class_RUtils <- "ca/joeltherrien/randomforest/utils/RUtils"
|
||||
.class_Utils <- "ca/joeltherrien/randomforest/utils/Utils"
|
||||
.class_CompetingRiskUtils <- "ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskUtils"
|
||||
.class_Settings <- "ca/joeltherrien/randomforest/Settings"
|
||||
|
||||
# Misc. Classes
|
||||
.class_RightContinuousStepFunction <- "ca/joeltherrien/randomforest/utils/RightContinuousStepFunction"
|
||||
.class_CompetingRiskResponse <- "ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskResponse"
|
||||
.class_CompetingRiskResponseWithCensorTime <- "ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskResponseWithCensorTime"
|
||||
|
||||
# TreeTrainer & its Builder
|
||||
.class_TreeTrainer <- "ca/joeltherrien/randomforest/tree/TreeTrainer"
|
||||
|
@ -37,9 +42,12 @@
|
|||
|
||||
# Forest class
|
||||
.class_Forest <- "ca/joeltherrien/randomforest/tree/Forest"
|
||||
.class_OnlineForest <- "ca/joeltherrien/randomforest/tree/OnlineForest"
|
||||
.class_OfflineForest <- "ca/joeltherrien/randomforest/tree/OfflineForest"
|
||||
|
||||
# ResponseCombiner classes
|
||||
.class_ResponseCombiner <- "ca/joeltherrien/randomforest/tree/ResponseCombiner"
|
||||
.class_ForestResponseCombiner <- "ca/joeltherrien/randomforest/tree/ForestResponseCombiner"
|
||||
.class_CompetingRiskResponseCombiner <- "ca/joeltherrien/randomforest/responses/competingrisk/combiner/CompetingRiskResponseCombiner"
|
||||
.class_CompetingRiskFunctionCombiner <- "ca/joeltherrien/randomforest/responses/competingrisk/combiner/CompetingRiskFunctionCombiner"
|
||||
.class_MeanResponseCombiner <- "ca/joeltherrien/randomforest/responses/regression/MeanResponseCombiner"
|
||||
|
@ -50,12 +58,20 @@
|
|||
.class_LogRankSplitFinder <- "ca/joeltherrien/randomforest/responses/competingrisk/splitfinder/LogRankSplitFinder"
|
||||
.class_WeightedVarianceSplitFinder <- "ca/joeltherrien/randomforest/responses/regression/WeightedVarianceSplitFinder"
|
||||
|
||||
.object_Optional <- function(forest=NULL){
|
||||
if(is.null(forest)){
|
||||
# VIMP classes
|
||||
.class_IBSCalculator <- "ca/joeltherrien/randomforest/responses/competingrisk/IBSCalculator"
|
||||
.class_ErrorCalculator <- "ca/joeltherrien/randomforest/tree/vimp/ErrorCalculator"
|
||||
.class_RegressionErrorCalculator <- "ca/joeltherrien/randomforest/tree/vimp/RegressionErrorCalculator"
|
||||
.class_IBSErrorCalculatorWrapper <- "ca/joeltherrien/randomforest/tree/vimp/IBSErrorCalculatorWrapper"
|
||||
.class_VariableImportanceCalculator <- "ca/joeltherrien/randomforest/tree/vimp/VariableImportanceCalculator"
|
||||
|
||||
|
||||
.object_Optional <- function(object=NULL){
|
||||
if(is.null(object)){
|
||||
return(.jcall("java/util/Optional", "Ljava/util/Optional;", "empty"))
|
||||
} else{
|
||||
forest <- .jcast(forest, .class_Object)
|
||||
return(.jcall("java/util/Optional", "Ljava/util/Optional;", "of", forest))
|
||||
object <- .jcast(object, .class_Object)
|
||||
return(.jcall("java/util/Optional", "Ljava/util/Optional;", "of", object))
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -66,3 +82,9 @@
|
|||
makeResponse <- function(className){
|
||||
return(paste0("L", className, ";"))
|
||||
}
|
||||
|
||||
getJavaClass <- function(object){
|
||||
class <- .jcall(object, makeResponse(.class_Class), "getClass")
|
||||
className <- .jcall(class, "S", "getName")
|
||||
return(className)
|
||||
}
|
||||
|
|
37
R/loadData.R
37
R/loadData.R
|
@ -1,4 +1,4 @@
|
|||
loadData <- function(data, xVarNames, responses, covariateList.java = NULL){
|
||||
loadData <- function(data, xVarNames, responses, covariateList.java = NULL, na.penalty = NULL){
|
||||
|
||||
if(class(responses) == "integer" | class(responses) == "numeric"){
|
||||
responses <- Numeric(responses)
|
||||
|
@ -6,7 +6,7 @@ loadData <- function(data, xVarNames, responses, covariateList.java = NULL){
|
|||
|
||||
# connectToData provides a pre-created covariate list we can re-use
|
||||
if(is.null(covariateList.java)){
|
||||
covariateList.java <- getCovariateList(data, xVarNames)
|
||||
covariateList.java <- getCovariateList(data, xVarNames, na.penalty)
|
||||
}
|
||||
|
||||
textColumns <- list()
|
||||
|
@ -18,11 +18,11 @@ loadData <- function(data, xVarNames, responses, covariateList.java = NULL){
|
|||
rowList <- .jcall(.class_RUtils, makeResponse(.class_List), "importDataWithResponses",
|
||||
responses$javaObject, covariateList.java, textData)
|
||||
|
||||
return(list(covariateList=covariateList.java, dataset=rowList))
|
||||
return(list(covariateList = covariateList.java, dataset = rowList, responses = responses))
|
||||
|
||||
}
|
||||
|
||||
getCovariateList <- function(data, xvarNames){
|
||||
getCovariateList <- function(data, xvarNames, na.penalty){
|
||||
covariateList <- .jcast(.jnew(.class_ArrayList, length(xvarNames)), .class_List)
|
||||
|
||||
for(i in 1:length(xvarNames)){
|
||||
|
@ -31,14 +31,14 @@ getCovariateList <- function(data, xvarNames){
|
|||
column <- data[,xName]
|
||||
|
||||
if(class(column) == "numeric" | class(column) == "integer"){
|
||||
covariate <- Java_NumericCovariate(xName, i-1)
|
||||
covariate <- Java_NumericCovariate(xName, i-1, na.penalty[i])
|
||||
}
|
||||
else if(class(column) == "logical"){
|
||||
covariate <- Java_BooleanCovariate(xName, i-1)
|
||||
covariate <- Java_BooleanCovariate(xName, i-1, na.penalty[i])
|
||||
}
|
||||
else if(class(column) == "factor"){
|
||||
lvls <- levels(column)
|
||||
covariate <- Java_FactorCovariate(xName, i-1, lvls)
|
||||
covariate <- Java_FactorCovariate(xName, i-1, lvls, na.penalty[i])
|
||||
}
|
||||
else{
|
||||
stop("Unknown column type")
|
||||
|
@ -54,15 +54,7 @@ getCovariateList <- function(data, xvarNames){
|
|||
|
||||
loadPredictionData <- function(newData, covariateList.java){
|
||||
|
||||
xVarNames <- character(.jcall(covariateList.java, "I", "size"))
|
||||
for(j in 1:length(xVarNames)){
|
||||
covariate.java <- .jcast(
|
||||
.jcall(covariateList.java, makeResponse(.class_Object), "get", as.integer(j-1)),
|
||||
.class_Covariate
|
||||
)
|
||||
|
||||
xVarNames[j] <- .jcall(covariate.java, makeResponse(.class_String), "getName")
|
||||
}
|
||||
xVarNames <- extractCovariateNamesFromJavaList(covariateList.java)
|
||||
|
||||
if(any(!(xVarNames %in% names(newData)))){
|
||||
varsMissing = xVarNames[!(xVarNames %in% names(newData))]
|
||||
|
@ -84,3 +76,16 @@ loadPredictionData <- function(newData, covariateList.java){
|
|||
return(rowList)
|
||||
}
|
||||
|
||||
extractCovariateNamesFromJavaList <- function(covariateList.java){
|
||||
xVarNames <- character(.jcall(covariateList.java, "I", "size"))
|
||||
for(j in 1:length(xVarNames)){
|
||||
covariate.java <- .jcast(
|
||||
.jcall(covariateList.java, makeResponse(.class_Object), "get", as.integer(j-1)),
|
||||
.class_Covariate
|
||||
)
|
||||
|
||||
xVarNames[j] <- .jcall(covariate.java, makeResponse(.class_String), "getName")
|
||||
}
|
||||
|
||||
return(xVarNames)
|
||||
}
|
||||
|
|
|
@ -5,6 +5,12 @@
|
|||
#' Loads a random forest that was saved using \code{\link{saveForest}}.
|
||||
#'
|
||||
#' @param directory The directory created that saved the previous forest.
|
||||
#' @param forest.output Specifies whether the forest loaded should be loaded
|
||||
#' into memory, or reflect the saved files where only one tree is loaded at a
|
||||
#' time.
|
||||
#' @param maxTreeNum If for some reason you only want to load the number of
|
||||
#' trees up until a certain point, you can specify maxTreeNum as a single
|
||||
#' number.
|
||||
#' @return A JForest object; see \code{\link{train}} for details.
|
||||
#' @export
|
||||
#' @seealso \code{\link{train}}, \code{\link{saveForest}}
|
||||
|
@ -20,7 +26,11 @@
|
|||
#'
|
||||
#' saveForest(forest, "trees")
|
||||
#' new_forest <- loadForest("trees")
|
||||
loadForest <- function(directory){
|
||||
loadForest <- function(directory, forest.output = c("online", "offline"), maxTreeNum = NULL){
|
||||
|
||||
if(is.null(forest.output) | length(forest.output)==0 | !(forest.output[1] %in% c("online", "offline"))){
|
||||
stop("forest.output must be one of c(\"online\", \"offline\")")
|
||||
}
|
||||
|
||||
# First load the response combiners and the split finders
|
||||
nodeResponseCombiner.java <- .jcall(.class_DataUtils, makeResponse(.class_Object), "loadObject", paste0(directory, "/nodeResponseCombiner.jData"))
|
||||
|
@ -30,7 +40,7 @@ loadForest <- function(directory){
|
|||
splitFinder.java <- .jcast(splitFinder.java, .class_SplitFinder)
|
||||
|
||||
forestResponseCombiner.java <- .jcall(.class_DataUtils, makeResponse(.class_Object), "loadObject", paste0(directory, "/forestResponseCombiner.jData"))
|
||||
forestResponseCombiner.java <- .jcast(forestResponseCombiner.java, .class_ResponseCombiner)
|
||||
forestResponseCombiner.java <- .jcast(forestResponseCombiner.java, .class_ForestResponseCombiner)
|
||||
|
||||
covariateList <- .jcall(.class_DataUtils, makeResponse(.class_Object), "loadObject", paste0(directory, "/covariateList.jData"))
|
||||
covariateList <- .jcast(covariateList, .class_List)
|
||||
|
@ -42,8 +52,11 @@ loadForest <- function(directory){
|
|||
params$splitFinder$javaObject <- splitFinder.java
|
||||
params$forestResponseCombiner$javaObject <- forestResponseCombiner.java
|
||||
|
||||
forest <- loadForestArgumentsSpecified(directory, params$nodeResponseCombiner, params$splitFinder, params$forestResponseCombiner, covariateList, call,
|
||||
params$ntree, params$numberOfSplits, params$mtry, params$nodeSize, params$maxNodeDepth, params$splitPureNodes, params$randomSeed)
|
||||
forest <- loadForestArgumentsSpecified(directory, params$nodeResponseCombiner, params$splitFinder,
|
||||
params$forestResponseCombiner, covariateList, call,
|
||||
params$ntree, params$numberOfSplits, params$mtry,
|
||||
params$nodeSize, params$maxNodeDepth, params$splitPureNodes,
|
||||
params$randomSeed, forest.output, maxTreeNum)
|
||||
|
||||
return(forest)
|
||||
|
||||
|
@ -55,8 +68,11 @@ loadForest <- function(directory){
|
|||
# that uses the Java version's settings yaml file to recreate the forest, but
|
||||
# I'd appreciate knowing that someone's going to use it first (email me; see
|
||||
# README).
|
||||
loadForestArgumentsSpecified <- function(treeDirectory, nodeResponseCombiner, splitFinder, forestResponseCombiner,
|
||||
covariateList.java, call, ntree, numberOfSplits, mtry, nodeSize, maxNodeDepth = 100000, splitPureNodes=TRUE, randomSeed=NULL){
|
||||
loadForestArgumentsSpecified <- function(treeDirectory, nodeResponseCombiner, splitFinder,
|
||||
forestResponseCombiner, covariateList.java, call,
|
||||
ntree, numberOfSplits, mtry, nodeSize,
|
||||
maxNodeDepth = 100000, splitPureNodes=TRUE,
|
||||
randomSeed=NULL, forest.output = "online", maxTreeNum = NULL){
|
||||
|
||||
params <- list(
|
||||
splitFinder=splitFinder,
|
||||
|
@ -71,7 +87,33 @@ loadForestArgumentsSpecified <- function(treeDirectory, nodeResponseCombiner, sp
|
|||
randomSeed=randomSeed
|
||||
)
|
||||
|
||||
forest.java <- .jcall(.class_DataUtils, makeResponse(.class_Forest), "loadForest", treeDirectory, forestResponseCombiner$javaObject)
|
||||
forest.java <- NULL
|
||||
if(forest.output[1] == "online"){
|
||||
castedForestResponseCombiner <- .jcast(forestResponseCombiner$javaObject, .class_ResponseCombiner) # OnlineForest constructor takes a ResponseCombiner
|
||||
|
||||
if(is.null(maxTreeNum)){
|
||||
forest.java <- .jcall(.class_DataUtils, makeResponse(.class_OnlineForest), "loadOnlineForest",
|
||||
treeDirectory, castedForestResponseCombiner)
|
||||
} else{
|
||||
tree.file.array <- .jcall(.class_RUtils, paste0("[", makeResponse(.class_File)), "getTreeFileArray",
|
||||
treeDirectory, as.integer(maxTreeNum), evalArray = FALSE)
|
||||
forest.java <- .jcall(.class_DataUtils, makeResponse(.class_OnlineForest), "loadOnlineForest",
|
||||
tree.file.array, castedForestResponseCombiner)
|
||||
|
||||
}
|
||||
|
||||
} else{ # offline forest
|
||||
if(is.null(maxTreeNum)){
|
||||
path.as.file <- .jnew(.class_File, treeDirectory)
|
||||
forest.java <- .jnew(.class_OfflineForest, path.as.file, forestResponseCombiner$javaObject)
|
||||
} else{
|
||||
tree.file.array <- .jcall(.class_RUtils, paste0("[", makeResponse(.class_File)), "getTreeFileArray",
|
||||
treeDirectory, as.integer(maxTreeNum), evalArray = FALSE)
|
||||
forest.java <- .jnew(.class_OfflineForest, tree.file.array, forestResponseCombiner$javaObject)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
forestObject <- list(call=call, javaObject=forest.java, covariateList=covariateList.java, params=params)
|
||||
class(forestObject) <- "JRandomForest"
|
||||
|
|
17
R/misc.R
17
R/misc.R
|
@ -14,6 +14,23 @@ convertRListToJava <- function(lst){
|
|||
return(javaList)
|
||||
}
|
||||
|
||||
#Internal function
|
||||
convertJavaListToR <- function(javaList, class = .class_Object){
|
||||
lst <- list()
|
||||
|
||||
javaList.length <- .jcall(javaList, "I", "size")
|
||||
|
||||
for(i in 0:(javaList.length - 1)){
|
||||
object <- .jcall(javaList, makeResponse(.class_Object), "get", as.integer(i))
|
||||
object <- .jcast(object, class)
|
||||
|
||||
lst[[i+1]] <- object
|
||||
}
|
||||
|
||||
return(lst)
|
||||
|
||||
}
|
||||
|
||||
#' @export
|
||||
print.SplitFinder = function(x, ...) print(x$call)
|
||||
|
||||
|
|
|
@ -99,7 +99,7 @@ predict.JRandomForest <- function(object, newData=NULL, parallel=TRUE, out.of.ba
|
|||
predictionsJava <- .jcall(forestObject, makeResponse(.class_List), function.to.use, predictionDataList)
|
||||
|
||||
if(predictionClass == "numeric"){
|
||||
predictions <- vector(length=nrow(newData), mode="numeric")
|
||||
predictions <- vector(length=numRows, mode="numeric")
|
||||
}
|
||||
else{
|
||||
predictions <- list()
|
||||
|
|
39
R/processCensoringDistribution.R
Normal file
39
R/processCensoringDistribution.R
Normal file
|
@ -0,0 +1,39 @@
|
|||
|
||||
|
||||
# Internal function. Takes a censoring distribution and turns it into a
|
||||
# RightContinuousStepFunction Java object.
|
||||
processCensoringDistribution <- function(censoringDistribution){
|
||||
|
||||
if(is.numeric(censoringDistribution)){
|
||||
# estimate ECDF
|
||||
censoringTimes <- .jarray(censoringDistribution, "D")
|
||||
java.censoringDistribution <- .jcall(.class_Utils, makeResponse(.class_RightContinuousStepFunction), "estimateOneMinusECDF", censoringTimes)
|
||||
|
||||
} else if(is.list(censoringDistribution)){
|
||||
# First check that censoringDistribution fits the correct format
|
||||
if(is.null(censoringDistribution$x) | is.null(censoringDistribution$y)){
|
||||
stop("If the censoringDistribution is provided as a list, it must have an x and a y item that are numeric.")
|
||||
}
|
||||
|
||||
if(length(censoringDistribution$x) != length(censoringDistribution$y)){
|
||||
stop("x and y in censoringDistribution must have the same length.")
|
||||
}
|
||||
|
||||
if(!is.numeric(censoringDistribution$x) | !is.numeric(censoringDistribution$y)){
|
||||
stop("x and y in censoringDistribution must both be numeric.")
|
||||
}
|
||||
|
||||
java.censoringDistribution <- createRightContinuousStepFunction(censoringDistribution$x, censoringDistribution$y, defaultY = 1.0)
|
||||
|
||||
} else if("stepfun" %in% class(censoringDistribution)){
|
||||
x <- stats::knots(censoringDistribution)
|
||||
y <- censoringDistribution(x)
|
||||
|
||||
java.censoringDistribution <- createRightContinuousStepFunction(x, y, defaultY = 1.0)
|
||||
}
|
||||
else{
|
||||
stop("Invalid censoringDistribution")
|
||||
}
|
||||
|
||||
return(java.censoringDistribution)
|
||||
}
|
95
R/processFormula.R
Normal file
95
R/processFormula.R
Normal file
|
@ -0,0 +1,95 @@
|
|||
|
||||
# Internal function that takes a formula and processes it for use in the Java
|
||||
# code. existingCovariateList is optional; if not provided then a new one is
|
||||
# created internally.
|
||||
processFormula <- function(formula, data, covariateList.java = NULL, na.penalty = NULL){
|
||||
|
||||
# Having an R copy of the data loaded at the same time can be wasteful; we
|
||||
# also allow users to provide an environment of the data which gets removed
|
||||
# after being imported into Java
|
||||
if(class(data) == "environment"){
|
||||
if(is.null(data$data)){
|
||||
stop("When providing an environment with the dataset, the environment must contain an item called 'data'")
|
||||
}
|
||||
|
||||
env <- data
|
||||
data <- env$data
|
||||
env$data <- NULL
|
||||
gc()
|
||||
}
|
||||
|
||||
yVar <- formula[[2]]
|
||||
|
||||
responses <- NULL
|
||||
variablesToDrop <- character(0)
|
||||
|
||||
# yVar is a call object; as.character(yVar) will be the different components, including the parameters.
|
||||
# if the length of yVar is > 1 then it's a function call. If the length is 1, and it's not in data,
|
||||
# then we also need to explicitly evaluate it
|
||||
if(class(yVar) == "call" || !(as.character(yVar) %in% colnames(data))){
|
||||
# yVar is a function like CR_Response
|
||||
responses <- eval(expr=yVar, envir=data)
|
||||
|
||||
if(class(formula[[3]]) == "name" && as.character(formula[[3]])=="."){
|
||||
# do any of the variables match data in data? We need to track that so we can drop them later
|
||||
variablesToDrop <- as.character(yVar)[as.character(yVar) %in% names(data)]
|
||||
}
|
||||
|
||||
formula[[2]] <- NULL
|
||||
|
||||
} else if(class(yVar) == "name"){ # and implicitly yVar is contained in data
|
||||
variablesToDrop <- as.character(yVar)
|
||||
}
|
||||
|
||||
# Includes responses which we may need to later cut out if `.` was used on the
|
||||
# right-hand-side
|
||||
filteredData <- stats::model.frame(formula=formula, data=data, na.action=stats::na.pass)
|
||||
|
||||
if(is.null(responses)){ # If this if-statement runs then we have a simple (i.e. numeric) response
|
||||
responses <- stats::model.response(filteredData)
|
||||
}
|
||||
|
||||
# remove any response variables on the right-hand-side
|
||||
covariateData <- filteredData[, !(names(filteredData) %in% variablesToDrop), drop=FALSE]
|
||||
|
||||
# Now that we know how many predictor variables we have, we should check na.penalty
|
||||
if(!is.null(na.penalty)){
|
||||
if(!is.numeric(na.penalty) & !is.logical(na.penalty)){
|
||||
stop("na.penalty must be either logical or numeric.")
|
||||
}
|
||||
|
||||
if(is.logical(na.penalty) & length(na.penalty) != 1 & length(na.penalty) != ncol(covariateData)){
|
||||
stop("na.penalty must have length of either 1 or the number of predictor variables if logical.")
|
||||
}
|
||||
|
||||
if(is.numeric(na.penalty) & length(na.penalty) != 1){
|
||||
stop("na.penalty must have length 1 if logical.")
|
||||
}
|
||||
|
||||
if(anyNA(na.penalty)){
|
||||
stop("na.penalty cannot contain NAs.")
|
||||
}
|
||||
|
||||
|
||||
# All good; now to transform it.
|
||||
if(is.numeric(na.penalty)){
|
||||
na.threshold <- na.penalty
|
||||
na.penalty <- apply(covariateData, 2, function(x){mean(is.na(x))}) >= na.threshold
|
||||
}
|
||||
else if(is.logical(na.penalty) & length(na.penalty) == 1){
|
||||
na.penalty <- rep(na.penalty, times = ncol(covariateData))
|
||||
}
|
||||
# else{} - na.penalty is logical and the correct length; no need to do anything to it
|
||||
|
||||
}
|
||||
|
||||
dataset <- loadData(
|
||||
covariateData,
|
||||
colnames(covariateData),
|
||||
responses,
|
||||
covariateList.java = covariateList.java,
|
||||
na.penalty = na.penalty
|
||||
)
|
||||
|
||||
return(dataset)
|
||||
}
|
|
@ -48,7 +48,7 @@ WeightedVarianceSplitFinder <- function(){
|
|||
#'
|
||||
MeanResponseCombiner <- function(){
|
||||
javaObject <- .jnew(.class_MeanResponseCombiner)
|
||||
javaObject <- .jcast(javaObject, .class_ResponseCombiner)
|
||||
javaObject <- .jcast(javaObject, .class_ForestResponseCombiner)
|
||||
|
||||
combiner <- list(javaObject=javaObject, call=match.call(), outputClass="numeric")
|
||||
combiner$convertToRFunction <- function(javaObject, ...){
|
||||
|
|
166
R/train.R
166
R/train.R
|
@ -14,11 +14,11 @@ getCores <- function(){
|
|||
return(cores)
|
||||
}
|
||||
|
||||
train.internal <- function(responses, covariateData, splitFinder,
|
||||
train.internal <- function(dataset, splitFinder,
|
||||
nodeResponseCombiner, forestResponseCombiner, ntree,
|
||||
numberOfSplits, mtry, nodeSize, maxNodeDepth,
|
||||
splitPureNodes, savePath, savePath.overwrite,
|
||||
cores, randomSeed, displayProgress){
|
||||
forest.output, cores, randomSeed, displayProgress){
|
||||
|
||||
# Some quick checks on parameters
|
||||
ntree <- as.integer(ntree)
|
||||
|
@ -51,16 +51,20 @@ train.internal <- function(responses, covariateData, splitFinder,
|
|||
stop("savePath.overwrite must be one of c(\"warn\", \"delete\", \"merge\")")
|
||||
}
|
||||
|
||||
if(is.null(forest.output) | length(forest.output)==0 | !(forest.output[1] %in% c("online", "offline"))){
|
||||
stop("forest.output must be one of c(\"online\", \"offline\")")
|
||||
}
|
||||
|
||||
if(is.null(splitFinder)){
|
||||
splitFinder <- splitFinderDefault(responses)
|
||||
splitFinder <- splitFinderDefault(dataset$responses)
|
||||
}
|
||||
|
||||
if(is.null(nodeResponseCombiner)){
|
||||
nodeResponseCombiner <- nodeResponseCombinerDefault(responses)
|
||||
nodeResponseCombiner <- nodeResponseCombinerDefault(dataset$responses)
|
||||
}
|
||||
|
||||
if(is.null(forestResponseCombiner)){
|
||||
forestResponseCombiner <- forestResponseCombinerDefault(responses)
|
||||
forestResponseCombiner <- forestResponseCombinerDefault(dataset$responses)
|
||||
}
|
||||
|
||||
|
||||
|
@ -75,20 +79,6 @@ train.internal <- function(responses, covariateData, splitFinder,
|
|||
stop("forestResponseCombiner must be a ResponseCombiner")
|
||||
}
|
||||
|
||||
if(class(covariateData)=="environment"){
|
||||
if(is.null(covariateData$data)){
|
||||
stop("When providing an environment with the dataset, the environment must contain an item called 'data'")
|
||||
}
|
||||
dataset <- loadData(covariateData$data, colnames(covariateData$data), responses)
|
||||
covariateData$data <- NULL # save memory, hopefully
|
||||
gc() # explicitly try to save memory
|
||||
}
|
||||
else{
|
||||
dataset <- loadData(covariateData, colnames(covariateData), responses)
|
||||
}
|
||||
|
||||
|
||||
|
||||
treeTrainer <- createTreeTrainer(responseCombiner=nodeResponseCombiner,
|
||||
splitFinder=splitFinder,
|
||||
covariateList=dataset$covariateList,
|
||||
|
@ -143,22 +133,23 @@ train.internal <- function(responses, covariateData, splitFinder,
|
|||
params=params,
|
||||
forestCall=match.call())
|
||||
|
||||
forest.java <- NULL
|
||||
if(cores > 1){
|
||||
.jcall(forestTrainer, "V", "trainParallelOnDisk", .object_Optional(), as.integer(cores))
|
||||
forest.java <- .jcall(forestTrainer, makeResponse(.class_OfflineForest), "trainParallelOnDisk", .object_Optional(), as.integer(cores))
|
||||
} else {
|
||||
.jcall(forestTrainer, "V", "trainSerialOnDisk", .object_Optional())
|
||||
forest.java <- .jcall(forestTrainer, makeResponse(.class_OfflineForest), "trainSerialOnDisk", .object_Optional())
|
||||
}
|
||||
|
||||
# Need to now load forest trees back into memory
|
||||
forest.java <- .jcall(.class_DataUtils, makeResponse(.class_Forest), "loadForest", savePath, forestResponseCombiner$javaObject)
|
||||
|
||||
if(forest.output[1] == "online"){
|
||||
forest.java <- convertToOnlineForest.Java(forest.java)
|
||||
}
|
||||
|
||||
}
|
||||
else{ # save directly into memory
|
||||
if(cores > 1){
|
||||
forest.java <- .jcall(forestTrainer, makeResponse(.class_Forest), "trainParallelInMemory", .object_Optional(), as.integer(cores))
|
||||
forest.java <- .jcall(forestTrainer, makeResponse(.class_OnlineForest), "trainParallelInMemory", .object_Optional(), as.integer(cores))
|
||||
} else {
|
||||
forest.java <- .jcall(forestTrainer, makeResponse(.class_Forest), "trainSerialInMemory", .object_Optional())
|
||||
forest.java <- .jcall(forestTrainer, makeResponse(.class_OnlineForest), "trainSerialInMemory", .object_Optional())
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -188,13 +179,13 @@ train.internal <- function(responses, covariateData, splitFinder,
|
|||
#'
|
||||
#' @param formula You may specify the response and covariates as a formula
|
||||
#' instead; make sure the response in the formula is still properly
|
||||
#' constructed; see \code{responses}
|
||||
#' constructed.
|
||||
#' @param data A data.frame containing the columns of the predictors and
|
||||
#' responses.
|
||||
#' @param splitFinder A split finder that's used to score splits in the random
|
||||
#' forest training algorithm. See \code{\link{CompetingRiskSplitFinders}}
|
||||
#' or \code{\link{WeightedVarianceSplitFinder}}. If you don't specify one,
|
||||
#' this function tries to pick one based on the response. For
|
||||
#' forest training algorithm. See \code{\link{CompetingRiskSplitFinders}} or
|
||||
#' \code{\link{WeightedVarianceSplitFinder}}. If you don't specify one, this
|
||||
#' function tries to pick one based on the response. For
|
||||
#' \code{\link{CR_Response}} without censor times, it will pick a
|
||||
#' \code{\link{LogRankSplitFinder}}; while if censor times were provided it
|
||||
#' will pick \code{\link{GrayLogRankSplitFinder}}; for integer or numeric
|
||||
|
@ -202,19 +193,19 @@ train.internal <- function(responses, covariateData, splitFinder,
|
|||
#' @param nodeResponseCombiner A response combiner that's used to combine
|
||||
#' responses for each terminal node in a tree (regression example; average the
|
||||
#' observations in each tree into a single number). See
|
||||
#' \code{\link{CR_ResponseCombiner}} or
|
||||
#' \code{\link{MeanResponseCombiner}}. If you don't specify one, this function
|
||||
#' tries to pick one based on the response. For \code{\link{CR_Response}} it
|
||||
#' picks a \code{\link{CR_ResponseCombiner}}; for integer or numeric
|
||||
#' responses it picks a \code{\link{MeanResponseCombiner}}.
|
||||
#' \code{\link{CR_ResponseCombiner}} or \code{\link{MeanResponseCombiner}}. If
|
||||
#' you don't specify one, this function tries to pick one based on the
|
||||
#' response. For \code{\link{CR_Response}} it picks a
|
||||
#' \code{\link{CR_ResponseCombiner}}; for integer or numeric responses it
|
||||
#' picks a \code{\link{MeanResponseCombiner}}.
|
||||
#' @param forestResponseCombiner A response combiner that's used to combine
|
||||
#' predictions across trees into one final result (regression example; average
|
||||
#' the prediction of each tree into a single number). See
|
||||
#' \code{\link{CR_FunctionCombiner}} or
|
||||
#' \code{\link{MeanResponseCombiner}}. If you don't specify one, this function
|
||||
#' tries to pick one based on the response. For \code{\link{CR_Response}} it
|
||||
#' picks a \code{\link{CR_FunctionCombiner}}; for integer or numeric
|
||||
#' responses it picks a \code{\link{MeanResponseCombiner}}.
|
||||
#' \code{\link{CR_FunctionCombiner}} or \code{\link{MeanResponseCombiner}}. If
|
||||
#' you don't specify one, this function tries to pick one based on the
|
||||
#' response. For \code{\link{CR_Response}} it picks a
|
||||
#' \code{\link{CR_FunctionCombiner}}; for integer or numeric responses it
|
||||
#' picks a \code{\link{MeanResponseCombiner}}.
|
||||
#' @param ntree An integer that specifies how many trees should be trained.
|
||||
#' @param numberOfSplits A tuning parameter specifying how many random splits
|
||||
#' should be tried for a covariate; a value of 0 means all splits will be
|
||||
|
@ -231,6 +222,20 @@ train.internal <- function(responses, covariateData, splitFinder,
|
|||
#' @param maxNodeDepth This parameter is analogous to \code{nodeSize} in that it
|
||||
#' controls tree length; by default \code{maxNodeDepth} is an extremely high
|
||||
#' number and tree depth is controlled by \code{nodeSize}.
|
||||
#' @param na.penalty This parameter controls whether predictor variables with
|
||||
#' NAs should be penalized when being considered for a best split. Best splits
|
||||
#' (and the associated score) are determined on only non-NA data; the penalty
|
||||
#' is to take the best split identified, and to randomly assign any NAs
|
||||
#' (according to the proportion of data split left and right), and then
|
||||
#' recalculate the corresponding split score, when is then compared with the
|
||||
#' other split candiate variables. This penalty adds some computational time,
|
||||
#' so it may be disabled for some variables. \code{na.penalty} may be
|
||||
#' specified as a vector of logicals indicating, for each predictor variable,
|
||||
#' whether the penalty should be applied to that variable. If it's length 1
|
||||
#' then it applies to all variables. Alternatively, a single numeric value may
|
||||
#' be provided to indicate a threshold whereby the penalty is activated only
|
||||
#' if the proportion of NAs for that variable in the training set exceeds that
|
||||
#' threshold.
|
||||
#' @param splitPureNodes This parameter determines whether the algorithm will
|
||||
#' split a pure node. If set to FALSE, then before every split it will check
|
||||
#' that every response is the same, and if so, not split. If set to TRUE it
|
||||
|
@ -253,6 +258,10 @@ train.internal <- function(responses, covariateData, splitFinder,
|
|||
#' assumes (without checking) that the existing trees are from a previous run
|
||||
#' and starts from where it left off. This option is useful if recovering from
|
||||
#' a crash.
|
||||
#' @param forest.output This parameter only applies if \code{savePath} has been
|
||||
#' set; set to 'online' (default) and the saved forest will be loaded into
|
||||
#' memory after being trained. Set to 'offline' and the forest is not saved
|
||||
#' into memory, but can still be used in a memory unintensive manner.
|
||||
#' @param cores This parameter specifies how many trees will be simultaneously
|
||||
#' trained. By default the package attempts to detect how many cores you have
|
||||
#' by using the \code{parallel} package and using all of them. You may specify
|
||||
|
@ -310,78 +319,21 @@ train.internal <- function(responses, covariateData, splitFinder,
|
|||
#' ypred <- predict(forest, newData)
|
||||
train <- function(formula, data, splitFinder = NULL, nodeResponseCombiner = NULL,
|
||||
forestResponseCombiner = NULL, ntree, numberOfSplits, mtry,
|
||||
nodeSize, maxNodeDepth = 100000, splitPureNodes=TRUE, savePath=NULL,
|
||||
savePath.overwrite=c("warn", "delete", "merge"), cores = getCores(),
|
||||
randomSeed = NULL, displayProgress = TRUE){
|
||||
nodeSize, maxNodeDepth = 100000, na.penalty = TRUE, splitPureNodes=TRUE,
|
||||
savePath = NULL, savePath.overwrite = c("warn", "delete", "merge"),
|
||||
forest.output = c("online", "offline"),
|
||||
cores = getCores(), randomSeed = NULL, displayProgress = TRUE){
|
||||
|
||||
# Having an R copy of the data loaded at the same time can be wasteful; we
|
||||
# also allow users to provide an environment of the data which gets removed
|
||||
# after being imported into Java
|
||||
env <- NULL
|
||||
if(class(data) == "environment"){
|
||||
if(is.null(data$data)){
|
||||
stop("When providing an environment with the dataset, the environment must contain an item called 'data'")
|
||||
}
|
||||
dataset <- processFormula(formula, data, na.penalty = na.penalty)
|
||||
|
||||
env <- data
|
||||
data <- env$data
|
||||
}
|
||||
|
||||
yVar <- formula[[2]]
|
||||
|
||||
responses <- NULL
|
||||
variablesToDrop <- character(0)
|
||||
|
||||
# yVar is a call object; as.character(yVar) will be the different components, including the parameters.
|
||||
# if the length of yVar is > 1 then it's a function call. If the length is 1, and it's not in data,
|
||||
# then we also need to explicitly evaluate it
|
||||
if(class(yVar)=="call" || !(as.character(yVar) %in% colnames(data))){
|
||||
# yVar is a function like CompetingRiskResponses
|
||||
responses <- eval(expr=yVar, envir=data)
|
||||
|
||||
if(class(formula[[3]]) == "name" && as.character(formula[[3]])=="."){
|
||||
# do any of the variables match data in data? We need to track that so we can drop them later
|
||||
variablesToDrop <- as.character(yVar)[as.character(yVar) %in% names(data)]
|
||||
}
|
||||
|
||||
formula[[2]] <- NULL
|
||||
|
||||
} else if(class(yVar)=="name"){ # and implicitly yVar is contained in data
|
||||
variablesToDrop <- as.character(yVar)
|
||||
}
|
||||
|
||||
# Includes responses which we may need to later cut out
|
||||
mf <- stats::model.frame(formula=formula, data=data, na.action=stats::na.pass)
|
||||
|
||||
if(is.null(responses)){
|
||||
responses <- stats::model.response(mf)
|
||||
}
|
||||
|
||||
# remove any response variables
|
||||
mf <- mf[,!(names(mf) %in% variablesToDrop), drop=FALSE]
|
||||
|
||||
# If environment was provided instead of data
|
||||
if(!is.null(env)){
|
||||
env$data <- mf
|
||||
rm(data)
|
||||
forest <- train.internal(responses, env, splitFinder = splitFinder,
|
||||
forest <- train.internal(dataset, splitFinder = splitFinder,
|
||||
nodeResponseCombiner = nodeResponseCombiner,
|
||||
forestResponseCombiner = forestResponseCombiner,
|
||||
ntree = ntree, numberOfSplits = numberOfSplits,
|
||||
mtry = mtry, nodeSize = nodeSize, maxNodeDepth = maxNodeDepth,
|
||||
splitPureNodes = splitPureNodes, savePath = savePath,
|
||||
savePath.overwrite = savePath.overwrite, cores = cores,
|
||||
randomSeed = randomSeed, displayProgress = displayProgress)
|
||||
} else{
|
||||
forest <- train.internal(responses, mf, splitFinder = splitFinder,
|
||||
nodeResponseCombiner = nodeResponseCombiner,
|
||||
forestResponseCombiner = forestResponseCombiner,
|
||||
ntree = ntree, numberOfSplits = numberOfSplits,
|
||||
mtry = mtry, nodeSize = nodeSize, maxNodeDepth = maxNodeDepth,
|
||||
splitPureNodes = splitPureNodes, savePath = savePath,
|
||||
savePath.overwrite = savePath.overwrite, cores = cores,
|
||||
randomSeed = randomSeed, displayProgress = displayProgress)
|
||||
}
|
||||
savePath.overwrite = savePath.overwrite, forest.output = forest.output,
|
||||
cores = cores, randomSeed = randomSeed, displayProgress = displayProgress)
|
||||
|
||||
forest$call <- match.call()
|
||||
forest$formula <- formula
|
||||
|
@ -429,7 +381,9 @@ createTreeTrainer <- function(responseCombiner, splitFinder, covariateList, numb
|
|||
|
||||
builder <- .jcall(.class_TreeTrainer, builderClassReturned, "builder")
|
||||
|
||||
builder <- .jcall(builder, builderClassReturned, "responseCombiner", responseCombiner$javaObject)
|
||||
responseCombinerCasted <- .jcast(responseCombiner$javaObject, .class_ResponseCombiner) # might need to cast a ForestResponseCombiner down
|
||||
|
||||
builder <- .jcall(builder, builderClassReturned, "responseCombiner", responseCombinerCasted)
|
||||
builder <- .jcall(builder, builderClassReturned, "splitFinder", splitFinder$javaObject)
|
||||
builder <- .jcall(builder, builderClassReturned, "covariates", covariateList)
|
||||
builder <- .jcall(builder, builderClassReturned, "numberOfSplits", as.integer(numberOfSplits))
|
||||
|
|
175
R/vimp.R
Normal file
175
R/vimp.R
Normal file
|
@ -0,0 +1,175 @@
|
|||
|
||||
|
||||
#' Variable Importance
|
||||
#'
|
||||
#' Calculate variable importance by recording the increase in error when a given
|
||||
#' predictor is randomly permuted. Regression forests uses mean squared error;
|
||||
#' competing risks uses integrated Brier score.
|
||||
#'
|
||||
#' @param forest The forest that was trained.
|
||||
#' @param newData A test set of the data if available. If not, then out of bag
|
||||
#' errors will be attempted on the training set.
|
||||
#' @param randomSeed The source of randomness used to permute the values. Can be
|
||||
#' left blank.
|
||||
#' @param events If using competing risks forest, the events that the error
|
||||
#' measure used for VIMP should be calculated on.
|
||||
#' @param time If using competing risks forest, the upper bound of the
|
||||
#' integrated Brier score.
|
||||
#' @param censoringDistribution (Optional) If using competing risks forest, the
|
||||
#' censoring distribution. See \code{\link{integratedBrierScore} for details.}
|
||||
#' @param eventWeights (Optional) If using competing risks forest, weights to be
|
||||
#' applied to the error for each of the \code{events}.
|
||||
#'
|
||||
#' @return A named numeric vector of importance values.
|
||||
#' @export
|
||||
#'
|
||||
#' @examples
|
||||
#' data(wihs)
|
||||
#'
|
||||
#' forest <- train(CR_Response(status, time) ~ ., wihs,
|
||||
#' ntree = 100, numberOfSplits = 0, mtry=3, nodeSize = 5)
|
||||
#'
|
||||
#' vimp(forest, events = 1:2, time = 8.0)
|
||||
#'
|
||||
vimp <- function(
|
||||
forest,
|
||||
newData = NULL,
|
||||
randomSeed = NULL,
|
||||
type = c("mean", "z", "raw"),
|
||||
events = NULL,
|
||||
time = NULL,
|
||||
censoringDistribution = NULL,
|
||||
eventWeights = NULL){
|
||||
|
||||
if(is.null(newData) & is.null(forest$dataset)){
|
||||
stop("forest doesn't have a copy of the training data loaded (this happens if you just loaded it); please manually specify newData and possibly out.of.bag")
|
||||
}
|
||||
|
||||
# Basically we check if type is either null, length 0, or one of the invalid values.
|
||||
# We can't include the last statement in the same statement as length(tyoe) < 1,
|
||||
# because R checks both cases and a different error would display if length(type) == 0
|
||||
typeError = is.null(type) | length(type) < 1
|
||||
if(!typeError){
|
||||
typeError = !(type[1] %in% c("mean", "z", "raw"))
|
||||
}
|
||||
if(typeError){
|
||||
stop("A valid response type must be provided.")
|
||||
}
|
||||
|
||||
if(is.null(newData)){
|
||||
data.java <- forest$dataset
|
||||
out.of.bag <- TRUE
|
||||
|
||||
}
|
||||
else{ # newData is provided
|
||||
data.java <- processFormula(forest$formula, newData, forest$covariateList)$dataset
|
||||
out.of.bag <- FALSE
|
||||
}
|
||||
|
||||
predictionClass <- forest$params$forestResponseCombiner$outputClass
|
||||
|
||||
if(predictionClass == "CompetingRiskFunctions"){
|
||||
if(is.null(time) | length(time) != 1){
|
||||
stop("time must be set at length 1")
|
||||
}
|
||||
|
||||
errorCalculator.java <- ibsCalculatorWrapper(
|
||||
events = events,
|
||||
time = time,
|
||||
censoringDistribution = censoringDistribution,
|
||||
eventWeights = eventWeights)
|
||||
|
||||
} else if(predictionClass == "numeric"){
|
||||
errorCalculator.java <- .jnew(.class_RegressionErrorCalculator)
|
||||
errorCalculator.java <- .jcast(errorCalculator.java, .class_ErrorCalculator)
|
||||
|
||||
} else{
|
||||
stop(paste0("VIMP not yet supported for ", predictionClass, ". If you're just using a non-custom version of largeRCRF then this is a bug and should be reported."))
|
||||
|
||||
}
|
||||
|
||||
forest.trees.java <- .jcall(forest$javaObject, makeResponse(.class_List), "getTrees")
|
||||
|
||||
vimp.calculator <- .jnew(.class_VariableImportanceCalculator,
|
||||
errorCalculator.java,
|
||||
forest.trees.java,
|
||||
data.java,
|
||||
out.of.bag # isTrainingSet parameter
|
||||
)
|
||||
|
||||
random.java <- NULL
|
||||
if(!is.null(randomSeed)){
|
||||
random.java <- .jnew(.class_Random, .jlong(as.integer(randomSeed)))
|
||||
}
|
||||
random.java <- .object_Optional(random.java)
|
||||
|
||||
covariateRList <- convertJavaListToR(forest$covariateList, class = .class_Covariate)
|
||||
importanceValues <- matrix(nrow = forest$params$ntree, ncol = length(covariateRList))
|
||||
colnames(importanceValues) <- extractCovariateNamesFromJavaList(forest$covariateList)
|
||||
|
||||
for(j in 1:length(covariateRList)){
|
||||
covariateJava <- covariateRList[[j]]
|
||||
covariateJava <-
|
||||
|
||||
importanceValues[, j] <- .jcall(vimp.calculator, "[D", "calculateVariableImportanceRaw", covariateJava, random.java)
|
||||
}
|
||||
|
||||
if(type[1] == "raw"){
|
||||
return(importanceValues)
|
||||
} else if(type[1] == "mean"){
|
||||
meanImportanceValues <- apply(importanceValues, 2, mean)
|
||||
return(meanImportanceValues)
|
||||
} else if(type[1] == "z"){
|
||||
zImportanceValues <- apply(importanceValues, 2, function(x){
|
||||
meanValue <- mean(x)
|
||||
standardError <- sd(x)/sqrt(length(x))
|
||||
return(meanValue / standardError)
|
||||
})
|
||||
return(zImportanceValues)
|
||||
|
||||
} else{
|
||||
stop("A valid response type must be provided.")
|
||||
}
|
||||
|
||||
|
||||
return(importance)
|
||||
|
||||
}
|
||||
|
||||
# Internal function
|
||||
ibsCalculatorWrapper <- function(events, time, censoringDistribution = NULL, eventWeights = NULL){
|
||||
if(is.null(events)){
|
||||
stop("events must be specified if using vimp on competing risks data")
|
||||
}
|
||||
|
||||
if(is.null(time)){
|
||||
stop("time must be specified if using vimp on competing risks data")
|
||||
}
|
||||
|
||||
|
||||
java.censoringDistribution <- NULL
|
||||
if(!is.null(censoringDistribution)){
|
||||
java.censoringDistribution <- processCensoringDistribution(censoringDistribution)
|
||||
java.censoringDistribution <- .object_Optional(java.censoringDistribution)
|
||||
}
|
||||
else{
|
||||
java.censoringDistribution <- .object_Optional(NULL)
|
||||
}
|
||||
|
||||
ibsCalculator.java <- .jnew(.class_IBSCalculator, java.censoringDistribution)
|
||||
|
||||
if(is.null(eventWeights)){
|
||||
eventWeights <- rep(1, times = length(events))
|
||||
}
|
||||
|
||||
ibsCalculatorWrapper.java <- .jnew(.class_IBSErrorCalculatorWrapper,
|
||||
ibsCalculator.java,
|
||||
.jarray(as.integer(events)),
|
||||
as.numeric(time),
|
||||
.jarray(as.numeric(eventWeights)))
|
||||
|
||||
ibsCalculatorWrapper.java <- .jcast(ibsCalculatorWrapper.java, .class_ErrorCalculator)
|
||||
return(ibsCalculatorWrapper.java)
|
||||
|
||||
|
||||
}
|
10
R/wihs.R
10
R/wihs.R
|
@ -15,9 +15,13 @@
|
|||
#' @source The data was obtained from the randomForestSRC R package.
|
||||
#'
|
||||
#' @references Bacon MC, von Wyl V, Alden C, Sharp G, Robison E, Hessol N, Gange
|
||||
#' S, Barranday Y, Holman S, Weber K, Young MA (2005). “The Women’s
|
||||
#' Interagency HIV Study: an Observational Cohort Brings Clinical Sciences to
|
||||
#' the Bench.” Clinical and Vaccine Immunology, 12(9), 1013–1019.
|
||||
#' S, Barranday Y, Holman S, Weber K, Young MA (2005). “The Women’s Interagency
|
||||
#' HIV Study: an Observational Cohort Brings Clinical Sciences to the Bench.”
|
||||
#' Clinical and Vaccine Immunology, 12(9), 1013–1019.
|
||||
#' doi:10.1128/CDLI.12.9.1013-1019.2005.
|
||||
#'
|
||||
#' Ishwaran, H., and Udaya B. Kogalur. 2018. Random Forests for Survival,
|
||||
#' Regression, and Classification (Rf-Src). manual.
|
||||
#' https://cran.r-project.org/package=randomForestSRC.
|
||||
#'
|
||||
"wihs"
|
Binary file not shown.
|
@ -2,4 +2,4 @@ The Java source code for this package can be obtained at https://github.com/jath
|
|||
|
||||
* Delete the Jar file in `inst/java/`
|
||||
* Build the Java code in its own separate directory using `mvn clean package` in the root of the directory (same folder containing `README.md`). Make sure you have [Maven](https://maven.apache.org/) installed.
|
||||
* Copy the `library/target/largeRCRF-library/1.0-SNAPSHOT.jar` file produced in the Java directory into `inst/java/`
|
||||
* Copy the `library/target/largeRCRF-library-1.0-SNAPSHOT.jar` file produced in the Java directory into `inst/java/`
|
|
@ -6,7 +6,8 @@
|
|||
\usage{
|
||||
addTrees(forest, numTreesToAdd, savePath = NULL,
|
||||
savePath.overwrite = c("warn", "delete", "merge"),
|
||||
cores = getCores(), displayProgress = TRUE)
|
||||
forest.output = c("online", "offline"), cores = getCores(),
|
||||
displayProgress = TRUE)
|
||||
}
|
||||
\arguments{
|
||||
\item{forest}{An existing forest.}
|
||||
|
@ -21,6 +22,11 @@ a previously saved forest.}
|
|||
directory, possibly containing another forest, this specifies what should
|
||||
be done.}
|
||||
|
||||
\item{forest.output}{This parameter only applies if \code{savePath} has been
|
||||
set; set to 'online' (default) and the saved forest will be loaded into
|
||||
memory after being trained. Set to 'offline' and the forest is not saved
|
||||
into memory, but can still be used in a memory unintensive manner.}
|
||||
|
||||
\item{cores}{The number of cores to be used for training the new trees.}
|
||||
|
||||
\item{displayProgress}{A logical indicating whether the progress should be
|
||||
|
|
20
man/convertToOnlineForest.Rd
Normal file
20
man/convertToOnlineForest.Rd
Normal file
|
@ -0,0 +1,20 @@
|
|||
% Generated by roxygen2: do not edit by hand
|
||||
% Please edit documentation in R/convertToOnlineForest.R
|
||||
\name{convertToOnlineForest}
|
||||
\alias{convertToOnlineForest}
|
||||
\title{Convert to Online Forest}
|
||||
\usage{
|
||||
convertToOnlineForest(forest)
|
||||
}
|
||||
\arguments{
|
||||
\item{forest}{The offline forest.}
|
||||
}
|
||||
\value{
|
||||
An online, in memory forst.
|
||||
}
|
||||
\description{
|
||||
Some forests are too large to store in memory and have been saved to disk.
|
||||
They can still be used, but their performance is much slower. If there's
|
||||
enough memory, they can easily be converted into an in-memory forest that is
|
||||
faster to use.
|
||||
}
|
69
man/integratedBrierScore.Rd
Normal file
69
man/integratedBrierScore.Rd
Normal file
|
@ -0,0 +1,69 @@
|
|||
% Generated by roxygen2: do not edit by hand
|
||||
% Please edit documentation in R/cr_integratedBrierScore.R
|
||||
\name{integratedBrierScore}
|
||||
\alias{integratedBrierScore}
|
||||
\title{Integrated Brier Score}
|
||||
\usage{
|
||||
integratedBrierScore(responses, predictions, event, time,
|
||||
censoringDistribution = NULL, parallel = TRUE)
|
||||
}
|
||||
\arguments{
|
||||
\item{responses}{A list of responses corresponding to the provided
|
||||
mortalities; use \code{\link{CR_Response}}.}
|
||||
|
||||
\item{predictions}{The predictions to be tested against.}
|
||||
|
||||
\item{event}{The event type for the error to be calculated on.}
|
||||
|
||||
\item{time}{\code{time} specifies the upper bound of the integral.}
|
||||
|
||||
\item{censoringDistribution}{Optional; if provided then weights are
|
||||
calculated on the errors. There are three ways to provide it - \itemize{
|
||||
\item{If you have all the censor times and just want to use a simple
|
||||
empirical estimate of the distribution, just provide a numeric vector of
|
||||
all of the censor times and it will be automatically calculated.} \item{You
|
||||
can directly specify the survivor function by providing a list with two
|
||||
numeric vectors called \code{x} and \code{y}. They should be of the same
|
||||
length and correspond to each point. It is assumed that previous to the
|
||||
first value in \code{y} the \code{y} value is 1.0; and that the function
|
||||
you provide is a right-continuous step function.} \item{You can provide a
|
||||
function from \code{\link[stats]{stepfun}}. Note that this only supports
|
||||
functions where \code{right = FALSE} (default), and that the first y value
|
||||
(corresponding to y before the first x value) will be to set to 1.0
|
||||
regardless of what is specified.}
|
||||
|
||||
}}
|
||||
|
||||
\item{parallel}{A logical indicating whether multiple cores should be
|
||||
utilized when calculating the error. Available as an option because it's
|
||||
been observed that using Java's \code{parallelStream} can be unstable on
|
||||
some systems. Default value is \code{TRUE}; only set to \code{FALSE} if you
|
||||
get strange errors while predicting.}
|
||||
}
|
||||
\value{
|
||||
A numeric vector of the Integrated Brier Score for each prediction.
|
||||
}
|
||||
\description{
|
||||
Used to calculate the Integrated Brier Score, which for the competing risks
|
||||
setting is the integral of the squared difference between each observed
|
||||
cumulative incidence function (CIF) for each observation and the
|
||||
corresponding predicted CIF. If the survivor function (1 - CDF) of the
|
||||
censoring distribution is provided, weights can be calculated to account for
|
||||
the censoring.
|
||||
}
|
||||
\examples{
|
||||
data <- data.frame(delta=c(1,1,0,0,2,2), T=1:6, x=1:6)
|
||||
|
||||
model <- train(CR_Response(delta, T) ~ x, data, ntree=100, numberOfSplits=0, mtry=1, nodeSize=1)
|
||||
|
||||
newData <- data.frame(delta=c(1,0,2,1,0,2), T=1:6, x=1:6)
|
||||
predictions <- predict(model, newData)
|
||||
|
||||
scores <- integratedBrierScore(CR_Response(data$delta, data$T), predictions, 1, 6.0)
|
||||
|
||||
}
|
||||
\references{
|
||||
Section 4.2 of Ishwaran H, Gerds TA, Kogalur UB, Moore RD, Gange
|
||||
SJ, Lau BM (2014). “Random Survival Forests for Competing Risks.”
|
||||
Biostatistics, 15(4), 757–773. doi:10.1093/ biostatistics/kxu010.
|
||||
}
|
|
@ -4,10 +4,19 @@
|
|||
\alias{loadForest}
|
||||
\title{Load Random Forest}
|
||||
\usage{
|
||||
loadForest(directory)
|
||||
loadForest(directory, forest.output = c("online", "offline"),
|
||||
maxTreeNum = NULL)
|
||||
}
|
||||
\arguments{
|
||||
\item{directory}{The directory created that saved the previous forest.}
|
||||
|
||||
\item{forest.output}{Specifies whether the forest loaded should be loaded
|
||||
into memory, or reflect the saved files where only one tree is loaded at a
|
||||
time.}
|
||||
|
||||
\item{maxTreeNum}{If for some reason you only want to load the number of
|
||||
trees up until a certain point, you can specify maxTreeNum as a single
|
||||
number.}
|
||||
}
|
||||
\value{
|
||||
A JForest object; see \code{\link{train}} for details.
|
||||
|
|
|
@ -41,3 +41,9 @@ mortalities <- list(
|
|||
naiveConcordance(CR_Response(newData$delta, newData$T), mortalities)
|
||||
|
||||
}
|
||||
\references{
|
||||
Section 3.2 of Wolbers, Marcel, Paul Blanche, Michael T. Koller,
|
||||
Jacqueline C M Witteman, and Thomas A Gerds. 2014. “Concordance for
|
||||
Prognostic Models with Competing Risks.” Biostatistics 15 (3): 526–39.
|
||||
https://doi.org/10.1093/biostatistics/kxt059.
|
||||
}
|
||||
|
|
55
man/train.Rd
55
man/train.Rd
|
@ -6,22 +6,23 @@
|
|||
\usage{
|
||||
train(formula, data, splitFinder = NULL, nodeResponseCombiner = NULL,
|
||||
forestResponseCombiner = NULL, ntree, numberOfSplits, mtry, nodeSize,
|
||||
maxNodeDepth = 1e+05, splitPureNodes = TRUE, savePath = NULL,
|
||||
savePath.overwrite = c("warn", "delete", "merge"),
|
||||
cores = getCores(), randomSeed = NULL, displayProgress = TRUE)
|
||||
maxNodeDepth = 1e+05, na.penalty = TRUE, splitPureNodes = TRUE,
|
||||
savePath = NULL, savePath.overwrite = c("warn", "delete", "merge"),
|
||||
forest.output = c("online", "offline"), cores = getCores(),
|
||||
randomSeed = NULL, displayProgress = TRUE)
|
||||
}
|
||||
\arguments{
|
||||
\item{formula}{You may specify the response and covariates as a formula
|
||||
instead; make sure the response in the formula is still properly
|
||||
constructed; see \code{responses}}
|
||||
constructed.}
|
||||
|
||||
\item{data}{A data.frame containing the columns of the predictors and
|
||||
responses.}
|
||||
|
||||
\item{splitFinder}{A split finder that's used to score splits in the random
|
||||
forest training algorithm. See \code{\link{CompetingRiskSplitFinders}}
|
||||
or \code{\link{WeightedVarianceSplitFinder}}. If you don't specify one,
|
||||
this function tries to pick one based on the response. For
|
||||
forest training algorithm. See \code{\link{CompetingRiskSplitFinders}} or
|
||||
\code{\link{WeightedVarianceSplitFinder}}. If you don't specify one, this
|
||||
function tries to pick one based on the response. For
|
||||
\code{\link{CR_Response}} without censor times, it will pick a
|
||||
\code{\link{LogRankSplitFinder}}; while if censor times were provided it
|
||||
will pick \code{\link{GrayLogRankSplitFinder}}; for integer or numeric
|
||||
|
@ -30,20 +31,20 @@ responses it picks a \code{\link{WeightedVarianceSplitFinder}}.}
|
|||
\item{nodeResponseCombiner}{A response combiner that's used to combine
|
||||
responses for each terminal node in a tree (regression example; average the
|
||||
observations in each tree into a single number). See
|
||||
\code{\link{CR_ResponseCombiner}} or
|
||||
\code{\link{MeanResponseCombiner}}. If you don't specify one, this function
|
||||
tries to pick one based on the response. For \code{\link{CR_Response}} it
|
||||
picks a \code{\link{CR_ResponseCombiner}}; for integer or numeric
|
||||
responses it picks a \code{\link{MeanResponseCombiner}}.}
|
||||
\code{\link{CR_ResponseCombiner}} or \code{\link{MeanResponseCombiner}}. If
|
||||
you don't specify one, this function tries to pick one based on the
|
||||
response. For \code{\link{CR_Response}} it picks a
|
||||
\code{\link{CR_ResponseCombiner}}; for integer or numeric responses it
|
||||
picks a \code{\link{MeanResponseCombiner}}.}
|
||||
|
||||
\item{forestResponseCombiner}{A response combiner that's used to combine
|
||||
predictions across trees into one final result (regression example; average
|
||||
the prediction of each tree into a single number). See
|
||||
\code{\link{CR_FunctionCombiner}} or
|
||||
\code{\link{MeanResponseCombiner}}. If you don't specify one, this function
|
||||
tries to pick one based on the response. For \code{\link{CR_Response}} it
|
||||
picks a \code{\link{CR_FunctionCombiner}}; for integer or numeric
|
||||
responses it picks a \code{\link{MeanResponseCombiner}}.}
|
||||
\code{\link{CR_FunctionCombiner}} or \code{\link{MeanResponseCombiner}}. If
|
||||
you don't specify one, this function tries to pick one based on the
|
||||
response. For \code{\link{CR_Response}} it picks a
|
||||
\code{\link{CR_FunctionCombiner}}; for integer or numeric responses it
|
||||
picks a \code{\link{MeanResponseCombiner}}.}
|
||||
|
||||
\item{ntree}{An integer that specifies how many trees should be trained.}
|
||||
|
||||
|
@ -66,6 +67,21 @@ as large as \code{nodeSize}.}
|
|||
controls tree length; by default \code{maxNodeDepth} is an extremely high
|
||||
number and tree depth is controlled by \code{nodeSize}.}
|
||||
|
||||
\item{na.penalty}{This parameter controls whether predictor variables with
|
||||
NAs should be penalized when being considered for a best split. Best splits
|
||||
(and the associated score) are determined on only non-NA data; the penalty
|
||||
is to take the best split identified, and to randomly assign any NAs
|
||||
(according to the proportion of data split left and right), and then
|
||||
recalculate the corresponding split score, when is then compared with the
|
||||
other split candiate variables. This penalty adds some computational time,
|
||||
so it may be disabled for some variables. \code{na.penalty} may be
|
||||
specified as a vector of logicals indicating, for each predictor variable,
|
||||
whether the penalty should be applied to that variable. If it's length 1
|
||||
then it applies to all variables. Alternatively, a single numeric value may
|
||||
be provided to indicate a threshold whereby the penalty is activated only
|
||||
if the proportion of NAs for that variable in the training set exceeds that
|
||||
threshold.}
|
||||
|
||||
\item{splitPureNodes}{This parameter determines whether the algorithm will
|
||||
split a pure node. If set to FALSE, then before every split it will check
|
||||
that every response is the same, and if so, not split. If set to TRUE it
|
||||
|
@ -91,6 +107,11 @@ assumes (without checking) that the existing trees are from a previous run
|
|||
and starts from where it left off. This option is useful if recovering from
|
||||
a crash.}
|
||||
|
||||
\item{forest.output}{This parameter only applies if \code{savePath} has been
|
||||
set; set to 'online' (default) and the saved forest will be loaded into
|
||||
memory after being trained. Set to 'offline' and the forest is not saved
|
||||
into memory, but can still be used in a memory unintensive manner.}
|
||||
|
||||
\item{cores}{This parameter specifies how many trees will be simultaneously
|
||||
trained. By default the package attempts to detect how many cores you have
|
||||
by using the \code{parallel} package and using all of them. You may specify
|
||||
|
|
48
man/vimp.Rd
Normal file
48
man/vimp.Rd
Normal file
|
@ -0,0 +1,48 @@
|
|||
% Generated by roxygen2: do not edit by hand
|
||||
% Please edit documentation in R/vimp.R
|
||||
\name{vimp}
|
||||
\alias{vimp}
|
||||
\title{Variable Importance}
|
||||
\usage{
|
||||
vimp(forest, newData = NULL, randomSeed = NULL, type = c("mean", "z",
|
||||
"raw"), events = NULL, time = NULL, censoringDistribution = NULL,
|
||||
eventWeights = NULL)
|
||||
}
|
||||
\arguments{
|
||||
\item{forest}{The forest that was trained.}
|
||||
|
||||
\item{newData}{A test set of the data if available. If not, then out of bag
|
||||
errors will be attempted on the training set.}
|
||||
|
||||
\item{randomSeed}{The source of randomness used to permute the values. Can be
|
||||
left blank.}
|
||||
|
||||
\item{events}{If using competing risks forest, the events that the error
|
||||
measure used for VIMP should be calculated on.}
|
||||
|
||||
\item{time}{If using competing risks forest, the upper bound of the
|
||||
integrated Brier score.}
|
||||
|
||||
\item{censoringDistribution}{(Optional) If using competing risks forest, the
|
||||
censoring distribution. See \code{\link{integratedBrierScore} for details.}}
|
||||
|
||||
\item{eventWeights}{(Optional) If using competing risks forest, weights to be
|
||||
applied to the error for each of the \code{events}.}
|
||||
}
|
||||
\value{
|
||||
A named numeric vector of importance values.
|
||||
}
|
||||
\description{
|
||||
Calculate variable importance by recording the increase in error when a given
|
||||
predictor is randomly permuted. Regression forests uses mean squared error;
|
||||
competing risks uses integrated Brier score.
|
||||
}
|
||||
\examples{
|
||||
data(wihs)
|
||||
|
||||
forest <- train(CR_Response(status, time) ~ ., wihs,
|
||||
ntree = 100, numberOfSplits = 0, mtry=3, nodeSize = 5)
|
||||
|
||||
vimp(forest, events = 1:2, time = 8.0)
|
||||
|
||||
}
|
10
man/wihs.Rd
10
man/wihs.Rd
|
@ -25,9 +25,13 @@ time may also be censored.
|
|||
}
|
||||
\references{
|
||||
Bacon MC, von Wyl V, Alden C, Sharp G, Robison E, Hessol N, Gange
|
||||
S, Barranday Y, Holman S, Weber K, Young MA (2005). “The Women’s
|
||||
Interagency HIV Study: an Observational Cohort Brings Clinical Sciences to
|
||||
the Bench.” Clinical and Vaccine Immunology, 12(9), 1013–1019.
|
||||
S, Barranday Y, Holman S, Weber K, Young MA (2005). “The Women’s Interagency
|
||||
HIV Study: an Observational Cohort Brings Clinical Sciences to the Bench.”
|
||||
Clinical and Vaccine Immunology, 12(9), 1013–1019.
|
||||
doi:10.1128/CDLI.12.9.1013-1019.2005.
|
||||
|
||||
Ishwaran, H., and Udaya B. Kogalur. 2018. Random Forests for Survival,
|
||||
Regression, and Classification (Rf-Src). manual.
|
||||
https://cran.r-project.org/package=randomForestSRC.
|
||||
}
|
||||
\keyword{datasets}
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
context("Add trees on existing forest")
|
||||
|
||||
test_that("Can add trees on existing forest", {
|
||||
test_that("Can add trees on existing online forest", {
|
||||
|
||||
trainingData <- data.frame(x=rnorm(100))
|
||||
trainingData$T <- rexp(100) + abs(trainingData$x)
|
||||
|
@ -20,6 +20,44 @@ test_that("Can add trees on existing forest", {
|
|||
|
||||
})
|
||||
|
||||
test_that("Can add trees on existing offline forest", {
|
||||
|
||||
if(file.exists("trees")){ # folder could exist from a previous failed test; delete it
|
||||
unlink("trees", recursive=TRUE)
|
||||
}
|
||||
|
||||
|
||||
trainingData <- data.frame(x=rnorm(100))
|
||||
trainingData$T <- rexp(100) + abs(trainingData$x)
|
||||
trainingData$delta <- sample(0:2, size = 100, replace=TRUE)
|
||||
|
||||
forest <- train(CR_Response(delta, T) ~ x, trainingData, ntree=50,
|
||||
numberOfSplits=0, mtry=1, nodeSize=5,
|
||||
forestResponseCombiner = CR_FunctionCombiner(events = 1:2, times = 0:10), # TODO - remove specifing times; this is workaround around unimplemented feature for offline forests
|
||||
cores=2, displayProgress=FALSE, savePath="trees",
|
||||
forest.output = "offline")
|
||||
warning("TODO - need to implement feature; test workaround in the meantime")
|
||||
|
||||
predictions <- predict(forest)
|
||||
|
||||
warning_message <- "Assuming that the previous forest at savePath is the provided forest argument; if not true then your results will be suspect"
|
||||
|
||||
forest.more <- expect_warning(addTrees(forest, 50, cores=2, displayProgress=FALSE,
|
||||
savePath="trees", savePath.overwrite = "merge",
|
||||
forest.output = "offline"), fixed=warning_message) # test multi-core
|
||||
|
||||
predictions <- predict(forest)
|
||||
|
||||
forest.more <- expect_warning(addTrees(forest, 50, cores=1, displayProgress=FALSE,
|
||||
savePath="trees", savePath.overwrite = "merge",
|
||||
forest.output = "offline"), fixed=warning_message) # test single-core
|
||||
|
||||
expect_true(T) # show Ok if we got this far
|
||||
|
||||
unlink("trees", recursive=TRUE)
|
||||
|
||||
})
|
||||
|
||||
test_that("Test adding trees on saved forest - using delete", {
|
||||
|
||||
if(file.exists("trees")){ # folder could exist from a previous failed test; delete it
|
||||
|
|
38
tests/testthat/test_brier_score.R
Normal file
38
tests/testthat/test_brier_score.R
Normal file
|
@ -0,0 +1,38 @@
|
|||
context("Calculate integrated Brier score without error")
|
||||
|
||||
# This code is more concerned that the code runs without error. The tests in the
|
||||
# Java code check that the results it returns are accurate.
|
||||
test_that("Can calculate Integrated Brier Score", {
|
||||
|
||||
sampleData <- data.frame(x=rnorm(100))
|
||||
sampleData$T <- sample(0:4, size=100, replace = TRUE) # the censor distribution we provide needs to conform to the data or we can get NaNs
|
||||
sampleData$delta <- sample(0:2, size = 100, replace = TRUE)
|
||||
|
||||
testData <- sampleData[1:5,]
|
||||
trainingData <- sampleData[6:100,]
|
||||
|
||||
forest <- train(CR_Response(delta, T) ~ x, trainingData, ntree=50, numberOfSplits=0, mtry=1, nodeSize=5, cores=2, displayProgress=FALSE)
|
||||
|
||||
predictions <- predict(forest, testData)
|
||||
|
||||
scores_test <- integratedBrierScore(CR_Response(testData$delta, testData$T), predictions, event = 1, time = 4,
|
||||
censoringDistribution = NULL)
|
||||
# Check that we don't get a crash if we calculate the error for only one observation
|
||||
scores_one <- integratedBrierScore(CR_Response(testData$delta, testData$T)[1], predictions[1], event = 1, time = 4,
|
||||
censoringDistribution = NULL)
|
||||
|
||||
# Make sure our error didn't somehow change
|
||||
expect_equal(scores_one, scores_test[1])
|
||||
|
||||
# Provide a censoring distribution via censor times
|
||||
scores_censoring1 <- integratedBrierScore(CR_Response(testData$delta, testData$T), predictions, event = 1, time = 4,
|
||||
censoringDistribution = c(0,1,1,2,3,4))
|
||||
scores_censoring2 <- integratedBrierScore(CR_Response(testData$delta, testData$T), predictions, event = 1, time = 4,
|
||||
censoringDistribution = list(x = 0:4, y = 1 - c(1/6, 3/6, 4/6, 5/6, 6/6)))
|
||||
scores_censoring3 <- integratedBrierScore(CR_Response(testData$delta, testData$T), predictions, event = 1, time = 4,
|
||||
censoringDistribution = stepfun(x=0:4, y=1 - c(0, 1/6, 3/6, 4/6, 5/6, 6/6)))
|
||||
|
||||
expect_equal(scores_censoring1, scores_censoring2)
|
||||
expect_equal(scores_censoring1, scores_censoring3)
|
||||
|
||||
})
|
|
@ -15,3 +15,44 @@ test_that("CR_Response of length 1 - no censor times", {
|
|||
expect_true(T) # show Ok if we got this far
|
||||
|
||||
})
|
||||
|
||||
test_that("Can sub-index CR_Response - no censor times", {
|
||||
|
||||
x <- CR_Response(1:5, 1:5)
|
||||
|
||||
index <- 5
|
||||
|
||||
y <- x[index]
|
||||
|
||||
expect_equal(y$eventTime, index)
|
||||
expect_equal(y$eventIndicator, index)
|
||||
|
||||
expect_equal(rJava::.jcall(y$javaObject, "I", "size"), 1)
|
||||
oneJavaItem <- rJava::.jcall(y$javaObject, largeRCRF:::makeResponse(largeRCRF:::.class_Object), "get", 0L)
|
||||
oneJavaItem <- rJava::.jcast(oneJavaItem, largeRCRF:::.class_CompetingRiskResponse)
|
||||
delta <- rJava::.jcall(oneJavaItem, "I", "getDelta")
|
||||
|
||||
expect_equal(delta, index)
|
||||
|
||||
})
|
||||
|
||||
test_that("Can sub-index CR_Response - censor times", {
|
||||
|
||||
x <- CR_Response(1:5, 1:5, 1:5)
|
||||
|
||||
index <- 5
|
||||
|
||||
y <- x[index]
|
||||
|
||||
expect_equal(y$eventTime, index)
|
||||
expect_equal(y$eventIndicator, index)
|
||||
expect_equal(y$censorTime, index)
|
||||
|
||||
expect_equal(rJava::.jcall(y$javaObject, "I", "size"), 1)
|
||||
oneJavaItem <- rJava::.jcall(y$javaObject, largeRCRF:::makeResponse(largeRCRF:::.class_Object), "get", 0L)
|
||||
oneJavaItem <- rJava::.jcast(oneJavaItem, largeRCRF:::.class_CompetingRiskResponseWithCensorTime)
|
||||
delta <- rJava::.jcall(oneJavaItem, "D", "getC")
|
||||
|
||||
expect_equal(delta, index)
|
||||
|
||||
})
|
|
@ -28,6 +28,7 @@ test_that("Regresssion doesn't crash", {
|
|||
forest <- train(y ~ x, trainingData, ntree=50, numberOfSplits=0, mtry=1, nodeSize=5, cores=2, displayProgress=FALSE)
|
||||
|
||||
predictions <- predict(forest, testData)
|
||||
other_predictions <- predict(forest) # there was a bug if newData wasn't provided.
|
||||
|
||||
expect_true(T) # show Ok if we got this far
|
||||
|
||||
|
|
|
@ -2,7 +2,11 @@ context("Train, save, and load without error")
|
|||
|
||||
test_that("Can save & load regression example", {
|
||||
|
||||
expect_false(file.exists("trees_saving_loading")) # Folder shouldn't exist yet
|
||||
if(file.exists("trees_saving_loading")){
|
||||
unlink("trees_saving_loading", recursive=TRUE)
|
||||
}
|
||||
|
||||
expect_false(file.exists("trees_saving_loading")) # Folder shouldn't exist at this point
|
||||
|
||||
x1 <- rnorm(1000)
|
||||
x2 <- rnorm(1000)
|
||||
|
|
|
@ -13,7 +13,7 @@ test_that("Can save a random forest while training, and use it afterward", {
|
|||
data <- data.frame(x1, x2, y)
|
||||
forest <- train(y ~ x1 + x2, data,
|
||||
ntree=100, numberOfSplits = 5, mtry = 1, nodeSize = 5,
|
||||
savePath="trees", displayProgress=FALSE)
|
||||
savePath="trees", forest.output = "online", displayProgress=FALSE)
|
||||
|
||||
expect_true(file.exists("trees")) # Something should have been saved
|
||||
|
||||
|
@ -26,6 +26,39 @@ test_that("Can save a random forest while training, and use it afterward", {
|
|||
predictions <- predict(newforest, newData)
|
||||
|
||||
|
||||
unlink("trees", recursive=TRUE)
|
||||
|
||||
})
|
||||
|
||||
test_that("Can save a random forest while training, and use it afterward with pure offline forest", {
|
||||
|
||||
if(file.exists("trees")){ # folder could exist from a previous failed test; delete it
|
||||
unlink("trees", recursive=TRUE)
|
||||
}
|
||||
|
||||
x1 <- rnorm(1000)
|
||||
x2 <- rnorm(1000)
|
||||
y <- 1 + x1 + x2 + rnorm(1000)
|
||||
|
||||
data <- data.frame(x1, x2, y)
|
||||
forest <- train(y ~ x1 + x2, data,
|
||||
ntree=100, numberOfSplits = 5, mtry = 1, nodeSize = 5,
|
||||
savePath="trees", forest.output = "offline", displayProgress=FALSE)
|
||||
|
||||
expect_true(file.exists("trees")) # Something should have been saved
|
||||
|
||||
# try making a little prediction to verify it works
|
||||
newData <- data.frame(x1=seq(from=-3, to=3, by=0.5), x2=0)
|
||||
predictions <- predict(forest, newData)
|
||||
|
||||
# Also make sure we can load the forest too
|
||||
newforest <- loadForest("trees")
|
||||
predictions <- predict(newforest, newData)
|
||||
|
||||
# Last, make sure we can take the forest online
|
||||
onlineForest <- convertToOnlineForest(forest)
|
||||
predictions <- predict(onlineForest, newData)
|
||||
|
||||
unlink("trees", recursive=TRUE)
|
||||
|
||||
})
|
102
tests/testthat/test_vimp.R
Normal file
102
tests/testthat/test_vimp.R
Normal file
|
@ -0,0 +1,102 @@
|
|||
context("Use VIMP without error")
|
||||
|
||||
test_that("VIMP doesn't crash; no test dataset", {
|
||||
|
||||
data(wihs)
|
||||
|
||||
forest <- train(CR_Response(status, time) ~ ., wihs, ntree=50, numberOfSplits=0, mtry=1, nodeSize=5, displayProgress=FALSE)
|
||||
|
||||
# Run VIMP several times under different scenarios
|
||||
importance <- vimp(forest, type="raw", events=1:2, time=5.0)
|
||||
vimp(forest, type="raw", events=1, time=5.0)
|
||||
vimp(forest, type="raw", events=1:2, time=5.0, eventWeights = c(0.2, 0.8))
|
||||
|
||||
# Not much of a test, but the Java code tests more for correctness. This just
|
||||
# tests that the R code runs without error.
|
||||
expect_equal(ncol(importance), 4) # 4 predictors
|
||||
|
||||
})
|
||||
|
||||
|
||||
test_that("VIMP doesn't crash; test dataset", {
|
||||
|
||||
data(wihs)
|
||||
|
||||
trainingData <- wihs[1:1000,]
|
||||
testData <- wihs[1001:nrow(wihs),]
|
||||
|
||||
forest <- train(CR_Response(status, time) ~ ., trainingData, ntree=50, numberOfSplits=0, mtry=1, nodeSize=5, displayProgress=FALSE, cores=1)
|
||||
|
||||
# Run VIMP several times under different scenarios
|
||||
importance <- vimp(forest, newData=testData, type="raw", events=1:2, time=5.0)
|
||||
vimp(forest, newData=testData, type="raw", events=1, time=5.0)
|
||||
vimp(forest, newData=testData, type="raw", events=1:2, time=5.0, eventWeights = c(0.2, 0.8))
|
||||
|
||||
# Not much of a test, but the Java code tests more for correctness. This just
|
||||
# tests that the R code runs without error.
|
||||
expect_equal(ncol(importance), 4) # 4 predictors
|
||||
|
||||
})
|
||||
|
||||
|
||||
test_that("VIMP doesn't crash; censoring distribution; all methods equal", {
|
||||
|
||||
sampleData <- data.frame(x=rnorm(100))
|
||||
sampleData$T <- sample(0:4, size=100, replace = TRUE) # the censor distribution we provide needs to conform to the data or we can get NaNs
|
||||
sampleData$delta <- sample(0:2, size = 100, replace = TRUE)
|
||||
|
||||
testData <- sampleData[1:5,]
|
||||
trainingData <- sampleData[6:100,]
|
||||
|
||||
forest <- train(CR_Response(delta, T) ~ x, trainingData, ntree=50, numberOfSplits=0, mtry=1, nodeSize=5, cores=2, displayProgress=FALSE)
|
||||
|
||||
importance1 <- vimp(forest, type="raw", events=1:2, time=4.0, randomSeed=50,
|
||||
censoringDistribution = c(0,1,1,2,3,4))
|
||||
importance2 <- vimp(forest, type="raw", events=1:2, time=4.0, randomSeed=50,
|
||||
censoringDistribution = list(x = 0:4, y = 1 - c(1/6, 3/6, 4/6, 5/6, 6/6)))
|
||||
importance3 <- vimp(forest, type="raw", events=1:2, time=4.0, randomSeed=50,
|
||||
censoringDistribution = stepfun(x=0:4, y=1 - c(0, 1/6, 3/6, 4/6, 5/6, 6/6)))
|
||||
|
||||
expect_equal(importance1, importance2)
|
||||
expect_equal(importance1, importance3)
|
||||
|
||||
})
|
||||
|
||||
test_that("VIMP doesn't crash; regression dataset", {
|
||||
|
||||
data <- data.frame(x1=rnorm(1000), x2=rnorm(1000), x3=rnorm(1000))
|
||||
data$y <- data$x1 + 3*data$x2 + 0.05*data$x3 + rnorm(1000)
|
||||
|
||||
forest <- train(y ~ ., data, ntree=50, numberOfSplits=100, mtry=2, nodeSize=5, displayProgress=FALSE)
|
||||
|
||||
importance <- vimp(forest, type="mean")
|
||||
|
||||
expect_true(importance["x2"] > importance["x3"])
|
||||
|
||||
# Not much of a test, but the Java code tests more for correctness. This just
|
||||
# tests that the R code runs without error.
|
||||
expect_equal(length(importance), 3) # 3 predictors
|
||||
|
||||
})
|
||||
|
||||
test_that("VIMP produces mean and z scores correctly", {
|
||||
|
||||
data <- data.frame(x1=rnorm(1000), x2=rnorm(1000), x3=rnorm(1000))
|
||||
data$y <- data$x1 + 3*data$x2 + 0.05*data$x3 + rnorm(1000)
|
||||
|
||||
forest <- train(y ~ ., data, ntree=50, numberOfSplits=100, mtry=2, nodeSize=5, displayProgress=FALSE)
|
||||
|
||||
actual.importance.raw <- vimp(forest, type="raw", randomSeed=5)
|
||||
actual.importance.mean <- vimp(forest, type="mean", randomSeed=5)
|
||||
actual.importance.z <- vimp(forest, type="z", randomSeed=5)
|
||||
|
||||
expected.importance.mean <- apply(actual.importance.raw, 2, mean)
|
||||
expected.importance.z <- apply(actual.importance.raw, 2, function(x){
|
||||
mn <- mean(x)
|
||||
return( mn / (sd(x) / sqrt(length(x))) )
|
||||
})
|
||||
|
||||
expect_equal(expected.importance.mean, actual.importance.mean)
|
||||
expect_equal(expected.importance.z, actual.importance.z)
|
||||
|
||||
})
|
Loading…
Reference in a new issue