diff --git a/news/validate-rfe.rst b/news/validate-rfe.rst new file mode 100644 index 000000000..d7036e8d4 --- /dev/null +++ b/news/validate-rfe.rst @@ -0,0 +1,26 @@ +**Added:** + +* The `validate` method for the RelativeHybridTopologyProtocol has been + implemented. This means that settings and system validation can mostly + be done prior to Protocol execution by calling + `RelativeHybridTopologyProtocol.validate(stateA, stateB, mapping)`. + +**Changed:** + +* + +**Deprecated:** + +* + +**Removed:** + +* + +**Fixed:** + +* + +**Security:** + +* diff --git a/openfe/protocols/openmm_rfe/equil_rfe_methods.py b/openfe/protocols/openmm_rfe/equil_rfe_methods.py index 4e8cf604c..697ba16f8 100644 --- a/openfe/protocols/openmm_rfe/equil_rfe_methods.py +++ b/openfe/protocols/openmm_rfe/equil_rfe_methods.py @@ -53,7 +53,6 @@ from openff.units import Quantity, unit from openff.units.openmm import ensure_quantity, from_openmm, to_openmm from openmmtools import multistate -from rdkit import Chem from openfe.due import Doi, due from openfe.protocols.openmm_utils.omm_settings import ( @@ -117,161 +116,6 @@ def _get_resname(off_mol) -> str: return names[0] -def _get_alchemical_charge_difference( - mapping: LigandAtomMapping, - nonbonded_method: str, - explicit_charge_correction: bool, - solvent_component: SolventComponent, -) -> int: - """ - Checks and returns the difference in formal charge between state A and B. - - Raises - ------ - ValueError - * If an explicit charge correction is attempted and the - nonbonded method is not PME. - * If the absolute charge difference is greater than one - and an explicit charge correction is attempted. - UserWarning - If there is any charge difference. - - Parameters - ---------- - mapping : dict[str, ComponentMapping] - Dictionary of mappings between transforming components. - nonbonded_method : str - The OpenMM nonbonded method used for the simulation. - explicit_charge_correction : bool - Whether or not to use an explicit charge correction. - solvent_component : openfe.SolventComponent - The SolventComponent of the simulation. - - Returns - ------- - int - The formal charge difference between states A and B. - This is defined as sum(charge state A) - sum(charge state B) - """ - - difference = mapping.get_alchemical_charge_difference() - - if abs(difference) > 0: - if explicit_charge_correction: - if nonbonded_method.lower() != "pme": - errmsg = "Explicit charge correction when not using PME is not currently supported." - raise ValueError(errmsg) - if abs(difference) > 1: - errmsg = ( - f"A charge difference of {difference} is observed " - "between the end states and an explicit charge " - "correction has been requested. Unfortunately " - "only absolute differences of 1 are supported." - ) - raise ValueError(errmsg) - - ion = {-1: solvent_component.positive_ion, 1: solvent_component.negative_ion}[ - difference - ] - wmsg = ( - f"A charge difference of {difference} is observed " - "between the end states. This will be addressed by " - f"transforming a water into a {ion} ion" - ) - logger.warning(wmsg) - warnings.warn(wmsg) - else: - wmsg = ( - f"A charge difference of {difference} is observed " - "between the end states. No charge correction has " - "been requested, please account for this in your " - "final results." - ) - logger.warning(wmsg) - warnings.warn(wmsg) - - return difference - - -def _validate_alchemical_components( - alchemical_components: dict[str, list[Component]], - mapping: Optional[Union[ComponentMapping, list[ComponentMapping]]], -): - """ - Checks that the alchemical components are suitable for the RFE protocol. - - Specifically we check: - 1. That all alchemical components are mapped. - 2. That all alchemical components are SmallMoleculeComponents. - 3. If the mappings involves element changes in core atoms - - Parameters - ---------- - alchemical_components : dict[str, list[Component]] - Dictionary contatining the alchemical components for - states A and B. - mapping : Optional[Union[ComponentMapping, list[ComponentMapping]]] - all mappings between transforming components. - - Raises - ------ - ValueError - * If there are more than one mapping or mapping is None - * If there are any unmapped alchemical components. - * If there are any alchemical components that are not - SmallMoleculeComponents. - UserWarning - * Mappings which involve element changes in core atoms - """ - if isinstance(mapping, ComponentMapping): - mapping = [mapping] - # Check mapping - # For now we only allow for a single mapping, this will likely change - if mapping is None or len(mapping) != 1: - errmsg = "A single LigandAtomMapping is expected for this Protocol" - raise ValueError(errmsg) - - # Check that all alchemical components are mapped & small molecules - mapped = { - "stateA": [m.componentA for m in mapping], - "stateB": [m.componentB for m in mapping], - } - - for idx in ["stateA", "stateB"]: - if len(alchemical_components[idx]) != len(mapped[idx]): - errmsg = f"missing alchemical components in {idx}" - raise ValueError(errmsg) - for comp in alchemical_components[idx]: - if comp not in mapped[idx]: - raise ValueError(f"Unmapped alchemical component {comp}") - if not isinstance(comp, SmallMoleculeComponent): # pragma: no-cover - errmsg = ( - "Transformations involving non " - "SmallMoleculeComponent species {comp} " - "are not currently supported" - ) - raise ValueError(errmsg) - - # Validate element changes in mappings - for m in mapping: - molA = m.componentA.to_rdkit() - molB = m.componentB.to_rdkit() - for i, j in m.componentA_to_componentB.items(): - atomA = molA.GetAtomWithIdx(i) - atomB = molB.GetAtomWithIdx(j) - if atomA.GetAtomicNum() != atomB.GetAtomicNum(): - wmsg = ( - f"Element change in mapping between atoms " - f"Ligand A: {i} (element {atomA.GetAtomicNum()}) and " - f"Ligand B: {j} (element {atomB.GetAtomicNum()})\n" - "No mass scaling is attempted in the hybrid topology, " - "the average mass of the two atoms will be used in the " - "simulation" - ) - logger.warning(wmsg) - warnings.warn(wmsg) # TODO: remove this once logging is fixed - - class RelativeHybridTopologyProtocolResult(gufe.ProtocolResult): """Dict-like container for the output of a RelativeHybridTopologyProtocol""" @@ -612,21 +456,337 @@ def _adaptive_settings( return protocol_settings - def _create( + @staticmethod + def _validate_endstates( + stateA: ChemicalSystem, + stateB: ChemicalSystem, + ) -> None: + """ + Validates the end states for the RFE protocol. + + Parameters + ---------- + stateA : ChemicalSystem + The chemical system of end state A. + stateB : ChemicalSystem + The chemical system of end state B. + + Raises + ------ + ValueError + * If either state contains more than one unique Component. + * If unique components are not SmallMoleculeComponents. + """ + # Get the difference in Components between each state + diff = stateA.component_diff(stateB) + + for i, entry in enumerate(diff): + state_label = "A" if i == 0 else "B" + + # Check that there is only one unique Component in each state + if len(entry) != 1: + errmsg = ( + "Only one alchemical component is allowed per end state. " + f"Found {len(entry)} in state {state_label}." + ) + raise ValueError(errmsg) + + # Check that the unique Component is a SmallMoleculeComponent + if not isinstance(entry[0], SmallMoleculeComponent): + errmsg = ( + f"Alchemical component in state {state_label} is of type " + f"{type(entry[0])}, but only SmallMoleculeComponents " + "transformations are currently supported." + ) + raise ValueError(errmsg) + + @staticmethod + def _validate_mapping( + mapping: Optional[Union[ComponentMapping, list[ComponentMapping]]], + alchemical_components: dict[str, list[Component]], + ) -> None: + """ + Validates that the provided mapping(s) are suitable for the RFE protocol. + + Parameters + ---------- + mapping : Optional[Union[ComponentMapping, list[ComponentMapping]]] + all mappings between transforming components. + alchemical_components : dict[str, list[Component]] + Dictionary contatining the alchemical components for + states A and B. + + Raises + ------ + ValueError + * If there are more than one mapping or mapping is None + * If the mapping components are not in the alchemical components. + UserWarning + * Mappings which involve element changes in core atoms + """ + # if a single mapping is provided, convert to list + if isinstance(mapping, ComponentMapping): + mapping = [mapping] + + # For now we only support a single mapping + if mapping is None or len(mapping) > 1: + errmsg = "A single LigandAtomMapping is expected for this Protocol" + raise ValueError(errmsg) + + # check that the mapping components are in the alchemical components + for m in mapping: + for state in ["A", "B"]: + comp = getattr(m, f"component{state}") + if comp not in alchemical_components[f"state{state}"]: + raise ValueError( + f"Mapping component{state} {comp} not " + f"in alchemical components of state{state}" + ) + + # TODO: remove - this is now the default behaviour? + # Check for element changes in mappings + for m in mapping: + molA = m.componentA.to_rdkit() + molB = m.componentB.to_rdkit() + for i, j in m.componentA_to_componentB.items(): + atomA = molA.GetAtomWithIdx(i) + atomB = molB.GetAtomWithIdx(j) + if atomA.GetAtomicNum() != atomB.GetAtomicNum(): + wmsg = ( + f"Element change in mapping between atoms " + f"Ligand A: {i} (element {atomA.GetAtomicNum()}) and " + f"Ligand B: {j} (element {atomB.GetAtomicNum()})\n" + "No mass scaling is attempted in the hybrid topology, " + "the average mass of the two atoms will be used in the " + "simulation" + ) + logger.warning(wmsg) + warnings.warn(wmsg) + + @staticmethod + def _validate_smcs( + stateA: ChemicalSystem, + stateB: ChemicalSystem, + ) -> None: + """ + Validates the SmallMoleculeComponents. + + Parameters + ---------- + stateA : ChemicalSystem + The chemical system of end state A. + stateB : ChemicalSystem + The chemical system of end state B. + + Raises + ------ + ValueError + * If there are isomorphic SmallMoleculeComponents with + different charges. + """ + smcs_A = stateA.get_components_of_type(SmallMoleculeComponent) + smcs_B = stateB.get_components_of_type(SmallMoleculeComponent) + smcs_all = list(set(smcs_A).union(set(smcs_B))) + offmols = [m.to_openff() for m in smcs_all] + + def _equal_charges(moli, molj): + # Base case, both molecules don't have charges + if (moli.partial_charges is None) & (molj.partial_charges is None): + return True + # If either is None but not the other + if (moli.partial_charges is None) ^ (molj.partial_charges is None): + return False + # Check if the charges are close to each other + return np.allclose(moli.partial_charges, molj.partial_charges) + + clashes = [] + + for i, moli in enumerate(offmols): + for molj in offmols: + if moli.is_isomorphic_with(molj): + if not _equal_charges(moli, molj): + clashes.append(smcs_all[i]) + + if len(clashes) > 0: + errmsg = ( + "Found SmallMoleculeComponents that are isomorphic " + "but with different charges, this is not currently allowed. " + f"Affected components: {clashes}" + ) + raise ValueError(errmsg) + + @staticmethod + def _validate_charge_difference( + mapping: LigandAtomMapping, + nonbonded_method: str, + explicit_charge_correction: bool, + solvent_component: SolventComponent | None, + ): + """ + Validates the net charge difference between the two states. + + Parameters + ---------- + mapping : dict[str, ComponentMapping] + Dictionary of mappings between transforming components. + nonbonded_method : str + The OpenMM nonbonded method used for the simulation. + explicit_charge_correction : bool + Whether or not to use an explicit charge correction. + solvent_component : openfe.SolventComponent | None + The SolventComponent of the simulation. + + Raises + ------ + ValueError + * If an explicit charge correction is attempted and the + nonbonded method is not PME. + * If the absolute charge difference is greater than one + and an explicit charge correction is attempted. + * If an explicit charge correction is attempted and there is no solvent present. + UserWarning + * If there is any charge difference. + """ + difference = mapping.get_alchemical_charge_difference() + + if abs(difference) == 0: + return + + if not explicit_charge_correction: + wmsg = ( + f"A charge difference of {difference} is observed " + "between the end states. No charge correction has " + "been requested, please account for this in your " + "final results." + ) + logger.warning(wmsg) + warnings.warn(wmsg) + return + + if solvent_component is None: + errmsg = "Cannot use explicit charge correction without solvent" + raise ValueError(errmsg) + + # We implicitly check earlier that we have to have pme for a solvated + # system, so we only need to check the nonbonded method here + if nonbonded_method.lower() != "pme": + errmsg = "Explicit charge correction when not using PME is not currently supported." + raise ValueError(errmsg) + + if abs(difference) > 1: + errmsg = ( + f"A charge difference of {difference} is observed " + "between the end states and an explicit charge " + "correction has been requested. Unfortunately " + "only absolute differences of 1 are supported." + ) + raise ValueError(errmsg) + + ion = {-1: solvent_component.positive_ion, 1: solvent_component.negative_ion}[difference] + + wmsg = ( + f"A charge difference of {difference} is observed " + "between the end states. This will be addressed by " + f"transforming a water into a {ion} ion" + ) + logger.info(wmsg) + + @staticmethod + def _validate_simulation_settings( + simulation_settings: MultiStateSimulationSettings, + integrator_settings: IntegratorSettings, + output_settings: MultiStateOutputSettings, + ): + """ + Validate various simulation settings, including but not limited to + timestep conversions, and output file write frequencies. + + Parameters + ---------- + simulation_settings : MultiStateSimulationSettings + The sampler simulation settings. + integrator_settings : IntegratorSettings + Settings defining the behaviour of the integrator. + output_settings : MultiStateOutputSettings + Settings defining the simulation file writing behaviour. + + Raises + ------ + ValueError + * If any of of the simulation control settings (e.g. + ``equilibration_length`` or ``production_length``) + are not divisible by the timestep or the number of + steps per iteration. + * If the output frequency for position, velocity, or + online analysis are not divisible by the time per + multistate iteration. + """ + + steps_per_iteration = settings_validation.convert_steps_per_iteration( + simulation_settings=simulation_settings, + integrator_settings=integrator_settings, + ) + + _ = settings_validation.get_simsteps( + sim_length=simulation_settings.equilibration_length, + timestep=integrator_settings.timestep, + mc_steps=steps_per_iteration, + ) + + _ = settings_validation.get_simsteps( + sim_length=simulation_settings.production_length, + timestep=integrator_settings.timestep, + mc_steps=steps_per_iteration, + ) + + _ = settings_validation.convert_checkpoint_interval_to_iterations( + checkpoint_interval=output_settings.checkpoint_interval, + time_per_iteration=simulation_settings.time_per_iteration, + ) + + if output_settings.positions_write_frequency is not None: + _ = settings_validation.divmod_time_and_check( + numerator=output_settings.positions_write_frequency, + denominator=simulation_settings.time_per_iteration, + numerator_name="output settings' positions_write_frequency", + denominator_name="sampler settings' time_per_iteration", + ) + + if output_settings.velocities_write_frequency is not None: + _ = settings_validation.divmod_time_and_check( + numerator=output_settings.velocities_write_frequency, + denominator=simulation_settings.time_per_iteration, + numerator_name="output settings' velocities_write_frequency", + denominator_name="sampler settings' time_per_iteration", + ) + + _, _ = settings_validation.convert_real_time_analysis_iterations( + simulation_settings=simulation_settings, + ) + + def _validate( self, stateA: ChemicalSystem, stateB: ChemicalSystem, - mapping: Optional[Union[gufe.ComponentMapping, list[gufe.ComponentMapping]]], - extends: Optional[gufe.ProtocolDAGResult] = None, - ) -> list[gufe.ProtocolUnit]: - # TODO: Extensions? + mapping: gufe.ComponentMapping | list[gufe.ComponentMapping] | None, + extends: gufe.ProtocolDAGResult | None = None, + ) -> None: + # Check we're not trying to extend if extends: - raise NotImplementedError("Can't extend simulations yet") + # This technically should be NotImplementedError + # but gufe.Protocol.validate calls `_validate` wrapped around an + # except for NotImplementedError, so we can't raise it here + raise ValueError("Can't extend simulations yet") - # Get alchemical components & validate them + mapping + # Validate the end states + self._validate_endstates(stateA, stateB) + + # Validate the mapping alchem_comps = system_validation.get_alchemical_components(stateA, stateB) - _validate_alchemical_components(alchem_comps, mapping) - ligandmapping = mapping[0] if isinstance(mapping, list) else mapping + self._validate_mapping(mapping, alchem_comps) + + # Validate the small molecule components + self._validate_smcs(stateA, stateB) # Validate solvent component nonbond = self.settings.forcefield_settings.nonbonded_method @@ -638,11 +798,68 @@ def _create( # Validate protein component system_validation.validate_protein(stateA) + # Validate charge difference + # Note: validation depends on the mapping & solvent component checks + if stateA.contains(SolventComponent): + solv_comp = stateA.get_components_of_type(SolventComponent)[0] + else: + solv_comp = None + + self._validate_charge_difference( + mapping=mapping[0] if isinstance(mapping, list) else mapping, + nonbonded_method=self.settings.forcefield_settings.nonbonded_method, + explicit_charge_correction=self.settings.alchemical_settings.explicit_charge_correction, + solvent_component=solv_comp, + ) + + # Validate integrator things + settings_validation.validate_timestep( + self.settings.forcefield_settings.hydrogen_mass, + self.settings.integrator_settings.timestep, + ) + + # Validate simulation & output settings + self._validate_simulation_settings( + self.settings.simulation_settings, + self.settings.integrator_settings, + self.settings.output_settings, + ) + + # Validate alchemical settings + # PR #125 temporarily pin lambda schedule spacing to n_replicas + if ( + self.settings.simulation_settings.n_replicas + != self.settings.lambda_settings.lambda_windows + ): + errmsg = ( + "Number of replicas in ``simulation_settings``: " + f"{self.settings.simulation_settings.n_replicas} must equal " + "the number of lambda windows in lambda_settings: " + f"{self.settings.lambda_settings.lambda_windows}." + ) + raise ValueError(errmsg) + + def _create( + self, + stateA: ChemicalSystem, + stateB: ChemicalSystem, + mapping: Optional[Union[gufe.ComponentMapping, list[gufe.ComponentMapping]]], + extends: Optional[gufe.ProtocolDAGResult] = None, + ) -> list[gufe.ProtocolUnit]: + # validate inputs + self.validate(stateA=stateA, stateB=stateB, mapping=mapping, extends=extends) + + # get alchemical components and mapping + alchem_comps = system_validation.get_alchemical_components(stateA, stateB) + ligandmapping = mapping[0] if isinstance(mapping, list) else mapping + # actually create and return Units Anames = ",".join(c.name for c in alchem_comps["stateA"]) Bnames = ",".join(c.name for c in alchem_comps["stateB"]) + # our DAG has no dependencies, so just list units n_repeats = self.settings.protocol_repeats + units = [ RelativeHybridTopologyProtocolUnit( protocol=self, @@ -816,10 +1033,6 @@ def run( output_settings: MultiStateOutputSettings = protocol_settings.output_settings integrator_settings: IntegratorSettings = protocol_settings.integrator_settings - # is the timestep good for the mass? - settings_validation.validate_timestep( - forcefield_settings.hydrogen_mass, integrator_settings.timestep - ) # TODO: Also validate various conversions? # Convert various time based inputs to steps/iterations steps_per_iteration = settings_validation.convert_steps_per_iteration( @@ -842,12 +1055,7 @@ def run( # Get the change difference between the end states # and check if the charge correction used is appropriate - charge_difference = _get_alchemical_charge_difference( - mapping, - forcefield_settings.nonbonded_method, - alchem_settings.explicit_charge_correction, - solvent_comp, - ) + charge_difference = mapping.get_alchemical_charge_difference() # 1. Create stateA system self.logger.info("Parameterizing molecules") diff --git a/openfe/protocols/openmm_utils/system_validation.py b/openfe/protocols/openmm_utils/system_validation.py index 0fd3c3518..3e8ed5c50 100644 --- a/openfe/protocols/openmm_utils/system_validation.py +++ b/openfe/protocols/openmm_utils/system_validation.py @@ -95,23 +95,24 @@ def validate_solvent(state: ChemicalSystem, nonbonded_method: str): `nocutoff`. * If the SolventComponent solvent is not water. """ - solv = [comp for comp in state.values() if isinstance(comp, SolventComponent)] + solv_comps = state.get_components_of_type(SolventComponent) - if len(solv) > 0 and nonbonded_method.lower() == "nocutoff": - errmsg = "nocutoff cannot be used for solvent transformations" - raise ValueError(errmsg) - - if len(solv) == 0 and nonbonded_method.lower() == "pme": - errmsg = "PME cannot be used for vacuum transform" - raise ValueError(errmsg) + if len(solv_comps) > 0: + if nonbonded_method.lower() == "nocutoff": + errmsg = "nocutoff cannot be used for solvent transformations" + raise ValueError(errmsg) - if len(solv) > 1: - errmsg = "Multiple SolventComponent found, only one is supported" - raise ValueError(errmsg) + if len(solv_comps) > 1: + errmsg = "Multiple SolventComponent found, only one is supported" + raise ValueError(errmsg) - if len(solv) > 0 and solv[0].smiles != "O": - errmsg = "Non water solvent is not currently supported" - raise ValueError(errmsg) + if solv_comps[0].smiles != "O": + errmsg = "Non water solvent is not currently supported" + raise ValueError(errmsg) + else: + if nonbonded_method.lower() == "pme": + errmsg = "PME cannot be used for vacuum transform" + raise ValueError(errmsg) def validate_protein(state: ChemicalSystem): @@ -129,9 +130,9 @@ def validate_protein(state: ChemicalSystem): ValueError If there are multiple ProteinComponent in the ChemicalSystem. """ - nprot = sum(1 for comp in state.values() if isinstance(comp, ProteinComponent)) + prot_comps = state.get_components_of_type(ProteinComponent) - if nprot > 1: + if len(prot_comps) > 1: errmsg = "Multiple ProteinComponent found, only one is supported" raise ValueError(errmsg) @@ -161,24 +162,18 @@ def get_components(state: ChemicalSystem) -> ParseCompRet: small_mols : list[SmallMoleculeComponent] """ - def _get_single_comps(comp_list, comptype): - ret_comps = [comp for comp in comp_list if isinstance(comp, comptype)] - if ret_comps: - return ret_comps[0] + def _get_single_comps(state, comptype): + comps = state.get_components_of_type(comptype) + + if len(comps) > 0: + return comps[0] else: return None - solvent_comp: Optional[SolventComponent] = _get_single_comps( - list(state.values()), SolventComponent - ) + solvent_comp: Optional[SolventComponent] = _get_single_comps(state, SolventComponent) - protein_comp: Optional[ProteinComponent] = _get_single_comps( - list(state.values()), ProteinComponent - ) + protein_comp: Optional[ProteinComponent] = _get_single_comps(state, ProteinComponent) - small_mols = [] - for comp in state.components.values(): - if isinstance(comp, SmallMoleculeComponent): - small_mols.append(comp) + small_mols = state.get_components_of_type(SmallMoleculeComponent) return solvent_comp, protein_comp, small_mols diff --git a/openfe/tests/protocols/openmm_rfe/test_hybrid_top_protocol.py b/openfe/tests/protocols/openmm_rfe/test_hybrid_top_protocol.py index fc11cf164..283104e52 100644 --- a/openfe/tests/protocols/openmm_rfe/test_hybrid_top_protocol.py +++ b/openfe/tests/protocols/openmm_rfe/test_hybrid_top_protocol.py @@ -30,10 +30,6 @@ from openfe import setup from openfe.protocols import openmm_rfe from openfe.protocols.openmm_rfe._rfe_utils import topologyhelpers -from openfe.protocols.openmm_rfe.equil_rfe_methods import ( - _get_alchemical_charge_difference, - _validate_alchemical_components, -) from openfe.protocols.openmm_utils import omm_compute, system_creation from openfe.protocols.openmm_utils.charge_generation import ( HAS_ESPALOMA_CHARGE, @@ -196,21 +192,6 @@ def test_create_independent_repeat_ids(benzene_system, toluene_system, benzene_t assert len(repeat_ids) == 6 -@pytest.mark.parametrize( - "mapping", - [None, [], ["A", "B"]], -) -def test_validate_alchemical_components_wrong_mappings(mapping): - with pytest.raises(ValueError, match="A single LigandAtomMapping"): - _validate_alchemical_components({"stateA": [], "stateB": []}, mapping) - - -def test_validate_alchemical_components_missing_alchem_comp(benzene_to_toluene_mapping): - alchem_comps = {"stateA": [openfe.SolventComponent()], "stateB": []} - with pytest.raises(ValueError, match="Unmapped alchemical component"): - _validate_alchemical_components(alchem_comps, benzene_to_toluene_mapping) - - @pytest.mark.parametrize("method", ["repex", "sams", "independent", "InDePeNdENT"]) def test_dry_run_default_vacuum( benzene_vacuum_system, @@ -970,246 +951,6 @@ def test_lambda_schedule(windows): assert len(lambdas.lambda_schedule) == windows -def test_hightimestep( - benzene_vacuum_system, - toluene_vacuum_system, - benzene_to_toluene_mapping, - vac_settings, - tmpdir, -): - vac_settings.forcefield_settings.hydrogen_mass = 1.0 - - p = openmm_rfe.RelativeHybridTopologyProtocol( - settings=vac_settings, - ) - - dag = p.create( - stateA=benzene_vacuum_system, - stateB=toluene_vacuum_system, - mapping=benzene_to_toluene_mapping, - ) - dag_unit = list(dag.protocol_units)[0] - - errmsg = "too large for hydrogen mass" - with tmpdir.as_cwd(): - with pytest.raises(ValueError, match=errmsg): - dag_unit.run(dry=True) - - -def test_n_replicas_not_n_windows( - benzene_vacuum_system, - toluene_vacuum_system, - benzene_to_toluene_mapping, - vac_settings, - tmpdir, -): - # For PR #125 we pin such that the number of lambda windows - # equals the numbers of replicas used - TODO: remove limitation - # default lambda windows is 11 - vac_settings.simulation_settings.n_replicas = 13 - - errmsg = "Number of replicas 13 does not equal the number of lambda windows 11" - - with tmpdir.as_cwd(): - with pytest.raises(ValueError, match=errmsg): - p = openmm_rfe.RelativeHybridTopologyProtocol( - settings=vac_settings, - ) - dag = p.create( - stateA=benzene_vacuum_system, - stateB=toluene_vacuum_system, - mapping=benzene_to_toluene_mapping, - ) - dag_unit = list(dag.protocol_units)[0] - dag_unit.run(dry=True) - - -def test_missing_ligand(benzene_system, benzene_to_toluene_mapping): - # state B doesn't have a ligand component - stateB = openfe.ChemicalSystem({"solvent": openfe.SolventComponent()}) - - p = openmm_rfe.RelativeHybridTopologyProtocol( - settings=openmm_rfe.RelativeHybridTopologyProtocol.default_settings(), - ) - - match_str = "missing alchemical components in stateB" - with pytest.raises(ValueError, match=match_str): - _ = p.create( - stateA=benzene_system, - stateB=stateB, - mapping=benzene_to_toluene_mapping, - ) - - -def test_vaccuum_PME_error( - benzene_vacuum_system, benzene_modifications, benzene_to_toluene_mapping -): - # state B doesn't have a solvent component (i.e. its vacuum) - stateB = openfe.ChemicalSystem({"ligand": benzene_modifications["toluene"]}) - - p = openmm_rfe.RelativeHybridTopologyProtocol( - settings=openmm_rfe.RelativeHybridTopologyProtocol.default_settings(), - ) - errmsg = "PME cannot be used for vacuum transform" - with pytest.raises(ValueError, match=errmsg): - _ = p.create( - stateA=benzene_vacuum_system, - stateB=stateB, - mapping=benzene_to_toluene_mapping, - ) - - -def test_incompatible_solvent(benzene_system, benzene_modifications, benzene_to_toluene_mapping): - # the solvents are different - stateB = openfe.ChemicalSystem( - { - "ligand": benzene_modifications["toluene"], - "solvent": openfe.SolventComponent(positive_ion="K", negative_ion="Cl"), - } - ) - - p = openmm_rfe.RelativeHybridTopologyProtocol( - settings=openmm_rfe.RelativeHybridTopologyProtocol.default_settings(), - ) - # We don't have a way to map non-ligand components so for now it - # just triggers that it's not a mapped component - errmsg = "missing alchemical components in stateA" - with pytest.raises(ValueError, match=errmsg): - _ = p.create( - stateA=benzene_system, - stateB=stateB, - mapping=benzene_to_toluene_mapping, - ) - - -def test_mapping_mismatch_A(benzene_system, toluene_system, benzene_modifications): - # the atom mapping doesn't refer to the ligands in the systems - mapping = setup.LigandAtomMapping( - componentA=benzene_system.components["ligand"], - componentB=benzene_modifications["phenol"], - componentA_to_componentB=dict(), - ) - - p = openmm_rfe.RelativeHybridTopologyProtocol( - settings=openmm_rfe.RelativeHybridTopologyProtocol.default_settings(), - ) - errmsg = ( - r"Unmapped alchemical component " - r"SmallMoleculeComponent\(name=toluene\)" - ) - with pytest.raises(ValueError, match=errmsg): - _ = p.create( - stateA=benzene_system, - stateB=toluene_system, - mapping=mapping, - ) - - -def test_mapping_mismatch_B(benzene_system, toluene_system, benzene_modifications): - mapping = setup.LigandAtomMapping( - componentA=benzene_modifications["phenol"], - componentB=toluene_system.components["ligand"], - componentA_to_componentB=dict(), - ) - - p = openmm_rfe.RelativeHybridTopologyProtocol( - settings=openmm_rfe.RelativeHybridTopologyProtocol.default_settings(), - ) - errmsg = ( - r"Unmapped alchemical component " - r"SmallMoleculeComponent\(name=benzene\)" - ) - with pytest.raises(ValueError, match=errmsg): - _ = p.create( - stateA=benzene_system, - stateB=toluene_system, - mapping=mapping, - ) - - -def test_complex_mismatch(benzene_system, toluene_complex_system, benzene_to_toluene_mapping): - # only one complex - p = openmm_rfe.RelativeHybridTopologyProtocol( - settings=openmm_rfe.RelativeHybridTopologyProtocol.default_settings(), - ) - with pytest.raises(ValueError): - _ = p.create( - stateA=benzene_system, - stateB=toluene_complex_system, - mapping=benzene_to_toluene_mapping, - ) - - -def test_too_many_specified_mappings(benzene_system, toluene_system, benzene_to_toluene_mapping): - # mapping dict requires 'ligand' key - p = openmm_rfe.RelativeHybridTopologyProtocol( - settings=openmm_rfe.RelativeHybridTopologyProtocol.default_settings(), - ) - errmsg = "A single LigandAtomMapping is expected for this Protocol" - with pytest.raises(ValueError, match=errmsg): - _ = p.create( - stateA=benzene_system, - stateB=toluene_system, - mapping=[benzene_to_toluene_mapping, benzene_to_toluene_mapping], - ) - - -def test_protein_mismatch( - benzene_complex_system, toluene_complex_system, benzene_to_toluene_mapping -): - # hack one protein to be labelled differently - prot = toluene_complex_system["protein"] - alt_prot = openfe.ProteinComponent(prot.to_rdkit(), name="Mickey Mouse") - alt_toluene_complex_system = openfe.ChemicalSystem( - { - "ligand": toluene_complex_system["ligand"], - "solvent": toluene_complex_system["solvent"], - "protein": alt_prot, - } - ) - - p = openmm_rfe.RelativeHybridTopologyProtocol( - settings=openmm_rfe.RelativeHybridTopologyProtocol.default_settings(), - ) - with pytest.raises(ValueError): - _ = p.create( - stateA=benzene_complex_system, - stateB=alt_toluene_complex_system, - mapping=benzene_to_toluene_mapping, - ) - - -def test_element_change_warning(atom_mapping_basic_test_files): - # check a mapping with element change gets rejected early - l1 = atom_mapping_basic_test_files["2-methylnaphthalene"] - l2 = atom_mapping_basic_test_files["2-naftanol"] - - # We use the 'old' lomap defaults because the - # basic test files inputs we use aren't fully aligned - mapper = setup.LomapAtomMapper( - time=20, threed=True, max3d=1000.0, element_change=True, seed="", shift=True - ) - - mapping = next(mapper.suggest_mappings(l1, l2)) - - sys1 = openfe.ChemicalSystem( - {"ligand": l1, "solvent": openfe.SolventComponent()}, - ) - sys2 = openfe.ChemicalSystem( - {"ligand": l2, "solvent": openfe.SolventComponent()}, - ) - - p = openmm_rfe.RelativeHybridTopologyProtocol( - settings=openmm_rfe.RelativeHybridTopologyProtocol.default_settings(), - ) - with pytest.warns(UserWarning, match="Element change"): - _ = p.create( - stateA=sys1, - stateB=sys2, - mapping=mapping, - ) - - def test_ligand_overlap_warning( benzene_vacuum_system, toluene_vacuum_system, benzene_to_toluene_mapping, vac_settings, tmpdir ): @@ -1752,68 +1493,6 @@ def test_filenotfound_replica_states(self, protocolresult): protocolresult.get_replica_states() -@pytest.mark.parametrize( - "mapping_name,result", - [ - ["benzene_to_toluene_mapping", 0], - ["benzene_to_benzoic_mapping", 1], - ["benzene_to_aniline_mapping", -1], - ["aniline_to_benzene_mapping", 1], - ], -) -def test_get_charge_difference(mapping_name, result, request): - mapping = request.getfixturevalue(mapping_name) - if result != 0: - ion = r"Na\+" if result == -1 else r"Cl\-" - wmsg = ( - f"A charge difference of {result} is observed " - "between the end states. This will be addressed by " - f"transforming a water into a {ion} ion" - ) - with pytest.warns(UserWarning, match=wmsg): - val = _get_alchemical_charge_difference(mapping, "pme", True, openfe.SolventComponent()) - assert result == pytest.approx(val) - else: - val = _get_alchemical_charge_difference(mapping, "pme", True, openfe.SolventComponent()) - assert result == pytest.approx(val) - - -def test_get_charge_difference_no_pme(benzene_to_benzoic_mapping): - errmsg = "Explicit charge correction when not using PME" - with pytest.raises(ValueError, match=errmsg): - _get_alchemical_charge_difference( - benzene_to_benzoic_mapping, - "nocutoff", - True, - openfe.SolventComponent(), - ) - - -def test_get_charge_difference_no_corr(benzene_to_benzoic_mapping): - wmsg = ( - "A charge difference of 1 is observed between the end states. " - "No charge correction has been requested" - ) - with pytest.warns(UserWarning, match=wmsg): - _get_alchemical_charge_difference( - benzene_to_benzoic_mapping, - "pme", - False, - openfe.SolventComponent(), - ) - - -def test_greater_than_one_charge_difference_error(aniline_to_benzoic_mapping): - errmsg = "A charge difference of 2" - with pytest.raises(ValueError, match=errmsg): - _get_alchemical_charge_difference( - aniline_to_benzoic_mapping, - "pme", - True, - openfe.SolventComponent(), - ) - - @pytest.fixture(scope="session") def benzene_solvent_openmm_system(benzene_modifications): smc = benzene_modifications["benzene"] @@ -2290,40 +1969,3 @@ def test_dry_run_vacuum_write_frequency( assert reporter.velocity_interval == velocities_write_frequency.m else: assert reporter.velocity_interval == 0 - - -@pytest.mark.parametrize( - "positions_write_frequency,velocities_write_frequency", - [ - [100.1 * unit.picosecond, 100 * unit.picosecond], - [100 * unit.picosecond, 100.1 * unit.picosecond], - ], -) -def test_pos_write_frequency_not_divisible( - benzene_vacuum_system, - toluene_vacuum_system, - benzene_to_toluene_mapping, - positions_write_frequency, - velocities_write_frequency, - tmpdir, - vac_settings, -): - vac_settings.output_settings.positions_write_frequency = positions_write_frequency - vac_settings.output_settings.velocities_write_frequency = velocities_write_frequency - - protocol = openmm_rfe.RelativeHybridTopologyProtocol( - settings=vac_settings, - ) - - # create DAG from protocol and take first (and only) work unit from within - dag = protocol.create( - stateA=benzene_vacuum_system, - stateB=toluene_vacuum_system, - mapping=benzene_to_toluene_mapping, - ) - dag_unit = list(dag.protocol_units)[0] - - with tmpdir.as_cwd(): - errmsg = "The output settings' " - with pytest.raises(ValueError, match=errmsg): - dag_unit.run(dry=True)["debug"]["sampler"] diff --git a/openfe/tests/protocols/openmm_rfe/test_hybrid_top_validation.py b/openfe/tests/protocols/openmm_rfe/test_hybrid_top_validation.py new file mode 100644 index 000000000..35ef57da0 --- /dev/null +++ b/openfe/tests/protocols/openmm_rfe/test_hybrid_top_validation.py @@ -0,0 +1,594 @@ +# This code is part of OpenFE and is licensed under the MIT license. +# For details, see https://github.com/OpenFreeEnergy/openfe +import logging + +import pytest +from openff.units import unit as offunit + +import openfe +from openfe import setup +from openfe.protocols import openmm_rfe + + +@pytest.fixture() +def vac_settings(): + settings = openmm_rfe.RelativeHybridTopologyProtocol.default_settings() + settings.forcefield_settings.nonbonded_method = "nocutoff" + settings.engine_settings.compute_platform = None + settings.protocol_repeats = 1 + return settings + + +@pytest.fixture() +def solv_settings(): + settings = openmm_rfe.RelativeHybridTopologyProtocol.default_settings() + settings.engine_settings.compute_platform = None + settings.protocol_repeats = 1 + return settings + + +def test_invalid_protocol_repeats(): + settings = openmm_rfe.RelativeHybridTopologyProtocol.default_settings() + with pytest.raises(ValueError, match="must be a positive value"): + settings.protocol_repeats = -1 + + +@pytest.mark.parametrize("state", ["A", "B"]) +def test_endstate_two_alchemcomp_stateA(state, benzene_modifications): + first_state = openfe.ChemicalSystem( + { + "ligandA": benzene_modifications["benzene"], + "ligandB": benzene_modifications["toluene"], + "solvent": openfe.SolventComponent(), + } + ) + other_state = openfe.ChemicalSystem( + { + "ligandC": benzene_modifications["phenol"], + "solvent": openfe.SolventComponent(), + } + ) + + if state == "A": + args = (first_state, other_state) + else: + args = (other_state, first_state) + + with pytest.raises(ValueError, match="Only one alchemical component"): + openmm_rfe.RelativeHybridTopologyProtocol._validate_endstates(*args) + + +@pytest.mark.parametrize("state", ["A", "B"]) +def test_endstates_not_smc(state, benzene_modifications): + first_state = openfe.ChemicalSystem( + { + "ligand": benzene_modifications["benzene"], + "foo": openfe.SolventComponent(), + } + ) + other_state = openfe.ChemicalSystem( + { + "ligand": benzene_modifications["benzene"], + "foo": benzene_modifications["toluene"], + } + ) + + if state == "A": + args = (first_state, other_state) + else: + args = (other_state, first_state) + + errmsg = "only SmallMoleculeComponents transformations" + with pytest.raises(ValueError, match=errmsg): + openmm_rfe.RelativeHybridTopologyProtocol._validate_endstates(*args) + + +def test_validate_mapping_none_mapping(): + errmsg = "A single LigandAtomMapping is expected" + with pytest.raises(ValueError, match=errmsg): + openmm_rfe.RelativeHybridTopologyProtocol._validate_mapping(None, None) + + +def test_validate_mapping_multi_mapping(benzene_to_toluene_mapping): + errmsg = "A single LigandAtomMapping is expected" + with pytest.raises(ValueError, match=errmsg): + openmm_rfe.RelativeHybridTopologyProtocol._validate_mapping( + [benzene_to_toluene_mapping] * 2, None + ) + + +@pytest.mark.parametrize("state", ["A", "B"]) +def test_validate_mapping_alchem_not_in(state, benzene_to_toluene_mapping): + errmsg = f"not in alchemical components of state{state}" + + if state == "A": + alchem_comps = {"stateA": [], "stateB": [benzene_to_toluene_mapping.componentB]} + else: + alchem_comps = {"stateA": [benzene_to_toluene_mapping.componentA], "stateB": []} + + with pytest.raises(ValueError, match=errmsg): + openmm_rfe.RelativeHybridTopologyProtocol._validate_mapping( + [benzene_to_toluene_mapping], + alchem_comps, + ) + + +def test_vaccuum_PME_error( + benzene_vacuum_system, toluene_vacuum_system, benzene_to_toluene_mapping, solv_settings +): + p = openmm_rfe.RelativeHybridTopologyProtocol(settings=solv_settings) + + errmsg = "PME cannot be used for vacuum transform" + with pytest.raises(ValueError, match=errmsg): + p.validate( + stateA=benzene_vacuum_system, + stateB=toluene_vacuum_system, + mapping=benzene_to_toluene_mapping, + ) + + +@pytest.mark.parametrize("charge", [None, "gasteiger"]) +def test_smcs_same_charge_passes(charge, benzene_modifications): + benzene = benzene_modifications["benzene"] + if charge is None: + smc = benzene + else: + offmol = benzene.to_openff() + offmol.assign_partial_charges(partial_charge_method="gasteiger") + smc = openfe.SmallMoleculeComponent.from_openff(offmol) + + # Just pass the same thing twice + state = openfe.ChemicalSystem({"l": smc}) + openmm_rfe.RelativeHybridTopologyProtocol._validate_smcs(state, state) + + +def test_smcs_different_charges_none_not_none(benzene_modifications): + # smcA has no charges + smcA = benzene_modifications["benzene"] + + # smcB has charges + offmol = smcA.to_openff() + offmol.assign_partial_charges(partial_charge_method="gasteiger") + smcB = openfe.SmallMoleculeComponent.from_openff(offmol) + + stateA = openfe.ChemicalSystem({"l": smcA}) + stateB = openfe.ChemicalSystem({"l": smcB}) + + errmsg = "isomorphic but with different charges" + with pytest.raises(ValueError, match=errmsg): + openmm_rfe.RelativeHybridTopologyProtocol._validate_smcs(stateA, stateB) + + +def test_smcs_different_charges_all(benzene_modifications): + # For this test, we will assign both A and B to both states + # It wouldn't happen in real life, but it tests that within a state + # you can pick up isomorphic molecules with different charges + # create an offmol with gasteiger charges + offmol = benzene_modifications["benzene"].to_openff() + offmol.assign_partial_charges(partial_charge_method="gasteiger") + smcA = openfe.SmallMoleculeComponent.from_openff(offmol) + + # now alter the offmol charges, scaling by 0.1 + offmol.partial_charges *= 0.1 + smcB = openfe.SmallMoleculeComponent.from_openff(offmol) + + state = openfe.ChemicalSystem({"l1": smcA, "l2": smcB}) + + errmsg = "isomorphic but with different charges" + with pytest.raises(ValueError, match=errmsg): + openmm_rfe.RelativeHybridTopologyProtocol._validate_smcs(state, state) + + +def test_solvent_nocutoff_error( + benzene_system, + toluene_system, + benzene_to_toluene_mapping, + vac_settings, +): + p = openmm_rfe.RelativeHybridTopologyProtocol(settings=vac_settings) + + errmsg = "nocutoff cannot be used for solvent transformation" + + with pytest.raises(ValueError, match=errmsg): + p.validate( + stateA=benzene_system, + stateB=toluene_system, + mapping=benzene_to_toluene_mapping, + ) + + +def test_nonwater_solvent_error( + benzene_modifications, + benzene_to_toluene_mapping, + solv_settings, +): + solvent = openfe.SolventComponent(smiles="C") + stateA = openfe.ChemicalSystem( + { + "ligand": benzene_modifications["benzene"], + "solvent": solvent, + } + ) + + stateB = openfe.ChemicalSystem({"ligand": benzene_modifications["toluene"], "solvent": solvent}) + + p = openmm_rfe.RelativeHybridTopologyProtocol(settings=solv_settings) + + errmsg = "Non water solvent is not currently supported" + + with pytest.raises(ValueError, match=errmsg): + p.validate( + stateA=stateA, + stateB=stateB, + mapping=benzene_to_toluene_mapping, + ) + + +def test_too_many_solv_comps_error( + benzene_modifications, + benzene_to_toluene_mapping, + solv_settings, +): + stateA = openfe.ChemicalSystem( + { + "ligand": benzene_modifications["benzene"], + "solvent!": openfe.SolventComponent(neutralize=True), + "solvent2": openfe.SolventComponent(neutralize=False), + } + ) + + stateB = openfe.ChemicalSystem( + { + "ligand": benzene_modifications["toluene"], + "solvent!": openfe.SolventComponent(neutralize=True), + "solvent2": openfe.SolventComponent(neutralize=False), + } + ) + + p = openmm_rfe.RelativeHybridTopologyProtocol(settings=solv_settings) + + errmsg = "Multiple SolventComponent found, only one is supported" + + with pytest.raises(ValueError, match=errmsg): + p.validate( + stateA=stateA, + stateB=stateB, + mapping=benzene_to_toluene_mapping, + ) + + +def test_bad_solv_settings( + benzene_system, + toluene_system, + benzene_to_toluene_mapping, + solv_settings, +): + """ + Test a case where the solvent settings would be wrong. + Not doing every cases since those are covered under + ``test_openmmutils.py``. + """ + solv_settings.solvation_settings.solvent_padding = 1.2 * offunit.nanometer + solv_settings.solvation_settings.number_of_solvent_molecules = 20 + + p = openmm_rfe.RelativeHybridTopologyProtocol(settings=solv_settings) + + errmsg = "Only one of solvent_padding, number_of_solvent_molecules," + with pytest.raises(ValueError, match=errmsg): + p.validate(stateA=benzene_system, stateB=toluene_system, mapping=benzene_to_toluene_mapping) + + +def test_too_many_prot_comps_error( + benzene_modifications, + benzene_to_toluene_mapping, + T4_protein_component, + eg5_protein, + solv_settings, +): + stateA = openfe.ChemicalSystem( + { + "ligand": benzene_modifications["benzene"], + "solvent": openfe.SolventComponent(), + "protein1": T4_protein_component, + "protein2": eg5_protein, + } + ) + + stateB = openfe.ChemicalSystem( + { + "ligand": benzene_modifications["toluene"], + "solvent": openfe.SolventComponent(), + "protein1": T4_protein_component, + "protein2": eg5_protein, + } + ) + + p = openmm_rfe.RelativeHybridTopologyProtocol(settings=solv_settings) + + errmsg = "Multiple ProteinComponent found, only one is supported" + + with pytest.raises(ValueError, match=errmsg): + p.validate( + stateA=stateA, + stateB=stateB, + mapping=benzene_to_toluene_mapping, + ) + + +def test_element_change_warning(atom_mapping_basic_test_files): + # check a mapping with element change gets rejected early + l1 = atom_mapping_basic_test_files["2-methylnaphthalene"] + l2 = atom_mapping_basic_test_files["2-naftanol"] + + # We use the 'old' lomap defaults because the + # basic test files inputs we use aren't fully aligned + mapper = setup.LomapAtomMapper( + time=20, threed=True, max3d=1000.0, element_change=True, seed="", shift=True + ) + + mapping = next(mapper.suggest_mappings(l1, l2)) + + alchem_comps = {"stateA": [l1], "stateB": [l2]} + + with pytest.warns(UserWarning, match="Element change"): + openmm_rfe.RelativeHybridTopologyProtocol._validate_mapping( + [mapping], + alchem_comps, + ) + + +def test_charge_difference_no_corr(benzene_to_benzoic_mapping): + wmsg = ( + "A charge difference of 1 is observed between the end states. " + "No charge correction has been requested" + ) + + with pytest.warns(UserWarning, match=wmsg): + openmm_rfe.RelativeHybridTopologyProtocol._validate_charge_difference( + benzene_to_benzoic_mapping, + "pme", + False, + openfe.SolventComponent(), + ) + + +def test_charge_difference_no_solvent(benzene_to_benzoic_mapping): + errmsg = "Cannot use explicit charge correction without solvent" + + with pytest.raises(ValueError, match=errmsg): + openmm_rfe.RelativeHybridTopologyProtocol._validate_charge_difference( + benzene_to_benzoic_mapping, + "pme", + True, + None, + ) + + +def test_charge_difference_no_pme(benzene_to_benzoic_mapping): + errmsg = "Explicit charge correction when not using PME" + + with pytest.raises(ValueError, match=errmsg): + openmm_rfe.RelativeHybridTopologyProtocol._validate_charge_difference( + benzene_to_benzoic_mapping, + "nocutoff", + True, + openfe.SolventComponent(), + ) + + +def test_greater_than_one_charge_difference_error(aniline_to_benzoic_mapping): + errmsg = "A charge difference of 2" + with pytest.raises(ValueError, match=errmsg): + openmm_rfe.RelativeHybridTopologyProtocol._validate_charge_difference( + aniline_to_benzoic_mapping, + "pme", + True, + openfe.SolventComponent(), + ) + + +@pytest.mark.parametrize( + "mapping_name,result", + [ + ["benzene_to_toluene_mapping", 0], + ["benzene_to_benzoic_mapping", 1], + ["benzene_to_aniline_mapping", -1], + ["aniline_to_benzene_mapping", 1], + ], +) +def test_get_charge_difference(mapping_name, result, request, caplog): + mapping = request.getfixturevalue(mapping_name) + caplog.set_level(logging.INFO) + + ion = r"Na+" if result == -1 else r"Cl-" + msg = ( + f"A charge difference of {result} is observed " + "between the end states. This will be addressed by " + f"transforming a water into a {ion} ion" + ) + + openmm_rfe.RelativeHybridTopologyProtocol._validate_charge_difference( + mapping, "pme", True, openfe.SolventComponent() + ) + + if result != 0: + assert msg in caplog.text + else: + assert msg not in caplog.text + + +def test_hightimestep( + benzene_vacuum_system, + toluene_vacuum_system, + benzene_to_toluene_mapping, + vac_settings, +): + vac_settings.forcefield_settings.hydrogen_mass = 1.0 + + p = openmm_rfe.RelativeHybridTopologyProtocol(settings=vac_settings) + + errmsg = "too large for hydrogen mass" + + with pytest.raises(ValueError, match=errmsg): + p.validate( + stateA=benzene_vacuum_system, + stateB=toluene_vacuum_system, + mapping=benzene_to_toluene_mapping, + extends=None, + ) + + +def test_time_per_iteration_divmod( + benzene_vacuum_system, + toluene_vacuum_system, + benzene_to_toluene_mapping, + vac_settings, +): + vac_settings.simulation_settings.time_per_iteration = 10 * offunit.ps + vac_settings.integrator_settings.timestep = 4 * offunit.ps + + p = openmm_rfe.RelativeHybridTopologyProtocol(settings=vac_settings) + + errmsg = "does not evenly divide by the timestep" + + with pytest.raises(ValueError, match=errmsg): + p.validate( + stateA=benzene_vacuum_system, + stateB=toluene_vacuum_system, + mapping=benzene_to_toluene_mapping, + extends=None, + ) + + +@pytest.mark.parametrize("attribute", ["equilibration_length", "production_length"]) +def test_simsteps_not_timestep_divisible( + attribute, + benzene_vacuum_system, + toluene_vacuum_system, + benzene_to_toluene_mapping, + vac_settings, +): + setattr(vac_settings.simulation_settings, attribute, 102 * offunit.fs) + p = openmm_rfe.RelativeHybridTopologyProtocol(settings=vac_settings) + + errmsg = "Simulation time not divisible by timestep" + + with pytest.raises(ValueError, match=errmsg): + p.validate( + stateA=benzene_vacuum_system, + stateB=toluene_vacuum_system, + mapping=benzene_to_toluene_mapping, + extends=None, + ) + + +@pytest.mark.parametrize("attribute", ["equilibration_length", "production_length"]) +def test_simsteps_not_mcstep_divisible( + attribute, + benzene_vacuum_system, + toluene_vacuum_system, + benzene_to_toluene_mapping, + vac_settings, +): + setattr(vac_settings.simulation_settings, attribute, 102 * offunit.ps) + p = openmm_rfe.RelativeHybridTopologyProtocol(settings=vac_settings) + + errmsg = "should contain a number of steps divisible by the number of integrator timesteps" + + with pytest.raises(ValueError, match=errmsg): + p.validate( + stateA=benzene_vacuum_system, + stateB=toluene_vacuum_system, + mapping=benzene_to_toluene_mapping, + extends=None, + ) + + +def test_checkpoint_interval_not_divisible_time_per_iter( + benzene_vacuum_system, + toluene_vacuum_system, + benzene_to_toluene_mapping, + vac_settings, +): + vac_settings.output_settings.checkpoint_interval = 4 * offunit.ps + vac_settings.simulation_settings.time_per_iteration = 2.5 * offunit.ps + p = openmm_rfe.RelativeHybridTopologyProtocol(settings=vac_settings) + + errmsg = "does not evenly divide by the amount of time per state MCMC" + + with pytest.raises(ValueError, match=errmsg): + p.validate( + stateA=benzene_vacuum_system, + stateB=toluene_vacuum_system, + mapping=benzene_to_toluene_mapping, + extends=None, + ) + + +@pytest.mark.parametrize("attribute", ["positions_write_frequency", "velocities_write_frequency"]) +def test_pos_vel_write_frequency_not_divisible( + benzene_vacuum_system, + toluene_vacuum_system, + benzene_to_toluene_mapping, + attribute, + vac_settings, +): + setattr(vac_settings.output_settings, attribute, 100.1 * offunit.picosecond) + + p = openmm_rfe.RelativeHybridTopologyProtocol(settings=vac_settings) + + errmsg = f"The output settings' {attribute}" + with pytest.raises(ValueError, match=errmsg): + p.validate( + stateA=benzene_vacuum_system, + stateB=toluene_vacuum_system, + mapping=benzene_to_toluene_mapping, + extends=None, + ) + + +@pytest.mark.parametrize( + "attribute", ["real_time_analysis_interval", "real_time_analysis_interval"] +) +def test_real_time_analysis_not_divisible( + benzene_vacuum_system, + toluene_vacuum_system, + benzene_to_toluene_mapping, + attribute, + vac_settings, +): + setattr(vac_settings.simulation_settings, attribute, 100.1 * offunit.picosecond) + + p = openmm_rfe.RelativeHybridTopologyProtocol(settings=vac_settings) + + errmsg = f"The {attribute}" + with pytest.raises(ValueError, match=errmsg): + p.validate( + stateA=benzene_vacuum_system, + stateB=toluene_vacuum_system, + mapping=benzene_to_toluene_mapping, + extends=None, + ) + + +def test_n_replicas_not_n_windows( + benzene_vacuum_system, + toluene_vacuum_system, + benzene_to_toluene_mapping, + vac_settings, + tmpdir, +): + # For PR #125 we pin such that the number of lambda windows + # equals the numbers of replicas used - TODO: remove limitation + vac_settings.simulation_settings.n_replicas = 13 + p = openmm_rfe.RelativeHybridTopologyProtocol(settings=vac_settings) + + errmsg = "Number of replicas in ``simulation_settings``:" + + with pytest.raises(ValueError, match=errmsg): + p.validate( + stateA=benzene_vacuum_system, + stateB=toluene_vacuum_system, + mapping=benzene_to_toluene_mapping, + extends=None, + )