Skip to content
1 change: 1 addition & 0 deletions docs/release-notes/4142.fix.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Download cached dataset files atomically so that parallel processes sharing a dataset directory (e.g. `pytest-xdist` workers) can no longer read a partially-written file {smaller}`gaoflow`
20 changes: 16 additions & 4 deletions src/scanpy/readwrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -1084,6 +1084,7 @@ def _get_filename_from_key(key, ext=None) -> Path:

def _download(url: str, path: Path):
from ssl import create_default_context
from tempfile import NamedTemporaryFile
from urllib.request import Request, urlopen

from certifi import contents
Expand All @@ -1092,6 +1093,8 @@ def _download(url: str, path: Path):
blocksize = 1024 * 8
blocknum = 0

# Write to a temp file and rename so readers never see a partial file (#4097).
tmp_path: Path | None = None
try:
req = Request(url, headers={"User-agent": "scanpy-user"})

Expand All @@ -1105,19 +1108,28 @@ def _download(url: str, path: Path):
unit_divisor=1024,
total=total if total is None else int(total),
) as t,
path.open("wb") as f,
NamedTemporaryFile(
dir=path.parent,
prefix=f"{path.name}.",
suffix=".part",
delete=False,
) as f,
):
tmp_path = Path(f.name)
block = resp.read(blocksize)
while block:
f.write(block)
blocknum += 1
t.update(len(block))
block = resp.read(blocksize)

tmp_path.replace(path)
tmp_path = None

except (KeyboardInterrupt, Exception):
# Make sure file doesn’t exist half-downloaded
if path.is_file():
path.unlink()
# Only remove our own temp file; leave path, which may be another process's.
if tmp_path is not None:
tmp_path.unlink(missing_ok=True)
raise


Expand Down
77 changes: 77 additions & 0 deletions tests/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

if TYPE_CHECKING:
from collections.abc import Callable
from typing import Self

from anndata import AnnData

Expand Down Expand Up @@ -168,6 +169,82 @@ def test_download_failure() -> None:
excinfo.value.close()


def test_download_atomic(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
"""The destination must not appear until the download finished (#4097)."""
import io
import urllib.request

from scanpy.readwrite import _download

content = b"0123456789" * 5_000
dest = tmp_path / "cache" / "data.bin"
dest.parent.mkdir()
dest_present_during_download: list[bool] = []

class FakeResponse:
def __init__(self) -> None:
self._buf = io.BytesIO(content)

def info(self) -> dict[str, str]:
return {"content-length": str(len(content))}

def read(self, size: int) -> bytes:
chunk = self._buf.read(size)
if chunk:
dest_present_during_download.append(dest.exists())
return chunk

def __enter__(self) -> Self:
return self

def __exit__(self, *exc: object) -> bool:
return False

monkeypatch.setattr(urllib.request, "urlopen", lambda *a, **k: FakeResponse())

_download("http://example.invalid/data.bin", dest)

assert dest.read_bytes() == content
assert len(dest_present_during_download) > 1
assert not any(dest_present_during_download)
assert list(dest.parent.iterdir()) == [dest]


def test_download_failure_keeps_existing_file(
tmp_path: Path, monkeypatch: pytest.MonkeyPatch
) -> None:
"""A failed download must not delete an already-present destination (#4097)."""
import urllib.request

from scanpy.readwrite import _download

dest = tmp_path / "cache" / "data.bin"
dest.parent.mkdir()
dest.write_bytes(b"complete")

class FailingResponse:
def info(self) -> dict[str, str]:
return {"content-length": "100"}

def read(self, size: int) -> bytes:
msg = "connection reset"
raise OSError(msg)

def __enter__(self) -> Self:
return self

def __exit__(self, *exc: object) -> bool:
return False

monkeypatch.setattr(urllib.request, "urlopen", lambda *a, **k: FailingResponse())

with pytest.raises(OSError, match="connection reset"):
_download("http://example.invalid/data.bin", dest)

assert dest.read_bytes() == b"complete"
assert list(dest.parent.iterdir()) == [dest]


# These are tested via doctest
DS_INCLUDED = frozenset({"krumsiek11", "toggleswitch", "pbmc68k_reduced"})
# These have parameters that affect shape and so on
Expand Down
Loading