33Wraps the ``nvalchemiops`` Warp-accelerated electrostatics implementations as
44:class:`~torch_sim.models.interface.ModelInterface` subclasses, with full PBC,
55stress (virial), and batched system support. Per-atom partial charges are read
6- from ``state.site_charges `` (a SimState atom extra).
6+ from ``state.partial_charges `` (a SimState atom extra).
77"""
88
99from __future__ import annotations
@@ -80,7 +80,7 @@ class DSFCoulombModel(ModelInterface):
8080 forces, and (optionally) stress. All user-facing quantities are in
8181 metal units (Angstrom / eV); the Coulomb constant ``ke`` is baked in.
8282
83- Per-atom partial charges are read from ``state.site_charges ``.
83+ Per-atom partial charges are read from ``state.partial_charges ``.
8484
8585 Args:
8686 cutoff: Real-space cutoff in Angstrom.
@@ -121,15 +121,18 @@ def forward(self, state: SimState, **_kwargs: object) -> dict[str, torch.Tensor]
121121 """Compute DSF electrostatic energy, forces, and stress.
122122
123123 Args:
124- state: Simulation state with ``site_charges `` set as an
124+ state: Simulation state with ``partial_charges `` set as an
125125 atom extra (shape ``[n_atoms]``).
126126 **_kwargs: Unused; accepted for interface compatibility.
127127
128128 Returns:
129129 dict with ``"energy"`` [n_systems], ``"forces"`` [n_atoms, 3],
130130 and (if ``compute_stress``) ``"stress"`` [n_systems, 3, 3].
131131 """
132- charges = state .site_charges
132+ if not state .has_extras ("partial_charges" ):
133+ raise ValueError ("Partial charges are required for DSF Coulomb summation." )
134+
135+ charges = state .partial_charges
133136 edge_index , neighbor_ptr , unit_shifts = _build_csr (
134137 state , self .cutoff , self .neighbor_list_fn
135138 )
@@ -170,7 +173,7 @@ class EwaldModel(ModelInterface):
170173 Returns per-atom energies that are aggregated to per-system. All
171174 user-facing quantities are in metal units (Angstrom / eV).
172175
173- Per-atom partial charges are read from ``state.site_charges ``.
176+ Per-atom partial charges are read from ``state.partial_charges ``.
174177
175178 Requires periodic boundary conditions.
176179
@@ -213,7 +216,7 @@ def forward(self, state: SimState, **_kwargs: object) -> dict[str, torch.Tensor]
213216 """Compute Ewald electrostatic energy, forces, and stress.
214217
215218 Args:
216- state: Simulation state with ``site_charges `` set as an
219+ state: Simulation state with ``partial_charges `` set as an
217220 atom extra (shape ``[n_atoms]``). Returns zeros for
218221 non-periodic states.
219222 **_kwargs: Unused; accepted for interface compatibility.
@@ -222,11 +225,14 @@ def forward(self, state: SimState, **_kwargs: object) -> dict[str, torch.Tensor]
222225 dict with ``"energy"`` [n_systems], ``"forces"`` [n_atoms, 3],
223226 and (if ``compute_stress``) ``"stress"`` [n_systems, 3, 3].
224227 """
228+ if not state .has_extras ("partial_charges" ):
229+ raise ValueError ("Partial charges are required for Ewald summation." )
230+
225231 if not state .pbc .any ():
226232 return _zero_result (
227233 state , self ._dtype , self ._compute_forces , self ._compute_stress
228234 )
229- charges = state .site_charges
235+ charges = state .partial_charges
230236 edge_index , neighbor_ptr , unit_shifts = _build_csr (
231237 state , self .cutoff , self .neighbor_list_fn
232238 )
@@ -270,7 +276,7 @@ class PMEModel(ModelInterface):
270276 per-system. All user-facing quantities are in metal units
271277 (Angstrom / eV).
272278
273- Per-atom partial charges are read from ``state.site_charges ``.
279+ Per-atom partial charges are read from ``state.partial_charges ``.
274280
275281 Requires periodic boundary conditions.
276282
@@ -322,7 +328,7 @@ def forward(self, state: SimState, **_kwargs: object) -> dict[str, torch.Tensor]
322328 """Compute PME electrostatic energy, forces, and stress.
323329
324330 Args:
325- state: Simulation state with ``site_charges `` set as an
331+ state: Simulation state with ``partial_charges `` set as an
326332 atom extra (shape ``[n_atoms]``). Returns zeros for
327333 non-periodic states.
328334 **_kwargs: Unused; accepted for interface compatibility.
@@ -331,11 +337,14 @@ def forward(self, state: SimState, **_kwargs: object) -> dict[str, torch.Tensor]
331337 dict with ``"energy"`` [n_systems], ``"forces"`` [n_atoms, 3],
332338 and (if ``compute_stress``) ``"stress"`` [n_systems, 3, 3].
333339 """
340+ if not state .has_extras ("partial_charges" ):
341+ raise ValueError ("Partial charges are required for PME summation." )
342+
334343 if not state .pbc .any ():
335344 return _zero_result (
336345 state , self ._dtype , self ._compute_forces , self ._compute_stress
337346 )
338- charges = state .site_charges
347+ charges = state .partial_charges
339348 edge_index , neighbor_ptr , unit_shifts = _build_csr (
340349 state , self .cutoff , self .neighbor_list_fn
341350 )
0 commit comments