Code cleanup; fixed 3 minor bugs in the settings

This commit is contained in:
Joel Therrien 2018-08-31 13:10:30 -07:00
parent 75f34853ab
commit 8333579a1f
17 changed files with 19 additions and 41 deletions

View file

@ -4,7 +4,6 @@ import lombok.RequiredArgsConstructor;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.Random;
import java.util.concurrent.ThreadLocalRandom; import java.util.concurrent.ThreadLocalRandom;
@RequiredArgsConstructor @RequiredArgsConstructor

View file

@ -74,13 +74,11 @@ public class DataLoader {
} }
final Forest forest = Forest.<O, FO>builder() return Forest.<O, FO>builder()
.trees(treeList) .trees(treeList)
.treeResponseCombiner(treeResponseCombiner) .treeResponseCombiner(treeResponseCombiner)
.build(); .build();
return forest;
} }
@FunctionalInterface @FunctionalInterface

View file

@ -22,7 +22,6 @@ import java.util.stream.Collectors;
public class Main { public class Main {
public static void main(String[] args) throws IOException, ClassNotFoundException { public static void main(String[] args) throws IOException, ClassNotFoundException {
if(args.length < 2){ if(args.length < 2){
System.out.println("Must provide two arguments - the path to the settings.yaml file and instructions to either train or analyze."); System.out.println("Must provide two arguments - the path to the settings.yaml file and instructions to either train or analyze.");
@ -160,7 +159,7 @@ public class Main {
yVarSettings.set("type", new TextNode("Double")); yVarSettings.set("type", new TextNode("Double"));
yVarSettings.set("name", new TextNode("y")); yVarSettings.set("name", new TextNode("y"));
final Settings settings = Settings.builder() return Settings.builder()
.covariates(Utils.easyList( .covariates(Utils.easyList(
new NumericCovariateSettings("x1"), new NumericCovariateSettings("x1"),
new BooleanCovariateSettings("x2"), new BooleanCovariateSettings("x2"),
@ -181,9 +180,6 @@ public class Main {
.saveProgress(true) .saveProgress(true)
.saveTreeLocation("trees/") .saveTreeLocation("trees/")
.build(); .build();
return settings;
} }
} }

View file

@ -133,7 +133,7 @@ public class Settings {
if(node.hasNonNull("times")){ if(node.hasNonNull("times")){
final List<Double> timeList = new ArrayList<>(); final List<Double> timeList = new ArrayList<>();
node.get("times").elements().forEachRemaining(event -> timeList.add(event.asDouble())); node.get("times").elements().forEachRemaining(event -> timeList.add(event.asDouble()));
times = eventList.stream().mapToDouble(db -> db).toArray(); times = timeList.stream().mapToDouble(db -> db).toArray();
} }
return new CompetingRiskResponseCombiner(events, times); return new CompetingRiskResponseCombiner(events, times);
@ -152,7 +152,7 @@ public class Settings {
if(node.hasNonNull("times")){ if(node.hasNonNull("times")){
final List<Double> timeList = new ArrayList<>(); final List<Double> timeList = new ArrayList<>();
node.get("times").elements().forEachRemaining(event -> timeList.add(event.asDouble())); node.get("times").elements().forEachRemaining(event -> timeList.add(event.asDouble()));
times = eventList.stream().mapToDouble(db -> db).toArray(); times = timeList.stream().mapToDouble(db -> db).toArray();
} }
return new CompetingRiskFunctionCombiner(events, times); return new CompetingRiskFunctionCombiner(events, times);
@ -171,7 +171,7 @@ public class Settings {
if(node.hasNonNull("times")){ if(node.hasNonNull("times")){
final List<Double> timeList = new ArrayList<>(); final List<Double> timeList = new ArrayList<>();
node.get("times").elements().forEachRemaining(event -> timeList.add(event.asDouble())); node.get("times").elements().forEachRemaining(event -> timeList.add(event.asDouble()));
times = eventList.stream().mapToDouble(db -> db).toArray(); times = timeList.stream().mapToDouble(db -> db).toArray();
} }
final CompetingRiskResponseCombiner responseCombiner = new CompetingRiskResponseCombiner(events, times); final CompetingRiskResponseCombiner responseCombiner = new CompetingRiskResponseCombiner(events, times);
@ -217,9 +217,7 @@ public class Settings {
final ObjectMapper mapper = new ObjectMapper(new YAMLFactory()); final ObjectMapper mapper = new ObjectMapper(new YAMLFactory());
//mapper.enableDefaultTyping(); //mapper.enableDefaultTyping();
final Settings settings = mapper.readValue(file, Settings.class); return mapper.readValue(file, Settings.class);
return settings;
} }

View file

@ -14,7 +14,6 @@ import lombok.NoArgsConstructor;
@Getter @Getter
@JsonTypeInfo( @JsonTypeInfo(
use = JsonTypeInfo.Id.NAME, use = JsonTypeInfo.Id.NAME,
include = JsonTypeInfo.As.PROPERTY,
property = "type") property = "type")
@JsonSubTypes({ @JsonSubTypes({
@JsonSubTypes.Type(value = BooleanCovariateSettings.class, name = "boolean"), @JsonSubTypes.Type(value = BooleanCovariateSettings.class, name = "boolean"),

View file

@ -7,7 +7,6 @@ import lombok.Getter;
import java.io.Serializable; import java.io.Serializable;
import java.util.List; import java.util.List;
import java.util.Map;
@Builder @Builder
public class CompetingRiskFunctions implements Serializable { public class CompetingRiskFunctions implements Serializable {

View file

@ -3,6 +3,7 @@ package ca.joeltherrien.randomforest.responses.competingrisk;
import ca.joeltherrien.randomforest.DataLoader; import ca.joeltherrien.randomforest.DataLoader;
import com.fasterxml.jackson.databind.node.ObjectNode; import com.fasterxml.jackson.databind.node.ObjectNode;
import lombok.Data; import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.RequiredArgsConstructor; import lombok.RequiredArgsConstructor;
import org.apache.commons.csv.CSVRecord; import org.apache.commons.csv.CSVRecord;
@ -10,6 +11,7 @@ import org.apache.commons.csv.CSVRecord;
* See Ishwaran paper on splitting rule modelled after Gray's test. This requires that we know the censor times. * See Ishwaran paper on splitting rule modelled after Gray's test. This requires that we know the censor times.
* *
*/ */
@EqualsAndHashCode(callSuper = true)
@Data @Data
public final class CompetingRiskResponseWithCensorTime extends CompetingRiskResponse { public final class CompetingRiskResponseWithCensorTime extends CompetingRiskResponse {
private final double c; private final double c;

View file

@ -7,9 +7,7 @@ import ca.joeltherrien.randomforest.utils.Point;
import lombok.RequiredArgsConstructor; import lombok.RequiredArgsConstructor;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Map;
@RequiredArgsConstructor @RequiredArgsConstructor
public class CompetingRiskFunctionCombiner implements ResponseCombiner<CompetingRiskFunctions, CompetingRiskFunctions> { public class CompetingRiskFunctionCombiner implements ResponseCombiner<CompetingRiskFunctions, CompetingRiskFunctions> {

View file

@ -2,7 +2,6 @@ package ca.joeltherrien.randomforest.tree;
import ca.joeltherrien.randomforest.CovariateRow; import ca.joeltherrien.randomforest.CovariateRow;
import lombok.Builder; import lombok.Builder;
import lombok.RequiredArgsConstructor;
import java.util.Collection; import java.util.Collection;
import java.util.Collections; import java.util.Collections;

View file

@ -184,7 +184,7 @@ public class ForestTrainer<Y, TO, FO> {
private final int treeIndex; private final int treeIndex;
private final List<Tree<TO>> treeList; private final List<Tree<TO>> treeList;
public TreeInMemoryWorker(final List<Row<Y>> data, final int treeIndex, final List<Tree<TO>> treeList) { TreeInMemoryWorker(final List<Row<Y>> data, final int treeIndex, final List<Tree<TO>> treeList) {
this.bootstrapper = new Bootstrapper<>(data); this.bootstrapper = new Bootstrapper<>(data);
this.treeIndex = treeIndex; this.treeIndex = treeIndex;
this.treeList = treeList; this.treeList = treeList;

View file

@ -43,9 +43,7 @@ public class TreeTrainer<Y, O> {
public Tree<O> growTree(List<Row<Y>> data){ public Tree<O> growTree(List<Row<Y>> data){
final Node<O> rootNode = growNode(data, 0); final Node<O> rootNode = growNode(data, 0);
final Tree<O> tree = new Tree<>(rootNode, data.stream().mapToInt(Row::getId).toArray()); return new Tree<>(rootNode, data.stream().mapToInt(Row::getId).toArray());
return tree;
} }
@ -95,8 +93,8 @@ public class TreeTrainer<Y, O> {
final List<Covariate> splitCovariates = new ArrayList<>(covariates); final List<Covariate> splitCovariates = new ArrayList<>(covariates);
Collections.shuffle(splitCovariates, ThreadLocalRandom.current()); Collections.shuffle(splitCovariates, ThreadLocalRandom.current());
for(int treeIndex = splitCovariates.size()-1; treeIndex >= mtry; treeIndex--){ if (splitCovariates.size() > mtry) {
splitCovariates.remove(treeIndex); splitCovariates.subList(mtry, splitCovariates.size()).clear();
} }
return splitCovariates; return splitCovariates;

View file

@ -74,9 +74,7 @@ public class Utils {
public static <T> List<T> easyList(T... array){ public static <T> List<T> easyList(T... array){
final List<T> list = new ArrayList<>(array.length); final List<T> list = new ArrayList<>(array.length);
for(final T item : array){ Collections.addAll(list, array);
list.add(item);
}
return list; return list;

View file

@ -16,7 +16,6 @@ import static org.junit.jupiter.api.Assertions.*;
import java.io.File; import java.io.File;
import java.io.IOException; import java.io.IOException;
import java.util.List; import java.util.List;
import java.util.Map;
import java.util.stream.Collectors; import java.util.stream.Collectors;
public class TestSavingLoading { public class TestSavingLoading {

View file

@ -19,7 +19,6 @@ import static org.junit.jupiter.api.Assertions.*;
import java.io.IOException; import java.io.IOException;
import java.util.List; import java.util.List;
import java.util.Map;
import java.util.stream.Collectors; import java.util.stream.Collectors;
public class TestCompetingRisk { public class TestCompetingRisk {

View file

@ -26,9 +26,7 @@ public class TestCompetingRiskResponseCombiner {
final CompetingRiskResponseCombiner combiner = new CompetingRiskResponseCombiner(new int[]{1,2}, null); final CompetingRiskResponseCombiner combiner = new CompetingRiskResponseCombiner(new int[]{1,2}, null);
final CompetingRiskFunctions functions = combiner.combine(data); return combiner.combine(data);
return functions;
} }
@Test @Test

View file

@ -18,6 +18,7 @@ import java.util.List;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
public class TestLoadingCSV { public class TestLoadingCSV {
@ -51,9 +52,7 @@ public class TestLoadingCSV {
final DataLoader.ResponseLoader loader = settings.getResponseLoader(); final DataLoader.ResponseLoader loader = settings.getResponseLoader();
final List<Row<Double>> data = DataLoader.loadData(covariates, loader, settings.getDataFileLocation()); return DataLoader.loadData(covariates, loader, settings.getDataFileLocation());
return data;
} }
@Test @Test
@ -94,9 +93,9 @@ public class TestLoadingCSV {
row = data.get(3); row = data.get(3);
assertEquals(-3.0, (double)row.getResponse()); assertEquals(-3.0, (double)row.getResponse());
assertEquals(true, row.getCovariateValue("x1").isNA()); assertTrue(row.getCovariateValue("x1").isNA());
assertEquals(true, row.getCovariateValue("x2").isNA()); assertTrue(row.getCovariateValue("x2").isNA());
assertEquals(true, row.getCovariateValue("x3").isNA()); assertTrue(row.getCovariateValue("x3").isNA());
} }
} }

View file

@ -12,7 +12,6 @@ import org.junit.jupiter.api.Test;
import java.io.File; import java.io.File;
import java.io.IOException; import java.io.IOException;
import java.util.List;
public class TestPersistence { public class TestPersistence {