diff --git a/src/main/java/ca/joeltherrien/randomforest/utils/RUtils.java b/src/main/java/ca/joeltherrien/randomforest/utils/RUtils.java index b5e8c7e..98d30dd 100644 --- a/src/main/java/ca/joeltherrien/randomforest/utils/RUtils.java +++ b/src/main/java/ca/joeltherrien/randomforest/utils/RUtils.java @@ -16,7 +16,15 @@ package ca.joeltherrien.randomforest.utils; +import ca.joeltherrien.randomforest.CovariateRow; +import ca.joeltherrien.randomforest.Row; +import ca.joeltherrien.randomforest.covariates.Covariate; +import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse; +import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponseWithCensorTime; + import java.io.*; +import java.util.ArrayList; +import java.util.List; import java.util.zip.GZIPInputStream; import java.util.zip.GZIPOutputStream; @@ -62,5 +70,111 @@ public final class RUtils { } + public static List> importDataWithResponses(List responses, List covariates, List rawCovariateData){ + if(covariates.size() != rawCovariateData.size()){ + throw new IllegalArgumentException("covariates size doesn't match number of columns in rawCovariateData; there must be a one-to-one relationship"); + } + + final int n = responses.size(); + final int p = covariates.size(); + final List> rowList = new ArrayList<>(n); + + for(int i=0; i newRow = new Row<>(valueArray, i+1, responses.get(i)); + rowList.add(newRow); + + } + + return rowList; + } + + public static List importData(List covariates, List rawCovariateData){ + if(covariates.size() != rawCovariateData.size()){ + throw new IllegalArgumentException("covariates size doesn't match number of columns in rawCovariateData; there must be a one-to-one relationship"); + } + + final int n = rawCovariateData.get(0).length; + final int p = covariates.size(); + final List rowList = new ArrayList<>(n); + + for(int i=0; i importCompetingRiskResponsesWithCensorTimes( + final int[] eventIndicators, + final double[] eventTimes, + final double[] censorTimes){ + + final int n = eventIndicators.length; + + if(eventTimes.length != n || censorTimes.length != n){ + throw new IllegalArgumentException("Array lengths must match"); + } + + final List responseList = new ArrayList<>(n); + + for(int i=0; i importCompetingRiskResponses( + final int[] eventIndicators, + final double[] eventTimes){ + + final int n = eventIndicators.length; + + if(eventTimes.length != n){ + throw new IllegalArgumentException("Array lengths must match"); + } + + final List responseList = new ArrayList<>(n); + + for(int i=0; i importNumericResponse(double[] values){ + final List responses = new ArrayList<>(values.length); + + for(double value : values){ + responses.add(value); + } + + return responses; + } }