largeRCRF/R/addTrees.R

142 lines
6.2 KiB
R
Raw Normal View History

#' Add Trees
#'
#' Add more trees to an existing forest. Most parameters are extracted from the
#' previous forest.
#'
#' @param forest An existing forest.
#' @param numTreesToAdd The number of trees to add.
#' @param savePath If saving the forest, the directory to save to. Default is
2019-07-23 18:55:46 +00:00
#' \code{NULL}. Note that you need to re-specify the path if you're modifying
#' a previously saved forest.
#' @param savePath.overwrite If \code{savePath} is pointing to an existing
#' directory, possibly containing another forest, this specifies what should
#' be done.
2019-11-14 01:07:58 +00:00
#' @param forest.output This parameter only applies if \code{savePath} has been
#' set; set to 'online' (default) and the saved forest will be loaded into
#' memory after being trained. Set to 'offline' and the forest is not saved
#' into memory, but can still be used in a memory unintensive manner.
#' @param cores The number of cores to be used for training the new trees.
#' @param displayProgress A logical indicating whether the progress should be
#' displayed to console; default is \code{TRUE}. Useful to set to FALSE in
#' some automated situations.
#'
#' @return A new forest with the original and additional trees.
#' @export
#'
2019-11-14 01:07:58 +00:00
addTrees <- function(forest, numTreesToAdd, savePath = NULL,
savePath.overwrite = c("warn", "delete", "merge"),
forest.output = c("online", "offline"),
cores = getCores(), displayProgress = TRUE){
if(is.null(forest$dataset)){
stop("Training dataset must be connected to forest before more trees can be added; this can be done manually by using connectToData")
}
numTreesToAdd <- as.integer(numTreesToAdd)
if(numTreesToAdd <= 0){
stop("numTreesToAdd must be a positive integer")
}
if(is.null(savePath.overwrite) | length(savePath.overwrite)==0 | !(savePath.overwrite[1] %in% c("warn", "delete", "merge"))){
stop("savePath.overwrite must be one of c(\"warn\", \"delete\", \"merge\")")
}
2019-11-14 01:07:58 +00:00
if(is.null(forest.output) | length(forest.output)==0 | !(forest.output[1] %in% c("online", "offline"))){
stop("forest.output must be one of c(\"online\", \"offline\")")
}
newTreeCount <- forest$params$ntree + as.integer(numTreesToAdd)
treeTrainer <- createTreeTrainer(responseCombiner=forest$params$nodeResponseCombiner,
splitFinder=forest$params$splitFinder,
covariateList=forest$covariateList,
numberOfSplits=forest$params$numberOfSplits,
nodeSize=forest$params$nodeSize,
maxNodeDepth=forest$params$maxNodeDepth,
mtry=forest$params$mtry,
splitPureNodes=forest$params$splitPureNodes)
forestTrainer <- createForestTrainer(treeTrainer=treeTrainer,
covariateList=forest$covariateList,
treeResponseCombiner=forest$params$forestResponseCombiner,
dataset=forest$dataset,
ntree=forest$params$ntree + numTreesToAdd,
randomSeed=forest$params$randomSeed,
saveTreeLocation=savePath,
displayProgress=displayProgress)
params <- list(
splitFinder=forest$params$splitFinder,
nodeResponseCombiner=forest$params$nodeResponseCombiner,
forestResponseCombiner=forest$params$forestResponseCombiner,
ntree=forest$params$ntree + numTreesToAdd,
numberOfSplits=forest$params$numberOfSplits,
mtry=forest$params$mtry,
nodeSize=forest$params$nodeSize,
splitPureNodes=forest$params$splitPureNodes,
maxNodeDepth = forest$params$maxNodeDepth,
randomSeed=forest$params$randomSeed
)
initial.forest.optional <- .object_Optional(forest$javaObject)
# We'll be saving an offline version of the forest
if(!is.null(savePath)){
if(file.exists(savePath)){ # we might have to remove the folder or display an error
if(savePath.overwrite[1] == "warn"){
stop(paste(savePath, "already exists; will not modify it. Please remove/rename it or set the savePath.overwrite to either 'delete' or 'merge'"))
} else if(savePath.overwrite[1] == "delete"){
unlink(savePath, recursive=TRUE)
} else if(savePath.overwrite[1] == "merge"){
warning("Assuming that the previous forest at savePath is the provided forest argument; if not true then your results will be suspect")
initial.forest.optional <- .object_Optional(NULL) # Java backend requires we be explicit about whether we're providing an in-memory initial forest or starting from a previous directory
}
}
if(savePath.overwrite[1] != "merge"){
dir.create(savePath)
}
# First save forest components (so that if the training crashes mid-way through it can theoretically be recovered by the user)
saveForestComponents(savePath,
covariateList=forest$covariateList,
params=params,
forestCall=match.call())
2019-11-14 01:07:58 +00:00
forest.java <- NULL
if(cores > 1){
2019-11-14 01:07:58 +00:00
forest.java <- .jcall(forestTrainer, makeResponse(.class_OfflineForest), "trainParallelOnDisk", initial.forest.optional, as.integer(cores))
} else {
2019-11-14 01:07:58 +00:00
forest.java <- .jcall(forestTrainer, makeResponse(.class_OfflineForest), "trainSerialOnDisk", initial.forest.optional)
}
2019-11-14 01:07:58 +00:00
if(forest.output[1] == "online"){
forest.java <- convertToOnlineForest.Java(forest.java)
}
}
else{ # save directly into memory
if(cores > 1){
2019-11-14 01:07:58 +00:00
forest.java <- .jcall(forestTrainer, makeResponse(.class_OnlineForest), "trainParallelInMemory", initial.forest.optional, as.integer(cores))
} else {
2019-11-14 01:07:58 +00:00
forest.java <- .jcall(forestTrainer, makeResponse(.class_OnlineForest), "trainSerialInMemory", initial.forest.optional)
}
}
forestObject <- list(call=match.call(), params=params, javaObject=forest.java, covariateList=forest$covariateList, dataset=forest$dataset)
class(forestObject) <- "JRandomForest"
return(forestObject)
}