largeRCRF/R/saveForest.R
2019-06-30 15:07:29 -07:00

94 lines
No EOL
3.6 KiB
R

#' Save Random Forests
#'
#' Saves a random forest for later use, given that the base R
#' \code{\link[base]{save}} function doesn't work for this package.
#'
#' @param forest The forest to save.
#' @param directory The directory that should be created to save the trees in.
#' Note that if the directory already exists, an error will be displayed
#' unless \code{overwrite} is set to TRUE.
#' @param overwrite Should the function overwrite an existing forest; FALSE by
#' default.
#' @export
#' @seealso \code{\link{train}}, \code{\link{loadForest}}
#' @examples
#' # Regression Example
#' x1 <- rnorm(1000)
#' x2 <- rnorm(1000)
#' y <- 1 + x1 + x2 + rnorm(1000)
#'
#' data <- data.frame(x1, x2, y)
#' forest <- train(y ~ x1 + x2, data,
#' ntree=100, numberOfSplits = 5, mtry = 1, nodeSize = 5)
#'
#' saveForest(forest, "saved_forest")
#' new_forest <- loadForest("saved_forest")
saveForest <- function(forest, directory, overwrite=FALSE){
check_and_create_directory(directory, overwrite)
saveTrees(forest, directory)
# Next save the response combiners and the split finders
saveForestComponents(directory,
covariateList=forest$covariateList,
params=forest$params,
forestCall=forest$call)
}
saveTrees <- function(forest, directory){
# This function assumes that directory is free for us to write in.
forest.java <- forest$javaObject
# First save the trees
tree.collection.java <- .jcall(forest.java, makeResponse(.class_List), "getTrees")
numberOfTrees <- forest$params$ntree
width = round(log10(numberOfTrees))+1
treeNames <- paste0(directory, "/tree-", formatC(1:numberOfTrees, width=width, format="d", flag="0"), ".tree")
for(i in 1:numberOfTrees){
treeName <-treeNames[i]
tree.java <- .jcall(tree.collection.java, makeResponse(.class_Object), "get", as.integer(i-1))
tree.java <- .jcast(tree.java, .class_Serializable)
.jcall(.class_DataUtils, "V", "saveObject", tree.java, treeName)
}
}
saveForestComponents <- function(directory, covariateList, params, forestCall){
nodeResponseCombiner <- params$nodeResponseCombiner
nodeResponseCombiner.java <- .jcast(nodeResponseCombiner$javaObject, .class_Serializable)
.jcall(.class_DataUtils, "V", "saveObject", nodeResponseCombiner.java, paste0(directory, "/nodeResponseCombiner.jData"))
nodeResponseCombiner$javaObject <- NULL
splitFinder <- params$splitFinder
splitFinder.java <- .jcast(splitFinder$javaObject, .class_Serializable)
.jcall(.class_DataUtils, "V", "saveObject", splitFinder.java, paste0(directory, "/splitFinder.jData"))
splitFinder$javaObject <- NULL
forestResponseCombiner <- params$forestResponseCombiner
forestResponseCombiner.java <- .jcast(forestResponseCombiner$javaObject, .class_Serializable)
.jcall(.class_DataUtils, "V", "saveObject", forestResponseCombiner.java, paste0(directory, "/forestResponseCombiner.jData"))
forestResponseCombiner$javaObject <- NULL
covariateList <- .jcast(covariateList, .class_Serializable)
.jcall(.class_DataUtils, "V", "saveObject", covariateList, paste0(directory, "/covariateList.jData"))
saveRDS(object=params, file=paste0(directory, "/parameters.rData"))
saveRDS(object=forestCall, file=paste0(directory, "/call.rData"))
}
check_and_create_directory <- function(directory, overwrite){
if(file.exists(directory) & !overwrite){
stop(paste(directory, "already exists; will not modify it. Please remove/rename it or set overwrite=TRUE"))
}
else if(file.exists(directory) & overwrite){
unlink(directory)
}
dir.create(directory)
}