Add support for loading datasets by CSV files.

This commit is contained in:
Joel Therrien 2018-07-06 13:21:34 -07:00
parent fe9ff37dcf
commit 6b62ad95c3
20 changed files with 313 additions and 5 deletions

28
pom.xml
View file

@ -58,5 +58,33 @@
</dependencies> </dependencies>
<build>
<plugins>
<plugin>
<artifactId>maven-assembly-plugin</artifactId>
<configuration>
<archive>
<manifest>
<mainClass>ca.joeltherrien.randomforest.Main</mainClass>
</manifest>
</archive>
<descriptorRefs>
<descriptorRef>jar-with-dependencies</descriptorRef>
</descriptorRefs>
</configuration>
<executions>
<execution>
<id>make-assembly</id> <!-- this is used for inheritance merges -->
<phase>package</phase> <!-- bind to the packaging phase -->
<goals>
<goal>single</goal>
</goals>
</execution>
</executions>
</plugin>
</plugins>
</build>
</project> </project>

View file

@ -0,0 +1,105 @@
package ca.joeltherrien.randomforest;
import ca.joeltherrien.randomforest.covariates.BooleanCovariateSettings;
import ca.joeltherrien.randomforest.covariates.Covariate;
import ca.joeltherrien.randomforest.covariates.FactorCovariateSettings;
import ca.joeltherrien.randomforest.covariates.NumericCovariateSettings;
import ca.joeltherrien.randomforest.tree.ForestTrainer;
import org.apache.commons.csv.CSVFormat;
import org.apache.commons.csv.CSVParser;
import org.apache.commons.csv.CSVRecord;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.io.Reader;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
public class Main {
public static void main(String[] args) throws IOException {
if(args.length != 1){
System.out.println("Must provide one argument - the path to the settings.yaml file.");
if(args.length == 0){
System.out.println("Generating template file.");
defaultTemplate().save(new File("template.yaml"));
}
return;
}
final File settingsFile = new File(args[0]);
final Settings settings = Settings.load(settingsFile);
final List<Covariate> covariates = settings.getCovariates().stream()
.map(cs -> cs.build()).collect(Collectors.toList());
final List<Row<Double>> dataset = loadData(covariates, settings);
final ForestTrainer<Double> forestTrainer = new ForestTrainer<>(settings, dataset, covariates);
if(settings.isSaveProgress()){
forestTrainer.trainParallelOnDisk(settings.getNumberOfThreads());
}
else{
forestTrainer.trainParallelInMemory(settings.getNumberOfThreads());
}
}
public static List<Row<Double>> loadData(final List<Covariate> covariates, final Settings settings) throws IOException {
final List<Row<Double>> dataset = new ArrayList<>();
final Reader input = new FileReader(settings.getDataFileLocation());
final CSVParser parser = CSVFormat.RFC4180.withFirstRecordAsHeader().parse(input);
int id = 1;
for(final CSVRecord record : parser){
final Map<String, Covariate.Value> covariateValueMap = new HashMap<>();
for(final Covariate<?> covariate : covariates){
covariateValueMap.put(covariate.getName(), covariate.createValue(record.get(covariate.getName())));
}
final String yStr = record.get(settings.getYVar());
final Double yNum = Double.parseDouble(yStr);
dataset.add(new Row<>(covariateValueMap, id++, yNum));
}
return dataset;
}
private static Settings defaultTemplate(){
return Settings.builder()
.covariates(List.of(
new NumericCovariateSettings("x1"),
new BooleanCovariateSettings("x2"),
new FactorCovariateSettings("x3", List.of("cat", "mouse", "dog"))
)
)
.yVar("y")
.dataFileLocation("data.csv")
.groupDifferentiator("WeightedVarianceGroupDifferentiator")
.responseCombiner("MeanResponseCombiner")
.treeResponseCombiner("MeanResponseCombiner")
.maxNodeDepth(100000)
.mtry(2)
.nodeSize(5)
.ntree(500)
.numberOfSplits(5)
.numberOfThreads(1)
.saveProgress(true)
.saveTreeLocation("trees/")
.build();
}
}

