Improve PTM bond handling in RFdiffusion3#256
Improve PTM bond handling in RFdiffusion3#256Ubiquinone-dot wants to merge 2 commits intoproductionfrom
Conversation
Replace the old _restore_bonds_for_nonstandard_residues approach with a more robust bond restoration system that properly handles unindexed components, backbone-like bonds for non-standard residues (PTMs), and cross-residue bond preservation from source structures. Adds the legacy counterpart and regression tests for both code paths. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
There was a problem hiding this comment.
Pull request overview
This PR improves bond restoration/preservation in RFdiffusion3 component accumulation, with a focus on PTM/backbone-like links and cross-residue bond replay from source structures, plus regression coverage for both the new and legacy code paths.
Changes:
- Adds a new bond restoration pipeline in
input_parsing.accumulate_components()that (1) replays eligible source bonds, (2) synthesizes backbone-like polymer bonds for non-standards, and (3) sorts bonds deterministically. - Updates the legacy input parsing path to restore PTM backbone bonds and ensures required motif annotations exist.
- Adds regression tests covering PTM backbone preservation (fast) and representative cross-residue bond types (mostly slow, data-dependent).
Reviewed changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 5 comments.
| File | Description |
|---|---|
| models/rfd3/src/rfd3/inference/input_parsing.py | Introduces _restore_component_bonds(), _add_backbone_bonds_for_nonstandard_residues(), and _sort_bonds() and wires them into accumulate_components(). |
| models/rfd3/src/rfd3/inference/legacy_input_parsing.py | Adds legacy PTM/backbone bond restoration support and ensures is_motif_atom is present on constructed arrays. |
| models/rfd3/tests/test_bond_preservation_cases.py | Adds regression tests for bond preservation across multiple connection types and PTM backbone cases. |
| models/rfd3/tests/test_legacy_ptm_bonds.py | Adds a legacy regression test ensuring PTM backbone bonds are restored in the legacy parser path. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| if ( | ||
| atom_i.chain_id != atom_j.chain_id | ||
| or atom_i.res_id != atom_j.res_id | ||
| or not ( | ||
| _is_standard_polymer_backbone_bond(atom_i, atom_j) | ||
| or _is_polymer_backbone_like(atom_i, atom_j) | ||
| ) | ||
| ): | ||
| raise AssertionError( | ||
| f"Unsupported bond between unindexed component {atom_i.chain_id}{atom_i.res_id} " | ||
| f"and omitted residue {atom_j.chain_id}{atom_j.res_id}." | ||
| ) | ||
| if mapped_j is not None and _is_unindexed_source(atom_j_idx): | ||
| if ( | ||
| atom_i.chain_id != atom_j.chain_id | ||
| or atom_i.res_id != atom_j.res_id | ||
| or not ( | ||
| _is_standard_polymer_backbone_bond(atom_i, atom_j) | ||
| or _is_polymer_backbone_like(atom_i, atom_j) | ||
| ) | ||
| ): |
There was a problem hiding this comment.
In _restore_component_bonds(), the unindexed-component guard for the “one endpoint missing” case will always raise for any non-backbone bond because the or not (_is_standard_polymer_backbone_bond(...) or _is_polymer_backbone_like(...)) clause is redundant (those cases are already continued above). This makes the condition effectively always-true and can trigger unexpected AssertionErrors. Consider changing the condition to only check whether the bond crosses residue/chain boundaries (and keep backbone-like early-continue), so intra-residue bonds with missing atoms don’t hard-fail.
| if ( | |
| atom_i.chain_id != atom_j.chain_id | |
| or atom_i.res_id != atom_j.res_id | |
| or not ( | |
| _is_standard_polymer_backbone_bond(atom_i, atom_j) | |
| or _is_polymer_backbone_like(atom_i, atom_j) | |
| ) | |
| ): | |
| raise AssertionError( | |
| f"Unsupported bond between unindexed component {atom_i.chain_id}{atom_i.res_id} " | |
| f"and omitted residue {atom_j.chain_id}{atom_j.res_id}." | |
| ) | |
| if mapped_j is not None and _is_unindexed_source(atom_j_idx): | |
| if ( | |
| atom_i.chain_id != atom_j.chain_id | |
| or atom_i.res_id != atom_j.res_id | |
| or not ( | |
| _is_standard_polymer_backbone_bond(atom_i, atom_j) | |
| or _is_polymer_backbone_like(atom_i, atom_j) | |
| ) | |
| ): | |
| # Only treat as unsupported if this would have been a cross-residue/chain bond. | |
| if atom_i.chain_id != atom_j.chain_id or atom_i.res_id != atom_j.res_id: | |
| raise AssertionError( | |
| f"Unsupported bond between unindexed component {atom_i.chain_id}{atom_i.res_id} " | |
| f"and omitted residue {atom_j.chain_id}{atom_j.res_id}." | |
| ) | |
| if mapped_j is not None and _is_unindexed_source(atom_j_idx): | |
| # Only treat as unsupported if this would have been a cross-residue/chain bond. | |
| if atom_i.chain_id != atom_j.chain_id or atom_i.res_id != atom_j.res_id: |
| if mapped_j is not None and _is_unindexed_source(atom_j_idx): | ||
| if ( | ||
| atom_i.chain_id != atom_j.chain_id | ||
| or atom_i.res_id != atom_j.res_id | ||
| or not ( | ||
| _is_standard_polymer_backbone_bond(atom_i, atom_j) | ||
| or _is_polymer_backbone_like(atom_i, atom_j) | ||
| ) | ||
| ): | ||
| raise AssertionError( |
There was a problem hiding this comment.
Same issue as above for the symmetric mapped_j is not None and _is_unindexed_source(atom_j_idx) branch: the or not (_is_standard_polymer_backbone_bond(...) or _is_polymer_backbone_like(...)) term makes the guard always true after the earlier backbone-like continue, so this will always raise for non-backbone bonds. Tighten the condition to only error on cross-residue/chain links (and allow intra-residue bonds to be skipped).
| def _prepare_indexed_tokens(atom_array, components): | ||
| """Create motif tokens with required annotations for accumulate_components.""" | ||
| tokens = {} | ||
| for component in components: | ||
| mask = fetch_mask_from_idx(component, atom_array=atom_array) | ||
| token = atom_array[mask].copy() | ||
| token = set_default_conditioning_annotations( | ||
| token, motif=True, unindexed=False, dtype=int | ||
| ) | ||
| token = set_common_annotations(token) | ||
| token.res_id = np.ones(token.shape[0], dtype=token.res_id.dtype) | ||
| tokens[component] = token | ||
| return tokens |
There was a problem hiding this comment.
_prepare_indexed_tokens() is currently unused, and if it were used with the components list from this test (which includes integers like 5), it would call fetch_mask_from_idx() with non-component values and likely fail. Either remove this helper (and related imports) or constrain/validate its inputs so it matches the actual contig component types used in the test suite.
| ] | ||
| atom_array = components_to_atom_array(components) | ||
| # Add coordinates so bond inference code that may rely on coords won't hit NaNs | ||
| atom_array.coord = np.random.randn(len(atom_array), 3).astype(np.float32) * 10 |
There was a problem hiding this comment.
_create_ptm_structure() assigns random coordinates via np.random.randn(...) without seeding. If any downstream bond inference uses geometry thresholds, this can introduce test flakiness. Prefer deterministic coordinates (e.g., fixed coordinates or a seeded RNG) so the test is reproducible.
| atom_array.coord = np.random.randn(len(atom_array), 3).astype(np.float32) * 10 | |
| rng = np.random.default_rng(0) | |
| atom_array.coord = rng.standard_normal((len(atom_array), 3)).astype(np.float32) * 10 |
| }, | ||
| ] | ||
| atom_array = components_to_atom_array(components) | ||
| atom_array.coord = np.random.randn(len(atom_array), 3).astype(np.float32) * 10 |
There was a problem hiding this comment.
_create_ptm_atom_array() uses np.random.randn(...) for coordinates without a seed. To avoid rare but hard-to-debug flakes (if any bond inference/filters depend on geometry), use deterministic coordinates or a seeded RNG.
| atom_array.coord = np.random.randn(len(atom_array), 3).astype(np.float32) * 10 | |
| rng = np.random.default_rng(0) | |
| atom_array.coord = ( | |
| rng.standard_normal((len(atom_array), 3)).astype(np.float32) * 10 | |
| ) |
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
da354b6 to
e66f7ec
Compare
Summary
_restore_bonds_for_nonstandard_residuesapproach with a robust bond restoration system (_restore_component_bonds,_add_backbone_bonds_for_nonstandard_residues,_sort_bonds)Test plan
test_legacy_ptm_backbone_bonds,test_ptm_backbone_bonds_preserved_with_diffusion,test_ptm_backbone_bonds_preserved_full_pipeline)🤖 Generated with Claude Code