diff --git a/CHANGELOG.md b/CHANGELOG.md index c2d7d6f..fb2dcb1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,6 +19,7 @@ The rules for this file: ### Authors * @orbeckst +* @rsexton2 ### Fixed * Have cluster.ProcessProtein.reprocess() record "no result" if @@ -26,6 +27,9 @@ The rules for this file: number of samples. Otherwise `python -m cluster` fails to process whole proteins. +### Added +* Added command-line interface for basicrta workflow (Issue #20) + ## [1.1.3] - 2025-09-11 ### Authors diff --git a/basicrta/cli.py b/basicrta/cli.py new file mode 100644 index 0000000..b9e6822 --- /dev/null +++ b/basicrta/cli.py @@ -0,0 +1,65 @@ +""" +Command line functionality of basicrta. + +The `main()` function of this module gets the argument parser from each of the +scripts below and executes the `main()` function of the module called. The +function also collects help from the subparsers and provides it at the command +line. + +Modules callable from the cli: contacts.py, gibbs.py, cluster.py, kinetics.py, +combine.py. +""" + +from importlib.metadata import version +import basicrta +import argparse +import subprocess +import importlib +import sys + +__version__ = version("basicrta") + +# define which scripts can be ran from cli +# can easily add functionality to cli as modules are added +commands = ['contacts', 'gibbs', 'cluster', 'combine', 'kinetics'] + +def main(): + """ This module provides the functionality for a command line interface for + basicrta scripts. The scripts available to the cli are: + + * contacts.py + * gibbs.py + * cluster.py + * combine.py + * kinetics.py + + Each script is called and ran using the `main()` function of each module and + the parser is passed to the cli using the `get_parser()` function. Any + module added to the cli needs to have both functions. + """ + parser = argparse.ArgumentParser(prog='basicrta', add_help=True) + subparsers = parser.add_subparsers(help="""step in the basicrta workflow to + execute""") + + # collect parser from each script in `commands` + for command in commands: + subparser = importlib.import_module(f"basicrta.{command}").get_parser() + subparsers.add_parser(f'{command}', parents=[subparser], add_help=True, + description=subparser.description, + conflict_handler='resolve', + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + help=subparser.description) + + # print subparser help if no arguments given + if len(sys.argv) == 2 and sys.argv[1] in commands: + subparsers.choices[f'{sys.argv[1]}'].print_help() + sys.exit() + + # print basicrta help if no subcommand given + parser.parse_args(args=None if sys.argv[1:] else ['--help']) + + # execute basicrta script + importlib.import_module(f"basicrta.{sys.argv[1]}").main() + +if __name__ == "__main__": + main() diff --git a/basicrta/cluster.py b/basicrta/cluster.py index 8e1207d..83d4ed3 100644 --- a/basicrta/cluster.py +++ b/basicrta/cluster.py @@ -1,3 +1,7 @@ +"""This module provides the ProcessProtein class, which collects and processes +Gibbs sampler data. +""" + import os import gc import warnings @@ -11,9 +15,6 @@ from basicrta.gibbs import Gibbs gc.enable() -"""This module provides the ProcessProtein class, which collects and processes -Gibbs sampler data. -""" class ProcessProtein(object): r"""ProcessProtein is the class that collects and processes Gibbs sampler @@ -237,21 +238,30 @@ def b_color_structure(self, structure): u.select_atoms('protein').write('tau_bcolored.pdb') - -if __name__ == "__main__": #pragma: no cover - # the script is tested in the test_cluster.py but cannot be accounted for - # in the coverage report +def get_parser(): import argparse - parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) - parser.add_argument('--nproc', type=int, default=1) - parser.add_argument('--cutoff', type=float) - parser.add_argument('--niter', type=int, default=110000) - parser.add_argument('--prot', type=str, default=None, nargs='?') - parser.add_argument('--label-cutoff', type=float, default=3, - dest='label_cutoff', - help='Only label residues with tau > ' - 'LABEL-CUTOFF * . ') - parser.add_argument('--structure', type=str, nargs='?') + parser = argparse.ArgumentParser(description="""perform clustering for each + residue located in basicrta-{cutoff}/""", + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + required = parser.add_argument_group('required arguments') + + required.add_argument('--cutoff', required=True, type=float, help="""cutoff + used in contact analysis, will cluster results in + basicrta-{cutoff}/""") + parser.add_argument('--nproc', type=int, default=1, help="""number of + processes to use in multiprocessing""") + parser.add_argument('--niter', type=int, default=110000, help="""number of + iterations used in the gibbs sampler, used to load + gibbs_{niter}.pkl""") + parser.add_argument('--prot', type=str, nargs='?', help="""name of protein + in tm_dict.txt, used to draw TM bars in tau vs resid + plot""") + parser.add_argument('--label_cutoff', type=float, default=3, + dest='label_cutoff', + help="""Only label residues with tau > + LABEL-CUTOFF * .""") + parser.add_argument('--structure', type=str, nargs='?', help="""will add tau + as bfactors to the structure if provided""") # use for default values parser.add_argument('--gskip', type=int, default=100, help='Gibbs skip parameter for decorrelated samples;' @@ -259,7 +269,12 @@ def b_color_structure(self, structure): parser.add_argument('--burnin', type=int, default=10000, help='Burn-in parameter, drop first N samples as equilibration;' 'default from https://pubs.acs.org/doi/10.1021/acs.jctc.4c01522') + # this is to make the cli work, should be just a temporary solution + parser.add_argument('cluster', nargs='?', help=argparse.SUPPRESS) + return parser +def main(): + parser = get_parser() args = parser.parse_args() pp = ProcessProtein(args.niter, args.prot, args.cutoff, @@ -267,3 +282,10 @@ def b_color_structure(self, structure): pp.reprocess(nproc=args.nproc) pp.write_data() pp.plot_protein(label_cutoff=args.label_cutoff) + + +if __name__ == "__main__": #pragma: no cover + # the script is tested in the test_cluster.py but cannot be accounted for + # in the coverage report + exit(main()) + diff --git a/basicrta/combine.py b/basicrta/combine.py index 826b159..8d27924 100644 --- a/basicrta/combine.py +++ b/basicrta/combine.py @@ -1,7 +1,7 @@ #!/usr/bin/env python """ -Command-line interface for combining contact timeseries from multiple repeat runs. +Combine contact timeseries from multiple repeat runs. This module provides functionality to combine contact files from multiple trajectory repeats, enabling pooled analysis of binding kinetics. @@ -9,22 +9,154 @@ import os import argparse -from basicrta.contacts import CombineContacts +class CombineContacts(object): + """Class to combine contact timeseries from multiple repeat runs. + + This class enables pooling data from multiple trajectory repeats and + calculating posteriors from all data together, rather than analyzing + each run separately. + + :param contact_files: List of contact pickle files to combine + :type contact_files: list of str + :param output_name: Name for the combined output file (default: 'combined_contacts.pkl') + :type output_name: str, optional + :param validate_compatibility: Whether to validate that files are compatible (default: True) + :type validate_compatibility: bool, optional + """ + + def __init__(self, contact_files, output_name='combined_contacts.pkl', + validate_compatibility=True): + self.contact_files = contact_files + self.output_name = output_name + self.validate_compatibility = validate_compatibility + + if len(contact_files) < 2: + raise ValueError("At least 2 contact files are required for combining") + + def _load_contact_file(self, filename): + """Load a contact pickle file and return data and metadata.""" + if not os.path.exists(filename): + raise FileNotFoundError(f"Contact file not found: {filename}") + + with open(filename, 'rb') as f: + contacts = pickle.load(f) + + metadata = contacts.dtype.metadata + return contacts, metadata + + def _validate_compatibility(self, metadatas): + """Validate that contact files are compatible for combining.""" + reference = metadatas[0] + + # Check that all files have the same atom groups + for i, meta in enumerate(metadatas[1:], 1): + # Compare cutoff + if meta['cutoff'] != reference['cutoff']: + raise ValueError(f"Incompatible cutoffs: file 0 has {reference['cutoff']}, " + f"file {i} has {meta['cutoff']}") + + # Compare atom group selections by checking if resids match + ref_ag1_resids = set(reference['ag1'].residues.resids) + ref_ag2_resids = set(reference['ag2'].residues.resids) + meta_ag1_resids = set(meta['ag1'].residues.resids) + meta_ag2_resids = set(meta['ag2'].residues.resids) + + if ref_ag1_resids != meta_ag1_resids: + raise ValueError(f"Incompatible ag1 residues between file 0 and file {i}") + if ref_ag2_resids != meta_ag2_resids: + raise ValueError(f"Incompatible ag2 residues between file 0 and file {i}") + + # Check timesteps and warn if different + timesteps = [meta['ts'] for meta in metadatas] + if not all(abs(ts - timesteps[0]) < 1e-6 for ts in timesteps): + print("WARNING: Different timesteps detected across runs:") + for i, (filename, ts) in enumerate(zip(self.contact_files, timesteps)): + print(f" File {i} ({filename}): dt = {ts} ns") + print("This may affect residence time estimates, especially for fast events.") + + def run(self): + """Combine contact files and save the result.""" + print(f"Combining {len(self.contact_files)} contact files...") + + all_contacts = [] + all_metadatas = [] + + # Load all contact files + for i, filename in enumerate(self.contact_files): + print(f"Loading file {i+1}/{len(self.contact_files)}: {filename}") + contacts, metadata = self._load_contact_file(filename) + all_contacts.append(contacts) + all_metadatas.append(metadata) + + # Validate compatibility if requested + if self.validate_compatibility: + print("Validating file compatibility...") + self._validate_compatibility(all_metadatas) + + # Combine contact data + print("Combining contact data...") + + # Calculate total size and create combined array + total_size = sum(len(contacts) for contacts in all_contacts) + reference_metadata = all_metadatas[0].copy() + + # Extend metadata to include trajectory source information + reference_metadata['source_files'] = self.contact_files + reference_metadata['n_trajectories'] = len(self.contact_files) + + # Determine number of columns (5 for raw contacts, 4 for processed) + n_cols = all_contacts[0].shape[1] + + # Create dtype with extended metadata + combined_dtype = np.dtype(np.float64, metadata=reference_metadata) + + # Add trajectory source column (will be last column) + combined_contacts = np.zeros((total_size, n_cols + 1), dtype=np.float64) + + # Combine data and add trajectory source information + offset = 0 + for traj_idx, contacts in enumerate(all_contacts): + n_contacts = len(contacts) + # Copy original contact data + combined_contacts[offset:offset+n_contacts, :n_cols] = contacts[:] + # Add trajectory source index + combined_contacts[offset:offset+n_contacts, n_cols] = traj_idx + offset += n_contacts + + # Create final memmap with proper dtype + final_contacts = combined_contacts.view(combined_dtype) + + # Save combined contacts + print(f"Saving combined contacts to {self.output_name}...") + final_contacts.dump(self.output_name, protocol=5) + + print(f"Successfully combined {len(self.contact_files)} files into {self.output_name}") + print(f"Total contacts: {total_size}") + print(f"Added trajectory source column (index {n_cols}) for kinetic clustering support") + + return self.output_name -def main(): - """Main function for combining contact files.""" +def get_parser(): + """Create parser, parse command line arguments, and return ArgumentParser + object. + + :return: An ArgumentParser instance with command line arguments stored. + :rtype: `ArgumentParser` object + """ parser = argparse.ArgumentParser( description="Combine contact timeseries from multiple repeat runs. " "This enables pooling data from multiple trajectory repeats " "and calculating posteriors from all data together." ) - parser.add_argument( + required = parser.add_argument_group('required arguments') + required.add_argument( '--contacts', nargs='+', required=True, - help="List of contact pickle files to combine (e.g., contacts_7.0.pkl from different runs)" + help="""List of contact pickle files to combine (e.g., contacts_7.0.pkl + from different runs)""", ) parser.add_argument( @@ -39,7 +171,14 @@ def main(): action='store_true', help="Skip compatibility validation (use with caution)" ) - + # this is to make the cli work, should be just a temporary solution + parser.add_argument('combine', nargs='?', help=argparse.SUPPRESS) + return parser + +def main(): + """Execute this function when this script is called from the command line. + """ + parser = get_parser() args = parser.parse_args() # Validate input files exist @@ -82,6 +221,5 @@ def main(): print(f"ERROR: {e}") return 1 - -if __name__ == '__main__': - exit(main()) +if __name__ == "__main__": + exit(main()) diff --git a/basicrta/contacts.py b/basicrta/contacts.py index 2117c31..f17503b 100644 --- a/basicrta/contacts.py +++ b/basicrta/contacts.py @@ -1,3 +1,13 @@ +""" +Create contact maps between two atom groups. + +This module provides the `MapContacts` class, which creates the initial contact +map between the two atom groups using a maximum cutoff (`max_cutoff`), which +provides for quicker processing if creating results for multiple cutoffs. The +`ProcessContacts` class takes the initial contact map and creates the processed +contact map based on the prescribed cutoff. +""" + from tqdm import tqdm from MDAnalysis.lib import distances from multiprocessing import Pool, Lock @@ -364,23 +374,9 @@ def run(self): return self.output_name - -if __name__ == '__main__': - """DOCSSS - """ - import argparse - parser = argparse.ArgumentParser(description="Create the primary contact \ - map and collect contacts based on the \ - desired cutoff distance") - parser.add_argument('--top', type=str, help="Topology") - parser.add_argument('--traj', type=str) - parser.add_argument('--sel1', type=str) - parser.add_argument('--sel2', type=str) - parser.add_argument('--cutoff', type=float) - parser.add_argument('--nproc', type=int, default=1) - parser.add_argument('--nslices', type=int, default=100) +def main(): + parser = get_parser() args = parser.parse_args() - u = mda.Universe(args.top, args.traj) cutoff, nproc, nslices = args.cutoff, args.nproc, args.nslices ag1 = u.select_atoms(args.sel1) @@ -396,3 +392,40 @@ def run(self): ProcessContacts(cutoff, mapname, nproc=nproc).run() + +def get_parser(): + import argparse + parser = argparse.ArgumentParser(description="""Create the initial contact + map and process it using a + prescribed cutoff""") + required = parser.add_argument_group('required arguments') + + required.add_argument('--top', type=str, help="Topology") + required.add_argument('--traj', type=str, help="Trajectory") + required.add_argument('--sel1', type=str, help="Primary atom selection, based \ + on MDAnalysis atom selection. basicrta will produce \ + tau for each residue in this atom group.") + required.add_argument('--sel2', type=str, help="Secondary atom selection, \ + based on MDAnalysis atom selection. basicrta will \ + collect contacts between each residue of this group \ + with each residue of `sel1`.") + required.add_argument('--cutoff', type=float, help="""Value to use for defining + a contact (in Angstrom). Any atom of `sel2` that is at + a distance less than or equal to `cutoff` of any atom + in `sel1` will be considered in contact.""", required=True) + parser.add_argument('--nproc', type=int, default=1, help="""Number of + processes to use in multiprocessing""") + parser.add_argument('--nslices', type=int, default=100, help="""Number of + slices to break the trajectory into. Increase this to + reduce the amount of memory needed for each process.""") + # this is to make the cli work, should be just a temporary solution + parser.add_argument('contacts', nargs='?', help=argparse.SUPPRESS) + return parser + + +if __name__ == '__main__': + exit(main()) + """DOCSSS + """ + + diff --git a/basicrta/gibbs.py b/basicrta/gibbs.py index d9b42f9..0c49382 100644 --- a/basicrta/gibbs.py +++ b/basicrta/gibbs.py @@ -1,3 +1,11 @@ +""" +Perform Gibbs samplers and process data. + +This module provides the `ParallelGibbs` class, which parallelizes the creation +of Gibbs samplers for each residue in the contact map. This module also provides +the `Gibbs` class, which allows for the loading and processing of the gibbs +sampler data, as well as plotting and saving processed results. +""" import os import gc import pickle @@ -335,7 +343,7 @@ def cluster(self, method="GaussianMixture", **kwargs): self.processed_results.indicator = pindicator self.processed_results.labels = all_labels - def process_gibbs(self, show=True): + def process_gibbs(self, show=False): r""" Process the samples collected from the Gibbs sampler. :meth:`process_gibbs` can be called multiple times to check the @@ -851,15 +859,33 @@ def plot_surv(self, scale=1, remove_noise=False, save=False, xlim=None, 's_vs_t.pdf', bbox_inches='tight') plt.show() - -if __name__ == '__main__': +def get_parser(): import argparse - parser = argparse.ArgumentParser() - parser.add_argument('--contacts') - parser.add_argument('--resid', type=int, default=None) - parser.add_argument('--nproc', type=int, default=1) - parser.add_argument('--niter', type=int, default=110000) - parser.add_argument('--ncomp', type=int, default=15) + parser = argparse.ArgumentParser(description="""run gibbs samplers for all + or a specified residue present in the + contact map""", + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + required = parser.add_argument_group('required arguments') + + required.add_argument('--contacts', required=True, help="""Contact file + produced from `basicrta contacts`, default is + contacts_{cutoff}.pkl""") + parser.add_argument('--resid', type=int, help="""run gibbs sampler for + this residue. Will collect cutoff from contact file + name.""") + parser.add_argument('--nproc', type=int, default=1, help="""number of + processes to use in multiprocessing""") + parser.add_argument('--niter', type=int, default=110000, help="""number of + iterations to use for the gibbs sampler""") + parser.add_argument('--ncomp', type=int, default=15, help="""number of + components to use for the exponential mixture + model""") + # this is to make the cli work, should be just a temporary solution + parser.add_argument('gibbs', nargs='?', help=argparse.SUPPRESS) + return parser + +def main(): + parser = get_parser() args = parser.parse_args() contact_path = os.path.abspath(args.contacts) @@ -867,3 +893,6 @@ def plot_surv(self, scale=1, remove_noise=False, save=False, xlim=None, ParallelGibbs(contact_path, nproc=args.nproc, ncomp=args.ncomp, niter=args.niter).run(run_resids=args.resid) + +if __name__ == '__main__': + exit(main()) diff --git a/basicrta/kinetics.py b/basicrta/kinetics.py index 8613092..818c390 100644 --- a/basicrta/kinetics.py +++ b/basicrta/kinetics.py @@ -1,3 +1,10 @@ +""" +Map kinetics from gibbs data to md trajectory. + +This module provides the `MapKinetics` class, which creates trajectories and +weighted densities based on the clustered gibbs data and original trajectory. +""" + from tqdm import tqdm from basicrta.util import get_start_stop_frames import numpy as np @@ -203,16 +210,30 @@ def weighted_densities(self, step=1, top_n=None, filterP=0): d.results.density.export(outname) - -if __name__ == "__main__": - from basicrta.gibbs import Gibbs +def get_parser(): import argparse - parser = argparse.ArgumentParser() - parser.add_argument("--gibbs", type=str) - parser.add_argument("--contacts", type=str) - parser.add_argument("--top_n", type=int, nargs='?', default=None) - parser.add_argument("--step", type=int, nargs='?', default=1) - parser.add_argument("--wdensity", action='store_true') + parser = argparse.ArgumentParser(description="""map kinetics from clustered + results onto trajectory, create weighted + densities if flag is used""") + required = parser.add_argument_group('required arguments') + required.add_argument("--gibbs", type=str, required=True, help="""gibbs pickle + file to use for creating kinetic trajectories and + densities""") + required.add_argument("--contacts", type=str, required=True, help="""contacts + file used in creation of the gibbs sampler data""") + parser.add_argument("--top_n", type=int, nargs='?', help="""use the `top_n` + most likely frames to create trajectory or densities""") + parser.add_argument("--step", type=int, nargs='?', default=1, help="""write + out frame if frame%%step=0""") + parser.add_argument("--wdensity", action='store_true', help="""create + weighted densities""") + # this is to make the cli work, should be just a temporary solution + parser.add_argument('kinetics', nargs='?', help=argparse.SUPPRESS) + return parser + +def main(): + from basicrta.gibbs import Gibbs + parser = get_parser() args = parser.parse_args() g = Gibbs().load(args.gibbs) @@ -220,3 +241,6 @@ def weighted_densities(self, step=1, top_n=None, filterP=0): mk.create_traj(top_n=args.top_n) if args.wdensity: mk.weighted_densities(step=args.step, top_n=args.top_n) + +if __name__ == "__main__": + exit(main()) diff --git a/basicrta/tests/test_cli.py b/basicrta/tests/test_cli.py new file mode 100644 index 0000000..e466240 --- /dev/null +++ b/basicrta/tests/test_cli.py @@ -0,0 +1,70 @@ +""" +Tests for combining contact timeseries from multiple repeat runs. +""" +import basicrta +import os +import pytest +import numpy as np +import pickle +import subprocess +import basicrta.cli +import importlib +import argparse +from basicrta.contacts import CombineContacts + + +class TestCLI: + """Test class for cli.py functionality.""" + modules = basicrta.cli.commands + + def test_cli_modules(self): + """Test cli as module""" + for module in self.modules: + #help successfully printed + help_ret = subprocess.run(['python', '-m', f'basicrta.{module}', + '--help']) + assert help_ret.returncode == 0 + + # error if required arguments not given + nohelp = subprocess.run(['python', '-m', f'basicrta.{module}']) + assert nohelp.returncode == 2 + + def test_cli_entrypoint(self): + # print general help if no command given + assert subprocess.run('basicrta').returncode == 0 + + for module in self.modules: + #help successfully printed + help_ret = subprocess.run(['basicrta', f'{module}', '--help']) + assert help_ret.returncode == 0 + + # help is printed if no arguments given + nohelp = subprocess.run(['basicrta', f'{module}']) + assert nohelp.returncode == 0 + + def test_get_module_parsers(self): + for module in self.modules: + parser = importlib.import_module(f"basicrta.{module}").get_parser() + assert type(parser) == argparse.ArgumentParser + + def test_call_main_empty(self): + for module in self.modules: + with pytest.raises(SystemExit): + importlib.import_module(f"basicrta.{module}").main() + + def test_cli_script_call(self): + #help successfully printed + help_ret = subprocess.run(['python', '-m', 'basicrta.cli', + '--help']) + assert help_ret.returncode == 0 + + # print help if arguments not given + nohelp = subprocess.run(['python', '-m', 'basicrta.cli']) + assert nohelp.returncode == 0 + + +# def test_call_main_args(self): +# with mock.patch('sys.argv', ['cluster', '--cutoff', '6.9']): +# importlib.import_module(f"basicrta.cluster").main() + + diff --git a/docs/source/api.rst b/docs/source/api.rst index ea1acdb..2580a39 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -9,5 +9,6 @@ API Documentation contacts cluster kinetics + combine util - + cli diff --git a/docs/source/autosummary/cli.rst b/docs/source/autosummary/cli.rst new file mode 100644 index 0000000..4a04c36 --- /dev/null +++ b/docs/source/autosummary/cli.rst @@ -0,0 +1,6 @@ +cli +=== + +.. automodule:: cli + :members: + :undoc-members: \ No newline at end of file diff --git a/docs/source/autosummary/combine.rst b/docs/source/autosummary/combine.rst new file mode 100644 index 0000000..21eb381 --- /dev/null +++ b/docs/source/autosummary/combine.rst @@ -0,0 +1,6 @@ +combine +======= + +.. automodule:: combine + :members: + :undoc-members: \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 33b8c1c..f3991e6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,6 +49,9 @@ doc = [ source = "https://github.com/becksteinlab/basicrta" documentation = "https://basicrta.readthedocs.io" +[project.scripts] +basicrta = "basicrta.cli:main" + [tool.setuptools] py-modules = []