diff --git a/pom.xml b/pom.xml index 637f869..01d9203 100644 --- a/pom.xml +++ b/pom.xml @@ -24,6 +24,19 @@ + + org.junit.jupiter + junit-jupiter-api + 5.2.0 + test + + + + org.junit.jupiter + junit-jupiter-engine + 5.2.0 + test + diff --git a/src/main/java/ca/joeltherrien/randomforest/FactorCovariate.java b/src/main/java/ca/joeltherrien/randomforest/FactorCovariate.java index ba3ac81..5200e02 100644 --- a/src/main/java/ca/joeltherrien/randomforest/FactorCovariate.java +++ b/src/main/java/ca/joeltherrien/randomforest/FactorCovariate.java @@ -26,7 +26,7 @@ public final class FactorCovariate implements Covariate{ for(int i=0; i{ } @Override - public Collection generateSplitRules(List> data, int number) { + public Set generateSplitRules(List> data, int number) { final Set splitRules = new HashSet<>(); // This is to ensure we don't get stuck in an infinite loop for small factors diff --git a/src/test/java/ca/joeltherrien/randomforest/covariates/FactorCovariateTest.java b/src/test/java/ca/joeltherrien/randomforest/covariates/FactorCovariateTest.java new file mode 100644 index 0000000..1a22ce9 --- /dev/null +++ b/src/test/java/ca/joeltherrien/randomforest/covariates/FactorCovariateTest.java @@ -0,0 +1,65 @@ +package ca.joeltherrien.randomforest.covariates; + + +import ca.joeltherrien.randomforest.FactorCovariate; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.function.Executable; + +import java.util.Collection; +import java.util.List; + +import static org.junit.jupiter.api.Assertions.*; + +public class FactorCovariateTest { + + @Test + void verifyEqualLevels() { + final FactorCovariate petCovariate = createTestCovariate(); + + final FactorCovariate.FactorValue dog1 = petCovariate.createValue("DOG"); + final FactorCovariate.FactorValue dog2 = petCovariate.createValue("DO" + "G"); + + assertSame(dog1, dog2); + + final FactorCovariate.FactorValue cat1 = petCovariate.createValue("CAT"); + final FactorCovariate.FactorValue cat2 = petCovariate.createValue("CA" + "T"); + + assertSame(cat1, cat2); + + final FactorCovariate.FactorValue mouse1 = petCovariate.createValue("MOUSE"); + final FactorCovariate.FactorValue mouse2 = petCovariate.createValue("MOUS" + "E"); + + assertSame(mouse1, mouse2); + + + } + + @Test + void verifyBadLevelException(){ + final FactorCovariate petCovariate = createTestCovariate(); + final Executable badCode = () -> petCovariate.createValue("vulcan"); + + assertThrows(IllegalArgumentException.class, badCode, "vulcan is not a level in FactorCovariate pet"); + } + + @Test + void testAllSubsets(){ + final FactorCovariate petCovariate = createTestCovariate(); + + final Collection splitRules = petCovariate.generateSplitRules(null, 100); + + assertEquals(splitRules.size(), 3); + + // TODO verify the contents of the split rules + + } + + + private FactorCovariate createTestCovariate(){ + final List levels = List.of("DOG", "CAT", "MOUSE"); + + return new FactorCovariate("pet", levels); + } + + +} diff --git a/src/main/java/ca/joeltherrien/randomforest/workshop/TrainForest.java b/src/test/java/ca/joeltherrien/randomforest/workshop/TrainForest.java similarity index 100% rename from src/main/java/ca/joeltherrien/randomforest/workshop/TrainForest.java rename to src/test/java/ca/joeltherrien/randomforest/workshop/TrainForest.java diff --git a/src/main/java/ca/joeltherrien/randomforest/workshop/TrainSingleTree.java b/src/test/java/ca/joeltherrien/randomforest/workshop/TrainSingleTree.java similarity index 100% rename from src/main/java/ca/joeltherrien/randomforest/workshop/TrainSingleTree.java rename to src/test/java/ca/joeltherrien/randomforest/workshop/TrainSingleTree.java diff --git a/src/test/java/ca/joeltherrien/randomforest/workshop/TrainSingleTreeFactor.java b/src/test/java/ca/joeltherrien/randomforest/workshop/TrainSingleTreeFactor.java new file mode 100644 index 0000000..e63f8cd --- /dev/null +++ b/src/test/java/ca/joeltherrien/randomforest/workshop/TrainSingleTreeFactor.java @@ -0,0 +1,190 @@ +package ca.joeltherrien.randomforest.workshop; + + +import ca.joeltherrien.randomforest.*; +import ca.joeltherrien.randomforest.regression.MeanResponseCombiner; +import ca.joeltherrien.randomforest.regression.WeightedVarianceGroupDifferentiator; +import ca.joeltherrien.randomforest.tree.Node; +import ca.joeltherrien.randomforest.tree.TreeTrainer; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.Random; +import java.util.stream.Collectors; +import java.util.stream.DoubleStream; +import java.util.stream.Stream; + +public class TrainSingleTreeFactor { + + public static void main(String[] args) { + System.out.println("Hello world!"); + + final Random random = new Random(123); + + final int n = 10000; + final List> trainingSet = new ArrayList<>(n); + + final Covariate x1Covariate = new NumericCovariate("x1"); + final Covariate x2Covariate = new NumericCovariate("x2"); + final FactorCovariate x3Covariate = new FactorCovariate("x3", List.of("cat", "dog", "mouse")); + + final List> x1List = DoubleStream + .generate(() -> random.nextDouble()*10.0) + .limit(n) + .mapToObj(x1 -> x1Covariate.createValue(x1)) + .collect(Collectors.toList()); + + final List> x2List = DoubleStream + .generate(() -> random.nextDouble()*10.0) + .limit(n) + .mapToObj(x2 -> x1Covariate.createValue(x2)) + .collect(Collectors.toList()); + + final List> x3List = DoubleStream + .generate(() -> random.nextDouble()) + .limit(n) + .mapToObj(db -> { + if(db < 0.333){ + return "cat"; + } + else if(db < 0.5){ + return "mouse"; + } + else{ + return "dog"; + } + }) + .map(str -> x3Covariate.createValue(str)) + .collect(Collectors.toList()); + + + for(int i=0; i x1 = x1List.get(i); + final Covariate.Value x2 = x2List.get(i); + final Covariate.Value x3 = x3List.get(i); + + trainingSet.add(generateRow(x1, x2, x3, i)); + } + + final TreeTrainer treeTrainer = TreeTrainer.builder() + .groupDifferentiator(new WeightedVarianceGroupDifferentiator()) + .responseCombiner(new MeanResponseCombiner()) + .maxNodeDepth(30) + .nodeSize(5) + .numberOfSplits(5) + .build(); + + final List covariateNames = List.of(x1Covariate, x2Covariate); + + final long startTime = System.currentTimeMillis(); + final Node baseNode = treeTrainer.growTree(trainingSet, covariateNames); + final long endTime = System.currentTimeMillis(); + + System.out.println(((double)(endTime - startTime))/1000.0); + + + + final Covariate.Value cat = x3Covariate.createValue("cat"); + final Covariate.Value dog = x3Covariate.createValue("dog"); + final Covariate.Value mouse = x3Covariate.createValue("mouse"); + + + final List testSet = new ArrayList<>(); + testSet.add(generateCovariateRow(x1Covariate.createValue(9.0), x2Covariate.createValue(2.0), cat, 1)); // expect 1 + testSet.add(generateCovariateRow(x1Covariate.createValue(5.0), x2Covariate.createValue(2.0), dog, 5)); + testSet.add(generateCovariateRow(x1Covariate.createValue(2.0), x2Covariate.createValue(2.0), cat, 3)); + testSet.add(generateCovariateRow(x1Covariate.createValue(9.0), x2Covariate.createValue(5.0), dog, 0)); + testSet.add(generateCovariateRow(x1Covariate.createValue(6.0), x2Covariate.createValue(5.0), cat, 8)); + testSet.add(generateCovariateRow(x1Covariate.createValue(3.0), x2Covariate.createValue(5.0), dog, 10)); + testSet.add(generateCovariateRow(x1Covariate.createValue(1.0), x2Covariate.createValue(5.0), cat, 3)); + testSet.add(generateCovariateRow(x1Covariate.createValue(7.0), x2Covariate.createValue(9.0), dog, 2)); + testSet.add(generateCovariateRow(x1Covariate.createValue(1.0), x2Covariate.createValue(9.0), cat, 4)); + + + testSet.add(generateCovariateRow(x1Covariate.createValue(3.0), x2Covariate.createValue(9.0), mouse, 0)); + testSet.add(generateCovariateRow(x1Covariate.createValue(7.0), x2Covariate.createValue(9.0), mouse, 5)); + + for(final CovariateRow testCase : testSet){ + System.out.println(testCase); + System.out.println(baseNode.evaluate(testCase)); + System.out.println(); + + + } + + + + + + } + + public static Row generateRow(Covariate.Value x1, Covariate.Value x2, Covariate.Value x3, int id){ + double y = generateResponse(x1.getValue(), x2.getValue(), x3.getValue()); + + final Map map = Map.of("x1", x1, "x2", x2); + + return new Row<>(map, id, y); + + } + + + public static CovariateRow generateCovariateRow(Covariate.Value x1, Covariate.Value x2, Covariate.Value x3, int id){ + final Map map = Map.of("x1", x1, "x2", x2, "x3", x3); + + return new CovariateRow(map, id); + + } + + + public static double generateResponse(double x1, double x2, String x3){ + + if(x3.equalsIgnoreCase("mouse")){ + if(x1 <= 5){ + return 0; + } + else{ + return 5; + } + } + + // cat & dog below + + if(x2 <= 3){ + if(x1 <= 3){ + return 3; + } + else if(x1 <= 7){ + return 5; + } + else{ + return 1; + } + } + else if(x1 >= 5){ + if(x2 > 6){ + return 2; + } + else if(x1 >= 8){ + return 0; + } + else{ + return 8; + } + } + else if(x1 <= 2){ + if(x2 >= 7){ + return 4; + } + else{ + return 3; + } + } + else{ + return 10; + } + + + } + +}