map = new HashMap<>();
- for(int j=0; j(map, i, y));
@@ -44,10 +51,8 @@ public class TrainForest {
}
- final List covariateNames = IntStream.range(0, p).mapToObj(j -> "x"+j).collect(Collectors.toList());
-
- TreeTrainer treeTrainer = TreeTrainer.builder()
+ final TreeTrainer treeTrainer = TreeTrainer.builder()
.numberOfSplits(5)
.nodeSize(5)
.maxNodeDepth(100000000)
@@ -58,7 +63,7 @@ public class TrainForest {
final ForestTrainer forestTrainer = ForestTrainer.builder()
.treeTrainer(treeTrainer)
.data(data)
- .covariatesToTry(covariateNames)
+ .covariatesToTry(covariateList)
.mtry(4)
.ntree(100)
.treeResponseCombiner(new MeanResponseCombiner())
@@ -69,7 +74,7 @@ public class TrainForest {
final long startTime = System.currentTimeMillis();
//final Forest forest = forestTrainer.trainSerial();
- //final Forest forest = forestTrainer.trainParallel(8);
+ //final Forest forest = forestTrainer.trainParallelInMemory(3);
forestTrainer.trainParallelOnDisk(3);
final long endTime = System.currentTimeMillis();
@@ -88,9 +93,9 @@ public class TrainForest {
System.out.println(forest.evaluate(testRow1));
System.out.println(forest.evaluate(testRow2));
-
- System.out.println("MinY = " + minY);
*/
+ System.out.println("MinY = " + minY);
+
}
}
diff --git a/src/main/java/ca/joeltherrien/randomforest/workshop/TrainSingleTree.java b/src/test/java/ca/joeltherrien/randomforest/workshop/TrainSingleTree.java
similarity index 57%
rename from src/main/java/ca/joeltherrien/randomforest/workshop/TrainSingleTree.java
rename to src/test/java/ca/joeltherrien/randomforest/workshop/TrainSingleTree.java
index d218687..72654a8 100644
--- a/src/main/java/ca/joeltherrien/randomforest/workshop/TrainSingleTree.java
+++ b/src/test/java/ca/joeltherrien/randomforest/workshop/TrainSingleTree.java
@@ -1,11 +1,10 @@
package ca.joeltherrien.randomforest.workshop;
+import ca.joeltherrien.randomforest.Covariate;
import ca.joeltherrien.randomforest.CovariateRow;
-import ca.joeltherrien.randomforest.NumericValue;
+import ca.joeltherrien.randomforest.NumericCovariate;
import ca.joeltherrien.randomforest.Row;
-import ca.joeltherrien.randomforest.Value;
-import ca.joeltherrien.randomforest.regression.MeanGroupDifferentiator;
import ca.joeltherrien.randomforest.regression.MeanResponseCombiner;
import ca.joeltherrien.randomforest.regression.WeightedVarianceGroupDifferentiator;
import ca.joeltherrien.randomforest.tree.Node;
@@ -25,30 +24,30 @@ public class TrainSingleTree {
final int n = 1000;
final List> trainingSet = new ArrayList<>(n);
- final List> x1List = DoubleStream
+ final Covariate x1Covariate = new NumericCovariate("x1");
+ final Covariate x2Covariate = new NumericCovariate("x2");
+
+ final List> x1List = DoubleStream
.generate(() -> random.nextDouble()*10.0)
.limit(n)
- .mapToObj(x1 -> new NumericValue(x1))
+ .mapToObj(x1 -> x1Covariate.createValue(x1))
.collect(Collectors.toList());
- final List> x2List = DoubleStream
+ final List> x2List = DoubleStream
.generate(() -> random.nextDouble()*10.0)
.limit(n)
- .mapToObj(x1 -> new NumericValue(x1))
+ .mapToObj(x2 -> x1Covariate.createValue(x2))
.collect(Collectors.toList());
for(int i=0; i x1 = x1List.get(i);
+ final Covariate.Value x2 = x2List.get(i);
trainingSet.add(generateRow(x1, x2, i));
}
-
- final long startTime = System.currentTimeMillis();
-
final TreeTrainer treeTrainer = TreeTrainer.builder()
.groupDifferentiator(new WeightedVarianceGroupDifferentiator())
.responseCombiner(new MeanResponseCombiner())
@@ -57,25 +56,29 @@ public class TrainSingleTree {
.numberOfSplits(0)
.build();
+ final List covariateNames = List.of(x1Covariate, x2Covariate);
+
+ final long startTime = System.currentTimeMillis();
+ final Node baseNode = treeTrainer.growTree(trainingSet, covariateNames);
final long endTime = System.currentTimeMillis();
System.out.println(((double)(endTime - startTime))/1000.0);
- final List covariateNames = List.of("x1", "x2");
- final Node baseNode = treeTrainer.growTree(trainingSet, covariateNames);
+
+
final List testSet = new ArrayList<>();
- testSet.add(generateCovariateRow(9, 2, 1)); // expect 1
- testSet.add(generateCovariateRow(5, 2, 5));
- testSet.add(generateCovariateRow(2, 2, 3));
- testSet.add(generateCovariateRow(9, 5, 0));
- testSet.add(generateCovariateRow(6, 5, 8));
- testSet.add(generateCovariateRow(3, 5, 10));
- testSet.add(generateCovariateRow(1, 5, 3));
- testSet.add(generateCovariateRow(7, 9, 2));
- testSet.add(generateCovariateRow(1, 9, 4));
+ testSet.add(generateCovariateRow(x1Covariate.createValue(9.0), x2Covariate.createValue(2.0), 1)); // expect 1
+ testSet.add(generateCovariateRow(x1Covariate.createValue(5.0), x2Covariate.createValue(2.0), 5));
+ testSet.add(generateCovariateRow(x1Covariate.createValue(2.0), x2Covariate.createValue(2.0), 3));
+ testSet.add(generateCovariateRow(x1Covariate.createValue(9.0), x2Covariate.createValue(5.0), 0));
+ testSet.add(generateCovariateRow(x1Covariate.createValue(6.0), x2Covariate.createValue(5.0), 8));
+ testSet.add(generateCovariateRow(x1Covariate.createValue(3.0), x2Covariate.createValue(5.0), 10));
+ testSet.add(generateCovariateRow(x1Covariate.createValue(1.0), x2Covariate.createValue(5.0), 3));
+ testSet.add(generateCovariateRow(x1Covariate.createValue(7.0), x2Covariate.createValue(9.0), 2));
+ testSet.add(generateCovariateRow(x1Covariate.createValue(1.0), x2Covariate.createValue(9.0), 4));
for(final CovariateRow testCase : testSet){
System.out.println(testCase);
@@ -91,18 +94,18 @@ public class TrainSingleTree {
}
- public static Row generateRow(double x1, double x2, int id){
- double y = generateResponse(x1, x2);
+ public static Row generateRow(Covariate.Value x1, Covariate.Value x2, int id){
+ double y = generateResponse(x1.getValue(), x2.getValue());
- final Map map = Map.of("x1", new NumericValue(x1), "x2", new NumericValue(x2));
+ final Map map = Map.of("x1", x1, "x2", x2);
return new Row<>(map, id, y);
}
- public static CovariateRow generateCovariateRow(double x1, double x2, int id){
- final Map map = Map.of("x1", new NumericValue(x1), "x2", new NumericValue(x2));
+ public static CovariateRow generateCovariateRow(Covariate.Value x1, Covariate.Value x2, int id){
+ final Map map = Map.of("x1", x1, "x2", x2);
return new CovariateRow(map, id);
diff --git a/src/test/java/ca/joeltherrien/randomforest/workshop/TrainSingleTreeFactor.java b/src/test/java/ca/joeltherrien/randomforest/workshop/TrainSingleTreeFactor.java
new file mode 100644
index 0000000..e63f8cd
--- /dev/null
+++ b/src/test/java/ca/joeltherrien/randomforest/workshop/TrainSingleTreeFactor.java
@@ -0,0 +1,190 @@
+package ca.joeltherrien.randomforest.workshop;
+
+
+import ca.joeltherrien.randomforest.*;
+import ca.joeltherrien.randomforest.regression.MeanResponseCombiner;
+import ca.joeltherrien.randomforest.regression.WeightedVarianceGroupDifferentiator;
+import ca.joeltherrien.randomforest.tree.Node;
+import ca.joeltherrien.randomforest.tree.TreeTrainer;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+import java.util.stream.Collectors;
+import java.util.stream.DoubleStream;
+import java.util.stream.Stream;
+
+public class TrainSingleTreeFactor {
+
+ public static void main(String[] args) {
+ System.out.println("Hello world!");
+
+ final Random random = new Random(123);
+
+ final int n = 10000;
+ final List> trainingSet = new ArrayList<>(n);
+
+ final Covariate x1Covariate = new NumericCovariate("x1");
+ final Covariate x2Covariate = new NumericCovariate("x2");
+ final FactorCovariate x3Covariate = new FactorCovariate("x3", List.of("cat", "dog", "mouse"));
+
+ final List> x1List = DoubleStream
+ .generate(() -> random.nextDouble()*10.0)
+ .limit(n)
+ .mapToObj(x1 -> x1Covariate.createValue(x1))
+ .collect(Collectors.toList());
+
+ final List> x2List = DoubleStream
+ .generate(() -> random.nextDouble()*10.0)
+ .limit(n)
+ .mapToObj(x2 -> x1Covariate.createValue(x2))
+ .collect(Collectors.toList());
+
+ final List> x3List = DoubleStream
+ .generate(() -> random.nextDouble())
+ .limit(n)
+ .mapToObj(db -> {
+ if(db < 0.333){
+ return "cat";
+ }
+ else if(db < 0.5){
+ return "mouse";
+ }
+ else{
+ return "dog";
+ }
+ })
+ .map(str -> x3Covariate.createValue(str))
+ .collect(Collectors.toList());
+
+
+ for(int i=0; i x1 = x1List.get(i);
+ final Covariate.Value x2 = x2List.get(i);
+ final Covariate.Value x3 = x3List.get(i);
+
+ trainingSet.add(generateRow(x1, x2, x3, i));
+ }
+
+ final TreeTrainer treeTrainer = TreeTrainer.builder()
+ .groupDifferentiator(new WeightedVarianceGroupDifferentiator())
+ .responseCombiner(new MeanResponseCombiner())
+ .maxNodeDepth(30)
+ .nodeSize(5)
+ .numberOfSplits(5)
+ .build();
+
+ final List covariateNames = List.of(x1Covariate, x2Covariate);
+
+ final long startTime = System.currentTimeMillis();
+ final Node baseNode = treeTrainer.growTree(trainingSet, covariateNames);
+ final long endTime = System.currentTimeMillis();
+
+ System.out.println(((double)(endTime - startTime))/1000.0);
+
+
+
+ final Covariate.Value cat = x3Covariate.createValue("cat");
+ final Covariate.Value dog = x3Covariate.createValue("dog");
+ final Covariate.Value mouse = x3Covariate.createValue("mouse");
+
+
+ final List testSet = new ArrayList<>();
+ testSet.add(generateCovariateRow(x1Covariate.createValue(9.0), x2Covariate.createValue(2.0), cat, 1)); // expect 1
+ testSet.add(generateCovariateRow(x1Covariate.createValue(5.0), x2Covariate.createValue(2.0), dog, 5));
+ testSet.add(generateCovariateRow(x1Covariate.createValue(2.0), x2Covariate.createValue(2.0), cat, 3));
+ testSet.add(generateCovariateRow(x1Covariate.createValue(9.0), x2Covariate.createValue(5.0), dog, 0));
+ testSet.add(generateCovariateRow(x1Covariate.createValue(6.0), x2Covariate.createValue(5.0), cat, 8));
+ testSet.add(generateCovariateRow(x1Covariate.createValue(3.0), x2Covariate.createValue(5.0), dog, 10));
+ testSet.add(generateCovariateRow(x1Covariate.createValue(1.0), x2Covariate.createValue(5.0), cat, 3));
+ testSet.add(generateCovariateRow(x1Covariate.createValue(7.0), x2Covariate.createValue(9.0), dog, 2));
+ testSet.add(generateCovariateRow(x1Covariate.createValue(1.0), x2Covariate.createValue(9.0), cat, 4));
+
+
+ testSet.add(generateCovariateRow(x1Covariate.createValue(3.0), x2Covariate.createValue(9.0), mouse, 0));
+ testSet.add(generateCovariateRow(x1Covariate.createValue(7.0), x2Covariate.createValue(9.0), mouse, 5));
+
+ for(final CovariateRow testCase : testSet){
+ System.out.println(testCase);
+ System.out.println(baseNode.evaluate(testCase));
+ System.out.println();
+
+
+ }
+
+
+
+
+
+ }
+
+ public static Row generateRow(Covariate.Value x1, Covariate.Value x2, Covariate.Value x3, int id){
+ double y = generateResponse(x1.getValue(), x2.getValue(), x3.getValue());
+
+ final Map map = Map.of("x1", x1, "x2", x2);
+
+ return new Row<>(map, id, y);
+
+ }
+
+
+ public static CovariateRow generateCovariateRow(Covariate.Value x1, Covariate.Value x2, Covariate.Value x3, int id){
+ final Map map = Map.of("x1", x1, "x2", x2, "x3", x3);
+
+ return new CovariateRow(map, id);
+
+ }
+
+
+ public static double generateResponse(double x1, double x2, String x3){
+
+ if(x3.equalsIgnoreCase("mouse")){
+ if(x1 <= 5){
+ return 0;
+ }
+ else{
+ return 5;
+ }
+ }
+
+ // cat & dog below
+
+ if(x2 <= 3){
+ if(x1 <= 3){
+ return 3;
+ }
+ else if(x1 <= 7){
+ return 5;
+ }
+ else{
+ return 1;
+ }
+ }
+ else if(x1 >= 5){
+ if(x2 > 6){
+ return 2;
+ }
+ else if(x1 >= 8){
+ return 0;
+ }
+ else{
+ return 8;
+ }
+ }
+ else if(x1 <= 2){
+ if(x2 >= 7){
+ return 4;
+ }
+ else{
+ return 3;
+ }
+ }
+ else{
+ return 10;
+ }
+
+
+ }
+
+}