Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions src/asebytes/_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,21 @@ def parse_uri(path: str) -> tuple[str | None, str]:
# ---------------------------------------------------------------------------


def _is_h5md_file(path: str) -> bool:
"""Check if an HDF5 file contains H5MD metadata."""
from pathlib import Path

if not Path(path).is_file():
return False
try:
import h5py

with h5py.File(path, "r") as f:
return "h5md" in f
except Exception:
return False


def _pick_class(entry: _RegistryEntry, path: str, writable: bool | None):
"""Import the module from *entry* and return the appropriate class."""
mod = _import_module(entry.module_path)
Expand Down Expand Up @@ -209,6 +224,27 @@ def resolve_backend(

candidates.append(entry)

# -- Sniff *.h5 files for H5MD content ---------------------------------
if (
scheme is None
and layer == "object"
and candidates
and any(
e.match_type == "pattern" and e.match_value == "*.h5"
for e in candidates
)
and _is_h5md_file(path_or_uri)
):
# Replace with H5MD backend entry
h5md_entries = [
e for e in _REGISTRY
if e.match_type == "pattern"
and e.match_value == "*.h5md"
and e.layer == layer
]
if h5md_entries:
candidates = h5md_entries

# -- No direct match -> cross-layer adapter wrapping --------------------
if not candidates:
if _allow_fallback:
Expand Down
12 changes: 7 additions & 5 deletions src/asebytes/h5md/_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,11 +199,13 @@ def _discover(self) -> None:
if col.startswith("_"):
continue
self._columns.append(col)
# Per-atom detection: particles/ columns with ndim >= 2
if col not in ("cell", "pbc") and self._store.has_array(col):
shape = self._store.get_shape(col)
if len(shape) >= 2 and shape[1] > 1:
self._per_atom_cols.add(col)
# Per-atom detection: columns in /particles/ with ndim >= 2
if col not in ("cell", "pbc"):
h5_path = self._store._path_cache.get(col)
if h5_path is not None and h5_path.startswith("/particles/"):
shape = self._store.get_shape(col)
if len(shape) >= 2:
self._per_atom_cols.add(col)

# Rebuild known_arrays and shapes
self._known_arrays = set()
Expand Down
17 changes: 14 additions & 3 deletions src/asebytes/h5md/_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def __init__(
self._compression_opts = compression_opts
self._chunk_frames = chunk_frames
self._ds_cache: dict[str, Any] = {} # column name -> h5py.Dataset
self._path_cache: dict[str, str] = {} # column name -> actual h5 path

# ------------------------------------------------------------------
# Path translation
Expand Down Expand Up @@ -210,11 +211,19 @@ def _get_element_origin(self, h5_path: str) -> str | None:
# Internal helpers
# ------------------------------------------------------------------

def _resolve_h5_path(self, key: str) -> str | None:
"""Resolve column name to H5 path, checking cache first."""
h5_path = self._path_cache.get(key)
if h5_path is not None:
return h5_path
h5_path, _ = self._column_to_h5(key)
return h5_path

def _get_ds(self, key: str) -> Any:
"""Return cached h5py.Dataset for the ``value`` sub-dataset."""
ds = self._ds_cache.get(key)
if ds is None:
h5_path, _ = self._column_to_h5(key)
h5_path = self._resolve_h5_path(key)
if h5_path is None:
raise KeyError(f"Unknown column: {key!r}")
ds = self._file[f"{h5_path}/value"]
Expand Down Expand Up @@ -397,7 +406,7 @@ def has_array(self, name: str) -> bool:
return name in self._file[meta_path]
except KeyError:
return False
h5_path, _ = self._column_to_h5(name)
h5_path = self._resolve_h5_path(name)
if h5_path is None:
return False
try:
Expand Down Expand Up @@ -445,9 +454,11 @@ def _walk_elements(
if isinstance(child, h5py.Group):
if "value" in child:
# This is an H5MD element
col = self._h5_to_column(f"{path}/{child_name}")
h5_path = f"{path}/{child_name}"
col = self._h5_to_column(h5_path)
if col is not None:
out.append(col)
self._path_cache[col] = h5_path
else:
# Recurse (e.g. into box/)
self._walk_elements(child, f"{path}/{child_name}", out)
Expand Down
107 changes: 107 additions & 0 deletions tests/contract/test_znh5md_roundtrip.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
"""znh5md <-> asebytes full round-trip tests.

Covers ALL s22_* fixtures to ensure every property type (positions, numbers,
cell, pbc, all calc results, custom info, custom per-atom arrays, velocities,
mixed pbc/cell, sparse properties) round-trips correctly between znh5md and
asebytes.

The existing interop tests only checked positions+numbers+forces, which use
standard H5MD element names and missed mapping bugs where per-atom calc
results placed by znh5md in particles/ (with ASE_ENTRY_ORIGIN=calc) could
not be found on read.
"""

from __future__ import annotations

import pytest
import znh5md

from asebytes import ASEIO

from .conftest import assert_atoms_equal

# All s22_* fixture names from conftest that produce list[ase.Atoms]
S22_FIXTURES = [
"s22",
"s22_energy",
"s22_energy_forces",
"s22_all_properties",
"s22_info_arrays_calc",
"s22_mixed_pbc_cell",
"s22_info_arrays_calc_missing_inbetween",
]


# ---------------------------------------------------------------------------
# znh5md -> asebytes (*.h5md and *.h5 via content sniffing)
# ---------------------------------------------------------------------------


class TestZnH5MDToAsebytes:
"""Verify asebytes can read files written by znh5md.

Tests both the native *.h5md extension and *.h5 files that require
content sniffing to route to H5MDBackend instead of ColumnarBackend.
"""

@pytest.mark.parametrize("fixture_name", S22_FIXTURES)
@pytest.mark.parametrize("ext", [".h5md", ".h5"])
def test_read(self, tmp_path, fixture_name, ext, request):
frames = request.getfixturevalue(fixture_name)
path = str(tmp_path / f"znh5md{ext}")
znh5md.IO(path).extend(frames)

db = ASEIO(path)
assert len(db) == len(frames)
for i, expected in enumerate(frames):
result = db[i]
assert_atoms_equal(result, expected)


# ---------------------------------------------------------------------------
# asebytes -> znh5md
# ---------------------------------------------------------------------------


class TestAsebytesToZnH5MD:
"""Verify znh5md can read files written by asebytes."""

@pytest.mark.parametrize("fixture_name", S22_FIXTURES)
def test_write(self, tmp_path, fixture_name, request):
frames = request.getfixturevalue(fixture_name)
path = str(tmp_path / "asebytes.h5md")
ASEIO(path).extend(frames)

zio = znh5md.IO(path)
assert len(zio) == len(frames)
for i, expected in enumerate(frames):
result = zio[i]
assert_atoms_equal(result, expected, atol=1e-6)


# ---------------------------------------------------------------------------
# Full bidirectional: znh5md -> asebytes -> znh5md
# ---------------------------------------------------------------------------


class TestBidirectionalRoundtrip:
"""Write with znh5md, read with asebytes, write back, read with znh5md."""

@pytest.mark.parametrize("fixture_name", S22_FIXTURES)
def test_roundtrip(self, tmp_path, fixture_name, request):
frames = request.getfixturevalue(fixture_name)

# Step 1: znh5md write -> asebytes read
zpath = str(tmp_path / "step1.h5md")
znh5md.IO(zpath).extend(frames)
db = ASEIO(zpath)
intermediate = [db[i] for i in range(len(db))]

# Step 2: asebytes write -> znh5md read
apath = str(tmp_path / "step2.h5md")
ASEIO(apath).extend(intermediate)
zio = znh5md.IO(apath)
assert len(zio) == len(frames)
for i, expected in enumerate(frames):
result = zio[i]
assert_atoms_equal(result, expected, atol=1e-6)