2019-05-31 22:13:24 +00:00
#' Predict
#'
#' Predict on the random forest.
#'
2019-06-30 22:41:33 +00:00
#' @param object A forest that was previously \code{\link{train}}ed
2019-05-31 22:13:24 +00:00
#' @param newData The new data containing all of the previous predictor
2019-06-19 20:14:11 +00:00
#' covariates. Can be NULL if you want to use the training dataset, and
2019-06-30 22:41:33 +00:00
#' \code{object} hasn't been loaded from the disk; otherwise you'll have to
2019-06-19 20:14:11 +00:00
#' specify it.
2019-05-31 22:13:24 +00:00
#' @param parallel A logical indicating whether multiple cores should be
#' utilized when making the predictions. Available as an option because it's
2019-06-06 22:53:25 +00:00
#' been observed that using Java's \code{parallelStream} can be unstable on
#' some systems. Default value is \code{TRUE}; only set to \code{FALSE} if you
#' get strange errors while predicting.
2019-05-31 22:13:24 +00:00
#' @param out.of.bag A logical indicating whether predictions should be based on
#' 'out of bag' trees; set only to \code{TRUE} if you're running predictions
2019-06-19 20:14:11 +00:00
#' on data that was used in the training. Default value is \code{TRUE} if
#' \code{newData} is \code{NULL}, otherwise \code{FALSE}.
2019-06-30 22:41:33 +00:00
#' @param ... Other parameters that may one day get passed onto other functions;
#' currently not used.
2019-05-31 22:13:24 +00:00
#' @return A list of responses corresponding with each row of \code{newData} if
#' it's a non-regression random forest; otherwise it returns a numeric vector.
#' @export
#' @examples
#' # Regression Example
#' x1 <- rnorm(1000)
#' x2 <- rnorm(1000)
#' y <- 1 + x1 + x2 + rnorm(1000)
#'
#' data <- data.frame(x1, x2, y)
2019-06-30 22:41:33 +00:00
#' forest <- train(y ~ x1 + x2, data, ntree=100, numberOfSplits = 5,
2019-06-30 22:07:29 +00:00
#' mtry = 1, nodeSize = 5)
2019-05-31 22:13:24 +00:00
#'
#' # Fix x2 to be 0
#' newData <- data.frame(x1 = seq(from=-2, to=2, by=0.5), x2 = 0)
#' ypred <- predict(forest, newData)
#'
#' plot(ypred ~ newData$x1, type="l")
#'
#' # Competing Risk Example
#' x1 <- abs(rnorm(1000))
#' x2 <- abs(rnorm(1000))
#'
#' T1 <- rexp(1000, rate=x1)
#' T2 <- rweibull(1000, shape=x1, scale=x2)
#' C <- rexp(1000)
#' u <- pmin(T1, T2, C)
#' delta <- ifelse(u==T1, 1, ifelse(u==T2, 2, 0))
#'
#' data <- data.frame(x1, x2)
#'
2019-06-30 22:07:29 +00:00
#' forest <- train(CR_Response(delta, u) ~ x1 + x2, data, ntree=100,
#' numberOfSplits=5, mtry=1, nodeSize=10)
2019-05-31 22:13:24 +00:00
#' newData <- data.frame(x1 = c(-1, 0, 1), x2 = 0)
#' ypred <- predict(forest, newData)
2019-06-30 22:41:33 +00:00
predict.JRandomForest <- function ( object , newData = NULL , parallel = TRUE , out.of.bag = NULL , ... ) {
# slight renaming
forest <- object
2019-06-19 20:14:11 +00:00
if ( is.null ( newData ) & is.null ( forest $ dataset ) ) {
stop ( " forest doesn't have a copy of the training data loaded (this happens if you just loaded it); please manually specify newData and possibly out.of.bag" )
}
2019-05-31 22:13:24 +00:00
if ( is.null ( newData ) ) {
2019-06-19 20:14:11 +00:00
predictionDataList <- forest $ dataset
if ( is.null ( out.of.bag ) ) {
out.of.bag <- TRUE
}
}
else { # newData is provided
if ( is.null ( out.of.bag ) ) {
out.of.bag <- FALSE
}
predictionDataList <- loadPredictionData ( newData , forest $ covariateList )
2019-05-31 22:13:24 +00:00
}
2019-06-19 20:14:11 +00:00
numRows <- .jcall ( predictionDataList , " I" , " size" )
2019-05-31 22:13:24 +00:00
forestObject <- forest $ javaObject
predictionClass <- forest $ params $ forestResponseCombiner $ outputClass
convertToRFunction <- forest $ params $ forestResponseCombiner $ convertToRFunction
if ( parallel ) {
function.to.use <- " evaluate"
}
else {
function.to.use <- " evaluateSerial"
}
if ( out.of.bag ) {
function.to.use <- paste0 ( function.to.use , " OOB" )
}
predictionsJava <- .jcall ( forestObject , makeResponse ( .class_List ) , function.to.use , predictionDataList )
if ( predictionClass == " numeric" ) {
2019-09-09 00:03:41 +00:00
predictions <- vector ( length = numRows , mode = " numeric" )
2019-05-31 22:13:24 +00:00
}
else {
predictions <- list ( )
}
2019-06-19 20:14:11 +00:00
for ( i in 1 : numRows ) {
2019-05-31 22:13:24 +00:00
prediction <- .jcall ( predictionsJava , makeResponse ( .class_Object ) , " get" , as.integer ( i -1 ) )
prediction <- convertToRFunction ( prediction , forest )
predictions [ [i ] ] <- prediction
}
class ( predictions ) <- paste0 ( predictionClass , " .List" )
return ( predictions )
}