Simplied code.

This commit is contained in:
Joel Therrien 2018-08-08 15:41:40 -07:00
parent 9d9dc9ef8d
commit 55eab76610
2 changed files with 1 additions and 5 deletions

View file

@ -29,9 +29,6 @@ public class ForestTrainer<Y, TO, FO> {
private final ResponseCombiner<TO, FO> treeResponseCombiner; private final ResponseCombiner<TO, FO> treeResponseCombiner;
private final List<Row<Y>> data; private final List<Row<Y>> data;
// number of covariates to randomly try
private final int mtry;
// number of trees to try // number of trees to try
private final int ntree; private final int ntree;
@ -39,7 +36,6 @@ public class ForestTrainer<Y, TO, FO> {
private final String saveTreeLocation; private final String saveTreeLocation;
public ForestTrainer(final Settings settings, final List<Row<Y>> data, final List<Covariate> covariates){ public ForestTrainer(final Settings settings, final List<Row<Y>> data, final List<Covariate> covariates){
this.mtry = settings.getMtry();
this.ntree = settings.getNtree(); this.ntree = settings.getNtree();
this.data = data; this.data = data;
this.displayProgress = true; this.displayProgress = true;

View file

@ -54,6 +54,7 @@ public class TrainForest {
final TreeTrainer<Double, Double> treeTrainer = TreeTrainer.<Double, Double>builder() final TreeTrainer<Double, Double> treeTrainer = TreeTrainer.<Double, Double>builder()
.numberOfSplits(5) .numberOfSplits(5)
.nodeSize(5) .nodeSize(5)
.mtry(4)
.maxNodeDepth(100000000) .maxNodeDepth(100000000)
.groupDifferentiator(new WeightedVarianceGroupDifferentiator()) .groupDifferentiator(new WeightedVarianceGroupDifferentiator())
.responseCombiner(new MeanResponseCombiner()) .responseCombiner(new MeanResponseCombiner())
@ -63,7 +64,6 @@ public class TrainForest {
.treeTrainer(treeTrainer) .treeTrainer(treeTrainer)
.data(data) .data(data)
.covariates(covariateList) .covariates(covariateList)
.mtry(4)
.ntree(100) .ntree(100)
.treeResponseCombiner(new MeanResponseCombiner()) .treeResponseCombiner(new MeanResponseCombiner())
.displayProgress(true) .displayProgress(true)