diff --git a/src/main/java/ca/joeltherrien/randomforest/covariates/BooleanCovariate.java b/src/main/java/ca/joeltherrien/randomforest/covariates/BooleanCovariate.java index 480485e..741fc2a 100644 --- a/src/main/java/ca/joeltherrien/randomforest/covariates/BooleanCovariate.java +++ b/src/main/java/ca/joeltherrien/randomforest/covariates/BooleanCovariate.java @@ -17,6 +17,8 @@ public final class BooleanCovariate implements Covariate { @Getter private final int index; + private boolean hasNAs = false; + private final BooleanSplitRule splitRule = new BooleanSplitRule(); // there's only one possible rule for BooleanCovariates. @Override @@ -32,6 +34,7 @@ public final class BooleanCovariate implements Covariate { @Override public Value createValue(String value) { if(value == null || value.equalsIgnoreCase("na")){ + hasNAs = true; return createValue( (Boolean) null); } @@ -46,6 +49,11 @@ public final class BooleanCovariate implements Covariate { } } + @Override + public boolean hasNAs() { + return hasNAs; + } + @Override public String toString(){ return "BooleanCovariate(name=" + name + ")"; diff --git a/src/main/java/ca/joeltherrien/randomforest/covariates/Covariate.java b/src/main/java/ca/joeltherrien/randomforest/covariates/Covariate.java index 1f04a0f..46c7849 100644 --- a/src/main/java/ca/joeltherrien/randomforest/covariates/Covariate.java +++ b/src/main/java/ca/joeltherrien/randomforest/covariates/Covariate.java @@ -26,6 +26,8 @@ public interface Covariate extends Serializable { */ Value createValue(String value); + boolean hasNAs(); + interface Value extends Serializable{ Covariate getParent(); diff --git a/src/main/java/ca/joeltherrien/randomforest/covariates/FactorCovariate.java b/src/main/java/ca/joeltherrien/randomforest/covariates/FactorCovariate.java index 89f7a02..e0f4101 100644 --- a/src/main/java/ca/joeltherrien/randomforest/covariates/FactorCovariate.java +++ b/src/main/java/ca/joeltherrien/randomforest/covariates/FactorCovariate.java @@ -19,6 +19,8 @@ public final class FactorCovariate implements Covariate{ private final FactorValue naValue; private final int numberOfPossiblePairings; + private boolean hasNAs; + public FactorCovariate(final String name, final int index, List levels){ this.name = name; @@ -72,6 +74,7 @@ public final class FactorCovariate implements Covariate{ @Override public FactorValue createValue(String value) { if(value == null || value.equalsIgnoreCase("na")){ + this.hasNAs = true; return this.naValue; } @@ -84,6 +87,12 @@ public final class FactorCovariate implements Covariate{ return factorValue; } + + @Override + public boolean hasNAs() { + return hasNAs; + } + @Override public String toString(){ return "FactorCovariate(name=" + name + ")"; diff --git a/src/main/java/ca/joeltherrien/randomforest/covariates/numeric/NumericCovariate.java b/src/main/java/ca/joeltherrien/randomforest/covariates/numeric/NumericCovariate.java index 8924626..1f00d51 100644 --- a/src/main/java/ca/joeltherrien/randomforest/covariates/numeric/NumericCovariate.java +++ b/src/main/java/ca/joeltherrien/randomforest/covariates/numeric/NumericCovariate.java @@ -12,6 +12,7 @@ import lombok.ToString; import java.util.*; import java.util.stream.Collectors; +import java.util.stream.Stream; @RequiredArgsConstructor @ToString @@ -23,10 +24,17 @@ public final class NumericCovariate implements Covariate { @Getter private final int index; + private boolean hasNAs = false; + @Override public NumericSplitRuleUpdater generateSplitRuleUpdater(List> data, int number, Random random) { - data = data.stream() - .filter(row -> !row.getCovariateValue(this).isNA()) + Stream> stream = data.stream(); + + if(hasNAs()){ + stream = stream.filter(row -> !row.getCovariateValue(this).isNA()); + } + + data = stream .sorted((r1, r2) -> { Double d1 = r1.getCovariateValue(this).getValue(); Double d2 = r2.getCovariateValue(this).getValue(); @@ -37,7 +45,6 @@ public final class NumericCovariate implements Covariate { Iterator sortedDataIterator = data.stream() .map(row -> row.getCovariateValue(this).getValue()) - .filter(v -> v != null) .iterator(); @@ -56,7 +63,7 @@ public final class NumericCovariate implements Covariate { dataIterator = new UniqueSubsetValueIterator<>( new UniqueValueIterator<>(sortedDataIterator), - indexSet.toArray(new Integer[indexSet.size()]) // TODO verify this is ordered + indexSet.toArray(new Integer[indexSet.size()]) ); } @@ -73,12 +80,19 @@ public final class NumericCovariate implements Covariate { @Override public NumericValue createValue(String value) { if(value == null || value.equalsIgnoreCase("na")){ + this.hasNAs = true; return createValue((Double) null); } return createValue(Double.parseDouble(value)); } + + @Override + public boolean hasNAs() { + return hasNAs; + } + @EqualsAndHashCode public class NumericValue implements Covariate.Value{ diff --git a/src/main/java/ca/joeltherrien/randomforest/tree/TreeTrainer.java b/src/main/java/ca/joeltherrien/randomforest/tree/TreeTrainer.java index 27589f8..d1d141d 100644 --- a/src/main/java/ca/joeltherrien/randomforest/tree/TreeTrainer.java +++ b/src/main/java/ca/joeltherrien/randomforest/tree/TreeTrainer.java @@ -76,22 +76,25 @@ public class TreeTrainer { final double probabilityLeftHand = (double) bestSplit.leftHand.size() / (double) (bestSplit.leftHand.size() + bestSplit.rightHand.size()); - // Assign missing values to the split - for(Row row : data) { - if(row.getCovariateValue(bestSplit.getSplitRule().getParent()).isNA()) { - final boolean randomDecision = random.nextDouble() <= probabilityLeftHand; + // Assign missing values to the split if necessary + if(bestSplit.getSplitRule().getParent().hasNAs()){ + for(Row row : data) { + if(row.getCovariateValue(bestSplit.getSplitRule().getParent()).isNA()) { + final boolean randomDecision = random.nextDouble() <= probabilityLeftHand; - if(randomDecision){ - bestSplit.getLeftHand().add(row); - } - else{ - bestSplit.getRightHand().add(row); - } + if(randomDecision){ + bestSplit.getLeftHand().add(row); + } + else{ + bestSplit.getRightHand().add(row); + } + } } } + final Node leftNode = growNode(bestSplit.leftHand, depth+1, random); final Node rightNode = growNode(bestSplit.rightHand, depth+1, random);