Add capabilities to get nodes of a certain type in a forest; used to produce summary statistics

This commit is contained in:
Joel Therrien 2019-02-02 09:36:00 -08:00
parent 77ec780304
commit 9f513ab75b
9 changed files with 110 additions and 15 deletions

View file

@ -16,8 +16,8 @@
package ca.joeltherrien.randomforest; package ca.joeltherrien.randomforest;
import ca.joeltherrien.randomforest.covariates.settings.BooleanCovariateSettings;
import ca.joeltherrien.randomforest.covariates.Covariate; import ca.joeltherrien.randomforest.covariates.Covariate;
import ca.joeltherrien.randomforest.covariates.settings.BooleanCovariateSettings;
import ca.joeltherrien.randomforest.covariates.settings.FactorCovariateSettings; import ca.joeltherrien.randomforest.covariates.settings.FactorCovariateSettings;
import ca.joeltherrien.randomforest.covariates.settings.NumericCovariateSettings; import ca.joeltherrien.randomforest.covariates.settings.NumericCovariateSettings;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskErrorRateCalculator; import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskErrorRateCalculator;
@ -38,6 +38,7 @@ import java.io.File;
import java.io.IOException; import java.io.IOException;
import java.io.PrintWriter; import java.io.PrintWriter;
import java.util.List; import java.util.List;
import java.util.Random;
public class Main { public class Main {
@ -98,7 +99,7 @@ public class Main {
// Let's reduce this down to n // Let's reduce this down to n
final int n = Integer.parseInt(args[2]); final int n = Integer.parseInt(args[2]);
Utils.reduceListToSize(dataset, n); Utils.reduceListToSize(dataset, n, new Random());
final File folder = new File(settings.getSaveTreeLocation()); final File folder = new File(settings.getSaveTreeLocation());
final Forest<?, CompetingRiskFunctions> forest = DataLoader.loadForest(folder, responseCombiner); final Forest<?, CompetingRiskFunctions> forest = DataLoader.loadForest(folder, responseCombiner);

View file

@ -17,11 +17,10 @@
package ca.joeltherrien.randomforest.tree; package ca.joeltherrien.randomforest.tree;
import ca.joeltherrien.randomforest.CovariateRow; import ca.joeltherrien.randomforest.CovariateRow;
import ca.joeltherrien.randomforest.covariates.Covariate;
import lombok.Builder; import lombok.Builder;
import java.util.Collection; import java.util.*;
import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors; import java.util.stream.Collectors;
@Builder @Builder
@ -67,4 +66,54 @@ public class Forest<O, FO> { // O = output of trees, FO = forest output. In prac
return Collections.unmodifiableCollection(trees); return Collections.unmodifiableCollection(trees);
} }
public Map<Covariate, Integer> findSplitsByCovariate(){
final Map<Covariate, Integer> countMap = new HashMap<>();
for(final Tree<O> tree : getTrees()){
final Node<O> rootNode = tree.getRootNode();
final List<SplitNode> splitNodeList = rootNode.getNodesOfType(SplitNode.class);
for(final SplitNode splitNode : splitNodeList){
final Covariate covariate = splitNode.getSplitRule().getParent();
final Integer currentCount = countMap.getOrDefault(covariate, 0);
countMap.put(covariate, currentCount+1);
}
}
return countMap;
}
public double averageTerminalNodeSize(){
long numberTerminalNodes = 0;
long totalSizeTerminalNodes = 0;
for(final Tree<O> tree : getTrees()){
final Node<O> rootNode = tree.getRootNode();
final List<TerminalNode> terminalNodeList = rootNode.getNodesOfType(TerminalNode.class);
for(final TerminalNode terminalNode : terminalNodeList){
numberTerminalNodes++;
totalSizeTerminalNodes += terminalNode.getSize();
}
}
return (double) totalSizeTerminalNodes / (double) numberTerminalNodes;
}
public int numberOfTerminalNodes(){
int countTerminalNodes = 0;
for(final Tree<O> tree : getTrees()){
final Node<O> rootNode = tree.getRootNode();
final List<TerminalNode> terminalNodeList = rootNode.getNodesOfType(TerminalNode.class);
countTerminalNodes += terminalNodeList.size();
}
return countTerminalNodes;
}
} }

View file

@ -19,9 +19,12 @@ package ca.joeltherrien.randomforest.tree;
import ca.joeltherrien.randomforest.CovariateRow; import ca.joeltherrien.randomforest.CovariateRow;
import java.io.Serializable; import java.io.Serializable;
import java.util.List;
public interface Node<Y> extends Serializable { public interface Node<Y> extends Serializable {
Y evaluate(CovariateRow row); Y evaluate(CovariateRow row);
<C extends Node<Y>> List<C> getNodesOfType(Class<C> nodeType);
} }

View file

@ -19,10 +19,15 @@ package ca.joeltherrien.randomforest.tree;
import ca.joeltherrien.randomforest.CovariateRow; import ca.joeltherrien.randomforest.CovariateRow;
import ca.joeltherrien.randomforest.covariates.Covariate; import ca.joeltherrien.randomforest.covariates.Covariate;
import lombok.Builder; import lombok.Builder;
import lombok.Getter;
import lombok.ToString; import lombok.ToString;
import java.util.ArrayList;
import java.util.List;
@Builder @Builder
@ToString @ToString
@Getter
public class SplitNode<Y> implements Node<Y> { public class SplitNode<Y> implements Node<Y> {
private final Node<Y> leftHand; private final Node<Y> leftHand;
@ -41,4 +46,18 @@ public class SplitNode<Y> implements Node<Y> {
} }
} }
@Override
public <C extends Node<Y>> List<C> getNodesOfType(Class<C> nodeType) {
final List<C> nodeList = new ArrayList<>();
if(nodeType.isInstance(this)){
nodeList.add((C) this);
}
nodeList.addAll(leftHand.getNodesOfType(nodeType));
nodeList.addAll(rightHand.getNodesOfType(nodeType));
return nodeList;
}
} }

View file

@ -17,19 +17,36 @@
package ca.joeltherrien.randomforest.tree; package ca.joeltherrien.randomforest.tree;
import ca.joeltherrien.randomforest.CovariateRow; import ca.joeltherrien.randomforest.CovariateRow;
import lombok.Getter;
import lombok.RequiredArgsConstructor; import lombok.RequiredArgsConstructor;
import lombok.ToString; import lombok.ToString;
import java.util.Collections;
import java.util.List;
@RequiredArgsConstructor @RequiredArgsConstructor
@ToString @ToString
public class TerminalNode<Y> implements Node<Y> { public class TerminalNode<Y> implements Node<Y> {
private final Y responseValue; private final Y responseValue;
@Getter
private final int size;
@Override @Override
public Y evaluate(CovariateRow row){ public Y evaluate(CovariateRow row){
return responseValue; return responseValue;
} }
@Override
public <C extends Node<Y>> List<C> getNodesOfType(Class<C> nodeType) {
if(nodeType.isInstance(this)){
return Collections.singletonList((C) this);
}
return Collections.emptyList();
}
} }

View file

