Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 40 additions & 8 deletions models/rfd3/src/rfd3/inference/input_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ class DesignInputSpecification(BaseModel):
# Extra args:
length: Optional[str] = Field(None, description="Length range as 'min-max' or int. Constrains length of contig if provided")
ligand: Optional[str] = Field(None, description="Ligand name or index to include in design.")
allow_ligand_on_existing_chain: bool = Field(False, description="If True, suppress the error when a ligand shares a chain ID with the built atom array. Use with caution — chain ID is leaked to the model.")
cif_parser_args: Optional[Dict[str, Any]] = Field(None, description="CIF parser arguments")
extra: Optional[Dict[str, Any]] = Field(default_factory=dict, description="Extra metadata to include in output (useful for logging additional info in metadata)")
dialect: int = Field(2, description="RFdiffusion3 input dialect. 1: legacy, 2: release.")
Expand Down Expand Up @@ -672,14 +673,45 @@ def _append_ligand(self, atom_array, atom_array_input_annotated):
+ list(atom_array_input_annotated.get_annotation_categories())
),
)
# Offset ligand residue ids based on the original input to avoid clashes
# with any newly created residues (matches legacy behaviour).
ligand_array.res_id = (
ligand_array.res_id
- np.min(ligand_array.res_id)
+ np.max(atom_array.res_id)
+ 1
)
# Validate chain assignments — chain ID is leaked to the model
# so collisions are a significant deviation from convention.
ligand_chains = np.unique(ligand_array.chain_id)
existing_chains = set(np.unique(atom_array.chain_id))
overlapping = sorted(existing_chains & set(ligand_chains))
if not self.allow_ligand_on_existing_chain:
if overlapping:
raise ValueError(
f"Ligand chain(s) {overlapping} overlap with existing "
f"chain(s) {sorted(existing_chains)}. Place ligands on "
f"separate chains or set 'allow_ligand_on_existing_chain: "
f"true' to restore the old behaviour."
)
# Multiple ligands must each be on their own chain.
for chain in ligand_chains:
n_residues = len(
np.unique(ligand_array.res_id[ligand_array.chain_id == chain])
)
if n_residues > 1:
raise ValueError(
f"Multiple ligand residues on chain {chain}. Each "
f"ligand must be on its own chain, or set "
f"'allow_ligand_on_existing_chain: true' to restore "
f"the old behaviour."
)
if self.allow_ligand_on_existing_chain:
# Legacy behaviour: offset from protein max to avoid clashes.
ligand_array.res_id = (
ligand_array.res_id
- np.min(ligand_array.res_id)
+ np.max(atom_array.res_id)
+ 1
)
else:
# Reset ligand res_id to start from 1 per chain, matching
# the convention AF3 uses in its output CIF files.
for chain in ligand_chains:
mask = ligand_array.chain_id == chain
ligand_array.res_id[mask] = 1
# Harmonize conditioning annotations before concatenation: biotite's
# concatenate only preserves annotations present in ALL arrays (set
# intersection), so mismatched optional conditioning annotations
Expand Down
Loading