Add support for loading datasets by CSV files.
This commit is contained in:
parent
fe9ff37dcf
commit
6b62ad95c3
20 changed files with 313 additions and 5 deletions
28
pom.xml
28
pom.xml
|
@ -58,5 +58,33 @@
|
||||||
|
|
||||||
</dependencies>
|
</dependencies>
|
||||||
|
|
||||||
|
<build>
|
||||||
|
<plugins>
|
||||||
|
<plugin>
|
||||||
|
<artifactId>maven-assembly-plugin</artifactId>
|
||||||
|
<configuration>
|
||||||
|
<archive>
|
||||||
|
<manifest>
|
||||||
|
<mainClass>ca.joeltherrien.randomforest.Main</mainClass>
|
||||||
|
</manifest>
|
||||||
|
</archive>
|
||||||
|
<descriptorRefs>
|
||||||
|
<descriptorRef>jar-with-dependencies</descriptorRef>
|
||||||
|
</descriptorRefs>
|
||||||
|
</configuration>
|
||||||
|
<executions>
|
||||||
|
<execution>
|
||||||
|
<id>make-assembly</id> <!-- this is used for inheritance merges -->
|
||||||
|
<phase>package</phase> <!-- bind to the packaging phase -->
|
||||||
|
<goals>
|
||||||
|
<goal>single</goal>
|
||||||
|
</goals>
|
||||||
|
</execution>
|
||||||
|
</executions>
|
||||||
|
</plugin>
|
||||||
|
</plugins>
|
||||||
|
</build>
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
</project>
|
</project>
|
105
src/main/java/ca/joeltherrien/randomforest/Main.java
Normal file
105
src/main/java/ca/joeltherrien/randomforest/Main.java
Normal file
|
@ -0,0 +1,105 @@
|
||||||
|
package ca.joeltherrien.randomforest;
|
||||||
|
|
||||||
|
import ca.joeltherrien.randomforest.covariates.BooleanCovariateSettings;
|
||||||
|
import ca.joeltherrien.randomforest.covariates.Covariate;
|
||||||
|
import ca.joeltherrien.randomforest.covariates.FactorCovariateSettings;
|
||||||
|
import ca.joeltherrien.randomforest.covariates.NumericCovariateSettings;
|
||||||
|
import ca.joeltherrien.randomforest.tree.ForestTrainer;
|
||||||
|
import org.apache.commons.csv.CSVFormat;
|
||||||
|
import org.apache.commons.csv.CSVParser;
|
||||||
|
import org.apache.commons.csv.CSVRecord;
|
||||||
|
|
||||||
|
import java.io.File;
|
||||||
|
import java.io.FileReader;
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.io.Reader;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
|
public class Main {
|
||||||
|
|
||||||
|
public static void main(String[] args) throws IOException {
|
||||||
|
if(args.length != 1){
|
||||||
|
System.out.println("Must provide one argument - the path to the settings.yaml file.");
|
||||||
|
if(args.length == 0){
|
||||||
|
System.out.println("Generating template file.");
|
||||||
|
defaultTemplate().save(new File("template.yaml"));
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
final File settingsFile = new File(args[0]);
|
||||||
|
final Settings settings = Settings.load(settingsFile);
|
||||||
|
|
||||||
|
final List<Covariate> covariates = settings.getCovariates().stream()
|
||||||
|
.map(cs -> cs.build()).collect(Collectors.toList());
|
||||||
|
|
||||||
|
final List<Row<Double>> dataset = loadData(covariates, settings);
|
||||||
|
|
||||||
|
final ForestTrainer<Double> forestTrainer = new ForestTrainer<>(settings, dataset, covariates);
|
||||||
|
|
||||||
|
if(settings.isSaveProgress()){
|
||||||
|
forestTrainer.trainParallelOnDisk(settings.getNumberOfThreads());
|
||||||
|
}
|
||||||
|
else{
|
||||||
|
forestTrainer.trainParallelInMemory(settings.getNumberOfThreads());
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
public static List<Row<Double>> loadData(final List<Covariate> covariates, final Settings settings) throws IOException {
|
||||||
|
|
||||||
|
final List<Row<Double>> dataset = new ArrayList<>();
|
||||||
|
|
||||||
|
final Reader input = new FileReader(settings.getDataFileLocation());
|
||||||
|
final CSVParser parser = CSVFormat.RFC4180.withFirstRecordAsHeader().parse(input);
|
||||||
|
|
||||||
|
|
||||||
|
int id = 1;
|
||||||
|
for(final CSVRecord record : parser){
|
||||||
|
final Map<String, Covariate.Value> covariateValueMap = new HashMap<>();
|
||||||
|
|
||||||
|
for(final Covariate<?> covariate : covariates){
|
||||||
|
covariateValueMap.put(covariate.getName(), covariate.createValue(record.get(covariate.getName())));
|
||||||
|
}
|
||||||
|
|
||||||
|
final String yStr = record.get(settings.getYVar());
|
||||||
|
final Double yNum = Double.parseDouble(yStr);
|
||||||
|
|
||||||
|
dataset.add(new Row<>(covariateValueMap, id++, yNum));
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
return dataset;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
private static Settings defaultTemplate(){
|
||||||
|
return Settings.builder()
|
||||||
|
.covariates(List.of(
|
||||||
|
new NumericCovariateSettings("x1"),
|
||||||
|
new BooleanCovariateSettings("x2"),
|
||||||
|
new FactorCovariateSettings("x3", List.of("cat", "mouse", "dog"))
|
||||||
|
)
|
||||||
|
)
|
||||||
|
.yVar("y")
|
||||||
|
.dataFileLocation("data.csv")
|
||||||
|
.groupDifferentiator("WeightedVarianceGroupDifferentiator")
|
||||||
|
.responseCombiner("MeanResponseCombiner")
|
||||||
|
.treeResponseCombiner("MeanResponseCombiner")
|
||||||
|
.maxNodeDepth(100000)
|
||||||
|
.mtry(2)
|
||||||
|
.nodeSize(5)
|
||||||
|
.ntree(500)
|
||||||
|
.numberOfSplits(5)
|
||||||
|
.numberOfThreads(1)
|
||||||
|
.saveProgress(true)
|
||||||
|
.saveTreeLocation("trees/")
|
||||||
|
.build();
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -2,6 +2,9 @@ package ca.joeltherrien.randomforest;
|
||||||
|
|
||||||
import ca.joeltherrien.randomforest.covariates.Covariate;
|
import ca.joeltherrien.randomforest.covariates.Covariate;
|
||||||
import ca.joeltherrien.randomforest.covariates.CovariateSettings;
|
import ca.joeltherrien.randomforest.covariates.CovariateSettings;
|
||||||
|
import ca.joeltherrien.randomforest.regression.MeanResponseCombiner;
|
||||||
|
import ca.joeltherrien.randomforest.tree.GroupDifferentiator;
|
||||||
|
import ca.joeltherrien.randomforest.tree.ResponseCombiner;
|
||||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||||
import com.fasterxml.jackson.dataformat.yaml.YAMLFactory;
|
import com.fasterxml.jackson.dataformat.yaml.YAMLFactory;
|
||||||
import lombok.*;
|
import lombok.*;
|
||||||
|
@ -9,7 +12,9 @@ import lombok.*;
|
||||||
import java.io.File;
|
import java.io.File;
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
|
import java.util.HashMap;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This class is saved & loaded using a saved configuration file. It contains all relevant settings when training a forest.
|
* This class is saved & loaded using a saved configuration file. It contains all relevant settings when training a forest.
|
||||||
|
@ -29,6 +34,7 @@ public class Settings {
|
||||||
private String treeResponseCombiner;
|
private String treeResponseCombiner;
|
||||||
|
|
||||||
private List<CovariateSettings> covariates = new ArrayList<>();
|
private List<CovariateSettings> covariates = new ArrayList<>();
|
||||||
|
private String yVar = "y";
|
||||||
|
|
||||||
// number of covariates to randomly try
|
// number of covariates to randomly try
|
||||||
private int mtry = 0;
|
private int mtry = 0;
|
||||||
|
@ -64,5 +70,4 @@ public class Settings {
|
||||||
mapper.writeValue(file, this);
|
mapper.writeValue(file, this);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -25,6 +25,23 @@ public final class BooleanCovariate implements Covariate<Boolean>{
|
||||||
return new BooleanValue(value);
|
return new BooleanValue(value);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Value<Boolean> createValue(String value) {
|
||||||
|
if(value == null || value.equalsIgnoreCase("na")){
|
||||||
|
return createValue( (Boolean) null);
|
||||||
|
}
|
||||||
|
|
||||||
|
if(value.equalsIgnoreCase("true")){
|
||||||
|
return createValue(true);
|
||||||
|
}
|
||||||
|
else if(value.equalsIgnoreCase("false")){
|
||||||
|
return createValue(false);
|
||||||
|
}
|
||||||
|
else{
|
||||||
|
throw new IllegalArgumentException("Require either true/false/na to create BooleanCovariate");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
public class BooleanValue implements Value<Boolean>{
|
public class BooleanValue implements Value<Boolean>{
|
||||||
|
|
||||||
private final Boolean value;
|
private final Boolean value;
|
||||||
|
|
|
@ -12,7 +12,7 @@ public final class BooleanCovariateSettings extends CovariateSettings<Boolean> {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
BooleanCovariate build() {
|
public BooleanCovariate build() {
|
||||||
return new BooleanCovariate(name);
|
return new BooleanCovariate(name);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -16,6 +16,14 @@ public interface Covariate<V> extends Serializable {
|
||||||
|
|
||||||
Value<V> createValue(V value);
|
Value<V> createValue(V value);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates a Value of the appropriate type from a String; primarily used when parsing CSVs.
|
||||||
|
*
|
||||||
|
* @param value
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
Value<V> createValue(String value);
|
||||||
|
|
||||||
interface Value<V> extends Serializable{
|
interface Value<V> extends Serializable{
|
||||||
|
|
||||||
Covariate<V> getParent();
|
Covariate<V> getParent();
|
||||||
|
|
|
@ -29,5 +29,5 @@ public abstract class CovariateSettings<V> {
|
||||||
this.name = name;
|
this.name = name;
|
||||||
}
|
}
|
||||||
|
|
||||||
abstract Covariate<V> build();
|
public abstract Covariate<V> build();
|
||||||
}
|
}
|
||||||
|
|
|
@ -18,7 +18,7 @@ public final class FactorCovariateSettings extends CovariateSettings<String> {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
FactorCovariate build() {
|
public FactorCovariate build() {
|
||||||
return new FactorCovariate(name, levels);
|
return new FactorCovariate(name, levels);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -55,6 +55,15 @@ public final class NumericCovariate implements Covariate<Double>{
|
||||||
return new NumericValue(value);
|
return new NumericValue(value);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Value<Double> createValue(String value) {
|
||||||
|
if(value == null || value.equalsIgnoreCase("na")){
|
||||||
|
return createValue((Double) null);
|
||||||
|
}
|
||||||
|
|
||||||
|
return createValue(Double.parseDouble(value));
|
||||||
|
}
|
||||||
|
|
||||||
public class NumericValue implements Covariate.Value<Double>{
|
public class NumericValue implements Covariate.Value<Double>{
|
||||||
|
|
||||||
private final Double value; // may be null
|
private final Double value; // may be null
|
||||||
|
|
|
@ -12,7 +12,7 @@ public final class NumericCovariateSettings extends CovariateSettings<Double> {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
NumericCovariate build() {
|
public NumericCovariate build() {
|
||||||
return new NumericCovariate(name);
|
return new NumericCovariate(name);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,11 +1,16 @@
|
||||||
package ca.joeltherrien.randomforest.regression;
|
package ca.joeltherrien.randomforest.regression;
|
||||||
|
|
||||||
import ca.joeltherrien.randomforest.tree.GroupDifferentiator;
|
import ca.joeltherrien.randomforest.tree.GroupDifferentiator;
|
||||||
|
import ca.joeltherrien.randomforest.tree.ResponseCombiner;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
public class MeanGroupDifferentiator implements GroupDifferentiator<Double> {
|
public class MeanGroupDifferentiator implements GroupDifferentiator<Double> {
|
||||||
|
|
||||||
|
static{
|
||||||
|
GroupDifferentiator.registerGroupDifferentiator("MeanGroupDifferentiator", new MeanGroupDifferentiator());
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Double differentiate(List<Double> leftHand, List<Double> rightHand) {
|
public Double differentiate(List<Double> leftHand, List<Double> rightHand) {
|
||||||
|
|
||||||
|
|
|
@ -17,6 +17,10 @@ import java.util.function.Supplier;
|
||||||
*/
|
*/
|
||||||
public class MeanResponseCombiner implements ResponseCombiner<Double, MeanResponseCombiner.Container> {
|
public class MeanResponseCombiner implements ResponseCombiner<Double, MeanResponseCombiner.Container> {
|
||||||
|
|
||||||
|
static{
|
||||||
|
ResponseCombiner.registerResponseCombiner("MeanResponseCombiner", new MeanResponseCombiner());
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Double combine(List<Double> responses) {
|
public Double combine(List<Double> responses) {
|
||||||
double size = responses.size();
|
double size = responses.size();
|
||||||
|
|
|
@ -6,6 +6,10 @@ import java.util.List;
|
||||||
|
|
||||||
public class WeightedVarianceGroupDifferentiator implements GroupDifferentiator<Double> {
|
public class WeightedVarianceGroupDifferentiator implements GroupDifferentiator<Double> {
|
||||||
|
|
||||||
|
static{
|
||||||
|
GroupDifferentiator.registerGroupDifferentiator("WeightedVarianceGroupDifferentiator", new WeightedVarianceGroupDifferentiator());
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Double differentiate(List<Double> leftHand, List<Double> rightHand) {
|
public Double differentiate(List<Double> leftHand, List<Double> rightHand) {
|
||||||
|
|
||||||
|
|
|
@ -1,8 +1,11 @@
|
||||||
package ca.joeltherrien.randomforest.tree;
|
package ca.joeltherrien.randomforest.tree;
|
||||||
|
|
||||||
import ca.joeltherrien.randomforest.Bootstrapper;
|
import ca.joeltherrien.randomforest.Bootstrapper;
|
||||||
|
import ca.joeltherrien.randomforest.Settings;
|
||||||
import ca.joeltherrien.randomforest.covariates.Covariate;
|
import ca.joeltherrien.randomforest.covariates.Covariate;
|
||||||
import ca.joeltherrien.randomforest.Row;
|
import ca.joeltherrien.randomforest.Row;
|
||||||
|
import lombok.AccessLevel;
|
||||||
|
import lombok.AllArgsConstructor;
|
||||||
import lombok.Builder;
|
import lombok.Builder;
|
||||||
|
|
||||||
import java.io.FileOutputStream;
|
import java.io.FileOutputStream;
|
||||||
|
@ -18,6 +21,7 @@ import java.util.stream.Collectors;
|
||||||
import java.util.stream.Stream;
|
import java.util.stream.Stream;
|
||||||
|
|
||||||
@Builder
|
@Builder
|
||||||
|
@AllArgsConstructor(access=AccessLevel.PRIVATE)
|
||||||
public class ForestTrainer<Y> {
|
public class ForestTrainer<Y> {
|
||||||
|
|
||||||
private final TreeTrainer<Y> treeTrainer;
|
private final TreeTrainer<Y> treeTrainer;
|
||||||
|
@ -34,6 +38,19 @@ public class ForestTrainer<Y> {
|
||||||
private final boolean displayProgress;
|
private final boolean displayProgress;
|
||||||
private final String saveTreeLocation;
|
private final String saveTreeLocation;
|
||||||
|
|
||||||
|
public ForestTrainer(final Settings settings, final List<Row<Y>> data, final List<Covariate> covariates){
|
||||||
|
this.mtry = settings.getMtry();
|
||||||
|
this.ntree = settings.getNtree();
|
||||||
|
this.data = data;
|
||||||
|
this.displayProgress = true;
|
||||||
|
this.saveTreeLocation = settings.getSaveTreeLocation();
|
||||||
|
|
||||||
|
this.covariatesToTry = covariates;
|
||||||
|
this.treeResponseCombiner = ResponseCombiner.loadResponseCombinerByName(settings.getTreeResponseCombiner());
|
||||||
|
this.treeTrainer = new TreeTrainer<>(settings);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
public Forest<Y> trainSerial(){
|
public Forest<Y> trainSerial(){
|
||||||
|
|
||||||
final List<Node<Y>> trees = new ArrayList<>(ntree);
|
final List<Node<Y>> trees = new ArrayList<>(ntree);
|
||||||
|
|
|
@ -1,6 +1,8 @@
|
||||||
package ca.joeltherrien.randomforest.tree;
|
package ca.joeltherrien.randomforest.tree;
|
||||||
|
|
||||||
|
import java.util.HashMap;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* When choosing an optimal node to split on, we choose the split that maximizes the difference between the two groups.
|
* When choosing an optimal node to split on, we choose the split that maximizes the difference between the two groups.
|
||||||
|
@ -12,4 +14,12 @@ public interface GroupDifferentiator<Y> {
|
||||||
|
|
||||||
Double differentiate(List<Y> leftHand, List<Y> rightHand);
|
Double differentiate(List<Y> leftHand, List<Y> rightHand);
|
||||||
|
|
||||||
|
Map<String, GroupDifferentiator> GROUP_DIFFERENTIATOR_MAP = new HashMap<>();
|
||||||
|
static GroupDifferentiator loadGroupDifferentiatorByName(final String name){
|
||||||
|
return GROUP_DIFFERENTIATOR_MAP.get(name);
|
||||||
|
}
|
||||||
|
static void registerGroupDifferentiator(final String name, final GroupDifferentiator groupDifferentiator){
|
||||||
|
GROUP_DIFFERENTIATOR_MAP.put(name, groupDifferentiator);
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,10 +1,22 @@
|
||||||
package ca.joeltherrien.randomforest.tree;
|
package ca.joeltherrien.randomforest.tree;
|
||||||
|
|
||||||
|
import ca.joeltherrien.randomforest.regression.MeanResponseCombiner;
|
||||||
|
|
||||||
|
import java.util.HashMap;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
import java.util.stream.Collector;
|
import java.util.stream.Collector;
|
||||||
|
|
||||||
public interface ResponseCombiner<Y, K> extends Collector<Y, K, Y> {
|
public interface ResponseCombiner<Y, K> extends Collector<Y, K, Y> {
|
||||||
|
|
||||||
Y combine(List<Y> responses);
|
Y combine(List<Y> responses);
|
||||||
|
|
||||||
|
final static Map<String, ResponseCombiner> RESPONSE_COMBINER_MAP = new HashMap<>();
|
||||||
|
static ResponseCombiner loadResponseCombinerByName(final String name){
|
||||||
|
return RESPONSE_COMBINER_MAP.get(name);
|
||||||
|
}
|
||||||
|
static void registerResponseCombiner(final String name, final ResponseCombiner responseCombiner){
|
||||||
|
RESPONSE_COMBINER_MAP.put(name, responseCombiner);
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,12 +2,15 @@ package ca.joeltherrien.randomforest.tree;
|
||||||
|
|
||||||
import ca.joeltherrien.randomforest.*;
|
import ca.joeltherrien.randomforest.*;
|
||||||
import ca.joeltherrien.randomforest.covariates.Covariate;
|
import ca.joeltherrien.randomforest.covariates.Covariate;
|
||||||
|
import lombok.AccessLevel;
|
||||||
|
import lombok.AllArgsConstructor;
|
||||||
import lombok.Builder;
|
import lombok.Builder;
|
||||||
|
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
@Builder
|
@Builder
|
||||||
|
@AllArgsConstructor(access = AccessLevel.PRIVATE)
|
||||||
public class TreeTrainer<Y> {
|
public class TreeTrainer<Y> {
|
||||||
|
|
||||||
private final ResponseCombiner<Y, ?> responseCombiner;
|
private final ResponseCombiner<Y, ?> responseCombiner;
|
||||||
|
@ -21,6 +24,14 @@ public class TreeTrainer<Y> {
|
||||||
private final int nodeSize;
|
private final int nodeSize;
|
||||||
private final int maxNodeDepth;
|
private final int maxNodeDepth;
|
||||||
|
|
||||||
|
public TreeTrainer(final Settings settings){
|
||||||
|
this.numberOfSplits = settings.getNumberOfSplits();
|
||||||
|
this.nodeSize = settings.getNodeSize();
|
||||||
|
this.maxNodeDepth = settings.getMaxNodeDepth();
|
||||||
|
|
||||||
|
this.responseCombiner = ResponseCombiner.loadResponseCombinerByName(settings.getResponseCombiner());
|
||||||
|
this.groupDifferentiator = GroupDifferentiator.loadGroupDifferentiatorByName(settings.getGroupDifferentiator());
|
||||||
|
}
|
||||||
|
|
||||||
public Node<Y> growTree(List<Row<Y>> data, List<Covariate> covariatesToTry){
|
public Node<Y> growTree(List<Row<Y>> data, List<Covariate> covariatesToTry){
|
||||||
return growNode(data, covariatesToTry, 0);
|
return growNode(data, covariatesToTry, 0);
|
||||||
|
|
|
@ -0,0 +1,66 @@
|
||||||
|
package ca.joeltherrien.randomforest.csv;
|
||||||
|
|
||||||
|
import ca.joeltherrien.randomforest.Main;
|
||||||
|
import ca.joeltherrien.randomforest.Row;
|
||||||
|
import ca.joeltherrien.randomforest.Settings;
|
||||||
|
import ca.joeltherrien.randomforest.covariates.BooleanCovariateSettings;
|
||||||
|
import ca.joeltherrien.randomforest.covariates.Covariate;
|
||||||
|
import ca.joeltherrien.randomforest.covariates.FactorCovariateSettings;
|
||||||
|
import ca.joeltherrien.randomforest.covariates.NumericCovariateSettings;
|
||||||
|
import org.junit.jupiter.api.Test;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
|
||||||
|
public class TestLoadingCSV {
|
||||||
|
|
||||||
|
/*
|
||||||
|
y,x1,x2,x3
|
||||||
|
5,3.0,"mouse",true
|
||||||
|
2,1.0,"dog",false
|
||||||
|
9,1.5,"cat",true
|
||||||
|
*/
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void verifyLoading() throws IOException {
|
||||||
|
final Settings settings = Settings.builder()
|
||||||
|
.dataFileLocation("src/test/resources/testCSV.csv")
|
||||||
|
.covariates(
|
||||||
|
List.of(new NumericCovariateSettings("x1"),
|
||||||
|
new FactorCovariateSettings("x2", List.of("dog", "cat", "mouse")),
|
||||||
|
new BooleanCovariateSettings("x3"))
|
||||||
|
)
|
||||||
|
.yVar("y")
|
||||||
|
.build();
|
||||||
|
|
||||||
|
final List<Covariate> covariates = settings.getCovariates().stream()
|
||||||
|
.map(cs -> cs.build()).collect(Collectors.toList());
|
||||||
|
|
||||||
|
final List<Row<Double>> data = Main.loadData(covariates, settings);
|
||||||
|
|
||||||
|
assertEquals(3, data.size());
|
||||||
|
|
||||||
|
Row<Double> row = data.get(0);
|
||||||
|
assertEquals(5.0, (double)row.getResponse());
|
||||||
|
assertEquals(3.0, row.getCovariateValue("x1").getValue());
|
||||||
|
assertEquals("mouse", row.getCovariateValue("x2").getValue());
|
||||||
|
assertEquals(true, row.getCovariateValue("x3").getValue());
|
||||||
|
|
||||||
|
row = data.get(1);
|
||||||
|
assertEquals(2.0, (double)row.getResponse());
|
||||||
|
assertEquals(1.0, row.getCovariateValue("x1").getValue());
|
||||||
|
assertEquals("dog", row.getCovariateValue("x2").getValue());
|
||||||
|
assertEquals(false, row.getCovariateValue("x3").getValue());
|
||||||
|
|
||||||
|
row = data.get(2);
|
||||||
|
assertEquals(9.0, (double)row.getResponse());
|
||||||
|
assertEquals(1.5, row.getCovariateValue("x1").getValue());
|
||||||
|
assertEquals("cat", row.getCovariateValue("x2").getValue());
|
||||||
|
assertEquals(true, row.getCovariateValue("x3").getValue());
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -4,6 +4,8 @@ import ca.joeltherrien.randomforest.Settings;
|
||||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
|
||||||
import ca.joeltherrien.randomforest.covariates.*;
|
import ca.joeltherrien.randomforest.covariates.*;
|
||||||
|
import ca.joeltherrien.randomforest.regression.MeanResponseCombiner;
|
||||||
|
import ca.joeltherrien.randomforest.regression.WeightedVarianceGroupDifferentiator;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
|
|
||||||
import java.io.File;
|
import java.io.File;
|
||||||
|
@ -22,6 +24,7 @@ public class TestPersistence {
|
||||||
new FactorCovariateSettings("x3", List.of("cat", "mouse", "dog"))
|
new FactorCovariateSettings("x3", List.of("cat", "mouse", "dog"))
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
.yVar("y")
|
||||||
.dataFileLocation("data.csv")
|
.dataFileLocation("data.csv")
|
||||||
.groupDifferentiator("WeightedVarianceGroupDifferentiator")
|
.groupDifferentiator("WeightedVarianceGroupDifferentiator")
|
||||||
.responseCombiner("MeanResponseCombiner")
|
.responseCombiner("MeanResponseCombiner")
|
||||||
|
|
4
src/test/resources/testCSV.csv
Normal file
4
src/test/resources/testCSV.csv
Normal file
|
@ -0,0 +1,4 @@
|
||||||
|
y,x1,x2,x3
|
||||||
|
5,3.0,"mouse",true
|
||||||
|
2,1.0,"dog",false
|
||||||
|
9,1.5,"cat",true
|
|
Loading…
Reference in a new issue