Make SplitRules their own class; independent of their Covariate parents.
This was done so that when we serialize trees (and thus SplitRules) we don't awkwardly also serialize ntree versions of the Covariates, which is really awkward when deserializing them.
This commit is contained in:
parent
76b2cdd3c4
commit
585d6d3c5b
20 changed files with 313 additions and 198 deletions
|
@ -37,6 +37,10 @@ public class CovariateRow implements Serializable {
|
||||||
return valueArray[covariate.getIndex()];
|
return valueArray[covariate.getIndex()];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public <V> Covariate.Value<V> getValueByIndex(int index){
|
||||||
|
return valueArray[index];
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String toString(){
|
public String toString(){
|
||||||
return "CovariateRow " + this.id;
|
return "CovariateRow " + this.id;
|
||||||
|
|
|
@ -16,13 +16,14 @@
|
||||||
|
|
||||||
package ca.joeltherrien.randomforest.covariates;
|
package ca.joeltherrien.randomforest.covariates;
|
||||||
|
|
||||||
import ca.joeltherrien.randomforest.CovariateRow;
|
|
||||||
import ca.joeltherrien.randomforest.Row;
|
import ca.joeltherrien.randomforest.Row;
|
||||||
import ca.joeltherrien.randomforest.tree.Split;
|
import ca.joeltherrien.randomforest.tree.Split;
|
||||||
|
|
||||||
import java.io.Serializable;
|
import java.io.Serializable;
|
||||||
import java.util.*;
|
import java.util.Collection;
|
||||||
import java.util.concurrent.ThreadLocalRandom;
|
import java.util.Iterator;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Random;
|
||||||
|
|
||||||
public interface Covariate<V> extends Serializable, Comparable<Covariate> {
|
public interface Covariate<V> extends Serializable, Comparable<Covariate> {
|
||||||
|
|
||||||
|
@ -69,92 +70,7 @@ public interface Covariate<V> extends Serializable, Comparable<Covariate> {
|
||||||
Collection<Row<Y>> rowsMovedToLeftHand();
|
Collection<Row<Y>> rowsMovedToLeftHand();
|
||||||
}
|
}
|
||||||
|
|
||||||
interface SplitRule<V> extends Serializable{
|
|
||||||
|
|
||||||
Covariate<V> getParent();
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Applies the SplitRule to a list of rows and returns a Split object, which contains two lists for both sides.
|
|
||||||
* This method is primarily used during the training of a tree when splits are being tested.
|
|
||||||
*
|
|
||||||
* @param rows
|
|
||||||
* @param <Y>
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
default <Y> Split<Y, V> applyRule(List<Row<Y>> rows) {
|
|
||||||
|
|
||||||
/*
|
|
||||||
When working with really large List<Row<Y>> we need to be careful about memory.
|
|
||||||
If the lefthand and righthand lists are too small they grow, but for a moment copies exist
|
|
||||||
and memory issues arise.
|
|
||||||
|
|
||||||
If they're too large, we waste memory yet again
|
|
||||||
*/
|
|
||||||
|
|
||||||
// value of 0 = rightHand, value of 1 = leftHand, value of 2 = missingValueHand
|
|
||||||
final byte[] whichHand = new byte[rows.size()];
|
|
||||||
int countLeftHand = 0;
|
|
||||||
int countRightHand = 0;
|
|
||||||
int countMissingHand = 0;
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
for(int i=0; i<whichHand.length; i++){
|
|
||||||
final Row<Y> row = rows.get(i);
|
|
||||||
|
|
||||||
final Value<V> value = row.getCovariateValue(getParent());
|
|
||||||
|
|
||||||
if(value.isNA()){
|
|
||||||
countMissingHand++;
|
|
||||||
whichHand[i] = 2;
|
|
||||||
}
|
|
||||||
|
|
||||||
if(isLeftHand(value)){
|
|
||||||
countLeftHand++;
|
|
||||||
whichHand[i] = 1;
|
|
||||||
}
|
|
||||||
else{
|
|
||||||
countRightHand++;
|
|
||||||
whichHand[i] = 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
final List<Row<Y>> missingValueRows = new ArrayList<>(countMissingHand);
|
|
||||||
final List<Row<Y>> leftHand = new ArrayList<>(countLeftHand);
|
|
||||||
final List<Row<Y>> rightHand = new ArrayList<>(countRightHand);
|
|
||||||
|
|
||||||
for(int i=0; i<whichHand.length; i++){
|
|
||||||
final Row<Y> row = rows.get(i);
|
|
||||||
|
|
||||||
if(whichHand[i] == 0){
|
|
||||||
rightHand.add(row);
|
|
||||||
}
|
|
||||||
else if(whichHand[i] == 1){
|
|
||||||
leftHand.add(row);
|
|
||||||
}
|
|
||||||
else{
|
|
||||||
missingValueRows.add(row);
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
return new Split<>(this, leftHand, rightHand, missingValueRows);
|
|
||||||
}
|
|
||||||
|
|
||||||
default boolean isLeftHand(CovariateRow row, final double probabilityNaLeftHand){
|
|
||||||
final Value<V> value = row.getCovariateValue(getParent());
|
|
||||||
|
|
||||||
if(value.isNA()){
|
|
||||||
return ThreadLocalRandom.current().nextDouble() <= probabilityNaLeftHand;
|
|
||||||
}
|
|
||||||
|
|
||||||
return isLeftHand(value);
|
|
||||||
}
|
|
||||||
|
|
||||||
boolean isLeftHand(Value<V> value);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,115 @@
|
||||||
|
/*
|
||||||
|
* Copyright (c) 2019 Joel Therrien.
|
||||||
|
* This program is free software: you can redistribute it and/or modify
|
||||||
|
* it under the terms of the GNU General Public License as published by
|
||||||
|
* the Free Software Foundation, either version 3 of the License, or
|
||||||
|
* (at your option) any later version.
|
||||||
|
*
|
||||||
|
* This program is distributed in the hope that it will be useful,
|
||||||
|
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
* GNU General Public License for more details.
|
||||||
|
*
|
||||||
|
* You should have received a copy of the GNU General Public License
|
||||||
|
* along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package ca.joeltherrien.randomforest.covariates;
|
||||||
|
|
||||||
|
import ca.joeltherrien.randomforest.CovariateRow;
|
||||||
|
import ca.joeltherrien.randomforest.Row;
|
||||||
|
import ca.joeltherrien.randomforest.tree.Split;
|
||||||
|
|
||||||
|
import java.io.Serializable;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.concurrent.ThreadLocalRandom;
|
||||||
|
|
||||||
|
public interface SplitRule<V> extends Serializable{
|
||||||
|
|
||||||
|
int getParentCovariateIndex();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Applies the SplitRule to a list of rows and returns a Split object, which contains two lists for both sides.
|
||||||
|
* This method is primarily used during the training of a tree when splits are being tested.
|
||||||
|
*
|
||||||
|
* @param rows
|
||||||
|
* @param <Y>
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
default <Y> Split<Y, V> applyRule(List<Row<Y>> rows) {
|
||||||
|
|
||||||
|
/*
|
||||||
|
When working with really large List<Row<Y>> we need to be careful about memory.
|
||||||
|
If the lefthand and righthand lists are too small they grow, but for a moment copies exist
|
||||||
|
and memory issues arise.
|
||||||
|
|
||||||
|
If they're too large, we waste memory yet again
|
||||||
|
*/
|
||||||
|
|
||||||
|
// value of 0 = rightHand, value of 1 = leftHand, value of 2 = missingValueHand
|
||||||
|
final byte[] whichHand = new byte[rows.size()];
|
||||||
|
int countLeftHand = 0;
|
||||||
|
int countRightHand = 0;
|
||||||
|
int countMissingHand = 0;
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
for(int i=0; i<whichHand.length; i++){
|
||||||
|
final Row<Y> row = rows.get(i);
|
||||||
|
|
||||||
|
final Covariate.Value<V> value = row.getValueByIndex(getParentCovariateIndex());
|
||||||
|
|
||||||
|
if(value.isNA()){
|
||||||
|
countMissingHand++;
|
||||||
|
whichHand[i] = 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
if(isLeftHand(value)){
|
||||||
|
countLeftHand++;
|
||||||
|
whichHand[i] = 1;
|
||||||
|
}
|
||||||
|
else{
|
||||||
|
countRightHand++;
|
||||||
|
whichHand[i] = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
final List<Row<Y>> missingValueRows = new ArrayList<>(countMissingHand);
|
||||||
|
final List<Row<Y>> leftHand = new ArrayList<>(countLeftHand);
|
||||||
|
final List<Row<Y>> rightHand = new ArrayList<>(countRightHand);
|
||||||
|
|
||||||
|
for(int i=0; i<whichHand.length; i++){
|
||||||
|
final Row<Y> row = rows.get(i);
|
||||||
|
|
||||||
|
if(whichHand[i] == 0){
|
||||||
|
rightHand.add(row);
|
||||||
|
}
|
||||||
|
else if(whichHand[i] == 1){
|
||||||
|
leftHand.add(row);
|
||||||
|
}
|
||||||
|
else{
|
||||||
|
missingValueRows.add(row);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
return new Split<>(this, leftHand, rightHand, missingValueRows);
|
||||||
|
}
|
||||||
|
|
||||||
|
default boolean isLeftHand(CovariateRow row, final double probabilityNaLeftHand){
|
||||||
|
final Covariate.Value<V> value = row.getValueByIndex(getParentCovariateIndex());
|
||||||
|
|
||||||
|
if(value.isNA()){
|
||||||
|
return ThreadLocalRandom.current().nextDouble() <= probabilityNaLeftHand;
|
||||||
|
}
|
||||||
|
|
||||||
|
return isLeftHand(value);
|
||||||
|
}
|
||||||
|
|
||||||
|
boolean isLeftHand(Covariate.Value<V> value);
|
||||||
|
|
||||||
|
|
||||||
|
}
|
|
@ -14,17 +14,18 @@
|
||||||
* along with this program. If not, see <https://www.gnu.org/licenses/>.
|
* along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package ca.joeltherrien.randomforest.covariates;
|
package ca.joeltherrien.randomforest.covariates.bool;
|
||||||
|
|
||||||
import ca.joeltherrien.randomforest.Row;
|
import ca.joeltherrien.randomforest.Row;
|
||||||
|
import ca.joeltherrien.randomforest.covariates.Covariate;
|
||||||
import ca.joeltherrien.randomforest.tree.Split;
|
import ca.joeltherrien.randomforest.tree.Split;
|
||||||
import ca.joeltherrien.randomforest.utils.SingletonIterator;
|
import ca.joeltherrien.randomforest.utils.SingletonIterator;
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
import lombok.RequiredArgsConstructor;
|
|
||||||
|
|
||||||
import java.util.*;
|
import java.util.Iterator;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Random;
|
||||||
|
|
||||||
@RequiredArgsConstructor
|
|
||||||
public final class BooleanCovariate implements Covariate<Boolean> {
|
public final class BooleanCovariate implements Covariate<Boolean> {
|
||||||
|
|
||||||
@Getter
|
@Getter
|
||||||
|
@ -35,7 +36,13 @@ public final class BooleanCovariate implements Covariate<Boolean> {
|
||||||
|
|
||||||
private boolean hasNAs = false;
|
private boolean hasNAs = false;
|
||||||
|
|
||||||
private final BooleanSplitRule splitRule = new BooleanSplitRule(); // there's only one possible rule for BooleanCovariates.
|
private final BooleanSplitRule splitRule; // there's only one possible rule for BooleanCovariates.
|
||||||
|
|
||||||
|
public BooleanCovariate(String name, int index){
|
||||||
|
this.name = name;
|
||||||
|
this.index = index;
|
||||||
|
splitRule = new BooleanSplitRule(this);
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public <Y> Iterator<Split<Y, Boolean>> generateSplitRuleUpdater(List<Row<Y>> data, int number, Random random) {
|
public <Y> Iterator<Split<Y, Boolean>> generateSplitRuleUpdater(List<Row<Y>> data, int number, Random random) {
|
||||||
|
@ -72,7 +79,7 @@ public final class BooleanCovariate implements Covariate<Boolean> {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String toString(){
|
public String toString(){
|
||||||
return "BooleanCovariate(name=" + name + ")";
|
return "BooleanCovariate(name=" + this.name + ", index=" + this.index + ", hasNAs=" + this.hasNAs + ")";
|
||||||
}
|
}
|
||||||
|
|
||||||
public class BooleanValue implements Value<Boolean>{
|
public class BooleanValue implements Value<Boolean>{
|
||||||
|
@ -100,25 +107,4 @@ public final class BooleanCovariate implements Covariate<Boolean> {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
public class BooleanSplitRule implements SplitRule<Boolean>{
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public final String toString() {
|
|
||||||
return "BooleanSplitRule";
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public BooleanCovariate getParent() {
|
|
||||||
return BooleanCovariate.this;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean isLeftHand(final Value<Boolean> value) {
|
|
||||||
if(value.isNA()) {
|
|
||||||
throw new IllegalArgumentException("Trying to determine split on missing value");
|
|
||||||
}
|
|
||||||
|
|
||||||
return !value.getValue();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
|
@ -0,0 +1,48 @@
|
||||||
|
/*
|
||||||
|
* Copyright (c) 2019 Joel Therrien.
|
||||||
|
* This program is free software: you can redistribute it and/or modify
|
||||||
|
* it under the terms of the GNU General Public License as published by
|
||||||
|
* the Free Software Foundation, either version 3 of the License, or
|
||||||
|
* (at your option) any later version.
|
||||||
|
*
|
||||||
|
* This program is distributed in the hope that it will be useful,
|
||||||
|
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
* GNU General Public License for more details.
|
||||||
|
*
|
||||||
|
* You should have received a copy of the GNU General Public License
|
||||||
|
* along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package ca.joeltherrien.randomforest.covariates.bool;
|
||||||
|
|
||||||
|
import ca.joeltherrien.randomforest.covariates.Covariate;
|
||||||
|
import ca.joeltherrien.randomforest.covariates.SplitRule;
|
||||||
|
|
||||||
|
public class BooleanSplitRule implements SplitRule<Boolean> {
|
||||||
|
|
||||||
|
private final int parentCovariateIndex;
|
||||||
|
|
||||||
|
public BooleanSplitRule(BooleanCovariate parent){
|
||||||
|
this.parentCovariateIndex = parent.getIndex();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public final String toString() {
|
||||||
|
return "BooleanSplitRule";
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int getParentCovariateIndex() {
|
||||||
|
return parentCovariateIndex;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean isLeftHand(final Covariate.Value<Boolean> value) {
|
||||||
|
if(value.isNA()) {
|
||||||
|
throw new IllegalArgumentException("Trying to determine split on missing value");
|
||||||
|
}
|
||||||
|
|
||||||
|
return !value.getValue();
|
||||||
|
}
|
||||||
|
}
|
|
@ -14,16 +14,17 @@
|
||||||
* along with this program. If not, see <https://www.gnu.org/licenses/>.
|
* along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package ca.joeltherrien.randomforest.covariates;
|
package ca.joeltherrien.randomforest.covariates.factor;
|
||||||
|
|
||||||
import ca.joeltherrien.randomforest.Row;
|
import ca.joeltherrien.randomforest.Row;
|
||||||
|
import ca.joeltherrien.randomforest.covariates.Covariate;
|
||||||
import ca.joeltherrien.randomforest.tree.Split;
|
import ca.joeltherrien.randomforest.tree.Split;
|
||||||
import lombok.EqualsAndHashCode;
|
import lombok.EqualsAndHashCode;
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
|
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
|
|
||||||
public final class FactorCovariate implements Covariate<String>{
|
public final class FactorCovariate implements Covariate<String> {
|
||||||
|
|
||||||
@Getter
|
@Getter
|
||||||
private final String name;
|
private final String name;
|
||||||
|
@ -44,6 +45,10 @@ public final class FactorCovariate implements Covariate<String>{
|
||||||
this.factorLevels = new HashMap<>();
|
this.factorLevels = new HashMap<>();
|
||||||
|
|
||||||
for(final String level : levels){
|
for(final String level : levels){
|
||||||
|
if(level.equalsIgnoreCase("na")){
|
||||||
|
throw new IllegalArgumentException("Cannot use NA (case-insensitive) as a level in factor covariate " + name);
|
||||||
|
}
|
||||||
|
|
||||||
final FactorValue newValue = new FactorValue(level);
|
final FactorValue newValue = new FactorValue(level);
|
||||||
|
|
||||||
factorLevels.put(level, newValue);
|
factorLevels.put(level, newValue);
|
||||||
|
@ -70,16 +75,16 @@ public final class FactorCovariate implements Covariate<String>{
|
||||||
|
|
||||||
while(splits.size() < number){
|
while(splits.size() < number){
|
||||||
Collections.shuffle(levels, random);
|
Collections.shuffle(levels, random);
|
||||||
final Set<FactorValue> leftSideValues = new HashSet<>();
|
final Set<String> leftSideValues = new HashSet<>();
|
||||||
leftSideValues.add(levels.get(0));
|
leftSideValues.add(levels.get(0).getValue());
|
||||||
|
|
||||||
for(int i=1; i<levels.size()/2; i++){
|
for(int i=1; i<levels.size()/2; i++){
|
||||||
if(random.nextBoolean()){
|
if(random.nextBoolean()){
|
||||||
leftSideValues.add(levels.get(i));
|
leftSideValues.add(levels.get(i).getValue());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
splits.add(new FactorSplitRule(leftSideValues).applyRule(data));
|
splits.add(new FactorSplitRule(this, leftSideValues).applyRule(data));
|
||||||
}
|
}
|
||||||
|
|
||||||
return splits.iterator();
|
return splits.iterator();
|
||||||
|
@ -110,8 +115,8 @@ public final class FactorCovariate implements Covariate<String>{
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String toString(){
|
public String toString() {
|
||||||
return "FactorCovariate(name=" + name + ")";
|
return "FactorCovariate(name=" + this.name + ", index=" + this.index + ", hasNAs=" + this.hasNAs + ")";
|
||||||
}
|
}
|
||||||
|
|
||||||
@EqualsAndHashCode
|
@EqualsAndHashCode
|
||||||
|
@ -139,27 +144,4 @@ public final class FactorCovariate implements Covariate<String>{
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@EqualsAndHashCode
|
|
||||||
public final class FactorSplitRule implements Covariate.SplitRule<String>{
|
|
||||||
|
|
||||||
private final Set<FactorValue> leftSideValues;
|
|
||||||
|
|
||||||
private FactorSplitRule(final Set<FactorValue> leftSideValues){
|
|
||||||
this.leftSideValues = leftSideValues;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public FactorCovariate getParent() {
|
|
||||||
return FactorCovariate.this;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean isLeftHand(final Value<String> value) {
|
|
||||||
if(value.isNA()){
|
|
||||||
throw new IllegalArgumentException("Trying to determine split on missing value");
|
|
||||||
}
|
|
||||||
|
|
||||||
return leftSideValues.contains(value);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
|
@ -0,0 +1,49 @@
|
||||||
|
/*
|
||||||
|
* Copyright (c) 2019 Joel Therrien.
|
||||||
|
* This program is free software: you can redistribute it and/or modify
|
||||||
|
* it under the terms of the GNU General Public License as published by
|
||||||
|
* the Free Software Foundation, either version 3 of the License, or
|
||||||
|
* (at your option) any later version.
|
||||||
|
*
|
||||||
|
* This program is distributed in the hope that it will be useful,
|
||||||
|
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
* GNU General Public License for more details.
|
||||||
|
*
|
||||||
|
* You should have received a copy of the GNU General Public License
|
||||||
|
* along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package ca.joeltherrien.randomforest.covariates.factor;
|
||||||
|
|
||||||
|
import ca.joeltherrien.randomforest.covariates.Covariate;
|
||||||
|
import ca.joeltherrien.randomforest.covariates.SplitRule;
|
||||||
|
import lombok.EqualsAndHashCode;
|
||||||
|
|
||||||
|
import java.util.Set;
|
||||||
|
|
||||||
|
@EqualsAndHashCode
|
||||||
|
public final class FactorSplitRule implements SplitRule<String> {
|
||||||
|
|
||||||
|
private final int parentCovariateIndex;
|
||||||
|
private final Set<String> leftSideValues;
|
||||||
|
|
||||||
|
public FactorSplitRule(final FactorCovariate parent, final Set<String> leftSideValues){
|
||||||
|
this.parentCovariateIndex = parent.getIndex();
|
||||||
|
this.leftSideValues = leftSideValues;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int getParentCovariateIndex() {
|
||||||
|
return parentCovariateIndex;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean isLeftHand(final Covariate.Value<String> value) {
|
||||||
|
if(value.isNA()){
|
||||||
|
throw new IllegalArgumentException("Trying to determine split on missing value");
|
||||||
|
}
|
||||||
|
|
||||||
|
return leftSideValues.contains(value.getValue());
|
||||||
|
}
|
||||||
|
}
|
|
@ -26,7 +26,10 @@ import lombok.Getter;
|
||||||
import lombok.RequiredArgsConstructor;
|
import lombok.RequiredArgsConstructor;
|
||||||
import lombok.ToString;
|
import lombok.ToString;
|
||||||
|
|
||||||
import java.util.*;
|
import java.util.Iterator;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Random;
|
||||||
|
import java.util.TreeSet;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
import java.util.stream.Stream;
|
import java.util.stream.Stream;
|
||||||
|
|
||||||
|
@ -141,34 +144,4 @@ public final class NumericCovariate implements Covariate<Double> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@EqualsAndHashCode
|
|
||||||
public class NumericSplitRule implements Covariate.SplitRule<Double>{
|
|
||||||
|
|
||||||
private final double threshold;
|
|
||||||
|
|
||||||
NumericSplitRule(final double threshold){
|
|
||||||
this.threshold = threshold;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public final String toString() {
|
|
||||||
return "NumericSplitRule on " + getParent().getName() + " at " + threshold;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public NumericCovariate getParent() {
|
|
||||||
return NumericCovariate.this;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean isLeftHand(final Value<Double> x) {
|
|
||||||
if(x.isNA()) {
|
|
||||||
throw new IllegalArgumentException("Trying to determine split on missing value");
|
|
||||||
}
|
|
||||||
|
|
||||||
final double xNum = x.getValue();
|
|
||||||
|
|
||||||
return xNum <= threshold;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,38 @@
|
||||||
|
package ca.joeltherrien.randomforest.covariates.numeric;
|
||||||
|
|
||||||
|
import ca.joeltherrien.randomforest.covariates.Covariate;
|
||||||
|
import ca.joeltherrien.randomforest.covariates.SplitRule;
|
||||||
|
import lombok.EqualsAndHashCode;
|
||||||
|
|
||||||
|
@EqualsAndHashCode
|
||||||
|
public class NumericSplitRule implements SplitRule<Double> {
|
||||||
|
|
||||||
|
private final int parentCovariateIndex;
|
||||||
|
private final double threshold;
|
||||||
|
|
||||||
|
NumericSplitRule(NumericCovariate parent, final double threshold){
|
||||||
|
this.parentCovariateIndex = parent.getIndex();
|
||||||
|
this.threshold = threshold;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public final String toString() {
|
||||||
|
return "NumericSplitRule on " + getParentCovariateIndex() + " at " + threshold;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int getParentCovariateIndex() {
|
||||||
|
return parentCovariateIndex;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean isLeftHand(final Covariate.Value<Double> x) {
|
||||||
|
if(x.isNA()) {
|
||||||
|
throw new IllegalArgumentException("Trying to determine split on missing value");
|
||||||
|
}
|
||||||
|
|
||||||
|
final double xNum = x.getValue();
|
||||||
|
|
||||||
|
return xNum <= threshold;
|
||||||
|
}
|
||||||
|
}
|
|
@ -41,7 +41,7 @@ public class NumericSplitRuleUpdater<Y> implements Covariate.SplitRuleUpdater<Y,
|
||||||
final List<Row<Y>> rightHandList = orderedData;
|
final List<Row<Y>> rightHandList = orderedData;
|
||||||
|
|
||||||
this.currentSplit = new Split<>(
|
this.currentSplit = new Split<>(
|
||||||
covariate.new NumericSplitRule(Double.NEGATIVE_INFINITY),
|
new NumericSplitRule(covariate, Double.NEGATIVE_INFINITY),
|
||||||
leftHandList,
|
leftHandList,
|
||||||
rightHandList,
|
rightHandList,
|
||||||
Collections.emptyList());
|
Collections.emptyList());
|
||||||
|
@ -67,7 +67,7 @@ public class NumericSplitRuleUpdater<Y> implements Covariate.SplitRuleUpdater<Y,
|
||||||
|
|
||||||
final List<Row<Y>> rowsMoved = orderedData.subList(currentPosition, newPosition);
|
final List<Row<Y>> rowsMoved = orderedData.subList(currentPosition, newPosition);
|
||||||
|
|
||||||
final NumericCovariate.NumericSplitRule splitRule = covariate.new NumericSplitRule(splitValue);
|
final NumericSplitRule splitRule = new NumericSplitRule(covariate, splitValue);
|
||||||
|
|
||||||
// Update current split
|
// Update current split
|
||||||
this.currentSplit = new Split<>(
|
this.currentSplit = new Split<>(
|
||||||
|
|
|
@ -25,11 +25,11 @@ import java.util.List;
|
||||||
@AllArgsConstructor
|
@AllArgsConstructor
|
||||||
public class NumericSplitUpdate<Y> implements Covariate.SplitUpdate<Y, Double> {
|
public class NumericSplitUpdate<Y> implements Covariate.SplitUpdate<Y, Double> {
|
||||||
|
|
||||||
private final NumericCovariate.NumericSplitRule numericSplitRule;
|
private final NumericSplitRule numericSplitRule;
|
||||||
private final List<Row<Y>> rowsMoved;
|
private final List<Row<Y>> rowsMoved;
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public NumericCovariate.NumericSplitRule getSplitRule() {
|
public NumericSplitRule getSplitRule() {
|
||||||
return numericSplitRule;
|
return numericSplitRule;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -16,7 +16,7 @@
|
||||||
|
|
||||||
package ca.joeltherrien.randomforest.covariates.settings;
|
package ca.joeltherrien.randomforest.covariates.settings;
|
||||||
|
|
||||||
import ca.joeltherrien.randomforest.covariates.BooleanCovariate;
|
import ca.joeltherrien.randomforest.covariates.bool.BooleanCovariate;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import lombok.NoArgsConstructor;
|
import lombok.NoArgsConstructor;
|
||||||
|
|
||||||
|
|
|
@ -16,7 +16,7 @@
|
||||||
|
|
||||||
package ca.joeltherrien.randomforest.covariates.settings;
|
package ca.joeltherrien.randomforest.covariates.settings;
|
||||||
|
|
||||||
import ca.joeltherrien.randomforest.covariates.FactorCovariate;
|
import ca.joeltherrien.randomforest.covariates.factor.FactorCovariate;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import lombok.NoArgsConstructor;
|
import lombok.NoArgsConstructor;
|
||||||
|
|
||||||
|
|
|
@ -67,18 +67,18 @@ public class Forest<O, FO> { // O = output of trees, FO = forest output. In prac
|
||||||
return Collections.unmodifiableCollection(trees);
|
return Collections.unmodifiableCollection(trees);
|
||||||
}
|
}
|
||||||
|
|
||||||
public Map<Covariate, Integer> findSplitsByCovariate(){
|
public Map<Integer, Integer> findSplitsByCovariate(){
|
||||||
final Map<Covariate, Integer> countMap = new TreeMap<>();
|
final Map<Integer, Integer> countMap = new TreeMap<>();
|
||||||
|
|
||||||
for(final Tree<O> tree : getTrees()){
|
for(final Tree<O> tree : getTrees()){
|
||||||
final Node<O> rootNode = tree.getRootNode();
|
final Node<O> rootNode = tree.getRootNode();
|
||||||
final List<SplitNode> splitNodeList = rootNode.getNodesOfType(SplitNode.class);
|
final List<SplitNode> splitNodeList = rootNode.getNodesOfType(SplitNode.class);
|
||||||
|
|
||||||
for(final SplitNode splitNode : splitNodeList){
|
for(final SplitNode splitNode : splitNodeList){
|
||||||
final Covariate covariate = splitNode.getSplitRule().getParent();
|
final Integer covariateIndex = splitNode.getSplitRule().getParentCovariateIndex();
|
||||||
|
|
||||||
final Integer currentCount = countMap.getOrDefault(covariate, 0);
|
final Integer currentCount = countMap.getOrDefault(covariateIndex, 0);
|
||||||
countMap.put(covariate, currentCount+1);
|
countMap.put(covariateIndex, currentCount+1);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -17,7 +17,7 @@
|
||||||
package ca.joeltherrien.randomforest.tree;
|
package ca.joeltherrien.randomforest.tree;
|
||||||
|
|
||||||
import ca.joeltherrien.randomforest.Row;
|
import ca.joeltherrien.randomforest.Row;
|
||||||
import ca.joeltherrien.randomforest.covariates.Covariate;
|
import ca.joeltherrien.randomforest.covariates.SplitRule;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
|
@ -32,7 +32,7 @@ import java.util.List;
|
||||||
@Data
|
@Data
|
||||||
public final class Split<Y, V> {
|
public final class Split<Y, V> {
|
||||||
|
|
||||||
public final Covariate.SplitRule<V> splitRule;
|
public final SplitRule<V> splitRule;
|
||||||
public final List<Row<Y>> leftHand;
|
public final List<Row<Y>> leftHand;
|
||||||
public final List<Row<Y>> rightHand;
|
public final List<Row<Y>> rightHand;
|
||||||
public final List<Row<Y>> naHand;
|
public final List<Row<Y>> naHand;
|
||||||
|
|
|
@ -17,7 +17,7 @@
|
||||||
package ca.joeltherrien.randomforest.tree;
|
package ca.joeltherrien.randomforest.tree;
|
||||||
|
|
||||||
import ca.joeltherrien.randomforest.CovariateRow;
|
import ca.joeltherrien.randomforest.CovariateRow;
|
||||||
import ca.joeltherrien.randomforest.covariates.Covariate;
|
import ca.joeltherrien.randomforest.covariates.SplitRule;
|
||||||
import lombok.Builder;
|
import lombok.Builder;
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
import lombok.ToString;
|
import lombok.ToString;
|
||||||
|
@ -32,7 +32,7 @@ public class SplitNode<Y> implements Node<Y> {
|
||||||
|
|
||||||
private final Node<Y> leftHand;
|
private final Node<Y> leftHand;
|
||||||
private final Node<Y> rightHand;
|
private final Node<Y> rightHand;
|
||||||
private final Covariate.SplitRule splitRule;
|
private final SplitRule splitRule;
|
||||||
private final double probabilityNaLeftHand; // used when assigning NA values
|
private final double probabilityNaLeftHand; // used when assigning NA values
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -91,13 +91,13 @@ public class TreeTrainer<Y, O> {
|
||||||
(double) (bestSplit.leftHand.size() + bestSplit.rightHand.size());
|
(double) (bestSplit.leftHand.size() + bestSplit.rightHand.size());
|
||||||
|
|
||||||
// Assign missing values to the split if necessary
|
// Assign missing values to the split if necessary
|
||||||
if(bestSplit.getSplitRule().getParent().hasNAs()){
|
if(covariates.get(bestSplit.getSplitRule().getParentCovariateIndex()).hasNAs()){
|
||||||
bestSplit = bestSplit.modifiableClone(); // the lists in bestSplit are otherwise usually unmodifiable lists
|
bestSplit = bestSplit.modifiableClone(); // the lists in bestSplit are otherwise usually unmodifiable lists
|
||||||
|
|
||||||
for(Row<Y> row : data) {
|
for(Row<Y> row : data) {
|
||||||
final Covariate<?> covariate = bestSplit.getSplitRule().getParent();
|
final int covariateIndex = bestSplit.getSplitRule().getParentCovariateIndex();
|
||||||
|
|
||||||
if(row.getCovariateValue(covariate).isNA()) {
|
if(row.getValueByIndex(covariateIndex).isNA()) {
|
||||||
final boolean randomDecision = random.nextDouble() <= probabilityLeftHand;
|
final boolean randomDecision = random.nextDouble() <= probabilityLeftHand;
|
||||||
|
|
||||||
if(randomDecision){
|
if(randomDecision){
|
||||||
|
|
|
@ -113,12 +113,15 @@ public class TestSavingLoading {
|
||||||
final List<Row<CompetingRiskResponse>> dataset = DataUtils.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation());
|
final List<Row<CompetingRiskResponse>> dataset = DataUtils.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation());
|
||||||
|
|
||||||
final File directory = new File(settings.getSaveTreeLocation());
|
final File directory = new File(settings.getSaveTreeLocation());
|
||||||
|
if(directory.exists()){
|
||||||
|
directory.delete();
|
||||||
|
}
|
||||||
assertFalse(directory.exists());
|
assertFalse(directory.exists());
|
||||||
directory.mkdir();
|
directory.mkdir();
|
||||||
|
|
||||||
final ForestTrainer<CompetingRiskResponse, CompetingRiskFunctions, CompetingRiskFunctions> forestTrainer = new ForestTrainer<>(settings, dataset, covariates);
|
final ForestTrainer<CompetingRiskResponse, CompetingRiskFunctions, CompetingRiskFunctions> forestTrainer = new ForestTrainer<>(settings, dataset, covariates);
|
||||||
|
|
||||||
forestTrainer.trainParallelOnDisk(1);
|
forestTrainer.trainSerialOnDisk();
|
||||||
|
|
||||||
assertTrue(directory.exists());
|
assertTrue(directory.exists());
|
||||||
assertTrue(directory.isDirectory());
|
assertTrue(directory.isDirectory());
|
||||||
|
|
|
@ -17,6 +17,7 @@
|
||||||
package ca.joeltherrien.randomforest.covariates;
|
package ca.joeltherrien.randomforest.covariates;
|
||||||
|
|
||||||
|
|
||||||
|
import ca.joeltherrien.randomforest.covariates.factor.FactorCovariate;
|
||||||
import ca.joeltherrien.randomforest.utils.Utils;
|
import ca.joeltherrien.randomforest.utils.Utils;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.junit.jupiter.api.function.Executable;
|
import org.junit.jupiter.api.function.Executable;
|
||||||
|
@ -63,7 +64,7 @@ public class FactorCovariateTest {
|
||||||
void testAllSubsets(){
|
void testAllSubsets(){
|
||||||
final FactorCovariate petCovariate = createTestCovariate();
|
final FactorCovariate petCovariate = createTestCovariate();
|
||||||
|
|
||||||
final List<Covariate.SplitRule<String>> splitRules = new ArrayList<>();
|
final List<SplitRule<String>> splitRules = new ArrayList<>();
|
||||||
|
|
||||||
petCovariate.generateSplitRuleUpdater(null, 100, new Random())
|
petCovariate.generateSplitRuleUpdater(null, 100, new Random())
|
||||||
.forEachRemaining(split -> splitRules.add(split.getSplitRule()));
|
.forEachRemaining(split -> splitRules.add(split.getSplitRule()));
|
||||||
|
|
|
@ -19,7 +19,7 @@ package ca.joeltherrien.randomforest.workshop;
|
||||||
|
|
||||||
import ca.joeltherrien.randomforest.*;
|
import ca.joeltherrien.randomforest.*;
|
||||||
import ca.joeltherrien.randomforest.covariates.Covariate;
|
import ca.joeltherrien.randomforest.covariates.Covariate;
|
||||||
import ca.joeltherrien.randomforest.covariates.FactorCovariate;
|
import ca.joeltherrien.randomforest.covariates.factor.FactorCovariate;
|
||||||
import ca.joeltherrien.randomforest.covariates.numeric.NumericCovariate;
|
import ca.joeltherrien.randomforest.covariates.numeric.NumericCovariate;
|
||||||
import ca.joeltherrien.randomforest.responses.regression.MeanResponseCombiner;
|
import ca.joeltherrien.randomforest.responses.regression.MeanResponseCombiner;
|
||||||
import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceGroupDifferentiator;
|
import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceGroupDifferentiator;
|
||||||
|
|
Loading…
Add table
Reference in a new issue