Improve WeightedVarianceGroupDifferentiator to be faster
This commit is contained in:
parent
ee137370a1
commit
d935fe0bc0
1 changed files with 120 additions and 14 deletions
|
@ -16,31 +16,137 @@
|
|||
|
||||
package ca.joeltherrien.randomforest.responses.regression;
|
||||
|
||||
import ca.joeltherrien.randomforest.tree.SimpleGroupDifferentiator;
|
||||
import ca.joeltherrien.randomforest.Row;
|
||||
import ca.joeltherrien.randomforest.covariates.Covariate;
|
||||
import ca.joeltherrien.randomforest.tree.GroupDifferentiator;
|
||||
import ca.joeltherrien.randomforest.tree.Split;
|
||||
import ca.joeltherrien.randomforest.tree.SplitAndScore;
|
||||
|
||||
import java.util.Iterator;
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
public class WeightedVarianceGroupDifferentiator extends SimpleGroupDifferentiator<Double> {
|
||||
public class WeightedVarianceGroupDifferentiator implements GroupDifferentiator<Double> {
|
||||
|
||||
@Override
|
||||
public Double getScore(List<Double> leftHand, List<Double> rightHand) {
|
||||
private Double getScore(Set leftHand, Set rightHand) {
|
||||
|
||||
final double leftHandSize = leftHand.size();
|
||||
final double rightHandSize = rightHand.size();
|
||||
final double n = leftHandSize + rightHandSize;
|
||||
|
||||
if(leftHandSize == 0 || rightHandSize == 0){
|
||||
if(leftHand.n == 0 || rightHand.n == 0){
|
||||
return null;
|
||||
}
|
||||
|
||||
final double leftHandMean = leftHand.stream().mapToDouble(db -> db/leftHandSize).sum();
|
||||
final double rightHandMean = rightHand.stream().mapToDouble(db -> db/rightHandSize).sum();
|
||||
final double leftHandMean = leftHand.getMean();
|
||||
final double rightHandMean = rightHand.getMean();
|
||||
|
||||
final double leftVariance = leftHand.stream().mapToDouble(db -> (db - leftHandMean)*(db - leftHandMean)).sum();
|
||||
final double rightVariance = rightHand.stream().mapToDouble(db -> (db - rightHandMean)*(db - rightHandMean)).sum();
|
||||
final double leftVariance = leftHand.summationSquared - ((double) leftHand.n) * leftHandMean*leftHandMean;
|
||||
final double rightVariance = rightHand.summationSquared - ((double) rightHand.n) * rightHandMean*rightHandMean;
|
||||
|
||||
return -(leftVariance + rightVariance) / n;
|
||||
return -(leftVariance + rightVariance) / (leftHand.n + rightHand.n);
|
||||
}
|
||||
|
||||
@Override
|
||||
public SplitAndScore<Double, ?> differentiate(Iterator<Split<Double, ?>> splitIterator) {
|
||||
|
||||
if(splitIterator instanceof Covariate.SplitRuleUpdater){
|
||||
return differentiateWithSplitUpdater((Covariate.SplitRuleUpdater) splitIterator);
|
||||
}
|
||||
else{
|
||||
return differentiateWithBasicIterator(splitIterator);
|
||||
}
|
||||
}
|
||||
|
||||
private SplitAndScore<Double, ?> differentiateWithBasicIterator(Iterator<Split<Double, ?>> splitIterator){
|
||||
Double bestScore = null;
|
||||
Split<Double, ?> bestSplit = null;
|
||||
|
||||
while(splitIterator.hasNext()){
|
||||
final Split<Double, ?> candidateSplit = splitIterator.next();
|
||||
|
||||
final List<Double> leftHandList = candidateSplit.getLeftHand().stream().map(Row::getResponse).collect(Collectors.toList());
|
||||
final List<Double> rightHandList = candidateSplit.getRightHand().stream().map(Row::getResponse).collect(Collectors.toList());
|
||||
|
||||
if(leftHandList.isEmpty() || rightHandList.isEmpty()){
|
||||
continue;
|
||||
}
|
||||
|
||||
final Set setLeft = new Set(leftHandList);
|
||||
final Set setRight = new Set(rightHandList);
|
||||
|
||||
final Double score = getScore(setLeft, setRight);
|
||||
|
||||
if(score != null && Double.isFinite(score) && (bestScore == null || score > bestScore)){
|
||||
bestScore = score;
|
||||
bestSplit = candidateSplit;
|
||||
}
|
||||
}
|
||||
|
||||
if(bestSplit == null){
|
||||
return null;
|
||||
}
|
||||
|
||||
return new SplitAndScore<>(bestSplit, bestScore);
|
||||
}
|
||||
|
||||
private SplitAndScore<Double, ?> differentiateWithSplitUpdater(Covariate.SplitRuleUpdater<Double, ?> splitRuleUpdater) {
|
||||
|
||||
final List<Double> leftInitialSplit = splitRuleUpdater.currentSplit().getLeftHand()
|
||||
.stream().map(Row::getResponse).collect(Collectors.toList());
|
||||
final List<Double> rightInitialSplit = splitRuleUpdater.currentSplit().getRightHand()
|
||||
.stream().map(Row::getResponse).collect(Collectors.toList());
|
||||
|
||||
final Set setLeft = new Set(leftInitialSplit);
|
||||
final Set setRight = new Set(rightInitialSplit);
|
||||
|
||||
Double bestScore = null;
|
||||
Split<Double, ?> bestSplit = null;
|
||||
|
||||
while(splitRuleUpdater.hasNext()){
|
||||
for(Row<Double> rowMoved : splitRuleUpdater.nextUpdate().rowsMovedToLeftHand()){
|
||||
setLeft.updateAdd(rowMoved.getResponse());
|
||||
setRight.updateRemove(rowMoved.getResponse());
|
||||
}
|
||||
|
||||
final Double score = getScore(setLeft, setRight);
|
||||
|
||||
if(score != null && Double.isFinite(score) && (bestScore == null || score > bestScore)){
|
||||
bestScore = score;
|
||||
bestSplit = splitRuleUpdater.currentSplit();
|
||||
}
|
||||
}
|
||||
|
||||
if(bestSplit == null){
|
||||
return null;
|
||||
}
|
||||
|
||||
return new SplitAndScore<>(bestSplit, bestScore);
|
||||
|
||||
}
|
||||
|
||||
private class Set {
|
||||
private int n = 0;
|
||||
private double summation = 0.0;
|
||||
private double summationSquared = 0.0;
|
||||
|
||||
private Set(List<Double> list){
|
||||
for(Double number : list){
|
||||
updateAdd(number);
|
||||
}
|
||||
}
|
||||
|
||||
private double getMean(){
|
||||
return summation / n;
|
||||
}
|
||||
|
||||
private void updateAdd(double number){
|
||||
summation += number;
|
||||
summationSquared += number*number;
|
||||
n++;
|
||||
}
|
||||
|
||||
private void updateRemove(double number){
|
||||
summation -= number;
|
||||
summationSquared -= number*number;
|
||||
n--;
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue