WIP - Some changes to how trees are saved.

This commit is contained in:
Joel Therrien 2019-03-25 10:59:55 -07:00
parent 76614ee68b
commit 76b2cdd3c4
18 changed files with 123 additions and 69 deletions

View file

@ -28,6 +28,7 @@ import ca.joeltherrien.randomforest.responses.competingrisk.combiner.CompetingRi
import ca.joeltherrien.randomforest.tree.Forest; import ca.joeltherrien.randomforest.tree.Forest;
import ca.joeltherrien.randomforest.tree.ForestTrainer; import ca.joeltherrien.randomforest.tree.ForestTrainer;
import ca.joeltherrien.randomforest.tree.ResponseCombiner; import ca.joeltherrien.randomforest.tree.ResponseCombiner;
import ca.joeltherrien.randomforest.utils.DataUtils;
import ca.joeltherrien.randomforest.utils.StepFunction; import ca.joeltherrien.randomforest.utils.StepFunction;
import ca.joeltherrien.randomforest.utils.Utils; import ca.joeltherrien.randomforest.utils.Utils;
import com.fasterxml.jackson.databind.node.JsonNodeFactory; import com.fasterxml.jackson.databind.node.JsonNodeFactory;
@ -66,7 +67,7 @@ public class Main {
final List<Covariate> covariates = settings.getCovariates(); final List<Covariate> covariates = settings.getCovariates();
if(args[1].equalsIgnoreCase("train")){ if(args[1].equalsIgnoreCase("train")){
final List<Row> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation()); final List<Row> dataset = DataUtils.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation());
final ForestTrainer forestTrainer = new ForestTrainer(settings, dataset, covariates); final ForestTrainer forestTrainer = new ForestTrainer(settings, dataset, covariates);
@ -111,14 +112,14 @@ public class Main {
return; return;
} }
final List<Row<CompetingRiskResponse>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getValidationDataLocation()); final List<Row<CompetingRiskResponse>> dataset = DataUtils.loadData(covariates, settings.getResponseLoader(), settings.getValidationDataLocation());
// Let's reduce this down to n // Let's reduce this down to n
final int n = Integer.parseInt(args[2]); final int n = Integer.parseInt(args[2]);
Utils.reduceListToSize(dataset, n, new Random()); Utils.reduceListToSize(dataset, n, new Random());
final File folder = new File(settings.getSaveTreeLocation()); final File folder = new File(settings.getSaveTreeLocation());
final Forest<?, CompetingRiskFunctions> forest = DataLoader.loadForest(folder, responseCombiner); final Forest<?, CompetingRiskFunctions> forest = DataUtils.loadForest(folder, responseCombiner);
final boolean useBootstrapPredictions = settings.getTrainingDataLocation().equals(settings.getValidationDataLocation()); final boolean useBootstrapPredictions = settings.getTrainingDataLocation().equals(settings.getValidationDataLocation());

View file

