package ca.joeltherrien.randomforest.competingrisk; import ca.joeltherrien.randomforest.*; import ca.joeltherrien.randomforest.covariates.*; import ca.joeltherrien.randomforest.responses.competingrisk.*; import ca.joeltherrien.randomforest.tree.Forest; import ca.joeltherrien.randomforest.tree.ForestTrainer; import ca.joeltherrien.randomforest.tree.Node; import ca.joeltherrien.randomforest.tree.TreeTrainer; import ca.joeltherrien.randomforest.utils.MathFunction; import ca.joeltherrien.randomforest.utils.Point; import ca.joeltherrien.randomforest.utils.Utils; import com.fasterxml.jackson.databind.node.*; import org.junit.jupiter.api.Test; import static ca.joeltherrien.randomforest.TestUtils.assertCumulativeFunction; import static ca.joeltherrien.randomforest.TestUtils.closeEnough; import static org.junit.jupiter.api.Assertions.*; import java.io.IOException; import java.util.List; import java.util.stream.Collectors; public class TestCompetingRisk { /** * By default uses single log-rank test. * * @return */ public Settings getSettings(){ final ObjectNode groupDifferentiatorSettings = new ObjectNode(JsonNodeFactory.instance); groupDifferentiatorSettings.set("type", new TextNode("LogRankSingleGroupDifferentiator")); groupDifferentiatorSettings.set("eventOfFocus", new IntNode(1)); final ObjectNode responseCombinerSettings = new ObjectNode(JsonNodeFactory.instance); responseCombinerSettings.set("type", new TextNode("CompetingRiskResponseCombiner")); responseCombinerSettings.set("events", new ArrayNode(JsonNodeFactory.instance, Utils.easyList(new IntNode(1), new IntNode(2))) ); // not setting times final ObjectNode treeCombinerSettings = new ObjectNode(JsonNodeFactory.instance); treeCombinerSettings.set("type", new TextNode("CompetingRiskFunctionCombiner")); treeCombinerSettings.set("events", new ArrayNode(JsonNodeFactory.instance, Utils.easyList(new IntNode(1), new IntNode(2))) ); // not setting times final ObjectNode yVarSettings = new ObjectNode(JsonNodeFactory.instance); yVarSettings.set("type", new TextNode("CompetingRiskResponse")); yVarSettings.set("u", new TextNode("time")); yVarSettings.set("delta", new TextNode("status")); return Settings.builder() .covariates(Utils.easyList( new NumericCovariateSettings("ageatfda"), new BooleanCovariateSettings("idu"), new BooleanCovariateSettings("black"), new NumericCovariateSettings("cd4nadir") ) ) .dataFileLocation("src/test/resources/wihs.csv") .responseCombinerSettings(responseCombinerSettings) .treeCombinerSettings(treeCombinerSettings) .groupDifferentiatorSettings(groupDifferentiatorSettings) .yVarSettings(yVarSettings) .maxNodeDepth(100000) // TODO fill in these settings .mtry(2) .nodeSize(6) .ntree(100) .numberOfSplits(5) .numberOfThreads(3) .saveProgress(true) .saveTreeLocation("trees/") .build(); } public List<Covariate> getCovariates(Settings settings){ return settings.getCovariates().stream().map(covariateSettings -> covariateSettings.build()).collect(Collectors.toList()); } public CovariateRow getPredictionRow(List<Covariate> covariates){ return CovariateRow.createSimple(Utils.easyMap( "ageatfda", "35", "idu", "false", "black", "false", "cd4nadir", "0.81") , covariates, 1); } @Test public void testSingleTree() throws IOException { final Settings settings = getSettings(); settings.setDataFileLocation("src/test/resources/wihs.bootstrapped.csv"); settings.setCovariates(Utils.easyList( new BooleanCovariateSettings("idu"), new BooleanCovariateSettings("black") )); // by only using BooleanCovariates (only one split rule) we can guarantee identical results with randomForestSRC on one tree. final List<Covariate> covariates = getCovariates(settings); final List<Row<CompetingRiskResponse>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getDataFileLocation()); final TreeTrainer<CompetingRiskResponse, CompetingRiskFunctions> treeTrainer = new TreeTrainer<>(settings, covariates); final Node<CompetingRiskFunctions> node = treeTrainer.growTree(dataset); final CovariateRow newRow = getPredictionRow(covariates); final CompetingRiskFunctions functions = node.evaluate(newRow); final MathFunction causeOneCIFFunction = functions.getCumulativeIncidenceFunction(1); final MathFunction causeTwoCIFFunction = functions.getCumulativeIncidenceFunction(2); final MathFunction cumHazOneFunction = functions.getCauseSpecificHazardFunction(1); final MathFunction cumHazTwoFunction = functions.getCauseSpecificHazardFunction(2); final double margin = 0.0000001; closeEnough(0.003003003, causeOneCIFFunction.evaluate(0.02).getY(), margin); closeEnough(0.166183852, causeOneCIFFunction.evaluate(1.00).getY(), margin); closeEnough(0.715625487, causeOneCIFFunction.evaluate(6.50).getY(), margin); closeEnough(0.794796334, causeOneCIFFunction.evaluate(10.60).getY(), margin); closeEnough(0.794796334, causeOneCIFFunction.evaluate(10.80).getY(), margin); closeEnough(0.08149211, causeTwoCIFFunction.evaluate(1.00).getY(), margin); closeEnough(0.14926318, causeTwoCIFFunction.evaluate(6.50).getY(), margin); closeEnough(0.15332850, causeTwoCIFFunction.evaluate(10.80).getY(), margin); closeEnough(0.1888601, cumHazOneFunction.evaluate(1.00).getY(), margin); closeEnough(1.6189759, cumHazOneFunction.evaluate(6.50).getY(), margin); closeEnough(2.4878342, cumHazOneFunction.evaluate(10.80).getY(), margin); closeEnough(0.08946513, cumHazTwoFunction.evaluate(1.00).getY(), margin); closeEnough(0.32801830, cumHazTwoFunction.evaluate(6.50).getY(), margin); closeEnough(0.36505534, cumHazTwoFunction.evaluate(10.80).getY(), margin); } /** * Note - this test triggers a situation where the variance calculation in the log-rank test experiences an NaN. * * @throws IOException */ @Test public void testSingleTree2() throws IOException { final Settings settings = getSettings(); settings.setMtry(4); settings.setNumberOfSplits(0); settings.setDataFileLocation("src/test/resources/wihs.bootstrapped2.csv"); final List<Covariate> covariates = getCovariates(settings); final List<Row<CompetingRiskResponse>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getDataFileLocation()); final TreeTrainer<CompetingRiskResponse, CompetingRiskFunctions> treeTrainer = new TreeTrainer<>(settings, covariates); final Node<CompetingRiskFunctions> node = treeTrainer.growTree(dataset); final CovariateRow newRow = getPredictionRow(covariates); final CompetingRiskFunctions functions = node.evaluate(newRow); final MathFunction causeOneCIFFunction = functions.getCumulativeIncidenceFunction(1); final MathFunction causeTwoCIFFunction = functions.getCumulativeIncidenceFunction(2); final MathFunction cumHazOneFunction = functions.getCauseSpecificHazardFunction(1); final MathFunction cumHazTwoFunction = functions.getCauseSpecificHazardFunction(2); final double margin = 0.0000001; closeEnough(0, causeOneCIFFunction.evaluate(0.02).getY(), margin); closeEnough(0.555555555, causeOneCIFFunction.evaluate(0.4).getY(), margin); closeEnough(0.66666666666, causeOneCIFFunction.evaluate(0.8).getY(), margin); closeEnough(0.88888888888, causeOneCIFFunction.evaluate(0.9).getY(), margin); closeEnough(1.0, causeOneCIFFunction.evaluate(1.0).getY(), margin); /* closeEnough(0.08149211, causeTwoCIFFunction.evaluate(1.00).getY(), margin); closeEnough(0.14926318, causeTwoCIFFunction.evaluate(6.50).getY(), margin); closeEnough(0.15332850, causeTwoCIFFunction.evaluate(10.80).getY(), margin); closeEnough(0.1888601, cumHazOneFunction.evaluate(1.00).getY(), margin); closeEnough(1.6189759, cumHazOneFunction.evaluate(6.50).getY(), margin); closeEnough(2.4878342, cumHazOneFunction.evaluate(10.80).getY(), margin); closeEnough(0.08946513, cumHazTwoFunction.evaluate(1.00).getY(), margin); closeEnough(0.32801830, cumHazTwoFunction.evaluate(6.50).getY(), margin); closeEnough(0.36505534, cumHazTwoFunction.evaluate(10.80).getY(), margin); */ } @Test public void testLogRankSingleGroupDifferentiatorTwoBooleans() throws IOException { final Settings settings = getSettings(); settings.setCovariates(Utils.easyList( new BooleanCovariateSettings("idu"), new BooleanCovariateSettings("black") )); final List<Covariate> covariates = getCovariates(settings); final List<Row<CompetingRiskResponse>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getDataFileLocation()); final ForestTrainer<CompetingRiskResponse, CompetingRiskFunctions, CompetingRiskFunctions> forestTrainer = new ForestTrainer<>(settings, dataset, covariates); final Forest<CompetingRiskFunctions, CompetingRiskFunctions> forest = forestTrainer.trainSerial(); // prediction row // time status ageatfda idu black cd4nadir //409 1.3 1 35 FALSE FALSE 0.81 final CovariateRow newRow = getPredictionRow(covariates); final CompetingRiskFunctions functions = forest.evaluate(newRow); assertCumulativeFunction(functions.getCauseSpecificHazardFunction(1)); assertCumulativeFunction(functions.getCauseSpecificHazardFunction(2)); assertCumulativeFunction(functions.getCumulativeIncidenceFunction(1)); assertCumulativeFunction(functions.getCumulativeIncidenceFunction(2)); closeEnough(0.63, functions.getCumulativeIncidenceFunction(1).evaluate(4.0).getY(), 0.01); closeEnough(0.765, functions.getCumulativeIncidenceFunction(1).evaluate(10.8).getY(), 0.01); closeEnough(0.163, functions.getCumulativeIncidenceFunction(2).evaluate(4.0).getY(), 0.01); closeEnough(0.195, functions.getCumulativeIncidenceFunction(2).evaluate(10.8).getY(), 0.01); final CompetingRiskErrorRateCalculator errorRateCalculator = new CompetingRiskErrorRateCalculator(dataset, forest); final double[] errorRates = errorRateCalculator.calculateConcordance(new int[]{1,2}); // Error rates happen to be about the same /* randomForestSRC results; ignored for now closeEnough(0.4795, errorRates[0], 0.007); closeEnough(0.478, errorRates[1], 0.008); */ System.out.println(errorRates[0]); System.out.println(errorRates[1]); closeEnough(0.452, errorRates[0], 0.02); closeEnough(0.446, errorRates[1], 0.02); System.out.println(errorRateCalculator.calculateNaiveMortalityError(new int[]{1,2})); } @Test public void verifyDataset() throws IOException { final Settings settings = getSettings(); final List<Covariate> covariates = getCovariates(settings); final List<Row<CompetingRiskResponse>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getDataFileLocation()); // Let's count the events and make sure the data was correctly read. int countCensored = 0; int countEventOne = 0; int countEventTwo = 0; for(final Row<CompetingRiskResponse> row : dataset){ final CompetingRiskResponse response = row.getResponse(); if(response.getDelta() == 0){ countCensored++; } else if(response.getDelta() == 1){ countEventOne++; } else if(response.getDelta() == 2){ countEventTwo++; } else{ throw new RuntimeException("There's an event of type " + response.getDelta()); } } assertEquals(126, countCensored); assertEquals(679, countEventOne); assertEquals(359, countEventTwo); } @Test public void testLogRankSingleGroupDifferentiatorAllCovariates() throws IOException { final Settings settings = getSettings(); settings.setNtree(300); // results are too variable at 100 final List<Covariate> covariates = getCovariates(settings); final List<Row<CompetingRiskResponse>> dataset = DataLoader.loadData(covariates, settings.getResponseLoader(), settings.getDataFileLocation()); final ForestTrainer<CompetingRiskResponse, CompetingRiskFunctions, CompetingRiskFunctions> forestTrainer = new ForestTrainer<>(settings, dataset, covariates); final Forest<CompetingRiskFunctions, CompetingRiskFunctions> forest = forestTrainer.trainSerial(); // prediction row // time status ageatfda idu black cd4nadir //409 1.3 1 35 FALSE FALSE 0.81 final CovariateRow newRow = getPredictionRow(covariates); final CompetingRiskFunctions functions = forest.evaluate(newRow); assertCumulativeFunction(functions.getCauseSpecificHazardFunction(1)); assertCumulativeFunction(functions.getCauseSpecificHazardFunction(2)); assertCumulativeFunction(functions.getCumulativeIncidenceFunction(1)); assertCumulativeFunction(functions.getCumulativeIncidenceFunction(2)); final List<Point> causeOneCIFPoints = functions.getCumulativeIncidenceFunction(1).getPoints(); // We seem to consistently underestimate the results. assertTrue(causeOneCIFPoints.get(causeOneCIFPoints.size()-1).getY() > 0.74, "Results should match randomForestSRC; had " + causeOneCIFPoints.get(causeOneCIFPoints.size()-1).getY()); // note; most observations from randomForestSRC hover around 0.78 but I've seen it as low as 0.72 final CompetingRiskErrorRateCalculator errorRateCalculator = new CompetingRiskErrorRateCalculator(dataset, forest); final double[] errorRates = errorRateCalculator.calculateConcordance(new int[]{1,2}); System.out.println(errorRates[0]); System.out.println(errorRates[1]); /* randomForestSRC results; ignored for now closeEnough(0.412, errorRates[0], 0.007); closeEnough(0.384, errorRates[1], 0.007); */ // Consistency results closeEnough(0.395, errorRates[0], 0.01); closeEnough(0.345, errorRates[1], 0.01); System.out.println(errorRateCalculator.calculateNaiveMortalityError(new int[]{1,2})); } }