Add support for 'offline' forests

This commit is contained in:
Joel Therrien 2019-11-13 17:07:58 -08:00
parent 3fa2fef82a
commit 589df1a18b
16 changed files with 270 additions and 38 deletions

View file

@ -26,6 +26,7 @@ export(Numeric)
export(WeightedVarianceSplitFinder) export(WeightedVarianceSplitFinder)
export(addTrees) export(addTrees)
export(connectToData) export(connectToData)
export(convertToOnlineForest)
export(extractCHF) export(extractCHF)
export(extractCIF) export(extractCIF)
export(extractMortalities) export(extractMortalities)

View file

@ -14,6 +14,10 @@
#' @param savePath.overwrite If \code{savePath} is pointing to an existing #' @param savePath.overwrite If \code{savePath} is pointing to an existing
#' directory, possibly containing another forest, this specifies what should #' directory, possibly containing another forest, this specifies what should
#' be done. #' be done.
#' @param forest.output This parameter only applies if \code{savePath} has been
#' set; set to 'online' (default) and the saved forest will be loaded into
#' memory after being trained. Set to 'offline' and the forest is not saved
#' into memory, but can still be used in a memory unintensive manner.
#' @param cores The number of cores to be used for training the new trees. #' @param cores The number of cores to be used for training the new trees.
#' @param displayProgress A logical indicating whether the progress should be #' @param displayProgress A logical indicating whether the progress should be
#' displayed to console; default is \code{TRUE}. Useful to set to FALSE in #' displayed to console; default is \code{TRUE}. Useful to set to FALSE in
@ -22,7 +26,11 @@
#' @return A new forest with the original and additional trees. #' @return A new forest with the original and additional trees.
#' @export #' @export
#' #'
addTrees <- function(forest, numTreesToAdd, savePath = NULL, savePath.overwrite = c("warn", "delete", "merge"), cores = getCores(), displayProgress = TRUE){ addTrees <- function(forest, numTreesToAdd, savePath = NULL,
savePath.overwrite = c("warn", "delete", "merge"),
forest.output = c("online", "offline"),
cores = getCores(), displayProgress = TRUE){
if(is.null(forest$dataset)){ if(is.null(forest$dataset)){
stop("Training dataset must be connected to forest before more trees can be added; this can be done manually by using connectToData") stop("Training dataset must be connected to forest before more trees can be added; this can be done manually by using connectToData")
} }
@ -37,6 +45,10 @@ addTrees <- function(forest, numTreesToAdd, savePath = NULL, savePath.overwrite
stop("savePath.overwrite must be one of c(\"warn\", \"delete\", \"merge\")") stop("savePath.overwrite must be one of c(\"warn\", \"delete\", \"merge\")")
} }
if(is.null(forest.output) | length(forest.output)==0 | !(forest.output[1] %in% c("online", "offline"))){
stop("forest.output must be one of c(\"online\", \"offline\")")
}
newTreeCount <- forest$params$ntree + as.integer(numTreesToAdd) newTreeCount <- forest$params$ntree + as.integer(numTreesToAdd)
treeTrainer <- createTreeTrainer(responseCombiner=forest$params$nodeResponseCombiner, treeTrainer <- createTreeTrainer(responseCombiner=forest$params$nodeResponseCombiner,
@ -98,22 +110,23 @@ addTrees <- function(forest, numTreesToAdd, savePath = NULL, savePath.overwrite
params=params, params=params,
forestCall=match.call()) forestCall=match.call())
forest.java <- NULL
if(cores > 1){ if(cores > 1){
.jcall(forestTrainer, "V", "trainParallelOnDisk", initial.forest.optional, as.integer(cores)) forest.java <- .jcall(forestTrainer, makeResponse(.class_OfflineForest), "trainParallelOnDisk", initial.forest.optional, as.integer(cores))
} else { } else {
.jcall(forestTrainer, "V", "trainSerialOnDisk", initial.forest.optional) forest.java <- .jcall(forestTrainer, makeResponse(.class_OfflineForest), "trainSerialOnDisk", initial.forest.optional)
} }
# Need to now load forest trees back into memory if(forest.output[1] == "online"){
forest.java <- .jcall(.class_DataUtils, makeResponse(.class_Forest), "loadForest", savePath, forest$params$forestResponseCombiner$javaObject) forest.java <- convertToOnlineForest.Java(forest.java)
}
} }
else{ # save directly into memory else{ # save directly into memory
if(cores > 1){ if(cores > 1){
forest.java <- .jcall(forestTrainer, makeResponse(.class_Forest), "trainParallelInMemory", initial.forest.optional, as.integer(cores)) forest.java <- .jcall(forestTrainer, makeResponse(.class_OnlineForest), "trainParallelInMemory", initial.forest.optional, as.integer(cores))
} else { } else {
forest.java <- .jcall(forestTrainer, makeResponse(.class_Forest), "trainSerialInMemory", initial.forest.optional) forest.java <- .jcall(forestTrainer, makeResponse(.class_OnlineForest), "trainSerialInMemory", initial.forest.optional)
} }
} }

