Add support for 'offline' forests
This commit is contained in:
parent
3fa2fef82a
commit
589df1a18b
16 changed files with 270 additions and 38 deletions
|
@ -26,6 +26,7 @@ export(Numeric)
|
||||||
export(WeightedVarianceSplitFinder)
|
export(WeightedVarianceSplitFinder)
|
||||||
export(addTrees)
|
export(addTrees)
|
||||||
export(connectToData)
|
export(connectToData)
|
||||||
|
export(convertToOnlineForest)
|
||||||
export(extractCHF)
|
export(extractCHF)
|
||||||
export(extractCIF)
|
export(extractCIF)
|
||||||
export(extractMortalities)
|
export(extractMortalities)
|
||||||
|
|
29
R/addTrees.R
29
R/addTrees.R
|
@ -14,6 +14,10 @@
|
||||||
#' @param savePath.overwrite If \code{savePath} is pointing to an existing
|
#' @param savePath.overwrite If \code{savePath} is pointing to an existing
|
||||||
#' directory, possibly containing another forest, this specifies what should
|
#' directory, possibly containing another forest, this specifies what should
|
||||||
#' be done.
|
#' be done.
|
||||||
|
#' @param forest.output This parameter only applies if \code{savePath} has been
|
||||||
|
#' set; set to 'online' (default) and the saved forest will be loaded into
|
||||||
|
#' memory after being trained. Set to 'offline' and the forest is not saved
|
||||||
|
#' into memory, but can still be used in a memory unintensive manner.
|
||||||
#' @param cores The number of cores to be used for training the new trees.
|
#' @param cores The number of cores to be used for training the new trees.
|
||||||
#' @param displayProgress A logical indicating whether the progress should be
|
#' @param displayProgress A logical indicating whether the progress should be
|
||||||
#' displayed to console; default is \code{TRUE}. Useful to set to FALSE in
|
#' displayed to console; default is \code{TRUE}. Useful to set to FALSE in
|
||||||
|
@ -22,7 +26,11 @@
|
||||||
#' @return A new forest with the original and additional trees.
|
#' @return A new forest with the original and additional trees.
|
||||||
#' @export
|
#' @export
|
||||||
#'
|
#'
|
||||||
addTrees <- function(forest, numTreesToAdd, savePath = NULL, savePath.overwrite = c("warn", "delete", "merge"), cores = getCores(), displayProgress = TRUE){
|
addTrees <- function(forest, numTreesToAdd, savePath = NULL,
|
||||||
|
savePath.overwrite = c("warn", "delete", "merge"),
|
||||||
|
forest.output = c("online", "offline"),
|
||||||
|
cores = getCores(), displayProgress = TRUE){
|
||||||
|
|
||||||
if(is.null(forest$dataset)){
|
if(is.null(forest$dataset)){
|
||||||
stop("Training dataset must be connected to forest before more trees can be added; this can be done manually by using connectToData")
|
stop("Training dataset must be connected to forest before more trees can be added; this can be done manually by using connectToData")
|
||||||
}
|
}
|
||||||
|
@ -37,6 +45,10 @@ addTrees <- function(forest, numTreesToAdd, savePath = NULL, savePath.overwrite
|
||||||
stop("savePath.overwrite must be one of c(\"warn\", \"delete\", \"merge\")")
|
stop("savePath.overwrite must be one of c(\"warn\", \"delete\", \"merge\")")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if(is.null(forest.output) | length(forest.output)==0 | !(forest.output[1] %in% c("online", "offline"))){
|
||||||
|
stop("forest.output must be one of c(\"online\", \"offline\")")
|
||||||
|
}
|
||||||
|
|
||||||
newTreeCount <- forest$params$ntree + as.integer(numTreesToAdd)
|
newTreeCount <- forest$params$ntree + as.integer(numTreesToAdd)
|
||||||
|
|
||||||
treeTrainer <- createTreeTrainer(responseCombiner=forest$params$nodeResponseCombiner,
|
treeTrainer <- createTreeTrainer(responseCombiner=forest$params$nodeResponseCombiner,
|
||||||
|
@ -98,22 +110,23 @@ addTrees <- function(forest, numTreesToAdd, savePath = NULL, savePath.overwrite
|
||||||
params=params,
|
params=params,
|
||||||
forestCall=match.call())
|
forestCall=match.call())
|
||||||
|
|
||||||
|
forest.java <- NULL
|
||||||
if(cores > 1){
|
if(cores > 1){
|
||||||
.jcall(forestTrainer, "V", "trainParallelOnDisk", initial.forest.optional, as.integer(cores))
|
forest.java <- .jcall(forestTrainer, makeResponse(.class_OfflineForest), "trainParallelOnDisk", initial.forest.optional, as.integer(cores))
|
||||||
} else {
|
} else {
|
||||||
.jcall(forestTrainer, "V", "trainSerialOnDisk", initial.forest.optional)
|
forest.java <- .jcall(forestTrainer, makeResponse(.class_OfflineForest), "trainSerialOnDisk", initial.forest.optional)
|
||||||
}
|
}
|
||||||
|
|
||||||
# Need to now load forest trees back into memory
|
if(forest.output[1] == "online"){
|
||||||
forest.java <- .jcall(.class_DataUtils, makeResponse(.class_Forest), "loadForest", savePath, forest$params$forestResponseCombiner$javaObject)
|
forest.java <- convertToOnlineForest.Java(forest.java)
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
else{ # save directly into memory
|
else{ # save directly into memory
|
||||||
if(cores > 1){
|
if(cores > 1){
|
||||||
forest.java <- .jcall(forestTrainer, makeResponse(.class_Forest), "trainParallelInMemory", initial.forest.optional, as.integer(cores))
|
forest.java <- .jcall(forestTrainer, makeResponse(.class_OnlineForest), "trainParallelInMemory", initial.forest.optional, as.integer(cores))
|
||||||
} else {
|
} else {
|
||||||
forest.java <- .jcall(forestTrainer, makeResponse(.class_Forest), "trainSerialInMemory", initial.forest.optional)
|
forest.java <- .jcall(forestTrainer, makeResponse(.class_OnlineForest), "trainSerialInMemory", initial.forest.optional)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
38
R/convertToOnlineForest.R
Normal file
38
R/convertToOnlineForest.R
Normal file
|
@ -0,0 +1,38 @@
|
||||||
|
|
||||||
|
#' Convert to Online Forest
|
||||||
|
#'
|
||||||
|
#' Some forests are too large to store in memory and have been saved to disk.
|
||||||
|
#' They can still be used, but their performance is much slower. If there's
|
||||||
|
#' enough memory, they can easily be converted into an in-memory forest that is
|
||||||
|
#' faster to use.
|
||||||
|
#'
|
||||||
|
#' @param forest The offline forest.
|
||||||
|
#'
|
||||||
|
#' @return An online, in memory forst.
|
||||||
|
#' @export
|
||||||
|
#'
|
||||||
|
convertToOnlineForest <- function(forest){
|
||||||
|
old.forest.object <- forest$javaObject
|
||||||
|
|
||||||
|
if(getJavaClass(old.forest.object) == "ca.joeltherrien.randomforest.tree.OnlineForest"){
|
||||||
|
|
||||||
|
warning("forest is already in-memory")
|
||||||
|
return(forest)
|
||||||
|
|
||||||
|
} else if(getJavaClass(old.forest.object) == "ca.joeltherrien.randomforest.tree.OfflineForest"){
|
||||||
|
|
||||||
|
forest$javaObject <- convertToOnlineForest.Java(old.forest.object)
|
||||||
|
return(forest)
|
||||||
|
|
||||||
|
} else{
|
||||||
|
stop("'forest' is not an online or offline forest")
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
# Internal function
|
||||||
|
convertToOnlineForest.Java <- function(forest.java){
|
||||||
|
offline.forest <- .jcast(forest.java, .class_OfflineForest)
|
||||||
|
online.forest <- .jcall(offline.forest, makeResponse(.class_OnlineForest), "createOnlineCopy")
|
||||||
|
return(online.forest)
|
||||||
|
}
|
|
@ -42,7 +42,7 @@ CR_FunctionCombiner <- function(events, times = NULL){
|
||||||
}
|
}
|
||||||
|
|
||||||
javaObject <- .jnew(.class_CompetingRiskFunctionCombiner, eventArray, timeArray)
|
javaObject <- .jnew(.class_CompetingRiskFunctionCombiner, eventArray, timeArray)
|
||||||
javaObject <- .jcast(javaObject, .class_ResponseCombiner)
|
javaObject <- .jcast(javaObject, .class_ForestResponseCombiner)
|
||||||
|
|
||||||
combiner <- list(javaObject=javaObject,
|
combiner <- list(javaObject=javaObject,
|
||||||
call=match.call(),
|
call=match.call(),
|
||||||
|
|
|
@ -11,6 +11,7 @@
|
||||||
.class_Serializable <- "java/io/Serializable"
|
.class_Serializable <- "java/io/Serializable"
|
||||||
.class_File <- "java/io/File"
|
.class_File <- "java/io/File"
|
||||||
.class_Random <- "java/util/Random"
|
.class_Random <- "java/util/Random"
|
||||||
|
.class_Class <- "java/lang/Class"
|
||||||
|
|
||||||
# Utility Classes
|
# Utility Classes
|
||||||
.class_DataUtils <- "ca/joeltherrien/randomforest/utils/DataUtils"
|
.class_DataUtils <- "ca/joeltherrien/randomforest/utils/DataUtils"
|
||||||
|
@ -41,9 +42,12 @@
|
||||||
|
|
||||||
# Forest class
|
# Forest class
|
||||||
.class_Forest <- "ca/joeltherrien/randomforest/tree/Forest"
|
.class_Forest <- "ca/joeltherrien/randomforest/tree/Forest"
|
||||||
|
.class_OnlineForest <- "ca/joeltherrien/randomforest/tree/OnlineForest"
|
||||||
|
.class_OfflineForest <- "ca/joeltherrien/randomforest/tree/OfflineForest"
|
||||||
|
|
||||||
# ResponseCombiner classes
|
# ResponseCombiner classes
|
||||||
.class_ResponseCombiner <- "ca/joeltherrien/randomforest/tree/ResponseCombiner"
|
.class_ResponseCombiner <- "ca/joeltherrien/randomforest/tree/ResponseCombiner"
|
||||||
|
.class_ForestResponseCombiner <- "ca/joeltherrien/randomforest/tree/ForestResponseCombiner"
|
||||||
.class_CompetingRiskResponseCombiner <- "ca/joeltherrien/randomforest/responses/competingrisk/combiner/CompetingRiskResponseCombiner"
|
.class_CompetingRiskResponseCombiner <- "ca/joeltherrien/randomforest/responses/competingrisk/combiner/CompetingRiskResponseCombiner"
|
||||||
.class_CompetingRiskFunctionCombiner <- "ca/joeltherrien/randomforest/responses/competingrisk/combiner/CompetingRiskFunctionCombiner"
|
.class_CompetingRiskFunctionCombiner <- "ca/joeltherrien/randomforest/responses/competingrisk/combiner/CompetingRiskFunctionCombiner"
|
||||||
.class_MeanResponseCombiner <- "ca/joeltherrien/randomforest/responses/regression/MeanResponseCombiner"
|
.class_MeanResponseCombiner <- "ca/joeltherrien/randomforest/responses/regression/MeanResponseCombiner"
|
||||||
|
@ -77,4 +81,10 @@
|
||||||
# This function does that.
|
# This function does that.
|
||||||
makeResponse <- function(className){
|
makeResponse <- function(className){
|
||||||
return(paste0("L", className, ";"))
|
return(paste0("L", className, ";"))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
getJavaClass <- function(object){
|
||||||
|
class <- .jcall(object, makeResponse(.class_Class), "getClass")
|
||||||
|
className <- .jcall(class, "S", "getName")
|
||||||
|
return(className)
|
||||||
|
}
|
||||||
|
|
|
@ -5,6 +5,12 @@
|
||||||
#' Loads a random forest that was saved using \code{\link{saveForest}}.
|
#' Loads a random forest that was saved using \code{\link{saveForest}}.
|
||||||
#'
|
#'
|
||||||
#' @param directory The directory created that saved the previous forest.
|
#' @param directory The directory created that saved the previous forest.
|
||||||
|
#' @param forest.output Specifies whether the forest loaded should be loaded
|
||||||
|
#' into memory, or reflect the saved files where only one tree is loaded at a
|
||||||
|
#' time.
|
||||||
|
#' @param maxTreeNum If for some reason you only want to load the number of
|
||||||
|
#' trees up until a certain point, you can specify maxTreeNum as a single
|
||||||
|
#' number.
|
||||||
#' @return A JForest object; see \code{\link{train}} for details.
|
#' @return A JForest object; see \code{\link{train}} for details.
|
||||||
#' @export
|
#' @export
|
||||||
#' @seealso \code{\link{train}}, \code{\link{saveForest}}
|
#' @seealso \code{\link{train}}, \code{\link{saveForest}}
|
||||||
|
@ -20,7 +26,11 @@
|
||||||
#'
|
#'
|
||||||
#' saveForest(forest, "trees")
|
#' saveForest(forest, "trees")
|
||||||
#' new_forest <- loadForest("trees")
|
#' new_forest <- loadForest("trees")
|
||||||
loadForest <- function(directory){
|
loadForest <- function(directory, forest.output = c("online", "offline"), maxTreeNum = NULL){
|
||||||
|
|
||||||
|
if(is.null(forest.output) | length(forest.output)==0 | !(forest.output[1] %in% c("online", "offline"))){
|
||||||
|
stop("forest.output must be one of c(\"online\", \"offline\")")
|
||||||
|
}
|
||||||
|
|
||||||
# First load the response combiners and the split finders
|
# First load the response combiners and the split finders
|
||||||
nodeResponseCombiner.java <- .jcall(.class_DataUtils, makeResponse(.class_Object), "loadObject", paste0(directory, "/nodeResponseCombiner.jData"))
|
nodeResponseCombiner.java <- .jcall(.class_DataUtils, makeResponse(.class_Object), "loadObject", paste0(directory, "/nodeResponseCombiner.jData"))
|
||||||
|
@ -30,7 +40,7 @@ loadForest <- function(directory){
|
||||||
splitFinder.java <- .jcast(splitFinder.java, .class_SplitFinder)
|
splitFinder.java <- .jcast(splitFinder.java, .class_SplitFinder)
|
||||||
|
|
||||||
forestResponseCombiner.java <- .jcall(.class_DataUtils, makeResponse(.class_Object), "loadObject", paste0(directory, "/forestResponseCombiner.jData"))
|
forestResponseCombiner.java <- .jcall(.class_DataUtils, makeResponse(.class_Object), "loadObject", paste0(directory, "/forestResponseCombiner.jData"))
|
||||||
forestResponseCombiner.java <- .jcast(forestResponseCombiner.java, .class_ResponseCombiner)
|
forestResponseCombiner.java <- .jcast(forestResponseCombiner.java, .class_ForestResponseCombiner)
|
||||||
|
|
||||||
covariateList <- .jcall(.class_DataUtils, makeResponse(.class_Object), "loadObject", paste0(directory, "/covariateList.jData"))
|
covariateList <- .jcall(.class_DataUtils, makeResponse(.class_Object), "loadObject", paste0(directory, "/covariateList.jData"))
|
||||||
covariateList <- .jcast(covariateList, .class_List)
|
covariateList <- .jcast(covariateList, .class_List)
|
||||||
|
@ -42,8 +52,11 @@ loadForest <- function(directory){
|
||||||
params$splitFinder$javaObject <- splitFinder.java
|
params$splitFinder$javaObject <- splitFinder.java
|
||||||
params$forestResponseCombiner$javaObject <- forestResponseCombiner.java
|
params$forestResponseCombiner$javaObject <- forestResponseCombiner.java
|
||||||
|
|
||||||
forest <- loadForestArgumentsSpecified(directory, params$nodeResponseCombiner, params$splitFinder, params$forestResponseCombiner, covariateList, call,
|
forest <- loadForestArgumentsSpecified(directory, params$nodeResponseCombiner, params$splitFinder,
|
||||||
params$ntree, params$numberOfSplits, params$mtry, params$nodeSize, params$maxNodeDepth, params$splitPureNodes, params$randomSeed)
|
params$forestResponseCombiner, covariateList, call,
|
||||||
|
params$ntree, params$numberOfSplits, params$mtry,
|
||||||
|
params$nodeSize, params$maxNodeDepth, params$splitPureNodes,
|
||||||
|
params$randomSeed, forest.output, maxTreeNum)
|
||||||
|
|
||||||
return(forest)
|
return(forest)
|
||||||
|
|
||||||
|
@ -55,8 +68,11 @@ loadForest <- function(directory){
|
||||||
# that uses the Java version's settings yaml file to recreate the forest, but
|
# that uses the Java version's settings yaml file to recreate the forest, but
|
||||||
# I'd appreciate knowing that someone's going to use it first (email me; see
|
# I'd appreciate knowing that someone's going to use it first (email me; see
|
||||||
# README).
|
# README).
|
||||||
loadForestArgumentsSpecified <- function(treeDirectory, nodeResponseCombiner, splitFinder, forestResponseCombiner,
|
loadForestArgumentsSpecified <- function(treeDirectory, nodeResponseCombiner, splitFinder,
|
||||||
covariateList.java, call, ntree, numberOfSplits, mtry, nodeSize, maxNodeDepth = 100000, splitPureNodes=TRUE, randomSeed=NULL){
|
forestResponseCombiner, covariateList.java, call,
|
||||||
|
ntree, numberOfSplits, mtry, nodeSize,
|
||||||
|
maxNodeDepth = 100000, splitPureNodes=TRUE,
|
||||||
|
randomSeed=NULL, forest.output = "online", maxTreeNum = NULL){
|
||||||
|
|
||||||
params <- list(
|
params <- list(
|
||||||
splitFinder=splitFinder,
|
splitFinder=splitFinder,
|
||||||
|
@ -71,7 +87,33 @@ loadForestArgumentsSpecified <- function(treeDirectory, nodeResponseCombiner, sp
|
||||||
randomSeed=randomSeed
|
randomSeed=randomSeed
|
||||||
)
|
)
|
||||||
|
|
||||||
forest.java <- .jcall(.class_DataUtils, makeResponse(.class_Forest), "loadForest", treeDirectory, forestResponseCombiner$javaObject)
|
forest.java <- NULL
|
||||||
|
if(forest.output[1] == "online"){
|
||||||
|
castedForestResponseCombiner <- .jcast(forestResponseCombiner$javaObject, .class_ResponseCombiner) # OnlineForest constructor takes a ResponseCombiner
|
||||||
|
|
||||||
|
if(is.null(maxTreeNum)){
|
||||||
|
forest.java <- .jcall(.class_DataUtils, makeResponse(.class_OnlineForest), "loadOnlineForest",
|
||||||
|
treeDirectory, castedForestResponseCombiner)
|
||||||
|
} else{
|
||||||
|
tree.file.array <- .jcall(.class_RUtils, paste0("[", makeResponse(.class_File)), "getTreeFileArray",
|
||||||
|
treeDirectory, as.integer(maxTreeNum), evalArray = FALSE)
|
||||||
|
forest.java <- .jcall(.class_DataUtils, makeResponse(.class_OnlineForest), "loadOnlineForest",
|
||||||
|
tree.file.array, castedForestResponseCombiner)
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
} else{ # offline forest
|
||||||
|
if(is.null(maxTreeNum)){
|
||||||
|
path.as.file <- .jnew(.class_File, treeDirectory)
|
||||||
|
forest.java <- .jnew(.class_OfflineForest, path.as.file, forestResponseCombiner$javaObject)
|
||||||
|
} else{
|
||||||
|
tree.file.array <- .jcall(.class_RUtils, paste0("[", makeResponse(.class_File)), "getTreeFileArray",
|
||||||
|
treeDirectory, as.integer(maxTreeNum), evalArray = FALSE)
|
||||||
|
forest.java <- .jnew(.class_OfflineForest, tree.file.array, forestResponseCombiner$javaObject)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
forestObject <- list(call=call, javaObject=forest.java, covariateList=covariateList.java, params=params)
|
forestObject <- list(call=call, javaObject=forest.java, covariateList=covariateList.java, params=params)
|
||||||
class(forestObject) <- "JRandomForest"
|
class(forestObject) <- "JRandomForest"
|
||||||
|
|
|
@ -48,7 +48,7 @@ WeightedVarianceSplitFinder <- function(){
|
||||||
#'
|
#'
|
||||||
MeanResponseCombiner <- function(){
|
MeanResponseCombiner <- function(){
|
||||||
javaObject <- .jnew(.class_MeanResponseCombiner)
|
javaObject <- .jnew(.class_MeanResponseCombiner)
|
||||||
javaObject <- .jcast(javaObject, .class_ResponseCombiner)
|
javaObject <- .jcast(javaObject, .class_ForestResponseCombiner)
|
||||||
|
|
||||||
combiner <- list(javaObject=javaObject, call=match.call(), outputClass="numeric")
|
combiner <- list(javaObject=javaObject, call=match.call(), outputClass="numeric")
|
||||||
combiner$convertToRFunction <- function(javaObject, ...){
|
combiner$convertToRFunction <- function(javaObject, ...){
|
||||||
|
|
40
R/train.R
40
R/train.R
|
@ -18,7 +18,7 @@ train.internal <- function(dataset, splitFinder,
|
||||||
nodeResponseCombiner, forestResponseCombiner, ntree,
|
nodeResponseCombiner, forestResponseCombiner, ntree,
|
||||||
numberOfSplits, mtry, nodeSize, maxNodeDepth,
|
numberOfSplits, mtry, nodeSize, maxNodeDepth,
|
||||||
splitPureNodes, savePath, savePath.overwrite,
|
splitPureNodes, savePath, savePath.overwrite,
|
||||||
cores, randomSeed, displayProgress){
|
forest.output, cores, randomSeed, displayProgress){
|
||||||
|
|
||||||
# Some quick checks on parameters
|
# Some quick checks on parameters
|
||||||
ntree <- as.integer(ntree)
|
ntree <- as.integer(ntree)
|
||||||
|
@ -51,6 +51,10 @@ train.internal <- function(dataset, splitFinder,
|
||||||
stop("savePath.overwrite must be one of c(\"warn\", \"delete\", \"merge\")")
|
stop("savePath.overwrite must be one of c(\"warn\", \"delete\", \"merge\")")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if(is.null(forest.output) | length(forest.output)==0 | !(forest.output[1] %in% c("online", "offline"))){
|
||||||
|
stop("forest.output must be one of c(\"online\", \"offline\")")
|
||||||
|
}
|
||||||
|
|
||||||
if(is.null(splitFinder)){
|
if(is.null(splitFinder)){
|
||||||
splitFinder <- splitFinderDefault(dataset$responses)
|
splitFinder <- splitFinderDefault(dataset$responses)
|
||||||
}
|
}
|
||||||
|
@ -129,22 +133,23 @@ train.internal <- function(dataset, splitFinder,
|
||||||
params=params,
|
params=params,
|
||||||
forestCall=match.call())
|
forestCall=match.call())
|
||||||
|
|
||||||
|
forest.java <- NULL
|
||||||
if(cores > 1){
|
if(cores > 1){
|
||||||
.jcall(forestTrainer, "V", "trainParallelOnDisk", .object_Optional(), as.integer(cores))
|
forest.java <- .jcall(forestTrainer, makeResponse(.class_OfflineForest), "trainParallelOnDisk", .object_Optional(), as.integer(cores))
|
||||||
} else {
|
} else {
|
||||||
.jcall(forestTrainer, "V", "trainSerialOnDisk", .object_Optional())
|
forest.java <- .jcall(forestTrainer, makeResponse(.class_OfflineForest), "trainSerialOnDisk", .object_Optional())
|
||||||
}
|
}
|
||||||
|
|
||||||
# Need to now load forest trees back into memory
|
if(forest.output[1] == "online"){
|
||||||
forest.java <- .jcall(.class_DataUtils, makeResponse(.class_Forest), "loadForest", savePath, forestResponseCombiner$javaObject)
|
forest.java <- convertToOnlineForest.Java(forest.java)
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
else{ # save directly into memory
|
else{ # save directly into memory
|
||||||
if(cores > 1){
|
if(cores > 1){
|
||||||
forest.java <- .jcall(forestTrainer, makeResponse(.class_Forest), "trainParallelInMemory", .object_Optional(), as.integer(cores))
|
forest.java <- .jcall(forestTrainer, makeResponse(.class_OnlineForest), "trainParallelInMemory", .object_Optional(), as.integer(cores))
|
||||||
} else {
|
} else {
|
||||||
forest.java <- .jcall(forestTrainer, makeResponse(.class_Forest), "trainSerialInMemory", .object_Optional())
|
forest.java <- .jcall(forestTrainer, makeResponse(.class_OnlineForest), "trainSerialInMemory", .object_Optional())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -253,6 +258,10 @@ train.internal <- function(dataset, splitFinder,
|
||||||
#' assumes (without checking) that the existing trees are from a previous run
|
#' assumes (without checking) that the existing trees are from a previous run
|
||||||
#' and starts from where it left off. This option is useful if recovering from
|
#' and starts from where it left off. This option is useful if recovering from
|
||||||
#' a crash.
|
#' a crash.
|
||||||
|
#' @param forest.output This parameter only applies if \code{savePath} has been
|
||||||
|
#' set; set to 'online' (default) and the saved forest will be loaded into
|
||||||
|
#' memory after being trained. Set to 'offline' and the forest is not saved
|
||||||
|
#' into memory, but can still be used in a memory unintensive manner.
|
||||||
#' @param cores This parameter specifies how many trees will be simultaneously
|
#' @param cores This parameter specifies how many trees will be simultaneously
|
||||||
#' trained. By default the package attempts to detect how many cores you have
|
#' trained. By default the package attempts to detect how many cores you have
|
||||||
#' by using the \code{parallel} package and using all of them. You may specify
|
#' by using the \code{parallel} package and using all of them. You may specify
|
||||||
|
@ -311,7 +320,8 @@ train.internal <- function(dataset, splitFinder,
|
||||||
train <- function(formula, data, splitFinder = NULL, nodeResponseCombiner = NULL,
|
train <- function(formula, data, splitFinder = NULL, nodeResponseCombiner = NULL,
|
||||||
forestResponseCombiner = NULL, ntree, numberOfSplits, mtry,
|
forestResponseCombiner = NULL, ntree, numberOfSplits, mtry,
|
||||||
nodeSize, maxNodeDepth = 100000, na.penalty = TRUE, splitPureNodes=TRUE,
|
nodeSize, maxNodeDepth = 100000, na.penalty = TRUE, splitPureNodes=TRUE,
|
||||||
savePath=NULL, savePath.overwrite=c("warn", "delete", "merge"),
|
savePath = NULL, savePath.overwrite = c("warn", "delete", "merge"),
|
||||||
|
forest.output = c("online", "offline"),
|
||||||
cores = getCores(), randomSeed = NULL, displayProgress = TRUE){
|
cores = getCores(), randomSeed = NULL, displayProgress = TRUE){
|
||||||
|
|
||||||
dataset <- processFormula(formula, data, na.penalty = na.penalty)
|
dataset <- processFormula(formula, data, na.penalty = na.penalty)
|
||||||
|
@ -322,8 +332,8 @@ train <- function(formula, data, splitFinder = NULL, nodeResponseCombiner = NULL
|
||||||
ntree = ntree, numberOfSplits = numberOfSplits,
|
ntree = ntree, numberOfSplits = numberOfSplits,
|
||||||
mtry = mtry, nodeSize = nodeSize, maxNodeDepth = maxNodeDepth,
|
mtry = mtry, nodeSize = nodeSize, maxNodeDepth = maxNodeDepth,
|
||||||
splitPureNodes = splitPureNodes, savePath = savePath,
|
splitPureNodes = splitPureNodes, savePath = savePath,
|
||||||
savePath.overwrite = savePath.overwrite, cores = cores,
|
savePath.overwrite = savePath.overwrite, forest.output = forest.output,
|
||||||
randomSeed = randomSeed, displayProgress = displayProgress)
|
cores = cores, randomSeed = randomSeed, displayProgress = displayProgress)
|
||||||
|
|
||||||
forest$call <- match.call()
|
forest$call <- match.call()
|
||||||
forest$formula <- formula
|
forest$formula <- formula
|
||||||
|
@ -370,8 +380,10 @@ createTreeTrainer <- function(responseCombiner, splitFinder, covariateList, numb
|
||||||
builderClassReturned <- makeResponse(.class_TreeTrainer_Builder)
|
builderClassReturned <- makeResponse(.class_TreeTrainer_Builder)
|
||||||
|
|
||||||
builder <- .jcall(.class_TreeTrainer, builderClassReturned, "builder")
|
builder <- .jcall(.class_TreeTrainer, builderClassReturned, "builder")
|
||||||
|
|
||||||
builder <- .jcall(builder, builderClassReturned, "responseCombiner", responseCombiner$javaObject)
|
responseCombinerCasted <- .jcast(responseCombiner$javaObject, .class_ResponseCombiner) # might need to cast a ForestResponseCombiner down
|
||||||
|
|
||||||
|
builder <- .jcall(builder, builderClassReturned, "responseCombiner", responseCombinerCasted)
|
||||||
builder <- .jcall(builder, builderClassReturned, "splitFinder", splitFinder$javaObject)
|
builder <- .jcall(builder, builderClassReturned, "splitFinder", splitFinder$javaObject)
|
||||||
builder <- .jcall(builder, builderClassReturned, "covariates", covariateList)
|
builder <- .jcall(builder, builderClassReturned, "covariates", covariateList)
|
||||||
builder <- .jcall(builder, builderClassReturned, "numberOfSplits", as.integer(numberOfSplits))
|
builder <- .jcall(builder, builderClassReturned, "numberOfSplits", as.integer(numberOfSplits))
|
||||||
|
|
Binary file not shown.
|
@ -6,7 +6,8 @@
|
||||||
\usage{
|
\usage{
|
||||||
addTrees(forest, numTreesToAdd, savePath = NULL,
|
addTrees(forest, numTreesToAdd, savePath = NULL,
|
||||||
savePath.overwrite = c("warn", "delete", "merge"),
|
savePath.overwrite = c("warn", "delete", "merge"),
|
||||||
cores = getCores(), displayProgress = TRUE)
|
forest.output = c("online", "offline"), cores = getCores(),
|
||||||
|
displayProgress = TRUE)
|
||||||
}
|
}
|
||||||
\arguments{
|
\arguments{
|
||||||
\item{forest}{An existing forest.}
|
\item{forest}{An existing forest.}
|
||||||
|
@ -21,6 +22,11 @@ a previously saved forest.}
|
||||||
directory, possibly containing another forest, this specifies what should
|
directory, possibly containing another forest, this specifies what should
|
||||||
be done.}
|
be done.}
|
||||||
|
|
||||||
|
\item{forest.output}{This parameter only applies if \code{savePath} has been
|
||||||
|
set; set to 'online' (default) and the saved forest will be loaded into
|
||||||
|
memory after being trained. Set to 'offline' and the forest is not saved
|
||||||
|
into memory, but can still be used in a memory unintensive manner.}
|
||||||
|
|
||||||
\item{cores}{The number of cores to be used for training the new trees.}
|
\item{cores}{The number of cores to be used for training the new trees.}
|
||||||
|
|
||||||
\item{displayProgress}{A logical indicating whether the progress should be
|
\item{displayProgress}{A logical indicating whether the progress should be
|
||||||
|
|
20
man/convertToOnlineForest.Rd
Normal file
20
man/convertToOnlineForest.Rd
Normal file
|
@ -0,0 +1,20 @@
|
||||||
|
% Generated by roxygen2: do not edit by hand
|
||||||
|
% Please edit documentation in R/convertToOnlineForest.R
|
||||||
|
\name{convertToOnlineForest}
|
||||||
|
\alias{convertToOnlineForest}
|
||||||
|
\title{Convert to Online Forest}
|
||||||
|
\usage{
|
||||||
|
convertToOnlineForest(forest)
|
||||||
|
}
|
||||||
|
\arguments{
|
||||||
|
\item{forest}{The offline forest.}
|
||||||
|
}
|
||||||
|
\value{
|
||||||
|
An online, in memory forst.
|
||||||
|
}
|
||||||
|
\description{
|
||||||
|
Some forests are too large to store in memory and have been saved to disk.
|
||||||
|
They can still be used, but their performance is much slower. If there's
|
||||||
|
enough memory, they can easily be converted into an in-memory forest that is
|
||||||
|
faster to use.
|
||||||
|
}
|
|
@ -4,10 +4,19 @@
|
||||||
\alias{loadForest}
|
\alias{loadForest}
|
||||||
\title{Load Random Forest}
|
\title{Load Random Forest}
|
||||||
\usage{
|
\usage{
|
||||||
loadForest(directory)
|
loadForest(directory, forest.output = c("online", "offline"),
|
||||||
|
maxTreeNum = NULL)
|
||||||
}
|
}
|
||||||
\arguments{
|
\arguments{
|
||||||
\item{directory}{The directory created that saved the previous forest.}
|
\item{directory}{The directory created that saved the previous forest.}
|
||||||
|
|
||||||
|
\item{forest.output}{Specifies whether the forest loaded should be loaded
|
||||||
|
into memory, or reflect the saved files where only one tree is loaded at a
|
||||||
|
time.}
|
||||||
|
|
||||||
|
\item{maxTreeNum}{If for some reason you only want to load the number of
|
||||||
|
trees up until a certain point, you can specify maxTreeNum as a single
|
||||||
|
number.}
|
||||||
}
|
}
|
||||||
\value{
|
\value{
|
||||||
A JForest object; see \code{\link{train}} for details.
|
A JForest object; see \code{\link{train}} for details.
|
||||||
|
|
|
@ -8,7 +8,8 @@ train(formula, data, splitFinder = NULL, nodeResponseCombiner = NULL,
|
||||||
forestResponseCombiner = NULL, ntree, numberOfSplits, mtry, nodeSize,
|
forestResponseCombiner = NULL, ntree, numberOfSplits, mtry, nodeSize,
|
||||||
maxNodeDepth = 1e+05, na.penalty = TRUE, splitPureNodes = TRUE,
|
maxNodeDepth = 1e+05, na.penalty = TRUE, splitPureNodes = TRUE,
|
||||||
savePath = NULL, savePath.overwrite = c("warn", "delete", "merge"),
|
savePath = NULL, savePath.overwrite = c("warn", "delete", "merge"),
|
||||||
cores = getCores(), randomSeed = NULL, displayProgress = TRUE)
|
forest.output = c("online", "offline"), cores = getCores(),
|
||||||
|
randomSeed = NULL, displayProgress = TRUE)
|
||||||
}
|
}
|
||||||
\arguments{
|
\arguments{
|
||||||
\item{formula}{You may specify the response and covariates as a formula
|
\item{formula}{You may specify the response and covariates as a formula
|
||||||
|
@ -106,6 +107,11 @@ assumes (without checking) that the existing trees are from a previous run
|
||||||
and starts from where it left off. This option is useful if recovering from
|
and starts from where it left off. This option is useful if recovering from
|
||||||
a crash.}
|
a crash.}
|
||||||
|
|
||||||
|
\item{forest.output}{This parameter only applies if \code{savePath} has been
|
||||||
|
set; set to 'online' (default) and the saved forest will be loaded into
|
||||||
|
memory after being trained. Set to 'offline' and the forest is not saved
|
||||||
|
into memory, but can still be used in a memory unintensive manner.}
|
||||||
|
|
||||||
\item{cores}{This parameter specifies how many trees will be simultaneously
|
\item{cores}{This parameter specifies how many trees will be simultaneously
|
||||||
trained. By default the package attempts to detect how many cores you have
|
trained. By default the package attempts to detect how many cores you have
|
||||||
by using the \code{parallel} package and using all of them. You may specify
|
by using the \code{parallel} package and using all of them. You may specify
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
context("Add trees on existing forest")
|
context("Add trees on existing forest")
|
||||||
|
|
||||||
test_that("Can add trees on existing forest", {
|
test_that("Can add trees on existing online forest", {
|
||||||
|
|
||||||
trainingData <- data.frame(x=rnorm(100))
|
trainingData <- data.frame(x=rnorm(100))
|
||||||
trainingData$T <- rexp(100) + abs(trainingData$x)
|
trainingData$T <- rexp(100) + abs(trainingData$x)
|
||||||
|
@ -20,6 +20,44 @@ test_that("Can add trees on existing forest", {
|
||||||
|
|
||||||
})
|
})
|
||||||
|
|
||||||
|
test_that("Can add trees on existing offline forest", {
|
||||||
|
|
||||||
|
if(file.exists("trees")){ # folder could exist from a previous failed test; delete it
|
||||||
|
unlink("trees", recursive=TRUE)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
trainingData <- data.frame(x=rnorm(100))
|
||||||
|
trainingData$T <- rexp(100) + abs(trainingData$x)
|
||||||
|
trainingData$delta <- sample(0:2, size = 100, replace=TRUE)
|
||||||
|
|
||||||
|
forest <- train(CR_Response(delta, T) ~ x, trainingData, ntree=50,
|
||||||
|
numberOfSplits=0, mtry=1, nodeSize=5,
|
||||||
|
forestResponseCombiner = CR_FunctionCombiner(events = 1:2, times = 0:10), # TODO - remove specifing times; this is workaround around unimplemented feature for offline forests
|
||||||
|
cores=2, displayProgress=FALSE, savePath="trees",
|
||||||
|
forest.output = "offline")
|
||||||
|
warning("TODO - need to implement feature; test workaround in the meantime")
|
||||||
|
|
||||||
|
predictions <- predict(forest)
|
||||||
|
|
||||||
|
warning_message <- "Assuming that the previous forest at savePath is the provided forest argument; if not true then your results will be suspect"
|
||||||
|
|
||||||
|
forest.more <- expect_warning(addTrees(forest, 50, cores=2, displayProgress=FALSE,
|
||||||
|
savePath="trees", savePath.overwrite = "merge",
|
||||||
|
forest.output = "offline"), fixed=warning_message) # test multi-core
|
||||||
|
|
||||||
|
predictions <- predict(forest)
|
||||||
|
|
||||||
|
forest.more <- expect_warning(addTrees(forest, 50, cores=1, displayProgress=FALSE,
|
||||||
|
savePath="trees", savePath.overwrite = "merge",
|
||||||
|
forest.output = "offline"), fixed=warning_message) # test single-core
|
||||||
|
|
||||||
|
expect_true(T) # show Ok if we got this far
|
||||||
|
|
||||||
|
unlink("trees", recursive=TRUE)
|
||||||
|
|
||||||
|
})
|
||||||
|
|
||||||
test_that("Test adding trees on saved forest - using delete", {
|
test_that("Test adding trees on saved forest - using delete", {
|
||||||
|
|
||||||
if(file.exists("trees")){ # folder could exist from a previous failed test; delete it
|
if(file.exists("trees")){ # folder could exist from a previous failed test; delete it
|
||||||
|
|
|
@ -2,7 +2,11 @@ context("Train, save, and load without error")
|
||||||
|
|
||||||
test_that("Can save & load regression example", {
|
test_that("Can save & load regression example", {
|
||||||
|
|
||||||
expect_false(file.exists("trees_saving_loading")) # Folder shouldn't exist yet
|
if(file.exists("trees_saving_loading")){
|
||||||
|
unlink("trees_saving_loading", recursive=TRUE)
|
||||||
|
}
|
||||||
|
|
||||||
|
expect_false(file.exists("trees_saving_loading")) # Folder shouldn't exist at this point
|
||||||
|
|
||||||
x1 <- rnorm(1000)
|
x1 <- rnorm(1000)
|
||||||
x2 <- rnorm(1000)
|
x2 <- rnorm(1000)
|
||||||
|
|
|
@ -13,7 +13,7 @@ test_that("Can save a random forest while training, and use it afterward", {
|
||||||
data <- data.frame(x1, x2, y)
|
data <- data.frame(x1, x2, y)
|
||||||
forest <- train(y ~ x1 + x2, data,
|
forest <- train(y ~ x1 + x2, data,
|
||||||
ntree=100, numberOfSplits = 5, mtry = 1, nodeSize = 5,
|
ntree=100, numberOfSplits = 5, mtry = 1, nodeSize = 5,
|
||||||
savePath="trees", displayProgress=FALSE)
|
savePath="trees", forest.output = "online", displayProgress=FALSE)
|
||||||
|
|
||||||
expect_true(file.exists("trees")) # Something should have been saved
|
expect_true(file.exists("trees")) # Something should have been saved
|
||||||
|
|
||||||
|
@ -26,6 +26,39 @@ test_that("Can save a random forest while training, and use it afterward", {
|
||||||
predictions <- predict(newforest, newData)
|
predictions <- predict(newforest, newData)
|
||||||
|
|
||||||
|
|
||||||
|
unlink("trees", recursive=TRUE)
|
||||||
|
|
||||||
|
})
|
||||||
|
|
||||||
|
test_that("Can save a random forest while training, and use it afterward with pure offline forest", {
|
||||||
|
|
||||||
|
if(file.exists("trees")){ # folder could exist from a previous failed test; delete it
|
||||||
|
unlink("trees", recursive=TRUE)
|
||||||
|
}
|
||||||
|
|
||||||
|
x1 <- rnorm(1000)
|
||||||
|
x2 <- rnorm(1000)
|
||||||
|
y <- 1 + x1 + x2 + rnorm(1000)
|
||||||
|
|
||||||
|
data <- data.frame(x1, x2, y)
|
||||||
|
forest <- train(y ~ x1 + x2, data,
|
||||||
|
ntree=100, numberOfSplits = 5, mtry = 1, nodeSize = 5,
|
||||||
|
savePath="trees", forest.output = "offline", displayProgress=FALSE)
|
||||||
|
|
||||||
|
expect_true(file.exists("trees")) # Something should have been saved
|
||||||
|
|
||||||
|
# try making a little prediction to verify it works
|
||||||
|
newData <- data.frame(x1=seq(from=-3, to=3, by=0.5), x2=0)
|
||||||
|
predictions <- predict(forest, newData)
|
||||||
|
|
||||||
|
# Also make sure we can load the forest too
|
||||||
|
newforest <- loadForest("trees")
|
||||||
|
predictions <- predict(newforest, newData)
|
||||||
|
|
||||||
|
# Last, make sure we can take the forest online
|
||||||
|
onlineForest <- convertToOnlineForest(forest)
|
||||||
|
predictions <- predict(onlineForest, newData)
|
||||||
|
|
||||||
unlink("trees", recursive=TRUE)
|
unlink("trees", recursive=TRUE)
|
||||||
|
|
||||||
})
|
})
|
Loading…
Reference in a new issue