largeRCRF-Java/src/main/java/ca/joeltherrien/randomforest/tree/TreeTrainer.java

174 lines
5.7 KiB
Java

package ca.joeltherrien.randomforest.tree;
import ca.joeltherrien.randomforest.Row;
import ca.joeltherrien.randomforest.Settings;
import ca.joeltherrien.randomforest.covariates.Covariate;
import lombok.AccessLevel;
import lombok.AllArgsConstructor;
import lombok.Builder;
import java.util.*;
import java.util.stream.Collectors;
@Builder
@AllArgsConstructor(access = AccessLevel.PRIVATE)
public class TreeTrainer<Y, O> {
private final ResponseCombiner<Y, O> responseCombiner;
private final GroupDifferentiator<Y> groupDifferentiator;
/**
* The number of splits to perform on each covariate. A value of 0 means all possible splits are tried.
*
*/
private final int numberOfSplits;
private final int nodeSize;
private final int maxNodeDepth;
private final int mtry;
/**
* Whether to check if a node is pure or not when deciding to split. Splitting on a pure node won't change predictive accuracy,
* but (depending on conditions) may hurt performance.
*/
private final boolean checkNodePurity;
private final List<Covariate> covariates;
public TreeTrainer(final Settings settings, final List<Covariate> covariates){
this.numberOfSplits = settings.getNumberOfSplits();
this.nodeSize = settings.getNodeSize();
this.maxNodeDepth = settings.getMaxNodeDepth();
this.mtry = settings.getMtry();
this.checkNodePurity = settings.isCheckNodePurity();
this.responseCombiner = settings.getResponseCombiner();
this.groupDifferentiator = settings.getGroupDifferentiator();
this.covariates = covariates;
}
public Tree<O> growTree(List<Row<Y>> data, Random random){
final Node<O> rootNode = growNode(data, 0, random);
return new Tree<>(rootNode, data.stream().mapToInt(Row::getId).toArray());
}
private Node<O> growNode(List<Row<Y>> data, int depth, Random random){
// See https://kogalur.github.io/randomForestSRC/theory.html#section3.1 (near bottom)
if(data.size() >= 2*nodeSize && depth < maxNodeDepth && !nodeIsPure(data)){
final List<Covariate> covariatesToTry = selectCovariates(this.mtry, random);
final Split<Y,?> bestSplit = findBestSplitRule(data, covariatesToTry, random);
if(bestSplit == null){
return new TerminalNode<>(
responseCombiner.combine(
data.stream().map(row -> row.getResponse()).collect(Collectors.toList())
)
);
}
// Now that we have the best split; we need to handle any NAs that were dropped off
final double probabilityLeftHand = (double) bestSplit.leftHand.size() /
(double) (bestSplit.leftHand.size() + bestSplit.rightHand.size());
// Assign missing values to the split if necessary
if(bestSplit.getSplitRule().getParent().hasNAs()){
for(Row<Y> row : data) {
if(row.getCovariateValue(bestSplit.getSplitRule().getParent()).isNA()) {
final boolean randomDecision = random.nextDouble() <= probabilityLeftHand;
if(randomDecision){
bestSplit.getLeftHand().add(row);
}
else{
bestSplit.getRightHand().add(row);
}
}
}
}
final Node<O> leftNode = growNode(bestSplit.leftHand, depth+1, random);
final Node<O> rightNode = growNode(bestSplit.rightHand, depth+1, random);
return new SplitNode<>(leftNode, rightNode, bestSplit.getSplitRule(), probabilityLeftHand);
}
else{
return new TerminalNode<>(
responseCombiner.combine(
data.stream().map(row -> row.getResponse()).collect(Collectors.toList())
)
);
}
}
private List<Covariate> selectCovariates(int mtry, Random random){
if(mtry >= covariates.size()){
return covariates;
}
final List<Covariate> splitCovariates = new ArrayList<>(covariates);
Collections.shuffle(splitCovariates, random);
if (splitCovariates.size() > mtry) {
splitCovariates.subList(mtry, splitCovariates.size()).clear();
}
return splitCovariates;
}
private Split<Y, ?> findBestSplitRule(List<Row<Y>> data, List<Covariate> covariatesToTry, Random random){
SplitAndScore<Y, ?> bestSplitAndScore = null;
final GroupDifferentiator noGenericDifferentiator = groupDifferentiator; // cause Java generics suck
for(final Covariate covariate : covariatesToTry) {
final Iterator<Split> iterator = covariate.generateSplitRuleUpdater(data, this.numberOfSplits, random);
final SplitAndScore<Y, ?> candidateSplitAndScore = noGenericDifferentiator.differentiate(iterator);
if(candidateSplitAndScore != null && (bestSplitAndScore == null ||
candidateSplitAndScore.getScore() > bestSplitAndScore.getScore())) {
bestSplitAndScore = candidateSplitAndScore;
}
}
if(bestSplitAndScore == null){
return null;
}
return bestSplitAndScore.getSplit();
}
private boolean nodeIsPure(List<Row<Y>> data){
if(!checkNodePurity){
return false;
}
if(data.size() <= 1){
return true;
}
final Y first = data.get(0).getResponse();
for(int i = 1; i< data.size(); i++){
if(!data.get(i).getResponse().equals(first)){
return false;
}
}
return true;
}
}