diff --git a/src/asebytes/_registry.py b/src/asebytes/_registry.py index 42a4a14..060b50c 100644 --- a/src/asebytes/_registry.py +++ b/src/asebytes/_registry.py @@ -127,6 +127,21 @@ def parse_uri(path: str) -> tuple[str | None, str]: # --------------------------------------------------------------------------- +def _is_h5md_file(path: str) -> bool: + """Check if an HDF5 file contains H5MD metadata.""" + from pathlib import Path + + if not Path(path).is_file(): + return False + try: + import h5py + + with h5py.File(path, "r") as f: + return "h5md" in f + except Exception: + return False + + def _pick_class(entry: _RegistryEntry, path: str, writable: bool | None): """Import the module from *entry* and return the appropriate class.""" mod = _import_module(entry.module_path) @@ -209,6 +224,27 @@ def resolve_backend( candidates.append(entry) + # -- Sniff *.h5 files for H5MD content --------------------------------- + if ( + scheme is None + and layer == "object" + and candidates + and any( + e.match_type == "pattern" and e.match_value == "*.h5" + for e in candidates + ) + and _is_h5md_file(path_or_uri) + ): + # Replace with H5MD backend entry + h5md_entries = [ + e for e in _REGISTRY + if e.match_type == "pattern" + and e.match_value == "*.h5md" + and e.layer == layer + ] + if h5md_entries: + candidates = h5md_entries + # -- No direct match -> cross-layer adapter wrapping -------------------- if not candidates: if _allow_fallback: diff --git a/src/asebytes/h5md/_backend.py b/src/asebytes/h5md/_backend.py index a146404..f15ec10 100644 --- a/src/asebytes/h5md/_backend.py +++ b/src/asebytes/h5md/_backend.py @@ -199,11 +199,13 @@ def _discover(self) -> None: if col.startswith("_"): continue self._columns.append(col) - # Per-atom detection: particles/ columns with ndim >= 2 - if col not in ("cell", "pbc") and self._store.has_array(col): - shape = self._store.get_shape(col) - if len(shape) >= 2 and shape[1] > 1: - self._per_atom_cols.add(col) + # Per-atom detection: columns in /particles/ with ndim >= 2 + if col not in ("cell", "pbc"): + h5_path = self._store._path_cache.get(col) + if h5_path is not None and h5_path.startswith("/particles/"): + shape = self._store.get_shape(col) + if len(shape) >= 2: + self._per_atom_cols.add(col) # Rebuild known_arrays and shapes self._known_arrays = set() diff --git a/src/asebytes/h5md/_store.py b/src/asebytes/h5md/_store.py index 00c08ce..a6cf87d 100644 --- a/src/asebytes/h5md/_store.py +++ b/src/asebytes/h5md/_store.py @@ -108,6 +108,7 @@ def __init__( self._compression_opts = compression_opts self._chunk_frames = chunk_frames self._ds_cache: dict[str, Any] = {} # column name -> h5py.Dataset + self._path_cache: dict[str, str] = {} # column name -> actual h5 path # ------------------------------------------------------------------ # Path translation @@ -210,11 +211,19 @@ def _get_element_origin(self, h5_path: str) -> str | None: # Internal helpers # ------------------------------------------------------------------ + def _resolve_h5_path(self, key: str) -> str | None: + """Resolve column name to H5 path, checking cache first.""" + h5_path = self._path_cache.get(key) + if h5_path is not None: + return h5_path + h5_path, _ = self._column_to_h5(key) + return h5_path + def _get_ds(self, key: str) -> Any: """Return cached h5py.Dataset for the ``value`` sub-dataset.""" ds = self._ds_cache.get(key) if ds is None: - h5_path, _ = self._column_to_h5(key) + h5_path = self._resolve_h5_path(key) if h5_path is None: raise KeyError(f"Unknown column: {key!r}") ds = self._file[f"{h5_path}/value"] @@ -397,7 +406,7 @@ def has_array(self, name: str) -> bool: return name in self._file[meta_path] except KeyError: return False - h5_path, _ = self._column_to_h5(name) + h5_path = self._resolve_h5_path(name) if h5_path is None: return False try: @@ -445,9 +454,11 @@ def _walk_elements( if isinstance(child, h5py.Group): if "value" in child: # This is an H5MD element - col = self._h5_to_column(f"{path}/{child_name}") + h5_path = f"{path}/{child_name}" + col = self._h5_to_column(h5_path) if col is not None: out.append(col) + self._path_cache[col] = h5_path else: # Recurse (e.g. into box/) self._walk_elements(child, f"{path}/{child_name}", out) diff --git a/tests/contract/test_znh5md_roundtrip.py b/tests/contract/test_znh5md_roundtrip.py new file mode 100644 index 0000000..8a7a60a --- /dev/null +++ b/tests/contract/test_znh5md_roundtrip.py @@ -0,0 +1,107 @@ +"""znh5md <-> asebytes full round-trip tests. + +Covers ALL s22_* fixtures to ensure every property type (positions, numbers, +cell, pbc, all calc results, custom info, custom per-atom arrays, velocities, +mixed pbc/cell, sparse properties) round-trips correctly between znh5md and +asebytes. + +The existing interop tests only checked positions+numbers+forces, which use +standard H5MD element names and missed mapping bugs where per-atom calc +results placed by znh5md in particles/ (with ASE_ENTRY_ORIGIN=calc) could +not be found on read. +""" + +from __future__ import annotations + +import pytest +import znh5md + +from asebytes import ASEIO + +from .conftest import assert_atoms_equal + +# All s22_* fixture names from conftest that produce list[ase.Atoms] +S22_FIXTURES = [ + "s22", + "s22_energy", + "s22_energy_forces", + "s22_all_properties", + "s22_info_arrays_calc", + "s22_mixed_pbc_cell", + "s22_info_arrays_calc_missing_inbetween", +] + + +# --------------------------------------------------------------------------- +# znh5md -> asebytes (*.h5md and *.h5 via content sniffing) +# --------------------------------------------------------------------------- + + +class TestZnH5MDToAsebytes: + """Verify asebytes can read files written by znh5md. + + Tests both the native *.h5md extension and *.h5 files that require + content sniffing to route to H5MDBackend instead of ColumnarBackend. + """ + + @pytest.mark.parametrize("fixture_name", S22_FIXTURES) + @pytest.mark.parametrize("ext", [".h5md", ".h5"]) + def test_read(self, tmp_path, fixture_name, ext, request): + frames = request.getfixturevalue(fixture_name) + path = str(tmp_path / f"znh5md{ext}") + znh5md.IO(path).extend(frames) + + db = ASEIO(path) + assert len(db) == len(frames) + for i, expected in enumerate(frames): + result = db[i] + assert_atoms_equal(result, expected) + + +# --------------------------------------------------------------------------- +# asebytes -> znh5md +# --------------------------------------------------------------------------- + + +class TestAsebytesToZnH5MD: + """Verify znh5md can read files written by asebytes.""" + + @pytest.mark.parametrize("fixture_name", S22_FIXTURES) + def test_write(self, tmp_path, fixture_name, request): + frames = request.getfixturevalue(fixture_name) + path = str(tmp_path / "asebytes.h5md") + ASEIO(path).extend(frames) + + zio = znh5md.IO(path) + assert len(zio) == len(frames) + for i, expected in enumerate(frames): + result = zio[i] + assert_atoms_equal(result, expected, atol=1e-6) + + +# --------------------------------------------------------------------------- +# Full bidirectional: znh5md -> asebytes -> znh5md +# --------------------------------------------------------------------------- + + +class TestBidirectionalRoundtrip: + """Write with znh5md, read with asebytes, write back, read with znh5md.""" + + @pytest.mark.parametrize("fixture_name", S22_FIXTURES) + def test_roundtrip(self, tmp_path, fixture_name, request): + frames = request.getfixturevalue(fixture_name) + + # Step 1: znh5md write -> asebytes read + zpath = str(tmp_path / "step1.h5md") + znh5md.IO(zpath).extend(frames) + db = ASEIO(zpath) + intermediate = [db[i] for i in range(len(db))] + + # Step 2: asebytes write -> znh5md read + apath = str(tmp_path / "step2.h5md") + ASEIO(apath).extend(intermediate) + zio = znh5md.IO(apath) + assert len(zio) == len(frames) + for i, expected in enumerate(frames): + result = zio[i] + assert_atoms_equal(result, expected, atol=1e-6)