/*
 * Copyright (c) 2019 Joel Therrien.
 *     This program is free software: you can redistribute it and/or modify
 *     it under the terms of the GNU General Public License as published by
 *     the Free Software Foundation, either version 3 of the License, or
 *     (at your option) any later version.
 *
 *     This program is distributed in the hope that it will be useful,
 *     but WITHOUT ANY WARRANTY; without even the implied warranty of
 *     MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *     GNU General Public License for more details.
 *
 *     You should have received a copy of the GNU General Public License
 *     along with this program.  If not, see <https://www.gnu.org/licenses/>.
 */

package ca.joeltherrien.randomforest;

import ca.joeltherrien.randomforest.covariates.Covariate;
import ca.joeltherrien.randomforest.covariates.bool.BooleanCovariate;
import ca.joeltherrien.randomforest.covariates.factor.FactorCovariate;
import ca.joeltherrien.randomforest.covariates.numeric.NumericCovariate;
import ca.joeltherrien.randomforest.responses.regression.MeanResponseCombiner;
import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceSplitFinder;
import ca.joeltherrien.randomforest.tree.Forest;
import ca.joeltherrien.randomforest.tree.ForestTrainer;
import ca.joeltherrien.randomforest.tree.TreeTrainer;
import ca.joeltherrien.randomforest.utils.DataUtils;
import ca.joeltherrien.randomforest.utils.Utils;
import org.junit.jupiter.api.Test;

import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import java.util.Random;

public class TestDeterministicForests {

    private final String saveTreeLocation = "src/test/resources/trees/";

    private List<Covariate> generateCovariates(){
        final List<Covariate> covariateList = new ArrayList<>();

        int index = 0;
        for(int j=0; j<5; j++){
            final NumericCovariate numericCovariate = new NumericCovariate("numeric"+j, index, false);
            covariateList.add(numericCovariate);
            index++;
        }

        for(int j=0; j<5; j++){
            final BooleanCovariate booleanCovariate = new BooleanCovariate("boolean"+j, index, false);
            covariateList.add(booleanCovariate);
            index++;
        }

        final List<String> levels = Utils.easyList("cat", "dog", "mouse");
        for(int j=0; j<5; j++){
            final FactorCovariate factorCovariate = new FactorCovariate("factor"+j, index, levels, false);
            covariateList.add(factorCovariate);
            index++;
        }

        return covariateList;
    }

    private Covariate.Value generateRandomValue(Covariate covariate, Random random){
        if(covariate instanceof NumericCovariate){
            return covariate.createValue(random.nextGaussian());
        }
        if(covariate instanceof BooleanCovariate){
            return covariate.createValue(random.nextBoolean());
        }
        if(covariate instanceof FactorCovariate){
            final double itemSelection = random.nextDouble();
            final String item;
            if(itemSelection < 1.0/3.0){
                item = "cat";
            }
            else if(itemSelection < 2.0/3.0){
                item = "dog";
            }
            else{
                item = "mouse";
            }

            return covariate.createValue(item);
        }
        else{
            throw new IllegalArgumentException("Unknown covariate type of class " + covariate.getClass().getName());
        }

    }

    private List<Row<Double>> generateTestData(List<Covariate> covariateList, int n, Random random){
        final List<Row<Double>> rowList = new ArrayList<>();
        for(int i=0; i<n; i++){
            final double response = random.nextGaussian();
            final Covariate.Value[] valueArray = new Covariate.Value[covariateList.size()];

            for(int j=0; j<covariateList.size(); j++){
                valueArray[j] = generateRandomValue(covariateList.get(j), random);
            }

            rowList.add(new Row(valueArray, i, response));
        }

        return rowList;
    }

