diff --git a/src/openfe_analysis/cli.py b/src/openfe_analysis/cli.py index c45c540..437aad7 100644 --- a/src/openfe_analysis/cli.py +++ b/src/openfe_analysis/cli.py @@ -12,18 +12,36 @@ def cli(): @cli.command(name="RFE_analysis") -@click.argument( - "loc", - type=click.Path( - exists=True, readable=True, file_okay=False, dir_okay=True, path_type=pathlib.Path - ), +@click.option( + "--pdb", + type=click.Path(exists=True, readable=True, dir_okay=False, path_type=pathlib.Path), + required=True, + help="Path to the topology PDB file.", ) -@click.argument("output", type=click.Path(writable=True, dir_okay=False, path_type=pathlib.Path)) -def rfe_analysis(loc, output): - pdb = loc / "hybrid_system.pdb" - trj = loc / "simulation.nc" - - data = rmsd.gather_rms_data(pdb, trj) - - with click.open_file(output, "w") as f: - f.write(json.dumps(data)) +@click.option( + "--nc", + type=click.Path(exists=True, readable=True, dir_okay=False, path_type=pathlib.Path), + required=True, + help="Path to the NetCDF trajectory file.", +) +@click.option( + "--output", + type=click.Path(writable=True, dir_okay=False, path_type=pathlib.Path), + required=True, + help="Path to save the JSON results.", +) +def rfe_analysis(pdb: pathlib.Path, nc: pathlib.Path, output: pathlib.Path): + """ + Perform RMSD analysis for an RBFE simulation. + + Arguments: + pdb: path to the topology PDB file. + nc: path to the trajectory file (NetCDF format). + output: path to save the JSON results. + """ + # Run RMSD analysis + data = rmsd.gather_rms_data(pdb, nc) + + # Write results + with output.open("w") as f: + json.dump(data, f, indent=2) \ No newline at end of file diff --git a/src/openfe_analysis/reader.py b/src/openfe_analysis/reader.py index b7677bc..884550b 100644 --- a/src/openfe_analysis/reader.py +++ b/src/openfe_analysis/reader.py @@ -193,5 +193,7 @@ def _reopen(self): self._frame_index = -1 def close(self): - if self._dataset_owner: - self._dataset.close() + if self._dataset is not None: + if self._dataset_owner: + self._dataset.close() + self._dataset = None diff --git a/src/openfe_analysis/rmsd.py b/src/openfe_analysis/rmsd.py index 7d2af13..cb4cff3 100644 --- a/src/openfe_analysis/rmsd.py +++ b/src/openfe_analysis/rmsd.py @@ -7,12 +7,30 @@ import numpy as np import tqdm from MDAnalysis.analysis import rms +from MDAnalysis.transformations import unwrap, TransformationBase +from MDAnalysis.lib.mdamath import make_whole from numpy import typing as npt from .reader import FEReader from .transformations import Aligner, Minimiser, NoJump +class ShiftChains(TransformationBase): + """Shift all protein chains relative to the first chain to keep them in the same box.""" + def __init__(self, prot, max_threads=1): + self.prot = prot + self.max_threads = max_threads # required by MDAnalysis + super().__init__() + + def _transform(self, ts): + chains = [seg.atoms for seg in self.prot.segments] + ref_chain = chains[0] + for chain in chains[1:]: + vec = chain.center_of_mass() - ref_chain.center_of_mass() + chain.positions -= np.rint(vec / ts.dimensions[:3]) * ts.dimensions[:3] + return ts + + def make_Universe(top: pathlib.Path, trj: nc.Dataset, state: int) -> mda.Universe: """Makes a Universe and applies some transformations @@ -42,15 +60,22 @@ def make_Universe(top: pathlib.Path, trj: nc.Dataset, state: int) -> mda.Univers if prot: # if there's a protein in the system: - # - make the protein not jump periodic images between frames + # - make the protein whole across periodic images between frames # - put the ligand in the closest periodic image as the protein # - align everything to minimise protein RMSD - nope = NoJump(prot) + # Shift all chains relative to first chain to keep in same box + unwrap_tr = unwrap(prot) + shift = ShiftChains(prot) + + # Make each fragment whole internally + for frag in prot.fragments: + make_whole(frag, reference_atom=frag[0]) minnie = Minimiser(prot, ligand) align = Aligner(prot) u.trajectory.add_transformations( - nope, + unwrap_tr, + shift, minnie, align, ) @@ -128,9 +153,9 @@ def gather_rms_data( # TODO: Some smart guard to avoid allocating a silly amount of memory? prot2d = np.empty((len(u.trajectory[::skip]), len(prot), 3), dtype=np.float32) - prot_start = prot.positions - # prot_weights = prot.masses / np.mean(prot.masses) - ligand_start = ligand.positions + # Would this copy be safer? + prot_start = prot.positions.copy() + ligand_start = ligand.positions.copy() ligand_initial_com = ligand.center_of_mass() ligand_weights = ligand.masses / np.mean(ligand.masses) @@ -178,7 +203,7 @@ def gather_rms_data( output["ligand_wander"].append(this_ligand_wander) output["time(ps)"] = list(np.arange(len(u.trajectory))[::skip] * u.trajectory.dt) - + ds.close() return output diff --git a/src/openfe_analysis/tests/conftest.py b/src/openfe_analysis/tests/conftest.py index 5fa2db4..b237f47 100644 --- a/src/openfe_analysis/tests/conftest.py +++ b/src/openfe_analysis/tests/conftest.py @@ -30,6 +30,10 @@ def rbfe_skipped_data_dir() -> pathlib.Path: def simulation_nc(rbfe_output_data_dir) -> pathlib.Path: return rbfe_output_data_dir/"simulation.nc" +@pytest.fixture(scope="session") +def simulation_nc_multichain() -> pathlib.Path: + return "data/complex.nc" + @pytest.fixture(scope="session") def simulation_skipped_nc(rbfe_skipped_data_dir) -> pathlib.Path: @@ -40,6 +44,10 @@ def simulation_skipped_nc(rbfe_skipped_data_dir) -> pathlib.Path: def hybrid_system_pdb(rbfe_output_data_dir) -> pathlib.Path: return rbfe_output_data_dir/"hybrid_system.pdb" +@pytest.fixture(scope="session") +def system_pdb_multichain() -> pathlib.Path: + return "data/alchemical_system.pdb" + @pytest.fixture(scope="session") def hybrid_system_skipped_pdb(rbfe_skipped_data_dir)->pathlib.Path: diff --git a/src/openfe_analysis/tests/test_rmsd.py b/src/openfe_analysis/tests/test_rmsd.py index a108bbe..898d44d 100644 --- a/src/openfe_analysis/tests/test_rmsd.py +++ b/src/openfe_analysis/tests/test_rmsd.py @@ -1,12 +1,29 @@ import netCDF4 as nc import numpy as np import pytest +from itertools import islice from numpy.testing import assert_allclose +from MDAnalysis.analysis import rms +from openfe_analysis.rmsd import gather_rms_data, make_Universe -from openfe_analysis.rmsd import gather_rms_data +@pytest.fixture +def mda_universe(system_pdb_multichain, simulation_nc_multichain): + """ + Safely create and destroy an MDAnalysis Universe. + + Guarantees: + - NetCDF file is opened exactly once + """ + u = make_Universe( + system_pdb_multichain, + simulation_nc_multichain, + state=0, + ) + + yield u -@pytest.mark.flaky(reruns=3) +@pytest.mark.flaky(reruns=1) def test_gather_rms_data_regression(simulation_nc, hybrid_system_pdb): output = gather_rms_data( hybrid_system_pdb, @@ -43,7 +60,7 @@ def test_gather_rms_data_regression(simulation_nc, hybrid_system_pdb): ) -@pytest.mark.flaky(reruns=3) +@pytest.mark.flaky(reruns=1) def test_gather_rms_data_regression_skippednc(simulation_skipped_nc, hybrid_system_skipped_pdb): output = gather_rms_data( hybrid_system_skipped_pdb, @@ -78,3 +95,78 @@ def test_gather_rms_data_regression_skippednc(simulation_skipped_nc, hybrid_syst [1.176307, 1.203364, 1.486987, 1.17462, 1.143457, 1.244173], rtol=1e-3, ) + +def test_multichain_com_continuity(mda_universe): + u = mda_universe + prot = u.select_atoms("protein") + chains = [seg.atoms for seg in prot.segments] + assert len(chains) == 2 + + segments = prot.segments + assert len(segments) > 1, "Test requires multi-chain protein" + + chain_a = segments[0].atoms + chain_b = segments[1].atoms + + distances = [] + for ts in islice(u.trajectory, 20): + d = np.linalg.norm( + chain_a.center_of_mass() - chain_b.center_of_mass() + ) + distances.append(d) + + # No large frame-to-frame jumps (PBC artifacts) + jumps = np.abs(np.diff(distances)) + assert np.max(jumps) < 5.0 # Å + u.trajectory.close() + +def test_chain_radius_of_gyration_stable(simulation_nc_multichain, system_pdb_multichain): + u = make_Universe(system_pdb_multichain, simulation_nc_multichain, state=0) + + protein = u.select_atoms("protein") + chain = protein.segments[0].atoms + + rgs = [] + for ts in u.trajectory[:50]: + rgs.append(chain.radius_of_gyration()) + + # Chain should not explode or collapse due to PBC errors + assert np.std(rgs) < 2.0 + u.trajectory.close() + +def test_rmsd_continuity(mda_universe): + u = mda_universe + + prot = u.select_atoms("protein and name CA") + ref = prot.positions.copy() + + rmsds = [] + for ts in islice(u.trajectory, 20): + diff = prot.positions - ref + rmsd = np.sqrt((diff * diff).sum(axis=1).mean()) + rmsds.append(rmsd) + + jumps = np.abs(np.diff(rmsds)) + assert np.max(jumps) < 2.0 + u.trajectory.close() + +def test_rmsd_reference_is_first_frame(mda_universe): + u = mda_universe + prot = u.select_atoms("protein") + + ts = next(iter(u.trajectory)) # SAFE + ref = prot.positions.copy() + + rmsd = np.sqrt(((prot.positions - ref) ** 2).mean()) + assert rmsd == 0.0 + u.trajectory.close() + +def test_ligand_com_continuity(mda_universe): + u = mda_universe + ligand = u.select_atoms("resname UNK") + + coms = [ligand.center_of_mass() for ts in islice(u.trajectory, 20)] + jumps = [np.linalg.norm(coms[i+1] - coms[i]) for i in range(len(coms)-1)] + + assert max(jumps) < 5.0 + u.trajectory.close() diff --git a/src/openfe_analysis/transformations.py b/src/openfe_analysis/transformations.py index 1a4ef34..61db292 100644 --- a/src/openfe_analysis/transformations.py +++ b/src/openfe_analysis/transformations.py @@ -87,7 +87,8 @@ class Aligner(TransformationBase): def __init__(self, ref_ag: mda.AtomGroup): super().__init__() self.ref_idx = ref_ag.ix - self.ref_pos = ref_ag.positions + # Would this copy be safer? + self.ref_pos = ref_ag.positions.copy() self.weights = np.asarray(ref_ag.masses, dtype=np.float64) self.weights /= np.mean(self.weights) # normalise weights # remove COM shift from reference positions