Add FactorCovariate; testing required.
This commit is contained in:
parent
2259528c22
commit
e0cfed632f
1 changed files with 122 additions and 0 deletions
122
src/main/java/ca/joeltherrien/randomforest/FactorCovariate.java
Normal file
122
src/main/java/ca/joeltherrien/randomforest/FactorCovariate.java
Normal file
|
@ -0,0 +1,122 @@
|
||||||
|
package ca.joeltherrien.randomforest;
|
||||||
|
|
||||||
|
import lombok.EqualsAndHashCode;
|
||||||
|
|
||||||
|
import java.util.*;
|
||||||
|
import java.util.concurrent.ThreadLocalRandom;
|
||||||
|
|
||||||
|
public final class FactorCovariate implements Covariate<String>{
|
||||||
|
|
||||||
|
private final String name;
|
||||||
|
private final Map<String, FactorValue> factorLevels;
|
||||||
|
private final int numberOfPossiblePairings;
|
||||||
|
|
||||||
|
|
||||||
|
public FactorCovariate(final String name, List<String> levels){
|
||||||
|
this.name = name;
|
||||||
|
this.factorLevels = new HashMap<>();
|
||||||
|
|
||||||
|
for(final String level : levels){
|
||||||
|
final FactorValue newValue = new FactorValue(level);
|
||||||
|
|
||||||
|
factorLevels.put(level, newValue);
|
||||||
|
}
|
||||||
|
|
||||||
|
int numberOfPossiblePairingsTemp = 1;
|
||||||
|
for(int i=0; i<levels.size()-1; i++){
|
||||||
|
numberOfPossiblePairingsTemp *= 2;
|
||||||
|
}
|
||||||
|
this.numberOfPossiblePairings = numberOfPossiblePairingsTemp;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String getName() {
|
||||||
|
return name;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Collection<FactorSplitRule> generateSplitRules(List<Value<String>> data, int number) {
|
||||||
|
final Set<FactorSplitRule> splitRules = new HashSet<>();
|
||||||
|
|
||||||
|
// This is to ensure we don't get stuck in an infinite loop for small factors
|
||||||
|
number = Math.min(number, numberOfPossiblePairings);
|
||||||
|
final Random random = ThreadLocalRandom.current();
|
||||||
|
final List<FactorValue> levels = new ArrayList<>(factorLevels.values());
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
while(splitRules.size() < number){
|
||||||
|
Collections.shuffle(levels, random);
|
||||||
|
final Set<FactorValue> leftSideValues = new HashSet<>();
|
||||||
|
leftSideValues.add(levels.get(0));
|
||||||
|
|
||||||
|
for(int i=1; i<levels.size()/2; i++){
|
||||||
|
if(random.nextBoolean()){
|
||||||
|
leftSideValues.add(levels.get(i));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
splitRules.add(new FactorSplitRule(leftSideValues));
|
||||||
|
}
|
||||||
|
|
||||||
|
return splitRules;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public FactorValue createValue(String value) {
|
||||||
|
final FactorValue factorValue = factorLevels.get(value);
|
||||||
|
|
||||||
|
if(factorValue == null){
|
||||||
|
throw new IllegalArgumentException(value + " is not a level in FactorCovariate " + name);
|
||||||
|
}
|
||||||
|
|
||||||
|
return factorValue;
|
||||||
|
}
|
||||||
|
|
||||||
|
@EqualsAndHashCode
|
||||||
|
public final class FactorValue implements Covariate.Value<String>{
|
||||||
|
|
||||||
|
private final String value;
|
||||||
|
|
||||||
|
private FactorValue(final String value){
|
||||||
|
this.value = value;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public FactorCovariate getParent() {
|
||||||
|
return FactorCovariate.this;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String getValue() {
|
||||||
|
return value;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@EqualsAndHashCode
|
||||||
|
public final class FactorSplitRule implements Covariate.SplitRule<String>{
|
||||||
|
|
||||||
|
private final Set<FactorValue> leftSideValues;
|
||||||
|
|
||||||
|
private FactorSplitRule(final Set<FactorValue> leftSideValues){
|
||||||
|
this.leftSideValues = leftSideValues;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public FactorCovariate getParent() {
|
||||||
|
return FactorCovariate.this;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean isLeftHand(CovariateRow row) {
|
||||||
|
final FactorValue value = (FactorValue) row.getCovariateValue(getName()).getValue();
|
||||||
|
|
||||||
|
return leftSideValues.contains(value);
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in a new issue