diff --git a/src/main/java/ca/joeltherrien/randomforest/CovariateRow.java b/src/main/java/ca/joeltherrien/randomforest/CovariateRow.java index c8e2543..63d8e72 100644 --- a/src/main/java/ca/joeltherrien/randomforest/CovariateRow.java +++ b/src/main/java/ca/joeltherrien/randomforest/CovariateRow.java @@ -12,14 +12,13 @@ import java.util.Map; @RequiredArgsConstructor public class CovariateRow implements Serializable { - private final Map valueMap; + private final Covariate.Value[] valueArray; @Getter private final int id; - public Covariate.Value getCovariateValue(String name){ - return valueMap.get(name); - + public Covariate.Value getCovariateValue(Covariate covariate){ + return valueArray[covariate.getIndex()]; } @Override @@ -28,18 +27,21 @@ public class CovariateRow implements Serializable { } public static CovariateRow createSimple(Map simpleMap, List covariateList, int id){ - final Map valueMap = new HashMap<>(); + final Covariate.Value[] valueArray = new Covariate.Value[covariateList.size()]; final Map covariateMap = new HashMap<>(); covariateList.forEach(covariate -> covariateMap.put(covariate.getName(), covariate)); simpleMap.forEach((name, valueStr) -> { - if(covariateMap.containsKey(name)){ - valueMap.put(name, covariateMap.get(name).createValue(valueStr)); + final Covariate covariate = covariateMap.get(name); + + if(covariate != null){ // happens often in tests where we experiment with adding / removing covariates + valueArray[covariate.getIndex()] = covariate.createValue(valueStr); } + }); - return new CovariateRow(valueMap, id); + return new CovariateRow(valueArray, id); } } diff --git a/src/main/java/ca/joeltherrien/randomforest/DataLoader.java b/src/main/java/ca/joeltherrien/randomforest/DataLoader.java index b7e2e8c..25e8465 100644 --- a/src/main/java/ca/joeltherrien/randomforest/DataLoader.java +++ b/src/main/java/ca/joeltherrien/randomforest/DataLoader.java @@ -37,15 +37,15 @@ public class DataLoader { int id = 1; for(final CSVRecord record : parser){ - final Map covariateValueMap = new HashMap<>(); + final Covariate.Value[] valueArray = new Covariate.Value[covariates.size()]; for(final Covariate covariate : covariates){ - covariateValueMap.put(covariate.getName(), covariate.createValue(record.get(covariate.getName()))); + valueArray[covariate.getIndex()] = covariate.createValue(record.get(covariate.getName())); } final Y y = responseLoader.parse(record); - dataset.add(new Row<>(covariateValueMap, id++, y)); + dataset.add(new Row<>(valueArray, id++, y)); } diff --git a/src/main/java/ca/joeltherrien/randomforest/Main.java b/src/main/java/ca/joeltherrien/randomforest/Main.java index d90328d..aa8173a 100644 --- a/src/main/java/ca/joeltherrien/randomforest/Main.java +++ b/src/main/java/ca/joeltherrien/randomforest/Main.java @@ -1,9 +1,6 @@ package ca.joeltherrien.randomforest; -import ca.joeltherrien.randomforest.covariates.BooleanCovariateSettings; -import ca.joeltherrien.randomforest.covariates.Covariate; -import ca.joeltherrien.randomforest.covariates.FactorCovariateSettings; -import ca.joeltherrien.randomforest.covariates.NumericCovariateSettings; +import ca.joeltherrien.randomforest.covariates.*; import ca.joeltherrien.randomforest.responses.competingrisk.*; import ca.joeltherrien.randomforest.responses.competingrisk.combiner.CompetingRiskFunctionCombiner; import ca.joeltherrien.randomforest.responses.competingrisk.combiner.alternative.CompetingRiskListCombiner; @@ -17,6 +14,7 @@ import com.fasterxml.jackson.databind.node.ObjectNode; import com.fasterxml.jackson.databind.node.TextNode; import java.io.*; +import java.util.ArrayList; import java.util.List; import java.util.stream.Collectors; @@ -35,8 +33,7 @@ public class Main { final File settingsFile = new File(args[0]); final Settings settings = Settings.load(settingsFile); - final List covariates = settings.getCovariates().stream() - .map(cs -> cs.build()).collect(Collectors.toList()); + final List covariates = settings.getCovariates(); if(args[1].equalsIgnoreCase("train")){ final List dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation()); @@ -168,7 +165,7 @@ public class Main { yVarSettings.set("name", new TextNode("y")); return Settings.builder() - .covariates(Utils.easyList( + .covariateSettings(Utils.easyList( new NumericCovariateSettings("x1"), new BooleanCovariateSettings("x2"), new FactorCovariateSettings("x3", Utils.easyList("cat", "mouse", "dog")) diff --git a/src/main/java/ca/joeltherrien/randomforest/Row.java b/src/main/java/ca/joeltherrien/randomforest/Row.java index ef263bf..850f035 100644 --- a/src/main/java/ca/joeltherrien/randomforest/Row.java +++ b/src/main/java/ca/joeltherrien/randomforest/Row.java @@ -3,14 +3,16 @@ package ca.joeltherrien.randomforest; import ca.joeltherrien.randomforest.covariates.Covariate; +import java.util.HashMap; +import java.util.List; import java.util.Map; public class Row extends CovariateRow { private final Y response; - public Row(Map valueMap, int id, Y response){ - super(valueMap, id); + public Row(final Covariate.Value[] valueArray, final int id, final Y response){ + super(valueArray, id); this.response = response; } @@ -23,7 +25,21 @@ public class Row extends CovariateRow { public String toString() { return "Row " + this.getId(); } - - + + public static Row createSimple(Map simpleMap, List covariateList, int id, final Y response){ + final Covariate.Value[] valueArray = new Covariate.Value[covariateList.size()]; + final Map covariateMap = new HashMap<>(); + + covariateList.forEach(covariate -> covariateMap.put(covariate.getName(), covariate)); + + simpleMap.forEach((name, valueStr) -> { + final Covariate covariate = covariateMap.get(name); + if(covariate != null){ // happens often in tests where we experiment with adding / removing covariates + valueArray[covariate.getIndex()] = covariate.createValue(valueStr); + } + }); + + return new Row(valueArray, id, response); + } } diff --git a/src/main/java/ca/joeltherrien/randomforest/Settings.java b/src/main/java/ca/joeltherrien/randomforest/Settings.java index e53374b..0c15b57 100644 --- a/src/main/java/ca/joeltherrien/randomforest/Settings.java +++ b/src/main/java/ca/joeltherrien/randomforest/Settings.java @@ -1,5 +1,6 @@ package ca.joeltherrien.randomforest; +import ca.joeltherrien.randomforest.covariates.Covariate; import ca.joeltherrien.randomforest.covariates.CovariateSettings; import ca.joeltherrien.randomforest.responses.competingrisk.*; import ca.joeltherrien.randomforest.responses.competingrisk.combiner.CompetingRiskFunctionCombiner; @@ -195,7 +196,7 @@ public class Settings { private ObjectNode groupDifferentiatorSettings = new ObjectNode(JsonNodeFactory.instance); private ObjectNode treeCombinerSettings = new ObjectNode(JsonNodeFactory.instance); - private List covariates = new ArrayList<>(); + private List covariateSettings = new ArrayList<>(); private ObjectNode yVarSettings = new ObjectNode(JsonNodeFactory.instance); // number of covariates to randomly try @@ -227,7 +228,7 @@ public class Settings { //mapper.enableDefaultTyping(); // Jackson can struggle with some types of Lists, such as that returned by the useful List.of(...) - this.covariates = new ArrayList<>(this.covariates); + this.covariateSettings = new ArrayList<>(this.covariateSettings); mapper.writeValue(file, this); } @@ -260,4 +261,14 @@ public class Settings { return getResponseCombinerConstructor(type).apply(treeCombinerSettings); } + @JsonIgnore + public List getCovariates(){ + final List covariateSettingsList = this.getCovariateSettings(); + final List covariates = new ArrayList<>(covariateSettingsList.size()); + for(int i = 0; i < covariateSettingsList.size(); i++){ + covariates.add(covariateSettingsList.get(i).build(i)); + } + return covariates; + } + } diff --git a/src/main/java/ca/joeltherrien/randomforest/covariates/BooleanCovariate.java b/src/main/java/ca/joeltherrien/randomforest/covariates/BooleanCovariate.java index c4581ff..8959bc4 100644 --- a/src/main/java/ca/joeltherrien/randomforest/covariates/BooleanCovariate.java +++ b/src/main/java/ca/joeltherrien/randomforest/covariates/BooleanCovariate.java @@ -13,6 +13,9 @@ public final class BooleanCovariate implements Covariate{ @Getter private final String name; + @Getter + private final int index; + private final BooleanSplitRule splitRule = new BooleanSplitRule(); // there's only one possible rule for BooleanCovariates. @Override diff --git a/src/main/java/ca/joeltherrien/randomforest/covariates/BooleanCovariateSettings.java b/src/main/java/ca/joeltherrien/randomforest/covariates/BooleanCovariateSettings.java index ed384f3..5f57c15 100644 --- a/src/main/java/ca/joeltherrien/randomforest/covariates/BooleanCovariateSettings.java +++ b/src/main/java/ca/joeltherrien/randomforest/covariates/BooleanCovariateSettings.java @@ -12,7 +12,7 @@ public final class BooleanCovariateSettings extends CovariateSettings { } @Override - public BooleanCovariate build() { - return new BooleanCovariate(name); + public BooleanCovariate build(int index) { + return new BooleanCovariate(name, index); } } diff --git a/src/main/java/ca/joeltherrien/randomforest/covariates/Covariate.java b/src/main/java/ca/joeltherrien/randomforest/covariates/Covariate.java index 83760a5..77f1f10 100644 --- a/src/main/java/ca/joeltherrien/randomforest/covariates/Covariate.java +++ b/src/main/java/ca/joeltherrien/randomforest/covariates/Covariate.java @@ -12,6 +12,8 @@ public interface Covariate extends Serializable { String getName(); + int getIndex(); + Collection> generateSplitRules(final List> data, final int number); Value createValue(V value); @@ -54,7 +56,7 @@ public interface Covariate extends Serializable { for(final Row row : rows) { - final Value value = (Value) row.getCovariateValue(getParent().getName()); + final Value value = (Value) row.getCovariateValue(getParent()); if(value.isNA()){ missingValueRows.add(row); @@ -76,7 +78,7 @@ public interface Covariate extends Serializable { } default boolean isLeftHand(CovariateRow row, final double probabilityNaLeftHand){ - final Value value = (Value) row.getCovariateValue(getParent().getName()); + final Value value = (Value) row.getCovariateValue(getParent()); if(value.isNA()){ return ThreadLocalRandom.current().nextDouble() <= probabilityNaLeftHand; diff --git a/src/main/java/ca/joeltherrien/randomforest/covariates/CovariateSettings.java b/src/main/java/ca/joeltherrien/randomforest/covariates/CovariateSettings.java index 18f53f2..4d850ac 100644 --- a/src/main/java/ca/joeltherrien/randomforest/covariates/CovariateSettings.java +++ b/src/main/java/ca/joeltherrien/randomforest/covariates/CovariateSettings.java @@ -28,5 +28,5 @@ public abstract class CovariateSettings { this.name = name; } - public abstract Covariate build(); + public abstract Covariate build(int index); } diff --git a/src/main/java/ca/joeltherrien/randomforest/covariates/FactorCovariate.java b/src/main/java/ca/joeltherrien/randomforest/covariates/FactorCovariate.java index 4638f92..402cb7d 100644 --- a/src/main/java/ca/joeltherrien/randomforest/covariates/FactorCovariate.java +++ b/src/main/java/ca/joeltherrien/randomforest/covariates/FactorCovariate.java @@ -1,20 +1,27 @@ package ca.joeltherrien.randomforest.covariates; import lombok.EqualsAndHashCode; +import lombok.Getter; import java.util.*; import java.util.concurrent.ThreadLocalRandom; public final class FactorCovariate implements Covariate{ + @Getter private final String name; + + @Getter + private final int index; + private final Map factorLevels; private final FactorValue naValue; private final int numberOfPossiblePairings; - public FactorCovariate(final String name, List levels){ + public FactorCovariate(final String name, final int index, List levels){ this.name = name; + this.index = index; this.factorLevels = new HashMap<>(); for(final String level : levels){ @@ -33,10 +40,6 @@ public final class FactorCovariate implements Covariate{ } - @Override - public String getName() { - return name; - } @Override public Set generateSplitRules(List> data, int number) { diff --git a/src/main/java/ca/joeltherrien/randomforest/covariates/FactorCovariateSettings.java b/src/main/java/ca/joeltherrien/randomforest/covariates/FactorCovariateSettings.java index dbfaaae..9d7ece5 100644 --- a/src/main/java/ca/joeltherrien/randomforest/covariates/FactorCovariateSettings.java +++ b/src/main/java/ca/joeltherrien/randomforest/covariates/FactorCovariateSettings.java @@ -18,7 +18,7 @@ public final class FactorCovariateSettings extends CovariateSettings { } @Override - public FactorCovariate build() { - return new FactorCovariate(name, levels); + public FactorCovariate build(int index) { + return new FactorCovariate(name, index, levels); } } diff --git a/src/main/java/ca/joeltherrien/randomforest/covariates/NumericCovariate.java b/src/main/java/ca/joeltherrien/randomforest/covariates/NumericCovariate.java index 6ec8d05..a268ba8 100644 --- a/src/main/java/ca/joeltherrien/randomforest/covariates/NumericCovariate.java +++ b/src/main/java/ca/joeltherrien/randomforest/covariates/NumericCovariate.java @@ -16,6 +16,9 @@ public final class NumericCovariate implements Covariate{ @Getter private final String name; + @Getter + private final int index; + @Override public Collection generateSplitRules(List> data, int number) { diff --git a/src/main/java/ca/joeltherrien/randomforest/covariates/NumericCovariateSettings.java b/src/main/java/ca/joeltherrien/randomforest/covariates/NumericCovariateSettings.java index b35a81a..9cdf898 100644 --- a/src/main/java/ca/joeltherrien/randomforest/covariates/NumericCovariateSettings.java +++ b/src/main/java/ca/joeltherrien/randomforest/covariates/NumericCovariateSettings.java @@ -12,7 +12,7 @@ public final class NumericCovariateSettings extends CovariateSettings { } @Override - public NumericCovariate build() { - return new NumericCovariate(name); + public NumericCovariate build(int index) { + return new NumericCovariate(name, index); } } diff --git a/src/main/java/ca/joeltherrien/randomforest/tree/TreeTrainer.java b/src/main/java/ca/joeltherrien/randomforest/tree/TreeTrainer.java index 7ec00fa..98805b5 100644 --- a/src/main/java/ca/joeltherrien/randomforest/tree/TreeTrainer.java +++ b/src/main/java/ca/joeltherrien/randomforest/tree/TreeTrainer.java @@ -113,7 +113,7 @@ public class TreeTrainer { .generateSplitRules( data .stream() - .map(row -> row.getCovariateValue(covariate.getName())) + .map(row -> row.getCovariateValue(covariate)) .collect(Collectors.toList()) , numberToTry); diff --git a/src/test/java/ca/joeltherrien/randomforest/TestSavingLoading.java b/src/test/java/ca/joeltherrien/randomforest/TestSavingLoading.java index f6f800b..1cc183a 100644 --- a/src/test/java/ca/joeltherrien/randomforest/TestSavingLoading.java +++ b/src/test/java/ca/joeltherrien/randomforest/TestSavingLoading.java @@ -53,7 +53,7 @@ public class TestSavingLoading { yVarSettings.set("delta", new TextNode("status")); return Settings.builder() - .covariates(Utils.easyList( + .covariateSettings(Utils.easyList( new NumericCovariateSettings("ageatfda"), new BooleanCovariateSettings("idu"), new BooleanCovariateSettings("black"), @@ -87,14 +87,10 @@ public class TestSavingLoading { , covariates, 1); } - public List getCovariates(Settings settings){ - return settings.getCovariates().stream().map(covariateSettings -> covariateSettings.build()).collect(Collectors.toList()); - } - @Test public void testSavingLoading() throws IOException, ClassNotFoundException { final Settings settings = getSettings(); - final List covariates = getCovariates(settings); + final List covariates = settings.getCovariates(); final List> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation()); final File directory = new File(settings.getSaveTreeLocation()); diff --git a/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCompetingRisk.java b/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCompetingRisk.java index da18ce4..c761401 100644 --- a/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCompetingRisk.java +++ b/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCompetingRisk.java @@ -55,7 +55,7 @@ public class TestCompetingRisk { yVarSettings.set("delta", new TextNode("status")); return Settings.builder() - .covariates(Utils.easyList( + .covariateSettings(Utils.easyList( new NumericCovariateSettings("ageatfda"), new BooleanCovariateSettings("idu"), new BooleanCovariateSettings("black"), @@ -79,9 +79,6 @@ public class TestCompetingRisk { .build(); } - public List getCovariates(Settings settings){ - return settings.getCovariates().stream().map(covariateSettings -> covariateSettings.build()).collect(Collectors.toList()); - } public CovariateRow getPredictionRow(List covariates){ return CovariateRow.createSimple(Utils.easyMap( @@ -96,12 +93,12 @@ public class TestCompetingRisk { public void testSingleTree() throws IOException { final Settings settings = getSettings(); settings.setTrainingDataLocation("src/test/resources/wihs.bootstrapped.csv"); - settings.setCovariates(Utils.easyList( + settings.setCovariateSettings(Utils.easyList( new BooleanCovariateSettings("idu"), new BooleanCovariateSettings("black") )); // by only using BooleanCovariates (only one split rule) we can guarantee identical results with randomForestSRC on one tree. - final List covariates = getCovariates(settings); + final List covariates = settings.getCovariates(); final List> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation()); @@ -154,7 +151,7 @@ public class TestCompetingRisk { settings.setNumberOfSplits(0); settings.setTrainingDataLocation("src/test/resources/wihs.bootstrapped2.csv"); - final List covariates = getCovariates(settings); + final List covariates = settings.getCovariates(); final List> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation()); @@ -199,12 +196,12 @@ public class TestCompetingRisk { @Test public void testLogRankSingleGroupDifferentiatorTwoBooleans() throws IOException { final Settings settings = getSettings(); - settings.setCovariates(Utils.easyList( + settings.setCovariateSettings(Utils.easyList( new BooleanCovariateSettings("idu"), new BooleanCovariateSettings("black") )); - final List covariates = getCovariates(settings); + final List covariates = settings.getCovariates(); final List> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation()); @@ -254,7 +251,7 @@ public class TestCompetingRisk { public void verifyDataset() throws IOException { final Settings settings = getSettings(); - final List covariates = getCovariates(settings); + final List covariates = settings.getCovariates(); final List> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation()); @@ -291,7 +288,7 @@ public class TestCompetingRisk { final Settings settings = getSettings(); settings.setNtree(300); // results are too variable at 100 - final List covariates = getCovariates(settings); + final List covariates = settings.getCovariates(); final List> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation()); final ForestTrainer forestTrainer = new ForestTrainer<>(settings, dataset, covariates); final Forest forest = forestTrainer.trainSerial(); diff --git a/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCompetingRiskErrorRateCalculator.java b/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCompetingRiskErrorRateCalculator.java index 6dee34f..6da0d38 100644 --- a/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCompetingRiskErrorRateCalculator.java +++ b/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCompetingRiskErrorRateCalculator.java @@ -1,6 +1,7 @@ package ca.joeltherrien.randomforest.competingrisk; import ca.joeltherrien.randomforest.Row; +import ca.joeltherrien.randomforest.covariates.Covariate; import ca.joeltherrien.randomforest.responses.competingrisk.*; import ca.joeltherrien.randomforest.tree.Forest; import ca.joeltherrien.randomforest.utils.MathFunction; @@ -54,10 +55,10 @@ public class TestCompetingRiskErrorRateCalculator { final CompetingRiskResponse response4 = new CompetingRiskResponse(1, 3.0); final List> dataset = Utils.easyList( - new Row<>(Collections.emptyMap(), 1, response1), - new Row<>(Collections.emptyMap(), 2, response2), - new Row<>(Collections.emptyMap(), 3, response3), - new Row<>(Collections.emptyMap(), 4, response4) + new Row<>(new Covariate.Value[]{}, 1, response1), + new Row<>(new Covariate.Value[]{}, 2, response2), + new Row<>(new Covariate.Value[]{}, 3, response3), + new Row<>(new Covariate.Value[]{}, 4, response4) ); final double[] mortalityOneArray = new double[]{1, 4, 3, 9}; diff --git a/src/test/java/ca/joeltherrien/randomforest/covariates/FactorCovariateTest.java b/src/test/java/ca/joeltherrien/randomforest/covariates/FactorCovariateTest.java index 5a092b1..7d80a8c 100644 --- a/src/test/java/ca/joeltherrien/randomforest/covariates/FactorCovariateTest.java +++ b/src/test/java/ca/joeltherrien/randomforest/covariates/FactorCovariateTest.java @@ -58,7 +58,7 @@ public class FactorCovariateTest { private FactorCovariate createTestCovariate(){ final List levels = Utils.easyList("DOG", "CAT", "MOUSE"); - return new FactorCovariate("pet", levels); + return new FactorCovariate("pet", 0, levels); } diff --git a/src/test/java/ca/joeltherrien/randomforest/csv/TestLoadingCSV.java b/src/test/java/ca/joeltherrien/randomforest/csv/TestLoadingCSV.java index 5decd46..82262df 100644 --- a/src/test/java/ca/joeltherrien/randomforest/csv/TestLoadingCSV.java +++ b/src/test/java/ca/joeltherrien/randomforest/csv/TestLoadingCSV.java @@ -38,7 +38,7 @@ public class TestLoadingCSV { final Settings settings = Settings.builder() .trainingDataLocation(filename) - .covariates( + .covariateSettings( Utils.easyList(new NumericCovariateSettings("x1"), new FactorCovariateSettings("x2", Utils.easyList("dog", "cat", "mouse")), new BooleanCovariateSettings("x3")) @@ -46,8 +46,7 @@ public class TestLoadingCSV { .yVarSettings(yVarSettings) .build(); - final List covariates = settings.getCovariates().stream() - .map(cs -> cs.build()).collect(Collectors.toList()); + final List covariates = settings.getCovariates(); final DataLoader.ResponseLoader loader = settings.getResponseLoader(); @@ -56,46 +55,50 @@ public class TestLoadingCSV { } @Test - public void verifyLoadingNormal() throws IOException { + public void verifyLoadingNormal(final List covariates) throws IOException { final List> data = loadData("src/test/resources/testCSV.csv"); - assertData(data); + assertData(data, covariates); } @Test - public void verifyLoadingGz() throws IOException { + public void verifyLoadingGz(final List covariates) throws IOException { final List> data = loadData("src/test/resources/testCSV.csv.gz"); - assertData(data); + assertData(data, covariates); } - private void assertData(final List> data){ + private void assertData(final List> data, final List covariates){ + final Covariate x1 = covariates.get(0); + final Covariate x2 = covariates.get(0); + final Covariate x3 = covariates.get(0); + assertEquals(4, data.size()); Row row = data.get(0); assertEquals(5.0, (double)row.getResponse()); - assertEquals(3.0, row.getCovariateValue("x1").getValue()); - assertEquals("mouse", row.getCovariateValue("x2").getValue()); - assertEquals(true, row.getCovariateValue("x3").getValue()); + assertEquals(3.0, row.getCovariateValue(x1).getValue()); + assertEquals("mouse", row.getCovariateValue(x2).getValue()); + assertEquals(true, row.getCovariateValue(x3).getValue()); row = data.get(1); assertEquals(2.0, (double)row.getResponse()); - assertEquals(1.0, row.getCovariateValue("x1").getValue()); - assertEquals("dog", row.getCovariateValue("x2").getValue()); - assertEquals(false, row.getCovariateValue("x3").getValue()); + assertEquals(1.0, row.getCovariateValue(x1).getValue()); + assertEquals("dog", row.getCovariateValue(x2).getValue()); + assertEquals(false, row.getCovariateValue(x3).getValue()); row = data.get(2); assertEquals(9.0, (double)row.getResponse()); - assertEquals(1.5, row.getCovariateValue("x1").getValue()); - assertEquals("cat", row.getCovariateValue("x2").getValue()); - assertEquals(true, row.getCovariateValue("x3").getValue()); + assertEquals(1.5, row.getCovariateValue(x1).getValue()); + assertEquals("cat", row.getCovariateValue(x2).getValue()); + assertEquals(true, row.getCovariateValue(x3).getValue()); row = data.get(3); assertEquals(-3.0, (double)row.getResponse()); - assertTrue(row.getCovariateValue("x1").isNA()); - assertTrue(row.getCovariateValue("x2").isNA()); - assertTrue(row.getCovariateValue("x3").isNA()); + assertTrue(row.getCovariateValue(x1).isNA()); + assertTrue(row.getCovariateValue(x2).isNA()); + assertTrue(row.getCovariateValue(x3).isNA()); } } diff --git a/src/test/java/ca/joeltherrien/randomforest/settings/TestPersistence.java b/src/test/java/ca/joeltherrien/randomforest/settings/TestPersistence.java index a7f11a8..901cdb1 100644 --- a/src/test/java/ca/joeltherrien/randomforest/settings/TestPersistence.java +++ b/src/test/java/ca/joeltherrien/randomforest/settings/TestPersistence.java @@ -31,7 +31,7 @@ public class TestPersistence { yVarSettings.set("name", new TextNode("y")); final Settings settingsOriginal = Settings.builder() - .covariates(Utils.easyList( + .covariateSettings(Utils.easyList( new NumericCovariateSettings("x1"), new BooleanCovariateSettings("x2"), new FactorCovariateSettings("x3", Utils.easyList("cat", "mouse", "dog")) diff --git a/src/test/java/ca/joeltherrien/randomforest/workshop/TrainForest.java b/src/test/java/ca/joeltherrien/randomforest/workshop/TrainForest.java index e2b07e8..f453cd9 100644 --- a/src/test/java/ca/joeltherrien/randomforest/workshop/TrainForest.java +++ b/src/test/java/ca/joeltherrien/randomforest/workshop/TrainForest.java @@ -27,22 +27,22 @@ public class TrainForest { final List covariateList = new ArrayList<>(p); for(int j =0; j < p; j++){ - final NumericCovariate covariate = new NumericCovariate("x"+j); + final NumericCovariate covariate = new NumericCovariate("x"+j, j); covariateList.add(covariate); } for(int i=0; i map = new HashMap<>(); + final Covariate.Value[] valueArray = new Covariate.Value[covariateList.size()]; for(final Covariate covariate : covariateList) { final double x = random.nextDouble(); y += x; - map.put(covariate.getName(), covariate.createValue(x)); + valueArray[covariate.getIndex()] = covariate.createValue(y); } - data.add(i, new Row<>(map, i, y)); + data.add(i, new Row<>(valueArray, i, y)); if(y < minY){ minY = y; diff --git a/src/test/java/ca/joeltherrien/randomforest/workshop/TrainSingleTree.java b/src/test/java/ca/joeltherrien/randomforest/workshop/TrainSingleTree.java index b68c2af..88a855e 100644 --- a/src/test/java/ca/joeltherrien/randomforest/workshop/TrainSingleTree.java +++ b/src/test/java/ca/joeltherrien/randomforest/workshop/TrainSingleTree.java @@ -18,15 +18,13 @@ import java.util.stream.DoubleStream; public class TrainSingleTree { public static void main(String[] args) { - System.out.println("Hello world!"); - final Random random = new Random(123); final int n = 1000; final List> trainingSet = new ArrayList<>(n); - final Covariate x1Covariate = new NumericCovariate("x1"); - final Covariate x2Covariate = new NumericCovariate("x2"); + final Covariate x1Covariate = new NumericCovariate("x1", 0); + final Covariate x2Covariate = new NumericCovariate("x2", 1); final List> x1List = DoubleStream .generate(() -> random.nextDouble()*10.0) @@ -100,17 +98,21 @@ public class TrainSingleTree { public static Row generateRow(Covariate.Value x1, Covariate.Value x2, int id){ double y = generateResponse(x1.getValue(), x2.getValue()); - final Map map = Utils.easyMap("x1", x1, "x2", x2); + final Covariate.Value[] valueArray = new Covariate.Value[2]; + valueArray[0] = x1; + valueArray[1] = x2; - return new Row<>(map, id, y); + return new Row<>(valueArray, id, y); } public static CovariateRow generateCovariateRow(Covariate.Value x1, Covariate.Value x2, int id){ - final Map map = Utils.easyMap("x1", x1, "x2", x2); + final Covariate.Value[] valueArray = new Covariate.Value[2]; + valueArray[0] = x1; + valueArray[1] = x2; - return new CovariateRow(map, id); + return new CovariateRow(valueArray, id); } diff --git a/src/test/java/ca/joeltherrien/randomforest/workshop/TrainSingleTreeFactor.java b/src/test/java/ca/joeltherrien/randomforest/workshop/TrainSingleTreeFactor.java index c1a6b46..6c2d430 100644 --- a/src/test/java/ca/joeltherrien/randomforest/workshop/TrainSingleTreeFactor.java +++ b/src/test/java/ca/joeltherrien/randomforest/workshop/TrainSingleTreeFactor.java @@ -28,9 +28,9 @@ public class TrainSingleTreeFactor { final int n = 10000; final List> trainingSet = new ArrayList<>(n); - final Covariate x1Covariate = new NumericCovariate("x1"); - final Covariate x2Covariate = new NumericCovariate("x2"); - final FactorCovariate x3Covariate = new FactorCovariate("x3", Utils.easyList("cat", "dog", "mouse")); + final Covariate x1Covariate = new NumericCovariate("x1", 0); + final Covariate x2Covariate = new NumericCovariate("x2", 1); + final FactorCovariate x3Covariate = new FactorCovariate("x3", 2, Utils.easyList("cat", "dog", "mouse")); final List> x1List = DoubleStream .generate(() -> random.nextDouble()*10.0) @@ -128,17 +128,25 @@ public class TrainSingleTreeFactor { public static Row generateRow(Covariate.Value x1, Covariate.Value x2, Covariate.Value x3, int id){ double y = generateResponse(x1.getValue(), x2.getValue(), x3.getValue()); - final Map map = Utils.easyMap("x1", x1, "x2", x2); + final Covariate.Value[] valueArray = new Covariate.Value[3]; + valueArray[0] = x1; + valueArray[1] = x2; + valueArray[2] = x3; - return new Row<>(map, id, y); + //final Map map = Utils.easyMap("x1", x1, "x2", x2); // Missing x3? + + return new Row<>(valueArray, id, y); } public static CovariateRow generateCovariateRow(Covariate.Value x1, Covariate.Value x2, Covariate.Value x3, int id){ - final Map map = Utils.easyMap("x1", x1, "x2", x2, "x3", x3); + final Covariate.Value[] valueArray = new Covariate.Value[3]; + valueArray[0] = x1; + valueArray[1] = x2; + valueArray[2] = x3; - return new CovariateRow(map, id); + return new CovariateRow(valueArray, id); }