Add functionality to train a random

forest in serial.
This commit is contained in:
Joel Therrien 2018-07-02 17:58:53 -07:00
parent 6192643e12
commit df7835869a
6 changed files with 234 additions and 7 deletions

View file

@ -0,0 +1,30 @@
package ca.joeltherrien.randomforest;
import lombok.RequiredArgsConstructor;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
@RequiredArgsConstructor
public class Bootstrapper<T> {
final private List<T> originalData;
final private Random random = new Random();
public List<T> bootstrap(){
final int n = originalData.size();
final List<T> newList = new ArrayList<>(n);
for(int i=0; i<n; i++){
final int index = random.nextInt(n);
newList.add(originalData.get(index));
}
return newList;
}
}

View file

@ -0,0 +1,21 @@
package ca.joeltherrien.randomforest.tree;
import ca.joeltherrien.randomforest.CovariateRow;
import ca.joeltherrien.randomforest.ResponseCombiner;
import lombok.Builder;
import java.util.List;
@Builder
public class Forest<Y> {
private final List<Node<Y>> trees;
private final ResponseCombiner<Y, ?> treeResponseCombiner;
public Y evaluate(CovariateRow row){
return trees.parallelStream()
.map(node -> node.evaluate(row))
.collect(treeResponseCombiner);
}
}

View file

@ -0,0 +1,62 @@
package ca.joeltherrien.randomforest.tree;
import ca.joeltherrien.randomforest.Bootstrapper;
import ca.joeltherrien.randomforest.ResponseCombiner;
import ca.joeltherrien.randomforest.Row;
import lombok.Builder;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
@Builder
public class ForestTrainer<Y> {
private final TreeTrainer<Y> treeTrainer;
private final Bootstrapper<Row<Y>> bootstrapper;
private final List<String> covariatesToTry;
private final ResponseCombiner<Y, ?> treeResponseCombiner;
// number of covariates to randomly try
private final int mtry;
// number of trees to try
private final int ntree;
private final boolean displayProgress;
public Forest<Y> trainSerial(){
final List<Node<Y>> trees = new ArrayList<>(ntree);
for(int j=0; j<ntree; j++){
final List<String> treeCovariates = new ArrayList<>(covariatesToTry);
Collections.shuffle(treeCovariates);
for(int treeIndex = covariatesToTry.size()-1; treeIndex >= mtry; treeIndex--){
treeCovariates.remove(treeIndex);
}
final List<Row<Y>> bootstrappedData = bootstrapper.bootstrap();
trees.add(treeTrainer.growTree(bootstrappedData, treeCovariates));
if(displayProgress){
if(j==0) {
System.out.println();
}
System.out.print("\rFinished tree " + (j+1) + "/" + ntree);
if(j==ntree-1){
System.out.println();
}
}
}
return Forest.<Y>builder()
.treeResponseCombiner(treeResponseCombiner)
.trees(trees)
.build();
}
}

View file

