diff --git a/src/main/java/ca/joeltherrien/randomforest/utils/Utils.java b/src/main/java/ca/joeltherrien/randomforest/utils/Utils.java index 0bdbcd7..5f36543 100644 --- a/src/main/java/ca/joeltherrien/randomforest/utils/Utils.java +++ b/src/main/java/ca/joeltherrien/randomforest/utils/Utils.java @@ -79,6 +79,10 @@ public final class Utils { return endIndex - 1; } + if(x[startIndex] > time){ + return -1; + } + if(range < 200){ for(int i = startIndex; i < endIndex; i++){ if(x[i] > time){ @@ -90,7 +94,7 @@ public final class Utils { // else - final int middle = range / 2; + final int middle = startIndex + range / 2; final double middleTime = x[middle]; if(middleTime < time){ // go right diff --git a/src/test/java/ca/joeltherrien/randomforest/TestUtils.java b/src/test/java/ca/joeltherrien/randomforest/TestUtils.java index ab4c83a..6ef6f1b 100644 --- a/src/test/java/ca/joeltherrien/randomforest/TestUtils.java +++ b/src/test/java/ca/joeltherrien/randomforest/TestUtils.java @@ -7,6 +7,8 @@ import org.junit.jupiter.api.Test; import java.util.ArrayList; import java.util.HashSet; import java.util.List; +import java.util.function.DoubleSupplier; +import java.util.stream.DoubleStream; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertTrue; @@ -115,4 +117,40 @@ public class TestUtils { } } + @Test + public void testBinarySearchLessThan(){ + /* + There was a bug where I didn't add startIndex to range/2 for middle; no other tests caught it! + */ + final int n = 10000; + + double[] x = DoubleStream.generate(new DoubleSequenceGenerator()).limit(n).toArray(); + + + for(int i = 0; i < n; i=i+100){ + final int index = Utils.binarySearchLessThan(0, n, x, i); + final int indexOff = Utils.binarySearchLessThan(0, n, x, ((double) i) + 1.5); + + assertEquals(i, index); + assertEquals(i+1, indexOff); + } + + final int indexTooFar = Utils.binarySearchLessThan(0, n, x, n + 100); + assertEquals(n-1, indexTooFar); + + final int indexTooEarly = Utils.binarySearchLessThan(0, n, x, -100); + assertEquals(-1, indexTooEarly); + + + } + + private static class DoubleSequenceGenerator implements DoubleSupplier { + private double previous = 0.0; + + @Override + public double getAsDouble() { + return previous++; + } + } + }