|
5 | 5 | import java.util.ArrayList; |
6 | 6 | import java.util.List; |
7 | 7 | import java.util.Map; |
| 8 | +import java.util.Random; |
8 | 9 | import org.junit.BeforeClass; |
9 | 10 | import org.junit.Test; |
10 | 11 |
|
| 12 | +import com.feedzai.openml.data.Dataset; |
| 13 | +import com.feedzai.openml.data.Instance; |
11 | 14 | import com.feedzai.openml.data.schema.DatasetSchema; |
| 15 | +import com.feedzai.openml.mocks.MockDataset; |
12 | 16 | import com.feedzai.openml.provider.exception.ModelLoadingException; |
13 | 17 |
|
14 | 18 | import static com.feedzai.openml.provider.lightgbm.FairGBMDescriptorUtil.CONSTRAINT_GROUP_COLUMN_PARAMETER_NAME; |
| 19 | +import static com.feedzai.openml.provider.lightgbm.LightGBMBinaryClassificationModelTrainerTest.average; |
| 20 | +import static com.feedzai.openml.provider.lightgbm.LightGBMBinaryClassificationModelTrainerTest.ensureFeatureContributions; |
15 | 21 | import static org.assertj.core.api.Assertions.assertThat; |
| 22 | +import static org.assertj.core.api.Assertions.assertThatThrownBy; |
16 | 23 |
|
17 | 24 | /** |
18 | 25 | * Tests for using the LightGBMBinaryClassificationModelTrainer class with FairGBM. |
@@ -71,6 +78,73 @@ public static void setupFixture() { |
71 | 78 | // MODEL_PARAMS.replace(NUM_ITERATIONS_PARAMETER_NAME, NUM_ITERATIONS_FOR_FAST_TESTS); |
72 | 79 | } |
73 | 80 |
|
| 81 | + /** |
| 82 | + * Asserts that a model trained with numericals+categoricals and evaluated on the same datasource |
| 83 | + * has in average higher scores for the positive class (1) than for the negative one (0). |
| 84 | + * |
| 85 | + * @throws URISyntaxException In case of error retrieving the data resource path. |
| 86 | + * @throws IOException In case of error reading data. |
| 87 | + * @throws ModelLoadingException In case of error training the model. |
| 88 | + */ |
| 89 | + @Test |
| 90 | + public void fitWithNumericalsAndCategoricals() throws URISyntaxException, IOException, ModelLoadingException { |
| 91 | + |
| 92 | + final ArrayList<List<Double>> scoresPerClass = fitModelAndGetFirstScoresPerClass( |
| 93 | + DATASET_RESOURCE_NAME, |
| 94 | + TestSchemas.CATEGORICALS_SCHEMA_LABEL_AT_START, |
| 95 | + MAX_NUMBER_OF_INSTANCES_TO_TRAIN, |
| 96 | + MAX_NUMBER_OF_INSTANCES_TO_SCORE, |
| 97 | + SMALL_TRAIN_DATA_CHUNK_INSTANCES_SIZE |
| 98 | + ); |
| 99 | + |
| 100 | + assertThat(average(scoresPerClass.get(0))).as("score average per class") |
| 101 | + .isLessThan(average(scoresPerClass.get(1))); |
| 102 | + } |
| 103 | + |
| 104 | + /** |
| 105 | + * Assert that in general, a model trained+scored on schemas where the position of |
| 106 | + * the label changes results in exactly the same scores. |
| 107 | + * <p> |
| 108 | + * This tests for regressions on the copying data code during train that at the start |
| 109 | + * of development resulted in broken scores (mostly constant) that were very hard to diagnose. |
| 110 | + * |
| 111 | + * @throws URISyntaxException In case of error retrieving the data resource path. |
| 112 | + * @throws IOException In case of error reading data. |
| 113 | + * @throws ModelLoadingException In case of error training the model. |
| 114 | + */ |
| 115 | + @Test |
| 116 | + public void fitCategoricalsWithLabelInStartMiddleOrEndHasSameResults() |
| 117 | + throws URISyntaxException, IOException, ModelLoadingException { |
| 118 | + |
| 119 | + final ArrayList<List<Double>> scoresPerClassForLabelAtStart = fitModelAndGetFirstScoresPerClass( |
| 120 | + DATASET_RESOURCE_NAME, |
| 121 | + TestSchemas.CATEGORICALS_SCHEMA_LABEL_AT_START, |
| 122 | + MAX_NUMBER_OF_INSTANCES_TO_TRAIN, |
| 123 | + MAX_NUMBER_OF_INSTANCES_TO_SCORE, |
| 124 | + SMALL_TRAIN_DATA_CHUNK_INSTANCES_SIZE |
| 125 | + ); |
| 126 | + |
| 127 | + final ArrayList<List<Double>> scoresPerClassForLabelInMiddle = fitModelAndGetFirstScoresPerClass( |
| 128 | + DATASET_RESOURCE_NAME, |
| 129 | + TestSchemas.CATEGORICALS_SCHEMA_LABEL_IN_MIDDLE, |
| 130 | + MAX_NUMBER_OF_INSTANCES_TO_TRAIN, |
| 131 | + MAX_NUMBER_OF_INSTANCES_TO_SCORE, |
| 132 | + SMALL_TRAIN_DATA_CHUNK_INSTANCES_SIZE |
| 133 | + ); |
| 134 | + |
| 135 | + final ArrayList<List<Double>> scoresPerClassForLabelAtEnd = fitModelAndGetFirstScoresPerClass( |
| 136 | + DATASET_RESOURCE_NAME, |
| 137 | + TestSchemas.CATEGORICALS_SCHEMA_LABEL_AT_END, |
| 138 | + MAX_NUMBER_OF_INSTANCES_TO_TRAIN, |
| 139 | + MAX_NUMBER_OF_INSTANCES_TO_SCORE, |
| 140 | + SMALL_TRAIN_DATA_CHUNK_INSTANCES_SIZE |
| 141 | + ); |
| 142 | + |
| 143 | + assertThat(scoresPerClassForLabelAtStart).as("scores") |
| 144 | + .isEqualTo(scoresPerClassForLabelInMiddle) |
| 145 | + .isEqualTo(scoresPerClassForLabelAtEnd); |
| 146 | + } |
| 147 | + |
74 | 148 | @Test |
75 | 149 | public void fitResultsAreIndependentOfTrainChunkSizes() |
76 | 150 | throws URISyntaxException, IOException, ModelLoadingException { |
@@ -104,6 +178,77 @@ public void fitResultsAreIndependentOfTrainChunkSizes() |
104 | 178 | .isEqualTo(scoresWithSingleChunk); |
105 | 179 | } |
106 | 180 |
|
| 181 | + /** |
| 182 | + * Assert that there's an error when training with no instances. |
| 183 | + */ |
| 184 | + @Test |
| 185 | + public void fitWithNoInstances() { |
| 186 | + |
| 187 | + final List<Instance> noInstances = new ArrayList<>(); |
| 188 | + final Dataset emptyDataset = new MockDataset(TestSchemas.CATEGORICALS_SCHEMA_LABEL_AT_START, noInstances); |
| 189 | + |
| 190 | + assertThatThrownBy(() -> |
| 191 | + new LightGBMModelCreator().fit( |
| 192 | + emptyDataset, |
| 193 | + new Random(), |
| 194 | + MODEL_PARAMS |
| 195 | + ) |
| 196 | + ) |
| 197 | + .isInstanceOf(RuntimeException.class); |
| 198 | + } |
| 199 | + |
| 200 | + /** |
| 201 | + * Test Feature Contributions with target at end. |
| 202 | + * |
| 203 | + * @throws URISyntaxException For errors when loading the dataset resource. |
| 204 | + * @throws IOException For errors when reading the dataset. |
| 205 | + * @since 1.3.0 |
| 206 | + */ |
| 207 | + @Test |
| 208 | + public void testFeatureContributionsTargetEnd() throws URISyntaxException, IOException { |
| 209 | + final Dataset dataset = CSVUtils.getDatasetWithSchema( |
| 210 | + TestResources.getResourcePath(DATASET_RESOURCE_NAME), |
| 211 | + TestSchemas.CATEGORICALS_SCHEMA_LABEL_AT_END, |
| 212 | + 10000 |
| 213 | + ); |
| 214 | + ensureFeatureContributions(dataset, MODEL_PARAMS); |
| 215 | + } |
| 216 | + |
| 217 | + /** |
| 218 | + * Test Feature Contributions with target at middle. |
| 219 | + * |
| 220 | + * @throws URISyntaxException For errors when loading the dataset resource. |
| 221 | + * @throws IOException For errors when reading the dataset. |
| 222 | + * @since 1.3.0 |
| 223 | + */ |
| 224 | + @Test |
| 225 | + public void testFeatureContributionsTargetMiddle() throws URISyntaxException, IOException { |
| 226 | + final Dataset dataset = CSVUtils.getDatasetWithSchema( |
| 227 | + TestResources.getResourcePath(DATASET_RESOURCE_NAME), |
| 228 | + TestSchemas.CATEGORICALS_SCHEMA_LABEL_IN_MIDDLE, |
| 229 | + 10000 |
| 230 | + ); |
| 231 | + ensureFeatureContributions(dataset, MODEL_PARAMS); |
| 232 | + } |
| 233 | + |
| 234 | + /** |
| 235 | + * Test Feature Contributions with target at beginning. |
| 236 | + * |
| 237 | + * @throws URISyntaxException For errors when loading the dataset resource. |
| 238 | + * @throws IOException For errors when reading the dataset. |
| 239 | + * @since 1.3.0 |
| 240 | + */ |
| 241 | + @Test |
| 242 | + public void testFeatureContributionsTargetBeginning() throws URISyntaxException, IOException { |
| 243 | + final Dataset dataset = CSVUtils.getDatasetWithSchema( |
| 244 | + TestResources.getResourcePath(DATASET_RESOURCE_NAME), |
| 245 | + TestSchemas.CATEGORICALS_SCHEMA_LABEL_AT_START, |
| 246 | + 10000 |
| 247 | + ); |
| 248 | + ensureFeatureContributions(dataset, MODEL_PARAMS); |
| 249 | + } |
| 250 | + |
| 251 | + |
107 | 252 | static ArrayList<List<Double>> fitModelAndGetFirstScoresPerClass( |
108 | 253 | final String datasetResourceName, |
109 | 254 | final DatasetSchema schema, |
|
0 commit comments