package ca.joeltherrien.randomforest.covariates; import lombok.Getter; import lombok.RequiredArgsConstructor; import java.util.*; import java.util.concurrent.ThreadLocalRandom; import java.util.stream.Collectors; @RequiredArgsConstructor public final class NumericCovariate implements Covariate<Double>{ @Getter private final String name; @Override public Collection<NumericSplitRule> generateSplitRules(List<Value<Double>> data, int number) { final Random random = ThreadLocalRandom.current(); // only work with non-NA values data = data.stream().filter(value -> !value.isNA()).collect(Collectors.toList()); // for this implementation we need to shuffle the data final List<Value<Double>> shuffledData; if(number > data.size()){ shuffledData = new ArrayList<>(data); Collections.shuffle(shuffledData, random); } else{ // only need the top number entries shuffledData = new ArrayList<>(number); final Set<Integer> indexesToUse = new HashSet<>(); while(indexesToUse.size() < number){ final int index = random.nextInt(data.size()); if(indexesToUse.add(index)){ shuffledData.add(data.get(index)); } } } return shuffledData.stream() .mapToDouble(v -> v.getValue()) .mapToObj(threshold -> new NumericSplitRule(threshold)) .collect(Collectors.toSet()); // by returning a set we'll make everything far more efficient as a lot of rules can repeat due to bootstrapping } @Override public NumericValue createValue(Double value) { return new NumericValue(value); } @Override public Value<Double> createValue(String value) { if(value == null || value.equalsIgnoreCase("na")){ return createValue((Double) null); } return createValue(Double.parseDouble(value)); } public class NumericValue implements Covariate.Value<Double>{ private final Double value; // may be null private NumericValue(final Double value){ this.value = value; } @Override public NumericCovariate getParent() { return NumericCovariate.this; } @Override public Double getValue() { return value; } @Override public boolean isNA() { return value == null; } } public class NumericSplitRule implements Covariate.SplitRule<Double>{ private final double threshold; private NumericSplitRule(final double threshold){ this.threshold = threshold; } @Override public final String toString() { return "NumericSplitRule on " + getParent().getName() + " at " + threshold; } @Override public NumericCovariate getParent() { return NumericCovariate.this; } @Override public boolean isLeftHand(final Value<Double> x) { if(x.isNA()) { throw new IllegalArgumentException("Trying to determine split on missing value"); } final double xNum = x.getValue(); return xNum <= threshold; } } }