Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 44 additions & 13 deletions atompack-py/python/atompack/__init__.pyi
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Type stubs for atompack"""

from typing import Any, Sequence, overload
from typing import Any, Literal, Sequence, overload

import numpy as np
import numpy.typing as npt
Expand Down Expand Up @@ -201,9 +201,10 @@ class Molecule:
This reads directly from the molecule getters, so it works for both
owned and view-backed molecules without going through `atoms()`.
Geometry/species always become the ASE structure, supported builtin
results are attached through `SinglePointCalculator`, per-atom custom
arrays go to `atoms.arrays`, and remaining custom properties go to
`atoms.info`.
results are attached through `SinglePointCalculator`, atom-scope custom
properties go to `atoms.arrays`, molecule-scope tensor properties stay
in `atoms.info`, and legacy molecule arrays may still be shape-routed to
`atoms.arrays`.
"""
...

Expand Down Expand Up @@ -336,20 +337,30 @@ class Molecule:
"""
...

def set_property(self, key: str, value: float | int | str | npt.NDArray | None) -> None:
def set_property(
self,
key: str,
value: float | int | str | npt.NDArray | None,
*,
scope: Literal["molecule", "atom"] | None = None,
) -> None:
"""
Set a custom property.

Supported types: None, float, int, str, 1D float32/float64/int32/int64 arrays,
and 2D float32/float64 arrays with shape (n, 3). Input dtype is preserved.
The key 'stress' is reserved; use the dedicated ``stress`` property instead.
Supported types: None, float, int, str, and numeric ndarrays with dtype
float32, float64, int32, or int64. Input dtype and tensor shape are preserved.
New atom properties require ``scope="atom"``; existing atom properties keep
atom scope when overwritten. The key 'stress' is reserved; use the dedicated
``stress`` property instead.

Parameters
----------
key : str
Property key
value : float, int, str, ndarray, or None
Property value
scope : {"molecule", "atom"}, optional
Property scope. Defaults to molecule for new keys.

Raises
------
Expand All @@ -358,7 +369,7 @@ class Molecule:
"""
...

def property_keys(self) -> list[str]:
def property_keys(self, *, scope: Literal["molecule", "atom"] | None = None) -> list[str]:
"""
Get all property keys.

Expand All @@ -369,14 +380,16 @@ class Molecule:
"""
...

def has_property(self, key: str) -> bool:
def has_property(self, key: str, *, scope: Literal["molecule", "atom"] | None = None) -> bool:
"""
Check if a property exists.

Parameters
----------
key : str
Property key
scope : {"molecule", "atom"}, optional
Restrict the lookup to one scope.

Returns
-------
Expand All @@ -385,6 +398,17 @@ class Molecule:
"""
...

def delete_property(self, key: str) -> None:
"""
Delete a custom property by key.

Raises
------
KeyError
If property key does not exist
"""
...

@overload
def __getitem__(self, index: int) -> Atom: ...
@overload
Expand Down Expand Up @@ -630,15 +654,17 @@ def from_ase(
cell: npt.NDArray[np.float64] | None = None,
stress: npt.NDArray[np.float64] | None = None,
copy_info: bool = True,
copy_arrays: bool = True,
info: dict | None = None,
) -> Molecule:
"""
Convert an ASE Atoms object to an atompack Molecule.

This function extracts positions, atomic numbers, and available properties
(forces, energy, charges, velocities, cell) from an ASE Atoms object and
creates a corresponding atompack Molecule. Custom properties from atoms.info
dict are also copied by default.
creates a corresponding atompack Molecule. Custom properties from
atoms.info, atoms.arrays, and calculator results are copied as
molecule-scope properties by default.

