diff --git a/src/main/java/ca/joeltherrien/randomforest/BooleanCovariate.java b/src/main/java/ca/joeltherrien/randomforest/BooleanCovariate.java index 8920273..6e0743f 100644 --- a/src/main/java/ca/joeltherrien/randomforest/BooleanCovariate.java +++ b/src/main/java/ca/joeltherrien/randomforest/BooleanCovariate.java @@ -1,12 +1,9 @@ package ca.joeltherrien.randomforest; -import ca.joeltherrien.randomforest.exceptions.MissingValueException; import lombok.Getter; import lombok.RequiredArgsConstructor; import java.util.*; -import java.util.concurrent.ThreadLocalRandom; -import java.util.stream.Collectors; @RequiredArgsConstructor public class BooleanCovariate implements Covariate{ @@ -28,9 +25,9 @@ public class BooleanCovariate implements Covariate{ public class BooleanValue implements Value{ - private final boolean value; + private final Boolean value; - private BooleanValue(final boolean value){ + private BooleanValue(final Boolean value){ this.value = value; } @@ -43,6 +40,11 @@ public class BooleanCovariate implements Covariate{ public Boolean getValue() { return value; } + + @Override + public boolean isNA() { + return value == null; + } } public class BooleanSplitRule implements SplitRule{ @@ -58,15 +60,12 @@ public class BooleanCovariate implements Covariate{ } @Override - public boolean isLeftHand(CovariateRow row) { - final Value x = row.getCovariateValue(getParent().getName()); - if(x == null) { - throw new MissingValueException(row, this); + public boolean isLeftHand(final Value value) { + if(value.isNA()) { + throw new IllegalArgumentException("Trying to determine split on missing value"); } - final boolean xBoolean = (Boolean) x.getValue(); - - return !xBoolean; + return !value.getValue(); } } } diff --git a/src/main/java/ca/joeltherrien/randomforest/Covariate.java b/src/main/java/ca/joeltherrien/randomforest/Covariate.java index 86848d8..0cf40c8 100644 --- a/src/main/java/ca/joeltherrien/randomforest/Covariate.java +++ b/src/main/java/ca/joeltherrien/randomforest/Covariate.java @@ -1,9 +1,8 @@ package ca.joeltherrien.randomforest; import java.io.Serializable; -import java.util.Collection; -import java.util.LinkedList; -import java.util.List; +import java.util.*; +import java.util.concurrent.ThreadLocalRandom; public interface Covariate extends Serializable { @@ -19,6 +18,8 @@ public interface Covariate extends Serializable { V getValue(); + boolean isNA(); + } interface SplitRule extends Serializable{ @@ -37,9 +38,22 @@ public interface Covariate extends Serializable { final List> leftHand = new LinkedList<>(); final List> rightHand = new LinkedList<>(); - for(final Row row : rows) { + final List nonMissingDecisions = new ArrayList<>(); + final List> missingValueRows = new ArrayList<>(); - if(isLeftHand(row)){ + + for(final Row row : rows) { + final Value value = (Value) row.getCovariateValue(getParent().getName()); + + if(value.isNA()){ + missingValueRows.add(row); + continue; + } + + final boolean isLeftHand = isLeftHand(value); + nonMissingDecisions.add(isLeftHand); + + if(isLeftHand){ leftHand.add(row); } else{ @@ -48,10 +62,31 @@ public interface Covariate extends Serializable { } + if(nonMissingDecisions.size() == 0 && missingValueRows.size() > 0){ + throw new IllegalArgumentException("Can't apply " + this + " when there are rows with missing data and no non-missing value rows"); + } + + final Random random = ThreadLocalRandom.current(); + for(final Row missingValueRow : missingValueRows){ + final boolean randomDecision = nonMissingDecisions.get(random.nextInt(nonMissingDecisions.size())); + + if(randomDecision){ + leftHand.add(missingValueRow); + } + else{ + rightHand.add(missingValueRow); + } + } + return new Split<>(leftHand, rightHand); } - boolean isLeftHand(CovariateRow row); + default boolean isLeftHand(CovariateRow row){ + final Value value = (Value) row.getCovariateValue(getParent().getName()); + return isLeftHand(value); + } + + boolean isLeftHand(Value value); } diff --git a/src/main/java/ca/joeltherrien/randomforest/FactorCovariate.java b/src/main/java/ca/joeltherrien/randomforest/FactorCovariate.java index 5200e02..ed4c4e9 100644 --- a/src/main/java/ca/joeltherrien/randomforest/FactorCovariate.java +++ b/src/main/java/ca/joeltherrien/randomforest/FactorCovariate.java @@ -9,6 +9,7 @@ public final class FactorCovariate implements Covariate{ private final String name; private final Map factorLevels; + private final FactorValue naValue; private final int numberOfPossiblePairings; @@ -28,6 +29,7 @@ public final class FactorCovariate implements Covariate{ } this.numberOfPossiblePairings = numberOfPossiblePairingsTemp-1; + this.naValue = new FactorValue(null); } @@ -67,6 +69,10 @@ public final class FactorCovariate implements Covariate{ @Override public FactorValue createValue(String value) { + if(value == null){ + return this.naValue; + } + final FactorValue factorValue = factorLevels.get(value); if(factorValue == null){ @@ -94,6 +100,11 @@ public final class FactorCovariate implements Covariate{ public String getValue() { return value; } + + @Override + public boolean isNA() { + return value == null; + } } @EqualsAndHashCode @@ -111,12 +122,12 @@ public final class FactorCovariate implements Covariate{ } @Override - public boolean isLeftHand(CovariateRow row) { - final FactorValue value = (FactorValue) row.getCovariateValue(getName()).getValue(); + public boolean isLeftHand(final Value value) { + if(value.isNA()){ + throw new IllegalArgumentException("Trying to determine split on missing value"); + } return leftSideValues.contains(value); - - } } } diff --git a/src/main/java/ca/joeltherrien/randomforest/NumericCovariate.java b/src/main/java/ca/joeltherrien/randomforest/NumericCovariate.java index b9be60a..014bbfc 100644 --- a/src/main/java/ca/joeltherrien/randomforest/NumericCovariate.java +++ b/src/main/java/ca/joeltherrien/randomforest/NumericCovariate.java @@ -1,6 +1,5 @@ package ca.joeltherrien.randomforest; -import ca.joeltherrien.randomforest.exceptions.MissingValueException; import lombok.Getter; import lombok.RequiredArgsConstructor; @@ -19,6 +18,9 @@ public class NumericCovariate implements Covariate{ final Random random = ThreadLocalRandom.current(); + // only work with non-NA values + data = data.stream().filter(value -> !value.isNA()).collect(Collectors.toList()); + // for this implementation we need to shuffle the data final List> shuffledData; if(number > data.size()){ @@ -55,9 +57,9 @@ public class NumericCovariate implements Covariate{ public class NumericValue implements Covariate.Value{ - private final double value; + private final Double value; // may be null - private NumericValue(final double value){ + private NumericValue(final Double value){ this.value = value; } @@ -70,6 +72,11 @@ public class NumericCovariate implements Covariate{ public Double getValue() { return value; } + + @Override + public boolean isNA() { + return value == null; + } } public class NumericSplitRule implements Covariate.SplitRule{ @@ -91,13 +98,12 @@ public class NumericCovariate implements Covariate{ } @Override - public boolean isLeftHand(CovariateRow row) { - final Covariate.Value x = row.getCovariateValue(getParent().getName()); - if(x == null) { - throw new MissingValueException(row, this); + public boolean isLeftHand(final Value x) { + if(x.isNA()) { + throw new IllegalArgumentException("Trying to determine split on missing value"); } - final double xNum = (Double) x.getValue(); + final double xNum = x.getValue(); return xNum <= threshold; } diff --git a/src/main/java/ca/joeltherrien/randomforest/exceptions/MissingValueException.java b/src/main/java/ca/joeltherrien/randomforest/exceptions/MissingValueException.java deleted file mode 100644 index d340fdc..0000000 --- a/src/main/java/ca/joeltherrien/randomforest/exceptions/MissingValueException.java +++ /dev/null @@ -1,17 +0,0 @@ -package ca.joeltherrien.randomforest.exceptions; - -import ca.joeltherrien.randomforest.Covariate; -import ca.joeltherrien.randomforest.CovariateRow; - -public class MissingValueException extends RuntimeException{ - - /** - * - */ - private static final long serialVersionUID = 6808060079431207726L; - - public MissingValueException(CovariateRow row, Covariate.SplitRule rule) { - super("Missing value at CovariateRow " + row + rule); - } - -}