Add basic Settings class with persistence.
This commit is contained in:
parent
2cdcbe6cbf
commit
b010e79269
3 changed files with 122 additions and 0 deletions
18
pom.xml
18
pom.xml
|
@ -12,6 +12,7 @@
|
|||
<java.version>1.10</java.version>
|
||||
<maven.compiler.target>1.10</maven.compiler.target>
|
||||
<maven.compiler.source>1.10</maven.compiler.source>
|
||||
<jackson.version>2.9.6</jackson.version>
|
||||
</properties>
|
||||
|
||||
|
||||
|
@ -23,6 +24,23 @@
|
|||
<scope>provided</scope>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.apache.commons</groupId>
|
||||
<artifactId>commons-csv</artifactId>
|
||||
<version>1.5</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>com.fasterxml.jackson.core</groupId>
|
||||
<artifactId>jackson-databind</artifactId>
|
||||
<version>${jackson.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>com.fasterxml.jackson.dataformat</groupId>
|
||||
<artifactId>jackson-dataformat-yaml</artifactId>
|
||||
<version>${jackson.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.junit.jupiter</groupId>
|
||||
|
|
62
src/main/java/ca/joeltherrien/randomforest/Settings.java
Normal file
62
src/main/java/ca/joeltherrien/randomforest/Settings.java
Normal file
|
@ -0,0 +1,62 @@
|
|||
package ca.joeltherrien.randomforest;
|
||||
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import com.fasterxml.jackson.dataformat.yaml.YAMLFactory;
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Builder;
|
||||
import lombok.Data;
|
||||
|
||||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* This class is saved & loaded using a saved configuration file. It contains all relevant settings when training a forest.
|
||||
*/
|
||||
@Data
|
||||
@Builder
|
||||
@AllArgsConstructor
|
||||
public class Settings {
|
||||
|
||||
private int numberOfSplits = 5;
|
||||
private int nodeSize = 5;
|
||||
private int maxNodeDepth = 1000000; // basically no maxNodeDepth
|
||||
|
||||
private String responseCombiner;
|
||||
private String groupDifferentiator;
|
||||
private String treeResponseCombiner;
|
||||
|
||||
private List<String> covariates = new ArrayList<>();
|
||||
|
||||
// number of covariates to randomly try
|
||||
private int mtry = 0;
|
||||
|
||||
// number of trees to try
|
||||
private int ntree = 500;
|
||||
|
||||
private String dataFileLocation = "data.csv";
|
||||
private String saveTreeLocation = "trees/";
|
||||
|
||||
private int numberOfThreads = 1;
|
||||
private boolean saveProgress = false;
|
||||
|
||||
public Settings(){} // required for Jackson
|
||||
|
||||
public static Settings load(File file) throws IOException {
|
||||
final ObjectMapper mapper = new ObjectMapper(new YAMLFactory());
|
||||
|
||||
final Settings settings = mapper.readValue(file, Settings.class);
|
||||
|
||||
return settings;
|
||||
|
||||
}
|
||||
|
||||
public void save(File file) throws IOException {
|
||||
final ObjectMapper mapper = new ObjectMapper(new YAMLFactory());
|
||||
|
||||
mapper.writeValue(file, this);
|
||||
}
|
||||
|
||||
|
||||
}
|
|
@ -0,0 +1,42 @@
|
|||
package ca.joeltherrien.randomforest.settings;
|
||||
|
||||
import ca.joeltherrien.randomforest.Settings;
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
import java.util.List;
|
||||
|
||||
public class TestPersistence {
|
||||
|
||||
@Test
|
||||
public void testSaving() throws IOException {
|
||||
final Settings settingsOriginal = Settings.builder()
|
||||
.covariates(List.of("x1", "x2", "x3"))
|
||||
.dataFileLocation("data.csv")
|
||||
.groupDifferentiator("WeightedVarianceGroupDifferentiator")
|
||||
.responseCombiner("MeanResponseCombiner")
|
||||
.treeResponseCombiner("MeanResponseCombiner")
|
||||
.maxNodeDepth(100000)
|
||||
.mtry(2)
|
||||
.nodeSize(5)
|
||||
.ntree(500)
|
||||
.numberOfSplits(5)
|
||||
.numberOfThreads(1)
|
||||
.saveProgress(true)
|
||||
.saveTreeLocation("trees/")
|
||||
.build();
|
||||
|
||||
final File templateFile = new File("template.yaml");
|
||||
settingsOriginal.save(templateFile);
|
||||
|
||||
final Settings reloadedSettings = Settings.load(templateFile);
|
||||
|
||||
assertEquals(settingsOriginal, reloadedSettings);
|
||||
|
||||
templateFile.delete();
|
||||
|
||||
|
||||
}
|
||||
}
|
Loading…
Reference in a new issue