Explicitly store RightContinuousStepFunction in CompetingRiskFunctions

Done so that RUtils is useful. Also optimized imports.
This commit is contained in:
Joel Therrien 2018-10-25 10:49:43 -07:00
parent c68f67e47a
commit ae91dbe9e7
9 changed files with 37 additions and 29 deletions

View file

@ -1,7 +1,13 @@
package ca.joeltherrien.randomforest; package ca.joeltherrien.randomforest;
import ca.joeltherrien.randomforest.covariates.*; import ca.joeltherrien.randomforest.covariates.BooleanCovariateSettings;
import ca.joeltherrien.randomforest.responses.competingrisk.*; import ca.joeltherrien.randomforest.covariates.Covariate;
import ca.joeltherrien.randomforest.covariates.FactorCovariateSettings;
import ca.joeltherrien.randomforest.covariates.NumericCovariateSettings;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskErrorRateCalculator;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskFunctions;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponse;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskResponseWithCensorTime;
import ca.joeltherrien.randomforest.responses.competingrisk.combiner.CompetingRiskFunctionCombiner; import ca.joeltherrien.randomforest.responses.competingrisk.combiner.CompetingRiskFunctionCombiner;
import ca.joeltherrien.randomforest.tree.Forest; import ca.joeltherrien.randomforest.tree.Forest;
import ca.joeltherrien.randomforest.tree.ForestTrainer; import ca.joeltherrien.randomforest.tree.ForestTrainer;
@ -12,7 +18,9 @@ import com.fasterxml.jackson.databind.node.JsonNodeFactory;
import com.fasterxml.jackson.databind.node.ObjectNode; import com.fasterxml.jackson.databind.node.ObjectNode;
import com.fasterxml.jackson.databind.node.TextNode; import com.fasterxml.jackson.databind.node.TextNode;
import java.io.*; import java.io.File;
import java.io.IOException;
import java.io.PrintWriter;
import java.util.List; import java.util.List;
public class Main { public class Main {

View file

@ -5,7 +5,10 @@ import ca.joeltherrien.randomforest.Row;
import ca.joeltherrien.randomforest.tree.Split; import ca.joeltherrien.randomforest.tree.Split;
import java.io.Serializable; import java.io.Serializable;
import java.util.*; import java.util.ArrayList;
import java.util.Collection;
import java.util.LinkedList;
import java.util.List;
import java.util.concurrent.ThreadLocalRandom; import java.util.concurrent.ThreadLocalRandom;
public interface Covariate<V> extends Serializable { public interface Covariate<V> extends Serializable {

View file

@ -1,6 +1,6 @@
package ca.joeltherrien.randomforest.responses.competingrisk; package ca.joeltherrien.randomforest.responses.competingrisk;
import ca.joeltherrien.randomforest.utils.StepFunction; import ca.joeltherrien.randomforest.utils.RightContinuousStepFunction;
import lombok.Builder; import lombok.Builder;
import lombok.Getter; import lombok.Getter;
@ -10,22 +10,22 @@ import java.util.List;
@Builder @Builder
public class CompetingRiskFunctions implements Serializable { public class CompetingRiskFunctions implements Serializable {
private final List<StepFunction> causeSpecificHazards; private final List<RightContinuousStepFunction> causeSpecificHazards;
private final List<StepFunction> cumulativeIncidenceCurves; private final List<RightContinuousStepFunction> cumulativeIncidenceCurves;
@Getter @Getter
private final StepFunction survivalCurve; private final RightContinuousStepFunction survivalCurve;
public StepFunction getCauseSpecificHazardFunction(int cause){ public RightContinuousStepFunction getCauseSpecificHazardFunction(int cause){
return causeSpecificHazards.get(cause-1); return causeSpecificHazards.get(cause-1);
} }
public StepFunction getCumulativeIncidenceFunction(int cause) { public RightContinuousStepFunction getCumulativeIncidenceFunction(int cause) {
return cumulativeIncidenceCurves.get(cause-1); return cumulativeIncidenceCurves.get(cause-1);
} }
public double calculateEventSpecificMortality(final int event, final double tau){ public double calculateEventSpecificMortality(final int event, final double tau){
final StepFunction cif = getCumulativeIncidenceFunction(event); final RightContinuousStepFunction cif = getCumulativeIncidenceFunction(event);
double summation = 0.0; double summation = 0.0;

View file

@ -3,7 +3,6 @@ package ca.joeltherrien.randomforest.responses.competingrisk.combiner;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskFunctions; import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskFunctions;
import ca.joeltherrien.randomforest.tree.ResponseCombiner; import ca.joeltherrien.randomforest.tree.ResponseCombiner;
import ca.joeltherrien.randomforest.utils.RightContinuousStepFunction; import ca.joeltherrien.randomforest.utils.RightContinuousStepFunction;
import ca.joeltherrien.randomforest.utils.StepFunction;
import lombok.RequiredArgsConstructor; import lombok.RequiredArgsConstructor;
import java.util.ArrayList; import java.util.ArrayList;
@ -50,10 +49,10 @@ public class CompetingRiskFunctionCombiner implements ResponseCombiner<Competing
.sum(); .sum();
} }
final StepFunction survivalFunction = new RightContinuousStepFunction(timesToUse, survivalY, 1.0); final RightContinuousStepFunction survivalFunction = new RightContinuousStepFunction(timesToUse, survivalY, 1.0);
final List<StepFunction> causeSpecificCumulativeHazardFunctionList = new ArrayList<>(events.length); final List<RightContinuousStepFunction> causeSpecificCumulativeHazardFunctionList = new ArrayList<>(events.length);
final List<StepFunction> cumulativeIncidenceFunctionList = new ArrayList<>(events.length); final List<RightContinuousStepFunction> cumulativeIncidenceFunctionList = new ArrayList<>(events.length);
for(final int event : events){ for(final int event : events){

View file

@ -5,9 +5,10 @@ import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskRespons
import ca.joeltherrien.randomforest.tree.ResponseCombiner; import ca.joeltherrien.randomforest.tree.ResponseCombiner;
import ca.joeltherrien.randomforest.utils.Point; import ca.joeltherrien.randomforest.utils.Point;
import ca.joeltherrien.randomforest.utils.RightContinuousStepFunction; import ca.joeltherrien.randomforest.utils.RightContinuousStepFunction;
import ca.joeltherrien.randomforest.utils.StepFunction;
import java.util.*; import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
/** /**
* This class takes all of the observations in a terminal node and combines them to produce estimates of the cause-specific hazard function * This class takes all of the observations in a terminal node and combines them to produce estimates of the cause-specific hazard function
@ -38,8 +39,8 @@ public class CompetingRiskResponseCombiner implements ResponseCombiner<Competing
@Override @Override
public CompetingRiskFunctions combine(List<CompetingRiskResponse> responses) { public CompetingRiskFunctions combine(List<CompetingRiskResponse> responses) {
final List<StepFunction> causeSpecificCumulativeHazardFunctionList = new ArrayList<>(events.length); final List<RightContinuousStepFunction> causeSpecificCumulativeHazardFunctionList = new ArrayList<>(events.length);
final List<StepFunction> cumulativeIncidenceFunctionList = new ArrayList<>(events.length); final List<RightContinuousStepFunction> cumulativeIncidenceFunctionList = new ArrayList<>(events.length);
Collections.sort(responses, (y1, y2) -> { Collections.sort(responses, (y1, y2) -> {
if(y1.getU() < y2.getU()){ if(y1.getU() < y2.getU()){
@ -97,7 +98,7 @@ public class CompetingRiskResponseCombiner implements ResponseCombiner<Competing
} }
} }
final StepFunction survivalCurve = RightContinuousStepFunction.constructFromPoints(survivalPoints, 1.0); final RightContinuousStepFunction survivalCurve = RightContinuousStepFunction.constructFromPoints(survivalPoints, 1.0);
for(final int event : events){ for(final int event : events){
@ -138,10 +139,10 @@ public class CompetingRiskResponseCombiner implements ResponseCombiner<Competing
} }
final StepFunction causeSpecificCumulativeHazardFunction = RightContinuousStepFunction.constructFromPoints(hazardFunctionPoints, 0.0); final RightContinuousStepFunction causeSpecificCumulativeHazardFunction = RightContinuousStepFunction.constructFromPoints(hazardFunctionPoints, 0.0);
causeSpecificCumulativeHazardFunctionList.add(event-1, causeSpecificCumulativeHazardFunction); causeSpecificCumulativeHazardFunctionList.add(event-1, causeSpecificCumulativeHazardFunction);
final StepFunction cifFunction = RightContinuousStepFunction.constructFromPoints(cifPoints, 0.0); final RightContinuousStepFunction cifFunction = RightContinuousStepFunction.constructFromPoints(cifPoints, 0.0);
cumulativeIncidenceFunctionList.add(event-1, cifFunction); cumulativeIncidenceFunctionList.add(event-1, cifFunction);
} }

View file

@ -1,7 +1,7 @@
package ca.joeltherrien.randomforest.tree; package ca.joeltherrien.randomforest.tree;
import ca.joeltherrien.randomforest.covariates.Covariate;
import ca.joeltherrien.randomforest.CovariateRow; import ca.joeltherrien.randomforest.CovariateRow;
import ca.joeltherrien.randomforest.covariates.Covariate;
import lombok.Builder; import lombok.Builder;
@Builder @Builder

View file

@ -1,7 +1,6 @@
package ca.joeltherrien.randomforest.tree; package ca.joeltherrien.randomforest.tree;
import ca.joeltherrien.randomforest.CovariateRow; import ca.joeltherrien.randomforest.CovariateRow;
import lombok.RequiredArgsConstructor; import lombok.RequiredArgsConstructor;
@RequiredArgsConstructor @RequiredArgsConstructor

View file

@ -1,7 +1,6 @@
package ca.joeltherrien.randomforest.utils; package ca.joeltherrien.randomforest.utils;
import java.io.*; import java.io.*;
import java.util.List;
import java.util.zip.GZIPInputStream; import java.util.zip.GZIPInputStream;
import java.util.zip.GZIPOutputStream; import java.util.zip.GZIPOutputStream;

View file

@ -4,7 +4,6 @@ import ca.joeltherrien.randomforest.TestUtils;
import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskFunctions; import ca.joeltherrien.randomforest.responses.competingrisk.CompetingRiskFunctions;
import ca.joeltherrien.randomforest.utils.Point; import ca.joeltherrien.randomforest.utils.Point;
import ca.joeltherrien.randomforest.utils.RightContinuousStepFunction; import ca.joeltherrien.randomforest.utils.RightContinuousStepFunction;
import ca.joeltherrien.randomforest.utils.StepFunction;
import ca.joeltherrien.randomforest.utils.Utils; import ca.joeltherrien.randomforest.utils.Utils;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
@ -14,7 +13,7 @@ public class TestCompetingRiskFunctions {
@Test @Test
public void testCalculateEventSpecificMortality(){ public void testCalculateEventSpecificMortality(){
final StepFunction cif1 = RightContinuousStepFunction.constructFromPoints( final RightContinuousStepFunction cif1 = RightContinuousStepFunction.constructFromPoints(
Utils.easyList( Utils.easyList(
new Point(1.0, 0.3), new Point(1.0, 0.3),
new Point(1.5, 0.7), new Point(1.5, 0.7),
@ -23,10 +22,10 @@ public class TestCompetingRiskFunctions {
); );
// not being used // not being used
final StepFunction chf1 = RightContinuousStepFunction.constructFromPoints(Collections.emptyList(), 0.0); final RightContinuousStepFunction chf1 = RightContinuousStepFunction.constructFromPoints(Collections.emptyList(), 0.0);
// not being used // not being used
final StepFunction km = RightContinuousStepFunction.constructFromPoints(Collections.emptyList(), 0.0); final RightContinuousStepFunction km = RightContinuousStepFunction.constructFromPoints(Collections.emptyList(), 0.0);
final CompetingRiskFunctions functions = CompetingRiskFunctions.builder() final CompetingRiskFunctions functions = CompetingRiskFunctions.builder()
.causeSpecificHazards(Collections.singletonList(chf1)) .causeSpecificHazards(Collections.singletonList(chf1))