diff --git a/src/main/java/ca/joeltherrien/randomforest/responses/regression/WeightedVarianceGroupDifferentiator.java b/src/main/java/ca/joeltherrien/randomforest/responses/regression/WeightedVarianceGroupDifferentiator.java index cb57e9f..a16f403 100644 --- a/src/main/java/ca/joeltherrien/randomforest/responses/regression/WeightedVarianceGroupDifferentiator.java +++ b/src/main/java/ca/joeltherrien/randomforest/responses/regression/WeightedVarianceGroupDifferentiator.java @@ -16,31 +16,137 @@ package ca.joeltherrien.randomforest.responses.regression; -import ca.joeltherrien.randomforest.tree.SimpleGroupDifferentiator; +import ca.joeltherrien.randomforest.Row; +import ca.joeltherrien.randomforest.covariates.Covariate; +import ca.joeltherrien.randomforest.tree.GroupDifferentiator; +import ca.joeltherrien.randomforest.tree.Split; +import ca.joeltherrien.randomforest.tree.SplitAndScore; +import java.util.Iterator; import java.util.List; +import java.util.stream.Collectors; -public class WeightedVarianceGroupDifferentiator extends SimpleGroupDifferentiator { +public class WeightedVarianceGroupDifferentiator implements GroupDifferentiator { - @Override - public Double getScore(List leftHand, List rightHand) { + private Double getScore(Set leftHand, Set rightHand) { - final double leftHandSize = leftHand.size(); - final double rightHandSize = rightHand.size(); - final double n = leftHandSize + rightHandSize; - - if(leftHandSize == 0 || rightHandSize == 0){ + if(leftHand.n == 0 || rightHand.n == 0){ return null; } - final double leftHandMean = leftHand.stream().mapToDouble(db -> db/leftHandSize).sum(); - final double rightHandMean = rightHand.stream().mapToDouble(db -> db/rightHandSize).sum(); + final double leftHandMean = leftHand.getMean(); + final double rightHandMean = rightHand.getMean(); - final double leftVariance = leftHand.stream().mapToDouble(db -> (db - leftHandMean)*(db - leftHandMean)).sum(); - final double rightVariance = rightHand.stream().mapToDouble(db -> (db - rightHandMean)*(db - rightHandMean)).sum(); + final double leftVariance = leftHand.summationSquared - ((double) leftHand.n) * leftHandMean*leftHandMean; + final double rightVariance = rightHand.summationSquared - ((double) rightHand.n) * rightHandMean*rightHandMean; - return -(leftVariance + rightVariance) / n; + return -(leftVariance + rightVariance) / (leftHand.n + rightHand.n); + } + + @Override + public SplitAndScore differentiate(Iterator> splitIterator) { + + if(splitIterator instanceof Covariate.SplitRuleUpdater){ + return differentiateWithSplitUpdater((Covariate.SplitRuleUpdater) splitIterator); + } + else{ + return differentiateWithBasicIterator(splitIterator); + } + } + + private SplitAndScore differentiateWithBasicIterator(Iterator> splitIterator){ + Double bestScore = null; + Split bestSplit = null; + + while(splitIterator.hasNext()){ + final Split candidateSplit = splitIterator.next(); + + final List leftHandList = candidateSplit.getLeftHand().stream().map(Row::getResponse).collect(Collectors.toList()); + final List rightHandList = candidateSplit.getRightHand().stream().map(Row::getResponse).collect(Collectors.toList()); + + if(leftHandList.isEmpty() || rightHandList.isEmpty()){ + continue; + } + + final Set setLeft = new Set(leftHandList); + final Set setRight = new Set(rightHandList); + + final Double score = getScore(setLeft, setRight); + + if(score != null && Double.isFinite(score) && (bestScore == null || score > bestScore)){ + bestScore = score; + bestSplit = candidateSplit; + } + } + + if(bestSplit == null){ + return null; + } + + return new SplitAndScore<>(bestSplit, bestScore); + } + + private SplitAndScore differentiateWithSplitUpdater(Covariate.SplitRuleUpdater splitRuleUpdater) { + + final List leftInitialSplit = splitRuleUpdater.currentSplit().getLeftHand() + .stream().map(Row::getResponse).collect(Collectors.toList()); + final List rightInitialSplit = splitRuleUpdater.currentSplit().getRightHand() + .stream().map(Row::getResponse).collect(Collectors.toList()); + + final Set setLeft = new Set(leftInitialSplit); + final Set setRight = new Set(rightInitialSplit); + + Double bestScore = null; + Split bestSplit = null; + + while(splitRuleUpdater.hasNext()){ + for(Row rowMoved : splitRuleUpdater.nextUpdate().rowsMovedToLeftHand()){ + setLeft.updateAdd(rowMoved.getResponse()); + setRight.updateRemove(rowMoved.getResponse()); + } + + final Double score = getScore(setLeft, setRight); + + if(score != null && Double.isFinite(score) && (bestScore == null || score > bestScore)){ + bestScore = score; + bestSplit = splitRuleUpdater.currentSplit(); + } + } + + if(bestSplit == null){ + return null; + } + + return new SplitAndScore<>(bestSplit, bestScore); } + private class Set { + private int n = 0; + private double summation = 0.0; + private double summationSquared = 0.0; + + private Set(List list){ + for(Double number : list){ + updateAdd(number); + } + } + + private double getMean(){ + return summation / n; + } + + private void updateAdd(double number){ + summation += number; + summationSquared += number*number; + n++; + } + + private void updateRemove(double number){ + summation -= number; + summationSquared -= number*number; + n--; + } + } + }