Change ResponseCombiner to be a Collector that's compatible
with Streams.
This commit is contained in:
parent
3c9c78741f
commit
6192643e12
3 changed files with 66 additions and 6 deletions
|
@ -1,8 +1,9 @@
|
|||
package ca.joeltherrien.randomforest;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.stream.Collector;
|
||||
|
||||
public interface ResponseCombiner<Y> {
|
||||
public interface ResponseCombiner<Y, K> extends Collector<Y, K, Y> {
|
||||
|
||||
Y combine(List<Y> responses);
|
||||
|
||||
|
|
|
@ -3,8 +3,19 @@ package ca.joeltherrien.randomforest.regression;
|
|||
import ca.joeltherrien.randomforest.ResponseCombiner;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
import java.util.function.BiConsumer;
|
||||
import java.util.function.BinaryOperator;
|
||||
import java.util.function.Function;
|
||||
import java.util.function.Supplier;
|
||||
|
||||
public class MeanResponseCombiner implements ResponseCombiner<Double> {
|
||||
/**
|
||||
* This implementation of the collector isn't great... but good enough given that I'm not planning to fully support regression trees.
|
||||
*
|
||||
* (It's not great because you'll lose accuracy as you sum up the doubles, since dividing by n is the very last step.)
|
||||
*
|
||||
*/
|
||||
public class MeanResponseCombiner implements ResponseCombiner<Double, MeanResponseCombiner.Container> {
|
||||
|
||||
@Override
|
||||
public Double combine(List<Double> responses) {
|
||||
|
@ -13,4 +24,52 @@ public class MeanResponseCombiner implements ResponseCombiner<Double> {
|
|||
return responses.stream().mapToDouble(db -> db/size).sum();
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public Supplier<Container> supplier() {
|
||||
return () -> new Container(0 ,0);
|
||||
}
|
||||
|
||||
@Override
|
||||
public BiConsumer<Container, Double> accumulator() {
|
||||
return (container, number) -> {
|
||||
container.number+=number;
|
||||
container.n++;
|
||||
};
|
||||
}
|
||||
|
||||
@Override
|
||||
public BinaryOperator<Container> combiner() {
|
||||
return (c1, c2) -> {
|
||||
c1.number += c2.number;
|
||||
c1.n += c2.n;
|
||||
|
||||
return c1;
|
||||
};
|
||||
}
|
||||
|
||||
@Override
|
||||
public Function<Container, Double> finisher() {
|
||||
return (container) -> container.number/(double)container.n;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Set<Characteristics> characteristics() {
|
||||
return Set.of(Characteristics.UNORDERED);
|
||||
}
|
||||
|
||||
|
||||
public static class Container{
|
||||
|
||||
Container(double number, int n){
|
||||
this.number = number;
|
||||
this.n = n;
|
||||
}
|
||||
|
||||
public Double number;
|
||||
public int n;
|
||||
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
|
|
@ -13,7 +13,7 @@ import java.util.stream.Collectors;
|
|||
@Builder
|
||||
public class TreeTrainer<Y> {
|
||||
|
||||
private final ResponseCombiner<Y> responseCombiner;
|
||||
private final ResponseCombiner<Y, ?> responseCombiner;
|
||||
private final GroupDifferentiator<Y> groupDifferentiator;
|
||||
|
||||
/**
|
||||
|
@ -43,10 +43,10 @@ public class TreeTrainer<Y> {
|
|||
|
||||
}
|
||||
else{
|
||||
return new TerminalNode<>(responseCombiner.combine(
|
||||
return new TerminalNode<>(
|
||||
data.stream()
|
||||
.map(row -> row.getResponse())
|
||||
.collect(Collectors.toList()))
|
||||
.map(row -> row.getResponse())
|
||||
.collect(responseCombiner)
|
||||
|
||||
);
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue