From 38e70dd3a1d02f3399f107a9019c16b964bcf97a Mon Sep 17 00:00:00 2001 From: Joel Therrien Date: Wed, 4 Jul 2018 10:54:07 -0700 Subject: [PATCH] Add BooleanCovariate --- .../randomforest/BooleanCovariate.java | 72 +++++++++++++++++++ 1 file changed, 72 insertions(+) create mode 100644 src/main/java/ca/joeltherrien/randomforest/BooleanCovariate.java diff --git a/src/main/java/ca/joeltherrien/randomforest/BooleanCovariate.java b/src/main/java/ca/joeltherrien/randomforest/BooleanCovariate.java new file mode 100644 index 0000000..8920273 --- /dev/null +++ b/src/main/java/ca/joeltherrien/randomforest/BooleanCovariate.java @@ -0,0 +1,72 @@ +package ca.joeltherrien.randomforest; + +import ca.joeltherrien.randomforest.exceptions.MissingValueException; +import lombok.Getter; +import lombok.RequiredArgsConstructor; + +import java.util.*; +import java.util.concurrent.ThreadLocalRandom; +import java.util.stream.Collectors; + +@RequiredArgsConstructor +public class BooleanCovariate implements Covariate{ + + @Getter + private final String name; + + private final BooleanSplitRule splitRule = new BooleanSplitRule(); // there's only one possible rule for BooleanCovariates. + + @Override + public Collection generateSplitRules(List> data, int number) { + return Collections.singleton(splitRule); + } + + @Override + public BooleanValue createValue(Boolean value) { + return new BooleanValue(value); + } + + public class BooleanValue implements Value{ + + private final boolean value; + + private BooleanValue(final boolean value){ + this.value = value; + } + + @Override + public BooleanCovariate getParent() { + return BooleanCovariate.this; + } + + @Override + public Boolean getValue() { + return value; + } + } + + public class BooleanSplitRule implements SplitRule{ + + @Override + public final String toString() { + return "BooleanSplitRule"; + } + + @Override + public BooleanCovariate getParent() { + return BooleanCovariate.this; + } + + @Override + public boolean isLeftHand(CovariateRow row) { + final Value x = row.getCovariateValue(getParent().getName()); + if(x == null) { + throw new MissingValueException(row, this); + } + + final boolean xBoolean = (Boolean) x.getValue(); + + return !xBoolean; + } + } +}