Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions src/h5ad/core/subset.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,11 @@ def _group_get(parent: Any, key: str) -> Any | None:
return parent[key] if key in parent else None


def _ensure_optional_anndata_groups(dst: Any) -> None:
for key in ("layers", "obsm", "obsp", "varm", "varp"):
_ensure_group(dst, key)
Comment on lines +47 to +49

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Set mapping attrs when creating optional AnnData groups

subset_h5ad now always creates layers/obsm/obsp/varm/varp when absent, but these groups are created without any encoding-type/encoding-version metadata. In this codebase’s AnnData element docs, those members are mappings and mappings must carry encoding-type="dict" and encoding-version="0.1.0"; creating the groups without attrs turns previously-valid “absent optional member” outputs into malformed mapping groups, which can break schema-aware downstream readers/concat workflows.

Useful? React with 👍 / 👎.



def _decode_attr(value: Any) -> Any:
if isinstance(value, bytes):
return value.decode("utf-8")
Expand Down Expand Up @@ -517,6 +522,8 @@ def subset_h5ad(
total=1,
)

_ensure_optional_anndata_groups(dst)

if inplace:
if file.exists():
if file.is_dir():
Expand Down
70 changes: 64 additions & 6 deletions src/h5ad/storage/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from pathlib import Path
from typing import Any, Iterable, Optional, Sequence
import shutil
import warnings

import h5py

Expand All @@ -15,6 +16,10 @@
import numpy as np


ROOT_ENCODING_TYPE = "anndata"
ROOT_ENCODING_VERSION = "0.1.0"


@dataclass
class Store:
backend: str
Expand Down Expand Up @@ -96,19 +101,63 @@ def open_store(path: Path, mode: str) -> Store:
if backend == "zarr":
_require_zarr()
root = zarr.open_group(str(path), mode=mode)
if _is_writable_mode(mode):
ensure_anndata_root_attrs(root)
else:
warn_if_missing_anndata_root_attrs(root, path=path)
return Store(backend="zarr", root=root, path=path)
root = h5py.File(path, mode)
if _is_writable_mode(mode):
ensure_anndata_root_attrs(root)
else:
warn_if_missing_anndata_root_attrs(root, path=path)
return Store(backend="hdf5", root=root, path=path)


def _decode_attr(value: Any) -> Any:
if isinstance(value, bytes):
return value.decode("utf-8")
return value


def _is_writable_mode(mode: str) -> bool:
return any(flag in mode for flag in ("w", "a", "+", "x"))


def has_valid_anndata_root_attrs(root: Any) -> bool:
enc_type = _decode_attr(root.attrs.get("encoding-type", None))
enc_ver = _decode_attr(root.attrs.get("encoding-version", None))
return enc_type == ROOT_ENCODING_TYPE and enc_ver == ROOT_ENCODING_VERSION


def ensure_anndata_root_attrs(root: Any) -> None:
root.attrs["encoding-type"] = ROOT_ENCODING_TYPE
root.attrs["encoding-version"] = ROOT_ENCODING_VERSION


def warn_if_missing_anndata_root_attrs(root: Any, *, path: Path) -> None:
if has_valid_anndata_root_attrs(root):
return

enc_type = _decode_attr(root.attrs.get("encoding-type", None))
enc_ver = _decode_attr(root.attrs.get("encoding-version", None))
warnings.warn(
(
f"Store '{path}' root has missing or invalid AnnData attrs "
f"(encoding-type={ROOT_ENCODING_TYPE!r}, encoding-version={ROOT_ENCODING_VERSION!r}). "
f"Found encoding-type={enc_type!r}, encoding-version={enc_ver!r}."
),
UserWarning,
stacklevel=2,
)