38
R/convertToOnlineForest.R Normal file
View file

@ -0,0 +1,38 @@
#' Convert to Online Forest
#'
#' Some forests are too large to store in memory and have been saved to disk.
#' They can still be used, but their performance is much slower. If there's
#' enough memory, they can easily be converted into an in-memory forest that is
#' faster to use.
#'
#' @param forest The offline forest.
#'
#' @return An online, in memory forst.
#' @export
#'
convertToOnlineForest <- function(forest){
old.forest.object <- forest$javaObject
if(getJavaClass(old.forest.object) == "ca.joeltherrien.randomforest.tree.OnlineForest"){
warning("forest is already in-memory")
return(forest)
} else if(getJavaClass(old.forest.object) == "ca.joeltherrien.randomforest.tree.OfflineForest"){
forest$javaObject <- convertToOnlineForest.Java(old.forest.object)
return(forest)
} else{
stop("'forest' is not an online or offline forest")
}
}
# Internal function
convertToOnlineForest.Java <- function(forest.java){
offline.forest <- .jcast(forest.java, .class_OfflineForest)
online.forest <- .jcall(offline.forest, makeResponse(.class_OnlineForest), "createOnlineCopy")
return(online.forest)
}

View file

@ -42,7 +42,7 @@ CR_FunctionCombiner <- function(events, times = NULL){
} }
javaObject <- .jnew(.class_CompetingRiskFunctionCombiner, eventArray, timeArray) javaObject <- .jnew(.class_CompetingRiskFunctionCombiner, eventArray, timeArray)
javaObject <- .jcast(javaObject, .class_ResponseCombiner) javaObject <- .jcast(javaObject, .class_ForestResponseCombiner)
combiner <- list(javaObject=javaObject, combiner <- list(javaObject=javaObject,
call=match.call(), call=match.call(),

View file

@ -11,6 +11,7 @@
.class_Serializable <- "java/io/Serializable" .class_Serializable <- "java/io/Serializable"
.class_File <- "java/io/File" .class_File <- "java/io/File"
.class_Random <- "java/util/Random" .class_Random <- "java/util/Random"
.class_Class <- "java/lang/Class"
# Utility Classes # Utility Classes
.class_DataUtils <- "ca/joeltherrien/randomforest/utils/DataUtils" .class_DataUtils <- "ca/joeltherrien/randomforest/utils/DataUtils"
@ -41,9 +42,12 @@
# Forest class # Forest class
.class_Forest <- "ca/joeltherrien/randomforest/tree/Forest" .class_Forest <- "ca/joeltherrien/randomforest/tree/Forest"
.class_OnlineForest <- "ca/joeltherrien/randomforest/tree/OnlineForest"
.class_OfflineForest <- "ca/joeltherrien/randomforest/tree/OfflineForest"
# ResponseCombiner classes # ResponseCombiner classes
.class_ResponseCombiner <- "ca/joeltherrien/randomforest/tree/ResponseCombiner" .class_ResponseCombiner <- "ca/joeltherrien/randomforest/tree/ResponseCombiner"
.class_ForestResponseCombiner <- "ca/joeltherrien/randomforest/tree/ForestResponseCombiner"
.class_CompetingRiskResponseCombiner <- "ca/joeltherrien/randomforest/responses/competingrisk/combiner/CompetingRiskResponseCombiner" .class_CompetingRiskResponseCombiner <- "ca/joeltherrien/randomforest/responses/competingrisk/combiner/CompetingRiskResponseCombiner"
.class_CompetingRiskFunctionCombiner <- "ca/joeltherrien/randomforest/responses/competingrisk/combiner/CompetingRiskFunctionCombiner" .class_CompetingRiskFunctionCombiner <- "ca/joeltherrien/randomforest/responses/competingrisk/combiner/CompetingRiskFunctionCombiner"
.class_MeanResponseCombiner <- "ca/joeltherrien/randomforest/responses/regression/MeanResponseCombiner" .class_MeanResponseCombiner <- "ca/joeltherrien/randomforest/responses/regression/MeanResponseCombiner"
@ -78,3 +82,9 @@
makeResponse <- function(className){ makeResponse <- function(className){
return(paste0("L", className, ";")) return(paste0("L", className, ";"))
} }
getJavaClass <- function(object){
class <- .jcall(object, makeResponse(.class_Class), "getClass")
className <- .jcall(class, "S", "getName")
return(className)
}

View file

@ -5,6 +5,12 @@
#' Loads a random forest that was saved using \code{\link{saveForest}}. #' Loads a random forest that was saved using \code{\link{saveForest}}.
#' #'
#' @param directory The directory created that saved the previous forest. #' @param directory The directory created that saved the previous forest.
#' @param forest.output Specifies whether the forest loaded should be loaded
#' into memory, or reflect the saved files where only one tree is loaded at a
#' time.
#' @param maxTreeNum If for some reason you only want to load the number of
#' trees up until a certain point, you can specify maxTreeNum as a single
#' number.
#' @return A JForest object; see \code{\link{train}} for details. #' @return A JForest object; see \code{\link{train}} for details.
#' @export #' @export
#' @seealso \code{\link{train}}, \code{\link{saveForest}} #' @seealso \code{\link{train}}, \code{\link{saveForest}}
@ -20,7 +26,11 @@
#' #'
#' saveForest(forest, "trees") #' saveForest(forest, "trees")
#' new_forest <- loadForest("trees") #' new_forest <- loadForest("trees")
loadForest <- function(directory){ loadForest <- function(directory, forest.output = c("online", "offline"), maxTreeNum = NULL){
if(is.null(forest.output) | length(forest.output)==0 | !(forest.output[1] %in% c("online", "offline"))){
stop("forest.output must be one of c(\"online\", \"offline\")")
}
# First load the response combiners and the split finders # First load the response combiners and the split finders
nodeResponseCombiner.java <- .jcall(.class_DataUtils, makeResponse(.class_Object), "loadObject", paste0(directory, "/nodeResponseCombiner.jData")) nodeResponseCombiner.java <- .jcall(.class_DataUtils, makeResponse(.class_Object), "loadObject", paste0(directory, "/nodeResponseCombiner.jData"))
@ -30,7 +40,7 @@ loadForest <- function(directory){
splitFinder.java <- .jcast(splitFinder.java, .class_SplitFinder) splitFinder.java <- .jcast(splitFinder.java, .class_SplitFinder)
forestResponseCombiner.java <- .jcall(.class_DataUtils, makeResponse(.class_Object), "loadObject", paste0(directory, "/forestResponseCombiner.jData")) forestResponseCombiner.java <- .jcall(.class_DataUtils, makeResponse(.class_Object), "loadObject", paste0(directory, "/forestResponseCombiner.jData"))
forestResponseCombiner.java <- .jcast(forestResponseCombiner.java, .class_ResponseCombiner) forestResponseCombiner.java <- .jcast(forestResponseCombiner.java, .class_ForestResponseCombiner)
covariateList <- .jcall(.class_DataUtils, makeResponse(.class_Object), "loadObject", paste0(directory, "/covariateList.jData")) covariateList <- .jcall(.class_DataUtils, makeResponse(.class_Object), "loadObject", paste0(directory, "/covariateList.jData"))
covariateList <- .jcast(covariateList, .class_List) covariateList <- .jcast(covariateList, .class_List)
@ -42,8 +52,11 @@ loadForest <- function(directory){
params$splitFinder$javaObject <- splitFinder.java params$splitFinder$javaObject <- splitFinder.java
params$forestResponseCombiner$javaObject <- forestResponseCombiner.java params$forestResponseCombiner$javaObject <- forestResponseCombiner.java
forest <- loadForestArgumentsSpecified(directory, params$nodeResponseCombiner, params$splitFinder, params$forestResponseCombiner, covariateList, call, forest <- loadForestArgumentsSpecified(directory, params$nodeResponseCombiner, params$splitFinder,
params$ntree, params$numberOfSplits, params$mtry, params$nodeSize, params$maxNodeDepth, params$splitPureNodes, params$randomSeed) params$forestResponseCombiner, covariateList, call,
params$ntree, params$numberOfSplits, params$mtry,
params$nodeSize, params$maxNodeDepth, params$splitPureNodes,
params$randomSeed, forest.output, maxTreeNum)
return(forest) return(forest)
@ -55,8 +68,11 @@ loadForest <- function(directory){
# that uses the Java version's settings yaml file to recreate the forest, but # that uses the Java version's settings yaml file to recreate the forest, but
# I'd appreciate knowing that someone's going to use it first (email me; see # I'd appreciate knowing that someone's going to use it first (email me; see
# README). # README).
loadForestArgumentsSpecified <- function(treeDirectory, nodeResponseCombiner, splitFinder, forestResponseCombiner, loadForestArgumentsSpecified <- function(treeDirectory, nodeResponseCombiner, splitFinder,
covariateList.java, call, ntree, numberOfSplits, mtry, nodeSize, maxNodeDepth = 100000, splitPureNodes=TRUE, randomSeed=NULL){ forestResponseCombiner, covariateList.java, call,
ntree, numberOfSplits, mtry, nodeSize,
maxNodeDepth = 100000, splitPureNodes=TRUE,
randomSeed=NULL, forest.output = "online", maxTreeNum = NULL){
params <- list( params <- list(
splitFinder=splitFinder, splitFinder=splitFinder,
@ -71,7 +87,33 @@ loadForestArgumentsSpecified <- function(treeDirectory, nodeResponseCombiner, sp
randomSeed=randomSeed randomSeed=randomSeed
) )
forest.java <- .jcall(.class_DataUtils, makeResponse(.class_Forest), "loadForest", treeDirectory, forestResponseCombiner$javaObject) forest.java <- NULL
if(forest.output[1] == "online"){
castedForestResponseCombiner <- .jcast(forestResponseCombiner$javaObject, .class_ResponseCombiner) # OnlineForest constructor takes a ResponseCombiner
if(is.null(maxTreeNum)){
forest.java <- .jcall(.class_DataUtils, makeResponse(.class_OnlineForest), "loadOnlineForest",
treeDirectory, castedForestResponseCombiner)
} else{
tree.file.array <- .jcall(.class_RUtils, paste0("[", makeResponse(.class_File)), "getTreeFileArray",
treeDirectory, as.integer(maxTreeNum), evalArray = FALSE)
forest.java <- .jcall(.class_DataUtils, makeResponse(.class_OnlineForest), "loadOnlineForest",
tree.file.array, castedForestResponseCombiner)
}
} else{ # offline forest
if(is.null(maxTreeNum)){
path.as.file <- .jnew(.class_File, treeDirectory)
forest.java <- .jnew(.class_OfflineForest, path.as.file, forestResponseCombiner$javaObject)
} else{
tree.file.array <- .jcall(.class_RUtils, paste0("[", makeResponse(.class_File)), "getTreeFileArray",
treeDirectory, as.integer(maxTreeNum), evalArray = FALSE)
forest.java <- .jnew(.class_OfflineForest, tree.file.array, forestResponseCombiner$javaObject)
}
}
forestObject <- list(call=call, javaObject=forest.java, covariateList=covariateList.java, params=params) forestObject <- list(call=call, javaObject=forest.java, covariateList=covariateList.java, params=params)
class(forestObject) <- "JRandomForest" class(forestObject) <- "JRandomForest"

View file

@ -48,7 +48,7 @@ WeightedVarianceSplitFinder <- function(){
#' #'
MeanResponseCombiner <- function(){ MeanResponseCombiner <- function(){
javaObject <- .jnew(.class_MeanResponseCombiner) javaObject <- .jnew(.class_MeanResponseCombiner)
javaObject <- .jcast(javaObject, .class_ResponseCombiner) javaObject <- .jcast(javaObject, .class_ForestResponseCombiner)
combiner <- list(javaObject=javaObject, call=match.call(), outputClass="numeric") combiner <- list(javaObject=javaObject, call=match.call(), outputClass="numeric")
combiner$convertToRFunction <- function(javaObject, ...){ combiner$convertToRFunction <- function(javaObject, ...){

View file

@ -18,7 +18,7 @@ train.internal <- function(dataset, splitFinder,
nodeResponseCombiner, forestResponseCombiner, ntree, nodeResponseCombiner, forestResponseCombiner, ntree,
numberOfSplits, mtry, nodeSize, maxNodeDepth, numberOfSplits, mtry, nodeSize, maxNodeDepth,
splitPureNodes, savePath, savePath.overwrite, splitPureNodes, savePath, savePath.overwrite,
cores, randomSeed, displayProgress){ forest.output, cores, randomSeed, displayProgress){
# Some quick checks on parameters # Some quick checks on parameters
ntree <- as.integer(ntree) ntree <- as.integer(ntree)
@ -51,6 +51,10 @@ train.internal <- function(dataset, splitFinder,
stop("savePath.overwrite must be one of c(\"warn\", \"delete\", \"merge\")") stop("savePath.overwrite must be one of c(\"warn\", \"delete\", \"merge\")")
} }
if(is.null(forest.output) | length(forest.output)==0 | !(forest.output[1] %in% c("online", "offline"))){
stop("forest.output must be one of c(\"online\", \"offline\")")
}
if(is.null(splitFinder)){ if(is.null(splitFinder)){
splitFinder <- splitFinderDefault(dataset$responses) splitFinder <- splitFinderDefault(dataset$responses)
} }
@ -129,22 +133,23 @@ train.internal <- function(dataset, splitFinder,
params=params, params=params,
forestCall=match.call()) forestCall=match.call())
forest.java <- NULL
if(cores > 1){ if(cores > 1){
.jcall(forestTrainer, "V", "trainParallelOnDisk", .object_Optional(), as.integer(cores)) forest.java <- .jcall(forestTrainer, makeResponse(.class_OfflineForest), "trainParallelOnDisk", .object_Optional(), as.integer(cores))
} else { } else {
.jcall(forestTrainer, "V", "trainSerialOnDisk", .object_Optional()) forest.java <- .jcall(forestTrainer, makeResponse(.class_OfflineForest), "trainSerialOnDisk", .object_Optional())
} }
# Need to now load forest trees back into memory if(forest.output[1] == "online"){
forest.java <- .jcall(.class_DataUtils, makeResponse(.class_Forest), "loadForest", savePath, forestResponseCombiner$javaObject) forest.java <- convertToOnlineForest.Java(forest.java)
}
} }
else{ # save directly into memory else{ # save directly into memory
if(cores > 1){ if(cores > 1){
forest.java <- .jcall(forestTrainer, makeResponse(.class_Forest), "trainParallelInMemory", .object_Optional(), as.integer(cores)) forest.java <- .jcall(forestTrainer, makeResponse(.class_OnlineForest), "trainParallelInMemory", .object_Optional(), as.integer(cores))
} else { } else {
forest.java <- .jcall(forestTrainer, makeResponse(.class_Forest), "trainSerialInMemory", .object_Optional()) forest.java <- .jcall(forestTrainer, makeResponse(.class_OnlineForest), "trainSerialInMemory", .object_Optional())
} }
} }
@ -253,6 +258,10 @@ train.internal <- function(dataset, splitFinder,
#' assumes (without checking) that the existing trees are from a previous run #' assumes (without checking) that the existing trees are from a previous run
#' and starts from where it left off. This option is useful if recovering from #' and starts from where it left off. This option is useful if recovering from
#' a crash. #' a crash.
#' @param forest.output This parameter only applies if \code{savePath} has been
#' set; set to 'online' (default) and the saved forest will be loaded into
#' memory after being trained. Set to 'offline' and the forest is not saved
#' into memory, but can still be used in a memory unintensive manner.
#' @param cores This parameter specifies how many trees will be simultaneously #' @param cores This parameter specifies how many trees will be simultaneously
#' trained. By default the package attempts to detect how many cores you have #' trained. By default the package attempts to detect how many cores you have
#' by using the \code{parallel} package and using all of them. You may specify #' by using the \code{parallel} package and using all of them. You may specify
@ -311,7 +320,8 @@ train.internal <- function(dataset, splitFinder,
train <- function(formula, data, splitFinder = NULL, nodeResponseCombiner = NULL, train <- function(formula, data, splitFinder = NULL, nodeResponseCombiner = NULL,
forestResponseCombiner = NULL, ntree, numberOfSplits, mtry, forestResponseCombiner = NULL, ntree, numberOfSplits, mtry,
nodeSize, maxNodeDepth = 100000, na.penalty = TRUE, splitPureNodes=TRUE, nodeSize, maxNodeDepth = 100000, na.penalty = TRUE, splitPureNodes=TRUE,
savePath=NULL, savePath.overwrite=c("warn", "delete", "merge"), savePath = NULL, savePath.overwrite = c("warn", "delete", "merge"),
forest.output = c("online", "offline"),
cores = getCores(), randomSeed = NULL, displayProgress = TRUE){ cores = getCores(), randomSeed = NULL, displayProgress = TRUE){
dataset <- processFormula(formula, data, na.penalty = na.penalty) dataset <- processFormula(formula, data, na.penalty = na.penalty)
@ -322,8 +332,8 @@ train <- function(formula, data, splitFinder = NULL, nodeResponseCombiner = NULL
ntree = ntree, numberOfSplits = numberOfSplits, ntree = ntree, numberOfSplits = numberOfSplits,
mtry = mtry, nodeSize = nodeSize, maxNodeDepth = maxNodeDepth, mtry = mtry, nodeSize = nodeSize, maxNodeDepth = maxNodeDepth,
splitPureNodes = splitPureNodes, savePath = savePath, splitPureNodes = splitPureNodes, savePath = savePath,
savePath.overwrite = savePath.overwrite, cores = cores, savePath.overwrite = savePath.overwrite, forest.output = forest.output,
randomSeed = randomSeed, displayProgress = displayProgress) cores = cores, randomSeed = randomSeed, displayProgress = displayProgress)
forest$call <- match.call() forest$call <- match.call()
forest$formula <- formula forest$formula <- formula
@ -371,7 +381,9 @@ createTreeTrainer <- function(responseCombiner, splitFinder, covariateList, numb
builder <- .jcall(.class_TreeTrainer, builderClassReturned, "builder") builder <- .jcall(.class_TreeTrainer, builderClassReturned, "builder")
builder <- .jcall(builder, builderClassReturned, "responseCombiner", responseCombiner$javaObject) responseCombinerCasted <- .jcast(responseCombiner$javaObject, .class_ResponseCombiner) # might need to cast a ForestResponseCombiner down
builder <- .jcall(builder, builderClassReturned, "responseCombiner", responseCombinerCasted)
builder <- .jcall(builder, builderClassReturned, "splitFinder", splitFinder$javaObject) builder <- .jcall(builder, builderClassReturned, "splitFinder", splitFinder$javaObject)
builder <- .jcall(builder, builderClassReturned, "covariates", covariateList) builder <- .jcall(builder, builderClassReturned, "covariates", covariateList)
builder <- .jcall(builder, builderClassReturned, "numberOfSplits", as.integer(numberOfSplits)) builder <- .jcall(builder, builderClassReturned, "numberOfSplits", as.integer(numberOfSplits))

View file

@ -6,7 +6,8 @@
\usage{ \usage{
addTrees(forest, numTreesToAdd, savePath = NULL, addTrees(forest, numTreesToAdd, savePath = NULL,
savePath.overwrite = c("warn", "delete", "merge"), savePath.overwrite = c("warn", "delete", "merge"),
cores = getCores(), displayProgress = TRUE) forest.output = c("online", "offline"), cores = getCores(),
displayProgress = TRUE)
} }
\arguments{ \arguments{
\item{forest}{An existing forest.} \item{forest}{An existing forest.}
@ -21,6 +22,11 @@ a previously saved forest.}
directory, possibly containing another forest, this specifies what should directory, possibly containing another forest, this specifies what should
be done.} be done.}
\item{forest.output}{This parameter only applies if \code{savePath} has been
set; set to 'online' (default) and the saved forest will be loaded into
memory after being trained. Set to 'offline' and the forest is not saved
into memory, but can still be used in a memory unintensive manner.}
\item{cores}{The number of cores to be used for training the new trees.} \item{cores}{The number of cores to be used for training the new trees.}
\item{displayProgress}{A logical indicating whether the progress should be \item{displayProgress}{A logical indicating whether the progress should be

View file

@ -0,0 +1,20 @@
% Generated by roxygen2: do not edit by hand
% Please edit documentation in R/convertToOnlineForest.R
\name{convertToOnlineForest}
\alias{convertToOnlineForest}
\title{Convert to Online Forest}
\usage{
convertToOnlineForest(forest)
}
\arguments{
\item{forest}{The offline forest.}
}
\value{
An online, in memory forst.
}
\description{
Some forests are too large to store in memory and have been saved to disk.
They can still be used, but their performance is much slower. If there's
enough memory, they can easily be converted into an in-memory forest that is
faster to use.
}

View file

@ -4,10 +4,19 @@
\alias{loadForest} \alias{loadForest}
\title{Load Random Forest} \title{Load Random Forest}
\usage{ \usage{
loadForest(directory) loadForest(directory, forest.output = c("online", "offline"),
maxTreeNum = NULL)
} }
\arguments{ \arguments{
\item{directory}{The directory created that saved the previous forest.} \item{directory}{The directory created that saved the previous forest.}
\item{forest.output}{Specifies whether the forest loaded should be loaded
into memory, or reflect the saved files where only one tree is loaded at a
time.}
\item{maxTreeNum}{If for some reason you only want to load the number of
trees up until a certain point, you can specify maxTreeNum as a single
number.}
} }
\value{ \value{
A JForest object; see \code{\link{train}} for details. A JForest object; see \code{\link{train}} for details.

View file

@ -8,7 +8,8 @@ train(formula, data, splitFinder = NULL, nodeResponseCombiner = NULL,
forestResponseCombiner = NULL, ntree, numberOfSplits, mtry, nodeSize, forestResponseCombiner = NULL, ntree, numberOfSplits, mtry, nodeSize,
maxNodeDepth = 1e+05, na.penalty = TRUE, splitPureNodes = TRUE, maxNodeDepth = 1e+05, na.penalty = TRUE, splitPureNodes = TRUE,
savePath = NULL, savePath.overwrite = c("warn", "delete", "merge"), savePath = NULL, savePath.overwrite = c("warn", "delete", "merge"),
cores = getCores(), randomSeed = NULL, displayProgress = TRUE) forest.output = c("online", "offline"), cores = getCores(),
randomSeed = NULL, displayProgress = TRUE)
} }
\arguments{ \arguments{
\item{formula}{You may specify the response and covariates as a formula \item{formula}{You may specify the response and covariates as a formula
@ -106,6 +107,11 @@ assumes (without checking) that the existing trees are from a previous run
and starts from where it left off. This option is useful if recovering from and starts from where it left off. This option is useful if recovering from
a crash.} a crash.}
\item{forest.output}{This parameter only applies if \code{savePath} has been
set; set to 'online' (default) and the saved forest will be loaded into
memory after being trained. Set to 'offline' and the forest is not saved
into memory, but can still be used in a memory unintensive manner.}
\item{cores}{This parameter specifies how many trees will be simultaneously \item{cores}{This parameter specifies how many trees will be simultaneously
trained. By default the package attempts to detect how many cores you have trained. By default the package attempts to detect how many cores you have
by using the \code{parallel} package and using all of them. You may specify by using the \code{parallel} package and using all of them. You may specify

View file

@ -1,6 +1,6 @@
context("Add trees on existing forest") context("Add trees on existing forest")
test_that("Can add trees on existing forest", { test_that("Can add trees on existing online forest", {
trainingData <- data.frame(x=rnorm(100)) trainingData <- data.frame(x=rnorm(100))
trainingData$T <- rexp(100) + abs(trainingData$x) trainingData$T <- rexp(100) + abs(trainingData$x)
@ -20,6 +20,44 @@ test_that("Can add trees on existing forest", {
}) })
test_that("Can add trees on existing offline forest", {
if(file.exists("trees")){ # folder could exist from a previous failed test; delete it
unlink("trees", recursive=TRUE)
}
trainingData <- data.frame(x=rnorm(100))
trainingData$T <- rexp(100) + abs(trainingData$x)
trainingData$delta <- sample(0:2, size = 100, replace=TRUE)
forest <- train(CR_Response(delta, T) ~ x, trainingData, ntree=50,
numberOfSplits=0, mtry=1, nodeSize=5,
forestResponseCombiner = CR_FunctionCombiner(events = 1:2, times = 0:10), # TODO - remove specifing times; this is workaround around unimplemented feature for offline forests
cores=2, displayProgress=FALSE, savePath="trees",
forest.output = "offline")
warning("TODO - need to implement feature; test workaround in the meantime")
predictions <- predict(forest)
warning_message <- "Assuming that the previous forest at savePath is the provided forest argument; if not true then your results will be suspect"
forest.more <- expect_warning(addTrees(forest, 50, cores=2, displayProgress=FALSE,
savePath="trees", savePath.overwrite = "merge",
forest.output = "offline"), fixed=warning_message) # test multi-core
predictions <- predict(forest)
forest.more <- expect_warning(addTrees(forest, 50, cores=1, displayProgress=FALSE,
savePath="trees", savePath.overwrite = "merge",
forest.output = "offline"), fixed=warning_message) # test single-core
expect_true(T) # show Ok if we got this far
unlink("trees", recursive=TRUE)
})
test_that("Test adding trees on saved forest - using delete", { test_that("Test adding trees on saved forest - using delete", {
if(file.exists("trees")){ # folder could exist from a previous failed test; delete it if(file.exists("trees")){ # folder could exist from a previous failed test; delete it

View file

@ -2,7 +2,11 @@ context("Train, save, and load without error")
test_that("Can save & load regression example", { test_that("Can save & load regression example", {
expect_false(file.exists("trees_saving_loading")) # Folder shouldn't exist yet if(file.exists("trees_saving_loading")){
unlink("trees_saving_loading", recursive=TRUE)
}
expect_false(file.exists("trees_saving_loading")) # Folder shouldn't exist at this point
x1 <- rnorm(1000) x1 <- rnorm(1000)
x2 <- rnorm(1000) x2 <- rnorm(1000)

View file

@ -13,7 +13,7 @@ test_that("Can save a random forest while training, and use it afterward", {
data <- data.frame(x1, x2, y) data <- data.frame(x1, x2, y)
forest <- train(y ~ x1 + x2, data, forest <- train(y ~ x1 + x2, data,
ntree=100, numberOfSplits = 5, mtry = 1, nodeSize = 5, ntree=100, numberOfSplits = 5, mtry = 1, nodeSize = 5,
savePath="trees", displayProgress=FALSE) savePath="trees", forest.output = "online", displayProgress=FALSE)
expect_true(file.exists("trees")) # Something should have been saved expect_true(file.exists("trees")) # Something should have been saved
@ -26,6 +26,39 @@ test_that("Can save a random forest while training, and use it afterward", {
predictions <- predict(newforest, newData) predictions <- predict(newforest, newData)
unlink("trees", recursive=TRUE)
})
test_that("Can save a random forest while training, and use it afterward with pure offline forest", {
if(file.exists("trees")){ # folder could exist from a previous failed test; delete it
unlink("trees", recursive=TRUE)
}
x1 <- rnorm(1000)
x2 <- rnorm(1000)
y <- 1 + x1 + x2 + rnorm(1000)
data <- data.frame(x1, x2, y)
forest <- train(y ~ x1 + x2, data,
ntree=100, numberOfSplits = 5, mtry = 1, nodeSize = 5,
savePath="trees", forest.output = "offline", displayProgress=FALSE)
expect_true(file.exists("trees")) # Something should have been saved
# try making a little prediction to verify it works
newData <- data.frame(x1=seq(from=-3, to=3, by=0.5), x2=0)
predictions <- predict(forest, newData)
# Also make sure we can load the forest too
newforest <- loadForest("trees")
predictions <- predict(newforest, newData)
# Last, make sure we can take the forest online
onlineForest <- convertToOnlineForest(forest)
predictions <- predict(onlineForest, newData)
unlink("trees", recursive=TRUE) unlink("trees", recursive=TRUE)
}) })