Refactored code to allow for a class of covariates to determine which

SplitRules are tested.

Most of the refactoring involved the creation of a Covariate class (one
instance per column); with SplitRule and Value being folded in as inner
classes.
This commit is contained in:
Joel Therrien 2018-07-03 17:00:02 -07:00
parent e7af65e8fd
commit e96a578ac9
14 changed files with 233 additions and 189 deletions

View file

@ -0,0 +1,58 @@
package ca.joeltherrien.randomforest;
import java.io.Serializable;
import java.util.Collection;
import java.util.LinkedList;
import java.util.List;
public interface Covariate<V> extends Serializable {
String getName();
Collection<? extends SplitRule<V>> generateSplitRules(final List<Value<V>> data, final int number);
Value<V> createValue(V value);
interface Value<V> extends Serializable{
Covariate<V> getParent();
V getValue();
}
interface SplitRule<V> extends Serializable{
Covariate<V> getParent();
/**
* Applies the SplitRule to a list of rows and returns a Split object, which contains two lists for both sides.
* This method is primarily used during the training of a tree when splits are being tested.
*
* @param rows
* @param <Y>
* @return
*/
default <Y> Split<Y> applyRule(List<Row<Y>> rows) {
final List<Row<Y>> leftHand = new LinkedList<>();
final List<Row<Y>> rightHand = new LinkedList<>();
for(final Row<Y> row : rows) {
if(isLeftHand(row)){
leftHand.add(row);
}
else{
rightHand.add(row);
}
}
return new Split<>(leftHand, rightHand);
}
boolean isLeftHand(CovariateRow row);
}
}

View file

