Add OTFI imputation when training forest.

No tests have been written yet so this is still WIP.
This commit is contained in:
Joel Therrien 2018-07-05 12:05:07 -07:00
parent c048a285a1
commit 662a6cf761
5 changed files with 81 additions and 47 deletions

View file

@ -1,12 +1,9 @@
package ca.joeltherrien.randomforest;
import ca.joeltherrien.randomforest.exceptions.MissingValueException;
import lombok.Getter;
import lombok.RequiredArgsConstructor;
import java.util.*;
import java.util.concurrent.ThreadLocalRandom;
import java.util.stream.Collectors;
@RequiredArgsConstructor
public class BooleanCovariate implements Covariate<Boolean>{
@ -28,9 +25,9 @@ public class BooleanCovariate implements Covariate<Boolean>{
public class BooleanValue implements Value<Boolean>{
private final boolean value;
private final Boolean value;
private BooleanValue(final boolean value){
private BooleanValue(final Boolean value){
this.value = value;
}
@ -43,6 +40,11 @@ public class BooleanCovariate implements Covariate<Boolean>{
public Boolean getValue() {
return value;
}
@Override
public boolean isNA() {
return value == null;
}
}
public class BooleanSplitRule implements SplitRule<Boolean>{
@ -58,15 +60,12 @@ public class BooleanCovariate implements Covariate<Boolean>{
}
@Override
public boolean isLeftHand(CovariateRow row) {
final Value<?> x = row.getCovariateValue(getParent().getName());
if(x == null) {
throw new MissingValueException(row, this);
public boolean isLeftHand(final Value<Boolean> value) {
if(value.isNA()) {
throw new IllegalArgumentException("Trying to determine split on missing value");
}
final boolean xBoolean = (Boolean) x.getValue();
return !xBoolean;
return !value.getValue();
}
}
}

View file

@ -1,9 +1,8 @@
package ca.joeltherrien.randomforest;
import java.io.Serializable;
import java.util.Collection;
import java.util.LinkedList;
import java.util.List;
import java.util.*;
import java.util.concurrent.ThreadLocalRandom;
public interface Covariate<V> extends Serializable {
@ -19,6 +18,8 @@ public interface Covariate<V> extends Serializable {
V getValue();
boolean isNA();
}
interface SplitRule<V> extends Serializable{
@ -37,9 +38,22 @@ public interface Covariate<V> extends Serializable {
final List<Row<Y>> leftHand = new LinkedList<>();
final List<Row<Y>> rightHand = new LinkedList<>();
for(final Row<Y> row : rows) {
final List<Boolean> nonMissingDecisions = new ArrayList<>();
final List<Row<Y>> missingValueRows = new ArrayList<>();
if(isLeftHand(row)){
for(final Row<Y> row : rows) {
final Value<V> value = (Value<V>) row.getCovariateValue(getParent().getName());
if(value.isNA()){
missingValueRows.add(row);
continue;
}
final boolean isLeftHand = isLeftHand(value);
nonMissingDecisions.add(isLeftHand);
if(isLeftHand){
leftHand.add(row);
}
else{
@ -48,10 +62,31 @@ public interface Covariate<V> extends Serializable {
}
if(nonMissingDecisions.size() == 0 && missingValueRows.size() > 0){
throw new IllegalArgumentException("Can't apply " + this + " when there are rows with missing data and no non-missing value rows");
}
final Random random = ThreadLocalRandom.current();
for(final Row<Y> missingValueRow : missingValueRows){
final boolean randomDecision = nonMissingDecisions.get(random.nextInt(nonMissingDecisions.size()));
if(randomDecision){
leftHand.add(missingValueRow);
}
else{
rightHand.add(missingValueRow);
}
}
return new Split<>(leftHand, rightHand);
}
boolean isLeftHand(CovariateRow row);
default boolean isLeftHand(CovariateRow row){
final Value<V> value = (Value<V>) row.getCovariateValue(getParent().getName());
return isLeftHand(value);
}
boolean isLeftHand(Value<V> value);
}

View file

@ -9,6 +9,7 @@ public final class FactorCovariate implements Covariate<String>{
private final String name;
private final Map<String, FactorValue> factorLevels;
private final FactorValue naValue;
private final int numberOfPossiblePairings;
@ -28,6 +29,7 @@ public final class FactorCovariate implements Covariate<String>{
}
this.numberOfPossiblePairings = numberOfPossiblePairingsTemp-1;
this.naValue = new FactorValue(null);
}
@ -67,6 +69,10 @@ public final class FactorCovariate implements Covariate<String>{
@Override
public FactorValue createValue(String value) {
if(value == null){
return this.naValue;
}
final FactorValue factorValue = factorLevels.get(value);
if(factorValue == null){
@ -94,6 +100,11 @@ public final class FactorCovariate implements Covariate<String>{
public String getValue() {
return value;
}
@Override
public boolean isNA() {
return value == null;
}
}
@EqualsAndHashCode
@ -111,12 +122,12 @@ public final class FactorCovariate implements Covariate<String>{
}
@Override
public boolean isLeftHand(CovariateRow row) {
final FactorValue value = (FactorValue) row.getCovariateValue(getName()).getValue();
public boolean isLeftHand(final Value<String> value) {
if(value.isNA()){
throw new IllegalArgumentException("Trying to determine split on missing value");
}
return leftSideValues.contains(value);
}
}
}

View file

@ -1,6 +1,5 @@
package ca.joeltherrien.randomforest;
import ca.joeltherrien.randomforest.exceptions.MissingValueException;
import lombok.Getter;
import lombok.RequiredArgsConstructor;
@ -19,6 +18,9 @@ public class NumericCovariate implements Covariate<Double>{
final Random random = ThreadLocalRandom.current();
// only work with non-NA values
data = data.stream().filter(value -> !value.isNA()).collect(Collectors.toList());
// for this implementation we need to shuffle the data
final List<Value<Double>> shuffledData;
if(number > data.size()){
@ -55,9 +57,9 @@ public class NumericCovariate implements Covariate<Double>{
public class NumericValue implements Covariate.Value<Double>{
private final double value;
private final Double value; // may be null
private NumericValue(final double value){
private NumericValue(final Double value){
this.value = value;
}
@ -70,6 +72,11 @@ public class NumericCovariate implements Covariate<Double>{
public Double getValue() {
return value;
}
@Override
public boolean isNA() {
return value == null;
}
}
public class NumericSplitRule implements Covariate.SplitRule<Double>{
@ -91,13 +98,12 @@ public class NumericCovariate implements Covariate<Double>{
}
@Override
public boolean isLeftHand(CovariateRow row) {
final Covariate.Value<?> x = row.getCovariateValue(getParent().getName());
if(x == null) {
throw new MissingValueException(row, this);
public boolean isLeftHand(final Value<Double> x) {
if(x.isNA()) {
throw new IllegalArgumentException("Trying to determine split on missing value");
}
final double xNum = (Double) x.getValue();
final double xNum = x.getValue();
return xNum <= threshold;
}

View file

@ -1,17 +0,0 @@
package ca.joeltherrien.randomforest.exceptions;
import ca.joeltherrien.randomforest.Covariate;
import ca.joeltherrien.randomforest.CovariateRow;
public class MissingValueException extends RuntimeException{
/**
*
*/
private static final long serialVersionUID = 6808060079431207726L;
public MissingValueException(CovariateRow row, Covariate.SplitRule rule) {
super("Missing value at CovariateRow " + row + rule);
}
}