Fix bug where parallel forests never finish
Add test to detect case
This commit is contained in:
parent
de3de300cf
commit
c5c74ad7e9
2 changed files with 63 additions and 20 deletions
|
@ -139,7 +139,7 @@ public class ForestTrainer<Y, TO, FO> {
|
||||||
int prevNumberTreesSet = -1;
|
int prevNumberTreesSet = -1;
|
||||||
while(true){
|
while(true){
|
||||||
try {
|
try {
|
||||||
if (executorService.awaitTermination(5, TimeUnit.SECONDS)) break;
|
if (executorService.awaitTermination(1, TimeUnit.SECONDS)) break;
|
||||||
} catch (InterruptedException e) {
|
} catch (InterruptedException e) {
|
||||||
System.err.println("There was an InterruptedException while waiting for the forest to finish training; this is unusual but on its own shouldn't be a problem.");
|
System.err.println("There was an InterruptedException while waiting for the forest to finish training; this is unusual but on its own shouldn't be a problem.");
|
||||||
System.err.println("Please send a bug report about it to joelt@sfu.ca");
|
System.err.println("Please send a bug report about it to joelt@sfu.ca");
|
||||||
|
@ -147,7 +147,6 @@ public class ForestTrainer<Y, TO, FO> {
|
||||||
// do nothing; this shouldn't be an issue
|
// do nothing; this shouldn't be an issue
|
||||||
}
|
}
|
||||||
|
|
||||||
if(displayProgress) {
|
|
||||||
int numberTreesSet = 0;
|
int numberTreesSet = 0;
|
||||||
for (final Tree<TO> tree : trees) {
|
for (final Tree<TO> tree : trees) {
|
||||||
if (tree != null) {
|
if (tree != null) {
|
||||||
|
@ -155,13 +154,15 @@ public class ForestTrainer<Y, TO, FO> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if(displayProgress && numberTreesSet != prevNumberTreesSet) {
|
||||||
// Only output trees set on screen if there was a change
|
// Only output trees set on screen if there was a change
|
||||||
// In some environments where standard output is streamed to a file this method below causes frequent writes to output
|
// In some environments where standard output is streamed to a file this method below causes frequent writes to output
|
||||||
if(numberTreesSet != prevNumberTreesSet){
|
|
||||||
System.out.print("\rFinished " + numberTreesSet + "/" + ntree + " trees");
|
System.out.print("\rFinished " + numberTreesSet + "/" + ntree + " trees");
|
||||||
prevNumberTreesSet = numberTreesSet;
|
prevNumberTreesSet = numberTreesSet;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if(numberTreesSet == ntree){
|
||||||
|
executorService.shutdown();
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -199,24 +200,25 @@ public class ForestTrainer<Y, TO, FO> {
|
||||||
int prevNumberTreesSet = -1;
|
int prevNumberTreesSet = -1;
|
||||||
while(true){
|
while(true){
|
||||||
try {
|
try {
|
||||||
if (executorService.awaitTermination(5, TimeUnit.SECONDS)) break;
|
if (executorService.awaitTermination(1, TimeUnit.SECONDS)) break;
|
||||||
} catch (InterruptedException e) {
|
} catch (InterruptedException e) {
|
||||||
System.err.println("There was an InterruptedException while waiting for the forest to finish training; this is unusual but on its own shouldn't be a problem.");
|
System.err.println("There was an InterruptedException while waiting for the forest to finish training; this is unusual but on its own shouldn't be a problem.");
|
||||||
System.err.println("Please send a bug report about it to joelt@sfu.ca");
|
System.err.println("Please send a bug report about it to joelt@sfu.ca");
|
||||||
e.printStackTrace();
|
e.printStackTrace();
|
||||||
// do nothing; this shouldn't be an issue
|
// do nothing; this shouldn't be an issue
|
||||||
}
|
}
|
||||||
|
|
||||||
if(displayProgress) {
|
|
||||||
int numberTreesSet = treeCount.get();
|
int numberTreesSet = treeCount.get();
|
||||||
|
|
||||||
|
if(displayProgress && numberTreesSet != prevNumberTreesSet) {
|
||||||
// Only output trees set on screen if there was a change
|
// Only output trees set on screen if there was a change
|
||||||
// In some environments where standard output is streamed to a file this method below causes frequent writes to output
|
// In some environments where standard output is streamed to a file this method below causes frequent writes to output
|
||||||
if(numberTreesSet != prevNumberTreesSet){
|
|
||||||
System.out.print("\rFinished " + numberTreesSet + "/" + ntree + " trees");
|
System.out.print("\rFinished " + numberTreesSet + "/" + ntree + " trees");
|
||||||
prevNumberTreesSet = numberTreesSet;
|
prevNumberTreesSet = numberTreesSet;
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if(numberTreesSet == ntree){
|
||||||
|
executorService.shutdown();
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -109,7 +109,7 @@ public class TestSavingLoading {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testSavingLoading() throws IOException, ClassNotFoundException {
|
public void testSavingLoadingSerial() throws IOException, ClassNotFoundException {
|
||||||
final Settings settings = getSettings();
|
final Settings settings = getSettings();
|
||||||
final List<Covariate> covariates = settings.getCovariates();
|
final List<Covariate> covariates = settings.getCovariates();
|
||||||
final List<Row<CompetingRiskResponse>> dataset = DataUtils.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation());
|
final List<Row<CompetingRiskResponse>> dataset = DataUtils.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation());
|
||||||
|
@ -131,6 +131,47 @@ public class TestSavingLoading {
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
final Forest<CompetingRiskFunctions, CompetingRiskFunctions> forest = DataUtils.loadForest(directory, new CompetingRiskFunctionCombiner(new int[]{1,2}, null));
|
||||||
|
|
||||||
|
final CovariateRow predictionRow = getPredictionRow(covariates);
|
||||||
|
|
||||||
|
final CompetingRiskFunctions functions = forest.evaluate(predictionRow);
|
||||||
|
assertNotNull(functions);
|
||||||
|
assertTrue(functions.getCumulativeIncidenceFunction(1).getX().length > 2);
|
||||||
|
|
||||||
|
|
||||||
|
assertEquals(NTREE, forest.getTrees().size());
|
||||||
|
|
||||||
|
cleanup(directory);
|
||||||
|
|
||||||
|
assertFalse(directory.exists());
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testSavingLoadingParallel() throws IOException, ClassNotFoundException {
|
||||||
|
final Settings settings = getSettings();
|
||||||
|
final List<Covariate> covariates = settings.getCovariates();
|
||||||
|
final List<Row<CompetingRiskResponse>> dataset = DataUtils.loadData(covariates, settings.getResponseLoader(), settings.getTrainingDataLocation());
|
||||||
|
|
||||||
|
final File directory = new File(settings.getSaveTreeLocation());
|
||||||
|
if(directory.exists()){
|
||||||
|
cleanup(directory);
|
||||||
|
}
|
||||||
|
assertFalse(directory.exists());
|
||||||
|
directory.mkdir();
|
||||||
|
|
||||||
|
final ForestTrainer<CompetingRiskResponse, CompetingRiskFunctions, CompetingRiskFunctions> forestTrainer = new ForestTrainer<>(settings, dataset, covariates);
|
||||||
|
|
||||||
|
forestTrainer.trainParallelOnDisk(settings.getNumberOfThreads());
|
||||||
|
|
||||||
|
assertTrue(directory.exists());
|
||||||
|
assertTrue(directory.isDirectory());
|
||||||
|
assertEquals(NTREE, directory.listFiles().length);
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
final Forest<CompetingRiskFunctions, CompetingRiskFunctions> forest = DataUtils.loadForest(directory, new CompetingRiskFunctionCombiner(new int[]{1,2}, null));
|
final Forest<CompetingRiskFunctions, CompetingRiskFunctions> forest = DataUtils.loadForest(directory, new CompetingRiskFunctionCombiner(new int[]{1,2}, null));
|
||||||
|
|
||||||
final CovariateRow predictionRow = getPredictionRow(covariates);
|
final CovariateRow predictionRow = getPredictionRow(covariates);
|
||||||
|
|
Loading…
Reference in a new issue