largeRCRF-Java/src/main/java/ca/joeltherrien/randomforest/utils/RUtils.java
Joel Therrien c4bab39245 Remove the executable component, as the R package component has advanced enough that it can do everything.
Also, the executable component uses a dependency that keeps having security vulnerabilities.
2019-11-14 08:59:27 -08:00

211 lines
7.3 KiB
Java

/*
* Copyright (c) 2019 Joel Therrien.
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
package ca.joeltherrien.randomforest.utils;
import ca.joeltherrien.randomforest.CovariateRow;
import ca.joeltherrien.randomforest.Row;
import ca.joeltherrien.randomforest.covariates.Covariate;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponseWithCensorTime;
import java.io.*;
import java.util.ArrayList;
import java.util.List;
import java.util.zip.GZIPInputStream;
import java.util.zip.GZIPOutputStream;
/**
* These static methods are designed to make the R interface more performant; and to avoid using R loops.
*
*/
public final class RUtils {
public static double[] extractTimes(final RightContinuousStepFunction function){
return function.getX();
}
public static double[] extractY(final RightContinuousStepFunction function){
return function.getY();
}
/**
* Convenience method to help R package serialize Java objects.
*
* @param object The object to be serialized.
* @param filename The path to the object to be saved.
* @throws IOException
*/
public static void serializeObject(Serializable object, String filename) throws IOException {
final ObjectOutputStream outputStream = new ObjectOutputStream(new GZIPOutputStream(new FileOutputStream(filename)));
outputStream.writeObject(object);
outputStream.close();
}
/**
* Convenience method to help R package load serialized Java objects.
*
* @param filename The path to the object saved.
* @throws IOException
*/
public static Object loadObject(String filename) throws IOException, ClassNotFoundException {
final ObjectInputStream inputStream = new ObjectInputStream(new GZIPInputStream(new FileInputStream(filename)));
final Object object = inputStream.readObject();
inputStream.close();
return object;
}
public static <Y> List<Row<Y>> importDataWithResponses(List<Y> responses, List<Covariate> covariates, List<String[]> rawCovariateData){
if(covariates.size() != rawCovariateData.size()){
throw new IllegalArgumentException("covariates size doesn't match number of columns in rawCovariateData; there must be a one-to-one relationship");
}
final int n = responses.size();
final int p = covariates.size();
final List<Row<Y>> rowList = new ArrayList<>(n);
// Let's verify the size first
for(int j=0; j<p; j++){
if(rawCovariateData.get(j).length != n){
final String covariateWithBadLength = covariates.get(j).getName();
throw new IllegalArgumentException(
"Length of covariate " + covariateWithBadLength +
"(" + rawCovariateData.get(j).length +
") does not match length of responses (" + n + ").");
}
}
for(int i=0; i<n; i++){
final Covariate.Value[] valueArray = new Covariate.Value[p];
for(int j=0; j<p; j++){
final Covariate covariate = covariates.get(j);
final String rawValue = rawCovariateData.get(j)[i];
final Covariate.Value value = covariate.createValue(rawValue);
valueArray[j] = value;
}
final Row<Y> newRow = new Row<>(valueArray, i+1, responses.get(i));
rowList.add(newRow);
}
return rowList;
}
public static List<CovariateRow> importData(List<Covariate> covariates, List<String[]> rawCovariateData){
if(covariates.size() != rawCovariateData.size()){
throw new IllegalArgumentException("covariates size doesn't match number of columns in rawCovariateData; there must be a one-to-one relationship");
}
final int n = rawCovariateData.get(0).length;
final int p = covariates.size();
final List<CovariateRow> rowList = new ArrayList<>(n);
for(int i=0; i<n; i++){
final Covariate.Value[] valueArray = new Covariate.Value[p];
for(int j=0; j<p; j++){
final Covariate covariate = covariates.get(j);
final String rawValue = rawCovariateData.get(j)[i];
final Covariate.Value value = covariate.createValue(rawValue);
valueArray[j] = value;
}
final CovariateRow newRow = new CovariateRow(valueArray, i+1);
rowList.add(newRow);
}
return rowList;
}
public static List<CompetingRiskResponseWithCensorTime> importCompetingRiskResponsesWithCensorTimes(
final int[] eventIndicators,
final double[] eventTimes,
final double[] censorTimes){
final int n = eventIndicators.length;
if(eventTimes.length != n || censorTimes.length != n){
throw new IllegalArgumentException("Array lengths must match");
}
final List<CompetingRiskResponseWithCensorTime> responseList = new ArrayList<>(n);
for(int i=0; i<n; i++){
responseList.add(new CompetingRiskResponseWithCensorTime(eventIndicators[i], eventTimes[i], censorTimes[i]));
}
return responseList;
}
public static List<CompetingRiskResponse> importCompetingRiskResponses(
final int[] eventIndicators,
final double[] eventTimes){
final int n = eventIndicators.length;
if(eventTimes.length != n){
throw new IllegalArgumentException("Array lengths must match");
}
final List<CompetingRiskResponse> responseList = new ArrayList<>(n);
for(int i=0; i<n; i++){
responseList.add(new CompetingRiskResponse(eventIndicators[i], eventTimes[i]));
}
return responseList;
}
public static List<Double> importNumericResponse(double[] values){
final List<Double> responses = new ArrayList<>(values.length);
for(double value : values){
responses.add(value);
}
return responses;
}
public static List<Object> produceSublist(List<Object> initialList, int[] indices){
final List<Object> newList = new ArrayList<>(indices.length);
for(int i : indices){
newList.add(initialList.get(i));
}
return newList;
}
public static File[] getTreeFileArray(String folderPath, int endingId){
final File[] fileArray = new File[endingId];
for(int i = 1; i <= endingId; i++){
fileArray[i-1] = new File(folderPath + "/tree-" + i + ".tree");
}
return fileArray;
}
}