99from typing import Any , Dict , List , Optional , Union
1010
1111import numpy as np
12- from atomworks .constants import STANDARD_AA
12+ from atomworks .constants import STANDARD_AA , STANDARD_DNA , STANDARD_RNA
1313from atomworks .io .parser import parse_atom_array
14+ from atomworks .io .utils .bonds import get_inferred_polymer_bonds
1415
1516# from atomworks.ml.datasets.datasets import BaseDataset
1617from atomworks .ml .transforms .base import TransformedDict
3233 REQUIRED_INFERENCE_ANNOTATIONS ,
3334)
3435from rfd3 .inference .legacy_input_parsing import (
36+ _check_has_backbone_connections_to_nonstandard_residues ,
3537 create_atom_array_from_design_specification_legacy ,
3638)
3739from rfd3 .inference .parsing import InputSelection
4850)
4951from rfd3 .transforms .util_transforms import assign_types_
5052from rfd3 .utils .inference import (
51- _restore_bonds_for_nonstandard_residues ,
5253 extract_ligand_array ,
5354 inference_load_ ,
5455 set_com ,
5859
5960from foundry .common import exists
6061from foundry .utils .components import (
62+ fetch_mask_from_idx ,
6163 get_design_pattern_with_constraints ,
6264 get_motif_components_and_breaks ,
6365)
@@ -1000,6 +1002,259 @@ def create_motif_residue(
10001002 return token
10011003
10021004
1005+ def _polymer_link_atoms_for_residue (res_name : str ) -> set [frozenset [str ]]:
1006+ """
1007+ Return the atom-name pairs that represent canonical polymerization atoms for a residue.
1008+
1009+ Only standard AA/DNA/RNA residues are treated as having canonical backbone links; PTMs
1010+ and other chem comp types fall back to nonstandard handling.
1011+ """
1012+ if res_name in STANDARD_AA :
1013+ return {frozenset ({"C" , "N" })}
1014+ if res_name in STANDARD_DNA or res_name in STANDARD_RNA :
1015+ return {frozenset ({"O3'" , "P" }), frozenset ({"O3*" , "P" })} # allow legacy O3*
1016+ return set ()
1017+
1018+
1019+ def _is_standard_polymer_backbone_bond (atom_a : struc .Atom , atom_b : struc .Atom ) -> bool :
1020+ if atom_a .chain_id != atom_b .chain_id :
1021+ return False
1022+ if abs (atom_a .res_id - atom_b .res_id ) != 1 :
1023+ return False
1024+ atom_pair = frozenset ({atom_a .atom_name , atom_b .atom_name })
1025+
1026+ pairs_a = _polymer_link_atoms_for_residue (atom_a .res_name )
1027+ pairs_b = _polymer_link_atoms_for_residue (atom_b .res_name )
1028+ shared_pairs = pairs_a & pairs_b
1029+ return atom_pair in shared_pairs
1030+
1031+
1032+ def _is_polymer_backbone_like (atom_a : struc .Atom , atom_b : struc .Atom ) -> bool :
1033+ """
1034+ Broader backbone check that treats canonical polymer atoms (C/N for peptide,
1035+ O3'/O3*–P for nucleic) as backbone links even when the residue itself is
1036+ non-standard (e.g., PTR/SEP).
1037+ """
1038+ if atom_a .chain_id != atom_b .chain_id :
1039+ return False
1040+ if abs (atom_a .res_id - atom_b .res_id ) != 1 :
1041+ return False
1042+ atom_pair = frozenset ({atom_a .atom_name , atom_b .atom_name })
1043+ return atom_pair in {
1044+ frozenset ({"C" , "N" }),
1045+ frozenset ({"O3'" , "P" }),
1046+ frozenset ({"O3*" , "P" }),
1047+ }
1048+
1049+
1050+ def _restore_component_bonds (
1051+ atom_array_accum : struc .AtomArray ,
1052+ src_atom_array : Optional [struc .AtomArray ],
1053+ source_to_accum_idx : Dict [int , int ],
1054+ source_idx_to_component : Dict [int , str ],
1055+ unindexed_components : set [str ],
1056+ ) -> struc .AtomArray :
1057+ """
1058+ Rehydrate bonds from the input structure onto the accumulated array.
1059+
1060+ - Replays bonds from `src_atom_array` using the provided source→accum mappings.
1061+ - Skips canonical polymer backbone bonds (peptide/nucleic) that are reconstructed elsewhere.
1062+ - Protects unindexed components by disallowing cross-residue bonds that involve
1063+ an unindexed residue (except standard backbone).
1064+ - Emits warnings when a source bond cannot be remapped because one endpoint was
1065+ dropped during accumulation.
1066+ """
1067+ if atom_array_accum .bonds is None :
1068+ atom_array_accum .bonds = struc .BondList (atom_array_accum .array_length ())
1069+
1070+ if (
1071+ src_atom_array is None
1072+ or not hasattr (src_atom_array , "bonds" )
1073+ or src_atom_array .bonds is None
1074+ or not source_to_accum_idx
1075+ ):
1076+ return atom_array_accum
1077+
1078+ bonds_to_add : List [List [int ]] = []
1079+ seen_pairs : set [tuple [int , int ]] = set ()
1080+ src_bonds = np .asarray (src_atom_array .bonds .as_array (), dtype = np .int64 )
1081+
1082+ def _is_unindexed_source (idx : int ) -> bool :
1083+ component = source_idx_to_component .get (idx )
1084+ return component in unindexed_components if component is not None else False
1085+
1086+ def _fmt_atom (atom : struc .Atom ) -> str :
1087+ # Use a readable residue/atom separator to avoid names running together
1088+ return f"{ atom .chain_id } { atom .res_id } :{ atom .res_name } _{ atom .atom_name } "
1089+
1090+ for atom_i_idx , atom_j_idx , bond_type in src_bonds :
1091+ atom_i_idx = int (atom_i_idx )
1092+ atom_j_idx = int (atom_j_idx )
1093+ bond_type = int (bond_type )
1094+ mapped_i = source_to_accum_idx .get (atom_i_idx )
1095+ mapped_j = source_to_accum_idx .get (atom_j_idx )
1096+
1097+ atom_i = src_atom_array [atom_i_idx ]
1098+ atom_j = src_atom_array [atom_j_idx ]
1099+
1100+ # If we only have one side of the bond, assert if the mapped atom is from an
1101+ # unindexed component and the bond would connect across residues.
1102+ if mapped_i is None or mapped_j is None :
1103+ if _is_standard_polymer_backbone_bond (
1104+ atom_i , atom_j
1105+ ) or _is_polymer_backbone_like (atom_i , atom_j ):
1106+ continue
1107+ if mapped_i is not None and _is_unindexed_source (atom_i_idx ):
1108+ if (
1109+ atom_i .chain_id != atom_j .chain_id
1110+ or atom_i .res_id != atom_j .res_id
1111+ or not (
1112+ _is_standard_polymer_backbone_bond (atom_i , atom_j )
1113+ or _is_polymer_backbone_like (atom_i , atom_j )
1114+ )
1115+ ):
1116+ raise AssertionError (
1117+ f"Unsupported bond between unindexed component { atom_i .chain_id } { atom_i .res_id } "
1118+ f"and omitted residue { atom_j .chain_id } { atom_j .res_id } ."
1119+ )
1120+ if mapped_j is not None and _is_unindexed_source (atom_j_idx ):
1121+ if (
1122+ atom_i .chain_id != atom_j .chain_id
1123+ or atom_i .res_id != atom_j .res_id
1124+ or not (
1125+ _is_standard_polymer_backbone_bond (atom_i , atom_j )
1126+ or _is_polymer_backbone_like (atom_i , atom_j )
1127+ )
1128+ ):
1129+ raise AssertionError (
1130+ f"Unsupported bond between unindexed component { atom_j .chain_id } { atom_j .res_id } "
1131+ f"and omitted residue { atom_i .chain_id } { atom_i .res_id } ."
1132+ )
1133+ # Only warn when we retained one side of a cross-residue/chain linkage
1134+ # (e.g., glycan partner missing), not for missing intra-residue atoms.
1135+ if (mapped_i is not None or mapped_j is not None ) and (
1136+ atom_i .chain_id != atom_j .chain_id or atom_i .res_id != atom_j .res_id
1137+ ):
1138+ logger .warning (
1139+ (
1140+ "Skipping non-backbone bond from source structure between %s and %s (type %d): "
1141+ "one atom is not present in accumulated components. "
1142+ "Bond cannot be inferred automatically; set it manually if needed."
1143+ )
1144+ % (_fmt_atom (atom_i ), _fmt_atom (atom_j ), bond_type )
1145+ )
1146+ continue
1147+
1148+ # Do not connect unindexed residues to anything else for now.
1149+ comp_i = source_idx_to_component .get (atom_i_idx )
1150+ comp_j = source_idx_to_component .get (atom_j_idx )
1151+ if (comp_i in unindexed_components or comp_j in unindexed_components ) and (
1152+ atom_i .chain_id != atom_j .chain_id or atom_i .res_id != atom_j .res_id
1153+ ):
1154+ if not (
1155+ _is_standard_polymer_backbone_bond (atom_i , atom_j )
1156+ or _is_polymer_backbone_like (atom_i , atom_j )
1157+ ):
1158+ raise AssertionError (
1159+ "Bonds involving unindexed residues are not yet supported."
1160+ )
1161+ continue
1162+
1163+ if _is_standard_polymer_backbone_bond (
1164+ atom_i , atom_j
1165+ ) or _is_polymer_backbone_like (atom_i , atom_j ):
1166+ continue
1167+
1168+ pair = (min (mapped_i , mapped_j ), max (mapped_i , mapped_j ))
1169+ if pair in seen_pairs :
1170+ continue
1171+ seen_pairs .add (pair )
1172+ bonds_to_add .append ([mapped_i , mapped_j , bond_type ])
1173+
1174+ bond_array = (
1175+ np .array (bonds_to_add , dtype = np .int64 )
1176+ if bonds_to_add
1177+ else np .empty ((0 , 3 ), dtype = np .int64 )
1178+ )
1179+ new_bonds = struc .BondList (atom_array_accum .array_length (), bond_array )
1180+ atom_array_accum .bonds = atom_array_accum .bonds .merge (new_bonds )
1181+ return atom_array_accum
1182+
1183+
1184+ def _add_backbone_bonds_for_nonstandard_residues (
1185+ atom_array_accum : struc .AtomArray ,
1186+ ) -> struc .AtomArray :
1187+ """
1188+ Add backbone/polymer bonds for cases where at least one residue is non-standard.
1189+
1190+ Uses `atomworks.io.utils.bonds.get_inferred_polymer_bonds`, which consults CCD
1191+ chem-comp metadata to decide the correct polymerization atoms (C/N, CG/N, etc.).
1192+ Only bonds involving at least one non-standard residue are added; standard
1193+ AA/DNA/RNA pairs are assumed to already carry their backbone bonds.
1194+ """
1195+ if atom_array_accum .bonds is None :
1196+ atom_array_accum .bonds = struc .BondList (atom_array_accum .array_length ())
1197+
1198+ unindexed_mask = (
1199+ atom_array_accum .get_annotation ("is_motif_atom_unindexed" )
1200+ if "is_motif_atom_unindexed" in atom_array_accum .get_annotation_categories ()
1201+ else np .zeros (atom_array_accum .array_length (), dtype = bool )
1202+ )
1203+
1204+ existing_pairs = {
1205+ (min (a , b ), max (a , b )) for a , b , _ in atom_array_accum .bonds .as_array ()
1206+ }
1207+ bonds_to_add : List [List [int ]] = []
1208+
1209+ inferred_bonds , _ = get_inferred_polymer_bonds (atom_array_accum )
1210+ for atom_i_idx , atom_j_idx , bond_type in inferred_bonds :
1211+ atom_i_idx = int (atom_i_idx )
1212+ atom_j_idx = int (atom_j_idx )
1213+
1214+ # Do not connect unindexed residues across residue boundaries
1215+ if (unindexed_mask [atom_i_idx ] or unindexed_mask [atom_j_idx ]) and (
1216+ atom_array_accum .chain_id [atom_i_idx ]
1217+ != atom_array_accum .chain_id [atom_j_idx ]
1218+ or atom_array_accum .res_id [atom_i_idx ]
1219+ != atom_array_accum .res_id [atom_j_idx ]
1220+ ):
1221+ continue
1222+
1223+ # Only synthesize bonds when at least one residue is non-standard; standard
1224+ # backbone bonds should already exist.
1225+ if _is_standard_polymer_backbone_bond (
1226+ atom_array_accum [atom_i_idx ], atom_array_accum [atom_j_idx ]
1227+ ):
1228+ continue
1229+
1230+ pair = (min (atom_i_idx , atom_j_idx ), max (atom_i_idx , atom_j_idx ))
1231+ if pair in existing_pairs :
1232+ continue
1233+ existing_pairs .add (pair )
1234+ bonds_to_add .append ([pair [0 ], pair [1 ], int (bond_type )])
1235+
1236+ if bonds_to_add :
1237+ new_bonds = struc .BondList (
1238+ atom_array_accum .array_length (), np .array (bonds_to_add , dtype = np .int64 )
1239+ )
1240+ atom_array_accum .bonds = atom_array_accum .bonds .merge (new_bonds )
1241+ return atom_array_accum
1242+
1243+
1244+ def _sort_bonds (atom_array_accum : struc .AtomArray ) -> struc .AtomArray :
1245+ """Sort bonds deterministically by atom indices then bond type."""
1246+ bonds_arr = atom_array_accum .bonds .as_array ().copy ()
1247+ # ensure lower index first
1248+ swap_mask = bonds_arr [:, 0 ] > bonds_arr [:, 1 ]
1249+ bonds_arr [swap_mask , :2 ] = bonds_arr [swap_mask ][:, [1 , 0 ]]
1250+ order = np .lexsort ((bonds_arr [:, 2 ], bonds_arr [:, 1 ], bonds_arr [:, 0 ]))
1251+ bonds_arr = bonds_arr [order ]
1252+ atom_array_accum .bonds = struc .BondList (
1253+ atom_array_accum .array_length (), bonds_arr .astype (np .int64 )
1254+ )
1255+ return atom_array_accum
1256+
1257+
10031258def accumulate_components (
10041259 components_to_accumulate : List [Union [str , int ]],
10051260 * ,
@@ -1045,6 +1300,8 @@ def accumulate_components(
10451300 res_id = start_resid
10461301 molecule_id = 0
10471302 source_to_accum_idx : Dict [int , int ] = {}
1303+ source_idx_to_component : Dict [int , str ] = {}
1304+ unindexed_component_names = set (unindexed_tokens .keys ())
10481305 current_accum_idx = sum (len (arr ) for arr in atom_array_accum )
10491306
10501307 # ... Insert contig information one- by one-
@@ -1127,6 +1384,7 @@ def accumulate_components(
11271384 ):
11281385 for i , src_idx in enumerate (src_indices ):
11291386 source_to_accum_idx [int (src_idx )] = current_accum_idx + i
1387+ source_idx_to_component [int (src_idx )] = str (component )
11301388
11311389 # ... Insert & Increment residue ID
11321390 atom_array_accum .append (token )
@@ -1136,7 +1394,6 @@ def accumulate_components(
11361394 # ... Concatenate all components
11371395 atom_array_accum = struc .concatenate (atom_array_accum )
11381396 atom_array_accum .set_annotation ("pn_unit_iid" , atom_array_accum .chain_id )
1139-
11401397 should_restore_bonds = (
11411398 src_atom_array is not None
11421399 and bool (source_to_accum_idx )
@@ -1145,14 +1402,17 @@ def accumulate_components(
11451402 )
11461403 )
11471404 if should_restore_bonds :
1148- assert not unindexed_tokens , (
1149- "PTM backbone bond restoration is not compatible with unindexed components. "
1150- "PTMs must be specified as indexed components (using 'contig' parameter, not 'unindex'). "
1151- f"Found unindexed components: { list (unindexed_tokens .keys ())} "
1405+ atom_array_accum = _restore_component_bonds (
1406+ atom_array_accum = atom_array_accum ,
1407+ src_atom_array = src_atom_array ,
1408+ source_to_accum_idx = source_to_accum_idx ,
1409+ source_idx_to_component = source_idx_to_component ,
1410+ unindexed_components = unindexed_component_names ,
11521411 )
1153- atom_array_accum = _restore_bonds_for_nonstandard_residues (
1154- atom_array_accum , src_atom_array , source_to_accum_idx
1412+ atom_array_accum = _add_backbone_bonds_for_nonstandard_residues (
1413+ atom_array_accum = atom_array_accum
11551414 )
1415+ atom_array_accum = _sort_bonds (atom_array_accum )
11561416
11571417 # Reset res_id for unindexed residues to avoid duplicates (ridiculously long lines of code, cleanup later)
11581418 if np .any (atom_array_accum .is_motif_atom_unindexed .astype (bool )) and not np .all (
0 commit comments