This project is now purely a library only; the code for running directly from the command line will be put into a new project. This was important because we were including large dependencies into the R code that weren't needed and created some minor licensing inconveniences.
233 lines
8 KiB
Java
233 lines
8 KiB
Java
/*
|
|
* Copyright (c) 2019 Joel Therrien.
|
|
* This program is free software: you can redistribute it and/or modify
|
|
* it under the terms of the GNU General Public License as published by
|
|
* the Free Software Foundation, either version 3 of the License, or
|
|
* (at your option) any later version.
|
|
*
|
|
* This program is distributed in the hope that it will be useful,
|
|
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
* GNU General Public License for more details.
|
|
*
|
|
* You should have received a copy of the GNU General Public License
|
|
* along with this program. If not, see <https://www.gnu.org/licenses/>.
|
|
*/
|
|
|
|
package ca.joeltherrien.randomforest;
|
|
|
|
import ca.joeltherrien.randomforest.covariates.Covariate;
|
|
import ca.joeltherrien.randomforest.utils.ResponseLoader;
|
|
import ca.joeltherrien.randomforest.utils.StepFunction;
|
|
import ca.joeltherrien.randomforest.utils.Utils;
|
|
import org.apache.commons.csv.CSVFormat;
|
|
import org.apache.commons.csv.CSVParser;
|
|
import org.apache.commons.csv.CSVRecord;
|
|
import org.junit.jupiter.api.Test;
|
|
|
|
import java.io.*;
|
|
import java.util.ArrayList;
|
|
import java.util.HashSet;
|
|
import java.util.List;
|
|
import java.util.Random;
|
|
import java.util.function.DoubleSupplier;
|
|
import java.util.stream.DoubleStream;
|
|
import java.util.zip.GZIPInputStream;
|
|
|
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
|
import static org.junit.jupiter.api.Assertions.assertTrue;
|
|
|
|
public class TestUtils {
|
|
|
|
/*
|
|
This code is copied from the runnable part of largeRCRF-Java; it's not included in the non-test code of the library
|
|
because we want to avoid packaging dependencies into the library that the R code won't use.
|
|
*/
|
|
public static <Y> List<Row<Y>> loadData(final List<Covariate> covariates, final ResponseLoader<Y> responseLoader, String filename) throws IOException {
|
|
|
|
final List<Row<Y>> dataset = new ArrayList<>();
|
|
|
|
final Reader input;
|
|
if(filename.endsWith(".gz")){
|
|
final FileInputStream inputStream = new FileInputStream(filename);
|
|
final GZIPInputStream gzipInputStream = new GZIPInputStream(inputStream);
|
|
|
|
input = new InputStreamReader(gzipInputStream);
|
|
}
|
|
else{
|
|
input = new FileReader(filename);
|
|
}
|
|
|
|
|
|
final CSVParser parser = CSVFormat.RFC4180.withFirstRecordAsHeader().parse(input);
|
|
|
|
|
|
int id = 1;
|
|
for(final CSVRecord record : parser){
|
|
final Covariate.Value[] valueArray = new Covariate.Value[covariates.size()];
|
|
|
|
for(final Covariate<?> covariate : covariates){
|
|
valueArray[covariate.getIndex()] = covariate.createValue(record.get(covariate.getName()));
|
|
}
|
|
|
|
final Y y = responseLoader.parse(record);
|
|
|
|
dataset.add(new Row<>(valueArray, id++, y));
|
|
|
|
}
|
|
|
|
return dataset;
|
|
|
|
}
|
|
|
|
public static void closeEnough(double expected, double actual, double margin){
|
|
assertTrue(Math.abs(expected - actual) < margin, "Expected " + expected + " but saw " + actual);
|
|
}
|
|
|
|
/**
|
|
* We know the function is cumulative; make sure it is ordered correctly and that that function is monotone.
|
|
*
|
|
* @param function
|
|
*/
|
|
public static void assertCumulativeFunction(StepFunction function){
|
|
Double previousTime = null;
|
|
Double previousY = null;
|
|
|
|
final double[] times = function.getX();
|
|
|
|
for(int i=0; i<times.length; i++){
|
|
final double time = times[i];
|
|
final double y = function.evaluateByIndex(i);
|
|
|
|
if(previousTime != null){
|
|
assertTrue(previousTime < time, "Points should be ordered and strictly different");
|
|
assertTrue(previousY <= y, "Cumulative incidence functions are monotone");
|
|
}
|
|
|
|
previousTime = time;
|
|
previousY = y;
|
|
|
|
}
|
|
}
|
|
|
|
public static void assertSurvivalCurve(StepFunction function){
|
|
Double previousTime = null;
|
|
Double previousY = null;
|
|
|
|
final double[] times = function.getX();
|
|
|
|
for(int i=0; i<times.length; i++){
|
|
final double time = times[i];
|
|
final double y = function.evaluateByIndex(i);
|
|
|
|
if(previousTime != null){
|
|
assertTrue(previousTime < time, "Points should be ordered and strictly different");
|
|
assertTrue(previousY >= y, "Survival functions are monotone");
|
|
}
|
|
|
|
previousTime = time;
|
|
previousY = y;
|
|
|
|
}
|
|
}
|
|
|
|
@Test
|
|
public void testOneMinusECDF(){
|
|
final double[] times = new double[]{1.0, 1.0, 2.0, 3.0, 3.0, 50.0};
|
|
final StepFunction survivalCurve = Utils.estimateOneMinusECDF(times);
|
|
|
|
final double margin = 0.000001;
|
|
closeEnough(1.0, survivalCurve.evaluate(0.0), margin);
|
|
|
|
closeEnough(1.0, survivalCurve.evaluatePrevious(1.0), margin);
|
|
closeEnough(4.0/6.0, survivalCurve.evaluate(1.0), margin);
|
|
|
|
closeEnough(4.0/6.0, survivalCurve.evaluatePrevious(2.0), margin);
|
|
closeEnough(3.0/6.0, survivalCurve.evaluate(2.0), margin);
|
|
|
|
closeEnough(3.0/6.0, survivalCurve.evaluatePrevious(3.0), margin);
|
|
closeEnough(1.0/6.0, survivalCurve.evaluate(3.0), margin);
|
|
|
|
closeEnough(1.0/6.0, survivalCurve.evaluatePrevious(50.0), margin);
|
|
closeEnough(0.0, survivalCurve.evaluate(50.0), margin);
|
|
|
|
assertSurvivalCurve(survivalCurve);
|
|
|
|
}
|
|
|
|
@Test
|
|
public void reduceListToSize(){
|
|
final List<Integer> testList = Utils.easyList(1, 2, 3, 4, 5, 6, 7, 8, 9, 10);
|
|
final Random random = new Random();
|
|
for(int i=0; i<100; i++) { // want to test many times to ensure it doesn't work just due to randomness
|
|
final List<Integer> testList1 = new ArrayList<>(testList);
|
|
// Test when removing elements
|
|
Utils.reduceListToSize(testList1, 7, random);
|
|
assertEquals(7, testList1.size()); // verify proper size
|
|
assertEquals(7, new HashSet<>(testList1).size()); // verify the items are unique
|
|
|
|
|
|
final List<Integer> testList2 = new ArrayList<>(testList);
|
|
// Test when adding elements
|
|
Utils.reduceListToSize(testList2, 3, random);
|
|
assertEquals(3, testList2.size()); // verify proper size
|
|
assertEquals(3, new HashSet<>(testList2).size()); // verify the items are unique
|
|
|
|
final List<Integer> testList3 = new ArrayList<>(testList);
|
|
// verify no change
|
|
Utils.reduceListToSize(testList3, 15, random);
|
|
assertEquals(10, testList3.size()); // verify proper size
|
|
assertEquals(10, new HashSet<>(testList3).size()); // verify the items are unique
|
|
|
|
}
|
|
}
|
|
|
|
@Test
|
|
public void testBinarySearchLessThan(){
|
|
/*
|
|
There was a bug where I didn't add startIndex to range/2 for middle; no other tests caught it!
|
|
*/
|
|
final int n = 10000;
|
|
|
|
double[] x = DoubleStream.generate(new DoubleSequenceGenerator()).limit(n).toArray();
|
|
|
|
|
|
for(int i = 0; i < n; i=i+100){
|
|
final int index = Utils.binarySearchLessThan(0, n, x, i);
|
|
final int indexOff = Utils.binarySearchLessThan(0, n, x, ((double) i) + 1.5);
|
|
|
|
assertEquals(i, index);
|
|
assertEquals(i+1, indexOff);
|
|
}
|
|
|
|
final int indexTooFar = Utils.binarySearchLessThan(0, n, x, n + 100);
|
|
assertEquals(n-1, indexTooFar);
|
|
|
|
final int indexTooEarly = Utils.binarySearchLessThan(0, n, x, -100);
|
|
assertEquals(-1, indexTooEarly);
|
|
|
|
|
|
}
|
|
|
|
private static class DoubleSequenceGenerator implements DoubleSupplier {
|
|
private double previous = 0.0;
|
|
|
|
@Override
|
|
public double getAsDouble() {
|
|
return previous++;
|
|
}
|
|
}
|
|
|
|
public static void removeFolder(File file){
|
|
if(file.isFile()){
|
|
file.delete();
|
|
}
|
|
else{
|
|
for(final File inner : file.listFiles()){
|
|
removeFolder(inner);
|
|
}
|
|
file.delete();
|
|
}
|
|
}
|
|
|
|
}
|