Fix bug where NAs cause crash

This commit is contained in:
Joel Therrien 2019-03-04 11:36:21 -08:00
parent 91cf299362
commit 8014bd4629
3 changed files with 75 additions and 2 deletions

View file

@ -20,6 +20,7 @@ import ca.joeltherrien.randomforest.Row;
import ca.joeltherrien.randomforest.covariates.Covariate;
import lombok.Data;
import java.util.ArrayList;
import java.util.List;
/**
@ -36,4 +37,11 @@ public final class Split<Y, V> {
public final List<Row<Y>> rightHand;
public final List<Row<Y>> naHand;
public Split<Y, V> modifiableClone(){
return new Split<>(splitRule,
new ArrayList<>(leftHand),
new ArrayList<>(rightHand),
new ArrayList<>(naHand));
}
}

View file

@ -72,7 +72,7 @@ public class TreeTrainer<Y, O> {
// See https://kogalur.github.io/randomForestSRC/theory.html#section3.1 (near bottom)
if(data.size() >= 2*nodeSize && depth < maxNodeDepth && !nodeIsPure(data)){
final List<Covariate> covariatesToTry = selectCovariates(this.mtry, random);
final Split<Y,?> bestSplit = findBestSplitRule(data, covariatesToTry, random);
Split<Y,?> bestSplit = findBestSplitRule(data, covariatesToTry, random);
if(bestSplit == null){
@ -92,8 +92,12 @@ public class TreeTrainer<Y, O> {
// Assign missing values to the split if necessary
if(bestSplit.getSplitRule().getParent().hasNAs()){
bestSplit = bestSplit.modifiableClone(); // the lists in bestSplit are otherwise usually unmodifiable lists
for(Row<Y> row : data) {
if(row.getCovariateValue(bestSplit.getSplitRule().getParent()).isNA()) {
final Covariate<?> covariate = bestSplit.getSplitRule().getParent();
if(row.getCovariateValue(covariate).isNA()) {
final boolean randomDecision = random.nextDouble() <= probabilityLeftHand;
if(randomDecision){

View file

@ -0,0 +1,61 @@
package ca.joeltherrien.randomforest.nas;
import ca.joeltherrien.randomforest.Row;
import ca.joeltherrien.randomforest.covariates.Covariate;
import ca.joeltherrien.randomforest.covariates.numeric.NumericCovariate;
import ca.joeltherrien.randomforest.responses.regression.MeanResponseCombiner;
import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceGroupDifferentiator;
import ca.joeltherrien.randomforest.tree.TreeTrainer;
import ca.joeltherrien.randomforest.utils.Utils;
import org.junit.jupiter.api.Test;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Random;
public class TestNAs {
private List<Row<Double>> generateData(List<Covariate> covariates){
final List<Row<Double>> dataList = new ArrayList<>();
// We must include an NA for one of the values
dataList.add(Row.createSimple(Utils.easyMap("x", "NA"), covariates, 1, 5.0));
dataList.add(Row.createSimple(Utils.easyMap("x", "1"), covariates, 1, 6.0));
dataList.add(Row.createSimple(Utils.easyMap("x", "2"), covariates, 1, 5.5));
dataList.add(Row.createSimple(Utils.easyMap("x", "7"), covariates, 1, 0.0));
dataList.add(Row.createSimple(Utils.easyMap("x", "8"), covariates, 1, 1.0));
dataList.add(Row.createSimple(Utils.easyMap("x", "8.4"), covariates, 1, 1.0));
return dataList;
}
@Test
public void testException(){
// There was a bug with NAs where when we tried to randomly assign NAs during a split to the best split produced by NumericSplitRuleUpdater,
// but NumericSplitRuleUpdater had unmodifiable lists when creating the split.
// This bug verifies that this no longer causes a crash
final List<Covariate> covariates = Collections.singletonList(new NumericCovariate("x", 0));
final List<Row<Double>> dataset = generateData(covariates);
final TreeTrainer<Double, Double> treeTrainer = TreeTrainer.<Double, Double>builder()
.checkNodePurity(false)
.covariates(covariates)
.numberOfSplits(0)
.nodeSize(1)
.maxNodeDepth(1000)
.groupDifferentiator(new WeightedVarianceGroupDifferentiator())
.responseCombiner(new MeanResponseCombiner())
.build();
treeTrainer.growTree(dataset, new Random(123));
// As long as no exception occurs, we passed
}
}