@ -6,8 +6,7 @@ import ca.joeltherrien.randomforest.Split;
import ca.joeltherrien.randomforest.SplitRule; import ca.joeltherrien.randomforest.SplitRule;
import lombok.Builder; import lombok.Builder;
import java.util.Collections; import java.util.*;
import java.util.List;
import java.util.stream.Collectors; import java.util.stream.Collectors;
@Builder @Builder
@ -24,6 +23,8 @@ public class TreeTrainer<Y> {
private final int nodeSize; private final int nodeSize;
private final int maxNodeDepth; private final int maxNodeDepth;
private final Random random = new Random();
public Node<Y> growTree(List<Row<Y>> data, List<String> covariatesToTry){ public Node<Y> growTree(List<Row<Y>> data, List<String> covariatesToTry){
return growNode(data, covariatesToTry, 0); return growNode(data, covariatesToTry, 0);
@ -60,11 +61,30 @@ public class TreeTrainer<Y> {
boolean first = true; boolean first = true;
for(final String covariate : covariatesToTry){ for(final String covariate : covariatesToTry){
Collections.shuffle(data);
final List<Row<Y>> shuffledData;
if(numberOfSplits == 0 || numberOfSplits > data.size()){
shuffledData = new ArrayList<>(data);
Collections.shuffle(shuffledData);
}
else{ // only need the top numberOfSplits entries
shuffledData = new ArrayList<>(numberOfSplits);
final Set<Integer> indexesToUse = new HashSet<>();
while(indexesToUse.size() < numberOfSplits){
final int index = random.nextInt(data.size());
if(indexesToUse.add(index)){
shuffledData.add(data.get(index));
}
}
}
int tries = 0; int tries = 0;
while(tries <= numberOfSplits || (numberOfSplits == 0 && tries < data.size())){ while(tries < shuffledData.size()){
final SplitRule possibleRule = data.get(tries).getCovariate(covariate).generateSplitRule(covariate); final SplitRule possibleRule = shuffledData.get(tries).getCovariate(covariate).generateSplitRule(covariate);
final Split<Y> possibleSplit = possibleRule.applyRule(data); final Split<Y> possibleSplit = possibleRule.applyRule(data);
final Double score = groupDifferentiator.differentiate( final Double score = groupDifferentiator.differentiate(

View file

@ -0,0 +1,90 @@
package ca.joeltherrien.randomforest.workshop;
import ca.joeltherrien.randomforest.*;
import ca.joeltherrien.randomforest.regression.MeanResponseCombiner;
import ca.joeltherrien.randomforest.regression.WeightedVarianceGroupDifferentiator;
import ca.joeltherrien.randomforest.tree.Forest;
import ca.joeltherrien.randomforest.tree.ForestTrainer;
import ca.joeltherrien.randomforest.tree.TreeTrainer;
import java.util.*;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
public class TrainForest {
public static void main(String[] args){
// test creating a regression tree on a problem and see if the results are sensible.
final int n = 10000;
final int p =5;
final Random random = new Random();
final List<Row<Double>> data = new ArrayList<>(n);
double minY = 1000.0;
for(int i=0; i<n; i++){
double y = 0.0;
final Map<String, Value> map = new HashMap<>();
for(int j=0; j<p; j++){
final double x = random.nextDouble();
y+=x;
map.put("x"+j, new NumericValue(x));
}
data.add(i, new Row<>(map, i, y));
if(y < minY){
minY = y;
}
}
final List<String> covariateNames = IntStream.range(0, p).mapToObj(j -> "x"+j).collect(Collectors.toList());
TreeTrainer<Double> treeTrainer = TreeTrainer.<Double>builder()
.numberOfSplits(10)
.nodeSize(5)
.maxNodeDepth(100000000)
.groupDifferentiator(new WeightedVarianceGroupDifferentiator())
.responseCombiner(new MeanResponseCombiner())
.build();
final ForestTrainer<Double> forestTrainer = ForestTrainer.<Double>builder()
.treeTrainer(treeTrainer)
.bootstrapper(new Bootstrapper<>(data))
.covariatesToTry(covariateNames)
.mtry(4)
.ntree(100)
.treeResponseCombiner(new MeanResponseCombiner())
.displayProgress(true)
.build();
final long startTime = System.currentTimeMillis();
final Forest<Double> forest = forestTrainer.trainSerial();
final long endTime = System.currentTimeMillis();
System.out.println("Took " + (double)(endTime - startTime)/1000.0 + " seconds.");
final Value zeroValue = new NumericValue(0.1);
final Value point5Value = new NumericValue(0.5);
// test row
final CovariateRow testRow1 = new CovariateRow(Map.of("x0", zeroValue, "x1",zeroValue,"x2",zeroValue,"x3",zeroValue,"x4",zeroValue), 0);
final CovariateRow testRow2 = new CovariateRow(Map.of("x0", point5Value, "x1",point5Value,"x2",point5Value,"x3",point5Value,"x4",point5Value), 2);
System.out.println(forest.evaluate(testRow1));
System.out.println(forest.evaluate(testRow2));
System.out.println("MinY = " + minY);
}
}

View file

@ -1,6 +1,10 @@
package ca.joeltherrien.randomforest; package ca.joeltherrien.randomforest.workshop;
import ca.joeltherrien.randomforest.CovariateRow;
import ca.joeltherrien.randomforest.NumericValue;
import ca.joeltherrien.randomforest.Row;
import ca.joeltherrien.randomforest.Value;
import ca.joeltherrien.randomforest.regression.MeanGroupDifferentiator; import ca.joeltherrien.randomforest.regression.MeanGroupDifferentiator;
import ca.joeltherrien.randomforest.regression.MeanResponseCombiner; import ca.joeltherrien.randomforest.regression.MeanResponseCombiner;
import ca.joeltherrien.randomforest.regression.WeightedVarianceGroupDifferentiator; import ca.joeltherrien.randomforest.regression.WeightedVarianceGroupDifferentiator;
@ -11,7 +15,7 @@ import java.util.*;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import java.util.stream.DoubleStream; import java.util.stream.DoubleStream;
public class Main { public class TrainSingleTree {
public static void main(String[] args) { public static void main(String[] args) {
System.out.println("Hello world!"); System.out.println("Hello world!");