Changed the covariates to be more clever with how they produce the different splits. In the future (not yet implemented) a clever GroupDifferentiator could update the current score calculation based just on how many rows moved from one hand to the other. There were a few other changes as well; TreeTrainer#growTree now accepts a Random as a parameter which is used throughout the entire growing process. This means it's now theoretically possible to grow trees using a seed, so that results can be fully reproducible.
103 lines
3.8 KiB
Java
103 lines
3.8 KiB
Java
package ca.joeltherrien.randomforest.csv;
|
|
|
|
import ca.joeltherrien.randomforest.DataLoader;
|
|
import ca.joeltherrien.randomforest.Row;
|
|
import ca.joeltherrien.randomforest.Settings;
|
|
import ca.joeltherrien.randomforest.covariates.settings.BooleanCovariateSettings;
|
|
import ca.joeltherrien.randomforest.covariates.Covariate;
|
|
import ca.joeltherrien.randomforest.covariates.settings.FactorCovariateSettings;
|
|
import ca.joeltherrien.randomforest.covariates.settings.NumericCovariateSettings;
|
|
import ca.joeltherrien.randomforest.utils.Utils;
|
|
import com.fasterxml.jackson.databind.node.JsonNodeFactory;
|
|
import com.fasterxml.jackson.databind.node.ObjectNode;
|
|
import com.fasterxml.jackson.databind.node.TextNode;
|
|
import org.junit.jupiter.api.Test;
|
|
|
|
import java.io.IOException;
|
|
import java.util.List;
|
|
|
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
|
import static org.junit.jupiter.api.Assertions.assertTrue;
|
|
|
|
public class TestLoadingCSV {
|
|
|
|
/*
|
|
y,x1,x2,x3
|
|
5,3.0,"mouse",true
|
|
2,1.0,"dog",false
|
|
9,1.5,"cat",true
|
|
-3,NA,NA,NA
|
|
*/
|
|
|
|
|
|
public List<Row<Double>> loadData(String filename) throws IOException {
|
|
final ObjectNode yVarSettings = new ObjectNode(JsonNodeFactory.instance);
|
|
yVarSettings.set("type", new TextNode("Double"));
|
|
yVarSettings.set("name", new TextNode("y"));
|
|
|
|
final Settings settings = Settings.builder()
|
|
.trainingDataLocation(filename)
|
|
.covariateSettings(
|
|
Utils.easyList(new NumericCovariateSettings("x1"),
|
|
new FactorCovariateSettings("x2", Utils.easyList("dog", "cat", "mouse")),
|
|
new BooleanCovariateSettings("x3"))
|
|
)
|
|
.yVarSettings(yVarSettings)
|
|
.build();
|
|
|
|
final List<Covariate> covariates = settings.getCovariates();
|
|
|
|
|
|
final DataLoader.ResponseLoader loader = settings.getResponseLoader();
|
|
|
|
return DataLoader.loadData(covariates, loader, settings.getTrainingDataLocation());
|
|
}
|
|
|
|
@Test
|
|
public void verifyLoadingNormal(final List<Covariate> covariates) throws IOException {
|
|
final List<Row<Double>> data = loadData("src/test/resources/testCSV.csv");
|
|
|
|
assertData(data, covariates);
|
|
}
|
|
|
|
@Test
|
|
public void verifyLoadingGz(final List<Covariate> covariates) throws IOException {
|
|
final List<Row<Double>> data = loadData("src/test/resources/testCSV.csv.gz");
|
|
|
|
assertData(data, covariates);
|
|
}
|
|
|
|
|
|
private void assertData(final List<Row<Double>> data, final List<Covariate> covariates){
|
|
final Covariate x1 = covariates.get(0);
|
|
final Covariate x2 = covariates.get(0);
|
|
final Covariate x3 = covariates.get(0);
|
|
|
|
assertEquals(4, data.size());
|
|
|
|
Row<Double> row = data.get(0);
|
|
assertEquals(5.0, (double)row.getResponse());
|
|
assertEquals(3.0, row.getCovariateValue(x1).getValue());
|
|
assertEquals("mouse", row.getCovariateValue(x2).getValue());
|
|
assertEquals(true, row.getCovariateValue(x3).getValue());
|
|
|
|
row = data.get(1);
|
|
assertEquals(2.0, (double)row.getResponse());
|
|
assertEquals(1.0, row.getCovariateValue(x1).getValue());
|
|
assertEquals("dog", row.getCovariateValue(x2).getValue());
|
|
assertEquals(false, row.getCovariateValue(x3).getValue());
|
|
|
|
row = data.get(2);
|
|
assertEquals(9.0, (double)row.getResponse());
|
|
assertEquals(1.5, row.getCovariateValue(x1).getValue());
|
|
assertEquals("cat", row.getCovariateValue(x2).getValue());
|
|
assertEquals(true, row.getCovariateValue(x3).getValue());
|
|
|
|
row = data.get(3);
|
|
assertEquals(-3.0, (double)row.getResponse());
|
|
assertTrue(row.getCovariateValue(x1).isNA());
|
|
assertTrue(row.getCovariateValue(x2).isNA());
|
|
assertTrue(row.getCovariateValue(x3).isNA());
|
|
}
|
|
|
|
}
|