Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 101 additions & 1 deletion alphapulldown/folding_backend/alphafold3_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import re
import time
import typing
from collections import Counter
from collections.abc import Sequence
from typing import Any, List, Dict, Union, overload

Expand Down Expand Up @@ -1450,7 +1451,85 @@ def _write_translated_msa_debug_artifacts(
if not chain_records or not translation_results:
return

def _extract_paired_species_identifier(description: str) -> str:
if not description.startswith(("tr|", "sp|")):
return ""
description_tail = description.split("|")[-1]
if "_" not in description_tail:
return ""
return description_tail.rsplit("_", maxsplit=1)[-1].strip()

def _summarise_effective_af3_pairing(
records: list[_TranslatedMsaDebugRecord],
) -> dict[str, object]:
per_chain_species_counts: dict[str, Counter[str]] = {}
effective_non_gap_rows_by_chain: dict[str, int] = {}
effective_gap_rows_by_chain: dict[str, int] = {}

for record in records:
_, descriptions = af3_parsers.parse_fasta(record.paired_msa)
species_counts: Counter[str] = Counter()
for description in descriptions[1:]:
species_identifier = _extract_paired_species_identifier(description)
if species_identifier:
species_counts[species_identifier] += 1
per_chain_species_counts[record.chain_id] = species_counts
effective_non_gap_rows_by_chain[record.chain_id] = 0
effective_gap_rows_by_chain[record.chain_id] = 0

effective_paired_row_count = 0
effective_paired_row_histogram_by_num_chains: Counter[str] = Counter()
all_species = sorted(
{
species_identifier
for species_counts in per_chain_species_counts.values()
for species_identifier in species_counts
}
)
for species_identifier in all_species:
present_chain_ids = [
chain_id
for chain_id, species_counts in per_chain_species_counts.items()
if species_counts.get(species_identifier, 0) > 0
]
if len(present_chain_ids) <= 1:
continue

kept_rows = min(
per_chain_species_counts[chain_id][species_identifier]
for chain_id in present_chain_ids
)
effective_paired_row_count += kept_rows
effective_paired_row_histogram_by_num_chains[
str(len(present_chain_ids))
] += kept_rows
for chain_id in per_chain_species_counts:
if chain_id in present_chain_ids:
effective_non_gap_rows_by_chain[chain_id] += kept_rows
else:
effective_gap_rows_by_chain[chain_id] += kept_rows

return {
"effective_paired_row_count": int(effective_paired_row_count),
"effective_paired_row_histogram_by_num_chains": {
key: int(value)
for key, value in sorted(
effective_paired_row_histogram_by_num_chains.items(),
key=lambda item: int(item[0]),
)
},
"effective_non_gap_rows_by_chain": {
key: int(value)
for key, value in sorted(effective_non_gap_rows_by_chain.items())
},
"effective_gap_rows_by_chain": {
key: int(value)
for key, value in sorted(effective_gap_rows_by_chain.items())
},
}

os.makedirs(output_dir, exist_ok=True)
effective_pairing_summary = _summarise_effective_af3_pairing(chain_records)
summary = {
"job_name": job_name,
"translation_modes": sorted(
Expand All @@ -1470,7 +1549,18 @@ def _write_translated_msa_debug_artifacts(
sum(result.occupancy_histogram.get("ge_2", 0) for result in translation_results)
),
},
"paired_row_count": int(sum(result.paired_row_count for result in translation_results)),
"paired_row_count": int(
effective_pairing_summary["effective_paired_row_count"]
),
"translated_paired_input_row_count": int(
sum(result.paired_row_count for result in translation_results)
),
"effective_paired_row_count": int(
effective_pairing_summary["effective_paired_row_count"]
),
"effective_paired_row_histogram_by_num_chains": dict(
effective_pairing_summary["effective_paired_row_histogram_by_num_chains"]
),
"invalid_paired_rows": int(
sum(result.invalid_paired_rows for result in translation_results)
),
Expand Down Expand Up @@ -1515,6 +1605,16 @@ def _write_translated_msa_debug_artifacts(
"paired_rows_with_generated_accession_count": int(
record.paired_rows_with_generated_accession_count
),
"effective_paired_msa_row_count": int(
effective_pairing_summary["effective_non_gap_rows_by_chain"].get(
chain_id, 0
)
),
"effective_paired_gap_row_count": int(
effective_pairing_summary["effective_gap_rows_by_chain"].get(
chain_id, 0
)
),
}
)

Expand Down
75 changes: 75 additions & 0 deletions test/cluster/check_alphafold3_predictions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1768,6 +1768,81 @@ def test_issue_588_mmseqs_af2_features_enable_af3_species_pairing_inference(self
f"Expected AF3 ipTM > 0.6, got {confidence_payload['iptm']}",
)

def test_issue_588_mmseqs_af2_features_enable_af3_species_pairing_trimer_inference(self):
"""AF3 should accept trimer jobs built from AF2/mmseqs2 pkl features and report effective pairing."""
self._require_mmseqs_functional_environment()
env = self._make_af3_test_env()
feature_dir = self._generate_issue_588_mmseq_features(env)

flash_impl = self._af3_flash_attention_impl()
res = subprocess.run(
[
sys.executable,
str(self.script_single),
"--input=A0ABD7FQG0+P18004+A0ABD7FQG0",
f"--output_directory={self.output_dir}",
f"--data_directory={DATA_DIR}",
f"--features_directory={feature_dir}",
"--fold_backend=alphafold3",
f"--flash_attention_implementation={flash_impl}",
"--num_diffusion_samples=1",
"--random_seed=42",
"--debug_msas",
],
capture_output=True,
text=True,
env=env,
)
self._runCommonTests(res)

result_dir = self._resolve_single_af3_result_dir()
summary_paths = sorted(
result_dir.glob("*_af2_to_af3_translation_summary.json")
)
self.assertLen(summary_paths, 1)
summary = json.loads(summary_paths[0].read_text(encoding="utf-8"))
self.assertEqual(
summary["translation_modes"],
["af3_species_pairing_from_af2_individual_msas"],
)
self.assertTrue(summary["paired_rows_valid"])
self.assertTrue(summary["unpaired_rows_valid"])
self.assertGreater(summary["translated_paired_input_row_count"], 0)
self.assertGreater(summary["paired_row_count"], 0)
self.assertGreaterEqual(
summary["translated_paired_input_row_count"],
summary["paired_row_count"],
)
histogram = summary["effective_paired_row_histogram_by_num_chains"]
self.assertTrue(histogram)
self.assertGreaterEqual(max(int(key) for key in histogram), 2)
self.assertLen(summary["chains"], 3)
for chain_summary in summary["chains"]:
self.assertGreater(chain_summary["paired_msa_row_count"], 0)
self.assertGreater(chain_summary["unpaired_msa_row_count"], 0)
self.assertGreater(chain_summary["effective_paired_msa_row_count"], 0)

input_json_paths = sorted(result_dir.glob("*_data.json"))
self.assertLen(input_json_paths, 1)
written = json.loads(input_json_paths[0].read_text(encoding="utf-8"))
protein_entries = _protein_entries_from_af3_input(written)
self.assertLen(protein_entries, 2)
all_chain_ids = []
for protein_entry in protein_entries:
entry_ids = protein_entry["id"]
if isinstance(entry_ids, str):
entry_ids = [entry_ids]
all_chain_ids.extend(entry_ids)
self.assertEqual(
_a3m_query_sequence(protein_entry["pairedMsa"]),
protein_entry["sequence"],
)
self.assertEqual(
_a3m_query_sequence(protein_entry["unpairedMsa"]),
protein_entry["sequence"],
)
self.assertCountEqual(all_chain_ids, ["A", "B", "C"])


# --------------------------------------------------------------------------- #
# parameterised "run mode" tests #
Expand Down
Loading
Loading