diff --git a/src/main/java/ca/joeltherrien/randomforest/tree/Forest.java b/src/main/java/ca/joeltherrien/randomforest/tree/Forest.java index 5835e89..9d6913f 100644 --- a/src/main/java/ca/joeltherrien/randomforest/tree/Forest.java +++ b/src/main/java/ca/joeltherrien/randomforest/tree/Forest.java @@ -5,6 +5,7 @@ import lombok.Builder; import java.util.Collection; import java.util.Collections; +import java.util.List; import java.util.stream.Collectors; @Builder @@ -16,17 +17,29 @@ public class Forest { // O = output of trees, FO = forest output. In prac public FO evaluate(CovariateRow row){ return treeResponseCombiner.combine( - trees.parallelStream() + trees.stream() .map(node -> node.evaluate(row)) .collect(Collectors.toList()) ); } + /** + * Used primarily in the R package interface to avoid R loops; and for easier parallelization. + * + * @param rowList List of CovariateRows to evaluate + * @return A List of predictions. + */ + public List evaluate(List rowList){ + return rowList.parallelStream() + .map(this::evaluate) + .collect(Collectors.toList()); + } + public FO evaluateOOB(CovariateRow row){ return treeResponseCombiner.combine( - trees.parallelStream() + trees.stream() .filter(tree -> !tree.idInBootstrapSample(row.getId())) .map(node -> node.evaluate(row)) .collect(Collectors.toList()) diff --git a/src/main/java/ca/joeltherrien/randomforest/utils/RUtils.java b/src/main/java/ca/joeltherrien/randomforest/utils/RUtils.java new file mode 100644 index 0000000..419f20f --- /dev/null +++ b/src/main/java/ca/joeltherrien/randomforest/utils/RUtils.java @@ -0,0 +1,34 @@ +package ca.joeltherrien.randomforest.utils; + +import java.util.List; + +/** + * These static methods are designed to make the R interface more performant; and to avoid using R loops. + * + */ +public final class RUtils { + + public static double[] extractTimes(final MathFunction function){ + final List pointList = function.getPoints(); + final double[] times = new double[pointList.size()]; + + for(int i=0; i pointList = function.getPoints(); + final double[] times = new double[pointList.size()]; + + for(int i=0; i