View file

@ -2,6 +2,9 @@ package ca.joeltherrien.randomforest;
import ca.joeltherrien.randomforest.covariates.Covariate; import ca.joeltherrien.randomforest.covariates.Covariate;
import ca.joeltherrien.randomforest.covariates.CovariateSettings; import ca.joeltherrien.randomforest.covariates.CovariateSettings;
import ca.joeltherrien.randomforest.regression.MeanResponseCombiner;
import ca.joeltherrien.randomforest.tree.GroupDifferentiator;
import ca.joeltherrien.randomforest.tree.ResponseCombiner;
import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.dataformat.yaml.YAMLFactory; import com.fasterxml.jackson.dataformat.yaml.YAMLFactory;
import lombok.*; import lombok.*;
@ -9,7 +12,9 @@ import lombok.*;
import java.io.File; import java.io.File;
import java.io.IOException; import java.io.IOException;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Map;
/** /**
* This class is saved & loaded using a saved configuration file. It contains all relevant settings when training a forest. * This class is saved & loaded using a saved configuration file. It contains all relevant settings when training a forest.
@ -29,6 +34,7 @@ public class Settings {
private String treeResponseCombiner; private String treeResponseCombiner;
private List<CovariateSettings> covariates = new ArrayList<>(); private List<CovariateSettings> covariates = new ArrayList<>();
private String yVar = "y";
// number of covariates to randomly try // number of covariates to randomly try
private int mtry = 0; private int mtry = 0;
@ -64,5 +70,4 @@ public class Settings {
mapper.writeValue(file, this); mapper.writeValue(file, this);
} }
} }

View file

@ -25,6 +25,23 @@ public final class BooleanCovariate implements Covariate<Boolean>{
return new BooleanValue(value); return new BooleanValue(value);
} }
@Override
public Value<Boolean> createValue(String value) {
if(value == null || value.equalsIgnoreCase("na")){
return createValue( (Boolean) null);
}
if(value.equalsIgnoreCase("true")){
return createValue(true);
}
else if(value.equalsIgnoreCase("false")){
return createValue(false);
}
else{
throw new IllegalArgumentException("Require either true/false/na to create BooleanCovariate");
}
}
public class BooleanValue implements Value<Boolean>{ public class BooleanValue implements Value<Boolean>{
private final Boolean value; private final Boolean value;

View file

@ -12,7 +12,7 @@ public final class BooleanCovariateSettings extends CovariateSettings<Boolean> {
} }
@Override @Override
BooleanCovariate build() { public BooleanCovariate build() {
return new BooleanCovariate(name); return new BooleanCovariate(name);
} }
} }

View file

@ -16,6 +16,14 @@ public interface Covariate<V> extends Serializable {
Value<V> createValue(V value); Value<V> createValue(V value);
/**
* Creates a Value of the appropriate type from a String; primarily used when parsing CSVs.
*
* @param value
* @return
*/
Value<V> createValue(String value);
interface Value<V> extends Serializable{ interface Value<V> extends Serializable{
Covariate<V> getParent(); Covariate<V> getParent();

View file

@ -29,5 +29,5 @@ public abstract class CovariateSettings<V> {
this.name = name; this.name = name;
} }
abstract Covariate<V> build(); public abstract Covariate<V> build();
} }

View file

@ -18,7 +18,7 @@ public final class FactorCovariateSettings extends CovariateSettings<String> {
} }
@Override @Override
FactorCovariate build() { public FactorCovariate build() {
return new FactorCovariate(name, levels); return new FactorCovariate(name, levels);
} }
} }

View file

