diff --git a/models/rfd3/src/rfd3/inference/input_parsing.py b/models/rfd3/src/rfd3/inference/input_parsing.py index d97b3be3..3d922466 100644 --- a/models/rfd3/src/rfd3/inference/input_parsing.py +++ b/models/rfd3/src/rfd3/inference/input_parsing.py @@ -136,6 +136,7 @@ class DesignInputSpecification(BaseModel): # Extra args: length: Optional[str] = Field(None, description="Length range as 'min-max' or int. Constrains length of contig if provided") ligand: Optional[str] = Field(None, description="Ligand name or index to include in design.") + allow_ligand_on_existing_chain: bool = Field(False, description="If True, suppress the error when a ligand shares a chain ID with the built atom array. Use with caution — chain ID is leaked to the model.") cif_parser_args: Optional[Dict[str, Any]] = Field(None, description="CIF parser arguments") extra: Optional[Dict[str, Any]] = Field(default_factory=dict, description="Extra metadata to include in output (useful for logging additional info in metadata)") dialect: int = Field(2, description="RFdiffusion3 input dialect. 1: legacy, 2: release.") @@ -672,14 +673,45 @@ def _append_ligand(self, atom_array, atom_array_input_annotated): + list(atom_array_input_annotated.get_annotation_categories()) ), ) - # Offset ligand residue ids based on the original input to avoid clashes - # with any newly created residues (matches legacy behaviour). - ligand_array.res_id = ( - ligand_array.res_id - - np.min(ligand_array.res_id) - + np.max(atom_array.res_id) - + 1 - ) + # Validate chain assignments — chain ID is leaked to the model + # so collisions are a significant deviation from convention. + ligand_chains = np.unique(ligand_array.chain_id) + existing_chains = set(np.unique(atom_array.chain_id)) + overlapping = sorted(existing_chains & set(ligand_chains)) + if not self.allow_ligand_on_existing_chain: + if overlapping: + raise ValueError( + f"Ligand chain(s) {overlapping} overlap with existing " + f"chain(s) {sorted(existing_chains)}. Place ligands on " + f"separate chains or set 'allow_ligand_on_existing_chain: " + f"true' to restore the old behaviour." + ) + # Multiple ligands must each be on their own chain. + for chain in ligand_chains: + n_residues = len( + np.unique(ligand_array.res_id[ligand_array.chain_id == chain]) + ) + if n_residues > 1: + raise ValueError( + f"Multiple ligand residues on chain {chain}. Each " + f"ligand must be on its own chain, or set " + f"'allow_ligand_on_existing_chain: true' to restore " + f"the old behaviour." + ) + if self.allow_ligand_on_existing_chain: + # Legacy behaviour: offset from protein max to avoid clashes. + ligand_array.res_id = ( + ligand_array.res_id + - np.min(ligand_array.res_id) + + np.max(atom_array.res_id) + + 1 + ) + else: + # Reset ligand res_id to start from 1 per chain, matching + # the convention AF3 uses in its output CIF files. + for chain in ligand_chains: + mask = ligand_array.chain_id == chain + ligand_array.res_id[mask] = 1 # Harmonize conditioning annotations before concatenation: biotite's # concatenate only preserves annotations present in ALL arrays (set # intersection), so mismatched optional conditioning annotations