package ca.joeltherrien.randomforest.covariates; import ca.joeltherrien.randomforest.CovariateRow; import ca.joeltherrien.randomforest.Row; import ca.joeltherrien.randomforest.tree.Split; import java.io.Serializable; import java.util.*; import java.util.concurrent.ThreadLocalRandom; public interface Covariate<V> extends Serializable { String getName(); int getIndex(); <Y> Iterator<Split<Y, V>> generateSplitRuleUpdater(final List<Row<Y>> data, final int number, final Random random); Value<V> createValue(V value); /** * Creates a Value of the appropriate type from a String; primarily used when parsing CSVs. * * @param value * @return */ Value<V> createValue(String value); interface Value<V> extends Serializable{ Covariate<V> getParent(); V getValue(); boolean isNA(); } interface SplitRuleUpdater<Y, V> extends Iterator<Split<Y, V>>{ Split<Y, V> currentSplit(); SplitUpdate<Y, V> nextUpdate(); } interface SplitUpdate<Y, V> { SplitRule<V> getSplitRule(); Collection<Row<Y>> rowsMovedToLeftHand(); } interface SplitRule<V> extends Serializable{ Covariate<V> getParent(); /** * Applies the SplitRule to a list of rows and returns a Split object, which contains two lists for both sides. * This method is primarily used during the training of a tree when splits are being tested. * * @param rows * @param <Y> * @return */ default <Y> Split<Y, V> applyRule(List<Row<Y>> rows) { final List<Row<Y>> leftHand = new LinkedList<>(); final List<Row<Y>> rightHand = new LinkedList<>(); final List<Row<Y>> missingValueRows = new ArrayList<>(); for(final Row<Y> row : rows) { final Value<V> value = row.getCovariateValue(getParent()); if(value.isNA()){ missingValueRows.add(row); continue; } final boolean isLeftHand = isLeftHand(value); if(isLeftHand){ leftHand.add(row); } else{ rightHand.add(row); } } return new Split<>(this, leftHand, rightHand, missingValueRows); } default boolean isLeftHand(CovariateRow row, final double probabilityNaLeftHand){ final Value<V> value = row.getCovariateValue(getParent()); if(value.isNA()){ return ThreadLocalRandom.current().nextDouble() <= probabilityNaLeftHand; } return isLeftHand(value); } boolean isLeftHand(Value<V> value); } }