package ca.joeltherrien.randomforest.tree;

import ca.joeltherrien.randomforest.*;
import lombok.Builder;

import java.util.*;
import java.util.stream.Collectors;

@Builder
public class TreeTrainer<Y> {

    private final ResponseCombiner<Y, ?> responseCombiner;
    private final GroupDifferentiator<Y> groupDifferentiator;

    /**
     * The number of splits to perform on each covariate. A value of 0 means all possible splits are tried.
     *
     */
    private final int numberOfSplits;
    private final int nodeSize;
    private final int maxNodeDepth;


    public Node<Y> growTree(List<Row<Y>> data, List<Covariate> covariatesToTry){
        return growNode(data, covariatesToTry, 0);
    }

    private Node<Y> growNode(List<Row<Y>> data, List<Covariate> covariatesToTry, int depth){
        // TODO; what is minimum per tree?
        if(data.size() >= 2*nodeSize && depth < maxNodeDepth && !nodeIsPure(data)){
            final Covariate.SplitRule bestSplitRule = findBestSplitRule(data, covariatesToTry);

            if(bestSplitRule == null){
                return new TerminalNode<>(
                        data.stream()
                                .map(row -> row.getResponse())
                                .collect(responseCombiner)

                );
            }

            final Split<Y> split = bestSplitRule.applyRule(data); // TODO optimize this as we're duplicating work done in findBestSplitRule

            final Node<Y> leftNode = growNode(split.leftHand, covariatesToTry, depth+1);
            final Node<Y> rightNode = growNode(split.rightHand, covariatesToTry, depth+1);

            return new SplitNode<>(leftNode, rightNode, bestSplitRule);

        }
        else{
            return new TerminalNode<>(
                    data.stream()
                        .map(row -> row.getResponse())
                        .collect(responseCombiner)

            );
        }


    }

    private Covariate.SplitRule findBestSplitRule(List<Row<Y>> data, List<Covariate> covariatesToTry){
        Covariate.SplitRule bestSplitRule = null;
        double bestSplitScore = 0.0;
        boolean first = true;

        for(final Covariate covariate : covariatesToTry){

            final int numberToTry = numberOfSplits==0 ? data.size() : numberOfSplits;

            final Collection<Covariate.SplitRule> splitRulesToTry = covariate
                    .generateSplitRules(
                            data
                                    .stream()
                                    .map(row -> row.getCovariateValue(covariate.getName()))
                                    .collect(Collectors.toList())
                            , numberToTry);

            for(final Covariate.SplitRule possibleRule : splitRulesToTry){
                final Split<Y> possibleSplit = possibleRule.applyRule(data);

                final Double score = groupDifferentiator.differentiate(
                        possibleSplit.leftHand.stream().map(row -> row.getResponse()).collect(Collectors.toList()),
                        possibleSplit.rightHand.stream().map(row -> row.getResponse()).collect(Collectors.toList())
                );

                if(score != null && (score > bestSplitScore || first)){
                    bestSplitRule = possibleRule;
                    bestSplitScore = score;
                    first = false;
                }
            }

        }

        return bestSplitRule;

    }

    private boolean nodeIsPure(List<Row<Y>> data){
        final Y first = data.get(0).getResponse();
        return data.stream().allMatch(row -> row.getResponse().equals(first));
    }

}