diff --git a/library/src/main/java/ca/joeltherrien/randomforest/utils/DiscontinuousStepFunction.java b/library/src/main/java/ca/joeltherrien/randomforest/utils/DiscontinuousStepFunction.java
deleted file mode 100644
index 8b83564..0000000
--- a/library/src/main/java/ca/joeltherrien/randomforest/utils/DiscontinuousStepFunction.java
+++ /dev/null
@@ -1,113 +0,0 @@
-/*
- * Copyright (c) 2019 Joel Therrien.
- * This program is free software: you can redistribute it and/or modify
- * it under the terms of the GNU General Public License as published by
- * the Free Software Foundation, either version 3 of the License, or
- * (at your option) any later version.
- *
- * This program is distributed in the hope that it will be useful,
- * but WITHOUT ANY WARRANTY; without even the implied warranty of
- * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
- * GNU General Public License for more details.
- *
- * You should have received a copy of the GNU General Public License
- * along with this program. If not, see .
- */
-
-package ca.joeltherrien.randomforest.utils;
-
-/**
- * Represents a function represented by discrete points. However, the function may be right-continuous or left-continuous
- * at a given point, with no consistency. This function tracks that.
- */
-public final class DiscontinuousStepFunction extends StepFunction {
-
- private final double[] y;
- private final boolean[] isLeftContinuous;
-
- /**
- * Represents the value that should be returned by evaluate if there are points prior to the time the function is being evaluated at.
- *
- * Map be null.
- */
- private final double defaultY;
-
- public DiscontinuousStepFunction(double[] x, double[] y, boolean[] isLeftContinuous, double defaultY) {
- super(x);
- this.y = y;
- this.isLeftContinuous = isLeftContinuous;
- this.defaultY = defaultY;
- }
-
- @Override
- public double evaluate(double time){
- int index = Utils.binarySearchLessThan(0, x.length, x, time);
- if(index < 0){
- return defaultY;
- }
- else{
- if(x[index] == time){
- return evaluateByIndex(index);
- }
- else{
- return y[index];
- }
- }
- }
-
-
- @Override
- public double evaluatePrevious(double time){
- int index = Utils.binarySearchLessThan(0, x.length, x, time) - 1;
- if(index < 0){
- return defaultY;
- }
- else{
- if(x[index] == time){
- return evaluateByIndex(index);
- }
- else{
- return y[index];
- }
- }
- }
-
- @Override
- public double evaluateByIndex(int i) {
- if(isLeftContinuous[i]){
- i -= 1;
- }
-
- if(i < 0){
- return defaultY;
- }
-
- return y[i];
- }
-
- @Override
- public String toString(){
- final StringBuilder builder = new StringBuilder();
- builder.append("Default point: ");
- builder.append(defaultY);
- builder.append("\n");
-
- for(int i=0; i.
- */
-
-package ca.joeltherrien.randomforest.utils;
-
-import java.util.List;
-import java.util.ListIterator;
-
-/**
- * Represents a function represented by discrete points. We assume that the function is a stepwise left-continuous
- * function, constant at the value of the previous encountered point.
- *
- */
-public final class LeftContinuousStepFunction extends StepFunction {
-
- private final double[] y;
-
- /**
- * Represents the value that should be returned by evaluate if there are points prior to the time the function is being evaluated at.
- *
- * Map be null.
- */
- private final double defaultY;
-
- public LeftContinuousStepFunction(double[] x, double[] y, double defaultY) {
- super(x);
- this.y = y;
- this.defaultY = defaultY;
- }
-
- /**
- * This isn't a formal constructor because of limitations with abstract classes.
- *
- * @param pointList
- * @param defaultY
- * @return
- */
- public static LeftContinuousStepFunction constructFromPoints(final List pointList, final double defaultY){
-
- final double[] x = new double[pointList.size()];
- final double[] y = new double[pointList.size()];
-
- final ListIterator pointIterator = pointList.listIterator();
- while(pointIterator.hasNext()){
- final int index = pointIterator.nextIndex();
- final Point currentPoint = pointIterator.next();
-
- x[index] = currentPoint.getTime();
- y[index] = currentPoint.getY();
- }
-
- return new LeftContinuousStepFunction(x, y, defaultY);
-
- }
-
- @Override
- public double evaluate(double time){
- int index = Utils.binarySearchLessThan(0, x.length, x, time);
- if(index < 0){
- return defaultY;
- }
- else{
- if(x[index] == time){
- return evaluateByIndex(index-1);
- }
- else{
- return y[index];
- }
- }
- }
-
- @Override
- public double evaluatePrevious(double time){
- int index = Utils.binarySearchLessThan(0, x.length, x, time) - 1;
- if(index < 0){
- return defaultY;
- }
- else{
- if(x[index] == time){
- return evaluateByIndex(index-1);
- }
- else{
- return y[index];
- }
- }
- }
-
- @Override
- public double evaluateByIndex(int i) {
- if(i < 0){
- return defaultY;
- }
-
- return y[i];
- }
-
-
- @Override
- public String toString(){
- final StringBuilder builder = new StringBuilder();
- builder.append("Default point: ");
- builder.append(defaultY);
- builder.append("\n");
-
- for(int i=0; i combinedPoints = new ArrayList<>(leftX.length + rightX.length);
+
+ // These indexes represent the times that have *already* been processed.
+ // They start at -1 because we already processed the defaultY values.
+ int indexLeft = -1;
+ int indexRight = -1;
+
+ // This while-loop will keep going until one of the functions reaches the ends of its points
+ while(indexLeft < leftX.length-1 && indexRight < rightX.length-1){
+ final double time;
+ if(leftX[indexLeft+1] < rightX[indexRight+1]){
+ indexLeft += 1;
+
+ time = leftX[indexLeft];
+ combinedPoints.add(new Point(time, operator.applyAsDouble(funLeft.evaluateByIndex(indexLeft), funRight.evaluateByIndex(indexRight))));
+ }
+ else if(leftX[indexLeft+1] > rightX[indexRight+1]){
+ indexRight += 1;
+
+ time = rightX[indexRight];
+ combinedPoints.add(new Point(time, operator.applyAsDouble(funLeft.evaluateByIndex(indexLeft), funRight.evaluateByIndex(indexRight))));
+ }
+ else{ // equal times
+ indexLeft += 1;
+ indexRight += 1;
+
+ time = leftX[indexLeft];
+ combinedPoints.add(new Point(time, operator.applyAsDouble(funLeft.evaluateByIndex(indexLeft), funRight.evaluateByIndex(indexRight))));
+ }
+ }
+
+ // At this point, at least one of function left or function right has reached the end of its points
+
+ // This while-loop occurring implies that functionRight can not move forward anymore
+ while(indexLeft < leftX.length-1){
+ indexLeft += 1;
+
+ final double time = leftX[indexLeft];
+ combinedPoints.add(new Point(time, operator.applyAsDouble(funLeft.evaluateByIndex(indexLeft), funRight.evaluateByIndex(indexRight))));
+
+ }
+
+ // This while-loop occurring implies that functionLeft can not move forward anymore
+ while(indexRight < rightX.length-1){
+ indexRight += 1;
+
+ final double time = rightX[indexRight];
+ combinedPoints.add(new Point(time, operator.applyAsDouble(funLeft.evaluateByIndex(indexLeft), funRight.evaluateByIndex(indexRight))));
+
+ }
+
+
+ return RightContinuousStepFunction.constructFromPoints(combinedPoints, newDefaultY);
+
+ }
}
diff --git a/library/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestMathFunctions.java b/library/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestMathFunctions.java
index d831f2c..13ad8f8 100644
--- a/library/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestMathFunctions.java
+++ b/library/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestMathFunctions.java
@@ -16,7 +16,6 @@
package ca.joeltherrien.randomforest.competingrisk;
-import ca.joeltherrien.randomforest.utils.LeftContinuousStepFunction;
import ca.joeltherrien.randomforest.utils.RightContinuousStepFunction;
import org.junit.jupiter.api.Test;
@@ -31,13 +30,6 @@ public class TestMathFunctions {
return new RightContinuousStepFunction(time, y, 0.1);
}
- private LeftContinuousStepFunction generateLeftContinuousStepFunction(){
- final double[] time = new double[]{1.0, 2.0, 3.0};
- final double[] y = new double[]{-1.0, 1.0, 0.5};
-
- return new LeftContinuousStepFunction(time, y, 0.1);
- }
-
@Test
public void testRightContinuousStepFunction(){
final RightContinuousStepFunction function = generateRightContinuousStepFunction();
@@ -56,21 +48,5 @@ public class TestMathFunctions {
}
- @Test
- public void testLeftContinuousStepFunction(){
- final LeftContinuousStepFunction function = generateLeftContinuousStepFunction();
-
- assertEquals(0.1, function.evaluate(0.5));
- assertEquals(0.1, function.evaluate(1.0));
- assertEquals(-1.0, function.evaluate(2.0));
- assertEquals(1.0, function.evaluate(3.0));
-
-
- assertEquals(0.1, function.evaluate(0.6));
- assertEquals(-1.0, function.evaluate(1.1));
- assertEquals(1.0, function.evaluate(2.1));
- assertEquals(0.5, function.evaluate(3.1));
-
- }
}
diff --git a/library/src/test/java/ca/joeltherrien/randomforest/utils/RightContinuousStepFunctionOperatorTests.java b/library/src/test/java/ca/joeltherrien/randomforest/utils/RightContinuousStepFunctionOperatorTests.java
new file mode 100644
index 0000000..3eda073
--- /dev/null
+++ b/library/src/test/java/ca/joeltherrien/randomforest/utils/RightContinuousStepFunctionOperatorTests.java
@@ -0,0 +1,180 @@
+package ca.joeltherrien.randomforest.utils;
+
+import org.junit.jupiter.api.Test;
+
+import java.util.function.DoubleBinaryOperator;
+import java.util.function.DoubleUnaryOperator;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+
+public class RightContinuousStepFunctionOperatorTests {
+
+
+ // Idea - small and middle slightly overlap; middle and large slightly overlap.
+ // small and large never overlap (i.e. small's x values always occur before large's)
+ private final RightContinuousStepFunction smallNumbers;
+ private final RightContinuousStepFunction middleNumbers;
+ private final RightContinuousStepFunction largeNumbers;
+
+ private final double delta = 0.0000000001;
+
+ public RightContinuousStepFunctionOperatorTests(){
+ smallNumbers = RightContinuousStepFunction.constructFromPoints(Utils.easyList(
+ new Point(1.0, 1.0),
+ new Point(2.0, 3.0),
+ new Point(3.0, 2.0),
+ new Point(4.0, 1.0)
+ ), 0.0);
+
+ middleNumbers = RightContinuousStepFunction.constructFromPoints(Utils.easyList(
+ new Point(3.5, 4.0),
+ new Point(4.0, 3.0),
+ new Point(5.0, 2.0),
+ new Point(6.0, 1.0)
+ ), 5.0);
+
+ largeNumbers = RightContinuousStepFunction.constructFromPoints(Utils.easyList(
+ new Point(5.0, 5.0),
+ new Point(6.0, 6.0),
+ new Point(7.0, 3.0),
+ new Point(8.0, 2.0)
+ ), 3.0);
+
+ }
+
+ @Test
+ public void testDifferenceNoOverlapLargeMinusSmall(){
+ DoubleBinaryOperator operator = (a, b) -> a - b;
+
+ final RightContinuousStepFunction largeSmallDifference = RightContinuousStepFunction.biOperation(
+ largeNumbers,
+ smallNumbers,
+ operator);
+
+ assertEquals(8, largeSmallDifference.getX().length);
+ assertEquals(8, largeSmallDifference.getY().length);
+
+ final double[] offsetTimes = {-0.1, 0.0, 0.1};
+
+ for(int time = 1; time <= 9; time++){
+ for(double offsetTime : offsetTimes){
+ final double timeToEvaluateAt = (double) time + offsetTime;
+
+ final double largeFunEvaluation = largeNumbers.evaluate(timeToEvaluateAt);
+ final double smallFunEvaluation = smallNumbers.evaluate(timeToEvaluateAt);
+ final double expectedDifference = operator.applyAsDouble(largeFunEvaluation, smallFunEvaluation);
+
+ final double actualEvaluation = largeSmallDifference.evaluate(timeToEvaluateAt);
+
+ assertEquals(expectedDifference, actualEvaluation, delta);
+ }
+ }
+ }
+
+ @Test
+ public void testDifferenceNoOverlapSmallMinusLarge(){
+ DoubleBinaryOperator operator = (a, b) -> a - b;
+
+ final RightContinuousStepFunction smallLargeDifference = RightContinuousStepFunction.biOperation(
+ smallNumbers,
+ largeNumbers,
+ operator);
+
+ assertEquals(8, smallLargeDifference.getX().length);
+ assertEquals(8, smallLargeDifference.getY().length);
+
+ final double[] offsetTimes = {-0.1, 0.0, 0.1};
+
+ for(int time = 1; time <= 9; time++){
+ for(double offsetTime : offsetTimes){
+ final double timeToEvaluateAt = (double) time + offsetTime;
+
+ final double smallFunEvaluation = smallNumbers.evaluate(timeToEvaluateAt);
+ final double largeFunEvaluation = largeNumbers.evaluate(timeToEvaluateAt);
+ final double expectedDifference = operator.applyAsDouble(smallFunEvaluation, largeFunEvaluation);
+
+ final double actualEvaluation = smallLargeDifference.evaluate(timeToEvaluateAt);
+
+ assertEquals(expectedDifference, actualEvaluation, delta);
+ }
+ }
+ }
+
+ @Test
+ public void testDifferenceSomeOverlapLargeMinusMiddle(){
+ DoubleBinaryOperator operator = (a, b) -> a - b;
+
+ final RightContinuousStepFunction combinedFunction = RightContinuousStepFunction.biOperation(
+ largeNumbers,
+ middleNumbers,
+ operator);
+
+ assertEquals(6, combinedFunction.getX().length);
+ assertEquals(6, combinedFunction.getY().length);
+
+ final double[] offsetTimes = {-0.1, 0.0, 0.1};
+
+ for(int time = 1; time <= 9; time++){
+ for(double offsetTime : offsetTimes){
+ final double timeToEvaluateAt = (double) time + offsetTime;
+
+ final double middleFunEvaluation = middleNumbers.evaluate(timeToEvaluateAt);
+ final double largeFunEvaluation = largeNumbers.evaluate(timeToEvaluateAt);
+ final double expectedDifference = operator.applyAsDouble(largeFunEvaluation, middleFunEvaluation);
+
+ final double actualEvaluation = combinedFunction.evaluate(timeToEvaluateAt);
+
+ assertEquals(expectedDifference, actualEvaluation, delta);
+ }
+ }
+ }
+
+ @Test
+ public void testDifferenceCompleteOverlap(){
+ DoubleBinaryOperator operator = (a, b) -> a - b;
+
+ final RightContinuousStepFunction combinedFunction = RightContinuousStepFunction.biOperation(
+ middleNumbers,
+ middleNumbers,
+ operator);
+
+ assertEquals(4, combinedFunction.getX().length);
+ assertEquals(4, combinedFunction.getY().length);
+
+ final double[] offsetTimes = {-0.1, 0.0, 0.1};
+
+ for(int time = 1; time <= 9; time++){
+ for(double offsetTime : offsetTimes){
+ final double timeToEvaluateAt = (double) time + offsetTime;
+
+ final double actualEvaluation = combinedFunction.evaluate(timeToEvaluateAt);
+
+ assertEquals(0.0, actualEvaluation, delta);
+ }
+ }
+ }
+
+ @Test
+ public void testPowerFunction(){
+ final DoubleUnaryOperator operator = d -> d*d;
+
+ final RightContinuousStepFunction squaredFunction = smallNumbers.unaryOperation(operator);
+
+ assertEquals(4, squaredFunction.getX().length);
+ assertEquals(4, squaredFunction.getY().length);
+
+ final double[] offsetTimes = {-0.1, 0.0, 0.1};
+
+ for(int time = 1; time <= 9; time++){
+ for(double offsetTime : offsetTimes){
+ final double timeToEvaluateAt = (double) time + offsetTime;
+
+ final double expectedEvaluation = operator.applyAsDouble(smallNumbers.evaluate(timeToEvaluateAt));
+ final double actualEvaluation = squaredFunction.evaluate(timeToEvaluateAt);
+
+ assertEquals(expectedEvaluation, actualEvaluation, delta);
+ }
+ }
+ }
+
+}