From 6b62ad95c348ca760774ddf31c8982bb8a3d1f61 Mon Sep 17 00:00:00 2001 From: Joel Therrien Date: Fri, 6 Jul 2018 13:21:34 -0700 Subject: [PATCH] Add support for loading datasets by CSV files. --- pom.xml | 28 +++++ .../ca/joeltherrien/randomforest/Main.java | 105 ++++++++++++++++++ .../joeltherrien/randomforest/Settings.java | 7 +- .../covariates/BooleanCovariate.java | 17 +++ .../covariates/BooleanCovariateSettings.java | 2 +- .../randomforest/covariates/Covariate.java | 8 ++ .../covariates/CovariateSettings.java | 2 +- .../covariates/FactorCovariateSettings.java | 2 +- .../covariates/NumericCovariate.java | 9 ++ .../covariates/NumericCovariateSettings.java | 2 +- .../regression/MeanGroupDifferentiator.java | 5 + .../regression/MeanResponseCombiner.java | 4 + .../WeightedVarianceGroupDifferentiator.java | 4 + .../randomforest/tree/ForestTrainer.java | 17 +++ .../tree/GroupDifferentiator.java | 10 ++ .../randomforest/tree/ResponseCombiner.java | 12 ++ .../randomforest/tree/TreeTrainer.java | 11 ++ .../randomforest/csv/TestLoadingCSV.java | 66 +++++++++++ .../settings/TestPersistence.java | 3 + src/test/resources/testCSV.csv | 4 + 20 files changed, 313 insertions(+), 5 deletions(-) create mode 100644 src/main/java/ca/joeltherrien/randomforest/Main.java create mode 100644 src/test/java/ca/joeltherrien/randomforest/csv/TestLoadingCSV.java create mode 100644 src/test/resources/testCSV.csv diff --git a/pom.xml b/pom.xml index 36c9a38..cdb9617 100644 --- a/pom.xml +++ b/pom.xml @@ -58,5 +58,33 @@ + + + + maven-assembly-plugin + + + + ca.joeltherrien.randomforest.Main + + + + jar-with-dependencies + + + + + make-assembly + package + + single + + + + + + + + \ No newline at end of file diff --git a/src/main/java/ca/joeltherrien/randomforest/Main.java b/src/main/java/ca/joeltherrien/randomforest/Main.java new file mode 100644 index 0000000..f557663 --- /dev/null +++ b/src/main/java/ca/joeltherrien/randomforest/Main.java @@ -0,0 +1,105 @@ +package ca.joeltherrien.randomforest; + +import ca.joeltherrien.randomforest.covariates.BooleanCovariateSettings; +import ca.joeltherrien.randomforest.covariates.Covariate; +import ca.joeltherrien.randomforest.covariates.FactorCovariateSettings; +import ca.joeltherrien.randomforest.covariates.NumericCovariateSettings; +import ca.joeltherrien.randomforest.tree.ForestTrainer; +import org.apache.commons.csv.CSVFormat; +import org.apache.commons.csv.CSVParser; +import org.apache.commons.csv.CSVRecord; + +import java.io.File; +import java.io.FileReader; +import java.io.IOException; +import java.io.Reader; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +public class Main { + + public static void main(String[] args) throws IOException { + if(args.length != 1){ + System.out.println("Must provide one argument - the path to the settings.yaml file."); + if(args.length == 0){ + System.out.println("Generating template file."); + defaultTemplate().save(new File("template.yaml")); + } + return; + } + final File settingsFile = new File(args[0]); + final Settings settings = Settings.load(settingsFile); + + final List covariates = settings.getCovariates().stream() + .map(cs -> cs.build()).collect(Collectors.toList()); + + final List> dataset = loadData(covariates, settings); + + final ForestTrainer forestTrainer = new ForestTrainer<>(settings, dataset, covariates); + + if(settings.isSaveProgress()){ + forestTrainer.trainParallelOnDisk(settings.getNumberOfThreads()); + } + else{ + forestTrainer.trainParallelInMemory(settings.getNumberOfThreads()); + } + + + } + + + public static List> loadData(final List covariates, final Settings settings) throws IOException { + + final List> dataset = new ArrayList<>(); + + final Reader input = new FileReader(settings.getDataFileLocation()); + final CSVParser parser = CSVFormat.RFC4180.withFirstRecordAsHeader().parse(input); + + + int id = 1; + for(final CSVRecord record : parser){ + final Map covariateValueMap = new HashMap<>(); + + for(final Covariate covariate : covariates){ + covariateValueMap.put(covariate.getName(), covariate.createValue(record.get(covariate.getName()))); + } + + final String yStr = record.get(settings.getYVar()); + final Double yNum = Double.parseDouble(yStr); + + dataset.add(new Row<>(covariateValueMap, id++, yNum)); + + } + + return dataset; + + } + + private static Settings defaultTemplate(){ + return Settings.builder() + .covariates(List.of( + new NumericCovariateSettings("x1"), + new BooleanCovariateSettings("x2"), + new FactorCovariateSettings("x3", List.of("cat", "mouse", "dog")) + ) + ) + .yVar("y") + .dataFileLocation("data.csv") + .groupDifferentiator("WeightedVarianceGroupDifferentiator") + .responseCombiner("MeanResponseCombiner") + .treeResponseCombiner("MeanResponseCombiner") + .maxNodeDepth(100000) + .mtry(2) + .nodeSize(5) + .ntree(500) + .numberOfSplits(5) + .numberOfThreads(1) + .saveProgress(true) + .saveTreeLocation("trees/") + .build(); + } + +} diff --git a/src/main/java/ca/joeltherrien/randomforest/Settings.java b/src/main/java/ca/joeltherrien/randomforest/Settings.java index bf3985a..da24933 100644 --- a/src/main/java/ca/joeltherrien/randomforest/Settings.java +++ b/src/main/java/ca/joeltherrien/randomforest/Settings.java @@ -2,6 +2,9 @@ package ca.joeltherrien.randomforest; import ca.joeltherrien.randomforest.covariates.Covariate; import ca.joeltherrien.randomforest.covariates.CovariateSettings; +import ca.joeltherrien.randomforest.regression.MeanResponseCombiner; +import ca.joeltherrien.randomforest.tree.GroupDifferentiator; +import ca.joeltherrien.randomforest.tree.ResponseCombiner; import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.dataformat.yaml.YAMLFactory; import lombok.*; @@ -9,7 +12,9 @@ import lombok.*; import java.io.File; import java.io.IOException; import java.util.ArrayList; +import java.util.HashMap; import java.util.List; +import java.util.Map; /** * This class is saved & loaded using a saved configuration file. It contains all relevant settings when training a forest. @@ -29,6 +34,7 @@ public class Settings { private String treeResponseCombiner; private List covariates = new ArrayList<>(); + private String yVar = "y"; // number of covariates to randomly try private int mtry = 0; @@ -64,5 +70,4 @@ public class Settings { mapper.writeValue(file, this); } - } diff --git a/src/main/java/ca/joeltherrien/randomforest/covariates/BooleanCovariate.java b/src/main/java/ca/joeltherrien/randomforest/covariates/BooleanCovariate.java index 6c6548d..9b5250e 100644 --- a/src/main/java/ca/joeltherrien/randomforest/covariates/BooleanCovariate.java +++ b/src/main/java/ca/joeltherrien/randomforest/covariates/BooleanCovariate.java @@ -25,6 +25,23 @@ public final class BooleanCovariate implements Covariate{ return new BooleanValue(value); } + @Override + public Value createValue(String value) { + if(value == null || value.equalsIgnoreCase("na")){ + return createValue( (Boolean) null); + } + + if(value.equalsIgnoreCase("true")){ + return createValue(true); + } + else if(value.equalsIgnoreCase("false")){ + return createValue(false); + } + else{ + throw new IllegalArgumentException("Require either true/false/na to create BooleanCovariate"); + } + } + public class BooleanValue implements Value{ private final Boolean value; diff --git a/src/main/java/ca/joeltherrien/randomforest/covariates/BooleanCovariateSettings.java b/src/main/java/ca/joeltherrien/randomforest/covariates/BooleanCovariateSettings.java index 2bf6263..ed384f3 100644 --- a/src/main/java/ca/joeltherrien/randomforest/covariates/BooleanCovariateSettings.java +++ b/src/main/java/ca/joeltherrien/randomforest/covariates/BooleanCovariateSettings.java @@ -12,7 +12,7 @@ public final class BooleanCovariateSettings extends CovariateSettings { } @Override - BooleanCovariate build() { + public BooleanCovariate build() { return new BooleanCovariate(name); } } diff --git a/src/main/java/ca/joeltherrien/randomforest/covariates/Covariate.java b/src/main/java/ca/joeltherrien/randomforest/covariates/Covariate.java index 75bc672..e7a5aef 100644 --- a/src/main/java/ca/joeltherrien/randomforest/covariates/Covariate.java +++ b/src/main/java/ca/joeltherrien/randomforest/covariates/Covariate.java @@ -16,6 +16,14 @@ public interface Covariate extends Serializable { Value createValue(V value); + /** + * Creates a Value of the appropriate type from a String; primarily used when parsing CSVs. + * + * @param value + * @return + */ + Value createValue(String value); + interface Value extends Serializable{ Covariate getParent(); diff --git a/src/main/java/ca/joeltherrien/randomforest/covariates/CovariateSettings.java b/src/main/java/ca/joeltherrien/randomforest/covariates/CovariateSettings.java index baafe4c..1418428 100644 --- a/src/main/java/ca/joeltherrien/randomforest/covariates/CovariateSettings.java +++ b/src/main/java/ca/joeltherrien/randomforest/covariates/CovariateSettings.java @@ -29,5 +29,5 @@ public abstract class CovariateSettings { this.name = name; } - abstract Covariate build(); + public abstract Covariate build(); } diff --git a/src/main/java/ca/joeltherrien/randomforest/covariates/FactorCovariateSettings.java b/src/main/java/ca/joeltherrien/randomforest/covariates/FactorCovariateSettings.java index 04d4bb8..dbfaaae 100644 --- a/src/main/java/ca/joeltherrien/randomforest/covariates/FactorCovariateSettings.java +++ b/src/main/java/ca/joeltherrien/randomforest/covariates/FactorCovariateSettings.java @@ -18,7 +18,7 @@ public final class FactorCovariateSettings extends CovariateSettings { } @Override - FactorCovariate build() { + public FactorCovariate build() { return new FactorCovariate(name, levels); } } diff --git a/src/main/java/ca/joeltherrien/randomforest/covariates/NumericCovariate.java b/src/main/java/ca/joeltherrien/randomforest/covariates/NumericCovariate.java index 6bed687..2746697 100644 --- a/src/main/java/ca/joeltherrien/randomforest/covariates/NumericCovariate.java +++ b/src/main/java/ca/joeltherrien/randomforest/covariates/NumericCovariate.java @@ -55,6 +55,15 @@ public final class NumericCovariate implements Covariate{ return new NumericValue(value); } + @Override + public Value createValue(String value) { + if(value == null || value.equalsIgnoreCase("na")){ + return createValue((Double) null); + } + + return createValue(Double.parseDouble(value)); + } + public class NumericValue implements Covariate.Value{ private final Double value; // may be null diff --git a/src/main/java/ca/joeltherrien/randomforest/covariates/NumericCovariateSettings.java b/src/main/java/ca/joeltherrien/randomforest/covariates/NumericCovariateSettings.java index 0be6cf0..b35a81a 100644 --- a/src/main/java/ca/joeltherrien/randomforest/covariates/NumericCovariateSettings.java +++ b/src/main/java/ca/joeltherrien/randomforest/covariates/NumericCovariateSettings.java @@ -12,7 +12,7 @@ public final class NumericCovariateSettings extends CovariateSettings { } @Override - NumericCovariate build() { + public NumericCovariate build() { return new NumericCovariate(name); } } diff --git a/src/main/java/ca/joeltherrien/randomforest/regression/MeanGroupDifferentiator.java b/src/main/java/ca/joeltherrien/randomforest/regression/MeanGroupDifferentiator.java index 7c196e6..da3b0b1 100644 --- a/src/main/java/ca/joeltherrien/randomforest/regression/MeanGroupDifferentiator.java +++ b/src/main/java/ca/joeltherrien/randomforest/regression/MeanGroupDifferentiator.java @@ -1,11 +1,16 @@ package ca.joeltherrien.randomforest.regression; import ca.joeltherrien.randomforest.tree.GroupDifferentiator; +import ca.joeltherrien.randomforest.tree.ResponseCombiner; import java.util.List; public class MeanGroupDifferentiator implements GroupDifferentiator { + static{ + GroupDifferentiator.registerGroupDifferentiator("MeanGroupDifferentiator", new MeanGroupDifferentiator()); + } + @Override public Double differentiate(List leftHand, List rightHand) { diff --git a/src/main/java/ca/joeltherrien/randomforest/regression/MeanResponseCombiner.java b/src/main/java/ca/joeltherrien/randomforest/regression/MeanResponseCombiner.java index 57a3b35..3ff43f3 100644 --- a/src/main/java/ca/joeltherrien/randomforest/regression/MeanResponseCombiner.java +++ b/src/main/java/ca/joeltherrien/randomforest/regression/MeanResponseCombiner.java @@ -17,6 +17,10 @@ import java.util.function.Supplier; */ public class MeanResponseCombiner implements ResponseCombiner { + static{ + ResponseCombiner.registerResponseCombiner("MeanResponseCombiner", new MeanResponseCombiner()); + } + @Override public Double combine(List responses) { double size = responses.size(); diff --git a/src/main/java/ca/joeltherrien/randomforest/regression/WeightedVarianceGroupDifferentiator.java b/src/main/java/ca/joeltherrien/randomforest/regression/WeightedVarianceGroupDifferentiator.java index 2f40999..c27363f 100644 --- a/src/main/java/ca/joeltherrien/randomforest/regression/WeightedVarianceGroupDifferentiator.java +++ b/src/main/java/ca/joeltherrien/randomforest/regression/WeightedVarianceGroupDifferentiator.java @@ -6,6 +6,10 @@ import java.util.List; public class WeightedVarianceGroupDifferentiator implements GroupDifferentiator { + static{ + GroupDifferentiator.registerGroupDifferentiator("WeightedVarianceGroupDifferentiator", new WeightedVarianceGroupDifferentiator()); + } + @Override public Double differentiate(List leftHand, List rightHand) { diff --git a/src/main/java/ca/joeltherrien/randomforest/tree/ForestTrainer.java b/src/main/java/ca/joeltherrien/randomforest/tree/ForestTrainer.java index 76d4a70..f04a2f2 100644 --- a/src/main/java/ca/joeltherrien/randomforest/tree/ForestTrainer.java +++ b/src/main/java/ca/joeltherrien/randomforest/tree/ForestTrainer.java @@ -1,8 +1,11 @@ package ca.joeltherrien.randomforest.tree; import ca.joeltherrien.randomforest.Bootstrapper; +import ca.joeltherrien.randomforest.Settings; import ca.joeltherrien.randomforest.covariates.Covariate; import ca.joeltherrien.randomforest.Row; +import lombok.AccessLevel; +import lombok.AllArgsConstructor; import lombok.Builder; import java.io.FileOutputStream; @@ -18,6 +21,7 @@ import java.util.stream.Collectors; import java.util.stream.Stream; @Builder +@AllArgsConstructor(access=AccessLevel.PRIVATE) public class ForestTrainer { private final TreeTrainer treeTrainer; @@ -34,6 +38,19 @@ public class ForestTrainer { private final boolean displayProgress; private final String saveTreeLocation; + public ForestTrainer(final Settings settings, final List> data, final List covariates){ + this.mtry = settings.getMtry(); + this.ntree = settings.getNtree(); + this.data = data; + this.displayProgress = true; + this.saveTreeLocation = settings.getSaveTreeLocation(); + + this.covariatesToTry = covariates; + this.treeResponseCombiner = ResponseCombiner.loadResponseCombinerByName(settings.getTreeResponseCombiner()); + this.treeTrainer = new TreeTrainer<>(settings); + + } + public Forest trainSerial(){ final List> trees = new ArrayList<>(ntree); diff --git a/src/main/java/ca/joeltherrien/randomforest/tree/GroupDifferentiator.java b/src/main/java/ca/joeltherrien/randomforest/tree/GroupDifferentiator.java index cbd1247..8950aa8 100644 --- a/src/main/java/ca/joeltherrien/randomforest/tree/GroupDifferentiator.java +++ b/src/main/java/ca/joeltherrien/randomforest/tree/GroupDifferentiator.java @@ -1,6 +1,8 @@ package ca.joeltherrien.randomforest.tree; +import java.util.HashMap; import java.util.List; +import java.util.Map; /** * When choosing an optimal node to split on, we choose the split that maximizes the difference between the two groups. @@ -12,4 +14,12 @@ public interface GroupDifferentiator { Double differentiate(List leftHand, List rightHand); + Map GROUP_DIFFERENTIATOR_MAP = new HashMap<>(); + static GroupDifferentiator loadGroupDifferentiatorByName(final String name){ + return GROUP_DIFFERENTIATOR_MAP.get(name); + } + static void registerGroupDifferentiator(final String name, final GroupDifferentiator groupDifferentiator){ + GROUP_DIFFERENTIATOR_MAP.put(name, groupDifferentiator); + } + } diff --git a/src/main/java/ca/joeltherrien/randomforest/tree/ResponseCombiner.java b/src/main/java/ca/joeltherrien/randomforest/tree/ResponseCombiner.java index bbcdd30..6430aa0 100644 --- a/src/main/java/ca/joeltherrien/randomforest/tree/ResponseCombiner.java +++ b/src/main/java/ca/joeltherrien/randomforest/tree/ResponseCombiner.java @@ -1,10 +1,22 @@ package ca.joeltherrien.randomforest.tree; +import ca.joeltherrien.randomforest.regression.MeanResponseCombiner; + +import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.stream.Collector; public interface ResponseCombiner extends Collector { Y combine(List responses); + final static Map RESPONSE_COMBINER_MAP = new HashMap<>(); + static ResponseCombiner loadResponseCombinerByName(final String name){ + return RESPONSE_COMBINER_MAP.get(name); + } + static void registerResponseCombiner(final String name, final ResponseCombiner responseCombiner){ + RESPONSE_COMBINER_MAP.put(name, responseCombiner); + } + } diff --git a/src/main/java/ca/joeltherrien/randomforest/tree/TreeTrainer.java b/src/main/java/ca/joeltherrien/randomforest/tree/TreeTrainer.java index fd83382..d447952 100644 --- a/src/main/java/ca/joeltherrien/randomforest/tree/TreeTrainer.java +++ b/src/main/java/ca/joeltherrien/randomforest/tree/TreeTrainer.java @@ -2,12 +2,15 @@ package ca.joeltherrien.randomforest.tree; import ca.joeltherrien.randomforest.*; import ca.joeltherrien.randomforest.covariates.Covariate; +import lombok.AccessLevel; +import lombok.AllArgsConstructor; import lombok.Builder; import java.util.*; import java.util.stream.Collectors; @Builder +@AllArgsConstructor(access = AccessLevel.PRIVATE) public class TreeTrainer { private final ResponseCombiner responseCombiner; @@ -21,6 +24,14 @@ public class TreeTrainer { private final int nodeSize; private final int maxNodeDepth; + public TreeTrainer(final Settings settings){ + this.numberOfSplits = settings.getNumberOfSplits(); + this.nodeSize = settings.getNodeSize(); + this.maxNodeDepth = settings.getMaxNodeDepth(); + + this.responseCombiner = ResponseCombiner.loadResponseCombinerByName(settings.getResponseCombiner()); + this.groupDifferentiator = GroupDifferentiator.loadGroupDifferentiatorByName(settings.getGroupDifferentiator()); + } public Node growTree(List> data, List covariatesToTry){ return growNode(data, covariatesToTry, 0); diff --git a/src/test/java/ca/joeltherrien/randomforest/csv/TestLoadingCSV.java b/src/test/java/ca/joeltherrien/randomforest/csv/TestLoadingCSV.java new file mode 100644 index 0000000..5a1faf7 --- /dev/null +++ b/src/test/java/ca/joeltherrien/randomforest/csv/TestLoadingCSV.java @@ -0,0 +1,66 @@ +package ca.joeltherrien.randomforest.csv; + +import ca.joeltherrien.randomforest.Main; +import ca.joeltherrien.randomforest.Row; +import ca.joeltherrien.randomforest.Settings; +import ca.joeltherrien.randomforest.covariates.BooleanCovariateSettings; +import ca.joeltherrien.randomforest.covariates.Covariate; +import ca.joeltherrien.randomforest.covariates.FactorCovariateSettings; +import ca.joeltherrien.randomforest.covariates.NumericCovariateSettings; +import org.junit.jupiter.api.Test; + +import java.io.IOException; +import java.util.List; +import java.util.stream.Collectors; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class TestLoadingCSV { + + /* + y,x1,x2,x3 + 5,3.0,"mouse",true + 2,1.0,"dog",false + 9,1.5,"cat",true + */ + + @Test + public void verifyLoading() throws IOException { + final Settings settings = Settings.builder() + .dataFileLocation("src/test/resources/testCSV.csv") + .covariates( + List.of(new NumericCovariateSettings("x1"), + new FactorCovariateSettings("x2", List.of("dog", "cat", "mouse")), + new BooleanCovariateSettings("x3")) + ) + .yVar("y") + .build(); + + final List covariates = settings.getCovariates().stream() + .map(cs -> cs.build()).collect(Collectors.toList()); + + final List> data = Main.loadData(covariates, settings); + + assertEquals(3, data.size()); + + Row row = data.get(0); + assertEquals(5.0, (double)row.getResponse()); + assertEquals(3.0, row.getCovariateValue("x1").getValue()); + assertEquals("mouse", row.getCovariateValue("x2").getValue()); + assertEquals(true, row.getCovariateValue("x3").getValue()); + + row = data.get(1); + assertEquals(2.0, (double)row.getResponse()); + assertEquals(1.0, row.getCovariateValue("x1").getValue()); + assertEquals("dog", row.getCovariateValue("x2").getValue()); + assertEquals(false, row.getCovariateValue("x3").getValue()); + + row = data.get(2); + assertEquals(9.0, (double)row.getResponse()); + assertEquals(1.5, row.getCovariateValue("x1").getValue()); + assertEquals("cat", row.getCovariateValue("x2").getValue()); + assertEquals(true, row.getCovariateValue("x3").getValue()); + + } + +} diff --git a/src/test/java/ca/joeltherrien/randomforest/settings/TestPersistence.java b/src/test/java/ca/joeltherrien/randomforest/settings/TestPersistence.java index 562df01..421ca65 100644 --- a/src/test/java/ca/joeltherrien/randomforest/settings/TestPersistence.java +++ b/src/test/java/ca/joeltherrien/randomforest/settings/TestPersistence.java @@ -4,6 +4,8 @@ import ca.joeltherrien.randomforest.Settings; import static org.junit.jupiter.api.Assertions.assertEquals; import ca.joeltherrien.randomforest.covariates.*; +import ca.joeltherrien.randomforest.regression.MeanResponseCombiner; +import ca.joeltherrien.randomforest.regression.WeightedVarianceGroupDifferentiator; import org.junit.jupiter.api.Test; import java.io.File; @@ -22,6 +24,7 @@ public class TestPersistence { new FactorCovariateSettings("x3", List.of("cat", "mouse", "dog")) ) ) + .yVar("y") .dataFileLocation("data.csv") .groupDifferentiator("WeightedVarianceGroupDifferentiator") .responseCombiner("MeanResponseCombiner") diff --git a/src/test/resources/testCSV.csv b/src/test/resources/testCSV.csv new file mode 100644 index 0000000..ca1d181 --- /dev/null +++ b/src/test/resources/testCSV.csv @@ -0,0 +1,4 @@ +y,x1,x2,x3 +5,3.0,"mouse",true +2,1.0,"dog",false +9,1.5,"cat",true \ No newline at end of file