Introduce Support for Factors #7
1 changed files with 13 additions and 6 deletions
|
@ -31,9 +31,18 @@ public class TreeTrainer<Y> {
|
|||
|
||||
private Node<Y> growNode(List<Row<Y>> data, List<String> covariatesToTry, int depth){
|
||||
// TODO; what is minimum per tree?
|
||||
if(data.size() >= 2*nodeSize && depth < maxNodeDepth && !nodeIsPure(data, covariatesToTry)){
|
||||
if(data.size() >= 2*nodeSize && depth < maxNodeDepth && !nodeIsPure(data)){
|
||||
final SplitRule bestSplitRule = findBestSplitRule(data, covariatesToTry);
|
||||
|
||||
if(bestSplitRule == null){
|
||||
return new TerminalNode<>(
|
||||
data.stream()
|
||||
.map(row -> row.getResponse())
|
||||
.collect(responseCombiner)
|
||||
|
||||
);
|
||||
}
|
||||
|
||||
final Split<Y> split = bestSplitRule.applyRule(data); // TODO optimize this as we're duplicating work done in findBestSplitRule
|
||||
|
||||
final Node<Y> leftNode = growNode(split.leftHand, covariatesToTry, depth+1);
|
||||
|
@ -56,7 +65,7 @@ public class TreeTrainer<Y> {
|
|||
|
||||
private SplitRule findBestSplitRule(List<Row<Y>> data, List<String> covariatesToTry){
|
||||
SplitRule bestSplitRule = null;
|
||||
Double bestSplitScore = 0.0; // may be null
|
||||
double bestSplitScore = 0.0;
|
||||
boolean first = true;
|
||||
|
||||
for(final String covariate : covariatesToTry){
|
||||
|
@ -92,7 +101,7 @@ public class TreeTrainer<Y> {
|
|||
possibleSplit.rightHand.stream().map(row -> row.getResponse()).collect(Collectors.toList())
|
||||
);
|
||||
|
||||
if( first || (score != null && (bestSplitScore == null || score > bestSplitScore))){
|
||||
if(score != null && (score > bestSplitScore || first)){
|
||||
bestSplitRule = possibleRule;
|
||||
bestSplitScore = score;
|
||||
first = false;
|
||||
|
@ -107,9 +116,7 @@ public class TreeTrainer<Y> {
|
|||
|
||||
}
|
||||
|
||||
private boolean nodeIsPure(List<Row<Y>> data, List<String> covariatesToTry){
|
||||
// TODO how is this done?
|
||||
|
||||
private boolean nodeIsPure(List<Row<Y>> data){
|
||||
final Y first = data.get(0).getResponse();
|
||||
return data.stream().allMatch(row -> row.getResponse().equals(first));
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue