-
Notifications
You must be signed in to change notification settings - Fork 134
Improve PTM bond handling in RFdiffusion3 #256
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: production
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -9,8 +9,9 @@ | |
| from typing import Any, Dict, List, Optional, Union | ||
|
|
||
| import numpy as np | ||
| from atomworks.constants import STANDARD_AA | ||
| from atomworks.constants import STANDARD_AA, STANDARD_DNA, STANDARD_RNA | ||
| from atomworks.io.parser import parse_atom_array | ||
| from atomworks.io.utils.bonds import get_inferred_polymer_bonds | ||
|
|
||
| # from atomworks.ml.datasets.datasets import BaseDataset | ||
| from atomworks.ml.transforms.base import TransformedDict | ||
|
|
@@ -48,7 +49,6 @@ | |
| ) | ||
| from rfd3.transforms.util_transforms import assign_types_ | ||
| from rfd3.utils.inference import ( | ||
| _restore_bonds_for_nonstandard_residues, | ||
| extract_ligand_array, | ||
| inference_load_, | ||
| set_com, | ||
|
|
@@ -58,6 +58,7 @@ | |
|
|
||
| from foundry.common import exists | ||
| from foundry.utils.components import ( | ||
| fetch_mask_from_idx, | ||
| get_design_pattern_with_constraints, | ||
| get_motif_components_and_breaks, | ||
| ) | ||
|
|
@@ -968,6 +969,259 @@ def create_motif_residue( | |
| return token | ||
|
|
||
|
|
||
| def _polymer_link_atoms_for_residue(res_name: str) -> set[frozenset[str]]: | ||
| """ | ||
| Return the atom-name pairs that represent canonical polymerization atoms for a residue. | ||
|
|
||
| Only standard AA/DNA/RNA residues are treated as having canonical backbone links; PTMs | ||
| and other chem comp types fall back to nonstandard handling. | ||
| """ | ||
| if res_name in STANDARD_AA: | ||
| return {frozenset({"C", "N"})} | ||
| if res_name in STANDARD_DNA or res_name in STANDARD_RNA: | ||
| return {frozenset({"O3'", "P"}), frozenset({"O3*", "P"})} # allow legacy O3* | ||
| return set() | ||
|
|
||
|
|
||
| def _is_standard_polymer_backbone_bond(atom_a: struc.Atom, atom_b: struc.Atom) -> bool: | ||
| if atom_a.chain_id != atom_b.chain_id: | ||
| return False | ||
| if abs(atom_a.res_id - atom_b.res_id) != 1: | ||
| return False | ||
| atom_pair = frozenset({atom_a.atom_name, atom_b.atom_name}) | ||
|
|
||
| pairs_a = _polymer_link_atoms_for_residue(atom_a.res_name) | ||
| pairs_b = _polymer_link_atoms_for_residue(atom_b.res_name) | ||
| shared_pairs = pairs_a & pairs_b | ||
| return atom_pair in shared_pairs | ||
|
|
||
|
|
||
| def _is_polymer_backbone_like(atom_a: struc.Atom, atom_b: struc.Atom) -> bool: | ||
| """ | ||
| Broader backbone check that treats canonical polymer atoms (C/N for peptide, | ||
| O3'/O3*–P for nucleic) as backbone links even when the residue itself is | ||
| non-standard (e.g., PTR/SEP). | ||
| """ | ||
| if atom_a.chain_id != atom_b.chain_id: | ||
| return False | ||
| if abs(atom_a.res_id - atom_b.res_id) != 1: | ||
| return False | ||
| atom_pair = frozenset({atom_a.atom_name, atom_b.atom_name}) | ||
| return atom_pair in { | ||
| frozenset({"C", "N"}), | ||
| frozenset({"O3'", "P"}), | ||
| frozenset({"O3*", "P"}), | ||
| } | ||
|
|
||
|
|
||
| def _restore_component_bonds( | ||
| atom_array_accum: struc.AtomArray, | ||
| src_atom_array: Optional[struc.AtomArray], | ||
| source_to_accum_idx: Dict[int, int], | ||
| source_idx_to_component: Dict[int, str], | ||
| unindexed_components: set[str], | ||
| ) -> struc.AtomArray: | ||
| """ | ||
| Rehydrate bonds from the input structure onto the accumulated array. | ||
|
|
||
| - Replays bonds from `src_atom_array` using the provided source→accum mappings. | ||
| - Skips canonical polymer backbone bonds (peptide/nucleic) that are reconstructed elsewhere. | ||
| - Protects unindexed components by disallowing cross-residue bonds that involve | ||
| an unindexed residue (except standard backbone). | ||
| - Emits warnings when a source bond cannot be remapped because one endpoint was | ||
| dropped during accumulation. | ||
| """ | ||
| if atom_array_accum.bonds is None: | ||
| atom_array_accum.bonds = struc.BondList(atom_array_accum.array_length()) | ||
|
|
||
| if ( | ||
| src_atom_array is None | ||
| or not hasattr(src_atom_array, "bonds") | ||
| or src_atom_array.bonds is None | ||
| or not source_to_accum_idx | ||
| ): | ||
| return atom_array_accum | ||
|
|
||
| bonds_to_add: List[List[int]] = [] | ||
| seen_pairs: set[tuple[int, int]] = set() | ||
| src_bonds = np.asarray(src_atom_array.bonds.as_array(), dtype=np.int64) | ||
|
|
||
| def _is_unindexed_source(idx: int) -> bool: | ||
| component = source_idx_to_component.get(idx) | ||
| return component in unindexed_components if component is not None else False | ||
|
|
||
| def _fmt_atom(atom: struc.Atom) -> str: | ||
| # Use a readable residue/atom separator to avoid names running together | ||
| return f"{atom.chain_id}{atom.res_id}:{atom.res_name}_{atom.atom_name}" | ||
|
|
||
| for atom_i_idx, atom_j_idx, bond_type in src_bonds: | ||
| atom_i_idx = int(atom_i_idx) | ||
| atom_j_idx = int(atom_j_idx) | ||
| bond_type = int(bond_type) | ||
| mapped_i = source_to_accum_idx.get(atom_i_idx) | ||
| mapped_j = source_to_accum_idx.get(atom_j_idx) | ||
|
|
||
| atom_i = src_atom_array[atom_i_idx] | ||
| atom_j = src_atom_array[atom_j_idx] | ||
|
|
||
| # If we only have one side of the bond, assert if the mapped atom is from an | ||
| # unindexed component and the bond would connect across residues. | ||
| if mapped_i is None or mapped_j is None: | ||
| if _is_standard_polymer_backbone_bond( | ||
| atom_i, atom_j | ||
| ) or _is_polymer_backbone_like(atom_i, atom_j): | ||
| continue | ||
| if mapped_i is not None and _is_unindexed_source(atom_i_idx): | ||
| if ( | ||
| atom_i.chain_id != atom_j.chain_id | ||
| or atom_i.res_id != atom_j.res_id | ||
| or not ( | ||
| _is_standard_polymer_backbone_bond(atom_i, atom_j) | ||
| or _is_polymer_backbone_like(atom_i, atom_j) | ||
| ) | ||
| ): | ||
| raise AssertionError( | ||
| f"Unsupported bond between unindexed component {atom_i.chain_id}{atom_i.res_id} " | ||
| f"and omitted residue {atom_j.chain_id}{atom_j.res_id}." | ||
| ) | ||
| if mapped_j is not None and _is_unindexed_source(atom_j_idx): | ||
| if ( | ||
| atom_i.chain_id != atom_j.chain_id | ||
| or atom_i.res_id != atom_j.res_id | ||
| or not ( | ||
| _is_standard_polymer_backbone_bond(atom_i, atom_j) | ||
| or _is_polymer_backbone_like(atom_i, atom_j) | ||
| ) | ||
| ): | ||
| raise AssertionError( | ||
|
Comment on lines
+1087
to
+1096
|
||
| f"Unsupported bond between unindexed component {atom_j.chain_id}{atom_j.res_id} " | ||
| f"and omitted residue {atom_i.chain_id}{atom_i.res_id}." | ||
| ) | ||
| # Only warn when we retained one side of a cross-residue/chain linkage | ||
| # (e.g., glycan partner missing), not for missing intra-residue atoms. | ||
| if (mapped_i is not None or mapped_j is not None) and ( | ||
| atom_i.chain_id != atom_j.chain_id or atom_i.res_id != atom_j.res_id | ||
| ): | ||
| logger.warning( | ||
| ( | ||
| "Skipping non-backbone bond from source structure between %s and %s (type %d): " | ||
| "one atom is not present in accumulated components. " | ||
| "Bond cannot be inferred automatically; set it manually if needed." | ||
| ) | ||
| % (_fmt_atom(atom_i), _fmt_atom(atom_j), bond_type) | ||
| ) | ||
| continue | ||
|
|
||
| # Do not connect unindexed residues to anything else for now. | ||
| comp_i = source_idx_to_component.get(atom_i_idx) | ||
| comp_j = source_idx_to_component.get(atom_j_idx) | ||
| if (comp_i in unindexed_components or comp_j in unindexed_components) and ( | ||
| atom_i.chain_id != atom_j.chain_id or atom_i.res_id != atom_j.res_id | ||
| ): | ||
| if not ( | ||
| _is_standard_polymer_backbone_bond(atom_i, atom_j) | ||
| or _is_polymer_backbone_like(atom_i, atom_j) | ||
| ): | ||
| raise AssertionError( | ||
| "Bonds involving unindexed residues are not yet supported." | ||
| ) | ||
| continue | ||
|
|
||
| if _is_standard_polymer_backbone_bond( | ||
| atom_i, atom_j | ||
| ) or _is_polymer_backbone_like(atom_i, atom_j): | ||
| continue | ||
|
|
||
| pair = (min(mapped_i, mapped_j), max(mapped_i, mapped_j)) | ||
| if pair in seen_pairs: | ||
| continue | ||
| seen_pairs.add(pair) | ||
| bonds_to_add.append([mapped_i, mapped_j, bond_type]) | ||
|
|
||
| bond_array = ( | ||
| np.array(bonds_to_add, dtype=np.int64) | ||
| if bonds_to_add | ||
| else np.empty((0, 3), dtype=np.int64) | ||
| ) | ||
| new_bonds = struc.BondList(atom_array_accum.array_length(), bond_array) | ||
| atom_array_accum.bonds = atom_array_accum.bonds.merge(new_bonds) | ||
| return atom_array_accum | ||
|
|
||
|
|
||
| def _add_backbone_bonds_for_nonstandard_residues( | ||
| atom_array_accum: struc.AtomArray, | ||
| ) -> struc.AtomArray: | ||
| """ | ||
| Add backbone/polymer bonds for cases where at least one residue is non-standard. | ||
|
|
||
| Uses `atomworks.io.utils.bonds.get_inferred_polymer_bonds`, which consults CCD | ||
| chem-comp metadata to decide the correct polymerization atoms (C/N, CG/N, etc.). | ||
| Only bonds involving at least one non-standard residue are added; standard | ||
| AA/DNA/RNA pairs are assumed to already carry their backbone bonds. | ||
| """ | ||
| if atom_array_accum.bonds is None: | ||
| atom_array_accum.bonds = struc.BondList(atom_array_accum.array_length()) | ||
|
|
||
| unindexed_mask = ( | ||
| atom_array_accum.get_annotation("is_motif_atom_unindexed") | ||
| if "is_motif_atom_unindexed" in atom_array_accum.get_annotation_categories() | ||
| else np.zeros(atom_array_accum.array_length(), dtype=bool) | ||
| ) | ||
|
|
||
| existing_pairs = { | ||
| (min(a, b), max(a, b)) for a, b, _ in atom_array_accum.bonds.as_array() | ||
| } | ||
| bonds_to_add: List[List[int]] = [] | ||
|
|
||
| inferred_bonds, _ = get_inferred_polymer_bonds(atom_array_accum) | ||
| for atom_i_idx, atom_j_idx, bond_type in inferred_bonds: | ||
| atom_i_idx = int(atom_i_idx) | ||
| atom_j_idx = int(atom_j_idx) | ||
|
|
||
| # Do not connect unindexed residues across residue boundaries | ||
| if (unindexed_mask[atom_i_idx] or unindexed_mask[atom_j_idx]) and ( | ||
| atom_array_accum.chain_id[atom_i_idx] | ||
| != atom_array_accum.chain_id[atom_j_idx] | ||
| or atom_array_accum.res_id[atom_i_idx] | ||
| != atom_array_accum.res_id[atom_j_idx] | ||
| ): | ||
| continue | ||
|
|
||
| # Only synthesize bonds when at least one residue is non-standard; standard | ||
| # backbone bonds should already exist. | ||
| if _is_standard_polymer_backbone_bond( | ||
| atom_array_accum[atom_i_idx], atom_array_accum[atom_j_idx] | ||
| ): | ||
| continue | ||
|
|
||
| pair = (min(atom_i_idx, atom_j_idx), max(atom_i_idx, atom_j_idx)) | ||
| if pair in existing_pairs: | ||
| continue | ||
| existing_pairs.add(pair) | ||
| bonds_to_add.append([pair[0], pair[1], int(bond_type)]) | ||
|
|
||
| if bonds_to_add: | ||
| new_bonds = struc.BondList( | ||
| atom_array_accum.array_length(), np.array(bonds_to_add, dtype=np.int64) | ||
| ) | ||
| atom_array_accum.bonds = atom_array_accum.bonds.merge(new_bonds) | ||
| return atom_array_accum | ||
|
|
||
|
|
||
| def _sort_bonds(atom_array_accum: struc.AtomArray) -> struc.AtomArray: | ||
| """Sort bonds deterministically by atom indices then bond type.""" | ||
| bonds_arr = atom_array_accum.bonds.as_array().copy() | ||
| # ensure lower index first | ||
| swap_mask = bonds_arr[:, 0] > bonds_arr[:, 1] | ||
| bonds_arr[swap_mask, :2] = bonds_arr[swap_mask][:, [1, 0]] | ||
| order = np.lexsort((bonds_arr[:, 2], bonds_arr[:, 1], bonds_arr[:, 0])) | ||
| bonds_arr = bonds_arr[order] | ||
| atom_array_accum.bonds = struc.BondList( | ||
| atom_array_accum.array_length(), bonds_arr.astype(np.int64) | ||
| ) | ||
| return atom_array_accum | ||
|
|
||
|
|
||
| def accumulate_components( | ||
| components_to_accumulate: List[Union[str, int]], | ||
| *, | ||
|
|
@@ -1013,6 +1267,8 @@ def accumulate_components( | |
| res_id = start_resid | ||
| molecule_id = 0 | ||
| source_to_accum_idx: Dict[int, int] = {} | ||
| source_idx_to_component: Dict[int, str] = {} | ||
| unindexed_component_names = set(unindexed_tokens.keys()) | ||
| current_accum_idx = sum(len(arr) for arr in atom_array_accum) | ||
|
|
||
| # ... Insert contig information one- by one- | ||
|
|
@@ -1095,6 +1351,7 @@ def accumulate_components( | |
| ): | ||
| for i, src_idx in enumerate(src_indices): | ||
| source_to_accum_idx[int(src_idx)] = current_accum_idx + i | ||
| source_idx_to_component[int(src_idx)] = str(component) | ||
|
|
||
| # ... Insert & Increment residue ID | ||
| atom_array_accum.append(token) | ||
|
|
@@ -1104,23 +1361,17 @@ def accumulate_components( | |
| # ... Concatenate all components | ||
| atom_array_accum = struc.concatenate(atom_array_accum) | ||
| atom_array_accum.set_annotation("pn_unit_iid", atom_array_accum.chain_id) | ||
|
|
||
| should_restore_bonds = ( | ||
| src_atom_array is not None | ||
| and bool(source_to_accum_idx) | ||
| and _check_has_backbone_connections_to_nonstandard_residues( | ||
| atom_array_accum, src_atom_array | ||
| ) | ||
| atom_array_accum = _restore_component_bonds( | ||
| atom_array_accum=atom_array_accum, | ||
| src_atom_array=src_atom_array, | ||
| source_to_accum_idx=source_to_accum_idx, | ||
| source_idx_to_component=source_idx_to_component, | ||
| unindexed_components=unindexed_component_names, | ||
| ) | ||
| if should_restore_bonds: | ||
| assert not unindexed_tokens, ( | ||
| "PTM backbone bond restoration is not compatible with unindexed components. " | ||
| "PTMs must be specified as indexed components (using 'contig' parameter, not 'unindex'). " | ||
| f"Found unindexed components: {list(unindexed_tokens.keys())}" | ||
| ) | ||
| atom_array_accum = _restore_bonds_for_nonstandard_residues( | ||
| atom_array_accum, src_atom_array, source_to_accum_idx | ||
| ) | ||
| atom_array_accum = _add_backbone_bonds_for_nonstandard_residues( | ||
| atom_array_accum=atom_array_accum | ||
| ) | ||
| atom_array_accum = _sort_bonds(atom_array_accum) | ||
|
|
||
| # Reset res_id for unindexed residues to avoid duplicates (ridiculously long lines of code, cleanup later) | ||
| if np.any(atom_array_accum.is_motif_atom_unindexed.astype(bool)) and not np.all( | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In
_restore_component_bonds(), the unindexed-component guard for the “one endpoint missing” case will always raise for any non-backbone bond because theor not (_is_standard_polymer_backbone_bond(...) or _is_polymer_backbone_like(...))clause is redundant (those cases are alreadycontinued above). This makes the condition effectively always-true and can trigger unexpected AssertionErrors. Consider changing the condition to only check whether the bond crosses residue/chain boundaries (and keep backbone-like early-continue), so intra-residue bonds with missing atoms don’t hard-fail.