Switch code to storing Covariate.Value using arrays instead of Maps

This commit is contained in:
Joel Therrien 2018-09-18 11:17:15 -07:00
parent de39f60314
commit aa733d5eba
23 changed files with 145 additions and 101 deletions

View file

@ -12,14 +12,13 @@ import java.util.Map;
@RequiredArgsConstructor
public class CovariateRow implements Serializable {
private final Map<String, Covariate.Value> valueMap;
private final Covariate.Value[] valueArray;
@Getter
private final int id;
public Covariate.Value<?> getCovariateValue(String name){
return valueMap.get(name);
public Covariate.Value<?> getCovariateValue(Covariate covariate){
return valueArray[covariate.getIndex()];
}
@Override
@ -28,18 +27,21 @@ public class CovariateRow implements Serializable {
}
public static CovariateRow createSimple(Map<String, String> simpleMap, List<Covariate> covariateList, int id){
final Map<String, Covariate.Value> valueMap = new HashMap<>();
final Covariate.Value[] valueArray = new Covariate.Value[covariateList.size()];
final Map<String, Covariate> covariateMap = new HashMap<>();
covariateList.forEach(covariate -> covariateMap.put(covariate.getName(), covariate));
simpleMap.forEach((name, valueStr) -> {
if(covariateMap.containsKey(name)){
valueMap.put(name, covariateMap.get(name).createValue(valueStr));
final Covariate covariate = covariateMap.get(name);
if(covariate != null){ // happens often in tests where we experiment with adding / removing covariates
valueArray[covariate.getIndex()] = covariate.createValue(valueStr);
}
});
return new CovariateRow(valueMap, id);
return new CovariateRow(valueArray, id);
}
}

View file

@ -37,15 +37,15 @@ public class DataLoader {
int id = 1;
for(final CSVRecord record : parser){
final Map<String, Covariate.Value> covariateValueMap = new HashMap<>();
final Covariate.Value[] valueArray = new Covariate.Value[covariates.size()];
for(final Covariate<?> covariate : covariates){
covariateValueMap.put(covariate.getName(), covariate.createValue(record.get(covariate.getName())));
valueArray[covariate.getIndex()] = covariate.createValue(record.get(covariate.getName()));
}
final Y y = responseLoader.parse(record);
dataset.add(new Row<>(covariateValueMap, id++, y));
dataset.add(new Row<>(valueArray, id++, y));
}

View file

@ -1,9 +1,6 @@
package ca.joeltherrien.randomforest;
import ca.joeltherrien.randomforest.covariates.BooleanCovariateSettings;
import ca.joeltherrien.randomforest.covariates.Covariate;
import ca.joeltherrien.randomforest.covariates.FactorCovariateSettings;
import ca.joeltherrien.randomforest.covariates.NumericCovariateSettings;
import ca.joeltherrien.randomforest.covariates.*;
import ca.joeltherrien.randomforest.responses.competingrisk.*;
import ca.joeltherrien.randomforest.responses.competingrisk.combiner.CompetingRiskFunctionCombiner;
import ca.joeltherrien.randomforest.responses.competingrisk.combiner.alternative.CompetingRiskListCombiner;
@ -17,6 +14,7 @@ import com.fasterxml.jackson.databind.node.ObjectNode;
import com.fasterxml.jackson.databind.node.TextNode;
import java.io.*;
import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;
@ -35,8 +33,7 @@ public class Main {
final File settingsFile = new File(args[0]);
final Settings settings = Settings.load(settingsFile);
final List<Covariate> covariates = settings.getCovariates().stream()
.map(cs -> cs.build()).collect(Collectors.toList());
final List<Covariate> covariates = settings.getCovariates();
if(args[1].equalsIgnoreCase("train")){
final List<Row> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation());
@ -168,7 +165,7 @@ public class Main {
yVarSettings.set("name", new TextNode("y"));
return Settings.builder()
.covariates(Utils.easyList(
.covariateSettings(Utils.easyList(
new NumericCovariateSettings("x1"),
new BooleanCovariateSettings("x2"),
new FactorCovariateSettings("x3", Utils.easyList("cat", "mouse", "dog"))

View file

@ -3,14 +3,16 @@ package ca.joeltherrien.randomforest;
import ca.joeltherrien.randomforest.covariates.Covariate;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
public class Row<Y> extends CovariateRow {
private final Y response;
public Row(Map<String, Covariate.Value> valueMap, int id, Y response){
super(valueMap, id);
public Row(final Covariate.Value[] valueArray, final int id, final Y response){
super(valueArray, id);
this.response = response;
}
@ -24,6 +26,20 @@ public class Row<Y> extends CovariateRow {
return "Row " + this.getId();
}
public static <Y> Row<Y> createSimple(Map<String, String> simpleMap, List<Covariate> covariateList, int id, final Y response){
final Covariate.Value[] valueArray = new Covariate.Value[covariateList.size()];
final Map<String, Covariate> covariateMap = new HashMap<>();
covariateList.forEach(covariate -> covariateMap.put(covariate.getName(), covariate));
simpleMap.forEach((name, valueStr) -> {
final Covariate covariate = covariateMap.get(name);
if(covariate != null){ // happens often in tests where we experiment with adding / removing covariates
valueArray[covariate.getIndex()] = covariate.createValue(valueStr);
}
});
return new Row<Y>(valueArray, id, response);
}
}

View file

@ -1,5 +1,6 @@
package ca.joeltherrien.randomforest;
import ca.joeltherrien.randomforest.covariates.Covariate;
import ca.joeltherrien.randomforest.covariates.CovariateSettings;
import ca.joeltherrien.randomforest.responses.competingrisk.*;
import ca.joeltherrien.randomforest.responses.competingrisk.combiner.CompetingRiskFunctionCombiner;
@ -195,7 +196,7 @@ public class Settings {
private ObjectNode groupDifferentiatorSettings = new ObjectNode(JsonNodeFactory.instance);
private ObjectNode treeCombinerSettings = new ObjectNode(JsonNodeFactory.instance);
private List<CovariateSettings> covariates = new ArrayList<>();
private List<CovariateSettings> covariateSettings = new ArrayList<>();
private ObjectNode yVarSettings = new ObjectNode(JsonNodeFactory.instance);
// number of covariates to randomly try
@ -227,7 +228,7 @@ public class Settings {
//mapper.enableDefaultTyping();
// Jackson can struggle with some types of Lists, such as that returned by the useful List.of(...)
this.covariates = new ArrayList<>(this.covariates);
this.covariateSettings = new ArrayList<>(this.covariateSettings);
mapper.writeValue(file, this);
}
@ -260,4 +261,14 @@ public class Settings {
return getResponseCombinerConstructor(type).apply(treeCombinerSettings);
}
@JsonIgnore
public List<Covariate> getCovariates(){
final List<CovariateSettings> covariateSettingsList = this.getCovariateSettings();
final List<Covariate> covariates = new ArrayList<>(covariateSettingsList.size());
for(int i = 0; i < covariateSettingsList.size(); i++){
covariates.add(covariateSettingsList.get(i).build(i));
}
return covariates;
}
}

View file

@ -13,6 +13,9 @@ public final class BooleanCovariate implements Covariate<Boolean>{
@Getter
private final String name;
@Getter
private final int index;
private final BooleanSplitRule splitRule = new BooleanSplitRule(); // there's only one possible rule for BooleanCovariates.
@Override

View file

@ -12,7 +12,7 @@ public final class BooleanCovariateSettings extends CovariateSettings<Boolean> {
}
@Override
public BooleanCovariate build() {
return new BooleanCovariate(name);
public BooleanCovariate build(int index) {
return new BooleanCovariate(name, index);
}
}

View file

@ -12,6 +12,8 @@ public interface Covariate<V> extends Serializable {
String getName();
int getIndex();
Collection<? extends SplitRule<V>> generateSplitRules(final List<Value<V>> data, final int number);
Value<V> createValue(V value);
@ -54,7 +56,7 @@ public interface Covariate<V> extends Serializable {
for(final Row<Y> row : rows) {
final Value<V> value = (Value<V>) row.getCovariateValue(getParent().getName());
final Value<V> value = (Value<V>) row.getCovariateValue(getParent());
if(value.isNA()){
missingValueRows.add(row);
@ -76,7 +78,7 @@ public interface Covariate<V> extends Serializable {
}
default boolean isLeftHand(CovariateRow row, final double probabilityNaLeftHand){
final Value<V> value = (Value<V>) row.getCovariateValue(getParent().getName());
final Value<V> value = (Value<V>) row.getCovariateValue(getParent());
if(value.isNA()){
return ThreadLocalRandom.current().nextDouble() <= probabilityNaLeftHand;

View file

@ -28,5 +28,5 @@ public abstract class CovariateSettings<V> {
this.name = name;
}
public abstract Covariate<V> build();
public abstract Covariate<V> build(int index);
}

View file

@ -1,20 +1,27 @@
package ca.joeltherrien.randomforest.covariates;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import java.util.*;
import java.util.concurrent.ThreadLocalRandom;
public final class FactorCovariate implements Covariate<String>{
@Getter
private final String name;
@Getter
private final int index;
private final Map<String, FactorValue> factorLevels;
private final FactorValue naValue;
private final int numberOfPossiblePairings;
public FactorCovariate(final String name, List<String> levels){
public FactorCovariate(final String name, final int index, List<String> levels){
this.name = name;
this.index = index;
this.factorLevels = new HashMap<>();
for(final String level : levels){
@ -33,10 +40,6 @@ public final class FactorCovariate implements Covariate<String>{
}
@Override
public String getName() {
return name;
}
@Override
public Set<FactorSplitRule> generateSplitRules(List<Value<String>> data, int number) {

View file

@ -18,7 +18,7 @@ public final class FactorCovariateSettings extends CovariateSettings<String> {
}
@Override
public FactorCovariate build() {
return new FactorCovariate(name, levels);
public FactorCovariate build(int index) {
return new FactorCovariate(name, index, levels);
}
}

View file

@ -16,6 +16,9 @@ public final class NumericCovariate implements Covariate<Double>{
@Getter
private final String name;
@Getter
private final int index;
@Override
public Collection<NumericSplitRule> generateSplitRules(List<Value<Double>> data, int number) {

View file

@ -12,7 +12,7 @@ public final class NumericCovariateSettings extends CovariateSettings<Double> {
}
@Override
public NumericCovariate build() {
return new NumericCovariate(name);
public NumericCovariate build(int index) {
return new NumericCovariate(name, index);
}
}

View file

@ -113,7 +113,7 @@ public class TreeTrainer<Y, O> {
.generateSplitRules(
data
.stream()
.map(row -> row.getCovariateValue(covariate.getName()))
.map(row -> row.getCovariateValue(covariate))
.collect(Collectors.toList())
, numberToTry);

View file

@ -53,7 +53,7 @@ public class TestSavingLoading {
yVarSettings.set("delta", new TextNode("status"));
return Settings.builder()
.covariates(Utils.easyList(
.covariateSettings(Utils.easyList(
new NumericCovariateSettings("ageatfda"),
new BooleanCovariateSettings("idu"),
new BooleanCovariateSettings("black"),
@ -87,14 +87,10 @@ public class TestSavingLoading {
, covariates, 1);
}
public List<Covariate> getCovariates(Settings settings){
return settings.getCovariates().stream().map(covariateSettings -> covariateSettings.build()).collect(Collectors.toList());
}
@Test
public void testSavingLoading() throws IOException, ClassNotFoundException {
final Settings settings = getSettings();
final List<Covariate> covariates = getCovariates(settings);
final List<Covariate> covariates = settings.getCovariates();
final List<Row<CompetingRiskResponse>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation());
final File directory = new File(settings.getSaveTreeLocation());

View file

@ -55,7 +55,7 @@ public class TestCompetingRisk {
yVarSettings.set("delta", new TextNode("status"));
return Settings.builder()
.covariates(Utils.easyList(
.covariateSettings(Utils.easyList(
new NumericCovariateSettings("ageatfda"),
new BooleanCovariateSettings("idu"),
new BooleanCovariateSettings("black"),
@ -79,9 +79,6 @@ public class TestCompetingRisk {
.build();
}
public List<Covariate> getCovariates(Settings settings){
return settings.getCovariates().stream().map(covariateSettings -> covariateSettings.build()).collect(Collectors.toList());
}
public CovariateRow getPredictionRow(List<Covariate> covariates){
return CovariateRow.createSimple(Utils.easyMap(
@ -96,12 +93,12 @@ public class TestCompetingRisk {
public void testSingleTree() throws IOException {
final Settings settings = getSettings();
settings.setTrainingDataLocation("src/test/resources/wihs.bootstrapped.csv");
settings.setCovariates(Utils.easyList(
settings.setCovariateSettings(Utils.easyList(
new BooleanCovariateSettings("idu"),
new BooleanCovariateSettings("black")
)); // by only using BooleanCovariates (only one split rule) we can guarantee identical results with randomForestSRC on one tree.
final List<Covariate> covariates = getCovariates(settings);
final List<Covariate> covariates = settings.getCovariates();
final List<Row<CompetingRiskResponse>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation());
@ -154,7 +151,7 @@ public class TestCompetingRisk {
settings.setNumberOfSplits(0);
settings.setTrainingDataLocation("src/test/resources/wihs.bootstrapped2.csv");
final List<Covariate> covariates = getCovariates(settings);
final List<Covariate> covariates = settings.getCovariates();
final List<Row<CompetingRiskResponse>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation());
@ -199,12 +196,12 @@ public class TestCompetingRisk {
@Test
public void testLogRankSingleGroupDifferentiatorTwoBooleans() throws IOException {
final Settings settings = getSettings();
settings.setCovariates(Utils.easyList(
settings.setCovariateSettings(Utils.easyList(
new BooleanCovariateSettings("idu"),
new BooleanCovariateSettings("black")
));
final List<Covariate> covariates = getCovariates(settings);
final List<Covariate> covariates = settings.getCovariates();
final List<Row<CompetingRiskResponse>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation());
@ -254,7 +251,7 @@ public class TestCompetingRisk {
public void verifyDataset() throws IOException {
final Settings settings = getSettings();
final List<Covariate> covariates = getCovariates(settings);
final List<Covariate> covariates = settings.getCovariates();
final List<Row<CompetingRiskResponse>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation());
@ -291,7 +288,7 @@ public class TestCompetingRisk {
final Settings settings = getSettings();
settings.setNtree(300); // results are too variable at 100
final List<Covariate> covariates = getCovariates(settings);
final List<Covariate> covariates = settings.getCovariates();
final List<Row<CompetingRiskResponse>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation());
final ForestTrainer<CompetingRiskResponse, CompetingRiskFunctions, CompetingRiskFunctions> forestTrainer = new ForestTrainer<>(settings, dataset, covariates);
final Forest<CompetingRiskFunctions, CompetingRiskFunctions> forest = forestTrainer.trainSerial();

View file

@ -1,6 +1,7 @@
package ca.joeltherrien.randomforest.competingrisk;
import ca.joeltherrien.randomforest.Row;
import ca.joeltherrien.randomforest.covariates.Covariate;
import ca.joeltherrien.randomforest.responses.competingrisk.*;
import ca.joeltherrien.randomforest.tree.Forest;
import ca.joeltherrien.randomforest.utils.MathFunction;
@ -54,10 +55,10 @@ public class TestCompetingRiskErrorRateCalculator {
final CompetingRiskResponse response4 = new CompetingRiskResponse(1, 3.0);
final List<Row<CompetingRiskResponse>> dataset = Utils.easyList(
new Row<>(Collections.emptyMap(), 1, response1),
new Row<>(Collections.emptyMap(), 2, response2),
new Row<>(Collections.emptyMap(), 3, response3),
new Row<>(Collections.emptyMap(), 4, response4)
new Row<>(new Covariate.Value[]{}, 1, response1),
new Row<>(new Covariate.Value[]{}, 2, response2),
new Row<>(new Covariate.Value[]{}, 3, response3),
new Row<>(new Covariate.Value[]{}, 4, response4)
);
final double[] mortalityOneArray = new double[]{1, 4, 3, 9};

View file

@ -58,7 +58,7 @@ public class FactorCovariateTest {
private FactorCovariate createTestCovariate(){
final List<String> levels = Utils.easyList("DOG", "CAT", "MOUSE");
return new FactorCovariate("pet", levels);
return new FactorCovariate("pet", 0, levels);
}

View file

@ -38,7 +38,7 @@ public class TestLoadingCSV {
final Settings settings = Settings.builder()
.trainingDataLocation(filename)
.covariates(
.covariateSettings(
Utils.easyList(new NumericCovariateSettings("x1"),
new FactorCovariateSettings("x2", Utils.easyList("dog", "cat", "mouse")),
new BooleanCovariateSettings("x3"))
@ -46,8 +46,7 @@ public class TestLoadingCSV {
.yVarSettings(yVarSettings)
.build();
final List<Covariate> covariates = settings.getCovariates().stream()
.map(cs -> cs.build()).collect(Collectors.toList());
final List<Covariate> covariates = settings.getCovariates();
final DataLoader.ResponseLoader loader = settings.getResponseLoader();
@ -56,46 +55,50 @@ public class TestLoadingCSV {
}
@Test
public void verifyLoadingNormal() throws IOException {
public void verifyLoadingNormal(final List<Covariate> covariates) throws IOException {
final List<Row<Double>> data = loadData("src/test/resources/testCSV.csv");
assertData(data);
assertData(data, covariates);
}
@Test
public void verifyLoadingGz() throws IOException {
public void verifyLoadingGz(final List<Covariate> covariates) throws IOException {
final List<Row<Double>> data = loadData("src/test/resources/testCSV.csv.gz");
assertData(data);
assertData(data, covariates);
}
private void assertData(final List<Row<Double>> data){
private void assertData(final List<Row<Double>> data, final List<Covariate> covariates){
final Covariate x1 = covariates.get(0);
final Covariate x2 = covariates.get(0);
final Covariate x3 = covariates.get(0);
assertEquals(4, data.size());
Row<Double> row = data.get(0);
assertEquals(5.0, (double)row.getResponse());
assertEquals(3.0, row.getCovariateValue("x1").getValue());
assertEquals("mouse", row.getCovariateValue("x2").getValue());
assertEquals(true, row.getCovariateValue("x3").getValue());
assertEquals(3.0, row.getCovariateValue(x1).getValue());
assertEquals("mouse", row.getCovariateValue(x2).getValue());
assertEquals(true, row.getCovariateValue(x3).getValue());
row = data.get(1);
assertEquals(2.0, (double)row.getResponse());
assertEquals(1.0, row.getCovariateValue("x1").getValue());
assertEquals("dog", row.getCovariateValue("x2").getValue());
assertEquals(false, row.getCovariateValue("x3").getValue());
assertEquals(1.0, row.getCovariateValue(x1).getValue());
assertEquals("dog", row.getCovariateValue(x2).getValue());
assertEquals(false, row.getCovariateValue(x3).getValue());
row = data.get(2);
assertEquals(9.0, (double)row.getResponse());
assertEquals(1.5, row.getCovariateValue("x1").getValue());
assertEquals("cat", row.getCovariateValue("x2").getValue());
assertEquals(true, row.getCovariateValue("x3").getValue());
assertEquals(1.5, row.getCovariateValue(x1).getValue());
assertEquals("cat", row.getCovariateValue(x2).getValue());
assertEquals(true, row.getCovariateValue(x3).getValue());
row = data.get(3);
assertEquals(-3.0, (double)row.getResponse());
assertTrue(row.getCovariateValue("x1").isNA());
assertTrue(row.getCovariateValue("x2").isNA());
assertTrue(row.getCovariateValue("x3").isNA());
assertTrue(row.getCovariateValue(x1).isNA());
assertTrue(row.getCovariateValue(x2).isNA());
assertTrue(row.getCovariateValue(x3).isNA());
}
}

View file

@ -31,7 +31,7 @@ public class TestPersistence {
yVarSettings.set("name", new TextNode("y"));
final Settings settingsOriginal = Settings.builder()
.covariates(Utils.easyList(
.covariateSettings(Utils.easyList(
new NumericCovariateSettings("x1"),
new BooleanCovariateSettings("x2"),
new FactorCovariateSettings("x3", Utils.easyList("cat", "mouse", "dog"))

View file

@ -27,22 +27,22 @@ public class TrainForest {
final List<Covariate> covariateList = new ArrayList<>(p);
for(int j =0; j < p; j++){
final NumericCovariate covariate = new NumericCovariate("x"+j);
final NumericCovariate covariate = new NumericCovariate("x"+j, j);
covariateList.add(covariate);
}
for(int i=0; i<n; i++){
double y = 0.0;
final Map<String, Covariate.Value> map = new HashMap<>();
final Covariate.Value[] valueArray = new Covariate.Value[covariateList.size()];
for(final Covariate covariate : covariateList) {
final double x = random.nextDouble();
y += x;
map.put(covariate.getName(), covariate.createValue(x));
valueArray[covariate.getIndex()] = covariate.createValue(y);
}
data.add(i, new Row<>(map, i, y));
data.add(i, new Row<>(valueArray, i, y));
if(y < minY){
minY = y;

View file

@ -18,15 +18,13 @@ import java.util.stream.DoubleStream;
public class TrainSingleTree {
public static void main(String[] args) {
System.out.println("Hello world!");
final Random random = new Random(123);
final int n = 1000;
final List<Row<Double>> trainingSet = new ArrayList<>(n);
final Covariate<Double> x1Covariate = new NumericCovariate("x1");
final Covariate<Double> x2Covariate = new NumericCovariate("x2");
final Covariate<Double> x1Covariate = new NumericCovariate("x1", 0);
final Covariate<Double> x2Covariate = new NumericCovariate("x2", 1);
final List<Covariate.Value<Double>> x1List = DoubleStream
.generate(() -> random.nextDouble()*10.0)
@ -100,17 +98,21 @@ public class TrainSingleTree {
public static Row<Double> generateRow(Covariate.Value<Double> x1, Covariate.Value<Double> x2, int id){
double y = generateResponse(x1.getValue(), x2.getValue());
final Map<String, Covariate.Value> map = Utils.easyMap("x1", x1, "x2", x2);
final Covariate.Value[] valueArray = new Covariate.Value[2];
valueArray[0] = x1;
valueArray[1] = x2;
return new Row<>(map, id, y);
return new Row<>(valueArray, id, y);
}
public static CovariateRow generateCovariateRow(Covariate.Value x1, Covariate.Value x2, int id){
final Map<String, Covariate.Value> map = Utils.easyMap("x1", x1, "x2", x2);
final Covariate.Value[] valueArray = new Covariate.Value[2];
valueArray[0] = x1;
valueArray[1] = x2;
return new CovariateRow(map, id);
return new CovariateRow(valueArray, id);
}

View file

@ -28,9 +28,9 @@ public class TrainSingleTreeFactor {
final int n = 10000;
final List<Row<Double>> trainingSet = new ArrayList<>(n);
final Covariate<Double> x1Covariate = new NumericCovariate("x1");
final Covariate<Double> x2Covariate = new NumericCovariate("x2");
final FactorCovariate x3Covariate = new FactorCovariate("x3", Utils.easyList("cat", "dog", "mouse"));
final Covariate<Double> x1Covariate = new NumericCovariate("x1", 0);
final Covariate<Double> x2Covariate = new NumericCovariate("x2", 1);
final FactorCovariate x3Covariate = new FactorCovariate("x3", 2, Utils.easyList("cat", "dog", "mouse"));
final List<Covariate.Value<Double>> x1List = DoubleStream
.generate(() -> random.nextDouble()*10.0)
@ -128,17 +128,25 @@ public class TrainSingleTreeFactor {
public static Row<Double> generateRow(Covariate.Value<Double> x1, Covariate.Value<Double> x2, Covariate.Value<String> x3, int id){
double y = generateResponse(x1.getValue(), x2.getValue(), x3.getValue());
final Map<String, Covariate.Value> map = Utils.easyMap("x1", x1, "x2", x2);
final Covariate.Value[] valueArray = new Covariate.Value[3];
valueArray[0] = x1;
valueArray[1] = x2;
valueArray[2] = x3;
return new Row<>(map, id, y);
//final Map<String, Covariate.Value> map = Utils.easyMap("x1", x1, "x2", x2); // Missing x3?
return new Row<>(valueArray, id, y);
}
public static CovariateRow generateCovariateRow(Covariate.Value x1, Covariate.Value x2, Covariate.Value x3, int id){
final Map<String, Covariate.Value> map = Utils.easyMap("x1", x1, "x2", x2, "x3", x3);
final Covariate.Value[] valueArray = new Covariate.Value[3];
valueArray[0] = x1;
valueArray[1] = x2;
valueArray[2] = x3;
return new CovariateRow(map, id);
return new CovariateRow(valueArray, id);
}