diff --git a/models/rfd3/src/rfd3/inference/input_parsing.py b/models/rfd3/src/rfd3/inference/input_parsing.py index d97b3be3..0bd98745 100644 --- a/models/rfd3/src/rfd3/inference/input_parsing.py +++ b/models/rfd3/src/rfd3/inference/input_parsing.py @@ -9,8 +9,9 @@ from typing import Any, Dict, List, Optional, Union import numpy as np -from atomworks.constants import STANDARD_AA +from atomworks.constants import STANDARD_AA, STANDARD_DNA, STANDARD_RNA from atomworks.io.parser import parse_atom_array +from atomworks.io.utils.bonds import get_inferred_polymer_bonds # from atomworks.ml.datasets.datasets import BaseDataset from atomworks.ml.transforms.base import TransformedDict @@ -48,7 +49,6 @@ ) from rfd3.transforms.util_transforms import assign_types_ from rfd3.utils.inference import ( - _restore_bonds_for_nonstandard_residues, extract_ligand_array, inference_load_, set_com, @@ -58,6 +58,7 @@ from foundry.common import exists from foundry.utils.components import ( + fetch_mask_from_idx, get_design_pattern_with_constraints, get_motif_components_and_breaks, ) @@ -968,6 +969,259 @@ def create_motif_residue( return token +def _polymer_link_atoms_for_residue(res_name: str) -> set[frozenset[str]]: + """ + Return the atom-name pairs that represent canonical polymerization atoms for a residue. + + Only standard AA/DNA/RNA residues are treated as having canonical backbone links; PTMs + and other chem comp types fall back to nonstandard handling. + """ + if res_name in STANDARD_AA: + return {frozenset({"C", "N"})} + if res_name in STANDARD_DNA or res_name in STANDARD_RNA: + return {frozenset({"O3'", "P"}), frozenset({"O3*", "P"})} # allow legacy O3* + return set() + + +def _is_standard_polymer_backbone_bond(atom_a: struc.Atom, atom_b: struc.Atom) -> bool: + if atom_a.chain_id != atom_b.chain_id: + return False + if abs(atom_a.res_id - atom_b.res_id) != 1: + return False + atom_pair = frozenset({atom_a.atom_name, atom_b.atom_name}) + + pairs_a = _polymer_link_atoms_for_residue(atom_a.res_name) + pairs_b = _polymer_link_atoms_for_residue(atom_b.res_name) + shared_pairs = pairs_a & pairs_b + return atom_pair in shared_pairs + + +def _is_polymer_backbone_like(atom_a: struc.Atom, atom_b: struc.Atom) -> bool: + """ + Broader backbone check that treats canonical polymer atoms (C/N for peptide, + O3'/O3*–P for nucleic) as backbone links even when the residue itself is + non-standard (e.g., PTR/SEP). + """ + if atom_a.chain_id != atom_b.chain_id: + return False + if abs(atom_a.res_id - atom_b.res_id) != 1: + return False + atom_pair = frozenset({atom_a.atom_name, atom_b.atom_name}) + return atom_pair in { + frozenset({"C", "N"}), + frozenset({"O3'", "P"}), + frozenset({"O3*", "P"}), + } + + +def _restore_component_bonds( + atom_array_accum: struc.AtomArray, + src_atom_array: Optional[struc.AtomArray], + source_to_accum_idx: Dict[int, int], + source_idx_to_component: Dict[int, str], + unindexed_components: set[str], +) -> struc.AtomArray: + """ + Rehydrate bonds from the input structure onto the accumulated array. + + - Replays bonds from `src_atom_array` using the provided source→accum mappings. + - Skips canonical polymer backbone bonds (peptide/nucleic) that are reconstructed elsewhere. + - Protects unindexed components by disallowing cross-residue bonds that involve + an unindexed residue (except standard backbone). + - Emits warnings when a source bond cannot be remapped because one endpoint was + dropped during accumulation. + """ + if atom_array_accum.bonds is None: + atom_array_accum.bonds = struc.BondList(atom_array_accum.array_length()) + + if ( + src_atom_array is None + or not hasattr(src_atom_array, "bonds") + or src_atom_array.bonds is None + or not source_to_accum_idx + ): + return atom_array_accum + + bonds_to_add: List[List[int]] = [] + seen_pairs: set[tuple[int, int]] = set() + src_bonds = np.asarray(src_atom_array.bonds.as_array(), dtype=np.int64) + + def _is_unindexed_source(idx: int) -> bool: + component = source_idx_to_component.get(idx) + return component in unindexed_components if component is not None else False + + def _fmt_atom(atom: struc.Atom) -> str: + # Use a readable residue/atom separator to avoid names running together + return f"{atom.chain_id}{atom.res_id}:{atom.res_name}_{atom.atom_name}" + + for atom_i_idx, atom_j_idx, bond_type in src_bonds: + atom_i_idx = int(atom_i_idx) + atom_j_idx = int(atom_j_idx) + bond_type = int(bond_type) + mapped_i = source_to_accum_idx.get(atom_i_idx) + mapped_j = source_to_accum_idx.get(atom_j_idx) + + atom_i = src_atom_array[atom_i_idx] + atom_j = src_atom_array[atom_j_idx] + + # If we only have one side of the bond, assert if the mapped atom is from an + # unindexed component and the bond would connect across residues. + if mapped_i is None or mapped_j is None: + if _is_standard_polymer_backbone_bond( + atom_i, atom_j + ) or _is_polymer_backbone_like(atom_i, atom_j): + continue + if mapped_i is not None and _is_unindexed_source(atom_i_idx): + if ( + atom_i.chain_id != atom_j.chain_id + or atom_i.res_id != atom_j.res_id + or not ( + _is_standard_polymer_backbone_bond(atom_i, atom_j) + or _is_polymer_backbone_like(atom_i, atom_j) + ) + ): + raise AssertionError( + f"Unsupported bond between unindexed component {atom_i.chain_id}{atom_i.res_id} " + f"and omitted residue {atom_j.chain_id}{atom_j.res_id}." + ) + if mapped_j is not None and _is_unindexed_source(atom_j_idx): + if ( + atom_i.chain_id != atom_j.chain_id + or atom_i.res_id != atom_j.res_id + or not ( + _is_standard_polymer_backbone_bond(atom_i, atom_j) + or _is_polymer_backbone_like(atom_i, atom_j) + ) + ): + raise AssertionError( + f"Unsupported bond between unindexed component {atom_j.chain_id}{atom_j.res_id} " + f"and omitted residue {atom_i.chain_id}{atom_i.res_id}." + ) + # Only warn when we retained one side of a cross-residue/chain linkage + # (e.g., glycan partner missing), not for missing intra-residue atoms. + if (mapped_i is not None or mapped_j is not None) and ( + atom_i.chain_id != atom_j.chain_id or atom_i.res_id != atom_j.res_id + ): + logger.warning( + ( + "Skipping non-backbone bond from source structure between %s and %s (type %d): " + "one atom is not present in accumulated components. " + "Bond cannot be inferred automatically; set it manually if needed." + ) + % (_fmt_atom(atom_i), _fmt_atom(atom_j), bond_type) + ) + continue + + # Do not connect unindexed residues to anything else for now. + comp_i = source_idx_to_component.get(atom_i_idx) + comp_j = source_idx_to_component.get(atom_j_idx) + if (comp_i in unindexed_components or comp_j in unindexed_components) and ( + atom_i.chain_id != atom_j.chain_id or atom_i.res_id != atom_j.res_id + ): + if not ( + _is_standard_polymer_backbone_bond(atom_i, atom_j) + or _is_polymer_backbone_like(atom_i, atom_j) + ): + raise AssertionError( + "Bonds involving unindexed residues are not yet supported." + ) + continue + + if _is_standard_polymer_backbone_bond( + atom_i, atom_j + ) or _is_polymer_backbone_like(atom_i, atom_j): + continue + + pair = (min(mapped_i, mapped_j), max(mapped_i, mapped_j)) + if pair in seen_pairs: + continue + seen_pairs.add(pair) + bonds_to_add.append([mapped_i, mapped_j, bond_type]) + + bond_array = ( + np.array(bonds_to_add, dtype=np.int64) + if bonds_to_add + else np.empty((0, 3), dtype=np.int64) + ) + new_bonds = struc.BondList(atom_array_accum.array_length(), bond_array) + atom_array_accum.bonds = atom_array_accum.bonds.merge(new_bonds) + return atom_array_accum + + +def _add_backbone_bonds_for_nonstandard_residues( + atom_array_accum: struc.AtomArray, +) -> struc.AtomArray: + """ + Add backbone/polymer bonds for cases where at least one residue is non-standard. + + Uses `atomworks.io.utils.bonds.get_inferred_polymer_bonds`, which consults CCD + chem-comp metadata to decide the correct polymerization atoms (C/N, CG/N, etc.). + Only bonds involving at least one non-standard residue are added; standard + AA/DNA/RNA pairs are assumed to already carry their backbone bonds. + """ + if atom_array_accum.bonds is None: + atom_array_accum.bonds = struc.BondList(atom_array_accum.array_length()) + + unindexed_mask = ( + atom_array_accum.get_annotation("is_motif_atom_unindexed") + if "is_motif_atom_unindexed" in atom_array_accum.get_annotation_categories() + else np.zeros(atom_array_accum.array_length(), dtype=bool) + ) + + existing_pairs = { + (min(a, b), max(a, b)) for a, b, _ in atom_array_accum.bonds.as_array() + } + bonds_to_add: List[List[int]] = [] + + inferred_bonds, _ = get_inferred_polymer_bonds(atom_array_accum) + for atom_i_idx, atom_j_idx, bond_type in inferred_bonds: + atom_i_idx = int(atom_i_idx) + atom_j_idx = int(atom_j_idx) + + # Do not connect unindexed residues across residue boundaries + if (unindexed_mask[atom_i_idx] or unindexed_mask[atom_j_idx]) and ( + atom_array_accum.chain_id[atom_i_idx] + != atom_array_accum.chain_id[atom_j_idx] + or atom_array_accum.res_id[atom_i_idx] + != atom_array_accum.res_id[atom_j_idx] + ): + continue + + # Only synthesize bonds when at least one residue is non-standard; standard + # backbone bonds should already exist. + if _is_standard_polymer_backbone_bond( + atom_array_accum[atom_i_idx], atom_array_accum[atom_j_idx] + ): + continue + + pair = (min(atom_i_idx, atom_j_idx), max(atom_i_idx, atom_j_idx)) + if pair in existing_pairs: + continue + existing_pairs.add(pair) + bonds_to_add.append([pair[0], pair[1], int(bond_type)]) + + if bonds_to_add: + new_bonds = struc.BondList( + atom_array_accum.array_length(), np.array(bonds_to_add, dtype=np.int64) + ) + atom_array_accum.bonds = atom_array_accum.bonds.merge(new_bonds) + return atom_array_accum + + +def _sort_bonds(atom_array_accum: struc.AtomArray) -> struc.AtomArray: + """Sort bonds deterministically by atom indices then bond type.""" + bonds_arr = atom_array_accum.bonds.as_array().copy() + # ensure lower index first + swap_mask = bonds_arr[:, 0] > bonds_arr[:, 1] + bonds_arr[swap_mask, :2] = bonds_arr[swap_mask][:, [1, 0]] + order = np.lexsort((bonds_arr[:, 2], bonds_arr[:, 1], bonds_arr[:, 0])) + bonds_arr = bonds_arr[order] + atom_array_accum.bonds = struc.BondList( + atom_array_accum.array_length(), bonds_arr.astype(np.int64) + ) + return atom_array_accum + + def accumulate_components( components_to_accumulate: List[Union[str, int]], *, @@ -1013,6 +1267,8 @@ def accumulate_components( res_id = start_resid molecule_id = 0 source_to_accum_idx: Dict[int, int] = {} + source_idx_to_component: Dict[int, str] = {} + unindexed_component_names = set(unindexed_tokens.keys()) current_accum_idx = sum(len(arr) for arr in atom_array_accum) # ... Insert contig information one- by one- @@ -1095,6 +1351,7 @@ def accumulate_components( ): for i, src_idx in enumerate(src_indices): source_to_accum_idx[int(src_idx)] = current_accum_idx + i + source_idx_to_component[int(src_idx)] = str(component) # ... Insert & Increment residue ID atom_array_accum.append(token) @@ -1104,23 +1361,17 @@ def accumulate_components( # ... Concatenate all components atom_array_accum = struc.concatenate(atom_array_accum) atom_array_accum.set_annotation("pn_unit_iid", atom_array_accum.chain_id) - - should_restore_bonds = ( - src_atom_array is not None - and bool(source_to_accum_idx) - and _check_has_backbone_connections_to_nonstandard_residues( - atom_array_accum, src_atom_array - ) + atom_array_accum = _restore_component_bonds( + atom_array_accum=atom_array_accum, + src_atom_array=src_atom_array, + source_to_accum_idx=source_to_accum_idx, + source_idx_to_component=source_idx_to_component, + unindexed_components=unindexed_component_names, ) - if should_restore_bonds: - assert not unindexed_tokens, ( - "PTM backbone bond restoration is not compatible with unindexed components. " - "PTMs must be specified as indexed components (using 'contig' parameter, not 'unindex'). " - f"Found unindexed components: {list(unindexed_tokens.keys())}" - ) - atom_array_accum = _restore_bonds_for_nonstandard_residues( - atom_array_accum, src_atom_array, source_to_accum_idx - ) + atom_array_accum = _add_backbone_bonds_for_nonstandard_residues( + atom_array_accum=atom_array_accum + ) + atom_array_accum = _sort_bonds(atom_array_accum) # Reset res_id for unindexed residues to avoid duplicates (ridiculously long lines of code, cleanup later) if np.any(atom_array_accum.is_motif_atom_unindexed.astype(bool)) and not np.all( diff --git a/models/rfd3/src/rfd3/inference/legacy_input_parsing.py b/models/rfd3/src/rfd3/inference/legacy_input_parsing.py index a74d24eb..22d7ed29 100644 --- a/models/rfd3/src/rfd3/inference/legacy_input_parsing.py +++ b/models/rfd3/src/rfd3/inference/legacy_input_parsing.py @@ -2,6 +2,7 @@ import functools import logging from os import PathLike +from typing import Dict import biotite.structure as struc import numpy as np @@ -138,6 +139,8 @@ def fetch_motif_residue_( subarray = set_default_conditioning_annotations( subarray, motif=True, unindexed=False, dtype=int ) # all values init to True (fix all) + if "is_motif_atom" not in subarray.get_annotation_categories(): + subarray.set_annotation("is_motif_atom", np.ones(subarray.shape[0], dtype=int)) to_unindex = f"{src_chain}{src_resid}" in unindexed_components to_index = f"{src_chain}{src_resid}" in components @@ -268,9 +271,190 @@ def create_diffused_residues_(n): ) array = set_default_conditioning_annotations(array, motif=False) array = set_common_annotations(array) + if "is_motif_atom" not in array.get_annotation_categories(): + array.set_annotation("is_motif_atom", np.zeros(array.array_length(), dtype=int)) return array +def _check_has_backbone_connections_to_nonstandard_residues( + atom_array_accum, src_atom_array +): + """ + Check if the source atom array has backbone C-N bonds involving non-standard residues. + This is used to determine if we need to restore bonds in accumulate_components. + Only backbone peptide bonds (C-N) are considered. + Returns: + True if there are backbone C-N bonds involving at least one non-standard residue + """ + + if atom_array_accum is None or src_atom_array is None: + return False + if ( + not hasattr(atom_array_accum, "bonds") + or atom_array_accum.bonds is None + or not hasattr(src_atom_array, "bonds") + or src_atom_array.bonds is None + ): + return False + + bonds = src_atom_array.bonds.as_array() + if len(bonds) == 0: + return False + + unique_res_names = np.unique(atom_array_accum.res_name) + has_nonstandard = any(res_name not in STANDARD_AA for res_name in unique_res_names) + + if not has_nonstandard: + return False + + for bond in bonds: + atom_i, atom_j, _ = bond + atom_i_obj = src_atom_array[int(atom_i)] + atom_j_obj = src_atom_array[int(atom_j)] + + is_backbone_bond = ( + atom_i_obj.atom_name == "C" and atom_j_obj.atom_name == "N" + ) or (atom_i_obj.atom_name == "N" and atom_j_obj.atom_name == "C") + + if not is_backbone_bond: + continue + + res_i_is_standard = atom_i_obj.res_name in STANDARD_AA + res_j_is_standard = atom_j_obj.res_name in STANDARD_AA + + if not (res_i_is_standard and res_j_is_standard): + ranked_logger.debug( + "Found backbone C-N bond involving non-standard residue: %s%s:%s - %s%s:%s", + atom_i_obj.res_name, + atom_i_obj.res_id, + atom_i_obj.atom_name, + atom_j_obj.res_name, + atom_j_obj.res_id, + atom_j_obj.atom_name, + ) + return True + + return False + + +def _restore_bonds_for_nonstandard_residues( + atom_array_accum: struc.AtomArray, + src_atom_array: struc.AtomArray | None, + source_to_accum_idx: Dict[int, int], +) -> struc.AtomArray: + """ + Restores and creates bonds for non-standard residues (PTMs, modified AAs, etc.) + from source structure and between consecutive residues. + + This function: + 1. Preserves inter-residue bonds from the source structure (if available) + 2. Adds backbone C-N bonds between consecutive residues where at least one is non-standard + """ + + if atom_array_accum.bonds is None: + atom_array_accum.bonds = struc.BondList(atom_array_accum.array_length()) + + if ( + src_atom_array is not None + and hasattr(src_atom_array, "bonds") + and src_atom_array.bonds is not None + ): + original_bonds = src_atom_array.bonds.as_array() + if len(original_bonds) > 0: + bonds_to_add = [] + for bond in original_bonds: + atom_i, atom_j, bond_type = bond + if ( + int(atom_i) in source_to_accum_idx + and int(atom_j) in source_to_accum_idx + ): + src_res_i = src_atom_array[int(atom_i)].res_name + src_res_j = src_atom_array[int(atom_j)].res_name + + if src_res_i not in STANDARD_AA or src_res_j not in STANDARD_AA: + new_i = source_to_accum_idx[int(atom_i)] + new_j = source_to_accum_idx[int(atom_j)] + bonds_to_add.append([new_i, new_j, int(bond_type)]) + + if bonds_to_add: + new_bonds = struc.BondList( + atom_array_accum.array_length(), + np.array(bonds_to_add, dtype=np.int64), + ) + atom_array_accum.bonds = atom_array_accum.bonds.merge(new_bonds) + ranked_logger.info( + "Preserved %s inter-residue bonds involving non-standard residues from source structure", + len(bonds_to_add), + ) + + bonds_to_add = [] + token_starts = get_token_starts(atom_array_accum, add_exclusive_stop=True) + + for i in range(len(token_starts) - 2): + curr_start, curr_end = token_starts[i], token_starts[i + 1] + next_start, next_end = token_starts[i + 1], token_starts[i + 2] + + curr_residue = atom_array_accum[curr_start:curr_end] + next_residue = atom_array_accum[next_start:next_end] + + curr_is_nonstandard = curr_residue.res_name[0] not in STANDARD_AA + next_is_nonstandard = next_residue.res_name[0] not in STANDARD_AA + + if not (curr_is_nonstandard or next_is_nonstandard): + continue + + if curr_residue.chain_id[0] != next_residue.chain_id[0]: + continue + if next_residue.res_id[0] - curr_residue.res_id[0] != 1: + continue + + c_mask = curr_residue.atom_name == "C" + if not np.any(c_mask): + if curr_is_nonstandard and next_is_nonstandard: + ranked_logger.debug( + "Non-standard residue %s (res_id %s) has no C atom - cannot form backbone bond to next residue", + curr_residue.res_name[0], + curr_residue.res_id[0], + ) + continue + c_idx = curr_start + np.where(c_mask)[0][0] + + n_mask = next_residue.atom_name == "N" + if not np.any(n_mask): + if curr_is_nonstandard and next_is_nonstandard: + ranked_logger.debug( + "Non-standard residue %s (res_id %s) has no N atom - cannot form backbone bond from previous residue", + next_residue.res_name[0], + next_residue.res_id[0], + ) + continue + n_idx = next_start + np.where(n_mask)[0][0] + + existing_bonds = atom_array_accum.bonds.as_array() + bond_exists = False + if len(existing_bonds) > 0: + for existing_bond in existing_bonds: + if (existing_bond[0] == c_idx and existing_bond[1] == n_idx) or ( + existing_bond[0] == n_idx and existing_bond[1] == c_idx + ): + bond_exists = True + break + + if not bond_exists: + bonds_to_add.append([c_idx, n_idx, struc.BondType.SINGLE]) + + if bonds_to_add: + new_bonds = struc.BondList( + atom_array_accum.array_length(), np.array(bonds_to_add, dtype=np.int64) + ) + atom_array_accum.bonds = atom_array_accum.bonds.merge(new_bonds) + ranked_logger.info( + "Added %s backbone bonds involving non-standard residues", len(bonds_to_add) + ) + + return atom_array_accum + + def accumulate_components( components, src_atom_array, @@ -330,6 +514,8 @@ def accumulate_components( chain = start_chain res_id = start_resid molecule_id = 0 + source_to_accum_idx: Dict[int, int] = {} + current_accum_idx = 0 # 2) Insert contig information one- by one- for component, is_break in zip(components, breaks): if component == "/0": @@ -343,6 +529,13 @@ def accumulate_components( if str(component)[0].isalpha(): # motif (e.g. "A22") atom_array_insert = fetch_motif_residue(*split_contig(component)) n = 1 + src_indices = None + if src_atom_array is not None: + try: + src_mask = fetch_mask_from_idx(component, atom_array=src_atom_array) + src_indices = np.where(src_mask)[0] + except Exception: + src_indices = None if exists(is_break) and is_break: if not unindexed_components_started: chain = start_chain @@ -380,13 +573,40 @@ def accumulate_components( len(get_token_starts(atom_array_insert)) == n ), f"Mismatch in number of residues: expected {n}, got {len(get_token_starts(atom_array_insert))} in \n{atom_array_insert}" + if ( + src_atom_array is not None + and str(component)[0].isalpha() + and src_indices is not None + and len(src_indices) == len(atom_array_insert) + ): + for i, src_idx in enumerate(src_indices): + source_to_accum_idx[int(src_idx)] = current_accum_idx + i + # ... Insert & Increment residue ID atom_array_accum.append(atom_array_insert) res_id += n + current_accum_idx += len(atom_array_insert) atom_array_accum = struc.concatenate(atom_array_accum) atom_array_accum.set_annotation("pn_unit_iid", atom_array_accum.chain_id) + should_restore_bonds = ( + src_atom_array is not None + and bool(source_to_accum_idx) + and _check_has_backbone_connections_to_nonstandard_residues( + atom_array_accum, src_atom_array + ) + ) + if should_restore_bonds: + assert not unindexed_components, ( + "PTM backbone bond restoration is not compatible with unindexed components. " + "PTMs must be specified as indexed components (using 'contig' parameter, not 'unindex'). " + f"Found unindexed components: {unindexed_components}" + ) + atom_array_accum = _restore_bonds_for_nonstandard_residues( + atom_array_accum, src_atom_array, source_to_accum_idx + ) + # Reset res_id for unindexed residues to avoid duplicates (ridiculously long lines of code, cleanup later) if np.any(atom_array_accum.is_motif_atom_unindexed.astype(bool)) and not np.all( atom_array_accum.is_motif_atom_unindexed.astype(bool) @@ -405,6 +625,9 @@ def accumulate_components( atom_array_accum.is_motif_atom_unindexed.astype(bool) ] += max_id - min_id_udx + 1 + if atom_array_accum.bonds is None: + atom_array_accum.bonds = struc.BondList(atom_array_accum.array_length()) + return atom_array_accum diff --git a/models/rfd3/tests/test_bond_preservation_cases.py b/models/rfd3/tests/test_bond_preservation_cases.py new file mode 100644 index 00000000..678c1090 --- /dev/null +++ b/models/rfd3/tests/test_bond_preservation_cases.py @@ -0,0 +1,360 @@ +#!/usr/bin/env -S /bin/sh -c '"$(dirname "$0")/../../../scripts/shebang/modelhub_exec.sh" "$0" "$@"' +""" +Bond preservation regression tests for representative connection types. +""" + +from pathlib import Path + +import numpy as np +import pytest +from atomworks.io.parser import STANDARD_PARSER_ARGS, parse +from atomworks.io.tools.inference import components_to_atom_array +from biotite import structure as struc +from rfd3.inference.input_parsing import ( + accumulate_components, + create_atom_array_from_design_specification, +) +from rfd3.transforms.conditioning_base import ( + set_default_conditioning_annotations, +) +from rfd3.utils.inference import set_common_annotations + +from foundry.utils.components import fetch_mask_from_idx + +TEST_DATA_DIR = Path(__file__).parent / "test_data" + + +def _load_atom_array(pdb_id: str): + path = TEST_DATA_DIR / f"{pdb_id.lower()}.cif" + if not path.exists(): + pytest.skip(f"Test data file missing: {path}") + parser_args = { + **STANDARD_PARSER_ARGS, + # Ignore metal coordination; only covalent/disulfide bonds are restored. + "add_bond_types_from_struct_conn": ["covale", "disulf"], + } + result = parse(filename=path, build_assembly=("1",), **parser_args) + return result["assemblies"]["1"][0] + + +def _prepare_token(atom_array, component: str): + mask = fetch_mask_from_idx(component, atom_array=atom_array) + token = atom_array[mask].copy() + token = set_default_conditioning_annotations(token, motif=True, dtype=int) + token = set_common_annotations(token) + token.res_id = np.ones(token.shape[0], dtype=token.res_id.dtype) + return token + + +def _accumulate(atom_array, components): + tokens = {c: _prepare_token(atom_array, c) for c in components} + return accumulate_components( + components_to_accumulate=components, + indexed_tokens=tokens, + unindexed_tokens={}, + atom_array_accum=[], + start_chain="A", + start_resid=1, + unindexed_breaks=[None] * len(components), + src_atom_array=atom_array, + ) + + +def _create_ptm_atom_array(): + """Create a simple protein with PTMs: AG(PTR)(SEP)SA.""" + components = [ + { + "seq": "AG(PTR)(SEP)SA", + "chain_type": "polypeptide(l)", + "is_polymer": True, + "chain_id": "A", + }, + ] + atom_array = components_to_atom_array(components) + atom_array.coord = np.random.randn(len(atom_array), 3).astype(np.float32) * 10 + return atom_array + + +def _atom_index(arr, res_id: int, atom_name: str, chain: str = "A") -> int: + mask = ( + (arr.chain_id == chain) & (arr.res_id == res_id) & (arr.atom_name == atom_name) + ) + idx = np.where(mask)[0] + assert len(idx) == 1, f"Atom {chain}{res_id}:{atom_name} not unique/found" + return int(idx[0]) + + +def _bond_exists( + arr: struc.AtomArray, + idx_a: int, + idx_b: int, + bond_type: struc.BondType | None = None, +) -> bool: + bonds = arr.bonds.as_array() + mask = ((bonds[:, 0] == idx_a) & (bonds[:, 1] == idx_b)) | ( + (bonds[:, 0] == idx_b) & (bonds[:, 1] == idx_a) + ) + if bond_type is not None: + mask &= bonds[:, 2] == bond_type + return np.any(mask) + + +def _bond_label(arr: struc.AtomArray, idx: int) -> tuple[str, int, str, str]: + """Return a human-friendly label for assertions.""" + return ( + arr.chain_id[idx], + int(arr.res_id[idx]), + arr.res_name[idx], + arr.atom_name[idx], + ) + + +def _cross_residue_bonds( + arr: struc.AtomArray, +) -> set[tuple[tuple[str, int, str, str], tuple[str, int, str, str]]]: + bonds = set() + for a, b, _ in arr.bonds.as_array(): + a = int(a) + b = int(b) + if arr.chain_id[a] == arr.chain_id[b] and arr.res_id[a] == arr.res_id[b]: + continue + bond = tuple(sorted((_bond_label(arr, a), _bond_label(arr, b)))) + bonds.add(bond) + return bonds + + +@pytest.mark.slow +def test_disulfide_preserved(): + """ + 1crn.cif struct_conn disulf1: A CYS 3 SG <-> A CYS 40 SG. + """ + arr = _load_atom_array("1crn") + accum = accumulate_components( + components_to_accumulate=[1, "A3", "A40", 1], + indexed_tokens={ + "A3": _prepare_token(arr, "A3"), + "A40": _prepare_token(arr, "A40"), + }, + unindexed_tokens={}, + atom_array_accum=[], + start_chain="A", + start_resid=1, + unindexed_breaks=[None] * 4, + src_atom_array=arr, + ) + sg1 = _atom_index(accum, 2, "SG") # CYS3 after one diffused residue + sg2 = _atom_index(accum, 3, "SG") # CYS40 after one diffused + one indexed + assert _bond_exists(accum, sg1, sg2) + expected_bonds = {tuple(sorted((_bond_label(accum, sg1), _bond_label(accum, sg2))))} + assert _cross_residue_bonds(accum) == expected_bonds + + +@pytest.mark.slow +def test_covalent_ligand_preserved(): + """ + 4qdv.cif struct_conn covale1: A TYR 143 OH <-> E 30U 401 S1. + """ + arr = _load_atom_array("4qdv") + accum = _accumulate(arr, ["A143", "E401"]) + oh_tyr = _atom_index(accum, 1, "OH") + s1_30u = _atom_index(accum, 2, "S1") + expected_bonds = { + tuple(sorted((_bond_label(accum, oh_tyr), _bond_label(accum, s1_30u)))) + } + assert _cross_residue_bonds(accum) == expected_bonds + + +@pytest.mark.slow +def test_dna_af_adduct_preserved(): + """ + 1ua0.cif AF adduct: label chains B/E (auth C/C) DG4 C8 <-> AF333 N. + """ + arr = _load_atom_array("1ua0") + accum = _accumulate(arr, ["B4", "E333"]) + c8_dg = _atom_index(accum, 1, "C8") + n_af = _atom_index(accum, 2, "N") + assert _bond_exists(accum, c8_dg, n_af) + expected_bonds = { + tuple(sorted((_bond_label(accum, c8_dg), _bond_label(accum, n_af)))) + } + assert _cross_residue_bonds(accum) == expected_bonds + + +@pytest.mark.slow +def test_cyclic_thioether_link_preserved(): + """ + 6u6k.cif struct_conn: + - covale1: B ACE 1 C <-> B TRP 2 N + - covale2: B ACE 1 CH3 <-> B CYS 12 SG + """ + arr = _load_atom_array("6u6k") + accum = _accumulate(arr, ["B1", "B2", "B12"]) + ch3 = _atom_index(accum, 1, "CH3") + c_ace = _atom_index(accum, 1, "C") + n_trp = _atom_index(accum, 2, "N") + sg = _atom_index(accum, 3, "SG") + assert _bond_exists(accum, ch3, sg) + assert _bond_exists(accum, c_ace, n_trp) + expected_bonds = { + tuple(sorted((_bond_label(accum, ch3), _bond_label(accum, sg)))), + tuple(sorted((_bond_label(accum, c_ace), _bond_label(accum, n_trp)))), + } + assert _cross_residue_bonds(accum) == expected_bonds + + +@pytest.mark.slow +def test_nonpeptide_noncanonical_not_backbone_linked(): + """ + NIO in 3o14.cif has an atom named N but no struct_conn. + with only diffused neighbors no backbone bond should be synthesized. + """ + arr = _load_atom_array("3o14") + accum = accumulate_components( + components_to_accumulate=[1, "D300", 1], + indexed_tokens={"D300": _prepare_token(arr, "D300")}, + unindexed_tokens={}, + atom_array_accum=[], + start_chain="A", + start_resid=1, + unindexed_breaks=[None] * 3, + src_atom_array=arr, + ) + # No cross-residue bonds because nothing is connected to the ligand. + assert _cross_residue_bonds(accum) == set() + + +@pytest.mark.slow +def test_glycan_links_and_absence_when_partner_missing(): + """ + 8f7t.cif struct_conn covale5: C ASN 403 ND2 <-> G NAG 1 C1. + Also ensure NAG1 has no cross-res bonds if ASN403 is not included. + """ + arr = _load_atom_array("8f7t") + accum = _accumulate(arr, ["C403", "G1"]) + nd2 = _atom_index(accum, 1, "ND2") + c1 = _atom_index(accum, 2, "C1") + assert _bond_exists(accum, nd2, c1) + expected_bonds = {tuple(sorted((_bond_label(accum, nd2), _bond_label(accum, c1))))} + assert _cross_residue_bonds(accum) == expected_bonds + + accum_no_asn = accumulate_components( + components_to_accumulate=[1, "G1", 1], + indexed_tokens={"G1": _prepare_token(arr, "G1")}, + unindexed_tokens={}, + atom_array_accum=[], + start_chain="A", + start_resid=1, + unindexed_breaks=[None] * 3, + src_atom_array=arr, + ) + c1_lonely = _atom_index(accum_no_asn, 2, "C1") + assert _cross_residue_bonds(accum_no_asn) == set() + partners = [ + int(bond[1]) if bond[0] == c1_lonely else int(bond[0]) + for bond in accum_no_asn.bonds.as_array() + if c1_lonely in bond[:2] + ] + partner_res_ids = accum_no_asn.res_id[partners] if partners else [] + assert len(partner_res_ids) == 0 or np.all(partner_res_ids == 2) + + +@pytest.mark.slow +def test_backbone_struct_conn_preserved_with_diffusion(): + """ + 1p5d.cif struct_conn covale1/2: + GLY107 C <-> SEP108 N, SEP108 C <-> HIS109 N. + """ + arr = _load_atom_array("1p5d") + accum = accumulate_components( + components_to_accumulate=[2, "A108", 2], + indexed_tokens={"A108": _prepare_token(arr, "A108")}, + unindexed_tokens={}, + atom_array_accum=[], + start_chain="A", + start_resid=1, + unindexed_breaks=[None] * 3, + src_atom_array=arr, + ) + c_prev = _atom_index(accum, 2, "C") + n_sep = _atom_index(accum, 3, "N") + c_sep = _atom_index(accum, 3, "C") + n_next = _atom_index(accum, 4, "N") + assert _bond_exists(accum, c_prev, n_sep) + assert _bond_exists(accum, c_sep, n_next) + expected_bonds = { + tuple(sorted((_bond_label(accum, c_prev), _bond_label(accum, n_sep)))), + tuple(sorted((_bond_label(accum, c_sep), _bond_label(accum, n_next)))), + } + assert _cross_residue_bonds(accum) == expected_bonds + + +@pytest.mark.fast +def test_ptm_backbone_bonds_preserved_with_diffusion(): + """ + Synthetic PTR/SEP motif with 5 diffused residues on each side; ensure backbone + bonds span the diffused neighbors. + """ + src_atom_array = _create_ptm_atom_array() + indexed_tokens = {c: _prepare_token(src_atom_array, c) for c in ["A3", "A4"]} + + accum = accumulate_components( + components_to_accumulate=[5, "A3", "A4", 5], + indexed_tokens=indexed_tokens, + unindexed_tokens={}, + atom_array_accum=[], + start_chain="A", + start_resid=1, + unindexed_breaks=[None] * 4, + src_atom_array=src_atom_array, + ) + + diffused_c = _atom_index(accum, 5, "C") + ptr_n = _atom_index(accum, 6, "N") + assert _bond_exists(accum, diffused_c, ptr_n, struc.BondType.SINGLE) + + ptr_c = _atom_index(accum, 6, "C") + sep_n = _atom_index(accum, 7, "N") + assert _bond_exists(accum, ptr_c, sep_n, struc.BondType.SINGLE) + + sep_c = _atom_index(accum, 7, "C") + diffused_after_n = _atom_index(accum, 8, "N") + assert _bond_exists(accum, sep_c, diffused_after_n, struc.BondType.SINGLE) + + cross_bonds = _cross_residue_bonds(accum) + expected_cross = { + tuple(sorted((_bond_label(accum, diffused_c), _bond_label(accum, ptr_n)))), + tuple(sorted((_bond_label(accum, ptr_c), _bond_label(accum, sep_n)))), + tuple( + sorted((_bond_label(accum, sep_c), _bond_label(accum, diffused_after_n))) + ), + } + assert expected_cross.issubset(cross_bonds) + + +@pytest.mark.fast +def test_ptm_backbone_bonds_preserved_full_pipeline(): + """ + End-to-end test through create_atom_array_from_design_specification (dialect 2). + Ensures PTM backbone bonds survive the normal loading pipeline. + """ + atom_array_input = _create_ptm_atom_array() + + contig = "5-5,A3-4,5-5" # -> [5, A3, A4, 5] + atom_array, _ = create_atom_array_from_design_specification( + atom_array_input=atom_array_input, + input=None, + contig=contig, + length="12-12", + dialect=2, + ) + + diffused_c = _atom_index(atom_array, 5, "C") + ptr_n = _atom_index(atom_array, 6, "N") + ptr_c = _atom_index(atom_array, 6, "C") + sep_n = _atom_index(atom_array, 7, "N") + sep_c = _atom_index(atom_array, 7, "C") + diffused_after_n = _atom_index(atom_array, 8, "N") + + assert _bond_exists(atom_array, diffused_c, ptr_n, struc.BondType.SINGLE) + assert _bond_exists(atom_array, ptr_c, sep_n, struc.BondType.SINGLE) + assert _bond_exists(atom_array, sep_c, diffused_after_n, struc.BondType.SINGLE) diff --git a/models/rfd3/tests/test_legacy_ptm_bonds.py b/models/rfd3/tests/test_legacy_ptm_bonds.py new file mode 100644 index 00000000..fb849da8 --- /dev/null +++ b/models/rfd3/tests/test_legacy_ptm_bonds.py @@ -0,0 +1,107 @@ +""" +Tests that legacy accumulate_components preserves backbone bonds involving PTMs. +""" + +import numpy as np +import pytest +from atomworks.io.tools.inference import components_to_atom_array +from biotite import structure as struc +from rfd3.inference.legacy_input_parsing import accumulate_components +from rfd3.transforms.conditioning_base import set_default_conditioning_annotations +from rfd3.utils.inference import set_common_annotations + +from foundry.utils.components import fetch_mask_from_idx + + +def _create_ptm_structure(): + """Create a simple protein with PTMs: AG(PTR)(SEP)SA.""" + components = [ + { + "seq": "AG(PTR)(SEP)SA", + "chain_type": "polypeptide(l)", + "is_polymer": True, + "chain_id": "A", + }, + ] + atom_array = components_to_atom_array(components) + # Add coordinates so bond inference code that may rely on coords won't hit NaNs + atom_array.coord = np.random.randn(len(atom_array), 3).astype(np.float32) * 10 + return atom_array + + +def _prepare_indexed_tokens(atom_array, components): + """Create motif tokens with required annotations for accumulate_components.""" + tokens = {} + for component in components: + mask = fetch_mask_from_idx(component, atom_array=atom_array) + token = atom_array[mask].copy() + token = set_default_conditioning_annotations( + token, motif=True, unindexed=False, dtype=int + ) + token = set_common_annotations(token) + token.res_id = np.ones(token.shape[0], dtype=token.res_id.dtype) + tokens[component] = token + return tokens + + +def _connection_exists(bonds, a, b, bond_type=None): + """Check if a connection between atoms a and b exists (in either direction).""" + mask = ((bonds[:, 0] == a) & (bonds[:, 1] == b)) | ( + (bonds[:, 0] == b) & (bonds[:, 1] == a) + ) + if bond_type is not None: + mask &= bonds[:, 2] == bond_type + return np.any(mask) + + +@pytest.mark.fast +def test_legacy_ptm_backbone_bonds(): + """ + Verify that PTM backbone bonds are restored in legacy parser. + + Setup: 5 diffused residues -> PTR-SEP (connected motif) -> 5 diffused residues + Expect backbone bonds: diffused->PTR, PTR->SEP, SEP->diffused. + """ + + src_atom_array = _create_ptm_structure() + components = [5, "A3", "A4", 5] + + result = accumulate_components( + components=components, + src_atom_array=src_atom_array, + redesign_motif_sidechains=False, + unindexed_components=[], + unfixed_sequence_components=[], + breaks=[None] * len(components), + fixed_atoms={}, + unfix_all=False, + optional_conditions=[], + flexible_backbone=False, + unfix_residues=[], + ) + bonds = result.bonds.as_array() + + def atom_idx(chain, resid, atom_name): + mask = ( + (result.chain_id == chain) + & (result.res_id == resid) + & (result.atom_name == atom_name) + ) + idx = np.where(mask)[0] + assert ( + len(idx) == 1 + ), f"Expected unique atom for {chain}{resid}:{atom_name}, got {len(idx)}" + return idx[0] + + # Residue IDs after accumulation: 1-5 diffused, 6=A3 (PTR), 7=A4 (SEP), 8+=diffused + diffused_c = atom_idx("A", 5, "C") + ptr_n = atom_idx("A", 6, "N") + assert _connection_exists(bonds, diffused_c, ptr_n, struc.BondType.SINGLE) + + ptr_c = atom_idx("A", 6, "C") + sep_n = atom_idx("A", 7, "N") + assert _connection_exists(bonds, ptr_c, sep_n, struc.BondType.SINGLE) + + sep_c = atom_idx("A", 7, "C") + diffused_after_n = atom_idx("A", 8, "N") + assert _connection_exists(bonds, sep_c, diffused_after_n, struc.BondType.SINGLE)