Upgraded Settings class to allow for covariates to be built from
provided values.
This commit is contained in:
parent
b010e79269
commit
fe9ff37dcf
8 changed files with 117 additions and 8 deletions
|
@ -1,10 +1,10 @@
|
|||
package ca.joeltherrien.randomforest;
|
||||
|
||||
import ca.joeltherrien.randomforest.covariates.Covariate;
|
||||
import ca.joeltherrien.randomforest.covariates.CovariateSettings;
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import com.fasterxml.jackson.dataformat.yaml.YAMLFactory;
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Builder;
|
||||
import lombok.Data;
|
||||
import lombok.*;
|
||||
|
||||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
|
@ -17,6 +17,7 @@ import java.util.List;
|
|||
@Data
|
||||
@Builder
|
||||
@AllArgsConstructor
|
||||
@EqualsAndHashCode
|
||||
public class Settings {
|
||||
|
||||
private int numberOfSplits = 5;
|
||||
|
@ -27,7 +28,7 @@ public class Settings {
|
|||
private String groupDifferentiator;
|
||||
private String treeResponseCombiner;
|
||||
|
||||
private List<String> covariates = new ArrayList<>();
|
||||
private List<CovariateSettings> covariates = new ArrayList<>();
|
||||
|
||||
// number of covariates to randomly try
|
||||
private int mtry = 0;
|
||||
|
@ -45,6 +46,7 @@ public class Settings {
|
|||
|
||||
public static Settings load(File file) throws IOException {
|
||||
final ObjectMapper mapper = new ObjectMapper(new YAMLFactory());
|
||||
//mapper.enableDefaultTyping();
|
||||
|
||||
final Settings settings = mapper.readValue(file, Settings.class);
|
||||
|
||||
|
@ -54,6 +56,10 @@ public class Settings {
|
|||
|
||||
public void save(File file) throws IOException {
|
||||
final ObjectMapper mapper = new ObjectMapper(new YAMLFactory());
|
||||
//mapper.enableDefaultTyping();
|
||||
|
||||
// Jackson can struggle with some types of Lists, such as that returned by the useful List.of(...)
|
||||
this.covariates = new ArrayList<>(this.covariates);
|
||||
|
||||
mapper.writeValue(file, this);
|
||||
}
|
||||
|
|
|
@ -3,10 +3,12 @@ package ca.joeltherrien.randomforest.covariates;
|
|||
import lombok.Getter;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
|
||||
import java.util.*;
|
||||
import java.util.Collection;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
|
||||
@RequiredArgsConstructor
|
||||
public class BooleanCovariate implements Covariate<Boolean>{
|
||||
public final class BooleanCovariate implements Covariate<Boolean>{
|
||||
|
||||
@Getter
|
||||
private final String name;
|
||||
|
|
|
@ -0,0 +1,18 @@
|
|||
package ca.joeltherrien.randomforest.covariates;
|
||||
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
|
||||
@NoArgsConstructor // required by Jackson
|
||||
@Data
|
||||
public final class BooleanCovariateSettings extends CovariateSettings<Boolean> {
|
||||
|
||||
public BooleanCovariateSettings(String name){
|
||||
super(name);
|
||||
}
|
||||
|
||||
@Override
|
||||
BooleanCovariate build() {
|
||||
return new BooleanCovariate(name);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,33 @@
|
|||
package ca.joeltherrien.randomforest.covariates;
|
||||
|
||||
import com.fasterxml.jackson.annotation.JsonSubTypes;
|
||||
import com.fasterxml.jackson.annotation.JsonTypeInfo;
|
||||
import lombok.Getter;
|
||||
import lombok.NoArgsConstructor;
|
||||
|
||||
/**
|
||||
* Nuisance class to work with Jackson for persisting settings.
|
||||
*
|
||||
* @param <V>
|
||||
*/
|
||||
@NoArgsConstructor // required for Jackson
|
||||
@Getter
|
||||
@JsonTypeInfo(
|
||||
use = JsonTypeInfo.Id.NAME,
|
||||
include = JsonTypeInfo.As.PROPERTY,
|
||||
property = "type")
|
||||
@JsonSubTypes({
|
||||
@JsonSubTypes.Type(value = BooleanCovariateSettings.class, name = "boolean"),
|
||||
@JsonSubTypes.Type(value = NumericCovariateSettings.class, name = "numeric"),
|
||||
@JsonSubTypes.Type(value = FactorCovariateSettings.class, name = "factor")
|
||||
})
|
||||
public abstract class CovariateSettings<V> {
|
||||
|
||||
String name;
|
||||
|
||||
CovariateSettings(String name){
|
||||
this.name = name;
|
||||
}
|
||||
|
||||
abstract Covariate<V> build();
|
||||
}
|
|
@ -0,0 +1,24 @@
|
|||
package ca.joeltherrien.randomforest.covariates;
|
||||
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
@NoArgsConstructor // required by Jackson
|
||||
@Data
|
||||
public final class FactorCovariateSettings extends CovariateSettings<String> {
|
||||
|
||||
private List<String> levels;
|
||||
|
||||
public FactorCovariateSettings(String name, List<String> levels){
|
||||
super(name);
|
||||
this.levels = new ArrayList<>(levels); // Jackson struggles with List.of(...)
|
||||
}
|
||||
|
||||
@Override
|
||||
FactorCovariate build() {
|
||||
return new FactorCovariate(name, levels);
|
||||
}
|
||||
}
|
|
@ -8,7 +8,7 @@ import java.util.concurrent.ThreadLocalRandom;
|
|||
import java.util.stream.Collectors;
|
||||
|
||||
@RequiredArgsConstructor
|
||||
public class NumericCovariate implements Covariate<Double>{
|
||||
public final class NumericCovariate implements Covariate<Double>{
|
||||
|
||||
@Getter
|
||||
private final String name;
|
||||
|
|
|
@ -0,0 +1,18 @@
|
|||
package ca.joeltherrien.randomforest.covariates;
|
||||
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
|
||||
@NoArgsConstructor // required by Jackson
|
||||
@Data
|
||||
public final class NumericCovariateSettings extends CovariateSettings<Double> {
|
||||
|
||||
public NumericCovariateSettings(String name){
|
||||
super(name);
|
||||
}
|
||||
|
||||
@Override
|
||||
NumericCovariate build() {
|
||||
return new NumericCovariate(name);
|
||||
}
|
||||
}
|
|
@ -2,10 +2,13 @@ package ca.joeltherrien.randomforest.settings;
|
|||
|
||||
import ca.joeltherrien.randomforest.Settings;
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
|
||||
import ca.joeltherrien.randomforest.covariates.*;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
public class TestPersistence {
|
||||
|
@ -13,7 +16,12 @@ public class TestPersistence {
|
|||
@Test
|
||||
public void testSaving() throws IOException {
|
||||
final Settings settingsOriginal = Settings.builder()
|
||||
.covariates(List.of("x1", "x2", "x3"))
|
||||
.covariates(List.of(
|
||||
new NumericCovariateSettings("x1"),
|
||||
new BooleanCovariateSettings("x2"),
|
||||
new FactorCovariateSettings("x3", List.of("cat", "mouse", "dog"))
|
||||
)
|
||||
)
|
||||
.dataFileLocation("data.csv")
|
||||
.groupDifferentiator("WeightedVarianceGroupDifferentiator")
|
||||
.responseCombiner("MeanResponseCombiner")
|
||||
|
|
Loading…
Reference in a new issue