diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 269ea33..848375d 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -31,3 +31,7 @@ repos: language: system types: [python] pass_filenames: true # For speed, we only check the files that are changed +- repo: https://github.com/gitleaks/gitleaks + rev: v8.24.2 + hooks: + - id: gitleaks diff --git a/esm/sdk/forge.py b/esm/sdk/forge.py index 67c94b0..ea2119b 100644 --- a/esm/sdk/forge.py +++ b/esm/sdk/forge.py @@ -29,14 +29,6 @@ from esm.utils.constants.api import MIMETYPE_ES_PICKLE from esm.utils.misc import deserialize_tensors, maybe_list, maybe_tensor from esm.utils.msa import MSA -from esm.utils.structure.input_builder import ( - StructurePredictionInput, - serialize_structure_prediction_input, -) -from esm.utils.structure.molecular_complex import ( - MolecularComplex, - MolecularComplexResult, -) from esm.utils.types import FunctionAnnotation @@ -217,70 +209,6 @@ def fold( return self._process_fold_response(data, sequence) - @retry_decorator - async def async_fold_all_atom( - self, all_atom_input: StructurePredictionInput, model_name: str | None = None - ) -> MolecularComplexResult | list[MolecularComplexResult] | ESMProteinError: - """Fold a molecular complex containing proteins, nucleic acids, and/or ligands. - - Args: - all_atom_input: StructurePredictionInput containing sequences for different molecule types - model_name: Override the client level model name if needed - """ - request = self._process_fold_all_atom_request( - all_atom_input, model_name if model_name is not None else self.model - ) - - try: - data = await self._async_post("fold_all_atom", request) - except ESMProteinError as e: - return e - - return self._process_fold_all_atom_response(data) - - @retry_decorator - def fold_all_atom( - self, all_atom_input: StructurePredictionInput, model_name: str | None = None - ) -> MolecularComplexResult | list[MolecularComplexResult] | ESMProteinError: - """Predict coordinates for a molecular complex containing proteins, dna, rna, and/or ligands. - - Args: - all_atom_input: StructurePredictionInput containing sequences for different molecule types - model_name: Override the client level model name if needed - """ - request = self._process_fold_all_atom_request( - all_atom_input, model_name if model_name is not None else self.model - ) - - try: - data = self._post("fold_all_atom", request) - except ESMProteinError as e: - return e - - return self._process_fold_all_atom_response(data) - - @staticmethod - def _process_fold_all_atom_request( - all_atom_input: StructurePredictionInput, model_name: str | None = None - ) -> dict[str, Any]: - request: dict[str, Any] = { - "all_atom_input": serialize_structure_prediction_input(all_atom_input), - "model": model_name, - } - - return request - - @staticmethod - def _process_fold_all_atom_response(data: dict[str, Any]) -> MolecularComplexResult: - complex_data = data.get("complex") - molecular_complex = MolecularComplex.from_state_dict(complex_data) - return MolecularComplexResult( - complex=molecular_complex, - plddt=maybe_tensor(data.get("plddt"), convert_none_to_nan=True), - ptm=data.get("ptm", None), - distogram=maybe_tensor(data.get("distogram"), convert_none_to_nan=True), - ) - @retry_decorator async def async_inverse_fold( self, diff --git a/esm/utils/structure/molecular_complex.py b/esm/utils/structure/molecular_complex.py index f53ab9c..6b6da1c 100644 --- a/esm/utils/structure/molecular_complex.py +++ b/esm/utils/structure/molecular_complex.py @@ -9,11 +9,13 @@ from tempfile import TemporaryDirectory from typing import TYPE_CHECKING, Any, List +import biotite.structure as bs import biotite.structure.io.pdbx as pdbx import brotli import msgpack import numpy as np import torch +from biotite.structure.io.pdbx import CIFFile, set_structure from esm.utils import residue_constants from esm.utils.structure.metrics import compute_lddt, compute_rmsd @@ -52,9 +54,11 @@ class Molecule: token_idx: int atom_positions: np.ndarray # [N_atoms, 3] atom_elements: np.ndarray # [N_atoms] element strings - residue_type: int - molecule_type: int # PROTEIN=0, RNA=1, DNA=2, LIGAND=3 - confidence: float + atom_names: np.ndarray | None = None # [N_atoms] atom names (optional) + atom_hetero: np.ndarray | None = None # [N_atoms] hetero flags (optional) + residue_type: int = 0 + molecule_type: int = 0 # PROTEIN=0, RNA=1, DNA=2, LIGAND=3 + confidence: float = 0.0 @dataclass(frozen=True) @@ -76,21 +80,40 @@ class MolecularComplex: # Token-to-atom mapping for efficient access token_to_atoms: np.ndarray # [N_tokens, 2] start/end indices into atoms array + # Chain information + chain_id: np.ndarray # [N_tokens] chain identifier for each token + # Confidence data plddt: np.ndarray # Per-token confidence scores [N_tokens] # Metadata metadata: MolecularComplexMetadata + # Optional atom names and hetero flags (preserved from original structures) + atom_names: np.ndarray | None = None # [N_atoms] atom names (optional) + atom_hetero: np.ndarray | None = None # [N_atoms] hetero flags (optional) + def __post_init__(self): """Validate array dimensions.""" n_tokens = len(self.sequence) + n_atoms = len(self.atom_positions) assert ( self.token_to_atoms.shape[0] == n_tokens ), f"token_to_atoms shape {self.token_to_atoms.shape} != {n_tokens} tokens" + assert ( + self.chain_id.shape[0] == n_tokens + ), f"chain_id shape {self.chain_id.shape} != {n_tokens} tokens" assert ( self.plddt.shape[0] == n_tokens ), f"plddt shape {self.plddt.shape} != {n_tokens} tokens" + if self.atom_names is not None: + assert ( + self.atom_names.shape[0] == n_atoms + ), f"atom_names shape {self.atom_names.shape} != {n_atoms} atoms" + if self.atom_hetero is not None: + assert ( + self.atom_hetero.shape[0] == n_atoms + ), f"atom_hetero shape {self.atom_hetero.shape} != {n_atoms} atoms" def __len__(self) -> int: """Return number of tokens.""" @@ -109,6 +132,12 @@ def __getitem__(self, idx: int) -> Molecule: # Extract atom data for this token token_atom_positions = self.atom_positions[start_atom:end_atom] token_atom_elements = self.atom_elements[start_atom:end_atom] + token_atom_names = None + if self.atom_names is not None: + token_atom_names = self.atom_names[start_atom:end_atom] + token_atom_hetero = None + if self.atom_hetero is not None: + token_atom_hetero = self.atom_hetero[start_atom:end_atom] # Default values for residue/molecule type (would be extended based on actual implementation) residue_type = 0 # Default to standard residue @@ -119,6 +148,8 @@ def __getitem__(self, idx: int) -> Molecule: token_idx=idx, atom_positions=token_atom_positions, atom_elements=token_atom_elements, + atom_names=token_atom_names, + atom_hetero=token_atom_hetero, residue_type=residue_type, molecule_type=molecule_type, confidence=self.plddt[idx], @@ -151,6 +182,8 @@ def from_protein_complex(cls, pc: ProteinComplex) -> "MolecularComplex": # Convert atom37 to flat arrays flat_positions = [] flat_elements = [] + flat_names = [] + flat_hetero = [] token_to_atoms = [] atom_idx = 0 @@ -180,6 +213,12 @@ def from_protein_complex(cls, pc: ProteinComplex) -> "MolecularComplex": ) # First character is element flat_elements.append(element) + # Add atom name + flat_names.append(atom_name) + + # Add hetero flag (all proteins are non-hetero) + flat_hetero.append(False) + atom_idx += 1 # Record token-to-atom mapping [start_idx, end_idx) @@ -189,17 +228,20 @@ def from_protein_complex(cls, pc: ProteinComplex) -> "MolecularComplex": # Convert to numpy arrays atom_positions = np.array(flat_positions, dtype=np.float32) atom_elements = np.array(flat_elements, dtype=object) + atom_names = np.array(flat_names, dtype=object) + atom_hetero = np.array(flat_hetero, dtype=bool) token_to_atoms_array = np.array(token_to_atoms, dtype=np.int32) - # Extract confidence scores (skip chain breaks) + # Extract confidence scores and chain_ids (skip chain breaks) confidence_scores = [] - residue_idx = 0 - for aa in pc.sequence: + chain_ids = [] + for seq_idx, aa in enumerate(pc.sequence): if aa != "|": - confidence_scores.append(pc.confidence[residue_idx]) - residue_idx += 1 + confidence_scores.append(pc.confidence[seq_idx]) + chain_ids.append(pc.chain_id[seq_idx]) confidence_array = np.array(confidence_scores, dtype=np.float32) + chain_id_array = np.array(chain_ids, dtype=np.int64) # Create metadata - convert entity IDs to strings for MolecularComplexMetadata entity_lookup_str = {k: str(v) for k, v in pc.metadata.entity_lookup.items()} @@ -215,8 +257,11 @@ def from_protein_complex(cls, pc: ProteinComplex) -> "MolecularComplex": atom_positions=atom_positions, atom_elements=atom_elements, token_to_atoms=token_to_atoms_array, + chain_id=chain_id_array, plddt=confidence_array, metadata=metadata, + atom_names=atom_names, + atom_hetero=atom_hetero, ) def to_protein_complex(self) -> ProteinComplex: @@ -251,13 +296,27 @@ def to_protein_complex(self) -> ProteinComplex: atom37_positions = np.full((n_residues, 37, 3), np.nan, dtype=np.float32) atom37_mask = np.zeros((n_residues, 37), dtype=bool) - # Convert tokens back to single-letter sequence - single_letter_sequence = "".join( - [residue_constants.restype_3to1[token] for token in protein_tokens] - ) - - # Extract confidence scores for protein residues only + # Extract confidence scores and chain_ids for protein residues only protein_confidence = self.plddt[protein_indices] + protein_chain_ids = self.chain_id[protein_indices] + + # Convert tokens back to single-letter sequence with chain breaks + single_letter_residues = [] + prev_chain_id = None + + for i, (token, chain_id_val) in enumerate( + zip(protein_tokens, protein_chain_ids) + ): + # Add chain break if we're switching to a new chain + if prev_chain_id is not None and chain_id_val != prev_chain_id: + single_letter_residues.append("|") + single_letter_residues.append(residue_constants.restype_3to1[token]) + prev_chain_id = chain_id_val + + single_letter_sequence = "".join(single_letter_residues) + + # Calculate final sequence length (includes chain breaks) + sequence_length = len(single_letter_sequence) # Convert flat atoms back to atom37 representation for res_idx, token_idx in enumerate(protein_indices): @@ -283,19 +342,69 @@ def to_protein_complex(self) -> ProteinComplex: atom37_mask[res_idx, atom_type_idx] = True atom_count += 1 - # Create other required arrays for ProteinComplex - # For simplicity, assume all protein residues belong to the same entity/chain - entity_id = np.zeros(n_residues, dtype=np.int64) - chain_id = np.zeros(n_residues, dtype=np.int64) - sym_id = np.zeros(n_residues, dtype=np.int64) - residue_index = np.arange(1, n_residues + 1, dtype=np.int64) - insertion_code = np.array([""] * n_residues, dtype=object) + # Create arrays that match sequence length (including chain breaks) + # Initialize arrays with proper size + chain_id_expanded = np.full(sequence_length, -1, dtype=np.int64) + entity_id_expanded = np.full(sequence_length, -1, dtype=np.int64) + sym_id_expanded = np.zeros(sequence_length, dtype=np.int64) + residue_index_expanded = np.zeros(sequence_length, dtype=np.int64) + insertion_code_expanded = np.array([""] * sequence_length, dtype=object) + confidence_expanded = np.zeros(sequence_length, dtype=np.float32) + atom37_positions_expanded = np.full( + (sequence_length, 37, 3), np.nan, dtype=np.float32 + ) + atom37_mask_expanded = np.zeros((sequence_length, 37), dtype=bool) + + # Map residue data to sequence positions (skipping chain breaks) + residue_idx = 0 + residue_counter_per_chain = {} + + for seq_pos, char in enumerate(single_letter_sequence): + if char != "|": + # This is a residue position + chain_id_val = protein_chain_ids[residue_idx] + + chain_id_expanded[seq_pos] = chain_id_val + entity_id_expanded[seq_pos] = chain_id_val # Simplified mapping + + # Track residue numbering per chain + if chain_id_val not in residue_counter_per_chain: + residue_counter_per_chain[chain_id_val] = 1 + else: + residue_counter_per_chain[chain_id_val] += 1 + + residue_index_expanded[seq_pos] = residue_counter_per_chain[ + chain_id_val + ] + confidence_expanded[seq_pos] = protein_confidence[residue_idx] + atom37_positions_expanded[seq_pos] = atom37_positions[residue_idx] + atom37_mask_expanded[seq_pos] = atom37_mask[residue_idx] + + residue_idx += 1 + # Chain break positions keep default values (-1, False, etc.) + + # Use the expanded arrays + chain_id = chain_id_expanded + entity_id = entity_id_expanded + sym_id = sym_id_expanded + residue_index = residue_index_expanded + insertion_code = insertion_code_expanded + protein_confidence = confidence_expanded + atom37_positions = atom37_positions_expanded + atom37_mask = atom37_mask_expanded + + # Create protein complex metadata preserving chain information + # Convert MolecularComplex metadata to ProteinComplex format + unique_chain_ids = np.unique(protein_chain_ids) + entity_lookup = {int(cid): int(cid) for cid in unique_chain_ids} + chain_lookup = { + int(cid): self.metadata.chain_lookup.get(int(cid), chr(65 + int(cid))) + for cid in unique_chain_ids + } - # Create simplified protein complex metadata - # Map the first entity/chain from molecular complex metadata protein_metadata = ProteinComplexMetadata( - entity_lookup={0: 1}, # Single entity (int for ProteinComplexMetadata) - chain_lookup={0: "A"}, # Single chain + entity_lookup=entity_lookup, + chain_lookup=chain_lookup, assembly_composition=self.metadata.assembly_composition, ) @@ -336,7 +445,9 @@ def from_mmcif(cls, inp: str, id: str | None = None) -> "MolecularComplex": # Get structure - handle missing model information gracefully try: - structure = pdbx.get_structure(mmcif_file, model=1) + structure = pdbx.get_structure( + mmcif_file, model=1, extra_fields=["b_factor"] + ) except (KeyError, ValueError): # Fallback for mmCIF files without model information try: @@ -374,8 +485,11 @@ def from_mmcif(cls, inp: str, id: str | None = None) -> "MolecularComplex": sequence_tokens = [] flat_positions = [] flat_elements = [] + flat_names = [] + flat_hetero = [] token_to_atoms = [] confidence_scores = [] + chain_ids = [] # Track chain IDs for each token atom_idx = 0 @@ -396,9 +510,16 @@ def from_mmcif(cls, inp: str, id: str | None = None) -> "MolecularComplex": } chain_residue_groups[chain_id][res_id]["atoms"].append(atom) + # Create a mapping from chain_id to numeric indices + chain_id_to_numeric = { + chain_id: idx + for idx, chain_id in enumerate(sorted(chain_residue_groups.keys())) + } + # Process each chain and residue for chain_id in sorted(chain_residue_groups.keys()): residues = chain_residue_groups[chain_id] + numeric_chain_id = chain_id_to_numeric[chain_id] for res_id in sorted(residues.keys()): residue_data = residues[res_id] @@ -422,6 +543,9 @@ def from_mmcif(cls, inp: str, id: str | None = None) -> "MolecularComplex": token_name = res_name sequence_tokens.append(token_name) + chain_ids.append( + numeric_chain_id + ) # Store the numeric chain ID for this token token_start = atom_idx # Add all atoms from this residue @@ -432,6 +556,14 @@ def from_mmcif(cls, inp: str, id: str | None = None) -> "MolecularComplex": element = atom.element flat_elements.append(element) + # Get atom name + atom_name = atom.atom_name + flat_names.append(atom_name) + + # Get hetero flag + hetero_flag = atom.hetero + flat_hetero.append(hetero_flag) + atom_idx += 1 # Record token-to-atom mapping @@ -446,20 +578,36 @@ def from_mmcif(cls, inp: str, id: str | None = None) -> "MolecularComplex": # Create minimal arrays if no atoms found atom_positions = np.zeros((0, 3), dtype=np.float32) atom_elements = np.zeros(0, dtype=object) + atom_names = np.zeros(0, dtype=object) + atom_hetero = np.zeros(0, dtype=bool) token_to_atoms_array = np.zeros((len(sequence_tokens), 2), dtype=np.int32) + chain_id_array = ( + np.array(chain_ids, dtype=np.int64) + if chain_ids + else np.zeros(len(sequence_tokens), dtype=np.int64) + ) else: atom_positions = np.array(flat_positions, dtype=np.float32) atom_elements = np.array(flat_elements, dtype=object) + atom_names = np.array(flat_names, dtype=object) + atom_hetero = np.array(flat_hetero, dtype=bool) token_to_atoms_array = np.array(token_to_atoms, dtype=np.int32) + chain_id_array = np.array(chain_ids, dtype=np.int64) confidence_array = np.array(confidence_scores, dtype=np.float32) - # Create metadata + # Create metadata using the chain_id_to_numeric mapping + if chain_residue_groups: + chain_lookup = { + numeric_id: chain_id + for chain_id, numeric_id in chain_id_to_numeric.items() + } + else: + chain_lookup = {} + metadata = MolecularComplexMetadata( entity_lookup=entity_info, - chain_lookup={ - i: chain_id for i, chain_id in enumerate(chain_residue_groups.keys()) - }, + chain_lookup=chain_lookup, assembly_composition=None, ) @@ -475,168 +623,107 @@ def from_mmcif(cls, inp: str, id: str | None = None) -> "MolecularComplex": atom_positions=atom_positions, atom_elements=atom_elements, token_to_atoms=token_to_atoms_array, + chain_id=chain_id_array, plddt=confidence_array, metadata=metadata, + atom_names=atom_names, + atom_hetero=atom_hetero, ) def to_mmcif(self) -> str: - """Write MolecularComplex to mmcif string. + """Write MolecularComplex to mmcif string using biotite. Returns: String representation of the complex in mmCIF format """ - # No need for element mapping - already using element characters - - lines = [] - - # Header - lines.append(f"data_{self.id}") - lines.append("#") - lines.append(f"_entry.id {self.id}") - lines.append("#") - - # Structure metadata - lines.append("_struct.entry_id {}".format(self.id)) - lines.append("_struct.title 'Protein Structure'") - lines.append("#") - - # Entity information - entity_id = 1 - chain_counter = 0 - lines.append("loop_") - lines.append("_entity.id") - lines.append("_entity.type") - lines.append("_entity.pdbx_description") - - # Determine entities based on sequence - protein_tokens = [] - other_tokens = [] + # Pre-allocate AtomArray + n_atoms = len(self.atom_positions) + atom_array = bs.AtomArray(length=n_atoms) + + # Set coordinates directly (already vectorized) + atom_array.coord = self.atom_positions + + # Pre-allocate per-atom arrays + atom_res_ids = np.zeros(n_atoms, dtype=np.int32) + atom_chain_ids = np.empty(n_atoms, dtype=object) + atom_res_names = np.empty(n_atoms, dtype=object) + atom_hetero = np.zeros(n_atoms, dtype=bool) + atom_bfactors = np.zeros(n_atoms, dtype=np.float32) + atom_names = np.empty(n_atoms, dtype=object) + + # Track residue IDs per chain + chain_res_counters = {} + + # Vectorized expansion of token-level to atom-level annotations + for token_idx, (start, end) in enumerate(self.token_to_atoms): + token = self.sequence[token_idx] + chain_id_numeric = self.chain_id[token_idx] + chain_id_str = self.metadata.chain_lookup.get( + int(chain_id_numeric), chr(65 + int(chain_id_numeric)) + ) - for i, token in enumerate(self.sequence): - if token in residue_constants.restype_3to1: - protein_tokens.append((i, token)) - else: - other_tokens.append((i, token)) - - if protein_tokens: - lines.append(f"{entity_id} polymer 'Protein chain'") - entity_id += 1 - - for token in set(token for _, token in other_tokens): - lines.append(f"{entity_id} non-polymer 'Ligand {token}'") - entity_id += 1 - - lines.append("#") - - # Chain assignments - lines.append("loop_") - lines.append("_struct_asym.id") - lines.append("_struct_asym.entity_id") - - chain_id = "A" - if protein_tokens: - lines.append(f"{chain_id} 1") - chain_counter += 1 - chain_id = chr(ord(chain_id) + 1) - - entity_id = 2 - for token in set(token for _, token in other_tokens): - lines.append(f"{chain_id} {entity_id}") - entity_id += 1 - chain_counter += 1 - if chain_counter < 26: - chain_id = chr(ord(chain_id) + 1) - - lines.append("#") - - # Atom site information - lines.append("loop_") - lines.append("_atom_site.group_PDB") - lines.append("_atom_site.id") - lines.append("_atom_site.type_symbol") - lines.append("_atom_site.label_atom_id") - lines.append("_atom_site.label_alt_id") - lines.append("_atom_site.label_comp_id") - lines.append("_atom_site.label_asym_id") - lines.append("_atom_site.label_entity_id") - lines.append("_atom_site.label_seq_id") - lines.append("_atom_site.pdbx_PDB_ins_code") - lines.append("_atom_site.Cartn_x") - lines.append("_atom_site.Cartn_y") - lines.append("_atom_site.Cartn_z") - lines.append("_atom_site.occupancy") - lines.append("_atom_site.B_iso_or_equiv") - lines.append("_atom_site.pdbx_PDB_model_num") - lines.append("_atom_site.auth_seq_id") - lines.append("_atom_site.auth_comp_id") - lines.append("_atom_site.auth_asym_id") - lines.append("_atom_site.auth_atom_id") - - atom_id = 1 - seq_id = 1 - chain_id = "A" - entity_id = 1 - - for token_idx, token in enumerate(self.sequence): - start_atom, end_atom = self.token_to_atoms[token_idx] + # Track residue numbering per chain + if chain_id_numeric not in chain_res_counters: + chain_res_counters[chain_id_numeric] = 1 + res_id = chain_res_counters[chain_id_numeric] + chain_res_counters[chain_id_numeric] += 1 - # Determine if this is a protein residue or ligand + # Determine if protein is_protein = token in residue_constants.restype_3to1 - group_pdb = "ATOM" if is_protein else "HETATM" - current_entity_id = 1 if is_protein else 2 # Simplified entity assignment - current_chain_id = "A" if is_protein else "B" # Simplified chain assignment - - # Create atom names for this token - atom_names = [] - if is_protein: - # Use standard protein atom names - res_atoms = residue_constants.residue_atoms.get( + + # Get atom names for this residue + if self.atom_names is not None: + # Use stored atom names (preserves original names from mmCIF) + names = list(self.atom_names[start:end]) + elif is_protein: + # Fallback: use standard protein atom names + standard_names = residue_constants.residue_atoms.get( token, ["N", "CA", "C", "O"] ) - atom_names = res_atoms[: end_atom - start_atom] + names = standard_names[: end - start] + # Pad if needed + while len(names) < (end - start): + names.append(f"X{len(names)+1}") else: - # Generate generic atom names for ligands - for i in range(end_atom - start_atom): - atom_names.append(f"C{i+1}") - - # Pad atom names if needed - while len(atom_names) < (end_atom - start_atom): - atom_names.append(f"X{len(atom_names)+1}") - - # Write atoms for this token - for atom_idx_in_token, global_atom_idx in enumerate( - range(start_atom, end_atom) - ): - pos = self.atom_positions[global_atom_idx] - element_char = self.atom_elements[global_atom_idx] - element_symbol = element_char if isinstance(element_char, str) else "C" - - atom_name = ( - atom_names[atom_idx_in_token] - if atom_idx_in_token < len(atom_names) - else f"X{atom_idx_in_token+1}" - ) - - # Format atom site line - bfactor = ( - self.plddt[token_idx] * 100.0 - if len(self.plddt) > token_idx - else 50.0 - ) - - line = ( - f"{group_pdb:<6} {atom_id:>5} {element_symbol:<2} {atom_name:<4} . " - f"{token:<3} {current_chain_id} {current_entity_id} {seq_id:>3} ? " - f"{pos[0]:>8.3f} {pos[1]:>8.3f} {pos[2]:>8.3f} 1.00 {bfactor:>6.2f} 1 " - f"{seq_id:>3} {token:<3} {current_chain_id} {atom_name:<4}" - ) - lines.append(line) - atom_id += 1 + # Fallback: generate names for ligands/nucleic acids + names = [f"C{i+1}" for i in range(end - start)] + + # Vectorized assignment for this token's atoms + atom_res_ids[start:end] = res_id + atom_chain_ids[start:end] = chain_id_str + atom_res_names[start:end] = token + # Use stored hetero flags if available, otherwise guess based on protein status + if self.atom_hetero is not None: + atom_hetero[start:end] = self.atom_hetero[start:end] + else: + atom_hetero[start:end] = not is_protein + atom_bfactors[start:end] = self.plddt[token_idx] * 100.0 + atom_names[start:end] = names + + # Set all AtomArray attributes at once (convert object arrays to proper string arrays) + atom_array.res_id = atom_res_ids + atom_array.chain_id = np.array(atom_chain_ids, dtype="U4") + atom_array.res_name = np.array(atom_res_names, dtype="U4") + atom_array.hetero = atom_hetero + atom_array.b_factor = atom_bfactors + atom_array.atom_name = np.array(atom_names, dtype="U4") + + # Use existing elements or infer them from atom names + if self.atom_elements is not None and len(self.atom_elements) == n_atoms: + # Convert object array to proper string array for biotite + atom_array.element = np.array(self.atom_elements, dtype="U4") + else: + # Use biotite's built-in element inference + atom_array.element = bs.infer_elements(atom_array) - seq_id += 1 + # Create CIF file and set structure + cif_file = CIFFile() + set_structure(cif_file, atom_array, data_block=self.id) - lines.append("#") - return "\n".join(lines) + # Convert to string + output = io.StringIO() + cif_file.write(output) + return output.getvalue() def dockq(self, native: "MolecularComplex") -> Any: """Compute DockQ score against native structure. @@ -909,7 +996,10 @@ def from_state_dict(cls, dct): if isinstance(v, list) and k in [ "atom_positions", "atom_elements", + "atom_names", + "atom_hetero", "token_to_atoms", + "chain_id", "plddt", ]: dct[k] = np.array(v) @@ -918,10 +1008,20 @@ def from_state_dict(cls, dct): if isinstance(v, np.ndarray): if k in ["atom_positions", "plddt"]: dct[k] = v.astype(np.float32) - elif k in ["token_to_atoms"]: - dct[k] = v.astype(np.int32) + elif k in ["token_to_atoms", "chain_id"]: + dct[k] = ( + v.astype(np.int32) + if k == "token_to_atoms" + else v.astype(np.int64) + ) dct["metadata"] = MolecularComplexMetadata(**dct["metadata"]) + + # Backward compatibility: if chain_id is missing, create default array + if "chain_id" not in dct: + # Default all tokens to chain 0 + dct["chain_id"] = np.zeros(len(dct["sequence"]), dtype=np.int64) + return cls(**dct) @classmethod