diff --git a/NAMESPACE b/NAMESPACE index 402bb00..ac71073 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -26,6 +26,7 @@ export(Numeric) export(WeightedVarianceSplitFinder) export(addTrees) export(connectToData) +export(convertToOnlineForest) export(extractCHF) export(extractCIF) export(extractMortalities) diff --git a/R/addTrees.R b/R/addTrees.R index bb7f41a..79f82ae 100644 --- a/R/addTrees.R +++ b/R/addTrees.R @@ -14,6 +14,10 @@ #' @param savePath.overwrite If \code{savePath} is pointing to an existing #' directory, possibly containing another forest, this specifies what should #' be done. +#' @param forest.output This parameter only applies if \code{savePath} has been +#' set; set to 'online' (default) and the saved forest will be loaded into +#' memory after being trained. Set to 'offline' and the forest is not saved +#' into memory, but can still be used in a memory unintensive manner. #' @param cores The number of cores to be used for training the new trees. #' @param displayProgress A logical indicating whether the progress should be #' displayed to console; default is \code{TRUE}. Useful to set to FALSE in @@ -22,7 +26,11 @@ #' @return A new forest with the original and additional trees. #' @export #' -addTrees <- function(forest, numTreesToAdd, savePath = NULL, savePath.overwrite = c("warn", "delete", "merge"), cores = getCores(), displayProgress = TRUE){ +addTrees <- function(forest, numTreesToAdd, savePath = NULL, + savePath.overwrite = c("warn", "delete", "merge"), + forest.output = c("online", "offline"), + cores = getCores(), displayProgress = TRUE){ + if(is.null(forest$dataset)){ stop("Training dataset must be connected to forest before more trees can be added; this can be done manually by using connectToData") } @@ -37,6 +45,10 @@ addTrees <- function(forest, numTreesToAdd, savePath = NULL, savePath.overwrite stop("savePath.overwrite must be one of c(\"warn\", \"delete\", \"merge\")") } + if(is.null(forest.output) | length(forest.output)==0 | !(forest.output[1] %in% c("online", "offline"))){ + stop("forest.output must be one of c(\"online\", \"offline\")") + } + newTreeCount <- forest$params$ntree + as.integer(numTreesToAdd) treeTrainer <- createTreeTrainer(responseCombiner=forest$params$nodeResponseCombiner, @@ -98,22 +110,23 @@ addTrees <- function(forest, numTreesToAdd, savePath = NULL, savePath.overwrite params=params, forestCall=match.call()) + forest.java <- NULL if(cores > 1){ - .jcall(forestTrainer, "V", "trainParallelOnDisk", initial.forest.optional, as.integer(cores)) + forest.java <- .jcall(forestTrainer, makeResponse(.class_OfflineForest), "trainParallelOnDisk", initial.forest.optional, as.integer(cores)) } else { - .jcall(forestTrainer, "V", "trainSerialOnDisk", initial.forest.optional) + forest.java <- .jcall(forestTrainer, makeResponse(.class_OfflineForest), "trainSerialOnDisk", initial.forest.optional) } - # Need to now load forest trees back into memory - forest.java <- .jcall(.class_DataUtils, makeResponse(.class_Forest), "loadForest", savePath, forest$params$forestResponseCombiner$javaObject) - + if(forest.output[1] == "online"){ + forest.java <- convertToOnlineForest.Java(forest.java) + } } else{ # save directly into memory if(cores > 1){ - forest.java <- .jcall(forestTrainer, makeResponse(.class_Forest), "trainParallelInMemory", initial.forest.optional, as.integer(cores)) + forest.java <- .jcall(forestTrainer, makeResponse(.class_OnlineForest), "trainParallelInMemory", initial.forest.optional, as.integer(cores)) } else { - forest.java <- .jcall(forestTrainer, makeResponse(.class_Forest), "trainSerialInMemory", initial.forest.optional) + forest.java <- .jcall(forestTrainer, makeResponse(.class_OnlineForest), "trainSerialInMemory", initial.forest.optional) } } diff --git a/R/convertToOnlineForest.R b/R/convertToOnlineForest.R new file mode 100644 index 0000000..5e32b82 --- /dev/null +++ b/R/convertToOnlineForest.R @@ -0,0 +1,38 @@ + +#' Convert to Online Forest +#' +#' Some forests are too large to store in memory and have been saved to disk. +#' They can still be used, but their performance is much slower. If there's +#' enough memory, they can easily be converted into an in-memory forest that is +#' faster to use. +#' +#' @param forest The offline forest. +#' +#' @return An online, in memory forst. +#' @export +#' +convertToOnlineForest <- function(forest){ + old.forest.object <- forest$javaObject + + if(getJavaClass(old.forest.object) == "ca.joeltherrien.randomforest.tree.OnlineForest"){ + + warning("forest is already in-memory") + return(forest) + + } else if(getJavaClass(old.forest.object) == "ca.joeltherrien.randomforest.tree.OfflineForest"){ + + forest$javaObject <- convertToOnlineForest.Java(old.forest.object) + return(forest) + + } else{ + stop("'forest' is not an online or offline forest") + } + +} + +# Internal function +convertToOnlineForest.Java <- function(forest.java){ + offline.forest <- .jcast(forest.java, .class_OfflineForest) + online.forest <- .jcall(offline.forest, makeResponse(.class_OnlineForest), "createOnlineCopy") + return(online.forest) +} \ No newline at end of file diff --git a/R/cr_components.R b/R/cr_components.R index 2e7817d..d8fe18e 100644 --- a/R/cr_components.R +++ b/R/cr_components.R @@ -42,7 +42,7 @@ CR_FunctionCombiner <- function(events, times = NULL){ } javaObject <- .jnew(.class_CompetingRiskFunctionCombiner, eventArray, timeArray) - javaObject <- .jcast(javaObject, .class_ResponseCombiner) + javaObject <- .jcast(javaObject, .class_ForestResponseCombiner) combiner <- list(javaObject=javaObject, call=match.call(), diff --git a/R/java_classes_directory.R b/R/java_classes_directory.R index 8ffe241..759d567 100644 --- a/R/java_classes_directory.R +++ b/R/java_classes_directory.R @@ -11,6 +11,7 @@ .class_Serializable <- "java/io/Serializable" .class_File <- "java/io/File" .class_Random <- "java/util/Random" +.class_Class <- "java/lang/Class" # Utility Classes .class_DataUtils <- "ca/joeltherrien/randomforest/utils/DataUtils" @@ -41,9 +42,12 @@ # Forest class .class_Forest <- "ca/joeltherrien/randomforest/tree/Forest" +.class_OnlineForest <- "ca/joeltherrien/randomforest/tree/OnlineForest" +.class_OfflineForest <- "ca/joeltherrien/randomforest/tree/OfflineForest" # ResponseCombiner classes .class_ResponseCombiner <- "ca/joeltherrien/randomforest/tree/ResponseCombiner" +.class_ForestResponseCombiner <- "ca/joeltherrien/randomforest/tree/ForestResponseCombiner" .class_CompetingRiskResponseCombiner <- "ca/joeltherrien/randomforest/responses/competingrisk/combiner/CompetingRiskResponseCombiner" .class_CompetingRiskFunctionCombiner <- "ca/joeltherrien/randomforest/responses/competingrisk/combiner/CompetingRiskFunctionCombiner" .class_MeanResponseCombiner <- "ca/joeltherrien/randomforest/responses/regression/MeanResponseCombiner" @@ -77,4 +81,10 @@ # This function does that. makeResponse <- function(className){ return(paste0("L", className, ";")) -} \ No newline at end of file +} + +getJavaClass <- function(object){ + class <- .jcall(object, makeResponse(.class_Class), "getClass") + className <- .jcall(class, "S", "getName") + return(className) +} diff --git a/R/loadForest.R b/R/loadForest.R index ded2cde..9af28a6 100644 --- a/R/loadForest.R +++ b/R/loadForest.R @@ -5,6 +5,12 @@ #' Loads a random forest that was saved using \code{\link{saveForest}}. #' #' @param directory The directory created that saved the previous forest. +#' @param forest.output Specifies whether the forest loaded should be loaded +#' into memory, or reflect the saved files where only one tree is loaded at a +#' time. +#' @param maxTreeNum If for some reason you only want to load the number of +#' trees up until a certain point, you can specify maxTreeNum as a single +#' number. #' @return A JForest object; see \code{\link{train}} for details. #' @export #' @seealso \code{\link{train}}, \code{\link{saveForest}} @@ -20,7 +26,11 @@ #' #' saveForest(forest, "trees") #' new_forest <- loadForest("trees") -loadForest <- function(directory){ +loadForest <- function(directory, forest.output = c("online", "offline"), maxTreeNum = NULL){ + + if(is.null(forest.output) | length(forest.output)==0 | !(forest.output[1] %in% c("online", "offline"))){ + stop("forest.output must be one of c(\"online\", \"offline\")") + } # First load the response combiners and the split finders nodeResponseCombiner.java <- .jcall(.class_DataUtils, makeResponse(.class_Object), "loadObject", paste0(directory, "/nodeResponseCombiner.jData")) @@ -30,7 +40,7 @@ loadForest <- function(directory){ splitFinder.java <- .jcast(splitFinder.java, .class_SplitFinder) forestResponseCombiner.java <- .jcall(.class_DataUtils, makeResponse(.class_Object), "loadObject", paste0(directory, "/forestResponseCombiner.jData")) - forestResponseCombiner.java <- .jcast(forestResponseCombiner.java, .class_ResponseCombiner) + forestResponseCombiner.java <- .jcast(forestResponseCombiner.java, .class_ForestResponseCombiner) covariateList <- .jcall(.class_DataUtils, makeResponse(.class_Object), "loadObject", paste0(directory, "/covariateList.jData")) covariateList <- .jcast(covariateList, .class_List) @@ -42,8 +52,11 @@ loadForest <- function(directory){ params$splitFinder$javaObject <- splitFinder.java params$forestResponseCombiner$javaObject <- forestResponseCombiner.java - forest <- loadForestArgumentsSpecified(directory, params$nodeResponseCombiner, params$splitFinder, params$forestResponseCombiner, covariateList, call, - params$ntree, params$numberOfSplits, params$mtry, params$nodeSize, params$maxNodeDepth, params$splitPureNodes, params$randomSeed) + forest <- loadForestArgumentsSpecified(directory, params$nodeResponseCombiner, params$splitFinder, + params$forestResponseCombiner, covariateList, call, + params$ntree, params$numberOfSplits, params$mtry, + params$nodeSize, params$maxNodeDepth, params$splitPureNodes, + params$randomSeed, forest.output, maxTreeNum) return(forest) @@ -55,8 +68,11 @@ loadForest <- function(directory){ # that uses the Java version's settings yaml file to recreate the forest, but # I'd appreciate knowing that someone's going to use it first (email me; see # README). -loadForestArgumentsSpecified <- function(treeDirectory, nodeResponseCombiner, splitFinder, forestResponseCombiner, - covariateList.java, call, ntree, numberOfSplits, mtry, nodeSize, maxNodeDepth = 100000, splitPureNodes=TRUE, randomSeed=NULL){ +loadForestArgumentsSpecified <- function(treeDirectory, nodeResponseCombiner, splitFinder, + forestResponseCombiner, covariateList.java, call, + ntree, numberOfSplits, mtry, nodeSize, + maxNodeDepth = 100000, splitPureNodes=TRUE, + randomSeed=NULL, forest.output = "online", maxTreeNum = NULL){ params <- list( splitFinder=splitFinder, @@ -71,7 +87,33 @@ loadForestArgumentsSpecified <- function(treeDirectory, nodeResponseCombiner, sp randomSeed=randomSeed ) - forest.java <- .jcall(.class_DataUtils, makeResponse(.class_Forest), "loadForest", treeDirectory, forestResponseCombiner$javaObject) + forest.java <- NULL + if(forest.output[1] == "online"){ + castedForestResponseCombiner <- .jcast(forestResponseCombiner$javaObject, .class_ResponseCombiner) # OnlineForest constructor takes a ResponseCombiner + + if(is.null(maxTreeNum)){ + forest.java <- .jcall(.class_DataUtils, makeResponse(.class_OnlineForest), "loadOnlineForest", + treeDirectory, castedForestResponseCombiner) + } else{ + tree.file.array <- .jcall(.class_RUtils, paste0("[", makeResponse(.class_File)), "getTreeFileArray", + treeDirectory, as.integer(maxTreeNum), evalArray = FALSE) + forest.java <- .jcall(.class_DataUtils, makeResponse(.class_OnlineForest), "loadOnlineForest", + tree.file.array, castedForestResponseCombiner) + + } + + } else{ # offline forest + if(is.null(maxTreeNum)){ + path.as.file <- .jnew(.class_File, treeDirectory) + forest.java <- .jnew(.class_OfflineForest, path.as.file, forestResponseCombiner$javaObject) + } else{ + tree.file.array <- .jcall(.class_RUtils, paste0("[", makeResponse(.class_File)), "getTreeFileArray", + treeDirectory, as.integer(maxTreeNum), evalArray = FALSE) + forest.java <- .jnew(.class_OfflineForest, tree.file.array, forestResponseCombiner$javaObject) + } + } + + forestObject <- list(call=call, javaObject=forest.java, covariateList=covariateList.java, params=params) class(forestObject) <- "JRandomForest" diff --git a/R/regressionComponents.R b/R/regressionComponents.R index 681ab6f..d2ce1e8 100644 --- a/R/regressionComponents.R +++ b/R/regressionComponents.R @@ -48,7 +48,7 @@ WeightedVarianceSplitFinder <- function(){ #' MeanResponseCombiner <- function(){ javaObject <- .jnew(.class_MeanResponseCombiner) - javaObject <- .jcast(javaObject, .class_ResponseCombiner) + javaObject <- .jcast(javaObject, .class_ForestResponseCombiner) combiner <- list(javaObject=javaObject, call=match.call(), outputClass="numeric") combiner$convertToRFunction <- function(javaObject, ...){ diff --git a/R/train.R b/R/train.R index a513413..319b322 100644 --- a/R/train.R +++ b/R/train.R @@ -18,7 +18,7 @@ train.internal <- function(dataset, splitFinder, nodeResponseCombiner, forestResponseCombiner, ntree, numberOfSplits, mtry, nodeSize, maxNodeDepth, splitPureNodes, savePath, savePath.overwrite, - cores, randomSeed, displayProgress){ + forest.output, cores, randomSeed, displayProgress){ # Some quick checks on parameters ntree <- as.integer(ntree) @@ -51,6 +51,10 @@ train.internal <- function(dataset, splitFinder, stop("savePath.overwrite must be one of c(\"warn\", \"delete\", \"merge\")") } + if(is.null(forest.output) | length(forest.output)==0 | !(forest.output[1] %in% c("online", "offline"))){ + stop("forest.output must be one of c(\"online\", \"offline\")") + } + if(is.null(splitFinder)){ splitFinder <- splitFinderDefault(dataset$responses) } @@ -129,22 +133,23 @@ train.internal <- function(dataset, splitFinder, params=params, forestCall=match.call()) + forest.java <- NULL if(cores > 1){ - .jcall(forestTrainer, "V", "trainParallelOnDisk", .object_Optional(), as.integer(cores)) + forest.java <- .jcall(forestTrainer, makeResponse(.class_OfflineForest), "trainParallelOnDisk", .object_Optional(), as.integer(cores)) } else { - .jcall(forestTrainer, "V", "trainSerialOnDisk", .object_Optional()) + forest.java <- .jcall(forestTrainer, makeResponse(.class_OfflineForest), "trainSerialOnDisk", .object_Optional()) } - # Need to now load forest trees back into memory - forest.java <- .jcall(.class_DataUtils, makeResponse(.class_Forest), "loadForest", savePath, forestResponseCombiner$javaObject) - - + if(forest.output[1] == "online"){ + forest.java <- convertToOnlineForest.Java(forest.java) + } + } else{ # save directly into memory if(cores > 1){ - forest.java <- .jcall(forestTrainer, makeResponse(.class_Forest), "trainParallelInMemory", .object_Optional(), as.integer(cores)) + forest.java <- .jcall(forestTrainer, makeResponse(.class_OnlineForest), "trainParallelInMemory", .object_Optional(), as.integer(cores)) } else { - forest.java <- .jcall(forestTrainer, makeResponse(.class_Forest), "trainSerialInMemory", .object_Optional()) + forest.java <- .jcall(forestTrainer, makeResponse(.class_OnlineForest), "trainSerialInMemory", .object_Optional()) } } @@ -253,6 +258,10 @@ train.internal <- function(dataset, splitFinder, #' assumes (without checking) that the existing trees are from a previous run #' and starts from where it left off. This option is useful if recovering from #' a crash. +#' @param forest.output This parameter only applies if \code{savePath} has been +#' set; set to 'online' (default) and the saved forest will be loaded into +#' memory after being trained. Set to 'offline' and the forest is not saved +#' into memory, but can still be used in a memory unintensive manner. #' @param cores This parameter specifies how many trees will be simultaneously #' trained. By default the package attempts to detect how many cores you have #' by using the \code{parallel} package and using all of them. You may specify @@ -311,7 +320,8 @@ train.internal <- function(dataset, splitFinder, train <- function(formula, data, splitFinder = NULL, nodeResponseCombiner = NULL, forestResponseCombiner = NULL, ntree, numberOfSplits, mtry, nodeSize, maxNodeDepth = 100000, na.penalty = TRUE, splitPureNodes=TRUE, - savePath=NULL, savePath.overwrite=c("warn", "delete", "merge"), + savePath = NULL, savePath.overwrite = c("warn", "delete", "merge"), + forest.output = c("online", "offline"), cores = getCores(), randomSeed = NULL, displayProgress = TRUE){ dataset <- processFormula(formula, data, na.penalty = na.penalty) @@ -322,8 +332,8 @@ train <- function(formula, data, splitFinder = NULL, nodeResponseCombiner = NULL ntree = ntree, numberOfSplits = numberOfSplits, mtry = mtry, nodeSize = nodeSize, maxNodeDepth = maxNodeDepth, splitPureNodes = splitPureNodes, savePath = savePath, - savePath.overwrite = savePath.overwrite, cores = cores, - randomSeed = randomSeed, displayProgress = displayProgress) + savePath.overwrite = savePath.overwrite, forest.output = forest.output, + cores = cores, randomSeed = randomSeed, displayProgress = displayProgress) forest$call <- match.call() forest$formula <- formula @@ -370,8 +380,10 @@ createTreeTrainer <- function(responseCombiner, splitFinder, covariateList, numb builderClassReturned <- makeResponse(.class_TreeTrainer_Builder) builder <- .jcall(.class_TreeTrainer, builderClassReturned, "builder") - - builder <- .jcall(builder, builderClassReturned, "responseCombiner", responseCombiner$javaObject) + + responseCombinerCasted <- .jcast(responseCombiner$javaObject, .class_ResponseCombiner) # might need to cast a ForestResponseCombiner down + + builder <- .jcall(builder, builderClassReturned, "responseCombiner", responseCombinerCasted) builder <- .jcall(builder, builderClassReturned, "splitFinder", splitFinder$javaObject) builder <- .jcall(builder, builderClassReturned, "covariates", covariateList) builder <- .jcall(builder, builderClassReturned, "numberOfSplits", as.integer(numberOfSplits)) diff --git a/inst/java/largeRCRF-library-1.0-SNAPSHOT.jar b/inst/java/largeRCRF-library-1.0-SNAPSHOT.jar index 9de61f0..27f483a 100644 Binary files a/inst/java/largeRCRF-library-1.0-SNAPSHOT.jar and b/inst/java/largeRCRF-library-1.0-SNAPSHOT.jar differ diff --git a/man/addTrees.Rd b/man/addTrees.Rd index 3004baf..94cfad2 100644 --- a/man/addTrees.Rd +++ b/man/addTrees.Rd @@ -6,7 +6,8 @@ \usage{ addTrees(forest, numTreesToAdd, savePath = NULL, savePath.overwrite = c("warn", "delete", "merge"), - cores = getCores(), displayProgress = TRUE) + forest.output = c("online", "offline"), cores = getCores(), + displayProgress = TRUE) } \arguments{ \item{forest}{An existing forest.} @@ -21,6 +22,11 @@ a previously saved forest.} directory, possibly containing another forest, this specifies what should be done.} +\item{forest.output}{This parameter only applies if \code{savePath} has been +set; set to 'online' (default) and the saved forest will be loaded into +memory after being trained. Set to 'offline' and the forest is not saved +into memory, but can still be used in a memory unintensive manner.} + \item{cores}{The number of cores to be used for training the new trees.} \item{displayProgress}{A logical indicating whether the progress should be diff --git a/man/convertToOnlineForest.Rd b/man/convertToOnlineForest.Rd new file mode 100644 index 0000000..6e47cc7 --- /dev/null +++ b/man/convertToOnlineForest.Rd @@ -0,0 +1,20 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/convertToOnlineForest.R +\name{convertToOnlineForest} +\alias{convertToOnlineForest} +\title{Convert to Online Forest} +\usage{ +convertToOnlineForest(forest) +} +\arguments{ +\item{forest}{The offline forest.} +} +\value{ +An online, in memory forst. +} +\description{ +Some forests are too large to store in memory and have been saved to disk. +They can still be used, but their performance is much slower. If there's +enough memory, they can easily be converted into an in-memory forest that is +faster to use. +} diff --git a/man/loadForest.Rd b/man/loadForest.Rd index e571378..d283e12 100644 --- a/man/loadForest.Rd +++ b/man/loadForest.Rd @@ -4,10 +4,19 @@ \alias{loadForest} \title{Load Random Forest} \usage{ -loadForest(directory) +loadForest(directory, forest.output = c("online", "offline"), + maxTreeNum = NULL) } \arguments{ \item{directory}{The directory created that saved the previous forest.} + +\item{forest.output}{Specifies whether the forest loaded should be loaded +into memory, or reflect the saved files where only one tree is loaded at a +time.} + +\item{maxTreeNum}{If for some reason you only want to load the number of +trees up until a certain point, you can specify maxTreeNum as a single +number.} } \value{ A JForest object; see \code{\link{train}} for details. diff --git a/man/train.Rd b/man/train.Rd index 09ea41a..55b34ef 100644 --- a/man/train.Rd +++ b/man/train.Rd @@ -8,7 +8,8 @@ train(formula, data, splitFinder = NULL, nodeResponseCombiner = NULL, forestResponseCombiner = NULL, ntree, numberOfSplits, mtry, nodeSize, maxNodeDepth = 1e+05, na.penalty = TRUE, splitPureNodes = TRUE, savePath = NULL, savePath.overwrite = c("warn", "delete", "merge"), - cores = getCores(), randomSeed = NULL, displayProgress = TRUE) + forest.output = c("online", "offline"), cores = getCores(), + randomSeed = NULL, displayProgress = TRUE) } \arguments{ \item{formula}{You may specify the response and covariates as a formula @@ -106,6 +107,11 @@ assumes (without checking) that the existing trees are from a previous run and starts from where it left off. This option is useful if recovering from a crash.} +\item{forest.output}{This parameter only applies if \code{savePath} has been +set; set to 'online' (default) and the saved forest will be loaded into +memory after being trained. Set to 'offline' and the forest is not saved +into memory, but can still be used in a memory unintensive manner.} + \item{cores}{This parameter specifies how many trees will be simultaneously trained. By default the package attempts to detect how many cores you have by using the \code{parallel} package and using all of them. You may specify diff --git a/tests/testthat/test_adding_new_trees.R b/tests/testthat/test_adding_new_trees.R index c3bcf92..1b7398f 100644 --- a/tests/testthat/test_adding_new_trees.R +++ b/tests/testthat/test_adding_new_trees.R @@ -1,6 +1,6 @@ context("Add trees on existing forest") -test_that("Can add trees on existing forest", { +test_that("Can add trees on existing online forest", { trainingData <- data.frame(x=rnorm(100)) trainingData$T <- rexp(100) + abs(trainingData$x) @@ -20,6 +20,44 @@ test_that("Can add trees on existing forest", { }) +test_that("Can add trees on existing offline forest", { + + if(file.exists("trees")){ # folder could exist from a previous failed test; delete it + unlink("trees", recursive=TRUE) + } + + + trainingData <- data.frame(x=rnorm(100)) + trainingData$T <- rexp(100) + abs(trainingData$x) + trainingData$delta <- sample(0:2, size = 100, replace=TRUE) + + forest <- train(CR_Response(delta, T) ~ x, trainingData, ntree=50, + numberOfSplits=0, mtry=1, nodeSize=5, + forestResponseCombiner = CR_FunctionCombiner(events = 1:2, times = 0:10), # TODO - remove specifing times; this is workaround around unimplemented feature for offline forests + cores=2, displayProgress=FALSE, savePath="trees", + forest.output = "offline") + warning("TODO - need to implement feature; test workaround in the meantime") + + predictions <- predict(forest) + + warning_message <- "Assuming that the previous forest at savePath is the provided forest argument; if not true then your results will be suspect" + + forest.more <- expect_warning(addTrees(forest, 50, cores=2, displayProgress=FALSE, + savePath="trees", savePath.overwrite = "merge", + forest.output = "offline"), fixed=warning_message) # test multi-core + + predictions <- predict(forest) + + forest.more <- expect_warning(addTrees(forest, 50, cores=1, displayProgress=FALSE, + savePath="trees", savePath.overwrite = "merge", + forest.output = "offline"), fixed=warning_message) # test single-core + + expect_true(T) # show Ok if we got this far + + unlink("trees", recursive=TRUE) + +}) + test_that("Test adding trees on saved forest - using delete", { if(file.exists("trees")){ # folder could exist from a previous failed test; delete it diff --git a/tests/testthat/test_saving_loading.R b/tests/testthat/test_saving_loading.R index 1d7a425..2eb93ed 100644 --- a/tests/testthat/test_saving_loading.R +++ b/tests/testthat/test_saving_loading.R @@ -2,7 +2,11 @@ context("Train, save, and load without error") test_that("Can save & load regression example", { - expect_false(file.exists("trees_saving_loading")) # Folder shouldn't exist yet + if(file.exists("trees_saving_loading")){ + unlink("trees_saving_loading", recursive=TRUE) + } + + expect_false(file.exists("trees_saving_loading")) # Folder shouldn't exist at this point x1 <- rnorm(1000) x2 <- rnorm(1000) diff --git a/tests/testthat/test_saving_offline.R b/tests/testthat/test_saving_offline.R index 1823d34..918c24d 100644 --- a/tests/testthat/test_saving_offline.R +++ b/tests/testthat/test_saving_offline.R @@ -13,7 +13,7 @@ test_that("Can save a random forest while training, and use it afterward", { data <- data.frame(x1, x2, y) forest <- train(y ~ x1 + x2, data, ntree=100, numberOfSplits = 5, mtry = 1, nodeSize = 5, - savePath="trees", displayProgress=FALSE) + savePath="trees", forest.output = "online", displayProgress=FALSE) expect_true(file.exists("trees")) # Something should have been saved @@ -26,6 +26,39 @@ test_that("Can save a random forest while training, and use it afterward", { predictions <- predict(newforest, newData) + unlink("trees", recursive=TRUE) + +}) + +test_that("Can save a random forest while training, and use it afterward with pure offline forest", { + + if(file.exists("trees")){ # folder could exist from a previous failed test; delete it + unlink("trees", recursive=TRUE) + } + + x1 <- rnorm(1000) + x2 <- rnorm(1000) + y <- 1 + x1 + x2 + rnorm(1000) + + data <- data.frame(x1, x2, y) + forest <- train(y ~ x1 + x2, data, + ntree=100, numberOfSplits = 5, mtry = 1, nodeSize = 5, + savePath="trees", forest.output = "offline", displayProgress=FALSE) + + expect_true(file.exists("trees")) # Something should have been saved + + # try making a little prediction to verify it works + newData <- data.frame(x1=seq(from=-3, to=3, by=0.5), x2=0) + predictions <- predict(forest, newData) + + # Also make sure we can load the forest too + newforest <- loadForest("trees") + predictions <- predict(newforest, newData) + + # Last, make sure we can take the forest online + onlineForest <- convertToOnlineForest(forest) + predictions <- predict(onlineForest, newData) + unlink("trees", recursive=TRUE) }) \ No newline at end of file