Fix bug where NAs cause crash
This commit is contained in:
parent
91cf299362
commit
8014bd4629
3 changed files with 75 additions and 2 deletions
|
@ -20,6 +20,7 @@ import ca.joeltherrien.randomforest.Row;
|
||||||
import ca.joeltherrien.randomforest.covariates.Covariate;
|
import ca.joeltherrien.randomforest.covariates.Covariate;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -36,4 +37,11 @@ public final class Split<Y, V> {
|
||||||
public final List<Row<Y>> rightHand;
|
public final List<Row<Y>> rightHand;
|
||||||
public final List<Row<Y>> naHand;
|
public final List<Row<Y>> naHand;
|
||||||
|
|
||||||
|
public Split<Y, V> modifiableClone(){
|
||||||
|
return new Split<>(splitRule,
|
||||||
|
new ArrayList<>(leftHand),
|
||||||
|
new ArrayList<>(rightHand),
|
||||||
|
new ArrayList<>(naHand));
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -72,7 +72,7 @@ public class TreeTrainer<Y, O> {
|
||||||
// See https://kogalur.github.io/randomForestSRC/theory.html#section3.1 (near bottom)
|
// See https://kogalur.github.io/randomForestSRC/theory.html#section3.1 (near bottom)
|
||||||
if(data.size() >= 2*nodeSize && depth < maxNodeDepth && !nodeIsPure(data)){
|
if(data.size() >= 2*nodeSize && depth < maxNodeDepth && !nodeIsPure(data)){
|
||||||
final List<Covariate> covariatesToTry = selectCovariates(this.mtry, random);
|
final List<Covariate> covariatesToTry = selectCovariates(this.mtry, random);
|
||||||
final Split<Y,?> bestSplit = findBestSplitRule(data, covariatesToTry, random);
|
Split<Y,?> bestSplit = findBestSplitRule(data, covariatesToTry, random);
|
||||||
|
|
||||||
|
|
||||||
if(bestSplit == null){
|
if(bestSplit == null){
|
||||||
|
@ -92,8 +92,12 @@ public class TreeTrainer<Y, O> {
|
||||||
|
|
||||||
// Assign missing values to the split if necessary
|
// Assign missing values to the split if necessary
|
||||||
if(bestSplit.getSplitRule().getParent().hasNAs()){
|
if(bestSplit.getSplitRule().getParent().hasNAs()){
|
||||||
|
bestSplit = bestSplit.modifiableClone(); // the lists in bestSplit are otherwise usually unmodifiable lists
|
||||||
|
|
||||||
for(Row<Y> row : data) {
|
for(Row<Y> row : data) {
|
||||||
if(row.getCovariateValue(bestSplit.getSplitRule().getParent()).isNA()) {
|
final Covariate<?> covariate = bestSplit.getSplitRule().getParent();
|
||||||
|
|
||||||
|
if(row.getCovariateValue(covariate).isNA()) {
|
||||||
final boolean randomDecision = random.nextDouble() <= probabilityLeftHand;
|
final boolean randomDecision = random.nextDouble() <= probabilityLeftHand;
|
||||||
|
|
||||||
if(randomDecision){
|
if(randomDecision){
|
||||||
|
|
61
src/test/java/ca/joeltherrien/randomforest/nas/TestNAs.java
Normal file
61
src/test/java/ca/joeltherrien/randomforest/nas/TestNAs.java
Normal file
|
@ -0,0 +1,61 @@
|
||||||
|
package ca.joeltherrien.randomforest.nas;
|
||||||
|
|
||||||
|
import ca.joeltherrien.randomforest.Row;
|
||||||
|
import ca.joeltherrien.randomforest.covariates.Covariate;
|
||||||
|
import ca.joeltherrien.randomforest.covariates.numeric.NumericCovariate;
|
||||||
|
import ca.joeltherrien.randomforest.responses.regression.MeanResponseCombiner;
|
||||||
|
import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceGroupDifferentiator;
|
||||||
|
import ca.joeltherrien.randomforest.tree.TreeTrainer;
|
||||||
|
import ca.joeltherrien.randomforest.utils.Utils;
|
||||||
|
import org.junit.jupiter.api.Test;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Random;
|
||||||
|
|
||||||
|
public class TestNAs {
|
||||||
|
|
||||||
|
private List<Row<Double>> generateData(List<Covariate> covariates){
|
||||||
|
final List<Row<Double>> dataList = new ArrayList<>();
|
||||||
|
|
||||||
|
|
||||||
|
// We must include an NA for one of the values
|
||||||
|
dataList.add(Row.createSimple(Utils.easyMap("x", "NA"), covariates, 1, 5.0));
|
||||||
|
dataList.add(Row.createSimple(Utils.easyMap("x", "1"), covariates, 1, 6.0));
|
||||||
|
dataList.add(Row.createSimple(Utils.easyMap("x", "2"), covariates, 1, 5.5));
|
||||||
|
dataList.add(Row.createSimple(Utils.easyMap("x", "7"), covariates, 1, 0.0));
|
||||||
|
dataList.add(Row.createSimple(Utils.easyMap("x", "8"), covariates, 1, 1.0));
|
||||||
|
dataList.add(Row.createSimple(Utils.easyMap("x", "8.4"), covariates, 1, 1.0));
|
||||||
|
|
||||||
|
|
||||||
|
return dataList;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testException(){
|
||||||
|
// There was a bug with NAs where when we tried to randomly assign NAs during a split to the best split produced by NumericSplitRuleUpdater,
|
||||||
|
// but NumericSplitRuleUpdater had unmodifiable lists when creating the split.
|
||||||
|
// This bug verifies that this no longer causes a crash
|
||||||
|
|
||||||
|
final List<Covariate> covariates = Collections.singletonList(new NumericCovariate("x", 0));
|
||||||
|
final List<Row<Double>> dataset = generateData(covariates);
|
||||||
|
|
||||||
|
final TreeTrainer<Double, Double> treeTrainer = TreeTrainer.<Double, Double>builder()
|
||||||
|
.checkNodePurity(false)
|
||||||
|
.covariates(covariates)
|
||||||
|
.numberOfSplits(0)
|
||||||
|
.nodeSize(1)
|
||||||
|
.maxNodeDepth(1000)
|
||||||
|
.groupDifferentiator(new WeightedVarianceGroupDifferentiator())
|
||||||
|
.responseCombiner(new MeanResponseCombiner())
|
||||||
|
.build();
|
||||||
|
|
||||||
|
treeTrainer.growTree(dataset, new Random(123));
|
||||||
|
|
||||||
|
// As long as no exception occurs, we passed
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
Loading…
Reference in a new issue