From f28b78343000d3950bab4f66d9880c1c17cfb6ef Mon Sep 17 00:00:00 2001 From: JonathanSchmidt1 Date: Tue, 10 Mar 2026 15:16:20 +0100 Subject: [PATCH 01/16] add system conditioning to docs and ase calculator --- docs/src/engines/ase.rst | 25 +++++++ docs/src/engines/lammps.rst | 10 +++ .../metatomic/torch/ase_calculator.py | 48 ++++++++++++- .../metatomic_torch/tests/ase_calculator.py | 67 +++++++++++++++++++ 4 files changed, 149 insertions(+), 1 deletion(-) diff --git a/docs/src/engines/ase.rst b/docs/src/engines/ase.rst index 2bcdd8400..d9ff35f29 100644 --- a/docs/src/engines/ase.rst +++ b/docs/src/engines/ase.rst @@ -34,6 +34,31 @@ How to install the code The code is available in the ``metatomic-torch`` package, in the :py:class:`metatomic.torch.ase_calculator.MetatomicCalculator` class. +Supported model inputs +^^^^^^^^^^^^^^^^^^^^^^ + +The ASE calculator can provide per-atom inputs (e.g. ``"charges"``, +``"momenta"``, ``"velocities"``) as well as the following **system-level** +integer inputs used for model conditioning: + +.. list-table:: + :header-rows: 1 + :widths: 2 3 5 + + * - Input name + - Default + - How to set + * - ``"mtt::charge"`` + - ``0`` + - ``atoms.info["charge"] = `` + * - ``"mtt::spin"`` + - ``1`` + - ``atoms.info["multiplicity"] = `` + +``"mtt::charge"`` is the total charge of the simulation cell in elementary +charges. ``"mtt::spin"`` is the spin multiplicity (2S+1). Both values are +passed to the model as 64-bit integers. + How to use the code ^^^^^^^^^^^^^^^^^^^ diff --git a/docs/src/engines/lammps.rst b/docs/src/engines/lammps.rst index 661b3053b..4ea519e80 100644 --- a/docs/src/engines/lammps.rst +++ b/docs/src/engines/lammps.rst @@ -330,6 +330,16 @@ documentation. specifies which variant of the model outputs should be uses for making non-conservative stress predictions. Overrides the value given to the ``variant`` keyword. Defaults to no variant. + **charge** values = integer + total charge of the simulation cell in elementary charges, passed to the + model as the ``"mtt::charge"`` system-level input when the model requests + it. Only relevant for models using system-level charge conditioning (e.g. + PET with ``system_conditioning`` enabled in metatrain). Defaults to ``0``. + **spin** values = integer + spin multiplicity of the simulation cell (2S+1), passed to the model as + the ``"mtt::spin"`` system-level input when the model requests it. Only + relevant for models using system-level spin conditioning. Defaults to + ``1`` (closed-shell singlet). Examples -------- diff --git a/python/metatomic_torch/metatomic/torch/ase_calculator.py b/python/metatomic_torch/metatomic/torch/ase_calculator.py index 76964aaae..4af311f92 100644 --- a/python/metatomic_torch/metatomic/torch/ase_calculator.py +++ b/python/metatomic_torch/metatomic/torch/ase_calculator.py @@ -51,6 +51,28 @@ def _get_charges(atoms: ase.Atoms) -> np.ndarray: return atoms.get_initial_charges() +SYSTEM_QUANTITIES = { + "mtt::charge": { + "quantity": "charge", + "getter": lambda atoms: np.array([[atoms.info.get("charge", 0)]]), + "unit": "e", + }, + "mtt::spin": { + "quantity": "spin_multiplicity", + "getter": lambda atoms: np.array([[atoms.info.get("multiplicity", 1)]]), + "unit": "", + }, +} +""" +Per-system scalar inputs provided by ASE via ``atoms.info``. + +- ``"mtt::charge"``: total system charge in elementary charges, read from + ``atoms.info["charge"]``, defaults to ``0``. +- ``"mtt::spin"``: spin multiplicity (2S+1), read from + ``atoms.info["multiplicity"]``, defaults to ``1``. +""" + + ARRAY_QUANTITIES = { "momenta": { "quantity": "momentum", @@ -981,9 +1003,33 @@ def _get_ase_input( dtype: torch.dtype, device: torch.device, ) -> "TensorMap": + if name in SYSTEM_QUANTITIES: + infos = SYSTEM_QUANTITIES[name] + # shape: (1, 1) — one system, one scalar property + values = torch.tensor( + infos["getter"](atoms), dtype=torch.int64, device=device + ) + block = TensorBlock( + values, + samples=Labels(["system"], torch.tensor([[0]], device=device)), + components=[], + properties=Labels( + [infos["quantity"]], torch.tensor([[0]], device=device) + ), + ) + tensor = TensorMap( + Labels(["_"], torch.tensor([[0]], device=device)), [block] + ) + tensor.set_info("quantity", infos["quantity"]) + tensor.set_info("unit", infos["unit"]) + return tensor + if name not in ARRAY_QUANTITIES: raise ValueError( - f"The model requested '{name}', which is not available in `ase`." + f"The model requested '{name}', which is not available in `ase`. " + "System-level quantities like 'mtt::charge' or 'mtt::spin' can be " + "set via atoms.info['charge'] and atoms.info['multiplicity'] " + "respectively." ) infos = ARRAY_QUANTITIES[name] diff --git a/python/metatomic_torch/tests/ase_calculator.py b/python/metatomic_torch/tests/ase_calculator.py index a1d1e57fc..b0a9229df 100644 --- a/python/metatomic_torch/tests/ase_calculator.py +++ b/python/metatomic_torch/tests/ase_calculator.py @@ -27,6 +27,7 @@ ) from metatomic.torch.ase_calculator import ( ARRAY_QUANTITIES, + SYSTEM_QUANTITIES, MetatomicCalculator, _compute_ase_neighbors, _full_3x3_to_voigt_6_stress, @@ -871,3 +872,69 @@ def test_additional_input(atoms): ) # ase velocity is in (eV/u)^(1/2) and we want A/fs assert np.allclose(values, expected) + + +def test_system_level_input(atoms): + """mtt::charge and mtt::spin are per-system integer inputs read from atoms.info.""" + inputs = { + "mtt::charge": ModelOutput(quantity="charge", unit="e", per_atom=False), + "mtt::spin": ModelOutput( + quantity="spin_multiplicity", unit="", per_atom=False + ), + } + outputs = {("extra::" + n): inputs[n] for n in inputs} + capabilities = ModelCapabilities( + outputs=outputs, + atomic_types=[28], + interaction_range=0.0, + supported_devices=["cpu"], + dtype="float64", + ) + + model = AtomisticModel( + AdditionalInputModel(inputs).eval(), ModelMetadata(), capabilities + ) + atoms.info["charge"] = -2 + atoms.info["multiplicity"] = 3 + calculator = MetatomicCalculator(model, check_consistency=False) + results = calculator.run_model(atoms, outputs) + + charge_tensor = results["extra::mtt::charge"] + assert charge_tensor[0].samples.names == ["system"] + assert charge_tensor[0].values.dtype == torch.int64 + assert int(charge_tensor[0].values[0, 0]) == -2 + + spin_tensor = results["extra::mtt::spin"] + assert spin_tensor[0].samples.names == ["system"] + assert spin_tensor[0].values.dtype == torch.int64 + assert int(spin_tensor[0].values[0, 0]) == 3 + + +def test_system_level_input_defaults(atoms): + """mtt::charge defaults to 0 and mtt::spin to 1 when not set in atoms.info.""" + inputs = { + "mtt::charge": ModelOutput(quantity="charge", unit="e", per_atom=False), + "mtt::spin": ModelOutput( + quantity="spin_multiplicity", unit="", per_atom=False + ), + } + outputs = {("extra::" + n): inputs[n] for n in inputs} + capabilities = ModelCapabilities( + outputs=outputs, + atomic_types=[28], + interaction_range=0.0, + supported_devices=["cpu"], + dtype="float64", + ) + + model = AtomisticModel( + AdditionalInputModel(inputs).eval(), ModelMetadata(), capabilities + ) + # ensure the keys are absent + atoms.info.pop("charge", None) + atoms.info.pop("multiplicity", None) + calculator = MetatomicCalculator(model, check_consistency=False) + results = calculator.run_model(atoms, outputs) + + assert int(results["extra::mtt::charge"][0].values[0, 0]) == 0 + assert int(results["extra::mtt::spin"][0].values[0, 0]) == 1 From 628e59a921869efca48c3b129eacf2fa23a1021f Mon Sep 17 00:00:00 2001 From: JonathanSchmidt1 Date: Thu, 19 Mar 2026 22:23:31 +0100 Subject: [PATCH 02/16] multiplicity to spin --- docs/src/engines/ase.rst | 7 ++++--- python/metatomic_torch/metatomic/torch/ase_calculator.py | 6 +++--- python/metatomic_torch/tests/ase_calculator.py | 4 ++-- 3 files changed, 9 insertions(+), 8 deletions(-) diff --git a/docs/src/engines/ase.rst b/docs/src/engines/ase.rst index d9ff35f29..c94ee7680 100644 --- a/docs/src/engines/ase.rst +++ b/docs/src/engines/ase.rst @@ -53,11 +53,12 @@ integer inputs used for model conditioning: - ``atoms.info["charge"] = `` * - ``"mtt::spin"`` - ``1`` - - ``atoms.info["multiplicity"] = `` + - ``atoms.info["spin"] = `` ``"mtt::charge"`` is the total charge of the simulation cell in elementary -charges. ``"mtt::spin"`` is the spin multiplicity (2S+1). Both values are -passed to the model as 64-bit integers. +charges. ``"mtt::spin"`` is the spin multiplicity (2S+1) — a singlet is +``spin=1``, a doublet is ``spin=2``, a triplet is ``spin=3``, and so on. +Both values are passed to the model as 64-bit integers. How to use the code ^^^^^^^^^^^^^^^^^^^ diff --git a/python/metatomic_torch/metatomic/torch/ase_calculator.py b/python/metatomic_torch/metatomic/torch/ase_calculator.py index 4af311f92..a3445f29e 100644 --- a/python/metatomic_torch/metatomic/torch/ase_calculator.py +++ b/python/metatomic_torch/metatomic/torch/ase_calculator.py @@ -59,7 +59,7 @@ def _get_charges(atoms: ase.Atoms) -> np.ndarray: }, "mtt::spin": { "quantity": "spin_multiplicity", - "getter": lambda atoms: np.array([[atoms.info.get("multiplicity", 1)]]), + "getter": lambda atoms: np.array([[atoms.info.get("spin", 1)]]), "unit": "", }, } @@ -69,7 +69,7 @@ def _get_charges(atoms: ase.Atoms) -> np.ndarray: - ``"mtt::charge"``: total system charge in elementary charges, read from ``atoms.info["charge"]``, defaults to ``0``. - ``"mtt::spin"``: spin multiplicity (2S+1), read from - ``atoms.info["multiplicity"]``, defaults to ``1``. + ``atoms.info["spin"]``, defaults to ``1``. """ @@ -1028,7 +1028,7 @@ def _get_ase_input( raise ValueError( f"The model requested '{name}', which is not available in `ase`. " "System-level quantities like 'mtt::charge' or 'mtt::spin' can be " - "set via atoms.info['charge'] and atoms.info['multiplicity'] " + "set via atoms.info['charge'] and atoms.info['spin'] " "respectively." ) diff --git a/python/metatomic_torch/tests/ase_calculator.py b/python/metatomic_torch/tests/ase_calculator.py index b0a9229df..ea0611604 100644 --- a/python/metatomic_torch/tests/ase_calculator.py +++ b/python/metatomic_torch/tests/ase_calculator.py @@ -895,7 +895,7 @@ def test_system_level_input(atoms): AdditionalInputModel(inputs).eval(), ModelMetadata(), capabilities ) atoms.info["charge"] = -2 - atoms.info["multiplicity"] = 3 + atoms.info["spin"] = 3 calculator = MetatomicCalculator(model, check_consistency=False) results = calculator.run_model(atoms, outputs) @@ -932,7 +932,7 @@ def test_system_level_input_defaults(atoms): ) # ensure the keys are absent atoms.info.pop("charge", None) - atoms.info.pop("multiplicity", None) + atoms.info.pop("spin", None) calculator = MetatomicCalculator(model, check_consistency=False) results = calculator.run_model(atoms, outputs) From 81eb6f6b865277e706975a7f2f74b3e0685404fc Mon Sep 17 00:00:00 2001 From: JonathanSchmidt1 Date: Thu, 19 Mar 2026 22:31:40 +0100 Subject: [PATCH 03/16] fix dtype --- python/metatomic_torch/metatomic/torch/ase_calculator.py | 2 +- python/metatomic_torch/tests/ase_calculator.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/python/metatomic_torch/metatomic/torch/ase_calculator.py b/python/metatomic_torch/metatomic/torch/ase_calculator.py index a3445f29e..ce5ba75cb 100644 --- a/python/metatomic_torch/metatomic/torch/ase_calculator.py +++ b/python/metatomic_torch/metatomic/torch/ase_calculator.py @@ -1007,7 +1007,7 @@ def _get_ase_input( infos = SYSTEM_QUANTITIES[name] # shape: (1, 1) — one system, one scalar property values = torch.tensor( - infos["getter"](atoms), dtype=torch.int64, device=device + infos["getter"](atoms), dtype=dtype, device=device ) block = TensorBlock( values, diff --git a/python/metatomic_torch/tests/ase_calculator.py b/python/metatomic_torch/tests/ase_calculator.py index ea0611604..528666fb5 100644 --- a/python/metatomic_torch/tests/ase_calculator.py +++ b/python/metatomic_torch/tests/ase_calculator.py @@ -901,12 +901,12 @@ def test_system_level_input(atoms): charge_tensor = results["extra::mtt::charge"] assert charge_tensor[0].samples.names == ["system"] - assert charge_tensor[0].values.dtype == torch.int64 + assert charge_tensor[0].values.dtype == torch.float64 # matches model dtype assert int(charge_tensor[0].values[0, 0]) == -2 spin_tensor = results["extra::mtt::spin"] assert spin_tensor[0].samples.names == ["system"] - assert spin_tensor[0].values.dtype == torch.int64 + assert spin_tensor[0].values.dtype == torch.float64 # matches model dtype assert int(spin_tensor[0].values[0, 0]) == 3 From 377cb04adc3a33b50370370e55341172f7fdbcca Mon Sep 17 00:00:00 2001 From: JonathanSchmidt1 Date: Thu, 19 Mar 2026 23:32:33 +0100 Subject: [PATCH 04/16] adjust state check --- docs/src/engines/ase.rst | 4 ++- .../metatomic/torch/ase_calculator.py | 32 +++++++++++++++++++ 2 files changed, 35 insertions(+), 1 deletion(-) diff --git a/docs/src/engines/ase.rst b/docs/src/engines/ase.rst index c94ee7680..688bd39c8 100644 --- a/docs/src/engines/ase.rst +++ b/docs/src/engines/ase.rst @@ -58,7 +58,9 @@ integer inputs used for model conditioning: ``"mtt::charge"`` is the total charge of the simulation cell in elementary charges. ``"mtt::spin"`` is the spin multiplicity (2S+1) — a singlet is ``spin=1``, a doublet is ``spin=2``, a triplet is ``spin=3``, and so on. -Both values are passed to the model as 64-bit integers. +Both values are read as integers from ``atoms.info`` and stored in the +system as the model's floating-point dtype (float32 or float64); the model +converts them back to integers internally for the embedding lookup. How to use the code ^^^^^^^^^^^^^^^^^^^ diff --git a/python/metatomic_torch/metatomic/torch/ase_calculator.py b/python/metatomic_torch/metatomic/torch/ase_calculator.py index ce5ba75cb..c31369ccf 100644 --- a/python/metatomic_torch/metatomic/torch/ase_calculator.py +++ b/python/metatomic_torch/metatomic/torch/ase_calculator.py @@ -56,11 +56,15 @@ def _get_charges(atoms: ase.Atoms) -> np.ndarray: "quantity": "charge", "getter": lambda atoms: np.array([[atoms.info.get("charge", 0)]]), "unit": "e", + "info_key": "charge", + "default": 0, }, "mtt::spin": { "quantity": "spin_multiplicity", "getter": lambda atoms: np.array([[atoms.info.get("spin", 1)]]), "unit": "", + "info_key": "spin", + "default": 1, }, } """ @@ -323,6 +327,15 @@ def __init__( self._model = model.to(device=self._device) + # Cache which atoms.info keys need change-detection so that check_state + # does only plain Python list iteration on every MD step, avoiding a + # TorchScript JIT dispatch per step to requested_inputs(). + self._system_info_watch: List[Tuple[str, int]] = [ + (infos["info_key"], infos["default"]) + for name, infos in SYSTEM_QUANTITIES.items() + if name in self._model.requested_inputs() + ] + self._calculate_uncertainty = ( self._energy_uq_key in self._model.capabilities().outputs # we require per-atom uncertainties to capture local effects @@ -441,6 +454,25 @@ def run_model( check_consistency=self.parameters["check_consistency"], ) + def check_state(self, atoms: ase.Atoms, tol: float = 1e-15) -> List[str]: + """Detect system changes, including ``atoms.info`` keys used as model inputs. + + ASE's default :py:meth:`~ase.calculators.calculator.Calculator.check_state` + only tracks per-atom arrays (positions, numbers, …) and cell/pbc. Changes + to ``atoms.info["charge"]`` or ``atoms.info["spin"]`` are invisible to it, + causing stale cached results when the charge or spin is updated between calls. + + This override appends the name of any ``atoms.info`` key that has changed + since the last calculation to the standard change list, which forces a + fresh calculation. + """ + changes = super().check_state(atoms, tol=tol) + if self.atoms is not None: + for key, default in self._system_info_watch: + if self.atoms.info.get(key, default) != atoms.info.get(key, default): + changes.append(key) + return changes + def calculate( self, atoms: ase.Atoms, From a47163c5c485481268fcd17f95961e2a204a9f5d Mon Sep 17 00:00:00 2001 From: JonathanSchmidt1 Date: Thu, 19 Mar 2026 23:36:36 +0100 Subject: [PATCH 05/16] lint --- .../metatomic/torch/ase_calculator.py | 12 +++--------- python/metatomic_torch/tests/ase_calculator.py | 8 ++------ 2 files changed, 5 insertions(+), 15 deletions(-) diff --git a/python/metatomic_torch/metatomic/torch/ase_calculator.py b/python/metatomic_torch/metatomic/torch/ase_calculator.py index c31369ccf..f6ae7c106 100644 --- a/python/metatomic_torch/metatomic/torch/ase_calculator.py +++ b/python/metatomic_torch/metatomic/torch/ase_calculator.py @@ -1038,20 +1038,14 @@ def _get_ase_input( if name in SYSTEM_QUANTITIES: infos = SYSTEM_QUANTITIES[name] # shape: (1, 1) — one system, one scalar property - values = torch.tensor( - infos["getter"](atoms), dtype=dtype, device=device - ) + values = torch.tensor(infos["getter"](atoms), dtype=dtype, device=device) block = TensorBlock( values, samples=Labels(["system"], torch.tensor([[0]], device=device)), components=[], - properties=Labels( - [infos["quantity"]], torch.tensor([[0]], device=device) - ), - ) - tensor = TensorMap( - Labels(["_"], torch.tensor([[0]], device=device)), [block] + properties=Labels([infos["quantity"]], torch.tensor([[0]], device=device)), ) + tensor = TensorMap(Labels(["_"], torch.tensor([[0]], device=device)), [block]) tensor.set_info("quantity", infos["quantity"]) tensor.set_info("unit", infos["unit"]) return tensor diff --git a/python/metatomic_torch/tests/ase_calculator.py b/python/metatomic_torch/tests/ase_calculator.py index 528666fb5..2fddee246 100644 --- a/python/metatomic_torch/tests/ase_calculator.py +++ b/python/metatomic_torch/tests/ase_calculator.py @@ -878,9 +878,7 @@ def test_system_level_input(atoms): """mtt::charge and mtt::spin are per-system integer inputs read from atoms.info.""" inputs = { "mtt::charge": ModelOutput(quantity="charge", unit="e", per_atom=False), - "mtt::spin": ModelOutput( - quantity="spin_multiplicity", unit="", per_atom=False - ), + "mtt::spin": ModelOutput(quantity="spin_multiplicity", unit="", per_atom=False), } outputs = {("extra::" + n): inputs[n] for n in inputs} capabilities = ModelCapabilities( @@ -914,9 +912,7 @@ def test_system_level_input_defaults(atoms): """mtt::charge defaults to 0 and mtt::spin to 1 when not set in atoms.info.""" inputs = { "mtt::charge": ModelOutput(quantity="charge", unit="e", per_atom=False), - "mtt::spin": ModelOutput( - quantity="spin_multiplicity", unit="", per_atom=False - ), + "mtt::spin": ModelOutput(quantity="spin_multiplicity", unit="", per_atom=False), } outputs = {("extra::" + n): inputs[n] for n in inputs} capabilities = ModelCapabilities( From 7183374040943b0526ddcc69034a3da1c5582d24 Mon Sep 17 00:00:00 2001 From: JonathanSchmidt1 Date: Thu, 19 Mar 2026 23:37:36 +0100 Subject: [PATCH 06/16] lint --- python/metatomic_torch/tests/ase_calculator.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/metatomic_torch/tests/ase_calculator.py b/python/metatomic_torch/tests/ase_calculator.py index 2fddee246..287f74166 100644 --- a/python/metatomic_torch/tests/ase_calculator.py +++ b/python/metatomic_torch/tests/ase_calculator.py @@ -27,7 +27,6 @@ ) from metatomic.torch.ase_calculator import ( ARRAY_QUANTITIES, - SYSTEM_QUANTITIES, MetatomicCalculator, _compute_ase_neighbors, _full_3x3_to_voigt_6_stress, From cb91a650ee40fcbe17c77108dbd53791942a8794 Mon Sep 17 00:00:00 2001 From: JonathanSchmidt1 Date: Fri, 20 Mar 2026 00:20:45 +0100 Subject: [PATCH 07/16] spin calculator test --- .../metatomic_torch/tests/ase_calculator.py | 86 +++++++++++++++++++ 1 file changed, 86 insertions(+) diff --git a/python/metatomic_torch/tests/ase_calculator.py b/python/metatomic_torch/tests/ase_calculator.py index 287f74166..eb5ce34c2 100644 --- a/python/metatomic_torch/tests/ase_calculator.py +++ b/python/metatomic_torch/tests/ase_calculator.py @@ -933,3 +933,89 @@ def test_system_level_input_defaults(atoms): assert int(results["extra::mtt::charge"][0].values[0, 0]) == 0 assert int(results["extra::mtt::spin"][0].values[0, 0]) == 1 + + +class ChargeSpinEnergyModel(torch.nn.Module): + """Minimal energy model whose output depends on charge and spin. + + Returns energy = charge_value + 10 * spin_value so that different + charge/spin inputs always produce different energies. + """ + + def requested_inputs(self) -> Dict[str, ModelOutput]: + return { + "mtt::charge": ModelOutput(quantity="charge", unit="e", per_atom=False), + "mtt::spin": ModelOutput( + quantity="spin_multiplicity", unit="", per_atom=False + ), + } + + def forward( + self, + systems: List[System], + outputs: Dict[str, ModelOutput], + selected_atoms: Optional[Labels] = None, + ) -> Dict[str, TensorMap]: + system = systems[0] + charge = float(system.get_data("mtt::charge").block(0).values[0, 0]) + spin = float(system.get_data("mtt::spin").block(0).values[0, 0]) + energy_value = charge + 10.0 * spin + block = TensorBlock( + values=torch.tensor([[energy_value]], dtype=torch.float64), + samples=Labels("system", torch.tensor([[0]])), + components=torch.jit.annotate(List[Labels], []), + properties=Labels("energy", torch.tensor([[0]])), + ) + return {"energy": TensorMap(Labels("_", torch.tensor([[0]])), [block])} + + +def test_system_level_input_changes_energy(atoms): + """Different charge/spin values must produce different energies from the calculator.""" + capabilities = ModelCapabilities( + outputs={"energy": ModelOutput(per_atom=False)}, + atomic_types=[28], + interaction_range=0.0, + supported_devices=["cpu"], + dtype="float64", + ) + model = AtomisticModel( + ChargeSpinEnergyModel().eval(), ModelMetadata(), capabilities + ) + + # --- varying charge --- + atoms.info["spin"] = 1 + atoms.info["charge"] = 0 + calc = MetatomicCalculator(model, check_consistency=False) + atoms.calc = calc + e_neutral = atoms.get_potential_energy() + + atoms.info["charge"] = 2 + atoms.calc.reset() + e_charged = atoms.get_potential_energy() + + assert e_neutral != e_charged, "Different charges must give different energies" + + # --- varying spin --- + atoms.info["charge"] = 0 + atoms.info["spin"] = 1 + atoms.calc.reset() + e_singlet = atoms.get_potential_energy() + + atoms.info["spin"] = 3 + atoms.calc.reset() + e_triplet = atoms.get_potential_energy() + + assert e_singlet != e_triplet, "Different spins must give different energies" + + # --- cache invalidation: check_state detects atoms.info changes --- + atoms.info["charge"] = 0 + atoms.info["spin"] = 1 + atoms.calc.reset() + e_before = atoms.get_potential_energy() + + atoms.info["charge"] = 1 # change without explicit reset + e_after = atoms.get_potential_energy() + + assert e_before != e_after, ( + "check_state must invalidate cache when atoms.info['charge'] changes" + ) From ba8cec776b963d743e75e06fd2472ec971dbac8d Mon Sep 17 00:00:00 2001 From: JonathanSchmidt1 Date: Fri, 20 Mar 2026 00:22:35 +0100 Subject: [PATCH 08/16] linting --- python/metatomic_torch/tests/ase_calculator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/metatomic_torch/tests/ase_calculator.py b/python/metatomic_torch/tests/ase_calculator.py index eb5ce34c2..2f25a2df1 100644 --- a/python/metatomic_torch/tests/ase_calculator.py +++ b/python/metatomic_torch/tests/ase_calculator.py @@ -970,7 +970,7 @@ def forward( def test_system_level_input_changes_energy(atoms): - """Different charge/spin values must produce different energies from the calculator.""" + """Different charge/spin values must produce different energies.""" capabilities = ModelCapabilities( outputs={"energy": ModelOutput(per_atom=False)}, atomic_types=[28], From 94fbbb6502a0731b1f2c4473e23c62ca32c3f447 Mon Sep 17 00:00:00 2001 From: JonathanSchmidt1 Date: Fri, 20 Mar 2026 11:07:11 +0100 Subject: [PATCH 09/16] adjust logic that checks for changed system attributes --- .../metatomic_torch/metatomic/torch/ase_calculator.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/python/metatomic_torch/metatomic/torch/ase_calculator.py b/python/metatomic_torch/metatomic/torch/ase_calculator.py index 6a01a1111..81950e6fa 100644 --- a/python/metatomic_torch/metatomic/torch/ase_calculator.py +++ b/python/metatomic_torch/metatomic/torch/ase_calculator.py @@ -472,7 +472,16 @@ def check_state(self, atoms: ase.Atoms, tol: float = 1e-15) -> List[str]: changes = super().check_state(atoms, tol=tol) if self.atoms is not None: for key, default in self._system_info_watch: - if self.atoms.info.get(key, default) != atoms.info.get(key, default): + old = self.atoms.info.get(key, default) + new = atoms.info.get(key, default) + try: + equal = old == new + # numpy arrays and similar objects return array-like booleans; + # treat anything that is not a plain bool as "changed" to be safe + if not isinstance(equal, bool) or not equal: + changes.append(key) + except Exception: + # comparison raised (e.g. mixed types); assume changed changes.append(key) return changes From 7b2454c826e3ebc1817f0d05f7d5a7b3a869a22f Mon Sep 17 00:00:00 2001 From: JonathanSchmidt1 Date: Fri, 20 Mar 2026 12:27:05 +0100 Subject: [PATCH 10/16] remove last mention of multiplicity to have common keywords --- python/metatomic_torch/metatomic/torch/ase_calculator.py | 2 +- python/metatomic_torch/tests/ase_calculator.py | 8 +++----- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/python/metatomic_torch/metatomic/torch/ase_calculator.py b/python/metatomic_torch/metatomic/torch/ase_calculator.py index 81950e6fa..e32ffc12d 100644 --- a/python/metatomic_torch/metatomic/torch/ase_calculator.py +++ b/python/metatomic_torch/metatomic/torch/ase_calculator.py @@ -67,7 +67,7 @@ def _get_charges(atoms: ase.Atoms) -> np.ndarray: "default": 0, }, "mtt::spin": { - "quantity": "spin_multiplicity", + "quantity": "spin", "getter": lambda atoms: np.array([[atoms.info.get("spin", 1)]]), "unit": "", "info_key": "spin", diff --git a/python/metatomic_torch/tests/ase_calculator.py b/python/metatomic_torch/tests/ase_calculator.py index fd7ccf02b..1202e3e19 100644 --- a/python/metatomic_torch/tests/ase_calculator.py +++ b/python/metatomic_torch/tests/ase_calculator.py @@ -955,7 +955,7 @@ def test_system_level_input(atoms): """mtt::charge and mtt::spin are per-system integer inputs read from atoms.info.""" inputs = { "mtt::charge": ModelOutput(quantity="charge", unit="e", per_atom=False), - "mtt::spin": ModelOutput(quantity="spin_multiplicity", unit="", per_atom=False), + "mtt::spin": ModelOutput(quantity="spin", unit="", per_atom=False), } outputs = {("extra::" + n): inputs[n] for n in inputs} capabilities = ModelCapabilities( @@ -989,7 +989,7 @@ def test_system_level_input_defaults(atoms): """mtt::charge defaults to 0 and mtt::spin to 1 when not set in atoms.info.""" inputs = { "mtt::charge": ModelOutput(quantity="charge", unit="e", per_atom=False), - "mtt::spin": ModelOutput(quantity="spin_multiplicity", unit="", per_atom=False), + "mtt::spin": ModelOutput(quantity="spin", unit="", per_atom=False), } outputs = {("extra::" + n): inputs[n] for n in inputs} capabilities = ModelCapabilities( @@ -1023,9 +1023,7 @@ class ChargeSpinEnergyModel(torch.nn.Module): def requested_inputs(self) -> Dict[str, ModelOutput]: return { "mtt::charge": ModelOutput(quantity="charge", unit="e", per_atom=False), - "mtt::spin": ModelOutput( - quantity="spin_multiplicity", unit="", per_atom=False - ), + "mtt::spin": ModelOutput(quantity="spin", unit="", per_atom=False), } def forward( From d04eb6931d446db40a34e56b0a51a93d60677a47 Mon Sep 17 00:00:00 2001 From: JonathanSchmidt1 Date: Sun, 22 Mar 2026 12:32:07 +0100 Subject: [PATCH 11/16] add requested properties to compute energy to enable symmetrized calculator --- python/metatomic_torch/metatomic/torch/ase_calculator.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/python/metatomic_torch/metatomic/torch/ase_calculator.py b/python/metatomic_torch/metatomic/torch/ase_calculator.py index e32ffc12d..ad0f342d8 100644 --- a/python/metatomic_torch/metatomic/torch/ase_calculator.py +++ b/python/metatomic_torch/metatomic/torch/ase_calculator.py @@ -783,6 +783,11 @@ def compute_energy( cell = cell @ strain strains.append(strain) system = System(types, positions, cell, pbc) + for name, option in self._model.requested_inputs().items(): + input_tensormap = _get_ase_input( + atoms, name, option, dtype=self._dtype, device=self._device + ) + system.add_data(name, input_tensormap) systems.append(system) # Compute the neighbors lists requested by the model From 08c627a165cba163011198afe28fb17f8e1e8a2f Mon Sep 17 00:00:00 2001 From: JonathanSchmidt1 Date: Sun, 22 Mar 2026 14:01:17 +0100 Subject: [PATCH 12/16] revert lammps docs changes --- docs/src/engines/lammps.rst | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/docs/src/engines/lammps.rst b/docs/src/engines/lammps.rst index 4ea519e80..b7fa84a93 100644 --- a/docs/src/engines/lammps.rst +++ b/docs/src/engines/lammps.rst @@ -330,16 +330,7 @@ documentation. specifies which variant of the model outputs should be uses for making non-conservative stress predictions. Overrides the value given to the ``variant`` keyword. Defaults to no variant. - **charge** values = integer - total charge of the simulation cell in elementary charges, passed to the - model as the ``"mtt::charge"`` system-level input when the model requests - it. Only relevant for models using system-level charge conditioning (e.g. - PET with ``system_conditioning`` enabled in metatrain). Defaults to ``0``. - **spin** values = integer - spin multiplicity of the simulation cell (2S+1), passed to the model as - the ``"mtt::spin"`` system-level input when the model requests it. Only - relevant for models using system-level spin conditioning. Defaults to - ``1`` (closed-shell singlet). + Examples -------- From 568feda1263e7ba76683d3dfda80877798e4ef01 Mon Sep 17 00:00:00 2001 From: JonathanSchmidt1 Date: Sun, 22 Mar 2026 14:07:09 +0100 Subject: [PATCH 13/16] add test for symmetrized calculator --- .../tests/symmetrized_ase_calculator.py | 97 +++++++++++++++++++ 1 file changed, 97 insertions(+) diff --git a/python/metatomic_torch/tests/symmetrized_ase_calculator.py b/python/metatomic_torch/tests/symmetrized_ase_calculator.py index 3e1e83a96..9ecf250d9 100644 --- a/python/metatomic_torch/tests/symmetrized_ase_calculator.py +++ b/python/metatomic_torch/tests/symmetrized_ase_calculator.py @@ -485,6 +485,103 @@ def test_choose_quadrature_rules(): assert n_gamma == 2 * L + 1 +# -- Charge / spin conditioning tests ---------------------------------------- + + +class _ChargeSpinAnisoModel(torch.nn.Module): + """Minimal model whose energy = charge + 10*spin + P1(cos θ). + + The P1(cos θ) term is orientation-dependent and cancels exactly under O(3) + rotational averaging (Lebedev l_max >= 1). What remains is + ``charge + 10*spin``, which is rotation-invariant. + + This lets us verify two things in a single test: + - charge/spin values from ``atoms.info`` reach the model in every rotated + copy (the bug fixed in ``compute_energy``). + - the orientation-dependent part is correctly averaged away. + """ + + def requested_inputs(self) -> Dict[str, ModelOutput]: + return { + "mtt::charge": ModelOutput(quantity="charge", unit="e", per_atom=False), + "mtt::spin": ModelOutput(quantity="spin", unit="", per_atom=False), + } + + def forward( + self, + systems: List[System], + outputs: Dict[str, ModelOutput], + selected_atoms: Optional[Labels] = None, + ) -> Dict[str, TensorMap]: + energies: List[torch.Tensor] = [] + for system in systems: + charge = system.get_data("mtt::charge").block(0).values[0, 0] + spin = system.get_data("mtt::spin").block(0).values[0, 0] + + # Orientation-dependent P1 term (averages to zero under O(3)) + b = _body_axis_from_system(system).to(dtype=charge.dtype) + zhat = torch.tensor( + [0.0, 0.0, 1.0], dtype=charge.dtype, device=charge.device + ) + P1 = torch.dot(b, zhat) + + energies.append((charge + 10.0 * spin + P1).reshape(1, 1)) + + values = torch.cat(energies, dim=0) + block = TensorBlock( + values=values, + samples=Labels( + "system", + torch.arange( + len(systems), dtype=torch.int32 + ).reshape(-1, 1), + ), + components=torch.jit.annotate(List[Labels], []), + properties=Labels("energy", torch.tensor([[0]])), + ) + return {"energy": TensorMap(Labels("_", torch.tensor([[0]])), [block])} + + +def _charge_spin_calculator(charge: float, spin: float) -> mta.ase_calculator.MetatomicCalculator: + """Wrap _ChargeSpinAnisoModel in a MetatomicCalculator.""" + atomistic_model = mta.AtomisticModel( + _ChargeSpinAnisoModel().eval(), + mta.ModelMetadata(), + mta.ModelCapabilities( + outputs={"energy": mta.ModelOutput(per_atom=False)}, + atomic_types=list(range(1, 10)), + interaction_range=0.0, + supported_devices=["cpu"], + dtype="float64", + ), + ) + calc = mta.ase_calculator.MetatomicCalculator(atomistic_model) + return calc + + +@pytest.mark.parametrize("charge,spin", [(0.0, 1.0), (2.0, 1.0), (-1.0, 3.0)]) +def test_symmetrized_calculator_passes_charge_spin(dimer: Atoms, charge: float, spin: float) -> None: + """SymmetrizedCalculator must pass charge/spin to each rotated evaluation. + + The model returns ``charge + 10*spin + P1(cos θ)``. After O(3) averaging + the P1 term cancels, so the result must equal ``charge + 10*spin`` exactly. + If charge/spin were silently dropped, every evaluation would use the default + values (0 and 1) and the test would fail for non-default inputs. + """ + dimer.info["charge"] = charge + dimer.info["spin"] = spin + + base = _charge_spin_calculator(charge, spin) + calc = SymmetrizedCalculator(base, l_max=3, include_inversion=True) + dimer.calc = calc + energy = dimer.get_potential_energy() + + expected = charge + 10.0 * spin + assert np.isclose(energy, expected, atol=1e-8), ( + f"Expected energy={expected} for charge={charge}, spin={spin}, got {energy}" + ) + + def test_get_quadrature_properties(): """Check properties of the quadrature returned by _get_quadrature.""" from metatomic.torch.ase_calculator import _get_quadrature From ed237cd2a7a07a7b9245fe6b9a8f6220521a8496 Mon Sep 17 00:00:00 2001 From: JonathanSchmidt1 Date: Sun, 22 Mar 2026 16:57:02 +0100 Subject: [PATCH 14/16] fix non-conservative stress/nan-stress for molecules --- .../metatomic/torch/ase_calculator.py | 42 ++++++++++++------- .../tests/symmetrized_ase_calculator.py | 8 ++-- 2 files changed, 32 insertions(+), 18 deletions(-) diff --git a/python/metatomic_torch/metatomic/torch/ase_calculator.py b/python/metatomic_torch/metatomic/torch/ase_calculator.py index ad0f342d8..12d852e0f 100644 --- a/python/metatomic_torch/metatomic/torch/ase_calculator.py +++ b/python/metatomic_torch/metatomic/torch/ase_calculator.py @@ -310,11 +310,19 @@ def __init__( outputs, resolved_variants["non_conservative_forces"], ) - self._nc_stress_key = pick_output( - "non_conservative_stress", - outputs, - resolved_variants["non_conservative_stress"], + has_nc_stress = any( + key == "non_conservative_stress" + or key.startswith("non_conservative_stress/") + for key in outputs.keys() ) + if has_nc_stress: + self._nc_stress_key = pick_output( + "non_conservative_stress", + outputs, + resolved_variants["non_conservative_stress"], + ) + else: + self._nc_stress_key = None else: self._nc_forces_key = "non_conservative_forces" self._nc_stress_key = "non_conservative_stress" @@ -547,7 +555,8 @@ def calculate( if self.parameters["do_gradients_with_energy"]: if calculate_energies or calculate_energy: calculate_forces = True - calculate_stress = True + if atoms.pbc.all(): + calculate_stress = True with record_function("MetatomicCalculator::prepare_inputs"): outputs = self._ase_properties_to_metatensor_outputs( @@ -697,16 +706,19 @@ def calculate( forces_values = forces_values.cpu().double() self.results["forces"] = forces_values.numpy() - if calculate_stress: - if self.parameters["non_conservative"]: + if calculate_stress and atoms.pbc.all(): + if self.parameters["non_conservative"] and self._nc_stress_key is not None: stress_values = outputs[self._nc_stress_key].block().values.detach() - else: + elif not self.parameters["non_conservative"]: stress_values = strain.grad / atoms.cell.volume - stress_values = stress_values.reshape(3, 3) - stress_values = stress_values.cpu().double() - self.results["stress"] = _full_3x3_to_voigt_6_stress( - stress_values.numpy() - ) + else: + stress_values = None + if stress_values is not None: + stress_values = stress_values.reshape(3, 3) + stress_values = stress_values.cpu().double() + self.results["stress"] = _full_3x3_to_voigt_6_stress( + stress_values.numpy() + ) self.additional_outputs = {} for name in self._additional_output_requests: @@ -871,7 +883,7 @@ def compute_energy( for f in results_as_numpy_arrays["forces"] ] - if all(atoms.pbc.all() for atoms in atoms_list): + if all(atoms.pbc.all() for atoms in atoms_list) and self._nc_stress_key is not None: results_as_numpy_arrays["stress"] = [ s for s in predictions[self._nc_stress_key] @@ -938,7 +950,7 @@ def _ase_properties_to_metatensor_outputs( per_atom=True, ) - if calculate_stress and self.parameters["non_conservative"]: + if calculate_stress and self.parameters["non_conservative"] and self._nc_stress_key is not None: metatensor_outputs[self._nc_stress_key] = ModelOutput( quantity="pressure", unit="eV/Angstrom^3", diff --git a/python/metatomic_torch/tests/symmetrized_ase_calculator.py b/python/metatomic_torch/tests/symmetrized_ase_calculator.py index 9ecf250d9..654a2a527 100644 --- a/python/metatomic_torch/tests/symmetrized_ase_calculator.py +++ b/python/metatomic_torch/tests/symmetrized_ase_calculator.py @@ -542,7 +542,7 @@ def forward( return {"energy": TensorMap(Labels("_", torch.tensor([[0]])), [block])} -def _charge_spin_calculator(charge: float, spin: float) -> mta.ase_calculator.MetatomicCalculator: +def _charge_spin_calculator() -> mta.ase_calculator.MetatomicCalculator: """Wrap _ChargeSpinAnisoModel in a MetatomicCalculator.""" atomistic_model = mta.AtomisticModel( _ChargeSpinAnisoModel().eval(), @@ -560,7 +560,9 @@ def _charge_spin_calculator(charge: float, spin: float) -> mta.ase_calculator.Me @pytest.mark.parametrize("charge,spin", [(0.0, 1.0), (2.0, 1.0), (-1.0, 3.0)]) -def test_symmetrized_calculator_passes_charge_spin(dimer: Atoms, charge: float, spin: float) -> None: +def test_symmetrized_calculator_passes_charge_spin( + dimer: Atoms, charge: float, spin: float +) -> None: """SymmetrizedCalculator must pass charge/spin to each rotated evaluation. The model returns ``charge + 10*spin + P1(cos θ)``. After O(3) averaging @@ -571,7 +573,7 @@ def test_symmetrized_calculator_passes_charge_spin(dimer: Atoms, charge: float, dimer.info["charge"] = charge dimer.info["spin"] = spin - base = _charge_spin_calculator(charge, spin) + base = _charge_spin_calculator() calc = SymmetrizedCalculator(base, l_max=3, include_inversion=True) dimer.calc = calc energy = dimer.get_potential_energy() From b907c8940ef8b60f5b185af329a7613422d398b7 Mon Sep 17 00:00:00 2001 From: JonathanSchmidt1 Date: Thu, 26 Mar 2026 13:23:33 +0100 Subject: [PATCH 15/16] mtt::spin/charge to spin/charge --- docs/src/engines/ase.rst | 8 +-- .../metatomic/torch/ase_calculator.py | 10 +-- .../metatomic_torch/tests/ase_calculator.py | 63 ++++++++++++++----- .../tests/symmetrized_ase_calculator.py | 8 +-- 4 files changed, 62 insertions(+), 27 deletions(-) diff --git a/docs/src/engines/ase.rst b/docs/src/engines/ase.rst index 688bd39c8..cab3722f5 100644 --- a/docs/src/engines/ase.rst +++ b/docs/src/engines/ase.rst @@ -48,15 +48,15 @@ integer inputs used for model conditioning: * - Input name - Default - How to set - * - ``"mtt::charge"`` + * - ``"charge"`` - ``0`` - ``atoms.info["charge"] = `` - * - ``"mtt::spin"`` + * - ``"spin"`` - ``1`` - ``atoms.info["spin"] = `` -``"mtt::charge"`` is the total charge of the simulation cell in elementary -charges. ``"mtt::spin"`` is the spin multiplicity (2S+1) — a singlet is +``"charge"`` is the total charge of the simulation cell in elementary +charges. ``"spin"`` is the spin multiplicity (2S+1) — a singlet is ``spin=1``, a doublet is ``spin=2``, a triplet is ``spin=3``, and so on. Both values are read as integers from ``atoms.info`` and stored in the system as the model's floating-point dtype (float32 or float64); the model diff --git a/python/metatomic_torch/metatomic/torch/ase_calculator.py b/python/metatomic_torch/metatomic/torch/ase_calculator.py index 12d852e0f..87d9cd2aa 100644 --- a/python/metatomic_torch/metatomic/torch/ase_calculator.py +++ b/python/metatomic_torch/metatomic/torch/ase_calculator.py @@ -59,14 +59,14 @@ def _get_charges(atoms: ase.Atoms) -> np.ndarray: SYSTEM_QUANTITIES = { - "mtt::charge": { + "charge": { "quantity": "charge", "getter": lambda atoms: np.array([[atoms.info.get("charge", 0)]]), "unit": "e", "info_key": "charge", "default": 0, }, - "mtt::spin": { + "spin": { "quantity": "spin", "getter": lambda atoms: np.array([[atoms.info.get("spin", 1)]]), "unit": "", @@ -77,9 +77,9 @@ def _get_charges(atoms: ase.Atoms) -> np.ndarray: """ Per-system scalar inputs provided by ASE via ``atoms.info``. -- ``"mtt::charge"``: total system charge in elementary charges, read from +- ``"charge"``: total system charge in elementary charges, read from ``atoms.info["charge"]``, defaults to ``0``. -- ``"mtt::spin"``: spin multiplicity (2S+1), read from +- ``"spin"``: spin multiplicity (2S+1), read from ``atoms.info["spin"]``, defaults to ``1``. """ @@ -995,7 +995,7 @@ def _get_ase_input( if name not in ARRAY_QUANTITIES: raise ValueError( f"The model requested '{name}', which is not available in `ase`. " - "System-level quantities like 'mtt::charge' or 'mtt::spin' can be " + "System-level quantities like 'charge' or 'spin' can be " "set via atoms.info['charge'] and atoms.info['spin'] " "respectively." ) diff --git a/python/metatomic_torch/tests/ase_calculator.py b/python/metatomic_torch/tests/ase_calculator.py index 1202e3e19..0b8e35ce0 100644 --- a/python/metatomic_torch/tests/ase_calculator.py +++ b/python/metatomic_torch/tests/ase_calculator.py @@ -952,10 +952,10 @@ def test_additional_input(atoms): def test_system_level_input(atoms): - """mtt::charge and mtt::spin are per-system integer inputs read from atoms.info.""" + """charge and spin are per-system integer inputs read from atoms.info.""" inputs = { - "mtt::charge": ModelOutput(quantity="charge", unit="e", per_atom=False), - "mtt::spin": ModelOutput(quantity="spin", unit="", per_atom=False), + "charge": ModelOutput(quantity="charge", unit="e", per_atom=False), + "spin": ModelOutput(quantity="spin", unit="", per_atom=False), } outputs = {("extra::" + n): inputs[n] for n in inputs} capabilities = ModelCapabilities( @@ -974,22 +974,22 @@ def test_system_level_input(atoms): calculator = MetatomicCalculator(model, check_consistency=False) results = calculator.run_model(atoms, outputs) - charge_tensor = results["extra::mtt::charge"] + charge_tensor = results["extra::charge"] assert charge_tensor[0].samples.names == ["system"] assert charge_tensor[0].values.dtype == torch.float64 # matches model dtype assert int(charge_tensor[0].values[0, 0]) == -2 - spin_tensor = results["extra::mtt::spin"] + spin_tensor = results["extra::spin"] assert spin_tensor[0].samples.names == ["system"] assert spin_tensor[0].values.dtype == torch.float64 # matches model dtype assert int(spin_tensor[0].values[0, 0]) == 3 def test_system_level_input_defaults(atoms): - """mtt::charge defaults to 0 and mtt::spin to 1 when not set in atoms.info.""" + """charge defaults to 0 and spin to 1 when not set in atoms.info.""" inputs = { - "mtt::charge": ModelOutput(quantity="charge", unit="e", per_atom=False), - "mtt::spin": ModelOutput(quantity="spin", unit="", per_atom=False), + "charge": ModelOutput(quantity="charge", unit="e", per_atom=False), + "spin": ModelOutput(quantity="spin", unit="", per_atom=False), } outputs = {("extra::" + n): inputs[n] for n in inputs} capabilities = ModelCapabilities( @@ -1009,8 +1009,8 @@ def test_system_level_input_defaults(atoms): calculator = MetatomicCalculator(model, check_consistency=False) results = calculator.run_model(atoms, outputs) - assert int(results["extra::mtt::charge"][0].values[0, 0]) == 0 - assert int(results["extra::mtt::spin"][0].values[0, 0]) == 1 + assert int(results["extra::charge"][0].values[0, 0]) == 0 + assert int(results["extra::spin"][0].values[0, 0]) == 1 class ChargeSpinEnergyModel(torch.nn.Module): @@ -1022,8 +1022,8 @@ class ChargeSpinEnergyModel(torch.nn.Module): def requested_inputs(self) -> Dict[str, ModelOutput]: return { - "mtt::charge": ModelOutput(quantity="charge", unit="e", per_atom=False), - "mtt::spin": ModelOutput(quantity="spin", unit="", per_atom=False), + "charge": ModelOutput(quantity="charge", unit="e", per_atom=False), + "spin": ModelOutput(quantity="spin", unit="", per_atom=False), } def forward( @@ -1033,8 +1033,8 @@ def forward( selected_atoms: Optional[Labels] = None, ) -> Dict[str, TensorMap]: system = systems[0] - charge = float(system.get_data("mtt::charge").block(0).values[0, 0]) - spin = float(system.get_data("mtt::spin").block(0).values[0, 0]) + charge = float(system.get_data("charge").block(0).values[0, 0]) + spin = float(system.get_data("spin").block(0).values[0, 0]) energy_value = charge + 10.0 * spin block = TensorBlock( values=torch.tensor([[energy_value]], dtype=torch.float64), @@ -1097,6 +1097,41 @@ def test_system_level_input_changes_energy(atoms): ) +def test_system_level_input_export_roundtrip(atoms, tmp_path): + """Export a charge/spin model to disk and reload via MetatomicCalculator. + + Covers the full pipeline: build → export → save(".pt") → load from file → + run with atoms.info["charge"]/["spin"]. This is the path exercised by + end-users who load a saved model, so it must work end-to-end. + """ + capabilities = ModelCapabilities( + outputs={"energy": ModelOutput(per_atom=False)}, + atomic_types=[28], + interaction_range=0.0, + supported_devices=["cpu"], + dtype="float64", + ) + model = AtomisticModel( + ChargeSpinEnergyModel().eval(), ModelMetadata(), capabilities + ) + model_path = str(tmp_path / "charge_spin_model.pt") + model.save(model_path) + + atoms.info["charge"] = 0 + atoms.info["spin"] = 1 + calc = MetatomicCalculator(model_path, check_consistency=True) + atoms.calc = calc + e_neutral = atoms.get_potential_energy() + + atoms.info["charge"] = 2 + atoms.calc.reset() + e_charged = atoms.get_potential_energy() + + assert e_neutral != e_charged, ( + "Loaded model must produce charge-dependent energies" + ) + + @pytest.mark.parametrize("device,dtype", ALL_DEVICE_DTYPE) def test_mixed_pbc(model, device, dtype): """Test that the calculator works on a mixed-PBC system""" diff --git a/python/metatomic_torch/tests/symmetrized_ase_calculator.py b/python/metatomic_torch/tests/symmetrized_ase_calculator.py index 654a2a527..06d004305 100644 --- a/python/metatomic_torch/tests/symmetrized_ase_calculator.py +++ b/python/metatomic_torch/tests/symmetrized_ase_calculator.py @@ -503,8 +503,8 @@ class _ChargeSpinAnisoModel(torch.nn.Module): def requested_inputs(self) -> Dict[str, ModelOutput]: return { - "mtt::charge": ModelOutput(quantity="charge", unit="e", per_atom=False), - "mtt::spin": ModelOutput(quantity="spin", unit="", per_atom=False), + "charge": ModelOutput(quantity="charge", unit="e", per_atom=False), + "spin": ModelOutput(quantity="spin", unit="", per_atom=False), } def forward( @@ -515,8 +515,8 @@ def forward( ) -> Dict[str, TensorMap]: energies: List[torch.Tensor] = [] for system in systems: - charge = system.get_data("mtt::charge").block(0).values[0, 0] - spin = system.get_data("mtt::spin").block(0).values[0, 0] + charge = system.get_data("charge").block(0).values[0, 0] + spin = system.get_data("spin").block(0).values[0, 0] # Orientation-dependent P1 term (averages to zero under O(3)) b = _body_axis_from_system(system).to(dtype=charge.dtype) From 0a1eef83499a47b12047c35b14d3c7e25c8c83b3 Mon Sep 17 00:00:00 2001 From: JonathanSchmidt1 Date: Thu, 26 Mar 2026 13:35:16 +0100 Subject: [PATCH 16/16] add standard charge/spin in c part --- metatomic-torch/src/misc.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/metatomic-torch/src/misc.cpp b/metatomic-torch/src/misc.cpp index ab4a24982..7da67692a 100644 --- a/metatomic-torch/src/misc.cpp +++ b/metatomic-torch/src/misc.cpp @@ -427,6 +427,8 @@ inline std::unordered_set KNOWN_INPUTS_OUTPUTS = { "velocities", "masses", "charges", + "charge", + "spin", }; std::tuple details::validate_name_and_check_variant(