Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,7 @@ script manually). Commonly used flags:
- `--data_pipeline {alphafold2,alphafold3}` – choose the feature format to emit.
- `--db_preset {full_dbs,reduced_dbs}` – switch between the full BFD stack or the reduced databases.
- `--use_mmseqs2` – rely on the remote MMseqs2 API; skips local jackhmmer/HHsearch database lookups.
- `--skip_msa` – generate query-only single-sequence features instead of running bulk MSA searches. Use these feature pickles with `run_structure_prediction.py --pair_msa=False`.
- `--use_precomputed_msas` / `--save_msa_files` – reuse stored MSAs or keep new ones for later runs.
- `--compress_features` – zip the generated `*.pkl` files (`.xz` extension) to save space.
- `--skip_existing` – leave existing feature files untouched (safe for reruns).
Expand Down
119 changes: 111 additions & 8 deletions alphapulldown/objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,23 @@
strip_mmseq_comment_lines,
)


def _query_only_a3m(sequence: str, query_id: str = "query") -> str:
"""Return a single-sequence A3M string for query-only workflows."""
return f">{query_id}\n{sequence}\n"


def _query_only_stockholm(sequence: str, query_id: str = "query") -> str:
"""Return a single-sequence Stockholm alignment string."""
rf_annotation = "x" * len(sequence)
return (
"# STOCKHOLM 1.0\n"
f"{query_id} {sequence}\n"
f"#=GC RF {rf_annotation}\n"
"//\n"
)


