Explicitly store RightContinuousStepFunction in CompetingRiskFunctions

Done so that RUtils is useful. Also optimized imports.
This commit is contained in:
Joel Therrien 2018-10-25 10:49:43 -07:00
parent c68f67e47a
commit ae91dbe9e7
9 changed files with 37 additions and 29 deletions

View file

@ -1,7 +1,13 @@
package ca.joeltherrien.randomforest;
import ca.joeltherrien.randomforest.covariates.*;
import ca.joeltherrien.randomforest.responses.competingrisk.*;
import ca.joeltherrien.randomforest.covariates.BooleanCovariateSettings;
import ca.joeltherrien.randomforest.covariates.Covariate;
import ca.joeltherrien.randomforest.covariates.FactorCovariateSettings;
import ca.joeltherrien.randomforest.covariates.NumericCovariateSettings;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskErrorRateCalculator;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskFunctions;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponseWithCensorTime;
import ca.joeltherrien.randomforest.responses.competingrisk.combiner.CompetingRiskFunctionCombiner;
import ca.joeltherrien.randomforest.tree.Forest;
import ca.joeltherrien.randomforest.tree.ForestTrainer;
@ -12,7 +18,9 @@ import com.fasterxml.jackson.databind.node.JsonNodeFactory;
import com.fasterxml.jackson.databind.node.ObjectNode;
import com.fasterxml.jackson.databind.node.TextNode;
import java.io.*;
import java.io.File;
import java.io.IOException;
import java.io.PrintWriter;
import java.util.List;
public class Main {

View file

@ -5,7 +5,10 @@ import ca.joeltherrien.randomforest.Row;
import ca.joeltherrien.randomforest.tree.Split;
import java.io.Serializable;
import java.util.*;
import java.util.ArrayList;
import java.util.Collection;
import java.util.LinkedList;
import java.util.List;
import java.util.concurrent.ThreadLocalRandom;
public interface Covariate<V> extends Serializable {

View file

@ -1,6 +1,6 @@
package ca.joeltherrien.randomforest.responses.competingrisk;
import ca.joeltherrien.randomforest.utils.StepFunction;
import ca.joeltherrien.randomforest.utils.RightContinuousStepFunction;
import lombok.Builder;
import lombok.Getter;
@ -10,22 +10,22 @@ import java.util.List;
@Builder
public class CompetingRiskFunctions implements Serializable {
private final List<StepFunction> causeSpecificHazards;
private final List<StepFunction> cumulativeIncidenceCurves;
private final List<RightContinuousStepFunction> causeSpecificHazards;
private final List<RightContinuousStepFunction> cumulativeIncidenceCurves;
@Getter
private final StepFunction survivalCurve;
private final RightContinuousStepFunction survivalCurve;
public StepFunction getCauseSpecificHazardFunction(int cause){
public RightContinuousStepFunction getCauseSpecificHazardFunction(int cause){
return causeSpecificHazards.get(cause-1);
}
public StepFunction getCumulativeIncidenceFunction(int cause) {
public RightContinuousStepFunction getCumulativeIncidenceFunction(int cause) {
return cumulativeIncidenceCurves.get(cause-1);
}
public double calculateEventSpecificMortality(final int event, final double tau){
final StepFunction cif = getCumulativeIncidenceFunction(event);
final RightContinuousStepFunction cif = getCumulativeIncidenceFunction(event);
double summation = 0.0;

View file

@ -3,7 +3,6 @@ package ca.joeltherrien.randomforest.responses.competingrisk.combiner;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskFunctions;
import ca.joeltherrien.randomforest.tree.ResponseCombiner;
import ca.joeltherrien.randomforest.utils.RightContinuousStepFunction;
import ca.joeltherrien.randomforest.utils.StepFunction;
import lombok.RequiredArgsConstructor;
import java.util.ArrayList;
@ -50,10 +49,10 @@ public class CompetingRiskFunctionCombiner implements ResponseCombiner<Competing
.sum();
}
final StepFunction survivalFunction = new RightContinuousStepFunction(timesToUse, survivalY, 1.0);
final RightContinuousStepFunction survivalFunction = new RightContinuousStepFunction(timesToUse, survivalY, 1.0);
final List<StepFunction> causeSpecificCumulativeHazardFunctionList = new ArrayList<>(events.length);
final List<StepFunction> cumulativeIncidenceFunctionList = new ArrayList<>(events.length);
final List<RightContinuousStepFunction> causeSpecificCumulativeHazardFunctionList = new ArrayList<>(events.length);
final List<RightContinuousStepFunction> cumulativeIncidenceFunctionList = new ArrayList<>(events.length);
for(final int event : events){

View file

@ -5,9 +5,10 @@ import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskRespons
import ca.joeltherrien.randomforest.tree.ResponseCombiner;
import ca.joeltherrien.randomforest.utils.Point;
import ca.joeltherrien.randomforest.utils.RightContinuousStepFunction;
import ca.joeltherrien.randomforest.utils.StepFunction;
import java.util.*;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
/**
* This class takes all of the observations in a terminal node and combines them to produce estimates of the cause-specific hazard function
@ -38,8 +39,8 @@ public class CompetingRiskResponseCombiner implements ResponseCombiner<Competing
@Override
public CompetingRiskFunctions combine(List<CompetingRiskResponse> responses) {
final List<StepFunction> causeSpecificCumulativeHazardFunctionList = new ArrayList<>(events.length);
final List<StepFunction> cumulativeIncidenceFunctionList = new ArrayList<>(events.length);
final List<RightContinuousStepFunction> causeSpecificCumulativeHazardFunctionList = new ArrayList<>(events.length);
final List<RightContinuousStepFunction> cumulativeIncidenceFunctionList = new ArrayList<>(events.length);
Collections.sort(responses, (y1, y2) -> {
if(y1.getU() < y2.getU()){
@ -97,7 +98,7 @@ public class CompetingRiskResponseCombiner implements ResponseCombiner<Competing
}
}
final StepFunction survivalCurve = RightContinuousStepFunction.constructFromPoints(survivalPoints, 1.0);
final RightContinuousStepFunction survivalCurve = RightContinuousStepFunction.constructFromPoints(survivalPoints, 1.0);
for(final int event : events){
@ -138,10 +139,10 @@ public class CompetingRiskResponseCombiner implements ResponseCombiner<Competing
}
final StepFunction causeSpecificCumulativeHazardFunction = RightContinuousStepFunction.constructFromPoints(hazardFunctionPoints, 0.0);
final RightContinuousStepFunction causeSpecificCumulativeHazardFunction = RightContinuousStepFunction.constructFromPoints(hazardFunctionPoints, 0.0);
causeSpecificCumulativeHazardFunctionList.add(event-1, causeSpecificCumulativeHazardFunction);
final StepFunction cifFunction = RightContinuousStepFunction.constructFromPoints(cifPoints, 0.0);
final RightContinuousStepFunction cifFunction = RightContinuousStepFunction.constructFromPoints(cifPoints, 0.0);
cumulativeIncidenceFunctionList.add(event-1, cifFunction);
}

View file

@ -1,7 +1,7 @@
package ca.joeltherrien.randomforest.tree;
import ca.joeltherrien.randomforest.covariates.Covariate;
import ca.joeltherrien.randomforest.CovariateRow;
import ca.joeltherrien.randomforest.covariates.Covariate;
import lombok.Builder;
@Builder

View file

@ -1,7 +1,6 @@
package ca.joeltherrien.randomforest.tree;
import ca.joeltherrien.randomforest.CovariateRow;
import lombok.RequiredArgsConstructor;
@RequiredArgsConstructor

View file

@ -1,7 +1,6 @@
package ca.joeltherrien.randomforest.utils;
import java.io.*;
import java.util.List;
import java.util.zip.GZIPInputStream;
import java.util.zip.GZIPOutputStream;

View file

@ -4,7 +4,6 @@ import ca.joeltherrien.randomforest.TestUtils;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskFunctions;
import ca.joeltherrien.randomforest.utils.Point;
import ca.joeltherrien.randomforest.utils.RightContinuousStepFunction;
import ca.joeltherrien.randomforest.utils.StepFunction;
import ca.joeltherrien.randomforest.utils.Utils;
import org.junit.jupiter.api.Test;
@ -14,7 +13,7 @@ public class TestCompetingRiskFunctions {
@Test
public void testCalculateEventSpecificMortality(){
final StepFunction cif1 = RightContinuousStepFunction.constructFromPoints(
final RightContinuousStepFunction cif1 = RightContinuousStepFunction.constructFromPoints(
Utils.easyList(
new Point(1.0, 0.3),
new Point(1.5, 0.7),
@ -23,10 +22,10 @@ public class TestCompetingRiskFunctions {
);
// not being used
final StepFunction chf1 = RightContinuousStepFunction.constructFromPoints(Collections.emptyList(), 0.0);
final RightContinuousStepFunction chf1 = RightContinuousStepFunction.constructFromPoints(Collections.emptyList(), 0.0);
// not being used
final StepFunction km = RightContinuousStepFunction.constructFromPoints(Collections.emptyList(), 0.0);
final RightContinuousStepFunction km = RightContinuousStepFunction.constructFromPoints(Collections.emptyList(), 0.0);
final CompetingRiskFunctions functions = CompetingRiskFunctions.builder()
.causeSpecificHazards(Collections.singletonList(chf1))