Add support for loading datasets by CSV files.
This commit is contained in:
parent
fe9ff37dcf
commit
6b62ad95c3
20 changed files with 313 additions and 5 deletions
28
pom.xml
28
pom.xml
|
@ -58,5 +58,33 @@
|
|||
|
||||
</dependencies>
|
||||
|
||||
<build>
|
||||
<plugins>
|
||||
<plugin>
|
||||
<artifactId>maven-assembly-plugin</artifactId>
|
||||
<configuration>
|
||||
<archive>
|
||||
<manifest>
|
||||
<mainClass>ca.joeltherrien.randomforest.Main</mainClass>
|
||||
</manifest>
|
||||
</archive>
|
||||
<descriptorRefs>
|
||||
<descriptorRef>jar-with-dependencies</descriptorRef>
|
||||
</descriptorRefs>
|
||||
</configuration>
|
||||
<executions>
|
||||
<execution>
|
||||
<id>make-assembly</id> <!-- this is used for inheritance merges -->
|
||||
<phase>package</phase> <!-- bind to the packaging phase -->
|
||||
<goals>
|
||||
<goal>single</goal>
|
||||
</goals>
|
||||
</execution>
|
||||
</executions>
|
||||
</plugin>
|
||||
</plugins>
|
||||
</build>
|
||||
|
||||
|
||||
|
||||
</project>
|
105
src/main/java/ca/joeltherrien/randomforest/Main.java
Normal file
105
src/main/java/ca/joeltherrien/randomforest/Main.java
Normal file
|
@ -0,0 +1,105 @@
|
|||
package ca.joeltherrien.randomforest;
|
||||
|
||||
import ca.joeltherrien.randomforest.covariates.BooleanCovariateSettings;
|
||||
import ca.joeltherrien.randomforest.covariates.Covariate;
|
||||
import ca.joeltherrien.randomforest.covariates.FactorCovariateSettings;
|
||||
import ca.joeltherrien.randomforest.covariates.NumericCovariateSettings;
|
||||
import ca.joeltherrien.randomforest.tree.ForestTrainer;
|
||||
import org.apache.commons.csv.CSVFormat;
|
||||
import org.apache.commons.csv.CSVParser;
|
||||
import org.apache.commons.csv.CSVRecord;
|
||||
|
||||
import java.io.File;
|
||||
import java.io.FileReader;
|
||||
import java.io.IOException;
|
||||
import java.io.Reader;
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
public class Main {
|
||||
|
||||
public static void main(String[] args) throws IOException {
|
||||
if(args.length != 1){
|
||||
System.out.println("Must provide one argument - the path to the settings.yaml file.");
|
||||
if(args.length == 0){
|
||||
System.out.println("Generating template file.");
|
||||
defaultTemplate().save(new File("template.yaml"));
|
||||
}
|
||||
return;
|
||||
}
|
||||
final File settingsFile = new File(args[0]);
|
||||
final Settings settings = Settings.load(settingsFile);
|
||||
|
||||
final List<Covariate> covariates = settings.getCovariates().stream()
|
||||
.map(cs -> cs.build()).collect(Collectors.toList());
|
||||
|
||||
final List<Row<Double>> dataset = loadData(covariates, settings);
|
||||
|
||||
final ForestTrainer<Double> forestTrainer = new ForestTrainer<>(settings, dataset, covariates);
|
||||
|
||||
if(settings.isSaveProgress()){
|
||||
forestTrainer.trainParallelOnDisk(settings.getNumberOfThreads());
|
||||
}
|
||||
else{
|
||||
forestTrainer.trainParallelInMemory(settings.getNumberOfThreads());
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
||||
|
||||
public static List<Row<Double>> loadData(final List<Covariate> covariates, final Settings settings) throws IOException {
|
||||
|
||||
final List<Row<Double>> dataset = new ArrayList<>();
|
||||
|
||||
final Reader input = new FileReader(settings.getDataFileLocation());
|
||||
final CSVParser parser = CSVFormat.RFC4180.withFirstRecordAsHeader().parse(input);
|
||||
|
||||
|
||||
int id = 1;
|
||||
for(final CSVRecord record : parser){
|
||||
final Map<String, Covariate.Value> covariateValueMap = new HashMap<>();
|
||||
|
||||
for(final Covariate<?> covariate : covariates){
|
||||
covariateValueMap.put(covariate.getName(), covariate.createValue(record.get(covariate.getName())));
|
||||
}
|
||||
|
||||
final String yStr = record.get(settings.getYVar());
|
||||
final Double yNum = Double.parseDouble(yStr);
|
||||
|
||||
dataset.add(new Row<>(covariateValueMap, id++, yNum));
|
||||
|
||||
}
|
||||
|
||||
return dataset;
|
||||
|
||||
}
|
||||
|
||||
private static Settings defaultTemplate(){
|
||||
return Settings.builder()
|
||||
.covariates(List.of(
|
||||
new NumericCovariateSettings("x1"),
|
||||
new BooleanCovariateSettings("x2"),
|
||||
new FactorCovariateSettings("x3", List.of("cat", "mouse", "dog"))
|
||||
)
|
||||
)
|
||||
.yVar("y")
|
||||
.dataFileLocation("data.csv")
|
||||
.groupDifferentiator("WeightedVarianceGroupDifferentiator")
|
||||
.responseCombiner("MeanResponseCombiner")
|
||||
.treeResponseCombiner("MeanResponseCombiner")
|
||||
.maxNodeDepth(100000)
|
||||
.mtry(2)
|
||||
.nodeSize(5)
|
||||
.ntree(500)
|
||||
.numberOfSplits(5)
|
||||
.numberOfThreads(1)
|
||||
.saveProgress(true)
|
||||
.saveTreeLocation("trees/")
|
||||
.build();
|
||||
}
|
||||
|
||||
}
|
|
@ -2,6 +2,9 @@ package ca.joeltherrien.randomforest;
|
|||
|
||||
import ca.joeltherrien.randomforest.covariates.Covariate;
|
||||
import ca.joeltherrien.randomforest.covariates.CovariateSettings;
|
||||
import ca.joeltherrien.randomforest.regression.MeanResponseCombiner;
|
||||
import ca.joeltherrien.randomforest.tree.GroupDifferentiator;
|
||||
import ca.joeltherrien.randomforest.tree.ResponseCombiner;
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import com.fasterxml.jackson.dataformat.yaml.YAMLFactory;
|
||||
import lombok.*;
|
||||
|
@ -9,7 +12,9 @@ import lombok.*;
|
|||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
/**
|
||||
* This class is saved & loaded using a saved configuration file. It contains all relevant settings when training a forest.
|
||||
|
@ -29,6 +34,7 @@ public class Settings {
|
|||
private String treeResponseCombiner;
|
||||
|
||||
private List<CovariateSettings> covariates = new ArrayList<>();
|
||||
private String yVar = "y";
|
||||
|
||||
// number of covariates to randomly try
|
||||
private int mtry = 0;
|
||||
|
@ -64,5 +70,4 @@ public class Settings {
|
|||
mapper.writeValue(file, this);
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
|
|
@ -25,6 +25,23 @@ public final class BooleanCovariate implements Covariate<Boolean>{
|
|||
return new BooleanValue(value);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Value<Boolean> createValue(String value) {
|
||||
if(value == null || value.equalsIgnoreCase("na")){
|
||||
return createValue( (Boolean) null);
|
||||
}
|
||||
|
||||
if(value.equalsIgnoreCase("true")){
|
||||
return createValue(true);
|
||||
}
|
||||
else if(value.equalsIgnoreCase("false")){
|
||||
return createValue(false);
|
||||
}
|
||||
else{
|
||||
throw new IllegalArgumentException("Require either true/false/na to create BooleanCovariate");
|
||||
}
|
||||
}
|
||||
|
||||
public class BooleanValue implements Value<Boolean>{
|
||||
|
||||
private final Boolean value;
|
||||
|
|
|
@ -12,7 +12,7 @@ public final class BooleanCovariateSettings extends CovariateSettings<Boolean> {
|
|||
}
|
||||
|
||||
@Override
|
||||
BooleanCovariate build() {
|
||||
public BooleanCovariate build() {
|
||||
return new BooleanCovariate(name);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -16,6 +16,14 @@ public interface Covariate<V> extends Serializable {
|
|||
|
||||
Value<V> createValue(V value);
|
||||
|
||||
/**
|
||||
* Creates a Value of the appropriate type from a String; primarily used when parsing CSVs.
|
||||
*
|
||||
* @param value
|
||||
* @return
|
||||
*/
|
||||
Value<V> createValue(String value);
|
||||
|
||||
interface Value<V> extends Serializable{
|
||||
|
||||
Covariate<V> getParent();
|
||||
|
|
|
@ -29,5 +29,5 @@ public abstract class CovariateSettings<V> {
|
|||
this.name = name;
|
||||
}
|
||||
|
||||
abstract Covariate<V> build();
|
||||
public abstract Covariate<V> build();
|
||||
}
|
||||
|
|
|
@ -18,7 +18,7 @@ public final class FactorCovariateSettings extends CovariateSettings<String> {
|
|||
}
|
||||
|
||||
@Override
|
||||
FactorCovariate build() {
|
||||
public FactorCovariate build() {
|
||||
return new FactorCovariate(name, levels);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -55,6 +55,15 @@ public final class NumericCovariate implements Covariate<Double>{
|
|||
return new NumericValue(value);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Value<Double> createValue(String value) {
|
||||
if(value == null || value.equalsIgnoreCase("na")){
|
||||
return createValue((Double) null);
|
||||
}
|
||||
|
||||
return createValue(Double.parseDouble(value));
|
||||
}
|
||||
|
||||
public class NumericValue implements Covariate.Value<Double>{
|
||||
|
||||
private final Double value; // may be null
|
||||
|
|
|
@ -12,7 +12,7 @@ public final class NumericCovariateSettings extends CovariateSettings<Double> {
|
|||
}
|
||||
|
||||
@Override
|
||||
NumericCovariate build() {
|
||||
public NumericCovariate build() {
|
||||
return new NumericCovariate(name);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,11 +1,16 @@
|
|||
package ca.joeltherrien.randomforest.regression;
|
||||
|
||||
import ca.joeltherrien.randomforest.tree.GroupDifferentiator;
|
||||
import ca.joeltherrien.randomforest.tree.ResponseCombiner;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
public class MeanGroupDifferentiator implements GroupDifferentiator<Double> {
|
||||
|
||||
static{
|
||||
GroupDifferentiator.registerGroupDifferentiator("MeanGroupDifferentiator", new MeanGroupDifferentiator());
|
||||
}
|
||||
|
||||
@Override
|
||||
public Double differentiate(List<Double> leftHand, List<Double> rightHand) {
|
||||
|
||||
|
|
|
@ -17,6 +17,10 @@ import java.util.function.Supplier;
|
|||
*/
|
||||
public class MeanResponseCombiner implements ResponseCombiner<Double, MeanResponseCombiner.Container> {
|
||||
|
||||
static{
|
||||
ResponseCombiner.registerResponseCombiner("MeanResponseCombiner", new MeanResponseCombiner());
|
||||
}
|
||||
|
||||
@Override
|
||||
public Double combine(List<Double> responses) {
|
||||
double size = responses.size();
|
||||
|
|
|
@ -6,6 +6,10 @@ import java.util.List;
|
|||
|
||||
public class WeightedVarianceGroupDifferentiator implements GroupDifferentiator<Double> {
|
||||
|
||||
static{
|
||||
GroupDifferentiator.registerGroupDifferentiator("WeightedVarianceGroupDifferentiator", new WeightedVarianceGroupDifferentiator());
|
||||
}
|
||||
|
||||
@Override
|
||||
public Double differentiate(List<Double> leftHand, List<Double> rightHand) {
|
||||
|
||||
|
|
|
@ -1,8 +1,11 @@
|
|||
package ca.joeltherrien.randomforest.tree;
|
||||
|
||||
import ca.joeltherrien.randomforest.Bootstrapper;
|
||||
import ca.joeltherrien.randomforest.Settings;
|
||||
import ca.joeltherrien.randomforest.covariates.Covariate;
|
||||
import ca.joeltherrien.randomforest.Row;
|
||||
import lombok.AccessLevel;
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Builder;
|
||||
|
||||
import java.io.FileOutputStream;
|
||||
|
@ -18,6 +21,7 @@ import java.util.stream.Collectors;
|
|||
import java.util.stream.Stream;
|
||||
|
||||
@Builder
|
||||
@AllArgsConstructor(access=AccessLevel.PRIVATE)
|
||||
public class ForestTrainer<Y> {
|
||||
|
||||
private final TreeTrainer<Y> treeTrainer;
|
||||
|
@ -34,6 +38,19 @@ public class ForestTrainer<Y> {
|
|||
private final boolean displayProgress;
|
||||
private final String saveTreeLocation;
|
||||
|
||||
public ForestTrainer(final Settings settings, final List<Row<Y>> data, final List<Covariate> covariates){
|
||||
this.mtry = settings.getMtry();
|
||||
this.ntree = settings.getNtree();
|
||||
this.data = data;
|
||||
this.displayProgress = true;
|
||||
this.saveTreeLocation = settings.getSaveTreeLocation();
|
||||
|
||||
this.covariatesToTry = covariates;
|
||||
this.treeResponseCombiner = ResponseCombiner.loadResponseCombinerByName(settings.getTreeResponseCombiner());
|
||||
this.treeTrainer = new TreeTrainer<>(settings);
|
||||
|
||||
}
|
||||
|
||||
public Forest<Y> trainSerial(){
|
||||
|
||||
final List<Node<Y>> trees = new ArrayList<>(ntree);
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
package ca.joeltherrien.randomforest.tree;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
/**
|
||||
* When choosing an optimal node to split on, we choose the split that maximizes the difference between the two groups.
|
||||
|
@ -12,4 +14,12 @@ public interface GroupDifferentiator<Y> {
|
|||
|
||||
Double differentiate(List<Y> leftHand, List<Y> rightHand);
|
||||
|
||||
Map<String, GroupDifferentiator> GROUP_DIFFERENTIATOR_MAP = new HashMap<>();
|
||||
static GroupDifferentiator loadGroupDifferentiatorByName(final String name){
|
||||
return GROUP_DIFFERENTIATOR_MAP.get(name);
|
||||
}
|
||||
static void registerGroupDifferentiator(final String name, final GroupDifferentiator groupDifferentiator){
|
||||
GROUP_DIFFERENTIATOR_MAP.put(name, groupDifferentiator);
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -1,10 +1,22 @@
|
|||
package ca.joeltherrien.randomforest.tree;
|
||||
|
||||
import ca.joeltherrien.randomforest.regression.MeanResponseCombiner;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.stream.Collector;
|
||||
|
||||
public interface ResponseCombiner<Y, K> extends Collector<Y, K, Y> {
|
||||
|
||||
Y combine(List<Y> responses);
|
||||
|
||||
final static Map<String, ResponseCombiner> RESPONSE_COMBINER_MAP = new HashMap<>();
|
||||
static ResponseCombiner loadResponseCombinerByName(final String name){
|
||||
return RESPONSE_COMBINER_MAP.get(name);
|
||||
}
|
||||
static void registerResponseCombiner(final String name, final ResponseCombiner responseCombiner){
|
||||
RESPONSE_COMBINER_MAP.put(name, responseCombiner);
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -2,12 +2,15 @@ package ca.joeltherrien.randomforest.tree;
|
|||
|
||||
import ca.joeltherrien.randomforest.*;
|
||||
import ca.joeltherrien.randomforest.covariates.Covariate;
|
||||
import lombok.AccessLevel;
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Builder;
|
||||
|
||||
import java.util.*;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@Builder
|
||||
@AllArgsConstructor(access = AccessLevel.PRIVATE)
|
||||
public class TreeTrainer<Y> {
|
||||
|
||||
private final ResponseCombiner<Y, ?> responseCombiner;
|
||||
|
@ -21,6 +24,14 @@ public class TreeTrainer<Y> {
|
|||
private final int nodeSize;
|
||||
private final int maxNodeDepth;
|
||||
|
||||
public TreeTrainer(final Settings settings){
|
||||
this.numberOfSplits = settings.getNumberOfSplits();
|
||||
this.nodeSize = settings.getNodeSize();
|
||||
this.maxNodeDepth = settings.getMaxNodeDepth();
|
||||
|
||||
this.responseCombiner = ResponseCombiner.loadResponseCombinerByName(settings.getResponseCombiner());
|
||||
this.groupDifferentiator = GroupDifferentiator.loadGroupDifferentiatorByName(settings.getGroupDifferentiator());
|
||||
}
|
||||
|
||||
public Node<Y> growTree(List<Row<Y>> data, List<Covariate> covariatesToTry){
|
||||
return growNode(data, covariatesToTry, 0);
|
||||
|
|
|
@ -0,0 +1,66 @@
|
|||
package ca.joeltherrien.randomforest.csv;
|
||||
|
||||
import ca.joeltherrien.randomforest.Main;
|
||||
import ca.joeltherrien.randomforest.Row;
|
||||
import ca.joeltherrien.randomforest.Settings;
|
||||
import ca.joeltherrien.randomforest.covariates.BooleanCovariateSettings;
|
||||
import ca.joeltherrien.randomforest.covariates.Covariate;
|
||||
import ca.joeltherrien.randomforest.covariates.FactorCovariateSettings;
|
||||
import ca.joeltherrien.randomforest.covariates.NumericCovariateSettings;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
|
||||
public class TestLoadingCSV {
|
||||
|
||||
/*
|
||||
y,x1,x2,x3
|
||||
5,3.0,"mouse",true
|
||||
2,1.0,"dog",false
|
||||
9,1.5,"cat",true
|
||||
*/
|
||||
|
||||
@Test
|
||||
public void verifyLoading() throws IOException {
|
||||
final Settings settings = Settings.builder()
|
||||
.dataFileLocation("src/test/resources/testCSV.csv")
|
||||
.covariates(
|
||||
List.of(new NumericCovariateSettings("x1"),
|
||||
new FactorCovariateSettings("x2", List.of("dog", "cat", "mouse")),
|
||||
new BooleanCovariateSettings("x3"))
|
||||
)
|
||||
.yVar("y")
|
||||
.build();
|
||||
|
||||
final List<Covariate> covariates = settings.getCovariates().stream()
|
||||
.map(cs -> cs.build()).collect(Collectors.toList());
|
||||
|
||||
final List<Row<Double>> data = Main.loadData(covariates, settings);
|
||||
|
||||
assertEquals(3, data.size());
|
||||
|
||||
Row<Double> row = data.get(0);
|
||||
assertEquals(5.0, (double)row.getResponse());
|
||||
assertEquals(3.0, row.getCovariateValue("x1").getValue());
|
||||
assertEquals("mouse", row.getCovariateValue("x2").getValue());
|
||||
assertEquals(true, row.getCovariateValue("x3").getValue());
|
||||
|
||||
row = data.get(1);
|
||||
assertEquals(2.0, (double)row.getResponse());
|
||||
assertEquals(1.0, row.getCovariateValue("x1").getValue());
|
||||
assertEquals("dog", row.getCovariateValue("x2").getValue());
|
||||
assertEquals(false, row.getCovariateValue("x3").getValue());
|
||||
|
||||
row = data.get(2);
|
||||
assertEquals(9.0, (double)row.getResponse());
|
||||
assertEquals(1.5, row.getCovariateValue("x1").getValue());
|
||||
assertEquals("cat", row.getCovariateValue("x2").getValue());
|
||||
assertEquals(true, row.getCovariateValue("x3").getValue());
|
||||
|
||||
}
|
||||
|
||||
}
|
|
@ -4,6 +4,8 @@ import ca.joeltherrien.randomforest.Settings;
|
|||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
|
||||
import ca.joeltherrien.randomforest.covariates.*;
|
||||
import ca.joeltherrien.randomforest.regression.MeanResponseCombiner;
|
||||
import ca.joeltherrien.randomforest.regression.WeightedVarianceGroupDifferentiator;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import java.io.File;
|
||||
|
@ -22,6 +24,7 @@ public class TestPersistence {
|
|||
new FactorCovariateSettings("x3", List.of("cat", "mouse", "dog"))
|
||||
)
|
||||
)
|
||||
.yVar("y")
|
||||
.dataFileLocation("data.csv")
|
||||
.groupDifferentiator("WeightedVarianceGroupDifferentiator")
|
||||
.responseCombiner("MeanResponseCombiner")
|
||||
|
|
4
src/test/resources/testCSV.csv
Normal file
4
src/test/resources/testCSV.csv
Normal file
|
@ -0,0 +1,4 @@
|
|||
y,x1,x2,x3
|
||||
5,3.0,"mouse",true
|
||||
2,1.0,"dog",false
|
||||
9,1.5,"cat",true
|
|
Loading…
Reference in a new issue