Add optimizations. #13
45 changed files with 741 additions and 212 deletions
|
@ -4,20 +4,20 @@ import lombok.RequiredArgsConstructor;
|
|||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.ThreadLocalRandom;
|
||||
import java.util.Random;
|
||||
|
||||
@RequiredArgsConstructor
|
||||
public class Bootstrapper<T> {
|
||||
|
||||
final private List<T> originalData;
|
||||
|
||||
public List<T> bootstrap(){
|
||||
public List<T> bootstrap(Random random){
|
||||
final int n = originalData.size();
|
||||
|
||||
final List<T> newList = new ArrayList<>(n);
|
||||
|
||||
for(int i=0; i<n; i++){
|
||||
final int index = ThreadLocalRandom.current().nextInt(n);
|
||||
final int index = random.nextInt(n);
|
||||
|
||||
newList.add(originalData.get(index));
|
||||
}
|
||||
|
|
|
@ -17,7 +17,7 @@ public class CovariateRow implements Serializable {
|
|||
@Getter
|
||||
private final int id;
|
||||
|
||||
public Covariate.Value<?> getCovariateValue(Covariate covariate){
|
||||
public <V> Covariate.Value<V> getCovariateValue(Covariate<V> covariate){
|
||||
return valueArray[covariate.getIndex()];
|
||||
}
|
||||
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
package ca.joeltherrien.randomforest;
|
||||
|
||||
import ca.joeltherrien.randomforest.covariates.BooleanCovariateSettings;
|
||||
import ca.joeltherrien.randomforest.covariates.settings.BooleanCovariateSettings;
|
||||
import ca.joeltherrien.randomforest.covariates.Covariate;
|
||||
import ca.joeltherrien.randomforest.covariates.FactorCovariateSettings;
|
||||
import ca.joeltherrien.randomforest.covariates.NumericCovariateSettings;
|
||||
import ca.joeltherrien.randomforest.covariates.settings.FactorCovariateSettings;
|
||||
import ca.joeltherrien.randomforest.covariates.settings.NumericCovariateSettings;
|
||||
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskErrorRateCalculator;
|
||||
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskFunctions;
|
||||
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
|
||||
|
|
|
@ -17,6 +17,8 @@ public class Row<Y> extends CovariateRow {
|
|||
}
|
||||
|
||||
|
||||
|
||||
|
||||
public Y getResponse() {
|
||||
return this.response;
|
||||
}
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
package ca.joeltherrien.randomforest;
|
||||
|
||||
import ca.joeltherrien.randomforest.covariates.Covariate;
|
||||
import ca.joeltherrien.randomforest.covariates.CovariateSettings;
|
||||
import ca.joeltherrien.randomforest.covariates.settings.CovariateSettings;
|
||||
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
|
||||
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponseWithCensorTime;
|
||||
import ca.joeltherrien.randomforest.responses.competingrisk.combiner.CompetingRiskFunctionCombiner;
|
||||
|
@ -10,7 +10,6 @@ import ca.joeltherrien.randomforest.responses.competingrisk.differentiator.GrayL
|
|||
import ca.joeltherrien.randomforest.responses.competingrisk.differentiator.GrayLogRankSingleGroupDifferentiator;
|
||||
import ca.joeltherrien.randomforest.responses.competingrisk.differentiator.LogRankMultipleGroupDifferentiator;
|
||||
import ca.joeltherrien.randomforest.responses.competingrisk.differentiator.LogRankSingleGroupDifferentiator;
|
||||
import ca.joeltherrien.randomforest.responses.regression.MeanGroupDifferentiator;
|
||||
import ca.joeltherrien.randomforest.responses.regression.MeanResponseCombiner;
|
||||
import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceGroupDifferentiator;
|
||||
import ca.joeltherrien.randomforest.tree.GroupDifferentiator;
|
||||
|
@ -68,9 +67,6 @@ public class Settings {
|
|||
GROUP_DIFFERENTIATOR_MAP.put(name.toLowerCase(), groupDifferentiatorConstructor);
|
||||
}
|
||||
static{
|
||||
registerGroupDifferentiatorConstructor("MeanGroupDifferentiator",
|
||||
(node) -> new MeanGroupDifferentiator()
|
||||
);
|
||||
registerGroupDifferentiatorConstructor("WeightedVarianceGroupDifferentiator",
|
||||
(node) -> new WeightedVarianceGroupDifferentiator()
|
||||
);
|
||||
|
|
|
@ -1,11 +1,12 @@
|
|||
package ca.joeltherrien.randomforest.covariates;
|
||||
|
||||
import ca.joeltherrien.randomforest.Row;
|
||||
import ca.joeltherrien.randomforest.tree.Split;
|
||||
import ca.joeltherrien.randomforest.utils.SingletonIterator;
|
||||
import lombok.Getter;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
|
||||
import java.util.Collection;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.*;
|
||||
|
||||
@RequiredArgsConstructor
|
||||
public final class BooleanCovariate implements Covariate<Boolean> {
|
||||
|
@ -19,8 +20,8 @@ public final class BooleanCovariate implements Covariate<Boolean>{
|
|||
private final BooleanSplitRule splitRule = new BooleanSplitRule(); // there's only one possible rule for BooleanCovariates.
|
||||
|
||||
@Override
|
||||
public Collection<BooleanSplitRule> generateSplitRules(List<Value<Boolean>> data, int number) {
|
||||
return Collections.singleton(splitRule);
|
||||
public <Y> Iterator<Split<Y, Boolean>> generateSplitRuleUpdater(List<Row<Y>> data, int number, Random random) {
|
||||
return new SingletonIterator<>(BooleanCovariate.this.splitRule.applyRule(data));
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -74,6 +75,7 @@ public final class BooleanCovariate implements Covariate<Boolean>{
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
public class BooleanSplitRule implements SplitRule<Boolean>{
|
||||
|
||||
@Override
|
||||
|
|
|
@ -5,10 +5,7 @@ import ca.joeltherrien.randomforest.Row;
|
|||
import ca.joeltherrien.randomforest.tree.Split;
|
||||
|
||||
import java.io.Serializable;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collection;
|
||||
import java.util.LinkedList;
|
||||
import java.util.List;
|
||||
import java.util.*;
|
||||
import java.util.concurrent.ThreadLocalRandom;
|
||||
|
||||
public interface Covariate<V> extends Serializable {
|
||||
|
@ -17,7 +14,7 @@ public interface Covariate<V> extends Serializable {
|
|||
|
||||
int getIndex();
|
||||
|
||||
Collection<? extends SplitRule<V>> generateSplitRules(final List<Value<V>> data, final int number);
|
||||
<Y> Iterator<Split<Y, V>> generateSplitRuleUpdater(final List<Row<Y>> data, final int number, final Random random);
|
||||
|
||||
Value<V> createValue(V value);
|
||||
|
||||
|
@ -39,6 +36,16 @@ public interface Covariate<V> extends Serializable {
|
|||
|
||||
}
|
||||
|
||||
interface SplitRuleUpdater<Y, V> extends Iterator<Split<Y, V>>{
|
||||
Split<Y, V> currentSplit();
|
||||
SplitUpdate<Y, V> nextUpdate();
|
||||
}
|
||||
|
||||
interface SplitUpdate<Y, V> {
|
||||
SplitRule<V> getSplitRule();
|
||||
Collection<Row<Y>> rowsMovedToLeftHand();
|
||||
}
|
||||
|
||||
interface SplitRule<V> extends Serializable{
|
||||
|
||||
Covariate<V> getParent();
|
||||
|
@ -51,7 +58,7 @@ public interface Covariate<V> extends Serializable {
|
|||
* @param <Y>
|
||||
* @return
|
||||
*/
|
||||
default <Y> Split<Y> applyRule(List<Row<Y>> rows) {
|
||||
default <Y> Split<Y, V> applyRule(List<Row<Y>> rows) {
|
||||
final List<Row<Y>> leftHand = new LinkedList<>();
|
||||
final List<Row<Y>> rightHand = new LinkedList<>();
|
||||
|
||||
|
@ -77,7 +84,7 @@ public interface Covariate<V> extends Serializable {
|
|||
}
|
||||
|
||||
|
||||
return new Split<>(leftHand, rightHand, missingValueRows);
|
||||
return new Split<>(this, leftHand, rightHand, missingValueRows);
|
||||
}
|
||||
|
||||
default boolean isLeftHand(CovariateRow row, final double probabilityNaLeftHand){
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
package ca.joeltherrien.randomforest.covariates;
|
||||
|
||||
import ca.joeltherrien.randomforest.Row;
|
||||
import ca.joeltherrien.randomforest.tree.Split;
|
||||
import lombok.EqualsAndHashCode;
|
||||
import lombok.Getter;
|
||||
|
||||
|
@ -42,17 +44,14 @@ public final class FactorCovariate implements Covariate<String>{
|
|||
|
||||
|
||||
@Override
|
||||
public Set<FactorSplitRule> generateSplitRules(List<Value<String>> data, int number) {
|
||||
final Set<FactorSplitRule> splitRules = new HashSet<>();
|
||||
public <Y> Iterator<Split<Y, String>> generateSplitRuleUpdater(List<Row<Y>> data, int number, Random random) {
|
||||
final Set<Split<Y, String>> splits = new HashSet<>();
|
||||
|
||||
// This is to ensure we don't get stuck in an infinite loop for small factors
|
||||
number = Math.min(number, numberOfPossiblePairings);
|
||||
final Random random = ThreadLocalRandom.current();
|
||||
final List<FactorValue> levels = new ArrayList<>(factorLevels.values());
|
||||
|
||||
|
||||
|
||||
while(splitRules.size() < number){
|
||||
while(splits.size() < number){
|
||||
Collections.shuffle(levels, random);
|
||||
final Set<FactorValue> leftSideValues = new HashSet<>();
|
||||
leftSideValues.add(levels.get(0));
|
||||
|
@ -63,13 +62,14 @@ public final class FactorCovariate implements Covariate<String>{
|
|||
}
|
||||
}
|
||||
|
||||
splitRules.add(new FactorSplitRule(leftSideValues));
|
||||
splits.add(new FactorSplitRule(leftSideValues).applyRule(data));
|
||||
}
|
||||
|
||||
return splitRules;
|
||||
return splits.iterator();
|
||||
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public FactorValue createValue(String value) {
|
||||
if(value == null || value.equalsIgnoreCase("na")){
|
||||
|
|
|
@ -1,12 +1,16 @@
|
|||
package ca.joeltherrien.randomforest.covariates;
|
||||
package ca.joeltherrien.randomforest.covariates.numeric;
|
||||
|
||||
import ca.joeltherrien.randomforest.Row;
|
||||
import ca.joeltherrien.randomforest.covariates.Covariate;
|
||||
import ca.joeltherrien.randomforest.utils.IndexedIterator;
|
||||
import ca.joeltherrien.randomforest.utils.UniqueSubsetValueIterator;
|
||||
import ca.joeltherrien.randomforest.utils.UniqueValueIterator;
|
||||
import lombok.EqualsAndHashCode;
|
||||
import lombok.Getter;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import lombok.ToString;
|
||||
|
||||
import java.util.*;
|
||||
import java.util.concurrent.ThreadLocalRandom;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@RequiredArgsConstructor
|
||||
|
@ -20,40 +24,44 @@ public final class NumericCovariate implements Covariate<Double>{
|
|||
private final int index;
|
||||
|
||||
@Override
|
||||
public Collection<NumericSplitRule> generateSplitRules(List<Value<Double>> data, int number) {
|
||||
public <Y> NumericSplitRuleUpdater<Y> generateSplitRuleUpdater(List<Row<Y>> data, int number, Random random) {
|
||||
data = data.stream()
|
||||
.filter(row -> !row.getCovariateValue(this).isNA())
|
||||
.sorted((r1, r2) -> {
|
||||
Double d1 = r1.getCovariateValue(this).getValue();
|
||||
Double d2 = r2.getCovariateValue(this).getValue();
|
||||
|
||||
final Random random = ThreadLocalRandom.current();
|
||||
return d1.compareTo(d2);
|
||||
})
|
||||
.collect(Collectors.toList());
|
||||
|
||||
// only work with non-NA values
|
||||
data = data.stream().filter(value -> !value.isNA()).collect(Collectors.toList());
|
||||
//data = data.stream().filter(value -> !value.isNA()).distinct().collect(Collectors.toList()); // TODO which to use?
|
||||
Iterator<Double> sortedDataIterator = data.stream()
|
||||
.map(row -> row.getCovariateValue(this).getValue())
|
||||
.filter(v -> v != null)
|
||||
.iterator();
|
||||
|
||||
// for this implementation we need to shuffle the data
|
||||
final List<Value<Double>> shuffledData;
|
||||
if(number >= data.size()){
|
||||
shuffledData = data;
|
||||
|
||||
final IndexedIterator<Double> dataIterator;
|
||||
if(number == 0){
|
||||
dataIterator = new UniqueValueIterator<>(sortedDataIterator);
|
||||
}
|
||||
else{ // only need the top number entries
|
||||
shuffledData = new ArrayList<>(number);
|
||||
final Set<Integer> indexesToUse = new HashSet<>();
|
||||
//final List<Integer> indexesToUse = new ArrayList<>(); // TODO which to use?
|
||||
else{ // random splitting; we will not weight based on how many times a Row appears in the bootstrap sample
|
||||
final TreeSet<Integer> indexSet = new TreeSet<>();
|
||||
|
||||
while(indexesToUse.size() < number){
|
||||
final int index = random.nextInt(data.size());
|
||||
final int maxIndex = data.size();
|
||||
|
||||
if(indexesToUse.add(index)){
|
||||
shuffledData.add(data.get(index));
|
||||
}
|
||||
for(int i=0; i<number; i++){
|
||||
indexSet.add(random.nextInt(maxIndex));
|
||||
}
|
||||
|
||||
dataIterator = new UniqueSubsetValueIterator<>(
|
||||
new UniqueValueIterator<>(sortedDataIterator),
|
||||
indexSet.toArray(new Integer[indexSet.size()]) // TODO verify this is ordered
|
||||
);
|
||||
|
||||
}
|
||||
|
||||
return shuffledData.stream()
|
||||
.mapToDouble(v -> v.getValue())
|
||||
.mapToObj(threshold -> new NumericSplitRule(threshold))
|
||||
.collect(Collectors.toSet());
|
||||
// by returning a set we'll make everything far more efficient as a lot of rules can repeat due to bootstrapping
|
||||
|
||||
return new NumericSplitRuleUpdater<>(this, data, dataIterator);
|
||||
|
||||
}
|
||||
|
||||
|
@ -101,7 +109,7 @@ public final class NumericCovariate implements Covariate<Double>{
|
|||
|
||||
private final double threshold;
|
||||
|
||||
private NumericSplitRule(final double threshold){
|
||||
NumericSplitRule(final double threshold){
|
||||
this.threshold = threshold;
|
||||
}
|
||||
|
|
@ -0,0 +1,79 @@
|
|||
package ca.joeltherrien.randomforest.covariates.numeric;
|
||||
|
||||
import ca.joeltherrien.randomforest.Row;
|
||||
import ca.joeltherrien.randomforest.covariates.Covariate;
|
||||
import ca.joeltherrien.randomforest.tree.Split;
|
||||
import ca.joeltherrien.randomforest.utils.IndexedIterator;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
|
||||
public class NumericSplitRuleUpdater<Y> implements Covariate.SplitRuleUpdater<Y, Double> {
|
||||
|
||||
private final NumericCovariate covariate;
|
||||
private final List<Row<Y>> orderedData;
|
||||
private final IndexedIterator<Double> dataIterator;
|
||||
|
||||
private Split<Y, Double> currentSplit;
|
||||
|
||||
public NumericSplitRuleUpdater(final NumericCovariate covariate, final List<Row<Y>> orderedData, final IndexedIterator<Double> iterator){
|
||||
this.covariate = covariate;
|
||||
this.orderedData = orderedData;
|
||||
this.dataIterator = iterator;
|
||||
|
||||
final List<Row<Y>> leftHandList = Collections.emptyList();
|
||||
final List<Row<Y>> rightHandList = orderedData;
|
||||
|
||||
this.currentSplit = new Split<>(
|
||||
covariate.new NumericSplitRule(Double.MIN_VALUE),
|
||||
leftHandList,
|
||||
rightHandList,
|
||||
Collections.emptyList());
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public Split<Y, Double> currentSplit() {
|
||||
return this.currentSplit;
|
||||
}
|
||||
|
||||
@Override
|
||||
public NumericSplitUpdate<Y> nextUpdate() {
|
||||
if(hasNext()){
|
||||
final int currentPosition = dataIterator.getIndex();
|
||||
final Double splitValue = dataIterator.next();
|
||||
final int newPosition = dataIterator.getIndex();
|
||||
|
||||
final List<Row<Y>> rowsMoved = orderedData.subList(currentPosition, newPosition);
|
||||
|
||||
final NumericCovariate.NumericSplitRule splitRule = covariate.new NumericSplitRule(splitValue);
|
||||
|
||||
// Update current split
|
||||
this.currentSplit = new Split<>(
|
||||
splitRule,
|
||||
orderedData.subList(0, newPosition),
|
||||
orderedData.subList(newPosition, orderedData.size()),
|
||||
Collections.emptyList());
|
||||
|
||||
|
||||
return new NumericSplitUpdate<>(splitRule, rowsMoved);
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean hasNext() {
|
||||
return dataIterator.hasNext();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Split<Y, Double> next() {
|
||||
if(hasNext()){
|
||||
nextUpdate();
|
||||
}
|
||||
|
||||
return this.currentSplit();
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,24 @@
|
|||
package ca.joeltherrien.randomforest.covariates.numeric;
|
||||
|
||||
import ca.joeltherrien.randomforest.Row;
|
||||
import ca.joeltherrien.randomforest.covariates.Covariate;
|
||||
import lombok.AllArgsConstructor;
|
||||
|
||||
import java.util.Collection;
|
||||
|
||||
@AllArgsConstructor
|
||||
public class NumericSplitUpdate<Y> implements Covariate.SplitUpdate<Y, Double> {
|
||||
|
||||
private final NumericCovariate.NumericSplitRule numericSplitRule;
|
||||
private final Collection<Row<Y>> rowsMoved;
|
||||
|
||||
@Override
|
||||
public NumericCovariate.NumericSplitRule getSplitRule() {
|
||||
return numericSplitRule;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Collection<Row<Y>> rowsMovedToLeftHand() {
|
||||
return rowsMoved;
|
||||
}
|
||||
}
|
|
@ -1,5 +1,6 @@
|
|||
package ca.joeltherrien.randomforest.covariates;
|
||||
package ca.joeltherrien.randomforest.covariates.settings;
|
||||
|
||||
import ca.joeltherrien.randomforest.covariates.BooleanCovariate;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
|
|
@ -1,5 +1,6 @@
|
|||
package ca.joeltherrien.randomforest.covariates;
|
||||
package ca.joeltherrien.randomforest.covariates.settings;
|
||||
|
||||
import ca.joeltherrien.randomforest.covariates.Covariate;
|
||||
import com.fasterxml.jackson.annotation.JsonSubTypes;
|
||||
import com.fasterxml.jackson.annotation.JsonTypeInfo;
|
||||
import lombok.Getter;
|
|
@ -1,5 +1,6 @@
|
|||
package ca.joeltherrien.randomforest.covariates;
|
||||
package ca.joeltherrien.randomforest.covariates.settings;
|
||||
|
||||
import ca.joeltherrien.randomforest.covariates.FactorCovariate;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
|
|
@ -1,5 +1,6 @@
|
|||
package ca.joeltherrien.randomforest.covariates;
|
||||
package ca.joeltherrien.randomforest.covariates.settings;
|
||||
|
||||
import ca.joeltherrien.randomforest.covariates.numeric.NumericCovariate;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
|
|
@ -3,6 +3,7 @@ package ca.joeltherrien.randomforest.responses.competingrisk.differentiator;
|
|||
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
|
||||
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskSets;
|
||||
import ca.joeltherrien.randomforest.tree.GroupDifferentiator;
|
||||
import ca.joeltherrien.randomforest.tree.SimpleGroupDifferentiator;
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Data;
|
||||
|
||||
|
@ -14,11 +15,7 @@ import java.util.stream.Stream;
|
|||
* modifies the abstract method.
|
||||
*
|
||||
*/
|
||||
public abstract class CompetingRiskGroupDifferentiator<Y extends CompetingRiskResponse> implements GroupDifferentiator<Y>{
|
||||
|
||||
@Override
|
||||
public abstract Double differentiate(List<Y> leftHand, List<Y> rightHand);
|
||||
|
||||
public abstract class CompetingRiskGroupDifferentiator<Y extends CompetingRiskResponse> extends SimpleGroupDifferentiator<Y> {
|
||||
|
||||
/**
|
||||
* Calculates the log rank value (or the Gray's test value) for a *specific* event cause.
|
||||
|
|
|
@ -17,7 +17,7 @@ public class GrayLogRankMultipleGroupDifferentiator extends CompetingRiskGroupDi
|
|||
private final int[] events;
|
||||
|
||||
@Override
|
||||
public Double differentiate(List<CompetingRiskResponseWithCensorTime> leftHand, List<CompetingRiskResponseWithCensorTime> rightHand) {
|
||||
public Double getScore(List<CompetingRiskResponseWithCensorTime> leftHand, List<CompetingRiskResponseWithCensorTime> rightHand) {
|
||||
if(leftHand.size() == 0 || rightHand.size() == 0){
|
||||
return null;
|
||||
}
|
||||
|
|
|
@ -18,7 +18,7 @@ public class GrayLogRankSingleGroupDifferentiator extends CompetingRiskGroupDiff
|
|||
private final int[] events;
|
||||
|
||||
@Override
|
||||
public Double differentiate(List<CompetingRiskResponseWithCensorTime> leftHand, List<CompetingRiskResponseWithCensorTime> rightHand) {
|
||||
public Double getScore(List<CompetingRiskResponseWithCensorTime> leftHand, List<CompetingRiskResponseWithCensorTime> rightHand) {
|
||||
if(leftHand.size() == 0 || rightHand.size() == 0){
|
||||
return null;
|
||||
}
|
||||
|
|
|
@ -17,7 +17,7 @@ public class LogRankMultipleGroupDifferentiator extends CompetingRiskGroupDiffer
|
|||
private final int[] events;
|
||||
|
||||
@Override
|
||||
public Double differentiate(List<CompetingRiskResponse> leftHand, List<CompetingRiskResponse> rightHand) {
|
||||
public Double getScore(List<CompetingRiskResponse> leftHand, List<CompetingRiskResponse> rightHand) {
|
||||
if(leftHand.size() == 0 || rightHand.size() == 0){
|
||||
return null;
|
||||
}
|
||||
|
|
|
@ -18,7 +18,7 @@ public class LogRankSingleGroupDifferentiator extends CompetingRiskGroupDifferen
|
|||
private final int[] events;
|
||||
|
||||
@Override
|
||||
public Double differentiate(List<CompetingRiskResponse> leftHand, List<CompetingRiskResponse> rightHand) {
|
||||
public Double getScore(List<CompetingRiskResponse> leftHand, List<CompetingRiskResponse> rightHand) {
|
||||
if(leftHand.size() == 0 || rightHand.size() == 0){
|
||||
return null;
|
||||
}
|
||||
|
|
|
@ -1,26 +0,0 @@
|
|||
package ca.joeltherrien.randomforest.responses.regression;
|
||||
|
||||
import ca.joeltherrien.randomforest.tree.GroupDifferentiator;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
public class MeanGroupDifferentiator implements GroupDifferentiator<Double> {
|
||||
|
||||
@Override
|
||||
public Double differentiate(List<Double> leftHand, List<Double> rightHand) {
|
||||
|
||||
double leftHandSize = leftHand.size();
|
||||
double rightHandSize = rightHand.size();
|
||||
|
||||
if(leftHandSize == 0 || rightHandSize == 0){
|
||||
return null;
|
||||
}
|
||||
|
||||
double leftHandMean = leftHand.stream().mapToDouble(db -> db/leftHandSize).sum();
|
||||
double rightHandMean = rightHand.stream().mapToDouble(db -> db/rightHandSize).sum();
|
||||
|
||||
return Math.abs(leftHandMean - rightHandMean);
|
||||
|
||||
}
|
||||
|
||||
}
|
|
@ -1,13 +1,13 @@
|
|||
package ca.joeltherrien.randomforest.responses.regression;
|
||||
|
||||
import ca.joeltherrien.randomforest.tree.GroupDifferentiator;
|
||||
import ca.joeltherrien.randomforest.tree.SimpleGroupDifferentiator;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
public class WeightedVarianceGroupDifferentiator implements GroupDifferentiator<Double> {
|
||||
public class WeightedVarianceGroupDifferentiator extends SimpleGroupDifferentiator<Double> {
|
||||
|
||||
@Override
|
||||
public Double differentiate(List<Double> leftHand, List<Double> rightHand) {
|
||||
public Double getScore(List<Double> leftHand, List<Double> rightHand) {
|
||||
|
||||
final double leftHandSize = leftHand.size();
|
||||
final double rightHandSize = rightHand.size();
|
||||
|
|
|
@ -14,8 +14,10 @@ import java.io.IOException;
|
|||
import java.io.ObjectOutputStream;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Random;
|
||||
import java.util.concurrent.ExecutorService;
|
||||
import java.util.concurrent.Executors;
|
||||
import java.util.concurrent.ThreadLocalRandom;
|
||||
import java.util.concurrent.atomic.AtomicInteger;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.Stream;
|
||||
|
@ -45,17 +47,17 @@ public class ForestTrainer<Y, TO, FO> {
|
|||
this.covariates = covariates;
|
||||
this.treeResponseCombiner = settings.getTreeCombiner();
|
||||
this.treeTrainer = new TreeTrainer<>(settings, covariates);
|
||||
|
||||
}
|
||||
|
||||
public Forest<TO, FO> trainSerial(){
|
||||
|
||||
final List<Tree<TO>> trees = new ArrayList<>(ntree);
|
||||
final Bootstrapper<Row<Y>> bootstrapper = new Bootstrapper<>(data);
|
||||
final Random random = new Random();
|
||||
|
||||
for(int j=0; j<ntree; j++){
|
||||
|
||||
trees.add(trainTree(bootstrapper));
|
||||
trees.add(trainTree(bootstrapper, random));
|
||||
|
||||
if(displayProgress){
|
||||
if(j==0) {
|
||||
|
@ -162,9 +164,9 @@ public class ForestTrainer<Y, TO, FO> {
|
|||
|
||||
}
|
||||
|
||||
private Tree<TO> trainTree(final Bootstrapper<Row<Y>> bootstrapper){
|
||||
final List<Row<Y>> bootstrappedData = bootstrapper.bootstrap();
|
||||
return treeTrainer.growTree(bootstrappedData);
|
||||
private Tree<TO> trainTree(final Bootstrapper<Row<Y>> bootstrapper, Random random){
|
||||
final List<Row<Y>> bootstrappedData = bootstrapper.bootstrap(random);
|
||||
return treeTrainer.growTree(bootstrappedData, random);
|
||||
}
|
||||
|
||||
public void saveTree(final Tree<TO> tree, String name) throws IOException {
|
||||
|
@ -193,7 +195,8 @@ public class ForestTrainer<Y, TO, FO> {
|
|||
@Override
|
||||
public void run() {
|
||||
|
||||
final Tree<TO> tree = trainTree(bootstrapper);
|
||||
// ThreadLocalRandom should make sure we don't duplicate seeds
|
||||
final Tree<TO> tree = trainTree(bootstrapper, ThreadLocalRandom.current());
|
||||
|
||||
// should be okay as the list structure isn't changing
|
||||
treeList.set(treeIndex, tree);
|
||||
|
@ -216,7 +219,8 @@ public class ForestTrainer<Y, TO, FO> {
|
|||
@Override
|
||||
public void run() {
|
||||
|
||||
final Tree<TO> tree = trainTree(bootstrapper);
|
||||
// ThreadLocalRandom should make sure we don't duplicate seeds
|
||||
final Tree<TO> tree = trainTree(bootstrapper, ThreadLocalRandom.current());
|
||||
|
||||
try {
|
||||
saveTree(tree, filename);
|
||||
|
|
|
@ -1,15 +1,17 @@
|
|||
package ca.joeltherrien.randomforest.tree;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Iterator;
|
||||
|
||||
/**
|
||||
* When choosing an optimal node to split on, we choose the split that maximizes the difference between the two groups.
|
||||
* The GroupDifferentiator has one method that outputs a score to show how different groups are. The larger the score,
|
||||
* the greater the difference.
|
||||
* The GroupDifferentiator has one method that cycles through an iterator of Splits (FYI; check if the iterator is an
|
||||
* instance of Covariate.SplitRuleUpdater; in which case you get access to the rows that change between splits)
|
||||
*
|
||||
* If you want to implement a very trivial GroupDifferentiator that just takes two Lists as arguments, try extending
|
||||
* SimpleGroupDifferentiator.
|
||||
*/
|
||||
public interface GroupDifferentiator<Y> {
|
||||
|
||||
Double differentiate(List<Y> leftHand, List<Y> rightHand);
|
||||
<V> SplitAndScore<Y, V> differentiate(Iterator<Split<Y, V>> splitIterator);
|
||||
|
||||
}
|
||||
|
|
|
@ -0,0 +1,50 @@
|
|||
package ca.joeltherrien.randomforest.tree;
|
||||
|
||||
import ca.joeltherrien.randomforest.Row;
|
||||
|
||||
import java.util.Iterator;
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
public abstract class SimpleGroupDifferentiator<Y> implements GroupDifferentiator<Y> {
|
||||
|
||||
@Override
|
||||
public <V> SplitAndScore<Y, V> differentiate(Iterator<Split<Y, V>> splitIterator) {
|
||||
Double bestScore = null;
|
||||
Split<Y, V> bestSplit = null;
|
||||
|
||||
while(splitIterator.hasNext()){
|
||||
final Split<Y, V> candidateSplit = splitIterator.next();
|
||||
|
||||
final List<Y> leftHand = candidateSplit.getLeftHand().stream().map(Row::getResponse).collect(Collectors.toList());
|
||||
final List<Y> rightHand = candidateSplit.getRightHand().stream().map(Row::getResponse).collect(Collectors.toList());
|
||||
|
||||
if(leftHand.isEmpty() || rightHand.isEmpty()){
|
||||
continue;
|
||||
}
|
||||
|
||||
final Double score = getScore(leftHand, rightHand);
|
||||
|
||||
if(score != null && (bestScore == null || score > bestScore)){
|
||||
bestScore = score;
|
||||
bestSplit = candidateSplit;
|
||||
}
|
||||
}
|
||||
|
||||
if(bestSplit == null){
|
||||
return null;
|
||||
}
|
||||
|
||||
return new SplitAndScore<>(bestSplit, bestScore);
|
||||
}
|
||||
|
||||
/**
|
||||
* Return a score; higher is better.
|
||||
*
|
||||
* @param leftHand
|
||||
* @param rightHand
|
||||
* @return
|
||||
*/
|
||||
public abstract Double getScore(List<Y> leftHand, List<Y> rightHand);
|
||||
|
||||
}
|
|
@ -1,19 +1,21 @@
|
|||
package ca.joeltherrien.randomforest.tree;
|
||||
|
||||
import ca.joeltherrien.randomforest.Row;
|
||||
import ca.joeltherrien.randomforest.covariates.Covariate;
|
||||
import lombok.Data;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* Very simple class that contains three lists; it's essentially a thruple.
|
||||
* Very simple class that contains three lists and a SplitRule.
|
||||
*
|
||||
* @author joel
|
||||
*
|
||||
*/
|
||||
@Data
|
||||
public class Split<Y> {
|
||||
public final class Split<Y, V> {
|
||||
|
||||
public final Covariate.SplitRule<V> splitRule;
|
||||
public final List<Row<Y>> leftHand;
|
||||
public final List<Row<Y>> rightHand;
|
||||
public final List<Row<Y>> naHand;
|
||||
|
|
|
@ -0,0 +1,15 @@
|
|||
package ca.joeltherrien.randomforest.tree;
|
||||
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Getter;
|
||||
|
||||
@AllArgsConstructor
|
||||
public class SplitAndScore<Y, V> {
|
||||
|
||||
@Getter
|
||||
private final Split<Y, V> split;
|
||||
|
||||
@Getter
|
||||
private final Double score;
|
||||
|
||||
}
|
|
@ -3,10 +3,11 @@ package ca.joeltherrien.randomforest.tree;
|
|||
import ca.joeltherrien.randomforest.Row;
|
||||
import ca.joeltherrien.randomforest.Settings;
|
||||
import ca.joeltherrien.randomforest.covariates.Covariate;
|
||||
import lombok.*;
|
||||
import lombok.AccessLevel;
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Builder;
|
||||
|
||||
import java.util.*;
|
||||
import java.util.concurrent.ThreadLocalRandom;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@Builder
|
||||
|
@ -45,20 +46,21 @@ public class TreeTrainer<Y, O> {
|
|||
this.covariates = covariates;
|
||||
}
|
||||
|
||||
public Tree<O> growTree(List<Row<Y>> data){
|
||||
public Tree<O> growTree(List<Row<Y>> data, Random random){
|
||||
|
||||
final Node<O> rootNode = growNode(data, 0);
|
||||
final Node<O> rootNode = growNode(data, 0, random);
|
||||
return new Tree<>(rootNode, data.stream().mapToInt(Row::getId).toArray());
|
||||
|
||||
}
|
||||
|
||||
private Node<O> growNode(List<Row<Y>> data, int depth){
|
||||
private Node<O> growNode(List<Row<Y>> data, int depth, Random random){
|
||||
// See https://kogalur.github.io/randomForestSRC/theory.html#section3.1 (near bottom)
|
||||
if(data.size() >= 2*nodeSize && depth < maxNodeDepth && !nodeIsPure(data)){
|
||||
final List<Covariate> covariatesToTry = selectCovariates(this.mtry);
|
||||
final SplitRuleAndSplit bestSplitRuleAndSplit = findBestSplitRule(data, covariatesToTry);
|
||||
final List<Covariate> covariatesToTry = selectCovariates(this.mtry, random);
|
||||
final Split<Y,?> bestSplit = findBestSplitRule(data, covariatesToTry, random);
|
||||
|
||||
if(bestSplitRuleAndSplit.splitRule == null){
|
||||
|
||||
if(bestSplit == null){
|
||||
|
||||
return new TerminalNode<>(
|
||||
responseCombiner.combine(
|
||||
|
@ -69,14 +71,31 @@ public class TreeTrainer<Y, O> {
|
|||
|
||||
}
|
||||
|
||||
final Split<Y> split = bestSplitRuleAndSplit.split;
|
||||
// Note that NAs have already been handled
|
||||
|
||||
// Now that we have the best split; we need to handle any NAs that were dropped off
|
||||
final double probabilityLeftHand = (double) bestSplit.leftHand.size() /
|
||||
(double) (bestSplit.leftHand.size() + bestSplit.rightHand.size());
|
||||
|
||||
// Assign missing values to the split
|
||||
for(Row<Y> row : data) {
|
||||
if(row.getCovariateValue(bestSplit.getSplitRule().getParent()).isNA()) {
|
||||
final boolean randomDecision = random.nextDouble() <= probabilityLeftHand;
|
||||
|
||||
if(randomDecision){
|
||||
bestSplit.getLeftHand().add(row);
|
||||
}
|
||||
else{
|
||||
bestSplit.getRightHand().add(row);
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
final Node<O> leftNode = growNode(split.leftHand, depth+1);
|
||||
final Node<O> rightNode = growNode(split.rightHand, depth+1);
|
||||
final Node<O> leftNode = growNode(bestSplit.leftHand, depth+1, random);
|
||||
final Node<O> rightNode = growNode(bestSplit.rightHand, depth+1, random);
|
||||
|
||||
return new SplitNode<>(leftNode, rightNode, bestSplitRuleAndSplit.splitRule, bestSplitRuleAndSplit.probabilityLeftHand);
|
||||
return new SplitNode<>(leftNode, rightNode, bestSplit.getSplitRule(), probabilityLeftHand);
|
||||
|
||||
}
|
||||
else{
|
||||
|
@ -90,13 +109,13 @@ public class TreeTrainer<Y, O> {
|
|||
|
||||
}
|
||||
|
||||
private List<Covariate> selectCovariates(int mtry){
|
||||
private List<Covariate> selectCovariates(int mtry, Random random){
|
||||
if(mtry >= covariates.size()){
|
||||
return covariates;
|
||||
}
|
||||
|
||||
final List<Covariate> splitCovariates = new ArrayList<>(covariates);
|
||||
Collections.shuffle(splitCovariates, ThreadLocalRandom.current());
|
||||
Collections.shuffle(splitCovariates, random);
|
||||
|
||||
if (splitCovariates.size() > mtry) {
|
||||
splitCovariates.subList(mtry, splitCovariates.size()).clear();
|
||||
|
@ -105,63 +124,29 @@ public class TreeTrainer<Y, O> {
|
|||
return splitCovariates;
|
||||
}
|
||||
|
||||
private SplitRuleAndSplit findBestSplitRule(List<Row<Y>> data, List<Covariate> covariatesToTry){
|
||||
final Random random = ThreadLocalRandom.current();
|
||||
SplitRuleAndSplit bestSplitRuleAndSplit = new SplitRuleAndSplit();
|
||||
double bestSplitScore = 0.0;
|
||||
boolean first = true;
|
||||
private Split<Y, ?> findBestSplitRule(List<Row<Y>> data, List<Covariate> covariatesToTry, Random random){
|
||||
|
||||
SplitAndScore<Y, ?> bestSplitAndScore = null;
|
||||
final GroupDifferentiator noGenericDifferentiator = groupDifferentiator; // cause Java generics suck
|
||||
|
||||
for(final Covariate covariate : covariatesToTry) {
|
||||
final Iterator<Split> iterator = covariate.generateSplitRuleUpdater(data, this.numberOfSplits, random);
|
||||
|
||||
final int numberToTry = numberOfSplits==0 ? data.size() : numberOfSplits;
|
||||
final SplitAndScore<Y, ?> candidateSplitAndScore = noGenericDifferentiator.differentiate(iterator);
|
||||
|
||||
final Collection<Covariate.SplitRule> splitRulesToTry = covariate
|
||||
.generateSplitRules(
|
||||
data
|
||||
.stream()
|
||||
.map(row -> row.getCovariateValue(covariate))
|
||||
.collect(Collectors.toList())
|
||||
, numberToTry);
|
||||
|
||||
for(final Covariate.SplitRule possibleRule : splitRulesToTry){
|
||||
final Split<Y> possibleSplit = possibleRule.applyRule(data);
|
||||
|
||||
// We have to handle any NAs
|
||||
if(possibleSplit.leftHand.size() == 0 && possibleSplit.rightHand.size() == 0 && possibleSplit.naHand.size() > 0){
|
||||
throw new IllegalArgumentException("Can't apply " + this + " when there are rows with missing data and no non-missing value rows");
|
||||
}
|
||||
|
||||
final double probabilityLeftHand = (double) possibleSplit.leftHand.size() / (double) (possibleSplit.leftHand.size() + possibleSplit.rightHand.size());
|
||||
|
||||
for(final Row<Y> missingValueRow : possibleSplit.naHand){
|
||||
final boolean randomDecision = random.nextDouble() <= probabilityLeftHand;
|
||||
if(randomDecision){
|
||||
possibleSplit.leftHand.add(missingValueRow);
|
||||
}
|
||||
else{
|
||||
possibleSplit.rightHand.add(missingValueRow);
|
||||
}
|
||||
}
|
||||
|
||||
final Double score = groupDifferentiator.differentiate(
|
||||
possibleSplit.leftHand.stream().map(row -> row.getResponse()).collect(Collectors.toList()),
|
||||
possibleSplit.rightHand.stream().map(row -> row.getResponse()).collect(Collectors.toList())
|
||||
);
|
||||
|
||||
|
||||
if(score != null && !Double.isNaN(score) && (score > bestSplitScore || first)){
|
||||
bestSplitRuleAndSplit.splitRule = possibleRule;
|
||||
bestSplitRuleAndSplit.split = possibleSplit;
|
||||
bestSplitRuleAndSplit.probabilityLeftHand = probabilityLeftHand;
|
||||
|
||||
bestSplitScore = score;
|
||||
first = false;
|
||||
if(candidateSplitAndScore != null) {
|
||||
if (bestSplitAndScore == null || candidateSplitAndScore.getScore() > bestSplitAndScore.getScore()) {
|
||||
bestSplitAndScore = candidateSplitAndScore;
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
return bestSplitRuleAndSplit;
|
||||
if(bestSplitAndScore == null){
|
||||
return null;
|
||||
}
|
||||
|
||||
return bestSplitAndScore.getSplit();
|
||||
|
||||
}
|
||||
|
||||
|
@ -184,10 +169,4 @@ public class TreeTrainer<Y, O> {
|
|||
return true;
|
||||
}
|
||||
|
||||
private class SplitRuleAndSplit{
|
||||
private Covariate.SplitRule splitRule = null;
|
||||
private Split<Y> split = null;
|
||||
private double probabilityLeftHand;
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -0,0 +1,9 @@
|
|||
package ca.joeltherrien.randomforest.utils;
|
||||
|
||||
import java.util.Iterator;
|
||||
|
||||
public interface IndexedIterator<E> extends Iterator<E> {
|
||||
|
||||
int getIndex();
|
||||
|
||||
}
|
|
@ -0,0 +1,29 @@
|
|||
package ca.joeltherrien.randomforest.utils;
|
||||
|
||||
import lombok.RequiredArgsConstructor;
|
||||
|
||||
import java.util.Iterator;
|
||||
|
||||
@RequiredArgsConstructor
|
||||
public class SingletonIterator<E> implements Iterator<E> {
|
||||
|
||||
private final E value;
|
||||
|
||||
private boolean beenCalled = false;
|
||||
|
||||
|
||||
@Override
|
||||
public boolean hasNext() {
|
||||
return !beenCalled;
|
||||
}
|
||||
|
||||
@Override
|
||||
public E next() {
|
||||
if(!beenCalled){
|
||||
beenCalled = true;
|
||||
return value;
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,65 @@
|
|||
package ca.joeltherrien.randomforest.utils;
|
||||
|
||||
|
||||
import lombok.Getter;
|
||||
|
||||
import java.util.Iterator;
|
||||
|
||||
/**
|
||||
* Iterator that wraps around a UniqueValueIterator. It continues to iterate until it gets to one of the prespecified indexes,
|
||||
* and then proceeds just past that to the end of the existing values it's at.
|
||||
*
|
||||
* The wrapped iterator must be from a sorted collection of some sort such that equal values are clumped together.
|
||||
* I.e. "b b c c c d d a a" is okay but "a b b c c a" is not as 'a' appears twice at different locations
|
||||
*
|
||||
* @param <E>
|
||||
*/
|
||||
public class UniqueSubsetValueIterator<E> implements IndexedIterator<E> {
|
||||
|
||||
private final UniqueValueIterator<E> iterator;
|
||||
private final Integer[] indexValues;
|
||||
|
||||
private int currentIndexSpot = 0;
|
||||
|
||||
public UniqueSubsetValueIterator(final UniqueValueIterator<E> iterator, final Integer[] indexValues){
|
||||
this.iterator = iterator;
|
||||
this.indexValues = indexValues;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean hasNext() {
|
||||
return iterator.hasNext() && iterator.getIndex() <= indexValues[indexValues.length-1];
|
||||
}
|
||||
|
||||
@Override
|
||||
public E next() {
|
||||
if(hasNext()){
|
||||
final int indexToStopBy = indexValues[currentIndexSpot];
|
||||
|
||||
while(iterator.getIndex() <= indexToStopBy){
|
||||
iterator.next();
|
||||
}
|
||||
|
||||
for(int i = currentIndexSpot + 1; i < indexValues.length; i++){
|
||||
if(iterator.getIndex() <= indexValues[i]){
|
||||
currentIndexSpot = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
return iterator.getCurrentValue();
|
||||
|
||||
}
|
||||
|
||||
return null;
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getIndex(){
|
||||
return iterator.getIndex();
|
||||
}
|
||||
|
||||
|
||||
}
|
|
@ -0,0 +1,71 @@
|
|||
package ca.joeltherrien.randomforest.utils;
|
||||
|
||||
|
||||
import lombok.Getter;
|
||||
|
||||
import java.util.Iterator;
|
||||
|
||||
/**
|
||||
* Iterator that wraps around another iterator. It continues to iterate until it gets to the *end* of a sequence of identical values.
|
||||
* It also tracks the current index in the original iterator.
|
||||
*
|
||||
* The wrapped iterator must be from a sorted collection of some sort such that equal values are clumped together.
|
||||
* I.e. "b b c c c d d a a" is okay but "a b b c c a" is not as 'a' appears twice at different locations
|
||||
*
|
||||
* @param <E>
|
||||
*/
|
||||
public class UniqueValueIterator<E> implements IndexedIterator<E> {
|
||||
|
||||
private final Iterator<E> wrappedIterator;
|
||||
|
||||
@Getter private E currentValue = null;
|
||||
@Getter private E nextValue;
|
||||
|
||||
public UniqueValueIterator(final Iterator<E> wrappedIterator){
|
||||
this.wrappedIterator = wrappedIterator;
|
||||
this.nextValue = wrappedIterator.next();
|
||||
}
|
||||
|
||||
// Count must return the index of the last value of the sequence returned by next()
|
||||
@Getter
|
||||
private int index = 0;
|
||||
|
||||
@Override
|
||||
public boolean hasNext() {
|
||||
return nextValue != null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public E next() {
|
||||
|
||||
int count = 1;
|
||||
while(wrappedIterator.hasNext()){
|
||||
final E currentIteratorValue = wrappedIterator.next();
|
||||
|
||||
if(currentIteratorValue.equals(nextValue)){
|
||||
count++;
|
||||
}
|
||||
else{
|
||||
index +=count;
|
||||
currentValue = nextValue;
|
||||
nextValue = currentIteratorValue;
|
||||
|
||||
return currentValue;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
if(nextValue != null){
|
||||
index += count;
|
||||
currentValue = nextValue;
|
||||
nextValue = null;
|
||||
|
||||
return currentValue;
|
||||
}
|
||||
else{
|
||||
return null;
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
}
|
|
@ -1,8 +1,8 @@
|
|||
package ca.joeltherrien.randomforest;
|
||||
|
||||
import ca.joeltherrien.randomforest.covariates.BooleanCovariateSettings;
|
||||
import ca.joeltherrien.randomforest.covariates.settings.BooleanCovariateSettings;
|
||||
import ca.joeltherrien.randomforest.covariates.Covariate;
|
||||
import ca.joeltherrien.randomforest.covariates.NumericCovariateSettings;
|
||||
import ca.joeltherrien.randomforest.covariates.settings.NumericCovariateSettings;
|
||||
import ca.joeltherrien.randomforest.responses.competingrisk.combiner.CompetingRiskFunctionCombiner;
|
||||
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskFunctions;
|
||||
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
|
||||
|
|
|
@ -2,6 +2,8 @@ package ca.joeltherrien.randomforest.competingrisk;
|
|||
|
||||
import ca.joeltherrien.randomforest.*;
|
||||
import ca.joeltherrien.randomforest.covariates.*;
|
||||
import ca.joeltherrien.randomforest.covariates.settings.BooleanCovariateSettings;
|
||||
import ca.joeltherrien.randomforest.covariates.settings.NumericCovariateSettings;
|
||||
import ca.joeltherrien.randomforest.responses.competingrisk.*;
|
||||
import ca.joeltherrien.randomforest.tree.Forest;
|
||||
import ca.joeltherrien.randomforest.tree.ForestTrainer;
|
||||
|
@ -18,6 +20,7 @@ import static org.junit.jupiter.api.Assertions.*;
|
|||
|
||||
import java.io.IOException;
|
||||
import java.util.List;
|
||||
import java.util.Random;
|
||||
|
||||
public class TestCompetingRisk {
|
||||
|
||||
|
@ -104,7 +107,7 @@ public class TestCompetingRisk {
|
|||
final List<Row<CompetingRiskResponse>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation());
|
||||
|
||||
final TreeTrainer<CompetingRiskResponse, CompetingRiskFunctions> treeTrainer = new TreeTrainer<>(settings, covariates);
|
||||
final Node<CompetingRiskFunctions> node = treeTrainer.growTree(dataset);
|
||||
final Node<CompetingRiskFunctions> node = treeTrainer.growTree(dataset, new Random());
|
||||
|
||||
final CovariateRow newRow = getPredictionRow(covariates);
|
||||
|
||||
|
@ -157,7 +160,7 @@ public class TestCompetingRisk {
|
|||
final List<Row<CompetingRiskResponse>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation());
|
||||
|
||||
final TreeTrainer<CompetingRiskResponse, CompetingRiskFunctions> treeTrainer = new TreeTrainer<>(settings, covariates);
|
||||
final Node<CompetingRiskFunctions> node = treeTrainer.growTree(dataset);
|
||||
final Node<CompetingRiskFunctions> node = treeTrainer.growTree(dataset, new Random());
|
||||
|
||||
final CovariateRow newRow = getPredictionRow(covariates);
|
||||
|
||||
|
|
|
@ -4,7 +4,7 @@ import ca.joeltherrien.randomforest.DataLoader;
|
|||
import ca.joeltherrien.randomforest.Row;
|
||||
import ca.joeltherrien.randomforest.Settings;
|
||||
import ca.joeltherrien.randomforest.covariates.Covariate;
|
||||
import ca.joeltherrien.randomforest.covariates.NumericCovariateSettings;
|
||||
import ca.joeltherrien.randomforest.covariates.settings.NumericCovariateSettings;
|
||||
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
|
||||
import ca.joeltherrien.randomforest.responses.competingrisk.differentiator.LogRankMultipleGroupDifferentiator;
|
||||
import ca.joeltherrien.randomforest.utils.Utils;
|
||||
|
@ -53,14 +53,14 @@ public class TestLogRankMultipleGroupDifferentiator {
|
|||
final List<Row<CompetingRiskResponse>> group1Bad = data.subList(0, 196);
|
||||
final List<Row<CompetingRiskResponse>> group2Bad = data.subList(196, data.size());
|
||||
|
||||
final double scoreBad = groupDifferentiator.differentiate(
|
||||
final double scoreBad = groupDifferentiator.getScore(
|
||||
group1Bad.stream().map(Row::getResponse).collect(Collectors.toList()),
|
||||
group2Bad.stream().map(Row::getResponse).collect(Collectors.toList()));
|
||||
|
||||
final List<Row<CompetingRiskResponse>> group1Good = data.subList(0, 199);
|
||||
final List<Row<CompetingRiskResponse>> group2Good= data.subList(199, data.size());
|
||||
|
||||
final double scoreGood = groupDifferentiator.differentiate(
|
||||
final double scoreGood = groupDifferentiator.getScore(
|
||||
group1Good.stream().map(Row::getResponse).collect(Collectors.toList()),
|
||||
group2Good.stream().map(Row::getResponse).collect(Collectors.toList()));
|
||||
|
||||
|
|
|
@ -1,8 +1,7 @@
|
|||
package ca.joeltherrien.randomforest.competingrisk;
|
||||
|
||||
import ca.joeltherrien.randomforest.Row;
|
||||
import ca.joeltherrien.randomforest.responses.competingrisk.*;
|
||||
import ca.joeltherrien.randomforest.responses.competingrisk.differentiator.LogRankMultipleGroupDifferentiator;
|
||||
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
|
||||
import ca.joeltherrien.randomforest.responses.competingrisk.differentiator.LogRankSingleGroupDifferentiator;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
|
@ -51,7 +50,7 @@ public class TestLogRankSingleGroupDifferentiator {
|
|||
|
||||
final LogRankSingleGroupDifferentiator differentiator = new LogRankSingleGroupDifferentiator(1, new int[]{1});
|
||||
|
||||
final double score = differentiator.differentiate(data1, data2);
|
||||
final double score = differentiator.getScore(data1, data2);
|
||||
final double margin = 0.000001;
|
||||
|
||||
// Tested using 855 method
|
||||
|
@ -71,14 +70,14 @@ public class TestLogRankSingleGroupDifferentiator {
|
|||
final List<Row<CompetingRiskResponse>> group1Good = data.subList(0, 221);
|
||||
final List<Row<CompetingRiskResponse>> group2Good = data.subList(221, data.size());
|
||||
|
||||
final double scoreGood = groupDifferentiator.differentiate(
|
||||
final double scoreGood = groupDifferentiator.getScore(
|
||||
group1Good.stream().map(Row::getResponse).collect(Collectors.toList()),
|
||||
group2Good.stream().map(Row::getResponse).collect(Collectors.toList()));
|
||||
|
||||
final List<Row<CompetingRiskResponse>> group1Bad = data.subList(0, 222);
|
||||
final List<Row<CompetingRiskResponse>> group2Bad = data.subList(222, data.size());
|
||||
|
||||
final double scoreBad = groupDifferentiator.differentiate(
|
||||
final double scoreBad = groupDifferentiator.getScore(
|
||||
group1Bad.stream().map(Row::getResponse).collect(Collectors.toList()),
|
||||
group2Bad.stream().map(Row::getResponse).collect(Collectors.toList()));
|
||||
|
||||
|
|
|
@ -5,8 +5,9 @@ import ca.joeltherrien.randomforest.utils.Utils;
|
|||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.function.Executable;
|
||||
|
||||
import java.util.Collection;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Random;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.*;
|
||||
|
||||
|
@ -46,7 +47,10 @@ public class FactorCovariateTest {
|
|||
void testAllSubsets(){
|
||||
final FactorCovariate petCovariate = createTestCovariate();
|
||||
|
||||
final Collection<FactorCovariate.FactorSplitRule> splitRules = petCovariate.generateSplitRules(null, 100);
|
||||
final List<Covariate.SplitRule<String>> splitRules = new ArrayList<>();
|
||||
|
||||
petCovariate.generateSplitRuleUpdater(null, 100, new Random())
|
||||
.forEachRemaining(split -> splitRules.add(split.getSplitRule()));
|
||||
|
||||
assertEquals(splitRules.size(), 3);
|
||||
|
||||
|
|
|
@ -3,10 +3,10 @@ package ca.joeltherrien.randomforest.csv;
|
|||
import ca.joeltherrien.randomforest.DataLoader;
|
||||
import ca.joeltherrien.randomforest.Row;
|
||||
import ca.joeltherrien.randomforest.Settings;
|
||||
import ca.joeltherrien.randomforest.covariates.BooleanCovariateSettings;
|
||||
import ca.joeltherrien.randomforest.covariates.settings.BooleanCovariateSettings;
|
||||
import ca.joeltherrien.randomforest.covariates.Covariate;
|
||||
import ca.joeltherrien.randomforest.covariates.FactorCovariateSettings;
|
||||
import ca.joeltherrien.randomforest.covariates.NumericCovariateSettings;
|
||||
import ca.joeltherrien.randomforest.covariates.settings.FactorCovariateSettings;
|
||||
import ca.joeltherrien.randomforest.covariates.settings.NumericCovariateSettings;
|
||||
import ca.joeltherrien.randomforest.utils.Utils;
|
||||
import com.fasterxml.jackson.databind.node.JsonNodeFactory;
|
||||
import com.fasterxml.jackson.databind.node.ObjectNode;
|
||||
|
@ -15,7 +15,6 @@ import org.junit.jupiter.api.Test;
|
|||
|
||||
import java.io.IOException;
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||
|
|
|
@ -3,7 +3,9 @@ package ca.joeltherrien.randomforest.settings;
|
|||
import ca.joeltherrien.randomforest.Settings;
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
|
||||
import ca.joeltherrien.randomforest.covariates.*;
|
||||
import ca.joeltherrien.randomforest.covariates.settings.BooleanCovariateSettings;
|
||||
import ca.joeltherrien.randomforest.covariates.settings.FactorCovariateSettings;
|
||||
import ca.joeltherrien.randomforest.covariates.settings.NumericCovariateSettings;
|
||||
import ca.joeltherrien.randomforest.utils.Utils;
|
||||
import com.fasterxml.jackson.databind.node.JsonNodeFactory;
|
||||
import com.fasterxml.jackson.databind.node.ObjectNode;
|
||||
|
|
|
@ -0,0 +1,25 @@
|
|||
package ca.joeltherrien.randomforest.utils;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.*;
|
||||
|
||||
public class TestSingletonIterator {
|
||||
|
||||
@Test
|
||||
public void verifyBehaviour(){
|
||||
final Integer element = 5;
|
||||
|
||||
final SingletonIterator<Integer> iterator = new SingletonIterator<>(element);
|
||||
|
||||
assertTrue(iterator.hasNext());
|
||||
assertTrue(iterator.hasNext());
|
||||
|
||||
assertEquals(Integer.valueOf(5), iterator.next());
|
||||
|
||||
assertFalse(iterator.hasNext());
|
||||
assertNull(iterator.next());
|
||||
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,68 @@
|
|||
package ca.joeltherrien.randomforest.utils;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.*;
|
||||
|
||||
public class TestUniqueSubsetValueIterator {
|
||||
|
||||
@Test
|
||||
public void testIterator1(){
|
||||
final List<Integer> testData = Arrays.asList(
|
||||
1,1,2,3,5,5,5,6,7,7
|
||||
);
|
||||
|
||||
final Integer[] indexes = new Integer[]{2,3,4,5};
|
||||
|
||||
final UniqueValueIterator<Integer> uniqueValueIterator = new UniqueValueIterator<>(testData.iterator());
|
||||
final UniqueSubsetValueIterator<Integer> iterator = new UniqueSubsetValueIterator<>(uniqueValueIterator, indexes);
|
||||
|
||||
// we expect to get 2, 3, and 5 back. 5 should happen only once
|
||||
|
||||
assertEquals(iterator.getIndex(), 0);
|
||||
assertTrue(iterator.hasNext());
|
||||
|
||||
assertEquals(iterator.next().intValue(), 2);
|
||||
assertEquals(iterator.getIndex(), 3);
|
||||
assertTrue(iterator.hasNext());
|
||||
|
||||
assertEquals(iterator.next().intValue(), 3);
|
||||
assertEquals(iterator.getIndex(), 4);
|
||||
assertTrue(iterator.hasNext());
|
||||
|
||||
assertEquals(iterator.next().intValue(), 5);
|
||||
assertEquals(iterator.getIndex(), 7);
|
||||
assertFalse(iterator.hasNext());
|
||||
|
||||
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testIterator2(){
|
||||
final List<Integer> testData = Arrays.asList(
|
||||
1,1,2,3,5,5,5,6,7,7
|
||||
);
|
||||
|
||||
final Integer[] indexes = new Integer[]{1,8};
|
||||
|
||||
final UniqueValueIterator<Integer> uniqueValueIterator = new UniqueValueIterator<>(testData.iterator());
|
||||
final UniqueSubsetValueIterator<Integer> iterator = new UniqueSubsetValueIterator<>(uniqueValueIterator, indexes);
|
||||
|
||||
assertEquals(iterator.getIndex(), 0);
|
||||
assertTrue(iterator.hasNext());
|
||||
|
||||
assertEquals(iterator.next().intValue(), 1);
|
||||
assertEquals(iterator.getIndex(), 2);
|
||||
assertTrue(iterator.hasNext());
|
||||
|
||||
assertEquals(iterator.next().intValue(), 7);
|
||||
assertEquals(iterator.getIndex(), 10);
|
||||
assertFalse(iterator.hasNext());
|
||||
|
||||
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,112 @@
|
|||
package ca.joeltherrien.randomforest.utils;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.*;
|
||||
|
||||
public class TestUniqueValueIterator {
|
||||
|
||||
@Test
|
||||
public void testIterator1(){
|
||||
final List<Integer> testData = Arrays.asList(
|
||||
1,1,2,3,5,5,5,6,7,7
|
||||
);
|
||||
|
||||
final UniqueValueIterator<Integer> iterator = new UniqueValueIterator<>(testData.iterator());
|
||||
|
||||
assertEquals(iterator.getIndex(), 0);
|
||||
assertTrue(iterator.hasNext());
|
||||
|
||||
assertEquals(iterator.next().intValue(), 1);
|
||||
assertEquals(iterator.getIndex(), 2);
|
||||
assertTrue(iterator.hasNext());
|
||||
|
||||
assertEquals(iterator.next().intValue(), 2);
|
||||
assertEquals(iterator.getIndex(), 3);
|
||||
assertTrue(iterator.hasNext());
|
||||
|
||||
assertEquals(iterator.next().intValue(), 3);
|
||||
assertEquals(iterator.getIndex(), 4);
|
||||
assertTrue(iterator.hasNext());
|
||||
|
||||
assertEquals(iterator.next().intValue(), 5);
|
||||
assertEquals(iterator.getIndex(), 7);
|
||||
assertTrue(iterator.hasNext());
|
||||
|
||||
assertEquals(iterator.next().intValue(), 6);
|
||||
assertEquals(iterator.getIndex(), 8);
|
||||
assertTrue(iterator.hasNext());
|
||||
|
||||
assertEquals(iterator.next().intValue(), 7);
|
||||
assertEquals(iterator.getIndex(), 10);
|
||||
assertTrue(!iterator.hasNext());
|
||||
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testIterator2(){
|
||||
final List<Integer> testData = Arrays.asList(
|
||||
1,2,3,5,5,5,6,7 // same numbers; but 1 and 7 only appear once each
|
||||
);
|
||||
|
||||
final UniqueValueIterator<Integer> iterator = new UniqueValueIterator<>(testData.iterator());
|
||||
|
||||
assertEquals(iterator.getIndex(), 0);
|
||||
assertTrue(iterator.hasNext());
|
||||
|
||||
assertEquals(iterator.next().intValue(), 1);
|
||||
assertEquals(iterator.getIndex(), 1);
|
||||
assertTrue(iterator.hasNext());
|
||||
|
||||
assertEquals(iterator.next().intValue(), 2);
|
||||
assertEquals(iterator.getIndex(), 2);
|
||||
assertTrue(iterator.hasNext());
|
||||
|
||||
assertEquals(iterator.next().intValue(), 3);
|
||||
assertEquals(iterator.getIndex(), 3);
|
||||
assertTrue(iterator.hasNext());
|
||||
|
||||
assertEquals(iterator.next().intValue(), 5);
|
||||
assertEquals(iterator.getIndex(), 6);
|
||||
assertTrue(iterator.hasNext());
|
||||
|
||||
assertEquals(iterator.next().intValue(), 6);
|
||||
assertEquals(iterator.getIndex(), 7);
|
||||
assertTrue(iterator.hasNext());
|
||||
|
||||
assertEquals(iterator.next().intValue(), 7);
|
||||
assertEquals(iterator.getIndex(), 8);
|
||||
assertFalse(iterator.hasNext());
|
||||
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testIterator3(){
|
||||
final List<Integer> testData = Arrays.asList(
|
||||
1,1,1,1,1,1,1,2,2,2,2,2,3
|
||||
);
|
||||
|
||||
final UniqueValueIterator<Integer> iterator = new UniqueValueIterator<>(testData.iterator());
|
||||
|
||||
assertEquals(iterator.getIndex(), 0);
|
||||
assertTrue(iterator.hasNext());
|
||||
|
||||
assertEquals(iterator.next().intValue(), 1);
|
||||
assertEquals(iterator.getIndex(), 7);
|
||||
assertTrue(iterator.hasNext());
|
||||
|
||||
assertEquals(iterator.next().intValue(), 2);
|
||||
assertEquals(iterator.getIndex(), 12);
|
||||
assertTrue(iterator.hasNext());
|
||||
|
||||
assertEquals(iterator.next().intValue(), 3);
|
||||
assertEquals(iterator.getIndex(), 13);
|
||||
assertFalse(iterator.hasNext());
|
||||
|
||||
|
||||
}
|
||||
|
||||
}
|
|
@ -2,7 +2,7 @@ package ca.joeltherrien.randomforest.workshop;
|
|||
|
||||
import ca.joeltherrien.randomforest.*;
|
||||
import ca.joeltherrien.randomforest.covariates.Covariate;
|
||||
import ca.joeltherrien.randomforest.covariates.NumericCovariate;
|
||||
import ca.joeltherrien.randomforest.covariates.numeric.NumericCovariate;
|
||||
import ca.joeltherrien.randomforest.responses.regression.MeanResponseCombiner;
|
||||
import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceGroupDifferentiator;
|
||||
import ca.joeltherrien.randomforest.tree.ForestTrainer;
|
||||
|
|
|
@ -3,7 +3,7 @@ package ca.joeltherrien.randomforest.workshop;
|
|||
|
||||
import ca.joeltherrien.randomforest.covariates.Covariate;
|
||||
import ca.joeltherrien.randomforest.CovariateRow;
|
||||
import ca.joeltherrien.randomforest.covariates.NumericCovariate;
|
||||
import ca.joeltherrien.randomforest.covariates.numeric.NumericCovariate;
|
||||
import ca.joeltherrien.randomforest.Row;
|
||||
import ca.joeltherrien.randomforest.responses.regression.MeanResponseCombiner;
|
||||
import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceGroupDifferentiator;
|
||||
|
@ -54,13 +54,14 @@ public class TrainSingleTree {
|
|||
.covariates(covariateNames)
|
||||
.responseCombiner(new MeanResponseCombiner())
|
||||
.maxNodeDepth(30)
|
||||
.mtry(2)
|
||||
.nodeSize(5)
|
||||
.numberOfSplits(0)
|
||||
.build();
|
||||
|
||||
|
||||
final long startTime = System.currentTimeMillis();
|
||||
final Node<Double> baseNode = treeTrainer.growTree(trainingSet);
|
||||
final Node<Double> baseNode = treeTrainer.growTree(trainingSet, new Random());
|
||||
final long endTime = System.currentTimeMillis();
|
||||
|
||||
System.out.println(((double)(endTime - startTime))/1000.0);
|
||||
|
|
|
@ -4,7 +4,7 @@ package ca.joeltherrien.randomforest.workshop;
|
|||
import ca.joeltherrien.randomforest.*;
|
||||
import ca.joeltherrien.randomforest.covariates.Covariate;
|
||||
import ca.joeltherrien.randomforest.covariates.FactorCovariate;
|
||||
import ca.joeltherrien.randomforest.covariates.NumericCovariate;
|
||||
import ca.joeltherrien.randomforest.covariates.numeric.NumericCovariate;
|
||||
import ca.joeltherrien.randomforest.responses.regression.MeanResponseCombiner;
|
||||
import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceGroupDifferentiator;
|
||||
import ca.joeltherrien.randomforest.tree.Node;
|
||||
|
@ -13,7 +13,6 @@ import ca.joeltherrien.randomforest.utils.Utils;
|
|||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Random;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.DoubleStream;
|
||||
|
@ -21,8 +20,6 @@ import java.util.stream.DoubleStream;
|
|||
public class TrainSingleTreeFactor {
|
||||
|
||||
public static void main(String[] args) {
|
||||
System.out.println("Hello world!");
|
||||
|
||||
final Random random = new Random(123);
|
||||
|
||||
final int n = 10000;
|
||||
|
@ -84,7 +81,7 @@ public class TrainSingleTreeFactor {
|
|||
|
||||
|
||||
final long startTime = System.currentTimeMillis();
|
||||
final Node<Double> baseNode = treeTrainer.growTree(trainingSet);
|
||||
final Node<Double> baseNode = treeTrainer.growTree(trainingSet, random);
|
||||
final long endTime = System.currentTimeMillis();
|
||||
|
||||
System.out.println(((double)(endTime - startTime))/1000.0);
|
||||
|
|
Loading…
Reference in a new issue