Improve WeightedVarianceGroupDifferentiator to be faster
This commit is contained in:
parent
ee137370a1
commit
d935fe0bc0
1 changed files with 120 additions and 14 deletions
|
@ -16,31 +16,137 @@
|
||||||
|
|
||||||
package ca.joeltherrien.randomforest.responses.regression;
|
package ca.joeltherrien.randomforest.responses.regression;
|
||||||
|
|
||||||
import ca.joeltherrien.randomforest.tree.SimpleGroupDifferentiator;
|
import ca.joeltherrien.randomforest.Row;
|
||||||
|
import ca.joeltherrien.randomforest.covariates.Covariate;
|
||||||
|
import ca.joeltherrien.randomforest.tree.GroupDifferentiator;
|
||||||
|
import ca.joeltherrien.randomforest.tree.Split;
|
||||||
|
import ca.joeltherrien.randomforest.tree.SplitAndScore;
|
||||||
|
|
||||||
|
import java.util.Iterator;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
public class WeightedVarianceGroupDifferentiator extends SimpleGroupDifferentiator<Double> {
|
public class WeightedVarianceGroupDifferentiator implements GroupDifferentiator<Double> {
|
||||||
|
|
||||||
@Override
|
private Double getScore(Set leftHand, Set rightHand) {
|
||||||
public Double getScore(List<Double> leftHand, List<Double> rightHand) {
|
|
||||||
|
|
||||||
final double leftHandSize = leftHand.size();
|
if(leftHand.n == 0 || rightHand.n == 0){
|
||||||
final double rightHandSize = rightHand.size();
|
|
||||||
final double n = leftHandSize + rightHandSize;
|
|
||||||
|
|
||||||
if(leftHandSize == 0 || rightHandSize == 0){
|
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
final double leftHandMean = leftHand.stream().mapToDouble(db -> db/leftHandSize).sum();
|
final double leftHandMean = leftHand.getMean();
|
||||||
final double rightHandMean = rightHand.stream().mapToDouble(db -> db/rightHandSize).sum();
|
final double rightHandMean = rightHand.getMean();
|
||||||
|
|
||||||
final double leftVariance = leftHand.stream().mapToDouble(db -> (db - leftHandMean)*(db - leftHandMean)).sum();
|
final double leftVariance = leftHand.summationSquared - ((double) leftHand.n) * leftHandMean*leftHandMean;
|
||||||
final double rightVariance = rightHand.stream().mapToDouble(db -> (db - rightHandMean)*(db - rightHandMean)).sum();
|
final double rightVariance = rightHand.summationSquared - ((double) rightHand.n) * rightHandMean*rightHandMean;
|
||||||
|
|
||||||
return -(leftVariance + rightVariance) / n;
|
return -(leftVariance + rightVariance) / (leftHand.n + rightHand.n);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public SplitAndScore<Double, ?> differentiate(Iterator<Split<Double, ?>> splitIterator) {
|
||||||
|
|
||||||
|
if(splitIterator instanceof Covariate.SplitRuleUpdater){
|
||||||
|
return differentiateWithSplitUpdater((Covariate.SplitRuleUpdater) splitIterator);
|
||||||
|
}
|
||||||
|
else{
|
||||||
|
return differentiateWithBasicIterator(splitIterator);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private SplitAndScore<Double, ?> differentiateWithBasicIterator(Iterator<Split<Double, ?>> splitIterator){
|
||||||
|
Double bestScore = null;
|
||||||
|
Split<Double, ?> bestSplit = null;
|
||||||
|
|
||||||
|
while(splitIterator.hasNext()){
|
||||||
|
final Split<Double, ?> candidateSplit = splitIterator.next();
|
||||||
|
|
||||||
|
final List<Double> leftHandList = candidateSplit.getLeftHand().stream().map(Row::getResponse).collect(Collectors.toList());
|
||||||
|
final List<Double> rightHandList = candidateSplit.getRightHand().stream().map(Row::getResponse).collect(Collectors.toList());
|
||||||
|
|
||||||
|
if(leftHandList.isEmpty() || rightHandList.isEmpty()){
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
final Set setLeft = new Set(leftHandList);
|
||||||
|
final Set setRight = new Set(rightHandList);
|
||||||
|
|
||||||
|
final Double score = getScore(setLeft, setRight);
|
||||||
|
|
||||||
|
if(score != null && Double.isFinite(score) && (bestScore == null || score > bestScore)){
|
||||||
|
bestScore = score;
|
||||||
|
bestSplit = candidateSplit;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if(bestSplit == null){
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
return new SplitAndScore<>(bestSplit, bestScore);
|
||||||
|
}
|
||||||
|
|
||||||
|
private SplitAndScore<Double, ?> differentiateWithSplitUpdater(Covariate.SplitRuleUpdater<Double, ?> splitRuleUpdater) {
|
||||||
|
|
||||||
|
final List<Double> leftInitialSplit = splitRuleUpdater.currentSplit().getLeftHand()
|
||||||
|
.stream().map(Row::getResponse).collect(Collectors.toList());
|
||||||
|
final List<Double> rightInitialSplit = splitRuleUpdater.currentSplit().getRightHand()
|
||||||
|
.stream().map(Row::getResponse).collect(Collectors.toList());
|
||||||
|
|
||||||
|
final Set setLeft = new Set(leftInitialSplit);
|
||||||
|
final Set setRight = new Set(rightInitialSplit);
|
||||||
|
|
||||||
|
Double bestScore = null;
|
||||||
|
Split<Double, ?> bestSplit = null;
|
||||||
|
|
||||||
|
while(splitRuleUpdater.hasNext()){
|
||||||
|
for(Row<Double> rowMoved : splitRuleUpdater.nextUpdate().rowsMovedToLeftHand()){
|
||||||
|
setLeft.updateAdd(rowMoved.getResponse());
|
||||||
|
setRight.updateRemove(rowMoved.getResponse());
|
||||||
|
}
|
||||||
|
|
||||||
|
final Double score = getScore(setLeft, setRight);
|
||||||
|
|
||||||
|
if(score != null && Double.isFinite(score) && (bestScore == null || score > bestScore)){
|
||||||
|
bestScore = score;
|
||||||
|
bestSplit = splitRuleUpdater.currentSplit();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if(bestSplit == null){
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
return new SplitAndScore<>(bestSplit, bestScore);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private class Set {
|
||||||
|
private int n = 0;
|
||||||
|
private double summation = 0.0;
|
||||||
|
private double summationSquared = 0.0;
|
||||||
|
|
||||||
|
private Set(List<Double> list){
|
||||||
|
for(Double number : list){
|
||||||
|
updateAdd(number);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private double getMean(){
|
||||||
|
return summation / n;
|
||||||
|
}
|
||||||
|
|
||||||
|
private void updateAdd(double number){
|
||||||
|
summation += number;
|
||||||
|
summationSquared += number*number;
|
||||||
|
n++;
|
||||||
|
}
|
||||||
|
|
||||||
|
private void updateRemove(double number){
|
||||||
|
summation -= number;
|
||||||
|
summationSquared -= number*number;
|
||||||
|
n--;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue