New features -

Add support for making predictions without specifying training data
Add support for adding trees to an existing forest
Add support for toggling displayProgress

Also reduced the size of the package by removing some unused dependency
classes.
This commit is contained in:
Joel Therrien 2019-06-19 13:14:11 -07:00
parent 30d9060517
commit fdc708dad5
152 changed files with 526 additions and 98 deletions

View file

@ -1,7 +1,7 @@
Package: largeRCRF Package: largeRCRF
Type: Package Type: Package
Title: Large Random Competing Risk Forests, Java Implementation Run in R Title: Large Random Competing Risk Forests, Java Implementation Run in R
Version: 0.0.0.9037 Version: 0.0.0.9038
Authors@R: person("Joel", "Therrien", email = "joel@joeltherrien.ca", role = c("aut", "cre")) Authors@R: person("Joel", "Therrien", email = "joel@joeltherrien.ca", role = c("aut", "cre"))
Description: This package is used for training competing risk random forests on larger scale datasets. Description: This package is used for training competing risk random forests on larger scale datasets.
It currently only supports training models, running predictions, plotting those predictions (they are curves), It currently only supports training models, running predictions, plotting those predictions (they are curves),

View file

@ -25,6 +25,8 @@ export(LogRankSplitFinder)
export(MeanResponseCombiner) export(MeanResponseCombiner)
export(Numeric) export(Numeric)
export(WeightedVarianceSplitFinder) export(WeightedVarianceSplitFinder)
export(addTrees)
export(connectToData)
export(extractCHF) export(extractCHF)
export(extractCIF) export(extractCIF)
export(extractMortalities) export(extractMortalities)

128
R/addTrees.R Normal file
View file

@ -0,0 +1,128 @@
#' Add Trees
#'
#' Add more trees to an existing forest. Most parameters are extracted from the
#' previous forest.
#'
#' @param forest An existing forest.
#' @param numTreesToAdd The number of trees to add.
#' @param savePath If saving the forest, the directory to save to. Default is
#' \code{NULL}. Note that you need to respecify the path if you're modifying a
#' previously saved forest.
#' @param savePath.overwrite If \code{savePath} is pointing to an existing
#' directory, possibly containing another forest, this specifies what should
#' be done.
#' @param cores The number of cores to be used for training the new trees.
#' @param displayProgress A logical indicating whether the progress should be
#' displayed to console; default is \code{TRUE}. Useful to set to FALSE in
#' some automated situations.
#'
#' @return A new forest with the original and additional trees.
#' @export
#'
addTrees <- function(forest, numTreesToAdd, savePath = NULL, savePath.overwrite = c("warn", "delete", "merge"), cores = getCores(), displayProgress = TRUE){
if(is.null(forest$dataset)){
stop("Training dataset must be connected to forest before more trees can be added; this can be done manually by using connectToData")
}
numTreesToAdd <- as.integer(numTreesToAdd)
if(numTreesToAdd <= 0){
stop("numTreesToAdd must be a positive integer")
}
if(is.null(savePath.overwrite) | length(savePath.overwrite)==0 | !(savePath.overwrite[1] %in% c("warn", "delete", "merge"))){
stop("savePath.overwrite must be one of c(\"warn\", \"delete\", \"merge\")")
}
newTreeCount <- forest$params$ntree + as.integer(numTreesToAdd)
treeTrainer <- createTreeTrainer(responseCombiner=forest$params$nodeResponseCombiner,
splitFinder=forest$params$splitFinder,
covariateList=forest$covariateList,
numberOfSplits=forest$params$numberOfSplits,
nodeSize=forest$params$nodeSize,
maxNodeDepth=forest$params$maxNodeDepth,
mtry=forest$params$mtry,
splitPureNodes=forest$params$splitPureNodes)
forestTrainer <- createForestTrainer(treeTrainer=treeTrainer,
covariateList=forest$covariateList,
treeResponseCombiner=forest$params$forestResponseCombiner,
dataset=forest$dataset,
ntree=forest$params$ntree + numTreesToAdd,
randomSeed=forest$params$randomSeed,
saveTreeLocation=savePath,
displayProgress=displayProgress)
params <- list(
splitFinder=forest$params$splitFinder,
nodeResponseCombiner=forest$params$nodeResponseCombiner,
forestResponseCombiner=forest$params$forestResponseCombiner,
ntree=forest$params$ntree + numTreesToAdd,
numberOfSplits=forest$params$numberOfSplits,
mtry=forest$params$mtry,
nodeSize=forest$params$nodeSize,
splitPureNodes=forest$params$splitPureNodes,
maxNodeDepth = forest$params$maxNodeDepth,
randomSeed=forest$params$randomSeed
)
initial.forest.optional <- .object_Optional(forest$javaObject)
# We'll be saving an offline version of the forest
if(!is.null(savePath)){
if(file.exists(savePath)){ # we might have to remove the folder or display an error
if(savePath.overwrite[1] == "warn"){
stop(paste(savePath, "already exists; will not modify it. Please remove/rename it or set the savePath.overwrite to either 'delete' or 'merge'"))
} else if(savePath.overwrite[1] == "delete"){
unlink(savePath, recursive=TRUE)
} else if(savePath.overwrite[1] == "merge"){
warning("Assuming that the previous forest at savePath is the provided forest argument; if not true then your results will be suspect")
initial.forest.optional <- .object_Optional(NULL) # Java backend requires we be explicit about whether we're providing an in-memory initial forest or starting from a previous directory
}
}
if(savePath.overwrite[1] != "merge"){
dir.create(savePath)
}
# First save forest components (so that if the training crashes mid-way through it can theoretically be recovered by the user)
saveForestComponents(savePath,
covariateList=forest$covariateList,
params=params,
forestCall=match.call())
if(cores > 1){
.jcall(forestTrainer, "V", "trainParallelOnDisk", initial.forest.optional, as.integer(cores))
} else {
.jcall(forestTrainer, "V", "trainSerialOnDisk", initial.forest.optional)
}
# Need to now load forest trees back into memory
forest.java <- .jcall(.class_DataUtils, makeResponse(.class_Forest), "loadForest", savePath, forest$params$forestResponseCombiner$javaObject)
}
else{ # save directly into memory
if(cores > 1){
forest.java <- .jcall(forestTrainer, makeResponse(.class_Forest), "trainParallelInMemory", initial.forest.optional, as.integer(cores))
} else {
forest.java <- .jcall(forestTrainer, makeResponse(.class_Forest), "trainSerialInMemory", initial.forest.optional)
}
}
forestObject <- list(call=match.call(), params=params, javaObject=forest.java, covariateList=forest$covariateList, dataset=forest$dataset)
class(forestObject) <- "JRandomForest"
return(forestObject)
}

37
R/connectToData.R Normal file
View file

@ -0,0 +1,37 @@
#' Connect To Data
#'
#' When a trained forest is saved, the training dataset is not saved alongside
#' it. When it's loaded back up, it can be more convenient (and in some cases
#' necessary) to import the training dataset back into the Java environment so
#' that it's readily accessible. There are only two functions that look for the
#' training dataset: \code{predict}, where you can easily just specify an
#' alternative dataset, or \code{\link{addTrees}}, which requires the training
#' dataset be connected.
#' @param forest The forest to connect data too
#' @param responses The responses in the data; aka the left hand side of the formula
#' @param covariateData A data.frame containing all of the covariates used in the training dataset
#' @return The same forest, but connected to the training data.
#' @export
#' @examples
#' data <- data.frame(x1=rnorm(1000), x2=rnorm(1000), y=rnorm(1000))
#' forest <- train(y~x1+x2, data, ntree=100, numberOfSplits=0, nodeSize=1, mtry=1)
#' forest$dataset <- NULL # what the forest looks like after being loaded
#'
#' forest <- connectToData(forest, data$y, data)
connectToData <- function(forest, responses, covariateData){
covariateList <- forest$covariateList
numCovariates <- .jcall(covariateList, "I", "size")
covariateNames <- character(numCovariates)
for(j in 1:numCovariates){
covariate <- .jcall(covariateList, makeResponse(.class_Object), "get", as.integer(j-1))
covariate <- .jcast(covariate, .class_Covariate)
covariateNames[j] <- .jcall(covariate, makeResponse(.class_String), "getName")
}
forest$dataset <- loadData(covariateData, covariateNames, responses, covariateList)$dataset
return(forest)
}

