From 585d6d3c5bc153f7f9f9ef07f6b7fef8852a7605 Mon Sep 17 00:00:00 2001 From: Joel Therrien Date: Mon, 25 Mar 2019 14:44:31 -0700 Subject: [PATCH] Make SplitRules their own class; independent of their Covariate parents. This was done so that when we serialize trees (and thus SplitRules) we don't awkwardly also serialize ntree versions of the Covariates, which is really awkward when deserializing them. --- .../randomforest/CovariateRow.java | 4 + .../randomforest/covariates/Covariate.java | 92 +------------- .../randomforest/covariates/SplitRule.java | 115 ++++++++++++++++++ .../{ => bool}/BooleanCovariate.java | 40 ++---- .../covariates/bool/BooleanSplitRule.java | 48 ++++++++ .../{ => factor}/FactorCovariate.java | 44 ++----- .../covariates/factor/FactorSplitRule.java | 49 ++++++++ .../covariates/numeric/NumericCovariate.java | 35 +----- .../covariates/numeric/NumericSplitRule.java | 38 ++++++ .../numeric/NumericSplitRuleUpdater.java | 4 +- .../numeric/NumericSplitUpdate.java | 4 +- .../settings/BooleanCovariateSettings.java | 2 +- .../settings/FactorCovariateSettings.java | 2 +- .../randomforest/tree/Forest.java | 10 +- .../joeltherrien/randomforest/tree/Split.java | 4 +- .../randomforest/tree/SplitNode.java | 4 +- .../randomforest/tree/TreeTrainer.java | 6 +- .../randomforest/TestSavingLoading.java | 5 +- .../covariates/FactorCovariateTest.java | 3 +- .../workshop/TrainSingleTreeFactor.java | 2 +- 20 files changed, 313 insertions(+), 198 deletions(-) create mode 100644 src/main/java/ca/joeltherrien/randomforest/covariates/SplitRule.java rename src/main/java/ca/joeltherrien/randomforest/covariates/{ => bool}/BooleanCovariate.java (74%) create mode 100644 src/main/java/ca/joeltherrien/randomforest/covariates/bool/BooleanSplitRule.java rename src/main/java/ca/joeltherrien/randomforest/covariates/{ => factor}/FactorCovariate.java (75%) create mode 100644 src/main/java/ca/joeltherrien/randomforest/covariates/factor/FactorSplitRule.java create mode 100644 src/main/java/ca/joeltherrien/randomforest/covariates/numeric/NumericSplitRule.java diff --git a/src/main/java/ca/joeltherrien/randomforest/CovariateRow.java b/src/main/java/ca/joeltherrien/randomforest/CovariateRow.java index 60ac488..1f44191 100644 --- a/src/main/java/ca/joeltherrien/randomforest/CovariateRow.java +++ b/src/main/java/ca/joeltherrien/randomforest/CovariateRow.java @@ -37,6 +37,10 @@ public class CovariateRow implements Serializable { return valueArray[covariate.getIndex()]; } + public Covariate.Value getValueByIndex(int index){ + return valueArray[index]; + } + @Override public String toString(){ return "CovariateRow " + this.id; diff --git a/src/main/java/ca/joeltherrien/randomforest/covariates/Covariate.java b/src/main/java/ca/joeltherrien/randomforest/covariates/Covariate.java index ae21287..f0a8cc1 100644 --- a/src/main/java/ca/joeltherrien/randomforest/covariates/Covariate.java +++ b/src/main/java/ca/joeltherrien/randomforest/covariates/Covariate.java @@ -16,13 +16,14 @@ package ca.joeltherrien.randomforest.covariates; -import ca.joeltherrien.randomforest.CovariateRow; import ca.joeltherrien.randomforest.Row; import ca.joeltherrien.randomforest.tree.Split; import java.io.Serializable; -import java.util.*; -import java.util.concurrent.ThreadLocalRandom; +import java.util.Collection; +import java.util.Iterator; +import java.util.List; +import java.util.Random; public interface Covariate extends Serializable, Comparable { @@ -69,92 +70,7 @@ public interface Covariate extends Serializable, Comparable { Collection> rowsMovedToLeftHand(); } - interface SplitRule extends Serializable{ - Covariate getParent(); - - /** - * Applies the SplitRule to a list of rows and returns a Split object, which contains two lists for both sides. - * This method is primarily used during the training of a tree when splits are being tested. - * - * @param rows - * @param - * @return - */ - default Split applyRule(List> rows) { - - /* - When working with really large List> we need to be careful about memory. - If the lefthand and righthand lists are too small they grow, but for a moment copies exist - and memory issues arise. - - If they're too large, we waste memory yet again - */ - - // value of 0 = rightHand, value of 1 = leftHand, value of 2 = missingValueHand - final byte[] whichHand = new byte[rows.size()]; - int countLeftHand = 0; - int countRightHand = 0; - int countMissingHand = 0; - - - - for(int i=0; i row = rows.get(i); - - final Value value = row.getCovariateValue(getParent()); - - if(value.isNA()){ - countMissingHand++; - whichHand[i] = 2; - } - - if(isLeftHand(value)){ - countLeftHand++; - whichHand[i] = 1; - } - else{ - countRightHand++; - whichHand[i] = 0; - } - - } - - - final List> missingValueRows = new ArrayList<>(countMissingHand); - final List> leftHand = new ArrayList<>(countLeftHand); - final List> rightHand = new ArrayList<>(countRightHand); - - for(int i=0; i row = rows.get(i); - - if(whichHand[i] == 0){ - rightHand.add(row); - } - else if(whichHand[i] == 1){ - leftHand.add(row); - } - else{ - missingValueRows.add(row); - } - - } - - return new Split<>(this, leftHand, rightHand, missingValueRows); - } - - default boolean isLeftHand(CovariateRow row, final double probabilityNaLeftHand){ - final Value value = row.getCovariateValue(getParent()); - - if(value.isNA()){ - return ThreadLocalRandom.current().nextDouble() <= probabilityNaLeftHand; - } - - return isLeftHand(value); - } - - boolean isLeftHand(Value value); - } } diff --git a/src/main/java/ca/joeltherrien/randomforest/covariates/SplitRule.java b/src/main/java/ca/joeltherrien/randomforest/covariates/SplitRule.java new file mode 100644 index 0000000..36c9e7c --- /dev/null +++ b/src/main/java/ca/joeltherrien/randomforest/covariates/SplitRule.java @@ -0,0 +1,115 @@ +/* + * Copyright (c) 2019 Joel Therrien. + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package ca.joeltherrien.randomforest.covariates; + +import ca.joeltherrien.randomforest.CovariateRow; +import ca.joeltherrien.randomforest.Row; +import ca.joeltherrien.randomforest.tree.Split; + +import java.io.Serializable; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.ThreadLocalRandom; + +public interface SplitRule extends Serializable{ + + int getParentCovariateIndex(); + + /** + * Applies the SplitRule to a list of rows and returns a Split object, which contains two lists for both sides. + * This method is primarily used during the training of a tree when splits are being tested. + * + * @param rows + * @param + * @return + */ + default Split applyRule(List> rows) { + + /* + When working with really large List> we need to be careful about memory. + If the lefthand and righthand lists are too small they grow, but for a moment copies exist + and memory issues arise. + + If they're too large, we waste memory yet again + */ + + // value of 0 = rightHand, value of 1 = leftHand, value of 2 = missingValueHand + final byte[] whichHand = new byte[rows.size()]; + int countLeftHand = 0; + int countRightHand = 0; + int countMissingHand = 0; + + + + for(int i=0; i row = rows.get(i); + + final Covariate.Value value = row.getValueByIndex(getParentCovariateIndex()); + + if(value.isNA()){ + countMissingHand++; + whichHand[i] = 2; + } + + if(isLeftHand(value)){ + countLeftHand++; + whichHand[i] = 1; + } + else{ + countRightHand++; + whichHand[i] = 0; + } + + } + + + final List> missingValueRows = new ArrayList<>(countMissingHand); + final List> leftHand = new ArrayList<>(countLeftHand); + final List> rightHand = new ArrayList<>(countRightHand); + + for(int i=0; i row = rows.get(i); + + if(whichHand[i] == 0){ + rightHand.add(row); + } + else if(whichHand[i] == 1){ + leftHand.add(row); + } + else{ + missingValueRows.add(row); + } + + } + + return new Split<>(this, leftHand, rightHand, missingValueRows); + } + + default boolean isLeftHand(CovariateRow row, final double probabilityNaLeftHand){ + final Covariate.Value value = row.getValueByIndex(getParentCovariateIndex()); + + if(value.isNA()){ + return ThreadLocalRandom.current().nextDouble() <= probabilityNaLeftHand; + } + + return isLeftHand(value); + } + + boolean isLeftHand(Covariate.Value value); + + +} diff --git a/src/main/java/ca/joeltherrien/randomforest/covariates/BooleanCovariate.java b/src/main/java/ca/joeltherrien/randomforest/covariates/bool/BooleanCovariate.java similarity index 74% rename from src/main/java/ca/joeltherrien/randomforest/covariates/BooleanCovariate.java rename to src/main/java/ca/joeltherrien/randomforest/covariates/bool/BooleanCovariate.java index 2586929..797378b 100644 --- a/src/main/java/ca/joeltherrien/randomforest/covariates/BooleanCovariate.java +++ b/src/main/java/ca/joeltherrien/randomforest/covariates/bool/BooleanCovariate.java @@ -14,17 +14,18 @@ * along with this program. If not, see . */ -package ca.joeltherrien.randomforest.covariates; +package ca.joeltherrien.randomforest.covariates.bool; import ca.joeltherrien.randomforest.Row; +import ca.joeltherrien.randomforest.covariates.Covariate; import ca.joeltherrien.randomforest.tree.Split; import ca.joeltherrien.randomforest.utils.SingletonIterator; import lombok.Getter; -import lombok.RequiredArgsConstructor; -import java.util.*; +import java.util.Iterator; +import java.util.List; +import java.util.Random; -@RequiredArgsConstructor public final class BooleanCovariate implements Covariate { @Getter @@ -35,7 +36,13 @@ public final class BooleanCovariate implements Covariate { private boolean hasNAs = false; - private final BooleanSplitRule splitRule = new BooleanSplitRule(); // there's only one possible rule for BooleanCovariates. + private final BooleanSplitRule splitRule; // there's only one possible rule for BooleanCovariates. + + public BooleanCovariate(String name, int index){ + this.name = name; + this.index = index; + splitRule = new BooleanSplitRule(this); + } @Override public Iterator> generateSplitRuleUpdater(List> data, int number, Random random) { @@ -72,7 +79,7 @@ public final class BooleanCovariate implements Covariate { @Override public String toString(){ - return "BooleanCovariate(name=" + name + ")"; + return "BooleanCovariate(name=" + this.name + ", index=" + this.index + ", hasNAs=" + this.hasNAs + ")"; } public class BooleanValue implements Value{ @@ -100,25 +107,4 @@ public final class BooleanCovariate implements Covariate { } - public class BooleanSplitRule implements SplitRule{ - - @Override - public final String toString() { - return "BooleanSplitRule"; - } - - @Override - public BooleanCovariate getParent() { - return BooleanCovariate.this; - } - - @Override - public boolean isLeftHand(final Value value) { - if(value.isNA()) { - throw new IllegalArgumentException("Trying to determine split on missing value"); - } - - return !value.getValue(); - } - } } diff --git a/src/main/java/ca/joeltherrien/randomforest/covariates/bool/BooleanSplitRule.java b/src/main/java/ca/joeltherrien/randomforest/covariates/bool/BooleanSplitRule.java new file mode 100644 index 0000000..813d13a --- /dev/null +++ b/src/main/java/ca/joeltherrien/randomforest/covariates/bool/BooleanSplitRule.java @@ -0,0 +1,48 @@ +/* + * Copyright (c) 2019 Joel Therrien. + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package ca.joeltherrien.randomforest.covariates.bool; + +import ca.joeltherrien.randomforest.covariates.Covariate; +import ca.joeltherrien.randomforest.covariates.SplitRule; + +public class BooleanSplitRule implements SplitRule { + + private final int parentCovariateIndex; + + public BooleanSplitRule(BooleanCovariate parent){ + this.parentCovariateIndex = parent.getIndex(); + } + + @Override + public final String toString() { + return "BooleanSplitRule"; + } + + @Override + public int getParentCovariateIndex() { + return parentCovariateIndex; + } + + @Override + public boolean isLeftHand(final Covariate.Value value) { + if(value.isNA()) { + throw new IllegalArgumentException("Trying to determine split on missing value"); + } + + return !value.getValue(); + } +} \ No newline at end of file diff --git a/src/main/java/ca/joeltherrien/randomforest/covariates/FactorCovariate.java b/src/main/java/ca/joeltherrien/randomforest/covariates/factor/FactorCovariate.java similarity index 75% rename from src/main/java/ca/joeltherrien/randomforest/covariates/FactorCovariate.java rename to src/main/java/ca/joeltherrien/randomforest/covariates/factor/FactorCovariate.java index 4804c27..b3eb245 100644 --- a/src/main/java/ca/joeltherrien/randomforest/covariates/FactorCovariate.java +++ b/src/main/java/ca/joeltherrien/randomforest/covariates/factor/FactorCovariate.java @@ -14,16 +14,17 @@ * along with this program. If not, see . */ -package ca.joeltherrien.randomforest.covariates; +package ca.joeltherrien.randomforest.covariates.factor; import ca.joeltherrien.randomforest.Row; +import ca.joeltherrien.randomforest.covariates.Covariate; import ca.joeltherrien.randomforest.tree.Split; import lombok.EqualsAndHashCode; import lombok.Getter; import java.util.*; -public final class FactorCovariate implements Covariate{ +public final class FactorCovariate implements Covariate { @Getter private final String name; @@ -44,6 +45,10 @@ public final class FactorCovariate implements Covariate{ this.factorLevels = new HashMap<>(); for(final String level : levels){ + if(level.equalsIgnoreCase("na")){ + throw new IllegalArgumentException("Cannot use NA (case-insensitive) as a level in factor covariate " + name); + } + final FactorValue newValue = new FactorValue(level); factorLevels.put(level, newValue); @@ -70,16 +75,16 @@ public final class FactorCovariate implements Covariate{ while(splits.size() < number){ Collections.shuffle(levels, random); - final Set leftSideValues = new HashSet<>(); - leftSideValues.add(levels.get(0)); + final Set leftSideValues = new HashSet<>(); + leftSideValues.add(levels.get(0).getValue()); for(int i=1; i{ } @Override - public String toString(){ - return "FactorCovariate(name=" + name + ")"; + public String toString() { + return "FactorCovariate(name=" + this.name + ", index=" + this.index + ", hasNAs=" + this.hasNAs + ")"; } @EqualsAndHashCode @@ -139,27 +144,4 @@ public final class FactorCovariate implements Covariate{ } } - @EqualsAndHashCode - public final class FactorSplitRule implements Covariate.SplitRule{ - - private final Set leftSideValues; - - private FactorSplitRule(final Set leftSideValues){ - this.leftSideValues = leftSideValues; - } - - @Override - public FactorCovariate getParent() { - return FactorCovariate.this; - } - - @Override - public boolean isLeftHand(final Value value) { - if(value.isNA()){ - throw new IllegalArgumentException("Trying to determine split on missing value"); - } - - return leftSideValues.contains(value); - } - } } diff --git a/src/main/java/ca/joeltherrien/randomforest/covariates/factor/FactorSplitRule.java b/src/main/java/ca/joeltherrien/randomforest/covariates/factor/FactorSplitRule.java new file mode 100644 index 0000000..7fb4b4c --- /dev/null +++ b/src/main/java/ca/joeltherrien/randomforest/covariates/factor/FactorSplitRule.java @@ -0,0 +1,49 @@ +/* + * Copyright (c) 2019 Joel Therrien. + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package ca.joeltherrien.randomforest.covariates.factor; + +import ca.joeltherrien.randomforest.covariates.Covariate; +import ca.joeltherrien.randomforest.covariates.SplitRule; +import lombok.EqualsAndHashCode; + +import java.util.Set; + +@EqualsAndHashCode +public final class FactorSplitRule implements SplitRule { + + private final int parentCovariateIndex; + private final Set leftSideValues; + + public FactorSplitRule(final FactorCovariate parent, final Set leftSideValues){ + this.parentCovariateIndex = parent.getIndex(); + this.leftSideValues = leftSideValues; + } + + @Override + public int getParentCovariateIndex() { + return parentCovariateIndex; + } + + @Override + public boolean isLeftHand(final Covariate.Value value) { + if(value.isNA()){ + throw new IllegalArgumentException("Trying to determine split on missing value"); + } + + return leftSideValues.contains(value.getValue()); + } +} diff --git a/src/main/java/ca/joeltherrien/randomforest/covariates/numeric/NumericCovariate.java b/src/main/java/ca/joeltherrien/randomforest/covariates/numeric/NumericCovariate.java index e5b7b68..6c1c490 100644 --- a/src/main/java/ca/joeltherrien/randomforest/covariates/numeric/NumericCovariate.java +++ b/src/main/java/ca/joeltherrien/randomforest/covariates/numeric/NumericCovariate.java @@ -26,7 +26,10 @@ import lombok.Getter; import lombok.RequiredArgsConstructor; import lombok.ToString; -import java.util.*; +import java.util.Iterator; +import java.util.List; +import java.util.Random; +import java.util.TreeSet; import java.util.stream.Collectors; import java.util.stream.Stream; @@ -141,34 +144,4 @@ public final class NumericCovariate implements Covariate { } } - @EqualsAndHashCode - public class NumericSplitRule implements Covariate.SplitRule{ - - private final double threshold; - - NumericSplitRule(final double threshold){ - this.threshold = threshold; - } - - @Override - public final String toString() { - return "NumericSplitRule on " + getParent().getName() + " at " + threshold; - } - - @Override - public NumericCovariate getParent() { - return NumericCovariate.this; - } - - @Override - public boolean isLeftHand(final Value x) { - if(x.isNA()) { - throw new IllegalArgumentException("Trying to determine split on missing value"); - } - - final double xNum = x.getValue(); - - return xNum <= threshold; - } - } } diff --git a/src/main/java/ca/joeltherrien/randomforest/covariates/numeric/NumericSplitRule.java b/src/main/java/ca/joeltherrien/randomforest/covariates/numeric/NumericSplitRule.java new file mode 100644 index 0000000..acc8e3b --- /dev/null +++ b/src/main/java/ca/joeltherrien/randomforest/covariates/numeric/NumericSplitRule.java @@ -0,0 +1,38 @@ +package ca.joeltherrien.randomforest.covariates.numeric; + +import ca.joeltherrien.randomforest.covariates.Covariate; +import ca.joeltherrien.randomforest.covariates.SplitRule; +import lombok.EqualsAndHashCode; + +@EqualsAndHashCode +public class NumericSplitRule implements SplitRule { + + private final int parentCovariateIndex; + private final double threshold; + + NumericSplitRule(NumericCovariate parent, final double threshold){ + this.parentCovariateIndex = parent.getIndex(); + this.threshold = threshold; + } + + @Override + public final String toString() { + return "NumericSplitRule on " + getParentCovariateIndex() + " at " + threshold; + } + + @Override + public int getParentCovariateIndex() { + return parentCovariateIndex; + } + + @Override + public boolean isLeftHand(final Covariate.Value x) { + if(x.isNA()) { + throw new IllegalArgumentException("Trying to determine split on missing value"); + } + + final double xNum = x.getValue(); + + return xNum <= threshold; + } +} diff --git a/src/main/java/ca/joeltherrien/randomforest/covariates/numeric/NumericSplitRuleUpdater.java b/src/main/java/ca/joeltherrien/randomforest/covariates/numeric/NumericSplitRuleUpdater.java index fbbf7bf..662c23a 100644 --- a/src/main/java/ca/joeltherrien/randomforest/covariates/numeric/NumericSplitRuleUpdater.java +++ b/src/main/java/ca/joeltherrien/randomforest/covariates/numeric/NumericSplitRuleUpdater.java @@ -41,7 +41,7 @@ public class NumericSplitRuleUpdater implements Covariate.SplitRuleUpdater> rightHandList = orderedData; this.currentSplit = new Split<>( - covariate.new NumericSplitRule(Double.NEGATIVE_INFINITY), + new NumericSplitRule(covariate, Double.NEGATIVE_INFINITY), leftHandList, rightHandList, Collections.emptyList()); @@ -67,7 +67,7 @@ public class NumericSplitRuleUpdater implements Covariate.SplitRuleUpdater> rowsMoved = orderedData.subList(currentPosition, newPosition); - final NumericCovariate.NumericSplitRule splitRule = covariate.new NumericSplitRule(splitValue); + final NumericSplitRule splitRule = new NumericSplitRule(covariate, splitValue); // Update current split this.currentSplit = new Split<>( diff --git a/src/main/java/ca/joeltherrien/randomforest/covariates/numeric/NumericSplitUpdate.java b/src/main/java/ca/joeltherrien/randomforest/covariates/numeric/NumericSplitUpdate.java index 2f831c7..f372e3c 100644 --- a/src/main/java/ca/joeltherrien/randomforest/covariates/numeric/NumericSplitUpdate.java +++ b/src/main/java/ca/joeltherrien/randomforest/covariates/numeric/NumericSplitUpdate.java @@ -25,11 +25,11 @@ import java.util.List; @AllArgsConstructor public class NumericSplitUpdate implements Covariate.SplitUpdate { - private final NumericCovariate.NumericSplitRule numericSplitRule; + private final NumericSplitRule numericSplitRule; private final List> rowsMoved; @Override - public NumericCovariate.NumericSplitRule getSplitRule() { + public NumericSplitRule getSplitRule() { return numericSplitRule; } diff --git a/src/main/java/ca/joeltherrien/randomforest/covariates/settings/BooleanCovariateSettings.java b/src/main/java/ca/joeltherrien/randomforest/covariates/settings/BooleanCovariateSettings.java index d4309fd..3ed9028 100644 --- a/src/main/java/ca/joeltherrien/randomforest/covariates/settings/BooleanCovariateSettings.java +++ b/src/main/java/ca/joeltherrien/randomforest/covariates/settings/BooleanCovariateSettings.java @@ -16,7 +16,7 @@ package ca.joeltherrien.randomforest.covariates.settings; -import ca.joeltherrien.randomforest.covariates.BooleanCovariate; +import ca.joeltherrien.randomforest.covariates.bool.BooleanCovariate; import lombok.Data; import lombok.NoArgsConstructor; diff --git a/src/main/java/ca/joeltherrien/randomforest/covariates/settings/FactorCovariateSettings.java b/src/main/java/ca/joeltherrien/randomforest/covariates/settings/FactorCovariateSettings.java index 3568392..03ccd3b 100644 --- a/src/main/java/ca/joeltherrien/randomforest/covariates/settings/FactorCovariateSettings.java +++ b/src/main/java/ca/joeltherrien/randomforest/covariates/settings/FactorCovariateSettings.java @@ -16,7 +16,7 @@ package ca.joeltherrien.randomforest.covariates.settings; -import ca.joeltherrien.randomforest.covariates.FactorCovariate; +import ca.joeltherrien.randomforest.covariates.factor.FactorCovariate; import lombok.Data; import lombok.NoArgsConstructor; diff --git a/src/main/java/ca/joeltherrien/randomforest/tree/Forest.java b/src/main/java/ca/joeltherrien/randomforest/tree/Forest.java index f82abef..276c16c 100644 --- a/src/main/java/ca/joeltherrien/randomforest/tree/Forest.java +++ b/src/main/java/ca/joeltherrien/randomforest/tree/Forest.java @@ -67,18 +67,18 @@ public class Forest { // O = output of trees, FO = forest output. In prac return Collections.unmodifiableCollection(trees); } - public Map findSplitsByCovariate(){ - final Map countMap = new TreeMap<>(); + public Map findSplitsByCovariate(){ + final Map countMap = new TreeMap<>(); for(final Tree tree : getTrees()){ final Node rootNode = tree.getRootNode(); final List splitNodeList = rootNode.getNodesOfType(SplitNode.class); for(final SplitNode splitNode : splitNodeList){ - final Covariate covariate = splitNode.getSplitRule().getParent(); + final Integer covariateIndex = splitNode.getSplitRule().getParentCovariateIndex(); - final Integer currentCount = countMap.getOrDefault(covariate, 0); - countMap.put(covariate, currentCount+1); + final Integer currentCount = countMap.getOrDefault(covariateIndex, 0); + countMap.put(covariateIndex, currentCount+1); } } diff --git a/src/main/java/ca/joeltherrien/randomforest/tree/Split.java b/src/main/java/ca/joeltherrien/randomforest/tree/Split.java index 0b199fe..e527467 100644 --- a/src/main/java/ca/joeltherrien/randomforest/tree/Split.java +++ b/src/main/java/ca/joeltherrien/randomforest/tree/Split.java @@ -17,7 +17,7 @@ package ca.joeltherrien.randomforest.tree; import ca.joeltherrien.randomforest.Row; -import ca.joeltherrien.randomforest.covariates.Covariate; +import ca.joeltherrien.randomforest.covariates.SplitRule; import lombok.Data; import java.util.ArrayList; @@ -32,7 +32,7 @@ import java.util.List; @Data public final class Split { - public final Covariate.SplitRule splitRule; + public final SplitRule splitRule; public final List> leftHand; public final List> rightHand; public final List> naHand; diff --git a/src/main/java/ca/joeltherrien/randomforest/tree/SplitNode.java b/src/main/java/ca/joeltherrien/randomforest/tree/SplitNode.java index de9b3d5..f3bd170 100644 --- a/src/main/java/ca/joeltherrien/randomforest/tree/SplitNode.java +++ b/src/main/java/ca/joeltherrien/randomforest/tree/SplitNode.java @@ -17,7 +17,7 @@ package ca.joeltherrien.randomforest.tree; import ca.joeltherrien.randomforest.CovariateRow; -import ca.joeltherrien.randomforest.covariates.Covariate; +import ca.joeltherrien.randomforest.covariates.SplitRule; import lombok.Builder; import lombok.Getter; import lombok.ToString; @@ -32,7 +32,7 @@ public class SplitNode implements Node { private final Node leftHand; private final Node rightHand; - private final Covariate.SplitRule splitRule; + private final SplitRule splitRule; private final double probabilityNaLeftHand; // used when assigning NA values @Override diff --git a/src/main/java/ca/joeltherrien/randomforest/tree/TreeTrainer.java b/src/main/java/ca/joeltherrien/randomforest/tree/TreeTrainer.java index 4ce3630..a633664 100644 --- a/src/main/java/ca/joeltherrien/randomforest/tree/TreeTrainer.java +++ b/src/main/java/ca/joeltherrien/randomforest/tree/TreeTrainer.java @@ -91,13 +91,13 @@ public class TreeTrainer { (double) (bestSplit.leftHand.size() + bestSplit.rightHand.size()); // Assign missing values to the split if necessary - if(bestSplit.getSplitRule().getParent().hasNAs()){ + if(covariates.get(bestSplit.getSplitRule().getParentCovariateIndex()).hasNAs()){ bestSplit = bestSplit.modifiableClone(); // the lists in bestSplit are otherwise usually unmodifiable lists for(Row row : data) { - final Covariate covariate = bestSplit.getSplitRule().getParent(); + final int covariateIndex = bestSplit.getSplitRule().getParentCovariateIndex(); - if(row.getCovariateValue(covariate).isNA()) { + if(row.getValueByIndex(covariateIndex).isNA()) { final boolean randomDecision = random.nextDouble() <= probabilityLeftHand; if(randomDecision){ diff --git a/src/test/java/ca/joeltherrien/randomforest/TestSavingLoading.java b/src/test/java/ca/joeltherrien/randomforest/TestSavingLoading.java index 8a41da2..a8fd370 100644 --- a/src/test/java/ca/joeltherrien/randomforest/TestSavingLoading.java +++ b/src/test/java/ca/joeltherrien/randomforest/TestSavingLoading.java @@ -113,12 +113,15 @@ public class TestSavingLoading { final List> dataset = DataUtils.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation()); final File directory = new File(settings.getSaveTreeLocation()); + if(directory.exists()){ + directory.delete(); + } assertFalse(directory.exists()); directory.mkdir(); final ForestTrainer forestTrainer = new ForestTrainer<>(settings, dataset, covariates); - forestTrainer.trainParallelOnDisk(1); + forestTrainer.trainSerialOnDisk(); assertTrue(directory.exists()); assertTrue(directory.isDirectory()); diff --git a/src/test/java/ca/joeltherrien/randomforest/covariates/FactorCovariateTest.java b/src/test/java/ca/joeltherrien/randomforest/covariates/FactorCovariateTest.java index 47c74bb..84c134d 100644 --- a/src/test/java/ca/joeltherrien/randomforest/covariates/FactorCovariateTest.java +++ b/src/test/java/ca/joeltherrien/randomforest/covariates/FactorCovariateTest.java @@ -17,6 +17,7 @@ package ca.joeltherrien.randomforest.covariates; +import ca.joeltherrien.randomforest.covariates.factor.FactorCovariate; import ca.joeltherrien.randomforest.utils.Utils; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.function.Executable; @@ -63,7 +64,7 @@ public class FactorCovariateTest { void testAllSubsets(){ final FactorCovariate petCovariate = createTestCovariate(); - final List> splitRules = new ArrayList<>(); + final List> splitRules = new ArrayList<>(); petCovariate.generateSplitRuleUpdater(null, 100, new Random()) .forEachRemaining(split -> splitRules.add(split.getSplitRule())); diff --git a/src/test/java/ca/joeltherrien/randomforest/workshop/TrainSingleTreeFactor.java b/src/test/java/ca/joeltherrien/randomforest/workshop/TrainSingleTreeFactor.java index b9163d9..7313d22 100644 --- a/src/test/java/ca/joeltherrien/randomforest/workshop/TrainSingleTreeFactor.java +++ b/src/test/java/ca/joeltherrien/randomforest/workshop/TrainSingleTreeFactor.java @@ -19,7 +19,7 @@ package ca.joeltherrien.randomforest.workshop; import ca.joeltherrien.randomforest.*; import ca.joeltherrien.randomforest.covariates.Covariate; -import ca.joeltherrien.randomforest.covariates.FactorCovariate; +import ca.joeltherrien.randomforest.covariates.factor.FactorCovariate; import ca.joeltherrien.randomforest.covariates.numeric.NumericCovariate; import ca.joeltherrien.randomforest.responses.regression.MeanResponseCombiner; import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceGroupDifferentiator;