diff --git a/src/openfe_analysis/cli.py b/src/openfe_analysis/cli.py index 42826cc..7e761b6 100644 --- a/src/openfe_analysis/cli.py +++ b/src/openfe_analysis/cli.py @@ -1,8 +1,12 @@ import click +from click.exceptions import BadOptionUsage import json +import netCDF4 as nc import pathlib +import tqdm from . import rmsd +from .reader import FEReader @click.group() @@ -20,6 +24,7 @@ def cli(): dir_okay=False, path_type=pathlib.Path)) def rfe_analysis(loc, output): + """Perform structural analysis on OpenMM RFE simulation""" pdb = loc / "hybrid_system.pdb" trj = loc / "simulation.nc" @@ -27,3 +32,75 @@ def rfe_analysis(loc, output): with click.open_file(output, 'w') as f: f.write(json.dumps(data)) + + +_statehelp = """\ +""" + + +@cli.command(name='RFE_trjconv') +@click.argument("loc", type=click.Path(exists=True, + readable=True, + file_okay=False, + dir_okay=True, + path_type=pathlib.Path)) +@click.argument('output', type=click.Path(writable=True, + dir_okay=False, + exists=False, + path_type=pathlib.Path), + required=True) +@click.option('-s', '--state', required=True, + help="either an integer (0 and -1 giving endstates) or 'all'") +def RFE_trjconv(loc, output, state): + """Convert .nc trajectory files from RBFE to new format for a single state + + LOC is the directory where a simulation took place, it should contain the + simulation.nc and hybrid_system.pdb files that were produced. + + OUTPUT is the name of the new trajectory file, e.g. "out.xtc". Any file + format supported by MDAnalysis can be specified, including XTC and DCD + formats. + + The .nc trajectory file contains multiple states; a single state can be + specified for output. Negative indices are allowed and treated as in + Python, therefore ``--state=0`` or ``--state=-1`` will produce trajectories + of the two end states. + + If ``--state='all'`` is given, all states are outputted, and the output + filename has the state number inserted before the file prefix, + e.g. ``--output=traj.dcd`` would produce a files called ``traj_state0.dcd`` + etc. + """ + import MDAnalysis as mda + + pdb = loc / "hybrid_system.pdb" + trj = loc / "simulation.nc" + + ds = nc.Dataset(trj, mode='r') + + if state == 'all': + # figure out how many states we need to output + nstates = ds.dimensions['state'].size + + states = range(nstates) + # turn out.dcd -> out_0.dcd + outputs = [ + output.with_stem(output.stem + f'_state{i}') + for i in range(nstates) + ] + else: + try: + states = [int(state)] + except ValueError: + raise BadOptionUsage(f"Invalid state specified: {state}") + outputs = [output] + + for s, o in zip(states, outputs): + u = mda.Universe(pdb, ds, + format=FEReader, + state_id=s) + ag = u.atoms # todo, atom selections would be here + + with mda.Writer(str(o), n_atoms=len(ag)) as w: + for ts in tqdm.tqdm(u.trajectory): + w.write(ag) diff --git a/src/openfe_analysis/reader.py b/src/openfe_analysis/reader.py index 5eebf14..dd64242 100644 --- a/src/openfe_analysis/reader.py +++ b/src/openfe_analysis/reader.py @@ -55,12 +55,12 @@ def __init__(self, filename, convert_units=True, **kwargs): convert_units : bool convert positions to A """ - self._state_id = kwargs.pop('state_id', None) - self._replica_id = kwargs.pop('replica_id', None) - if not ((self._state_id is None) ^ (self._replica_id is None)): + s_id = kwargs.pop('state_id', None) + r_id = kwargs.pop('replica_id', None) + if not ((s_id is None) ^ (r_id is None)): raise ValueError("Specify one and only one of state or replica, " - f"got state id={self._state_id} " - f"replica_id={self._replica_id}") + f"got state id={s_id} " + f"replica_id={r_id}") super().__init__(filename, convert_units, **kwargs) @@ -70,6 +70,15 @@ def __init__(self, filename, convert_units=True, **kwargs): else: self._dataset = nc.Dataset(filename) self._dataset_owner = True + + # if we have a negative indexed state_id or replica_id, convert this + if s_id is not None and s_id < 0: + s_id = range(self._dataset.dimensions['state'].size)[s_id] + elif r_id is not None and r_id < 0: + r_id = range(self._dataset.dimensions['replica'].size)[r_id] + self._state_id = s_id + self._replica_id = r_id + self._n_atoms = self._dataset.dimensions['atom'].size self.ts = Timestep(self._n_atoms) self._dt = _determine_dt(self._dataset) diff --git a/src/tests/test_reader.py b/src/tests/test_reader.py index fe3f997..7b36225 100644 --- a/src/tests/test_reader.py +++ b/src/tests/test_reader.py @@ -24,3 +24,17 @@ def test_universe_from_nc_file(simulation_nc, hybrid_system_pdb): assert len(u.atoms) == 4782 +def test_universe_creation_negative_state(simulation_nc, hybrid_system_pdb): + u = mda.Universe(hybrid_system_pdb, simulation_nc, + format='openfe rfe', state_id=-1) + + assert u.trajectory._state_id == 10 + assert u.trajectory._replica_id is None + + +def test_universe_creation_negative_replica(simulation_nc, hybrid_system_pdb): + u = mda.Universe(hybrid_system_pdb, simulation_nc, + format='openfe rfe', replica_id=-1) + + assert u.trajectory._state_id is None + assert u.trajectory._replica_id == 10