Switch code to storing Covariate.Value using arrays instead of Maps

This commit is contained in:
Joel Therrien 2018-09-18 11:17:15 -07:00
parent de39f60314
commit aa733d5eba
23 changed files with 145 additions and 101 deletions

View file

@ -12,14 +12,13 @@ import java.util.Map;
@RequiredArgsConstructor @RequiredArgsConstructor
public class CovariateRow implements Serializable { public class CovariateRow implements Serializable {
private final Map<String, Covariate.Value> valueMap; private final Covariate.Value[] valueArray;
@Getter @Getter
private final int id; private final int id;
public Covariate.Value<?> getCovariateValue(String name){ public Covariate.Value<?> getCovariateValue(Covariate covariate){
return valueMap.get(name); return valueArray[covariate.getIndex()];
} }
@Override @Override
@ -28,18 +27,21 @@ public class CovariateRow implements Serializable {
} }
public static CovariateRow createSimple(Map<String, String> simpleMap, List<Covariate> covariateList, int id){ public static CovariateRow createSimple(Map<String, String> simpleMap, List<Covariate> covariateList, int id){
final Map<String, Covariate.Value> valueMap = new HashMap<>(); final Covariate.Value[] valueArray = new Covariate.Value[covariateList.size()];
final Map<String, Covariate> covariateMap = new HashMap<>(); final Map<String, Covariate> covariateMap = new HashMap<>();
covariateList.forEach(covariate -> covariateMap.put(covariate.getName(), covariate)); covariateList.forEach(covariate -> covariateMap.put(covariate.getName(), covariate));
simpleMap.forEach((name, valueStr) -> { simpleMap.forEach((name, valueStr) -> {
if(covariateMap.containsKey(name)){ final Covariate covariate = covariateMap.get(name);
valueMap.put(name, covariateMap.get(name).createValue(valueStr));
if(covariate != null){ // happens often in tests where we experiment with adding / removing covariates
valueArray[covariate.getIndex()] = covariate.createValue(valueStr);
} }
}); });
return new CovariateRow(valueMap, id); return new CovariateRow(valueArray, id);
} }
} }

View file

