diff --git a/src/main/java/ca/joeltherrien/randomforest/Settings.java b/src/main/java/ca/joeltherrien/randomforest/Settings.java index a3fccf1..bf3985a 100644 --- a/src/main/java/ca/joeltherrien/randomforest/Settings.java +++ b/src/main/java/ca/joeltherrien/randomforest/Settings.java @@ -1,10 +1,10 @@ package ca.joeltherrien.randomforest; +import ca.joeltherrien.randomforest.covariates.Covariate; +import ca.joeltherrien.randomforest.covariates.CovariateSettings; import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.dataformat.yaml.YAMLFactory; -import lombok.AllArgsConstructor; -import lombok.Builder; -import lombok.Data; +import lombok.*; import java.io.File; import java.io.IOException; @@ -17,6 +17,7 @@ import java.util.List; @Data @Builder @AllArgsConstructor +@EqualsAndHashCode public class Settings { private int numberOfSplits = 5; @@ -27,7 +28,7 @@ public class Settings { private String groupDifferentiator; private String treeResponseCombiner; - private List covariates = new ArrayList<>(); + private List covariates = new ArrayList<>(); // number of covariates to randomly try private int mtry = 0; @@ -45,6 +46,7 @@ public class Settings { public static Settings load(File file) throws IOException { final ObjectMapper mapper = new ObjectMapper(new YAMLFactory()); + //mapper.enableDefaultTyping(); final Settings settings = mapper.readValue(file, Settings.class); @@ -54,6 +56,10 @@ public class Settings { public void save(File file) throws IOException { final ObjectMapper mapper = new ObjectMapper(new YAMLFactory()); + //mapper.enableDefaultTyping(); + + // Jackson can struggle with some types of Lists, such as that returned by the useful List.of(...) + this.covariates = new ArrayList<>(this.covariates); mapper.writeValue(file, this); } diff --git a/src/main/java/ca/joeltherrien/randomforest/covariates/BooleanCovariate.java b/src/main/java/ca/joeltherrien/randomforest/covariates/BooleanCovariate.java index b57df47..6c6548d 100644 --- a/src/main/java/ca/joeltherrien/randomforest/covariates/BooleanCovariate.java +++ b/src/main/java/ca/joeltherrien/randomforest/covariates/BooleanCovariate.java @@ -3,10 +3,12 @@ package ca.joeltherrien.randomforest.covariates; import lombok.Getter; import lombok.RequiredArgsConstructor; -import java.util.*; +import java.util.Collection; +import java.util.Collections; +import java.util.List; @RequiredArgsConstructor -public class BooleanCovariate implements Covariate{ +public final class BooleanCovariate implements Covariate{ @Getter private final String name; diff --git a/src/main/java/ca/joeltherrien/randomforest/covariates/BooleanCovariateSettings.java b/src/main/java/ca/joeltherrien/randomforest/covariates/BooleanCovariateSettings.java new file mode 100644 index 0000000..2bf6263 --- /dev/null +++ b/src/main/java/ca/joeltherrien/randomforest/covariates/BooleanCovariateSettings.java @@ -0,0 +1,18 @@ +package ca.joeltherrien.randomforest.covariates; + +import lombok.Data; +import lombok.NoArgsConstructor; + +@NoArgsConstructor // required by Jackson +@Data +public final class BooleanCovariateSettings extends CovariateSettings { + + public BooleanCovariateSettings(String name){ + super(name); + } + + @Override + BooleanCovariate build() { + return new BooleanCovariate(name); + } +} diff --git a/src/main/java/ca/joeltherrien/randomforest/covariates/CovariateSettings.java b/src/main/java/ca/joeltherrien/randomforest/covariates/CovariateSettings.java new file mode 100644 index 0000000..baafe4c --- /dev/null +++ b/src/main/java/ca/joeltherrien/randomforest/covariates/CovariateSettings.java @@ -0,0 +1,33 @@ +package ca.joeltherrien.randomforest.covariates; + +import com.fasterxml.jackson.annotation.JsonSubTypes; +import com.fasterxml.jackson.annotation.JsonTypeInfo; +import lombok.Getter; +import lombok.NoArgsConstructor; + +/** + * Nuisance class to work with Jackson for persisting settings. + * + * @param + */ +@NoArgsConstructor // required for Jackson +@Getter +@JsonTypeInfo( + use = JsonTypeInfo.Id.NAME, + include = JsonTypeInfo.As.PROPERTY, + property = "type") +@JsonSubTypes({ + @JsonSubTypes.Type(value = BooleanCovariateSettings.class, name = "boolean"), + @JsonSubTypes.Type(value = NumericCovariateSettings.class, name = "numeric"), + @JsonSubTypes.Type(value = FactorCovariateSettings.class, name = "factor") +}) +public abstract class CovariateSettings { + + String name; + + CovariateSettings(String name){ + this.name = name; + } + + abstract Covariate build(); +} diff --git a/src/main/java/ca/joeltherrien/randomforest/covariates/FactorCovariateSettings.java b/src/main/java/ca/joeltherrien/randomforest/covariates/FactorCovariateSettings.java new file mode 100644 index 0000000..04d4bb8 --- /dev/null +++ b/src/main/java/ca/joeltherrien/randomforest/covariates/FactorCovariateSettings.java @@ -0,0 +1,24 @@ +package ca.joeltherrien.randomforest.covariates; + +import lombok.Data; +import lombok.NoArgsConstructor; + +import java.util.ArrayList; +import java.util.List; + +@NoArgsConstructor // required by Jackson +@Data +public final class FactorCovariateSettings extends CovariateSettings { + + private List levels; + + public FactorCovariateSettings(String name, List levels){ + super(name); + this.levels = new ArrayList<>(levels); // Jackson struggles with List.of(...) + } + + @Override + FactorCovariate build() { + return new FactorCovariate(name, levels); + } +} diff --git a/src/main/java/ca/joeltherrien/randomforest/covariates/NumericCovariate.java b/src/main/java/ca/joeltherrien/randomforest/covariates/NumericCovariate.java index 155beee..6bed687 100644 --- a/src/main/java/ca/joeltherrien/randomforest/covariates/NumericCovariate.java +++ b/src/main/java/ca/joeltherrien/randomforest/covariates/NumericCovariate.java @@ -8,7 +8,7 @@ import java.util.concurrent.ThreadLocalRandom; import java.util.stream.Collectors; @RequiredArgsConstructor -public class NumericCovariate implements Covariate{ +public final class NumericCovariate implements Covariate{ @Getter private final String name; diff --git a/src/main/java/ca/joeltherrien/randomforest/covariates/NumericCovariateSettings.java b/src/main/java/ca/joeltherrien/randomforest/covariates/NumericCovariateSettings.java new file mode 100644 index 0000000..0be6cf0 --- /dev/null +++ b/src/main/java/ca/joeltherrien/randomforest/covariates/NumericCovariateSettings.java @@ -0,0 +1,18 @@ +package ca.joeltherrien.randomforest.covariates; + +import lombok.Data; +import lombok.NoArgsConstructor; + +@NoArgsConstructor // required by Jackson +@Data +public final class NumericCovariateSettings extends CovariateSettings { + + public NumericCovariateSettings(String name){ + super(name); + } + + @Override + NumericCovariate build() { + return new NumericCovariate(name); + } +} diff --git a/src/test/java/ca/joeltherrien/randomforest/settings/TestPersistence.java b/src/test/java/ca/joeltherrien/randomforest/settings/TestPersistence.java index cb6bf19..562df01 100644 --- a/src/test/java/ca/joeltherrien/randomforest/settings/TestPersistence.java +++ b/src/test/java/ca/joeltherrien/randomforest/settings/TestPersistence.java @@ -2,10 +2,13 @@ package ca.joeltherrien.randomforest.settings; import ca.joeltherrien.randomforest.Settings; import static org.junit.jupiter.api.Assertions.assertEquals; + +import ca.joeltherrien.randomforest.covariates.*; import org.junit.jupiter.api.Test; import java.io.File; import java.io.IOException; +import java.util.ArrayList; import java.util.List; public class TestPersistence { @@ -13,7 +16,12 @@ public class TestPersistence { @Test public void testSaving() throws IOException { final Settings settingsOriginal = Settings.builder() - .covariates(List.of("x1", "x2", "x3")) + .covariates(List.of( + new NumericCovariateSettings("x1"), + new BooleanCovariateSettings("x2"), + new FactorCovariateSettings("x3", List.of("cat", "mouse", "dog")) + ) + ) .dataFileLocation("data.csv") .groupDifferentiator("WeightedVarianceGroupDifferentiator") .responseCombiner("MeanResponseCombiner")