diff --git a/src/main/java/ca/joeltherrien/randomforest/covariates/FactorCovariate.java b/src/main/java/ca/joeltherrien/randomforest/covariates/FactorCovariate.java index cf48933..ebf6cf4 100644 --- a/src/main/java/ca/joeltherrien/randomforest/covariates/FactorCovariate.java +++ b/src/main/java/ca/joeltherrien/randomforest/covariates/FactorCovariate.java @@ -69,7 +69,7 @@ public final class FactorCovariate implements Covariate{ @Override public FactorValue createValue(String value) { - if(value == null){ + if(value == null || value.equalsIgnoreCase("na")){ return this.naValue; } diff --git a/src/test/java/ca/joeltherrien/randomforest/csv/TestLoadingCSV.java b/src/test/java/ca/joeltherrien/randomforest/csv/TestLoadingCSV.java index 5a1faf7..8032185 100644 --- a/src/test/java/ca/joeltherrien/randomforest/csv/TestLoadingCSV.java +++ b/src/test/java/ca/joeltherrien/randomforest/csv/TestLoadingCSV.java @@ -22,6 +22,7 @@ public class TestLoadingCSV { 5,3.0,"mouse",true 2,1.0,"dog",false 9,1.5,"cat",true + -3,NA,NA,NA */ @Test @@ -41,7 +42,7 @@ public class TestLoadingCSV { final List> data = Main.loadData(covariates, settings); - assertEquals(3, data.size()); + assertEquals(4, data.size()); Row row = data.get(0); assertEquals(5.0, (double)row.getResponse()); @@ -61,6 +62,12 @@ public class TestLoadingCSV { assertEquals("cat", row.getCovariateValue("x2").getValue()); assertEquals(true, row.getCovariateValue("x3").getValue()); + row = data.get(3); + assertEquals(-3.0, (double)row.getResponse()); + assertEquals(true, row.getCovariateValue("x1").isNA()); + assertEquals(true, row.getCovariateValue("x2").isNA()); + assertEquals(true, row.getCovariateValue("x3").isNA()); + } } diff --git a/src/test/resources/testCSV.csv b/src/test/resources/testCSV.csv index ca1d181..7083420 100644 --- a/src/test/resources/testCSV.csv +++ b/src/test/resources/testCSV.csv @@ -1,4 +1,5 @@ y,x1,x2,x3 5,3.0,"mouse",true 2,1.0,"dog",false -9,1.5,"cat",true \ No newline at end of file +9,1.5,"cat",true +-3,NA,NA,NA \ No newline at end of file