From 7c3c5cc92050667e4447945e7038c40b83ba6b40 Mon Sep 17 00:00:00 2001 From: Dima <33123184+DimaMolod@users.noreply.github.com> Date: Fri, 10 Apr 2026 16:40:48 +0200 Subject: [PATCH 1/6] Tighten AF3 trimer pairing summary --- .../folding_backend/alphafold3_backend.py | 102 ++++++- test/unit/test_af2_to_af3_msa.py | 233 ++++++++++++++++ test/unit/test_alphafold3_backend_helpers.py | 261 +++++++++++++++++- 3 files changed, 585 insertions(+), 11 deletions(-) diff --git a/alphapulldown/folding_backend/alphafold3_backend.py b/alphapulldown/folding_backend/alphafold3_backend.py index ffe6c84d..ca5d375a 100644 --- a/alphapulldown/folding_backend/alphafold3_backend.py +++ b/alphapulldown/folding_backend/alphafold3_backend.py @@ -18,6 +18,7 @@ import re import time import typing +from collections import Counter from collections.abc import Sequence from typing import Any, List, Dict, Union, overload @@ -1450,7 +1451,85 @@ def _write_translated_msa_debug_artifacts( if not chain_records or not translation_results: return + def _extract_paired_species_identifier(description: str) -> str: + if not description.startswith(("tr|", "sp|")): + return "" + description_tail = description.split("|")[-1] + if "_" not in description_tail: + return "" + return description_tail.rsplit("_", maxsplit=1)[-1].strip() + + def _summarise_effective_af3_pairing( + records: list[_TranslatedMsaDebugRecord], + ) -> dict[str, object]: + per_chain_species_counts: dict[str, Counter[str]] = {} + effective_non_gap_rows_by_chain: dict[str, int] = {} + effective_gap_rows_by_chain: dict[str, int] = {} + + for record in records: + _, descriptions = af3_parsers.parse_fasta(record.paired_msa) + species_counts: Counter[str] = Counter() + for description in descriptions[1:]: + species_identifier = _extract_paired_species_identifier(description) + if species_identifier: + species_counts[species_identifier] += 1 + per_chain_species_counts[record.chain_id] = species_counts + effective_non_gap_rows_by_chain[record.chain_id] = 0 + effective_gap_rows_by_chain[record.chain_id] = 0 + + effective_paired_row_count = 0 + effective_paired_row_histogram_by_num_chains: Counter[str] = Counter() + all_species = sorted( + { + species_identifier + for species_counts in per_chain_species_counts.values() + for species_identifier in species_counts + } + ) + for species_identifier in all_species: + present_chain_ids = [ + chain_id + for chain_id, species_counts in per_chain_species_counts.items() + if species_counts.get(species_identifier, 0) > 0 + ] + if len(present_chain_ids) <= 1: + continue + + kept_rows = min( + per_chain_species_counts[chain_id][species_identifier] + for chain_id in present_chain_ids + ) + effective_paired_row_count += kept_rows + effective_paired_row_histogram_by_num_chains[ + str(len(present_chain_ids)) + ] += kept_rows + for chain_id in per_chain_species_counts: + if chain_id in present_chain_ids: + effective_non_gap_rows_by_chain[chain_id] += kept_rows + else: + effective_gap_rows_by_chain[chain_id] += kept_rows + + return { + "effective_paired_row_count": int(effective_paired_row_count), + "effective_paired_row_histogram_by_num_chains": { + key: int(value) + for key, value in sorted( + effective_paired_row_histogram_by_num_chains.items(), + key=lambda item: int(item[0]), + ) + }, + "effective_non_gap_rows_by_chain": { + key: int(value) + for key, value in sorted(effective_non_gap_rows_by_chain.items()) + }, + "effective_gap_rows_by_chain": { + key: int(value) + for key, value in sorted(effective_gap_rows_by_chain.items()) + }, + } + os.makedirs(output_dir, exist_ok=True) + effective_pairing_summary = _summarise_effective_af3_pairing(chain_records) summary = { "job_name": job_name, "translation_modes": sorted( @@ -1470,7 +1549,18 @@ def _write_translated_msa_debug_artifacts( sum(result.occupancy_histogram.get("ge_2", 0) for result in translation_results) ), }, - "paired_row_count": int(sum(result.paired_row_count for result in translation_results)), + "paired_row_count": int( + effective_pairing_summary["effective_paired_row_count"] + ), + "translated_paired_input_row_count": int( + sum(result.paired_row_count for result in translation_results) + ), + "effective_paired_row_count": int( + effective_pairing_summary["effective_paired_row_count"] + ), + "effective_paired_row_histogram_by_num_chains": dict( + effective_pairing_summary["effective_paired_row_histogram_by_num_chains"] + ), "invalid_paired_rows": int( sum(result.invalid_paired_rows for result in translation_results) ), @@ -1515,6 +1605,16 @@ def _write_translated_msa_debug_artifacts( "paired_rows_with_generated_accession_count": int( record.paired_rows_with_generated_accession_count ), + "effective_paired_msa_row_count": int( + effective_pairing_summary["effective_non_gap_rows_by_chain"].get( + chain_id, 0 + ) + ), + "effective_paired_gap_row_count": int( + effective_pairing_summary["effective_gap_rows_by_chain"].get( + chain_id, 0 + ) + ), } ) diff --git a/test/unit/test_af2_to_af3_msa.py b/test/unit/test_af2_to_af3_msa.py index fad76b17..0ee99f46 100644 --- a/test/unit/test_af2_to_af3_msa.py +++ b/test/unit/test_af2_to_af3_msa.py @@ -1,4 +1,8 @@ +import itertools + import numpy as np +from alphafold3.model import data_constants as af3_data_constants +from alphafold3.model import msa_pairing as af3_msa_pairing from alphapulldown.utils.af2_to_af3_msa import ( msa_rows_and_deletions_to_a3m, @@ -55,6 +59,110 @@ def _aligned_and_deletions(sequence: str) -> tuple[str, list[int]]: return "".join(aligned_chars), deletion_counts +def _make_af2_chain_feature_dict( + sequence: str, + *, + paired_rows: list[tuple[str, str]], + unpaired_rows: list[str] | None = None, +) -> dict[str, np.ndarray]: + if unpaired_rows is None: + unpaired_rows = [] + + return { + "msa_all_seq": np.stack( + [_encode(sequence)] + [_encode(row) for _, row in paired_rows] + ), + "deletion_matrix_int_all_seq": np.zeros( + (len(paired_rows) + 1, len(sequence)), dtype=np.int32 + ), + "msa_species_identifiers_all_seq": np.asarray( + [b""] + [species_id.encode("utf-8") for species_id, _ in paired_rows], + dtype=object, + ), + "msa": np.stack([_encode(sequence)] + [_encode(row) for row in unpaired_rows]), + "deletion_matrix_int": np.zeros( + (len(unpaired_rows) + 1, len(sequence)), dtype=np.int32 + ), + } + + +def _pair_translated_msas_with_af3( + *, + chain_ids: list[str], + chain_sequences: list[str], + chain_msas, +): + paired_chains = [] + for chain_id, sequence, chain_msa in zip( + chain_ids, chain_sequences, chain_msas, strict=True + ): + paired_sequences = [sequence] + paired_descriptions = ["query"] + if chain_msa.paired_msa: + paired_sequences = [ + _aligned_and_deletions(a3m_sequence)[0] + for a3m_sequence in _a3m_sequences(chain_msa.paired_msa) + ] + paired_descriptions = _a3m_descriptions(chain_msa.paired_msa) + + species_identifiers = [] + for description in paired_descriptions: + if description.startswith("tr|") and "_" in description: + species_identifiers.append(description.rsplit("_", maxsplit=1)[-1].encode()) + else: + species_identifiers.append(b"") + + paired_chains.append( + { + "chain_id": chain_id, + "msa_all_seq": np.stack( + [_encode(aligned_sequence) for aligned_sequence in paired_sequences] + ), + "deletion_matrix_all_seq": np.zeros( + (len(paired_sequences), len(sequence)), dtype=np.int32 + ), + "msa_species_identifiers_all_seq": np.asarray( + species_identifiers, dtype=object + ), + } + ) + + return af3_msa_pairing.create_paired_features( + chains=paired_chains, + max_paired_sequences=512, + nonempty_chain_ids=set(chain_ids), + max_hits_per_species=600, + ) + + +def _non_gap_payload_rows(msa_rows: np.ndarray) -> int: + return int( + np.sum( + np.any(msa_rows[1:] != af3_data_constants.MSA_GAP_IDX, axis=1), + ) + ) + + +def _all_gap_payload_rows(msa_rows: np.ndarray) -> int: + return int( + np.sum( + np.all(msa_rows[1:] == af3_data_constants.MSA_GAP_IDX, axis=1), + ) + ) + + +def _canonical_af3_paired_rows(chains) -> tuple[tuple[str, tuple[tuple[int, ...], ...]], ...]: + canonical_rows = [] + for chain in sorted(chains, key=lambda chain: chain["chain_id"]): + canonical_rows.append( + ( + chain["chain_id"], + tuple(tuple(int(token) for token in row) for row in chain["msa_all_seq"]), + ) + ) + return tuple(canonical_rows) + + def test_msa_rows_and_deletions_to_a3m_preserves_lowercase_compression(): a3m = msa_rows_and_deletions_to_a3m( msa_rows=np.stack([_encode("A-C")]), @@ -518,3 +626,128 @@ def test_translate_af2_individual_chain_features_tracks_missing_species_ids(): "tr|APA0000002|APA0000002_ECOLX", ] assert _a3m_payload_descriptions(result.chain_msas[1].paired_msa) == ["sequence_1"] + + +def test_translate_af2_individual_chain_features_supports_three_chain_sparse_middle_pairing(): + chain_ids = ["A", "B", "C"] + chain_sequences = ["AC", "GT", "MK"] + result = translate_af2_individual_chain_features_to_af3_msas_with_stats( + chain_feature_dicts=[ + _make_af2_chain_feature_dict( + "AC", + paired_rows=[("ECOLX", "A-"), ("ECOLX", "AA")], + unpaired_rows=["AA"], + ), + _make_af2_chain_feature_dict( + "GT", + paired_rows=[], + unpaired_rows=["G-"], + ), + _make_af2_chain_feature_dict( + "MK", + paired_rows=[("ECOLX", "M-"), ("ECOLX", "MM")], + unpaired_rows=["MM"], + ), + ], + chain_sequences=chain_sequences, + ) + + assert [stats.paired_msa_row_count for stats in result.chain_stats] == [2, 0, 2] + assert result.chain_msas[1].paired_msa == "" + + paired_chains = _pair_translated_msas_with_af3( + chain_ids=chain_ids, + chain_sequences=chain_sequences, + chain_msas=result.chain_msas, + ) + + assert [chain["msa_all_seq"].shape[0] for chain in paired_chains] == [3, 3, 3] + assert _non_gap_payload_rows(paired_chains[0]["msa_all_seq"]) == 2 + assert _non_gap_payload_rows(paired_chains[1]["msa_all_seq"]) == 0 + assert _non_gap_payload_rows(paired_chains[2]["msa_all_seq"]) == 2 + assert _all_gap_payload_rows(paired_chains[1]["msa_all_seq"]) == 2 + + +def test_translate_af2_individual_chain_features_supports_three_chain_min_count_crop(): + chain_ids = ["A", "B", "C"] + chain_sequences = ["AC", "GT", "MK"] + result = translate_af2_individual_chain_features_to_af3_msas_with_stats( + chain_feature_dicts=[ + _make_af2_chain_feature_dict( + "AC", + paired_rows=[("ECOLX", "A-"), ("ECOLX", "AA"), ("ECOLX", "AC")], + ), + _make_af2_chain_feature_dict( + "GT", + paired_rows=[("ECOLX", "G-")], + ), + _make_af2_chain_feature_dict( + "MK", + paired_rows=[("ECOLX", "M-"), ("ECOLX", "MM"), ("ECOLX", "MK")], + ), + ], + chain_sequences=chain_sequences, + ) + + assert [stats.paired_msa_row_count for stats in result.chain_stats] == [3, 1, 3] + + paired_chains = _pair_translated_msas_with_af3( + chain_ids=chain_ids, + chain_sequences=chain_sequences, + chain_msas=result.chain_msas, + ) + + assert [chain["msa_all_seq"].shape[0] for chain in paired_chains] == [2, 2, 2] + assert [_non_gap_payload_rows(chain["msa_all_seq"]) for chain in paired_chains] == [ + 1, + 1, + 1, + ] + assert [_all_gap_payload_rows(chain["msa_all_seq"]) for chain in paired_chains] == [ + 0, + 0, + 0, + ] + + +def test_translate_af2_individual_chain_features_is_permutation_invariant_for_three_chains(): + base_chain_ids = ["A", "B", "C"] + base_chain_sequences = { + "A": "AC", + "B": "GT", + "C": "MK", + } + base_chain_features = { + "A": _make_af2_chain_feature_dict( + "AC", + paired_rows=[("S1", "A-"), ("S1", "AA"), ("S2", "AC")], + ), + "B": _make_af2_chain_feature_dict( + "GT", + paired_rows=[("S1", "G-")], + ), + "C": _make_af2_chain_feature_dict( + "MK", + paired_rows=[("S1", "M-"), ("S1", "MM"), ("S2", "MK")], + ), + } + + canonical_rows = None + for permutation in itertools.permutations(base_chain_ids): + chain_sequences = [base_chain_sequences[chain_id] for chain_id in permutation] + result = translate_af2_individual_chain_features_to_af3_msas_with_stats( + chain_feature_dicts=[ + base_chain_features[chain_id] for chain_id in permutation + ], + chain_sequences=chain_sequences, + ) + paired_chains = _pair_translated_msas_with_af3( + chain_ids=list(permutation), + chain_sequences=chain_sequences, + chain_msas=result.chain_msas, + ) + + permutation_rows = _canonical_af3_paired_rows(paired_chains) + if canonical_rows is None: + canonical_rows = permutation_rows + assert permutation_rows == canonical_rows diff --git a/test/unit/test_alphafold3_backend_helpers.py b/test/unit/test_alphafold3_backend_helpers.py index 9e0dec46..6b6d9df2 100644 --- a/test/unit/test_alphafold3_backend_helpers.py +++ b/test/unit/test_alphafold3_backend_helpers.py @@ -1263,6 +1263,69 @@ def _translation_stats( ) +AF2_ALPHABET = "ACDEFGHIKLMNPQRSTVWYX-" +AF2_TOKEN_BY_RESIDUE = { + residue: index for index, residue in enumerate(AF2_ALPHABET) +} + + +def _encode_af2_row(sequence: str) -> np.ndarray: + return np.array( + [AF2_TOKEN_BY_RESIDUE[residue] for residue in sequence], + dtype=np.int32, + ) + + +def _make_af2_chain_feature_dict( + sequence: str, + *, + paired_rows: list[tuple[str, str]], + unpaired_rows: list[str] | None = None, +) -> dict[str, np.ndarray]: + if unpaired_rows is None: + unpaired_rows = [] + + return { + "msa_all_seq": np.stack( + [_encode_af2_row(sequence)] + [_encode_af2_row(row) for _, row in paired_rows] + ), + "deletion_matrix_int_all_seq": np.zeros( + (len(paired_rows) + 1, len(sequence)), + dtype=np.int32, + ), + "msa_species_identifiers_all_seq": np.asarray( + [b""] + [species_id.encode("utf-8") for species_id, _ in paired_rows], + dtype=object, + ), + "msa": np.stack( + [_encode_af2_row(sequence)] + [_encode_af2_row(row) for row in unpaired_rows] + ), + "deletion_matrix_int": np.zeros( + (len(unpaired_rows) + 1, len(sequence)), + dtype=np.int32, + ), + } + + +def _make_stub_multimeric_object(af3_backend_module, *, description: str, interactors): + multimer = object.__new__(af3_backend_module.MultimericObject) + multimer.description = description + multimer.interactors = list(interactors) + + merged_sequence = "".join(interactor.sequence for interactor in interactors) + asym_ids = [] + for asym_id, interactor in enumerate(interactors, start=1): + asym_ids.extend([asym_id] * len(interactor.sequence)) + + multimer.feature_dict = { + "msa": np.stack([_encode_af2_row(merged_sequence)]), + "num_alignments": np.asarray([1], dtype=np.int32), + "deletion_matrix_int": np.zeros((1, len(merged_sequence)), dtype=np.int32), + "asym_id": np.asarray(asym_ids, dtype=np.int32), + } + return multimer + + def test_prepare_input_multimer_writes_translated_msa_debug_artifacts( af3_backend_module, monkeypatch, @@ -1282,24 +1345,32 @@ def test_prepare_input_multimer_writes_translated_msa_debug_artifacts( af3_backend_module, chain_msas=[ _translation_msas( - paired_msa=">pairA\nACDE\n", + paired_msa=( + ">query\nACDE\n" + ">tr|APA0000001|APA0000001_ECOLX\nACDE\n" + ">tr|APA0000002|APA0000002_SHIDY\nACDE\n" + ), unpaired_msa=">unpairedA\nACDE\n", ), _translation_msas( - paired_msa=">pairB\nFGHI\n", + paired_msa=( + ">query\nFGHI\n" + ">tr|APB0000001|APB0000001_ECOLX\nFGHI\n" + ">sequence_2\nFGHI\n" + ), unpaired_msa=">unpairedB\nFGHI\n", ), ], chain_stats=[ _translation_stats( - paired_msa_row_count=1, + paired_msa_row_count=2, unpaired_msa_row_count=1, paired_species_identifier_count=2, paired_rows_without_species_identifier_count=0, paired_rows_with_generated_accession_count=1, ), _translation_stats( - paired_msa_row_count=1, + paired_msa_row_count=2, unpaired_msa_row_count=2, paired_species_identifier_count=1, paired_rows_without_species_identifier_count=1, @@ -1339,9 +1410,17 @@ def test_prepare_input_multimer_writes_translated_msa_debug_artifacts( assert output_dir == str(tmp_path) assert resolve_msa_overlaps is False assert [chain.id for chain in fold_input.chains] == ["A", "B"] - assert fold_input.chains[0].paired_msa == ">pairA\nACDE\n" + assert fold_input.chains[0].paired_msa == ( + ">query\nACDE\n" + ">tr|APA0000001|APA0000001_ECOLX\nACDE\n" + ">tr|APA0000002|APA0000002_SHIDY\nACDE\n" + ) assert fold_input.chains[0].unpaired_msa == ">unpairedA\nACDE\n" - assert fold_input.chains[1].paired_msa == ">pairB\nFGHI\n" + assert fold_input.chains[1].paired_msa == ( + ">query\nFGHI\n" + ">tr|APB0000001|APB0000001_ECOLX\nFGHI\n" + ">sequence_2\nFGHI\n" + ) assert fold_input.chains[1].unpaired_msa == ">unpairedB\nFGHI\n" summary_path = tmp_path / f"{fold_input.sanitised_name()}_af2_to_af3_translation_summary.json" @@ -1350,32 +1429,43 @@ def test_prepare_input_multimer_writes_translated_msa_debug_artifacts( assert summary["translation_modes"] == [ "af3_species_pairing_from_af2_individual_msas" ] - assert summary["paired_row_count"] == 2 + assert summary["translated_paired_input_row_count"] == 2 + assert summary["paired_row_count"] == 1 + assert summary["effective_paired_row_count"] == 1 + assert summary["effective_paired_row_histogram_by_num_chains"] == {"2": 1} assert summary["chains"] == [ { "chain_id": "A", "chain_description": "protA", "chain_length": 4, - "paired_msa_row_count": 1, + "paired_msa_row_count": 2, "unpaired_msa_row_count": 1, "paired_species_identifier_count": 2, "paired_rows_without_species_identifier_count": 0, "paired_rows_with_generated_accession_count": 1, + "effective_paired_msa_row_count": 1, + "effective_paired_gap_row_count": 0, }, { "chain_id": "B", "chain_description": "protB", "chain_length": 4, - "paired_msa_row_count": 1, + "paired_msa_row_count": 2, "unpaired_msa_row_count": 2, "paired_species_identifier_count": 1, "paired_rows_without_species_identifier_count": 1, "paired_rows_with_generated_accession_count": 0, + "effective_paired_msa_row_count": 1, + "effective_paired_gap_row_count": 0, }, ] assert ( tmp_path / f"{fold_input.sanitised_name()}_chain-A_paired_input.a3m" - ).read_text(encoding="utf-8") == ">pairA\nACDE\n" + ).read_text(encoding="utf-8") == ( + ">query\nACDE\n" + ">tr|APA0000001|APA0000001_ECOLX\nACDE\n" + ">tr|APA0000002|APA0000002_SHIDY\nACDE\n" + ) assert ( tmp_path / f"{fold_input.sanitised_name()}_chain-B_unpaired_input.a3m" ).read_text(encoding="utf-8") == ">unpairedB\nFGHI\n" @@ -1548,6 +1638,157 @@ def test_prepare_input_multimer_skips_complex_fallback_after_duplicate_residue_n ] +def test_prepare_input_multimer_trimer_preserves_sparse_middle_chain_pairing( + af3_backend_module, + monkeypatch, + tmp_path, +): + interactor_a = MonomericObject("protA", "AC") + interactor_b = MonomericObject("protB", "GT") + interactor_c = MonomericObject("protC", "MK") + interactor_a.feature_dict = _make_af2_chain_feature_dict( + "AC", + paired_rows=[("ECOLX", "A-"), ("ECOLX", "AA")], + unpaired_rows=["AA"], + ) + interactor_b.feature_dict = _make_af2_chain_feature_dict( + "GT", + paired_rows=[], + unpaired_rows=["G-"], + ) + interactor_c.feature_dict = _make_af2_chain_feature_dict( + "MK", + paired_rows=[("ECOLX", "M-"), ("ECOLX", "MM")], + unpaired_rows=["MM"], + ) + + multimer = _make_stub_multimeric_object( + af3_backend_module, + description="protA_and_protB_and_protC", + interactors=[interactor_a, interactor_b, interactor_c], + ) + + def _unexpected_fallback(**kwargs): + raise AssertionError(f"Unexpected AF2 merged-MSA fallback: {kwargs}") + + monkeypatch.setattr( + af3_backend_module, + "translate_af2_complex_msa_to_af3_unpaired_chain_msas_with_stats", + _unexpected_fallback, + ) + + prepared_inputs = af3_backend_module.AlphaFold3Backend.prepare_input( + objects_to_model=[ + { + "object": multimer, + "output_dir": str(tmp_path), + } + ], + random_seed=43, + debug_msas=True, + ) + + fold_input, (_output_dir, resolve_msa_overlaps) = next(iter(prepared_inputs[0].items())) + assert resolve_msa_overlaps is False + assert [chain.id for chain in fold_input.chains] == ["A", "B", "C"] + assert fold_input.chains[0].paired_msa + assert fold_input.chains[1].paired_msa == "" + assert fold_input.chains[2].paired_msa + + summary_path = tmp_path / f"{fold_input.sanitised_name()}_af2_to_af3_translation_summary.json" + summary = json.loads(summary_path.read_text(encoding="utf-8")) + assert summary["translation_modes"] == ["af3_species_pairing_from_af2_individual_msas"] + assert summary["translated_paired_input_row_count"] == 2 + assert summary["paired_row_count"] == 2 + assert summary["effective_paired_row_count"] == 2 + assert summary["effective_paired_row_histogram_by_num_chains"] == {"2": 2} + + chain_summary = {chain["chain_id"]: chain for chain in summary["chains"]} + assert chain_summary["A"]["paired_msa_row_count"] == 2 + assert chain_summary["A"]["effective_paired_msa_row_count"] == 2 + assert chain_summary["A"]["effective_paired_gap_row_count"] == 0 + assert chain_summary["B"]["paired_msa_row_count"] == 0 + assert chain_summary["B"]["effective_paired_msa_row_count"] == 0 + assert chain_summary["B"]["effective_paired_gap_row_count"] == 2 + assert chain_summary["C"]["paired_msa_row_count"] == 2 + assert chain_summary["C"]["effective_paired_msa_row_count"] == 2 + assert chain_summary["C"]["effective_paired_gap_row_count"] == 0 + + +def test_prepare_input_multimer_trimer_reports_effective_min_count_pairing( + af3_backend_module, + monkeypatch, + tmp_path, +): + interactor_a = MonomericObject("protA", "AC") + interactor_b = MonomericObject("protB", "GT") + interactor_c = MonomericObject("protC", "MK") + interactor_a.feature_dict = _make_af2_chain_feature_dict( + "AC", + paired_rows=[("ECOLX", "A-"), ("ECOLX", "AA"), ("ECOLX", "AC")], + unpaired_rows=["AA"], + ) + interactor_b.feature_dict = _make_af2_chain_feature_dict( + "GT", + paired_rows=[("ECOLX", "G-")], + unpaired_rows=["GG"], + ) + interactor_c.feature_dict = _make_af2_chain_feature_dict( + "MK", + paired_rows=[("ECOLX", "M-"), ("ECOLX", "MM"), ("ECOLX", "MK")], + unpaired_rows=["MM"], + ) + + multimer = _make_stub_multimeric_object( + af3_backend_module, + description="protA_and_protB_and_protC", + interactors=[interactor_a, interactor_b, interactor_c], + ) + + def _unexpected_fallback(**kwargs): + raise AssertionError(f"Unexpected AF2 merged-MSA fallback: {kwargs}") + + monkeypatch.setattr( + af3_backend_module, + "translate_af2_complex_msa_to_af3_unpaired_chain_msas_with_stats", + _unexpected_fallback, + ) + + prepared_inputs = af3_backend_module.AlphaFold3Backend.prepare_input( + objects_to_model=[ + { + "object": multimer, + "output_dir": str(tmp_path), + } + ], + random_seed=47, + debug_msas=True, + ) + + fold_input, (_output_dir, resolve_msa_overlaps) = next(iter(prepared_inputs[0].items())) + assert resolve_msa_overlaps is False + assert all(chain.paired_msa for chain in fold_input.chains) + + summary_path = tmp_path / f"{fold_input.sanitised_name()}_af2_to_af3_translation_summary.json" + summary = json.loads(summary_path.read_text(encoding="utf-8")) + assert summary["translation_modes"] == ["af3_species_pairing_from_af2_individual_msas"] + assert summary["translated_paired_input_row_count"] == 3 + assert summary["paired_row_count"] == 1 + assert summary["effective_paired_row_count"] == 1 + assert summary["effective_paired_row_histogram_by_num_chains"] == {"3": 1} + + chain_summary = {chain["chain_id"]: chain for chain in summary["chains"]} + assert chain_summary["A"]["paired_msa_row_count"] == 3 + assert chain_summary["B"]["paired_msa_row_count"] == 1 + assert chain_summary["C"]["paired_msa_row_count"] == 3 + assert chain_summary["A"]["effective_paired_msa_row_count"] == 1 + assert chain_summary["B"]["effective_paired_msa_row_count"] == 1 + assert chain_summary["C"]["effective_paired_msa_row_count"] == 1 + assert chain_summary["A"]["effective_paired_gap_row_count"] == 0 + assert chain_summary["B"]["effective_paired_gap_row_count"] == 0 + assert chain_summary["C"]["effective_paired_gap_row_count"] == 0 + + def test_prepare_input_builds_template_mmcif_and_skips_zero_atom_templates( af3_backend_module, monkeypatch, From 99a1bf3f34673efd2920f586385361fae7824047 Mon Sep 17 00:00:00 2001 From: Dima <33123184+DimaMolod@users.noreply.github.com> Date: Fri, 10 Apr 2026 16:45:25 +0200 Subject: [PATCH 2/6] Add AF3 trimer cluster smoke test --- test/cluster/check_alphafold3_predictions.py | 70 ++++++++++++++++++++ 1 file changed, 70 insertions(+) diff --git a/test/cluster/check_alphafold3_predictions.py b/test/cluster/check_alphafold3_predictions.py index f0aef0d2..fa53c46e 100755 --- a/test/cluster/check_alphafold3_predictions.py +++ b/test/cluster/check_alphafold3_predictions.py @@ -1768,6 +1768,76 @@ def test_issue_588_mmseqs_af2_features_enable_af3_species_pairing_inference(self f"Expected AF3 ipTM > 0.6, got {confidence_payload['iptm']}", ) + def test_issue_588_mmseqs_af2_features_enable_af3_species_pairing_trimer_inference(self): + """AF3 should accept trimer jobs built from AF2/mmseqs2 pkl features and report effective pairing.""" + self._require_mmseqs_functional_environment() + env = self._make_af3_test_env() + feature_dir = self._generate_issue_588_mmseq_features(env) + + flash_impl = self._af3_flash_attention_impl() + res = subprocess.run( + [ + sys.executable, + str(self.script_single), + "--input=A0ABD7FQG0+P18004+A0ABD7FQG0", + f"--output_directory={self.output_dir}", + f"--data_directory={DATA_DIR}", + f"--features_directory={feature_dir}", + "--fold_backend=alphafold3", + f"--flash_attention_implementation={flash_impl}", + "--num_diffusion_samples=1", + "--random_seed=42", + "--debug_msas", + ], + capture_output=True, + text=True, + env=env, + ) + self._runCommonTests(res) + + result_dir = self._resolve_single_af3_result_dir() + summary_paths = sorted( + result_dir.glob("*_af2_to_af3_translation_summary.json") + ) + self.assertLen(summary_paths, 1) + summary = json.loads(summary_paths[0].read_text(encoding="utf-8")) + self.assertEqual( + summary["translation_modes"], + ["af3_species_pairing_from_af2_individual_msas"], + ) + self.assertTrue(summary["paired_rows_valid"]) + self.assertTrue(summary["unpaired_rows_valid"]) + self.assertEqual(summary["paired_row_count"], summary["effective_paired_row_count"]) + self.assertGreater(summary["translated_paired_input_row_count"], 0) + self.assertGreater(summary["effective_paired_row_count"], 0) + self.assertGreaterEqual( + summary["translated_paired_input_row_count"], + summary["effective_paired_row_count"], + ) + histogram = summary["effective_paired_row_histogram_by_num_chains"] + self.assertTrue(histogram) + self.assertGreaterEqual(max(int(key) for key in histogram), 2) + self.assertLen(summary["chains"], 3) + for chain_summary in summary["chains"]: + self.assertGreater(chain_summary["paired_msa_row_count"], 0) + self.assertGreater(chain_summary["unpaired_msa_row_count"], 0) + self.assertGreater(chain_summary["effective_paired_msa_row_count"], 0) + + input_json_paths = sorted(result_dir.glob("*_data.json")) + self.assertLen(input_json_paths, 1) + written = json.loads(input_json_paths[0].read_text(encoding="utf-8")) + protein_entries = _protein_entries_from_af3_input(written) + self.assertLen(protein_entries, 3) + for protein_entry in protein_entries: + self.assertEqual( + _a3m_query_sequence(protein_entry["pairedMsa"]), + protein_entry["sequence"], + ) + self.assertEqual( + _a3m_query_sequence(protein_entry["unpairedMsa"]), + protein_entry["sequence"], + ) + # --------------------------------------------------------------------------- # # parameterised "run mode" tests # From 3d8fe2b70edf2f6c7367fb3842b789519a2462d6 Mon Sep 17 00:00:00 2001 From: Dima <33123184+DimaMolod@users.noreply.github.com> Date: Fri, 10 Apr 2026 17:12:35 +0200 Subject: [PATCH 3/6] Force cluster AF3 tests to use checkout code --- test/cluster/check_alphafold3_predictions.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/test/cluster/check_alphafold3_predictions.py b/test/cluster/check_alphafold3_predictions.py index fa53c46e..5af82187 100755 --- a/test/cluster/check_alphafold3_predictions.py +++ b/test/cluster/check_alphafold3_predictions.py @@ -1088,6 +1088,14 @@ def _check_chain_counts_and_sequences(self, protein_list: str): def _make_af3_test_env(self) -> Dict[str, str]: flash_impl = self._af3_flash_attention_impl() env = os.environ.copy() + # Force subprocesses launched from test helpers to import this checkout, + # not an older AlphaPulldown installation from the cluster env. + existing_pythonpath = env.get("PYTHONPATH", "") + env["PYTHONPATH"] = ( + f"{REPO_ROOT}:{existing_pythonpath}" + if existing_pythonpath + else str(REPO_ROOT) + ) env["XLA_FLAGS"] = "--xla_disable_hlo_passes=custom-kernel-fusion-rewriter --xla_gpu_force_compilation_parallelism=0" env["XLA_PYTHON_CLIENT_PREALLOCATE"] = "true" env["XLA_CLIENT_MEM_FRACTION"] = "0.95" From eb2a71dbeaf4aced5062cb1ee940e5c0cd70937f Mon Sep 17 00:00:00 2001 From: Dima <33123184+DimaMolod@users.noreply.github.com> Date: Fri, 10 Apr 2026 17:48:33 +0200 Subject: [PATCH 4/6] Handle grouped identical chains in AF3 trimer test --- test/cluster/check_alphafold3_predictions.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/test/cluster/check_alphafold3_predictions.py b/test/cluster/check_alphafold3_predictions.py index 5af82187..84733cd8 100755 --- a/test/cluster/check_alphafold3_predictions.py +++ b/test/cluster/check_alphafold3_predictions.py @@ -1835,8 +1835,13 @@ def test_issue_588_mmseqs_af2_features_enable_af3_species_pairing_trimer_inferen self.assertLen(input_json_paths, 1) written = json.loads(input_json_paths[0].read_text(encoding="utf-8")) protein_entries = _protein_entries_from_af3_input(written) - self.assertLen(protein_entries, 3) + self.assertLen(protein_entries, 2) + all_chain_ids = [] for protein_entry in protein_entries: + entry_ids = protein_entry["id"] + if isinstance(entry_ids, str): + entry_ids = [entry_ids] + all_chain_ids.extend(entry_ids) self.assertEqual( _a3m_query_sequence(protein_entry["pairedMsa"]), protein_entry["sequence"], @@ -1845,6 +1850,7 @@ def test_issue_588_mmseqs_af2_features_enable_af3_species_pairing_trimer_inferen _a3m_query_sequence(protein_entry["unpairedMsa"]), protein_entry["sequence"], ) + self.assertCountEqual(all_chain_ids, ["A", "B", "C"]) # --------------------------------------------------------------------------- # From 6d62d271a2defc795e96607340c2f7af20b78b50 Mon Sep 17 00:00:00 2001 From: Dima <33123184+DimaMolod@users.noreply.github.com> Date: Fri, 10 Apr 2026 19:45:49 +0200 Subject: [PATCH 5/6] Compare AF2 trimer translation with native AF3 pairing --- test/unit/test_af2_to_af3_msa.py | 173 +++++++++++++++++++++++++++++++ 1 file changed, 173 insertions(+) diff --git a/test/unit/test_af2_to_af3_msa.py b/test/unit/test_af2_to_af3_msa.py index 0ee99f46..9f30b035 100644 --- a/test/unit/test_af2_to_af3_msa.py +++ b/test/unit/test_af2_to_af3_msa.py @@ -1,4 +1,5 @@ import itertools +from types import SimpleNamespace import numpy as np from alphafold3.model import data_constants as af3_data_constants @@ -163,6 +164,49 @@ def _canonical_af3_paired_rows(chains) -> tuple[tuple[str, tuple[tuple[int, ...] return tuple(canonical_rows) +def _build_native_af3_chain_msas( + *, + chain_ids: list[str], + chain_sequences: list[str], + paired_rows_by_chain: dict[str, list[tuple[str, str]]], + unpaired_rows_by_chain: dict[str, list[str]] | None = None, +): + if unpaired_rows_by_chain is None: + unpaired_rows_by_chain = {} + + chain_msas = [] + for chain_id, sequence in zip(chain_ids, chain_sequences, strict=True): + paired_rows = paired_rows_by_chain.get(chain_id, []) + paired_msa = "" + if paired_rows: + paired_lines = [">query", sequence] + for index, (species_id, row) in enumerate(paired_rows, start=1): + if species_id: + accession = f"AP{chain_id}{index:07d}" + description = f"tr|{accession}|{accession}_{species_id}" + else: + description = f"sequence_{index}" + paired_lines.extend([f">{description}", row]) + paired_msa = "\n".join(paired_lines) + "\n" + + unpaired_rows = unpaired_rows_by_chain.get(chain_id, []) + unpaired_msa = "" + if unpaired_rows: + unpaired_lines = [">query", sequence] + for index, row in enumerate(unpaired_rows, start=1): + unpaired_lines.extend([f">sequence_{index}", row]) + unpaired_msa = "\n".join(unpaired_lines) + "\n" + + chain_msas.append( + SimpleNamespace( + paired_msa=paired_msa, + unpaired_msa=unpaired_msa, + ) + ) + + return chain_msas + + def test_msa_rows_and_deletions_to_a3m_preserves_lowercase_compression(): a3m = msa_rows_and_deletions_to_a3m( msa_rows=np.stack([_encode("A-C")]), @@ -751,3 +795,132 @@ def test_translate_af2_individual_chain_features_is_permutation_invariant_for_th if canonical_rows is None: canonical_rows = permutation_rows assert permutation_rows == canonical_rows + + +def test_translate_af2_individual_chain_features_matches_native_af3_pairing_for_sparse_trimer_permutations(): + base_chain_ids = ["A", "B", "C"] + base_chain_sequences = { + "A": "AC", + "B": "GT", + "C": "MK", + } + base_chain_features = { + "A": _make_af2_chain_feature_dict( + "AC", + paired_rows=[("ECOLX", "A-"), ("ECOLX", "AA")], + unpaired_rows=["AA"], + ), + "B": _make_af2_chain_feature_dict( + "GT", + paired_rows=[], + unpaired_rows=["G-"], + ), + "C": _make_af2_chain_feature_dict( + "MK", + paired_rows=[("ECOLX", "M-"), ("ECOLX", "MM")], + unpaired_rows=["MM"], + ), + } + native_paired_rows = { + "A": [("ECOLX", "A-"), ("ECOLX", "AA")], + "B": [], + "C": [("ECOLX", "M-"), ("ECOLX", "MM")], + } + native_unpaired_rows = { + "A": ["AA"], + "B": ["G-"], + "C": ["MM"], + } + + canonical_rows = None + for permutation in itertools.permutations(base_chain_ids): + chain_ids = list(permutation) + chain_sequences = [base_chain_sequences[chain_id] for chain_id in chain_ids] + translated = translate_af2_individual_chain_features_to_af3_msas_with_stats( + chain_feature_dicts=[base_chain_features[chain_id] for chain_id in chain_ids], + chain_sequences=chain_sequences, + ) + translated_rows = _canonical_af3_paired_rows( + _pair_translated_msas_with_af3( + chain_ids=chain_ids, + chain_sequences=chain_sequences, + chain_msas=translated.chain_msas, + ) + ) + native_rows = _canonical_af3_paired_rows( + _pair_translated_msas_with_af3( + chain_ids=chain_ids, + chain_sequences=chain_sequences, + chain_msas=_build_native_af3_chain_msas( + chain_ids=chain_ids, + chain_sequences=chain_sequences, + paired_rows_by_chain=native_paired_rows, + unpaired_rows_by_chain=native_unpaired_rows, + ), + ) + ) + + assert translated_rows == native_rows + if canonical_rows is None: + canonical_rows = translated_rows + assert translated_rows == canonical_rows + + +def test_translate_af2_individual_chain_features_matches_native_af3_pairing_for_min_count_trimer_permutations(): + base_chain_ids = ["A", "B", "C"] + base_chain_sequences = { + "A": "AC", + "B": "GT", + "C": "MK", + } + base_chain_features = { + "A": _make_af2_chain_feature_dict( + "AC", + paired_rows=[("S1", "A-"), ("S1", "AA"), ("S2", "AC")], + ), + "B": _make_af2_chain_feature_dict( + "GT", + paired_rows=[("S1", "G-")], + ), + "C": _make_af2_chain_feature_dict( + "MK", + paired_rows=[("S1", "M-"), ("S1", "MM"), ("S2", "MK")], + ), + } + native_paired_rows = { + "A": [("S1", "A-"), ("S1", "AA"), ("S2", "AC")], + "B": [("S1", "G-")], + "C": [("S1", "M-"), ("S1", "MM"), ("S2", "MK")], + } + + canonical_rows = None + for permutation in itertools.permutations(base_chain_ids): + chain_ids = list(permutation) + chain_sequences = [base_chain_sequences[chain_id] for chain_id in chain_ids] + translated = translate_af2_individual_chain_features_to_af3_msas_with_stats( + chain_feature_dicts=[base_chain_features[chain_id] for chain_id in chain_ids], + chain_sequences=chain_sequences, + ) + translated_rows = _canonical_af3_paired_rows( + _pair_translated_msas_with_af3( + chain_ids=chain_ids, + chain_sequences=chain_sequences, + chain_msas=translated.chain_msas, + ) + ) + native_rows = _canonical_af3_paired_rows( + _pair_translated_msas_with_af3( + chain_ids=chain_ids, + chain_sequences=chain_sequences, + chain_msas=_build_native_af3_chain_msas( + chain_ids=chain_ids, + chain_sequences=chain_sequences, + paired_rows_by_chain=native_paired_rows, + ), + ) + ) + + assert translated_rows == native_rows + if canonical_rows is None: + canonical_rows = translated_rows + assert translated_rows == canonical_rows From 4052957363c60e0a99b2f8799426dbc1769a6b76 Mon Sep 17 00:00:00 2001 From: Dima <33123184+DimaMolod@users.noreply.github.com> Date: Fri, 10 Apr 2026 20:01:45 +0200 Subject: [PATCH 6/6] Trim release-only AF3 trimer test scaffolding --- test/cluster/check_alphafold3_predictions.py | 13 +- test/unit/test_af2_to_af3_msa.py | 118 +++++++++---------- 2 files changed, 57 insertions(+), 74 deletions(-) diff --git a/test/cluster/check_alphafold3_predictions.py b/test/cluster/check_alphafold3_predictions.py index 84733cd8..91059b9a 100755 --- a/test/cluster/check_alphafold3_predictions.py +++ b/test/cluster/check_alphafold3_predictions.py @@ -1088,14 +1088,6 @@ def _check_chain_counts_and_sequences(self, protein_list: str): def _make_af3_test_env(self) -> Dict[str, str]: flash_impl = self._af3_flash_attention_impl() env = os.environ.copy() - # Force subprocesses launched from test helpers to import this checkout, - # not an older AlphaPulldown installation from the cluster env. - existing_pythonpath = env.get("PYTHONPATH", "") - env["PYTHONPATH"] = ( - f"{REPO_ROOT}:{existing_pythonpath}" - if existing_pythonpath - else str(REPO_ROOT) - ) env["XLA_FLAGS"] = "--xla_disable_hlo_passes=custom-kernel-fusion-rewriter --xla_gpu_force_compilation_parallelism=0" env["XLA_PYTHON_CLIENT_PREALLOCATE"] = "true" env["XLA_CLIENT_MEM_FRACTION"] = "0.95" @@ -1815,12 +1807,11 @@ def test_issue_588_mmseqs_af2_features_enable_af3_species_pairing_trimer_inferen ) self.assertTrue(summary["paired_rows_valid"]) self.assertTrue(summary["unpaired_rows_valid"]) - self.assertEqual(summary["paired_row_count"], summary["effective_paired_row_count"]) self.assertGreater(summary["translated_paired_input_row_count"], 0) - self.assertGreater(summary["effective_paired_row_count"], 0) + self.assertGreater(summary["paired_row_count"], 0) self.assertGreaterEqual( summary["translated_paired_input_row_count"], - summary["effective_paired_row_count"], + summary["paired_row_count"], ) histogram = summary["effective_paired_row_histogram_by_num_chains"] self.assertTrue(histogram) diff --git a/test/unit/test_af2_to_af3_msa.py b/test/unit/test_af2_to_af3_msa.py index 9f30b035..4d2769bb 100644 --- a/test/unit/test_af2_to_af3_msa.py +++ b/test/unit/test_af2_to_af3_msa.py @@ -207,6 +207,48 @@ def _build_native_af3_chain_msas( return chain_msas +def _assert_translated_pairing_matches_native_af3_across_permutations( + *, + base_chain_ids: list[str], + base_chain_sequences: dict[str, str], + base_chain_features: dict[str, dict[str, np.ndarray]], + native_paired_rows: dict[str, list[tuple[str, str]]], + native_unpaired_rows: dict[str, list[str]] | None = None, +) -> None: + canonical_rows = None + for permutation in itertools.permutations(base_chain_ids): + chain_ids = list(permutation) + chain_sequences = [base_chain_sequences[chain_id] for chain_id in chain_ids] + translated = translate_af2_individual_chain_features_to_af3_msas_with_stats( + chain_feature_dicts=[base_chain_features[chain_id] for chain_id in chain_ids], + chain_sequences=chain_sequences, + ) + translated_rows = _canonical_af3_paired_rows( + _pair_translated_msas_with_af3( + chain_ids=chain_ids, + chain_sequences=chain_sequences, + chain_msas=translated.chain_msas, + ) + ) + native_rows = _canonical_af3_paired_rows( + _pair_translated_msas_with_af3( + chain_ids=chain_ids, + chain_sequences=chain_sequences, + chain_msas=_build_native_af3_chain_msas( + chain_ids=chain_ids, + chain_sequences=chain_sequences, + paired_rows_by_chain=native_paired_rows, + unpaired_rows_by_chain=native_unpaired_rows, + ), + ) + ) + + assert translated_rows == native_rows + if canonical_rows is None: + canonical_rows = translated_rows + assert translated_rows == canonical_rows + + def test_msa_rows_and_deletions_to_a3m_preserves_lowercase_compression(): a3m = msa_rows_and_deletions_to_a3m( msa_rows=np.stack([_encode("A-C")]), @@ -832,38 +874,13 @@ def test_translate_af2_individual_chain_features_matches_native_af3_pairing_for_ "C": ["MM"], } - canonical_rows = None - for permutation in itertools.permutations(base_chain_ids): - chain_ids = list(permutation) - chain_sequences = [base_chain_sequences[chain_id] for chain_id in chain_ids] - translated = translate_af2_individual_chain_features_to_af3_msas_with_stats( - chain_feature_dicts=[base_chain_features[chain_id] for chain_id in chain_ids], - chain_sequences=chain_sequences, - ) - translated_rows = _canonical_af3_paired_rows( - _pair_translated_msas_with_af3( - chain_ids=chain_ids, - chain_sequences=chain_sequences, - chain_msas=translated.chain_msas, - ) - ) - native_rows = _canonical_af3_paired_rows( - _pair_translated_msas_with_af3( - chain_ids=chain_ids, - chain_sequences=chain_sequences, - chain_msas=_build_native_af3_chain_msas( - chain_ids=chain_ids, - chain_sequences=chain_sequences, - paired_rows_by_chain=native_paired_rows, - unpaired_rows_by_chain=native_unpaired_rows, - ), - ) - ) - - assert translated_rows == native_rows - if canonical_rows is None: - canonical_rows = translated_rows - assert translated_rows == canonical_rows + _assert_translated_pairing_matches_native_af3_across_permutations( + base_chain_ids=base_chain_ids, + base_chain_sequences=base_chain_sequences, + base_chain_features=base_chain_features, + native_paired_rows=native_paired_rows, + native_unpaired_rows=native_unpaired_rows, + ) def test_translate_af2_individual_chain_features_matches_native_af3_pairing_for_min_count_trimer_permutations(): @@ -893,34 +910,9 @@ def test_translate_af2_individual_chain_features_matches_native_af3_pairing_for_ "C": [("S1", "M-"), ("S1", "MM"), ("S2", "MK")], } - canonical_rows = None - for permutation in itertools.permutations(base_chain_ids): - chain_ids = list(permutation) - chain_sequences = [base_chain_sequences[chain_id] for chain_id in chain_ids] - translated = translate_af2_individual_chain_features_to_af3_msas_with_stats( - chain_feature_dicts=[base_chain_features[chain_id] for chain_id in chain_ids], - chain_sequences=chain_sequences, - ) - translated_rows = _canonical_af3_paired_rows( - _pair_translated_msas_with_af3( - chain_ids=chain_ids, - chain_sequences=chain_sequences, - chain_msas=translated.chain_msas, - ) - ) - native_rows = _canonical_af3_paired_rows( - _pair_translated_msas_with_af3( - chain_ids=chain_ids, - chain_sequences=chain_sequences, - chain_msas=_build_native_af3_chain_msas( - chain_ids=chain_ids, - chain_sequences=chain_sequences, - paired_rows_by_chain=native_paired_rows, - ), - ) - ) - - assert translated_rows == native_rows - if canonical_rows is None: - canonical_rows = translated_rows - assert translated_rows == canonical_rows + _assert_translated_pairing_matches_native_af3_across_permutations( + base_chain_ids=base_chain_ids, + base_chain_sequences=base_chain_sequences, + base_chain_features=base_chain_features, + native_paired_rows=native_paired_rows, + )