Skip to content

Commit 7dd7b31

Browse files
Improve PTM bond handling in RFdiffusion3 (#256)
* Improve PTM bond handling in RFdiffusion3 input parsing Replace the old _restore_bonds_for_nonstandard_residues approach with a more robust bond restoration system that properly handles unindexed components, backbone-like bonds for non-standard residues (PTMs), and cross-residue bond preservation from source structures. Adds the legacy counterpart and regression tests for both code paths. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * Apply ruff formatting to PTM bond handling files Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * Remove is_motif_atom guards from legacy input parsing The bare is_motif_atom annotation is being abolished; remove the guards that were adding it when missing. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * Gate bond restoration behind nonstandard-residue check Only run _restore_component_bonds, _add_backbone_bonds_for_nonstandard_residues, and _sort_bonds when the source structure actually has backbone connections to nonstandard residues. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> --------- Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent b5395e4 commit 7dd7b31

4 files changed

Lines changed: 955 additions & 9 deletions

File tree

models/rfd3/src/rfd3/inference/input_parsing.py

Lines changed: 269 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,9 @@
99
from typing import Any, Dict, List, Optional, Union
1010

1111
import numpy as np
12-
from atomworks.constants import STANDARD_AA
12+
from atomworks.constants import STANDARD_AA, STANDARD_DNA, STANDARD_RNA
1313
from atomworks.io.parser import parse_atom_array
14+
from atomworks.io.utils.bonds import get_inferred_polymer_bonds
1415

1516
# from atomworks.ml.datasets.datasets import BaseDataset
1617
from atomworks.ml.transforms.base import TransformedDict
@@ -32,6 +33,7 @@
3233
REQUIRED_INFERENCE_ANNOTATIONS,
3334
)
3435
from rfd3.inference.legacy_input_parsing import (
36+
_check_has_backbone_connections_to_nonstandard_residues,
3537
create_atom_array_from_design_specification_legacy,
3638
)
3739
from rfd3.inference.parsing import InputSelection
@@ -48,7 +50,6 @@
4850
)
4951
from rfd3.transforms.util_transforms import assign_types_
5052
from rfd3.utils.inference import (
51-
_restore_bonds_for_nonstandard_residues,
5253
extract_ligand_array,
5354
inference_load_,
5455
set_com,
@@ -58,6 +59,7 @@
5859

5960
from foundry.common import exists
6061
from foundry.utils.components import (
62+
fetch_mask_from_idx,
6163
get_design_pattern_with_constraints,
6264
get_motif_components_and_breaks,
6365
)
@@ -1000,6 +1002,259 @@ def create_motif_residue(
10001002
return token
10011003

10021004

1005+
def _polymer_link_atoms_for_residue(res_name: str) -> set[frozenset[str]]:
1006+
"""
1007+
Return the atom-name pairs that represent canonical polymerization atoms for a residue.
1008+
1009+
Only standard AA/DNA/RNA residues are treated as having canonical backbone links; PTMs
1010+
and other chem comp types fall back to nonstandard handling.
1011+
"""
1012+
if res_name in STANDARD_AA:
1013+
return {frozenset({"C", "N"})}
1014+
if res_name in STANDARD_DNA or res_name in STANDARD_RNA:
1015+
return {frozenset({"O3'", "P"}), frozenset({"O3*", "P"})} # allow legacy O3*
1016+
return set()
1017+
1018+
1019+
def _is_standard_polymer_backbone_bond(atom_a: struc.Atom, atom_b: struc.Atom) -> bool:
1020+
if atom_a.chain_id != atom_b.chain_id:
1021+
return False
1022+
if abs(atom_a.res_id - atom_b.res_id) != 1:
1023+
return False
1024+
atom_pair = frozenset({atom_a.atom_name, atom_b.atom_name})
1025+
1026+
pairs_a = _polymer_link_atoms_for_residue(atom_a.res_name)
1027+
pairs_b = _polymer_link_atoms_for_residue(atom_b.res_name)
1028+
shared_pairs = pairs_a & pairs_b
1029+
return atom_pair in shared_pairs
1030+
1031+
1032+
def _is_polymer_backbone_like(atom_a: struc.Atom, atom_b: struc.Atom) -> bool:
1033+
"""
1034+
Broader backbone check that treats canonical polymer atoms (C/N for peptide,
1035+
O3'/O3*–P for nucleic) as backbone links even when the residue itself is
1036+
non-standard (e.g., PTR/SEP).
1037+
"""
1038+
if atom_a.chain_id != atom_b.chain_id:
1039+
return False
1040+
if abs(atom_a.res_id - atom_b.res_id) != 1:
1041+
return False
1042+
atom_pair = frozenset({atom_a.atom_name, atom_b.atom_name})
1043+
return atom_pair in {
1044+
frozenset({"C", "N"}),
1045+
frozenset({"O3'", "P"}),
1046+
frozenset({"O3*", "P"}),
1047+
}
1048+
1049+
1050+
def _restore_component_bonds(
1051+
atom_array_accum: struc.AtomArray,
1052+
src_atom_array: Optional[struc.AtomArray],
1053+
source_to_accum_idx: Dict[int, int],
1054+
source_idx_to_component: Dict[int, str],
1055+
unindexed_components: set[str],
1056+
) -> struc.AtomArray:
1057+
"""
1058+
Rehydrate bonds from the input structure onto the accumulated array.
1059+
1060+
- Replays bonds from `src_atom_array` using the provided source→accum mappings.
1061+
- Skips canonical polymer backbone bonds (peptide/nucleic) that are reconstructed elsewhere.
1062+
- Protects unindexed components by disallowing cross-residue bonds that involve
1063+
an unindexed residue (except standard backbone).
1064+
- Emits warnings when a source bond cannot be remapped because one endpoint was
1065+
dropped during accumulation.
1066+
"""
1067+
if atom_array_accum.bonds is None:
1068+
atom_array_accum.bonds = struc.BondList(atom_array_accum.array_length())
1069+
1070+
if (
1071+
src_atom_array is None
1072+
or not hasattr(src_atom_array, "bonds")
1073+
or src_atom_array.bonds is None
1074+
or not source_to_accum_idx
1075+
):
1076+
return atom_array_accum
1077+
1078+
bonds_to_add: List[List[int]] = []
1079+
seen_pairs: set[tuple[int, int]] = set()
1080+
src_bonds = np.asarray(src_atom_array.bonds.as_array(), dtype=np.int64)
1081+
1082+
def _is_unindexed_source(idx: int) -> bool:
1083+
component = source_idx_to_component.get(idx)
1084+
return component in unindexed_components if component is not None else False
1085+
1086+
def _fmt_atom(atom: struc.Atom) -> str:
1087+
# Use a readable residue/atom separator to avoid names running together
1088+
return f"{atom.chain_id}{atom.res_id}:{atom.res_name}_{atom.atom_name}"
1089+
1090+
for atom_i_idx, atom_j_idx, bond_type in src_bonds:
1091+
atom_i_idx = int(atom_i_idx)
1092+
atom_j_idx = int(atom_j_idx)
1093+
bond_type = int(bond_type)
1094+
mapped_i = source_to_accum_idx.get(atom_i_idx)
1095+
mapped_j = source_to_accum_idx.get(atom_j_idx)
1096+
1097+
atom_i = src_atom_array[atom_i_idx]
1098+
atom_j = src_atom_array[atom_j_idx]
1099+
1100+
# If we only have one side of the bond, assert if the mapped atom is from an
1101+
# unindexed component and the bond would connect across residues.
1102+
if mapped_i is None or mapped_j is None:
1103+
if _is_standard_polymer_backbone_bond(
1104+
atom_i, atom_j
1105+
) or _is_polymer_backbone_like(atom_i, atom_j):
1106+
continue
1107+
if mapped_i is not None and _is_unindexed_source(atom_i_idx):
1108+
if (
1109+
atom_i.chain_id != atom_j.chain_id
1110+
or atom_i.res_id != atom_j.res_id
1111+
or not (
1112+
_is_standard_polymer_backbone_bond(atom_i, atom_j)
1113+
or _is_polymer_backbone_like(atom_i, atom_j)
1114+
)
1115+
):
1116+
raise AssertionError(
1117+
f"Unsupported bond between unindexed component {atom_i.chain_id}{atom_i.res_id} "
1118+
f"and omitted residue {atom_j.chain_id}{atom_j.res_id}."
1119+
)
1120+
if mapped_j is not None and _is_unindexed_source(atom_j_idx):
1121+
if (
1122+
atom_i.chain_id != atom_j.chain_id
1123+
or atom_i.res_id != atom_j.res_id
1124+
or not (
1125+
_is_standard_polymer_backbone_bond(atom_i, atom_j)
1126+
or _is_polymer_backbone_like(atom_i, atom_j)
1127+
)
1128+
):
1129+
raise AssertionError(
1130+
f"Unsupported bond between unindexed component {atom_j.chain_id}{atom_j.res_id} "
1131+
f"and omitted residue {atom_i.chain_id}{atom_i.res_id}."
1132+
)
1133+
# Only warn when we retained one side of a cross-residue/chain linkage
1134+
# (e.g., glycan partner missing), not for missing intra-residue atoms.
1135+
if (mapped_i is not None or mapped_j is not None) and (
1136+
atom_i.chain_id != atom_j.chain_id or atom_i.res_id != atom_j.res_id
1137+
):
1138+
logger.warning(
1139+
(
1140+
"Skipping non-backbone bond from source structure between %s and %s (type %d): "
1141+
"one atom is not present in accumulated components. "
1142+
"Bond cannot be inferred automatically; set it manually if needed."
1143+
)
1144+
% (_fmt_atom(atom_i), _fmt_atom(atom_j), bond_type)
1145+
)
1146+
continue
1147+
1148+
# Do not connect unindexed residues to anything else for now.
1149+
comp_i = source_idx_to_component.get(atom_i_idx)
1150+
comp_j = source_idx_to_component.get(atom_j_idx)
1151+
if (comp_i in unindexed_components or comp_j in unindexed_components) and (
1152+
atom_i.chain_id != atom_j.chain_id or atom_i.res_id != atom_j.res_id
1153+
):
1154+
if not (
1155+
_is_standard_polymer_backbone_bond(atom_i, atom_j)
1156+
or _is_polymer_backbone_like(atom_i, atom_j)
1157+
):
1158+
raise AssertionError(
1159+
"Bonds involving unindexed residues are not yet supported."
1160+
)
1161+
continue
1162+
1163+
if _is_standard_polymer_backbone_bond(
1164+
atom_i, atom_j
1165+
) or _is_polymer_backbone_like(atom_i, atom_j):
1166+
continue
1167+
1168+
pair = (min(mapped_i, mapped_j), max(mapped_i, mapped_j))
1169+
if pair in seen_pairs:
1170+
continue
1171+
seen_pairs.add(pair)
1172+
bonds_to_add.append([mapped_i, mapped_j, bond_type])
1173+
1174+
bond_array = (
1175+
np.array(bonds_to_add, dtype=np.int64)
1176+
if bonds_to_add
1177+
else np.empty((0, 3), dtype=np.int64)
1178+
)
1179+
new_bonds = struc.BondList(atom_array_accum.array_length(), bond_array)
1180+
atom_array_accum.bonds = atom_array_accum.bonds.merge(new_bonds)
1181+
return atom_array_accum
1182+
1183+
1184+
def _add_backbone_bonds_for_nonstandard_residues(
1185+
atom_array_accum: struc.AtomArray,
1186+
) -> struc.AtomArray:
1187+
"""
1188+
Add backbone/polymer bonds for cases where at least one residue is non-standard.
1189+
1190+
Uses `atomworks.io.utils.bonds.get_inferred_polymer_bonds`, which consults CCD
1191+
chem-comp metadata to decide the correct polymerization atoms (C/N, CG/N, etc.).
1192+
Only bonds involving at least one non-standard residue are added; standard
1193+
AA/DNA/RNA pairs are assumed to already carry their backbone bonds.
1194+
"""
1195+
if atom_array_accum.bonds is None:
1196+
atom_array_accum.bonds = struc.BondList(atom_array_accum.array_length())
1197+
1198+
unindexed_mask = (
1199+
atom_array_accum.get_annotation("is_motif_atom_unindexed")
1200+
if "is_motif_atom_unindexed" in atom_array_accum.get_annotation_categories()
1201+
else np.zeros(atom_array_accum.array_length(), dtype=bool)
1202+
)
1203+
1204+
existing_pairs = {
1205+
(min(a, b), max(a, b)) for a, b, _ in atom_array_accum.bonds.as_array()
1206+
}
1207+
bonds_to_add: List[List[int]] = []
1208+
1209+
inferred_bonds, _ = get_inferred_polymer_bonds(atom_array_accum)
1210+
for atom_i_idx, atom_j_idx, bond_type in inferred_bonds:
1211+
atom_i_idx = int(atom_i_idx)
1212+
atom_j_idx = int(atom_j_idx)
1213+
1214+
# Do not connect unindexed residues across residue boundaries
1215+
if (unindexed_mask[atom_i_idx] or unindexed_mask[atom_j_idx]) and (
1216+
atom_array_accum.chain_id[atom_i_idx]
1217+
!= atom_array_accum.chain_id[atom_j_idx]
1218+
or atom_array_accum.res_id[atom_i_idx]
1219+
!= atom_array_accum.res_id[atom_j_idx]
1220+
):
1221+
continue
1222+
1223+
# Only synthesize bonds when at least one residue is non-standard; standard
1224+
# backbone bonds should already exist.
1225+
if _is_standard_polymer_backbone_bond(
1226+
atom_array_accum[atom_i_idx], atom_array_accum[atom_j_idx]
1227+
):
1228+
continue
1229+
1230+
pair = (min(atom_i_idx, atom_j_idx), max(atom_i_idx, atom_j_idx))
1231+
if pair in existing_pairs:
1232+
continue
1233+
existing_pairs.add(pair)
1234+
bonds_to_add.append([pair[0], pair[1], int(bond_type)])
1235+
1236+
if bonds_to_add:
1237+
new_bonds = struc.BondList(
1238+
atom_array_accum.array_length(), np.array(bonds_to_add, dtype=np.int64)
1239+
)
1240+
atom_array_accum.bonds = atom_array_accum.bonds.merge(new_bonds)
1241+
return atom_array_accum
1242+
1243+
1244+
def _sort_bonds(atom_array_accum: struc.AtomArray) -> struc.AtomArray:
1245+
"""Sort bonds deterministically by atom indices then bond type."""
1246+
bonds_arr = atom_array_accum.bonds.as_array().copy()
1247+
# ensure lower index first
1248+
swap_mask = bonds_arr[:, 0] > bonds_arr[:, 1]
1249+
bonds_arr[swap_mask, :2] = bonds_arr[swap_mask][:, [1, 0]]
1250+
order = np.lexsort((bonds_arr[:, 2], bonds_arr[:, 1], bonds_arr[:, 0]))
1251+
bonds_arr = bonds_arr[order]
1252+
atom_array_accum.bonds = struc.BondList(
1253+
atom_array_accum.array_length(), bonds_arr.astype(np.int64)
1254+
)
1255+
return atom_array_accum
1256+
1257+
10031258
def accumulate_components(
10041259
components_to_accumulate: List[Union[str, int]],
10051260
*,
@@ -1045,6 +1300,8 @@ def accumulate_components(
10451300
res_id = start_resid
10461301
molecule_id = 0
10471302
source_to_accum_idx: Dict[int, int] = {}
1303+
source_idx_to_component: Dict[int, str] = {}
1304+
unindexed_component_names = set(unindexed_tokens.keys())
10481305
current_accum_idx = sum(len(arr) for arr in atom_array_accum)
10491306

10501307
# ... Insert contig information one- by one-
@@ -1127,6 +1384,7 @@ def accumulate_components(
11271384
):
11281385
for i, src_idx in enumerate(src_indices):
11291386
source_to_accum_idx[int(src_idx)] = current_accum_idx + i
1387+
source_idx_to_component[int(src_idx)] = str(component)
11301388

11311389
# ... Insert & Increment residue ID
11321390
atom_array_accum.append(token)
@@ -1136,7 +1394,6 @@ def accumulate_components(
11361394
# ... Concatenate all components
11371395
atom_array_accum = struc.concatenate(atom_array_accum)
11381396
atom_array_accum.set_annotation("pn_unit_iid", atom_array_accum.chain_id)
1139-
11401397
should_restore_bonds = (
11411398
src_atom_array is not None
11421399
and bool(source_to_accum_idx)
@@ -1145,14 +1402,17 @@ def accumulate_components(
11451402
)
11461403
)
11471404
if should_restore_bonds:
1148-
assert not unindexed_tokens, (
1149-
"PTM backbone bond restoration is not compatible with unindexed components. "
1150-
"PTMs must be specified as indexed components (using 'contig' parameter, not 'unindex'). "
1151-
f"Found unindexed components: {list(unindexed_tokens.keys())}"
1405+
atom_array_accum = _restore_component_bonds(
1406+
atom_array_accum=atom_array_accum,
1407+
src_atom_array=src_atom_array,
1408+
source_to_accum_idx=source_to_accum_idx,
1409+
source_idx_to_component=source_idx_to_component,
1410+
unindexed_components=unindexed_component_names,
11521411
)
1153-
atom_array_accum = _restore_bonds_for_nonstandard_residues(
1154-
atom_array_accum, src_atom_array, source_to_accum_idx
1412+
atom_array_accum = _add_backbone_bonds_for_nonstandard_residues(
1413+
atom_array_accum=atom_array_accum
11551414
)
1415+
atom_array_accum = _sort_bonds(atom_array_accum)
11561416

11571417
# Reset res_id for unindexed residues to avoid duplicates (ridiculously long lines of code, cleanup later)
11581418
if np.any(atom_array_accum.is_motif_atom_unindexed.astype(bool)) and not np.all(

0 commit comments

Comments
 (0)