@ -8,12 +8,12 @@ import java.util.Map;
@RequiredArgsConstructor @RequiredArgsConstructor
public class CovariateRow { public class CovariateRow {
private final Map<String, Value> valueMap; private final Map<String, Covariate.Value> valueMap;
@Getter @Getter
private final int id; private final int id;
public Value<?> getCovariate(String name){ public Covariate.Value<?> getCovariateValue(String name){
return valueMap.get(name); return valueMap.get(name);
} }

View file

@ -0,0 +1,103 @@
package ca.joeltherrien.randomforest;
import ca.joeltherrien.randomforest.exceptions.MissingValueException;
import lombok.Getter;
import lombok.RequiredArgsConstructor;
import java.util.*;
import java.util.concurrent.ThreadLocalRandom;
import java.util.stream.Collectors;
@RequiredArgsConstructor
public class NumericCovariate implements Covariate<Double>{
@Getter
private final String name;
@Override
public Collection<NumericSplitRule> generateSplitRules(List<Value<Double>> data, int number) {
// for this implementation we need to shuffle the data
final List<Value<Double>> shuffledData;
if(number > data.size()){
shuffledData = new ArrayList<>(data);
Collections.shuffle(shuffledData);
}
else{ // only need the top number entries
shuffledData = new ArrayList<>(number);
final Set<Integer> indexesToUse = new HashSet<>();
while(indexesToUse.size() < number){
final int index = ThreadLocalRandom.current().nextInt(data.size());
if(indexesToUse.add(index)){
shuffledData.add(data.get(index));
}
}
}
return shuffledData.stream()
.mapToDouble(v -> v.getValue())
.mapToObj(threshold -> new NumericSplitRule(threshold))
.collect(Collectors.toSet());
// by returning a set we'll make everything far more efficient as a lot of rules can repeat due to bootstrapping
}
@Override
public NumericValue createValue(Double value) {
return new NumericValue(value);
}
public class NumericValue implements Covariate.Value<Double>{
private final double value;
private NumericValue(final double value){
this.value = value;
}
@Override
public Covariate<Double> getParent() {
return NumericCovariate.this;
}
@Override
public Double getValue() {
return value;
}
}
public class NumericSplitRule implements Covariate.SplitRule<Double>{
private final double threshold;
private NumericSplitRule(final double threshold){
this.threshold = threshold;
}
@Override
public final String toString() {
return "NumericSplitRule on " + getParent().getName() + " at " + threshold;
}
@Override
public Covariate<Double> getParent() {
return NumericCovariate.this;
}
@Override
public boolean isLeftHand(CovariateRow row) {
final Covariate.Value<?> x = row.getCovariateValue(getParent().getName());
if(x == null) {
throw new MissingValueException(row, this);
}
final double xNum = (Double) x.getValue();
return xNum <= threshold;
}
}
}

View file

@ -1,33 +0,0 @@
package ca.joeltherrien.randomforest;
import java.util.LinkedList;
import java.util.List;
import ca.joeltherrien.randomforest.exceptions.MissingValueException;
import lombok.AllArgsConstructor;
@AllArgsConstructor
public class NumericSplitRule extends SplitRule{
public final String covariateName;
public final double threshold;
@Override
public final String toString() {
return "NumericSplitRule on " + covariateName + " at " + threshold;
}
@Override
public boolean isLeftHand(CovariateRow row) {
final Value<?> x = row.getCovariate(covariateName);
if(x == null) {
throw new MissingValueException(row, this);
}
final double xNum = (Double) x.getValue();
return xNum <= threshold;
}
}

View file

@ -1,24 +0,0 @@
package ca.joeltherrien.randomforest;
import lombok.RequiredArgsConstructor;
@RequiredArgsConstructor
public class NumericValue implements Value<Double> {
private final double value;
@Override
public Double getValue() {
return value;
}
@Override
public SplitRule generateSplitRule(final String covariateName) {
return new NumericSplitRule(covariateName, value);
}
@Override
public String toString(){
return "" + value;
}
}

View file

@ -7,7 +7,7 @@ public class Row<Y> extends CovariateRow {
private final Y response; private final Y response;
public Row(Map<String, Value> valueMap, int id, Y response){ public Row(Map<String, Covariate.Value> valueMap, int id, Y response){
super(valueMap, id); super(valueMap, id);
this.response = response; this.response = response;
} }

View file

@ -1,37 +0,0 @@
package ca.joeltherrien.randomforest;
import java.io.Serializable;
import java.util.LinkedList;
import java.util.List;
public abstract class SplitRule implements Serializable {
/**
* Applies the SplitRule to a list of rows and returns a Split object, which contains two lists for both sides.
* This method is primarily used during the training of a tree when splits are being tested.
*
* @param rows
* @param <Y>
* @return
*/
public <Y> Split<Y> applyRule(List<Row<Y>> rows) {
final List<Row<Y>> leftHand = new LinkedList<>();
final List<Row<Y>> rightHand = new LinkedList<>();
for(final Row<Y> row : rows) {
if(isLeftHand(row)){
leftHand.add(row);
}
else{
rightHand.add(row);
}
}
return new Split<>(leftHand, rightHand);
}
public abstract boolean isLeftHand(CovariateRow row);
}

View file

@ -1,11 +0,0 @@
package ca.joeltherrien.randomforest;
public interface Value<V> {
V getValue();
SplitRule generateSplitRule(String covariateName);
}

View file

@ -1,8 +1,7 @@
package ca.joeltherrien.randomforest.exceptions; package ca.joeltherrien.randomforest.exceptions;
import ca.joeltherrien.randomforest.Covariate;
import ca.joeltherrien.randomforest.CovariateRow; import ca.joeltherrien.randomforest.CovariateRow;
import ca.joeltherrien.randomforest.Row;
import ca.joeltherrien.randomforest.SplitRule;
public class MissingValueException extends RuntimeException{ public class MissingValueException extends RuntimeException{
@ -11,7 +10,7 @@ public class MissingValueException extends RuntimeException{
*/ */
private static final long serialVersionUID = 6808060079431207726L; private static final long serialVersionUID = 6808060079431207726L;
public MissingValueException(CovariateRow row, SplitRule rule) { public MissingValueException(CovariateRow row, Covariate.SplitRule rule) {
super("Missing value at CovariateRow " + row + rule); super("Missing value at CovariateRow " + row + rule);
} }

View file

@ -1,6 +1,7 @@
package ca.joeltherrien.randomforest.tree; package ca.joeltherrien.randomforest.tree;
import ca.joeltherrien.randomforest.Bootstrapper; import ca.joeltherrien.randomforest.Bootstrapper;
import ca.joeltherrien.randomforest.Covariate;
import ca.joeltherrien.randomforest.ResponseCombiner; import ca.joeltherrien.randomforest.ResponseCombiner;
import ca.joeltherrien.randomforest.Row; import ca.joeltherrien.randomforest.Row;
import lombok.Builder; import lombok.Builder;
@ -21,7 +22,7 @@ import java.util.stream.Stream;
public class ForestTrainer<Y> { public class ForestTrainer<Y> {
private final TreeTrainer<Y> treeTrainer; private final TreeTrainer<Y> treeTrainer;
private final List<String> covariatesToTry; private final List<Covariate> covariatesToTry;
private final ResponseCombiner<Y, ?> treeResponseCombiner; private final ResponseCombiner<Y, ?> treeResponseCombiner;
private final List<Row<Y>> data; private final List<Row<Y>> data;
@ -140,7 +141,7 @@ public class ForestTrainer<Y> {
} }
private Node<Y> trainTree(final Bootstrapper<Row<Y>> bootstrapper){ private Node<Y> trainTree(final Bootstrapper<Row<Y>> bootstrapper){
final List<String> treeCovariates = new ArrayList<>(covariatesToTry); final List<Covariate> treeCovariates = new ArrayList<>(covariatesToTry);
Collections.shuffle(treeCovariates); Collections.shuffle(treeCovariates);
for(int treeIndex = covariatesToTry.size()-1; treeIndex >= mtry; treeIndex--){ for(int treeIndex = covariatesToTry.size()-1; treeIndex >= mtry; treeIndex--){

View file

@ -1,8 +1,7 @@
package ca.joeltherrien.randomforest.tree; package ca.joeltherrien.randomforest.tree;
import ca.joeltherrien.randomforest.Covariate;
import ca.joeltherrien.randomforest.CovariateRow; import ca.joeltherrien.randomforest.CovariateRow;
import ca.joeltherrien.randomforest.Row;
import ca.joeltherrien.randomforest.SplitRule;
import lombok.Builder; import lombok.Builder;
@Builder @Builder
@ -10,7 +9,7 @@ public class SplitNode<Y> implements Node<Y> {
private final Node<Y> leftHand; private final Node<Y> leftHand;
private final Node<Y> rightHand; private final Node<Y> rightHand;
private final SplitRule splitRule; private final Covariate.SplitRule splitRule;
@Override @Override
public Y evaluate(CovariateRow row) { public Y evaluate(CovariateRow row) {

View file

@ -1,8 +1,6 @@
package ca.joeltherrien.randomforest.tree; package ca.joeltherrien.randomforest.tree;
import ca.joeltherrien.randomforest.*; import ca.joeltherrien.randomforest.*;
import ca.joeltherrien.randomforest.regression.MeanResponseCombiner;
import ca.joeltherrien.randomforest.regression.WeightedVarianceGroupDifferentiator;
import lombok.Builder; import lombok.Builder;
import java.util.*; import java.util.*;
@ -22,17 +20,15 @@ public class TreeTrainer<Y> {
private final int nodeSize; private final int nodeSize;
private final int maxNodeDepth; private final int maxNodeDepth;
private final Random random = new Random();
public Node<Y> growTree(List<Row<Y>> data, List<Covariate> covariatesToTry){
public Node<Y> growTree(List<Row<Y>> data, List<String> covariatesToTry){
return growNode(data, covariatesToTry, 0); return growNode(data, covariatesToTry, 0);
} }
private Node<Y> growNode(List<Row<Y>> data, List<String> covariatesToTry, int depth){ private Node<Y> growNode(List<Row<Y>> data, List<Covariate> covariatesToTry, int depth){
// TODO; what is minimum per tree? // TODO; what is minimum per tree?
if(data.size() >= 2*nodeSize && depth < maxNodeDepth && !nodeIsPure(data)){ if(data.size() >= 2*nodeSize && depth < maxNodeDepth && !nodeIsPure(data)){
final SplitRule bestSplitRule = findBestSplitRule(data, covariatesToTry); final Covariate.SplitRule bestSplitRule = findBestSplitRule(data, covariatesToTry);
if(bestSplitRule == null){ if(bestSplitRule == null){
return new TerminalNode<>( return new TerminalNode<>(
@ -63,37 +59,24 @@ public class TreeTrainer<Y> {
} }
private SplitRule findBestSplitRule(List<Row<Y>> data, List<String> covariatesToTry){ private Covariate.SplitRule findBestSplitRule(List<Row<Y>> data, List<Covariate> covariatesToTry){
SplitRule bestSplitRule = null; Covariate.SplitRule bestSplitRule = null;
double bestSplitScore = 0.0; double bestSplitScore = 0.0;
boolean first = true; boolean first = true;
for(final String covariate : covariatesToTry){ for(final Covariate covariate : covariatesToTry){
final List<Row<Y>> shuffledData; final int numberToTry = numberOfSplits==0 ? data.size() : numberOfSplits;
if(numberOfSplits == 0 || numberOfSplits > data.size()){
shuffledData = new ArrayList<>(data);
Collections.shuffle(shuffledData);
}
else{ // only need the top numberOfSplits entries
shuffledData = new ArrayList<>(numberOfSplits);
final Set<Integer> indexesToUse = new HashSet<>();
while(indexesToUse.size() < numberOfSplits){ final Collection<Covariate.SplitRule> splitRulesToTry = covariate
final int index = random.nextInt(data.size()); .generateSplitRules(
data
.stream()
.map(row -> row.getCovariateValue(covariate.getName()))
.collect(Collectors.toList())
, numberToTry);
if(indexesToUse.add(index)){ for(final Covariate.SplitRule possibleRule : splitRulesToTry){
shuffledData.add(data.get(index));
}
}
}
int tries = 0;
while(tries < shuffledData.size()){
final SplitRule possibleRule = shuffledData.get(tries).getCovariate(covariate).generateSplitRule(covariate);
final Split<Y> possibleSplit = possibleRule.applyRule(data); final Split<Y> possibleSplit = possibleRule.applyRule(data);
final Double score = groupDifferentiator.differentiate( final Double score = groupDifferentiator.differentiate(
@ -106,8 +89,6 @@ public class TreeTrainer<Y> {
bestSplitScore = score; bestSplitScore = score;
first = false; first = false;
} }
tries++;
} }
} }

View file

@ -19,21 +19,28 @@ public class TrainForest {
final int n = 10000; final int n = 10000;
final int p = 5; final int p = 5;
final Random random = new Random(); final Random random = new Random();
final List<Row<Double>> data = new ArrayList<>(n); final List<Row<Double>> data = new ArrayList<>(n);
double minY = 1000.0; double minY = 1000.0;
final List<Covariate> covariateList = new ArrayList<>(p);
for(int j =0; j < p; j++){
final NumericCovariate covariate = new NumericCovariate("x"+j);
covariateList.add(covariate);
}
for(int i=0; i<n; i++){ for(int i=0; i<n; i++){
double y = 0.0; double y = 0.0;
final Map<String, Value> map = new HashMap<>(); final Map<String, Covariate.Value> map = new HashMap<>();
for(int j=0; j<p; j++){ for(final Covariate covariate : covariateList) {
final double x = random.nextDouble(); final double x = random.nextDouble();
y += x; y += x;
map.put("x"+j, new NumericValue(x)); map.put(covariate.getName(), covariate.createValue(x));
} }
data.add(i, new Row<>(map, i, y)); data.add(i, new Row<>(map, i, y));
@ -44,10 +51,8 @@ public class TrainForest {
} }
final List<String> covariateNames = IntStream.range(0, p).mapToObj(j -> "x"+j).collect(Collectors.toList());
final TreeTrainer<Double> treeTrainer = TreeTrainer.<Double>builder()
TreeTrainer<Double> treeTrainer = TreeTrainer.<Double>builder()
.numberOfSplits(5) .numberOfSplits(5)
.nodeSize(5) .nodeSize(5)
.maxNodeDepth(100000000) .maxNodeDepth(100000000)
@ -58,7 +63,7 @@ public class TrainForest {
final ForestTrainer<Double> forestTrainer = ForestTrainer.<Double>builder() final ForestTrainer<Double> forestTrainer = ForestTrainer.<Double>builder()
.treeTrainer(treeTrainer) .treeTrainer(treeTrainer)
.data(data) .data(data)
.covariatesToTry(covariateNames) .covariatesToTry(covariateList)
.mtry(4) .mtry(4)
.ntree(100) .ntree(100)
.treeResponseCombiner(new MeanResponseCombiner()) .treeResponseCombiner(new MeanResponseCombiner())
@ -69,7 +74,7 @@ public class TrainForest {
final long startTime = System.currentTimeMillis(); final long startTime = System.currentTimeMillis();
//final Forest<Double> forest = forestTrainer.trainSerial(); //final Forest<Double> forest = forestTrainer.trainSerial();
//final Forest<Double> forest = forestTrainer.trainParallel(8); //final Forest<Double> forest = forestTrainer.trainParallelInMemory(3);
forestTrainer.trainParallelOnDisk(3); forestTrainer.trainParallelOnDisk(3);
final long endTime = System.currentTimeMillis(); final long endTime = System.currentTimeMillis();
@ -88,9 +93,9 @@ public class TrainForest {
System.out.println(forest.evaluate(testRow1)); System.out.println(forest.evaluate(testRow1));
System.out.println(forest.evaluate(testRow2)); System.out.println(forest.evaluate(testRow2));
System.out.println("MinY = " + minY);
*/ */
System.out.println("MinY = " + minY);
} }
} }

View file

@ -1,11 +1,10 @@
package ca.joeltherrien.randomforest.workshop; package ca.joeltherrien.randomforest.workshop;
import ca.joeltherrien.randomforest.Covariate;
import ca.joeltherrien.randomforest.CovariateRow; import ca.joeltherrien.randomforest.CovariateRow;
import ca.joeltherrien.randomforest.NumericValue; import ca.joeltherrien.randomforest.NumericCovariate;
import ca.joeltherrien.randomforest.Row; import ca.joeltherrien.randomforest.Row;
import ca.joeltherrien.randomforest.Value;
import ca.joeltherrien.randomforest.regression.MeanGroupDifferentiator;
import ca.joeltherrien.randomforest.regression.MeanResponseCombiner; import ca.joeltherrien.randomforest.regression.MeanResponseCombiner;
import ca.joeltherrien.randomforest.regression.WeightedVarianceGroupDifferentiator; import ca.joeltherrien.randomforest.regression.WeightedVarianceGroupDifferentiator;
import ca.joeltherrien.randomforest.tree.Node; import ca.joeltherrien.randomforest.tree.Node;
@ -25,30 +24,30 @@ public class TrainSingleTree {
final int n = 1000; final int n = 1000;
final List<Row<Double>> trainingSet = new ArrayList<>(n); final List<Row<Double>> trainingSet = new ArrayList<>(n);
final List<Value<Double>> x1List = DoubleStream final Covariate<Double> x1Covariate = new NumericCovariate("x1");
final Covariate<Double> x2Covariate = new NumericCovariate("x2");
final List<Covariate.Value<Double>> x1List = DoubleStream
.generate(() -> random.nextDouble()*10.0) .generate(() -> random.nextDouble()*10.0)
.limit(n) .limit(n)
.mapToObj(x1 -> new NumericValue(x1)) .mapToObj(x1 -> x1Covariate.createValue(x1))
.collect(Collectors.toList()); .collect(Collectors.toList());
final List<Value<Double>> x2List = DoubleStream final List<Covariate.Value<Double>> x2List = DoubleStream
.generate(() -> random.nextDouble()*10.0) .generate(() -> random.nextDouble()*10.0)
.limit(n) .limit(n)
.mapToObj(x1 -> new NumericValue(x1)) .mapToObj(x2 -> x1Covariate.createValue(x2))
.collect(Collectors.toList()); .collect(Collectors.toList());
for(int i=0; i<n; i++){ for(int i=0; i<n; i++){
double x1 = x1List.get(i).getValue(); final Covariate.Value<Double> x1 = x1List.get(i);
double x2 = x2List.get(i).getValue(); final Covariate.Value<Double> x2 = x2List.get(i);
trainingSet.add(generateRow(x1, x2, i)); trainingSet.add(generateRow(x1, x2, i));
} }
final long startTime = System.currentTimeMillis();
final TreeTrainer<Double> treeTrainer = TreeTrainer.<Double>builder() final TreeTrainer<Double> treeTrainer = TreeTrainer.<Double>builder()
.groupDifferentiator(new WeightedVarianceGroupDifferentiator()) .groupDifferentiator(new WeightedVarianceGroupDifferentiator())
.responseCombiner(new MeanResponseCombiner()) .responseCombiner(new MeanResponseCombiner())
@ -57,25 +56,29 @@ public class TrainSingleTree {
.numberOfSplits(0) .numberOfSplits(0)
.build(); .build();
final List<Covariate> covariateNames = List.of(x1Covariate, x2Covariate);
final long startTime = System.currentTimeMillis();
final Node<Double> baseNode = treeTrainer.growTree(trainingSet, covariateNames);
final long endTime = System.currentTimeMillis(); final long endTime = System.currentTimeMillis();
System.out.println(((double)(endTime - startTime))/1000.0); System.out.println(((double)(endTime - startTime))/1000.0);
final List<String> covariateNames = List.of("x1", "x2");
final Node<Double> baseNode = treeTrainer.growTree(trainingSet, covariateNames);
final List<CovariateRow> testSet = new ArrayList<>(); final List<CovariateRow> testSet = new ArrayList<>();
testSet.add(generateCovariateRow(9, 2, 1)); // expect 1 testSet.add(generateCovariateRow(x1Covariate.createValue(9.0), x2Covariate.createValue(2.0), 1)); // expect 1
testSet.add(generateCovariateRow(5, 2, 5)); testSet.add(generateCovariateRow(x1Covariate.createValue(5.0), x2Covariate.createValue(2.0), 5));
testSet.add(generateCovariateRow(2, 2, 3)); testSet.add(generateCovariateRow(x1Covariate.createValue(2.0), x2Covariate.createValue(2.0), 3));
testSet.add(generateCovariateRow(9, 5, 0)); testSet.add(generateCovariateRow(x1Covariate.createValue(9.0), x2Covariate.createValue(5.0), 0));
testSet.add(generateCovariateRow(6, 5, 8)); testSet.add(generateCovariateRow(x1Covariate.createValue(6.0), x2Covariate.createValue(5.0), 8));
testSet.add(generateCovariateRow(3, 5, 10)); testSet.add(generateCovariateRow(x1Covariate.createValue(3.0), x2Covariate.createValue(5.0), 10));
testSet.add(generateCovariateRow(1, 5, 3)); testSet.add(generateCovariateRow(x1Covariate.createValue(1.0), x2Covariate.createValue(5.0), 3));
testSet.add(generateCovariateRow(7, 9, 2)); testSet.add(generateCovariateRow(x1Covariate.createValue(7.0), x2Covariate.createValue(9.0), 2));
testSet.add(generateCovariateRow(1, 9, 4)); testSet.add(generateCovariateRow(x1Covariate.createValue(1.0), x2Covariate.createValue(9.0), 4));
for(final CovariateRow testCase : testSet){ for(final CovariateRow testCase : testSet){
System.out.println(testCase); System.out.println(testCase);
@ -91,18 +94,18 @@ public class TrainSingleTree {
} }
public static Row<Double> generateRow(double x1, double x2, int id){ public static Row<Double> generateRow(Covariate.Value<Double> x1, Covariate.Value<Double> x2, int id){
double y = generateResponse(x1, x2); double y = generateResponse(x1.getValue(), x2.getValue());
final Map<String, Value> map = Map.of("x1", new NumericValue(x1), "x2", new NumericValue(x2)); final Map<String, Covariate.Value> map = Map.of("x1", x1, "x2", x2);
return new Row<>(map, id, y); return new Row<>(map, id, y);
} }
public static CovariateRow generateCovariateRow(double x1, double x2, int id){ public static CovariateRow generateCovariateRow(Covariate.Value x1, Covariate.Value x2, int id){
final Map<String, Value> map = Map.of("x1", new NumericValue(x1), "x2", new NumericValue(x2)); final Map<String, Covariate.Value> map = Map.of("x1", x1, "x2", x2);
return new CovariateRow(map, id); return new CovariateRow(map, id);