From 76614ee68bcc1a3cf89379da9708ea5df51b4601 Mon Sep 17 00:00:00 2001 From: Joel Therrien Date: Mon, 25 Mar 2019 10:59:26 -0700 Subject: [PATCH] Better memory management to help prevent OutOfMemoryExceptions --- .../randomforest/covariates/Covariate.java | 53 +++++++++++++++---- .../{DataLoader.java => utils/DataUtils.java} | 9 +++- 2 files changed, 51 insertions(+), 11 deletions(-) rename src/main/java/ca/joeltherrien/randomforest/{DataLoader.java => utils/DataUtils.java} (91%) diff --git a/src/main/java/ca/joeltherrien/randomforest/covariates/Covariate.java b/src/main/java/ca/joeltherrien/randomforest/covariates/Covariate.java index 193889f..ae21287 100644 --- a/src/main/java/ca/joeltherrien/randomforest/covariates/Covariate.java +++ b/src/main/java/ca/joeltherrien/randomforest/covariates/Covariate.java @@ -82,31 +82,64 @@ public interface Covariate extends Serializable, Comparable { * @return */ default Split applyRule(List> rows) { - final List> leftHand = new ArrayList<>(rows.size()*3/4); - final List> rightHand = new ArrayList<>(rows.size()*3/4); - final List> missingValueRows = new ArrayList<>(); + /* + When working with really large List> we need to be careful about memory. + If the lefthand and righthand lists are too small they grow, but for a moment copies exist + and memory issues arise. + + If they're too large, we waste memory yet again + */ + + // value of 0 = rightHand, value of 1 = leftHand, value of 2 = missingValueHand + final byte[] whichHand = new byte[rows.size()]; + int countLeftHand = 0; + int countRightHand = 0; + int countMissingHand = 0; - for(final Row row : rows) { + + for(int i=0; i row = rows.get(i); + final Value value = row.getCovariateValue(getParent()); if(value.isNA()){ - missingValueRows.add(row); - continue; + countMissingHand++; + whichHand[i] = 2; } - final boolean isLeftHand = isLeftHand(value); - if(isLeftHand){ - leftHand.add(row); + if(isLeftHand(value)){ + countLeftHand++; + whichHand[i] = 1; } else{ - rightHand.add(row); + countRightHand++; + whichHand[i] = 0; } } + final List> missingValueRows = new ArrayList<>(countMissingHand); + final List> leftHand = new ArrayList<>(countLeftHand); + final List> rightHand = new ArrayList<>(countRightHand); + + for(int i=0; i row = rows.get(i); + + if(whichHand[i] == 0){ + rightHand.add(row); + } + else if(whichHand[i] == 1){ + leftHand.add(row); + } + else{ + missingValueRows.add(row); + } + + } + return new Split<>(this, leftHand, rightHand, missingValueRows); } diff --git a/src/main/java/ca/joeltherrien/randomforest/DataLoader.java b/src/main/java/ca/joeltherrien/randomforest/utils/DataUtils.java similarity index 91% rename from src/main/java/ca/joeltherrien/randomforest/DataLoader.java rename to src/main/java/ca/joeltherrien/randomforest/utils/DataUtils.java index 5a8b022..bc5ea80 100644 --- a/src/main/java/ca/joeltherrien/randomforest/DataLoader.java +++ b/src/main/java/ca/joeltherrien/randomforest/utils/DataUtils.java @@ -29,8 +29,9 @@ import org.apache.commons.csv.CSVRecord; import java.io.*; import java.util.*; import java.util.zip.GZIPInputStream; +import java.util.zip.GZIPOutputStream; -public class DataLoader { +public class DataUtils { public static List> loadData(final List covariates, final ResponseLoader responseLoader, String filename) throws IOException { @@ -97,6 +98,12 @@ public class DataLoader { } + public static void saveObject(Serializable object, String filename) throws IOException { + final ObjectOutputStream outputStream = new ObjectOutputStream(new GZIPOutputStream(new FileOutputStream(filename))); + outputStream.writeObject(object); + outputStream.close(); + } + @FunctionalInterface public interface ResponseLoader{ Y parse(CSVRecord record);