diff --git a/atompack-py/python/atompack/__init__.pyi b/atompack-py/python/atompack/__init__.pyi index 9d49536..d2fbece 100644 --- a/atompack-py/python/atompack/__init__.pyi +++ b/atompack-py/python/atompack/__init__.pyi @@ -1,6 +1,6 @@ """Type stubs for atompack""" -from typing import Any, Sequence, overload +from typing import Any, Literal, Sequence, overload import numpy as np import numpy.typing as npt @@ -201,9 +201,10 @@ class Molecule: This reads directly from the molecule getters, so it works for both owned and view-backed molecules without going through `atoms()`. Geometry/species always become the ASE structure, supported builtin - results are attached through `SinglePointCalculator`, per-atom custom - arrays go to `atoms.arrays`, and remaining custom properties go to - `atoms.info`. + results are attached through `SinglePointCalculator`, atom-scope custom + properties go to `atoms.arrays`, molecule-scope tensor properties stay + in `atoms.info`, and legacy molecule arrays may still be shape-routed to + `atoms.arrays`. """ ... @@ -336,13 +337,21 @@ class Molecule: """ ... - def set_property(self, key: str, value: float | int | str | npt.NDArray | None) -> None: + def set_property( + self, + key: str, + value: float | int | str | npt.NDArray | None, + *, + scope: Literal["molecule", "atom"] | None = None, + ) -> None: """ Set a custom property. - Supported types: None, float, int, str, 1D float32/float64/int32/int64 arrays, - and 2D float32/float64 arrays with shape (n, 3). Input dtype is preserved. - The key 'stress' is reserved; use the dedicated ``stress`` property instead. + Supported types: None, float, int, str, and numeric ndarrays with dtype + float32, float64, int32, or int64. Input dtype and tensor shape are preserved. + New atom properties require ``scope="atom"``; existing atom properties keep + atom scope when overwritten. The key 'stress' is reserved; use the dedicated + ``stress`` property instead. Parameters ---------- @@ -350,6 +359,8 @@ class Molecule: Property key value : float, int, str, ndarray, or None Property value + scope : {"molecule", "atom"}, optional + Property scope. Defaults to molecule for new keys. Raises ------ @@ -358,7 +369,7 @@ class Molecule: """ ... - def property_keys(self) -> list[str]: + def property_keys(self, *, scope: Literal["molecule", "atom"] | None = None) -> list[str]: """ Get all property keys. @@ -369,7 +380,7 @@ class Molecule: """ ... - def has_property(self, key: str) -> bool: + def has_property(self, key: str, *, scope: Literal["molecule", "atom"] | None = None) -> bool: """ Check if a property exists. @@ -377,6 +388,8 @@ class Molecule: ---------- key : str Property key + scope : {"molecule", "atom"}, optional + Restrict the lookup to one scope. Returns ------- @@ -385,6 +398,17 @@ class Molecule: """ ... + def delete_property(self, key: str) -> None: + """ + Delete a custom property by key. + + Raises + ------ + KeyError + If property key does not exist + """ + ... + @overload def __getitem__(self, index: int) -> Atom: ... @overload @@ -630,6 +654,7 @@ def from_ase( cell: npt.NDArray[np.float64] | None = None, stress: npt.NDArray[np.float64] | None = None, copy_info: bool = True, + copy_arrays: bool = True, info: dict | None = None, ) -> Molecule: """ @@ -637,8 +662,9 @@ def from_ase( This function extracts positions, atomic numbers, and available properties (forces, energy, charges, velocities, cell) from an ASE Atoms object and - creates a corresponding atompack Molecule. Custom properties from atoms.info - dict are also copied by default. + creates a corresponding atompack Molecule. Custom properties from + atoms.info, atoms.arrays, and calculator results are copied as + molecule-scope properties by default. Parameters ---------- @@ -656,7 +682,11 @@ def from_ase( Override cell from ASE Atoms. If None, attempts to extract from atoms. copy_info : bool, default=True If True, copies custom properties from atoms.info dict to molecule properties. - Supports: str, int, float, 1D float/int arrays, 2D arrays with shape (n, 3). + Supported values are None, scalar int/float/bool, str, and ndarray-like + values with dtype float32, float64, int32, or int64. + copy_arrays : bool, default=True + If True, copies custom values from atoms.arrays to molecule properties. + Shape is not used to infer atom-property scope. info : dict, optional Additional properties to store in the molecule. These will be added after copying atoms.info (if copy_info=True), so they can override atoms.info values. @@ -731,6 +761,7 @@ def add_ase_batch( atoms_list: list[object], *, copy_info: bool = True, + copy_arrays: bool = True, info: dict | list[dict | None] | None = None, batch_size: int = 512, ) -> None: diff --git a/atompack-py/python/atompack/_atompack_rs.pyi b/atompack-py/python/atompack/_atompack_rs.pyi index 99e8730..36a1ff0 100644 --- a/atompack-py/python/atompack/_atompack_rs.pyi +++ b/atompack-py/python/atompack/_atompack_rs.pyi @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import Any, Sequence, overload +from typing import Any, Literal, Sequence, overload import numpy as np @@ -329,7 +329,13 @@ class PyMolecule: """ ... - def set_property(self, key: str, value: Any) -> None: + def set_property( + self, + key: str, + value: Any, + *, + scope: Literal["molecule", "atom"] | None = None, + ) -> None: """ Set a custom property. @@ -339,10 +345,12 @@ class PyMolecule: Property key value : Any Property value + scope : {"molecule", "atom"}, optional + Property scope. Defaults to molecule for new keys. """ ... - def property_keys(self) -> list[str]: + def property_keys(self, *, scope: Literal["molecule", "atom"] | None = None) -> list[str]: """ Get all property keys. @@ -353,7 +361,7 @@ class PyMolecule: """ ... - def has_property(self, key: str) -> bool: + def has_property(self, key: str, *, scope: Literal["molecule", "atom"] | None = None) -> bool: """ Check if a property exists. @@ -361,6 +369,8 @@ class PyMolecule: ---------- key : str Property key + scope : {"molecule", "atom"}, optional + Restrict the lookup to one scope. Returns ------- @@ -369,6 +379,10 @@ class PyMolecule: """ ... + def delete_property(self, key: str) -> None: + """Delete a custom property by key.""" + ... + @overload def __getitem__(self, index: int) -> PyAtom: ... @overload diff --git a/atompack-py/python/atompack/ase_bridge.py b/atompack-py/python/atompack/ase_bridge.py index a1cc155..b9ab31d 100644 --- a/atompack-py/python/atompack/ase_bridge.py +++ b/atompack-py/python/atompack/ase_bridge.py @@ -22,6 +22,16 @@ _ASE_TYPES = None _CALC_MODES = {"singlepoint", "nocopy", "none"} _UNSUPPORTED_PROPERTY = object() +_SUPPORTED_ARRAY_DTYPES = { + np.dtype(np.float32), + np.dtype(np.float64), + np.dtype(np.int32), + np.dtype(np.int64), +} +_SUPPORTED_CUSTOM_PROPERTY_TEXT = ( + "supported values are None, scalar int/float/bool, str, or ndarray-like " + "values with dtype float32, float64, int32, or int64" +) def _voigt6_to_mat3x3(stress): @@ -53,36 +63,50 @@ def _get_stress(atoms): return None -def _coerce_property(value, n_atoms): +def _coerce_property(value): if value is None: return None - if isinstance(value, (str, bool, int, float, np.integer, np.floating)): - if isinstance(value, str): - return value + if isinstance(value, str): + return value + if isinstance(value, (bool, int, float, np.integer, np.floating)): if isinstance(value, (bool, int, np.integer)): return int(value) return float(value) - arr = np.asarray(value) + try: + arr = np.asarray(value) + except Exception: + return _UNSUPPORTED_PROPERTY + if arr.ndim == 0 and arr.dtype.kind in {"b", "i", "u", "f"}: return arr.item() - if arr.ndim == 1 and arr.shape[0] == n_atoms: - if arr.dtype == np.float32: - return arr.astype(np.float32, copy=False) - if arr.dtype.kind == "f": - return arr.astype(np.float64, copy=False) - if arr.dtype == np.int32: - return arr.astype(np.int32, copy=False) - if arr.dtype.kind in {"b", "i", "u"}: - return arr.astype(np.int64, copy=False) - if arr.ndim == 2 and arr.shape == (n_atoms, 3) and arr.dtype.kind == "f": - if arr.dtype == np.float32: - return arr.astype(np.float32, copy=False) - return arr.astype(np.float64, copy=False) + if arr.dtype in _SUPPORTED_ARRAY_DTYPES: + return arr.astype(arr.dtype, copy=False) return _UNSUPPORTED_PROPERTY -def _merge_properties(properties, builtins, values, n_atoms): +def _unsupported_property_reason(value): + try: + arr = np.asarray(value) + except Exception as exc: + return f"could not convert value to an ndarray: {exc}" + return ( + f"got value of type {type(value).__name__} with ndarray dtype " + f"{arr.dtype} and shape {arr.shape}; {_SUPPORTED_CUSTOM_PROPERTY_TEXT}" + ) + + +def _coerce_custom_property(key, value, source): + coerced = _coerce_property(value) + if coerced is _UNSUPPORTED_PROPERTY: + raise TypeError( + f"Unsupported ASE custom property {key!r} from {source}: " + f"{_unsupported_property_reason(value)}" + ) + return coerced + + +def _merge_properties(properties, builtins, values, source): for key, value in values.items(): if key in _BUILTIN_FIELDS: # Builtin keys in atoms.info / info-override go to the builtins @@ -95,9 +119,7 @@ def _merge_properties(properties, builtins, values, n_atoms): if arr.shape == (3, 3) and arr.dtype.kind == "f": builtins["stress"] = arr.astype(np.float64, copy=False) continue - coerced = _coerce_property(value, n_atoms) - if coerced is not _UNSUPPORTED_PROPERTY: - properties[key] = coerced + properties[key] = _coerce_custom_property(key, value, source) def _extract_ase_record( @@ -110,6 +132,7 @@ def _extract_ase_record( cell=None, stress=None, copy_info=True, + copy_arrays=True, info=None, ): positions = np.asarray(atoms.get_positions(), dtype=np.float32) @@ -181,7 +204,7 @@ def _extract_ase_record( properties = {} arrays = getattr(atoms, "arrays", None) - if isinstance(arrays, dict): + if copy_arrays and isinstance(arrays, dict): for key, value in arrays.items(): # Skip both ASE-reserved geometry keys ("positions", "numbers") # and atompack builtin field names. A user who stashes "forces" @@ -189,23 +212,19 @@ def _extract_ase_record( # builtins["forces"] (from get_forces()) and properties["forces"]. if key in _ASE_RESERVED_ARRAYS or key in _BUILTIN_FIELDS: continue - coerced = _coerce_property(value, n_atoms) - if coerced is not _UNSUPPORTED_PROPERTY: - properties[key] = coerced + properties[key] = _coerce_custom_property(key, value, "atoms.arrays") calc = getattr(atoms, "calc", None) results = getattr(calc, "results", None) if isinstance(results, dict): for key, value in results.items(): if key not in _BUILTIN_FIELDS: - coerced = _coerce_property(value, n_atoms) - if coerced is not _UNSUPPORTED_PROPERTY: - properties[key] = coerced + properties[key] = _coerce_custom_property(key, value, "atoms.calc.results") if copy_info and getattr(atoms, "info", None): - _merge_properties(properties, builtins, atoms.info, n_atoms) + _merge_properties(properties, builtins, atoms.info, "atoms.info") if info is not None: - _merge_properties(properties, builtins, info, n_atoms) + _merge_properties(properties, builtins, info, "info override") return { "positions": positions, @@ -661,9 +680,15 @@ def from_ase( cell=None, stress=None, copy_info=True, + copy_arrays=True, info=None, ): - """Convert one ASE Atoms object to an atompack Molecule.""" + """Convert one ASE Atoms object to an atompack Molecule. + + Custom values from ``atoms.info``, ``atoms.arrays``, calculator results, + and explicit ``info=`` overrides are stored as molecule-scope properties. + Array shape is not used to infer atom-property scope during ingestion. + """ return _record_to_molecule( _extract_ase_record( atoms, @@ -674,6 +699,7 @@ def from_ase( cell=cell, stress=stress, copy_info=copy_info, + copy_arrays=copy_arrays, info=info, ) ) @@ -684,6 +710,7 @@ def add_ase_batch( atoms_list, *, copy_info=True, + copy_arrays=True, info=None, batch_size=512, ): @@ -710,7 +737,12 @@ def flush_slow(): slow_records.clear() for atoms, info_override in zip(atoms_list, info_overrides): - record = _extract_ase_record(atoms, copy_info=copy_info, info=info_override) + record = _extract_ase_record( + atoms, + copy_info=copy_info, + copy_arrays=copy_arrays, + info=info_override, + ) if record["properties"]: flush_fast() slow_records.append(_record_to_molecule(record)) diff --git a/atompack-py/src/database.rs b/atompack-py/src/database.rs index c890ae8..2e78d4e 100644 --- a/atompack-py/src/database.rs +++ b/atompack-py/src/database.rs @@ -160,9 +160,11 @@ impl PyAtomDatabase { ) .map_err(|e| PyValueError::new_err(format!("{}", e))); } - let owned = molecule.clone_as_owned()?; + let owned = molecule.as_owned().ok_or_else(|| { + PyValueError::new_err("Molecule is missing both owned and view state") + })?; self.inner - .add_molecule(&owned) + .add_molecule(owned) .map_err(|e| PyValueError::new_err(format!("{}", e))) } @@ -170,14 +172,18 @@ impl PyAtomDatabase { fn add_molecules(&mut self, molecules: Vec>) -> PyResult<()> { let mut raw_records: Vec<(&[u8], u32)> = Vec::new(); let mut raw_views: Vec<&SoaMoleculeView> = Vec::new(); - let mut owned_molecules: Vec = Vec::new(); + let mut owned_molecules: Vec<&Molecule> = Vec::new(); for m in &molecules { if let Some(view) = m.as_view() { raw_records.push((view.raw_bytes(), view.n_atoms as u32)); raw_views.push(view); + } else if let Some(owned) = m.as_owned() { + owned_molecules.push(owned); } else { - owned_molecules.push(m.clone_as_owned()?); + return Err(PyValueError::new_err( + "Molecule is missing both owned and view state", + )); } } @@ -211,9 +217,8 @@ impl PyAtomDatabase { } } if !owned_molecules.is_empty() { - let mol_refs: Vec<&Molecule> = owned_molecules.iter().collect(); self.inner - .add_molecules(&mol_refs) + .add_molecules(&owned_molecules) .map_err(|e| PyValueError::new_err(format!("{}", e)))?; } Ok(()) diff --git a/atompack-py/src/lib.rs b/atompack-py/src/lib.rs index 92639ea..546d239 100644 --- a/atompack-py/src/lib.rs +++ b/atompack-py/src/lib.rs @@ -9,9 +9,11 @@ use atompack::{ Atom, AtomDatabase, FloatArrayData, FloatScalarData, Mat3Data, Molecule, SharedMmapBytes, - Vec3Data, atom::PropertyValue, compression::CompressionType, + Vec3Data, + atom::{PropertyValue, TensorData}, + compression::CompressionType, }; -use numpy::{Element, PyArray1, PyArray2, PyArray3, PyArrayMethods}; +use numpy::{Element, PyArray1, PyArray2, PyArray3, PyArrayDyn, PyArrayMethods}; use pyo3::exceptions::{PyFileExistsError, PyIndexError, PyKeyError, PyTypeError, PyValueError}; use pyo3::prelude::*; use pyo3::types::{PyBytes, PyDict, PyTuple}; @@ -114,6 +116,10 @@ const TYPE_MAT3X3_F64: u8 = 10; const TYPE_FLOAT32: u8 = 11; const TYPE_MAT3X3_F32: u8 = 12; const TYPE_NONE: u8 = 13; +const TYPE_TENSOR_F32: u8 = 14; +const TYPE_TENSOR_F64: u8 = 15; +const TYPE_TENSOR_I32: u8 = 16; +const TYPE_TENSOR_I64: u8 = 17; const RECORD_FORMAT_SOA_V2: u32 = 2; const RECORD_FORMAT_SOA_V3: u32 = 3; @@ -126,8 +132,9 @@ pub(crate) use self::py_dtypes::{ parse_mat3_field, parse_positions_field, parse_property_value, parse_vec3_field, }; pub(crate) use self::soa::{ - LazySection, SectionSchema, SoaContext, SoaMoleculeView, parse_mol_fast_soa, read_f64_scalar, - read_i64_scalar, section_schema_from_ref, type_tag_elem_bytes, + LazySection, SectionSchema, SoaContext, SoaMoleculeView, decode_property_value, + is_tensor_type_tag, parse_mol_fast_soa, read_f64_scalar, read_i64_scalar, + section_schema_from_ref, type_tag_elem_bytes, }; mod database; diff --git a/atompack-py/src/molecule.rs b/atompack-py/src/molecule.rs index 3d53029..b7718e9 100644 --- a/atompack-py/src/molecule.rs +++ b/atompack-py/src/molecule.rs @@ -65,6 +65,75 @@ pub(crate) use self::helpers::{ }; use self::helpers::{into_py_any, property_section_to_pyobject, property_value_to_pyobject}; +#[derive(Clone, Copy, PartialEq, Eq)] +enum CustomPropertyScope { + Molecule, + Atom, +} + +fn parse_custom_property_scope(scope: Option<&str>) -> PyResult> { + match scope { + None => Ok(None), + Some("molecule") => Ok(Some(CustomPropertyScope::Molecule)), + Some("atom") => Ok(Some(CustomPropertyScope::Atom)), + Some(other) => Err(PyValueError::new_err(format!( + "scope must be 'molecule' or 'atom', got '{}'", + other + ))), + } +} + +fn atom_property_shape_error(key: &str, actual: usize, n_atoms: usize) -> PyErr { + PyValueError::new_err(format!( + "Atom property '{}' first dimension ({}) doesn't match atom count ({})", + key, actual, n_atoms + )) +} + +fn validate_atom_property_value(key: &str, value: &PropertyValue, n_atoms: usize) -> PyResult<()> { + match value { + PropertyValue::FloatArray(values) if values.len() == n_atoms => Ok(()), + PropertyValue::Vec3Array(values) if values.len() == n_atoms => Ok(()), + PropertyValue::IntArray(values) if values.len() == n_atoms => Ok(()), + PropertyValue::Float32Array(values) if values.len() == n_atoms => Ok(()), + PropertyValue::Vec3ArrayF64(values) if values.len() == n_atoms => Ok(()), + PropertyValue::Int32Array(values) if values.len() == n_atoms => Ok(()), + PropertyValue::Tensor(values) => match values.shape().first() { + Some(first_dim) if *first_dim == n_atoms => Ok(()), + Some(first_dim) => Err(atom_property_shape_error(key, *first_dim, n_atoms)), + None => Err(PyValueError::new_err(format!( + "Atom tensor property '{}' must have at least one dimension", + key + ))), + }, + PropertyValue::FloatArray(values) => { + Err(atom_property_shape_error(key, values.len(), n_atoms)) + } + PropertyValue::Vec3Array(values) => { + Err(atom_property_shape_error(key, values.len(), n_atoms)) + } + PropertyValue::IntArray(values) => { + Err(atom_property_shape_error(key, values.len(), n_atoms)) + } + PropertyValue::Float32Array(values) => { + Err(atom_property_shape_error(key, values.len(), n_atoms)) + } + PropertyValue::Vec3ArrayF64(values) => { + Err(atom_property_shape_error(key, values.len(), n_atoms)) + } + PropertyValue::Int32Array(values) => { + Err(atom_property_shape_error(key, values.len(), n_atoms)) + } + PropertyValue::None + | PropertyValue::Float(_) + | PropertyValue::Int(_) + | PropertyValue::String(_) => Err(PyValueError::new_err(format!( + "Atom property '{}' must be a numeric ndarray with first dimension equal to atom count ({})", + key, n_atoms + ))), + } +} + #[pymethods] impl PyMolecule { /// Create a new molecule from numpy arrays. @@ -434,55 +503,170 @@ impl PyMolecule { let py = slf.py(); let molecule = slf.borrow(); if let Some(inner) = molecule.as_owned() { - return match inner.properties.get(key) { - Some(v) => property_value_to_pyobject(py, v), - None => Err(PyKeyError::new_err(format!("Property '{}' not found", key))), + let molecule_value = inner.properties.get(key); + let atom_value = inner.atom_properties.get(key); + return match (molecule_value, atom_value) { + (Some(_), Some(_)) => Err(PyValueError::new_err(format!( + "Property '{}' exists in both molecule and atom scopes", + key + ))), + (Some(v), None) | (None, Some(v)) => property_value_to_pyobject(py, v), + (None, None) => Err(PyKeyError::new_err(format!("Property '{}' not found", key))), }; } let view = molecule.as_view().ok_or_else(|| { PyValueError::new_err("Molecule is missing both owned and view state") })?; - match view.find_custom_section(KIND_MOL_PROP, key)? { - Some(section) => property_section_to_pyobject(py, view, section), - None => Err(PyKeyError::new_err(format!("Property '{}' not found", key))), + let molecule_section = view.find_custom_section(KIND_MOL_PROP, key)?; + let atom_section = view.find_custom_section(KIND_ATOM_PROP, key)?; + match (molecule_section, atom_section) { + (Some(_), Some(_)) => Err(PyValueError::new_err(format!( + "Property '{}' exists in both molecule and atom scopes", + key + ))), + (Some(section), None) | (None, Some(section)) => { + property_section_to_pyobject(py, view, section) + } + (None, None) => Err(PyKeyError::new_err(format!("Property '{}' not found", key))), } } /// Set a custom property - fn set_property(&mut self, py: Python<'_>, key: String, value: Py) -> PyResult<()> { - let inner = self.ensure_owned()?; - let value = value.bind(py); + #[pyo3(signature = (key, value, *, scope=None))] + fn set_property( + &mut self, + py: Python<'_>, + key: String, + value: Py, + scope: Option<&str>, + ) -> PyResult<()> { + let requested_scope = parse_custom_property_scope(scope)?; + let n_atoms = self.len(); if key == "stress" { return Err(PyValueError::new_err( "'stress' is a reserved field; use molecule.stress instead", )); } - inner.properties.insert(key, parse_property_value(value)?); + let parsed = parse_property_value(value.bind(py))?; + let inner = self.ensure_owned()?; + let has_molecule = inner.properties.contains_key(&key); + let has_atom = inner.atom_properties.contains_key(&key); + if has_molecule && has_atom { + return Err(PyValueError::new_err(format!( + "Property '{}' exists in both molecule and atom scopes", + key + ))); + } + let target_scope = match (has_molecule, has_atom, requested_scope) { + (true, false, None | Some(CustomPropertyScope::Molecule)) => { + CustomPropertyScope::Molecule + } + (true, false, Some(CustomPropertyScope::Atom)) => { + return Err(PyValueError::new_err(format!( + "Property '{}' already exists as a molecule property; delete it before setting atom scope", + key + ))); + } + (false, true, None | Some(CustomPropertyScope::Atom)) => CustomPropertyScope::Atom, + (false, true, Some(CustomPropertyScope::Molecule)) => { + return Err(PyValueError::new_err(format!( + "Property '{}' already exists as an atom property; delete it before setting molecule scope", + key + ))); + } + (false, false, Some(CustomPropertyScope::Atom)) => CustomPropertyScope::Atom, + (false, false, None | Some(CustomPropertyScope::Molecule)) => { + CustomPropertyScope::Molecule + } + (true, true, _) => unreachable!(), + }; + match target_scope { + CustomPropertyScope::Molecule => { + inner.properties.insert(key, parsed); + } + CustomPropertyScope::Atom => { + validate_atom_property_value(&key, &parsed, n_atoms)?; + inner.atom_properties.insert(key, parsed); + } + } Ok(()) } /// Get all property keys - fn property_keys(&self) -> PyResult> { + #[pyo3(signature = (*, scope=None))] + fn property_keys(&self, scope: Option<&str>) -> PyResult> { + let requested_scope = parse_custom_property_scope(scope)?; if let Some(inner) = self.as_owned() { - Ok(inner.properties.keys().cloned().collect()) + let mut keys = Vec::new(); + if requested_scope != Some(CustomPropertyScope::Atom) { + keys.extend(inner.properties.keys().cloned()); + } + if requested_scope != Some(CustomPropertyScope::Molecule) { + keys.extend(inner.atom_properties.keys().cloned()); + } + keys.sort(); + Ok(keys) } else if let Some(view) = self.as_view() { - view.property_keys() + let mut keys = Vec::new(); + for section in &view.custom_sections { + let include = match requested_scope { + Some(CustomPropertyScope::Molecule) => section.kind == KIND_MOL_PROP, + Some(CustomPropertyScope::Atom) => section.kind == KIND_ATOM_PROP, + None => matches!(section.kind, KIND_MOL_PROP | KIND_ATOM_PROP), + }; + if include { + keys.push(view.lazy_section_key(section)?.to_string()); + } + } + keys.sort(); + Ok(keys) } else { Ok(Vec::new()) } } /// Check if a property exists - fn has_property(&self, key: &str) -> PyResult { + #[pyo3(signature = (key, *, scope=None))] + fn has_property(&self, key: &str, scope: Option<&str>) -> PyResult { + let requested_scope = parse_custom_property_scope(scope)?; if let Some(inner) = self.as_owned() { - Ok(inner.properties.contains_key(key)) + Ok(match requested_scope { + Some(CustomPropertyScope::Molecule) => inner.properties.contains_key(key), + Some(CustomPropertyScope::Atom) => inner.atom_properties.contains_key(key), + None => { + inner.properties.contains_key(key) || inner.atom_properties.contains_key(key) + } + }) } else if let Some(view) = self.as_view() { - Ok(view.find_custom_section(KIND_MOL_PROP, key)?.is_some()) + Ok(match requested_scope { + Some(CustomPropertyScope::Molecule) => { + view.find_custom_section(KIND_MOL_PROP, key)?.is_some() + } + Some(CustomPropertyScope::Atom) => { + view.find_custom_section(KIND_ATOM_PROP, key)?.is_some() + } + None => { + view.find_custom_section(KIND_MOL_PROP, key)?.is_some() + || view.find_custom_section(KIND_ATOM_PROP, key)?.is_some() + } + }) } else { Ok(false) } } + /// Delete a custom property by key. + fn delete_property(&mut self, key: &str) -> PyResult<()> { + let inner = self.ensure_owned()?; + let removed_molecule = inner.properties.remove(key).is_some(); + let removed_atom = inner.atom_properties.remove(key).is_some(); + if removed_molecule || removed_atom { + Ok(()) + } else { + Err(PyKeyError::new_err(format!("Property '{}' not found", key))) + } + } + /// Index molecule atoms by integer, or custom properties by string. fn __getitem__<'py>(slf: Bound<'py, Self>, index: &Bound<'py, PyAny>) -> PyResult> { let py = slf.py(); @@ -490,17 +674,33 @@ impl PyMolecule { if let Ok(key) = index.extract::() { if let Some(inner) = molecule.as_owned() { - return match inner.properties.get(&key) { - Some(v) => property_value_to_pyobject(py, v), - None => Err(PyKeyError::new_err(format!("Property '{}' not found", key))), + let molecule_value = inner.properties.get(&key); + let atom_value = inner.atom_properties.get(&key); + return match (molecule_value, atom_value) { + (Some(_), Some(_)) => Err(PyValueError::new_err(format!( + "Property '{}' exists in both molecule and atom scopes", + key + ))), + (Some(v), None) | (None, Some(v)) => property_value_to_pyobject(py, v), + (None, None) => { + Err(PyKeyError::new_err(format!("Property '{}' not found", key))) + } }; } let view = molecule.as_view().ok_or_else(|| { PyValueError::new_err("Molecule is missing both owned and view state") })?; - return match view.find_custom_section(KIND_MOL_PROP, &key)? { - Some(section) => property_section_to_pyobject(py, view, section), - None => Err(PyKeyError::new_err(format!("Property '{}' not found", key))), + let molecule_section = view.find_custom_section(KIND_MOL_PROP, &key)?; + let atom_section = view.find_custom_section(KIND_ATOM_PROP, &key)?; + return match (molecule_section, atom_section) { + (Some(_), Some(_)) => Err(PyValueError::new_err(format!( + "Property '{}' exists in both molecule and atom scopes", + key + ))), + (Some(section), None) | (None, Some(section)) => { + property_section_to_pyobject(py, view, section) + } + (None, None) => Err(PyKeyError::new_err(format!("Property '{}' not found", key))), }; } diff --git a/atompack-py/src/molecule_helpers.rs b/atompack-py/src/molecule_helpers.rs index 12aa81a..6b95c8f 100644 --- a/atompack-py/src/molecule_helpers.rs +++ b/atompack-py/src/molecule_helpers.rs @@ -392,6 +392,22 @@ pub(super) fn property_value_to_pyobject( } PropertyValue::IntArray(v) => into_py_any(py, PyArray1::from_slice(py, v))?, PropertyValue::Int32Array(v) => into_py_any(py, PyArray1::from_slice(py, v))?, + PropertyValue::Tensor(TensorData::F32 { shape, values }) => into_py_any( + py, + PyArray1::from_vec(py, values.clone()).reshape(shape.as_slice())?, + )?, + PropertyValue::Tensor(TensorData::F64 { shape, values }) => into_py_any( + py, + PyArray1::from_vec(py, values.clone()).reshape(shape.as_slice())?, + )?, + PropertyValue::Tensor(TensorData::I32 { shape, values }) => into_py_any( + py, + PyArray1::from_vec(py, values.clone()).reshape(shape.as_slice())?, + )?, + PropertyValue::Tensor(TensorData::I64 { shape, values }) => into_py_any( + py, + PyArray1::from_vec(py, values.clone()).reshape(shape.as_slice())?, + )?, }) } @@ -444,6 +460,10 @@ pub(super) fn property_section_to_pyobject<'py>( )? } TYPE_I32_ARRAY => into_py_any(py, pyarray1_from_cow(py, cast_or_decode_i32(payload)?))?, + tag if is_tensor_type_tag(tag) => { + let value = decode_property_value(tag, payload)?; + property_value_to_pyobject(py, &value)? + } _ => { return Err(PyValueError::new_err(format!( "Unsupported property type tag {}", @@ -462,6 +482,7 @@ fn property_value_is_atom_array(value: &PropertyValue, n_atoms: usize) -> bool { PropertyValue::Float32Array(values) => values.len() == n_atoms, PropertyValue::Vec3ArrayF64(values) => values.len() == n_atoms, PropertyValue::Int32Array(values) => values.len() == n_atoms, + PropertyValue::Tensor(_) => false, PropertyValue::Float(_) | PropertyValue::Int(_) | PropertyValue::String(_) => false, } } diff --git a/atompack-py/src/py_dtypes.rs b/atompack-py/src/py_dtypes.rs index 95a0bd3..4ba7fb9 100644 --- a/atompack-py/src/py_dtypes.rs +++ b/atompack-py/src/py_dtypes.rs @@ -242,6 +242,76 @@ pub(crate) fn parse_mat3_field(value: &Bound<'_, PyAny>, label: &str) -> PyResul array.parse_mat3_data(label) } +fn tensor_shape(shape: &[usize], values_len: usize) -> PyResult> { + let expected = shape.iter().try_fold(1usize, |acc, dim| { + acc.checked_mul(*dim) + .ok_or_else(|| PyValueError::new_err("Tensor shape overflows usize")) + })?; + if expected != values_len { + return Err(PyValueError::new_err(format!( + "Tensor shape {:?} expects {} values, got {}", + shape, expected, values_len + ))); + } + if shape.len() > u8::MAX as usize { + return Err(PyValueError::new_err(format!( + "Tensor rank {} exceeds maximum {}", + shape.len(), + u8::MAX + ))); + } + if shape.iter().any(|dim| *dim > u32::MAX as usize) { + return Err(PyValueError::new_err( + "Tensor dimensions must fit in u32 for storage", + )); + } + Ok(shape.to_vec()) +} + +fn parse_tensor_property_value(value: &Bound<'_, PyAny>) -> PyResult> { + if let Ok(arr) = value.cast::>() { + let readonly = arr.readonly(); + let view = readonly.as_array(); + let values: Vec = view.iter().copied().collect(); + let shape = tensor_shape(view.shape(), values.len())?; + return Ok(Some(PropertyValue::Tensor(TensorData::F32 { + shape, + values, + }))); + } + if let Ok(arr) = value.cast::>() { + let readonly = arr.readonly(); + let view = readonly.as_array(); + let values: Vec = view.iter().copied().collect(); + let shape = tensor_shape(view.shape(), values.len())?; + return Ok(Some(PropertyValue::Tensor(TensorData::F64 { + shape, + values, + }))); + } + if let Ok(arr) = value.cast::>() { + let readonly = arr.readonly(); + let view = readonly.as_array(); + let values: Vec = view.iter().copied().collect(); + let shape = tensor_shape(view.shape(), values.len())?; + return Ok(Some(PropertyValue::Tensor(TensorData::I32 { + shape, + values, + }))); + } + if let Ok(arr) = value.cast::>() { + let readonly = arr.readonly(); + let view = readonly.as_array(); + let values: Vec = view.iter().copied().collect(); + let shape = tensor_shape(view.shape(), values.len())?; + return Ok(Some(PropertyValue::Tensor(TensorData::I64 { + shape, + values, + }))); + } + Ok(None) +} + pub(crate) fn parse_property_value(value: &Bound<'_, PyAny>) -> PyResult { if value.is_none() { return Ok(PropertyValue::None); @@ -272,42 +342,39 @@ pub(crate) fn parse_property_value(value: &Bound<'_, PyAny>) -> PyResult { let readonly = arr.readonly(); let arr_view = readonly.as_array(); let shape = arr_view.shape(); - if shape[1] != 3 { - return Err(PyValueError::new_err( - "Vec3Array properties must have shape (n, 3)", + if shape[1] == 3 { + return Ok(PropertyValue::Vec3Array( + arr_view + .outer_iter() + .map(|row| [row[0], row[1], row[2]]) + .collect(), )); } - PropertyValue::Vec3Array( - arr_view - .outer_iter() - .map(|row| [row[0], row[1], row[2]]) - .collect(), - ) } PyFloatArray2::F64(arr) => { let readonly = arr.readonly(); let arr_view = readonly.as_array(); let shape = arr_view.shape(); - if shape[1] != 3 { - return Err(PyValueError::new_err( - "Vec3Array properties must have shape (n, 3)", + if shape[1] == 3 { + return Ok(PropertyValue::Vec3ArrayF64( + arr_view + .outer_iter() + .map(|row| [row[0], row[1], row[2]]) + .collect(), )); } - PropertyValue::Vec3ArrayF64( - arr_view - .outer_iter() - .map(|row| [row[0], row[1], row[2]]) - .collect(), - ) } - }); + } + } + if let Some(value) = parse_tensor_property_value(value)? { + return Ok(value); } Err(PyValueError::new_err( - "Unsupported property type. Supported: None, float, int, str, ndarray", + "Unsupported property type. Supported: None, float, int, str, and numeric ndarray with dtype float32, float64, int32, or int64", )) } diff --git a/atompack-py/src/soa.rs b/atompack-py/src/soa.rs index d496c9b..002cdb6 100644 --- a/atompack-py/src/soa.rs +++ b/atompack-py/src/soa.rs @@ -149,9 +149,16 @@ pub(crate) fn section_schema_from_ref( n_atoms: usize, ) -> atompack::Result { let per_atom = is_per_atom(section.kind, section.key, section.type_tag); + let tensor_type = is_tensor_type_tag(section.type_tag); let elem_bytes = match section.type_tag { TYPE_NONE => 0, TYPE_STRING => 0, + tag if tensor_type => tensor_type_tag_elem_bytes(tag).ok_or_else(|| { + invalid_data(format!( + "Unsupported tensor type tag {} for key '{}'", + tag, section.key + )) + })?, tag if per_atom => { let elem_bytes = type_tag_elem_bytes(tag); if elem_bytes == 0 { @@ -169,7 +176,7 @@ pub(crate) fn section_schema_from_ref( TYPE_MAT3X3_F64 => 72, _ => section.payload.len(), }; - let slot_bytes = if matches!(section.type_tag, TYPE_STRING | TYPE_NONE) { + let slot_bytes = if matches!(section.type_tag, TYPE_STRING | TYPE_NONE) || tensor_type { 0 } else if per_atom { elem_bytes @@ -216,6 +223,26 @@ pub(crate) fn validate_section_payload( ))); } } + tag if is_tensor_type_tag(tag) => { + let (shape, _) = tensor_shape_from_payload(tag, section.payload)?; + if per_atom { + match shape.first() { + Some(first_dim) if *first_dim == n_atoms => {} + Some(first_dim) => { + return Err(invalid_data(format!( + "Atom tensor property '{}' first dimension ({}) doesn't match atom count ({})", + section.key, first_dim, n_atoms + ))); + } + None => { + return Err(invalid_data(format!( + "Atom tensor property '{}' must have at least one dimension", + section.key + ))); + } + } + } + } TYPE_FLOAT | TYPE_INT | TYPE_FLOAT32 | TYPE_BOOL3 | TYPE_MAT3X3_F32 | TYPE_MAT3X3_F64 => { if section.payload.len() != slot_bytes { return Err(invalid_data(format!( @@ -278,10 +305,73 @@ pub(crate) fn type_tag_elem_bytes(tag: u8) -> usize { TYPE_FLOAT32 => 4, TYPE_MAT3X3_F32 => 36, TYPE_MAT3X3_F64 => 72, + TYPE_TENSOR_F32 | TYPE_TENSOR_I32 => 4, + TYPE_TENSOR_F64 | TYPE_TENSOR_I64 => 8, _ => 0, } } +pub(crate) fn is_tensor_type_tag(tag: u8) -> bool { + matches!( + tag, + TYPE_TENSOR_F32 | TYPE_TENSOR_F64 | TYPE_TENSOR_I32 | TYPE_TENSOR_I64 + ) +} + +fn tensor_type_tag_elem_bytes(tag: u8) -> Option { + match tag { + TYPE_TENSOR_F32 | TYPE_TENSOR_I32 => Some(4), + TYPE_TENSOR_F64 | TYPE_TENSOR_I64 => Some(8), + _ => None, + } +} + +fn tensor_value_count(shape: &[usize]) -> atompack::Result { + shape.iter().try_fold(1usize, |acc, dim| { + acc.checked_mul(*dim) + .ok_or_else(|| invalid_data("Tensor shape overflows usize")) + }) +} + +pub(crate) fn tensor_shape_from_payload( + type_tag: u8, + payload: &[u8], +) -> atompack::Result<(Vec, usize)> { + let elem_bytes = tensor_type_tag_elem_bytes(type_tag) + .ok_or_else(|| invalid_data(format!("Type tag {} is not a tensor", type_tag)))?; + let Some((&rank, rest)) = payload.split_first() else { + return Err(invalid_data("Tensor payload missing rank")); + }; + let rank = rank as usize; + let shape_bytes = rank + .checked_mul(4) + .ok_or_else(|| invalid_data("Tensor rank overflow"))?; + if rest.len() < shape_bytes { + return Err(invalid_data("Tensor payload truncated at shape")); + } + let mut shape = Vec::with_capacity(rank); + for chunk in rest[..shape_bytes].chunks_exact(4) { + shape.push(u32::from_le_bytes(slice_to_array(chunk, "tensor shape")?) as usize); + } + let data_offset = 1 + shape_bytes; + let data_len = payload.len() - data_offset; + if !data_len.is_multiple_of(elem_bytes) { + return Err(invalid_data(format!( + "Tensor payload length {} is not divisible by element size {}", + data_len, elem_bytes + ))); + } + let expected = tensor_value_count(&shape)?; + let actual = data_len / elem_bytes; + if expected != actual { + return Err(invalid_data(format!( + "Tensor shape {:?} expects {} values, got {}", + shape, expected, actual + ))); + } + Ok((shape, data_offset)) +} + /// Whether a section with the given kind/key/type_tag is per-atom (vs per-molecule). pub(crate) fn is_per_atom(kind: u8, key: &str, _type_tag: u8) -> bool { match kind { @@ -300,8 +390,16 @@ fn database_schema_section( n_atoms: usize, ) -> PyResult { let per_atom = is_per_atom(kind, key, type_tag); + let tensor_type = is_tensor_type_tag(type_tag); let elem_bytes = if matches!(type_tag, TYPE_STRING | TYPE_NONE) { 0 + } else if tensor_type { + tensor_type_tag_elem_bytes(type_tag).ok_or_else(|| { + PyValueError::new_err(format!( + "Unsupported tensor type tag {} for '{}'", + type_tag, key + )) + })? } else { let elem_bytes = type_tag_elem_bytes(type_tag); if elem_bytes == 0 { @@ -312,7 +410,7 @@ fn database_schema_section( } elem_bytes }; - let slot_bytes = if matches!(type_tag, TYPE_STRING | TYPE_NONE) { + let slot_bytes = if matches!(type_tag, TYPE_STRING | TYPE_NONE) || tensor_type { 0 } else if per_atom { let expected = n_atoms.checked_mul(elem_bytes).ok_or_else(|| { @@ -752,14 +850,6 @@ impl SoaMoleculeView { Ok(None) } - pub(crate) fn property_keys(&self) -> PyResult> { - self.custom_sections - .iter() - .filter(|s| s.kind == KIND_MOL_PROP) - .map(|s| Ok(self.lazy_section_key(s)?.to_string())) - .collect() - } - pub(crate) fn atom_at(&self, index: usize) -> PyResult> { if index >= self.n_atoms { return Ok(None); @@ -1105,7 +1195,7 @@ fn decode_mat3x3_f32(payload: &[u8]) -> PyResult<[[f32; 3]; 3]> { ]) } -fn decode_property_value(type_tag: u8, payload: &[u8]) -> PyResult { +pub(crate) fn decode_property_value(type_tag: u8, payload: &[u8]) -> PyResult { Ok(match type_tag { TYPE_NONE => { if !payload.is_empty() { @@ -1126,6 +1216,38 @@ fn decode_property_value(type_tag: u8, payload: &[u8]) -> PyResult PropertyValue::Float32Array(decode_f32_array(payload)?), TYPE_VEC3_F64 => PropertyValue::Vec3ArrayF64(decode_vec3_f64(payload)?), TYPE_I32_ARRAY => PropertyValue::Int32Array(decode_i32_array(payload)?), + TYPE_TENSOR_F32 => { + let (shape, data_offset) = tensor_shape_from_payload(type_tag, payload) + .map_err(|e| PyValueError::new_err(e.to_string()))?; + PropertyValue::Tensor(TensorData::F32 { + shape, + values: decode_f32_array(&payload[data_offset..])?, + }) + } + TYPE_TENSOR_F64 => { + let (shape, data_offset) = tensor_shape_from_payload(type_tag, payload) + .map_err(|e| PyValueError::new_err(e.to_string()))?; + PropertyValue::Tensor(TensorData::F64 { + shape, + values: decode_f64_array(&payload[data_offset..])?, + }) + } + TYPE_TENSOR_I32 => { + let (shape, data_offset) = tensor_shape_from_payload(type_tag, payload) + .map_err(|e| PyValueError::new_err(e.to_string()))?; + PropertyValue::Tensor(TensorData::I32 { + shape, + values: decode_i32_array(&payload[data_offset..])?, + }) + } + TYPE_TENSOR_I64 => { + let (shape, data_offset) = tensor_shape_from_payload(type_tag, payload) + .map_err(|e| PyValueError::new_err(e.to_string()))?; + PropertyValue::Tensor(TensorData::I64 { + shape, + values: decode_i64_array(&payload[data_offset..])?, + }) + } _ => { return Err(PyValueError::new_err(format!( "Unsupported property type tag {}", diff --git a/atompack-py/tests/test_atom_molecule.py b/atompack-py/tests/test_atom_molecule.py index 6c03f83..a1a36b9 100644 --- a/atompack-py/tests/test_atom_molecule.py +++ b/atompack-py/tests/test_atom_molecule.py @@ -169,6 +169,8 @@ def test_molecule_custom_properties() -> None: mol.set_property("int_vec32", np.array([3, 4], dtype=np.int32)) mol.set_property("vec3", np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], dtype=np.float32)) mol.set_property("vec3_f64", np.array([[1.1, 1.2, 1.3], [2.1, 2.2, 2.3]], dtype=np.float64)) + mol.set_property("tensor_f32", np.arange(24, dtype=np.float32).reshape(2, 3, 4)) + mol.set_property("tensor_i64", np.arange(6, dtype=np.int64).reshape(2, 3)) mol.set_property("optional_label", None) mol.stress = np.eye(3, dtype=np.float64) * 3.0 @@ -209,6 +211,17 @@ def test_molecule_custom_properties() -> None: np.testing.assert_allclose( vec3_f64, np.array([[1.1, 1.2, 1.3], [2.1, 2.2, 2.3]], dtype=np.float64) ) + + tensor_f32 = mol.get_property("tensor_f32") + assert tensor_f32.dtype == np.float32 + assert tensor_f32.shape == (2, 3, 4) + np.testing.assert_allclose(tensor_f32, np.arange(24, dtype=np.float32).reshape(2, 3, 4)) + + tensor_i64 = mol.get_property("tensor_i64") + assert tensor_i64.dtype == np.int64 + assert tensor_i64.shape == (2, 3) + np.testing.assert_array_equal(tensor_i64, np.arange(6, dtype=np.int64).reshape(2, 3)) + assert mol.get_property("optional_label") is None np.testing.assert_allclose(mol.stress, np.eye(3, dtype=np.float64) * 3.0) @@ -226,6 +239,8 @@ def test_molecule_custom_properties() -> None: "int_vec32", "vec3", "vec3_f64", + "tensor_f32", + "tensor_i64", "optional_label", } @@ -236,6 +251,47 @@ def test_molecule_custom_properties() -> None: mol.get_property("does_not_exist") +def test_molecule_atom_property_scope_and_delete() -> None: + mol = _make_molecule() + + mol.set_property("partial_charge", np.array([0.1, 0.2], dtype=np.float32), scope="atom") + mol.set_property("descriptor", np.arange(8, dtype=np.float64).reshape(2, 2, 2), scope="atom") + + partial_charge = mol.get_property("partial_charge") + assert partial_charge.dtype == np.float32 + np.testing.assert_allclose(partial_charge, np.array([0.1, 0.2], dtype=np.float32)) + + descriptor = mol["descriptor"] + assert descriptor.dtype == np.float64 + assert descriptor.shape == (2, 2, 2) + + assert mol.has_property("partial_charge") is True + assert mol.has_property("partial_charge", scope="atom") is True + assert mol.has_property("partial_charge", scope="molecule") is False + assert "partial_charge" in mol.property_keys() + assert "partial_charge" in mol.property_keys(scope="atom") + assert "partial_charge" not in mol.property_keys(scope="molecule") + + # Existing atom properties keep atom scope when overwritten without scope. + mol.set_property("partial_charge", np.array([0.3, 0.4], dtype=np.float32)) + np.testing.assert_allclose( + mol.get_property("partial_charge"), + np.array([0.3, 0.4], dtype=np.float32), + ) + + with pytest.raises(ValueError, match=r"first dimension"): + mol.set_property("bad_atom", np.array([1.0], dtype=np.float64), scope="atom") + + mol.set_property("label", "train") + with pytest.raises(ValueError, match=r"already exists as a molecule property"): + mol.set_property("label", np.array([1.0, 2.0], dtype=np.float64), scope="atom") + + mol.delete_property("label") + assert mol.has_property("label") is False + with pytest.raises(KeyError, match=r"not found"): + mol.delete_property("label") + + def test_missing_property_raises_keyerror_consistently() -> None: # Lock in symmetry: both indexing patterns must raise the same exception # type for missing custom properties. They previously disagreed. diff --git a/atompack-py/tests/test_database.py b/atompack-py/tests/test_database.py index ecf9107..58a2ffe 100644 --- a/atompack-py/tests/test_database.py +++ b/atompack-py/tests/test_database.py @@ -64,6 +64,7 @@ def test_database_rejects_invalid_compression(tmp_path: Path) -> None: with pytest.raises(ValueError, match=r"Invalid compression"): atompack.Database(str(tmp_path / "bad.atp"), compression="definitely-not-a-codec") + def test_database_add_arrays_batch_rejects_v2_incompatible_builtin_dtype(tmp_path: Path) -> None: path = tmp_path / "batch_arrays_v2_compat.atp" db = atompack.Database(str(path)) @@ -134,6 +135,7 @@ def test_database_roundtrip_from_arrays_with_builtins(tmp_path: Path) -> None: np.testing.assert_allclose(read.forces, forces) np.testing.assert_allclose(read.cell, cell) + def test_database_add_molecules_roundtrip_from_arrays_with_builtins(tmp_path: Path) -> None: path = tmp_path / "from_arrays_add_molecules.atp" positions = np.array([[0.0, 0.0, 0.0], [1.0, 0.0, 0.0]], dtype=np.float32) @@ -172,7 +174,9 @@ def test_database_add_molecules_roundtrip_from_arrays_with_builtins(tmp_path: Pa np.testing.assert_allclose(first.cell, cell) assert second.energy == pytest.approx(-8.5) - np.testing.assert_allclose(second.positions, positions + np.array([[0.5, 0.0, 0.0], [0.5, 0.0, 0.0]], dtype=np.float32)) + np.testing.assert_allclose( + second.positions, positions + np.array([[0.5, 0.0, 0.0], [0.5, 0.0, 0.0]], dtype=np.float32) + ) np.testing.assert_array_equal(second.atomic_numbers, atomic_numbers) np.testing.assert_allclose(second.forces, forces * 2.0) np.testing.assert_allclose(second.cell, cell * 2.0) @@ -494,6 +498,51 @@ def test_database_custom_array_properties_roundtrip_all_numeric_tags( ) +@pytest.mark.parametrize("mmap", [False, True]) +@pytest.mark.parametrize("compression", ["none", "lz4", "zstd"]) +def test_database_tensor_properties_allow_value_level_shapes( + tmp_path: Path, + mmap: bool, + compression: str, +) -> None: + path = tmp_path / f"tensor_value_shapes_{compression}_{mmap}.atp" + mol1 = _make_molecule(-6.0) + mol1.set_property("density", np.arange(24, dtype=np.float32).reshape(2, 3, 4)) + mol1.set_property( + "atom_descriptor", np.arange(8, dtype=np.float32).reshape(2, 2, 2), scope="atom" + ) + + mol2 = _make_molecule(-7.0) + mol2.set_property("density", np.arange(5, dtype=np.float32).reshape(1, 5)) + mol2.set_property("atom_descriptor", np.arange(8, dtype=np.float32).reshape(2, 4), scope="atom") + + db = atompack.Database(str(path), compression=compression) + db.add_molecules([mol1, mol2]) + db.flush() + + reopened = atompack.Database.open(str(path), mmap=mmap) + first = reopened[0] + second = reopened[1] + + first_density = first.get_property("density") + assert first_density.dtype == np.float32 + assert first_density.shape == (2, 3, 4) + np.testing.assert_allclose(first_density, np.arange(24, dtype=np.float32).reshape(2, 3, 4)) + + second_density = second.get_property("density") + assert second_density.dtype == np.float32 + assert second_density.shape == (1, 5) + np.testing.assert_allclose(second_density, np.arange(5, dtype=np.float32).reshape(1, 5)) + + first_descriptor = first.get_property("atom_descriptor") + assert first_descriptor.shape == (2, 2, 2) + np.testing.assert_allclose(first_descriptor, np.arange(8, dtype=np.float32).reshape(2, 2, 2)) + + second_descriptor = second.get_property("atom_descriptor") + assert second_descriptor.shape == (2, 4) + np.testing.assert_allclose(second_descriptor, np.arange(8, dtype=np.float32).reshape(2, 4)) + + @pytest.mark.parametrize("mmap", [False, True]) @pytest.mark.parametrize("compression", ["none", "lz4", "zstd"]) def test_database_single_item_mutation_is_copy_on_write( @@ -632,7 +681,6 @@ def test_database_open_defaults_to_read_only_mmap(tmp_path: Path) -> None: assert len(check) == 2 - def test_database_sequence_iteration_stops_cleanly(tmp_path: Path) -> None: path = tmp_path / "iter.atp" db = atompack.Database(str(path)) diff --git a/atompack-py/tests/test_from_ase.py b/atompack-py/tests/test_from_ase.py index c485cd4..a244322 100644 --- a/atompack-py/tests/test_from_ase.py +++ b/atompack-py/tests/test_from_ase.py @@ -87,11 +87,11 @@ def test_from_ase_extracts_core_fields() -> None: "int_vec32": np.array([3, 4], dtype=np.int32), "vec3": np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], dtype=np.float32), "vec3_f64": np.array([[1.0, 1.1, 1.2], [2.0, 2.1, 2.2]], dtype=np.float64), + "density_grid": np.arange(24, dtype=np.float32).reshape(2, 3, 4), "nullable": None, "stress": np.array( [[1.1, 1.2, 1.3], [2.1, 2.2, 2.3], [3.1, 3.2, 3.3]], dtype=np.float64 ), - "unsupported": {"nested": True}, }, ) @@ -127,6 +127,13 @@ def test_from_ase_extracts_core_fields() -> None: np.testing.assert_allclose( mol.get_property("vec3_f64"), np.array([[1.0, 1.1, 1.2], [2.0, 2.1, 2.2]], dtype=np.float64) ) + density_grid = mol.get_property("density_grid") + assert density_grid.dtype == np.float32 + assert density_grid.shape == (2, 3, 4) + np.testing.assert_allclose( + density_grid, + np.arange(24, dtype=np.float32).reshape(2, 3, 4), + ) assert mol.get_property("nullable") is None np.testing.assert_allclose( mol.stress, @@ -135,8 +142,6 @@ def test_from_ase_extracts_core_fields() -> None: with pytest.raises(KeyError, match=r"not found"): mol.get_property("stress") - assert mol.has_property("unsupported") is False - def test_from_ase_info_override_and_copy_toggle() -> None: atoms = FakeASEAtoms( @@ -153,6 +158,93 @@ def test_from_ase_info_override_and_copy_toggle() -> None: assert mol_override.get_property("temperature") == pytest.approx(500.0) +def test_from_ase_copies_tensor_custom_properties_from_info_and_calc_results() -> None: + info_tensor = np.arange(24, dtype=np.float32).reshape(2, 3, 4) + calc_tensor = np.arange(12, dtype=np.int64).reshape(1, 3, 4) + atoms = FakeASEAtoms( + positions=np.array([[0.0, 0.0, 0.0], [1.0, 0.0, 0.0]], dtype=np.float64), + atomic_numbers=np.array([6, 8], dtype=np.int64), + pbc=np.array([False, False, False]), + info={"density_grid": info_tensor}, + calc=FakeCalc({"band_tensor": calc_tensor}), + ) + + mol = atompack.from_ase(atoms) + + density_grid = mol.get_property("density_grid") + assert density_grid.dtype == np.float32 + assert density_grid.shape == (2, 3, 4) + np.testing.assert_allclose(density_grid, info_tensor) + + band_tensor = mol.get_property("band_tensor") + assert band_tensor.dtype == np.int64 + assert band_tensor.shape == (1, 3, 4) + np.testing.assert_array_equal(band_tensor, calc_tensor) + + +def test_from_ase_custom_arrays_remain_molecule_properties() -> None: + descriptor = np.arange(8, dtype=np.float32).reshape(2, 2, 2) + atoms = FakeASEAtoms( + positions=np.array([[0.0, 0.0, 0.0], [1.0, 0.0, 0.0]], dtype=np.float64), + atomic_numbers=np.array([6, 8], dtype=np.int64), + pbc=np.array([False, False, False]), + arrays={"descriptor": descriptor}, + ) + + mol = atompack.from_ase(atoms) + + restored = mol.get_property("descriptor") + assert restored.dtype == np.float32 + assert restored.shape == (2, 2, 2) + np.testing.assert_allclose(restored, descriptor) + assert mol.has_property("descriptor", scope="molecule") is True + assert mol.has_property("descriptor", scope="atom") is False + + +def test_from_ase_rejects_unsupported_enabled_custom_values_and_honors_optouts() -> None: + atoms_with_bad_info = FakeASEAtoms( + positions=np.array([[0.0, 0.0, 0.0]], dtype=np.float64), + atomic_numbers=np.array([1], dtype=np.int64), + pbc=np.array([False, False, False]), + info={"bad_info": {"nested": True}}, + ) + with pytest.raises( + TypeError, + match=r"Unsupported ASE custom property 'bad_info' from atoms\.info", + ): + atompack.from_ase(atoms_with_bad_info) + + mol = atompack.from_ase(atoms_with_bad_info, copy_info=False) + assert mol.has_property("bad_info") is False + + atoms_with_bad_array = FakeASEAtoms( + positions=np.array([[0.0, 0.0, 0.0]], dtype=np.float64), + atomic_numbers=np.array([1], dtype=np.int64), + pbc=np.array([False, False, False]), + arrays={"bad_array": np.array([object()], dtype=object)}, + ) + with pytest.raises( + TypeError, + match=r"Unsupported ASE custom property 'bad_array' from atoms\.arrays", + ): + atompack.from_ase(atoms_with_bad_array) + + mol = atompack.from_ase(atoms_with_bad_array, copy_arrays=False) + assert mol.has_property("bad_array") is False + + atoms_with_bad_calc_result = FakeASEAtoms( + positions=np.array([[0.0, 0.0, 0.0]], dtype=np.float64), + atomic_numbers=np.array([1], dtype=np.int64), + pbc=np.array([False, False, False]), + calc=FakeCalc({"bad_result": {"nested": True}}), + ) + with pytest.raises( + TypeError, + match=r"Unsupported ASE custom property 'bad_result' from atoms\.calc\.results", + ): + atompack.from_ase(atoms_with_bad_calc_result) + + def test_from_ase_extracts_arrays_and_calc_results() -> None: atoms = FakeASEAtoms( positions=np.array([[0.0, 0.0, 0.0], [1.0, 0.0, 0.0]], dtype=np.float64), @@ -178,17 +270,13 @@ def test_from_ase_extracts_arrays_and_calc_results() -> None: mol = atompack.from_ase(atoms) np.testing.assert_array_equal(mol.get_property("tags"), np.array([1, 2], dtype=np.int32)) - np.testing.assert_allclose( - mol.get_property("masses"), np.array([12.0, 16.0], dtype=np.float64) - ) + np.testing.assert_allclose(mol.get_property("masses"), np.array([12.0, 16.0], dtype=np.float64)) np.testing.assert_allclose( mol.get_property("momenta"), np.array([[0.1, 0.0, 0.0], [0.0, 0.2, 0.0]], dtype=np.float64), ) assert mol.get_property("free_energy") == pytest.approx(-5.5) - np.testing.assert_allclose( - mol.get_property("magmoms"), np.array([0.3, 0.4], dtype=np.float64) - ) + np.testing.assert_allclose(mol.get_property("magmoms"), np.array([0.3, 0.4], dtype=np.float64)) np.testing.assert_allclose(mol.stress, np.eye(3, dtype=np.float64) * 3.0) with pytest.raises(KeyError, match=r"not found"): mol.get_property("forces") @@ -362,6 +450,41 @@ def test_to_ase_owned_maps_builtins_and_properties() -> None: np.testing.assert_array_equal(atoms.arrays["tags"], np.array([3, 4], dtype=np.int32)) +@pytest.mark.parametrize("view_backed", [False, True]) +def test_to_ase_routes_tensor_properties_by_scope(tmp_path, view_backed: bool) -> None: + molecule_tensor = np.arange(8, dtype=np.float32).reshape(2, 2, 2) + atom_tensor = np.arange(8, dtype=np.float64).reshape(2, 2, 2) + mol = atompack.Molecule.from_arrays( + np.array([[0.0, 0.0, 0.0], [1.0, 0.5, 0.0]], dtype=np.float32), + np.array([6, 8], dtype=np.uint8), + ) + mol.set_property("molecule_tensor", molecule_tensor) + mol.set_property("atom_tensor", atom_tensor, scope="atom") + + if view_backed: + path = tmp_path / "to_ase_tensor_scope.atp" + db = atompack.Database(str(path), compression="none") + db.add_molecule(mol) + db.flush() + mol = atompack.Database.open(str(path))[0] + + atoms = mol.to_ase() + + assert "molecule_tensor" in atoms.info + assert "molecule_tensor" not in atoms.arrays + restored_molecule_tensor = atoms.info["molecule_tensor"] + assert restored_molecule_tensor.dtype == np.float32 + assert restored_molecule_tensor.shape == (2, 2, 2) + np.testing.assert_allclose(restored_molecule_tensor, molecule_tensor) + + assert "atom_tensor" in atoms.arrays + assert "atom_tensor" not in atoms.info + restored_atom_tensor = atoms.arrays["atom_tensor"] + assert restored_atom_tensor.dtype == np.float64 + assert restored_atom_tensor.shape == (2, 2, 2) + np.testing.assert_allclose(restored_atom_tensor, atom_tensor) + + def test_to_ase_roundtrip_preserves_none_custom_property() -> None: mol = atompack.Molecule.from_arrays( np.array([[0.0, 0.0, 0.0]], dtype=np.float32), diff --git a/atompack/src/atom.rs b/atompack/src/atom.rs index ed7b575..16db2d1 100644 --- a/atompack/src/atom.rs +++ b/atompack/src/atom.rs @@ -3,7 +3,9 @@ use bytemuck::{Pod, Zeroable}; use serde::{Deserialize, Serialize}; use std::collections::HashMap; -pub use crate::types::{FloatArrayData, FloatScalarData, Mat3Data, PropertyValue, Vec3Data}; +pub use crate::types::{ + FloatArrayData, FloatScalarData, Mat3Data, PropertyValue, TensorData, Vec3Data, +}; #[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize, Pod, Zeroable)] #[repr(C)] diff --git a/atompack/src/storage/dtypes.rs b/atompack/src/storage/dtypes.rs index 309eec1..17bc20d 100644 --- a/atompack/src/storage/dtypes.rs +++ b/atompack/src/storage/dtypes.rs @@ -1,5 +1,5 @@ use super::*; -use crate::atom::{FloatArrayData, FloatScalarData, Mat3Data, PropertyValue, Vec3Data}; +use crate::atom::{FloatArrayData, FloatScalarData, Mat3Data, PropertyValue, TensorData, Vec3Data}; pub(super) fn arr(bytes: &[u8]) -> Result<[u8; N]> { bytes @@ -79,6 +79,10 @@ pub(super) fn property_value_type_tag(value: &PropertyValue) -> u8 { PropertyValue::Float32Array(_) => TYPE_F32_ARRAY, PropertyValue::Vec3ArrayF64(_) => TYPE_VEC3_F64, PropertyValue::Int32Array(_) => TYPE_I32_ARRAY, + PropertyValue::Tensor(TensorData::F32 { .. }) => TYPE_TENSOR_F32, + PropertyValue::Tensor(TensorData::F64 { .. }) => TYPE_TENSOR_F64, + PropertyValue::Tensor(TensorData::I32 { .. }) => TYPE_TENSOR_I32, + PropertyValue::Tensor(TensorData::I64 { .. }) => TYPE_TENSOR_I64, } } @@ -93,9 +97,114 @@ pub(super) fn property_value_payload_len(value: &PropertyValue) -> usize { PropertyValue::Float32Array(values) => values.len() * 4, PropertyValue::Vec3ArrayF64(values) => values.len() * 24, PropertyValue::Int32Array(values) => values.len() * 4, + PropertyValue::Tensor(values) => tensor_data_payload_len(values), } } +pub(super) fn is_tensor_type_tag(type_tag: u8) -> bool { + matches!( + type_tag, + TYPE_TENSOR_F32 | TYPE_TENSOR_F64 | TYPE_TENSOR_I32 | TYPE_TENSOR_I64 + ) +} + +pub(super) fn tensor_type_tag_elem_bytes(type_tag: u8) -> Option { + match type_tag { + TYPE_TENSOR_F32 | TYPE_TENSOR_I32 => Some(4), + TYPE_TENSOR_F64 | TYPE_TENSOR_I64 => Some(8), + _ => None, + } +} + +fn tensor_value_count(shape: &[usize]) -> Result { + shape.iter().try_fold(1usize, |acc, dim| { + acc.checked_mul(*dim) + .ok_or_else(|| Error::InvalidData("Tensor shape overflows usize".into())) + }) +} + +fn tensor_data_payload_len(value: &TensorData) -> usize { + let shape = value.shape(); + let values_len = match value { + TensorData::F32 { values, .. } => values.len() * 4, + TensorData::F64 { values, .. } => values.len() * 8, + TensorData::I32 { values, .. } => values.len() * 4, + TensorData::I64 { values, .. } => values.len() * 8, + }; + 1 + shape.len() * 4 + values_len +} + +fn extend_tensor_header(buf: &mut Vec, shape: &[usize]) { + buf.push(shape.len() as u8); + for dim in shape { + buf.extend_from_slice(&(*dim as u32).to_le_bytes()); + } +} + +fn validate_tensor_data_shape(shape: &[usize], values_len: usize) -> Result<()> { + if shape.len() > u8::MAX as usize { + return Err(Error::InvalidData(format!( + "Tensor rank {} exceeds maximum {}", + shape.len(), + u8::MAX + ))); + } + if shape.iter().any(|dim| *dim > u32::MAX as usize) { + return Err(Error::InvalidData( + "Tensor dimensions must fit in u32 for storage".into(), + )); + } + let expected = tensor_value_count(shape)?; + if expected != values_len { + return Err(Error::InvalidData(format!( + "Tensor shape {:?} expects {} values, got {}", + shape, expected, values_len + ))); + } + Ok(()) +} + +pub(super) fn tensor_shape_from_payload( + type_tag: u8, + payload: &[u8], +) -> Result<(Vec, usize)> { + let elem_bytes = tensor_type_tag_elem_bytes(type_tag) + .ok_or_else(|| Error::InvalidData(format!("Type tag {} is not a tensor", type_tag)))?; + let Some((&rank, rest)) = payload.split_first() else { + return Err(Error::InvalidData("Tensor payload missing rank".into())); + }; + let rank = rank as usize; + let shape_bytes = rank + .checked_mul(4) + .ok_or_else(|| Error::InvalidData("Tensor rank overflow".into()))?; + if rest.len() < shape_bytes { + return Err(Error::InvalidData( + "Tensor payload truncated at shape".into(), + )); + } + let mut shape = Vec::with_capacity(rank); + for chunk in rest[..shape_bytes].chunks_exact(4) { + shape.push(u32::from_le_bytes(arr(chunk)?) as usize); + } + let data_offset = 1 + shape_bytes; + let data_len = payload.len() - data_offset; + if !data_len.is_multiple_of(elem_bytes) { + return Err(Error::InvalidData(format!( + "Tensor payload length {} is not divisible by element size {}", + data_len, elem_bytes + ))); + } + let expected = tensor_value_count(&shape)?; + let actual = data_len / elem_bytes; + if expected != actual { + return Err(Error::InvalidData(format!( + "Tensor shape {:?} expects {} values, got {}", + shape, expected, actual + ))); + } + Ok((shape, data_offset)) +} + fn extend_f64(buf: &mut Vec, values: &[f64]) { for value in values { buf.extend_from_slice(&f64::to_le_bytes(*value)); @@ -160,6 +269,34 @@ pub(super) fn property_value_to_bytes(value: &PropertyValue) -> Vec { extend_i32(&mut payload, values); payload } + PropertyValue::Tensor(TensorData::F32 { shape, values }) => { + validate_tensor_data_shape(shape, values.len()).expect("invalid f32 tensor property"); + let mut payload = Vec::with_capacity(1 + shape.len() * 4 + values.len() * 4); + extend_tensor_header(&mut payload, shape); + extend_f32(&mut payload, values); + payload + } + PropertyValue::Tensor(TensorData::F64 { shape, values }) => { + validate_tensor_data_shape(shape, values.len()).expect("invalid f64 tensor property"); + let mut payload = Vec::with_capacity(1 + shape.len() * 4 + values.len() * 8); + extend_tensor_header(&mut payload, shape); + extend_f64(&mut payload, values); + payload + } + PropertyValue::Tensor(TensorData::I32 { shape, values }) => { + validate_tensor_data_shape(shape, values.len()).expect("invalid i32 tensor property"); + let mut payload = Vec::with_capacity(1 + shape.len() * 4 + values.len() * 4); + extend_tensor_header(&mut payload, shape); + extend_i32(&mut payload, values); + payload + } + PropertyValue::Tensor(TensorData::I64 { shape, values }) => { + validate_tensor_data_shape(shape, values.len()).expect("invalid i64 tensor property"); + let mut payload = Vec::with_capacity(1 + shape.len() * 4 + values.len() * 8); + extend_tensor_header(&mut payload, shape); + extend_i64(&mut payload, values); + payload + } } } @@ -349,6 +486,40 @@ pub(super) fn decode_property_value(type_tag: u8, payload: &[u8]) -> Result>()?, ) } + TYPE_TENSOR_F32 => { + let (shape, data_offset) = tensor_shape_from_payload(type_tag, payload)?; + PropertyValue::Tensor(TensorData::F32 { + shape, + values: decode_f32_array(&payload[data_offset..])?, + }) + } + TYPE_TENSOR_F64 => { + let (shape, data_offset) = tensor_shape_from_payload(type_tag, payload)?; + PropertyValue::Tensor(TensorData::F64 { + shape, + values: decode_f64_array(&payload[data_offset..])?, + }) + } + TYPE_TENSOR_I32 => { + let (shape, data_offset) = tensor_shape_from_payload(type_tag, payload)?; + PropertyValue::Tensor(TensorData::I32 { + shape, + values: payload[data_offset..] + .chunks_exact(4) + .map(|chunk| Ok(i32::from_le_bytes(arr(chunk)?))) + .collect::>()?, + }) + } + TYPE_TENSOR_I64 => { + let (shape, data_offset) = tensor_shape_from_payload(type_tag, payload)?; + PropertyValue::Tensor(TensorData::I64 { + shape, + values: payload[data_offset..] + .chunks_exact(8) + .map(|chunk| Ok(i64::from_le_bytes(arr(chunk)?))) + .collect::>()?, + }) + } _ => return Err(Error::InvalidData(format!("Unknown type tag {}", type_tag))), }) } diff --git a/atompack/src/storage/mod.rs b/atompack/src/storage/mod.rs index 405bb4c..e0869b7 100644 --- a/atompack/src/storage/mod.rs +++ b/atompack/src/storage/mod.rs @@ -80,6 +80,10 @@ const TYPE_MAT3X3_F64: u8 = 10; // [[f64; 3]; 3] const TYPE_FLOAT32: u8 = 11; // f32 scalar const TYPE_MAT3X3_F32: u8 = 12; // [[f32; 3]; 3] const TYPE_NONE: u8 = 13; // explicit null property +const TYPE_TENSOR_F32: u8 = 14; // [ndim:u8][dims:u32...][f32...] +const TYPE_TENSOR_F64: u8 = 15; // [ndim:u8][dims:u32...][f64...] +const TYPE_TENSOR_I32: u8 = 16; // [ndim:u8][dims:u32...][i32...] +const TYPE_TENSOR_I64: u8 = 17; // [ndim:u8][dims:u32...][i64...] // Two redundant page-aligned header slots for crash safety. const HEADER_SLOT_SIZE: usize = 4096; @@ -1368,7 +1372,7 @@ mod tests { #[test] fn test_soa_all_property_value_types() { - use crate::atom::PropertyValue; + use crate::atom::{PropertyValue, TensorData}; let temp = NamedTempFile::new().unwrap(); let path = temp.path().to_path_buf(); @@ -1376,24 +1380,35 @@ mod tests { let mut mol = molecule_from_atoms(vec![Atom::new(0.0, 0.0, 0.0, 1)]); // Test every PropertyValue variant in atom_properties - mol.atom_properties - .insert("f64arr".to_string(), PropertyValue::FloatArray(vec![1.5])); mol.atom_properties.insert( - "vec3f32".to_string(), + "atom_f64arr".to_string(), + PropertyValue::FloatArray(vec![1.5]), + ); + mol.atom_properties.insert( + "atom_vec3f32".to_string(), PropertyValue::Vec3Array(vec![[1.0, 2.0, 3.0]]), ); mol.atom_properties - .insert("i64arr".to_string(), PropertyValue::IntArray(vec![42])); + .insert("atom_i64arr".to_string(), PropertyValue::IntArray(vec![42])); mol.atom_properties.insert( - "f32arr".to_string(), + "atom_f32arr".to_string(), PropertyValue::Float32Array(vec![3.125]), ); mol.atom_properties.insert( - "vec3f64".to_string(), + "atom_vec3f64".to_string(), PropertyValue::Vec3ArrayF64(vec![[1.0, 2.0, 3.0]]), ); - mol.atom_properties - .insert("i32arr".to_string(), PropertyValue::Int32Array(vec![-7])); + mol.atom_properties.insert( + "atom_i32arr".to_string(), + PropertyValue::Int32Array(vec![-7]), + ); + mol.atom_properties.insert( + "atom_tensor".to_string(), + PropertyValue::Tensor(TensorData::F32 { + shape: vec![1, 2], + values: vec![0.25, 0.75], + }), + ); // Test every PropertyValue variant in properties mol.properties @@ -1428,6 +1443,13 @@ mod tests { ); mol.properties .insert("none_val".to_string(), PropertyValue::None); + mol.properties.insert( + "tensor_val".to_string(), + PropertyValue::Tensor(TensorData::I64 { + shape: vec![2, 2], + values: vec![1, 2, 3, 4], + }), + ); { let mut db = AtomDatabase::create(&path, CompressionType::Lz4).unwrap(); @@ -1439,34 +1461,41 @@ mod tests { let r = db.get_molecule(0).unwrap(); // Verify atom_properties - assert_eq!(r.atom_properties.len(), 6); - match r.atom_properties.get("f64arr").unwrap() { + assert_eq!(r.atom_properties.len(), 7); + match r.atom_properties.get("atom_f64arr").unwrap() { PropertyValue::FloatArray(v) => assert_eq!(v, &[1.5]), other => panic!("expected FloatArray, got {:?}", other), } - match r.atom_properties.get("vec3f32").unwrap() { + match r.atom_properties.get("atom_vec3f32").unwrap() { PropertyValue::Vec3Array(v) => assert_eq!(v, &[[1.0, 2.0, 3.0]]), other => panic!("expected Vec3Array, got {:?}", other), } - match r.atom_properties.get("i64arr").unwrap() { + match r.atom_properties.get("atom_i64arr").unwrap() { PropertyValue::IntArray(v) => assert_eq!(v, &[42]), other => panic!("expected IntArray, got {:?}", other), } - match r.atom_properties.get("f32arr").unwrap() { + match r.atom_properties.get("atom_f32arr").unwrap() { PropertyValue::Float32Array(v) => assert_eq!(v, &[3.125f32]), other => panic!("expected Float32Array, got {:?}", other), } - match r.atom_properties.get("vec3f64").unwrap() { + match r.atom_properties.get("atom_vec3f64").unwrap() { PropertyValue::Vec3ArrayF64(v) => assert_eq!(v, &[[1.0, 2.0, 3.0]]), other => panic!("expected Vec3ArrayF64, got {:?}", other), } - match r.atom_properties.get("i32arr").unwrap() { + match r.atom_properties.get("atom_i32arr").unwrap() { PropertyValue::Int32Array(v) => assert_eq!(v, &[-7]), other => panic!("expected Int32Array, got {:?}", other), } + match r.atom_properties.get("atom_tensor").unwrap() { + PropertyValue::Tensor(TensorData::F32 { shape, values }) => { + assert_eq!(shape, &[1, 2]); + assert_eq!(values, &[0.25, 0.75]); + } + other => panic!("expected TensorData::F32, got {:?}", other), + } // Verify properties - assert_eq!(r.properties.len(), 10); + assert_eq!(r.properties.len(), 11); match r.properties.get("scalar_f").unwrap() { PropertyValue::Float(v) => assert_eq!(*v, 99.9), other => panic!("expected Float, got {:?}", other), @@ -1483,6 +1512,13 @@ mod tests { PropertyValue::None => {} other => panic!("expected None, got {:?}", other), } + match r.properties.get("tensor_val").unwrap() { + PropertyValue::Tensor(TensorData::I64 { shape, values }) => { + assert_eq!(shape, &[2, 2]); + assert_eq!(values, &[1, 2, 3, 4]); + } + other => panic!("expected TensorData::I64, got {:?}", other), + } } #[test] diff --git a/atompack/src/storage/schema.rs b/atompack/src/storage/schema.rs index 7ba950f..d5ab94b 100644 --- a/atompack/src/storage/schema.rs +++ b/atompack/src/storage/schema.rs @@ -1,7 +1,8 @@ use super::dtypes::{ arr, float_array_data_type_tag, float_array_payload_len, float_scalar_data_type_tag, - float_scalar_payload_len, mat3_data_type_tag, mat3_payload_len, positions_type_from_molecule, - property_value_payload_len, property_value_type_tag, + float_scalar_payload_len, is_tensor_type_tag, mat3_data_type_tag, mat3_payload_len, + positions_type_from_molecule, property_value_payload_len, property_value_type_tag, + tensor_shape_from_payload, tensor_type_tag_elem_bytes, validate_builtin_type_tag_for_record_format, vec3_data_type_tag, vec3_payload_len, }; use super::soa::{SoaLayout, resolve_layout}; @@ -139,6 +140,8 @@ fn schema_type_tag_elem_bytes(tag: u8) -> Result { TYPE_MAT3X3_F64 => Ok(72), TYPE_FLOAT32 => Ok(4), TYPE_MAT3X3_F32 => Ok(36), + TYPE_TENSOR_F32 | TYPE_TENSOR_I32 => Ok(4), + TYPE_TENSOR_F64 | TYPE_TENSOR_I64 => Ok(8), _ => Err(Error::InvalidData(format!( "Unsupported section type tag {}", tag @@ -160,11 +163,13 @@ fn schema_entry( key: &str, type_tag: u8, payload_len: usize, + tensor_payload: Option<&[u8]>, n_atoms: usize, ) -> Result { let per_atom = schema_is_per_atom(kind, key); let elem_bytes = schema_type_tag_elem_bytes(type_tag)?; - let slot_bytes = if matches!(type_tag, TYPE_STRING | TYPE_NONE) { + let slot_bytes = if matches!(type_tag, TYPE_STRING | TYPE_NONE) || is_tensor_type_tag(type_tag) + { 0 } else if per_atom { match type_tag { @@ -178,7 +183,38 @@ fn schema_entry( payload_len }; - if per_atom { + if per_atom && is_tensor_type_tag(type_tag) { + let payload = tensor_payload.ok_or_else(|| { + Error::InvalidData(format!("Tensor section '{}' is missing payload", key)) + })?; + let (shape, _) = tensor_shape_from_payload(type_tag, payload)?; + match shape.first() { + Some(first_dim) if *first_dim == n_atoms => {} + Some(first_dim) => { + return Err(Error::InvalidData(format!( + "Atom tensor property '{}' first dimension ({}) doesn't match atom count ({})", + key, first_dim, n_atoms + ))); + } + None => { + return Err(Error::InvalidData(format!( + "Atom tensor property '{}' must have at least one dimension", + key + ))); + } + } + if tensor_type_tag_elem_bytes(type_tag) != Some(elem_bytes) { + return Err(Error::InvalidData(format!( + "Tensor section '{}' has inconsistent element size", + key + ))); + } + } else if is_tensor_type_tag(type_tag) { + let payload = tensor_payload.ok_or_else(|| { + Error::InvalidData(format!("Tensor section '{}' is missing payload", key)) + })?; + let _ = tensor_shape_from_payload(type_tag, payload)?; + } else if per_atom { let expected = elem_bytes .checked_mul(n_atoms) .ok_or_else(|| Error::InvalidData(format!("Schema overflow for section '{}'", key)))?; @@ -211,6 +247,20 @@ pub(super) fn validate_schema_lock_for_record_format( Ok(()) } +fn insert_schema_entry( + schema: &mut SchemaLock, + kind: u8, + key: &str, + type_tag: u8, + payload_len: usize, + tensor_payload: Option<&[u8]>, + n_atoms: usize, +) -> Result<()> { + let entry = schema_entry(kind, key, type_tag, payload_len, tensor_payload, n_atoms)?; + schema.sections.insert((kind, key.to_string()), entry); + Ok(()) +} + pub(super) fn schema_from_molecule(molecule: &Molecule) -> Result { let n_atoms = molecule.len(); let mut schema = SchemaLock { @@ -218,82 +268,153 @@ pub(super) fn schema_from_molecule(molecule: &Molecule) -> Result { sections: BTreeMap::new(), }; - let mut insert = |kind: u8, key: &str, type_tag: u8, payload_len: usize| -> Result<()> { - let entry = schema_entry(kind, key, type_tag, payload_len, n_atoms)?; - schema.sections.insert((kind, key.to_string()), entry); - Ok(()) - }; + for key in molecule.atom_properties.keys() { + if molecule.properties.contains_key(key) { + return Err(Error::InvalidData(format!( + "Custom property '{}' exists in both atom and molecule scopes", + key + ))); + } + } if let Some(charges) = &molecule.charges { - insert( + insert_schema_entry( + &mut schema, KIND_BUILTIN, "charges", float_array_data_type_tag(charges), float_array_payload_len(charges), + None, + n_atoms, )?; } if let Some(cell) = &molecule.cell { - insert( + insert_schema_entry( + &mut schema, KIND_BUILTIN, "cell", mat3_data_type_tag(cell), mat3_payload_len(cell), + None, + n_atoms, )?; } if let Some(energy) = &molecule.energy { - insert( + insert_schema_entry( + &mut schema, KIND_BUILTIN, "energy", float_scalar_data_type_tag(energy), float_scalar_payload_len(energy), + None, + n_atoms, )?; } if let Some(forces) = &molecule.forces { - insert( + insert_schema_entry( + &mut schema, KIND_BUILTIN, "forces", vec3_data_type_tag(forces), vec3_payload_len(forces), + None, + n_atoms, )?; } if let Some(name) = &molecule.name { - insert(KIND_BUILTIN, "name", TYPE_STRING, name.len())?; + insert_schema_entry( + &mut schema, + KIND_BUILTIN, + "name", + TYPE_STRING, + name.len(), + None, + n_atoms, + )?; } if molecule.pbc.is_some() { - insert(KIND_BUILTIN, "pbc", TYPE_BOOL3, 3)?; + insert_schema_entry( + &mut schema, + KIND_BUILTIN, + "pbc", + TYPE_BOOL3, + 3, + None, + n_atoms, + )?; } if let Some(stress) = &molecule.stress { - insert( + insert_schema_entry( + &mut schema, KIND_BUILTIN, "stress", mat3_data_type_tag(stress), mat3_payload_len(stress), + None, + n_atoms, )?; } if let Some(velocities) = &molecule.velocities { - insert( + insert_schema_entry( + &mut schema, KIND_BUILTIN, "velocities", vec3_data_type_tag(velocities), vec3_payload_len(velocities), + None, + n_atoms, )?; } for (key, value) in &molecule.atom_properties { - insert( - KIND_ATOM_PROP, - key, - property_value_type_tag(value), - property_value_payload_len(value), - )?; + let type_tag = property_value_type_tag(value); + if is_tensor_type_tag(type_tag) { + let payload = super::dtypes::property_value_to_bytes(value); + insert_schema_entry( + &mut schema, + KIND_ATOM_PROP, + key, + type_tag, + payload.len(), + Some(&payload), + n_atoms, + )?; + } else { + insert_schema_entry( + &mut schema, + KIND_ATOM_PROP, + key, + type_tag, + property_value_payload_len(value), + None, + n_atoms, + )?; + } } for (key, value) in &molecule.properties { - insert( - KIND_MOL_PROP, - key, - property_value_type_tag(value), - property_value_payload_len(value), - )?; + let type_tag = property_value_type_tag(value); + if is_tensor_type_tag(type_tag) { + let payload = super::dtypes::property_value_to_bytes(value); + insert_schema_entry( + &mut schema, + KIND_MOL_PROP, + key, + type_tag, + payload.len(), + Some(&payload), + n_atoms, + )?; + } else { + insert_schema_entry( + &mut schema, + KIND_MOL_PROP, + key, + type_tag, + property_value_payload_len(value), + None, + n_atoms, + )?; + } } Ok(schema) @@ -374,11 +495,13 @@ fn parse_record_schema_with_layout( if payload_end > bytes.len() { return Err(Error::InvalidData("SOA section payload truncated".into())); } + let payload = &bytes[pos..payload_end]; pos = payload_end; if kind == KIND_BUILTIN { validate_builtin_type_tag_for_record_format(record_format, &key, type_tag)?; } - let entry = schema_entry(kind, &key, type_tag, payload_len, n_atoms)?; + let tensor_payload = is_tensor_type_tag(type_tag).then_some(payload); + let entry = schema_entry(kind, &key, type_tag, payload_len, tensor_payload, n_atoms)?; schema.sections.insert((kind, key), entry); } diff --git a/atompack/src/types.rs b/atompack/src/types.rs index e6d8b54..1b92008 100644 --- a/atompack/src/types.rs +++ b/atompack/src/types.rs @@ -13,6 +13,7 @@ pub enum PropertyValue { Float32Array(Vec), Vec3ArrayF64(Vec<[f64; 3]>), Int32Array(Vec), + Tensor(TensorData), } impl PropertyValue { @@ -24,6 +25,7 @@ impl PropertyValue { PropertyValue::Float32Array(values) => Some(values.len()), PropertyValue::Vec3ArrayF64(values) => Some(values.len()), PropertyValue::Int32Array(values) => Some(values.len()), + PropertyValue::Tensor(values) => values.first_dim(), _ => None, } } @@ -33,6 +35,29 @@ impl PropertyValue { } } +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum TensorData { + F32 { shape: Vec, values: Vec }, + F64 { shape: Vec, values: Vec }, + I32 { shape: Vec, values: Vec }, + I64 { shape: Vec, values: Vec }, +} + +impl TensorData { + pub fn shape(&self) -> &[usize] { + match self { + Self::F32 { shape, .. } + | Self::F64 { shape, .. } + | Self::I32 { shape, .. } + | Self::I64 { shape, .. } => shape, + } + } + + pub fn first_dim(&self) -> Option { + self.shape().first().copied() + } +} + #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] pub enum Vec3Data { F32(Vec<[f32; 3]>), diff --git a/docs/source/getting-started.rst b/docs/source/getting-started.rst index 08c4889..477458d 100644 --- a/docs/source/getting-started.rst +++ b/docs/source/getting-started.rst @@ -126,6 +126,12 @@ structures efficiently with ``add_ase_batch(...)``: ``add_ase_batch(...)`` is the preferred path when you already have an iterator or list of ``ase.Atoms`` objects and want to ingest them directly into a database. +Custom ASE values are copied as molecule properties during ingestion. ``atoms.info``, +custom ``atoms.arrays``, and non-builtin calculator results support ``None``, strings, +numeric scalars, and numeric arrays with dtype ``float32``, ``float64``, ``int32``, or +``int64``. Higher-rank arrays are stored as tensor properties; Atompack does not infer +atom-property scope from ASE array shape. + When Atompack Is A Good Fit ---------------------------