diff --git a/src/h5ad/core/subset.py b/src/h5ad/core/subset.py index d9e7829..da66a4b 100644 --- a/src/h5ad/core/subset.py +++ b/src/h5ad/core/subset.py @@ -44,6 +44,11 @@ def _group_get(parent: Any, key: str) -> Any | None: return parent[key] if key in parent else None +def _ensure_optional_anndata_groups(dst: Any) -> None: + for key in ("layers", "obsm", "obsp", "varm", "varp"): + _ensure_group(dst, key) + + def _decode_attr(value: Any) -> Any: if isinstance(value, bytes): return value.decode("utf-8") @@ -517,6 +522,8 @@ def subset_h5ad( total=1, ) + _ensure_optional_anndata_groups(dst) + if inplace: if file.exists(): if file.is_dir(): diff --git a/src/h5ad/storage/__init__.py b/src/h5ad/storage/__init__.py index 43d876d..29c2227 100644 --- a/src/h5ad/storage/__init__.py +++ b/src/h5ad/storage/__init__.py @@ -4,6 +4,7 @@ from pathlib import Path from typing import Any, Iterable, Optional, Sequence import shutil +import warnings import h5py @@ -15,6 +16,10 @@ import numpy as np +ROOT_ENCODING_TYPE = "anndata" +ROOT_ENCODING_VERSION = "0.1.0" + + @dataclass class Store: backend: str @@ -96,19 +101,63 @@ def open_store(path: Path, mode: str) -> Store: if backend == "zarr": _require_zarr() root = zarr.open_group(str(path), mode=mode) + if _is_writable_mode(mode): + ensure_anndata_root_attrs(root) + else: + warn_if_missing_anndata_root_attrs(root, path=path) return Store(backend="zarr", root=root, path=path) root = h5py.File(path, mode) + if _is_writable_mode(mode): + ensure_anndata_root_attrs(root) + else: + warn_if_missing_anndata_root_attrs(root, path=path) return Store(backend="hdf5", root=root, path=path) +def _decode_attr(value: Any) -> Any: + if isinstance(value, bytes): + return value.decode("utf-8") + return value + + +def _is_writable_mode(mode: str) -> bool: + return any(flag in mode for flag in ("w", "a", "+", "x")) + + +def has_valid_anndata_root_attrs(root: Any) -> bool: + enc_type = _decode_attr(root.attrs.get("encoding-type", None)) + enc_ver = _decode_attr(root.attrs.get("encoding-version", None)) + return enc_type == ROOT_ENCODING_TYPE and enc_ver == ROOT_ENCODING_VERSION + + +def ensure_anndata_root_attrs(root: Any) -> None: + root.attrs["encoding-type"] = ROOT_ENCODING_TYPE + root.attrs["encoding-version"] = ROOT_ENCODING_VERSION + + +def warn_if_missing_anndata_root_attrs(root: Any, *, path: Path) -> None: + if has_valid_anndata_root_attrs(root): + return + + enc_type = _decode_attr(root.attrs.get("encoding-type", None)) + enc_ver = _decode_attr(root.attrs.get("encoding-version", None)) + warnings.warn( + ( + f"Store '{path}' root has missing or invalid AnnData attrs " + f"(encoding-type={ROOT_ENCODING_TYPE!r}, encoding-version={ROOT_ENCODING_VERSION!r}). " + f"Found encoding-type={enc_type!r}, encoding-version={enc_ver!r}." + ), + UserWarning, + stacklevel=2, + ) + + def _normalize_attr_value(value: Any, target_backend: str) -> Any: if target_backend == "zarr": if isinstance(value, bytes): return value.decode("utf-8") if isinstance(value, (list, tuple)): - return [ - v.decode("utf-8") if isinstance(v, bytes) else v for v in value - ] + return [v.decode("utf-8") if isinstance(v, bytes) else v for v in value] if isinstance(value, np.ndarray): if value.dtype.kind in ("S", "O"): return [ @@ -187,7 +236,9 @@ def create_dataset( if zarr_format == 3: kwargs = dict(kwargs) kwargs.pop("compressor", None) - elif zarr_format == 2 and "compressors" in kwargs and "compressor" not in kwargs: + elif ( + zarr_format == 2 and "compressors" in kwargs and "compressor" not in kwargs + ): kwargs = dict(kwargs) compressors = kwargs.pop("compressors") if isinstance(compressors, (list, tuple)) and len(compressors) == 1: @@ -234,8 +285,12 @@ def copy_dataset(src: Any, dst_group: Any, name: str) -> Any: return ds -def copy_tree(src_obj: Any, dst_group: Any, name: str, *, exclude: Iterable[str] = ()) -> Any: - if is_hdf5_group(dst_group) and (is_hdf5_group(src_obj) or is_hdf5_dataset(src_obj)): +def copy_tree( + src_obj: Any, dst_group: Any, name: str, *, exclude: Iterable[str] = () +) -> Any: + if is_hdf5_group(dst_group) and ( + is_hdf5_group(src_obj) or is_hdf5_dataset(src_obj) + ): if not exclude: dst_group.copy(src_obj, dst_group, name) return dst_group[name] @@ -256,6 +311,9 @@ def copy_tree(src_obj: Any, dst_group: Any, name: str, *, exclude: Iterable[str] def copy_store_contents(src_root: Any, dst_root: Any) -> None: + target_backend = "zarr" if is_zarr_group(dst_root) else "hdf5" + copy_attrs(src_root.attrs, dst_root.attrs, target_backend=target_backend) + ensure_anndata_root_attrs(dst_root) for key in src_root.keys(): copy_tree(src_root[key], dst_root, key) diff --git a/tests/test_storage_root_attrs.py b/tests/test_storage_root_attrs.py new file mode 100644 index 0000000..13f4621 --- /dev/null +++ b/tests/test_storage_root_attrs.py @@ -0,0 +1,42 @@ +"""Tests for AnnData root encoding attributes enforcement/warnings.""" + +from pathlib import Path + +import h5py +import pytest + +from h5ad.storage import open_store + + +def _make_minimal_h5ad(path: Path) -> None: + with h5py.File(path, "w") as f: + obs = f.create_group("obs") + obs.attrs["_index"] = "obs_names" + obs.create_dataset("obs_names", data=[b"cell_1"]) + + var = f.create_group("var") + var.attrs["_index"] = "var_names" + var.create_dataset("var_names", data=[b"gene_1"]) + + f.create_dataset("X", data=[[1.0]]) + + +def test_open_store_read_warns_for_missing_root_attrs(temp_dir: Path) -> None: + file_path = temp_dir / "missing_root_attrs.h5ad" + _make_minimal_h5ad(file_path) + + with pytest.warns(UserWarning, match="missing required AnnData attrs"): + with open_store(file_path, "r"): + pass + + +def test_open_store_writable_mode_sets_root_attrs(temp_dir: Path) -> None: + file_path = temp_dir / "set_root_attrs.h5ad" + _make_minimal_h5ad(file_path) + + with open_store(file_path, "a"): + pass + + with h5py.File(file_path, "r") as f: + assert f.attrs.get("encoding-type") == "anndata" + assert f.attrs.get("encoding-version") == "0.1.0" diff --git a/tests/test_subset.py b/tests/test_subset.py index 78c5cf8..42d5ec6 100644 --- a/tests/test_subset.py +++ b/tests/test_subset.py @@ -284,6 +284,30 @@ def test_subset_sparse_empty_result(self, sample_sparse_csr_h5ad, temp_dir): class TestSubsetH5ad: """Integration tests for subset_h5ad function.""" + def test_subset_h5ad_creates_optional_empty_groups( + self, sample_h5ad_file, temp_dir + ): + """Subset output should include optional AnnData groups even if absent in source.""" + obs_file = temp_dir / "obs_names.txt" + obs_file.write_text("cell_1\ncell_3\n") + + output = temp_dir / "subset.h5ad" + console = Console(stderr=True) + + subset_h5ad( + file=sample_h5ad_file, + output=output, + obs_file=obs_file, + var_file=None, + chunk_rows=1024, + console=console, + ) + + with h5py.File(output, "r") as f: + for key in ("layers", "obsm", "obsp", "varm", "varp"): + assert key in f + assert isinstance(f[key], h5py.Group) + def test_subset_h5ad_obs_only(self, sample_h5ad_file, temp_dir): """Test subsetting h5ad file by obs only.""" obs_file = temp_dir / "obs_names.txt" @@ -438,7 +462,9 @@ def test_subset_h5ad_obsp_sparse_group(self, temp_dir): conn.attrs["shape"] = np.array([4, 4], dtype=np.int64) conn.create_dataset("data", data=np.array([1.0, 2.0, 3.0, 4.0])) conn.create_dataset("indices", data=np.array([0, 1, 2, 3], dtype=np.int64)) - conn.create_dataset("indptr", data=np.array([0, 1, 2, 3, 4], dtype=np.int64)) + conn.create_dataset( + "indptr", data=np.array([0, 1, 2, 3, 4], dtype=np.int64) + ) obs_file = temp_dir / "obs_names.txt" obs_file.write_text("cell_1\ncell_3\n") @@ -562,7 +588,8 @@ def _csr_group(parent, name, shape): obs = f.create_group("obs") obs.attrs["_index"] = "obs_names" obs.create_dataset( - "obs_names", data=np.array(["cell_1", "cell_2", "cell_3", "cell_4"], dtype="S") + "obs_names", + data=np.array(["cell_1", "cell_2", "cell_3", "cell_4"], dtype="S"), ) var = f.create_group("var") diff --git a/uv.lock b/uv.lock index cb342bd..7884068 100644 --- a/uv.lock +++ b/uv.lock @@ -134,7 +134,7 @@ wheels = [ [[package]] name = "h5ad" -version = "0.3.0" +version = "0.3.1" source = { editable = "." } dependencies = [ { name = "h5py" },