Change ResponseCombiner to be a Collector that's compatible

with Streams.
This commit is contained in:
Joel Therrien 2018-07-02 12:27:18 -07:00
parent 3c9c78741f
commit 6192643e12
3 changed files with 66 additions and 6 deletions

View file

@ -1,8 +1,9 @@
package ca.joeltherrien.randomforest;
import java.util.List;
import java.util.stream.Collector;
public interface ResponseCombiner<Y> {
public interface ResponseCombiner<Y, K> extends Collector<Y, K, Y> {
Y combine(List<Y> responses);

View file

@ -3,8 +3,19 @@ package ca.joeltherrien.randomforest.regression;
import ca.joeltherrien.randomforest.ResponseCombiner;
import java.util.List;
import java.util.Set;
import java.util.function.BiConsumer;
import java.util.function.BinaryOperator;
import java.util.function.Function;
import java.util.function.Supplier;
public class MeanResponseCombiner implements ResponseCombiner<Double> {
/**
* This implementation of the collector isn't great... but good enough given that I'm not planning to fully support regression trees.
*
* (It's not great because you'll lose accuracy as you sum up the doubles, since dividing by n is the very last step.)
*
*/
public class MeanResponseCombiner implements ResponseCombiner<Double, MeanResponseCombiner.Container> {
@Override
public Double combine(List<Double> responses) {
@ -13,4 +24,52 @@ public class MeanResponseCombiner implements ResponseCombiner<Double> {
return responses.stream().mapToDouble(db -> db/size).sum();
}
@Override
public Supplier<Container> supplier() {
return () -> new Container(0 ,0);
}
@Override
public BiConsumer<Container, Double> accumulator() {
return (container, number) -> {
container.number+=number;
container.n++;
};
}
@Override
public BinaryOperator<Container> combiner() {
return (c1, c2) -> {
c1.number += c2.number;
c1.n += c2.n;
return c1;
};
}
@Override
public Function<Container, Double> finisher() {
return (container) -> container.number/(double)container.n;
}
@Override
public Set<Characteristics> characteristics() {
return Set.of(Characteristics.UNORDERED);
}
public static class Container{
Container(double number, int n){
this.number = number;
this.n = n;
}
public Double number;
public int n;
}
}

View file

@ -13,7 +13,7 @@ import java.util.stream.Collectors;
@Builder
public class TreeTrainer<Y> {
private final ResponseCombiner<Y> responseCombiner;
private final ResponseCombiner<Y, ?> responseCombiner;
private final GroupDifferentiator<Y> groupDifferentiator;
/**
@ -43,10 +43,10 @@ public class TreeTrainer<Y> {
}
else{
return new TerminalNode<>(responseCombiner.combine(
return new TerminalNode<>(
data.stream()
.map(row -> row.getResponse())
.collect(Collectors.toList()))
.map(row -> row.getResponse())
.collect(responseCombiner)
);
}