WIP - Some changes to how trees are saved.
This commit is contained in:
parent
76614ee68b
commit
76b2cdd3c4
18 changed files with 123 additions and 69 deletions
|
@ -28,6 +28,7 @@ import ca.joeltherrien.randomforest.responses.competingrisk.combiner.CompetingRi
|
|||
import ca.joeltherrien.randomforest.tree.Forest;
|
||||
import ca.joeltherrien.randomforest.tree.ForestTrainer;
|
||||
import ca.joeltherrien.randomforest.tree.ResponseCombiner;
|
||||
import ca.joeltherrien.randomforest.utils.DataUtils;
|
||||
import ca.joeltherrien.randomforest.utils.StepFunction;
|
||||
import ca.joeltherrien.randomforest.utils.Utils;
|
||||
import com.fasterxml.jackson.databind.node.JsonNodeFactory;
|
||||
|
@ -66,7 +67,7 @@ public class Main {
|
|||
final List<Covariate> covariates = settings.getCovariates();
|
||||
|
||||
if(args[1].equalsIgnoreCase("train")){
|
||||
final List<Row> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation());
|
||||
final List<Row> dataset = DataUtils.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation());
|
||||
|
||||
final ForestTrainer forestTrainer = new ForestTrainer(settings, dataset, covariates);
|
||||
|
||||
|
@ -111,14 +112,14 @@ public class Main {
|
|||
return;
|
||||
}
|
||||
|
||||
final List<Row<CompetingRiskResponse>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getValidationDataLocation());
|
||||
final List<Row<CompetingRiskResponse>> dataset = DataUtils.loadData(covariates, settings.getResponseLoader(), settings.getValidationDataLocation());
|
||||
|
||||
// Let's reduce this down to n
|
||||
final int n = Integer.parseInt(args[2]);
|
||||
Utils.reduceListToSize(dataset, n, new Random());
|
||||
|
||||
final File folder = new File(settings.getSaveTreeLocation());
|
||||
final Forest<?, CompetingRiskFunctions> forest = DataLoader.loadForest(folder, responseCombiner);
|
||||
final Forest<?, CompetingRiskFunctions> forest = DataUtils.loadForest(folder, responseCombiner);
|
||||
|
||||
final boolean useBootstrapPredictions = settings.getTrainingDataLocation().equals(settings.getValidationDataLocation());
|
||||
|
||||
|
|
|
@ -30,6 +30,7 @@ import ca.joeltherrien.randomforest.responses.regression.MeanResponseCombiner;
|
|||
import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceGroupDifferentiator;
|
||||
import ca.joeltherrien.randomforest.tree.GroupDifferentiator;
|
||||
import ca.joeltherrien.randomforest.tree.ResponseCombiner;
|
||||
import ca.joeltherrien.randomforest.utils.DataUtils;
|
||||
import com.fasterxml.jackson.annotation.JsonIgnore;
|
||||
import com.fasterxml.jackson.databind.JsonNode;
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
|
@ -55,17 +56,17 @@ import java.util.function.Function;
|
|||
@EqualsAndHashCode
|
||||
public class Settings {
|
||||
|
||||
private static Map<String, Function<ObjectNode, DataLoader.ResponseLoader>> RESPONSE_LOADER_MAP = new HashMap<>();
|
||||
public static Function<ObjectNode, DataLoader.ResponseLoader> getResponseLoaderConstructor(final String name){
|
||||
private static Map<String, Function<ObjectNode, DataUtils.ResponseLoader>> RESPONSE_LOADER_MAP = new HashMap<>();
|
||||
public static Function<ObjectNode, DataUtils.ResponseLoader> getResponseLoaderConstructor(final String name){
|
||||
return RESPONSE_LOADER_MAP.get(name.toLowerCase());
|
||||
}
|
||||
public static void registerResponseLoaderConstructor(final String name, final Function<ObjectNode, DataLoader.ResponseLoader> responseLoaderConstructor){
|
||||
public static void registerResponseLoaderConstructor(final String name, final Function<ObjectNode, DataUtils.ResponseLoader> responseLoaderConstructor){
|
||||
RESPONSE_LOADER_MAP.put(name.toLowerCase(), responseLoaderConstructor);
|
||||
}
|
||||
|
||||
static{
|
||||
registerResponseLoaderConstructor("double",
|
||||
node -> new DataLoader.DoubleLoader(node)
|
||||
node -> new DataUtils.DoubleLoader(node)
|
||||
);
|
||||
registerResponseLoaderConstructor("CompetingRiskResponse",
|
||||
node -> new CompetingRiskResponse.CompetingResponseLoader(node)
|
||||
|
@ -238,7 +239,7 @@ public class Settings {
|
|||
}
|
||||
|
||||
@JsonIgnore
|
||||
public DataLoader.ResponseLoader getResponseLoader(){
|
||||
public DataUtils.ResponseLoader getResponseLoader(){
|
||||
final String type = yVarSettings.get("type").asText();
|
||||
|
||||
return getResponseLoaderConstructor(type).apply(yVarSettings);
|
||||
|
|
|
@ -59,6 +59,13 @@ public final class NumericCovariate implements Covariate<Double> {
|
|||
})
|
||||
.collect(Collectors.toList());
|
||||
|
||||
// It's technically possible for data to be empty now due to NAs which will cause a crash
|
||||
// when we use random.nextInt(maxIndex).
|
||||
if(data.size() == 0){
|
||||
return null;
|
||||
}
|
||||
|
||||
|
||||
Iterator<Double> sortedDataIterator = data.stream()
|
||||
.map(row -> row.getCovariateValue(this).getValue())
|
||||
.iterator();
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
|
||||
package ca.joeltherrien.randomforest.responses.competingrisk;
|
||||
|
||||
import ca.joeltherrien.randomforest.DataLoader;
|
||||
import ca.joeltherrien.randomforest.utils.DataUtils;
|
||||
import com.fasterxml.jackson.databind.node.ObjectNode;
|
||||
import lombok.Data;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
|
@ -36,7 +36,7 @@ public class CompetingRiskResponse implements Serializable {
|
|||
|
||||
|
||||
@RequiredArgsConstructor
|
||||
public static class CompetingResponseLoader implements DataLoader.ResponseLoader<CompetingRiskResponse>{
|
||||
public static class CompetingResponseLoader implements DataUtils.ResponseLoader<CompetingRiskResponse>{
|
||||
|
||||
private final String deltaName;
|
||||
private final String uName;
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
|
||||
package ca.joeltherrien.randomforest.responses.competingrisk;
|
||||
|
||||
import ca.joeltherrien.randomforest.DataLoader;
|
||||
import ca.joeltherrien.randomforest.utils.DataUtils;
|
||||
import com.fasterxml.jackson.databind.node.ObjectNode;
|
||||
import lombok.Data;
|
||||
import lombok.EqualsAndHashCode;
|
||||
|
@ -38,7 +38,7 @@ public final class CompetingRiskResponseWithCensorTime extends CompetingRiskResp
|
|||
}
|
||||
|
||||
@RequiredArgsConstructor
|
||||
public static class CompetingResponseWithCensorTimeLoader implements DataLoader.ResponseLoader<CompetingRiskResponseWithCensorTime>{
|
||||
public static class CompetingResponseWithCensorTimeLoader implements DataUtils.ResponseLoader<CompetingRiskResponseWithCensorTime>{
|
||||
|
||||
private final String deltaName;
|
||||
private final String uName;
|
||||
|
|
|
@ -28,6 +28,7 @@ public class Forest<O, FO> { // O = output of trees, FO = forest output. In prac
|
|||
|
||||
private final Collection<Tree<O>> trees;
|
||||
private final ResponseCombiner<O, FO> treeResponseCombiner;
|
||||
private final List<Covariate> covariateList;
|
||||
|
||||
public FO evaluate(CovariateRow row){
|
||||
|
||||
|
|
|
@ -17,18 +17,17 @@
|
|||
package ca.joeltherrien.randomforest.tree;
|
||||
|
||||
import ca.joeltherrien.randomforest.Bootstrapper;
|
||||
import ca.joeltherrien.randomforest.utils.DataUtils;
|
||||
import ca.joeltherrien.randomforest.Row;
|
||||
import ca.joeltherrien.randomforest.Settings;
|
||||
import ca.joeltherrien.randomforest.covariates.Covariate;
|
||||
import ca.joeltherrien.randomforest.utils.Utils;
|
||||
import lombok.AccessLevel;
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Builder;
|
||||
|
||||
import java.io.File;
|
||||
import java.io.FileOutputStream;
|
||||
import java.io.IOException;
|
||||
import java.io.ObjectOutputStream;
|
||||
import java.text.NumberFormat;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Random;
|
||||
|
@ -38,7 +37,6 @@ import java.util.concurrent.ThreadLocalRandom;
|
|||
import java.util.concurrent.atomic.AtomicInteger;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.Stream;
|
||||
import java.util.zip.GZIPOutputStream;
|
||||
|
||||
@Builder
|
||||
@AllArgsConstructor(access=AccessLevel.PRIVATE)
|
||||
|
@ -90,6 +88,7 @@ public class ForestTrainer<Y, TO, FO> {
|
|||
return Forest.<TO, FO>builder()
|
||||
.treeResponseCombiner(treeResponseCombiner)
|
||||
.trees(trees)
|
||||
.covariateList(covariates)
|
||||
.build();
|
||||
|
||||
}
|
||||
|
@ -112,7 +111,7 @@ public class ForestTrainer<Y, TO, FO> {
|
|||
System.out.print("\rFinished " + treeCount.get() + "/" + ntree + " trees");
|
||||
}
|
||||
|
||||
final Runnable worker = new TreeSavedWorker(data, "tree-" + formatNumber(j+1) + ".tree", treeCount);
|
||||
final Runnable worker = new TreeSavedWorker(data, "tree-" + Utils.formatNumber(j+1, ntree) + ".tree", treeCount);
|
||||
worker.run();
|
||||
|
||||
}
|
||||
|
@ -184,7 +183,7 @@ public class ForestTrainer<Y, TO, FO> {
|
|||
final AtomicInteger treeCount = new AtomicInteger(treeFiles.length); // tracks how many trees are finished
|
||||
|
||||
for(int j=treeCount.get(); j<ntree; j++){
|
||||
final Runnable worker = new TreeSavedWorker(data, "tree-" + formatNumber(j+1) + ".tree", treeCount);
|
||||
final Runnable worker = new TreeSavedWorker(data, "tree-" + Utils.formatNumber(j+1, ntree) + ".tree", treeCount);
|
||||
executorService.execute(worker);
|
||||
}
|
||||
|
||||
|
@ -215,16 +214,6 @@ public class ForestTrainer<Y, TO, FO> {
|
|||
return treeTrainer.growTree(bootstrappedData, random);
|
||||
}
|
||||
|
||||
public void saveTree(final Tree<TO> tree, String name) throws IOException {
|
||||
final String filename = saveTreeLocation + "/" + name;
|
||||
|
||||
final ObjectOutputStream outputStream = new ObjectOutputStream(new GZIPOutputStream(new FileOutputStream(filename)));
|
||||
|
||||
outputStream.writeObject(tree);
|
||||
|
||||
outputStream.close();
|
||||
|
||||
}
|
||||
|
||||
private class TreeInMemoryWorker implements Runnable {
|
||||
|
||||
|
@ -250,27 +239,6 @@ public class ForestTrainer<Y, TO, FO> {
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* When saving trees we typically save them as tree-1.tree, tree-2.tree. This is fine until we get tree-10.tree, which
|
||||
* when sorted alphabetically goes before tree-2.tree. We should instead save tree-01.tree, ... tree-10.tree.
|
||||
*
|
||||
* We need to set the number of 0s though based on ntree.
|
||||
*
|
||||
* @return
|
||||
*/
|
||||
private String formatNumber(int currentTreeNumber){
|
||||
final int numDigits = (int) Math.log10(ntree) + 1;
|
||||
|
||||
String currentTreeNumberString = Integer.toString(currentTreeNumber);
|
||||
final StringBuilder builder = new StringBuilder();
|
||||
|
||||
for(int i=0; i<numDigits-currentTreeNumberString.length(); i++){
|
||||
builder.append('0');
|
||||
}
|
||||
builder.append(currentTreeNumberString);
|
||||
|
||||
return builder.toString();
|
||||
}
|
||||
|
||||
private class TreeSavedWorker implements Runnable {
|
||||
|
||||
|
@ -291,7 +259,7 @@ public class ForestTrainer<Y, TO, FO> {
|
|||
final Tree<TO> tree = trainTree(bootstrapper, ThreadLocalRandom.current());
|
||||
|
||||
try {
|
||||
saveTree(tree, filename);
|
||||
DataUtils.saveObject(tree, saveTreeLocation + "/" + filename);
|
||||
} catch (IOException e) {
|
||||
System.err.println("IOException while saving " + filename);
|
||||
e.printStackTrace();
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
package ca.joeltherrien.randomforest.tree;
|
||||
|
||||
import java.io.Serializable;
|
||||
import java.util.Iterator;
|
||||
|
||||
/**
|
||||
|
@ -26,7 +27,7 @@ import java.util.Iterator;
|
|||
* If you want to implement a very trivial GroupDifferentiator that just takes two Lists as arguments, try extending
|
||||
* SimpleGroupDifferentiator.
|
||||
*/
|
||||
public interface GroupDifferentiator<Y> {
|
||||
public interface GroupDifferentiator<Y> extends Serializable {
|
||||
|
||||
SplitAndScore<Y, ?> differentiate(Iterator<Split<Y, ?>> splitIterator);
|
||||
|
||||
|
|
|
@ -16,9 +16,10 @@
|
|||
|
||||
package ca.joeltherrien.randomforest.tree;
|
||||
|
||||
import java.io.Serializable;
|
||||
import java.util.List;
|
||||
|
||||
public interface ResponseCombiner<I, O> {
|
||||
public interface ResponseCombiner<I, O> extends Serializable {
|
||||
|
||||
O combine(List<I> responses);
|
||||
|
||||
|
|
|
@ -57,4 +57,5 @@ public class Tree<Y> implements Node<Y> {
|
|||
public String toString(){
|
||||
return rootNode.toString();
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -165,6 +165,11 @@ public class TreeTrainer<Y, O> {
|
|||
for(final Covariate covariate : covariatesToTry) {
|
||||
final Iterator<Split> iterator = covariate.generateSplitRuleUpdater(data, this.numberOfSplits, random);
|
||||
|
||||
// this happens if there were only NA values in data for this covariate. Rare, but I've seen it.
|
||||
if(iterator == null){
|
||||
continue;
|
||||
}
|
||||
|
||||
final SplitAndScore<Y, ?> candidateSplitAndScore = noGenericDifferentiator.differentiate(iterator);
|
||||
|
||||
if(candidateSplitAndScore != null && (bestSplitAndScore == null ||
|
||||
|
|
|
@ -14,8 +14,9 @@
|
|||
* along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package ca.joeltherrien.randomforest;
|
||||
package ca.joeltherrien.randomforest.utils;
|
||||
|
||||
import ca.joeltherrien.randomforest.Row;
|
||||
import ca.joeltherrien.randomforest.covariates.Covariate;
|
||||
import ca.joeltherrien.randomforest.tree.Forest;
|
||||
import ca.joeltherrien.randomforest.tree.ResponseCombiner;
|
||||
|
@ -98,12 +99,26 @@ public class DataUtils {
|
|||
|
||||
}
|
||||
|
||||
public static <O, FO> Forest<O, FO> loadForest(String folder, ResponseCombiner<O, FO> treeResponseCombiner) throws IOException, ClassNotFoundException {
|
||||
final File directory = new File(folder);
|
||||
return loadForest(directory, treeResponseCombiner);
|
||||
}
|
||||
|
||||
public static void saveObject(Serializable object, String filename) throws IOException {
|
||||
final ObjectOutputStream outputStream = new ObjectOutputStream(new GZIPOutputStream(new FileOutputStream(filename)));
|
||||
outputStream.writeObject(object);
|
||||
outputStream.close();
|
||||
}
|
||||
|
||||
public static Object loadObject(String filename) throws IOException, ClassNotFoundException {
|
||||
final ObjectInputStream inputStream = new ObjectInputStream(new GZIPInputStream(new FileInputStream(filename)));
|
||||
final Object object = inputStream.readObject();
|
||||
inputStream.close();
|
||||
|
||||
return object;
|
||||
|
||||
}
|
||||
|
||||
@FunctionalInterface
|
||||
public interface ResponseLoader<Y>{
|
||||
Y parse(CSVRecord record);
|
||||
|
|
|
@ -208,5 +208,26 @@ public final class Utils {
|
|||
return map;
|
||||
}
|
||||
|
||||
/**
|
||||
* When saving trees we typically save them as tree-1.tree, tree-2.tree. This is fine until we get tree-10.tree, which
|
||||
* when sorted alphabetically goes before tree-2.tree. We should instead save tree-01.tree, ... tree-10.tree.
|
||||
*
|
||||
* We need to set the number of 0s though based on ntree.
|
||||
*
|
||||
* @return
|
||||
*/
|
||||
public static String formatNumber(int currentTreeNumber, int maxNumberOfTrees){
|
||||
final int numDigits = (int) Math.log10(maxNumberOfTrees) + 1;
|
||||
|
||||
String currentTreeNumberString = Integer.toString(currentTreeNumber);
|
||||
final StringBuilder builder = new StringBuilder();
|
||||
|
||||
for(int i=0; i<numDigits-currentTreeNumberString.length(); i++){
|
||||
builder.append('0');
|
||||
}
|
||||
builder.append(currentTreeNumberString);
|
||||
|
||||
return builder.toString();
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -24,6 +24,7 @@ import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskFunctio
|
|||
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
|
||||
import ca.joeltherrien.randomforest.tree.Forest;
|
||||
import ca.joeltherrien.randomforest.tree.ForestTrainer;
|
||||
import ca.joeltherrien.randomforest.utils.DataUtils;
|
||||
import ca.joeltherrien.randomforest.utils.Utils;
|
||||
import com.fasterxml.jackson.databind.node.*;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
@ -109,7 +110,7 @@ public class TestSavingLoading {
|
|||
public void testSavingLoading() throws IOException, ClassNotFoundException {
|
||||
final Settings settings = getSettings();
|
||||
final List<Covariate> covariates = settings.getCovariates();
|
||||
final List<Row<CompetingRiskResponse>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation());
|
||||
final List<Row<CompetingRiskResponse>> dataset = DataUtils.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation());
|
||||
|
||||
final File directory = new File(settings.getSaveTreeLocation());
|
||||
assertFalse(directory.exists());
|
||||
|
@ -125,7 +126,7 @@ public class TestSavingLoading {
|
|||
|
||||
|
||||
|
||||
final Forest<CompetingRiskFunctions, CompetingRiskFunctions> forest = DataLoader.loadForest(directory, new CompetingRiskFunctionCombiner(new int[]{1,2}, null));
|
||||
final Forest<CompetingRiskFunctions, CompetingRiskFunctions> forest = DataUtils.loadForest(directory, new CompetingRiskFunctionCombiner(new int[]{1,2}, null));
|
||||
|
||||
final CovariateRow predictionRow = getPredictionRow(covariates);
|
||||
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
package ca.joeltherrien.randomforest.competingrisk;
|
||||
|
||||
import ca.joeltherrien.randomforest.CovariateRow;
|
||||
import ca.joeltherrien.randomforest.DataLoader;
|
||||
import ca.joeltherrien.randomforest.utils.DataUtils;
|
||||
import ca.joeltherrien.randomforest.Row;
|
||||
import ca.joeltherrien.randomforest.Settings;
|
||||
import ca.joeltherrien.randomforest.covariates.Covariate;
|
||||
|
@ -126,7 +126,7 @@ public class TestCompetingRisk {
|
|||
|
||||
final List<Covariate> covariates = settings.getCovariates();
|
||||
|
||||
final List<Row<CompetingRiskResponse>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation());
|
||||
final List<Row<CompetingRiskResponse>> dataset = DataUtils.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation());
|
||||
|
||||
final TreeTrainer<CompetingRiskResponse, CompetingRiskFunctions> treeTrainer = new TreeTrainer<>(settings, covariates);
|
||||
final Node<CompetingRiskFunctions> node = treeTrainer.growTree(dataset, new Random());
|
||||
|
@ -179,7 +179,7 @@ public class TestCompetingRisk {
|
|||
|
||||
final List<Covariate> covariates = settings.getCovariates();
|
||||
|
||||
final List<Row<CompetingRiskResponse>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation());
|
||||
final List<Row<CompetingRiskResponse>> dataset = DataUtils.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation());
|
||||
|
||||
final TreeTrainer<CompetingRiskResponse, CompetingRiskFunctions> treeTrainer = new TreeTrainer<>(settings, covariates);
|
||||
final Node<CompetingRiskFunctions> node = treeTrainer.growTree(dataset, new Random());
|
||||
|
@ -229,7 +229,7 @@ public class TestCompetingRisk {
|
|||
|
||||
final List<Covariate> covariates = settings.getCovariates();
|
||||
|
||||
final List<Row<CompetingRiskResponse>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation());
|
||||
final List<Row<CompetingRiskResponse>> dataset = DataUtils.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation());
|
||||
|
||||
final ForestTrainer<CompetingRiskResponse, CompetingRiskFunctions, CompetingRiskFunctions> forestTrainer = new ForestTrainer<>(settings, dataset, covariates);
|
||||
|
||||
|
@ -277,7 +277,7 @@ public class TestCompetingRisk {
|
|||
|
||||
final List<Covariate> covariates = settings.getCovariates();
|
||||
|
||||
final List<Row<CompetingRiskResponse>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation());
|
||||
final List<Row<CompetingRiskResponse>> dataset = DataUtils.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation());
|
||||
|
||||
// Let's count the events and make sure the data was correctly read.
|
||||
int countCensored = 0;
|
||||
|
@ -320,7 +320,7 @@ public class TestCompetingRisk {
|
|||
settings.setNtree(300); // results are too variable at 100
|
||||
|
||||
final List<Covariate> covariates = settings.getCovariates();
|
||||
final List<Row<CompetingRiskResponse>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(),
|
||||
final List<Row<CompetingRiskResponse>> dataset = DataUtils.loadData(covariates, settings.getResponseLoader(),
|
||||
settings.getTrainingDataLocation());
|
||||
final ForestTrainer<CompetingRiskResponse, CompetingRiskFunctions, CompetingRiskFunctions> forestTrainer = new ForestTrainer<>(settings, dataset, covariates);
|
||||
|
||||
|
@ -341,7 +341,7 @@ public class TestCompetingRisk {
|
|||
settings.setNtree(300); // results are too variable at 100
|
||||
|
||||
final List<Covariate> covariates = settings.getCovariates();
|
||||
final List<Row<CompetingRiskResponse>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(),
|
||||
final List<Row<CompetingRiskResponse>> dataset = DataUtils.loadData(covariates, settings.getResponseLoader(),
|
||||
settings.getTrainingDataLocation());
|
||||
final ForestTrainer<CompetingRiskResponse, CompetingRiskFunctions, CompetingRiskFunctions> forestTrainer = new ForestTrainer<>(settings, dataset, covariates);
|
||||
final Forest<CompetingRiskFunctions, CompetingRiskFunctions> forest = forestTrainer.trainSerial();
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
|
||||
package ca.joeltherrien.randomforest.competingrisk;
|
||||
|
||||
import ca.joeltherrien.randomforest.DataLoader;
|
||||
import ca.joeltherrien.randomforest.utils.DataUtils;
|
||||
import ca.joeltherrien.randomforest.Row;
|
||||
import ca.joeltherrien.randomforest.Settings;
|
||||
import ca.joeltherrien.randomforest.covariates.Covariate;
|
||||
|
@ -36,7 +36,6 @@ import java.io.IOException;
|
|||
import java.util.Collections;
|
||||
import java.util.Iterator;
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||
|
||||
|
@ -63,8 +62,8 @@ public class TestLogRankMultipleGroupDifferentiator {
|
|||
|
||||
final List<Covariate> covariates = settings.getCovariates();
|
||||
|
||||
final DataLoader.ResponseLoader loader = settings.getResponseLoader();
|
||||
final List<Row<CompetingRiskResponse>> rows = DataLoader.loadData(covariates, loader, settings.getTrainingDataLocation());
|
||||
final DataUtils.ResponseLoader loader = settings.getResponseLoader();
|
||||
final List<Row<CompetingRiskResponse>> rows = DataUtils.loadData(covariates, loader, settings.getTrainingDataLocation());
|
||||
|
||||
return new Data<>(rows, covariates);
|
||||
}
|
||||
|
|
|
@ -35,6 +35,23 @@ public class NumericCovariateTest {
|
|||
return rowList;
|
||||
}
|
||||
|
||||
private List<Row<Double>> createTestDatasetMissingValues(NumericCovariate covariate){
|
||||
final List<Row<Double>> rowList = new ArrayList<>();
|
||||
final List<Covariate> covariateList = Collections.singletonList(covariate);
|
||||
|
||||
final String naString = "NA";
|
||||
|
||||
rowList.add(Row.createSimple(Utils.easyMap("x", naString), covariateList, 1, 1.0));
|
||||
rowList.add(Row.createSimple(Utils.easyMap("x", naString), covariateList, 2, 1.0));
|
||||
rowList.add(Row.createSimple(Utils.easyMap("x", naString), covariateList, 3, 1.0));
|
||||
rowList.add(Row.createSimple(Utils.easyMap("x", naString), covariateList, 4, 1.0));
|
||||
rowList.add(Row.createSimple(Utils.easyMap("x", naString), covariateList, 5, 1.0));
|
||||
rowList.add(Row.createSimple(Utils.easyMap("x", naString), covariateList, 6, 1.0));
|
||||
rowList.add(Row.createSimple(Utils.easyMap("x", naString), covariateList, 7, 1.0));
|
||||
|
||||
return rowList;
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testNumericCovariateDeterministic(){
|
||||
final NumericCovariate covariate = new NumericCovariate("x", 0);
|
||||
|
@ -184,6 +201,21 @@ public class NumericCovariateTest {
|
|||
|
||||
}
|
||||
|
||||
/**
|
||||
* If all the values are missing on a covariate then we shouldn't return an iterator.
|
||||
*
|
||||
*/
|
||||
@Test
|
||||
public void testNumericSplitRuleUpdaterWithIndexesAllMissingData(){
|
||||
final NumericCovariate covariate = new NumericCovariate("x", 0);
|
||||
final List<Row<Double>> dataset = createTestDatasetMissingValues(covariate);
|
||||
final NumericSplitRuleUpdater<Double> updater = covariate.generateSplitRuleUpdater(dataset, 5, new Random());
|
||||
|
||||
assertNull(updater);
|
||||
}
|
||||
|
||||
|
||||
|
||||
private <T> void assertContains(List<T> subList, List<T> greaterList){
|
||||
boolean allContained = true;
|
||||
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
|
||||
package ca.joeltherrien.randomforest.csv;
|
||||
|
||||
import ca.joeltherrien.randomforest.DataLoader;
|
||||
import ca.joeltherrien.randomforest.utils.DataUtils;
|
||||
import ca.joeltherrien.randomforest.Row;
|
||||
import ca.joeltherrien.randomforest.Settings;
|
||||
import ca.joeltherrien.randomforest.covariates.settings.BooleanCovariateSettings;
|
||||
|
@ -64,9 +64,9 @@ public class TestLoadingCSV {
|
|||
final List<Covariate> covariates = settings.getCovariates();
|
||||
|
||||
|
||||
final DataLoader.ResponseLoader loader = settings.getResponseLoader();
|
||||
final DataUtils.ResponseLoader loader = settings.getResponseLoader();
|
||||
|
||||
return DataLoader.loadData(covariates, loader, settings.getTrainingDataLocation());
|
||||
return DataUtils.loadData(covariates, loader, settings.getTrainingDataLocation());
|
||||
}
|
||||
|
||||
@Test
|
||||
|
|
Loading…
Reference in a new issue