Add functionality to train a random
forest in serial.
This commit is contained in:
parent
6192643e12
commit
df7835869a
6 changed files with 234 additions and 7 deletions
30
src/main/java/ca/joeltherrien/randomforest/Bootstrapper.java
Normal file
30
src/main/java/ca/joeltherrien/randomforest/Bootstrapper.java
Normal file
|
@ -0,0 +1,30 @@
|
||||||
|
package ca.joeltherrien.randomforest;
|
||||||
|
|
||||||
|
import lombok.RequiredArgsConstructor;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Random;
|
||||||
|
|
||||||
|
@RequiredArgsConstructor
|
||||||
|
public class Bootstrapper<T> {
|
||||||
|
|
||||||
|
final private List<T> originalData;
|
||||||
|
final private Random random = new Random();
|
||||||
|
|
||||||
|
public List<T> bootstrap(){
|
||||||
|
final int n = originalData.size();
|
||||||
|
|
||||||
|
final List<T> newList = new ArrayList<>(n);
|
||||||
|
|
||||||
|
for(int i=0; i<n; i++){
|
||||||
|
final int index = random.nextInt(n);
|
||||||
|
|
||||||
|
newList.add(originalData.get(index));
|
||||||
|
}
|
||||||
|
|
||||||
|
return newList;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
21
src/main/java/ca/joeltherrien/randomforest/tree/Forest.java
Normal file
21
src/main/java/ca/joeltherrien/randomforest/tree/Forest.java
Normal file
|
@ -0,0 +1,21 @@
|
||||||
|
package ca.joeltherrien.randomforest.tree;
|
||||||
|
|
||||||
|
import ca.joeltherrien.randomforest.CovariateRow;
|
||||||
|
import ca.joeltherrien.randomforest.ResponseCombiner;
|
||||||
|
import lombok.Builder;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
@Builder
|
||||||
|
public class Forest<Y> {
|
||||||
|
|
||||||
|
private final List<Node<Y>> trees;
|
||||||
|
private final ResponseCombiner<Y, ?> treeResponseCombiner;
|
||||||
|
|
||||||
|
public Y evaluate(CovariateRow row){
|
||||||
|
return trees.parallelStream()
|
||||||
|
.map(node -> node.evaluate(row))
|
||||||
|
.collect(treeResponseCombiner);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,62 @@
|
||||||
|
package ca.joeltherrien.randomforest.tree;
|
||||||
|
|
||||||
|
import ca.joeltherrien.randomforest.Bootstrapper;
|
||||||
|
import ca.joeltherrien.randomforest.ResponseCombiner;
|
||||||
|
import ca.joeltherrien.randomforest.Row;
|
||||||
|
import lombok.Builder;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
@Builder
|
||||||
|
public class ForestTrainer<Y> {
|
||||||
|
|
||||||
|
private final TreeTrainer<Y> treeTrainer;
|
||||||
|
private final Bootstrapper<Row<Y>> bootstrapper;
|
||||||
|
private final List<String> covariatesToTry;
|
||||||
|
private final ResponseCombiner<Y, ?> treeResponseCombiner;
|
||||||
|
|
||||||
|
// number of covariates to randomly try
|
||||||
|
private final int mtry;
|
||||||
|
|
||||||
|
// number of trees to try
|
||||||
|
private final int ntree;
|
||||||
|
|
||||||
|
private final boolean displayProgress;
|
||||||
|
|
||||||
|
public Forest<Y> trainSerial(){
|
||||||
|
|
||||||
|
final List<Node<Y>> trees = new ArrayList<>(ntree);
|
||||||
|
|
||||||
|
for(int j=0; j<ntree; j++){
|
||||||
|
final List<String> treeCovariates = new ArrayList<>(covariatesToTry);
|
||||||
|
Collections.shuffle(treeCovariates);
|
||||||
|
|
||||||
|
for(int treeIndex = covariatesToTry.size()-1; treeIndex >= mtry; treeIndex--){
|
||||||
|
treeCovariates.remove(treeIndex);
|
||||||
|
}
|
||||||
|
|
||||||
|
final List<Row<Y>> bootstrappedData = bootstrapper.bootstrap();
|
||||||
|
|
||||||
|
trees.add(treeTrainer.growTree(bootstrappedData, treeCovariates));
|
||||||
|
|
||||||
|
if(displayProgress){
|
||||||
|
if(j==0) {
|
||||||
|
System.out.println();
|
||||||
|
}
|
||||||
|
System.out.print("\rFinished tree " + (j+1) + "/" + ntree);
|
||||||
|
if(j==ntree-1){
|
||||||
|
System.out.println();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return Forest.<Y>builder()
|
||||||
|
.treeResponseCombiner(treeResponseCombiner)
|
||||||
|
.trees(trees)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -6,8 +6,7 @@ import ca.joeltherrien.randomforest.Split;
|
||||||
import ca.joeltherrien.randomforest.SplitRule;
|
import ca.joeltherrien.randomforest.SplitRule;
|
||||||
import lombok.Builder;
|
import lombok.Builder;
|
||||||
|
|
||||||
import java.util.Collections;
|
import java.util.*;
|
||||||
import java.util.List;
|
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
@Builder
|
@Builder
|
||||||
|
@ -24,6 +23,8 @@ public class TreeTrainer<Y> {
|
||||||
private final int nodeSize;
|
private final int nodeSize;
|
||||||
private final int maxNodeDepth;
|
private final int maxNodeDepth;
|
||||||
|
|
||||||
|
private final Random random = new Random();
|
||||||
|
|
||||||
|
|
||||||
public Node<Y> growTree(List<Row<Y>> data, List<String> covariatesToTry){
|
public Node<Y> growTree(List<Row<Y>> data, List<String> covariatesToTry){
|
||||||
return growNode(data, covariatesToTry, 0);
|
return growNode(data, covariatesToTry, 0);
|
||||||
|
@ -60,11 +61,30 @@ public class TreeTrainer<Y> {
|
||||||
boolean first = true;
|
boolean first = true;
|
||||||
|
|
||||||
for(final String covariate : covariatesToTry){
|
for(final String covariate : covariatesToTry){
|
||||||
Collections.shuffle(data);
|
|
||||||
|
final List<Row<Y>> shuffledData;
|
||||||
|
if(numberOfSplits == 0 || numberOfSplits > data.size()){
|
||||||
|
shuffledData = new ArrayList<>(data);
|
||||||
|
Collections.shuffle(shuffledData);
|
||||||
|
}
|
||||||
|
else{ // only need the top numberOfSplits entries
|
||||||
|
shuffledData = new ArrayList<>(numberOfSplits);
|
||||||
|
final Set<Integer> indexesToUse = new HashSet<>();
|
||||||
|
|
||||||
|
while(indexesToUse.size() < numberOfSplits){
|
||||||
|
final int index = random.nextInt(data.size());
|
||||||
|
|
||||||
|
if(indexesToUse.add(index)){
|
||||||
|
shuffledData.add(data.get(index));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
int tries = 0;
|
int tries = 0;
|
||||||
while(tries <= numberOfSplits || (numberOfSplits == 0 && tries < data.size())){
|
while(tries < shuffledData.size()){
|
||||||
final SplitRule possibleRule = data.get(tries).getCovariate(covariate).generateSplitRule(covariate);
|
final SplitRule possibleRule = shuffledData.get(tries).getCovariate(covariate).generateSplitRule(covariate);
|
||||||
final Split<Y> possibleSplit = possibleRule.applyRule(data);
|
final Split<Y> possibleSplit = possibleRule.applyRule(data);
|
||||||
|
|
||||||
final Double score = groupDifferentiator.differentiate(
|
final Double score = groupDifferentiator.differentiate(
|
||||||
|
|
|
@ -0,0 +1,90 @@
|
||||||
|
package ca.joeltherrien.randomforest.workshop;
|
||||||
|
|
||||||
|
import ca.joeltherrien.randomforest.*;
|
||||||
|
import ca.joeltherrien.randomforest.regression.MeanResponseCombiner;
|
||||||
|
import ca.joeltherrien.randomforest.regression.WeightedVarianceGroupDifferentiator;
|
||||||
|
import ca.joeltherrien.randomforest.tree.Forest;
|
||||||
|
import ca.joeltherrien.randomforest.tree.ForestTrainer;
|
||||||
|
import ca.joeltherrien.randomforest.tree.TreeTrainer;
|
||||||
|
|
||||||
|
import java.util.*;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
import java.util.stream.IntStream;
|
||||||
|
|
||||||
|
public class TrainForest {
|
||||||
|
|
||||||
|
public static void main(String[] args){
|
||||||
|
// test creating a regression tree on a problem and see if the results are sensible.
|
||||||
|
|
||||||
|
final int n = 10000;
|
||||||
|
final int p =5;
|
||||||
|
|
||||||
|
final Random random = new Random();
|
||||||
|
|
||||||
|
final List<Row<Double>> data = new ArrayList<>(n);
|
||||||
|
|
||||||
|
double minY = 1000.0;
|
||||||
|
|
||||||
|
for(int i=0; i<n; i++){
|
||||||
|
double y = 0.0;
|
||||||
|
final Map<String, Value> map = new HashMap<>();
|
||||||
|
|
||||||
|
for(int j=0; j<p; j++){
|
||||||
|
final double x = random.nextDouble();
|
||||||
|
y+=x;
|
||||||
|
|
||||||
|
map.put("x"+j, new NumericValue(x));
|
||||||
|
}
|
||||||
|
|
||||||
|
data.add(i, new Row<>(map, i, y));
|
||||||
|
|
||||||
|
if(y < minY){
|
||||||
|
minY = y;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
final List<String> covariateNames = IntStream.range(0, p).mapToObj(j -> "x"+j).collect(Collectors.toList());
|
||||||
|
|
||||||
|
|
||||||
|
TreeTrainer<Double> treeTrainer = TreeTrainer.<Double>builder()
|
||||||
|
.numberOfSplits(10)
|
||||||
|
.nodeSize(5)
|
||||||
|
.maxNodeDepth(100000000)
|
||||||
|
.groupDifferentiator(new WeightedVarianceGroupDifferentiator())
|
||||||
|
.responseCombiner(new MeanResponseCombiner())
|
||||||
|
.build();
|
||||||
|
|
||||||
|
final ForestTrainer<Double> forestTrainer = ForestTrainer.<Double>builder()
|
||||||
|
.treeTrainer(treeTrainer)
|
||||||
|
.bootstrapper(new Bootstrapper<>(data))
|
||||||
|
.covariatesToTry(covariateNames)
|
||||||
|
.mtry(4)
|
||||||
|
.ntree(100)
|
||||||
|
.treeResponseCombiner(new MeanResponseCombiner())
|
||||||
|
.displayProgress(true)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
final long startTime = System.currentTimeMillis();
|
||||||
|
final Forest<Double> forest = forestTrainer.trainSerial();
|
||||||
|
final long endTime = System.currentTimeMillis();
|
||||||
|
|
||||||
|
System.out.println("Took " + (double)(endTime - startTime)/1000.0 + " seconds.");
|
||||||
|
|
||||||
|
|
||||||
|
final Value zeroValue = new NumericValue(0.1);
|
||||||
|
final Value point5Value = new NumericValue(0.5);
|
||||||
|
|
||||||
|
// test row
|
||||||
|
final CovariateRow testRow1 = new CovariateRow(Map.of("x0", zeroValue, "x1",zeroValue,"x2",zeroValue,"x3",zeroValue,"x4",zeroValue), 0);
|
||||||
|
final CovariateRow testRow2 = new CovariateRow(Map.of("x0", point5Value, "x1",point5Value,"x2",point5Value,"x3",point5Value,"x4",point5Value), 2);
|
||||||
|
|
||||||
|
|
||||||
|
System.out.println(forest.evaluate(testRow1));
|
||||||
|
System.out.println(forest.evaluate(testRow2));
|
||||||
|
|
||||||
|
System.out.println("MinY = " + minY);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -1,6 +1,10 @@
|
||||||
package ca.joeltherrien.randomforest;
|
package ca.joeltherrien.randomforest.workshop;
|
||||||
|
|
||||||
|
|
||||||
|
import ca.joeltherrien.randomforest.CovariateRow;
|
||||||
|
import ca.joeltherrien.randomforest.NumericValue;
|
||||||
|
import ca.joeltherrien.randomforest.Row;
|
||||||
|
import ca.joeltherrien.randomforest.Value;
|
||||||
import ca.joeltherrien.randomforest.regression.MeanGroupDifferentiator;
|
import ca.joeltherrien.randomforest.regression.MeanGroupDifferentiator;
|
||||||
import ca.joeltherrien.randomforest.regression.MeanResponseCombiner;
|
import ca.joeltherrien.randomforest.regression.MeanResponseCombiner;
|
||||||
import ca.joeltherrien.randomforest.regression.WeightedVarianceGroupDifferentiator;
|
import ca.joeltherrien.randomforest.regression.WeightedVarianceGroupDifferentiator;
|
||||||
|
@ -11,7 +15,7 @@ import java.util.*;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
import java.util.stream.DoubleStream;
|
import java.util.stream.DoubleStream;
|
||||||
|
|
||||||
public class Main {
|
public class TrainSingleTree {
|
||||||
|
|
||||||
public static void main(String[] args) {
|
public static void main(String[] args) {
|
||||||
System.out.println("Hello world!");
|
System.out.println("Hello world!");
|
Loading…
Reference in a new issue