package ca.joeltherrien.randomforest.competingrisk;

import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskFunctions;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
import ca.joeltherrien.randomforest.responses.competingrisk.combiner.CompetingRiskResponseCombiner;
import ca.joeltherrien.randomforest.utils.MathFunction;
import org.junit.jupiter.api.Test;

import static ca.joeltherrien.randomforest.TestUtils.closeEnough;

import java.util.ArrayList;
import java.util.List;

public class TestCompetingRiskResponseCombiner {

    private CompetingRiskFunctions generateFunctions(){
        final List<CompetingRiskResponse> data = new ArrayList<>();

        data.add(new CompetingRiskResponse(1, 1.0));
        data.add(new CompetingRiskResponse(1, 1.0));
        data.add(new CompetingRiskResponse(1, 2.0));
        data.add(new CompetingRiskResponse(2, 1.5));
        data.add(new CompetingRiskResponse(2, 2.0));
        data.add(new CompetingRiskResponse(0, 1.5));
        data.add(new CompetingRiskResponse(0, 2.5));

        final CompetingRiskResponseCombiner combiner = new CompetingRiskResponseCombiner(new int[]{1,2}, null);

        return combiner.combine(data);
    }

    @Test
    public void testCompetingRiskResponseCombiner(){
        final CompetingRiskFunctions functions = generateFunctions();

        final MathFunction survivalCurve = functions.getSurvivalCurve();

        // time = 1.0 1.5 2.0 2.5
        // surv = 0.7142857 0.5714286 0.1904762 0.1904762

        final double margin = 0.0000001;

        closeEnough(0.7142857, survivalCurve.evaluate(1.0).getY(), margin);
        closeEnough(0.5714286, survivalCurve.evaluate(1.5).getY(), margin);
        closeEnough(0.1904762, survivalCurve.evaluate(2.0).getY(), margin);
        closeEnough(0.1904762, survivalCurve.evaluate(2.5).getY(), margin);


        // Time = 1.0 1.5 2.0 2.5
        /* Cumulative hazard function. Each row for one event.
                  [,1]       [,2]       [,3]       [,4]
        [1,]  0.2857143  0.2857143  0.6190476  0.6190476
        [2,]  0.0000000  0.2000000  0.5333333  0.5333333
         */

        final MathFunction cumHaz1 = functions.getCauseSpecificHazardFunction(1);
        closeEnough(0.2857143, cumHaz1.evaluate(1.0).getY(), margin);
        closeEnough(0.2857143, cumHaz1.evaluate(1.5).getY(), margin);
        closeEnough(0.6190476, cumHaz1.evaluate(2.0).getY(), margin);
        closeEnough(0.6190476, cumHaz1.evaluate(2.5).getY(), margin);

        final MathFunction cumHaz2 = functions.getCauseSpecificHazardFunction(2);
        closeEnough(0.0, cumHaz2.evaluate(1.0).getY(), margin);
        closeEnough(0.2, cumHaz2.evaluate(1.5).getY(), margin);
        closeEnough(0.5333333, cumHaz2.evaluate(2.0).getY(), margin);
        closeEnough(0.5333333, cumHaz2.evaluate(2.5).getY(), margin);

        /* Time = 1.0 1.5 2.0 2.5
        Cumulative Incidence Curve. Each row for one event.
                  [,1]      [,2]      [,3]      [,4]
        [1,] 0.2857143 0.2857143 0.4761905 0.4761905
        [2,] 0.0000000 0.1428571 0.3333333 0.3333333
         */

        final MathFunction cic1 = functions.getCumulativeIncidenceFunction(1);
        closeEnough(0.2857143, cic1.evaluate(1.0).getY(), margin);
        closeEnough(0.2857143, cic1.evaluate(1.5).getY(), margin);
        closeEnough(0.4761905, cic1.evaluate(2.0).getY(), margin);
        closeEnough(0.4761905, cic1.evaluate(2.5).getY(), margin);

        final MathFunction cic2 = functions.getCumulativeIncidenceFunction(2);
        closeEnough(0.0, cic2.evaluate(1.0).getY(), margin);
        closeEnough(0.1428571, cic2.evaluate(1.5).getY(), margin);
        closeEnough(0.3333333, cic2.evaluate(2.0).getY(), margin);
        closeEnough(0.3333333, cic2.evaluate(2.5).getY(), margin);

    }

}