Add some R utility functions to help data get quickly loaded
This commit is contained in:
parent
29b154110a
commit
91cf299362
1 changed files with 114 additions and 0 deletions
|
@ -16,7 +16,15 @@
|
|||
|
||||
package ca.joeltherrien.randomforest.utils;
|
||||
|
||||
import ca.joeltherrien.randomforest.CovariateRow;
|
||||
import ca.joeltherrien.randomforest.Row;
|
||||
import ca.joeltherrien.randomforest.covariates.Covariate;
|
||||
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
|
||||
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponseWithCensorTime;
|
||||
|
||||
import java.io.*;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.zip.GZIPInputStream;
|
||||
import java.util.zip.GZIPOutputStream;
|
||||
|
||||
|
@ -62,5 +70,111 @@ public final class RUtils {
|
|||
|
||||
}
|
||||
|
||||
public static <Y> List<Row<Y>> importDataWithResponses(List<Y> responses, List<Covariate> covariates, List<String[]> rawCovariateData){
|
||||
if(covariates.size() != rawCovariateData.size()){
|
||||
throw new IllegalArgumentException("covariates size doesn't match number of columns in rawCovariateData; there must be a one-to-one relationship");
|
||||
}
|
||||
|
||||
final int n = responses.size();
|
||||
final int p = covariates.size();
|
||||
final List<Row<Y>> rowList = new ArrayList<>(n);
|
||||
|
||||
for(int i=0; i<n; i++){
|
||||
|
||||
final Covariate.Value[] valueArray = new Covariate.Value[p];
|
||||
for(int j=0; j<p; j++){
|
||||
final Covariate covariate = covariates.get(j);
|
||||
final String rawValue = rawCovariateData.get(j)[i];
|
||||
|
||||
final Covariate.Value value = covariate.createValue(rawValue);
|
||||
valueArray[j] = value;
|
||||
}
|
||||
|
||||
final Row<Y> newRow = new Row<>(valueArray, i+1, responses.get(i));
|
||||
rowList.add(newRow);
|
||||
|
||||
}
|
||||
|
||||
return rowList;
|
||||
}
|
||||
|
||||
public static List<CovariateRow> importData(List<Covariate> covariates, List<String[]> rawCovariateData){
|
||||
if(covariates.size() != rawCovariateData.size()){
|
||||
throw new IllegalArgumentException("covariates size doesn't match number of columns in rawCovariateData; there must be a one-to-one relationship");
|
||||
}
|
||||
|
||||
final int n = rawCovariateData.get(0).length;
|
||||
final int p = covariates.size();
|
||||
final List<CovariateRow> rowList = new ArrayList<>(n);
|
||||
|
||||
for(int i=0; i<n; i++){
|
||||
|
||||
final Covariate.Value[] valueArray = new Covariate.Value[p];
|
||||
for(int j=0; j<p; j++){
|
||||
final Covariate covariate = covariates.get(j);
|
||||
final String rawValue = rawCovariateData.get(j)[i];
|
||||
|
||||
final Covariate.Value value = covariate.createValue(rawValue);
|
||||
valueArray[j] = value;
|
||||
}
|
||||
|
||||
final CovariateRow newRow = new CovariateRow(valueArray, i+1);
|
||||
rowList.add(newRow);
|
||||
|
||||
}
|
||||
|
||||
return rowList;
|
||||
}
|
||||
|
||||
public static List<CompetingRiskResponseWithCensorTime> importCompetingRiskResponsesWithCensorTimes(
|
||||
final int[] eventIndicators,
|
||||
final double[] eventTimes,
|
||||
final double[] censorTimes){
|
||||
|
||||
final int n = eventIndicators.length;
|
||||
|
||||
if(eventTimes.length != n || censorTimes.length != n){
|
||||
throw new IllegalArgumentException("Array lengths must match");
|
||||
}
|
||||
|
||||
final List<CompetingRiskResponseWithCensorTime> responseList = new ArrayList<>(n);
|
||||
|
||||
for(int i=0; i<n; i++){
|
||||
responseList.add(new CompetingRiskResponseWithCensorTime(eventIndicators[i], eventTimes[i], censorTimes[i]));
|
||||
}
|
||||
|
||||
return responseList;
|
||||
|
||||
}
|
||||
|
||||
public static List<CompetingRiskResponse> importCompetingRiskResponses(
|
||||
final int[] eventIndicators,
|
||||
final double[] eventTimes){
|
||||
|
||||
final int n = eventIndicators.length;
|
||||
|
||||
if(eventTimes.length != n){
|
||||
throw new IllegalArgumentException("Array lengths must match");
|
||||
}
|
||||
|
||||
final List<CompetingRiskResponse> responseList = new ArrayList<>(n);
|
||||
|
||||
for(int i=0; i<n; i++){
|
||||
responseList.add(new CompetingRiskResponse(eventIndicators[i], eventTimes[i]));
|
||||
}
|
||||
|
||||
return responseList;
|
||||
|
||||
}
|
||||
|
||||
public static List<Double> importNumericResponse(double[] values){
|
||||
final List<Double> responses = new ArrayList<>(values.length);
|
||||
|
||||
for(double value : values){
|
||||
responses.add(value);
|
||||
}
|
||||
|
||||
return responses;
|
||||
}
|
||||
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue