diff --git a/src/main/java/ca/joeltherrien/randomforest/Bootstrapper.java b/src/main/java/ca/joeltherrien/randomforest/Bootstrapper.java index 2570201..b73db9d 100644 --- a/src/main/java/ca/joeltherrien/randomforest/Bootstrapper.java +++ b/src/main/java/ca/joeltherrien/randomforest/Bootstrapper.java @@ -4,7 +4,6 @@ import lombok.RequiredArgsConstructor; import java.util.ArrayList; import java.util.List; -import java.util.Random; import java.util.concurrent.ThreadLocalRandom; @RequiredArgsConstructor diff --git a/src/main/java/ca/joeltherrien/randomforest/DataLoader.java b/src/main/java/ca/joeltherrien/randomforest/DataLoader.java index 0d999cb..b7e2e8c 100644 --- a/src/main/java/ca/joeltherrien/randomforest/DataLoader.java +++ b/src/main/java/ca/joeltherrien/randomforest/DataLoader.java @@ -74,13 +74,11 @@ public class DataLoader { } - final Forest forest = Forest.builder() + return Forest.builder() .trees(treeList) .treeResponseCombiner(treeResponseCombiner) .build(); - return forest; - } @FunctionalInterface diff --git a/src/main/java/ca/joeltherrien/randomforest/Main.java b/src/main/java/ca/joeltherrien/randomforest/Main.java index 37835fd..5fd76a7 100644 --- a/src/main/java/ca/joeltherrien/randomforest/Main.java +++ b/src/main/java/ca/joeltherrien/randomforest/Main.java @@ -22,7 +22,6 @@ import java.util.stream.Collectors; public class Main { - public static void main(String[] args) throws IOException, ClassNotFoundException { if(args.length < 2){ System.out.println("Must provide two arguments - the path to the settings.yaml file and instructions to either train or analyze."); @@ -160,7 +159,7 @@ public class Main { yVarSettings.set("type", new TextNode("Double")); yVarSettings.set("name", new TextNode("y")); - final Settings settings = Settings.builder() + return Settings.builder() .covariates(Utils.easyList( new NumericCovariateSettings("x1"), new BooleanCovariateSettings("x2"), @@ -181,9 +180,6 @@ public class Main { .saveProgress(true) .saveTreeLocation("trees/") .build(); - - - return settings; } } diff --git a/src/main/java/ca/joeltherrien/randomforest/Settings.java b/src/main/java/ca/joeltherrien/randomforest/Settings.java index bbbfb29..70d857a 100644 --- a/src/main/java/ca/joeltherrien/randomforest/Settings.java +++ b/src/main/java/ca/joeltherrien/randomforest/Settings.java @@ -133,7 +133,7 @@ public class Settings { if(node.hasNonNull("times")){ final List timeList = new ArrayList<>(); node.get("times").elements().forEachRemaining(event -> timeList.add(event.asDouble())); - times = eventList.stream().mapToDouble(db -> db).toArray(); + times = timeList.stream().mapToDouble(db -> db).toArray(); } return new CompetingRiskResponseCombiner(events, times); @@ -152,7 +152,7 @@ public class Settings { if(node.hasNonNull("times")){ final List timeList = new ArrayList<>(); node.get("times").elements().forEachRemaining(event -> timeList.add(event.asDouble())); - times = eventList.stream().mapToDouble(db -> db).toArray(); + times = timeList.stream().mapToDouble(db -> db).toArray(); } return new CompetingRiskFunctionCombiner(events, times); @@ -171,7 +171,7 @@ public class Settings { if(node.hasNonNull("times")){ final List timeList = new ArrayList<>(); node.get("times").elements().forEachRemaining(event -> timeList.add(event.asDouble())); - times = eventList.stream().mapToDouble(db -> db).toArray(); + times = timeList.stream().mapToDouble(db -> db).toArray(); } final CompetingRiskResponseCombiner responseCombiner = new CompetingRiskResponseCombiner(events, times); @@ -217,9 +217,7 @@ public class Settings { final ObjectMapper mapper = new ObjectMapper(new YAMLFactory()); //mapper.enableDefaultTyping(); - final Settings settings = mapper.readValue(file, Settings.class); - - return settings; + return mapper.readValue(file, Settings.class); } diff --git a/src/main/java/ca/joeltherrien/randomforest/covariates/CovariateSettings.java b/src/main/java/ca/joeltherrien/randomforest/covariates/CovariateSettings.java index 1418428..18f53f2 100644 --- a/src/main/java/ca/joeltherrien/randomforest/covariates/CovariateSettings.java +++ b/src/main/java/ca/joeltherrien/randomforest/covariates/CovariateSettings.java @@ -14,7 +14,6 @@ import lombok.NoArgsConstructor; @Getter @JsonTypeInfo( use = JsonTypeInfo.Id.NAME, - include = JsonTypeInfo.As.PROPERTY, property = "type") @JsonSubTypes({ @JsonSubTypes.Type(value = BooleanCovariateSettings.class, name = "boolean"), diff --git a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskFunctions.java b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskFunctions.java index 4aa4470..869b7f4 100644 --- a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskFunctions.java +++ b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskFunctions.java @@ -7,7 +7,6 @@ import lombok.Getter; import java.io.Serializable; import java.util.List; -import java.util.Map; @Builder public class CompetingRiskFunctions implements Serializable { diff --git a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskResponseWithCensorTime.java b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskResponseWithCensorTime.java index 7aefb88..3aa78c8 100644 --- a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskResponseWithCensorTime.java +++ b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/CompetingRiskResponseWithCensorTime.java @@ -3,6 +3,7 @@ package ca.joeltherrien.randomforest.responses.competingrisk; import ca.joeltherrien.randomforest.DataLoader; import com.fasterxml.jackson.databind.node.ObjectNode; import lombok.Data; +import lombok.EqualsAndHashCode; import lombok.RequiredArgsConstructor; import org.apache.commons.csv.CSVRecord; @@ -10,6 +11,7 @@ import org.apache.commons.csv.CSVRecord; * See Ishwaran paper on splitting rule modelled after Gray's test. This requires that we know the censor times. * */ +@EqualsAndHashCode(callSuper = true) @Data public final class CompetingRiskResponseWithCensorTime extends CompetingRiskResponse { private final double c; diff --git a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/combiner/CompetingRiskFunctionCombiner.java b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/combiner/CompetingRiskFunctionCombiner.java index ed5fff2..e371ca4 100644 --- a/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/combiner/CompetingRiskFunctionCombiner.java +++ b/src/main/java/ca/joeltherrien/randomforest/responses/competingrisk/combiner/CompetingRiskFunctionCombiner.java @@ -7,9 +7,7 @@ import ca.joeltherrien.randomforest.utils.Point; import lombok.RequiredArgsConstructor; import java.util.ArrayList; -import java.util.HashMap; import java.util.List; -import java.util.Map; @RequiredArgsConstructor public class CompetingRiskFunctionCombiner implements ResponseCombiner { diff --git a/src/main/java/ca/joeltherrien/randomforest/tree/Forest.java b/src/main/java/ca/joeltherrien/randomforest/tree/Forest.java index 12948e6..5835e89 100644 --- a/src/main/java/ca/joeltherrien/randomforest/tree/Forest.java +++ b/src/main/java/ca/joeltherrien/randomforest/tree/Forest.java @@ -2,7 +2,6 @@ package ca.joeltherrien.randomforest.tree; import ca.joeltherrien.randomforest.CovariateRow; import lombok.Builder; -import lombok.RequiredArgsConstructor; import java.util.Collection; import java.util.Collections; diff --git a/src/main/java/ca/joeltherrien/randomforest/tree/ForestTrainer.java b/src/main/java/ca/joeltherrien/randomforest/tree/ForestTrainer.java index 69f7ab0..37ee4ce 100644 --- a/src/main/java/ca/joeltherrien/randomforest/tree/ForestTrainer.java +++ b/src/main/java/ca/joeltherrien/randomforest/tree/ForestTrainer.java @@ -184,7 +184,7 @@ public class ForestTrainer { private final int treeIndex; private final List> treeList; - public TreeInMemoryWorker(final List> data, final int treeIndex, final List> treeList) { + TreeInMemoryWorker(final List> data, final int treeIndex, final List> treeList) { this.bootstrapper = new Bootstrapper<>(data); this.treeIndex = treeIndex; this.treeList = treeList; diff --git a/src/main/java/ca/joeltherrien/randomforest/tree/TreeTrainer.java b/src/main/java/ca/joeltherrien/randomforest/tree/TreeTrainer.java index 15f0cdb..7ec00fa 100644 --- a/src/main/java/ca/joeltherrien/randomforest/tree/TreeTrainer.java +++ b/src/main/java/ca/joeltherrien/randomforest/tree/TreeTrainer.java @@ -43,9 +43,7 @@ public class TreeTrainer { public Tree growTree(List> data){ final Node rootNode = growNode(data, 0); - final Tree tree = new Tree<>(rootNode, data.stream().mapToInt(Row::getId).toArray()); - - return tree; + return new Tree<>(rootNode, data.stream().mapToInt(Row::getId).toArray()); } @@ -95,8 +93,8 @@ public class TreeTrainer { final List splitCovariates = new ArrayList<>(covariates); Collections.shuffle(splitCovariates, ThreadLocalRandom.current()); - for(int treeIndex = splitCovariates.size()-1; treeIndex >= mtry; treeIndex--){ - splitCovariates.remove(treeIndex); + if (splitCovariates.size() > mtry) { + splitCovariates.subList(mtry, splitCovariates.size()).clear(); } return splitCovariates; diff --git a/src/main/java/ca/joeltherrien/randomforest/utils/Utils.java b/src/main/java/ca/joeltherrien/randomforest/utils/Utils.java index f9b5c86..2342851 100644 --- a/src/main/java/ca/joeltherrien/randomforest/utils/Utils.java +++ b/src/main/java/ca/joeltherrien/randomforest/utils/Utils.java @@ -74,9 +74,7 @@ public class Utils { public static List easyList(T... array){ final List list = new ArrayList<>(array.length); - for(final T item : array){ - list.add(item); - } + Collections.addAll(list, array); return list; diff --git a/src/test/java/ca/joeltherrien/randomforest/TestSavingLoading.java b/src/test/java/ca/joeltherrien/randomforest/TestSavingLoading.java index 1e5e220..49572c9 100644 --- a/src/test/java/ca/joeltherrien/randomforest/TestSavingLoading.java +++ b/src/test/java/ca/joeltherrien/randomforest/TestSavingLoading.java @@ -16,7 +16,6 @@ import static org.junit.jupiter.api.Assertions.*; import java.io.File; import java.io.IOException; import java.util.List; -import java.util.Map; import java.util.stream.Collectors; public class TestSavingLoading { diff --git a/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCompetingRisk.java b/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCompetingRisk.java index 44296d3..8bbd4ce 100644 --- a/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCompetingRisk.java +++ b/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCompetingRisk.java @@ -19,7 +19,6 @@ import static org.junit.jupiter.api.Assertions.*; import java.io.IOException; import java.util.List; -import java.util.Map; import java.util.stream.Collectors; public class TestCompetingRisk { diff --git a/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCompetingRiskResponseCombiner.java b/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCompetingRiskResponseCombiner.java index 24cfaae..d66fb0f 100644 --- a/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCompetingRiskResponseCombiner.java +++ b/src/test/java/ca/joeltherrien/randomforest/competingrisk/TestCompetingRiskResponseCombiner.java @@ -26,9 +26,7 @@ public class TestCompetingRiskResponseCombiner { final CompetingRiskResponseCombiner combiner = new CompetingRiskResponseCombiner(new int[]{1,2}, null); - final CompetingRiskFunctions functions = combiner.combine(data); - - return functions; + return combiner.combine(data); } @Test diff --git a/src/test/java/ca/joeltherrien/randomforest/csv/TestLoadingCSV.java b/src/test/java/ca/joeltherrien/randomforest/csv/TestLoadingCSV.java index 423c857..fcb2c1d 100644 --- a/src/test/java/ca/joeltherrien/randomforest/csv/TestLoadingCSV.java +++ b/src/test/java/ca/joeltherrien/randomforest/csv/TestLoadingCSV.java @@ -18,6 +18,7 @@ import java.util.List; import java.util.stream.Collectors; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; public class TestLoadingCSV { @@ -51,9 +52,7 @@ public class TestLoadingCSV { final DataLoader.ResponseLoader loader = settings.getResponseLoader(); - final List> data = DataLoader.loadData(covariates, loader, settings.getDataFileLocation()); - - return data; + return DataLoader.loadData(covariates, loader, settings.getDataFileLocation()); } @Test @@ -94,9 +93,9 @@ public class TestLoadingCSV { row = data.get(3); assertEquals(-3.0, (double)row.getResponse()); - assertEquals(true, row.getCovariateValue("x1").isNA()); - assertEquals(true, row.getCovariateValue("x2").isNA()); - assertEquals(true, row.getCovariateValue("x3").isNA()); + assertTrue(row.getCovariateValue("x1").isNA()); + assertTrue(row.getCovariateValue("x2").isNA()); + assertTrue(row.getCovariateValue("x3").isNA()); } } diff --git a/src/test/java/ca/joeltherrien/randomforest/settings/TestPersistence.java b/src/test/java/ca/joeltherrien/randomforest/settings/TestPersistence.java index 6c149ea..87f8449 100644 --- a/src/test/java/ca/joeltherrien/randomforest/settings/TestPersistence.java +++ b/src/test/java/ca/joeltherrien/randomforest/settings/TestPersistence.java @@ -12,7 +12,6 @@ import org.junit.jupiter.api.Test; import java.io.File; import java.io.IOException; -import java.util.List; public class TestPersistence {