Add OTFI imputation when training forest.

No tests have been written yet so this is still WIP.
This commit is contained in:
Joel Therrien 2018-07-05 12:05:07 -07:00
parent c048a285a1
commit 662a6cf761
5 changed files with 81 additions and 47 deletions

View file

@ -1,12 +1,9 @@
package ca.joeltherrien.randomforest; package ca.joeltherrien.randomforest;
import ca.joeltherrien.randomforest.exceptions.MissingValueException;
import lombok.Getter; import lombok.Getter;
import lombok.RequiredArgsConstructor; import lombok.RequiredArgsConstructor;
import java.util.*; import java.util.*;
import java.util.concurrent.ThreadLocalRandom;
import java.util.stream.Collectors;
@RequiredArgsConstructor @RequiredArgsConstructor
public class BooleanCovariate implements Covariate<Boolean>{ public class BooleanCovariate implements Covariate<Boolean>{
@ -28,9 +25,9 @@ public class BooleanCovariate implements Covariate<Boolean>{
public class BooleanValue implements Value<Boolean>{ public class BooleanValue implements Value<Boolean>{
private final boolean value; private final Boolean value;
private BooleanValue(final boolean value){ private BooleanValue(final Boolean value){
this.value = value; this.value = value;
} }
@ -43,6 +40,11 @@ public class BooleanCovariate implements Covariate<Boolean>{
public Boolean getValue() { public Boolean getValue() {
return value; return value;
} }
@Override
public boolean isNA() {
return value == null;
}
} }
public class BooleanSplitRule implements SplitRule<Boolean>{ public class BooleanSplitRule implements SplitRule<Boolean>{
@ -58,15 +60,12 @@ public class BooleanCovariate implements Covariate<Boolean>{
} }
@Override @Override
public boolean isLeftHand(CovariateRow row) { public boolean isLeftHand(final Value<Boolean> value) {
final Value<?> x = row.getCovariateValue(getParent().getName()); if(value.isNA()) {
if(x == null) { throw new IllegalArgumentException("Trying to determine split on missing value");
throw new MissingValueException(row, this);
} }
final boolean xBoolean = (Boolean) x.getValue(); return !value.getValue();
return !xBoolean;
} }
} }
} }

View file

@ -1,9 +1,8 @@
package ca.joeltherrien.randomforest; package ca.joeltherrien.randomforest;
import java.io.Serializable; import java.io.Serializable;
import java.util.Collection; import java.util.*;
import java.util.LinkedList; import java.util.concurrent.ThreadLocalRandom;
import java.util.List;
public interface Covariate<V> extends Serializable { public interface Covariate<V> extends Serializable {
@ -19,6 +18,8 @@ public interface Covariate<V> extends Serializable {
V getValue(); V getValue();
boolean isNA();
} }
interface SplitRule<V> extends Serializable{ interface SplitRule<V> extends Serializable{
@ -37,9 +38,22 @@ public interface Covariate<V> extends Serializable {
final List<Row<Y>> leftHand = new LinkedList<>(); final List<Row<Y>> leftHand = new LinkedList<>();
final List<Row<Y>> rightHand = new LinkedList<>(); final List<Row<Y>> rightHand = new LinkedList<>();
for(final Row<Y> row : rows) { final List<Boolean> nonMissingDecisions = new ArrayList<>();
final List<Row<Y>> missingValueRows = new ArrayList<>();
if(isLeftHand(row)){
for(final Row<Y> row : rows) {
final Value<V> value = (Value<V>) row.getCovariateValue(getParent().getName());
if(value.isNA()){
missingValueRows.add(row);
continue;
}
final boolean isLeftHand = isLeftHand(value);
nonMissingDecisions.add(isLeftHand);
if(isLeftHand){
leftHand.add(row); leftHand.add(row);
} }
else{ else{
@ -48,10 +62,31 @@ public interface Covariate<V> extends Serializable {
} }
if(nonMissingDecisions.size() == 0 && missingValueRows.size() > 0){
throw new IllegalArgumentException("Can't apply " + this + " when there are rows with missing data and no non-missing value rows");
}
final Random random = ThreadLocalRandom.current();
for(final Row<Y> missingValueRow : missingValueRows){
final boolean randomDecision = nonMissingDecisions.get(random.nextInt(nonMissingDecisions.size()));
if(randomDecision){
leftHand.add(missingValueRow);
}
else{
rightHand.add(missingValueRow);
}
}
return new Split<>(leftHand, rightHand); return new Split<>(leftHand, rightHand);
} }
boolean isLeftHand(CovariateRow row); default boolean isLeftHand(CovariateRow row){
final Value<V> value = (Value<V>) row.getCovariateValue(getParent().getName());
return isLeftHand(value);
}
boolean isLeftHand(Value<V> value);
} }

View file

@ -9,6 +9,7 @@ public final class FactorCovariate implements Covariate<String>{
private final String name; private final String name;
private final Map<String, FactorValue> factorLevels; private final Map<String, FactorValue> factorLevels;
private final FactorValue naValue;
private final int numberOfPossiblePairings; private final int numberOfPossiblePairings;
@ -28,6 +29,7 @@ public final class FactorCovariate implements Covariate<String>{
} }
this.numberOfPossiblePairings = numberOfPossiblePairingsTemp-1; this.numberOfPossiblePairings = numberOfPossiblePairingsTemp-1;
this.naValue = new FactorValue(null);
} }
@ -67,6 +69,10 @@ public final class FactorCovariate implements Covariate<String>{
@Override @Override
public FactorValue createValue(String value) { public FactorValue createValue(String value) {
if(value == null){
return this.naValue;
}
final FactorValue factorValue = factorLevels.get(value); final FactorValue factorValue = factorLevels.get(value);
if(factorValue == null){ if(factorValue == null){
@ -94,6 +100,11 @@ public final class FactorCovariate implements Covariate<String>{
public String getValue() { public String getValue() {
return value; return value;
} }
@Override
public boolean isNA() {
return value == null;
}
} }
@EqualsAndHashCode @EqualsAndHashCode
@ -111,12 +122,12 @@ public final class FactorCovariate implements Covariate<String>{
} }
@Override @Override
public boolean isLeftHand(CovariateRow row) { public boolean isLeftHand(final Value<String> value) {
final FactorValue value = (FactorValue) row.getCovariateValue(getName()).getValue(); if(value.isNA()){
throw new IllegalArgumentException("Trying to determine split on missing value");
}
return leftSideValues.contains(value); return leftSideValues.contains(value);
} }
} }
} }

View file

@ -1,6 +1,5 @@
package ca.joeltherrien.randomforest; package ca.joeltherrien.randomforest;
import ca.joeltherrien.randomforest.exceptions.MissingValueException;
import lombok.Getter; import lombok.Getter;
import lombok.RequiredArgsConstructor; import lombok.RequiredArgsConstructor;
@ -19,6 +18,9 @@ public class NumericCovariate implements Covariate<Double>{
final Random random = ThreadLocalRandom.current(); final Random random = ThreadLocalRandom.current();
// only work with non-NA values
data = data.stream().filter(value -> !value.isNA()).collect(Collectors.toList());
// for this implementation we need to shuffle the data // for this implementation we need to shuffle the data
final List<Value<Double>> shuffledData; final List<Value<Double>> shuffledData;
if(number > data.size()){ if(number > data.size()){
@ -55,9 +57,9 @@ public class NumericCovariate implements Covariate<Double>{
public class NumericValue implements Covariate.Value<Double>{ public class NumericValue implements Covariate.Value<Double>{
private final double value; private final Double value; // may be null
private NumericValue(final double value){ private NumericValue(final Double value){
this.value = value; this.value = value;
} }
@ -70,6 +72,11 @@ public class NumericCovariate implements Covariate<Double>{
public Double getValue() { public Double getValue() {
return value; return value;
} }
@Override
public boolean isNA() {
return value == null;
}
} }
public class NumericSplitRule implements Covariate.SplitRule<Double>{ public class NumericSplitRule implements Covariate.SplitRule<Double>{
@ -91,13 +98,12 @@ public class NumericCovariate implements Covariate<Double>{
} }
@Override @Override
public boolean isLeftHand(CovariateRow row) { public boolean isLeftHand(final Value<Double> x) {
final Covariate.Value<?> x = row.getCovariateValue(getParent().getName()); if(x.isNA()) {
if(x == null) { throw new IllegalArgumentException("Trying to determine split on missing value");
throw new MissingValueException(row, this);
} }
final double xNum = (Double) x.getValue(); final double xNum = x.getValue();
return xNum <= threshold; return xNum <= threshold;
} }

View file

@ -1,17 +0,0 @@
package ca.joeltherrien.randomforest.exceptions;
import ca.joeltherrien.randomforest.Covariate;
import ca.joeltherrien.randomforest.CovariateRow;
public class MissingValueException extends RuntimeException{
/**
*
*/
private static final long serialVersionUID = 6808060079431207726L;
public MissingValueException(CovariateRow row, Covariate.SplitRule rule) {
super("Missing value at CovariateRow " + row + rule);
}
}