Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 32 additions & 14 deletions src/openfe_analysis/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,36 @@ def cli():


@cli.command(name="RFE_analysis")
@click.argument(
"loc",
type=click.Path(
exists=True, readable=True, file_okay=False, dir_okay=True, path_type=pathlib.Path
),
@click.option(
"--pdb",
type=click.Path(exists=True, readable=True, dir_okay=False, path_type=pathlib.Path),
required=True,
help="Path to the topology PDB file.",
)
@click.argument("output", type=click.Path(writable=True, dir_okay=False, path_type=pathlib.Path))
def rfe_analysis(loc, output):
pdb = loc / "hybrid_system.pdb"
trj = loc / "simulation.nc"

data = rmsd.gather_rms_data(pdb, trj)

with click.open_file(output, "w") as f:
f.write(json.dumps(data))
@click.option(
"--nc",
type=click.Path(exists=True, readable=True, dir_okay=False, path_type=pathlib.Path),
required=True,
help="Path to the NetCDF trajectory file.",
)
@click.option(
"--output",
type=click.Path(writable=True, dir_okay=False, path_type=pathlib.Path),
required=True,
help="Path to save the JSON results.",
)
def rfe_analysis(pdb: pathlib.Path, nc: pathlib.Path, output: pathlib.Path):
"""
Perform RMSD analysis for an RBFE simulation.

Arguments:
pdb: path to the topology PDB file.
nc: path to the trajectory file (NetCDF format).
output: path to save the JSON results.
"""
# Run RMSD analysis
data = rmsd.gather_rms_data(pdb, nc)

# Write results
with output.open("w") as f:
json.dump(data, f, indent=2)
6 changes: 4 additions & 2 deletions src/openfe_analysis/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,5 +193,7 @@ def _reopen(self):
self._frame_index = -1

def close(self):
if self._dataset_owner:
self._dataset.close()
if self._dataset is not None:
if self._dataset_owner:
self._dataset.close()
self._dataset = None
39 changes: 32 additions & 7 deletions src/openfe_analysis/rmsd.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,30 @@
import numpy as np
import tqdm
from MDAnalysis.analysis import rms
from MDAnalysis.transformations import unwrap, TransformationBase
from MDAnalysis.lib.mdamath import make_whole
from numpy import typing as npt

from .reader import FEReader
from .transformations import Aligner, Minimiser, NoJump


class ShiftChains(TransformationBase):
"""Shift all protein chains relative to the first chain to keep them in the same box."""
def __init__(self, prot, max_threads=1):
self.prot = prot
self.max_threads = max_threads # required by MDAnalysis
super().__init__()

def _transform(self, ts):
chains = [seg.atoms for seg in self.prot.segments]
ref_chain = chains[0]
for chain in chains[1:]:
vec = chain.center_of_mass() - ref_chain.center_of_mass()
chain.positions -= np.rint(vec / ts.dimensions[:3]) * ts.dimensions[:3]
return ts


def make_Universe(top: pathlib.Path, trj: nc.Dataset, state: int) -> mda.Universe:
"""Makes a Universe and applies some transformations

Expand Down Expand Up @@ -42,15 +60,22 @@ def make_Universe(top: pathlib.Path, trj: nc.Dataset, state: int) -> mda.Univers

if prot:
# if there's a protein in the system:
# - make the protein not jump periodic images between frames
# - make the protein whole across periodic images between frames
# - put the ligand in the closest periodic image as the protein
# - align everything to minimise protein RMSD
nope = NoJump(prot)
# Shift all chains relative to first chain to keep in same box
unwrap_tr = unwrap(prot)
shift = ShiftChains(prot)

# Make each fragment whole internally
for frag in prot.fragments:
make_whole(frag, reference_atom=frag[0])
minnie = Minimiser(prot, ligand)
align = Aligner(prot)

u.trajectory.add_transformations(
nope,
unwrap_tr,
shift,
minnie,
align,
)
Expand Down Expand Up @@ -128,9 +153,9 @@ def gather_rms_data(
# TODO: Some smart guard to avoid allocating a silly amount of memory?
prot2d = np.empty((len(u.trajectory[::skip]), len(prot), 3), dtype=np.float32)

prot_start = prot.positions
# prot_weights = prot.masses / np.mean(prot.masses)
ligand_start = ligand.positions
# Would this copy be safer?
prot_start = prot.positions.copy()
ligand_start = ligand.positions.copy()
ligand_initial_com = ligand.center_of_mass()
ligand_weights = ligand.masses / np.mean(ligand.masses)

Expand Down Expand Up @@ -178,7 +203,7 @@ def gather_rms_data(
output["ligand_wander"].append(this_ligand_wander)

output["time(ps)"] = list(np.arange(len(u.trajectory))[::skip] * u.trajectory.dt)

ds.close()
return output


Expand Down
8 changes: 8 additions & 0 deletions src/openfe_analysis/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ def rbfe_skipped_data_dir() -> pathlib.Path:
def simulation_nc(rbfe_output_data_dir) -> pathlib.Path:
return rbfe_output_data_dir/"simulation.nc"

@pytest.fixture(scope="session")
def simulation_nc_multichain() -> pathlib.Path:
return "data/complex.nc"


@pytest.fixture(scope="session")
def simulation_skipped_nc(rbfe_skipped_data_dir) -> pathlib.Path:
Expand All @@ -40,6 +44,10 @@ def simulation_skipped_nc(rbfe_skipped_data_dir) -> pathlib.Path:
def hybrid_system_pdb(rbfe_output_data_dir) -> pathlib.Path:
return rbfe_output_data_dir/"hybrid_system.pdb"

@pytest.fixture(scope="session")
def system_pdb_multichain() -> pathlib.Path:
return "data/alchemical_system.pdb"


@pytest.fixture(scope="session")
def hybrid_system_skipped_pdb(rbfe_skipped_data_dir)->pathlib.Path:
Expand Down
98 changes: 95 additions & 3 deletions src/openfe_analysis/tests/test_rmsd.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,29 @@
import netCDF4 as nc
import numpy as np
import pytest
from itertools import islice
from numpy.testing import assert_allclose
from MDAnalysis.analysis import rms
from openfe_analysis.rmsd import gather_rms_data, make_Universe

from openfe_analysis.rmsd import gather_rms_data
@pytest.fixture
def mda_universe(system_pdb_multichain, simulation_nc_multichain):
"""
Safely create and destroy an MDAnalysis Universe.

Guarantees:
- NetCDF file is opened exactly once
"""
u = make_Universe(
system_pdb_multichain,
simulation_nc_multichain,
state=0,
)

yield u


@pytest.mark.flaky(reruns=3)
@pytest.mark.flaky(reruns=1)
def test_gather_rms_data_regression(simulation_nc, hybrid_system_pdb):
output = gather_rms_data(
hybrid_system_pdb,
Expand Down Expand Up @@ -43,7 +60,7 @@ def test_gather_rms_data_regression(simulation_nc, hybrid_system_pdb):
)


@pytest.mark.flaky(reruns=3)
@pytest.mark.flaky(reruns=1)
def test_gather_rms_data_regression_skippednc(simulation_skipped_nc, hybrid_system_skipped_pdb):
output = gather_rms_data(
hybrid_system_skipped_pdb,
Expand Down Expand Up @@ -78,3 +95,78 @@ def test_gather_rms_data_regression_skippednc(simulation_skipped_nc, hybrid_syst
[1.176307, 1.203364, 1.486987, 1.17462, 1.143457, 1.244173],
rtol=1e-3,
)

def test_multichain_com_continuity(mda_universe):
u = mda_universe
prot = u.select_atoms("protein")
chains = [seg.atoms for seg in prot.segments]
assert len(chains) == 2

segments = prot.segments
assert len(segments) > 1, "Test requires multi-chain protein"

chain_a = segments[0].atoms
chain_b = segments[1].atoms

distances = []
for ts in islice(u.trajectory, 20):
d = np.linalg.norm(
chain_a.center_of_mass() - chain_b.center_of_mass()
)
distances.append(d)

# No large frame-to-frame jumps (PBC artifacts)
jumps = np.abs(np.diff(distances))
assert np.max(jumps) < 5.0 # Å
u.trajectory.close()

def test_chain_radius_of_gyration_stable(simulation_nc_multichain, system_pdb_multichain):
u = make_Universe(system_pdb_multichain, simulation_nc_multichain, state=0)

protein = u.select_atoms("protein")
chain = protein.segments[0].atoms

rgs = []
for ts in u.trajectory[:50]:
rgs.append(chain.radius_of_gyration())

# Chain should not explode or collapse due to PBC errors
assert np.std(rgs) < 2.0
u.trajectory.close()

def test_rmsd_continuity(mda_universe):
u = mda_universe

prot = u.select_atoms("protein and name CA")
ref = prot.positions.copy()

rmsds = []
for ts in islice(u.trajectory, 20):
diff = prot.positions - ref
rmsd = np.sqrt((diff * diff).sum(axis=1).mean())
rmsds.append(rmsd)

jumps = np.abs(np.diff(rmsds))
assert np.max(jumps) < 2.0
u.trajectory.close()

def test_rmsd_reference_is_first_frame(mda_universe):
u = mda_universe
prot = u.select_atoms("protein")

ts = next(iter(u.trajectory)) # SAFE
ref = prot.positions.copy()

rmsd = np.sqrt(((prot.positions - ref) ** 2).mean())
assert rmsd == 0.0
u.trajectory.close()

def test_ligand_com_continuity(mda_universe):
u = mda_universe
ligand = u.select_atoms("resname UNK")

coms = [ligand.center_of_mass() for ts in islice(u.trajectory, 20)]
jumps = [np.linalg.norm(coms[i+1] - coms[i]) for i in range(len(coms)-1)]

assert max(jumps) < 5.0
u.trajectory.close()
3 changes: 2 additions & 1 deletion src/openfe_analysis/transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,8 @@ class Aligner(TransformationBase):
def __init__(self, ref_ag: mda.AtomGroup):
super().__init__()
self.ref_idx = ref_ag.ix
self.ref_pos = ref_ag.positions
# Would this copy be safer?
self.ref_pos = ref_ag.positions.copy()
self.weights = np.asarray(ref_ag.masses, dtype=np.float64)
self.weights /= np.mean(self.weights) # normalise weights
# remove COM shift from reference positions
Expand Down
Loading