Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,17 @@ The rules for this file:

### Authors
* @orbeckst
* @rsexton2

### Fixed
* Have cluster.ProcessProtein.reprocess() record "no result" if
the gibbs.Gibbs.process_gibbs() step fails due to insufficient
number of samples. Otherwise `python -m cluster` fails to process
whole proteins.

### Added
* Added command-line interface for basicrta workflow (Issue #20)

## [1.1.3] - 2025-09-11

### Authors
Expand Down
65 changes: 65 additions & 0 deletions basicrta/cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
"""
Command line functionality of basicrta.

The `main()` function of this module gets the argument parser from each of the
scripts below and executes the `main()` function of the module called. The
function also collects help from the subparsers and provides it at the command
line.

Modules callable from the cli: contacts.py, gibbs.py, cluster.py, kinetics.py,
combine.py.
"""

from importlib.metadata import version
import basicrta
import argparse
import subprocess
import importlib
import sys

__version__ = version("basicrta")

# define which scripts can be ran from cli
# can easily add functionality to cli as modules are added
commands = ['contacts', 'gibbs', 'cluster', 'combine', 'kinetics']

def main():
""" This module provides the functionality for a command line interface for
basicrta scripts. The scripts available to the cli are:

* contacts.py
* gibbs.py
* cluster.py
* combine.py
* kinetics.py

Each script is called and ran using the `main()` function of each module and
the parser is passed to the cli using the `get_parser()` function. Any
module added to the cli needs to have both functions.
"""
parser = argparse.ArgumentParser(prog='basicrta', add_help=True)
subparsers = parser.add_subparsers(help="""step in the basicrta workflow to
execute""")

# collect parser from each script in `commands`
for command in commands:
subparser = importlib.import_module(f"basicrta.{command}").get_parser()
subparsers.add_parser(f'{command}', parents=[subparser], add_help=True,
description=subparser.description,
conflict_handler='resolve',
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
help=subparser.description)

# print subparser help if no arguments given
if len(sys.argv) == 2 and sys.argv[1] in commands:
subparsers.choices[f'{sys.argv[1]}'].print_help()
sys.exit()

# print basicrta help if no subcommand given
parser.parse_args(args=None if sys.argv[1:] else ['--help'])

# execute basicrta script
importlib.import_module(f"basicrta.{sys.argv[1]}").main()

if __name__ == "__main__":
main()
56 changes: 39 additions & 17 deletions basicrta/cluster.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
"""This module provides the ProcessProtein class, which collects and processes
Gibbs sampler data.
"""

import os
import gc
import warnings
Expand All @@ -11,9 +15,6 @@
from basicrta.gibbs import Gibbs
gc.enable()

"""This module provides the ProcessProtein class, which collects and processes
Gibbs sampler data.
"""

class ProcessProtein(object):
r"""ProcessProtein is the class that collects and processes Gibbs sampler
Expand Down Expand Up @@ -237,33 +238,54 @@ def b_color_structure(self, structure):

u.select_atoms('protein').write('tau_bcolored.pdb')


if __name__ == "__main__": #pragma: no cover
# the script is tested in the test_cluster.py but cannot be accounted for
# in the coverage report
def get_parser():
import argparse
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--nproc', type=int, default=1)
parser.add_argument('--cutoff', type=float)
parser.add_argument('--niter', type=int, default=110000)
parser.add_argument('--prot', type=str, default=None, nargs='?')
parser.add_argument('--label-cutoff', type=float, default=3,
dest='label_cutoff',
help='Only label residues with tau > '
'LABEL-CUTOFF * <tau>. ')
parser.add_argument('--structure', type=str, nargs='?')
parser = argparse.ArgumentParser(description="""perform clustering for each
residue located in basicrta-{cutoff}/""",
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
required = parser.add_argument_group('required arguments')

required.add_argument('--cutoff', required=True, type=float, help="""cutoff
used in contact analysis, will cluster results in
basicrta-{cutoff}/""")
parser.add_argument('--nproc', type=int, default=1, help="""number of
processes to use in multiprocessing""")
parser.add_argument('--niter', type=int, default=110000, help="""number of
iterations used in the gibbs sampler, used to load
gibbs_{niter}.pkl""")
parser.add_argument('--prot', type=str, nargs='?', help="""name of protein
in tm_dict.txt, used to draw TM bars in tau vs resid
plot""")
parser.add_argument('--label_cutoff', type=float, default=3,
dest='label_cutoff',
help="""Only label residues with tau >
LABEL-CUTOFF * <tau>.""")
parser.add_argument('--structure', type=str, nargs='?', help="""will add tau
as bfactors to the structure if provided""")
# use for default values
parser.add_argument('--gskip', type=int, default=100,
help='Gibbs skip parameter for decorrelated samples;'
'default from https://pubs.acs.org/doi/10.1021/acs.jctc.4c01522')
parser.add_argument('--burnin', type=int, default=10000,
help='Burn-in parameter, drop first N samples as equilibration;'
'default from https://pubs.acs.org/doi/10.1021/acs.jctc.4c01522')
# this is to make the cli work, should be just a temporary solution
parser.add_argument('cluster', nargs='?', help=argparse.SUPPRESS)
return parser

def main():
parser = get_parser()
args = parser.parse_args()

pp = ProcessProtein(args.niter, args.prot, args.cutoff,
gskip=args.gskip, burnin=args.burnin)
pp.reprocess(nproc=args.nproc)
pp.write_data()
pp.plot_protein(label_cutoff=args.label_cutoff)


if __name__ == "__main__": #pragma: no cover
# the script is tested in the test_cluster.py but cannot be accounted for
# in the coverage report
exit(main())

158 changes: 148 additions & 10 deletions basicrta/combine.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,162 @@
#!/usr/bin/env python

"""
Command-line interface for combining contact timeseries from multiple repeat runs.
Combine contact timeseries from multiple repeat runs.

This module provides functionality to combine contact files from multiple
trajectory repeats, enabling pooled analysis of binding kinetics.
"""

import os
import argparse
from basicrta.contacts import CombineContacts

class CombineContacts(object):
"""Class to combine contact timeseries from multiple repeat runs.

This class enables pooling data from multiple trajectory repeats and
calculating posteriors from all data together, rather than analyzing
each run separately.

:param contact_files: List of contact pickle files to combine
:type contact_files: list of str
:param output_name: Name for the combined output file (default: 'combined_contacts.pkl')
:type output_name: str, optional
:param validate_compatibility: Whether to validate that files are compatible (default: True)
:type validate_compatibility: bool, optional
"""

def __init__(self, contact_files, output_name='combined_contacts.pkl',
validate_compatibility=True):
self.contact_files = contact_files
self.output_name = output_name
self.validate_compatibility = validate_compatibility

if len(contact_files) < 2:
raise ValueError("At least 2 contact files are required for combining")

def _load_contact_file(self, filename):
"""Load a contact pickle file and return data and metadata."""
if not os.path.exists(filename):
raise FileNotFoundError(f"Contact file not found: {filename}")

with open(filename, 'rb') as f:
contacts = pickle.load(f)

metadata = contacts.dtype.metadata
return contacts, metadata

def _validate_compatibility(self, metadatas):
"""Validate that contact files are compatible for combining."""
reference = metadatas[0]

# Check that all files have the same atom groups
for i, meta in enumerate(metadatas[1:], 1):
# Compare cutoff
if meta['cutoff'] != reference['cutoff']:
raise ValueError(f"Incompatible cutoffs: file 0 has {reference['cutoff']}, "
f"file {i} has {meta['cutoff']}")

# Compare atom group selections by checking if resids match
ref_ag1_resids = set(reference['ag1'].residues.resids)
ref_ag2_resids = set(reference['ag2'].residues.resids)
meta_ag1_resids = set(meta['ag1'].residues.resids)
meta_ag2_resids = set(meta['ag2'].residues.resids)

if ref_ag1_resids != meta_ag1_resids:
raise ValueError(f"Incompatible ag1 residues between file 0 and file {i}")
if ref_ag2_resids != meta_ag2_resids:
raise ValueError(f"Incompatible ag2 residues between file 0 and file {i}")

# Check timesteps and warn if different
timesteps = [meta['ts'] for meta in metadatas]
if not all(abs(ts - timesteps[0]) < 1e-6 for ts in timesteps):
print("WARNING: Different timesteps detected across runs:")
for i, (filename, ts) in enumerate(zip(self.contact_files, timesteps)):
print(f" File {i} ({filename}): dt = {ts} ns")
print("This may affect residence time estimates, especially for fast events.")

def run(self):
"""Combine contact files and save the result."""
print(f"Combining {len(self.contact_files)} contact files...")

all_contacts = []
all_metadatas = []

# Load all contact files
for i, filename in enumerate(self.contact_files):
print(f"Loading file {i+1}/{len(self.contact_files)}: {filename}")
contacts, metadata = self._load_contact_file(filename)
all_contacts.append(contacts)
all_metadatas.append(metadata)

# Validate compatibility if requested
if self.validate_compatibility:
print("Validating file compatibility...")
self._validate_compatibility(all_metadatas)

# Combine contact data
print("Combining contact data...")

# Calculate total size and create combined array
total_size = sum(len(contacts) for contacts in all_contacts)
reference_metadata = all_metadatas[0].copy()

# Extend metadata to include trajectory source information
reference_metadata['source_files'] = self.contact_files
reference_metadata['n_trajectories'] = len(self.contact_files)

# Determine number of columns (5 for raw contacts, 4 for processed)
n_cols = all_contacts[0].shape[1]

# Create dtype with extended metadata
combined_dtype = np.dtype(np.float64, metadata=reference_metadata)

# Add trajectory source column (will be last column)
combined_contacts = np.zeros((total_size, n_cols + 1), dtype=np.float64)

# Combine data and add trajectory source information
offset = 0
for traj_idx, contacts in enumerate(all_contacts):
n_contacts = len(contacts)
# Copy original contact data
combined_contacts[offset:offset+n_contacts, :n_cols] = contacts[:]
# Add trajectory source index
combined_contacts[offset:offset+n_contacts, n_cols] = traj_idx
offset += n_contacts

# Create final memmap with proper dtype
final_contacts = combined_contacts.view(combined_dtype)

# Save combined contacts
print(f"Saving combined contacts to {self.output_name}...")
final_contacts.dump(self.output_name, protocol=5)

print(f"Successfully combined {len(self.contact_files)} files into {self.output_name}")
print(f"Total contacts: {total_size}")
print(f"Added trajectory source column (index {n_cols}) for kinetic clustering support")

return self.output_name

def main():
"""Main function for combining contact files."""
def get_parser():
"""Create parser, parse command line arguments, and return ArgumentParser
object.

:return: An ArgumentParser instance with command line arguments stored.
:rtype: `ArgumentParser` object
"""
parser = argparse.ArgumentParser(
description="Combine contact timeseries from multiple repeat runs. "
"This enables pooling data from multiple trajectory repeats "
"and calculating posteriors from all data together."
)

parser.add_argument(
required = parser.add_argument_group('required arguments')
required.add_argument(
'--contacts',
nargs='+',
required=True,
help="List of contact pickle files to combine (e.g., contacts_7.0.pkl from different runs)"
help="""List of contact pickle files to combine (e.g., contacts_7.0.pkl
from different runs)""",
)

parser.add_argument(
Expand All @@ -39,7 +171,14 @@ def main():
action='store_true',
help="Skip compatibility validation (use with caution)"
)

# this is to make the cli work, should be just a temporary solution
parser.add_argument('combine', nargs='?', help=argparse.SUPPRESS)
return parser

def main():
"""Execute this function when this script is called from the command line.
"""
parser = get_parser()
args = parser.parse_args()

# Validate input files exist
Expand Down Expand Up @@ -82,6 +221,5 @@ def main():
print(f"ERROR: {e}")
return 1


if __name__ == '__main__':
exit(main())
if __name__ == "__main__":
exit(main())
Loading
Loading