Massive refactor; Use Iterators/Updaters when calculating difference scores for faster calculations.

Changed the covariates to be more clever with how they produce the different splits. In the future (not yet implemented) a clever GroupDifferentiator
could update the current score calculation based just on how many rows moved from one hand to the other. There were a few other changes as well;
TreeTrainer#growTree now accepts a Random as a parameter which is used throughout the entire growing process. This means it's now theoretically
possible to grow trees using a seed, so that results can be fully reproducible.
This commit is contained in:
Joel Therrien 2019-01-09 21:31:27 -08:00
parent e892076a05
commit a5fe856857
45 changed files with 741 additions and 212 deletions

View file

@ -4,20 +4,20 @@ import lombok.RequiredArgsConstructor;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.ThreadLocalRandom;
import java.util.Random;
@RequiredArgsConstructor
public class Bootstrapper<T> {
final private List<T> originalData;
public List<T> bootstrap(){
public List<T> bootstrap(Random random){
final int n = originalData.size();
final List<T> newList = new ArrayList<>(n);
for(int i=0; i<n; i++){
final int index = ThreadLocalRandom.current().nextInt(n);
final int index = random.nextInt(n);
newList.add(originalData.get(index));
}

View file

@ -17,7 +17,7 @@ public class CovariateRow implements Serializable {
@Getter
private final int id;
public Covariate.Value<?> getCovariateValue(Covariate covariate){
public <V> Covariate.Value<V> getCovariateValue(Covariate<V> covariate){
return valueArray[covariate.getIndex()];
}

View file

@ -1,9 +1,9 @@
package ca.joeltherrien.randomforest;
import ca.joeltherrien.randomforest.covariates.BooleanCovariateSettings;
import ca.joeltherrien.randomforest.covariates.settings.BooleanCovariateSettings;
import ca.joeltherrien.randomforest.covariates.Covariate;
import ca.joeltherrien.randomforest.covariates.FactorCovariateSettings;
import ca.joeltherrien.randomforest.covariates.NumericCovariateSettings;
import ca.joeltherrien.randomforest.covariates.settings.FactorCovariateSettings;
import ca.joeltherrien.randomforest.covariates.settings.NumericCovariateSettings;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskErrorRateCalculator;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskFunctions;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;

View file

@ -17,6 +17,8 @@ public class Row<Y> extends CovariateRow {
}
public Y getResponse() {
return this.response;
}

View file

@ -1,7 +1,7 @@
package ca.joeltherrien.randomforest;
import ca.joeltherrien.randomforest.covariates.Covariate;
import ca.joeltherrien.randomforest.covariates.CovariateSettings;
import ca.joeltherrien.randomforest.covariates.settings.CovariateSettings;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponseWithCensorTime;
import ca.joeltherrien.randomforest.responses.competingrisk.combiner.CompetingRiskFunctionCombiner;
@ -10,7 +10,6 @@ import ca.joeltherrien.randomforest.responses.competingrisk.differentiator.GrayL
import ca.joeltherrien.randomforest.responses.competingrisk.differentiator.GrayLogRankSingleGroupDifferentiator;
import ca.joeltherrien.randomforest.responses.competingrisk.differentiator.LogRankMultipleGroupDifferentiator;
import ca.joeltherrien.randomforest.responses.competingrisk.differentiator.LogRankSingleGroupDifferentiator;
import ca.joeltherrien.randomforest.responses.regression.MeanGroupDifferentiator;
import ca.joeltherrien.randomforest.responses.regression.MeanResponseCombiner;
import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceGroupDifferentiator;
import ca.joeltherrien.randomforest.tree.GroupDifferentiator;
@ -68,9 +67,6 @@ public class Settings {
GROUP_DIFFERENTIATOR_MAP.put(name.toLowerCase(), groupDifferentiatorConstructor);
}
static{
registerGroupDifferentiatorConstructor("MeanGroupDifferentiator",
(node) -> new MeanGroupDifferentiator()
);
registerGroupDifferentiatorConstructor("WeightedVarianceGroupDifferentiator",
(node) -> new WeightedVarianceGroupDifferentiator()
);

View file

@ -1,11 +1,12 @@
package ca.joeltherrien.randomforest.covariates;
import ca.joeltherrien.randomforest.Row;
import ca.joeltherrien.randomforest.tree.Split;
import ca.joeltherrien.randomforest.utils.SingletonIterator;
import lombok.Getter;
import lombok.RequiredArgsConstructor;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.*;
@RequiredArgsConstructor
public final class BooleanCovariate implements Covariate<Boolean> {
@ -19,8 +20,8 @@ public final class BooleanCovariate implements Covariate<Boolean>{
private final BooleanSplitRule splitRule = new BooleanSplitRule(); // there's only one possible rule for BooleanCovariates.
@Override
public Collection<BooleanSplitRule> generateSplitRules(List<Value<Boolean>> data, int number) {
return Collections.singleton(splitRule);
public <Y> Iterator<Split<Y, Boolean>> generateSplitRuleUpdater(List<Row<Y>> data, int number, Random random) {
return new SingletonIterator<>(BooleanCovariate.this.splitRule.applyRule(data));
}
@Override
@ -74,6 +75,7 @@ public final class BooleanCovariate implements Covariate<Boolean>{
}
}
public class BooleanSplitRule implements SplitRule<Boolean>{
@Override

View file

@ -5,10 +5,7 @@ import ca.joeltherrien.randomforest.Row;
import ca.joeltherrien.randomforest.tree.Split;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collection;
import java.util.LinkedList;
import java.util.List;
import java.util.*;
import java.util.concurrent.ThreadLocalRandom;
public interface Covariate<V> extends Serializable {
@ -17,7 +14,7 @@ public interface Covariate<V> extends Serializable {
int getIndex();
Collection<? extends SplitRule<V>> generateSplitRules(final List<Value<V>> data, final int number);
<Y> Iterator<Split<Y, V>> generateSplitRuleUpdater(final List<Row<Y>> data, final int number, final Random random);
Value<V> createValue(V value);
@ -39,6 +36,16 @@ public interface Covariate<V> extends Serializable {
}
interface SplitRuleUpdater<Y, V> extends Iterator<Split<Y, V>>{
Split<Y, V> currentSplit();
SplitUpdate<Y, V> nextUpdate();
}
interface SplitUpdate<Y, V> {
SplitRule<V> getSplitRule();
Collection<Row<Y>> rowsMovedToLeftHand();
}
interface SplitRule<V> extends Serializable{
Covariate<V> getParent();
@ -51,7 +58,7 @@ public interface Covariate<V> extends Serializable {
* @param <Y>
* @return
*/
default <Y> Split<Y> applyRule(List<Row<Y>> rows) {
default <Y> Split<Y, V> applyRule(List<Row<Y>> rows) {
final List<Row<Y>> leftHand = new LinkedList<>();
final List<Row<Y>> rightHand = new LinkedList<>();
@ -77,7 +84,7 @@ public interface Covariate<V> extends Serializable {
}
return new Split<>(leftHand, rightHand, missingValueRows);
return new Split<>(this, leftHand, rightHand, missingValueRows);
}
default boolean isLeftHand(CovariateRow row, final double probabilityNaLeftHand){

View file

@ -1,5 +1,7 @@
package ca.joeltherrien.randomforest.covariates;
import ca.joeltherrien.randomforest.Row;
import ca.joeltherrien.randomforest.tree.Split;
import lombok.EqualsAndHashCode;
import lombok.Getter;
@ -42,17 +44,14 @@ public final class FactorCovariate implements Covariate<String>{
@Override
public Set<FactorSplitRule> generateSplitRules(List<Value<String>> data, int number) {
final Set<FactorSplitRule> splitRules = new HashSet<>();
public <Y> Iterator<Split<Y, String>> generateSplitRuleUpdater(List<Row<Y>> data, int number, Random random) {
final Set<Split<Y, String>> splits = new HashSet<>();
// This is to ensure we don't get stuck in an infinite loop for small factors
number = Math.min(number, numberOfPossiblePairings);
final Random random = ThreadLocalRandom.current();
final List<FactorValue> levels = new ArrayList<>(factorLevels.values());
while(splitRules.size() < number){
while(splits.size() < number){
Collections.shuffle(levels, random);
final Set<FactorValue> leftSideValues = new HashSet<>();
leftSideValues.add(levels.get(0));
@ -63,13 +62,14 @@ public final class FactorCovariate implements Covariate<String>{
}
}
splitRules.add(new FactorSplitRule(leftSideValues));
splits.add(new FactorSplitRule(leftSideValues).applyRule(data));
}
return splitRules;
return splits.iterator();
}
@Override
public FactorValue createValue(String value) {
if(value == null || value.equalsIgnoreCase("na")){

View file

@ -1,12 +1,16 @@
package ca.joeltherrien.randomforest.covariates;
package ca.joeltherrien.randomforest.covariates.numeric;
import ca.joeltherrien.randomforest.Row;
import ca.joeltherrien.randomforest.covariates.Covariate;
import ca.joeltherrien.randomforest.utils.IndexedIterator;
import ca.joeltherrien.randomforest.utils.UniqueSubsetValueIterator;
import ca.joeltherrien.randomforest.utils.UniqueValueIterator;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.RequiredArgsConstructor;
import lombok.ToString;
import java.util.*;
import java.util.concurrent.ThreadLocalRandom;
import java.util.stream.Collectors;
@RequiredArgsConstructor
@ -20,40 +24,44 @@ public final class NumericCovariate implements Covariate<Double>{
private final int index;
@Override
public Collection<NumericSplitRule> generateSplitRules(List<Value<Double>> data, int number) {
public <Y> NumericSplitRuleUpdater<Y> generateSplitRuleUpdater(List<Row<Y>> data, int number, Random random) {
data = data.stream()
.filter(row -> !row.getCovariateValue(this).isNA())
.sorted((r1, r2) -> {
Double d1 = r1.getCovariateValue(this).getValue();
Double d2 = r2.getCovariateValue(this).getValue();
final Random random = ThreadLocalRandom.current();
return d1.compareTo(d2);
})
.collect(Collectors.toList());
// only work with non-NA values
data = data.stream().filter(value -> !value.isNA()).collect(Collectors.toList());
//data = data.stream().filter(value -> !value.isNA()).distinct().collect(Collectors.toList()); // TODO which to use?
Iterator<Double> sortedDataIterator = data.stream()
.map(row -> row.getCovariateValue(this).getValue())
.filter(v -> v != null)
.iterator();
// for this implementation we need to shuffle the data
final List<Value<Double>> shuffledData;
if(number >= data.size()){
shuffledData = data;
final IndexedIterator<Double> dataIterator;
if(number == 0){
dataIterator = new UniqueValueIterator<>(sortedDataIterator);
}
else{ // only need the top number entries
shuffledData = new ArrayList<>(number);
final Set<Integer> indexesToUse = new HashSet<>();
//final List<Integer> indexesToUse = new ArrayList<>(); // TODO which to use?
else{ // random splitting; we will not weight based on how many times a Row appears in the bootstrap sample
final TreeSet<Integer> indexSet = new TreeSet<>();
while(indexesToUse.size() < number){
final int index = random.nextInt(data.size());
final int maxIndex = data.size();
if(indexesToUse.add(index)){
shuffledData.add(data.get(index));
}
for(int i=0; i<number; i++){
indexSet.add(random.nextInt(maxIndex));
}
dataIterator = new UniqueSubsetValueIterator<>(
new UniqueValueIterator<>(sortedDataIterator),
indexSet.toArray(new Integer[indexSet.size()]) // TODO verify this is ordered
);
}
return shuffledData.stream()
.mapToDouble(v -> v.getValue())
.mapToObj(threshold -> new NumericSplitRule(threshold))
.collect(Collectors.toSet());
// by returning a set we'll make everything far more efficient as a lot of rules can repeat due to bootstrapping
return new NumericSplitRuleUpdater<>(this, data, dataIterator);
}
@ -101,7 +109,7 @@ public final class NumericCovariate implements Covariate<Double>{
private final double threshold;
private NumericSplitRule(final double threshold){
NumericSplitRule(final double threshold){
this.threshold = threshold;
}

View file

@ -0,0 +1,79 @@
package ca.joeltherrien.randomforest.covariates.numeric;
import ca.joeltherrien.randomforest.Row;
import ca.joeltherrien.randomforest.covariates.Covariate;
import ca.joeltherrien.randomforest.tree.Split;
import ca.joeltherrien.randomforest.utils.IndexedIterator;
import java.util.Collections;
import java.util.List;
public class NumericSplitRuleUpdater<Y> implements Covariate.SplitRuleUpdater<Y, Double> {
private final NumericCovariate covariate;
private final List<Row<Y>> orderedData;
private final IndexedIterator<Double> dataIterator;
private Split<Y, Double> currentSplit;
public NumericSplitRuleUpdater(final NumericCovariate covariate, final List<Row<Y>> orderedData, final IndexedIterator<Double> iterator){
this.covariate = covariate;
this.orderedData = orderedData;
this.dataIterator = iterator;
final List<Row<Y>> leftHandList = Collections.emptyList();
final List<Row<Y>> rightHandList = orderedData;
this.currentSplit = new Split<>(
covariate.new NumericSplitRule(Double.MIN_VALUE),
leftHandList,
rightHandList,
Collections.emptyList());
}
@Override
public Split<Y, Double> currentSplit() {
return this.currentSplit;
}
@Override
public NumericSplitUpdate<Y> nextUpdate() {
if(hasNext()){
final int currentPosition = dataIterator.getIndex();
final Double splitValue = dataIterator.next();
final int newPosition = dataIterator.getIndex();
final List<Row<Y>> rowsMoved = orderedData.subList(currentPosition, newPosition);
final NumericCovariate.NumericSplitRule splitRule = covariate.new NumericSplitRule(splitValue);
// Update current split
this.currentSplit = new Split<>(
splitRule,
orderedData.subList(0, newPosition),
orderedData.subList(newPosition, orderedData.size()),
Collections.emptyList());
return new NumericSplitUpdate<>(splitRule, rowsMoved);
}
return null;
}
@Override
public boolean hasNext() {
return dataIterator.hasNext();
}
@Override
public Split<Y, Double> next() {
if(hasNext()){
nextUpdate();
}
return this.currentSplit();
}
}

View file

@ -0,0 +1,24 @@
package ca.joeltherrien.randomforest.covariates.numeric;
import ca.joeltherrien.randomforest.Row;
import ca.joeltherrien.randomforest.covariates.Covariate;
import lombok.AllArgsConstructor;
import java.util.Collection;
@AllArgsConstructor
public class NumericSplitUpdate<Y> implements Covariate.SplitUpdate<Y, Double> {
private final NumericCovariate.NumericSplitRule numericSplitRule;
private final Collection<Row<Y>> rowsMoved;
@Override
public NumericCovariate.NumericSplitRule getSplitRule() {
return numericSplitRule;
}
@Override
public Collection<Row<Y>> rowsMovedToLeftHand() {
return rowsMoved;
}
}

View file

@ -1,5 +1,6 @@
package ca.joeltherrien.randomforest.covariates;
package ca.joeltherrien.randomforest.covariates.settings;
import ca.joeltherrien.randomforest.covariates.BooleanCovariate;
import lombok.Data;
import lombok.NoArgsConstructor;

View file

@ -1,5 +1,6 @@
package ca.joeltherrien.randomforest.covariates;
package ca.joeltherrien.randomforest.covariates.settings;
import ca.joeltherrien.randomforest.covariates.Covariate;
import com.fasterxml.jackson.annotation.JsonSubTypes;
import com.fasterxml.jackson.annotation.JsonTypeInfo;
import lombok.Getter;

View file

@ -1,5 +1,6 @@
package ca.joeltherrien.randomforest.covariates;
package ca.joeltherrien.randomforest.covariates.settings;
import ca.joeltherrien.randomforest.covariates.FactorCovariate;
import lombok.Data;
import lombok.NoArgsConstructor;

View file

@ -1,5 +1,6 @@
package ca.joeltherrien.randomforest.covariates;
package ca.joeltherrien.randomforest.covariates.settings;
import ca.joeltherrien.randomforest.covariates.numeric.NumericCovariate;
import lombok.Data;
import lombok.NoArgsConstructor;

View file

@ -3,6 +3,7 @@ package ca.joeltherrien.randomforest.responses.competingrisk.differentiator;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskSets;
import ca.joeltherrien.randomforest.tree.GroupDifferentiator;
import ca.joeltherrien.randomforest.tree.SimpleGroupDifferentiator;
import lombok.AllArgsConstructor;
import lombok.Data;
@ -14,11 +15,7 @@ import java.util.stream.Stream;
* modifies the abstract method.
*
*/
public abstract class CompetingRiskGroupDifferentiator<Y extends CompetingRiskResponse> implements GroupDifferentiator<Y>{
@Override
public abstract Double differentiate(List<Y> leftHand, List<Y> rightHand);
public abstract class CompetingRiskGroupDifferentiator<Y extends CompetingRiskResponse> extends SimpleGroupDifferentiator<Y> {
/**
* Calculates the log rank value (or the Gray's test value) for a *specific* event cause.

View file

@ -17,7 +17,7 @@ public class GrayLogRankMultipleGroupDifferentiator extends CompetingRiskGroupDi
private final int[] events;
@Override
public Double differentiate(List<CompetingRiskResponseWithCensorTime> leftHand, List<CompetingRiskResponseWithCensorTime> rightHand) {
public Double getScore(List<CompetingRiskResponseWithCensorTime> leftHand, List<CompetingRiskResponseWithCensorTime> rightHand) {
if(leftHand.size() == 0 || rightHand.size() == 0){
return null;
}

View file

@ -18,7 +18,7 @@ public class GrayLogRankSingleGroupDifferentiator extends CompetingRiskGroupDiff
private final int[] events;
@Override
public Double differentiate(List<CompetingRiskResponseWithCensorTime> leftHand, List<CompetingRiskResponseWithCensorTime> rightHand) {
public Double getScore(List<CompetingRiskResponseWithCensorTime> leftHand, List<CompetingRiskResponseWithCensorTime> rightHand) {
if(leftHand.size() == 0 || rightHand.size() == 0){
return null;
}

View file

@ -17,7 +17,7 @@ public class LogRankMultipleGroupDifferentiator extends CompetingRiskGroupDiffer
private final int[] events;
@Override
public Double differentiate(List<CompetingRiskResponse> leftHand, List<CompetingRiskResponse> rightHand) {
public Double getScore(List<CompetingRiskResponse> leftHand, List<CompetingRiskResponse> rightHand) {
if(leftHand.size() == 0 || rightHand.size() == 0){
return null;
}

View file

@ -18,7 +18,7 @@ public class LogRankSingleGroupDifferentiator extends CompetingRiskGroupDifferen
private final int[] events;
@Override
public Double differentiate(List<CompetingRiskResponse> leftHand, List<CompetingRiskResponse> rightHand) {
public Double getScore(List<CompetingRiskResponse> leftHand, List<CompetingRiskResponse> rightHand) {
if(leftHand.size() == 0 || rightHand.size() == 0){
return null;
}

View file

@ -1,26 +0,0 @@
package ca.joeltherrien.randomforest.responses.regression;
import ca.joeltherrien.randomforest.tree.GroupDifferentiator;
import java.util.List;
public class MeanGroupDifferentiator implements GroupDifferentiator<Double> {
@Override
public Double differentiate(List<Double> leftHand, List<Double> rightHand) {
double leftHandSize = leftHand.size();
double rightHandSize = rightHand.size();
if(leftHandSize == 0 || rightHandSize == 0){
return null;
}
double leftHandMean = leftHand.stream().mapToDouble(db -> db/leftHandSize).sum();
double rightHandMean = rightHand.stream().mapToDouble(db -> db/rightHandSize).sum();
return Math.abs(leftHandMean - rightHandMean);
}
}

View file

@ -1,13 +1,13 @@
package ca.joeltherrien.randomforest.responses.regression;
import ca.joeltherrien.randomforest.tree.GroupDifferentiator;
import ca.joeltherrien.randomforest.tree.SimpleGroupDifferentiator;
import java.util.List;
public class WeightedVarianceGroupDifferentiator implements GroupDifferentiator<Double> {
public class WeightedVarianceGroupDifferentiator extends SimpleGroupDifferentiator<Double> {
@Override
public Double differentiate(List<Double> leftHand, List<Double> rightHand) {
public Double getScore(List<Double> leftHand, List<Double> rightHand) {
final double leftHandSize = leftHand.size();
final double rightHandSize = rightHand.size();

View file

@ -14,8 +14,10 @@ import java.io.IOException;
import java.io.ObjectOutputStream;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;
import java.util.stream.Stream;
@ -45,17 +47,17 @@ public class ForestTrainer<Y, TO, FO> {
this.covariates = covariates;
this.treeResponseCombiner = settings.getTreeCombiner();
this.treeTrainer = new TreeTrainer<>(settings, covariates);
}
public Forest<TO, FO> trainSerial(){
final List<Tree<TO>> trees = new ArrayList<>(ntree);
final Bootstrapper<Row<Y>> bootstrapper = new Bootstrapper<>(data);
final Random random = new Random();
for(int j=0; j<ntree; j++){
trees.add(trainTree(bootstrapper));
trees.add(trainTree(bootstrapper, random));
if(displayProgress){
if(j==0) {
@ -162,9 +164,9 @@ public class ForestTrainer<Y, TO, FO> {
}
private Tree<TO> trainTree(final Bootstrapper<Row<Y>> bootstrapper){
final List<Row<Y>> bootstrappedData = bootstrapper.bootstrap();
return treeTrainer.growTree(bootstrappedData);
private Tree<TO> trainTree(final Bootstrapper<Row<Y>> bootstrapper, Random random){
final List<Row<Y>> bootstrappedData = bootstrapper.bootstrap(random);
return treeTrainer.growTree(bootstrappedData, random);
}
public void saveTree(final Tree<TO> tree, String name) throws IOException {
@ -193,7 +195,8 @@ public class ForestTrainer<Y, TO, FO> {
@Override
public void run() {
final Tree<TO> tree = trainTree(bootstrapper);
// ThreadLocalRandom should make sure we don't duplicate seeds
final Tree<TO> tree = trainTree(bootstrapper, ThreadLocalRandom.current());
// should be okay as the list structure isn't changing
treeList.set(treeIndex, tree);
@ -216,7 +219,8 @@ public class ForestTrainer<Y, TO, FO> {
@Override
public void run() {
final Tree<TO> tree = trainTree(bootstrapper);
// ThreadLocalRandom should make sure we don't duplicate seeds
final Tree<TO> tree = trainTree(bootstrapper, ThreadLocalRandom.current());
try {
saveTree(tree, filename);

View file

@ -1,15 +1,17 @@
package ca.joeltherrien.randomforest.tree;
import java.util.List;
import java.util.Iterator;
/**
* When choosing an optimal node to split on, we choose the split that maximizes the difference between the two groups.
* The GroupDifferentiator has one method that outputs a score to show how different groups are. The larger the score,
* the greater the difference.
* The GroupDifferentiator has one method that cycles through an iterator of Splits (FYI; check if the iterator is an
* instance of Covariate.SplitRuleUpdater; in which case you get access to the rows that change between splits)
*
* If you want to implement a very trivial GroupDifferentiator that just takes two Lists as arguments, try extending
* SimpleGroupDifferentiator.
*/
public interface GroupDifferentiator<Y> {
Double differentiate(List<Y> leftHand, List<Y> rightHand);
<V> SplitAndScore<Y, V> differentiate(Iterator<Split<Y, V>> splitIterator);
}

View file

@ -0,0 +1,50 @@
package ca.joeltherrien.randomforest.tree;
import ca.joeltherrien.randomforest.Row;
import java.util.Iterator;
import java.util.List;
import java.util.stream.Collectors;
public abstract class SimpleGroupDifferentiator<Y> implements GroupDifferentiator<Y> {
@Override
public <V> SplitAndScore<Y, V> differentiate(Iterator<Split<Y, V>> splitIterator) {
Double bestScore = null;
Split<Y, V> bestSplit = null;
while(splitIterator.hasNext()){
final Split<Y, V> candidateSplit = splitIterator.next();
final List<Y> leftHand = candidateSplit.getLeftHand().stream().map(Row::getResponse).collect(Collectors.toList());
final List<Y> rightHand = candidateSplit.getRightHand().stream().map(Row::getResponse).collect(Collectors.toList());
if(leftHand.isEmpty() || rightHand.isEmpty()){
continue;
}
final Double score = getScore(leftHand, rightHand);
if(score != null && (bestScore == null || score > bestScore)){
bestScore = score;
bestSplit = candidateSplit;
}
}
if(bestSplit == null){
return null;
}
return new SplitAndScore<>(bestSplit, bestScore);
}
/**
* Return a score; higher is better.
*
* @param leftHand
* @param rightHand
* @return
*/
public abstract Double getScore(List<Y> leftHand, List<Y> rightHand);
}

View file

@ -1,19 +1,21 @@
package ca.joeltherrien.randomforest.tree;
import ca.joeltherrien.randomforest.Row;
import ca.joeltherrien.randomforest.covariates.Covariate;
import lombok.Data;
import java.util.List;
/**
* Very simple class that contains three lists; it's essentially a thruple.
* Very simple class that contains three lists and a SplitRule.
*
* @author joel
*
*/
@Data
public class Split<Y> {
public final class Split<Y, V> {
public final Covariate.SplitRule<V> splitRule;
public final List<Row<Y>> leftHand;
public final List<Row<Y>> rightHand;
public final List<Row<Y>> naHand;

View file

@ -0,0 +1,15 @@
package ca.joeltherrien.randomforest.tree;
import lombok.AllArgsConstructor;
import lombok.Getter;
@AllArgsConstructor
public class SplitAndScore<Y, V> {
@Getter
private final Split<Y, V> split;
@Getter
private final Double score;
}

View file

@ -3,10 +3,11 @@ package ca.joeltherrien.randomforest.tree;
import ca.joeltherrien.randomforest.Row;
import ca.joeltherrien.randomforest.Settings;
import ca.joeltherrien.randomforest.covariates.Covariate;
import lombok.*;
import lombok.AccessLevel;
import lombok.AllArgsConstructor;
import lombok.Builder;
import java.util.*;
import java.util.concurrent.ThreadLocalRandom;
import java.util.stream.Collectors;
@Builder
@ -45,20 +46,21 @@ public class TreeTrainer<Y, O> {
this.covariates = covariates;
}
public Tree<O> growTree(List<Row<Y>> data){
public Tree<O> growTree(List<Row<Y>> data, Random random){
final Node<O> rootNode = growNode(data, 0);
final Node<O> rootNode = growNode(data, 0, random);
return new Tree<>(rootNode, data.stream().mapToInt(Row::getId).toArray());
}
private Node<O> growNode(List<Row<Y>> data, int depth){
private Node<O> growNode(List<Row<Y>> data, int depth, Random random){
// See https://kogalur.github.io/randomForestSRC/theory.html#section3.1 (near bottom)
if(data.size() >= 2*nodeSize && depth < maxNodeDepth && !nodeIsPure(data)){
final List<Covariate> covariatesToTry = selectCovariates(this.mtry);
final SplitRuleAndSplit bestSplitRuleAndSplit = findBestSplitRule(data, covariatesToTry);
final List<Covariate> covariatesToTry = selectCovariates(this.mtry, random);
final Split<Y,?> bestSplit = findBestSplitRule(data, covariatesToTry, random);
if(bestSplitRuleAndSplit.splitRule == null){
if(bestSplit == null){
return new TerminalNode<>(
responseCombiner.combine(
@ -69,14 +71,31 @@ public class TreeTrainer<Y, O> {
}
final Split<Y> split = bestSplitRuleAndSplit.split;
// Note that NAs have already been handled
// Now that we have the best split; we need to handle any NAs that were dropped off
final double probabilityLeftHand = (double) bestSplit.leftHand.size() /
(double) (bestSplit.leftHand.size() + bestSplit.rightHand.size());
// Assign missing values to the split
for(Row<Y> row : data) {
if(row.getCovariateValue(bestSplit.getSplitRule().getParent()).isNA()) {
final boolean randomDecision = random.nextDouble() <= probabilityLeftHand;
if(randomDecision){
bestSplit.getLeftHand().add(row);
}
else{
bestSplit.getRightHand().add(row);
}
}
}
final Node<O> leftNode = growNode(split.leftHand, depth+1);
final Node<O> rightNode = growNode(split.rightHand, depth+1);
final Node<O> leftNode = growNode(bestSplit.leftHand, depth+1, random);
final Node<O> rightNode = growNode(bestSplit.rightHand, depth+1, random);
return new SplitNode<>(leftNode, rightNode, bestSplitRuleAndSplit.splitRule, bestSplitRuleAndSplit.probabilityLeftHand);
return new SplitNode<>(leftNode, rightNode, bestSplit.getSplitRule(), probabilityLeftHand);
}
else{
@ -90,13 +109,13 @@ public class TreeTrainer<Y, O> {
}
private List<Covariate> selectCovariates(int mtry){
private List<Covariate> selectCovariates(int mtry, Random random){
if(mtry >= covariates.size()){
return covariates;
}
final List<Covariate> splitCovariates = new ArrayList<>(covariates);
Collections.shuffle(splitCovariates, ThreadLocalRandom.current());
Collections.shuffle(splitCovariates, random);
if (splitCovariates.size() > mtry) {
splitCovariates.subList(mtry, splitCovariates.size()).clear();
@ -105,63 +124,29 @@ public class TreeTrainer<Y, O> {
return splitCovariates;
}
private SplitRuleAndSplit findBestSplitRule(List<Row<Y>> data, List<Covariate> covariatesToTry){
final Random random = ThreadLocalRandom.current();
SplitRuleAndSplit bestSplitRuleAndSplit = new SplitRuleAndSplit();
double bestSplitScore = 0.0;
boolean first = true;
private Split<Y, ?> findBestSplitRule(List<Row<Y>> data, List<Covariate> covariatesToTry, Random random){
SplitAndScore<Y, ?> bestSplitAndScore = null;
final GroupDifferentiator noGenericDifferentiator = groupDifferentiator; // cause Java generics suck
for(final Covariate covariate : covariatesToTry) {
final Iterator<Split> iterator = covariate.generateSplitRuleUpdater(data, this.numberOfSplits, random);
final int numberToTry = numberOfSplits==0 ? data.size() : numberOfSplits;
final SplitAndScore<Y, ?> candidateSplitAndScore = noGenericDifferentiator.differentiate(iterator);
final Collection<Covariate.SplitRule> splitRulesToTry = covariate
.generateSplitRules(
data
.stream()
.map(row -> row.getCovariateValue(covariate))
.collect(Collectors.toList())
, numberToTry);
for(final Covariate.SplitRule possibleRule : splitRulesToTry){
final Split<Y> possibleSplit = possibleRule.applyRule(data);
// We have to handle any NAs
if(possibleSplit.leftHand.size() == 0 && possibleSplit.rightHand.size() == 0 && possibleSplit.naHand.size() > 0){
throw new IllegalArgumentException("Can't apply " + this + " when there are rows with missing data and no non-missing value rows");
}
final double probabilityLeftHand = (double) possibleSplit.leftHand.size() / (double) (possibleSplit.leftHand.size() + possibleSplit.rightHand.size());
for(final Row<Y> missingValueRow : possibleSplit.naHand){
final boolean randomDecision = random.nextDouble() <= probabilityLeftHand;
if(randomDecision){
possibleSplit.leftHand.add(missingValueRow);
}
else{
possibleSplit.rightHand.add(missingValueRow);
}
}
final Double score = groupDifferentiator.differentiate(
possibleSplit.leftHand.stream().map(row -> row.getResponse()).collect(Collectors.toList()),
possibleSplit.rightHand.stream().map(row -> row.getResponse()).collect(Collectors.toList())
);
if(score != null && !Double.isNaN(score) && (score > bestSplitScore || first)){
bestSplitRuleAndSplit.splitRule = possibleRule;
bestSplitRuleAndSplit.split = possibleSplit;
bestSplitRuleAndSplit.probabilityLeftHand = probabilityLeftHand;
bestSplitScore = score;
first = false;
if(candidateSplitAndScore != null) {
if (bestSplitAndScore == null || candidateSplitAndScore.getScore() > bestSplitAndScore.getScore()) {
bestSplitAndScore = candidateSplitAndScore;
}
}
}
return bestSplitRuleAndSplit;
if(bestSplitAndScore == null){
return null;
}
return bestSplitAndScore.getSplit();
}
@ -184,10 +169,4 @@ public class TreeTrainer<Y, O> {
return true;
}
private class SplitRuleAndSplit{
private Covariate.SplitRule splitRule = null;
private Split<Y> split = null;
private double probabilityLeftHand;
}
}

View file

@ -0,0 +1,9 @@
package ca.joeltherrien.randomforest.utils;
import java.util.Iterator;
public interface IndexedIterator<E> extends Iterator<E> {
int getIndex();
}

View file

@ -0,0 +1,29 @@
package ca.joeltherrien.randomforest.utils;
import lombok.RequiredArgsConstructor;
import java.util.Iterator;
@RequiredArgsConstructor
public class SingletonIterator<E> implements Iterator<E> {
private final E value;
private boolean beenCalled = false;
@Override
public boolean hasNext() {
return !beenCalled;
}
@Override
public E next() {
if(!beenCalled){
beenCalled = true;
return value;
}
return null;
}
}

View file

@ -0,0 +1,65 @@
package ca.joeltherrien.randomforest.utils;
import lombok.Getter;
import java.util.Iterator;
/**
* Iterator that wraps around a UniqueValueIterator. It continues to iterate until it gets to one of the prespecified indexes,
* and then proceeds just past that to the end of the existing values it's at.
*
* The wrapped iterator must be from a sorted collection of some sort such that equal values are clumped together.
* I.e. "b b c c c d d a a" is okay but "a b b c c a" is not as 'a' appears twice at different locations
*
* @param <E>
*/
public class UniqueSubsetValueIterator<E> implements IndexedIterator<E> {
private final UniqueValueIterator<E> iterator;
private final Integer[] indexValues;
private int currentIndexSpot = 0;
public UniqueSubsetValueIterator(final UniqueValueIterator<E> iterator, final Integer[] indexValues){
this.iterator = iterator;
this.indexValues = indexValues;
}
@Override
public boolean hasNext() {
return iterator.hasNext() && iterator.getIndex() <= indexValues[indexValues.length-1];
}
@Override
public E next() {
if(hasNext()){
final int indexToStopBy = indexValues[currentIndexSpot];
while(iterator.getIndex() <= indexToStopBy){
iterator.next();
}
for(int i = currentIndexSpot + 1; i < indexValues.length; i++){
if(iterator.getIndex() <= indexValues[i]){
currentIndexSpot = i;
break;
}
}
return iterator.getCurrentValue();
}
return null;
}
@Override
public int getIndex(){
return iterator.getIndex();
}
}

View file

@ -0,0 +1,71 @@
package ca.joeltherrien.randomforest.utils;
import lombok.Getter;
import java.util.Iterator;
/**
* Iterator that wraps around another iterator. It continues to iterate until it gets to the *end* of a sequence of identical values.
* It also tracks the current index in the original iterator.
*
* The wrapped iterator must be from a sorted collection of some sort such that equal values are clumped together.
* I.e. "b b c c c d d a a" is okay but "a b b c c a" is not as 'a' appears twice at different locations
*
* @param <E>
*/
public class UniqueValueIterator<E> implements IndexedIterator<E> {
private final Iterator<E> wrappedIterator;
@Getter private E currentValue = null;
@Getter private E nextValue;
public UniqueValueIterator(final Iterator<E> wrappedIterator){
this.wrappedIterator = wrappedIterator;
this.nextValue = wrappedIterator.next();
}
// Count must return the index of the last value of the sequence returned by next()
@Getter
private int index = 0;
@Override
public boolean hasNext() {
return nextValue != null;
}
@Override
public E next() {
int count = 1;
while(wrappedIterator.hasNext()){
final E currentIteratorValue = wrappedIterator.next();
if(currentIteratorValue.equals(nextValue)){
count++;
}
else{
index +=count;
currentValue = nextValue;
nextValue = currentIteratorValue;
return currentValue;
}
}
if(nextValue != null){
index += count;
currentValue = nextValue;
nextValue = null;
return currentValue;
}
else{
return null;
}
}
}

View file

@ -1,8 +1,8 @@
package ca.joeltherrien.randomforest;
import ca.joeltherrien.randomforest.covariates.BooleanCovariateSettings;
import ca.joeltherrien.randomforest.covariates.settings.BooleanCovariateSettings;
import ca.joeltherrien.randomforest.covariates.Covariate;
import ca.joeltherrien.randomforest.covariates.NumericCovariateSettings;
import ca.joeltherrien.randomforest.covariates.settings.NumericCovariateSettings;
import ca.joeltherrien.randomforest.responses.competingrisk.combiner.CompetingRiskFunctionCombiner;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskFunctions;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;

View file

@ -2,6 +2,8 @@ package ca.joeltherrien.randomforest.competingrisk;
import ca.joeltherrien.randomforest.*;
import ca.joeltherrien.randomforest.covariates.*;
import ca.joeltherrien.randomforest.covariates.settings.BooleanCovariateSettings;
import ca.joeltherrien.randomforest.covariates.settings.NumericCovariateSettings;
import ca.joeltherrien.randomforest.responses.competingrisk.*;
import ca.joeltherrien.randomforest.tree.Forest;
import ca.joeltherrien.randomforest.tree.ForestTrainer;
@ -18,6 +20,7 @@ import static org.junit.jupiter.api.Assertions.*;
import java.io.IOException;
import java.util.List;
import java.util.Random;
public class TestCompetingRisk {
@ -104,7 +107,7 @@ public class TestCompetingRisk {
final List<Row<CompetingRiskResponse>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation());
final TreeTrainer<CompetingRiskResponse, CompetingRiskFunctions> treeTrainer = new TreeTrainer<>(settings, covariates);
final Node<CompetingRiskFunctions> node = treeTrainer.growTree(dataset);
final Node<CompetingRiskFunctions> node = treeTrainer.growTree(dataset, new Random());
final CovariateRow newRow = getPredictionRow(covariates);
@ -157,7 +160,7 @@ public class TestCompetingRisk {
final List<Row<CompetingRiskResponse>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation());
final TreeTrainer<CompetingRiskResponse, CompetingRiskFunctions> treeTrainer = new TreeTrainer<>(settings, covariates);
final Node<CompetingRiskFunctions> node = treeTrainer.growTree(dataset);
final Node<CompetingRiskFunctions> node = treeTrainer.growTree(dataset, new Random());
final CovariateRow newRow = getPredictionRow(covariates);

View file

@ -4,7 +4,7 @@ import ca.joeltherrien.randomforest.DataLoader;
import ca.joeltherrien.randomforest.Row;
import ca.joeltherrien.randomforest.Settings;
import ca.joeltherrien.randomforest.covariates.Covariate;
import ca.joeltherrien.randomforest.covariates.NumericCovariateSettings;
import ca.joeltherrien.randomforest.covariates.settings.NumericCovariateSettings;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
import ca.joeltherrien.randomforest.responses.competingrisk.differentiator.LogRankMultipleGroupDifferentiator;
import ca.joeltherrien.randomforest.utils.Utils;
@ -53,14 +53,14 @@ public class TestLogRankMultipleGroupDifferentiator {
final List<Row<CompetingRiskResponse>> group1Bad = data.subList(0, 196);
final List<Row<CompetingRiskResponse>> group2Bad = data.subList(196, data.size());
final double scoreBad = groupDifferentiator.differentiate(
final double scoreBad = groupDifferentiator.getScore(
group1Bad.stream().map(Row::getResponse).collect(Collectors.toList()),
group2Bad.stream().map(Row::getResponse).collect(Collectors.toList()));
final List<Row<CompetingRiskResponse>> group1Good = data.subList(0, 199);
final List<Row<CompetingRiskResponse>> group2Good= data.subList(199, data.size());
final double scoreGood = groupDifferentiator.differentiate(
final double scoreGood = groupDifferentiator.getScore(
group1Good.stream().map(Row::getResponse).collect(Collectors.toList()),
group2Good.stream().map(Row::getResponse).collect(Collectors.toList()));

View file

@ -1,8 +1,7 @@
package ca.joeltherrien.randomforest.competingrisk;
import ca.joeltherrien.randomforest.Row;
import ca.joeltherrien.randomforest.responses.competingrisk.*;
import ca.joeltherrien.randomforest.responses.competingrisk.differentiator.LogRankMultipleGroupDifferentiator;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
import ca.joeltherrien.randomforest.responses.competingrisk.differentiator.LogRankSingleGroupDifferentiator;
import org.junit.jupiter.api.Test;
@ -51,7 +50,7 @@ public class TestLogRankSingleGroupDifferentiator {
final LogRankSingleGroupDifferentiator differentiator = new LogRankSingleGroupDifferentiator(1, new int[]{1});
final double score = differentiator.differentiate(data1, data2);
final double score = differentiator.getScore(data1, data2);
final double margin = 0.000001;
// Tested using 855 method
@ -71,14 +70,14 @@ public class TestLogRankSingleGroupDifferentiator {
final List<Row<CompetingRiskResponse>> group1Good = data.subList(0, 221);
final List<Row<CompetingRiskResponse>> group2Good = data.subList(221, data.size());
final double scoreGood = groupDifferentiator.differentiate(
final double scoreGood = groupDifferentiator.getScore(
group1Good.stream().map(Row::getResponse).collect(Collectors.toList()),
group2Good.stream().map(Row::getResponse).collect(Collectors.toList()));
final List<Row<CompetingRiskResponse>> group1Bad = data.subList(0, 222);
final List<Row<CompetingRiskResponse>> group2Bad = data.subList(222, data.size());
final double scoreBad = groupDifferentiator.differentiate(
final double scoreBad = groupDifferentiator.getScore(
group1Bad.stream().map(Row::getResponse).collect(Collectors.toList()),
group2Bad.stream().map(Row::getResponse).collect(Collectors.toList()));

View file

@ -5,8 +5,9 @@ import ca.joeltherrien.randomforest.utils.Utils;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.function.Executable;
import java.util.Collection;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import static org.junit.jupiter.api.Assertions.*;
@ -46,7 +47,10 @@ public class FactorCovariateTest {
void testAllSubsets(){
final FactorCovariate petCovariate = createTestCovariate();
final Collection<FactorCovariate.FactorSplitRule> splitRules = petCovariate.generateSplitRules(null, 100);
final List<Covariate.SplitRule<String>> splitRules = new ArrayList<>();
petCovariate.generateSplitRuleUpdater(null, 100, new Random())
.forEachRemaining(split -> splitRules.add(split.getSplitRule()));
assertEquals(splitRules.size(), 3);

View file

@ -3,10 +3,10 @@ package ca.joeltherrien.randomforest.csv;
import ca.joeltherrien.randomforest.DataLoader;
import ca.joeltherrien.randomforest.Row;
import ca.joeltherrien.randomforest.Settings;
import ca.joeltherrien.randomforest.covariates.BooleanCovariateSettings;
import ca.joeltherrien.randomforest.covariates.settings.BooleanCovariateSettings;
import ca.joeltherrien.randomforest.covariates.Covariate;
import ca.joeltherrien.randomforest.covariates.FactorCovariateSettings;
import ca.joeltherrien.randomforest.covariates.NumericCovariateSettings;
import ca.joeltherrien.randomforest.covariates.settings.FactorCovariateSettings;
import ca.joeltherrien.randomforest.covariates.settings.NumericCovariateSettings;
import ca.joeltherrien.randomforest.utils.Utils;
import com.fasterxml.jackson.databind.node.JsonNodeFactory;
import com.fasterxml.jackson.databind.node.ObjectNode;
@ -15,7 +15,6 @@ import org.junit.jupiter.api.Test;
import java.io.IOException;
import java.util.List;
import java.util.stream.Collectors;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;

View file

@ -3,7 +3,9 @@ package ca.joeltherrien.randomforest.settings;
import ca.joeltherrien.randomforest.Settings;
import static org.junit.jupiter.api.Assertions.assertEquals;
import ca.joeltherrien.randomforest.covariates.*;
import ca.joeltherrien.randomforest.covariates.settings.BooleanCovariateSettings;
import ca.joeltherrien.randomforest.covariates.settings.FactorCovariateSettings;
import ca.joeltherrien.randomforest.covariates.settings.NumericCovariateSettings;
import ca.joeltherrien.randomforest.utils.Utils;
import com.fasterxml.jackson.databind.node.JsonNodeFactory;
import com.fasterxml.jackson.databind.node.ObjectNode;

View file

@ -0,0 +1,25 @@
package ca.joeltherrien.randomforest.utils;
import org.junit.jupiter.api.Test;
import static org.junit.jupiter.api.Assertions.*;
public class TestSingletonIterator {
@Test
public void verifyBehaviour(){
final Integer element = 5;
final SingletonIterator<Integer> iterator = new SingletonIterator<>(element);
assertTrue(iterator.hasNext());
assertTrue(iterator.hasNext());
assertEquals(Integer.valueOf(5), iterator.next());
assertFalse(iterator.hasNext());
assertNull(iterator.next());
}
}

View file

@ -0,0 +1,68 @@
package ca.joeltherrien.randomforest.utils;
import org.junit.jupiter.api.Test;
import java.util.Arrays;
import java.util.List;
import static org.junit.jupiter.api.Assertions.*;
public class TestUniqueSubsetValueIterator {
@Test
public void testIterator1(){
final List<Integer> testData = Arrays.asList(
1,1,2,3,5,5,5,6,7,7
);
final Integer[] indexes = new Integer[]{2,3,4,5};
final UniqueValueIterator<Integer> uniqueValueIterator = new UniqueValueIterator<>(testData.iterator());
final UniqueSubsetValueIterator<Integer> iterator = new UniqueSubsetValueIterator<>(uniqueValueIterator, indexes);
// we expect to get 2, 3, and 5 back. 5 should happen only once
assertEquals(iterator.getIndex(), 0);
assertTrue(iterator.hasNext());
assertEquals(iterator.next().intValue(), 2);
assertEquals(iterator.getIndex(), 3);
assertTrue(iterator.hasNext());
assertEquals(iterator.next().intValue(), 3);
assertEquals(iterator.getIndex(), 4);
assertTrue(iterator.hasNext());
assertEquals(iterator.next().intValue(), 5);
assertEquals(iterator.getIndex(), 7);
assertFalse(iterator.hasNext());
}
@Test
public void testIterator2(){
final List<Integer> testData = Arrays.asList(
1,1,2,3,5,5,5,6,7,7
);
final Integer[] indexes = new Integer[]{1,8};
final UniqueValueIterator<Integer> uniqueValueIterator = new UniqueValueIterator<>(testData.iterator());
final UniqueSubsetValueIterator<Integer> iterator = new UniqueSubsetValueIterator<>(uniqueValueIterator, indexes);
assertEquals(iterator.getIndex(), 0);
assertTrue(iterator.hasNext());
assertEquals(iterator.next().intValue(), 1);
assertEquals(iterator.getIndex(), 2);
assertTrue(iterator.hasNext());
assertEquals(iterator.next().intValue(), 7);
assertEquals(iterator.getIndex(), 10);
assertFalse(iterator.hasNext());
}
}

View file

@ -0,0 +1,112 @@
package ca.joeltherrien.randomforest.utils;
import org.junit.jupiter.api.Test;
import java.util.Arrays;
import java.util.List;
import static org.junit.jupiter.api.Assertions.*;
public class TestUniqueValueIterator {
@Test
public void testIterator1(){
final List<Integer> testData = Arrays.asList(
1,1,2,3,5,5,5,6,7,7
);
final UniqueValueIterator<Integer> iterator = new UniqueValueIterator<>(testData.iterator());
assertEquals(iterator.getIndex(), 0);
assertTrue(iterator.hasNext());
assertEquals(iterator.next().intValue(), 1);
assertEquals(iterator.getIndex(), 2);
assertTrue(iterator.hasNext());
assertEquals(iterator.next().intValue(), 2);
assertEquals(iterator.getIndex(), 3);
assertTrue(iterator.hasNext());
assertEquals(iterator.next().intValue(), 3);
assertEquals(iterator.getIndex(), 4);
assertTrue(iterator.hasNext());
assertEquals(iterator.next().intValue(), 5);
assertEquals(iterator.getIndex(), 7);
assertTrue(iterator.hasNext());
assertEquals(iterator.next().intValue(), 6);
assertEquals(iterator.getIndex(), 8);
assertTrue(iterator.hasNext());
assertEquals(iterator.next().intValue(), 7);
assertEquals(iterator.getIndex(), 10);
assertTrue(!iterator.hasNext());
}
@Test
public void testIterator2(){
final List<Integer> testData = Arrays.asList(
1,2,3,5,5,5,6,7 // same numbers; but 1 and 7 only appear once each
);
final UniqueValueIterator<Integer> iterator = new UniqueValueIterator<>(testData.iterator());
assertEquals(iterator.getIndex(), 0);
assertTrue(iterator.hasNext());
assertEquals(iterator.next().intValue(), 1);
assertEquals(iterator.getIndex(), 1);
assertTrue(iterator.hasNext());
assertEquals(iterator.next().intValue(), 2);
assertEquals(iterator.getIndex(), 2);
assertTrue(iterator.hasNext());
assertEquals(iterator.next().intValue(), 3);
assertEquals(iterator.getIndex(), 3);
assertTrue(iterator.hasNext());
assertEquals(iterator.next().intValue(), 5);
assertEquals(iterator.getIndex(), 6);
assertTrue(iterator.hasNext());
assertEquals(iterator.next().intValue(), 6);
assertEquals(iterator.getIndex(), 7);
assertTrue(iterator.hasNext());
assertEquals(iterator.next().intValue(), 7);
assertEquals(iterator.getIndex(), 8);
assertFalse(iterator.hasNext());
}
@Test
public void testIterator3(){
final List<Integer> testData = Arrays.asList(
1,1,1,1,1,1,1,2,2,2,2,2,3
);
final UniqueValueIterator<Integer> iterator = new UniqueValueIterator<>(testData.iterator());
assertEquals(iterator.getIndex(), 0);
assertTrue(iterator.hasNext());
assertEquals(iterator.next().intValue(), 1);
assertEquals(iterator.getIndex(), 7);
assertTrue(iterator.hasNext());
assertEquals(iterator.next().intValue(), 2);
assertEquals(iterator.getIndex(), 12);
assertTrue(iterator.hasNext());
assertEquals(iterator.next().intValue(), 3);
assertEquals(iterator.getIndex(), 13);
assertFalse(iterator.hasNext());
}
}

View file

@ -2,7 +2,7 @@ package ca.joeltherrien.randomforest.workshop;
import ca.joeltherrien.randomforest.*;
import ca.joeltherrien.randomforest.covariates.Covariate;
import ca.joeltherrien.randomforest.covariates.NumericCovariate;
import ca.joeltherrien.randomforest.covariates.numeric.NumericCovariate;
import ca.joeltherrien.randomforest.responses.regression.MeanResponseCombiner;
import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceGroupDifferentiator;
import ca.joeltherrien.randomforest.tree.ForestTrainer;

View file

@ -3,7 +3,7 @@ package ca.joeltherrien.randomforest.workshop;
import ca.joeltherrien.randomforest.covariates.Covariate;
import ca.joeltherrien.randomforest.CovariateRow;
import ca.joeltherrien.randomforest.covariates.NumericCovariate;
import ca.joeltherrien.randomforest.covariates.numeric.NumericCovariate;
import ca.joeltherrien.randomforest.Row;
import ca.joeltherrien.randomforest.responses.regression.MeanResponseCombiner;
import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceGroupDifferentiator;
@ -54,13 +54,14 @@ public class TrainSingleTree {
.covariates(covariateNames)
.responseCombiner(new MeanResponseCombiner())
.maxNodeDepth(30)
.mtry(2)
.nodeSize(5)
.numberOfSplits(0)
.build();
final long startTime = System.currentTimeMillis();
final Node<Double> baseNode = treeTrainer.growTree(trainingSet);
final Node<Double> baseNode = treeTrainer.growTree(trainingSet, new Random());
final long endTime = System.currentTimeMillis();
System.out.println(((double)(endTime - startTime))/1000.0);

View file

@ -4,7 +4,7 @@ package ca.joeltherrien.randomforest.workshop;
import ca.joeltherrien.randomforest.*;
import ca.joeltherrien.randomforest.covariates.Covariate;
import ca.joeltherrien.randomforest.covariates.FactorCovariate;
import ca.joeltherrien.randomforest.covariates.NumericCovariate;
import ca.joeltherrien.randomforest.covariates.numeric.NumericCovariate;
import ca.joeltherrien.randomforest.responses.regression.MeanResponseCombiner;
import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceGroupDifferentiator;
import ca.joeltherrien.randomforest.tree.Node;
@ -13,7 +13,6 @@ import ca.joeltherrien.randomforest.utils.Utils;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.stream.Collectors;
import java.util.stream.DoubleStream;
@ -21,8 +20,6 @@ import java.util.stream.DoubleStream;
public class TrainSingleTreeFactor {
public static void main(String[] args) {
System.out.println("Hello world!");
final Random random = new Random(123);
final int n = 10000;
@ -84,7 +81,7 @@ public class TrainSingleTreeFactor {
final long startTime = System.currentTimeMillis();
final Node<Double> baseNode = treeTrainer.growTree(trainingSet);
final Node<Double> baseNode = treeTrainer.growTree(trainingSet, random);
final long endTime = System.currentTimeMillis();
System.out.println(((double)(endTime - startTime))/1000.0);