diff --git a/src/main/java/ca/joeltherrien/randomforest/Main.java b/src/main/java/ca/joeltherrien/randomforest/Main.java index 24d96c8..e353c7f 100644 --- a/src/main/java/ca/joeltherrien/randomforest/Main.java +++ b/src/main/java/ca/joeltherrien/randomforest/Main.java @@ -28,6 +28,7 @@ import ca.joeltherrien.randomforest.responses.competingrisk.combiner.CompetingRi import ca.joeltherrien.randomforest.tree.Forest; import ca.joeltherrien.randomforest.tree.ForestTrainer; import ca.joeltherrien.randomforest.tree.ResponseCombiner; +import ca.joeltherrien.randomforest.utils.DataUtils; import ca.joeltherrien.randomforest.utils.StepFunction; import ca.joeltherrien.randomforest.utils.Utils; import com.fasterxml.jackson.databind.node.JsonNodeFactory; @@ -66,7 +67,7 @@ public class Main { final List covariates = settings.getCovariates(); if(args[1].equalsIgnoreCase("train")){ - final List dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation()); + final List dataset = DataUtils.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation()); final ForestTrainer forestTrainer = new ForestTrainer(settings, dataset, covariates); @@ -111,14 +112,14 @@ public class Main { return; } - final List> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getValidationDataLocation()); + final List> dataset = DataUtils.loadData(covariates, settings.getResponseLoader(), settings.getValidationDataLocation()); // Let's reduce this down to n final int n = Integer.parseInt(args[2]); Utils.reduceListToSize(dataset, n, new Random()); final File folder = new File(settings.getSaveTreeLocation()); - final Forest forest = DataLoader.loadForest(folder, responseCombiner); + final Forest forest = DataUtils.loadForest(folder, responseCombiner); final boolean useBootstrapPredictions = settings.getTrainingDataLocation().equals(settings.getValidationDataLocation()); diff --git a/src/main/java/ca/joeltherrien/randomforest/Settings.java b/src/main/java/ca/joeltherrien/randomforest/Settings.java index 3abccdf..ddffeae 100644 --- a/src/main/java/ca/joeltherrien/randomforest/Settings.java +++ b/src/main/java/ca/joeltherrien/randomforest/Settings.java @@ -30,6 +30,7 @@ import ca.joeltherrien.randomforest.responses.regression.MeanResponseCombiner; import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceGroupDifferentiator; import ca.joeltherrien.randomforest.tree.GroupDifferentiator; import ca.joeltherrien.randomforest.tree.ResponseCombiner; +import ca.joeltherrien.randomforest.utils.DataUtils; import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.ObjectMapper; @@ -55,17 +56,17 @@ import java.util.function.Function; @EqualsAndHashCode public class Settings { - private static Map> RESPONSE_LOADER_MAP = new HashMap<>(); - public static Function getResponseLoaderConstructor(final String name){ + private static Map> RESPONSE_LOADER_MAP = new HashMap<>(); + public static Function getResponseLoaderConstructor(final String name){ return RESPONSE_LOADER_MAP.get(name.toLowerCase()); } - public static void registerResponseLoaderConstructor(final String name, final Function responseLoaderConstructor){ + public static void registerResponseLoaderConstructor(final String name, final Function responseLoaderConstructor){ RESPONSE_LOADER_MAP.put(name.toLowerCase(), responseLoaderConstructor); } static{ registerResponseLoaderConstructor("double", - node -> new DataLoader.DoubleLoader(node) + node -> new DataUtils.DoubleLoader(node) ); registerResponseLoaderConstructor("CompetingRiskResponse", node -> new CompetingRiskResponse.CompetingResponseLoader(node) @@ -238,7 +239,7 @@ public class Settings { } @JsonIgnore - public DataLoader.ResponseLoader getResponseLoader(){ + public DataUtils.ResponseLoader getResponseLoader(){ final String type = yVarSettings.get("type").asText(); return getResponseLoaderConstructor(type).apply(yVarSettings); diff --git a/src/main/java/ca/joeltherrien/randomforest/covariates/numeric/NumericCovariate.java b/src/main/java/ca/joeltherrien/randomforest/covariates/numeric/NumericCovariate.java index b4c89ce..e5b7b68 100644 --- a/src/main/java/ca/joeltherrien/randomforest/covariates/numeric/NumericCovariate.java +++ b/src/main/java/ca/joeltherrien/randomforest/covariates/numeric/NumericCovariate.java @@ -59,6 +59,13 @@ public final class NumericCovariate implements Covariate { }) .collect(Collectors.toList()); + // It's technically possible for data to be empty now due to NAs which will cause a crash + // when we use random.nextInt(maxIndex). + if(data.size() == 0){ + return null; + } + + Iterator sortedDataIterator = data.stream() .map(row -> row.getCovariateValue(this).getValue()) .iterator(); diff --git a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskResponse.java b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskResponse.java index d20a8b9..827614b 100644 --- a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskResponse.java +++ b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskResponse.java @@ -16,7 +16,7 @@ package ca.joeltherrien.randomforest.responses.competingrisk; -import ca.joeltherrien.randomforest.DataLoader; +import ca.joeltherrien.randomforest.utils.DataUtils; import com.fasterxml.jackson.databind.node.ObjectNode; import lombok.Data; import lombok.RequiredArgsConstructor; @@ -36,7 +36,7 @@ public class CompetingRiskResponse implements Serializable { @RequiredArgsConstructor - public static class CompetingResponseLoader implements DataLoader.ResponseLoader{ + public static class CompetingResponseLoader implements DataUtils.ResponseLoader{ private final String deltaName; private final String uName; diff --git a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskResponseWithCensorTime.java b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskResponseWithCensorTime.java index 056a038..f984d38 100644 --- a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskResponseWithCensorTime.java +++ b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskResponseWithCensorTime.java @@ -16,7 +16,7 @@ package ca.joeltherrien.randomforest.responses.competingrisk; -import ca.joeltherrien.randomforest.DataLoader; +import ca.joeltherrien.randomforest.utils.DataUtils; import com.fasterxml.jackson.databind.node.ObjectNode; import lombok.Data; import lombok.EqualsAndHashCode; @@ -38,7 +38,7 @@ public final class CompetingRiskResponseWithCensorTime extends CompetingRiskResp } @RequiredArgsConstructor - public static class CompetingResponseWithCensorTimeLoader implements DataLoader.ResponseLoader{ + public static class CompetingResponseWithCensorTimeLoader implements DataUtils.ResponseLoader{ private final String deltaName; private final String uName; diff --git a/src/main/java/ca/joeltherrien/randomforest/tree/Forest.java b/src/main/java/ca/joeltherrien/randomforest/tree/Forest.java index 1cc37dc..f82abef 100644 --- a/src/main/java/ca/joeltherrien/randomforest/tree/Forest.java +++ b/src/main/java/ca/joeltherrien/randomforest/tree/Forest.java @@ -28,6 +28,7 @@ public class Forest { // O = output of trees, FO = forest output. In prac private final Collection> trees; private final ResponseCombiner treeResponseCombiner; + private final List covariateList; public FO evaluate(CovariateRow row){ diff --git a/src/main/java/ca/joeltherrien/randomforest/tree/ForestTrainer.java b/src/main/java/ca/joeltherrien/randomforest/tree/ForestTrainer.java index 2e31725..83964aa 100644 --- a/src/main/java/ca/joeltherrien/randomforest/tree/ForestTrainer.java +++ b/src/main/java/ca/joeltherrien/randomforest/tree/ForestTrainer.java @@ -17,18 +17,17 @@ package ca.joeltherrien.randomforest.tree; import ca.joeltherrien.randomforest.Bootstrapper; +import ca.joeltherrien.randomforest.utils.DataUtils; import ca.joeltherrien.randomforest.Row; import ca.joeltherrien.randomforest.Settings; import ca.joeltherrien.randomforest.covariates.Covariate; +import ca.joeltherrien.randomforest.utils.Utils; import lombok.AccessLevel; import lombok.AllArgsConstructor; import lombok.Builder; import java.io.File; -import java.io.FileOutputStream; import java.io.IOException; -import java.io.ObjectOutputStream; -import java.text.NumberFormat; import java.util.ArrayList; import java.util.List; import java.util.Random; @@ -38,7 +37,6 @@ import java.util.concurrent.ThreadLocalRandom; import java.util.concurrent.atomic.AtomicInteger; import java.util.stream.Collectors; import java.util.stream.Stream; -import java.util.zip.GZIPOutputStream; @Builder @AllArgsConstructor(access=AccessLevel.PRIVATE) @@ -90,6 +88,7 @@ public class ForestTrainer { return Forest.builder() .treeResponseCombiner(treeResponseCombiner) .trees(trees) + .covariateList(covariates) .build(); } @@ -112,7 +111,7 @@ public class ForestTrainer { System.out.print("\rFinished " + treeCount.get() + "/" + ntree + " trees"); } - final Runnable worker = new TreeSavedWorker(data, "tree-" + formatNumber(j+1) + ".tree", treeCount); + final Runnable worker = new TreeSavedWorker(data, "tree-" + Utils.formatNumber(j+1, ntree) + ".tree", treeCount); worker.run(); } @@ -184,7 +183,7 @@ public class ForestTrainer { final AtomicInteger treeCount = new AtomicInteger(treeFiles.length); // tracks how many trees are finished for(int j=treeCount.get(); j { return treeTrainer.growTree(bootstrappedData, random); } - public void saveTree(final Tree tree, String name) throws IOException { - final String filename = saveTreeLocation + "/" + name; - - final ObjectOutputStream outputStream = new ObjectOutputStream(new GZIPOutputStream(new FileOutputStream(filename))); - - outputStream.writeObject(tree); - - outputStream.close(); - - } private class TreeInMemoryWorker implements Runnable { @@ -250,27 +239,6 @@ public class ForestTrainer { } } - /** - * When saving trees we typically save them as tree-1.tree, tree-2.tree. This is fine until we get tree-10.tree, which - * when sorted alphabetically goes before tree-2.tree. We should instead save tree-01.tree, ... tree-10.tree. - * - * We need to set the number of 0s though based on ntree. - * - * @return - */ - private String formatNumber(int currentTreeNumber){ - final int numDigits = (int) Math.log10(ntree) + 1; - - String currentTreeNumberString = Integer.toString(currentTreeNumber); - final StringBuilder builder = new StringBuilder(); - - for(int i=0; i { final Tree tree = trainTree(bootstrapper, ThreadLocalRandom.current()); try { - saveTree(tree, filename); + DataUtils.saveObject(tree, saveTreeLocation + "/" + filename); } catch (IOException e) { System.err.println("IOException while saving " + filename); e.printStackTrace(); diff --git a/src/main/java/ca/joeltherrien/randomforest/tree/GroupDifferentiator.java b/src/main/java/ca/joeltherrien/randomforest/tree/GroupDifferentiator.java index b0da8ad..8d72efb 100644 --- a/src/main/java/ca/joeltherrien/randomforest/tree/GroupDifferentiator.java +++ b/src/main/java/ca/joeltherrien/randomforest/tree/GroupDifferentiator.java @@ -16,6 +16,7 @@ package ca.joeltherrien.randomforest.tree; +import java.io.Serializable; import java.util.Iterator; /** @@ -26,7 +27,7 @@ import java.util.Iterator; * If you want to implement a very trivial GroupDifferentiator that just takes two Lists as arguments, try extending * SimpleGroupDifferentiator. */ -public interface GroupDifferentiator { +public interface GroupDifferentiator extends Serializable { SplitAndScore differentiate(Iterator> splitIterator); diff --git a/src/main/java/ca/joeltherrien/randomforest/tree/ResponseCombiner.java b/src/main/java/ca/joeltherrien/randomforest/tree/ResponseCombiner.java index 66c6821..8bfd277 100644 --- a/src/main/java/ca/joeltherrien/randomforest/tree/ResponseCombiner.java +++ b/src/main/java/ca/joeltherrien/randomforest/tree/ResponseCombiner.java @@ -16,9 +16,10 @@ package ca.joeltherrien.randomforest.tree; +import java.io.Serializable; import java.util.List; -public interface ResponseCombiner { +public interface ResponseCombiner extends Serializable { O combine(List responses); diff --git a/src/main/java/ca/joeltherrien/randomforest/tree/Tree.java b/src/main/java/ca/joeltherrien/randomforest/tree/Tree.java index b3dba61..fdf4919 100644 --- a/src/main/java/ca/joeltherrien/randomforest/tree/Tree.java +++ b/src/main/java/ca/joeltherrien/randomforest/tree/Tree.java @@ -57,4 +57,5 @@ public class Tree implements Node { public String toString(){ return rootNode.toString(); } + } diff --git a/src/main/java/ca/joeltherrien/randomforest/tree/TreeTrainer.java b/src/main/java/ca/joeltherrien/randomforest/tree/TreeTrainer.java index 98292e3..4ce3630 100644 --- a/src/main/java/ca/joeltherrien/randomforest/tree/TreeTrainer.java +++ b/src/main/java/ca/joeltherrien/randomforest/tree/TreeTrainer.java @@ -165,6 +165,11 @@ public class TreeTrainer { for(final Covariate covariate : covariatesToTry) { final Iterator iterator = covariate.generateSplitRuleUpdater(data, this.numberOfSplits, random); + // this happens if there were only NA values in data for this covariate. Rare, but I've seen it. + if(iterator == null){ + continue; + } + final SplitAndScore candidateSplitAndScore = noGenericDifferentiator.differentiate(iterator); if(candidateSplitAndScore != null && (bestSplitAndScore == null || diff --git a/src/main/java/ca/joeltherrien/randomforest/utils/DataUtils.java b/src/main/java/ca/joeltherrien/randomforest/utils/DataUtils.java index bc5ea80..dbe7950 100644 --- a/src/main/java/ca/joeltherrien/randomforest/utils/DataUtils.java +++ b/src/main/java/ca/joeltherrien/randomforest/utils/DataUtils.java @@ -14,8 +14,9 @@ * along with this program. If not, see . */ -package ca.joeltherrien.randomforest; +package ca.joeltherrien.randomforest.utils; +import ca.joeltherrien.randomforest.Row; import ca.joeltherrien.randomforest.covariates.Covariate; import ca.joeltherrien.randomforest.tree.Forest; import ca.joeltherrien.randomforest.tree.ResponseCombiner; @@ -98,12 +99,26 @@ public class DataUtils { } + public static Forest loadForest(String folder, ResponseCombiner treeResponseCombiner) throws IOException, ClassNotFoundException { + final File directory = new File(folder); + return loadForest(directory, treeResponseCombiner); + } + public static void saveObject(Serializable object, String filename) throws IOException { final ObjectOutputStream outputStream = new ObjectOutputStream(new GZIPOutputStream(new FileOutputStream(filename))); outputStream.writeObject(object); outputStream.close(); } + public static Object loadObject(String filename) throws IOException, ClassNotFoundException { + final ObjectInputStream inputStream = new ObjectInputStream(new GZIPInputStream(new FileInputStream(filename))); + final Object object = inputStream.readObject(); + inputStream.close(); + + return object; + + } + @FunctionalInterface public interface ResponseLoader{ Y parse(CSVRecord record); diff --git a/src/main/java/ca/joeltherrien/randomforest/utils/Utils.java b/src/main/java/ca/joeltherrien/randomforest/utils/Utils.java index a8bc873..7939b86 100644 --- a/src/main/java/ca/joeltherrien/randomforest/utils/Utils.java +++ b/src/main/java/ca/joeltherrien/randomforest/utils/Utils.java @@ -208,5 +208,26 @@ public final class Utils { return map; } + /** + * When saving trees we typically save them as tree-1.tree, tree-2.tree. This is fine until we get tree-10.tree, which + * when sorted alphabetically goes before tree-2.tree. We should instead save tree-01.tree, ... tree-10.tree. + * + * We need to set the number of 0s though based on ntree. + * + * @return + */ + public static String formatNumber(int currentTreeNumber, int maxNumberOfTrees){ + final int numDigits = (int) Math.log10(maxNumberOfTrees) + 1; + + String currentTreeNumberString = Integer.toString(currentTreeNumber); + final StringBuilder builder = new StringBuilder(); + + for(int i=0; i covariates = settings.getCovariates(); - final List> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation()); + final List> dataset = DataUtils.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation()); final File directory = new File(settings.getSaveTreeLocation()); assertFalse(directory.exists()); @@ -125,7 +126,7 @@ public class TestSavingLoading { - final Forest forest = DataLoader.loadForest(directory, new CompetingRiskFunctionCombiner(new int[]{1,2}, null)); + final Forest forest = DataUtils.loadForest(directory, new CompetingRiskFunctionCombiner(new int[]{1,2}, null)); final CovariateRow predictionRow = getPredictionRow(covariates); diff --git a/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCompetingRisk.java b/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCompetingRisk.java index 5e8df31..72c09ce 100644 --- a/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCompetingRisk.java +++ b/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCompetingRisk.java @@ -17,7 +17,7 @@ package ca.joeltherrien.randomforest.competingrisk; import ca.joeltherrien.randomforest.CovariateRow; -import ca.joeltherrien.randomforest.DataLoader; +import ca.joeltherrien.randomforest.utils.DataUtils; import ca.joeltherrien.randomforest.Row; import ca.joeltherrien.randomforest.Settings; import ca.joeltherrien.randomforest.covariates.Covariate; @@ -126,7 +126,7 @@ public class TestCompetingRisk { final List covariates = settings.getCovariates(); - final List> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation()); + final List> dataset = DataUtils.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation()); final TreeTrainer treeTrainer = new TreeTrainer<>(settings, covariates); final Node node = treeTrainer.growTree(dataset, new Random()); @@ -179,7 +179,7 @@ public class TestCompetingRisk { final List covariates = settings.getCovariates(); - final List> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation()); + final List> dataset = DataUtils.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation()); final TreeTrainer treeTrainer = new TreeTrainer<>(settings, covariates); final Node node = treeTrainer.growTree(dataset, new Random()); @@ -229,7 +229,7 @@ public class TestCompetingRisk { final List covariates = settings.getCovariates(); - final List> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation()); + final List> dataset = DataUtils.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation()); final ForestTrainer forestTrainer = new ForestTrainer<>(settings, dataset, covariates); @@ -277,7 +277,7 @@ public class TestCompetingRisk { final List covariates = settings.getCovariates(); - final List> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation()); + final List> dataset = DataUtils.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation()); // Let's count the events and make sure the data was correctly read. int countCensored = 0; @@ -320,7 +320,7 @@ public class TestCompetingRisk { settings.setNtree(300); // results are too variable at 100 final List covariates = settings.getCovariates(); - final List> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), + final List> dataset = DataUtils.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation()); final ForestTrainer forestTrainer = new ForestTrainer<>(settings, dataset, covariates); @@ -341,7 +341,7 @@ public class TestCompetingRisk { settings.setNtree(300); // results are too variable at 100 final List covariates = settings.getCovariates(); - final List> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), + final List> dataset = DataUtils.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation()); final ForestTrainer forestTrainer = new ForestTrainer<>(settings, dataset, covariates); final Forest forest = forestTrainer.trainSerial(); diff --git a/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestLogRankMultipleGroupDifferentiator.java b/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestLogRankMultipleGroupDifferentiator.java index ada9ec0..ddb56d9 100644 --- a/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestLogRankMultipleGroupDifferentiator.java +++ b/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestLogRankMultipleGroupDifferentiator.java @@ -16,7 +16,7 @@ package ca.joeltherrien.randomforest.competingrisk; -import ca.joeltherrien.randomforest.DataLoader; +import ca.joeltherrien.randomforest.utils.DataUtils; import ca.joeltherrien.randomforest.Row; import ca.joeltherrien.randomforest.Settings; import ca.joeltherrien.randomforest.covariates.Covariate; @@ -36,7 +36,6 @@ import java.io.IOException; import java.util.Collections; import java.util.Iterator; import java.util.List; -import java.util.stream.Collectors; import static org.junit.jupiter.api.Assertions.assertTrue; @@ -63,8 +62,8 @@ public class TestLogRankMultipleGroupDifferentiator { final List covariates = settings.getCovariates(); - final DataLoader.ResponseLoader loader = settings.getResponseLoader(); - final List> rows = DataLoader.loadData(covariates, loader, settings.getTrainingDataLocation()); + final DataUtils.ResponseLoader loader = settings.getResponseLoader(); + final List> rows = DataUtils.loadData(covariates, loader, settings.getTrainingDataLocation()); return new Data<>(rows, covariates); } diff --git a/src/test/java/ca/joeltherrien/randomforest/covariates/NumericCovariateTest.java b/src/test/java/ca/joeltherrien/randomforest/covariates/NumericCovariateTest.java index 2af5ebe..282d9d7 100644 --- a/src/test/java/ca/joeltherrien/randomforest/covariates/NumericCovariateTest.java +++ b/src/test/java/ca/joeltherrien/randomforest/covariates/NumericCovariateTest.java @@ -35,6 +35,23 @@ public class NumericCovariateTest { return rowList; } + private List> createTestDatasetMissingValues(NumericCovariate covariate){ + final List> rowList = new ArrayList<>(); + final List covariateList = Collections.singletonList(covariate); + + final String naString = "NA"; + + rowList.add(Row.createSimple(Utils.easyMap("x", naString), covariateList, 1, 1.0)); + rowList.add(Row.createSimple(Utils.easyMap("x", naString), covariateList, 2, 1.0)); + rowList.add(Row.createSimple(Utils.easyMap("x", naString), covariateList, 3, 1.0)); + rowList.add(Row.createSimple(Utils.easyMap("x", naString), covariateList, 4, 1.0)); + rowList.add(Row.createSimple(Utils.easyMap("x", naString), covariateList, 5, 1.0)); + rowList.add(Row.createSimple(Utils.easyMap("x", naString), covariateList, 6, 1.0)); + rowList.add(Row.createSimple(Utils.easyMap("x", naString), covariateList, 7, 1.0)); + + return rowList; + } + @Test public void testNumericCovariateDeterministic(){ final NumericCovariate covariate = new NumericCovariate("x", 0); @@ -184,6 +201,21 @@ public class NumericCovariateTest { } + /** + * If all the values are missing on a covariate then we shouldn't return an iterator. + * + */ + @Test + public void testNumericSplitRuleUpdaterWithIndexesAllMissingData(){ + final NumericCovariate covariate = new NumericCovariate("x", 0); + final List> dataset = createTestDatasetMissingValues(covariate); + final NumericSplitRuleUpdater updater = covariate.generateSplitRuleUpdater(dataset, 5, new Random()); + + assertNull(updater); + } + + + private void assertContains(List subList, List greaterList){ boolean allContained = true; diff --git a/src/test/java/ca/joeltherrien/randomforest/csv/TestLoadingCSV.java b/src/test/java/ca/joeltherrien/randomforest/csv/TestLoadingCSV.java index 4be6f51..ac5dc13 100644 --- a/src/test/java/ca/joeltherrien/randomforest/csv/TestLoadingCSV.java +++ b/src/test/java/ca/joeltherrien/randomforest/csv/TestLoadingCSV.java @@ -16,7 +16,7 @@ package ca.joeltherrien.randomforest.csv; -import ca.joeltherrien.randomforest.DataLoader; +import ca.joeltherrien.randomforest.utils.DataUtils; import ca.joeltherrien.randomforest.Row; import ca.joeltherrien.randomforest.Settings; import ca.joeltherrien.randomforest.covariates.settings.BooleanCovariateSettings; @@ -64,9 +64,9 @@ public class TestLoadingCSV { final List covariates = settings.getCovariates(); - final DataLoader.ResponseLoader loader = settings.getResponseLoader(); + final DataUtils.ResponseLoader loader = settings.getResponseLoader(); - return DataLoader.loadData(covariates, loader, settings.getTrainingDataLocation()); + return DataUtils.loadData(covariates, loader, settings.getTrainingDataLocation()); } @Test