Simplied code.

This commit is contained in:
Joel Therrien 2018-08-08 15:41:40 -07:00
parent 9d9dc9ef8d
commit 55eab76610
2 changed files with 1 additions and 5 deletions

View file

@ -29,9 +29,6 @@ public class ForestTrainer<Y, TO, FO> {
private final ResponseCombiner<TO, FO> treeResponseCombiner;
private final List<Row<Y>> data;
// number of covariates to randomly try
private final int mtry;
// number of trees to try
private final int ntree;
@ -39,7 +36,6 @@ public class ForestTrainer<Y, TO, FO> {
private final String saveTreeLocation;
public ForestTrainer(final Settings settings, final List<Row<Y>> data, final List<Covariate> covariates){
this.mtry = settings.getMtry();
this.ntree = settings.getNtree();
this.data = data;
this.displayProgress = true;

View file

@ -54,6 +54,7 @@ public class TrainForest {
final TreeTrainer<Double, Double> treeTrainer = TreeTrainer.<Double, Double>builder()
.numberOfSplits(5)
.nodeSize(5)
.mtry(4)
.maxNodeDepth(100000000)
.groupDifferentiator(new WeightedVarianceGroupDifferentiator())
.responseCombiner(new MeanResponseCombiner())
@ -63,7 +64,6 @@ public class TrainForest {
.treeTrainer(treeTrainer)
.data(data)
.covariates(covariateList)
.mtry(4)
.ntree(100)
.treeResponseCombiner(new MeanResponseCombiner())
.displayProgress(true)