diff --git a/src/main/java/ca/joeltherrien/randomforest/BooleanCovariate.java b/src/main/java/ca/joeltherrien/randomforest/BooleanCovariate.java new file mode 100644 index 0000000..8920273 --- /dev/null +++ b/src/main/java/ca/joeltherrien/randomforest/BooleanCovariate.java @@ -0,0 +1,72 @@ +package ca.joeltherrien.randomforest; + +import ca.joeltherrien.randomforest.exceptions.MissingValueException; +import lombok.Getter; +import lombok.RequiredArgsConstructor; + +import java.util.*; +import java.util.concurrent.ThreadLocalRandom; +import java.util.stream.Collectors; + +@RequiredArgsConstructor +public class BooleanCovariate implements Covariate{ + + @Getter + private final String name; + + private final BooleanSplitRule splitRule = new BooleanSplitRule(); // there's only one possible rule for BooleanCovariates. + + @Override + public Collection generateSplitRules(List> data, int number) { + return Collections.singleton(splitRule); + } + + @Override + public BooleanValue createValue(Boolean value) { + return new BooleanValue(value); + } + + public class BooleanValue implements Value{ + + private final boolean value; + + private BooleanValue(final boolean value){ + this.value = value; + } + + @Override + public BooleanCovariate getParent() { + return BooleanCovariate.this; + } + + @Override + public Boolean getValue() { + return value; + } + } + + public class BooleanSplitRule implements SplitRule{ + + @Override + public final String toString() { + return "BooleanSplitRule"; + } + + @Override + public BooleanCovariate getParent() { + return BooleanCovariate.this; + } + + @Override + public boolean isLeftHand(CovariateRow row) { + final Value x = row.getCovariateValue(getParent().getName()); + if(x == null) { + throw new MissingValueException(row, this); + } + + final boolean xBoolean = (Boolean) x.getValue(); + + return !xBoolean; + } + } +}