Several changes -

Fixed some tests that weren't running.
Fixed a bug where training crashed if FactorCovariates had any NA
Fixed a bug where FactorCovariates were ignored in splitting if nsplit==0
Added a covariate specific option for whether splitting on an NA variable should have a penalty.

This penalty is accomplished by first calculating the split score and best split for a covariate
without NAs as done previously before. Then NAs are randomly assigned, and the split score is
recalculated on that best split. The new score is the lower of the new score and the original.
This commit is contained in:
Joel Therrien 2019-08-28 18:07:35 -07:00
parent c24626ff61
commit 79a9522ba7
27 changed files with 422 additions and 108 deletions

View file

@ -227,9 +227,9 @@ public class Main {
return Settings.builder()
.covariateSettings(Utils.easyList(
new NumericCovariateSettings("x1"),
new BooleanCovariateSettings("x2"),
new FactorCovariateSettings("x3", Utils.easyList("cat", "mouse", "dog"))
new NumericCovariateSettings("x1", true),
new BooleanCovariateSettings("x2", false),
new FactorCovariateSettings("x3", Utils.easyList("cat", "mouse", "dog"), true)
)
)
.trainingDataLocation("training_data.csv")

View file

@ -24,12 +24,12 @@ import lombok.NoArgsConstructor;
@Data
public final class BooleanCovariateSettings extends CovariateSettings<Boolean> {
public BooleanCovariateSettings(String name){
super(name);
public BooleanCovariateSettings(String name, boolean naSplitPenalty){
super(name, naSplitPenalty);
}
@Override
public BooleanCovariate build(int index) {
return new BooleanCovariate(name, index);
return new BooleanCovariate(name, index, naSplitPenalty);
}
}

View file

@ -40,9 +40,11 @@ import lombok.NoArgsConstructor;
public abstract class CovariateSettings<V> {
String name;
boolean naSplitPenalty;
CovariateSettings(String name){
CovariateSettings(String name, boolean naSplitPenalty){
this.name = name;
this.naSplitPenalty = naSplitPenalty;
}
public abstract Covariate<V> build(int index);

View file

@ -29,13 +29,13 @@ public final class FactorCovariateSettings extends CovariateSettings<String> {
private List<String> levels;
public FactorCovariateSettings(String name, List<String> levels){
super(name);
public FactorCovariateSettings(String name, List<String> levels, boolean naSplitPenalty){
super(name, naSplitPenalty);
this.levels = new ArrayList<>(levels); // Jackson struggles with List.of(...)
}
@Override
public FactorCovariate build(int index) {
return new FactorCovariate(name, index, levels);
return new FactorCovariate(name, index, levels, naSplitPenalty);
}
}

View file

@ -24,12 +24,12 @@ import lombok.NoArgsConstructor;
@Data
public final class NumericCovariateSettings extends CovariateSettings<Double> {
public NumericCovariateSettings(String name){
super(name);
public NumericCovariateSettings(String name, boolean naSplitPenalty){
super(name, naSplitPenalty);
}
@Override
public NumericCovariate build(int index) {
return new NumericCovariate(name, index);
return new NumericCovariate(name, index, naSplitPenalty);
}
}

View file

@ -49,9 +49,9 @@ public class TestPersistence {
final Settings settingsOriginal = Settings.builder()
.covariateSettings(Utils.easyList(
new NumericCovariateSettings("x1"),
new BooleanCovariateSettings("x2"),
new FactorCovariateSettings("x3", Utils.easyList("cat", "mouse", "dog"))
new NumericCovariateSettings("x1", true),
new BooleanCovariateSettings("x2", false),
new FactorCovariateSettings("x3", Utils.easyList("cat", "mouse", "dog"), true)
)
)
.trainingDataLocation("training_data.csv")

View file

@ -55,9 +55,9 @@ public class TestLoadingCSV {
final Settings settings = Settings.builder()
.trainingDataLocation(filename)
.covariateSettings(
Utils.easyList(new NumericCovariateSettings("x1"),
new FactorCovariateSettings("x2", Utils.easyList("dog", "cat", "mouse")),
new BooleanCovariateSettings("x3"))
Utils.easyList(new NumericCovariateSettings("x1", true),
new FactorCovariateSettings("x2", Utils.easyList("dog", "cat", "mouse"), false),
new BooleanCovariateSettings("x3", true))
)
.yVarSettings(yVarSettings)
.build();
@ -71,14 +71,14 @@ public class TestLoadingCSV {
}
@Test
public void verifyLoadingNormal(final List<Covariate> covariates) throws IOException {
public void testLoadingNormal(final List<Covariate> covariates) throws IOException {
final List<Row<Double>> data = loadData("src/test/resources/testCSV.csv");
assertData(data, covariates);
}
@Test
public void verifyLoadingGz(final List<Covariate> covariates) throws IOException {
public void testLoadingGz(final List<Covariate> covariates) throws IOException {
final List<Row<Double>> data = loadData("src/test/resources/testCSV.csv.gz");
assertData(data, covariates);

View file

@ -49,6 +49,8 @@ public interface Covariate<V> extends Serializable, Comparable<Covariate> {
return getIndex() - other.getIndex();
}
boolean haveNASplitPenalty();
interface Value<V> extends Serializable{
Covariate<V> getParent();

View file

@ -25,6 +25,7 @@ import lombok.Getter;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import java.util.stream.Collectors;
public final class BooleanCovariate implements Covariate<Boolean> {
@ -40,14 +41,26 @@ public final class BooleanCovariate implements Covariate<Boolean> {
private final BooleanSplitRule splitRule; // there's only one possible rule for BooleanCovariates.
public BooleanCovariate(String name, int index){
private final boolean haveNASplitPenalty;
@Override
public boolean haveNASplitPenalty(){
// penalty would add worthless computational time if there are no NAs
return hasNAs && haveNASplitPenalty;
}
public BooleanCovariate(String name, int index, boolean haveNASplitPenalty){
this.name = name;
this.index = index;
splitRule = new BooleanSplitRule(this);
this.splitRule = new BooleanSplitRule(this);
this.haveNASplitPenalty = haveNASplitPenalty;
}
@Override
public <Y> Iterator<Split<Y, Boolean>> generateSplitRuleUpdater(List<Row<Y>> data, int number, Random random) {
if(hasNAs){
data = data.stream().filter(row -> !row.getValueByIndex(index).isNA()).collect(Collectors.toList());
}
return new SingletonIterator<>(this.splitRule.applyRule(data));
}

View file

@ -23,6 +23,7 @@ import lombok.EqualsAndHashCode;
import lombok.Getter;
import java.util.*;
import java.util.stream.Collectors;
public final class FactorCovariate implements Covariate<String> {
@ -40,8 +41,15 @@ public final class FactorCovariate implements Covariate<String> {
private boolean hasNAs;
private final boolean haveNASplitPenalty;
@Override
public boolean haveNASplitPenalty(){
// penalty would add worthless computational time if there are no NAs
return hasNAs && haveNASplitPenalty;
}
public FactorCovariate(final String name, final int index, List<String> levels){
public FactorCovariate(final String name, final int index, List<String> levels, final boolean haveNASplitPenalty){
this.name = name;
this.index = index;
this.factorLevels = new HashMap<>();
@ -63,12 +71,22 @@ public final class FactorCovariate implements Covariate<String> {
this.numberOfPossiblePairings = numberOfPossiblePairingsTemp-1;
this.naValue = new FactorValue(null);
this.haveNASplitPenalty = haveNASplitPenalty;
}
@Override
public <Y> Iterator<Split<Y, String>> generateSplitRuleUpdater(List<Row<Y>> data, int number, Random random) {
if(hasNAs()){
data = data.stream().filter(row -> !row.getCovariateValue(this).isNA()).collect(Collectors.toList());
}
if(number == 0){ // nsplit = 0 => try every possibility, although we limit it to the number of observations.
number = data.size();
}
final Set<Split<Y, String>> splits = new HashSet<>();
// This is to ensure we don't get stuck in an infinite loop for small factors

View file

@ -47,6 +47,13 @@ public final class NumericCovariate implements Covariate<Double> {
private boolean hasNAs = false;
private final boolean haveNASplitPenalty;
@Override
public boolean haveNASplitPenalty(){
// penalty would add worthless computational time if there are no NAs
return hasNAs && haveNASplitPenalty;
}
@Override
public <Y> NumericSplitRuleUpdater<Y> generateSplitRuleUpdater(List<Row<Y>> data, int number, Random random) {
Stream<Row<Y>> stream = data.stream();

View file

@ -17,15 +17,13 @@
package ca.joeltherrien.randomforest.tree;
import lombok.AllArgsConstructor;
import lombok.Getter;
import lombok.Data;
@AllArgsConstructor
@Data
public class SplitAndScore<Y, V> {
@Getter
private final Split<Y, V> split;
@Getter
private final Double score;
private Split<Y, V> split;
private Double score;
}

View file

@ -17,7 +17,9 @@
package ca.joeltherrien.randomforest.tree;
import ca.joeltherrien.randomforest.Row;
import ca.joeltherrien.randomforest.VisibleForTesting;
import ca.joeltherrien.randomforest.covariates.Covariate;
import ca.joeltherrien.randomforest.utils.SingletonIterator;
import lombok.AccessLevel;
import lombok.AllArgsConstructor;
import lombok.Builder;
@ -72,31 +74,12 @@ public class TreeTrainer<Y, O> {
}
// Now that we have the best split; we need to handle any NAs that were dropped off
final double probabilityLeftHand = (double) bestSplit.leftHand.size() /
(double) (bestSplit.leftHand.size() + bestSplit.rightHand.size());
// Assign missing values to the split if necessary
if(covariates.get(bestSplit.getSplitRule().getParentCovariateIndex()).hasNAs()){
bestSplit = bestSplit.modifiableClone(); // the lists in bestSplit are otherwise usually unmodifiable lists
for(Row<Y> row : data) {
final int covariateIndex = bestSplit.getSplitRule().getParentCovariateIndex();
if(row.getValueByIndex(covariateIndex).isNA()) {
final boolean randomDecision = random.nextDouble() <= probabilityLeftHand;
if(randomDecision){
bestSplit.getLeftHand().add(row);
}
else{
bestSplit.getRightHand().add(row);
}
}
}
}
bestSplit = randomlyAssignNAs(data, bestSplit, random);
final Node<O> leftNode;
final Node<O> rightNode;
@ -144,7 +127,8 @@ public class TreeTrainer<Y, O> {
return splitCovariates;
}
private Split<Y, ?> findBestSplitRule(List<Row<Y>> data, List<Covariate> covariatesToTry, Random random){
@VisibleForTesting
public Split<Y, ?> findBestSplitRule(List<Row<Y>> data, List<Covariate> covariatesToTry, Random random){
SplitAndScore<Y, ?> bestSplitAndScore = null;
final SplitFinder noGenericSplitFinder = splitFinder; // cause Java generics are sometimes too frustrating
@ -157,10 +141,32 @@ public class TreeTrainer<Y, O> {
continue;
}
final SplitAndScore<Y, ?> candidateSplitAndScore = noGenericSplitFinder.findBestSplit(iterator);
SplitAndScore<Y, ?> candidateSplitAndScore = noGenericSplitFinder.findBestSplit(iterator);
if(candidateSplitAndScore != null && (bestSplitAndScore == null ||
candidateSplitAndScore.getScore() > bestSplitAndScore.getScore())) {
if(candidateSplitAndScore == null){
continue;
}
// This score was based on splitting only non-NA values. However, there might be a similar covariate we are also considering
// that is just as good at splitting but has less NAs; we should thus penalize the split score for variables with NAs
// We do this by randomly assigning the NAs and then recalculating the split score on the best split we already have.
//
// We only have to penalize the score though if we know it's possible that this might be the best split. If it's not,
// then we can skip the computations.
final boolean mayBeGoodSplit = bestSplitAndScore == null ||
candidateSplitAndScore.getScore() > bestSplitAndScore.getScore();
if(mayBeGoodSplit && covariate.haveNASplitPenalty()){
Split<Y, ?> candiateSplitWithNAs = randomlyAssignNAs(data, candidateSplitAndScore.getSplit(), random);
final Iterator<Split<Y, ?>> newSplitWithRandomNAs = new SingletonIterator<>(candiateSplitWithNAs);
final double newScore = splitFinder.findBestSplit(newSplitWithRandomNAs).getScore();
// There's a chance that NAs might add noise to *improve* the score; but we want to ensure we penalize it.
// Thus we only change the score if its worse.
candidateSplitAndScore.setScore(Math.min(newScore, candidateSplitAndScore.getScore()));
}
if(bestSplitAndScore == null || candidateSplitAndScore.getScore() > bestSplitAndScore.getScore()) {
bestSplitAndScore = candidateSplitAndScore;
}
@ -174,6 +180,38 @@ public class TreeTrainer<Y, O> {
}
private <V> Split<Y, V> randomlyAssignNAs(List<Row<Y>> data, Split<Y, V> existingSplit, Random random){
// Now that we have the best split; we need to handle any NAs that were dropped off
final double probabilityLeftHand = (double) existingSplit.leftHand.size() /
(double) (existingSplit.leftHand.size() + existingSplit.rightHand.size());
final int covariateIndex = existingSplit.getSplitRule().getParentCovariateIndex();
// Assign missing values to the split if necessary
if(covariates.get(existingSplit.getSplitRule().getParentCovariateIndex()).hasNAs()){
existingSplit = existingSplit.modifiableClone(); // the lists in bestSplit are otherwise usually unmodifiable lists
for(Row<Y> row : data) {
if(row.getValueByIndex(covariateIndex).isNA()) {
final boolean randomDecision = random.nextDouble() <= probabilityLeftHand;
if(randomDecision){
existingSplit.getLeftHand().add(row);
}
else{
existingSplit.getRightHand().add(row);
}
}
}
}
return existingSplit;
}
private boolean nodeIsPure(List<Row<Y>> data){
if(!checkNodePurity){
return false;

View file

@ -45,20 +45,20 @@ public class TestDeterministicForests {
int index = 0;
for(int j=0; j<5; j++){
final NumericCovariate numericCovariate = new NumericCovariate("numeric"+j, index);
final NumericCovariate numericCovariate = new NumericCovariate("numeric"+j, index, false);
covariateList.add(numericCovariate);
index++;
}
for(int j=0; j<5; j++){
final BooleanCovariate booleanCovariate = new BooleanCovariate("boolean"+j, index);
final BooleanCovariate booleanCovariate = new BooleanCovariate("boolean"+j, index, false);
covariateList.add(booleanCovariate);
index++;
}
final List<String> levels = Utils.easyList("cat", "dog", "mouse");
for(int j=0; j<5; j++){
final FactorCovariate factorCovariate = new FactorCovariate("factor"+j, index, levels);
final FactorCovariate factorCovariate = new FactorCovariate("factor"+j, index, levels, false);
covariateList.add(factorCovariate);
index++;
}

View file

@ -44,7 +44,7 @@ public class TestProvidingInitialForest {
private List<Row<Double>> data;
public TestProvidingInitialForest(){
covariateList = Collections.singletonList(new NumericCovariate("x", 0));
covariateList = Collections.singletonList(new NumericCovariate("x", 0, false));
data = Utils.easyList(
Row.createSimple(Utils.easyMap("x", "1.0"), covariateList, 1, 1.0),
@ -198,7 +198,7 @@ public class TestProvidingInitialForest {
it's not clear if the forest being provided is the same one that trees were saved from.
*/
@Test
public void verifyExceptions(){
public void testExceptions(){
final String filePath = "src/test/resources/trees/";
final File directory = new File(filePath);
if(directory.exists()){

View file

@ -47,10 +47,10 @@ public class TestSavingLoading {
public List<Covariate> getCovariates(){
return Utils.easyList(
new NumericCovariate("ageatfda", 0),
new BooleanCovariate("idu", 1),
new BooleanCovariate("black", 2),
new NumericCovariate("cd4nadir", 3)
new NumericCovariate("ageatfda", 0, false),
new BooleanCovariate("idu", 1, false),
new BooleanCovariate("black", 2, false),
new NumericCovariate("cd4nadir", 3, false)
);
}

View file

@ -156,7 +156,7 @@ public class TestUtils {
}
@Test
public void reduceListToSize(){
public void testReduceListToSize(){
final List<Integer> testList = Utils.easyList(1, 2, 3, 4, 5, 6, 7, 8, 9, 10);
final Random random = new Random();
for(int i=0; i<100; i++) { // want to test many times to ensure it doesn't work just due to randomness

View file

@ -52,7 +52,7 @@ public class IBSCalculatorTest {
*/
@Test
public void resultsWithoutCensoringDistribution(){
public void testResultsWithoutCensoringDistribution(){
final IBSCalculator calculator = new IBSCalculator();
final double errorDifferentEvent = calculator.calculateError(
@ -74,7 +74,7 @@ public class IBSCalculatorTest {
}
@Test
public void resultsWithCensoringDistribution(){
public void testResultsWithCensoringDistribution(){
final RightContinuousStepFunction censorSurvivalFunction = RightContinuousStepFunction.constructFromPoints(
Utils.easyList(
new Point(0.0, 0.75),

View file

@ -53,10 +53,10 @@ public class TestCompetingRisk {
public List<Covariate> getCovariates(){
return Utils.easyList(
new NumericCovariate("ageatfda", 0),
new BooleanCovariate("idu", 1),
new BooleanCovariate("black", 2),
new NumericCovariate("cd4nadir", 3)
new NumericCovariate("ageatfda", 0, false),
new BooleanCovariate("idu", 1, false),
new BooleanCovariate("black", 2, false),
new NumericCovariate("cd4nadir", 3, false)
);
}
@ -109,8 +109,8 @@ public class TestCompetingRisk {
// by only using BooleanCovariates (only one split rule) we can guarantee identical results with randomForestSRC on one tree.
final List<Covariate> covariates = Utils.easyList(
new BooleanCovariate("idu", 0),
new BooleanCovariate("black", 1)
new BooleanCovariate("idu", 0, false),
new BooleanCovariate("black", 1, false)
);
final List<Row<CompetingRiskResponse>> dataset = getData(covariates, "src/test/resources/wihs.bootstrapped.csv");
@ -210,8 +210,8 @@ public class TestCompetingRisk {
public void testLogRankSplitFinderTwoBooleans() throws IOException {
// by only using BooleanCovariates (only one split rule) we can guarantee identical results with randomForestSRC on one tree.
final List<Covariate> covariates = Utils.easyList(
new BooleanCovariate("idu", 0),
new BooleanCovariate("black", 1)
new BooleanCovariate("idu", 0, false),
new BooleanCovariate("black", 1, false)
);
@ -259,7 +259,7 @@ public class TestCompetingRisk {
}
@Test
public void verifyDataset() throws IOException {
public void testDataset() throws IOException {
final List<Covariate> covariates = getCovariates();
final List<Row<CompetingRiskResponse>> dataset = getData(covariates, DEFAULT_FILEPATH);

View file

@ -46,7 +46,7 @@ public class TestLogRankSplitFinder {
public static Data<CompetingRiskResponse> loadData(String filename) throws IOException {
final List<Covariate> covariates = Utils.easyList(
new NumericCovariate("x2", 0)
new NumericCovariate("x2", 0, false)
);
final List<Row<CompetingRiskResponse>> rows = TestUtils.loadData(covariates, new ResponseLoader.CompetingRisksResponseLoader("delta", "u"), filename);

View file

@ -17,12 +17,15 @@
package ca.joeltherrien.randomforest.covariates;
import ca.joeltherrien.randomforest.Row;
import ca.joeltherrien.randomforest.covariates.factor.FactorCovariate;
import ca.joeltherrien.randomforest.tree.Split;
import ca.joeltherrien.randomforest.utils.Utils;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.function.Executable;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Random;
@ -31,7 +34,7 @@ import static org.junit.jupiter.api.Assertions.*;
public class FactorCovariateTest {
@Test
void verifyEqualLevels() {
public void testEqualLevels() {
final FactorCovariate petCovariate = createTestCovariate();
final FactorCovariate.FactorValue dog1 = petCovariate.createValue("DOG");
@ -53,7 +56,7 @@ public class FactorCovariateTest {
}
@Test
void verifyBadLevelException(){
public void testBadLevelException(){
final FactorCovariate petCovariate = createTestCovariate();
final Executable badCode = () -> petCovariate.createValue("vulcan");
@ -61,25 +64,169 @@ public class FactorCovariateTest {
}
@Test
void testAllSubsets(){
public void testAllSubsets(){
final int n = 2*3; // ensure that n is a multiple of 3 for the test
final FactorCovariate petCovariate = createTestCovariate();
final List<Row<Double>> data = generateSampleData(petCovariate, n);
final List<SplitRule<String>> splitRules = new ArrayList<>();
final List<Split<Double, String>> splits = new ArrayList<>();
petCovariate.generateSplitRuleUpdater(null, 100, new Random())
.forEachRemaining(split -> splitRules.add(split.getSplitRule()));
petCovariate.generateSplitRuleUpdater(data, 100, new Random())
.forEachRemaining(split -> splits.add(split));
assertEquals(splitRules.size(), 3);
assertEquals(splits.size(), 3);
// TODO verify the contents of the split rules
// These are the 3 possibilities
boolean dog_catmouse = false;
boolean cat_dogmouse = false;
boolean mouse_dogcat = false;
for(Split<Double, String> split : splits){
List<Row<Double>> smallerHand;
List<Row<Double>> largerHand;
if(split.getLeftHand().size() < split.getRightHand().size()){
smallerHand = split.getLeftHand();
largerHand = split.getRightHand();
} else{
smallerHand = split.getRightHand();
largerHand = split.getLeftHand();
}
// There should be exactly one distinct value in the smaller list
assertEquals(n/3, smallerHand.size());
assertEquals(1,
smallerHand.stream()
.map(row -> row.getCovariateValue(petCovariate).getValue())
.distinct()
.count()
);
// There should be exactly two distinct values in the smaller list
assertEquals(2*n/3, largerHand.size());
assertEquals(2,
largerHand.stream()
.map(row -> row.getCovariateValue(petCovariate).getValue())
.distinct()
.count()
);
switch(smallerHand.get(0).getCovariateValue(petCovariate).getValue()){
case "DOG":
dog_catmouse = true;
case "CAT":
cat_dogmouse = true;
case "MOUSE":
mouse_dogcat = true;
}
}
assertTrue(dog_catmouse);
assertTrue(cat_dogmouse);
assertTrue(mouse_dogcat);
}
/*
* There was a bug where if number==0 in generateSplitRuleUpdater, then the result was empty.
*/
@Test
public void testNumber0Subsets(){
final int n = 2*3; // ensure that n is a multiple of 3 for the test
final FactorCovariate petCovariate = createTestCovariate();
final List<Row<Double>> data = generateSampleData(petCovariate, n);
final List<Split<Double, String>> splits = new ArrayList<>();
petCovariate.generateSplitRuleUpdater(data, 0, new Random())
.forEachRemaining(split -> splits.add(split));
assertEquals(splits.size(), 3);
// These are the 3 possibilities
boolean dog_catmouse = false;
boolean cat_dogmouse = false;
boolean mouse_dogcat = false;
for(Split<Double, String> split : splits){
List<Row<Double>> smallerHand;
List<Row<Double>> largerHand;
if(split.getLeftHand().size() < split.getRightHand().size()){
smallerHand = split.getLeftHand();
largerHand = split.getRightHand();
} else{
smallerHand = split.getRightHand();
largerHand = split.getLeftHand();
}
// There should be exactly one distinct value in the smaller list
assertEquals(n/3, smallerHand.size());
assertEquals(1,
smallerHand.stream()
.map(row -> row.getCovariateValue(petCovariate).getValue())
.distinct()
.count()
);
// There should be exactly two distinct values in the smaller list
assertEquals(2*n/3, largerHand.size());
assertEquals(2,
largerHand.stream()
.map(row -> row.getCovariateValue(petCovariate).getValue())
.distinct()
.count()
);
switch(smallerHand.get(0).getCovariateValue(petCovariate).getValue()){
case "DOG":
dog_catmouse = true;
case "CAT":
cat_dogmouse = true;
case "MOUSE":
mouse_dogcat = true;
}
}
assertTrue(dog_catmouse);
assertTrue(cat_dogmouse);
assertTrue(mouse_dogcat);
}
@Test
public void testSpitRuleUpdaterWithNAs(){
// When some NAs were present calling generateSplitRuleUpdater caused an exception.
final FactorCovariate covariate = createTestCovariate();
final List<Row<Double>> sampleData = generateSampleData(covariate, 10);
sampleData.add(Row.createSimple(Utils.easyMap("pet", "NA"), Collections.singletonList(covariate), 11, 5.0));
covariate.generateSplitRuleUpdater(sampleData, 0, new Random());
// Test passes if no exception has occurred.
}
private FactorCovariate createTestCovariate(){
final List<String> levels = Utils.easyList("DOG", "CAT", "MOUSE");
return new FactorCovariate("pet", 0, levels);
return new FactorCovariate("pet", 0, levels, false);
}
private List<Row<Double>> generateSampleData(Covariate covariate, int n){
final List<Covariate> covariateList = Collections.singletonList(covariate);
final List<Row<Double>> dataList = new ArrayList<>(n);
final String[] levels = new String[]{"DOG", "CAT", "MOUSE"};
for(int i=0; i<n; i++){
dataList.add(Row.createSimple(Utils.easyMap("pet", levels[i % 3]), covariateList, 1, 1.0));
}
return dataList;
}

View file

@ -70,7 +70,7 @@ public class NumericCovariateTest {
@Test
public void testNumericCovariateDeterministic(){
final NumericCovariate covariate = new NumericCovariate("x", 0);
final NumericCovariate covariate = new NumericCovariate("x", 0, false);
final List<Row<Double>> dataset = createTestDataset(covariate);
@ -158,7 +158,7 @@ public class NumericCovariateTest {
@Test
public void testNumericSplitRuleUpdaterWithIndexes(){
final NumericCovariate covariate = new NumericCovariate("x", 0);
final NumericCovariate covariate = new NumericCovariate("x", 0, false);
final List<Row<Double>> dataset = createTestDataset(covariate);
@ -223,7 +223,7 @@ public class NumericCovariateTest {
*/
@Test
public void testNumericSplitRuleUpdaterWithIndexesAllMissingData(){
final NumericCovariate covariate = new NumericCovariate("x", 0);
final NumericCovariate covariate = new NumericCovariate("x", 0, false);
final List<Row<Double>> dataset = createTestDatasetMissingValues(covariate);
final NumericSplitRuleUpdater<Double> updater = covariate.generateSplitRuleUpdater(dataset, 5, new Random());

View file

@ -18,31 +18,34 @@ package ca.joeltherrien.randomforest.nas;
import ca.joeltherrien.randomforest.Row;
import ca.joeltherrien.randomforest.covariates.Covariate;
import ca.joeltherrien.randomforest.covariates.bool.BooleanCovariate;
import ca.joeltherrien.randomforest.covariates.factor.FactorCovariate;
import ca.joeltherrien.randomforest.covariates.numeric.NumericCovariate;
import ca.joeltherrien.randomforest.responses.regression.MeanResponseCombiner;
import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceSplitFinder;
import ca.joeltherrien.randomforest.tree.Split;
import ca.joeltherrien.randomforest.tree.TreeTrainer;
import ca.joeltherrien.randomforest.utils.Utils;
import org.junit.jupiter.api.Test;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Random;
import static org.junit.jupiter.api.Assertions.*;
public class TestNAs {
private List<Row<Double>> generateData(List<Covariate> covariates){
private List<Row<Double>> generateData1(List<Covariate> covariates){
final List<Row<Double>> dataList = new ArrayList<>();
// We must include an NA for one of the values
dataList.add(Row.createSimple(Utils.easyMap("x", "NA"), covariates, 1, 5.0));
dataList.add(Row.createSimple(Utils.easyMap("x", "1"), covariates, 1, 6.0));
dataList.add(Row.createSimple(Utils.easyMap("x", "2"), covariates, 1, 5.5));
dataList.add(Row.createSimple(Utils.easyMap("x", "7"), covariates, 1, 0.0));
dataList.add(Row.createSimple(Utils.easyMap("x", "8"), covariates, 1, 1.0));
dataList.add(Row.createSimple(Utils.easyMap("x", "8.4"), covariates, 1, 1.0));
dataList.add(Row.createSimple(Utils.easyMap("x", "NA", "y", "true", "z", "green"), covariates, 1, 5.0));
dataList.add(Row.createSimple(Utils.easyMap("x", "1", "y", "NA", "z", "blue"), covariates, 2, 6.0));
dataList.add(Row.createSimple(Utils.easyMap("x", "2", "y", "true", "z", "NA"), covariates, 3, 5.5));
dataList.add(Row.createSimple(Utils.easyMap("x", "7", "y", "false", "z", "green"), covariates, 4, 0.0));
dataList.add(Row.createSimple(Utils.easyMap("x", "8", "y", "true", "z", "blue"), covariates, 5, 1.0));
dataList.add(Row.createSimple(Utils.easyMap("x", "8.4", "y", "false", "z", "yellow"), covariates, 6, 1.0));
return dataList;
@ -54,14 +57,19 @@ public class TestNAs {
// but NumericSplitRuleUpdater had unmodifiable lists when creating the split.
// This bug verifies that this no longer causes a crash
final List<Covariate> covariates = Collections.singletonList(new NumericCovariate("x", 0));
final List<Row<Double>> dataset = generateData(covariates);
final List<Covariate> covariates = Utils.easyList(
new NumericCovariate("x", 0, false),
new BooleanCovariate("y", 1, true),
new FactorCovariate("z", 2, Utils.easyList("green", "blue", "yellow"), true)
);
final List<Row<Double>> dataset = generateData1(covariates);
final TreeTrainer<Double, Double> treeTrainer = TreeTrainer.<Double, Double>builder()
.checkNodePurity(false)
.covariates(covariates)
.numberOfSplits(0)
.nodeSize(1)
.mtry(3)
.maxNodeDepth(1000)
.splitFinder(new WeightedVarianceSplitFinder())
.responseCombiner(new MeanResponseCombiner())
@ -70,6 +78,87 @@ public class TestNAs {
treeTrainer.growTree(dataset, new Random(123));
// As long as no exception occurs, we passed
}
private List<Row<Double>> generateData2(List<Covariate> covariates){
final List<Row<Double>> dataList = new ArrayList<>();
// Idea - when ignoring NAs, BadVar gives a perfect split.
// GoodVar is slightly worse than BadVar when NAs are excluded.
// However, BadVar has a ton of NAs that will degrade its performance.
dataList.add(Row.createSimple(
Utils.easyMap("BadVar", "-1.0", "GoodVar", "true") // GoodVars one error
, covariates, 1, 5.0)
);
dataList.add(Row.createSimple(
Utils.easyMap("BadVar", "NA", "GoodVar", "false")
, covariates, 2, 5.0)
);
dataList.add(Row.createSimple(
Utils.easyMap("BadVar", "NA", "GoodVar", "false")
, covariates, 3, 5.0)
);
dataList.add(Row.createSimple(
Utils.easyMap("BadVar", "0.5", "GoodVar", "true")
, covariates, 4, 10.0)
);
dataList.add(Row.createSimple(
Utils.easyMap("BadVar", "NA", "GoodVar", "true")
, covariates, 5, 10.0)
);
dataList.add(Row.createSimple(
Utils.easyMap("BadVar", "NA", "GoodVar", "true")
, covariates, 6, 10.0)
);
return dataList;
}
@Test
// Test that the NA penalty works when selecting a best split.
public void testNAPenalty(){
final List<Covariate> covariates1 = Utils.easyList(
new NumericCovariate("BadVar", 0, true),
new BooleanCovariate("GoodVar", 1, false)
);
final List<Row<Double>> dataList1 = generateData2(covariates1);
final TreeTrainer<Double, Double> treeTrainer1 = TreeTrainer.<Double, Double>builder()
.checkNodePurity(false)
.covariates(covariates1)
.numberOfSplits(0)
.nodeSize(1)
.mtry(2)
.maxNodeDepth(1000)
.splitFinder(new WeightedVarianceSplitFinder())
.responseCombiner(new MeanResponseCombiner())
.build();
final Split<Double, ?> bestSplit1 = treeTrainer1.findBestSplitRule(dataList1, covariates1, new Random(123));
assertEquals(1, bestSplit1.getSplitRule().getParentCovariateIndex()); // 1 corresponds to GoodVar
// Run again without the penalty; verify that we get different results
final List<Covariate> covariates2 = Utils.easyList(
new NumericCovariate("BadVar", 0, false),
new BooleanCovariate("GoodVar", 1, false)
);
final List<Row<Double>> dataList2 = generateData2(covariates2);
final TreeTrainer<Double, Double> treeTrainer2 = TreeTrainer.<Double, Double>builder()
.checkNodePurity(false)
.covariates(covariates2)
.numberOfSplits(0)
.nodeSize(1)
.mtry(2)
.maxNodeDepth(1000)
.splitFinder(new WeightedVarianceSplitFinder())
.responseCombiner(new MeanResponseCombiner())
.build();
final Split<Double, ?> bestSplit2 = treeTrainer2.findBestSplitRule(dataList2, covariates2, new Random(123));
assertEquals(0, bestSplit2.getSplitRule().getParentCovariateIndex()); // 1 corresponds to GoodVar
}

View file

@ -39,10 +39,10 @@ public class VariableImportanceCalculatorTest {
*/
public VariableImportanceCalculatorTest(){
final BooleanCovariate booleanCovariate = new BooleanCovariate("x", 0);
final NumericCovariate numericCovariate = new NumericCovariate("y", 1);
final BooleanCovariate booleanCovariate = new BooleanCovariate("x", 0, false);
final NumericCovariate numericCovariate = new NumericCovariate("y", 1, false);
final FactorCovariate factorCovariate = new FactorCovariate("z", 2,
Utils.easyList("red", "blue", "green"));
Utils.easyList("red", "blue", "green"), false);
this.covariates = Utils.easyList(booleanCovariate, numericCovariate, factorCovariate);

View file

@ -43,7 +43,7 @@ public class TrainForest {
final List<Covariate> covariateList = new ArrayList<>(p);
for(int j =0; j < p; j++){
final NumericCovariate covariate = new NumericCovariate("x"+j, j);
final NumericCovariate covariate = new NumericCovariate("x"+j, j, false);
covariateList.add(covariate);
}

View file

@ -39,8 +39,8 @@ public class TrainSingleTree {
final int n = 1000;
final List<Row<Double>> trainingSet = new ArrayList<>(n);
final Covariate<Double> x1Covariate = new NumericCovariate("x1", 0);
final Covariate<Double> x2Covariate = new NumericCovariate("x2", 1);
final Covariate<Double> x1Covariate = new NumericCovariate("x1", 0, false);
final Covariate<Double> x2Covariate = new NumericCovariate("x2", 1, false);
final List<Covariate.Value<Double>> x1List = DoubleStream
.generate(() -> random.nextDouble()*10.0)

View file

@ -41,9 +41,9 @@ public class TrainSingleTreeFactor {
final int n = 10000;
final List<Row<Double>> trainingSet = new ArrayList<>(n);
final Covariate<Double> x1Covariate = new NumericCovariate("x1", 0);
final Covariate<Double> x2Covariate = new NumericCovariate("x2", 1);
final FactorCovariate x3Covariate = new FactorCovariate("x3", 2, Utils.easyList("cat", "dog", "mouse"));
final Covariate<Double> x1Covariate = new NumericCovariate("x1", 0, false);
final Covariate<Double> x2Covariate = new NumericCovariate("x2", 1, false);
final FactorCovariate x3Covariate = new FactorCovariate("x3", 2, Utils.easyList("cat", "dog", "mouse"), false);
final List<Covariate.Value<Double>> x1List = DoubleStream
.generate(() -> random.nextDouble()*10.0)