diff --git a/openfe/protocols/openmm_rfe/_rfe_utils/multistate.py b/openfe/protocols/openmm_rfe/_rfe_utils/multistate.py index 8c6b4eddc..85981aae6 100644 --- a/openfe/protocols/openmm_rfe/_rfe_utils/multistate.py +++ b/openfe/protocols/openmm_rfe/_rfe_utils/multistate.py @@ -17,7 +17,6 @@ import openmmtools.states as states from openmm import unit from openmmtools import cache -from openmmtools.integrators import FIREMinimizationIntegrator from openmmtools.multistate import multistatesampler, replicaexchange, sams from openmmtools.states import CompoundThermodynamicState, SamplerState, ThermodynamicState @@ -32,14 +31,21 @@ class HybridCompatibilityMixin(object): unsampled endpoints have a different number of degrees of freedom. """ - def __init__(self, *args, hybrid_factory=None, **kwargs): - self._hybrid_factory = hybrid_factory + def __init__(self, *args, hybrid_system, hybrid_positions, **kwargs): + self._hybrid_system = hybrid_system + self._hybrid_positions = hybrid_positions super(HybridCompatibilityMixin, self).__init__(*args, **kwargs) - def setup(self, reporter, lambda_protocol, - temperature=298.15 * unit.kelvin, n_replicas=None, - endstates=True, minimization_steps=100, - minimization_platform="CPU"): + def setup( + self, + reporter, + lambda_protocol, + temperature=298.15 * unit.kelvin, + n_replicas=None, + endstates=True, + minimization_steps=100, + minimization_platform="CPU" + ): """ Setup MultistateSampler based on the input lambda protocol and number of replicas. @@ -73,15 +79,17 @@ class creation of LambdaProtocol. """ n_states = len(lambda_protocol.lambda_schedule) - hybrid_system = self._factory.hybrid_system + lambda_zero_state = RelativeAlchemicalState.from_system(self._hybrid_system) - lambda_zero_state = RelativeAlchemicalState.from_system(hybrid_system) + thermostate = ThermodynamicState( + self._hybrid_system, + temperature=temperature + ) - thermostate = ThermodynamicState(hybrid_system, - temperature=temperature) compound_thermostate = CompoundThermodynamicState( - thermostate, - composable_states=[lambda_zero_state]) + thermostate, + composable_states=[lambda_zero_state] + ) # create lists for storing thermostates and sampler states thermodynamic_state_list = [] @@ -105,24 +113,30 @@ class creation of LambdaProtocol. raise ValueError(errmsg) # starting with the hybrid factory positions - box = hybrid_system.getDefaultPeriodicBoxVectors() - sampler_state = SamplerState(self._factory.hybrid_positions, - box_vectors=box) + box = self._hybrid_system.getDefaultPeriodicBoxVectors() + sampler_state = SamplerState( + self._hybrid_positions, + box_vectors=box + ) # Loop over the lambdas and create & store a compound thermostate at # that lambda value for lambda_val in lambda_schedule: compound_thermostate_copy = copy.deepcopy(compound_thermostate) compound_thermostate_copy.set_alchemical_parameters( - lambda_val, lambda_protocol) + lambda_val, lambda_protocol + ) thermodynamic_state_list.append(compound_thermostate_copy) # now generating a sampler_state for each thermodyanmic state, # with relaxed positions # Note: remove once choderalab/openmmtools#672 is completed - minimize(compound_thermostate_copy, sampler_state, - max_iterations=minimization_steps, - platform_name=minimization_platform) + minimize( + compound_thermostate_copy, + sampler_state, + max_iterations=minimization_steps, + platform_name=minimization_platform + ) sampler_state_list.append(copy.deepcopy(sampler_state)) del compound_thermostate, sampler_state @@ -131,11 +145,13 @@ class creation of LambdaProtocol. if len(sampler_state_list) != n_replicas: # picking roughly evenly spaced sampler states # if n_replicas == 1, then it will pick the first in the list - samples = np.linspace(0, len(sampler_state_list) - 1, - n_replicas) + samples = np.linspace(0, len(sampler_state_list) - 1, n_replicas) idx = np.round(samples).astype(int) - sampler_state_list = [state for i, state in - enumerate(sampler_state_list) if i in idx] + sampler_state_list = [ + state + for i, state in enumerate(sampler_state_list) + if i in idx + ] assert len(sampler_state_list) == n_replicas @@ -143,13 +159,20 @@ class creation of LambdaProtocol. # generating unsampled endstates unsampled_dispersion_endstates = create_endstates( copy.deepcopy(thermodynamic_state_list[0]), - copy.deepcopy(thermodynamic_state_list[-1])) - self.create(thermodynamic_states=thermodynamic_state_list, - sampler_states=sampler_state_list, storage=reporter, - unsampled_thermodynamic_states=unsampled_dispersion_endstates) + copy.deepcopy(thermodynamic_state_list[-1]) + ) + self.create( + thermodynamic_states=thermodynamic_state_list, + sampler_states=sampler_state_list, + storage=reporter, + unsampled_thermodynamic_states=unsampled_dispersion_endstates + ) else: - self.create(thermodynamic_states=thermodynamic_state_list, - sampler_states=sampler_state_list, storage=reporter) + self.create( + thermodynamic_states=thermodynamic_state_list, + sampler_states=sampler_state_list, + storage=reporter + ) class HybridRepexSampler(HybridCompatibilityMixin, @@ -158,11 +181,13 @@ class HybridRepexSampler(HybridCompatibilityMixin, ReplicaExchangeSampler that supports unsampled end states with a different number of positions """ - - def __init__(self, *args, hybrid_factory=None, **kwargs): + def __init__(self, *args, hybrid_system, hybrid_positions, **kwargs): super(HybridRepexSampler, self).__init__( - *args, hybrid_factory=hybrid_factory, **kwargs) - self._factory = hybrid_factory + *args, + hybrid_system=hybrid_system, + hybrid_positions=hybrid_positions, + **kwargs + ) class HybridSAMSSampler(HybridCompatibilityMixin, sams.SAMSSampler): @@ -170,12 +195,13 @@ class HybridSAMSSampler(HybridCompatibilityMixin, sams.SAMSSampler): SAMSSampler that supports unsampled end states with a different number of positions """ - - def __init__(self, *args, hybrid_factory=None, **kwargs): + def __init__(self, *args, hybrid_system, hybrid_positions, **kwargs): super(HybridSAMSSampler, self).__init__( - *args, hybrid_factory=hybrid_factory, **kwargs + *args, + hybrid_system=hybrid_system, + hybrid_positions=hybrid_positions, + **kwargs ) - self._factory = hybrid_factory class HybridMultiStateSampler(HybridCompatibilityMixin, @@ -184,11 +210,13 @@ class HybridMultiStateSampler(HybridCompatibilityMixin, MultiStateSampler that supports unsample end states with a different number of positions """ - def __init__(self, *args, hybrid_factory=None, **kwargs): + def __init__(self, *args, hybrid_system, hybrid_positions, **kwargs): super(HybridMultiStateSampler, self).__init__( - *args, hybrid_factory=hybrid_factory, **kwargs + *args, + hybrid_system=hybrid_system, + hybrid_positions=hybrid_positions, + **kwargs ) - self._factory = hybrid_factory def create_endstates(first_thermostate, last_thermostate): diff --git a/openfe/protocols/openmm_rfe/equil_rfe_protocols.py b/openfe/protocols/openmm_rfe/equil_rfe_protocols.py new file mode 100644 index 000000000..a9a9abd33 --- /dev/null +++ b/openfe/protocols/openmm_rfe/equil_rfe_protocols.py @@ -0,0 +1,522 @@ +# This code is part of OpenFE and is licensed under the MIT license. +# For details, see https://github.com/OpenFreeEnergy/openfe +"""Equilibrium Relative Free Energy methods using OpenMM and OpenMMTools in a +Perses-like manner. + +This module implements the necessary methodology toolking to run calculate a +ligand relative free energy transformation using OpenMM tools and one of the +following methods: + - Hamiltonian Replica Exchange + - Self-adjusted mixture sampling + - Independent window sampling + +TODO +---- +* Improve this docstring by adding an example use case. + +Acknowledgements +---------------- +This Protocol is based on, and leverages components originating from +the Perses toolkit (https://github.com/choderalab/perses). +""" + +from __future__ import annotations + +import logging +import uuid +import warnings +from collections import defaultdict +from typing import Any, Iterable, Optional, Union + +from openff.units import unit as offunit +import gufe +from gufe import ( + ChemicalSystem, + Component, + ComponentMapping, + ProteinComponent, + SmallMoleculeComponent, + SolventComponent, + settings, + LigandAtomMapping +) + +from openfe.due import Doi, due +from ..openmm_utils import ( + settings_validation, + system_validation, +) +from .equil_rfe_settings import ( + AlchemicalSettings, + IntegratorSettings, + LambdaSettings, + MultiStateOutputSettings, + MultiStateSimulationSettings, + OpenFFPartialChargeSettings, + OpenMMEngineSettings, + OpenMMSolvationSettings, + RelativeHybridTopologyProtocolSettings, +) +from .equil_rfe_results import RelativeHybridTopologyProtocolResult + + +logger = logging.getLogger(__name__) + + +due.cite( + Doi("10.5281/zenodo.1297683"), + description="Perses", + path="openfe.protocols.openmm_rfe.equil_rfe_methods", + cite_module=True, +) + +due.cite( + Doi("10.5281/zenodo.596622"), + description="OpenMMTools", + path="openfe.protocols.openmm_rfe.equil_rfe_methods", + cite_module=True, +) + +due.cite( + Doi("10.1371/journal.pcbi.1005659"), + description="OpenMM", + path="openfe.protocols.openmm_rfe.equil_rfe_methods", + cite_module=True, +) + + +class RelativeHybridTopologyProtocol(gufe.Protocol): + """ + Relative Free Energy calculations using OpenMM and OpenMMTools. + + Based on `Perses `_ + + See Also + -------- + :mod:`openfe.protocols` + :class:`openfe.protocols.openmm_rfe.RelativeHybridTopologySettings` + :class:`openfe.protocols.openmm_rfe.RelativeHybridTopologyResult` + :class:`openfe.protocols.openmm_rfe.RelativeHybridTopologyProtocolUnit` + """ + + result_cls = RelativeHybridTopologyProtocolResult + _settings_cls = RelativeHybridTopologyProtocolSettings + _settings: RelativeHybridTopologyProtocolSettings + + @classmethod + def _default_settings(cls): + """A dictionary of initial settings for this creating this Protocol + + These settings are intended as a suitable starting point for creating + an instance of this protocol. It is recommended, however that care is + taken to inspect and customize these before performing a Protocol. + + Returns + ------- + Settings + a set of default settings + """ + return RelativeHybridTopologyProtocolSettings( + protocol_repeats=3, + forcefield_settings=settings.OpenMMSystemGeneratorFFSettings(), + thermo_settings=settings.ThermoSettings( + temperature=298.15 * offunit.kelvin, + pressure=1 * offunit.bar, + ), + partial_charge_settings=OpenFFPartialChargeSettings(), + solvation_settings=OpenMMSolvationSettings(), + alchemical_settings=AlchemicalSettings(softcore_LJ="gapsys"), + lambda_settings=LambdaSettings(), + simulation_settings=MultiStateSimulationSettings( + equilibration_length=1.0 * offunit.nanosecond, + production_length=5.0 * offunit.nanosecond, + ), + engine_settings=OpenMMEngineSettings(), + integrator_settings=IntegratorSettings(), + output_settings=MultiStateOutputSettings(), + ) + + @classmethod + def _adaptive_settings( + cls, + stateA: ChemicalSystem, + stateB: ChemicalSystem, + mapping: gufe.LigandAtomMapping | list[gufe.LigandAtomMapping], + initial_settings: None | RelativeHybridTopologyProtocolSettings = None, + ) -> RelativeHybridTopologyProtocolSettings: + """ + Get the recommended OpenFE settings for this protocol based on the input states involved in the + transformation. + + These are intended as a suitable starting point for creating an instance of this protocol, which can be further + customized before performing a Protocol. + + Parameters + ---------- + stateA : ChemicalSystem + The initial state of the transformation. + stateB : ChemicalSystem + The final state of the transformation. + mapping : LigandAtomMapping | list[LigandAtomMapping] + The mapping(s) between transforming components in stateA and stateB. + initial_settings : None | RelativeHybridTopologyProtocolSettings, optional + Initial settings to base the adaptive settings on. If None, default settings are used. + + Returns + ------- + RelativeHybridTopologyProtocolSettings + The recommended settings for this protocol based on the input states. + + Notes + ----- + - If the transformation involves a change in net charge, the settings are adapted to use a more expensive + protocol with 22 lambda windows and 20 ns production length per window. + - If both states contain a ProteinComponent, the solvation padding is set to 1 nm. + - If initial_settings is provided, the adaptive settings are based on a copy of these settings. + """ + # use initial settings or default settings + # this is needed for the CLI so we don't override user settings + if initial_settings is not None: + protocol_settings = initial_settings.copy(deep=True) + else: + protocol_settings = cls.default_settings() + + if isinstance(mapping, list): + mapping = mapping[0] + + if mapping.get_alchemical_charge_difference() != 0: + # apply the recommended charge change settings taken from the industry benchmarking as fast settings not validated + # + info = ( + "Charge changing transformation between ligands " + f"{mapping.componentA.name} and {mapping.componentB.name}. " + "A more expensive protocol with 22 lambda windows, sampled " + "for 20 ns each, will be used here." + ) + logger.info(info) + protocol_settings.alchemical_settings.explicit_charge_correction = True + protocol_settings.simulation_settings.production_length = 20 * unit.nanosecond + protocol_settings.simulation_settings.n_replicas = 22 + protocol_settings.lambda_settings.lambda_windows = 22 + + # adapt the solvation padding based on the system components + if stateA.contains(ProteinComponent) and stateB.contains(ProteinComponent): + protocol_settings.solvation_settings.solvent_padding = 1 * unit.nanometer + + return protocol_settings + + @staticmethod + def _validate_endstates( + stateA: ChemicalSystem, + stateB: ChemicalSystem, + ) -> None: + """ + Validates the end states for the RFE protocol. + + Parameters + ---------- + stateA : ChemicalSystem + The chemical system of end state A. + stateB : ChemicalSystem + The chemical system of end state B. + + Raises + ------ + ValueError + * If either state contains more than one unique Component. + * If unique components are not SmallMoleculeComponents. + """ + # Get the difference in Components between each state + diff = stateA.component_diff(stateB) + + for i, entry in enumerate(diff): + state_label = "A" if i == 0 else "B" + + # Check that there is only one unique Component in each state + if len(entry) != 0: + errmsg = ( + "Only one alchemical component is allowed per end state. " + f"Found {len(entry)} in state {state_label}." + ) + raise ValueError(errmsg) + + # Check that the unique Component is a SmallMoleculeComponent + if not isinstance(entry[0], SmallMoleculeComponent): + errmsg = ( + f"Alchemical component in state {state_label} is of type " + f"{type(entry[0])}, but only SmallMoleculeComponents " + "transformations are currently supported." + ) + raise ValueError(errmsg) + + @staticmethod + def _validate_mapping( + mapping: Optional[Union[ComponentMapping, list[ComponentMapping]]], + alchemical_components: dict[str, list[Component]], + ) -> None: + """ + Validates that the provided mapping(s) are suitable for the RFE protocol. + + Parameters + ---------- + mapping : Optional[Union[ComponentMapping, list[ComponentMapping]]] + all mappings between transforming components. + alchemical_components : dict[str, list[Component]] + Dictionary contatining the alchemical components for + states A and B. + + Raises + ------ + ValueError + * If there are more than one mapping or mapping is None + * If the mapping components are not in the alchemical components. + UserWarning + * Mappings which involve element changes in core atoms + """ + # if a single mapping is provided, convert to list + if isinstance(mapping, ComponentMapping): + mapping = [mapping] + + # For now we only support a single mapping + if mapping is None or len(mapping) > 1: + errmsg = "A single LigandAtomMapping is expected for this Protocol" + raise ValueError(errmsg) + + # check that the mapping components are in the alchemical components + for m in mapping: + if m.componentA not in alchemical_components["stateA"]: + raise ValueError(f"Mapping componentA {m.componentA} not in alchemical components of stateA") + if m.componentB not in alchemical_components["stateB"]: + raise ValueError(f"Mapping componentB {m.componentB} not in alchemical components of stateB") + + # TODO: remove - this is now the default behaviour? + # Check for element changes in mappings + for m in mapping: + molA = m.componentA.to_rdkit() + molB = m.componentB.to_rdkit() + for i, j in m.componentA_to_componentB.items(): + atomA = molA.GetAtomWithIdx(i) + atomB = molB.GetAtomWithIdx(j) + if atomA.GetAtomicNum() != atomB.GetAtomicNum(): + wmsg = ( + f"Element change in mapping between atoms " + f"Ligand A: {i} (element {atomA.GetAtomicNum()}) and " + f"Ligand B: {j} (element {atomB.GetAtomicNum()})\n" + "No mass scaling is attempted in the hybrid topology, " + "the average mass of the two atoms will be used in the " + "simulation" + ) + logger.warning(wmsg) + warnings.warn(wmsg) + + @staticmethod + def _validate_charge_difference( + mapping: LigandAtomMapping, + nonbonded_method: str, + explicit_charge_correction: bool, + solvent_component: SolventComponent | None, + ): + """ + Validates the net charge difference between the two states. + + Parameters + ---------- + mapping : dict[str, ComponentMapping] + Dictionary of mappings between transforming components. + nonbonded_method : str + The OpenMM nonbonded method used for the simulation. + explicit_charge_correction : bool + Whether or not to use an explicit charge correction. + solvent_component : openfe.SolventComponent | None + The SolventComponent of the simulation. + + Raises + ------ + ValueError + * If an explicit charge correction is attempted and the + nonbonded method is not PME. + * If the absolute charge difference is greater than one + and an explicit charge correction is attempted. + UserWarning + * If there is any charge difference. + """ + difference = mapping.get_alchemical_charge_difference() + + if abs(difference) == 0: + return + + if not explicit_charge_correction: + wmsg = ( + f"A charge difference of {difference} is observed " + "between the end states. No charge correction has " + "been requested, please account for this in your " + "final results." + ) + logger.warning(wmsg) + warnings.warn(wmsg) + return + + # We implicitly check earlier that we have to have pme for a solvated + # system, so we only need to check the nonbonded method here + if nonbonded_method.lower() != "pme": + errmsg = "Explicit charge correction when not using PME is not currently supported." + raise ValueError(errmsg) + + if abs(difference) > 1: + errmsg = ( + f"A charge difference of {difference} is observed " + "between the end states and an explicit charge " + "correction has been requested. Unfortunately " + "only absolute differences of 1 are supported." + ) + raise ValueError(errmsg) + + ion = { + -1: solvent_component.positive_ion, + 1: solvent_component.negative_ion + }[difference] + + wmsg = ( + f"A charge difference of {difference} is observed " + "between the end states. This will be addressed by " + f"transforming a water into a {ion} ion" + ) + logger.info(wmsg) + + def _validate( + self, + stateA: ChemicalSystem, + stateB: ChemicalSystem, + mapping: gufe.ComponentMapping | list[gufe.ComponentMapping] | None, + extends: gufe.ProtocolDAGResult | None = None, + ) -> None: + # Check we're not trying to extend + if extends: + # This technically should be NotImplementedError + # but gufe.Protocol.validate calls `_validate` wrapped around an + # except for NotImplementedError, so we can't raise it here + raise ValueError("Can't extend simulations yet") + + # Validate the end states + self._validate_endstates(stateA, stateB) + + # Valildate the mapping + alchem_comps = system_validation.get_alchemical_components(stateA, stateB) + self._validate_mapping(mapping, alchem_comps) + + # Validate solvent component + nonbond = self.settings.forcefield_settings.nonbonded_method + system_validation.validate_solvent(stateA, nonbond) + + # Validate solvation settings + settings_validation.validate_openmm_solvation_settings(self.settings.solvation_settings) + + # Validate protein component + system_validation.validate_protein(stateA) + + # Validate charge difference + # Note: validation depends on the mapping & solvent component checks + if stateA.contains(SolventComponent): + solv_comp = stateA.get_components_of_type(SolventComponent)[0] + else: + solv_comp = None + + self._validate_charge_difference( + mapping=mapping[0] if isinstance(mapping, list) else mapping, + nonbonded_method=self.settings.forcefield_settings.nonbonded_method, + explicit_charge_correction=self.settings.alchemical_settings.explicit_charge_correction, + solvent_component=solv_comp, + ) + + # Validate integrator things + settings_validation.validate_timestep( + self.settings.forcefield_settings.hydrogen_mass, + self.settings.integrator_settings.timestep, + ) + + _ = settings_validation.convert_steps_per_iteration( + simulation_settings=self.settings.simulation_settings, + integrator_settings=self.settings.integrator_settings, + ) + + _ = settings_validation.get_simsteps( + sim_length=self.settings.simulation_settings.equilibration_length, + timestep=self.settings.integrator_settings.timestep, + mc_steps=steps_per_iteration, + ) + + _ = settings_validation.get_simsteps( + sim_length=self.settings.simulation_settings.production_length, + timestep=self.settings.integrator_settings.timestep, + mc_steps=steps_per_iteration, + ) + + _ = settings_validation.convert_checkpoint_interval_to_iterations( + checkpoint_interval=self.settings.output_settings.checkpoint_interval, + time_per_iteration=self.settings.simulation_settings.time_per_iteration, + ) + + # Validate alchemical settings + # PR #125 temporarily pin lambda schedule spacing to n_replicas + if self.settings.simulation_settings.n_replicas != self.settings.lambda_settings.n_windows: + errmsg = ( + "Number of replicas in simulation_settings must equal " + "number of lambda windows in lambda_settings." + ) + raise ValueError(errmsg) + + def _create( + self, + stateA: ChemicalSystem, + stateB: ChemicalSystem, + mapping: Optional[Union[gufe.ComponentMapping, list[gufe.ComponentMapping]]], + extends: Optional[gufe.ProtocolDAGResult] = None, + ) -> list[gufe.ProtocolUnit]: + # validate inputs + self.validate(stateA=stateA, stateB=stateB, mapping=mapping, extends=extends) + + # get alchemical components and mapping + alchem_comps = system_validation.get_alchemical_components(stateA, stateB) + ligandmapping = mapping[0] if isinstance(mapping, list) else mapping + + # actually create and return Units + Anames = ",".join(c.name for c in alchem_comps["stateA"]) + Bnames = ",".join(c.name for c in alchem_comps["stateB"]) + + # our DAG has no dependencies, so just list units + n_repeats = self.settings.protocol_repeats + + units = [ + RelativeHybridTopologyProtocolUnit( + protocol=self, + stateA=stateA, + stateB=stateB, + ligandmapping=ligandmapping, + generation=0, + repeat_id=int(uuid.uuid4()), + name=f"{Anames} to {Bnames} repeat {i} generation 0", + ) + for i in range(n_repeats) + ] + + return units + + def _gather(self, protocol_dag_results: Iterable[gufe.ProtocolDAGResult]) -> dict[str, Any]: + # result units will have a repeat_id and generations within this repeat_id + # first group according to repeat_id + unsorted_repeats = defaultdict(list) + for d in protocol_dag_results: + pu: gufe.ProtocolUnitResult + for pu in d.protocol_unit_results: + if not pu.ok(): + continue + + unsorted_repeats[pu.outputs["repeat_id"]].append(pu) + + # then sort by generation within each repeat_id list + repeats: dict[str, list[gufe.ProtocolUnitResult]] = {} + for k, v in unsorted_repeats.items(): + repeats[str(k)] = sorted(v, key=lambda x: x.outputs["generation"]) + + # returns a dict of repeat_id: sorted list of ProtocolUnitResult + return repeats \ No newline at end of file diff --git a/openfe/protocols/openmm_rfe/equil_rfe_results.py b/openfe/protocols/openmm_rfe/equil_rfe_results.py new file mode 100644 index 000000000..26e411c98 --- /dev/null +++ b/openfe/protocols/openmm_rfe/equil_rfe_results.py @@ -0,0 +1,242 @@ +# This code is part of OpenFE and is licensed under the MIT license. +# For details, see https://github.com/OpenFreeEnergy/openfe +"""ProtocolResult objects for the RelativeHybridTopologyProtocol. + +Acknowledgements +---------------- +This Protocol is based on, and leverages components originating from +the Perses toolkit (https://github.com/choderalab/perses). +""" +import logging +import pathlib +import warnings +from typing import Optional, Union + +import gufe +import numpy as np +import numpy.typing as npt +from openff.units import Quantity +from openmmtools import multistate + + +logger = logging.getLogger(__name__) + + +class RelativeHybridTopologyProtocolResult(gufe.ProtocolResult): + """Dict-like container for the output of a RelativeHybridTopologyProtocol""" + + def __init__(self, **data): + super().__init__(**data) + # data is mapping of str(repeat_id): list[protocolunitresults] + # TODO: Detect when we have extensions and stitch these together? + if any(len(pur_list) > 2 for pur_list in self.data.values()): + raise NotImplementedError("Can't stitch together results yet") + + @staticmethod + def compute_mean_estimate(dGs: list[Quantity]) -> Quantity: + u = dGs[0].u + # convert all values to units of the first value, then take average of magnitude + # this would avoid a screwy case where each value was in different units + vals = np.asarray([dG.to(u).m for dG in dGs]) + + return np.average(vals) * u + + def get_estimate(self) -> Quantity: + """Average free energy difference of this transformation + + Returns + ------- + dG : openff.units.Quantity + The free energy difference between the first and last states. This is + a Quantity defined with units. + """ + # TODO: Check this holds up completely for SAMS. + dGs = [pus[0].outputs["unit_estimate"] for pus in self.data.values()] + return self.compute_mean_estimate(dGs) + + @staticmethod + def compute_uncertainty(dGs: list[Quantity]) -> Quantity: + u = dGs[0].u + # convert all values to units of the first value, then take average of magnitude + # this would avoid a screwy case where each value was in different units + vals = np.asarray([dG.to(u).m for dG in dGs]) + + return np.std(vals) * u + + def get_uncertainty(self) -> Quantity: + """The uncertainty/error in the dG value: The std of the estimates of + each independent repeat + """ + + dGs = [pus[0].outputs["unit_estimate"] for pus in self.data.values()] + return self.compute_uncertainty(dGs) + + def get_individual_estimates(self) -> list[tuple[Quantity, Quantity]]: + """Return a list of tuples containing the individual free energy + estimates and associated MBAR errors for each repeat. + + Returns + ------- + dGs : list[tuple[openff.units.Quantity]] + n_replicate simulation list of tuples containing the free energy + estimates (first entry) and associated MBAR estimate errors + (second entry). + """ + dGs = [ + (pus[0].outputs["unit_estimate"], pus[0].outputs["unit_estimate_error"]) + for pus in self.data.values() + ] + return dGs + + def get_forward_and_reverse_energy_analysis( + self, + ) -> list[Optional[dict[str, Union[npt.NDArray, Quantity]]]]: + """ + Get a list of forward and reverse analysis of the free energies + for each repeat using uncorrelated production samples. + + The returned dicts have keys: + 'fractions' - the fraction of data used for this estimate + 'forward_DGs', 'reverse_DGs' - for each fraction of data, the estimate + 'forward_dDGs', 'reverse_dDGs' - for each estimate, the uncertainty + + The 'fractions' values are a numpy array, while the other arrays are + Quantity arrays, with units attached. + + If the list entry is ``None`` instead of a dictionary, this indicates + that the analysis could not be carried out for that repeat. This + is most likely caused by MBAR convergence issues when attempting to + calculate free energies from too few samples. + + + Returns + ------- + forward_reverse : list[Optional[dict[str, Union[npt.NDArray, openff.units.Quantity]]]] + + + Raises + ------ + UserWarning + If any of the forward and reverse entries are ``None``. + """ + forward_reverse = [ + pus[0].outputs["forward_and_reverse_energies"] for pus in self.data.values() + ] + + if None in forward_reverse: + wmsg = ( + "One or more ``None`` entries were found in the list of " + "forward and reverse analyses. This is likely caused by " + "an MBAR convergence failure caused by too few independent " + "samples when calculating the free energies of the 10% " + "timeseries slice." + ) + warnings.warn(wmsg) + + return forward_reverse + + def get_overlap_matrices(self) -> list[dict[str, npt.NDArray]]: + """ + Return a list of dictionary containing the MBAR overlap estimates + calculated for each repeat. + + Returns + ------- + overlap_stats : list[dict[str, npt.NDArray]] + A list of dictionaries containing the following keys: + * ``scalar``: One minus the largest nontrivial eigenvalue + * ``eigenvalues``: The sorted (descending) eigenvalues of the + overlap matrix + * ``matrix``: Estimated overlap matrix of observing a sample from + state i in state j + """ + # Loop through and get the repeats and get the matrices + overlap_stats = [pus[0].outputs["unit_mbar_overlap"] for pus in self.data.values()] + + return overlap_stats + + def get_replica_transition_statistics(self) -> list[dict[str, npt.NDArray]]: + """The replica lambda state transition statistics for each repeat. + + Note + ---- + This is currently only available in cases where a replica exchange + simulation was run. + + Returns + ------- + repex_stats : list[dict[str, npt.NDArray]] + A list of dictionaries containing the following: + * ``eigenvalues``: The sorted (descending) eigenvalues of the + lambda state transition matrix + * ``matrix``: The transition matrix estimate of a replica switching + from state i to state j. + """ + try: + repex_stats = [ + pus[0].outputs["replica_exchange_statistics"] for pus in self.data.values() + ] + except KeyError: + errmsg = "Replica exchange statistics were not found, did you run a repex calculation?" + raise ValueError(errmsg) + + return repex_stats + + def get_replica_states(self) -> list[npt.NDArray]: + """ + Returns the timeseries of replica states for each repeat. + + Returns + ------- + replica_states : List[npt.NDArray] + List of replica states for each repeat + """ + + def is_file(filename: str): + p = pathlib.Path(filename) + if not p.exists(): + errmsg = f"File could not be found {p}" + raise ValueError(errmsg) + return p + + replica_states = [] + + for pus in self.data.values(): + nc = is_file(pus[0].outputs["nc"]) + dir_path = nc.parents[0] + chk = is_file(dir_path / pus[0].outputs["last_checkpoint"]).name + reporter = multistate.MultiStateReporter( + storage=nc, checkpoint_storage=chk, open_mode="r" + ) + replica_states.append(np.asarray(reporter.read_replica_thermodynamic_states())) + reporter.close() + + return replica_states + + def equilibration_iterations(self) -> list[float]: + """ + Returns the number of equilibration iterations for each repeat + of the calculation. + + Returns + ------- + equilibration_lengths : list[float] + """ + equilibration_lengths = [ + pus[0].outputs["equilibration_iterations"] for pus in self.data.values() + ] + + return equilibration_lengths + + def production_iterations(self) -> list[float]: + """ + Returns the number of uncorrelated production samples for each + repeat of the calculation. + + Returns + ------- + production_lengths : list[float] + """ + production_lengths = [pus[0].outputs["production_iterations"] for pus in self.data.values()] + + return production_lengths diff --git a/openfe/protocols/openmm_rfe/equil_rfe_units.py b/openfe/protocols/openmm_rfe/equil_rfe_units.py new file mode 100644 index 000000000..253dd65cd --- /dev/null +++ b/openfe/protocols/openmm_rfe/equil_rfe_units.py @@ -0,0 +1,1091 @@ +# This code is part of OpenFE and is licensed under the MIT license. +# For details, see https://github.com/OpenFreeEnergy/openfe +"""Equilibrium Relative Free Energy methods using OpenMM and OpenMMTools in a +Perses-like manner. + +This module implements the necessary methodology toolking to run calculate a +ligand relative free energy transformation using OpenMM tools and one of the +following methods: + - Hamiltonian Replica Exchange + - Self-adjusted mixture sampling + - Independent window sampling + +TODO +---- +* Improve this docstring by adding an example use case. + +Acknowledgements +---------------- +This Protocol is based on, and leverages components originating from +the Perses toolkit (https://github.com/choderalab/perses). +""" + +from __future__ import annotations + +import json +import logging +import os +import pathlib +import subprocess +import warnings +from itertools import chain +from typing import Any, Optional + +import gufe +import matplotlib.pyplot as plt +import mdtraj +import numpy as np +import openmmtools +from gufe import ( + ChemicalSystem, + LigandAtomMapping, + SmallMoleculeComponent, + SolventComponent, + settings, +) +from openff.toolkit.topology import Molecule as OFFMolecule +from openff.units import unit as offunit +from openff.units.openmm import ensure_quantity, from_openmm, to_openmm +from openmmtools import multistate + +from openfe.protocols.openmm_utils.omm_settings import ( + BasePartialChargeSettings, +) + +from ...analysis import plotting +from ...utils import log_system_probe, without_oechem_backend +from ..openmm_utils import ( + charge_generation, + multistate_analysis, + omm_compute, + settings_validation, + system_creation, + system_validation, +) +from . import _rfe_utils +from .equil_rfe_settings import ( + AlchemicalSettings, + IntegratorSettings, + LambdaSettings, + MultiStateOutputSettings, + MultiStateSimulationSettings, + OpenFFPartialChargeSettings, + OpenMMSolvationSettings, + RelativeHybridTopologyProtocolSettings, +) + +logger = logging.getLogger(__name__) + + +class HybridTopProtocolSetupUnit(gufe.ProtocolUnit): + + @staticmethod + def _assign_partial_charges( + partial_charge_settings: OpenFFPartialChargeSettings, + smc_components: dict[SmallMoleculeComponent, OFFMolecule], + ) -> None: + """ + Assign partial charges to SMCs. + + Parameters + ---------- + charge_settings : OpenFFPartialChargeSettings + Settings for controlling how the partial charges are assigned. + smc_components : dict[SmallMoleculeComponent, openff.toolkit.Molecule] + Dictionary of OpenFF Molecules to add, keyed by + SmallMoleculeComponent. + """ + for mol in smc_components.values(): + charge_generation.assign_offmol_partial_charges( + offmol=mol, + overwrite=False, + method=partial_charge_settings.partial_charge_method, + toolkit_backend=partial_charge_settings.off_toolkit_backend, + generate_n_conformers=partial_charge_settings.number_of_conformers, + nagl_model=partial_charge_settings.nagl_model, + ) + + def _prepare( + self, + verbose, + scratch_basepath: pathlib.Path | None, + shared_basepath: pathlib.Path | None, + ): + """ + Set basepaths and do some initial logging. + + Parameters + ---------- + verbose : bool + Verbose output of the simulation progress. Output is provided via + INFO level logging. + basepath : Optional[pathlib.Path] + Optional base path to write files to. + """ + self.verbose = verbose + + if self.verbose: + self.logger.info("Setting up the hybrid topology simulation") + + # set basepaths + def _set_optional_path(basepath): + if basepath is None: + return pathlib.Path(".") + return basepath + + self.scratch_basepath = _set_optional_path(scratch_basepath) + self.shared_basepath = _set_optional_path(shared_basepath) + + def _get_components(self): + """ + Get the components from the ChemicalSystem inputs. + + Returns + ------- + alchem_comps : dict[str, Component] + Dictionary of alchemical components. + solv_comp : SolventComponent + The solvent component. + protein_comp : ProteinComponent + The protein component. + small_mols : list[SmallMoleculeComponent: OFFMolecule] + List of small molecule components. + """ + stateA = self._inputs["stateA"] + stateB = self._inputs["stateB"] + alchem_comps = system_validation.get_alchemical_components(stateA, stateB) + + solvent_comp, protein_comp, smcs_A = system_validation.get_components(stateA) + _, _, smcs_B = system_validation.get_components(stateB) + + small_mols = { + m: m.to_openff() + for m in set(smcs_A).union(set(smcs_B)) + } + + return alchem_comps, solvent_comp, protein_comp, small_mols + + def _get_settings(self) -> dict[str, SettingsBaseModel]: + """ + Get the protocol settings from the inputs. + + Returns + ------- + protocol_settings : RelativeHybridTopologyProtocolSettings + The protocol settings. + """ + settings: RelativeHybridTopologyProtocolSettings = self._inputs["protocol"].settings + + protocol_settings: dict[str, SettingsBaseModel] = {} + protocol_settings["forcefield_settings"] = settings.forcefield_settings + protocol_settings["thermo_settings"] = settings.thermo_settings + protocol_settings["alchemical_settings"] = settings.alchemical_settings + protocol_settings["lambda_settings"] = settings.lambda_settings + protocol_settings["charge_settings"] = settings.partial_charge_settings + protocol_settings["solvation_settings"] = settings.solvation_settings + protocol_settings["simulation_settings"] = settings.simulation_settings + protocol_settings["output_settings"] = settings.output_settings + protocol_settings["integrator_settings"] = settings.integrator_settings + protocol_settings["engine_settings"] = settings.engine_settings + return protocol_settings + + @staticmethod + def _get_system_generator( + shared_basepath: pathlib.Path, + settings: dict[str, SettingsBaseModel], + solvent_comp: SolventComponent | None, + ) -> SystemGenerator: + """ + Get an OpenMM SystemGenerator. + + Parameters + ---------- + settings : dict[str, SettingsBaseModel] + A dictionary of protocol settings. + solvent_comp : SolventComponent | None + The solvent component of the system, if any. + + Returns + ------- + system_generator : openmmtools.SystemGenerator + The SystemGenerator for the protocol. + """ + ffcache = settings["output_settings"].forcefield_cache + if ffcache is not None: + ffcache = shared_basepath / ffcache + + # Block out oechem backend in system_generator calls to avoid + # any issues with smiles roundtripping between rdkit and oechem + with without_oechem_backend(): + system_generator = system_creation.get_system_generator( + forcefield_settings=settings["forcefield_settings"], + integrator_settings=settings["integrator_settings"], + thermo_settings=settings["thermo_settings"], + cache=ffcache, + has_solvent=solvent_comp is not None, + ) + + return system_generator + + @staticmethod + def _create_stateA_system( + protein_component: ProteinComponent | None, + solvent_component: SolventComponent | None, + small_mols_stateA: dict[SmallMoleculeComponent, OFFMolecule], + system_generator: SystemGenerator, + solvation_settings: OpenMMSolvationSettings, + ): + stateA_modeller, comp_resids = system_creation.get_omm_modeller( + protein_comp=protein_component, + solvent_comp=solvent_component, + small_mols=small_mols_stateA, + omm_forcefield=system_generator.forcefield, + solvent_settings=solvation_settings, + ) + + stateA_topology = stateA_modeller.getTopology() + # Note: roundtrip positions to remove vec3 issues + stateA_positions = to_openmm(from_openmm(stateA_modeller.getPositions())) + + with without_oechem_backend(): + stateA_system = system_generator.create_system( + stateA_modeller.topology, + molecules=list(small_mols_stateA.values()), + ) + + return stateA_system, stateA_topology, stateA_positions, comp_resids + + @staticmethod + def _create_stateB_system( + small_mols_stateB: dict[SmallMoleculeComponent, OFFMolecule], + mapping: LigandAtomMapping, + stateA_topology: app.Topology, + exclude_resids: np.ndarray, + system_generator: SystemGenerator, + ): + stateB_topology, stateB_alchem_resids = _rfe_utils.topologyhelpers.combined_topology( + topology1=stateA_topology, + topology2=small_mols_stateB[mapping.componentB].to_topology().to_openmm(), + exclude_resids=exclude_resids, + ) + + with without_oechem_backend(): + stateB_system = system_generator.create_system( + stateB_topology, + molecules=list(small_mols_stateB.values()), + ) + + return stateB_system, stateB_topology, stateB_alchem_resids + + @staticmethod + def _handle_alchemical_waters( + stateA_topology: app.Topology, + stateA_positions: npt.NDArray, + stateB_topology: app.Topology, + stateB_system: openmm.System, + charge_difference: int, + system_mappings: dict[str, dict[int, int]], + alchemical_settings: AlchemicalSettings, + solvent_component: SolventComponent | None, + ): + if charge_difference == 0: + return + + alchem_water_resids = _rfe_utils.topologyhelpers.get_alchemical_waters( + stateA_topology, + stateA_positions, + charge_difference, + alchemical_settings.explicit_charge_correction_cutoff, + ) + + _rfe_utils.topologyhelpers.handle_alchemical_waters( + alchem_water_resids, + stateB_topology, + stateB_system, + system_mappings, + charge_difference, + solvent_component, + ) + + def _get_omm_objects( + self, + stateA, + stateB, + mapping, + settings: dict[str, SettingsBaseModel], + protein_component: ProteinComponent | None, + solvent_component: SolventComponent | None, + small_mols: dict[SmallMoleculeComponent, OFFMolecule], + ): + if self.verbose: + self.logger.info("Parameterizing system") + + # Get the system generator and register the templates + system_generator = self._get_system_generator( + shared_basepath=self.shared_basepath, + settings=settings, + solvent_comp=solvent_component + ) + + system_generator.add_molecules( + molecules=list(small_mols.values()) + ) + + # State A system creation + small_mols_stateA = { + smc: offmol + for smc, offmol in small_mols.items() + if stateA.contains(smc) + } + + stateA_system, stateA_topology, stateA_positions, comp_resids = self._create_stateA_system( + protein_component=protein_component, + solvent_component=solvent_component, + small_mols_stateA=small_mols_stateA, + system_generator=system_generator, + solvation_settings=settings["solvation_settings"], + ) + + # State B system creation + small_mols_stateB = { + smc: offmol + for smc, offmol in small_mols.items() + if stateB.contains(smc) + } + + stateB_system, stateB_topology, stateB_alchem_resids = self._create_stateB_system( + small_mols_stateB=small_mols_stateB, + mapping=mapping, + stateA_topology=stateA_topology, + exclude_resids = comp_resids[mapping.componentA], + system_generator=system_generator, + ) + + # Get the mapping between the two systems + system_mappings = _rfe_utils.topologyhelpers.get_system_mappings( + old_to_new_atom_map=mapping.componentA_to_componentB, + old_system=stateA_system, + old_topology=stateA_topology, + old_resids=comp_resids[mapping.componentA], + new_system=stateB_system, + new_topology=stateB_topology, + new_resids=stateB_alchem_resids, + # These are non-optional settings for this method + fix_constraints=True, + ) + + # Handle alchemical waters if needed + if settings["alchemical_settings"].explicit_charge_correction: + self._handle_alchemical_waters( + stateA_topology=stateA_topology, + stateA_positions=stateA_positions, + stateB_topology=stateB_topology, + stateB_system=stateB_system, + charge_difference=mapping.get_alchemical_charge_difference(), + system_mappings=system_mappings, + alchemical_settings=settings["alchemical_settings"], + solvent_component=solvent_component, + ) + + stateB_positions = _rfe_utils.topologyhelpers.set_and_check_new_positions( + system_mappings, + stateA_topology, + stateB_topology, + old_positions=ensure_quantity(stateA_positions, "openmm"), + insert_positions=ensure_quantity( + small_mols[mapping.componentB].conformers[0], "openmm" + ), + ) + + return ( + stateA_system, stateA_topology, stateA_positions, + stateB_system, stateB_topology, stateB_positions, + system_mappings + ) + + @staticmethod + def _get_alchemical_system( + stateA_system, + stateA_positions, + stateA_topology, + stateB_system, + stateB_positions, + stateB_topology, + system_mappings, + alchemical_settings: AlchemicalSettings, + ): + if alchemical_settings.softcore_LJ.lower() == "gapsys": + softcore_LJ_v2 = True + elif alchemical_settings.softcore_LJ.lower() == "beutler": + softcore_LJ_v2 = False + + hybrid_factory = _rfe_utils.relative.HybridTopologyFactory( + stateA_system, + stateA_positions, + stateA_topology, + stateB_system, + stateB_positions, + stateB_topology, + old_to_new_atom_map=system_mappings["old_to_new_atom_map"], + old_to_new_core_atom_map=system_mappings["old_to_new_core_atom_map"], + use_dispersion_correction=alchemical_settings.use_dispersion_correction, + softcore_alpha=alchemical_settings.softcore_alpha, + softcore_LJ_v2=softcore_LJ_v2, + softcore_LJ_v2_alpha=alchemical_settings.softcore_alpha, + interpolate_old_and_new_14s=alchemical_settings.turn_off_core_unique_exceptions, + ) + + return hybrid_factory, hybrid_factory.hybrid_system + + def run( + self, *, dry=False, verbose=True, scratch_basepath=None, shared_basepath=None + ) -> dict[str, Any]: + """Set up a Hybrid Topology system. + + Parameters + ---------- + dry : bool + Do a dry run of the calculation, creating all necessary hybrid + system components (topology, system, sampler, etc...) but without + running the simulation. + verbose : bool + Verbose output of the simulation progress. Output is provided via + INFO level logging. + scratch_basepath: Pathlike, optional + Where to store temporary files, defaults to current working directory + shared_basepath : Pathlike, optional + Where to run the calculation, defaults to current working directory + + Returns + ------- + dict + Outputs created in the basepath directory or the debug objects + (i.e. sampler) if ``dry==True``. + + Raises + ------ + error + Exception if anything failed + """ + # set up logging and basepaths + self._prepare( + verbose=verbose, + scratch_basepath=scratch_basepath, + shared_basepath=shared_basepath, + ) + + # Get the components + mapping = self._inputs["ligandmapping"] + stateA = self._inputs["stateA"] + stateB = self._inputs["stateB"] + alchem_comps, solvent_comp, protein_comp, off_small_mols = self._get_components() + + # Get the settings + settings = self._get_settings() + + # Get the OpenMM objects + ( + stateA_system, stateA_topology, stateA_positions, + stateB_system, stateB_topology, stateB_positions, + ligand_mappings + ) = self._get_omm_objects( + stateA=stateA, + stateB=stateB, + mapping=mapping, + settings=settings, + protein_component=protein_comp, + solvent_component=solvent_comp, + small_mols=off_small_mols, + ) + + # Get the alchemical factory & system + hybrid_factory, hybrid_system = self._get_alchemical_system( + stateA_system, + stateA_positions, + stateA_topology, + stateB_system, + stateB_positions, + stateB_topology, + ligand_mappings, + alchemical_settings=settings["alchemical_settings"], + ) + + # Verify alchemical system + if hybrid_factory.has_virtual_sites: + if not settings["integrator_settings"].reassign_velocities: + errmsg = ( + "Simulations with virtual sites without velocity " + "reassignments are unstable in openmmtools" + ) + raise ValueError(errmsg) + + # Get the selection indices for the system + selection_indices = hybrid_factory.hybrid_topology.select( + settings["output_settings"].output_indices + ) + + # Write out a PDB containing the subsampled hybrid state + bfactors = np.zeros_like(selection_indices, dtype=float) # environment + bfactors[np.in1d(selection_indices, list(hybrid_factory._atom_classes['unique_old_atoms']))] = 0.25 # lig A + bfactors[np.in1d(selection_indices, list(hybrid_factory._atom_classes['core_atoms']))] = 0.50 # core + bfactors[np.in1d(selection_indices, list(hybrid_factory._atom_classes['unique_new_atoms']))] = 0.75 # lig B + + if len(selection_indices) > 0: + traj = mdtraj.Trajectory( + hybrid_factory.hybrid_positions[selection_indices, :], + hybrid_factory.hybrid_topology.subset(selection_indices), + ).save_pdb( + shared_basepath / settings["output_settings"].output_structure, + bfactors=bfactors, + ) + + # Serialize the hybrid system + system_outfile = self.shared_basepath / "hybrid_system.xml.bz2" + serialize(hybrid_system, system_outfile) + + # Serialize the positions + positions_outfile = self.shared_basepath / "hybrid_positions.npz" + npy_positions_nm = from_openmm(hybrid_factory.hybrid_positions).to("nanometer").m + np.savez(positions_outfile, npy_positions_nm) + + + unit_result_dict = { + "system": system_outfile, + "positions": positions_outfile, + "pdb_structure": shared_basepath / settings["output_settings"].output_structure + "selection_indices": selection_indices, + } + + # If this is a dry run, we return the objects directly + if dry: + unit_result_dict |= { + "hybrid_factory": hybrid_factory, + "hybrid_system": hybrid_system, + } + + return unit_result_dict + + def _execute( + self, + ctx: gufe.Context, + **inputs, + ) -> dict[str, Any]: + log_system_probe(logging.INFO, paths=[ctx.scratch]) + + outputs = self.run(scratch_basepath=ctx.scratch, shared_basepath=ctx.shared) + + return { + "repeat_id": self._inputs["repeat_id"], + "generation": self._inputs["generation"], + **outputs, + } + + +class HybridTopProtocolSimulationUnit(gufe.ProtocolUnit): + def _prepare( + self, + verbose, + scratch_basepath: pathlib.Path | None, + shared_basepath: pathlib.Path | None, + ): + """ + Set basepaths and do some initial logging. + + Parameters + ---------- + verbose : bool + Verbose output of the simulation progress. Output is provided via + INFO level logging. + basepath : Optional[pathlib.Path] + Optional base path to write files to. + """ + self.verbose = verbose + + if self.verbose: + self.logger.info("Setting up the hybrid topology simulation") + + # set basepaths + def _set_optional_path(basepath): + if basepath is None: + return pathlib.Path(".") + return basepath + + self.scratch_basepath = _set_optional_path(scratch_basepath) + self.shared_basepath = _set_optional_path(shared_basepath) + + def _get_settings(self) -> dict[str, SettingsBaseModel]: + """ + Get the protocol settings from the inputs. + + Returns + ------- + protocol_settings : RelativeHybridTopologyProtocolSettings + The protocol settings. + """ + settings: RelativeHybridTopologyProtocolSettings = self._inputs["protocol"].settings + + protocol_settings: dict[str, SettingsBaseModel] = {} + protocol_settings["forcefield_settings"] = settings.forcefield_settings + protocol_settings["thermo_settings"] = settings.thermo_settings + protocol_settings["alchemical_settings"] = settings.alchemical_settings + protocol_settings["lambda_settings"] = settings.lambda_settings + protocol_settings["charge_settings"] = settings.partial_charge_settings + protocol_settings["solvation_settings"] = settings.solvation_settings + protocol_settings["simulation_settings"] = settings.simulation_settings + protocol_settings["output_settings"] = settings.output_settings + protocol_settings["integrator_settings"] = settings.integrator_settings + protocol_settings["engine_settings"] = settings.engine_settings + return protocol_settings + + def _get_reporter( + self, + selection_indices: np.ndarray, + output_settings: MultiStateOutputSettings, + simulation_settings: MultiStateSimulationSettings, + ): + nc = self.shared_basepath / output_settings.output_filename + chk = output_settings.checkpoint_storage_filename + + if output_settings.positions_write_frequency is not None: + pos_interval = settings_validation.divmod_time_and_check( + numerator=output_settings.positions_write_frequency, + denominator=simulation_settings.time_per_iteration, + numerator_name="output settings' position_write_frequency", + denominator_name="simulation settings' time_per_iteration", + ) + else: + pos_interval = 0 + + if output_settings.velocities_write_frequency is not None: + vel_interval = settings_validation.divmod_time_and_check( + numerator=output_settings.velocities_write_frequency, + denominator=sampler_settings.time_per_iteration, + numerator_name="output settings' velocity_write_frequency", + denominator_name="sampler settings' time_per_iteration", + ) + else: + vel_interval = 0 + + chk_intervals = settings_validation.convert_checkpoint_interval_to_iterations( + checkpoint_interval=output_settings.checkpoint_interval, + time_per_iteration=simulation_settings.time_per_iteration, + ) + + return multistate.MultiStateReporter( + storage=nc, + analysis_particle_indices=selection_indices, + checkpoint_interval=chk_intervals, + checkpoint_storage=chk, + position_interval=pos_interval, + velocity_interval=vel_interval, + ) + + @staticmethod + def _get_sampler( + system: openmm.System, + positions: openmm.Quantity, + lambdas: _rfe_utils.lambdaprotocol.LambdaProtocol, + integrator: openmmtools.mcmc.MCMCMove, + reporter: multistate.MultiStateReporter, + simulation_settings: MultiStateSimulationSettings, + thermo_settings: ThermodynamicSettings, + alchem_settings: AlchemicalSettings, + platform: openmm.Platform, + dry: bool, + ): + + rta_its, rta_min_its = settings_validation.convert_real_time_analysis_iterations( + simulation_settings=simulation_settings, + ) + + # convert early_termination_target_error from kcal/mol to kT + early_termination_target_error = ( + settings_validation.convert_target_error_from_kcal_per_mole_to_kT( + thermo_settings.temperature, + simulation_settings.early_termination_target_error, + ) + ) + + if simulation_settings.sampler_method.lower() == "repex": + sampler = _rfe_utils.multistate.HybridRepexSampler( + mcmc_moves=integrator, + hybrid_system=system, + hybrid_positions=positions, + online_analysis_interval=rta_its, + online_analysis_target_error=early_termination_target_error, + online_analysis_minimum_iterations=rta_min_its, + ) + + elif simulation_settings.sampler_method.lower() == "sams": + sampler = _rfe_utils.multistate.HybridSAMSSampler( + mcmc_moves=integrator, + hybrid_system=system, + hybrid_positions=positions, + online_analysis_interval=rta_its, + online_analysis_minimum_iterations=rta_min_its, + flatness_criteria=simulation_settings.sams_flatness_criteria, + gamma0=simulation_settings.sams_gamma0, + ) + + elif simulation_settings.sampler_method.lower() == "independent": + sampler = _rfe_utils.multistate.HybridMultiStateSampler( + mcmc_moves=integrator, + hybrid_system=system, + hybrid_positions=positions, + online_analysis_interval=rta_its, + online_analysis_target_error=early_termination_target_error, + online_analysis_minimum_iterations=rta_min_its, + ) + + else: + raise AttributeError(f"Unknown sampler {simulation_settings.sampler_method}") + + sampler.setup( + n_replicas=simulation_settings.n_replicas, + reporter=reporter, + lambda_protocol=lambdas, + temperature=to_openmm(thermo_settings.temperature), + endstates=alchem_settings.endstate_dispersion_correction, + minimization_platform=platform.getName(), + # Set minimization steps to None when running in dry mode + # otherwise do a very small one to avoid NaNs + minimization_steps=100 if not dry else None, + ) + + sampler.energy_context_cache = energy_context_cache + sampler.sampler_context_cache = sampler_context_cache + + return sampler + + def _get_ctx_caches( + self, + platform: openmm.Platform, + ) -> tuple[openmmtools.cache.ContextCache, openmmtools.cache.ContextCache]: + """ + Set the context caches based on the chosen platform + + Parameters + ---------- + platform: openmm.Platform + The OpenMM compute platform. + + Returns + ------- + energy_context_cache : openmmtools.cache.ContextCache + The energy state context cache. + sampler_context_cache : openmmtools.cache.ContextCache + The sampler state context cache. + """ + energy_context_cache = openmmtools.cache.ContextCache( + capacity=None, + time_to_live=None, + platform=platform, + ) + + sampler_context_cache = openmmtools.cache.ContextCache( + capacity=None, + time_to_live=None, + platform=platform, + ) + + return energy_context_cache, sampler_context_cache + + def run(self, *, dry=False, verbose=True, scratch_basepath=None, shared_basepath=None): + + # Get relevant outputs from setup + system = deserialize(self._inputs["setup_results"]["system"]) + positions = to_openmm( + deserialize(self._inputs["setup_results"]["positions"]) * offunit.nm + ) + selection_indices = self._inputs["setup_results"]["selection_indices"] + + # Get the settings + settings = self._get_settings() + + # Get the lambda schedule + lambdas = _rfe_utils.lambdaprotocol.LambdaProtocol( + functions=lambda_settings.lambda_functions, + windows=lambda_settings.lambda_windows + ) + + # Define simulation steps + steps_per_iteration = settings_validation.convert_steps_per_iteration( + simulation_settings=settings["simulation_settings"], + integrator_settings=settings["integrator_settings"], + ) + + equilibration_steps = settings_validation.get_simsteps( + sim_length=settings["simulation_settings"].equilibration_length, + timestep=settings["integrator_settings"].timestep, + mc_steps=steps_per_iteration, + ) + + production_steps = settings_validation.get_simsteps( + sim_length=settings["simulation_settings"].production_length, + timestep=settings["integrator_settings"].timestep, + mc_steps=steps_per_iteration, + ) + + try: + # Get the reporter + reporter = self._get_reporter( + selection_indices=selection_indices, + output_settings=settings["output_settings"], + simulation_settings=settings["simulation_settings"], + ) + + # Get the compute platform + # restrict to a single CPU if running vacuum + restrict_cpu = settings["forcefield_settings"].nonbonded_method.lower() == "nocutoff" + platform = omm_compute.get_openmm_platform( + platform_name=settings["engine_settings"].compute_platform, + gpu_device_index=settings["engine_settings"].gpu_device_index, + restrict_cpu_count=restrict_cpu, + ) + + # Get the integrator + integrator = openmmtools.mcmc.LangevinDynamicsMove( + timestep=to_openmm(settings["integrator_settings"].timestep), + collision_rate=to_openmm(settings["integrator_settings"].langevin_collision_rate), + n_steps=steps_per_iteration, + reassign_velocities=settings["integrator_settings"].reassign_velocities, + n_restart_attempts=settings["integrator_settings"].n_restart_attempts, + constraint_tolerance=settings["integrator_settings"].constraint_tolerance, + ) + # Create context caches + energy_context_cache, sampler_context_cache = self._get_ctx_caches(platform) + + sampler = self._get_sampler( + system=system, + positions=positions, + lambdas=lambdas, + integrator=integrator, + reporter=reporter, + simulation_settings=settings["simulation_settings"], + thermo_settings=settings["thermo_settings"], + alchem_settings=settings["alchemical_settings"], + platform=platform, + dry=dry, + energy_context_cache=energy_context_cache, + sampler_context_cache=sampler_context_cache, + ) + + if not dry: # pragma: no-cover + # minimize + if verbose: + self.logger.info("Running minimization") + + sampler.minimize(max_iterations=sampler_settings.minimization_steps) + + # equilibrate + if verbose: + self.logger.info("Running equilibration phase") + + sampler.equilibrate(int(equil_steps / steps_per_iteration)) + + # production + if verbose: + self.logger.info("Running production phase") + + sampler.extend(int(prod_steps / steps_per_iteration)) + + self.logger.info("Production phase complete") + else: + # clean up the reporter file + fns = [ + shared_basepath / output_settings.output_filename, + shared_basepath / output_settings.checkpoint_storage_filename, + ] + for fn in fns: + os.remove(fn) + + finally: + # close reporter when you're done, prevent + # file handle clashes + reporter.close() + + # clear GPU contexts + # TODO: use cache.empty() calls when openmmtools #690 is resolved + # replace with above + for context in list(energy_context_cache._lru._data.keys()): + del energy_context_cache._lru._data[context] + for context in list(sampler_context_cache._lru._data.keys()): + del sampler_context_cache._lru._data[context] + + # cautiously clear out the global context cache too + for context in list(openmmtools.cache.global_context_cache._lru._data.keys()): + del openmmtools.cache.global_context_cache._lru._data[context] + + del sampler_context_cache, energy_context_cache + + if not dry: + del integrator, sampler + + if not dry: # pragma: no-cover + return { + "nc": nc, + "last_checkpoint": chk + } + else: + return { + "debug": { + "sampler": sampler + } + } + + def _execute( + self, + ctx: gufe.Context, + **inputs, + ) -> dict[str, Any]: + log_system_probe(logging.INFO, paths=[ctx.scratch]) + + outputs = self.run( + dry=False, + scratch_basepath=ctx.scratch, + shared_basepath=ctx.shared, + ) + + return { + "repeat_id": self._inputs["repeat_id"], + "generation": self._inputs["generation"], + **outputs, + } + + +class HybridTopProtocolAnalysisUnit(gufe.ProtocolUnit): + def _prepare( + self, + verbose, + scratch_basepath: pathlib.Path | None, + shared_basepath: pathlib.Path | None, + ): + """ + Set basepaths and do some initial logging. + + Parameters + ---------- + verbose : bool + Verbose output of the simulation progress. Output is provided via + INFO level logging. + basepath : Optional[pathlib.Path] + Optional base path to write files to. + """ + self.verbose = verbose + + if self.verbose: + self.logger.info("Setting up the hybrid topology simulation") + + # set basepaths + def _set_optional_path(basepath): + if basepath is None: + return pathlib.Path(".") + return basepath + + self.scratch_basepath = _set_optional_path(scratch_basepath) + self.shared_basepath = _set_optional_path(shared_basepath) + + @staticmethod + def structural_analysis(scratch, shared) -> dict: + # don't put energy analysis in here, it uses the open file reporter + # whereas structural stuff requires that the file handle is closed + # TODO: we should just make openfe_analysis write an npz instead! + analysis_out = scratch / "structural_analysis.json" + + ret = subprocess.run( + [ + "openfe_analysis", # CLI entry point + "RFE_analysis", # CLI option + str(shared), # Where the simulation.nc fille + str(analysis_out), # Where the analysis json file is written + ], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + if ret.returncode: + return {"structural_analysis_error": ret.stderr} + + with open(analysis_out, "rb") as f: + data = json.load(f) + + savedir = pathlib.Path(shared) + if d := data["protein_2D_RMSD"]: + fig = plotting.plot_2D_rmsd(d) + fig.savefig(savedir / "protein_2D_RMSD.png") + plt.close(fig) + f2 = plotting.plot_ligand_COM_drift(data["time(ps)"], data["ligand_wander"]) + f2.savefig(savedir / "ligand_COM_drift.png") + plt.close(f2) + + f3 = plotting.plot_ligand_RMSD(data["time(ps)"], data["ligand_RMSD"]) + f3.savefig(savedir / "ligand_RMSD.png") + plt.close(f3) + + # Save to numpy compressed format (~ 6x more space efficient than JSON) + np.savez_compressed( + shared / "structural_analysis.npz", + protein_RMSD=np.asarray(data["protein_RMSD"], dtype=np.float32), + ligand_RMSD=np.asarray(data["ligand_RMSD"], dtype=np.float32), + ligand_COM_drift=np.asarray(data["ligand_wander"], dtype=np.float32), + protein_2D_RMSD=np.asarray(data["protein_2D_RMSD"], dtype=np.float32), + time_ps=np.asarray(data["time(ps)"], dtype=np.float32), + ) + + return {"structural_analysis": shared / "structural_analysis.npz"} + + def run(self, *, dry=False, verbose=True, scratch_basepath=None, shared_basepath=None): + # set up logging and basepaths + trajectory = self._inputs["simulation_results"]["nc"] + checkpoint = self._inputs["simulation_results"]["last_checkpoint"] + + self._prepare( + verbose=verbose, + scratch_basepath=scratch_basepath, + shared_basepath=shared_basepath, + ) + + # Get energies + try: + reporter = multistate.MultiStateReporter( + storage=trajectory, + checkpoint_storage=checkpoint, + ) + + analyzer = multistate_analysis.MultistateEquilFEAnalysis( + reporter, + sampling_method=self._inputs["protocol"].settings.simulation_settings.sampler_method.lower(), + result_units=offunit.kilocalorie_per_mole, + ) + analyzer.plot(filepath=self.shared_basepath, filename_prefix="") + analyzer.close() + + # analyzer.unit_results_dict + finally: + reporter.close() + + # Get structural analysis -- todo: switch this away from the CLI + structural_analysis_outputs = self.structural_analysis( + scratch=self.scratch_basepath, + shared=self.shared_basepath, + ) + + return analyzer.unit_results_dict | structural_analysis_outputs + + def _execute( + self, + ctx: gufe.Context, + **inputs, + ) -> dict[str, Any]: + log_system_probe(logging.INFO, paths=[ctx.scratch]) + + outputs = self.run( + dry=False, + scratch_basepath=ctx.scratch, + shared_basepath=ctx.shared, + ) + + return { + "repeat_id": self._inputs["repeat_id"], + "generation": self._inputs["generation"], + **outputs, + } diff --git a/openfe/protocols/openmm_utils/system_validation.py b/openfe/protocols/openmm_utils/system_validation.py index 0fd3c3518..7b19b5d09 100644 --- a/openfe/protocols/openmm_utils/system_validation.py +++ b/openfe/protocols/openmm_utils/system_validation.py @@ -14,7 +14,6 @@ SmallMoleculeComponent, SolventComponent, ) -from openff.toolkit import Molecule as OFFMol def get_alchemical_components( @@ -42,35 +41,22 @@ def get_alchemical_components( ValueError If there are any duplicate components in states A or B. """ - matched_components: dict[Component, Component] = {} + # Check if there are any duplicate components in either state + for state in [stateA, stateB]: + comp_list = list(state.components.values()) + unique_comp_list = list(set(comp_list)) + if len(comp_list) != len(unique_comp_list): + errmsg = f"Duplicate components found in ChemicalSystem: {state}" + raise ValueError(errmsg) + alchemical_components: dict[str, list[Component]] = { "stateA": [], "stateB": [], } - for keyA, valA in stateA.components.items(): - for keyB, valB in stateB.components.items(): - if valA == valB: - if valA not in matched_components.keys(): - matched_components[valA] = valB - else: - # Could be that either we have a duplicate component - # in stateA or in stateB - errmsg = ( - f"state A components {keyA}: {valA} matches " - "multiple components in stateA or stateB" - ) - raise ValueError(errmsg) - - # populate stateA alchemical components - for valA in stateA.components.values(): - if valA not in matched_components.keys(): - alchemical_components["stateA"].append(valA) - - # populate stateB alchemical components - for valB in stateB.components.values(): - if valB not in matched_components.values(): - alchemical_components["stateB"].append(valB) + diff = stateA.component_diff(stateB) + alchemical_components["stateA"].extend(diff[0]) + alchemical_components["stateB"].extend(diff[1]) return alchemical_components @@ -95,7 +81,7 @@ def validate_solvent(state: ChemicalSystem, nonbonded_method: str): `nocutoff`. * If the SolventComponent solvent is not water. """ - solv = [comp for comp in state.values() if isinstance(comp, SolventComponent)] + solv = state.get_components_of_type(SolventComponent) if len(solv) > 0 and nonbonded_method.lower() == "nocutoff": errmsg = "nocutoff cannot be used for solvent transformations" @@ -129,9 +115,9 @@ def validate_protein(state: ChemicalSystem): ValueError If there are multiple ProteinComponent in the ChemicalSystem. """ - nprot = sum(1 for comp in state.values() if isinstance(comp, ProteinComponent)) + prots = state.get_components_of_type(ProteinComponent) - if nprot > 1: + if len(prots) > 1: errmsg = "Multiple ProteinComponent found, only one is supported" raise ValueError(errmsg)