diff --git a/pom.xml b/pom.xml index 01d9203..36c9a38 100644 --- a/pom.xml +++ b/pom.xml @@ -12,6 +12,7 @@ 1.10 1.10 1.10 + 2.9.6 @@ -23,6 +24,23 @@ provided + + org.apache.commons + commons-csv + 1.5 + + + + com.fasterxml.jackson.core + jackson-databind + ${jackson.version} + + + + com.fasterxml.jackson.dataformat + jackson-dataformat-yaml + ${jackson.version} + org.junit.jupiter diff --git a/src/main/java/ca/joeltherrien/randomforest/Settings.java b/src/main/java/ca/joeltherrien/randomforest/Settings.java new file mode 100644 index 0000000..a3fccf1 --- /dev/null +++ b/src/main/java/ca/joeltherrien/randomforest/Settings.java @@ -0,0 +1,62 @@ +package ca.joeltherrien.randomforest; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.dataformat.yaml.YAMLFactory; +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Data; + +import java.io.File; +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +/** + * This class is saved & loaded using a saved configuration file. It contains all relevant settings when training a forest. + */ +@Data +@Builder +@AllArgsConstructor +public class Settings { + + private int numberOfSplits = 5; + private int nodeSize = 5; + private int maxNodeDepth = 1000000; // basically no maxNodeDepth + + private String responseCombiner; + private String groupDifferentiator; + private String treeResponseCombiner; + + private List covariates = new ArrayList<>(); + + // number of covariates to randomly try + private int mtry = 0; + + // number of trees to try + private int ntree = 500; + + private String dataFileLocation = "data.csv"; + private String saveTreeLocation = "trees/"; + + private int numberOfThreads = 1; + private boolean saveProgress = false; + + public Settings(){} // required for Jackson + + public static Settings load(File file) throws IOException { + final ObjectMapper mapper = new ObjectMapper(new YAMLFactory()); + + final Settings settings = mapper.readValue(file, Settings.class); + + return settings; + + } + + public void save(File file) throws IOException { + final ObjectMapper mapper = new ObjectMapper(new YAMLFactory()); + + mapper.writeValue(file, this); + } + + +} diff --git a/src/test/java/ca/joeltherrien/randomforest/settings/TestPersistence.java b/src/test/java/ca/joeltherrien/randomforest/settings/TestPersistence.java new file mode 100644 index 0000000..cb6bf19 --- /dev/null +++ b/src/test/java/ca/joeltherrien/randomforest/settings/TestPersistence.java @@ -0,0 +1,42 @@ +package ca.joeltherrien.randomforest.settings; + +import ca.joeltherrien.randomforest.Settings; +import static org.junit.jupiter.api.Assertions.assertEquals; +import org.junit.jupiter.api.Test; + +import java.io.File; +import java.io.IOException; +import java.util.List; + +public class TestPersistence { + + @Test + public void testSaving() throws IOException { + final Settings settingsOriginal = Settings.builder() + .covariates(List.of("x1", "x2", "x3")) + .dataFileLocation("data.csv") + .groupDifferentiator("WeightedVarianceGroupDifferentiator") + .responseCombiner("MeanResponseCombiner") + .treeResponseCombiner("MeanResponseCombiner") + .maxNodeDepth(100000) + .mtry(2) + .nodeSize(5) + .ntree(500) + .numberOfSplits(5) + .numberOfThreads(1) + .saveProgress(true) + .saveTreeLocation("trees/") + .build(); + + final File templateFile = new File("template.yaml"); + settingsOriginal.save(templateFile); + + final Settings reloadedSettings = Settings.load(templateFile); + + assertEquals(settingsOriginal, reloadedSettings); + + templateFile.delete(); + + + } +} \ No newline at end of file