Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
260 changes: 260 additions & 0 deletions ipsuite/data_loading/add_data_gromacs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,260 @@
import logging
import typing
import warnings
from pathlib import Path

import h5py
import MDAnalysis as mda
import numpy as np
import tqdm
import znh5md
import zntrack
from ase import Atoms
from ase.calculators.singlepoint import SinglePointCalculator
from ase.units import kJ, mol
from MDAnalysis.auxiliary.EDR import EDRReader

logger = logging.getLogger(__name__)

_TYPE_TO_ELEMENT = {
"CL": "Cl",
"NA": "Na",
"MG": "Mg",
"ZN": "Zn",
"FE": "Fe",
"CA": "Ca",
"MN": "Mn",
"CU": "Cu",
"LI": "Li",
"AL": "Al",
"SI": "Si",
"BR": "Br",
"SE": "Se",
}


def _get_symbols(u: mda.Universe) -> list[str]:
"""Extract element symbols from a Universe, trying multiple strategies."""
# 1. Use elements attribute if available
try:
return list(u.atoms.elements)
except (mda.exceptions.NoDataError, AttributeError):
pass

# 2. Use atom types (usually cleaner than names for CHARMM-GUI)
types = u.atoms.types
symbols = []
for t in types:
t_upper = t.upper()
if t_upper in _TYPE_TO_ELEMENT:
symbols.append(_TYPE_TO_ELEMENT[t_upper])
elif len(t) <= 2 and t[0].isalpha():
# Capitalize properly: first letter upper, rest lower
symbols.append(t[0].upper() + t[1:].lower() if len(t) > 1 else t.upper())
else:
# Last resort: take leading alphabetic characters from atom name
symbols.append(t[0].upper())
return symbols
Comment on lines +44 to +57
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

cat -n ipsuite/data_loading/add_data_gromacs.py | head -100

Repository: zincware/IPSuite

Length of output: 3699


🏁 Script executed:

# Check the full _get_symbols function and surrounding context
rg -A 30 "def _get_symbols" ipsuite/data_loading/add_data_gromacs.py

Repository: zincware/IPSuite

Length of output: 1268


🏁 Script executed:

# Look for how _get_symbols is used in the file
rg "_get_symbols" ipsuite/data_loading/add_data_gromacs.py -B 2 -A 2

Repository: zincware/IPSuite

Length of output: 335


🏁 Script executed:

# Check if there's any existing logic that uses u.atoms.names
rg "u\.atoms\.(names|types)" ipsuite/data_loading/add_data_gromacs.py -B 2 -A 2

Repository: zincware/IPSuite

Length of output: 191


🏁 Script executed:

# Let's understand MDAnalysis element guessing behavior better
python << 'EOF'
# Check what MDAnalysis does with element guessing
import sys
# Try to get info about the default guesser
try:
    from MDAnalysis.guesser import DEFAULT_GUESSER
    print("DEFAULT_GUESSER available")
    print(f"Type: {type(DEFAULT_GUESSER)}")
except ImportError:
    print("DEFAULT_GUESSER not directly available")

# Check the implementation
try:
    from MDAnalysis.guesser.default import DefaultGuesser
    print("DefaultGuesser available")
    # Check what the guesser does
    import inspect
    print(inspect.getsourcefile(DefaultGuesser))
except ImportError as e:
    print(f"Could not import DefaultGuesser: {e}")

EOF

Repository: zincware/IPSuite

Length of output: 160


🏁 Script executed:

# Let's search for documentation or usage of element guessing in the codebase
rg "guess.*element|element.*guess" ipsuite/ -i

Repository: zincware/IPSuite

Length of output: 42


🏁 Script executed:

# Check if there are any tests or examples showing how _get_symbols should work
rg -l "_get_symbols|gmx_to_ase" ipsuite/ --type py

Repository: zincware/IPSuite

Length of output: 100


🏁 Script executed:

# Look at the broader context - is there any test file for this module?
fd "test.*gromacs|gromacs.*test" ipsuite/ -type f

Repository: zincware/IPSuite

Length of output: 230


🏁 Script executed:

# Let's check what MDAnalysis version is being used and look at its documentation
rg "MDAnalysis" ipsuite/ setup.py pyproject.toml requirements.txt -A 1 -B 1

Repository: zincware/IPSuite

Length of output: 2179


🏁 Script executed:

# Check if there are any comments or documentation about element inference in the file
rg "element|symbol" ipsuite/data_loading/add_data_gromacs.py -i

Repository: zincware/IPSuite

Length of output: 647


🏁 Script executed:

# Let's check if there are other usages of this in the codebase to understand the impact
rg "gmx_to_ase|_get_symbols" ipsuite/ -l

Repository: zincware/IPSuite

Length of output: 100


🏁 Script executed:

# Check the pyproject.toml or similar for MDAnalysis version info
find ipsuite -name "pyproject.toml" -o -name "setup.cfg" | head -5 | xargs cat 2>/dev/null | head -100

Repository: zincware/IPSuite

Length of output: 42


🏁 Script executed:

# Look at what ASE Atoms expects for symbols
python << 'EOF'
# Check what happens when you create ASE atoms with invalid symbols
# This will help us understand the severity

test_cases = [
    "Ow",  # Invalid - OW type
    "Ca",  # Calcium
    "C",   # Carbon
    "O",   # Oxygen
]

# Check if these are valid element symbols in ASE
try:
    from ase.data import atomic_numbers
    for sym in test_cases:
        try:
            an = atomic_numbers[sym]
            print(f"✓ {sym:3s} -> atomic number {an}")
        except KeyError:
            print(f"✗ {sym:3s} -> INVALID/UNKNOWN")
except ImportError:
    print("ASE not available in sandbox, but we can reason about this")
    print("Standard element symbols are 1-2 chars, capitalized properly")
    for sym in test_cases:
        # Standard chemistry rules
        if sym == "Ow":
            print(f"✗ {sym} -> INVALID (no such element)")
        elif sym == "Ca":
            print(f"✓ {sym} -> Calcium (valid but semantically wrong for Carbon Alpha)")
        elif sym in ["C", "O"]:
            print(f"✓ {sym} -> Valid")

EOF

Repository: zincware/IPSuite

Length of output: 303


🏁 Script executed:

# Let's also verify what the comment says about MDAnalysis names vs types
# Look at the gmx.py file that might have similar logic
cat -n ipsuite/configuration_generation/gmx.py | grep -A 20 "def timestep_to_atoms"

Repository: zincware/IPSuite

Length of output: 842


🏁 Script executed:

# Let's check what happens after symbols are used - does ASE validate them?
rg -A 5 "Atoms(symbols=symbols" ipsuite/data_loading/add_data_gromacs.py

Repository: zincware/IPSuite

Length of output: 137


🏁 Script executed:

# Let's understand the problem better by looking at actual GROMACS atom types
# Check if there are any test files or examples
find . -name "*.gro" -o -name "*.tpr" -o -name "*.xtc" 2>/dev/null | head -5

Repository: zincware/IPSuite

Length of output: 42


🏁 Script executed:

# Search for any documentation on the expected behavior
rg "water|OW|CA" ipsuite/data_loading/ -i

Repository: zincware/IPSuite

Length of output: 1415


🏁 Script executed:

# Let's check what the MDAnalysis documentation actually says about element guessing
web_search
MDAnalysis element guessing default guesser names types

Repository: zincware/IPSuite

Length of output: 159


🌐 Web query:

MDAnalysis element guessing default guesser documentation

💡 Result:

The default guesser for element guessing in MDAnalysis is DefaultGuesser from MDAnalysis.guesser.default_guesser. It is the generic, default guessing class used for most topologies, particularly following biological naming conventions (e.g., "CA" guessed as carbon, not calcium). Element guessing behavior: The guess_atom_element(atomname) method strips numbers and symbols from the atom name, looks it up in MDAnalysis.guesser.tables.atomelements (derived from CHARMM/AMBER tables), removes the last character if not found, or falls back to the first character of the name if still not found. Empty names return empty string. When guessing types, they represent the element using the same method. It is used automatically as the default context (context='default') when creating a Universe with to_guess=['types', 'elements'] or via u.guess_TopologyAttrs(to_guess=['elements']). Pre-2.8.0, standalone functions like guess_atom_element in MDAnalysis.topology.guessers were used (now deprecated). Example: import MDAnalysis as mda u = mda.Universe('topology.gro', context='default', to_guess=['elements']) # or u.guess_TopologyAttrs(to_guess=['elements']) Official documentation: https://docs.mdanalysis.org/stable/documentation_pages/guesser_modules/default_guesser.html and https://userguide.mdanalysis.org/stable/formats/guessers/default.html (MDAnalysis 2.10.0 as of latest docs).

Citations:


🏁 Script executed:

# Let's look at the actual usage and see if invalid symbols cause problems
cat -n ipsuite/data_loading/add_data_gromacs.py | sed -n '120,160p'

Repository: zincware/IPSuite

Length of output: 1823


🏁 Script executed:

# Check if there's validation happening elsewhere
rg "atomic_numbers\|invalid.*symbol\|element.*valid" ipsuite/data_loading/add_data_gromacs.py

