Change ResponseCombiner to be a Collector that's compatible
with Streams.
This commit is contained in:
parent
3c9c78741f
commit
6192643e12
3 changed files with 66 additions and 6 deletions
|
@ -1,8 +1,9 @@
|
||||||
package ca.joeltherrien.randomforest;
|
package ca.joeltherrien.randomforest;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
import java.util.stream.Collector;
|
||||||
|
|
||||||
public interface ResponseCombiner<Y> {
|
public interface ResponseCombiner<Y, K> extends Collector<Y, K, Y> {
|
||||||
|
|
||||||
Y combine(List<Y> responses);
|
Y combine(List<Y> responses);
|
||||||
|
|
||||||
|
|
|
@ -3,8 +3,19 @@ package ca.joeltherrien.randomforest.regression;
|
||||||
import ca.joeltherrien.randomforest.ResponseCombiner;
|
import ca.joeltherrien.randomforest.ResponseCombiner;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
import java.util.Set;
|
||||||
|
import java.util.function.BiConsumer;
|
||||||
|
import java.util.function.BinaryOperator;
|
||||||
|
import java.util.function.Function;
|
||||||
|
import java.util.function.Supplier;
|
||||||
|
|
||||||
public class MeanResponseCombiner implements ResponseCombiner<Double> {
|
/**
|
||||||
|
* This implementation of the collector isn't great... but good enough given that I'm not planning to fully support regression trees.
|
||||||
|
*
|
||||||
|
* (It's not great because you'll lose accuracy as you sum up the doubles, since dividing by n is the very last step.)
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
public class MeanResponseCombiner implements ResponseCombiner<Double, MeanResponseCombiner.Container> {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Double combine(List<Double> responses) {
|
public Double combine(List<Double> responses) {
|
||||||
|
@ -13,4 +24,52 @@ public class MeanResponseCombiner implements ResponseCombiner<Double> {
|
||||||
return responses.stream().mapToDouble(db -> db/size).sum();
|
return responses.stream().mapToDouble(db -> db/size).sum();
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Supplier<Container> supplier() {
|
||||||
|
return () -> new Container(0 ,0);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public BiConsumer<Container, Double> accumulator() {
|
||||||
|
return (container, number) -> {
|
||||||
|
container.number+=number;
|
||||||
|
container.n++;
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public BinaryOperator<Container> combiner() {
|
||||||
|
return (c1, c2) -> {
|
||||||
|
c1.number += c2.number;
|
||||||
|
c1.n += c2.n;
|
||||||
|
|
||||||
|
return c1;
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Function<Container, Double> finisher() {
|
||||||
|
return (container) -> container.number/(double)container.n;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Set<Characteristics> characteristics() {
|
||||||
|
return Set.of(Characteristics.UNORDERED);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
public static class Container{
|
||||||
|
|
||||||
|
Container(double number, int n){
|
||||||
|
this.number = number;
|
||||||
|
this.n = n;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Double number;
|
||||||
|
public int n;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -13,7 +13,7 @@ import java.util.stream.Collectors;
|
||||||
@Builder
|
@Builder
|
||||||
public class TreeTrainer<Y> {
|
public class TreeTrainer<Y> {
|
||||||
|
|
||||||
private final ResponseCombiner<Y> responseCombiner;
|
private final ResponseCombiner<Y, ?> responseCombiner;
|
||||||
private final GroupDifferentiator<Y> groupDifferentiator;
|
private final GroupDifferentiator<Y> groupDifferentiator;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -43,10 +43,10 @@ public class TreeTrainer<Y> {
|
||||||
|
|
||||||
}
|
}
|
||||||
else{
|
else{
|
||||||
return new TerminalNode<>(responseCombiner.combine(
|
return new TerminalNode<>(
|
||||||
data.stream()
|
data.stream()
|
||||||
.map(row -> row.getResponse())
|
.map(row -> row.getResponse())
|
||||||
.collect(Collectors.toList()))
|
.collect(responseCombiner)
|
||||||
|
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue