Add capabilities to get nodes of a certain type in a forest; used to produce summary statistics
This commit is contained in:
parent
77ec780304
commit
9f513ab75b
9 changed files with 110 additions and 15 deletions
|
@ -16,8 +16,8 @@
|
|||
|
||||
package ca.joeltherrien.randomforest;
|
||||
|
||||
import ca.joeltherrien.randomforest.covariates.settings.BooleanCovariateSettings;
|
||||
import ca.joeltherrien.randomforest.covariates.Covariate;
|
||||
import ca.joeltherrien.randomforest.covariates.settings.BooleanCovariateSettings;
|
||||
import ca.joeltherrien.randomforest.covariates.settings.FactorCovariateSettings;
|
||||
import ca.joeltherrien.randomforest.covariates.settings.NumericCovariateSettings;
|
||||
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskErrorRateCalculator;
|
||||
|
@ -38,6 +38,7 @@ import java.io.File;
|
|||
import java.io.IOException;
|
||||
import java.io.PrintWriter;
|
||||
import java.util.List;
|
||||
import java.util.Random;
|
||||
|
||||
public class Main {
|
||||
|
||||
|
@ -98,7 +99,7 @@ public class Main {
|
|||
|
||||
// Let's reduce this down to n
|
||||
final int n = Integer.parseInt(args[2]);
|
||||
Utils.reduceListToSize(dataset, n);
|
||||
Utils.reduceListToSize(dataset, n, new Random());
|
||||
|
||||
final File folder = new File(settings.getSaveTreeLocation());
|
||||
final Forest<?, CompetingRiskFunctions> forest = DataLoader.loadForest(folder, responseCombiner);
|
||||
|
|
|
@ -17,11 +17,10 @@
|
|||
package ca.joeltherrien.randomforest.tree;
|
||||
|
||||
import ca.joeltherrien.randomforest.CovariateRow;
|
||||
import ca.joeltherrien.randomforest.covariates.Covariate;
|
||||
import lombok.Builder;
|
||||
|
||||
import java.util.Collection;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.*;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@Builder
|
||||
|
@ -67,4 +66,54 @@ public class Forest<O, FO> { // O = output of trees, FO = forest output. In prac
|
|||
return Collections.unmodifiableCollection(trees);
|
||||
}
|
||||
|
||||
public Map<Covariate, Integer> findSplitsByCovariate(){
|
||||
final Map<Covariate, Integer> countMap = new HashMap<>();
|
||||
|
||||
for(final Tree<O> tree : getTrees()){
|
||||
final Node<O> rootNode = tree.getRootNode();
|
||||
final List<SplitNode> splitNodeList = rootNode.getNodesOfType(SplitNode.class);
|
||||
|
||||
for(final SplitNode splitNode : splitNodeList){
|
||||
final Covariate covariate = splitNode.getSplitRule().getParent();
|
||||
|
||||
final Integer currentCount = countMap.getOrDefault(covariate, 0);
|
||||
countMap.put(covariate, currentCount+1);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
return countMap;
|
||||
}
|
||||
|
||||
public double averageTerminalNodeSize(){
|
||||
long numberTerminalNodes = 0;
|
||||
long totalSizeTerminalNodes = 0;
|
||||
|
||||
for(final Tree<O> tree : getTrees()){
|
||||
final Node<O> rootNode = tree.getRootNode();
|
||||
final List<TerminalNode> terminalNodeList = rootNode.getNodesOfType(TerminalNode.class);
|
||||
|
||||
for(final TerminalNode terminalNode : terminalNodeList){
|
||||
numberTerminalNodes++;
|
||||
totalSizeTerminalNodes += terminalNode.getSize();
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
return (double) totalSizeTerminalNodes / (double) numberTerminalNodes;
|
||||
}
|
||||
|
||||
public int numberOfTerminalNodes(){
|
||||
int countTerminalNodes = 0;
|
||||
|
||||
for(final Tree<O> tree : getTrees()){
|
||||
final Node<O> rootNode = tree.getRootNode();
|
||||
final List<TerminalNode> terminalNodeList = rootNode.getNodesOfType(TerminalNode.class);
|
||||
|
||||
countTerminalNodes += terminalNodeList.size();
|
||||
}
|
||||
|
||||
return countTerminalNodes;
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -19,9 +19,12 @@ package ca.joeltherrien.randomforest.tree;
|
|||
import ca.joeltherrien.randomforest.CovariateRow;
|
||||
|
||||
import java.io.Serializable;
|
||||
import java.util.List;
|
||||
|
||||
public interface Node<Y> extends Serializable {
|
||||
|
||||
Y evaluate(CovariateRow row);
|
||||
|
||||
<C extends Node<Y>> List<C> getNodesOfType(Class<C> nodeType);
|
||||
|
||||
}
|
||||
|
|
|
@ -19,10 +19,15 @@ package ca.joeltherrien.randomforest.tree;
|
|||
import ca.joeltherrien.randomforest.CovariateRow;
|
||||
import ca.joeltherrien.randomforest.covariates.Covariate;
|
||||
import lombok.Builder;
|
||||
import lombok.Getter;
|
||||
import lombok.ToString;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
@Builder
|
||||
@ToString
|
||||
@Getter
|
||||
public class SplitNode<Y> implements Node<Y> {
|
||||
|
||||
private final Node<Y> leftHand;
|
||||
|
@ -41,4 +46,18 @@ public class SplitNode<Y> implements Node<Y> {
|
|||
}
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public <C extends Node<Y>> List<C> getNodesOfType(Class<C> nodeType) {
|
||||
final List<C> nodeList = new ArrayList<>();
|
||||
if(nodeType.isInstance(this)){
|
||||
nodeList.add((C) this);
|
||||
}
|
||||
|
||||
nodeList.addAll(leftHand.getNodesOfType(nodeType));
|
||||
nodeList.addAll(rightHand.getNodesOfType(nodeType));
|
||||
|
||||
return nodeList;
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -17,19 +17,36 @@
|
|||
package ca.joeltherrien.randomforest.tree;
|
||||
|
||||
import ca.joeltherrien.randomforest.CovariateRow;
|
||||
import lombok.Getter;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import lombok.ToString;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
|
||||
@RequiredArgsConstructor
|
||||
@ToString
|
||||
public class TerminalNode<Y> implements Node<Y> {
|
||||
|
||||
private final Y responseValue;
|
||||
|
||||
@Getter
|
||||
private final int size;
|
||||
|
||||
@Override
|
||||
public Y evaluate(CovariateRow row){
|
||||
return responseValue;
|
||||
}
|
||||
|
||||
@Override
|
||||
public <C extends Node<Y>> List<C> getNodesOfType(Class<C> nodeType) {
|
||||
|
||||
if(nodeType.isInstance(this)){
|
||||
return Collections.singletonList((C) this);
|
||||
}
|
||||
|
||||
return Collections.emptyList();
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
|
|
@ -17,11 +17,14 @@
|
|||
package ca.joeltherrien.randomforest.tree;
|
||||
|
||||
import ca.joeltherrien.randomforest.CovariateRow;
|
||||
import lombok.Getter;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
|
||||
public class Tree<Y> implements Node<Y> {
|
||||
|
||||
@Getter
|
||||
private final Node<Y> rootNode;
|
||||
private final int[] bootstrapRowIds;
|
||||
|
||||
|
@ -37,6 +40,11 @@ public class Tree<Y> implements Node<Y> {
|
|||
return rootNode.evaluate(row);
|
||||
}
|
||||
|
||||
@Override
|
||||
public <C extends Node<Y>> List<C> getNodesOfType(Class<C> nodeType) {
|
||||
return rootNode.getNodesOfType(nodeType);
|
||||
}
|
||||
|
||||
public int[] getBootstrapRowIds(){
|
||||
return bootstrapRowIds.clone();
|
||||
}
|
||||
|
|
|
@ -77,11 +77,10 @@ public class TreeTrainer<Y, O> {
|
|||
|
||||
|
||||
if(bestSplit == null){
|
||||
|
||||
return new TerminalNode<>(
|
||||
responseCombiner.combine(
|
||||
data.stream().map(row -> row.getResponse()).collect(Collectors.toList())
|
||||
)
|
||||
), data.size()
|
||||
);
|
||||
|
||||
|
||||
|
@ -121,7 +120,7 @@ public class TreeTrainer<Y, O> {
|
|||
return new TerminalNode<>(
|
||||
responseCombiner.combine(
|
||||
data.stream().map(row -> row.getResponse()).collect(Collectors.toList())
|
||||
)
|
||||
), data.size()
|
||||
);
|
||||
}
|
||||
|
||||
|
|
|
@ -17,7 +17,6 @@
|
|||
package ca.joeltherrien.randomforest.utils;
|
||||
|
||||
import java.util.*;
|
||||
import java.util.concurrent.ThreadLocalRandom;
|
||||
|
||||
public final class Utils {
|
||||
|
||||
|
@ -52,12 +51,11 @@ public final class Utils {
|
|||
|
||||
}
|
||||
|
||||
public static <T> void reduceListToSize(List<T> list, int n){
|
||||
public static <T> void reduceListToSize(List<T> list, int n, final Random random){
|
||||
if(list.size() <= n){
|
||||
return;
|
||||
}
|
||||
|
||||
final Random random = ThreadLocalRandom.current();
|
||||
if(n > list.size()/2){
|
||||
// faster to randomly remove items
|
||||
while(list.size() > n){
|
||||
|
|
|
@ -23,6 +23,7 @@ import org.junit.jupiter.api.Test;
|
|||
import java.util.ArrayList;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Random;
|
||||
import java.util.function.DoubleSupplier;
|
||||
import java.util.stream.DoubleStream;
|
||||
|
||||
|
@ -109,24 +110,24 @@ public class TestUtils {
|
|||
@Test
|
||||
public void reduceListToSize(){
|
||||
final List<Integer> testList = Utils.easyList(1, 2, 3, 4, 5, 6, 7, 8, 9, 10);
|
||||
|
||||
final Random random = new Random();
|
||||
for(int i=0; i<100; i++) { // want to test many times to ensure it doesn't work just due to randomness
|
||||
final List<Integer> testList1 = new ArrayList<>(testList);
|
||||
// Test when removing elements
|
||||
Utils.reduceListToSize(testList1, 7);
|
||||
Utils.reduceListToSize(testList1, 7, random);
|
||||
assertEquals(7, testList1.size()); // verify proper size
|
||||
assertEquals(7, new HashSet<>(testList1).size()); // verify the items are unique
|
||||
|
||||
|
||||
final List<Integer> testList2 = new ArrayList<>(testList);
|
||||
// Test when adding elements
|
||||
Utils.reduceListToSize(testList2, 3);
|
||||
Utils.reduceListToSize(testList2, 3, random);
|
||||
assertEquals(3, testList2.size()); // verify proper size
|
||||
assertEquals(3, new HashSet<>(testList2).size()); // verify the items are unique
|
||||
|
||||
final List<Integer> testList3 = new ArrayList<>(testList);
|
||||
// verify no change
|
||||
Utils.reduceListToSize(testList3, 15);
|
||||
Utils.reduceListToSize(testList3, 15, random);
|
||||
assertEquals(10, testList3.size()); // verify proper size
|
||||
assertEquals(10, new HashSet<>(testList3).size()); // verify the items are unique
|
||||
|
||||
|
|
Loading…
Reference in a new issue