Introduce Support for Factors #7
6 changed files with 270 additions and 2 deletions
13
pom.xml
13
pom.xml
|
@ -24,6 +24,19 @@
|
|||
</dependency>
|
||||
|
||||
|
||||
<dependency>
|
||||
<groupId>org.junit.jupiter</groupId>
|
||||
<artifactId>junit-jupiter-api</artifactId>
|
||||
<version>5.2.0</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.junit.jupiter</groupId>
|
||||
<artifactId>junit-jupiter-engine</artifactId>
|
||||
<version>5.2.0</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
|
||||
</dependencies>
|
||||
|
||||
|
|
|
@ -26,7 +26,7 @@ public final class FactorCovariate implements Covariate<String>{
|
|||
for(int i=0; i<levels.size()-1; i++){
|
||||
numberOfPossiblePairingsTemp *= 2;
|
||||
}
|
||||
this.numberOfPossiblePairings = numberOfPossiblePairingsTemp;
|
||||
this.numberOfPossiblePairings = numberOfPossiblePairingsTemp-1;
|
||||
|
||||
}
|
||||
|
||||
|
@ -37,7 +37,7 @@ public final class FactorCovariate implements Covariate<String>{
|
|||
}
|
||||
|
||||
@Override
|
||||
public Collection<FactorSplitRule> generateSplitRules(List<Value<String>> data, int number) {
|
||||
public Set<FactorSplitRule> generateSplitRules(List<Value<String>> data, int number) {
|
||||
final Set<FactorSplitRule> splitRules = new HashSet<>();
|
||||
|
||||
// This is to ensure we don't get stuck in an infinite loop for small factors
|
||||
|
|
|
@ -0,0 +1,65 @@
|
|||
package ca.joeltherrien.randomforest.covariates;
|
||||
|
||||
|
||||
import ca.joeltherrien.randomforest.FactorCovariate;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.function.Executable;
|
||||
|
||||
import java.util.Collection;
|
||||
import java.util.List;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.*;
|
||||
|
||||
public class FactorCovariateTest {
|
||||
|
||||
@Test
|
||||
void verifyEqualLevels() {
|
||||
final FactorCovariate petCovariate = createTestCovariate();
|
||||
|
||||
final FactorCovariate.FactorValue dog1 = petCovariate.createValue("DOG");
|
||||
final FactorCovariate.FactorValue dog2 = petCovariate.createValue("DO" + "G");
|
||||
|
||||
assertSame(dog1, dog2);
|
||||
|
||||
final FactorCovariate.FactorValue cat1 = petCovariate.createValue("CAT");
|
||||
final FactorCovariate.FactorValue cat2 = petCovariate.createValue("CA" + "T");
|
||||
|
||||
assertSame(cat1, cat2);
|
||||
|
||||
final FactorCovariate.FactorValue mouse1 = petCovariate.createValue("MOUSE");
|
||||
final FactorCovariate.FactorValue mouse2 = petCovariate.createValue("MOUS" + "E");
|
||||
|
||||
assertSame(mouse1, mouse2);
|
||||
|
||||
|
||||
}
|
||||
|
||||
@Test
|
||||
void verifyBadLevelException(){
|
||||
final FactorCovariate petCovariate = createTestCovariate();
|
||||
final Executable badCode = () -> petCovariate.createValue("vulcan");
|
||||
|
||||
assertThrows(IllegalArgumentException.class, badCode, "vulcan is not a level in FactorCovariate pet");
|
||||
}
|
||||
|
||||
@Test
|
||||
void testAllSubsets(){
|
||||
final FactorCovariate petCovariate = createTestCovariate();
|
||||
|
||||
final Collection<FactorCovariate.FactorSplitRule> splitRules = petCovariate.generateSplitRules(null, 100);
|
||||
|
||||
assertEquals(splitRules.size(), 3);
|
||||
|
||||
// TODO verify the contents of the split rules
|
||||
|
||||
}
|
||||
|
||||
|
||||
private FactorCovariate createTestCovariate(){
|
||||
final List<String> levels = List.of("DOG", "CAT", "MOUSE");
|
||||
|
||||
return new FactorCovariate("pet", levels);
|
||||
}
|
||||
|
||||
|
||||
}
|
|
@ -0,0 +1,190 @@
|
|||
package ca.joeltherrien.randomforest.workshop;
|
||||
|
||||
|
||||
import ca.joeltherrien.randomforest.*;
|
||||
import ca.joeltherrien.randomforest.regression.MeanResponseCombiner;
|
||||
import ca.joeltherrien.randomforest.regression.WeightedVarianceGroupDifferentiator;
|
||||
import ca.joeltherrien.randomforest.tree.Node;
|
||||
import ca.joeltherrien.randomforest.tree.TreeTrainer;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Random;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.DoubleStream;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
public class TrainSingleTreeFactor {
|
||||
|
||||
public static void main(String[] args) {
|
||||
System.out.println("Hello world!");
|
||||
|
||||
final Random random = new Random(123);
|
||||
|
||||
final int n = 10000;
|
||||
final List<Row<Double>> trainingSet = new ArrayList<>(n);
|
||||
|
||||
final Covariate<Double> x1Covariate = new NumericCovariate("x1");
|
||||
final Covariate<Double> x2Covariate = new NumericCovariate("x2");
|
||||
final FactorCovariate x3Covariate = new FactorCovariate("x3", List.of("cat", "dog", "mouse"));
|
||||
|
||||
final List<Covariate.Value<Double>> x1List = DoubleStream
|
||||
.generate(() -> random.nextDouble()*10.0)
|
||||
.limit(n)
|
||||
.mapToObj(x1 -> x1Covariate.createValue(x1))
|
||||
.collect(Collectors.toList());
|
||||
|
||||
final List<Covariate.Value<Double>> x2List = DoubleStream
|
||||
.generate(() -> random.nextDouble()*10.0)
|
||||
.limit(n)
|
||||
.mapToObj(x2 -> x1Covariate.createValue(x2))
|
||||
.collect(Collectors.toList());
|
||||
|
||||
final List<Covariate.Value<String>> x3List = DoubleStream
|
||||
.generate(() -> random.nextDouble())
|
||||
.limit(n)
|
||||
.mapToObj(db -> {
|
||||
if(db < 0.333){
|
||||
return "cat";
|
||||
}
|
||||
else if(db < 0.5){
|
||||
return "mouse";
|
||||
}
|
||||
else{
|
||||
return "dog";
|
||||
}
|
||||
})
|
||||
.map(str -> x3Covariate.createValue(str))
|
||||
.collect(Collectors.toList());
|
||||
|
||||
|
||||
for(int i=0; i<n; i++){
|
||||
final Covariate.Value<Double> x1 = x1List.get(i);
|
||||
final Covariate.Value<Double> x2 = x2List.get(i);
|
||||
final Covariate.Value<String> x3 = x3List.get(i);
|
||||
|
||||
trainingSet.add(generateRow(x1, x2, x3, i));
|
||||
}
|
||||
|
||||
final TreeTrainer<Double> treeTrainer = TreeTrainer.<Double>builder()
|
||||
.groupDifferentiator(new WeightedVarianceGroupDifferentiator())
|
||||
.responseCombiner(new MeanResponseCombiner())
|
||||
.maxNodeDepth(30)
|
||||
.nodeSize(5)
|
||||
.numberOfSplits(5)
|
||||
.build();
|
||||
|
||||
final List<Covariate> covariateNames = List.of(x1Covariate, x2Covariate);
|
||||
|
||||
final long startTime = System.currentTimeMillis();
|
||||
final Node<Double> baseNode = treeTrainer.growTree(trainingSet, covariateNames);
|
||||
final long endTime = System.currentTimeMillis();
|
||||
|
||||
System.out.println(((double)(endTime - startTime))/1000.0);
|
||||
|
||||
|
||||
|
||||
final Covariate.Value<String> cat = x3Covariate.createValue("cat");
|
||||
final Covariate.Value<String> dog = x3Covariate.createValue("dog");
|
||||
final Covariate.Value<String> mouse = x3Covariate.createValue("mouse");
|
||||
|
||||
|
||||
final List<CovariateRow> testSet = new ArrayList<>();
|
||||
testSet.add(generateCovariateRow(x1Covariate.createValue(9.0), x2Covariate.createValue(2.0), cat, 1)); // expect 1
|
||||
testSet.add(generateCovariateRow(x1Covariate.createValue(5.0), x2Covariate.createValue(2.0), dog, 5));
|
||||
testSet.add(generateCovariateRow(x1Covariate.createValue(2.0), x2Covariate.createValue(2.0), cat, 3));
|
||||
testSet.add(generateCovariateRow(x1Covariate.createValue(9.0), x2Covariate.createValue(5.0), dog, 0));
|
||||
testSet.add(generateCovariateRow(x1Covariate.createValue(6.0), x2Covariate.createValue(5.0), cat, 8));
|
||||
testSet.add(generateCovariateRow(x1Covariate.createValue(3.0), x2Covariate.createValue(5.0), dog, 10));
|
||||
testSet.add(generateCovariateRow(x1Covariate.createValue(1.0), x2Covariate.createValue(5.0), cat, 3));
|
||||
testSet.add(generateCovariateRow(x1Covariate.createValue(7.0), x2Covariate.createValue(9.0), dog, 2));
|
||||
testSet.add(generateCovariateRow(x1Covariate.createValue(1.0), x2Covariate.createValue(9.0), cat, 4));
|
||||
|
||||
|
||||
testSet.add(generateCovariateRow(x1Covariate.createValue(3.0), x2Covariate.createValue(9.0), mouse, 0));
|
||||
testSet.add(generateCovariateRow(x1Covariate.createValue(7.0), x2Covariate.createValue(9.0), mouse, 5));
|
||||
|
||||
for(final CovariateRow testCase : testSet){
|
||||
System.out.println(testCase);
|
||||
System.out.println(baseNode.evaluate(testCase));
|
||||
System.out.println();
|
||||
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
||||
public static Row<Double> generateRow(Covariate.Value<Double> x1, Covariate.Value<Double> x2, Covariate.Value<String> x3, int id){
|
||||
double y = generateResponse(x1.getValue(), x2.getValue(), x3.getValue());
|
||||
|
||||
final Map<String, Covariate.Value> map = Map.of("x1", x1, "x2", x2);
|
||||
|
||||
return new Row<>(map, id, y);
|
||||
|
||||
}
|
||||
|
||||
|
||||
public static CovariateRow generateCovariateRow(Covariate.Value x1, Covariate.Value x2, Covariate.Value x3, int id){
|
||||
final Map<String, Covariate.Value> map = Map.of("x1", x1, "x2", x2, "x3", x3);
|
||||
|
||||
return new CovariateRow(map, id);
|
||||
|
||||
}
|
||||
|
||||
|
||||
public static double generateResponse(double x1, double x2, String x3){
|
||||
|
||||
if(x3.equalsIgnoreCase("mouse")){
|
||||
if(x1 <= 5){
|
||||
return 0;
|
||||
}
|
||||
else{
|
||||
return 5;
|
||||
}
|
||||
}
|
||||
|
||||
// cat & dog below
|
||||
|
||||
if(x2 <= 3){
|
||||
if(x1 <= 3){
|
||||
return 3;
|
||||
}
|
||||
else if(x1 <= 7){
|
||||
return 5;
|
||||
}
|
||||
else{
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
else if(x1 >= 5){
|
||||
if(x2 > 6){
|
||||
return 2;
|
||||
}
|
||||
else if(x1 >= 8){
|
||||
return 0;
|
||||
}
|
||||
else{
|
||||
return 8;
|
||||
}
|
||||
}
|
||||
else if(x1 <= 2){
|
||||
if(x2 >= 7){
|
||||
return 4;
|
||||
}
|
||||
else{
|
||||
return 3;
|
||||
}
|
||||
}
|
||||
else{
|
||||
return 10;
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
||||
}
|
Loading…
Reference in a new issue