Migrate to Java 1.8
This commit is contained in:
parent
949c8789e7
commit
75f34853ab
12 changed files with 125 additions and 30 deletions
6
pom.xml
6
pom.xml
|
@ -9,9 +9,9 @@
|
||||||
<version>1.0-SNAPSHOT</version>
|
<version>1.0-SNAPSHOT</version>
|
||||||
|
|
||||||
<properties>
|
<properties>
|
||||||
<java.version>1.10</java.version>
|
<java.version>1.8</java.version>
|
||||||
<maven.compiler.target>1.10</maven.compiler.target>
|
<maven.compiler.target>1.8</maven.compiler.target>
|
||||||
<maven.compiler.source>1.10</maven.compiler.source>
|
<maven.compiler.source>1.8</maven.compiler.source>
|
||||||
<jackson.version>2.9.6</jackson.version>
|
<jackson.version>2.9.6</jackson.version>
|
||||||
</properties>
|
</properties>
|
||||||
|
|
||||||
|
|
|
@ -161,10 +161,10 @@ public class Main {
|
||||||
yVarSettings.set("name", new TextNode("y"));
|
yVarSettings.set("name", new TextNode("y"));
|
||||||
|
|
||||||
final Settings settings = Settings.builder()
|
final Settings settings = Settings.builder()
|
||||||
.covariates(List.of(
|
.covariates(Utils.easyList(
|
||||||
new NumericCovariateSettings("x1"),
|
new NumericCovariateSettings("x1"),
|
||||||
new BooleanCovariateSettings("x2"),
|
new BooleanCovariateSettings("x2"),
|
||||||
new FactorCovariateSettings("x3", List.of("cat", "mouse", "dog"))
|
new FactorCovariateSettings("x3", Utils.easyList("cat", "mouse", "dog"))
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
.dataFileLocation("data.csv")
|
.dataFileLocation("data.csv")
|
||||||
|
|
|
@ -64,5 +64,92 @@ public class Utils {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Replacement for Java 9's List.of
|
||||||
|
*
|
||||||
|
* @param array
|
||||||
|
* @param <T>
|
||||||
|
* @return A list
|
||||||
|
*/
|
||||||
|
public static <T> List<T> easyList(T... array){
|
||||||
|
final List<T> list = new ArrayList<>(array.length);
|
||||||
|
|
||||||
|
for(final T item : array){
|
||||||
|
list.add(item);
|
||||||
|
}
|
||||||
|
|
||||||
|
return list;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Replacement for Java 9's Map.of
|
||||||
|
*
|
||||||
|
* @param array
|
||||||
|
* @return A map
|
||||||
|
*/
|
||||||
|
public static Map easyMap(Object... array){
|
||||||
|
if(array.length % 2 != 0){
|
||||||
|
throw new IllegalArgumentException("Must provide a value for every key");
|
||||||
|
}
|
||||||
|
|
||||||
|
final Map map = new HashMap();
|
||||||
|
for(int i=0; i<array.length; i+=2){
|
||||||
|
map.put(array[i], array[i+1]);
|
||||||
|
}
|
||||||
|
|
||||||
|
return map;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Replacement for Java 9's Map.of
|
||||||
|
* @return A map
|
||||||
|
*/
|
||||||
|
public static <K,V> Map<K,V> easyMap(K k1, V v1){
|
||||||
|
final Map<K,V> map = new HashMap<>();
|
||||||
|
map.put(k1, v1);
|
||||||
|
|
||||||
|
return map;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Replacement for Java 9's Map.of
|
||||||
|
* @return A map
|
||||||
|
*/
|
||||||
|
public static <K,V> Map<K,V> easyMap(K k1, V v1, K k2, V v2){
|
||||||
|
final Map<K,V> map = new HashMap<>();
|
||||||
|
map.put(k1, v1);
|
||||||
|
map.put(k2, v2);
|
||||||
|
|
||||||
|
return map;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Replacement for Java 9's Map.of
|
||||||
|
* @return A map
|
||||||
|
*/
|
||||||
|
public static <K,V> Map<K,V> easyMap(K k1, V v1, K k2, V v2, K k3, V v3){
|
||||||
|
final Map<K,V> map = new HashMap<>();
|
||||||
|
map.put(k1, v1);
|
||||||
|
map.put(k2, v2);
|
||||||
|
map.put(k3, v3);
|
||||||
|
|
||||||
|
return map;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Replacement for Java 9's Map.of
|
||||||
|
* @return A map
|
||||||
|
*/
|
||||||
|
public static <K,V> Map<K,V> easyMap(K k1, V v1, K k2, V v2, K k3, V v3, K k4, V v4){
|
||||||
|
final Map<K,V> map = new HashMap<>();
|
||||||
|
map.put(k1, v1);
|
||||||
|
map.put(k2, v2);
|
||||||
|
map.put(k3, v3);
|
||||||
|
map.put(k4, v4);
|
||||||
|
|
||||||
|
return map;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -8,6 +8,7 @@ import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskFunctio
|
||||||
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
|
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
|
||||||
import ca.joeltherrien.randomforest.tree.Forest;
|
import ca.joeltherrien.randomforest.tree.Forest;
|
||||||
import ca.joeltherrien.randomforest.tree.ForestTrainer;
|
import ca.joeltherrien.randomforest.tree.ForestTrainer;
|
||||||
|
import ca.joeltherrien.randomforest.utils.Utils;
|
||||||
import com.fasterxml.jackson.databind.node.*;
|
import com.fasterxml.jackson.databind.node.*;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import static org.junit.jupiter.api.Assertions.*;
|
import static org.junit.jupiter.api.Assertions.*;
|
||||||
|
@ -35,7 +36,7 @@ public class TestSavingLoading {
|
||||||
final ObjectNode responseCombinerSettings = new ObjectNode(JsonNodeFactory.instance);
|
final ObjectNode responseCombinerSettings = new ObjectNode(JsonNodeFactory.instance);
|
||||||
responseCombinerSettings.set("type", new TextNode("CompetingRiskResponseCombiner"));
|
responseCombinerSettings.set("type", new TextNode("CompetingRiskResponseCombiner"));
|
||||||
responseCombinerSettings.set("events",
|
responseCombinerSettings.set("events",
|
||||||
new ArrayNode(JsonNodeFactory.instance, List.of(new IntNode(1), new IntNode(2)))
|
new ArrayNode(JsonNodeFactory.instance, Utils.easyList(new IntNode(1), new IntNode(2)))
|
||||||
);
|
);
|
||||||
// not setting times
|
// not setting times
|
||||||
|
|
||||||
|
@ -43,7 +44,7 @@ public class TestSavingLoading {
|
||||||
final ObjectNode treeCombinerSettings = new ObjectNode(JsonNodeFactory.instance);
|
final ObjectNode treeCombinerSettings = new ObjectNode(JsonNodeFactory.instance);
|
||||||
treeCombinerSettings.set("type", new TextNode("CompetingRiskFunctionCombiner"));
|
treeCombinerSettings.set("type", new TextNode("CompetingRiskFunctionCombiner"));
|
||||||
treeCombinerSettings.set("events",
|
treeCombinerSettings.set("events",
|
||||||
new ArrayNode(JsonNodeFactory.instance, List.of(new IntNode(1), new IntNode(2)))
|
new ArrayNode(JsonNodeFactory.instance, Utils.easyList(new IntNode(1), new IntNode(2)))
|
||||||
);
|
);
|
||||||
// not setting times
|
// not setting times
|
||||||
|
|
||||||
|
@ -53,7 +54,7 @@ public class TestSavingLoading {
|
||||||
yVarSettings.set("delta", new TextNode("status"));
|
yVarSettings.set("delta", new TextNode("status"));
|
||||||
|
|
||||||
return Settings.builder()
|
return Settings.builder()
|
||||||
.covariates(List.of(
|
.covariates(Utils.easyList(
|
||||||
new NumericCovariateSettings("ageatfda"),
|
new NumericCovariateSettings("ageatfda"),
|
||||||
new BooleanCovariateSettings("idu"),
|
new BooleanCovariateSettings("idu"),
|
||||||
new BooleanCovariateSettings("black"),
|
new BooleanCovariateSettings("black"),
|
||||||
|
@ -78,7 +79,7 @@ public class TestSavingLoading {
|
||||||
}
|
}
|
||||||
|
|
||||||
public CovariateRow getPredictionRow(List<Covariate> covariates){
|
public CovariateRow getPredictionRow(List<Covariate> covariates){
|
||||||
return CovariateRow.createSimple(Map.of(
|
return CovariateRow.createSimple(Utils.easyMap(
|
||||||
"ageatfda", "35",
|
"ageatfda", "35",
|
||||||
"idu", "false",
|
"idu", "false",
|
||||||
"black", "false",
|
"black", "false",
|
||||||
|
|
|
@ -77,7 +77,7 @@ public class TestUtils {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void reduceListToSize(){
|
public void reduceListToSize(){
|
||||||
final List<Integer> testList = List.of(1, 2, 3, 4, 5, 6, 7, 8, 9, 10);
|
final List<Integer> testList = Utils.easyList(1, 2, 3, 4, 5, 6, 7, 8, 9, 10);
|
||||||
|
|
||||||
for(int i=0; i<100; i++) { // want to test many times to ensure it doesn't work just due to randomness
|
for(int i=0; i<100; i++) { // want to test many times to ensure it doesn't work just due to randomness
|
||||||
final List<Integer> testList1 = new ArrayList<>(testList);
|
final List<Integer> testList1 = new ArrayList<>(testList);
|
||||||
|
|
|
@ -9,6 +9,7 @@ import ca.joeltherrien.randomforest.tree.Node;
|
||||||
import ca.joeltherrien.randomforest.tree.TreeTrainer;
|
import ca.joeltherrien.randomforest.tree.TreeTrainer;
|
||||||
import ca.joeltherrien.randomforest.utils.MathFunction;
|
import ca.joeltherrien.randomforest.utils.MathFunction;
|
||||||
import ca.joeltherrien.randomforest.utils.Point;
|
import ca.joeltherrien.randomforest.utils.Point;
|
||||||
|
import ca.joeltherrien.randomforest.utils.Utils;
|
||||||
import com.fasterxml.jackson.databind.node.*;
|
import com.fasterxml.jackson.databind.node.*;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
|
|
||||||
|
@ -37,7 +38,7 @@ public class TestCompetingRisk {
|
||||||
final ObjectNode responseCombinerSettings = new ObjectNode(JsonNodeFactory.instance);
|
final ObjectNode responseCombinerSettings = new ObjectNode(JsonNodeFactory.instance);
|
||||||
responseCombinerSettings.set("type", new TextNode("CompetingRiskResponseCombiner"));
|
responseCombinerSettings.set("type", new TextNode("CompetingRiskResponseCombiner"));
|
||||||
responseCombinerSettings.set("events",
|
responseCombinerSettings.set("events",
|
||||||
new ArrayNode(JsonNodeFactory.instance, List.of(new IntNode(1), new IntNode(2)))
|
new ArrayNode(JsonNodeFactory.instance, Utils.easyList(new IntNode(1), new IntNode(2)))
|
||||||
);
|
);
|
||||||
// not setting times
|
// not setting times
|
||||||
|
|
||||||
|
@ -45,7 +46,7 @@ public class TestCompetingRisk {
|
||||||
final ObjectNode treeCombinerSettings = new ObjectNode(JsonNodeFactory.instance);
|
final ObjectNode treeCombinerSettings = new ObjectNode(JsonNodeFactory.instance);
|
||||||
treeCombinerSettings.set("type", new TextNode("CompetingRiskFunctionCombiner"));
|
treeCombinerSettings.set("type", new TextNode("CompetingRiskFunctionCombiner"));
|
||||||
treeCombinerSettings.set("events",
|
treeCombinerSettings.set("events",
|
||||||
new ArrayNode(JsonNodeFactory.instance, List.of(new IntNode(1), new IntNode(2)))
|
new ArrayNode(JsonNodeFactory.instance, Utils.easyList(new IntNode(1), new IntNode(2)))
|
||||||
);
|
);
|
||||||
// not setting times
|
// not setting times
|
||||||
|
|
||||||
|
@ -55,7 +56,7 @@ public class TestCompetingRisk {
|
||||||
yVarSettings.set("delta", new TextNode("status"));
|
yVarSettings.set("delta", new TextNode("status"));
|
||||||
|
|
||||||
return Settings.builder()
|
return Settings.builder()
|
||||||
.covariates(List.of(
|
.covariates(Utils.easyList(
|
||||||
new NumericCovariateSettings("ageatfda"),
|
new NumericCovariateSettings("ageatfda"),
|
||||||
new BooleanCovariateSettings("idu"),
|
new BooleanCovariateSettings("idu"),
|
||||||
new BooleanCovariateSettings("black"),
|
new BooleanCovariateSettings("black"),
|
||||||
|
@ -84,7 +85,7 @@ public class TestCompetingRisk {
|
||||||
}
|
}
|
||||||
|
|
||||||
public CovariateRow getPredictionRow(List<Covariate> covariates){
|
public CovariateRow getPredictionRow(List<Covariate> covariates){
|
||||||
return CovariateRow.createSimple(Map.of(
|
return CovariateRow.createSimple(Utils.easyMap(
|
||||||
"ageatfda", "35",
|
"ageatfda", "35",
|
||||||
"idu", "false",
|
"idu", "false",
|
||||||
"black", "false",
|
"black", "false",
|
||||||
|
@ -96,7 +97,7 @@ public class TestCompetingRisk {
|
||||||
public void testSingleTree() throws IOException {
|
public void testSingleTree() throws IOException {
|
||||||
final Settings settings = getSettings();
|
final Settings settings = getSettings();
|
||||||
settings.setDataFileLocation("src/test/resources/wihs.bootstrapped.csv");
|
settings.setDataFileLocation("src/test/resources/wihs.bootstrapped.csv");
|
||||||
settings.setCovariates(List.of(
|
settings.setCovariates(Utils.easyList(
|
||||||
new BooleanCovariateSettings("idu"),
|
new BooleanCovariateSettings("idu"),
|
||||||
new BooleanCovariateSettings("black")
|
new BooleanCovariateSettings("black")
|
||||||
)); // by only using BooleanCovariates (only one split rule) we can guarantee identical results with randomForestSRC on one tree.
|
)); // by only using BooleanCovariates (only one split rule) we can guarantee identical results with randomForestSRC on one tree.
|
||||||
|
@ -199,7 +200,7 @@ public class TestCompetingRisk {
|
||||||
@Test
|
@Test
|
||||||
public void testLogRankSingleGroupDifferentiatorTwoBooleans() throws IOException {
|
public void testLogRankSingleGroupDifferentiatorTwoBooleans() throws IOException {
|
||||||
final Settings settings = getSettings();
|
final Settings settings = getSettings();
|
||||||
settings.setCovariates(List.of(
|
settings.setCovariates(Utils.easyList(
|
||||||
new BooleanCovariateSettings("idu"),
|
new BooleanCovariateSettings("idu"),
|
||||||
new BooleanCovariateSettings("black")
|
new BooleanCovariateSettings("black")
|
||||||
));
|
));
|
||||||
|
|
|
@ -5,6 +5,7 @@ import ca.joeltherrien.randomforest.responses.competingrisk.*;
|
||||||
import ca.joeltherrien.randomforest.tree.Forest;
|
import ca.joeltherrien.randomforest.tree.Forest;
|
||||||
import ca.joeltherrien.randomforest.utils.MathFunction;
|
import ca.joeltherrien.randomforest.utils.MathFunction;
|
||||||
import ca.joeltherrien.randomforest.utils.Point;
|
import ca.joeltherrien.randomforest.utils.Point;
|
||||||
|
import ca.joeltherrien.randomforest.utils.Utils;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
|
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
|
@ -26,7 +27,7 @@ public class TestCompetingRiskErrorRateCalculator {
|
||||||
final CompetingRiskResponse response4 = new CompetingRiskResponse(1, 3.0);
|
final CompetingRiskResponse response4 = new CompetingRiskResponse(1, 3.0);
|
||||||
|
|
||||||
final double[] mortalityArray = new double[]{1, 4, 3, 9};
|
final double[] mortalityArray = new double[]{1, 4, 3, 9};
|
||||||
final List<CompetingRiskResponse> responseList = List.of(response1, response2, response3, response4);
|
final List<CompetingRiskResponse> responseList = Utils.easyList(response1, response2, response3, response4);
|
||||||
|
|
||||||
final int event = 1;
|
final int event = 1;
|
||||||
|
|
||||||
|
@ -52,7 +53,7 @@ public class TestCompetingRiskErrorRateCalculator {
|
||||||
final CompetingRiskResponse response3 = new CompetingRiskResponse(2, 8.0);
|
final CompetingRiskResponse response3 = new CompetingRiskResponse(2, 8.0);
|
||||||
final CompetingRiskResponse response4 = new CompetingRiskResponse(1, 3.0);
|
final CompetingRiskResponse response4 = new CompetingRiskResponse(1, 3.0);
|
||||||
|
|
||||||
final List<Row<CompetingRiskResponse>> dataset = List.of(
|
final List<Row<CompetingRiskResponse>> dataset = Utils.easyList(
|
||||||
new Row<>(Collections.emptyMap(), 1, response1),
|
new Row<>(Collections.emptyMap(), 1, response1),
|
||||||
new Row<>(Collections.emptyMap(), 2, response2),
|
new Row<>(Collections.emptyMap(), 2, response2),
|
||||||
new Row<>(Collections.emptyMap(), 3, response3),
|
new Row<>(Collections.emptyMap(), 3, response3),
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package ca.joeltherrien.randomforest.covariates;
|
package ca.joeltherrien.randomforest.covariates;
|
||||||
|
|
||||||
|
|
||||||
|
import ca.joeltherrien.randomforest.utils.Utils;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.junit.jupiter.api.function.Executable;
|
import org.junit.jupiter.api.function.Executable;
|
||||||
|
|
||||||
|
@ -55,7 +56,7 @@ public class FactorCovariateTest {
|
||||||
|
|
||||||
|
|
||||||
private FactorCovariate createTestCovariate(){
|
private FactorCovariate createTestCovariate(){
|
||||||
final List<String> levels = List.of("DOG", "CAT", "MOUSE");
|
final List<String> levels = Utils.easyList("DOG", "CAT", "MOUSE");
|
||||||
|
|
||||||
return new FactorCovariate("pet", levels);
|
return new FactorCovariate("pet", levels);
|
||||||
}
|
}
|
||||||
|
|
|
@ -7,6 +7,7 @@ import ca.joeltherrien.randomforest.covariates.BooleanCovariateSettings;
|
||||||
import ca.joeltherrien.randomforest.covariates.Covariate;
|
import ca.joeltherrien.randomforest.covariates.Covariate;
|
||||||
import ca.joeltherrien.randomforest.covariates.FactorCovariateSettings;
|
import ca.joeltherrien.randomforest.covariates.FactorCovariateSettings;
|
||||||
import ca.joeltherrien.randomforest.covariates.NumericCovariateSettings;
|
import ca.joeltherrien.randomforest.covariates.NumericCovariateSettings;
|
||||||
|
import ca.joeltherrien.randomforest.utils.Utils;
|
||||||
import com.fasterxml.jackson.databind.node.JsonNodeFactory;
|
import com.fasterxml.jackson.databind.node.JsonNodeFactory;
|
||||||
import com.fasterxml.jackson.databind.node.ObjectNode;
|
import com.fasterxml.jackson.databind.node.ObjectNode;
|
||||||
import com.fasterxml.jackson.databind.node.TextNode;
|
import com.fasterxml.jackson.databind.node.TextNode;
|
||||||
|
@ -37,8 +38,8 @@ public class TestLoadingCSV {
|
||||||
final Settings settings = Settings.builder()
|
final Settings settings = Settings.builder()
|
||||||
.dataFileLocation(filename)
|
.dataFileLocation(filename)
|
||||||
.covariates(
|
.covariates(
|
||||||
List.of(new NumericCovariateSettings("x1"),
|
Utils.easyList(new NumericCovariateSettings("x1"),
|
||||||
new FactorCovariateSettings("x2", List.of("dog", "cat", "mouse")),
|
new FactorCovariateSettings("x2", Utils.easyList("dog", "cat", "mouse")),
|
||||||
new BooleanCovariateSettings("x3"))
|
new BooleanCovariateSettings("x3"))
|
||||||
)
|
)
|
||||||
.yVarSettings(yVarSettings)
|
.yVarSettings(yVarSettings)
|
||||||
|
|
|
@ -4,6 +4,7 @@ import ca.joeltherrien.randomforest.Settings;
|
||||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
|
||||||
import ca.joeltherrien.randomforest.covariates.*;
|
import ca.joeltherrien.randomforest.covariates.*;
|
||||||
|
import ca.joeltherrien.randomforest.utils.Utils;
|
||||||
import com.fasterxml.jackson.databind.node.JsonNodeFactory;
|
import com.fasterxml.jackson.databind.node.JsonNodeFactory;
|
||||||
import com.fasterxml.jackson.databind.node.ObjectNode;
|
import com.fasterxml.jackson.databind.node.ObjectNode;
|
||||||
import com.fasterxml.jackson.databind.node.TextNode;
|
import com.fasterxml.jackson.databind.node.TextNode;
|
||||||
|
@ -31,10 +32,10 @@ public class TestPersistence {
|
||||||
yVarSettings.set("name", new TextNode("y"));
|
yVarSettings.set("name", new TextNode("y"));
|
||||||
|
|
||||||
final Settings settingsOriginal = Settings.builder()
|
final Settings settingsOriginal = Settings.builder()
|
||||||
.covariates(List.of(
|
.covariates(Utils.easyList(
|
||||||
new NumericCovariateSettings("x1"),
|
new NumericCovariateSettings("x1"),
|
||||||
new BooleanCovariateSettings("x2"),
|
new BooleanCovariateSettings("x2"),
|
||||||
new FactorCovariateSettings("x3", List.of("cat", "mouse", "dog"))
|
new FactorCovariateSettings("x3", Utils.easyList("cat", "mouse", "dog"))
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
.dataFileLocation("data.csv")
|
.dataFileLocation("data.csv")
|
||||||
|
|
|
@ -9,6 +9,7 @@ import ca.joeltherrien.randomforest.responses.regression.MeanResponseCombiner;
|
||||||
import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceGroupDifferentiator;
|
import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceGroupDifferentiator;
|
||||||
import ca.joeltherrien.randomforest.tree.Node;
|
import ca.joeltherrien.randomforest.tree.Node;
|
||||||
import ca.joeltherrien.randomforest.tree.TreeTrainer;
|
import ca.joeltherrien.randomforest.tree.TreeTrainer;
|
||||||
|
import ca.joeltherrien.randomforest.utils.Utils;
|
||||||
|
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
|
@ -48,7 +49,7 @@ public class TrainSingleTree {
|
||||||
trainingSet.add(generateRow(x1, x2, i));
|
trainingSet.add(generateRow(x1, x2, i));
|
||||||
}
|
}
|
||||||
|
|
||||||
final List<Covariate> covariateNames = List.of(x1Covariate, x2Covariate);
|
final List<Covariate> covariateNames = Utils.easyList(x1Covariate, x2Covariate);
|
||||||
|
|
||||||
final TreeTrainer<Double, Double> treeTrainer = TreeTrainer.<Double, Double>builder()
|
final TreeTrainer<Double, Double> treeTrainer = TreeTrainer.<Double, Double>builder()
|
||||||
.groupDifferentiator(new WeightedVarianceGroupDifferentiator())
|
.groupDifferentiator(new WeightedVarianceGroupDifferentiator())
|
||||||
|
@ -99,7 +100,7 @@ public class TrainSingleTree {
|
||||||
public static Row<Double> generateRow(Covariate.Value<Double> x1, Covariate.Value<Double> x2, int id){
|
public static Row<Double> generateRow(Covariate.Value<Double> x1, Covariate.Value<Double> x2, int id){
|
||||||
double y = generateResponse(x1.getValue(), x2.getValue());
|
double y = generateResponse(x1.getValue(), x2.getValue());
|
||||||
|
|
||||||
final Map<String, Covariate.Value> map = Map.of("x1", x1, "x2", x2);
|
final Map<String, Covariate.Value> map = Utils.easyMap("x1", x1, "x2", x2);
|
||||||
|
|
||||||
return new Row<>(map, id, y);
|
return new Row<>(map, id, y);
|
||||||
|
|
||||||
|
@ -107,7 +108,7 @@ public class TrainSingleTree {
|
||||||
|
|
||||||
|
|
||||||
public static CovariateRow generateCovariateRow(Covariate.Value x1, Covariate.Value x2, int id){
|
public static CovariateRow generateCovariateRow(Covariate.Value x1, Covariate.Value x2, int id){
|
||||||
final Map<String, Covariate.Value> map = Map.of("x1", x1, "x2", x2);
|
final Map<String, Covariate.Value> map = Utils.easyMap("x1", x1, "x2", x2);
|
||||||
|
|
||||||
return new CovariateRow(map, id);
|
return new CovariateRow(map, id);
|
||||||
|
|
||||||
|
|
|
@ -9,6 +9,7 @@ import ca.joeltherrien.randomforest.responses.regression.MeanResponseCombiner;
|
||||||
import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceGroupDifferentiator;
|
import ca.joeltherrien.randomforest.responses.regression.WeightedVarianceGroupDifferentiator;
|
||||||
import ca.joeltherrien.randomforest.tree.Node;
|
import ca.joeltherrien.randomforest.tree.Node;
|
||||||
import ca.joeltherrien.randomforest.tree.TreeTrainer;
|
import ca.joeltherrien.randomforest.tree.TreeTrainer;
|
||||||
|
import ca.joeltherrien.randomforest.utils.Utils;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
@ -29,7 +30,7 @@ public class TrainSingleTreeFactor {
|
||||||
|
|
||||||
final Covariate<Double> x1Covariate = new NumericCovariate("x1");
|
final Covariate<Double> x1Covariate = new NumericCovariate("x1");
|
||||||
final Covariate<Double> x2Covariate = new NumericCovariate("x2");
|
final Covariate<Double> x2Covariate = new NumericCovariate("x2");
|
||||||
final FactorCovariate x3Covariate = new FactorCovariate("x3", List.of("cat", "dog", "mouse"));
|
final FactorCovariate x3Covariate = new FactorCovariate("x3", Utils.easyList("cat", "dog", "mouse"));
|
||||||
|
|
||||||
final List<Covariate.Value<Double>> x1List = DoubleStream
|
final List<Covariate.Value<Double>> x1List = DoubleStream
|
||||||
.generate(() -> random.nextDouble()*10.0)
|
.generate(() -> random.nextDouble()*10.0)
|
||||||
|
@ -69,7 +70,7 @@ public class TrainSingleTreeFactor {
|
||||||
trainingSet.add(generateRow(x1, x2, x3, i));
|
trainingSet.add(generateRow(x1, x2, x3, i));
|
||||||
}
|
}
|
||||||
|
|
||||||
final List<Covariate> covariateNames = List.of(x1Covariate, x2Covariate);
|
final List<Covariate> covariateNames = Utils.easyList(x1Covariate, x2Covariate);
|
||||||
|
|
||||||
final TreeTrainer<Double, Double> treeTrainer = TreeTrainer.<Double, Double>builder()
|
final TreeTrainer<Double, Double> treeTrainer = TreeTrainer.<Double, Double>builder()
|
||||||
.groupDifferentiator(new WeightedVarianceGroupDifferentiator())
|
.groupDifferentiator(new WeightedVarianceGroupDifferentiator())
|
||||||
|
@ -127,7 +128,7 @@ public class TrainSingleTreeFactor {
|
||||||
public static Row<Double> generateRow(Covariate.Value<Double> x1, Covariate.Value<Double> x2, Covariate.Value<String> x3, int id){
|
public static Row<Double> generateRow(Covariate.Value<Double> x1, Covariate.Value<Double> x2, Covariate.Value<String> x3, int id){
|
||||||
double y = generateResponse(x1.getValue(), x2.getValue(), x3.getValue());
|
double y = generateResponse(x1.getValue(), x2.getValue(), x3.getValue());
|
||||||
|
|
||||||
final Map<String, Covariate.Value> map = Map.of("x1", x1, "x2", x2);
|
final Map<String, Covariate.Value> map = Utils.easyMap("x1", x1, "x2", x2);
|
||||||
|
|
||||||
return new Row<>(map, id, y);
|
return new Row<>(map, id, y);
|
||||||
|
|
||||||
|
@ -135,7 +136,7 @@ public class TrainSingleTreeFactor {
|
||||||
|
|
||||||
|
|
||||||
public static CovariateRow generateCovariateRow(Covariate.Value x1, Covariate.Value x2, Covariate.Value x3, int id){
|
public static CovariateRow generateCovariateRow(Covariate.Value x1, Covariate.Value x2, Covariate.Value x3, int id){
|
||||||
final Map<String, Covariate.Value> map = Map.of("x1", x1, "x2", x2, "x3", x3);
|
final Map<String, Covariate.Value> map = Utils.easyMap("x1", x1, "x2", x2, "x3", x3);
|
||||||
|
|
||||||
return new CovariateRow(map, id);
|
return new CovariateRow(map, id);
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue