Massive refactor; Use Iterators/Updaters when calculating difference scores for faster calculations.

Changed the covariates to be more clever with how they produce the different splits. In the future (not yet implemented) a clever GroupDifferentiator
could update the current score calculation based just on how many rows moved from one hand to the other. There were a few other changes as well;
TreeTrainer#growTree now accepts a Random as a parameter which is used throughout the entire growing process. This means it's now theoretically
possible to grow trees using a seed, so that results can be fully reproducible.
This commit is contained in:
Joel Therrien 2019-01-09 21:31:27 -08:00
parent e892076a05
commit a5fe856857
45 changed files with 741 additions and 212 deletions

View file

@ -4,20 +4,20 @@ import lombok.RequiredArgsConstructor;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.concurrent.ThreadLocalRandom; import java.util.Random;
@RequiredArgsConstructor @RequiredArgsConstructor
public class Bootstrapper<T> { public class Bootstrapper<T> {
final private List<T> originalData; final private List<T> originalData;
public List<T> bootstrap(){ public List<T> bootstrap(Random random){
final int n = originalData.size(); final int n = originalData.size();
final List<T> newList = new ArrayList<>(n); final List<T> newList = new ArrayList<>(n);
for(int i=0; i<n; i++){ for(int i=0; i<n; i++){
final int index = ThreadLocalRandom.current().nextInt(n); final int index = random.nextInt(n);
newList.add(originalData.get(index)); newList.add(originalData.get(index));
} }

View file

@ -17,7 +17,7 @@ public class CovariateRow implements Serializable {
@Getter @Getter
private final int id; private final int id;
public Covariate.Value<?> getCovariateValue(Covariate covariate){ public <V> Covariate.Value<V> getCovariateValue(Covariate<V> covariate){
return valueArray[covariate.getIndex()]; return valueArray[covariate.getIndex()];
} }

View file

@ -1,9 +1,9 @@
package ca.joeltherrien.randomforest; package ca.joeltherrien.randomforest;
import ca.joeltherrien.randomforest.covariates.BooleanCovariateSettings; import ca.joeltherrien.randomforest.covariates.settings.BooleanCovariateSettings;
import ca.joeltherrien.randomforest.covariates.Covariate; import ca.joeltherrien.randomforest.covariates.Covariate;
import ca.joeltherrien.randomforest.covariates.FactorCovariateSettings; import ca.joeltherrien.randomforest.covariates.settings.FactorCovariateSettings;
import ca.joeltherrien.randomforest.covariates.NumericCovariateSettings; import ca.joeltherrien.randomforest.covariates.settings.NumericCovariateSettings;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskErrorRateCalculator; import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskErrorRateCalculator;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskFunctions; import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskFunctions;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse; import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;

View file

@ -17,6 +17,8 @@ public class Row<Y> extends CovariateRow {
} }
public Y getResponse() { public Y getResponse() {
return this.response; return this.response;
} }

View file

@ -1,7 +1,7 @@
package ca.joeltherrien.randomforest; package ca.joeltherrien.randomforest;
import ca.joeltherrien.randomforest.covariates.Covariate; import ca.joeltherrien.randomforest.covariates.Covariate;
import ca.joeltherrien.randomforest.covariates.CovariateSettings; import ca.joeltherrien.randomforest.covariates.settings.CovariateSettings;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse; import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponseWithCensorTime; import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponseWithCensorTime;
import ca.joeltherrien.randomforest.responses.competingrisk.combiner.CompetingRiskFunctionCombiner; import ca.joeltherrien.randomforest.responses.competingrisk.combiner.CompetingRiskFunctionCombiner;
@ -10,7 +10,6 @@ import ca.joeltherrien.randomforest.responses.competingrisk.differentiator.GrayL
import ca.joeltherrien.randomforest.responses.competingrisk.differentiator.GrayLogRankSingleGroupDifferentiator; import ca.joeltherrien.randomforest.responses.competingrisk.differentiator.GrayLogRankSingleGroupDifferentiator;
import ca.joeltherrien.randomforest.responses.competingrisk.differentiator.LogRankMultipleGroupDifferentiator; import ca.joeltherrien.randomforest.responses.competingrisk.differentiator.LogRankMultipleGroupDifferentiator;
import ca.joeltherrien.randomforest.responses.competingrisk.differentiator.LogRankSingleGroupDifferentiator; import ca.joeltherrien.randomforest.responses.competingrisk.differentiator.LogRankSingleGroupDifferentiator;
import ca.joeltherrien.randomforest.responses.regression.MeanGroupDifferentiator;
import ca.joeltherrien.randomforest.responses.regression.MeanResponseCombiner; import ca.joeltherrien.randomforest.responses.regression.MeanResponseCombiner;
import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceGroupDifferentiator; import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceGroupDifferentiator;
import ca.joeltherrien.randomforest.tree.GroupDifferentiator; import ca.joeltherrien.randomforest.tree.GroupDifferentiator;
@ -68,9 +67,6 @@ public class Settings {
GROUP_DIFFERENTIATOR_MAP.put(name.toLowerCase(), groupDifferentiatorConstructor); GROUP_DIFFERENTIATOR_MAP.put(name.toLowerCase(), groupDifferentiatorConstructor);
} }
static{ static{
registerGroupDifferentiatorConstructor("MeanGroupDifferentiator",
(node) -> new MeanGroupDifferentiator()
);
registerGroupDifferentiatorConstructor("WeightedVarianceGroupDifferentiator", registerGroupDifferentiatorConstructor("WeightedVarianceGroupDifferentiator",
(node) -> new WeightedVarianceGroupDifferentiator() (node) -> new WeightedVarianceGroupDifferentiator()
); );

View file

@ -1,14 +1,15 @@
package ca.joeltherrien.randomforest.covariates; package ca.joeltherrien.randomforest.covariates;
import ca.joeltherrien.randomforest.Row;
import ca.joeltherrien.randomforest.tree.Split;
import ca.joeltherrien.randomforest.utils.SingletonIterator;
import lombok.Getter; import lombok.Getter;
import lombok.RequiredArgsConstructor; import lombok.RequiredArgsConstructor;
import java.util.Collection; import java.util.*;
import java.util.Collections;
import java.util.List;
@RequiredArgsConstructor @RequiredArgsConstructor
public final class BooleanCovariate implements Covariate<Boolean>{ public final class BooleanCovariate implements Covariate<Boolean> {
@Getter @Getter
private final String name; private final String name;
@ -19,8 +20,8 @@ public final class BooleanCovariate implements Covariate<Boolean>{
private final BooleanSplitRule splitRule = new BooleanSplitRule(); // there's only one possible rule for BooleanCovariates. private final BooleanSplitRule splitRule = new BooleanSplitRule(); // there's only one possible rule for BooleanCovariates.
@Override @Override
public Collection<BooleanSplitRule> generateSplitRules(List<Value<Boolean>> data, int number) { public <Y> Iterator<Split<Y, Boolean>> generateSplitRuleUpdater(List<Row<Y>> data, int number, Random random) {
return Collections.singleton(splitRule); return new SingletonIterator<>(BooleanCovariate.this.splitRule.applyRule(data));
} }
@Override @Override
@ -74,6 +75,7 @@ public final class BooleanCovariate implements Covariate<Boolean>{
} }
} }
public class BooleanSplitRule implements SplitRule<Boolean>{ public class BooleanSplitRule implements SplitRule<Boolean>{
@Override @Override

View file

@ -5,10 +5,7 @@ import ca.joeltherrien.randomforest.Row;
import ca.joeltherrien.randomforest.tree.Split; import ca.joeltherrien.randomforest.tree.Split;
import java.io.Serializable; import java.io.Serializable;
import java.util.ArrayList; import java.util.*;
import java.util.Collection;
import java.util.LinkedList;
import java.util.List;
import java.util.concurrent.ThreadLocalRandom; import java.util.concurrent.ThreadLocalRandom;
public interface Covariate<V> extends Serializable { public interface Covariate<V> extends Serializable {
@ -17,7 +14,7 @@ public interface Covariate<V> extends Serializable {
int getIndex(); int getIndex();
Collection<? extends SplitRule<V>> generateSplitRules(final List<Value<V>> data, final int number); <Y> Iterator<Split<Y, V>> generateSplitRuleUpdater(final List<Row<Y>> data, final int number, final Random random);
Value<V> createValue(V value); Value<V> createValue(V value);
@ -39,6 +36,16 @@ public interface Covariate<V> extends Serializable {
} }
interface SplitRuleUpdater<Y, V> extends Iterator<Split<Y, V>>{
Split<Y, V> currentSplit();
SplitUpdate<Y, V> nextUpdate();
}
interface SplitUpdate<Y, V> {
SplitRule<V> getSplitRule();
Collection<Row<Y>> rowsMovedToLeftHand();
}
interface SplitRule<V> extends Serializable{ interface SplitRule<V> extends Serializable{
Covariate<V> getParent(); Covariate<V> getParent();
@ -51,7 +58,7 @@ public interface Covariate<V> extends Serializable {
* @param <Y> * @param <Y>
* @return * @return
*/ */
default <Y> Split<Y> applyRule(List<Row<Y>> rows) { default <Y> Split<Y, V> applyRule(List<Row<Y>> rows) {
final List<Row<Y>> leftHand = new LinkedList<>(); final List<Row<Y>> leftHand = new LinkedList<>();
final List<Row<Y>> rightHand = new LinkedList<>(); final List<Row<Y>> rightHand = new LinkedList<>();
@ -77,7 +84,7 @@ public interface Covariate<V> extends Serializable {
} }
return new Split<>(leftHand, rightHand, missingValueRows); return new Split<>(this, leftHand, rightHand, missingValueRows);
} }
default boolean isLeftHand(CovariateRow row, final double probabilityNaLeftHand){ default boolean isLeftHand(CovariateRow row, final double probabilityNaLeftHand){

View file

@ -1,5 +1,7 @@
package ca.joeltherrien.randomforest.covariates; package ca.joeltherrien.randomforest.covariates;
import ca.joeltherrien.randomforest.Row;
import ca.joeltherrien.randomforest.tree.Split;
import lombok.EqualsAndHashCode; import lombok.EqualsAndHashCode;
import lombok.Getter; import lombok.Getter;
@ -42,17 +44,14 @@ public final class FactorCovariate implements Covariate<String>{
@Override @Override
public Set<FactorSplitRule> generateSplitRules(List<Value<String>> data, int number) { public <Y> Iterator<Split<Y, String>> generateSplitRuleUpdater(List<Row<Y>> data, int number, Random random) {
final Set<FactorSplitRule> splitRules = new HashSet<>(); final Set<Split<Y, String>> splits = new HashSet<>();
// This is to ensure we don't get stuck in an infinite loop for small factors // This is to ensure we don't get stuck in an infinite loop for small factors
number = Math.min(number, numberOfPossiblePairings); number = Math.min(number, numberOfPossiblePairings);
final Random random = ThreadLocalRandom.current();
final List<FactorValue> levels = new ArrayList<>(factorLevels.values()); final List<FactorValue> levels = new ArrayList<>(factorLevels.values());
while(splits.size() < number){
while(splitRules.size() < number){
Collections.shuffle(levels, random); Collections.shuffle(levels, random);
final Set<FactorValue> leftSideValues = new HashSet<>(); final Set<FactorValue> leftSideValues = new HashSet<>();
leftSideValues.add(levels.get(0)); leftSideValues.add(levels.get(0));
@ -63,13 +62,14 @@ public final class FactorCovariate implements Covariate<String>{
} }
} }
splitRules.add(new FactorSplitRule(leftSideValues)); splits.add(new FactorSplitRule(leftSideValues).applyRule(data));
} }
return splitRules; return splits.iterator();
} }
@Override @Override
public FactorValue createValue(String value) { public FactorValue createValue(String value) {
if(value == null || value.equalsIgnoreCase("na")){ if(value == null || value.equalsIgnoreCase("na")){

View file

@ -1,17 +1,21 @@
package ca.joeltherrien.randomforest.covariates; package ca.joeltherrien.randomforest.covariates.numeric;
import ca.joeltherrien.randomforest.Row;
import ca.joeltherrien.randomforest.covariates.Covariate;
import ca.joeltherrien.randomforest.utils.IndexedIterator;
import ca.joeltherrien.randomforest.utils.UniqueSubsetValueIterator;
import ca.joeltherrien.randomforest.utils.UniqueValueIterator;
import lombok.EqualsAndHashCode; import lombok.EqualsAndHashCode;
import lombok.Getter; import lombok.Getter;
import lombok.RequiredArgsConstructor; import lombok.RequiredArgsConstructor;
import lombok.ToString; import lombok.ToString;
import java.util.*; import java.util.*;
import java.util.concurrent.ThreadLocalRandom;
import java.util.stream.Collectors; import java.util.stream.Collectors;
@RequiredArgsConstructor @RequiredArgsConstructor
@ToString @ToString
public final class NumericCovariate implements Covariate<Double>{ public final class NumericCovariate implements Covariate<Double> {
@Getter @Getter
private final String name; private final String name;
@ -20,40 +24,44 @@ public final class NumericCovariate implements Covariate<Double>{
private final int index; private final int index;
@Override @Override
public Collection<NumericSplitRule> generateSplitRules(List<Value<Double>> data, int number) { public <Y> NumericSplitRuleUpdater<Y> generateSplitRuleUpdater(List<Row<Y>> data, int number, Random random) {
data = data.stream()
.filter(row -> !row.getCovariateValue(this).isNA())
.sorted((r1, r2) -> {
Double d1 = r1.getCovariateValue(this).getValue();
Double d2 = r2.getCovariateValue(this).getValue();
final Random random = ThreadLocalRandom.current(); return d1.compareTo(d2);
})
.collect(Collectors.toList());
// only work with non-NA values Iterator<Double> sortedDataIterator = data.stream()
data = data.stream().filter(value -> !value.isNA()).collect(Collectors.toList()); .map(row -> row.getCovariateValue(this).getValue())
//data = data.stream().filter(value -> !value.isNA()).distinct().collect(Collectors.toList()); // TODO which to use? .filter(v -> v != null)
.iterator();
// for this implementation we need to shuffle the data
final List<Value<Double>> shuffledData; final IndexedIterator<Double> dataIterator;
if(number >= data.size()){ if(number == 0){
shuffledData = data; dataIterator = new UniqueValueIterator<>(sortedDataIterator);
} }
else{ // only need the top number entries else{ // random splitting; we will not weight based on how many times a Row appears in the bootstrap sample
shuffledData = new ArrayList<>(number); final TreeSet<Integer> indexSet = new TreeSet<>();
final Set<Integer> indexesToUse = new HashSet<>();
//final List<Integer> indexesToUse = new ArrayList<>(); // TODO which to use?
while(indexesToUse.size() < number){ final int maxIndex = data.size();
final int index = random.nextInt(data.size());
if(indexesToUse.add(index)){ for(int i=0; i<number; i++){
shuffledData.add(data.get(index)); indexSet.add(random.nextInt(maxIndex));
}
} }
dataIterator = new UniqueSubsetValueIterator<>(
new UniqueValueIterator<>(sortedDataIterator),
indexSet.toArray(new Integer[indexSet.size()]) // TODO verify this is ordered
);
} }
return shuffledData.stream() return new NumericSplitRuleUpdater<>(this, data, dataIterator);
.mapToDouble(v -> v.getValue())
.mapToObj(threshold -> new NumericSplitRule(threshold))
.collect(Collectors.toSet());
// by returning a set we'll make everything far more efficient as a lot of rules can repeat due to bootstrapping
} }
@ -101,7 +109,7 @@ public final class NumericCovariate implements Covariate<Double>{
private final double threshold; private final double threshold;
private NumericSplitRule(final double threshold){ NumericSplitRule(final double threshold){
this.threshold = threshold; this.threshold = threshold;
} }

View file

@ -0,0 +1,79 @@
package ca.joeltherrien.randomforest.covariates.numeric;
import ca.joeltherrien.randomforest.Row;
import ca.joeltherrien.randomforest.covariates.Covariate;
import ca.joeltherrien.randomforest.tree.Split;
import ca.joeltherrien.randomforest.utils.IndexedIterator;
import java.util.Collections;
import java.util.List;
public class NumericSplitRuleUpdater<Y> implements Covariate.SplitRuleUpdater<Y, Double> {
private final NumericCovariate covariate;
private final List<Row<Y>> orderedData;
private final IndexedIterator<Double> dataIterator;
private Split<Y, Double> currentSplit;
public NumericSplitRuleUpdater(final NumericCovariate covariate, final List<Row<Y>> orderedData, final IndexedIterator<Double> iterator){
this.covariate = covariate;
this.orderedData = orderedData;
this.dataIterator = iterator;
final List<Row<Y>> leftHandList = Collections.emptyList();
final List<Row<Y>> rightHandList = orderedData;
this.currentSplit = new Split<>(
covariate.new NumericSplitRule(Double.MIN_VALUE),
leftHandList,
rightHandList,
Collections.emptyList());
}
@Override
public Split<Y, Double> currentSplit() {
return this.currentSplit;
}
@Override
public NumericSplitUpdate<Y> nextUpdate() {
if(hasNext()){
final int currentPosition = dataIterator.getIndex();
final Double splitValue = dataIterator.next();
final int newPosition = dataIterator.getIndex();
final List<Row<Y>> rowsMoved = orderedData.subList(currentPosition, newPosition);
final NumericCovariate.NumericSplitRule splitRule = covariate.new NumericSplitRule(splitValue);
// Update current split
this.currentSplit = new Split<>(
splitRule,
orderedData.subList(0, newPosition),
orderedData.subList(newPosition, orderedData.size()),
Collections.emptyList());
return new NumericSplitUpdate<>(splitRule, rowsMoved);
}
return null;
}
@Override
public boolean hasNext() {
return dataIterator.hasNext();
}
@Override
public Split<Y, Double> next() {
if(hasNext()){
nextUpdate();
}
return this.currentSplit();
}
}

View file

@ -0,0 +1,24 @@
package ca.joeltherrien.randomforest.covariates.numeric;
import ca.joeltherrien.randomforest.Row;
import ca.joeltherrien.randomforest.covariates.Covariate;
import lombok.AllArgsConstructor;
import java.util.Collection;
@AllArgsConstructor
public class NumericSplitUpdate<Y> implements Covariate.SplitUpdate<Y, Double> {
private final NumericCovariate.NumericSplitRule numericSplitRule;
private final Collection<Row<Y>> rowsMoved;
@Override
public NumericCovariate.NumericSplitRule getSplitRule() {
return numericSplitRule;
}
@Override
public Collection<Row<Y>> rowsMovedToLeftHand() {
return rowsMoved;
}
}

View file

@ -1,5 +1,6 @@
package ca.joeltherrien.randomforest.covariates; package ca.joeltherrien.randomforest.covariates.settings;
import ca.joeltherrien.randomforest.covariates.BooleanCovariate;
import lombok.Data; import lombok.Data;
import lombok.NoArgsConstructor; import lombok.NoArgsConstructor;

View file

@ -1,5 +1,6 @@
package ca.joeltherrien.randomforest.covariates; package ca.joeltherrien.randomforest.covariates.settings;
import ca.joeltherrien.randomforest.covariates.Covariate;
import com.fasterxml.jackson.annotation.JsonSubTypes; import com.fasterxml.jackson.annotation.JsonSubTypes;
import com.fasterxml.jackson.annotation.JsonTypeInfo; import com.fasterxml.jackson.annotation.JsonTypeInfo;
import lombok.Getter; import lombok.Getter;

View file

@ -1,5 +1,6 @@
package ca.joeltherrien.randomforest.covariates; package ca.joeltherrien.randomforest.covariates.settings;
import ca.joeltherrien.randomforest.covariates.FactorCovariate;
import lombok.Data; import lombok.Data;
import lombok.NoArgsConstructor; import lombok.NoArgsConstructor;

View file

@ -1,5 +1,6 @@
package ca.joeltherrien.randomforest.covariates; package ca.joeltherrien.randomforest.covariates.settings;
import ca.joeltherrien.randomforest.covariates.numeric.NumericCovariate;
import lombok.Data; import lombok.Data;
import lombok.NoArgsConstructor; import lombok.NoArgsConstructor;

View file

@ -3,6 +3,7 @@ package ca.joeltherrien.randomforest.responses.competingrisk.differentiator;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse; import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskSets; import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskSets;
import ca.joeltherrien.randomforest.tree.GroupDifferentiator; import ca.joeltherrien.randomforest.tree.GroupDifferentiator;
import ca.joeltherrien.randomforest.tree.SimpleGroupDifferentiator;
import lombok.AllArgsConstructor; import lombok.AllArgsConstructor;
import lombok.Data; import lombok.Data;
@ -14,11 +15,7 @@ import java.util.stream.Stream;
* modifies the abstract method. * modifies the abstract method.
* *
*/ */
public abstract class CompetingRiskGroupDifferentiator<Y extends CompetingRiskResponse> implements GroupDifferentiator<Y>{ public abstract class CompetingRiskGroupDifferentiator<Y extends CompetingRiskResponse> extends SimpleGroupDifferentiator<Y> {
@Override
public abstract Double differentiate(List<Y> leftHand, List<Y> rightHand);
/** /**
* Calculates the log rank value (or the Gray's test value) for a *specific* event cause. * Calculates the log rank value (or the Gray's test value) for a *specific* event cause.

View file

@ -17,7 +17,7 @@ public class GrayLogRankMultipleGroupDifferentiator extends CompetingRiskGroupDi
private final int[] events; private final int[] events;
@Override @Override
public Double differentiate(List<CompetingRiskResponseWithCensorTime> leftHand, List<CompetingRiskResponseWithCensorTime> rightHand) { public Double getScore(List<CompetingRiskResponseWithCensorTime> leftHand, List<CompetingRiskResponseWithCensorTime> rightHand) {
if(leftHand.size() == 0 || rightHand.size() == 0){ if(leftHand.size() == 0 || rightHand.size() == 0){
return null; return null;
} }

View file

@ -18,7 +18,7 @@ public class GrayLogRankSingleGroupDifferentiator extends CompetingRiskGroupDiff
private final int[] events; private final int[] events;
@Override @Override
public Double differentiate(List<CompetingRiskResponseWithCensorTime> leftHand, List<CompetingRiskResponseWithCensorTime> rightHand) { public Double getScore(List<CompetingRiskResponseWithCensorTime> leftHand, List<CompetingRiskResponseWithCensorTime> rightHand) {
if(leftHand.size() == 0 || rightHand.size() == 0){ if(leftHand.size() == 0 || rightHand.size() == 0){
return null; return null;
} }

View file

@ -17,7 +17,7 @@ public class LogRankMultipleGroupDifferentiator extends CompetingRiskGroupDiffer
private final int[] events; private final int[] events;
@Override @Override
public Double differentiate(List<CompetingRiskResponse> leftHand, List<CompetingRiskResponse> rightHand) { public Double getScore(List<CompetingRiskResponse> leftHand, List<CompetingRiskResponse> rightHand) {
if(leftHand.size() == 0 || rightHand.size() == 0){ if(leftHand.size() == 0 || rightHand.size() == 0){
return null; return null;
} }

View file

@ -18,7 +18,7 @@ public class LogRankSingleGroupDifferentiator extends CompetingRiskGroupDifferen
private final int[] events; private final int[] events;
@Override @Override
public Double differentiate(List<CompetingRiskResponse> leftHand, List<CompetingRiskResponse> rightHand) { public Double getScore(List<CompetingRiskResponse> leftHand, List<CompetingRiskResponse> rightHand) {
if(leftHand.size() == 0 || rightHand.size() == 0){ if(leftHand.size() == 0 || rightHand.size() == 0){
return null; return null;
} }

View file

@ -1,26 +0,0 @@
package ca.joeltherrien.randomforest.responses.regression;
import ca.joeltherrien.randomforest.tree.GroupDifferentiator;
import java.util.List;
public class MeanGroupDifferentiator implements GroupDifferentiator<Double> {
@Override
public Double differentiate(List<Double> leftHand, List<Double> rightHand) {
double leftHandSize = leftHand.size();
double rightHandSize = rightHand.size();
if(leftHandSize == 0 || rightHandSize == 0){
return null;
}
double leftHandMean = leftHand.stream().mapToDouble(db -> db/leftHandSize).sum();
double rightHandMean = rightHand.stream().mapToDouble(db -> db/rightHandSize).sum();
return Math.abs(leftHandMean - rightHandMean);
}
}

View file

@ -1,13 +1,13 @@
package ca.joeltherrien.randomforest.responses.regression; package ca.joeltherrien.randomforest.responses.regression;
import ca.joeltherrien.randomforest.tree.GroupDifferentiator; import ca.joeltherrien.randomforest.tree.SimpleGroupDifferentiator;
import java.util.List; import java.util.List;
public class WeightedVarianceGroupDifferentiator implements GroupDifferentiator<Double> { public class WeightedVarianceGroupDifferentiator extends SimpleGroupDifferentiator<Double> {
@Override @Override
public Double differentiate(List<Double> leftHand, List<Double> rightHand) { public Double getScore(List<Double> leftHand, List<Double> rightHand) {
final double leftHandSize = leftHand.size(); final double leftHandSize = leftHand.size();
final double rightHandSize = rightHand.size(); final double rightHandSize = rightHand.size();

View file

@ -14,8 +14,10 @@ import java.io.IOException;
import java.io.ObjectOutputStream; import java.io.ObjectOutputStream;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.Random;
import java.util.concurrent.ExecutorService; import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors; import java.util.concurrent.Executors;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import java.util.stream.Stream; import java.util.stream.Stream;
@ -45,17 +47,17 @@ public class ForestTrainer<Y, TO, FO> {
this.covariates = covariates; this.covariates = covariates;
this.treeResponseCombiner = settings.getTreeCombiner(); this.treeResponseCombiner = settings.getTreeCombiner();
this.treeTrainer = new TreeTrainer<>(settings, covariates); this.treeTrainer = new TreeTrainer<>(settings, covariates);
} }
public Forest<TO, FO> trainSerial(){ public Forest<TO, FO> trainSerial(){
final List<Tree<TO>> trees = new ArrayList<>(ntree); final List<Tree<TO>> trees = new ArrayList<>(ntree);
final Bootstrapper<Row<Y>> bootstrapper = new Bootstrapper<>(data); final Bootstrapper<Row<Y>> bootstrapper = new Bootstrapper<>(data);
final Random random = new Random();
for(int j=0; j<ntree; j++){ for(int j=0; j<ntree; j++){
trees.add(trainTree(bootstrapper)); trees.add(trainTree(bootstrapper, random));
if(displayProgress){ if(displayProgress){
if(j==0) { if(j==0) {
@ -162,9 +164,9 @@ public class ForestTrainer<Y, TO, FO> {
} }
private Tree<TO> trainTree(final Bootstrapper<Row<Y>> bootstrapper){ private Tree<TO> trainTree(final Bootstrapper<Row<Y>> bootstrapper, Random random){
final List<Row<Y>> bootstrappedData = bootstrapper.bootstrap(); final List<Row<Y>> bootstrappedData = bootstrapper.bootstrap(random);
return treeTrainer.growTree(bootstrappedData); return treeTrainer.growTree(bootstrappedData, random);
} }
public void saveTree(final Tree<TO> tree, String name) throws IOException { public void saveTree(final Tree<TO> tree, String name) throws IOException {
@ -193,7 +195,8 @@ public class ForestTrainer<Y, TO, FO> {
@Override @Override
public void run() { public void run() {
final Tree<TO> tree = trainTree(bootstrapper); // ThreadLocalRandom should make sure we don't duplicate seeds
final Tree<TO> tree = trainTree(bootstrapper, ThreadLocalRandom.current());
// should be okay as the list structure isn't changing // should be okay as the list structure isn't changing
treeList.set(treeIndex, tree); treeList.set(treeIndex, tree);
@ -216,7 +219,8 @@ public class ForestTrainer<Y, TO, FO> {
@Override @Override
public void run() { public void run() {
final Tree<TO> tree = trainTree(bootstrapper); // ThreadLocalRandom should make sure we don't duplicate seeds
final Tree<TO> tree = trainTree(bootstrapper, ThreadLocalRandom.current());
try { try {
saveTree(tree, filename); saveTree(tree, filename);

View file

@ -1,15 +1,17 @@
package ca.joeltherrien.randomforest.tree; package ca.joeltherrien.randomforest.tree;
import java.util.List; import java.util.Iterator;
/** /**
* When choosing an optimal node to split on, we choose the split that maximizes the difference between the two groups. * When choosing an optimal node to split on, we choose the split that maximizes the difference between the two groups.
* The GroupDifferentiator has one method that outputs a score to show how different groups are. The larger the score, * The GroupDifferentiator has one method that cycles through an iterator of Splits (FYI; check if the iterator is an
* the greater the difference. * instance of Covariate.SplitRuleUpdater; in which case you get access to the rows that change between splits)
* *
* If you want to implement a very trivial GroupDifferentiator that just takes two Lists as arguments, try extending
* SimpleGroupDifferentiator.
*/ */
public interface GroupDifferentiator<Y> { public interface GroupDifferentiator<Y> {
Double differentiate(List<Y> leftHand, List<Y> rightHand); <V> SplitAndScore<Y, V> differentiate(Iterator<Split<Y, V>> splitIterator);
} }

View file

@ -0,0 +1,50 @@
package ca.joeltherrien.randomforest.tree;
import ca.joeltherrien.randomforest.Row;
import java.util.Iterator;
import java.util.List;
import java.util.stream.Collectors;
public abstract class SimpleGroupDifferentiator<Y> implements GroupDifferentiator<Y> {
@Override
public <V> SplitAndScore<Y, V> differentiate(Iterator<Split<Y, V>> splitIterator) {
Double bestScore = null;
Split<Y, V> bestSplit = null;
while(splitIterator.hasNext()){
final Split<Y, V> candidateSplit = splitIterator.next();
final List<Y> leftHand = candidateSplit.getLeftHand().stream().map(Row::getResponse).collect(Collectors.toList());
final List<Y> rightHand = candidateSplit.getRightHand().stream().map(Row::getResponse).collect(Collectors.toList());
if(leftHand.isEmpty() || rightHand.isEmpty()){
continue;
}
final Double score = getScore(leftHand, rightHand);
if(score != null && (bestScore == null || score > bestScore)){
bestScore = score;
bestSplit = candidateSplit;
}
}
if(bestSplit == null){
return null;
}
return new SplitAndScore<>(bestSplit, bestScore);
}
/**
* Return a score; higher is better.
*
* @param leftHand
* @param rightHand
* @return
*/
public abstract Double getScore(List<Y> leftHand, List<Y> rightHand);
}

View file

@ -1,19 +1,21 @@
package ca.joeltherrien.randomforest.tree; package ca.joeltherrien.randomforest.tree;
import ca.joeltherrien.randomforest.Row; import ca.joeltherrien.randomforest.Row;
import ca.joeltherrien.randomforest.covariates.Covariate;
import lombok.Data; import lombok.Data;
import java.util.List; import java.util.List;
/** /**
* Very simple class that contains three lists; it's essentially a thruple. * Very simple class that contains three lists and a SplitRule.
* *
* @author joel * @author joel
* *
*/ */
@Data @Data
public class Split<Y> { public final class Split<Y, V> {
public final Covariate.SplitRule<V> splitRule;
public final List<Row<Y>> leftHand; public final List<Row<Y>> leftHand;
public final List<Row<Y>> rightHand; public final List<Row<Y>> rightHand;
public final List<Row<Y>> naHand; public final List<Row<Y>> naHand;

View file

@ -0,0 +1,15 @@
package ca.joeltherrien.randomforest.tree;
import lombok.AllArgsConstructor;
import lombok.Getter;
@AllArgsConstructor
public class SplitAndScore<Y, V> {
@Getter
private final Split<Y, V> split;
@Getter
private final Double score;
}

View file

@ -3,10 +3,11 @@ package ca.joeltherrien.randomforest.tree;
import ca.joeltherrien.randomforest.Row; import ca.joeltherrien.randomforest.Row;
import ca.joeltherrien.randomforest.Settings; import ca.joeltherrien.randomforest.Settings;
import ca.joeltherrien.randomforest.covariates.Covariate; import ca.joeltherrien.randomforest.covariates.Covariate;
import lombok.*; import lombok.AccessLevel;
import lombok.AllArgsConstructor;
import lombok.Builder;
import java.util.*; import java.util.*;
import java.util.concurrent.ThreadLocalRandom;
import java.util.stream.Collectors; import java.util.stream.Collectors;
@Builder @Builder
@ -45,20 +46,21 @@ public class TreeTrainer<Y, O> {
this.covariates = covariates; this.covariates = covariates;
} }
public Tree<O> growTree(List<Row<Y>> data){ public Tree<O> growTree(List<Row<Y>> data, Random random){
final Node<O> rootNode = growNode(data, 0); final Node<O> rootNode = growNode(data, 0, random);
return new Tree<>(rootNode, data.stream().mapToInt(Row::getId).toArray()); return new Tree<>(rootNode, data.stream().mapToInt(Row::getId).toArray());
} }
private Node<O> growNode(List<Row<Y>> data, int depth){ private Node<O> growNode(List<Row<Y>> data, int depth, Random random){
// See https://kogalur.github.io/randomForestSRC/theory.html#section3.1 (near bottom) // See https://kogalur.github.io/randomForestSRC/theory.html#section3.1 (near bottom)
if(data.size() >= 2*nodeSize && depth < maxNodeDepth && !nodeIsPure(data)){ if(data.size() >= 2*nodeSize && depth < maxNodeDepth && !nodeIsPure(data)){
final List<Covariate> covariatesToTry = selectCovariates(this.mtry); final List<Covariate> covariatesToTry = selectCovariates(this.mtry, random);
final SplitRuleAndSplit bestSplitRuleAndSplit = findBestSplitRule(data, covariatesToTry); final Split<Y,?> bestSplit = findBestSplitRule(data, covariatesToTry, random);
if(bestSplitRuleAndSplit.splitRule == null){
if(bestSplit == null){
return new TerminalNode<>( return new TerminalNode<>(
responseCombiner.combine( responseCombiner.combine(
@ -69,14 +71,31 @@ public class TreeTrainer<Y, O> {
} }
final Split<Y> split = bestSplitRuleAndSplit.split;
// Note that NAs have already been handled // Now that we have the best split; we need to handle any NAs that were dropped off
final double probabilityLeftHand = (double) bestSplit.leftHand.size() /
(double) (bestSplit.leftHand.size() + bestSplit.rightHand.size());
// Assign missing values to the split
for(Row<Y> row : data) {
if(row.getCovariateValue(bestSplit.getSplitRule().getParent()).isNA()) {
final boolean randomDecision = random.nextDouble() <= probabilityLeftHand;
if(randomDecision){
bestSplit.getLeftHand().add(row);
}
else{
bestSplit.getRightHand().add(row);
}
}
}
final Node<O> leftNode = growNode(split.leftHand, depth+1); final Node<O> leftNode = growNode(bestSplit.leftHand, depth+1, random);
final Node<O> rightNode = growNode(split.rightHand, depth+1); final Node<O> rightNode = growNode(bestSplit.rightHand, depth+1, random);
return new SplitNode<>(leftNode, rightNode, bestSplitRuleAndSplit.splitRule, bestSplitRuleAndSplit.probabilityLeftHand); return new SplitNode<>(leftNode, rightNode, bestSplit.getSplitRule(), probabilityLeftHand);
} }
else{ else{
@ -90,13 +109,13 @@ public class TreeTrainer<Y, O> {
} }
private List<Covariate> selectCovariates(int mtry){ private List<Covariate> selectCovariates(int mtry, Random random){
if(mtry >= covariates.size()){ if(mtry >= covariates.size()){
return covariates; return covariates;
} }
final List<Covariate> splitCovariates = new ArrayList<>(covariates); final List<Covariate> splitCovariates = new ArrayList<>(covariates);
Collections.shuffle(splitCovariates, ThreadLocalRandom.current()); Collections.shuffle(splitCovariates, random);
if (splitCovariates.size() > mtry) { if (splitCovariates.size() > mtry) {
splitCovariates.subList(mtry, splitCovariates.size()).clear(); splitCovariates.subList(mtry, splitCovariates.size()).clear();
@ -105,63 +124,29 @@ public class TreeTrainer<Y, O> {
return splitCovariates; return splitCovariates;
} }
private SplitRuleAndSplit findBestSplitRule(List<Row<Y>> data, List<Covariate> covariatesToTry){ private Split<Y, ?> findBestSplitRule(List<Row<Y>> data, List<Covariate> covariatesToTry, Random random){
final Random random = ThreadLocalRandom.current();
SplitRuleAndSplit bestSplitRuleAndSplit = new SplitRuleAndSplit();
double bestSplitScore = 0.0;
boolean first = true;
for(final Covariate covariate : covariatesToTry){ SplitAndScore<Y, ?> bestSplitAndScore = null;
final GroupDifferentiator noGenericDifferentiator = groupDifferentiator; // cause Java generics suck
final int numberToTry = numberOfSplits==0 ? data.size() : numberOfSplits; for(final Covariate covariate : covariatesToTry) {
final Iterator<Split> iterator = covariate.generateSplitRuleUpdater(data, this.numberOfSplits, random);
final Collection<Covariate.SplitRule> splitRulesToTry = covariate final SplitAndScore<Y, ?> candidateSplitAndScore = noGenericDifferentiator.differentiate(iterator);
.generateSplitRules(
data
.stream()
.map(row -> row.getCovariateValue(covariate))
.collect(Collectors.toList())
, numberToTry);
for(final Covariate.SplitRule possibleRule : splitRulesToTry){ if(candidateSplitAndScore != null) {
final Split<Y> possibleSplit = possibleRule.applyRule(data); if (bestSplitAndScore == null || candidateSplitAndScore.getScore() > bestSplitAndScore.getScore()) {
bestSplitAndScore = candidateSplitAndScore;
// We have to handle any NAs
if(possibleSplit.leftHand.size() == 0 && possibleSplit.rightHand.size() == 0 && possibleSplit.naHand.size() > 0){
throw new IllegalArgumentException("Can't apply " + this + " when there are rows with missing data and no non-missing value rows");
}
final double probabilityLeftHand = (double) possibleSplit.leftHand.size() / (double) (possibleSplit.leftHand.size() + possibleSplit.rightHand.size());
for(final Row<Y> missingValueRow : possibleSplit.naHand){
final boolean randomDecision = random.nextDouble() <= probabilityLeftHand;
if(randomDecision){
possibleSplit.leftHand.add(missingValueRow);
}
else{
possibleSplit.rightHand.add(missingValueRow);
}
}
final Double score = groupDifferentiator.differentiate(
possibleSplit.leftHand.stream().map(row -> row.getResponse()).collect(Collectors.toList()),
possibleSplit.rightHand.stream().map(row -> row.getResponse()).collect(Collectors.toList())
);
if(score != null && !Double.isNaN(score) && (score > bestSplitScore || first)){
bestSplitRuleAndSplit.splitRule = possibleRule;
bestSplitRuleAndSplit.split = possibleSplit;
bestSplitRuleAndSplit.probabilityLeftHand = probabilityLeftHand;
bestSplitScore = score;
first = false;
} }
} }
} }
return bestSplitRuleAndSplit; if(bestSplitAndScore == null){
return null;
}
return bestSplitAndScore.getSplit();
} }
@ -184,10 +169,4 @@ public class TreeTrainer<Y, O> {
return true; return true;
} }
private class SplitRuleAndSplit{
private Covariate.SplitRule splitRule = null;
private Split<Y> split = null;
private double probabilityLeftHand;
}
} }

View file

@ -0,0 +1,9 @@
package ca.joeltherrien.randomforest.utils;
import java.util.Iterator;
public interface IndexedIterator<E> extends Iterator<E> {
int getIndex();
}

View file

@ -0,0 +1,29 @@
package ca.joeltherrien.randomforest.utils;
import lombok.RequiredArgsConstructor;
import java.util.Iterator;
@RequiredArgsConstructor
public class SingletonIterator<E> implements Iterator<E> {
private final E value;
private boolean beenCalled = false;
@Override
public boolean hasNext() {
return !beenCalled;
}
@Override
public E next() {
if(!beenCalled){
beenCalled = true;
return value;
}
return null;
}
}

View file

@ -0,0 +1,65 @@
package ca.joeltherrien.randomforest.utils;
import lombok.Getter;
import java.util.Iterator;
/**
* Iterator that wraps around a UniqueValueIterator. It continues to iterate until it gets to one of the prespecified indexes,
* and then proceeds just past that to the end of the existing values it's at.
*
* The wrapped iterator must be from a sorted collection of some sort such that equal values are clumped together.
* I.e. "b b c c c d d a a" is okay but "a b b c c a" is not as 'a' appears twice at different locations
*
* @param <E>
*/
public class UniqueSubsetValueIterator<E> implements IndexedIterator<E> {
private final UniqueValueIterator<E> iterator;
private final Integer[] indexValues;
private int currentIndexSpot = 0;
public UniqueSubsetValueIterator(final UniqueValueIterator<E> iterator, final Integer[] indexValues){
this.iterator = iterator;
this.indexValues = indexValues;
}
@Override
public boolean hasNext() {
return iterator.hasNext() && iterator.getIndex() <= indexValues[indexValues.length-1];
}
@Override
public E next() {
if(hasNext()){
final int indexToStopBy = indexValues[currentIndexSpot];
while(iterator.getIndex() <= indexToStopBy){
iterator.next();
}
for(int i = currentIndexSpot + 1; i < indexValues.length; i++){
if(iterator.getIndex() <= indexValues[i]){
currentIndexSpot = i;
break;
}
}
return iterator.getCurrentValue();
}
return null;
}
@Override
public int getIndex(){
return iterator.getIndex();
}
}

View file

@ -0,0 +1,71 @@
package ca.joeltherrien.randomforest.utils;
import lombok.Getter;
import java.util.Iterator;
/**
* Iterator that wraps around another iterator. It continues to iterate until it gets to the *end* of a sequence of identical values.
* It also tracks the current index in the original iterator.
*
* The wrapped iterator must be from a sorted collection of some sort such that equal values are clumped together.
* I.e. "b b c c c d d a a" is okay but "a b b c c a" is not as 'a' appears twice at different locations
*
* @param <E>
*/
public class UniqueValueIterator<E> implements IndexedIterator<E> {
private final Iterator<E> wrappedIterator;
@Getter private E currentValue = null;
@Getter private E nextValue;
public UniqueValueIterator(final Iterator<E> wrappedIterator){
this.wrappedIterator = wrappedIterator;
this.nextValue = wrappedIterator.next();
}
// Count must return the index of the last value of the sequence returned by next()
@Getter
private int index = 0;
@Override
public boolean hasNext() {
return nextValue != null;
}
@Override
public E next() {
int count = 1;
while(wrappedIterator.hasNext()){
final E currentIteratorValue = wrappedIterator.next();
if(currentIteratorValue.equals(nextValue)){
count++;
}
else{
index +=count;
currentValue = nextValue;
nextValue = currentIteratorValue;
return currentValue;
}
}
if(nextValue != null){
index += count;
currentValue = nextValue;
nextValue = null;
return currentValue;
}
else{
return null;
}
}
}

View file

@ -1,8 +1,8 @@
package ca.joeltherrien.randomforest; package ca.joeltherrien.randomforest;
import ca.joeltherrien.randomforest.covariates.BooleanCovariateSettings; import ca.joeltherrien.randomforest.covariates.settings.BooleanCovariateSettings;
import ca.joeltherrien.randomforest.covariates.Covariate; import ca.joeltherrien.randomforest.covariates.Covariate;
import ca.joeltherrien.randomforest.covariates.NumericCovariateSettings; import ca.joeltherrien.randomforest.covariates.settings.NumericCovariateSettings;
import ca.joeltherrien.randomforest.responses.competingrisk.combiner.CompetingRiskFunctionCombiner; import ca.joeltherrien.randomforest.responses.competingrisk.combiner.CompetingRiskFunctionCombiner;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskFunctions; import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskFunctions;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse; import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;

View file

@ -2,6 +2,8 @@ package ca.joeltherrien.randomforest.competingrisk;
import ca.joeltherrien.randomforest.*; import ca.joeltherrien.randomforest.*;
import ca.joeltherrien.randomforest.covariates.*; import ca.joeltherrien.randomforest.covariates.*;
import ca.joeltherrien.randomforest.covariates.settings.BooleanCovariateSettings;
import ca.joeltherrien.randomforest.covariates.settings.NumericCovariateSettings;
import ca.joeltherrien.randomforest.responses.competingrisk.*; import ca.joeltherrien.randomforest.responses.competingrisk.*;
import ca.joeltherrien.randomforest.tree.Forest; import ca.joeltherrien.randomforest.tree.Forest;
import ca.joeltherrien.randomforest.tree.ForestTrainer; import ca.joeltherrien.randomforest.tree.ForestTrainer;
@ -18,6 +20,7 @@ import static org.junit.jupiter.api.Assertions.*;
import java.io.IOException; import java.io.IOException;
import java.util.List; import java.util.List;
import java.util.Random;
public class TestCompetingRisk { public class TestCompetingRisk {
@ -104,7 +107,7 @@ public class TestCompetingRisk {
final List<Row<CompetingRiskResponse>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation()); final List<Row<CompetingRiskResponse>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation());
final TreeTrainer<CompetingRiskResponse, CompetingRiskFunctions> treeTrainer = new TreeTrainer<>(settings, covariates); final TreeTrainer<CompetingRiskResponse, CompetingRiskFunctions> treeTrainer = new TreeTrainer<>(settings, covariates);
final Node<CompetingRiskFunctions> node = treeTrainer.growTree(dataset); final Node<CompetingRiskFunctions> node = treeTrainer.growTree(dataset, new Random());
final CovariateRow newRow = getPredictionRow(covariates); final CovariateRow newRow = getPredictionRow(covariates);
@ -157,7 +160,7 @@ public class TestCompetingRisk {
final List<Row<CompetingRiskResponse>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation()); final List<Row<CompetingRiskResponse>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation());
final TreeTrainer<CompetingRiskResponse, CompetingRiskFunctions> treeTrainer = new TreeTrainer<>(settings, covariates); final TreeTrainer<CompetingRiskResponse, CompetingRiskFunctions> treeTrainer = new TreeTrainer<>(settings, covariates);
final Node<CompetingRiskFunctions> node = treeTrainer.growTree(dataset); final Node<CompetingRiskFunctions> node = treeTrainer.growTree(dataset, new Random());
final CovariateRow newRow = getPredictionRow(covariates); final CovariateRow newRow = getPredictionRow(covariates);

View file

@ -4,7 +4,7 @@ import ca.joeltherrien.randomforest.DataLoader;
import ca.joeltherrien.randomforest.Row; import ca.joeltherrien.randomforest.Row;
import ca.joeltherrien.randomforest.Settings; import ca.joeltherrien.randomforest.Settings;
import ca.joeltherrien.randomforest.covariates.Covariate; import ca.joeltherrien.randomforest.covariates.Covariate;
import ca.joeltherrien.randomforest.covariates.NumericCovariateSettings; import ca.joeltherrien.randomforest.covariates.settings.NumericCovariateSettings;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse; import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
import ca.joeltherrien.randomforest.responses.competingrisk.differentiator.LogRankMultipleGroupDifferentiator; import ca.joeltherrien.randomforest.responses.competingrisk.differentiator.LogRankMultipleGroupDifferentiator;
import ca.joeltherrien.randomforest.utils.Utils; import ca.joeltherrien.randomforest.utils.Utils;
@ -53,14 +53,14 @@ public class TestLogRankMultipleGroupDifferentiator {
final List<Row<CompetingRiskResponse>> group1Bad = data.subList(0, 196); final List<Row<CompetingRiskResponse>> group1Bad = data.subList(0, 196);
final List<Row<CompetingRiskResponse>> group2Bad = data.subList(196, data.size()); final List<Row<CompetingRiskResponse>> group2Bad = data.subList(196, data.size());
final double scoreBad = groupDifferentiator.differentiate( final double scoreBad = groupDifferentiator.getScore(
group1Bad.stream().map(Row::getResponse).collect(Collectors.toList()), group1Bad.stream().map(Row::getResponse).collect(Collectors.toList()),
group2Bad.stream().map(Row::getResponse).collect(Collectors.toList())); group2Bad.stream().map(Row::getResponse).collect(Collectors.toList()));
final List<Row<CompetingRiskResponse>> group1Good = data.subList(0, 199); final List<Row<CompetingRiskResponse>> group1Good = data.subList(0, 199);
final List<Row<CompetingRiskResponse>> group2Good= data.subList(199, data.size()); final List<Row<CompetingRiskResponse>> group2Good= data.subList(199, data.size());
final double scoreGood = groupDifferentiator.differentiate( final double scoreGood = groupDifferentiator.getScore(
group1Good.stream().map(Row::getResponse).collect(Collectors.toList()), group1Good.stream().map(Row::getResponse).collect(Collectors.toList()),
group2Good.stream().map(Row::getResponse).collect(Collectors.toList())); group2Good.stream().map(Row::getResponse).collect(Collectors.toList()));

View file

@ -1,8 +1,7 @@
package ca.joeltherrien.randomforest.competingrisk; package ca.joeltherrien.randomforest.competingrisk;
import ca.joeltherrien.randomforest.Row; import ca.joeltherrien.randomforest.Row;
import ca.joeltherrien.randomforest.responses.competingrisk.*; import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
import ca.joeltherrien.randomforest.responses.competingrisk.differentiator.LogRankMultipleGroupDifferentiator;
import ca.joeltherrien.randomforest.responses.competingrisk.differentiator.LogRankSingleGroupDifferentiator; import ca.joeltherrien.randomforest.responses.competingrisk.differentiator.LogRankSingleGroupDifferentiator;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
@ -51,7 +50,7 @@ public class TestLogRankSingleGroupDifferentiator {
final LogRankSingleGroupDifferentiator differentiator = new LogRankSingleGroupDifferentiator(1, new int[]{1}); final LogRankSingleGroupDifferentiator differentiator = new LogRankSingleGroupDifferentiator(1, new int[]{1});
final double score = differentiator.differentiate(data1, data2); final double score = differentiator.getScore(data1, data2);
final double margin = 0.000001; final double margin = 0.000001;
// Tested using 855 method // Tested using 855 method
@ -71,14 +70,14 @@ public class TestLogRankSingleGroupDifferentiator {
final List<Row<CompetingRiskResponse>> group1Good = data.subList(0, 221); final List<Row<CompetingRiskResponse>> group1Good = data.subList(0, 221);
final List<Row<CompetingRiskResponse>> group2Good = data.subList(221, data.size()); final List<Row<CompetingRiskResponse>> group2Good = data.subList(221, data.size());
final double scoreGood = groupDifferentiator.differentiate( final double scoreGood = groupDifferentiator.getScore(
group1Good.stream().map(Row::getResponse).collect(Collectors.toList()), group1Good.stream().map(Row::getResponse).collect(Collectors.toList()),
group2Good.stream().map(Row::getResponse).collect(Collectors.toList())); group2Good.stream().map(Row::getResponse).collect(Collectors.toList()));
final List<Row<CompetingRiskResponse>> group1Bad = data.subList(0, 222); final List<Row<CompetingRiskResponse>> group1Bad = data.subList(0, 222);
final List<Row<CompetingRiskResponse>> group2Bad = data.subList(222, data.size()); final List<Row<CompetingRiskResponse>> group2Bad = data.subList(222, data.size());
final double scoreBad = groupDifferentiator.differentiate( final double scoreBad = groupDifferentiator.getScore(
group1Bad.stream().map(Row::getResponse).collect(Collectors.toList()), group1Bad.stream().map(Row::getResponse).collect(Collectors.toList()),
group2Bad.stream().map(Row::getResponse).collect(Collectors.toList())); group2Bad.stream().map(Row::getResponse).collect(Collectors.toList()));

View file

@ -5,8 +5,9 @@ import ca.joeltherrien.randomforest.utils.Utils;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.function.Executable; import org.junit.jupiter.api.function.Executable;
import java.util.Collection; import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.Random;
import static org.junit.jupiter.api.Assertions.*; import static org.junit.jupiter.api.Assertions.*;
@ -46,7 +47,10 @@ public class FactorCovariateTest {
void testAllSubsets(){ void testAllSubsets(){
final FactorCovariate petCovariate = createTestCovariate(); final FactorCovariate petCovariate = createTestCovariate();
final Collection<FactorCovariate.FactorSplitRule> splitRules = petCovariate.generateSplitRules(null, 100); final List<Covariate.SplitRule<String>> splitRules = new ArrayList<>();
petCovariate.generateSplitRuleUpdater(null, 100, new Random())
.forEachRemaining(split -> splitRules.add(split.getSplitRule()));
assertEquals(splitRules.size(), 3); assertEquals(splitRules.size(), 3);

View file

@ -3,10 +3,10 @@ package ca.joeltherrien.randomforest.csv;
import ca.joeltherrien.randomforest.DataLoader; import ca.joeltherrien.randomforest.DataLoader;
import ca.joeltherrien.randomforest.Row; import ca.joeltherrien.randomforest.Row;
import ca.joeltherrien.randomforest.Settings; import ca.joeltherrien.randomforest.Settings;
import ca.joeltherrien.randomforest.covariates.BooleanCovariateSettings; import ca.joeltherrien.randomforest.covariates.settings.BooleanCovariateSettings;
import ca.joeltherrien.randomforest.covariates.Covariate; import ca.joeltherrien.randomforest.covariates.Covariate;
import ca.joeltherrien.randomforest.covariates.FactorCovariateSettings; import ca.joeltherrien.randomforest.covariates.settings.FactorCovariateSettings;
import ca.joeltherrien.randomforest.covariates.NumericCovariateSettings; import ca.joeltherrien.randomforest.covariates.settings.NumericCovariateSettings;
import ca.joeltherrien.randomforest.utils.Utils; import ca.joeltherrien.randomforest.utils.Utils;
import com.fasterxml.jackson.databind.node.JsonNodeFactory; import com.fasterxml.jackson.databind.node.JsonNodeFactory;
import com.fasterxml.jackson.databind.node.ObjectNode; import com.fasterxml.jackson.databind.node.ObjectNode;
@ -15,7 +15,6 @@ import org.junit.jupiter.api.Test;
import java.io.IOException; import java.io.IOException;
import java.util.List; import java.util.List;
import java.util.stream.Collectors;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.assertTrue;

View file

@ -3,7 +3,9 @@ package ca.joeltherrien.randomforest.settings;
import ca.joeltherrien.randomforest.Settings; import ca.joeltherrien.randomforest.Settings;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import ca.joeltherrien.randomforest.covariates.*; import ca.joeltherrien.randomforest.covariates.settings.BooleanCovariateSettings;
import ca.joeltherrien.randomforest.covariates.settings.FactorCovariateSettings;
import ca.joeltherrien.randomforest.covariates.settings.NumericCovariateSettings;
import ca.joeltherrien.randomforest.utils.Utils; import ca.joeltherrien.randomforest.utils.Utils;
import com.fasterxml.jackson.databind.node.JsonNodeFactory; import com.fasterxml.jackson.databind.node.JsonNodeFactory;
import com.fasterxml.jackson.databind.node.ObjectNode; import com.fasterxml.jackson.databind.node.ObjectNode;

View file

@ -0,0 +1,25 @@
package ca.joeltherrien.randomforest.utils;
import org.junit.jupiter.api.Test;
import static org.junit.jupiter.api.Assertions.*;
public class TestSingletonIterator {
@Test
public void verifyBehaviour(){
final Integer element = 5;
final SingletonIterator<Integer> iterator = new SingletonIterator<>(element);
assertTrue(iterator.hasNext());
assertTrue(iterator.hasNext());
assertEquals(Integer.valueOf(5), iterator.next());
assertFalse(iterator.hasNext());
assertNull(iterator.next());
}
}

View file

@ -0,0 +1,68 @@
package ca.joeltherrien.randomforest.utils;
import org.junit.jupiter.api.Test;
import java.util.Arrays;
import java.util.List;
import static org.junit.jupiter.api.Assertions.*;
public class TestUniqueSubsetValueIterator {
@Test
public void testIterator1(){
final List<Integer> testData = Arrays.asList(
1,1,2,3,5,5,5,6,7,7
);
final Integer[] indexes = new Integer[]{2,3,4,5};
final UniqueValueIterator<Integer> uniqueValueIterator = new UniqueValueIterator<>(testData.iterator());
final UniqueSubsetValueIterator<Integer> iterator = new UniqueSubsetValueIterator<>(uniqueValueIterator, indexes);
// we expect to get 2, 3, and 5 back. 5 should happen only once
assertEquals(iterator.getIndex(), 0);
assertTrue(iterator.hasNext());
assertEquals(iterator.next().intValue(), 2);
assertEquals(iterator.getIndex(), 3);
assertTrue(iterator.hasNext());
assertEquals(iterator.next().intValue(), 3);
assertEquals(iterator.getIndex(), 4);
assertTrue(iterator.hasNext());
assertEquals(iterator.next().intValue(), 5);
assertEquals(iterator.getIndex(), 7);
assertFalse(iterator.hasNext());
}
@Test
public void testIterator2(){
final List<Integer> testData = Arrays.asList(
1,1,2,3,5,5,5,6,7,7
);
final Integer[] indexes = new Integer[]{1,8};
final UniqueValueIterator<Integer> uniqueValueIterator = new UniqueValueIterator<>(testData.iterator());
final UniqueSubsetValueIterator<Integer> iterator = new UniqueSubsetValueIterator<>(uniqueValueIterator, indexes);
assertEquals(iterator.getIndex(), 0);
assertTrue(iterator.hasNext());
assertEquals(iterator.next().intValue(), 1);
assertEquals(iterator.getIndex(), 2);
assertTrue(iterator.hasNext());
assertEquals(iterator.next().intValue(), 7);
assertEquals(iterator.getIndex(), 10);
assertFalse(iterator.hasNext());
}
}

View file

@ -0,0 +1,112 @@
package ca.joeltherrien.randomforest.utils;
import org.junit.jupiter.api.Test;
import java.util.Arrays;
import java.util.List;
import static org.junit.jupiter.api.Assertions.*;
public class TestUniqueValueIterator {
@Test
public void testIterator1(){
final List<Integer> testData = Arrays.asList(
1,1,2,3,5,5,5,6,7,7
);
final UniqueValueIterator<Integer> iterator = new UniqueValueIterator<>(testData.iterator());
assertEquals(iterator.getIndex(), 0);
assertTrue(iterator.hasNext());
assertEquals(iterator.next().intValue(), 1);
assertEquals(iterator.getIndex(), 2);
assertTrue(iterator.hasNext());
assertEquals(iterator.next().intValue(), 2);
assertEquals(iterator.getIndex(), 3);
assertTrue(iterator.hasNext());
assertEquals(iterator.next().intValue(), 3);
assertEquals(iterator.getIndex(), 4);
assertTrue(iterator.hasNext());
assertEquals(iterator.next().intValue(), 5);
assertEquals(iterator.getIndex(), 7);
assertTrue(iterator.hasNext());
assertEquals(iterator.next().intValue(), 6);
assertEquals(iterator.getIndex(), 8);
assertTrue(iterator.hasNext());
assertEquals(iterator.next().intValue(), 7);
assertEquals(iterator.getIndex(), 10);
assertTrue(!iterator.hasNext());
}
@Test
public void testIterator2(){
final List<Integer> testData = Arrays.asList(
1,2,3,5,5,5,6,7 // same numbers; but 1 and 7 only appear once each
);
final UniqueValueIterator<Integer> iterator = new UniqueValueIterator<>(testData.iterator());
assertEquals(iterator.getIndex(), 0);
assertTrue(iterator.hasNext());
assertEquals(iterator.next().intValue(), 1);
assertEquals(iterator.getIndex(), 1);
assertTrue(iterator.hasNext());
assertEquals(iterator.next().intValue(), 2);
assertEquals(iterator.getIndex(), 2);
assertTrue(iterator.hasNext());
assertEquals(iterator.next().intValue(), 3);
assertEquals(iterator.getIndex(), 3);
assertTrue(iterator.hasNext());
assertEquals(iterator.next().intValue(), 5);
assertEquals(iterator.getIndex(), 6);
assertTrue(iterator.hasNext());
assertEquals(iterator.next().intValue(), 6);
assertEquals(iterator.getIndex(), 7);
assertTrue(iterator.hasNext());
assertEquals(iterator.next().intValue(), 7);
assertEquals(iterator.getIndex(), 8);
assertFalse(iterator.hasNext());
}
@Test
public void testIterator3(){
final List<Integer> testData = Arrays.asList(
1,1,1,1,1,1,1,2,2,2,2,2,3
);
final UniqueValueIterator<Integer> iterator = new UniqueValueIterator<>(testData.iterator());
assertEquals(iterator.getIndex(), 0);
assertTrue(iterator.hasNext());
assertEquals(iterator.next().intValue(), 1);
assertEquals(iterator.getIndex(), 7);
assertTrue(iterator.hasNext());
assertEquals(iterator.next().intValue(), 2);
assertEquals(iterator.getIndex(), 12);
assertTrue(iterator.hasNext());
assertEquals(iterator.next().intValue(), 3);
assertEquals(iterator.getIndex(), 13);
assertFalse(iterator.hasNext());
}
}

View file

@ -2,7 +2,7 @@ package ca.joeltherrien.randomforest.workshop;
import ca.joeltherrien.randomforest.*; import ca.joeltherrien.randomforest.*;
import ca.joeltherrien.randomforest.covariates.Covariate; import ca.joeltherrien.randomforest.covariates.Covariate;
import ca.joeltherrien.randomforest.covariates.NumericCovariate; import ca.joeltherrien.randomforest.covariates.numeric.NumericCovariate;
import ca.joeltherrien.randomforest.responses.regression.MeanResponseCombiner; import ca.joeltherrien.randomforest.responses.regression.MeanResponseCombiner;
import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceGroupDifferentiator; import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceGroupDifferentiator;
import ca.joeltherrien.randomforest.tree.ForestTrainer; import ca.joeltherrien.randomforest.tree.ForestTrainer;

View file

@ -3,7 +3,7 @@ package ca.joeltherrien.randomforest.workshop;
import ca.joeltherrien.randomforest.covariates.Covariate; import ca.joeltherrien.randomforest.covariates.Covariate;
import ca.joeltherrien.randomforest.CovariateRow; import ca.joeltherrien.randomforest.CovariateRow;
import ca.joeltherrien.randomforest.covariates.NumericCovariate; import ca.joeltherrien.randomforest.covariates.numeric.NumericCovariate;
import ca.joeltherrien.randomforest.Row; import ca.joeltherrien.randomforest.Row;
import ca.joeltherrien.randomforest.responses.regression.MeanResponseCombiner; import ca.joeltherrien.randomforest.responses.regression.MeanResponseCombiner;
import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceGroupDifferentiator; import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceGroupDifferentiator;
@ -54,13 +54,14 @@ public class TrainSingleTree {
.covariates(covariateNames) .covariates(covariateNames)
.responseCombiner(new MeanResponseCombiner()) .responseCombiner(new MeanResponseCombiner())
.maxNodeDepth(30) .maxNodeDepth(30)
.mtry(2)
.nodeSize(5) .nodeSize(5)
.numberOfSplits(0) .numberOfSplits(0)
.build(); .build();
final long startTime = System.currentTimeMillis(); final long startTime = System.currentTimeMillis();
final Node<Double> baseNode = treeTrainer.growTree(trainingSet); final Node<Double> baseNode = treeTrainer.growTree(trainingSet, new Random());
final long endTime = System.currentTimeMillis(); final long endTime = System.currentTimeMillis();
System.out.println(((double)(endTime - startTime))/1000.0); System.out.println(((double)(endTime - startTime))/1000.0);

View file

@ -4,7 +4,7 @@ package ca.joeltherrien.randomforest.workshop;
import ca.joeltherrien.randomforest.*; import ca.joeltherrien.randomforest.*;
import ca.joeltherrien.randomforest.covariates.Covariate; import ca.joeltherrien.randomforest.covariates.Covariate;
import ca.joeltherrien.randomforest.covariates.FactorCovariate; import ca.joeltherrien.randomforest.covariates.FactorCovariate;
import ca.joeltherrien.randomforest.covariates.NumericCovariate; import ca.joeltherrien.randomforest.covariates.numeric.NumericCovariate;
import ca.joeltherrien.randomforest.responses.regression.MeanResponseCombiner; import ca.joeltherrien.randomforest.responses.regression.MeanResponseCombiner;
import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceGroupDifferentiator; import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceGroupDifferentiator;
import ca.joeltherrien.randomforest.tree.Node; import ca.joeltherrien.randomforest.tree.Node;
@ -13,7 +13,6 @@ import ca.joeltherrien.randomforest.utils.Utils;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.Map;
import java.util.Random; import java.util.Random;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import java.util.stream.DoubleStream; import java.util.stream.DoubleStream;
@ -21,8 +20,6 @@ import java.util.stream.DoubleStream;
public class TrainSingleTreeFactor { public class TrainSingleTreeFactor {
public static void main(String[] args) { public static void main(String[] args) {
System.out.println("Hello world!");
final Random random = new Random(123); final Random random = new Random(123);
final int n = 10000; final int n = 10000;
@ -84,7 +81,7 @@ public class TrainSingleTreeFactor {
final long startTime = System.currentTimeMillis(); final long startTime = System.currentTimeMillis();
final Node<Double> baseNode = treeTrainer.growTree(trainingSet); final Node<Double> baseNode = treeTrainer.growTree(trainingSet, random);
final long endTime = System.currentTimeMillis(); final long endTime = System.currentTimeMillis();
System.out.println(((double)(endTime - startTime))/1000.0); System.out.println(((double)(endTime - startTime))/1000.0);