Merge branch '01-factors' of joel/RandomSurvivalForests into master
This commit is contained in:
commit
c048a285a1
19 changed files with 710 additions and 195 deletions
13
pom.xml
13
pom.xml
|
@ -24,6 +24,19 @@
|
||||||
</dependency>
|
</dependency>
|
||||||
|
|
||||||
|
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.junit.jupiter</groupId>
|
||||||
|
<artifactId>junit-jupiter-api</artifactId>
|
||||||
|
<version>5.2.0</version>
|
||||||
|
<scope>test</scope>
|
||||||
|
</dependency>
|
||||||
|
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.junit.jupiter</groupId>
|
||||||
|
<artifactId>junit-jupiter-engine</artifactId>
|
||||||
|
<version>5.2.0</version>
|
||||||
|
<scope>test</scope>
|
||||||
|
</dependency>
|
||||||
|
|
||||||
</dependencies>
|
</dependencies>
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,72 @@
|
||||||
|
package ca.joeltherrien.randomforest;
|
||||||
|
|
||||||
|
import ca.joeltherrien.randomforest.exceptions.MissingValueException;
|
||||||
|
import lombok.Getter;
|
||||||
|
import lombok.RequiredArgsConstructor;
|
||||||
|
|
||||||
|
import java.util.*;
|
||||||
|
import java.util.concurrent.ThreadLocalRandom;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
|
@RequiredArgsConstructor
|
||||||
|
public class BooleanCovariate implements Covariate<Boolean>{
|
||||||
|
|
||||||
|
@Getter
|
||||||
|
private final String name;
|
||||||
|
|
||||||
|
private final BooleanSplitRule splitRule = new BooleanSplitRule(); // there's only one possible rule for BooleanCovariates.
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Collection<BooleanSplitRule> generateSplitRules(List<Value<Boolean>> data, int number) {
|
||||||
|
return Collections.singleton(splitRule);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public BooleanValue createValue(Boolean value) {
|
||||||
|
return new BooleanValue(value);
|
||||||
|
}
|
||||||
|
|
||||||
|
public class BooleanValue implements Value<Boolean>{
|
||||||
|
|
||||||
|
private final boolean value;
|
||||||
|
|
||||||
|
private BooleanValue(final boolean value){
|
||||||
|
this.value = value;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public BooleanCovariate getParent() {
|
||||||
|
return BooleanCovariate.this;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Boolean getValue() {
|
||||||
|
return value;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public class BooleanSplitRule implements SplitRule<Boolean>{
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public final String toString() {
|
||||||
|
return "BooleanSplitRule";
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public BooleanCovariate getParent() {
|
||||||
|
return BooleanCovariate.this;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean isLeftHand(CovariateRow row) {
|
||||||
|
final Value<?> x = row.getCovariateValue(getParent().getName());
|
||||||
|
if(x == null) {
|
||||||
|
throw new MissingValueException(row, this);
|
||||||
|
}
|
||||||
|
|
||||||
|
final boolean xBoolean = (Boolean) x.getValue();
|
||||||
|
|
||||||
|
return !xBoolean;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
58
src/main/java/ca/joeltherrien/randomforest/Covariate.java
Normal file
58
src/main/java/ca/joeltherrien/randomforest/Covariate.java
Normal file
|
@ -0,0 +1,58 @@
|
||||||
|
package ca.joeltherrien.randomforest;
|
||||||
|
|
||||||
|
import java.io.Serializable;
|
||||||
|
import java.util.Collection;
|
||||||
|
import java.util.LinkedList;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
public interface Covariate<V> extends Serializable {
|
||||||
|
|
||||||
|
String getName();
|
||||||
|
|
||||||
|
Collection<? extends SplitRule<V>> generateSplitRules(final List<Value<V>> data, final int number);
|
||||||
|
|
||||||
|
Value<V> createValue(V value);
|
||||||
|
|
||||||
|
interface Value<V> extends Serializable{
|
||||||
|
|
||||||
|
Covariate<V> getParent();
|
||||||
|
|
||||||
|
V getValue();
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
interface SplitRule<V> extends Serializable{
|
||||||
|
|
||||||
|
Covariate<V> getParent();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Applies the SplitRule to a list of rows and returns a Split object, which contains two lists for both sides.
|
||||||
|
* This method is primarily used during the training of a tree when splits are being tested.
|
||||||
|
*
|
||||||
|
* @param rows
|
||||||
|
* @param <Y>
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
default <Y> Split<Y> applyRule(List<Row<Y>> rows) {
|
||||||
|
final List<Row<Y>> leftHand = new LinkedList<>();
|
||||||
|
final List<Row<Y>> rightHand = new LinkedList<>();
|
||||||
|
|
||||||
|
for(final Row<Y> row : rows) {
|
||||||
|
|
||||||
|
if(isLeftHand(row)){
|
||||||
|
leftHand.add(row);
|
||||||
|
}
|
||||||
|
else{
|
||||||
|
rightHand.add(row);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
return new Split<>(leftHand, rightHand);
|
||||||
|
}
|
||||||
|
|
||||||
|
boolean isLeftHand(CovariateRow row);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
}
|
|
@ -8,12 +8,12 @@ import java.util.Map;
|
||||||
@RequiredArgsConstructor
|
@RequiredArgsConstructor
|
||||||
public class CovariateRow {
|
public class CovariateRow {
|
||||||
|
|
||||||
private final Map<String, Value> valueMap;
|
private final Map<String, Covariate.Value> valueMap;
|
||||||
|
|
||||||
@Getter
|
@Getter
|
||||||
private final int id;
|
private final int id;
|
||||||
|
|
||||||
public Value<?> getCovariate(String name){
|
public Covariate.Value<?> getCovariateValue(String name){
|
||||||
return valueMap.get(name);
|
return valueMap.get(name);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
122
src/main/java/ca/joeltherrien/randomforest/FactorCovariate.java
Normal file
122
src/main/java/ca/joeltherrien/randomforest/FactorCovariate.java
Normal file
|
@ -0,0 +1,122 @@
|
||||||
|
package ca.joeltherrien.randomforest;
|
||||||
|
|
||||||
|
import lombok.EqualsAndHashCode;
|
||||||
|
|
||||||
|
import java.util.*;
|
||||||
|
import java.util.concurrent.ThreadLocalRandom;
|
||||||
|
|
||||||
|
public final class FactorCovariate implements Covariate<String>{
|
||||||
|
|
||||||
|
private final String name;
|
||||||
|
private final Map<String, FactorValue> factorLevels;
|
||||||
|
private final int numberOfPossiblePairings;
|
||||||
|
|
||||||
|
|
||||||
|
public FactorCovariate(final String name, List<String> levels){
|
||||||
|
this.name = name;
|
||||||
|
this.factorLevels = new HashMap<>();
|
||||||
|
|
||||||
|
for(final String level : levels){
|
||||||
|
final FactorValue newValue = new FactorValue(level);
|
||||||
|
|
||||||
|
factorLevels.put(level, newValue);
|
||||||
|
}
|
||||||
|
|
||||||
|
int numberOfPossiblePairingsTemp = 1;
|
||||||
|
for(int i=0; i<levels.size()-1; i++){
|
||||||
|
numberOfPossiblePairingsTemp *= 2;
|
||||||
|
}
|
||||||
|
this.numberOfPossiblePairings = numberOfPossiblePairingsTemp-1;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String getName() {
|
||||||
|
return name;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Set<FactorSplitRule> generateSplitRules(List<Value<String>> data, int number) {
|
||||||
|
final Set<FactorSplitRule> splitRules = new HashSet<>();
|
||||||
|
|
||||||
|
// This is to ensure we don't get stuck in an infinite loop for small factors
|
||||||
|
number = Math.min(number, numberOfPossiblePairings);
|
||||||
|
final Random random = ThreadLocalRandom.current();
|
||||||
|
final List<FactorValue> levels = new ArrayList<>(factorLevels.values());
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
while(splitRules.size() < number){
|
||||||
|
Collections.shuffle(levels, random);
|
||||||
|
final Set<FactorValue> leftSideValues = new HashSet<>();
|
||||||
|
leftSideValues.add(levels.get(0));
|
||||||
|
|
||||||
|
for(int i=1; i<levels.size()/2; i++){
|
||||||
|
if(random.nextBoolean()){
|
||||||
|
leftSideValues.add(levels.get(i));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
splitRules.add(new FactorSplitRule(leftSideValues));
|
||||||
|
}
|
||||||
|
|
||||||
|
return splitRules;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public FactorValue createValue(String value) {
|
||||||
|
final FactorValue factorValue = factorLevels.get(value);
|
||||||
|
|
||||||
|
if(factorValue == null){
|
||||||
|
throw new IllegalArgumentException(value + " is not a level in FactorCovariate " + name);
|
||||||
|
}
|
||||||
|
|
||||||
|
return factorValue;
|
||||||
|
}
|
||||||
|
|
||||||
|
@EqualsAndHashCode
|
||||||
|
public final class FactorValue implements Covariate.Value<String>{
|
||||||
|
|
||||||
|
private final String value;
|
||||||
|
|
||||||
|
private FactorValue(final String value){
|
||||||
|
this.value = value;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public FactorCovariate getParent() {
|
||||||
|
return FactorCovariate.this;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String getValue() {
|
||||||
|
return value;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@EqualsAndHashCode
|
||||||
|
public final class FactorSplitRule implements Covariate.SplitRule<String>{
|
||||||
|
|
||||||
|
private final Set<FactorValue> leftSideValues;
|
||||||
|
|
||||||
|
private FactorSplitRule(final Set<FactorValue> leftSideValues){
|
||||||
|
this.leftSideValues = leftSideValues;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public FactorCovariate getParent() {
|
||||||
|
return FactorCovariate.this;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean isLeftHand(CovariateRow row) {
|
||||||
|
final FactorValue value = (FactorValue) row.getCovariateValue(getName()).getValue();
|
||||||
|
|
||||||
|
return leftSideValues.contains(value);
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
105
src/main/java/ca/joeltherrien/randomforest/NumericCovariate.java
Normal file
105
src/main/java/ca/joeltherrien/randomforest/NumericCovariate.java
Normal file
|
@ -0,0 +1,105 @@
|
||||||
|
package ca.joeltherrien.randomforest;
|
||||||
|
|
||||||
|
import ca.joeltherrien.randomforest.exceptions.MissingValueException;
|
||||||
|
import lombok.Getter;
|
||||||
|
import lombok.RequiredArgsConstructor;
|
||||||
|
|
||||||
|
import java.util.*;
|
||||||
|
import java.util.concurrent.ThreadLocalRandom;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
|
@RequiredArgsConstructor
|
||||||
|
public class NumericCovariate implements Covariate<Double>{
|
||||||
|
|
||||||
|
@Getter
|
||||||
|
private final String name;
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Collection<NumericSplitRule> generateSplitRules(List<Value<Double>> data, int number) {
|
||||||
|
|
||||||
|
final Random random = ThreadLocalRandom.current();
|
||||||
|
|
||||||
|
// for this implementation we need to shuffle the data
|
||||||
|
final List<Value<Double>> shuffledData;
|
||||||
|
if(number > data.size()){
|
||||||
|
shuffledData = new ArrayList<>(data);
|
||||||
|
Collections.shuffle(shuffledData, random);
|
||||||
|
}
|
||||||
|
else{ // only need the top number entries
|
||||||
|
shuffledData = new ArrayList<>(number);
|
||||||
|
final Set<Integer> indexesToUse = new HashSet<>();
|
||||||
|
|
||||||
|
while(indexesToUse.size() < number){
|
||||||
|
final int index = random.nextInt(data.size());
|
||||||
|
|
||||||
|
if(indexesToUse.add(index)){
|
||||||
|
shuffledData.add(data.get(index));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
return shuffledData.stream()
|
||||||
|
.mapToDouble(v -> v.getValue())
|
||||||
|
.mapToObj(threshold -> new NumericSplitRule(threshold))
|
||||||
|
.collect(Collectors.toSet());
|
||||||
|
// by returning a set we'll make everything far more efficient as a lot of rules can repeat due to bootstrapping
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public NumericValue createValue(Double value) {
|
||||||
|
return new NumericValue(value);
|
||||||
|
}
|
||||||
|
|
||||||
|
public class NumericValue implements Covariate.Value<Double>{
|
||||||
|
|
||||||
|
private final double value;
|
||||||
|
|
||||||
|
private NumericValue(final double value){
|
||||||
|
this.value = value;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public NumericCovariate getParent() {
|
||||||
|
return NumericCovariate.this;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Double getValue() {
|
||||||
|
return value;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public class NumericSplitRule implements Covariate.SplitRule<Double>{
|
||||||
|
|
||||||
|
private final double threshold;
|
||||||
|
|
||||||
|
private NumericSplitRule(final double threshold){
|
||||||
|
this.threshold = threshold;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public final String toString() {
|
||||||
|
return "NumericSplitRule on " + getParent().getName() + " at " + threshold;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public NumericCovariate getParent() {
|
||||||
|
return NumericCovariate.this;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean isLeftHand(CovariateRow row) {
|
||||||
|
final Covariate.Value<?> x = row.getCovariateValue(getParent().getName());
|
||||||
|
if(x == null) {
|
||||||
|
throw new MissingValueException(row, this);
|
||||||
|
}
|
||||||
|
|
||||||
|
final double xNum = (Double) x.getValue();
|
||||||
|
|
||||||
|
return xNum <= threshold;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,33 +0,0 @@
|
||||||
package ca.joeltherrien.randomforest;
|
|
||||||
|
|
||||||
import java.util.LinkedList;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
import ca.joeltherrien.randomforest.exceptions.MissingValueException;
|
|
||||||
import lombok.AllArgsConstructor;
|
|
||||||
|
|
||||||
@AllArgsConstructor
|
|
||||||
public class NumericSplitRule extends SplitRule{
|
|
||||||
|
|
||||||
public final String covariateName;
|
|
||||||
public final double threshold;
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public final String toString() {
|
|
||||||
return "NumericSplitRule on " + covariateName + " at " + threshold;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean isLeftHand(CovariateRow row) {
|
|
||||||
final Value<?> x = row.getCovariate(covariateName);
|
|
||||||
if(x == null) {
|
|
||||||
throw new MissingValueException(row, this);
|
|
||||||
}
|
|
||||||
|
|
||||||
final double xNum = (Double) x.getValue();
|
|
||||||
|
|
||||||
return xNum <= threshold;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,24 +0,0 @@
|
||||||
package ca.joeltherrien.randomforest;
|
|
||||||
|
|
||||||
import lombok.RequiredArgsConstructor;
|
|
||||||
|
|
||||||
@RequiredArgsConstructor
|
|
||||||
public class NumericValue implements Value<Double> {
|
|
||||||
|
|
||||||
private final double value;
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Double getValue() {
|
|
||||||
return value;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public SplitRule generateSplitRule(final String covariateName) {
|
|
||||||
return new NumericSplitRule(covariateName, value);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String toString(){
|
|
||||||
return "" + value;
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -7,7 +7,7 @@ public class Row<Y> extends CovariateRow {
|
||||||
|
|
||||||
private final Y response;
|
private final Y response;
|
||||||
|
|
||||||
public Row(Map<String, Value> valueMap, int id, Y response){
|
public Row(Map<String, Covariate.Value> valueMap, int id, Y response){
|
||||||
super(valueMap, id);
|
super(valueMap, id);
|
||||||
this.response = response;
|
this.response = response;
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,37 +0,0 @@
|
||||||
package ca.joeltherrien.randomforest;
|
|
||||||
|
|
||||||
import java.io.Serializable;
|
|
||||||
import java.util.LinkedList;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
public abstract class SplitRule implements Serializable {
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Applies the SplitRule to a list of rows and returns a Split object, which contains two lists for both sides.
|
|
||||||
* This method is primarily used during the training of a tree when splits are being tested.
|
|
||||||
*
|
|
||||||
* @param rows
|
|
||||||
* @param <Y>
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
public <Y> Split<Y> applyRule(List<Row<Y>> rows) {
|
|
||||||
final List<Row<Y>> leftHand = new LinkedList<>();
|
|
||||||
final List<Row<Y>> rightHand = new LinkedList<>();
|
|
||||||
|
|
||||||
for(final Row<Y> row : rows) {
|
|
||||||
|
|
||||||
if(isLeftHand(row)){
|
|
||||||
leftHand.add(row);
|
|
||||||
}
|
|
||||||
else{
|
|
||||||
rightHand.add(row);
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
return new Split<>(leftHand, rightHand);
|
|
||||||
}
|
|
||||||
|
|
||||||
public abstract boolean isLeftHand(CovariateRow row);
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,11 +0,0 @@
|
||||||
package ca.joeltherrien.randomforest;
|
|
||||||
|
|
||||||
|
|
||||||
public interface Value<V> {
|
|
||||||
|
|
||||||
V getValue();
|
|
||||||
|
|
||||||
SplitRule generateSplitRule(String covariateName);
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,8 +1,7 @@
|
||||||
package ca.joeltherrien.randomforest.exceptions;
|
package ca.joeltherrien.randomforest.exceptions;
|
||||||
|
|
||||||
|
import ca.joeltherrien.randomforest.Covariate;
|
||||||
import ca.joeltherrien.randomforest.CovariateRow;
|
import ca.joeltherrien.randomforest.CovariateRow;
|
||||||
import ca.joeltherrien.randomforest.Row;
|
|
||||||
import ca.joeltherrien.randomforest.SplitRule;
|
|
||||||
|
|
||||||
public class MissingValueException extends RuntimeException{
|
public class MissingValueException extends RuntimeException{
|
||||||
|
|
||||||
|
@ -11,7 +10,7 @@ public class MissingValueException extends RuntimeException{
|
||||||
*/
|
*/
|
||||||
private static final long serialVersionUID = 6808060079431207726L;
|
private static final long serialVersionUID = 6808060079431207726L;
|
||||||
|
|
||||||
public MissingValueException(CovariateRow row, SplitRule rule) {
|
public MissingValueException(CovariateRow row, Covariate.SplitRule rule) {
|
||||||
super("Missing value at CovariateRow " + row + rule);
|
super("Missing value at CovariateRow " + row + rule);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package ca.joeltherrien.randomforest.tree;
|
package ca.joeltherrien.randomforest.tree;
|
||||||
|
|
||||||
import ca.joeltherrien.randomforest.Bootstrapper;
|
import ca.joeltherrien.randomforest.Bootstrapper;
|
||||||
|
import ca.joeltherrien.randomforest.Covariate;
|
||||||
import ca.joeltherrien.randomforest.ResponseCombiner;
|
import ca.joeltherrien.randomforest.ResponseCombiner;
|
||||||
import ca.joeltherrien.randomforest.Row;
|
import ca.joeltherrien.randomforest.Row;
|
||||||
import lombok.Builder;
|
import lombok.Builder;
|
||||||
|
@ -21,7 +22,7 @@ import java.util.stream.Stream;
|
||||||
public class ForestTrainer<Y> {
|
public class ForestTrainer<Y> {
|
||||||
|
|
||||||
private final TreeTrainer<Y> treeTrainer;
|
private final TreeTrainer<Y> treeTrainer;
|
||||||
private final List<String> covariatesToTry;
|
private final List<Covariate> covariatesToTry;
|
||||||
private final ResponseCombiner<Y, ?> treeResponseCombiner;
|
private final ResponseCombiner<Y, ?> treeResponseCombiner;
|
||||||
private final List<Row<Y>> data;
|
private final List<Row<Y>> data;
|
||||||
|
|
||||||
|
@ -140,7 +141,7 @@ public class ForestTrainer<Y> {
|
||||||
}
|
}
|
||||||
|
|
||||||
private Node<Y> trainTree(final Bootstrapper<Row<Y>> bootstrapper){
|
private Node<Y> trainTree(final Bootstrapper<Row<Y>> bootstrapper){
|
||||||
final List<String> treeCovariates = new ArrayList<>(covariatesToTry);
|
final List<Covariate> treeCovariates = new ArrayList<>(covariatesToTry);
|
||||||
Collections.shuffle(treeCovariates);
|
Collections.shuffle(treeCovariates);
|
||||||
|
|
||||||
for(int treeIndex = covariatesToTry.size()-1; treeIndex >= mtry; treeIndex--){
|
for(int treeIndex = covariatesToTry.size()-1; treeIndex >= mtry; treeIndex--){
|
||||||
|
|
|
@ -1,8 +1,7 @@
|
||||||
package ca.joeltherrien.randomforest.tree;
|
package ca.joeltherrien.randomforest.tree;
|
||||||
|
|
||||||
|
import ca.joeltherrien.randomforest.Covariate;
|
||||||
import ca.joeltherrien.randomforest.CovariateRow;
|
import ca.joeltherrien.randomforest.CovariateRow;
|
||||||
import ca.joeltherrien.randomforest.Row;
|
|
||||||
import ca.joeltherrien.randomforest.SplitRule;
|
|
||||||
import lombok.Builder;
|
import lombok.Builder;
|
||||||
|
|
||||||
@Builder
|
@Builder
|
||||||
|
@ -10,7 +9,7 @@ public class SplitNode<Y> implements Node<Y> {
|
||||||
|
|
||||||
private final Node<Y> leftHand;
|
private final Node<Y> leftHand;
|
||||||
private final Node<Y> rightHand;
|
private final Node<Y> rightHand;
|
||||||
private final SplitRule splitRule;
|
private final Covariate.SplitRule splitRule;
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Y evaluate(CovariateRow row) {
|
public Y evaluate(CovariateRow row) {
|
||||||
|
|
|
@ -1,8 +1,6 @@
|
||||||
package ca.joeltherrien.randomforest.tree;
|
package ca.joeltherrien.randomforest.tree;
|
||||||
|
|
||||||
import ca.joeltherrien.randomforest.*;
|
import ca.joeltherrien.randomforest.*;
|
||||||
import ca.joeltherrien.randomforest.regression.MeanResponseCombiner;
|
|
||||||
import ca.joeltherrien.randomforest.regression.WeightedVarianceGroupDifferentiator;
|
|
||||||
import lombok.Builder;
|
import lombok.Builder;
|
||||||
|
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
|
@ -22,17 +20,24 @@ public class TreeTrainer<Y> {
|
||||||
private final int nodeSize;
|
private final int nodeSize;
|
||||||
private final int maxNodeDepth;
|
private final int maxNodeDepth;
|
||||||
|
|
||||||
private final Random random = new Random();
|
|
||||||
|
|
||||||
|
public Node<Y> growTree(List<Row<Y>> data, List<Covariate> covariatesToTry){
|
||||||
public Node<Y> growTree(List<Row<Y>> data, List<String> covariatesToTry){
|
|
||||||
return growNode(data, covariatesToTry, 0);
|
return growNode(data, covariatesToTry, 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
private Node<Y> growNode(List<Row<Y>> data, List<String> covariatesToTry, int depth){
|
private Node<Y> growNode(List<Row<Y>> data, List<Covariate> covariatesToTry, int depth){
|
||||||
// TODO; what is minimum per tree?
|
// TODO; what is minimum per tree?
|
||||||
if(data.size() >= 2*nodeSize && depth < maxNodeDepth && !nodeIsPure(data, covariatesToTry)){
|
if(data.size() >= 2*nodeSize && depth < maxNodeDepth && !nodeIsPure(data)){
|
||||||
final SplitRule bestSplitRule = findBestSplitRule(data, covariatesToTry);
|
final Covariate.SplitRule bestSplitRule = findBestSplitRule(data, covariatesToTry);
|
||||||
|
|
||||||
|
if(bestSplitRule == null){
|
||||||
|
return new TerminalNode<>(
|
||||||
|
data.stream()
|
||||||
|
.map(row -> row.getResponse())
|
||||||
|
.collect(responseCombiner)
|
||||||
|
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
final Split<Y> split = bestSplitRule.applyRule(data); // TODO optimize this as we're duplicating work done in findBestSplitRule
|
final Split<Y> split = bestSplitRule.applyRule(data); // TODO optimize this as we're duplicating work done in findBestSplitRule
|
||||||
|
|
||||||
|
@ -54,37 +59,24 @@ public class TreeTrainer<Y> {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private SplitRule findBestSplitRule(List<Row<Y>> data, List<String> covariatesToTry){
|
private Covariate.SplitRule findBestSplitRule(List<Row<Y>> data, List<Covariate> covariatesToTry){
|
||||||
SplitRule bestSplitRule = null;
|
Covariate.SplitRule bestSplitRule = null;
|
||||||
Double bestSplitScore = 0.0; // may be null
|
double bestSplitScore = 0.0;
|
||||||
boolean first = true;
|
boolean first = true;
|
||||||
|
|
||||||
for(final String covariate : covariatesToTry){
|
for(final Covariate covariate : covariatesToTry){
|
||||||
|
|
||||||
final List<Row<Y>> shuffledData;
|
final int numberToTry = numberOfSplits==0 ? data.size() : numberOfSplits;
|
||||||
if(numberOfSplits == 0 || numberOfSplits > data.size()){
|
|
||||||
shuffledData = new ArrayList<>(data);
|
|
||||||
Collections.shuffle(shuffledData);
|
|
||||||
}
|
|
||||||
else{ // only need the top numberOfSplits entries
|
|
||||||
shuffledData = new ArrayList<>(numberOfSplits);
|
|
||||||
final Set<Integer> indexesToUse = new HashSet<>();
|
|
||||||
|
|
||||||
while(indexesToUse.size() < numberOfSplits){
|
final Collection<Covariate.SplitRule> splitRulesToTry = covariate
|
||||||
final int index = random.nextInt(data.size());
|
.generateSplitRules(
|
||||||
|
data
|
||||||
|
.stream()
|
||||||
|
.map(row -> row.getCovariateValue(covariate.getName()))
|
||||||
|
.collect(Collectors.toList())
|
||||||
|
, numberToTry);
|
||||||
|
|
||||||
if(indexesToUse.add(index)){
|
for(final Covariate.SplitRule possibleRule : splitRulesToTry){
|
||||||
shuffledData.add(data.get(index));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
int tries = 0;
|
|
||||||
|
|
||||||
while(tries < shuffledData.size()){
|
|
||||||
final SplitRule possibleRule = shuffledData.get(tries).getCovariate(covariate).generateSplitRule(covariate);
|
|
||||||
final Split<Y> possibleSplit = possibleRule.applyRule(data);
|
final Split<Y> possibleSplit = possibleRule.applyRule(data);
|
||||||
|
|
||||||
final Double score = groupDifferentiator.differentiate(
|
final Double score = groupDifferentiator.differentiate(
|
||||||
|
@ -92,13 +84,11 @@ public class TreeTrainer<Y> {
|
||||||
possibleSplit.rightHand.stream().map(row -> row.getResponse()).collect(Collectors.toList())
|
possibleSplit.rightHand.stream().map(row -> row.getResponse()).collect(Collectors.toList())
|
||||||
);
|
);
|
||||||
|
|
||||||
if( first || (score != null && (bestSplitScore == null || score > bestSplitScore))){
|
if(score != null && (score > bestSplitScore || first)){
|
||||||
bestSplitRule = possibleRule;
|
bestSplitRule = possibleRule;
|
||||||
bestSplitScore = score;
|
bestSplitScore = score;
|
||||||
first = false;
|
first = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
tries++;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -107,9 +97,7 @@ public class TreeTrainer<Y> {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private boolean nodeIsPure(List<Row<Y>> data, List<String> covariatesToTry){
|
private boolean nodeIsPure(List<Row<Y>> data){
|
||||||
// TODO how is this done?
|
|
||||||
|
|
||||||
final Y first = data.get(0).getResponse();
|
final Y first = data.get(0).getResponse();
|
||||||
return data.stream().allMatch(row -> row.getResponse().equals(first));
|
return data.stream().allMatch(row -> row.getResponse().equals(first));
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,65 @@
|
||||||
|
package ca.joeltherrien.randomforest.covariates;
|
||||||
|
|
||||||
|
|
||||||
|
import ca.joeltherrien.randomforest.FactorCovariate;
|
||||||
|
import org.junit.jupiter.api.Test;
|
||||||
|
import org.junit.jupiter.api.function.Executable;
|
||||||
|
|
||||||
|
import java.util.Collection;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
import static org.junit.jupiter.api.Assertions.*;
|
||||||
|
|
||||||
|
public class FactorCovariateTest {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void verifyEqualLevels() {
|
||||||
|
final FactorCovariate petCovariate = createTestCovariate();
|
||||||
|
|
||||||
|
final FactorCovariate.FactorValue dog1 = petCovariate.createValue("DOG");
|
||||||
|
final FactorCovariate.FactorValue dog2 = petCovariate.createValue("DO" + "G");
|
||||||
|
|
||||||
|
assertSame(dog1, dog2);
|
||||||
|
|
||||||
|
final FactorCovariate.FactorValue cat1 = petCovariate.createValue("CAT");
|
||||||
|
final FactorCovariate.FactorValue cat2 = petCovariate.createValue("CA" + "T");
|
||||||
|
|
||||||
|
assertSame(cat1, cat2);
|
||||||
|
|
||||||
|
final FactorCovariate.FactorValue mouse1 = petCovariate.createValue("MOUSE");
|
||||||
|
final FactorCovariate.FactorValue mouse2 = petCovariate.createValue("MOUS" + "E");
|
||||||
|
|
||||||
|
assertSame(mouse1, mouse2);
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void verifyBadLevelException(){
|
||||||
|
final FactorCovariate petCovariate = createTestCovariate();
|
||||||
|
final Executable badCode = () -> petCovariate.createValue("vulcan");
|
||||||
|
|
||||||
|
assertThrows(IllegalArgumentException.class, badCode, "vulcan is not a level in FactorCovariate pet");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void testAllSubsets(){
|
||||||
|
final FactorCovariate petCovariate = createTestCovariate();
|
||||||
|
|
||||||
|
final Collection<FactorCovariate.FactorSplitRule> splitRules = petCovariate.generateSplitRules(null, 100);
|
||||||
|
|
||||||
|
assertEquals(splitRules.size(), 3);
|
||||||
|
|
||||||
|
// TODO verify the contents of the split rules
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
private FactorCovariate createTestCovariate(){
|
||||||
|
final List<String> levels = List.of("DOG", "CAT", "MOUSE");
|
||||||
|
|
||||||
|
return new FactorCovariate("pet", levels);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
}
|
|
@ -19,21 +19,28 @@ public class TrainForest {
|
||||||
final int n = 10000;
|
final int n = 10000;
|
||||||
final int p = 5;
|
final int p = 5;
|
||||||
|
|
||||||
|
|
||||||
final Random random = new Random();
|
final Random random = new Random();
|
||||||
|
|
||||||
final List<Row<Double>> data = new ArrayList<>(n);
|
final List<Row<Double>> data = new ArrayList<>(n);
|
||||||
|
|
||||||
double minY = 1000.0;
|
double minY = 1000.0;
|
||||||
|
|
||||||
|
final List<Covariate> covariateList = new ArrayList<>(p);
|
||||||
|
for(int j =0; j < p; j++){
|
||||||
|
final NumericCovariate covariate = new NumericCovariate("x"+j);
|
||||||
|
covariateList.add(covariate);
|
||||||
|
}
|
||||||
|
|
||||||
for(int i=0; i<n; i++){
|
for(int i=0; i<n; i++){
|
||||||
double y = 0.0;
|
double y = 0.0;
|
||||||
final Map<String, Value> map = new HashMap<>();
|
final Map<String, Covariate.Value> map = new HashMap<>();
|
||||||
|
|
||||||
for(int j=0; j<p; j++){
|
for(final Covariate covariate : covariateList) {
|
||||||
final double x = random.nextDouble();
|
final double x = random.nextDouble();
|
||||||
y+=x;
|
y += x;
|
||||||
|
|
||||||
map.put("x"+j, new NumericValue(x));
|
map.put(covariate.getName(), covariate.createValue(x));
|
||||||
}
|
}
|
||||||
|
|
||||||
data.add(i, new Row<>(map, i, y));
|
data.add(i, new Row<>(map, i, y));
|
||||||
|
@ -44,10 +51,8 @@ public class TrainForest {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
final List<String> covariateNames = IntStream.range(0, p).mapToObj(j -> "x"+j).collect(Collectors.toList());
|
|
||||||
|
|
||||||
|
final TreeTrainer<Double> treeTrainer = TreeTrainer.<Double>builder()
|
||||||
TreeTrainer<Double> treeTrainer = TreeTrainer.<Double>builder()
|
|
||||||
.numberOfSplits(5)
|
.numberOfSplits(5)
|
||||||
.nodeSize(5)
|
.nodeSize(5)
|
||||||
.maxNodeDepth(100000000)
|
.maxNodeDepth(100000000)
|
||||||
|
@ -58,7 +63,7 @@ public class TrainForest {
|
||||||
final ForestTrainer<Double> forestTrainer = ForestTrainer.<Double>builder()
|
final ForestTrainer<Double> forestTrainer = ForestTrainer.<Double>builder()
|
||||||
.treeTrainer(treeTrainer)
|
.treeTrainer(treeTrainer)
|
||||||
.data(data)
|
.data(data)
|
||||||
.covariatesToTry(covariateNames)
|
.covariatesToTry(covariateList)
|
||||||
.mtry(4)
|
.mtry(4)
|
||||||
.ntree(100)
|
.ntree(100)
|
||||||
.treeResponseCombiner(new MeanResponseCombiner())
|
.treeResponseCombiner(new MeanResponseCombiner())
|
||||||
|
@ -69,7 +74,7 @@ public class TrainForest {
|
||||||
final long startTime = System.currentTimeMillis();
|
final long startTime = System.currentTimeMillis();
|
||||||
|
|
||||||
//final Forest<Double> forest = forestTrainer.trainSerial();
|
//final Forest<Double> forest = forestTrainer.trainSerial();
|
||||||
//final Forest<Double> forest = forestTrainer.trainParallel(8);
|
//final Forest<Double> forest = forestTrainer.trainParallelInMemory(3);
|
||||||
forestTrainer.trainParallelOnDisk(3);
|
forestTrainer.trainParallelOnDisk(3);
|
||||||
|
|
||||||
final long endTime = System.currentTimeMillis();
|
final long endTime = System.currentTimeMillis();
|
||||||
|
@ -88,9 +93,9 @@ public class TrainForest {
|
||||||
|
|
||||||
System.out.println(forest.evaluate(testRow1));
|
System.out.println(forest.evaluate(testRow1));
|
||||||
System.out.println(forest.evaluate(testRow2));
|
System.out.println(forest.evaluate(testRow2));
|
||||||
|
|
||||||
System.out.println("MinY = " + minY);
|
|
||||||
*/
|
*/
|
||||||
|
System.out.println("MinY = " + minY);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
|
@ -1,11 +1,10 @@
|
||||||
package ca.joeltherrien.randomforest.workshop;
|
package ca.joeltherrien.randomforest.workshop;
|
||||||
|
|
||||||
|
|
||||||
|
import ca.joeltherrien.randomforest.Covariate;
|
||||||
import ca.joeltherrien.randomforest.CovariateRow;
|
import ca.joeltherrien.randomforest.CovariateRow;
|
||||||
import ca.joeltherrien.randomforest.NumericValue;
|
import ca.joeltherrien.randomforest.NumericCovariate;
|
||||||
import ca.joeltherrien.randomforest.Row;
|
import ca.joeltherrien.randomforest.Row;
|
||||||
import ca.joeltherrien.randomforest.Value;
|
|
||||||
import ca.joeltherrien.randomforest.regression.MeanGroupDifferentiator;
|
|
||||||
import ca.joeltherrien.randomforest.regression.MeanResponseCombiner;
|
import ca.joeltherrien.randomforest.regression.MeanResponseCombiner;
|
||||||
import ca.joeltherrien.randomforest.regression.WeightedVarianceGroupDifferentiator;
|
import ca.joeltherrien.randomforest.regression.WeightedVarianceGroupDifferentiator;
|
||||||
import ca.joeltherrien.randomforest.tree.Node;
|
import ca.joeltherrien.randomforest.tree.Node;
|
||||||
|
@ -25,30 +24,30 @@ public class TrainSingleTree {
|
||||||
final int n = 1000;
|
final int n = 1000;
|
||||||
final List<Row<Double>> trainingSet = new ArrayList<>(n);
|
final List<Row<Double>> trainingSet = new ArrayList<>(n);
|
||||||
|
|
||||||
final List<Value<Double>> x1List = DoubleStream
|
final Covariate<Double> x1Covariate = new NumericCovariate("x1");
|
||||||
|
final Covariate<Double> x2Covariate = new NumericCovariate("x2");
|
||||||
|
|
||||||
|
final List<Covariate.Value<Double>> x1List = DoubleStream
|
||||||
.generate(() -> random.nextDouble()*10.0)
|
.generate(() -> random.nextDouble()*10.0)
|
||||||
.limit(n)
|
.limit(n)
|
||||||
.mapToObj(x1 -> new NumericValue(x1))
|
.mapToObj(x1 -> x1Covariate.createValue(x1))
|
||||||
.collect(Collectors.toList());
|
.collect(Collectors.toList());
|
||||||
|
|
||||||
final List<Value<Double>> x2List = DoubleStream
|
final List<Covariate.Value<Double>> x2List = DoubleStream
|
||||||
.generate(() -> random.nextDouble()*10.0)
|
.generate(() -> random.nextDouble()*10.0)
|
||||||
.limit(n)
|
.limit(n)
|
||||||
.mapToObj(x1 -> new NumericValue(x1))
|
.mapToObj(x2 -> x1Covariate.createValue(x2))
|
||||||
.collect(Collectors.toList());
|
.collect(Collectors.toList());
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
for(int i=0; i<n; i++){
|
for(int i=0; i<n; i++){
|
||||||
double x1 = x1List.get(i).getValue();
|
final Covariate.Value<Double> x1 = x1List.get(i);
|
||||||
double x2 = x2List.get(i).getValue();
|
final Covariate.Value<Double> x2 = x2List.get(i);
|
||||||
|
|
||||||
trainingSet.add(generateRow(x1, x2, i));
|
trainingSet.add(generateRow(x1, x2, i));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
final long startTime = System.currentTimeMillis();
|
|
||||||
|
|
||||||
final TreeTrainer<Double> treeTrainer = TreeTrainer.<Double>builder()
|
final TreeTrainer<Double> treeTrainer = TreeTrainer.<Double>builder()
|
||||||
.groupDifferentiator(new WeightedVarianceGroupDifferentiator())
|
.groupDifferentiator(new WeightedVarianceGroupDifferentiator())
|
||||||
.responseCombiner(new MeanResponseCombiner())
|
.responseCombiner(new MeanResponseCombiner())
|
||||||
|
@ -57,25 +56,29 @@ public class TrainSingleTree {
|
||||||
.numberOfSplits(0)
|
.numberOfSplits(0)
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
|
final List<Covariate> covariateNames = List.of(x1Covariate, x2Covariate);
|
||||||
|
|
||||||
|
final long startTime = System.currentTimeMillis();
|
||||||
|
final Node<Double> baseNode = treeTrainer.growTree(trainingSet, covariateNames);
|
||||||
final long endTime = System.currentTimeMillis();
|
final long endTime = System.currentTimeMillis();
|
||||||
|
|
||||||
System.out.println(((double)(endTime - startTime))/1000.0);
|
System.out.println(((double)(endTime - startTime))/1000.0);
|
||||||
|
|
||||||
final List<String> covariateNames = List.of("x1", "x2");
|
|
||||||
|
|
||||||
final Node<Double> baseNode = treeTrainer.growTree(trainingSet, covariateNames);
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
final List<CovariateRow> testSet = new ArrayList<>();
|
final List<CovariateRow> testSet = new ArrayList<>();
|
||||||
testSet.add(generateCovariateRow(9, 2, 1)); // expect 1
|
testSet.add(generateCovariateRow(x1Covariate.createValue(9.0), x2Covariate.createValue(2.0), 1)); // expect 1
|
||||||
testSet.add(generateCovariateRow(5, 2, 5));
|
testSet.add(generateCovariateRow(x1Covariate.createValue(5.0), x2Covariate.createValue(2.0), 5));
|
||||||
testSet.add(generateCovariateRow(2, 2, 3));
|
testSet.add(generateCovariateRow(x1Covariate.createValue(2.0), x2Covariate.createValue(2.0), 3));
|
||||||
testSet.add(generateCovariateRow(9, 5, 0));
|
testSet.add(generateCovariateRow(x1Covariate.createValue(9.0), x2Covariate.createValue(5.0), 0));
|
||||||
testSet.add(generateCovariateRow(6, 5, 8));
|
testSet.add(generateCovariateRow(x1Covariate.createValue(6.0), x2Covariate.createValue(5.0), 8));
|
||||||
testSet.add(generateCovariateRow(3, 5, 10));
|
testSet.add(generateCovariateRow(x1Covariate.createValue(3.0), x2Covariate.createValue(5.0), 10));
|
||||||
testSet.add(generateCovariateRow(1, 5, 3));
|
testSet.add(generateCovariateRow(x1Covariate.createValue(1.0), x2Covariate.createValue(5.0), 3));
|
||||||
testSet.add(generateCovariateRow(7, 9, 2));
|
testSet.add(generateCovariateRow(x1Covariate.createValue(7.0), x2Covariate.createValue(9.0), 2));
|
||||||
testSet.add(generateCovariateRow(1, 9, 4));
|
testSet.add(generateCovariateRow(x1Covariate.createValue(1.0), x2Covariate.createValue(9.0), 4));
|
||||||
|
|
||||||
for(final CovariateRow testCase : testSet){
|
for(final CovariateRow testCase : testSet){
|
||||||
System.out.println(testCase);
|
System.out.println(testCase);
|
||||||
|
@ -91,18 +94,18 @@ public class TrainSingleTree {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public static Row<Double> generateRow(double x1, double x2, int id){
|
public static Row<Double> generateRow(Covariate.Value<Double> x1, Covariate.Value<Double> x2, int id){
|
||||||
double y = generateResponse(x1, x2);
|
double y = generateResponse(x1.getValue(), x2.getValue());
|
||||||
|
|
||||||
final Map<String, Value> map = Map.of("x1", new NumericValue(x1), "x2", new NumericValue(x2));
|
final Map<String, Covariate.Value> map = Map.of("x1", x1, "x2", x2);
|
||||||
|
|
||||||
return new Row<>(map, id, y);
|
return new Row<>(map, id, y);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
public static CovariateRow generateCovariateRow(double x1, double x2, int id){
|
public static CovariateRow generateCovariateRow(Covariate.Value x1, Covariate.Value x2, int id){
|
||||||
final Map<String, Value> map = Map.of("x1", new NumericValue(x1), "x2", new NumericValue(x2));
|
final Map<String, Covariate.Value> map = Map.of("x1", x1, "x2", x2);
|
||||||
|
|
||||||
return new CovariateRow(map, id);
|
return new CovariateRow(map, id);
|
||||||
|
|
|
@ -0,0 +1,190 @@
|
||||||
|
package ca.joeltherrien.randomforest.workshop;
|
||||||
|
|
||||||
|
|
||||||
|
import ca.joeltherrien.randomforest.*;
|
||||||
|
import ca.joeltherrien.randomforest.regression.MeanResponseCombiner;
|
||||||
|
import ca.joeltherrien.randomforest.regression.WeightedVarianceGroupDifferentiator;
|
||||||
|
import ca.joeltherrien.randomforest.tree.Node;
|
||||||
|
import ca.joeltherrien.randomforest.tree.TreeTrainer;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.Random;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
import java.util.stream.DoubleStream;
|
||||||
|
import java.util.stream.Stream;
|
||||||
|
|
||||||
|
public class TrainSingleTreeFactor {
|
||||||
|
|
||||||
|
public static void main(String[] args) {
|
||||||
|
System.out.println("Hello world!");
|
||||||
|
|
||||||
|
final Random random = new Random(123);
|
||||||
|
|
||||||
|
final int n = 10000;
|
||||||
|
final List<Row<Double>> trainingSet = new ArrayList<>(n);
|
||||||
|
|
||||||
|
final Covariate<Double> x1Covariate = new NumericCovariate("x1");
|
||||||
|
final Covariate<Double> x2Covariate = new NumericCovariate("x2");
|
||||||
|
final FactorCovariate x3Covariate = new FactorCovariate("x3", List.of("cat", "dog", "mouse"));
|
||||||
|
|
||||||
|
final List<Covariate.Value<Double>> x1List = DoubleStream
|
||||||
|
.generate(() -> random.nextDouble()*10.0)
|
||||||
|
.limit(n)
|
||||||
|
.mapToObj(x1 -> x1Covariate.createValue(x1))
|
||||||
|
.collect(Collectors.toList());
|
||||||
|
|
||||||
|
final List<Covariate.Value<Double>> x2List = DoubleStream
|
||||||
|
.generate(() -> random.nextDouble()*10.0)
|
||||||
|
.limit(n)
|
||||||
|
.mapToObj(x2 -> x1Covariate.createValue(x2))
|
||||||
|
.collect(Collectors.toList());
|
||||||
|
|
||||||
|
final List<Covariate.Value<String>> x3List = DoubleStream
|
||||||
|
.generate(() -> random.nextDouble())
|
||||||
|
.limit(n)
|
||||||
|
.mapToObj(db -> {
|
||||||
|
if(db < 0.333){
|
||||||
|
return "cat";
|
||||||
|
}
|
||||||
|
else if(db < 0.5){
|
||||||
|
return "mouse";
|
||||||
|
}
|
||||||
|
else{
|
||||||
|
return "dog";
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.map(str -> x3Covariate.createValue(str))
|
||||||
|
.collect(Collectors.toList());
|
||||||
|
|
||||||
|
|
||||||
|
for(int i=0; i<n; i++){
|
||||||
|
final Covariate.Value<Double> x1 = x1List.get(i);
|
||||||
|
final Covariate.Value<Double> x2 = x2List.get(i);
|
||||||
|
final Covariate.Value<String> x3 = x3List.get(i);
|
||||||
|
|
||||||
|
trainingSet.add(generateRow(x1, x2, x3, i));
|
||||||
|
}
|
||||||
|
|
||||||
|
final TreeTrainer<Double> treeTrainer = TreeTrainer.<Double>builder()
|
||||||
|
.groupDifferentiator(new WeightedVarianceGroupDifferentiator())
|
||||||
|
.responseCombiner(new MeanResponseCombiner())
|
||||||
|
.maxNodeDepth(30)
|
||||||
|
.nodeSize(5)
|
||||||
|
.numberOfSplits(5)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
final List<Covariate> covariateNames = List.of(x1Covariate, x2Covariate);
|
||||||
|
|
||||||
|
final long startTime = System.currentTimeMillis();
|
||||||
|
final Node<Double> baseNode = treeTrainer.growTree(trainingSet, covariateNames);
|
||||||
|
final long endTime = System.currentTimeMillis();
|
||||||
|
|
||||||
|
System.out.println(((double)(endTime - startTime))/1000.0);
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
final Covariate.Value<String> cat = x3Covariate.createValue("cat");
|
||||||
|
final Covariate.Value<String> dog = x3Covariate.createValue("dog");
|
||||||
|
final Covariate.Value<String> mouse = x3Covariate.createValue("mouse");
|
||||||
|
|
||||||
|
|
||||||
|
final List<CovariateRow> testSet = new ArrayList<>();
|
||||||
|
testSet.add(generateCovariateRow(x1Covariate.createValue(9.0), x2Covariate.createValue(2.0), cat, 1)); // expect 1
|
||||||
|
testSet.add(generateCovariateRow(x1Covariate.createValue(5.0), x2Covariate.createValue(2.0), dog, 5));
|
||||||
|
testSet.add(generateCovariateRow(x1Covariate.createValue(2.0), x2Covariate.createValue(2.0), cat, 3));
|
||||||
|
testSet.add(generateCovariateRow(x1Covariate.createValue(9.0), x2Covariate.createValue(5.0), dog, 0));
|
||||||
|
testSet.add(generateCovariateRow(x1Covariate.createValue(6.0), x2Covariate.createValue(5.0), cat, 8));
|
||||||
|
testSet.add(generateCovariateRow(x1Covariate.createValue(3.0), x2Covariate.createValue(5.0), dog, 10));
|
||||||
|
testSet.add(generateCovariateRow(x1Covariate.createValue(1.0), x2Covariate.createValue(5.0), cat, 3));
|
||||||
|
testSet.add(generateCovariateRow(x1Covariate.createValue(7.0), x2Covariate.createValue(9.0), dog, 2));
|
||||||
|
testSet.add(generateCovariateRow(x1Covariate.createValue(1.0), x2Covariate.createValue(9.0), cat, 4));
|
||||||
|
|
||||||
|
|
||||||
|
testSet.add(generateCovariateRow(x1Covariate.createValue(3.0), x2Covariate.createValue(9.0), mouse, 0));
|
||||||
|
testSet.add(generateCovariateRow(x1Covariate.createValue(7.0), x2Covariate.createValue(9.0), mouse, 5));
|
||||||
|
|
||||||
|
for(final CovariateRow testCase : testSet){
|
||||||
|
System.out.println(testCase);
|
||||||
|
System.out.println(baseNode.evaluate(testCase));
|
||||||
|
System.out.println();
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
public static Row<Double> generateRow(Covariate.Value<Double> x1, Covariate.Value<Double> x2, Covariate.Value<String> x3, int id){
|
||||||
|
double y = generateResponse(x1.getValue(), x2.getValue(), x3.getValue());
|
||||||
|
|
||||||
|
final Map<String, Covariate.Value> map = Map.of("x1", x1, "x2", x2);
|
||||||
|
|
||||||
|
return new Row<>(map, id, y);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
public static CovariateRow generateCovariateRow(Covariate.Value x1, Covariate.Value x2, Covariate.Value x3, int id){
|
||||||
|
final Map<String, Covariate.Value> map = Map.of("x1", x1, "x2", x2, "x3", x3);
|
||||||
|
|
||||||
|
return new CovariateRow(map, id);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
public static double generateResponse(double x1, double x2, String x3){
|
||||||
|
|
||||||
|
if(x3.equalsIgnoreCase("mouse")){
|
||||||
|
if(x1 <= 5){
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
else{
|
||||||
|
return 5;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// cat & dog below
|
||||||
|
|
||||||
|
if(x2 <= 3){
|
||||||
|
if(x1 <= 3){
|
||||||
|
return 3;
|
||||||
|
}
|
||||||
|
else if(x1 <= 7){
|
||||||
|
return 5;
|
||||||
|
}
|
||||||
|
else{
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else if(x1 >= 5){
|
||||||
|
if(x2 > 6){
|
||||||
|
return 2;
|
||||||
|
}
|
||||||
|
else if(x1 >= 8){
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
else{
|
||||||
|
return 8;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else if(x1 <= 2){
|
||||||
|
if(x2 >= 7){
|
||||||
|
return 4;
|
||||||
|
}
|
||||||
|
else{
|
||||||
|
return 3;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else{
|
||||||
|
return 10;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
Loading…
Reference in a new issue