Add basic Settings class with persistence.
This commit is contained in:
parent
2cdcbe6cbf
commit
b010e79269
3 changed files with 122 additions and 0 deletions
18
pom.xml
18
pom.xml
|
@ -12,6 +12,7 @@
|
||||||
<java.version>1.10</java.version>
|
<java.version>1.10</java.version>
|
||||||
<maven.compiler.target>1.10</maven.compiler.target>
|
<maven.compiler.target>1.10</maven.compiler.target>
|
||||||
<maven.compiler.source>1.10</maven.compiler.source>
|
<maven.compiler.source>1.10</maven.compiler.source>
|
||||||
|
<jackson.version>2.9.6</jackson.version>
|
||||||
</properties>
|
</properties>
|
||||||
|
|
||||||
|
|
||||||
|
@ -23,6 +24,23 @@
|
||||||
<scope>provided</scope>
|
<scope>provided</scope>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.apache.commons</groupId>
|
||||||
|
<artifactId>commons-csv</artifactId>
|
||||||
|
<version>1.5</version>
|
||||||
|
</dependency>
|
||||||
|
|
||||||
|
<dependency>
|
||||||
|
<groupId>com.fasterxml.jackson.core</groupId>
|
||||||
|
<artifactId>jackson-databind</artifactId>
|
||||||
|
<version>${jackson.version}</version>
|
||||||
|
</dependency>
|
||||||
|
|
||||||
|
<dependency>
|
||||||
|
<groupId>com.fasterxml.jackson.dataformat</groupId>
|
||||||
|
<artifactId>jackson-dataformat-yaml</artifactId>
|
||||||
|
<version>${jackson.version}</version>
|
||||||
|
</dependency>
|
||||||
|
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>org.junit.jupiter</groupId>
|
<groupId>org.junit.jupiter</groupId>
|
||||||
|
|
62
src/main/java/ca/joeltherrien/randomforest/Settings.java
Normal file
62
src/main/java/ca/joeltherrien/randomforest/Settings.java
Normal file
|
@ -0,0 +1,62 @@
|
||||||
|
package ca.joeltherrien.randomforest;
|
||||||
|
|
||||||
|
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||||
|
import com.fasterxml.jackson.dataformat.yaml.YAMLFactory;
|
||||||
|
import lombok.AllArgsConstructor;
|
||||||
|
import lombok.Builder;
|
||||||
|
import lombok.Data;
|
||||||
|
|
||||||
|
import java.io.File;
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This class is saved & loaded using a saved configuration file. It contains all relevant settings when training a forest.
|
||||||
|
*/
|
||||||
|
@Data
|
||||||
|
@Builder
|
||||||
|
@AllArgsConstructor
|
||||||
|
public class Settings {
|
||||||
|
|
||||||
|
private int numberOfSplits = 5;
|
||||||
|
private int nodeSize = 5;
|
||||||
|
private int maxNodeDepth = 1000000; // basically no maxNodeDepth
|
||||||
|
|
||||||
|
private String responseCombiner;
|
||||||
|
private String groupDifferentiator;
|
||||||
|
private String treeResponseCombiner;
|
||||||
|
|
||||||
|
private List<String> covariates = new ArrayList<>();
|
||||||
|
|
||||||
|
// number of covariates to randomly try
|
||||||
|
private int mtry = 0;
|
||||||
|
|
||||||
|
// number of trees to try
|
||||||
|
private int ntree = 500;
|
||||||
|
|
||||||
|
private String dataFileLocation = "data.csv";
|
||||||
|
private String saveTreeLocation = "trees/";
|
||||||
|
|
||||||
|
private int numberOfThreads = 1;
|
||||||
|
private boolean saveProgress = false;
|
||||||
|
|
||||||
|
public Settings(){} // required for Jackson
|
||||||
|
|
||||||
|
public static Settings load(File file) throws IOException {
|
||||||
|
final ObjectMapper mapper = new ObjectMapper(new YAMLFactory());
|
||||||
|
|
||||||
|
final Settings settings = mapper.readValue(file, Settings.class);
|
||||||
|
|
||||||
|
return settings;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
public void save(File file) throws IOException {
|
||||||
|
final ObjectMapper mapper = new ObjectMapper(new YAMLFactory());
|
||||||
|
|
||||||
|
mapper.writeValue(file, this);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,42 @@
|
||||||
|
package ca.joeltherrien.randomforest.settings;
|
||||||
|
|
||||||
|
import ca.joeltherrien.randomforest.Settings;
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
import org.junit.jupiter.api.Test;
|
||||||
|
|
||||||
|
import java.io.File;
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
public class TestPersistence {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testSaving() throws IOException {
|
||||||
|
final Settings settingsOriginal = Settings.builder()
|
||||||
|
.covariates(List.of("x1", "x2", "x3"))
|
||||||
|
.dataFileLocation("data.csv")
|
||||||
|
.groupDifferentiator("WeightedVarianceGroupDifferentiator")
|
||||||
|
.responseCombiner("MeanResponseCombiner")
|
||||||
|
.treeResponseCombiner("MeanResponseCombiner")
|
||||||
|
.maxNodeDepth(100000)
|
||||||
|
.mtry(2)
|
||||||
|
.nodeSize(5)
|
||||||
|
.ntree(500)
|
||||||
|
.numberOfSplits(5)
|
||||||
|
.numberOfThreads(1)
|
||||||
|
.saveProgress(true)
|
||||||
|
.saveTreeLocation("trees/")
|
||||||
|
.build();
|
||||||
|
|
||||||
|
final File templateFile = new File("template.yaml");
|
||||||
|
settingsOriginal.save(templateFile);
|
||||||
|
|
||||||
|
final Settings reloadedSettings = Settings.load(templateFile);
|
||||||
|
|
||||||
|
assertEquals(settingsOriginal, reloadedSettings);
|
||||||
|
|
||||||
|
templateFile.delete();
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in a new issue