From 49ede965a2677eb4c216f1059ddb6c4d12644c8e Mon Sep 17 00:00:00 2001 From: IAlibay Date: Thu, 1 Jan 2026 09:23:41 -0500 Subject: [PATCH 01/11] move the protocols results to a single file and deduplicate --- .../openmm_afe/afe_protocol_results.py | 561 ++++++++++++++++++ .../openmm_afe/equil_binding_afe_method.py | 414 ------------- .../openmm_afe/equil_solvation_afe_method.py | 330 +---------- 3 files changed, 574 insertions(+), 731 deletions(-) create mode 100644 openfe/protocols/openmm_afe/afe_protocol_results.py diff --git a/openfe/protocols/openmm_afe/afe_protocol_results.py b/openfe/protocols/openmm_afe/afe_protocol_results.py new file mode 100644 index 000000000..e193aebe1 --- /dev/null +++ b/openfe/protocols/openmm_afe/afe_protocol_results.py @@ -0,0 +1,561 @@ +# This code is part of OpenFE and is licensed under the MIT license. +# For details, see https://github.com/OpenFreeEnergy/openfe +"""OpenMM Equilibrium Solvation AFE Protocol --- :mod:`openfe.protocols.openmm_afe.equil_solvation_afe_method` +=============================================================================================================== + +This module implements the necessary methodology tooling to run calculate an +absolute solvation free energy using OpenMM tools and one of the following +alchemical sampling methods: + +* Hamiltonian Replica Exchange +* Self-adjusted mixture sampling +* Independent window sampling + +Current limitations +------------------- +* Alchemical species with a net charge are not currently supported. +* Disapearing molecules are only allowed in state A. Support for + appearing molecules will be added in due course. +* Only small molecules are allowed to act as alchemical molecules. + Alchemically changing protein or solvent components would induce + perturbations which are too large to be handled by this Protocol. + + +Acknowledgements +---------------- +* Originally based on hydration.py in + `espaloma_charge `_ + +""" + +from __future__ import annotations + +import itertools +import logging +import pathlib +import warnings +from typing import Optional, Union + +import gufe +import numpy as np +import numpy.typing as npt +from openff.units import Quantity +from openmmtools import multistate + +from openfe.protocols.restraint_utils.geometry.boresch import BoreschRestraintGeometry + + +logger = logging.getLogger(__name__) + + +class AbsoluteProtocolResultMixin: + bound_state = "solvent" + unbound_state = "vacuum" + + def __init__(self, **data): + super().__init__(**data) + # TODO: Detect when we have extensions and stitch these together? + if any( + len(pur_list) > 2 + for pur_list in itertools.chain( + self.data[self.bound_state].values(), self.data[self.unbound_state].values() + ) + ): + raise NotImplementedError("Can't stitch together results yet") + + def get_forward_and_reverse_energy_analysis( + self, + ) -> dict[str, list[Optional[dict[str, Union[npt.NDArray, Quantity]]]]]: + """ + Get the reverse and forward analysis of the free energies. + + Returns + ------- + forward_reverse : dict[str, list[Optional[dict[str, Union[npt.NDArray, openff.units.Quantity]]]]] + A dictionary, keyed for each leg of the thermodynamic cycle, + either ``solvent`` and ``vaccuum` for a solvation free energy or + ``solvent`` and ``complex`` for a binding free energy, + with each containing a list of dictionaries containing the forward + and reverse analysis of each repeat of that simulation type. + + The forward and reverse analysis dictionaries contain: + - `fractions`: npt.NDArray + The fractions of data used for the estimates + - `forward_DGs`, `reverse_DGs`: openff.units.Quantity + The forward and reverse estimates for each fraction of data + - `forward_dDGs`, `reverse_dDGs`: openff.units.Quantity + The forward and reverse estimate uncertainty for each + fraction of data. + + If one of the cycle leg list entries is ``None``, this indicates + that the analysis could not be carried out for that repeat. This + is most likely caused by MBAR convergence issues when attempting to + calculate free energies from too few samples. + + Raises + ------ + UserWarning + * If any of the forward and reverse dictionaries are ``None`` in a + given thermodynamic cycle leg. + """ + + forward_reverse: dict[str, list[Optional[dict[str, Union[npt.NDArray, Quantity]]]]] = {} + + for key in [self.bound_state, self.unbound_state]: + forward_reverse[key] = [ + pus[0].outputs["forward_and_reverse_energies"] for pus in self.data[key].values() + ] + + if None in forward_reverse[key]: + wmsg = ( + "One or more ``None`` entries were found in the forward " + f"and reverse dictionaries of the repeats of the {key} " + "calculations. This is likely caused by an MBAR convergence " + "failure caused by too few independent samples when " + "calculating the free energies of the 10% timeseries slice." + ) + warnings.warn(wmsg) + + return forward_reverse + + def get_overlap_matrices(self) -> dict[str, list[dict[str, npt.NDArray]]]: + """ + Get a the MBAR overlap estimates for all legs of the simulation. + + Returns + ------- + overlap_stats : dict[str, list[dict[str, npt.NDArray]]] + A dictionary keyed for each leg of the thermodynamic cycle, either + ``solvent`` and ``vaccuum` for a solvation free energy or + ``solvent`` and ``complex`` for a binding free energy, + with each containing a list of dictionaries with the MBAR overlap + estimates of each repeat of that simulation type. + + The underlying MBAR dictionaries contain the following keys: + * ``scalar``: One minus the largest nontrivial eigenvalue + * ``eigenvalues``: The sorted (descending) eigenvalues of the + overlap matrix + * ``matrix``: Estimated overlap matrix of observing a sample from + state i in state j + """ + # Loop through and get the repeats and get the matrices + overlap_stats: dict[str, list[dict[str, npt.NDArray]]] = {} + + for key in [self.bound_state, self.unbound_state]: + overlap_stats[key] = [ + pus[0].outputs["unit_mbar_overlap"] for pus in self.data[key].values() + ] + + return overlap_stats + + def get_replica_transition_statistics(self) -> dict[str, list[dict[str, npt.NDArray]]]: + """ + Get the replica exchange transition statistics for all + legs of the simulation. + + Note + ---- + This is currently only available in cases where a replica exchange + simulation was run. + + Returns + ------- + repex_stats : dict[str, list[dict[str, npt.NDArray]]] + A dictionary with keys for each leg of the thermodynamic cycle, either + ``solvent`` and ``vaccuum` for a solvation free energy or + ``solvent`` and ``complex`` for a binding free energy, + with each containing a list of dictionaries containing the replica + transition statistics for each repeat of that simulation type. + + The replica transition statistics dictionaries contain the following: + * ``eigenvalues``: The sorted (descending) eigenvalues of the + lambda state transition matrix + * ``matrix``: The transition matrix estimate of a replica switching + from state i to state j. + """ + repex_stats: dict[str, list[dict[str, npt.NDArray]]] = {} + try: + for key in [self.bound_state, self.unbound_state]: + repex_stats[key] = [ + pus[0].outputs["replica_exchange_statistics"] for pus in self.data[key].values() + ] + except KeyError: + errmsg = "Replica exchange statistics were not found, did you run a repex calculation?" + raise ValueError(errmsg) + + return repex_stats + + def get_replica_states(self) -> dict[str, list[npt.NDArray]]: + """ + Get the timeseries of replica states for all simulation legs. + + Returns + ------- + replica_states : dict[str, list[npt.NDArray]] + Dictionary keyed for each leg of the thermodynamic cycle, either + `solvent` and `vacuum` for solvation free energies, + or `complex` and `solvent` for binding free energies, + with lists of replica states timeseries for each repeat of that + simulation type. + """ + replica_states: dict[str, list[npt.NDArray]] = { + self.bound_state: [], + self.unbound_state: [] + } + + def is_file(filename: str): + p = pathlib.Path(filename) + + if not p.exists(): + errmsg = f"File could not be found {p}" + raise ValueError(errmsg) + + return p + + def get_replica_state(nc, chk): + nc = is_file(nc) + dir_path = nc.parents[0] + chk = is_file(dir_path / chk).name + + reporter = multistate.MultiStateReporter( + storage=nc, checkpoint_storage=chk, open_mode="r" + ) + + retval = np.asarray(reporter.read_replica_thermodynamic_states()) + reporter.close() + + return retval + + for key in [self.bound_state, self.unbound_state]: + for pus in self.data[key].values(): + states = get_replica_state( + pus[0].outputs["nc"], + pus[0].outputs["last_checkpoint"], + ) + replica_states[key].append(states) + + return replica_states + + def equilibration_iterations(self) -> dict[str, list[float]]: + """ + Get the number of equilibration iterations for each simulation. + + Returns + ------- + equilibration_lengths : dict[str, list[float]] + Dictionary keyed for each leg of the thermodynamic cycle, either + `solvent` and `vacuum` for solvation free energies, + or `complex` and `solvent` for binding free energies, + with lists containing the number of equilibration iterations for + each repeat of that simulation type. + """ + equilibration_lengths: dict[str, list[float]] = {} + + for key in [self.bound_state, self.unbound_state]: + equilibration_lengths[key] = [ + pus[0].outputs["equilibration_iterations"] for pus in self.data[key].values() + ] + + return equilibration_lengths + + def production_iterations(self) -> dict[str, list[float]]: + """ + Get the number of production iterations for each simulation. + Returns the number of uncorrelated production samples for each + repeat of the calculation. + + Returns + ------- + production_lengths : dict[str, list[float]] + Dictionary keyed for each leg of the thermodynamic cycle, either + `solvent` and `vacuum` for solvation free energies, + or `complex` and `solvent` for binding free energies, + with lists containing the number of equilibration iterations for + each repeat of that simulation type. + """ + production_lengths: dict[str, list[float]] = {} + + for key in [self.bound_state, self.unbound_state]: + production_lengths[key] = [ + pus[0].outputs["production_iterations"] for pus in self.data[key].values() + ] + + return production_lengths + + def selection_indices(self) -> dict[str, list[Optional[npt.NDArray]]]: + """ + Get the system selection indices used to write PDB and + trajectory files. + + Returns + ------- + indices : dict[str, list[npt.NDArray]] + A dictionary keyed for each state, either + `solvent` and `vacuum` for solvation free energies, + or `complex` and `solvent` for binding free energies, + each containing a list of NDArrays containing the corresponding + full system atom indices for each atom written in the production + trajectory files for each replica. + """ + indices: dict[str, list[Optional[npt.NDArray]]] = {} + + for key in [self.bound_state, self.unbound_state]: + indices[key] = [] + for pus in self.data[key].values(): + indices[key].append(pus[0].outputs["selection_indices"]) + + return indices + + +class AbsoluteSolvationProtocolResult(gufe.ProtocolResult, AbsoluteProtocolResultMixin): + """Dict-like container for the output of a AbsoluteSolvationProtocol""" + + bound_state = "solvent" + unbound_state = "vacuum" + + def get_individual_estimates(self) -> dict[str, list[tuple[Quantity, Quantity]]]: + """ + Get the individual estimate of the free energies. + + Returns + ------- + dGs : dict[str, list[tuple[openff.units.Quantity, openff.units.Quantity]]] + A dictionary, keyed `solvent` and `vacuum` for each leg + of the thermodynamic cycle, with lists of tuples containing + the individual free energy estimates and associated MBAR + uncertainties for each repeat of that simulation type. + """ + vac_dGs = [] + solv_dGs = [] + + for pus in self.data["vacuum"].values(): + vac_dGs.append((pus[0].outputs["unit_estimate"], pus[0].outputs["unit_estimate_error"])) + + for pus in self.data["solvent"].values(): + solv_dGs.append( + (pus[0].outputs["unit_estimate"], pus[0].outputs["unit_estimate_error"]) + ) + + return {"solvent": solv_dGs, "vacuum": vac_dGs} + + def get_estimate(self): + """Get the solvation free energy estimate for this calculation. + + Returns + ------- + dG : openff.units.Quantity + The solvation free energy. This is a Quantity defined with units. + """ + + def _get_average(estimates): + # Get the unit value of the first value in the estimates + u = estimates[0][0].u + # Loop through estimates and get the free energy values + # in the unit of the first estimate + dGs = [i[0].to(u).m for i in estimates] + + return np.average(dGs) * u + + individual_estimates = self.get_individual_estimates() + vac_dG = _get_average(individual_estimates["vacuum"]) + solv_dG = _get_average(individual_estimates["solvent"]) + + return vac_dG - solv_dG + + def get_uncertainty(self): + """Get the solvation free energy error for this calculation. + + Returns + ------- + err : openff.units.Quantity + The standard deviation between estimates of the solvation free + energy. This is a Quantity defined with units. + """ + + def _get_stdev(estimates): + # Get the unit value of the first value in the estimates + u = estimates[0][0].u + # Loop through estimates and get the free energy values + # in the unit of the first estimate + dGs = [i[0].to(u).m for i in estimates] + + return np.std(dGs) * u + + individual_estimates = self.get_individual_estimates() + vac_err = _get_stdev(individual_estimates["vacuum"]) + solv_err = _get_stdev(individual_estimates["solvent"]) + + # return the combined error + return np.sqrt(vac_err**2 + solv_err**2) + + +class AbsoluteBindingProtocolResult(gufe.ProtocolResult, AbsoluteProtocolResultMixin): + """Dict-like container for the output of a AbsoluteBindingProtocol""" + + bound_state = "complex" + unbound_state = "solvent" + + def get_individual_estimates( + self, + ) -> dict[str, list[tuple[Quantity, Quantity]]]: + """ + Get the individual estimate of the free energies. + + Returns + ------- + dGs : dict[str, list[tuple[openff.units.Quantity, openff.units.Quantity]]] + A dictionary, keyed `solvent`, `complex`, and 'standard_state' + representing each portion of the thermodynamic cycle, + with lists of tuples containing the individual free energy + estimates and, for 'solvent' and 'complex', the associated MBAR + uncertainties for each repeat of that simulation type. + + Notes + ----- + * Standard state correction has no error and so will return a value + of 0. + """ + complex_dGs = [] + correction_dGs = [] + solv_dGs = [] + + for pus in self.data["complex"].values(): + complex_dGs.append( + (pus[0].outputs["unit_estimate"], pus[0].outputs["unit_estimate_error"]) + ) + correction_dGs.append( + ( + pus[0].outputs["standard_state_correction"], + 0 * offunit.kilocalorie_per_mole, # correction has no error + ) + ) + + for pus in self.data["solvent"].values(): + solv_dGs.append( + (pus[0].outputs["unit_estimate"], pus[0].outputs["unit_estimate_error"]) + ) + + return { + "solvent": solv_dGs, + "complex": complex_dGs, + "standard_state_correction": correction_dGs, + } + + @staticmethod + def _add_complex_standard_state_corr( + complex_dG: list[tuple[Quantity, Quantity]], + standard_state_dG: list[tuple[Quantity, Quantity]], + ) -> list[tuple[Quantity, Quantity]]: + """ + Helper method to combine the + complex & standard state corrections legs. + + Parameters + ---------- + complex_dG : list[tuple[openff.units.Quantity, openff.units.Quantity]] + The individual estimates of the complex leg, + where the first entry of each tuple is the dG estimate + and the second entry is the MBAR error. + standard_state_dG : list[tuple[Quantity, Quantity]] + The individual standard state corrections for each corresponding + complex leg. The first entry is the correction, the second + is an empty error value of 0. + + Returns + ------- + combined_dG : list[tuple[openff.units.Quantity,openff.units. Quantity]] + A list of dG estimates & MBAR errors for the combined + complex & standard state correction of each repeat. + + Notes + ----- + We assume that both list of items are in the right order. + """ + combined_dG: list[tuple[Quantity, Quantity]] = [] + for comp, corr in zip(complex_dG, standard_state_dG): + # No need to convert unit types, since pint takes care of that + # except that mypy hates it because pint isn't typed properly... + # No need to add errors since there's just the one + combined_dG.append((comp[0] + corr[0], comp[1])) # type: ignore[operator] + + return combined_dG + + def get_estimate(self) -> Quantity: + """Get the binding free energy estimate for this calculation. + + Returns + ------- + dG : openff.units.Quantity + The binding free energy. This is a Quantity defined with units. + """ + + def _get_average(estimates): + # Get the unit value of the first value in the estimates + u = estimates[0][0].u + # Loop through estimates and get the free energy values + # in the unit of the first estimate + dGs = [i[0].to(u).m for i in estimates] + + return np.average(dGs) * u + + individual_estimates = self.get_individual_estimates() + complex_dG = _get_average( + self._add_complex_standard_state_corr( + individual_estimates["complex"], + individual_estimates["standard_state_correction"], + ) + ) + solv_dG = _get_average(individual_estimates["solvent"]) + + return -complex_dG + solv_dG + + def get_uncertainty(self) -> Quantity: + """Get the binding free energy error for this calculation. + + Returns + ------- + err : openff.units.Quantity + The standard deviation between estimates of the binding free + energy. This is a Quantity defined with units. + """ + + def _get_stdev(estimates): + # Get the unit value of the first value in the estimates + u = estimates[0][0].u + # Loop through estimates and get the free energy values + # in the unit of the first estimate + dGs = [i[0].to(u).m for i in estimates] + + return np.std(dGs) * u + + individual_estimates = self.get_individual_estimates() + + complex_err = _get_stdev( + self._add_complex_standard_state_corr( + individual_estimates["complex"], + individual_estimates["standard_state_correction"], + ) + ) + solv_err = _get_stdev(individual_estimates["solvent"]) + + # return the combined error + return np.sqrt(complex_err**2 + solv_err**2) + + def restraint_geometries(self) -> list[BoreschRestraintGeometry]: + """ + Get a list of the restraint geometries for the + complex simulations. These define the atoms that have + been restrained in the system. + + Returns + ------- + geometries : list[dict[str, Any]] + A list of dictionaries containing the details of the atoms + in the system that are involved in the restraint. + """ + geometries = [ + BoreschRestraintGeometry.model_validate(pus[0].outputs["restraint_geometry"]) + for pus in self.data["complex"].values() + ] + + return geometries diff --git a/openfe/protocols/openmm_afe/equil_binding_afe_method.py b/openfe/protocols/openmm_afe/equil_binding_afe_method.py index dc12f1a8e..fa644928a 100644 --- a/openfe/protocols/openmm_afe/equil_binding_afe_method.py +++ b/openfe/protocols/openmm_afe/equil_binding_afe_method.py @@ -24,7 +24,6 @@ """ -import itertools import logging import pathlib import uuid @@ -50,7 +49,6 @@ from openmm import System from openmm import unit as ommunit from openmm.app import Topology as omm_topology -from openmmtools import multistate from openmmtools.states import GlobalParameterState, ThermodynamicState from rdkit import Chem @@ -103,418 +101,6 @@ logger = logging.getLogger(__name__) -class AbsoluteBindingProtocolResult(gufe.ProtocolResult): - """Dict-like container for the output of a AbsoluteBindingProtocol""" - - def __init__(self, **data): - super().__init__(**data) - # TODO: Detect when we have extensions and stitch these together? - if any( - len(pur_list) > 2 - for pur_list in itertools.chain( - self.data["solvent"].values(), self.data["complex"].values() - ) - ): - raise NotImplementedError("Can't stitch together results yet") - - def get_individual_estimates( - self, - ) -> dict[str, list[tuple[Quantity, Quantity]]]: - """ - Get the individual estimate of the free energies. - - Returns - ------- - dGs : dict[str, list[tuple[openff.units.Quantity, openff.units.Quantity]]] - A dictionary, keyed `solvent`, `complex`, and 'standard_state' - representing each portion of the thermodynamic cycle, - with lists of tuples containing the individual free energy - estimates and, for 'solvent' and 'complex', the associated MBAR - uncertainties for each repeat of that simulation type. - - Notes - ----- - * Standard state correction has no error and so will return a value - of 0. - """ - complex_dGs = [] - correction_dGs = [] - solv_dGs = [] - - for pus in self.data["complex"].values(): - complex_dGs.append( - (pus[0].outputs["unit_estimate"], pus[0].outputs["unit_estimate_error"]) - ) - correction_dGs.append( - ( - pus[0].outputs["standard_state_correction"], - 0 * offunit.kilocalorie_per_mole, # correction has no error - ) - ) - - for pus in self.data["solvent"].values(): - solv_dGs.append( - (pus[0].outputs["unit_estimate"], pus[0].outputs["unit_estimate_error"]) - ) - - return { - "solvent": solv_dGs, - "complex": complex_dGs, - "standard_state_correction": correction_dGs, - } - - @staticmethod - def _add_complex_standard_state_corr( - complex_dG: list[tuple[Quantity, Quantity]], - standard_state_dG: list[tuple[Quantity, Quantity]], - ) -> list[tuple[Quantity, Quantity]]: - """ - Helper method to combine the - complex & standard state corrections legs. - - Parameters - ---------- - complex_dG : list[tuple[openff.units.Quantity, openff.units.Quantity]] - The individual estimates of the complex leg, - where the first entry of each tuple is the dG estimate - and the second entry is the MBAR error. - standard_state_dG : list[tuple[Quantity, Quantity]] - The individual standard state corrections for each corresponding - complex leg. The first entry is the correction, the second - is an empty error value of 0. - - Returns - ------- - combined_dG : list[tuple[openff.units.Quantity,openff.units. Quantity]] - A list of dG estimates & MBAR errors for the combined - complex & standard state correction of each repeat. - - Notes - ----- - We assume that both list of items are in the right order. - """ - combined_dG: list[tuple[Quantity, Quantity]] = [] - for comp, corr in zip(complex_dG, standard_state_dG): - # No need to convert unit types, since pint takes care of that - # except that mypy hates it because pint isn't typed properly... - # No need to add errors since there's just the one - combined_dG.append((comp[0] + corr[0], comp[1])) # type: ignore[operator] - - return combined_dG - - def get_estimate(self) -> Quantity: - """Get the binding free energy estimate for this calculation. - - Returns - ------- - dG : openff.units.Quantity - The binding free energy. This is a Quantity defined with units. - """ - - def _get_average(estimates): - # Get the unit value of the first value in the estimates - u = estimates[0][0].u - # Loop through estimates and get the free energy values - # in the unit of the first estimate - dGs = [i[0].to(u).m for i in estimates] - - return np.average(dGs) * u - - individual_estimates = self.get_individual_estimates() - complex_dG = _get_average( - self._add_complex_standard_state_corr( - individual_estimates["complex"], - individual_estimates["standard_state_correction"], - ) - ) - solv_dG = _get_average(individual_estimates["solvent"]) - - return -complex_dG + solv_dG - - def get_uncertainty(self) -> Quantity: - """Get the binding free energy error for this calculation. - - Returns - ------- - err : openff.units.Quantity - The standard deviation between estimates of the binding free - energy. This is a Quantity defined with units. - """ - - def _get_stdev(estimates): - # Get the unit value of the first value in the estimates - u = estimates[0][0].u - # Loop through estimates and get the free energy values - # in the unit of the first estimate - dGs = [i[0].to(u).m for i in estimates] - - return np.std(dGs) * u - - individual_estimates = self.get_individual_estimates() - - complex_err = _get_stdev( - self._add_complex_standard_state_corr( - individual_estimates["complex"], individual_estimates["standard_state_correction"] - ) - ) - solv_err = _get_stdev(individual_estimates["solvent"]) - - # return the combined error - return np.sqrt(complex_err**2 + solv_err**2) - - def get_forward_and_reverse_energy_analysis( - self, - ) -> dict[str, list[Optional[dict[str, Union[npt.NDArray, Quantity]]]]]: - """ - Get the reverse and forward analysis of the free energies. - - Returns - ------- - forward_reverse : dict[str, list[Optional[dict[str, Union[npt.NDArray, openff.units.Quantity]]]]] - A dictionary, keyed `solvent` and `complex` for each leg of the - thermodynamic cycle which each contain a list of dictionaries - containing the forward and reverse analysis of each repeat - of that simulation type. - - The forward and reverse analysis dictionaries contain: - - `fractions`: npt.NDArray - The fractions of data used for the estimates - - `forward_DGs`, `reverse_DGs`: openff.units.Quantity - The forward and reverse estimates for each fraction of data - - `forward_dDGs`, `reverse_dDGs`: openff.units.Quantity - The forward and reverse estimate uncertainty for each - fraction of data. - - If one of the cycle leg list entries is ``None``, this indicates - that the analysis could not be carried out for that repeat. This - is most likely caused by MBAR convergence issues when attempting to - calculate free energies from too few samples. - - Raises - ------ - UserWarning - * If any of the forward and reverse dictionaries are ``None`` in a - given thermodynamic cycle leg. - """ - - forward_reverse: dict[str, list[Optional[dict[str, Union[npt.NDArray, Quantity]]]]] = {} - - for key in ["solvent", "complex"]: - forward_reverse[key] = [ - pus[0].outputs["forward_and_reverse_energies"] for pus in self.data[key].values() - ] - - if None in forward_reverse[key]: - wmsg = ( - "One or more ``None`` entries were found in the forward " - f"and reverse dictionaries of the repeats of the {key} " - "calculations. This is likely caused by an MBAR convergence " - "failure caused by too few independent samples when " - "calculating the free energies of the 10% timeseries slice." - ) - warnings.warn(wmsg) - - return forward_reverse - - def get_overlap_matrices(self) -> dict[str, list[dict[str, npt.NDArray]]]: - """ - Get a the MBAR overlap estimates for all legs of the simulation. - - Returns - ------- - overlap_stats : dict[str, list[dict[str, npt.NDArray]]] - A dictionary with keys `solvent` and `complex` for each - leg of the thermodynamic cycle, which each containing a - list of dictionaries with the MBAR overlap estimates of - each repeat of that simulation type. - - The underlying MBAR dictionaries contain the following keys: - * ``scalar``: One minus the largest nontrivial eigenvalue - * ``eigenvalues``: The sorted (descending) eigenvalues of the - overlap matrix - * ``matrix``: Estimated overlap matrix of observing a sample from - state i in state j - """ - # Loop through and get the repeats and get the matrices - overlap_stats: dict[str, list[dict[str, npt.NDArray]]] = {} - - for key in ["solvent", "complex"]: - overlap_stats[key] = [ - pus[0].outputs["unit_mbar_overlap"] for pus in self.data[key].values() - ] - - return overlap_stats - - def get_replica_transition_statistics( - self, - ) -> dict[str, list[dict[str, npt.NDArray]]]: - """ - Get the replica exchange transition statistics for all - legs of the simulation. - - Note - ---- - This is currently only available in cases where a replica exchange - simulation was run. - - Returns - ------- - repex_stats : dict[str, list[dict[str, npt.NDArray]]] - A dictionary with keys `solvent` and `complex` for each - leg of the thermodynamic cycle, which each containing - a list of dictionaries containing the replica transition - statistics for each repeat of that simulation type. - - The replica transition statistics dictionaries contain the following: - * ``eigenvalues``: The sorted (descending) eigenvalues of the - lambda state transition matrix - * ``matrix``: The transition matrix estimate of a replica switching - from state i to state j. - """ - repex_stats: dict[str, list[dict[str, npt.NDArray]]] = {} - try: - for key in ["solvent", "complex"]: - repex_stats[key] = [ - pus[0].outputs["replica_exchange_statistics"] for pus in self.data[key].values() - ] - except KeyError: - errmsg = "Replica exchange statistics were not found, did you run a repex calculation?" - raise ValueError(errmsg) - - return repex_stats - - def get_replica_states(self) -> dict[str, list[npt.NDArray]]: - """ - Get the timeseries of replica states for all simulation legs. - - Returns - ------- - replica_states : dict[str, list[npt.NDArray]] - Dictionary keyed `solvent` and `complex` for each leg of - the thermodynamic cycle, with lists of replica states - timeseries for each repeat of that simulation type. - """ - replica_states: dict[str, list[npt.NDArray]] = {"solvent": [], "complex": []} - - def is_file(filename: str): - p = pathlib.Path(filename) - - if not p.exists(): - errmsg = f"File could not be found {p}" - raise ValueError(errmsg) - - return p - - def get_replica_state(nc, chk): - nc = is_file(nc) - dir_path = nc.parents[0] - chk = is_file(dir_path / chk).name - - reporter = multistate.MultiStateReporter( - storage=nc, checkpoint_storage=chk, open_mode="r" - ) - - retval = np.asarray(reporter.read_replica_thermodynamic_states()) - reporter.close() - - return retval - - for key in ["solvent", "complex"]: - for pus in self.data[key].values(): - states = get_replica_state( - pus[0].outputs["nc"], - pus[0].outputs["last_checkpoint"], - ) - replica_states[key].append(states) - - return replica_states - - def equilibration_iterations(self) -> dict[str, list[float]]: - """ - Get the number of equilibration iterations for each simulation. - - Returns - ------- - equilibration_lengths : dict[str, list[float]] - Dictionary keyed `solvent` and `complex` for each leg - of the thermodynamic cycle, with lists containing the - number of equilibration iterations for each repeat - of that simulation type. - """ - equilibration_lengths: dict[str, list[float]] = {} - - for key in ["solvent", "complex"]: - equilibration_lengths[key] = [ - pus[0].outputs["equilibration_iterations"] for pus in self.data[key].values() - ] - - return equilibration_lengths - - def production_iterations(self) -> dict[str, list[float]]: - """ - Get the number of production iterations for each simulation. - Returns the number of uncorrelated production samples for each - repeat of the calculation. - - Returns - ------- - production_lengths : dict[str, list[float]] - Dictionary keyed `solvent` and `complex` for each leg of the - thermodynamic cycle, with lists with the number - of production iterations for each repeat of that simulation - type. - """ - production_lengths: dict[str, list[float]] = {} - - for key in ["solvent", "complex"]: - production_lengths[key] = [ - pus[0].outputs["production_iterations"] for pus in self.data[key].values() - ] - - return production_lengths - - def restraint_geometries(self) -> list[BoreschRestraintGeometry]: - """ - Get a list of the restraint geometries for the - complex simulations. These define the atoms that have - been restrained in the system. - - Returns - ------- - geometries : list[dict[str, Any]] - A list of dictionaries containing the details of the atoms - in the system that are involved in the restraint. - """ - geometries = [ - BoreschRestraintGeometry.model_validate(pus[0].outputs["restraint_geometry"]) - for pus in self.data["complex"].values() - ] - - return geometries - - def selection_indices(self) -> dict[str, list[Optional[npt.NDArray]]]: - """ - Get the system selection indices used to write PDB and - trajectory files. - - Returns - ------- - indices : dict[str, list[npt.NDArray]] - A dictionary keyed as `complex` and `solvent` for each - state, each containing a list of NDArrays containing the corresponding - full system atom indices for each atom written in the production - trajectory files for each replica. - """ - indices: dict[str, list[Optional[npt.NDArray]]] = {} - - for key in ["complex", "solvent"]: - indices[key] = [] - for pus in self.data[key].values(): - indices[key].append(pus[0].outputs["selection_indices"]) - - return indices - - class AbsoluteBindingProtocol(gufe.Protocol): """ Absolute binding free energy calculations using OpenMM and OpenMMTools. diff --git a/openfe/protocols/openmm_afe/equil_solvation_afe_method.py b/openfe/protocols/openmm_afe/equil_solvation_afe_method.py index 4ffb94770..50154e953 100644 --- a/openfe/protocols/openmm_afe/equil_solvation_afe_method.py +++ b/openfe/protocols/openmm_afe/equil_solvation_afe_method.py @@ -30,9 +30,7 @@ from __future__ import annotations -import itertools import logging -import pathlib import uuid import warnings from collections import defaultdict @@ -40,7 +38,6 @@ import gufe import numpy as np -import numpy.typing as npt from gufe import ( ChemicalSystem, ProteinComponent, @@ -48,9 +45,7 @@ SolventComponent, settings, ) -from gufe.components import Component -from openff.units import Quantity, unit -from openmmtools import multistate +from openff.units import offunit from openfe.due import Doi, due from openfe.protocols.openmm_afe.equil_afe_settings import ( @@ -103,305 +98,6 @@ logger = logging.getLogger(__name__) -class AbsoluteSolvationProtocolResult(gufe.ProtocolResult): - """Dict-like container for the output of a AbsoluteSolvationProtocol""" - - def __init__(self, **data): - super().__init__(**data) - # TODO: Detect when we have extensions and stitch these together? - if any( - len(pur_list) > 2 - for pur_list in itertools.chain( - self.data["solvent"].values(), self.data["vacuum"].values() - ) - ): - raise NotImplementedError("Can't stitch together results yet") - - def get_individual_estimates(self) -> dict[str, list[tuple[Quantity, Quantity]]]: - """ - Get the individual estimate of the free energies. - - Returns - ------- - dGs : dict[str, list[tuple[openff.units.Quantity, openff.units.Quantity]]] - A dictionary, keyed `solvent` and `vacuum` for each leg - of the thermodynamic cycle, with lists of tuples containing - the individual free energy estimates and associated MBAR - uncertainties for each repeat of that simulation type. - """ - vac_dGs = [] - solv_dGs = [] - - for pus in self.data["vacuum"].values(): - vac_dGs.append((pus[0].outputs["unit_estimate"], pus[0].outputs["unit_estimate_error"])) - - for pus in self.data["solvent"].values(): - solv_dGs.append( - (pus[0].outputs["unit_estimate"], pus[0].outputs["unit_estimate_error"]) - ) - - return {"solvent": solv_dGs, "vacuum": vac_dGs} - - def get_estimate(self): - """Get the solvation free energy estimate for this calculation. - - Returns - ------- - dG : openff.units.Quantity - The solvation free energy. This is a Quantity defined with units. - """ - - def _get_average(estimates): - # Get the unit value of the first value in the estimates - u = estimates[0][0].u - # Loop through estimates and get the free energy values - # in the unit of the first estimate - dGs = [i[0].to(u).m for i in estimates] - - return np.average(dGs) * u - - individual_estimates = self.get_individual_estimates() - vac_dG = _get_average(individual_estimates["vacuum"]) - solv_dG = _get_average(individual_estimates["solvent"]) - - return vac_dG - solv_dG - - def get_uncertainty(self): - """Get the solvation free energy error for this calculation. - - Returns - ------- - err : openff.units.Quantity - The standard deviation between estimates of the solvation free - energy. This is a Quantity defined with units. - """ - - def _get_stdev(estimates): - # Get the unit value of the first value in the estimates - u = estimates[0][0].u - # Loop through estimates and get the free energy values - # in the unit of the first estimate - dGs = [i[0].to(u).m for i in estimates] - - return np.std(dGs) * u - - individual_estimates = self.get_individual_estimates() - vac_err = _get_stdev(individual_estimates["vacuum"]) - solv_err = _get_stdev(individual_estimates["solvent"]) - - # return the combined error - return np.sqrt(vac_err**2 + solv_err**2) - - def get_forward_and_reverse_energy_analysis( - self, - ) -> dict[str, list[Optional[dict[str, Union[npt.NDArray, Quantity]]]]]: - """ - Get the reverse and forward analysis of the free energies. - - Returns - ------- - forward_reverse : dict[str, list[Optional[dict[str, Union[npt.NDArray, openff.units.Quantity]]]]] - A dictionary, keyed `solvent` and `vacuum` for each leg of the - thermodynamic cycle which each contain a list of dictionaries - containing the forward and reverse analysis of each repeat - of that simulation type. - - The forward and reverse analysis dictionaries contain: - - `fractions`: npt.NDArray - The fractions of data used for the estimates - - `forward_DGs`, `reverse_DGs`: openff.units.Quantity - The forward and reverse estimates for each fraction of data - - `forward_dDGs`, `reverse_dDGs`: openff.units.Quantity - The forward and reverse estimate uncertainty for each - fraction of data. - - If one of the cycle leg list entries is ``None``, this indicates - that the analysis could not be carried out for that repeat. This - is most likely caused by MBAR convergence issues when attempting to - calculate free energies from too few samples. - - Raises - ------ - UserWarning - * If any of the forward and reverse dictionaries are ``None`` in a - given thermodynamic cycle leg. - """ - - forward_reverse: dict[str, list[Optional[dict[str, Union[npt.NDArray, Quantity]]]]] = {} - - for key in ["solvent", "vacuum"]: - forward_reverse[key] = [ - pus[0].outputs["forward_and_reverse_energies"] for pus in self.data[key].values() - ] - - if None in forward_reverse[key]: - wmsg = ( - "One or more ``None`` entries were found in the forward " - f"and reverse dictionaries of the repeats of the {key} " - "calculations. This is likely caused by an MBAR convergence " - "failure caused by too few independent samples when " - "calculating the free energies of the 10% timeseries slice." - ) - warnings.warn(wmsg) - - return forward_reverse - - def get_overlap_matrices(self) -> dict[str, list[dict[str, npt.NDArray]]]: - """ - Get a the MBAR overlap estimates for all legs of the simulation. - - Returns - ------- - overlap_stats : dict[str, list[dict[str, npt.NDArray]]] - A dictionary with keys `solvent` and `vacuum` for each - leg of the thermodynamic cycle, which each containing a - list of dictionaries with the MBAR overlap estimates of - each repeat of that simulation type. - - The underlying MBAR dictionaries contain the following keys: - * ``scalar``: One minus the largest nontrivial eigenvalue - * ``eigenvalues``: The sorted (descending) eigenvalues of the - overlap matrix - * ``matrix``: Estimated overlap matrix of observing a sample from - state i in state j - """ - # Loop through and get the repeats and get the matrices - overlap_stats: dict[str, list[dict[str, npt.NDArray]]] = {} - - for key in ["solvent", "vacuum"]: - overlap_stats[key] = [ - pus[0].outputs["unit_mbar_overlap"] for pus in self.data[key].values() - ] - - return overlap_stats - - def get_replica_transition_statistics(self) -> dict[str, list[dict[str, npt.NDArray]]]: - """ - Get the replica exchange transition statistics for all - legs of the simulation. - - Note - ---- - This is currently only available in cases where a replica exchange - simulation was run. - - Returns - ------- - repex_stats : dict[str, list[dict[str, npt.NDArray]]] - A dictionary with keys `solvent` and `vacuum` for each - leg of the thermodynamic cycle, which each containing - a list of dictionaries containing the replica transition - statistics for each repeat of that simulation type. - - The replica transition statistics dictionaries contain the following: - * ``eigenvalues``: The sorted (descending) eigenvalues of the - lambda state transition matrix - * ``matrix``: The transition matrix estimate of a replica switching - from state i to state j. - """ - repex_stats: dict[str, list[dict[str, npt.NDArray]]] = {} - try: - for key in ["solvent", "vacuum"]: - repex_stats[key] = [ - pus[0].outputs["replica_exchange_statistics"] for pus in self.data[key].values() - ] - except KeyError: - errmsg = "Replica exchange statistics were not found, did you run a repex calculation?" - raise ValueError(errmsg) - - return repex_stats - - def get_replica_states(self) -> dict[str, list[npt.NDArray]]: - """ - Get the timeseries of replica states for all simulation legs. - - Returns - ------- - replica_states : dict[str, list[npt.NDArray]] - Dictionary keyed `solvent` and `vacuum` for each leg of - the thermodynamic cycle, with lists of replica states - timeseries for each repeat of that simulation type. - """ - replica_states: dict[str, list[npt.NDArray]] = {"solvent": [], "vacuum": []} - - def is_file(filename: str): - p = pathlib.Path(filename) - - if not p.exists(): - errmsg = f"File could not be found {p}" - raise ValueError(errmsg) - - return p - - def get_replica_state(nc, chk): - nc = is_file(nc) - dir_path = nc.parents[0] - chk = is_file(dir_path / chk).name - - reporter = multistate.MultiStateReporter( - storage=nc, checkpoint_storage=chk, open_mode="r" - ) - - retval = np.asarray(reporter.read_replica_thermodynamic_states()) - reporter.close() - - return retval - - for key in ["solvent", "vacuum"]: - for pus in self.data[key].values(): - states = get_replica_state( - pus[0].outputs["nc"], - pus[0].outputs["last_checkpoint"], - ) - replica_states[key].append(states) - - return replica_states - - def equilibration_iterations(self) -> dict[str, list[float]]: - """ - Get the number of equilibration iterations for each simulation. - - Returns - ------- - equilibration_lengths : dict[str, list[float]] - Dictionary keyed `solvent` and `vacuum` for each leg - of the thermodynamic cycle, with lists containing the - number of equilibration iterations for each repeat - of that simulation type. - """ - equilibration_lengths: dict[str, list[float]] = {} - - for key in ["solvent", "vacuum"]: - equilibration_lengths[key] = [ - pus[0].outputs["equilibration_iterations"] for pus in self.data[key].values() - ] - - return equilibration_lengths - - def production_iterations(self) -> dict[str, list[float]]: - """ - Get the number of production iterations for each simulation. - Returns the number of uncorrelated production samples for each - repeat of the calculation. - - Returns - ------- - production_lengths : dict[str, list[float]] - Dictionary keyed `solvent` and `vacuum` for each leg of the - thermodynamic cycle, with lists with the number - of production iterations for each repeat of that simulation - type. - """ - production_lengths: dict[str, list[float]] = {} - - for key in ["solvent", "vacuum"]: - production_lengths[key] = [ - pus[0].outputs["production_iterations"] for pus in self.data[key].values() - ] - - return production_lengths - - class AbsoluteSolvationProtocol(gufe.Protocol): """ Absolute solvation free energy calculations using OpenMM and OpenMMTools. @@ -439,8 +135,8 @@ def _default_settings(cls): nonbonded_method="nocutoff", ), thermo_settings=settings.ThermoSettings( - temperature=298.15 * unit.kelvin, - pressure=1 * unit.bar, + temperature=298.15 * offunit.kelvin, + pressure=1 * offunit.bar, ), alchemical_settings=AlchemicalSettings(), lambda_settings=LambdaSettings( @@ -460,9 +156,9 @@ def _default_settings(cls): solvent_engine_settings=OpenMMEngineSettings(), integrator_settings=IntegratorSettings(), solvent_equil_simulation_settings=MDSimulationSettings( - equilibration_length_nvt=0.1 * unit.nanosecond, - equilibration_length=0.2 * unit.nanosecond, - production_length=0.5 * unit.nanosecond, + equilibration_length_nvt=0.1 * offunit.nanosecond, + equilibration_length=0.2 * offunit.nanosecond, + production_length=0.5 * offunit.nanosecond, ), solvent_equil_output_settings=MDOutputSettings( equil_nvt_structure="equil_nvt_structure.pdb", @@ -472,8 +168,8 @@ def _default_settings(cls): ), solvent_simulation_settings=MultiStateSimulationSettings( n_replicas=14, - equilibration_length=1.0 * unit.nanosecond, - production_length=10.0 * unit.nanosecond, + equilibration_length=1.0 * offunit.nanosecond, + production_length=10.0 * offunit.nanosecond, ), solvent_output_settings=MultiStateOutputSettings( output_filename="solvent.nc", @@ -481,8 +177,8 @@ def _default_settings(cls): ), vacuum_equil_simulation_settings=MDSimulationSettings( equilibration_length_nvt=None, - equilibration_length=0.2 * unit.nanosecond, - production_length=0.5 * unit.nanosecond, + equilibration_length=0.2 * offunit.nanosecond, + production_length=0.5 * offunit.nanosecond, ), vacuum_equil_output_settings=MDOutputSettings( equil_nvt_structure=None, @@ -492,8 +188,8 @@ def _default_settings(cls): ), vacuum_simulation_settings=MultiStateSimulationSettings( n_replicas=14, - equilibration_length=0.5 * unit.nanosecond, - production_length=2.0 * unit.nanosecond, + equilibration_length=0.5 * offunit.nanosecond, + production_length=2.0 * offunit.nanosecond, ), vacuum_output_settings=MultiStateOutputSettings( output_filename="vacuum.nc", @@ -714,7 +410,7 @@ def _validate( # Check vacuum equilibration MD settings is 0 ns nvt_time = self.settings.vacuum_equil_simulation_settings.equilibration_length_nvt if nvt_time is not None: - if not np.allclose(nvt_time, 0 * unit.nanosecond): + if not np.allclose(nvt_time, 0 * offunit.nanosecond): errmsg = "NVT equilibration cannot be run in vacuum simulation" raise ValueError(errmsg) From 6b504a85e0e8cd0a59e12b2d8007c1c58e60bfd3 Mon Sep 17 00:00:00 2001 From: IAlibay Date: Thu, 1 Jan 2026 09:25:02 -0500 Subject: [PATCH 02/11] rename base units file --- openfe/protocols/openmm_afe/{base.py => afe_units.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename openfe/protocols/openmm_afe/{base.py => afe_units.py} (100%) diff --git a/openfe/protocols/openmm_afe/base.py b/openfe/protocols/openmm_afe/afe_units.py similarity index 100% rename from openfe/protocols/openmm_afe/base.py rename to openfe/protocols/openmm_afe/afe_units.py From 8abf04fd76909d8d908cda855b5a9b9e428b6196 Mon Sep 17 00:00:00 2001 From: IAlibay Date: Thu, 1 Jan 2026 09:29:46 -0500 Subject: [PATCH 03/11] fix some imports --- openfe/protocols/openmm_afe/afe_protocol_results.py | 3 --- openfe/protocols/openmm_afe/equil_binding_afe_method.py | 3 ++- openfe/protocols/openmm_afe/equil_solvation_afe_method.py | 4 +++- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/openfe/protocols/openmm_afe/afe_protocol_results.py b/openfe/protocols/openmm_afe/afe_protocol_results.py index e193aebe1..25a16a59e 100644 --- a/openfe/protocols/openmm_afe/afe_protocol_results.py +++ b/openfe/protocols/openmm_afe/afe_protocol_results.py @@ -27,9 +27,6 @@ `espaloma_charge `_ """ - -from __future__ import annotations - import itertools import logging import pathlib diff --git a/openfe/protocols/openmm_afe/equil_binding_afe_method.py b/openfe/protocols/openmm_afe/equil_binding_afe_method.py index fa644928a..56b85546e 100644 --- a/openfe/protocols/openmm_afe/equil_binding_afe_method.py +++ b/openfe/protocols/openmm_afe/equil_binding_afe_method.py @@ -74,7 +74,8 @@ from openfe.protocols.restraint_utils.openmm import omm_restraints from openfe.protocols.restraint_utils.openmm.omm_restraints import BoreschRestraint -from .base import BaseAbsoluteUnit +from .afe_units import BaseAbsoluteUnit +from .afe_protocol_results import AbsoluteBindingProtocolResult due.cite( Doi("10.5281/zenodo.596504"), diff --git a/openfe/protocols/openmm_afe/equil_solvation_afe_method.py b/openfe/protocols/openmm_afe/equil_solvation_afe_method.py index 50154e953..33eadaf86 100644 --- a/openfe/protocols/openmm_afe/equil_solvation_afe_method.py +++ b/openfe/protocols/openmm_afe/equil_solvation_afe_method.py @@ -64,7 +64,9 @@ ) from ..openmm_utils import settings_validation, system_validation -from .base import BaseAbsoluteUnit +from .afe_units import BaseAbsoluteUnit +from .afe_protocol_results import AbsoluteSolvationProtocolResult + due.cite( Doi("10.5281/zenodo.596504"), From a3e228f0e196dd9df493c38da3949e1bfa5d5390 Mon Sep 17 00:00:00 2001 From: IAlibay Date: Thu, 1 Jan 2026 09:30:47 -0500 Subject: [PATCH 04/11] import fix --- openfe/protocols/openmm_afe/__init__.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/openfe/protocols/openmm_afe/__init__.py b/openfe/protocols/openmm_afe/__init__.py index 48919cd0c..05aa324c3 100644 --- a/openfe/protocols/openmm_afe/__init__.py +++ b/openfe/protocols/openmm_afe/__init__.py @@ -8,17 +8,19 @@ from .equil_binding_afe_method import ( AbsoluteBindingComplexUnit, AbsoluteBindingProtocol, - AbsoluteBindingProtocolResult, AbsoluteBindingSettings, AbsoluteBindingSolventUnit, ) from .equil_solvation_afe_method import ( AbsoluteSolvationProtocol, - AbsoluteSolvationProtocolResult, AbsoluteSolvationSettings, AbsoluteSolvationSolventUnit, AbsoluteSolvationVacuumUnit, ) +from .afe_protocol_results import ( + AbsoluteBindingProtocolResult, + AbsoluteSolvationProtocolResult, +) __all__ = [ "AbsoluteSolvationProtocol", From efcee64dc04ca87b6e6b90750bec6f2c680dbf52 Mon Sep 17 00:00:00 2001 From: IAlibay Date: Thu, 1 Jan 2026 13:02:35 -0500 Subject: [PATCH 05/11] fix some more imports --- openfe/protocols/openmm_afe/afe_protocol_results.py | 5 ++++- openfe/protocols/openmm_afe/afe_units.py | 7 ++++++- openfe/protocols/openmm_afe/equil_solvation_afe_method.py | 2 +- 3 files changed, 11 insertions(+), 3 deletions(-) diff --git a/openfe/protocols/openmm_afe/afe_protocol_results.py b/openfe/protocols/openmm_afe/afe_protocol_results.py index 25a16a59e..c0cef7fc8 100644 --- a/openfe/protocols/openmm_afe/afe_protocol_results.py +++ b/openfe/protocols/openmm_afe/afe_protocol_results.py @@ -36,10 +36,13 @@ import gufe import numpy as np import numpy.typing as npt +from openff.units import unit as offunit from openff.units import Quantity from openmmtools import multistate -from openfe.protocols.restraint_utils.geometry.boresch import BoreschRestraintGeometry +from openfe.protocols.restraint_utils.geometry.boresch import ( + BoreschRestraintGeometry +) logger = logging.getLogger(__name__) diff --git a/openfe/protocols/openmm_afe/afe_units.py b/openfe/protocols/openmm_afe/afe_units.py index d28a630ca..e01a87d16 100644 --- a/openfe/protocols/openmm_afe/afe_units.py +++ b/openfe/protocols/openmm_afe/afe_units.py @@ -30,7 +30,12 @@ import numpy.typing as npt import openmm import openmmtools -from gufe import ChemicalSystem, ProteinComponent, SmallMoleculeComponent, SolventComponent +from gufe import ( + ChemicalSystem, + ProteinComponent, + SmallMoleculeComponent, + SolventComponent, +) from gufe.components import Component from openff.toolkit.topology import Molecule as OFFMolecule from openff.units import Quantity, unit diff --git a/openfe/protocols/openmm_afe/equil_solvation_afe_method.py b/openfe/protocols/openmm_afe/equil_solvation_afe_method.py index 33eadaf86..941e16b38 100644 --- a/openfe/protocols/openmm_afe/equil_solvation_afe_method.py +++ b/openfe/protocols/openmm_afe/equil_solvation_afe_method.py @@ -45,7 +45,7 @@ SolventComponent, settings, ) -from openff.units import offunit +from openff.units import unit as offunit from openfe.due import Doi, due from openfe.protocols.openmm_afe.equil_afe_settings import ( From d11ab5a65918d6a08e702c53d1edfc3226658a9c Mon Sep 17 00:00:00 2001 From: IAlibay Date: Thu, 1 Jan 2026 13:06:46 -0500 Subject: [PATCH 06/11] update some docstring --- .../openmm_afe/afe_protocol_results.py | 32 +++++-------------- openfe/protocols/openmm_afe/afe_units.py | 6 ++-- 2 files changed, 11 insertions(+), 27 deletions(-) diff --git a/openfe/protocols/openmm_afe/afe_protocol_results.py b/openfe/protocols/openmm_afe/afe_protocol_results.py index c0cef7fc8..c9564565d 100644 --- a/openfe/protocols/openmm_afe/afe_protocol_results.py +++ b/openfe/protocols/openmm_afe/afe_protocol_results.py @@ -1,31 +1,15 @@ # This code is part of OpenFE and is licensed under the MIT license. # For details, see https://github.com/OpenFreeEnergy/openfe -"""OpenMM Equilibrium Solvation AFE Protocol --- :mod:`openfe.protocols.openmm_afe.equil_solvation_afe_method` -=============================================================================================================== - -This module implements the necessary methodology tooling to run calculate an -absolute solvation free energy using OpenMM tools and one of the following -alchemical sampling methods: - -* Hamiltonian Replica Exchange -* Self-adjusted mixture sampling -* Independent window sampling - -Current limitations -------------------- -* Alchemical species with a net charge are not currently supported. -* Disapearing molecules are only allowed in state A. Support for - appearing molecules will be added in due course. -* Only small molecules are allowed to act as alchemical molecules. - Alchemically changing protein or solvent components would induce - perturbations which are too large to be handled by this Protocol. - +""" +Result classes for the Absolute Free Energy Protocols +===================================================== -Acknowledgements ----------------- -* Originally based on hydration.py in - `espaloma_charge `_ +This module implements :class:`gufe.ProtocolResult` classes for the absolute +free energy Protocols. +Specifically it implements: + * AbsoluteBindingProtocolResult + * AbsoluteSolvationProtocolResult """ import itertools import logging diff --git a/openfe/protocols/openmm_afe/afe_units.py b/openfe/protocols/openmm_afe/afe_units.py index e01a87d16..d91f0afa8 100644 --- a/openfe/protocols/openmm_afe/afe_units.py +++ b/openfe/protocols/openmm_afe/afe_units.py @@ -1,9 +1,9 @@ # This code is part of OpenFE and is licensed under the MIT license. # For details, see https://github.com/OpenFreeEnergy/openfe -"""OpenMM Equilibrium AFE Protocol base classes -=============================================== +"""OpenMM AFE Protocol base classes +=================================== -Base classes for the equilibrium OpenMM absolute free energy ProtocolUnits. +Base classes for the OpenMM absolute free energy ProtocolUnits. Thist mostly implements BaseAbsoluteUnit whose methods can be overriden to define different types of alchemical transformations. From 0f0ce12567ad51d8828793b99f7c9ec5954bc597 Mon Sep 17 00:00:00 2001 From: IAlibay Date: Sat, 3 Jan 2026 09:20:43 -0500 Subject: [PATCH 07/11] move units out of method files --- openfe/protocols/openmm_afe/abfe_units.py | 463 ++++++++++++++++++ openfe/protocols/openmm_afe/ahfe_units.py | 174 +++++++ .../{afe_units.py => base_afe_units.py} | 0 .../openmm_afe/equil_binding_afe_method.py | 459 +---------------- .../openmm_afe/equil_solvation_afe_method.py | 165 +------ 5 files changed, 653 insertions(+), 608 deletions(-) create mode 100644 openfe/protocols/openmm_afe/abfe_units.py create mode 100644 openfe/protocols/openmm_afe/ahfe_units.py rename openfe/protocols/openmm_afe/{afe_units.py => base_afe_units.py} (100%) diff --git a/openfe/protocols/openmm_afe/abfe_units.py b/openfe/protocols/openmm_afe/abfe_units.py new file mode 100644 index 000000000..cbbc11fab --- /dev/null +++ b/openfe/protocols/openmm_afe/abfe_units.py @@ -0,0 +1,463 @@ +# This code is part of OpenFE and is licensed under the MIT license. +# For details, see https://github.com/OpenFreeEnergy/openfe +# This code is part of OpenFE and is licensed under the MIT license. +# For details, see https://github.com/OpenFreeEnergy/openfe +"""ABFE Protocol Units --- :mod:`openfe.protocols.openmm_afe.abfe_units` +======================================================================== +This module defines the ProtocolUnits for the +:class:`AbsoluteBindingProtocol`. +""" +import logging +import pathlib + +import MDAnalysis as mda +import numpy as np +import numpy.typing as npt +from gufe import ( + SolventComponent, +) +from gufe.components import Component +from openff.units import Quantity +from openff.units.openmm import to_openmm +from openmm import System +from openmm import unit as ommunit +from openmm.app import Topology as omm_topology +from openmmtools.states import GlobalParameterState, ThermodynamicState +from rdkit import Chem + +from openfe.protocols.openmm_afe.equil_afe_settings import ( + BoreschRestraintSettings, + SettingsBaseModel, +) +from openfe.protocols.openmm_utils import system_validation +from openfe.protocols.restraint_utils import geometry +from openfe.protocols.restraint_utils.geometry.boresch import BoreschRestraintGeometry +from openfe.protocols.restraint_utils.openmm import omm_restraints +from openfe.protocols.restraint_utils.openmm.omm_restraints import BoreschRestraint + +from .base_afe_units import BaseAbsoluteUnit + + +logger = logging.getLogger(__name__) + + +class AbsoluteBindingComplexUnit(BaseAbsoluteUnit): + """ + Protocol Unit for the complex phase of an absolute binding free energy + """ + + simtype = "complex" + + def _get_components(self): + """ + Get the relevant components for a complex transformation. + + Returns + ------- + alchem_comps : dict[str, Component] + A dict of alchemical components + solv_comp : SolventComponent + The SolventComponent of the system + prot_comp : ProteinComponent | None + The protein component of the system, if it exists. + small_mols : dict[SmallMoleculeComponent: OFFMolecule] + SmallMoleculeComponents to add to the system. + """ + stateA = self._inputs["stateA"] + alchem_comps = self._inputs["alchemical_components"] + + solv_comp, prot_comp, small_mols = system_validation.get_components(stateA) + off_comps = {m: m.to_openff() for m in small_mols} + + # We don't need to check that solv_comp is not None, otherwise + # an error will have been raised when calling `validate_solvent` + # in the Protocol's `_create`. + # Similarly we don't need to check prot_comp + return alchem_comps, solv_comp, prot_comp, off_comps + + def _handle_settings(self) -> dict[str, SettingsBaseModel]: + """ + Extract the relevant settings for a complex transformation. + + Returns + ------- + settings : dict[str, SettingsBaseModel] + A dictionary with the following entries: + * forcefield_settings : OpenMMSystemGeneratorFFSettings + * thermo_settings : ThermoSettings + * charge_settings : OpenFFPartialChargeSettings + * solvation_settings : OpenMMSolvationSettings + * alchemical_settings : AlchemicalSettings + * lambda_settings : LambdaSettings + * engine_settings : OpenMMEngineSettings + * integrator_settings : IntegratorSettings + * equil_simulation_settings : MDSimulationSettings + * equil_output_settings : ABFEPreEquilOutputSettings + * simulation_settings : SimulationSettings + * output_settings: MultiStateOutputSettings + * restraint_settings: BaseRestraintSettings + """ + prot_settings = self._inputs["protocol"].settings + + settings = {} + settings["forcefield_settings"] = prot_settings.forcefield_settings + settings["thermo_settings"] = prot_settings.thermo_settings + settings["charge_settings"] = prot_settings.partial_charge_settings + settings["solvation_settings"] = prot_settings.complex_solvation_settings + settings["alchemical_settings"] = prot_settings.alchemical_settings + settings["lambda_settings"] = prot_settings.complex_lambda_settings + settings["engine_settings"] = prot_settings.engine_settings + settings["integrator_settings"] = prot_settings.integrator_settings + settings["equil_simulation_settings"] = prot_settings.complex_equil_simulation_settings + settings["equil_output_settings"] = prot_settings.complex_equil_output_settings + settings["simulation_settings"] = prot_settings.complex_simulation_settings + settings["output_settings"] = prot_settings.complex_output_settings + settings["restraint_settings"] = prot_settings.restraint_settings + + return settings + + @staticmethod + def _get_mda_universe( + topology: omm_topology, + positions: ommunit.Quantity, + trajectory: pathlib.Path | None, + ) -> mda.Universe: + """ + Helper method to get a Universe from an openmm Topology, + and either an input trajectory or a set of positions. + + Parameters + ---------- + topology : openmm.app.Topology + An OpenMM Topology that defines the System. + positions: openmm.unit.Quantity + The System's current positions. + Used if a trajectory file is None or is not a file. + trajectory: pathlib.Path + A Path to a trajectory file to read positions from. + + Returns + ------- + mda.Universe + An MDAnalysis Universe of the System. + """ + from MDAnalysis.coordinates.memory import MemoryReader + + # If the trajectory file doesn't exist, then we use positions + if trajectory is not None and trajectory.is_file(): + return mda.Universe( + topology, + trajectory, + topology_format="OPENMMTOPOLOGY", + ) + else: + # Positions is an openmm Quantity in nm we need + # to convert to angstroms + return mda.Universe( + topology, + np.array(positions._value) * 10, + topology_format="OPENMMTOPOLOGY", + trajectory_format=MemoryReader, + ) + + @staticmethod + def _get_idxs_from_residxs( + topology: omm_topology, + residxs: list[int], + ) -> list[int]: + """ + Helper method to get the a list of atom indices which belong to a list + of residues. + + Parameters + ---------- + topology : openmm.app.Topology + An OpenMM Topology that defines the System. + residxs : list[int] + A list of residue numbers who's atoms we should get atom indices. + + Returns + ------- + atom_ids : list[int] + A list of atom indices. + + TODO + ---- + * Check how this works when we deal with virtual sites. + """ + atom_ids = [] + + for r in topology.residues(): + if r.index in residxs: + atom_ids.extend([at.index for at in r.atoms()]) + + return atom_ids + + @staticmethod + def _get_boresch_restraint( + universe: mda.Universe, + guest_rdmol: Chem.Mol, + guest_atom_ids: list[int], + host_atom_ids: list[int], + temperature: Quantity, + settings: BoreschRestraintSettings, + ) -> tuple[BoreschRestraintGeometry, BoreschRestraint]: + """ + Get a Boresch-like restraint Geometry and OpenMM restraint force + supplier. + + Parameters + ---------- + universe : mda.Universe + An MDAnalysis Universe defining the system to get the restraint for. + guest_rdmol : Chem.Mol + An RDKit Molecule defining the guest molecule in the system. + guest_atom_ids: list[int] + A list of atom indices defining the guest molecule in the universe. + host_atom_ids : list[int] + A list of atom indices defining the host molecules in the universe. + temperature : openff.units.Quantity + The temperature of the simulation where the restraint will be added. + settings : BoreschRestraintSettings + Settings on how the Boresch-like restraint should be defined. + + Returns + ------- + geom : BoreschRestraintGeometry + A class defining the Boresch-like restraint. + restraint : BoreschRestraint + A factory class for generating Boresch restraints in OpenMM. + """ + # Take the minimum of the two possible force constants to check against + frc_const = min(settings.K_thetaA, settings.K_thetaB) + + geom = geometry.boresch.find_boresch_restraint( + universe=universe, + guest_rdmol=guest_rdmol, + guest_idxs=guest_atom_ids, + host_idxs=host_atom_ids, + host_selection=settings.host_selection, + anchor_finding_strategy=settings.anchor_finding_strategy, + dssp_filter=settings.dssp_filter, + rmsf_cutoff=settings.rmsf_cutoff, + host_min_distance=settings.host_min_distance, + host_max_distance=settings.host_max_distance, + angle_force_constant=frc_const, + temperature=temperature, + ) + + restraint = omm_restraints.BoreschRestraint(settings) + return geom, restraint + + def _add_restraints( + self, + system: System, + topology: omm_topology, + positions: ommunit.Quantity, + alchem_comps: dict[str, list[Component]], + comp_resids: dict[Component, npt.NDArray], + settings: dict[str, SettingsBaseModel], + ) -> tuple[ + GlobalParameterState, + Quantity, + System, + geometry.HostGuestRestraintGeometry, + ]: + """ + Find and add restraints to the OpenMM System. + + Notes + ----- + Currently, only Boresch-like restraints are supported. + + Parameters + ---------- + system : openmm.System + The System to add the restraint to. + topology : openmm.app.Topology + An OpenMM Topology that defines the System. + positions: openmm.unit.Quantity + The System's current positions. + Used if a trajectory file isn't found. + alchem_comps: dict[str, list[Component]] + A dictionary with a list of alchemical components + in both state A and B. + comp_resids: dict[Component, npt.NDArray] + A dictionary keyed by each Component in the System + which contains arrays with the residue indices that is contained + by that Component. + settings : dict[str, SettingsBaseModel] + A dictionary of settings that defines how to find and set + the restraint. + + Returns + ------- + restraint_parameter_state : RestraintParameterState + A RestraintParameterState object that defines the control + parameter for the restraint. + correction : openff.units.Quantity + The standard state correction for the restraint. + system : openmm.System + A copy of the System with the restraint added. + rest_geom : geometry.HostGuestRestraintGeometry + The restraint Geometry object. + """ + if self.verbose: + self.logger.info("Generating restraints") + + # Get the guest rdmol + guest_rdmol = alchem_comps["stateA"][0].to_rdkit() + + # sanitize the rdmol if possible - warn if you can't + err = Chem.SanitizeMol(guest_rdmol, catchErrors=True) + + if err: + msg = "restraint generation: could not sanitize ligand rdmol" + logger.warning(msg) + + # Get the guest idxs + # concatenate a list of residue indexes for all alchemical components + residxs = np.concatenate([comp_resids[key] for key in alchem_comps["stateA"]]) + + # get the alchemicical atom ids + guest_atom_ids = self._get_idxs_from_residxs(topology, residxs) + + # Now get the host idxs + # We assume this is everything but the alchemical component + # and the solvent. + solv_comps = [c for c in comp_resids if isinstance(c, SolventComponent)] + exclude_comps = [alchem_comps["stateA"]] + solv_comps + residxs = np.concatenate([v for i, v in comp_resids.items() if i not in exclude_comps]) + + host_atom_ids = self._get_idxs_from_residxs(topology, residxs) + + # Finally create an MDAnalysis Universe + # We try to pass the equilibration production file path through + # In some cases (debugging / dry runs) this won't be available + # so we'll default to using input positions. + univ = self._get_mda_universe( + topology, + positions, + self.shared_basepath / settings["equil_output_settings"].production_trajectory_filename, + ) + + if isinstance(settings["restraint_settings"], BoreschRestraintSettings): + rest_geom, restraint = self._get_boresch_restraint( + univ, + guest_rdmol, + guest_atom_ids, + host_atom_ids, + settings["thermo_settings"].temperature, + settings["restraint_settings"], + ) + else: + # TODO turn this into a direction for different restraint types supported? + raise NotImplementedError("Other restraint types are not yet available") + + if self.verbose: + self.logger.info(f"restraint geometry is: {rest_geom}") + + # We need a temporary thermodynamic state to add the restraint + # & get the correction + thermodynamic_state = ThermodynamicState( + system, + temperature=to_openmm(settings["thermo_settings"].temperature), + pressure=to_openmm(settings["thermo_settings"].pressure), + ) + + # Add the force to the thermodynamic state + restraint.add_force( + thermodynamic_state, + rest_geom, + controlling_parameter_name="lambda_restraints", + ) + # Get the standard state correction as a unit.Quantity + correction = restraint.get_standard_state_correction( + thermodynamic_state, + rest_geom, + ) + + # Get the GlobalParameterState for the restraint + restraint_parameter_state = omm_restraints.RestraintParameterState(lambda_restraints=1.0) + return ( + restraint_parameter_state, + correction, + # Remove the thermostat, otherwise you'll get an + # Andersen thermostat by default! + thermodynamic_state.get_system(remove_thermostat=True), + rest_geom, + ) + + +class AbsoluteBindingSolventUnit(BaseAbsoluteUnit): + """ + Protocol Unit for the solvent phase of an absolute binding free energy + """ + + simtype = "solvent" + + def _get_components(self): + """ + Get the relevant components for a solvent transformation. + + Returns + ------- + alchem_comps : dict[str, Component] + A list of alchemical components + solv_comp : SolventComponent + The SolventComponent of the system + prot_comp : ProteinComponent | None + The protein component of the system, if it exists. + small_mols : dict[SmallMoleculeComponent: OFFMolecule] + SmallMoleculeComponents to add to the system. + """ + stateA = self._inputs["stateA"] + alchem_comps = self._inputs["alchemical_components"] + + solv_comp, prot_comp, small_mols = system_validation.get_components(stateA) + off_comps = {m: m.to_openff() for m in alchem_comps["stateA"]} + + # We don't need to check that solv_comp is not None, otherwise + # an error will have been raised when calling `validate_solvent` + # in the Protocol's `_create`. + # Similarly we don't need to check prot_comp just return None + return alchem_comps, solv_comp, None, off_comps + + def _handle_settings(self) -> dict[str, SettingsBaseModel]: + """ + Extract the relevant settings for a solvent transformation. + + Returns + ------- + settings : dict[str, SettingsBaseModel] + A dictionary with the following entries: + * forcefield_settings : OpenMMSystemGeneratorFFSettings + * thermo_settings : ThermoSettings + * charge_settings : OpenFFPartialChargeSettings + * solvation_settings : OpenMMSolvationSettings + * alchemical_settings : AlchemicalSettings + * lambda_settings : LambdaSettings + * engine_settings : OpenMMEngineSettings + * integrator_settings : IntegratorSettings + * equil_simulation_settings : MDSimulationSettings + * equil_output_settings : ABFEPreEquilOutputSettings + * simulation_settings : MultiStateSimulationSettings + * output_settings: MultiStateOutputSettings + """ + prot_settings = self._inputs["protocol"].settings + + settings = {} + settings["forcefield_settings"] = prot_settings.forcefield_settings + settings["thermo_settings"] = prot_settings.thermo_settings + settings["charge_settings"] = prot_settings.partial_charge_settings + settings["solvation_settings"] = prot_settings.solvent_solvation_settings + settings["alchemical_settings"] = prot_settings.alchemical_settings + settings["lambda_settings"] = prot_settings.solvent_lambda_settings + settings["engine_settings"] = prot_settings.engine_settings + settings["integrator_settings"] = prot_settings.integrator_settings + settings["equil_simulation_settings"] = prot_settings.solvent_equil_simulation_settings + settings["equil_output_settings"] = prot_settings.solvent_equil_output_settings + settings["simulation_settings"] = prot_settings.solvent_simulation_settings + settings["output_settings"] = prot_settings.solvent_output_settings + + return settings diff --git a/openfe/protocols/openmm_afe/ahfe_units.py b/openfe/protocols/openmm_afe/ahfe_units.py new file mode 100644 index 000000000..96269da9b --- /dev/null +++ b/openfe/protocols/openmm_afe/ahfe_units.py @@ -0,0 +1,174 @@ +# This code is part of OpenFE and is licensed under the MIT license. +# For details, see https://github.com/OpenFreeEnergy/openfe +# This code is part of OpenFE and is licensed under the MIT license. +# For details, see https://github.com/OpenFreeEnergy/openfe +"""AHFE Protocol Units --- :mod:`openfe.protocols.openmm_afe.ahfe_units` +======================================================================== +This module defines the ProtocolUnits for the +:class:`AbsoluteSolvationProtocol`. +""" +import logging + +from openfe.protocols.openmm_afe.equil_afe_settings import ( + SettingsBaseModel, +) + +from ..openmm_utils import system_validation +from .base_afe_units import BaseAbsoluteUnit + + +logger = logging.getLogger(__name__) + + +class AbsoluteSolvationVacuumUnit(BaseAbsoluteUnit): + """ + Protocol Unit for the vacuum phase of an absolute solvation free energy + """ + + simtype = "vacuum" + + def _get_components(self): + """ + Get the relevant components for a vacuum transformation. + + Returns + ------- + alchem_comps : dict[str, list[Component]] + A list of alchemical components + solv_comp : None + For the gas phase transformation, None will always be returned + for the solvent component of the chemical system. + prot_comp : Optional[ProteinComponent] + The protein component of the system, if it exists. + small_mols : dict[Component, OpenFF Molecule] + The openff Molecules to add to the system. This + is equivalent to the alchemical components in stateA (since + we only allow for disappearing ligands). + """ + stateA = self._inputs["stateA"] + alchem_comps = self._inputs["alchemical_components"] + + off_comps = {m: m.to_openff() for m in alchem_comps["stateA"]} + + _, prot_comp, _ = system_validation.get_components(stateA) + + # Notes: + # 1. Our input state will contain a solvent, we ``None`` that out + # since this is the gas phase unit. + # 2. Our small molecules will always just be the alchemical components + # (of stateA since we enforce only one disappearing ligand) + return alchem_comps, None, prot_comp, off_comps + + def _handle_settings(self) -> dict[str, SettingsBaseModel]: + """ + Extract the relevant settings for a vacuum transformation. + + Returns + ------- + settings : dict[str, SettingsBaseModel] + A dictionary with the following entries: + * forcefield_settings : OpenMMSystemGeneratorFFSettings + * thermo_settings : ThermoSettings + * charge_settings : OpenFFPartialChargeSettings + * solvation_settings : OpenMMSolvationSettings + * alchemical_settings : AlchemicalSettings + * lambda_settings : LambdaSettings + * engine_settings : OpenMMEngineSettings + * integrator_settings : IntegratorSettings + * equil_simulation_settings : MDSimulationSettings + * equil_output_settings : MDOutputSettings + * simulation_settings : SimulationSettings + * output_settings: MultiStateOutputSettings + """ + prot_settings = self._inputs["protocol"].settings + + settings = {} + settings["forcefield_settings"] = prot_settings.vacuum_forcefield_settings + settings["thermo_settings"] = prot_settings.thermo_settings + settings["charge_settings"] = prot_settings.partial_charge_settings + settings["solvation_settings"] = prot_settings.solvation_settings + settings["alchemical_settings"] = prot_settings.alchemical_settings + settings["lambda_settings"] = prot_settings.lambda_settings + settings["engine_settings"] = prot_settings.vacuum_engine_settings + settings["integrator_settings"] = prot_settings.integrator_settings + settings["equil_simulation_settings"] = prot_settings.vacuum_equil_simulation_settings + settings["equil_output_settings"] = prot_settings.vacuum_equil_output_settings + settings["simulation_settings"] = prot_settings.vacuum_simulation_settings + settings["output_settings"] = prot_settings.vacuum_output_settings + + return settings + + +class AbsoluteSolvationSolventUnit(BaseAbsoluteUnit): + """ + Protocol Unit for the solvent phase of an absolute solvation free energy + """ + + simtype = "solvent" + + def _get_components(self): + """ + Get the relevant components for a solvent transformation. + + Returns + ------- + alchem_comps : dict[str, Component] + A list of alchemical components + solv_comp : SolventComponent + The SolventComponent of the system + prot_comp : Optional[ProteinComponent] + The protein component of the system, if it exists. + small_mols : dict[SmallMoleculeComponent: OFFMolecule] + SmallMoleculeComponents to add to the system. + """ + stateA = self._inputs["stateA"] + alchem_comps = self._inputs["alchemical_components"] + + solv_comp, prot_comp, small_mols = system_validation.get_components(stateA) + off_comps = {m: m.to_openff() for m in small_mols} + + # We don't need to check that solv_comp is not None, otherwise + # an error will have been raised when calling `validate_solvent` + # in the Protocol's `_create`. + # Similarly we don't need to check prot_comp since that's also + # disallowed on create + return alchem_comps, solv_comp, prot_comp, off_comps + + def _handle_settings(self) -> dict[str, SettingsBaseModel]: + """ + Extract the relevant settings for a solvent transformation. + + Returns + ------- + settings : dict[str, SettingsBaseModel] + A dictionary with the following entries: + * forcefield_settings : OpenMMSystemGeneratorFFSettings + * thermo_settings : ThermoSettings + * charge_settings : OpenFFPartialChargeSettings + * solvation_settings : OpenMMSolvationSettings + * alchemical_settings : AlchemicalSettings + * lambda_settings : LambdaSettings + * engine_settings : OpenMMEngineSettings + * integrator_settings : IntegratorSettings + * equil_simulation_settings : MDSimulationSettings + * equil_output_settings : MDOutputSettings + * simulation_settings : MultiStateSimulationSettings + * output_settings: MultiStateOutputSettings + """ + prot_settings = self._inputs["protocol"].settings + + settings = {} + settings["forcefield_settings"] = prot_settings.solvent_forcefield_settings + settings["thermo_settings"] = prot_settings.thermo_settings + settings["charge_settings"] = prot_settings.partial_charge_settings + settings["solvation_settings"] = prot_settings.solvation_settings + settings["alchemical_settings"] = prot_settings.alchemical_settings + settings["lambda_settings"] = prot_settings.lambda_settings + settings["engine_settings"] = prot_settings.solvent_engine_settings + settings["integrator_settings"] = prot_settings.integrator_settings + settings["equil_simulation_settings"] = prot_settings.solvent_equil_simulation_settings + settings["equil_output_settings"] = prot_settings.solvent_equil_output_settings + settings["simulation_settings"] = prot_settings.solvent_simulation_settings + settings["output_settings"] = prot_settings.solvent_output_settings + + return settings diff --git a/openfe/protocols/openmm_afe/afe_units.py b/openfe/protocols/openmm_afe/base_afe_units.py similarity index 100% rename from openfe/protocols/openmm_afe/afe_units.py rename to openfe/protocols/openmm_afe/base_afe_units.py diff --git a/openfe/protocols/openmm_afe/equil_binding_afe_method.py b/openfe/protocols/openmm_afe/equil_binding_afe_method.py index 56b85546e..5bdca550e 100644 --- a/openfe/protocols/openmm_afe/equil_binding_afe_method.py +++ b/openfe/protocols/openmm_afe/equil_binding_afe_method.py @@ -23,18 +23,13 @@ `Yank `_. """ - import logging -import pathlib import uuid import warnings from collections import defaultdict -from typing import Any, Iterable, Optional, Union +from typing import Any, Iterable import gufe -import MDAnalysis as mda -import numpy as np -import numpy.typing as npt from gufe import ( ChemicalSystem, ProteinComponent, @@ -42,15 +37,7 @@ SolventComponent, settings, ) -from gufe.components import Component -from openff.units import Quantity from openff.units import unit as offunit -from openff.units.openmm import to_openmm -from openmm import System -from openmm import unit as ommunit -from openmm.app import Topology as omm_topology -from openmmtools.states import GlobalParameterState, ThermodynamicState -from rdkit import Chem from openfe.due import Doi, due from openfe.protocols.openmm_afe.equil_afe_settings import ( @@ -66,16 +53,14 @@ OpenFFPartialChargeSettings, OpenMMEngineSettings, OpenMMSolvationSettings, - SettingsBaseModel, ) -from openfe.protocols.openmm_utils import settings_validation, system_validation -from openfe.protocols.restraint_utils import geometry -from openfe.protocols.restraint_utils.geometry.boresch import BoreschRestraintGeometry -from openfe.protocols.restraint_utils.openmm import omm_restraints -from openfe.protocols.restraint_utils.openmm.omm_restraints import BoreschRestraint +from openfe.protocols.openmm_utils import ( + settings_validation, + system_validation +) -from .afe_units import BaseAbsoluteUnit from .afe_protocol_results import AbsoluteBindingProtocolResult +from .abfe_units import AbsoluteBindingComplexUnit, AbsoluteBindingSolventUnit due.cite( Doi("10.5281/zenodo.596504"), @@ -348,8 +333,8 @@ def _validate( *, stateA: ChemicalSystem, stateB: ChemicalSystem, - mapping: Optional[Union[gufe.ComponentMapping, list[gufe.ComponentMapping]]] = None, - extends: Optional[gufe.ProtocolDAGResult] = None, + mapping: gufe.ComponentMapping | list[gufe.ComponentMapping] | None = None, + extends: gufe.ProtocolDAGResult | None = None, ): # Check we're not extending if extends is not None: @@ -425,8 +410,8 @@ def _create( self, stateA: ChemicalSystem, stateB: ChemicalSystem, - mapping: Optional[Union[gufe.ComponentMapping, list[gufe.ComponentMapping]]] = None, - extends: Optional[gufe.ProtocolDAGResult] = None, + mapping: gufe.ComponentMapping | list[gufe.ComponentMapping] | None = None, + extends: gufe.ProtocolDAGResult | None = None, ) -> list[gufe.ProtocolUnit]: # Validate inputs self.validate(stateA=stateA, stateB=stateB, mapping=mapping, extends=extends) @@ -496,426 +481,4 @@ def _gather( for k, v in unsorted_complex_repeats.items(): repeats["complex"][str(k)] = sorted(v, key=lambda x: x.outputs["generation"]) - return repeats - - -class AbsoluteBindingComplexUnit(BaseAbsoluteUnit): - """ - Protocol Unit for the complex phase of an absolute binding free energy - """ - - simtype = "complex" - - def _get_components(self): - """ - Get the relevant components for a complex transformation. - - Returns - ------- - alchem_comps : dict[str, Component] - A dict of alchemical components - solv_comp : SolventComponent - The SolventComponent of the system - prot_comp : Optional[ProteinComponent] - The protein component of the system, if it exists. - small_mols : dict[SmallMoleculeComponent: OFFMolecule] - SmallMoleculeComponents to add to the system. - """ - stateA = self._inputs["stateA"] - alchem_comps = self._inputs["alchemical_components"] - - solv_comp, prot_comp, small_mols = system_validation.get_components(stateA) - off_comps = {m: m.to_openff() for m in small_mols} - - # We don't need to check that solv_comp is not None, otherwise - # an error will have been raised when calling `validate_solvent` - # in the Protocol's `_create`. - # Similarly we don't need to check prot_comp - return alchem_comps, solv_comp, prot_comp, off_comps - - def _handle_settings(self) -> dict[str, SettingsBaseModel]: - """ - Extract the relevant settings for a complex transformation. - - Returns - ------- - settings : dict[str, SettingsBaseModel] - A dictionary with the following entries: - * forcefield_settings : OpenMMSystemGeneratorFFSettings - * thermo_settings : ThermoSettings - * charge_settings : OpenFFPartialChargeSettings - * solvation_settings : OpenMMSolvationSettings - * alchemical_settings : AlchemicalSettings - * lambda_settings : LambdaSettings - * engine_settings : OpenMMEngineSettings - * integrator_settings : IntegratorSettings - * equil_simulation_settings : MDSimulationSettings - * equil_output_settings : ABFEPreEquilOutputSettings - * simulation_settings : SimulationSettings - * output_settings: MultiStateOutputSettings - * restraint_settings: BaseRestraintSettings - """ - prot_settings = self._inputs["protocol"].settings - - settings = {} - settings["forcefield_settings"] = prot_settings.forcefield_settings - settings["thermo_settings"] = prot_settings.thermo_settings - settings["charge_settings"] = prot_settings.partial_charge_settings - settings["solvation_settings"] = prot_settings.complex_solvation_settings - settings["alchemical_settings"] = prot_settings.alchemical_settings - settings["lambda_settings"] = prot_settings.complex_lambda_settings - settings["engine_settings"] = prot_settings.engine_settings - settings["integrator_settings"] = prot_settings.integrator_settings - settings["equil_simulation_settings"] = prot_settings.complex_equil_simulation_settings - settings["equil_output_settings"] = prot_settings.complex_equil_output_settings - settings["simulation_settings"] = prot_settings.complex_simulation_settings - settings["output_settings"] = prot_settings.complex_output_settings - settings["restraint_settings"] = prot_settings.restraint_settings - - return settings - - @staticmethod - def _get_mda_universe( - topology: omm_topology, - positions: ommunit.Quantity, - trajectory: Optional[pathlib.Path], - ) -> mda.Universe: - """ - Helper method to get a Universe from an openmm Topology, - and either an input trajectory or a set of positions. - - Parameters - ---------- - topology : openmm.app.Topology - An OpenMM Topology that defines the System. - positions: openmm.unit.Quantity - The System's current positions. - Used if a trajectory file is None or is not a file. - trajectory: pathlib.Path - A Path to a trajectory file to read positions from. - - Returns - ------- - mda.Universe - An MDAnalysis Universe of the System. - """ - from MDAnalysis.coordinates.memory import MemoryReader - - # If the trajectory file doesn't exist, then we use positions - if trajectory is not None and trajectory.is_file(): - return mda.Universe( - topology, - trajectory, - topology_format="OPENMMTOPOLOGY", - ) - else: - # Positions is an openmm Quantity in nm we need - # to convert to angstroms - return mda.Universe( - topology, - np.array(positions._value) * 10, - topology_format="OPENMMTOPOLOGY", - trajectory_format=MemoryReader, - ) - - @staticmethod - def _get_idxs_from_residxs( - topology: omm_topology, - residxs: list[int], - ) -> list[int]: - """ - Helper method to get the a list of atom indices which belong to a list - of residues. - - Parameters - ---------- - topology : openmm.app.Topology - An OpenMM Topology that defines the System. - residxs : list[int] - A list of residue numbers who's atoms we should get atom indices. - - Returns - ------- - atom_ids : list[int] - A list of atom indices. - - TODO - ---- - * Check how this works when we deal with virtual sites. - """ - atom_ids = [] - - for r in topology.residues(): - if r.index in residxs: - atom_ids.extend([at.index for at in r.atoms()]) - - return atom_ids - - @staticmethod - def _get_boresch_restraint( - universe: mda.Universe, - guest_rdmol: Chem.Mol, - guest_atom_ids: list[int], - host_atom_ids: list[int], - temperature: Quantity, - settings: BoreschRestraintSettings, - ) -> tuple[BoreschRestraintGeometry, BoreschRestraint]: - """ - Get a Boresch-like restraint Geometry and OpenMM restraint force - supplier. - - Parameters - ---------- - universe : mda.Universe - An MDAnalysis Universe defining the system to get the restraint for. - guest_rdmol : Chem.Mol - An RDKit Molecule defining the guest molecule in the system. - guest_atom_ids: list[int] - A list of atom indices defining the guest molecule in the universe. - host_atom_ids : list[int] - A list of atom indices defining the host molecules in the universe. - temperature : openff.units.Quantity - The temperature of the simulation where the restraint will be added. - settings : BoreschRestraintSettings - Settings on how the Boresch-like restraint should be defined. - - Returns - ------- - geom : BoreschRestraintGeometry - A class defining the Boresch-like restraint. - restraint : BoreschRestraint - A factory class for generating Boresch restraints in OpenMM. - """ - # Take the minimum of the two possible force constants to check against - frc_const = min(settings.K_thetaA, settings.K_thetaB) - - geom = geometry.boresch.find_boresch_restraint( - universe=universe, - guest_rdmol=guest_rdmol, - guest_idxs=guest_atom_ids, - host_idxs=host_atom_ids, - host_selection=settings.host_selection, - anchor_finding_strategy=settings.anchor_finding_strategy, - dssp_filter=settings.dssp_filter, - rmsf_cutoff=settings.rmsf_cutoff, - host_min_distance=settings.host_min_distance, - host_max_distance=settings.host_max_distance, - angle_force_constant=frc_const, - temperature=temperature, - ) - - restraint = omm_restraints.BoreschRestraint(settings) - return geom, restraint - - def _add_restraints( - self, - system: System, - topology: omm_topology, - positions: ommunit.Quantity, - alchem_comps: dict[str, list[Component]], - comp_resids: dict[Component, npt.NDArray], - settings: dict[str, SettingsBaseModel], - ) -> tuple[ - GlobalParameterState, - Quantity, - System, - geometry.HostGuestRestraintGeometry, - ]: - """ - Find and add restraints to the OpenMM System. - - Notes - ----- - Currently, only Boresch-like restraints are supported. - - Parameters - ---------- - system : openmm.System - The System to add the restraint to. - topology : openmm.app.Topology - An OpenMM Topology that defines the System. - positions: openmm.unit.Quantity - The System's current positions. - Used if a trajectory file isn't found. - alchem_comps: dict[str, list[Component]] - A dictionary with a list of alchemical components - in both state A and B. - comp_resids: dict[Component, npt.NDArray] - A dictionary keyed by each Component in the System - which contains arrays with the residue indices that is contained - by that Component. - settings : dict[str, SettingsBaseModel] - A dictionary of settings that defines how to find and set - the restraint. - - Returns - ------- - restraint_parameter_state : RestraintParameterState - A RestraintParameterState object that defines the control - parameter for the restraint. - correction : openff.units.Quantity - The standard state correction for the restraint. - system : openmm.System - A copy of the System with the restraint added. - rest_geom : geometry.HostGuestRestraintGeometry - The restraint Geometry object. - """ - if self.verbose: - self.logger.info("Generating restraints") - - # Get the guest rdmol - guest_rdmol = alchem_comps["stateA"][0].to_rdkit() - - # sanitize the rdmol if possible - warn if you can't - err = Chem.SanitizeMol(guest_rdmol, catchErrors=True) - - if err: - msg = "restraint generation: could not sanitize ligand rdmol" - logger.warning(msg) - - # Get the guest idxs - # concatenate a list of residue indexes for all alchemical components - residxs = np.concatenate([comp_resids[key] for key in alchem_comps["stateA"]]) - - # get the alchemicical atom ids - guest_atom_ids = self._get_idxs_from_residxs(topology, residxs) - - # Now get the host idxs - # We assume this is everything but the alchemical component - # and the solvent. - solv_comps = [c for c in comp_resids if isinstance(c, SolventComponent)] - exclude_comps = [alchem_comps["stateA"]] + solv_comps - residxs = np.concatenate([v for i, v in comp_resids.items() if i not in exclude_comps]) - - host_atom_ids = self._get_idxs_from_residxs(topology, residxs) - - # Finally create an MDAnalysis Universe - # We try to pass the equilibration production file path through - # In some cases (debugging / dry runs) this won't be available - # so we'll default to using input positions. - univ = self._get_mda_universe( - topology, - positions, - self.shared_basepath / settings["equil_output_settings"].production_trajectory_filename, - ) - - if isinstance(settings["restraint_settings"], BoreschRestraintSettings): - rest_geom, restraint = self._get_boresch_restraint( - univ, - guest_rdmol, - guest_atom_ids, - host_atom_ids, - settings["thermo_settings"].temperature, - settings["restraint_settings"], - ) - else: - # TODO turn this into a direction for different restraint types supported? - raise NotImplementedError("Other restraint types are not yet available") - - if self.verbose: - self.logger.info(f"restraint geometry is: {rest_geom}") - - # We need a temporary thermodynamic state to add the restraint - # & get the correction - thermodynamic_state = ThermodynamicState( - system, - temperature=to_openmm(settings["thermo_settings"].temperature), - pressure=to_openmm(settings["thermo_settings"].pressure), - ) - - # Add the force to the thermodynamic state - restraint.add_force( - thermodynamic_state, - rest_geom, - controlling_parameter_name="lambda_restraints", - ) - # Get the standard state correction as a unit.Quantity - correction = restraint.get_standard_state_correction( - thermodynamic_state, - rest_geom, - ) - - # Get the GlobalParameterState for the restraint - restraint_parameter_state = omm_restraints.RestraintParameterState(lambda_restraints=1.0) - return ( - restraint_parameter_state, - correction, - # Remove the thermostat, otherwise you'll get an - # Andersen thermostat by default! - thermodynamic_state.get_system(remove_thermostat=True), - rest_geom, - ) - - -class AbsoluteBindingSolventUnit(BaseAbsoluteUnit): - """ - Protocol Unit for the solvent phase of an absolute binding free energy - """ - - simtype = "solvent" - - def _get_components(self): - """ - Get the relevant components for a solvent transformation. - - Returns - ------- - alchem_comps : dict[str, Component] - A list of alchemical components - solv_comp : SolventComponent - The SolventComponent of the system - prot_comp : Optional[ProteinComponent] - The protein component of the system, if it exists. - small_mols : dict[SmallMoleculeComponent: OFFMolecule] - SmallMoleculeComponents to add to the system. - """ - stateA = self._inputs["stateA"] - alchem_comps = self._inputs["alchemical_components"] - - solv_comp, prot_comp, small_mols = system_validation.get_components(stateA) - off_comps = {m: m.to_openff() for m in alchem_comps["stateA"]} - - # We don't need to check that solv_comp is not None, otherwise - # an error will have been raised when calling `validate_solvent` - # in the Protocol's `_create`. - # Similarly we don't need to check prot_comp just return None - return alchem_comps, solv_comp, None, off_comps - - def _handle_settings(self) -> dict[str, SettingsBaseModel]: - """ - Extract the relevant settings for a solvent transformation. - - Returns - ------- - settings : dict[str, SettingsBaseModel] - A dictionary with the following entries: - * forcefield_settings : OpenMMSystemGeneratorFFSettings - * thermo_settings : ThermoSettings - * charge_settings : OpenFFPartialChargeSettings - * solvation_settings : OpenMMSolvationSettings - * alchemical_settings : AlchemicalSettings - * lambda_settings : LambdaSettings - * engine_settings : OpenMMEngineSettings - * integrator_settings : IntegratorSettings - * equil_simulation_settings : MDSimulationSettings - * equil_output_settings : ABFEPreEquilOutputSettings - * simulation_settings : MultiStateSimulationSettings - * output_settings: MultiStateOutputSettings - """ - prot_settings = self._inputs["protocol"].settings - - settings = {} - settings["forcefield_settings"] = prot_settings.forcefield_settings - settings["thermo_settings"] = prot_settings.thermo_settings - settings["charge_settings"] = prot_settings.partial_charge_settings - settings["solvation_settings"] = prot_settings.solvent_solvation_settings - settings["alchemical_settings"] = prot_settings.alchemical_settings - settings["lambda_settings"] = prot_settings.solvent_lambda_settings - settings["engine_settings"] = prot_settings.engine_settings - settings["integrator_settings"] = prot_settings.integrator_settings - settings["equil_simulation_settings"] = prot_settings.solvent_equil_simulation_settings - settings["equil_output_settings"] = prot_settings.solvent_equil_output_settings - settings["simulation_settings"] = prot_settings.solvent_simulation_settings - settings["output_settings"] = prot_settings.solvent_output_settings - - return settings + return repeats \ No newline at end of file diff --git a/openfe/protocols/openmm_afe/equil_solvation_afe_method.py b/openfe/protocols/openmm_afe/equil_solvation_afe_method.py index 941e16b38..ecea8ff95 100644 --- a/openfe/protocols/openmm_afe/equil_solvation_afe_method.py +++ b/openfe/protocols/openmm_afe/equil_solvation_afe_method.py @@ -27,9 +27,6 @@ `espaloma_charge `_ """ - -from __future__ import annotations - import logging import uuid import warnings @@ -60,11 +57,13 @@ OpenFFPartialChargeSettings, OpenMMEngineSettings, OpenMMSolvationSettings, - SettingsBaseModel, ) from ..openmm_utils import settings_validation, system_validation -from .afe_units import BaseAbsoluteUnit +from .ahfe_units import ( + AbsoluteSolvationSolventUnit, + AbsoluteSolvationVacuumUnit, +) from .afe_protocol_results import AbsoluteSolvationProtocolResult @@ -503,158 +502,4 @@ def _gather( for k, v in unsorted_vacuum_repeats.items(): repeats["vacuum"][str(k)] = sorted(v, key=lambda x: x.outputs["generation"]) - return repeats - - -class AbsoluteSolvationVacuumUnit(BaseAbsoluteUnit): - """ - Protocol Unit for the vacuum phase of an absolute solvation free energy - """ - - simtype = "vacuum" - - def _get_components(self): - """ - Get the relevant components for a vacuum transformation. - - Returns - ------- - alchem_comps : dict[str, list[Component]] - A list of alchemical components - solv_comp : None - For the gas phase transformation, None will always be returned - for the solvent component of the chemical system. - prot_comp : Optional[ProteinComponent] - The protein component of the system, if it exists. - small_mols : dict[Component, OpenFF Molecule] - The openff Molecules to add to the system. This - is equivalent to the alchemical components in stateA (since - we only allow for disappearing ligands). - """ - stateA = self._inputs["stateA"] - alchem_comps = self._inputs["alchemical_components"] - - off_comps = {m: m.to_openff() for m in alchem_comps["stateA"]} - - _, prot_comp, _ = system_validation.get_components(stateA) - - # Notes: - # 1. Our input state will contain a solvent, we ``None`` that out - # since this is the gas phase unit. - # 2. Our small molecules will always just be the alchemical components - # (of stateA since we enforce only one disappearing ligand) - return alchem_comps, None, prot_comp, off_comps - - def _handle_settings(self) -> dict[str, SettingsBaseModel]: - """ - Extract the relevant settings for a vacuum transformation. - - Returns - ------- - settings : dict[str, SettingsBaseModel] - A dictionary with the following entries: - * forcefield_settings : OpenMMSystemGeneratorFFSettings - * thermo_settings : ThermoSettings - * charge_settings : OpenFFPartialChargeSettings - * solvation_settings : OpenMMSolvationSettings - * alchemical_settings : AlchemicalSettings - * lambda_settings : LambdaSettings - * engine_settings : OpenMMEngineSettings - * integrator_settings : IntegratorSettings - * equil_simulation_settings : MDSimulationSettings - * equil_output_settings : MDOutputSettings - * simulation_settings : SimulationSettings - * output_settings: MultiStateOutputSettings - """ - prot_settings = self._inputs["protocol"].settings - - settings = {} - settings["forcefield_settings"] = prot_settings.vacuum_forcefield_settings - settings["thermo_settings"] = prot_settings.thermo_settings - settings["charge_settings"] = prot_settings.partial_charge_settings - settings["solvation_settings"] = prot_settings.solvation_settings - settings["alchemical_settings"] = prot_settings.alchemical_settings - settings["lambda_settings"] = prot_settings.lambda_settings - settings["engine_settings"] = prot_settings.vacuum_engine_settings - settings["integrator_settings"] = prot_settings.integrator_settings - settings["equil_simulation_settings"] = prot_settings.vacuum_equil_simulation_settings - settings["equil_output_settings"] = prot_settings.vacuum_equil_output_settings - settings["simulation_settings"] = prot_settings.vacuum_simulation_settings - settings["output_settings"] = prot_settings.vacuum_output_settings - - return settings - - -class AbsoluteSolvationSolventUnit(BaseAbsoluteUnit): - """ - Protocol Unit for the solvent phase of an absolute solvation free energy - """ - - simtype = "solvent" - - def _get_components(self): - """ - Get the relevant components for a solvent transformation. - - Returns - ------- - alchem_comps : dict[str, Component] - A list of alchemical components - solv_comp : SolventComponent - The SolventComponent of the system - prot_comp : Optional[ProteinComponent] - The protein component of the system, if it exists. - small_mols : dict[SmallMoleculeComponent: OFFMolecule] - SmallMoleculeComponents to add to the system. - """ - stateA = self._inputs["stateA"] - alchem_comps = self._inputs["alchemical_components"] - - solv_comp, prot_comp, small_mols = system_validation.get_components(stateA) - off_comps = {m: m.to_openff() for m in small_mols} - - # We don't need to check that solv_comp is not None, otherwise - # an error will have been raised when calling `validate_solvent` - # in the Protocol's `_create`. - # Similarly we don't need to check prot_comp since that's also - # disallowed on create - return alchem_comps, solv_comp, prot_comp, off_comps - - def _handle_settings(self) -> dict[str, SettingsBaseModel]: - """ - Extract the relevant settings for a solvent transformation. - - Returns - ------- - settings : dict[str, SettingsBaseModel] - A dictionary with the following entries: - * forcefield_settings : OpenMMSystemGeneratorFFSettings - * thermo_settings : ThermoSettings - * charge_settings : OpenFFPartialChargeSettings - * solvation_settings : OpenMMSolvationSettings - * alchemical_settings : AlchemicalSettings - * lambda_settings : LambdaSettings - * engine_settings : OpenMMEngineSettings - * integrator_settings : IntegratorSettings - * equil_simulation_settings : MDSimulationSettings - * equil_output_settings : MDOutputSettings - * simulation_settings : MultiStateSimulationSettings - * output_settings: MultiStateOutputSettings - """ - prot_settings = self._inputs["protocol"].settings - - settings = {} - settings["forcefield_settings"] = prot_settings.solvent_forcefield_settings - settings["thermo_settings"] = prot_settings.thermo_settings - settings["charge_settings"] = prot_settings.partial_charge_settings - settings["solvation_settings"] = prot_settings.solvation_settings - settings["alchemical_settings"] = prot_settings.alchemical_settings - settings["lambda_settings"] = prot_settings.lambda_settings - settings["engine_settings"] = prot_settings.solvent_engine_settings - settings["integrator_settings"] = prot_settings.integrator_settings - settings["equil_simulation_settings"] = prot_settings.solvent_equil_simulation_settings - settings["equil_output_settings"] = prot_settings.solvent_equil_output_settings - settings["simulation_settings"] = prot_settings.solvent_simulation_settings - settings["output_settings"] = prot_settings.solvent_output_settings - - return settings + return repeats \ No newline at end of file From 1422cf9f06a8a7a2526e2ddbdd7b8f11b1d01e24 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 3 Jan 2026 14:59:22 +0000 Subject: [PATCH 08/11] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- openfe/protocols/openmm_afe/__init__.py | 8 ++++---- openfe/protocols/openmm_afe/abfe_units.py | 2 +- openfe/protocols/openmm_afe/afe_protocol_results.py | 10 ++++------ openfe/protocols/openmm_afe/ahfe_units.py | 2 +- .../protocols/openmm_afe/equil_binding_afe_method.py | 10 ++++------ .../protocols/openmm_afe/equil_solvation_afe_method.py | 6 +++--- 6 files changed, 17 insertions(+), 21 deletions(-) diff --git a/openfe/protocols/openmm_afe/__init__.py b/openfe/protocols/openmm_afe/__init__.py index 05aa324c3..ac74a7ae9 100644 --- a/openfe/protocols/openmm_afe/__init__.py +++ b/openfe/protocols/openmm_afe/__init__.py @@ -5,6 +5,10 @@ """ +from .afe_protocol_results import ( + AbsoluteBindingProtocolResult, + AbsoluteSolvationProtocolResult, +) from .equil_binding_afe_method import ( AbsoluteBindingComplexUnit, AbsoluteBindingProtocol, @@ -17,10 +21,6 @@ AbsoluteSolvationSolventUnit, AbsoluteSolvationVacuumUnit, ) -from .afe_protocol_results import ( - AbsoluteBindingProtocolResult, - AbsoluteSolvationProtocolResult, -) __all__ = [ "AbsoluteSolvationProtocol", diff --git a/openfe/protocols/openmm_afe/abfe_units.py b/openfe/protocols/openmm_afe/abfe_units.py index cbbc11fab..f24340aec 100644 --- a/openfe/protocols/openmm_afe/abfe_units.py +++ b/openfe/protocols/openmm_afe/abfe_units.py @@ -7,6 +7,7 @@ This module defines the ProtocolUnits for the :class:`AbsoluteBindingProtocol`. """ + import logging import pathlib @@ -37,7 +38,6 @@ from .base_afe_units import BaseAbsoluteUnit - logger = logging.getLogger(__name__) diff --git a/openfe/protocols/openmm_afe/afe_protocol_results.py b/openfe/protocols/openmm_afe/afe_protocol_results.py index c9564565d..13744099a 100644 --- a/openfe/protocols/openmm_afe/afe_protocol_results.py +++ b/openfe/protocols/openmm_afe/afe_protocol_results.py @@ -11,6 +11,7 @@ * AbsoluteBindingProtocolResult * AbsoluteSolvationProtocolResult """ + import itertools import logging import pathlib @@ -20,14 +21,11 @@ import gufe import numpy as np import numpy.typing as npt -from openff.units import unit as offunit from openff.units import Quantity +from openff.units import unit as offunit from openmmtools import multistate -from openfe.protocols.restraint_utils.geometry.boresch import ( - BoreschRestraintGeometry -) - +from openfe.protocols.restraint_utils.geometry.boresch import BoreschRestraintGeometry logger = logging.getLogger(__name__) @@ -184,7 +182,7 @@ def get_replica_states(self) -> dict[str, list[npt.NDArray]]: """ replica_states: dict[str, list[npt.NDArray]] = { self.bound_state: [], - self.unbound_state: [] + self.unbound_state: [], } def is_file(filename: str): diff --git a/openfe/protocols/openmm_afe/ahfe_units.py b/openfe/protocols/openmm_afe/ahfe_units.py index 96269da9b..9fb9b03da 100644 --- a/openfe/protocols/openmm_afe/ahfe_units.py +++ b/openfe/protocols/openmm_afe/ahfe_units.py @@ -7,6 +7,7 @@ This module defines the ProtocolUnits for the :class:`AbsoluteSolvationProtocol`. """ + import logging from openfe.protocols.openmm_afe.equil_afe_settings import ( @@ -16,7 +17,6 @@ from ..openmm_utils import system_validation from .base_afe_units import BaseAbsoluteUnit - logger = logging.getLogger(__name__) diff --git a/openfe/protocols/openmm_afe/equil_binding_afe_method.py b/openfe/protocols/openmm_afe/equil_binding_afe_method.py index 5bdca550e..c619f285c 100644 --- a/openfe/protocols/openmm_afe/equil_binding_afe_method.py +++ b/openfe/protocols/openmm_afe/equil_binding_afe_method.py @@ -23,6 +23,7 @@ `Yank `_. """ + import logging import uuid import warnings @@ -54,13 +55,10 @@ OpenMMEngineSettings, OpenMMSolvationSettings, ) -from openfe.protocols.openmm_utils import ( - settings_validation, - system_validation -) +from openfe.protocols.openmm_utils import settings_validation, system_validation -from .afe_protocol_results import AbsoluteBindingProtocolResult from .abfe_units import AbsoluteBindingComplexUnit, AbsoluteBindingSolventUnit +from .afe_protocol_results import AbsoluteBindingProtocolResult due.cite( Doi("10.5281/zenodo.596504"), @@ -481,4 +479,4 @@ def _gather( for k, v in unsorted_complex_repeats.items(): repeats["complex"][str(k)] = sorted(v, key=lambda x: x.outputs["generation"]) - return repeats \ No newline at end of file + return repeats diff --git a/openfe/protocols/openmm_afe/equil_solvation_afe_method.py b/openfe/protocols/openmm_afe/equil_solvation_afe_method.py index ecea8ff95..7abcde6f7 100644 --- a/openfe/protocols/openmm_afe/equil_solvation_afe_method.py +++ b/openfe/protocols/openmm_afe/equil_solvation_afe_method.py @@ -27,6 +27,7 @@ `espaloma_charge `_ """ + import logging import uuid import warnings @@ -60,12 +61,11 @@ ) from ..openmm_utils import settings_validation, system_validation +from .afe_protocol_results import AbsoluteSolvationProtocolResult from .ahfe_units import ( AbsoluteSolvationSolventUnit, AbsoluteSolvationVacuumUnit, ) -from .afe_protocol_results import AbsoluteSolvationProtocolResult - due.cite( Doi("10.5281/zenodo.596504"), @@ -502,4 +502,4 @@ def _gather( for k, v in unsorted_vacuum_repeats.items(): repeats["vacuum"][str(k)] = sorted(v, key=lambda x: x.outputs["generation"]) - return repeats \ No newline at end of file + return repeats From dd56a7d4f2b22790db46c1297d9507a76d3603dd Mon Sep 17 00:00:00 2001 From: IAlibay Date: Sat, 3 Jan 2026 11:06:18 -0500 Subject: [PATCH 09/11] move a few things in init --- openfe/protocols/openmm_afe/__init__.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/openfe/protocols/openmm_afe/__init__.py b/openfe/protocols/openmm_afe/__init__.py index ac74a7ae9..c85755b58 100644 --- a/openfe/protocols/openmm_afe/__init__.py +++ b/openfe/protocols/openmm_afe/__init__.py @@ -10,14 +10,18 @@ AbsoluteSolvationProtocolResult, ) from .equil_binding_afe_method import ( - AbsoluteBindingComplexUnit, AbsoluteBindingProtocol, AbsoluteBindingSettings, +) +from .abfe_units import ( + AbsoluteBindingComplexUnit, AbsoluteBindingSolventUnit, ) from .equil_solvation_afe_method import ( AbsoluteSolvationProtocol, AbsoluteSolvationSettings, +) +from .ahfe_units import ( AbsoluteSolvationSolventUnit, AbsoluteSolvationVacuumUnit, ) From 228978adad62f5624432b7c7a9c4fbac6888deb9 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 4 Jan 2026 19:31:07 +0000 Subject: [PATCH 10/11] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- openfe/protocols/openmm_afe/__init__.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/openfe/protocols/openmm_afe/__init__.py b/openfe/protocols/openmm_afe/__init__.py index c85755b58..314179c56 100644 --- a/openfe/protocols/openmm_afe/__init__.py +++ b/openfe/protocols/openmm_afe/__init__.py @@ -5,26 +5,26 @@ """ +from .abfe_units import ( + AbsoluteBindingComplexUnit, + AbsoluteBindingSolventUnit, +) from .afe_protocol_results import ( AbsoluteBindingProtocolResult, AbsoluteSolvationProtocolResult, ) +from .ahfe_units import ( + AbsoluteSolvationSolventUnit, + AbsoluteSolvationVacuumUnit, +) from .equil_binding_afe_method import ( AbsoluteBindingProtocol, AbsoluteBindingSettings, ) -from .abfe_units import ( - AbsoluteBindingComplexUnit, - AbsoluteBindingSolventUnit, -) from .equil_solvation_afe_method import ( AbsoluteSolvationProtocol, AbsoluteSolvationSettings, ) -from .ahfe_units import ( - AbsoluteSolvationSolventUnit, - AbsoluteSolvationVacuumUnit, -) __all__ = [ "AbsoluteSolvationProtocol", From 64b2187084502041aa670dd32776fd8b42119c49 Mon Sep 17 00:00:00 2001 From: IAlibay Date: Sun, 4 Jan 2026 14:35:51 -0500 Subject: [PATCH 11/11] make mypy happy --- .../openmm_afe/afe_protocol_results.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/openfe/protocols/openmm_afe/afe_protocol_results.py b/openfe/protocols/openmm_afe/afe_protocol_results.py index 13744099a..b2e44cc58 100644 --- a/openfe/protocols/openmm_afe/afe_protocol_results.py +++ b/openfe/protocols/openmm_afe/afe_protocol_results.py @@ -85,7 +85,8 @@ def get_forward_and_reverse_energy_analysis( for key in [self.bound_state, self.unbound_state]: forward_reverse[key] = [ - pus[0].outputs["forward_and_reverse_energies"] for pus in self.data[key].values() + pus[0].outputs["forward_and_reverse_energies"] + for pus in self.data[key].values() # type: ignore[attr-defined] ] if None in forward_reverse[key]: @@ -125,7 +126,8 @@ def get_overlap_matrices(self) -> dict[str, list[dict[str, npt.NDArray]]]: for key in [self.bound_state, self.unbound_state]: overlap_stats[key] = [ - pus[0].outputs["unit_mbar_overlap"] for pus in self.data[key].values() + pus[0].outputs["unit_mbar_overlap"] + for pus in self.data[key].values() # type: ignore[attr-defined] ] return overlap_stats @@ -159,7 +161,8 @@ def get_replica_transition_statistics(self) -> dict[str, list[dict[str, npt.NDAr try: for key in [self.bound_state, self.unbound_state]: repex_stats[key] = [ - pus[0].outputs["replica_exchange_statistics"] for pus in self.data[key].values() + pus[0].outputs["replica_exchange_statistics"] + for pus in self.data[key].values() # type: ignore[attr-defined] ] except KeyError: errmsg = "Replica exchange statistics were not found, did you run a repex calculation?" @@ -209,7 +212,7 @@ def get_replica_state(nc, chk): return retval for key in [self.bound_state, self.unbound_state]: - for pus in self.data[key].values(): + for pus in self.data[key].values(): # type: ignore[attr-defined] states = get_replica_state( pus[0].outputs["nc"], pus[0].outputs["last_checkpoint"], @@ -235,7 +238,8 @@ def equilibration_iterations(self) -> dict[str, list[float]]: for key in [self.bound_state, self.unbound_state]: equilibration_lengths[key] = [ - pus[0].outputs["equilibration_iterations"] for pus in self.data[key].values() + pus[0].outputs["equilibration_iterations"] + for pus in self.data[key].values() # type: ignore[attr-defined] ] return equilibration_lengths @@ -259,7 +263,8 @@ def production_iterations(self) -> dict[str, list[float]]: for key in [self.bound_state, self.unbound_state]: production_lengths[key] = [ - pus[0].outputs["production_iterations"] for pus in self.data[key].values() + pus[0].outputs["production_iterations"] + for pus in self.data[key].values() # type: ignore[attr-defined] ] return production_lengths @@ -283,7 +288,7 @@ def selection_indices(self) -> dict[str, list[Optional[npt.NDArray]]]: for key in [self.bound_state, self.unbound_state]: indices[key] = [] - for pus in self.data[key].values(): + for pus in self.data[key].values(): # type: ignore[attr-defined] indices[key].append(pus[0].outputs["selection_indices"]) return indices