|
15 | 15 |
|
16 | 16 | """Tests for croissant_builder.""" |
17 | 17 |
|
| 18 | +from typing import Any, Dict, List, Type |
18 | 19 | import numpy as np |
19 | 20 | import pytest |
20 | 21 | from tensorflow_datasets import testing |
@@ -146,7 +147,10 @@ def _create_mlc_field( |
146 | 147 | ], |
147 | 148 | ) |
148 | 149 | def test_simple_datatype_converter( |
149 | | - mlc_field, expected_feature, int_dtype, float_dtype |
| 150 | + mlc_field: mlc.Field, |
| 151 | + expected_feature: type[Any], |
| 152 | + int_dtype: np.dtype | None, |
| 153 | + float_dtype: np.dtype | None, |
150 | 154 | ): |
151 | 155 | actual_feature = croissant_builder.datatype_converter( |
152 | 156 | mlc_field, |
@@ -252,7 +256,11 @@ def test_datatype_converter_bbox_with_invalid_format(): |
252 | 256 | ), |
253 | 257 | ], |
254 | 258 | ) |
255 | | -def test_datatype_converter_complex(mlc_field, feature_type, subfield_types): |
| 259 | +def test_datatype_converter_complex( |
| 260 | + mlc_field: mlc.Field, |
| 261 | + feature_type: Type[Any], |
| 262 | + subfield_types: Dict[str, Type[Any]] | None, |
| 263 | +): |
256 | 264 | actual_feature = croissant_builder.datatype_converter(mlc_field) |
257 | 265 | assert actual_feature.doc.desc == mlc_field.description |
258 | 266 | assert isinstance(actual_feature, feature_type) |
@@ -411,7 +419,9 @@ def test_version_converter(tmp_path): |
411 | 419 |
|
412 | 420 |
|
413 | 421 | @pytest.fixture(name="crs_builder") |
414 | | -def mock_croissant_dataset_builder(tmp_path, request): |
| 422 | +def mock_croissant_dataset_builder( |
| 423 | + tmp_path, request |
| 424 | +) -> croissant_builder.CroissantBuilder: |
415 | 425 | dataset_name = request.param["dataset_name"] |
416 | 426 | with testing.dummy_croissant_file( |
417 | 427 | dataset_name=dataset_name, |
@@ -477,7 +487,11 @@ def test_croissant_builder(crs_builder): |
477 | 487 | indirect=["crs_builder"], |
478 | 488 | ) |
479 | 489 | @pytest.mark.parametrize("split_name", ["train", "test"]) |
480 | | -def test_download_and_prepare(crs_builder, expected_entries, split_name): |
| 490 | +def test_download_and_prepare( |
| 491 | + crs_builder: croissant_builder.CroissantBuilder, |
| 492 | + expected_entries: List[Dict[str, Any]], |
| 493 | + split_name: str, |
| 494 | +): |
481 | 495 | crs_builder.download_and_prepare() |
482 | 496 | data_source = crs_builder.as_data_source(split=split_name) |
483 | 497 | expected_entries = [ |
|
0 commit comments