diff --git a/pmd-rules.xml b/pmd-rules.xml
new file mode 100644
index 0000000..abad748
--- /dev/null
+++ b/pmd-rules.xml
@@ -0,0 +1,93 @@
+
+
+
+
+
+ The default ruleset used by the Maven PMD Plugin, when no other ruleset is specified.
+ It contains the rules of the old (pre PMD 6.0.0) rulesets java-basic, java-empty, java-imports,
+ java-unnecessary, java-unusedcode.
+
+ This ruleset might be used as a starting point for an own customized ruleset [0].
+
+ [0] https://pmd.github.io/latest/pmd_userdocs_understanding_rulesets.html
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/pom.xml b/pom.xml
index 5d6ba5b..6d22d2b 100644
--- a/pom.xml
+++ b/pom.xml
@@ -13,6 +13,9 @@
1.8
1.8
2.9.6
+
+ UTF-8
+ UTF-8
@@ -89,9 +92,27 @@
+
+ org.apache.maven.plugins
+ maven-pmd-plugin
+ 3.11.0
+
+
+ package
+
+ check
+
+
+
+
+
+
+ ${project.basedir}/pmd-rules.xml
+
+
+
-
\ No newline at end of file
diff --git a/src/main/java/ca/joeltherrien/randomforest/Bootstrapper.java b/src/main/java/ca/joeltherrien/randomforest/Bootstrapper.java
index b73db9d..a7577db 100644
--- a/src/main/java/ca/joeltherrien/randomforest/Bootstrapper.java
+++ b/src/main/java/ca/joeltherrien/randomforest/Bootstrapper.java
@@ -4,20 +4,20 @@ import lombok.RequiredArgsConstructor;
import java.util.ArrayList;
import java.util.List;
-import java.util.concurrent.ThreadLocalRandom;
+import java.util.Random;
@RequiredArgsConstructor
public class Bootstrapper {
final private List originalData;
- public List bootstrap(){
+ public List bootstrap(Random random){
final int n = originalData.size();
final List newList = new ArrayList<>(n);
for(int i=0; i getCovariateValue(Covariate covariate){
+ public Covariate.Value getCovariateValue(Covariate covariate){
return valueArray[covariate.getIndex()];
}
diff --git a/src/main/java/ca/joeltherrien/randomforest/DataLoader.java b/src/main/java/ca/joeltherrien/randomforest/DataLoader.java
index 25e8465..5118fdc 100644
--- a/src/main/java/ca/joeltherrien/randomforest/DataLoader.java
+++ b/src/main/java/ca/joeltherrien/randomforest/DataLoader.java
@@ -58,7 +58,7 @@ public class DataLoader {
throw new IllegalArgumentException("Tree directory must be a directory!");
}
- final File[] treeFiles = folder.listFiles(((file, s) -> s.endsWith(".tree")));
+ final File[] treeFiles = folder.listFiles((file, s) -> s.endsWith(".tree"));
final List treeFileList = Arrays.asList(treeFiles);
Collections.sort(treeFileList, Comparator.comparing(File::getName));
diff --git a/src/main/java/ca/joeltherrien/randomforest/Main.java b/src/main/java/ca/joeltherrien/randomforest/Main.java
index f9b7194..f5a7731 100644
--- a/src/main/java/ca/joeltherrien/randomforest/Main.java
+++ b/src/main/java/ca/joeltherrien/randomforest/Main.java
@@ -1,9 +1,9 @@
package ca.joeltherrien.randomforest;
-import ca.joeltherrien.randomforest.covariates.BooleanCovariateSettings;
+import ca.joeltherrien.randomforest.covariates.settings.BooleanCovariateSettings;
import ca.joeltherrien.randomforest.covariates.Covariate;
-import ca.joeltherrien.randomforest.covariates.FactorCovariateSettings;
-import ca.joeltherrien.randomforest.covariates.NumericCovariateSettings;
+import ca.joeltherrien.randomforest.covariates.settings.FactorCovariateSettings;
+import ca.joeltherrien.randomforest.covariates.settings.NumericCovariateSettings;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskErrorRateCalculator;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskFunctions;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
diff --git a/src/main/java/ca/joeltherrien/randomforest/Row.java b/src/main/java/ca/joeltherrien/randomforest/Row.java
index 850f035..00c5078 100644
--- a/src/main/java/ca/joeltherrien/randomforest/Row.java
+++ b/src/main/java/ca/joeltherrien/randomforest/Row.java
@@ -17,6 +17,8 @@ public class Row extends CovariateRow {
}
+
+
public Y getResponse() {
return this.response;
}
diff --git a/src/main/java/ca/joeltherrien/randomforest/Settings.java b/src/main/java/ca/joeltherrien/randomforest/Settings.java
index ee40a35..c89681c 100644
--- a/src/main/java/ca/joeltherrien/randomforest/Settings.java
+++ b/src/main/java/ca/joeltherrien/randomforest/Settings.java
@@ -1,7 +1,7 @@
package ca.joeltherrien.randomforest;
import ca.joeltherrien.randomforest.covariates.Covariate;
-import ca.joeltherrien.randomforest.covariates.CovariateSettings;
+import ca.joeltherrien.randomforest.covariates.settings.CovariateSettings;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponseWithCensorTime;
import ca.joeltherrien.randomforest.responses.competingrisk.combiner.CompetingRiskFunctionCombiner;
@@ -10,7 +10,6 @@ import ca.joeltherrien.randomforest.responses.competingrisk.differentiator.GrayL
import ca.joeltherrien.randomforest.responses.competingrisk.differentiator.GrayLogRankSingleGroupDifferentiator;
import ca.joeltherrien.randomforest.responses.competingrisk.differentiator.LogRankMultipleGroupDifferentiator;
import ca.joeltherrien.randomforest.responses.competingrisk.differentiator.LogRankSingleGroupDifferentiator;
-import ca.joeltherrien.randomforest.responses.regression.MeanGroupDifferentiator;
import ca.joeltherrien.randomforest.responses.regression.MeanResponseCombiner;
import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceGroupDifferentiator;
import ca.joeltherrien.randomforest.tree.GroupDifferentiator;
@@ -68,9 +67,6 @@ public class Settings {
GROUP_DIFFERENTIATOR_MAP.put(name.toLowerCase(), groupDifferentiatorConstructor);
}
static{
- registerGroupDifferentiatorConstructor("MeanGroupDifferentiator",
- (node) -> new MeanGroupDifferentiator()
- );
registerGroupDifferentiatorConstructor("WeightedVarianceGroupDifferentiator",
(node) -> new WeightedVarianceGroupDifferentiator()
);
diff --git a/src/main/java/ca/joeltherrien/randomforest/covariates/BooleanCovariate.java b/src/main/java/ca/joeltherrien/randomforest/covariates/BooleanCovariate.java
index 8959bc4..741fc2a 100644
--- a/src/main/java/ca/joeltherrien/randomforest/covariates/BooleanCovariate.java
+++ b/src/main/java/ca/joeltherrien/randomforest/covariates/BooleanCovariate.java
@@ -1,14 +1,15 @@
package ca.joeltherrien.randomforest.covariates;
+import ca.joeltherrien.randomforest.Row;
+import ca.joeltherrien.randomforest.tree.Split;
+import ca.joeltherrien.randomforest.utils.SingletonIterator;
import lombok.Getter;
import lombok.RequiredArgsConstructor;
-import java.util.Collection;
-import java.util.Collections;
-import java.util.List;
+import java.util.*;
@RequiredArgsConstructor
-public final class BooleanCovariate implements Covariate{
+public final class BooleanCovariate implements Covariate {
@Getter
private final String name;
@@ -16,11 +17,13 @@ public final class BooleanCovariate implements Covariate{
@Getter
private final int index;
+ private boolean hasNAs = false;
+
private final BooleanSplitRule splitRule = new BooleanSplitRule(); // there's only one possible rule for BooleanCovariates.
@Override
- public Collection generateSplitRules(List> data, int number) {
- return Collections.singleton(splitRule);
+ public Iterator> generateSplitRuleUpdater(List> data, int number, Random random) {
+ return new SingletonIterator<>(this.splitRule.applyRule(data));
}
@Override
@@ -31,6 +34,7 @@ public final class BooleanCovariate implements Covariate{
@Override
public Value createValue(String value) {
if(value == null || value.equalsIgnoreCase("na")){
+ hasNAs = true;
return createValue( (Boolean) null);
}
@@ -45,6 +49,11 @@ public final class BooleanCovariate implements Covariate{
}
}
+ @Override
+ public boolean hasNAs() {
+ return hasNAs;
+ }
+
@Override
public String toString(){
return "BooleanCovariate(name=" + name + ")";
@@ -74,6 +83,7 @@ public final class BooleanCovariate implements Covariate{
}
}
+
public class BooleanSplitRule implements SplitRule{
@Override
diff --git a/src/main/java/ca/joeltherrien/randomforest/covariates/Covariate.java b/src/main/java/ca/joeltherrien/randomforest/covariates/Covariate.java
index 9fa77ef..95c64df 100644
--- a/src/main/java/ca/joeltherrien/randomforest/covariates/Covariate.java
+++ b/src/main/java/ca/joeltherrien/randomforest/covariates/Covariate.java
@@ -5,10 +5,7 @@ import ca.joeltherrien.randomforest.Row;
import ca.joeltherrien.randomforest.tree.Split;
import java.io.Serializable;
-import java.util.ArrayList;
-import java.util.Collection;
-import java.util.LinkedList;
-import java.util.List;
+import java.util.*;
import java.util.concurrent.ThreadLocalRandom;
public interface Covariate extends Serializable {
@@ -17,7 +14,7 @@ public interface Covariate extends Serializable {
int getIndex();
- Collection extends SplitRule> generateSplitRules(final List> data, final int number);
+ Iterator> generateSplitRuleUpdater(final List> data, final int number, final Random random);
Value createValue(V value);
@@ -29,6 +26,8 @@ public interface Covariate extends Serializable {
*/
Value createValue(String value);
+ boolean hasNAs();
+
interface Value extends Serializable{
Covariate getParent();
@@ -39,6 +38,17 @@ public interface Covariate extends Serializable {
}
+ interface SplitRuleUpdater extends Iterator>{
+ Split currentSplit();
+ boolean currentSplitValid();
+ SplitUpdate nextUpdate();
+ }
+
+ interface SplitUpdate {
+ SplitRule getSplitRule();
+ Collection> rowsMovedToLeftHand();
+ }
+
interface SplitRule extends Serializable{
Covariate getParent();
@@ -51,7 +61,7 @@ public interface Covariate extends Serializable {
* @param
* @return
*/
- default Split applyRule(List> rows) {
+ default Split applyRule(List> rows) {
final List> leftHand = new LinkedList<>();
final List> rightHand = new LinkedList<>();
@@ -59,7 +69,7 @@ public interface Covariate extends Serializable {
for(final Row row : rows) {
- final Value value = (Value) row.getCovariateValue(getParent());
+ final Value value = row.getCovariateValue(getParent());
if(value.isNA()){
missingValueRows.add(row);
@@ -77,11 +87,11 @@ public interface Covariate extends Serializable {
}
- return new Split<>(leftHand, rightHand, missingValueRows);
+ return new Split<>(this, leftHand, rightHand, missingValueRows);
}
default boolean isLeftHand(CovariateRow row, final double probabilityNaLeftHand){
- final Value value = (Value) row.getCovariateValue(getParent());
+ final Value value = row.getCovariateValue(getParent());
if(value.isNA()){
return ThreadLocalRandom.current().nextDouble() <= probabilityNaLeftHand;
diff --git a/src/main/java/ca/joeltherrien/randomforest/covariates/FactorCovariate.java b/src/main/java/ca/joeltherrien/randomforest/covariates/FactorCovariate.java
index 402cb7d..e0f4101 100644
--- a/src/main/java/ca/joeltherrien/randomforest/covariates/FactorCovariate.java
+++ b/src/main/java/ca/joeltherrien/randomforest/covariates/FactorCovariate.java
@@ -1,10 +1,11 @@
package ca.joeltherrien.randomforest.covariates;
+import ca.joeltherrien.randomforest.Row;
+import ca.joeltherrien.randomforest.tree.Split;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import java.util.*;
-import java.util.concurrent.ThreadLocalRandom;
public final class FactorCovariate implements Covariate{
@@ -18,6 +19,8 @@ public final class FactorCovariate implements Covariate{
private final FactorValue naValue;
private final int numberOfPossiblePairings;
+ private boolean hasNAs;
+
public FactorCovariate(final String name, final int index, List levels){
this.name = name;
@@ -42,17 +45,14 @@ public final class FactorCovariate implements Covariate{
@Override
- public Set generateSplitRules(List> data, int number) {
- final Set splitRules = new HashSet<>();
+ public Iterator> generateSplitRuleUpdater(List> data, int number, Random random) {
+ final Set> splits = new HashSet<>();
// This is to ensure we don't get stuck in an infinite loop for small factors
number = Math.min(number, numberOfPossiblePairings);
- final Random random = ThreadLocalRandom.current();
final List levels = new ArrayList<>(factorLevels.values());
-
-
- while(splitRules.size() < number){
+ while(splits.size() < number){
Collections.shuffle(levels, random);
final Set leftSideValues = new HashSet<>();
leftSideValues.add(levels.get(0));
@@ -63,16 +63,18 @@ public final class FactorCovariate implements Covariate{
}
}
- splitRules.add(new FactorSplitRule(leftSideValues));
+ splits.add(new FactorSplitRule(leftSideValues).applyRule(data));
}
- return splitRules;
+ return splits.iterator();
}
+
@Override
public FactorValue createValue(String value) {
if(value == null || value.equalsIgnoreCase("na")){
+ this.hasNAs = true;
return this.naValue;
}
@@ -85,6 +87,12 @@ public final class FactorCovariate implements Covariate{
return factorValue;
}
+
+ @Override
+ public boolean hasNAs() {
+ return hasNAs;
+ }
+
@Override
public String toString(){
return "FactorCovariate(name=" + name + ")";
diff --git a/src/main/java/ca/joeltherrien/randomforest/covariates/NumericCovariate.java b/src/main/java/ca/joeltherrien/randomforest/covariates/numeric/NumericCovariate.java
similarity index 51%
rename from src/main/java/ca/joeltherrien/randomforest/covariates/NumericCovariate.java
rename to src/main/java/ca/joeltherrien/randomforest/covariates/numeric/NumericCovariate.java
index a268ba8..9b23d08 100644
--- a/src/main/java/ca/joeltherrien/randomforest/covariates/NumericCovariate.java
+++ b/src/main/java/ca/joeltherrien/randomforest/covariates/numeric/NumericCovariate.java
@@ -1,17 +1,22 @@
-package ca.joeltherrien.randomforest.covariates;
+package ca.joeltherrien.randomforest.covariates.numeric;
+import ca.joeltherrien.randomforest.Row;
+import ca.joeltherrien.randomforest.covariates.Covariate;
+import ca.joeltherrien.randomforest.utils.IndexedIterator;
+import ca.joeltherrien.randomforest.utils.UniqueSubsetValueIterator;
+import ca.joeltherrien.randomforest.utils.UniqueValueIterator;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.RequiredArgsConstructor;
import lombok.ToString;
import java.util.*;
-import java.util.concurrent.ThreadLocalRandom;
import java.util.stream.Collectors;
+import java.util.stream.Stream;
@RequiredArgsConstructor
@ToString
-public final class NumericCovariate implements Covariate{
+public final class NumericCovariate implements Covariate {
@Getter
private final String name;
@@ -19,42 +24,51 @@ public final class NumericCovariate implements Covariate{
@Getter
private final int index;
+ private boolean hasNAs = false;
+
@Override
- public Collection generateSplitRules(List> data, int number) {
+ public NumericSplitRuleUpdater generateSplitRuleUpdater(List> data, int number, Random random) {
+ Stream> stream = data.stream();
- final Random random = ThreadLocalRandom.current();
-
- // only work with non-NA values
- data = data.stream().filter(value -> !value.isNA()).collect(Collectors.toList());
- //data = data.stream().filter(value -> !value.isNA()).distinct().collect(Collectors.toList()); // TODO which to use?
-
- // for this implementation we need to shuffle the data
- final List> shuffledData;
- if(number >= data.size()){
- shuffledData = new ArrayList<>(data);
- Collections.shuffle(shuffledData, random);
+ if(hasNAs()){
+ stream = stream.filter(row -> !row.getCovariateValue(this).isNA());
}
- else{ // only need the top number entries
- shuffledData = new ArrayList<>(number);
- final Set indexesToUse = new HashSet<>();
- //final List indexesToUse = new ArrayList<>(); // TODO which to use?
- while(indexesToUse.size() < number){
- final int index = random.nextInt(data.size());
+ data = stream
+ .sorted((r1, r2) -> {
+ Double d1 = r1.getCovariateValue(this).getValue();
+ Double d2 = r2.getCovariateValue(this).getValue();
- if(indexesToUse.add(index)){
- shuffledData.add(data.get(index));
- }
+ return d1.compareTo(d2);
+ })
+ .collect(Collectors.toList());
+
+ Iterator sortedDataIterator = data.stream()
+ .map(row -> row.getCovariateValue(this).getValue())
+ .iterator();
+
+
+ final IndexedIterator dataIterator;
+ if(number == 0){
+ dataIterator = new UniqueValueIterator<>(sortedDataIterator);
+ }
+ else{
+ final TreeSet indexSet = new TreeSet<>();
+
+ final int maxIndex = data.size();
+
+ for(int i=0; i(
+ new UniqueValueIterator<>(sortedDataIterator),
+ indexSet.toArray(new Integer[indexSet.size()])
+ );
+
}
- return shuffledData.stream()
- .mapToDouble(v -> v.getValue())
- .mapToObj(threshold -> new NumericSplitRule(threshold))
- .collect(Collectors.toSet());
- // by returning a set we'll make everything far more efficient as a lot of rules can repeat due to bootstrapping
-
+ return new NumericSplitRuleUpdater<>(this, data, dataIterator);
}
@@ -66,12 +80,19 @@ public final class NumericCovariate implements Covariate{
@Override
public NumericValue createValue(String value) {
if(value == null || value.equalsIgnoreCase("na")){
+ this.hasNAs = true;
return createValue((Double) null);
}
return createValue(Double.parseDouble(value));
}
+
+ @Override
+ public boolean hasNAs() {
+ return hasNAs;
+ }
+
@EqualsAndHashCode
public class NumericValue implements Covariate.Value{
@@ -102,7 +123,7 @@ public final class NumericCovariate implements Covariate{
private final double threshold;
- private NumericSplitRule(final double threshold){
+ NumericSplitRule(final double threshold){
this.threshold = threshold;
}
diff --git a/src/main/java/ca/joeltherrien/randomforest/covariates/numeric/NumericSplitRuleUpdater.java b/src/main/java/ca/joeltherrien/randomforest/covariates/numeric/NumericSplitRuleUpdater.java
new file mode 100644
index 0000000..4f1586c
--- /dev/null
+++ b/src/main/java/ca/joeltherrien/randomforest/covariates/numeric/NumericSplitRuleUpdater.java
@@ -0,0 +1,84 @@
+package ca.joeltherrien.randomforest.covariates.numeric;
+
+import ca.joeltherrien.randomforest.Row;
+import ca.joeltherrien.randomforest.covariates.Covariate;
+import ca.joeltherrien.randomforest.tree.Split;
+import ca.joeltherrien.randomforest.utils.IndexedIterator;
+
+import java.util.Collections;
+import java.util.List;
+
+public class NumericSplitRuleUpdater implements Covariate.SplitRuleUpdater {
+
+ private final NumericCovariate covariate;
+ private final List> orderedData;
+ private final IndexedIterator dataIterator;
+
+ private Split currentSplit;
+
+ public NumericSplitRuleUpdater(final NumericCovariate covariate, final List> orderedData, final IndexedIterator iterator){
+ this.covariate = covariate;
+ this.orderedData = orderedData;
+ this.dataIterator = iterator;
+
+ final List> leftHandList = Collections.emptyList();
+ final List> rightHandList = orderedData;
+
+ this.currentSplit = new Split<>(
+ covariate.new NumericSplitRule(Double.MIN_VALUE),
+ leftHandList,
+ rightHandList,
+ Collections.emptyList());
+
+ }
+
+ @Override
+ public Split currentSplit() {
+ return this.currentSplit;
+ }
+
+ @Override
+ public boolean currentSplitValid() {
+ return currentSplit.getLeftHand().size() > 0 && currentSplit.getRightHand().size() > 0;
+ }
+
+ @Override
+ public NumericSplitUpdate nextUpdate() {
+ if(hasNext()){
+ final int currentPosition = dataIterator.getIndex();
+ final Double splitValue = dataIterator.next();
+ final int newPosition = dataIterator.getIndex();
+
+ final List> rowsMoved = orderedData.subList(currentPosition, newPosition);
+
+ final NumericCovariate.NumericSplitRule splitRule = covariate.new NumericSplitRule(splitValue);
+
+ // Update current split
+ this.currentSplit = new Split<>(
+ splitRule,
+ Collections.unmodifiableList(orderedData.subList(0, newPosition)),
+ Collections.unmodifiableList(orderedData.subList(newPosition, orderedData.size())),
+ Collections.emptyList());
+
+
+ return new NumericSplitUpdate<>(splitRule, rowsMoved);
+ }
+
+ return null;
+ }
+
+ @Override
+ public boolean hasNext() {
+ return dataIterator.hasNext();
+ }
+
+ @Override
+ public Split next() {
+ if(hasNext()){
+ nextUpdate();
+ }
+
+ return this.currentSplit();
+ }
+
+}
diff --git a/src/main/java/ca/joeltherrien/randomforest/covariates/numeric/NumericSplitUpdate.java b/src/main/java/ca/joeltherrien/randomforest/covariates/numeric/NumericSplitUpdate.java
new file mode 100644
index 0000000..f2757c0
--- /dev/null
+++ b/src/main/java/ca/joeltherrien/randomforest/covariates/numeric/NumericSplitUpdate.java
@@ -0,0 +1,24 @@
+package ca.joeltherrien.randomforest.covariates.numeric;
+
+import ca.joeltherrien.randomforest.Row;
+import ca.joeltherrien.randomforest.covariates.Covariate;
+import lombok.AllArgsConstructor;
+
+import java.util.Collection;
+
+@AllArgsConstructor
+public class NumericSplitUpdate implements Covariate.SplitUpdate {
+
+ private final NumericCovariate.NumericSplitRule numericSplitRule;
+ private final Collection> rowsMoved;
+
+ @Override
+ public NumericCovariate.NumericSplitRule getSplitRule() {
+ return numericSplitRule;
+ }
+
+ @Override
+ public Collection> rowsMovedToLeftHand() {
+ return rowsMoved;
+ }
+}
diff --git a/src/main/java/ca/joeltherrien/randomforest/covariates/BooleanCovariateSettings.java b/src/main/java/ca/joeltherrien/randomforest/covariates/settings/BooleanCovariateSettings.java
similarity index 75%
rename from src/main/java/ca/joeltherrien/randomforest/covariates/BooleanCovariateSettings.java
rename to src/main/java/ca/joeltherrien/randomforest/covariates/settings/BooleanCovariateSettings.java
index 5f57c15..ba6366e 100644
--- a/src/main/java/ca/joeltherrien/randomforest/covariates/BooleanCovariateSettings.java
+++ b/src/main/java/ca/joeltherrien/randomforest/covariates/settings/BooleanCovariateSettings.java
@@ -1,5 +1,6 @@
-package ca.joeltherrien.randomforest.covariates;
+package ca.joeltherrien.randomforest.covariates.settings;
+import ca.joeltherrien.randomforest.covariates.BooleanCovariate;
import lombok.Data;
import lombok.NoArgsConstructor;
diff --git a/src/main/java/ca/joeltherrien/randomforest/covariates/CovariateSettings.java b/src/main/java/ca/joeltherrien/randomforest/covariates/settings/CovariateSettings.java
similarity index 87%
rename from src/main/java/ca/joeltherrien/randomforest/covariates/CovariateSettings.java
rename to src/main/java/ca/joeltherrien/randomforest/covariates/settings/CovariateSettings.java
index 4d850ac..9b9f93c 100644
--- a/src/main/java/ca/joeltherrien/randomforest/covariates/CovariateSettings.java
+++ b/src/main/java/ca/joeltherrien/randomforest/covariates/settings/CovariateSettings.java
@@ -1,5 +1,6 @@
-package ca.joeltherrien.randomforest.covariates;
+package ca.joeltherrien.randomforest.covariates.settings;
+import ca.joeltherrien.randomforest.covariates.Covariate;
import com.fasterxml.jackson.annotation.JsonSubTypes;
import com.fasterxml.jackson.annotation.JsonTypeInfo;
import lombok.Getter;
diff --git a/src/main/java/ca/joeltherrien/randomforest/covariates/FactorCovariateSettings.java b/src/main/java/ca/joeltherrien/randomforest/covariates/settings/FactorCovariateSettings.java
similarity index 82%
rename from src/main/java/ca/joeltherrien/randomforest/covariates/FactorCovariateSettings.java
rename to src/main/java/ca/joeltherrien/randomforest/covariates/settings/FactorCovariateSettings.java
index 9d7ece5..f40213c 100644
--- a/src/main/java/ca/joeltherrien/randomforest/covariates/FactorCovariateSettings.java
+++ b/src/main/java/ca/joeltherrien/randomforest/covariates/settings/FactorCovariateSettings.java
@@ -1,5 +1,6 @@
-package ca.joeltherrien.randomforest.covariates;
+package ca.joeltherrien.randomforest.covariates.settings;
+import ca.joeltherrien.randomforest.covariates.FactorCovariate;
import lombok.Data;
import lombok.NoArgsConstructor;
diff --git a/src/main/java/ca/joeltherrien/randomforest/covariates/NumericCovariateSettings.java b/src/main/java/ca/joeltherrien/randomforest/covariates/settings/NumericCovariateSettings.java
similarity index 74%
rename from src/main/java/ca/joeltherrien/randomforest/covariates/NumericCovariateSettings.java
rename to src/main/java/ca/joeltherrien/randomforest/covariates/settings/NumericCovariateSettings.java
index 9cdf898..84a69a7 100644
--- a/src/main/java/ca/joeltherrien/randomforest/covariates/NumericCovariateSettings.java
+++ b/src/main/java/ca/joeltherrien/randomforest/covariates/settings/NumericCovariateSettings.java
@@ -1,5 +1,6 @@
-package ca.joeltherrien.randomforest.covariates;
+package ca.joeltherrien.randomforest.covariates.settings;
+import ca.joeltherrien.randomforest.covariates.numeric.NumericCovariate;
import lombok.Data;
import lombok.NoArgsConstructor;
diff --git a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskGraySetsImpl.java b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskGraySetsImpl.java
index 6650dcb..2eced4c 100644
--- a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskGraySetsImpl.java
+++ b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskGraySetsImpl.java
@@ -1,37 +1,77 @@
package ca.joeltherrien.randomforest.responses.competingrisk;
-import ca.joeltherrien.randomforest.utils.MathFunction;
-import lombok.Builder;
-import lombok.Getter;
+import java.util.Arrays;
-import java.util.List;
-import java.util.Map;
+public class CompetingRiskGraySetsImpl implements CompetingRiskSets {
-/**
- * Represents a response from CompetingRiskUtils#calculateGraySetsEfficiently
- *
- */
-@Builder
-@Getter
-public class CompetingRiskGraySetsImpl implements CompetingRiskSets{
+ final double[] times; // length m array
+ int[][] riskSetLeft; // J x m array
+ final int[][] riskSetTotal; // J x m array
+ int[][] numberOfEventsLeft; // J+1 x m array
+ final int[][] numberOfEventsTotal; // J+1 x m array
- private final List eventTimes;
- private final MathFunction[] riskSet;
- private final Map numberOfEvents;
-
- @Override
- public MathFunction getRiskSet(int event){
- return(riskSet[event-1]);
+ public CompetingRiskGraySetsImpl(double[] times, int[][] riskSetLeft, int[][] riskSetTotal, int[][] numberOfEventsLeft, int[][] numberOfEventsTotal) {
+ this.times = times;
+ this.riskSetLeft = riskSetLeft;
+ this.riskSetTotal = riskSetTotal;
+ this.numberOfEventsLeft = numberOfEventsLeft;
+ this.numberOfEventsTotal = numberOfEventsTotal;
}
@Override
- public int getNumberOfEvents(Double time, int event){
- if(numberOfEvents.containsKey(time)){
- return numberOfEvents.get(time)[event];
+ public double[] getDistinctTimes() {
+ return times;
+ }
+
+ @Override
+ public int getRiskSetLeft(int timeIndex, int event) {
+ return riskSetLeft[event-1][timeIndex];
+ }
+
+ @Override
+ public int getRiskSetTotal(int timeIndex, int event) {
+ return riskSetTotal[event-1][timeIndex];
+ }
+
+
+ @Override
+ public int getNumberOfEventsLeft(int timeIndex, int event) {
+ return numberOfEventsLeft[event][timeIndex];
+ }
+
+ @Override
+ public int getNumberOfEventsTotal(int timeIndex, int event) {
+ return numberOfEventsTotal[event][timeIndex];
+ }
+
+ @Override
+ public void update(CompetingRiskResponseWithCensorTime rowMovedToLeft) {
+ final double time = rowMovedToLeft.getU();
+ final int k = Arrays.binarySearch(times, time);
+ final int delta_m_1 = rowMovedToLeft.getDelta() - 1;
+ final double censorTime = rowMovedToLeft.getC();
+
+ for(int j=0; j= t, in I(...)
+ for(int i=0; i<=k; i++){
+ riskSetLeftJ[i]++;
+ }
+
+ // second iteration; only if delta-1 != j
+ // corresponds to the second part, U_i < t & delta_i != j & C_i > t
+ if(delta_m_1 != j && !rowMovedToLeft.isCensored()){
+ int i = k+1;
+ while(i < times.length && times[i] < censorTime){
+ riskSetLeftJ[i]++;
+ i++;
+ }
+ }
+
}
- return 0;
+ numberOfEventsLeft[rowMovedToLeft.getDelta()][k]++;
}
-
-
}
diff --git a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskSets.java b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskSets.java
index 9a53d3e..d9a72b5 100644
--- a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskSets.java
+++ b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskSets.java
@@ -1,13 +1,13 @@
package ca.joeltherrien.randomforest.responses.competingrisk;
-import ca.joeltherrien.randomforest.utils.MathFunction;
+public interface CompetingRiskSets {
-import java.util.List;
+ double[] getDistinctTimes();
+ int getRiskSetLeft(int timeIndex, int event);
+ int getRiskSetTotal(int timeIndex, int event);
+ int getNumberOfEventsLeft(int timeIndex, int event);
+ int getNumberOfEventsTotal(int timeIndex, int event);
-public interface CompetingRiskSets {
-
- MathFunction getRiskSet(int event);
- int getNumberOfEvents(Double time, int event);
- List getEventTimes();
+ void update(T rowMovedToLeft);
}
diff --git a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskSetsImpl.java b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskSetsImpl.java
index 44a9dd8..7fd8ba3 100644
--- a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskSetsImpl.java
+++ b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskSetsImpl.java
@@ -1,36 +1,59 @@
package ca.joeltherrien.randomforest.responses.competingrisk;
-import ca.joeltherrien.randomforest.utils.MathFunction;
-import lombok.Builder;
-import lombok.Getter;
+import java.util.Arrays;
-import java.util.List;
-import java.util.Map;
+public class CompetingRiskSetsImpl implements CompetingRiskSets {
-/**
- * Represents a response from CompetingRiskUtils#calculateSetsEfficiently
- *
- */
-@Builder
-@Getter
-public class CompetingRiskSetsImpl implements CompetingRiskSets{
+ final double[] times; // length m array
+ int[] riskSetLeft; // length m array
+ final int[] riskSetTotal; // length m array
+ int[][] numberOfEventsLeft; // J+1 x m array
+ final int[][] numberOfEventsTotal; // J+1 x m array
- private final List eventTimes;
- private final MathFunction riskSet;
- private final Map numberOfEvents;
- @Override
- public MathFunction getRiskSet(int event){
- return riskSet;
+ public CompetingRiskSetsImpl(double[] times, int[] riskSetLeft, int[] riskSetTotal, int[][] numberOfEventsLeft, int[][] numberOfEventsTotal) {
+ this.times = times;
+ this.riskSetLeft = riskSetLeft;
+ this.riskSetTotal = riskSetTotal;
+ this.numberOfEventsLeft = numberOfEventsLeft;
+ this.numberOfEventsTotal = numberOfEventsTotal;
}
@Override
- public int getNumberOfEvents(Double time, int event){
- if(numberOfEvents.containsKey(time)){
- return numberOfEvents.get(time)[event];
+ public double[] getDistinctTimes() {
+ return times;
+ }
+
+ @Override
+ public int getRiskSetLeft(int timeIndex, int event) {
+ return riskSetLeft[timeIndex];
+ }
+
+ @Override
+ public int getRiskSetTotal(int timeIndex, int event) {
+ return riskSetTotal[timeIndex];
+ }
+
+
+ @Override
+ public int getNumberOfEventsLeft(int timeIndex, int event) {
+ return numberOfEventsLeft[event][timeIndex];
+ }
+
+ @Override
+ public int getNumberOfEventsTotal(int timeIndex, int event) {
+ return numberOfEventsTotal[event][timeIndex];
+ }
+
+ @Override
+ public void update(CompetingRiskResponse rowMovedToLeft) {
+ final double time = rowMovedToLeft.getU();
+ final int k = Arrays.binarySearch(times, time);
+
+ for(int i=0; i<=k; i++){
+ riskSetLeft[i]++;
}
- return 0;
+ numberOfEventsLeft[rowMovedToLeft.getDelta()][k]++;
}
-
}
diff --git a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskUtils.java b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskUtils.java
index 3a612b1..94dc48e 100644
--- a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskUtils.java
+++ b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskUtils.java
@@ -1,11 +1,9 @@
package ca.joeltherrien.randomforest.responses.competingrisk;
-import ca.joeltherrien.randomforest.utils.LeftContinuousStepFunction;
import ca.joeltherrien.randomforest.utils.StepFunction;
-import ca.joeltherrien.randomforest.utils.VeryDiscontinuousStepFunction;
import java.util.*;
-import java.util.stream.DoubleStream;
+import java.util.stream.Stream;
public class CompetingRiskUtils {
@@ -102,18 +100,30 @@ public class CompetingRiskUtils {
}
- public static CompetingRiskSetsImpl calculateSetsEfficiently(final List responses, int[] eventsOfFocus){
- final int n = responses.size();
- int[] numberOfCurrentEvents = new int[eventsOfFocus.length+1];
- final Map numberOfEvents = new HashMap<>();
+ public static CompetingRiskSetsImpl calculateSetsEfficiently(final List initialLeftHand,
+ final List initialRightHand,
+ int[] eventsOfFocus,
+ boolean calculateRiskSets){
- final List eventTimes = new ArrayList<>(n);
- final List eventAndCensorTimes = new ArrayList<>(n);
- final List riskSetNumberList = new ArrayList<>(n);
+ final double[] distinctEventTimes = Stream.concat(
+ initialLeftHand.stream(),
+ initialRightHand.stream())
+ //.filter(y -> !y.isCensored())
+ .map(CompetingRiskResponse::getU)
+ .mapToDouble(Double::doubleValue)
+ .sorted()
+ .distinct()
+ .toArray();
+
+
+ final int m = distinctEventTimes.length;
+ final int[][] numberOfCurrentEventsTotal = new int[eventsOfFocus.length+1][m];
+
+ // Left Hand First
// need to first sort responses
- Collections.sort(responses, (y1, y2) -> {
+ Collections.sort(initialLeftHand, (y1, y2) -> {
if(y1.getU() < y2.getU()){
return -1;
}
@@ -125,127 +135,191 @@ public class CompetingRiskUtils {
}
});
+ final int nLeft = initialLeftHand.size();
+ final int nRight = initialRightHand.size();
+
+ final int[][] numberOfCurrentEventsLeft = new int[eventsOfFocus.length+1][m];
+ final int[] riskSetArrayLeft = new int[m];
+ final int[] riskSetArrayTotal = new int[m];
- for(int i=0; i currentResponse.getU();
- numberOfCurrentEvents[currentResponse.getDelta()]++;
+ for(int k=0; k currentResponse.getU();
+
+ final int k = Arrays.binarySearch(distinctEventTimes, currentResponse.getU());
+
+ numberOfCurrentEventsLeft[currentResponse.getDelta()][k]++;
+ numberOfCurrentEventsTotal[currentResponse.getDelta()][k]++;
if(lastOfTime){
int totalNumberOfCurrentEvents = 0;
- for(int e = 1; e < numberOfCurrentEvents.length; e++){ // exclude censored events
- totalNumberOfCurrentEvents += numberOfCurrentEvents[e];
+ for(int e = 1; e < eventsOfFocus.length+1; e++){ // exclude censored events
+ totalNumberOfCurrentEvents += numberOfCurrentEventsLeft[e][k];
}
- final double currentTime = currentResponse.getU();
+ // Calculate risk set values
+ // Note that we only decrease values in the *future*
+ if(calculateRiskSets){
+ final int decreaseBy = totalNumberOfCurrentEvents + numberOfCurrentEventsLeft[0][k];
+ for(int j=k+1; j 0){ // add numberOfCurrentEvents
- // Add point
- eventTimes.add(currentTime);
- numberOfEvents.put(currentTime, numberOfCurrentEvents);
+ }
}
- // Always do risk set
- // remember that the LeftContinuousFunction takes into account that at this currentTime the risk value is the previous value
- final int riskSet = n - (i+1);
- riskSetNumberList.add(riskSet);
- eventAndCensorTimes.add(currentTime);
-
- // reset counters
- numberOfCurrentEvents = new int[eventsOfFocus.length+1];
}
}
- final double[] riskSetArray = new double[eventAndCensorTimes.size()];
- final double[] timesArray = new double[eventAndCensorTimes.size()];
- for(int i=0; i {
+ if(y1.getU() < y2.getU()){
+ return -1;
+ }
+ else if(y1.getU() > y2.getU()){
+ return 1;
+ }
+ else{
+ return 0;
+ }
+ });
+
+ // Right Hand
+ int[] currentEventsRight = new int[eventsOfFocus.length+1];
+ for(int i=0; i currentResponse.getU();
+
+ final int k = Arrays.binarySearch(distinctEventTimes, currentResponse.getU());
+
+ currentEventsRight[currentResponse.getDelta()]++;
+ numberOfCurrentEventsTotal[currentResponse.getDelta()][k]++;
+
+ if(lastOfTime){
+ int totalNumberOfCurrentEvents = 0;
+ for(int e = 1; e < eventsOfFocus.length+1; e++){ // exclude censored events
+ totalNumberOfCurrentEvents += currentEventsRight[e];
+ }
+
+ // Calculate risk set values
+ // Note that we only decrease values in the *future*
+ if(calculateRiskSets){
+ final int decreaseBy = totalNumberOfCurrentEvents + currentEventsRight[0];
+ for(int j=k+1; j responses, int[] eventsOfFocus){
- final List sillyList = responses; // annoying Java generic work-around
- final CompetingRiskSetsImpl originalSets = calculateSetsEfficiently(sillyList, eventsOfFocus);
- final double[] allTimes = DoubleStream.concat(
- responses.stream()
- .mapToDouble(CompetingRiskResponseWithCensorTime::getC),
- responses.stream()
- .mapToDouble(CompetingRiskResponseWithCensorTime::getU)
- ).sorted().distinct().toArray();
+ public static CompetingRiskGraySetsImpl calculateGraySetsEfficiently(final List initialLeftHand,
+ final List initialRightHand,
+ int[] eventsOfFocus){
+ final List leftHandGenericsSuck = initialLeftHand;
+ final List rightHandGenericsSuck = initialRightHand;
+ final CompetingRiskSetsImpl normalSets = calculateSetsEfficiently(
+ leftHandGenericsSuck,
+ rightHandGenericsSuck,
+ eventsOfFocus, false);
- final VeryDiscontinuousStepFunction[] riskSets = new VeryDiscontinuousStepFunction[eventsOfFocus.length];
+ final double[] times = normalSets.times;
+ final int[][] numberOfEventsLeft = normalSets.numberOfEventsLeft;
+ final int[][] numberOfEventsTotal = normalSets.numberOfEventsTotal;
- for(final int event : eventsOfFocus){
- final double[] yAt = new double[allTimes.length];
- final double[] yRight = new double[allTimes.length];
+ // FYI; initialLeftHand and initialRightHand have both now been sorted
+ // Time to calculate the Gray modified risk sets
+ final int[][] riskSetsLeft = new int[eventsOfFocus.length][times.length];
+ final int[][] riskSetsTotal = new int[eventsOfFocus.length][times.length];
- for(final CompetingRiskResponseWithCensorTime response : responses){
- if(response.getDelta() == event){
- // traditional case only; increment on time t when I(t <= Ui)
- final double time = response.getU();
- final int index = Arrays.binarySearch(allTimes, time);
+ // Left hand first
+ for(final CompetingRiskResponseWithCensorTime response : initialLeftHand){
+ final double time = response.getU();
+ final int k = Arrays.binarySearch(times, time);
+ final int delta_m_1 = response.getDelta() - 1;
+ final double censorTime = response.getC();
- if(index < 0){ // TODO remove once code is stable
- throw new IllegalStateException("Index shouldn't be negative!");
- }
+ for(int j=0; j= t, in I(...)
+ for(int i=0; i<=k; i++){
+ riskSetLeftJ[i]++;
+ riskSetTotalJ[i]++;
+ }
+
+ // second iteration; only if delta-1 != j
+ // corresponds to the second part, U_i < t & delta_i != j & C_i > t
+ if(delta_m_1 != j && !response.isCensored()){
+ int i = k+1;
+ while(i < times.length && times[i] < censorTime){
+ riskSetLeftJ[i]++;
+ riskSetTotalJ[i]++;
+ i++;
}
}
- else{
- // need to increment on time t on following conditions; I(t <= Ui | t < Ci)
- // Fact: Ci >= Ui.
- // increment yAt up to Ci. If Ui==Ci, increment yAt at Ci.
- final double time = response.getC();
- final int index = Arrays.binarySearch(allTimes, time);
-
- if(index < 0){ // TODO remove once code is stable
- throw new IllegalStateException("Index shouldn't be negative!");
- }
-
- for(int i=0; i= t, in I(...)
+ for(int i=0; i<=k; i++){
+ riskSetTotalJ[i]++;
+ }
+
+ // second iteration; only if delta-1 != j
+ // corresponds to the second part, U_i < t & delta_i != j & C_i > t
+ if(delta_m_1 != j && !response.isCensored()){
+ int i = k+1;
+ while(i < times.length && times[i] < censorTime){
+ riskSetTotalJ[i]++;
+ i++;
+ }
+ }
+
+ }
+
+ }
+
+ return new CompetingRiskGraySetsImpl(times, riskSetsLeft, riskSetsTotal, numberOfEventsLeft, numberOfEventsTotal);
}
diff --git a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/differentiator/CompetingRiskGroupDifferentiator.java b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/differentiator/CompetingRiskGroupDifferentiator.java
index f6efdeb..703ebfc 100644
--- a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/differentiator/CompetingRiskGroupDifferentiator.java
+++ b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/differentiator/CompetingRiskGroupDifferentiator.java
@@ -1,60 +1,132 @@
package ca.joeltherrien.randomforest.responses.competingrisk.differentiator;
+import ca.joeltherrien.randomforest.Row;
+import ca.joeltherrien.randomforest.covariates.Covariate;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskSets;
import ca.joeltherrien.randomforest.tree.GroupDifferentiator;
+import ca.joeltherrien.randomforest.tree.Split;
+import ca.joeltherrien.randomforest.tree.SplitAndScore;
import lombok.AllArgsConstructor;
import lombok.Data;
+import java.util.Iterator;
import java.util.List;
-import java.util.stream.Stream;
+import java.util.stream.Collectors;
/**
* See page 761 of Random survival forests for competing risks by Ishwaran et al. The class is abstract as Gray's test
* modifies the abstract method.
*
*/
-public abstract class CompetingRiskGroupDifferentiator implements GroupDifferentiator{
+public abstract class CompetingRiskGroupDifferentiator implements GroupDifferentiator {
+
+ abstract protected CompetingRiskSets createCompetingRiskSets(List leftHand, List rightHand);
+
+ abstract protected Double getScore(final CompetingRiskSets competingRiskSets);
@Override
- public abstract Double differentiate(List leftHand, List rightHand);
+ public SplitAndScore differentiate(Iterator> splitIterator) {
+ if(splitIterator instanceof Covariate.SplitRuleUpdater){
+ return differentiateWithSplitUpdater((Covariate.SplitRuleUpdater) splitIterator);
+ }
+ else{
+ return differentiateWithBasicIterator(splitIterator);
+ }
+ }
+
+ private SplitAndScore differentiateWithBasicIterator(Iterator> splitIterator){
+ Double bestScore = null;
+ Split bestSplit = null;
+
+ while(splitIterator.hasNext()){
+ final Split candidateSplit = splitIterator.next();
+
+ final List leftHand = candidateSplit.getLeftHand().stream().map(Row::getResponse).collect(Collectors.toList());
+ final List rightHand = candidateSplit.getRightHand().stream().map(Row::getResponse).collect(Collectors.toList());
+
+ if(leftHand.isEmpty() || rightHand.isEmpty()){
+ continue;
+ }
+
+ final CompetingRiskSets competingRiskSets = createCompetingRiskSets(leftHand, rightHand);
+
+ final Double score = getScore(competingRiskSets);
+
+ if(Double.isFinite(score) && (bestScore == null || score > bestScore)){
+ bestScore = score;
+ bestSplit = candidateSplit;
+ }
+ }
+
+ if(bestSplit == null){
+ return null;
+ }
+
+ return new SplitAndScore<>(bestSplit, bestScore);
+ }
+
+ private SplitAndScore differentiateWithSplitUpdater(Covariate.SplitRuleUpdater splitRuleUpdater) {
+
+ final List leftInitialSplit = splitRuleUpdater.currentSplit().getLeftHand()
+ .stream().map(Row::getResponse).collect(Collectors.toList());
+ final List rightInitialSplit = splitRuleUpdater.currentSplit().getRightHand()
+ .stream().map(Row::getResponse).collect(Collectors.toList());
+
+ final CompetingRiskSets competingRiskSets = createCompetingRiskSets(leftInitialSplit, rightInitialSplit);
+
+ Double bestScore = null;
+ Split bestSplit = null;
+
+ while(splitRuleUpdater.hasNext()){
+ for(Row rowMoved : splitRuleUpdater.nextUpdate().rowsMovedToLeftHand()){
+ competingRiskSets.update(rowMoved.getResponse());
+ }
+
+ final Double score = getScore(competingRiskSets);
+
+ if(Double.isFinite(score) && (bestScore == null || score > bestScore)){
+ bestScore = score;
+ bestSplit = splitRuleUpdater.currentSplit();
+ }
+ }
+
+ if(bestSplit == null){
+ return null;
+ }
+
+ return new SplitAndScore<>(bestSplit, bestScore);
+
+ }
/**
* Calculates the log rank value (or the Gray's test value) for a *specific* event cause.
*
* @param eventOfFocus
- * @param competingRiskSetsLeft A summary of the different sets used in the calculation for the left side
- * @param competingRiskSetsRight A summary of the different sets used in the calculation for the right side
+ * @param competingRiskSets A summary of the different sets used in the calculation
* @return
*/
- LogRankValue specificLogRankValue(final int eventOfFocus, final CompetingRiskSets competingRiskSetsLeft, final CompetingRiskSets competingRiskSetsRight){
-
- final double[] distinctEventTimes = Stream.concat(
- competingRiskSetsLeft.getEventTimes().stream(),
- competingRiskSetsRight.getEventTimes().stream())
- .mapToDouble(Double::doubleValue)
- .sorted()
- .distinct()
- .toArray();
+ LogRankValue specificLogRankValue(final int eventOfFocus, final CompetingRiskSets competingRiskSets){
double summation = 0.0;
double variance = 0.0;
- for(final double time_k : distinctEventTimes){
+ final double[] distinctTimes = competingRiskSets.getDistinctTimes();
+
+ for(int k = 0; k leftHand, List rightHand) {
- if(leftHand.size() == 0 || rightHand.size() == 0){
- return null;
- }
-
- final CompetingRiskGraySetsImpl competingRiskSetsLeft = CompetingRiskUtils.calculateGraySetsEfficiently(leftHand, events);
- final CompetingRiskGraySetsImpl competingRiskSetsRight = CompetingRiskUtils.calculateGraySetsEfficiently(rightHand, events);
+ protected CompetingRiskSets createCompetingRiskSets(List leftHand, List rightHand){
+ return CompetingRiskUtils.calculateGraySetsEfficiently(leftHand, rightHand, events);
+ }
+ @Override
+ protected Double getScore(final CompetingRiskSets competingRiskSets){
double numerator = 0.0;
double denominatorSquared = 0.0;
for(final int eventOfFocus : events){
- final LogRankValue valueOfInterest = specificLogRankValue(eventOfFocus, competingRiskSetsLeft, competingRiskSetsRight);
+ final LogRankValue valueOfInterest = specificLogRankValue(eventOfFocus, competingRiskSets);
numerator += valueOfInterest.getNumerator()*valueOfInterest.getVarianceSqrt();
denominatorSquared += valueOfInterest.getVariance();
@@ -37,7 +35,6 @@ public class GrayLogRankMultipleGroupDifferentiator extends CompetingRiskGroupDi
}
return Math.abs(numerator / Math.sqrt(denominatorSquared));
-
}
diff --git a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/differentiator/GrayLogRankSingleGroupDifferentiator.java b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/differentiator/GrayLogRankSingleGroupDifferentiator.java
index 48e3b6f..afe66db 100644
--- a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/differentiator/GrayLogRankSingleGroupDifferentiator.java
+++ b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/differentiator/GrayLogRankSingleGroupDifferentiator.java
@@ -1,7 +1,7 @@
package ca.joeltherrien.randomforest.responses.competingrisk.differentiator;
-import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskGraySetsImpl;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponseWithCensorTime;
+import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskSets;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskUtils;
import lombok.RequiredArgsConstructor;
@@ -18,18 +18,14 @@ public class GrayLogRankSingleGroupDifferentiator extends CompetingRiskGroupDiff
private final int[] events;
@Override
- public Double differentiate(List leftHand, List rightHand) {
- if(leftHand.size() == 0 || rightHand.size() == 0){
- return null;
- }
-
- final CompetingRiskGraySetsImpl competingRiskSetsLeft = CompetingRiskUtils.calculateGraySetsEfficiently(leftHand, events);
- final CompetingRiskGraySetsImpl competingRiskSetsRight = CompetingRiskUtils.calculateGraySetsEfficiently(rightHand, events);
-
- final LogRankValue valueOfInterest = specificLogRankValue(eventOfFocus, competingRiskSetsLeft, competingRiskSetsRight);
+ protected CompetingRiskSets createCompetingRiskSets(List leftHand, List rightHand){
+ return CompetingRiskUtils.calculateGraySetsEfficiently(leftHand, rightHand, events);
+ }
+ @Override
+ protected Double getScore(final CompetingRiskSets competingRiskSets){
+ final LogRankValue valueOfInterest = specificLogRankValue(eventOfFocus, competingRiskSets);
return Math.abs(valueOfInterest.getNumerator() / valueOfInterest.getVarianceSqrt());
-
}
}
diff --git a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/differentiator/LogRankMultipleGroupDifferentiator.java b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/differentiator/LogRankMultipleGroupDifferentiator.java
index 2ad2424..6465b44 100644
--- a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/differentiator/LogRankMultipleGroupDifferentiator.java
+++ b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/differentiator/LogRankMultipleGroupDifferentiator.java
@@ -1,7 +1,7 @@
package ca.joeltherrien.randomforest.responses.competingrisk.differentiator;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
-import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskSetsImpl;
+import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskSets;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskUtils;
import lombok.RequiredArgsConstructor;
@@ -17,19 +17,17 @@ public class LogRankMultipleGroupDifferentiator extends CompetingRiskGroupDiffer
private final int[] events;
@Override
- public Double differentiate(List leftHand, List rightHand) {
- if(leftHand.size() == 0 || rightHand.size() == 0){
- return null;
- }
-
- final CompetingRiskSetsImpl competingRiskSetsLeft = CompetingRiskUtils.calculateSetsEfficiently(leftHand, events);
- final CompetingRiskSetsImpl competingRiskSetsRight = CompetingRiskUtils.calculateSetsEfficiently(rightHand, events);
+ protected CompetingRiskSets createCompetingRiskSets(List leftHand, List rightHand){
+ return CompetingRiskUtils.calculateSetsEfficiently(leftHand, rightHand, events, true);
+ }
+ @Override
+ protected Double getScore(final CompetingRiskSets competingRiskSets){
double numerator = 0.0;
double denominatorSquared = 0.0;
for(final int eventOfFocus : events){
- final LogRankValue valueOfInterest = specificLogRankValue(eventOfFocus, competingRiskSetsLeft, competingRiskSetsRight);
+ final LogRankValue valueOfInterest = specificLogRankValue(eventOfFocus, competingRiskSets);
numerator += valueOfInterest.getNumerator()*valueOfInterest.getVarianceSqrt();
denominatorSquared += valueOfInterest.getVariance();
@@ -37,7 +35,7 @@ public class LogRankMultipleGroupDifferentiator extends CompetingRiskGroupDiffer
}
return Math.abs(numerator / Math.sqrt(denominatorSquared));
-
}
+
}
diff --git a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/differentiator/LogRankSingleGroupDifferentiator.java b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/differentiator/LogRankSingleGroupDifferentiator.java
index 107964e..7c633b1 100644
--- a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/differentiator/LogRankSingleGroupDifferentiator.java
+++ b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/differentiator/LogRankSingleGroupDifferentiator.java
@@ -1,7 +1,7 @@
package ca.joeltherrien.randomforest.responses.competingrisk.differentiator;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
-import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskSetsImpl;
+import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskSets;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskUtils;
import lombok.RequiredArgsConstructor;
@@ -18,18 +18,14 @@ public class LogRankSingleGroupDifferentiator extends CompetingRiskGroupDifferen
private final int[] events;
@Override
- public Double differentiate(List leftHand, List rightHand) {
- if(leftHand.size() == 0 || rightHand.size() == 0){
- return null;
- }
-
- final CompetingRiskSetsImpl competingRiskSetsLeft = CompetingRiskUtils.calculateSetsEfficiently(leftHand, events);
- final CompetingRiskSetsImpl competingRiskSetsRight = CompetingRiskUtils.calculateSetsEfficiently(rightHand, events);
-
- final LogRankValue valueOfInterest = specificLogRankValue(eventOfFocus, competingRiskSetsLeft, competingRiskSetsRight);
+ protected CompetingRiskSets createCompetingRiskSets(List leftHand, List rightHand){
+ return CompetingRiskUtils.calculateSetsEfficiently(leftHand, rightHand, events, true);
+ }
+ @Override
+ protected Double getScore(final CompetingRiskSets competingRiskSets){
+ final LogRankValue valueOfInterest = specificLogRankValue(eventOfFocus, competingRiskSets);
return Math.abs(valueOfInterest.getNumerator() / valueOfInterest.getVarianceSqrt());
-
}
}
diff --git a/src/main/java/ca/joeltherrien/randomforest/responses/regression/MeanGroupDifferentiator.java b/src/main/java/ca/joeltherrien/randomforest/responses/regression/MeanGroupDifferentiator.java
deleted file mode 100644
index 75bc129..0000000
--- a/src/main/java/ca/joeltherrien/randomforest/responses/regression/MeanGroupDifferentiator.java
+++ /dev/null
@@ -1,26 +0,0 @@
-package ca.joeltherrien.randomforest.responses.regression;
-
-import ca.joeltherrien.randomforest.tree.GroupDifferentiator;
-
-import java.util.List;
-
-public class MeanGroupDifferentiator implements GroupDifferentiator {
-
- @Override
- public Double differentiate(List leftHand, List rightHand) {
-
- double leftHandSize = leftHand.size();
- double rightHandSize = rightHand.size();
-
- if(leftHandSize == 0 || rightHandSize == 0){
- return null;
- }
-
- double leftHandMean = leftHand.stream().mapToDouble(db -> db/leftHandSize).sum();
- double rightHandMean = rightHand.stream().mapToDouble(db -> db/rightHandSize).sum();
-
- return Math.abs(leftHandMean - rightHandMean);
-
- }
-
-}
diff --git a/src/main/java/ca/joeltherrien/randomforest/responses/regression/WeightedVarianceGroupDifferentiator.java b/src/main/java/ca/joeltherrien/randomforest/responses/regression/WeightedVarianceGroupDifferentiator.java
index 25f7e6e..9ae4673 100644
--- a/src/main/java/ca/joeltherrien/randomforest/responses/regression/WeightedVarianceGroupDifferentiator.java
+++ b/src/main/java/ca/joeltherrien/randomforest/responses/regression/WeightedVarianceGroupDifferentiator.java
@@ -1,13 +1,13 @@
package ca.joeltherrien.randomforest.responses.regression;
-import ca.joeltherrien.randomforest.tree.GroupDifferentiator;
+import ca.joeltherrien.randomforest.tree.SimpleGroupDifferentiator;
import java.util.List;
-public class WeightedVarianceGroupDifferentiator implements GroupDifferentiator {
+public class WeightedVarianceGroupDifferentiator extends SimpleGroupDifferentiator {
@Override
- public Double differentiate(List leftHand, List rightHand) {
+ public Double getScore(List leftHand, List rightHand) {
final double leftHandSize = leftHand.size();
final double rightHandSize = rightHand.size();
diff --git a/src/main/java/ca/joeltherrien/randomforest/tree/ForestTrainer.java b/src/main/java/ca/joeltherrien/randomforest/tree/ForestTrainer.java
index 37ee4ce..b6330cc 100644
--- a/src/main/java/ca/joeltherrien/randomforest/tree/ForestTrainer.java
+++ b/src/main/java/ca/joeltherrien/randomforest/tree/ForestTrainer.java
@@ -14,8 +14,10 @@ import java.io.IOException;
import java.io.ObjectOutputStream;
import java.util.ArrayList;
import java.util.List;
+import java.util.Random;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
+import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;
import java.util.stream.Stream;
@@ -45,17 +47,17 @@ public class ForestTrainer {
this.covariates = covariates;
this.treeResponseCombiner = settings.getTreeCombiner();
this.treeTrainer = new TreeTrainer<>(settings, covariates);
-
}
public Forest trainSerial(){
final List> trees = new ArrayList<>(ntree);
final Bootstrapper> bootstrapper = new Bootstrapper<>(data);
+ final Random random = new Random();
for(int j=0; j {
}
- final File[] treeFiles = folder.listFiles(((file, s) -> s.endsWith(".tree")));
+ final File[] treeFiles = folder.listFiles((file, s) -> s.endsWith(".tree"));
final ExecutorService executorService = Executors.newFixedThreadPool(threads);
final AtomicInteger treeCount = new AtomicInteger(treeFiles.length); // tracks how many trees are finished
@@ -162,9 +164,9 @@ public class ForestTrainer {
}
- private Tree trainTree(final Bootstrapper> bootstrapper){
- final List> bootstrappedData = bootstrapper.bootstrap();
- return treeTrainer.growTree(bootstrappedData);
+ private Tree trainTree(final Bootstrapper> bootstrapper, Random random){
+ final List> bootstrappedData = bootstrapper.bootstrap(random);
+ return treeTrainer.growTree(bootstrappedData, random);
}
public void saveTree(final Tree tree, String name) throws IOException {
@@ -193,7 +195,8 @@ public class ForestTrainer {
@Override
public void run() {
- final Tree tree = trainTree(bootstrapper);
+ // ThreadLocalRandom should make sure we don't duplicate seeds
+ final Tree tree = trainTree(bootstrapper, ThreadLocalRandom.current());
// should be okay as the list structure isn't changing
treeList.set(treeIndex, tree);
@@ -216,7 +219,8 @@ public class ForestTrainer {
@Override
public void run() {
- final Tree tree = trainTree(bootstrapper);
+ // ThreadLocalRandom should make sure we don't duplicate seeds
+ final Tree tree = trainTree(bootstrapper, ThreadLocalRandom.current());
try {
saveTree(tree, filename);
diff --git a/src/main/java/ca/joeltherrien/randomforest/tree/GroupDifferentiator.java b/src/main/java/ca/joeltherrien/randomforest/tree/GroupDifferentiator.java
index cbd1247..66e8cbc 100644
--- a/src/main/java/ca/joeltherrien/randomforest/tree/GroupDifferentiator.java
+++ b/src/main/java/ca/joeltherrien/randomforest/tree/GroupDifferentiator.java
@@ -1,15 +1,17 @@
package ca.joeltherrien.randomforest.tree;
-import java.util.List;
+import java.util.Iterator;
/**
* When choosing an optimal node to split on, we choose the split that maximizes the difference between the two groups.
- * The GroupDifferentiator has one method that outputs a score to show how different groups are. The larger the score,
- * the greater the difference.
+ * The GroupDifferentiator has one method that cycles through an iterator of Splits (FYI; check if the iterator is an
+ * instance of Covariate.SplitRuleUpdater; in which case you get access to the rows that change between splits)
*
+ * If you want to implement a very trivial GroupDifferentiator that just takes two Lists as arguments, try extending
+ * SimpleGroupDifferentiator.
*/
public interface GroupDifferentiator {
- Double differentiate(List leftHand, List rightHand);
+ SplitAndScore differentiate(Iterator> splitIterator);
}
diff --git a/src/main/java/ca/joeltherrien/randomforest/tree/SimpleGroupDifferentiator.java b/src/main/java/ca/joeltherrien/randomforest/tree/SimpleGroupDifferentiator.java
new file mode 100644
index 0000000..596f81e
--- /dev/null
+++ b/src/main/java/ca/joeltherrien/randomforest/tree/SimpleGroupDifferentiator.java
@@ -0,0 +1,50 @@
+package ca.joeltherrien.randomforest.tree;
+
+import ca.joeltherrien.randomforest.Row;
+
+import java.util.Iterator;
+import java.util.List;
+import java.util.stream.Collectors;
+
+public abstract class SimpleGroupDifferentiator implements GroupDifferentiator {
+
+ @Override
+ public SplitAndScore differentiate(Iterator> splitIterator) {
+ Double bestScore = null;
+ Split bestSplit = null;
+
+ while(splitIterator.hasNext()){
+ final Split candidateSplit = splitIterator.next();
+
+ final List leftHand = candidateSplit.getLeftHand().stream().map(Row::getResponse).collect(Collectors.toList());
+ final List rightHand = candidateSplit.getRightHand().stream().map(Row::getResponse).collect(Collectors.toList());
+
+ if(leftHand.isEmpty() || rightHand.isEmpty()){
+ continue;
+ }
+
+ final Double score = getScore(leftHand, rightHand);
+
+ if(score != null && (bestScore == null || score > bestScore)){
+ bestScore = score;
+ bestSplit = candidateSplit;
+ }
+ }
+
+ if(bestSplit == null){
+ return null;
+ }
+
+ return new SplitAndScore<>(bestSplit, bestScore);
+ }
+
+ /**
+ * Return a score; higher is better.
+ *
+ * @param leftHand
+ * @param rightHand
+ * @return
+ */
+ public abstract Double getScore(List leftHand, List rightHand);
+
+}
diff --git a/src/main/java/ca/joeltherrien/randomforest/tree/Split.java b/src/main/java/ca/joeltherrien/randomforest/tree/Split.java
index e566e64..b55444c 100644
--- a/src/main/java/ca/joeltherrien/randomforest/tree/Split.java
+++ b/src/main/java/ca/joeltherrien/randomforest/tree/Split.java
@@ -1,19 +1,21 @@
package ca.joeltherrien.randomforest.tree;
import ca.joeltherrien.randomforest.Row;
+import ca.joeltherrien.randomforest.covariates.Covariate;
import lombok.Data;
import java.util.List;
/**
- * Very simple class that contains three lists; it's essentially a thruple.
+ * Very simple class that contains three lists and a SplitRule.
*
* @author joel
*
*/
@Data
-public class Split {
+public final class Split {
+ public final Covariate.SplitRule splitRule;
public final List> leftHand;
public final List> rightHand;
public final List> naHand;
diff --git a/src/main/java/ca/joeltherrien/randomforest/tree/SplitAndScore.java b/src/main/java/ca/joeltherrien/randomforest/tree/SplitAndScore.java
new file mode 100644
index 0000000..1160680
--- /dev/null
+++ b/src/main/java/ca/joeltherrien/randomforest/tree/SplitAndScore.java
@@ -0,0 +1,15 @@
+package ca.joeltherrien.randomforest.tree;
+
+import lombok.AllArgsConstructor;
+import lombok.Getter;
+
+@AllArgsConstructor
+public class SplitAndScore {
+
+ @Getter
+ private final Split split;
+
+ @Getter
+ private final Double score;
+
+}
diff --git a/src/main/java/ca/joeltherrien/randomforest/tree/SplitNode.java b/src/main/java/ca/joeltherrien/randomforest/tree/SplitNode.java
index 8837170..d43e273 100644
--- a/src/main/java/ca/joeltherrien/randomforest/tree/SplitNode.java
+++ b/src/main/java/ca/joeltherrien/randomforest/tree/SplitNode.java
@@ -3,8 +3,10 @@ package ca.joeltherrien.randomforest.tree;
import ca.joeltherrien.randomforest.CovariateRow;
import ca.joeltherrien.randomforest.covariates.Covariate;
import lombok.Builder;
+import lombok.ToString;
@Builder
+@ToString
public class SplitNode implements Node {
private final Node leftHand;
diff --git a/src/main/java/ca/joeltherrien/randomforest/tree/TerminalNode.java b/src/main/java/ca/joeltherrien/randomforest/tree/TerminalNode.java
index 917c10d..cc96d03 100644
--- a/src/main/java/ca/joeltherrien/randomforest/tree/TerminalNode.java
+++ b/src/main/java/ca/joeltherrien/randomforest/tree/TerminalNode.java
@@ -2,8 +2,10 @@ package ca.joeltherrien.randomforest.tree;
import ca.joeltherrien.randomforest.CovariateRow;
import lombok.RequiredArgsConstructor;
+import lombok.ToString;
@RequiredArgsConstructor
+@ToString
public class TerminalNode implements Node {
private final Y responseValue;
@@ -14,6 +16,4 @@ public class TerminalNode implements Node {
}
-
-
}
diff --git a/src/main/java/ca/joeltherrien/randomforest/tree/Tree.java b/src/main/java/ca/joeltherrien/randomforest/tree/Tree.java
index 12d3d03..de29e80 100644
--- a/src/main/java/ca/joeltherrien/randomforest/tree/Tree.java
+++ b/src/main/java/ca/joeltherrien/randomforest/tree/Tree.java
@@ -29,4 +29,8 @@ public class Tree implements Node {
return Arrays.binarySearch(this.bootstrapRowIds, id) >= 0;
}
+ @Override
+ public String toString(){
+ return rootNode.toString();
+ }
}
diff --git a/src/main/java/ca/joeltherrien/randomforest/tree/TreeTrainer.java b/src/main/java/ca/joeltherrien/randomforest/tree/TreeTrainer.java
index 0e359b4..d1d141d 100644
--- a/src/main/java/ca/joeltherrien/randomforest/tree/TreeTrainer.java
+++ b/src/main/java/ca/joeltherrien/randomforest/tree/TreeTrainer.java
@@ -8,7 +8,6 @@ import lombok.AllArgsConstructor;
import lombok.Builder;
import java.util.*;
-import java.util.concurrent.ThreadLocalRandom;
import java.util.stream.Collectors;
@Builder
@@ -47,20 +46,21 @@ public class TreeTrainer {
this.covariates = covariates;
}
- public Tree