Add functionality to train a random
forest in serial.
This commit is contained in:
parent
6192643e12
commit
df7835869a
6 changed files with 234 additions and 7 deletions
30
src/main/java/ca/joeltherrien/randomforest/Bootstrapper.java
Normal file
30
src/main/java/ca/joeltherrien/randomforest/Bootstrapper.java
Normal file
|
@ -0,0 +1,30 @@
|
|||
package ca.joeltherrien.randomforest;
|
||||
|
||||
import lombok.RequiredArgsConstructor;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Random;
|
||||
|
||||
@RequiredArgsConstructor
|
||||
public class Bootstrapper<T> {
|
||||
|
||||
final private List<T> originalData;
|
||||
final private Random random = new Random();
|
||||
|
||||
public List<T> bootstrap(){
|
||||
final int n = originalData.size();
|
||||
|
||||
final List<T> newList = new ArrayList<>(n);
|
||||
|
||||
for(int i=0; i<n; i++){
|
||||
final int index = random.nextInt(n);
|
||||
|
||||
newList.add(originalData.get(index));
|
||||
}
|
||||
|
||||
return newList;
|
||||
|
||||
}
|
||||
|
||||
}
|
21
src/main/java/ca/joeltherrien/randomforest/tree/Forest.java
Normal file
21
src/main/java/ca/joeltherrien/randomforest/tree/Forest.java
Normal file
|
@ -0,0 +1,21 @@
|
|||
package ca.joeltherrien.randomforest.tree;
|
||||
|
||||
import ca.joeltherrien.randomforest.CovariateRow;
|
||||
import ca.joeltherrien.randomforest.ResponseCombiner;
|
||||
import lombok.Builder;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@Builder
|
||||
public class Forest<Y> {
|
||||
|
||||
private final List<Node<Y>> trees;
|
||||
private final ResponseCombiner<Y, ?> treeResponseCombiner;
|
||||
|
||||
public Y evaluate(CovariateRow row){
|
||||
return trees.parallelStream()
|
||||
.map(node -> node.evaluate(row))
|
||||
.collect(treeResponseCombiner);
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,62 @@
|
|||
package ca.joeltherrien.randomforest.tree;
|
||||
|
||||
import ca.joeltherrien.randomforest.Bootstrapper;
|
||||
import ca.joeltherrien.randomforest.ResponseCombiner;
|
||||
import ca.joeltherrien.randomforest.Row;
|
||||
import lombok.Builder;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
|
||||
@Builder
|
||||
public class ForestTrainer<Y> {
|
||||
|
||||
private final TreeTrainer<Y> treeTrainer;
|
||||
private final Bootstrapper<Row<Y>> bootstrapper;
|
||||
private final List<String> covariatesToTry;
|
||||
private final ResponseCombiner<Y, ?> treeResponseCombiner;
|
||||
|
||||
// number of covariates to randomly try
|
||||
private final int mtry;
|
||||
|
||||
// number of trees to try
|
||||
private final int ntree;
|
||||
|
||||
private final boolean displayProgress;
|
||||
|
||||
public Forest<Y> trainSerial(){
|
||||
|
||||
final List<Node<Y>> trees = new ArrayList<>(ntree);
|
||||
|
||||
for(int j=0; j<ntree; j++){
|
||||
final List<String> treeCovariates = new ArrayList<>(covariatesToTry);
|
||||
Collections.shuffle(treeCovariates);
|
||||
|
||||
for(int treeIndex = covariatesToTry.size()-1; treeIndex >= mtry; treeIndex--){
|
||||
treeCovariates.remove(treeIndex);
|
||||
}
|
||||
|
||||
final List<Row<Y>> bootstrappedData = bootstrapper.bootstrap();
|
||||
|
||||
trees.add(treeTrainer.growTree(bootstrappedData, treeCovariates));
|
||||
|
||||
if(displayProgress){
|
||||
if(j==0) {
|
||||
System.out.println();
|
||||
}
|
||||
System.out.print("\rFinished tree " + (j+1) + "/" + ntree);
|
||||
if(j==ntree-1){
|
||||
System.out.println();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return Forest.<Y>builder()
|
||||
.treeResponseCombiner(treeResponseCombiner)
|
||||
.trees(trees)
|
||||
.build();
|
||||
|
||||
}
|
||||
|
||||
}
|
|
@ -6,8 +6,7 @@ import ca.joeltherrien.randomforest.Split;
|
|||
import ca.joeltherrien.randomforest.SplitRule;
|
||||
import lombok.Builder;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.*;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@Builder
|
||||
|
@ -24,6 +23,8 @@ public class TreeTrainer<Y> {
|
|||
private final int nodeSize;
|
||||
private final int maxNodeDepth;
|
||||
|
||||
private final Random random = new Random();
|
||||
|
||||
|
||||
public Node<Y> growTree(List<Row<Y>> data, List<String> covariatesToTry){
|
||||
return growNode(data, covariatesToTry, 0);
|
||||
|
@ -60,11 +61,30 @@ public class TreeTrainer<Y> {
|
|||
boolean first = true;
|
||||
|
||||
for(final String covariate : covariatesToTry){
|
||||
Collections.shuffle(data);
|
||||
|
||||
final List<Row<Y>> shuffledData;
|
||||
if(numberOfSplits == 0 || numberOfSplits > data.size()){
|
||||
shuffledData = new ArrayList<>(data);
|
||||
Collections.shuffle(shuffledData);
|
||||
}
|
||||
else{ // only need the top numberOfSplits entries
|
||||
shuffledData = new ArrayList<>(numberOfSplits);
|
||||
final Set<Integer> indexesToUse = new HashSet<>();
|
||||
|
||||
while(indexesToUse.size() < numberOfSplits){
|
||||
final int index = random.nextInt(data.size());
|
||||
|
||||
if(indexesToUse.add(index)){
|
||||
shuffledData.add(data.get(index));
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
int tries = 0;
|
||||
while(tries <= numberOfSplits || (numberOfSplits == 0 && tries < data.size())){
|
||||
final SplitRule possibleRule = data.get(tries).getCovariate(covariate).generateSplitRule(covariate);
|
||||
while(tries < shuffledData.size()){
|
||||
final SplitRule possibleRule = shuffledData.get(tries).getCovariate(covariate).generateSplitRule(covariate);
|
||||
final Split<Y> possibleSplit = possibleRule.applyRule(data);
|
||||
|
||||
final Double score = groupDifferentiator.differentiate(
|
||||
|
|
|
@ -0,0 +1,90 @@
|
|||
package ca.joeltherrien.randomforest.workshop;
|
||||
|
||||
import ca.joeltherrien.randomforest.*;
|
||||
import ca.joeltherrien.randomforest.regression.MeanResponseCombiner;
|
||||
import ca.joeltherrien.randomforest.regression.WeightedVarianceGroupDifferentiator;
|
||||
import ca.joeltherrien.randomforest.tree.Forest;
|
||||
import ca.joeltherrien.randomforest.tree.ForestTrainer;
|
||||
import ca.joeltherrien.randomforest.tree.TreeTrainer;
|
||||
|
||||
import java.util.*;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.IntStream;
|
||||
|
||||
public class TrainForest {
|
||||
|
||||
public static void main(String[] args){
|
||||
// test creating a regression tree on a problem and see if the results are sensible.
|
||||
|
||||
final int n = 10000;
|
||||
final int p =5;
|
||||
|
||||
final Random random = new Random();
|
||||
|
||||
final List<Row<Double>> data = new ArrayList<>(n);
|
||||
|
||||
double minY = 1000.0;
|
||||
|
||||
for(int i=0; i<n; i++){
|
||||
double y = 0.0;
|
||||
final Map<String, Value> map = new HashMap<>();
|
||||
|
||||
for(int j=0; j<p; j++){
|
||||
final double x = random.nextDouble();
|
||||
y+=x;
|
||||
|
||||
map.put("x"+j, new NumericValue(x));
|
||||
}
|
||||
|
||||
data.add(i, new Row<>(map, i, y));
|
||||
|
||||
if(y < minY){
|
||||
minY = y;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
final List<String> covariateNames = IntStream.range(0, p).mapToObj(j -> "x"+j).collect(Collectors.toList());
|
||||
|
||||
|
||||
TreeTrainer<Double> treeTrainer = TreeTrainer.<Double>builder()
|
||||
.numberOfSplits(10)
|
||||
.nodeSize(5)
|
||||
.maxNodeDepth(100000000)
|
||||
.groupDifferentiator(new WeightedVarianceGroupDifferentiator())
|
||||
.responseCombiner(new MeanResponseCombiner())
|
||||
.build();
|
||||
|
||||
final ForestTrainer<Double> forestTrainer = ForestTrainer.<Double>builder()
|
||||
.treeTrainer(treeTrainer)
|
||||
.bootstrapper(new Bootstrapper<>(data))
|
||||
.covariatesToTry(covariateNames)
|
||||
.mtry(4)
|
||||
.ntree(100)
|
||||
.treeResponseCombiner(new MeanResponseCombiner())
|
||||
.displayProgress(true)
|
||||
.build();
|
||||
|
||||
final long startTime = System.currentTimeMillis();
|
||||
final Forest<Double> forest = forestTrainer.trainSerial();
|
||||
final long endTime = System.currentTimeMillis();
|
||||
|
||||
System.out.println("Took " + (double)(endTime - startTime)/1000.0 + " seconds.");
|
||||
|
||||
|
||||
final Value zeroValue = new NumericValue(0.1);
|
||||
final Value point5Value = new NumericValue(0.5);
|
||||
|
||||
// test row
|
||||
final CovariateRow testRow1 = new CovariateRow(Map.of("x0", zeroValue, "x1",zeroValue,"x2",zeroValue,"x3",zeroValue,"x4",zeroValue), 0);
|
||||
final CovariateRow testRow2 = new CovariateRow(Map.of("x0", point5Value, "x1",point5Value,"x2",point5Value,"x3",point5Value,"x4",point5Value), 2);
|
||||
|
||||
|
||||
System.out.println(forest.evaluate(testRow1));
|
||||
System.out.println(forest.evaluate(testRow2));
|
||||
|
||||
System.out.println("MinY = " + minY);
|
||||
|
||||
}
|
||||
|
||||
}
|
|
@ -1,6 +1,10 @@
|
|||
package ca.joeltherrien.randomforest;
|
||||
package ca.joeltherrien.randomforest.workshop;
|
||||
|
||||
|
||||
import ca.joeltherrien.randomforest.CovariateRow;
|
||||
import ca.joeltherrien.randomforest.NumericValue;
|
||||
import ca.joeltherrien.randomforest.Row;
|
||||
import ca.joeltherrien.randomforest.Value;
|
||||
import ca.joeltherrien.randomforest.regression.MeanGroupDifferentiator;
|
||||
import ca.joeltherrien.randomforest.regression.MeanResponseCombiner;
|
||||
import ca.joeltherrien.randomforest.regression.WeightedVarianceGroupDifferentiator;
|
||||
|
@ -11,7 +15,7 @@ import java.util.*;
|
|||
import java.util.stream.Collectors;
|
||||
import java.util.stream.DoubleStream;
|
||||
|
||||
public class Main {
|
||||
public class TrainSingleTree {
|
||||
|
||||
public static void main(String[] args) {
|
||||
System.out.println("Hello world!");
|
Loading…
Reference in a new issue