Initial commit for pre-release development version
This commit is contained in:
commit
9217c72cf9
1170 changed files with 2798 additions and 0 deletions
2
.Rbuildignore
Normal file
2
.Rbuildignore
Normal file
|
@ -0,0 +1,2 @@
|
||||||
|
^.*\.Rproj$
|
||||||
|
^\.Rproj\.user$
|
3
.gitignore
vendored
Normal file
3
.gitignore
vendored
Normal file
|
@ -0,0 +1,3 @@
|
||||||
|
*.Rproj
|
||||||
|
.Rproj.user
|
||||||
|
copyJar
|
19
DESCRIPTION
Normal file
19
DESCRIPTION
Normal file
|
@ -0,0 +1,19 @@
|
||||||
|
Package: largeRCRF
|
||||||
|
Type: Package
|
||||||
|
Title: Large Random Competing Risk Forests, Java Implementation Run in R
|
||||||
|
Version: 0.0.0.9036
|
||||||
|
Authors@R: person("Joel", "Therrien", email = "joel@joeltherrien.ca", role = c("aut", "cre"))
|
||||||
|
Description: This package is used for training competing risk random forests on larger scale datasets.
|
||||||
|
It currently only supports training models, running predictions, plotting those predictions (they are curves),
|
||||||
|
and some simple error analysis using concordance measures.
|
||||||
|
License: GPL-3
|
||||||
|
Encoding: UTF-8
|
||||||
|
LazyData: true
|
||||||
|
Imports:
|
||||||
|
rJava (>= 0.9-9)
|
||||||
|
Suggests:
|
||||||
|
parallel,
|
||||||
|
testthat
|
||||||
|
Depends: R (>= 3.4.2)
|
||||||
|
SystemRequirements: Java JDK 1.8 or higher
|
||||||
|
RoxygenNote: 6.1.1
|
39
NAMESPACE
Normal file
39
NAMESPACE
Normal file
|
@ -0,0 +1,39 @@
|
||||||
|
# Generated by roxygen2: do not edit by hand
|
||||||
|
|
||||||
|
S3method(extractCHF,CompetingRiskFunctions)
|
||||||
|
S3method(extractCHF,CompetingRiskFunctions.List)
|
||||||
|
S3method(extractCIF,CompetingRiskFunctions)
|
||||||
|
S3method(extractCIF,CompetingRiskFunctions.List)
|
||||||
|
S3method(extractMortalities,CompetingRiskFunctions)
|
||||||
|
S3method(extractMortalities,CompetingRiskFunctions.List)
|
||||||
|
S3method(extractSurvivorCurve,CompetingRiskFunctions)
|
||||||
|
S3method(extractSurvivorCurve,CompetingRiskFunctions.List)
|
||||||
|
S3method(plot,JMatrixPlottable)
|
||||||
|
S3method(predict,JRandomForest)
|
||||||
|
S3method(print,CompetingRiskFunctions)
|
||||||
|
S3method(print,CompetingRiskFunctions.List)
|
||||||
|
S3method(print,JRandomForest)
|
||||||
|
S3method(print,ResponseCombiner)
|
||||||
|
S3method(print,SplitFinder)
|
||||||
|
S3method(train,default)
|
||||||
|
S3method(train,formula)
|
||||||
|
export(CR_FunctionCombiner)
|
||||||
|
export(CR_Response)
|
||||||
|
export(CR_ResponseCombiner)
|
||||||
|
export(GrayLogRankSplitFinder)
|
||||||
|
export(LogRankSplitFinder)
|
||||||
|
export(MeanResponseCombiner)
|
||||||
|
export(Numeric)
|
||||||
|
export(WeightedVarianceSplitFinder)
|
||||||
|
export(convertRListToJava)
|
||||||
|
export(extractCHF)
|
||||||
|
export(extractCIF)
|
||||||
|
export(extractMortalities)
|
||||||
|
export(extractSurvivorCurve)
|
||||||
|
export(load_covariate_list_from_settings)
|
||||||
|
export(load_forest)
|
||||||
|
export(load_forest_args_provided)
|
||||||
|
export(naiveConcordance)
|
||||||
|
export(save_forest)
|
||||||
|
export(train)
|
||||||
|
import(rJava)
|
106
R/CR_Response.R
Normal file
106
R/CR_Response.R
Normal file
|
@ -0,0 +1,106 @@
|
||||||
|
|
||||||
|
#' Competing Risk Response
|
||||||
|
#'
|
||||||
|
#' Takes vectors of event time and event type and turns it into the internal
|
||||||
|
#' objects used throughout the package. The result of this function shouldn't be
|
||||||
|
#' used directly, but should instead by provided as the \code{y} parameter in
|
||||||
|
#' \code{\link{train}}.
|
||||||
|
#'
|
||||||
|
#' @param delta A vector of integers detailing the event that occurred. A value
|
||||||
|
#' of 0 denotes that censoring occurred first and that time was recorded.
|
||||||
|
#' @param u A vector of numerics detailing the recorded event times (possibly
|
||||||
|
#' censored).
|
||||||
|
#' @param C If the censoring times are known for all observations, they can be
|
||||||
|
#' included which allows for \code{\link{GrayLogRankSplitFinder}} to be used.
|
||||||
|
#' Default is \code{NULL}.
|
||||||
|
#'
|
||||||
|
#' @details To be clear, if T1,...TJ are the J different competing risks, and C
|
||||||
|
#' is the censoring time, then \code{u[i] = min(T1[i], ...TJ[i], C[i])}; and
|
||||||
|
#' \code{delta[i]} denotes which time was the minimum, with a value of 0 if
|
||||||
|
#' C[i] was the smallest.
|
||||||
|
#' @export
|
||||||
|
#' @examples
|
||||||
|
#' T1 <- rexp(10)
|
||||||
|
#' T2 <- rweibull(10, 2, 2)
|
||||||
|
#' C <- rexp(10)
|
||||||
|
#'
|
||||||
|
#' u <- pmin(T1, T2, C)
|
||||||
|
#' delta <- ifelse(u == T1, 1, ifelse(u == T2, 2, 0))
|
||||||
|
#'
|
||||||
|
#' responses <- CR_Response(delta, u)
|
||||||
|
#' # Then use responses in train
|
||||||
|
CR_Response <- function(delta, u, C = NULL){
|
||||||
|
if(is.null(C)){
|
||||||
|
return(Java_CompetingRiskResponses(delta, u))
|
||||||
|
} else{
|
||||||
|
return(Java_CompetingRiskResponsesWithCensorTimes(delta, u, C))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# Internal function
|
||||||
|
Java_CompetingRiskResponses <- function(delta, u){
|
||||||
|
|
||||||
|
if(length(delta) != length(u)){
|
||||||
|
stop("delta and u must be of the same length")
|
||||||
|
}
|
||||||
|
|
||||||
|
if(anyNA(delta) | is.null(delta)){
|
||||||
|
stop("delta must be specified")
|
||||||
|
}
|
||||||
|
|
||||||
|
if(anyNA(u) | is.null(u)){
|
||||||
|
stop("u must be specified")
|
||||||
|
}
|
||||||
|
|
||||||
|
delta <- as.integer(delta)
|
||||||
|
u <- as.double(u)
|
||||||
|
|
||||||
|
delta.java <- .jarray(delta, contents.class="I")
|
||||||
|
u.java <- .jarray(u, contents.class="D")
|
||||||
|
|
||||||
|
responses.java.list <- .jcall(.class_RUtils, makeResponse(.class_List),
|
||||||
|
"importCompetingRiskResponses", delta.java, u.java)
|
||||||
|
|
||||||
|
responses <- list(javaObject=responses.java.list, eventIndicator=delta, eventTime=u)
|
||||||
|
class(responses) <- "CompetingRiskResponses"
|
||||||
|
|
||||||
|
return(responses)
|
||||||
|
}
|
||||||
|
|
||||||
|
# Internal function
|
||||||
|
Java_CompetingRiskResponsesWithCensorTimes <- function(delta, u, C){
|
||||||
|
|
||||||
|
if(length(delta) != length(u) | length(u) != length(C)){
|
||||||
|
stop("delta, u, and C must be of the same length")
|
||||||
|
}
|
||||||
|
|
||||||
|
if(anyNA(delta) | is.null(delta)){
|
||||||
|
stop("delta must be specified")
|
||||||
|
}
|
||||||
|
|
||||||
|
if(anyNA(u) | is.null(u)){
|
||||||
|
stop("u must be specified")
|
||||||
|
}
|
||||||
|
|
||||||
|
if(anyNA(C) | is.null(C)){
|
||||||
|
stop("C must be specified")
|
||||||
|
}
|
||||||
|
|
||||||
|
delta <- as.integer(delta)
|
||||||
|
u <- as.double(u)
|
||||||
|
C <- as.double(C)
|
||||||
|
|
||||||
|
delta.java <- .jarray(delta, contents.class="I")
|
||||||
|
u.java <- .jarray(u, contents.class="D")
|
||||||
|
C.java <- .jarray(C, contents.class="D")
|
||||||
|
|
||||||
|
responses.java.list <- .jcall(.class_RUtils, makeResponse(.class_List),
|
||||||
|
"importCompetingRiskResponsesWithCensorTimes", delta.java, u.java, C.java)
|
||||||
|
|
||||||
|
responses <- list(javaObject=responses.java.list, eventIndicator=delta, eventTime=u, censorTime=C)
|
||||||
|
class(responses) <- "CompetingRiskResponsesWithCensorTimes"
|
||||||
|
|
||||||
|
return(responses)
|
||||||
|
}
|
||||||
|
|
25
R/Numeric.R
Normal file
25
R/Numeric.R
Normal file
|
@ -0,0 +1,25 @@
|
||||||
|
|
||||||
|
#' Numeric
|
||||||
|
#'
|
||||||
|
#' An internal function that converts an R vector of numerics or integers into an R list containing java.lang.Double objects. This method does not need to be used directly by the user, as \code{\link{train}} will automatically handle numeric responses if you're working in the regression settings.
|
||||||
|
#' @param y The R vector of numbers
|
||||||
|
#' @export
|
||||||
|
#' @return An R list containing rJava Doubles.
|
||||||
|
#' @keywords internal
|
||||||
|
#' @examples
|
||||||
|
#' x <- Numeric(1:5)
|
||||||
|
#' class(x[[1]])
|
||||||
|
Numeric <- function(y){
|
||||||
|
y <- as.double(y)
|
||||||
|
|
||||||
|
javaList <- .jcall(.class_RUtils,
|
||||||
|
makeResponse(.class_List),
|
||||||
|
"importNumericResponse",
|
||||||
|
y)
|
||||||
|
|
||||||
|
responses <- list(javaObject=javaList, y=y)
|
||||||
|
|
||||||
|
class(responses) <- "JNumeric"
|
||||||
|
|
||||||
|
return(responses)
|
||||||
|
}
|
202
R/cr_components.R
Normal file
202
R/cr_components.R
Normal file
|
@ -0,0 +1,202 @@
|
||||||
|
|
||||||
|
#' Competing Risk Function Combiner
|
||||||
|
#'
|
||||||
|
#' Creates a CompetingRiskFunctionCombiner rJava object, which is used
|
||||||
|
#' internally for constructing a forest. The forest uses it when creating
|
||||||
|
#' predictions to average the cumulative incidence curves, cause-specific
|
||||||
|
#' cumulative hazard functions, and Kaplan-Meier curves generated by each tree
|
||||||
|
#' into individual functions.
|
||||||
|
#'
|
||||||
|
#' The user only needs to pass this object into \code{\link{train}} as the
|
||||||
|
#' \code{forestResponseCombiner} parameter.
|
||||||
|
#'
|
||||||
|
#' @return A response combiner object to be used in \code{\link{train}}; not
|
||||||
|
#' useful on its own. However, internally, a response combiner object is a
|
||||||
|
#' list consisting of the following objects: \describe{
|
||||||
|
#' \item{\code{javaObject}}{The java object used in the algorithm}
|
||||||
|
#' \item{\code{call}}{The call (used in \code{print})}
|
||||||
|
#' \item{\code{outputClass}}{The R class of the outputs; used in
|
||||||
|
#' \code{\link{predict.JRandomForest}}} \item{\code{convertToRFunction}}{An R
|
||||||
|
#' function that converts a Java prediction from the combiner into R output
|
||||||
|
#' that is readable by a user.} }
|
||||||
|
#'
|
||||||
|
#' @param events A vector of integers specifying which competing risk events's
|
||||||
|
#' functions should be processed. This should correspond to all of the
|
||||||
|
#' competing risk events that can occur, from 1 to the largest number.
|
||||||
|
#' @param times An optional numeric vector that forces the output functions to
|
||||||
|
#' only update at these time points. Pre-specifying the values can result in
|
||||||
|
#' faster performance when predicting, however if the times are not exhaustive
|
||||||
|
#' then the resulting curves will not update at that point (they'll be flat).
|
||||||
|
#' If left blank, the package will default to using all of the time points.
|
||||||
|
#' @export
|
||||||
|
#' @examples
|
||||||
|
#' T1 <- rexp(1000)
|
||||||
|
#' T2 <- rweibull(1000, 1, 2)
|
||||||
|
#' C <- rexp(1000)
|
||||||
|
#'
|
||||||
|
#' u <- round(pmin(T1, T2, C))
|
||||||
|
#' # ...
|
||||||
|
#'
|
||||||
|
#' forestCombiner <- CR_FunctionCombiner(1:2) # there are two possible events
|
||||||
|
#' # or, since we know that u is always an integer
|
||||||
|
#' forestCombiner <- CR_FunctionCombiner(1:2, 0:max(u))
|
||||||
|
CR_FunctionCombiner <- function(events, times = NULL){
|
||||||
|
# need to first change events into array of int
|
||||||
|
eventArray <- .jarray(events, "I")
|
||||||
|
|
||||||
|
if(is.null(times)){
|
||||||
|
timeArray <- .jnull(class="[D")
|
||||||
|
}
|
||||||
|
else{
|
||||||
|
timeArray <- .jarray(as.numeric(times), "D")
|
||||||
|
}
|
||||||
|
|
||||||
|
javaObject <- .jnew(.class_CompetingRiskFunctionCombiner, eventArray, timeArray)
|
||||||
|
javaObject <- .jcast(javaObject, .class_ResponseCombiner)
|
||||||
|
|
||||||
|
combiner <- list(javaObject=javaObject,
|
||||||
|
call=match.call(),
|
||||||
|
events=events,
|
||||||
|
outputClass="CompetingRiskFunctions",
|
||||||
|
convertToRFunction=convertCompetingRiskFunctions)
|
||||||
|
class(combiner) <- "ResponseCombiner"
|
||||||
|
|
||||||
|
return(combiner)
|
||||||
|
}
|
||||||
|
|
||||||
|
#' Competing Risk Response Combiner
|
||||||
|
#'
|
||||||
|
#' Creates a CompetingRiskResponseCombiner rJava object, which is used
|
||||||
|
#' internally for constructing a forest. It is used when each tree in the forest
|
||||||
|
#' is constructed, as it combines response level information (u & delta) into
|
||||||
|
#' functions such as cumulative incidence curves, cause-specific cumulative
|
||||||
|
#' hazard functions, and an overall Kaplan-Meier curve. This combination is done
|
||||||
|
#' for each terminal node for each tree.
|
||||||
|
#'
|
||||||
|
#' The user only needs to pass this object into \code{\link{train}} as the
|
||||||
|
#' \code{nodeResponseCombiner} parameter.
|
||||||
|
#'
|
||||||
|
#' @return A response combiner object to be used in \code{\link{train}}; not
|
||||||
|
#' useful on its own. However, internally, a response combiner object is a
|
||||||
|
#' list consisting of the following objects: \describe{
|
||||||
|
#' \item{\code{javaObject}}{The java object used in the algorithm}
|
||||||
|
#' \item{\code{call}}{The call (used in \code{print})}
|
||||||
|
#' \item{\code{outputClass}}{The R class of the outputs; used in
|
||||||
|
#' \code{\link{predict.JRandomForest}}} \item{\code{convertToRFunction}}{An R
|
||||||
|
#' function that converts a Java prediction from the combiner into R output
|
||||||
|
#' that is readable by a user.} }
|
||||||
|
#'
|
||||||
|
#' @param events A vector of integers specifying which competing risk events's
|
||||||
|
#' functions should be processed. This should correspond to all of the
|
||||||
|
#' competing risk events that can occur, from 1 to the largest number.
|
||||||
|
#' @export
|
||||||
|
#' @examples
|
||||||
|
#' T1 <- rexp(1000)
|
||||||
|
#' T2 <- rweibull(1000, 1, 2)
|
||||||
|
#' C <- rexp(1000)
|
||||||
|
#'
|
||||||
|
#' u <- round(pmin(T1, T2, C))
|
||||||
|
#' # ...
|
||||||
|
#'
|
||||||
|
#' forestCombiner <- CR_ResponseCombiner(1:2) # there are two possible events
|
||||||
|
CR_ResponseCombiner <- function(events){
|
||||||
|
# need to first change events into array of int
|
||||||
|
eventArray <- .jarray(events, "I")
|
||||||
|
|
||||||
|
|
||||||
|
javaObject <- .jnew(.class_CompetingRiskResponseCombiner, eventArray)
|
||||||
|
javaObject <- .jcast(javaObject, .class_ResponseCombiner)
|
||||||
|
|
||||||
|
combiner <- list(javaObject=javaObject,
|
||||||
|
call=match.call(),
|
||||||
|
outputClass="CompetingRiskFunctions",
|
||||||
|
convertToRFunction=convertCompetingRiskFunctions
|
||||||
|
)
|
||||||
|
class(combiner) <- "ResponseCombiner"
|
||||||
|
|
||||||
|
return(combiner)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
#' Competing Risk Split Finders
|
||||||
|
#'
|
||||||
|
#' Creates a SplitFinder rJava Java object, which is then used internally when
|
||||||
|
#' training a competing risk random forest. The split finder is responsible for
|
||||||
|
#' finding the best split according to the logic of the split finder.
|
||||||
|
#'
|
||||||
|
#' These split finders require that the response be \code{\link{CR_Response}}.
|
||||||
|
#'
|
||||||
|
#' The user only needs to pass this object into \code{\link{train}} as the
|
||||||
|
#' \code{splitFinder} parameter.
|
||||||
|
#'
|
||||||
|
#' @return An internal rJava Java object used in \code{\link{train}}.
|
||||||
|
#' @note The Gray log-rank split finder \strong{requires} that the response
|
||||||
|
#' include the censoring time.
|
||||||
|
#' @param events A vector of integers specifying which competing risk events
|
||||||
|
#' should be focused on when determining differences. Currently, equal weights
|
||||||
|
#' will be assigned to all included groups.
|
||||||
|
#' @param eventsOfFocus The split finder will only maximize differences
|
||||||
|
#' between the two groups with respect to these specified events. Default is
|
||||||
|
#' \code{NULL}, which will cause the split finder to focus on all events
|
||||||
|
#' included in \code{events}.
|
||||||
|
#' @details Roughly speaking, the Gray log-rank split finder looks at
|
||||||
|
#' differences between the cumulative incidence functions of the two groups,
|
||||||
|
#' while the plain log-rank split finder look at differences between the
|
||||||
|
#' cause-specific hazard functions. See the references for a more detailed
|
||||||
|
#' discussion.
|
||||||
|
#' @references Kogalur, U., Ishwaran, H. Random Forests for Survival,
|
||||||
|
#' Regression, and Classification: A Parallel Package for a General
|
||||||
|
#' Implemention of Breiman's Random Forests: Theory and Specifications. URL
|
||||||
|
#' https://kogalur.github.io/randomForestSRC/theory.html#section8.2
|
||||||
|
#'
|
||||||
|
#' Ishwaran, H., et. al. (2014) Random survival forests for competing risks,
|
||||||
|
#' Biostatistics (2014), 15, 4, pp. 757–773
|
||||||
|
#'
|
||||||
|
#' @name CompetingRiskSplitFinders
|
||||||
|
NULL
|
||||||
|
|
||||||
|
#' @rdname CompetingRiskSplitFinders
|
||||||
|
#' @export
|
||||||
|
#' @examples splitFinder <- GrayLogRankSplitFinder(1:2)
|
||||||
|
GrayLogRankSplitFinder <- function(events, eventsOfFocus = NULL){
|
||||||
|
return(java.LogRankSplitFinder(events, eventsOfFocus, .class_GrayLogRankSplitFinder, match.call()))
|
||||||
|
}
|
||||||
|
|
||||||
|
#' @rdname CompetingRiskSplitFinders
|
||||||
|
#' @export
|
||||||
|
#' @examples splitFinder <- LogRankSplitFinder(1:2)
|
||||||
|
LogRankSplitFinder <- function(events, eventsOfFocus = NULL){
|
||||||
|
return(java.LogRankSplitFinder(events, eventsOfFocus, .class_LogRankSplitFinder, match.call()))
|
||||||
|
}
|
||||||
|
|
||||||
|
# Internal function for creating a competing risk split finder
|
||||||
|
java.LogRankSplitFinder <- function(events, eventsOfFocus, java.class, call = match.call()){
|
||||||
|
events <- sort(events)
|
||||||
|
|
||||||
|
if(is.null(eventsOfFocus)){
|
||||||
|
eventsOfFocus <- events
|
||||||
|
}
|
||||||
|
|
||||||
|
# Check the events
|
||||||
|
if(any(diff(events) != 1) | min(events) != 1){
|
||||||
|
stop("The events provided for creating a log rank split finder do not run from 1 to the maximum")
|
||||||
|
}
|
||||||
|
|
||||||
|
if(any(!(eventsOfFocus %in% events))){
|
||||||
|
stop("There's an event of focus for the log rank split finder that's not included in the events vector")
|
||||||
|
}
|
||||||
|
|
||||||
|
events <- .jarray(as.integer(events))
|
||||||
|
eventsOfFocus <- .jarray(as.integer(eventsOfFocus))
|
||||||
|
|
||||||
|
javaObject <- .jnew(java.class, eventsOfFocus, events)
|
||||||
|
javaObject <- .jcast(javaObject, .class_SplitFinder)
|
||||||
|
|
||||||
|
splitFinder <- list(javaObject=javaObject, call=call)
|
||||||
|
class(splitFinder) <- "SplitFinder"
|
||||||
|
|
||||||
|
return(splitFinder)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
52
R/cr_naiveConcordance.R
Normal file
52
R/cr_naiveConcordance.R
Normal file
|
@ -0,0 +1,52 @@
|
||||||
|
|
||||||
|
#' Naive Concordance
|
||||||
|
#'
|
||||||
|
#' Used to calculate a concordance index error. The user needs to supply a list
|
||||||
|
#' of mortalities, with each item in the list being a vector for the specific
|
||||||
|
#' events. To calculate mortalities a user should look to
|
||||||
|
#' \code{\link{extractMortalities}}.
|
||||||
|
#'
|
||||||
|
#' @return A vector of 1 minus the concordance scores, with each element
|
||||||
|
#' corresponding to one of the events. To be clear, the lower the score the
|
||||||
|
#' more accurate the model was.
|
||||||
|
#'
|
||||||
|
#' @param responses A list of responses corresponding to the provided
|
||||||
|
#' mortalities; use \code{\link{CR_Response}}.
|
||||||
|
#' @param predictedMortalities A list of mortality vectors; each element of the
|
||||||
|
#' list should correspond to one of the events in the order of event 1 to J,
|
||||||
|
#' and should be a vector of the same length as responses.
|
||||||
|
#' @export
|
||||||
|
naiveConcordance <- function(responses, predictedMortalities){
|
||||||
|
if(is.null(responses)){
|
||||||
|
stop("responses cannot be null")
|
||||||
|
}
|
||||||
|
|
||||||
|
if(is.null(predictedMortalities)){
|
||||||
|
stop("predictedMortalities cannot be null")
|
||||||
|
}
|
||||||
|
if(!is.list(predictedMortalities)){
|
||||||
|
stop("predictedMortalities must be a list")
|
||||||
|
}
|
||||||
|
|
||||||
|
responseList = responses$javaObject
|
||||||
|
responseLength = .jcall(responseList, "I", "size")
|
||||||
|
|
||||||
|
events = as.integer(1:length(predictedMortalities))
|
||||||
|
|
||||||
|
concordances = numeric(length(predictedMortalities))
|
||||||
|
|
||||||
|
for(event in events){
|
||||||
|
if(length(predictedMortalities[[event]]) != responseLength){
|
||||||
|
stop("Every mortality vector in predictedMortalities must be the same length as responses")
|
||||||
|
}
|
||||||
|
|
||||||
|
# Need to turn predictedMortalities into an array of doubles
|
||||||
|
mortality = .jarray(predictedMortalities[[event]], "D")
|
||||||
|
|
||||||
|
concordances[event] = 1 - .jcall(.class_CompetingRiskUtils, "D", "calculateConcordance", responseList, mortality, event)
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
return(concordances)
|
||||||
|
|
||||||
|
}
|
142
R/cr_predictions.R
Normal file
142
R/cr_predictions.R
Normal file
|
@ -0,0 +1,142 @@
|
||||||
|
|
||||||
|
|
||||||
|
convertCompetingRiskFunctionsSlow <- function(javaObject, forest){
|
||||||
|
events <- forest$params$forestResponseCombiner$events
|
||||||
|
lst <- list(javaObject = javaObject, events = events)
|
||||||
|
|
||||||
|
rightContinuousStepFunctionResponseClass <- makeResponse(.class_RightContinuousStepFunction)
|
||||||
|
|
||||||
|
kaplanMeier <- .jcall(javaObject, rightContinuousStepFunctionResponseClass, "getSurvivalCurve")
|
||||||
|
|
||||||
|
lst$time.interest <- .jcall(.class_RUtils, "[D", "extractTimes", kaplanMeier)
|
||||||
|
lst$survivorCurve <- .jcall(.class_RUtils, "[D", "extractY", kaplanMeier)
|
||||||
|
|
||||||
|
lst$cif <- matrix(nrow=length(lst$time.interest), ncol=length(events))
|
||||||
|
lst$chf <- matrix(nrow=length(lst$time.interest), ncol=length(events))
|
||||||
|
|
||||||
|
for(i in events){
|
||||||
|
cif <- .jcall(javaObject, rightContinuousStepFunctionResponseClass, "getCumulativeIncidenceFunction", as.integer(i))
|
||||||
|
lst$cif[,i] <- .jcall(.class_RUtils, "[D", "extractY", cif)
|
||||||
|
|
||||||
|
chf <- .jcall(javaObject, rightContinuousStepFunctionResponseClass, "getCauseSpecificHazardFunction", as.integer(i))
|
||||||
|
lst$chf[,i] <- .jcall(.class_RUtils, "[D", "extractY", chf)
|
||||||
|
}
|
||||||
|
|
||||||
|
class(lst) <- "CompetingRiskFunctions"
|
||||||
|
return(lst)
|
||||||
|
}
|
||||||
|
|
||||||
|
convertCompetingRiskFunctions <- compiler::cmpfun(convertCompetingRiskFunctionsSlow)
|
||||||
|
|
||||||
|
|
||||||
|
#' Competing Risk Predictions
|
||||||
|
#'
|
||||||
|
#' @param x The predictions output from a competing risk random forest.
|
||||||
|
#' @param event The event who's CIF/CHF/Mortality you are interested in.
|
||||||
|
#' @param time The time to evaluate the mortality for (relevant only for
|
||||||
|
#' \code{extractMortalities}).
|
||||||
|
#'
|
||||||
|
#' @name CompetingRiskPredictions
|
||||||
|
NULL
|
||||||
|
|
||||||
|
#' @rdname CompetingRiskPredictions
|
||||||
|
#' @export
|
||||||
|
#' @description
|
||||||
|
#' \code{extractCIF} extracts the cumulative incidence function for a prediction.
|
||||||
|
extractCIF <- function (x, event) {
|
||||||
|
UseMethod("extractCIF", x)
|
||||||
|
}
|
||||||
|
|
||||||
|
#' @export
|
||||||
|
extractCIF.CompetingRiskFunctions <- function(prediction, event){
|
||||||
|
fun <- stepfun(prediction$time.interest, c(0, prediction$cif[,event]))
|
||||||
|
|
||||||
|
class(fun) <- "function"
|
||||||
|
attr(fun, "call") <- sys.call()
|
||||||
|
return(fun)
|
||||||
|
}
|
||||||
|
|
||||||
|
#' @export
|
||||||
|
extractCIF.CompetingRiskFunctions.List <- function(predictions, event){
|
||||||
|
return(lapply(predictions, extractCIF.CompetingRiskFunctions, event))
|
||||||
|
}
|
||||||
|
|
||||||
|
#' @rdname CompetingRiskPredictions
|
||||||
|
#' @export
|
||||||
|
#' @description
|
||||||
|
#' \code{extractCHF} extracts the cause-specific cumulative hazard function for a prediction.
|
||||||
|
extractCHF <- function (x, event) {
|
||||||
|
UseMethod("extractCHF", x)
|
||||||
|
}
|
||||||
|
|
||||||
|
#' @export
|
||||||
|
extractCHF.CompetingRiskFunctions <- function(prediction, event){
|
||||||
|
fun <- stepfun(prediction$time.interest, c(0, prediction$chf[,event]))
|
||||||
|
|
||||||
|
class(fun) <- "function"
|
||||||
|
attr(fun, "call") <- sys.call()
|
||||||
|
return(fun)
|
||||||
|
}
|
||||||
|
|
||||||
|
#' @export
|
||||||
|
extractCHF.CompetingRiskFunctions.List <- function(predictions, event){
|
||||||
|
return(lapply(predictions, extractCHF.CompetingRiskFunctions, event))
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
#' @rdname CompetingRiskPredictions
|
||||||
|
#' @export
|
||||||
|
#' @description \code{extractSurvivorCurve} extracts the Kaplan-Meier estimator
|
||||||
|
#' of the overall survivor curve for a prediction.
|
||||||
|
extractSurvivorCurve <- function (x) {
|
||||||
|
UseMethod("extractSurvivorCurve", x)
|
||||||
|
}
|
||||||
|
|
||||||
|
#' @export
|
||||||
|
extractSurvivorCurve.CompetingRiskFunctions <- function(prediction){
|
||||||
|
fun <- stepfun(prediction$time.interest, c(1, prediction$survivorCurve))
|
||||||
|
|
||||||
|
class(fun) <- "function"
|
||||||
|
attr(fun, "call") <- sys.call()
|
||||||
|
return(fun)
|
||||||
|
}
|
||||||
|
|
||||||
|
#' @export
|
||||||
|
extractSurvivorCurve.CompetingRiskFunctions.List <- function(predictions){
|
||||||
|
return(lapply(predictions, extractSurvivorCurve.CompetingRiskFunctions))
|
||||||
|
}
|
||||||
|
|
||||||
|
#' @rdname CompetingRiskPredictions
|
||||||
|
#' @export
|
||||||
|
#' @description \code{extractMortalities} extracts the cause-specific
|
||||||
|
#' mortalities for a function, which here is the CIF integrated from 0 to
|
||||||
|
#' \code{time}.
|
||||||
|
extractMortalities <- function(x, event, time){
|
||||||
|
UseMethod("extractMortalities", x)
|
||||||
|
}
|
||||||
|
|
||||||
|
#' @export
|
||||||
|
extractMortalities.CompetingRiskFunctions <- function(prediction, event, time){
|
||||||
|
if(is.null(event) | anyNA(event)){
|
||||||
|
stop("event must be specified")
|
||||||
|
}
|
||||||
|
|
||||||
|
if(is.null(time) | anyNA(time)){
|
||||||
|
stop("time must be specified")
|
||||||
|
}
|
||||||
|
|
||||||
|
return(.jcall(prediction$javaObject, "D", "calculateEventSpecificMortality", as.integer(event), time))
|
||||||
|
}
|
||||||
|
|
||||||
|
#' @export
|
||||||
|
extractMortalities.CompetingRiskFunctions.List <- function(predictions, event, time){
|
||||||
|
if(is.null(event) | anyNA(event)){
|
||||||
|
stop("event must be specified")
|
||||||
|
}
|
||||||
|
|
||||||
|
if(is.null(time) | anyNA(time)){
|
||||||
|
stop("time must be specified")
|
||||||
|
}
|
||||||
|
|
||||||
|
return(as.numeric(lapply(predictions, extractMortalities.CompetingRiskFunctions, event, time)))
|
||||||
|
}
|
57
R/create_java_covariates.R
Normal file
57
R/create_java_covariates.R
Normal file
|
@ -0,0 +1,57 @@
|
||||||
|
# These functions are not exported, so I won't create their documentation either.
|
||||||
|
# I.e. it's not a mistake that the documentation below lacks the " ' " on each line.
|
||||||
|
|
||||||
|
# Covariates
|
||||||
|
#
|
||||||
|
# Creates a covariate for use in the Java code. These functions don't need to
|
||||||
|
# be directly run by a user, as loadData and train will detect, create and use
|
||||||
|
# such covariate objects.
|
||||||
|
#
|
||||||
|
# @name covariates
|
||||||
|
#
|
||||||
|
# @param name The name of the covariate, that later values will be placed
|
||||||
|
# under.
|
||||||
|
# @return An internal rJava object for later internal use.
|
||||||
|
# @keywords internal
|
||||||
|
# @examples
|
||||||
|
# # This is unnecessary for a user to do
|
||||||
|
#
|
||||||
|
# # Create a covariate
|
||||||
|
# booleanCovariate <- Java_BooleanCovariate("x1")
|
||||||
|
# factorCovariate <- Java_FactorCovariate("x2", c("cat", "dog", "mouse"))
|
||||||
|
# numericCovariate <- Java_NumericCovariate("x3")
|
||||||
|
#
|
||||||
|
# # Call the approriate Java method
|
||||||
|
# # The Java createValue method always takes in a String
|
||||||
|
# value1 <- .jcall(booleanCovariate, "Lca/joeltherrien/randomforest/covariates/Covariate$Value;", "createValue", "true")
|
||||||
|
# value2 <- .jcall(factorCovariate, "Lca/joeltherrien/randomforest/covariates/Covariate$Value;", "createValue", "dog")
|
||||||
|
# value3 <- .jcall(numericCovariate, "Lca/joeltherrien/randomforest/covariates/Covariate$Value;", "createValue", "3.14")
|
||||||
|
NULL
|
||||||
|
|
||||||
|
# @rdname covariates
|
||||||
|
Java_BooleanCovariate <- function(name, index){
|
||||||
|
covariate <- .jnew(.class_BooleanCovariate, name, as.integer(index))
|
||||||
|
covariate <- .jcast(covariate, .class_Object) # needed for later adding it into Java Lists
|
||||||
|
|
||||||
|
return(covariate)
|
||||||
|
}
|
||||||
|
|
||||||
|
# @rdname covariates
|
||||||
|
# @param levels The levels of the factor as a character vector
|
||||||
|
Java_FactorCovariate <- function(name, index, levels){
|
||||||
|
levelsArray <- .jarray(levels, makeResponse(.class_String))
|
||||||
|
levelsList <- .jcall("java/util/Arrays", "Ljava/util/List;", "asList", .jcast(levelsArray, "[Ljava/lang/Object;"))
|
||||||
|
|
||||||
|
covariate <- .jnew(.class_FactorCovariate, name, as.integer(index), levelsList)
|
||||||
|
covariate <- .jcast(covariate, .class_Object) # needed for later adding it into Java Lists
|
||||||
|
|
||||||
|
return(covariate)
|
||||||
|
}
|
||||||
|
|
||||||
|
# @rdname covariates
|
||||||
|
Java_NumericCovariate <- function(name, index){
|
||||||
|
covariate <- .jnew(.class_NumericCovariate, name, as.integer(index))
|
||||||
|
covariate <- .jcast(covariate, .class_Object) # needed for later adding it into Java Lists
|
||||||
|
|
||||||
|
return(covariate)
|
||||||
|
}
|
52
R/defaults.R
Normal file
52
R/defaults.R
Normal file
|
@ -0,0 +1,52 @@
|
||||||
|
splitFinderDefault <- function(responses){
|
||||||
|
if(class(responses) == "CompetingRiskResponses"){
|
||||||
|
# get all of the events
|
||||||
|
deltas <- unique(sort(responses$eventIndicator))
|
||||||
|
deltas <- deltas[!(deltas %in% as.integer(0))]
|
||||||
|
|
||||||
|
return(LogRankSplitFinder(deltas))
|
||||||
|
} else if(class(responses) == "CompetingRiskResponsesWithCensorTimes"){
|
||||||
|
# get all of the events
|
||||||
|
deltas <- sort(unique(responses$eventIndicator))
|
||||||
|
deltas <- deltas[!(deltas %in% as.integer(0))]
|
||||||
|
|
||||||
|
return(GrayLogRankSplitFinder(deltas))
|
||||||
|
}
|
||||||
|
else if(class(responses) == "numeric" | class(responses) == "integer" | class(responses) == "JNumeric"){
|
||||||
|
return(WeightedVarianceSplitFinder())
|
||||||
|
}
|
||||||
|
else{
|
||||||
|
stop("Unable to determine an appropriate split finder for this response; please specify one manually.")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
nodeResponseCombinerDefault <- function(responses){
|
||||||
|
if(class(responses) == "CompetingRiskResponses" | class(responses) == "CompetingRiskResponsesWithCensorTimes"){
|
||||||
|
# get all of the events
|
||||||
|
deltas <- unique(sort(responses$eventIndicator))
|
||||||
|
deltas <- deltas[!(deltas %in% as.integer(0))]
|
||||||
|
|
||||||
|
return(CR_ResponseCombiner(deltas))
|
||||||
|
} else if(class(responses) == "numeric" | class(responses) == "integer" | class(responses) == "JNumeric"){
|
||||||
|
return(MeanResponseCombiner())
|
||||||
|
}
|
||||||
|
else{
|
||||||
|
stop("Unable to determine an appropriate node response combiner for this response; please specify one manually")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
forestResponseCombinerDefault <- function(responses){
|
||||||
|
if(class(responses) == "CompetingRiskResponses" | class(responses) == "CompetingRiskResponsesWithCensorTimes"){
|
||||||
|
# get all of the events
|
||||||
|
deltas <- unique(sort(responses$eventIndicator))
|
||||||
|
deltas <- deltas[!(deltas %in% as.integer(0))]
|
||||||
|
|
||||||
|
return(CR_FunctionCombiner(deltas))
|
||||||
|
} else if(class(responses) == "numeric" | class(responses) == "integer" | class(responses) == "JNumeric"){
|
||||||
|
return(MeanResponseCombiner())
|
||||||
|
}
|
||||||
|
else{
|
||||||
|
stop("Unable to determine an appropriate forest response combiner for this response; please specify one manually.")
|
||||||
|
}
|
||||||
|
}
|
57
R/java_classes_directory.R
Normal file
57
R/java_classes_directory.R
Normal file
|
@ -0,0 +1,57 @@
|
||||||
|
# This file keeps track of the different Java classes used
|
||||||
|
# Whenever refactoring happens in the Java code, this file should be updated and (hopefully) nothing will break.
|
||||||
|
|
||||||
|
# General Java objects
|
||||||
|
.class_Object <- "java/lang/Object"
|
||||||
|
.class_String <- "java/lang/String"
|
||||||
|
.class_List <- "java/util/List"
|
||||||
|
.class_ArrayList <- "java/util/ArrayList"
|
||||||
|
.class_Collection <- "java/util/Collection"
|
||||||
|
.class_Serializable <- "java/io/Serializable"
|
||||||
|
.class_File <- "java/io/File"
|
||||||
|
|
||||||
|
# Utility Classes
|
||||||
|
.class_DataUtils <- "ca/joeltherrien/randomforest/utils/DataUtils"
|
||||||
|
.class_RUtils <- "ca/joeltherrien/randomforest/utils/RUtils"
|
||||||
|
.class_CompetingRiskUtils <- "ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskUtils"
|
||||||
|
.class_Settings <- "ca/joeltherrien/randomforest/Settings"
|
||||||
|
|
||||||
|
# Misc. Classes
|
||||||
|
.class_RightContinuousStepFunction <- "ca/joeltherrien/randomforest/utils/RightContinuousStepFunction"
|
||||||
|
|
||||||
|
# TreeTrainer & its Builder
|
||||||
|
.class_TreeTrainer <- "ca/joeltherrien/randomforest/tree/TreeTrainer"
|
||||||
|
.class_TreeTrainer_Builder <- "ca/joeltherrien/randomforest/tree/TreeTrainer$TreeTrainerBuilder"
|
||||||
|
|
||||||
|
# ForestTrainer & its Builder
|
||||||
|
.class_ForestTrainer <- "ca/joeltherrien/randomforest/tree/ForestTrainer"
|
||||||
|
.class_ForestTrainer_Builder <- "ca/joeltherrien/randomforest/tree/ForestTrainer$ForestTrainerBuilder"
|
||||||
|
|
||||||
|
|
||||||
|
# Covariate classes
|
||||||
|
.class_Covariate <- "ca/joeltherrien/randomforest/covariates/Covariate"
|
||||||
|
.class_BooleanCovariate <- "ca/joeltherrien/randomforest/covariates/bool/BooleanCovariate"
|
||||||
|
.class_FactorCovariate <- "ca/joeltherrien/randomforest/covariates/factor/FactorCovariate"
|
||||||
|
.class_NumericCovariate <- "ca/joeltherrien/randomforest/covariates/numeric/NumericCovariate"
|
||||||
|
|
||||||
|
# Forest class
|
||||||
|
.class_Forest <- "ca/joeltherrien/randomforest/tree/Forest"
|
||||||
|
|
||||||
|
# ResponseCombiner classes
|
||||||
|
.class_ResponseCombiner <- "ca/joeltherrien/randomforest/tree/ResponseCombiner"
|
||||||
|
.class_CompetingRiskResponseCombiner <- "ca/joeltherrien/randomforest/responses/competingrisk/combiner/CompetingRiskResponseCombiner"
|
||||||
|
.class_CompetingRiskFunctionCombiner <- "ca/joeltherrien/randomforest/responses/competingrisk/combiner/CompetingRiskFunctionCombiner"
|
||||||
|
.class_MeanResponseCombiner <- "ca/joeltherrien/randomforest/responses/regression/MeanResponseCombiner"
|
||||||
|
|
||||||
|
# SplitFinder classes
|
||||||
|
.class_SplitFinder <- "ca/joeltherrien/randomforest/tree/SplitFinder"
|
||||||
|
.class_GrayLogRankSplitFinder <- "ca/joeltherrien/randomforest/responses/competingrisk/splitfinder/GrayLogRankSplitFinder"
|
||||||
|
.class_LogRankSplitFinder <- "ca/joeltherrien/randomforest/responses/competingrisk/splitfinder/LogRankSplitFinder"
|
||||||
|
.class_WeightedVarianceSplitFinder <- "ca/joeltherrien/randomforest/responses/regression/WeightedVarianceSplitFinder"
|
||||||
|
|
||||||
|
# When a class object is returned, rJava often often wants L prepended and ; appended.
|
||||||
|
# So a list that returns "java/lang/Object" should show "Ljava/lang/Object;"
|
||||||
|
# This function does that
|
||||||
|
makeResponse <- function(className){
|
||||||
|
return(paste0("L", className, ";"))
|
||||||
|
}
|
83
R/loadData.R
Normal file
83
R/loadData.R
Normal file
|
@ -0,0 +1,83 @@
|
||||||
|
loadData <- function(data, xVarNames, responses){
|
||||||
|
|
||||||
|
if(class(responses) == "integer" | class(responses) == "numeric"){
|
||||||
|
responses <- Numeric(responses)
|
||||||
|
}
|
||||||
|
|
||||||
|
covariateList.java <- getCovariateList(data, xVarNames)
|
||||||
|
|
||||||
|
textColumns <- list()
|
||||||
|
for(j in 1:length(xVarNames)){
|
||||||
|
textColumns[[j]] <- .jarray(as.character(data[,xVarNames[j]]), "S")
|
||||||
|
}
|
||||||
|
textData <- convertRListToJava(textColumns)
|
||||||
|
|
||||||
|
rowList <- .jcall(.class_RUtils, makeResponse(.class_List), "importDataWithResponses",
|
||||||
|
responses$javaObject, covariateList.java, textData)
|
||||||
|
|
||||||
|
return(list(covariateList=covariateList.java, dataset=rowList))
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
getCovariateList <- function(data, xvarNames){
|
||||||
|
covariateList <- .jcast(.jnew(.class_ArrayList, length(xvarNames)), .class_List)
|
||||||
|
|
||||||
|
for(i in 1:length(xvarNames)){
|
||||||
|
xName = xvarNames[i]
|
||||||
|
|
||||||
|
column <- data[,xName]
|
||||||
|
|
||||||
|
if(class(column) == "numeric" | class(column) == "integer"){
|
||||||
|
covariate <- Java_NumericCovariate(xName, i-1)
|
||||||
|
}
|
||||||
|
else if(class(column) == "logical"){
|
||||||
|
covariate <- Java_BooleanCovariate(xName, i-1)
|
||||||
|
}
|
||||||
|
else if(class(column) == "factor"){
|
||||||
|
lvls <- levels(column)
|
||||||
|
covariate <- Java_FactorCovariate(xName, i-1, lvls)
|
||||||
|
}
|
||||||
|
else{
|
||||||
|
stop("Unknown column type")
|
||||||
|
}
|
||||||
|
|
||||||
|
.jcall(covariateList, "Z", "add", covariate)
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
return(covariateList)
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
loadPredictionData <- function(newData, covariateList.java){
|
||||||
|
|
||||||
|
xVarNames <- character(.jcall(covariateList.java, "I", "size"))
|
||||||
|
for(j in 1:length(xVarNames)){
|
||||||
|
covariate.java <- .jcast(
|
||||||
|
.jcall(covariateList.java, makeResponse(.class_Object), "get", as.integer(j-1)),
|
||||||
|
.class_Covariate
|
||||||
|
)
|
||||||
|
|
||||||
|
xVarNames[j] <- .jcall(covariate.java, makeResponse(.class_String), "getName")
|
||||||
|
}
|
||||||
|
|
||||||
|
if(any(!(xVarNames %in% names(newData)))){
|
||||||
|
varsMissing = xVarNames[!(xVarNames %in% names(newData))]
|
||||||
|
|
||||||
|
error <- paste0("The following covariates are not present in newdata: ", paste(varsMissing, collapse = ", "))
|
||||||
|
stop(error)
|
||||||
|
}
|
||||||
|
|
||||||
|
textColumns <- list()
|
||||||
|
for(j in 1:length(xVarNames)){
|
||||||
|
textColumns[[j]] <- .jarray(as.character(newData[,xVarNames[j]]), "S")
|
||||||
|
}
|
||||||
|
textData <- convertRListToJava(textColumns)
|
||||||
|
|
||||||
|
rowList <- .jcall(.class_RUtils, makeResponse(.class_List),
|
||||||
|
"importData", covariateList.java, textData)
|
||||||
|
|
||||||
|
|
||||||
|
return(rowList)
|
||||||
|
}
|
||||||
|
|
75
R/load_forest.R
Normal file
75
R/load_forest.R
Normal file
|
@ -0,0 +1,75 @@
|
||||||
|
|
||||||
|
|
||||||
|
#' Load Random Forest
|
||||||
|
#'
|
||||||
|
#' Loads a random forest that was saved using \code{\link{save_forest}}.
|
||||||
|
#'
|
||||||
|
#' @param forest The directory created that saved the previous forest.
|
||||||
|
#' @return A JForest object; see \code{\link{train}} for details.
|
||||||
|
#' @export
|
||||||
|
#' @seealso \code{\link{train}}, \code{\link{save_forest}}
|
||||||
|
#' @examples
|
||||||
|
#' # Regression Example
|
||||||
|
#' x1 <- rnorm(1000)
|
||||||
|
#' x2 <- rnorm(1000)
|
||||||
|
#' y <- 1 + x1 + x2 + rnorm(1000)
|
||||||
|
#'
|
||||||
|
#' data <- data.frame(x1, x2, y)
|
||||||
|
#' forest <- train(y ~ x1 + x2, data,
|
||||||
|
#' ntree=100, numberOfSplits = 5, mtry = 1, nodeSize = 5)
|
||||||
|
#'
|
||||||
|
#' save_forest(forest, "trees")
|
||||||
|
#' new_forest <- load_forest("trees")
|
||||||
|
load_forest <- function(directory){
|
||||||
|
|
||||||
|
# First load the response combiners and the split finders
|
||||||
|
nodeResponseCombiner.java <- .jcall(.class_DataUtils, makeResponse(.class_Object), "loadObject", paste0(directory, "/nodeResponseCombiner.jData"))
|
||||||
|
nodeResponseCombiner.java <- .jcast(nodeResponseCombiner.java, .class_ResponseCombiner)
|
||||||
|
|
||||||
|
splitFinder.java <- .jcall(.class_DataUtils, makeResponse(.class_Object), "loadObject", paste0(directory, "/splitFinder.jData"))
|
||||||
|
splitFinder.java <- .jcast(splitFinder.java, .class_SplitFinder)
|
||||||
|
|
||||||
|
forestResponseCombiner.java <- .jcall(.class_DataUtils, makeResponse(.class_Object), "loadObject", paste0(directory, "/forestResponseCombiner.jData"))
|
||||||
|
forestResponseCombiner.java <- .jcast(forestResponseCombiner.java, .class_ResponseCombiner)
|
||||||
|
|
||||||
|
covariateList <- .jcall(.class_DataUtils, makeResponse(.class_Object), "loadObject", paste0(directory, "/covariateList.jData"))
|
||||||
|
covariateList <- .jcast(covariateList, .class_List)
|
||||||
|
|
||||||
|
params <- readRDS(paste0(directory, "/parameters.rData"))
|
||||||
|
call <- readRDS(paste0(directory, "/call.rData"))
|
||||||
|
|
||||||
|
params$nodeResponseCombiner$javaObject <- nodeResponseCombiner.java
|
||||||
|
params$splitFinder$javaObject <- splitFinder.java
|
||||||
|
params$forestResponseCombiner$javaObject <- forestResponseCombiner.java
|
||||||
|
|
||||||
|
forest <- load_forest_args_provided(directory, params$nodeResponseCombiner, params$splitFinder, params$forestResponseCombiner, covariateList, call,
|
||||||
|
params$ntree, params$numberOfSplits, params$mtry, params$nodeSize, params$maxNodeDepth, params$splitPureNodes)
|
||||||
|
|
||||||
|
return(forest)
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
#' @export
|
||||||
|
load_forest_args_provided <- function(treeDirectory, nodeResponseCombiner, splitFinder, forestResponseCombiner,
|
||||||
|
covariateList.java, call, ntree, numberOfSplits, mtry, nodeSize, maxNodeDepth = 100000, splitPureNodes=TRUE){
|
||||||
|
|
||||||
|
params <- list(
|
||||||
|
splitFinder=splitFinder,
|
||||||
|
nodeResponseCombiner=nodeResponseCombiner,
|
||||||
|
forestResponseCombiner=forestResponseCombiner,
|
||||||
|
ntree=ntree,
|
||||||
|
numberOfSplits=numberOfSplits,
|
||||||
|
mtry=mtry,
|
||||||
|
nodeSize=nodeSize,
|
||||||
|
splitPureNodes=splitPureNodes,
|
||||||
|
maxNodeDepth = maxNodeDepth
|
||||||
|
)
|
||||||
|
|
||||||
|
forest.java <- .jcall(.class_DataUtils, makeResponse(.class_Forest), "loadForest", treeDirectory, forestResponseCombiner$javaObject)
|
||||||
|
|
||||||
|
forestObject <- list(call=call, javaObject=forest.java, covariateList=covariateList.java, params=params)
|
||||||
|
class(forestObject) <- "JRandomForest"
|
||||||
|
|
||||||
|
return(forestObject)
|
||||||
|
|
||||||
|
}
|
101
R/misc.R
Normal file
101
R/misc.R
Normal file
|
@ -0,0 +1,101 @@
|
||||||
|
#' convertRListToJava
|
||||||
|
#'
|
||||||
|
#' An internal function that converts an R list of rJava objects into a
|
||||||
|
#' java.util.List rJava object containing those objects. It's used internally,
|
||||||
|
#' and is only available because it's used in some examples that demonstrate what
|
||||||
|
#' other objects do.
|
||||||
|
#' @param lst The R list containing rJava objects
|
||||||
|
#' @export
|
||||||
|
#' @return An rJava List object to be used internally.
|
||||||
|
#' @keywords internal
|
||||||
|
#' @examples
|
||||||
|
#' x <- Numeric(1:5)
|
||||||
|
#' class(x)
|
||||||
|
#' x <- convertRListToJava(x)
|
||||||
|
#' class(x)
|
||||||
|
convertRListToJava <- function(lst){
|
||||||
|
javaList <- .jnew(.class_ArrayList, as.integer(length(lst)))
|
||||||
|
javaList <- .jcast(javaList, .class_List)
|
||||||
|
|
||||||
|
for (item in lst){
|
||||||
|
if (class(item) != "jobjRef" & class(item) != "jarrayRef"){
|
||||||
|
stop("All items in the list must be rJava Java objects")
|
||||||
|
}
|
||||||
|
|
||||||
|
.jcall(javaList, "Z", "add", .jcast(item, .class_Object))
|
||||||
|
}
|
||||||
|
|
||||||
|
return(javaList)
|
||||||
|
}
|
||||||
|
|
||||||
|
#' @export
|
||||||
|
print.SplitFinder = function(splitFinder) print(splitFinder$call)
|
||||||
|
|
||||||
|
#' @export
|
||||||
|
print.ResponseCombiner = function(combiner) print(combiner$call)
|
||||||
|
|
||||||
|
#' @export
|
||||||
|
print.JRandomForest <- function(forest){
|
||||||
|
cat("Call:\n")
|
||||||
|
print(forest$call)
|
||||||
|
cat("\nParameters:\n")
|
||||||
|
cat("\tSplit Finder: "); print(forest$params$splitFinder$call)
|
||||||
|
cat("\tTerminal Node Response Combiner: "); print(forest$params$nodeResponseCombiner$call)
|
||||||
|
cat("\tForest Response Combiner: "); print(forest$params$forestResponseCombiner$call)
|
||||||
|
cat("\t# of trees: "); cat(forest$params$ntree); cat("\n")
|
||||||
|
cat("\t# of Splits: "); cat(forest$params$numberOfSplits); cat("\n")
|
||||||
|
cat("\t# of Covariates to try: "); cat(forest$params$mtry); cat("\n")
|
||||||
|
cat("\tNode Size: "); cat(forest$params$nodeSize); cat("\n")
|
||||||
|
cat("\tMax Node Depth: "); cat(forest$params$maxNodeDepth); cat("\n")
|
||||||
|
|
||||||
|
cat("Try using me with predict() or one of the relevant commands to determine error\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
#' @export
|
||||||
|
print.CompetingRiskFunctions.List <- function(lst){
|
||||||
|
cat("Number of predictions: ")
|
||||||
|
cat(length(lst))
|
||||||
|
|
||||||
|
cat("\n\nSee the help page ?CompetingRiskPredictions for a list of relevant functions on how to use this object.\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
#' @export
|
||||||
|
print.CompetingRiskFunctions <- function(functions){
|
||||||
|
mx <- ncol(functions$cif)
|
||||||
|
cat(mx); cat(" CIFs available\n")
|
||||||
|
cat(mx); cat(" CHFs available\n")
|
||||||
|
cat("An overall survival curve available\n")
|
||||||
|
cat("\nSee the help page ?CompetingRiskPredictions for a list of relevant functions on how to use this object.\n")
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
#' @export
|
||||||
|
plot.JMatrixPlottable <- function(mat, add=FALSE, type="s", xlab="Time", ylab=NULL, col="black", ...){
|
||||||
|
if(!add){
|
||||||
|
if(is.null(ylab)){
|
||||||
|
matType <- attr(mat, "type")
|
||||||
|
event <- attr(mat, "event")
|
||||||
|
|
||||||
|
if(matType == "cif"){
|
||||||
|
ylab <- paste0("CIF-", event, "(t)")
|
||||||
|
}
|
||||||
|
else if(matType == "chf"){
|
||||||
|
ylab <- paste0("CHF(t)-", event, "(t)")
|
||||||
|
}
|
||||||
|
else if(matType == "kaplanMeier"){
|
||||||
|
ylab <- "S-hat(t)"
|
||||||
|
}
|
||||||
|
else{
|
||||||
|
ylab <- "Y"
|
||||||
|
warning("Unknown type attribute in plottable object")
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
plot(mat[,2] ~ mat[,1], col=col, type=type, xlab=xlab, ylab=ylab, ...)
|
||||||
|
}
|
||||||
|
else{
|
||||||
|
points(mat[,2] ~ mat[,1], col=col, type=type, xlab=xlab, ylab=ylab, ...)
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
96
R/predict.R
Normal file
96
R/predict.R
Normal file
|
@ -0,0 +1,96 @@
|
||||||
|
|
||||||
|
|
||||||
|
#' Predict
|
||||||
|
#'
|
||||||
|
#' Predict on the random forest.
|
||||||
|
#'
|
||||||
|
#' @param forest A forest that was previously \code{\link{train}}ed
|
||||||
|
#' @param newData The new data containing all of the previous predictor
|
||||||
|
#' covariates. Note that even if predictions are being made on the training
|
||||||
|
#' set, the dataset must be specified. \code{largeRCRF} doesn't keep track of
|
||||||
|
#' the dataset after the forest is trained.
|
||||||
|
#' @param parallel A logical indicating whether multiple cores should be
|
||||||
|
#' utilized when making the predictions. Available as an option because it's
|
||||||
|
#' been observed by this author that using Java's \code{parallelStream} can be
|
||||||
|
#' unstable on some systems. Default value is \code{TRUE}.
|
||||||
|
#' @param out.of.bag A logical indicating whether predictions should be based on
|
||||||
|
#' 'out of bag' trees; set only to \code{TRUE} if you're running predictions
|
||||||
|
#' on data that was used in the training. Default value is \code{FALSE}.
|
||||||
|
#' @return A list of responses corresponding with each row of \code{newData} if
|
||||||
|
#' it's a non-regression random forest; otherwise it returns a numeric vector.
|
||||||
|
#' @export
|
||||||
|
#' @examples
|
||||||
|
#' # Regression Example
|
||||||
|
#' x1 <- rnorm(1000)
|
||||||
|
#' x2 <- rnorm(1000)
|
||||||
|
#' y <- 1 + x1 + x2 + rnorm(1000)
|
||||||
|
#'
|
||||||
|
#' data <- data.frame(x1, x2, y)
|
||||||
|
#' forest <- train(y ~ x1 + x2, data, WeightedVarianceSplitFinder(), MeanResponseCombiner(), MeanResponseCombiner(), ntree=100, numberOfSplits = 5, mtry = 1, nodeSize = 5)
|
||||||
|
#'
|
||||||
|
#' # Fix x2 to be 0
|
||||||
|
#' newData <- data.frame(x1 = seq(from=-2, to=2, by=0.5), x2 = 0)
|
||||||
|
#' ypred <- predict(forest, newData)
|
||||||
|
#'
|
||||||
|
#' plot(ypred ~ newData$x1, type="l")
|
||||||
|
#'
|
||||||
|
#' # Competing Risk Example
|
||||||
|
#' x1 <- abs(rnorm(1000))
|
||||||
|
#' x2 <- abs(rnorm(1000))
|
||||||
|
#'
|
||||||
|
#' T1 <- rexp(1000, rate=x1)
|
||||||
|
#' T2 <- rweibull(1000, shape=x1, scale=x2)
|
||||||
|
#' C <- rexp(1000)
|
||||||
|
#' u <- pmin(T1, T2, C)
|
||||||
|
#' delta <- ifelse(u==T1, 1, ifelse(u==T2, 2, 0))
|
||||||
|
#'
|
||||||
|
#' data <- data.frame(x1, x2)
|
||||||
|
#'
|
||||||
|
#' forest <- train(CR_Response(delta, u) ~ x1 + x2, data,
|
||||||
|
#' LogRankSplitFinder(1:2), CompetingRiskResponseCombiner(1:2), CompetingRiskFunctionCombiner(1:2), ntree=100, numberOfSplits=5, mtry=1, nodeSize=10)
|
||||||
|
#' newData <- data.frame(x1 = c(-1, 0, 1), x2 = 0)
|
||||||
|
#' ypred <- predict(forest, newData)
|
||||||
|
predict.JRandomForest <- function(forest, newData=NULL, parallel=TRUE, out.of.bag=FALSE){
|
||||||
|
if(is.null(newData)){
|
||||||
|
stop("newData must be specified, even if predictions are on the training set")
|
||||||
|
}
|
||||||
|
|
||||||
|
forestObject <- forest$javaObject
|
||||||
|
covariateList <- forest$covariateList
|
||||||
|
predictionClass <- forest$params$forestResponseCombiner$outputClass
|
||||||
|
convertToRFunction <- forest$params$forestResponseCombiner$convertToRFunction
|
||||||
|
|
||||||
|
predictionDataList <- loadPredictionData(newData, covariateList)
|
||||||
|
|
||||||
|
if(parallel){
|
||||||
|
function.to.use <- "evaluate"
|
||||||
|
}
|
||||||
|
else{
|
||||||
|
function.to.use <- "evaluateSerial"
|
||||||
|
}
|
||||||
|
|
||||||
|
if(out.of.bag){
|
||||||
|
function.to.use <- paste0(function.to.use, "OOB")
|
||||||
|
}
|
||||||
|
|
||||||
|
predictionsJava <- .jcall(forestObject, makeResponse(.class_List), function.to.use, predictionDataList)
|
||||||
|
|
||||||
|
if(predictionClass == "numeric"){
|
||||||
|
predictions <- vector(length=nrow(newData), mode="numeric")
|
||||||
|
}
|
||||||
|
else{
|
||||||
|
predictions <- list()
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
for(i in 1:nrow(newData)){
|
||||||
|
prediction <- .jcall(predictionsJava, makeResponse(.class_Object), "get", as.integer(i-1))
|
||||||
|
prediction <- convertToRFunction(prediction, forest)
|
||||||
|
|
||||||
|
predictions[[i]] <- prediction
|
||||||
|
}
|
||||||
|
|
||||||
|
class(predictions) <- paste0(predictionClass, ".List")
|
||||||
|
|
||||||
|
return(predictions)
|
||||||
|
}
|
37
R/recover_forest.R
Normal file
37
R/recover_forest.R
Normal file
|
@ -0,0 +1,37 @@
|
||||||
|
|
||||||
|
recover_forest_predictable <- function(tree_directory, settingsPath) {
|
||||||
|
|
||||||
|
settings.java <- load_settings(settingsPath)
|
||||||
|
|
||||||
|
nodeResponseCombiner.java <- .jcall(settings.java, makeResponse(.class_ResponseCombiner), "getResponseCombiner")
|
||||||
|
splitFinder.java <- .jcall(settings.java, makeResponse(.class_SplitFinder), "getSplitFinder")
|
||||||
|
forestResponseCombiner.java <- .jcall(settings.java, makeResponse(.class_ResponseCombiner), "getTreeCombiner")
|
||||||
|
|
||||||
|
covariateList <- .jcall(settings.java, makeResponse(.class_List), "getCovariates")
|
||||||
|
|
||||||
|
params <- readRDS(paste0(directory, "/parameters.rData"))
|
||||||
|
call <- readRDS(paste0(directory, "/call.rData"))
|
||||||
|
|
||||||
|
params$nodeResponseCombiner$javaObject <- nodeResponseCombiner.java
|
||||||
|
params$splitFinder$javaObject <- splitFinder.java
|
||||||
|
params$forestResponseCombiner$javaObject <- forestResponseCombiner.java
|
||||||
|
|
||||||
|
forest <- load_forest_args_provided(directory, params$nodeResponseCombiner, params$splitFinder, params$forestResponseCombiner, covariateList, params, call)
|
||||||
|
|
||||||
|
return(forest)
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
load_settings <- function(settingsPath) {
|
||||||
|
settingsFile <- .jnew(.class_File, settingsPath)
|
||||||
|
settings.java <- .jcall(.class_Settings, makeResponse(.class_Settings), "load", settingsFile)
|
||||||
|
|
||||||
|
return(settings.java)
|
||||||
|
}
|
||||||
|
|
||||||
|
#' @export
|
||||||
|
load_covariate_list_from_settings <- function(settingsPath){
|
||||||
|
settings.java = load_settings(settingsPath)
|
||||||
|
covariateList <- .jcall(settings.java, makeResponse(.class_List), "getCovariates")
|
||||||
|
return(covariateList)
|
||||||
|
}
|
80
R/regressionComponents.R
Normal file
80
R/regressionComponents.R
Normal file
|
@ -0,0 +1,80 @@
|
||||||
|
|
||||||
|
#' WeightedVarianceSplitFinder
|
||||||
|
#'
|
||||||
|
#' This split finder is used in regression random forests. When a split is made,
|
||||||
|
#' this finder computes the sample variance in each group (divided by n, not
|
||||||
|
#' n-1); it then minimizes the the sum of these variances, each of them weighted
|
||||||
|
#' by their sample size divided by the total sample size of that node.
|
||||||
|
#'
|
||||||
|
#' @note There are other split finders that are used in regression random
|
||||||
|
#' forests that are not included in this package. This package is oriented
|
||||||
|
#' toward the competing risk side of survival analysis; the regression options
|
||||||
|
#' are provided as an example of how extensible the back-end Java package is.
|
||||||
|
#' If you are interested in using this package for regression (or other uses),
|
||||||
|
#' feel free to write your own components. It's really not hard to write these
|
||||||
|
#' components; the WeightedVarianceSplitFinder Java class is quite short; most
|
||||||
|
#' of the code is to reuse calculations from previous considered splits.
|
||||||
|
#' @export
|
||||||
|
#' @return A split finder object to be used in \code{\link{train}}; not
|
||||||
|
#' useful on its own.
|
||||||
|
#' @examples
|
||||||
|
#' splitFinder <- WeightedVarianceSplitFinder()
|
||||||
|
#' # You would then use it in train()
|
||||||
|
#'
|
||||||
|
#' @references https://kogalur.github.io/randomForestSRC/theory.html#section8.3
|
||||||
|
WeightedVarianceSplitFinder <- function(){
|
||||||
|
javaObject <- .jnew(.class_WeightedVarianceSplitFinder)
|
||||||
|
javaObject <- .jcast(javaObject, .class_SplitFinder)
|
||||||
|
|
||||||
|
splitFinder <- list(javaObject=javaObject, call=match.call())
|
||||||
|
class(splitFinder) <- "SplitFinder"
|
||||||
|
|
||||||
|
return(splitFinder)
|
||||||
|
}
|
||||||
|
|
||||||
|
#' MeanResponseCombiner
|
||||||
|
#'
|
||||||
|
#' This response combiner is used in regression random forests, where the
|
||||||
|
#' response in the data is a single number that needs to be averaged in each
|
||||||
|
#' terminal node, and then averaged across trees. This response combiner is
|
||||||
|
#' appropriate as an argument for both the \code{nodeResponseCombiner} and
|
||||||
|
#' \code{forestResponseCombiner} parameters in \code{\link{train}} when doing
|
||||||
|
#' regression.
|
||||||
|
#' @export
|
||||||
|
#' @return A response combiner object to be used in \code{\link{train}}; not
|
||||||
|
#' useful on its own. However, internally, a response combiner object is a
|
||||||
|
#' list consisting of the following objects:
|
||||||
|
#' \describe{
|
||||||
|
#' \item{\code{javaObject}}{The java object used in the algorithm}
|
||||||
|
#' \item{\code{call}}{The call (used in \code{print})}
|
||||||
|
#' \item{\code{outputClass}}{The R class of the outputs; used in \code{\link{predict.JRandomForest}}}
|
||||||
|
#' \item{\code{convertToRFunction}}{An R function that converts a Java prediction from the combiner into R output that is readable by a user.}
|
||||||
|
#' }
|
||||||
|
#'
|
||||||
|
#' @examples
|
||||||
|
#' responseCombiner <- MeanResponseCombiner()
|
||||||
|
#' # You would then use it in train()
|
||||||
|
#'
|
||||||
|
#' # However; I'll show an internal Java method to make it clear what it does
|
||||||
|
#' # Note that you should never have to do the following
|
||||||
|
#' x <- 1:3
|
||||||
|
#' x <- convertRListToJava(Numeric(x))
|
||||||
|
#'
|
||||||
|
#' # will output a Java object containing 2
|
||||||
|
#' output <- rJava::.jcall(responseCombiner$javaObject, "Ljava/lang/Double;", "combine", x)
|
||||||
|
#' responseCombiner$convertToRFunction(output)
|
||||||
|
#'
|
||||||
|
MeanResponseCombiner <- function(){
|
||||||
|
javaObject <- .jnew(.class_MeanResponseCombiner)
|
||||||
|
javaObject <- .jcast(javaObject, .class_ResponseCombiner)
|
||||||
|
|
||||||
|
combiner <- list(javaObject=javaObject, call=match.call(), outputClass="numeric")
|
||||||
|
combiner$convertToRFunction <- function(javaObject, ...){
|
||||||
|
return(.jcall(javaObject, "D", "doubleValue"))
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class(combiner) <- "ResponseCombiner"
|
||||||
|
|
||||||
|
return(combiner)
|
||||||
|
}
|
94
R/save_forest.R
Normal file
94
R/save_forest.R
Normal file
|
@ -0,0 +1,94 @@
|
||||||
|
|
||||||
|
|
||||||
|
#' Save Random Forests
|
||||||
|
#'
|
||||||
|
#' Saves a random forest for later use, given that the base R
|
||||||
|
#' \code{\link{base::save}} function doesn't work for this package.
|
||||||
|
#'
|
||||||
|
#' @param forest The forest to save.
|
||||||
|
#' @param directory The directory that should be created to save the trees in.
|
||||||
|
#' Note that if the directory already exists, an error will be displayed
|
||||||
|
#' unless \code{overwrite} is set to TRUE.
|
||||||
|
#' @param overwrite Should the function overwrite an existing forest; FALSE by
|
||||||
|
#' default.
|
||||||
|
|
||||||
|
#' @export
|
||||||
|
#' @seealso \code{\link{train}}, \code{\link{load_forest}}
|
||||||
|
#' @examples
|
||||||
|
#' # Regression Example
|
||||||
|
#' x1 <- rnorm(1000)
|
||||||
|
#' x2 <- rnorm(1000)
|
||||||
|
#' y <- 1 + x1 + x2 + rnorm(1000)
|
||||||
|
#'
|
||||||
|
#' data <- data.frame(x1, x2, y)
|
||||||
|
#' forest <- train(y ~ x1 + x2, data,
|
||||||
|
#' ntree=100, numberOfSplits = 5, mtry = 1, nodeSize = 5)
|
||||||
|
#'
|
||||||
|
#' save_forest(forest, "trees")
|
||||||
|
#' new_forest <- load_forest("trees")
|
||||||
|
save_forest <- function(forest, directory, overwrite=FALSE){
|
||||||
|
check_and_create_directory(directory, overwrite)
|
||||||
|
|
||||||
|
saveTrees(forest, directory)
|
||||||
|
|
||||||
|
# Next save the response combiners and the split finders
|
||||||
|
saveForestComponents(directory,
|
||||||
|
covariateList=forest$covariateList,
|
||||||
|
params=forest$params,
|
||||||
|
forestCall=forest$call)
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
saveTrees <- function(forest, directory){
|
||||||
|
# This function assumes that directory is free for us to write in.
|
||||||
|
|
||||||
|
forest.java <- forest$javaObject
|
||||||
|
|
||||||
|
# First save the trees
|
||||||
|
tree.collection.java <- .jcall(forest.java, makeResponse(.class_List), "getTrees")
|
||||||
|
numberOfTrees <- forest$params$ntree
|
||||||
|
width = round(log10(numberOfTrees))+1
|
||||||
|
treeNames <- paste0(directory, "/tree-", formatC(1:numberOfTrees, width=width, format="d", flag="0"), ".tree")
|
||||||
|
for(i in 1:numberOfTrees){
|
||||||
|
treeName <-treeNames[i]
|
||||||
|
tree.java <- .jcall(tree.collection.java, makeResponse(.class_Object), "get", as.integer(i-1))
|
||||||
|
tree.java <- .jcast(tree.java, .class_Serializable)
|
||||||
|
.jcall(.class_DataUtils, "V", "saveObject", tree.java, treeName)
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
saveForestComponents <- function(directory, covariateList, params, forestCall){
|
||||||
|
|
||||||
|
nodeResponseCombiner <- params$nodeResponseCombiner
|
||||||
|
nodeResponseCombiner.java <- .jcast(nodeResponseCombiner$javaObject, .class_Serializable)
|
||||||
|
.jcall(.class_DataUtils, "V", "saveObject", nodeResponseCombiner.java, paste0(directory, "/nodeResponseCombiner.jData"))
|
||||||
|
nodeResponseCombiner$javaObject <- NULL
|
||||||
|
|
||||||
|
splitFinder <- params$splitFinder
|
||||||
|
splitFinder.java <- .jcast(splitFinder$javaObject, .class_Serializable)
|
||||||
|
.jcall(.class_DataUtils, "V", "saveObject", splitFinder.java, paste0(directory, "/splitFinder.jData"))
|
||||||
|
splitFinder$javaObject <- NULL
|
||||||
|
|
||||||
|
forestResponseCombiner <- params$forestResponseCombiner
|
||||||
|
forestResponseCombiner.java <- .jcast(forestResponseCombiner$javaObject, .class_Serializable)
|
||||||
|
.jcall(.class_DataUtils, "V", "saveObject", forestResponseCombiner.java, paste0(directory, "/forestResponseCombiner.jData"))
|
||||||
|
forestResponseCombiner$javaObject <- NULL
|
||||||
|
|
||||||
|
covariateList <- .jcast(covariateList, .class_Serializable)
|
||||||
|
.jcall(.class_DataUtils, "V", "saveObject", covariateList, paste0(directory, "/covariateList.jData"))
|
||||||
|
|
||||||
|
saveRDS(object=params, file=paste0(directory, "/parameters.rData"))
|
||||||
|
saveRDS(object=forestCall, file=paste0(directory, "/call.rData"))
|
||||||
|
}
|
||||||
|
|
||||||
|
check_and_create_directory <- function(directory, overwrite){
|
||||||
|
if(file.exists(directory) & !overwrite){
|
||||||
|
stop(paste(directory, "already exists; will not modify it. Please remove/rename it or set overwrite=TRUE"))
|
||||||
|
}
|
||||||
|
else if(file.exists(directory) & overwrite){
|
||||||
|
unlink(directory)
|
||||||
|
}
|
||||||
|
|
||||||
|
dir.create(directory)
|
||||||
|
}
|
417
R/train.R
Normal file
417
R/train.R
Normal file
|
@ -0,0 +1,417 @@
|
||||||
|
|
||||||
|
getCores <- function(){
|
||||||
|
cores <- NA
|
||||||
|
if (requireNamespace("parallel", quietly = TRUE)){
|
||||||
|
cores <- parallel::detectCores()
|
||||||
|
}
|
||||||
|
|
||||||
|
if (is.na(cores)){
|
||||||
|
message("Unable to detect how many cores are available; defaulting to only using one. Feel free to override this by pre-specifying the cores argument.")
|
||||||
|
cores <- 1
|
||||||
|
}
|
||||||
|
|
||||||
|
return(cores)
|
||||||
|
}
|
||||||
|
|
||||||
|
#' Train Random Forests
|
||||||
|
#'
|
||||||
|
#' Trains the random forest. The type of response the random forest can be
|
||||||
|
#' trained on varies depending on the \code{splitFinder},
|
||||||
|
#' \code{nodeResponseCombiner}, and the \code{forestResponseCombiner}
|
||||||
|
#' parameters. Make sure these are compatible with each other, and with the
|
||||||
|
#' response you plug in. \code{splitFinder} should work on the responses you are
|
||||||
|
#' providing; \code{nodeResponseCombiner} should combine these responses into
|
||||||
|
#' some intermediate product, and \code{forestResponseCombiner} combines these
|
||||||
|
#' intermediate products into the final output product.
|
||||||
|
#'
|
||||||
|
#' @param responses An R list of the responses. See \code{\link{CR_Response}}
|
||||||
|
#' for an example function.
|
||||||
|
#' @param covariateData A data.frame containing only the columns of the
|
||||||
|
#' covariates you wish to use in your training (unless you're using the
|
||||||
|
#' \code{formula} version of \code{train}, in which case it should contain the
|
||||||
|
#' response as well).
|
||||||
|
#' @param splitFinder A split finder that's used to score splits in the random
|
||||||
|
#' forest training algorithm. See \code{\link{Competing Risk Split Finders}}
|
||||||
|
#' or \code{\link{WeightedVarianceSplitFinder}}. If you don't specify one,
|
||||||
|
#' this function tries to pick one based on the response. For
|
||||||
|
#' \code{\link{CR_Response}} wihtout censor times, it will pick a
|
||||||
|
#' \code{\link{LogRankSplitFinder}}; while if censor times were provided it
|
||||||
|
#' will pick \code{\link{GrayLogRankSplitFinder}}; for integer or numeric
|
||||||
|
#' responses it picks a \code{\link{WeightedVarianceSplitFinder}}.
|
||||||
|
#' @param nodeResponseCombiner A response combiner that's used to combine
|
||||||
|
#' responses for each terminal node in a tree (regression example; average the
|
||||||
|
#' observations in each tree into a single number). See
|
||||||
|
#' \code{\link{CompetingRiskResponseCombiner}} or
|
||||||
|
#' \code{\link{MeanResponseCombiner}}. If you don't specify one, this function
|
||||||
|
#' tries to pick one based on the response. For \code{\link{CR_Response}} it
|
||||||
|
#' picks a \code{\link{CompetingRiskResponseCombiner}}; for integer or numeric
|
||||||
|
#' responses it picks a \code{\link{MeanResponseCombiner}}.
|
||||||
|
#' @param forestResponseCombiner A response combiner that's used to combine
|
||||||
|
#' predictions across trees into one final result (regression example; average
|
||||||
|
#' the prediction of each tree into a single number). See
|
||||||
|
#' \code{\link{CompetingRiskFunctionCombiner}} or
|
||||||
|
#' \code{\link{MeanResponseCombiner}}. If you don't specify one, this function
|
||||||
|
#' tries to pick one based on the response. For \code{\link{CR_Response}} it
|
||||||
|
#' picks a \code{\link{CompetingRiskFunctionCombiner}}; for integer or numeric
|
||||||
|
#' responses it picks a \code{\link{MeanResponseCombiner}}.
|
||||||
|
#' @param ntree An integer that specifies how many trees should be trained.
|
||||||
|
#' @param numberOfSplits A tuning parameter specifying how many random splits
|
||||||
|
#' should be tried for a covariate; a value of 0 means all splits will be
|
||||||
|
#' tried (with an exception for factors, who might have too many splits to
|
||||||
|
#' feasibly compute).
|
||||||
|
#' @param mtry A tuning parameter specifying how many covariates will be
|
||||||
|
#' randomly chosen to be tried in the splitting process. This value must be at
|
||||||
|
#' least 1.
|
||||||
|
#' @param nodeSize The algorithm will not attempt to split a node that has
|
||||||
|
#' observations less than 2*\code{nodeSize}; this results in terminal nodes
|
||||||
|
#' having a size of roughly \code{nodeSize} (true sizes may be both smaller or
|
||||||
|
#' greater). This value must be at least 1.
|
||||||
|
#' @param maxNodeDepth This parameter is analogous to \code{nodeSize} in that it
|
||||||
|
#' helps keep trees shorter; by default maxNodeDepth is an extremely high
|
||||||
|
#' number and tree depth is controlled by \code{nodeSize}.
|
||||||
|
#' @param splitPureNodes This parameter determines whether the algorithm will
|
||||||
|
#' split a pure node. If set to FALSE, then before every split it will check
|
||||||
|
#' that every response is the same, and if so, not split. If set to TRUE it
|
||||||
|
#' forgoes that check and just splits. Prediction accuracy won't change under
|
||||||
|
#' any sensible \code{nodeResponseCombiner} as all terminal nodes from a split
|
||||||
|
#' pure node should give the same prediction, so this parameter only affects
|
||||||
|
#' performance. If your response is continuous you'll likely experience faster
|
||||||
|
#' train times by setting it to TRUE. Default value is TRUE.
|
||||||
|
#' @param savePath If set, this parameter will save each tree of the random
|
||||||
|
#' forest in this directory as the forest is trained. Use this parameter if
|
||||||
|
#' you need to save memory while training. See also \code{\link{load_forest}}
|
||||||
|
#' @param savePath.overwrite This parameter controls the behaviour for what
|
||||||
|
#' happens if \code{savePath} is pointing to an existing directory. If set to
|
||||||
|
#' \code{warn} (default) then \code{train} refuses to proceed. If set to
|
||||||
|
#' \code{delete} then all the contents in that folder are deleted for the new
|
||||||
|
#' forest to be trained. Note that all contents are deleted, even those files
|
||||||
|
#' not related to \code{largeRCRF}. Use only if you're sure it's safe. If set
|
||||||
|
#' to \code{merge}, then the files describing the forest (such as its
|
||||||
|
#' parameters) are overwritten but the saved trees are not. The algorithm
|
||||||
|
#' assumes (without checking) that the existing trees are from a previous run
|
||||||
|
#' and starts from where it left off. This option is useful if recovering from
|
||||||
|
#' a crash.
|
||||||
|
#' @param cores This parameter specifies how many trees will be simultaneously
|
||||||
|
#' trained. By default the package attempts to detect how many cores you have
|
||||||
|
#' by using the \code{parallel} package, and using all of them. You may
|
||||||
|
#' specify a lower number if you wish. It is not recommended to specify a
|
||||||
|
#' number greater than the number of available cores as this will hurt
|
||||||
|
#' performance with no available benefit.
|
||||||
|
#' @param randomSeed This parameter specifies a random seed if reproducible,
|
||||||
|
#' deterministic forests are desired. The number o1
|
||||||
|
#' @export
|
||||||
|
#' @return A \code{JRandomForest} object. You may call \code{predict} or
|
||||||
|
#' \code{print} on it.
|
||||||
|
#' @seealso \code{\link{predict.JRandomForest}}
|
||||||
|
#' @note If saving memory is a concern, you can replace \code{covariateData}
|
||||||
|
#' with an environment containing one element called \code{data} as the actual
|
||||||
|
#' dataset. After the data has been imported into Java, but before the forest
|
||||||
|
#' training begins, the dataset in the environment is deleted, freeing up
|
||||||
|
#' memory in R.
|
||||||
|
#' @examples
|
||||||
|
#' # Regression Example
|
||||||
|
#' x1 <- rnorm(1000)
|
||||||
|
#' x2 <- rnorm(1000)
|
||||||
|
#' y <- 1 + x1 + x2 + rnorm(1000)
|
||||||
|
#'
|
||||||
|
#' data <- data.frame(x1, x2, y)
|
||||||
|
#' forest <- train(y ~ x1 + x2, data, WeightedVarianceSplitFinder(), MeanResponseCombiner(), MeanResponseCombiner(), ntree=100, numberOfSplits = 5, mtry = 1, nodeSize = 5)
|
||||||
|
#'
|
||||||
|
#' # Fix x2 to be 0
|
||||||
|
#' newData <- data.frame(x1 = seq(from=-2, to=2, by=0.5), x2 = 0)
|
||||||
|
#' ypred <- predict(forest, newData)
|
||||||
|
#'
|
||||||
|
#' plot(ypred ~ newData$x1, type="l")
|
||||||
|
#'
|
||||||
|
#' # Competing Risk Example
|
||||||
|
#' x1 <- abs(rnorm(1000))
|
||||||
|
#' x2 <- abs(rnorm(1000))
|
||||||
|
#'
|
||||||
|
#' T1 <- rexp(1000, rate=x1)
|
||||||
|
#' T2 <- rweibull(1000, shape=x1, scale=x2)
|
||||||
|
#' C <- rexp(1000)
|
||||||
|
#' u <- pmin(T1, T2, C)
|
||||||
|
#' delta <- ifelse(u==T1, 1, ifelse(u==T2, 2, 0))
|
||||||
|
#'
|
||||||
|
#' data <- data.frame(x1, x2)
|
||||||
|
#'
|
||||||
|
#' forest <- train(CompetingRiskResponses(delta, u) ~ x1 + x2, data,
|
||||||
|
#' LogRankSplitFinder(1:2), CompetingRiskResponseCombiner(1:2), CompetingRiskFunctionCombiner(1:2), ntree=100, numberOfSplits=5, mtry=1, nodeSize=10)
|
||||||
|
#' newData <- data.frame(x1 = c(-1, 0, 1), x2 = 0)
|
||||||
|
#' ypred <- predict(forest, newData)
|
||||||
|
train <- function(x, ...) UseMethod("train")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
#' @rdname train
|
||||||
|
#' @export
|
||||||
|
train.default <- function(responses, covariateData, splitFinder = splitFinderDefault(responses), nodeResponseCombiner = nodeResponseCombinerDefault(responses), forestResponseCombiner = forestResponseCombinerDefault(responses), ntree, numberOfSplits, mtry, nodeSize, maxNodeDepth = 100000, splitPureNodes=TRUE, savePath=NULL, savePath.overwrite=c("warn", "delete", "merge"), cores = getCores(), randomSeed = NULL){
|
||||||
|
|
||||||
|
# Some quick checks on parameters
|
||||||
|
ntree <- as.integer(ntree)
|
||||||
|
numberOfSplits <- as.integer(numberOfSplits)
|
||||||
|
mtry <- as.integer(mtry)
|
||||||
|
nodeSize <- as.integer(nodeSize)
|
||||||
|
maxNodeDepth <- as.integer(maxNodeDepth)
|
||||||
|
cores <- as.integer(cores)
|
||||||
|
|
||||||
|
if (ntree <= 0){
|
||||||
|
stop("ntree must be strictly positive.")
|
||||||
|
}
|
||||||
|
if (numberOfSplits < 0){
|
||||||
|
stop("numberOfSplits cannot be negative.")
|
||||||
|
}
|
||||||
|
if (mtry <= 0){
|
||||||
|
stop("mtry must be strictly positive. If you want to try all covariates, you can set it to be very large.")
|
||||||
|
}
|
||||||
|
if (nodeSize <= 0){
|
||||||
|
stop("nodeSize must be strictly positive.")
|
||||||
|
}
|
||||||
|
if (maxNodeDepth <= 0){
|
||||||
|
stop("maxNodeDepth must be strictly positive")
|
||||||
|
}
|
||||||
|
if (cores <= 0){
|
||||||
|
stop("cores must be strictly positive")
|
||||||
|
}
|
||||||
|
|
||||||
|
if(is.null(savePath.overwrite) | length(savePath.overwrite)==0 | !(savePath.overwrite[1] %in% c("warn", "delete", "merge"))){
|
||||||
|
stop("savePath.overwrite must be one of c(\"warn\", \"delete\", \"merge\")")
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
if(class(nodeResponseCombiner) != "ResponseCombiner"){
|
||||||
|
stop("nodeResponseCombiner must be a ResponseCombiner")
|
||||||
|
}
|
||||||
|
if(class(splitFinder) != "SplitFinder"){
|
||||||
|
stop("splitFinder must be a SplitFinder")
|
||||||
|
}
|
||||||
|
if(class(forestResponseCombiner) != "ResponseCombiner"){
|
||||||
|
stop("forestResponseCombiner must be a ResponseCombiner")
|
||||||
|
}
|
||||||
|
|
||||||
|
if(class(covariateData)=="environment"){
|
||||||
|
if(is.null(covariateData$data)){
|
||||||
|
stop("When providing an environment with the dataset, the environment must contain an item called 'data'")
|
||||||
|
}
|
||||||
|
dataset <- loadData(covariateData$data, colnames(covariateData$data), responses)
|
||||||
|
covariateData$data <- NULL # save memory, hopefully
|
||||||
|
gc() # explicitly try to save memory
|
||||||
|
}
|
||||||
|
else{
|
||||||
|
dataset <- loadData(covariateData, colnames(covariateData), responses)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
treeTrainer <- createTreeTrainer(responseCombiner=nodeResponseCombiner,
|
||||||
|
splitFinder=splitFinder,
|
||||||
|
covariateList=dataset$covariateList,
|
||||||
|
numberOfSplits=numberOfSplits,
|
||||||
|
nodeSize=nodeSize,
|
||||||
|
maxNodeDepth=maxNodeDepth,
|
||||||
|
mtry=mtry,
|
||||||
|
splitPureNodes=splitPureNodes)
|
||||||
|
|
||||||
|
forestTrainer <- createForestTrainer(treeTrainer=treeTrainer,
|
||||||
|
covariateList=dataset$covariateList,
|
||||||
|
treeResponseCombiner=forestResponseCombiner,
|
||||||
|
dataset=dataset$dataset,
|
||||||
|
ntree=ntree,
|
||||||
|
randomSeed=randomSeed,
|
||||||
|
saveTreeLocation=savePath)
|
||||||
|
|
||||||
|
params <- list(
|
||||||
|
splitFinder=splitFinder,
|
||||||
|
nodeResponseCombiner=nodeResponseCombiner,
|
||||||
|
forestResponseCombiner=forestResponseCombiner,
|
||||||
|
ntree=ntree,
|
||||||
|
numberOfSplits=numberOfSplits,
|
||||||
|
mtry=mtry,
|
||||||
|
nodeSize=nodeSize,
|
||||||
|
splitPureNodes=splitPureNodes,
|
||||||
|
maxNodeDepth = maxNodeDepth,
|
||||||
|
savePath=savePath
|
||||||
|
)
|
||||||
|
|
||||||
|
# We'll be saving an offline version of the forest
|
||||||
|
if(!is.null(savePath)){
|
||||||
|
|
||||||
|
if(file.exists(savePath)){ # we might have to remove the folder or display an error
|
||||||
|
|
||||||
|
if(savePath.overwrite[1] == "warn"){
|
||||||
|
stop(paste(savePath, "already exists; will not modify it. Please remove/rename it or set the savePath.overwrite to either 'delete' or 'merge'"))
|
||||||
|
} else if(savePath.overwrite[1] == "delete"){
|
||||||
|
unlink(savePath)
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
if(savePath.overwrite[1] != "merge"){
|
||||||
|
dir.create(savePath)
|
||||||
|
}
|
||||||
|
|
||||||
|
# First save forest components (so that if the training crashes mid-way through it can theoretically be recovered by the user)
|
||||||
|
saveForestComponents(savePath,
|
||||||
|
covariateList=dataset$covariateList,
|
||||||
|
params=params,
|
||||||
|
forestCall=match.call())
|
||||||
|
|
||||||
|
if(cores > 1){
|
||||||
|
.jcall(forestTrainer, "V", "trainParallelOnDisk", as.integer(cores))
|
||||||
|
} else {
|
||||||
|
.jcall(forestTrainer, "V", "trainSerialOnDisk")
|
||||||
|
}
|
||||||
|
|
||||||
|
# Need to now load forest trees back into memory
|
||||||
|
forest.java <- .jcall(.class_DataUtils, makeResponse(.class_Forest), "loadForest", savePath, forestResponseCombiner$javaObject)
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
else{ # save directly into memory
|
||||||
|
if(cores > 1){
|
||||||
|
forest.java <- .jcall(forestTrainer, makeResponse(.class_Forest), "trainParallelInMemory", as.integer(cores))
|
||||||
|
} else {
|
||||||
|
forest.java <- .jcall(forestTrainer, makeResponse(.class_Forest), "trainSerialInMemory")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
forestObject <- list(call=match.call(), params=params, javaObject=forest.java, covariateList=dataset$covariateList)
|
||||||
|
|
||||||
|
# TODO - remove redundant code if tests pass
|
||||||
|
#forestObject$params <- list(
|
||||||
|
# splitFinder=splitFinder,
|
||||||
|
# nodeResponseCombiner=nodeResponseCombiner,
|
||||||
|
# forestResponseCombiner=forestResponseCombiner,
|
||||||
|
# ntree=ntree,
|
||||||
|
# numberOfSplits=numberOfSplits,
|
||||||
|
# mtry=mtry,
|
||||||
|
# nodeSize=nodeSize,
|
||||||
|
# splitPureNodes=splitPureNodes,
|
||||||
|
# maxNodeDepth = maxNodeDepth,
|
||||||
|
# savePath=savePath
|
||||||
|
#)
|
||||||
|
|
||||||
|
class(forestObject) <- "JRandomForest"
|
||||||
|
return(forestObject)
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
#' @rdname train
|
||||||
|
#' @export
|
||||||
|
#' @param formula You may specify the response and covariates as a formula instead; make sure the response in the formula is still properly constructed; see \code{responses}
|
||||||
|
train.formula <- function(formula, covariateData, ...){
|
||||||
|
|
||||||
|
# Having an R copy of the data loaded at the same time can be wasteful; we
|
||||||
|
# also allow users to provide an environment of the data which gets removed
|
||||||
|
# after being imported into Java
|
||||||
|
env <- NULL
|
||||||
|
if(class(covariateData) == "environment"){
|
||||||
|
if(is.null(covariateData$data)){
|
||||||
|
stop("When providing an environment with the dataset, the environment must contain an item called 'data'")
|
||||||
|
}
|
||||||
|
|
||||||
|
env <- covariateData
|
||||||
|
covariateData <- env$data
|
||||||
|
}
|
||||||
|
|
||||||
|
yVar <- formula[[2]]
|
||||||
|
|
||||||
|
responses <- NULL
|
||||||
|
variablesToDrop <- character(0)
|
||||||
|
|
||||||
|
# yVar is a call object; as.character(yVar) will be the different components, including the parameters.
|
||||||
|
# if the length of yVar is > 1 then it's a function call. If the length is 1, and it's not in covariateData,
|
||||||
|
# then we also need to explicitly evaluate it
|
||||||
|
if(class(yVar)=="call" || !(as.character(yVar) %in% colnames(covariateData))){
|
||||||
|
# yVar is a function like CompetingRiskResponses
|
||||||
|
responses <- eval(expr=yVar, envir=covariateData)
|
||||||
|
|
||||||
|
if(class(formula[[3]]) == "name" && as.character(formula[[3]])=="."){
|
||||||
|
# do any of the variables match data in covariateData? We need to track that so we can drop them later
|
||||||
|
variablesToDrop <- as.character(yVar)[as.character(yVar) %in% names(covariateData)]
|
||||||
|
}
|
||||||
|
|
||||||
|
formula[[2]] <- NULL
|
||||||
|
|
||||||
|
} else if(class(yVar)=="name"){ # and implicitly yVar is contained in covariateData
|
||||||
|
variablesToDrop <- as.character(yVar)
|
||||||
|
}
|
||||||
|
|
||||||
|
# Includes responses which we may need to later cut out
|
||||||
|
mf <- model.frame(formula=formula, data=covariateData, na.action=na.pass)
|
||||||
|
|
||||||
|
if(is.null(responses)){
|
||||||
|
responses <- model.response(mf)
|
||||||
|
}
|
||||||
|
|
||||||
|
# remove any response variables
|
||||||
|
mf <- mf[,!(names(mf) %in% variablesToDrop), drop=FALSE]
|
||||||
|
|
||||||
|
# If environment was provided instead of data
|
||||||
|
if(!is.null(env)){
|
||||||
|
env$data <- mf
|
||||||
|
rm(covariateData)
|
||||||
|
forest <- train.default(responses, env, ...)
|
||||||
|
} else{
|
||||||
|
forest <- train.default(responses, mf, ...)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
forest$call <- match.call()
|
||||||
|
forest$formula <- formula
|
||||||
|
|
||||||
|
return(forest)
|
||||||
|
}
|
||||||
|
|
||||||
|
createForestTrainer <- function(treeTrainer, covariateList, treeResponseCombiner, dataset, ntree, randomSeed, saveTreeLocation){
|
||||||
|
builderClassReturned <- makeResponse(.class_ForestTrainer_Builder)
|
||||||
|
|
||||||
|
builder <- .jcall(.class_ForestTrainer, builderClassReturned, "builder")
|
||||||
|
|
||||||
|
builder <- .jcall(builder, builderClassReturned, "treeTrainer", treeTrainer)
|
||||||
|
builder <- .jcall(builder, builderClassReturned, "covariates", covariateList)
|
||||||
|
builder <- .jcall(builder, builderClassReturned, "treeResponseCombiner", treeResponseCombiner$javaObject)
|
||||||
|
builder <- .jcall(builder, builderClassReturned, "data", dataset)
|
||||||
|
builder <- .jcall(builder, builderClassReturned, "ntree", as.integer(ntree))
|
||||||
|
builder <- .jcall(builder, builderClassReturned, "displayProgress", TRUE)
|
||||||
|
|
||||||
|
if(!is.null(randomSeed)){
|
||||||
|
builder <- .jcall(builder, builderClassReturned, "randomSeed", .jlong(randomSeed))
|
||||||
|
}
|
||||||
|
else{
|
||||||
|
builder <- .jcall(builder, builderClassReturned, "randomSeed", .jlong(as.integer(Sys.time())))
|
||||||
|
}
|
||||||
|
|
||||||
|
if(!is.null(saveTreeLocation)){
|
||||||
|
builder <- .jcall(builder, builderClassReturned, "saveTreeLocation", saveTreeLocation)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
forestTrainer <- .jcall(builder, makeResponse(.class_ForestTrainer), "build")
|
||||||
|
return(forestTrainer)
|
||||||
|
}
|
||||||
|
|
||||||
|
createTreeTrainer <- function(responseCombiner, splitFinder, covariateList, numberOfSplits, nodeSize, maxNodeDepth, mtry, splitPureNodes){
|
||||||
|
builderClassReturned <- makeResponse(.class_TreeTrainer_Builder)
|
||||||
|
|
||||||
|
builder <- .jcall(.class_TreeTrainer, builderClassReturned, "builder")
|
||||||
|
|
||||||
|
builder <- .jcall(builder, builderClassReturned, "responseCombiner", responseCombiner$javaObject)
|
||||||
|
builder <- .jcall(builder, builderClassReturned, "splitFinder", splitFinder$javaObject)
|
||||||
|
builder <- .jcall(builder, builderClassReturned, "covariates", covariateList)
|
||||||
|
builder <- .jcall(builder, builderClassReturned, "numberOfSplits", as.integer(numberOfSplits))
|
||||||
|
builder <- .jcall(builder, builderClassReturned, "nodeSize", as.integer(nodeSize))
|
||||||
|
builder <- .jcall(builder, builderClassReturned, "maxNodeDepth", as.integer(maxNodeDepth))
|
||||||
|
builder <- .jcall(builder, builderClassReturned, "mtry", as.integer(mtry))
|
||||||
|
builder <- .jcall(builder, builderClassReturned, "checkNodePurity", !splitPureNodes)
|
||||||
|
|
||||||
|
treeTrainer <- .jcall(builder, makeResponse(.class_TreeTrainer), "build")
|
||||||
|
return(treeTrainer)
|
||||||
|
}
|
12
R/wrapFunction.R
Normal file
12
R/wrapFunction.R
Normal file
|
@ -0,0 +1,12 @@
|
||||||
|
wrapFunction <- function(mf){
|
||||||
|
f <- function(x){
|
||||||
|
|
||||||
|
y <- vector(mode="numeric", length=length(x))
|
||||||
|
for(i in 1:length(x)){
|
||||||
|
y[i] <- .jcall(mf, "D", "evaluate", x[i])
|
||||||
|
}
|
||||||
|
|
||||||
|
return(y)
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
7
R/zzz.R
Normal file
7
R/zzz.R
Normal file
|
@ -0,0 +1,7 @@
|
||||||
|
.onLoad <- function(libname, pkgname) {
|
||||||
|
# rJava needs to be initialized with the path to the class files
|
||||||
|
.jpackage(pkgname, lib.loc=libname, morePaths = "inst/java/")
|
||||||
|
}
|
||||||
|
|
||||||
|
#' @import rJava
|
||||||
|
NULL
|
BIN
inst/java/ca/joeltherrien/randomforest/Bootstrapper.class
Normal file
BIN
inst/java/ca/joeltherrien/randomforest/Bootstrapper.class
Normal file
Binary file not shown.
BIN
inst/java/ca/joeltherrien/randomforest/CovariateRow.class
Normal file
BIN
inst/java/ca/joeltherrien/randomforest/CovariateRow.class
Normal file
Binary file not shown.
BIN
inst/java/ca/joeltherrien/randomforest/Main.class
Normal file
BIN
inst/java/ca/joeltherrien/randomforest/Main.class
Normal file
Binary file not shown.
BIN
inst/java/ca/joeltherrien/randomforest/Row.class
Normal file
BIN
inst/java/ca/joeltherrien/randomforest/Row.class
Normal file
Binary file not shown.
Binary file not shown.
BIN
inst/java/ca/joeltherrien/randomforest/Settings.class
Normal file
BIN
inst/java/ca/joeltherrien/randomforest/Settings.class
Normal file
Binary file not shown.
BIN
inst/java/ca/joeltherrien/randomforest/VisibleForTesting.class
Normal file
BIN
inst/java/ca/joeltherrien/randomforest/VisibleForTesting.class
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
inst/java/ca/joeltherrien/randomforest/tree/Forest.class
Normal file
BIN
inst/java/ca/joeltherrien/randomforest/tree/Forest.class
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
inst/java/ca/joeltherrien/randomforest/tree/ForestTrainer.class
Normal file
BIN
inst/java/ca/joeltherrien/randomforest/tree/ForestTrainer.class
Normal file
Binary file not shown.
BIN
inst/java/ca/joeltherrien/randomforest/tree/Node.class
Normal file
BIN
inst/java/ca/joeltherrien/randomforest/tree/Node.class
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
inst/java/ca/joeltherrien/randomforest/tree/Split.class
Normal file
BIN
inst/java/ca/joeltherrien/randomforest/tree/Split.class
Normal file
Binary file not shown.
BIN
inst/java/ca/joeltherrien/randomforest/tree/SplitAndScore.class
Normal file
BIN
inst/java/ca/joeltherrien/randomforest/tree/SplitAndScore.class
Normal file
Binary file not shown.
BIN
inst/java/ca/joeltherrien/randomforest/tree/SplitFinder.class
Normal file
BIN
inst/java/ca/joeltherrien/randomforest/tree/SplitFinder.class
Normal file
Binary file not shown.
Binary file not shown.
BIN
inst/java/ca/joeltherrien/randomforest/tree/SplitNode.class
Normal file
BIN
inst/java/ca/joeltherrien/randomforest/tree/SplitNode.class
Normal file
Binary file not shown.
BIN
inst/java/ca/joeltherrien/randomforest/tree/TerminalNode.class
Normal file
BIN
inst/java/ca/joeltherrien/randomforest/tree/TerminalNode.class
Normal file
Binary file not shown.
BIN
inst/java/ca/joeltherrien/randomforest/tree/Tree.class
Normal file
BIN
inst/java/ca/joeltherrien/randomforest/tree/Tree.class
Normal file
Binary file not shown.
BIN
inst/java/ca/joeltherrien/randomforest/tree/TreeTrainer$1.class
Normal file
BIN
inst/java/ca/joeltherrien/randomforest/tree/TreeTrainer$1.class
Normal file
Binary file not shown.
Binary file not shown.
BIN
inst/java/ca/joeltherrien/randomforest/tree/TreeTrainer.class
Normal file
BIN
inst/java/ca/joeltherrien/randomforest/tree/TreeTrainer.class
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
inst/java/ca/joeltherrien/randomforest/utils/DataUtils.class
Normal file
BIN
inst/java/ca/joeltherrien/randomforest/utils/DataUtils.class
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
inst/java/ca/joeltherrien/randomforest/utils/MathFunction.class
Normal file
BIN
inst/java/ca/joeltherrien/randomforest/utils/MathFunction.class
Normal file
Binary file not shown.
Some files were not shown because too many files have changed in this diff Show more
Loading…
Reference in a new issue