Add optimizations. #13
45 changed files with 741 additions and 212 deletions
|
@ -4,20 +4,20 @@ import lombok.RequiredArgsConstructor;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.concurrent.ThreadLocalRandom;
|
import java.util.Random;
|
||||||
|
|
||||||
@RequiredArgsConstructor
|
@RequiredArgsConstructor
|
||||||
public class Bootstrapper<T> {
|
public class Bootstrapper<T> {
|
||||||
|
|
||||||
final private List<T> originalData;
|
final private List<T> originalData;
|
||||||
|
|
||||||
public List<T> bootstrap(){
|
public List<T> bootstrap(Random random){
|
||||||
final int n = originalData.size();
|
final int n = originalData.size();
|
||||||
|
|
||||||
final List<T> newList = new ArrayList<>(n);
|
final List<T> newList = new ArrayList<>(n);
|
||||||
|
|
||||||
for(int i=0; i<n; i++){
|
for(int i=0; i<n; i++){
|
||||||
final int index = ThreadLocalRandom.current().nextInt(n);
|
final int index = random.nextInt(n);
|
||||||
|
|
||||||
newList.add(originalData.get(index));
|
newList.add(originalData.get(index));
|
||||||
}
|
}
|
||||||
|
|
|
@ -17,7 +17,7 @@ public class CovariateRow implements Serializable {
|
||||||
@Getter
|
@Getter
|
||||||
private final int id;
|
private final int id;
|
||||||
|
|
||||||
public Covariate.Value<?> getCovariateValue(Covariate covariate){
|
public <V> Covariate.Value<V> getCovariateValue(Covariate<V> covariate){
|
||||||
return valueArray[covariate.getIndex()];
|
return valueArray[covariate.getIndex()];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,9 +1,9 @@
|
||||||
package ca.joeltherrien.randomforest;
|
package ca.joeltherrien.randomforest;
|
||||||
|
|
||||||
import ca.joeltherrien.randomforest.covariates.BooleanCovariateSettings;
|
import ca.joeltherrien.randomforest.covariates.settings.BooleanCovariateSettings;
|
||||||
import ca.joeltherrien.randomforest.covariates.Covariate;
|
import ca.joeltherrien.randomforest.covariates.Covariate;
|
||||||
import ca.joeltherrien.randomforest.covariates.FactorCovariateSettings;
|
import ca.joeltherrien.randomforest.covariates.settings.FactorCovariateSettings;
|
||||||
import ca.joeltherrien.randomforest.covariates.NumericCovariateSettings;
|
import ca.joeltherrien.randomforest.covariates.settings.NumericCovariateSettings;
|
||||||
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskErrorRateCalculator;
|
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskErrorRateCalculator;
|
||||||
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskFunctions;
|
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskFunctions;
|
||||||
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
|
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
|
||||||
|
|
|
@ -17,6 +17,8 @@ public class Row<Y> extends CovariateRow {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
public Y getResponse() {
|
public Y getResponse() {
|
||||||
return this.response;
|
return this.response;
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
package ca.joeltherrien.randomforest;
|
package ca.joeltherrien.randomforest;
|
||||||
|
|
||||||
import ca.joeltherrien.randomforest.covariates.Covariate;
|
import ca.joeltherrien.randomforest.covariates.Covariate;
|
||||||
import ca.joeltherrien.randomforest.covariates.CovariateSettings;
|
import ca.joeltherrien.randomforest.covariates.settings.CovariateSettings;
|
||||||
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
|
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
|
||||||
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponseWithCensorTime;
|
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponseWithCensorTime;
|
||||||
import ca.joeltherrien.randomforest.responses.competingrisk.combiner.CompetingRiskFunctionCombiner;
|
import ca.joeltherrien.randomforest.responses.competingrisk.combiner.CompetingRiskFunctionCombiner;
|
||||||
|
@ -10,7 +10,6 @@ import ca.joeltherrien.randomforest.responses.competingrisk.differentiator.GrayL
|
||||||
import ca.joeltherrien.randomforest.responses.competingrisk.differentiator.GrayLogRankSingleGroupDifferentiator;
|
import ca.joeltherrien.randomforest.responses.competingrisk.differentiator.GrayLogRankSingleGroupDifferentiator;
|
||||||
import ca.joeltherrien.randomforest.responses.competingrisk.differentiator.LogRankMultipleGroupDifferentiator;
|
import ca.joeltherrien.randomforest.responses.competingrisk.differentiator.LogRankMultipleGroupDifferentiator;
|
||||||
import ca.joeltherrien.randomforest.responses.competingrisk.differentiator.LogRankSingleGroupDifferentiator;
|
import ca.joeltherrien.randomforest.responses.competingrisk.differentiator.LogRankSingleGroupDifferentiator;
|
||||||
import ca.joeltherrien.randomforest.responses.regression.MeanGroupDifferentiator;
|
|
||||||
import ca.joeltherrien.randomforest.responses.regression.MeanResponseCombiner;
|
import ca.joeltherrien.randomforest.responses.regression.MeanResponseCombiner;
|
||||||
import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceGroupDifferentiator;
|
import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceGroupDifferentiator;
|
||||||
import ca.joeltherrien.randomforest.tree.GroupDifferentiator;
|
import ca.joeltherrien.randomforest.tree.GroupDifferentiator;
|
||||||
|
@ -68,9 +67,6 @@ public class Settings {
|
||||||
GROUP_DIFFERENTIATOR_MAP.put(name.toLowerCase(), groupDifferentiatorConstructor);
|
GROUP_DIFFERENTIATOR_MAP.put(name.toLowerCase(), groupDifferentiatorConstructor);
|
||||||
}
|
}
|
||||||
static{
|
static{
|
||||||
registerGroupDifferentiatorConstructor("MeanGroupDifferentiator",
|
|
||||||
(node) -> new MeanGroupDifferentiator()
|
|
||||||
);
|
|
||||||
registerGroupDifferentiatorConstructor("WeightedVarianceGroupDifferentiator",
|
registerGroupDifferentiatorConstructor("WeightedVarianceGroupDifferentiator",
|
||||||
(node) -> new WeightedVarianceGroupDifferentiator()
|
(node) -> new WeightedVarianceGroupDifferentiator()
|
||||||
);
|
);
|
||||||
|
|
|
@ -1,14 +1,15 @@
|
||||||
package ca.joeltherrien.randomforest.covariates;
|
package ca.joeltherrien.randomforest.covariates;
|
||||||
|
|
||||||
|
import ca.joeltherrien.randomforest.Row;
|
||||||
|
import ca.joeltherrien.randomforest.tree.Split;
|
||||||
|
import ca.joeltherrien.randomforest.utils.SingletonIterator;
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
import lombok.RequiredArgsConstructor;
|
import lombok.RequiredArgsConstructor;
|
||||||
|
|
||||||
import java.util.Collection;
|
import java.util.*;
|
||||||
import java.util.Collections;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
@RequiredArgsConstructor
|
@RequiredArgsConstructor
|
||||||
public final class BooleanCovariate implements Covariate<Boolean>{
|
public final class BooleanCovariate implements Covariate<Boolean> {
|
||||||
|
|
||||||
@Getter
|
@Getter
|
||||||
private final String name;
|
private final String name;
|
||||||
|
@ -19,8 +20,8 @@ public final class BooleanCovariate implements Covariate<Boolean>{
|
||||||
private final BooleanSplitRule splitRule = new BooleanSplitRule(); // there's only one possible rule for BooleanCovariates.
|
private final BooleanSplitRule splitRule = new BooleanSplitRule(); // there's only one possible rule for BooleanCovariates.
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Collection<BooleanSplitRule> generateSplitRules(List<Value<Boolean>> data, int number) {
|
public <Y> Iterator<Split<Y, Boolean>> generateSplitRuleUpdater(List<Row<Y>> data, int number, Random random) {
|
||||||
return Collections.singleton(splitRule);
|
return new SingletonIterator<>(BooleanCovariate.this.splitRule.applyRule(data));
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -74,6 +75,7 @@ public final class BooleanCovariate implements Covariate<Boolean>{
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
public class BooleanSplitRule implements SplitRule<Boolean>{
|
public class BooleanSplitRule implements SplitRule<Boolean>{
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -5,10 +5,7 @@ import ca.joeltherrien.randomforest.Row;
|
||||||
import ca.joeltherrien.randomforest.tree.Split;
|
import ca.joeltherrien.randomforest.tree.Split;
|
||||||
|
|
||||||
import java.io.Serializable;
|
import java.io.Serializable;
|
||||||
import java.util.ArrayList;
|
import java.util.*;
|
||||||
import java.util.Collection;
|
|
||||||
import java.util.LinkedList;
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.concurrent.ThreadLocalRandom;
|
import java.util.concurrent.ThreadLocalRandom;
|
||||||
|
|
||||||
public interface Covariate<V> extends Serializable {
|
public interface Covariate<V> extends Serializable {
|
||||||
|
@ -17,7 +14,7 @@ public interface Covariate<V> extends Serializable {
|
||||||
|
|
||||||
int getIndex();
|
int getIndex();
|
||||||
|
|
||||||
Collection<? extends SplitRule<V>> generateSplitRules(final List<Value<V>> data, final int number);
|
<Y> Iterator<Split<Y, V>> generateSplitRuleUpdater(final List<Row<Y>> data, final int number, final Random random);
|
||||||
|
|
||||||
Value<V> createValue(V value);
|
Value<V> createValue(V value);
|
||||||
|
|
||||||
|
@ -39,6 +36,16 @@ public interface Covariate<V> extends Serializable {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
interface SplitRuleUpdater<Y, V> extends Iterator<Split<Y, V>>{
|
||||||
|
Split<Y, V> currentSplit();
|
||||||
|
SplitUpdate<Y, V> nextUpdate();
|
||||||
|
}
|
||||||
|
|
||||||
|
interface SplitUpdate<Y, V> {
|
||||||
|
SplitRule<V> getSplitRule();
|
||||||
|
Collection<Row<Y>> rowsMovedToLeftHand();
|
||||||
|
}
|
||||||
|
|
||||||
interface SplitRule<V> extends Serializable{
|
interface SplitRule<V> extends Serializable{
|
||||||
|
|
||||||
Covariate<V> getParent();
|
Covariate<V> getParent();
|
||||||
|
@ -51,7 +58,7 @@ public interface Covariate<V> extends Serializable {
|
||||||
* @param <Y>
|
* @param <Y>
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
default <Y> Split<Y> applyRule(List<Row<Y>> rows) {
|
default <Y> Split<Y, V> applyRule(List<Row<Y>> rows) {
|
||||||
final List<Row<Y>> leftHand = new LinkedList<>();
|
final List<Row<Y>> leftHand = new LinkedList<>();
|
||||||
final List<Row<Y>> rightHand = new LinkedList<>();
|
final List<Row<Y>> rightHand = new LinkedList<>();
|
||||||
|
|
||||||
|
@ -77,7 +84,7 @@ public interface Covariate<V> extends Serializable {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
return new Split<>(leftHand, rightHand, missingValueRows);
|
return new Split<>(this, leftHand, rightHand, missingValueRows);
|
||||||
}
|
}
|
||||||
|
|
||||||
default boolean isLeftHand(CovariateRow row, final double probabilityNaLeftHand){
|
default boolean isLeftHand(CovariateRow row, final double probabilityNaLeftHand){
|
||||||
|
|
|
@ -1,5 +1,7 @@
|
||||||
package ca.joeltherrien.randomforest.covariates;
|
package ca.joeltherrien.randomforest.covariates;
|
||||||
|
|
||||||
|
import ca.joeltherrien.randomforest.Row;
|
||||||
|
import ca.joeltherrien.randomforest.tree.Split;
|
||||||
import lombok.EqualsAndHashCode;
|
import lombok.EqualsAndHashCode;
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
|
|
||||||
|
@ -42,17 +44,14 @@ public final class FactorCovariate implements Covariate<String>{
|
||||||
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Set<FactorSplitRule> generateSplitRules(List<Value<String>> data, int number) {
|
public <Y> Iterator<Split<Y, String>> generateSplitRuleUpdater(List<Row<Y>> data, int number, Random random) {
|
||||||
final Set<FactorSplitRule> splitRules = new HashSet<>();
|
final Set<Split<Y, String>> splits = new HashSet<>();
|
||||||
|
|
||||||
// This is to ensure we don't get stuck in an infinite loop for small factors
|
// This is to ensure we don't get stuck in an infinite loop for small factors
|
||||||
number = Math.min(number, numberOfPossiblePairings);
|
number = Math.min(number, numberOfPossiblePairings);
|
||||||
final Random random = ThreadLocalRandom.current();
|
|
||||||
final List<FactorValue> levels = new ArrayList<>(factorLevels.values());
|
final List<FactorValue> levels = new ArrayList<>(factorLevels.values());
|
||||||
|
|
||||||
|
while(splits.size() < number){
|
||||||
|
|
||||||
while(splitRules.size() < number){
|
|
||||||
Collections.shuffle(levels, random);
|
Collections.shuffle(levels, random);
|
||||||
final Set<FactorValue> leftSideValues = new HashSet<>();
|
final Set<FactorValue> leftSideValues = new HashSet<>();
|
||||||
leftSideValues.add(levels.get(0));
|
leftSideValues.add(levels.get(0));
|
||||||
|
@ -63,13 +62,14 @@ public final class FactorCovariate implements Covariate<String>{
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
splitRules.add(new FactorSplitRule(leftSideValues));
|
splits.add(new FactorSplitRule(leftSideValues).applyRule(data));
|
||||||
}
|
}
|
||||||
|
|
||||||
return splitRules;
|
return splits.iterator();
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public FactorValue createValue(String value) {
|
public FactorValue createValue(String value) {
|
||||||
if(value == null || value.equalsIgnoreCase("na")){
|
if(value == null || value.equalsIgnoreCase("na")){
|
||||||
|
|
|
@ -1,17 +1,21 @@
|
||||||
package ca.joeltherrien.randomforest.covariates;
|
package ca.joeltherrien.randomforest.covariates.numeric;
|
||||||
|
|
||||||
|
import ca.joeltherrien.randomforest.Row;
|
||||||
|
import ca.joeltherrien.randomforest.covariates.Covariate;
|
||||||
|
import ca.joeltherrien.randomforest.utils.IndexedIterator;
|
||||||
|
import ca.joeltherrien.randomforest.utils.UniqueSubsetValueIterator;
|
||||||
|
import ca.joeltherrien.randomforest.utils.UniqueValueIterator;
|
||||||
import lombok.EqualsAndHashCode;
|
import lombok.EqualsAndHashCode;
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
import lombok.RequiredArgsConstructor;
|
import lombok.RequiredArgsConstructor;
|
||||||
import lombok.ToString;
|
import lombok.ToString;
|
||||||
|
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
import java.util.concurrent.ThreadLocalRandom;
|
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
@RequiredArgsConstructor
|
@RequiredArgsConstructor
|
||||||
@ToString
|
@ToString
|
||||||
public final class NumericCovariate implements Covariate<Double>{
|
public final class NumericCovariate implements Covariate<Double> {
|
||||||
|
|
||||||
@Getter
|
@Getter
|
||||||
private final String name;
|
private final String name;
|
||||||
|
@ -20,40 +24,44 @@ public final class NumericCovariate implements Covariate<Double>{
|
||||||
private final int index;
|
private final int index;
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Collection<NumericSplitRule> generateSplitRules(List<Value<Double>> data, int number) {
|
public <Y> NumericSplitRuleUpdater<Y> generateSplitRuleUpdater(List<Row<Y>> data, int number, Random random) {
|
||||||
|
data = data.stream()
|
||||||
|
.filter(row -> !row.getCovariateValue(this).isNA())
|
||||||
|
.sorted((r1, r2) -> {
|
||||||
|
Double d1 = r1.getCovariateValue(this).getValue();
|
||||||
|
Double d2 = r2.getCovariateValue(this).getValue();
|
||||||
|
|
||||||
final Random random = ThreadLocalRandom.current();
|
return d1.compareTo(d2);
|
||||||
|
})
|
||||||
|
.collect(Collectors.toList());
|
||||||
|
|
||||||
// only work with non-NA values
|
Iterator<Double> sortedDataIterator = data.stream()
|
||||||
data = data.stream().filter(value -> !value.isNA()).collect(Collectors.toList());
|
.map(row -> row.getCovariateValue(this).getValue())
|
||||||
//data = data.stream().filter(value -> !value.isNA()).distinct().collect(Collectors.toList()); // TODO which to use?
|
.filter(v -> v != null)
|
||||||
|
.iterator();
|
||||||
|
|
||||||
// for this implementation we need to shuffle the data
|
|
||||||
final List<Value<Double>> shuffledData;
|
final IndexedIterator<Double> dataIterator;
|
||||||
if(number >= data.size()){
|
if(number == 0){
|
||||||
shuffledData = data;
|
dataIterator = new UniqueValueIterator<>(sortedDataIterator);
|
||||||
}
|
}
|
||||||
else{ // only need the top number entries
|
else{ // random splitting; we will not weight based on how many times a Row appears in the bootstrap sample
|
||||||
shuffledData = new ArrayList<>(number);
|
final TreeSet<Integer> indexSet = new TreeSet<>();
|
||||||
final Set<Integer> indexesToUse = new HashSet<>();
|
|
||||||
//final List<Integer> indexesToUse = new ArrayList<>(); // TODO which to use?
|
|
||||||
|
|
||||||
while(indexesToUse.size() < number){
|
final int maxIndex = data.size();
|
||||||
final int index = random.nextInt(data.size());
|
|
||||||
|
|
||||||
if(indexesToUse.add(index)){
|
for(int i=0; i<number; i++){
|
||||||
shuffledData.add(data.get(index));
|
indexSet.add(random.nextInt(maxIndex));
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
dataIterator = new UniqueSubsetValueIterator<>(
|
||||||
|
new UniqueValueIterator<>(sortedDataIterator),
|
||||||
|
indexSet.toArray(new Integer[indexSet.size()]) // TODO verify this is ordered
|
||||||
|
);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return shuffledData.stream()
|
return new NumericSplitRuleUpdater<>(this, data, dataIterator);
|
||||||
.mapToDouble(v -> v.getValue())
|
|
||||||
.mapToObj(threshold -> new NumericSplitRule(threshold))
|
|
||||||
.collect(Collectors.toSet());
|
|
||||||
// by returning a set we'll make everything far more efficient as a lot of rules can repeat due to bootstrapping
|
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -101,7 +109,7 @@ public final class NumericCovariate implements Covariate<Double>{
|
||||||
|
|
||||||
private final double threshold;
|
private final double threshold;
|
||||||
|
|
||||||
private NumericSplitRule(final double threshold){
|
NumericSplitRule(final double threshold){
|
||||||
this.threshold = threshold;
|
this.threshold = threshold;
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,79 @@
|
||||||
|
package ca.joeltherrien.randomforest.covariates.numeric;
|
||||||
|
|
||||||
|
import ca.joeltherrien.randomforest.Row;
|
||||||
|
import ca.joeltherrien.randomforest.covariates.Covariate;
|
||||||
|
import ca.joeltherrien.randomforest.tree.Split;
|
||||||
|
import ca.joeltherrien.randomforest.utils.IndexedIterator;
|
||||||
|
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
public class NumericSplitRuleUpdater<Y> implements Covariate.SplitRuleUpdater<Y, Double> {
|
||||||
|
|
||||||
|
private final NumericCovariate covariate;
|
||||||
|
private final List<Row<Y>> orderedData;
|
||||||
|
private final IndexedIterator<Double> dataIterator;
|
||||||
|
|
||||||
|
private Split<Y, Double> currentSplit;
|
||||||
|
|
||||||
|
public NumericSplitRuleUpdater(final NumericCovariate covariate, final List<Row<Y>> orderedData, final IndexedIterator<Double> iterator){
|
||||||
|
this.covariate = covariate;
|
||||||
|
this.orderedData = orderedData;
|
||||||
|
this.dataIterator = iterator;
|
||||||
|
|
||||||
|
final List<Row<Y>> leftHandList = Collections.emptyList();
|
||||||
|
final List<Row<Y>> rightHandList = orderedData;
|
||||||
|
|
||||||
|
this.currentSplit = new Split<>(
|
||||||
|
covariate.new NumericSplitRule(Double.MIN_VALUE),
|
||||||
|
leftHandList,
|
||||||
|
rightHandList,
|
||||||
|
Collections.emptyList());
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Split<Y, Double> currentSplit() {
|
||||||
|
return this.currentSplit;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public NumericSplitUpdate<Y> nextUpdate() {
|
||||||
|
if(hasNext()){
|
||||||
|
final int currentPosition = dataIterator.getIndex();
|
||||||
|
final Double splitValue = dataIterator.next();
|
||||||
|
final int newPosition = dataIterator.getIndex();
|
||||||
|
|
||||||
|
final List<Row<Y>> rowsMoved = orderedData.subList(currentPosition, newPosition);
|
||||||
|
|
||||||
|
final NumericCovariate.NumericSplitRule splitRule = covariate.new NumericSplitRule(splitValue);
|
||||||
|
|
||||||
|
// Update current split
|
||||||
|
this.currentSplit = new Split<>(
|
||||||
|
splitRule,
|
||||||
|
orderedData.subList(0, newPosition),
|
||||||
|
orderedData.subList(newPosition, orderedData.size()),
|
||||||
|
Collections.emptyList());
|
||||||
|
|
||||||
|
|
||||||
|
return new NumericSplitUpdate<>(splitRule, rowsMoved);
|
||||||
|
}
|
||||||
|
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean hasNext() {
|
||||||
|
return dataIterator.hasNext();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Split<Y, Double> next() {
|
||||||
|
if(hasNext()){
|
||||||
|
nextUpdate();
|
||||||
|
}
|
||||||
|
|
||||||
|
return this.currentSplit();
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,24 @@
|
||||||
|
package ca.joeltherrien.randomforest.covariates.numeric;
|
||||||
|
|
||||||
|
import ca.joeltherrien.randomforest.Row;
|
||||||
|
import ca.joeltherrien.randomforest.covariates.Covariate;
|
||||||
|
import lombok.AllArgsConstructor;
|
||||||
|
|
||||||
|
import java.util.Collection;
|
||||||
|
|
||||||
|
@AllArgsConstructor
|
||||||
|
public class NumericSplitUpdate<Y> implements Covariate.SplitUpdate<Y, Double> {
|
||||||
|
|
||||||
|
private final NumericCovariate.NumericSplitRule numericSplitRule;
|
||||||
|
private final Collection<Row<Y>> rowsMoved;
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public NumericCovariate.NumericSplitRule getSplitRule() {
|
||||||
|
return numericSplitRule;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Collection<Row<Y>> rowsMovedToLeftHand() {
|
||||||
|
return rowsMoved;
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,5 +1,6 @@
|
||||||
package ca.joeltherrien.randomforest.covariates;
|
package ca.joeltherrien.randomforest.covariates.settings;
|
||||||
|
|
||||||
|
import ca.joeltherrien.randomforest.covariates.BooleanCovariate;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import lombok.NoArgsConstructor;
|
import lombok.NoArgsConstructor;
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
package ca.joeltherrien.randomforest.covariates;
|
package ca.joeltherrien.randomforest.covariates.settings;
|
||||||
|
|
||||||
|
import ca.joeltherrien.randomforest.covariates.Covariate;
|
||||||
import com.fasterxml.jackson.annotation.JsonSubTypes;
|
import com.fasterxml.jackson.annotation.JsonSubTypes;
|
||||||
import com.fasterxml.jackson.annotation.JsonTypeInfo;
|
import com.fasterxml.jackson.annotation.JsonTypeInfo;
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
|
@ -1,5 +1,6 @@
|
||||||
package ca.joeltherrien.randomforest.covariates;
|
package ca.joeltherrien.randomforest.covariates.settings;
|
||||||
|
|
||||||
|
import ca.joeltherrien.randomforest.covariates.FactorCovariate;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import lombok.NoArgsConstructor;
|
import lombok.NoArgsConstructor;
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
package ca.joeltherrien.randomforest.covariates;
|
package ca.joeltherrien.randomforest.covariates.settings;
|
||||||
|
|
||||||
|
import ca.joeltherrien.randomforest.covariates.numeric.NumericCovariate;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import lombok.NoArgsConstructor;
|
import lombok.NoArgsConstructor;
|
||||||
|
|
|
@ -3,6 +3,7 @@ package ca.joeltherrien.randomforest.responses.competingrisk.differentiator;
|
||||||
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
|
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
|
||||||
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskSets;
|
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskSets;
|
||||||
import ca.joeltherrien.randomforest.tree.GroupDifferentiator;
|
import ca.joeltherrien.randomforest.tree.GroupDifferentiator;
|
||||||
|
import ca.joeltherrien.randomforest.tree.SimpleGroupDifferentiator;
|
||||||
import lombok.AllArgsConstructor;
|
import lombok.AllArgsConstructor;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
|
||||||
|
@ -14,11 +15,7 @@ import java.util.stream.Stream;
|
||||||
* modifies the abstract method.
|
* modifies the abstract method.
|
||||||
*
|
*
|
||||||
*/
|
*/
|
||||||
public abstract class CompetingRiskGroupDifferentiator<Y extends CompetingRiskResponse> implements GroupDifferentiator<Y>{
|
public abstract class CompetingRiskGroupDifferentiator<Y extends CompetingRiskResponse> extends SimpleGroupDifferentiator<Y> {
|
||||||
|
|
||||||
@Override
|
|
||||||
public abstract Double differentiate(List<Y> leftHand, List<Y> rightHand);
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Calculates the log rank value (or the Gray's test value) for a *specific* event cause.
|
* Calculates the log rank value (or the Gray's test value) for a *specific* event cause.
|
||||||
|
|
|
@ -17,7 +17,7 @@ public class GrayLogRankMultipleGroupDifferentiator extends CompetingRiskGroupDi
|
||||||
private final int[] events;
|
private final int[] events;
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Double differentiate(List<CompetingRiskResponseWithCensorTime> leftHand, List<CompetingRiskResponseWithCensorTime> rightHand) {
|
public Double getScore(List<CompetingRiskResponseWithCensorTime> leftHand, List<CompetingRiskResponseWithCensorTime> rightHand) {
|
||||||
if(leftHand.size() == 0 || rightHand.size() == 0){
|
if(leftHand.size() == 0 || rightHand.size() == 0){
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
|
@ -18,7 +18,7 @@ public class GrayLogRankSingleGroupDifferentiator extends CompetingRiskGroupDiff
|
||||||
private final int[] events;
|
private final int[] events;
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Double differentiate(List<CompetingRiskResponseWithCensorTime> leftHand, List<CompetingRiskResponseWithCensorTime> rightHand) {
|
public Double getScore(List<CompetingRiskResponseWithCensorTime> leftHand, List<CompetingRiskResponseWithCensorTime> rightHand) {
|
||||||
if(leftHand.size() == 0 || rightHand.size() == 0){
|
if(leftHand.size() == 0 || rightHand.size() == 0){
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
|
@ -17,7 +17,7 @@ public class LogRankMultipleGroupDifferentiator extends CompetingRiskGroupDiffer
|
||||||
private final int[] events;
|
private final int[] events;
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Double differentiate(List<CompetingRiskResponse> leftHand, List<CompetingRiskResponse> rightHand) {
|
public Double getScore(List<CompetingRiskResponse> leftHand, List<CompetingRiskResponse> rightHand) {
|
||||||
if(leftHand.size() == 0 || rightHand.size() == 0){
|
if(leftHand.size() == 0 || rightHand.size() == 0){
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
|
@ -18,7 +18,7 @@ public class LogRankSingleGroupDifferentiator extends CompetingRiskGroupDifferen
|
||||||
private final int[] events;
|
private final int[] events;
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Double differentiate(List<CompetingRiskResponse> leftHand, List<CompetingRiskResponse> rightHand) {
|
public Double getScore(List<CompetingRiskResponse> leftHand, List<CompetingRiskResponse> rightHand) {
|
||||||
if(leftHand.size() == 0 || rightHand.size() == 0){
|
if(leftHand.size() == 0 || rightHand.size() == 0){
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,26 +0,0 @@
|
||||||
package ca.joeltherrien.randomforest.responses.regression;
|
|
||||||
|
|
||||||
import ca.joeltherrien.randomforest.tree.GroupDifferentiator;
|
|
||||||
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
public class MeanGroupDifferentiator implements GroupDifferentiator<Double> {
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Double differentiate(List<Double> leftHand, List<Double> rightHand) {
|
|
||||||
|
|
||||||
double leftHandSize = leftHand.size();
|
|
||||||
double rightHandSize = rightHand.size();
|
|
||||||
|
|
||||||
if(leftHandSize == 0 || rightHandSize == 0){
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
double leftHandMean = leftHand.stream().mapToDouble(db -> db/leftHandSize).sum();
|
|
||||||
double rightHandMean = rightHand.stream().mapToDouble(db -> db/rightHandSize).sum();
|
|
||||||
|
|
||||||
return Math.abs(leftHandMean - rightHandMean);
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,13 +1,13 @@
|
||||||
package ca.joeltherrien.randomforest.responses.regression;
|
package ca.joeltherrien.randomforest.responses.regression;
|
||||||
|
|
||||||
import ca.joeltherrien.randomforest.tree.GroupDifferentiator;
|
import ca.joeltherrien.randomforest.tree.SimpleGroupDifferentiator;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
public class WeightedVarianceGroupDifferentiator implements GroupDifferentiator<Double> {
|
public class WeightedVarianceGroupDifferentiator extends SimpleGroupDifferentiator<Double> {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Double differentiate(List<Double> leftHand, List<Double> rightHand) {
|
public Double getScore(List<Double> leftHand, List<Double> rightHand) {
|
||||||
|
|
||||||
final double leftHandSize = leftHand.size();
|
final double leftHandSize = leftHand.size();
|
||||||
final double rightHandSize = rightHand.size();
|
final double rightHandSize = rightHand.size();
|
||||||
|
|
|
@ -14,8 +14,10 @@ import java.io.IOException;
|
||||||
import java.io.ObjectOutputStream;
|
import java.io.ObjectOutputStream;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
import java.util.Random;
|
||||||
import java.util.concurrent.ExecutorService;
|
import java.util.concurrent.ExecutorService;
|
||||||
import java.util.concurrent.Executors;
|
import java.util.concurrent.Executors;
|
||||||
|
import java.util.concurrent.ThreadLocalRandom;
|
||||||
import java.util.concurrent.atomic.AtomicInteger;
|
import java.util.concurrent.atomic.AtomicInteger;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
import java.util.stream.Stream;
|
import java.util.stream.Stream;
|
||||||
|
@ -45,17 +47,17 @@ public class ForestTrainer<Y, TO, FO> {
|
||||||
this.covariates = covariates;
|
this.covariates = covariates;
|
||||||
this.treeResponseCombiner = settings.getTreeCombiner();
|
this.treeResponseCombiner = settings.getTreeCombiner();
|
||||||
this.treeTrainer = new TreeTrainer<>(settings, covariates);
|
this.treeTrainer = new TreeTrainer<>(settings, covariates);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public Forest<TO, FO> trainSerial(){
|
public Forest<TO, FO> trainSerial(){
|
||||||
|
|
||||||
final List<Tree<TO>> trees = new ArrayList<>(ntree);
|
final List<Tree<TO>> trees = new ArrayList<>(ntree);
|
||||||
final Bootstrapper<Row<Y>> bootstrapper = new Bootstrapper<>(data);
|
final Bootstrapper<Row<Y>> bootstrapper = new Bootstrapper<>(data);
|
||||||
|
final Random random = new Random();
|
||||||
|
|
||||||
for(int j=0; j<ntree; j++){
|
for(int j=0; j<ntree; j++){
|
||||||
|
|
||||||
trees.add(trainTree(bootstrapper));
|
trees.add(trainTree(bootstrapper, random));
|
||||||
|
|
||||||
if(displayProgress){
|
if(displayProgress){
|
||||||
if(j==0) {
|
if(j==0) {
|
||||||
|
@ -162,9 +164,9 @@ public class ForestTrainer<Y, TO, FO> {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private Tree<TO> trainTree(final Bootstrapper<Row<Y>> bootstrapper){
|
private Tree<TO> trainTree(final Bootstrapper<Row<Y>> bootstrapper, Random random){
|
||||||
final List<Row<Y>> bootstrappedData = bootstrapper.bootstrap();
|
final List<Row<Y>> bootstrappedData = bootstrapper.bootstrap(random);
|
||||||
return treeTrainer.growTree(bootstrappedData);
|
return treeTrainer.growTree(bootstrappedData, random);
|
||||||
}
|
}
|
||||||
|
|
||||||
public void saveTree(final Tree<TO> tree, String name) throws IOException {
|
public void saveTree(final Tree<TO> tree, String name) throws IOException {
|
||||||
|
@ -193,7 +195,8 @@ public class ForestTrainer<Y, TO, FO> {
|
||||||
@Override
|
@Override
|
||||||
public void run() {
|
public void run() {
|
||||||
|
|
||||||
final Tree<TO> tree = trainTree(bootstrapper);
|
// ThreadLocalRandom should make sure we don't duplicate seeds
|
||||||
|
final Tree<TO> tree = trainTree(bootstrapper, ThreadLocalRandom.current());
|
||||||
|
|
||||||
// should be okay as the list structure isn't changing
|
// should be okay as the list structure isn't changing
|
||||||
treeList.set(treeIndex, tree);
|
treeList.set(treeIndex, tree);
|
||||||
|
@ -216,7 +219,8 @@ public class ForestTrainer<Y, TO, FO> {
|
||||||
@Override
|
@Override
|
||||||
public void run() {
|
public void run() {
|
||||||
|
|
||||||
final Tree<TO> tree = trainTree(bootstrapper);
|
// ThreadLocalRandom should make sure we don't duplicate seeds
|
||||||
|
final Tree<TO> tree = trainTree(bootstrapper, ThreadLocalRandom.current());
|
||||||
|
|
||||||
try {
|
try {
|
||||||
saveTree(tree, filename);
|
saveTree(tree, filename);
|
||||||
|
|
|
@ -1,15 +1,17 @@
|
||||||
package ca.joeltherrien.randomforest.tree;
|
package ca.joeltherrien.randomforest.tree;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.Iterator;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* When choosing an optimal node to split on, we choose the split that maximizes the difference between the two groups.
|
* When choosing an optimal node to split on, we choose the split that maximizes the difference between the two groups.
|
||||||
* The GroupDifferentiator has one method that outputs a score to show how different groups are. The larger the score,
|
* The GroupDifferentiator has one method that cycles through an iterator of Splits (FYI; check if the iterator is an
|
||||||
* the greater the difference.
|
* instance of Covariate.SplitRuleUpdater; in which case you get access to the rows that change between splits)
|
||||||
*
|
*
|
||||||
|
* If you want to implement a very trivial GroupDifferentiator that just takes two Lists as arguments, try extending
|
||||||
|
* SimpleGroupDifferentiator.
|
||||||
*/
|
*/
|
||||||
public interface GroupDifferentiator<Y> {
|
public interface GroupDifferentiator<Y> {
|
||||||
|
|
||||||
Double differentiate(List<Y> leftHand, List<Y> rightHand);
|
<V> SplitAndScore<Y, V> differentiate(Iterator<Split<Y, V>> splitIterator);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,50 @@
|
||||||
|
package ca.joeltherrien.randomforest.tree;
|
||||||
|
|
||||||
|
import ca.joeltherrien.randomforest.Row;
|
||||||
|
|
||||||
|
import java.util.Iterator;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
|
public abstract class SimpleGroupDifferentiator<Y> implements GroupDifferentiator<Y> {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public <V> SplitAndScore<Y, V> differentiate(Iterator<Split<Y, V>> splitIterator) {
|
||||||
|
Double bestScore = null;
|
||||||
|
Split<Y, V> bestSplit = null;
|
||||||
|
|
||||||
|
while(splitIterator.hasNext()){
|
||||||
|
final Split<Y, V> candidateSplit = splitIterator.next();
|
||||||
|
|
||||||
|
final List<Y> leftHand = candidateSplit.getLeftHand().stream().map(Row::getResponse).collect(Collectors.toList());
|
||||||
|
final List<Y> rightHand = candidateSplit.getRightHand().stream().map(Row::getResponse).collect(Collectors.toList());
|
||||||
|
|
||||||
|
if(leftHand.isEmpty() || rightHand.isEmpty()){
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
final Double score = getScore(leftHand, rightHand);
|
||||||
|
|
||||||
|
if(score != null && (bestScore == null || score > bestScore)){
|
||||||
|
bestScore = score;
|
||||||
|
bestSplit = candidateSplit;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if(bestSplit == null){
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
return new SplitAndScore<>(bestSplit, bestScore);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Return a score; higher is better.
|
||||||
|
*
|
||||||
|
* @param leftHand
|
||||||
|
* @param rightHand
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
public abstract Double getScore(List<Y> leftHand, List<Y> rightHand);
|
||||||
|
|
||||||
|
}
|
|
@ -1,19 +1,21 @@
|
||||||
package ca.joeltherrien.randomforest.tree;
|
package ca.joeltherrien.randomforest.tree;
|
||||||
|
|
||||||
import ca.joeltherrien.randomforest.Row;
|
import ca.joeltherrien.randomforest.Row;
|
||||||
|
import ca.joeltherrien.randomforest.covariates.Covariate;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Very simple class that contains three lists; it's essentially a thruple.
|
* Very simple class that contains three lists and a SplitRule.
|
||||||
*
|
*
|
||||||
* @author joel
|
* @author joel
|
||||||
*
|
*
|
||||||
*/
|
*/
|
||||||
@Data
|
@Data
|
||||||
public class Split<Y> {
|
public final class Split<Y, V> {
|
||||||
|
|
||||||
|
public final Covariate.SplitRule<V> splitRule;
|
||||||
public final List<Row<Y>> leftHand;
|
public final List<Row<Y>> leftHand;
|
||||||
public final List<Row<Y>> rightHand;
|
public final List<Row<Y>> rightHand;
|
||||||
public final List<Row<Y>> naHand;
|
public final List<Row<Y>> naHand;
|
||||||
|
|
|
@ -0,0 +1,15 @@
|
||||||
|
package ca.joeltherrien.randomforest.tree;
|
||||||
|
|
||||||
|
import lombok.AllArgsConstructor;
|
||||||
|
import lombok.Getter;
|
||||||
|
|
||||||
|
@AllArgsConstructor
|
||||||
|
public class SplitAndScore<Y, V> {
|
||||||
|
|
||||||
|
@Getter
|
||||||
|
private final Split<Y, V> split;
|
||||||
|
|
||||||
|
@Getter
|
||||||
|
private final Double score;
|
||||||
|
|
||||||
|
}
|
|
@ -3,10 +3,11 @@ package ca.joeltherrien.randomforest.tree;
|
||||||
import ca.joeltherrien.randomforest.Row;
|
import ca.joeltherrien.randomforest.Row;
|
||||||
import ca.joeltherrien.randomforest.Settings;
|
import ca.joeltherrien.randomforest.Settings;
|
||||||
import ca.joeltherrien.randomforest.covariates.Covariate;
|
import ca.joeltherrien.randomforest.covariates.Covariate;
|
||||||
import lombok.*;
|
import lombok.AccessLevel;
|
||||||
|
import lombok.AllArgsConstructor;
|
||||||
|
import lombok.Builder;
|
||||||
|
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
import java.util.concurrent.ThreadLocalRandom;
|
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
@Builder
|
@Builder
|
||||||
|
@ -45,20 +46,21 @@ public class TreeTrainer<Y, O> {
|
||||||
this.covariates = covariates;
|
this.covariates = covariates;
|
||||||
}
|
}
|
||||||
|
|
||||||
public Tree<O> growTree(List<Row<Y>> data){
|
public Tree<O> growTree(List<Row<Y>> data, Random random){
|
||||||
|
|
||||||
final Node<O> rootNode = growNode(data, 0);
|
final Node<O> rootNode = growNode(data, 0, random);
|
||||||
return new Tree<>(rootNode, data.stream().mapToInt(Row::getId).toArray());
|
return new Tree<>(rootNode, data.stream().mapToInt(Row::getId).toArray());
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private Node<O> growNode(List<Row<Y>> data, int depth){
|
private Node<O> growNode(List<Row<Y>> data, int depth, Random random){
|
||||||
// See https://kogalur.github.io/randomForestSRC/theory.html#section3.1 (near bottom)
|
// See https://kogalur.github.io/randomForestSRC/theory.html#section3.1 (near bottom)
|
||||||
if(data.size() >= 2*nodeSize && depth < maxNodeDepth && !nodeIsPure(data)){
|
if(data.size() >= 2*nodeSize && depth < maxNodeDepth && !nodeIsPure(data)){
|
||||||
final List<Covariate> covariatesToTry = selectCovariates(this.mtry);
|
final List<Covariate> covariatesToTry = selectCovariates(this.mtry, random);
|
||||||
final SplitRuleAndSplit bestSplitRuleAndSplit = findBestSplitRule(data, covariatesToTry);
|
final Split<Y,?> bestSplit = findBestSplitRule(data, covariatesToTry, random);
|
||||||
|
|
||||||
if(bestSplitRuleAndSplit.splitRule == null){
|
|
||||||
|
if(bestSplit == null){
|
||||||
|
|
||||||
return new TerminalNode<>(
|
return new TerminalNode<>(
|
||||||
responseCombiner.combine(
|
responseCombiner.combine(
|
||||||
|
@ -69,14 +71,31 @@ public class TreeTrainer<Y, O> {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
final Split<Y> split = bestSplitRuleAndSplit.split;
|
|
||||||
// Note that NAs have already been handled
|
// Now that we have the best split; we need to handle any NAs that were dropped off
|
||||||
|
final double probabilityLeftHand = (double) bestSplit.leftHand.size() /
|
||||||
|
(double) (bestSplit.leftHand.size() + bestSplit.rightHand.size());
|
||||||
|
|
||||||
|
// Assign missing values to the split
|
||||||
|
for(Row<Y> row : data) {
|
||||||
|
if(row.getCovariateValue(bestSplit.getSplitRule().getParent()).isNA()) {
|
||||||
|
final boolean randomDecision = random.nextDouble() <= probabilityLeftHand;
|
||||||
|
|
||||||
|
if(randomDecision){
|
||||||
|
bestSplit.getLeftHand().add(row);
|
||||||
|
}
|
||||||
|
else{
|
||||||
|
bestSplit.getRightHand().add(row);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
final Node<O> leftNode = growNode(split.leftHand, depth+1);
|
final Node<O> leftNode = growNode(bestSplit.leftHand, depth+1, random);
|
||||||
final Node<O> rightNode = growNode(split.rightHand, depth+1);
|
final Node<O> rightNode = growNode(bestSplit.rightHand, depth+1, random);
|
||||||
|
|
||||||
return new SplitNode<>(leftNode, rightNode, bestSplitRuleAndSplit.splitRule, bestSplitRuleAndSplit.probabilityLeftHand);
|
return new SplitNode<>(leftNode, rightNode, bestSplit.getSplitRule(), probabilityLeftHand);
|
||||||
|
|
||||||
}
|
}
|
||||||
else{
|
else{
|
||||||
|
@ -90,13 +109,13 @@ public class TreeTrainer<Y, O> {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private List<Covariate> selectCovariates(int mtry){
|
private List<Covariate> selectCovariates(int mtry, Random random){
|
||||||
if(mtry >= covariates.size()){
|
if(mtry >= covariates.size()){
|
||||||
return covariates;
|
return covariates;
|
||||||
}
|
}
|
||||||
|
|
||||||
final List<Covariate> splitCovariates = new ArrayList<>(covariates);
|
final List<Covariate> splitCovariates = new ArrayList<>(covariates);
|
||||||
Collections.shuffle(splitCovariates, ThreadLocalRandom.current());
|
Collections.shuffle(splitCovariates, random);
|
||||||
|
|
||||||
if (splitCovariates.size() > mtry) {
|
if (splitCovariates.size() > mtry) {
|
||||||
splitCovariates.subList(mtry, splitCovariates.size()).clear();
|
splitCovariates.subList(mtry, splitCovariates.size()).clear();
|
||||||
|
@ -105,63 +124,29 @@ public class TreeTrainer<Y, O> {
|
||||||
return splitCovariates;
|
return splitCovariates;
|
||||||
}
|
}
|
||||||
|
|
||||||
private SplitRuleAndSplit findBestSplitRule(List<Row<Y>> data, List<Covariate> covariatesToTry){
|
private Split<Y, ?> findBestSplitRule(List<Row<Y>> data, List<Covariate> covariatesToTry, Random random){
|
||||||
final Random random = ThreadLocalRandom.current();
|
|
||||||
SplitRuleAndSplit bestSplitRuleAndSplit = new SplitRuleAndSplit();
|
|
||||||
double bestSplitScore = 0.0;
|
|
||||||
boolean first = true;
|
|
||||||
|
|
||||||
for(final Covariate covariate : covariatesToTry){
|
SplitAndScore<Y, ?> bestSplitAndScore = null;
|
||||||
|
final GroupDifferentiator noGenericDifferentiator = groupDifferentiator; // cause Java generics suck
|
||||||
|
|
||||||
final int numberToTry = numberOfSplits==0 ? data.size() : numberOfSplits;
|
for(final Covariate covariate : covariatesToTry) {
|
||||||
|
final Iterator<Split> iterator = covariate.generateSplitRuleUpdater(data, this.numberOfSplits, random);
|
||||||
|
|
||||||
final Collection<Covariate.SplitRule> splitRulesToTry = covariate
|
final SplitAndScore<Y, ?> candidateSplitAndScore = noGenericDifferentiator.differentiate(iterator);
|
||||||
.generateSplitRules(
|
|
||||||
data
|
|
||||||
.stream()
|
|
||||||
.map(row -> row.getCovariateValue(covariate))
|
|
||||||
.collect(Collectors.toList())
|
|
||||||
, numberToTry);
|
|
||||||
|
|
||||||
for(final Covariate.SplitRule possibleRule : splitRulesToTry){
|
if(candidateSplitAndScore != null) {
|
||||||
final Split<Y> possibleSplit = possibleRule.applyRule(data);
|
if (bestSplitAndScore == null || candidateSplitAndScore.getScore() > bestSplitAndScore.getScore()) {
|
||||||
|
bestSplitAndScore = candidateSplitAndScore;
|
||||||
// We have to handle any NAs
|
|
||||||
if(possibleSplit.leftHand.size() == 0 && possibleSplit.rightHand.size() == 0 && possibleSplit.naHand.size() > 0){
|
|
||||||
throw new IllegalArgumentException("Can't apply " + this + " when there are rows with missing data and no non-missing value rows");
|
|
||||||
}
|
|
||||||
|
|
||||||
final double probabilityLeftHand = (double) possibleSplit.leftHand.size() / (double) (possibleSplit.leftHand.size() + possibleSplit.rightHand.size());
|
|
||||||
|
|
||||||
for(final Row<Y> missingValueRow : possibleSplit.naHand){
|
|
||||||
final boolean randomDecision = random.nextDouble() <= probabilityLeftHand;
|
|
||||||
if(randomDecision){
|
|
||||||
possibleSplit.leftHand.add(missingValueRow);
|
|
||||||
}
|
|
||||||
else{
|
|
||||||
possibleSplit.rightHand.add(missingValueRow);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
final Double score = groupDifferentiator.differentiate(
|
|
||||||
possibleSplit.leftHand.stream().map(row -> row.getResponse()).collect(Collectors.toList()),
|
|
||||||
possibleSplit.rightHand.stream().map(row -> row.getResponse()).collect(Collectors.toList())
|
|
||||||
);
|
|
||||||
|
|
||||||
|
|
||||||
if(score != null && !Double.isNaN(score) && (score > bestSplitScore || first)){
|
|
||||||
bestSplitRuleAndSplit.splitRule = possibleRule;
|
|
||||||
bestSplitRuleAndSplit.split = possibleSplit;
|
|
||||||
bestSplitRuleAndSplit.probabilityLeftHand = probabilityLeftHand;
|
|
||||||
|
|
||||||
bestSplitScore = score;
|
|
||||||
first = false;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return bestSplitRuleAndSplit;
|
if(bestSplitAndScore == null){
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
return bestSplitAndScore.getSplit();
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -184,10 +169,4 @@ public class TreeTrainer<Y, O> {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
private class SplitRuleAndSplit{
|
|
||||||
private Covariate.SplitRule splitRule = null;
|
|
||||||
private Split<Y> split = null;
|
|
||||||
private double probabilityLeftHand;
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,9 @@
|
||||||
|
package ca.joeltherrien.randomforest.utils;
|
||||||
|
|
||||||
|
import java.util.Iterator;
|
||||||
|
|
||||||
|
public interface IndexedIterator<E> extends Iterator<E> {
|
||||||
|
|
||||||
|
int getIndex();
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,29 @@
|
||||||
|
package ca.joeltherrien.randomforest.utils;
|
||||||
|
|
||||||
|
import lombok.RequiredArgsConstructor;
|
||||||
|
|
||||||
|
import java.util.Iterator;
|
||||||
|
|
||||||
|
@RequiredArgsConstructor
|
||||||
|
public class SingletonIterator<E> implements Iterator<E> {
|
||||||
|
|
||||||
|
private final E value;
|
||||||
|
|
||||||
|
private boolean beenCalled = false;
|
||||||
|
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean hasNext() {
|
||||||
|
return !beenCalled;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public E next() {
|
||||||
|
if(!beenCalled){
|
||||||
|
beenCalled = true;
|
||||||
|
return value;
|
||||||
|
}
|
||||||
|
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,65 @@
|
||||||
|
package ca.joeltherrien.randomforest.utils;
|
||||||
|
|
||||||
|
|
||||||
|
import lombok.Getter;
|
||||||
|
|
||||||
|
import java.util.Iterator;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Iterator that wraps around a UniqueValueIterator. It continues to iterate until it gets to one of the prespecified indexes,
|
||||||
|
* and then proceeds just past that to the end of the existing values it's at.
|
||||||
|
*
|
||||||
|
* The wrapped iterator must be from a sorted collection of some sort such that equal values are clumped together.
|
||||||
|
* I.e. "b b c c c d d a a" is okay but "a b b c c a" is not as 'a' appears twice at different locations
|
||||||
|
*
|
||||||
|
* @param <E>
|
||||||
|
*/
|
||||||
|
public class UniqueSubsetValueIterator<E> implements IndexedIterator<E> {
|
||||||
|
|
||||||
|
private final UniqueValueIterator<E> iterator;
|
||||||
|
private final Integer[] indexValues;
|
||||||
|
|
||||||
|
private int currentIndexSpot = 0;
|
||||||
|
|
||||||
|
public UniqueSubsetValueIterator(final UniqueValueIterator<E> iterator, final Integer[] indexValues){
|
||||||
|
this.iterator = iterator;
|
||||||
|
this.indexValues = indexValues;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean hasNext() {
|
||||||
|
return iterator.hasNext() && iterator.getIndex() <= indexValues[indexValues.length-1];
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public E next() {
|
||||||
|
if(hasNext()){
|
||||||
|
final int indexToStopBy = indexValues[currentIndexSpot];
|
||||||
|
|
||||||
|
while(iterator.getIndex() <= indexToStopBy){
|
||||||
|
iterator.next();
|
||||||
|
}
|
||||||
|
|
||||||
|
for(int i = currentIndexSpot + 1; i < indexValues.length; i++){
|
||||||
|
if(iterator.getIndex() <= indexValues[i]){
|
||||||
|
currentIndexSpot = i;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
return iterator.getCurrentValue();
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
return null;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int getIndex(){
|
||||||
|
return iterator.getIndex();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,71 @@
|
||||||
|
package ca.joeltherrien.randomforest.utils;
|
||||||
|
|
||||||
|
|
||||||
|
import lombok.Getter;
|
||||||
|
|
||||||
|
import java.util.Iterator;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Iterator that wraps around another iterator. It continues to iterate until it gets to the *end* of a sequence of identical values.
|
||||||
|
* It also tracks the current index in the original iterator.
|
||||||
|
*
|
||||||
|
* The wrapped iterator must be from a sorted collection of some sort such that equal values are clumped together.
|
||||||
|
* I.e. "b b c c c d d a a" is okay but "a b b c c a" is not as 'a' appears twice at different locations
|
||||||
|
*
|
||||||
|
* @param <E>
|
||||||
|
*/
|
||||||
|
public class UniqueValueIterator<E> implements IndexedIterator<E> {
|
||||||
|
|
||||||
|
private final Iterator<E> wrappedIterator;
|
||||||
|
|
||||||
|
@Getter private E currentValue = null;
|
||||||
|
@Getter private E nextValue;
|
||||||
|
|
||||||
|
public UniqueValueIterator(final Iterator<E> wrappedIterator){
|
||||||
|
this.wrappedIterator = wrappedIterator;
|
||||||
|
this.nextValue = wrappedIterator.next();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Count must return the index of the last value of the sequence returned by next()
|
||||||
|
@Getter
|
||||||
|
private int index = 0;
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean hasNext() {
|
||||||
|
return nextValue != null;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public E next() {
|
||||||
|
|
||||||
|
int count = 1;
|
||||||
|
while(wrappedIterator.hasNext()){
|
||||||
|
final E currentIteratorValue = wrappedIterator.next();
|
||||||
|
|
||||||
|
if(currentIteratorValue.equals(nextValue)){
|
||||||
|
count++;
|
||||||
|
}
|
||||||
|
else{
|
||||||
|
index +=count;
|
||||||
|
currentValue = nextValue;
|
||||||
|
nextValue = currentIteratorValue;
|
||||||
|
|
||||||
|
return currentValue;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
if(nextValue != null){
|
||||||
|
index += count;
|
||||||
|
currentValue = nextValue;
|
||||||
|
nextValue = null;
|
||||||
|
|
||||||
|
return currentValue;
|
||||||
|
}
|
||||||
|
else{
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,8 +1,8 @@
|
||||||
package ca.joeltherrien.randomforest;
|
package ca.joeltherrien.randomforest;
|
||||||
|
|
||||||
import ca.joeltherrien.randomforest.covariates.BooleanCovariateSettings;
|
import ca.joeltherrien.randomforest.covariates.settings.BooleanCovariateSettings;
|
||||||
import ca.joeltherrien.randomforest.covariates.Covariate;
|
import ca.joeltherrien.randomforest.covariates.Covariate;
|
||||||
import ca.joeltherrien.randomforest.covariates.NumericCovariateSettings;
|
import ca.joeltherrien.randomforest.covariates.settings.NumericCovariateSettings;
|
||||||
import ca.joeltherrien.randomforest.responses.competingrisk.combiner.CompetingRiskFunctionCombiner;
|
import ca.joeltherrien.randomforest.responses.competingrisk.combiner.CompetingRiskFunctionCombiner;
|
||||||
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskFunctions;
|
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskFunctions;
|
||||||
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
|
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
|
||||||
|
|
|
@ -2,6 +2,8 @@ package ca.joeltherrien.randomforest.competingrisk;
|
||||||
|
|
||||||
import ca.joeltherrien.randomforest.*;
|
import ca.joeltherrien.randomforest.*;
|
||||||
import ca.joeltherrien.randomforest.covariates.*;
|
import ca.joeltherrien.randomforest.covariates.*;
|
||||||
|
import ca.joeltherrien.randomforest.covariates.settings.BooleanCovariateSettings;
|
||||||
|
import ca.joeltherrien.randomforest.covariates.settings.NumericCovariateSettings;
|
||||||
import ca.joeltherrien.randomforest.responses.competingrisk.*;
|
import ca.joeltherrien.randomforest.responses.competingrisk.*;
|
||||||
import ca.joeltherrien.randomforest.tree.Forest;
|
import ca.joeltherrien.randomforest.tree.Forest;
|
||||||
import ca.joeltherrien.randomforest.tree.ForestTrainer;
|
import ca.joeltherrien.randomforest.tree.ForestTrainer;
|
||||||
|
@ -18,6 +20,7 @@ import static org.junit.jupiter.api.Assertions.*;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
import java.util.Random;
|
||||||
|
|
||||||
public class TestCompetingRisk {
|
public class TestCompetingRisk {
|
||||||
|
|
||||||
|
@ -104,7 +107,7 @@ public class TestCompetingRisk {
|
||||||
final List<Row<CompetingRiskResponse>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation());
|
final List<Row<CompetingRiskResponse>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation());
|
||||||
|
|
||||||
final TreeTrainer<CompetingRiskResponse, CompetingRiskFunctions> treeTrainer = new TreeTrainer<>(settings, covariates);
|
final TreeTrainer<CompetingRiskResponse, CompetingRiskFunctions> treeTrainer = new TreeTrainer<>(settings, covariates);
|
||||||
final Node<CompetingRiskFunctions> node = treeTrainer.growTree(dataset);
|
final Node<CompetingRiskFunctions> node = treeTrainer.growTree(dataset, new Random());
|
||||||
|
|
||||||
final CovariateRow newRow = getPredictionRow(covariates);
|
final CovariateRow newRow = getPredictionRow(covariates);
|
||||||
|
|
||||||
|
@ -157,7 +160,7 @@ public class TestCompetingRisk {
|
||||||
final List<Row<CompetingRiskResponse>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation());
|
final List<Row<CompetingRiskResponse>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation());
|
||||||
|
|
||||||
final TreeTrainer<CompetingRiskResponse, CompetingRiskFunctions> treeTrainer = new TreeTrainer<>(settings, covariates);
|
final TreeTrainer<CompetingRiskResponse, CompetingRiskFunctions> treeTrainer = new TreeTrainer<>(settings, covariates);
|
||||||
final Node<CompetingRiskFunctions> node = treeTrainer.growTree(dataset);
|
final Node<CompetingRiskFunctions> node = treeTrainer.growTree(dataset, new Random());
|
||||||
|
|
||||||
final CovariateRow newRow = getPredictionRow(covariates);
|
final CovariateRow newRow = getPredictionRow(covariates);
|
||||||
|
|
||||||
|
|
|
@ -4,7 +4,7 @@ import ca.joeltherrien.randomforest.DataLoader;
|
||||||
import ca.joeltherrien.randomforest.Row;
|
import ca.joeltherrien.randomforest.Row;
|
||||||
import ca.joeltherrien.randomforest.Settings;
|
import ca.joeltherrien.randomforest.Settings;
|
||||||
import ca.joeltherrien.randomforest.covariates.Covariate;
|
import ca.joeltherrien.randomforest.covariates.Covariate;
|
||||||
import ca.joeltherrien.randomforest.covariates.NumericCovariateSettings;
|
import ca.joeltherrien.randomforest.covariates.settings.NumericCovariateSettings;
|
||||||
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
|
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
|
||||||
import ca.joeltherrien.randomforest.responses.competingrisk.differentiator.LogRankMultipleGroupDifferentiator;
|
import ca.joeltherrien.randomforest.responses.competingrisk.differentiator.LogRankMultipleGroupDifferentiator;
|
||||||
import ca.joeltherrien.randomforest.utils.Utils;
|
import ca.joeltherrien.randomforest.utils.Utils;
|
||||||
|
@ -53,14 +53,14 @@ public class TestLogRankMultipleGroupDifferentiator {
|
||||||
final List<Row<CompetingRiskResponse>> group1Bad = data.subList(0, 196);
|
final List<Row<CompetingRiskResponse>> group1Bad = data.subList(0, 196);
|
||||||
final List<Row<CompetingRiskResponse>> group2Bad = data.subList(196, data.size());
|
final List<Row<CompetingRiskResponse>> group2Bad = data.subList(196, data.size());
|
||||||
|
|
||||||
final double scoreBad = groupDifferentiator.differentiate(
|
final double scoreBad = groupDifferentiator.getScore(
|
||||||
group1Bad.stream().map(Row::getResponse).collect(Collectors.toList()),
|
group1Bad.stream().map(Row::getResponse).collect(Collectors.toList()),
|
||||||
group2Bad.stream().map(Row::getResponse).collect(Collectors.toList()));
|
group2Bad.stream().map(Row::getResponse).collect(Collectors.toList()));
|
||||||
|
|
||||||
final List<Row<CompetingRiskResponse>> group1Good = data.subList(0, 199);
|
final List<Row<CompetingRiskResponse>> group1Good = data.subList(0, 199);
|
||||||
final List<Row<CompetingRiskResponse>> group2Good= data.subList(199, data.size());
|
final List<Row<CompetingRiskResponse>> group2Good= data.subList(199, data.size());
|
||||||
|
|
||||||
final double scoreGood = groupDifferentiator.differentiate(
|
final double scoreGood = groupDifferentiator.getScore(
|
||||||
group1Good.stream().map(Row::getResponse).collect(Collectors.toList()),
|
group1Good.stream().map(Row::getResponse).collect(Collectors.toList()),
|
||||||
group2Good.stream().map(Row::getResponse).collect(Collectors.toList()));
|
group2Good.stream().map(Row::getResponse).collect(Collectors.toList()));
|
||||||
|
|
||||||
|
|
|
@ -1,8 +1,7 @@
|
||||||
package ca.joeltherrien.randomforest.competingrisk;
|
package ca.joeltherrien.randomforest.competingrisk;
|
||||||
|
|
||||||
import ca.joeltherrien.randomforest.Row;
|
import ca.joeltherrien.randomforest.Row;
|
||||||
import ca.joeltherrien.randomforest.responses.competingrisk.*;
|
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
|
||||||
import ca.joeltherrien.randomforest.responses.competingrisk.differentiator.LogRankMultipleGroupDifferentiator;
|
|
||||||
import ca.joeltherrien.randomforest.responses.competingrisk.differentiator.LogRankSingleGroupDifferentiator;
|
import ca.joeltherrien.randomforest.responses.competingrisk.differentiator.LogRankSingleGroupDifferentiator;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
|
|
||||||
|
@ -51,7 +50,7 @@ public class TestLogRankSingleGroupDifferentiator {
|
||||||
|
|
||||||
final LogRankSingleGroupDifferentiator differentiator = new LogRankSingleGroupDifferentiator(1, new int[]{1});
|
final LogRankSingleGroupDifferentiator differentiator = new LogRankSingleGroupDifferentiator(1, new int[]{1});
|
||||||
|
|
||||||
final double score = differentiator.differentiate(data1, data2);
|
final double score = differentiator.getScore(data1, data2);
|
||||||
final double margin = 0.000001;
|
final double margin = 0.000001;
|
||||||
|
|
||||||
// Tested using 855 method
|
// Tested using 855 method
|
||||||
|
@ -71,14 +70,14 @@ public class TestLogRankSingleGroupDifferentiator {
|
||||||
final List<Row<CompetingRiskResponse>> group1Good = data.subList(0, 221);
|
final List<Row<CompetingRiskResponse>> group1Good = data.subList(0, 221);
|
||||||
final List<Row<CompetingRiskResponse>> group2Good = data.subList(221, data.size());
|
final List<Row<CompetingRiskResponse>> group2Good = data.subList(221, data.size());
|
||||||
|
|
||||||
final double scoreGood = groupDifferentiator.differentiate(
|
final double scoreGood = groupDifferentiator.getScore(
|
||||||
group1Good.stream().map(Row::getResponse).collect(Collectors.toList()),
|
group1Good.stream().map(Row::getResponse).collect(Collectors.toList()),
|
||||||
group2Good.stream().map(Row::getResponse).collect(Collectors.toList()));
|
group2Good.stream().map(Row::getResponse).collect(Collectors.toList()));
|
||||||
|
|
||||||
final List<Row<CompetingRiskResponse>> group1Bad = data.subList(0, 222);
|
final List<Row<CompetingRiskResponse>> group1Bad = data.subList(0, 222);
|
||||||
final List<Row<CompetingRiskResponse>> group2Bad = data.subList(222, data.size());
|
final List<Row<CompetingRiskResponse>> group2Bad = data.subList(222, data.size());
|
||||||
|
|
||||||
final double scoreBad = groupDifferentiator.differentiate(
|
final double scoreBad = groupDifferentiator.getScore(
|
||||||
group1Bad.stream().map(Row::getResponse).collect(Collectors.toList()),
|
group1Bad.stream().map(Row::getResponse).collect(Collectors.toList()),
|
||||||
group2Bad.stream().map(Row::getResponse).collect(Collectors.toList()));
|
group2Bad.stream().map(Row::getResponse).collect(Collectors.toList()));
|
||||||
|
|
||||||
|
|
|
@ -5,8 +5,9 @@ import ca.joeltherrien.randomforest.utils.Utils;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.junit.jupiter.api.function.Executable;
|
import org.junit.jupiter.api.function.Executable;
|
||||||
|
|
||||||
import java.util.Collection;
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
import java.util.Random;
|
||||||
|
|
||||||
import static org.junit.jupiter.api.Assertions.*;
|
import static org.junit.jupiter.api.Assertions.*;
|
||||||
|
|
||||||
|
@ -46,7 +47,10 @@ public class FactorCovariateTest {
|
||||||
void testAllSubsets(){
|
void testAllSubsets(){
|
||||||
final FactorCovariate petCovariate = createTestCovariate();
|
final FactorCovariate petCovariate = createTestCovariate();
|
||||||
|
|
||||||
final Collection<FactorCovariate.FactorSplitRule> splitRules = petCovariate.generateSplitRules(null, 100);
|
final List<Covariate.SplitRule<String>> splitRules = new ArrayList<>();
|
||||||
|
|
||||||
|
petCovariate.generateSplitRuleUpdater(null, 100, new Random())
|
||||||
|
.forEachRemaining(split -> splitRules.add(split.getSplitRule()));
|
||||||
|
|
||||||
assertEquals(splitRules.size(), 3);
|
assertEquals(splitRules.size(), 3);
|
||||||
|
|
||||||
|
|
|
@ -3,10 +3,10 @@ package ca.joeltherrien.randomforest.csv;
|
||||||
import ca.joeltherrien.randomforest.DataLoader;
|
import ca.joeltherrien.randomforest.DataLoader;
|
||||||
import ca.joeltherrien.randomforest.Row;
|
import ca.joeltherrien.randomforest.Row;
|
||||||
import ca.joeltherrien.randomforest.Settings;
|
import ca.joeltherrien.randomforest.Settings;
|
||||||
import ca.joeltherrien.randomforest.covariates.BooleanCovariateSettings;
|
import ca.joeltherrien.randomforest.covariates.settings.BooleanCovariateSettings;
|
||||||
import ca.joeltherrien.randomforest.covariates.Covariate;
|
import ca.joeltherrien.randomforest.covariates.Covariate;
|
||||||
import ca.joeltherrien.randomforest.covariates.FactorCovariateSettings;
|
import ca.joeltherrien.randomforest.covariates.settings.FactorCovariateSettings;
|
||||||
import ca.joeltherrien.randomforest.covariates.NumericCovariateSettings;
|
import ca.joeltherrien.randomforest.covariates.settings.NumericCovariateSettings;
|
||||||
import ca.joeltherrien.randomforest.utils.Utils;
|
import ca.joeltherrien.randomforest.utils.Utils;
|
||||||
import com.fasterxml.jackson.databind.node.JsonNodeFactory;
|
import com.fasterxml.jackson.databind.node.JsonNodeFactory;
|
||||||
import com.fasterxml.jackson.databind.node.ObjectNode;
|
import com.fasterxml.jackson.databind.node.ObjectNode;
|
||||||
|
@ -15,7 +15,6 @@ import org.junit.jupiter.api.Test;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.stream.Collectors;
|
|
||||||
|
|
||||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||||
|
|
|
@ -3,7 +3,9 @@ package ca.joeltherrien.randomforest.settings;
|
||||||
import ca.joeltherrien.randomforest.Settings;
|
import ca.joeltherrien.randomforest.Settings;
|
||||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
|
||||||
import ca.joeltherrien.randomforest.covariates.*;
|
import ca.joeltherrien.randomforest.covariates.settings.BooleanCovariateSettings;
|
||||||
|
import ca.joeltherrien.randomforest.covariates.settings.FactorCovariateSettings;
|
||||||
|
import ca.joeltherrien.randomforest.covariates.settings.NumericCovariateSettings;
|
||||||
import ca.joeltherrien.randomforest.utils.Utils;
|
import ca.joeltherrien.randomforest.utils.Utils;
|
||||||
import com.fasterxml.jackson.databind.node.JsonNodeFactory;
|
import com.fasterxml.jackson.databind.node.JsonNodeFactory;
|
||||||
import com.fasterxml.jackson.databind.node.ObjectNode;
|
import com.fasterxml.jackson.databind.node.ObjectNode;
|
||||||
|
|
|
@ -0,0 +1,25 @@
|
||||||
|
package ca.joeltherrien.randomforest.utils;
|
||||||
|
|
||||||
|
import org.junit.jupiter.api.Test;
|
||||||
|
|
||||||
|
import static org.junit.jupiter.api.Assertions.*;
|
||||||
|
|
||||||
|
public class TestSingletonIterator {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void verifyBehaviour(){
|
||||||
|
final Integer element = 5;
|
||||||
|
|
||||||
|
final SingletonIterator<Integer> iterator = new SingletonIterator<>(element);
|
||||||
|
|
||||||
|
assertTrue(iterator.hasNext());
|
||||||
|
assertTrue(iterator.hasNext());
|
||||||
|
|
||||||
|
assertEquals(Integer.valueOf(5), iterator.next());
|
||||||
|
|
||||||
|
assertFalse(iterator.hasNext());
|
||||||
|
assertNull(iterator.next());
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,68 @@
|
||||||
|
package ca.joeltherrien.randomforest.utils;
|
||||||
|
|
||||||
|
import org.junit.jupiter.api.Test;
|
||||||
|
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
import static org.junit.jupiter.api.Assertions.*;
|
||||||
|
|
||||||
|
public class TestUniqueSubsetValueIterator {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testIterator1(){
|
||||||
|
final List<Integer> testData = Arrays.asList(
|
||||||
|
1,1,2,3,5,5,5,6,7,7
|
||||||
|
);
|
||||||
|
|
||||||
|
final Integer[] indexes = new Integer[]{2,3,4,5};
|
||||||
|
|
||||||
|
final UniqueValueIterator<Integer> uniqueValueIterator = new UniqueValueIterator<>(testData.iterator());
|
||||||
|
final UniqueSubsetValueIterator<Integer> iterator = new UniqueSubsetValueIterator<>(uniqueValueIterator, indexes);
|
||||||
|
|
||||||
|
// we expect to get 2, 3, and 5 back. 5 should happen only once
|
||||||
|
|
||||||
|
assertEquals(iterator.getIndex(), 0);
|
||||||
|
assertTrue(iterator.hasNext());
|
||||||
|
|
||||||
|
assertEquals(iterator.next().intValue(), 2);
|
||||||
|
assertEquals(iterator.getIndex(), 3);
|
||||||
|
assertTrue(iterator.hasNext());
|
||||||
|
|
||||||
|
assertEquals(iterator.next().intValue(), 3);
|
||||||
|
assertEquals(iterator.getIndex(), 4);
|
||||||
|
assertTrue(iterator.hasNext());
|
||||||
|
|
||||||
|
assertEquals(iterator.next().intValue(), 5);
|
||||||
|
assertEquals(iterator.getIndex(), 7);
|
||||||
|
assertFalse(iterator.hasNext());
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testIterator2(){
|
||||||
|
final List<Integer> testData = Arrays.asList(
|
||||||
|
1,1,2,3,5,5,5,6,7,7
|
||||||
|
);
|
||||||
|
|
||||||
|
final Integer[] indexes = new Integer[]{1,8};
|
||||||
|
|
||||||
|
final UniqueValueIterator<Integer> uniqueValueIterator = new UniqueValueIterator<>(testData.iterator());
|
||||||
|
final UniqueSubsetValueIterator<Integer> iterator = new UniqueSubsetValueIterator<>(uniqueValueIterator, indexes);
|
||||||
|
|
||||||
|
assertEquals(iterator.getIndex(), 0);
|
||||||
|
assertTrue(iterator.hasNext());
|
||||||
|
|
||||||
|
assertEquals(iterator.next().intValue(), 1);
|
||||||
|
assertEquals(iterator.getIndex(), 2);
|
||||||
|
assertTrue(iterator.hasNext());
|
||||||
|
|
||||||
|
assertEquals(iterator.next().intValue(), 7);
|
||||||
|
assertEquals(iterator.getIndex(), 10);
|
||||||
|
assertFalse(iterator.hasNext());
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,112 @@
|
||||||
|
package ca.joeltherrien.randomforest.utils;
|
||||||
|
|
||||||
|
import org.junit.jupiter.api.Test;
|
||||||
|
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
import static org.junit.jupiter.api.Assertions.*;
|
||||||
|
|
||||||
|
public class TestUniqueValueIterator {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testIterator1(){
|
||||||
|
final List<Integer> testData = Arrays.asList(
|
||||||
|
1,1,2,3,5,5,5,6,7,7
|
||||||
|
);
|
||||||
|
|
||||||
|
final UniqueValueIterator<Integer> iterator = new UniqueValueIterator<>(testData.iterator());
|
||||||
|
|
||||||
|
assertEquals(iterator.getIndex(), 0);
|
||||||
|
assertTrue(iterator.hasNext());
|
||||||
|
|
||||||
|
assertEquals(iterator.next().intValue(), 1);
|
||||||
|
assertEquals(iterator.getIndex(), 2);
|
||||||
|
assertTrue(iterator.hasNext());
|
||||||
|
|
||||||
|
assertEquals(iterator.next().intValue(), 2);
|
||||||
|
assertEquals(iterator.getIndex(), 3);
|
||||||
|
assertTrue(iterator.hasNext());
|
||||||
|
|
||||||
|
assertEquals(iterator.next().intValue(), 3);
|
||||||
|
assertEquals(iterator.getIndex(), 4);
|
||||||
|
assertTrue(iterator.hasNext());
|
||||||
|
|
||||||
|
assertEquals(iterator.next().intValue(), 5);
|
||||||
|
assertEquals(iterator.getIndex(), 7);
|
||||||
|
assertTrue(iterator.hasNext());
|
||||||
|
|
||||||
|
assertEquals(iterator.next().intValue(), 6);
|
||||||
|
assertEquals(iterator.getIndex(), 8);
|
||||||
|
assertTrue(iterator.hasNext());
|
||||||
|
|
||||||
|
assertEquals(iterator.next().intValue(), 7);
|
||||||
|
assertEquals(iterator.getIndex(), 10);
|
||||||
|
assertTrue(!iterator.hasNext());
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testIterator2(){
|
||||||
|
final List<Integer> testData = Arrays.asList(
|
||||||
|
1,2,3,5,5,5,6,7 // same numbers; but 1 and 7 only appear once each
|
||||||
|
);
|
||||||
|
|
||||||
|
final UniqueValueIterator<Integer> iterator = new UniqueValueIterator<>(testData.iterator());
|
||||||
|
|
||||||
|
assertEquals(iterator.getIndex(), 0);
|
||||||
|
assertTrue(iterator.hasNext());
|
||||||
|
|
||||||
|
assertEquals(iterator.next().intValue(), 1);
|
||||||
|
assertEquals(iterator.getIndex(), 1);
|
||||||
|
assertTrue(iterator.hasNext());
|
||||||
|
|
||||||
|
assertEquals(iterator.next().intValue(), 2);
|
||||||
|
assertEquals(iterator.getIndex(), 2);
|
||||||
|
assertTrue(iterator.hasNext());
|
||||||
|
|
||||||
|
assertEquals(iterator.next().intValue(), 3);
|
||||||
|
assertEquals(iterator.getIndex(), 3);
|
||||||
|
assertTrue(iterator.hasNext());
|
||||||
|
|
||||||
|
assertEquals(iterator.next().intValue(), 5);
|
||||||
|
assertEquals(iterator.getIndex(), 6);
|
||||||
|
assertTrue(iterator.hasNext());
|
||||||
|
|
||||||
|
assertEquals(iterator.next().intValue(), 6);
|
||||||
|
assertEquals(iterator.getIndex(), 7);
|
||||||
|
assertTrue(iterator.hasNext());
|
||||||
|
|
||||||
|
assertEquals(iterator.next().intValue(), 7);
|
||||||
|
assertEquals(iterator.getIndex(), 8);
|
||||||
|
assertFalse(iterator.hasNext());
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testIterator3(){
|
||||||
|
final List<Integer> testData = Arrays.asList(
|
||||||
|
1,1,1,1,1,1,1,2,2,2,2,2,3
|
||||||
|
);
|
||||||
|
|
||||||
|
final UniqueValueIterator<Integer> iterator = new UniqueValueIterator<>(testData.iterator());
|
||||||
|
|
||||||
|
assertEquals(iterator.getIndex(), 0);
|
||||||
|
assertTrue(iterator.hasNext());
|
||||||
|
|
||||||
|
assertEquals(iterator.next().intValue(), 1);
|
||||||
|
assertEquals(iterator.getIndex(), 7);
|
||||||
|
assertTrue(iterator.hasNext());
|
||||||
|
|
||||||
|
assertEquals(iterator.next().intValue(), 2);
|
||||||
|
assertEquals(iterator.getIndex(), 12);
|
||||||
|
assertTrue(iterator.hasNext());
|
||||||
|
|
||||||
|
assertEquals(iterator.next().intValue(), 3);
|
||||||
|
assertEquals(iterator.getIndex(), 13);
|
||||||
|
assertFalse(iterator.hasNext());
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -2,7 +2,7 @@ package ca.joeltherrien.randomforest.workshop;
|
||||||
|
|
||||||
import ca.joeltherrien.randomforest.*;
|
import ca.joeltherrien.randomforest.*;
|
||||||
import ca.joeltherrien.randomforest.covariates.Covariate;
|
import ca.joeltherrien.randomforest.covariates.Covariate;
|
||||||
import ca.joeltherrien.randomforest.covariates.NumericCovariate;
|
import ca.joeltherrien.randomforest.covariates.numeric.NumericCovariate;
|
||||||
import ca.joeltherrien.randomforest.responses.regression.MeanResponseCombiner;
|
import ca.joeltherrien.randomforest.responses.regression.MeanResponseCombiner;
|
||||||
import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceGroupDifferentiator;
|
import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceGroupDifferentiator;
|
||||||
import ca.joeltherrien.randomforest.tree.ForestTrainer;
|
import ca.joeltherrien.randomforest.tree.ForestTrainer;
|
||||||
|
|
|
@ -3,7 +3,7 @@ package ca.joeltherrien.randomforest.workshop;
|
||||||
|
|
||||||
import ca.joeltherrien.randomforest.covariates.Covariate;
|
import ca.joeltherrien.randomforest.covariates.Covariate;
|
||||||
import ca.joeltherrien.randomforest.CovariateRow;
|
import ca.joeltherrien.randomforest.CovariateRow;
|
||||||
import ca.joeltherrien.randomforest.covariates.NumericCovariate;
|
import ca.joeltherrien.randomforest.covariates.numeric.NumericCovariate;
|
||||||
import ca.joeltherrien.randomforest.Row;
|
import ca.joeltherrien.randomforest.Row;
|
||||||
import ca.joeltherrien.randomforest.responses.regression.MeanResponseCombiner;
|
import ca.joeltherrien.randomforest.responses.regression.MeanResponseCombiner;
|
||||||
import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceGroupDifferentiator;
|
import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceGroupDifferentiator;
|
||||||
|
@ -54,13 +54,14 @@ public class TrainSingleTree {
|
||||||
.covariates(covariateNames)
|
.covariates(covariateNames)
|
||||||
.responseCombiner(new MeanResponseCombiner())
|
.responseCombiner(new MeanResponseCombiner())
|
||||||
.maxNodeDepth(30)
|
.maxNodeDepth(30)
|
||||||
|
.mtry(2)
|
||||||
.nodeSize(5)
|
.nodeSize(5)
|
||||||
.numberOfSplits(0)
|
.numberOfSplits(0)
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
|
|
||||||
final long startTime = System.currentTimeMillis();
|
final long startTime = System.currentTimeMillis();
|
||||||
final Node<Double> baseNode = treeTrainer.growTree(trainingSet);
|
final Node<Double> baseNode = treeTrainer.growTree(trainingSet, new Random());
|
||||||
final long endTime = System.currentTimeMillis();
|
final long endTime = System.currentTimeMillis();
|
||||||
|
|
||||||
System.out.println(((double)(endTime - startTime))/1000.0);
|
System.out.println(((double)(endTime - startTime))/1000.0);
|
||||||
|
|
|
@ -4,7 +4,7 @@ package ca.joeltherrien.randomforest.workshop;
|
||||||
import ca.joeltherrien.randomforest.*;
|
import ca.joeltherrien.randomforest.*;
|
||||||
import ca.joeltherrien.randomforest.covariates.Covariate;
|
import ca.joeltherrien.randomforest.covariates.Covariate;
|
||||||
import ca.joeltherrien.randomforest.covariates.FactorCovariate;
|
import ca.joeltherrien.randomforest.covariates.FactorCovariate;
|
||||||
import ca.joeltherrien.randomforest.covariates.NumericCovariate;
|
import ca.joeltherrien.randomforest.covariates.numeric.NumericCovariate;
|
||||||
import ca.joeltherrien.randomforest.responses.regression.MeanResponseCombiner;
|
import ca.joeltherrien.randomforest.responses.regression.MeanResponseCombiner;
|
||||||
import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceGroupDifferentiator;
|
import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceGroupDifferentiator;
|
||||||
import ca.joeltherrien.randomforest.tree.Node;
|
import ca.joeltherrien.randomforest.tree.Node;
|
||||||
|
@ -13,7 +13,6 @@ import ca.joeltherrien.randomforest.utils.Utils;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
|
||||||
import java.util.Random;
|
import java.util.Random;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
import java.util.stream.DoubleStream;
|
import java.util.stream.DoubleStream;
|
||||||
|
@ -21,8 +20,6 @@ import java.util.stream.DoubleStream;
|
||||||
public class TrainSingleTreeFactor {
|
public class TrainSingleTreeFactor {
|
||||||
|
|
||||||
public static void main(String[] args) {
|
public static void main(String[] args) {
|
||||||
System.out.println("Hello world!");
|
|
||||||
|
|
||||||
final Random random = new Random(123);
|
final Random random = new Random(123);
|
||||||
|
|
||||||
final int n = 10000;
|
final int n = 10000;
|
||||||
|
@ -84,7 +81,7 @@ public class TrainSingleTreeFactor {
|
||||||
|
|
||||||
|
|
||||||
final long startTime = System.currentTimeMillis();
|
final long startTime = System.currentTimeMillis();
|
||||||
final Node<Double> baseNode = treeTrainer.growTree(trainingSet);
|
final Node<Double> baseNode = treeTrainer.growTree(trainingSet, random);
|
||||||
final long endTime = System.currentTimeMillis();
|
final long endTime = System.currentTimeMillis();
|
||||||
|
|
||||||
System.out.println(((double)(endTime - startTime))/1000.0);
|
System.out.println(((double)(endTime - startTime))/1000.0);
|
||||||
|
|
Loading…
Reference in a new issue