class MonomericObject:
"""
monomeric objects
Expand All @@ -41,6 +58,7 @@ def __init__(self, description, sequence) -> None:
self.sequence = sequence
self.feature_dict = dict()
self._uniprot_runner = None
self.skip_msa = False
pass

@property
Expand Down Expand Up @@ -140,7 +158,8 @@ def all_seq_msa_features(
def make_features(
self, pipeline, output_dir: str,
use_precomputed_msa: bool = False,
save_msa: bool = True, compress_msa_files: bool = False
save_msa: bool = True, compress_msa_files: bool = False,
skip_msa: bool = False,
):
"""a method that make msa and template features"""
os.makedirs(os.path.join(output_dir, self.description), exist_ok=True)
Expand All @@ -155,13 +174,20 @@ def make_features(
logging.info(
"will save msa files in :{}".format(msa_output_dir))
plPath(msa_output_dir).mkdir(parents=True, exist_ok=True)
with temp_fasta_file(sequence_str) as fasta_file:
self.feature_dict = pipeline.process(
fasta_file, msa_output_dir)
pairing_results = self.all_seq_msa_features(
fasta_file, self._uniprot_runner, msa_output_dir, use_precomputed_msa
self.skip_msa = skip_msa
if skip_msa:
self.feature_dict = self._build_query_only_feature_dict()
self.feature_dict.update(
self._search_templates_with_query_only_msa(pipeline, msa_output_dir)
)
self.feature_dict.update(pairing_results)
else:
with temp_fasta_file(sequence_str) as fasta_file:
self.feature_dict = pipeline.process(
fasta_file, msa_output_dir)
pairing_results = self.all_seq_msa_features(
fasta_file, self._uniprot_runner, msa_output_dir, use_precomputed_msa
)
self.feature_dict.update(pairing_results)

# Add extra features to make it compatible with pickle features obtaiend from mmseqs2
template_confidence_scores = self.feature_dict.get('template_confidence_scores', None)
Expand All @@ -187,6 +213,60 @@ def make_features(
MonomericObject.zip_msa_files(
os.path.join(output_dir, self.description))

def _build_query_only_feature_dict(self) -> Dict[str, Any]:
"""Build AF2-compatible features with the query as the only MSA row."""
query_only_msa = parsers.parse_a3m(_query_only_a3m(self.sequence))
sequence_features = pipeline.make_sequence_features(
sequence=self.sequence,
description=self.description,
num_res=len(self.sequence),
)
msa_features = pipeline.make_msa_features((query_only_msa,))
all_seq_features = {
f"{key}_all_seq": np.array(value, copy=True)
for key, value in msa_features.items()
}
all_seq_features["msa_uniprot_accession_identifiers_all_seq"] = np.array(
[b""], dtype=object
)
return {**sequence_features, **msa_features, **all_seq_features}

def _search_templates_with_query_only_msa(
self, af2_pipeline: pipeline.DataPipeline, msa_output_dir: str
) -> Dict[str, Any]:
"""Run template search from a synthetic single-sequence alignment."""
template_searcher = getattr(af2_pipeline, "template_searcher", None)
template_featurizer = getattr(af2_pipeline, "template_featurizer", None)
if template_searcher is None or template_featurizer is None:
return {}

stockholm_msa = _query_only_stockholm(self.sequence)
if template_searcher.input_format == "sto":
template_query = stockholm_msa
elif template_searcher.input_format == "a3m":
template_query = _query_only_a3m(self.sequence)
else:
raise ValueError(
"Unrecognized template input format: "
f"{template_searcher.input_format}"
)

pdb_templates_result = template_searcher.query(template_query)
pdb_hits_out_path = os.path.join(
msa_output_dir, f"pdb_hits.{template_searcher.output_format}"
)
with open(pdb_hits_out_path, "w") as handle:
handle.write(pdb_templates_result)

pdb_template_hits = template_searcher.get_template_hits(
output_string=pdb_templates_result,
input_sequence=self.sequence,
)
templates_result = template_featurizer.get_templates(
query_sequence=self.sequence,
hits=pdb_template_hits,
)
return dict(templates_result.features)

def make_mmseq_features(
self, DEFAULT_API_SERVER,
Expand All @@ -195,6 +275,7 @@ def make_mmseq_features(
use_precomputed_msa=False,
use_templates=False,
custom_template_path=None,
skip_msa: bool = False,
):
"""
A method to use mmseq_remote to calculate MSA.
Expand All @@ -212,7 +293,29 @@ def make_mmseq_features(
logging.info(f"Skipping {self.description} (result.zip)")

a3m_path = os.path.join(result_dir, self.description + ".a3m")
if use_precomputed_msa and os.path.isfile(a3m_path):
self.skip_msa = skip_msa
if skip_msa:
a3m_lines = [_query_only_a3m(self.sequence, query_id="101")]
plPath(a3m_path).write_text(a3m_lines[0])
(
unpaired_msa,
paired_msa,
query_seqs_unique,
query_seqs_cardinality,
template_features,
) = get_msa_and_templates(
jobname=self.description,
query_sequences=self.sequence,
a3m_lines=a3m_lines,
result_dir=plPath(result_dir),
msa_mode="single_sequence",
use_templates=use_templates,
custom_template_path=custom_template_path,
pair_mode="none",
host_url=DEFAULT_API_SERVER,
user_agent="alphapulldown",
)
elif use_precomputed_msa and os.path.isfile(a3m_path):
logging.info(f"Using precomputed MSA from {a3m_path}")
a3m_lines = [plPath(a3m_path).read_text()]
(unpaired_msa, paired_msa, query_seqs_unique, query_seqs_cardinality,
Expand Down
125 changes: 95 additions & 30 deletions alphapulldown/scripts/create_individual_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import tempfile
from datetime import datetime
from pathlib import Path
from types import SimpleNamespace
import numpy as np

from absl import logging, app, flags
Expand Down Expand Up @@ -142,6 +143,7 @@
flags.DEFINE_boolean('re_search_templates_mmseqs2', False, '')
flags.DEFINE_bool("use_mmseqs2", False, "")
flags.DEFINE_bool("save_msa_files", False, "")
flags.DEFINE_bool("skip_msa", False, "")
flags.DEFINE_bool("skip_existing", False, "")
flags.DEFINE_string("new_uniclust_dir", None, "")
flags.DEFINE_integer("seq_index", None, "")
Expand Down Expand Up @@ -272,14 +274,68 @@ def get_af3_chain_kind(description, sequence):
def create_af3_chain(sequence, description, chain_id):
"""Construct an AF3 chain object for the provided sequence."""
chain_kind = get_af3_chain_kind(description, sequence)
query_only_a3m = f">query\n{sequence}\n" if FLAGS.skip_msa else None
if chain_kind == "dna":
return folding_input.DnaChain(sequence=sequence, id=chain_id, modifications=[])
return folding_input.DnaChain(
sequence=sequence,
id=chain_id,
modifications=[],
description=description,
)
if chain_kind == "rna":
return folding_input.RnaChain(sequence=sequence, id=chain_id, modifications=[])
return folding_input.ProteinChain(sequence=sequence, id=chain_id, ptms=[])
kwargs = {
"sequence": sequence,
"id": chain_id,
"modifications": [],
"description": description,
}
if FLAGS.skip_msa:
kwargs["unpaired_msa"] = query_only_a3m
return folding_input.RnaChain(**kwargs)

kwargs = {
"sequence": sequence,
"id": chain_id,
"ptms": [],
"description": description,
}
if FLAGS.skip_msa:
kwargs.update(
{
"paired_msa": "",
"unpaired_msa": query_only_a3m,
"templates": None,
}
)
return folding_input.ProteinChain(**kwargs)

# =================== AlphaFold 2 Feature Creation ===================

def _create_af2_template_stack():
"""Create the AF2 template searcher and featurizer."""
if FLAGS.use_hhsearch:
template_searcher = hhsearch.HHSearch(
binary_path=FLAGS.hhsearch_binary_path, databases=[FLAGS.pdb70_database_path]
)
template_featuriser = templates.HhsearchHitFeaturizer(
mmcif_dir=FLAGS.template_mmcif_dir, max_template_date=FLAGS.max_template_date,
max_hits=20, kalign_binary_path=FLAGS.kalign_binary_path,
release_dates_path=None, obsolete_pdbs_path=FLAGS.obsolete_pdbs_path
)
else:
template_featuriser = templates.HmmsearchHitFeaturizer(
mmcif_dir=FLAGS.template_mmcif_dir, max_template_date=FLAGS.max_template_date,
max_hits=20, kalign_binary_path=FLAGS.kalign_binary_path,
obsolete_pdbs_path=FLAGS.obsolete_pdbs_path, release_dates_path=None
)
template_searcher = hmmsearch.Hmmsearch(
binary_path=FLAGS.hmmsearch_binary_path,
hmmbuild_binary_path=FLAGS.hmmbuild_binary_path,
database_path=FLAGS.pdb_seqres_database_path
)
return template_searcher, template_featuriser


def create_pipeline_af2():
"""Create and configure the AlphaFold2 data pipeline."""
use_small_bfd = FLAGS.db_preset == "reduced_dbs"
Expand All @@ -289,26 +345,13 @@ def create_pipeline_af2():
template_searcher = None
template_featuriser = None
else:
if FLAGS.use_hhsearch:
template_searcher = hhsearch.HHSearch(
binary_path=FLAGS.hhsearch_binary_path, databases=[FLAGS.pdb70_database_path]
)
template_featuriser = templates.HhsearchHitFeaturizer(
mmcif_dir=FLAGS.template_mmcif_dir, max_template_date=FLAGS.max_template_date,
max_hits=20, kalign_binary_path=FLAGS.kalign_binary_path,
release_dates_path=None, obsolete_pdbs_path=FLAGS.obsolete_pdbs_path
)
else:
template_featuriser = templates.HmmsearchHitFeaturizer(
mmcif_dir=FLAGS.template_mmcif_dir, max_template_date=FLAGS.max_template_date,
max_hits=20, kalign_binary_path=FLAGS.kalign_binary_path,
obsolete_pdbs_path=FLAGS.obsolete_pdbs_path, release_dates_path=None
)
template_searcher = hmmsearch.Hmmsearch(
binary_path=FLAGS.hmmsearch_binary_path,
hmmbuild_binary_path=FLAGS.hmmbuild_binary_path,
database_path=FLAGS.pdb_seqres_database_path
)
template_searcher, template_featuriser = _create_af2_template_stack()

if FLAGS.skip_msa:
return SimpleNamespace(
template_searcher=template_searcher,
template_featurizer=template_featuriser,
)

return AF2DataPipeline(
jackhmmer_binary_path=FLAGS.jackhmmer_binary_path,
Expand All @@ -334,9 +377,12 @@ def create_individual_features():
uniprot_runner = None
else:
pipeline = create_pipeline_af2()
uniprot_runner = create_uniprot_runner(
FLAGS.jackhmmer_binary_path, FLAGS.uniprot_database_path
)
if FLAGS.skip_msa:
uniprot_runner = None
else:
uniprot_runner = create_uniprot_runner(
FLAGS.jackhmmer_binary_path, FLAGS.uniprot_database_path
)

for seq_idx, (seq, desc) in enumerate(iter_seqs(FLAGS.fasta_paths), 1):
if FLAGS.seq_index is None or seq_idx == FLAGS.seq_index:
Expand Down Expand Up @@ -444,6 +490,18 @@ def _reuse_truemultimer_monomer_features(feat):
monomer = _load_existing_monomer_from_output_dir(source_name)
if monomer is None:
return None
cached_skip_msa = getattr(monomer, "skip_msa", False)
if FLAGS.skip_msa != cached_skip_msa:
requested_mode = "--skip_msa" if FLAGS.skip_msa else "full-MSA"
cached_mode = "--skip_msa" if cached_skip_msa else "full-MSA"
logging.info(
"Existing monomer features for %s were generated in %s mode, but the "
"current TrueMultimer entry requested %s mode. Recomputing features.",
source_name,
cached_mode,
requested_mode,
)
return None
if monomer.sequence != feat["sequence"]:
logging.warning(
"Existing monomer features for %s use sequence %s, but the current "
Expand Down Expand Up @@ -489,19 +547,23 @@ def create_and_save_monomer_objects(monomer, pipeline, custom_template_path=None

if _should_skip_monomer_output(monomer.description):
return
monomer.skip_msa = FLAGS.skip_msa
if FLAGS.use_mmseqs2:
monomer.make_mmseq_features(
DEFAULT_API_SERVER=DEFAULT_API_SERVER,
output_dir=FLAGS.output_dir,
use_precomputed_msa=FLAGS.use_precomputed_msas,
use_templates=FLAGS.re_search_templates_mmseqs2 or custom_template_path is not None,
custom_template_path=custom_template_path,
skip_msa=FLAGS.skip_msa,
)
else:
monomer.make_features(
pipeline=pipeline, output_dir=FLAGS.output_dir,
use_precomputed_msa=FLAGS.use_precomputed_msas,
save_msa=FLAGS.save_msa_files)
save_msa=FLAGS.save_msa_files,
skip_msa=FLAGS.skip_msa,
)
_persist_monomer_outputs(monomer)

def create_individual_features_truemultimer():
Expand Down Expand Up @@ -535,9 +597,12 @@ def process_multimeric_features(feat, idx):
uniprot_runner = None
else:
pipeline = create_pipeline_af2()
uniprot_runner = create_uniprot_runner(
FLAGS.jackhmmer_binary_path, FLAGS.uniprot_database_path
)
if FLAGS.skip_msa:
uniprot_runner = None
else:
uniprot_runner = create_uniprot_runner(
FLAGS.jackhmmer_binary_path, FLAGS.uniprot_database_path
)

monomer = MonomericObject(protein, feat['sequence'])
monomer.uniprot_runner = uniprot_runner
Expand Down
10 changes: 10 additions & 0 deletions alphapulldown/scripts/run_structure_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,16 @@ def pre_modelling_setup(
A MultimericObject or MonomericObject
output_directory for this particular modelling job
"""
if (
len(interactors) > 1
and FLAGS.pair_msa
and any(getattr(interactor, "skip_msa", False) for interactor in interactors)
):
raise ValueError(
"--skip_msa generates query-only MSAs and cannot be combined with "
"--pair_msa=True. Re-run structure prediction with --pair_msa=False."
)

if len(interactors) > 1:
# this means it's going to be a MultimericObject
object_to_model = MultimericObject(
Expand Down
1 change: 1 addition & 0 deletions alphapulldown/utils/modelling_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,7 @@ def process_each_dict(data,monomer_objects_dir):
monomer.feature_dict,
curr_interactor_region,
)
chopped_object.skip_msa = getattr(monomer, "skip_msa", False)
chopped_object.prepare_final_sliced_feature_dict()
interactors.append(chopped_object)
return interactors
Expand Down
Loading
Loading