Basic functinality to train a single regression tree is
implemented.
This commit is contained in:
parent
7a467207a4
commit
3c9c78741f
26 changed files with 594 additions and 115 deletions
2
.gitignore
vendored
2
.gitignore
vendored
|
@ -2,3 +2,5 @@
|
||||||
.settings
|
.settings
|
||||||
.project
|
.project
|
||||||
target/
|
target/
|
||||||
|
*.iml
|
||||||
|
.idea
|
||||||
|
|
31
pom.xml
Normal file
31
pom.xml
Normal file
|
@ -0,0 +1,31 @@
|
||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<project xmlns="http://maven.apache.org/POM/4.0.0"
|
||||||
|
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
||||||
|
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
||||||
|
<modelVersion>4.0.0</modelVersion>
|
||||||
|
|
||||||
|
<groupId>ca.joeltherrien</groupId>
|
||||||
|
<artifactId>RandomSurvivalForests</artifactId>
|
||||||
|
<version>1.0-SNAPSHOT</version>
|
||||||
|
|
||||||
|
<properties>
|
||||||
|
<java.version>1.10</java.version>
|
||||||
|
<maven.compiler.target>1.10</maven.compiler.target>
|
||||||
|
<maven.compiler.source>1.10</maven.compiler.source>
|
||||||
|
</properties>
|
||||||
|
|
||||||
|
|
||||||
|
<dependencies>
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.projectlombok</groupId>
|
||||||
|
<artifactId>lombok</artifactId>
|
||||||
|
<version>1.18.0</version>
|
||||||
|
<scope>provided</scope>
|
||||||
|
</dependency>
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
</dependencies>
|
||||||
|
|
||||||
|
|
||||||
|
</project>
|
|
@ -1,10 +0,0 @@
|
||||||
package ca.joeltherrien.randomforest;
|
|
||||||
|
|
||||||
public class Main {
|
|
||||||
|
|
||||||
public static void main(String[] args) {
|
|
||||||
System.out.println("Hello world!");
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,5 +0,0 @@
|
||||||
package ca.joeltherrien.randomforest;
|
|
||||||
|
|
||||||
public class Node {
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,44 +0,0 @@
|
||||||
package ca.joeltherrien.randomforest;
|
|
||||||
|
|
||||||
import java.util.LinkedList;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
import ca.joeltherrien.randomforest.exceptions.MissingValueException;
|
|
||||||
|
|
||||||
public class NumericSplitRule implements SplitRule{
|
|
||||||
|
|
||||||
public final String covariateName;
|
|
||||||
public final double threshold;
|
|
||||||
|
|
||||||
public NumericSplitRule(String covariateName, double threshold) {
|
|
||||||
super();
|
|
||||||
this.covariateName = covariateName;
|
|
||||||
this.threshold = threshold;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public final String toString() {
|
|
||||||
return "NumericSplitRule on " + covariateName + " at " + threshold;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public <Y> Split<Y> applyRule(List<Row<Y>> rows) {
|
|
||||||
final List<Row<Y>> leftHand = new LinkedList<>();
|
|
||||||
final List<Row<Y>> rightHand = new LinkedList<>();
|
|
||||||
|
|
||||||
for(final Row<Y> row : rows) {
|
|
||||||
final Value x = row.getCovariate(covariateName);
|
|
||||||
if(x == null) {
|
|
||||||
throw new MissingValueException(row, this);
|
|
||||||
}
|
|
||||||
|
|
||||||
final NumericValue xNum = (NumericValue) x;
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO Auto-generated method stub
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,33 +0,0 @@
|
||||||
package ca.joeltherrien.randomforest;
|
|
||||||
|
|
||||||
import java.util.Map;
|
|
||||||
|
|
||||||
public class Row<Y> {
|
|
||||||
|
|
||||||
private final Map<String, Value> covariates;
|
|
||||||
private final Y response;
|
|
||||||
private final int id;
|
|
||||||
|
|
||||||
public Row(Map<String, Value> covariates, Y response, int id) {
|
|
||||||
super();
|
|
||||||
this.covariates = covariates;
|
|
||||||
this.response = response;
|
|
||||||
this.id = id;
|
|
||||||
}
|
|
||||||
|
|
||||||
public Value getCovariate(String name) {
|
|
||||||
return this.covariates.get(name);
|
|
||||||
}
|
|
||||||
|
|
||||||
public Y getResponse() {
|
|
||||||
return this.response;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String toString() {
|
|
||||||
return "Row " + this.id;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,9 +0,0 @@
|
||||||
package ca.joeltherrien.randomforest;
|
|
||||||
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
public interface SplitRule {
|
|
||||||
|
|
||||||
<Y> Split<Y> applyRule(List<Row<Y>> rows);
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,7 +0,0 @@
|
||||||
package ca.joeltherrien.randomforest;
|
|
||||||
|
|
||||||
public interface Value {
|
|
||||||
|
|
||||||
// TODO
|
|
||||||
|
|
||||||
}
|
|
26
src/main/java/ca/joeltherrien/randomforest/CovariateRow.java
Normal file
26
src/main/java/ca/joeltherrien/randomforest/CovariateRow.java
Normal file
|
@ -0,0 +1,26 @@
|
||||||
|
package ca.joeltherrien.randomforest;
|
||||||
|
|
||||||
|
import lombok.Getter;
|
||||||
|
import lombok.RequiredArgsConstructor;
|
||||||
|
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
@RequiredArgsConstructor
|
||||||
|
public class CovariateRow {
|
||||||
|
|
||||||
|
private final Map<String, Value> valueMap;
|
||||||
|
|
||||||
|
@Getter
|
||||||
|
private final int id;
|
||||||
|
|
||||||
|
public Value<?> getCovariate(String name){
|
||||||
|
return valueMap.get(name);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String toString(){
|
||||||
|
return "CovariateRow " + this.id;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
146
src/main/java/ca/joeltherrien/randomforest/Main.java
Normal file
146
src/main/java/ca/joeltherrien/randomforest/Main.java
Normal file
|
@ -0,0 +1,146 @@
|
||||||
|
package ca.joeltherrien.randomforest;
|
||||||
|
|
||||||
|
|
||||||
|
import ca.joeltherrien.randomforest.regression.MeanGroupDifferentiator;
|
||||||
|
import ca.joeltherrien.randomforest.regression.MeanResponseCombiner;
|
||||||
|
import ca.joeltherrien.randomforest.regression.WeightedVarianceGroupDifferentiator;
|
||||||
|
import ca.joeltherrien.randomforest.tree.Node;
|
||||||
|
import ca.joeltherrien.randomforest.tree.TreeTrainer;
|
||||||
|
|
||||||
|
import java.util.*;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
import java.util.stream.DoubleStream;
|
||||||
|
|
||||||
|
public class Main {
|
||||||
|
|
||||||
|
public static void main(String[] args) {
|
||||||
|
System.out.println("Hello world!");
|
||||||
|
|
||||||
|
final Random random = new Random(123);
|
||||||
|
|
||||||
|
final int n = 1000;
|
||||||
|
final List<Row<Double>> trainingSet = new ArrayList<>(n);
|
||||||
|
|
||||||
|
final List<Value<Double>> x1List = DoubleStream
|
||||||
|
.generate(() -> random.nextDouble()*10.0)
|
||||||
|
.limit(n)
|
||||||
|
.mapToObj(x1 -> new NumericValue(x1))
|
||||||
|
.collect(Collectors.toList());
|
||||||
|
|
||||||
|
final List<Value<Double>> x2List = DoubleStream
|
||||||
|
.generate(() -> random.nextDouble()*10.0)
|
||||||
|
.limit(n)
|
||||||
|
.mapToObj(x1 -> new NumericValue(x1))
|
||||||
|
.collect(Collectors.toList());
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
for(int i=0; i<n; i++){
|
||||||
|
double x1 = x1List.get(i).getValue();
|
||||||
|
double x2 = x2List.get(i).getValue();
|
||||||
|
|
||||||
|
trainingSet.add(generateRow(x1, x2, i));
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
final long startTime = System.currentTimeMillis();
|
||||||
|
|
||||||
|
final TreeTrainer<Double> treeTrainer = TreeTrainer.<Double>builder()
|
||||||
|
.groupDifferentiator(new WeightedVarianceGroupDifferentiator())
|
||||||
|
.responseCombiner(new MeanResponseCombiner())
|
||||||
|
.maxNodeDepth(30)
|
||||||
|
.nodeSize(5)
|
||||||
|
.numberOfSplits(0)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
final long endTime = System.currentTimeMillis();
|
||||||
|
|
||||||
|
System.out.println(((double)(endTime - startTime))/1000.0);
|
||||||
|
|
||||||
|
final List<String> covariateNames = List.of("x1", "x2");
|
||||||
|
|
||||||
|
final Node<Double> baseNode = treeTrainer.growTree(trainingSet, covariateNames);
|
||||||
|
|
||||||
|
|
||||||
|
final List<CovariateRow> testSet = new ArrayList<>();
|
||||||
|
testSet.add(generateCovariateRow(9, 2, 1)); // expect 1
|
||||||
|
testSet.add(generateCovariateRow(5, 2, 5));
|
||||||
|
testSet.add(generateCovariateRow(2, 2, 3));
|
||||||
|
testSet.add(generateCovariateRow(9, 5, 0));
|
||||||
|
testSet.add(generateCovariateRow(6, 5, 8));
|
||||||
|
testSet.add(generateCovariateRow(3, 5, 10));
|
||||||
|
testSet.add(generateCovariateRow(1, 5, 3));
|
||||||
|
testSet.add(generateCovariateRow(7, 9, 2));
|
||||||
|
testSet.add(generateCovariateRow(1, 9, 4));
|
||||||
|
|
||||||
|
for(final CovariateRow testCase : testSet){
|
||||||
|
System.out.println(testCase);
|
||||||
|
System.out.println(baseNode.evaluate(testCase));
|
||||||
|
System.out.println();
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
public static Row<Double> generateRow(double x1, double x2, int id){
|
||||||
|
double y = generateResponse(x1, x2);
|
||||||
|
|
||||||
|
final Map<String, Value> map = Map.of("x1", new NumericValue(x1), "x2", new NumericValue(x2));
|
||||||
|
|
||||||
|
return new Row<>(map, id, y);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
public static CovariateRow generateCovariateRow(double x1, double x2, int id){
|
||||||
|
final Map<String, Value> map = Map.of("x1", new NumericValue(x1), "x2", new NumericValue(x2));
|
||||||
|
|
||||||
|
return new CovariateRow(map, id);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
public static double generateResponse(double x1, double x2){
|
||||||
|
if(x2 <= 3){
|
||||||
|
if(x1 <= 3){
|
||||||
|
return 3;
|
||||||
|
}
|
||||||
|
else if(x1 <= 7){
|
||||||
|
return 5;
|
||||||
|
}
|
||||||
|
else{
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else if(x1 >= 5){
|
||||||
|
if(x2 > 6){
|
||||||
|
return 2;
|
||||||
|
}
|
||||||
|
else if(x1 >= 8){
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
else{
|
||||||
|
return 8;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else if(x1 <= 2){
|
||||||
|
if(x2 >= 7){
|
||||||
|
return 4;
|
||||||
|
}
|
||||||
|
else{
|
||||||
|
return 3;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else{
|
||||||
|
return 10;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,33 @@
|
||||||
|
package ca.joeltherrien.randomforest;
|
||||||
|
|
||||||
|
import java.util.LinkedList;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
import ca.joeltherrien.randomforest.exceptions.MissingValueException;
|
||||||
|
import lombok.AllArgsConstructor;
|
||||||
|
|
||||||
|
@AllArgsConstructor
|
||||||
|
public class NumericSplitRule extends SplitRule{
|
||||||
|
|
||||||
|
public final String covariateName;
|
||||||
|
public final double threshold;
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public final String toString() {
|
||||||
|
return "NumericSplitRule on " + covariateName + " at " + threshold;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean isLeftHand(CovariateRow row) {
|
||||||
|
final Value<?> x = row.getCovariate(covariateName);
|
||||||
|
if(x == null) {
|
||||||
|
throw new MissingValueException(row, this);
|
||||||
|
}
|
||||||
|
|
||||||
|
final double xNum = (Double) x.getValue();
|
||||||
|
|
||||||
|
return xNum <= threshold;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
}
|
19
src/main/java/ca/joeltherrien/randomforest/NumericValue.java
Normal file
19
src/main/java/ca/joeltherrien/randomforest/NumericValue.java
Normal file
|
@ -0,0 +1,19 @@
|
||||||
|
package ca.joeltherrien.randomforest;
|
||||||
|
|
||||||
|
import lombok.RequiredArgsConstructor;
|
||||||
|
|
||||||
|
@RequiredArgsConstructor
|
||||||
|
public class NumericValue implements Value<Double> {
|
||||||
|
|
||||||
|
private final double value;
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Double getValue() {
|
||||||
|
return value;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public SplitRule generateSplitRule(final String covariateName) {
|
||||||
|
return new NumericSplitRule(covariateName, value);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,9 @@
|
||||||
|
package ca.joeltherrien.randomforest;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
public interface ResponseCombiner<Y> {
|
||||||
|
|
||||||
|
Y combine(List<Y> responses);
|
||||||
|
|
||||||
|
}
|
27
src/main/java/ca/joeltherrien/randomforest/Row.java
Normal file
27
src/main/java/ca/joeltherrien/randomforest/Row.java
Normal file
|
@ -0,0 +1,27 @@
|
||||||
|
package ca.joeltherrien.randomforest;
|
||||||
|
|
||||||
|
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
public class Row<Y> extends CovariateRow {
|
||||||
|
|
||||||
|
private final Y response;
|
||||||
|
|
||||||
|
public Row(Map<String, Value> valueMap, int id, Y response){
|
||||||
|
super(valueMap, id);
|
||||||
|
this.response = response;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
public Y getResponse() {
|
||||||
|
return this.response;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String toString() {
|
||||||
|
return "Row " + this.getId();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
}
|
|
@ -1,5 +1,7 @@
|
||||||
package ca.joeltherrien.randomforest;
|
package ca.joeltherrien.randomforest;
|
||||||
|
|
||||||
|
import lombok.Data;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -8,13 +10,10 @@ import java.util.List;
|
||||||
* @author joel
|
* @author joel
|
||||||
*
|
*
|
||||||
*/
|
*/
|
||||||
|
@Data
|
||||||
public class Split<Y> {
|
public class Split<Y> {
|
||||||
|
|
||||||
public final List<Row<Y>> leftHand;
|
public final List<Row<Y>> leftHand;
|
||||||
public final List<Row<Y>> rightHand;
|
public final List<Row<Y>> rightHand;
|
||||||
|
|
||||||
public Split(List<Row<Y>> leftHand, List<Row<Y>> rightHand){
|
|
||||||
this.leftHand = leftHand;
|
|
||||||
this.rightHand = rightHand;
|
|
||||||
}
|
|
||||||
}
|
}
|
36
src/main/java/ca/joeltherrien/randomforest/SplitRule.java
Normal file
36
src/main/java/ca/joeltherrien/randomforest/SplitRule.java
Normal file
|
@ -0,0 +1,36 @@
|
||||||
|
package ca.joeltherrien.randomforest;
|
||||||
|
|
||||||
|
import java.util.LinkedList;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
public abstract class SplitRule {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Applies the SplitRule to a list of rows and returns a Split object, which contains two lists for both sides.
|
||||||
|
* This method is primarily used during the training of a tree when splits are being tested.
|
||||||
|
*
|
||||||
|
* @param rows
|
||||||
|
* @param <Y>
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
public <Y> Split<Y> applyRule(List<Row<Y>> rows) {
|
||||||
|
final List<Row<Y>> leftHand = new LinkedList<>();
|
||||||
|
final List<Row<Y>> rightHand = new LinkedList<>();
|
||||||
|
|
||||||
|
for(final Row<Y> row : rows) {
|
||||||
|
|
||||||
|
if(isLeftHand(row)){
|
||||||
|
leftHand.add(row);
|
||||||
|
}
|
||||||
|
else{
|
||||||
|
rightHand.add(row);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
return new Split<>(leftHand, rightHand);
|
||||||
|
}
|
||||||
|
|
||||||
|
public abstract boolean isLeftHand(CovariateRow row);
|
||||||
|
|
||||||
|
}
|
11
src/main/java/ca/joeltherrien/randomforest/Value.java
Normal file
11
src/main/java/ca/joeltherrien/randomforest/Value.java
Normal file
|
@ -0,0 +1,11 @@
|
||||||
|
package ca.joeltherrien.randomforest;
|
||||||
|
|
||||||
|
|
||||||
|
public interface Value<V> {
|
||||||
|
|
||||||
|
V getValue();
|
||||||
|
|
||||||
|
SplitRule generateSplitRule(String covariateName);
|
||||||
|
|
||||||
|
|
||||||
|
}
|
|
@ -1,5 +1,6 @@
|
||||||
package ca.joeltherrien.randomforest.exceptions;
|
package ca.joeltherrien.randomforest.exceptions;
|
||||||
|
|
||||||
|
import ca.joeltherrien.randomforest.CovariateRow;
|
||||||
import ca.joeltherrien.randomforest.Row;
|
import ca.joeltherrien.randomforest.Row;
|
||||||
import ca.joeltherrien.randomforest.SplitRule;
|
import ca.joeltherrien.randomforest.SplitRule;
|
||||||
|
|
||||||
|
@ -10,8 +11,8 @@ public class MissingValueException extends RuntimeException{
|
||||||
*/
|
*/
|
||||||
private static final long serialVersionUID = 6808060079431207726L;
|
private static final long serialVersionUID = 6808060079431207726L;
|
||||||
|
|
||||||
public MissingValueException(Row<?> row, SplitRule rule) {
|
public MissingValueException(CovariateRow row, SplitRule rule) {
|
||||||
super("Missing value at row " + row + rule);
|
super("Missing value at CovariateRow " + row + rule);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
|
@ -0,0 +1,26 @@
|
||||||
|
package ca.joeltherrien.randomforest.regression;
|
||||||
|
|
||||||
|
import ca.joeltherrien.randomforest.tree.GroupDifferentiator;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
public class MeanGroupDifferentiator implements GroupDifferentiator<Double> {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Double differentiate(List<Double> leftHand, List<Double> rightHand) {
|
||||||
|
|
||||||
|
double leftHandSize = leftHand.size();
|
||||||
|
double rightHandSize = rightHand.size();
|
||||||
|
|
||||||
|
if(leftHandSize == 0 || rightHandSize == 0){
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
double leftHandMean = leftHand.stream().mapToDouble(db -> db/leftHandSize).sum();
|
||||||
|
double rightHandMean = rightHand.stream().mapToDouble(db -> db/rightHandSize).sum();
|
||||||
|
|
||||||
|
return Math.abs(leftHandMean - rightHandMean);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,16 @@
|
||||||
|
package ca.joeltherrien.randomforest.regression;
|
||||||
|
|
||||||
|
import ca.joeltherrien.randomforest.ResponseCombiner;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
public class MeanResponseCombiner implements ResponseCombiner<Double> {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Double combine(List<Double> responses) {
|
||||||
|
double size = responses.size();
|
||||||
|
|
||||||
|
return responses.stream().mapToDouble(db -> db/size).sum();
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,30 @@
|
||||||
|
package ca.joeltherrien.randomforest.regression;
|
||||||
|
|
||||||
|
import ca.joeltherrien.randomforest.tree.GroupDifferentiator;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
public class WeightedVarianceGroupDifferentiator implements GroupDifferentiator<Double> {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Double differentiate(List<Double> leftHand, List<Double> rightHand) {
|
||||||
|
|
||||||
|
final double leftHandSize = leftHand.size();
|
||||||
|
final double rightHandSize = rightHand.size();
|
||||||
|
final double n = leftHandSize + rightHandSize;
|
||||||
|
|
||||||
|
if(leftHandSize == 0 || rightHandSize == 0){
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
final double leftHandMean = leftHand.stream().mapToDouble(db -> db/leftHandSize).sum();
|
||||||
|
final double rightHandMean = rightHand.stream().mapToDouble(db -> db/rightHandSize).sum();
|
||||||
|
|
||||||
|
final double leftVariance = leftHand.stream().mapToDouble(db -> (db - leftHandMean)*(db - leftHandMean)).sum();
|
||||||
|
final double rightVariance = rightHand.stream().mapToDouble(db -> (db - rightHandMean)*(db - rightHandMean)).sum();
|
||||||
|
|
||||||
|
return -(leftVariance + rightVariance) / n;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,15 @@
|
||||||
|
package ca.joeltherrien.randomforest.tree;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* When choosing an optimal node to split on, we choose the split that maximizes the difference between the two groups.
|
||||||
|
* The GroupDifferentiator has one method that outputs a score to show how different groups are. The larger the score,
|
||||||
|
* the greater the difference.
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
public interface GroupDifferentiator<Y> {
|
||||||
|
|
||||||
|
Double differentiate(List<Y> leftHand, List<Y> rightHand);
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,9 @@
|
||||||
|
package ca.joeltherrien.randomforest.tree;
|
||||||
|
|
||||||
|
import ca.joeltherrien.randomforest.CovariateRow;
|
||||||
|
|
||||||
|
public interface Node<Y> {
|
||||||
|
|
||||||
|
Y evaluate(CovariateRow row);
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,26 @@
|
||||||
|
package ca.joeltherrien.randomforest.tree;
|
||||||
|
|
||||||
|
import ca.joeltherrien.randomforest.CovariateRow;
|
||||||
|
import ca.joeltherrien.randomforest.Row;
|
||||||
|
import ca.joeltherrien.randomforest.SplitRule;
|
||||||
|
import lombok.Builder;
|
||||||
|
|
||||||
|
@Builder
|
||||||
|
public class SplitNode<Y> implements Node<Y> {
|
||||||
|
|
||||||
|
private final Node<Y> leftHand;
|
||||||
|
private final Node<Y> rightHand;
|
||||||
|
private final SplitRule splitRule;
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Y evaluate(CovariateRow row) {
|
||||||
|
|
||||||
|
if(splitRule.isLeftHand(row)){
|
||||||
|
return leftHand.evaluate(row);
|
||||||
|
}
|
||||||
|
else{
|
||||||
|
return rightHand.evaluate(row);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,20 @@
|
||||||
|
package ca.joeltherrien.randomforest.tree;
|
||||||
|
|
||||||
|
import ca.joeltherrien.randomforest.CovariateRow;
|
||||||
|
|
||||||
|
import lombok.RequiredArgsConstructor;
|
||||||
|
|
||||||
|
@RequiredArgsConstructor
|
||||||
|
public class TerminalNode<Y> implements Node<Y> {
|
||||||
|
|
||||||
|
private final Y responseValue;
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Y evaluate(CovariateRow row){
|
||||||
|
return responseValue;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
}
|
105
src/main/java/ca/joeltherrien/randomforest/tree/TreeTrainer.java
Normal file
105
src/main/java/ca/joeltherrien/randomforest/tree/TreeTrainer.java
Normal file
|
@ -0,0 +1,105 @@
|
||||||
|
package ca.joeltherrien.randomforest.tree;
|
||||||
|
|
||||||
|
import ca.joeltherrien.randomforest.ResponseCombiner;
|
||||||
|
import ca.joeltherrien.randomforest.Row;
|
||||||
|
import ca.joeltherrien.randomforest.Split;
|
||||||
|
import ca.joeltherrien.randomforest.SplitRule;
|
||||||
|
import lombok.Builder;
|
||||||
|
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
|
@Builder
|
||||||
|
public class TreeTrainer<Y> {
|
||||||
|
|
||||||
|
private final ResponseCombiner<Y> responseCombiner;
|
||||||
|
private final GroupDifferentiator<Y> groupDifferentiator;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The number of splits to perform on each covariate. A value of 0 means all possible splits are tried.
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
private final int numberOfSplits;
|
||||||
|
private final int nodeSize;
|
||||||
|
private final int maxNodeDepth;
|
||||||
|
|
||||||
|
|
||||||
|
public Node<Y> growTree(List<Row<Y>> data, List<String> covariatesToTry){
|
||||||
|
return growNode(data, covariatesToTry, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
private Node<Y> growNode(List<Row<Y>> data, List<String> covariatesToTry, int depth){
|
||||||
|
// TODO; what is minimum per tree?
|
||||||
|
if(data.size() >= 2*nodeSize && depth < maxNodeDepth && !nodeIsPure(data, covariatesToTry)){
|
||||||
|
final SplitRule bestSplitRule = findBestSplitRule(data, covariatesToTry);
|
||||||
|
|
||||||
|
final Split<Y> split = bestSplitRule.applyRule(data); // TODO optimize this as we're duplicating work done in findBestSplitRule
|
||||||
|
|
||||||
|
final Node<Y> leftNode = growNode(split.leftHand, covariatesToTry, depth+1);
|
||||||
|
final Node<Y> rightNode = growNode(split.rightHand, covariatesToTry, depth+1);
|
||||||
|
|
||||||
|
return new SplitNode<>(leftNode, rightNode, bestSplitRule);
|
||||||
|
|
||||||
|
}
|
||||||
|
else{
|
||||||
|
return new TerminalNode<>(responseCombiner.combine(
|
||||||
|
data.stream()
|
||||||
|
.map(row -> row.getResponse())
|
||||||
|
.collect(Collectors.toList()))
|
||||||
|
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
private SplitRule findBestSplitRule(List<Row<Y>> data, List<String> covariatesToTry){
|
||||||
|
SplitRule bestSplitRule = null;
|
||||||
|
double bestSplitScore = 0;
|
||||||
|
boolean first = true;
|
||||||
|
|
||||||
|
for(final String covariate : covariatesToTry){
|
||||||
|
Collections.shuffle(data);
|
||||||
|
|
||||||
|
int tries = 0;
|
||||||
|
while(tries <= numberOfSplits || (numberOfSplits == 0 && tries < data.size())){
|
||||||
|
final SplitRule possibleRule = data.get(tries).getCovariate(covariate).generateSplitRule(covariate);
|
||||||
|
final Split<Y> possibleSplit = possibleRule.applyRule(data);
|
||||||
|
|
||||||
|
final Double score = groupDifferentiator.differentiate(
|
||||||
|
possibleSplit.leftHand.stream().map(row -> row.getResponse()).collect(Collectors.toList()),
|
||||||
|
possibleSplit.rightHand.stream().map(row -> row.getResponse()).collect(Collectors.toList())
|
||||||
|
);
|
||||||
|
|
||||||
|
/*
|
||||||
|
if( (groupDifferentiator.shouldMaximize() && score > bestSplitScore) || (!groupDifferentiator.shouldMaximize() && score < bestSplitScore) || first){
|
||||||
|
bestSplitRule = possibleRule;
|
||||||
|
bestSplitScore = score;
|
||||||
|
first = false;
|
||||||
|
}
|
||||||
|
*/
|
||||||
|
|
||||||
|
if( score != null && (score > bestSplitScore || first)){
|
||||||
|
bestSplitRule = possibleRule;
|
||||||
|
bestSplitScore = score;
|
||||||
|
first = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
tries++;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
return bestSplitRule;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
private boolean nodeIsPure(List<Row<Y>> data, List<String> covariatesToTry){
|
||||||
|
// TODO how is this done?
|
||||||
|
|
||||||
|
final Y first = data.get(0).getResponse();
|
||||||
|
return data.stream().allMatch(row -> row.getResponse().equals(first));
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
Loading…
Reference in a new issue