diff --git a/pom.xml b/pom.xml
index 01d9203..36c9a38 100644
--- a/pom.xml
+++ b/pom.xml
@@ -12,6 +12,7 @@
1.10
1.10
1.10
+ 2.9.6
@@ -23,6 +24,23 @@
provided
+
+ org.apache.commons
+ commons-csv
+ 1.5
+
+
+
+ com.fasterxml.jackson.core
+ jackson-databind
+ ${jackson.version}
+
+
+
+ com.fasterxml.jackson.dataformat
+ jackson-dataformat-yaml
+ ${jackson.version}
+
org.junit.jupiter
diff --git a/src/main/java/ca/joeltherrien/randomforest/Settings.java b/src/main/java/ca/joeltherrien/randomforest/Settings.java
new file mode 100644
index 0000000..a3fccf1
--- /dev/null
+++ b/src/main/java/ca/joeltherrien/randomforest/Settings.java
@@ -0,0 +1,62 @@
+package ca.joeltherrien.randomforest;
+
+import com.fasterxml.jackson.databind.ObjectMapper;
+import com.fasterxml.jackson.dataformat.yaml.YAMLFactory;
+import lombok.AllArgsConstructor;
+import lombok.Builder;
+import lombok.Data;
+
+import java.io.File;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
+
+/**
+ * This class is saved & loaded using a saved configuration file. It contains all relevant settings when training a forest.
+ */
+@Data
+@Builder
+@AllArgsConstructor
+public class Settings {
+
+ private int numberOfSplits = 5;
+ private int nodeSize = 5;
+ private int maxNodeDepth = 1000000; // basically no maxNodeDepth
+
+ private String responseCombiner;
+ private String groupDifferentiator;
+ private String treeResponseCombiner;
+
+ private List covariates = new ArrayList<>();
+
+ // number of covariates to randomly try
+ private int mtry = 0;
+
+ // number of trees to try
+ private int ntree = 500;
+
+ private String dataFileLocation = "data.csv";
+ private String saveTreeLocation = "trees/";
+
+ private int numberOfThreads = 1;
+ private boolean saveProgress = false;
+
+ public Settings(){} // required for Jackson
+
+ public static Settings load(File file) throws IOException {
+ final ObjectMapper mapper = new ObjectMapper(new YAMLFactory());
+
+ final Settings settings = mapper.readValue(file, Settings.class);
+
+ return settings;
+
+ }
+
+ public void save(File file) throws IOException {
+ final ObjectMapper mapper = new ObjectMapper(new YAMLFactory());
+
+ mapper.writeValue(file, this);
+ }
+
+
+}
diff --git a/src/test/java/ca/joeltherrien/randomforest/settings/TestPersistence.java b/src/test/java/ca/joeltherrien/randomforest/settings/TestPersistence.java
new file mode 100644
index 0000000..cb6bf19
--- /dev/null
+++ b/src/test/java/ca/joeltherrien/randomforest/settings/TestPersistence.java
@@ -0,0 +1,42 @@
+package ca.joeltherrien.randomforest.settings;
+
+import ca.joeltherrien.randomforest.Settings;
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import org.junit.jupiter.api.Test;
+
+import java.io.File;
+import java.io.IOException;
+import java.util.List;
+
+public class TestPersistence {
+
+ @Test
+ public void testSaving() throws IOException {
+ final Settings settingsOriginal = Settings.builder()
+ .covariates(List.of("x1", "x2", "x3"))
+ .dataFileLocation("data.csv")
+ .groupDifferentiator("WeightedVarianceGroupDifferentiator")
+ .responseCombiner("MeanResponseCombiner")
+ .treeResponseCombiner("MeanResponseCombiner")
+ .maxNodeDepth(100000)
+ .mtry(2)
+ .nodeSize(5)
+ .ntree(500)
+ .numberOfSplits(5)
+ .numberOfThreads(1)
+ .saveProgress(true)
+ .saveTreeLocation("trees/")
+ .build();
+
+ final File templateFile = new File("template.yaml");
+ settingsOriginal.save(templateFile);
+
+ final Settings reloadedSettings = Settings.load(templateFile);
+
+ assertEquals(settingsOriginal, reloadedSettings);
+
+ templateFile.delete();
+
+
+ }
+}
\ No newline at end of file