Add OTFI imputation when training forest.
No tests have been written yet so this is still WIP.
This commit is contained in:
parent
c048a285a1
commit
662a6cf761
5 changed files with 81 additions and 47 deletions
|
@ -1,12 +1,9 @@
|
|||
package ca.joeltherrien.randomforest;
|
||||
|
||||
import ca.joeltherrien.randomforest.exceptions.MissingValueException;
|
||||
import lombok.Getter;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
|
||||
import java.util.*;
|
||||
import java.util.concurrent.ThreadLocalRandom;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@RequiredArgsConstructor
|
||||
public class BooleanCovariate implements Covariate<Boolean>{
|
||||
|
@ -28,9 +25,9 @@ public class BooleanCovariate implements Covariate<Boolean>{
|
|||
|
||||
public class BooleanValue implements Value<Boolean>{
|
||||
|
||||
private final boolean value;
|
||||
private final Boolean value;
|
||||
|
||||
private BooleanValue(final boolean value){
|
||||
private BooleanValue(final Boolean value){
|
||||
this.value = value;
|
||||
}
|
||||
|
||||
|
@ -43,6 +40,11 @@ public class BooleanCovariate implements Covariate<Boolean>{
|
|||
public Boolean getValue() {
|
||||
return value;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isNA() {
|
||||
return value == null;
|
||||
}
|
||||
}
|
||||
|
||||
public class BooleanSplitRule implements SplitRule<Boolean>{
|
||||
|
@ -58,15 +60,12 @@ public class BooleanCovariate implements Covariate<Boolean>{
|
|||
}
|
||||
|
||||
@Override
|
||||
public boolean isLeftHand(CovariateRow row) {
|
||||
final Value<?> x = row.getCovariateValue(getParent().getName());
|
||||
if(x == null) {
|
||||
throw new MissingValueException(row, this);
|
||||
public boolean isLeftHand(final Value<Boolean> value) {
|
||||
if(value.isNA()) {
|
||||
throw new IllegalArgumentException("Trying to determine split on missing value");
|
||||
}
|
||||
|
||||
final boolean xBoolean = (Boolean) x.getValue();
|
||||
|
||||
return !xBoolean;
|
||||
return !value.getValue();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,9 +1,8 @@
|
|||
package ca.joeltherrien.randomforest;
|
||||
|
||||
import java.io.Serializable;
|
||||
import java.util.Collection;
|
||||
import java.util.LinkedList;
|
||||
import java.util.List;
|
||||
import java.util.*;
|
||||
import java.util.concurrent.ThreadLocalRandom;
|
||||
|
||||
public interface Covariate<V> extends Serializable {
|
||||
|
||||
|
@ -19,6 +18,8 @@ public interface Covariate<V> extends Serializable {
|
|||
|
||||
V getValue();
|
||||
|
||||
boolean isNA();
|
||||
|
||||
}
|
||||
|
||||
interface SplitRule<V> extends Serializable{
|
||||
|
@ -37,9 +38,22 @@ public interface Covariate<V> extends Serializable {
|
|||
final List<Row<Y>> leftHand = new LinkedList<>();
|
||||
final List<Row<Y>> rightHand = new LinkedList<>();
|
||||
|
||||
for(final Row<Y> row : rows) {
|
||||
final List<Boolean> nonMissingDecisions = new ArrayList<>();
|
||||
final List<Row<Y>> missingValueRows = new ArrayList<>();
|
||||
|
||||
if(isLeftHand(row)){
|
||||
|
||||
for(final Row<Y> row : rows) {
|
||||
final Value<V> value = (Value<V>) row.getCovariateValue(getParent().getName());
|
||||
|
||||
if(value.isNA()){
|
||||
missingValueRows.add(row);
|
||||
continue;
|
||||
}
|
||||
|
||||
final boolean isLeftHand = isLeftHand(value);
|
||||
nonMissingDecisions.add(isLeftHand);
|
||||
|
||||
if(isLeftHand){
|
||||
leftHand.add(row);
|
||||
}
|
||||
else{
|
||||
|
@ -48,10 +62,31 @@ public interface Covariate<V> extends Serializable {
|
|||
|
||||
}
|
||||
|
||||
if(nonMissingDecisions.size() == 0 && missingValueRows.size() > 0){
|
||||
throw new IllegalArgumentException("Can't apply " + this + " when there are rows with missing data and no non-missing value rows");
|
||||
}
|
||||
|
||||
final Random random = ThreadLocalRandom.current();
|
||||
for(final Row<Y> missingValueRow : missingValueRows){
|
||||
final boolean randomDecision = nonMissingDecisions.get(random.nextInt(nonMissingDecisions.size()));
|
||||
|
||||
if(randomDecision){
|
||||
leftHand.add(missingValueRow);
|
||||
}
|
||||
else{
|
||||
rightHand.add(missingValueRow);
|
||||
}
|
||||
}
|
||||
|
||||
return new Split<>(leftHand, rightHand);
|
||||
}
|
||||
|
||||
boolean isLeftHand(CovariateRow row);
|
||||
default boolean isLeftHand(CovariateRow row){
|
||||
final Value<V> value = (Value<V>) row.getCovariateValue(getParent().getName());
|
||||
return isLeftHand(value);
|
||||
}
|
||||
|
||||
boolean isLeftHand(Value<V> value);
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -9,6 +9,7 @@ public final class FactorCovariate implements Covariate<String>{
|
|||
|
||||
private final String name;
|
||||
private final Map<String, FactorValue> factorLevels;
|
||||
private final FactorValue naValue;
|
||||
private final int numberOfPossiblePairings;
|
||||
|
||||
|
||||
|
@ -28,6 +29,7 @@ public final class FactorCovariate implements Covariate<String>{
|
|||
}
|
||||
this.numberOfPossiblePairings = numberOfPossiblePairingsTemp-1;
|
||||
|
||||
this.naValue = new FactorValue(null);
|
||||
}
|
||||
|
||||
|
||||
|
@ -67,6 +69,10 @@ public final class FactorCovariate implements Covariate<String>{
|
|||
|
||||
@Override
|
||||
public FactorValue createValue(String value) {
|
||||
if(value == null){
|
||||
return this.naValue;
|
||||
}
|
||||
|
||||
final FactorValue factorValue = factorLevels.get(value);
|
||||
|
||||
if(factorValue == null){
|
||||
|
@ -94,6 +100,11 @@ public final class FactorCovariate implements Covariate<String>{
|
|||
public String getValue() {
|
||||
return value;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isNA() {
|
||||
return value == null;
|
||||
}
|
||||
}
|
||||
|
||||
@EqualsAndHashCode
|
||||
|
@ -111,12 +122,12 @@ public final class FactorCovariate implements Covariate<String>{
|
|||
}
|
||||
|
||||
@Override
|
||||
public boolean isLeftHand(CovariateRow row) {
|
||||
final FactorValue value = (FactorValue) row.getCovariateValue(getName()).getValue();
|
||||
public boolean isLeftHand(final Value<String> value) {
|
||||
if(value.isNA()){
|
||||
throw new IllegalArgumentException("Trying to determine split on missing value");
|
||||
}
|
||||
|
||||
return leftSideValues.contains(value);
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
package ca.joeltherrien.randomforest;
|
||||
|
||||
import ca.joeltherrien.randomforest.exceptions.MissingValueException;
|
||||
import lombok.Getter;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
|
||||
|
@ -19,6 +18,9 @@ public class NumericCovariate implements Covariate<Double>{
|
|||
|
||||
final Random random = ThreadLocalRandom.current();
|
||||
|
||||
// only work with non-NA values
|
||||
data = data.stream().filter(value -> !value.isNA()).collect(Collectors.toList());
|
||||
|
||||
// for this implementation we need to shuffle the data
|
||||
final List<Value<Double>> shuffledData;
|
||||
if(number > data.size()){
|
||||
|
@ -55,9 +57,9 @@ public class NumericCovariate implements Covariate<Double>{
|
|||
|
||||
public class NumericValue implements Covariate.Value<Double>{
|
||||
|
||||
private final double value;
|
||||
private final Double value; // may be null
|
||||
|
||||
private NumericValue(final double value){
|
||||
private NumericValue(final Double value){
|
||||
this.value = value;
|
||||
}
|
||||
|
||||
|
@ -70,6 +72,11 @@ public class NumericCovariate implements Covariate<Double>{
|
|||
public Double getValue() {
|
||||
return value;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isNA() {
|
||||
return value == null;
|
||||
}
|
||||
}
|
||||
|
||||
public class NumericSplitRule implements Covariate.SplitRule<Double>{
|
||||
|
@ -91,13 +98,12 @@ public class NumericCovariate implements Covariate<Double>{
|
|||
}
|
||||
|
||||
@Override
|
||||
public boolean isLeftHand(CovariateRow row) {
|
||||
final Covariate.Value<?> x = row.getCovariateValue(getParent().getName());
|
||||
if(x == null) {
|
||||
throw new MissingValueException(row, this);
|
||||
public boolean isLeftHand(final Value<Double> x) {
|
||||
if(x.isNA()) {
|
||||
throw new IllegalArgumentException("Trying to determine split on missing value");
|
||||
}
|
||||
|
||||
final double xNum = (Double) x.getValue();
|
||||
final double xNum = x.getValue();
|
||||
|
||||
return xNum <= threshold;
|
||||
}
|
||||
|
|
|
@ -1,17 +0,0 @@
|
|||
package ca.joeltherrien.randomforest.exceptions;
|
||||
|
||||
import ca.joeltherrien.randomforest.Covariate;
|
||||
import ca.joeltherrien.randomforest.CovariateRow;
|
||||
|
||||
public class MissingValueException extends RuntimeException{
|
||||
|
||||
/**
|
||||
*
|
||||
*/
|
||||
private static final long serialVersionUID = 6808060079431207726L;
|
||||
|
||||
public MissingValueException(CovariateRow row, Covariate.SplitRule rule) {
|
||||
super("Missing value at CovariateRow " + row + rule);
|
||||
}
|
||||
|
||||
}
|
Loading…
Reference in a new issue