Skip to content

Commit 23ca156

Browse files
author
andre.cruz
committed
added tests for FairGBM openml interface
1 parent 8f7863d commit 23ca156

File tree

2 files changed

+151
-6
lines changed

2 files changed

+151
-6
lines changed

openml-lightgbm/lightgbm-provider/src/test/java/com/feedzai/openml/provider/lightgbm/FairGBMBinaryClassificationModelTrainerTest.java

Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,21 @@
55
import java.util.ArrayList;
66
import java.util.List;
77
import java.util.Map;
8+
import java.util.Random;
89
import org.junit.BeforeClass;
910
import org.junit.Test;
1011

12+
import com.feedzai.openml.data.Dataset;
13+
import com.feedzai.openml.data.Instance;
1114
import com.feedzai.openml.data.schema.DatasetSchema;
15+
import com.feedzai.openml.mocks.MockDataset;
1216
import com.feedzai.openml.provider.exception.ModelLoadingException;
1317

1418
import static com.feedzai.openml.provider.lightgbm.FairGBMDescriptorUtil.CONSTRAINT_GROUP_COLUMN_PARAMETER_NAME;
19+
import static com.feedzai.openml.provider.lightgbm.LightGBMBinaryClassificationModelTrainerTest.average;
20+
import static com.feedzai.openml.provider.lightgbm.LightGBMBinaryClassificationModelTrainerTest.ensureFeatureContributions;
1521
import static org.assertj.core.api.Assertions.assertThat;
22+
import static org.assertj.core.api.Assertions.assertThatThrownBy;
1623

1724
/**
1825
* Tests for using the LightGBMBinaryClassificationModelTrainer class with FairGBM.
@@ -71,6 +78,73 @@ public static void setupFixture() {
7178
// MODEL_PARAMS.replace(NUM_ITERATIONS_PARAMETER_NAME, NUM_ITERATIONS_FOR_FAST_TESTS);
7279
}
7380

81+
/**
82+
* Asserts that a model trained with numericals+categoricals and evaluated on the same datasource
83+
* has in average higher scores for the positive class (1) than for the negative one (0).
84+
*
85+
* @throws URISyntaxException In case of error retrieving the data resource path.
86+
* @throws IOException In case of error reading data.
87+
* @throws ModelLoadingException In case of error training the model.
88+
*/
89+
@Test
90+
public void fitWithNumericalsAndCategoricals() throws URISyntaxException, IOException, ModelLoadingException {
91+
92+
final ArrayList<List<Double>> scoresPerClass = fitModelAndGetFirstScoresPerClass(
93+
DATASET_RESOURCE_NAME,
94+
TestSchemas.CATEGORICALS_SCHEMA_LABEL_AT_START,
95+
MAX_NUMBER_OF_INSTANCES_TO_TRAIN,
96+
MAX_NUMBER_OF_INSTANCES_TO_SCORE,
97+
SMALL_TRAIN_DATA_CHUNK_INSTANCES_SIZE
98+
);
99+
100+
assertThat(average(scoresPerClass.get(0))).as("score average per class")
101+
.isLessThan(average(scoresPerClass.get(1)));
102+
}
103+
104+
/**
105+
* Assert that in general, a model trained+scored on schemas where the position of
106+
* the label changes results in exactly the same scores.
107+
* <p>
108+
* This tests for regressions on the copying data code during train that at the start
109+
* of development resulted in broken scores (mostly constant) that were very hard to diagnose.
110+
*
111+
* @throws URISyntaxException In case of error retrieving the data resource path.
112+
* @throws IOException In case of error reading data.
113+
* @throws ModelLoadingException In case of error training the model.
114+
*/
115+
@Test
116+
public void fitCategoricalsWithLabelInStartMiddleOrEndHasSameResults()
117+
throws URISyntaxException, IOException, ModelLoadingException {
118+
119+
final ArrayList<List<Double>> scoresPerClassForLabelAtStart = fitModelAndGetFirstScoresPerClass(
120+
DATASET_RESOURCE_NAME,
121+
TestSchemas.CATEGORICALS_SCHEMA_LABEL_AT_START,
122+
MAX_NUMBER_OF_INSTANCES_TO_TRAIN,
123+
MAX_NUMBER_OF_INSTANCES_TO_SCORE,
124+
SMALL_TRAIN_DATA_CHUNK_INSTANCES_SIZE
125+
);
126+
127+
final ArrayList<List<Double>> scoresPerClassForLabelInMiddle = fitModelAndGetFirstScoresPerClass(
128+
DATASET_RESOURCE_NAME,
129+
TestSchemas.CATEGORICALS_SCHEMA_LABEL_IN_MIDDLE,
130+
MAX_NUMBER_OF_INSTANCES_TO_TRAIN,
131+
MAX_NUMBER_OF_INSTANCES_TO_SCORE,
132+
SMALL_TRAIN_DATA_CHUNK_INSTANCES_SIZE
133+
);
134+
135+
final ArrayList<List<Double>> scoresPerClassForLabelAtEnd = fitModelAndGetFirstScoresPerClass(
136+
DATASET_RESOURCE_NAME,
137+
TestSchemas.CATEGORICALS_SCHEMA_LABEL_AT_END,
138+
MAX_NUMBER_OF_INSTANCES_TO_TRAIN,
139+
MAX_NUMBER_OF_INSTANCES_TO_SCORE,
140+
SMALL_TRAIN_DATA_CHUNK_INSTANCES_SIZE
141+
);
142+
143+
assertThat(scoresPerClassForLabelAtStart).as("scores")
144+
.isEqualTo(scoresPerClassForLabelInMiddle)
145+
.isEqualTo(scoresPerClassForLabelAtEnd);
146+
}
147+
74148
@Test
75149
public void fitResultsAreIndependentOfTrainChunkSizes()
76150
throws URISyntaxException, IOException, ModelLoadingException {
@@ -104,6 +178,77 @@ public void fitResultsAreIndependentOfTrainChunkSizes()
104178
.isEqualTo(scoresWithSingleChunk);
105179
}
106180

181+
/**
182+
* Assert that there's an error when training with no instances.
183+
*/
184+
@Test
185+
public void fitWithNoInstances() {
186+
187+
final List<Instance> noInstances = new ArrayList<>();
188+
final Dataset emptyDataset = new MockDataset(TestSchemas.CATEGORICALS_SCHEMA_LABEL_AT_START, noInstances);
189+
190+
assertThatThrownBy(() ->
191+
new LightGBMModelCreator().fit(
192+
emptyDataset,
193+
new Random(),
194+
MODEL_PARAMS
195+
)
196+
)
197+
.isInstanceOf(RuntimeException.class);
198+
}
199+
200+
/**
201+
* Test Feature Contributions with target at end.
202+
*
203+
* @throws URISyntaxException For errors when loading the dataset resource.
204+
* @throws IOException For errors when reading the dataset.
205+
* @since 1.3.0
206+
*/
207+
@Test
208+
public void testFeatureContributionsTargetEnd() throws URISyntaxException, IOException {
209+
final Dataset dataset = CSVUtils.getDatasetWithSchema(
210+
TestResources.getResourcePath(DATASET_RESOURCE_NAME),
211+
TestSchemas.CATEGORICALS_SCHEMA_LABEL_AT_END,
212+
10000
213+
);
214+
ensureFeatureContributions(dataset, MODEL_PARAMS);
215+
}
216+
217+
/**
218+
* Test Feature Contributions with target at middle.
219+
*
220+
* @throws URISyntaxException For errors when loading the dataset resource.
221+
* @throws IOException For errors when reading the dataset.
222+
* @since 1.3.0
223+
*/
224+
@Test
225+
public void testFeatureContributionsTargetMiddle() throws URISyntaxException, IOException {
226+
final Dataset dataset = CSVUtils.getDatasetWithSchema(
227+
TestResources.getResourcePath(DATASET_RESOURCE_NAME),
228+
TestSchemas.CATEGORICALS_SCHEMA_LABEL_IN_MIDDLE,
229+
10000
230+
);
231+
ensureFeatureContributions(dataset, MODEL_PARAMS);
232+
}
233+
234+
/**
235+
* Test Feature Contributions with target at beginning.
236+
*
237+
* @throws URISyntaxException For errors when loading the dataset resource.
238+
* @throws IOException For errors when reading the dataset.
239+
* @since 1.3.0
240+
*/
241+
@Test
242+
public void testFeatureContributionsTargetBeginning() throws URISyntaxException, IOException {
243+
final Dataset dataset = CSVUtils.getDatasetWithSchema(
244+
TestResources.getResourcePath(DATASET_RESOURCE_NAME),
245+
TestSchemas.CATEGORICALS_SCHEMA_LABEL_AT_START,
246+
10000
247+
);
248+
ensureFeatureContributions(dataset, MODEL_PARAMS);
249+
}
250+
251+
107252
static ArrayList<List<Double>> fitModelAndGetFirstScoresPerClass(
108253
final String datasetResourceName,
109254
final DatasetSchema schema,

openml-lightgbm/lightgbm-provider/src/test/java/com/feedzai/openml/provider/lightgbm/LightGBMBinaryClassificationModelTrainerTest.java

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -317,7 +317,7 @@ public void testFeatureContributionsTargetEnd() throws URISyntaxException, IOExc
317317
TestSchemas.CATEGORICALS_SCHEMA_LABEL_AT_END,
318318
10000
319319
);
320-
ensureFeatureContributions(dataset);
320+
ensureFeatureContributions(dataset, MODEL_PARAMS);
321321
}
322322

