Make SplitRules their own class; independent of their Covariate parents.

This was done so that when we serialize trees (and thus SplitRules) we don't awkwardly also serialize ntree versions of the Covariates,
which is really awkward when deserializing them.
This commit is contained in:
Joel Therrien 2019-03-25 14:44:31 -07:00
parent 76b2cdd3c4
commit 585d6d3c5b
20 changed files with 313 additions and 198 deletions

View file

@ -37,6 +37,10 @@ public class CovariateRow implements Serializable {
return valueArray[covariate.getIndex()]; return valueArray[covariate.getIndex()];
} }
public <V> Covariate.Value<V> getValueByIndex(int index){
return valueArray[index];
}
@Override @Override
public String toString(){ public String toString(){
return "CovariateRow " + this.id; return "CovariateRow " + this.id;

View file

@ -16,13 +16,14 @@
package ca.joeltherrien.randomforest.covariates; package ca.joeltherrien.randomforest.covariates;
import ca.joeltherrien.randomforest.CovariateRow;
import ca.joeltherrien.randomforest.Row; import ca.joeltherrien.randomforest.Row;
import ca.joeltherrien.randomforest.tree.Split; import ca.joeltherrien.randomforest.tree.Split;
import java.io.Serializable; import java.io.Serializable;
import java.util.*; import java.util.Collection;
import java.util.concurrent.ThreadLocalRandom; import java.util.Iterator;
import java.util.List;
import java.util.Random;
public interface Covariate<V> extends Serializable, Comparable<Covariate> { public interface Covariate<V> extends Serializable, Comparable<Covariate> {
@ -69,92 +70,7 @@ public interface Covariate<V> extends Serializable, Comparable<Covariate> {
Collection<Row<Y>> rowsMovedToLeftHand(); Collection<Row<Y>> rowsMovedToLeftHand();
} }
interface SplitRule<V> extends Serializable{
Covariate<V> getParent();
/**
* Applies the SplitRule to a list of rows and returns a Split object, which contains two lists for both sides.
* This method is primarily used during the training of a tree when splits are being tested.
*
* @param rows
* @param <Y>
* @return
*/
default <Y> Split<Y, V> applyRule(List<Row<Y>> rows) {
/*
When working with really large List<Row<Y>> we need to be careful about memory.
If the lefthand and righthand lists are too small they grow, but for a moment copies exist
and memory issues arise.
If they're too large, we waste memory yet again
*/
// value of 0 = rightHand, value of 1 = leftHand, value of 2 = missingValueHand
final byte[] whichHand = new byte[rows.size()];
int countLeftHand = 0;
int countRightHand = 0;
int countMissingHand = 0;
for(int i=0; i<whichHand.length; i++){
final Row<Y> row = rows.get(i);
final Value<V> value = row.getCovariateValue(getParent());
if(value.isNA()){
countMissingHand++;
whichHand[i] = 2;
}
if(isLeftHand(value)){
countLeftHand++;
whichHand[i] = 1;
}
else{
countRightHand++;
whichHand[i] = 0;
}
}
final List<Row<Y>> missingValueRows = new ArrayList<>(countMissingHand);
final List<Row<Y>> leftHand = new ArrayList<>(countLeftHand);
final List<Row<Y>> rightHand = new ArrayList<>(countRightHand);
for(int i=0; i<whichHand.length; i++){
final Row<Y> row = rows.get(i);
if(whichHand[i] == 0){
rightHand.add(row);
}
else if(whichHand[i] == 1){
leftHand.add(row);
}
else{
missingValueRows.add(row);
}
}
return new Split<>(this, leftHand, rightHand, missingValueRows);
}
default boolean isLeftHand(CovariateRow row, final double probabilityNaLeftHand){
final Value<V> value = row.getCovariateValue(getParent());
if(value.isNA()){
return ThreadLocalRandom.current().nextDouble() <= probabilityNaLeftHand;
}
return isLeftHand(value);
}
boolean isLeftHand(Value<V> value);
}
} }

View file

@ -0,0 +1,115 @@
/*
* Copyright (c) 2019 Joel Therrien.
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
package ca.joeltherrien.randomforest.covariates;
import ca.joeltherrien.randomforest.CovariateRow;
import ca.joeltherrien.randomforest.Row;
import ca.joeltherrien.randomforest.tree.Split;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.ThreadLocalRandom;
public interface SplitRule<V> extends Serializable{
int getParentCovariateIndex();
/**
* Applies the SplitRule to a list of rows and returns a Split object, which contains two lists for both sides.
* This method is primarily used during the training of a tree when splits are being tested.
*
* @param rows
* @param <Y>
* @return
*/
default <Y> Split<Y, V> applyRule(List<Row<Y>> rows) {
/*
When working with really large List<Row<Y>> we need to be careful about memory.
If the lefthand and righthand lists are too small they grow, but for a moment copies exist
and memory issues arise.
If they're too large, we waste memory yet again
*/
// value of 0 = rightHand, value of 1 = leftHand, value of 2 = missingValueHand
final byte[] whichHand = new byte[rows.size()];
int countLeftHand = 0;
int countRightHand = 0;
int countMissingHand = 0;
for(int i=0; i<whichHand.length; i++){
final Row<Y> row = rows.get(i);
final Covariate.Value<V> value = row.getValueByIndex(getParentCovariateIndex());
if(value.isNA()){
countMissingHand++;
whichHand[i] = 2;
}
if(isLeftHand(value)){
countLeftHand++;
whichHand[i] = 1;
}
else{
countRightHand++;
whichHand[i] = 0;
}
}
final List<Row<Y>> missingValueRows = new ArrayList<>(countMissingHand);
final List<Row<Y>> leftHand = new ArrayList<>(countLeftHand);
final List<Row<Y>> rightHand = new ArrayList<>(countRightHand);
for(int i=0; i<whichHand.length; i++){
final Row<Y> row = rows.get(i);
if(whichHand[i] == 0){
rightHand.add(row);
}
else if(whichHand[i] == 1){
leftHand.add(row);
}
else{
missingValueRows.add(row);
}
}
return new Split<>(this, leftHand, rightHand, missingValueRows);
}
default boolean isLeftHand(CovariateRow row, final double probabilityNaLeftHand){
final Covariate.Value<V> value = row.getValueByIndex(getParentCovariateIndex());
if(value.isNA()){
return ThreadLocalRandom.current().nextDouble() <= probabilityNaLeftHand;
}
return isLeftHand(value);
}
boolean isLeftHand(Covariate.Value<V> value);
}

