diff --git a/src/main/java/ca/joeltherrien/randomforest/Main.java b/src/main/java/ca/joeltherrien/randomforest/Main.java index 5d98322..847a5fb 100644 --- a/src/main/java/ca/joeltherrien/randomforest/Main.java +++ b/src/main/java/ca/joeltherrien/randomforest/Main.java @@ -16,8 +16,8 @@ package ca.joeltherrien.randomforest; -import ca.joeltherrien.randomforest.covariates.settings.BooleanCovariateSettings; import ca.joeltherrien.randomforest.covariates.Covariate; +import ca.joeltherrien.randomforest.covariates.settings.BooleanCovariateSettings; import ca.joeltherrien.randomforest.covariates.settings.FactorCovariateSettings; import ca.joeltherrien.randomforest.covariates.settings.NumericCovariateSettings; import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskErrorRateCalculator; @@ -38,6 +38,7 @@ import java.io.File; import java.io.IOException; import java.io.PrintWriter; import java.util.List; +import java.util.Random; public class Main { @@ -98,7 +99,7 @@ public class Main { // Let's reduce this down to n final int n = Integer.parseInt(args[2]); - Utils.reduceListToSize(dataset, n); + Utils.reduceListToSize(dataset, n, new Random()); final File folder = new File(settings.getSaveTreeLocation()); final Forest forest = DataLoader.loadForest(folder, responseCombiner); diff --git a/src/main/java/ca/joeltherrien/randomforest/tree/Forest.java b/src/main/java/ca/joeltherrien/randomforest/tree/Forest.java index a9976c5..cabce79 100644 --- a/src/main/java/ca/joeltherrien/randomforest/tree/Forest.java +++ b/src/main/java/ca/joeltherrien/randomforest/tree/Forest.java @@ -17,11 +17,10 @@ package ca.joeltherrien.randomforest.tree; import ca.joeltherrien.randomforest.CovariateRow; +import ca.joeltherrien.randomforest.covariates.Covariate; import lombok.Builder; -import java.util.Collection; -import java.util.Collections; -import java.util.List; +import java.util.*; import java.util.stream.Collectors; @Builder @@ -67,4 +66,54 @@ public class Forest { // O = output of trees, FO = forest output. In prac return Collections.unmodifiableCollection(trees); } + public Map findSplitsByCovariate(){ + final Map countMap = new HashMap<>(); + + for(final Tree tree : getTrees()){ + final Node rootNode = tree.getRootNode(); + final List splitNodeList = rootNode.getNodesOfType(SplitNode.class); + + for(final SplitNode splitNode : splitNodeList){ + final Covariate covariate = splitNode.getSplitRule().getParent(); + + final Integer currentCount = countMap.getOrDefault(covariate, 0); + countMap.put(covariate, currentCount+1); + } + + } + + return countMap; + } + + public double averageTerminalNodeSize(){ + long numberTerminalNodes = 0; + long totalSizeTerminalNodes = 0; + + for(final Tree tree : getTrees()){ + final Node rootNode = tree.getRootNode(); + final List terminalNodeList = rootNode.getNodesOfType(TerminalNode.class); + + for(final TerminalNode terminalNode : terminalNodeList){ + numberTerminalNodes++; + totalSizeTerminalNodes += terminalNode.getSize(); + } + + } + + return (double) totalSizeTerminalNodes / (double) numberTerminalNodes; + } + + public int numberOfTerminalNodes(){ + int countTerminalNodes = 0; + + for(final Tree tree : getTrees()){ + final Node rootNode = tree.getRootNode(); + final List terminalNodeList = rootNode.getNodesOfType(TerminalNode.class); + + countTerminalNodes += terminalNodeList.size(); + } + + return countTerminalNodes; + } + } diff --git a/src/main/java/ca/joeltherrien/randomforest/tree/Node.java b/src/main/java/ca/joeltherrien/randomforest/tree/Node.java index 5921c10..233003c 100644 --- a/src/main/java/ca/joeltherrien/randomforest/tree/Node.java +++ b/src/main/java/ca/joeltherrien/randomforest/tree/Node.java @@ -19,9 +19,12 @@ package ca.joeltherrien.randomforest.tree; import ca.joeltherrien.randomforest.CovariateRow; import java.io.Serializable; +import java.util.List; public interface Node extends Serializable { Y evaluate(CovariateRow row); + > List getNodesOfType(Class nodeType); + } diff --git a/src/main/java/ca/joeltherrien/randomforest/tree/SplitNode.java b/src/main/java/ca/joeltherrien/randomforest/tree/SplitNode.java index d64a783..de9b3d5 100644 --- a/src/main/java/ca/joeltherrien/randomforest/tree/SplitNode.java +++ b/src/main/java/ca/joeltherrien/randomforest/tree/SplitNode.java @@ -19,10 +19,15 @@ package ca.joeltherrien.randomforest.tree; import ca.joeltherrien.randomforest.CovariateRow; import ca.joeltherrien.randomforest.covariates.Covariate; import lombok.Builder; +import lombok.Getter; import lombok.ToString; +import java.util.ArrayList; +import java.util.List; + @Builder @ToString +@Getter public class SplitNode implements Node { private final Node leftHand; @@ -41,4 +46,18 @@ public class SplitNode implements Node { } } + + @Override + public > List getNodesOfType(Class nodeType) { + final List nodeList = new ArrayList<>(); + if(nodeType.isInstance(this)){ + nodeList.add((C) this); + } + + nodeList.addAll(leftHand.getNodesOfType(nodeType)); + nodeList.addAll(rightHand.getNodesOfType(nodeType)); + + return nodeList; + } + } diff --git a/src/main/java/ca/joeltherrien/randomforest/tree/TerminalNode.java b/src/main/java/ca/joeltherrien/randomforest/tree/TerminalNode.java index b4b9c18..17c6249 100644 --- a/src/main/java/ca/joeltherrien/randomforest/tree/TerminalNode.java +++ b/src/main/java/ca/joeltherrien/randomforest/tree/TerminalNode.java @@ -17,19 +17,36 @@ package ca.joeltherrien.randomforest.tree; import ca.joeltherrien.randomforest.CovariateRow; +import lombok.Getter; import lombok.RequiredArgsConstructor; import lombok.ToString; +import java.util.Collections; +import java.util.List; + @RequiredArgsConstructor @ToString public class TerminalNode implements Node { private final Y responseValue; + @Getter + private final int size; + @Override public Y evaluate(CovariateRow row){ return responseValue; } + @Override + public > List getNodesOfType(Class nodeType) { + + if(nodeType.isInstance(this)){ + return Collections.singletonList((C) this); + } + + return Collections.emptyList(); + } + } diff --git a/src/main/java/ca/joeltherrien/randomforest/tree/Tree.java b/src/main/java/ca/joeltherrien/randomforest/tree/Tree.java index 6b8b3d8..b3dba61 100644 --- a/src/main/java/ca/joeltherrien/randomforest/tree/Tree.java +++ b/src/main/java/ca/joeltherrien/randomforest/tree/Tree.java @@ -17,11 +17,14 @@ package ca.joeltherrien.randomforest.tree; import ca.joeltherrien.randomforest.CovariateRow; +import lombok.Getter; import java.util.Arrays; +import java.util.List; public class Tree implements Node { + @Getter private final Node rootNode; private final int[] bootstrapRowIds; @@ -37,6 +40,11 @@ public class Tree implements Node { return rootNode.evaluate(row); } + @Override + public > List getNodesOfType(Class nodeType) { + return rootNode.getNodesOfType(nodeType); + } + public int[] getBootstrapRowIds(){ return bootstrapRowIds.clone(); } diff --git a/src/main/java/ca/joeltherrien/randomforest/tree/TreeTrainer.java b/src/main/java/ca/joeltherrien/randomforest/tree/TreeTrainer.java index 4757b9f..df5fcb5 100644 --- a/src/main/java/ca/joeltherrien/randomforest/tree/TreeTrainer.java +++ b/src/main/java/ca/joeltherrien/randomforest/tree/TreeTrainer.java @@ -77,11 +77,10 @@ public class TreeTrainer { if(bestSplit == null){ - return new TerminalNode<>( responseCombiner.combine( data.stream().map(row -> row.getResponse()).collect(Collectors.toList()) - ) + ), data.size() ); @@ -121,7 +120,7 @@ public class TreeTrainer { return new TerminalNode<>( responseCombiner.combine( data.stream().map(row -> row.getResponse()).collect(Collectors.toList()) - ) + ), data.size() ); } diff --git a/src/main/java/ca/joeltherrien/randomforest/utils/Utils.java b/src/main/java/ca/joeltherrien/randomforest/utils/Utils.java index 51e4dcd..a8bc873 100644 --- a/src/main/java/ca/joeltherrien/randomforest/utils/Utils.java +++ b/src/main/java/ca/joeltherrien/randomforest/utils/Utils.java @@ -17,7 +17,6 @@ package ca.joeltherrien.randomforest.utils; import java.util.*; -import java.util.concurrent.ThreadLocalRandom; public final class Utils { @@ -52,12 +51,11 @@ public final class Utils { } - public static void reduceListToSize(List list, int n){ + public static void reduceListToSize(List list, int n, final Random random){ if(list.size() <= n){ return; } - final Random random = ThreadLocalRandom.current(); if(n > list.size()/2){ // faster to randomly remove items while(list.size() > n){ diff --git a/src/test/java/ca/joeltherrien/randomforest/TestUtils.java b/src/test/java/ca/joeltherrien/randomforest/TestUtils.java index 157810c..24faceb 100644 --- a/src/test/java/ca/joeltherrien/randomforest/TestUtils.java +++ b/src/test/java/ca/joeltherrien/randomforest/TestUtils.java @@ -23,6 +23,7 @@ import org.junit.jupiter.api.Test; import java.util.ArrayList; import java.util.HashSet; import java.util.List; +import java.util.Random; import java.util.function.DoubleSupplier; import java.util.stream.DoubleStream; @@ -109,24 +110,24 @@ public class TestUtils { @Test public void reduceListToSize(){ final List testList = Utils.easyList(1, 2, 3, 4, 5, 6, 7, 8, 9, 10); - + final Random random = new Random(); for(int i=0; i<100; i++) { // want to test many times to ensure it doesn't work just due to randomness final List testList1 = new ArrayList<>(testList); // Test when removing elements - Utils.reduceListToSize(testList1, 7); + Utils.reduceListToSize(testList1, 7, random); assertEquals(7, testList1.size()); // verify proper size assertEquals(7, new HashSet<>(testList1).size()); // verify the items are unique final List testList2 = new ArrayList<>(testList); // Test when adding elements - Utils.reduceListToSize(testList2, 3); + Utils.reduceListToSize(testList2, 3, random); assertEquals(3, testList2.size()); // verify proper size assertEquals(3, new HashSet<>(testList2).size()); // verify the items are unique final List testList3 = new ArrayList<>(testList); // verify no change - Utils.reduceListToSize(testList3, 15); + Utils.reduceListToSize(testList3, 15, random); assertEquals(10, testList3.size()); // verify proper size assertEquals(10, new HashSet<>(testList3).size()); // verify the items are unique