Add OTFI imputation when training forest.
No tests have been written yet so this is still WIP.
This commit is contained in:
parent
c048a285a1
commit
662a6cf761
5 changed files with 81 additions and 47 deletions
|
@ -1,12 +1,9 @@
|
||||||
package ca.joeltherrien.randomforest;
|
package ca.joeltherrien.randomforest;
|
||||||
|
|
||||||
import ca.joeltherrien.randomforest.exceptions.MissingValueException;
|
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
import lombok.RequiredArgsConstructor;
|
import lombok.RequiredArgsConstructor;
|
||||||
|
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
import java.util.concurrent.ThreadLocalRandom;
|
|
||||||
import java.util.stream.Collectors;
|
|
||||||
|
|
||||||
@RequiredArgsConstructor
|
@RequiredArgsConstructor
|
||||||
public class BooleanCovariate implements Covariate<Boolean>{
|
public class BooleanCovariate implements Covariate<Boolean>{
|
||||||
|
@ -28,9 +25,9 @@ public class BooleanCovariate implements Covariate<Boolean>{
|
||||||
|
|
||||||
public class BooleanValue implements Value<Boolean>{
|
public class BooleanValue implements Value<Boolean>{
|
||||||
|
|
||||||
private final boolean value;
|
private final Boolean value;
|
||||||
|
|
||||||
private BooleanValue(final boolean value){
|
private BooleanValue(final Boolean value){
|
||||||
this.value = value;
|
this.value = value;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -43,6 +40,11 @@ public class BooleanCovariate implements Covariate<Boolean>{
|
||||||
public Boolean getValue() {
|
public Boolean getValue() {
|
||||||
return value;
|
return value;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean isNA() {
|
||||||
|
return value == null;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public class BooleanSplitRule implements SplitRule<Boolean>{
|
public class BooleanSplitRule implements SplitRule<Boolean>{
|
||||||
|
@ -58,15 +60,12 @@ public class BooleanCovariate implements Covariate<Boolean>{
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public boolean isLeftHand(CovariateRow row) {
|
public boolean isLeftHand(final Value<Boolean> value) {
|
||||||
final Value<?> x = row.getCovariateValue(getParent().getName());
|
if(value.isNA()) {
|
||||||
if(x == null) {
|
throw new IllegalArgumentException("Trying to determine split on missing value");
|
||||||
throw new MissingValueException(row, this);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
final boolean xBoolean = (Boolean) x.getValue();
|
return !value.getValue();
|
||||||
|
|
||||||
return !xBoolean;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,9 +1,8 @@
|
||||||
package ca.joeltherrien.randomforest;
|
package ca.joeltherrien.randomforest;
|
||||||
|
|
||||||
import java.io.Serializable;
|
import java.io.Serializable;
|
||||||
import java.util.Collection;
|
import java.util.*;
|
||||||
import java.util.LinkedList;
|
import java.util.concurrent.ThreadLocalRandom;
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
public interface Covariate<V> extends Serializable {
|
public interface Covariate<V> extends Serializable {
|
||||||
|
|
||||||
|
@ -19,6 +18,8 @@ public interface Covariate<V> extends Serializable {
|
||||||
|
|
||||||
V getValue();
|
V getValue();
|
||||||
|
|
||||||
|
boolean isNA();
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
interface SplitRule<V> extends Serializable{
|
interface SplitRule<V> extends Serializable{
|
||||||
|
@ -37,9 +38,22 @@ public interface Covariate<V> extends Serializable {
|
||||||
final List<Row<Y>> leftHand = new LinkedList<>();
|
final List<Row<Y>> leftHand = new LinkedList<>();
|
||||||
final List<Row<Y>> rightHand = new LinkedList<>();
|
final List<Row<Y>> rightHand = new LinkedList<>();
|
||||||
|
|
||||||
for(final Row<Y> row : rows) {
|
final List<Boolean> nonMissingDecisions = new ArrayList<>();
|
||||||
|
final List<Row<Y>> missingValueRows = new ArrayList<>();
|
||||||
|
|
||||||
if(isLeftHand(row)){
|
|
||||||
|
for(final Row<Y> row : rows) {
|
||||||
|
final Value<V> value = (Value<V>) row.getCovariateValue(getParent().getName());
|
||||||
|
|
||||||
|
if(value.isNA()){
|
||||||
|
missingValueRows.add(row);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
final boolean isLeftHand = isLeftHand(value);
|
||||||
|
nonMissingDecisions.add(isLeftHand);
|
||||||
|
|
||||||
|
if(isLeftHand){
|
||||||
leftHand.add(row);
|
leftHand.add(row);
|
||||||
}
|
}
|
||||||
else{
|
else{
|
||||||
|
@ -48,10 +62,31 @@ public interface Covariate<V> extends Serializable {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if(nonMissingDecisions.size() == 0 && missingValueRows.size() > 0){
|
||||||
|
throw new IllegalArgumentException("Can't apply " + this + " when there are rows with missing data and no non-missing value rows");
|
||||||
|
}
|
||||||
|
|
||||||
|
final Random random = ThreadLocalRandom.current();
|
||||||
|
for(final Row<Y> missingValueRow : missingValueRows){
|
||||||
|
final boolean randomDecision = nonMissingDecisions.get(random.nextInt(nonMissingDecisions.size()));
|
||||||
|
|
||||||
|
if(randomDecision){
|
||||||
|
leftHand.add(missingValueRow);
|
||||||
|
}
|
||||||
|
else{
|
||||||
|
rightHand.add(missingValueRow);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return new Split<>(leftHand, rightHand);
|
return new Split<>(leftHand, rightHand);
|
||||||
}
|
}
|
||||||
|
|
||||||
boolean isLeftHand(CovariateRow row);
|
default boolean isLeftHand(CovariateRow row){
|
||||||
|
final Value<V> value = (Value<V>) row.getCovariateValue(getParent().getName());
|
||||||
|
return isLeftHand(value);
|
||||||
|
}
|
||||||
|
|
||||||
|
boolean isLeftHand(Value<V> value);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -9,6 +9,7 @@ public final class FactorCovariate implements Covariate<String>{
|
||||||
|
|
||||||
private final String name;
|
private final String name;
|
||||||
private final Map<String, FactorValue> factorLevels;
|
private final Map<String, FactorValue> factorLevels;
|
||||||
|
private final FactorValue naValue;
|
||||||
private final int numberOfPossiblePairings;
|
private final int numberOfPossiblePairings;
|
||||||
|
|
||||||
|
|
||||||
|
@ -28,6 +29,7 @@ public final class FactorCovariate implements Covariate<String>{
|
||||||
}
|
}
|
||||||
this.numberOfPossiblePairings = numberOfPossiblePairingsTemp-1;
|
this.numberOfPossiblePairings = numberOfPossiblePairingsTemp-1;
|
||||||
|
|
||||||
|
this.naValue = new FactorValue(null);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -67,6 +69,10 @@ public final class FactorCovariate implements Covariate<String>{
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public FactorValue createValue(String value) {
|
public FactorValue createValue(String value) {
|
||||||
|
if(value == null){
|
||||||
|
return this.naValue;
|
||||||
|
}
|
||||||
|
|
||||||
final FactorValue factorValue = factorLevels.get(value);
|
final FactorValue factorValue = factorLevels.get(value);
|
||||||
|
|
||||||
if(factorValue == null){
|
if(factorValue == null){
|
||||||
|
@ -94,6 +100,11 @@ public final class FactorCovariate implements Covariate<String>{
|
||||||
public String getValue() {
|
public String getValue() {
|
||||||
return value;
|
return value;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean isNA() {
|
||||||
|
return value == null;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@EqualsAndHashCode
|
@EqualsAndHashCode
|
||||||
|
@ -111,12 +122,12 @@ public final class FactorCovariate implements Covariate<String>{
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public boolean isLeftHand(CovariateRow row) {
|
public boolean isLeftHand(final Value<String> value) {
|
||||||
final FactorValue value = (FactorValue) row.getCovariateValue(getName()).getValue();
|
if(value.isNA()){
|
||||||
|
throw new IllegalArgumentException("Trying to determine split on missing value");
|
||||||
|
}
|
||||||
|
|
||||||
return leftSideValues.contains(value);
|
return leftSideValues.contains(value);
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,5 @@
|
||||||
package ca.joeltherrien.randomforest;
|
package ca.joeltherrien.randomforest;
|
||||||
|
|
||||||
import ca.joeltherrien.randomforest.exceptions.MissingValueException;
|
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
import lombok.RequiredArgsConstructor;
|
import lombok.RequiredArgsConstructor;
|
||||||
|
|
||||||
|
@ -19,6 +18,9 @@ public class NumericCovariate implements Covariate<Double>{
|
||||||
|
|
||||||
final Random random = ThreadLocalRandom.current();
|
final Random random = ThreadLocalRandom.current();
|
||||||
|
|
||||||
|
// only work with non-NA values
|
||||||
|
data = data.stream().filter(value -> !value.isNA()).collect(Collectors.toList());
|
||||||
|
|
||||||
// for this implementation we need to shuffle the data
|
// for this implementation we need to shuffle the data
|
||||||
final List<Value<Double>> shuffledData;
|
final List<Value<Double>> shuffledData;
|
||||||
if(number > data.size()){
|
if(number > data.size()){
|
||||||
|
@ -55,9 +57,9 @@ public class NumericCovariate implements Covariate<Double>{
|
||||||
|
|
||||||
public class NumericValue implements Covariate.Value<Double>{
|
public class NumericValue implements Covariate.Value<Double>{
|
||||||
|
|
||||||
private final double value;
|
private final Double value; // may be null
|
||||||
|
|
||||||
private NumericValue(final double value){
|
private NumericValue(final Double value){
|
||||||
this.value = value;
|
this.value = value;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -70,6 +72,11 @@ public class NumericCovariate implements Covariate<Double>{
|
||||||
public Double getValue() {
|
public Double getValue() {
|
||||||
return value;
|
return value;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean isNA() {
|
||||||
|
return value == null;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public class NumericSplitRule implements Covariate.SplitRule<Double>{
|
public class NumericSplitRule implements Covariate.SplitRule<Double>{
|
||||||
|
@ -91,13 +98,12 @@ public class NumericCovariate implements Covariate<Double>{
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public boolean isLeftHand(CovariateRow row) {
|
public boolean isLeftHand(final Value<Double> x) {
|
||||||
final Covariate.Value<?> x = row.getCovariateValue(getParent().getName());
|
if(x.isNA()) {
|
||||||
if(x == null) {
|
throw new IllegalArgumentException("Trying to determine split on missing value");
|
||||||
throw new MissingValueException(row, this);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
final double xNum = (Double) x.getValue();
|
final double xNum = x.getValue();
|
||||||
|
|
||||||
return xNum <= threshold;
|
return xNum <= threshold;
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,17 +0,0 @@
|
||||||
package ca.joeltherrien.randomforest.exceptions;
|
|
||||||
|
|
||||||
import ca.joeltherrien.randomforest.Covariate;
|
|
||||||
import ca.joeltherrien.randomforest.CovariateRow;
|
|
||||||
|
|
||||||
public class MissingValueException extends RuntimeException{
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
*/
|
|
||||||
private static final long serialVersionUID = 6808060079431207726L;
|
|
||||||
|
|
||||||
public MissingValueException(CovariateRow row, Covariate.SplitRule rule) {
|
|
||||||
super("Missing value at CovariateRow " + row + rule);
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
Loading…
Reference in a new issue