Switch code to storing Covariate.Value using arrays instead of Maps
This commit is contained in:
parent
de39f60314
commit
aa733d5eba
23 changed files with 145 additions and 101 deletions
|
@ -12,14 +12,13 @@ import java.util.Map;
|
|||
@RequiredArgsConstructor
|
||||
public class CovariateRow implements Serializable {
|
||||
|
||||
private final Map<String, Covariate.Value> valueMap;
|
||||
private final Covariate.Value[] valueArray;
|
||||
|
||||
@Getter
|
||||
private final int id;
|
||||
|
||||
public Covariate.Value<?> getCovariateValue(String name){
|
||||
return valueMap.get(name);
|
||||
|
||||
public Covariate.Value<?> getCovariateValue(Covariate covariate){
|
||||
return valueArray[covariate.getIndex()];
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -28,18 +27,21 @@ public class CovariateRow implements Serializable {
|
|||
}
|
||||
|
||||
public static CovariateRow createSimple(Map<String, String> simpleMap, List<Covariate> covariateList, int id){
|
||||
final Map<String, Covariate.Value> valueMap = new HashMap<>();
|
||||
final Covariate.Value[] valueArray = new Covariate.Value[covariateList.size()];
|
||||
final Map<String, Covariate> covariateMap = new HashMap<>();
|
||||
|
||||
covariateList.forEach(covariate -> covariateMap.put(covariate.getName(), covariate));
|
||||
|
||||
simpleMap.forEach((name, valueStr) -> {
|
||||
if(covariateMap.containsKey(name)){
|
||||
valueMap.put(name, covariateMap.get(name).createValue(valueStr));
|
||||
final Covariate covariate = covariateMap.get(name);
|
||||
|
||||
if(covariate != null){ // happens often in tests where we experiment with adding / removing covariates
|
||||
valueArray[covariate.getIndex()] = covariate.createValue(valueStr);
|
||||
}
|
||||
|
||||
});
|
||||
|
||||
return new CovariateRow(valueMap, id);
|
||||
return new CovariateRow(valueArray, id);
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -37,15 +37,15 @@ public class DataLoader {
|
|||
|
||||
int id = 1;
|
||||
for(final CSVRecord record : parser){
|
||||
final Map<String, Covariate.Value> covariateValueMap = new HashMap<>();
|
||||
final Covariate.Value[] valueArray = new Covariate.Value[covariates.size()];
|
||||
|
||||
for(final Covariate<?> covariate : covariates){
|
||||
covariateValueMap.put(covariate.getName(), covariate.createValue(record.get(covariate.getName())));
|
||||
valueArray[covariate.getIndex()] = covariate.createValue(record.get(covariate.getName()));
|
||||
}
|
||||
|
||||
final Y y = responseLoader.parse(record);
|
||||
|
||||
dataset.add(new Row<>(covariateValueMap, id++, y));
|
||||
dataset.add(new Row<>(valueArray, id++, y));
|
||||
|
||||
}
|
||||
|
||||
|
|
|
@ -1,9 +1,6 @@
|
|||
package ca.joeltherrien.randomforest;
|
||||
|
||||
import ca.joeltherrien.randomforest.covariates.BooleanCovariateSettings;
|
||||
import ca.joeltherrien.randomforest.covariates.Covariate;
|
||||
import ca.joeltherrien.randomforest.covariates.FactorCovariateSettings;
|
||||
import ca.joeltherrien.randomforest.covariates.NumericCovariateSettings;
|
||||
import ca.joeltherrien.randomforest.covariates.*;
|
||||
import ca.joeltherrien.randomforest.responses.competingrisk.*;
|
||||
import ca.joeltherrien.randomforest.responses.competingrisk.combiner.CompetingRiskFunctionCombiner;
|
||||
import ca.joeltherrien.randomforest.responses.competingrisk.combiner.alternative.CompetingRiskListCombiner;
|
||||
|
@ -17,6 +14,7 @@ import com.fasterxml.jackson.databind.node.ObjectNode;
|
|||
import com.fasterxml.jackson.databind.node.TextNode;
|
||||
|
||||
import java.io.*;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
|
@ -35,8 +33,7 @@ public class Main {
|
|||
final File settingsFile = new File(args[0]);
|
||||
final Settings settings = Settings.load(settingsFile);
|
||||
|
||||
final List<Covariate> covariates = settings.getCovariates().stream()
|
||||
.map(cs -> cs.build()).collect(Collectors.toList());
|
||||
final List<Covariate> covariates = settings.getCovariates();
|
||||
|
||||
if(args[1].equalsIgnoreCase("train")){
|
||||
final List<Row> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation());
|
||||
|
@ -168,7 +165,7 @@ public class Main {
|
|||
yVarSettings.set("name", new TextNode("y"));
|
||||
|
||||
return Settings.builder()
|
||||
.covariates(Utils.easyList(
|
||||
.covariateSettings(Utils.easyList(
|
||||
new NumericCovariateSettings("x1"),
|
||||
new BooleanCovariateSettings("x2"),
|
||||
new FactorCovariateSettings("x3", Utils.easyList("cat", "mouse", "dog"))
|
||||
|
|
|
@ -3,14 +3,16 @@ package ca.joeltherrien.randomforest;
|
|||
|
||||
import ca.joeltherrien.randomforest.covariates.Covariate;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
public class Row<Y> extends CovariateRow {
|
||||
|
||||
private final Y response;
|
||||
|
||||
public Row(Map<String, Covariate.Value> valueMap, int id, Y response){
|
||||
super(valueMap, id);
|
||||
public Row(final Covariate.Value[] valueArray, final int id, final Y response){
|
||||
super(valueArray, id);
|
||||
this.response = response;
|
||||
}
|
||||
|
||||
|
@ -24,6 +26,20 @@ public class Row<Y> extends CovariateRow {
|
|||
return "Row " + this.getId();
|
||||
}
|
||||
|
||||
public static <Y> Row<Y> createSimple(Map<String, String> simpleMap, List<Covariate> covariateList, int id, final Y response){
|
||||
final Covariate.Value[] valueArray = new Covariate.Value[covariateList.size()];
|
||||
final Map<String, Covariate> covariateMap = new HashMap<>();
|
||||
|
||||
covariateList.forEach(covariate -> covariateMap.put(covariate.getName(), covariate));
|
||||
|
||||
simpleMap.forEach((name, valueStr) -> {
|
||||
final Covariate covariate = covariateMap.get(name);
|
||||
if(covariate != null){ // happens often in tests where we experiment with adding / removing covariates
|
||||
valueArray[covariate.getIndex()] = covariate.createValue(valueStr);
|
||||
}
|
||||
});
|
||||
|
||||
return new Row<Y>(valueArray, id, response);
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
package ca.joeltherrien.randomforest;
|
||||
|
||||
import ca.joeltherrien.randomforest.covariates.Covariate;
|
||||
import ca.joeltherrien.randomforest.covariates.CovariateSettings;
|
||||
import ca.joeltherrien.randomforest.responses.competingrisk.*;
|
||||
import ca.joeltherrien.randomforest.responses.competingrisk.combiner.CompetingRiskFunctionCombiner;
|
||||
|
@ -195,7 +196,7 @@ public class Settings {
|
|||
private ObjectNode groupDifferentiatorSettings = new ObjectNode(JsonNodeFactory.instance);
|
||||
private ObjectNode treeCombinerSettings = new ObjectNode(JsonNodeFactory.instance);
|
||||
|
||||
private List<CovariateSettings> covariates = new ArrayList<>();
|
||||
private List<CovariateSettings> covariateSettings = new ArrayList<>();
|
||||
private ObjectNode yVarSettings = new ObjectNode(JsonNodeFactory.instance);
|
||||
|
||||
// number of covariates to randomly try
|
||||
|
@ -227,7 +228,7 @@ public class Settings {
|
|||
//mapper.enableDefaultTyping();
|
||||
|
||||
// Jackson can struggle with some types of Lists, such as that returned by the useful List.of(...)
|
||||
this.covariates = new ArrayList<>(this.covariates);
|
||||
this.covariateSettings = new ArrayList<>(this.covariateSettings);
|
||||
|
||||
mapper.writeValue(file, this);
|
||||
}
|
||||
|
@ -260,4 +261,14 @@ public class Settings {
|
|||
return getResponseCombinerConstructor(type).apply(treeCombinerSettings);
|
||||
}
|
||||
|
||||
@JsonIgnore
|
||||
public List<Covariate> getCovariates(){
|
||||
final List<CovariateSettings> covariateSettingsList = this.getCovariateSettings();
|
||||
final List<Covariate> covariates = new ArrayList<>(covariateSettingsList.size());
|
||||
for(int i = 0; i < covariateSettingsList.size(); i++){
|
||||
covariates.add(covariateSettingsList.get(i).build(i));
|
||||
}
|
||||
return covariates;
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -13,6 +13,9 @@ public final class BooleanCovariate implements Covariate<Boolean>{
|
|||
@Getter
|
||||
private final String name;
|
||||
|
||||
@Getter
|
||||
private final int index;
|
||||
|
||||
private final BooleanSplitRule splitRule = new BooleanSplitRule(); // there's only one possible rule for BooleanCovariates.
|
||||
|
||||
@Override
|
||||
|
|
|
@ -12,7 +12,7 @@ public final class BooleanCovariateSettings extends CovariateSettings<Boolean> {
|
|||
}
|
||||
|
||||
@Override
|
||||
public BooleanCovariate build() {
|
||||
return new BooleanCovariate(name);
|
||||
public BooleanCovariate build(int index) {
|
||||
return new BooleanCovariate(name, index);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -12,6 +12,8 @@ public interface Covariate<V> extends Serializable {
|
|||
|
||||
String getName();
|
||||
|
||||
int getIndex();
|
||||
|
||||
Collection<? extends SplitRule<V>> generateSplitRules(final List<Value<V>> data, final int number);
|
||||
|
||||
Value<V> createValue(V value);
|
||||
|
@ -54,7 +56,7 @@ public interface Covariate<V> extends Serializable {
|
|||
|
||||
|
||||
for(final Row<Y> row : rows) {
|
||||
final Value<V> value = (Value<V>) row.getCovariateValue(getParent().getName());
|
||||
final Value<V> value = (Value<V>) row.getCovariateValue(getParent());
|
||||
|
||||
if(value.isNA()){
|
||||
missingValueRows.add(row);
|
||||
|
@ -76,7 +78,7 @@ public interface Covariate<V> extends Serializable {
|
|||
}
|
||||
|
||||
default boolean isLeftHand(CovariateRow row, final double probabilityNaLeftHand){
|
||||
final Value<V> value = (Value<V>) row.getCovariateValue(getParent().getName());
|
||||
final Value<V> value = (Value<V>) row.getCovariateValue(getParent());
|
||||
|
||||
if(value.isNA()){
|
||||
return ThreadLocalRandom.current().nextDouble() <= probabilityNaLeftHand;
|
||||
|
|
|
@ -28,5 +28,5 @@ public abstract class CovariateSettings<V> {
|
|||
this.name = name;
|
||||
}
|
||||
|
||||
public abstract Covariate<V> build();
|
||||
public abstract Covariate<V> build(int index);
|
||||
}
|
||||
|
|
|
@ -1,20 +1,27 @@
|
|||
package ca.joeltherrien.randomforest.covariates;
|
||||
|
||||
import lombok.EqualsAndHashCode;
|
||||
import lombok.Getter;
|
||||
|
||||
import java.util.*;
|
||||
import java.util.concurrent.ThreadLocalRandom;
|
||||
|
||||
public final class FactorCovariate implements Covariate<String>{
|
||||
|
||||
@Getter
|
||||
private final String name;
|
||||
|
||||
@Getter
|
||||
private final int index;
|
||||
|
||||
private final Map<String, FactorValue> factorLevels;
|
||||
private final FactorValue naValue;
|
||||
private final int numberOfPossiblePairings;
|
||||
|
||||
|
||||
public FactorCovariate(final String name, List<String> levels){
|
||||
public FactorCovariate(final String name, final int index, List<String> levels){
|
||||
this.name = name;
|
||||
this.index = index;
|
||||
this.factorLevels = new HashMap<>();
|
||||
|
||||
for(final String level : levels){
|
||||
|
@ -33,10 +40,6 @@ public final class FactorCovariate implements Covariate<String>{
|
|||
}
|
||||
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
return name;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Set<FactorSplitRule> generateSplitRules(List<Value<String>> data, int number) {
|
||||
|
|
|
@ -18,7 +18,7 @@ public final class FactorCovariateSettings extends CovariateSettings<String> {
|
|||
}
|
||||
|
||||
@Override
|
||||
public FactorCovariate build() {
|
||||
return new FactorCovariate(name, levels);
|
||||
public FactorCovariate build(int index) {
|
||||
return new FactorCovariate(name, index, levels);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -16,6 +16,9 @@ public final class NumericCovariate implements Covariate<Double>{
|
|||
@Getter
|
||||
private final String name;
|
||||
|
||||
@Getter
|
||||
private final int index;
|
||||
|
||||
@Override
|
||||
public Collection<NumericSplitRule> generateSplitRules(List<Value<Double>> data, int number) {
|
||||
|
||||
|
|
|
@ -12,7 +12,7 @@ public final class NumericCovariateSettings extends CovariateSettings<Double> {
|
|||
}
|
||||
|
||||
@Override
|
||||
public NumericCovariate build() {
|
||||
return new NumericCovariate(name);
|
||||
public NumericCovariate build(int index) {
|
||||
return new NumericCovariate(name, index);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -113,7 +113,7 @@ public class TreeTrainer<Y, O> {
|
|||
.generateSplitRules(
|
||||
data
|
||||
.stream()
|
||||
.map(row -> row.getCovariateValue(covariate.getName()))
|
||||
.map(row -> row.getCovariateValue(covariate))
|
||||
.collect(Collectors.toList())
|
||||
, numberToTry);
|
||||
|
||||
|
|
|
@ -53,7 +53,7 @@ public class TestSavingLoading {
|
|||
yVarSettings.set("delta", new TextNode("status"));
|
||||
|
||||
return Settings.builder()
|
||||
.covariates(Utils.easyList(
|
||||
.covariateSettings(Utils.easyList(
|
||||
new NumericCovariateSettings("ageatfda"),
|
||||
new BooleanCovariateSettings("idu"),
|
||||
new BooleanCovariateSettings("black"),
|
||||
|
@ -87,14 +87,10 @@ public class TestSavingLoading {
|
|||
, covariates, 1);
|
||||
}
|
||||
|
||||
public List<Covariate> getCovariates(Settings settings){
|
||||
return settings.getCovariates().stream().map(covariateSettings -> covariateSettings.build()).collect(Collectors.toList());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testSavingLoading() throws IOException, ClassNotFoundException {
|
||||
final Settings settings = getSettings();
|
||||
final List<Covariate> covariates = getCovariates(settings);
|
||||
final List<Covariate> covariates = settings.getCovariates();
|
||||
final List<Row<CompetingRiskResponse>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation());
|
||||
|
||||
final File directory = new File(settings.getSaveTreeLocation());
|
||||
|
|
|
@ -55,7 +55,7 @@ public class TestCompetingRisk {
|
|||
yVarSettings.set("delta", new TextNode("status"));
|
||||
|
||||
return Settings.builder()
|
||||
.covariates(Utils.easyList(
|
||||
.covariateSettings(Utils.easyList(
|
||||
new NumericCovariateSettings("ageatfda"),
|
||||
new BooleanCovariateSettings("idu"),
|
||||
new BooleanCovariateSettings("black"),
|
||||
|
@ -79,9 +79,6 @@ public class TestCompetingRisk {
|
|||
.build();
|
||||
}
|
||||
|
||||
public List<Covariate> getCovariates(Settings settings){
|
||||
return settings.getCovariates().stream().map(covariateSettings -> covariateSettings.build()).collect(Collectors.toList());
|
||||
}
|
||||
|
||||
public CovariateRow getPredictionRow(List<Covariate> covariates){
|
||||
return CovariateRow.createSimple(Utils.easyMap(
|
||||
|
@ -96,12 +93,12 @@ public class TestCompetingRisk {
|
|||
public void testSingleTree() throws IOException {
|
||||
final Settings settings = getSettings();
|
||||
settings.setTrainingDataLocation("src/test/resources/wihs.bootstrapped.csv");
|
||||
settings.setCovariates(Utils.easyList(
|
||||
settings.setCovariateSettings(Utils.easyList(
|
||||
new BooleanCovariateSettings("idu"),
|
||||
new BooleanCovariateSettings("black")
|
||||
)); // by only using BooleanCovariates (only one split rule) we can guarantee identical results with randomForestSRC on one tree.
|
||||
|
||||
final List<Covariate> covariates = getCovariates(settings);
|
||||
final List<Covariate> covariates = settings.getCovariates();
|
||||
|
||||
final List<Row<CompetingRiskResponse>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation());
|
||||
|
||||
|
@ -154,7 +151,7 @@ public class TestCompetingRisk {
|
|||
settings.setNumberOfSplits(0);
|
||||
settings.setTrainingDataLocation("src/test/resources/wihs.bootstrapped2.csv");
|
||||
|
||||
final List<Covariate> covariates = getCovariates(settings);
|
||||
final List<Covariate> covariates = settings.getCovariates();
|
||||
|
||||
final List<Row<CompetingRiskResponse>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation());
|
||||
|
||||
|
@ -199,12 +196,12 @@ public class TestCompetingRisk {
|
|||
@Test
|
||||
public void testLogRankSingleGroupDifferentiatorTwoBooleans() throws IOException {
|
||||
final Settings settings = getSettings();
|
||||
settings.setCovariates(Utils.easyList(
|
||||
settings.setCovariateSettings(Utils.easyList(
|
||||
new BooleanCovariateSettings("idu"),
|
||||
new BooleanCovariateSettings("black")
|
||||
));
|
||||
|
||||
final List<Covariate> covariates = getCovariates(settings);
|
||||
final List<Covariate> covariates = settings.getCovariates();
|
||||
|
||||
final List<Row<CompetingRiskResponse>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation());
|
||||
|
||||
|
@ -254,7 +251,7 @@ public class TestCompetingRisk {
|
|||
public void verifyDataset() throws IOException {
|
||||
final Settings settings = getSettings();
|
||||
|
||||
final List<Covariate> covariates = getCovariates(settings);
|
||||
final List<Covariate> covariates = settings.getCovariates();
|
||||
|
||||
final List<Row<CompetingRiskResponse>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation());
|
||||
|
||||
|
@ -291,7 +288,7 @@ public class TestCompetingRisk {
|
|||
final Settings settings = getSettings();
|
||||
settings.setNtree(300); // results are too variable at 100
|
||||
|
||||
final List<Covariate> covariates = getCovariates(settings);
|
||||
final List<Covariate> covariates = settings.getCovariates();
|
||||
final List<Row<CompetingRiskResponse>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation());
|
||||
final ForestTrainer<CompetingRiskResponse, CompetingRiskFunctions, CompetingRiskFunctions> forestTrainer = new ForestTrainer<>(settings, dataset, covariates);
|
||||
final Forest<CompetingRiskFunctions, CompetingRiskFunctions> forest = forestTrainer.trainSerial();
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package ca.joeltherrien.randomforest.competingrisk;
|
||||
|
||||
import ca.joeltherrien.randomforest.Row;
|
||||
import ca.joeltherrien.randomforest.covariates.Covariate;
|
||||
import ca.joeltherrien.randomforest.responses.competingrisk.*;
|
||||
import ca.joeltherrien.randomforest.tree.Forest;
|
||||
import ca.joeltherrien.randomforest.utils.MathFunction;
|
||||
|
@ -54,10 +55,10 @@ public class TestCompetingRiskErrorRateCalculator {
|
|||
final CompetingRiskResponse response4 = new CompetingRiskResponse(1, 3.0);
|
||||
|
||||
final List<Row<CompetingRiskResponse>> dataset = Utils.easyList(
|
||||
new Row<>(Collections.emptyMap(), 1, response1),
|
||||
new Row<>(Collections.emptyMap(), 2, response2),
|
||||
new Row<>(Collections.emptyMap(), 3, response3),
|
||||
new Row<>(Collections.emptyMap(), 4, response4)
|
||||
new Row<>(new Covariate.Value[]{}, 1, response1),
|
||||
new Row<>(new Covariate.Value[]{}, 2, response2),
|
||||
new Row<>(new Covariate.Value[]{}, 3, response3),
|
||||
new Row<>(new Covariate.Value[]{}, 4, response4)
|
||||
);
|
||||
|
||||
final double[] mortalityOneArray = new double[]{1, 4, 3, 9};
|
||||
|
|
|
@ -58,7 +58,7 @@ public class FactorCovariateTest {
|
|||
private FactorCovariate createTestCovariate(){
|
||||
final List<String> levels = Utils.easyList("DOG", "CAT", "MOUSE");
|
||||
|
||||
return new FactorCovariate("pet", levels);
|
||||
return new FactorCovariate("pet", 0, levels);
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -38,7 +38,7 @@ public class TestLoadingCSV {
|
|||
|
||||
final Settings settings = Settings.builder()
|
||||
.trainingDataLocation(filename)
|
||||
.covariates(
|
||||
.covariateSettings(
|
||||
Utils.easyList(new NumericCovariateSettings("x1"),
|
||||
new FactorCovariateSettings("x2", Utils.easyList("dog", "cat", "mouse")),
|
||||
new BooleanCovariateSettings("x3"))
|
||||
|
@ -46,8 +46,7 @@ public class TestLoadingCSV {
|
|||
.yVarSettings(yVarSettings)
|
||||
.build();
|
||||
|
||||
final List<Covariate> covariates = settings.getCovariates().stream()
|
||||
.map(cs -> cs.build()).collect(Collectors.toList());
|
||||
final List<Covariate> covariates = settings.getCovariates();
|
||||
|
||||
|
||||
final DataLoader.ResponseLoader loader = settings.getResponseLoader();
|
||||
|
@ -56,46 +55,50 @@ public class TestLoadingCSV {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void verifyLoadingNormal() throws IOException {
|
||||
public void verifyLoadingNormal(final List<Covariate> covariates) throws IOException {
|
||||
final List<Row<Double>> data = loadData("src/test/resources/testCSV.csv");
|
||||
|
||||
assertData(data);
|
||||
assertData(data, covariates);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void verifyLoadingGz() throws IOException {
|
||||
public void verifyLoadingGz(final List<Covariate> covariates) throws IOException {
|
||||
final List<Row<Double>> data = loadData("src/test/resources/testCSV.csv.gz");
|
||||
|
||||
assertData(data);
|
||||
assertData(data, covariates);
|
||||
}
|
||||
|
||||
|
||||
private void assertData(final List<Row<Double>> data){
|
||||
private void assertData(final List<Row<Double>> data, final List<Covariate> covariates){
|
||||
final Covariate x1 = covariates.get(0);
|
||||
final Covariate x2 = covariates.get(0);
|
||||
final Covariate x3 = covariates.get(0);
|
||||
|
||||
assertEquals(4, data.size());
|
||||
|
||||
Row<Double> row = data.get(0);
|
||||
assertEquals(5.0, (double)row.getResponse());
|
||||
assertEquals(3.0, row.getCovariateValue("x1").getValue());
|
||||
assertEquals("mouse", row.getCovariateValue("x2").getValue());
|
||||
assertEquals(true, row.getCovariateValue("x3").getValue());
|
||||
assertEquals(3.0, row.getCovariateValue(x1).getValue());
|
||||
assertEquals("mouse", row.getCovariateValue(x2).getValue());
|
||||
assertEquals(true, row.getCovariateValue(x3).getValue());
|
||||
|
||||
row = data.get(1);
|
||||
assertEquals(2.0, (double)row.getResponse());
|
||||
assertEquals(1.0, row.getCovariateValue("x1").getValue());
|
||||
assertEquals("dog", row.getCovariateValue("x2").getValue());
|
||||
assertEquals(false, row.getCovariateValue("x3").getValue());
|
||||
assertEquals(1.0, row.getCovariateValue(x1).getValue());
|
||||
assertEquals("dog", row.getCovariateValue(x2).getValue());
|
||||
assertEquals(false, row.getCovariateValue(x3).getValue());
|
||||
|
||||
row = data.get(2);
|
||||
assertEquals(9.0, (double)row.getResponse());
|
||||
assertEquals(1.5, row.getCovariateValue("x1").getValue());
|
||||
assertEquals("cat", row.getCovariateValue("x2").getValue());
|
||||
assertEquals(true, row.getCovariateValue("x3").getValue());
|
||||
assertEquals(1.5, row.getCovariateValue(x1).getValue());
|
||||
assertEquals("cat", row.getCovariateValue(x2).getValue());
|
||||
assertEquals(true, row.getCovariateValue(x3).getValue());
|
||||
|
||||
row = data.get(3);
|
||||
assertEquals(-3.0, (double)row.getResponse());
|
||||
assertTrue(row.getCovariateValue("x1").isNA());
|
||||
assertTrue(row.getCovariateValue("x2").isNA());
|
||||
assertTrue(row.getCovariateValue("x3").isNA());
|
||||
assertTrue(row.getCovariateValue(x1).isNA());
|
||||
assertTrue(row.getCovariateValue(x2).isNA());
|
||||
assertTrue(row.getCovariateValue(x3).isNA());
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -31,7 +31,7 @@ public class TestPersistence {
|
|||
yVarSettings.set("name", new TextNode("y"));
|
||||
|
||||
final Settings settingsOriginal = Settings.builder()
|
||||
.covariates(Utils.easyList(
|
||||
.covariateSettings(Utils.easyList(
|
||||
new NumericCovariateSettings("x1"),
|
||||
new BooleanCovariateSettings("x2"),
|
||||
new FactorCovariateSettings("x3", Utils.easyList("cat", "mouse", "dog"))
|
||||
|
|
|
@ -27,22 +27,22 @@ public class TrainForest {
|
|||
|
||||
final List<Covariate> covariateList = new ArrayList<>(p);
|
||||
for(int j =0; j < p; j++){
|
||||
final NumericCovariate covariate = new NumericCovariate("x"+j);
|
||||
final NumericCovariate covariate = new NumericCovariate("x"+j, j);
|
||||
covariateList.add(covariate);
|
||||
}
|
||||
|
||||
for(int i=0; i<n; i++){
|
||||
double y = 0.0;
|
||||
final Map<String, Covariate.Value> map = new HashMap<>();
|
||||
final Covariate.Value[] valueArray = new Covariate.Value[covariateList.size()];
|
||||
|
||||
for(final Covariate covariate : covariateList) {
|
||||
final double x = random.nextDouble();
|
||||
y += x;
|
||||
|
||||
map.put(covariate.getName(), covariate.createValue(x));
|
||||
valueArray[covariate.getIndex()] = covariate.createValue(y);
|
||||
}
|
||||
|
||||
data.add(i, new Row<>(map, i, y));
|
||||
data.add(i, new Row<>(valueArray, i, y));
|
||||
|
||||
if(y < minY){
|
||||
minY = y;
|
||||
|
|
|
@ -18,15 +18,13 @@ import java.util.stream.DoubleStream;
|
|||
public class TrainSingleTree {
|
||||
|
||||
public static void main(String[] args) {
|
||||
System.out.println("Hello world!");
|
||||
|
||||
final Random random = new Random(123);
|
||||
|
||||
final int n = 1000;
|
||||
final List<Row<Double>> trainingSet = new ArrayList<>(n);
|
||||
|
||||
final Covariate<Double> x1Covariate = new NumericCovariate("x1");
|
||||
final Covariate<Double> x2Covariate = new NumericCovariate("x2");
|
||||
final Covariate<Double> x1Covariate = new NumericCovariate("x1", 0);
|
||||
final Covariate<Double> x2Covariate = new NumericCovariate("x2", 1);
|
||||
|
||||
final List<Covariate.Value<Double>> x1List = DoubleStream
|
||||
.generate(() -> random.nextDouble()*10.0)
|
||||
|
@ -100,17 +98,21 @@ public class TrainSingleTree {
|
|||
public static Row<Double> generateRow(Covariate.Value<Double> x1, Covariate.Value<Double> x2, int id){
|
||||
double y = generateResponse(x1.getValue(), x2.getValue());
|
||||
|
||||
final Map<String, Covariate.Value> map = Utils.easyMap("x1", x1, "x2", x2);
|
||||
final Covariate.Value[] valueArray = new Covariate.Value[2];
|
||||
valueArray[0] = x1;
|
||||
valueArray[1] = x2;
|
||||
|
||||
return new Row<>(map, id, y);
|
||||
return new Row<>(valueArray, id, y);
|
||||
|
||||
}
|
||||
|
||||
|
||||
public static CovariateRow generateCovariateRow(Covariate.Value x1, Covariate.Value x2, int id){
|
||||
final Map<String, Covariate.Value> map = Utils.easyMap("x1", x1, "x2", x2);
|
||||
final Covariate.Value[] valueArray = new Covariate.Value[2];
|
||||
valueArray[0] = x1;
|
||||
valueArray[1] = x2;
|
||||
|
||||
return new CovariateRow(map, id);
|
||||
return new CovariateRow(valueArray, id);
|
||||
|
||||
}
|
||||
|
||||
|
|
|
@ -28,9 +28,9 @@ public class TrainSingleTreeFactor {
|
|||
final int n = 10000;
|
||||
final List<Row<Double>> trainingSet = new ArrayList<>(n);
|
||||
|
||||
final Covariate<Double> x1Covariate = new NumericCovariate("x1");
|
||||
final Covariate<Double> x2Covariate = new NumericCovariate("x2");
|
||||
final FactorCovariate x3Covariate = new FactorCovariate("x3", Utils.easyList("cat", "dog", "mouse"));
|
||||
final Covariate<Double> x1Covariate = new NumericCovariate("x1", 0);
|
||||
final Covariate<Double> x2Covariate = new NumericCovariate("x2", 1);
|
||||
final FactorCovariate x3Covariate = new FactorCovariate("x3", 2, Utils.easyList("cat", "dog", "mouse"));
|
||||
|
||||
final List<Covariate.Value<Double>> x1List = DoubleStream
|
||||
.generate(() -> random.nextDouble()*10.0)
|
||||
|
@ -128,17 +128,25 @@ public class TrainSingleTreeFactor {
|
|||
public static Row<Double> generateRow(Covariate.Value<Double> x1, Covariate.Value<Double> x2, Covariate.Value<String> x3, int id){
|
||||
double y = generateResponse(x1.getValue(), x2.getValue(), x3.getValue());
|
||||
|
||||
final Map<String, Covariate.Value> map = Utils.easyMap("x1", x1, "x2", x2);
|
||||
final Covariate.Value[] valueArray = new Covariate.Value[3];
|
||||
valueArray[0] = x1;
|
||||
valueArray[1] = x2;
|
||||
valueArray[2] = x3;
|
||||
|
||||
return new Row<>(map, id, y);
|
||||
//final Map<String, Covariate.Value> map = Utils.easyMap("x1", x1, "x2", x2); // Missing x3?
|
||||
|
||||
return new Row<>(valueArray, id, y);
|
||||
|
||||
}
|
||||
|
||||
|
||||
public static CovariateRow generateCovariateRow(Covariate.Value x1, Covariate.Value x2, Covariate.Value x3, int id){
|
||||
final Map<String, Covariate.Value> map = Utils.easyMap("x1", x1, "x2", x2, "x3", x3);
|
||||
final Covariate.Value[] valueArray = new Covariate.Value[3];
|
||||
valueArray[0] = x1;
|
||||
valueArray[1] = x2;
|
||||
valueArray[2] = x3;
|
||||
|
||||
return new CovariateRow(map, id);
|
||||
return new CovariateRow(valueArray, id);
|
||||
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in a new issue