Upgraded Settings class to allow for covariates to be built from

provided values.
This commit is contained in:
Joel Therrien 2018-07-05 19:04:26 -07:00
parent b010e79269
commit fe9ff37dcf
8 changed files with 117 additions and 8 deletions

View file

@ -1,10 +1,10 @@
package ca.joeltherrien.randomforest;
import ca.joeltherrien.randomforest.covariates.Covariate;
import ca.joeltherrien.randomforest.covariates.CovariateSettings;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.dataformat.yaml.YAMLFactory;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.*;
import java.io.File;
import java.io.IOException;
@ -17,6 +17,7 @@ import java.util.List;
@Data
@Builder
@AllArgsConstructor
@EqualsAndHashCode
public class Settings {
private int numberOfSplits = 5;
@ -27,7 +28,7 @@ public class Settings {
private String groupDifferentiator;
private String treeResponseCombiner;
private List<String> covariates = new ArrayList<>();
private List<CovariateSettings> covariates = new ArrayList<>();
// number of covariates to randomly try
private int mtry = 0;
@ -45,6 +46,7 @@ public class Settings {
public static Settings load(File file) throws IOException {
final ObjectMapper mapper = new ObjectMapper(new YAMLFactory());
//mapper.enableDefaultTyping();
final Settings settings = mapper.readValue(file, Settings.class);
@ -54,6 +56,10 @@ public class Settings {
public void save(File file) throws IOException {
final ObjectMapper mapper = new ObjectMapper(new YAMLFactory());
//mapper.enableDefaultTyping();
// Jackson can struggle with some types of Lists, such as that returned by the useful List.of(...)
this.covariates = new ArrayList<>(this.covariates);
mapper.writeValue(file, this);
}

View file

@ -3,10 +3,12 @@ package ca.joeltherrien.randomforest.covariates;
import lombok.Getter;
import lombok.RequiredArgsConstructor;
import java.util.*;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
@RequiredArgsConstructor
public class BooleanCovariate implements Covariate<Boolean>{
public final class BooleanCovariate implements Covariate<Boolean>{
@Getter
private final String name;

View file

@ -0,0 +1,18 @@
package ca.joeltherrien.randomforest.covariates;
import lombok.Data;
import lombok.NoArgsConstructor;
@NoArgsConstructor // required by Jackson
@Data
public final class BooleanCovariateSettings extends CovariateSettings<Boolean> {
public BooleanCovariateSettings(String name){
super(name);
}
@Override
BooleanCovariate build() {
return new BooleanCovariate(name);
}
}

View file

@ -0,0 +1,33 @@
package ca.joeltherrien.randomforest.covariates;
import com.fasterxml.jackson.annotation.JsonSubTypes;
import com.fasterxml.jackson.annotation.JsonTypeInfo;
import lombok.Getter;
import lombok.NoArgsConstructor;
/**
* Nuisance class to work with Jackson for persisting settings.
*
* @param <V>
*/
@NoArgsConstructor // required for Jackson
@Getter
@JsonTypeInfo(
use = JsonTypeInfo.Id.NAME,
include = JsonTypeInfo.As.PROPERTY,
property = "type")
@JsonSubTypes({
@JsonSubTypes.Type(value = BooleanCovariateSettings.class, name = "boolean"),
@JsonSubTypes.Type(value = NumericCovariateSettings.class, name = "numeric"),
@JsonSubTypes.Type(value = FactorCovariateSettings.class, name = "factor")
})
public abstract class CovariateSettings<V> {
String name;
CovariateSettings(String name){
this.name = name;
}
abstract Covariate<V> build();
}

View file

@ -0,0 +1,24 @@
package ca.joeltherrien.randomforest.covariates;
import lombok.Data;
import lombok.NoArgsConstructor;
import java.util.ArrayList;
import java.util.List;
@NoArgsConstructor // required by Jackson
@Data
public final class FactorCovariateSettings extends CovariateSettings<String> {
private List<String> levels;
public FactorCovariateSettings(String name, List<String> levels){
super(name);
this.levels = new ArrayList<>(levels); // Jackson struggles with List.of(...)
}
@Override
FactorCovariate build() {
return new FactorCovariate(name, levels);
}
}

View file

@ -8,7 +8,7 @@ import java.util.concurrent.ThreadLocalRandom;
import java.util.stream.Collectors;
@RequiredArgsConstructor
public class NumericCovariate implements Covariate<Double>{
public final class NumericCovariate implements Covariate<Double>{
@Getter
private final String name;

View file

@ -0,0 +1,18 @@
package ca.joeltherrien.randomforest.covariates;
import lombok.Data;
import lombok.NoArgsConstructor;
@NoArgsConstructor // required by Jackson
@Data
public final class NumericCovariateSettings extends CovariateSettings<Double> {
public NumericCovariateSettings(String name){
super(name);
}
@Override
NumericCovariate build() {
return new NumericCovariate(name);
}
}

View file

@ -2,10 +2,13 @@ package ca.joeltherrien.randomforest.settings;
import ca.joeltherrien.randomforest.Settings;
import static org.junit.jupiter.api.Assertions.assertEquals;
import ca.joeltherrien.randomforest.covariates.*;
import org.junit.jupiter.api.Test;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
public class TestPersistence {
@ -13,7 +16,12 @@ public class TestPersistence {
@Test
public void testSaving() throws IOException {
final Settings settingsOriginal = Settings.builder()
.covariates(List.of("x1", "x2", "x3"))
.covariates(List.of(
new NumericCovariateSettings("x1"),
new BooleanCovariateSettings("x2"),
new FactorCovariateSettings("x3", List.of("cat", "mouse", "dog"))
)
)
.dataFileLocation("data.csv")
.groupDifferentiator("WeightedVarianceGroupDifferentiator")
.responseCombiner("MeanResponseCombiner")