Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
287 changes: 269 additions & 18 deletions models/rfd3/src/rfd3/inference/input_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@
from typing import Any, Dict, List, Optional, Union

import numpy as np
from atomworks.constants import STANDARD_AA
from atomworks.constants import STANDARD_AA, STANDARD_DNA, STANDARD_RNA
from atomworks.io.parser import parse_atom_array
from atomworks.io.utils.bonds import get_inferred_polymer_bonds

# from atomworks.ml.datasets.datasets import BaseDataset
from atomworks.ml.transforms.base import TransformedDict
Expand Down Expand Up @@ -48,7 +49,6 @@
)
from rfd3.transforms.util_transforms import assign_types_
from rfd3.utils.inference import (
_restore_bonds_for_nonstandard_residues,
extract_ligand_array,
inference_load_,
set_com,
Expand All @@ -58,6 +58,7 @@

from foundry.common import exists
from foundry.utils.components import (
fetch_mask_from_idx,
get_design_pattern_with_constraints,
get_motif_components_and_breaks,
)
Expand Down Expand Up @@ -968,6 +969,259 @@ def create_motif_residue(
return token


def _polymer_link_atoms_for_residue(res_name: str) -> set[frozenset[str]]:
"""
Return the atom-name pairs that represent canonical polymerization atoms for a residue.

Only standard AA/DNA/RNA residues are treated as having canonical backbone links; PTMs
and other chem comp types fall back to nonstandard handling.
"""
if res_name in STANDARD_AA:
return {frozenset({"C", "N"})}
if res_name in STANDARD_DNA or res_name in STANDARD_RNA:
return {frozenset({"O3'", "P"}), frozenset({"O3*", "P"})} # allow legacy O3*
return set()


def _is_standard_polymer_backbone_bond(atom_a: struc.Atom, atom_b: struc.Atom) -> bool:
if atom_a.chain_id != atom_b.chain_id:
return False
if abs(atom_a.res_id - atom_b.res_id) != 1:
return False
atom_pair = frozenset({atom_a.atom_name, atom_b.atom_name})

pairs_a = _polymer_link_atoms_for_residue(atom_a.res_name)
pairs_b = _polymer_link_atoms_for_residue(atom_b.res_name)
shared_pairs = pairs_a & pairs_b
return atom_pair in shared_pairs


def _is_polymer_backbone_like(atom_a: struc.Atom, atom_b: struc.Atom) -> bool:
"""
Broader backbone check that treats canonical polymer atoms (C/N for peptide,
O3'/O3*–P for nucleic) as backbone links even when the residue itself is
non-standard (e.g., PTR/SEP).
"""
if atom_a.chain_id != atom_b.chain_id:
return False
if abs(atom_a.res_id - atom_b.res_id) != 1:
return False
atom_pair = frozenset({atom_a.atom_name, atom_b.atom_name})
return atom_pair in {
frozenset({"C", "N"}),
frozenset({"O3'", "P"}),
frozenset({"O3*", "P"}),
}


def _restore_component_bonds(
atom_array_accum: struc.AtomArray,
src_atom_array: Optional[struc.AtomArray],
source_to_accum_idx: Dict[int, int],
source_idx_to_component: Dict[int, str],
unindexed_components: set[str],
) -> struc.AtomArray:
"""
Rehydrate bonds from the input structure onto the accumulated array.

- Replays bonds from `src_atom_array` using the provided source→accum mappings.
- Skips canonical polymer backbone bonds (peptide/nucleic) that are reconstructed elsewhere.
- Protects unindexed components by disallowing cross-residue bonds that involve
an unindexed residue (except standard backbone).
- Emits warnings when a source bond cannot be remapped because one endpoint was
dropped during accumulation.
"""
if atom_array_accum.bonds is None:
atom_array_accum.bonds = struc.BondList(atom_array_accum.array_length())

if (
src_atom_array is None
or not hasattr(src_atom_array, "bonds")
or src_atom_array.bonds is None
or not source_to_accum_idx
):
return atom_array_accum

bonds_to_add: List[List[int]] = []
seen_pairs: set[tuple[int, int]] = set()
src_bonds = np.asarray(src_atom_array.bonds.as_array(), dtype=np.int64)

def _is_unindexed_source(idx: int) -> bool:
component = source_idx_to_component.get(idx)
return component in unindexed_components if component is not None else False

def _fmt_atom(atom: struc.Atom) -> str:
# Use a readable residue/atom separator to avoid names running together
return f"{atom.chain_id}{atom.res_id}:{atom.res_name}_{atom.atom_name}"

for atom_i_idx, atom_j_idx, bond_type in src_bonds:
atom_i_idx = int(atom_i_idx)
atom_j_idx = int(atom_j_idx)
bond_type = int(bond_type)
mapped_i = source_to_accum_idx.get(atom_i_idx)
mapped_j = source_to_accum_idx.get(atom_j_idx)

atom_i = src_atom_array[atom_i_idx]
atom_j = src_atom_array[atom_j_idx]

# If we only have one side of the bond, assert if the mapped atom is from an
# unindexed component and the bond would connect across residues.
if mapped_i is None or mapped_j is None:
if _is_standard_polymer_backbone_bond(
atom_i, atom_j
) or _is_polymer_backbone_like(atom_i, atom_j):
continue
if mapped_i is not None and _is_unindexed_source(atom_i_idx):
if (
atom_i.chain_id != atom_j.chain_id
or atom_i.res_id != atom_j.res_id
or not (
_is_standard_polymer_backbone_bond(atom_i, atom_j)
or _is_polymer_backbone_like(atom_i, atom_j)
)
):
raise AssertionError(
f"Unsupported bond between unindexed component {atom_i.chain_id}{atom_i.res_id} "
f"and omitted residue {atom_j.chain_id}{atom_j.res_id}."
)
if mapped_j is not None and _is_unindexed_source(atom_j_idx):
if (
atom_i.chain_id != atom_j.chain_id
or atom_i.res_id != atom_j.res_id
or not (
_is_standard_polymer_backbone_bond(atom_i, atom_j)
or _is_polymer_backbone_like(atom_i, atom_j)
)
):
Comment on lines +1075 to +1095
Copy link

Copilot AI Mar 31, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In _restore_component_bonds(), the unindexed-component guard for the “one endpoint missing” case will always raise for any non-backbone bond because the or not (_is_standard_polymer_backbone_bond(...) or _is_polymer_backbone_like(...)) clause is redundant (those cases are already continued above). This makes the condition effectively always-true and can trigger unexpected AssertionErrors. Consider changing the condition to only check whether the bond crosses residue/chain boundaries (and keep backbone-like early-continue), so intra-residue bonds with missing atoms don’t hard-fail.

Suggested change
if (
atom_i.chain_id != atom_j.chain_id
or atom_i.res_id != atom_j.res_id
or not (
_is_standard_polymer_backbone_bond(atom_i, atom_j)
or _is_polymer_backbone_like(atom_i, atom_j)
)
):
raise AssertionError(
f"Unsupported bond between unindexed component {atom_i.chain_id}{atom_i.res_id} "
f"and omitted residue {atom_j.chain_id}{atom_j.res_id}."
)
if mapped_j is not None and _is_unindexed_source(atom_j_idx):
if (
atom_i.chain_id != atom_j.chain_id
or atom_i.res_id != atom_j.res_id
or not (
_is_standard_polymer_backbone_bond(atom_i, atom_j)
or _is_polymer_backbone_like(atom_i, atom_j)
)
):
# Only treat as unsupported if this would have been a cross-residue/chain bond.
if atom_i.chain_id != atom_j.chain_id or atom_i.res_id != atom_j.res_id:
raise AssertionError(
f"Unsupported bond between unindexed component {atom_i.chain_id}{atom_i.res_id} "
f"and omitted residue {atom_j.chain_id}{atom_j.res_id}."
)
if mapped_j is not None and _is_unindexed_source(atom_j_idx):
# Only treat as unsupported if this would have been a cross-residue/chain bond.
if atom_i.chain_id != atom_j.chain_id or atom_i.res_id != atom_j.res_id:

Copilot uses AI. Check for mistakes.
raise AssertionError(
Comment on lines +1087 to +1096
Copy link

Copilot AI Mar 31, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same issue as above for the symmetric mapped_j is not None and _is_unindexed_source(atom_j_idx) branch: the or not (_is_standard_polymer_backbone_bond(...) or _is_polymer_backbone_like(...)) term makes the guard always true after the earlier backbone-like continue, so this will always raise for non-backbone bonds. Tighten the condition to only error on cross-residue/chain links (and allow intra-residue bonds to be skipped).

Copilot uses AI. Check for mistakes.
f"Unsupported bond between unindexed component {atom_j.chain_id}{atom_j.res_id} "
f"and omitted residue {atom_i.chain_id}{atom_i.res_id}."
)
# Only warn when we retained one side of a cross-residue/chain linkage
# (e.g., glycan partner missing), not for missing intra-residue atoms.
if (mapped_i is not None or mapped_j is not None) and (
atom_i.chain_id != atom_j.chain_id or atom_i.res_id != atom_j.res_id
):
logger.warning(
(
"Skipping non-backbone bond from source structure between %s and %s (type %d): "
"one atom is not present in accumulated components. "
"Bond cannot be inferred automatically; set it manually if needed."
)
% (_fmt_atom(atom_i), _fmt_atom(atom_j), bond_type)
)
continue

# Do not connect unindexed residues to anything else for now.
comp_i = source_idx_to_component.get(atom_i_idx)
comp_j = source_idx_to_component.get(atom_j_idx)
if (comp_i in unindexed_components or comp_j in unindexed_components) and (
atom_i.chain_id != atom_j.chain_id or atom_i.res_id != atom_j.res_id
):
if not (
_is_standard_polymer_backbone_bond(atom_i, atom_j)
or _is_polymer_backbone_like(atom_i, atom_j)
):
raise AssertionError(
"Bonds involving unindexed residues are not yet supported."
)
continue

if _is_standard_polymer_backbone_bond(
atom_i, atom_j
) or _is_polymer_backbone_like(atom_i, atom_j):
continue

pair = (min(mapped_i, mapped_j), max(mapped_i, mapped_j))
if pair in seen_pairs:
continue
seen_pairs.add(pair)
bonds_to_add.append([mapped_i, mapped_j, bond_type])

bond_array = (
np.array(bonds_to_add, dtype=np.int64)
if bonds_to_add
else np.empty((0, 3), dtype=np.int64)
)
new_bonds = struc.BondList(atom_array_accum.array_length(), bond_array)
atom_array_accum.bonds = atom_array_accum.bonds.merge(new_bonds)
return atom_array_accum


def _add_backbone_bonds_for_nonstandard_residues(
atom_array_accum: struc.AtomArray,
) -> struc.AtomArray:
"""
Add backbone/polymer bonds for cases where at least one residue is non-standard.

Uses `atomworks.io.utils.bonds.get_inferred_polymer_bonds`, which consults CCD
chem-comp metadata to decide the correct polymerization atoms (C/N, CG/N, etc.).
Only bonds involving at least one non-standard residue are added; standard
AA/DNA/RNA pairs are assumed to already carry their backbone bonds.
"""
if atom_array_accum.bonds is None:
atom_array_accum.bonds = struc.BondList(atom_array_accum.array_length())

unindexed_mask = (
atom_array_accum.get_annotation("is_motif_atom_unindexed")
if "is_motif_atom_unindexed" in atom_array_accum.get_annotation_categories()
else np.zeros(atom_array_accum.array_length(), dtype=bool)
)

existing_pairs = {
(min(a, b), max(a, b)) for a, b, _ in atom_array_accum.bonds.as_array()
}
bonds_to_add: List[List[int]] = []

inferred_bonds, _ = get_inferred_polymer_bonds(atom_array_accum)
for atom_i_idx, atom_j_idx, bond_type in inferred_bonds:
atom_i_idx = int(atom_i_idx)
atom_j_idx = int(atom_j_idx)

# Do not connect unindexed residues across residue boundaries
if (unindexed_mask[atom_i_idx] or unindexed_mask[atom_j_idx]) and (
atom_array_accum.chain_id[atom_i_idx]
!= atom_array_accum.chain_id[atom_j_idx]
or atom_array_accum.res_id[atom_i_idx]
!= atom_array_accum.res_id[atom_j_idx]
):
continue

# Only synthesize bonds when at least one residue is non-standard; standard
# backbone bonds should already exist.
if _is_standard_polymer_backbone_bond(
atom_array_accum[atom_i_idx], atom_array_accum[atom_j_idx]
):
continue

pair = (min(atom_i_idx, atom_j_idx), max(atom_i_idx, atom_j_idx))
if pair in existing_pairs:
continue
existing_pairs.add(pair)
bonds_to_add.append([pair[0], pair[1], int(bond_type)])

if bonds_to_add:
new_bonds = struc.BondList(
atom_array_accum.array_length(), np.array(bonds_to_add, dtype=np.int64)
)
atom_array_accum.bonds = atom_array_accum.bonds.merge(new_bonds)
return atom_array_accum


def _sort_bonds(atom_array_accum: struc.AtomArray) -> struc.AtomArray:
"""Sort bonds deterministically by atom indices then bond type."""
bonds_arr = atom_array_accum.bonds.as_array().copy()
# ensure lower index first
swap_mask = bonds_arr[:, 0] > bonds_arr[:, 1]
bonds_arr[swap_mask, :2] = bonds_arr[swap_mask][:, [1, 0]]
order = np.lexsort((bonds_arr[:, 2], bonds_arr[:, 1], bonds_arr[:, 0]))
bonds_arr = bonds_arr[order]
atom_array_accum.bonds = struc.BondList(
atom_array_accum.array_length(), bonds_arr.astype(np.int64)
)
return atom_array_accum


def accumulate_components(
components_to_accumulate: List[Union[str, int]],
*,
Expand Down Expand Up @@ -1013,6 +1267,8 @@ def accumulate_components(
res_id = start_resid
molecule_id = 0
source_to_accum_idx: Dict[int, int] = {}
source_idx_to_component: Dict[int, str] = {}
unindexed_component_names = set(unindexed_tokens.keys())
current_accum_idx = sum(len(arr) for arr in atom_array_accum)

# ... Insert contig information one- by one-
Expand Down Expand Up @@ -1095,6 +1351,7 @@ def accumulate_components(
):
for i, src_idx in enumerate(src_indices):
source_to_accum_idx[int(src_idx)] = current_accum_idx + i
source_idx_to_component[int(src_idx)] = str(component)

# ... Insert & Increment residue ID
atom_array_accum.append(token)
Expand All @@ -1104,23 +1361,17 @@ def accumulate_components(
# ... Concatenate all components
atom_array_accum = struc.concatenate(atom_array_accum)
atom_array_accum.set_annotation("pn_unit_iid", atom_array_accum.chain_id)

should_restore_bonds = (
src_atom_array is not None
and bool(source_to_accum_idx)
and _check_has_backbone_connections_to_nonstandard_residues(
atom_array_accum, src_atom_array
)
atom_array_accum = _restore_component_bonds(
atom_array_accum=atom_array_accum,
src_atom_array=src_atom_array,
source_to_accum_idx=source_to_accum_idx,
source_idx_to_component=source_idx_to_component,
unindexed_components=unindexed_component_names,
)
if should_restore_bonds:
assert not unindexed_tokens, (
"PTM backbone bond restoration is not compatible with unindexed components. "
"PTMs must be specified as indexed components (using 'contig' parameter, not 'unindex'). "
f"Found unindexed components: {list(unindexed_tokens.keys())}"
)
atom_array_accum = _restore_bonds_for_nonstandard_residues(
atom_array_accum, src_atom_array, source_to_accum_idx
)
atom_array_accum = _add_backbone_bonds_for_nonstandard_residues(
atom_array_accum=atom_array_accum
)
atom_array_accum = _sort_bonds(atom_array_accum)

# Reset res_id for unindexed residues to avoid duplicates (ridiculously long lines of code, cleanup later)
if np.any(atom_array_accum.is_motif_atom_unindexed.astype(bool)) and not np.all(
Expand Down
Loading
Loading