Covariates track if they have any NA values, and skip NA handling code if possible

This commit is contained in:
Joel Therrien 2019-01-10 14:09:43 -08:00
parent a57741b726
commit 31d6ce9b3e
5 changed files with 50 additions and 14 deletions

View file

@ -17,6 +17,8 @@ public final class BooleanCovariate implements Covariate<Boolean> {
@Getter
private final int index;
private boolean hasNAs = false;
private final BooleanSplitRule splitRule = new BooleanSplitRule(); // there's only one possible rule for BooleanCovariates.
@Override
@ -32,6 +34,7 @@ public final class BooleanCovariate implements Covariate<Boolean> {
@Override
public Value<Boolean> createValue(String value) {
if(value == null || value.equalsIgnoreCase("na")){
hasNAs = true;
return createValue( (Boolean) null);
}
@ -46,6 +49,11 @@ public final class BooleanCovariate implements Covariate<Boolean> {
}
}
@Override
public boolean hasNAs() {
return hasNAs;
}
@Override
public String toString(){
return "BooleanCovariate(name=" + name + ")";

View file

@ -26,6 +26,8 @@ public interface Covariate<V> extends Serializable {
*/
Value<V> createValue(String value);
boolean hasNAs();
interface Value<V> extends Serializable{
Covariate<V> getParent();

View file

@ -19,6 +19,8 @@ public final class FactorCovariate implements Covariate<String>{
private final FactorValue naValue;
private final int numberOfPossiblePairings;
private boolean hasNAs;
public FactorCovariate(final String name, final int index, List<String> levels){
this.name = name;
@ -72,6 +74,7 @@ public final class FactorCovariate implements Covariate<String>{
@Override
public FactorValue createValue(String value) {
if(value == null || value.equalsIgnoreCase("na")){
this.hasNAs = true;
return this.naValue;
}
@ -84,6 +87,12 @@ public final class FactorCovariate implements Covariate<String>{
return factorValue;
}
@Override
public boolean hasNAs() {
return hasNAs;
}
@Override
public String toString(){
return "FactorCovariate(name=" + name + ")";

View file

@ -12,6 +12,7 @@ import lombok.ToString;
import java.util.*;
import java.util.stream.Collectors;
import java.util.stream.Stream;
@RequiredArgsConstructor
@ToString
@ -23,10 +24,17 @@ public final class NumericCovariate implements Covariate<Double> {
@Getter
private final int index;
private boolean hasNAs = false;
@Override
public <Y> NumericSplitRuleUpdater<Y> generateSplitRuleUpdater(List<Row<Y>> data, int number, Random random) {
data = data.stream()
.filter(row -> !row.getCovariateValue(this).isNA())
Stream<Row<Y>> stream = data.stream();
if(hasNAs()){
stream = stream.filter(row -> !row.getCovariateValue(this).isNA());
}
data = stream
.sorted((r1, r2) -> {
Double d1 = r1.getCovariateValue(this).getValue();
Double d2 = r2.getCovariateValue(this).getValue();
@ -37,7 +45,6 @@ public final class NumericCovariate implements Covariate<Double> {
Iterator<Double> sortedDataIterator = data.stream()
.map(row -> row.getCovariateValue(this).getValue())
.filter(v -> v != null)
.iterator();
@ -56,7 +63,7 @@ public final class NumericCovariate implements Covariate<Double> {
dataIterator = new UniqueSubsetValueIterator<>(
new UniqueValueIterator<>(sortedDataIterator),
indexSet.toArray(new Integer[indexSet.size()]) // TODO verify this is ordered
indexSet.toArray(new Integer[indexSet.size()])
);
}
@ -73,12 +80,19 @@ public final class NumericCovariate implements Covariate<Double> {
@Override
public NumericValue createValue(String value) {
if(value == null || value.equalsIgnoreCase("na")){
this.hasNAs = true;
return createValue((Double) null);
}
return createValue(Double.parseDouble(value));
}
@Override
public boolean hasNAs() {
return hasNAs;
}
@EqualsAndHashCode
public class NumericValue implements Covariate.Value<Double>{

View file

@ -76,22 +76,25 @@ public class TreeTrainer<Y, O> {
final double probabilityLeftHand = (double) bestSplit.leftHand.size() /
(double) (bestSplit.leftHand.size() + bestSplit.rightHand.size());
// Assign missing values to the split
for(Row<Y> row : data) {
if(row.getCovariateValue(bestSplit.getSplitRule().getParent()).isNA()) {
final boolean randomDecision = random.nextDouble() <= probabilityLeftHand;
// Assign missing values to the split if necessary
if(bestSplit.getSplitRule().getParent().hasNAs()){
for(Row<Y> row : data) {
if(row.getCovariateValue(bestSplit.getSplitRule().getParent()).isNA()) {
final boolean randomDecision = random.nextDouble() <= probabilityLeftHand;
if(randomDecision){
bestSplit.getLeftHand().add(row);
}
else{
bestSplit.getRightHand().add(row);
}
if(randomDecision){
bestSplit.getLeftHand().add(row);
}
else{
bestSplit.getRightHand().add(row);
}
}
}
}
final Node<O> leftNode = growNode(bestSplit.leftHand, depth+1, random);
final Node<O> rightNode = growNode(bestSplit.rightHand, depth+1, random);