|
19 | 19 | import pathlib |
20 | 20 | from typing import Type, TypeAlias |
21 | 21 |
|
| 22 | +from etils import epath |
22 | 23 | import pytest |
23 | 24 | from tensorflow_datasets import testing |
24 | 25 | from tensorflow_datasets.core import dataset_builder |
25 | 26 | from tensorflow_datasets.core import file_adapters |
| 27 | +from tensorflow_datasets.core import naming |
26 | 28 |
|
27 | 29 |
|
28 | 30 | FileFormat: TypeAlias = file_adapters.FileFormat |
@@ -138,3 +140,41 @@ def test_prase_file_format(format_enum_value, file_format): |
138 | 140 | def test_convert_path_to_file_format(path, file_format, expected_path): |
139 | 141 | converted_path = file_adapters.convert_path_to_file_format(path, file_format) |
140 | 142 | assert os.fspath(converted_path) == expected_path |
| 143 | + |
| 144 | + |
| 145 | +@pytest.mark.parametrize( |
| 146 | + 'adapter_cls', |
| 147 | + ( |
| 148 | + (file_adapters.TfRecordFileAdapter), |
| 149 | + (file_adapters.ArrayRecordFileAdapter), |
| 150 | + ), |
| 151 | +) |
| 152 | +def test_shard_lengths( |
| 153 | + tmp_path: pathlib.Path, adapter_cls: file_adapters.FileAdapter |
| 154 | +): |
| 155 | + file_template = naming.ShardedFileTemplate( |
| 156 | + data_dir=tmp_path, |
| 157 | + dataset_name='data', |
| 158 | + filetype_suffix=adapter_cls.FILE_SUFFIX, |
| 159 | + split='train', |
| 160 | + ) |
| 161 | + tmp_path_1 = file_template.sharded_filepath(shard_index=0, num_shards=2) |
| 162 | + tmp_path_2 = file_template.sharded_filepath(shard_index=1, num_shards=2) |
| 163 | + adapter_cls.write_examples( |
| 164 | + tmp_path_1, [(0, b'0'), (1, b'1'), (2, b'2222'), (3, b'33333')] |
| 165 | + ) |
| 166 | + adapter_cls.write_examples(tmp_path_2, [(3, b'3'), (4, b'4'), (5, b'555')]) |
| 167 | + size_1 = epath.Path(tmp_path_1).stat().length |
| 168 | + size_2 = epath.Path(tmp_path_2).stat().length |
| 169 | + expected_shard_lengths = [(4, size_1), (3, size_2)] |
| 170 | + |
| 171 | + # First test without passing the number of shards explicitly. |
| 172 | + actual_no_num_shards = adapter_cls.shard_lengths_and_sizes(file_template) |
| 173 | + assert actual_no_num_shards == expected_shard_lengths, 'no num_shards passed' |
| 174 | + |
| 175 | + # Now test with passing the number of shards explicitly. |
| 176 | + actual_with_num_shards = adapter_cls.shard_lengths_and_sizes( |
| 177 | + file_template, |
| 178 | + num_shards=2, |
| 179 | + ) |
| 180 | + assert actual_with_num_shards == expected_shard_lengths, 'num_shards passed' |
0 commit comments