diff --git a/R/predict.R b/R/predict.R index 8d7ff5a..81ccfc2 100644 --- a/R/predict.R +++ b/R/predict.R @@ -99,7 +99,7 @@ predict.JRandomForest <- function(object, newData=NULL, parallel=TRUE, out.of.ba predictionsJava <- .jcall(forestObject, makeResponse(.class_List), function.to.use, predictionDataList) if(predictionClass == "numeric"){ - predictions <- vector(length=nrow(newData), mode="numeric") + predictions <- vector(length=numRows, mode="numeric") } else{ predictions <- list() diff --git a/tests/testthat/test_running.R b/tests/testthat/test_running.R index b235b61..7d14faa 100644 --- a/tests/testthat/test_running.R +++ b/tests/testthat/test_running.R @@ -28,6 +28,7 @@ test_that("Regresssion doesn't crash", { forest <- train(y ~ x, trainingData, ntree=50, numberOfSplits=0, mtry=1, nodeSize=5, cores=2, displayProgress=FALSE) predictions <- predict(forest, testData) + other_predictions <- predict(forest) # there was a bug if newData wasn't provided. expect_true(T) # show Ok if we got this far