Covariates track if they have any NA values, and skip NA handling code if possible

This commit is contained in:
Joel Therrien 2019-01-10 14:09:43 -08:00
parent a57741b726
commit 31d6ce9b3e
5 changed files with 50 additions and 14 deletions

View file

@ -17,6 +17,8 @@ public final class BooleanCovariate implements Covariate<Boolean> {
@Getter @Getter
private final int index; private final int index;
private boolean hasNAs = false;
private final BooleanSplitRule splitRule = new BooleanSplitRule(); // there's only one possible rule for BooleanCovariates. private final BooleanSplitRule splitRule = new BooleanSplitRule(); // there's only one possible rule for BooleanCovariates.
@Override @Override
@ -32,6 +34,7 @@ public final class BooleanCovariate implements Covariate<Boolean> {
@Override @Override
public Value<Boolean> createValue(String value) { public Value<Boolean> createValue(String value) {
if(value == null || value.equalsIgnoreCase("na")){ if(value == null || value.equalsIgnoreCase("na")){
hasNAs = true;
return createValue( (Boolean) null); return createValue( (Boolean) null);
} }
@ -46,6 +49,11 @@ public final class BooleanCovariate implements Covariate<Boolean> {
} }
} }
@Override
public boolean hasNAs() {
return hasNAs;
}
@Override @Override
public String toString(){ public String toString(){
return "BooleanCovariate(name=" + name + ")"; return "BooleanCovariate(name=" + name + ")";

View file

@ -26,6 +26,8 @@ public interface Covariate<V> extends Serializable {
*/ */
Value<V> createValue(String value); Value<V> createValue(String value);
boolean hasNAs();
interface Value<V> extends Serializable{ interface Value<V> extends Serializable{
Covariate<V> getParent(); Covariate<V> getParent();

View file

@ -19,6 +19,8 @@ public final class FactorCovariate implements Covariate<String>{
private final FactorValue naValue; private final FactorValue naValue;
private final int numberOfPossiblePairings; private final int numberOfPossiblePairings;
private boolean hasNAs;
public FactorCovariate(final String name, final int index, List<String> levels){ public FactorCovariate(final String name, final int index, List<String> levels){
this.name = name; this.name = name;
@ -72,6 +74,7 @@ public final class FactorCovariate implements Covariate<String>{
@Override @Override
public FactorValue createValue(String value) { public FactorValue createValue(String value) {
if(value == null || value.equalsIgnoreCase("na")){ if(value == null || value.equalsIgnoreCase("na")){
this.hasNAs = true;
return this.naValue; return this.naValue;
} }
@ -84,6 +87,12 @@ public final class FactorCovariate implements Covariate<String>{
return factorValue; return factorValue;
} }
@Override
public boolean hasNAs() {
return hasNAs;
}
@Override @Override
public String toString(){ public String toString(){
return "FactorCovariate(name=" + name + ")"; return "FactorCovariate(name=" + name + ")";

View file

@ -12,6 +12,7 @@ import lombok.ToString;
import java.util.*; import java.util.*;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import java.util.stream.Stream;
@RequiredArgsConstructor @RequiredArgsConstructor
@ToString @ToString
@ -23,10 +24,17 @@ public final class NumericCovariate implements Covariate<Double> {
@Getter @Getter
private final int index; private final int index;
private boolean hasNAs = false;
@Override @Override
public <Y> NumericSplitRuleUpdater<Y> generateSplitRuleUpdater(List<Row<Y>> data, int number, Random random) { public <Y> NumericSplitRuleUpdater<Y> generateSplitRuleUpdater(List<Row<Y>> data, int number, Random random) {
data = data.stream() Stream<Row<Y>> stream = data.stream();
.filter(row -> !row.getCovariateValue(this).isNA())
if(hasNAs()){
stream = stream.filter(row -> !row.getCovariateValue(this).isNA());
}
data = stream
.sorted((r1, r2) -> { .sorted((r1, r2) -> {
Double d1 = r1.getCovariateValue(this).getValue(); Double d1 = r1.getCovariateValue(this).getValue();
Double d2 = r2.getCovariateValue(this).getValue(); Double d2 = r2.getCovariateValue(this).getValue();
@ -37,7 +45,6 @@ public final class NumericCovariate implements Covariate<Double> {
Iterator<Double> sortedDataIterator = data.stream() Iterator<Double> sortedDataIterator = data.stream()
.map(row -> row.getCovariateValue(this).getValue()) .map(row -> row.getCovariateValue(this).getValue())
.filter(v -> v != null)
.iterator(); .iterator();
@ -56,7 +63,7 @@ public final class NumericCovariate implements Covariate<Double> {
dataIterator = new UniqueSubsetValueIterator<>( dataIterator = new UniqueSubsetValueIterator<>(
new UniqueValueIterator<>(sortedDataIterator), new UniqueValueIterator<>(sortedDataIterator),
indexSet.toArray(new Integer[indexSet.size()]) // TODO verify this is ordered indexSet.toArray(new Integer[indexSet.size()])
); );
} }
@ -73,12 +80,19 @@ public final class NumericCovariate implements Covariate<Double> {
@Override @Override
public NumericValue createValue(String value) { public NumericValue createValue(String value) {
if(value == null || value.equalsIgnoreCase("na")){ if(value == null || value.equalsIgnoreCase("na")){
this.hasNAs = true;
return createValue((Double) null); return createValue((Double) null);
} }
return createValue(Double.parseDouble(value)); return createValue(Double.parseDouble(value));
} }
@Override
public boolean hasNAs() {
return hasNAs;
}
@EqualsAndHashCode @EqualsAndHashCode
public class NumericValue implements Covariate.Value<Double>{ public class NumericValue implements Covariate.Value<Double>{

View file

@ -76,22 +76,25 @@ public class TreeTrainer<Y, O> {
final double probabilityLeftHand = (double) bestSplit.leftHand.size() / final double probabilityLeftHand = (double) bestSplit.leftHand.size() /
(double) (bestSplit.leftHand.size() + bestSplit.rightHand.size()); (double) (bestSplit.leftHand.size() + bestSplit.rightHand.size());
// Assign missing values to the split // Assign missing values to the split if necessary
for(Row<Y> row : data) { if(bestSplit.getSplitRule().getParent().hasNAs()){
if(row.getCovariateValue(bestSplit.getSplitRule().getParent()).isNA()) { for(Row<Y> row : data) {
final boolean randomDecision = random.nextDouble() <= probabilityLeftHand; if(row.getCovariateValue(bestSplit.getSplitRule().getParent()).isNA()) {
final boolean randomDecision = random.nextDouble() <= probabilityLeftHand;
if(randomDecision){ if(randomDecision){
bestSplit.getLeftHand().add(row); bestSplit.getLeftHand().add(row);
} }
else{ else{
bestSplit.getRightHand().add(row); bestSplit.getRightHand().add(row);
} }
}
} }
} }
final Node<O> leftNode = growNode(bestSplit.leftHand, depth+1, random); final Node<O> leftNode = growNode(bestSplit.leftHand, depth+1, random);
final Node<O> rightNode = growNode(bestSplit.rightHand, depth+1, random); final Node<O> rightNode = growNode(bestSplit.rightHand, depth+1, random);