diff --git a/src/main/java/ca/joeltherrien/randomforest/tree/Split.java b/src/main/java/ca/joeltherrien/randomforest/tree/Split.java index b931513..0b199fe 100644 --- a/src/main/java/ca/joeltherrien/randomforest/tree/Split.java +++ b/src/main/java/ca/joeltherrien/randomforest/tree/Split.java @@ -20,6 +20,7 @@ import ca.joeltherrien.randomforest.Row; import ca.joeltherrien.randomforest.covariates.Covariate; import lombok.Data; +import java.util.ArrayList; import java.util.List; /** @@ -36,4 +37,11 @@ public final class Split { public final List> rightHand; public final List> naHand; + public Split modifiableClone(){ + return new Split<>(splitRule, + new ArrayList<>(leftHand), + new ArrayList<>(rightHand), + new ArrayList<>(naHand)); + } + } diff --git a/src/main/java/ca/joeltherrien/randomforest/tree/TreeTrainer.java b/src/main/java/ca/joeltherrien/randomforest/tree/TreeTrainer.java index 3a524b7..9756c79 100644 --- a/src/main/java/ca/joeltherrien/randomforest/tree/TreeTrainer.java +++ b/src/main/java/ca/joeltherrien/randomforest/tree/TreeTrainer.java @@ -72,7 +72,7 @@ public class TreeTrainer { // See https://kogalur.github.io/randomForestSRC/theory.html#section3.1 (near bottom) if(data.size() >= 2*nodeSize && depth < maxNodeDepth && !nodeIsPure(data)){ final List covariatesToTry = selectCovariates(this.mtry, random); - final Split bestSplit = findBestSplitRule(data, covariatesToTry, random); + Split bestSplit = findBestSplitRule(data, covariatesToTry, random); if(bestSplit == null){ @@ -92,8 +92,12 @@ public class TreeTrainer { // Assign missing values to the split if necessary if(bestSplit.getSplitRule().getParent().hasNAs()){ + bestSplit = bestSplit.modifiableClone(); // the lists in bestSplit are otherwise usually unmodifiable lists + for(Row row : data) { - if(row.getCovariateValue(bestSplit.getSplitRule().getParent()).isNA()) { + final Covariate covariate = bestSplit.getSplitRule().getParent(); + + if(row.getCovariateValue(covariate).isNA()) { final boolean randomDecision = random.nextDouble() <= probabilityLeftHand; if(randomDecision){ diff --git a/src/test/java/ca/joeltherrien/randomforest/nas/TestNAs.java b/src/test/java/ca/joeltherrien/randomforest/nas/TestNAs.java new file mode 100644 index 0000000..d748f1e --- /dev/null +++ b/src/test/java/ca/joeltherrien/randomforest/nas/TestNAs.java @@ -0,0 +1,61 @@ +package ca.joeltherrien.randomforest.nas; + +import ca.joeltherrien.randomforest.Row; +import ca.joeltherrien.randomforest.covariates.Covariate; +import ca.joeltherrien.randomforest.covariates.numeric.NumericCovariate; +import ca.joeltherrien.randomforest.responses.regression.MeanResponseCombiner; +import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceGroupDifferentiator; +import ca.joeltherrien.randomforest.tree.TreeTrainer; +import ca.joeltherrien.randomforest.utils.Utils; +import org.junit.jupiter.api.Test; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Random; + +public class TestNAs { + + private List> generateData(List covariates){ + final List> dataList = new ArrayList<>(); + + + // We must include an NA for one of the values + dataList.add(Row.createSimple(Utils.easyMap("x", "NA"), covariates, 1, 5.0)); + dataList.add(Row.createSimple(Utils.easyMap("x", "1"), covariates, 1, 6.0)); + dataList.add(Row.createSimple(Utils.easyMap("x", "2"), covariates, 1, 5.5)); + dataList.add(Row.createSimple(Utils.easyMap("x", "7"), covariates, 1, 0.0)); + dataList.add(Row.createSimple(Utils.easyMap("x", "8"), covariates, 1, 1.0)); + dataList.add(Row.createSimple(Utils.easyMap("x", "8.4"), covariates, 1, 1.0)); + + + return dataList; + } + + @Test + public void testException(){ + // There was a bug with NAs where when we tried to randomly assign NAs during a split to the best split produced by NumericSplitRuleUpdater, + // but NumericSplitRuleUpdater had unmodifiable lists when creating the split. + // This bug verifies that this no longer causes a crash + + final List covariates = Collections.singletonList(new NumericCovariate("x", 0)); + final List> dataset = generateData(covariates); + + final TreeTrainer treeTrainer = TreeTrainer.builder() + .checkNodePurity(false) + .covariates(covariates) + .numberOfSplits(0) + .nodeSize(1) + .maxNodeDepth(1000) + .groupDifferentiator(new WeightedVarianceGroupDifferentiator()) + .responseCombiner(new MeanResponseCombiner()) + .build(); + + treeTrainer.growTree(dataset, new Random(123)); + + // As long as no exception occurs, we passed + + + } + +}