@ -30,6 +30,7 @@ import ca.joeltherrien.randomforest.responses.regression.MeanResponseCombiner;
import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceGroupDifferentiator; import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceGroupDifferentiator;
import ca.joeltherrien.randomforest.tree.GroupDifferentiator; import ca.joeltherrien.randomforest.tree.GroupDifferentiator;
import ca.joeltherrien.randomforest.tree.ResponseCombiner; import ca.joeltherrien.randomforest.tree.ResponseCombiner;
import ca.joeltherrien.randomforest.utils.DataUtils;
import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonIgnore;
import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.ObjectMapper;
@ -55,17 +56,17 @@ import java.util.function.Function;
@EqualsAndHashCode @EqualsAndHashCode
public class Settings { public class Settings {
private static Map<String, Function<ObjectNode, DataLoader.ResponseLoader>> RESPONSE_LOADER_MAP = new HashMap<>(); private static Map<String, Function<ObjectNode, DataUtils.ResponseLoader>> RESPONSE_LOADER_MAP = new HashMap<>();
public static Function<ObjectNode, DataLoader.ResponseLoader> getResponseLoaderConstructor(final String name){ public static Function<ObjectNode, DataUtils.ResponseLoader> getResponseLoaderConstructor(final String name){
return RESPONSE_LOADER_MAP.get(name.toLowerCase()); return RESPONSE_LOADER_MAP.get(name.toLowerCase());
} }
public static void registerResponseLoaderConstructor(final String name, final Function<ObjectNode, DataLoader.ResponseLoader> responseLoaderConstructor){ public static void registerResponseLoaderConstructor(final String name, final Function<ObjectNode, DataUtils.ResponseLoader> responseLoaderConstructor){
RESPONSE_LOADER_MAP.put(name.toLowerCase(), responseLoaderConstructor); RESPONSE_LOADER_MAP.put(name.toLowerCase(), responseLoaderConstructor);
} }
static{ static{
registerResponseLoaderConstructor("double", registerResponseLoaderConstructor("double",
node -> new DataLoader.DoubleLoader(node) node -> new DataUtils.DoubleLoader(node)
); );
registerResponseLoaderConstructor("CompetingRiskResponse", registerResponseLoaderConstructor("CompetingRiskResponse",
node -> new CompetingRiskResponse.CompetingResponseLoader(node) node -> new CompetingRiskResponse.CompetingResponseLoader(node)
@ -238,7 +239,7 @@ public class Settings {
} }
@JsonIgnore @JsonIgnore
public DataLoader.ResponseLoader getResponseLoader(){ public DataUtils.ResponseLoader getResponseLoader(){
final String type = yVarSettings.get("type").asText(); final String type = yVarSettings.get("type").asText();
return getResponseLoaderConstructor(type).apply(yVarSettings); return getResponseLoaderConstructor(type).apply(yVarSettings);

View file

@ -59,6 +59,13 @@ public final class NumericCovariate implements Covariate<Double> {
}) })
.collect(Collectors.toList()); .collect(Collectors.toList());
// It's technically possible for data to be empty now due to NAs which will cause a crash
// when we use random.nextInt(maxIndex).
if(data.size() == 0){
return null;
}
Iterator<Double> sortedDataIterator = data.stream() Iterator<Double> sortedDataIterator = data.stream()
.map(row -> row.getCovariateValue(this).getValue()) .map(row -> row.getCovariateValue(this).getValue())
.iterator(); .iterator();

View file

@ -16,7 +16,7 @@
package ca.joeltherrien.randomforest.responses.competingrisk; package ca.joeltherrien.randomforest.responses.competingrisk;
import ca.joeltherrien.randomforest.DataLoader; import ca.joeltherrien.randomforest.utils.DataUtils;
import com.fasterxml.jackson.databind.node.ObjectNode; import com.fasterxml.jackson.databind.node.ObjectNode;
import lombok.Data; import lombok.Data;
import lombok.RequiredArgsConstructor; import lombok.RequiredArgsConstructor;
@ -36,7 +36,7 @@ public class CompetingRiskResponse implements Serializable {
@RequiredArgsConstructor @RequiredArgsConstructor
public static class CompetingResponseLoader implements DataLoader.ResponseLoader<CompetingRiskResponse>{ public static class CompetingResponseLoader implements DataUtils.ResponseLoader<CompetingRiskResponse>{
private final String deltaName; private final String deltaName;
private final String uName; private final String uName;

View file

@ -16,7 +16,7 @@
package ca.joeltherrien.randomforest.responses.competingrisk; package ca.joeltherrien.randomforest.responses.competingrisk;
import ca.joeltherrien.randomforest.DataLoader; import ca.joeltherrien.randomforest.utils.DataUtils;
import com.fasterxml.jackson.databind.node.ObjectNode; import com.fasterxml.jackson.databind.node.ObjectNode;
import lombok.Data; import lombok.Data;
import lombok.EqualsAndHashCode; import lombok.EqualsAndHashCode;
@ -38,7 +38,7 @@ public final class CompetingRiskResponseWithCensorTime extends CompetingRiskResp
} }
@RequiredArgsConstructor @RequiredArgsConstructor
public static class CompetingResponseWithCensorTimeLoader implements DataLoader.ResponseLoader<CompetingRiskResponseWithCensorTime>{ public static class CompetingResponseWithCensorTimeLoader implements DataUtils.ResponseLoader<CompetingRiskResponseWithCensorTime>{
private final String deltaName; private final String deltaName;
private final String uName; private final String uName;

View file

@ -28,6 +28,7 @@ public class Forest<O, FO> { // O = output of trees, FO = forest output. In prac
private final Collection<Tree<O>> trees; private final Collection<Tree<O>> trees;
private final ResponseCombiner<O, FO> treeResponseCombiner; private final ResponseCombiner<O, FO> treeResponseCombiner;
private final List<Covariate> covariateList;
public FO evaluate(CovariateRow row){ public FO evaluate(CovariateRow row){

View file

@ -17,18 +17,17 @@
package ca.joeltherrien.randomforest.tree; package ca.joeltherrien.randomforest.tree;
import ca.joeltherrien.randomforest.Bootstrapper; import ca.joeltherrien.randomforest.Bootstrapper;
import ca.joeltherrien.randomforest.utils.DataUtils;
import ca.joeltherrien.randomforest.Row; import ca.joeltherrien.randomforest.Row;
import ca.joeltherrien.randomforest.Settings; import ca.joeltherrien.randomforest.Settings;
import ca.joeltherrien.randomforest.covariates.Covariate; import ca.joeltherrien.randomforest.covariates.Covariate;
import ca.joeltherrien.randomforest.utils.Utils;
import lombok.AccessLevel; import lombok.AccessLevel;
import lombok.AllArgsConstructor; import lombok.AllArgsConstructor;
import lombok.Builder; import lombok.Builder;
import java.io.File; import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException; import java.io.IOException;
import java.io.ObjectOutputStream;
import java.text.NumberFormat;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.Random; import java.util.Random;
@ -38,7 +37,6 @@ import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import java.util.stream.Stream; import java.util.stream.Stream;
import java.util.zip.GZIPOutputStream;
@Builder @Builder
@AllArgsConstructor(access=AccessLevel.PRIVATE) @AllArgsConstructor(access=AccessLevel.PRIVATE)
@ -90,6 +88,7 @@ public class ForestTrainer<Y, TO, FO> {
return Forest.<TO, FO>builder() return Forest.<TO, FO>builder()
.treeResponseCombiner(treeResponseCombiner) .treeResponseCombiner(treeResponseCombiner)
.trees(trees) .trees(trees)
.covariateList(covariates)
.build(); .build();
} }
@ -112,7 +111,7 @@ public class ForestTrainer<Y, TO, FO> {
System.out.print("\rFinished " + treeCount.get() + "/" + ntree + " trees"); System.out.print("\rFinished " + treeCount.get() + "/" + ntree + " trees");
} }
final Runnable worker = new TreeSavedWorker(data, "tree-" + formatNumber(j+1) + ".tree", treeCount); final Runnable worker = new TreeSavedWorker(data, "tree-" + Utils.formatNumber(j+1, ntree) + ".tree", treeCount);
worker.run(); worker.run();
} }
@ -184,7 +183,7 @@ public class ForestTrainer<Y, TO, FO> {
final AtomicInteger treeCount = new AtomicInteger(treeFiles.length); // tracks how many trees are finished final AtomicInteger treeCount = new AtomicInteger(treeFiles.length); // tracks how many trees are finished
for(int j=treeCount.get(); j<ntree; j++){ for(int j=treeCount.get(); j<ntree; j++){
final Runnable worker = new TreeSavedWorker(data, "tree-" + formatNumber(j+1) + ".tree", treeCount); final Runnable worker = new TreeSavedWorker(data, "tree-" + Utils.formatNumber(j+1, ntree) + ".tree", treeCount);
executorService.execute(worker); executorService.execute(worker);
} }
@ -215,16 +214,6 @@ public class ForestTrainer<Y, TO, FO> {
return treeTrainer.growTree(bootstrappedData, random); return treeTrainer.growTree(bootstrappedData, random);
} }
public void saveTree(final Tree<TO> tree, String name) throws IOException {
final String filename = saveTreeLocation + "/" + name;
final ObjectOutputStream outputStream = new ObjectOutputStream(new GZIPOutputStream(new FileOutputStream(filename)));
outputStream.writeObject(tree);
outputStream.close();
}
private class TreeInMemoryWorker implements Runnable { private class TreeInMemoryWorker implements Runnable {
@ -250,27 +239,6 @@ public class ForestTrainer<Y, TO, FO> {
} }
} }
/**
* When saving trees we typically save them as tree-1.tree, tree-2.tree. This is fine until we get tree-10.tree, which
* when sorted alphabetically goes before tree-2.tree. We should instead save tree-01.tree, ... tree-10.tree.
*
* We need to set the number of 0s though based on ntree.
*
* @return
*/
private String formatNumber(int currentTreeNumber){
final int numDigits = (int) Math.log10(ntree) + 1;
String currentTreeNumberString = Integer.toString(currentTreeNumber);
final StringBuilder builder = new StringBuilder();
for(int i=0; i<numDigits-currentTreeNumberString.length(); i++){
builder.append('0');
}
builder.append(currentTreeNumberString);
return builder.toString();
}
private class TreeSavedWorker implements Runnable { private class TreeSavedWorker implements Runnable {
@ -291,7 +259,7 @@ public class ForestTrainer<Y, TO, FO> {
final Tree<TO> tree = trainTree(bootstrapper, ThreadLocalRandom.current()); final Tree<TO> tree = trainTree(bootstrapper, ThreadLocalRandom.current());
try { try {
saveTree(tree, filename); DataUtils.saveObject(tree, saveTreeLocation + "/" + filename);
} catch (IOException e) { } catch (IOException e) {
System.err.println("IOException while saving " + filename); System.err.println("IOException while saving " + filename);
e.printStackTrace(); e.printStackTrace();

View file

@ -16,6 +16,7 @@
package ca.joeltherrien.randomforest.tree; package ca.joeltherrien.randomforest.tree;
import java.io.Serializable;
import java.util.Iterator; import java.util.Iterator;
/** /**
@ -26,7 +27,7 @@ import java.util.Iterator;
* If you want to implement a very trivial GroupDifferentiator that just takes two Lists as arguments, try extending * If you want to implement a very trivial GroupDifferentiator that just takes two Lists as arguments, try extending
* SimpleGroupDifferentiator. * SimpleGroupDifferentiator.
*/ */
public interface GroupDifferentiator<Y> { public interface GroupDifferentiator<Y> extends Serializable {
SplitAndScore<Y, ?> differentiate(Iterator<Split<Y, ?>> splitIterator); SplitAndScore<Y, ?> differentiate(Iterator<Split<Y, ?>> splitIterator);

View file

@ -16,9 +16,10 @@
package ca.joeltherrien.randomforest.tree; package ca.joeltherrien.randomforest.tree;
import java.io.Serializable;
import java.util.List; import java.util.List;
public interface ResponseCombiner<I, O> { public interface ResponseCombiner<I, O> extends Serializable {
O combine(List<I> responses); O combine(List<I> responses);

View file

@ -57,4 +57,5 @@ public class Tree<Y> implements Node<Y> {
public String toString(){ public String toString(){
return rootNode.toString(); return rootNode.toString();
} }
} }

View file

@ -165,6 +165,11 @@ public class TreeTrainer<Y, O> {
for(final Covariate covariate : covariatesToTry) { for(final Covariate covariate : covariatesToTry) {
final Iterator<Split> iterator = covariate.generateSplitRuleUpdater(data, this.numberOfSplits, random); final Iterator<Split> iterator = covariate.generateSplitRuleUpdater(data, this.numberOfSplits, random);
// this happens if there were only NA values in data for this covariate. Rare, but I've seen it.
if(iterator == null){
continue;
}
final SplitAndScore<Y, ?> candidateSplitAndScore = noGenericDifferentiator.differentiate(iterator); final SplitAndScore<Y, ?> candidateSplitAndScore = noGenericDifferentiator.differentiate(iterator);
if(candidateSplitAndScore != null && (bestSplitAndScore == null || if(candidateSplitAndScore != null && (bestSplitAndScore == null ||

View file

@ -14,8 +14,9 @@
* along with this program. If not, see <https://www.gnu.org/licenses/>. * along with this program. If not, see <https://www.gnu.org/licenses/>.
*/ */
package ca.joeltherrien.randomforest; package ca.joeltherrien.randomforest.utils;
import ca.joeltherrien.randomforest.Row;
import ca.joeltherrien.randomforest.covariates.Covariate; import ca.joeltherrien.randomforest.covariates.Covariate;
import ca.joeltherrien.randomforest.tree.Forest; import ca.joeltherrien.randomforest.tree.Forest;
import ca.joeltherrien.randomforest.tree.ResponseCombiner; import ca.joeltherrien.randomforest.tree.ResponseCombiner;
@ -98,12 +99,26 @@ public class DataUtils {
} }
public static <O, FO> Forest<O, FO> loadForest(String folder, ResponseCombiner<O, FO> treeResponseCombiner) throws IOException, ClassNotFoundException {
final File directory = new File(folder);
return loadForest(directory, treeResponseCombiner);
}
public static void saveObject(Serializable object, String filename) throws IOException { public static void saveObject(Serializable object, String filename) throws IOException {
final ObjectOutputStream outputStream = new ObjectOutputStream(new GZIPOutputStream(new FileOutputStream(filename))); final ObjectOutputStream outputStream = new ObjectOutputStream(new GZIPOutputStream(new FileOutputStream(filename)));
outputStream.writeObject(object); outputStream.writeObject(object);
outputStream.close(); outputStream.close();
} }
public static Object loadObject(String filename) throws IOException, ClassNotFoundException {
final ObjectInputStream inputStream = new ObjectInputStream(new GZIPInputStream(new FileInputStream(filename)));
final Object object = inputStream.readObject();
inputStream.close();
return object;
}
@FunctionalInterface @FunctionalInterface
public interface ResponseLoader<Y>{ public interface ResponseLoader<Y>{
Y parse(CSVRecord record); Y parse(CSVRecord record);

View file

@ -208,5 +208,26 @@ public final class Utils {
return map; return map;
} }
/**
* When saving trees we typically save them as tree-1.tree, tree-2.tree. This is fine until we get tree-10.tree, which
* when sorted alphabetically goes before tree-2.tree. We should instead save tree-01.tree, ... tree-10.tree.
*
* We need to set the number of 0s though based on ntree.
*
* @return
*/
public static String formatNumber(int currentTreeNumber, int maxNumberOfTrees){
final int numDigits = (int) Math.log10(maxNumberOfTrees) + 1;
String currentTreeNumberString = Integer.toString(currentTreeNumber);
final StringBuilder builder = new StringBuilder();
for(int i=0; i<numDigits-currentTreeNumberString.length(); i++){
builder.append('0');
}
builder.append(currentTreeNumberString);
return builder.toString();
}
} }

View file

@ -24,6 +24,7 @@ import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskFunctio
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse; import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
import ca.joeltherrien.randomforest.tree.Forest; import ca.joeltherrien.randomforest.tree.Forest;
import ca.joeltherrien.randomforest.tree.ForestTrainer; import ca.joeltherrien.randomforest.tree.ForestTrainer;
import ca.joeltherrien.randomforest.utils.DataUtils;
import ca.joeltherrien.randomforest.utils.Utils; import ca.joeltherrien.randomforest.utils.Utils;
import com.fasterxml.jackson.databind.node.*; import com.fasterxml.jackson.databind.node.*;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
@ -109,7 +110,7 @@ public class TestSavingLoading {
public void testSavingLoading() throws IOException, ClassNotFoundException { public void testSavingLoading() throws IOException, ClassNotFoundException {
final Settings settings = getSettings(); final Settings settings = getSettings();
final List<Covariate> covariates = settings.getCovariates(); final List<Covariate> covariates = settings.getCovariates();
final List<Row<CompetingRiskResponse>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation()); final List<Row<CompetingRiskResponse>> dataset = DataUtils.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation());
final File directory = new File(settings.getSaveTreeLocation()); final File directory = new File(settings.getSaveTreeLocation());
assertFalse(directory.exists()); assertFalse(directory.exists());
@ -125,7 +126,7 @@ public class TestSavingLoading {
final Forest<CompetingRiskFunctions, CompetingRiskFunctions> forest = DataLoader.loadForest(directory, new CompetingRiskFunctionCombiner(new int[]{1,2}, null)); final Forest<CompetingRiskFunctions, CompetingRiskFunctions> forest = DataUtils.loadForest(directory, new CompetingRiskFunctionCombiner(new int[]{1,2}, null));
final CovariateRow predictionRow = getPredictionRow(covariates); final CovariateRow predictionRow = getPredictionRow(covariates);

View file

@ -17,7 +17,7 @@
package ca.joeltherrien.randomforest.competingrisk; package ca.joeltherrien.randomforest.competingrisk;
import ca.joeltherrien.randomforest.CovariateRow; import ca.joeltherrien.randomforest.CovariateRow;
import ca.joeltherrien.randomforest.DataLoader; import ca.joeltherrien.randomforest.utils.DataUtils;
import ca.joeltherrien.randomforest.Row; import ca.joeltherrien.randomforest.Row;
import ca.joeltherrien.randomforest.Settings; import ca.joeltherrien.randomforest.Settings;
import ca.joeltherrien.randomforest.covariates.Covariate; import ca.joeltherrien.randomforest.covariates.Covariate;
@ -126,7 +126,7 @@ public class TestCompetingRisk {
final List<Covariate> covariates = settings.getCovariates(); final List<Covariate> covariates = settings.getCovariates();
final List<Row<CompetingRiskResponse>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation()); final List<Row<CompetingRiskResponse>> dataset = DataUtils.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation());
final TreeTrainer<CompetingRiskResponse, CompetingRiskFunctions> treeTrainer = new TreeTrainer<>(settings, covariates); final TreeTrainer<CompetingRiskResponse, CompetingRiskFunctions> treeTrainer = new TreeTrainer<>(settings, covariates);
final Node<CompetingRiskFunctions> node = treeTrainer.growTree(dataset, new Random()); final Node<CompetingRiskFunctions> node = treeTrainer.growTree(dataset, new Random());
@ -179,7 +179,7 @@ public class TestCompetingRisk {
final List<Covariate> covariates = settings.getCovariates(); final List<Covariate> covariates = settings.getCovariates();
final List<Row<CompetingRiskResponse>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation()); final List<Row<CompetingRiskResponse>> dataset = DataUtils.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation());
final TreeTrainer<CompetingRiskResponse, CompetingRiskFunctions> treeTrainer = new TreeTrainer<>(settings, covariates); final TreeTrainer<CompetingRiskResponse, CompetingRiskFunctions> treeTrainer = new TreeTrainer<>(settings, covariates);
final Node<CompetingRiskFunctions> node = treeTrainer.growTree(dataset, new Random()); final Node<CompetingRiskFunctions> node = treeTrainer.growTree(dataset, new Random());
@ -229,7 +229,7 @@ public class TestCompetingRisk {
final List<Covariate> covariates = settings.getCovariates(); final List<Covariate> covariates = settings.getCovariates();
final List<Row<CompetingRiskResponse>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation()); final List<Row<CompetingRiskResponse>> dataset = DataUtils.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation());
final ForestTrainer<CompetingRiskResponse, CompetingRiskFunctions, CompetingRiskFunctions> forestTrainer = new ForestTrainer<>(settings, dataset, covariates); final ForestTrainer<CompetingRiskResponse, CompetingRiskFunctions, CompetingRiskFunctions> forestTrainer = new ForestTrainer<>(settings, dataset, covariates);
@ -277,7 +277,7 @@ public class TestCompetingRisk {
final List<Covariate> covariates = settings.getCovariates(); final List<Covariate> covariates = settings.getCovariates();
final List<Row<CompetingRiskResponse>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation()); final List<Row<CompetingRiskResponse>> dataset = DataUtils.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation());
// Let's count the events and make sure the data was correctly read. // Let's count the events and make sure the data was correctly read.
int countCensored = 0; int countCensored = 0;
@ -320,7 +320,7 @@ public class TestCompetingRisk {
settings.setNtree(300); // results are too variable at 100 settings.setNtree(300); // results are too variable at 100
final List<Covariate> covariates = settings.getCovariates(); final List<Covariate> covariates = settings.getCovariates();
final List<Row<CompetingRiskResponse>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), final List<Row<CompetingRiskResponse>> dataset = DataUtils.loadData(covariates, settings.getResponseLoader(),
settings.getTrainingDataLocation()); settings.getTrainingDataLocation());
final ForestTrainer<CompetingRiskResponse, CompetingRiskFunctions, CompetingRiskFunctions> forestTrainer = new ForestTrainer<>(settings, dataset, covariates); final ForestTrainer<CompetingRiskResponse, CompetingRiskFunctions, CompetingRiskFunctions> forestTrainer = new ForestTrainer<>(settings, dataset, covariates);
@ -341,7 +341,7 @@ public class TestCompetingRisk {
settings.setNtree(300); // results are too variable at 100 settings.setNtree(300); // results are too variable at 100
final List<Covariate> covariates = settings.getCovariates(); final List<Covariate> covariates = settings.getCovariates();
final List<Row<CompetingRiskResponse>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), final List<Row<CompetingRiskResponse>> dataset = DataUtils.loadData(covariates, settings.getResponseLoader(),
settings.getTrainingDataLocation()); settings.getTrainingDataLocation());
final ForestTrainer<CompetingRiskResponse, CompetingRiskFunctions, CompetingRiskFunctions> forestTrainer = new ForestTrainer<>(settings, dataset, covariates); final ForestTrainer<CompetingRiskResponse, CompetingRiskFunctions, CompetingRiskFunctions> forestTrainer = new ForestTrainer<>(settings, dataset, covariates);
final Forest<CompetingRiskFunctions, CompetingRiskFunctions> forest = forestTrainer.trainSerial(); final Forest<CompetingRiskFunctions, CompetingRiskFunctions> forest = forestTrainer.trainSerial();

View file

@ -16,7 +16,7 @@
package ca.joeltherrien.randomforest.competingrisk; package ca.joeltherrien.randomforest.competingrisk;
import ca.joeltherrien.randomforest.DataLoader; import ca.joeltherrien.randomforest.utils.DataUtils;
import ca.joeltherrien.randomforest.Row; import ca.joeltherrien.randomforest.Row;
import ca.joeltherrien.randomforest.Settings; import ca.joeltherrien.randomforest.Settings;
import ca.joeltherrien.randomforest.covariates.Covariate; import ca.joeltherrien.randomforest.covariates.Covariate;
@ -36,7 +36,6 @@ import java.io.IOException;
import java.util.Collections; import java.util.Collections;
import java.util.Iterator; import java.util.Iterator;
import java.util.List; import java.util.List;
import java.util.stream.Collectors;
import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.assertTrue;
@ -63,8 +62,8 @@ public class TestLogRankMultipleGroupDifferentiator {
final List<Covariate> covariates = settings.getCovariates(); final List<Covariate> covariates = settings.getCovariates();
final DataLoader.ResponseLoader loader = settings.getResponseLoader(); final DataUtils.ResponseLoader loader = settings.getResponseLoader();
final List<Row<CompetingRiskResponse>> rows = DataLoader.loadData(covariates, loader, settings.getTrainingDataLocation()); final List<Row<CompetingRiskResponse>> rows = DataUtils.loadData(covariates, loader, settings.getTrainingDataLocation());
return new Data<>(rows, covariates); return new Data<>(rows, covariates);
} }

View file

@ -35,6 +35,23 @@ public class NumericCovariateTest {
return rowList; return rowList;
} }
private List<Row<Double>> createTestDatasetMissingValues(NumericCovariate covariate){
final List<Row<Double>> rowList = new ArrayList<>();
final List<Covariate> covariateList = Collections.singletonList(covariate);
final String naString = "NA";
rowList.add(Row.createSimple(Utils.easyMap("x", naString), covariateList, 1, 1.0));
rowList.add(Row.createSimple(Utils.easyMap("x", naString), covariateList, 2, 1.0));
rowList.add(Row.createSimple(Utils.easyMap("x", naString), covariateList, 3, 1.0));
rowList.add(Row.createSimple(Utils.easyMap("x", naString), covariateList, 4, 1.0));
rowList.add(Row.createSimple(Utils.easyMap("x", naString), covariateList, 5, 1.0));
rowList.add(Row.createSimple(Utils.easyMap("x", naString), covariateList, 6, 1.0));
rowList.add(Row.createSimple(Utils.easyMap("x", naString), covariateList, 7, 1.0));
return rowList;
}
@Test @Test
public void testNumericCovariateDeterministic(){ public void testNumericCovariateDeterministic(){
final NumericCovariate covariate = new NumericCovariate("x", 0); final NumericCovariate covariate = new NumericCovariate("x", 0);
@ -184,6 +201,21 @@ public class NumericCovariateTest {
} }
/**
* If all the values are missing on a covariate then we shouldn't return an iterator.
*
*/
@Test
public void testNumericSplitRuleUpdaterWithIndexesAllMissingData(){
final NumericCovariate covariate = new NumericCovariate("x", 0);
final List<Row<Double>> dataset = createTestDatasetMissingValues(covariate);
final NumericSplitRuleUpdater<Double> updater = covariate.generateSplitRuleUpdater(dataset, 5, new Random());
assertNull(updater);
}
private <T> void assertContains(List<T> subList, List<T> greaterList){ private <T> void assertContains(List<T> subList, List<T> greaterList){
boolean allContained = true; boolean allContained = true;

View file

@ -16,7 +16,7 @@
package ca.joeltherrien.randomforest.csv; package ca.joeltherrien.randomforest.csv;
import ca.joeltherrien.randomforest.DataLoader; import ca.joeltherrien.randomforest.utils.DataUtils;
import ca.joeltherrien.randomforest.Row; import ca.joeltherrien.randomforest.Row;
import ca.joeltherrien.randomforest.Settings; import ca.joeltherrien.randomforest.Settings;
import ca.joeltherrien.randomforest.covariates.settings.BooleanCovariateSettings; import ca.joeltherrien.randomforest.covariates.settings.BooleanCovariateSettings;
@ -64,9 +64,9 @@ public class TestLoadingCSV {
final List<Covariate> covariates = settings.getCovariates(); final List<Covariate> covariates = settings.getCovariates();
final DataLoader.ResponseLoader loader = settings.getResponseLoader(); final DataUtils.ResponseLoader loader = settings.getResponseLoader();
return DataLoader.loadData(covariates, loader, settings.getTrainingDataLocation()); return DataUtils.loadData(covariates, loader, settings.getTrainingDataLocation());
} }
@Test @Test