Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions alphapulldown/objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,13 +194,15 @@ def make_mmseq_features(
compress_msa_files=False,
use_precomputed_msa=False,
use_templates=False,
custom_template_path=None,
):
"""
A method to use mmseq_remote to calculate MSA.
Modified from ColabFold to allow reusing precomputed MSAs if available.
"""
os.makedirs(output_dir, exist_ok=True)
using_zipped_msa_files = MonomericObject.unzip_msa_files(output_dir)
use_templates = use_templates or custom_template_path is not None

msa_mode = "mmseqs2_uniref_env"
keep_existing_results = True
Expand All @@ -224,7 +226,7 @@ def make_mmseq_features(
result_dir=plPath(result_dir),
msa_mode='single_sequence',
use_templates=True,
custom_template_path=None,
custom_template_path=custom_template_path,
pair_mode="none",
host_url=DEFAULT_API_SERVER,
user_agent='alphapulldown')
Expand All @@ -243,7 +245,7 @@ def make_mmseq_features(
result_dir=plPath(result_dir),
msa_mode=msa_mode,
use_templates=use_templates,
custom_template_path=None,
custom_template_path=custom_template_path,
pair_mode="none",
host_url=DEFAULT_API_SERVER,
user_agent='alphapulldown'
Expand Down
17 changes: 14 additions & 3 deletions alphapulldown/scripts/create_individual_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,7 @@ def create_individual_features():
monomer.uniprot_runner = uniprot_runner
create_and_save_monomer_objects(monomer, pipeline)

def create_and_save_monomer_objects(monomer, pipeline):
def create_and_save_monomer_objects(monomer, pipeline, custom_template_path=None):
"""Save a MonomericObject after feature creation (pickled, optionally compressed)."""
# Ensure output directory exists
os.makedirs(FLAGS.output_dir, exist_ok=True)
Expand All @@ -362,7 +362,13 @@ def create_and_save_monomer_objects(monomer, pipeline):
with open(metadata_output_path, "w") as meta_data_outfile:
json.dump(meta_dict, meta_data_outfile)
if FLAGS.use_mmseqs2:
monomer.make_mmseq_features(DEFAULT_API_SERVER=DEFAULT_API_SERVER, output_dir=FLAGS.output_dir, use_precomputed_msa=FLAGS.use_precomputed_msas, use_templates=FLAGS.re_search_templates_mmseqs2)
monomer.make_mmseq_features(
DEFAULT_API_SERVER=DEFAULT_API_SERVER,
output_dir=FLAGS.output_dir,
use_precomputed_msa=FLAGS.use_precomputed_msas,
use_templates=FLAGS.re_search_templates_mmseqs2 or custom_template_path is not None,
custom_template_path=custom_template_path,
)
else:
monomer.make_features(
pipeline=pipeline, output_dir=FLAGS.output_dir,
Expand Down Expand Up @@ -406,7 +412,12 @@ def process_multimeric_features(feat, idx):

monomer = MonomericObject(protein, feat['sequence'])
monomer.uniprot_runner = uniprot_runner
create_and_save_monomer_objects(monomer, pipeline)
custom_template_path = str(Path(local_path_to_custom_db) / "templates")
create_and_save_monomer_objects(
monomer,
pipeline,
custom_template_path=custom_template_path if FLAGS.use_mmseqs2 else None,
)

def create_custom_db(temp_dir, protein, template_paths, chains):
"""Create a local custom template DB for TrueMultimer/AF2."""
Expand Down
34 changes: 34 additions & 0 deletions test/integration/test_create_individual_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -1194,6 +1194,38 @@ def test_create_and_save_monomer_objects_uses_mmseqs_when_requested(tmp_flags, t
"output_dir": str(tmp_path),
"use_precomputed_msa": True,
"use_templates": True,
"custom_template_path": None,
}
]
assert (tmp_path / "protA.pkl").exists()


def test_create_and_save_monomer_objects_passes_custom_templates_to_mmseqs(tmp_flags, tmp_path):
create_features.FLAGS.output_dir = str(tmp_path)
create_features.FLAGS.compress_features = False
create_features.FLAGS.skip_existing = False
create_features.FLAGS.use_mmseqs2 = True
create_features.FLAGS.use_precomputed_msas = False
create_features.FLAGS.re_search_templates_mmseqs2 = False

monomer = RecordingDummyMonomer("protA")
custom_template_path = str(tmp_path / "custom_db" / "templates")

with patch("alphapulldown.utils.save_meta_data.get_meta_dict", return_value={"source": "test"}):
create_features.create_and_save_monomer_objects(
monomer,
pipeline=None,
custom_template_path=custom_template_path,
)

assert monomer.feature_calls == []
assert monomer.mmseq_calls == [
{
"DEFAULT_API_SERVER": create_features.DEFAULT_API_SERVER,
"output_dir": str(tmp_path),
"use_precomputed_msa": False,
"use_templates": True,
"custom_template_path": custom_template_path,
}
]
assert (tmp_path / "protA.pkl").exists()
Expand Down Expand Up @@ -1357,9 +1389,11 @@ def test_process_multimeric_features_uses_mmseqs_without_local_pipeline(tmp_flag
mock_pipeline.assert_not_called()
mock_runner.assert_not_called()
saved_monomer, saved_pipeline = mock_save.call_args.args
saved_kwargs = mock_save.call_args.kwargs
assert saved_pipeline is None
assert saved_monomer.description == "complex_mmseqs"
assert saved_monomer.uniprot_runner is None
assert saved_kwargs == {"custom_template_path": "/tmp/custom_db/templates"}


def test_create_custom_db_passes_thresholds_to_builder(tmp_flags):
Expand Down
67 changes: 67 additions & 0 deletions test/unit/test_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,6 +454,73 @@ def fake_build(sequence, msa, template_features):
)


def test_make_mmseq_features_uses_custom_template_path_for_precomputed_msa(
monkeypatch, tmp_path
):
monomer = MonomericObject("proteinA", "ACDE")
calls = {}
(tmp_path / "proteinA.a3m").write_text(">101\nACDE\n", encoding="utf-8")
custom_template_path = str(tmp_path / "custom_templates")

monkeypatch.setattr(
MonomericObject, "unzip_msa_files", staticmethod(lambda _path: False)
)
monkeypatch.setattr(
objects_mod,
"unserialize_msa",
lambda a3m_lines, sequence: (
["PRECOMP_MSA"],
["PRECOMP_PAIRED"],
["UNIQUE"],
["CARD"],
["PRECOMP_TEMPLATE"],
),
)

def fake_get_msa_and_templates(**kwargs):
calls["get_msa_and_templates"] = kwargs
return (
["IGNORED_UNPAIRED"],
["IGNORED_PAIRED"],
["IGNORED_UNIQUE"],
["IGNORED_CARD"],
["CUSTOM_TEMPLATE"],
)

monkeypatch.setattr(objects_mod, "get_msa_and_templates", fake_get_msa_and_templates)
monkeypatch.setattr(
objects_mod,
"build_monomer_feature",
lambda *_args, **_kwargs: {
"msa": np.asarray([[1, 2, 3, 4]], dtype=np.int32),
"deletion_matrix_int": np.asarray([[0, 0, 0, 0]], dtype=np.int32),
"template_confidence_scores": None,
"template_release_date": None,
},
)
monkeypatch.setattr(
objects_mod,
"enrich_mmseq_feature_dict_with_identifiers",
lambda feature_dict, *_args, **_kwargs: feature_dict.update(
{
"msa_species_identifiers": np.asarray([b"562"], dtype=object),
"msa_uniprot_accession_identifiers": np.asarray([b"A0A123"], dtype=object),
}
),
)

monomer.make_mmseq_features(
DEFAULT_API_SERVER="https://fake.server",
output_dir=str(tmp_path),
use_precomputed_msa=True,
use_templates=False,
custom_template_path=custom_template_path,
)

assert calls["get_msa_and_templates"]["use_templates"] is True
assert calls["get_msa_and_templates"]["custom_template_path"] == custom_template_path


def test_make_mmseq_features_reuses_identifier_sidecar_on_precomputed_run(
monkeypatch, tmp_path
):
Expand Down
Loading