
Also, the executable component uses a dependency that keeps having security vulnerabilities.
211 lines
7.3 KiB
Java
211 lines
7.3 KiB
Java
/*
|
|
* Copyright (c) 2019 Joel Therrien.
|
|
* This program is free software: you can redistribute it and/or modify
|
|
* it under the terms of the GNU General Public License as published by
|
|
* the Free Software Foundation, either version 3 of the License, or
|
|
* (at your option) any later version.
|
|
*
|
|
* This program is distributed in the hope that it will be useful,
|
|
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
* GNU General Public License for more details.
|
|
*
|
|
* You should have received a copy of the GNU General Public License
|
|
* along with this program. If not, see <https://www.gnu.org/licenses/>.
|
|
*/
|
|
|
|
package ca.joeltherrien.randomforest.utils;
|
|
|
|
import ca.joeltherrien.randomforest.CovariateRow;
|
|
import ca.joeltherrien.randomforest.Row;
|
|
import ca.joeltherrien.randomforest.covariates.Covariate;
|
|
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
|
|
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponseWithCensorTime;
|
|
|
|
import java.io.*;
|
|
import java.util.ArrayList;
|
|
import java.util.List;
|
|
import java.util.zip.GZIPInputStream;
|
|
import java.util.zip.GZIPOutputStream;
|
|
|
|
/**
|
|
* These static methods are designed to make the R interface more performant; and to avoid using R loops.
|
|
*
|
|
*/
|
|
public final class RUtils {
|
|
|
|
public static double[] extractTimes(final RightContinuousStepFunction function){
|
|
return function.getX();
|
|
}
|
|
|
|
public static double[] extractY(final RightContinuousStepFunction function){
|
|
return function.getY();
|
|
}
|
|
|
|
/**
|
|
* Convenience method to help R package serialize Java objects.
|
|
*
|
|
* @param object The object to be serialized.
|
|
* @param filename The path to the object to be saved.
|
|
* @throws IOException
|
|
*/
|
|
public static void serializeObject(Serializable object, String filename) throws IOException {
|
|
final ObjectOutputStream outputStream = new ObjectOutputStream(new GZIPOutputStream(new FileOutputStream(filename)));
|
|
outputStream.writeObject(object);
|
|
outputStream.close();
|
|
}
|
|
|
|
/**
|
|
* Convenience method to help R package load serialized Java objects.
|
|
*
|
|
* @param filename The path to the object saved.
|
|
* @throws IOException
|
|
*/
|
|
public static Object loadObject(String filename) throws IOException, ClassNotFoundException {
|
|
final ObjectInputStream inputStream = new ObjectInputStream(new GZIPInputStream(new FileInputStream(filename)));
|
|
final Object object = inputStream.readObject();
|
|
inputStream.close();
|
|
|
|
return object;
|
|
|
|
}
|
|
|
|
public static <Y> List<Row<Y>> importDataWithResponses(List<Y> responses, List<Covariate> covariates, List<String[]> rawCovariateData){
|
|
if(covariates.size() != rawCovariateData.size()){
|
|
throw new IllegalArgumentException("covariates size doesn't match number of columns in rawCovariateData; there must be a one-to-one relationship");
|
|
}
|
|
|
|
final int n = responses.size();
|
|
final int p = covariates.size();
|
|
final List<Row<Y>> rowList = new ArrayList<>(n);
|
|
|
|
// Let's verify the size first
|
|
for(int j=0; j<p; j++){
|
|
if(rawCovariateData.get(j).length != n){
|
|
final String covariateWithBadLength = covariates.get(j).getName();
|
|
throw new IllegalArgumentException(
|
|
"Length of covariate " + covariateWithBadLength +
|
|
"(" + rawCovariateData.get(j).length +
|
|
") does not match length of responses (" + n + ").");
|
|
}
|
|
}
|
|
|
|
for(int i=0; i<n; i++){
|
|
|
|
final Covariate.Value[] valueArray = new Covariate.Value[p];
|
|
for(int j=0; j<p; j++){
|
|
final Covariate covariate = covariates.get(j);
|
|
final String rawValue = rawCovariateData.get(j)[i];
|
|
|
|
final Covariate.Value value = covariate.createValue(rawValue);
|
|
valueArray[j] = value;
|
|
}
|
|
|
|
final Row<Y> newRow = new Row<>(valueArray, i+1, responses.get(i));
|
|
rowList.add(newRow);
|
|
|
|
}
|
|
|
|
return rowList;
|
|
}
|
|
|
|
public static List<CovariateRow> importData(List<Covariate> covariates, List<String[]> rawCovariateData){
|
|
if(covariates.size() != rawCovariateData.size()){
|
|
throw new IllegalArgumentException("covariates size doesn't match number of columns in rawCovariateData; there must be a one-to-one relationship");
|
|
}
|
|
|
|
final int n = rawCovariateData.get(0).length;
|
|
final int p = covariates.size();
|
|
final List<CovariateRow> rowList = new ArrayList<>(n);
|
|
|
|
for(int i=0; i<n; i++){
|
|
|
|
final Covariate.Value[] valueArray = new Covariate.Value[p];
|
|
for(int j=0; j<p; j++){
|
|
final Covariate covariate = covariates.get(j);
|
|
final String rawValue = rawCovariateData.get(j)[i];
|
|
|
|
final Covariate.Value value = covariate.createValue(rawValue);
|
|
valueArray[j] = value;
|
|
}
|
|
|
|
final CovariateRow newRow = new CovariateRow(valueArray, i+1);
|
|
rowList.add(newRow);
|
|
|
|
}
|
|
|
|
return rowList;
|
|
}
|
|
|
|
public static List<CompetingRiskResponseWithCensorTime> importCompetingRiskResponsesWithCensorTimes(
|
|
final int[] eventIndicators,
|
|
final double[] eventTimes,
|
|
final double[] censorTimes){
|
|
|
|
final int n = eventIndicators.length;
|
|
|
|
if(eventTimes.length != n || censorTimes.length != n){
|
|
throw new IllegalArgumentException("Array lengths must match");
|
|
}
|
|
|
|
final List<CompetingRiskResponseWithCensorTime> responseList = new ArrayList<>(n);
|
|
|
|
for(int i=0; i<n; i++){
|
|
responseList.add(new CompetingRiskResponseWithCensorTime(eventIndicators[i], eventTimes[i], censorTimes[i]));
|
|
}
|
|
|
|
return responseList;
|
|
|
|
}
|
|
|
|
public static List<CompetingRiskResponse> importCompetingRiskResponses(
|
|
final int[] eventIndicators,
|
|
final double[] eventTimes){
|
|
|
|
final int n = eventIndicators.length;
|
|
|
|
if(eventTimes.length != n){
|
|
throw new IllegalArgumentException("Array lengths must match");
|
|
}
|
|
|
|
final List<CompetingRiskResponse> responseList = new ArrayList<>(n);
|
|
|
|
for(int i=0; i<n; i++){
|
|
responseList.add(new CompetingRiskResponse(eventIndicators[i], eventTimes[i]));
|
|
}
|
|
|
|
return responseList;
|
|
|
|
}
|
|
|
|
public static List<Double> importNumericResponse(double[] values){
|
|
final List<Double> responses = new ArrayList<>(values.length);
|
|
|
|
for(double value : values){
|
|
responses.add(value);
|
|
}
|
|
|
|
return responses;
|
|
}
|
|
|
|
public static List<Object> produceSublist(List<Object> initialList, int[] indices){
|
|
final List<Object> newList = new ArrayList<>(indices.length);
|
|
|
|
for(int i : indices){
|
|
newList.add(initialList.get(i));
|
|
}
|
|
|
|
return newList;
|
|
}
|
|
|
|
public static File[] getTreeFileArray(String folderPath, int endingId){
|
|
final File[] fileArray = new File[endingId];
|
|
|
|
for(int i = 1; i <= endingId; i++){
|
|
fileArray[i-1] = new File(folderPath + "/tree-" + i + ".tree");
|
|
}
|
|
|
|
return fileArray;
|
|
}
|
|
|
|
}
|