From 69856af44374e5c1cf500d1fe3bc99b74d0bca36 Mon Sep 17 00:00:00 2001 From: ialarmedalien Date: Mon, 30 Mar 2026 15:37:16 -0700 Subject: [PATCH 1/4] Add file batcher --- src/cdm_data_loaders/utils/file_system.py | 72 ++++ tests/utils/test_file_system.py | 442 ++++++++++++++++++++++ 2 files changed, 514 insertions(+) create mode 100644 src/cdm_data_loaders/utils/file_system.py create mode 100644 tests/utils/test_file_system.py diff --git a/src/cdm_data_loaders/utils/file_system.py b/src/cdm_data_loaders/utils/file_system.py new file mode 100644 index 00000000..1bbc691b --- /dev/null +++ b/src/cdm_data_loaders/utils/file_system.py @@ -0,0 +1,72 @@ +"""File system-related utilities.""" + +import re +from pathlib import Path +from re import Pattern + +# Matches files like: name_00001.ext or name_00001.ext.gz +FILE_NAME_REGEX = re.compile(r"^\w+_(\d+)(\.\w+)+$") + + +class BatchCursor: + """A batcher that can be used to retrieve batches of files from a directory.""" + + def __init__(self, directory: str | Path, batch_size: int = 1, start_at: int = 1, end_at: int = 0) -> None: + """Initialise a new directory batch cursor. + + :param directory: directory to retrieve files from + :type directory: str | Path + :param batch_size: number of files to return per invocation, defaults to 1 + :type batch_size: int, optional + :param start_at: file number to start at, defaults to 1 + :type start_at: int, optional + :param end_at: file number to end at, inclusive (i.e. if set to 15, file_0015.txt will be the last file) + Defaults to 0, which implies no end_at + :type end_at: int, optional + """ + errs = [] + if not isinstance(batch_size, int) or batch_size < 1: + errs.append("batch_size must be an integer, 1 or greater") + if not isinstance(start_at, int) or start_at < 0: + errs.append("start_at must be an integer, 0 or greater") + if not isinstance(end_at, int) or end_at < 0: + errs.append("end_at must be an integer, 1 or greater") + elif end_at > 0 and end_at < start_at: + # end_at must be greater than start_at + errs.append("end_at must be greater than start_at") + if errs: + err_msg = f"Error{'' if len(errs) == 1 else 's'} initialising BatchCursor:{'\n- '.join(errs)}\n" + raise ValueError(err_msg) + + self.directory = Path(directory) + self.batch_size: int = batch_size + self.start_at: int = start_at + self.end_at: int | None = end_at if end_at > 0 else None + self.file_regex: Pattern[str] = FILE_NAME_REGEX + + def _get_sequence_number(self, path: Path) -> int: + match = self.file_regex.match(path.name) + return int(match.group(1)) # pyright: ignore[reportOptionalMemberAccess] + + def get_batch(self) -> list[Path]: + """Return the next `batch_size` files whose sequence number >= start_at. + + Re-scans the directory on every call to pick up newly added files and + updates `start_at` to the next file in the directory list. + """ + if self.end_at and self.start_at > self.end_at: + return [] + + matched = sorted(p for p in self.directory.iterdir() if p.is_file() and self.file_regex.match(p.name)) + eligible = [ + p + for p in matched + if self._get_sequence_number(p) >= self.start_at + and (self.end_at is None or self._get_sequence_number(p) <= self.end_at) + ] + + batch = eligible[: self.batch_size] + if batch: + self.start_at = self._get_sequence_number(batch[-1]) + 1 + + return batch diff --git a/tests/utils/test_file_system.py b/tests/utils/test_file_system.py new file mode 100644 index 00000000..d5430d64 --- /dev/null +++ b/tests/utils/test_file_system.py @@ -0,0 +1,442 @@ +"""Tests of file system-related utilities.""" + +from copy import deepcopy +from pathlib import Path +from typing import Any + +import pytest + +from cdm_data_loaders.utils.file_system import FILE_NAME_REGEX, BatchCursor + +# the maximum file number in the directory +MAX_FILE_NUMBER = 15 +# what the range should be set to (max range is exclusive) +MAX_RANGE_VALUE = 16 + +EXPECTED: dict[int | None, dict[int | None, Any]] = { + # batch_size + 1: { + # start_at + 1: [[r] for r in range(1, MAX_RANGE_VALUE)], + 6: [[r] for r in range(6, MAX_RANGE_VALUE)], + 8: [[r] for r in range(8, MAX_RANGE_VALUE)], + 11: [[r] for r in range(11, MAX_RANGE_VALUE)], + }, + 5: { + # start_at + 1: [range(1, 6), range(6, 11), range(11, MAX_RANGE_VALUE)], + 6: [range(6, 11), range(11, MAX_RANGE_VALUE)], + 8: [range(8, 13), range(13, MAX_RANGE_VALUE)], + 11: [range(11, MAX_RANGE_VALUE)], + }, + 8: { + 1: [range(1, 9), range(9, MAX_RANGE_VALUE)], + 6: [range(6, 14), range(14, MAX_RANGE_VALUE)], + 8: [range(8, MAX_RANGE_VALUE)], + 11: [range(11, MAX_RANGE_VALUE)], + }, + 15: { + 1: [range(1, MAX_RANGE_VALUE)], + 6: [range(6, MAX_RANGE_VALUE)], + 8: [range(8, MAX_RANGE_VALUE)], + 11: [range(11, MAX_RANGE_VALUE)], + }, +} +# batch_size is not specified +EXPECTED[None] = EXPECTED[1] +# batch_size greater than # of records +EXPECTED[20] = EXPECTED[15] +# add in results for start_at == 0 or not specified +for ix, vals in EXPECTED.items(): + if ix is not None: + EXPECTED[ix][None] = vals[1] + EXPECTED[ix][0] = vals[1] + +EXPECTED_END_AT = deepcopy(EXPECTED) + + +def make_files(directory: Path, names: list[str]) -> list[Path]: + """Touch each filename in *directory* and return the sorted Path list.""" + paths = [] + for name in names: + p = directory / name + p.parent.mkdir(parents=True, exist_ok=True) + p.touch() + paths.append(p) + return sorted(paths) + + +def make_file_names(prefix: str, ext: str, numbers: list[int] | range) -> list[str]: + """Create the file names for a given set of numbers.""" + if not isinstance(numbers, list): + # convert the range into a list + numbers = list(numbers) + return [f"{prefix}_{n:05}.{ext}" for n in numbers] + + +def make_sequence(directory: Path, prefix: str, ext: str, numbers: list[int] | range) -> list[Path]: + """Create files for the given sequence numbers and return sorted Paths.""" + names = make_file_names(prefix, ext, numbers) + return make_files(directory, names) + + +@pytest.fixture +def file_dir(tmp_path: Path) -> Path: + """Directory pre-populated with files numbered 00001 - 00015.""" + make_sequence(tmp_path, "report", "csv", range(1, MAX_RANGE_VALUE)) + return tmp_path + + +def test_batcher_defaults() -> None: + """Ensure defaults are set correctly.""" + bc = BatchCursor(".") + assert bc.start_at == 1 + assert bc.batch_size == 1 + assert bc.end_at is None + assert bc.file_regex == FILE_NAME_REGEX + + +def test_batcher_default_end_at() -> None: + """Ensure default end_at is set correctly.""" + bc = BatchCursor(".", end_at=0) + assert bc.end_at is None + + +@pytest.mark.parametrize("batch_size", [None, 0, -1, 1.0, 1.2345678, -15, -1234567890, "something", "50", "None"]) +def test_invalid_batch_size(batch_size: float | str | None) -> None: + """Test invalid batch_size, start_at, end_at, and file_regex parameters.""" + with pytest.raises(ValueError, match="batch_size must be an integer, 1 or greater"): + BatchCursor(".", batch_size) # pyright: ignore[reportArgumentType] + + +@pytest.mark.parametrize("start_at", [None, -1, 1.0, 1.2345678, -15, -1234567890, "something", "50", "None"]) +def test_invalid_start_at_params(start_at: float | None) -> None: + """Test invalid start_at parameters.""" + with pytest.raises(ValueError, match="start_at must be an integer, 0 or greater"): + BatchCursor(".", start_at=start_at) # pyright: ignore[reportArgumentType] + + +@pytest.mark.parametrize("end_at", [None, -1, 1.0, 1.2345678, -15, -1234567890, "something", "50", "None"]) +def test_invalid_end_at_params(end_at: float | None) -> None: + """Test invalid end_at parameters.""" + with pytest.raises(ValueError, match="end_at must be an integer, 1 or greater"): + BatchCursor(".", end_at=end_at) # pyright: ignore[reportArgumentType] + + +def test_invalid_start_vs_end_at_params() -> None: + """Ensure that an error is thrown if start_at and end_at are not compatible.""" + with pytest.raises(ValueError, match="end_at must be greater than start_at"): + BatchCursor(".", start_at=2, end_at=1) + + +def test_ok_start_end_at_params() -> None: + """Ensure that 0 is a valid end_at parameter, regardless of start_at value.""" + bc = BatchCursor(".", start_at=5, end_at=0) + assert bc.end_at is None + assert bc.start_at == 5 # noqa: PLR2004 + + +def test_end_at_greater_than_start_at_during_iteration() -> None: + """Ensure that if end_at is smaller than start_at during iteration, an empty list is returned.""" + bc = BatchCursor(".", start_at=0, end_at=5) + assert bc.end_at == 5 # noqa: PLR2004 + bc.start_at = 10 + assert bc.get_batch() == [] + + +CUTOFF_VALUE = 12 + + +# your basic batch +@pytest.mark.parametrize("end_at", [None, 0, CUTOFF_VALUE]) +@pytest.mark.parametrize("start_at", [None, 0, 1, 6, 8, 11]) +@pytest.mark.parametrize("batch_size", EXPECTED.keys()) +def test_get_batch_parametrized( + file_dir: Path, + batch_size: int | None, + start_at: int | None, + end_at: int | None, +) -> None: + """Test retrieval of batches of files.""" + cursor_params = {} + if batch_size is not None: + cursor_params["batch_size"] = batch_size + if start_at is not None: + cursor_params["start_at"] = start_at + if end_at is not None: + cursor_params["end_at"] = end_at + + cursor = BatchCursor(file_dir, **cursor_params) + + # generate the expected files + expected_files: list[list[Path]] = [ + [file_dir / fn for fn in make_file_names("report", "csv", numbers)] + for numbers in EXPECTED[batch_size][start_at] + ] + if end_at: + expected_files = [] + for numbers in EXPECTED[batch_size][start_at]: + if cutoffless := [n for n in numbers if n <= end_at]: + expected_files.append([file_dir / fn for fn in make_file_names("report", "csv", cutoffless)]) + + output: list[list[Path]] = [] + while batch := cursor.get_batch(): + output.append(batch) + if cursor.start_at >= MAX_RANGE_VALUE: + break + + # check the number of batches is correct + assert len(output) == len(expected_files) + + # results are sorted + for batch in output: + # results are all file paths + assert all(isinstance(p, Path) for p in batch) + assert sorted(batch) == batch + assert output == expected_files + + # if end_at is defined, the start_at value will be one greater than the end_at value + if end_at: + assert cursor.start_at == CUTOFF_VALUE + 1 + + +def test_default_start_at_is_zero(file_dir: Path) -> None: + """Ensure that the default start_at is 0.""" + cursor_default = BatchCursor(file_dir, batch_size=3) + cursor_explicit = BatchCursor(file_dir, batch_size=3, start_at=0) + assert cursor_default.get_batch() == cursor_explicit.get_batch() + + +def test_start_at_matches_sequence_number(file_dir: Path) -> None: + """Ensure start_at value matches sequence number.""" + cursor = BatchCursor(file_dir, batch_size=5, start_at=15) + result = cursor.get_batch() + assert len(result) == 1 + assert result[0].name == "report_00015.csv" + + +# advancing the cursor +def test_start_at_advances_after_get_batch(file_dir: Path) -> None: + """Ensure that the start_at value changes after each successful get_batch operation.""" + batch_size = 5 + cursor = BatchCursor(file_dir, batch_size=batch_size, start_at=0) + assert cursor.start_at == 0 + batch_1 = cursor.get_batch() + assert cursor.start_at == batch_size + 1 # next file is report_00006.csv + batch_2 = cursor.get_batch() + assert cursor.start_at == batch_size * 2 + 1 # report_00011.csv + batch_3 = cursor.get_batch() + assert cursor.start_at == batch_size * 3 + 1 # report_00016.csv (does not exist) + # next call returns nothing + assert cursor.get_batch() == [] + + # all files should be the sequential list of existing files + all_files = batch_1 + batch_2 + batch_3 + assert all_files == [file_dir / f"report_{n:05}.csv" for n in range(1, MAX_RANGE_VALUE)] + + +def test_cursor_does_not_advance_on_empty_result(file_dir: Path) -> None: + """Ensure that the cursor does not advance if the batch is empty.""" + start_at = 999 + cursor = BatchCursor(file_dir, batch_size=5, start_at=start_at) + cursor.get_batch() + assert cursor.start_at == start_at + + +def test_partial_batch_advances_correctly(file_dir: Path) -> None: + """Ensure that the cursor only advances as far as the last file in the batch.""" + # Only 3 files remain from 13 onward + cursor = BatchCursor(file_dir, batch_size=5, start_at=13) + result = cursor.get_batch() + assert result == [file_dir / f"report_{n:05}.csv" for n in [13, 14, 15]] + assert cursor.start_at == 16 # noqa: PLR2004 + + +def test_cursor_can_be_reset(file_dir: Path) -> None: + """Ensure that the cursor can be reset.""" + batch_size = 5 + cursor = BatchCursor(file_dir, batch_size=batch_size) + original_result = cursor.get_batch() + assert cursor.start_at == batch_size + 1 + # set cursor to 0 + cursor.start_at = 0 + reset_result = cursor.get_batch() + assert cursor.start_at == batch_size + 1 + assert original_result == reset_result + assert reset_result[0].name == "report_00001.csv" + + +# Edge cases -- boundaries +def test_start_at_beyond_end_returns_empty_list(file_dir: Path) -> None: + """Ensure that nothing is returned if start_at is too high.""" + cursor = BatchCursor(file_dir, batch_size=5, start_at=999) + assert cursor.get_batch() == [] + + +def test_empty_directory_returns_empty_list(tmp_path: Path) -> None: + """Ensure that an empty dir returns nothing.""" + cursor = BatchCursor(tmp_path, batch_size=5) + assert cursor.get_batch() == [] + + +def test_batch_size_larger_than_remaining_files(file_dir: Path) -> None: + """Ensure that batches are sized correctly for partial batches.""" + cursor = BatchCursor(file_dir, batch_size=10, start_at=10) + result = cursor.get_batch() + # should have 00010 - 00015 + assert result[-1].name == "report_00015.csv" + assert result == [file_dir / f"report_{n:05}.csv" for n in range(10, MAX_RANGE_VALUE)] + assert cursor.start_at == 16 # noqa: PLR2004 + + +# gaps in the sequence +def test_start_at_skips_to_next_available_when_gap(tmp_path: Path) -> None: + """Ensure that gaps in the sequence are dealt with correctly.""" + # Files exist for 1,2,3 then jump to 10,11,12 — no 4-9 + make_sequence(tmp_path, "data", "csv.gz", [1, 2, 3, 10, 11, 12]) + cursor = BatchCursor(tmp_path, batch_size=5, start_at=5) + # retrieve 5 files, starting at 00005 + assert cursor.get_batch() == [tmp_path / f"data_{n:05}.csv.gz" for n in [10, 11, 12]] + assert cursor.start_at == 13 # noqa: PLR2004 + assert cursor.get_batch() == [] + + +def test_sequential_calls_across_gap(tmp_path: Path) -> None: + """Ensure that files are correctly retrieved across gaps in the sequence.""" + make_sequence(tmp_path, "data", "csv.gz", [1, 2, 3, 10, 11, 12]) + cursor = BatchCursor(tmp_path, batch_size=2) + assert cursor.get_batch() == [tmp_path / f"data_{n:05}.csv.gz" for n in [1, 2]] + assert cursor.start_at == 3 # noqa: PLR2004 + + assert cursor.get_batch() == [tmp_path / f"data_{n:05}.csv.gz" for n in [3, 10]] + assert cursor.start_at == 11 # noqa: PLR2004 + + assert cursor.get_batch() == [tmp_path / f"data_{n:05}.csv.gz" for n in [11, 12]] + assert cursor.start_at == 13 # noqa: PLR2004 + + assert cursor.get_batch() == [] + assert cursor.start_at == 13 # noqa: PLR2004 + + +def test_sequential_calls_across_gap_with_end_at(tmp_path: Path) -> None: + """Ensure that files are correctly retrieved across gaps in the sequence when end_at is specified.""" + make_sequence(tmp_path, "data", "csv.gz", [1, 2, 3, 5, 8, 11, 15]) + cursor = BatchCursor(tmp_path, batch_size=2, end_at=10) + assert cursor.get_batch() == [tmp_path / f"data_{n:05}.csv.gz" for n in [1, 2]] + assert cursor.start_at == 3 # noqa: PLR2004 + + assert cursor.get_batch() == [tmp_path / f"data_{n:05}.csv.gz" for n in [3, 5]] + assert cursor.start_at == 6 # noqa: PLR2004 + + assert cursor.get_batch() == [tmp_path / f"data_{n:05}.csv.gz" for n in [8]] + assert cursor.start_at == 9 # noqa: PLR2004 + + assert cursor.get_batch() == [] + assert cursor.start_at == 9 # noqa: PLR2004 + + +@pytest.fixture +def mixed_dir(tmp_path: Path) -> Path: + """Directory containing valid files alongside files that should be ignored.""" + make_sequence(tmp_path, "data", "txt", range(1, 6)) + more_files = [ + "data_123.txt", + "data_000001.txt", + "data_000100.txt.gz", + "data_000200.txt.tar.gz", + "data_000400.csv.gz", + "file_000300.txt.tar.gz", + # no numbers + "README.md", + # contains non-\w character + ".hidden_00001.txt", + # no extension + "data_00001", + # files in nested dirs -- will not be found + "nested/data_00010.txt", + "nested/dir1/data_00020.txt", + ] + make_files(tmp_path, more_files) + return tmp_path + + +# File-name pattern filtering +def test_ignores_invalid_filenames(mixed_dir: Path) -> None: + """Ensure that filenames are matched correctly.""" + cursor = BatchCursor(mixed_dir, batch_size=20) + generated_file_names = [f"data_{n:05}.txt" for n in range(1, 6)] + file_names = sorted( + [ + *generated_file_names, + "data_123.txt", + "data_000001.txt", + "data_000100.txt.gz", + "data_000200.txt.tar.gz", + "data_000400.csv.gz", + "file_000300.txt.tar.gz", + ] + ) + assert cursor.get_batch() == [mixed_dir / fn for fn in file_names] + + +def test_mixed_extensions_sorted_correctly(tmp_path: Path) -> None: + """Ensure that files with a mix of extensions are sorted numerially.""" + names = ["data_00001.csv", "data_00001.tar.gz", "data_00002.tar.gz", "data_00003.txt"] + make_files(tmp_path, names) + cursor = BatchCursor(tmp_path, batch_size=10) + assert [p.name for p in cursor.get_batch()] == names + + +# Dynamic / live-directory behaviour +def test_picks_up_newly_added_files(tmp_path: Path) -> None: + """Ensure that adding files to a dir during batching picks up new files correctly.""" + # dir contains log_00001.txt -> log_00003.txt + make_sequence(tmp_path, "log", "txt", range(1, 4)) + cursor = BatchCursor(tmp_path, batch_size=10) + assert cursor.get_batch() == [tmp_path / f"log_{n:05}.txt" for n in [1, 2, 3]] + + (tmp_path / "log_00004.txt").touch() + # Reset cursor to re-scan from the start + cursor.start_at = 0 + assert cursor.get_batch() == [tmp_path / f"log_{n:05}.txt" for n in [1, 2, 3, 4]] + + +def test_new_files_within_current_window_are_included(tmp_path: Path) -> None: + """Ensure that all new files within the current batching params are included, regardless of sequence position.""" + # dir contains log_00001.txt -> log_00005.txt + make_sequence(tmp_path, "log", "txt", range(1, 6)) + cursor = BatchCursor(tmp_path, batch_size=3, end_at=13) + + assert cursor.get_batch() == [tmp_path / f"log_{n:05}.txt" for n in [1, 2, 3]] + assert cursor.start_at == 4 # noqa: PLR2004 + + # New files added within the next window before the next call + (tmp_path / "log_00006.txt").touch() + (tmp_path / "log_00007.txt").touch() + assert cursor.get_batch() == [tmp_path / f"log_{n:05}.txt" for n in [4, 5, 6]] + assert cursor.start_at == 7 # noqa: PLR2004 + + # New files added within the next window before the next call + # Note: we are missing 00008 and 00009 + (tmp_path / "log_00010.txt").touch() + (tmp_path / "log_00011.txt").touch() + assert cursor.get_batch() == [tmp_path / f"log_{n:05}.txt" for n in [7, 10, 11]] + assert cursor.start_at == 12 # noqa: PLR2004 + + # add in missing files -- nothing is returned as start_at is at 12 + (tmp_path / "log_00008.txt").touch() + (tmp_path / "log_00009.txt").touch() + assert cursor.get_batch() == [] + assert cursor.start_at == 12 # noqa: PLR2004 + + # add more files + (tmp_path / "log_00012.txt").touch() + (tmp_path / "log_00013.txt").touch() + assert cursor.get_batch() == [tmp_path / f"log_{n:05}.txt" for n in [12, 13]] + assert cursor.start_at == 14 # noqa: PLR2004 + + # add more files beyond the end_at value + (tmp_path / "log_00014.txt").touch() + (tmp_path / "log_00015.txt").touch() + assert cursor.get_batch() == [] + assert cursor.start_at == 14 # noqa: PLR2004 From f6b8ace38aebaaa59515b3cd4394cd25af5d2c53 Mon Sep 17 00:00:00 2001 From: i alarmed alien Date: Mon, 30 Mar 2026 15:56:04 -0700 Subject: [PATCH 2/4] Potential fix for pull request finding 'Module is imported with 'import' and 'import from'' Co-authored-by: Copilot Autofix powered by AI <223894421+github-code-quality[bot]@users.noreply.github.com> --- src/cdm_data_loaders/utils/file_system.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/cdm_data_loaders/utils/file_system.py b/src/cdm_data_loaders/utils/file_system.py index 1bbc691b..0660b881 100644 --- a/src/cdm_data_loaders/utils/file_system.py +++ b/src/cdm_data_loaders/utils/file_system.py @@ -2,7 +2,6 @@ import re from pathlib import Path -from re import Pattern # Matches files like: name_00001.ext or name_00001.ext.gz FILE_NAME_REGEX = re.compile(r"^\w+_(\d+)(\.\w+)+$") @@ -42,7 +41,7 @@ def __init__(self, directory: str | Path, batch_size: int = 1, start_at: int = 1 self.batch_size: int = batch_size self.start_at: int = start_at self.end_at: int | None = end_at if end_at > 0 else None - self.file_regex: Pattern[str] = FILE_NAME_REGEX + self.file_regex: re.Pattern[str] = FILE_NAME_REGEX def _get_sequence_number(self, path: Path) -> int: match = self.file_regex.match(path.name) From 522a81e12f84889085f82f4537b5c076b4934fbb Mon Sep 17 00:00:00 2001 From: ialarmedalien Date: Tue, 31 Mar 2026 10:47:52 -0700 Subject: [PATCH 3/4] Updating uniref pipeline to use Pydantic CLI --- README.md | 4 +- pyproject.toml | 7 +- scripts/entrypoint.sh | 16 +- .../parsers/uniprot/uniref.py | 2 +- .../pipelines/cts_defaults.py | 4 + .../{uniprot_kb_pipeline.py => uniprot_kb.py} | 2 +- src/cdm_data_loaders/pipelines/uniref.py | 153 ++++++++ .../pipelines/uniref_pipeline.py | 132 ------- tests/pipelines/test_uniref.py | 364 ++++++++++++++++++ uv.lock | 29 +- 10 files changed, 566 insertions(+), 147 deletions(-) create mode 100644 src/cdm_data_loaders/pipelines/cts_defaults.py rename src/cdm_data_loaders/pipelines/{uniprot_kb_pipeline.py => uniprot_kb.py} (98%) create mode 100644 src/cdm_data_loaders/pipelines/uniref.py delete mode 100644 src/cdm_data_loaders/pipelines/uniref_pipeline.py create mode 100644 tests/pipelines/test_uniref.py diff --git a/README.md b/README.md index 0881220c..4f2a1326 100644 --- a/README.md +++ b/README.md @@ -59,8 +59,8 @@ The repo provides a Docker container that can be used to run several import pipe Current endpoints include: - `test`: run the unit tests that do _not_ require external dependencies like Spark -- `uniprot`: run the UniProtKB (UniProt protein database) import pipeline; see [the UniProtKB pipeline](src/cdm_data_loaders/pipelines/uniprot_kb_pipeline.py) for arguments -- `uniref`: run the UniRef import pipeline; the [the UniRef pipeline](src/cdm_data_loaders/pipelines/uniref_pipeline.py) for arguments +- `uniprot`: run the UniProtKB (UniProt protein database) import pipeline; see [the UniProtKB pipeline](src/cdm_data_loaders/pipelines/uniprot_kb.py) for arguments +- `uniref`: run the UniRef import pipeline; the [the UniRef pipeline](src/cdm_data_loaders/pipelines/uniref.py) for arguments ## Development diff --git a/pyproject.toml b/pyproject.toml index 6f0b13fc..7851c85e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,7 +14,7 @@ dependencies = [ "click>=8.3.1", "defusedxml>=0.7.1", "delta-spark>=4.1.0", - "dlt[deltalake,filesystem,parquet]>=1.22.2", + "dlt[deltalake,duckdb,filesystem,parquet]>=1.22.2", "lxml>=6.0.2", "pydantic>=2.12.5", "pydantic-settings>=2.12.0", @@ -23,8 +23,9 @@ dependencies = [ [project.scripts] idmapping = "cdm_data_loaders.parsers.uniprot.idmapping:cli" -uniprot_pipeline = "cdm_data_loaders.pipelines.uniprot_kb_pipeline:cli" -uniref_pipeline = "cdm_data_loaders.pipelines.uniref_pipeline:cli" +uniprot = "cdm_data_loaders.pipelines.uniprot_kb:cli" +uniref = "cdm_data_loaders.pipelines.uniref:cli" +ncbi_rest_api = "cdm_data_loaders.pipelines.ncbi_rest_api:cli" [dependency-groups] dev = [ diff --git a/scripts/entrypoint.sh b/scripts/entrypoint.sh index dc83be94..bfdb2fb1 100755 --- a/scripts/entrypoint.sh +++ b/scripts/entrypoint.sh @@ -3,7 +3,7 @@ set -euo pipefail # Ensure at least one argument is provided if [ "$#" -eq 0 ]; then - echo "Usage: $0 {uniref|uniprot|test|xml_split} [args...]" + echo "Usage: $0 {uniref|uniprot|ncbi_api|test|xml_split} [args...]" exit 1 fi @@ -16,12 +16,16 @@ case "$cmd" in exec /usr/bin/tini -- xml_file_splitter "$@" ;; uniref) - # Run the uniref pipeline with any additional arguments via tini - exec /usr/bin/tini -- uv run --no-sync uniref_pipeline "$@" + # Run the uniref pipeline with any additional arguments + exec /usr/bin/tini -- uv run --no-sync uniref "$@" ;; uniprot) - # Run the uniprot pipeline with any additional arguments via tini - exec /usr/bin/tini -- uv run --no-sync uniprot_pipeline "$@" + # Run the uniprot pipeline with any additional arguments + exec /usr/bin/tini -- uv run --no-sync uniprot "$@" + ;; + ncbi_rest_api) + # Run the NCBI datasets API importer + exec /usr/bin/tini -- uv run --no-sync ncbi_rest_api "$@" ;; test) # run the tests @@ -31,7 +35,7 @@ case "$cmd" in exec /usr/bin/tini -- /bin/bash ;; *) - echo "Error: unknown command '$cmd'; valid commands are 'uniref' or 'uniprot'." >&2 + echo "Error: unknown command '$cmd'; valid commands are 'uniref', 'uniprot', 'ncbi_api', or 'xml_split'." >&2 exit 1 ;; esac diff --git a/src/cdm_data_loaders/parsers/uniprot/uniref.py b/src/cdm_data_loaders/parsers/uniprot/uniref.py index 73ac7200..074f7326 100644 --- a/src/cdm_data_loaders/parsers/uniprot/uniref.py +++ b/src/cdm_data_loaders/parsers/uniprot/uniref.py @@ -25,7 +25,7 @@ NS = "ns" UNIREF_NS = {NS: UNIREF_URL, "": UNIREF_URL} UNIREF = "UniRef" -UNIREF_VARIANTS = [100, 90, 50] +UNIREF_VARIANTS = ["100", "90", "50"] ENTITY_ID = "entity_id" PREFIX_TRANSLATION = { diff --git a/src/cdm_data_loaders/pipelines/cts_defaults.py b/src/cdm_data_loaders/pipelines/cts_defaults.py new file mode 100644 index 00000000..5cfb966b --- /dev/null +++ b/src/cdm_data_loaders/pipelines/cts_defaults.py @@ -0,0 +1,4 @@ +"""Common defaults for running pipelines on the KBase CTS.""" + +INPUT_MOUNT = "/input_dir" +OUTPUT_MOUNT = "/output_dir" diff --git a/src/cdm_data_loaders/pipelines/uniprot_kb_pipeline.py b/src/cdm_data_loaders/pipelines/uniprot_kb.py similarity index 98% rename from src/cdm_data_loaders/pipelines/uniprot_kb_pipeline.py rename to src/cdm_data_loaders/pipelines/uniprot_kb.py index df9e6bbe..be01e0a0 100644 --- a/src/cdm_data_loaders/pipelines/uniprot_kb_pipeline.py +++ b/src/cdm_data_loaders/pipelines/uniprot_kb.py @@ -57,7 +57,7 @@ def run_pipeline(config: Settings) -> None: :param config: config for running the pipeline. :type config: Settings """ - for uniprot_file in sorted(config.input_dir.glob("*.xml.gz")): + for uniprot_file in sorted(config.input_dir.glob("*.xml*")): if config.start_at: # get the integer part of the file name f_int = uniprot_file.stem.replace("parts_", "") diff --git a/src/cdm_data_loaders/pipelines/uniref.py b/src/cdm_data_loaders/pipelines/uniref.py new file mode 100644 index 00000000..1e54367a --- /dev/null +++ b/src/cdm_data_loaders/pipelines/uniref.py @@ -0,0 +1,153 @@ +"""DLT pipeline to import UniRef data.""" + +import datetime +from collections.abc import Generator +from typing import Any + +import dlt +from dlt.extract.items import DataItemWithMeta +from pydantic import AliasChoices, Field, ValidationError, field_validator +from pydantic_settings import BaseSettings, SettingsConfigDict, SettingsError + +from cdm_data_loaders.parsers.uniprot.uniref import UNIREF_URL, UNIREF_VARIANTS, parse_uniref_entry +from cdm_data_loaders.pipelines.cts_defaults import INPUT_MOUNT +from cdm_data_loaders.utils.cdm_logger import get_cdm_logger +from cdm_data_loaders.utils.file_system import BatchCursor +from cdm_data_loaders.utils.xml_utils import stream_xml_file + +logger = get_cdm_logger() + +APP_NAME = "uniref_importer" + +VALID_DESTINATIONS = ["local_fs", "minio"] +DEFAULT_BATCH_SIZE = 50 + + +class Settings(BaseSettings): + """Configuration for running the UniRef import pipeline.""" + + model_config = SettingsConfigDict( + cli_parse_args=True, + cli_prog_name="uniref", + cli_exit_on_error=False, + cli_ignore_unknown_args=True, + ) + input_dir: str = Field( + default=INPUT_MOUNT, + description="Location of directory containing UniRef XML files to import", + # explicitly allow both kebab case and snake case + validation_alias=AliasChoices("i", "input_dir", "input-dir", "input_dir"), + ) + destination: str = Field( + default="local_fs", + description=f"Destination configuration to use for data output. Choices: {VALID_DESTINATIONS}", + validation_alias=AliasChoices("d", "destination"), + ) + uniref_variant: str = Field( + description=f"Which UniRef variant to import. Choices: {UNIREF_VARIANTS}", + validation_alias=AliasChoices("u", "uniref", "uniref-variant", "uniref_variant"), + ) + start_at: int = Field( + default=0, + description="File to start import at", + validation_alias=AliasChoices("s", "start", "start-at", "start_at"), + ) + output: str | None = Field( + default=None, + description="Location to save imported data to, if different from the default supplied by the destination config", + validation_alias=AliasChoices("o", "output"), + ) + + @field_validator("uniref_variant") + @classmethod + def validate_uniref_variant(cls, v: str) -> str: + """Validate the uniref variant against valid choices. + + :param v: uniref variant specified + :type v: str + :raises ValueError: if the uniref variant is not valid + :return: valid uniref variant + :rtype: str + """ + if v not in UNIREF_VARIANTS: + err_msg = f"uniref_variant must be one of {UNIREF_VARIANTS}, got '{v}'" + raise ValueError(err_msg) + return v + + @field_validator("destination") + @classmethod + def validate_destination(cls, v: str) -> str: + """Validate the destination against valid choices. + + :param v: destination specified + :type v: str + :raises ValueError: if the destination is not valid + :return: valid destination + :rtype: str + """ + if v not in VALID_DESTINATIONS: + err_msg = f"destination must be one of {VALID_DESTINATIONS}, got '{v}'" + raise ValueError(err_msg) + return v + + +@dlt.resource(name="parse_uniref", parallelized=True) +def parse_uniref(config: Settings) -> Generator[DataItemWithMeta, Any]: + """Parse the information from UniRef files, batch by batch. + + :param config: config for running the pipeline. + :type config: Settings + """ + timestamp = datetime.datetime.now(tz=datetime.UTC) + batch_params: dict[str, Any] = {} + if config.start_at: + batch_params["start_at"] = config.start_at + + batcher = BatchCursor(config.input_dir, batch_size=DEFAULT_BATCH_SIZE, **batch_params) + while uniref_files := batcher.get_batch(): + for file_path in uniref_files: + logger.info("Reading from %s", str(file_path)) + for n_entries, entry in enumerate(stream_xml_file(file_path, f"{{{UNIREF_URL}}}entry")): + parsed_entry = parse_uniref_entry(entry, timestamp, f"UniRef {config.uniref_variant}", file_path) + for table, rows in parsed_entry.items(): + yield dlt.mark.with_table_name(rows, table) + if n_entries + 1 % 10000 == 0: + logger.info("Processed %d entries", n_entries + 1) + + +def run_pipeline(config: Settings) -> None: + """Execute the pipeline. + + :param config: config for running the pipeline. + :type config: Settings + """ + # check whether there is a custom output location; if so, set it in the config + if config.output: + dlt.config[f"destination.{config.destination}.bucket_url"] = config.output + + pipeline = dlt.pipeline( + pipeline_name=f"uniref_{config.uniref_variant}", + destination=dlt.destination(config.destination, max_table_nesting=0), + dataset_name="uniprot_kb", + ) + load_info = pipeline.run(parse_uniref(config), table_format="delta") + logger.info(load_info) + logger.info("Work complete!") + + +def cli() -> None: + """CLI interface for the UniRef importer pipeline. + + See the ``Settings`` object for parameters. + """ + try: + config = Settings() # pyright: ignore[reportCallIssue] + except (Exception, SettingsError, ValidationError) as e: + print(f"Error initialising config: {e}") + raise + + run_pipeline(config) + + +if __name__ == "__main__": + cli() diff --git a/src/cdm_data_loaders/pipelines/uniref_pipeline.py b/src/cdm_data_loaders/pipelines/uniref_pipeline.py deleted file mode 100644 index 4e77503b..00000000 --- a/src/cdm_data_loaders/pipelines/uniref_pipeline.py +++ /dev/null @@ -1,132 +0,0 @@ -"""DLT pipeline to import UniProt data.""" - -import datetime -from collections.abc import Generator -from pathlib import Path -from typing import Any - -import click -import dlt -from dlt.extract.items import DataItemWithMeta -from pydantic import Field -from pydantic_settings import BaseSettings - -from cdm_data_loaders.parsers.uniprot.uniref import UNIREF_URL, UNIREF_VARIANTS, parse_uniref_entry -from cdm_data_loaders.utils.cdm_logger import get_cdm_logger -from cdm_data_loaders.utils.xml_utils import stream_xml_file - -logger = get_cdm_logger() - -APP_NAME = "uniref_importer" - -VALID_DESTINATIONS = ["local_fs", "minio"] - -DEFAULT_UNIPROT_DIR = "s3://cdm-lake/tenant-general-warehouse/kbase/datasets/uniprot/" -UNIREF_DIR = Path("/global_share") / "uniprot" / "derived" / "2025_03" / "uniref" - - -class Settings(BaseSettings): - """Configuration for running the UniRef import pipeline.""" - - input_dir: Path = Field() - destination: str = Field() - uniref_variant: int = Field() - start_at: int = Field(0) - timestamp: datetime.datetime = Field(datetime.datetime.now(tz=datetime.UTC)) - - -@dlt.resource(name="parse_uniref", parallelized=True) -def parse_uniref( - file_path: str | Path, current_timestamp: datetime.datetime, uniref_value: int -) -> Generator[DataItemWithMeta, Any]: - """Parse the information from a UniProt entry. - - :param entry: _description_ - :type entry: Element - :return: _description_ - :rtype: _type_ - """ - for n, entry in enumerate(stream_xml_file(file_path, f"{{{UNIREF_URL}}}entry")): - parsed_entry = parse_uniref_entry(entry, current_timestamp, f"UniRef {uniref_value}", file_path) - for table, rows in parsed_entry.items(): - yield dlt.mark.with_table_name(rows, table) - if n + 1 % 10000 == 0: - print(f"Processed {n + 1} entries") - - -def run_pipeline(config: Settings) -> None: - """Execute the pipeline. - - :param config: config for running the pipeline. - :type config: Settings - """ - for uniref_file in sorted(config.input_dir.glob("*.xml.gz")): - if config.start_at: - # get the integer part of the file name - f_int = uniref_file.stem.replace("part_", "") - if int(f_int) < config.start_at: - logger.info("Skipping %s", str(uniref_file)) - continue - logger.info("Reading from %s", str(uniref_file)) - pipeline = dlt.pipeline( - pipeline_name=f"uniref_{config.uniref_variant}", - destination=dlt.destination(config.destination, max_table_nesting=0), - dataset_name="uniprot_kb", - ) - load_info = pipeline.run( - parse_uniref(uniref_file, config.timestamp, config.uniref_variant), table_format="delta" - ) - logger.info("Work complete!") - logger.info(load_info) - - -@click.command() -@click.option( - "-n", "--uniref", required=True, type=click.Choice(UNIREF_VARIANTS), help="Which UniRef variant to import" -) -@click.option("-s", "--start", type=int, default=0, help="File to start import at") -@click.option("-i", "--input_dir", default=str(UNIREF_DIR), help="Location of UniRef XML files to import") -@click.option( - "-d", - "--destination", - type=click.Choice(VALID_DESTINATIONS), - default="local_fs", - help="Destination configuration to use for data output", -) -@click.option( - "-o", - "--output", - type=str, - default="", - help="Location to save imported data to, if different from the default supplied by the destination config", -) -def cli(input_dir: str, destination: str, output: str | None, start: int, uniref: int) -> None: - """CLI interface for the UniRef importer pipeline. - - :param input_dir: Location of the directory containing the UniRef XML files - :type input_dir: str - :param destination: destination configuration to use - :type destination: str | None - :param output: location in the object store to save files to - :type output: str - :param start: if provided, which file to start the import at - :type start: int - :param uniref: which UniRef dataset to import (50 / 90 / 100) - :type uniref: int - """ - # check whether there is a custom output location; if so, set it in the config - if output: - dlt.config[f"destination.{destination}.bucket_url"] = output - - runtime_config = Settings( - input_dir=Path(input_dir), - destination=destination, - uniref_variant=uniref, - start_at=start, - timestamp=datetime.datetime.now(tz=datetime.UTC), - ) - run_pipeline(runtime_config) - - -if __name__ == "__main__": - cli() diff --git a/tests/pipelines/test_uniref.py b/tests/pipelines/test_uniref.py new file mode 100644 index 00000000..c1144906 --- /dev/null +++ b/tests/pipelines/test_uniref.py @@ -0,0 +1,364 @@ +"""Tests for the UniRef DLT pipeline.""" + +import datetime +import sys +from pathlib import Path +from unittest.mock import MagicMock, patch + +import pytest +from pydantic import ValidationError +from pydantic_settings import CliApp + +from cdm_data_loaders.pipelines.cts_defaults import INPUT_MOUNT +from cdm_data_loaders.pipelines.uniref import ( + DEFAULT_BATCH_SIZE, + Settings, + parse_uniref, + run_pipeline, + UNIREF_URL, + UNIREF_VARIANTS, +) + +VALID_DESTINATIONS = ["local_fs", "minio"] +TEST_DEFAULT_UNIREF_VARIANT = "50" + + +def make_settings( + extra_argv: list[str] | None = None, *, uniref_variant: str = TEST_DEFAULT_UNIREF_VARIANT, **kwargs +) -> Settings: + """Generate a validated Settings object.""" + data = {"uniref": uniref_variant, **kwargs} + return Settings.model_validate(data) + + +def test_settings_defaults() -> None: + """Ensure the settings defaults are set up correctly.""" + s = make_settings() + assert s.destination == "local_fs" + assert s.start_at == 0 + assert s.output is None + assert s.input_dir == INPUT_MOUNT + assert s.uniref_variant == TEST_DEFAULT_UNIREF_VARIANT + + +def test_settings_all_params_set() -> None: + """Ensure that settings are set correctly when all args are specified.""" + s = make_settings( + input_dir="/dir/path", + destination=VALID_DESTINATIONS[0], + uniref_variant="100", + start_at="50", + output="/some/dir", + ) + assert s.input_dir == "/dir/path" + assert s.destination == VALID_DESTINATIONS[0] + assert s.uniref_variant == "100" + assert s.start_at == 50 + assert s.output == "/some/dir" + + +@pytest.mark.parametrize("destination", VALID_DESTINATIONS) +@pytest.mark.parametrize("uniref_variant", UNIREF_VARIANTS) +def test_settings_valid_variants_accepted(uniref_variant: str, destination: str) -> None: + """Ensure that each valid uniref_variant value is accepted without error.""" + s = make_settings(uniref_variant=uniref_variant, destination=destination) + assert s.uniref_variant == uniref_variant + assert s.destination == destination + + +@pytest.mark.parametrize("bad", ["25", "75", "uniref50", "", "ALL"]) +def test_invalid_variant_raises(bad: str) -> None: + """Ensure that an unrecognised uniref_variant raises a ValidationError.""" + with pytest.raises(ValidationError, match="uniref_variant must be one of"): + make_settings(uniref_variant=bad) + + +@pytest.mark.parametrize("bad", ["s3", "gcs", "filesystem", "", "LocalFs"]) +def test_invalid_destination_raises(bad: str) -> None: + """Ensure that an unrecognised destination raises a ValidationError.""" + with pytest.raises(ValidationError, match="destination must be one of"): + make_settings(destination=bad) + + +def _cliapp_run(cli_args: list[str]) -> Settings: + """Tests that Settings correctly parses command-line arguments via CliApp. + + Uses CliApp.run with explicit cli_args to avoid mutating sys.argv globally. + """ + return CliApp.run(Settings, cli_args=cli_args) + + +@pytest.mark.parametrize("input_dir", ["-i", "--input-dir", "--input_dir"]) +@pytest.mark.parametrize("destination", ["-d", "--destination"]) +@pytest.mark.parametrize("uniref_variant", ["-u", "--uniref", "--uniref-variant", "--uniref_variant"]) +@pytest.mark.parametrize("start_at", ["-s", "--start-at", "--start_at"]) +@pytest.mark.parametrize("output", ["-o", "--output"]) +def test_cli_all_variants(input_dir: str, destination: str, uniref_variant: str, start_at: str, output: str) -> None: + """Test all the variants of the Settings fields.""" + s = _cliapp_run( + [ + input_dir, + "/dir/path", + destination, + VALID_DESTINATIONS[0], + uniref_variant, + TEST_DEFAULT_UNIREF_VARIANT, + start_at, + "50", + output, + "/some/dir", + ] + ) + assert s.input_dir == "/dir/path" + assert s.destination == VALID_DESTINATIONS[0] + assert s.uniref_variant == TEST_DEFAULT_UNIREF_VARIANT + assert s.start_at == 50 + assert s.output == "/some/dir" + + +def test_cli_invalid_variant_via_cli_raises() -> None: + """Ensure that an invalid uniref_variant passed via CLI causes a SystemExit.""" + with pytest.raises(ValidationError, match="Value error, uniref_variant must be one of"): + _cliapp_run(["--uniref-variant", "999"]) + + +def test_cli_invalid_destination_via_cli_raises() -> None: + """Ensure that an invalid destination passed via CLI causes a SystemExit.""" + with pytest.raises(ValidationError, match="Value error, destination must be one of"): + _cliapp_run(["--uniref-variant", "50", "--destination", "s3"]) + + +def test_cli_missing_required_uniref_variant_raises() -> None: + """Ensure that omitting the required uniref_variant argument causes a SystemExit.""" + with pytest.raises(ValidationError, match="Field required"): + _cliapp_run([]) + + +def _collect(config: Settings): + """Drain the generator returned by parse_uniref.""" + return list(parse_uniref(config)) + + +def test_empty_batch_yields_nothing(config: Settings) -> None: + """Ensure that no items are yielded when BatchCursor returns an empty batch.""" + with patch("cdm_data_loaders.pipelines.uniref.BatchCursor") as mock_batcher_cls: + mock_batcher = MagicMock() + mock_batcher.get_batch.return_value = [] + mock_batcher_cls.return_value = mock_batcher + + results = _collect(config) + + assert results == [] + + +def test_start_at_zero_not_passed_to_batch_cursor(config: Settings) -> None: + """Ensure that start_at=0 (falsy) is not forwarded as a kwarg to BatchCursor.""" + with patch("cdm_data_loaders.pipelines.uniref.BatchCursor") as mock_batcher_cls: + mock_batcher = MagicMock() + mock_batcher.get_batch.return_value = [] + mock_batcher_cls.return_value = mock_batcher + + _collect(config) + + _, kwargs = mock_batcher_cls.call_args + assert "start_at" not in kwargs + + +def test_start_at_nonzero_passed_to_batch_cursor() -> None: + """Ensure that a non-zero start_at value is forwarded as a kwarg to BatchCursor.""" + start_at = 2 + config_with_start: Settings = make_settings( + uniref_variant="90", + input_dir="/fake/input", + start_at=start_at, + ) + with patch("cdm_data_loaders.pipelines.uniref.BatchCursor") as mock_batcher_cls: + mock_batcher = MagicMock() + mock_batcher.get_batch.return_value = [] + mock_batcher_cls.return_value = mock_batcher + + _collect(config_with_start) + + _, kwargs = mock_batcher_cls.call_args + assert kwargs.get("start_at") == start_at + + +def test_batch_cursor_receives_correct_input_dir_and_batch_size(config: Settings) -> None: + """Ensure BatchCursor is constructed with the configured input_dir and DEFAULT_BATCH_SIZE.""" + with patch("cdm_data_loaders.pipelines.uniref.BatchCursor") as mock_batcher_cls: + mock_batcher = MagicMock() + mock_batcher.get_batch.return_value = [] + mock_batcher_cls.return_value = mock_batcher + + _collect(config) + + args, kwargs = mock_batcher_cls.call_args + assert args[0] == "/fake/input" + assert kwargs.get("batch_size") == DEFAULT_BATCH_SIZE + + +def test_yields_items_for_each_table_in_parsed_entry(config: Settings) -> None: + """Ensure one item is yielded per table key returned by parse_uniref_entry.""" + fake_file = Path("/fake/input/uniref50_part1.xml") + fake_entry = MagicMock() + parsed_entry = { + "uniref_member": [{"id": "A"}, {"id": "B"}], + "uniref_cluster": [{"cluster_id": "UniRef50_A0A000"}], + } + + with ( + patch("cdm_data_loaders.pipelines.uniref.BatchCursor") as mock_batcher_cls, + patch("cdm_data_loaders.pipelines.uniref.stream_xml_file") as mock_stream, + patch("cdm_data_loaders.pipelines.uniref.parse_uniref_entry") as mock_parse, + patch("cdm_data_loaders.pipelines.uniref.dlt") as mock_dlt, + ): + mock_batcher = MagicMock() + mock_batcher.get_batch.side_effect = [[fake_file], []] + mock_batcher_cls.return_value = mock_batcher + mock_stream.return_value = [fake_entry] + mock_parse.return_value = parsed_entry + mock_dlt.mark.with_table_name.return_value = object() + + results = _collect(config) + + assert len(results) == len(parsed_entry) + assert mock_dlt.mark.with_table_name.call_count == 2 + + +def test_parse_uniref_entry_called_with_correct_args(config: Settings) -> None: + """Ensure parse_uniref_entry is called with the entry, timestamp, dataset label, and file path.""" + fake_file = Path("/fake/input/uniref50_part1.xml") + mock_stream_return = ["one", "two", "three"] + with ( + patch("cdm_data_loaders.pipelines.uniref.BatchCursor") as mock_batcher_cls, + patch("cdm_data_loaders.pipelines.uniref.stream_xml_file") as mock_stream, + patch("cdm_data_loaders.pipelines.uniref.parse_uniref_entry") as mock_parse, + patch("cdm_data_loaders.pipelines.uniref.dlt"), + ): + mock_batcher = MagicMock() + mock_batcher.get_batch.side_effect = [[fake_file], []] + mock_batcher_cls.return_value = mock_batcher + mock_stream.return_value = mock_stream_return + mock_parse.return_value = {} + + _collect(config) + + assert mock_parse.call_count == len(mock_stream_return) + call_args = [list(c[0]) for c in mock_parse.call_args_list] + for idx, ca in enumerate(call_args): + assert isinstance(ca[1], datetime.datetime) + assert ca[0] == mock_stream_return[idx] + assert ca[2] == "UniRef 50" + assert ca[3] == fake_file + + +def test_multiple_files_in_batch_are_all_processed(config: Settings) -> None: + """Ensure every file in a batch is passed to stream_xml_file.""" + files = [Path(f"/fake/input/part{i}.xml") for i in range(3)] + + with ( + patch("cdm_data_loaders.pipelines.uniref.BatchCursor") as mock_batcher_cls, + patch("cdm_data_loaders.pipelines.uniref.stream_xml_file") as mock_stream, + patch("cdm_data_loaders.pipelines.uniref.parse_uniref_entry") as mock_parse, + patch("cdm_data_loaders.pipelines.uniref.dlt"), + ): + mock_batcher = MagicMock() + mock_batcher.get_batch.side_effect = [files, []] + mock_batcher_cls.return_value = mock_batcher + mock_stream.return_value = [] + mock_parse.return_value = {} + + _collect(config) + + assert mock_stream.call_count == 3 + + +def test_multiple_batches_are_consumed(config: Settings) -> None: + """Ensure the generator continues processing until BatchCursor returns an empty batch.""" + fake_files = [Path(f"/fake/input/part_{n}.xml") for n in [1, 2, 3, 4, 5]] + batch1 = fake_files[0:2] + batch2 = fake_files[2:4] + batch3 = fake_files[4:] + + with ( + patch("cdm_data_loaders.pipelines.uniref.BatchCursor") as mock_batcher_cls, + patch("cdm_data_loaders.pipelines.uniref.stream_xml_file") as mock_stream, + patch("cdm_data_loaders.pipelines.uniref.parse_uniref_entry") as mock_parse, + patch("cdm_data_loaders.pipelines.uniref.dlt"), + ): + mock_batcher = MagicMock() + mock_batcher.get_batch.side_effect = [batch1, batch2, batch3, []] + mock_batcher_cls.return_value = mock_batcher + mock_stream.return_value = [] + mock_parse.return_value = {} + + _collect(config) + call_args = [list(c[0]) for c in mock_stream.call_args_list] + assert call_args == [[f, f"{{{UNIREF_URL}}}entry"] for f in fake_files] + + +"""Smoke tests for run_pipeline, verifying pipeline construction and execution.""" + + +@pytest.fixture +def config() -> Settings: + """Provide a minimal valid Settings object.""" + return make_settings(uniref_variant="50", input_dir="/fake/input") + + +def test_pipeline_is_executed(config: Settings) -> None: + """Ensure pipeline.run is called when run_pipeline is invoked.""" + with ( + patch("cdm_data_loaders.pipelines.uniref.dlt") as mock_dlt, + patch("cdm_data_loaders.pipelines.uniref.parse_uniref"), + ): + mock_pipeline = MagicMock() + mock_dlt.pipeline.return_value = mock_pipeline + + run_pipeline(config) + + mock_pipeline.run.assert_called_once() + + +def test_custom_output_sets_dlt_config() -> None: + """Ensure a non-empty output sets the correct dlt.config bucket_url key.""" + config = make_settings(uniref_variant="50", output="/custom/output", destination="minio") + + with ( + patch("cdm_data_loaders.pipelines.uniref.dlt") as mock_dlt, + patch("cdm_data_loaders.pipelines.uniref.parse_uniref"), + ): + mock_pipeline = MagicMock() + mock_dlt.pipeline.return_value = mock_pipeline + + run_pipeline(config) + + mock_dlt.config.__setitem__.assert_called_once_with("destination.minio.bucket_url", "/custom/output") + + +def test_no_custom_output_does_not_set_dlt_config(config: Settings) -> None: + """Ensure that an empty output does not mutate dlt.config.""" + with ( + patch("cdm_data_loaders.pipelines.uniref.dlt") as mock_dlt, + patch("cdm_data_loaders.pipelines.uniref.parse_uniref"), + ): + mock_pipeline = MagicMock() + mock_dlt.pipeline.return_value = mock_pipeline + + run_pipeline(config) + + mock_dlt.config.__setitem__.assert_not_called() + + +def test_pipeline_name_includes_uniref_variant(config: Settings) -> None: + """Ensure the pipeline is created with a name derived from the uniref_variant and the correct dataset_name.""" + with ( + patch("cdm_data_loaders.pipelines.uniref.dlt") as mock_dlt, + patch("cdm_data_loaders.pipelines.uniref.parse_uniref"), + ): + mock_dlt.pipeline.return_value = MagicMock() + run_pipeline(config) + + _, kwargs = mock_dlt.pipeline.call_args + assert kwargs["pipeline_name"] == "uniref_50" + assert kwargs["dataset_name"] == "uniprot_kb" diff --git a/uv.lock b/uv.lock index c591e129..2db7a025 100644 --- a/uv.lock +++ b/uv.lock @@ -390,7 +390,7 @@ dependencies = [ { name = "click" }, { name = "defusedxml" }, { name = "delta-spark" }, - { name = "dlt", extra = ["deltalake", "filesystem", "parquet"] }, + { name = "dlt", extra = ["deltalake", "duckdb", "filesystem", "parquet"] }, { name = "lxml" }, { name = "pydantic" }, { name = "pydantic-settings" }, @@ -429,7 +429,7 @@ requires-dist = [ { name = "click", specifier = ">=8.3.1" }, { name = "defusedxml", specifier = ">=0.7.1" }, { name = "delta-spark", specifier = ">=4.1.0" }, - { name = "dlt", extras = ["deltalake", "filesystem", "parquet"], specifier = ">=1.22.2" }, + { name = "dlt", extras = ["deltalake", "duckdb", "filesystem", "parquet"], specifier = ">=1.22.2" }, { name = "lxml", specifier = ">=6.0.2" }, { name = "pydantic", specifier = ">=2.12.5" }, { name = "pydantic-settings", specifier = ">=2.12.0" }, @@ -813,6 +813,9 @@ deltalake = [ { name = "deltalake" }, { name = "pyarrow" }, ] +duckdb = [ + { name = "duckdb" }, +] filesystem = [ { name = "botocore" }, { name = "s3fs" }, @@ -839,6 +842,28 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/55/e2/2537ebcff11c1ee1ff17d8d0b6f4db75873e3b0fb32c2d4a2ee31ecb310a/docstring_parser-0.17.0-py3-none-any.whl", hash = "sha256:cf2569abd23dce8099b300f9b4fa8191e9582dda731fd533daf54c4551658708", size = 36896, upload-time = "2025-07-21T07:35:00.684Z" }, ] +[[package]] +name = "duckdb" +version = "1.5.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/ee/11/e05a7eb73a373d523e45d83c261025e02bc31ebf868e6282c30c4d02cc59/duckdb-1.5.0.tar.gz", hash = "sha256:f974b61b1c375888ee62bc3125c60ac11c4e45e4457dd1bb31a8f8d3cf277edd", size = 17981141, upload-time = "2026-03-09T12:50:26.372Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/35/5d/af5501221f42e4e3662c047ecec4dcd0761229fceeba3c67ad4d9d8741df/duckdb-1.5.0-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:11dd05b827846c87f0ae2f67b9ae1d60985882a7c08ce855379e4a08d5be0e1d", size = 30057396, upload-time = "2026-03-09T12:49:39.95Z" }, + { url = "https://files.pythonhosted.org/packages/43/bd/a278d73fedbd3783bf9aedb09cad4171fe8e55bd522952a84f6849522eb6/duckdb-1.5.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:5ad8d9c91b7c280ab6811f59deff554b845706c20baa28c4e8f80a95690b252b", size = 15962700, upload-time = "2026-03-09T12:49:43.504Z" }, + { url = "https://files.pythonhosted.org/packages/76/fc/c916e928606946209c20fb50898dabf120241fb528a244e2bd8cde1bd9e2/duckdb-1.5.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:0ee4dabe03ed810d64d93927e0fd18cd137060b81ee75dcaeaaff32cbc816656", size = 14220272, upload-time = "2026-03-09T12:49:46.867Z" }, + { url = "https://files.pythonhosted.org/packages/53/07/1390e69db922423b2e111e32ed342b3e8fad0a31c144db70681ea1ba4d56/duckdb-1.5.0-cp313-cp313-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:9409ed1184b363ddea239609c5926f5148ee412b8d9e5ffa617718d755d942f6", size = 19244401, upload-time = "2026-03-09T12:49:49.865Z" }, + { url = "https://files.pythonhosted.org/packages/54/13/b58d718415cde993823a54952ea511d2612302f1d2bc220549d0cef752a4/duckdb-1.5.0-cp313-cp313-manylinux_2_26_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:1df8c4f9c853a45f3ec1e79ed7fe1957a203e5ec893bbbb853e727eb93e0090f", size = 21345827, upload-time = "2026-03-09T12:49:52.977Z" }, + { url = "https://files.pythonhosted.org/packages/e0/96/4460429651e371eb5ff745a4790e7fa0509c7a58c71fc4f0f893404c9646/duckdb-1.5.0-cp313-cp313-win_amd64.whl", hash = "sha256:9a3d3dfa2d8bc74008ce3ad9564761ae23505a9e4282f6a36df29bd87249620b", size = 13053101, upload-time = "2026-03-09T12:49:56.134Z" }, + { url = "https://files.pythonhosted.org/packages/ba/54/6d5b805113214b830fa3c267bb3383fb8febaa30760d0162ef59aadb110a/duckdb-1.5.0-cp313-cp313-win_arm64.whl", hash = "sha256:2deebcbafd9d39c04f31ec968f4dd7cee832c021e10d96b32ab0752453e247c8", size = 13865071, upload-time = "2026-03-09T12:49:59.282Z" }, + { url = "https://files.pythonhosted.org/packages/66/9f/dd806d4e8ecd99006eb240068f34e1054533da1857ad06ac726305cd102d/duckdb-1.5.0-cp314-cp314-macosx_10_15_universal2.whl", hash = "sha256:d4b618de670cd2271dd7b3397508c7b3c62d8ea70c592c755643211a6f9154fa", size = 30065704, upload-time = "2026-03-09T12:50:02.671Z" }, + { url = "https://files.pythonhosted.org/packages/79/c2/7b7b8a5c65d5535c88a513e267b5e6d7a55ab3e9b67e4ddd474454653268/duckdb-1.5.0-cp314-cp314-macosx_10_15_x86_64.whl", hash = "sha256:065ae50cb185bac4b904287df72e6b4801b3bee2ad85679576dd712b8ba07021", size = 15964883, upload-time = "2026-03-09T12:50:06.343Z" }, + { url = "https://files.pythonhosted.org/packages/23/c5/9a52a2cdb228b8d8d191a603254364d929274d9cc7d285beada8f7daa712/duckdb-1.5.0-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:6be5e48e287a24d98306ce9dd55093c3b105a8fbd8a2e7a45e13df34bf081985", size = 14221498, upload-time = "2026-03-09T12:50:10.567Z" }, + { url = "https://files.pythonhosted.org/packages/b8/68/646045cb97982702a8a143dc2e45f3bdcb79fbe2d559a98d74b8c160e5e2/duckdb-1.5.0-cp314-cp314-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:a5ee41a0bf793882f02192ce105b9a113c3e8c505a27c7ef9437d7b756317113", size = 19249787, upload-time = "2026-03-09T12:50:13.524Z" }, + { url = "https://files.pythonhosted.org/packages/15/1b/5abf0c7f38febb3b4a231c784223fceccfd3f2bfd957699d786f46e41ce6/duckdb-1.5.0-cp314-cp314-manylinux_2_26_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:f8e42aaf3cd217417c5dc9ff522dc3939d18b25a6fe5f846348277e831e6f59c", size = 21351583, upload-time = "2026-03-09T12:50:16.701Z" }, + { url = "https://files.pythonhosted.org/packages/93/a4/a90f2901cc0a1ce7ca4f0564b8492b9dbfe048a6395b27933d46ae9be473/duckdb-1.5.0-cp314-cp314-win_amd64.whl", hash = "sha256:11ae50aaeda2145b50294ee0247e4f11fb9448b3cc3d2aea1cfc456637dfb977", size = 13575130, upload-time = "2026-03-09T12:50:19.716Z" }, + { url = "https://files.pythonhosted.org/packages/64/aa/f14dd5e241ec80d9f9d82196ca65e0c53badfc8a7a619d5497c5626657ad/duckdb-1.5.0-cp314-cp314-win_arm64.whl", hash = "sha256:d6d2858c734d1a7e7a1b6e9b8403b3fce26dfefb4e0a2479c420fba6cd36db36", size = 14341879, upload-time = "2026-03-09T12:50:22.347Z" }, +] + [[package]] name = "elementpath" version = "5.1.1" From 1b41d70ad054f48dfbc78ba768bd0b506b37f5d5 Mon Sep 17 00:00:00 2001 From: ialarmedalien Date: Tue, 31 Mar 2026 10:52:14 -0700 Subject: [PATCH 4/4] Fixing minor typos, linting --- scripts/entrypoint.sh | 12 ++++++------ tests/pipelines/test_uniref.py | 9 +++------ 2 files changed, 9 insertions(+), 12 deletions(-) diff --git a/scripts/entrypoint.sh b/scripts/entrypoint.sh index bfdb2fb1..271a85a1 100755 --- a/scripts/entrypoint.sh +++ b/scripts/entrypoint.sh @@ -3,7 +3,7 @@ set -euo pipefail # Ensure at least one argument is provided if [ "$#" -eq 0 ]; then - echo "Usage: $0 {uniref|uniprot|ncbi_api|test|xml_split} [args...]" + echo "Usage: $0 {uniref|uniprot|xml_split|test} [args...]" exit 1 fi @@ -23,10 +23,10 @@ case "$cmd" in # Run the uniprot pipeline with any additional arguments exec /usr/bin/tini -- uv run --no-sync uniprot "$@" ;; - ncbi_rest_api) - # Run the NCBI datasets API importer - exec /usr/bin/tini -- uv run --no-sync ncbi_rest_api "$@" - ;; + # ncbi_rest_api) + # # Run the NCBI datasets API importer + # exec /usr/bin/tini -- uv run --no-sync ncbi_rest_api "$@" + # ;; test) # run the tests exec /usr/bin/tini -- uv run --no-sync pytest -m "not requires_spark" @@ -35,7 +35,7 @@ case "$cmd" in exec /usr/bin/tini -- /bin/bash ;; *) - echo "Error: unknown command '$cmd'; valid commands are 'uniref', 'uniprot', 'ncbi_api', or 'xml_split'." >&2 + echo "Error: unknown command '$cmd'; valid commands are 'uniref', 'uniprot', or 'xml_split'." >&2 exit 1 ;; esac diff --git a/tests/pipelines/test_uniref.py b/tests/pipelines/test_uniref.py index c1144906..eedeb9c7 100644 --- a/tests/pipelines/test_uniref.py +++ b/tests/pipelines/test_uniref.py @@ -1,7 +1,6 @@ """Tests for the UniRef DLT pipeline.""" import datetime -import sys from pathlib import Path from unittest.mock import MagicMock, patch @@ -12,20 +11,18 @@ from cdm_data_loaders.pipelines.cts_defaults import INPUT_MOUNT from cdm_data_loaders.pipelines.uniref import ( DEFAULT_BATCH_SIZE, + UNIREF_URL, + UNIREF_VARIANTS, Settings, parse_uniref, run_pipeline, - UNIREF_URL, - UNIREF_VARIANTS, ) VALID_DESTINATIONS = ["local_fs", "minio"] TEST_DEFAULT_UNIREF_VARIANT = "50" -def make_settings( - extra_argv: list[str] | None = None, *, uniref_variant: str = TEST_DEFAULT_UNIREF_VARIANT, **kwargs -) -> Settings: +def make_settings(uniref_variant: str = TEST_DEFAULT_UNIREF_VARIANT, **kwargs: str | int) -> Settings: """Generate a validated Settings object.""" data = {"uniref": uniref_variant, **kwargs} return Settings.model_validate(data)