Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions alphapulldown/folding_backend/alphafold2_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import subprocess
import enum
import typing
from collections.abc import Mapping
from typing import Dict, Union, List, Any
import os
from absl import logging
Expand Down Expand Up @@ -145,6 +146,37 @@ def _jnp_to_np(output):
return output


def _normalize_prediction_result(
prediction_result: Any, *, model_name: str
) -> Dict[str, Any]:
"""Accept the supported AF2 runner return shapes and return a mutable dict."""
if isinstance(prediction_result, Mapping):
return dict(prediction_result)

if isinstance(prediction_result, tuple):
if len(prediction_result) == 2 and isinstance(prediction_result[0], Mapping):
auxiliary_output = prediction_result[1]
logging.warning(
"Model %s returned a (prediction_result, auxiliary_output) tuple; "
"ignoring auxiliary output of type %s.",
model_name,
type(auxiliary_output).__name__,
)
return dict(prediction_result[0])

first_type = type(prediction_result[0]).__name__ if prediction_result else "n/a"
raise TypeError(
"model_runner.predict must return a mapping or a (mapping, auxiliary) "
f"tuple; got tuple of length {len(prediction_result)} with first "
f"element type {first_type}."
)

raise TypeError(
"model_runner.predict must return a mapping or a (mapping, auxiliary) "
f"tuple; got {type(prediction_result).__name__}."
)