@ -55,6 +55,15 @@ public final class NumericCovariate implements Covariate<Double>{
return new NumericValue(value); return new NumericValue(value);
} }
@Override
public Value<Double> createValue(String value) {
if(value == null || value.equalsIgnoreCase("na")){
return createValue((Double) null);
}
return createValue(Double.parseDouble(value));
}
public class NumericValue implements Covariate.Value<Double>{ public class NumericValue implements Covariate.Value<Double>{
private final Double value; // may be null private final Double value; // may be null

View file

@ -12,7 +12,7 @@ public final class NumericCovariateSettings extends CovariateSettings<Double> {
} }
@Override @Override
NumericCovariate build() { public NumericCovariate build() {
return new NumericCovariate(name); return new NumericCovariate(name);
} }
} }

View file

@ -1,11 +1,16 @@
package ca.joeltherrien.randomforest.regression; package ca.joeltherrien.randomforest.regression;
import ca.joeltherrien.randomforest.tree.GroupDifferentiator; import ca.joeltherrien.randomforest.tree.GroupDifferentiator;
import ca.joeltherrien.randomforest.tree.ResponseCombiner;
import java.util.List; import java.util.List;
public class MeanGroupDifferentiator implements GroupDifferentiator<Double> { public class MeanGroupDifferentiator implements GroupDifferentiator<Double> {
static{
GroupDifferentiator.registerGroupDifferentiator("MeanGroupDifferentiator", new MeanGroupDifferentiator());
}
@Override @Override
public Double differentiate(List<Double> leftHand, List<Double> rightHand) { public Double differentiate(List<Double> leftHand, List<Double> rightHand) {

View file

@ -17,6 +17,10 @@ import java.util.function.Supplier;
*/ */
public class MeanResponseCombiner implements ResponseCombiner<Double, MeanResponseCombiner.Container> { public class MeanResponseCombiner implements ResponseCombiner<Double, MeanResponseCombiner.Container> {
static{
ResponseCombiner.registerResponseCombiner("MeanResponseCombiner", new MeanResponseCombiner());
}
@Override @Override
public Double combine(List<Double> responses) { public Double combine(List<Double> responses) {
double size = responses.size(); double size = responses.size();

View file

@ -6,6 +6,10 @@ import java.util.List;
public class WeightedVarianceGroupDifferentiator implements GroupDifferentiator<Double> { public class WeightedVarianceGroupDifferentiator implements GroupDifferentiator<Double> {
static{
GroupDifferentiator.registerGroupDifferentiator("WeightedVarianceGroupDifferentiator", new WeightedVarianceGroupDifferentiator());
}
@Override @Override
public Double differentiate(List<Double> leftHand, List<Double> rightHand) { public Double differentiate(List<Double> leftHand, List<Double> rightHand) {

View file

@ -1,8 +1,11 @@
package ca.joeltherrien.randomforest.tree; package ca.joeltherrien.randomforest.tree;
import ca.joeltherrien.randomforest.Bootstrapper; import ca.joeltherrien.randomforest.Bootstrapper;
import ca.joeltherrien.randomforest.Settings;
import ca.joeltherrien.randomforest.covariates.Covariate; import ca.joeltherrien.randomforest.covariates.Covariate;
import ca.joeltherrien.randomforest.Row; import ca.joeltherrien.randomforest.Row;
import lombok.AccessLevel;
import lombok.AllArgsConstructor;
import lombok.Builder; import lombok.Builder;
import java.io.FileOutputStream; import java.io.FileOutputStream;
@ -18,6 +21,7 @@ import java.util.stream.Collectors;
import java.util.stream.Stream; import java.util.stream.Stream;
@Builder @Builder
@AllArgsConstructor(access=AccessLevel.PRIVATE)
public class ForestTrainer<Y> { public class ForestTrainer<Y> {
private final TreeTrainer<Y> treeTrainer; private final TreeTrainer<Y> treeTrainer;
@ -34,6 +38,19 @@ public class ForestTrainer<Y> {
private final boolean displayProgress; private final boolean displayProgress;
private final String saveTreeLocation; private final String saveTreeLocation;
public ForestTrainer(final Settings settings, final List<Row<Y>> data, final List<Covariate> covariates){
this.mtry = settings.getMtry();
this.ntree = settings.getNtree();
this.data = data;
this.displayProgress = true;
this.saveTreeLocation = settings.getSaveTreeLocation();
this.covariatesToTry = covariates;
this.treeResponseCombiner = ResponseCombiner.loadResponseCombinerByName(settings.getTreeResponseCombiner());
this.treeTrainer = new TreeTrainer<>(settings);
}
public Forest<Y> trainSerial(){ public Forest<Y> trainSerial(){
final List<Node<Y>> trees = new ArrayList<>(ntree); final List<Node<Y>> trees = new ArrayList<>(ntree);

View file

@ -1,6 +1,8 @@
package ca.joeltherrien.randomforest.tree; package ca.joeltherrien.randomforest.tree;
import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Map;
/** /**
* When choosing an optimal node to split on, we choose the split that maximizes the difference between the two groups. * When choosing an optimal node to split on, we choose the split that maximizes the difference between the two groups.
@ -12,4 +14,12 @@ public interface GroupDifferentiator<Y> {
Double differentiate(List<Y> leftHand, List<Y> rightHand); Double differentiate(List<Y> leftHand, List<Y> rightHand);
Map<String, GroupDifferentiator> GROUP_DIFFERENTIATOR_MAP = new HashMap<>();
static GroupDifferentiator loadGroupDifferentiatorByName(final String name){
return GROUP_DIFFERENTIATOR_MAP.get(name);
}
static void registerGroupDifferentiator(final String name, final GroupDifferentiator groupDifferentiator){
GROUP_DIFFERENTIATOR_MAP.put(name, groupDifferentiator);
}
} }

View file

@ -1,10 +1,22 @@
package ca.joeltherrien.randomforest.tree; package ca.joeltherrien.randomforest.tree;
import ca.joeltherrien.randomforest.regression.MeanResponseCombiner;
import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Map;
import java.util.stream.Collector; import java.util.stream.Collector;
public interface ResponseCombiner<Y, K> extends Collector<Y, K, Y> { public interface ResponseCombiner<Y, K> extends Collector<Y, K, Y> {
Y combine(List<Y> responses); Y combine(List<Y> responses);
final static Map<String, ResponseCombiner> RESPONSE_COMBINER_MAP = new HashMap<>();
static ResponseCombiner loadResponseCombinerByName(final String name){
return RESPONSE_COMBINER_MAP.get(name);
}
static void registerResponseCombiner(final String name, final ResponseCombiner responseCombiner){
RESPONSE_COMBINER_MAP.put(name, responseCombiner);
}
} }

View file

@ -2,12 +2,15 @@ package ca.joeltherrien.randomforest.tree;
import ca.joeltherrien.randomforest.*; import ca.joeltherrien.randomforest.*;
import ca.joeltherrien.randomforest.covariates.Covariate; import ca.joeltherrien.randomforest.covariates.Covariate;
import lombok.AccessLevel;
import lombok.AllArgsConstructor;
import lombok.Builder; import lombok.Builder;
import java.util.*; import java.util.*;
import java.util.stream.Collectors; import java.util.stream.Collectors;
@Builder @Builder
@AllArgsConstructor(access = AccessLevel.PRIVATE)
public class TreeTrainer<Y> { public class TreeTrainer<Y> {
private final ResponseCombiner<Y, ?> responseCombiner; private final ResponseCombiner<Y, ?> responseCombiner;
@ -21,6 +24,14 @@ public class TreeTrainer<Y> {
private final int nodeSize; private final int nodeSize;
private final int maxNodeDepth; private final int maxNodeDepth;
public TreeTrainer(final Settings settings){
this.numberOfSplits = settings.getNumberOfSplits();
this.nodeSize = settings.getNodeSize();
this.maxNodeDepth = settings.getMaxNodeDepth();
this.responseCombiner = ResponseCombiner.loadResponseCombinerByName(settings.getResponseCombiner());
this.groupDifferentiator = GroupDifferentiator.loadGroupDifferentiatorByName(settings.getGroupDifferentiator());
}
public Node<Y> growTree(List<Row<Y>> data, List<Covariate> covariatesToTry){ public Node<Y> growTree(List<Row<Y>> data, List<Covariate> covariatesToTry){
return growNode(data, covariatesToTry, 0); return growNode(data, covariatesToTry, 0);

View file

@ -0,0 +1,66 @@
package ca.joeltherrien.randomforest.csv;
import ca.joeltherrien.randomforest.Main;
import ca.joeltherrien.randomforest.Row;
import ca.joeltherrien.randomforest.Settings;
import ca.joeltherrien.randomforest.covariates.BooleanCovariateSettings;
import ca.joeltherrien.randomforest.covariates.Covariate;
import ca.joeltherrien.randomforest.covariates.FactorCovariateSettings;
import ca.joeltherrien.randomforest.covariates.NumericCovariateSettings;
import org.junit.jupiter.api.Test;
import java.io.IOException;
import java.util.List;
import java.util.stream.Collectors;
import static org.junit.jupiter.api.Assertions.assertEquals;
public class TestLoadingCSV {
/*
y,x1,x2,x3
5,3.0,"mouse",true
2,1.0,"dog",false
9,1.5,"cat",true
*/
@Test
public void verifyLoading() throws IOException {
final Settings settings = Settings.builder()
.dataFileLocation("src/test/resources/testCSV.csv")
.covariates(
List.of(new NumericCovariateSettings("x1"),
new FactorCovariateSettings("x2", List.of("dog", "cat", "mouse")),
new BooleanCovariateSettings("x3"))
)
.yVar("y")
.build();
final List<Covariate> covariates = settings.getCovariates().stream()
.map(cs -> cs.build()).collect(Collectors.toList());
final List<Row<Double>> data = Main.loadData(covariates, settings);
assertEquals(3, data.size());
Row<Double> row = data.get(0);
assertEquals(5.0, (double)row.getResponse());
assertEquals(3.0, row.getCovariateValue("x1").getValue());
assertEquals("mouse", row.getCovariateValue("x2").getValue());
assertEquals(true, row.getCovariateValue("x3").getValue());
row = data.get(1);
assertEquals(2.0, (double)row.getResponse());
assertEquals(1.0, row.getCovariateValue("x1").getValue());
assertEquals("dog", row.getCovariateValue("x2").getValue());
assertEquals(false, row.getCovariateValue("x3").getValue());
row = data.get(2);
assertEquals(9.0, (double)row.getResponse());
assertEquals(1.5, row.getCovariateValue("x1").getValue());
assertEquals("cat", row.getCovariateValue("x2").getValue());
assertEquals(true, row.getCovariateValue("x3").getValue());
}
}

View file

@ -4,6 +4,8 @@ import ca.joeltherrien.randomforest.Settings;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import ca.joeltherrien.randomforest.covariates.*; import ca.joeltherrien.randomforest.covariates.*;
import ca.joeltherrien.randomforest.regression.MeanResponseCombiner;
import ca.joeltherrien.randomforest.regression.WeightedVarianceGroupDifferentiator;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import java.io.File; import java.io.File;
@ -22,6 +24,7 @@ public class TestPersistence {
new FactorCovariateSettings("x3", List.of("cat", "mouse", "dog")) new FactorCovariateSettings("x3", List.of("cat", "mouse", "dog"))
) )
) )
.yVar("y")
.dataFileLocation("data.csv") .dataFileLocation("data.csv")
.groupDifferentiator("WeightedVarianceGroupDifferentiator") .groupDifferentiator("WeightedVarianceGroupDifferentiator")
.responseCombiner("MeanResponseCombiner") .responseCombiner("MeanResponseCombiner")

View file

@ -0,0 +1,4 @@
y,x1,x2,x3
5,3.0,"mouse",true
2,1.0,"dog",false
9,1.5,"cat",true
1 y x1 x2 x3
2 5 3.0 mouse true
3 2 1.0 dog false
4 9 1.5 cat true