diff --git a/src/main/java/ca/joeltherrien/randomforest/utils/MathFunction.java b/src/main/java/ca/joeltherrien/randomforest/utils/MathFunction.java index 3425787..61d4509 100644 --- a/src/main/java/ca/joeltherrien/randomforest/utils/MathFunction.java +++ b/src/main/java/ca/joeltherrien/randomforest/utils/MathFunction.java @@ -4,9 +4,7 @@ import lombok.Getter; import java.io.Serializable; import java.util.Collections; -import java.util.Comparator; import java.util.List; -import java.util.Optional; /** * Represents a function represented by discrete points. We assume that the function is a stepwise continuous function, @@ -35,29 +33,67 @@ public class MathFunction implements Serializable { } public Point evaluate(double time){ - Point point = defaultValue; - - for(final Point currentPoint: points){ - if(currentPoint.getTime() > time){ - break; - } - point = currentPoint; + int index = binarySearch(points, time); + if(index < 0){ + return defaultValue; + } + else{ + return points.get(index); } - - return point; } public Point evaluatePrevious(double time){ - Point point = defaultValue; - for(final Point currentPoint: points){ - if(currentPoint.getTime() >= time){ - break; - } - point = currentPoint; + int index = binarySearch(points, time) - 1; + if(index < 0){ + return defaultValue; + } + else{ + return points.get(index); } - return point; + + } + + /** + * Returns the index of the largest (in terms of time) Point that is <= the provided time value. + * + * @param points + * @param time + * @return The index of the largest Point who's time is <= the time parameter. + */ + private static int binarySearch(List points, double time){ + final int pointSize = points.size(); + + if(pointSize == 0 || points.get(pointSize-1).getTime() <= time){ + // we're already too far + return pointSize - 1; + } + + if(pointSize < 200){ + for(int i = 0; i < pointSize; i++){ + if(points.get(i).getTime() > time){ + return i - 1; + } + } + } + + // else + + + final int middle = pointSize / 2; + final double middleTime = points.get(middle).getTime(); + if(middleTime < time){ + // go right + return binarySearch(points.subList(middle, pointSize), time) + middle; + } + else if(middleTime > time){ + // go left + return binarySearch(points.subList(0, middle), time); + } + else{ // middleTime == time + return middle; + } } @Override