diff --git a/alphapulldown/objects.py b/alphapulldown/objects.py index 72adcd51..26c610f3 100644 --- a/alphapulldown/objects.py +++ b/alphapulldown/objects.py @@ -194,6 +194,7 @@ def make_mmseq_features( compress_msa_files=False, use_precomputed_msa=False, use_templates=False, + custom_template_path=None, ): """ A method to use mmseq_remote to calculate MSA. @@ -201,6 +202,7 @@ def make_mmseq_features( """ os.makedirs(output_dir, exist_ok=True) using_zipped_msa_files = MonomericObject.unzip_msa_files(output_dir) + use_templates = use_templates or custom_template_path is not None msa_mode = "mmseqs2_uniref_env" keep_existing_results = True @@ -224,7 +226,7 @@ def make_mmseq_features( result_dir=plPath(result_dir), msa_mode='single_sequence', use_templates=True, - custom_template_path=None, + custom_template_path=custom_template_path, pair_mode="none", host_url=DEFAULT_API_SERVER, user_agent='alphapulldown') @@ -243,7 +245,7 @@ def make_mmseq_features( result_dir=plPath(result_dir), msa_mode=msa_mode, use_templates=use_templates, - custom_template_path=None, + custom_template_path=custom_template_path, pair_mode="none", host_url=DEFAULT_API_SERVER, user_agent='alphapulldown' diff --git a/alphapulldown/scripts/create_individual_features.py b/alphapulldown/scripts/create_individual_features.py index b8fed124..58a1ee76 100644 --- a/alphapulldown/scripts/create_individual_features.py +++ b/alphapulldown/scripts/create_individual_features.py @@ -340,7 +340,7 @@ def create_individual_features(): monomer.uniprot_runner = uniprot_runner create_and_save_monomer_objects(monomer, pipeline) -def create_and_save_monomer_objects(monomer, pipeline): +def create_and_save_monomer_objects(monomer, pipeline, custom_template_path=None): """Save a MonomericObject after feature creation (pickled, optionally compressed).""" # Ensure output directory exists os.makedirs(FLAGS.output_dir, exist_ok=True) @@ -362,7 +362,13 @@ def create_and_save_monomer_objects(monomer, pipeline): with open(metadata_output_path, "w") as meta_data_outfile: json.dump(meta_dict, meta_data_outfile) if FLAGS.use_mmseqs2: - monomer.make_mmseq_features(DEFAULT_API_SERVER=DEFAULT_API_SERVER, output_dir=FLAGS.output_dir, use_precomputed_msa=FLAGS.use_precomputed_msas, use_templates=FLAGS.re_search_templates_mmseqs2) + monomer.make_mmseq_features( + DEFAULT_API_SERVER=DEFAULT_API_SERVER, + output_dir=FLAGS.output_dir, + use_precomputed_msa=FLAGS.use_precomputed_msas, + use_templates=FLAGS.re_search_templates_mmseqs2 or custom_template_path is not None, + custom_template_path=custom_template_path, + ) else: monomer.make_features( pipeline=pipeline, output_dir=FLAGS.output_dir, @@ -406,7 +412,12 @@ def process_multimeric_features(feat, idx): monomer = MonomericObject(protein, feat['sequence']) monomer.uniprot_runner = uniprot_runner - create_and_save_monomer_objects(monomer, pipeline) + custom_template_path = str(Path(local_path_to_custom_db) / "templates") + create_and_save_monomer_objects( + monomer, + pipeline, + custom_template_path=custom_template_path if FLAGS.use_mmseqs2 else None, + ) def create_custom_db(temp_dir, protein, template_paths, chains): """Create a local custom template DB for TrueMultimer/AF2.""" diff --git a/test/integration/test_create_individual_features.py b/test/integration/test_create_individual_features.py index 3f1297af..a5969cbb 100644 --- a/test/integration/test_create_individual_features.py +++ b/test/integration/test_create_individual_features.py @@ -1194,6 +1194,38 @@ def test_create_and_save_monomer_objects_uses_mmseqs_when_requested(tmp_flags, t "output_dir": str(tmp_path), "use_precomputed_msa": True, "use_templates": True, + "custom_template_path": None, + } + ] + assert (tmp_path / "protA.pkl").exists() + + +def test_create_and_save_monomer_objects_passes_custom_templates_to_mmseqs(tmp_flags, tmp_path): + create_features.FLAGS.output_dir = str(tmp_path) + create_features.FLAGS.compress_features = False + create_features.FLAGS.skip_existing = False + create_features.FLAGS.use_mmseqs2 = True + create_features.FLAGS.use_precomputed_msas = False + create_features.FLAGS.re_search_templates_mmseqs2 = False + + monomer = RecordingDummyMonomer("protA") + custom_template_path = str(tmp_path / "custom_db" / "templates") + + with patch("alphapulldown.utils.save_meta_data.get_meta_dict", return_value={"source": "test"}): + create_features.create_and_save_monomer_objects( + monomer, + pipeline=None, + custom_template_path=custom_template_path, + ) + + assert monomer.feature_calls == [] + assert monomer.mmseq_calls == [ + { + "DEFAULT_API_SERVER": create_features.DEFAULT_API_SERVER, + "output_dir": str(tmp_path), + "use_precomputed_msa": False, + "use_templates": True, + "custom_template_path": custom_template_path, } ] assert (tmp_path / "protA.pkl").exists() @@ -1357,9 +1389,11 @@ def test_process_multimeric_features_uses_mmseqs_without_local_pipeline(tmp_flag mock_pipeline.assert_not_called() mock_runner.assert_not_called() saved_monomer, saved_pipeline = mock_save.call_args.args + saved_kwargs = mock_save.call_args.kwargs assert saved_pipeline is None assert saved_monomer.description == "complex_mmseqs" assert saved_monomer.uniprot_runner is None + assert saved_kwargs == {"custom_template_path": "/tmp/custom_db/templates"} def test_create_custom_db_passes_thresholds_to_builder(tmp_flags): diff --git a/test/unit/test_objects.py b/test/unit/test_objects.py index f99de4eb..97e6a74d 100644 --- a/test/unit/test_objects.py +++ b/test/unit/test_objects.py @@ -454,6 +454,73 @@ def fake_build(sequence, msa, template_features): ) +def test_make_mmseq_features_uses_custom_template_path_for_precomputed_msa( + monkeypatch, tmp_path +): + monomer = MonomericObject("proteinA", "ACDE") + calls = {} + (tmp_path / "proteinA.a3m").write_text(">101\nACDE\n", encoding="utf-8") + custom_template_path = str(tmp_path / "custom_templates") + + monkeypatch.setattr( + MonomericObject, "unzip_msa_files", staticmethod(lambda _path: False) + ) + monkeypatch.setattr( + objects_mod, + "unserialize_msa", + lambda a3m_lines, sequence: ( + ["PRECOMP_MSA"], + ["PRECOMP_PAIRED"], + ["UNIQUE"], + ["CARD"], + ["PRECOMP_TEMPLATE"], + ), + ) + + def fake_get_msa_and_templates(**kwargs): + calls["get_msa_and_templates"] = kwargs + return ( + ["IGNORED_UNPAIRED"], + ["IGNORED_PAIRED"], + ["IGNORED_UNIQUE"], + ["IGNORED_CARD"], + ["CUSTOM_TEMPLATE"], + ) + + monkeypatch.setattr(objects_mod, "get_msa_and_templates", fake_get_msa_and_templates) + monkeypatch.setattr( + objects_mod, + "build_monomer_feature", + lambda *_args, **_kwargs: { + "msa": np.asarray([[1, 2, 3, 4]], dtype=np.int32), + "deletion_matrix_int": np.asarray([[0, 0, 0, 0]], dtype=np.int32), + "template_confidence_scores": None, + "template_release_date": None, + }, + ) + monkeypatch.setattr( + objects_mod, + "enrich_mmseq_feature_dict_with_identifiers", + lambda feature_dict, *_args, **_kwargs: feature_dict.update( + { + "msa_species_identifiers": np.asarray([b"562"], dtype=object), + "msa_uniprot_accession_identifiers": np.asarray([b"A0A123"], dtype=object), + } + ), + ) + + monomer.make_mmseq_features( + DEFAULT_API_SERVER="https://fake.server", + output_dir=str(tmp_path), + use_precomputed_msa=True, + use_templates=False, + custom_template_path=custom_template_path, + ) + + assert calls["get_msa_and_templates"]["use_templates"] is True + assert calls["get_msa_and_templates"]["custom_template_path"] == custom_template_path + + def test_make_mmseq_features_reuses_identifier_sidecar_on_precomputed_run( monkeypatch, tmp_path ):