@ -37,15 +37,15 @@ public class DataLoader {
int id = 1; int id = 1;
for(final CSVRecord record : parser){ for(final CSVRecord record : parser){
final Map<String, Covariate.Value> covariateValueMap = new HashMap<>(); final Covariate.Value[] valueArray = new Covariate.Value[covariates.size()];
for(final Covariate<?> covariate : covariates){ for(final Covariate<?> covariate : covariates){
covariateValueMap.put(covariate.getName(), covariate.createValue(record.get(covariate.getName()))); valueArray[covariate.getIndex()] = covariate.createValue(record.get(covariate.getName()));
} }
final Y y = responseLoader.parse(record); final Y y = responseLoader.parse(record);
dataset.add(new Row<>(covariateValueMap, id++, y)); dataset.add(new Row<>(valueArray, id++, y));
} }

View file

@ -1,9 +1,6 @@
package ca.joeltherrien.randomforest; package ca.joeltherrien.randomforest;
import ca.joeltherrien.randomforest.covariates.BooleanCovariateSettings; import ca.joeltherrien.randomforest.covariates.*;
import ca.joeltherrien.randomforest.covariates.Covariate;
import ca.joeltherrien.randomforest.covariates.FactorCovariateSettings;
import ca.joeltherrien.randomforest.covariates.NumericCovariateSettings;
import ca.joeltherrien.randomforest.responses.competingrisk.*; import ca.joeltherrien.randomforest.responses.competingrisk.*;
import ca.joeltherrien.randomforest.responses.competingrisk.combiner.CompetingRiskFunctionCombiner; import ca.joeltherrien.randomforest.responses.competingrisk.combiner.CompetingRiskFunctionCombiner;
import ca.joeltherrien.randomforest.responses.competingrisk.combiner.alternative.CompetingRiskListCombiner; import ca.joeltherrien.randomforest.responses.competingrisk.combiner.alternative.CompetingRiskListCombiner;
@ -17,6 +14,7 @@ import com.fasterxml.jackson.databind.node.ObjectNode;
import com.fasterxml.jackson.databind.node.TextNode; import com.fasterxml.jackson.databind.node.TextNode;
import java.io.*; import java.io.*;
import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.stream.Collectors; import java.util.stream.Collectors;
@ -35,8 +33,7 @@ public class Main {
final File settingsFile = new File(args[0]); final File settingsFile = new File(args[0]);
final Settings settings = Settings.load(settingsFile); final Settings settings = Settings.load(settingsFile);
final List<Covariate> covariates = settings.getCovariates().stream() final List<Covariate> covariates = settings.getCovariates();
.map(cs -> cs.build()).collect(Collectors.toList());
if(args[1].equalsIgnoreCase("train")){ if(args[1].equalsIgnoreCase("train")){
final List<Row> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation()); final List<Row> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation());
@ -168,7 +165,7 @@ public class Main {
yVarSettings.set("name", new TextNode("y")); yVarSettings.set("name", new TextNode("y"));
return Settings.builder() return Settings.builder()
.covariates(Utils.easyList( .covariateSettings(Utils.easyList(
new NumericCovariateSettings("x1"), new NumericCovariateSettings("x1"),
new BooleanCovariateSettings("x2"), new BooleanCovariateSettings("x2"),
new FactorCovariateSettings("x3", Utils.easyList("cat", "mouse", "dog")) new FactorCovariateSettings("x3", Utils.easyList("cat", "mouse", "dog"))

View file

@ -3,14 +3,16 @@ package ca.joeltherrien.randomforest;
import ca.joeltherrien.randomforest.covariates.Covariate; import ca.joeltherrien.randomforest.covariates.Covariate;
import java.util.HashMap;
import java.util.List;
import java.util.Map; import java.util.Map;
public class Row<Y> extends CovariateRow { public class Row<Y> extends CovariateRow {
private final Y response; private final Y response;
public Row(Map<String, Covariate.Value> valueMap, int id, Y response){ public Row(final Covariate.Value[] valueArray, final int id, final Y response){
super(valueMap, id); super(valueArray, id);
this.response = response; this.response = response;
} }
@ -23,7 +25,21 @@ public class Row<Y> extends CovariateRow {
public String toString() { public String toString() {
return "Row " + this.getId(); return "Row " + this.getId();
} }
public static <Y> Row<Y> createSimple(Map<String, String> simpleMap, List<Covariate> covariateList, int id, final Y response){
final Covariate.Value[] valueArray = new Covariate.Value[covariateList.size()];
final Map<String, Covariate> covariateMap = new HashMap<>();
covariateList.forEach(covariate -> covariateMap.put(covariate.getName(), covariate));
simpleMap.forEach((name, valueStr) -> {
final Covariate covariate = covariateMap.get(name);
if(covariate != null){ // happens often in tests where we experiment with adding / removing covariates
valueArray[covariate.getIndex()] = covariate.createValue(valueStr);
}
});
return new Row<Y>(valueArray, id, response);
}
} }

View file

@ -1,5 +1,6 @@
package ca.joeltherrien.randomforest; package ca.joeltherrien.randomforest;
import ca.joeltherrien.randomforest.covariates.Covariate;
import ca.joeltherrien.randomforest.covariates.CovariateSettings; import ca.joeltherrien.randomforest.covariates.CovariateSettings;
import ca.joeltherrien.randomforest.responses.competingrisk.*; import ca.joeltherrien.randomforest.responses.competingrisk.*;
import ca.joeltherrien.randomforest.responses.competingrisk.combiner.CompetingRiskFunctionCombiner; import ca.joeltherrien.randomforest.responses.competingrisk.combiner.CompetingRiskFunctionCombiner;
@ -195,7 +196,7 @@ public class Settings {
private ObjectNode groupDifferentiatorSettings = new ObjectNode(JsonNodeFactory.instance); private ObjectNode groupDifferentiatorSettings = new ObjectNode(JsonNodeFactory.instance);
private ObjectNode treeCombinerSettings = new ObjectNode(JsonNodeFactory.instance); private ObjectNode treeCombinerSettings = new ObjectNode(JsonNodeFactory.instance);
private List<CovariateSettings> covariates = new ArrayList<>(); private List<CovariateSettings> covariateSettings = new ArrayList<>();
private ObjectNode yVarSettings = new ObjectNode(JsonNodeFactory.instance); private ObjectNode yVarSettings = new ObjectNode(JsonNodeFactory.instance);
// number of covariates to randomly try // number of covariates to randomly try
@ -227,7 +228,7 @@ public class Settings {
//mapper.enableDefaultTyping(); //mapper.enableDefaultTyping();
// Jackson can struggle with some types of Lists, such as that returned by the useful List.of(...) // Jackson can struggle with some types of Lists, such as that returned by the useful List.of(...)
this.covariates = new ArrayList<>(this.covariates); this.covariateSettings = new ArrayList<>(this.covariateSettings);
mapper.writeValue(file, this); mapper.writeValue(file, this);
} }
@ -260,4 +261,14 @@ public class Settings {
return getResponseCombinerConstructor(type).apply(treeCombinerSettings); return getResponseCombinerConstructor(type).apply(treeCombinerSettings);
} }
@JsonIgnore
public List<Covariate> getCovariates(){
final List<CovariateSettings> covariateSettingsList = this.getCovariateSettings();
final List<Covariate> covariates = new ArrayList<>(covariateSettingsList.size());
for(int i = 0; i < covariateSettingsList.size(); i++){
covariates.add(covariateSettingsList.get(i).build(i));
}
return covariates;
}
} }

View file

@ -13,6 +13,9 @@ public final class BooleanCovariate implements Covariate<Boolean>{
@Getter @Getter
private final String name; private final String name;
@Getter
private final int index;
private final BooleanSplitRule splitRule = new BooleanSplitRule(); // there's only one possible rule for BooleanCovariates. private final BooleanSplitRule splitRule = new BooleanSplitRule(); // there's only one possible rule for BooleanCovariates.
@Override @Override

View file

@ -12,7 +12,7 @@ public final class BooleanCovariateSettings extends CovariateSettings<Boolean> {
} }
@Override @Override
public BooleanCovariate build() { public BooleanCovariate build(int index) {
return new BooleanCovariate(name); return new BooleanCovariate(name, index);
} }
} }

View file

@ -12,6 +12,8 @@ public interface Covariate<V> extends Serializable {
String getName(); String getName();
int getIndex();
Collection<? extends SplitRule<V>> generateSplitRules(final List<Value<V>> data, final int number); Collection<? extends SplitRule<V>> generateSplitRules(final List<Value<V>> data, final int number);
Value<V> createValue(V value); Value<V> createValue(V value);
@ -54,7 +56,7 @@ public interface Covariate<V> extends Serializable {
for(final Row<Y> row : rows) { for(final Row<Y> row : rows) {
final Value<V> value = (Value<V>) row.getCovariateValue(getParent().getName()); final Value<V> value = (Value<V>) row.getCovariateValue(getParent());
if(value.isNA()){ if(value.isNA()){
missingValueRows.add(row); missingValueRows.add(row);
@ -76,7 +78,7 @@ public interface Covariate<V> extends Serializable {
} }
default boolean isLeftHand(CovariateRow row, final double probabilityNaLeftHand){ default boolean isLeftHand(CovariateRow row, final double probabilityNaLeftHand){
final Value<V> value = (Value<V>) row.getCovariateValue(getParent().getName()); final Value<V> value = (Value<V>) row.getCovariateValue(getParent());
if(value.isNA()){ if(value.isNA()){
return ThreadLocalRandom.current().nextDouble() <= probabilityNaLeftHand; return ThreadLocalRandom.current().nextDouble() <= probabilityNaLeftHand;

View file

@ -28,5 +28,5 @@ public abstract class CovariateSettings<V> {
this.name = name; this.name = name;
} }
public abstract Covariate<V> build(); public abstract Covariate<V> build(int index);
} }

View file

@ -1,20 +1,27 @@
package ca.joeltherrien.randomforest.covariates; package ca.joeltherrien.randomforest.covariates;
import lombok.EqualsAndHashCode; import lombok.EqualsAndHashCode;
import lombok.Getter;
import java.util.*; import java.util.*;
import java.util.concurrent.ThreadLocalRandom; import java.util.concurrent.ThreadLocalRandom;
public final class FactorCovariate implements Covariate<String>{ public final class FactorCovariate implements Covariate<String>{
@Getter
private final String name; private final String name;
@Getter
private final int index;
private final Map<String, FactorValue> factorLevels; private final Map<String, FactorValue> factorLevels;
private final FactorValue naValue; private final FactorValue naValue;
private final int numberOfPossiblePairings; private final int numberOfPossiblePairings;
public FactorCovariate(final String name, List<String> levels){ public FactorCovariate(final String name, final int index, List<String> levels){
this.name = name; this.name = name;
this.index = index;
this.factorLevels = new HashMap<>(); this.factorLevels = new HashMap<>();
for(final String level : levels){ for(final String level : levels){
@ -33,10 +40,6 @@ public final class FactorCovariate implements Covariate<String>{
} }
@Override
public String getName() {
return name;
}
@Override @Override
public Set<FactorSplitRule> generateSplitRules(List<Value<String>> data, int number) { public Set<FactorSplitRule> generateSplitRules(List<Value<String>> data, int number) {

View file

@ -18,7 +18,7 @@ public final class FactorCovariateSettings extends CovariateSettings<String> {
} }
@Override @Override
public FactorCovariate build() { public FactorCovariate build(int index) {
return new FactorCovariate(name, levels); return new FactorCovariate(name, index, levels);
} }
} }

View file

@ -16,6 +16,9 @@ public final class NumericCovariate implements Covariate<Double>{
@Getter @Getter
private final String name; private final String name;
@Getter
private final int index;
@Override @Override
public Collection<NumericSplitRule> generateSplitRules(List<Value<Double>> data, int number) { public Collection<NumericSplitRule> generateSplitRules(List<Value<Double>> data, int number) {

View file

@ -12,7 +12,7 @@ public final class NumericCovariateSettings extends CovariateSettings<Double> {
} }
@Override @Override
public NumericCovariate build() { public NumericCovariate build(int index) {
return new NumericCovariate(name); return new NumericCovariate(name, index);
} }
} }

View file

@ -113,7 +113,7 @@ public class TreeTrainer<Y, O> {
.generateSplitRules( .generateSplitRules(
data data
.stream() .stream()
.map(row -> row.getCovariateValue(covariate.getName())) .map(row -> row.getCovariateValue(covariate))
.collect(Collectors.toList()) .collect(Collectors.toList())
, numberToTry); , numberToTry);

View file

@ -53,7 +53,7 @@ public class TestSavingLoading {
yVarSettings.set("delta", new TextNode("status")); yVarSettings.set("delta", new TextNode("status"));
return Settings.builder() return Settings.builder()
.covariates(Utils.easyList( .covariateSettings(Utils.easyList(
new NumericCovariateSettings("ageatfda"), new NumericCovariateSettings("ageatfda"),
new BooleanCovariateSettings("idu"), new BooleanCovariateSettings("idu"),
new BooleanCovariateSettings("black"), new BooleanCovariateSettings("black"),
@ -87,14 +87,10 @@ public class TestSavingLoading {
, covariates, 1); , covariates, 1);
} }
public List<Covariate> getCovariates(Settings settings){
return settings.getCovariates().stream().map(covariateSettings -> covariateSettings.build()).collect(Collectors.toList());
}
@Test @Test
public void testSavingLoading() throws IOException, ClassNotFoundException { public void testSavingLoading() throws IOException, ClassNotFoundException {
final Settings settings = getSettings(); final Settings settings = getSettings();
final List<Covariate> covariates = getCovariates(settings); final List<Covariate> covariates = settings.getCovariates();
final List<Row<CompetingRiskResponse>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation()); final List<Row<CompetingRiskResponse>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation());
final File directory = new File(settings.getSaveTreeLocation()); final File directory = new File(settings.getSaveTreeLocation());

View file

@ -55,7 +55,7 @@ public class TestCompetingRisk {
yVarSettings.set("delta", new TextNode("status")); yVarSettings.set("delta", new TextNode("status"));
return Settings.builder() return Settings.builder()
.covariates(Utils.easyList( .covariateSettings(Utils.easyList(
new NumericCovariateSettings("ageatfda"), new NumericCovariateSettings("ageatfda"),
new BooleanCovariateSettings("idu"), new BooleanCovariateSettings("idu"),
new BooleanCovariateSettings("black"), new BooleanCovariateSettings("black"),
@ -79,9 +79,6 @@ public class TestCompetingRisk {
.build(); .build();
} }
public List<Covariate> getCovariates(Settings settings){
return settings.getCovariates().stream().map(covariateSettings -> covariateSettings.build()).collect(Collectors.toList());
}
public CovariateRow getPredictionRow(List<Covariate> covariates){ public CovariateRow getPredictionRow(List<Covariate> covariates){
return CovariateRow.createSimple(Utils.easyMap( return CovariateRow.createSimple(Utils.easyMap(
@ -96,12 +93,12 @@ public class TestCompetingRisk {
public void testSingleTree() throws IOException { public void testSingleTree() throws IOException {
final Settings settings = getSettings(); final Settings settings = getSettings();
settings.setTrainingDataLocation("src/test/resources/wihs.bootstrapped.csv"); settings.setTrainingDataLocation("src/test/resources/wihs.bootstrapped.csv");
settings.setCovariates(Utils.easyList( settings.setCovariateSettings(Utils.easyList(
new BooleanCovariateSettings("idu"), new BooleanCovariateSettings("idu"),
new BooleanCovariateSettings("black") new BooleanCovariateSettings("black")
)); // by only using BooleanCovariates (only one split rule) we can guarantee identical results with randomForestSRC on one tree. )); // by only using BooleanCovariates (only one split rule) we can guarantee identical results with randomForestSRC on one tree.
final List<Covariate> covariates = getCovariates(settings); final List<Covariate> covariates = settings.getCovariates();
final List<Row<CompetingRiskResponse>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation()); final List<Row<CompetingRiskResponse>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation());
@ -154,7 +151,7 @@ public class TestCompetingRisk {
settings.setNumberOfSplits(0); settings.setNumberOfSplits(0);
settings.setTrainingDataLocation("src/test/resources/wihs.bootstrapped2.csv"); settings.setTrainingDataLocation("src/test/resources/wihs.bootstrapped2.csv");
final List<Covariate> covariates = getCovariates(settings); final List<Covariate> covariates = settings.getCovariates();
final List<Row<CompetingRiskResponse>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation()); final List<Row<CompetingRiskResponse>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation());
@ -199,12 +196,12 @@ public class TestCompetingRisk {
@Test @Test
public void testLogRankSingleGroupDifferentiatorTwoBooleans() throws IOException { public void testLogRankSingleGroupDifferentiatorTwoBooleans() throws IOException {
final Settings settings = getSettings(); final Settings settings = getSettings();
settings.setCovariates(Utils.easyList( settings.setCovariateSettings(Utils.easyList(
new BooleanCovariateSettings("idu"), new BooleanCovariateSettings("idu"),
new BooleanCovariateSettings("black") new BooleanCovariateSettings("black")
)); ));
final List<Covariate> covariates = getCovariates(settings); final List<Covariate> covariates = settings.getCovariates();
final List<Row<CompetingRiskResponse>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation()); final List<Row<CompetingRiskResponse>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation());
@ -254,7 +251,7 @@ public class TestCompetingRisk {
public void verifyDataset() throws IOException { public void verifyDataset() throws IOException {
final Settings settings = getSettings(); final Settings settings = getSettings();
final List<Covariate> covariates = getCovariates(settings); final List<Covariate> covariates = settings.getCovariates();
final List<Row<CompetingRiskResponse>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation()); final List<Row<CompetingRiskResponse>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation());
@ -291,7 +288,7 @@ public class TestCompetingRisk {
final Settings settings = getSettings(); final Settings settings = getSettings();
settings.setNtree(300); // results are too variable at 100 settings.setNtree(300); // results are too variable at 100
final List<Covariate> covariates = getCovariates(settings); final List<Covariate> covariates = settings.getCovariates();
final List<Row<CompetingRiskResponse>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation()); final List<Row<CompetingRiskResponse>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation());
final ForestTrainer<CompetingRiskResponse, CompetingRiskFunctions, CompetingRiskFunctions> forestTrainer = new ForestTrainer<>(settings, dataset, covariates); final ForestTrainer<CompetingRiskResponse, CompetingRiskFunctions, CompetingRiskFunctions> forestTrainer = new ForestTrainer<>(settings, dataset, covariates);
final Forest<CompetingRiskFunctions, CompetingRiskFunctions> forest = forestTrainer.trainSerial(); final Forest<CompetingRiskFunctions, CompetingRiskFunctions> forest = forestTrainer.trainSerial();

View file

@ -1,6 +1,7 @@
package ca.joeltherrien.randomforest.competingrisk; package ca.joeltherrien.randomforest.competingrisk;
import ca.joeltherrien.randomforest.Row; import ca.joeltherrien.randomforest.Row;
import ca.joeltherrien.randomforest.covariates.Covariate;
import ca.joeltherrien.randomforest.responses.competingrisk.*; import ca.joeltherrien.randomforest.responses.competingrisk.*;
import ca.joeltherrien.randomforest.tree.Forest; import ca.joeltherrien.randomforest.tree.Forest;
import ca.joeltherrien.randomforest.utils.MathFunction; import ca.joeltherrien.randomforest.utils.MathFunction;
@ -54,10 +55,10 @@ public class TestCompetingRiskErrorRateCalculator {
final CompetingRiskResponse response4 = new CompetingRiskResponse(1, 3.0); final CompetingRiskResponse response4 = new CompetingRiskResponse(1, 3.0);
final List<Row<CompetingRiskResponse>> dataset = Utils.easyList( final List<Row<CompetingRiskResponse>> dataset = Utils.easyList(
new Row<>(Collections.emptyMap(), 1, response1), new Row<>(new Covariate.Value[]{}, 1, response1),
new Row<>(Collections.emptyMap(), 2, response2), new Row<>(new Covariate.Value[]{}, 2, response2),
new Row<>(Collections.emptyMap(), 3, response3), new Row<>(new Covariate.Value[]{}, 3, response3),
new Row<>(Collections.emptyMap(), 4, response4) new Row<>(new Covariate.Value[]{}, 4, response4)
); );
final double[] mortalityOneArray = new double[]{1, 4, 3, 9}; final double[] mortalityOneArray = new double[]{1, 4, 3, 9};

View file

@ -58,7 +58,7 @@ public class FactorCovariateTest {
private FactorCovariate createTestCovariate(){ private FactorCovariate createTestCovariate(){
final List<String> levels = Utils.easyList("DOG", "CAT", "MOUSE"); final List<String> levels = Utils.easyList("DOG", "CAT", "MOUSE");
return new FactorCovariate("pet", levels); return new FactorCovariate("pet", 0, levels);
} }

View file

@ -38,7 +38,7 @@ public class TestLoadingCSV {
final Settings settings = Settings.builder() final Settings settings = Settings.builder()
.trainingDataLocation(filename) .trainingDataLocation(filename)
.covariates( .covariateSettings(
Utils.easyList(new NumericCovariateSettings("x1"), Utils.easyList(new NumericCovariateSettings("x1"),
new FactorCovariateSettings("x2", Utils.easyList("dog", "cat", "mouse")), new FactorCovariateSettings("x2", Utils.easyList("dog", "cat", "mouse")),
new BooleanCovariateSettings("x3")) new BooleanCovariateSettings("x3"))
@ -46,8 +46,7 @@ public class TestLoadingCSV {
.yVarSettings(yVarSettings) .yVarSettings(yVarSettings)
.build(); .build();
final List<Covariate> covariates = settings.getCovariates().stream() final List<Covariate> covariates = settings.getCovariates();
.map(cs -> cs.build()).collect(Collectors.toList());
final DataLoader.ResponseLoader loader = settings.getResponseLoader(); final DataLoader.ResponseLoader loader = settings.getResponseLoader();
@ -56,46 +55,50 @@ public class TestLoadingCSV {
} }
@Test @Test
public void verifyLoadingNormal() throws IOException { public void verifyLoadingNormal(final List<Covariate> covariates) throws IOException {
final List<Row<Double>> data = loadData("src/test/resources/testCSV.csv"); final List<Row<Double>> data = loadData("src/test/resources/testCSV.csv");
assertData(data); assertData(data, covariates);
} }
@Test @Test
public void verifyLoadingGz() throws IOException { public void verifyLoadingGz(final List<Covariate> covariates) throws IOException {
final List<Row<Double>> data = loadData("src/test/resources/testCSV.csv.gz"); final List<Row<Double>> data = loadData("src/test/resources/testCSV.csv.gz");
assertData(data); assertData(data, covariates);
} }
private void assertData(final List<Row<Double>> data){ private void assertData(final List<Row<Double>> data, final List<Covariate> covariates){
final Covariate x1 = covariates.get(0);
final Covariate x2 = covariates.get(0);
final Covariate x3 = covariates.get(0);
assertEquals(4, data.size()); assertEquals(4, data.size());
Row<Double> row = data.get(0); Row<Double> row = data.get(0);
assertEquals(5.0, (double)row.getResponse()); assertEquals(5.0, (double)row.getResponse());
assertEquals(3.0, row.getCovariateValue("x1").getValue()); assertEquals(3.0, row.getCovariateValue(x1).getValue());
assertEquals("mouse", row.getCovariateValue("x2").getValue()); assertEquals("mouse", row.getCovariateValue(x2).getValue());
assertEquals(true, row.getCovariateValue("x3").getValue()); assertEquals(true, row.getCovariateValue(x3).getValue());
row = data.get(1); row = data.get(1);
assertEquals(2.0, (double)row.getResponse()); assertEquals(2.0, (double)row.getResponse());
assertEquals(1.0, row.getCovariateValue("x1").getValue()); assertEquals(1.0, row.getCovariateValue(x1).getValue());
assertEquals("dog", row.getCovariateValue("x2").getValue()); assertEquals("dog", row.getCovariateValue(x2).getValue());
assertEquals(false, row.getCovariateValue("x3").getValue()); assertEquals(false, row.getCovariateValue(x3).getValue());
row = data.get(2); row = data.get(2);
assertEquals(9.0, (double)row.getResponse()); assertEquals(9.0, (double)row.getResponse());
assertEquals(1.5, row.getCovariateValue("x1").getValue()); assertEquals(1.5, row.getCovariateValue(x1).getValue());
assertEquals("cat", row.getCovariateValue("x2").getValue()); assertEquals("cat", row.getCovariateValue(x2).getValue());
assertEquals(true, row.getCovariateValue("x3").getValue()); assertEquals(true, row.getCovariateValue(x3).getValue());
row = data.get(3); row = data.get(3);
assertEquals(-3.0, (double)row.getResponse()); assertEquals(-3.0, (double)row.getResponse());
assertTrue(row.getCovariateValue("x1").isNA()); assertTrue(row.getCovariateValue(x1).isNA());
assertTrue(row.getCovariateValue("x2").isNA()); assertTrue(row.getCovariateValue(x2).isNA());
assertTrue(row.getCovariateValue("x3").isNA()); assertTrue(row.getCovariateValue(x3).isNA());
} }
} }

View file

@ -31,7 +31,7 @@ public class TestPersistence {
yVarSettings.set("name", new TextNode("y")); yVarSettings.set("name", new TextNode("y"));
final Settings settingsOriginal = Settings.builder() final Settings settingsOriginal = Settings.builder()
.covariates(Utils.easyList( .covariateSettings(Utils.easyList(
new NumericCovariateSettings("x1"), new NumericCovariateSettings("x1"),
new BooleanCovariateSettings("x2"), new BooleanCovariateSettings("x2"),
new FactorCovariateSettings("x3", Utils.easyList("cat", "mouse", "dog")) new FactorCovariateSettings("x3", Utils.easyList("cat", "mouse", "dog"))

View file

@ -27,22 +27,22 @@ public class TrainForest {
final List<Covariate> covariateList = new ArrayList<>(p); final List<Covariate> covariateList = new ArrayList<>(p);
for(int j =0; j < p; j++){ for(int j =0; j < p; j++){
final NumericCovariate covariate = new NumericCovariate("x"+j); final NumericCovariate covariate = new NumericCovariate("x"+j, j);
covariateList.add(covariate); covariateList.add(covariate);
} }
for(int i=0; i<n; i++){ for(int i=0; i<n; i++){
double y = 0.0; double y = 0.0;
final Map<String, Covariate.Value> map = new HashMap<>(); final Covariate.Value[] valueArray = new Covariate.Value[covariateList.size()];
for(final Covariate covariate : covariateList) { for(final Covariate covariate : covariateList) {
final double x = random.nextDouble(); final double x = random.nextDouble();
y += x; y += x;
map.put(covariate.getName(), covariate.createValue(x)); valueArray[covariate.getIndex()] = covariate.createValue(y);
} }
data.add(i, new Row<>(map, i, y)); data.add(i, new Row<>(valueArray, i, y));
if(y < minY){ if(y < minY){
minY = y; minY = y;

View file

@ -18,15 +18,13 @@ import java.util.stream.DoubleStream;
public class TrainSingleTree { public class TrainSingleTree {
public static void main(String[] args) { public static void main(String[] args) {
System.out.println("Hello world!");
final Random random = new Random(123); final Random random = new Random(123);
final int n = 1000; final int n = 1000;
final List<Row<Double>> trainingSet = new ArrayList<>(n); final List<Row<Double>> trainingSet = new ArrayList<>(n);
final Covariate<Double> x1Covariate = new NumericCovariate("x1"); final Covariate<Double> x1Covariate = new NumericCovariate("x1", 0);
final Covariate<Double> x2Covariate = new NumericCovariate("x2"); final Covariate<Double> x2Covariate = new NumericCovariate("x2", 1);
final List<Covariate.Value<Double>> x1List = DoubleStream final List<Covariate.Value<Double>> x1List = DoubleStream
.generate(() -> random.nextDouble()*10.0) .generate(() -> random.nextDouble()*10.0)
@ -100,17 +98,21 @@ public class TrainSingleTree {
public static Row<Double> generateRow(Covariate.Value<Double> x1, Covariate.Value<Double> x2, int id){ public static Row<Double> generateRow(Covariate.Value<Double> x1, Covariate.Value<Double> x2, int id){
double y = generateResponse(x1.getValue(), x2.getValue()); double y = generateResponse(x1.getValue(), x2.getValue());
final Map<String, Covariate.Value> map = Utils.easyMap("x1", x1, "x2", x2); final Covariate.Value[] valueArray = new Covariate.Value[2];
valueArray[0] = x1;
valueArray[1] = x2;
return new Row<>(map, id, y); return new Row<>(valueArray, id, y);
} }
public static CovariateRow generateCovariateRow(Covariate.Value x1, Covariate.Value x2, int id){ public static CovariateRow generateCovariateRow(Covariate.Value x1, Covariate.Value x2, int id){
final Map<String, Covariate.Value> map = Utils.easyMap("x1", x1, "x2", x2); final Covariate.Value[] valueArray = new Covariate.Value[2];
valueArray[0] = x1;
valueArray[1] = x2;
return new CovariateRow(map, id); return new CovariateRow(valueArray, id);
} }

View file

@ -28,9 +28,9 @@ public class TrainSingleTreeFactor {
final int n = 10000; final int n = 10000;
final List<Row<Double>> trainingSet = new ArrayList<>(n); final List<Row<Double>> trainingSet = new ArrayList<>(n);
final Covariate<Double> x1Covariate = new NumericCovariate("x1"); final Covariate<Double> x1Covariate = new NumericCovariate("x1", 0);
final Covariate<Double> x2Covariate = new NumericCovariate("x2"); final Covariate<Double> x2Covariate = new NumericCovariate("x2", 1);
final FactorCovariate x3Covariate = new FactorCovariate("x3", Utils.easyList("cat", "dog", "mouse")); final FactorCovariate x3Covariate = new FactorCovariate("x3", 2, Utils.easyList("cat", "dog", "mouse"));
final List<Covariate.Value<Double>> x1List = DoubleStream final List<Covariate.Value<Double>> x1List = DoubleStream
.generate(() -> random.nextDouble()*10.0) .generate(() -> random.nextDouble()*10.0)
@ -128,17 +128,25 @@ public class TrainSingleTreeFactor {
public static Row<Double> generateRow(Covariate.Value<Double> x1, Covariate.Value<Double> x2, Covariate.Value<String> x3, int id){ public static Row<Double> generateRow(Covariate.Value<Double> x1, Covariate.Value<Double> x2, Covariate.Value<String> x3, int id){
double y = generateResponse(x1.getValue(), x2.getValue(), x3.getValue()); double y = generateResponse(x1.getValue(), x2.getValue(), x3.getValue());
final Map<String, Covariate.Value> map = Utils.easyMap("x1", x1, "x2", x2); final Covariate.Value[] valueArray = new Covariate.Value[3];
valueArray[0] = x1;
valueArray[1] = x2;
valueArray[2] = x3;
return new Row<>(map, id, y); //final Map<String, Covariate.Value> map = Utils.easyMap("x1", x1, "x2", x2); // Missing x3?
return new Row<>(valueArray, id, y);
} }
public static CovariateRow generateCovariateRow(Covariate.Value x1, Covariate.Value x2, Covariate.Value x3, int id){ public static CovariateRow generateCovariateRow(Covariate.Value x1, Covariate.Value x2, Covariate.Value x3, int id){
final Map<String, Covariate.Value> map = Utils.easyMap("x1", x1, "x2", x2, "x3", x3); final Covariate.Value[] valueArray = new Covariate.Value[3];
valueArray[0] = x1;
valueArray[1] = x2;
valueArray[2] = x3;
return new CovariateRow(map, id); return new CovariateRow(valueArray, id);
} }