Add capabilities to get nodes of a certain type in a forest; used to produce summary statistics
This commit is contained in:
parent
77ec780304
commit
9f513ab75b
9 changed files with 110 additions and 15 deletions
|
@ -16,8 +16,8 @@
|
||||||
|
|
||||||
package ca.joeltherrien.randomforest;
|
package ca.joeltherrien.randomforest;
|
||||||
|
|
||||||
import ca.joeltherrien.randomforest.covariates.settings.BooleanCovariateSettings;
|
|
||||||
import ca.joeltherrien.randomforest.covariates.Covariate;
|
import ca.joeltherrien.randomforest.covariates.Covariate;
|
||||||
|
import ca.joeltherrien.randomforest.covariates.settings.BooleanCovariateSettings;
|
||||||
import ca.joeltherrien.randomforest.covariates.settings.FactorCovariateSettings;
|
import ca.joeltherrien.randomforest.covariates.settings.FactorCovariateSettings;
|
||||||
import ca.joeltherrien.randomforest.covariates.settings.NumericCovariateSettings;
|
import ca.joeltherrien.randomforest.covariates.settings.NumericCovariateSettings;
|
||||||
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskErrorRateCalculator;
|
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskErrorRateCalculator;
|
||||||
|
@ -38,6 +38,7 @@ import java.io.File;
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.io.PrintWriter;
|
import java.io.PrintWriter;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
import java.util.Random;
|
||||||
|
|
||||||
public class Main {
|
public class Main {
|
||||||
|
|
||||||
|
@ -98,7 +99,7 @@ public class Main {
|
||||||
|
|
||||||
// Let's reduce this down to n
|
// Let's reduce this down to n
|
||||||
final int n = Integer.parseInt(args[2]);
|
final int n = Integer.parseInt(args[2]);
|
||||||
Utils.reduceListToSize(dataset, n);
|
Utils.reduceListToSize(dataset, n, new Random());
|
||||||
|
|
||||||
final File folder = new File(settings.getSaveTreeLocation());
|
final File folder = new File(settings.getSaveTreeLocation());
|
||||||
final Forest<?, CompetingRiskFunctions> forest = DataLoader.loadForest(folder, responseCombiner);
|
final Forest<?, CompetingRiskFunctions> forest = DataLoader.loadForest(folder, responseCombiner);
|
||||||
|
|
|
@ -17,11 +17,10 @@
|
||||||
package ca.joeltherrien.randomforest.tree;
|
package ca.joeltherrien.randomforest.tree;
|
||||||
|
|
||||||
import ca.joeltherrien.randomforest.CovariateRow;
|
import ca.joeltherrien.randomforest.CovariateRow;
|
||||||
|
import ca.joeltherrien.randomforest.covariates.Covariate;
|
||||||
import lombok.Builder;
|
import lombok.Builder;
|
||||||
|
|
||||||
import java.util.Collection;
|
import java.util.*;
|
||||||
import java.util.Collections;
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
@Builder
|
@Builder
|
||||||
|
@ -67,4 +66,54 @@ public class Forest<O, FO> { // O = output of trees, FO = forest output. In prac
|
||||||
return Collections.unmodifiableCollection(trees);
|
return Collections.unmodifiableCollection(trees);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public Map<Covariate, Integer> findSplitsByCovariate(){
|
||||||
|
final Map<Covariate, Integer> countMap = new HashMap<>();
|
||||||
|
|
||||||
|
for(final Tree<O> tree : getTrees()){
|
||||||
|
final Node<O> rootNode = tree.getRootNode();
|
||||||
|
final List<SplitNode> splitNodeList = rootNode.getNodesOfType(SplitNode.class);
|
||||||
|
|
||||||
|
for(final SplitNode splitNode : splitNodeList){
|
||||||
|
final Covariate covariate = splitNode.getSplitRule().getParent();
|
||||||
|
|
||||||
|
final Integer currentCount = countMap.getOrDefault(covariate, 0);
|
||||||
|
countMap.put(covariate, currentCount+1);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
return countMap;
|
||||||
|
}
|
||||||
|
|
||||||
|
public double averageTerminalNodeSize(){
|
||||||
|
long numberTerminalNodes = 0;
|
||||||
|
long totalSizeTerminalNodes = 0;
|
||||||
|
|
||||||
|
for(final Tree<O> tree : getTrees()){
|
||||||
|
final Node<O> rootNode = tree.getRootNode();
|
||||||
|
final List<TerminalNode> terminalNodeList = rootNode.getNodesOfType(TerminalNode.class);
|
||||||
|
|
||||||
|
for(final TerminalNode terminalNode : terminalNodeList){
|
||||||
|
numberTerminalNodes++;
|
||||||
|
totalSizeTerminalNodes += terminalNode.getSize();
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
return (double) totalSizeTerminalNodes / (double) numberTerminalNodes;
|
||||||
|
}
|
||||||
|
|
||||||
|
public int numberOfTerminalNodes(){
|
||||||
|
int countTerminalNodes = 0;
|
||||||
|
|
||||||
|
for(final Tree<O> tree : getTrees()){
|
||||||
|
final Node<O> rootNode = tree.getRootNode();
|
||||||
|
final List<TerminalNode> terminalNodeList = rootNode.getNodesOfType(TerminalNode.class);
|
||||||
|
|
||||||
|
countTerminalNodes += terminalNodeList.size();
|
||||||
|
}
|
||||||
|
|
||||||
|
return countTerminalNodes;
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -19,9 +19,12 @@ package ca.joeltherrien.randomforest.tree;
|
||||||
import ca.joeltherrien.randomforest.CovariateRow;
|
import ca.joeltherrien.randomforest.CovariateRow;
|
||||||
|
|
||||||
import java.io.Serializable;
|
import java.io.Serializable;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
public interface Node<Y> extends Serializable {
|
public interface Node<Y> extends Serializable {
|
||||||
|
|
||||||
Y evaluate(CovariateRow row);
|
Y evaluate(CovariateRow row);
|
||||||
|
|
||||||
|
<C extends Node<Y>> List<C> getNodesOfType(Class<C> nodeType);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -19,10 +19,15 @@ package ca.joeltherrien.randomforest.tree;
|
||||||
import ca.joeltherrien.randomforest.CovariateRow;
|
import ca.joeltherrien.randomforest.CovariateRow;
|
||||||
import ca.joeltherrien.randomforest.covariates.Covariate;
|
import ca.joeltherrien.randomforest.covariates.Covariate;
|
||||||
import lombok.Builder;
|
import lombok.Builder;
|
||||||
|
import lombok.Getter;
|
||||||
import lombok.ToString;
|
import lombok.ToString;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
@Builder
|
@Builder
|
||||||
@ToString
|
@ToString
|
||||||
|
@Getter
|
||||||
public class SplitNode<Y> implements Node<Y> {
|
public class SplitNode<Y> implements Node<Y> {
|
||||||
|
|
||||||
private final Node<Y> leftHand;
|
private final Node<Y> leftHand;
|
||||||
|
@ -41,4 +46,18 @@ public class SplitNode<Y> implements Node<Y> {
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public <C extends Node<Y>> List<C> getNodesOfType(Class<C> nodeType) {
|
||||||
|
final List<C> nodeList = new ArrayList<>();
|
||||||
|
if(nodeType.isInstance(this)){
|
||||||
|
nodeList.add((C) this);
|
||||||
|
}
|
||||||
|
|
||||||
|
nodeList.addAll(leftHand.getNodesOfType(nodeType));
|
||||||
|
nodeList.addAll(rightHand.getNodesOfType(nodeType));
|
||||||
|
|
||||||
|
return nodeList;
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -17,19 +17,36 @@
|
||||||
package ca.joeltherrien.randomforest.tree;
|
package ca.joeltherrien.randomforest.tree;
|
||||||
|
|
||||||
import ca.joeltherrien.randomforest.CovariateRow;
|
import ca.joeltherrien.randomforest.CovariateRow;
|
||||||
|
import lombok.Getter;
|
||||||
import lombok.RequiredArgsConstructor;
|
import lombok.RequiredArgsConstructor;
|
||||||
import lombok.ToString;
|
import lombok.ToString;
|
||||||
|
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
@RequiredArgsConstructor
|
@RequiredArgsConstructor
|
||||||
@ToString
|
@ToString
|
||||||
public class TerminalNode<Y> implements Node<Y> {
|
public class TerminalNode<Y> implements Node<Y> {
|
||||||
|
|
||||||
private final Y responseValue;
|
private final Y responseValue;
|
||||||
|
|
||||||
|
@Getter
|
||||||
|
private final int size;
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Y evaluate(CovariateRow row){
|
public Y evaluate(CovariateRow row){
|
||||||
return responseValue;
|
return responseValue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public <C extends Node<Y>> List<C> getNodesOfType(Class<C> nodeType) {
|
||||||
|
|
||||||
|
if(nodeType.isInstance(this)){
|
||||||
|
return Collections.singletonList((C) this);
|
||||||
|
}
|
||||||
|
|
||||||
|
return Collections.emptyList();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -17,11 +17,14 @@
|
||||||
package ca.joeltherrien.randomforest.tree;
|
package ca.joeltherrien.randomforest.tree;
|
||||||
|
|
||||||
import ca.joeltherrien.randomforest.CovariateRow;
|
import ca.joeltherrien.randomforest.CovariateRow;
|
||||||
|
import lombok.Getter;
|
||||||
|
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
public class Tree<Y> implements Node<Y> {
|
public class Tree<Y> implements Node<Y> {
|
||||||
|
|
||||||
|
@Getter
|
||||||
private final Node<Y> rootNode;
|
private final Node<Y> rootNode;
|
||||||
private final int[] bootstrapRowIds;
|
private final int[] bootstrapRowIds;
|
||||||
|
|
||||||
|
@ -37,6 +40,11 @@ public class Tree<Y> implements Node<Y> {
|
||||||
return rootNode.evaluate(row);
|
return rootNode.evaluate(row);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public <C extends Node<Y>> List<C> getNodesOfType(Class<C> nodeType) {
|
||||||
|
return rootNode.getNodesOfType(nodeType);
|
||||||
|
}
|
||||||
|
|
||||||
public int[] getBootstrapRowIds(){
|
public int[] getBootstrapRowIds(){
|
||||||
return bootstrapRowIds.clone();
|
return bootstrapRowIds.clone();
|
||||||
}
|
}
|
||||||
|
|
|
@ -77,11 +77,10 @@ public class TreeTrainer<Y, O> {
|
||||||
|
|
||||||
|
|
||||||
if(bestSplit == null){
|
if(bestSplit == null){
|
||||||
|
|
||||||
return new TerminalNode<>(
|
return new TerminalNode<>(
|
||||||
responseCombiner.combine(
|
responseCombiner.combine(
|
||||||
data.stream().map(row -> row.getResponse()).collect(Collectors.toList())
|
data.stream().map(row -> row.getResponse()).collect(Collectors.toList())
|
||||||
)
|
), data.size()
|
||||||
);
|
);
|
||||||
|
|
||||||
|
|
||||||
|
@ -121,7 +120,7 @@ public class TreeTrainer<Y, O> {
|
||||||
return new TerminalNode<>(
|
return new TerminalNode<>(
|
||||||
responseCombiner.combine(
|
responseCombiner.combine(
|
||||||
data.stream().map(row -> row.getResponse()).collect(Collectors.toList())
|
data.stream().map(row -> row.getResponse()).collect(Collectors.toList())
|
||||||
)
|
), data.size()
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -17,7 +17,6 @@
|
||||||
package ca.joeltherrien.randomforest.utils;
|
package ca.joeltherrien.randomforest.utils;
|
||||||
|
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
import java.util.concurrent.ThreadLocalRandom;
|
|
||||||
|
|
||||||
public final class Utils {
|
public final class Utils {
|
||||||
|
|
||||||
|
@ -52,12 +51,11 @@ public final class Utils {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public static <T> void reduceListToSize(List<T> list, int n){
|
public static <T> void reduceListToSize(List<T> list, int n, final Random random){
|
||||||
if(list.size() <= n){
|
if(list.size() <= n){
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
final Random random = ThreadLocalRandom.current();
|
|
||||||
if(n > list.size()/2){
|
if(n > list.size()/2){
|
||||||
// faster to randomly remove items
|
// faster to randomly remove items
|
||||||
while(list.size() > n){
|
while(list.size() > n){
|
||||||
|
|
|
@ -23,6 +23,7 @@ import org.junit.jupiter.api.Test;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.HashSet;
|
import java.util.HashSet;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
import java.util.Random;
|
||||||
import java.util.function.DoubleSupplier;
|
import java.util.function.DoubleSupplier;
|
||||||
import java.util.stream.DoubleStream;
|
import java.util.stream.DoubleStream;
|
||||||
|
|
||||||
|
@ -109,24 +110,24 @@ public class TestUtils {
|
||||||
@Test
|
@Test
|
||||||
public void reduceListToSize(){
|
public void reduceListToSize(){
|
||||||
final List<Integer> testList = Utils.easyList(1, 2, 3, 4, 5, 6, 7, 8, 9, 10);
|
final List<Integer> testList = Utils.easyList(1, 2, 3, 4, 5, 6, 7, 8, 9, 10);
|
||||||
|
final Random random = new Random();
|
||||||
for(int i=0; i<100; i++) { // want to test many times to ensure it doesn't work just due to randomness
|
for(int i=0; i<100; i++) { // want to test many times to ensure it doesn't work just due to randomness
|
||||||
final List<Integer> testList1 = new ArrayList<>(testList);
|
final List<Integer> testList1 = new ArrayList<>(testList);
|
||||||
// Test when removing elements
|
// Test when removing elements
|
||||||
Utils.reduceListToSize(testList1, 7);
|
Utils.reduceListToSize(testList1, 7, random);
|
||||||
assertEquals(7, testList1.size()); // verify proper size
|
assertEquals(7, testList1.size()); // verify proper size
|
||||||
assertEquals(7, new HashSet<>(testList1).size()); // verify the items are unique
|
assertEquals(7, new HashSet<>(testList1).size()); // verify the items are unique
|
||||||
|
|
||||||
|
|
||||||
final List<Integer> testList2 = new ArrayList<>(testList);
|
final List<Integer> testList2 = new ArrayList<>(testList);
|
||||||
// Test when adding elements
|
// Test when adding elements
|
||||||
Utils.reduceListToSize(testList2, 3);
|
Utils.reduceListToSize(testList2, 3, random);
|
||||||
assertEquals(3, testList2.size()); // verify proper size
|
assertEquals(3, testList2.size()); // verify proper size
|
||||||
assertEquals(3, new HashSet<>(testList2).size()); // verify the items are unique
|
assertEquals(3, new HashSet<>(testList2).size()); // verify the items are unique
|
||||||
|
|
||||||
final List<Integer> testList3 = new ArrayList<>(testList);
|
final List<Integer> testList3 = new ArrayList<>(testList);
|
||||||
// verify no change
|
// verify no change
|
||||||
Utils.reduceListToSize(testList3, 15);
|
Utils.reduceListToSize(testList3, 15, random);
|
||||||
assertEquals(10, testList3.size()); // verify proper size
|
assertEquals(10, testList3.size()); // verify proper size
|
||||||
assertEquals(10, new HashSet<>(testList3).size()); // verify the items are unique
|
assertEquals(10, new HashSet<>(testList3).size()); // verify the items are unique
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue