From f3bb3037a69d7dc982ac82081ef28b64942c5cb1 Mon Sep 17 00:00:00 2001 From: Dima <33123184+DimaMolod@users.noreply.github.com> Date: Fri, 10 Apr 2026 11:35:27 +0200 Subject: [PATCH 1/5] fix(#42): add skip-MSA feature generation mode --- alphapulldown/objects.py | 113 +++++++++++- .../scripts/create_individual_features.py | 49 ++++- .../scripts/run_structure_prediction.py | 6 + alphapulldown/utils/modelling_setup.py | 1 + conftest.py | 1 + .../test_create_individual_features.py | 173 +++++++++++++++++- test/unit/test_modelling_setup.py | 36 ++++ test/unit/test_objects.py | 146 ++++++++++++++- test/unit/test_script_entrypoints.py | 70 +++++++ 9 files changed, 579 insertions(+), 16 deletions(-) diff --git a/alphapulldown/objects.py b/alphapulldown/objects.py index 26c610f3..70127479 100644 --- a/alphapulldown/objects.py +++ b/alphapulldown/objects.py @@ -26,6 +26,17 @@ strip_mmseq_comment_lines, ) + +def _query_only_a3m(sequence: str, query_id: str = "query") -> str: + """Return a single-sequence A3M string for query-only workflows.""" + return f">{query_id}\n{sequence}\n" + + +def _query_only_stockholm(sequence: str, query_id: str = "query") -> str: + """Return a single-sequence Stockholm alignment string.""" + return f"# STOCKHOLM 1.0\n{query_id} {sequence}\n//\n" + + class MonomericObject: """ monomeric objects @@ -41,6 +52,7 @@ def __init__(self, description, sequence) -> None: self.sequence = sequence self.feature_dict = dict() self._uniprot_runner = None + self.skip_msa = False pass @property @@ -140,7 +152,8 @@ def all_seq_msa_features( def make_features( self, pipeline, output_dir: str, use_precomputed_msa: bool = False, - save_msa: bool = True, compress_msa_files: bool = False + save_msa: bool = True, compress_msa_files: bool = False, + skip_msa: bool = False, ): """a method that make msa and template features""" os.makedirs(os.path.join(output_dir, self.description), exist_ok=True) @@ -155,13 +168,20 @@ def make_features( logging.info( "will save msa files in :{}".format(msa_output_dir)) plPath(msa_output_dir).mkdir(parents=True, exist_ok=True) - with temp_fasta_file(sequence_str) as fasta_file: - self.feature_dict = pipeline.process( - fasta_file, msa_output_dir) - pairing_results = self.all_seq_msa_features( - fasta_file, self._uniprot_runner, msa_output_dir, use_precomputed_msa + self.skip_msa = skip_msa + if skip_msa: + self.feature_dict = self._build_query_only_feature_dict() + self.feature_dict.update( + self._search_templates_with_query_only_msa(pipeline, msa_output_dir) ) - self.feature_dict.update(pairing_results) + else: + with temp_fasta_file(sequence_str) as fasta_file: + self.feature_dict = pipeline.process( + fasta_file, msa_output_dir) + pairing_results = self.all_seq_msa_features( + fasta_file, self._uniprot_runner, msa_output_dir, use_precomputed_msa + ) + self.feature_dict.update(pairing_results) # Add extra features to make it compatible with pickle features obtaiend from mmseqs2 template_confidence_scores = self.feature_dict.get('template_confidence_scores', None) @@ -187,6 +207,60 @@ def make_features( MonomericObject.zip_msa_files( os.path.join(output_dir, self.description)) + def _build_query_only_feature_dict(self) -> Dict[str, Any]: + """Build AF2-compatible features with the query as the only MSA row.""" + query_only_msa = parsers.parse_a3m(_query_only_a3m(self.sequence)) + sequence_features = pipeline.make_sequence_features( + sequence=self.sequence, + description=self.description, + num_res=len(self.sequence), + ) + msa_features = pipeline.make_msa_features((query_only_msa,)) + all_seq_features = { + f"{key}_all_seq": np.array(value, copy=True) + for key, value in msa_features.items() + } + all_seq_features["msa_uniprot_accession_identifiers_all_seq"] = np.array( + [b""], dtype=object + ) + return {**sequence_features, **msa_features, **all_seq_features} + + def _search_templates_with_query_only_msa( + self, af2_pipeline: pipeline.DataPipeline, msa_output_dir: str + ) -> Dict[str, Any]: + """Run template search from a synthetic single-sequence alignment.""" + template_searcher = getattr(af2_pipeline, "template_searcher", None) + template_featurizer = getattr(af2_pipeline, "template_featurizer", None) + if template_searcher is None or template_featurizer is None: + return {} + + stockholm_msa = _query_only_stockholm(self.sequence) + if template_searcher.input_format == "sto": + template_query = stockholm_msa + elif template_searcher.input_format == "a3m": + template_query = _query_only_a3m(self.sequence) + else: + raise ValueError( + "Unrecognized template input format: " + f"{template_searcher.input_format}" + ) + + pdb_templates_result = template_searcher.query(template_query) + pdb_hits_out_path = os.path.join( + msa_output_dir, f"pdb_hits.{template_searcher.output_format}" + ) + with open(pdb_hits_out_path, "w") as handle: + handle.write(pdb_templates_result) + + pdb_template_hits = template_searcher.get_template_hits( + output_string=pdb_templates_result, + input_sequence=self.sequence, + ) + templates_result = template_featurizer.get_templates( + query_sequence=self.sequence, + hits=pdb_template_hits, + ) + return dict(templates_result.features) def make_mmseq_features( self, DEFAULT_API_SERVER, @@ -195,6 +269,7 @@ def make_mmseq_features( use_precomputed_msa=False, use_templates=False, custom_template_path=None, + skip_msa: bool = False, ): """ A method to use mmseq_remote to calculate MSA. @@ -212,7 +287,29 @@ def make_mmseq_features( logging.info(f"Skipping {self.description} (result.zip)") a3m_path = os.path.join(result_dir, self.description + ".a3m") - if use_precomputed_msa and os.path.isfile(a3m_path): + self.skip_msa = skip_msa + if skip_msa: + a3m_lines = [_query_only_a3m(self.sequence, query_id="101")] + plPath(a3m_path).write_text(a3m_lines[0]) + ( + unpaired_msa, + paired_msa, + query_seqs_unique, + query_seqs_cardinality, + template_features, + ) = get_msa_and_templates( + jobname=self.description, + query_sequences=self.sequence, + a3m_lines=a3m_lines, + result_dir=plPath(result_dir), + msa_mode="single_sequence", + use_templates=use_templates, + custom_template_path=custom_template_path, + pair_mode="none", + host_url=DEFAULT_API_SERVER, + user_agent="alphapulldown", + ) + elif use_precomputed_msa and os.path.isfile(a3m_path): logging.info(f"Using precomputed MSA from {a3m_path}") a3m_lines = [plPath(a3m_path).read_text()] (unpaired_msa, paired_msa, query_seqs_unique, query_seqs_cardinality, diff --git a/alphapulldown/scripts/create_individual_features.py b/alphapulldown/scripts/create_individual_features.py index 867550d1..d7d074a8 100644 --- a/alphapulldown/scripts/create_individual_features.py +++ b/alphapulldown/scripts/create_individual_features.py @@ -142,6 +142,7 @@ flags.DEFINE_boolean('re_search_templates_mmseqs2', False, '') flags.DEFINE_bool("use_mmseqs2", False, "") flags.DEFINE_bool("save_msa_files", False, "") +flags.DEFINE_bool("skip_msa", False, "") flags.DEFINE_bool("skip_existing", False, "") flags.DEFINE_string("new_uniclust_dir", None, "") flags.DEFINE_integer("seq_index", None, "") @@ -272,11 +273,40 @@ def get_af3_chain_kind(description, sequence): def create_af3_chain(sequence, description, chain_id): """Construct an AF3 chain object for the provided sequence.""" chain_kind = get_af3_chain_kind(description, sequence) + query_only_a3m = f">query\n{sequence}\n" if FLAGS.skip_msa else None if chain_kind == "dna": - return folding_input.DnaChain(sequence=sequence, id=chain_id, modifications=[]) + return folding_input.DnaChain( + sequence=sequence, + id=chain_id, + modifications=[], + description=description, + ) if chain_kind == "rna": - return folding_input.RnaChain(sequence=sequence, id=chain_id, modifications=[]) - return folding_input.ProteinChain(sequence=sequence, id=chain_id, ptms=[]) + kwargs = { + "sequence": sequence, + "id": chain_id, + "modifications": [], + "description": description, + } + if FLAGS.skip_msa: + kwargs["unpaired_msa"] = query_only_a3m + return folding_input.RnaChain(**kwargs) + + kwargs = { + "sequence": sequence, + "id": chain_id, + "ptms": [], + "description": description, + } + if FLAGS.skip_msa: + kwargs.update( + { + "paired_msa": "", + "unpaired_msa": query_only_a3m, + "templates": None, + } + ) + return folding_input.ProteinChain(**kwargs) # =================== AlphaFold 2 Feature Creation =================== @@ -444,6 +474,13 @@ def _reuse_truemultimer_monomer_features(feat): monomer = _load_existing_monomer_from_output_dir(source_name) if monomer is None: return None + if FLAGS.skip_msa and not getattr(monomer, "skip_msa", False): + logging.info( + "Existing monomer features for %s were generated with bulk MSAs. " + "Recomputing query-only features for --skip_msa.", + source_name, + ) + return None if monomer.sequence != feat["sequence"]: logging.warning( "Existing monomer features for %s use sequence %s, but the current " @@ -489,6 +526,7 @@ def create_and_save_monomer_objects(monomer, pipeline, custom_template_path=None if _should_skip_monomer_output(monomer.description): return + monomer.skip_msa = FLAGS.skip_msa if FLAGS.use_mmseqs2: monomer.make_mmseq_features( DEFAULT_API_SERVER=DEFAULT_API_SERVER, @@ -496,12 +534,15 @@ def create_and_save_monomer_objects(monomer, pipeline, custom_template_path=None use_precomputed_msa=FLAGS.use_precomputed_msas, use_templates=FLAGS.re_search_templates_mmseqs2 or custom_template_path is not None, custom_template_path=custom_template_path, + skip_msa=FLAGS.skip_msa, ) else: monomer.make_features( pipeline=pipeline, output_dir=FLAGS.output_dir, use_precomputed_msa=FLAGS.use_precomputed_msas, - save_msa=FLAGS.save_msa_files) + save_msa=FLAGS.save_msa_files, + skip_msa=FLAGS.skip_msa, + ) _persist_monomer_outputs(monomer) def create_individual_features_truemultimer(): diff --git a/alphapulldown/scripts/run_structure_prediction.py b/alphapulldown/scripts/run_structure_prediction.py index 6fc8efd9..ee9d5ec5 100644 --- a/alphapulldown/scripts/run_structure_prediction.py +++ b/alphapulldown/scripts/run_structure_prediction.py @@ -348,6 +348,12 @@ def pre_modelling_setup( A MultimericObject or MonomericObject output_directory for this particular modelling job """ + if FLAGS.pair_msa and any(getattr(interactor, "skip_msa", False) for interactor in interactors): + raise ValueError( + "--skip_msa generates query-only MSAs and cannot be combined with " + "--pair_msa=True. Re-run structure prediction with --pair_msa=False." + ) + if len(interactors) > 1: # this means it's going to be a MultimericObject object_to_model = MultimericObject( diff --git a/alphapulldown/utils/modelling_setup.py b/alphapulldown/utils/modelling_setup.py index 1833471a..88d9c136 100644 --- a/alphapulldown/utils/modelling_setup.py +++ b/alphapulldown/utils/modelling_setup.py @@ -271,6 +271,7 @@ def process_each_dict(data,monomer_objects_dir): monomer.feature_dict, curr_interactor_region, ) + chopped_object.skip_msa = getattr(monomer, "skip_msa", False) chopped_object.prepare_final_sliced_feature_dict() interactors.append(chopped_object) return interactors diff --git a/conftest.py b/conftest.py index ad1e0167..ae06d0c4 100755 --- a/conftest.py +++ b/conftest.py @@ -163,6 +163,7 @@ def _flag_values_dict(): use_mmseqs2=False, use_precomputed_msas=False, save_msa_files=False, + skip_msa=False, skip_existing=False, compress_features=False, db_preset="full_dbs", diff --git a/test/integration/test_create_individual_features.py b/test/integration/test_create_individual_features.py index 0b695bb7..5718c89d 100644 --- a/test/integration/test_create_individual_features.py +++ b/test/integration/test_create_individual_features.py @@ -78,22 +78,50 @@ def build_af3_stub_modules(): mmcif_mod = types.ModuleType("alphafold3.structure.mmcif") class ProteinChain: - def __init__(self, sequence, id, ptms=None): + def __init__( + self, + sequence, + id, + ptms=None, + residue_ids=None, + description=None, + paired_msa=None, + unpaired_msa=None, + templates=None, + ): self.sequence = sequence self.id = id self.ptms = [] if ptms is None else list(ptms) + self.residue_ids = residue_ids + self.description = description + self.paired_msa = paired_msa + self.unpaired_msa = unpaired_msa + self.templates = templates class RnaChain: - def __init__(self, sequence, id, modifications=None): + def __init__( + self, + sequence, + id, + modifications=None, + residue_ids=None, + description=None, + unpaired_msa=None, + ): self.sequence = sequence self.id = id self.modifications = [] if modifications is None else list(modifications) + self.residue_ids = residue_ids + self.description = description + self.unpaired_msa = unpaired_msa class DnaChain: - def __init__(self, sequence, id, modifications=None): + def __init__(self, sequence, id, modifications=None, residue_ids=None, description=None): self.sequence = sequence self.id = id self.modifications = [] if modifications is None else list(modifications) + self.residue_ids = residue_ids + self.description = description class Input: def __init__(self, name, chains, rng_seeds): @@ -598,6 +626,62 @@ def test_process_multimeric_features_falls_back_when_source_sequence_mismatches( assert saved_monomer.sequence == "ACDF" assert saved_monomer.uniprot_runner == "runner" + def test_process_multimeric_features_does_not_reuse_bulk_msa_pickle_for_skip_msa( + self, tmp_flags + ): + template_path = Path(self.test_dir) / "template1.cif" + template_path.write_text("data_template\n", encoding="utf-8") + + from absl import flags + + FLAGS = flags.FLAGS + FLAGS(["test"]) + FLAGS.output_dir = os.path.join(self.test_dir, "skip_msa_truemultimer_output") + FLAGS.use_mmseqs2 = False + FLAGS.compress_features = False + FLAGS.skip_existing = False + FLAGS.skip_msa = True + FLAGS.jackhmmer_binary_path = "/usr/bin/jackhmmer" + FLAGS.uniprot_database_path = "/db/uniprot.fasta" + + output_dir = Path(FLAGS.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + source = MonomericObject("proteinA", "ACDE") + source.feature_dict = { + "msa": np.asarray([[1, 2, 3, 4]], dtype=np.int32), + "deletion_matrix_int": np.zeros((1, 4), dtype=np.int32), + "num_alignments": np.asarray([1, 1, 1, 1], dtype=np.int32), + "msa_species_identifiers": np.asarray([b"9606"], dtype=object), + } + with open(output_dir / "proteinA.pkl", "wb") as handle: + pickle.dump(source, handle) + + feat = { + "protein": "proteinA.template1.cif.A", + "chains": ["A"], + "templates": [str(template_path)], + "sequence": "ACDE", + } + + with patch.object( + create_features, + "extract_multimeric_template_features_for_single_chain", + ) as mock_extract, \ + patch.object(create_features, "create_custom_db", return_value="/tmp/custom_db") as mock_custom_db, \ + patch.object(create_features, "create_arguments") as mock_arguments, \ + patch.object(create_features, "create_pipeline_af2", return_value="pipeline") as mock_pipeline, \ + patch.object(create_features, "create_uniprot_runner", return_value="runner") as mock_runner, \ + patch.object(create_features, "create_and_save_monomer_objects") as mock_save: + create_features.process_multimeric_features(feat, 1) + + mock_extract.assert_not_called() + mock_custom_db.assert_called_once() + mock_arguments.assert_called_once_with("/tmp/custom_db") + mock_pipeline.assert_called_once_with() + mock_runner.assert_called_once_with("/usr/bin/jackhmmer", "/db/uniprot.fasta") + mock_save.assert_called_once() + def test_main_dispatches_to_truemultimer_for_af2_template_runs(self): """The main entrypoint should route AF2 template jobs to the TrueMultimer path.""" from absl import flags @@ -1318,6 +1402,7 @@ def test_create_and_save_monomer_objects_writes_compressed_af2_outputs(tmp_flags "output_dir": str(tmp_path), "use_precomputed_msa": True, "save_msa": True, + "skip_msa": False, } ] assert monomer.mmseq_calls == [] @@ -1359,11 +1444,61 @@ def test_create_and_save_monomer_objects_uses_mmseqs_when_requested(tmp_flags, t "use_precomputed_msa": True, "use_templates": True, "custom_template_path": None, + "skip_msa": False, } ] assert (tmp_path / "protA.pkl").exists() +def test_create_and_save_monomer_objects_passes_skip_msa_to_af2_builder(tmp_flags, tmp_path): + create_features.FLAGS.output_dir = str(tmp_path) + create_features.FLAGS.compress_features = False + create_features.FLAGS.skip_existing = False + create_features.FLAGS.use_mmseqs2 = False + create_features.FLAGS.use_precomputed_msas = False + create_features.FLAGS.save_msa_files = False + create_features.FLAGS.skip_msa = True + + monomer = RecordingDummyMonomer("protA") + with patch("alphapulldown.utils.save_meta_data.get_meta_dict", return_value={"source": "test"}): + create_features.create_and_save_monomer_objects(monomer, pipeline="pipeline") + + assert monomer.feature_calls == [ + { + "pipeline": "pipeline", + "output_dir": str(tmp_path), + "use_precomputed_msa": False, + "save_msa": False, + "skip_msa": True, + } + ] + + +def test_create_and_save_monomer_objects_passes_skip_msa_to_mmseqs_builder(tmp_flags, tmp_path): + create_features.FLAGS.output_dir = str(tmp_path) + create_features.FLAGS.compress_features = False + create_features.FLAGS.skip_existing = False + create_features.FLAGS.use_mmseqs2 = True + create_features.FLAGS.use_precomputed_msas = False + create_features.FLAGS.re_search_templates_mmseqs2 = False + create_features.FLAGS.skip_msa = True + + monomer = RecordingDummyMonomer("protA") + with patch("alphapulldown.utils.save_meta_data.get_meta_dict", return_value={"source": "test"}): + create_features.create_and_save_monomer_objects(monomer, pipeline=None) + + assert monomer.mmseq_calls == [ + { + "DEFAULT_API_SERVER": create_features.DEFAULT_API_SERVER, + "output_dir": str(tmp_path), + "use_precomputed_msa": False, + "use_templates": False, + "custom_template_path": None, + "skip_msa": True, + } + ] + + def test_create_and_save_monomer_objects_passes_custom_templates_to_mmseqs(tmp_flags, tmp_path): create_features.FLAGS.output_dir = str(tmp_path) create_features.FLAGS.compress_features = False @@ -1390,6 +1525,7 @@ def test_create_and_save_monomer_objects_passes_custom_templates_to_mmseqs(tmp_f "use_precomputed_msa": False, "use_templates": True, "custom_template_path": custom_template_path, + "skip_msa": False, } ] assert (tmp_path / "protA.pkl").exists() @@ -1661,6 +1797,37 @@ def test_create_af3_individual_features_skips_existing_outputs(tmp_flags, tmp_pa assert existing_output.read_text(encoding="utf-8") == "{}" +def test_create_af3_individual_features_prefills_query_only_msas_when_skip_msa( + tmp_flags, tmp_path +): + create_features.FLAGS.output_dir = str(tmp_path) + create_features.FLAGS.data_pipeline = "alphafold3" + create_features.FLAGS.skip_msa = True + + af3_modules, folding_input_stub = build_af3_stub_modules() + pipeline = MagicMock(process=MagicMock(return_value=DummyJsonObj())) + with patch.dict(sys.modules, af3_modules), \ + patch.object(create_features, "create_pipeline_af3", return_value=pipeline), \ + patch.object(create_features, "folding_input", folding_input_stub), \ + patch.object( + create_features, + "iter_seqs", + return_value=[("ACDE", "protein_chain protein"), ("AUGA", "rna_chain RNA")], + ), \ + patch("pathlib.Path.write_text", new=real_write_text): + create_features.create_af3_individual_features() + + protein_input = pipeline.process.call_args_list[0].args[0] + protein_chain = protein_input.chains[0] + assert protein_chain.unpaired_msa == ">query\nACDE\n" + assert protein_chain.paired_msa == "" + assert protein_chain.templates is None + + rna_input = pipeline.process.call_args_list[1].args[0] + rna_chain = rna_input.chains[0] + assert rna_chain.unpaired_msa == ">query\nAUGA\n" + + def test_main_dispatches_to_af3_feature_creation(tmp_flags, tmp_path): create_features.FLAGS.data_pipeline = "alphafold3" create_features.FLAGS.output_dir = str(tmp_path / "af3_out") diff --git a/test/unit/test_modelling_setup.py b/test/unit/test_modelling_setup.py index d57ec6d0..0d8a3606 100644 --- a/test/unit/test_modelling_setup.py +++ b/test/unit/test_modelling_setup.py @@ -220,6 +220,42 @@ def prepare_final_sliced_feature_dict(self): assert calls["args"] == ("proteinA", "ACDEFG", monomer.feature_dict, [(2, 4)]) +def test_create_interactors_propagates_skip_msa_marker_to_chopped_objects(monkeypatch): + monomer = MonomericObject("proteinA", "ACDEFG") + monomer.feature_dict = {"template_aatype": np.ones((1,), dtype=np.float32)} + monomer.skip_msa = True + + class FakeChoppedObject: + def __init__(self, description, sequence, feature_dict, regions): + self.description = description + self.sequence = sequence + self.feature_dict = feature_dict + self.regions = regions + self.prepared = False + + def prepare_final_sliced_feature_dict(self): + self.prepared = True + + monkeypatch.setattr( + modelling_setup, + "make_dir_monomer_dictionary", + lambda _: {"proteinA.pkl": "/unused"}, + ) + monkeypatch.setattr(modelling_setup, "load_monomer_objects", lambda *_: monomer) + monkeypatch.setattr(modelling_setup, "check_empty_templates", lambda _: False) + monkeypatch.setattr(modelling_setup, "ChoppedObject", FakeChoppedObject) + + result = modelling_setup.create_interactors( + [{"col_1": [{"proteinA": [(2, 4)]}]}], + ["/unused"], + ) + + chopped = result[0][0] + assert isinstance(chopped, FakeChoppedObject) + assert chopped.prepared is True + assert chopped.skip_msa is True + + def test_create_interactors_currently_skips_append_when_templates_are_empty(monkeypatch): monomer = MonomericObject("proteinA", "ACDE") monomer.feature_dict = {} diff --git a/test/unit/test_objects.py b/test/unit/test_objects.py index 97e6a74d..1ce01ad6 100644 --- a/test/unit/test_objects.py +++ b/test/unit/test_objects.py @@ -145,7 +145,9 @@ def process(self, *_args, **_kwargs): staticmethod(lambda path: zip_calls.append(path)), ) monkeypatch.setattr( - MonomericObject, "remove_msa_files", staticmethod(lambda _path: None) + MonomericObject, + "remove_msa_files", + staticmethod(lambda msa_output_path=None, **_kwargs: None), ) monomer.make_features( @@ -196,6 +198,86 @@ def process(self, *_args, **_kwargs): assert remove_calls == [str(tmp_path / "proteinA")] +def test_make_features_skip_msa_builds_query_only_features_and_templates( + monkeypatch, tmp_path +): + monomer = MonomericObject("proteinA", "ACDE") + calls = {} + + class FakeTemplateSearcher: + input_format = "a3m" + output_format = "hhr" + + def query(self, alignment): + calls["template_query"] = alignment + return "template_hits" + + def get_template_hits(self, output_string, input_sequence): + calls["template_hits"] = (output_string, input_sequence) + return ["hitA"] + + class FakeTemplateFeaturizer: + def get_templates(self, query_sequence, hits): + calls["template_features"] = (query_sequence, hits) + return SimpleNamespace( + features={ + "template_aatype": np.ones((1, 4, 22), dtype=np.float32), + "template_all_atom_masks": np.ones((1, 4, 37), dtype=np.float32), + "template_all_atom_positions": np.ones( + (1, 4, 37, 3), dtype=np.float32 + ), + "template_domain_names": np.asarray([b"1abc_A"], dtype=object), + "template_sequence": np.asarray([b"ACDE"], dtype=object), + "template_sum_probs": np.asarray([0.5], dtype=np.float32), + } + ) + + class FakePipeline: + template_searcher = FakeTemplateSearcher() + template_featurizer = FakeTemplateFeaturizer() + + def process(self, *_args, **_kwargs): + raise AssertionError("skip_msa should bypass pipeline.process") + + monkeypatch.setattr( + MonomericObject, "unzip_msa_files", staticmethod(lambda _path: False) + ) + monkeypatch.setattr( + monomer, + "all_seq_msa_features", + lambda *_args, **_kwargs: (_ for _ in ()).throw( + AssertionError("skip_msa should bypass all_seq_msa_features") + ), + ) + monkeypatch.setattr( + MonomericObject, + "remove_msa_files", + staticmethod(lambda msa_output_path=None, **_kwargs: None), + ) + monkeypatch.setattr( + MonomericObject, "zip_msa_files", staticmethod(lambda _path: None) + ) + + monomer.make_features( + pipeline=FakePipeline(), + output_dir=str(tmp_path), + save_msa=False, + skip_msa=True, + ) + + assert calls["template_query"] == ">query\nACDE\n" + assert calls["template_hits"] == ("template_hits", "ACDE") + assert calls["template_features"] == ("ACDE", ["hitA"]) + assert monomer.skip_msa is True + assert monomer.feature_dict["msa"].shape == (1, 4) + assert monomer.feature_dict["msa_all_seq"].shape == (1, 4) + assert np.array_equal( + monomer.feature_dict["num_alignments"], np.asarray([1, 1, 1, 1], dtype=np.int32) + ) + assert monomer.feature_dict["msa_species_identifiers_all_seq"].tolist() == [b""] + assert monomer.feature_dict["template_domain_names"].tolist() == [b"1abc_A"] + + def test_make_mmseq_features_builds_all_seq_features_and_writes_a3m( monkeypatch, tmp_path ): @@ -265,6 +347,68 @@ def fake_enrich(feature_dict, a3m, **kwargs): assert monomer.feature_dict["template_release_date"] == ["none"] +def test_make_mmseq_features_skip_msa_uses_single_sequence_mode( + monkeypatch, tmp_path +): + monomer = MonomericObject("proteinA", "ACDE") + calls = {} + + monkeypatch.setattr( + MonomericObject, "unzip_msa_files", staticmethod(lambda _path: False) + ) + + def fake_get_msa_and_templates(**kwargs): + calls["get_msa_and_templates"] = kwargs + return (["UNPAIRED"], [""], ["UNIQUE"], ["CARD"], ["TEMPLATE"]) + + monkeypatch.setattr(objects_mod, "get_msa_and_templates", fake_get_msa_and_templates) + monkeypatch.setattr( + objects_mod, + "build_monomer_feature", + lambda sequence, msa, template_features: { + "msa": np.asarray([[1, 2, 3, 4]], dtype=np.int32), + "deletion_matrix_int": np.asarray([[0, 0, 0, 0]], dtype=np.int32), + "template_confidence_scores": None, + "template_release_date": None, + }, + ) + + def fake_enrich(feature_dict, a3m, **kwargs): + calls["enrich"] = {"a3m": a3m, "kwargs": kwargs} + feature_dict["msa_species_identifiers"] = np.asarray([b""], dtype=object) + feature_dict["msa_uniprot_accession_identifiers"] = np.asarray( + [b""], dtype=object + ) + + monkeypatch.setattr( + objects_mod, + "enrich_mmseq_feature_dict_with_identifiers", + fake_enrich, + ) + monkeypatch.setattr( + MonomericObject, "zip_msa_files", staticmethod(lambda _path: None) + ) + + monomer.make_mmseq_features( + DEFAULT_API_SERVER="https://fake.server", + output_dir=str(tmp_path), + use_templates=True, + skip_msa=True, + ) + + assert calls["get_msa_and_templates"]["msa_mode"] == "single_sequence" + assert calls["get_msa_and_templates"]["pair_mode"] == "none" + assert calls["get_msa_and_templates"]["a3m_lines"] == [">101\nACDE"] + assert calls["get_msa_and_templates"]["use_templates"] is True + assert calls["enrich"]["a3m"] == ">101\nACDE" + assert monomer.skip_msa is True + assert monomer.feature_dict["msa"].shape == (1, 4) + assert monomer.feature_dict["msa_all_seq"].shape == (1, 4) + assert monomer.feature_dict["msa_uniprot_accession_identifiers_all_seq"].tolist() == [ + b"" + ] + + def test_make_mmseq_features_compresses_fresh_mmseq_result_dir( monkeypatch, tmp_path ): diff --git a/test/unit/test_script_entrypoints.py b/test/unit/test_script_entrypoints.py index 321fabbd..fea87969 100644 --- a/test/unit/test_script_entrypoints.py +++ b/test/unit/test_script_entrypoints.py @@ -790,6 +790,76 @@ def test_pre_modelling_setup_warns_for_long_paths_and_uses_chopped_metadata_name assert any("No feature metadata found for fragmentA" in message for message in warnings) +def test_pre_modelling_setup_rejects_pair_msa_for_skip_msa_interactors( + run_structure_prediction_module, + tmp_path, +): + _set_flag(run_structure_prediction_module.FLAGS, "pair_msa", True) + _set_flag(run_structure_prediction_module.FLAGS, "multimeric_template", False) + _set_flag(run_structure_prediction_module.FLAGS, "description_file", None) + _set_flag(run_structure_prediction_module.FLAGS, "path_to_mmt", None) + _set_flag(run_structure_prediction_module.FLAGS, "save_features_for_multimeric_object", False) + _set_flag( + run_structure_prediction_module.FLAGS, + "features_directory", + [str(tmp_path / "features")], + ) + _set_flag(run_structure_prediction_module.FLAGS, "use_ap_style", False) + + feature_dir = tmp_path / "features" + feature_dir.mkdir() + (feature_dir / "protA_feature_metadata_2026-03-30.json").write_text( + '{"meta": 1}', + encoding="utf-8", + ) + + monomer = run_structure_prediction_module.MonomericObject("protA", "ACDE") + monomer.skip_msa = True + + with pytest.raises(ValueError, match="--pair_msa=False"): + run_structure_prediction_module.pre_modelling_setup( + [monomer], + output_dir=str(tmp_path / "outputs"), + ) + + +def test_pre_modelling_setup_allows_skip_msa_when_pairing_disabled( + run_structure_prediction_module, + tmp_path, +): + _set_flag(run_structure_prediction_module.FLAGS, "pair_msa", False) + _set_flag(run_structure_prediction_module.FLAGS, "multimeric_template", False) + _set_flag(run_structure_prediction_module.FLAGS, "description_file", None) + _set_flag(run_structure_prediction_module.FLAGS, "path_to_mmt", None) + _set_flag(run_structure_prediction_module.FLAGS, "save_features_for_multimeric_object", False) + _set_flag( + run_structure_prediction_module.FLAGS, + "features_directory", + [str(tmp_path / "features")], + ) + _set_flag(run_structure_prediction_module.FLAGS, "use_ap_style", False) + + feature_dir = tmp_path / "features" + feature_dir.mkdir() + for description in ("protA", "protB"): + (feature_dir / f"{description}_feature_metadata_2026-03-30.json").write_text( + '{"meta": 1}', + encoding="utf-8", + ) + + monomer_a = run_structure_prediction_module.MonomericObject("protA", "AAAA") + monomer_a.skip_msa = True + monomer_b = run_structure_prediction_module.MonomericObject("protB", "BBBB") + + returned_object, _ = run_structure_prediction_module.pre_modelling_setup( + [monomer_a, monomer_b], + output_dir=str(tmp_path / "outputs"), + ) + + assert isinstance(returned_object, run_structure_prediction_module.MultimericObject) + assert returned_object.pair_msa is False + + def test_main_routes_protein_and_json_jobs_to_predict_structure( run_structure_prediction_module, monkeypatch, From 6eaa0922e86acd9bed986d99718455c31a73db6e Mon Sep 17 00:00:00 2001 From: Dima <33123184+DimaMolod@users.noreply.github.com> Date: Fri, 10 Apr 2026 11:39:24 +0200 Subject: [PATCH 2/5] docs(#42): document the skip-MSA flag --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index e1dfadfc..5d7a1a53 100644 --- a/README.md +++ b/README.md @@ -253,6 +253,7 @@ script manually). Commonly used flags: - `--data_pipeline {alphafold2,alphafold3}` – choose the feature format to emit. - `--db_preset {full_dbs,reduced_dbs}` – switch between the full BFD stack or the reduced databases. - `--use_mmseqs2` – rely on the remote MMseqs2 API; skips local jackhmmer/HHsearch database lookups. +- `--skip_msa` – generate query-only single-sequence features instead of running bulk MSA searches. Use these feature pickles with `run_structure_prediction.py --pair_msa=False`. - `--use_precomputed_msas` / `--save_msa_files` – reuse stored MSAs or keep new ones for later runs. - `--compress_features` – zip the generated `*.pkl` files (`.xz` extension) to save space. - `--skip_existing` – leave existing feature files untouched (safe for reruns). From e1737ac6a696221a10315ce9786aeb61e034c7e6 Mon Sep 17 00:00:00 2001 From: Dima <33123184+DimaMolod@users.noreply.github.com> Date: Fri, 10 Apr 2026 11:52:59 +0200 Subject: [PATCH 3/5] fix: address skip-MSA review edge cases --- .../scripts/create_individual_features.py | 11 +++- .../scripts/run_structure_prediction.py | 6 +- .../test_create_individual_features.py | 57 +++++++++++++++++++ test/unit/test_script_entrypoints.py | 41 ++++++++++++- 4 files changed, 109 insertions(+), 6 deletions(-) diff --git a/alphapulldown/scripts/create_individual_features.py b/alphapulldown/scripts/create_individual_features.py index d7d074a8..2aee6f13 100644 --- a/alphapulldown/scripts/create_individual_features.py +++ b/alphapulldown/scripts/create_individual_features.py @@ -474,11 +474,16 @@ def _reuse_truemultimer_monomer_features(feat): monomer = _load_existing_monomer_from_output_dir(source_name) if monomer is None: return None - if FLAGS.skip_msa and not getattr(monomer, "skip_msa", False): + cached_skip_msa = getattr(monomer, "skip_msa", False) + if FLAGS.skip_msa != cached_skip_msa: + requested_mode = "--skip_msa" if FLAGS.skip_msa else "full-MSA" + cached_mode = "--skip_msa" if cached_skip_msa else "full-MSA" logging.info( - "Existing monomer features for %s were generated with bulk MSAs. " - "Recomputing query-only features for --skip_msa.", + "Existing monomer features for %s were generated in %s mode, but the " + "current TrueMultimer entry requested %s mode. Recomputing features.", source_name, + cached_mode, + requested_mode, ) return None if monomer.sequence != feat["sequence"]: diff --git a/alphapulldown/scripts/run_structure_prediction.py b/alphapulldown/scripts/run_structure_prediction.py index ee9d5ec5..38a190ab 100644 --- a/alphapulldown/scripts/run_structure_prediction.py +++ b/alphapulldown/scripts/run_structure_prediction.py @@ -348,7 +348,11 @@ def pre_modelling_setup( A MultimericObject or MonomericObject output_directory for this particular modelling job """ - if FLAGS.pair_msa and any(getattr(interactor, "skip_msa", False) for interactor in interactors): + if ( + len(interactors) > 1 + and FLAGS.pair_msa + and any(getattr(interactor, "skip_msa", False) for interactor in interactors) + ): raise ValueError( "--skip_msa generates query-only MSAs and cannot be combined with " "--pair_msa=True. Re-run structure prediction with --pair_msa=False." diff --git a/test/integration/test_create_individual_features.py b/test/integration/test_create_individual_features.py index 5718c89d..7dcb830d 100644 --- a/test/integration/test_create_individual_features.py +++ b/test/integration/test_create_individual_features.py @@ -682,6 +682,63 @@ def test_process_multimeric_features_does_not_reuse_bulk_msa_pickle_for_skip_msa mock_runner.assert_called_once_with("/usr/bin/jackhmmer", "/db/uniprot.fasta") mock_save.assert_called_once() + def test_process_multimeric_features_does_not_reuse_skip_msa_pickle_for_full_msa( + self, tmp_flags + ): + template_path = Path(self.test_dir) / "template1.cif" + template_path.write_text("data_template\n", encoding="utf-8") + + from absl import flags + + FLAGS = flags.FLAGS + FLAGS(["test"]) + FLAGS.output_dir = os.path.join(self.test_dir, "full_msa_truemultimer_output") + FLAGS.use_mmseqs2 = False + FLAGS.compress_features = False + FLAGS.skip_existing = False + FLAGS.skip_msa = False + FLAGS.jackhmmer_binary_path = "/usr/bin/jackhmmer" + FLAGS.uniprot_database_path = "/db/uniprot.fasta" + + output_dir = Path(FLAGS.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + source = MonomericObject("proteinA", "ACDE") + source.skip_msa = True + source.feature_dict = { + "msa": np.asarray([[1, 2, 3, 4]], dtype=np.int32), + "deletion_matrix_int": np.zeros((1, 4), dtype=np.int32), + "num_alignments": np.asarray([1, 1, 1, 1], dtype=np.int32), + "msa_species_identifiers": np.asarray([b""], dtype=object), + } + with open(output_dir / "proteinA.pkl", "wb") as handle: + pickle.dump(source, handle) + + feat = { + "protein": "proteinA.template1.cif.A", + "chains": ["A"], + "templates": [str(template_path)], + "sequence": "ACDE", + } + + with patch.object( + create_features, + "extract_multimeric_template_features_for_single_chain", + ) as mock_extract, \ + patch.object(create_features, "create_custom_db", return_value="/tmp/custom_db") as mock_custom_db, \ + patch.object(create_features, "create_arguments") as mock_arguments, \ + patch.object(create_features, "create_pipeline_af2", return_value="pipeline") as mock_pipeline, \ + patch.object(create_features, "create_uniprot_runner", return_value="runner") as mock_runner, \ + patch.object(create_features, "create_and_save_monomer_objects") as mock_save: + create_features.process_multimeric_features(feat, 1) + + mock_extract.assert_not_called() + mock_custom_db.assert_called_once() + mock_arguments.assert_called_once_with("/tmp/custom_db") + mock_pipeline.assert_called_once_with() + mock_runner.assert_called_once_with("/usr/bin/jackhmmer", "/db/uniprot.fasta") + mock_save.assert_called_once() + def test_main_dispatches_to_truemultimer_for_af2_template_runs(self): """The main entrypoint should route AF2 template jobs to the TrueMultimer path.""" from absl import flags diff --git a/test/unit/test_script_entrypoints.py b/test/unit/test_script_entrypoints.py index fea87969..97aebf62 100644 --- a/test/unit/test_script_entrypoints.py +++ b/test/unit/test_script_entrypoints.py @@ -790,7 +790,7 @@ def test_pre_modelling_setup_warns_for_long_paths_and_uses_chopped_metadata_name assert any("No feature metadata found for fragmentA" in message for message in warnings) -def test_pre_modelling_setup_rejects_pair_msa_for_skip_msa_interactors( +def test_pre_modelling_setup_allows_skip_msa_monomers_with_default_pair_flag( run_structure_prediction_module, tmp_path, ): @@ -816,9 +816,46 @@ def test_pre_modelling_setup_rejects_pair_msa_for_skip_msa_interactors( monomer = run_structure_prediction_module.MonomericObject("protA", "ACDE") monomer.skip_msa = True + returned_object, _ = run_structure_prediction_module.pre_modelling_setup( + [monomer], + output_dir=str(tmp_path / "outputs"), + ) + + assert returned_object is monomer + assert returned_object.input_seqs == ["ACDE"] + + +def test_pre_modelling_setup_rejects_pair_msa_for_skip_msa_multimers( + run_structure_prediction_module, + tmp_path, +): + _set_flag(run_structure_prediction_module.FLAGS, "pair_msa", True) + _set_flag(run_structure_prediction_module.FLAGS, "multimeric_template", False) + _set_flag(run_structure_prediction_module.FLAGS, "description_file", None) + _set_flag(run_structure_prediction_module.FLAGS, "path_to_mmt", None) + _set_flag(run_structure_prediction_module.FLAGS, "save_features_for_multimeric_object", False) + _set_flag( + run_structure_prediction_module.FLAGS, + "features_directory", + [str(tmp_path / "features")], + ) + _set_flag(run_structure_prediction_module.FLAGS, "use_ap_style", False) + + feature_dir = tmp_path / "features" + feature_dir.mkdir() + for description in ("protA", "protB"): + (feature_dir / f"{description}_feature_metadata_2026-03-30.json").write_text( + '{"meta": 1}', + encoding="utf-8", + ) + + monomer_a = run_structure_prediction_module.MonomericObject("protA", "ACDE") + monomer_a.skip_msa = True + monomer_b = run_structure_prediction_module.MonomericObject("protB", "BCDE") + with pytest.raises(ValueError, match="--pair_msa=False"): run_structure_prediction_module.pre_modelling_setup( - [monomer], + [monomer_a, monomer_b], output_dir=str(tmp_path / "outputs"), ) From 363b0934239af6d9514545ab704c67733de7001c Mon Sep 17 00:00:00 2001 From: Dima <33123184+DimaMolod@users.noreply.github.com> Date: Fri, 10 Apr 2026 12:06:54 +0200 Subject: [PATCH 4/5] fix(#42): avoid AF2 bulk-db checks in skip-MSA mode --- .../scripts/create_individual_features.py | 71 +++++++++----- .../test_create_individual_features.py | 98 ++++++++++++++++++- 2 files changed, 141 insertions(+), 28 deletions(-) diff --git a/alphapulldown/scripts/create_individual_features.py b/alphapulldown/scripts/create_individual_features.py index 2aee6f13..f8ef6d4b 100644 --- a/alphapulldown/scripts/create_individual_features.py +++ b/alphapulldown/scripts/create_individual_features.py @@ -15,6 +15,7 @@ import tempfile from datetime import datetime from pathlib import Path +from types import SimpleNamespace import numpy as np from absl import logging, app, flags @@ -310,6 +311,31 @@ def create_af3_chain(sequence, description, chain_id): # =================== AlphaFold 2 Feature Creation =================== +def _create_af2_template_stack(): + """Create the AF2 template searcher and featurizer.""" + if FLAGS.use_hhsearch: + template_searcher = hhsearch.HHSearch( + binary_path=FLAGS.hhsearch_binary_path, databases=[FLAGS.pdb70_database_path] + ) + template_featuriser = templates.HhsearchHitFeaturizer( + mmcif_dir=FLAGS.template_mmcif_dir, max_template_date=FLAGS.max_template_date, + max_hits=20, kalign_binary_path=FLAGS.kalign_binary_path, + release_dates_path=None, obsolete_pdbs_path=FLAGS.obsolete_pdbs_path + ) + else: + template_featuriser = templates.HmmsearchHitFeaturizer( + mmcif_dir=FLAGS.template_mmcif_dir, max_template_date=FLAGS.max_template_date, + max_hits=20, kalign_binary_path=FLAGS.kalign_binary_path, + obsolete_pdbs_path=FLAGS.obsolete_pdbs_path, release_dates_path=None + ) + template_searcher = hmmsearch.Hmmsearch( + binary_path=FLAGS.hmmsearch_binary_path, + hmmbuild_binary_path=FLAGS.hmmbuild_binary_path, + database_path=FLAGS.pdb_seqres_database_path + ) + return template_searcher, template_featuriser + + def create_pipeline_af2(): """Create and configure the AlphaFold2 data pipeline.""" use_small_bfd = FLAGS.db_preset == "reduced_dbs" @@ -319,26 +345,13 @@ def create_pipeline_af2(): template_searcher = None template_featuriser = None else: - if FLAGS.use_hhsearch: - template_searcher = hhsearch.HHSearch( - binary_path=FLAGS.hhsearch_binary_path, databases=[FLAGS.pdb70_database_path] - ) - template_featuriser = templates.HhsearchHitFeaturizer( - mmcif_dir=FLAGS.template_mmcif_dir, max_template_date=FLAGS.max_template_date, - max_hits=20, kalign_binary_path=FLAGS.kalign_binary_path, - release_dates_path=None, obsolete_pdbs_path=FLAGS.obsolete_pdbs_path - ) - else: - template_featuriser = templates.HmmsearchHitFeaturizer( - mmcif_dir=FLAGS.template_mmcif_dir, max_template_date=FLAGS.max_template_date, - max_hits=20, kalign_binary_path=FLAGS.kalign_binary_path, - obsolete_pdbs_path=FLAGS.obsolete_pdbs_path, release_dates_path=None - ) - template_searcher = hmmsearch.Hmmsearch( - binary_path=FLAGS.hmmsearch_binary_path, - hmmbuild_binary_path=FLAGS.hmmbuild_binary_path, - database_path=FLAGS.pdb_seqres_database_path - ) + template_searcher, template_featuriser = _create_af2_template_stack() + + if FLAGS.skip_msa: + return SimpleNamespace( + template_searcher=template_searcher, + template_featurizer=template_featuriser, + ) return AF2DataPipeline( jackhmmer_binary_path=FLAGS.jackhmmer_binary_path, @@ -364,9 +377,12 @@ def create_individual_features(): uniprot_runner = None else: pipeline = create_pipeline_af2() - uniprot_runner = create_uniprot_runner( - FLAGS.jackhmmer_binary_path, FLAGS.uniprot_database_path - ) + if FLAGS.skip_msa: + uniprot_runner = None + else: + uniprot_runner = create_uniprot_runner( + FLAGS.jackhmmer_binary_path, FLAGS.uniprot_database_path + ) for seq_idx, (seq, desc) in enumerate(iter_seqs(FLAGS.fasta_paths), 1): if FLAGS.seq_index is None or seq_idx == FLAGS.seq_index: @@ -581,9 +597,12 @@ def process_multimeric_features(feat, idx): uniprot_runner = None else: pipeline = create_pipeline_af2() - uniprot_runner = create_uniprot_runner( - FLAGS.jackhmmer_binary_path, FLAGS.uniprot_database_path - ) + if FLAGS.skip_msa: + uniprot_runner = None + else: + uniprot_runner = create_uniprot_runner( + FLAGS.jackhmmer_binary_path, FLAGS.uniprot_database_path + ) monomer = MonomericObject(protein, feat['sequence']) monomer.uniprot_runner = uniprot_runner diff --git a/test/integration/test_create_individual_features.py b/test/integration/test_create_individual_features.py index 7dcb830d..128dc84a 100644 --- a/test/integration/test_create_individual_features.py +++ b/test/integration/test_create_individual_features.py @@ -679,8 +679,10 @@ def test_process_multimeric_features_does_not_reuse_bulk_msa_pickle_for_skip_msa mock_custom_db.assert_called_once() mock_arguments.assert_called_once_with("/tmp/custom_db") mock_pipeline.assert_called_once_with() - mock_runner.assert_called_once_with("/usr/bin/jackhmmer", "/db/uniprot.fasta") - mock_save.assert_called_once() + mock_runner.assert_not_called() + saved_monomer, saved_pipeline = mock_save.call_args.args + assert saved_pipeline == "pipeline" + assert saved_monomer.uniprot_runner is None def test_process_multimeric_features_does_not_reuse_skip_msa_pickle_for_full_msa( self, tmp_flags @@ -1416,6 +1418,41 @@ def test_create_pipeline_af2_uses_hmmsearch_template_stack(tmp_flags): assert mock_pipeline.call_args.kwargs["template_featurizer"] == "featurizer" +def test_create_pipeline_af2_skip_msa_returns_template_only_pipeline(tmp_flags): + create_features.FLAGS.use_mmseqs2 = False + create_features.FLAGS.use_hhsearch = False + create_features.FLAGS.skip_msa = True + create_features.FLAGS.hmmsearch_binary_path = "/bin/hmmsearch" + create_features.FLAGS.hmmbuild_binary_path = "/bin/hmmbuild" + create_features.FLAGS.pdb_seqres_database_path = "/db/pdb_seqres.txt" + create_features.FLAGS.template_mmcif_dir = "/db/mmcif" + create_features.FLAGS.max_template_date = "2021-09-30" + create_features.FLAGS.kalign_binary_path = "/bin/kalign" + create_features.FLAGS.obsolete_pdbs_path = "/db/obsolete.dat" + + with patch.object(create_features.hmmsearch, "Hmmsearch", return_value="searcher") as mock_searcher, \ + patch.object(create_features.templates, "HmmsearchHitFeaturizer", return_value="featurizer") as mock_featurizer, \ + patch.object(create_features, "AF2DataPipeline") as mock_pipeline: + pipeline = create_features.create_pipeline_af2() + + mock_searcher.assert_called_once_with( + binary_path="/bin/hmmsearch", + hmmbuild_binary_path="/bin/hmmbuild", + database_path="/db/pdb_seqres.txt", + ) + mock_featurizer.assert_called_once_with( + mmcif_dir="/db/mmcif", + max_template_date="2021-09-30", + max_hits=20, + kalign_binary_path="/bin/kalign", + obsolete_pdbs_path="/db/obsolete.dat", + release_dates_path=None, + ) + mock_pipeline.assert_not_called() + assert pipeline.template_searcher == "searcher" + assert pipeline.template_featurizer == "featurizer" + + def test_create_individual_features_only_saves_selected_sequence(tmp_flags): create_features.FLAGS.seq_index = 2 @@ -1436,6 +1473,28 @@ def test_create_individual_features_only_saves_selected_sequence(tmp_flags): assert saved_monomer.uniprot_runner == "runner" +def test_create_individual_features_skip_msa_avoids_uniprot_runner(tmp_flags): + create_features.FLAGS.seq_index = None + create_features.FLAGS.use_mmseqs2 = False + create_features.FLAGS.skip_msa = True + + with patch.object(create_features, "create_arguments") as mock_arguments, \ + patch.object(create_features, "create_pipeline_af2", return_value="template-only-pipeline") as mock_pipeline, \ + patch.object(create_features, "create_uniprot_runner") as mock_runner, \ + patch.object(create_features, "MonomericObject", DummyMonomer), \ + patch.object(create_features, "iter_seqs", return_value=[("AAAA", "first")]), \ + patch.object(create_features, "create_and_save_monomer_objects") as mock_save: + create_features.create_individual_features() + + mock_arguments.assert_called_once_with() + mock_pipeline.assert_called_once_with() + mock_runner.assert_not_called() + saved_monomer, saved_pipeline = mock_save.call_args.args + assert saved_pipeline == "template-only-pipeline" + assert saved_monomer.description == "first" + assert saved_monomer.uniprot_runner is None + + def test_create_and_save_monomer_objects_writes_compressed_af2_outputs(tmp_flags, tmp_path): create_features.FLAGS.output_dir = str(tmp_path) create_features.FLAGS.compress_features = True @@ -1753,6 +1812,41 @@ def test_process_multimeric_features_uses_mmseqs_without_local_pipeline(tmp_flag assert saved_kwargs == {"custom_template_path": "/tmp/custom_db/templates"} +def test_process_multimeric_features_skip_msa_avoids_uniprot_runner(tmp_flags, tmp_path): + template_path = tmp_path / "template.cif" + template_path.write_text("data_template\n", encoding="utf-8") + + create_features.FLAGS.output_dir = str(tmp_path / "out") + create_features.FLAGS.use_mmseqs2 = False + create_features.FLAGS.skip_msa = True + + feat = { + "protein": "complex_local", + "chains": ["A"], + "templates": [str(template_path)], + "sequence": "ACDE", + } + + with patch.object(create_features, "MonomericObject", RecordingDummyMonomer), \ + patch.object(create_features, "create_custom_db", return_value="/tmp/custom_db") as mock_custom_db, \ + patch.object(create_features, "create_arguments") as mock_arguments, \ + patch.object(create_features, "create_pipeline_af2", return_value="template-only-pipeline") as mock_pipeline, \ + patch.object(create_features, "create_uniprot_runner") as mock_runner, \ + patch.object(create_features, "create_and_save_monomer_objects") as mock_save: + create_features.process_multimeric_features(feat, 1) + + mock_custom_db.assert_called_once() + mock_arguments.assert_called_once_with("/tmp/custom_db") + mock_pipeline.assert_called_once_with() + mock_runner.assert_not_called() + saved_monomer, saved_pipeline = mock_save.call_args.args + saved_kwargs = mock_save.call_args.kwargs + assert saved_pipeline == "template-only-pipeline" + assert saved_monomer.description == "complex_local" + assert saved_monomer.uniprot_runner is None + assert saved_kwargs == {"custom_template_path": None} + + def test_create_custom_db_passes_thresholds_to_builder(tmp_flags): create_features.FLAGS.threshold_clashes = 12.5 create_features.FLAGS.hb_allowance = 0.7 From 9176d6ee78b0944bf3fbf7c23820347614fd165b Mon Sep 17 00:00:00 2001 From: Dima <33123184+DimaMolod@users.noreply.github.com> Date: Fri, 10 Apr 2026 12:14:14 +0200 Subject: [PATCH 5/5] fix(#42): add RF annotation to skip-MSA stockholm --- alphapulldown/objects.py | 8 ++++- test/unit/test_objects.py | 62 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 69 insertions(+), 1 deletion(-) diff --git a/alphapulldown/objects.py b/alphapulldown/objects.py index 70127479..feb1a10c 100644 --- a/alphapulldown/objects.py +++ b/alphapulldown/objects.py @@ -34,7 +34,13 @@ def _query_only_a3m(sequence: str, query_id: str = "query") -> str: def _query_only_stockholm(sequence: str, query_id: str = "query") -> str: """Return a single-sequence Stockholm alignment string.""" - return f"# STOCKHOLM 1.0\n{query_id} {sequence}\n//\n" + rf_annotation = "x" * len(sequence) + return ( + "# STOCKHOLM 1.0\n" + f"{query_id} {sequence}\n" + f"#=GC RF {rf_annotation}\n" + "//\n" + ) class MonomericObject: diff --git a/test/unit/test_objects.py b/test/unit/test_objects.py index 1ce01ad6..edf8e7f3 100644 --- a/test/unit/test_objects.py +++ b/test/unit/test_objects.py @@ -278,6 +278,68 @@ def process(self, *_args, **_kwargs): assert monomer.feature_dict["template_domain_names"].tolist() == [b"1abc_A"] +def test_make_features_skip_msa_builds_stockholm_with_rf_for_hmmsearch( + monkeypatch, tmp_path +): + monomer = MonomericObject("proteinA", "ACDE") + calls = {} + + class FakeTemplateSearcher: + input_format = "sto" + output_format = "sto" + + def query(self, alignment): + calls["template_query"] = alignment + return "template_hits" + + def get_template_hits(self, output_string, input_sequence): + return ["hitA"] + + class FakeTemplateFeaturizer: + def get_templates(self, query_sequence, hits): + return SimpleNamespace( + features={ + "template_aatype": np.ones((1, 4, 22), dtype=np.float32), + "template_all_atom_masks": np.ones((1, 4, 37), dtype=np.float32), + "template_all_atom_positions": np.ones( + (1, 4, 37, 3), dtype=np.float32 + ), + "template_domain_names": np.asarray([b"1abc_A"], dtype=object), + "template_sequence": np.asarray([b"ACDE"], dtype=object), + "template_sum_probs": np.asarray([0.5], dtype=np.float32), + } + ) + + class FakePipeline: + template_searcher = FakeTemplateSearcher() + template_featurizer = FakeTemplateFeaturizer() + + def process(self, *_args, **_kwargs): + raise AssertionError("skip_msa should bypass pipeline.process") + + monkeypatch.setattr( + MonomericObject, "unzip_msa_files", staticmethod(lambda _path: False) + ) + monkeypatch.setattr( + MonomericObject, + "remove_msa_files", + staticmethod(lambda msa_output_path=None, **_kwargs: None), + ) + monkeypatch.setattr( + MonomericObject, "zip_msa_files", staticmethod(lambda _path: None) + ) + + monomer.make_features( + pipeline=FakePipeline(), + output_dir=str(tmp_path), + save_msa=False, + skip_msa=True, + ) + + assert "#=GC RF xxxx" in calls["template_query"] + assert calls["template_query"].startswith("# STOCKHOLM 1.0\nquery ACDE\n") + + def test_make_mmseq_features_builds_all_seq_features_and_writes_a3m( monkeypatch, tmp_path ):