Repository: zincware/IPSuite

Length of output: 42


🏁 Script executed:

# Let's examine what happens with "Ow" type - is it even a realistic scenario?
# Check if there are any CHARMM-GUI or common GROMACS force field references
rg "CHARMM|amber|oplsaa|gromos" ipsuite/ -i | head -10

Repository: zincware/IPSuite

Length of output: 168


🏁 Script executed:

# Let's look at the exact lines where symbols are used to create Atoms
cat -n ipsuite/data_loading/add_data_gromacs.py | sed -n '136,155p'

Repository: zincware/IPSuite

Length of output: 928


Use MDAnalysis's proper element guessing from atom names, not types.

The current approach using u.atoms.types creates semantic bugs:

  • OW (water oxygen in CHARMM-GUI) becomes invalid Ow
  • CA (carbon alpha) is hardcoded to Ca (calcium) instead of being guessed as carbon

MDAnalysis's DefaultGuesser operates on atom names using domain-standard rules (stripping digits/symbols, lookups in CHARMM/AMBER tables). Replace the manual type-based logic with u.atoms.names and leverage MDAnalysis's built-in element inference, or call u.guess_TopologyAttrs(to_guess=['elements']) to populate u.atoms.elements properly.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@ipsuite/data_loading/add_data_gromacs.py` around lines 44 - 57, The code
currently derives element symbols from u.atoms.types which misclassifies
CHARMM/AMBER types; change the logic to use atom names and MDAnalysis's element
guessing instead: use u.atoms.names (or call
u.guess_TopologyAttrs(to_guess=['elements']) and read u.atoms.elements) and
build the symbols list from those guessed elements (or rely on
MDAnalysis.DefaultGuesser behavior) rather than transforming types; update the
block that references types, t, and symbols to use names/elements (keeping the
same return of symbols) and remove the manual type-based heuristics.



def _match_edr_frame(
edr_times: np.ndarray, traj_time: float, tolerance: float = 0.1
) -> int:
"""Find the EDR index closest to a trajectory time, warning on large gaps."""
idx = int(np.argmin(np.abs(edr_times - traj_time)))
time_diff = abs(edr_times[idx] - traj_time)
if time_diff > tolerance:
logger.warning(
"EDR time %.3f ps does not match trajectory time %.3f ps (diff=%.3f ps)",
edr_times[idx],
traj_time,
time_diff,
)
return idx
Comment on lines +60 to +73
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Avoid quadratic EDR matching and reject missed timestamps.

np.argmin(np.abs(edr_times - traj_time)) rescans the full EDR axis for every frame, so matching becomes O(n×m) on long runs. It also still returns a row after tolerance is exceeded, which means shifted or downsampled .edr files silently annotate frames with the wrong labels. If the time axes are monotonic, np.searchsorted or a moving pointer avoids the repeated scan and makes it easy to skip/raise on out-of-tolerance matches.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@ipsuite/data_loading/add_data_gromacs.py` around lines 60 - 73,
_match_edr_frame currently scans the whole edr_times array per frame which is
O(n×m) and silently returns an out-of-tolerance index; replace the np.argmin
approach with np.searchsorted on the monotonic edr_times (or use a moving
pointer maintained by the caller) to find the nearest candidate on the
left/right, compute the nearest index and time_diff, and if time_diff >
tolerance reject the match (raise a ValueError or return a sentinel like
None/-1) instead of returning a bad index; update callers to handle the
rejection and keep the function signature (_match_edr_frame(edr_times,
traj_time, tolerance)) and logging behavior.



def gmx_to_ase(
topology: str,
trajectory: str | None = None,
edr: str | None = None,
start: int | None = None,
stop: int | None = None,
step: int | None = None,
) -> list[Atoms]:
"""Convert a GROMACS trajectory to a list of ASE Atoms objects.

Extracts all available information: positions, velocities, forces,
and (via the .edr file) energies and stress.

Parameters
----------
topology : str
Path to a GROMACS topology/structure file (.gro, .tpr).
trajectory : str | None
Path to a trajectory file (.xtc, .trr). If None, only the single
structure from the topology file is returned.
edr : str | None
Path to a GROMACS energy file (.edr). If given, per-frame energies
and stress tensors are attached via SinglePointCalculator.
start, stop, step : int | None
Slice parameters for selecting a subset of frames.

Returns
-------
list[Atoms]
One ASE Atoms object per frame. Each Atoms has:
- positions (always)
- cell and pbc (always)
- velocities (if present in trajectory)
- forces (if present in trajectory, e.g. .trr)
- calculator with energy/stress/forces (if .edr provided or forces
present), plus all EDR terms stored in calc.results
"""
if trajectory is not None:
u = mda.Universe(topology, trajectory)
else:
u = mda.Universe(topology)

