Add support for 'offline' forests
This commit is contained in:
parent
3fa2fef82a
commit
589df1a18b
16 changed files with 270 additions and 38 deletions
|
@ -26,6 +26,7 @@ export(Numeric)
|
|||
export(WeightedVarianceSplitFinder)
|
||||
export(addTrees)
|
||||
export(connectToData)
|
||||
export(convertToOnlineForest)
|
||||
export(extractCHF)
|
||||
export(extractCIF)
|
||||
export(extractMortalities)
|
||||
|
|
29
R/addTrees.R
29
R/addTrees.R
|
@ -14,6 +14,10 @@
|
|||
#' @param savePath.overwrite If \code{savePath} is pointing to an existing
|
||||
#' directory, possibly containing another forest, this specifies what should
|
||||
#' be done.
|
||||
#' @param forest.output This parameter only applies if \code{savePath} has been
|
||||
#' set; set to 'online' (default) and the saved forest will be loaded into
|
||||
#' memory after being trained. Set to 'offline' and the forest is not saved
|
||||
#' into memory, but can still be used in a memory unintensive manner.
|
||||
#' @param cores The number of cores to be used for training the new trees.
|
||||
#' @param displayProgress A logical indicating whether the progress should be
|
||||
#' displayed to console; default is \code{TRUE}. Useful to set to FALSE in
|
||||
|
@ -22,7 +26,11 @@
|
|||
#' @return A new forest with the original and additional trees.
|
||||
#' @export
|
||||
#'
|
||||
addTrees <- function(forest, numTreesToAdd, savePath = NULL, savePath.overwrite = c("warn", "delete", "merge"), cores = getCores(), displayProgress = TRUE){
|
||||
addTrees <- function(forest, numTreesToAdd, savePath = NULL,
|
||||
savePath.overwrite = c("warn", "delete", "merge"),
|
||||
forest.output = c("online", "offline"),
|
||||
cores = getCores(), displayProgress = TRUE){
|
||||
|
||||
if(is.null(forest$dataset)){
|
||||
stop("Training dataset must be connected to forest before more trees can be added; this can be done manually by using connectToData")
|
||||
}
|
||||
|
@ -37,6 +45,10 @@ addTrees <- function(forest, numTreesToAdd, savePath = NULL, savePath.overwrite
|
|||
stop("savePath.overwrite must be one of c(\"warn\", \"delete\", \"merge\")")
|
||||
}
|
||||
|
||||
if(is.null(forest.output) | length(forest.output)==0 | !(forest.output[1] %in% c("online", "offline"))){
|
||||
stop("forest.output must be one of c(\"online\", \"offline\")")
|
||||
}
|
||||
|
||||
newTreeCount <- forest$params$ntree + as.integer(numTreesToAdd)
|
||||
|
||||
treeTrainer <- createTreeTrainer(responseCombiner=forest$params$nodeResponseCombiner,
|
||||
|
@ -98,22 +110,23 @@ addTrees <- function(forest, numTreesToAdd, savePath = NULL, savePath.overwrite
|
|||
params=params,
|
||||
forestCall=match.call())
|
||||
|
||||
forest.java <- NULL
|
||||
if(cores > 1){
|
||||
.jcall(forestTrainer, "V", "trainParallelOnDisk", initial.forest.optional, as.integer(cores))
|
||||
forest.java <- .jcall(forestTrainer, makeResponse(.class_OfflineForest), "trainParallelOnDisk", initial.forest.optional, as.integer(cores))
|
||||
} else {
|
||||
.jcall(forestTrainer, "V", "trainSerialOnDisk", initial.forest.optional)
|
||||
forest.java <- .jcall(forestTrainer, makeResponse(.class_OfflineForest), "trainSerialOnDisk", initial.forest.optional)
|
||||
}
|
||||
|
||||
# Need to now load forest trees back into memory
|
||||
forest.java <- .jcall(.class_DataUtils, makeResponse(.class_Forest), "loadForest", savePath, forest$params$forestResponseCombiner$javaObject)
|
||||
|
||||
if(forest.output[1] == "online"){
|
||||
forest.java <- convertToOnlineForest.Java(forest.java)
|
||||
}
|
||||
|
||||
}
|
||||
else{ # save directly into memory
|
||||
if(cores > 1){
|
||||
forest.java <- .jcall(forestTrainer, makeResponse(.class_Forest), "trainParallelInMemory", initial.forest.optional, as.integer(cores))
|
||||
forest.java <- .jcall(forestTrainer, makeResponse(.class_OnlineForest), "trainParallelInMemory", initial.forest.optional, as.integer(cores))
|
||||
} else {
|
||||
forest.java <- .jcall(forestTrainer, makeResponse(.class_Forest), "trainSerialInMemory", initial.forest.optional)
|
||||
forest.java <- .jcall(forestTrainer, makeResponse(.class_OnlineForest), "trainSerialInMemory", initial.forest.optional)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
38
R/convertToOnlineForest.R
Normal file
38
R/convertToOnlineForest.R
Normal file
|
@ -0,0 +1,38 @@
|
|||
|
||||
#' Convert to Online Forest
|
||||
#'
|
||||
#' Some forests are too large to store in memory and have been saved to disk.
|
||||
#' They can still be used, but their performance is much slower. If there's
|
||||
#' enough memory, they can easily be converted into an in-memory forest that is
|
||||
#' faster to use.
|
||||
#'
|
||||
#' @param forest The offline forest.
|
||||
#'
|
||||
#' @return An online, in memory forst.
|
||||
#' @export
|
||||
#'
|
||||
convertToOnlineForest <- function(forest){
|
||||
old.forest.object <- forest$javaObject
|
||||
|
||||
if(getJavaClass(old.forest.object) == "ca.joeltherrien.randomforest.tree.OnlineForest"){
|
||||
|
||||
warning("forest is already in-memory")
|
||||
return(forest)
|
||||
|
||||
} else if(getJavaClass(old.forest.object) == "ca.joeltherrien.randomforest.tree.OfflineForest"){
|
||||
|
||||
forest$javaObject <- convertToOnlineForest.Java(old.forest.object)
|
||||
return(forest)
|
||||
|
||||
} else{
|
||||
stop("'forest' is not an online or offline forest")
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
# Internal function
|
||||
convertToOnlineForest.Java <- function(forest.java){
|
||||
offline.forest <- .jcast(forest.java, .class_OfflineForest)
|
||||
online.forest <- .jcall(offline.forest, makeResponse(.class_OnlineForest), "createOnlineCopy")
|
||||
return(online.forest)
|
||||
}
|
|
@ -42,7 +42,7 @@ CR_FunctionCombiner <- function(events, times = NULL){
|
|||
}
|
||||
|
||||
javaObject <- .jnew(.class_CompetingRiskFunctionCombiner, eventArray, timeArray)
|
||||
javaObject <- .jcast(javaObject, .class_ResponseCombiner)
|
||||
javaObject <- .jcast(javaObject, .class_ForestResponseCombiner)
|
||||
|
||||
combiner <- list(javaObject=javaObject,
|
||||
call=match.call(),
|
||||
|
|
|
@ -11,6 +11,7 @@
|
|||
.class_Serializable <- "java/io/Serializable"
|
||||
.class_File <- "java/io/File"
|
||||
.class_Random <- "java/util/Random"
|
||||
.class_Class <- "java/lang/Class"
|
||||
|
||||
# Utility Classes
|
||||
.class_DataUtils <- "ca/joeltherrien/randomforest/utils/DataUtils"
|
||||
|
@ -41,9 +42,12 @@
|
|||
|
||||
# Forest class
|
||||
.class_Forest <- "ca/joeltherrien/randomforest/tree/Forest"
|
||||
.class_OnlineForest <- "ca/joeltherrien/randomforest/tree/OnlineForest"
|
||||
.class_OfflineForest <- "ca/joeltherrien/randomforest/tree/OfflineForest"
|
||||
|
||||
# ResponseCombiner classes
|
||||
.class_ResponseCombiner <- "ca/joeltherrien/randomforest/tree/ResponseCombiner"
|
||||
.class_ForestResponseCombiner <- "ca/joeltherrien/randomforest/tree/ForestResponseCombiner"
|
||||
.class_CompetingRiskResponseCombiner <- "ca/joeltherrien/randomforest/responses/competingrisk/combiner/CompetingRiskResponseCombiner"
|
||||
.class_CompetingRiskFunctionCombiner <- "ca/joeltherrien/randomforest/responses/competingrisk/combiner/CompetingRiskFunctionCombiner"
|
||||
.class_MeanResponseCombiner <- "ca/joeltherrien/randomforest/responses/regression/MeanResponseCombiner"
|
||||
|
@ -77,4 +81,10 @@
|
|||
# This function does that.
|
||||
makeResponse <- function(className){
|
||||
return(paste0("L", className, ";"))
|
||||
}
|
||||
}
|
||||
|
||||
getJavaClass <- function(object){
|
||||
class <- .jcall(object, makeResponse(.class_Class), "getClass")
|
||||
className <- .jcall(class, "S", "getName")
|
||||
return(className)
|
||||
}
|
||||
|
|
|
@ -5,6 +5,12 @@
|
|||
#' Loads a random forest that was saved using \code{\link{saveForest}}.
|
||||
#'
|
||||
#' @param directory The directory created that saved the previous forest.
|
||||
#' @param forest.output Specifies whether the forest loaded should be loaded
|
||||
#' into memory, or reflect the saved files where only one tree is loaded at a
|
||||
#' time.
|
||||
#' @param maxTreeNum If for some reason you only want to load the number of
|
||||
#' trees up until a certain point, you can specify maxTreeNum as a single
|
||||
#' number.
|
||||
#' @return A JForest object; see \code{\link{train}} for details.
|
||||
#' @export
|
||||
#' @seealso \code{\link{train}}, \code{\link{saveForest}}
|
||||
|
@ -20,7 +26,11 @@
|
|||
#'
|
||||
#' saveForest(forest, "trees")
|
||||
#' new_forest <- loadForest("trees")
|
||||
loadForest <- function(directory){
|
||||
loadForest <- function(directory, forest.output = c("online", "offline"), maxTreeNum = NULL){
|
||||
|
||||
if(is.null(forest.output) | length(forest.output)==0 | !(forest.output[1] %in% c("online", "offline"))){
|
||||
stop("forest.output must be one of c(\"online\", \"offline\")")
|
||||
}
|
||||
|
||||
# First load the response combiners and the split finders
|
||||
nodeResponseCombiner.java <- .jcall(.class_DataUtils, makeResponse(.class_Object), "loadObject", paste0(directory, "/nodeResponseCombiner.jData"))
|
||||
|
@ -30,7 +40,7 @@ loadForest <- function(directory){
|
|||
splitFinder.java <- .jcast(splitFinder.java, .class_SplitFinder)
|
||||
|
||||
forestResponseCombiner.java <- .jcall(.class_DataUtils, makeResponse(.class_Object), "loadObject", paste0(directory, "/forestResponseCombiner.jData"))
|
||||
forestResponseCombiner.java <- .jcast(forestResponseCombiner.java, .class_ResponseCombiner)
|
||||
forestResponseCombiner.java <- .jcast(forestResponseCombiner.java, .class_ForestResponseCombiner)
|
||||
|
||||
covariateList <- .jcall(.class_DataUtils, makeResponse(.class_Object), "loadObject", paste0(directory, "/covariateList.jData"))
|
||||
covariateList <- .jcast(covariateList, .class_List)
|
||||
|
@ -42,8 +52,11 @@ loadForest <- function(directory){
|
|||
params$splitFinder$javaObject <- splitFinder.java
|
||||
params$forestResponseCombiner$javaObject <- forestResponseCombiner.java
|
||||
|
||||
forest <- loadForestArgumentsSpecified(directory, params$nodeResponseCombiner, params$splitFinder, params$forestResponseCombiner, covariateList, call,
|
||||
params$ntree, params$numberOfSplits, params$mtry, params$nodeSize, params$maxNodeDepth, params$splitPureNodes, params$randomSeed)
|
||||
forest <- loadForestArgumentsSpecified(directory, params$nodeResponseCombiner, params$splitFinder,
|
||||
params$forestResponseCombiner, covariateList, call,
|
||||
params$ntree, params$numberOfSplits, params$mtry,
|
||||
params$nodeSize, params$maxNodeDepth, params$splitPureNodes,
|
||||
params$randomSeed, forest.output, maxTreeNum)
|
||||
|
||||
return(forest)
|
||||
|
||||
|
@ -55,8 +68,11 @@ loadForest <- function(directory){
|
|||
# that uses the Java version's settings yaml file to recreate the forest, but
|
||||
# I'd appreciate knowing that someone's going to use it first (email me; see
|
||||
# README).
|
||||
loadForestArgumentsSpecified <- function(treeDirectory, nodeResponseCombiner, splitFinder, forestResponseCombiner,
|
||||
covariateList.java, call, ntree, numberOfSplits, mtry, nodeSize, maxNodeDepth = 100000, splitPureNodes=TRUE, randomSeed=NULL){
|
||||
loadForestArgumentsSpecified <- function(treeDirectory, nodeResponseCombiner, splitFinder,
|
||||
forestResponseCombiner, covariateList.java, call,
|
||||
ntree, numberOfSplits, mtry, nodeSize,
|
||||
maxNodeDepth = 100000, splitPureNodes=TRUE,
|
||||
randomSeed=NULL, forest.output = "online", maxTreeNum = NULL){
|
||||
|
||||
params <- list(
|
||||
splitFinder=splitFinder,
|
||||
|
@ -71,7 +87,33 @@ loadForestArgumentsSpecified <- function(treeDirectory, nodeResponseCombiner, sp
|
|||
randomSeed=randomSeed
|
||||
)
|
||||
|
||||
forest.java <- .jcall(.class_DataUtils, makeResponse(.class_Forest), "loadForest", treeDirectory, forestResponseCombiner$javaObject)
|
||||
forest.java <- NULL
|
||||
if(forest.output[1] == "online"){
|
||||
castedForestResponseCombiner <- .jcast(forestResponseCombiner$javaObject, .class_ResponseCombiner) # OnlineForest constructor takes a ResponseCombiner
|
||||
|
||||
if(is.null(maxTreeNum)){
|
||||
forest.java <- .jcall(.class_DataUtils, makeResponse(.class_OnlineForest), "loadOnlineForest",
|
||||
treeDirectory, castedForestResponseCombiner)
|
||||
} else{
|
||||
tree.file.array <- .jcall(.class_RUtils, paste0("[", makeResponse(.class_File)), "getTreeFileArray",
|
||||
treeDirectory, as.integer(maxTreeNum), evalArray = FALSE)
|
||||
forest.java <- .jcall(.class_DataUtils, makeResponse(.class_OnlineForest), "loadOnlineForest",
|
||||
tree.file.array, castedForestResponseCombiner)
|
||||
|
||||
}
|
||||
|
||||
} else{ # offline forest
|
||||
if(is.null(maxTreeNum)){
|
||||
path.as.file <- .jnew(.class_File, treeDirectory)
|
||||
forest.java <- .jnew(.class_OfflineForest, path.as.file, forestResponseCombiner$javaObject)
|
||||
} else{
|
||||
tree.file.array <- .jcall(.class_RUtils, paste0("[", makeResponse(.class_File)), "getTreeFileArray",
|
||||
treeDirectory, as.integer(maxTreeNum), evalArray = FALSE)
|
||||
forest.java <- .jnew(.class_OfflineForest, tree.file.array, forestResponseCombiner$javaObject)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
forestObject <- list(call=call, javaObject=forest.java, covariateList=covariateList.java, params=params)
|
||||
class(forestObject) <- "JRandomForest"
|
||||
|
|
|
@ -48,7 +48,7 @@ WeightedVarianceSplitFinder <- function(){
|
|||
#'
|
||||
MeanResponseCombiner <- function(){
|
||||
javaObject <- .jnew(.class_MeanResponseCombiner)
|
||||
javaObject <- .jcast(javaObject, .class_ResponseCombiner)
|
||||
javaObject <- .jcast(javaObject, .class_ForestResponseCombiner)
|
||||
|
||||
combiner <- list(javaObject=javaObject, call=match.call(), outputClass="numeric")
|
||||
combiner$convertToRFunction <- function(javaObject, ...){
|
||||
|
|
40
R/train.R
40
R/train.R
|
@ -18,7 +18,7 @@ train.internal <- function(dataset, splitFinder,
|
|||
nodeResponseCombiner, forestResponseCombiner, ntree,
|
||||
numberOfSplits, mtry, nodeSize, maxNodeDepth,
|
||||
splitPureNodes, savePath, savePath.overwrite,
|
||||
cores, randomSeed, displayProgress){
|
||||
forest.output, cores, randomSeed, displayProgress){
|
||||
|
||||
# Some quick checks on parameters
|
||||
ntree <- as.integer(ntree)
|
||||
|
@ -51,6 +51,10 @@ train.internal <- function(dataset, splitFinder,
|
|||
stop("savePath.overwrite must be one of c(\"warn\", \"delete\", \"merge\")")
|
||||
}
|
||||
|
||||
if(is.null(forest.output) | length(forest.output)==0 | !(forest.output[1] %in% c("online", "offline"))){
|
||||
stop("forest.output must be one of c(\"online\", \"offline\")")
|
||||
}
|
||||
|
||||
if(is.null(splitFinder)){
|
||||
splitFinder <- splitFinderDefault(dataset$responses)
|
||||
}
|
||||
|
@ -129,22 +133,23 @@ train.internal <- function(dataset, splitFinder,
|
|||
params=params,
|
||||
forestCall=match.call())
|
||||
|
||||
forest.java <- NULL
|
||||
if(cores > 1){
|
||||
.jcall(forestTrainer, "V", "trainParallelOnDisk", .object_Optional(), as.integer(cores))
|
||||
forest.java <- .jcall(forestTrainer, makeResponse(.class_OfflineForest), "trainParallelOnDisk", .object_Optional(), as.integer(cores))
|
||||
} else {
|
||||
.jcall(forestTrainer, "V", "trainSerialOnDisk", .object_Optional())
|
||||
forest.java <- .jcall(forestTrainer, makeResponse(.class_OfflineForest), "trainSerialOnDisk", .object_Optional())
|
||||
}
|
||||
|
||||
# Need to now load forest trees back into memory
|
||||
forest.java <- .jcall(.class_DataUtils, makeResponse(.class_Forest), "loadForest", savePath, forestResponseCombiner$javaObject)
|
||||
|
||||
|
||||
if(forest.output[1] == "online"){
|
||||
forest.java <- convertToOnlineForest.Java(forest.java)
|
||||
}
|
||||
|
||||
}
|
||||
else{ # save directly into memory
|
||||
if(cores > 1){
|
||||
forest.java <- .jcall(forestTrainer, makeResponse(.class_Forest), "trainParallelInMemory", .object_Optional(), as.integer(cores))
|
||||
forest.java <- .jcall(forestTrainer, makeResponse(.class_OnlineForest), "trainParallelInMemory", .object_Optional(), as.integer(cores))
|
||||
} else {
|
||||
forest.java <- .jcall(forestTrainer, makeResponse(.class_Forest), "trainSerialInMemory", .object_Optional())
|
||||
forest.java <- .jcall(forestTrainer, makeResponse(.class_OnlineForest), "trainSerialInMemory", .object_Optional())
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -253,6 +258,10 @@ train.internal <- function(dataset, splitFinder,
|
|||
#' assumes (without checking) that the existing trees are from a previous run
|
||||
#' and starts from where it left off. This option is useful if recovering from
|
||||
#' a crash.
|
||||
#' @param forest.output This parameter only applies if \code{savePath} has been
|
||||
#' set; set to 'online' (default) and the saved forest will be loaded into
|
||||
#' memory after being trained. Set to 'offline' and the forest is not saved
|
||||
#' into memory, but can still be used in a memory unintensive manner.
|
||||
#' @param cores This parameter specifies how many trees will be simultaneously
|
||||
#' trained. By default the package attempts to detect how many cores you have
|
||||
#' by using the \code{parallel} package and using all of them. You may specify
|
||||
|
@ -311,7 +320,8 @@ train.internal <- function(dataset, splitFinder,
|
|||
train <- function(formula, data, splitFinder = NULL, nodeResponseCombiner = NULL,
|
||||
forestResponseCombiner = NULL, ntree, numberOfSplits, mtry,
|
||||
nodeSize, maxNodeDepth = 100000, na.penalty = TRUE, splitPureNodes=TRUE,
|
||||
savePath=NULL, savePath.overwrite=c("warn", "delete", "merge"),
|
||||
savePath = NULL, savePath.overwrite = c("warn", "delete", "merge"),
|
||||
forest.output = c("online", "offline"),
|
||||
cores = getCores(), randomSeed = NULL, displayProgress = TRUE){
|
||||
|
||||
dataset <- processFormula(formula, data, na.penalty = na.penalty)
|
||||
|
@ -322,8 +332,8 @@ train <- function(formula, data, splitFinder = NULL, nodeResponseCombiner = NULL
|
|||
ntree = ntree, numberOfSplits = numberOfSplits,
|
||||
mtry = mtry, nodeSize = nodeSize, maxNodeDepth = maxNodeDepth,
|
||||
splitPureNodes = splitPureNodes, savePath = savePath,
|
||||
savePath.overwrite = savePath.overwrite, cores = cores,
|
||||
randomSeed = randomSeed, displayProgress = displayProgress)
|
||||
savePath.overwrite = savePath.overwrite, forest.output = forest.output,
|
||||
cores = cores, randomSeed = randomSeed, displayProgress = displayProgress)
|
||||
|
||||
forest$call <- match.call()
|
||||
forest$formula <- formula
|
||||
|
@ -370,8 +380,10 @@ createTreeTrainer <- function(responseCombiner, splitFinder, covariateList, numb
|
|||
builderClassReturned <- makeResponse(.class_TreeTrainer_Builder)
|
||||
|
||||
builder <- .jcall(.class_TreeTrainer, builderClassReturned, "builder")
|
||||
|
||||
builder <- .jcall(builder, builderClassReturned, "responseCombiner", responseCombiner$javaObject)
|
||||
|
||||
responseCombinerCasted <- .jcast(responseCombiner$javaObject, .class_ResponseCombiner) # might need to cast a ForestResponseCombiner down
|
||||
|
||||
builder <- .jcall(builder, builderClassReturned, "responseCombiner", responseCombinerCasted)
|
||||
builder <- .jcall(builder, builderClassReturned, "splitFinder", splitFinder$javaObject)
|
||||
builder <- .jcall(builder, builderClassReturned, "covariates", covariateList)
|
||||
builder <- .jcall(builder, builderClassReturned, "numberOfSplits", as.integer(numberOfSplits))
|
||||
|
|
Binary file not shown.
|
@ -6,7 +6,8 @@
|
|||
\usage{
|
||||
addTrees(forest, numTreesToAdd, savePath = NULL,
|
||||
savePath.overwrite = c("warn", "delete", "merge"),
|
||||
cores = getCores(), displayProgress = TRUE)
|
||||
forest.output = c("online", "offline"), cores = getCores(),
|
||||
displayProgress = TRUE)
|
||||
}
|
||||
\arguments{
|
||||
\item{forest}{An existing forest.}
|
||||
|
@ -21,6 +22,11 @@ a previously saved forest.}
|
|||
directory, possibly containing another forest, this specifies what should
|
||||
be done.}
|
||||
|
||||
\item{forest.output}{This parameter only applies if \code{savePath} has been
|
||||
set; set to 'online' (default) and the saved forest will be loaded into
|
||||
memory after being trained. Set to 'offline' and the forest is not saved
|
||||
into memory, but can still be used in a memory unintensive manner.}
|
||||
|
||||
\item{cores}{The number of cores to be used for training the new trees.}
|
||||
|
||||
\item{displayProgress}{A logical indicating whether the progress should be
|
||||
|
|
20
man/convertToOnlineForest.Rd
Normal file
20
man/convertToOnlineForest.Rd
Normal file
|
@ -0,0 +1,20 @@
|
|||
% Generated by roxygen2: do not edit by hand
|
||||
% Please edit documentation in R/convertToOnlineForest.R
|
||||
\name{convertToOnlineForest}
|
||||
\alias{convertToOnlineForest}
|
||||
\title{Convert to Online Forest}
|
||||
\usage{
|
||||
convertToOnlineForest(forest)
|
||||
}
|
||||
\arguments{
|
||||
\item{forest}{The offline forest.}
|
||||
}
|
||||
\value{
|
||||
An online, in memory forst.
|
||||
}
|
||||
\description{
|
||||
Some forests are too large to store in memory and have been saved to disk.
|
||||
They can still be used, but their performance is much slower. If there's
|
||||
enough memory, they can easily be converted into an in-memory forest that is
|
||||
faster to use.
|
||||
}
|
|
@ -4,10 +4,19 @@
|
|||
\alias{loadForest}
|
||||
\title{Load Random Forest}
|
||||
\usage{
|
||||
loadForest(directory)
|
||||
loadForest(directory, forest.output = c("online", "offline"),
|
||||
maxTreeNum = NULL)
|
||||
}
|
||||
\arguments{
|
||||
\item{directory}{The directory created that saved the previous forest.}
|
||||
|
||||
\item{forest.output}{Specifies whether the forest loaded should be loaded
|
||||
into memory, or reflect the saved files where only one tree is loaded at a
|
||||
time.}
|
||||
|
||||
\item{maxTreeNum}{If for some reason you only want to load the number of
|
||||
trees up until a certain point, you can specify maxTreeNum as a single
|
||||
number.}
|
||||
}
|
||||
\value{
|
||||
A JForest object; see \code{\link{train}} for details.
|
||||
|
|
|
@ -8,7 +8,8 @@ train(formula, data, splitFinder = NULL, nodeResponseCombiner = NULL,
|
|||
forestResponseCombiner = NULL, ntree, numberOfSplits, mtry, nodeSize,
|
||||
maxNodeDepth = 1e+05, na.penalty = TRUE, splitPureNodes = TRUE,
|
||||
savePath = NULL, savePath.overwrite = c("warn", "delete", "merge"),
|
||||
cores = getCores(), randomSeed = NULL, displayProgress = TRUE)
|
||||
forest.output = c("online", "offline"), cores = getCores(),
|
||||
randomSeed = NULL, displayProgress = TRUE)
|
||||
}
|
||||
\arguments{
|
||||
\item{formula}{You may specify the response and covariates as a formula
|
||||
|
@ -106,6 +107,11 @@ assumes (without checking) that the existing trees are from a previous run
|
|||
and starts from where it left off. This option is useful if recovering from
|
||||
a crash.}
|
||||
|
||||
\item{forest.output}{This parameter only applies if \code{savePath} has been
|
||||
set; set to 'online' (default) and the saved forest will be loaded into
|
||||
memory after being trained. Set to 'offline' and the forest is not saved
|
||||
into memory, but can still be used in a memory unintensive manner.}
|
||||
|
||||
\item{cores}{This parameter specifies how many trees will be simultaneously
|
||||
trained. By default the package attempts to detect how many cores you have
|
||||
by using the \code{parallel} package and using all of them. You may specify
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
context("Add trees on existing forest")
|
||||
|
||||
test_that("Can add trees on existing forest", {
|
||||
test_that("Can add trees on existing online forest", {
|
||||
|
||||
trainingData <- data.frame(x=rnorm(100))
|
||||
trainingData$T <- rexp(100) + abs(trainingData$x)
|
||||
|
@ -20,6 +20,44 @@ test_that("Can add trees on existing forest", {
|
|||
|
||||
})
|
||||
|
||||
test_that("Can add trees on existing offline forest", {
|
||||
|
||||
if(file.exists("trees")){ # folder could exist from a previous failed test; delete it
|
||||
unlink("trees", recursive=TRUE)
|
||||
}
|
||||
|
||||
|
||||
trainingData <- data.frame(x=rnorm(100))
|
||||
trainingData$T <- rexp(100) + abs(trainingData$x)
|
||||
trainingData$delta <- sample(0:2, size = 100, replace=TRUE)
|
||||
|
||||
forest <- train(CR_Response(delta, T) ~ x, trainingData, ntree=50,
|
||||
numberOfSplits=0, mtry=1, nodeSize=5,
|
||||
forestResponseCombiner = CR_FunctionCombiner(events = 1:2, times = 0:10), # TODO - remove specifing times; this is workaround around unimplemented feature for offline forests
|
||||
cores=2, displayProgress=FALSE, savePath="trees",
|
||||
forest.output = "offline")
|
||||
warning("TODO - need to implement feature; test workaround in the meantime")
|
||||
|
||||
predictions <- predict(forest)
|
||||
|
||||
warning_message <- "Assuming that the previous forest at savePath is the provided forest argument; if not true then your results will be suspect"
|
||||
|
||||
forest.more <- expect_warning(addTrees(forest, 50, cores=2, displayProgress=FALSE,
|
||||
savePath="trees", savePath.overwrite = "merge",
|
||||
forest.output = "offline"), fixed=warning_message) # test multi-core
|
||||
|
||||
predictions <- predict(forest)
|
||||
|
||||
forest.more <- expect_warning(addTrees(forest, 50, cores=1, displayProgress=FALSE,
|
||||
savePath="trees", savePath.overwrite = "merge",
|
||||
forest.output = "offline"), fixed=warning_message) # test single-core
|
||||
|
||||
expect_true(T) # show Ok if we got this far
|
||||
|
||||
unlink("trees", recursive=TRUE)
|
||||
|
||||
})
|
||||
|
||||
test_that("Test adding trees on saved forest - using delete", {
|
||||
|
||||
if(file.exists("trees")){ # folder could exist from a previous failed test; delete it
|
||||
|
|
|
@ -2,7 +2,11 @@ context("Train, save, and load without error")
|
|||
|
||||
test_that("Can save & load regression example", {
|
||||
|
||||
expect_false(file.exists("trees_saving_loading")) # Folder shouldn't exist yet
|
||||
if(file.exists("trees_saving_loading")){
|
||||
unlink("trees_saving_loading", recursive=TRUE)
|
||||
}
|
||||
|
||||
expect_false(file.exists("trees_saving_loading")) # Folder shouldn't exist at this point
|
||||
|
||||
x1 <- rnorm(1000)
|
||||
x2 <- rnorm(1000)
|
||||
|
|
|
@ -13,7 +13,7 @@ test_that("Can save a random forest while training, and use it afterward", {
|
|||
data <- data.frame(x1, x2, y)
|
||||
forest <- train(y ~ x1 + x2, data,
|
||||
ntree=100, numberOfSplits = 5, mtry = 1, nodeSize = 5,
|
||||
savePath="trees", displayProgress=FALSE)
|
||||
savePath="trees", forest.output = "online", displayProgress=FALSE)
|
||||
|
||||
expect_true(file.exists("trees")) # Something should have been saved
|
||||
|
||||
|
@ -26,6 +26,39 @@ test_that("Can save a random forest while training, and use it afterward", {
|
|||
predictions <- predict(newforest, newData)
|
||||
|
||||
|
||||
unlink("trees", recursive=TRUE)
|
||||
|
||||
})
|
||||
|
||||
test_that("Can save a random forest while training, and use it afterward with pure offline forest", {
|
||||
|
||||
if(file.exists("trees")){ # folder could exist from a previous failed test; delete it
|
||||
unlink("trees", recursive=TRUE)
|
||||
}
|
||||
|
||||
x1 <- rnorm(1000)
|
||||
x2 <- rnorm(1000)
|
||||
y <- 1 + x1 + x2 + rnorm(1000)
|
||||
|
||||
data <- data.frame(x1, x2, y)
|
||||
forest <- train(y ~ x1 + x2, data,
|
||||
ntree=100, numberOfSplits = 5, mtry = 1, nodeSize = 5,
|
||||
savePath="trees", forest.output = "offline", displayProgress=FALSE)
|
||||
|
||||
expect_true(file.exists("trees")) # Something should have been saved
|
||||
|
||||
# try making a little prediction to verify it works
|
||||
newData <- data.frame(x1=seq(from=-3, to=3, by=0.5), x2=0)
|
||||
predictions <- predict(forest, newData)
|
||||
|
||||
# Also make sure we can load the forest too
|
||||
newforest <- loadForest("trees")
|
||||
predictions <- predict(newforest, newData)
|
||||
|
||||
# Last, make sure we can take the forest online
|
||||
onlineForest <- convertToOnlineForest(forest)
|
||||
predictions <- predict(onlineForest, newData)
|
||||
|
||||
unlink("trees", recursive=TRUE)
|
||||
|
||||
})
|
Loading…
Reference in a new issue