Code cleanup; fixed 3 minor bugs in the settings
This commit is contained in:
parent
75f34853ab
commit
8333579a1f
17 changed files with 19 additions and 41 deletions
|
@ -4,7 +4,6 @@ import lombok.RequiredArgsConstructor;
|
|||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Random;
|
||||
import java.util.concurrent.ThreadLocalRandom;
|
||||
|
||||
@RequiredArgsConstructor
|
||||
|
|
|
@ -74,13 +74,11 @@ public class DataLoader {
|
|||
|
||||
}
|
||||
|
||||
final Forest forest = Forest.<O, FO>builder()
|
||||
return Forest.<O, FO>builder()
|
||||
.trees(treeList)
|
||||
.treeResponseCombiner(treeResponseCombiner)
|
||||
.build();
|
||||
|
||||
return forest;
|
||||
|
||||
}
|
||||
|
||||
@FunctionalInterface
|
||||
|
|
|
@ -22,7 +22,6 @@ import java.util.stream.Collectors;
|
|||
|
||||
public class Main {
|
||||
|
||||
|
||||
public static void main(String[] args) throws IOException, ClassNotFoundException {
|
||||
if(args.length < 2){
|
||||
System.out.println("Must provide two arguments - the path to the settings.yaml file and instructions to either train or analyze.");
|
||||
|
@ -160,7 +159,7 @@ public class Main {
|
|||
yVarSettings.set("type", new TextNode("Double"));
|
||||
yVarSettings.set("name", new TextNode("y"));
|
||||
|
||||
final Settings settings = Settings.builder()
|
||||
return Settings.builder()
|
||||
.covariates(Utils.easyList(
|
||||
new NumericCovariateSettings("x1"),
|
||||
new BooleanCovariateSettings("x2"),
|
||||
|
@ -181,9 +180,6 @@ public class Main {
|
|||
.saveProgress(true)
|
||||
.saveTreeLocation("trees/")
|
||||
.build();
|
||||
|
||||
|
||||
return settings;
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -133,7 +133,7 @@ public class Settings {
|
|||
if(node.hasNonNull("times")){
|
||||
final List<Double> timeList = new ArrayList<>();
|
||||
node.get("times").elements().forEachRemaining(event -> timeList.add(event.asDouble()));
|
||||
times = eventList.stream().mapToDouble(db -> db).toArray();
|
||||
times = timeList.stream().mapToDouble(db -> db).toArray();
|
||||
}
|
||||
|
||||
return new CompetingRiskResponseCombiner(events, times);
|
||||
|
@ -152,7 +152,7 @@ public class Settings {
|
|||
if(node.hasNonNull("times")){
|
||||
final List<Double> timeList = new ArrayList<>();
|
||||
node.get("times").elements().forEachRemaining(event -> timeList.add(event.asDouble()));
|
||||
times = eventList.stream().mapToDouble(db -> db).toArray();
|
||||
times = timeList.stream().mapToDouble(db -> db).toArray();
|
||||
}
|
||||
|
||||
return new CompetingRiskFunctionCombiner(events, times);
|
||||
|
@ -171,7 +171,7 @@ public class Settings {
|
|||
if(node.hasNonNull("times")){
|
||||
final List<Double> timeList = new ArrayList<>();
|
||||
node.get("times").elements().forEachRemaining(event -> timeList.add(event.asDouble()));
|
||||
times = eventList.stream().mapToDouble(db -> db).toArray();
|
||||
times = timeList.stream().mapToDouble(db -> db).toArray();
|
||||
}
|
||||
|
||||
final CompetingRiskResponseCombiner responseCombiner = new CompetingRiskResponseCombiner(events, times);
|
||||
|
@ -217,9 +217,7 @@ public class Settings {
|
|||
final ObjectMapper mapper = new ObjectMapper(new YAMLFactory());
|
||||
//mapper.enableDefaultTyping();
|
||||
|
||||
final Settings settings = mapper.readValue(file, Settings.class);
|
||||
|
||||
return settings;
|
||||
return mapper.readValue(file, Settings.class);
|
||||
|
||||
}
|
||||
|
||||
|
|
|
@ -14,7 +14,6 @@ import lombok.NoArgsConstructor;
|
|||
@Getter
|
||||
@JsonTypeInfo(
|
||||
use = JsonTypeInfo.Id.NAME,
|
||||
include = JsonTypeInfo.As.PROPERTY,
|
||||
property = "type")
|
||||
@JsonSubTypes({
|
||||
@JsonSubTypes.Type(value = BooleanCovariateSettings.class, name = "boolean"),
|
||||
|
|
|
@ -7,7 +7,6 @@ import lombok.Getter;
|
|||
|
||||
import java.io.Serializable;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
@Builder
|
||||
public class CompetingRiskFunctions implements Serializable {
|
||||
|
|
|
@ -3,6 +3,7 @@ package ca.joeltherrien.randomforest.responses.competingrisk;
|
|||
import ca.joeltherrien.randomforest.DataLoader;
|
||||
import com.fasterxml.jackson.databind.node.ObjectNode;
|
||||
import lombok.Data;
|
||||
import lombok.EqualsAndHashCode;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import org.apache.commons.csv.CSVRecord;
|
||||
|
||||
|
@ -10,6 +11,7 @@ import org.apache.commons.csv.CSVRecord;
|
|||
* See Ishwaran paper on splitting rule modelled after Gray's test. This requires that we know the censor times.
|
||||
*
|
||||
*/
|
||||
@EqualsAndHashCode(callSuper = true)
|
||||
@Data
|
||||
public final class CompetingRiskResponseWithCensorTime extends CompetingRiskResponse {
|
||||
private final double c;
|
||||
|
|
|
@ -7,9 +7,7 @@ import ca.joeltherrien.randomforest.utils.Point;
|
|||
import lombok.RequiredArgsConstructor;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
@RequiredArgsConstructor
|
||||
public class CompetingRiskFunctionCombiner implements ResponseCombiner<CompetingRiskFunctions, CompetingRiskFunctions> {
|
||||
|
|
|
@ -2,7 +2,6 @@ package ca.joeltherrien.randomforest.tree;
|
|||
|
||||
import ca.joeltherrien.randomforest.CovariateRow;
|
||||
import lombok.Builder;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
|
||||
import java.util.Collection;
|
||||
import java.util.Collections;
|
||||
|
|
|
@ -184,7 +184,7 @@ public class ForestTrainer<Y, TO, FO> {
|
|||
private final int treeIndex;
|
||||
private final List<Tree<TO>> treeList;
|
||||
|
||||
public TreeInMemoryWorker(final List<Row<Y>> data, final int treeIndex, final List<Tree<TO>> treeList) {
|
||||
TreeInMemoryWorker(final List<Row<Y>> data, final int treeIndex, final List<Tree<TO>> treeList) {
|
||||
this.bootstrapper = new Bootstrapper<>(data);
|
||||
this.treeIndex = treeIndex;
|
||||
this.treeList = treeList;
|
||||
|
|
|
@ -43,9 +43,7 @@ public class TreeTrainer<Y, O> {
|
|||
public Tree<O> growTree(List<Row<Y>> data){
|
||||
|
||||
final Node<O> rootNode = growNode(data, 0);
|
||||
final Tree<O> tree = new Tree<>(rootNode, data.stream().mapToInt(Row::getId).toArray());
|
||||
|
||||
return tree;
|
||||
return new Tree<>(rootNode, data.stream().mapToInt(Row::getId).toArray());
|
||||
|
||||
}
|
||||
|
||||
|
@ -95,8 +93,8 @@ public class TreeTrainer<Y, O> {
|
|||
final List<Covariate> splitCovariates = new ArrayList<>(covariates);
|
||||
Collections.shuffle(splitCovariates, ThreadLocalRandom.current());
|
||||
|
||||
for(int treeIndex = splitCovariates.size()-1; treeIndex >= mtry; treeIndex--){
|
||||
splitCovariates.remove(treeIndex);
|
||||
if (splitCovariates.size() > mtry) {
|
||||
splitCovariates.subList(mtry, splitCovariates.size()).clear();
|
||||
}
|
||||
|
||||
return splitCovariates;
|
||||
|
|
|
@ -74,9 +74,7 @@ public class Utils {
|
|||
public static <T> List<T> easyList(T... array){
|
||||
final List<T> list = new ArrayList<>(array.length);
|
||||
|
||||
for(final T item : array){
|
||||
list.add(item);
|
||||
}
|
||||
Collections.addAll(list, array);
|
||||
|
||||
return list;
|
||||
|
||||
|
|
|
@ -16,7 +16,6 @@ import static org.junit.jupiter.api.Assertions.*;
|
|||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
public class TestSavingLoading {
|
||||
|
|
|
@ -19,7 +19,6 @@ import static org.junit.jupiter.api.Assertions.*;
|
|||
|
||||
import java.io.IOException;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
public class TestCompetingRisk {
|
||||
|
|
|
@ -26,9 +26,7 @@ public class TestCompetingRiskResponseCombiner {
|
|||
|
||||
final CompetingRiskResponseCombiner combiner = new CompetingRiskResponseCombiner(new int[]{1,2}, null);
|
||||
|
||||
final CompetingRiskFunctions functions = combiner.combine(data);
|
||||
|
||||
return functions;
|
||||
return combiner.combine(data);
|
||||
}
|
||||
|
||||
@Test
|
||||
|
|
|
@ -18,6 +18,7 @@ import java.util.List;
|
|||
import java.util.stream.Collectors;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||
|
||||
public class TestLoadingCSV {
|
||||
|
||||
|
@ -51,9 +52,7 @@ public class TestLoadingCSV {
|
|||
|
||||
final DataLoader.ResponseLoader loader = settings.getResponseLoader();
|
||||
|
||||
final List<Row<Double>> data = DataLoader.loadData(covariates, loader, settings.getDataFileLocation());
|
||||
|
||||
return data;
|
||||
return DataLoader.loadData(covariates, loader, settings.getDataFileLocation());
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -94,9 +93,9 @@ public class TestLoadingCSV {
|
|||
|
||||
row = data.get(3);
|
||||
assertEquals(-3.0, (double)row.getResponse());
|
||||
assertEquals(true, row.getCovariateValue("x1").isNA());
|
||||
assertEquals(true, row.getCovariateValue("x2").isNA());
|
||||
assertEquals(true, row.getCovariateValue("x3").isNA());
|
||||
assertTrue(row.getCovariateValue("x1").isNA());
|
||||
assertTrue(row.getCovariateValue("x2").isNA());
|
||||
assertTrue(row.getCovariateValue("x3").isNA());
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -12,7 +12,6 @@ import org.junit.jupiter.api.Test;
|
|||
|
||||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
import java.util.List;
|
||||
|
||||
public class TestPersistence {
|
||||
|
||||
|
|
Loading…
Reference in a new issue