Switch code to storing Covariate.Value using arrays instead of Maps
This commit is contained in:
parent
de39f60314
commit
aa733d5eba
23 changed files with 145 additions and 101 deletions
|
@ -12,14 +12,13 @@ import java.util.Map;
|
||||||
@RequiredArgsConstructor
|
@RequiredArgsConstructor
|
||||||
public class CovariateRow implements Serializable {
|
public class CovariateRow implements Serializable {
|
||||||
|
|
||||||
private final Map<String, Covariate.Value> valueMap;
|
private final Covariate.Value[] valueArray;
|
||||||
|
|
||||||
@Getter
|
@Getter
|
||||||
private final int id;
|
private final int id;
|
||||||
|
|
||||||
public Covariate.Value<?> getCovariateValue(String name){
|
public Covariate.Value<?> getCovariateValue(Covariate covariate){
|
||||||
return valueMap.get(name);
|
return valueArray[covariate.getIndex()];
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -28,18 +27,21 @@ public class CovariateRow implements Serializable {
|
||||||
}
|
}
|
||||||
|
|
||||||
public static CovariateRow createSimple(Map<String, String> simpleMap, List<Covariate> covariateList, int id){
|
public static CovariateRow createSimple(Map<String, String> simpleMap, List<Covariate> covariateList, int id){
|
||||||
final Map<String, Covariate.Value> valueMap = new HashMap<>();
|
final Covariate.Value[] valueArray = new Covariate.Value[covariateList.size()];
|
||||||
final Map<String, Covariate> covariateMap = new HashMap<>();
|
final Map<String, Covariate> covariateMap = new HashMap<>();
|
||||||
|
|
||||||
covariateList.forEach(covariate -> covariateMap.put(covariate.getName(), covariate));
|
covariateList.forEach(covariate -> covariateMap.put(covariate.getName(), covariate));
|
||||||
|
|
||||||
simpleMap.forEach((name, valueStr) -> {
|
simpleMap.forEach((name, valueStr) -> {
|
||||||
if(covariateMap.containsKey(name)){
|
final Covariate covariate = covariateMap.get(name);
|
||||||
valueMap.put(name, covariateMap.get(name).createValue(valueStr));
|
|
||||||
|
if(covariate != null){ // happens often in tests where we experiment with adding / removing covariates
|
||||||
|
valueArray[covariate.getIndex()] = covariate.createValue(valueStr);
|
||||||
}
|
}
|
||||||
|
|
||||||
});
|
});
|
||||||
|
|
||||||
return new CovariateRow(valueMap, id);
|
return new CovariateRow(valueArray, id);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -37,15 +37,15 @@ public class DataLoader {
|
||||||
|
|
||||||
int id = 1;
|
int id = 1;
|
||||||
for(final CSVRecord record : parser){
|
for(final CSVRecord record : parser){
|
||||||
final Map<String, Covariate.Value> covariateValueMap = new HashMap<>();
|
final Covariate.Value[] valueArray = new Covariate.Value[covariates.size()];
|
||||||
|
|
||||||
for(final Covariate<?> covariate : covariates){
|
for(final Covariate<?> covariate : covariates){
|
||||||
covariateValueMap.put(covariate.getName(), covariate.createValue(record.get(covariate.getName())));
|
valueArray[covariate.getIndex()] = covariate.createValue(record.get(covariate.getName()));
|
||||||
}
|
}
|
||||||
|
|
||||||
final Y y = responseLoader.parse(record);
|
final Y y = responseLoader.parse(record);
|
||||||
|
|
||||||
dataset.add(new Row<>(covariateValueMap, id++, y));
|
dataset.add(new Row<>(valueArray, id++, y));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,9 +1,6 @@
|
||||||
package ca.joeltherrien.randomforest;
|
package ca.joeltherrien.randomforest;
|
||||||
|
|
||||||
import ca.joeltherrien.randomforest.covariates.BooleanCovariateSettings;
|
import ca.joeltherrien.randomforest.covariates.*;
|
||||||
import ca.joeltherrien.randomforest.covariates.Covariate;
|
|
||||||
import ca.joeltherrien.randomforest.covariates.FactorCovariateSettings;
|
|
||||||
import ca.joeltherrien.randomforest.covariates.NumericCovariateSettings;
|
|
||||||
import ca.joeltherrien.randomforest.responses.competingrisk.*;
|
import ca.joeltherrien.randomforest.responses.competingrisk.*;
|
||||||
import ca.joeltherrien.randomforest.responses.competingrisk.combiner.CompetingRiskFunctionCombiner;
|
import ca.joeltherrien.randomforest.responses.competingrisk.combiner.CompetingRiskFunctionCombiner;
|
||||||
import ca.joeltherrien.randomforest.responses.competingrisk.combiner.alternative.CompetingRiskListCombiner;
|
import ca.joeltherrien.randomforest.responses.competingrisk.combiner.alternative.CompetingRiskListCombiner;
|
||||||
|
@ -17,6 +14,7 @@ import com.fasterxml.jackson.databind.node.ObjectNode;
|
||||||
import com.fasterxml.jackson.databind.node.TextNode;
|
import com.fasterxml.jackson.databind.node.TextNode;
|
||||||
|
|
||||||
import java.io.*;
|
import java.io.*;
|
||||||
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
|
@ -35,8 +33,7 @@ public class Main {
|
||||||
final File settingsFile = new File(args[0]);
|
final File settingsFile = new File(args[0]);
|
||||||
final Settings settings = Settings.load(settingsFile);
|
final Settings settings = Settings.load(settingsFile);
|
||||||
|
|
||||||
final List<Covariate> covariates = settings.getCovariates().stream()
|
final List<Covariate> covariates = settings.getCovariates();
|
||||||
.map(cs -> cs.build()).collect(Collectors.toList());
|
|
||||||
|
|
||||||
if(args[1].equalsIgnoreCase("train")){
|
if(args[1].equalsIgnoreCase("train")){
|
||||||
final List<Row> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation());
|
final List<Row> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation());
|
||||||
|
@ -168,7 +165,7 @@ public class Main {
|
||||||
yVarSettings.set("name", new TextNode("y"));
|
yVarSettings.set("name", new TextNode("y"));
|
||||||
|
|
||||||
return Settings.builder()
|
return Settings.builder()
|
||||||
.covariates(Utils.easyList(
|
.covariateSettings(Utils.easyList(
|
||||||
new NumericCovariateSettings("x1"),
|
new NumericCovariateSettings("x1"),
|
||||||
new BooleanCovariateSettings("x2"),
|
new BooleanCovariateSettings("x2"),
|
||||||
new FactorCovariateSettings("x3", Utils.easyList("cat", "mouse", "dog"))
|
new FactorCovariateSettings("x3", Utils.easyList("cat", "mouse", "dog"))
|
||||||
|
|
|
@ -3,14 +3,16 @@ package ca.joeltherrien.randomforest;
|
||||||
|
|
||||||
import ca.joeltherrien.randomforest.covariates.Covariate;
|
import ca.joeltherrien.randomforest.covariates.Covariate;
|
||||||
|
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
|
||||||
public class Row<Y> extends CovariateRow {
|
public class Row<Y> extends CovariateRow {
|
||||||
|
|
||||||
private final Y response;
|
private final Y response;
|
||||||
|
|
||||||
public Row(Map<String, Covariate.Value> valueMap, int id, Y response){
|
public Row(final Covariate.Value[] valueArray, final int id, final Y response){
|
||||||
super(valueMap, id);
|
super(valueArray, id);
|
||||||
this.response = response;
|
this.response = response;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -23,7 +25,21 @@ public class Row<Y> extends CovariateRow {
|
||||||
public String toString() {
|
public String toString() {
|
||||||
return "Row " + this.getId();
|
return "Row " + this.getId();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public static <Y> Row<Y> createSimple(Map<String, String> simpleMap, List<Covariate> covariateList, int id, final Y response){
|
||||||
|
final Covariate.Value[] valueArray = new Covariate.Value[covariateList.size()];
|
||||||
|
final Map<String, Covariate> covariateMap = new HashMap<>();
|
||||||
|
|
||||||
|
covariateList.forEach(covariate -> covariateMap.put(covariate.getName(), covariate));
|
||||||
|
|
||||||
|
simpleMap.forEach((name, valueStr) -> {
|
||||||
|
final Covariate covariate = covariateMap.get(name);
|
||||||
|
if(covariate != null){ // happens often in tests where we experiment with adding / removing covariates
|
||||||
|
valueArray[covariate.getIndex()] = covariate.createValue(valueStr);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
return new Row<Y>(valueArray, id, response);
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
package ca.joeltherrien.randomforest;
|
package ca.joeltherrien.randomforest;
|
||||||
|
|
||||||
|
import ca.joeltherrien.randomforest.covariates.Covariate;
|
||||||
import ca.joeltherrien.randomforest.covariates.CovariateSettings;
|
import ca.joeltherrien.randomforest.covariates.CovariateSettings;
|
||||||
import ca.joeltherrien.randomforest.responses.competingrisk.*;
|
import ca.joeltherrien.randomforest.responses.competingrisk.*;
|
||||||
import ca.joeltherrien.randomforest.responses.competingrisk.combiner.CompetingRiskFunctionCombiner;
|
import ca.joeltherrien.randomforest.responses.competingrisk.combiner.CompetingRiskFunctionCombiner;
|
||||||
|
@ -195,7 +196,7 @@ public class Settings {
|
||||||
private ObjectNode groupDifferentiatorSettings = new ObjectNode(JsonNodeFactory.instance);
|
private ObjectNode groupDifferentiatorSettings = new ObjectNode(JsonNodeFactory.instance);
|
||||||
private ObjectNode treeCombinerSettings = new ObjectNode(JsonNodeFactory.instance);
|
private ObjectNode treeCombinerSettings = new ObjectNode(JsonNodeFactory.instance);
|
||||||
|
|
||||||
private List<CovariateSettings> covariates = new ArrayList<>();
|
private List<CovariateSettings> covariateSettings = new ArrayList<>();
|
||||||
private ObjectNode yVarSettings = new ObjectNode(JsonNodeFactory.instance);
|
private ObjectNode yVarSettings = new ObjectNode(JsonNodeFactory.instance);
|
||||||
|
|
||||||
// number of covariates to randomly try
|
// number of covariates to randomly try
|
||||||
|
@ -227,7 +228,7 @@ public class Settings {
|
||||||
//mapper.enableDefaultTyping();
|
//mapper.enableDefaultTyping();
|
||||||
|
|
||||||
// Jackson can struggle with some types of Lists, such as that returned by the useful List.of(...)
|
// Jackson can struggle with some types of Lists, such as that returned by the useful List.of(...)
|
||||||
this.covariates = new ArrayList<>(this.covariates);
|
this.covariateSettings = new ArrayList<>(this.covariateSettings);
|
||||||
|
|
||||||
mapper.writeValue(file, this);
|
mapper.writeValue(file, this);
|
||||||
}
|
}
|
||||||
|
@ -260,4 +261,14 @@ public class Settings {
|
||||||
return getResponseCombinerConstructor(type).apply(treeCombinerSettings);
|
return getResponseCombinerConstructor(type).apply(treeCombinerSettings);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@JsonIgnore
|
||||||
|
public List<Covariate> getCovariates(){
|
||||||
|
final List<CovariateSettings> covariateSettingsList = this.getCovariateSettings();
|
||||||
|
final List<Covariate> covariates = new ArrayList<>(covariateSettingsList.size());
|
||||||
|
for(int i = 0; i < covariateSettingsList.size(); i++){
|
||||||
|
covariates.add(covariateSettingsList.get(i).build(i));
|
||||||
|
}
|
||||||
|
return covariates;
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -13,6 +13,9 @@ public final class BooleanCovariate implements Covariate<Boolean>{
|
||||||
@Getter
|
@Getter
|
||||||
private final String name;
|
private final String name;
|
||||||
|
|
||||||
|
@Getter
|
||||||
|
private final int index;
|
||||||
|
|
||||||
private final BooleanSplitRule splitRule = new BooleanSplitRule(); // there's only one possible rule for BooleanCovariates.
|
private final BooleanSplitRule splitRule = new BooleanSplitRule(); // there's only one possible rule for BooleanCovariates.
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -12,7 +12,7 @@ public final class BooleanCovariateSettings extends CovariateSettings<Boolean> {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public BooleanCovariate build() {
|
public BooleanCovariate build(int index) {
|
||||||
return new BooleanCovariate(name);
|
return new BooleanCovariate(name, index);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -12,6 +12,8 @@ public interface Covariate<V> extends Serializable {
|
||||||
|
|
||||||
String getName();
|
String getName();
|
||||||
|
|
||||||
|
int getIndex();
|
||||||
|
|
||||||
Collection<? extends SplitRule<V>> generateSplitRules(final List<Value<V>> data, final int number);
|
Collection<? extends SplitRule<V>> generateSplitRules(final List<Value<V>> data, final int number);
|
||||||
|
|
||||||
Value<V> createValue(V value);
|
Value<V> createValue(V value);
|
||||||
|
@ -54,7 +56,7 @@ public interface Covariate<V> extends Serializable {
|
||||||
|
|
||||||
|
|
||||||
for(final Row<Y> row : rows) {
|
for(final Row<Y> row : rows) {
|
||||||
final Value<V> value = (Value<V>) row.getCovariateValue(getParent().getName());
|
final Value<V> value = (Value<V>) row.getCovariateValue(getParent());
|
||||||
|
|
||||||
if(value.isNA()){
|
if(value.isNA()){
|
||||||
missingValueRows.add(row);
|
missingValueRows.add(row);
|
||||||
|
@ -76,7 +78,7 @@ public interface Covariate<V> extends Serializable {
|
||||||
}
|
}
|
||||||
|
|
||||||
default boolean isLeftHand(CovariateRow row, final double probabilityNaLeftHand){
|
default boolean isLeftHand(CovariateRow row, final double probabilityNaLeftHand){
|
||||||
final Value<V> value = (Value<V>) row.getCovariateValue(getParent().getName());
|
final Value<V> value = (Value<V>) row.getCovariateValue(getParent());
|
||||||
|
|
||||||
if(value.isNA()){
|
if(value.isNA()){
|
||||||
return ThreadLocalRandom.current().nextDouble() <= probabilityNaLeftHand;
|
return ThreadLocalRandom.current().nextDouble() <= probabilityNaLeftHand;
|
||||||
|
|
|
@ -28,5 +28,5 @@ public abstract class CovariateSettings<V> {
|
||||||
this.name = name;
|
this.name = name;
|
||||||
}
|
}
|
||||||
|
|
||||||
public abstract Covariate<V> build();
|
public abstract Covariate<V> build(int index);
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,20 +1,27 @@
|
||||||
package ca.joeltherrien.randomforest.covariates;
|
package ca.joeltherrien.randomforest.covariates;
|
||||||
|
|
||||||
import lombok.EqualsAndHashCode;
|
import lombok.EqualsAndHashCode;
|
||||||
|
import lombok.Getter;
|
||||||
|
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
import java.util.concurrent.ThreadLocalRandom;
|
import java.util.concurrent.ThreadLocalRandom;
|
||||||
|
|
||||||
public final class FactorCovariate implements Covariate<String>{
|
public final class FactorCovariate implements Covariate<String>{
|
||||||
|
|
||||||
|
@Getter
|
||||||
private final String name;
|
private final String name;
|
||||||
|
|
||||||
|
@Getter
|
||||||
|
private final int index;
|
||||||
|
|
||||||
private final Map<String, FactorValue> factorLevels;
|
private final Map<String, FactorValue> factorLevels;
|
||||||
private final FactorValue naValue;
|
private final FactorValue naValue;
|
||||||
private final int numberOfPossiblePairings;
|
private final int numberOfPossiblePairings;
|
||||||
|
|
||||||
|
|
||||||
public FactorCovariate(final String name, List<String> levels){
|
public FactorCovariate(final String name, final int index, List<String> levels){
|
||||||
this.name = name;
|
this.name = name;
|
||||||
|
this.index = index;
|
||||||
this.factorLevels = new HashMap<>();
|
this.factorLevels = new HashMap<>();
|
||||||
|
|
||||||
for(final String level : levels){
|
for(final String level : levels){
|
||||||
|
@ -33,10 +40,6 @@ public final class FactorCovariate implements Covariate<String>{
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String getName() {
|
|
||||||
return name;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Set<FactorSplitRule> generateSplitRules(List<Value<String>> data, int number) {
|
public Set<FactorSplitRule> generateSplitRules(List<Value<String>> data, int number) {
|
||||||
|
|
|
@ -18,7 +18,7 @@ public final class FactorCovariateSettings extends CovariateSettings<String> {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public FactorCovariate build() {
|
public FactorCovariate build(int index) {
|
||||||
return new FactorCovariate(name, levels);
|
return new FactorCovariate(name, index, levels);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -16,6 +16,9 @@ public final class NumericCovariate implements Covariate<Double>{
|
||||||
@Getter
|
@Getter
|
||||||
private final String name;
|
private final String name;
|
||||||
|
|
||||||
|
@Getter
|
||||||
|
private final int index;
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Collection<NumericSplitRule> generateSplitRules(List<Value<Double>> data, int number) {
|
public Collection<NumericSplitRule> generateSplitRules(List<Value<Double>> data, int number) {
|
||||||
|
|
||||||
|
|
|
@ -12,7 +12,7 @@ public final class NumericCovariateSettings extends CovariateSettings<Double> {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public NumericCovariate build() {
|
public NumericCovariate build(int index) {
|
||||||
return new NumericCovariate(name);
|
return new NumericCovariate(name, index);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -113,7 +113,7 @@ public class TreeTrainer<Y, O> {
|
||||||
.generateSplitRules(
|
.generateSplitRules(
|
||||||
data
|
data
|
||||||
.stream()
|
.stream()
|
||||||
.map(row -> row.getCovariateValue(covariate.getName()))
|
.map(row -> row.getCovariateValue(covariate))
|
||||||
.collect(Collectors.toList())
|
.collect(Collectors.toList())
|
||||||
, numberToTry);
|
, numberToTry);
|
||||||
|
|
||||||
|
|
|
@ -53,7 +53,7 @@ public class TestSavingLoading {
|
||||||
yVarSettings.set("delta", new TextNode("status"));
|
yVarSettings.set("delta", new TextNode("status"));
|
||||||
|
|
||||||
return Settings.builder()
|
return Settings.builder()
|
||||||
.covariates(Utils.easyList(
|
.covariateSettings(Utils.easyList(
|
||||||
new NumericCovariateSettings("ageatfda"),
|
new NumericCovariateSettings("ageatfda"),
|
||||||
new BooleanCovariateSettings("idu"),
|
new BooleanCovariateSettings("idu"),
|
||||||
new BooleanCovariateSettings("black"),
|
new BooleanCovariateSettings("black"),
|
||||||
|
@ -87,14 +87,10 @@ public class TestSavingLoading {
|
||||||
, covariates, 1);
|
, covariates, 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
public List<Covariate> getCovariates(Settings settings){
|
|
||||||
return settings.getCovariates().stream().map(covariateSettings -> covariateSettings.build()).collect(Collectors.toList());
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testSavingLoading() throws IOException, ClassNotFoundException {
|
public void testSavingLoading() throws IOException, ClassNotFoundException {
|
||||||
final Settings settings = getSettings();
|
final Settings settings = getSettings();
|
||||||
final List<Covariate> covariates = getCovariates(settings);
|
final List<Covariate> covariates = settings.getCovariates();
|
||||||
final List<Row<CompetingRiskResponse>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation());
|
final List<Row<CompetingRiskResponse>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation());
|
||||||
|
|
||||||
final File directory = new File(settings.getSaveTreeLocation());
|
final File directory = new File(settings.getSaveTreeLocation());
|
||||||
|
|
|
@ -55,7 +55,7 @@ public class TestCompetingRisk {
|
||||||
yVarSettings.set("delta", new TextNode("status"));
|
yVarSettings.set("delta", new TextNode("status"));
|
||||||
|
|
||||||
return Settings.builder()
|
return Settings.builder()
|
||||||
.covariates(Utils.easyList(
|
.covariateSettings(Utils.easyList(
|
||||||
new NumericCovariateSettings("ageatfda"),
|
new NumericCovariateSettings("ageatfda"),
|
||||||
new BooleanCovariateSettings("idu"),
|
new BooleanCovariateSettings("idu"),
|
||||||
new BooleanCovariateSettings("black"),
|
new BooleanCovariateSettings("black"),
|
||||||
|
@ -79,9 +79,6 @@ public class TestCompetingRisk {
|
||||||
.build();
|
.build();
|
||||||
}
|
}
|
||||||
|
|
||||||
public List<Covariate> getCovariates(Settings settings){
|
|
||||||
return settings.getCovariates().stream().map(covariateSettings -> covariateSettings.build()).collect(Collectors.toList());
|
|
||||||
}
|
|
||||||
|
|
||||||
public CovariateRow getPredictionRow(List<Covariate> covariates){
|
public CovariateRow getPredictionRow(List<Covariate> covariates){
|
||||||
return CovariateRow.createSimple(Utils.easyMap(
|
return CovariateRow.createSimple(Utils.easyMap(
|
||||||
|
@ -96,12 +93,12 @@ public class TestCompetingRisk {
|
||||||
public void testSingleTree() throws IOException {
|
public void testSingleTree() throws IOException {
|
||||||
final Settings settings = getSettings();
|
final Settings settings = getSettings();
|
||||||
settings.setTrainingDataLocation("src/test/resources/wihs.bootstrapped.csv");
|
settings.setTrainingDataLocation("src/test/resources/wihs.bootstrapped.csv");
|
||||||
settings.setCovariates(Utils.easyList(
|
settings.setCovariateSettings(Utils.easyList(
|
||||||
new BooleanCovariateSettings("idu"),
|
new BooleanCovariateSettings("idu"),
|
||||||
new BooleanCovariateSettings("black")
|
new BooleanCovariateSettings("black")
|
||||||
)); // by only using BooleanCovariates (only one split rule) we can guarantee identical results with randomForestSRC on one tree.
|
)); // by only using BooleanCovariates (only one split rule) we can guarantee identical results with randomForestSRC on one tree.
|
||||||
|
|
||||||
final List<Covariate> covariates = getCovariates(settings);
|
final List<Covariate> covariates = settings.getCovariates();
|
||||||
|
|
||||||
final List<Row<CompetingRiskResponse>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation());
|
final List<Row<CompetingRiskResponse>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation());
|
||||||
|
|
||||||
|
@ -154,7 +151,7 @@ public class TestCompetingRisk {
|
||||||
settings.setNumberOfSplits(0);
|
settings.setNumberOfSplits(0);
|
||||||
settings.setTrainingDataLocation("src/test/resources/wihs.bootstrapped2.csv");
|
settings.setTrainingDataLocation("src/test/resources/wihs.bootstrapped2.csv");
|
||||||
|
|
||||||
final List<Covariate> covariates = getCovariates(settings);
|
final List<Covariate> covariates = settings.getCovariates();
|
||||||
|
|
||||||
final List<Row<CompetingRiskResponse>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation());
|
final List<Row<CompetingRiskResponse>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation());
|
||||||
|
|
||||||
|
@ -199,12 +196,12 @@ public class TestCompetingRisk {
|
||||||
@Test
|
@Test
|
||||||
public void testLogRankSingleGroupDifferentiatorTwoBooleans() throws IOException {
|
public void testLogRankSingleGroupDifferentiatorTwoBooleans() throws IOException {
|
||||||
final Settings settings = getSettings();
|
final Settings settings = getSettings();
|
||||||
settings.setCovariates(Utils.easyList(
|
settings.setCovariateSettings(Utils.easyList(
|
||||||
new BooleanCovariateSettings("idu"),
|
new BooleanCovariateSettings("idu"),
|
||||||
new BooleanCovariateSettings("black")
|
new BooleanCovariateSettings("black")
|
||||||
));
|
));
|
||||||
|
|
||||||
final List<Covariate> covariates = getCovariates(settings);
|
final List<Covariate> covariates = settings.getCovariates();
|
||||||
|
|
||||||
final List<Row<CompetingRiskResponse>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation());
|
final List<Row<CompetingRiskResponse>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation());
|
||||||
|
|
||||||
|
@ -254,7 +251,7 @@ public class TestCompetingRisk {
|
||||||
public void verifyDataset() throws IOException {
|
public void verifyDataset() throws IOException {
|
||||||
final Settings settings = getSettings();
|
final Settings settings = getSettings();
|
||||||
|
|
||||||
final List<Covariate> covariates = getCovariates(settings);
|
final List<Covariate> covariates = settings.getCovariates();
|
||||||
|
|
||||||
final List<Row<CompetingRiskResponse>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation());
|
final List<Row<CompetingRiskResponse>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation());
|
||||||
|
|
||||||
|
@ -291,7 +288,7 @@ public class TestCompetingRisk {
|
||||||
final Settings settings = getSettings();
|
final Settings settings = getSettings();
|
||||||
settings.setNtree(300); // results are too variable at 100
|
settings.setNtree(300); // results are too variable at 100
|
||||||
|
|
||||||
final List<Covariate> covariates = getCovariates(settings);
|
final List<Covariate> covariates = settings.getCovariates();
|
||||||
final List<Row<CompetingRiskResponse>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation());
|
final List<Row<CompetingRiskResponse>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation());
|
||||||
final ForestTrainer<CompetingRiskResponse, CompetingRiskFunctions, CompetingRiskFunctions> forestTrainer = new ForestTrainer<>(settings, dataset, covariates);
|
final ForestTrainer<CompetingRiskResponse, CompetingRiskFunctions, CompetingRiskFunctions> forestTrainer = new ForestTrainer<>(settings, dataset, covariates);
|
||||||
final Forest<CompetingRiskFunctions, CompetingRiskFunctions> forest = forestTrainer.trainSerial();
|
final Forest<CompetingRiskFunctions, CompetingRiskFunctions> forest = forestTrainer.trainSerial();
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package ca.joeltherrien.randomforest.competingrisk;
|
package ca.joeltherrien.randomforest.competingrisk;
|
||||||
|
|
||||||
import ca.joeltherrien.randomforest.Row;
|
import ca.joeltherrien.randomforest.Row;
|
||||||
|
import ca.joeltherrien.randomforest.covariates.Covariate;
|
||||||
import ca.joeltherrien.randomforest.responses.competingrisk.*;
|
import ca.joeltherrien.randomforest.responses.competingrisk.*;
|
||||||
import ca.joeltherrien.randomforest.tree.Forest;
|
import ca.joeltherrien.randomforest.tree.Forest;
|
||||||
import ca.joeltherrien.randomforest.utils.MathFunction;
|
import ca.joeltherrien.randomforest.utils.MathFunction;
|
||||||
|
@ -54,10 +55,10 @@ public class TestCompetingRiskErrorRateCalculator {
|
||||||
final CompetingRiskResponse response4 = new CompetingRiskResponse(1, 3.0);
|
final CompetingRiskResponse response4 = new CompetingRiskResponse(1, 3.0);
|
||||||
|
|
||||||
final List<Row<CompetingRiskResponse>> dataset = Utils.easyList(
|
final List<Row<CompetingRiskResponse>> dataset = Utils.easyList(
|
||||||
new Row<>(Collections.emptyMap(), 1, response1),
|
new Row<>(new Covariate.Value[]{}, 1, response1),
|
||||||
new Row<>(Collections.emptyMap(), 2, response2),
|
new Row<>(new Covariate.Value[]{}, 2, response2),
|
||||||
new Row<>(Collections.emptyMap(), 3, response3),
|
new Row<>(new Covariate.Value[]{}, 3, response3),
|
||||||
new Row<>(Collections.emptyMap(), 4, response4)
|
new Row<>(new Covariate.Value[]{}, 4, response4)
|
||||||
);
|
);
|
||||||
|
|
||||||
final double[] mortalityOneArray = new double[]{1, 4, 3, 9};
|
final double[] mortalityOneArray = new double[]{1, 4, 3, 9};
|
||||||
|
|
|
@ -58,7 +58,7 @@ public class FactorCovariateTest {
|
||||||
private FactorCovariate createTestCovariate(){
|
private FactorCovariate createTestCovariate(){
|
||||||
final List<String> levels = Utils.easyList("DOG", "CAT", "MOUSE");
|
final List<String> levels = Utils.easyList("DOG", "CAT", "MOUSE");
|
||||||
|
|
||||||
return new FactorCovariate("pet", levels);
|
return new FactorCovariate("pet", 0, levels);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -38,7 +38,7 @@ public class TestLoadingCSV {
|
||||||
|
|
||||||
final Settings settings = Settings.builder()
|
final Settings settings = Settings.builder()
|
||||||
.trainingDataLocation(filename)
|
.trainingDataLocation(filename)
|
||||||
.covariates(
|
.covariateSettings(
|
||||||
Utils.easyList(new NumericCovariateSettings("x1"),
|
Utils.easyList(new NumericCovariateSettings("x1"),
|
||||||
new FactorCovariateSettings("x2", Utils.easyList("dog", "cat", "mouse")),
|
new FactorCovariateSettings("x2", Utils.easyList("dog", "cat", "mouse")),
|
||||||
new BooleanCovariateSettings("x3"))
|
new BooleanCovariateSettings("x3"))
|
||||||
|
@ -46,8 +46,7 @@ public class TestLoadingCSV {
|
||||||
.yVarSettings(yVarSettings)
|
.yVarSettings(yVarSettings)
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
final List<Covariate> covariates = settings.getCovariates().stream()
|
final List<Covariate> covariates = settings.getCovariates();
|
||||||
.map(cs -> cs.build()).collect(Collectors.toList());
|
|
||||||
|
|
||||||
|
|
||||||
final DataLoader.ResponseLoader loader = settings.getResponseLoader();
|
final DataLoader.ResponseLoader loader = settings.getResponseLoader();
|
||||||
|
@ -56,46 +55,50 @@ public class TestLoadingCSV {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void verifyLoadingNormal() throws IOException {
|
public void verifyLoadingNormal(final List<Covariate> covariates) throws IOException {
|
||||||
final List<Row<Double>> data = loadData("src/test/resources/testCSV.csv");
|
final List<Row<Double>> data = loadData("src/test/resources/testCSV.csv");
|
||||||
|
|
||||||
assertData(data);
|
assertData(data, covariates);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void verifyLoadingGz() throws IOException {
|
public void verifyLoadingGz(final List<Covariate> covariates) throws IOException {
|
||||||
final List<Row<Double>> data = loadData("src/test/resources/testCSV.csv.gz");
|
final List<Row<Double>> data = loadData("src/test/resources/testCSV.csv.gz");
|
||||||
|
|
||||||
assertData(data);
|
assertData(data, covariates);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
private void assertData(final List<Row<Double>> data){
|
private void assertData(final List<Row<Double>> data, final List<Covariate> covariates){
|
||||||
|
final Covariate x1 = covariates.get(0);
|
||||||
|
final Covariate x2 = covariates.get(0);
|
||||||
|
final Covariate x3 = covariates.get(0);
|
||||||
|
|
||||||
assertEquals(4, data.size());
|
assertEquals(4, data.size());
|
||||||
|
|
||||||
Row<Double> row = data.get(0);
|
Row<Double> row = data.get(0);
|
||||||
assertEquals(5.0, (double)row.getResponse());
|
assertEquals(5.0, (double)row.getResponse());
|
||||||
assertEquals(3.0, row.getCovariateValue("x1").getValue());
|
assertEquals(3.0, row.getCovariateValue(x1).getValue());
|
||||||
assertEquals("mouse", row.getCovariateValue("x2").getValue());
|
assertEquals("mouse", row.getCovariateValue(x2).getValue());
|
||||||
assertEquals(true, row.getCovariateValue("x3").getValue());
|
assertEquals(true, row.getCovariateValue(x3).getValue());
|
||||||
|
|
||||||
row = data.get(1);
|
row = data.get(1);
|
||||||
assertEquals(2.0, (double)row.getResponse());
|
assertEquals(2.0, (double)row.getResponse());
|
||||||
assertEquals(1.0, row.getCovariateValue("x1").getValue());
|
assertEquals(1.0, row.getCovariateValue(x1).getValue());
|
||||||
assertEquals("dog", row.getCovariateValue("x2").getValue());
|
assertEquals("dog", row.getCovariateValue(x2).getValue());
|
||||||
assertEquals(false, row.getCovariateValue("x3").getValue());
|
assertEquals(false, row.getCovariateValue(x3).getValue());
|
||||||
|
|
||||||
row = data.get(2);
|
row = data.get(2);
|
||||||
assertEquals(9.0, (double)row.getResponse());
|
assertEquals(9.0, (double)row.getResponse());
|
||||||
assertEquals(1.5, row.getCovariateValue("x1").getValue());
|
assertEquals(1.5, row.getCovariateValue(x1).getValue());
|
||||||
assertEquals("cat", row.getCovariateValue("x2").getValue());
|
assertEquals("cat", row.getCovariateValue(x2).getValue());
|
||||||
assertEquals(true, row.getCovariateValue("x3").getValue());
|
assertEquals(true, row.getCovariateValue(x3).getValue());
|
||||||
|
|
||||||
row = data.get(3);
|
row = data.get(3);
|
||||||
assertEquals(-3.0, (double)row.getResponse());
|
assertEquals(-3.0, (double)row.getResponse());
|
||||||
assertTrue(row.getCovariateValue("x1").isNA());
|
assertTrue(row.getCovariateValue(x1).isNA());
|
||||||
assertTrue(row.getCovariateValue("x2").isNA());
|
assertTrue(row.getCovariateValue(x2).isNA());
|
||||||
assertTrue(row.getCovariateValue("x3").isNA());
|
assertTrue(row.getCovariateValue(x3).isNA());
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -31,7 +31,7 @@ public class TestPersistence {
|
||||||
yVarSettings.set("name", new TextNode("y"));
|
yVarSettings.set("name", new TextNode("y"));
|
||||||
|
|
||||||
final Settings settingsOriginal = Settings.builder()
|
final Settings settingsOriginal = Settings.builder()
|
||||||
.covariates(Utils.easyList(
|
.covariateSettings(Utils.easyList(
|
||||||
new NumericCovariateSettings("x1"),
|
new NumericCovariateSettings("x1"),
|
||||||
new BooleanCovariateSettings("x2"),
|
new BooleanCovariateSettings("x2"),
|
||||||
new FactorCovariateSettings("x3", Utils.easyList("cat", "mouse", "dog"))
|
new FactorCovariateSettings("x3", Utils.easyList("cat", "mouse", "dog"))
|
||||||
|
|
|
@ -27,22 +27,22 @@ public class TrainForest {
|
||||||
|
|
||||||
final List<Covariate> covariateList = new ArrayList<>(p);
|
final List<Covariate> covariateList = new ArrayList<>(p);
|
||||||
for(int j =0; j < p; j++){
|
for(int j =0; j < p; j++){
|
||||||
final NumericCovariate covariate = new NumericCovariate("x"+j);
|
final NumericCovariate covariate = new NumericCovariate("x"+j, j);
|
||||||
covariateList.add(covariate);
|
covariateList.add(covariate);
|
||||||
}
|
}
|
||||||
|
|
||||||
for(int i=0; i<n; i++){
|
for(int i=0; i<n; i++){
|
||||||
double y = 0.0;
|
double y = 0.0;
|
||||||
final Map<String, Covariate.Value> map = new HashMap<>();
|
final Covariate.Value[] valueArray = new Covariate.Value[covariateList.size()];
|
||||||
|
|
||||||
for(final Covariate covariate : covariateList) {
|
for(final Covariate covariate : covariateList) {
|
||||||
final double x = random.nextDouble();
|
final double x = random.nextDouble();
|
||||||
y += x;
|
y += x;
|
||||||
|
|
||||||
map.put(covariate.getName(), covariate.createValue(x));
|
valueArray[covariate.getIndex()] = covariate.createValue(y);
|
||||||
}
|
}
|
||||||
|
|
||||||
data.add(i, new Row<>(map, i, y));
|
data.add(i, new Row<>(valueArray, i, y));
|
||||||
|
|
||||||
if(y < minY){
|
if(y < minY){
|
||||||
minY = y;
|
minY = y;
|
||||||
|
|
|
@ -18,15 +18,13 @@ import java.util.stream.DoubleStream;
|
||||||
public class TrainSingleTree {
|
public class TrainSingleTree {
|
||||||
|
|
||||||
public static void main(String[] args) {
|
public static void main(String[] args) {
|
||||||
System.out.println("Hello world!");
|
|
||||||
|
|
||||||
final Random random = new Random(123);
|
final Random random = new Random(123);
|
||||||
|
|
||||||
final int n = 1000;
|
final int n = 1000;
|
||||||
final List<Row<Double>> trainingSet = new ArrayList<>(n);
|
final List<Row<Double>> trainingSet = new ArrayList<>(n);
|
||||||
|
|
||||||
final Covariate<Double> x1Covariate = new NumericCovariate("x1");
|
final Covariate<Double> x1Covariate = new NumericCovariate("x1", 0);
|
||||||
final Covariate<Double> x2Covariate = new NumericCovariate("x2");
|
final Covariate<Double> x2Covariate = new NumericCovariate("x2", 1);
|
||||||
|
|
||||||
final List<Covariate.Value<Double>> x1List = DoubleStream
|
final List<Covariate.Value<Double>> x1List = DoubleStream
|
||||||
.generate(() -> random.nextDouble()*10.0)
|
.generate(() -> random.nextDouble()*10.0)
|
||||||
|
@ -100,17 +98,21 @@ public class TrainSingleTree {
|
||||||
public static Row<Double> generateRow(Covariate.Value<Double> x1, Covariate.Value<Double> x2, int id){
|
public static Row<Double> generateRow(Covariate.Value<Double> x1, Covariate.Value<Double> x2, int id){
|
||||||
double y = generateResponse(x1.getValue(), x2.getValue());
|
double y = generateResponse(x1.getValue(), x2.getValue());
|
||||||
|
|
||||||
final Map<String, Covariate.Value> map = Utils.easyMap("x1", x1, "x2", x2);
|
final Covariate.Value[] valueArray = new Covariate.Value[2];
|
||||||
|
valueArray[0] = x1;
|
||||||
|
valueArray[1] = x2;
|
||||||
|
|
||||||
return new Row<>(map, id, y);
|
return new Row<>(valueArray, id, y);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
public static CovariateRow generateCovariateRow(Covariate.Value x1, Covariate.Value x2, int id){
|
public static CovariateRow generateCovariateRow(Covariate.Value x1, Covariate.Value x2, int id){
|
||||||
final Map<String, Covariate.Value> map = Utils.easyMap("x1", x1, "x2", x2);
|
final Covariate.Value[] valueArray = new Covariate.Value[2];
|
||||||
|
valueArray[0] = x1;
|
||||||
|
valueArray[1] = x2;
|
||||||
|
|
||||||
return new CovariateRow(map, id);
|
return new CovariateRow(valueArray, id);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -28,9 +28,9 @@ public class TrainSingleTreeFactor {
|
||||||
final int n = 10000;
|
final int n = 10000;
|
||||||
final List<Row<Double>> trainingSet = new ArrayList<>(n);
|
final List<Row<Double>> trainingSet = new ArrayList<>(n);
|
||||||
|
|
||||||
final Covariate<Double> x1Covariate = new NumericCovariate("x1");
|
final Covariate<Double> x1Covariate = new NumericCovariate("x1", 0);
|
||||||
final Covariate<Double> x2Covariate = new NumericCovariate("x2");
|
final Covariate<Double> x2Covariate = new NumericCovariate("x2", 1);
|
||||||
final FactorCovariate x3Covariate = new FactorCovariate("x3", Utils.easyList("cat", "dog", "mouse"));
|
final FactorCovariate x3Covariate = new FactorCovariate("x3", 2, Utils.easyList("cat", "dog", "mouse"));
|
||||||
|
|
||||||
final List<Covariate.Value<Double>> x1List = DoubleStream
|
final List<Covariate.Value<Double>> x1List = DoubleStream
|
||||||
.generate(() -> random.nextDouble()*10.0)
|
.generate(() -> random.nextDouble()*10.0)
|
||||||
|
@ -128,17 +128,25 @@ public class TrainSingleTreeFactor {
|
||||||
public static Row<Double> generateRow(Covariate.Value<Double> x1, Covariate.Value<Double> x2, Covariate.Value<String> x3, int id){
|
public static Row<Double> generateRow(Covariate.Value<Double> x1, Covariate.Value<Double> x2, Covariate.Value<String> x3, int id){
|
||||||
double y = generateResponse(x1.getValue(), x2.getValue(), x3.getValue());
|
double y = generateResponse(x1.getValue(), x2.getValue(), x3.getValue());
|
||||||
|
|
||||||
final Map<String, Covariate.Value> map = Utils.easyMap("x1", x1, "x2", x2);
|
final Covariate.Value[] valueArray = new Covariate.Value[3];
|
||||||
|
valueArray[0] = x1;
|
||||||
|
valueArray[1] = x2;
|
||||||
|
valueArray[2] = x3;
|
||||||
|
|
||||||
return new Row<>(map, id, y);
|
//final Map<String, Covariate.Value> map = Utils.easyMap("x1", x1, "x2", x2); // Missing x3?
|
||||||
|
|
||||||
|
return new Row<>(valueArray, id, y);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
public static CovariateRow generateCovariateRow(Covariate.Value x1, Covariate.Value x2, Covariate.Value x3, int id){
|
public static CovariateRow generateCovariateRow(Covariate.Value x1, Covariate.Value x2, Covariate.Value x3, int id){
|
||||||
final Map<String, Covariate.Value> map = Utils.easyMap("x1", x1, "x2", x2, "x3", x3);
|
final Covariate.Value[] valueArray = new Covariate.Value[3];
|
||||||
|
valueArray[0] = x1;
|
||||||
|
valueArray[1] = x2;
|
||||||
|
valueArray[2] = x3;
|
||||||
|
|
||||||
return new CovariateRow(map, id);
|
return new CovariateRow(valueArray, id);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue