largeRCRF-Java/src/main/java/ca/joeltherrien/randomforest/Main.java

105 lines
3.6 KiB
Java

package ca.joeltherrien.randomforest;
import ca.joeltherrien.randomforest.covariates.BooleanCovariateSettings;
import ca.joeltherrien.randomforest.covariates.Covariate;
import ca.joeltherrien.randomforest.covariates.FactorCovariateSettings;
import ca.joeltherrien.randomforest.covariates.NumericCovariateSettings;
import ca.joeltherrien.randomforest.tree.ForestTrainer;
import org.apache.commons.csv.CSVFormat;
import org.apache.commons.csv.CSVParser;
import org.apache.commons.csv.CSVRecord;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.io.Reader;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
public class Main {
public static void main(String[] args) throws IOException {
if(args.length != 1){
System.out.println("Must provide one argument - the path to the settings.yaml file.");
if(args.length == 0){
System.out.println("Generating template file.");
defaultTemplate().save(new File("template.yaml"));
}
return;
}
final File settingsFile = new File(args[0]);
final Settings settings = Settings.load(settingsFile);
final List<Covariate> covariates = settings.getCovariates().stream()
.map(cs -> cs.build()).collect(Collectors.toList());
final List<Row<Double>> dataset = loadData(covariates, settings);
final ForestTrainer<Double> forestTrainer = new ForestTrainer<>(settings, dataset, covariates);
if(settings.isSaveProgress()){
forestTrainer.trainParallelOnDisk(settings.getNumberOfThreads());
}
else{
forestTrainer.trainParallelInMemory(settings.getNumberOfThreads());
}
}
public static List<Row<Double>> loadData(final List<Covariate> covariates, final Settings settings) throws IOException {
final List<Row<Double>> dataset = new ArrayList<>();
final Reader input = new FileReader(settings.getDataFileLocation());
final CSVParser parser = CSVFormat.RFC4180.withFirstRecordAsHeader().parse(input);
int id = 1;
for(final CSVRecord record : parser){
final Map<String, Covariate.Value> covariateValueMap = new HashMap<>();
for(final Covariate<?> covariate : covariates){
covariateValueMap.put(covariate.getName(), covariate.createValue(record.get(covariate.getName())));
}
final String yStr = record.get(settings.getYVar());
final Double yNum = Double.parseDouble(yStr);
dataset.add(new Row<>(covariateValueMap, id++, yNum));
}
return dataset;
}
private static Settings defaultTemplate(){
return Settings.builder()
.covariates(List.of(
new NumericCovariateSettings("x1"),
new BooleanCovariateSettings("x2"),
new FactorCovariateSettings("x3", List.of("cat", "mouse", "dog"))
)
)
.yVar("y")
.dataFileLocation("data.csv")
.groupDifferentiator("WeightedVarianceGroupDifferentiator")
.responseCombiner("MeanResponseCombiner")
.treeResponseCombiner("MeanResponseCombiner")
.maxNodeDepth(100000)
.mtry(2)
.nodeSize(5)
.ntree(500)
.numberOfSplits(5)
.numberOfThreads(1)
.saveProgress(true)
.saveTreeLocation("trees/")
.build();
}
}