diff --git a/pom.xml b/pom.xml
index 637f869..01d9203 100644
--- a/pom.xml
+++ b/pom.xml
@@ -24,6 +24,19 @@
+
+ org.junit.jupiter
+ junit-jupiter-api
+ 5.2.0
+ test
+
+
+
+ org.junit.jupiter
+ junit-jupiter-engine
+ 5.2.0
+ test
+
diff --git a/src/main/java/ca/joeltherrien/randomforest/FactorCovariate.java b/src/main/java/ca/joeltherrien/randomforest/FactorCovariate.java
index ba3ac81..5200e02 100644
--- a/src/main/java/ca/joeltherrien/randomforest/FactorCovariate.java
+++ b/src/main/java/ca/joeltherrien/randomforest/FactorCovariate.java
@@ -26,7 +26,7 @@ public final class FactorCovariate implements Covariate{
for(int i=0; i{
}
@Override
- public Collection generateSplitRules(List> data, int number) {
+ public Set generateSplitRules(List> data, int number) {
final Set splitRules = new HashSet<>();
// This is to ensure we don't get stuck in an infinite loop for small factors
diff --git a/src/test/java/ca/joeltherrien/randomforest/covariates/FactorCovariateTest.java b/src/test/java/ca/joeltherrien/randomforest/covariates/FactorCovariateTest.java
new file mode 100644
index 0000000..1a22ce9
--- /dev/null
+++ b/src/test/java/ca/joeltherrien/randomforest/covariates/FactorCovariateTest.java
@@ -0,0 +1,65 @@
+package ca.joeltherrien.randomforest.covariates;
+
+
+import ca.joeltherrien.randomforest.FactorCovariate;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.function.Executable;
+
+import java.util.Collection;
+import java.util.List;
+
+import static org.junit.jupiter.api.Assertions.*;
+
+public class FactorCovariateTest {
+
+ @Test
+ void verifyEqualLevels() {
+ final FactorCovariate petCovariate = createTestCovariate();
+
+ final FactorCovariate.FactorValue dog1 = petCovariate.createValue("DOG");
+ final FactorCovariate.FactorValue dog2 = petCovariate.createValue("DO" + "G");
+
+ assertSame(dog1, dog2);
+
+ final FactorCovariate.FactorValue cat1 = petCovariate.createValue("CAT");
+ final FactorCovariate.FactorValue cat2 = petCovariate.createValue("CA" + "T");
+
+ assertSame(cat1, cat2);
+
+ final FactorCovariate.FactorValue mouse1 = petCovariate.createValue("MOUSE");
+ final FactorCovariate.FactorValue mouse2 = petCovariate.createValue("MOUS" + "E");
+
+ assertSame(mouse1, mouse2);
+
+
+ }
+
+ @Test
+ void verifyBadLevelException(){
+ final FactorCovariate petCovariate = createTestCovariate();
+ final Executable badCode = () -> petCovariate.createValue("vulcan");
+
+ assertThrows(IllegalArgumentException.class, badCode, "vulcan is not a level in FactorCovariate pet");
+ }
+
+ @Test
+ void testAllSubsets(){
+ final FactorCovariate petCovariate = createTestCovariate();
+
+ final Collection splitRules = petCovariate.generateSplitRules(null, 100);
+
+ assertEquals(splitRules.size(), 3);
+
+ // TODO verify the contents of the split rules
+
+ }
+
+
+ private FactorCovariate createTestCovariate(){
+ final List levels = List.of("DOG", "CAT", "MOUSE");
+
+ return new FactorCovariate("pet", levels);
+ }
+
+
+}
diff --git a/src/main/java/ca/joeltherrien/randomforest/workshop/TrainForest.java b/src/test/java/ca/joeltherrien/randomforest/workshop/TrainForest.java
similarity index 100%
rename from src/main/java/ca/joeltherrien/randomforest/workshop/TrainForest.java
rename to src/test/java/ca/joeltherrien/randomforest/workshop/TrainForest.java
diff --git a/src/main/java/ca/joeltherrien/randomforest/workshop/TrainSingleTree.java b/src/test/java/ca/joeltherrien/randomforest/workshop/TrainSingleTree.java
similarity index 100%
rename from src/main/java/ca/joeltherrien/randomforest/workshop/TrainSingleTree.java
rename to src/test/java/ca/joeltherrien/randomforest/workshop/TrainSingleTree.java
diff --git a/src/test/java/ca/joeltherrien/randomforest/workshop/TrainSingleTreeFactor.java b/src/test/java/ca/joeltherrien/randomforest/workshop/TrainSingleTreeFactor.java
new file mode 100644
index 0000000..e63f8cd
--- /dev/null
+++ b/src/test/java/ca/joeltherrien/randomforest/workshop/TrainSingleTreeFactor.java
@@ -0,0 +1,190 @@
+package ca.joeltherrien.randomforest.workshop;
+
+
+import ca.joeltherrien.randomforest.*;
+import ca.joeltherrien.randomforest.regression.MeanResponseCombiner;
+import ca.joeltherrien.randomforest.regression.WeightedVarianceGroupDifferentiator;
+import ca.joeltherrien.randomforest.tree.Node;
+import ca.joeltherrien.randomforest.tree.TreeTrainer;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+import java.util.stream.Collectors;
+import java.util.stream.DoubleStream;
+import java.util.stream.Stream;
+
+public class TrainSingleTreeFactor {
+
+ public static void main(String[] args) {
+ System.out.println("Hello world!");
+
+ final Random random = new Random(123);
+
+ final int n = 10000;
+ final List> trainingSet = new ArrayList<>(n);
+
+ final Covariate x1Covariate = new NumericCovariate("x1");
+ final Covariate x2Covariate = new NumericCovariate("x2");
+ final FactorCovariate x3Covariate = new FactorCovariate("x3", List.of("cat", "dog", "mouse"));
+
+ final List> x1List = DoubleStream
+ .generate(() -> random.nextDouble()*10.0)
+ .limit(n)
+ .mapToObj(x1 -> x1Covariate.createValue(x1))
+ .collect(Collectors.toList());
+
+ final List> x2List = DoubleStream
+ .generate(() -> random.nextDouble()*10.0)
+ .limit(n)
+ .mapToObj(x2 -> x1Covariate.createValue(x2))
+ .collect(Collectors.toList());
+
+ final List> x3List = DoubleStream
+ .generate(() -> random.nextDouble())
+ .limit(n)
+ .mapToObj(db -> {
+ if(db < 0.333){
+ return "cat";
+ }
+ else if(db < 0.5){
+ return "mouse";
+ }
+ else{
+ return "dog";
+ }
+ })
+ .map(str -> x3Covariate.createValue(str))
+ .collect(Collectors.toList());
+
+
+ for(int i=0; i x1 = x1List.get(i);
+ final Covariate.Value x2 = x2List.get(i);
+ final Covariate.Value x3 = x3List.get(i);
+
+ trainingSet.add(generateRow(x1, x2, x3, i));
+ }
+
+ final TreeTrainer treeTrainer = TreeTrainer.builder()
+ .groupDifferentiator(new WeightedVarianceGroupDifferentiator())
+ .responseCombiner(new MeanResponseCombiner())
+ .maxNodeDepth(30)
+ .nodeSize(5)
+ .numberOfSplits(5)
+ .build();
+
+ final List covariateNames = List.of(x1Covariate, x2Covariate);
+
+ final long startTime = System.currentTimeMillis();
+ final Node baseNode = treeTrainer.growTree(trainingSet, covariateNames);
+ final long endTime = System.currentTimeMillis();
+
+ System.out.println(((double)(endTime - startTime))/1000.0);
+
+
+
+ final Covariate.Value cat = x3Covariate.createValue("cat");
+ final Covariate.Value dog = x3Covariate.createValue("dog");
+ final Covariate.Value mouse = x3Covariate.createValue("mouse");
+
+
+ final List testSet = new ArrayList<>();
+ testSet.add(generateCovariateRow(x1Covariate.createValue(9.0), x2Covariate.createValue(2.0), cat, 1)); // expect 1
+ testSet.add(generateCovariateRow(x1Covariate.createValue(5.0), x2Covariate.createValue(2.0), dog, 5));
+ testSet.add(generateCovariateRow(x1Covariate.createValue(2.0), x2Covariate.createValue(2.0), cat, 3));
+ testSet.add(generateCovariateRow(x1Covariate.createValue(9.0), x2Covariate.createValue(5.0), dog, 0));
+ testSet.add(generateCovariateRow(x1Covariate.createValue(6.0), x2Covariate.createValue(5.0), cat, 8));
+ testSet.add(generateCovariateRow(x1Covariate.createValue(3.0), x2Covariate.createValue(5.0), dog, 10));
+ testSet.add(generateCovariateRow(x1Covariate.createValue(1.0), x2Covariate.createValue(5.0), cat, 3));
+ testSet.add(generateCovariateRow(x1Covariate.createValue(7.0), x2Covariate.createValue(9.0), dog, 2));
+ testSet.add(generateCovariateRow(x1Covariate.createValue(1.0), x2Covariate.createValue(9.0), cat, 4));
+
+
+ testSet.add(generateCovariateRow(x1Covariate.createValue(3.0), x2Covariate.createValue(9.0), mouse, 0));
+ testSet.add(generateCovariateRow(x1Covariate.createValue(7.0), x2Covariate.createValue(9.0), mouse, 5));
+
+ for(final CovariateRow testCase : testSet){
+ System.out.println(testCase);
+ System.out.println(baseNode.evaluate(testCase));
+ System.out.println();
+
+
+ }
+
+
+
+
+
+ }
+
+ public static Row generateRow(Covariate.Value x1, Covariate.Value x2, Covariate.Value x3, int id){
+ double y = generateResponse(x1.getValue(), x2.getValue(), x3.getValue());
+
+ final Map map = Map.of("x1", x1, "x2", x2);
+
+ return new Row<>(map, id, y);
+
+ }
+
+
+ public static CovariateRow generateCovariateRow(Covariate.Value x1, Covariate.Value x2, Covariate.Value x3, int id){
+ final Map map = Map.of("x1", x1, "x2", x2, "x3", x3);
+
+ return new CovariateRow(map, id);
+
+ }
+
+
+ public static double generateResponse(double x1, double x2, String x3){
+
+ if(x3.equalsIgnoreCase("mouse")){
+ if(x1 <= 5){
+ return 0;
+ }
+ else{
+ return 5;
+ }
+ }
+
+ // cat & dog below
+
+ if(x2 <= 3){
+ if(x1 <= 3){
+ return 3;
+ }
+ else if(x1 <= 7){
+ return 5;
+ }
+ else{
+ return 1;
+ }
+ }
+ else if(x1 >= 5){
+ if(x2 > 6){
+ return 2;
+ }
+ else if(x1 >= 8){
+ return 0;
+ }
+ else{
+ return 8;
+ }
+ }
+ else if(x1 <= 2){
+ if(x2 >= 7){
+ return 4;
+ }
+ else{
+ return 3;
+ }
+ }
+ else{
+ return 10;
+ }
+
+
+ }
+
+}