diff --git a/src/main/java/ca/joeltherrien/randomforest/DataLoader.java b/src/main/java/ca/joeltherrien/randomforest/DataLoader.java index 97ebbfc..09d5e96 100644 --- a/src/main/java/ca/joeltherrien/randomforest/DataLoader.java +++ b/src/main/java/ca/joeltherrien/randomforest/DataLoader.java @@ -12,6 +12,7 @@ import org.apache.commons.csv.CSVRecord; import java.io.*; import java.util.*; +import java.util.zip.GZIPInputStream; public class DataLoader { @@ -19,7 +20,18 @@ public class DataLoader { final List> dataset = new ArrayList<>(); - final Reader input = new FileReader(filename); + final Reader input; + if(filename.endsWith(".gz")){ + final FileInputStream inputStream = new FileInputStream(filename); + final GZIPInputStream gzipInputStream = new GZIPInputStream(inputStream); + + input = new InputStreamReader(gzipInputStream); + } + else{ + input = new FileReader(filename); + } + + final CSVParser parser = CSVFormat.RFC4180.withFirstRecordAsHeader().parse(input); diff --git a/src/test/java/ca/joeltherrien/randomforest/csv/TestLoadingCSV.java b/src/test/java/ca/joeltherrien/randomforest/csv/TestLoadingCSV.java index 96a2a48..3037ea6 100644 --- a/src/test/java/ca/joeltherrien/randomforest/csv/TestLoadingCSV.java +++ b/src/test/java/ca/joeltherrien/randomforest/csv/TestLoadingCSV.java @@ -28,14 +28,14 @@ public class TestLoadingCSV { -3,NA,NA,NA */ - @Test - public void verifyLoading() throws IOException, ClassNotFoundException { + + public List> loadData(String filename) throws IOException { final ObjectNode yVarSettings = new ObjectNode(JsonNodeFactory.instance); yVarSettings.set("type", new TextNode("Double")); yVarSettings.set("name", new TextNode("y")); final Settings settings = Settings.builder() - .dataFileLocation("src/test/resources/testCSV.csv") + .dataFileLocation(filename) .covariates( List.of(new NumericCovariateSettings("x1"), new FactorCovariateSettings("x2", List.of("dog", "cat", "mouse")), @@ -52,6 +52,25 @@ public class TestLoadingCSV { final List> data = DataLoader.loadData(covariates, loader, settings.getDataFileLocation()); + return data; + } + + @Test + public void verifyLoadingNormal() throws IOException { + final List> data = loadData("src/test/resources/testCSV.csv"); + + assertData(data); + } + + @Test + public void verifyLoadingGz() throws IOException { + final List> data = loadData("src/test/resources/testCSV.csv.gz"); + + assertData(data); + } + + + private void assertData(final List> data){ assertEquals(4, data.size()); Row row = data.get(0); @@ -77,7 +96,6 @@ public class TestLoadingCSV { assertEquals(true, row.getCovariateValue("x1").isNA()); assertEquals(true, row.getCovariateValue("x2").isNA()); assertEquals(true, row.getCovariateValue("x3").isNA()); - } } diff --git a/src/test/resources/testCSV.csv.gz b/src/test/resources/testCSV.csv.gz new file mode 100644 index 0000000..41e9bf9 Binary files /dev/null and b/src/test/resources/testCSV.csv.gz differ