Add basic Settings class with persistence.

This commit is contained in:
Joel Therrien 2018-07-05 13:59:32 -07:00
parent 2cdcbe6cbf
commit b010e79269
3 changed files with 122 additions and 0 deletions

18
pom.xml
View file

@ -12,6 +12,7 @@
<java.version>1.10</java.version>
<maven.compiler.target>1.10</maven.compiler.target>
<maven.compiler.source>1.10</maven.compiler.source>
<jackson.version>2.9.6</jackson.version>
</properties>
@ -23,6 +24,23 @@
<scope>provided</scope>
</dependency>
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-csv</artifactId>
<version>1.5</version>
</dependency>
<dependency>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-databind</artifactId>
<version>${jackson.version}</version>
</dependency>
<dependency>
<groupId>com.fasterxml.jackson.dataformat</groupId>
<artifactId>jackson-dataformat-yaml</artifactId>
<version>${jackson.version}</version>
</dependency>
<dependency>
<groupId>org.junit.jupiter</groupId>

View file

@ -0,0 +1,62 @@
package ca.joeltherrien.randomforest;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.dataformat.yaml.YAMLFactory;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
/**
* This class is saved & loaded using a saved configuration file. It contains all relevant settings when training a forest.
*/
@Data
@Builder
@AllArgsConstructor
public class Settings {
private int numberOfSplits = 5;
private int nodeSize = 5;
private int maxNodeDepth = 1000000; // basically no maxNodeDepth
private String responseCombiner;
private String groupDifferentiator;
private String treeResponseCombiner;
private List<String> covariates = new ArrayList<>();
// number of covariates to randomly try
private int mtry = 0;
// number of trees to try
private int ntree = 500;
private String dataFileLocation = "data.csv";
private String saveTreeLocation = "trees/";
private int numberOfThreads = 1;
private boolean saveProgress = false;
public Settings(){} // required for Jackson
public static Settings load(File file) throws IOException {
final ObjectMapper mapper = new ObjectMapper(new YAMLFactory());
final Settings settings = mapper.readValue(file, Settings.class);
return settings;
}
public void save(File file) throws IOException {
final ObjectMapper mapper = new ObjectMapper(new YAMLFactory());
mapper.writeValue(file, this);
}
}

View file

@ -0,0 +1,42 @@
package ca.joeltherrien.randomforest.settings;
import ca.joeltherrien.randomforest.Settings;
import static org.junit.jupiter.api.Assertions.assertEquals;
import org.junit.jupiter.api.Test;
import java.io.File;
import java.io.IOException;
import java.util.List;
public class TestPersistence {
@Test
public void testSaving() throws IOException {
final Settings settingsOriginal = Settings.builder()
.covariates(List.of("x1", "x2", "x3"))
.dataFileLocation("data.csv")
.groupDifferentiator("WeightedVarianceGroupDifferentiator")
.responseCombiner("MeanResponseCombiner")
.treeResponseCombiner("MeanResponseCombiner")
.maxNodeDepth(100000)
.mtry(2)
.nodeSize(5)
.ntree(500)
.numberOfSplits(5)
.numberOfThreads(1)
.saveProgress(true)
.saveTreeLocation("trees/")
.build();
final File templateFile = new File("template.yaml");
settingsOriginal.save(templateFile);
final Settings reloadedSettings = Settings.load(templateFile);
assertEquals(settingsOriginal, reloadedSettings);
templateFile.delete();
}
}