Skip to content

Commit 6cda07f

Browse files
author
The TensorFlow Datasets Authors
committed
Add typing to parametrized tests in croissant_builder_test
PiperOrigin-RevId: 797622999
1 parent fed6e92 commit 6cda07f

File tree

1 file changed

+18
-4
lines changed

1 file changed

+18
-4
lines changed

tensorflow_datasets/core/dataset_builders/croissant_builder_test.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
"""Tests for croissant_builder."""
1717

18+
from typing import Any, Dict, List, Type
1819
import numpy as np
1920
import pytest
2021
from tensorflow_datasets import testing
@@ -146,7 +147,10 @@ def _create_mlc_field(
146147
],
147148
)
148149
def test_simple_datatype_converter(
149-
mlc_field, expected_feature, int_dtype, float_dtype
150+
mlc_field: mlc.Field,
151+
expected_feature: type[Any],
152+
int_dtype: np.dtype | None,
153+
float_dtype: np.dtype | None,
150154
):
151155
actual_feature = croissant_builder.datatype_converter(
152156
mlc_field,
@@ -252,7 +256,11 @@ def test_datatype_converter_bbox_with_invalid_format():
252256
),
253257
],
254258
)
255-
def test_datatype_converter_complex(mlc_field, feature_type, subfield_types):
259+
def test_datatype_converter_complex(
260+
mlc_field: mlc.Field,
261+
feature_type: Type[Any],
262+
subfield_types: Dict[str, Type[Any]] | None,
263+
):
256264
actual_feature = croissant_builder.datatype_converter(mlc_field)
257265
assert actual_feature.doc.desc == mlc_field.description
258266
assert isinstance(actual_feature, feature_type)
@@ -411,7 +419,9 @@ def test_version_converter(tmp_path):
411419

412420

413421
@pytest.fixture(name="crs_builder")
414-
def mock_croissant_dataset_builder(tmp_path, request):
422+
def mock_croissant_dataset_builder(
423+
tmp_path, request
424+
) -> croissant_builder.CroissantBuilder:
415425
dataset_name = request.param["dataset_name"]
416426
with testing.dummy_croissant_file(
417427
dataset_name=dataset_name,
@@ -477,7 +487,11 @@ def test_croissant_builder(crs_builder):
477487
indirect=["crs_builder"],
478488
)
479489
@pytest.mark.parametrize("split_name", ["train", "test"])
480-
def test_download_and_prepare(crs_builder, expected_entries, split_name):
490+
def test_download_and_prepare(
491+
crs_builder: croissant_builder.CroissantBuilder,
492+
expected_entries: List[Dict[str, Any]],
493+
split_name: str,
494+
):
481495
crs_builder.download_and_prepare()
482496
data_source = crs_builder.as_data_source(split=split_name)
483497
expected_entries = [

0 commit comments

Comments
 (0)