From 6192643e125713eda4161f4d658df30b29559555 Mon Sep 17 00:00:00 2001 From: Joel Therrien Date: Mon, 2 Jul 2018 12:27:18 -0700 Subject: [PATCH] Change ResponseCombiner to be a Collector that's compatible with Streams. --- .../randomforest/ResponseCombiner.java | 3 +- .../regression/MeanResponseCombiner.java | 61 ++++++++++++++++++- .../randomforest/tree/TreeTrainer.java | 8 +-- 3 files changed, 66 insertions(+), 6 deletions(-) diff --git a/src/main/java/ca/joeltherrien/randomforest/ResponseCombiner.java b/src/main/java/ca/joeltherrien/randomforest/ResponseCombiner.java index 7f24928..e418fdd 100644 --- a/src/main/java/ca/joeltherrien/randomforest/ResponseCombiner.java +++ b/src/main/java/ca/joeltherrien/randomforest/ResponseCombiner.java @@ -1,8 +1,9 @@ package ca.joeltherrien.randomforest; import java.util.List; +import java.util.stream.Collector; -public interface ResponseCombiner { +public interface ResponseCombiner extends Collector { Y combine(List responses); diff --git a/src/main/java/ca/joeltherrien/randomforest/regression/MeanResponseCombiner.java b/src/main/java/ca/joeltherrien/randomforest/regression/MeanResponseCombiner.java index 1e37f3c..4923a32 100644 --- a/src/main/java/ca/joeltherrien/randomforest/regression/MeanResponseCombiner.java +++ b/src/main/java/ca/joeltherrien/randomforest/regression/MeanResponseCombiner.java @@ -3,8 +3,19 @@ package ca.joeltherrien.randomforest.regression; import ca.joeltherrien.randomforest.ResponseCombiner; import java.util.List; +import java.util.Set; +import java.util.function.BiConsumer; +import java.util.function.BinaryOperator; +import java.util.function.Function; +import java.util.function.Supplier; -public class MeanResponseCombiner implements ResponseCombiner { +/** + * This implementation of the collector isn't great... but good enough given that I'm not planning to fully support regression trees. + * + * (It's not great because you'll lose accuracy as you sum up the doubles, since dividing by n is the very last step.) + * + */ +public class MeanResponseCombiner implements ResponseCombiner { @Override public Double combine(List responses) { @@ -13,4 +24,52 @@ public class MeanResponseCombiner implements ResponseCombiner { return responses.stream().mapToDouble(db -> db/size).sum(); } + + @Override + public Supplier supplier() { + return () -> new Container(0 ,0); + } + + @Override + public BiConsumer accumulator() { + return (container, number) -> { + container.number+=number; + container.n++; + }; + } + + @Override + public BinaryOperator combiner() { + return (c1, c2) -> { + c1.number += c2.number; + c1.n += c2.n; + + return c1; + }; + } + + @Override + public Function finisher() { + return (container) -> container.number/(double)container.n; + } + + @Override + public Set characteristics() { + return Set.of(Characteristics.UNORDERED); + } + + + public static class Container{ + + Container(double number, int n){ + this.number = number; + this.n = n; + } + + public Double number; + public int n; + + } + + } diff --git a/src/main/java/ca/joeltherrien/randomforest/tree/TreeTrainer.java b/src/main/java/ca/joeltherrien/randomforest/tree/TreeTrainer.java index df2ae5c..fed92bf 100644 --- a/src/main/java/ca/joeltherrien/randomforest/tree/TreeTrainer.java +++ b/src/main/java/ca/joeltherrien/randomforest/tree/TreeTrainer.java @@ -13,7 +13,7 @@ import java.util.stream.Collectors; @Builder public class TreeTrainer { - private final ResponseCombiner responseCombiner; + private final ResponseCombiner responseCombiner; private final GroupDifferentiator groupDifferentiator; /** @@ -43,10 +43,10 @@ public class TreeTrainer { } else{ - return new TerminalNode<>(responseCombiner.combine( + return new TerminalNode<>( data.stream() - .map(row -> row.getResponse()) - .collect(Collectors.toList())) + .map(row -> row.getResponse()) + .collect(responseCombiner) ); }