Code cleanup; fixed 3 minor bugs in the settings

This commit is contained in:
Joel Therrien 2018-08-31 13:10:30 -07:00
parent 75f34853ab
commit 8333579a1f
17 changed files with 19 additions and 41 deletions

View file

@ -4,7 +4,6 @@ import lombok.RequiredArgsConstructor;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import java.util.concurrent.ThreadLocalRandom;
@RequiredArgsConstructor

View file

@ -74,13 +74,11 @@ public class DataLoader {
}
final Forest forest = Forest.<O, FO>builder()
return Forest.<O, FO>builder()
.trees(treeList)
.treeResponseCombiner(treeResponseCombiner)
.build();
return forest;
}
@FunctionalInterface

View file

@ -22,7 +22,6 @@ import java.util.stream.Collectors;
public class Main {
public static void main(String[] args) throws IOException, ClassNotFoundException {
if(args.length < 2){
System.out.println("Must provide two arguments - the path to the settings.yaml file and instructions to either train or analyze.");
@ -160,7 +159,7 @@ public class Main {
yVarSettings.set("type", new TextNode("Double"));
yVarSettings.set("name", new TextNode("y"));
final Settings settings = Settings.builder()
return Settings.builder()
.covariates(Utils.easyList(
new NumericCovariateSettings("x1"),
new BooleanCovariateSettings("x2"),
@ -181,9 +180,6 @@ public class Main {
.saveProgress(true)
.saveTreeLocation("trees/")
.build();
return settings;
}
}

View file

@ -133,7 +133,7 @@ public class Settings {
if(node.hasNonNull("times")){
final List<Double> timeList = new ArrayList<>();
node.get("times").elements().forEachRemaining(event -> timeList.add(event.asDouble()));
times = eventList.stream().mapToDouble(db -> db).toArray();
times = timeList.stream().mapToDouble(db -> db).toArray();
}
return new CompetingRiskResponseCombiner(events, times);
@ -152,7 +152,7 @@ public class Settings {
if(node.hasNonNull("times")){
final List<Double> timeList = new ArrayList<>();
node.get("times").elements().forEachRemaining(event -> timeList.add(event.asDouble()));
times = eventList.stream().mapToDouble(db -> db).toArray();
times = timeList.stream().mapToDouble(db -> db).toArray();
}
return new CompetingRiskFunctionCombiner(events, times);
@ -171,7 +171,7 @@ public class Settings {
if(node.hasNonNull("times")){
final List<Double> timeList = new ArrayList<>();
node.get("times").elements().forEachRemaining(event -> timeList.add(event.asDouble()));
times = eventList.stream().mapToDouble(db -> db).toArray();
times = timeList.stream().mapToDouble(db -> db).toArray();
}
final CompetingRiskResponseCombiner responseCombiner = new CompetingRiskResponseCombiner(events, times);
@ -217,9 +217,7 @@ public class Settings {
final ObjectMapper mapper = new ObjectMapper(new YAMLFactory());
//mapper.enableDefaultTyping();
final Settings settings = mapper.readValue(file, Settings.class);
return settings;
return mapper.readValue(file, Settings.class);
}

View file

@ -14,7 +14,6 @@ import lombok.NoArgsConstructor;
@Getter
@JsonTypeInfo(
use = JsonTypeInfo.Id.NAME,
include = JsonTypeInfo.As.PROPERTY,
property = "type")
@JsonSubTypes({
@JsonSubTypes.Type(value = BooleanCovariateSettings.class, name = "boolean"),

View file

@ -7,7 +7,6 @@ import lombok.Getter;
import java.io.Serializable;
import java.util.List;
import java.util.Map;
@Builder
public class CompetingRiskFunctions implements Serializable {

View file

@ -3,6 +3,7 @@ package ca.joeltherrien.randomforest.responses.competingrisk;
import ca.joeltherrien.randomforest.DataLoader;
import com.fasterxml.jackson.databind.node.ObjectNode;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.RequiredArgsConstructor;
import org.apache.commons.csv.CSVRecord;
@ -10,6 +11,7 @@ import org.apache.commons.csv.CSVRecord;
* See Ishwaran paper on splitting rule modelled after Gray's test. This requires that we know the censor times.
*
*/
@EqualsAndHashCode(callSuper = true)
@Data
public final class CompetingRiskResponseWithCensorTime extends CompetingRiskResponse {
private final double c;

View file

@ -7,9 +7,7 @@ import ca.joeltherrien.randomforest.utils.Point;
import lombok.RequiredArgsConstructor;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
@RequiredArgsConstructor
public class CompetingRiskFunctionCombiner implements ResponseCombiner<CompetingRiskFunctions, CompetingRiskFunctions> {

View file

@ -2,7 +2,6 @@ package ca.joeltherrien.randomforest.tree;
import ca.joeltherrien.randomforest.CovariateRow;
import lombok.Builder;
import lombok.RequiredArgsConstructor;
import java.util.Collection;
import java.util.Collections;

View file

@ -184,7 +184,7 @@ public class ForestTrainer<Y, TO, FO> {
private final int treeIndex;
private final List<Tree<TO>> treeList;
public TreeInMemoryWorker(final List<Row<Y>> data, final int treeIndex, final List<Tree<TO>> treeList) {
TreeInMemoryWorker(final List<Row<Y>> data, final int treeIndex, final List<Tree<TO>> treeList) {
this.bootstrapper = new Bootstrapper<>(data);
this.treeIndex = treeIndex;
this.treeList = treeList;

View file

@ -43,9 +43,7 @@ public class TreeTrainer<Y, O> {
public Tree<O> growTree(List<Row<Y>> data){
final Node<O> rootNode = growNode(data, 0);
final Tree<O> tree = new Tree<>(rootNode, data.stream().mapToInt(Row::getId).toArray());
return tree;
return new Tree<>(rootNode, data.stream().mapToInt(Row::getId).toArray());
}
@ -95,8 +93,8 @@ public class TreeTrainer<Y, O> {
final List<Covariate> splitCovariates = new ArrayList<>(covariates);
Collections.shuffle(splitCovariates, ThreadLocalRandom.current());
for(int treeIndex = splitCovariates.size()-1; treeIndex >= mtry; treeIndex--){
splitCovariates.remove(treeIndex);
if (splitCovariates.size() > mtry) {
splitCovariates.subList(mtry, splitCovariates.size()).clear();
}
return splitCovariates;

View file

@ -74,9 +74,7 @@ public class Utils {
public static <T> List<T> easyList(T... array){
final List<T> list = new ArrayList<>(array.length);
for(final T item : array){
list.add(item);
}
Collections.addAll(list, array);
return list;

View file

@ -16,7 +16,6 @@ import static org.junit.jupiter.api.Assertions.*;
import java.io.File;
import java.io.IOException;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
public class TestSavingLoading {

View file

@ -19,7 +19,6 @@ import static org.junit.jupiter.api.Assertions.*;
import java.io.IOException;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
public class TestCompetingRisk {

View file

@ -26,9 +26,7 @@ public class TestCompetingRiskResponseCombiner {
final CompetingRiskResponseCombiner combiner = new CompetingRiskResponseCombiner(new int[]{1,2}, null);
final CompetingRiskFunctions functions = combiner.combine(data);
return functions;
return combiner.combine(data);
}
@Test

View file

@ -18,6 +18,7 @@ import java.util.List;
import java.util.stream.Collectors;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
public class TestLoadingCSV {
@ -51,9 +52,7 @@ public class TestLoadingCSV {
final DataLoader.ResponseLoader loader = settings.getResponseLoader();
final List<Row<Double>> data = DataLoader.loadData(covariates, loader, settings.getDataFileLocation());
return data;
return DataLoader.loadData(covariates, loader, settings.getDataFileLocation());
}
@Test
@ -94,9 +93,9 @@ public class TestLoadingCSV {
row = data.get(3);
assertEquals(-3.0, (double)row.getResponse());
assertEquals(true, row.getCovariateValue("x1").isNA());
assertEquals(true, row.getCovariateValue("x2").isNA());
assertEquals(true, row.getCovariateValue("x3").isNA());
assertTrue(row.getCovariateValue("x1").isNA());
assertTrue(row.getCovariateValue("x2").isNA());
assertTrue(row.getCovariateValue("x3").isNA());
}
}

View file

@ -12,7 +12,6 @@ import org.junit.jupiter.api.Test;
import java.io.File;
import java.io.IOException;
import java.util.List;
public class TestPersistence {