Skip to content
36 changes: 28 additions & 8 deletions ipsuite/analysis/model/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,27 +72,47 @@ def compute_rot_forces(mol, key: str = "forces"):
return result


def force_decomposition(atom, mapping, key: str = "forces"):
def force_decomposition(
atom,
mapping,
full_forces: np.ndarray | None = None,
key: str = "forces",
map: np.ndarray | None = None,
):
if key not in ["forces", "forces_ensemble"]:
raise KeyError("Unknown force decomposition key")
_, molecules = mapping.forward_mapping(atom)
full_forces = np.zeros_like(atom.calc.results[key])
atom_trans_forces = np.zeros_like(atom.calc.results[key])
atom_rot_forces = np.zeros_like(atom.calc.results[key])

if full_forces is not None:
if map is None:
_, molecules, map = mapping.forward_mapping(atom, forces=full_forces)
else:
_, molecules, map = mapping.forward_mapping(atom, forces=full_forces, map=map)
atom_trans_forces = np.zeros_like(full_forces)
atom_rot_forces = np.zeros_like(full_forces)
full_forces = np.zeros_like(full_forces)

elif atom.calc is not None:
try:
_, molecules, map = mapping.forward_mapping(atom, map=map)
except NameError:
_, molecules, map = mapping.forward_mapping(atom)
full_forces = np.zeros_like(atom.calc.results[key])
atom_trans_forces = np.zeros_like(atom.calc.results[key])
atom_rot_forces = np.zeros_like(atom.calc.results[key])

total_n_atoms = 0

for molecule in molecules:
n_atoms = len(molecule)
mol_slice = slice(total_n_atoms, total_n_atoms + n_atoms)
# TODO: What if molecule indices are not ordered?
full_forces[mol_slice] = molecule.calc.results[key]
atom_rot_forces[mol_slice] = compute_rot_forces(molecule, key)
atom_trans_forces[mol_slice] = compute_trans_forces(molecule, key)
total_n_atoms += n_atoms

# print(full_forces-test)
atom_vib_forces = full_forces - atom_trans_forces - atom_rot_forces

return atom_trans_forces, atom_rot_forces, atom_vib_forces
return atom_trans_forces, atom_rot_forces, atom_vib_forces, map


def decompose_stress_tensor(stresses):
Expand Down
32 changes: 26 additions & 6 deletions ipsuite/geometry/mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import typing as t

import ase
import numpy as np

from ipsuite.geometry import barycenter_coarse_grain, graphs, unwrap

Expand All @@ -27,21 +28,40 @@ class BarycenterMapping:
The indices of the molecules will be frozen for all configurations.
"""

frozen: bool = False
frozen: bool = True

_components: t.Any | None = None

def forward_mapping(self, atoms: ase.Atoms) -> tuple[ase.Atoms, list[ase.Atoms]]:
if self._components is None:
def forward_mapping(
self,
atoms: ase.Atoms,
forces: np.ndarray | None = None,
map: np.ndarray | None = None,
) -> tuple[ase.Atoms, list[ase.Atoms]]:
if map is None:
components = graphs.identify_molecules(atoms)
print("recompute")
else:
components = map

"""if self._components is None:
components = graphs.identify_molecules(atoms)
print("\n got new comps")
else:
components = self._components
print("\n using frozen comps")

if self.frozen:
self._components = components
molecules = unwrap.unwrap_system(atoms, components)
print("\n is frozen")
self._components = components"""

# components = np.arange(0, 3*40).reshape(-1,3)
if forces is not None:
molecules = unwrap.unwrap_system(atoms, components, forces=forces.copy())
else:
molecules = unwrap.unwrap_system(atoms, components)
cg_atoms = barycenter_coarse_grain.coarse_grain_to_barycenter(molecules)
return cg_atoms, molecules
return cg_atoms, molecules, components

def backward_mapping(
self, cg_atoms: ase.Atoms, molecules: list[ase.Atoms]
Expand Down
15 changes: 11 additions & 4 deletions ipsuite/geometry/unwrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@ def unwrap(atoms, edges, idx):
displace_neighbors(atoms, e)


def unwrap_system(atoms: ase.Atoms, components: list[np.ndarray]) -> list[ase.Atom]:
def unwrap_system(
atoms: ase.Atoms, components: list[np.ndarray], forces: np.ndarray | None = None
) -> list[ase.Atom]:
"""Molecules in a system which extend across periodic boundaries are mapped such that
they are connected but dangle out of the cell.
Mapping to the side where the fragment of molecule is closest
Expand All @@ -49,9 +51,15 @@ def unwrap_system(atoms: ase.Atoms, components: list[np.ndarray]) -> list[ase.At
and calling the `atoms.wrap()` method.
"""
molecules = []

for component in components:
mol = atoms[component].copy()
if atoms.calc is not None:
component = np.sort(component)
mol = atoms[component]
if forces is not None:
results = {"forces": forces[component].copy()}
mol.calc = SinglePointCalculator(mol, **results)

elif atoms.calc is not None:
results = {"forces": atoms.get_forces()[component]}
if "forces_uncertainty" in atoms.calc.results.keys():
f_unc = atoms.calc.results["forces_uncertainty"][component]
Expand All @@ -66,5 +74,4 @@ def unwrap_system(atoms: ase.Atoms, components: list[np.ndarray]) -> list[ase.At
closest_atom = closest_atom_to_center(mol)
unwrap(mol, edges, idx=closest_atom)
molecules.append(mol)

return molecules