Simplied code.
This commit is contained in:
parent
9d9dc9ef8d
commit
55eab76610
2 changed files with 1 additions and 5 deletions
|
@ -29,9 +29,6 @@ public class ForestTrainer<Y, TO, FO> {
|
||||||
private final ResponseCombiner<TO, FO> treeResponseCombiner;
|
private final ResponseCombiner<TO, FO> treeResponseCombiner;
|
||||||
private final List<Row<Y>> data;
|
private final List<Row<Y>> data;
|
||||||
|
|
||||||
// number of covariates to randomly try
|
|
||||||
private final int mtry;
|
|
||||||
|
|
||||||
// number of trees to try
|
// number of trees to try
|
||||||
private final int ntree;
|
private final int ntree;
|
||||||
|
|
||||||
|
@ -39,7 +36,6 @@ public class ForestTrainer<Y, TO, FO> {
|
||||||
private final String saveTreeLocation;
|
private final String saveTreeLocation;
|
||||||
|
|
||||||
public ForestTrainer(final Settings settings, final List<Row<Y>> data, final List<Covariate> covariates){
|
public ForestTrainer(final Settings settings, final List<Row<Y>> data, final List<Covariate> covariates){
|
||||||
this.mtry = settings.getMtry();
|
|
||||||
this.ntree = settings.getNtree();
|
this.ntree = settings.getNtree();
|
||||||
this.data = data;
|
this.data = data;
|
||||||
this.displayProgress = true;
|
this.displayProgress = true;
|
||||||
|
|
|
@ -54,6 +54,7 @@ public class TrainForest {
|
||||||
final TreeTrainer<Double, Double> treeTrainer = TreeTrainer.<Double, Double>builder()
|
final TreeTrainer<Double, Double> treeTrainer = TreeTrainer.<Double, Double>builder()
|
||||||
.numberOfSplits(5)
|
.numberOfSplits(5)
|
||||||
.nodeSize(5)
|
.nodeSize(5)
|
||||||
|
.mtry(4)
|
||||||
.maxNodeDepth(100000000)
|
.maxNodeDepth(100000000)
|
||||||
.groupDifferentiator(new WeightedVarianceGroupDifferentiator())
|
.groupDifferentiator(new WeightedVarianceGroupDifferentiator())
|
||||||
.responseCombiner(new MeanResponseCombiner())
|
.responseCombiner(new MeanResponseCombiner())
|
||||||
|
@ -63,7 +64,6 @@ public class TrainForest {
|
||||||
.treeTrainer(treeTrainer)
|
.treeTrainer(treeTrainer)
|
||||||
.data(data)
|
.data(data)
|
||||||
.covariates(covariateList)
|
.covariates(covariateList)
|
||||||
.mtry(4)
|
|
||||||
.ntree(100)
|
.ntree(100)
|
||||||
.treeResponseCombiner(new MeanResponseCombiner())
|
.treeResponseCombiner(new MeanResponseCombiner())
|
||||||
.displayProgress(true)
|
.displayProgress(true)
|
||||||
|
|
Loading…
Reference in a new issue