diff --git a/src/main/java/ca/joeltherrien/randomforest/FactorCovariate.java b/src/main/java/ca/joeltherrien/randomforest/FactorCovariate.java new file mode 100644 index 0000000..ba3ac81 --- /dev/null +++ b/src/main/java/ca/joeltherrien/randomforest/FactorCovariate.java @@ -0,0 +1,122 @@ +package ca.joeltherrien.randomforest; + +import lombok.EqualsAndHashCode; + +import java.util.*; +import java.util.concurrent.ThreadLocalRandom; + +public final class FactorCovariate implements Covariate{ + + private final String name; + private final Map factorLevels; + private final int numberOfPossiblePairings; + + + public FactorCovariate(final String name, List levels){ + this.name = name; + this.factorLevels = new HashMap<>(); + + for(final String level : levels){ + final FactorValue newValue = new FactorValue(level); + + factorLevels.put(level, newValue); + } + + int numberOfPossiblePairingsTemp = 1; + for(int i=0; i generateSplitRules(List> data, int number) { + final Set splitRules = new HashSet<>(); + + // This is to ensure we don't get stuck in an infinite loop for small factors + number = Math.min(number, numberOfPossiblePairings); + final Random random = ThreadLocalRandom.current(); + final List levels = new ArrayList<>(factorLevels.values()); + + + + while(splitRules.size() < number){ + Collections.shuffle(levels, random); + final Set leftSideValues = new HashSet<>(); + leftSideValues.add(levels.get(0)); + + for(int i=1; i{ + + private final String value; + + private FactorValue(final String value){ + this.value = value; + } + + @Override + public FactorCovariate getParent() { + return FactorCovariate.this; + } + + @Override + public String getValue() { + return value; + } + } + + @EqualsAndHashCode + public final class FactorSplitRule implements Covariate.SplitRule{ + + private final Set leftSideValues; + + private FactorSplitRule(final Set leftSideValues){ + this.leftSideValues = leftSideValues; + } + + @Override + public FactorCovariate getParent() { + return FactorCovariate.this; + } + + @Override + public boolean isLeftHand(CovariateRow row) { + final FactorValue value = (FactorValue) row.getCovariateValue(getName()).getValue(); + + return leftSideValues.contains(value); + + + } + } +}