
Also, the executable component uses a dependency that keeps having security vulnerabilities.
66 lines
2.1 KiB
Java
66 lines
2.1 KiB
Java
/*
|
|
* Copyright (c) 2019 Joel Therrien.
|
|
* This program is free software: you can redistribute it and/or modify
|
|
* it under the terms of the GNU General Public License as published by
|
|
* the Free Software Foundation, either version 3 of the License, or
|
|
* (at your option) any later version.
|
|
*
|
|
* This program is distributed in the hope that it will be useful,
|
|
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
* GNU General Public License for more details.
|
|
*
|
|
* You should have received a copy of the GNU General Public License
|
|
* along with this program. If not, see <https://www.gnu.org/licenses/>.
|
|
*/
|
|
|
|
package ca.joeltherrien.randomforest.tree;
|
|
|
|
import ca.joeltherrien.randomforest.Row;
|
|
|
|
import java.util.Iterator;
|
|
import java.util.List;
|
|
import java.util.stream.Collectors;
|
|
|
|
public abstract class SimpleSplitFinder<Y> implements SplitFinder<Y> {
|
|
|
|
@Override
|
|
public SplitAndScore<Y, ?> findBestSplit(Iterator<Split<Y, ?>> splitIterator) {
|
|
Double bestScore = null;
|
|
Split<Y, ?> bestSplit = null;
|
|
|
|
while(splitIterator.hasNext()){
|
|
final Split<Y, ?> candidateSplit = splitIterator.next();
|
|
|
|
final List<Y> leftHand = candidateSplit.getLeftHand().stream().map(Row::getResponse).collect(Collectors.toList());
|
|
final List<Y> rightHand = candidateSplit.getRightHand().stream().map(Row::getResponse).collect(Collectors.toList());
|
|
|
|
if(leftHand.isEmpty() || rightHand.isEmpty()){
|
|
continue;
|
|
}
|
|
|
|
final Double score = getScore(leftHand, rightHand);
|
|
|
|
if(score != null && (bestScore == null || score > bestScore)){
|
|
bestScore = score;
|
|
bestSplit = candidateSplit;
|
|
}
|
|
}
|
|
|
|
if(bestSplit == null){
|
|
return null;
|
|
}
|
|
|
|
return new SplitAndScore<>(bestSplit, bestScore);
|
|
}
|
|
|
|
/**
|
|
* Return a score; higher is better.
|
|
*
|
|
* @param leftHand
|
|
* @param rightHand
|
|
* @return
|
|
*/
|
|
public abstract Double getScore(List<Y> leftHand, List<Y> rightHand);
|
|
|
|
}
|