Simplied code.
This commit is contained in:
parent
9d9dc9ef8d
commit
55eab76610
2 changed files with 1 additions and 5 deletions
|
@ -29,9 +29,6 @@ public class ForestTrainer<Y, TO, FO> {
|
|||
private final ResponseCombiner<TO, FO> treeResponseCombiner;
|
||||
private final List<Row<Y>> data;
|
||||
|
||||
// number of covariates to randomly try
|
||||
private final int mtry;
|
||||
|
||||
// number of trees to try
|
||||
private final int ntree;
|
||||
|
||||
|
@ -39,7 +36,6 @@ public class ForestTrainer<Y, TO, FO> {
|
|||
private final String saveTreeLocation;
|
||||
|
||||
public ForestTrainer(final Settings settings, final List<Row<Y>> data, final List<Covariate> covariates){
|
||||
this.mtry = settings.getMtry();
|
||||
this.ntree = settings.getNtree();
|
||||
this.data = data;
|
||||
this.displayProgress = true;
|
||||
|
|
|
@ -54,6 +54,7 @@ public class TrainForest {
|
|||
final TreeTrainer<Double, Double> treeTrainer = TreeTrainer.<Double, Double>builder()
|
||||
.numberOfSplits(5)
|
||||
.nodeSize(5)
|
||||
.mtry(4)
|
||||
.maxNodeDepth(100000000)
|
||||
.groupDifferentiator(new WeightedVarianceGroupDifferentiator())
|
||||
.responseCombiner(new MeanResponseCombiner())
|
||||
|
@ -63,7 +64,6 @@ public class TrainForest {
|
|||
.treeTrainer(treeTrainer)
|
||||
.data(data)
|
||||
.covariates(covariateList)
|
||||
.mtry(4)
|
||||
.ntree(100)
|
||||
.treeResponseCombiner(new MeanResponseCombiner())
|
||||
.displayProgress(true)
|
||||
|
|
Loading…
Reference in a new issue