323323
/**
@@ -334,7 +334,7 @@ public void testFeatureContributionsTargetMiddle() throws URISyntaxException, IO
334334
TestSchemas.CATEGORICALS_SCHEMA_LABEL_IN_MIDDLE,
335335
10000
336336
);
337-
ensureFeatureContributions(dataset);
337+
ensureFeatureContributions(dataset, MODEL_PARAMS);
338338
}
339339

340340
/**
@@ -351,7 +351,7 @@ public void testFeatureContributionsTargetBeginning() throws URISyntaxException,
351351
TestSchemas.CATEGORICALS_SCHEMA_LABEL_AT_START,
352352
10000
353353
);
354-
ensureFeatureContributions(dataset);
354+
ensureFeatureContributions(dataset, MODEL_PARAMS);
355355
}
356356

357357
/**
@@ -360,12 +360,12 @@ public void testFeatureContributionsTargetBeginning() throws URISyntaxException,
360360
* @param dataset The {@link Dataset}.
361361
* @since 1.3.0
362362
*/
363-
private void ensureFeatureContributions(final Dataset dataset) {
363+
static void ensureFeatureContributions(final Dataset dataset, final Map<String, String> modelParams) {
364364
final int targetIndex = dataset.getSchema().getTargetIndex().get();
365365
final int num1Index = 1;
366366
final int cat1Index = 4;
367367

368-
final Map<String, String> trainParams = new HashMap<>(MODEL_PARAMS);
368+
final Map<String, String> trainParams = new HashMap<>(modelParams);
369369
trainParams.replace(NUM_ITERATIONS_PARAMETER_NAME, "100");
370370

371371
final LightGBMBinaryClassificationModel model = new LightGBMModelCreator().fit(
@@ -484,7 +484,7 @@ static ArrayList<List<Double>> getClassScores(final Dataset dataset,
484484
* @param inputArray Input array from which to compute the average.
485485
* @return Average
486486
*/
487-
double average(final List<Double> inputArray) {
487+
static double average(final List<Double> inputArray) {
488488

489489
double sum = 0.0;
490490
for (final Double x : inputArray) {

0 commit comments

Comments
 (0)