symbols = _get_symbols(u)

# Load EDR data if provided
edr_data = None
if edr is not None:
with warnings.catch_warnings():
warnings.simplefilter("ignore")
reader = EDRReader(edr)
edr_all = reader.get_data(list(reader.terms))
edr_times = edr_all.pop("Time")
edr_data = dict(edr_all)
edr_terms = list(edr_data.keys())

frames = []
for ts in tqdm.tqdm(u.trajectory[start:stop:step]):
positions = ts.positions.copy()
box = ts.dimensions

atoms = Atoms(symbols=symbols, positions=positions, pbc=True)

if box is not None and all(box[:3] > 0):
atoms.set_cell(box, scale_atoms=False)

# Velocities (e.g. from .gro or .trr)
if ts.has_velocities:
# MDAnalysis: Å/ps, ASE: Å/fs -> divide by 1000
atoms.set_velocities(ts.velocities / 1000.0)

# Forces and energies via SinglePointCalculator
forces = ts.forces.copy() if ts.has_forces else None
energy = None
stress = None
extra_results = {}

if edr_data is not None:
idx = _match_edr_frame(edr_times, ts.time)
energy = float(edr_data["Potential"][idx]) * (kJ / mol) # convert to eV

# Build Voigt stress from pressure tensor if available
try:
pxx = edr_data["Pres-XX"][idx]
pyy = edr_data["Pres-YY"][idx]
pzz = edr_data["Pres-ZZ"][idx]
pyz = edr_data["Pres-YZ"][idx]
pxz = edr_data["Pres-XZ"][idx]
pxy = edr_data["Pres-XY"][idx]
# GROMACS pressure in bar -> store as-is (not ASE native eV/ų)
stress = np.array([pxx, pyy, pzz, pyz, pxz, pxy])
Comment on lines +164 to +165
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Stress units mismatch: GROMACS pressure is in bar, ASE expects eV/ų.

The comment acknowledges the unit discrepancy, but storing stress in bar breaks ASE's stress conventions. Any code calling atoms.get_stress() will receive values in the wrong units, leading to incorrect virial/pressure calculations downstream.

Convert bar to eV/ų: 1 bar = 1e-4 GPa = 6.2415e-7 eV/ų (use ase.units for precision).

Proposed fix
+from ase import units
+
+# bar -> eV/ų
+BAR_TO_EV_ANG3 = 1.0 / (units.bar / (units.eV / units.Ang**3))
+
 ...
-                # GROMACS pressure in bar -> store as-is (not ASE native eV/ų)
-                stress = np.array([pxx, pyy, pzz, pyz, pxz, pxy])
+                # Convert GROMACS pressure (bar) to ASE stress (eV/ų)
+                stress = np.array([pxx, pyy, pzz, pyz, pxz, pxy]) * BAR_TO_EV_ANG3

except KeyError:
pass

# Store all EDR terms for this frame
for term in edr_terms:
extra_results[term] = float(edr_data[term][idx])

if energy is not None or forces is not None:
calc = SinglePointCalculator(
atoms,
energy=energy,
forces=forces,
stress=stress,
)
Comment on lines +147 to +179
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

# First, let's look at the full file context
cd /tmp && find . -name "add_data_gromacs.py" -type f 2>/dev/null | head -5

Repository: zincware/IPSuite

Length of output: 42


🏁 Script executed:

# Find the file in the repository
find . -name "add_data_gromacs.py" -type f 2>/dev/null

Repository: zincware/IPSuite

Length of output: 102


🏁 Script executed:

# Read the relevant file section to understand context
cat -n ipsuite/data_loading/add_data_gromacs.py | head -200

Repository: zincware/IPSuite

Length of output: 7743


🏁 Script executed:

# Check imports to see how kJ/mol is defined
head -50 ipsuite/data_loading/add_data_gromacs.py

Repository: zincware/IPSuite

Length of output: 1203


🌐 Web query:

MDAnalysis force units trajectory trajectories

💡 Result:

In MDAnalysis, the base unit for forces in trajectories is kJ/(mol·Å). This applies when forces are present in the trajectory file format (e.g., via Universe.atoms.forces). MDAnalysis automatically converts forces to this unit upon reading trajectories, regardless of the original MD simulation format. Other units like kcal/(mol·Å) or N can be converted using MDAnalysis.units.convert.

Citations:


🌐 Web query:

ASE SinglePointCalculator force units eV Angstrom

💡 Result:

In the Atomic Simulation Environment (ASE), the SinglePointCalculator stores pre-computed results (energy, forces, stress, etc.) for a specific atomic configuration from an external calculation. It does not perform new calculations but returns the stored values via methods like get_forces. ASE uses consistent internal units: energies in eV, lengths in Angstroms (Å). Consequently, forces are in eV/Å. The SinglePointCalculator preserves whatever units the input data had when created (typically from I/O readers that convert external code outputs to ASE units). When reading results from formats like CASTEP .geom or ABINIT output, forces are explicitly converted to eV/Angstrom. For example: - CASTEP .geom reader: "Note that everything in the .geom file is in atomic units, which has been conversed to commonly used unit angstrom(length) and eV (energy)." - ABINIT reader: forces from 'cartesian_forces' converted using Hartree / Bohr = eV/Å. All ASE calculators, including those populating SinglePointCalculator, follow this convention (eV for energy, eV/Å for forces), as confirmed by documentation, tutorials (e.g., EMT calculator forces in eV/Å), and community sources. Thus, forces from ASE SinglePointCalculator are in eV/Å.

Citations:


Convert forces to eV/Å before creating the calculator.

Line 147 forwards ts.forces unchanged, but MDAnalysis uses kJ/(mol·Å) as its base force unit while ASE expects forces in eV/Å. That leaves every stored force off by the same kJ/mol → eV factor you're already applying to Potential on line 154.

Suggested fix
-        forces = ts.forces.copy() if ts.has_forces else None
+        forces = ts.forces.copy() * (kJ / mol) if ts.has_forces else None
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@ipsuite/data_loading/add_data_gromacs.py` around lines 147 - 179, ts.forces
are left in MDAnalysis units (kJ/(mol·Å)) but ASE expects eV/Å; multiply the
copied forces by the same conversion factor used for energy (the (kJ / mol)
factor) before passing into SinglePointCalculator. Update the code that sets
forces (the ts.forces.copy() branch) to convert units (e.g., forces =
ts.forces.copy() * (kJ / mol) when ts.has_forces) so the forces variable passed
to SinglePointCalculator(atoms, energy=..., forces=forces, stress=...) is in
eV/Å.

calc.results.update(extra_results)
atoms.calc = calc

frames.append(atoms)

return frames


class Gmx2Frames(zntrack.Node):
"""Convert GROMACS output files to ASE Atoms frames.

Reads topology, trajectory, and optionally energy (.edr) files
to produce a list of ASE Atoms with positions, velocities, forces,
energies, and stress where available.

Parameters
----------
topology : Path
Path to a GROMACS topology/structure file (.gro, .tpr).
trajectory : Path, optional
Path to a trajectory file (.xtc, .trr).
edr : Path, optional
Path to a GROMACS energy file (.edr).
start : int, optional
First frame index to read.
stop : int, optional
Last frame index (exclusive) to read.
step : int, optional
Step size for frame selection.

Examples
--------
>>> with project:
... md = ips.Gmx2Frames(
... topology="gromacs/system.gro",
... trajectory="gromacs/production.xtc",
... edr="gromacs/production.edr",
... start=1,
... )
"""

topology: Path = zntrack.deps_path()
trajectory: Path | None = zntrack.deps_path(None)
edr: Path | None = zntrack.deps_path(None)
start: int | None = zntrack.params(None)
stop: int | None = zntrack.params(None)
step: int | None = zntrack.params(None)

frames_path: Path = zntrack.outs_path(zntrack.nwd / "frames.h5")

def run(self) -> None:
data = gmx_to_ase(
topology=str(self.topology),
trajectory=str(self.trajectory) if self.trajectory else None,
edr=str(self.edr) if self.edr else None,
start=self.start,
stop=self.stop,
step=self.step,
)
frame_io = znh5md.IO(self.frames_path)
frame_io.extend(data)

@property
def frames(self) -> typing.List[Atoms]:
with self.state.fs.open(self.frames_path, "rb") as f:
with h5py.File(f) as file:
return znh5md.IO(file_handle=file)[:]


if __name__ == "__main__":
# Example: load the production trajectory with energies
frames = gmx_to_ase(
"gromacs/system.gro",
"gromacs/production.xtc",
edr="gromacs/production.edr",
)
print(f"Loaded {len(frames)} frames, {len(frames[0])} atoms per frame")
print(f"Cell: {frames[0].cell.cellpar()}")
print(f"Potential energy (frame 0): {frames[0].get_potential_energy()} eV")
if len(frames) >= 2:
print(f"All EDR terms on frame 1: {list(frames[1].calc.results.keys())}")
Loading