Code cleanup; fixed 3 minor bugs in the settings
This commit is contained in:
parent
75f34853ab
commit
8333579a1f
17 changed files with 19 additions and 41 deletions
|
@ -4,7 +4,6 @@ import lombok.RequiredArgsConstructor;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Random;
|
|
||||||
import java.util.concurrent.ThreadLocalRandom;
|
import java.util.concurrent.ThreadLocalRandom;
|
||||||
|
|
||||||
@RequiredArgsConstructor
|
@RequiredArgsConstructor
|
||||||
|
|
|
@ -74,13 +74,11 @@ public class DataLoader {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
final Forest forest = Forest.<O, FO>builder()
|
return Forest.<O, FO>builder()
|
||||||
.trees(treeList)
|
.trees(treeList)
|
||||||
.treeResponseCombiner(treeResponseCombiner)
|
.treeResponseCombiner(treeResponseCombiner)
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
return forest;
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@FunctionalInterface
|
@FunctionalInterface
|
||||||
|
|
|
@ -22,7 +22,6 @@ import java.util.stream.Collectors;
|
||||||
|
|
||||||
public class Main {
|
public class Main {
|
||||||
|
|
||||||
|
|
||||||
public static void main(String[] args) throws IOException, ClassNotFoundException {
|
public static void main(String[] args) throws IOException, ClassNotFoundException {
|
||||||
if(args.length < 2){
|
if(args.length < 2){
|
||||||
System.out.println("Must provide two arguments - the path to the settings.yaml file and instructions to either train or analyze.");
|
System.out.println("Must provide two arguments - the path to the settings.yaml file and instructions to either train or analyze.");
|
||||||
|
@ -160,7 +159,7 @@ public class Main {
|
||||||
yVarSettings.set("type", new TextNode("Double"));
|
yVarSettings.set("type", new TextNode("Double"));
|
||||||
yVarSettings.set("name", new TextNode("y"));
|
yVarSettings.set("name", new TextNode("y"));
|
||||||
|
|
||||||
final Settings settings = Settings.builder()
|
return Settings.builder()
|
||||||
.covariates(Utils.easyList(
|
.covariates(Utils.easyList(
|
||||||
new NumericCovariateSettings("x1"),
|
new NumericCovariateSettings("x1"),
|
||||||
new BooleanCovariateSettings("x2"),
|
new BooleanCovariateSettings("x2"),
|
||||||
|
@ -181,9 +180,6 @@ public class Main {
|
||||||
.saveProgress(true)
|
.saveProgress(true)
|
||||||
.saveTreeLocation("trees/")
|
.saveTreeLocation("trees/")
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
|
|
||||||
return settings;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -133,7 +133,7 @@ public class Settings {
|
||||||
if(node.hasNonNull("times")){
|
if(node.hasNonNull("times")){
|
||||||
final List<Double> timeList = new ArrayList<>();
|
final List<Double> timeList = new ArrayList<>();
|
||||||
node.get("times").elements().forEachRemaining(event -> timeList.add(event.asDouble()));
|
node.get("times").elements().forEachRemaining(event -> timeList.add(event.asDouble()));
|
||||||
times = eventList.stream().mapToDouble(db -> db).toArray();
|
times = timeList.stream().mapToDouble(db -> db).toArray();
|
||||||
}
|
}
|
||||||
|
|
||||||
return new CompetingRiskResponseCombiner(events, times);
|
return new CompetingRiskResponseCombiner(events, times);
|
||||||
|
@ -152,7 +152,7 @@ public class Settings {
|
||||||
if(node.hasNonNull("times")){
|
if(node.hasNonNull("times")){
|
||||||
final List<Double> timeList = new ArrayList<>();
|
final List<Double> timeList = new ArrayList<>();
|
||||||
node.get("times").elements().forEachRemaining(event -> timeList.add(event.asDouble()));
|
node.get("times").elements().forEachRemaining(event -> timeList.add(event.asDouble()));
|
||||||
times = eventList.stream().mapToDouble(db -> db).toArray();
|
times = timeList.stream().mapToDouble(db -> db).toArray();
|
||||||
}
|
}
|
||||||
|
|
||||||
return new CompetingRiskFunctionCombiner(events, times);
|
return new CompetingRiskFunctionCombiner(events, times);
|
||||||
|
@ -171,7 +171,7 @@ public class Settings {
|
||||||
if(node.hasNonNull("times")){
|
if(node.hasNonNull("times")){
|
||||||
final List<Double> timeList = new ArrayList<>();
|
final List<Double> timeList = new ArrayList<>();
|
||||||
node.get("times").elements().forEachRemaining(event -> timeList.add(event.asDouble()));
|
node.get("times").elements().forEachRemaining(event -> timeList.add(event.asDouble()));
|
||||||
times = eventList.stream().mapToDouble(db -> db).toArray();
|
times = timeList.stream().mapToDouble(db -> db).toArray();
|
||||||
}
|
}
|
||||||
|
|
||||||
final CompetingRiskResponseCombiner responseCombiner = new CompetingRiskResponseCombiner(events, times);
|
final CompetingRiskResponseCombiner responseCombiner = new CompetingRiskResponseCombiner(events, times);
|
||||||
|
@ -217,9 +217,7 @@ public class Settings {
|
||||||
final ObjectMapper mapper = new ObjectMapper(new YAMLFactory());
|
final ObjectMapper mapper = new ObjectMapper(new YAMLFactory());
|
||||||
//mapper.enableDefaultTyping();
|
//mapper.enableDefaultTyping();
|
||||||
|
|
||||||
final Settings settings = mapper.readValue(file, Settings.class);
|
return mapper.readValue(file, Settings.class);
|
||||||
|
|
||||||
return settings;
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -14,7 +14,6 @@ import lombok.NoArgsConstructor;
|
||||||
@Getter
|
@Getter
|
||||||
@JsonTypeInfo(
|
@JsonTypeInfo(
|
||||||
use = JsonTypeInfo.Id.NAME,
|
use = JsonTypeInfo.Id.NAME,
|
||||||
include = JsonTypeInfo.As.PROPERTY,
|
|
||||||
property = "type")
|
property = "type")
|
||||||
@JsonSubTypes({
|
@JsonSubTypes({
|
||||||
@JsonSubTypes.Type(value = BooleanCovariateSettings.class, name = "boolean"),
|
@JsonSubTypes.Type(value = BooleanCovariateSettings.class, name = "boolean"),
|
||||||
|
|
|
@ -7,7 +7,6 @@ import lombok.Getter;
|
||||||
|
|
||||||
import java.io.Serializable;
|
import java.io.Serializable;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
|
||||||
|
|
||||||
@Builder
|
@Builder
|
||||||
public class CompetingRiskFunctions implements Serializable {
|
public class CompetingRiskFunctions implements Serializable {
|
||||||
|
|
|
@ -3,6 +3,7 @@ package ca.joeltherrien.randomforest.responses.competingrisk;
|
||||||
import ca.joeltherrien.randomforest.DataLoader;
|
import ca.joeltherrien.randomforest.DataLoader;
|
||||||
import com.fasterxml.jackson.databind.node.ObjectNode;
|
import com.fasterxml.jackson.databind.node.ObjectNode;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
import lombok.EqualsAndHashCode;
|
||||||
import lombok.RequiredArgsConstructor;
|
import lombok.RequiredArgsConstructor;
|
||||||
import org.apache.commons.csv.CSVRecord;
|
import org.apache.commons.csv.CSVRecord;
|
||||||
|
|
||||||
|
@ -10,6 +11,7 @@ import org.apache.commons.csv.CSVRecord;
|
||||||
* See Ishwaran paper on splitting rule modelled after Gray's test. This requires that we know the censor times.
|
* See Ishwaran paper on splitting rule modelled after Gray's test. This requires that we know the censor times.
|
||||||
*
|
*
|
||||||
*/
|
*/
|
||||||
|
@EqualsAndHashCode(callSuper = true)
|
||||||
@Data
|
@Data
|
||||||
public final class CompetingRiskResponseWithCensorTime extends CompetingRiskResponse {
|
public final class CompetingRiskResponseWithCensorTime extends CompetingRiskResponse {
|
||||||
private final double c;
|
private final double c;
|
||||||
|
|
|
@ -7,9 +7,7 @@ import ca.joeltherrien.randomforest.utils.Point;
|
||||||
import lombok.RequiredArgsConstructor;
|
import lombok.RequiredArgsConstructor;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.HashMap;
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
|
||||||
|
|
||||||
@RequiredArgsConstructor
|
@RequiredArgsConstructor
|
||||||
public class CompetingRiskFunctionCombiner implements ResponseCombiner<CompetingRiskFunctions, CompetingRiskFunctions> {
|
public class CompetingRiskFunctionCombiner implements ResponseCombiner<CompetingRiskFunctions, CompetingRiskFunctions> {
|
||||||
|
|
|
@ -2,7 +2,6 @@ package ca.joeltherrien.randomforest.tree;
|
||||||
|
|
||||||
import ca.joeltherrien.randomforest.CovariateRow;
|
import ca.joeltherrien.randomforest.CovariateRow;
|
||||||
import lombok.Builder;
|
import lombok.Builder;
|
||||||
import lombok.RequiredArgsConstructor;
|
|
||||||
|
|
||||||
import java.util.Collection;
|
import java.util.Collection;
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
|
|
|
@ -184,7 +184,7 @@ public class ForestTrainer<Y, TO, FO> {
|
||||||
private final int treeIndex;
|
private final int treeIndex;
|
||||||
private final List<Tree<TO>> treeList;
|
private final List<Tree<TO>> treeList;
|
||||||
|
|
||||||
public TreeInMemoryWorker(final List<Row<Y>> data, final int treeIndex, final List<Tree<TO>> treeList) {
|
TreeInMemoryWorker(final List<Row<Y>> data, final int treeIndex, final List<Tree<TO>> treeList) {
|
||||||
this.bootstrapper = new Bootstrapper<>(data);
|
this.bootstrapper = new Bootstrapper<>(data);
|
||||||
this.treeIndex = treeIndex;
|
this.treeIndex = treeIndex;
|
||||||
this.treeList = treeList;
|
this.treeList = treeList;
|
||||||
|
|
|
@ -43,9 +43,7 @@ public class TreeTrainer<Y, O> {
|
||||||
public Tree<O> growTree(List<Row<Y>> data){
|
public Tree<O> growTree(List<Row<Y>> data){
|
||||||
|
|
||||||
final Node<O> rootNode = growNode(data, 0);
|
final Node<O> rootNode = growNode(data, 0);
|
||||||
final Tree<O> tree = new Tree<>(rootNode, data.stream().mapToInt(Row::getId).toArray());
|
return new Tree<>(rootNode, data.stream().mapToInt(Row::getId).toArray());
|
||||||
|
|
||||||
return tree;
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -95,8 +93,8 @@ public class TreeTrainer<Y, O> {
|
||||||
final List<Covariate> splitCovariates = new ArrayList<>(covariates);
|
final List<Covariate> splitCovariates = new ArrayList<>(covariates);
|
||||||
Collections.shuffle(splitCovariates, ThreadLocalRandom.current());
|
Collections.shuffle(splitCovariates, ThreadLocalRandom.current());
|
||||||
|
|
||||||
for(int treeIndex = splitCovariates.size()-1; treeIndex >= mtry; treeIndex--){
|
if (splitCovariates.size() > mtry) {
|
||||||
splitCovariates.remove(treeIndex);
|
splitCovariates.subList(mtry, splitCovariates.size()).clear();
|
||||||
}
|
}
|
||||||
|
|
||||||
return splitCovariates;
|
return splitCovariates;
|
||||||
|
|
|
@ -74,9 +74,7 @@ public class Utils {
|
||||||
public static <T> List<T> easyList(T... array){
|
public static <T> List<T> easyList(T... array){
|
||||||
final List<T> list = new ArrayList<>(array.length);
|
final List<T> list = new ArrayList<>(array.length);
|
||||||
|
|
||||||
for(final T item : array){
|
Collections.addAll(list, array);
|
||||||
list.add(item);
|
|
||||||
}
|
|
||||||
|
|
||||||
return list;
|
return list;
|
||||||
|
|
||||||
|
|
|
@ -16,7 +16,6 @@ import static org.junit.jupiter.api.Assertions.*;
|
||||||
import java.io.File;
|
import java.io.File;
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
public class TestSavingLoading {
|
public class TestSavingLoading {
|
||||||
|
|
|
@ -19,7 +19,6 @@ import static org.junit.jupiter.api.Assertions.*;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
public class TestCompetingRisk {
|
public class TestCompetingRisk {
|
||||||
|
|
|
@ -26,9 +26,7 @@ public class TestCompetingRiskResponseCombiner {
|
||||||
|
|
||||||
final CompetingRiskResponseCombiner combiner = new CompetingRiskResponseCombiner(new int[]{1,2}, null);
|
final CompetingRiskResponseCombiner combiner = new CompetingRiskResponseCombiner(new int[]{1,2}, null);
|
||||||
|
|
||||||
final CompetingRiskFunctions functions = combiner.combine(data);
|
return combiner.combine(data);
|
||||||
|
|
||||||
return functions;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
|
|
@ -18,6 +18,7 @@ import java.util.List;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||||
|
|
||||||
public class TestLoadingCSV {
|
public class TestLoadingCSV {
|
||||||
|
|
||||||
|
@ -51,9 +52,7 @@ public class TestLoadingCSV {
|
||||||
|
|
||||||
final DataLoader.ResponseLoader loader = settings.getResponseLoader();
|
final DataLoader.ResponseLoader loader = settings.getResponseLoader();
|
||||||
|
|
||||||
final List<Row<Double>> data = DataLoader.loadData(covariates, loader, settings.getDataFileLocation());
|
return DataLoader.loadData(covariates, loader, settings.getDataFileLocation());
|
||||||
|
|
||||||
return data;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@ -94,9 +93,9 @@ public class TestLoadingCSV {
|
||||||
|
|
||||||
row = data.get(3);
|
row = data.get(3);
|
||||||
assertEquals(-3.0, (double)row.getResponse());
|
assertEquals(-3.0, (double)row.getResponse());
|
||||||
assertEquals(true, row.getCovariateValue("x1").isNA());
|
assertTrue(row.getCovariateValue("x1").isNA());
|
||||||
assertEquals(true, row.getCovariateValue("x2").isNA());
|
assertTrue(row.getCovariateValue("x2").isNA());
|
||||||
assertEquals(true, row.getCovariateValue("x3").isNA());
|
assertTrue(row.getCovariateValue("x3").isNA());
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -12,7 +12,6 @@ import org.junit.jupiter.api.Test;
|
||||||
|
|
||||||
import java.io.File;
|
import java.io.File;
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
public class TestPersistence {
|
public class TestPersistence {
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue