Add some R utility functions to help data get quickly loaded

This commit is contained in:
Joel Therrien 2019-03-04 11:23:31 -08:00
parent 29b154110a
commit 91cf299362

View file

@ -16,7 +16,15 @@
package ca.joeltherrien.randomforest.utils; package ca.joeltherrien.randomforest.utils;
import ca.joeltherrien.randomforest.CovariateRow;
import ca.joeltherrien.randomforest.Row;
import ca.joeltherrien.randomforest.covariates.Covariate;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponseWithCensorTime;
import java.io.*; import java.io.*;
import java.util.ArrayList;
import java.util.List;
import java.util.zip.GZIPInputStream; import java.util.zip.GZIPInputStream;
import java.util.zip.GZIPOutputStream; import java.util.zip.GZIPOutputStream;
@ -62,5 +70,111 @@ public final class RUtils {
} }
public static <Y> List<Row<Y>> importDataWithResponses(List<Y> responses, List<Covariate> covariates, List<String[]> rawCovariateData){
if(covariates.size() != rawCovariateData.size()){
throw new IllegalArgumentException("covariates size doesn't match number of columns in rawCovariateData; there must be a one-to-one relationship");
}
final int n = responses.size();
final int p = covariates.size();
final List<Row<Y>> rowList = new ArrayList<>(n);
for(int i=0; i<n; i++){
final Covariate.Value[] valueArray = new Covariate.Value[p];
for(int j=0; j<p; j++){
final Covariate covariate = covariates.get(j);
final String rawValue = rawCovariateData.get(j)[i];
final Covariate.Value value = covariate.createValue(rawValue);
valueArray[j] = value;
}
final Row<Y> newRow = new Row<>(valueArray, i+1, responses.get(i));
rowList.add(newRow);
}
return rowList;
}
public static List<CovariateRow> importData(List<Covariate> covariates, List<String[]> rawCovariateData){
if(covariates.size() != rawCovariateData.size()){
throw new IllegalArgumentException("covariates size doesn't match number of columns in rawCovariateData; there must be a one-to-one relationship");
}
final int n = rawCovariateData.get(0).length;
final int p = covariates.size();
final List<CovariateRow> rowList = new ArrayList<>(n);
for(int i=0; i<n; i++){
final Covariate.Value[] valueArray = new Covariate.Value[p];
for(int j=0; j<p; j++){
final Covariate covariate = covariates.get(j);
final String rawValue = rawCovariateData.get(j)[i];
final Covariate.Value value = covariate.createValue(rawValue);
valueArray[j] = value;
}
final CovariateRow newRow = new CovariateRow(valueArray, i+1);
rowList.add(newRow);
}
return rowList;
}
public static List<CompetingRiskResponseWithCensorTime> importCompetingRiskResponsesWithCensorTimes(
final int[] eventIndicators,
final double[] eventTimes,
final double[] censorTimes){
final int n = eventIndicators.length;
if(eventTimes.length != n || censorTimes.length != n){
throw new IllegalArgumentException("Array lengths must match");
}
final List<CompetingRiskResponseWithCensorTime> responseList = new ArrayList<>(n);
for(int i=0; i<n; i++){
responseList.add(new CompetingRiskResponseWithCensorTime(eventIndicators[i], eventTimes[i], censorTimes[i]));
}
return responseList;
}
public static List<CompetingRiskResponse> importCompetingRiskResponses(
final int[] eventIndicators,
final double[] eventTimes){
final int n = eventIndicators.length;
if(eventTimes.length != n){
throw new IllegalArgumentException("Array lengths must match");
}
final List<CompetingRiskResponse> responseList = new ArrayList<>(n);
for(int i=0; i<n; i++){
responseList.add(new CompetingRiskResponse(eventIndicators[i], eventTimes[i]));
}
return responseList;
}
public static List<Double> importNumericResponse(double[] values){
final List<Double> responses = new ArrayList<>(values.length);
for(double value : values){
responses.add(value);
}
return responses;
}
} }