/*
 * Copyright (c) 2019 Joel Therrien.
 *     This program is free software: you can redistribute it and/or modify
 *     it under the terms of the GNU General Public License as published by
 *     the Free Software Foundation, either version 3 of the License, or
 *     (at your option) any later version.
 *
 *     This program is distributed in the hope that it will be useful,
 *     but WITHOUT ANY WARRANTY; without even the implied warranty of
 *     MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *     GNU General Public License for more details.
 *
 *     You should have received a copy of the GNU General Public License
 *     along with this program.  If not, see <https://www.gnu.org/licenses/>.
 */

package ca.joeltherrien.randomforest.utils;

import ca.joeltherrien.randomforest.CovariateRow;
import ca.joeltherrien.randomforest.Row;
import ca.joeltherrien.randomforest.covariates.Covariate;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponseWithCensorTime;

import java.io.*;
import java.util.ArrayList;
import java.util.List;
import java.util.zip.GZIPInputStream;
import java.util.zip.GZIPOutputStream;

/**
 * These static methods are designed to make the R interface more performant; and to avoid using R loops.
 *
 */
public final class RUtils {

    public static double[] extractTimes(final RightContinuousStepFunction function){
        return function.getX();
    }

    public static double[] extractY(final RightContinuousStepFunction function){
        return function.getY();
    }

    /**
     * Convenience method to help R package serialize Java objects.
     *
     * @param object The object to be serialized.
     * @param filename The path to the object to be saved.
     * @throws IOException
     */
    public static void serializeObject(Serializable object, String filename) throws IOException {
        final ObjectOutputStream outputStream = new ObjectOutputStream(new GZIPOutputStream(new FileOutputStream(filename)));
        outputStream.writeObject(object);
        outputStream.close();
    }

    /**
     * Convenience method to help R package load serialized Java objects.
     *
     * @param filename The path to the object saved.
     * @throws IOException
     */
    public static Object loadObject(String filename) throws IOException, ClassNotFoundException {
        final ObjectInputStream inputStream = new ObjectInputStream(new GZIPInputStream(new FileInputStream(filename)));
        final Object object = inputStream.readObject();
        inputStream.close();

        return object;

    }

    public static <Y> List<Row<Y>> importDataWithResponses(List<Y> responses, List<Covariate> covariates, List<String[]> rawCovariateData){
        if(covariates.size() != rawCovariateData.size()){
            throw new IllegalArgumentException("covariates size doesn't match number of columns in rawCovariateData; there must be a one-to-one relationship");
        }

        final int n = responses.size();
        final int p = covariates.size();
        final List<Row<Y>> rowList = new ArrayList<>(n);

        // Let's verify the size first
        for(int j=0; j<p; j++){
            if(rawCovariateData.get(j).length != n){
                final String covariateWithBadLength = covariates.get(j).getName();
                throw new IllegalArgumentException(
                        "Length of covariate " + covariateWithBadLength +
                                "(" + rawCovariateData.get(j).length +
                                ") does not match length of responses (" + n + ").");
            }
        }

        for(int i=0; i<n; i++){

            final Covariate.Value[] valueArray = new Covariate.Value[p];
            for(int j=0; j<p; j++){
                final Covariate covariate = covariates.get(j);
                final String rawValue = rawCovariateData.get(j)[i];

                final Covariate.Value value = covariate.createValue(rawValue);
                valueArray[j] = value;
            }

            final Row<Y> newRow = new Row<>(valueArray, i+1, responses.get(i));
            rowList.add(newRow);

        }

        return rowList;
    }

    public static List<CovariateRow> importData(List<Covariate> covariates, List<String[]> rawCovariateData){
        if(covariates.size() != rawCovariateData.size()){
            throw new IllegalArgumentException("covariates size doesn't match number of columns in rawCovariateData; there must be a one-to-one relationship");
        }

        final int n = rawCovariateData.get(0).length;
        final int p = covariates.size();
        final List<CovariateRow> rowList = new ArrayList<>(n);

        for(int i=0; i<n; i++){

            final Covariate.Value[] valueArray = new Covariate.Value[p];
            for(int j=0; j<p; j++){
                final Covariate covariate = covariates.get(j);
                final String rawValue = rawCovariateData.get(j)[i];

                final Covariate.Value value = covariate.createValue(rawValue);
                valueArray[j] = value;
            }

            final CovariateRow newRow = new CovariateRow(valueArray, i+1);
            rowList.add(newRow);

        }

        return rowList;
    }

    public static List<CompetingRiskResponseWithCensorTime> importCompetingRiskResponsesWithCensorTimes(
            final int[] eventIndicators,
            final double[] eventTimes,
            final double[] censorTimes){

        final int n = eventIndicators.length;

        if(eventTimes.length != n || censorTimes.length != n){
            throw new IllegalArgumentException("Array lengths must match");
        }

        final List<CompetingRiskResponseWithCensorTime> responseList = new ArrayList<>(n);

        for(int i=0; i<n; i++){
            responseList.add(new CompetingRiskResponseWithCensorTime(eventIndicators[i], eventTimes[i], censorTimes[i]));
        }

        return responseList;

    }

    public static List<CompetingRiskResponse> importCompetingRiskResponses(
            final int[] eventIndicators,
            final double[] eventTimes){

        final int n = eventIndicators.length;

        if(eventTimes.length != n){
            throw new IllegalArgumentException("Array lengths must match");
        }

        final List<CompetingRiskResponse> responseList = new ArrayList<>(n);

        for(int i=0; i<n; i++){
            responseList.add(new CompetingRiskResponse(eventIndicators[i], eventTimes[i]));
        }

        return responseList;

    }

    public static List<Double> importNumericResponse(double[] values){
        final List<Double> responses = new ArrayList<>(values.length);

        for(double value : values){
            responses.add(value);
        }

        return responses;
    }

    public static List<Object> produceSublist(List<Object> initialList, int[] indices){
        final List<Object> newList = new ArrayList<>(indices.length);

        for(int i : indices){
            newList.add(initialList.get(i));
        }

        return newList;
    }

    public static File[] getTreeFileArray(String folderPath, int endingId){
        final File[] fileArray = new File[endingId];

        for(int i = 1; i <= endingId; i++){
            fileArray[i-1] = new File(folderPath + "/tree-" + i + ".tree");
        }

        return fileArray;
    }

}