diff --git a/docs/release-notes/4142.fix.md b/docs/release-notes/4142.fix.md new file mode 100644 index 0000000000..9305dcfed4 --- /dev/null +++ b/docs/release-notes/4142.fix.md @@ -0,0 +1 @@ +Download cached dataset files atomically so that parallel processes sharing a dataset directory (e.g. `pytest-xdist` workers) can no longer read a partially-written file {smaller}`gaoflow` diff --git a/src/scanpy/readwrite.py b/src/scanpy/readwrite.py index 9a3d3748fe..fc10b07bee 100644 --- a/src/scanpy/readwrite.py +++ b/src/scanpy/readwrite.py @@ -1084,6 +1084,7 @@ def _get_filename_from_key(key, ext=None) -> Path: def _download(url: str, path: Path): from ssl import create_default_context + from tempfile import NamedTemporaryFile from urllib.request import Request, urlopen from certifi import contents @@ -1092,6 +1093,8 @@ def _download(url: str, path: Path): blocksize = 1024 * 8 blocknum = 0 + # Write to a temp file and rename so readers never see a partial file (#4097). + tmp_path: Path | None = None try: req = Request(url, headers={"User-agent": "scanpy-user"}) @@ -1105,8 +1108,14 @@ def _download(url: str, path: Path): unit_divisor=1024, total=total if total is None else int(total), ) as t, - path.open("wb") as f, + NamedTemporaryFile( + dir=path.parent, + prefix=f"{path.name}.", + suffix=".part", + delete=False, + ) as f, ): + tmp_path = Path(f.name) block = resp.read(blocksize) while block: f.write(block) @@ -1114,10 +1123,13 @@ def _download(url: str, path: Path): t.update(len(block)) block = resp.read(blocksize) + tmp_path.replace(path) + tmp_path = None + except (KeyboardInterrupt, Exception): - # Make sure file doesn’t exist half-downloaded - if path.is_file(): - path.unlink() + # Only remove our own temp file; leave path, which may be another process's. + if tmp_path is not None: + tmp_path.unlink(missing_ok=True) raise diff --git a/tests/test_datasets.py b/tests/test_datasets.py index e97d12a65a..7d72ddd236 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -21,6 +21,7 @@ if TYPE_CHECKING: from collections.abc import Callable + from typing import Self from anndata import AnnData @@ -168,6 +169,82 @@ def test_download_failure() -> None: excinfo.value.close() +def test_download_atomic(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: + """The destination must not appear until the download finished (#4097).""" + import io + import urllib.request + + from scanpy.readwrite import _download + + content = b"0123456789" * 5_000 + dest = tmp_path / "cache" / "data.bin" + dest.parent.mkdir() + dest_present_during_download: list[bool] = [] + + class FakeResponse: + def __init__(self) -> None: + self._buf = io.BytesIO(content) + + def info(self) -> dict[str, str]: + return {"content-length": str(len(content))} + + def read(self, size: int) -> bytes: + chunk = self._buf.read(size) + if chunk: + dest_present_during_download.append(dest.exists()) + return chunk + + def __enter__(self) -> Self: + return self + + def __exit__(self, *exc: object) -> bool: + return False + + monkeypatch.setattr(urllib.request, "urlopen", lambda *a, **k: FakeResponse()) + + _download("http://example.invalid/data.bin", dest) + + assert dest.read_bytes() == content + assert len(dest_present_during_download) > 1 + assert not any(dest_present_during_download) + assert list(dest.parent.iterdir()) == [dest] + + +def test_download_failure_keeps_existing_file( + tmp_path: Path, monkeypatch: pytest.MonkeyPatch +) -> None: + """A failed download must not delete an already-present destination (#4097).""" + import urllib.request + + from scanpy.readwrite import _download + + dest = tmp_path / "cache" / "data.bin" + dest.parent.mkdir() + dest.write_bytes(b"complete") + + class FailingResponse: + def info(self) -> dict[str, str]: + return {"content-length": "100"} + + def read(self, size: int) -> bytes: + msg = "connection reset" + raise OSError(msg) + + def __enter__(self) -> Self: + return self + + def __exit__(self, *exc: object) -> bool: + return False + + monkeypatch.setattr(urllib.request, "urlopen", lambda *a, **k: FailingResponse()) + + with pytest.raises(OSError, match="connection reset"): + _download("http://example.invalid/data.bin", dest) + + assert dest.read_bytes() == b"complete" + assert list(dest.parent.iterdir()) == [dest] + + # These are tested via doctest DS_INCLUDED = frozenset({"krumsiek11", "toggleswitch", "pbmc68k_reduced"}) # These have parameters that affect shape and so on