Skip to content

Commit e0dc7fd

Browse files
committed
fea: site_charges -> partial_charges
1 parent afb45e6 commit e0dc7fd

3 files changed

Lines changed: 26 additions & 19 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,14 +41,12 @@ test = [
4141
"torch-sim-atomistic[io,symmetry,vesin]",
4242
"platformdirs>=4.0.0",
4343
"psutil>=7.0.0",
44-
"pymatgen>=2025.6.14",
4544
"pytest-cov>=6",
4645
"pytest>=8",
4746
"spglib>=2.6",
48-
"vesin[torch]>=0.5.3",
4947
]
5048
vesin = ["vesin[torch]>=0.5.3"]
51-
io = ["ase>=3.26", "phonopy>=2.37.0", "pymatgen>=2025.6.14"]
49+
io = ["ase>=3.26", "phonopy>=2.37.0", "pymatgen>=2026.3.23"]
5250
symmetry = ["moyopy>=0.7.8"]
5351
mace = ["mace-torch>=0.3.15"]
5452
mattersim = ["mattersim>=0.1.2"]

tests/models/test_electrostatics.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def _make_charged_state(
3030
charges = torch.empty(n, dtype=dtype, device=device)
3131
charges[::2] = 1.0
3232
charges[1::2] = -1.0
33-
state._atom_extras["site_charges"] = charges # noqa: SLF001
33+
state._atom_extras["partial_charges"] = charges # noqa: SLF001
3434
return state
3535

3636

@@ -49,33 +49,33 @@ def pme_model() -> PMEModel:
4949
return PMEModel(cutoff=8.0, accuracy=1e-6, device=DEVICE, dtype=DTYPE)
5050

5151

52-
def _add_site_charges(state: ts.SimState) -> ts.SimState:
52+
def _add_partial_charges(state: ts.SimState) -> ts.SimState:
5353
"""Inject alternating +/-0.5 site charges into a state."""
5454
n = state.n_atoms
5555
charges = torch.zeros(n, dtype=state.positions.dtype, device=state.device)
5656
charges[::2] = 0.5
5757
charges[1::2] = -0.5
58-
state._atom_extras["site_charges"] = charges # noqa: SLF001
58+
state._atom_extras["partial_charges"] = charges # noqa: SLF001
5959
return state
6060

6161

6262
test_dsf_model_outputs = make_validate_model_outputs_test(
6363
model_fixture_name="dsf_model",
6464
device=DEVICE,
6565
dtype=DTYPE,
66-
state_modifiers=[_add_site_charges],
66+
state_modifiers=[_add_partial_charges],
6767
)
6868
test_ewald_model_outputs = make_validate_model_outputs_test(
6969
model_fixture_name="ewald_model",
7070
device=DEVICE,
7171
dtype=DTYPE,
72-
state_modifiers=[_add_site_charges],
72+
state_modifiers=[_add_partial_charges],
7373
)
7474
test_pme_model_outputs = make_validate_model_outputs_test(
7575
model_fixture_name="pme_model",
7676
device=DEVICE,
7777
dtype=DTYPE,
78-
state_modifiers=[_add_site_charges],
78+
state_modifiers=[_add_partial_charges],
7979
)
8080

8181

torch_sim/models/electrostatics.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
Wraps the ``nvalchemiops`` Warp-accelerated electrostatics implementations as
44
:class:`~torch_sim.models.interface.ModelInterface` subclasses, with full PBC,
55
stress (virial), and batched system support. Per-atom partial charges are read
6-
from ``state.site_charges`` (a SimState atom extra).
6+
from ``state.partial_charges`` (a SimState atom extra).
77
"""
88

99
from __future__ import annotations
@@ -80,7 +80,7 @@ class DSFCoulombModel(ModelInterface):
8080
forces, and (optionally) stress. All user-facing quantities are in
8181
metal units (Angstrom / eV); the Coulomb constant ``ke`` is baked in.
8282
83-
Per-atom partial charges are read from ``state.site_charges``.
83+
Per-atom partial charges are read from ``state.partial_charges``.
8484
8585
Args:
8686
cutoff: Real-space cutoff in Angstrom.
@@ -121,15 +121,18 @@ def forward(self, state: SimState, **_kwargs: object) -> dict[str, torch.Tensor]
121121
"""Compute DSF electrostatic energy, forces, and stress.
122122
123123
Args:
124-
state: Simulation state with ``site_charges`` set as an
124+
state: Simulation state with ``partial_charges`` set as an
125125
atom extra (shape ``[n_atoms]``).
126126
**_kwargs: Unused; accepted for interface compatibility.
127127
128128
Returns:
129129
dict with ``"energy"`` [n_systems], ``"forces"`` [n_atoms, 3],
130130
and (if ``compute_stress``) ``"stress"`` [n_systems, 3, 3].
131131
"""
132-
charges = state.site_charges
132+
if not state.has_extras("partial_charges"):
133+
raise ValueError("Partial charges are required for DSF Coulomb summation.")
134+
135+
charges = state.partial_charges
133136
edge_index, neighbor_ptr, unit_shifts = _build_csr(
134137
state, self.cutoff, self.neighbor_list_fn
135138
)
@@ -170,7 +173,7 @@ class EwaldModel(ModelInterface):
170173
Returns per-atom energies that are aggregated to per-system. All
171174
user-facing quantities are in metal units (Angstrom / eV).
172175
173-
Per-atom partial charges are read from ``state.site_charges``.
176+
Per-atom partial charges are read from ``state.partial_charges``.
174177
175178
Requires periodic boundary conditions.
176179
@@ -213,7 +216,7 @@ def forward(self, state: SimState, **_kwargs: object) -> dict[str, torch.Tensor]
213216
"""Compute Ewald electrostatic energy, forces, and stress.
214217
215218
Args:
216-
state: Simulation state with ``site_charges`` set as an
219+
state: Simulation state with ``partial_charges`` set as an
217220
atom extra (shape ``[n_atoms]``). Returns zeros for
218221
non-periodic states.
219222
**_kwargs: Unused; accepted for interface compatibility.
@@ -222,11 +225,14 @@ def forward(self, state: SimState, **_kwargs: object) -> dict[str, torch.Tensor]
222225
dict with ``"energy"`` [n_systems], ``"forces"`` [n_atoms, 3],
223226
and (if ``compute_stress``) ``"stress"`` [n_systems, 3, 3].
224227
"""
228+
if not state.has_extras("partial_charges"):
229+
raise ValueError("Partial charges are required for Ewald summation.")
230+
225231
if not state.pbc.any():
226232
return _zero_result(
227233
state, self._dtype, self._compute_forces, self._compute_stress
228234
)
229-
charges = state.site_charges
235+
charges = state.partial_charges
230236
edge_index, neighbor_ptr, unit_shifts = _build_csr(
231237
state, self.cutoff, self.neighbor_list_fn
232238
)
@@ -270,7 +276,7 @@ class PMEModel(ModelInterface):
270276
per-system. All user-facing quantities are in metal units
271277
(Angstrom / eV).
272278
273-
Per-atom partial charges are read from ``state.site_charges``.
279+
Per-atom partial charges are read from ``state.partial_charges``.
274280
275281
Requires periodic boundary conditions.
276282
@@ -322,7 +328,7 @@ def forward(self, state: SimState, **_kwargs: object) -> dict[str, torch.Tensor]
322328
"""Compute PME electrostatic energy, forces, and stress.
323329
324330
Args:
325-
state: Simulation state with ``site_charges`` set as an
331+
state: Simulation state with ``partial_charges`` set as an
326332
atom extra (shape ``[n_atoms]``). Returns zeros for
327333
non-periodic states.
328334
**_kwargs: Unused; accepted for interface compatibility.
@@ -331,11 +337,14 @@ def forward(self, state: SimState, **_kwargs: object) -> dict[str, torch.Tensor]
331337
dict with ``"energy"`` [n_systems], ``"forces"`` [n_atoms, 3],
332338
and (if ``compute_stress``) ``"stress"`` [n_systems, 3, 3].
333339
"""
340+
if not state.has_extras("partial_charges"):
341+
raise ValueError("Partial charges are required for PME summation.")
342+
334343
if not state.pbc.any():
335344
return _zero_result(
336345
state, self._dtype, self._compute_forces, self._compute_stress
337346
)
338-
charges = state.site_charges
347+
charges = state.partial_charges
339348
edge_index, neighbor_ptr, unit_shifts = _build_csr(
340349
state, self.cutoff, self.neighbor_list_fn
341350
)

0 commit comments

Comments
 (0)