Upgraded Settings class to allow for covariates to be built from
provided values.
This commit is contained in:
parent
b010e79269
commit
fe9ff37dcf
8 changed files with 117 additions and 8 deletions
|
@ -1,10 +1,10 @@
|
||||||
package ca.joeltherrien.randomforest;
|
package ca.joeltherrien.randomforest;
|
||||||
|
|
||||||
|
import ca.joeltherrien.randomforest.covariates.Covariate;
|
||||||
|
import ca.joeltherrien.randomforest.covariates.CovariateSettings;
|
||||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||||
import com.fasterxml.jackson.dataformat.yaml.YAMLFactory;
|
import com.fasterxml.jackson.dataformat.yaml.YAMLFactory;
|
||||||
import lombok.AllArgsConstructor;
|
import lombok.*;
|
||||||
import lombok.Builder;
|
|
||||||
import lombok.Data;
|
|
||||||
|
|
||||||
import java.io.File;
|
import java.io.File;
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
|
@ -17,6 +17,7 @@ import java.util.List;
|
||||||
@Data
|
@Data
|
||||||
@Builder
|
@Builder
|
||||||
@AllArgsConstructor
|
@AllArgsConstructor
|
||||||
|
@EqualsAndHashCode
|
||||||
public class Settings {
|
public class Settings {
|
||||||
|
|
||||||
private int numberOfSplits = 5;
|
private int numberOfSplits = 5;
|
||||||
|
@ -27,7 +28,7 @@ public class Settings {
|
||||||
private String groupDifferentiator;
|
private String groupDifferentiator;
|
||||||
private String treeResponseCombiner;
|
private String treeResponseCombiner;
|
||||||
|
|
||||||
private List<String> covariates = new ArrayList<>();
|
private List<CovariateSettings> covariates = new ArrayList<>();
|
||||||
|
|
||||||
// number of covariates to randomly try
|
// number of covariates to randomly try
|
||||||
private int mtry = 0;
|
private int mtry = 0;
|
||||||
|
@ -45,6 +46,7 @@ public class Settings {
|
||||||
|
|
||||||
public static Settings load(File file) throws IOException {
|
public static Settings load(File file) throws IOException {
|
||||||
final ObjectMapper mapper = new ObjectMapper(new YAMLFactory());
|
final ObjectMapper mapper = new ObjectMapper(new YAMLFactory());
|
||||||
|
//mapper.enableDefaultTyping();
|
||||||
|
|
||||||
final Settings settings = mapper.readValue(file, Settings.class);
|
final Settings settings = mapper.readValue(file, Settings.class);
|
||||||
|
|
||||||
|
@ -54,6 +56,10 @@ public class Settings {
|
||||||
|
|
||||||
public void save(File file) throws IOException {
|
public void save(File file) throws IOException {
|
||||||
final ObjectMapper mapper = new ObjectMapper(new YAMLFactory());
|
final ObjectMapper mapper = new ObjectMapper(new YAMLFactory());
|
||||||
|
//mapper.enableDefaultTyping();
|
||||||
|
|
||||||
|
// Jackson can struggle with some types of Lists, such as that returned by the useful List.of(...)
|
||||||
|
this.covariates = new ArrayList<>(this.covariates);
|
||||||
|
|
||||||
mapper.writeValue(file, this);
|
mapper.writeValue(file, this);
|
||||||
}
|
}
|
||||||
|
|
|
@ -3,10 +3,12 @@ package ca.joeltherrien.randomforest.covariates;
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
import lombok.RequiredArgsConstructor;
|
import lombok.RequiredArgsConstructor;
|
||||||
|
|
||||||
import java.util.*;
|
import java.util.Collection;
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
@RequiredArgsConstructor
|
@RequiredArgsConstructor
|
||||||
public class BooleanCovariate implements Covariate<Boolean>{
|
public final class BooleanCovariate implements Covariate<Boolean>{
|
||||||
|
|
||||||
@Getter
|
@Getter
|
||||||
private final String name;
|
private final String name;
|
||||||
|
|
|
@ -0,0 +1,18 @@
|
||||||
|
package ca.joeltherrien.randomforest.covariates;
|
||||||
|
|
||||||
|
import lombok.Data;
|
||||||
|
import lombok.NoArgsConstructor;
|
||||||
|
|
||||||
|
@NoArgsConstructor // required by Jackson
|
||||||
|
@Data
|
||||||
|
public final class BooleanCovariateSettings extends CovariateSettings<Boolean> {
|
||||||
|
|
||||||
|
public BooleanCovariateSettings(String name){
|
||||||
|
super(name);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
BooleanCovariate build() {
|
||||||
|
return new BooleanCovariate(name);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,33 @@
|
||||||
|
package ca.joeltherrien.randomforest.covariates;
|
||||||
|
|
||||||
|
import com.fasterxml.jackson.annotation.JsonSubTypes;
|
||||||
|
import com.fasterxml.jackson.annotation.JsonTypeInfo;
|
||||||
|
import lombok.Getter;
|
||||||
|
import lombok.NoArgsConstructor;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Nuisance class to work with Jackson for persisting settings.
|
||||||
|
*
|
||||||
|
* @param <V>
|
||||||
|
*/
|
||||||
|
@NoArgsConstructor // required for Jackson
|
||||||
|
@Getter
|
||||||
|
@JsonTypeInfo(
|
||||||
|
use = JsonTypeInfo.Id.NAME,
|
||||||
|
include = JsonTypeInfo.As.PROPERTY,
|
||||||
|
property = "type")
|
||||||
|
@JsonSubTypes({
|
||||||
|
@JsonSubTypes.Type(value = BooleanCovariateSettings.class, name = "boolean"),
|
||||||
|
@JsonSubTypes.Type(value = NumericCovariateSettings.class, name = "numeric"),
|
||||||
|
@JsonSubTypes.Type(value = FactorCovariateSettings.class, name = "factor")
|
||||||
|
})
|
||||||
|
public abstract class CovariateSettings<V> {
|
||||||
|
|
||||||
|
String name;
|
||||||
|
|
||||||
|
CovariateSettings(String name){
|
||||||
|
this.name = name;
|
||||||
|
}
|
||||||
|
|
||||||
|
abstract Covariate<V> build();
|
||||||
|
}
|
|
@ -0,0 +1,24 @@
|
||||||
|
package ca.joeltherrien.randomforest.covariates;
|
||||||
|
|
||||||
|
import lombok.Data;
|
||||||
|
import lombok.NoArgsConstructor;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
@NoArgsConstructor // required by Jackson
|
||||||
|
@Data
|
||||||
|
public final class FactorCovariateSettings extends CovariateSettings<String> {
|
||||||
|
|
||||||
|
private List<String> levels;
|
||||||
|
|
||||||
|
public FactorCovariateSettings(String name, List<String> levels){
|
||||||
|
super(name);
|
||||||
|
this.levels = new ArrayList<>(levels); // Jackson struggles with List.of(...)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
FactorCovariate build() {
|
||||||
|
return new FactorCovariate(name, levels);
|
||||||
|
}
|
||||||
|
}
|
|
@ -8,7 +8,7 @@ import java.util.concurrent.ThreadLocalRandom;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
@RequiredArgsConstructor
|
@RequiredArgsConstructor
|
||||||
public class NumericCovariate implements Covariate<Double>{
|
public final class NumericCovariate implements Covariate<Double>{
|
||||||
|
|
||||||
@Getter
|
@Getter
|
||||||
private final String name;
|
private final String name;
|
||||||
|
|
|
@ -0,0 +1,18 @@
|
||||||
|
package ca.joeltherrien.randomforest.covariates;
|
||||||
|
|
||||||
|
import lombok.Data;
|
||||||
|
import lombok.NoArgsConstructor;
|
||||||
|
|
||||||
|
@NoArgsConstructor // required by Jackson
|
||||||
|
@Data
|
||||||
|
public final class NumericCovariateSettings extends CovariateSettings<Double> {
|
||||||
|
|
||||||
|
public NumericCovariateSettings(String name){
|
||||||
|
super(name);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
NumericCovariate build() {
|
||||||
|
return new NumericCovariate(name);
|
||||||
|
}
|
||||||
|
}
|
|
@ -2,10 +2,13 @@ package ca.joeltherrien.randomforest.settings;
|
||||||
|
|
||||||
import ca.joeltherrien.randomforest.Settings;
|
import ca.joeltherrien.randomforest.Settings;
|
||||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
|
||||||
|
import ca.joeltherrien.randomforest.covariates.*;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
|
|
||||||
import java.io.File;
|
import java.io.File;
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
public class TestPersistence {
|
public class TestPersistence {
|
||||||
|
@ -13,7 +16,12 @@ public class TestPersistence {
|
||||||
@Test
|
@Test
|
||||||
public void testSaving() throws IOException {
|
public void testSaving() throws IOException {
|
||||||
final Settings settingsOriginal = Settings.builder()
|
final Settings settingsOriginal = Settings.builder()
|
||||||
.covariates(List.of("x1", "x2", "x3"))
|
.covariates(List.of(
|
||||||
|
new NumericCovariateSettings("x1"),
|
||||||
|
new BooleanCovariateSettings("x2"),
|
||||||
|
new FactorCovariateSettings("x3", List.of("cat", "mouse", "dog"))
|
||||||
|
)
|
||||||
|
)
|
||||||
.dataFileLocation("data.csv")
|
.dataFileLocation("data.csv")
|
||||||
.groupDifferentiator("WeightedVarianceGroupDifferentiator")
|
.groupDifferentiator("WeightedVarianceGroupDifferentiator")
|
||||||
.responseCombiner("MeanResponseCombiner")
|
.responseCombiner("MeanResponseCombiner")
|
||||||
|
|
Loading…
Reference in a new issue