View file

@ -50,6 +50,16 @@
.class_LogRankSplitFinder <- "ca/joeltherrien/randomforest/responses/competingrisk/splitfinder/LogRankSplitFinder" .class_LogRankSplitFinder <- "ca/joeltherrien/randomforest/responses/competingrisk/splitfinder/LogRankSplitFinder"
.class_WeightedVarianceSplitFinder <- "ca/joeltherrien/randomforest/responses/regression/WeightedVarianceSplitFinder" .class_WeightedVarianceSplitFinder <- "ca/joeltherrien/randomforest/responses/regression/WeightedVarianceSplitFinder"
.object_Optional <- function(forest=NULL){
if(is.null(forest)){
return(.jcall("java/util/Optional", "Ljava/util/Optional;", "empty"))
} else{
forest <- .jcast(forest, .class_Object)
return(.jcall("java/util/Optional", "Ljava/util/Optional;", "of", forest))
}
}
# When a class object is returned, rJava often often wants L prepended and ; appended. # When a class object is returned, rJava often often wants L prepended and ; appended.
# So a list that returns "java/lang/Object" should show "Ljava/lang/Object;" # So a list that returns "java/lang/Object" should show "Ljava/lang/Object;"
# This function does that. # This function does that.

View file

@ -1,10 +1,13 @@
loadData <- function(data, xVarNames, responses){ loadData <- function(data, xVarNames, responses, covariateList.java = NULL){
if(class(responses) == "integer" | class(responses) == "numeric"){ if(class(responses) == "integer" | class(responses) == "numeric"){
responses <- Numeric(responses) responses <- Numeric(responses)
} }
# connectToData provides a pre-created covariate list we can re-use
if(is.null(covariateList.java)){
covariateList.java <- getCovariateList(data, xVarNames) covariateList.java <- getCovariateList(data, xVarNames)
}
textColumns <- list() textColumns <- list()
for(j in 1:length(xVarNames)){ for(j in 1:length(xVarNames)){

View file

@ -43,7 +43,7 @@ loadForest <- function(directory){
params$forestResponseCombiner$javaObject <- forestResponseCombiner.java params$forestResponseCombiner$javaObject <- forestResponseCombiner.java
forest <- loadForestArgumentsSpecified(directory, params$nodeResponseCombiner, params$splitFinder, params$forestResponseCombiner, covariateList, call, forest <- loadForestArgumentsSpecified(directory, params$nodeResponseCombiner, params$splitFinder, params$forestResponseCombiner, covariateList, call,
params$ntree, params$numberOfSplits, params$mtry, params$nodeSize, params$maxNodeDepth, params$splitPureNodes) params$ntree, params$numberOfSplits, params$mtry, params$nodeSize, params$maxNodeDepth, params$splitPureNodes, params$randomSeed)
return(forest) return(forest)
@ -56,7 +56,7 @@ loadForest <- function(directory){
# I'd appreciate knowing that someone's going to use it first (email me; see # I'd appreciate knowing that someone's going to use it first (email me; see
# README). # README).
loadForestArgumentsSpecified <- function(treeDirectory, nodeResponseCombiner, splitFinder, forestResponseCombiner, loadForestArgumentsSpecified <- function(treeDirectory, nodeResponseCombiner, splitFinder, forestResponseCombiner,
covariateList.java, call, ntree, numberOfSplits, mtry, nodeSize, maxNodeDepth = 100000, splitPureNodes=TRUE){ covariateList.java, call, ntree, numberOfSplits, mtry, nodeSize, maxNodeDepth = 100000, splitPureNodes=TRUE, randomSeed=NULL){
params <- list( params <- list(
splitFinder=splitFinder, splitFinder=splitFinder,
@ -67,7 +67,8 @@ loadForestArgumentsSpecified <- function(treeDirectory, nodeResponseCombiner, sp
mtry=mtry, mtry=mtry,
nodeSize=nodeSize, nodeSize=nodeSize,
splitPureNodes=splitPureNodes, splitPureNodes=splitPureNodes,
maxNodeDepth = maxNodeDepth maxNodeDepth=maxNodeDepth,
randomSeed=randomSeed
) )
forest.java <- .jcall(.class_DataUtils, makeResponse(.class_Forest), "loadForest", treeDirectory, forestResponseCombiner$javaObject) forest.java <- .jcall(.class_DataUtils, makeResponse(.class_Forest), "loadForest", treeDirectory, forestResponseCombiner$javaObject)

View file

@ -6,9 +6,9 @@
#' #'
#' @param forest A forest that was previously \code{\link{train}}ed #' @param forest A forest that was previously \code{\link{train}}ed
#' @param newData The new data containing all of the previous predictor #' @param newData The new data containing all of the previous predictor
#' covariates. Note that even if predictions are being made on the training #' covariates. Can be NULL if you want to use the training dataset, and
#' set, the dataset must be specified. \code{largeRCRF} doesn't keep track of #' \code{forest} hasn't been loaded from the disk; otherwise you'll have to
#' the dataset after the forest is trained. #' specify it.
#' @param parallel A logical indicating whether multiple cores should be #' @param parallel A logical indicating whether multiple cores should be
#' utilized when making the predictions. Available as an option because it's #' utilized when making the predictions. Available as an option because it's
#' been observed that using Java's \code{parallelStream} can be unstable on #' been observed that using Java's \code{parallelStream} can be unstable on
@ -16,7 +16,8 @@
#' get strange errors while predicting. #' get strange errors while predicting.
#' @param out.of.bag A logical indicating whether predictions should be based on #' @param out.of.bag A logical indicating whether predictions should be based on
#' 'out of bag' trees; set only to \code{TRUE} if you're running predictions #' 'out of bag' trees; set only to \code{TRUE} if you're running predictions
#' on data that was used in the training. Default value is \code{FALSE}. #' on data that was used in the training. Default value is \code{TRUE} if
#' \code{newData} is \code{NULL}, otherwise \code{FALSE}.
#' @return A list of responses corresponding with each row of \code{newData} if #' @return A list of responses corresponding with each row of \code{newData} if
#' it's a non-regression random forest; otherwise it returns a numeric vector. #' it's a non-regression random forest; otherwise it returns a numeric vector.
#' @export #' @export
@ -50,18 +51,33 @@
#' forest <- train(CR_Response(delta, u) ~ x1 + x2, data, ntree=100, numberOfSplits=5, mtry=1, nodeSize=10) #' forest <- train(CR_Response(delta, u) ~ x1 + x2, data, ntree=100, numberOfSplits=5, mtry=1, nodeSize=10)
#' newData <- data.frame(x1 = c(-1, 0, 1), x2 = 0) #' newData <- data.frame(x1 = c(-1, 0, 1), x2 = 0)
#' ypred <- predict(forest, newData) #' ypred <- predict(forest, newData)
predict.JRandomForest <- function(forest, newData=NULL, parallel=TRUE, out.of.bag=FALSE){ predict.JRandomForest <- function(forest, newData=NULL, parallel=TRUE, out.of.bag=NULL){
if(is.null(newData)){
stop("newData must be specified, even if predictions are on the training set") if(is.null(newData) & is.null(forest$dataset)){
stop("forest doesn't have a copy of the training data loaded (this happens if you just loaded it); please manually specify newData and possibly out.of.bag")
} }
if(is.null(newData)){
predictionDataList <- forest$dataset
if(is.null(out.of.bag)){
out.of.bag <- TRUE
}
}
else{ # newData is provided
if(is.null(out.of.bag)){
out.of.bag <- FALSE
}
predictionDataList <- loadPredictionData(newData, forest$covariateList)
}
numRows <- .jcall(predictionDataList, "I", "size")
forestObject <- forest$javaObject forestObject <- forest$javaObject
covariateList <- forest$covariateList
predictionClass <- forest$params$forestResponseCombiner$outputClass predictionClass <- forest$params$forestResponseCombiner$outputClass
convertToRFunction <- forest$params$forestResponseCombiner$convertToRFunction convertToRFunction <- forest$params$forestResponseCombiner$convertToRFunction
predictionDataList <- loadPredictionData(newData, covariateList)
if(parallel){ if(parallel){
function.to.use <- "evaluate" function.to.use <- "evaluate"
} }
@ -82,8 +98,7 @@ predict.JRandomForest <- function(forest, newData=NULL, parallel=TRUE, out.of.ba
predictions <- list() predictions <- list()
} }
for(i in 1:numRows){
for(i in 1:nrow(newData)){
prediction <- .jcall(predictionsJava, makeResponse(.class_Object), "get", as.integer(i-1)) prediction <- .jcall(predictionsJava, makeResponse(.class_Object), "get", as.integer(i-1))
prediction <- convertToRFunction(prediction, forest) prediction <- convertToRFunction(prediction, forest)

View file

@ -30,10 +30,12 @@ getCores <- function(){
#' #'
#' @param responses An R list of the responses. See \code{\link{CR_Response}} #' @param responses An R list of the responses. See \code{\link{CR_Response}}
#' for an example function. #' for an example function.
#' @param data A data.frame containing the columns of the predictors and
#' responses; not relevant if you're not using the formula version of
#' \code{train}.
#' @param covariateData A data.frame containing only the columns of the #' @param covariateData A data.frame containing only the columns of the
#' covariates you wish to use in your training (unless you're using the #' covariates you wish to use in your training (not relevant if you're using
#' \code{formula} version of \code{train}, in which case it should contain the #' the formula version of \code{train}).
#' response as well).
#' @param splitFinder A split finder that's used to score splits in the random #' @param splitFinder A split finder that's used to score splits in the random
#' forest training algorithm. See \code{\link{Competing Risk Split Finders}} #' forest training algorithm. See \code{\link{Competing Risk Split Finders}}
#' or \code{\link{WeightedVarianceSplitFinder}}. If you don't specify one, #' or \code{\link{WeightedVarianceSplitFinder}}. If you don't specify one,
@ -78,10 +80,10 @@ getCores <- function(){
#' split a pure node. If set to FALSE, then before every split it will check #' split a pure node. If set to FALSE, then before every split it will check
#' that every response is the same, and if so, not split. If set to TRUE it #' that every response is the same, and if so, not split. If set to TRUE it
#' forgoes that check and splits it. Prediction accuracy won't change under #' forgoes that check and splits it. Prediction accuracy won't change under
#' any sensible \code{nodeResponseCombiner}; as all terminal nodes from a split #' any sensible \code{nodeResponseCombiner}; as all terminal nodes from a
#' pure node should give the same prediction, so this parameter only affects #' split pure node should give the same prediction, so this parameter only
#' performance. If your response is continuous you'll likely experience faster #' affects performance. If your response is continuous you'll likely
#' train times by setting it to TRUE. Default value is TRUE. #' experience faster train times by setting it to TRUE. Default value is TRUE.
#' @param savePath If set, this parameter will save each tree of the random #' @param savePath If set, this parameter will save each tree of the random
#' forest in this directory as the forest is trained. Use this parameter if #' forest in this directory as the forest is trained. Use this parameter if
#' you need to save memory while training. See also \code{\link{loadForest}} #' you need to save memory while training. See also \code{\link{loadForest}}
@ -98,21 +100,24 @@ getCores <- function(){
#' a crash. #' a crash.
#' @param cores This parameter specifies how many trees will be simultaneously #' @param cores This parameter specifies how many trees will be simultaneously
#' trained. By default the package attempts to detect how many cores you have #' trained. By default the package attempts to detect how many cores you have
#' by using the \code{parallel} package and using all of them. You may #' by using the \code{parallel} package and using all of them. You may specify
#' specify a lower number if you wish. It is not recommended to specify a #' a lower number if you wish. It is not recommended to specify a number
#' number greater than the number of available cores as this will hurt #' greater than the number of available cores as this will hurt performance
#' performance with no available benefit. #' with no available benefit.
#' @param randomSeed This parameter specifies a random seed if reproducible, #' @param randomSeed This parameter specifies a random seed if reproducible,
#' deterministic forests are desired. #' deterministic forests are desired.
#' @param displayProgress A logical indicating whether the progress should be
#' displayed to console; default is \code{TRUE}. Useful to set to FALSE in
#' some automated situations.
#' @export #' @export
#' @return A \code{JRandomForest} object. You may call \code{predict} or #' @return A \code{JRandomForest} object. You may call \code{predict} or
#' \code{print} on it. #' \code{print} on it.
#' @seealso \code{\link{predict.JRandomForest}} #' @seealso \code{\link{predict.JRandomForest}}
#' @note If saving memory is a concern, you can replace \code{covariateData} #' @note If saving memory is a concern, you can replace \code{covariateData} or
#' with an environment containing one element called \code{data} as the actual #' \code{data} with an environment containing one element called \code{data}
#' dataset. After the data has been imported into Java, but before the forest #' as the actual dataset. After the data has been imported into Java, but
#' training begins, the dataset in the environment is deleted, freeing up #' before the forest training begins, the dataset in the environment is
#' memory in R. #' deleted, freeing up memory in R.
#' @examples #' @examples
#' # Regression Example #' # Regression Example
#' x1 <- rnorm(1000) #' x1 <- rnorm(1000)
@ -150,7 +155,7 @@ train <- function(x, ...) UseMethod("train")
#' @rdname train #' @rdname train
#' @export #' @export
train.default <- function(responses, covariateData, splitFinder = splitFinderDefault(responses), nodeResponseCombiner = nodeResponseCombinerDefault(responses), forestResponseCombiner = forestResponseCombinerDefault(responses), ntree, numberOfSplits, mtry, nodeSize, maxNodeDepth = 100000, splitPureNodes=TRUE, savePath=NULL, savePath.overwrite=c("warn", "delete", "merge"), cores = getCores(), randomSeed = NULL){ train.default <- function(responses, covariateData, splitFinder = splitFinderDefault(responses), nodeResponseCombiner = nodeResponseCombinerDefault(responses), forestResponseCombiner = forestResponseCombinerDefault(responses), ntree, numberOfSplits, mtry, nodeSize, maxNodeDepth = 100000, splitPureNodes=TRUE, savePath=NULL, savePath.overwrite=c("warn", "delete", "merge"), cores = getCores(), randomSeed = NULL, displayProgress = TRUE){
# Some quick checks on parameters # Some quick checks on parameters
ntree <- as.integer(ntree) ntree <- as.integer(ntree)
@ -223,7 +228,8 @@ train.default <- function(responses, covariateData, splitFinder = splitFinderDef
dataset=dataset$dataset, dataset=dataset$dataset,
ntree=ntree, ntree=ntree,
randomSeed=randomSeed, randomSeed=randomSeed,
saveTreeLocation=savePath) saveTreeLocation=savePath,
displayProgress=displayProgress)
params <- list( params <- list(
splitFinder=splitFinder, splitFinder=splitFinder,
@ -235,7 +241,7 @@ train.default <- function(responses, covariateData, splitFinder = splitFinderDef
nodeSize=nodeSize, nodeSize=nodeSize,
splitPureNodes=splitPureNodes, splitPureNodes=splitPureNodes,
maxNodeDepth = maxNodeDepth, maxNodeDepth = maxNodeDepth,
savePath=savePath randomSeed=randomSeed
) )
# We'll be saving an offline version of the forest # We'll be saving an offline version of the forest
@ -262,9 +268,9 @@ train.default <- function(responses, covariateData, splitFinder = splitFinderDef
forestCall=match.call()) forestCall=match.call())
if(cores > 1){ if(cores > 1){
.jcall(forestTrainer, "V", "trainParallelOnDisk", as.integer(cores)) .jcall(forestTrainer, "V", "trainParallelOnDisk", .object_Optional(), as.integer(cores))
} else { } else {
.jcall(forestTrainer, "V", "trainSerialOnDisk") .jcall(forestTrainer, "V", "trainSerialOnDisk", .object_Optional())
} }
# Need to now load forest trees back into memory # Need to now load forest trees back into memory
@ -274,16 +280,16 @@ train.default <- function(responses, covariateData, splitFinder = splitFinderDef
} }
else{ # save directly into memory else{ # save directly into memory
if(cores > 1){ if(cores > 1){
forest.java <- .jcall(forestTrainer, makeResponse(.class_Forest), "trainParallelInMemory", as.integer(cores)) forest.java <- .jcall(forestTrainer, makeResponse(.class_Forest), "trainParallelInMemory", .object_Optional(), as.integer(cores))
} else { } else {
forest.java <- .jcall(forestTrainer, makeResponse(.class_Forest), "trainSerialInMemory") forest.java <- .jcall(forestTrainer, makeResponse(.class_Forest), "trainSerialInMemory", .object_Optional())
} }
} }
forestObject <- list(call=match.call(), params=params, javaObject=forest.java, covariateList=dataset$covariateList) forestObject <- list(call=match.call(), params=params, javaObject=forest.java, covariateList=dataset$covariateList, dataset=dataset$dataset)
class(forestObject) <- "JRandomForest" class(forestObject) <- "JRandomForest"
return(forestObject) return(forestObject)
@ -298,19 +304,19 @@ train.default <- function(responses, covariateData, splitFinder = splitFinderDef
#' @param formula You may specify the response and covariates as a formula #' @param formula You may specify the response and covariates as a formula
#' instead; make sure the response in the formula is still properly #' instead; make sure the response in the formula is still properly
#' constructed; see \code{responses} #' constructed; see \code{responses}
train.formula <- function(formula, covariateData, ...){ train.formula <- function(formula, data, ...){
# Having an R copy of the data loaded at the same time can be wasteful; we # Having an R copy of the data loaded at the same time can be wasteful; we
# also allow users to provide an environment of the data which gets removed # also allow users to provide an environment of the data which gets removed
# after being imported into Java # after being imported into Java
env <- NULL env <- NULL
if(class(covariateData) == "environment"){ if(class(data) == "environment"){
if(is.null(covariateData$data)){ if(is.null(data$data)){
stop("When providing an environment with the dataset, the environment must contain an item called 'data'") stop("When providing an environment with the dataset, the environment must contain an item called 'data'")
} }
env <- covariateData env <- data
covariateData <- env$data data <- env$data
} }
yVar <- formula[[2]] yVar <- formula[[2]]
@ -319,25 +325,25 @@ train.formula <- function(formula, covariateData, ...){
variablesToDrop <- character(0) variablesToDrop <- character(0)
# yVar is a call object; as.character(yVar) will be the different components, including the parameters. # yVar is a call object; as.character(yVar) will be the different components, including the parameters.
# if the length of yVar is > 1 then it's a function call. If the length is 1, and it's not in covariateData, # if the length of yVar is > 1 then it's a function call. If the length is 1, and it's not in data,
# then we also need to explicitly evaluate it # then we also need to explicitly evaluate it
if(class(yVar)=="call" || !(as.character(yVar) %in% colnames(covariateData))){ if(class(yVar)=="call" || !(as.character(yVar) %in% colnames(data))){
# yVar is a function like CompetingRiskResponses # yVar is a function like CompetingRiskResponses
responses <- eval(expr=yVar, envir=covariateData) responses <- eval(expr=yVar, envir=data)
if(class(formula[[3]]) == "name" && as.character(formula[[3]])=="."){ if(class(formula[[3]]) == "name" && as.character(formula[[3]])=="."){
# do any of the variables match data in covariateData? We need to track that so we can drop them later # do any of the variables match data in data? We need to track that so we can drop them later
variablesToDrop <- as.character(yVar)[as.character(yVar) %in% names(covariateData)] variablesToDrop <- as.character(yVar)[as.character(yVar) %in% names(data)]
} }
formula[[2]] <- NULL formula[[2]] <- NULL
} else if(class(yVar)=="name"){ # and implicitly yVar is contained in covariateData } else if(class(yVar)=="name"){ # and implicitly yVar is contained in data
variablesToDrop <- as.character(yVar) variablesToDrop <- as.character(yVar)
} }
# Includes responses which we may need to later cut out # Includes responses which we may need to later cut out
mf <- model.frame(formula=formula, data=covariateData, na.action=na.pass) mf <- model.frame(formula=formula, data=data, na.action=na.pass)
if(is.null(responses)){ if(is.null(responses)){
responses <- model.response(mf) responses <- model.response(mf)
@ -349,7 +355,7 @@ train.formula <- function(formula, covariateData, ...){
# If environment was provided instead of data # If environment was provided instead of data
if(!is.null(env)){ if(!is.null(env)){
env$data <- mf env$data <- mf
rm(covariateData) rm(data)
forest <- train.default(responses, env, ...) forest <- train.default(responses, env, ...)
} else{ } else{
forest <- train.default(responses, mf, ...) forest <- train.default(responses, mf, ...)
@ -363,7 +369,14 @@ train.formula <- function(formula, covariateData, ...){
return(forest) return(forest)
} }
createForestTrainer <- function(treeTrainer, covariateList, treeResponseCombiner, dataset, ntree, randomSeed, saveTreeLocation){ createForestTrainer <- function(treeTrainer,
covariateList,
treeResponseCombiner,
dataset,
ntree,
randomSeed,
saveTreeLocation,
displayProgress){
builderClassReturned <- makeResponse(.class_ForestTrainer_Builder) builderClassReturned <- makeResponse(.class_ForestTrainer_Builder)
builder <- .jcall(.class_ForestTrainer, builderClassReturned, "builder") builder <- .jcall(.class_ForestTrainer, builderClassReturned, "builder")
@ -373,7 +386,7 @@ createForestTrainer <- function(treeTrainer, covariateList, treeResponseCombiner
builder <- .jcall(builder, builderClassReturned, "treeResponseCombiner", treeResponseCombiner$javaObject) builder <- .jcall(builder, builderClassReturned, "treeResponseCombiner", treeResponseCombiner$javaObject)
builder <- .jcall(builder, builderClassReturned, "data", dataset) builder <- .jcall(builder, builderClassReturned, "data", dataset)
builder <- .jcall(builder, builderClassReturned, "ntree", as.integer(ntree)) builder <- .jcall(builder, builderClassReturned, "ntree", as.integer(ntree))
builder <- .jcall(builder, builderClassReturned, "displayProgress", TRUE) builder <- .jcall(builder, builderClassReturned, "displayProgress", displayProgress)
if(!is.null(randomSeed)){ if(!is.null(randomSeed)){
builder <- .jcall(builder, builderClassReturned, "randomSeed", .jlong(randomSeed)) builder <- .jcall(builder, builderClassReturned, "randomSeed", .jlong(randomSeed))

Some files were not shown because too many files have changed in this diff Show more