Covariates track if they have any NA values, and skip NA handling code if possible
This commit is contained in:
parent
a57741b726
commit
31d6ce9b3e
5 changed files with 50 additions and 14 deletions
|
@ -17,6 +17,8 @@ public final class BooleanCovariate implements Covariate<Boolean> {
|
|||
@Getter
|
||||
private final int index;
|
||||
|
||||
private boolean hasNAs = false;
|
||||
|
||||
private final BooleanSplitRule splitRule = new BooleanSplitRule(); // there's only one possible rule for BooleanCovariates.
|
||||
|
||||
@Override
|
||||
|
@ -32,6 +34,7 @@ public final class BooleanCovariate implements Covariate<Boolean> {
|
|||
@Override
|
||||
public Value<Boolean> createValue(String value) {
|
||||
if(value == null || value.equalsIgnoreCase("na")){
|
||||
hasNAs = true;
|
||||
return createValue( (Boolean) null);
|
||||
}
|
||||
|
||||
|
@ -46,6 +49,11 @@ public final class BooleanCovariate implements Covariate<Boolean> {
|
|||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean hasNAs() {
|
||||
return hasNAs;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString(){
|
||||
return "BooleanCovariate(name=" + name + ")";
|
||||
|
|
|
@ -26,6 +26,8 @@ public interface Covariate<V> extends Serializable {
|
|||
*/
|
||||
Value<V> createValue(String value);
|
||||
|
||||
boolean hasNAs();
|
||||
|
||||
interface Value<V> extends Serializable{
|
||||
|
||||
Covariate<V> getParent();
|
||||
|
|
|
@ -19,6 +19,8 @@ public final class FactorCovariate implements Covariate<String>{
|
|||
private final FactorValue naValue;
|
||||
private final int numberOfPossiblePairings;
|
||||
|
||||
private boolean hasNAs;
|
||||
|
||||
|
||||
public FactorCovariate(final String name, final int index, List<String> levels){
|
||||
this.name = name;
|
||||
|
@ -72,6 +74,7 @@ public final class FactorCovariate implements Covariate<String>{
|
|||
@Override
|
||||
public FactorValue createValue(String value) {
|
||||
if(value == null || value.equalsIgnoreCase("na")){
|
||||
this.hasNAs = true;
|
||||
return this.naValue;
|
||||
}
|
||||
|
||||
|
@ -84,6 +87,12 @@ public final class FactorCovariate implements Covariate<String>{
|
|||
return factorValue;
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public boolean hasNAs() {
|
||||
return hasNAs;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString(){
|
||||
return "FactorCovariate(name=" + name + ")";
|
||||
|
|
|
@ -12,6 +12,7 @@ import lombok.ToString;
|
|||
|
||||
import java.util.*;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
@RequiredArgsConstructor
|
||||
@ToString
|
||||
|
@ -23,10 +24,17 @@ public final class NumericCovariate implements Covariate<Double> {
|
|||
@Getter
|
||||
private final int index;
|
||||
|
||||
private boolean hasNAs = false;
|
||||
|
||||
@Override
|
||||
public <Y> NumericSplitRuleUpdater<Y> generateSplitRuleUpdater(List<Row<Y>> data, int number, Random random) {
|
||||
data = data.stream()
|
||||
.filter(row -> !row.getCovariateValue(this).isNA())
|
||||
Stream<Row<Y>> stream = data.stream();
|
||||
|
||||
if(hasNAs()){
|
||||
stream = stream.filter(row -> !row.getCovariateValue(this).isNA());
|
||||
}
|
||||
|
||||
data = stream
|
||||
.sorted((r1, r2) -> {
|
||||
Double d1 = r1.getCovariateValue(this).getValue();
|
||||
Double d2 = r2.getCovariateValue(this).getValue();
|
||||
|
@ -37,7 +45,6 @@ public final class NumericCovariate implements Covariate<Double> {
|
|||
|
||||
Iterator<Double> sortedDataIterator = data.stream()
|
||||
.map(row -> row.getCovariateValue(this).getValue())
|
||||
.filter(v -> v != null)
|
||||
.iterator();
|
||||
|
||||
|
||||
|
@ -56,7 +63,7 @@ public final class NumericCovariate implements Covariate<Double> {
|
|||
|
||||
dataIterator = new UniqueSubsetValueIterator<>(
|
||||
new UniqueValueIterator<>(sortedDataIterator),
|
||||
indexSet.toArray(new Integer[indexSet.size()]) // TODO verify this is ordered
|
||||
indexSet.toArray(new Integer[indexSet.size()])
|
||||
);
|
||||
|
||||
}
|
||||
|
@ -73,12 +80,19 @@ public final class NumericCovariate implements Covariate<Double> {
|
|||
@Override
|
||||
public NumericValue createValue(String value) {
|
||||
if(value == null || value.equalsIgnoreCase("na")){
|
||||
this.hasNAs = true;
|
||||
return createValue((Double) null);
|
||||
}
|
||||
|
||||
return createValue(Double.parseDouble(value));
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public boolean hasNAs() {
|
||||
return hasNAs;
|
||||
}
|
||||
|
||||
@EqualsAndHashCode
|
||||
public class NumericValue implements Covariate.Value<Double>{
|
||||
|
||||
|
|
|
@ -76,22 +76,25 @@ public class TreeTrainer<Y, O> {
|
|||
final double probabilityLeftHand = (double) bestSplit.leftHand.size() /
|
||||
(double) (bestSplit.leftHand.size() + bestSplit.rightHand.size());
|
||||
|
||||
// Assign missing values to the split
|
||||
for(Row<Y> row : data) {
|
||||
if(row.getCovariateValue(bestSplit.getSplitRule().getParent()).isNA()) {
|
||||
final boolean randomDecision = random.nextDouble() <= probabilityLeftHand;
|
||||
// Assign missing values to the split if necessary
|
||||
if(bestSplit.getSplitRule().getParent().hasNAs()){
|
||||
for(Row<Y> row : data) {
|
||||
if(row.getCovariateValue(bestSplit.getSplitRule().getParent()).isNA()) {
|
||||
final boolean randomDecision = random.nextDouble() <= probabilityLeftHand;
|
||||
|
||||
if(randomDecision){
|
||||
bestSplit.getLeftHand().add(row);
|
||||
}
|
||||
else{
|
||||
bestSplit.getRightHand().add(row);
|
||||
}
|
||||
if(randomDecision){
|
||||
bestSplit.getLeftHand().add(row);
|
||||
}
|
||||
else{
|
||||
bestSplit.getRightHand().add(row);
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
final Node<O> leftNode = growNode(bestSplit.leftHand, depth+1, random);
|
||||
final Node<O> rightNode = growNode(bestSplit.rightHand, depth+1, random);
|
||||
|
||||
|
|
Loading…
Reference in a new issue