def _save_pae_json_file(pae: np.ndarray, max_pae: float, output_dir: str, model_name: str) -> None:
"""
Check prediction result for PAE data and save to a JSON file if present.
Expand Down Expand Up @@ -709,6 +741,9 @@ def predict_individual_job(
prediction_result = model_runner.predict(
processed_feature_dict, random_seed=model_random_seed
)
prediction_result = _normalize_prediction_result(
prediction_result, model_name=model_name
)
t_diff = time.time() - t_0
timings[f"predict_and_compile_{model_name}"] = t_diff
logging.info(f"prediction costs : {t_diff} s")
Expand Down
73 changes: 70 additions & 3 deletions alphapulldown/objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,23 @@ def _query_only_stockholm(sequence: str, query_id: str = "query") -> str:
)


def _ensure_identifier_feature_arrays(
feature_dict: Dict[str, np.ndarray],
feature_groups: Tuple[Tuple[str, Tuple[str, ...]], ...],
) -> Dict[str, np.ndarray]:
"""Backfill missing identifier arrays to match the corresponding MSA rows."""
normalized = dict(feature_dict)
for msa_key, identifier_keys in feature_groups:
msa = normalized.get(msa_key)
if msa is None:
continue
num_rows = int(np.asarray(msa).shape[0])
for key in identifier_keys:
if key not in normalized:
normalized[key] = np.array([b""] * num_rows, dtype=object)
return normalized


class MonomericObject:
"""
monomeric objects
Expand Down Expand Up @@ -145,7 +162,18 @@ def all_seq_msa_features(

msa = parsers.parse_stockholm(result["sto"])
msa = msa.truncate(max_seqs=50000)
all_seq_features = pipeline.make_msa_features([msa])
all_seq_features = _ensure_identifier_feature_arrays(
pipeline.make_msa_features([msa]),
(
(
"msa",
(
"msa_species_identifiers",
"msa_uniprot_accession_identifiers",
),
),
),
)
valid_feats = msa_pairing.MSA_FEATURES + (
"msa_species_identifiers",
"msa_uniprot_accession_identifiers",
Expand Down Expand Up @@ -362,9 +390,12 @@ def make_mmseq_features(
# Remove header lines starting with '#' if present.
a3m_lines[0] = strip_mmseq_comment_lines(a3m_lines[0])
self.feature_dict = build_monomer_feature(self.sequence, unpaired_msa[0], template_features[0])
# Enrich from the same A3M string that build_monomer_feature parsed, so
# the identifier rows go through the same parse_a3m dedup as msa_features
# and their count matches feature_dict['msa'] exactly.
enrich_mmseq_feature_dict_with_identifiers(
self.feature_dict,
a3m_lines[0],
unpaired_msa[0],
cache_path=os.path.join(
result_dir, f"{self.description}.mmseq_ids.json"
),
Expand Down Expand Up @@ -811,12 +842,48 @@ def remove_all_seq_features(np_chain_list: List[Dict]) -> List[Dict]:
output_list.append(new_chain)
return output_list

@staticmethod
def normalize_all_seq_identifier_features(np_chain_list: List[Dict]) -> List[Dict]:
"""Ensure identifier arrays exist consistently across chains.

Some feature sources provide species identifiers but omit UniProt
accession IDs, while DeepMind's multimer pairing and merge code assumes
both unpaired and `_all_seq` identifier keys exist consistently across
chains.
"""
output_list = []
for feat_dict in np_chain_list:
output_list.append(
_ensure_identifier_feature_arrays(
feat_dict,
(
(
"msa",
(
"msa_species_identifiers",
"msa_uniprot_accession_identifiers",
),
),
(
"msa_all_seq",
(
"msa_species_identifiers_all_seq",
"msa_uniprot_accession_identifiers_all_seq",
),
),
),
)
)
return output_list

def pair_and_merge(self, all_chain_features):
"""merge all chain features"""
feature_processing.process_unmerged_features(all_chain_features)
MAX_TEMPLATES = 4
MSA_CROP_SIZE = 2048
np_chains_list = list(all_chain_features.values())
np_chains_list = MultimericObject.normalize_all_seq_identifier_features(
list(all_chain_features.values())
)
pair_msa_sequences = self.pair_msa and not feature_processing._is_homomer_or_monomer(
np_chains_list)
logging.debug(f"pair_msa_sequences is type : {type(pair_msa_sequences)} value: {pair_msa_sequences}")
Expand Down
180 changes: 174 additions & 6 deletions test/cluster/check_alphafold2_predictions.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,23 @@ def _load_feature_dict(feature_path: Path) -> dict:
return payload


def _load_feature_metadata(feature_dir: Path, protein_id: str) -> tuple[Path, dict]:
matches = sorted(feature_dir.glob(f"{protein_id}_feature_metadata_*.json*"))
if len(matches) != 1:
raise FileNotFoundError(
f"Expected one feature metadata file for {protein_id} in {feature_dir}, "
f"found {matches}"
)
metadata_path = matches[0]
opener = lzma.open if metadata_path.suffix == ".xz" else open
with opener(metadata_path, "rt", encoding="utf-8") as handle:
return metadata_path, json.load(handle)


def _metadata_bool(value) -> bool:
return str(value).strip().lower() in {"1", "true", "yes"}


def _non_empty_identifier_count(values) -> int:
count = 0
for value in values:
Expand All @@ -105,19 +122,21 @@ def _non_empty_identifier_count(values) -> int:
def _af2_subprocess_env() -> dict[str, str]:
"""Return stable GPU/JAX defaults for AF2 functional subprocesses."""
env = os.environ.copy()
env.setdefault("OMP_NUM_THREADS", "4")
env.setdefault("MKL_NUM_THREADS", "4")
env.setdefault("NUMEXPR_NUM_THREADS", "4")
env.setdefault("TF_NUM_INTEROP_THREADS", "4")
env.setdefault("TF_NUM_INTRAOP_THREADS", "4")
env.setdefault("OMP_NUM_THREADS", "1")
env.setdefault("OPENBLAS_NUM_THREADS", "1")
env.setdefault("MKL_NUM_THREADS", "1")
env.setdefault("NUMEXPR_NUM_THREADS", "1")
env.setdefault("TF_NUM_INTEROP_THREADS", "1")
env.setdefault("TF_NUM_INTRAOP_THREADS", "1")
env.setdefault("TF_FORCE_GPU_ALLOW_GROWTH", "true")
env.setdefault("TF_CPP_MIN_LOG_LEVEL", "2")
env.setdefault("XLA_PYTHON_CLIENT_PREALLOCATE", "false")
env.setdefault("XLA_PYTHON_CLIENT_MEM_FRACTION", "0.8")
env.setdefault("JAX_PLATFORM_NAME", "gpu")
env.setdefault(
"XLA_FLAGS",
"--xla_gpu_force_compilation_parallelism=0 "
"--xla_gpu_force_compilation_parallelism=1 "
"--xla_force_host_platform_device_count=1 "
"--xla_cpu_multi_thread_eigen=false intra_op_parallelism_threads=1",
)
return env
Expand Down Expand Up @@ -544,6 +563,87 @@ def _generate_issue_588_mmseq_features(self) -> Path:
)
return feature_dir

def _generate_issue_588_precomputed_mmseq_features(self) -> Path:
source_dir = self.output_dir / "issue_588_mmseq_source_features"
precomputed_dir = self.output_dir / "issue_588_mmseq_precomputed_features"
source_dir.mkdir(parents=True, exist_ok=True)
precomputed_dir.mkdir(parents=True, exist_ok=True)
fasta_paths = ",".join(
str(self.test_data_dir / "fastas" / f"{protein_id}.fasta")
for protein_id in self.ISSUE_588_IDS
)

source_res = self._run_prediction_subprocess(
[
sys.executable,
str(self.script_create_features),
f"--fasta_paths={fasta_paths}",
f"--output_dir={source_dir}",
f"--data_dir={DATA_DIR}",
"--max_template_date=2024-05-02",
"--use_mmseqs2=True",
"--data_pipeline=alphafold2",
"--save_msa_files=True",
"--compress_features=True",
"--skip_existing=False",
]
)
self.assertEqual(
source_res.returncode,
0,
"MMseqs source feature generation failed.\n"
f"STDOUT:\n{source_res.stdout}\nSTDERR:\n{source_res.stderr}",
)

for protein_id in self.ISSUE_588_IDS:
self.assertTrue(
(source_dir / f"{protein_id}.a3m").is_file(),
f"Expected MMseq A3M {source_dir / f'{protein_id}.a3m'} to be created.",
)
self.assertTrue(
(source_dir / f"{protein_id}.pkl.xz").is_file(),
f"Expected compressed feature pickle {source_dir / f'{protein_id}.pkl.xz'} to be created.",
)
shutil.copy2(
source_dir / f"{protein_id}.a3m",
precomputed_dir / f"{protein_id}.a3m",
)
sidecar = source_dir / f"{protein_id}.mmseq_ids.json"
if sidecar.is_file():
shutil.copy2(sidecar, precomputed_dir / sidecar.name)

precomputed_res = self._run_prediction_subprocess(
[
sys.executable,
str(self.script_create_features),
f"--fasta_paths={fasta_paths}",
f"--output_dir={precomputed_dir}",
f"--data_dir={DATA_DIR}",
"--max_template_date=2024-05-02",
"--use_mmseqs2=True",
"--use_precomputed_msas=True",
"--data_pipeline=alphafold2",
"--compress_features=True",
"--skip_existing=False",
]
)
self.assertEqual(
precomputed_res.returncode,
0,
"Precomputed-MMseq feature generation failed.\n"
f"STDOUT:\n{precomputed_res.stdout}\nSTDERR:\n{precomputed_res.stderr}",
)
for protein_id in self.ISSUE_588_IDS:
self.assertTrue(
(precomputed_dir / f"{protein_id}.a3m").is_file(),
f"Expected copied MMseq A3M {precomputed_dir / f'{protein_id}.a3m'} to be present.",
)
self.assertTrue(
(precomputed_dir / f"{protein_id}.pkl.xz").is_file(),
f"Expected precomputed feature pickle {precomputed_dir / f'{protein_id}.pkl.xz'} to be created.",
)
return precomputed_dir

def _resolve_af2_result_dir(self, root: Path) -> Path:
if (root / "ranking_debug.json").exists():
return root
Expand Down Expand Up @@ -642,5 +742,73 @@ def test_issue_588_mmseqs_generated_features_enable_af2_multimer_inference(self)
f"Expected AF2 ipTM > 0.6, got {result_payload['iptm']}",
)

def test_issue_614_precomputed_mmseqs_features_enable_af2_multimer_inference(self):
"""Issue #614 regression: AF2 should fold successfully from precomputed MMseq A3Ms."""
self._require_mmseqs_functional_environment()
feature_dir = self._generate_issue_588_precomputed_mmseq_features()

for protein_id in self.ISSUE_588_IDS:
metadata_path, metadata = _load_feature_metadata(feature_dir, protein_id)
self.assertTrue(
_metadata_bool(metadata["other"]["use_precomputed_msas"]),
f"{metadata_path} should record use_precomputed_msas=True",
)
feature_dict = _load_feature_dict(feature_dir / f"{protein_id}.pkl.xz")
self.assertGreater(
_non_empty_identifier_count(
feature_dict["msa_species_identifiers_all_seq"]
),
0,
f"{protein_id} should keep recovered species IDs from cached MMseq A3Ms",
)
self.assertGreater(
_non_empty_identifier_count(
feature_dict["msa_uniprot_accession_identifiers_all_seq"]
),
0,
f"{protein_id} should keep recovered accession IDs from cached MMseq A3Ms",
)

prediction_dir = self.output_dir / "af2_precomputed_prediction"
prediction_dir.mkdir(parents=True, exist_ok=True)
res = self._run_prediction_subprocess(
[
sys.executable,
str(self.script_single),
"--input=A0ABD7FQG0+P18004",
f"--output_directory={prediction_dir}",
"--num_cycle=1",
"--num_predictions_per_model=1",
"--model_names=model_4_multimer_v3",
f"--data_directory={DATA_DIR}",
f"--features_directory={feature_dir}",
"--random_seed=42",
]
)
self.assertEqual(
res.returncode,
0,
"AF2 inference from precomputed MMseq features failed.\n"
f"STDOUT:\n{res.stdout}\nSTDERR:\n{res.stderr}",
)

result_dir = self._resolve_af2_result_dir(prediction_dir)
ranking_payload = json.loads(
(result_dir / "ranking_debug.json").read_text(encoding="utf-8")
)
self.assertTrue(ranking_payload["order"])

result_pickles = sorted(result_dir.glob("result_*.pkl"))
self.assertLen(result_pickles, 1)
with result_pickles[0].open("rb") as handle:
result_payload = pickle.load(handle)
self.assertIn("iptm", result_payload)
self.assertIn("ranking_confidence", result_payload)
self.assertGreater(
result_payload["iptm"],
0.6,
f"Expected AF2 ipTM > 0.6 from precomputed MMseq features, got {result_payload['iptm']}",
)

if __name__ == "__main__":
absltest.main()
Loading
Loading