Compare commits
No commits in common. "master" and "1.0.3.1-develop" have entirely different histories.
master
...
1.0.3.1-de
38 changed files with 202 additions and 1196 deletions
|
@ -1,7 +1,7 @@
|
||||||
Package: largeRCRF
|
Package: largeRCRF
|
||||||
Type: Package
|
Type: Package
|
||||||
Title: Large Random Competing Risks Forests
|
Title: Large Random Competing Risks Forests
|
||||||
Version: 1.0.5
|
Version: 1.0.3.1
|
||||||
Authors@R: c(
|
Authors@R: c(
|
||||||
person("Joel", "Therrien", email = "joel_therrien@sfu.ca", role = c("aut", "cre", "cph")),
|
person("Joel", "Therrien", email = "joel_therrien@sfu.ca", role = c("aut", "cre", "cph")),
|
||||||
person("Jiguo", "Cao", email = "jiguo_cao@sfu.ca", role = c("aut", "dgs"))
|
person("Jiguo", "Cao", email = "jiguo_cao@sfu.ca", role = c("aut", "dgs"))
|
||||||
|
@ -15,8 +15,7 @@ Copyright: All provided source code is copyrighted and owned by Joel Therrien.
|
||||||
Encoding: UTF-8
|
Encoding: UTF-8
|
||||||
LazyData: true
|
LazyData: true
|
||||||
Imports:
|
Imports:
|
||||||
rJava (>= 0.9-9),
|
rJava (>= 0.9-9)
|
||||||
stats (>= 3.4.0)
|
|
||||||
Suggests:
|
Suggests:
|
||||||
parallel,
|
parallel,
|
||||||
testthat,
|
testthat,
|
||||||
|
|
|
@ -1,7 +1,5 @@
|
||||||
# Generated by roxygen2: do not edit by hand
|
# Generated by roxygen2: do not edit by hand
|
||||||
|
|
||||||
S3method("[",CompetingRiskResponses)
|
|
||||||
S3method("[",CompetingRiskResponsesWithCensorTimes)
|
|
||||||
S3method(extractCHF,CompetingRiskFunctions)
|
S3method(extractCHF,CompetingRiskFunctions)
|
||||||
S3method(extractCHF,CompetingRiskFunctions.List)
|
S3method(extractCHF,CompetingRiskFunctions.List)
|
||||||
S3method(extractCIF,CompetingRiskFunctions)
|
S3method(extractCIF,CompetingRiskFunctions)
|
||||||
|
@ -26,15 +24,12 @@ export(Numeric)
|
||||||
export(WeightedVarianceSplitFinder)
|
export(WeightedVarianceSplitFinder)
|
||||||
export(addTrees)
|
export(addTrees)
|
||||||
export(connectToData)
|
export(connectToData)
|
||||||
export(convertToOnlineForest)
|
|
||||||
export(extractCHF)
|
export(extractCHF)
|
||||||
export(extractCIF)
|
export(extractCIF)
|
||||||
export(extractMortalities)
|
export(extractMortalities)
|
||||||
export(extractSurvivorCurve)
|
export(extractSurvivorCurve)
|
||||||
export(integratedBrierScore)
|
|
||||||
export(loadForest)
|
export(loadForest)
|
||||||
export(naiveConcordance)
|
export(naiveConcordance)
|
||||||
export(saveForest)
|
export(saveForest)
|
||||||
export(train)
|
export(train)
|
||||||
export(vimp)
|
|
||||||
import(rJava)
|
import(rJava)
|
||||||
|
|
|
@ -38,51 +38,6 @@ CR_Response <- function(delta, u, C = NULL){
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
# This function is useful is we ever want to do something like CR_Response(c(1,1,2), c(0.1,0.2,0.3))[1]
|
|
||||||
#' @export
|
|
||||||
"[.CompetingRiskResponses" <- function(object, indices){
|
|
||||||
newList <- list(
|
|
||||||
eventIndicator = object$eventIndicator[indices],
|
|
||||||
eventTime = object$eventTime[indices]
|
|
||||||
)
|
|
||||||
|
|
||||||
previous.java.list <- object$javaObject
|
|
||||||
|
|
||||||
new.java.list <- .jcall(.class_RUtils,
|
|
||||||
makeResponse(.class_List),
|
|
||||||
"produceSublist",
|
|
||||||
previous.java.list,
|
|
||||||
.jarray(as.integer(indices - 1)))
|
|
||||||
|
|
||||||
newList$javaObject <- new.java.list
|
|
||||||
|
|
||||||
class(newList) <- "CompetingRiskResponses"
|
|
||||||
return(newList)
|
|
||||||
}
|
|
||||||
|
|
||||||
# This function is useful is we ever want to do something like CR_Response(c(1,1,2), c(0.1,0.2,0.3), c(2,3,4))[1]
|
|
||||||
#' @export
|
|
||||||
"[.CompetingRiskResponsesWithCensorTimes" <- function(object, indices){
|
|
||||||
newList <- list(
|
|
||||||
eventIndicator = object$eventIndicator[indices],
|
|
||||||
eventTime = object$eventTime[indices],
|
|
||||||
censorTime = object$censorTime[indices]
|
|
||||||
)
|
|
||||||
|
|
||||||
previous.java.list <- object$javaObject
|
|
||||||
|
|
||||||
new.java.list <- .jcall(.class_RUtils,
|
|
||||||
makeResponse(.class_List),
|
|
||||||
"produceSublist",
|
|
||||||
previous.java.list,
|
|
||||||
.jarray(as.integer(indices - 1)))
|
|
||||||
|
|
||||||
newList$javaObject <- new.java.list
|
|
||||||
|
|
||||||
class(newList) <- "CompetingRiskResponsesWithCensorTimes"
|
|
||||||
return(newList)
|
|
||||||
}
|
|
||||||
|
|
||||||
# Internal function
|
# Internal function
|
||||||
Java_CompetingRiskResponses <- function(delta, u){
|
Java_CompetingRiskResponses <- function(delta, u){
|
||||||
|
|
||||||
|
|
29
R/addTrees.R
29
R/addTrees.R
|
@ -14,10 +14,6 @@
|
||||||
#' @param savePath.overwrite If \code{savePath} is pointing to an existing
|
#' @param savePath.overwrite If \code{savePath} is pointing to an existing
|
||||||
#' directory, possibly containing another forest, this specifies what should
|
#' directory, possibly containing another forest, this specifies what should
|
||||||
#' be done.
|
#' be done.
|
||||||
#' @param forest.output This parameter only applies if \code{savePath} has been
|
|
||||||
#' set; set to 'online' (default) and the saved forest will be loaded into
|
|
||||||
#' memory after being trained. Set to 'offline' and the forest is not saved
|
|
||||||
#' into memory, but can still be used in a memory unintensive manner.
|
|
||||||
#' @param cores The number of cores to be used for training the new trees.
|
#' @param cores The number of cores to be used for training the new trees.
|
||||||
#' @param displayProgress A logical indicating whether the progress should be
|
#' @param displayProgress A logical indicating whether the progress should be
|
||||||
#' displayed to console; default is \code{TRUE}. Useful to set to FALSE in
|
#' displayed to console; default is \code{TRUE}. Useful to set to FALSE in
|
||||||
|
@ -26,11 +22,7 @@
|
||||||
#' @return A new forest with the original and additional trees.
|
#' @return A new forest with the original and additional trees.
|
||||||
#' @export
|
#' @export
|
||||||
#'
|
#'
|
||||||
addTrees <- function(forest, numTreesToAdd, savePath = NULL,
|
addTrees <- function(forest, numTreesToAdd, savePath = NULL, savePath.overwrite = c("warn", "delete", "merge"), cores = getCores(), displayProgress = TRUE){
|
||||||
savePath.overwrite = c("warn", "delete", "merge"),
|
|
||||||
forest.output = c("online", "offline"),
|
|
||||||
cores = getCores(), displayProgress = TRUE){
|
|
||||||
|
|
||||||
if(is.null(forest$dataset)){
|
if(is.null(forest$dataset)){
|
||||||
stop("Training dataset must be connected to forest before more trees can be added; this can be done manually by using connectToData")
|
stop("Training dataset must be connected to forest before more trees can be added; this can be done manually by using connectToData")
|
||||||
}
|
}
|
||||||
|
@ -45,10 +37,6 @@ addTrees <- function(forest, numTreesToAdd, savePath = NULL,
|
||||||
stop("savePath.overwrite must be one of c(\"warn\", \"delete\", \"merge\")")
|
stop("savePath.overwrite must be one of c(\"warn\", \"delete\", \"merge\")")
|
||||||
}
|
}
|
||||||
|
|
||||||
if(is.null(forest.output) | length(forest.output)==0 | !(forest.output[1] %in% c("online", "offline"))){
|
|
||||||
stop("forest.output must be one of c(\"online\", \"offline\")")
|
|
||||||
}
|
|
||||||
|
|
||||||
newTreeCount <- forest$params$ntree + as.integer(numTreesToAdd)
|
newTreeCount <- forest$params$ntree + as.integer(numTreesToAdd)
|
||||||
|
|
||||||
treeTrainer <- createTreeTrainer(responseCombiner=forest$params$nodeResponseCombiner,
|
treeTrainer <- createTreeTrainer(responseCombiner=forest$params$nodeResponseCombiner,
|
||||||
|
@ -110,23 +98,22 @@ addTrees <- function(forest, numTreesToAdd, savePath = NULL,
|
||||||
params=params,
|
params=params,
|
||||||
forestCall=match.call())
|
forestCall=match.call())
|
||||||
|
|
||||||
forest.java <- NULL
|
|
||||||
if(cores > 1){
|
if(cores > 1){
|
||||||
forest.java <- .jcall(forestTrainer, makeResponse(.class_OfflineForest), "trainParallelOnDisk", initial.forest.optional, as.integer(cores))
|
.jcall(forestTrainer, "V", "trainParallelOnDisk", initial.forest.optional, as.integer(cores))
|
||||||
} else {
|
} else {
|
||||||
forest.java <- .jcall(forestTrainer, makeResponse(.class_OfflineForest), "trainSerialOnDisk", initial.forest.optional)
|
.jcall(forestTrainer, "V", "trainSerialOnDisk", initial.forest.optional)
|
||||||
}
|
}
|
||||||
|
|
||||||
if(forest.output[1] == "online"){
|
# Need to now load forest trees back into memory
|
||||||
forest.java <- convertToOnlineForest.Java(forest.java)
|
forest.java <- .jcall(.class_DataUtils, makeResponse(.class_Forest), "loadForest", savePath, forest$params$forestResponseCombiner$javaObject)
|
||||||
}
|
|
||||||
|
|
||||||
}
|
}
|
||||||
else{ # save directly into memory
|
else{ # save directly into memory
|
||||||
if(cores > 1){
|
if(cores > 1){
|
||||||
forest.java <- .jcall(forestTrainer, makeResponse(.class_OnlineForest), "trainParallelInMemory", initial.forest.optional, as.integer(cores))
|
forest.java <- .jcall(forestTrainer, makeResponse(.class_Forest), "trainParallelInMemory", initial.forest.optional, as.integer(cores))
|
||||||
} else {
|
} else {
|
||||||
forest.java <- .jcall(forestTrainer, makeResponse(.class_OnlineForest), "trainSerialInMemory", initial.forest.optional)
|
forest.java <- .jcall(forestTrainer, makeResponse(.class_Forest), "trainSerialInMemory", initial.forest.optional)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,38 +0,0 @@
|
||||||
|
|
||||||
#' Convert to Online Forest
|
|
||||||
#'
|
|
||||||
#' Some forests are too large to store in memory and have been saved to disk.
|
|
||||||
#' They can still be used, but their performance is much slower. If there's
|
|
||||||
#' enough memory, they can easily be converted into an in-memory forest that is
|
|
||||||
#' faster to use.
|
|
||||||
#'
|
|
||||||
#' @param forest The offline forest.
|
|
||||||
#'
|
|
||||||
#' @return An online, in memory forst.
|
|
||||||
#' @export
|
|
||||||
#'
|
|
||||||
convertToOnlineForest <- function(forest){
|
|
||||||
old.forest.object <- forest$javaObject
|
|
||||||
|
|
||||||
if(getJavaClass(old.forest.object) == "ca.joeltherrien.randomforest.tree.OnlineForest"){
|
|
||||||
|
|
||||||
warning("forest is already in-memory")
|
|
||||||
return(forest)
|
|
||||||
|
|
||||||
} else if(getJavaClass(old.forest.object) == "ca.joeltherrien.randomforest.tree.OfflineForest"){
|
|
||||||
|
|
||||||
forest$javaObject <- convertToOnlineForest.Java(old.forest.object)
|
|
||||||
return(forest)
|
|
||||||
|
|
||||||
} else{
|
|
||||||
stop("'forest' is not an online or offline forest")
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
# Internal function
|
|
||||||
convertToOnlineForest.Java <- function(forest.java){
|
|
||||||
offline.forest <- .jcast(forest.java, .class_OfflineForest)
|
|
||||||
online.forest <- .jcall(offline.forest, makeResponse(.class_OnlineForest), "createOnlineCopy")
|
|
||||||
return(online.forest)
|
|
||||||
}
|
|
|
@ -42,7 +42,7 @@ CR_FunctionCombiner <- function(events, times = NULL){
|
||||||
}
|
}
|
||||||
|
|
||||||
javaObject <- .jnew(.class_CompetingRiskFunctionCombiner, eventArray, timeArray)
|
javaObject <- .jnew(.class_CompetingRiskFunctionCombiner, eventArray, timeArray)
|
||||||
javaObject <- .jcast(javaObject, .class_ForestResponseCombiner)
|
javaObject <- .jcast(javaObject, .class_ResponseCombiner)
|
||||||
|
|
||||||
combiner <- list(javaObject=javaObject,
|
combiner <- list(javaObject=javaObject,
|
||||||
call=match.call(),
|
call=match.call(),
|
||||||
|
|
|
@ -1,83 +0,0 @@
|
||||||
|
|
||||||
|
|
||||||
#' Integrated Brier Score
|
|
||||||
#'
|
|
||||||
#' Used to calculate the Integrated Brier Score, which for the competing risks
|
|
||||||
#' setting is the integral of the squared difference between each observed
|
|
||||||
#' cumulative incidence function (CIF) for each observation and the
|
|
||||||
#' corresponding predicted CIF. If the survivor function (1 - CDF) of the
|
|
||||||
#' censoring distribution is provided, weights can be calculated to account for
|
|
||||||
#' the censoring.
|
|
||||||
#'
|
|
||||||
#' @return A numeric vector of the Integrated Brier Score for each prediction.
|
|
||||||
#' @param responses A list of responses corresponding to the provided
|
|
||||||
#' mortalities; use \code{\link{CR_Response}}.
|
|
||||||
#' @param predictions The predictions to be tested against.
|
|
||||||
#' @param event The event type for the error to be calculated on.
|
|
||||||
#' @param time \code{time} specifies the upper bound of the integral.
|
|
||||||
#' @param censoringDistribution Optional; if provided then weights are
|
|
||||||
#' calculated on the errors. There are three ways to provide it - \itemize{
|
|
||||||
#' \item{If you have all the censor times and just want to use a simple
|
|
||||||
#' empirical estimate of the distribution, just provide a numeric vector of
|
|
||||||
#' all of the censor times and it will be automatically calculated.} \item{You
|
|
||||||
#' can directly specify the survivor function by providing a list with two
|
|
||||||
#' numeric vectors called \code{x} and \code{y}. They should be of the same
|
|
||||||
#' length and correspond to each point. It is assumed that previous to the
|
|
||||||
#' first value in \code{y} the \code{y} value is 1.0; and that the function
|
|
||||||
#' you provide is a right-continuous step function.} \item{You can provide a
|
|
||||||
#' function from \code{\link[stats]{stepfun}}. Note that this only supports
|
|
||||||
#' functions where \code{right = FALSE} (default), and that the first y value
|
|
||||||
#' (corresponding to y before the first x value) will be to set to 1.0
|
|
||||||
#' regardless of what is specified.}
|
|
||||||
#'
|
|
||||||
#' }
|
|
||||||
#' @param parallel A logical indicating whether multiple cores should be
|
|
||||||
#' utilized when calculating the error. Available as an option because it's
|
|
||||||
#' been observed that using Java's \code{parallelStream} can be unstable on
|
|
||||||
#' some systems. Default value is \code{TRUE}; only set to \code{FALSE} if you
|
|
||||||
#' get strange errors while predicting.
|
|
||||||
#'
|
|
||||||
#' @export
|
|
||||||
#' @references Section 4.2 of Ishwaran H, Gerds TA, Kogalur UB, Moore RD, Gange
|
|
||||||
#' SJ, Lau BM (2014). “Random Survival Forests for Competing Risks.”
|
|
||||||
#' Biostatistics, 15(4), 757–773. doi:10.1093/ biostatistics/kxu010.
|
|
||||||
#'
|
|
||||||
#' @examples
|
|
||||||
#' data <- data.frame(delta=c(1,1,0,0,2,2), T=1:6, x=1:6)
|
|
||||||
#'
|
|
||||||
#' model <- train(CR_Response(delta, T) ~ x, data, ntree=100, numberOfSplits=0, mtry=1, nodeSize=1)
|
|
||||||
#'
|
|
||||||
#' newData <- data.frame(delta=c(1,0,2,1,0,2), T=1:6, x=1:6)
|
|
||||||
#' predictions <- predict(model, newData)
|
|
||||||
#'
|
|
||||||
#' scores <- integratedBrierScore(CR_Response(data$delta, data$T), predictions, 1, 6.0)
|
|
||||||
#'
|
|
||||||
integratedBrierScore <- function(responses, predictions, event, time, censoringDistribution = NULL, parallel = TRUE){
|
|
||||||
if(length(responses$eventTime) != length(predictions)){
|
|
||||||
stop("Length of responses and predictions must be equal.")
|
|
||||||
}
|
|
||||||
|
|
||||||
java.censoringDistribution <- NULL
|
|
||||||
if(!is.null(censoringDistribution)){
|
|
||||||
java.censoringDistribution <- processCensoringDistribution(censoringDistribution)
|
|
||||||
java.censoringDistribution <- .object_Optional(java.censoringDistribution)
|
|
||||||
}
|
|
||||||
else{
|
|
||||||
java.censoringDistribution <- .object_Optional(NULL)
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
predictions.java <- lapply(predictions, function(x){return(x$javaObject)})
|
|
||||||
predictions.java <- convertRListToJava(predictions.java)
|
|
||||||
|
|
||||||
errors <- .jcall(.class_CompetingRiskUtils, "[D", "calculateIBSError",
|
|
||||||
responses$javaObject,
|
|
||||||
predictions.java,
|
|
||||||
java.censoringDistribution,
|
|
||||||
as.integer(event),
|
|
||||||
time,
|
|
||||||
parallel)
|
|
||||||
|
|
||||||
return(errors)
|
|
||||||
|
|
||||||
}
|
|
|
@ -16,11 +16,6 @@
|
||||||
#' list should correspond to one of the events in the order of event 1 to J,
|
#' list should correspond to one of the events in the order of event 1 to J,
|
||||||
#' and should be a vector of the same length as responses.
|
#' and should be a vector of the same length as responses.
|
||||||
#' @export
|
#' @export
|
||||||
#' @references Section 3.2 of Wolbers, Marcel, Paul Blanche, Michael T. Koller,
|
|
||||||
#' Jacqueline C M Witteman, and Thomas A Gerds. 2014. “Concordance for
|
|
||||||
#' Prognostic Models with Competing Risks.” Biostatistics 15 (3): 526–39.
|
|
||||||
#' https://doi.org/10.1093/biostatistics/kxt059.
|
|
||||||
#'
|
|
||||||
#' @examples
|
#' @examples
|
||||||
#' data <- data.frame(delta=c(1,1,0,0,2,2), T=1:6, x=1:6)
|
#' data <- data.frame(delta=c(1,1,0,0,2,2), T=1:6, x=1:6)
|
||||||
#'
|
#'
|
||||||
|
|
|
@ -1,11 +0,0 @@
|
||||||
# Internal function
|
|
||||||
createRightContinuousStepFunction <- function(x, y, defaultY){
|
|
||||||
x.java <- .jarray(as.numeric(x))
|
|
||||||
y.java <- .jarray(as.numeric(y))
|
|
||||||
|
|
||||||
# as.numeric is explicitly required in case integers were accidently passed
|
|
||||||
# in.
|
|
||||||
newFun <- .jnew(.class_RightContinuousStepFunction, as.numeric(x), as.numeric(y), as.numeric(defaultY))
|
|
||||||
return(newFun)
|
|
||||||
|
|
||||||
}
|
|
|
@ -29,8 +29,8 @@
|
||||||
NULL
|
NULL
|
||||||
|
|
||||||
# @rdname covariates
|
# @rdname covariates
|
||||||
Java_BooleanCovariate <- function(name, index, na.penalty){
|
Java_BooleanCovariate <- function(name, index){
|
||||||
covariate <- .jnew(.class_BooleanCovariate, name, as.integer(index), na.penalty)
|
covariate <- .jnew(.class_BooleanCovariate, name, as.integer(index))
|
||||||
covariate <- .jcast(covariate, .class_Object) # needed for later adding it into Java Lists
|
covariate <- .jcast(covariate, .class_Object) # needed for later adding it into Java Lists
|
||||||
|
|
||||||
return(covariate)
|
return(covariate)
|
||||||
|
@ -38,19 +38,19 @@ Java_BooleanCovariate <- function(name, index, na.penalty){
|
||||||
|
|
||||||
# @rdname covariates
|
# @rdname covariates
|
||||||
# @param levels The levels of the factor as a character vector
|
# @param levels The levels of the factor as a character vector
|
||||||
Java_FactorCovariate <- function(name, index, levels, na.penalty){
|
Java_FactorCovariate <- function(name, index, levels){
|
||||||
levelsArray <- .jarray(levels, makeResponse(.class_String))
|
levelsArray <- .jarray(levels, makeResponse(.class_String))
|
||||||
levelsList <- .jcall("java/util/Arrays", "Ljava/util/List;", "asList", .jcast(levelsArray, "[Ljava/lang/Object;"))
|
levelsList <- .jcall("java/util/Arrays", "Ljava/util/List;", "asList", .jcast(levelsArray, "[Ljava/lang/Object;"))
|
||||||
|
|
||||||
covariate <- .jnew(.class_FactorCovariate, name, as.integer(index), levelsList, na.penalty)
|
covariate <- .jnew(.class_FactorCovariate, name, as.integer(index), levelsList)
|
||||||
covariate <- .jcast(covariate, .class_Object) # needed for later adding it into Java Lists
|
covariate <- .jcast(covariate, .class_Object) # needed for later adding it into Java Lists
|
||||||
|
|
||||||
return(covariate)
|
return(covariate)
|
||||||
}
|
}
|
||||||
|
|
||||||
# @rdname covariates
|
# @rdname covariates
|
||||||
Java_NumericCovariate <- function(name, index, na.penalty){
|
Java_NumericCovariate <- function(name, index){
|
||||||
covariate <- .jnew(.class_NumericCovariate, name, as.integer(index), na.penalty)
|
covariate <- .jnew(.class_NumericCovariate, name, as.integer(index))
|
||||||
covariate <- .jcast(covariate, .class_Object) # needed for later adding it into Java Lists
|
covariate <- .jcast(covariate, .class_Object) # needed for later adding it into Java Lists
|
||||||
|
|
||||||
return(covariate)
|
return(covariate)
|
||||||
|
|
|
@ -10,20 +10,15 @@
|
||||||
.class_Collection <- "java/util/Collection"
|
.class_Collection <- "java/util/Collection"
|
||||||
.class_Serializable <- "java/io/Serializable"
|
.class_Serializable <- "java/io/Serializable"
|
||||||
.class_File <- "java/io/File"
|
.class_File <- "java/io/File"
|
||||||
.class_Random <- "java/util/Random"
|
|
||||||
.class_Class <- "java/lang/Class"
|
|
||||||
|
|
||||||
# Utility Classes
|
# Utility Classes
|
||||||
.class_DataUtils <- "ca/joeltherrien/randomforest/utils/DataUtils"
|
.class_DataUtils <- "ca/joeltherrien/randomforest/utils/DataUtils"
|
||||||
.class_RUtils <- "ca/joeltherrien/randomforest/utils/RUtils"
|
.class_RUtils <- "ca/joeltherrien/randomforest/utils/RUtils"
|
||||||
.class_Utils <- "ca/joeltherrien/randomforest/utils/Utils"
|
|
||||||
.class_CompetingRiskUtils <- "ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskUtils"
|
.class_CompetingRiskUtils <- "ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskUtils"
|
||||||
.class_Settings <- "ca/joeltherrien/randomforest/Settings"
|
.class_Settings <- "ca/joeltherrien/randomforest/Settings"
|
||||||
|
|
||||||
# Misc. Classes
|
# Misc. Classes
|
||||||
.class_RightContinuousStepFunction <- "ca/joeltherrien/randomforest/utils/RightContinuousStepFunction"
|
.class_RightContinuousStepFunction <- "ca/joeltherrien/randomforest/utils/RightContinuousStepFunction"
|
||||||
.class_CompetingRiskResponse <- "ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskResponse"
|
|
||||||
.class_CompetingRiskResponseWithCensorTime <- "ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskResponseWithCensorTime"
|
|
||||||
|
|
||||||
# TreeTrainer & its Builder
|
# TreeTrainer & its Builder
|
||||||
.class_TreeTrainer <- "ca/joeltherrien/randomforest/tree/TreeTrainer"
|
.class_TreeTrainer <- "ca/joeltherrien/randomforest/tree/TreeTrainer"
|
||||||
|
@ -42,12 +37,9 @@
|
||||||
|
|
||||||
# Forest class
|
# Forest class
|
||||||
.class_Forest <- "ca/joeltherrien/randomforest/tree/Forest"
|
.class_Forest <- "ca/joeltherrien/randomforest/tree/Forest"
|
||||||
.class_OnlineForest <- "ca/joeltherrien/randomforest/tree/OnlineForest"
|
|
||||||
.class_OfflineForest <- "ca/joeltherrien/randomforest/tree/OfflineForest"
|
|
||||||
|
|
||||||
# ResponseCombiner classes
|
# ResponseCombiner classes
|
||||||
.class_ResponseCombiner <- "ca/joeltherrien/randomforest/tree/ResponseCombiner"
|
.class_ResponseCombiner <- "ca/joeltherrien/randomforest/tree/ResponseCombiner"
|
||||||
.class_ForestResponseCombiner <- "ca/joeltherrien/randomforest/tree/ForestResponseCombiner"
|
|
||||||
.class_CompetingRiskResponseCombiner <- "ca/joeltherrien/randomforest/responses/competingrisk/combiner/CompetingRiskResponseCombiner"
|
.class_CompetingRiskResponseCombiner <- "ca/joeltherrien/randomforest/responses/competingrisk/combiner/CompetingRiskResponseCombiner"
|
||||||
.class_CompetingRiskFunctionCombiner <- "ca/joeltherrien/randomforest/responses/competingrisk/combiner/CompetingRiskFunctionCombiner"
|
.class_CompetingRiskFunctionCombiner <- "ca/joeltherrien/randomforest/responses/competingrisk/combiner/CompetingRiskFunctionCombiner"
|
||||||
.class_MeanResponseCombiner <- "ca/joeltherrien/randomforest/responses/regression/MeanResponseCombiner"
|
.class_MeanResponseCombiner <- "ca/joeltherrien/randomforest/responses/regression/MeanResponseCombiner"
|
||||||
|
@ -58,20 +50,12 @@
|
||||||
.class_LogRankSplitFinder <- "ca/joeltherrien/randomforest/responses/competingrisk/splitfinder/LogRankSplitFinder"
|
.class_LogRankSplitFinder <- "ca/joeltherrien/randomforest/responses/competingrisk/splitfinder/LogRankSplitFinder"
|
||||||
.class_WeightedVarianceSplitFinder <- "ca/joeltherrien/randomforest/responses/regression/WeightedVarianceSplitFinder"
|
.class_WeightedVarianceSplitFinder <- "ca/joeltherrien/randomforest/responses/regression/WeightedVarianceSplitFinder"
|
||||||
|
|
||||||
# VIMP classes
|
.object_Optional <- function(forest=NULL){
|
||||||
.class_IBSCalculator <- "ca/joeltherrien/randomforest/responses/competingrisk/IBSCalculator"
|
if(is.null(forest)){
|
||||||
.class_ErrorCalculator <- "ca/joeltherrien/randomforest/tree/vimp/ErrorCalculator"
|
|
||||||
.class_RegressionErrorCalculator <- "ca/joeltherrien/randomforest/tree/vimp/RegressionErrorCalculator"
|
|
||||||
.class_IBSErrorCalculatorWrapper <- "ca/joeltherrien/randomforest/tree/vimp/IBSErrorCalculatorWrapper"
|
|
||||||
.class_VariableImportanceCalculator <- "ca/joeltherrien/randomforest/tree/vimp/VariableImportanceCalculator"
|
|
||||||
|
|
||||||
|
|
||||||
.object_Optional <- function(object=NULL){
|
|
||||||
if(is.null(object)){
|
|
||||||
return(.jcall("java/util/Optional", "Ljava/util/Optional;", "empty"))
|
return(.jcall("java/util/Optional", "Ljava/util/Optional;", "empty"))
|
||||||
} else{
|
} else{
|
||||||
object <- .jcast(object, .class_Object)
|
forest <- .jcast(forest, .class_Object)
|
||||||
return(.jcall("java/util/Optional", "Ljava/util/Optional;", "of", object))
|
return(.jcall("java/util/Optional", "Ljava/util/Optional;", "of", forest))
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -82,9 +66,3 @@
|
||||||
makeResponse <- function(className){
|
makeResponse <- function(className){
|
||||||
return(paste0("L", className, ";"))
|
return(paste0("L", className, ";"))
|
||||||
}
|
}
|
||||||
|
|
||||||
getJavaClass <- function(object){
|
|
||||||
class <- .jcall(object, makeResponse(.class_Class), "getClass")
|
|
||||||
className <- .jcall(class, "S", "getName")
|
|
||||||
return(className)
|
|
||||||
}
|
|
||||||
|
|
37
R/loadData.R
37
R/loadData.R
|
@ -1,4 +1,4 @@
|
||||||
loadData <- function(data, xVarNames, responses, covariateList.java = NULL, na.penalty = NULL){
|
loadData <- function(data, xVarNames, responses, covariateList.java = NULL){
|
||||||
|
|
||||||
if(class(responses) == "integer" | class(responses) == "numeric"){
|
if(class(responses) == "integer" | class(responses) == "numeric"){
|
||||||
responses <- Numeric(responses)
|
responses <- Numeric(responses)
|
||||||
|
@ -6,7 +6,7 @@ loadData <- function(data, xVarNames, responses, covariateList.java = NULL, na.p
|
||||||
|
|
||||||
# connectToData provides a pre-created covariate list we can re-use
|
# connectToData provides a pre-created covariate list we can re-use
|
||||||
if(is.null(covariateList.java)){
|
if(is.null(covariateList.java)){
|
||||||
covariateList.java <- getCovariateList(data, xVarNames, na.penalty)
|
covariateList.java <- getCovariateList(data, xVarNames)
|
||||||
}
|
}
|
||||||
|
|
||||||
textColumns <- list()
|
textColumns <- list()
|
||||||
|
@ -18,11 +18,11 @@ loadData <- function(data, xVarNames, responses, covariateList.java = NULL, na.p
|
||||||
rowList <- .jcall(.class_RUtils, makeResponse(.class_List), "importDataWithResponses",
|
rowList <- .jcall(.class_RUtils, makeResponse(.class_List), "importDataWithResponses",
|
||||||
responses$javaObject, covariateList.java, textData)
|
responses$javaObject, covariateList.java, textData)
|
||||||
|
|
||||||
return(list(covariateList = covariateList.java, dataset = rowList, responses = responses))
|
return(list(covariateList=covariateList.java, dataset=rowList))
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
getCovariateList <- function(data, xvarNames, na.penalty){
|
getCovariateList <- function(data, xvarNames){
|
||||||
covariateList <- .jcast(.jnew(.class_ArrayList, length(xvarNames)), .class_List)
|
covariateList <- .jcast(.jnew(.class_ArrayList, length(xvarNames)), .class_List)
|
||||||
|
|
||||||
for(i in 1:length(xvarNames)){
|
for(i in 1:length(xvarNames)){
|
||||||
|
@ -31,14 +31,14 @@ getCovariateList <- function(data, xvarNames, na.penalty){
|
||||||
column <- data[,xName]
|
column <- data[,xName]
|
||||||
|
|
||||||
if(class(column) == "numeric" | class(column) == "integer"){
|
if(class(column) == "numeric" | class(column) == "integer"){
|
||||||
covariate <- Java_NumericCovariate(xName, i-1, na.penalty[i])
|
covariate <- Java_NumericCovariate(xName, i-1)
|
||||||
}
|
}
|
||||||
else if(class(column) == "logical"){
|
else if(class(column) == "logical"){
|
||||||
covariate <- Java_BooleanCovariate(xName, i-1, na.penalty[i])
|
covariate <- Java_BooleanCovariate(xName, i-1)
|
||||||
}
|
}
|
||||||
else if(class(column) == "factor"){
|
else if(class(column) == "factor"){
|
||||||
lvls <- levels(column)
|
lvls <- levels(column)
|
||||||
covariate <- Java_FactorCovariate(xName, i-1, lvls, na.penalty[i])
|
covariate <- Java_FactorCovariate(xName, i-1, lvls)
|
||||||
}
|
}
|
||||||
else{
|
else{
|
||||||
stop("Unknown column type")
|
stop("Unknown column type")
|
||||||
|
@ -54,7 +54,15 @@ getCovariateList <- function(data, xvarNames, na.penalty){
|
||||||
|
|
||||||
loadPredictionData <- function(newData, covariateList.java){
|
loadPredictionData <- function(newData, covariateList.java){
|
||||||
|
|
||||||
xVarNames <- extractCovariateNamesFromJavaList(covariateList.java)
|
xVarNames <- character(.jcall(covariateList.java, "I", "size"))
|
||||||
|
for(j in 1:length(xVarNames)){
|
||||||
|
covariate.java <- .jcast(
|
||||||
|
.jcall(covariateList.java, makeResponse(.class_Object), "get", as.integer(j-1)),
|
||||||
|
.class_Covariate
|
||||||
|
)
|
||||||
|
|
||||||
|
xVarNames[j] <- .jcall(covariate.java, makeResponse(.class_String), "getName")
|
||||||
|
}
|
||||||
|
|
||||||
if(any(!(xVarNames %in% names(newData)))){
|
if(any(!(xVarNames %in% names(newData)))){
|
||||||
varsMissing = xVarNames[!(xVarNames %in% names(newData))]
|
varsMissing = xVarNames[!(xVarNames %in% names(newData))]
|
||||||
|
@ -76,16 +84,3 @@ loadPredictionData <- function(newData, covariateList.java){
|
||||||
return(rowList)
|
return(rowList)
|
||||||
}
|
}
|
||||||
|
|
||||||
extractCovariateNamesFromJavaList <- function(covariateList.java){
|
|
||||||
xVarNames <- character(.jcall(covariateList.java, "I", "size"))
|
|
||||||
for(j in 1:length(xVarNames)){
|
|
||||||
covariate.java <- .jcast(
|
|
||||||
.jcall(covariateList.java, makeResponse(.class_Object), "get", as.integer(j-1)),
|
|
||||||
.class_Covariate
|
|
||||||
)
|
|
||||||
|
|
||||||
xVarNames[j] <- .jcall(covariate.java, makeResponse(.class_String), "getName")
|
|
||||||
}
|
|
||||||
|
|
||||||
return(xVarNames)
|
|
||||||
}
|
|
||||||
|
|
|
@ -5,12 +5,6 @@
|
||||||
#' Loads a random forest that was saved using \code{\link{saveForest}}.
|
#' Loads a random forest that was saved using \code{\link{saveForest}}.
|
||||||
#'
|
#'
|
||||||
#' @param directory The directory created that saved the previous forest.
|
#' @param directory The directory created that saved the previous forest.
|
||||||
#' @param forest.output Specifies whether the forest loaded should be loaded
|
|
||||||
#' into memory, or reflect the saved files where only one tree is loaded at a
|
|
||||||
#' time.
|
|
||||||
#' @param maxTreeNum If for some reason you only want to load the number of
|
|
||||||
#' trees up until a certain point, you can specify maxTreeNum as a single
|
|
||||||
#' number.
|
|
||||||
#' @return A JForest object; see \code{\link{train}} for details.
|
#' @return A JForest object; see \code{\link{train}} for details.
|
||||||
#' @export
|
#' @export
|
||||||
#' @seealso \code{\link{train}}, \code{\link{saveForest}}
|
#' @seealso \code{\link{train}}, \code{\link{saveForest}}
|
||||||
|
@ -26,11 +20,7 @@
|
||||||
#'
|
#'
|
||||||
#' saveForest(forest, "trees")
|
#' saveForest(forest, "trees")
|
||||||
#' new_forest <- loadForest("trees")
|
#' new_forest <- loadForest("trees")
|
||||||
loadForest <- function(directory, forest.output = c("online", "offline"), maxTreeNum = NULL){
|
loadForest <- function(directory){
|
||||||
|
|
||||||
if(is.null(forest.output) | length(forest.output)==0 | !(forest.output[1] %in% c("online", "offline"))){
|
|
||||||
stop("forest.output must be one of c(\"online\", \"offline\")")
|
|
||||||
}
|
|
||||||
|
|
||||||
# First load the response combiners and the split finders
|
# First load the response combiners and the split finders
|
||||||
nodeResponseCombiner.java <- .jcall(.class_DataUtils, makeResponse(.class_Object), "loadObject", paste0(directory, "/nodeResponseCombiner.jData"))
|
nodeResponseCombiner.java <- .jcall(.class_DataUtils, makeResponse(.class_Object), "loadObject", paste0(directory, "/nodeResponseCombiner.jData"))
|
||||||
|
@ -40,7 +30,7 @@ loadForest <- function(directory, forest.output = c("online", "offline"), maxTre
|
||||||
splitFinder.java <- .jcast(splitFinder.java, .class_SplitFinder)
|
splitFinder.java <- .jcast(splitFinder.java, .class_SplitFinder)
|
||||||
|
|
||||||
forestResponseCombiner.java <- .jcall(.class_DataUtils, makeResponse(.class_Object), "loadObject", paste0(directory, "/forestResponseCombiner.jData"))
|
forestResponseCombiner.java <- .jcall(.class_DataUtils, makeResponse(.class_Object), "loadObject", paste0(directory, "/forestResponseCombiner.jData"))
|
||||||
forestResponseCombiner.java <- .jcast(forestResponseCombiner.java, .class_ForestResponseCombiner)
|
forestResponseCombiner.java <- .jcast(forestResponseCombiner.java, .class_ResponseCombiner)
|
||||||
|
|
||||||
covariateList <- .jcall(.class_DataUtils, makeResponse(.class_Object), "loadObject", paste0(directory, "/covariateList.jData"))
|
covariateList <- .jcall(.class_DataUtils, makeResponse(.class_Object), "loadObject", paste0(directory, "/covariateList.jData"))
|
||||||
covariateList <- .jcast(covariateList, .class_List)
|
covariateList <- .jcast(covariateList, .class_List)
|
||||||
|
@ -52,11 +42,8 @@ loadForest <- function(directory, forest.output = c("online", "offline"), maxTre
|
||||||
params$splitFinder$javaObject <- splitFinder.java
|
params$splitFinder$javaObject <- splitFinder.java
|
||||||
params$forestResponseCombiner$javaObject <- forestResponseCombiner.java
|
params$forestResponseCombiner$javaObject <- forestResponseCombiner.java
|
||||||
|
|
||||||
forest <- loadForestArgumentsSpecified(directory, params$nodeResponseCombiner, params$splitFinder,
|
forest <- loadForestArgumentsSpecified(directory, params$nodeResponseCombiner, params$splitFinder, params$forestResponseCombiner, covariateList, call,
|
||||||
params$forestResponseCombiner, covariateList, call,
|
params$ntree, params$numberOfSplits, params$mtry, params$nodeSize, params$maxNodeDepth, params$splitPureNodes, params$randomSeed)
|
||||||
params$ntree, params$numberOfSplits, params$mtry,
|
|
||||||
params$nodeSize, params$maxNodeDepth, params$splitPureNodes,
|
|
||||||
params$randomSeed, forest.output, maxTreeNum)
|
|
||||||
|
|
||||||
return(forest)
|
return(forest)
|
||||||
|
|
||||||
|
@ -68,11 +55,8 @@ loadForest <- function(directory, forest.output = c("online", "offline"), maxTre
|
||||||
# that uses the Java version's settings yaml file to recreate the forest, but
|
# that uses the Java version's settings yaml file to recreate the forest, but
|
||||||
# I'd appreciate knowing that someone's going to use it first (email me; see
|
# I'd appreciate knowing that someone's going to use it first (email me; see
|
||||||
# README).
|
# README).
|
||||||
loadForestArgumentsSpecified <- function(treeDirectory, nodeResponseCombiner, splitFinder,
|
loadForestArgumentsSpecified <- function(treeDirectory, nodeResponseCombiner, splitFinder, forestResponseCombiner,
|
||||||
forestResponseCombiner, covariateList.java, call,
|
covariateList.java, call, ntree, numberOfSplits, mtry, nodeSize, maxNodeDepth = 100000, splitPureNodes=TRUE, randomSeed=NULL){
|
||||||
ntree, numberOfSplits, mtry, nodeSize,
|
|
||||||
maxNodeDepth = 100000, splitPureNodes=TRUE,
|
|
||||||
randomSeed=NULL, forest.output = "online", maxTreeNum = NULL){
|
|
||||||
|
|
||||||
params <- list(
|
params <- list(
|
||||||
splitFinder=splitFinder,
|
splitFinder=splitFinder,
|
||||||
|
@ -87,33 +71,7 @@ loadForestArgumentsSpecified <- function(treeDirectory, nodeResponseCombiner, sp
|
||||||
randomSeed=randomSeed
|
randomSeed=randomSeed
|
||||||
)
|
)
|
||||||
|
|
||||||
forest.java <- NULL
|
forest.java <- .jcall(.class_DataUtils, makeResponse(.class_Forest), "loadForest", treeDirectory, forestResponseCombiner$javaObject)
|
||||||
if(forest.output[1] == "online"){
|
|
||||||
castedForestResponseCombiner <- .jcast(forestResponseCombiner$javaObject, .class_ResponseCombiner) # OnlineForest constructor takes a ResponseCombiner
|
|
||||||
|
|
||||||
if(is.null(maxTreeNum)){
|
|
||||||
forest.java <- .jcall(.class_DataUtils, makeResponse(.class_OnlineForest), "loadOnlineForest",
|
|
||||||
treeDirectory, castedForestResponseCombiner)
|
|
||||||
} else{
|
|
||||||
tree.file.array <- .jcall(.class_RUtils, paste0("[", makeResponse(.class_File)), "getTreeFileArray",
|
|
||||||
treeDirectory, as.integer(maxTreeNum), evalArray = FALSE)
|
|
||||||
forest.java <- .jcall(.class_DataUtils, makeResponse(.class_OnlineForest), "loadOnlineForest",
|
|
||||||
tree.file.array, castedForestResponseCombiner)
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
} else{ # offline forest
|
|
||||||
if(is.null(maxTreeNum)){
|
|
||||||
path.as.file <- .jnew(.class_File, treeDirectory)
|
|
||||||
forest.java <- .jnew(.class_OfflineForest, path.as.file, forestResponseCombiner$javaObject)
|
|
||||||
} else{
|
|
||||||
tree.file.array <- .jcall(.class_RUtils, paste0("[", makeResponse(.class_File)), "getTreeFileArray",
|
|
||||||
treeDirectory, as.integer(maxTreeNum), evalArray = FALSE)
|
|
||||||
forest.java <- .jnew(.class_OfflineForest, tree.file.array, forestResponseCombiner$javaObject)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
forestObject <- list(call=call, javaObject=forest.java, covariateList=covariateList.java, params=params)
|
forestObject <- list(call=call, javaObject=forest.java, covariateList=covariateList.java, params=params)
|
||||||
class(forestObject) <- "JRandomForest"
|
class(forestObject) <- "JRandomForest"
|
||||||
|
|
17
R/misc.R
17
R/misc.R
|
@ -14,23 +14,6 @@ convertRListToJava <- function(lst){
|
||||||
return(javaList)
|
return(javaList)
|
||||||
}
|
}
|
||||||
|
|
||||||
#Internal function
|
|
||||||
convertJavaListToR <- function(javaList, class = .class_Object){
|
|
||||||
lst <- list()
|
|
||||||
|
|
||||||
javaList.length <- .jcall(javaList, "I", "size")
|
|
||||||
|
|
||||||
for(i in 0:(javaList.length - 1)){
|
|
||||||
object <- .jcall(javaList, makeResponse(.class_Object), "get", as.integer(i))
|
|
||||||
object <- .jcast(object, class)
|
|
||||||
|
|
||||||
lst[[i+1]] <- object
|
|
||||||
}
|
|
||||||
|
|
||||||
return(lst)
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
#' @export
|
#' @export
|
||||||
print.SplitFinder = function(x, ...) print(x$call)
|
print.SplitFinder = function(x, ...) print(x$call)
|
||||||
|
|
||||||
|
|
|
@ -99,7 +99,7 @@ predict.JRandomForest <- function(object, newData=NULL, parallel=TRUE, out.of.ba
|
||||||
predictionsJava <- .jcall(forestObject, makeResponse(.class_List), function.to.use, predictionDataList)
|
predictionsJava <- .jcall(forestObject, makeResponse(.class_List), function.to.use, predictionDataList)
|
||||||
|
|
||||||
if(predictionClass == "numeric"){
|
if(predictionClass == "numeric"){
|
||||||
predictions <- vector(length=numRows, mode="numeric")
|
predictions <- vector(length=nrow(newData), mode="numeric")
|
||||||
}
|
}
|
||||||
else{
|
else{
|
||||||
predictions <- list()
|
predictions <- list()
|
||||||
|
|
|
@ -1,39 +0,0 @@
|
||||||
|
|
||||||
|
|
||||||
# Internal function. Takes a censoring distribution and turns it into a
|
|
||||||
# RightContinuousStepFunction Java object.
|
|
||||||
processCensoringDistribution <- function(censoringDistribution){
|
|
||||||
|
|
||||||
if(is.numeric(censoringDistribution)){
|
|
||||||
# estimate ECDF
|
|
||||||
censoringTimes <- .jarray(censoringDistribution, "D")
|
|
||||||
java.censoringDistribution <- .jcall(.class_Utils, makeResponse(.class_RightContinuousStepFunction), "estimateOneMinusECDF", censoringTimes)
|
|
||||||
|
|
||||||
} else if(is.list(censoringDistribution)){
|
|
||||||
# First check that censoringDistribution fits the correct format
|
|
||||||
if(is.null(censoringDistribution$x) | is.null(censoringDistribution$y)){
|
|
||||||
stop("If the censoringDistribution is provided as a list, it must have an x and a y item that are numeric.")
|
|
||||||
}
|
|
||||||
|
|
||||||
if(length(censoringDistribution$x) != length(censoringDistribution$y)){
|
|
||||||
stop("x and y in censoringDistribution must have the same length.")
|
|
||||||
}
|
|
||||||
|
|
||||||
if(!is.numeric(censoringDistribution$x) | !is.numeric(censoringDistribution$y)){
|
|
||||||
stop("x and y in censoringDistribution must both be numeric.")
|
|
||||||
}
|
|
||||||
|
|
||||||
java.censoringDistribution <- createRightContinuousStepFunction(censoringDistribution$x, censoringDistribution$y, defaultY = 1.0)
|
|
||||||
|
|
||||||
} else if("stepfun" %in% class(censoringDistribution)){
|
|
||||||
x <- stats::knots(censoringDistribution)
|
|
||||||
y <- censoringDistribution(x)
|
|
||||||
|
|
||||||
java.censoringDistribution <- createRightContinuousStepFunction(x, y, defaultY = 1.0)
|
|
||||||
}
|
|
||||||
else{
|
|
||||||
stop("Invalid censoringDistribution")
|
|
||||||
}
|
|
||||||
|
|
||||||
return(java.censoringDistribution)
|
|
||||||
}
|
|
|
@ -1,95 +0,0 @@
|
||||||
|
|
||||||
# Internal function that takes a formula and processes it for use in the Java
|
|
||||||
# code. existingCovariateList is optional; if not provided then a new one is
|
|
||||||
# created internally.
|
|
||||||
processFormula <- function(formula, data, covariateList.java = NULL, na.penalty = NULL){
|
|
||||||
|
|
||||||
# Having an R copy of the data loaded at the same time can be wasteful; we
|
|
||||||
# also allow users to provide an environment of the data which gets removed
|
|
||||||
# after being imported into Java
|
|
||||||
if(class(data) == "environment"){
|
|
||||||
if(is.null(data$data)){
|
|
||||||
stop("When providing an environment with the dataset, the environment must contain an item called 'data'")
|
|
||||||
}
|
|
||||||
|
|
||||||
env <- data
|
|
||||||
data <- env$data
|
|
||||||
env$data <- NULL
|
|
||||||
gc()
|
|
||||||
}
|
|
||||||
|
|
||||||
yVar <- formula[[2]]
|
|
||||||
|
|
||||||
responses <- NULL
|
|
||||||
variablesToDrop <- character(0)
|
|
||||||
|
|
||||||
# yVar is a call object; as.character(yVar) will be the different components, including the parameters.
|
|
||||||
# if the length of yVar is > 1 then it's a function call. If the length is 1, and it's not in data,
|
|
||||||
# then we also need to explicitly evaluate it
|
|
||||||
if(class(yVar) == "call" || !(as.character(yVar) %in% colnames(data))){
|
|
||||||
# yVar is a function like CR_Response
|
|
||||||
responses <- eval(expr=yVar, envir=data)
|
|
||||||
|
|
||||||
if(class(formula[[3]]) == "name" && as.character(formula[[3]])=="."){
|
|
||||||
# do any of the variables match data in data? We need to track that so we can drop them later
|
|
||||||
variablesToDrop <- as.character(yVar)[as.character(yVar) %in% names(data)]
|
|
||||||
}
|
|
||||||
|
|
||||||
formula[[2]] <- NULL
|
|
||||||
|
|
||||||
} else if(class(yVar) == "name"){ # and implicitly yVar is contained in data
|
|
||||||
variablesToDrop <- as.character(yVar)
|
|
||||||
}
|
|
||||||
|
|
||||||
# Includes responses which we may need to later cut out if `.` was used on the
|
|
||||||
# right-hand-side
|
|
||||||
filteredData <- stats::model.frame(formula=formula, data=data, na.action=stats::na.pass)
|
|
||||||
|
|
||||||
if(is.null(responses)){ # If this if-statement runs then we have a simple (i.e. numeric) response
|
|
||||||
responses <- stats::model.response(filteredData)
|
|
||||||
}
|
|
||||||
|
|
||||||
# remove any response variables on the right-hand-side
|
|
||||||
covariateData <- filteredData[, !(names(filteredData) %in% variablesToDrop), drop=FALSE]
|
|
||||||
|
|
||||||
# Now that we know how many predictor variables we have, we should check na.penalty
|
|
||||||
if(!is.null(na.penalty)){
|
|
||||||
if(!is.numeric(na.penalty) & !is.logical(na.penalty)){
|
|
||||||
stop("na.penalty must be either logical or numeric.")
|
|
||||||
}
|
|
||||||
|
|
||||||
if(is.logical(na.penalty) & length(na.penalty) != 1 & length(na.penalty) != ncol(covariateData)){
|
|
||||||
stop("na.penalty must have length of either 1 or the number of predictor variables if logical.")
|
|
||||||
}
|
|
||||||
|
|
||||||
if(is.numeric(na.penalty) & length(na.penalty) != 1){
|
|
||||||
stop("na.penalty must have length 1 if logical.")
|
|
||||||
}
|
|
||||||
|
|
||||||
if(anyNA(na.penalty)){
|
|
||||||
stop("na.penalty cannot contain NAs.")
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# All good; now to transform it.
|
|
||||||
if(is.numeric(na.penalty)){
|
|
||||||
na.threshold <- na.penalty
|
|
||||||
na.penalty <- apply(covariateData, 2, function(x){mean(is.na(x))}) >= na.threshold
|
|
||||||
}
|
|
||||||
else if(is.logical(na.penalty) & length(na.penalty) == 1){
|
|
||||||
na.penalty <- rep(na.penalty, times = ncol(covariateData))
|
|
||||||
}
|
|
||||||
# else{} - na.penalty is logical and the correct length; no need to do anything to it
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
dataset <- loadData(
|
|
||||||
covariateData,
|
|
||||||
colnames(covariateData),
|
|
||||||
responses,
|
|
||||||
covariateList.java = covariateList.java,
|
|
||||||
na.penalty = na.penalty
|
|
||||||
)
|
|
||||||
|
|
||||||
return(dataset)
|
|
||||||
}
|
|
|
@ -48,7 +48,7 @@ WeightedVarianceSplitFinder <- function(){
|
||||||
#'
|
#'
|
||||||
MeanResponseCombiner <- function(){
|
MeanResponseCombiner <- function(){
|
||||||
javaObject <- .jnew(.class_MeanResponseCombiner)
|
javaObject <- .jnew(.class_MeanResponseCombiner)
|
||||||
javaObject <- .jcast(javaObject, .class_ForestResponseCombiner)
|
javaObject <- .jcast(javaObject, .class_ResponseCombiner)
|
||||||
|
|
||||||
combiner <- list(javaObject=javaObject, call=match.call(), outputClass="numeric")
|
combiner <- list(javaObject=javaObject, call=match.call(), outputClass="numeric")
|
||||||
combiner$convertToRFunction <- function(javaObject, ...){
|
combiner$convertToRFunction <- function(javaObject, ...){
|
||||||
|
|
176
R/train.R
176
R/train.R
|
@ -14,11 +14,11 @@ getCores <- function(){
|
||||||
return(cores)
|
return(cores)
|
||||||
}
|
}
|
||||||
|
|
||||||
train.internal <- function(dataset, splitFinder,
|
train.internal <- function(responses, covariateData, splitFinder,
|
||||||
nodeResponseCombiner, forestResponseCombiner, ntree,
|
nodeResponseCombiner, forestResponseCombiner, ntree,
|
||||||
numberOfSplits, mtry, nodeSize, maxNodeDepth,
|
numberOfSplits, mtry, nodeSize, maxNodeDepth,
|
||||||
splitPureNodes, savePath, savePath.overwrite,
|
splitPureNodes, savePath, savePath.overwrite,
|
||||||
forest.output, cores, randomSeed, displayProgress){
|
cores, randomSeed, displayProgress){
|
||||||
|
|
||||||
# Some quick checks on parameters
|
# Some quick checks on parameters
|
||||||
ntree <- as.integer(ntree)
|
ntree <- as.integer(ntree)
|
||||||
|
@ -51,20 +51,16 @@ train.internal <- function(dataset, splitFinder,
|
||||||
stop("savePath.overwrite must be one of c(\"warn\", \"delete\", \"merge\")")
|
stop("savePath.overwrite must be one of c(\"warn\", \"delete\", \"merge\")")
|
||||||
}
|
}
|
||||||
|
|
||||||
if(is.null(forest.output) | length(forest.output)==0 | !(forest.output[1] %in% c("online", "offline"))){
|
|
||||||
stop("forest.output must be one of c(\"online\", \"offline\")")
|
|
||||||
}
|
|
||||||
|
|
||||||
if(is.null(splitFinder)){
|
if(is.null(splitFinder)){
|
||||||
splitFinder <- splitFinderDefault(dataset$responses)
|
splitFinder <- splitFinderDefault(responses)
|
||||||
}
|
}
|
||||||
|
|
||||||
if(is.null(nodeResponseCombiner)){
|
if(is.null(nodeResponseCombiner)){
|
||||||
nodeResponseCombiner <- nodeResponseCombinerDefault(dataset$responses)
|
nodeResponseCombiner <- nodeResponseCombinerDefault(responses)
|
||||||
}
|
}
|
||||||
|
|
||||||
if(is.null(forestResponseCombiner)){
|
if(is.null(forestResponseCombiner)){
|
||||||
forestResponseCombiner <- forestResponseCombinerDefault(dataset$responses)
|
forestResponseCombiner <- forestResponseCombinerDefault(responses)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -79,6 +75,20 @@ train.internal <- function(dataset, splitFinder,
|
||||||
stop("forestResponseCombiner must be a ResponseCombiner")
|
stop("forestResponseCombiner must be a ResponseCombiner")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if(class(covariateData)=="environment"){
|
||||||
|
if(is.null(covariateData$data)){
|
||||||
|
stop("When providing an environment with the dataset, the environment must contain an item called 'data'")
|
||||||
|
}
|
||||||
|
dataset <- loadData(covariateData$data, colnames(covariateData$data), responses)
|
||||||
|
covariateData$data <- NULL # save memory, hopefully
|
||||||
|
gc() # explicitly try to save memory
|
||||||
|
}
|
||||||
|
else{
|
||||||
|
dataset <- loadData(covariateData, colnames(covariateData), responses)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
treeTrainer <- createTreeTrainer(responseCombiner=nodeResponseCombiner,
|
treeTrainer <- createTreeTrainer(responseCombiner=nodeResponseCombiner,
|
||||||
splitFinder=splitFinder,
|
splitFinder=splitFinder,
|
||||||
covariateList=dataset$covariateList,
|
covariateList=dataset$covariateList,
|
||||||
|
@ -133,23 +143,22 @@ train.internal <- function(dataset, splitFinder,
|
||||||
params=params,
|
params=params,
|
||||||
forestCall=match.call())
|
forestCall=match.call())
|
||||||
|
|
||||||
forest.java <- NULL
|
|
||||||
if(cores > 1){
|
if(cores > 1){
|
||||||
forest.java <- .jcall(forestTrainer, makeResponse(.class_OfflineForest), "trainParallelOnDisk", .object_Optional(), as.integer(cores))
|
.jcall(forestTrainer, "V", "trainParallelOnDisk", .object_Optional(), as.integer(cores))
|
||||||
} else {
|
} else {
|
||||||
forest.java <- .jcall(forestTrainer, makeResponse(.class_OfflineForest), "trainSerialOnDisk", .object_Optional())
|
.jcall(forestTrainer, "V", "trainSerialOnDisk", .object_Optional())
|
||||||
}
|
}
|
||||||
|
|
||||||
if(forest.output[1] == "online"){
|
# Need to now load forest trees back into memory
|
||||||
forest.java <- convertToOnlineForest.Java(forest.java)
|
forest.java <- .jcall(.class_DataUtils, makeResponse(.class_Forest), "loadForest", savePath, forestResponseCombiner$javaObject)
|
||||||
}
|
|
||||||
|
|
||||||
}
|
}
|
||||||
else{ # save directly into memory
|
else{ # save directly into memory
|
||||||
if(cores > 1){
|
if(cores > 1){
|
||||||
forest.java <- .jcall(forestTrainer, makeResponse(.class_OnlineForest), "trainParallelInMemory", .object_Optional(), as.integer(cores))
|
forest.java <- .jcall(forestTrainer, makeResponse(.class_Forest), "trainParallelInMemory", .object_Optional(), as.integer(cores))
|
||||||
} else {
|
} else {
|
||||||
forest.java <- .jcall(forestTrainer, makeResponse(.class_OnlineForest), "trainSerialInMemory", .object_Optional())
|
forest.java <- .jcall(forestTrainer, makeResponse(.class_Forest), "trainSerialInMemory", .object_Optional())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -179,13 +188,13 @@ train.internal <- function(dataset, splitFinder,
|
||||||
#'
|
#'
|
||||||
#' @param formula You may specify the response and covariates as a formula
|
#' @param formula You may specify the response and covariates as a formula
|
||||||
#' instead; make sure the response in the formula is still properly
|
#' instead; make sure the response in the formula is still properly
|
||||||
#' constructed.
|
#' constructed; see \code{responses}
|
||||||
#' @param data A data.frame containing the columns of the predictors and
|
#' @param data A data.frame containing the columns of the predictors and
|
||||||
#' responses.
|
#' responses.
|
||||||
#' @param splitFinder A split finder that's used to score splits in the random
|
#' @param splitFinder A split finder that's used to score splits in the random
|
||||||
#' forest training algorithm. See \code{\link{CompetingRiskSplitFinders}} or
|
#' forest training algorithm. See \code{\link{CompetingRiskSplitFinders}}
|
||||||
#' \code{\link{WeightedVarianceSplitFinder}}. If you don't specify one, this
|
#' or \code{\link{WeightedVarianceSplitFinder}}. If you don't specify one,
|
||||||
#' function tries to pick one based on the response. For
|
#' this function tries to pick one based on the response. For
|
||||||
#' \code{\link{CR_Response}} without censor times, it will pick a
|
#' \code{\link{CR_Response}} without censor times, it will pick a
|
||||||
#' \code{\link{LogRankSplitFinder}}; while if censor times were provided it
|
#' \code{\link{LogRankSplitFinder}}; while if censor times were provided it
|
||||||
#' will pick \code{\link{GrayLogRankSplitFinder}}; for integer or numeric
|
#' will pick \code{\link{GrayLogRankSplitFinder}}; for integer or numeric
|
||||||
|
@ -193,19 +202,19 @@ train.internal <- function(dataset, splitFinder,
|
||||||
#' @param nodeResponseCombiner A response combiner that's used to combine
|
#' @param nodeResponseCombiner A response combiner that's used to combine
|
||||||
#' responses for each terminal node in a tree (regression example; average the
|
#' responses for each terminal node in a tree (regression example; average the
|
||||||
#' observations in each tree into a single number). See
|
#' observations in each tree into a single number). See
|
||||||
#' \code{\link{CR_ResponseCombiner}} or \code{\link{MeanResponseCombiner}}. If
|
#' \code{\link{CR_ResponseCombiner}} or
|
||||||
#' you don't specify one, this function tries to pick one based on the
|
#' \code{\link{MeanResponseCombiner}}. If you don't specify one, this function
|
||||||
#' response. For \code{\link{CR_Response}} it picks a
|
#' tries to pick one based on the response. For \code{\link{CR_Response}} it
|
||||||
#' \code{\link{CR_ResponseCombiner}}; for integer or numeric responses it
|
#' picks a \code{\link{CR_ResponseCombiner}}; for integer or numeric
|
||||||
#' picks a \code{\link{MeanResponseCombiner}}.
|
#' responses it picks a \code{\link{MeanResponseCombiner}}.
|
||||||
#' @param forestResponseCombiner A response combiner that's used to combine
|
#' @param forestResponseCombiner A response combiner that's used to combine
|
||||||
#' predictions across trees into one final result (regression example; average
|
#' predictions across trees into one final result (regression example; average
|
||||||
#' the prediction of each tree into a single number). See
|
#' the prediction of each tree into a single number). See
|
||||||
#' \code{\link{CR_FunctionCombiner}} or \code{\link{MeanResponseCombiner}}. If
|
#' \code{\link{CR_FunctionCombiner}} or
|
||||||
#' you don't specify one, this function tries to pick one based on the
|
#' \code{\link{MeanResponseCombiner}}. If you don't specify one, this function
|
||||||
#' response. For \code{\link{CR_Response}} it picks a
|
#' tries to pick one based on the response. For \code{\link{CR_Response}} it
|
||||||
#' \code{\link{CR_FunctionCombiner}}; for integer or numeric responses it
|
#' picks a \code{\link{CR_FunctionCombiner}}; for integer or numeric
|
||||||
#' picks a \code{\link{MeanResponseCombiner}}.
|
#' responses it picks a \code{\link{MeanResponseCombiner}}.
|
||||||
#' @param ntree An integer that specifies how many trees should be trained.
|
#' @param ntree An integer that specifies how many trees should be trained.
|
||||||
#' @param numberOfSplits A tuning parameter specifying how many random splits
|
#' @param numberOfSplits A tuning parameter specifying how many random splits
|
||||||
#' should be tried for a covariate; a value of 0 means all splits will be
|
#' should be tried for a covariate; a value of 0 means all splits will be
|
||||||
|
@ -222,20 +231,6 @@ train.internal <- function(dataset, splitFinder,
|
||||||
#' @param maxNodeDepth This parameter is analogous to \code{nodeSize} in that it
|
#' @param maxNodeDepth This parameter is analogous to \code{nodeSize} in that it
|
||||||
#' controls tree length; by default \code{maxNodeDepth} is an extremely high
|
#' controls tree length; by default \code{maxNodeDepth} is an extremely high
|
||||||
#' number and tree depth is controlled by \code{nodeSize}.
|
#' number and tree depth is controlled by \code{nodeSize}.
|
||||||
#' @param na.penalty This parameter controls whether predictor variables with
|
|
||||||
#' NAs should be penalized when being considered for a best split. Best splits
|
|
||||||
#' (and the associated score) are determined on only non-NA data; the penalty
|
|
||||||
#' is to take the best split identified, and to randomly assign any NAs
|
|
||||||
#' (according to the proportion of data split left and right), and then
|
|
||||||
#' recalculate the corresponding split score, when is then compared with the
|
|
||||||
#' other split candiate variables. This penalty adds some computational time,
|
|
||||||
#' so it may be disabled for some variables. \code{na.penalty} may be
|
|
||||||
#' specified as a vector of logicals indicating, for each predictor variable,
|
|
||||||
#' whether the penalty should be applied to that variable. If it's length 1
|
|
||||||
#' then it applies to all variables. Alternatively, a single numeric value may
|
|
||||||
#' be provided to indicate a threshold whereby the penalty is activated only
|
|
||||||
#' if the proportion of NAs for that variable in the training set exceeds that
|
|
||||||
#' threshold.
|
|
||||||
#' @param splitPureNodes This parameter determines whether the algorithm will
|
#' @param splitPureNodes This parameter determines whether the algorithm will
|
||||||
#' split a pure node. If set to FALSE, then before every split it will check
|
#' split a pure node. If set to FALSE, then before every split it will check
|
||||||
#' that every response is the same, and if so, not split. If set to TRUE it
|
#' that every response is the same, and if so, not split. If set to TRUE it
|
||||||
|
@ -258,10 +253,6 @@ train.internal <- function(dataset, splitFinder,
|
||||||
#' assumes (without checking) that the existing trees are from a previous run
|
#' assumes (without checking) that the existing trees are from a previous run
|
||||||
#' and starts from where it left off. This option is useful if recovering from
|
#' and starts from where it left off. This option is useful if recovering from
|
||||||
#' a crash.
|
#' a crash.
|
||||||
#' @param forest.output This parameter only applies if \code{savePath} has been
|
|
||||||
#' set; set to 'online' (default) and the saved forest will be loaded into
|
|
||||||
#' memory after being trained. Set to 'offline' and the forest is not saved
|
|
||||||
#' into memory, but can still be used in a memory unintensive manner.
|
|
||||||
#' @param cores This parameter specifies how many trees will be simultaneously
|
#' @param cores This parameter specifies how many trees will be simultaneously
|
||||||
#' trained. By default the package attempts to detect how many cores you have
|
#' trained. By default the package attempts to detect how many cores you have
|
||||||
#' by using the \code{parallel} package and using all of them. You may specify
|
#' by using the \code{parallel} package and using all of them. You may specify
|
||||||
|
@ -319,21 +310,78 @@ train.internal <- function(dataset, splitFinder,
|
||||||
#' ypred <- predict(forest, newData)
|
#' ypred <- predict(forest, newData)
|
||||||
train <- function(formula, data, splitFinder = NULL, nodeResponseCombiner = NULL,
|
train <- function(formula, data, splitFinder = NULL, nodeResponseCombiner = NULL,
|
||||||
forestResponseCombiner = NULL, ntree, numberOfSplits, mtry,
|
forestResponseCombiner = NULL, ntree, numberOfSplits, mtry,
|
||||||
nodeSize, maxNodeDepth = 100000, na.penalty = TRUE, splitPureNodes=TRUE,
|
nodeSize, maxNodeDepth = 100000, splitPureNodes=TRUE, savePath=NULL,
|
||||||
savePath = NULL, savePath.overwrite = c("warn", "delete", "merge"),
|
savePath.overwrite=c("warn", "delete", "merge"), cores = getCores(),
|
||||||
forest.output = c("online", "offline"),
|
randomSeed = NULL, displayProgress = TRUE){
|
||||||
cores = getCores(), randomSeed = NULL, displayProgress = TRUE){
|
|
||||||
|
|
||||||
dataset <- processFormula(formula, data, na.penalty = na.penalty)
|
# Having an R copy of the data loaded at the same time can be wasteful; we
|
||||||
|
# also allow users to provide an environment of the data which gets removed
|
||||||
|
# after being imported into Java
|
||||||
|
env <- NULL
|
||||||
|
if(class(data) == "environment"){
|
||||||
|
if(is.null(data$data)){
|
||||||
|
stop("When providing an environment with the dataset, the environment must contain an item called 'data'")
|
||||||
|
}
|
||||||
|
|
||||||
forest <- train.internal(dataset, splitFinder = splitFinder,
|
env <- data
|
||||||
nodeResponseCombiner = nodeResponseCombiner,
|
data <- env$data
|
||||||
forestResponseCombiner = forestResponseCombiner,
|
}
|
||||||
ntree = ntree, numberOfSplits = numberOfSplits,
|
|
||||||
mtry = mtry, nodeSize = nodeSize, maxNodeDepth = maxNodeDepth,
|
yVar <- formula[[2]]
|
||||||
splitPureNodes = splitPureNodes, savePath = savePath,
|
|
||||||
savePath.overwrite = savePath.overwrite, forest.output = forest.output,
|
responses <- NULL
|
||||||
cores = cores, randomSeed = randomSeed, displayProgress = displayProgress)
|
variablesToDrop <- character(0)
|
||||||
|
|
||||||
|
# yVar is a call object; as.character(yVar) will be the different components, including the parameters.
|
||||||
|
# if the length of yVar is > 1 then it's a function call. If the length is 1, and it's not in data,
|
||||||
|
# then we also need to explicitly evaluate it
|
||||||
|
if(class(yVar)=="call" || !(as.character(yVar) %in% colnames(data))){
|
||||||
|
# yVar is a function like CompetingRiskResponses
|
||||||
|
responses <- eval(expr=yVar, envir=data)
|
||||||
|
|
||||||
|
if(class(formula[[3]]) == "name" && as.character(formula[[3]])=="."){
|
||||||
|
# do any of the variables match data in data? We need to track that so we can drop them later
|
||||||
|
variablesToDrop <- as.character(yVar)[as.character(yVar) %in% names(data)]
|
||||||
|
}
|
||||||
|
|
||||||
|
formula[[2]] <- NULL
|
||||||
|
|
||||||
|
} else if(class(yVar)=="name"){ # and implicitly yVar is contained in data
|
||||||
|
variablesToDrop <- as.character(yVar)
|
||||||
|
}
|
||||||
|
|
||||||
|
# Includes responses which we may need to later cut out
|
||||||
|
mf <- stats::model.frame(formula=formula, data=data, na.action=stats::na.pass)
|
||||||
|
|
||||||
|
if(is.null(responses)){
|
||||||
|
responses <- stats::model.response(mf)
|
||||||
|
}
|
||||||
|
|
||||||
|
# remove any response variables
|
||||||
|
mf <- mf[,!(names(mf) %in% variablesToDrop), drop=FALSE]
|
||||||
|
|
||||||
|
# If environment was provided instead of data
|
||||||
|
if(!is.null(env)){
|
||||||
|
env$data <- mf
|
||||||
|
rm(data)
|
||||||
|
forest <- train.internal(responses, env, splitFinder = splitFinder,
|
||||||
|
nodeResponseCombiner = nodeResponseCombiner,
|
||||||
|
forestResponseCombiner = forestResponseCombiner,
|
||||||
|
ntree = ntree, numberOfSplits = numberOfSplits,
|
||||||
|
mtry = mtry, nodeSize = nodeSize, maxNodeDepth = maxNodeDepth,
|
||||||
|
splitPureNodes = splitPureNodes, savePath = savePath,
|
||||||
|
savePath.overwrite = savePath.overwrite, cores = cores,
|
||||||
|
randomSeed = randomSeed, displayProgress = displayProgress)
|
||||||
|
} else{
|
||||||
|
forest <- train.internal(responses, mf, splitFinder = splitFinder,
|
||||||
|
nodeResponseCombiner = nodeResponseCombiner,
|
||||||
|
forestResponseCombiner = forestResponseCombiner,
|
||||||
|
ntree = ntree, numberOfSplits = numberOfSplits,
|
||||||
|
mtry = mtry, nodeSize = nodeSize, maxNodeDepth = maxNodeDepth,
|
||||||
|
splitPureNodes = splitPureNodes, savePath = savePath,
|
||||||
|
savePath.overwrite = savePath.overwrite, cores = cores,
|
||||||
|
randomSeed = randomSeed, displayProgress = displayProgress)
|
||||||
|
}
|
||||||
|
|
||||||
forest$call <- match.call()
|
forest$call <- match.call()
|
||||||
forest$formula <- formula
|
forest$formula <- formula
|
||||||
|
@ -381,9 +429,7 @@ createTreeTrainer <- function(responseCombiner, splitFinder, covariateList, numb
|
||||||
|
|
||||||
builder <- .jcall(.class_TreeTrainer, builderClassReturned, "builder")
|
builder <- .jcall(.class_TreeTrainer, builderClassReturned, "builder")
|
||||||
|
|
||||||
responseCombinerCasted <- .jcast(responseCombiner$javaObject, .class_ResponseCombiner) # might need to cast a ForestResponseCombiner down
|
builder <- .jcall(builder, builderClassReturned, "responseCombiner", responseCombiner$javaObject)
|
||||||
|
|
||||||
builder <- .jcall(builder, builderClassReturned, "responseCombiner", responseCombinerCasted)
|
|
||||||
builder <- .jcall(builder, builderClassReturned, "splitFinder", splitFinder$javaObject)
|
builder <- .jcall(builder, builderClassReturned, "splitFinder", splitFinder$javaObject)
|
||||||
builder <- .jcall(builder, builderClassReturned, "covariates", covariateList)
|
builder <- .jcall(builder, builderClassReturned, "covariates", covariateList)
|
||||||
builder <- .jcall(builder, builderClassReturned, "numberOfSplits", as.integer(numberOfSplits))
|
builder <- .jcall(builder, builderClassReturned, "numberOfSplits", as.integer(numberOfSplits))
|
||||||
|
|
175
R/vimp.R
175
R/vimp.R
|
@ -1,175 +0,0 @@
|
||||||
|
|
||||||
|
|
||||||
#' Variable Importance
|
|
||||||
#'
|
|
||||||
#' Calculate variable importance by recording the increase in error when a given
|
|
||||||
#' predictor is randomly permuted. Regression forests uses mean squared error;
|
|
||||||
#' competing risks uses integrated Brier score.
|
|
||||||
#'
|
|
||||||
#' @param forest The forest that was trained.
|
|
||||||
#' @param newData A test set of the data if available. If not, then out of bag
|
|
||||||
#' errors will be attempted on the training set.
|
|
||||||
#' @param randomSeed The source of randomness used to permute the values. Can be
|
|
||||||
#' left blank.
|
|
||||||
#' @param events If using competing risks forest, the events that the error
|
|
||||||
#' measure used for VIMP should be calculated on.
|
|
||||||
#' @param time If using competing risks forest, the upper bound of the
|
|
||||||
#' integrated Brier score.
|
|
||||||
#' @param censoringDistribution (Optional) If using competing risks forest, the
|
|
||||||
#' censoring distribution. See \code{\link{integratedBrierScore} for details.}
|
|
||||||
#' @param eventWeights (Optional) If using competing risks forest, weights to be
|
|
||||||
#' applied to the error for each of the \code{events}.
|
|
||||||
#'
|
|
||||||
#' @return A named numeric vector of importance values.
|
|
||||||
#' @export
|
|
||||||
#'
|
|
||||||
#' @examples
|
|
||||||
#' data(wihs)
|
|
||||||
#'
|
|
||||||
#' forest <- train(CR_Response(status, time) ~ ., wihs,
|
|
||||||
#' ntree = 100, numberOfSplits = 0, mtry=3, nodeSize = 5)
|
|
||||||
#'
|
|
||||||
#' vimp(forest, events = 1:2, time = 8.0)
|
|
||||||
#'
|
|
||||||
vimp <- function(
|
|
||||||
forest,
|
|
||||||
newData = NULL,
|
|
||||||
randomSeed = NULL,
|
|
||||||
type = c("mean", "z", "raw"),
|
|
||||||
events = NULL,
|
|
||||||
time = NULL,
|
|
||||||
censoringDistribution = NULL,
|
|
||||||
eventWeights = NULL){
|
|
||||||
|
|
||||||
if(is.null(newData) & is.null(forest$dataset)){
|
|
||||||
stop("forest doesn't have a copy of the training data loaded (this happens if you just loaded it); please manually specify newData and possibly out.of.bag")
|
|
||||||
}
|
|
||||||
|
|
||||||
# Basically we check if type is either null, length 0, or one of the invalid values.
|
|
||||||
# We can't include the last statement in the same statement as length(tyoe) < 1,
|
|
||||||
# because R checks both cases and a different error would display if length(type) == 0
|
|
||||||
typeError = is.null(type) | length(type) < 1
|
|
||||||
if(!typeError){
|
|
||||||
typeError = !(type[1] %in% c("mean", "z", "raw"))
|
|
||||||
}
|
|
||||||
if(typeError){
|
|
||||||
stop("A valid response type must be provided.")
|
|
||||||
}
|
|
||||||
|
|
||||||
if(is.null(newData)){
|
|
||||||
data.java <- forest$dataset
|
|
||||||
out.of.bag <- TRUE
|
|
||||||
|
|
||||||
}
|
|
||||||
else{ # newData is provided
|
|
||||||
data.java <- processFormula(forest$formula, newData, forest$covariateList)$dataset
|
|
||||||
out.of.bag <- FALSE
|
|
||||||
}
|
|
||||||
|
|
||||||
predictionClass <- forest$params$forestResponseCombiner$outputClass
|
|
||||||
|
|
||||||
if(predictionClass == "CompetingRiskFunctions"){
|
|
||||||
if(is.null(time) | length(time) != 1){
|
|
||||||
stop("time must be set at length 1")
|
|
||||||
}
|
|
||||||
|
|
||||||
errorCalculator.java <- ibsCalculatorWrapper(
|
|
||||||
events = events,
|
|
||||||
time = time,
|
|
||||||
censoringDistribution = censoringDistribution,
|
|
||||||
eventWeights = eventWeights)
|
|
||||||
|
|
||||||
} else if(predictionClass == "numeric"){
|
|
||||||
errorCalculator.java <- .jnew(.class_RegressionErrorCalculator)
|
|
||||||
errorCalculator.java <- .jcast(errorCalculator.java, .class_ErrorCalculator)
|
|
||||||
|
|
||||||
} else{
|
|
||||||
stop(paste0("VIMP not yet supported for ", predictionClass, ". If you're just using a non-custom version of largeRCRF then this is a bug and should be reported."))
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
forest.trees.java <- .jcall(forest$javaObject, makeResponse(.class_List), "getTrees")
|
|
||||||
|
|
||||||
vimp.calculator <- .jnew(.class_VariableImportanceCalculator,
|
|
||||||
errorCalculator.java,
|
|
||||||
forest.trees.java,
|
|
||||||
data.java,
|
|
||||||
out.of.bag # isTrainingSet parameter
|
|
||||||
)
|
|
||||||
|
|
||||||
random.java <- NULL
|
|
||||||
if(!is.null(randomSeed)){
|
|
||||||
random.java <- .jnew(.class_Random, .jlong(as.integer(randomSeed)))
|
|
||||||
}
|
|
||||||
random.java <- .object_Optional(random.java)
|
|
||||||
|
|
||||||
covariateRList <- convertJavaListToR(forest$covariateList, class = .class_Covariate)
|
|
||||||
importanceValues <- matrix(nrow = forest$params$ntree, ncol = length(covariateRList))
|
|
||||||
colnames(importanceValues) <- extractCovariateNamesFromJavaList(forest$covariateList)
|
|
||||||
|
|
||||||
for(j in 1:length(covariateRList)){
|
|
||||||
covariateJava <- covariateRList[[j]]
|
|
||||||
covariateJava <-
|
|
||||||
|
|
||||||
importanceValues[, j] <- .jcall(vimp.calculator, "[D", "calculateVariableImportanceRaw", covariateJava, random.java)
|
|
||||||
}
|
|
||||||
|
|
||||||
if(type[1] == "raw"){
|
|
||||||
return(importanceValues)
|
|
||||||
} else if(type[1] == "mean"){
|
|
||||||
meanImportanceValues <- apply(importanceValues, 2, mean)
|
|
||||||
return(meanImportanceValues)
|
|
||||||
} else if(type[1] == "z"){
|
|
||||||
zImportanceValues <- apply(importanceValues, 2, function(x){
|
|
||||||
meanValue <- mean(x)
|
|
||||||
standardError <- sd(x)/sqrt(length(x))
|
|
||||||
return(meanValue / standardError)
|
|
||||||
})
|
|
||||||
return(zImportanceValues)
|
|
||||||
|
|
||||||
} else{
|
|
||||||
stop("A valid response type must be provided.")
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
return(importance)
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
# Internal function
|
|
||||||
ibsCalculatorWrapper <- function(events, time, censoringDistribution = NULL, eventWeights = NULL){
|
|
||||||
if(is.null(events)){
|
|
||||||
stop("events must be specified if using vimp on competing risks data")
|
|
||||||
}
|
|
||||||
|
|
||||||
if(is.null(time)){
|
|
||||||
stop("time must be specified if using vimp on competing risks data")
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
java.censoringDistribution <- NULL
|
|
||||||
if(!is.null(censoringDistribution)){
|
|
||||||
java.censoringDistribution <- processCensoringDistribution(censoringDistribution)
|
|
||||||
java.censoringDistribution <- .object_Optional(java.censoringDistribution)
|
|
||||||
}
|
|
||||||
else{
|
|
||||||
java.censoringDistribution <- .object_Optional(NULL)
|
|
||||||
}
|
|
||||||
|
|
||||||
ibsCalculator.java <- .jnew(.class_IBSCalculator, java.censoringDistribution)
|
|
||||||
|
|
||||||
if(is.null(eventWeights)){
|
|
||||||
eventWeights <- rep(1, times = length(events))
|
|
||||||
}
|
|
||||||
|
|
||||||
ibsCalculatorWrapper.java <- .jnew(.class_IBSErrorCalculatorWrapper,
|
|
||||||
ibsCalculator.java,
|
|
||||||
.jarray(as.integer(events)),
|
|
||||||
as.numeric(time),
|
|
||||||
.jarray(as.numeric(eventWeights)))
|
|
||||||
|
|
||||||
ibsCalculatorWrapper.java <- .jcast(ibsCalculatorWrapper.java, .class_ErrorCalculator)
|
|
||||||
return(ibsCalculatorWrapper.java)
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
12
R/wihs.R
12
R/wihs.R
|
@ -15,13 +15,9 @@
|
||||||
#' @source The data was obtained from the randomForestSRC R package.
|
#' @source The data was obtained from the randomForestSRC R package.
|
||||||
#'
|
#'
|
||||||
#' @references Bacon MC, von Wyl V, Alden C, Sharp G, Robison E, Hessol N, Gange
|
#' @references Bacon MC, von Wyl V, Alden C, Sharp G, Robison E, Hessol N, Gange
|
||||||
#' S, Barranday Y, Holman S, Weber K, Young MA (2005). “The Women’s Interagency
|
#' S, Barranday Y, Holman S, Weber K, Young MA (2005). “The Women’s
|
||||||
#' HIV Study: an Observational Cohort Brings Clinical Sciences to the Bench.”
|
#' Interagency HIV Study: an Observational Cohort Brings Clinical Sciences to
|
||||||
#' Clinical and Vaccine Immunology, 12(9), 1013–1019.
|
#' the Bench.” Clinical and Vaccine Immunology, 12(9), 1013–1019.
|
||||||
#' doi:10.1128/CDLI.12.9.1013-1019.2005.
|
#' doi:10.1128/CDLI.12.9.1013-1019.2005.
|
||||||
#'
|
|
||||||
#' Ishwaran, H., and Udaya B. Kogalur. 2018. Random Forests for Survival,
|
|
||||||
#' Regression, and Classification (Rf-Src). manual.
|
|
||||||
#' https://cran.r-project.org/package=randomForestSRC.
|
|
||||||
#'
|
#'
|
||||||
"wihs"
|
"wihs"
|
Binary file not shown.
|
@ -2,4 +2,4 @@ The Java source code for this package can be obtained at https://github.com/jath
|
||||||
|
|
||||||
* Delete the Jar file in `inst/java/`
|
* Delete the Jar file in `inst/java/`
|
||||||
* Build the Java code in its own separate directory using `mvn clean package` in the root of the directory (same folder containing `README.md`). Make sure you have [Maven](https://maven.apache.org/) installed.
|
* Build the Java code in its own separate directory using `mvn clean package` in the root of the directory (same folder containing `README.md`). Make sure you have [Maven](https://maven.apache.org/) installed.
|
||||||
* Copy the `library/target/largeRCRF-library-1.0-SNAPSHOT.jar` file produced in the Java directory into `inst/java/`
|
* Copy the `library/target/largeRCRF-library/1.0-SNAPSHOT.jar` file produced in the Java directory into `inst/java/`
|
|
@ -6,8 +6,7 @@
|
||||||
\usage{
|
\usage{
|
||||||
addTrees(forest, numTreesToAdd, savePath = NULL,
|
addTrees(forest, numTreesToAdd, savePath = NULL,
|
||||||
savePath.overwrite = c("warn", "delete", "merge"),
|
savePath.overwrite = c("warn", "delete", "merge"),
|
||||||
forest.output = c("online", "offline"), cores = getCores(),
|
cores = getCores(), displayProgress = TRUE)
|
||||||
displayProgress = TRUE)
|
|
||||||
}
|
}
|
||||||
\arguments{
|
\arguments{
|
||||||
\item{forest}{An existing forest.}
|
\item{forest}{An existing forest.}
|
||||||
|
@ -22,11 +21,6 @@ a previously saved forest.}
|
||||||
directory, possibly containing another forest, this specifies what should
|
directory, possibly containing another forest, this specifies what should
|
||||||
be done.}
|
be done.}
|
||||||
|
|
||||||
\item{forest.output}{This parameter only applies if \code{savePath} has been
|
|
||||||
set; set to 'online' (default) and the saved forest will be loaded into
|
|
||||||
memory after being trained. Set to 'offline' and the forest is not saved
|
|
||||||
into memory, but can still be used in a memory unintensive manner.}
|
|
||||||
|
|
||||||
\item{cores}{The number of cores to be used for training the new trees.}
|
\item{cores}{The number of cores to be used for training the new trees.}
|
||||||
|
|
||||||
\item{displayProgress}{A logical indicating whether the progress should be
|
\item{displayProgress}{A logical indicating whether the progress should be
|
||||||
|
|
|
@ -1,20 +0,0 @@
|
||||||
% Generated by roxygen2: do not edit by hand
|
|
||||||
% Please edit documentation in R/convertToOnlineForest.R
|
|
||||||
\name{convertToOnlineForest}
|
|
||||||
\alias{convertToOnlineForest}
|
|
||||||
\title{Convert to Online Forest}
|
|
||||||
\usage{
|
|
||||||
convertToOnlineForest(forest)
|
|
||||||
}
|
|
||||||
\arguments{
|
|
||||||
\item{forest}{The offline forest.}
|
|
||||||
}
|
|
||||||
\value{
|
|
||||||
An online, in memory forst.
|
|
||||||
}
|
|
||||||
\description{
|
|
||||||
Some forests are too large to store in memory and have been saved to disk.
|
|
||||||
They can still be used, but their performance is much slower. If there's
|
|
||||||
enough memory, they can easily be converted into an in-memory forest that is
|
|
||||||
faster to use.
|
|
||||||
}
|
|
|
@ -1,69 +0,0 @@
|
||||||
% Generated by roxygen2: do not edit by hand
|
|
||||||
% Please edit documentation in R/cr_integratedBrierScore.R
|
|
||||||
\name{integratedBrierScore}
|
|
||||||
\alias{integratedBrierScore}
|
|
||||||
\title{Integrated Brier Score}
|
|
||||||
\usage{
|
|
||||||
integratedBrierScore(responses, predictions, event, time,
|
|
||||||
censoringDistribution = NULL, parallel = TRUE)
|
|
||||||
}
|
|
||||||
\arguments{
|
|
||||||
\item{responses}{A list of responses corresponding to the provided
|
|
||||||
mortalities; use \code{\link{CR_Response}}.}
|
|
||||||
|
|
||||||
\item{predictions}{The predictions to be tested against.}
|
|
||||||
|
|
||||||
\item{event}{The event type for the error to be calculated on.}
|
|
||||||
|
|
||||||
\item{time}{\code{time} specifies the upper bound of the integral.}
|
|
||||||
|
|
||||||
\item{censoringDistribution}{Optional; if provided then weights are
|
|
||||||
calculated on the errors. There are three ways to provide it - \itemize{
|
|
||||||
\item{If you have all the censor times and just want to use a simple
|
|
||||||
empirical estimate of the distribution, just provide a numeric vector of
|
|
||||||
all of the censor times and it will be automatically calculated.} \item{You
|
|
||||||
can directly specify the survivor function by providing a list with two
|
|
||||||
numeric vectors called \code{x} and \code{y}. They should be of the same
|
|
||||||
length and correspond to each point. It is assumed that previous to the
|
|
||||||
first value in \code{y} the \code{y} value is 1.0; and that the function
|
|
||||||
you provide is a right-continuous step function.} \item{You can provide a
|
|
||||||
function from \code{\link[stats]{stepfun}}. Note that this only supports
|
|
||||||
functions where \code{right = FALSE} (default), and that the first y value
|
|
||||||
(corresponding to y before the first x value) will be to set to 1.0
|
|
||||||
regardless of what is specified.}
|
|
||||||
|
|
||||||
}}
|
|
||||||
|
|
||||||
\item{parallel}{A logical indicating whether multiple cores should be
|
|
||||||
utilized when calculating the error. Available as an option because it's
|
|
||||||
been observed that using Java's \code{parallelStream} can be unstable on
|
|
||||||
some systems. Default value is \code{TRUE}; only set to \code{FALSE} if you
|
|
||||||
get strange errors while predicting.}
|
|
||||||
}
|
|
||||||
\value{
|
|
||||||
A numeric vector of the Integrated Brier Score for each prediction.
|
|
||||||
}
|
|
||||||
\description{
|
|
||||||
Used to calculate the Integrated Brier Score, which for the competing risks
|
|
||||||
setting is the integral of the squared difference between each observed
|
|
||||||
cumulative incidence function (CIF) for each observation and the
|
|
||||||
corresponding predicted CIF. If the survivor function (1 - CDF) of the
|
|
||||||
censoring distribution is provided, weights can be calculated to account for
|
|
||||||
the censoring.
|
|
||||||
}
|
|
||||||
\examples{
|
|
||||||
data <- data.frame(delta=c(1,1,0,0,2,2), T=1:6, x=1:6)
|
|
||||||
|
|
||||||
model <- train(CR_Response(delta, T) ~ x, data, ntree=100, numberOfSplits=0, mtry=1, nodeSize=1)
|
|
||||||
|
|
||||||
newData <- data.frame(delta=c(1,0,2,1,0,2), T=1:6, x=1:6)
|
|
||||||
predictions <- predict(model, newData)
|
|
||||||
|
|
||||||
scores <- integratedBrierScore(CR_Response(data$delta, data$T), predictions, 1, 6.0)
|
|
||||||
|
|
||||||
}
|
|
||||||
\references{
|
|
||||||
Section 4.2 of Ishwaran H, Gerds TA, Kogalur UB, Moore RD, Gange
|
|
||||||
SJ, Lau BM (2014). “Random Survival Forests for Competing Risks.”
|
|
||||||
Biostatistics, 15(4), 757–773. doi:10.1093/ biostatistics/kxu010.
|
|
||||||
}
|
|
|
@ -4,19 +4,10 @@
|
||||||
\alias{loadForest}
|
\alias{loadForest}
|
||||||
\title{Load Random Forest}
|
\title{Load Random Forest}
|
||||||
\usage{
|
\usage{
|
||||||
loadForest(directory, forest.output = c("online", "offline"),
|
loadForest(directory)
|
||||||
maxTreeNum = NULL)
|
|
||||||
}
|
}
|
||||||
\arguments{
|
\arguments{
|
||||||
\item{directory}{The directory created that saved the previous forest.}
|
\item{directory}{The directory created that saved the previous forest.}
|
||||||
|
|
||||||
\item{forest.output}{Specifies whether the forest loaded should be loaded
|
|
||||||
into memory, or reflect the saved files where only one tree is loaded at a
|
|
||||||
time.}
|
|
||||||
|
|
||||||
\item{maxTreeNum}{If for some reason you only want to load the number of
|
|
||||||
trees up until a certain point, you can specify maxTreeNum as a single
|
|
||||||
number.}
|
|
||||||
}
|
}
|
||||||
\value{
|
\value{
|
||||||
A JForest object; see \code{\link{train}} for details.
|
A JForest object; see \code{\link{train}} for details.
|
||||||
|
|
|
@ -41,9 +41,3 @@ mortalities <- list(
|
||||||
naiveConcordance(CR_Response(newData$delta, newData$T), mortalities)
|
naiveConcordance(CR_Response(newData$delta, newData$T), mortalities)
|
||||||
|
|
||||||
}
|
}
|
||||||
\references{
|
|
||||||
Section 3.2 of Wolbers, Marcel, Paul Blanche, Michael T. Koller,
|
|
||||||
Jacqueline C M Witteman, and Thomas A Gerds. 2014. “Concordance for
|
|
||||||
Prognostic Models with Competing Risks.” Biostatistics 15 (3): 526–39.
|
|
||||||
https://doi.org/10.1093/biostatistics/kxt059.
|
|
||||||
}
|
|
||||||
|
|
55
man/train.Rd
55
man/train.Rd
|
@ -6,23 +6,22 @@
|
||||||
\usage{
|
\usage{
|
||||||
train(formula, data, splitFinder = NULL, nodeResponseCombiner = NULL,
|
train(formula, data, splitFinder = NULL, nodeResponseCombiner = NULL,
|
||||||
forestResponseCombiner = NULL, ntree, numberOfSplits, mtry, nodeSize,
|
forestResponseCombiner = NULL, ntree, numberOfSplits, mtry, nodeSize,
|
||||||
maxNodeDepth = 1e+05, na.penalty = TRUE, splitPureNodes = TRUE,
|
maxNodeDepth = 1e+05, splitPureNodes = TRUE, savePath = NULL,
|
||||||
savePath = NULL, savePath.overwrite = c("warn", "delete", "merge"),
|
savePath.overwrite = c("warn", "delete", "merge"),
|
||||||
forest.output = c("online", "offline"), cores = getCores(),
|
cores = getCores(), randomSeed = NULL, displayProgress = TRUE)
|
||||||
randomSeed = NULL, displayProgress = TRUE)
|
|
||||||
}
|
}
|
||||||
\arguments{
|
\arguments{
|
||||||
\item{formula}{You may specify the response and covariates as a formula
|
\item{formula}{You may specify the response and covariates as a formula
|
||||||
instead; make sure the response in the formula is still properly
|
instead; make sure the response in the formula is still properly
|
||||||
constructed.}
|
constructed; see \code{responses}}
|
||||||
|
|
||||||
\item{data}{A data.frame containing the columns of the predictors and
|
\item{data}{A data.frame containing the columns of the predictors and
|
||||||
responses.}
|
responses.}
|
||||||
|
|
||||||
\item{splitFinder}{A split finder that's used to score splits in the random
|
\item{splitFinder}{A split finder that's used to score splits in the random
|
||||||
forest training algorithm. See \code{\link{CompetingRiskSplitFinders}} or
|
forest training algorithm. See \code{\link{CompetingRiskSplitFinders}}
|
||||||
\code{\link{WeightedVarianceSplitFinder}}. If you don't specify one, this
|
or \code{\link{WeightedVarianceSplitFinder}}. If you don't specify one,
|
||||||
function tries to pick one based on the response. For
|
this function tries to pick one based on the response. For
|
||||||
\code{\link{CR_Response}} without censor times, it will pick a
|
\code{\link{CR_Response}} without censor times, it will pick a
|
||||||
\code{\link{LogRankSplitFinder}}; while if censor times were provided it
|
\code{\link{LogRankSplitFinder}}; while if censor times were provided it
|
||||||
will pick \code{\link{GrayLogRankSplitFinder}}; for integer or numeric
|
will pick \code{\link{GrayLogRankSplitFinder}}; for integer or numeric
|
||||||
|
@ -31,20 +30,20 @@ responses it picks a \code{\link{WeightedVarianceSplitFinder}}.}
|
||||||
\item{nodeResponseCombiner}{A response combiner that's used to combine
|
\item{nodeResponseCombiner}{A response combiner that's used to combine
|
||||||
responses for each terminal node in a tree (regression example; average the
|
responses for each terminal node in a tree (regression example; average the
|
||||||
observations in each tree into a single number). See
|
observations in each tree into a single number). See
|
||||||
\code{\link{CR_ResponseCombiner}} or \code{\link{MeanResponseCombiner}}. If
|
\code{\link{CR_ResponseCombiner}} or
|
||||||
you don't specify one, this function tries to pick one based on the
|
\code{\link{MeanResponseCombiner}}. If you don't specify one, this function
|
||||||
response. For \code{\link{CR_Response}} it picks a
|
tries to pick one based on the response. For \code{\link{CR_Response}} it
|
||||||
\code{\link{CR_ResponseCombiner}}; for integer or numeric responses it
|
picks a \code{\link{CR_ResponseCombiner}}; for integer or numeric
|
||||||
picks a \code{\link{MeanResponseCombiner}}.}
|
responses it picks a \code{\link{MeanResponseCombiner}}.}
|
||||||
|
|
||||||
\item{forestResponseCombiner}{A response combiner that's used to combine
|
\item{forestResponseCombiner}{A response combiner that's used to combine
|
||||||
predictions across trees into one final result (regression example; average
|
predictions across trees into one final result (regression example; average
|
||||||
the prediction of each tree into a single number). See
|
the prediction of each tree into a single number). See
|
||||||
\code{\link{CR_FunctionCombiner}} or \code{\link{MeanResponseCombiner}}. If
|
\code{\link{CR_FunctionCombiner}} or
|
||||||
you don't specify one, this function tries to pick one based on the
|
\code{\link{MeanResponseCombiner}}. If you don't specify one, this function
|
||||||
response. For \code{\link{CR_Response}} it picks a
|
tries to pick one based on the response. For \code{\link{CR_Response}} it
|
||||||
\code{\link{CR_FunctionCombiner}}; for integer or numeric responses it
|
picks a \code{\link{CR_FunctionCombiner}}; for integer or numeric
|
||||||
picks a \code{\link{MeanResponseCombiner}}.}
|
responses it picks a \code{\link{MeanResponseCombiner}}.}
|
||||||
|
|
||||||
\item{ntree}{An integer that specifies how many trees should be trained.}
|
\item{ntree}{An integer that specifies how many trees should be trained.}
|
||||||
|
|
||||||
|
@ -67,21 +66,6 @@ as large as \code{nodeSize}.}
|
||||||
controls tree length; by default \code{maxNodeDepth} is an extremely high
|
controls tree length; by default \code{maxNodeDepth} is an extremely high
|
||||||
number and tree depth is controlled by \code{nodeSize}.}
|
number and tree depth is controlled by \code{nodeSize}.}
|
||||||
|
|
||||||
\item{na.penalty}{This parameter controls whether predictor variables with
|
|
||||||
NAs should be penalized when being considered for a best split. Best splits
|
|
||||||
(and the associated score) are determined on only non-NA data; the penalty
|
|
||||||
is to take the best split identified, and to randomly assign any NAs
|
|
||||||
(according to the proportion of data split left and right), and then
|
|
||||||
recalculate the corresponding split score, when is then compared with the
|
|
||||||
other split candiate variables. This penalty adds some computational time,
|
|
||||||
so it may be disabled for some variables. \code{na.penalty} may be
|
|
||||||
specified as a vector of logicals indicating, for each predictor variable,
|
|
||||||
whether the penalty should be applied to that variable. If it's length 1
|
|
||||||
then it applies to all variables. Alternatively, a single numeric value may
|
|
||||||
be provided to indicate a threshold whereby the penalty is activated only
|
|
||||||
if the proportion of NAs for that variable in the training set exceeds that
|
|
||||||
threshold.}
|
|
||||||
|
|
||||||
\item{splitPureNodes}{This parameter determines whether the algorithm will
|
\item{splitPureNodes}{This parameter determines whether the algorithm will
|
||||||
split a pure node. If set to FALSE, then before every split it will check
|
split a pure node. If set to FALSE, then before every split it will check
|
||||||
that every response is the same, and if so, not split. If set to TRUE it
|
that every response is the same, and if so, not split. If set to TRUE it
|
||||||
|
@ -107,11 +91,6 @@ assumes (without checking) that the existing trees are from a previous run
|
||||||
and starts from where it left off. This option is useful if recovering from
|
and starts from where it left off. This option is useful if recovering from
|
||||||
a crash.}
|
a crash.}
|
||||||
|
|
||||||
\item{forest.output}{This parameter only applies if \code{savePath} has been
|
|
||||||
set; set to 'online' (default) and the saved forest will be loaded into
|
|
||||||
memory after being trained. Set to 'offline' and the forest is not saved
|
|
||||||
into memory, but can still be used in a memory unintensive manner.}
|
|
||||||
|
|
||||||
\item{cores}{This parameter specifies how many trees will be simultaneously
|
\item{cores}{This parameter specifies how many trees will be simultaneously
|
||||||
trained. By default the package attempts to detect how many cores you have
|
trained. By default the package attempts to detect how many cores you have
|
||||||
by using the \code{parallel} package and using all of them. You may specify
|
by using the \code{parallel} package and using all of them. You may specify
|
||||||
|
|
48
man/vimp.Rd
48
man/vimp.Rd
|
@ -1,48 +0,0 @@
|
||||||
% Generated by roxygen2: do not edit by hand
|
|
||||||
% Please edit documentation in R/vimp.R
|
|
||||||
\name{vimp}
|
|
||||||
\alias{vimp}
|
|
||||||
\title{Variable Importance}
|
|
||||||
\usage{
|
|
||||||
vimp(forest, newData = NULL, randomSeed = NULL, type = c("mean", "z",
|
|
||||||
"raw"), events = NULL, time = NULL, censoringDistribution = NULL,
|
|
||||||
eventWeights = NULL)
|
|
||||||
}
|
|
||||||
\arguments{
|
|
||||||
\item{forest}{The forest that was trained.}
|
|
||||||
|
|
||||||
\item{newData}{A test set of the data if available. If not, then out of bag
|
|
||||||
errors will be attempted on the training set.}
|
|
||||||
|
|
||||||
\item{randomSeed}{The source of randomness used to permute the values. Can be
|
|
||||||
left blank.}
|
|
||||||
|
|
||||||
\item{events}{If using competing risks forest, the events that the error
|
|
||||||
measure used for VIMP should be calculated on.}
|
|
||||||
|
|
||||||
\item{time}{If using competing risks forest, the upper bound of the
|
|
||||||
integrated Brier score.}
|
|
||||||
|
|
||||||
\item{censoringDistribution}{(Optional) If using competing risks forest, the
|
|
||||||
censoring distribution. See \code{\link{integratedBrierScore} for details.}}
|
|
||||||
|
|
||||||
\item{eventWeights}{(Optional) If using competing risks forest, weights to be
|
|
||||||
applied to the error for each of the \code{events}.}
|
|
||||||
}
|
|
||||||
\value{
|
|
||||||
A named numeric vector of importance values.
|
|
||||||
}
|
|
||||||
\description{
|
|
||||||
Calculate variable importance by recording the increase in error when a given
|
|
||||||
predictor is randomly permuted. Regression forests uses mean squared error;
|
|
||||||
competing risks uses integrated Brier score.
|
|
||||||
}
|
|
||||||
\examples{
|
|
||||||
data(wihs)
|
|
||||||
|
|
||||||
forest <- train(CR_Response(status, time) ~ ., wihs,
|
|
||||||
ntree = 100, numberOfSplits = 0, mtry=3, nodeSize = 5)
|
|
||||||
|
|
||||||
vimp(forest, events = 1:2, time = 8.0)
|
|
||||||
|
|
||||||
}
|
|
12
man/wihs.Rd
12
man/wihs.Rd
|
@ -25,13 +25,9 @@ time may also be censored.
|
||||||
}
|
}
|
||||||
\references{
|
\references{
|
||||||
Bacon MC, von Wyl V, Alden C, Sharp G, Robison E, Hessol N, Gange
|
Bacon MC, von Wyl V, Alden C, Sharp G, Robison E, Hessol N, Gange
|
||||||
S, Barranday Y, Holman S, Weber K, Young MA (2005). “The Women’s Interagency
|
S, Barranday Y, Holman S, Weber K, Young MA (2005). “The Women’s
|
||||||
HIV Study: an Observational Cohort Brings Clinical Sciences to the Bench.”
|
Interagency HIV Study: an Observational Cohort Brings Clinical Sciences to
|
||||||
Clinical and Vaccine Immunology, 12(9), 1013–1019.
|
the Bench.” Clinical and Vaccine Immunology, 12(9), 1013–1019.
|
||||||
doi:10.1128/CDLI.12.9.1013-1019.2005.
|
doi:10.1128/CDLI.12.9.1013-1019.2005.
|
||||||
|
|
||||||
Ishwaran, H., and Udaya B. Kogalur. 2018. Random Forests for Survival,
|
|
||||||
Regression, and Classification (Rf-Src). manual.
|
|
||||||
https://cran.r-project.org/package=randomForestSRC.
|
|
||||||
}
|
}
|
||||||
\keyword{datasets}
|
\keyword{datasets}
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
context("Add trees on existing forest")
|
context("Add trees on existing forest")
|
||||||
|
|
||||||
test_that("Can add trees on existing online forest", {
|
test_that("Can add trees on existing forest", {
|
||||||
|
|
||||||
trainingData <- data.frame(x=rnorm(100))
|
trainingData <- data.frame(x=rnorm(100))
|
||||||
trainingData$T <- rexp(100) + abs(trainingData$x)
|
trainingData$T <- rexp(100) + abs(trainingData$x)
|
||||||
|
@ -20,44 +20,6 @@ test_that("Can add trees on existing online forest", {
|
||||||
|
|
||||||
})
|
})
|
||||||
|
|
||||||
test_that("Can add trees on existing offline forest", {
|
|
||||||
|
|
||||||
if(file.exists("trees")){ # folder could exist from a previous failed test; delete it
|
|
||||||
unlink("trees", recursive=TRUE)
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
trainingData <- data.frame(x=rnorm(100))
|
|
||||||
trainingData$T <- rexp(100) + abs(trainingData$x)
|
|
||||||
trainingData$delta <- sample(0:2, size = 100, replace=TRUE)
|
|
||||||
|
|
||||||
forest <- train(CR_Response(delta, T) ~ x, trainingData, ntree=50,
|
|
||||||
numberOfSplits=0, mtry=1, nodeSize=5,
|
|
||||||
forestResponseCombiner = CR_FunctionCombiner(events = 1:2, times = 0:10), # TODO - remove specifing times; this is workaround around unimplemented feature for offline forests
|
|
||||||
cores=2, displayProgress=FALSE, savePath="trees",
|
|
||||||
forest.output = "offline")
|
|
||||||
warning("TODO - need to implement feature; test workaround in the meantime")
|
|
||||||
|
|
||||||
predictions <- predict(forest)
|
|
||||||
|
|
||||||
warning_message <- "Assuming that the previous forest at savePath is the provided forest argument; if not true then your results will be suspect"
|
|
||||||
|
|
||||||
forest.more <- expect_warning(addTrees(forest, 50, cores=2, displayProgress=FALSE,
|
|
||||||
savePath="trees", savePath.overwrite = "merge",
|
|
||||||
forest.output = "offline"), fixed=warning_message) # test multi-core
|
|
||||||
|
|
||||||
predictions <- predict(forest)
|
|
||||||
|
|
||||||
forest.more <- expect_warning(addTrees(forest, 50, cores=1, displayProgress=FALSE,
|
|
||||||
savePath="trees", savePath.overwrite = "merge",
|
|
||||||
forest.output = "offline"), fixed=warning_message) # test single-core
|
|
||||||
|
|
||||||
expect_true(T) # show Ok if we got this far
|
|
||||||
|
|
||||||
unlink("trees", recursive=TRUE)
|
|
||||||
|
|
||||||
})
|
|
||||||
|
|
||||||
test_that("Test adding trees on saved forest - using delete", {
|
test_that("Test adding trees on saved forest - using delete", {
|
||||||
|
|
||||||
if(file.exists("trees")){ # folder could exist from a previous failed test; delete it
|
if(file.exists("trees")){ # folder could exist from a previous failed test; delete it
|
||||||
|
|
|
@ -1,38 +0,0 @@
|
||||||
context("Calculate integrated Brier score without error")
|
|
||||||
|
|
||||||
# This code is more concerned that the code runs without error. The tests in the
|
|
||||||
# Java code check that the results it returns are accurate.
|
|
||||||
test_that("Can calculate Integrated Brier Score", {
|
|
||||||
|
|
||||||
sampleData <- data.frame(x=rnorm(100))
|
|
||||||
sampleData$T <- sample(0:4, size=100, replace = TRUE) # the censor distribution we provide needs to conform to the data or we can get NaNs
|
|
||||||
sampleData$delta <- sample(0:2, size = 100, replace = TRUE)
|
|
||||||
|
|
||||||
testData <- sampleData[1:5,]
|
|
||||||
trainingData <- sampleData[6:100,]
|
|
||||||
|
|
||||||
forest <- train(CR_Response(delta, T) ~ x, trainingData, ntree=50, numberOfSplits=0, mtry=1, nodeSize=5, cores=2, displayProgress=FALSE)
|
|
||||||
|
|
||||||
predictions <- predict(forest, testData)
|
|
||||||
|
|
||||||
scores_test <- integratedBrierScore(CR_Response(testData$delta, testData$T), predictions, event = 1, time = 4,
|
|
||||||
censoringDistribution = NULL)
|
|
||||||
# Check that we don't get a crash if we calculate the error for only one observation
|
|
||||||
scores_one <- integratedBrierScore(CR_Response(testData$delta, testData$T)[1], predictions[1], event = 1, time = 4,
|
|
||||||
censoringDistribution = NULL)
|
|
||||||
|
|
||||||
# Make sure our error didn't somehow change
|
|
||||||
expect_equal(scores_one, scores_test[1])
|
|
||||||
|
|
||||||
# Provide a censoring distribution via censor times
|
|
||||||
scores_censoring1 <- integratedBrierScore(CR_Response(testData$delta, testData$T), predictions, event = 1, time = 4,
|
|
||||||
censoringDistribution = c(0,1,1,2,3,4))
|
|
||||||
scores_censoring2 <- integratedBrierScore(CR_Response(testData$delta, testData$T), predictions, event = 1, time = 4,
|
|
||||||
censoringDistribution = list(x = 0:4, y = 1 - c(1/6, 3/6, 4/6, 5/6, 6/6)))
|
|
||||||
scores_censoring3 <- integratedBrierScore(CR_Response(testData$delta, testData$T), predictions, event = 1, time = 4,
|
|
||||||
censoringDistribution = stepfun(x=0:4, y=1 - c(0, 1/6, 3/6, 4/6, 5/6, 6/6)))
|
|
||||||
|
|
||||||
expect_equal(scores_censoring1, scores_censoring2)
|
|
||||||
expect_equal(scores_censoring1, scores_censoring3)
|
|
||||||
|
|
||||||
})
|
|
|
@ -15,44 +15,3 @@ test_that("CR_Response of length 1 - no censor times", {
|
||||||
expect_true(T) # show Ok if we got this far
|
expect_true(T) # show Ok if we got this far
|
||||||
|
|
||||||
})
|
})
|
||||||
|
|
||||||
test_that("Can sub-index CR_Response - no censor times", {
|
|
||||||
|
|
||||||
x <- CR_Response(1:5, 1:5)
|
|
||||||
|
|
||||||
index <- 5
|
|
||||||
|
|
||||||
y <- x[index]
|
|
||||||
|
|
||||||
expect_equal(y$eventTime, index)
|
|
||||||
expect_equal(y$eventIndicator, index)
|
|
||||||
|
|
||||||
expect_equal(rJava::.jcall(y$javaObject, "I", "size"), 1)
|
|
||||||
oneJavaItem <- rJava::.jcall(y$javaObject, largeRCRF:::makeResponse(largeRCRF:::.class_Object), "get", 0L)
|
|
||||||
oneJavaItem <- rJava::.jcast(oneJavaItem, largeRCRF:::.class_CompetingRiskResponse)
|
|
||||||
delta <- rJava::.jcall(oneJavaItem, "I", "getDelta")
|
|
||||||
|
|
||||||
expect_equal(delta, index)
|
|
||||||
|
|
||||||
})
|
|
||||||
|
|
||||||
test_that("Can sub-index CR_Response - censor times", {
|
|
||||||
|
|
||||||
x <- CR_Response(1:5, 1:5, 1:5)
|
|
||||||
|
|
||||||
index <- 5
|
|
||||||
|
|
||||||
y <- x[index]
|
|
||||||
|
|
||||||
expect_equal(y$eventTime, index)
|
|
||||||
expect_equal(y$eventIndicator, index)
|
|
||||||
expect_equal(y$censorTime, index)
|
|
||||||
|
|
||||||
expect_equal(rJava::.jcall(y$javaObject, "I", "size"), 1)
|
|
||||||
oneJavaItem <- rJava::.jcall(y$javaObject, largeRCRF:::makeResponse(largeRCRF:::.class_Object), "get", 0L)
|
|
||||||
oneJavaItem <- rJava::.jcast(oneJavaItem, largeRCRF:::.class_CompetingRiskResponseWithCensorTime)
|
|
||||||
delta <- rJava::.jcall(oneJavaItem, "D", "getC")
|
|
||||||
|
|
||||||
expect_equal(delta, index)
|
|
||||||
|
|
||||||
})
|
|
|
@ -28,7 +28,6 @@ test_that("Regresssion doesn't crash", {
|
||||||
forest <- train(y ~ x, trainingData, ntree=50, numberOfSplits=0, mtry=1, nodeSize=5, cores=2, displayProgress=FALSE)
|
forest <- train(y ~ x, trainingData, ntree=50, numberOfSplits=0, mtry=1, nodeSize=5, cores=2, displayProgress=FALSE)
|
||||||
|
|
||||||
predictions <- predict(forest, testData)
|
predictions <- predict(forest, testData)
|
||||||
other_predictions <- predict(forest) # there was a bug if newData wasn't provided.
|
|
||||||
|
|
||||||
expect_true(T) # show Ok if we got this far
|
expect_true(T) # show Ok if we got this far
|
||||||
|
|
||||||
|
|
|
@ -2,11 +2,7 @@ context("Train, save, and load without error")
|
||||||
|
|
||||||
test_that("Can save & load regression example", {
|
test_that("Can save & load regression example", {
|
||||||
|
|
||||||
if(file.exists("trees_saving_loading")){
|
expect_false(file.exists("trees_saving_loading")) # Folder shouldn't exist yet
|
||||||
unlink("trees_saving_loading", recursive=TRUE)
|
|
||||||
}
|
|
||||||
|
|
||||||
expect_false(file.exists("trees_saving_loading")) # Folder shouldn't exist at this point
|
|
||||||
|
|
||||||
x1 <- rnorm(1000)
|
x1 <- rnorm(1000)
|
||||||
x2 <- rnorm(1000)
|
x2 <- rnorm(1000)
|
||||||
|
|
|
@ -13,7 +13,7 @@ test_that("Can save a random forest while training, and use it afterward", {
|
||||||
data <- data.frame(x1, x2, y)
|
data <- data.frame(x1, x2, y)
|
||||||
forest <- train(y ~ x1 + x2, data,
|
forest <- train(y ~ x1 + x2, data,
|
||||||
ntree=100, numberOfSplits = 5, mtry = 1, nodeSize = 5,
|
ntree=100, numberOfSplits = 5, mtry = 1, nodeSize = 5,
|
||||||
savePath="trees", forest.output = "online", displayProgress=FALSE)
|
savePath="trees", displayProgress=FALSE)
|
||||||
|
|
||||||
expect_true(file.exists("trees")) # Something should have been saved
|
expect_true(file.exists("trees")) # Something should have been saved
|
||||||
|
|
||||||
|
@ -26,39 +26,6 @@ test_that("Can save a random forest while training, and use it afterward", {
|
||||||
predictions <- predict(newforest, newData)
|
predictions <- predict(newforest, newData)
|
||||||
|
|
||||||
|
|
||||||
unlink("trees", recursive=TRUE)
|
|
||||||
|
|
||||||
})
|
|
||||||
|
|
||||||
test_that("Can save a random forest while training, and use it afterward with pure offline forest", {
|
|
||||||
|
|
||||||
if(file.exists("trees")){ # folder could exist from a previous failed test; delete it
|
|
||||||
unlink("trees", recursive=TRUE)
|
|
||||||
}
|
|
||||||
|
|
||||||
x1 <- rnorm(1000)
|
|
||||||
x2 <- rnorm(1000)
|
|
||||||
y <- 1 + x1 + x2 + rnorm(1000)
|
|
||||||
|
|
||||||
data <- data.frame(x1, x2, y)
|
|
||||||
forest <- train(y ~ x1 + x2, data,
|
|
||||||
ntree=100, numberOfSplits = 5, mtry = 1, nodeSize = 5,
|
|
||||||
savePath="trees", forest.output = "offline", displayProgress=FALSE)
|
|
||||||
|
|
||||||
expect_true(file.exists("trees")) # Something should have been saved
|
|
||||||
|
|
||||||
# try making a little prediction to verify it works
|
|
||||||
newData <- data.frame(x1=seq(from=-3, to=3, by=0.5), x2=0)
|
|
||||||
predictions <- predict(forest, newData)
|
|
||||||
|
|
||||||
# Also make sure we can load the forest too
|
|
||||||
newforest <- loadForest("trees")
|
|
||||||
predictions <- predict(newforest, newData)
|
|
||||||
|
|
||||||
# Last, make sure we can take the forest online
|
|
||||||
onlineForest <- convertToOnlineForest(forest)
|
|
||||||
predictions <- predict(onlineForest, newData)
|
|
||||||
|
|
||||||
unlink("trees", recursive=TRUE)
|
unlink("trees", recursive=TRUE)
|
||||||
|
|
||||||
})
|
})
|
|
@ -1,102 +0,0 @@
|
||||||
context("Use VIMP without error")
|
|
||||||
|
|
||||||
test_that("VIMP doesn't crash; no test dataset", {
|
|
||||||
|
|
||||||
data(wihs)
|
|
||||||
|
|
||||||
forest <- train(CR_Response(status, time) ~ ., wihs, ntree=50, numberOfSplits=0, mtry=1, nodeSize=5, displayProgress=FALSE)
|
|
||||||
|
|
||||||
# Run VIMP several times under different scenarios
|
|
||||||
importance <- vimp(forest, type="raw", events=1:2, time=5.0)
|
|
||||||
vimp(forest, type="raw", events=1, time=5.0)
|
|
||||||
vimp(forest, type="raw", events=1:2, time=5.0, eventWeights = c(0.2, 0.8))
|
|
||||||
|
|
||||||
# Not much of a test, but the Java code tests more for correctness. This just
|
|
||||||
# tests that the R code runs without error.
|
|
||||||
expect_equal(ncol(importance), 4) # 4 predictors
|
|
||||||
|
|
||||||
})
|
|
||||||
|
|
||||||
|
|
||||||
test_that("VIMP doesn't crash; test dataset", {
|
|
||||||
|
|
||||||
data(wihs)
|
|
||||||
|
|
||||||
trainingData <- wihs[1:1000,]
|
|
||||||
testData <- wihs[1001:nrow(wihs),]
|
|
||||||
|
|
||||||
forest <- train(CR_Response(status, time) ~ ., trainingData, ntree=50, numberOfSplits=0, mtry=1, nodeSize=5, displayProgress=FALSE, cores=1)
|
|
||||||
|
|
||||||
# Run VIMP several times under different scenarios
|
|
||||||
importance <- vimp(forest, newData=testData, type="raw", events=1:2, time=5.0)
|
|
||||||
vimp(forest, newData=testData, type="raw", events=1, time=5.0)
|
|
||||||
vimp(forest, newData=testData, type="raw", events=1:2, time=5.0, eventWeights = c(0.2, 0.8))
|
|
||||||
|
|
||||||
# Not much of a test, but the Java code tests more for correctness. This just
|
|
||||||
# tests that the R code runs without error.
|
|
||||||
expect_equal(ncol(importance), 4) # 4 predictors
|
|
||||||
|
|
||||||
})
|
|
||||||
|
|
||||||
|
|
||||||
test_that("VIMP doesn't crash; censoring distribution; all methods equal", {
|
|
||||||
|
|
||||||
sampleData <- data.frame(x=rnorm(100))
|
|
||||||
sampleData$T <- sample(0:4, size=100, replace = TRUE) # the censor distribution we provide needs to conform to the data or we can get NaNs
|
|
||||||
sampleData$delta <- sample(0:2, size = 100, replace = TRUE)
|
|
||||||
|
|
||||||
testData <- sampleData[1:5,]
|
|
||||||
trainingData <- sampleData[6:100,]
|
|
||||||
|
|
||||||
forest <- train(CR_Response(delta, T) ~ x, trainingData, ntree=50, numberOfSplits=0, mtry=1, nodeSize=5, cores=2, displayProgress=FALSE)
|
|
||||||
|
|
||||||
importance1 <- vimp(forest, type="raw", events=1:2, time=4.0, randomSeed=50,
|
|
||||||
censoringDistribution = c(0,1,1,2,3,4))
|
|
||||||
importance2 <- vimp(forest, type="raw", events=1:2, time=4.0, randomSeed=50,
|
|
||||||
censoringDistribution = list(x = 0:4, y = 1 - c(1/6, 3/6, 4/6, 5/6, 6/6)))
|
|
||||||
importance3 <- vimp(forest, type="raw", events=1:2, time=4.0, randomSeed=50,
|
|
||||||
censoringDistribution = stepfun(x=0:4, y=1 - c(0, 1/6, 3/6, 4/6, 5/6, 6/6)))
|
|
||||||
|
|
||||||
expect_equal(importance1, importance2)
|
|
||||||
expect_equal(importance1, importance3)
|
|
||||||
|
|
||||||
})
|
|
||||||
|
|
||||||
test_that("VIMP doesn't crash; regression dataset", {
|
|
||||||
|
|
||||||
data <- data.frame(x1=rnorm(1000), x2=rnorm(1000), x3=rnorm(1000))
|
|
||||||
data$y <- data$x1 + 3*data$x2 + 0.05*data$x3 + rnorm(1000)
|
|
||||||
|
|
||||||
forest <- train(y ~ ., data, ntree=50, numberOfSplits=100, mtry=2, nodeSize=5, displayProgress=FALSE)
|
|
||||||
|
|
||||||
importance <- vimp(forest, type="mean")
|
|
||||||
|
|
||||||
expect_true(importance["x2"] > importance["x3"])
|
|
||||||
|
|
||||||
# Not much of a test, but the Java code tests more for correctness. This just
|
|
||||||
# tests that the R code runs without error.
|
|
||||||
expect_equal(length(importance), 3) # 3 predictors
|
|
||||||
|
|
||||||
})
|
|
||||||
|
|
||||||
test_that("VIMP produces mean and z scores correctly", {
|
|
||||||
|
|
||||||
data <- data.frame(x1=rnorm(1000), x2=rnorm(1000), x3=rnorm(1000))
|
|
||||||
data$y <- data$x1 + 3*data$x2 + 0.05*data$x3 + rnorm(1000)
|
|
||||||
|
|
||||||
forest <- train(y ~ ., data, ntree=50, numberOfSplits=100, mtry=2, nodeSize=5, displayProgress=FALSE)
|
|
||||||
|
|
||||||
actual.importance.raw <- vimp(forest, type="raw", randomSeed=5)
|
|
||||||
actual.importance.mean <- vimp(forest, type="mean", randomSeed=5)
|
|
||||||
actual.importance.z <- vimp(forest, type="z", randomSeed=5)
|
|
||||||
|
|
||||||
expected.importance.mean <- apply(actual.importance.raw, 2, mean)
|
|
||||||
expected.importance.z <- apply(actual.importance.raw, 2, function(x){
|
|
||||||
mn <- mean(x)
|
|
||||||
return( mn / (sd(x) / sqrt(length(x))) )
|
|
||||||
})
|
|
||||||
|
|
||||||
expect_equal(expected.importance.mean, actual.importance.mean)
|
|
||||||
expect_equal(expected.importance.z, actual.importance.z)
|
|
||||||
|
|
||||||
})
|
|
Loading…
Reference in a new issue