Introduce Support for Factors #7
1 changed files with 13 additions and 6 deletions
|
@ -31,9 +31,18 @@ public class TreeTrainer<Y> {
|
||||||
|
|
||||||
private Node<Y> growNode(List<Row<Y>> data, List<String> covariatesToTry, int depth){
|
private Node<Y> growNode(List<Row<Y>> data, List<String> covariatesToTry, int depth){
|
||||||
// TODO; what is minimum per tree?
|
// TODO; what is minimum per tree?
|
||||||
if(data.size() >= 2*nodeSize && depth < maxNodeDepth && !nodeIsPure(data, covariatesToTry)){
|
if(data.size() >= 2*nodeSize && depth < maxNodeDepth && !nodeIsPure(data)){
|
||||||
final SplitRule bestSplitRule = findBestSplitRule(data, covariatesToTry);
|
final SplitRule bestSplitRule = findBestSplitRule(data, covariatesToTry);
|
||||||
|
|
||||||
|
if(bestSplitRule == null){
|
||||||
|
return new TerminalNode<>(
|
||||||
|
data.stream()
|
||||||
|
.map(row -> row.getResponse())
|
||||||
|
.collect(responseCombiner)
|
||||||
|
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
final Split<Y> split = bestSplitRule.applyRule(data); // TODO optimize this as we're duplicating work done in findBestSplitRule
|
final Split<Y> split = bestSplitRule.applyRule(data); // TODO optimize this as we're duplicating work done in findBestSplitRule
|
||||||
|
|
||||||
final Node<Y> leftNode = growNode(split.leftHand, covariatesToTry, depth+1);
|
final Node<Y> leftNode = growNode(split.leftHand, covariatesToTry, depth+1);
|
||||||
|
@ -56,7 +65,7 @@ public class TreeTrainer<Y> {
|
||||||
|
|
||||||
private SplitRule findBestSplitRule(List<Row<Y>> data, List<String> covariatesToTry){
|
private SplitRule findBestSplitRule(List<Row<Y>> data, List<String> covariatesToTry){
|
||||||
SplitRule bestSplitRule = null;
|
SplitRule bestSplitRule = null;
|
||||||
Double bestSplitScore = 0.0; // may be null
|
double bestSplitScore = 0.0;
|
||||||
boolean first = true;
|
boolean first = true;
|
||||||
|
|
||||||
for(final String covariate : covariatesToTry){
|
for(final String covariate : covariatesToTry){
|
||||||
|
@ -92,7 +101,7 @@ public class TreeTrainer<Y> {
|
||||||
possibleSplit.rightHand.stream().map(row -> row.getResponse()).collect(Collectors.toList())
|
possibleSplit.rightHand.stream().map(row -> row.getResponse()).collect(Collectors.toList())
|
||||||
);
|
);
|
||||||
|
|
||||||
if( first || (score != null && (bestSplitScore == null || score > bestSplitScore))){
|
if(score != null && (score > bestSplitScore || first)){
|
||||||
bestSplitRule = possibleRule;
|
bestSplitRule = possibleRule;
|
||||||
bestSplitScore = score;
|
bestSplitScore = score;
|
||||||
first = false;
|
first = false;
|
||||||
|
@ -107,9 +116,7 @@ public class TreeTrainer<Y> {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private boolean nodeIsPure(List<Row<Y>> data, List<String> covariatesToTry){
|
private boolean nodeIsPure(List<Row<Y>> data){
|
||||||
// TODO how is this done?
|
|
||||||
|
|
||||||
final Y first = data.get(0).getResponse();
|
final Y first = data.get(0).getResponse();
|
||||||
return data.stream().allMatch(row -> row.getResponse().equals(first));
|
return data.stream().allMatch(row -> row.getResponse().equals(first));
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue