Migrate to Java 1.8

This commit is contained in:
Joel Therrien 2018-08-31 12:48:39 -07:00
parent 949c8789e7
commit 75f34853ab
12 changed files with 125 additions and 30 deletions

View file

@ -9,9 +9,9 @@
<version>1.0-SNAPSHOT</version> <version>1.0-SNAPSHOT</version>
<properties> <properties>
<java.version>1.10</java.version> <java.version>1.8</java.version>
<maven.compiler.target>1.10</maven.compiler.target> <maven.compiler.target>1.8</maven.compiler.target>
<maven.compiler.source>1.10</maven.compiler.source> <maven.compiler.source>1.8</maven.compiler.source>
<jackson.version>2.9.6</jackson.version> <jackson.version>2.9.6</jackson.version>
</properties> </properties>

View file

@ -161,10 +161,10 @@ public class Main {
yVarSettings.set("name", new TextNode("y")); yVarSettings.set("name", new TextNode("y"));
final Settings settings = Settings.builder() final Settings settings = Settings.builder()
.covariates(List.of( .covariates(Utils.easyList(
new NumericCovariateSettings("x1"), new NumericCovariateSettings("x1"),
new BooleanCovariateSettings("x2"), new BooleanCovariateSettings("x2"),
new FactorCovariateSettings("x3", List.of("cat", "mouse", "dog")) new FactorCovariateSettings("x3", Utils.easyList("cat", "mouse", "dog"))
) )
) )
.dataFileLocation("data.csv") .dataFileLocation("data.csv")

View file

@ -64,5 +64,92 @@ public class Utils {
} }
} }
/**
* Replacement for Java 9's List.of
*
* @param array
* @param <T>
* @return A list
*/
public static <T> List<T> easyList(T... array){
final List<T> list = new ArrayList<>(array.length);
for(final T item : array){
list.add(item);
}
return list;
}
/**
* Replacement for Java 9's Map.of
*
* @param array
* @return A map
*/
public static Map easyMap(Object... array){
if(array.length % 2 != 0){
throw new IllegalArgumentException("Must provide a value for every key");
}
final Map map = new HashMap();
for(int i=0; i<array.length; i+=2){
map.put(array[i], array[i+1]);
}
return map;
}
/**
* Replacement for Java 9's Map.of
* @return A map
*/
public static <K,V> Map<K,V> easyMap(K k1, V v1){
final Map<K,V> map = new HashMap<>();
map.put(k1, v1);
return map;
}
/**
* Replacement for Java 9's Map.of
* @return A map
*/
public static <K,V> Map<K,V> easyMap(K k1, V v1, K k2, V v2){
final Map<K,V> map = new HashMap<>();
map.put(k1, v1);
map.put(k2, v2);
return map;
}
/**
* Replacement for Java 9's Map.of
* @return A map
*/
public static <K,V> Map<K,V> easyMap(K k1, V v1, K k2, V v2, K k3, V v3){
final Map<K,V> map = new HashMap<>();
map.put(k1, v1);
map.put(k2, v2);
map.put(k3, v3);
return map;
}
/**
* Replacement for Java 9's Map.of
* @return A map
*/
public static <K,V> Map<K,V> easyMap(K k1, V v1, K k2, V v2, K k3, V v3, K k4, V v4){
final Map<K,V> map = new HashMap<>();
map.put(k1, v1);
map.put(k2, v2);
map.put(k3, v3);
map.put(k4, v4);
return map;
}
} }

View file

@ -8,6 +8,7 @@ import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskFunctio
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse; import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
import ca.joeltherrien.randomforest.tree.Forest; import ca.joeltherrien.randomforest.tree.Forest;
import ca.joeltherrien.randomforest.tree.ForestTrainer; import ca.joeltherrien.randomforest.tree.ForestTrainer;
import ca.joeltherrien.randomforest.utils.Utils;
import com.fasterxml.jackson.databind.node.*; import com.fasterxml.jackson.databind.node.*;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import static org.junit.jupiter.api.Assertions.*; import static org.junit.jupiter.api.Assertions.*;
@ -35,7 +36,7 @@ public class TestSavingLoading {
final ObjectNode responseCombinerSettings = new ObjectNode(JsonNodeFactory.instance); final ObjectNode responseCombinerSettings = new ObjectNode(JsonNodeFactory.instance);
responseCombinerSettings.set("type", new TextNode("CompetingRiskResponseCombiner")); responseCombinerSettings.set("type", new TextNode("CompetingRiskResponseCombiner"));
responseCombinerSettings.set("events", responseCombinerSettings.set("events",
new ArrayNode(JsonNodeFactory.instance, List.of(new IntNode(1), new IntNode(2))) new ArrayNode(JsonNodeFactory.instance, Utils.easyList(new IntNode(1), new IntNode(2)))
); );
// not setting times // not setting times
@ -43,7 +44,7 @@ public class TestSavingLoading {
final ObjectNode treeCombinerSettings = new ObjectNode(JsonNodeFactory.instance); final ObjectNode treeCombinerSettings = new ObjectNode(JsonNodeFactory.instance);
treeCombinerSettings.set("type", new TextNode("CompetingRiskFunctionCombiner")); treeCombinerSettings.set("type", new TextNode("CompetingRiskFunctionCombiner"));
treeCombinerSettings.set("events", treeCombinerSettings.set("events",
new ArrayNode(JsonNodeFactory.instance, List.of(new IntNode(1), new IntNode(2))) new ArrayNode(JsonNodeFactory.instance, Utils.easyList(new IntNode(1), new IntNode(2)))
); );
// not setting times // not setting times
@ -53,7 +54,7 @@ public class TestSavingLoading {
yVarSettings.set("delta", new TextNode("status")); yVarSettings.set("delta", new TextNode("status"));
return Settings.builder() return Settings.builder()
.covariates(List.of( .covariates(Utils.easyList(
new NumericCovariateSettings("ageatfda"), new NumericCovariateSettings("ageatfda"),
new BooleanCovariateSettings("idu"), new BooleanCovariateSettings("idu"),
new BooleanCovariateSettings("black"), new BooleanCovariateSettings("black"),
@ -78,7 +79,7 @@ public class TestSavingLoading {
} }
public CovariateRow getPredictionRow(List<Covariate> covariates){ public CovariateRow getPredictionRow(List<Covariate> covariates){
return CovariateRow.createSimple(Map.of( return CovariateRow.createSimple(Utils.easyMap(
"ageatfda", "35", "ageatfda", "35",
"idu", "false", "idu", "false",
"black", "false", "black", "false",

View file

@ -77,7 +77,7 @@ public class TestUtils {
@Test @Test
public void reduceListToSize(){ public void reduceListToSize(){
final List<Integer> testList = List.of(1, 2, 3, 4, 5, 6, 7, 8, 9, 10); final List<Integer> testList = Utils.easyList(1, 2, 3, 4, 5, 6, 7, 8, 9, 10);
for(int i=0; i<100; i++) { // want to test many times to ensure it doesn't work just due to randomness for(int i=0; i<100; i++) { // want to test many times to ensure it doesn't work just due to randomness
final List<Integer> testList1 = new ArrayList<>(testList); final List<Integer> testList1 = new ArrayList<>(testList);

View file

@ -9,6 +9,7 @@ import ca.joeltherrien.randomforest.tree.Node;
import ca.joeltherrien.randomforest.tree.TreeTrainer; import ca.joeltherrien.randomforest.tree.TreeTrainer;
import ca.joeltherrien.randomforest.utils.MathFunction; import ca.joeltherrien.randomforest.utils.MathFunction;
import ca.joeltherrien.randomforest.utils.Point; import ca.joeltherrien.randomforest.utils.Point;
import ca.joeltherrien.randomforest.utils.Utils;
import com.fasterxml.jackson.databind.node.*; import com.fasterxml.jackson.databind.node.*;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
@ -37,7 +38,7 @@ public class TestCompetingRisk {
final ObjectNode responseCombinerSettings = new ObjectNode(JsonNodeFactory.instance); final ObjectNode responseCombinerSettings = new ObjectNode(JsonNodeFactory.instance);
responseCombinerSettings.set("type", new TextNode("CompetingRiskResponseCombiner")); responseCombinerSettings.set("type", new TextNode("CompetingRiskResponseCombiner"));
responseCombinerSettings.set("events", responseCombinerSettings.set("events",
new ArrayNode(JsonNodeFactory.instance, List.of(new IntNode(1), new IntNode(2))) new ArrayNode(JsonNodeFactory.instance, Utils.easyList(new IntNode(1), new IntNode(2)))
); );
// not setting times // not setting times
@ -45,7 +46,7 @@ public class TestCompetingRisk {
final ObjectNode treeCombinerSettings = new ObjectNode(JsonNodeFactory.instance); final ObjectNode treeCombinerSettings = new ObjectNode(JsonNodeFactory.instance);
treeCombinerSettings.set("type", new TextNode("CompetingRiskFunctionCombiner")); treeCombinerSettings.set("type", new TextNode("CompetingRiskFunctionCombiner"));
treeCombinerSettings.set("events", treeCombinerSettings.set("events",
new ArrayNode(JsonNodeFactory.instance, List.of(new IntNode(1), new IntNode(2))) new ArrayNode(JsonNodeFactory.instance, Utils.easyList(new IntNode(1), new IntNode(2)))
); );
// not setting times // not setting times
@ -55,7 +56,7 @@ public class TestCompetingRisk {
yVarSettings.set("delta", new TextNode("status")); yVarSettings.set("delta", new TextNode("status"));
return Settings.builder() return Settings.builder()
.covariates(List.of( .covariates(Utils.easyList(
new NumericCovariateSettings("ageatfda"), new NumericCovariateSettings("ageatfda"),
new BooleanCovariateSettings("idu"), new BooleanCovariateSettings("idu"),
new BooleanCovariateSettings("black"), new BooleanCovariateSettings("black"),
@ -84,7 +85,7 @@ public class TestCompetingRisk {
} }
public CovariateRow getPredictionRow(List<Covariate> covariates){ public CovariateRow getPredictionRow(List<Covariate> covariates){
return CovariateRow.createSimple(Map.of( return CovariateRow.createSimple(Utils.easyMap(
"ageatfda", "35", "ageatfda", "35",
"idu", "false", "idu", "false",
"black", "false", "black", "false",
@ -96,7 +97,7 @@ public class TestCompetingRisk {
public void testSingleTree() throws IOException { public void testSingleTree() throws IOException {
final Settings settings = getSettings(); final Settings settings = getSettings();
settings.setDataFileLocation("src/test/resources/wihs.bootstrapped.csv"); settings.setDataFileLocation("src/test/resources/wihs.bootstrapped.csv");
settings.setCovariates(List.of( settings.setCovariates(Utils.easyList(
new BooleanCovariateSettings("idu"), new BooleanCovariateSettings("idu"),
new BooleanCovariateSettings("black") new BooleanCovariateSettings("black")
)); // by only using BooleanCovariates (only one split rule) we can guarantee identical results with randomForestSRC on one tree. )); // by only using BooleanCovariates (only one split rule) we can guarantee identical results with randomForestSRC on one tree.
@ -199,7 +200,7 @@ public class TestCompetingRisk {
@Test @Test
public void testLogRankSingleGroupDifferentiatorTwoBooleans() throws IOException { public void testLogRankSingleGroupDifferentiatorTwoBooleans() throws IOException {
final Settings settings = getSettings(); final Settings settings = getSettings();
settings.setCovariates(List.of( settings.setCovariates(Utils.easyList(
new BooleanCovariateSettings("idu"), new BooleanCovariateSettings("idu"),
new BooleanCovariateSettings("black") new BooleanCovariateSettings("black")
)); ));

View file

@ -5,6 +5,7 @@ import ca.joeltherrien.randomforest.responses.competingrisk.*;
import ca.joeltherrien.randomforest.tree.Forest; import ca.joeltherrien.randomforest.tree.Forest;
import ca.joeltherrien.randomforest.utils.MathFunction; import ca.joeltherrien.randomforest.utils.MathFunction;
import ca.joeltherrien.randomforest.utils.Point; import ca.joeltherrien.randomforest.utils.Point;
import ca.joeltherrien.randomforest.utils.Utils;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import java.util.Collections; import java.util.Collections;
@ -26,7 +27,7 @@ public class TestCompetingRiskErrorRateCalculator {
final CompetingRiskResponse response4 = new CompetingRiskResponse(1, 3.0); final CompetingRiskResponse response4 = new CompetingRiskResponse(1, 3.0);
final double[] mortalityArray = new double[]{1, 4, 3, 9}; final double[] mortalityArray = new double[]{1, 4, 3, 9};
final List<CompetingRiskResponse> responseList = List.of(response1, response2, response3, response4); final List<CompetingRiskResponse> responseList = Utils.easyList(response1, response2, response3, response4);
final int event = 1; final int event = 1;
@ -52,7 +53,7 @@ public class TestCompetingRiskErrorRateCalculator {
final CompetingRiskResponse response3 = new CompetingRiskResponse(2, 8.0); final CompetingRiskResponse response3 = new CompetingRiskResponse(2, 8.0);
final CompetingRiskResponse response4 = new CompetingRiskResponse(1, 3.0); final CompetingRiskResponse response4 = new CompetingRiskResponse(1, 3.0);
final List<Row<CompetingRiskResponse>> dataset = List.of( final List<Row<CompetingRiskResponse>> dataset = Utils.easyList(
new Row<>(Collections.emptyMap(), 1, response1), new Row<>(Collections.emptyMap(), 1, response1),
new Row<>(Collections.emptyMap(), 2, response2), new Row<>(Collections.emptyMap(), 2, response2),
new Row<>(Collections.emptyMap(), 3, response3), new Row<>(Collections.emptyMap(), 3, response3),

View file

@ -1,6 +1,7 @@
package ca.joeltherrien.randomforest.covariates; package ca.joeltherrien.randomforest.covariates;
import ca.joeltherrien.randomforest.utils.Utils;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.function.Executable; import org.junit.jupiter.api.function.Executable;
@ -55,7 +56,7 @@ public class FactorCovariateTest {
private FactorCovariate createTestCovariate(){ private FactorCovariate createTestCovariate(){
final List<String> levels = List.of("DOG", "CAT", "MOUSE"); final List<String> levels = Utils.easyList("DOG", "CAT", "MOUSE");
return new FactorCovariate("pet", levels); return new FactorCovariate("pet", levels);
} }

View file

@ -7,6 +7,7 @@ import ca.joeltherrien.randomforest.covariates.BooleanCovariateSettings;
import ca.joeltherrien.randomforest.covariates.Covariate; import ca.joeltherrien.randomforest.covariates.Covariate;
import ca.joeltherrien.randomforest.covariates.FactorCovariateSettings; import ca.joeltherrien.randomforest.covariates.FactorCovariateSettings;
import ca.joeltherrien.randomforest.covariates.NumericCovariateSettings; import ca.joeltherrien.randomforest.covariates.NumericCovariateSettings;
import ca.joeltherrien.randomforest.utils.Utils;
import com.fasterxml.jackson.databind.node.JsonNodeFactory; import com.fasterxml.jackson.databind.node.JsonNodeFactory;
import com.fasterxml.jackson.databind.node.ObjectNode; import com.fasterxml.jackson.databind.node.ObjectNode;
import com.fasterxml.jackson.databind.node.TextNode; import com.fasterxml.jackson.databind.node.TextNode;
@ -37,8 +38,8 @@ public class TestLoadingCSV {
final Settings settings = Settings.builder() final Settings settings = Settings.builder()
.dataFileLocation(filename) .dataFileLocation(filename)
.covariates( .covariates(
List.of(new NumericCovariateSettings("x1"), Utils.easyList(new NumericCovariateSettings("x1"),
new FactorCovariateSettings("x2", List.of("dog", "cat", "mouse")), new FactorCovariateSettings("x2", Utils.easyList("dog", "cat", "mouse")),
new BooleanCovariateSettings("x3")) new BooleanCovariateSettings("x3"))
) )
.yVarSettings(yVarSettings) .yVarSettings(yVarSettings)

View file

@ -4,6 +4,7 @@ import ca.joeltherrien.randomforest.Settings;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import ca.joeltherrien.randomforest.covariates.*; import ca.joeltherrien.randomforest.covariates.*;
import ca.joeltherrien.randomforest.utils.Utils;
import com.fasterxml.jackson.databind.node.JsonNodeFactory; import com.fasterxml.jackson.databind.node.JsonNodeFactory;
import com.fasterxml.jackson.databind.node.ObjectNode; import com.fasterxml.jackson.databind.node.ObjectNode;
import com.fasterxml.jackson.databind.node.TextNode; import com.fasterxml.jackson.databind.node.TextNode;
@ -31,10 +32,10 @@ public class TestPersistence {
yVarSettings.set("name", new TextNode("y")); yVarSettings.set("name", new TextNode("y"));
final Settings settingsOriginal = Settings.builder() final Settings settingsOriginal = Settings.builder()
.covariates(List.of( .covariates(Utils.easyList(
new NumericCovariateSettings("x1"), new NumericCovariateSettings("x1"),
new BooleanCovariateSettings("x2"), new BooleanCovariateSettings("x2"),
new FactorCovariateSettings("x3", List.of("cat", "mouse", "dog")) new FactorCovariateSettings("x3", Utils.easyList("cat", "mouse", "dog"))
) )
) )
.dataFileLocation("data.csv") .dataFileLocation("data.csv")

View file

@ -9,6 +9,7 @@ import ca.joeltherrien.randomforest.responses.regression.MeanResponseCombiner;
import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceGroupDifferentiator; import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceGroupDifferentiator;
import ca.joeltherrien.randomforest.tree.Node; import ca.joeltherrien.randomforest.tree.Node;
import ca.joeltherrien.randomforest.tree.TreeTrainer; import ca.joeltherrien.randomforest.tree.TreeTrainer;
import ca.joeltherrien.randomforest.utils.Utils;
import java.util.*; import java.util.*;
import java.util.stream.Collectors; import java.util.stream.Collectors;
@ -48,7 +49,7 @@ public class TrainSingleTree {
trainingSet.add(generateRow(x1, x2, i)); trainingSet.add(generateRow(x1, x2, i));
} }
final List<Covariate> covariateNames = List.of(x1Covariate, x2Covariate); final List<Covariate> covariateNames = Utils.easyList(x1Covariate, x2Covariate);
final TreeTrainer<Double, Double> treeTrainer = TreeTrainer.<Double, Double>builder() final TreeTrainer<Double, Double> treeTrainer = TreeTrainer.<Double, Double>builder()
.groupDifferentiator(new WeightedVarianceGroupDifferentiator()) .groupDifferentiator(new WeightedVarianceGroupDifferentiator())
@ -99,7 +100,7 @@ public class TrainSingleTree {
public static Row<Double> generateRow(Covariate.Value<Double> x1, Covariate.Value<Double> x2, int id){ public static Row<Double> generateRow(Covariate.Value<Double> x1, Covariate.Value<Double> x2, int id){
double y = generateResponse(x1.getValue(), x2.getValue()); double y = generateResponse(x1.getValue(), x2.getValue());
final Map<String, Covariate.Value> map = Map.of("x1", x1, "x2", x2); final Map<String, Covariate.Value> map = Utils.easyMap("x1", x1, "x2", x2);
return new Row<>(map, id, y); return new Row<>(map, id, y);
@ -107,7 +108,7 @@ public class TrainSingleTree {
public static CovariateRow generateCovariateRow(Covariate.Value x1, Covariate.Value x2, int id){ public static CovariateRow generateCovariateRow(Covariate.Value x1, Covariate.Value x2, int id){
final Map<String, Covariate.Value> map = Map.of("x1", x1, "x2", x2); final Map<String, Covariate.Value> map = Utils.easyMap("x1", x1, "x2", x2);
return new CovariateRow(map, id); return new CovariateRow(map, id);

View file

@ -9,6 +9,7 @@ import ca.joeltherrien.randomforest.responses.regression.MeanResponseCombiner;
import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceGroupDifferentiator; import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceGroupDifferentiator;
import ca.joeltherrien.randomforest.tree.Node; import ca.joeltherrien.randomforest.tree.Node;
import ca.joeltherrien.randomforest.tree.TreeTrainer; import ca.joeltherrien.randomforest.tree.TreeTrainer;
import ca.joeltherrien.randomforest.utils.Utils;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
@ -29,7 +30,7 @@ public class TrainSingleTreeFactor {
final Covariate<Double> x1Covariate = new NumericCovariate("x1"); final Covariate<Double> x1Covariate = new NumericCovariate("x1");
final Covariate<Double> x2Covariate = new NumericCovariate("x2"); final Covariate<Double> x2Covariate = new NumericCovariate("x2");
final FactorCovariate x3Covariate = new FactorCovariate("x3", List.of("cat", "dog", "mouse")); final FactorCovariate x3Covariate = new FactorCovariate("x3", Utils.easyList("cat", "dog", "mouse"));
final List<Covariate.Value<Double>> x1List = DoubleStream final List<Covariate.Value<Double>> x1List = DoubleStream
.generate(() -> random.nextDouble()*10.0) .generate(() -> random.nextDouble()*10.0)
@ -69,7 +70,7 @@ public class TrainSingleTreeFactor {
trainingSet.add(generateRow(x1, x2, x3, i)); trainingSet.add(generateRow(x1, x2, x3, i));
} }
final List<Covariate> covariateNames = List.of(x1Covariate, x2Covariate); final List<Covariate> covariateNames = Utils.easyList(x1Covariate, x2Covariate);
final TreeTrainer<Double, Double> treeTrainer = TreeTrainer.<Double, Double>builder() final TreeTrainer<Double, Double> treeTrainer = TreeTrainer.<Double, Double>builder()
.groupDifferentiator(new WeightedVarianceGroupDifferentiator()) .groupDifferentiator(new WeightedVarianceGroupDifferentiator())
@ -127,7 +128,7 @@ public class TrainSingleTreeFactor {
public static Row<Double> generateRow(Covariate.Value<Double> x1, Covariate.Value<Double> x2, Covariate.Value<String> x3, int id){ public static Row<Double> generateRow(Covariate.Value<Double> x1, Covariate.Value<Double> x2, Covariate.Value<String> x3, int id){
double y = generateResponse(x1.getValue(), x2.getValue(), x3.getValue()); double y = generateResponse(x1.getValue(), x2.getValue(), x3.getValue());
final Map<String, Covariate.Value> map = Map.of("x1", x1, "x2", x2); final Map<String, Covariate.Value> map = Utils.easyMap("x1", x1, "x2", x2);
return new Row<>(map, id, y); return new Row<>(map, id, y);
@ -135,7 +136,7 @@ public class TrainSingleTreeFactor {
public static CovariateRow generateCovariateRow(Covariate.Value x1, Covariate.Value x2, Covariate.Value x3, int id){ public static CovariateRow generateCovariateRow(Covariate.Value x1, Covariate.Value x2, Covariate.Value x3, int id){
final Map<String, Covariate.Value> map = Map.of("x1", x1, "x2", x2, "x3", x3); final Map<String, Covariate.Value> map = Utils.easyMap("x1", x1, "x2", x2, "x3", x3);
return new CovariateRow(map, id); return new CovariateRow(map, id);