diff --git a/alphapulldown/folding_backend/alphafold2_backend.py b/alphapulldown/folding_backend/alphafold2_backend.py index 9c25e70e..6268f6e3 100644 --- a/alphapulldown/folding_backend/alphafold2_backend.py +++ b/alphapulldown/folding_backend/alphafold2_backend.py @@ -12,6 +12,7 @@ import subprocess import enum import typing +from collections.abc import Mapping from typing import Dict, Union, List, Any import os from absl import logging @@ -145,6 +146,37 @@ def _jnp_to_np(output): return output +def _normalize_prediction_result( + prediction_result: Any, *, model_name: str +) -> Dict[str, Any]: + """Accept the supported AF2 runner return shapes and return a mutable dict.""" + if isinstance(prediction_result, Mapping): + return dict(prediction_result) + + if isinstance(prediction_result, tuple): + if len(prediction_result) == 2 and isinstance(prediction_result[0], Mapping): + auxiliary_output = prediction_result[1] + logging.warning( + "Model %s returned a (prediction_result, auxiliary_output) tuple; " + "ignoring auxiliary output of type %s.", + model_name, + type(auxiliary_output).__name__, + ) + return dict(prediction_result[0]) + + first_type = type(prediction_result[0]).__name__ if prediction_result else "n/a" + raise TypeError( + "model_runner.predict must return a mapping or a (mapping, auxiliary) " + f"tuple; got tuple of length {len(prediction_result)} with first " + f"element type {first_type}." + ) + + raise TypeError( + "model_runner.predict must return a mapping or a (mapping, auxiliary) " + f"tuple; got {type(prediction_result).__name__}." + ) + + def _save_pae_json_file(pae: np.ndarray, max_pae: float, output_dir: str, model_name: str) -> None: """ Check prediction result for PAE data and save to a JSON file if present. @@ -709,6 +741,9 @@ def predict_individual_job( prediction_result = model_runner.predict( processed_feature_dict, random_seed=model_random_seed ) + prediction_result = _normalize_prediction_result( + prediction_result, model_name=model_name + ) t_diff = time.time() - t_0 timings[f"predict_and_compile_{model_name}"] = t_diff logging.info(f"prediction costs : {t_diff} s") diff --git a/alphapulldown/objects.py b/alphapulldown/objects.py index feb1a10c..5754bad1 100644 --- a/alphapulldown/objects.py +++ b/alphapulldown/objects.py @@ -43,6 +43,23 @@ def _query_only_stockholm(sequence: str, query_id: str = "query") -> str: ) +def _ensure_identifier_feature_arrays( + feature_dict: Dict[str, np.ndarray], + feature_groups: Tuple[Tuple[str, Tuple[str, ...]], ...], +) -> Dict[str, np.ndarray]: + """Backfill missing identifier arrays to match the corresponding MSA rows.""" + normalized = dict(feature_dict) + for msa_key, identifier_keys in feature_groups: + msa = normalized.get(msa_key) + if msa is None: + continue + num_rows = int(np.asarray(msa).shape[0]) + for key in identifier_keys: + if key not in normalized: + normalized[key] = np.array([b""] * num_rows, dtype=object) + return normalized + + class MonomericObject: """ monomeric objects @@ -145,7 +162,18 @@ def all_seq_msa_features( msa = parsers.parse_stockholm(result["sto"]) msa = msa.truncate(max_seqs=50000) - all_seq_features = pipeline.make_msa_features([msa]) + all_seq_features = _ensure_identifier_feature_arrays( + pipeline.make_msa_features([msa]), + ( + ( + "msa", + ( + "msa_species_identifiers", + "msa_uniprot_accession_identifiers", + ), + ), + ), + ) valid_feats = msa_pairing.MSA_FEATURES + ( "msa_species_identifiers", "msa_uniprot_accession_identifiers", @@ -362,9 +390,12 @@ def make_mmseq_features( # Remove header lines starting with '#' if present. a3m_lines[0] = strip_mmseq_comment_lines(a3m_lines[0]) self.feature_dict = build_monomer_feature(self.sequence, unpaired_msa[0], template_features[0]) + # Enrich from the same A3M string that build_monomer_feature parsed, so + # the identifier rows go through the same parse_a3m dedup as msa_features + # and their count matches feature_dict['msa'] exactly. enrich_mmseq_feature_dict_with_identifiers( self.feature_dict, - a3m_lines[0], + unpaired_msa[0], cache_path=os.path.join( result_dir, f"{self.description}.mmseq_ids.json" ), @@ -811,12 +842,48 @@ def remove_all_seq_features(np_chain_list: List[Dict]) -> List[Dict]: output_list.append(new_chain) return output_list + @staticmethod + def normalize_all_seq_identifier_features(np_chain_list: List[Dict]) -> List[Dict]: + """Ensure identifier arrays exist consistently across chains. + + Some feature sources provide species identifiers but omit UniProt + accession IDs, while DeepMind's multimer pairing and merge code assumes + both unpaired and `_all_seq` identifier keys exist consistently across + chains. + """ + output_list = [] + for feat_dict in np_chain_list: + output_list.append( + _ensure_identifier_feature_arrays( + feat_dict, + ( + ( + "msa", + ( + "msa_species_identifiers", + "msa_uniprot_accession_identifiers", + ), + ), + ( + "msa_all_seq", + ( + "msa_species_identifiers_all_seq", + "msa_uniprot_accession_identifiers_all_seq", + ), + ), + ), + ) + ) + return output_list + def pair_and_merge(self, all_chain_features): """merge all chain features""" feature_processing.process_unmerged_features(all_chain_features) MAX_TEMPLATES = 4 MSA_CROP_SIZE = 2048 - np_chains_list = list(all_chain_features.values()) + np_chains_list = MultimericObject.normalize_all_seq_identifier_features( + list(all_chain_features.values()) + ) pair_msa_sequences = self.pair_msa and not feature_processing._is_homomer_or_monomer( np_chains_list) logging.debug(f"pair_msa_sequences is type : {type(pair_msa_sequences)} value: {pair_msa_sequences}") diff --git a/test/cluster/check_alphafold2_predictions.py b/test/cluster/check_alphafold2_predictions.py index 440d869f..0b7423f9 100755 --- a/test/cluster/check_alphafold2_predictions.py +++ b/test/cluster/check_alphafold2_predictions.py @@ -92,6 +92,23 @@ def _load_feature_dict(feature_path: Path) -> dict: return payload +def _load_feature_metadata(feature_dir: Path, protein_id: str) -> tuple[Path, dict]: + matches = sorted(feature_dir.glob(f"{protein_id}_feature_metadata_*.json*")) + if len(matches) != 1: + raise FileNotFoundError( + f"Expected one feature metadata file for {protein_id} in {feature_dir}, " + f"found {matches}" + ) + metadata_path = matches[0] + opener = lzma.open if metadata_path.suffix == ".xz" else open + with opener(metadata_path, "rt", encoding="utf-8") as handle: + return metadata_path, json.load(handle) + + +def _metadata_bool(value) -> bool: + return str(value).strip().lower() in {"1", "true", "yes"} + + def _non_empty_identifier_count(values) -> int: count = 0 for value in values: @@ -105,11 +122,12 @@ def _non_empty_identifier_count(values) -> int: def _af2_subprocess_env() -> dict[str, str]: """Return stable GPU/JAX defaults for AF2 functional subprocesses.""" env = os.environ.copy() - env.setdefault("OMP_NUM_THREADS", "4") - env.setdefault("MKL_NUM_THREADS", "4") - env.setdefault("NUMEXPR_NUM_THREADS", "4") - env.setdefault("TF_NUM_INTEROP_THREADS", "4") - env.setdefault("TF_NUM_INTRAOP_THREADS", "4") + env.setdefault("OMP_NUM_THREADS", "1") + env.setdefault("OPENBLAS_NUM_THREADS", "1") + env.setdefault("MKL_NUM_THREADS", "1") + env.setdefault("NUMEXPR_NUM_THREADS", "1") + env.setdefault("TF_NUM_INTEROP_THREADS", "1") + env.setdefault("TF_NUM_INTRAOP_THREADS", "1") env.setdefault("TF_FORCE_GPU_ALLOW_GROWTH", "true") env.setdefault("TF_CPP_MIN_LOG_LEVEL", "2") env.setdefault("XLA_PYTHON_CLIENT_PREALLOCATE", "false") @@ -117,7 +135,8 @@ def _af2_subprocess_env() -> dict[str, str]: env.setdefault("JAX_PLATFORM_NAME", "gpu") env.setdefault( "XLA_FLAGS", - "--xla_gpu_force_compilation_parallelism=0 " + "--xla_gpu_force_compilation_parallelism=1 " + "--xla_force_host_platform_device_count=1 " "--xla_cpu_multi_thread_eigen=false intra_op_parallelism_threads=1", ) return env @@ -544,6 +563,87 @@ def _generate_issue_588_mmseq_features(self) -> Path: ) return feature_dir + def _generate_issue_588_precomputed_mmseq_features(self) -> Path: + source_dir = self.output_dir / "issue_588_mmseq_source_features" + precomputed_dir = self.output_dir / "issue_588_mmseq_precomputed_features" + source_dir.mkdir(parents=True, exist_ok=True) + precomputed_dir.mkdir(parents=True, exist_ok=True) + fasta_paths = ",".join( + str(self.test_data_dir / "fastas" / f"{protein_id}.fasta") + for protein_id in self.ISSUE_588_IDS + ) + + source_res = self._run_prediction_subprocess( + [ + sys.executable, + str(self.script_create_features), + f"--fasta_paths={fasta_paths}", + f"--output_dir={source_dir}", + f"--data_dir={DATA_DIR}", + "--max_template_date=2024-05-02", + "--use_mmseqs2=True", + "--data_pipeline=alphafold2", + "--save_msa_files=True", + "--compress_features=True", + "--skip_existing=False", + ] + ) + self.assertEqual( + source_res.returncode, + 0, + "MMseqs source feature generation failed.\n" + f"STDOUT:\n{source_res.stdout}\nSTDERR:\n{source_res.stderr}", + ) + + for protein_id in self.ISSUE_588_IDS: + self.assertTrue( + (source_dir / f"{protein_id}.a3m").is_file(), + f"Expected MMseq A3M {source_dir / f'{protein_id}.a3m'} to be created.", + ) + self.assertTrue( + (source_dir / f"{protein_id}.pkl.xz").is_file(), + f"Expected compressed feature pickle {source_dir / f'{protein_id}.pkl.xz'} to be created.", + ) + shutil.copy2( + source_dir / f"{protein_id}.a3m", + precomputed_dir / f"{protein_id}.a3m", + ) + sidecar = source_dir / f"{protein_id}.mmseq_ids.json" + if sidecar.is_file(): + shutil.copy2(sidecar, precomputed_dir / sidecar.name) + + precomputed_res = self._run_prediction_subprocess( + [ + sys.executable, + str(self.script_create_features), + f"--fasta_paths={fasta_paths}", + f"--output_dir={precomputed_dir}", + f"--data_dir={DATA_DIR}", + "--max_template_date=2024-05-02", + "--use_mmseqs2=True", + "--use_precomputed_msas=True", + "--data_pipeline=alphafold2", + "--compress_features=True", + "--skip_existing=False", + ] + ) + self.assertEqual( + precomputed_res.returncode, + 0, + "Precomputed-MMseq feature generation failed.\n" + f"STDOUT:\n{precomputed_res.stdout}\nSTDERR:\n{precomputed_res.stderr}", + ) + for protein_id in self.ISSUE_588_IDS: + self.assertTrue( + (precomputed_dir / f"{protein_id}.a3m").is_file(), + f"Expected copied MMseq A3M {precomputed_dir / f'{protein_id}.a3m'} to be present.", + ) + self.assertTrue( + (precomputed_dir / f"{protein_id}.pkl.xz").is_file(), + f"Expected precomputed feature pickle {precomputed_dir / f'{protein_id}.pkl.xz'} to be created.", + ) + return precomputed_dir + def _resolve_af2_result_dir(self, root: Path) -> Path: if (root / "ranking_debug.json").exists(): return root @@ -642,5 +742,73 @@ def test_issue_588_mmseqs_generated_features_enable_af2_multimer_inference(self) f"Expected AF2 ipTM > 0.6, got {result_payload['iptm']}", ) + def test_issue_614_precomputed_mmseqs_features_enable_af2_multimer_inference(self): + """Issue #614 regression: AF2 should fold successfully from precomputed MMseq A3Ms.""" + self._require_mmseqs_functional_environment() + feature_dir = self._generate_issue_588_precomputed_mmseq_features() + + for protein_id in self.ISSUE_588_IDS: + metadata_path, metadata = _load_feature_metadata(feature_dir, protein_id) + self.assertTrue( + _metadata_bool(metadata["other"]["use_precomputed_msas"]), + f"{metadata_path} should record use_precomputed_msas=True", + ) + feature_dict = _load_feature_dict(feature_dir / f"{protein_id}.pkl.xz") + self.assertGreater( + _non_empty_identifier_count( + feature_dict["msa_species_identifiers_all_seq"] + ), + 0, + f"{protein_id} should keep recovered species IDs from cached MMseq A3Ms", + ) + self.assertGreater( + _non_empty_identifier_count( + feature_dict["msa_uniprot_accession_identifiers_all_seq"] + ), + 0, + f"{protein_id} should keep recovered accession IDs from cached MMseq A3Ms", + ) + + prediction_dir = self.output_dir / "af2_precomputed_prediction" + prediction_dir.mkdir(parents=True, exist_ok=True) + res = self._run_prediction_subprocess( + [ + sys.executable, + str(self.script_single), + "--input=A0ABD7FQG0+P18004", + f"--output_directory={prediction_dir}", + "--num_cycle=1", + "--num_predictions_per_model=1", + "--model_names=model_4_multimer_v3", + f"--data_directory={DATA_DIR}", + f"--features_directory={feature_dir}", + "--random_seed=42", + ] + ) + self.assertEqual( + res.returncode, + 0, + "AF2 inference from precomputed MMseq features failed.\n" + f"STDOUT:\n{res.stdout}\nSTDERR:\n{res.stderr}", + ) + + result_dir = self._resolve_af2_result_dir(prediction_dir) + ranking_payload = json.loads( + (result_dir / "ranking_debug.json").read_text(encoding="utf-8") + ) + self.assertTrue(ranking_payload["order"]) + + result_pickles = sorted(result_dir.glob("result_*.pkl")) + self.assertLen(result_pickles, 1) + with result_pickles[0].open("rb") as handle: + result_payload = pickle.load(handle) + self.assertIn("iptm", result_payload) + self.assertIn("ranking_confidence", result_payload) + self.assertGreater( + result_payload["iptm"], + 0.6, + f"Expected AF2 ipTM > 0.6 from precomputed MMseq features, got {result_payload['iptm']}", + ) + if __name__ == "__main__": absltest.main() diff --git a/test/cluster/check_alphafold3_predictions.py b/test/cluster/check_alphafold3_predictions.py index 91059b9a..9aeee410 100755 --- a/test/cluster/check_alphafold3_predictions.py +++ b/test/cluster/check_alphafold3_predictions.py @@ -20,7 +20,9 @@ import numpy as np import re import unittest +from types import SimpleNamespace from typing import Dict, List, Tuple, Any +from unittest import mock from absl.testing import absltest, parameterized @@ -151,14 +153,24 @@ def _non_empty_a3m_payload_rows(a3m_text: str) -> list[str]: def _load_feature_dict(feature_path: Path) -> dict[str, Any]: - opener = lzma.open if feature_path.suffix == ".xz" else open - with opener(feature_path, "rb") as handle: - payload = pickle.load(handle) + payload = _load_feature_payload(feature_path) if hasattr(payload, "feature_dict"): return payload.feature_dict return payload +def _load_feature_payload(feature_path: Path) -> Any: + opener = lzma.open if feature_path.suffix == ".xz" else open + with opener(feature_path, "rb") as handle: + return pickle.load(handle) + + +def _write_feature_payload(feature_path: Path, payload: Any) -> None: + opener = lzma.open if feature_path.suffix == ".xz" else open + with opener(feature_path, "wb") as handle: + pickle.dump(payload, handle) + + def _non_empty_identifier_count(values) -> int: count = 0 for value in values: @@ -1407,6 +1419,106 @@ def _prepare_fold_input( fold_input_obj, _ = next(iter(mappings[0].items())) return fold_input_obj + def _copy_real_feature_fixture( + self, + *, + source_dir: Path, + protein_id: str, + target_dir: Path, + ) -> Path: + copied_feature_path = None + for pattern in ( + f"{protein_id}.pkl", + f"{protein_id}.pkl.xz", + f"{protein_id}.a3m", + f"{protein_id}_feature_metadata_*.json*", + ): + for source_path in sorted(source_dir.glob(pattern)): + target_path = target_dir / source_path.name + shutil.copy2(source_path, target_path) + if source_path.name.startswith(f"{protein_id}.pkl"): + copied_feature_path = target_path + + self.assertIsNotNone( + copied_feature_path, + f"Missing real feature fixture for {protein_id} in {source_dir}", + ) + return copied_feature_path + + @staticmethod + def _synthetic_accession_ids(species_ids: np.ndarray) -> np.ndarray: + identifiers = [] + for index, value in enumerate(species_ids): + if isinstance(value, bytes): + value = value.decode("utf-8") + identifiers.append( + f"ACC{index:05d}".encode("utf-8") if str(value).strip() else b"" + ) + return np.asarray(identifiers, dtype=object) + + def _prepare_mixed_identifier_fixture_dir(self) -> Path: + """Materialize real AF2 fixtures with mixed identifier enrichment. + + The underlying MSA rows come from repo fixtures in `test/test_data`. + We only adjust the identifier sidecars so one chain looks enriched while + the other reproduces the "no species enrichment / no accession IDs" + failure mode from issue #614's AF3 follow-up comment. + """ + feature_dir = self.output_dir / "mixed_identifier_features" + feature_dir.mkdir(parents=True, exist_ok=True) + source_dir = self.test_features_dir / "af2_features" / "protein" + + enriched_feature_path = self._copy_real_feature_fixture( + source_dir=source_dir, + protein_id="A0A024R1R8", + target_dir=feature_dir, + ) + unenriched_feature_path = self._copy_real_feature_fixture( + source_dir=source_dir, + protein_id="P61626", + target_dir=feature_dir, + ) + + enriched_payload = _load_feature_payload(enriched_feature_path) + enriched_feature_dict = ( + enriched_payload.feature_dict + if hasattr(enriched_payload, "feature_dict") + else enriched_payload + ) + enriched_feature_dict["msa_uniprot_accession_identifiers"] = ( + self._synthetic_accession_ids( + np.asarray(enriched_feature_dict["msa_species_identifiers"]) + ) + ) + enriched_feature_dict["msa_uniprot_accession_identifiers_all_seq"] = ( + self._synthetic_accession_ids( + np.asarray(enriched_feature_dict["msa_species_identifiers_all_seq"]) + ) + ) + _write_feature_payload(enriched_feature_path, enriched_payload) + + unenriched_payload = _load_feature_payload(unenriched_feature_path) + unenriched_feature_dict = ( + unenriched_payload.feature_dict + if hasattr(unenriched_payload, "feature_dict") + else unenriched_payload + ) + unenriched_feature_dict["msa_species_identifiers"] = np.asarray( + [b""] * int(np.asarray(unenriched_feature_dict["msa"]).shape[0]), + dtype=object, + ) + unenriched_feature_dict["msa_species_identifiers_all_seq"] = np.asarray( + [b""] * int(np.asarray(unenriched_feature_dict["msa_all_seq"]).shape[0]), + dtype=object, + ) + unenriched_feature_dict.pop("msa_uniprot_accession_identifiers", None) + unenriched_feature_dict.pop( + "msa_uniprot_accession_identifiers_all_seq", None + ) + _write_feature_payload(unenriched_feature_path, unenriched_payload) + + return feature_dir + def test_issue_588_mmseqs_af2_features_produce_sane_af3_chain_input_msas(self): """Issue #588 regression: verify AF3 input construction from exact AF2/mmseqs2 pkl fixtures.""" from alphapulldown.folding_backend.alphafold3_backend import process_fold_input @@ -1638,6 +1750,117 @@ def test_af3_prepare_input_preserves_templates_for_templated_af2_pkl_features(se all(template["mmcif"] for template in protein_entries[0]["templates"]) ) + def test_af3_real_fixture_pipeline_tolerates_mixed_missing_accession_ids(self): + """AF3 prep should tolerate a real mixed-enrichment multimer feature set.""" + from alphapulldown.folding_backend.alphafold3_backend import ( + AlphaFold3Backend, + process_fold_input, + ) + from alphapulldown.scripts import run_structure_prediction + + feature_dir = self._prepare_mixed_identifier_fixture_dir() + + enriched_feature_dict = _load_feature_dict(feature_dir / "A0A024R1R8.pkl") + self.assertGreater( + _non_empty_identifier_count( + enriched_feature_dict["msa_uniprot_accession_identifiers_all_seq"] + ), + 0, + ) + unenriched_feature_dict = _load_feature_dict(feature_dir / "P61626.pkl") + self.assertEqual( + _non_empty_identifier_count( + unenriched_feature_dict["msa_species_identifiers_all_seq"] + ), + 0, + ) + self.assertNotIn( + "msa_uniprot_accession_identifiers_all_seq", + unenriched_feature_dict, + ) + + script_flags = SimpleNamespace( + pair_msa=True, + multimeric_template=False, + description_file=None, + path_to_mmt=None, + threshold_clashes=1000, + hb_allowance=0.4, + plddt_threshold=0, + save_features_for_multimeric_object=False, + features_directory=[str(feature_dir)], + use_ap_style=False, + ) + + with mock.patch.object(run_structure_prediction, "FLAGS", script_flags): + parsed = run_structure_prediction.parse_fold( + ["A0A024R1R8+P61626"], + [str(feature_dir)], + "+", + ) + data = run_structure_prediction.create_custom_info(parsed) + all_interactors = run_structure_prediction.create_interactors( + data, + [str(feature_dir)], + ) + self.assertLen(all_interactors, 1) + self.assertLen(all_interactors[0], 2) + object_to_model, prepared_output_dir = ( + run_structure_prediction.pre_modelling_setup( + all_interactors[0], + output_dir=str(self.output_dir / "mixed_identifier_prediction"), + ) + ) + + mappings = AlphaFold3Backend.prepare_input( + objects_to_model=[ + {"object": object_to_model, "output_dir": prepared_output_dir} + ], + random_seed=42, + debug_msas=True, + ) + self.assertLen(mappings, 1) + fold_input_obj, ( + prepared_output_dir, + resolve_msa_overlaps, + ) = next(iter(mappings[0].items())) + + process_fold_input( + fold_input=fold_input_obj, + model_runner=None, + output_dir=prepared_output_dir, + buckets=(512,), + resolve_msa_overlaps=resolve_msa_overlaps, + ) + + job_name = fold_input_obj.sanitised_name() + summary_path = ( + Path(prepared_output_dir) + / f"{job_name}_af2_to_af3_translation_summary.json" + ) + self.assertTrue(summary_path.is_file(), f"Missing translation summary {summary_path}") + summary = json.loads(summary_path.read_text(encoding="utf-8")) + self.assertLen(summary["chains"], 2) + self.assertTrue(summary["unpaired_rows_valid"]) + + input_json = Path(prepared_output_dir) / f"{job_name}_data.json" + self.assertTrue(input_json.is_file(), f"Missing AF3 input JSON {input_json}") + written = json.loads(input_json.read_text(encoding="utf-8")) + protein_entries = { + protein_entry["id"]: protein_entry + for protein_entry in _protein_entries_from_af3_input(written) + } + self.assertEqual(set(protein_entries), {"A", "B"}) + for chain in fold_input_obj.chains: + if not hasattr(chain, "sequence"): + continue + protein_entry = protein_entries[chain.id] + self.assertEqual(protein_entry["sequence"], chain.sequence) + self.assertEqual( + _a3m_query_sequence(protein_entry["unpairedMsa"]), + chain.sequence, + ) + class TestAlphaFold3MmseqsIssue588Inference(_TestBase): """Opt-in AF3 end-to-end smoke test for freshly regenerated mmseq AF2 features.""" diff --git a/test/cluster/run_alphafold2_predictions.py b/test/cluster/run_alphafold2_predictions.py index 1086e3dd..a921ec11 100644 --- a/test/cluster/run_alphafold2_predictions.py +++ b/test/cluster/run_alphafold2_predictions.py @@ -121,10 +121,13 @@ def _timestamp() -> str: def _default_gpu_env_lines(*, cpus_per_task: int) -> list[str]: - thread_count = max(1, min(cpus_per_task, 4)) + # Keep host-side BLAS/TF work single-threaded to avoid oversubscribing + # CPUs while AF2 functional tests are already pinning a single GPU job. + thread_count = 1 return [ "export PYTHONUNBUFFERED=1", f'export OMP_NUM_THREADS="${{OMP_NUM_THREADS:-{thread_count}}}"', + f'export OPENBLAS_NUM_THREADS="${{OPENBLAS_NUM_THREADS:-{thread_count}}}"', f'export MKL_NUM_THREADS="${{MKL_NUM_THREADS:-{thread_count}}}"', f'export NUMEXPR_NUM_THREADS="${{NUMEXPR_NUM_THREADS:-{thread_count}}}"', f'export TF_NUM_INTEROP_THREADS="${{TF_NUM_INTEROP_THREADS:-{thread_count}}}"', @@ -134,7 +137,7 @@ def _default_gpu_env_lines(*, cpus_per_task: int) -> list[str]: 'export XLA_PYTHON_CLIENT_PREALLOCATE="${XLA_PYTHON_CLIENT_PREALLOCATE:-false}"', 'export XLA_PYTHON_CLIENT_MEM_FRACTION="${XLA_PYTHON_CLIENT_MEM_FRACTION:-0.8}"', 'export JAX_PLATFORM_NAME="${JAX_PLATFORM_NAME:-gpu}"', - 'if [ -z "${XLA_FLAGS:-}" ]; then export XLA_FLAGS="--xla_gpu_force_compilation_parallelism=0 --xla_cpu_multi_thread_eigen=false intra_op_parallelism_threads=1"; fi', + 'if [ -z "${XLA_FLAGS:-}" ]; then export XLA_FLAGS="--xla_gpu_force_compilation_parallelism=1 --xla_force_host_platform_device_count=1 --xla_cpu_multi_thread_eigen=false intra_op_parallelism_threads=1"; fi', ] diff --git a/test/integration/test_create_individual_features.py b/test/integration/test_create_individual_features.py index 128dc84a..149bbd70 100644 --- a/test/integration/test_create_individual_features.py +++ b/test/integration/test_create_individual_features.py @@ -1678,7 +1678,7 @@ def test_create_and_save_monomer_objects_reuses_mmseq_identifier_sidecar( objects_mod, "get_msa_and_templates", lambda **_kwargs: ( - ["UNPAIRED"], + [a3m_text], ["PAIRED"], ["UNIQUE"], ["CARD"], @@ -1690,7 +1690,7 @@ def test_create_and_save_monomer_objects_reuses_mmseq_identifier_sidecar( objects_mod, "unserialize_msa", lambda a3m_lines, sequence: ( - ["PRECOMP_MSA"], + [a3m_text], ["PRECOMP_PAIRED"], ["UNIQUE"], ["CARD"], diff --git a/test/test_data/fastas/P04737.fasta b/test/test_data/fastas/P04737.fasta new file mode 100644 index 00000000..b05febd1 --- /dev/null +++ b/test/test_data/fastas/P04737.fasta @@ -0,0 +1,4 @@ +>sp|P04737|PIL1_ECOLI Pilin OS=Escherichia coli (strain K12) OX=83333 GN=traA PE=1 SV=1 +MNAVLSVQGASAPVKKKSFFSKFTRLNMLRLARAVIPAAVLMMFFPQLAMAAGSSGQDLM +ASGNTTVKATFGKDSSVVKWVVLAEVLVGAVMYMMTKNVKFLAGFAIISVFIAVGMAVVG +L diff --git a/test/test_data/fastas/P15069.fasta b/test/test_data/fastas/P15069.fasta new file mode 100644 index 00000000..8d0c7b22 --- /dev/null +++ b/test/test_data/fastas/P15069.fasta @@ -0,0 +1,9 @@ +>sp|P15069|TRAH1_ECOLI Protein TraH OS=Escherichia coli (strain K12) OX=83333 GN=traH PE=3 SV=2 +MMPRIKPLLVLCAALLTVTPAASADVNSDMNQFFNKLGFASNTTQPGVWQGQAAGYAYGG +SLYARTQVKNVQLISMTLPDINAGCGGIDAYLGSFSFINGEQLQRFVKQIMSNAAGYFFD +LALQTTVPEIKTAKDFLQKMASDINSMNLSSCQAAQGIIGGLFPRTQVSQQKVCQDIAGE +SNIFADWAASRQGCTVGGKSDSVRDKASDKDKERVTKNINIMWNALSKNRMFDGNKELKE +FVMTLTGSLVFGPNGEITPLSARTTDRSIIRAMMEGGTAKISHCNDSDKCLKVVADTPVT +ISRDNALKSQITKLLASIQNKAVSDTPLDDKEKGFISSTTIPVFKYLVDPQMLGVSNSMI +YQLTDYIGYDILLQYIQELIQQARAMVATGNYDEAVIGHINDNMNDATRQIAAFQSQVQV +QQDALLVVDRQMSYMRQQLSARMLSRYQNNYHFGGSTL diff --git a/test/unit/test_alphafold2_backend_helpers.py b/test/unit/test_alphafold2_backend_helpers.py index 565689c3..18eb65ac 100644 --- a/test/unit/test_alphafold2_backend_helpers.py +++ b/test/unit/test_alphafold2_backend_helpers.py @@ -569,6 +569,73 @@ def fake_pad(feature_dict, desired_num_msa, desired_num_res): assert payload["plddt"].tolist() == [91.0, 88.0] +def test_predict_individual_job_accepts_tuple_prediction_results( + af2_backend_module, + tmp_path, +): + monomer = af2_backend_module.MonomericObject("single", "AB") + monomer.feature_dict = {"residue_index": np.array([0, 1], dtype=np.int32)} + + fake_runner = SimpleNamespace( + multimer_mode=False, + process_features=lambda feature_dict, random_seed: dict(feature_dict), + predict=lambda processed_feature_dict, random_seed: ( + { + "plddt": np.array([91.0, 88.0], dtype=np.float32), + "predicted_aligned_error": np.zeros((2, 2), dtype=np.float32), + "max_predicted_aligned_error": 31.0, + }, + {"auxiliary": "ignored"}, + ), + ) + + results = af2_backend_module.AlphaFold2Backend.predict_individual_job( + model_runners={"modelA": fake_runner}, + multimeric_object=monomer, + allow_resume=False, + skip_templates=False, + output_dir=tmp_path, + random_seed=13, + ) + + assert results["modelA"]["plddt"].tolist() == [91.0, 88.0] + assert results["modelA"]["seqs"] == ["AB"] + assert results["modelA"]["unrelaxed_protein"].name == "predicted" + with open(tmp_path / "result_modelA.pkl", "rb") as handle: + payload = pickle.load(handle) + assert payload["plddt"].tolist() == [91.0, 88.0] + + +def test_predict_individual_job_rejects_tuple_without_mapping_payload( + af2_backend_module, + tmp_path, +): + monomer = af2_backend_module.MonomericObject("single", "AB") + monomer.feature_dict = {"residue_index": np.array([0, 1], dtype=np.int32)} + + fake_runner = SimpleNamespace( + multimer_mode=False, + process_features=lambda feature_dict, random_seed: dict(feature_dict), + predict=lambda processed_feature_dict, random_seed: ( + "not-a-mapping", + {"auxiliary": "ignored"}, + ), + ) + + with pytest.raises( + TypeError, + match=r"model_runner\.predict must return a mapping or a \(mapping, auxiliary\) tuple", + ): + af2_backend_module.AlphaFold2Backend.predict_individual_job( + model_runners={"modelA": fake_runner}, + multimeric_object=monomer, + allow_resume=False, + skip_templates=False, + output_dir=tmp_path, + random_seed=17, + ) + + def test_predict_individual_job_rejects_skipped_templates_in_multimer_mode( af2_backend_module, tmp_path, diff --git a/test/unit/test_cluster_wrapper_helpers.py b/test/unit/test_cluster_wrapper_helpers.py index fa5b192f..d1bb99a8 100644 --- a/test/unit/test_cluster_wrapper_helpers.py +++ b/test/unit/test_cluster_wrapper_helpers.py @@ -24,6 +24,7 @@ def test_af2_cluster_subprocess_env_sets_safe_gpu_defaults(monkeypatch): for name in ( "OMP_NUM_THREADS", + "OPENBLAS_NUM_THREADS", "MKL_NUM_THREADS", "NUMEXPR_NUM_THREADS", "TF_NUM_INTEROP_THREADS", @@ -39,17 +40,19 @@ def test_af2_cluster_subprocess_env_sets_safe_gpu_defaults(monkeypatch): env = module._af2_subprocess_env() - assert env["OMP_NUM_THREADS"] == "4" - assert env["MKL_NUM_THREADS"] == "4" - assert env["NUMEXPR_NUM_THREADS"] == "4" - assert env["TF_NUM_INTEROP_THREADS"] == "4" - assert env["TF_NUM_INTRAOP_THREADS"] == "4" + assert env["OMP_NUM_THREADS"] == "1" + assert env["OPENBLAS_NUM_THREADS"] == "1" + assert env["MKL_NUM_THREADS"] == "1" + assert env["NUMEXPR_NUM_THREADS"] == "1" + assert env["TF_NUM_INTEROP_THREADS"] == "1" + assert env["TF_NUM_INTRAOP_THREADS"] == "1" assert env["TF_FORCE_GPU_ALLOW_GROWTH"] == "true" assert env["TF_CPP_MIN_LOG_LEVEL"] == "2" assert env["XLA_PYTHON_CLIENT_PREALLOCATE"] == "false" assert env["XLA_PYTHON_CLIENT_MEM_FRACTION"] == "0.8" assert env["JAX_PLATFORM_NAME"] == "gpu" - assert "--xla_gpu_force_compilation_parallelism=0" in env["XLA_FLAGS"] + assert "--xla_gpu_force_compilation_parallelism=1" in env["XLA_FLAGS"] + assert "--xla_force_host_platform_device_count=1" in env["XLA_FLAGS"] def test_af2_cluster_wrapper_job_script_exports_gpu_defaults(tmp_path): @@ -72,9 +75,12 @@ def test_af2_cluster_wrapper_job_script_exports_gpu_defaults(tmp_path): ) script_text = job.script_path.read_text(encoding="utf-8") - assert 'OMP_NUM_THREADS="${OMP_NUM_THREADS:-4}"' in script_text - assert 'TF_NUM_INTRAOP_THREADS="${TF_NUM_INTRAOP_THREADS:-4}"' in script_text + assert 'OMP_NUM_THREADS="${OMP_NUM_THREADS:-1}"' in script_text + assert 'OPENBLAS_NUM_THREADS="${OPENBLAS_NUM_THREADS:-1}"' in script_text + assert 'TF_NUM_INTRAOP_THREADS="${TF_NUM_INTRAOP_THREADS:-1}"' in script_text assert 'JAX_PLATFORM_NAME="${JAX_PLATFORM_NAME:-gpu}"' in script_text + assert "--xla_gpu_force_compilation_parallelism=1" in script_text + assert "--xla_force_host_platform_device_count=1" in script_text assert "addopts=-ra --strict-markers" in script_text assert "--use-temp-dir" in script_text diff --git a/test/unit/test_mmseqs_species_identifiers.py b/test/unit/test_mmseqs_species_identifiers.py index b8f30012..d28074d7 100644 --- a/test/unit/test_mmseqs_species_identifiers.py +++ b/test/unit/test_mmseqs_species_identifiers.py @@ -1,4 +1,5 @@ import json +import os from urllib import error import numpy as np @@ -11,6 +12,32 @@ from alphapulldown.utils import mmseqs_species_identifiers +_FASTAS_DIR = os.path.join( + os.path.dirname(os.path.dirname(os.path.abspath(__file__))), + 'test_data', + 'fastas', +) + + +def _read_uniprot_fasta(uniprot_id: str) -> str: + path = os.path.join(_FASTAS_DIR, f'{uniprot_id}.fasta') + with open(path, encoding='utf-8') as handle: + lines = handle.read().splitlines() + return ''.join(line for line in lines[1:] if line and not line.startswith('>')) + + +def _build_colabfold_server_a3m( + query_sequence: str, hits: list[tuple[str, str]] +) -> str: + """Assemble a ColabFold-server-style A3M ('#\\t1' header, one-line rows).""" + lines = [f'#{len(query_sequence)}\t1', '>101', query_sequence] + for header, aligned in hits: + lines.append(f'>{header}') + lines.append(aligned) + lines.append('') + return '\n'.join(lines) + + @pytest.fixture(autouse=True) def clear_species_id_cache(): mmseqs_species_identifiers._SPECIES_ID_CACHE.clear() @@ -419,7 +446,7 @@ def fake_enrich(feature_dict, a3m, **kwargs): 'template_feature': 'TEMPLATE_FROM_RESEARCH', } assert calls['enrich_mmseq_feature_dict_with_identifiers'] == { - 'a3m': '>101\nACDE', + 'a3m': 'PRECOMPUTED_UNPAIRED', 'kwargs': {'cache_path': str(tmp_path / 'dummy.mmseq_ids.json')}, } assert isinstance(monomer.feature_dict['template_confidence_scores'], np.ndarray) @@ -579,3 +606,110 @@ def fake_uniparc_query(accessions, *, urlopen): } assert uniprot_calls == [('A0A636IKY3',)] assert uniparc_calls == [('UPI001118B830',)] + + +@pytest.mark.parametrize( + 'uniprot_id,accession_species', + [ + ( + 'P04737', + { + 'A0A636IKY3': '562', + 'A0A743YDY2': '573', + 'UPI001118B830': '562', + }, + ), + ( + 'P15069', + { + 'A0A636IKY3': '562', + 'A0A743YDY2': '573', + 'UPI001118B830': '562', + }, + ), + ], +) +def test_make_mmseq_features_precomputed_colabfold_a3m_enriches_identifiers( + monkeypatch, tmp_path, uniprot_id, accession_species +): + """Regression for issue #613: precomputed ColabFold-server A3Ms must enrich. + + Before the fix, make_mmseq_features parsed the raw A3M for identifiers but + fed a different processed string to build_monomer_feature, so dedup rules + disagreed and identifier rows did not match MSA rows. This drove + enrich_mmseq_feature_dict_with_identifiers to log a warning and skip + enrichment, leaving species pairing unusable. + """ + + query = _read_uniprot_fasta(uniprot_id) + assert len(query) > 0 + + # Realistic ColabFold-server-style A3M: '#\t1' header, one header/seq per + # line pair, a mix of exact-duplicate hits, insertion-variant hits (lowercase + # letters) that strip to the query, an all-gap row, and point-mutation hits + # with real-format UniProt accessions so enrichment has something to resolve. + hits = [ + (f'sp|{uniprot_id}|QUERY_DUP', query), + ('UniRef100_A0A636IKY3', query[:10] + 'a' + query[10:]), + ('UniRef100_A0A743YDY2', query[:15] + 'bc' + query[15:]), + ('UniRef100_UPI001118B830', query[:20] + 'def' + query[20:]), + ('UniRef100_A0A100XYZ0', query[:5] + 'X' + query[6:]), + ('UniRef100_A0A200ABC5', query[:8] + 'Y' + query[9:]), + ('UniRef100_GAP_ROW', '-' * len(query)), + ('UniRef100_ALL_LOWER', 'a' * len(query)), + ] + accession_species = { + **accession_species, + 'A0A100XYZ0': '9606', + 'A0A200ABC5': '10090', + } + a3m = _build_colabfold_server_a3m(query, hits) + + precomputed_a3m = tmp_path / f'{uniprot_id}.a3m' + precomputed_a3m.write_text(a3m, encoding='utf-8') + + monkeypatch.setattr( + MonomericObject, 'unzip_msa_files', staticmethod(lambda _path: False) + ) + monkeypatch.setattr( + mmseqs_species_identifiers, + 'resolve_species_ids_by_accession', + lambda accessions, **_: { + accession: accession_species.get(accession, '') + for accession in accessions + }, + ) + + monomer = MonomericObject(uniprot_id, query) + monomer.make_mmseq_features( + DEFAULT_API_SERVER='https://unused.example', + output_dir=str(tmp_path), + use_precomputed_msa=True, + use_templates=False, + ) + + msa = monomer.feature_dict['msa'] + species = monomer.feature_dict['msa_species_identifiers'] + accessions = monomer.feature_dict['msa_uniprot_accession_identifiers'] + + assert species.shape[0] == msa.shape[0], ( + f'enrichment row count {species.shape[0]} != msa rows {msa.shape[0]}' + ) + assert accessions.shape[0] == msa.shape[0] + # '_all_seq' mirrors the enriched rows, used for pairing downstream. + assert ( + monomer.feature_dict['msa_species_identifiers_all_seq'].shape[0] + == msa.shape[0] + ) + # The resolver is called with the real UniProt-format accessions only — + # insertion-variant hits collapse onto the query row but the point-mutation + # hit survives, so at least one of the resolvable accessions lands in the + # deduped identifier rows. + resolvable = { + a.decode('utf-8') + for a in accessions.tolist() + if a.decode('utf-8') in accession_species + } + assert resolvable, ( + f'expected at least one resolvable accession in {accessions.tolist()}' + ) diff --git a/test/unit/test_objects.py b/test/unit/test_objects.py index edf8e7f3..32c60ae1 100644 --- a/test/unit/test_objects.py +++ b/test/unit/test_objects.py @@ -397,7 +397,7 @@ def fake_enrich(feature_dict, a3m, **kwargs): assert calls["build_monomer_feature"] == ("ACDE", "UNPAIRED", "TEMPLATE") assert calls["enrich"] == { - "a3m": ">101\nACDE\n>hit\nAC-E", + "a3m": "UNPAIRED", "kwargs": {"cache_path": str(tmp_path / "proteinA.mmseq_ids.json")}, } assert (tmp_path / "proteinA.a3m").read_text(encoding="utf-8").startswith(">101") @@ -421,7 +421,7 @@ def test_make_mmseq_features_skip_msa_uses_single_sequence_mode( def fake_get_msa_and_templates(**kwargs): calls["get_msa_and_templates"] = kwargs - return (["UNPAIRED"], [""], ["UNIQUE"], ["CARD"], ["TEMPLATE"]) + return ([">101\nACDE\n"], [""], ["UNIQUE"], ["CARD"], ["TEMPLATE"]) monkeypatch.setattr(objects_mod, "get_msa_and_templates", fake_get_msa_and_templates) monkeypatch.setattr( @@ -462,7 +462,7 @@ def fake_enrich(feature_dict, a3m, **kwargs): assert calls["get_msa_and_templates"]["pair_mode"] == "none" assert calls["get_msa_and_templates"]["a3m_lines"] == [">101\nACDE"] assert calls["get_msa_and_templates"]["use_templates"] is True - assert calls["enrich"]["a3m"] == ">101\nACDE" + assert calls["enrich"]["a3m"] == ">101\nACDE\n" assert monomer.skip_msa is True assert monomer.feature_dict["msa"].shape == (1, 4) assert monomer.feature_dict["msa_all_seq"].shape == (1, 4) @@ -757,7 +757,7 @@ def test_make_mmseq_features_reuses_identifier_sidecar_on_precomputed_run( objects_mod, "get_msa_and_templates", lambda **_kwargs: ( - ["UNPAIRED"], + [a3m_text], ["PAIRED"], ["UNIQUE"], ["CARD"], @@ -774,7 +774,7 @@ def test_make_mmseq_features_reuses_identifier_sidecar_on_precomputed_run( objects_mod, "unserialize_msa", lambda a3m_lines, sequence: ( - ["PRECOMP_MSA"], + [a3m_text], ["PRECOMP_PAIRED"], ["UNIQUE"], ["CARD"], @@ -1417,6 +1417,81 @@ def fake_merge_chain_features(**kwargs): assert output == {"processed": {"merged": True}} +def test_pair_and_merge_backfills_missing_all_seq_accession_ids_before_pairing( + monkeypatch, +): + multimer = MultimericObject.__new__(MultimericObject) + multimer.pair_msa = True + calls = {} + + chain_a = _feature_dict(sequence="ACDE", msa_rows=1, all_seq_rows=2, template_count=0) + chain_b = _feature_dict(sequence="FGHI", msa_rows=1, all_seq_rows=2, template_count=0) + chain_a["msa_species_identifiers_all_seq"] = np.asarray([b"", b"9606"], dtype=object) + chain_b["msa_species_identifiers_all_seq"] = np.asarray([b"", b"9606"], dtype=object) + chain_a["msa_uniprot_accession_identifiers"] = np.asarray([b"P12345"], dtype=object) + chain_a["msa_uniprot_accession_identifiers_all_seq"] = np.asarray( + [b"", b"P12345"], + dtype=object, + ) + + real_create_paired_features = objects_mod.msa_pairing.create_paired_features + + monkeypatch.setattr( + objects_mod.feature_processing, + "process_unmerged_features", + lambda _features: None, + ) + monkeypatch.setattr( + objects_mod.feature_processing, + "_is_homomer_or_monomer", + lambda _chains: False, + ) + + def wrapped_create_paired_features(*, chains): + calls["create_paired_features"] = chains + return real_create_paired_features(chains) + + monkeypatch.setattr( + objects_mod.msa_pairing, + "create_paired_features", + wrapped_create_paired_features, + ) + monkeypatch.setattr( + objects_mod.msa_pairing, + "deduplicate_unpaired_sequences", + lambda chains: chains, + ) + monkeypatch.setattr( + objects_mod.feature_processing, + "crop_chains", + lambda chains, **kwargs: chains, + ) + monkeypatch.setattr( + objects_mod.msa_pairing, + "merge_chain_features", + lambda **kwargs: {"chains": kwargs["np_chains_list"]}, + ) + monkeypatch.setattr( + objects_mod.feature_processing, + "process_final", + lambda example: example, + ) + + output = multimer.pair_and_merge({"A": chain_a, "B": chain_b}) + + assert calls["create_paired_features"][1][ + "msa_uniprot_accession_identifiers" + ].tolist() == [b""] + assert calls["create_paired_features"][1][ + "msa_uniprot_accession_identifiers_all_seq" + ].tolist() == [b"", b""] + assert output["chains"][1]["msa_uniprot_accession_identifiers"].tolist() == [b""] + assert output["chains"][1]["msa_uniprot_accession_identifiers_all_seq"].tolist() == [ + b"", + b"", + ] + + def test_pair_and_merge_removes_all_seq_features_when_pairing_disabled(monkeypatch): multimer = MultimericObject.__new__(MultimericObject) multimer.pair_msa = False diff --git a/test/unit/test_objects_helpers.py b/test/unit/test_objects_helpers.py index 048e5e41..f820bf26 100644 --- a/test/unit/test_objects_helpers.py +++ b/test/unit/test_objects_helpers.py @@ -135,3 +135,46 @@ def fake_run_msa_tool(*args, **kwargs): True, ) assert run_kwargs == {} + + +def test_all_seq_msa_features_backfills_missing_uniprot_accession_identifiers( + monkeypatch, tmp_path +): + monomer = MonomericObject("desc", "ACDE") + input_fasta_path = str(tmp_path / "input.fasta") + Path(input_fasta_path).write_text(">x\nACDE\n", encoding="utf-8") + + class FakeMsa: + def truncate(self, max_seqs): + assert max_seqs == 50000 + return self + + monkeypatch.setattr( + "alphapulldown.objects.pipeline.run_msa_tool", + lambda *args, **kwargs: {"sto": "fake"}, + ) + monkeypatch.setattr( + "alphapulldown.objects.parsers.parse_stockholm", + lambda sto: FakeMsa(), + ) + monkeypatch.setattr( + "alphapulldown.objects.pipeline.make_msa_features", + lambda _msas: { + "msa": np.asarray([[1, 2], [1, 3]], dtype=np.int32), + "msa_species_identifiers": np.asarray([b"", b"9606"], dtype=object), + "deletion_matrix_int": np.asarray([[0, 0], [0, 0]], dtype=np.int32), + }, + ) + + features = monomer.all_seq_msa_features( + input_fasta_path=input_fasta_path, + uniprot_msa_runner="runner", + output_dir=str(tmp_path), + use_precomputed_msa=False, + ) + + assert features["msa_species_identifiers_all_seq"].tolist() == [b"", b"9606"] + assert features["msa_uniprot_accession_identifiers_all_seq"].tolist() == [ + b"", + b"", + ]