Parameters
----------
Expand All @@ -656,7 +682,11 @@ def from_ase(
Override cell from ASE Atoms. If None, attempts to extract from atoms.
copy_info : bool, default=True
If True, copies custom properties from atoms.info dict to molecule properties.
Supports: str, int, float, 1D float/int arrays, 2D arrays with shape (n, 3).
Supported values are None, scalar int/float/bool, str, and ndarray-like
values with dtype float32, float64, int32, or int64.
copy_arrays : bool, default=True
If True, copies custom values from atoms.arrays to molecule properties.
Shape is not used to infer atom-property scope.
info : dict, optional
Additional properties to store in the molecule. These will be added after
copying atoms.info (if copy_info=True), so they can override atoms.info values.
Expand Down Expand Up @@ -731,6 +761,7 @@ def add_ase_batch(
atoms_list: list[object],
*,
copy_info: bool = True,
copy_arrays: bool = True,
info: dict | list[dict | None] | None = None,
batch_size: int = 512,
) -> None:
Expand Down
22 changes: 18 additions & 4 deletions atompack-py/python/atompack/_atompack_rs.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from __future__ import annotations

from typing import Any, Sequence, overload
from typing import Any, Literal, Sequence, overload

import numpy as np

Expand Down Expand Up @@ -329,7 +329,13 @@ class PyMolecule:
"""
...

def set_property(self, key: str, value: Any) -> None:
def set_property(
self,
key: str,
value: Any,
*,
scope: Literal["molecule", "atom"] | None = None,
) -> None:
"""
Set a custom property.

Expand All @@ -339,10 +345,12 @@ class PyMolecule:
Property key
value : Any
Property value
scope : {"molecule", "atom"}, optional
Property scope. Defaults to molecule for new keys.
"""
...

def property_keys(self) -> list[str]:
def property_keys(self, *, scope: Literal["molecule", "atom"] | None = None) -> list[str]:
"""
Get all property keys.

Expand All @@ -353,14 +361,16 @@ class PyMolecule:
"""
...

def has_property(self, key: str) -> bool:
def has_property(self, key: str, *, scope: Literal["molecule", "atom"] | None = None) -> bool:
"""
Check if a property exists.

Parameters
----------
key : str
Property key
scope : {"molecule", "atom"}, optional
Restrict the lookup to one scope.

Returns
-------
Expand All @@ -369,6 +379,10 @@ class PyMolecule:
"""
...

def delete_property(self, key: str) -> None:
"""Delete a custom property by key."""
...

@overload
def __getitem__(self, index: int) -> PyAtom: ...
@overload
Expand Down
98 changes: 65 additions & 33 deletions atompack-py/python/atompack/ase_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,16 @@
_ASE_TYPES = None
_CALC_MODES = {"singlepoint", "nocopy", "none"}
_UNSUPPORTED_PROPERTY = object()
_SUPPORTED_ARRAY_DTYPES = {
np.dtype(np.float32),
np.dtype(np.float64),
np.dtype(np.int32),
np.dtype(np.int64),
}
_SUPPORTED_CUSTOM_PROPERTY_TEXT = (
"supported values are None, scalar int/float/bool, str, or ndarray-like "
"values with dtype float32, float64, int32, or int64"
)


def _voigt6_to_mat3x3(stress):
Expand Down Expand Up @@ -53,36 +63,50 @@ def _get_stress(atoms):
return None


def _coerce_property(value, n_atoms):
def _coerce_property(value):
if value is None:
return None
if isinstance(value, (str, bool, int, float, np.integer, np.floating)):
if isinstance(value, str):
return value
if isinstance(value, str):
return value
if isinstance(value, (bool, int, float, np.integer, np.floating)):
if isinstance(value, (bool, int, np.integer)):
return int(value)
return float(value)

arr = np.asarray(value)
try:
arr = np.asarray(value)
except Exception:
return _UNSUPPORTED_PROPERTY

if arr.ndim == 0 and arr.dtype.kind in {"b", "i", "u", "f"}:
return arr.item()
if arr.ndim == 1 and arr.shape[0] == n_atoms:
if arr.dtype == np.float32:
return arr.astype(np.float32, copy=False)
if arr.dtype.kind == "f":
return arr.astype(np.float64, copy=False)
if arr.dtype == np.int32:
return arr.astype(np.int32, copy=False)
if arr.dtype.kind in {"b", "i", "u"}:
return arr.astype(np.int64, copy=False)
if arr.ndim == 2 and arr.shape == (n_atoms, 3) and arr.dtype.kind == "f":
if arr.dtype == np.float32:
return arr.astype(np.float32, copy=False)
return arr.astype(np.float64, copy=False)
if arr.dtype in _SUPPORTED_ARRAY_DTYPES:
return arr.astype(arr.dtype, copy=False)
return _UNSUPPORTED_PROPERTY


def _merge_properties(properties, builtins, values, n_atoms):
def _unsupported_property_reason(value):
try:
arr = np.asarray(value)
except Exception as exc:
return f"could not convert value to an ndarray: {exc}"
return (
f"got value of type {type(value).__name__} with ndarray dtype "
f"{arr.dtype} and shape {arr.shape}; {_SUPPORTED_CUSTOM_PROPERTY_TEXT}"
)


def _coerce_custom_property(key, value, source):
coerced = _coerce_property(value)
if coerced is _UNSUPPORTED_PROPERTY:
raise TypeError(
f"Unsupported ASE custom property {key!r} from {source}: "
f"{_unsupported_property_reason(value)}"
)
return coerced


def _merge_properties(properties, builtins, values, source):
for key, value in values.items():
if key in _BUILTIN_FIELDS:
# Builtin keys in atoms.info / info-override go to the builtins
Expand All @@ -95,9 +119,7 @@ def _merge_properties(properties, builtins, values, n_atoms):
if arr.shape == (3, 3) and arr.dtype.kind == "f":
builtins["stress"] = arr.astype(np.float64, copy=False)
continue
coerced = _coerce_property(value, n_atoms)
if coerced is not _UNSUPPORTED_PROPERTY:
properties[key] = coerced
properties[key] = _coerce_custom_property(key, value, source)


def _extract_ase_record(
Expand All @@ -110,6 +132,7 @@ def _extract_ase_record(
cell=None,
stress=None,
copy_info=True,
copy_arrays=True,
info=None,
):
positions = np.asarray(atoms.get_positions(), dtype=np.float32)
Expand Down Expand Up @@ -181,31 +204,27 @@ def _extract_ase_record(
properties = {}

arrays = getattr(atoms, "arrays", None)
if isinstance(arrays, dict):
if copy_arrays and isinstance(arrays, dict):
for key, value in arrays.items():
# Skip both ASE-reserved geometry keys ("positions", "numbers")
# and atompack builtin field names. A user who stashes "forces"
# in atoms.arrays must not have it duplicated into both
# builtins["forces"] (from get_forces()) and properties["forces"].
if key in _ASE_RESERVED_ARRAYS or key in _BUILTIN_FIELDS:
continue
coerced = _coerce_property(value, n_atoms)
if coerced is not _UNSUPPORTED_PROPERTY:
properties[key] = coerced
properties[key] = _coerce_custom_property(key, value, "atoms.arrays")

calc = getattr(atoms, "calc", None)
results = getattr(calc, "results", None)
if isinstance(results, dict):
for key, value in results.items():
if key not in _BUILTIN_FIELDS:
coerced = _coerce_property(value, n_atoms)
if coerced is not _UNSUPPORTED_PROPERTY:
properties[key] = coerced
properties[key] = _coerce_custom_property(key, value, "atoms.calc.results")

if copy_info and getattr(atoms, "info", None):
_merge_properties(properties, builtins, atoms.info, n_atoms)
_merge_properties(properties, builtins, atoms.info, "atoms.info")
if info is not None:
_merge_properties(properties, builtins, info, n_atoms)
_merge_properties(properties, builtins, info, "info override")

return {
"positions": positions,
Expand Down Expand Up @@ -661,9 +680,15 @@ def from_ase(
cell=None,
stress=None,
copy_info=True,
copy_arrays=True,
info=None,
):
"""Convert one ASE Atoms object to an atompack Molecule."""
"""Convert one ASE Atoms object to an atompack Molecule.

Custom values from ``atoms.info``, ``atoms.arrays``, calculator results,
and explicit ``info=`` overrides are stored as molecule-scope properties.
Array shape is not used to infer atom-property scope during ingestion.
"""
return _record_to_molecule(
_extract_ase_record(
atoms,
Expand All @@ -674,6 +699,7 @@ def from_ase(
cell=cell,
stress=stress,
copy_info=copy_info,
copy_arrays=copy_arrays,
info=info,
)
)
Expand All @@ -684,6 +710,7 @@ def add_ase_batch(
atoms_list,
*,
copy_info=True,
copy_arrays=True,
info=None,
batch_size=512,
):
Expand All @@ -710,7 +737,12 @@ def flush_slow():
slow_records.clear()

for atoms, info_override in zip(atoms_list, info_overrides):
record = _extract_ase_record(atoms, copy_info=copy_info, info=info_override)
record = _extract_ase_record(
atoms,
copy_info=copy_info,
copy_arrays=copy_arrays,
info=info_override,
)
if record["properties"]:
flush_fast()
slow_records.append(_record_to_molecule(record))
Expand Down
Loading
Loading