    @Test
    public void testResultsAlwaysSame() throws IOException, ClassNotFoundException {
        final List<Covariate> covariateList = generateCovariates();

        final Random dataGeneratingRandom = new Random();

        final List<Row<Double>> trainingData = generateTestData(covariateList, 100, dataGeneratingRandom);
        final List<Row<Double>> testData = generateTestData(covariateList, 10, dataGeneratingRandom);

        // pick a new seed at random
        final long trainingSeed = dataGeneratingRandom.nextLong();

        final TreeTrainer<Double, Double> treeTrainer = TreeTrainer.<Double, Double>builder()
                .checkNodePurity(false)
                .covariates(covariateList)
                .maxNodeDepth(100)
                .mtry(1)
                .nodeSize(10)
                .numberOfSplits(1) // want results to be dominated by randomness
                .responseCombiner(new MeanResponseCombiner())
                .splitFinder(new WeightedVarianceSplitFinder())
                .build();

        final ForestTrainer<Double, Double, Double> forestTrainer = ForestTrainer.<Double, Double, Double>builder()
                .treeTrainer(treeTrainer)
                .covariates(covariateList)
                .data(trainingData)
                .displayProgress(false)
                .ntree(10)
                .randomSeed(trainingSeed)
                .treeResponseCombiner(new MeanResponseCombiner())
                .saveTreeLocation(saveTreeLocation)
                .build();

        // By training the referenceForest through one method we also verify that all the methods produce the same forests for a given seed.
        final Forest<Double, Double> referenceForest = forestTrainer.trainSerialInMemory(Optional.empty());

        verifySerialInMemoryTraining(referenceForest, forestTrainer, testData);
        verifyParallelInMemoryTraining(referenceForest, forestTrainer, testData);
        verifySerialOnDiskTraining(referenceForest, forestTrainer, testData);
        verifyParallelOnDiskTraining(referenceForest, forestTrainer, testData);

    }

    /**
     * Tests that if we train a forest under a specified seed for 10 trees, that it is equal to training a forest
     * for 5 trees only, and then starting from that point to train the last 5.
     *
     * @throws IOException
     * @throws ClassNotFoundException
     */
    @Test
    public void testInterupptedTrainingProducesSameResults() throws IOException, ClassNotFoundException {
        final List<Covariate> covariateList = generateCovariates();

        final Random dataGeneratingRandom = new Random();

        final List<Row<Double>> trainingData = generateTestData(covariateList, 100, dataGeneratingRandom);
        final List<Row<Double>> testData = generateTestData(covariateList, 10, dataGeneratingRandom);

        // pick a new seed at random
        final long trainingSeed = dataGeneratingRandom.nextLong();

        final TreeTrainer<Double, Double> treeTrainer = TreeTrainer.<Double, Double>builder()
                .checkNodePurity(false)
                .covariates(covariateList)
                .maxNodeDepth(100)
                .mtry(1)
                .nodeSize(10)
                .numberOfSplits(1) // want results to be dominated by randomness
                .responseCombiner(new MeanResponseCombiner())
                .splitFinder(new WeightedVarianceSplitFinder())
                .build();

        final ForestTrainer<Double, Double, Double> forestTrainer5Trees = ForestTrainer.<Double, Double, Double>builder()
                .treeTrainer(treeTrainer)
                .covariates(covariateList)
                .data(trainingData)
                .displayProgress(false)
                .ntree(5)
                .randomSeed(trainingSeed)
                .treeResponseCombiner(new MeanResponseCombiner())
                .saveTreeLocation(saveTreeLocation)
                .build();

        final ForestTrainer<Double, Double, Double> forestTrainer10Trees = ForestTrainer.<Double, Double, Double>builder()
                .treeTrainer(treeTrainer)
                .covariates(covariateList)
                .data(trainingData)
                .displayProgress(false)
                .ntree(10)
                .randomSeed(trainingSeed)
                .treeResponseCombiner(new MeanResponseCombiner())
                .saveTreeLocation(saveTreeLocation)
                .build();

        // By training the referenceForest through one method we also verify that all the methods produce the same forests for a given seed.
        final Forest<Double, Double> referenceForest = forestTrainer10Trees.trainSerialInMemory(Optional.empty());

        final File saveTreeFile = new File(saveTreeLocation);


        forestTrainer5Trees.trainSerialOnDisk(Optional.empty());
        forestTrainer10Trees.trainSerialOnDisk(Optional.empty());
        final Forest<Double, Double> forestSerial = DataUtils.loadOnlineForest(saveTreeFile, new MeanResponseCombiner());
        TestUtils.removeFolder(saveTreeFile);
        verifyTwoForestsEqual(testData, referenceForest, forestSerial);


        forestTrainer5Trees.trainParallelOnDisk(Optional.empty(), 4);
        forestTrainer10Trees.trainParallelOnDisk(Optional.empty(), 4);
        final Forest<Double, Double> forestParallel = DataUtils.loadOnlineForest(saveTreeFile, new MeanResponseCombiner());
        TestUtils.removeFolder(saveTreeFile);
        verifyTwoForestsEqual(testData, referenceForest, forestParallel);

    }

    private void verifySerialInMemoryTraining(
            final Forest<Double, Double> referenceForest,
            ForestTrainer<Double, Double, Double> forestTrainer,
            List<Row<Double>> testData){

        for(int k=0; k<3; k++){
            final Forest<Double, Double> replicantForest = forestTrainer.trainSerialInMemory(Optional.empty());
            verifyTwoForestsEqual(testData, referenceForest, replicantForest);
        }
    }

    private void verifyParallelInMemoryTraining(
            Forest<Double, Double> referenceForest,
            ForestTrainer<Double, Double, Double> forestTrainer,
            List<Row<Double>> testData){

        for(int k=0; k<3; k++){
            final Forest<Double, Double> replicantForest = forestTrainer.trainParallelInMemory(Optional.empty(), 4);
            verifyTwoForestsEqual(testData, referenceForest, replicantForest);
        }
    }

    private void verifySerialOnDiskTraining(
            Forest<Double, Double> referenceForest,
            ForestTrainer<Double, Double, Double> forestTrainer,
            List<Row<Double>> testData) throws IOException, ClassNotFoundException {

        final MeanResponseCombiner responseCombiner = new MeanResponseCombiner();
        final File saveTreeFile = new File(saveTreeLocation);

        for(int k=0; k<3; k++){
            forestTrainer.trainSerialOnDisk(Optional.empty());
            final Forest<Double, Double> replicantForest = DataUtils.loadOnlineForest(saveTreeFile, responseCombiner);
            TestUtils.removeFolder(saveTreeFile);
            verifyTwoForestsEqual(testData, referenceForest, replicantForest);
        }
    }

    private void verifyParallelOnDiskTraining(
            final Forest<Double, Double> referenceForest, ForestTrainer<Double, Double, Double> forestTrainer,
            List<Row<Double>> testData) throws IOException, ClassNotFoundException {

        final MeanResponseCombiner responseCombiner = new MeanResponseCombiner();
        final File saveTreeFile = new File(saveTreeLocation);

        for(int k=0; k<3; k++){
            forestTrainer.trainParallelOnDisk(Optional.empty(), 4);
            final Forest<Double, Double> replicantForest = DataUtils.loadOnlineForest(saveTreeFile, responseCombiner);
            TestUtils.removeFolder(saveTreeFile);
            verifyTwoForestsEqual(testData, referenceForest, replicantForest);
        }
    }

    // Technically verifies the two forests give equal predictions on a given test dataset
    private void verifyTwoForestsEqual(final List<Row<Double>> testData,
                                      final Forest<Double, Double> forest1,
                                      final Forest<Double, Double> forest2){

        for(Row row : testData){
            final Double prediction1 = forest1.evaluate(row);
            final Double prediction2 = forest2.evaluate(row);

            // I've noticed that results aren't necessarily always *identical*
            TestUtils.closeEnough(prediction1, prediction2, 0.0000000001);
        }

    }


}