def _normalize_attr_value(value: Any, target_backend: str) -> Any:
if target_backend == "zarr":
if isinstance(value, bytes):
return value.decode("utf-8")
if isinstance(value, (list, tuple)):
return [
v.decode("utf-8") if isinstance(v, bytes) else v for v in value
]
return [v.decode("utf-8") if isinstance(v, bytes) else v for v in value]
if isinstance(value, np.ndarray):
if value.dtype.kind in ("S", "O"):
return [
Expand Down Expand Up @@ -187,7 +236,9 @@ def create_dataset(
if zarr_format == 3:
kwargs = dict(kwargs)
kwargs.pop("compressor", None)
elif zarr_format == 2 and "compressors" in kwargs and "compressor" not in kwargs:
elif (
zarr_format == 2 and "compressors" in kwargs and "compressor" not in kwargs
):
kwargs = dict(kwargs)
compressors = kwargs.pop("compressors")
if isinstance(compressors, (list, tuple)) and len(compressors) == 1:
Expand Down Expand Up @@ -234,8 +285,12 @@ def copy_dataset(src: Any, dst_group: Any, name: str) -> Any:
return ds


def copy_tree(src_obj: Any, dst_group: Any, name: str, *, exclude: Iterable[str] = ()) -> Any:
if is_hdf5_group(dst_group) and (is_hdf5_group(src_obj) or is_hdf5_dataset(src_obj)):
def copy_tree(
src_obj: Any, dst_group: Any, name: str, *, exclude: Iterable[str] = ()
) -> Any:
if is_hdf5_group(dst_group) and (
is_hdf5_group(src_obj) or is_hdf5_dataset(src_obj)
):
if not exclude:
dst_group.copy(src_obj, dst_group, name)
return dst_group[name]
Expand All @@ -256,6 +311,9 @@ def copy_tree(src_obj: Any, dst_group: Any, name: str, *, exclude: Iterable[str]


def copy_store_contents(src_root: Any, dst_root: Any) -> None:
target_backend = "zarr" if is_zarr_group(dst_root) else "hdf5"
copy_attrs(src_root.attrs, dst_root.attrs, target_backend=target_backend)
ensure_anndata_root_attrs(dst_root)
for key in src_root.keys():
copy_tree(src_root[key], dst_root, key)

Expand Down
42 changes: 42 additions & 0 deletions tests/test_storage_root_attrs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
"""Tests for AnnData root encoding attributes enforcement/warnings."""

from pathlib import Path

import h5py
import pytest

from h5ad.storage import open_store


def _make_minimal_h5ad(path: Path) -> None:
with h5py.File(path, "w") as f:
obs = f.create_group("obs")
obs.attrs["_index"] = "obs_names"
obs.create_dataset("obs_names", data=[b"cell_1"])

var = f.create_group("var")
var.attrs["_index"] = "var_names"
var.create_dataset("var_names", data=[b"gene_1"])

f.create_dataset("X", data=[[1.0]])


def test_open_store_read_warns_for_missing_root_attrs(temp_dir: Path) -> None:
file_path = temp_dir / "missing_root_attrs.h5ad"
_make_minimal_h5ad(file_path)

with pytest.warns(UserWarning, match="missing required AnnData attrs"):
with open_store(file_path, "r"):
pass


def test_open_store_writable_mode_sets_root_attrs(temp_dir: Path) -> None:
file_path = temp_dir / "set_root_attrs.h5ad"
_make_minimal_h5ad(file_path)

with open_store(file_path, "a"):
pass

with h5py.File(file_path, "r") as f:
assert f.attrs.get("encoding-type") == "anndata"
assert f.attrs.get("encoding-version") == "0.1.0"
31 changes: 29 additions & 2 deletions tests/test_subset.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,30 @@ def test_subset_sparse_empty_result(self, sample_sparse_csr_h5ad, temp_dir):
class TestSubsetH5ad:
"""Integration tests for subset_h5ad function."""

def test_subset_h5ad_creates_optional_empty_groups(
self, sample_h5ad_file, temp_dir
):
"""Subset output should include optional AnnData groups even if absent in source."""
obs_file = temp_dir / "obs_names.txt"
obs_file.write_text("cell_1\ncell_3\n")

output = temp_dir / "subset.h5ad"
console = Console(stderr=True)

subset_h5ad(
file=sample_h5ad_file,
output=output,
obs_file=obs_file,
var_file=None,
chunk_rows=1024,
console=console,
)

with h5py.File(output, "r") as f:
for key in ("layers", "obsm", "obsp", "varm", "varp"):
assert key in f
assert isinstance(f[key], h5py.Group)

def test_subset_h5ad_obs_only(self, sample_h5ad_file, temp_dir):
"""Test subsetting h5ad file by obs only."""
obs_file = temp_dir / "obs_names.txt"
Expand Down Expand Up @@ -438,7 +462,9 @@ def test_subset_h5ad_obsp_sparse_group(self, temp_dir):
conn.attrs["shape"] = np.array([4, 4], dtype=np.int64)
conn.create_dataset("data", data=np.array([1.0, 2.0, 3.0, 4.0]))
conn.create_dataset("indices", data=np.array([0, 1, 2, 3], dtype=np.int64))
conn.create_dataset("indptr", data=np.array([0, 1, 2, 3, 4], dtype=np.int64))
conn.create_dataset(
"indptr", data=np.array([0, 1, 2, 3, 4], dtype=np.int64)
)

obs_file = temp_dir / "obs_names.txt"
obs_file.write_text("cell_1\ncell_3\n")
Expand Down Expand Up @@ -562,7 +588,8 @@ def _csr_group(parent, name, shape):
obs = f.create_group("obs")
obs.attrs["_index"] = "obs_names"
obs.create_dataset(
"obs_names", data=np.array(["cell_1", "cell_2", "cell_3", "cell_4"], dtype="S")
"obs_names",
data=np.array(["cell_1", "cell_2", "cell_3", "cell_4"], dtype="S"),
)

var = f.create_group("var")
Expand Down
2 changes: 1 addition & 1 deletion uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading