Covariates track if they have any NA values, and skip NA handling code if possible
This commit is contained in:
parent
a57741b726
commit
31d6ce9b3e
5 changed files with 50 additions and 14 deletions
|
@ -17,6 +17,8 @@ public final class BooleanCovariate implements Covariate<Boolean> {
|
||||||
@Getter
|
@Getter
|
||||||
private final int index;
|
private final int index;
|
||||||
|
|
||||||
|
private boolean hasNAs = false;
|
||||||
|
|
||||||
private final BooleanSplitRule splitRule = new BooleanSplitRule(); // there's only one possible rule for BooleanCovariates.
|
private final BooleanSplitRule splitRule = new BooleanSplitRule(); // there's only one possible rule for BooleanCovariates.
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -32,6 +34,7 @@ public final class BooleanCovariate implements Covariate<Boolean> {
|
||||||
@Override
|
@Override
|
||||||
public Value<Boolean> createValue(String value) {
|
public Value<Boolean> createValue(String value) {
|
||||||
if(value == null || value.equalsIgnoreCase("na")){
|
if(value == null || value.equalsIgnoreCase("na")){
|
||||||
|
hasNAs = true;
|
||||||
return createValue( (Boolean) null);
|
return createValue( (Boolean) null);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -46,6 +49,11 @@ public final class BooleanCovariate implements Covariate<Boolean> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean hasNAs() {
|
||||||
|
return hasNAs;
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String toString(){
|
public String toString(){
|
||||||
return "BooleanCovariate(name=" + name + ")";
|
return "BooleanCovariate(name=" + name + ")";
|
||||||
|
|
|
@ -26,6 +26,8 @@ public interface Covariate<V> extends Serializable {
|
||||||
*/
|
*/
|
||||||
Value<V> createValue(String value);
|
Value<V> createValue(String value);
|
||||||
|
|
||||||
|
boolean hasNAs();
|
||||||
|
|
||||||
interface Value<V> extends Serializable{
|
interface Value<V> extends Serializable{
|
||||||
|
|
||||||
Covariate<V> getParent();
|
Covariate<V> getParent();
|
||||||
|
|
|
@ -19,6 +19,8 @@ public final class FactorCovariate implements Covariate<String>{
|
||||||
private final FactorValue naValue;
|
private final FactorValue naValue;
|
||||||
private final int numberOfPossiblePairings;
|
private final int numberOfPossiblePairings;
|
||||||
|
|
||||||
|
private boolean hasNAs;
|
||||||
|
|
||||||
|
|
||||||
public FactorCovariate(final String name, final int index, List<String> levels){
|
public FactorCovariate(final String name, final int index, List<String> levels){
|
||||||
this.name = name;
|
this.name = name;
|
||||||
|
@ -72,6 +74,7 @@ public final class FactorCovariate implements Covariate<String>{
|
||||||
@Override
|
@Override
|
||||||
public FactorValue createValue(String value) {
|
public FactorValue createValue(String value) {
|
||||||
if(value == null || value.equalsIgnoreCase("na")){
|
if(value == null || value.equalsIgnoreCase("na")){
|
||||||
|
this.hasNAs = true;
|
||||||
return this.naValue;
|
return this.naValue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -84,6 +87,12 @@ public final class FactorCovariate implements Covariate<String>{
|
||||||
return factorValue;
|
return factorValue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean hasNAs() {
|
||||||
|
return hasNAs;
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String toString(){
|
public String toString(){
|
||||||
return "FactorCovariate(name=" + name + ")";
|
return "FactorCovariate(name=" + name + ")";
|
||||||
|
|
|
@ -12,6 +12,7 @@ import lombok.ToString;
|
||||||
|
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
|
import java.util.stream.Stream;
|
||||||
|
|
||||||
@RequiredArgsConstructor
|
@RequiredArgsConstructor
|
||||||
@ToString
|
@ToString
|
||||||
|
@ -23,10 +24,17 @@ public final class NumericCovariate implements Covariate<Double> {
|
||||||
@Getter
|
@Getter
|
||||||
private final int index;
|
private final int index;
|
||||||
|
|
||||||
|
private boolean hasNAs = false;
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public <Y> NumericSplitRuleUpdater<Y> generateSplitRuleUpdater(List<Row<Y>> data, int number, Random random) {
|
public <Y> NumericSplitRuleUpdater<Y> generateSplitRuleUpdater(List<Row<Y>> data, int number, Random random) {
|
||||||
data = data.stream()
|
Stream<Row<Y>> stream = data.stream();
|
||||||
.filter(row -> !row.getCovariateValue(this).isNA())
|
|
||||||
|
if(hasNAs()){
|
||||||
|
stream = stream.filter(row -> !row.getCovariateValue(this).isNA());
|
||||||
|
}
|
||||||
|
|
||||||
|
data = stream
|
||||||
.sorted((r1, r2) -> {
|
.sorted((r1, r2) -> {
|
||||||
Double d1 = r1.getCovariateValue(this).getValue();
|
Double d1 = r1.getCovariateValue(this).getValue();
|
||||||
Double d2 = r2.getCovariateValue(this).getValue();
|
Double d2 = r2.getCovariateValue(this).getValue();
|
||||||
|
@ -37,7 +45,6 @@ public final class NumericCovariate implements Covariate<Double> {
|
||||||
|
|
||||||
Iterator<Double> sortedDataIterator = data.stream()
|
Iterator<Double> sortedDataIterator = data.stream()
|
||||||
.map(row -> row.getCovariateValue(this).getValue())
|
.map(row -> row.getCovariateValue(this).getValue())
|
||||||
.filter(v -> v != null)
|
|
||||||
.iterator();
|
.iterator();
|
||||||
|
|
||||||
|
|
||||||
|
@ -56,7 +63,7 @@ public final class NumericCovariate implements Covariate<Double> {
|
||||||
|
|
||||||
dataIterator = new UniqueSubsetValueIterator<>(
|
dataIterator = new UniqueSubsetValueIterator<>(
|
||||||
new UniqueValueIterator<>(sortedDataIterator),
|
new UniqueValueIterator<>(sortedDataIterator),
|
||||||
indexSet.toArray(new Integer[indexSet.size()]) // TODO verify this is ordered
|
indexSet.toArray(new Integer[indexSet.size()])
|
||||||
);
|
);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -73,12 +80,19 @@ public final class NumericCovariate implements Covariate<Double> {
|
||||||
@Override
|
@Override
|
||||||
public NumericValue createValue(String value) {
|
public NumericValue createValue(String value) {
|
||||||
if(value == null || value.equalsIgnoreCase("na")){
|
if(value == null || value.equalsIgnoreCase("na")){
|
||||||
|
this.hasNAs = true;
|
||||||
return createValue((Double) null);
|
return createValue((Double) null);
|
||||||
}
|
}
|
||||||
|
|
||||||
return createValue(Double.parseDouble(value));
|
return createValue(Double.parseDouble(value));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean hasNAs() {
|
||||||
|
return hasNAs;
|
||||||
|
}
|
||||||
|
|
||||||
@EqualsAndHashCode
|
@EqualsAndHashCode
|
||||||
public class NumericValue implements Covariate.Value<Double>{
|
public class NumericValue implements Covariate.Value<Double>{
|
||||||
|
|
||||||
|
|
|
@ -76,22 +76,25 @@ public class TreeTrainer<Y, O> {
|
||||||
final double probabilityLeftHand = (double) bestSplit.leftHand.size() /
|
final double probabilityLeftHand = (double) bestSplit.leftHand.size() /
|
||||||
(double) (bestSplit.leftHand.size() + bestSplit.rightHand.size());
|
(double) (bestSplit.leftHand.size() + bestSplit.rightHand.size());
|
||||||
|
|
||||||
// Assign missing values to the split
|
// Assign missing values to the split if necessary
|
||||||
for(Row<Y> row : data) {
|
if(bestSplit.getSplitRule().getParent().hasNAs()){
|
||||||
if(row.getCovariateValue(bestSplit.getSplitRule().getParent()).isNA()) {
|
for(Row<Y> row : data) {
|
||||||
final boolean randomDecision = random.nextDouble() <= probabilityLeftHand;
|
if(row.getCovariateValue(bestSplit.getSplitRule().getParent()).isNA()) {
|
||||||
|
final boolean randomDecision = random.nextDouble() <= probabilityLeftHand;
|
||||||
|
|
||||||
if(randomDecision){
|
if(randomDecision){
|
||||||
bestSplit.getLeftHand().add(row);
|
bestSplit.getLeftHand().add(row);
|
||||||
}
|
}
|
||||||
else{
|
else{
|
||||||
bestSplit.getRightHand().add(row);
|
bestSplit.getRightHand().add(row);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
final Node<O> leftNode = growNode(bestSplit.leftHand, depth+1, random);
|
final Node<O> leftNode = growNode(bestSplit.leftHand, depth+1, random);
|
||||||
final Node<O> rightNode = growNode(bestSplit.rightHand, depth+1, random);
|
final Node<O> rightNode = growNode(bestSplit.rightHand, depth+1, random);
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue