diff --git a/pom.xml b/pom.xml
index 36c9a38..cdb9617 100644
--- a/pom.xml
+++ b/pom.xml
@@ -58,5 +58,33 @@
+
+
+
+ maven-assembly-plugin
+
+
+
+ ca.joeltherrien.randomforest.Main
+
+
+
+ jar-with-dependencies
+
+
+
+
+ make-assembly
+ package
+
+ single
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/src/main/java/ca/joeltherrien/randomforest/Main.java b/src/main/java/ca/joeltherrien/randomforest/Main.java
new file mode 100644
index 0000000..f557663
--- /dev/null
+++ b/src/main/java/ca/joeltherrien/randomforest/Main.java
@@ -0,0 +1,105 @@
+package ca.joeltherrien.randomforest;
+
+import ca.joeltherrien.randomforest.covariates.BooleanCovariateSettings;
+import ca.joeltherrien.randomforest.covariates.Covariate;
+import ca.joeltherrien.randomforest.covariates.FactorCovariateSettings;
+import ca.joeltherrien.randomforest.covariates.NumericCovariateSettings;
+import ca.joeltherrien.randomforest.tree.ForestTrainer;
+import org.apache.commons.csv.CSVFormat;
+import org.apache.commons.csv.CSVParser;
+import org.apache.commons.csv.CSVRecord;
+
+import java.io.File;
+import java.io.FileReader;
+import java.io.IOException;
+import java.io.Reader;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.stream.Collectors;
+
+public class Main {
+
+ public static void main(String[] args) throws IOException {
+ if(args.length != 1){
+ System.out.println("Must provide one argument - the path to the settings.yaml file.");
+ if(args.length == 0){
+ System.out.println("Generating template file.");
+ defaultTemplate().save(new File("template.yaml"));
+ }
+ return;
+ }
+ final File settingsFile = new File(args[0]);
+ final Settings settings = Settings.load(settingsFile);
+
+ final List covariates = settings.getCovariates().stream()
+ .map(cs -> cs.build()).collect(Collectors.toList());
+
+ final List> dataset = loadData(covariates, settings);
+
+ final ForestTrainer forestTrainer = new ForestTrainer<>(settings, dataset, covariates);
+
+ if(settings.isSaveProgress()){
+ forestTrainer.trainParallelOnDisk(settings.getNumberOfThreads());
+ }
+ else{
+ forestTrainer.trainParallelInMemory(settings.getNumberOfThreads());
+ }
+
+
+ }
+
+
+ public static List> loadData(final List covariates, final Settings settings) throws IOException {
+
+ final List> dataset = new ArrayList<>();
+
+ final Reader input = new FileReader(settings.getDataFileLocation());
+ final CSVParser parser = CSVFormat.RFC4180.withFirstRecordAsHeader().parse(input);
+
+
+ int id = 1;
+ for(final CSVRecord record : parser){
+ final Map covariateValueMap = new HashMap<>();
+
+ for(final Covariate> covariate : covariates){
+ covariateValueMap.put(covariate.getName(), covariate.createValue(record.get(covariate.getName())));
+ }
+
+ final String yStr = record.get(settings.getYVar());
+ final Double yNum = Double.parseDouble(yStr);
+
+ dataset.add(new Row<>(covariateValueMap, id++, yNum));
+
+ }
+
+ return dataset;
+
+ }
+
+ private static Settings defaultTemplate(){
+ return Settings.builder()
+ .covariates(List.of(
+ new NumericCovariateSettings("x1"),
+ new BooleanCovariateSettings("x2"),
+ new FactorCovariateSettings("x3", List.of("cat", "mouse", "dog"))
+ )
+ )
+ .yVar("y")
+ .dataFileLocation("data.csv")
+ .groupDifferentiator("WeightedVarianceGroupDifferentiator")
+ .responseCombiner("MeanResponseCombiner")
+ .treeResponseCombiner("MeanResponseCombiner")
+ .maxNodeDepth(100000)
+ .mtry(2)
+ .nodeSize(5)
+ .ntree(500)
+ .numberOfSplits(5)
+ .numberOfThreads(1)
+ .saveProgress(true)
+ .saveTreeLocation("trees/")
+ .build();
+ }
+
+}
diff --git a/src/main/java/ca/joeltherrien/randomforest/Settings.java b/src/main/java/ca/joeltherrien/randomforest/Settings.java
index bf3985a..da24933 100644
--- a/src/main/java/ca/joeltherrien/randomforest/Settings.java
+++ b/src/main/java/ca/joeltherrien/randomforest/Settings.java
@@ -2,6 +2,9 @@ package ca.joeltherrien.randomforest;
import ca.joeltherrien.randomforest.covariates.Covariate;
import ca.joeltherrien.randomforest.covariates.CovariateSettings;
+import ca.joeltherrien.randomforest.regression.MeanResponseCombiner;
+import ca.joeltherrien.randomforest.tree.GroupDifferentiator;
+import ca.joeltherrien.randomforest.tree.ResponseCombiner;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.dataformat.yaml.YAMLFactory;
import lombok.*;
@@ -9,7 +12,9 @@ import lombok.*;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
+import java.util.HashMap;
import java.util.List;
+import java.util.Map;
/**
* This class is saved & loaded using a saved configuration file. It contains all relevant settings when training a forest.
@@ -29,6 +34,7 @@ public class Settings {
private String treeResponseCombiner;
private List covariates = new ArrayList<>();
+ private String yVar = "y";
// number of covariates to randomly try
private int mtry = 0;
@@ -64,5 +70,4 @@ public class Settings {
mapper.writeValue(file, this);
}
-
}
diff --git a/src/main/java/ca/joeltherrien/randomforest/covariates/BooleanCovariate.java b/src/main/java/ca/joeltherrien/randomforest/covariates/BooleanCovariate.java
index 6c6548d..9b5250e 100644
--- a/src/main/java/ca/joeltherrien/randomforest/covariates/BooleanCovariate.java
+++ b/src/main/java/ca/joeltherrien/randomforest/covariates/BooleanCovariate.java
@@ -25,6 +25,23 @@ public final class BooleanCovariate implements Covariate{
return new BooleanValue(value);
}
+ @Override
+ public Value createValue(String value) {
+ if(value == null || value.equalsIgnoreCase("na")){
+ return createValue( (Boolean) null);
+ }
+
+ if(value.equalsIgnoreCase("true")){
+ return createValue(true);
+ }
+ else if(value.equalsIgnoreCase("false")){
+ return createValue(false);
+ }
+ else{
+ throw new IllegalArgumentException("Require either true/false/na to create BooleanCovariate");
+ }
+ }
+
public class BooleanValue implements Value{
private final Boolean value;
diff --git a/src/main/java/ca/joeltherrien/randomforest/covariates/BooleanCovariateSettings.java b/src/main/java/ca/joeltherrien/randomforest/covariates/BooleanCovariateSettings.java
index 2bf6263..ed384f3 100644
--- a/src/main/java/ca/joeltherrien/randomforest/covariates/BooleanCovariateSettings.java
+++ b/src/main/java/ca/joeltherrien/randomforest/covariates/BooleanCovariateSettings.java
@@ -12,7 +12,7 @@ public final class BooleanCovariateSettings extends CovariateSettings {
}
@Override
- BooleanCovariate build() {
+ public BooleanCovariate build() {
return new BooleanCovariate(name);
}
}
diff --git a/src/main/java/ca/joeltherrien/randomforest/covariates/Covariate.java b/src/main/java/ca/joeltherrien/randomforest/covariates/Covariate.java
index 75bc672..e7a5aef 100644
--- a/src/main/java/ca/joeltherrien/randomforest/covariates/Covariate.java
+++ b/src/main/java/ca/joeltherrien/randomforest/covariates/Covariate.java
@@ -16,6 +16,14 @@ public interface Covariate extends Serializable {
Value createValue(V value);
+ /**
+ * Creates a Value of the appropriate type from a String; primarily used when parsing CSVs.
+ *
+ * @param value
+ * @return
+ */
+ Value createValue(String value);
+
interface Value extends Serializable{
Covariate getParent();
diff --git a/src/main/java/ca/joeltherrien/randomforest/covariates/CovariateSettings.java b/src/main/java/ca/joeltherrien/randomforest/covariates/CovariateSettings.java
index baafe4c..1418428 100644
--- a/src/main/java/ca/joeltherrien/randomforest/covariates/CovariateSettings.java
+++ b/src/main/java/ca/joeltherrien/randomforest/covariates/CovariateSettings.java
@@ -29,5 +29,5 @@ public abstract class CovariateSettings {
this.name = name;
}
- abstract Covariate build();
+ public abstract Covariate build();
}
diff --git a/src/main/java/ca/joeltherrien/randomforest/covariates/FactorCovariateSettings.java b/src/main/java/ca/joeltherrien/randomforest/covariates/FactorCovariateSettings.java
index 04d4bb8..dbfaaae 100644
--- a/src/main/java/ca/joeltherrien/randomforest/covariates/FactorCovariateSettings.java
+++ b/src/main/java/ca/joeltherrien/randomforest/covariates/FactorCovariateSettings.java
@@ -18,7 +18,7 @@ public final class FactorCovariateSettings extends CovariateSettings {
}
@Override
- FactorCovariate build() {
+ public FactorCovariate build() {
return new FactorCovariate(name, levels);
}
}
diff --git a/src/main/java/ca/joeltherrien/randomforest/covariates/NumericCovariate.java b/src/main/java/ca/joeltherrien/randomforest/covariates/NumericCovariate.java
index 6bed687..2746697 100644
--- a/src/main/java/ca/joeltherrien/randomforest/covariates/NumericCovariate.java
+++ b/src/main/java/ca/joeltherrien/randomforest/covariates/NumericCovariate.java
@@ -55,6 +55,15 @@ public final class NumericCovariate implements Covariate{
return new NumericValue(value);
}
+ @Override
+ public Value createValue(String value) {
+ if(value == null || value.equalsIgnoreCase("na")){
+ return createValue((Double) null);
+ }
+
+ return createValue(Double.parseDouble(value));
+ }
+
public class NumericValue implements Covariate.Value{
private final Double value; // may be null
diff --git a/src/main/java/ca/joeltherrien/randomforest/covariates/NumericCovariateSettings.java b/src/main/java/ca/joeltherrien/randomforest/covariates/NumericCovariateSettings.java
index 0be6cf0..b35a81a 100644
--- a/src/main/java/ca/joeltherrien/randomforest/covariates/NumericCovariateSettings.java
+++ b/src/main/java/ca/joeltherrien/randomforest/covariates/NumericCovariateSettings.java
@@ -12,7 +12,7 @@ public final class NumericCovariateSettings extends CovariateSettings {
}
@Override
- NumericCovariate build() {
+ public NumericCovariate build() {
return new NumericCovariate(name);
}
}
diff --git a/src/main/java/ca/joeltherrien/randomforest/regression/MeanGroupDifferentiator.java b/src/main/java/ca/joeltherrien/randomforest/regression/MeanGroupDifferentiator.java
index 7c196e6..da3b0b1 100644
--- a/src/main/java/ca/joeltherrien/randomforest/regression/MeanGroupDifferentiator.java
+++ b/src/main/java/ca/joeltherrien/randomforest/regression/MeanGroupDifferentiator.java
@@ -1,11 +1,16 @@
package ca.joeltherrien.randomforest.regression;
import ca.joeltherrien.randomforest.tree.GroupDifferentiator;
+import ca.joeltherrien.randomforest.tree.ResponseCombiner;
import java.util.List;
public class MeanGroupDifferentiator implements GroupDifferentiator {
+ static{
+ GroupDifferentiator.registerGroupDifferentiator("MeanGroupDifferentiator", new MeanGroupDifferentiator());
+ }
+
@Override
public Double differentiate(List leftHand, List rightHand) {
diff --git a/src/main/java/ca/joeltherrien/randomforest/regression/MeanResponseCombiner.java b/src/main/java/ca/joeltherrien/randomforest/regression/MeanResponseCombiner.java
index 57a3b35..3ff43f3 100644
--- a/src/main/java/ca/joeltherrien/randomforest/regression/MeanResponseCombiner.java
+++ b/src/main/java/ca/joeltherrien/randomforest/regression/MeanResponseCombiner.java
@@ -17,6 +17,10 @@ import java.util.function.Supplier;
*/
public class MeanResponseCombiner implements ResponseCombiner {
+ static{
+ ResponseCombiner.registerResponseCombiner("MeanResponseCombiner", new MeanResponseCombiner());
+ }
+
@Override
public Double combine(List responses) {
double size = responses.size();
diff --git a/src/main/java/ca/joeltherrien/randomforest/regression/WeightedVarianceGroupDifferentiator.java b/src/main/java/ca/joeltherrien/randomforest/regression/WeightedVarianceGroupDifferentiator.java
index 2f40999..c27363f 100644
--- a/src/main/java/ca/joeltherrien/randomforest/regression/WeightedVarianceGroupDifferentiator.java
+++ b/src/main/java/ca/joeltherrien/randomforest/regression/WeightedVarianceGroupDifferentiator.java
@@ -6,6 +6,10 @@ import java.util.List;
public class WeightedVarianceGroupDifferentiator implements GroupDifferentiator {
+ static{
+ GroupDifferentiator.registerGroupDifferentiator("WeightedVarianceGroupDifferentiator", new WeightedVarianceGroupDifferentiator());
+ }
+
@Override
public Double differentiate(List leftHand, List rightHand) {
diff --git a/src/main/java/ca/joeltherrien/randomforest/tree/ForestTrainer.java b/src/main/java/ca/joeltherrien/randomforest/tree/ForestTrainer.java
index 76d4a70..f04a2f2 100644
--- a/src/main/java/ca/joeltherrien/randomforest/tree/ForestTrainer.java
+++ b/src/main/java/ca/joeltherrien/randomforest/tree/ForestTrainer.java
@@ -1,8 +1,11 @@
package ca.joeltherrien.randomforest.tree;
import ca.joeltherrien.randomforest.Bootstrapper;
+import ca.joeltherrien.randomforest.Settings;
import ca.joeltherrien.randomforest.covariates.Covariate;
import ca.joeltherrien.randomforest.Row;
+import lombok.AccessLevel;
+import lombok.AllArgsConstructor;
import lombok.Builder;
import java.io.FileOutputStream;
@@ -18,6 +21,7 @@ import java.util.stream.Collectors;
import java.util.stream.Stream;
@Builder
+@AllArgsConstructor(access=AccessLevel.PRIVATE)
public class ForestTrainer {
private final TreeTrainer treeTrainer;
@@ -34,6 +38,19 @@ public class ForestTrainer {
private final boolean displayProgress;
private final String saveTreeLocation;
+ public ForestTrainer(final Settings settings, final List> data, final List covariates){
+ this.mtry = settings.getMtry();
+ this.ntree = settings.getNtree();
+ this.data = data;
+ this.displayProgress = true;
+ this.saveTreeLocation = settings.getSaveTreeLocation();
+
+ this.covariatesToTry = covariates;
+ this.treeResponseCombiner = ResponseCombiner.loadResponseCombinerByName(settings.getTreeResponseCombiner());
+ this.treeTrainer = new TreeTrainer<>(settings);
+
+ }
+
public Forest trainSerial(){
final List> trees = new ArrayList<>(ntree);
diff --git a/src/main/java/ca/joeltherrien/randomforest/tree/GroupDifferentiator.java b/src/main/java/ca/joeltherrien/randomforest/tree/GroupDifferentiator.java
index cbd1247..8950aa8 100644
--- a/src/main/java/ca/joeltherrien/randomforest/tree/GroupDifferentiator.java
+++ b/src/main/java/ca/joeltherrien/randomforest/tree/GroupDifferentiator.java
@@ -1,6 +1,8 @@
package ca.joeltherrien.randomforest.tree;
+import java.util.HashMap;
import java.util.List;
+import java.util.Map;
/**
* When choosing an optimal node to split on, we choose the split that maximizes the difference between the two groups.
@@ -12,4 +14,12 @@ public interface GroupDifferentiator {
Double differentiate(List leftHand, List rightHand);
+ Map GROUP_DIFFERENTIATOR_MAP = new HashMap<>();
+ static GroupDifferentiator loadGroupDifferentiatorByName(final String name){
+ return GROUP_DIFFERENTIATOR_MAP.get(name);
+ }
+ static void registerGroupDifferentiator(final String name, final GroupDifferentiator groupDifferentiator){
+ GROUP_DIFFERENTIATOR_MAP.put(name, groupDifferentiator);
+ }
+
}
diff --git a/src/main/java/ca/joeltherrien/randomforest/tree/ResponseCombiner.java b/src/main/java/ca/joeltherrien/randomforest/tree/ResponseCombiner.java
index bbcdd30..6430aa0 100644
--- a/src/main/java/ca/joeltherrien/randomforest/tree/ResponseCombiner.java
+++ b/src/main/java/ca/joeltherrien/randomforest/tree/ResponseCombiner.java
@@ -1,10 +1,22 @@
package ca.joeltherrien.randomforest.tree;
+import ca.joeltherrien.randomforest.regression.MeanResponseCombiner;
+
+import java.util.HashMap;
import java.util.List;
+import java.util.Map;
import java.util.stream.Collector;
public interface ResponseCombiner extends Collector {
Y combine(List responses);
+ final static Map RESPONSE_COMBINER_MAP = new HashMap<>();
+ static ResponseCombiner loadResponseCombinerByName(final String name){
+ return RESPONSE_COMBINER_MAP.get(name);
+ }
+ static void registerResponseCombiner(final String name, final ResponseCombiner responseCombiner){
+ RESPONSE_COMBINER_MAP.put(name, responseCombiner);
+ }
+
}
diff --git a/src/main/java/ca/joeltherrien/randomforest/tree/TreeTrainer.java b/src/main/java/ca/joeltherrien/randomforest/tree/TreeTrainer.java
index fd83382..d447952 100644
--- a/src/main/java/ca/joeltherrien/randomforest/tree/TreeTrainer.java
+++ b/src/main/java/ca/joeltherrien/randomforest/tree/TreeTrainer.java
@@ -2,12 +2,15 @@ package ca.joeltherrien.randomforest.tree;
import ca.joeltherrien.randomforest.*;
import ca.joeltherrien.randomforest.covariates.Covariate;
+import lombok.AccessLevel;
+import lombok.AllArgsConstructor;
import lombok.Builder;
import java.util.*;
import java.util.stream.Collectors;
@Builder
+@AllArgsConstructor(access = AccessLevel.PRIVATE)
public class TreeTrainer {
private final ResponseCombiner responseCombiner;
@@ -21,6 +24,14 @@ public class TreeTrainer {
private final int nodeSize;
private final int maxNodeDepth;
+ public TreeTrainer(final Settings settings){
+ this.numberOfSplits = settings.getNumberOfSplits();
+ this.nodeSize = settings.getNodeSize();
+ this.maxNodeDepth = settings.getMaxNodeDepth();
+
+ this.responseCombiner = ResponseCombiner.loadResponseCombinerByName(settings.getResponseCombiner());
+ this.groupDifferentiator = GroupDifferentiator.loadGroupDifferentiatorByName(settings.getGroupDifferentiator());
+ }
public Node growTree(List> data, List covariatesToTry){
return growNode(data, covariatesToTry, 0);
diff --git a/src/test/java/ca/joeltherrien/randomforest/csv/TestLoadingCSV.java b/src/test/java/ca/joeltherrien/randomforest/csv/TestLoadingCSV.java
new file mode 100644
index 0000000..5a1faf7
--- /dev/null
+++ b/src/test/java/ca/joeltherrien/randomforest/csv/TestLoadingCSV.java
@@ -0,0 +1,66 @@
+package ca.joeltherrien.randomforest.csv;
+
+import ca.joeltherrien.randomforest.Main;
+import ca.joeltherrien.randomforest.Row;
+import ca.joeltherrien.randomforest.Settings;
+import ca.joeltherrien.randomforest.covariates.BooleanCovariateSettings;
+import ca.joeltherrien.randomforest.covariates.Covariate;
+import ca.joeltherrien.randomforest.covariates.FactorCovariateSettings;
+import ca.joeltherrien.randomforest.covariates.NumericCovariateSettings;
+import org.junit.jupiter.api.Test;
+
+import java.io.IOException;
+import java.util.List;
+import java.util.stream.Collectors;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+
+public class TestLoadingCSV {
+
+ /*
+ y,x1,x2,x3
+ 5,3.0,"mouse",true
+ 2,1.0,"dog",false
+ 9,1.5,"cat",true
+ */
+
+ @Test
+ public void verifyLoading() throws IOException {
+ final Settings settings = Settings.builder()
+ .dataFileLocation("src/test/resources/testCSV.csv")
+ .covariates(
+ List.of(new NumericCovariateSettings("x1"),
+ new FactorCovariateSettings("x2", List.of("dog", "cat", "mouse")),
+ new BooleanCovariateSettings("x3"))
+ )
+ .yVar("y")
+ .build();
+
+ final List covariates = settings.getCovariates().stream()
+ .map(cs -> cs.build()).collect(Collectors.toList());
+
+ final List> data = Main.loadData(covariates, settings);
+
+ assertEquals(3, data.size());
+
+ Row row = data.get(0);
+ assertEquals(5.0, (double)row.getResponse());
+ assertEquals(3.0, row.getCovariateValue("x1").getValue());
+ assertEquals("mouse", row.getCovariateValue("x2").getValue());
+ assertEquals(true, row.getCovariateValue("x3").getValue());
+
+ row = data.get(1);
+ assertEquals(2.0, (double)row.getResponse());
+ assertEquals(1.0, row.getCovariateValue("x1").getValue());
+ assertEquals("dog", row.getCovariateValue("x2").getValue());
+ assertEquals(false, row.getCovariateValue("x3").getValue());
+
+ row = data.get(2);
+ assertEquals(9.0, (double)row.getResponse());
+ assertEquals(1.5, row.getCovariateValue("x1").getValue());
+ assertEquals("cat", row.getCovariateValue("x2").getValue());
+ assertEquals(true, row.getCovariateValue("x3").getValue());
+
+ }
+
+}
diff --git a/src/test/java/ca/joeltherrien/randomforest/settings/TestPersistence.java b/src/test/java/ca/joeltherrien/randomforest/settings/TestPersistence.java
index 562df01..421ca65 100644
--- a/src/test/java/ca/joeltherrien/randomforest/settings/TestPersistence.java
+++ b/src/test/java/ca/joeltherrien/randomforest/settings/TestPersistence.java
@@ -4,6 +4,8 @@ import ca.joeltherrien.randomforest.Settings;
import static org.junit.jupiter.api.Assertions.assertEquals;
import ca.joeltherrien.randomforest.covariates.*;
+import ca.joeltherrien.randomforest.regression.MeanResponseCombiner;
+import ca.joeltherrien.randomforest.regression.WeightedVarianceGroupDifferentiator;
import org.junit.jupiter.api.Test;
import java.io.File;
@@ -22,6 +24,7 @@ public class TestPersistence {
new FactorCovariateSettings("x3", List.of("cat", "mouse", "dog"))
)
)
+ .yVar("y")
.dataFileLocation("data.csv")
.groupDifferentiator("WeightedVarianceGroupDifferentiator")
.responseCombiner("MeanResponseCombiner")
diff --git a/src/test/resources/testCSV.csv b/src/test/resources/testCSV.csv
new file mode 100644
index 0000000..ca1d181
--- /dev/null
+++ b/src/test/resources/testCSV.csv
@@ -0,0 +1,4 @@
+y,x1,x2,x3
+5,3.0,"mouse",true
+2,1.0,"dog",false
+9,1.5,"cat",true
\ No newline at end of file