@ -17,11 +17,14 @@
package ca.joeltherrien.randomforest.tree; package ca.joeltherrien.randomforest.tree;
import ca.joeltherrien.randomforest.CovariateRow; import ca.joeltherrien.randomforest.CovariateRow;
import lombok.Getter;
import java.util.Arrays; import java.util.Arrays;
import java.util.List;
public class Tree<Y> implements Node<Y> { public class Tree<Y> implements Node<Y> {
@Getter
private final Node<Y> rootNode; private final Node<Y> rootNode;
private final int[] bootstrapRowIds; private final int[] bootstrapRowIds;
@ -37,6 +40,11 @@ public class Tree<Y> implements Node<Y> {
return rootNode.evaluate(row); return rootNode.evaluate(row);
} }
@Override
public <C extends Node<Y>> List<C> getNodesOfType(Class<C> nodeType) {
return rootNode.getNodesOfType(nodeType);
}
public int[] getBootstrapRowIds(){ public int[] getBootstrapRowIds(){
return bootstrapRowIds.clone(); return bootstrapRowIds.clone();
} }

View file

@ -77,11 +77,10 @@ public class TreeTrainer<Y, O> {
if(bestSplit == null){ if(bestSplit == null){
return new TerminalNode<>( return new TerminalNode<>(
responseCombiner.combine( responseCombiner.combine(
data.stream().map(row -> row.getResponse()).collect(Collectors.toList()) data.stream().map(row -> row.getResponse()).collect(Collectors.toList())
) ), data.size()
); );
@ -121,7 +120,7 @@ public class TreeTrainer<Y, O> {
return new TerminalNode<>( return new TerminalNode<>(
responseCombiner.combine( responseCombiner.combine(
data.stream().map(row -> row.getResponse()).collect(Collectors.toList()) data.stream().map(row -> row.getResponse()).collect(Collectors.toList())
) ), data.size()
); );
} }

View file

@ -17,7 +17,6 @@
package ca.joeltherrien.randomforest.utils; package ca.joeltherrien.randomforest.utils;
import java.util.*; import java.util.*;
import java.util.concurrent.ThreadLocalRandom;
public final class Utils { public final class Utils {
@ -52,12 +51,11 @@ public final class Utils {
} }
public static <T> void reduceListToSize(List<T> list, int n){ public static <T> void reduceListToSize(List<T> list, int n, final Random random){
if(list.size() <= n){ if(list.size() <= n){
return; return;
} }
final Random random = ThreadLocalRandom.current();
if(n > list.size()/2){ if(n > list.size()/2){
// faster to randomly remove items // faster to randomly remove items
while(list.size() > n){ while(list.size() > n){

View file

@ -23,6 +23,7 @@ import org.junit.jupiter.api.Test;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.HashSet; import java.util.HashSet;
import java.util.List; import java.util.List;
import java.util.Random;
import java.util.function.DoubleSupplier; import java.util.function.DoubleSupplier;
import java.util.stream.DoubleStream; import java.util.stream.DoubleStream;
@ -109,24 +110,24 @@ public class TestUtils {
@Test @Test
public void reduceListToSize(){ public void reduceListToSize(){
final List<Integer> testList = Utils.easyList(1, 2, 3, 4, 5, 6, 7, 8, 9, 10); final List<Integer> testList = Utils.easyList(1, 2, 3, 4, 5, 6, 7, 8, 9, 10);
final Random random = new Random();
for(int i=0; i<100; i++) { // want to test many times to ensure it doesn't work just due to randomness for(int i=0; i<100; i++) { // want to test many times to ensure it doesn't work just due to randomness
final List<Integer> testList1 = new ArrayList<>(testList); final List<Integer> testList1 = new ArrayList<>(testList);
// Test when removing elements // Test when removing elements
Utils.reduceListToSize(testList1, 7); Utils.reduceListToSize(testList1, 7, random);
assertEquals(7, testList1.size()); // verify proper size assertEquals(7, testList1.size()); // verify proper size
assertEquals(7, new HashSet<>(testList1).size()); // verify the items are unique assertEquals(7, new HashSet<>(testList1).size()); // verify the items are unique
final List<Integer> testList2 = new ArrayList<>(testList); final List<Integer> testList2 = new ArrayList<>(testList);
// Test when adding elements // Test when adding elements
Utils.reduceListToSize(testList2, 3); Utils.reduceListToSize(testList2, 3, random);
assertEquals(3, testList2.size()); // verify proper size assertEquals(3, testList2.size()); // verify proper size
assertEquals(3, new HashSet<>(testList2).size()); // verify the items are unique assertEquals(3, new HashSet<>(testList2).size()); // verify the items are unique
final List<Integer> testList3 = new ArrayList<>(testList); final List<Integer> testList3 = new ArrayList<>(testList);
// verify no change // verify no change
Utils.reduceListToSize(testList3, 15); Utils.reduceListToSize(testList3, 15, random);
assertEquals(10, testList3.size()); // verify proper size assertEquals(10, testList3.size()); // verify proper size
assertEquals(10, new HashSet<>(testList3).size()); // verify the items are unique assertEquals(10, new HashSet<>(testList3).size()); // verify the items are unique