View file

@ -14,17 +14,18 @@
* along with this program. If not, see <https://www.gnu.org/licenses/>. * along with this program. If not, see <https://www.gnu.org/licenses/>.
*/ */
package ca.joeltherrien.randomforest.covariates; package ca.joeltherrien.randomforest.covariates.bool;
import ca.joeltherrien.randomforest.Row; import ca.joeltherrien.randomforest.Row;
import ca.joeltherrien.randomforest.covariates.Covariate;
import ca.joeltherrien.randomforest.tree.Split; import ca.joeltherrien.randomforest.tree.Split;
import ca.joeltherrien.randomforest.utils.SingletonIterator; import ca.joeltherrien.randomforest.utils.SingletonIterator;
import lombok.Getter; import lombok.Getter;
import lombok.RequiredArgsConstructor;
import java.util.*; import java.util.Iterator;
import java.util.List;
import java.util.Random;
@RequiredArgsConstructor
public final class BooleanCovariate implements Covariate<Boolean> { public final class BooleanCovariate implements Covariate<Boolean> {
@Getter @Getter
@ -35,7 +36,13 @@ public final class BooleanCovariate implements Covariate<Boolean> {
private boolean hasNAs = false; private boolean hasNAs = false;
private final BooleanSplitRule splitRule = new BooleanSplitRule(); // there's only one possible rule for BooleanCovariates. private final BooleanSplitRule splitRule; // there's only one possible rule for BooleanCovariates.
public BooleanCovariate(String name, int index){
this.name = name;
this.index = index;
splitRule = new BooleanSplitRule(this);
}
@Override @Override
public <Y> Iterator<Split<Y, Boolean>> generateSplitRuleUpdater(List<Row<Y>> data, int number, Random random) { public <Y> Iterator<Split<Y, Boolean>> generateSplitRuleUpdater(List<Row<Y>> data, int number, Random random) {
@ -72,7 +79,7 @@ public final class BooleanCovariate implements Covariate<Boolean> {
@Override @Override
public String toString(){ public String toString(){
return "BooleanCovariate(name=" + name + ")"; return "BooleanCovariate(name=" + this.name + ", index=" + this.index + ", hasNAs=" + this.hasNAs + ")";
} }
public class BooleanValue implements Value<Boolean>{ public class BooleanValue implements Value<Boolean>{
@ -100,25 +107,4 @@ public final class BooleanCovariate implements Covariate<Boolean> {
} }
public class BooleanSplitRule implements SplitRule<Boolean>{
@Override
public final String toString() {
return "BooleanSplitRule";
}
@Override
public BooleanCovariate getParent() {
return BooleanCovariate.this;
}
@Override
public boolean isLeftHand(final Value<Boolean> value) {
if(value.isNA()) {
throw new IllegalArgumentException("Trying to determine split on missing value");
}
return !value.getValue();
}
}
} }

View file

@ -0,0 +1,48 @@
/*
* Copyright (c) 2019 Joel Therrien.
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
package ca.joeltherrien.randomforest.covariates.bool;
import ca.joeltherrien.randomforest.covariates.Covariate;
import ca.joeltherrien.randomforest.covariates.SplitRule;
public class BooleanSplitRule implements SplitRule<Boolean> {
private final int parentCovariateIndex;
public BooleanSplitRule(BooleanCovariate parent){
this.parentCovariateIndex = parent.getIndex();
}
@Override
public final String toString() {
return "BooleanSplitRule";
}
@Override
public int getParentCovariateIndex() {
return parentCovariateIndex;
}
@Override
public boolean isLeftHand(final Covariate.Value<Boolean> value) {
if(value.isNA()) {
throw new IllegalArgumentException("Trying to determine split on missing value");
}
return !value.getValue();
}
}

View file

@ -14,16 +14,17 @@
* along with this program. If not, see <https://www.gnu.org/licenses/>. * along with this program. If not, see <https://www.gnu.org/licenses/>.
*/ */
package ca.joeltherrien.randomforest.covariates; package ca.joeltherrien.randomforest.covariates.factor;
import ca.joeltherrien.randomforest.Row; import ca.joeltherrien.randomforest.Row;
import ca.joeltherrien.randomforest.covariates.Covariate;
import ca.joeltherrien.randomforest.tree.Split; import ca.joeltherrien.randomforest.tree.Split;
import lombok.EqualsAndHashCode; import lombok.EqualsAndHashCode;
import lombok.Getter; import lombok.Getter;
import java.util.*; import java.util.*;
public final class FactorCovariate implements Covariate<String>{ public final class FactorCovariate implements Covariate<String> {
@Getter @Getter
private final String name; private final String name;
@ -44,6 +45,10 @@ public final class FactorCovariate implements Covariate<String>{
this.factorLevels = new HashMap<>(); this.factorLevels = new HashMap<>();
for(final String level : levels){ for(final String level : levels){
if(level.equalsIgnoreCase("na")){
throw new IllegalArgumentException("Cannot use NA (case-insensitive) as a level in factor covariate " + name);
}
final FactorValue newValue = new FactorValue(level); final FactorValue newValue = new FactorValue(level);
factorLevels.put(level, newValue); factorLevels.put(level, newValue);
@ -70,16 +75,16 @@ public final class FactorCovariate implements Covariate<String>{
while(splits.size() < number){ while(splits.size() < number){
Collections.shuffle(levels, random); Collections.shuffle(levels, random);
final Set<FactorValue> leftSideValues = new HashSet<>(); final Set<String> leftSideValues = new HashSet<>();
leftSideValues.add(levels.get(0)); leftSideValues.add(levels.get(0).getValue());
for(int i=1; i<levels.size()/2; i++){ for(int i=1; i<levels.size()/2; i++){
if(random.nextBoolean()){ if(random.nextBoolean()){
leftSideValues.add(levels.get(i)); leftSideValues.add(levels.get(i).getValue());
} }
} }
splits.add(new FactorSplitRule(leftSideValues).applyRule(data)); splits.add(new FactorSplitRule(this, leftSideValues).applyRule(data));
} }
return splits.iterator(); return splits.iterator();
@ -110,8 +115,8 @@ public final class FactorCovariate implements Covariate<String>{
} }
@Override @Override
public String toString(){ public String toString() {
return "FactorCovariate(name=" + name + ")"; return "FactorCovariate(name=" + this.name + ", index=" + this.index + ", hasNAs=" + this.hasNAs + ")";
} }
@EqualsAndHashCode @EqualsAndHashCode
@ -139,27 +144,4 @@ public final class FactorCovariate implements Covariate<String>{
} }
} }
@EqualsAndHashCode
public final class FactorSplitRule implements Covariate.SplitRule<String>{
private final Set<FactorValue> leftSideValues;
private FactorSplitRule(final Set<FactorValue> leftSideValues){
this.leftSideValues = leftSideValues;
}
@Override
public FactorCovariate getParent() {
return FactorCovariate.this;
}
@Override
public boolean isLeftHand(final Value<String> value) {
if(value.isNA()){
throw new IllegalArgumentException("Trying to determine split on missing value");
}
return leftSideValues.contains(value);
}
}
} }

View file

@ -0,0 +1,49 @@
/*
* Copyright (c) 2019 Joel Therrien.
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
package ca.joeltherrien.randomforest.covariates.factor;
import ca.joeltherrien.randomforest.covariates.Covariate;
import ca.joeltherrien.randomforest.covariates.SplitRule;
import lombok.EqualsAndHashCode;
import java.util.Set;
@EqualsAndHashCode
public final class FactorSplitRule implements SplitRule<String> {
private final int parentCovariateIndex;
private final Set<String> leftSideValues;
public FactorSplitRule(final FactorCovariate parent, final Set<String> leftSideValues){
this.parentCovariateIndex = parent.getIndex();
this.leftSideValues = leftSideValues;
}
@Override
public int getParentCovariateIndex() {
return parentCovariateIndex;
}
@Override
public boolean isLeftHand(final Covariate.Value<String> value) {
if(value.isNA()){
throw new IllegalArgumentException("Trying to determine split on missing value");
}
return leftSideValues.contains(value.getValue());
}
}

View file

@ -26,7 +26,10 @@ import lombok.Getter;
import lombok.RequiredArgsConstructor; import lombok.RequiredArgsConstructor;
import lombok.ToString; import lombok.ToString;
import java.util.*; import java.util.Iterator;
import java.util.List;
import java.util.Random;
import java.util.TreeSet;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import java.util.stream.Stream; import java.util.stream.Stream;
@ -141,34 +144,4 @@ public final class NumericCovariate implements Covariate<Double> {
} }
} }
@EqualsAndHashCode
public class NumericSplitRule implements Covariate.SplitRule<Double>{
private final double threshold;
NumericSplitRule(final double threshold){
this.threshold = threshold;
}
@Override
public final String toString() {
return "NumericSplitRule on " + getParent().getName() + " at " + threshold;
}
@Override
public NumericCovariate getParent() {
return NumericCovariate.this;
}
@Override
public boolean isLeftHand(final Value<Double> x) {
if(x.isNA()) {
throw new IllegalArgumentException("Trying to determine split on missing value");
}
final double xNum = x.getValue();
return xNum <= threshold;
}
}
} }

View file

@ -0,0 +1,38 @@
package ca.joeltherrien.randomforest.covariates.numeric;
import ca.joeltherrien.randomforest.covariates.Covariate;
import ca.joeltherrien.randomforest.covariates.SplitRule;
import lombok.EqualsAndHashCode;
@EqualsAndHashCode
public class NumericSplitRule implements SplitRule<Double> {
private final int parentCovariateIndex;
private final double threshold;
NumericSplitRule(NumericCovariate parent, final double threshold){
this.parentCovariateIndex = parent.getIndex();
this.threshold = threshold;
}
@Override
public final String toString() {
return "NumericSplitRule on " + getParentCovariateIndex() + " at " + threshold;
}
@Override
public int getParentCovariateIndex() {
return parentCovariateIndex;
}
@Override
public boolean isLeftHand(final Covariate.Value<Double> x) {
if(x.isNA()) {
throw new IllegalArgumentException("Trying to determine split on missing value");
}
final double xNum = x.getValue();
return xNum <= threshold;
}
}

View file

@ -41,7 +41,7 @@ public class NumericSplitRuleUpdater<Y> implements Covariate.SplitRuleUpdater<Y,
final List<Row<Y>> rightHandList = orderedData; final List<Row<Y>> rightHandList = orderedData;
this.currentSplit = new Split<>( this.currentSplit = new Split<>(
covariate.new NumericSplitRule(Double.NEGATIVE_INFINITY), new NumericSplitRule(covariate, Double.NEGATIVE_INFINITY),
leftHandList, leftHandList,
rightHandList, rightHandList,
Collections.emptyList()); Collections.emptyList());
@ -67,7 +67,7 @@ public class NumericSplitRuleUpdater<Y> implements Covariate.SplitRuleUpdater<Y,
final List<Row<Y>> rowsMoved = orderedData.subList(currentPosition, newPosition); final List<Row<Y>> rowsMoved = orderedData.subList(currentPosition, newPosition);
final NumericCovariate.NumericSplitRule splitRule = covariate.new NumericSplitRule(splitValue); final NumericSplitRule splitRule = new NumericSplitRule(covariate, splitValue);
// Update current split // Update current split
this.currentSplit = new Split<>( this.currentSplit = new Split<>(

View file

@ -25,11 +25,11 @@ import java.util.List;
@AllArgsConstructor @AllArgsConstructor
public class NumericSplitUpdate<Y> implements Covariate.SplitUpdate<Y, Double> { public class NumericSplitUpdate<Y> implements Covariate.SplitUpdate<Y, Double> {
private final NumericCovariate.NumericSplitRule numericSplitRule; private final NumericSplitRule numericSplitRule;
private final List<Row<Y>> rowsMoved; private final List<Row<Y>> rowsMoved;
@Override @Override
public NumericCovariate.NumericSplitRule getSplitRule() { public NumericSplitRule getSplitRule() {
return numericSplitRule; return numericSplitRule;
} }

View file

@ -16,7 +16,7 @@
package ca.joeltherrien.randomforest.covariates.settings; package ca.joeltherrien.randomforest.covariates.settings;
import ca.joeltherrien.randomforest.covariates.BooleanCovariate; import ca.joeltherrien.randomforest.covariates.bool.BooleanCovariate;
import lombok.Data; import lombok.Data;
import lombok.NoArgsConstructor; import lombok.NoArgsConstructor;

View file

@ -16,7 +16,7 @@
package ca.joeltherrien.randomforest.covariates.settings; package ca.joeltherrien.randomforest.covariates.settings;
import ca.joeltherrien.randomforest.covariates.FactorCovariate; import ca.joeltherrien.randomforest.covariates.factor.FactorCovariate;
import lombok.Data; import lombok.Data;
import lombok.NoArgsConstructor; import lombok.NoArgsConstructor;

View file

@ -67,18 +67,18 @@ public class Forest<O, FO> { // O = output of trees, FO = forest output. In prac
return Collections.unmodifiableCollection(trees); return Collections.unmodifiableCollection(trees);
} }
public Map<Covariate, Integer> findSplitsByCovariate(){ public Map<Integer, Integer> findSplitsByCovariate(){
final Map<Covariate, Integer> countMap = new TreeMap<>(); final Map<Integer, Integer> countMap = new TreeMap<>();
for(final Tree<O> tree : getTrees()){ for(final Tree<O> tree : getTrees()){
final Node<O> rootNode = tree.getRootNode(); final Node<O> rootNode = tree.getRootNode();
final List<SplitNode> splitNodeList = rootNode.getNodesOfType(SplitNode.class); final List<SplitNode> splitNodeList = rootNode.getNodesOfType(SplitNode.class);
for(final SplitNode splitNode : splitNodeList){ for(final SplitNode splitNode : splitNodeList){
final Covariate covariate = splitNode.getSplitRule().getParent(); final Integer covariateIndex = splitNode.getSplitRule().getParentCovariateIndex();
final Integer currentCount = countMap.getOrDefault(covariate, 0); final Integer currentCount = countMap.getOrDefault(covariateIndex, 0);
countMap.put(covariate, currentCount+1); countMap.put(covariateIndex, currentCount+1);
} }
} }

View file

@ -17,7 +17,7 @@
package ca.joeltherrien.randomforest.tree; package ca.joeltherrien.randomforest.tree;
import ca.joeltherrien.randomforest.Row; import ca.joeltherrien.randomforest.Row;
import ca.joeltherrien.randomforest.covariates.Covariate; import ca.joeltherrien.randomforest.covariates.SplitRule;
import lombok.Data; import lombok.Data;
import java.util.ArrayList; import java.util.ArrayList;
@ -32,7 +32,7 @@ import java.util.List;
@Data @Data
public final class Split<Y, V> { public final class Split<Y, V> {
public final Covariate.SplitRule<V> splitRule; public final SplitRule<V> splitRule;
public final List<Row<Y>> leftHand; public final List<Row<Y>> leftHand;
public final List<Row<Y>> rightHand; public final List<Row<Y>> rightHand;
public final List<Row<Y>> naHand; public final List<Row<Y>> naHand;

View file

@ -17,7 +17,7 @@
package ca.joeltherrien.randomforest.tree; package ca.joeltherrien.randomforest.tree;
import ca.joeltherrien.randomforest.CovariateRow; import ca.joeltherrien.randomforest.CovariateRow;
import ca.joeltherrien.randomforest.covariates.Covariate; import ca.joeltherrien.randomforest.covariates.SplitRule;
import lombok.Builder; import lombok.Builder;
import lombok.Getter; import lombok.Getter;
import lombok.ToString; import lombok.ToString;
@ -32,7 +32,7 @@ public class SplitNode<Y> implements Node<Y> {
private final Node<Y> leftHand; private final Node<Y> leftHand;
private final Node<Y> rightHand; private final Node<Y> rightHand;
private final Covariate.SplitRule splitRule; private final SplitRule splitRule;
private final double probabilityNaLeftHand; // used when assigning NA values private final double probabilityNaLeftHand; // used when assigning NA values
@Override @Override

View file

@ -91,13 +91,13 @@ public class TreeTrainer<Y, O> {
(double) (bestSplit.leftHand.size() + bestSplit.rightHand.size()); (double) (bestSplit.leftHand.size() + bestSplit.rightHand.size());
// Assign missing values to the split if necessary // Assign missing values to the split if necessary
if(bestSplit.getSplitRule().getParent().hasNAs()){ if(covariates.get(bestSplit.getSplitRule().getParentCovariateIndex()).hasNAs()){
bestSplit = bestSplit.modifiableClone(); // the lists in bestSplit are otherwise usually unmodifiable lists bestSplit = bestSplit.modifiableClone(); // the lists in bestSplit are otherwise usually unmodifiable lists
for(Row<Y> row : data) { for(Row<Y> row : data) {
final Covariate<?> covariate = bestSplit.getSplitRule().getParent(); final int covariateIndex = bestSplit.getSplitRule().getParentCovariateIndex();
if(row.getCovariateValue(covariate).isNA()) { if(row.getValueByIndex(covariateIndex).isNA()) {
final boolean randomDecision = random.nextDouble() <= probabilityLeftHand; final boolean randomDecision = random.nextDouble() <= probabilityLeftHand;
if(randomDecision){ if(randomDecision){

View file

@ -113,12 +113,15 @@ public class TestSavingLoading {
final List<Row<CompetingRiskResponse>> dataset = DataUtils.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation()); final List<Row<CompetingRiskResponse>> dataset = DataUtils.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation());
final File directory = new File(settings.getSaveTreeLocation()); final File directory = new File(settings.getSaveTreeLocation());
if(directory.exists()){
directory.delete();
}
assertFalse(directory.exists()); assertFalse(directory.exists());
directory.mkdir(); directory.mkdir();
final ForestTrainer<CompetingRiskResponse, CompetingRiskFunctions, CompetingRiskFunctions> forestTrainer = new ForestTrainer<>(settings, dataset, covariates); final ForestTrainer<CompetingRiskResponse, CompetingRiskFunctions, CompetingRiskFunctions> forestTrainer = new ForestTrainer<>(settings, dataset, covariates);
forestTrainer.trainParallelOnDisk(1); forestTrainer.trainSerialOnDisk();
assertTrue(directory.exists()); assertTrue(directory.exists());
assertTrue(directory.isDirectory()); assertTrue(directory.isDirectory());

View file

@ -17,6 +17,7 @@
package ca.joeltherrien.randomforest.covariates; package ca.joeltherrien.randomforest.covariates;
import ca.joeltherrien.randomforest.covariates.factor.FactorCovariate;
import ca.joeltherrien.randomforest.utils.Utils; import ca.joeltherrien.randomforest.utils.Utils;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.function.Executable; import org.junit.jupiter.api.function.Executable;
@ -63,7 +64,7 @@ public class FactorCovariateTest {
void testAllSubsets(){ void testAllSubsets(){
final FactorCovariate petCovariate = createTestCovariate(); final FactorCovariate petCovariate = createTestCovariate();
final List<Covariate.SplitRule<String>> splitRules = new ArrayList<>(); final List<SplitRule<String>> splitRules = new ArrayList<>();
petCovariate.generateSplitRuleUpdater(null, 100, new Random()) petCovariate.generateSplitRuleUpdater(null, 100, new Random())
.forEachRemaining(split -> splitRules.add(split.getSplitRule())); .forEachRemaining(split -> splitRules.add(split.getSplitRule()));

View file

@ -19,7 +19,7 @@ package ca.joeltherrien.randomforest.workshop;
import ca.joeltherrien.randomforest.*; import ca.joeltherrien.randomforest.*;
import ca.joeltherrien.randomforest.covariates.Covariate; import ca.joeltherrien.randomforest.covariates.Covariate;
import ca.joeltherrien.randomforest.covariates.FactorCovariate; import ca.joeltherrien.randomforest.covariates.factor.FactorCovariate;
import ca.joeltherrien.randomforest.covariates.numeric.NumericCovariate; import ca.joeltherrien.randomforest.covariates.numeric.NumericCovariate;
import ca.joeltherrien.randomforest.responses.regression.MeanResponseCombiner; import ca.joeltherrien.randomforest.responses.regression.MeanResponseCombiner;
import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceGroupDifferentiator; import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceGroupDifferentiator;