diff --git a/src/main/java/ca/joeltherrien/randomforest/CovariateRow.java b/src/main/java/ca/joeltherrien/randomforest/CovariateRow.java index 60ac488..1f44191 100644 --- a/src/main/java/ca/joeltherrien/randomforest/CovariateRow.java +++ b/src/main/java/ca/joeltherrien/randomforest/CovariateRow.java @@ -37,6 +37,10 @@ public class CovariateRow implements Serializable { return valueArray[covariate.getIndex()]; } + public Covariate.Value getValueByIndex(int index){ + return valueArray[index]; + } + @Override public String toString(){ return "CovariateRow " + this.id; diff --git a/src/main/java/ca/joeltherrien/randomforest/covariates/Covariate.java b/src/main/java/ca/joeltherrien/randomforest/covariates/Covariate.java index ae21287..f0a8cc1 100644 --- a/src/main/java/ca/joeltherrien/randomforest/covariates/Covariate.java +++ b/src/main/java/ca/joeltherrien/randomforest/covariates/Covariate.java @@ -16,13 +16,14 @@ package ca.joeltherrien.randomforest.covariates; -import ca.joeltherrien.randomforest.CovariateRow; import ca.joeltherrien.randomforest.Row; import ca.joeltherrien.randomforest.tree.Split; import java.io.Serializable; -import java.util.*; -import java.util.concurrent.ThreadLocalRandom; +import java.util.Collection; +import java.util.Iterator; +import java.util.List; +import java.util.Random; public interface Covariate extends Serializable, Comparable { @@ -69,92 +70,7 @@ public interface Covariate extends Serializable, Comparable { Collection> rowsMovedToLeftHand(); } - interface SplitRule extends Serializable{ - Covariate getParent(); - - /** - * Applies the SplitRule to a list of rows and returns a Split object, which contains two lists for both sides. - * This method is primarily used during the training of a tree when splits are being tested. - * - * @param rows - * @param - * @return - */ - default Split applyRule(List> rows) { - - /* - When working with really large List> we need to be careful about memory. - If the lefthand and righthand lists are too small they grow, but for a moment copies exist - and memory issues arise. - - If they're too large, we waste memory yet again - */ - - // value of 0 = rightHand, value of 1 = leftHand, value of 2 = missingValueHand - final byte[] whichHand = new byte[rows.size()]; - int countLeftHand = 0; - int countRightHand = 0; - int countMissingHand = 0; - - - - for(int i=0; i row = rows.get(i); - - final Value value = row.getCovariateValue(getParent()); - - if(value.isNA()){ - countMissingHand++; - whichHand[i] = 2; - } - - if(isLeftHand(value)){ - countLeftHand++; - whichHand[i] = 1; - } - else{ - countRightHand++; - whichHand[i] = 0; - } - - } - - - final List> missingValueRows = new ArrayList<>(countMissingHand); - final List> leftHand = new ArrayList<>(countLeftHand); - final List> rightHand = new ArrayList<>(countRightHand); - - for(int i=0; i row = rows.get(i); - - if(whichHand[i] == 0){ - rightHand.add(row); - } - else if(whichHand[i] == 1){ - leftHand.add(row); - } - else{ - missingValueRows.add(row); - } - - } - - return new Split<>(this, leftHand, rightHand, missingValueRows); - } - - default boolean isLeftHand(CovariateRow row, final double probabilityNaLeftHand){ - final Value value = row.getCovariateValue(getParent()); - - if(value.isNA()){ - return ThreadLocalRandom.current().nextDouble() <= probabilityNaLeftHand; - } - - return isLeftHand(value); - } - - boolean isLeftHand(Value value); - } } diff --git a/src/main/java/ca/joeltherrien/randomforest/covariates/SplitRule.java b/src/main/java/ca/joeltherrien/randomforest/covariates/SplitRule.java new file mode 100644 index 0000000..36c9e7c --- /dev/null +++ b/src/main/java/ca/joeltherrien/randomforest/covariates/SplitRule.java @@ -0,0 +1,115 @@ +/* + * Copyright (c) 2019 Joel Therrien. + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package ca.joeltherrien.randomforest.covariates; + +import ca.joeltherrien.randomforest.CovariateRow; +import ca.joeltherrien.randomforest.Row; +import ca.joeltherrien.randomforest.tree.Split; + +import java.io.Serializable; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.ThreadLocalRandom; + +public interface SplitRule extends Serializable{ + + int getParentCovariateIndex(); + + /** + * Applies the SplitRule to a list of rows and returns a Split object, which contains two lists for both sides. + * This method is primarily used during the training of a tree when splits are being tested. + * + * @param rows + * @param + * @return + */ + default Split applyRule(List> rows) { + + /* + When working with really large List> we need to be careful about memory. + If the lefthand and righthand lists are too small they grow, but for a moment copies exist + and memory issues arise. + + If they're too large, we waste memory yet again + */ + + // value of 0 = rightHand, value of 1 = leftHand, value of 2 = missingValueHand + final byte[] whichHand = new byte[rows.size()]; + int countLeftHand = 0; + int countRightHand = 0; + int countMissingHand = 0; + + + + for(int i=0; i row = rows.get(i); + + final Covariate.Value value = row.getValueByIndex(getParentCovariateIndex()); + + if(value.isNA()){ + countMissingHand++; + whichHand[i] = 2; + } + + if(isLeftHand(value)){ + countLeftHand++; + whichHand[i] = 1; + } + else{ + countRightHand++; + whichHand[i] = 0; + } + + } + + + final List> missingValueRows = new ArrayList<>(countMissingHand); + final List> leftHand = new ArrayList<>(countLeftHand); + final List> rightHand = new ArrayList<>(countRightHand); + + for(int i=0; i row = rows.get(i); + + if(whichHand[i] == 0){ + rightHand.add(row); + } + else if(whichHand[i] == 1){ + leftHand.add(row); + } + else{ + missingValueRows.add(row); + } + + } + + return new Split<>(this, leftHand, rightHand, missingValueRows); + } + + default boolean isLeftHand(CovariateRow row, final double probabilityNaLeftHand){ + final Covariate.Value value = row.getValueByIndex(getParentCovariateIndex()); + + if(value.isNA()){ + return ThreadLocalRandom.current().nextDouble() <= probabilityNaLeftHand; + } + + return isLeftHand(value); + } + + boolean isLeftHand(Covariate.Value value); + + +} diff --git a/src/main/java/ca/joeltherrien/randomforest/covariates/BooleanCovariate.java b/src/main/java/ca/joeltherrien/randomforest/covariates/bool/BooleanCovariate.java similarity index 74% rename from src/main/java/ca/joeltherrien/randomforest/covariates/BooleanCovariate.java rename to src/main/java/ca/joeltherrien/randomforest/covariates/bool/BooleanCovariate.java index 2586929..797378b 100644 --- a/src/main/java/ca/joeltherrien/randomforest/covariates/BooleanCovariate.java +++ b/src/main/java/ca/joeltherrien/randomforest/covariates/bool/BooleanCovariate.java @@ -14,17 +14,18 @@ * along with this program. If not, see . */ -package ca.joeltherrien.randomforest.covariates; +package ca.joeltherrien.randomforest.covariates.bool; import ca.joeltherrien.randomforest.Row; +import ca.joeltherrien.randomforest.covariates.Covariate; import ca.joeltherrien.randomforest.tree.Split; import ca.joeltherrien.randomforest.utils.SingletonIterator; import lombok.Getter; -import lombok.RequiredArgsConstructor; -import java.util.*; +import java.util.Iterator; +import java.util.List; +import java.util.Random; -@RequiredArgsConstructor public final class BooleanCovariate implements Covariate { @Getter @@ -35,7 +36,13 @@ public final class BooleanCovariate implements Covariate { private boolean hasNAs = false; - private final BooleanSplitRule splitRule = new BooleanSplitRule(); // there's only one possible rule for BooleanCovariates. + private final BooleanSplitRule splitRule; // there's only one possible rule for BooleanCovariates. + + public BooleanCovariate(String name, int index){ + this.name = name; + this.index = index; + splitRule = new BooleanSplitRule(this); + } @Override public Iterator> generateSplitRuleUpdater(List> data, int number, Random random) { @@ -72,7 +79,7 @@ public final class BooleanCovariate implements Covariate { @Override public String toString(){ - return "BooleanCovariate(name=" + name + ")"; + return "BooleanCovariate(name=" + this.name + ", index=" + this.index + ", hasNAs=" + this.hasNAs + ")"; } public class BooleanValue implements Value{ @@ -100,25 +107,4 @@ public final class BooleanCovariate implements Covariate { } - public class BooleanSplitRule implements SplitRule{ - - @Override - public final String toString() { - return "BooleanSplitRule"; - } - - @Override - public BooleanCovariate getParent() { - return BooleanCovariate.this; - } - - @Override - public boolean isLeftHand(final Value value) { - if(value.isNA()) { - throw new IllegalArgumentException("Trying to determine split on missing value"); - } - - return !value.getValue(); - } - } } diff --git a/src/main/java/ca/joeltherrien/randomforest/covariates/bool/BooleanSplitRule.java b/src/main/java/ca/joeltherrien/randomforest/covariates/bool/BooleanSplitRule.java new file mode 100644 index 0000000..813d13a --- /dev/null +++ b/src/main/java/ca/joeltherrien/randomforest/covariates/bool/BooleanSplitRule.java @@ -0,0 +1,48 @@ +/* + * Copyright (c) 2019 Joel Therrien. + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package ca.joeltherrien.randomforest.covariates.bool; + +import ca.joeltherrien.randomforest.covariates.Covariate; +import ca.joeltherrien.randomforest.covariates.SplitRule; + +public class BooleanSplitRule implements SplitRule { + + private final int parentCovariateIndex; + + public BooleanSplitRule(BooleanCovariate parent){ + this.parentCovariateIndex = parent.getIndex(); + } + + @Override + public final String toString() { + return "BooleanSplitRule"; + } + + @Override + public int getParentCovariateIndex() { + return parentCovariateIndex; + } + + @Override + public boolean isLeftHand(final Covariate.Value value) { + if(value.isNA()) { + throw new IllegalArgumentException("Trying to determine split on missing value"); + } + + return !value.getValue(); + } +} \ No newline at end of file diff --git a/src/main/java/ca/joeltherrien/randomforest/covariates/FactorCovariate.java b/src/main/java/ca/joeltherrien/randomforest/covariates/factor/FactorCovariate.java similarity index 75% rename from src/main/java/ca/joeltherrien/randomforest/covariates/FactorCovariate.java rename to src/main/java/ca/joeltherrien/randomforest/covariates/factor/FactorCovariate.java index 4804c27..b3eb245 100644 --- a/src/main/java/ca/joeltherrien/randomforest/covariates/FactorCovariate.java +++ b/src/main/java/ca/joeltherrien/randomforest/covariates/factor/FactorCovariate.java @@ -14,16 +14,17 @@ * along with this program. If not, see . */ -package ca.joeltherrien.randomforest.covariates; +package ca.joeltherrien.randomforest.covariates.factor; import ca.joeltherrien.randomforest.Row; +import ca.joeltherrien.randomforest.covariates.Covariate; import ca.joeltherrien.randomforest.tree.Split; import lombok.EqualsAndHashCode; import lombok.Getter; import java.util.*; -public final class FactorCovariate implements Covariate{ +public final class FactorCovariate implements Covariate { @Getter private final String name; @@ -44,6 +45,10 @@ public final class FactorCovariate implements Covariate{ this.factorLevels = new HashMap<>(); for(final String level : levels){ + if(level.equalsIgnoreCase("na")){ + throw new IllegalArgumentException("Cannot use NA (case-insensitive) as a level in factor covariate " + name); + } + final FactorValue newValue = new FactorValue(level); factorLevels.put(level, newValue); @@ -70,16 +75,16 @@ public final class FactorCovariate implements Covariate{ while(splits.size() < number){ Collections.shuffle(levels, random); - final Set leftSideValues = new HashSet<>(); - leftSideValues.add(levels.get(0)); + final Set leftSideValues = new HashSet<>(); + leftSideValues.add(levels.get(0).getValue()); for(int i=1; i{ } @Override - public String toString(){ - return "FactorCovariate(name=" + name + ")"; + public String toString() { + return "FactorCovariate(name=" + this.name + ", index=" + this.index + ", hasNAs=" + this.hasNAs + ")"; } @EqualsAndHashCode @@ -139,27 +144,4 @@ public final class FactorCovariate implements Covariate{ } } - @EqualsAndHashCode - public final class FactorSplitRule implements Covariate.SplitRule{ - - private final Set leftSideValues; - - private FactorSplitRule(final Set leftSideValues){ - this.leftSideValues = leftSideValues; - } - - @Override - public FactorCovariate getParent() { - return FactorCovariate.this; - } - - @Override - public boolean isLeftHand(final Value value) { - if(value.isNA()){ - throw new IllegalArgumentException("Trying to determine split on missing value"); - } - - return leftSideValues.contains(value); - } - } } diff --git a/src/main/java/ca/joeltherrien/randomforest/covariates/factor/FactorSplitRule.java b/src/main/java/ca/joeltherrien/randomforest/covariates/factor/FactorSplitRule.java new file mode 100644 index 0000000..7fb4b4c --- /dev/null +++ b/src/main/java/ca/joeltherrien/randomforest/covariates/factor/FactorSplitRule.java @@ -0,0 +1,49 @@ +/* + * Copyright (c) 2019 Joel Therrien. + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package ca.joeltherrien.randomforest.covariates.factor; + +import ca.joeltherrien.randomforest.covariates.Covariate; +import ca.joeltherrien.randomforest.covariates.SplitRule; +import lombok.EqualsAndHashCode; + +import java.util.Set; + +@EqualsAndHashCode +public final class FactorSplitRule implements SplitRule { + + private final int parentCovariateIndex; + private final Set leftSideValues; + + public FactorSplitRule(final FactorCovariate parent, final Set leftSideValues){ + this.parentCovariateIndex = parent.getIndex(); + this.leftSideValues = leftSideValues; + } + + @Override + public int getParentCovariateIndex() { + return parentCovariateIndex; + } + + @Override + public boolean isLeftHand(final Covariate.Value value) { + if(value.isNA()){ + throw new IllegalArgumentException("Trying to determine split on missing value"); + } + + return leftSideValues.contains(value.getValue()); + } +} diff --git a/src/main/java/ca/joeltherrien/randomforest/covariates/numeric/NumericCovariate.java b/src/main/java/ca/joeltherrien/randomforest/covariates/numeric/NumericCovariate.java index e5b7b68..6c1c490 100644 --- a/src/main/java/ca/joeltherrien/randomforest/covariates/numeric/NumericCovariate.java +++ b/src/main/java/ca/joeltherrien/randomforest/covariates/numeric/NumericCovariate.java @@ -26,7 +26,10 @@ import lombok.Getter; import lombok.RequiredArgsConstructor; import lombok.ToString; -import java.util.*; +import java.util.Iterator; +import java.util.List; +import java.util.Random; +import java.util.TreeSet; import java.util.stream.Collectors; import java.util.stream.Stream; @@ -141,34 +144,4 @@ public final class NumericCovariate implements Covariate { } } - @EqualsAndHashCode - public class NumericSplitRule implements Covariate.SplitRule{ - - private final double threshold; - - NumericSplitRule(final double threshold){ - this.threshold = threshold; - } - - @Override - public final String toString() { - return "NumericSplitRule on " + getParent().getName() + " at " + threshold; - } - - @Override - public NumericCovariate getParent() { - return NumericCovariate.this; - } - - @Override - public boolean isLeftHand(final Value x) { - if(x.isNA()) { - throw new IllegalArgumentException("Trying to determine split on missing value"); - } - - final double xNum = x.getValue(); - - return xNum <= threshold; - } - } } diff --git a/src/main/java/ca/joeltherrien/randomforest/covariates/numeric/NumericSplitRule.java b/src/main/java/ca/joeltherrien/randomforest/covariates/numeric/NumericSplitRule.java new file mode 100644 index 0000000..acc8e3b --- /dev/null +++ b/src/main/java/ca/joeltherrien/randomforest/covariates/numeric/NumericSplitRule.java @@ -0,0 +1,38 @@ +package ca.joeltherrien.randomforest.covariates.numeric; + +import ca.joeltherrien.randomforest.covariates.Covariate; +import ca.joeltherrien.randomforest.covariates.SplitRule; +import lombok.EqualsAndHashCode; + +@EqualsAndHashCode +public class NumericSplitRule implements SplitRule { + + private final int parentCovariateIndex; + private final double threshold; + + NumericSplitRule(NumericCovariate parent, final double threshold){ + this.parentCovariateIndex = parent.getIndex(); + this.threshold = threshold; + } + + @Override + public final String toString() { + return "NumericSplitRule on " + getParentCovariateIndex() + " at " + threshold; + } + + @Override + public int getParentCovariateIndex() { + return parentCovariateIndex; + } + + @Override + public boolean isLeftHand(final Covariate.Value x) { + if(x.isNA()) { + throw new IllegalArgumentException("Trying to determine split on missing value"); + } + + final double xNum = x.getValue(); + + return xNum <= threshold; + } +} diff --git a/src/main/java/ca/joeltherrien/randomforest/covariates/numeric/NumericSplitRuleUpdater.java b/src/main/java/ca/joeltherrien/randomforest/covariates/numeric/NumericSplitRuleUpdater.java index fbbf7bf..662c23a 100644 --- a/src/main/java/ca/joeltherrien/randomforest/covariates/numeric/NumericSplitRuleUpdater.java +++ b/src/main/java/ca/joeltherrien/randomforest/covariates/numeric/NumericSplitRuleUpdater.java @@ -41,7 +41,7 @@ public class NumericSplitRuleUpdater implements Covariate.SplitRuleUpdater> rightHandList = orderedData; this.currentSplit = new Split<>( - covariate.new NumericSplitRule(Double.NEGATIVE_INFINITY), + new NumericSplitRule(covariate, Double.NEGATIVE_INFINITY), leftHandList, rightHandList, Collections.emptyList()); @@ -67,7 +67,7 @@ public class NumericSplitRuleUpdater implements Covariate.SplitRuleUpdater> rowsMoved = orderedData.subList(currentPosition, newPosition); - final NumericCovariate.NumericSplitRule splitRule = covariate.new NumericSplitRule(splitValue); + final NumericSplitRule splitRule = new NumericSplitRule(covariate, splitValue); // Update current split this.currentSplit = new Split<>( diff --git a/src/main/java/ca/joeltherrien/randomforest/covariates/numeric/NumericSplitUpdate.java b/src/main/java/ca/joeltherrien/randomforest/covariates/numeric/NumericSplitUpdate.java index 2f831c7..f372e3c 100644 --- a/src/main/java/ca/joeltherrien/randomforest/covariates/numeric/NumericSplitUpdate.java +++ b/src/main/java/ca/joeltherrien/randomforest/covariates/numeric/NumericSplitUpdate.java @@ -25,11 +25,11 @@ import java.util.List; @AllArgsConstructor public class NumericSplitUpdate implements Covariate.SplitUpdate { - private final NumericCovariate.NumericSplitRule numericSplitRule; + private final NumericSplitRule numericSplitRule; private final List> rowsMoved; @Override - public NumericCovariate.NumericSplitRule getSplitRule() { + public NumericSplitRule getSplitRule() { return numericSplitRule; } diff --git a/src/main/java/ca/joeltherrien/randomforest/covariates/settings/BooleanCovariateSettings.java b/src/main/java/ca/joeltherrien/randomforest/covariates/settings/BooleanCovariateSettings.java index d4309fd..3ed9028 100644 --- a/src/main/java/ca/joeltherrien/randomforest/covariates/settings/BooleanCovariateSettings.java +++ b/src/main/java/ca/joeltherrien/randomforest/covariates/settings/BooleanCovariateSettings.java @@ -16,7 +16,7 @@ package ca.joeltherrien.randomforest.covariates.settings; -import ca.joeltherrien.randomforest.covariates.BooleanCovariate; +import ca.joeltherrien.randomforest.covariates.bool.BooleanCovariate; import lombok.Data; import lombok.NoArgsConstructor; diff --git a/src/main/java/ca/joeltherrien/randomforest/covariates/settings/FactorCovariateSettings.java b/src/main/java/ca/joeltherrien/randomforest/covariates/settings/FactorCovariateSettings.java index 3568392..03ccd3b 100644 --- a/src/main/java/ca/joeltherrien/randomforest/covariates/settings/FactorCovariateSettings.java +++ b/src/main/java/ca/joeltherrien/randomforest/covariates/settings/FactorCovariateSettings.java @@ -16,7 +16,7 @@ package ca.joeltherrien.randomforest.covariates.settings; -import ca.joeltherrien.randomforest.covariates.FactorCovariate; +import ca.joeltherrien.randomforest.covariates.factor.FactorCovariate; import lombok.Data; import lombok.NoArgsConstructor; diff --git a/src/main/java/ca/joeltherrien/randomforest/tree/Forest.java b/src/main/java/ca/joeltherrien/randomforest/tree/Forest.java index f82abef..276c16c 100644 --- a/src/main/java/ca/joeltherrien/randomforest/tree/Forest.java +++ b/src/main/java/ca/joeltherrien/randomforest/tree/Forest.java @@ -67,18 +67,18 @@ public class Forest { // O = output of trees, FO = forest output. In prac return Collections.unmodifiableCollection(trees); } - public Map findSplitsByCovariate(){ - final Map countMap = new TreeMap<>(); + public Map findSplitsByCovariate(){ + final Map countMap = new TreeMap<>(); for(final Tree tree : getTrees()){ final Node rootNode = tree.getRootNode(); final List splitNodeList = rootNode.getNodesOfType(SplitNode.class); for(final SplitNode splitNode : splitNodeList){ - final Covariate covariate = splitNode.getSplitRule().getParent(); + final Integer covariateIndex = splitNode.getSplitRule().getParentCovariateIndex(); - final Integer currentCount = countMap.getOrDefault(covariate, 0); - countMap.put(covariate, currentCount+1); + final Integer currentCount = countMap.getOrDefault(covariateIndex, 0); + countMap.put(covariateIndex, currentCount+1); } } diff --git a/src/main/java/ca/joeltherrien/randomforest/tree/Split.java b/src/main/java/ca/joeltherrien/randomforest/tree/Split.java index 0b199fe..e527467 100644 --- a/src/main/java/ca/joeltherrien/randomforest/tree/Split.java +++ b/src/main/java/ca/joeltherrien/randomforest/tree/Split.java @@ -17,7 +17,7 @@ package ca.joeltherrien.randomforest.tree; import ca.joeltherrien.randomforest.Row; -import ca.joeltherrien.randomforest.covariates.Covariate; +import ca.joeltherrien.randomforest.covariates.SplitRule; import lombok.Data; import java.util.ArrayList; @@ -32,7 +32,7 @@ import java.util.List; @Data public final class Split { - public final Covariate.SplitRule splitRule; + public final SplitRule splitRule; public final List> leftHand; public final List> rightHand; public final List> naHand; diff --git a/src/main/java/ca/joeltherrien/randomforest/tree/SplitNode.java b/src/main/java/ca/joeltherrien/randomforest/tree/SplitNode.java index de9b3d5..f3bd170 100644 --- a/src/main/java/ca/joeltherrien/randomforest/tree/SplitNode.java +++ b/src/main/java/ca/joeltherrien/randomforest/tree/SplitNode.java @@ -17,7 +17,7 @@ package ca.joeltherrien.randomforest.tree; import ca.joeltherrien.randomforest.CovariateRow; -import ca.joeltherrien.randomforest.covariates.Covariate; +import ca.joeltherrien.randomforest.covariates.SplitRule; import lombok.Builder; import lombok.Getter; import lombok.ToString; @@ -32,7 +32,7 @@ public class SplitNode implements Node { private final Node leftHand; private final Node rightHand; - private final Covariate.SplitRule splitRule; + private final SplitRule splitRule; private final double probabilityNaLeftHand; // used when assigning NA values @Override diff --git a/src/main/java/ca/joeltherrien/randomforest/tree/TreeTrainer.java b/src/main/java/ca/joeltherrien/randomforest/tree/TreeTrainer.java index 4ce3630..a633664 100644 --- a/src/main/java/ca/joeltherrien/randomforest/tree/TreeTrainer.java +++ b/src/main/java/ca/joeltherrien/randomforest/tree/TreeTrainer.java @@ -91,13 +91,13 @@ public class TreeTrainer { (double) (bestSplit.leftHand.size() + bestSplit.rightHand.size()); // Assign missing values to the split if necessary - if(bestSplit.getSplitRule().getParent().hasNAs()){ + if(covariates.get(bestSplit.getSplitRule().getParentCovariateIndex()).hasNAs()){ bestSplit = bestSplit.modifiableClone(); // the lists in bestSplit are otherwise usually unmodifiable lists for(Row row : data) { - final Covariate covariate = bestSplit.getSplitRule().getParent(); + final int covariateIndex = bestSplit.getSplitRule().getParentCovariateIndex(); - if(row.getCovariateValue(covariate).isNA()) { + if(row.getValueByIndex(covariateIndex).isNA()) { final boolean randomDecision = random.nextDouble() <= probabilityLeftHand; if(randomDecision){ diff --git a/src/test/java/ca/joeltherrien/randomforest/TestSavingLoading.java b/src/test/java/ca/joeltherrien/randomforest/TestSavingLoading.java index 8a41da2..a8fd370 100644 --- a/src/test/java/ca/joeltherrien/randomforest/TestSavingLoading.java +++ b/src/test/java/ca/joeltherrien/randomforest/TestSavingLoading.java @@ -113,12 +113,15 @@ public class TestSavingLoading { final List> dataset = DataUtils.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation()); final File directory = new File(settings.getSaveTreeLocation()); + if(directory.exists()){ + directory.delete(); + } assertFalse(directory.exists()); directory.mkdir(); final ForestTrainer forestTrainer = new ForestTrainer<>(settings, dataset, covariates); - forestTrainer.trainParallelOnDisk(1); + forestTrainer.trainSerialOnDisk(); assertTrue(directory.exists()); assertTrue(directory.isDirectory()); diff --git a/src/test/java/ca/joeltherrien/randomforest/covariates/FactorCovariateTest.java b/src/test/java/ca/joeltherrien/randomforest/covariates/FactorCovariateTest.java index 47c74bb..84c134d 100644 --- a/src/test/java/ca/joeltherrien/randomforest/covariates/FactorCovariateTest.java +++ b/src/test/java/ca/joeltherrien/randomforest/covariates/FactorCovariateTest.java @@ -17,6 +17,7 @@ package ca.joeltherrien.randomforest.covariates; +import ca.joeltherrien.randomforest.covariates.factor.FactorCovariate; import ca.joeltherrien.randomforest.utils.Utils; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.function.Executable; @@ -63,7 +64,7 @@ public class FactorCovariateTest { void testAllSubsets(){ final FactorCovariate petCovariate = createTestCovariate(); - final List> splitRules = new ArrayList<>(); + final List> splitRules = new ArrayList<>(); petCovariate.generateSplitRuleUpdater(null, 100, new Random()) .forEachRemaining(split -> splitRules.add(split.getSplitRule())); diff --git a/src/test/java/ca/joeltherrien/randomforest/workshop/TrainSingleTreeFactor.java b/src/test/java/ca/joeltherrien/randomforest/workshop/TrainSingleTreeFactor.java index b9163d9..7313d22 100644 --- a/src/test/java/ca/joeltherrien/randomforest/workshop/TrainSingleTreeFactor.java +++ b/src/test/java/ca/joeltherrien/randomforest/workshop/TrainSingleTreeFactor.java @@ -19,7 +19,7 @@ package ca.joeltherrien.randomforest.workshop; import ca.joeltherrien.randomforest.*; import ca.joeltherrien.randomforest.covariates.Covariate; -import ca.joeltherrien.randomforest.covariates.FactorCovariate; +import ca.joeltherrien.randomforest.covariates.factor.FactorCovariate; import ca.joeltherrien.randomforest.covariates.numeric.NumericCovariate; import ca.joeltherrien.randomforest.responses.regression.MeanResponseCombiner; import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceGroupDifferentiator;