|
| 1 | +"""Secondary structure annotation using DSSP.""" |
| 2 | + |
| 3 | +import logging |
| 4 | +from enum import IntEnum, StrEnum |
| 5 | + |
| 6 | +import biotite.application.dssp as dssp |
| 7 | +import numpy as np |
| 8 | +from biotite.structure import AtomArray |
| 9 | + |
| 10 | +from atomworks.enums import ChainType |
| 11 | +from atomworks.ml.executables.dssp import DSSPExecutable |
| 12 | +from atomworks.ml.transforms._checks import ( |
| 13 | + check_atom_array_annotation, |
| 14 | +) |
| 15 | +from atomworks.ml.transforms.base import Transform |
| 16 | +from atomworks.ml.utils.token import get_token_starts, spread_token_wise |
| 17 | + |
| 18 | +logger = logging.getLogger("atomworks.ml") |
| 19 | + |
| 20 | + |
| 21 | +class SSEnum(IntEnum): |
| 22 | + """Secondary structure enum for protein residues. |
| 23 | +
|
| 24 | + Groups DSSP codes into five categories for efficient storage and manipulation. |
| 25 | + Used for both DSSP output annotations and secondary structure conditioning. |
| 26 | +
|
| 27 | + Values: |
| 28 | + NONE: Not set/not conditioned (value: -1) |
| 29 | + ALPHA_HELIX: Alpha helix (H, G, I in DSSP) (value: 0) |
| 30 | + BETA_SHEET: Beta sheet (E, B in DSSP) (value: 1) |
| 31 | + OTHER_PROTEIN: Coil/loop/turn (T, S, C, P in DSSP) (value: 2) |
| 32 | + NON_PROTEIN: Non-protein chain or DSSP failed (value: 3) |
| 33 | + """ |
| 34 | + |
| 35 | + NONE = -1 |
| 36 | + ALPHA_HELIX = 0 |
| 37 | + BETA_SHEET = 1 |
| 38 | + OTHER_PROTEIN = 2 |
| 39 | + NON_PROTEIN = 3 |
| 40 | + |
| 41 | + @classmethod |
| 42 | + def names(cls) -> list[str]: |
| 43 | + """Return human-readable names for each group.""" |
| 44 | + return ["none", "alpha_helix", "beta_sheet", "other_protein", "non_protein"] |
| 45 | + |
| 46 | + @classmethod |
| 47 | + def to_string(cls, value: int) -> str: |
| 48 | + """Convert integer value to human-readable string.""" |
| 49 | + # Handle NONE specially since it's -1 |
| 50 | + if value == -1: |
| 51 | + return "none" |
| 52 | + return cls.names()[value + 1] # Offset by 1 since NONE is at index 0 |
| 53 | + |
| 54 | + |
| 55 | +class DSSPCode(StrEnum): |
| 56 | + """DSSP secondary structure codes as defined by the DSSP program.""" |
| 57 | + |
| 58 | + ALPHA_HELIX = "H" # alpha-helix |
| 59 | + ISOLATED_BETA_BRIDGE = "B" # residue in isolated beta-bridge |
| 60 | + EXTENDED_STRAND = "E" # extended strand, participates in beta ladder |
| 61 | + THREE_TEN_HELIX = "G" # 3-10 helix |
| 62 | + PI_HELIX = "I" # pi-helix |
| 63 | + POLYPROLINE_HELIX = "P" # kappa-helix (poly-proline II helix) |
| 64 | + HYDROGEN_BONDED_TURN = "T" # hydrogen-bonded turn |
| 65 | + BEND = "S" # bend |
| 66 | + OTHER = "C" # loop, coil, or irregular |
| 67 | + NON_PROTEIN = "!" # non-protein |
| 68 | + |
| 69 | + @classmethod |
| 70 | + def valid_codes(cls) -> set[str]: |
| 71 | + """Return set of valid DSSP codes.""" |
| 72 | + return {e.value for e in cls} |
| 73 | + |
| 74 | + @classmethod |
| 75 | + def to_group_index(cls, code: str) -> int: |
| 76 | + """Map DSSP code to SSEnum index.""" |
| 77 | + if code in {cls.ALPHA_HELIX.value, cls.THREE_TEN_HELIX.value, cls.PI_HELIX.value}: |
| 78 | + return SSEnum.ALPHA_HELIX |
| 79 | + if code in {cls.EXTENDED_STRAND.value, cls.ISOLATED_BETA_BRIDGE.value}: |
| 80 | + return SSEnum.BETA_SHEET |
| 81 | + if code == cls.NON_PROTEIN.value: |
| 82 | + return SSEnum.NON_PROTEIN |
| 83 | + return SSEnum.OTHER_PROTEIN |
| 84 | + |
| 85 | + |
| 86 | +def _get_chain_sse_and_valid(chain_atom_array: AtomArray, bin_path: str) -> tuple[np.ndarray, bool]: |
| 87 | + """Run DSSP on a chain's protein atoms, return group indices and whether DSSP ran successfully. |
| 88 | +
|
| 89 | + Args: |
| 90 | + chain_atom_array: AtomArray containing atoms from a single chain. |
| 91 | + bin_path: Path to DSSP executable. |
| 92 | +
|
| 93 | + Returns: |
| 94 | + Tuple of (group_indices, is_valid) where group_indices are integers from |
| 95 | + SSEnum and is_valid indicates if DSSP ran successfully. |
| 96 | + """ |
| 97 | + try: |
| 98 | + dssp_codes = dssp.DsspApp.annotate_sse(chain_atom_array, bin_path=bin_path) |
| 99 | + # Convert DSSP codes to group indices |
| 100 | + group_indices = np.array([DSSPCode.to_group_index(code) for code in dssp_codes], dtype=np.int8) |
| 101 | + return group_indices, True |
| 102 | + except Exception as e: |
| 103 | + chain_id = getattr(chain_atom_array, "chain_id", ["?"])[0] |
| 104 | + logger.error( |
| 105 | + f"Error running DSSP for entity {chain_id}: {e}; " |
| 106 | + f"using NON_PROTEIN code for this entity's residues, and setting is_valid annotation to False" |
| 107 | + ) |
| 108 | + return ( |
| 109 | + np.full(len(chain_atom_array), SSEnum.NON_PROTEIN, dtype=np.int8), |
| 110 | + False, |
| 111 | + ) |
| 112 | + |
| 113 | + |
| 114 | +def annotate_secondary_structure( |
| 115 | + atom_array: AtomArray, |
| 116 | + bin_path: str | None = None, |
| 117 | + annotation_name: str = "dssp_sse", |
| 118 | + is_valid_annotation_name: str | None = None, |
| 119 | +) -> AtomArray: |
| 120 | + """Annotate secondary structure for each residue using DSSP. |
| 121 | +
|
| 122 | + Only protein tokens are assigned secondary structure groups; all others are |
| 123 | + set to NON_PROTEIN. |
| 124 | +
|
| 125 | + Also adds a boolean annotation indicating whether the SSE is valid (not default |
| 126 | + NON_PROTEIN due to error). |
| 127 | +
|
| 128 | + Args: |
| 129 | + atom_array: AtomArray to annotate. |
| 130 | + bin_path: Path to DSSP executable. If ``None``, uses executable from ``DSSPExecutable``. |
| 131 | + annotation_name: Name for the SSE annotation. Defaults to ``"dssp_sse"``. |
| 132 | + is_valid_annotation_name: Name for the validity annotation. If ``None``, |
| 133 | + uses ``"{annotation_name}_is_valid"``. Defaults to ``None``. |
| 134 | +
|
| 135 | + Returns: |
| 136 | + AtomArray with secondary structure annotations added. |
| 137 | + """ |
| 138 | + # Get bin_path from executable manager if not provided |
| 139 | + if bin_path is None: |
| 140 | + dssp_exec = DSSPExecutable.get_or_initialize() |
| 141 | + bin_path = dssp_exec.get_bin_path() |
| 142 | + |
| 143 | + # Atom-level masks |
| 144 | + is_protein_atom_lvl = np.isin(atom_array.chain_type, ChainType.get_proteins()) |
| 145 | + is_atomized_atom_lvl = ( |
| 146 | + atom_array.atomize |
| 147 | + if "atomize" in atom_array.get_annotation_categories() |
| 148 | + else np.zeros(atom_array.array_length(), dtype=bool) |
| 149 | + ) |
| 150 | + |
| 151 | + # Token-level masks |
| 152 | + token_starts = get_token_starts(atom_array) |
| 153 | + atom_array_token_lvl = atom_array[token_starts] |
| 154 | + |
| 155 | + # Default all tokens to NON_PROTEIN and all is_valid to False |
| 156 | + sse = np.full(len(atom_array_token_lvl), SSEnum.NON_PROTEIN, dtype=np.int8) |
| 157 | + is_valid = np.zeros(len(atom_array_token_lvl), dtype=bool) |
| 158 | + |
| 159 | + if np.any(is_protein_atom_lvl): |
| 160 | + # Loop over chain instances |
| 161 | + for chain_iid in np.unique(atom_array.chain_iid): |
| 162 | + chain_iid_mask = atom_array.chain_iid == chain_iid |
| 163 | + chain_iid_protein_mask = chain_iid_mask & is_protein_atom_lvl & ~is_atomized_atom_lvl |
| 164 | + |
| 165 | + if not np.any(chain_iid_protein_mask): |
| 166 | + # Early exit if this chain has no protein atoms |
| 167 | + continue |
| 168 | + |
| 169 | + # Get chain atoms and compute SSE |
| 170 | + chain_atom_array = atom_array[chain_iid_protein_mask] |
| 171 | + sse_chain, is_valid_chain = _get_chain_sse_and_valid(chain_atom_array, bin_path) |
| 172 | + |
| 173 | + # Assign to all tokens in this chain instance |
| 174 | + token_mask = chain_iid_protein_mask[token_starts] |
| 175 | + if len(sse_chain) == token_mask.sum(): |
| 176 | + sse[token_mask] = sse_chain |
| 177 | + is_valid[token_mask] = is_valid_chain |
| 178 | + else: |
| 179 | + # Catch-all for situations that arise (usually due to cropping) |
| 180 | + logger.warning( |
| 181 | + f"Mismatch in SSE length for chain {chain_iid}: {len(sse_chain)} != {token_mask.sum()}. " |
| 182 | + f"We will use NON_PROTEIN for this chain, and set is_valid to False." |
| 183 | + ) |
| 184 | + |
| 185 | + # Spread token-wise to all atoms and set annotations |
| 186 | + sse_spread = spread_token_wise(atom_array, sse) |
| 187 | + is_valid_spread = spread_token_wise(atom_array, is_valid) |
| 188 | + |
| 189 | + # Use provided is_valid annotation name or default |
| 190 | + if is_valid_annotation_name is None: |
| 191 | + is_valid_annotation_name = f"{annotation_name}_is_valid" |
| 192 | + |
| 193 | + atom_array.set_annotation(annotation_name, sse_spread) |
| 194 | + atom_array.set_annotation(is_valid_annotation_name, is_valid_spread) |
| 195 | + |
| 196 | + return atom_array |
| 197 | + |
| 198 | + |
| 199 | +class AnnotateSecondaryStructure(Transform): |
| 200 | + """Annotate secondary structure for each residue using DSSP. |
| 201 | +
|
| 202 | + Adds integer annotations from :py:class:`SSEnum` indicating the secondary structure type. |
| 203 | +
|
| 204 | + Args: |
| 205 | + bin_path: Path to DSSP executable. If ``None``, uses ``DSSP`` environment variable. Defaults to ``None``. |
| 206 | + annotation_name: Name for the SSE annotation. Defaults to ``"dssp_sse"``. |
| 207 | + is_valid_annotation_name: Name for the validity annotation. If ``None``, |
| 208 | + uses ``"{annotation_name}_is_valid"``. Defaults to ``None``. |
| 209 | + max_n_tokens: Maximum number of tokens to run DSSP on. If structure exceeds this, |
| 210 | + DSSP is skipped and no annotations are added. Defaults to 800, which encompasses most proteins. |
| 211 | + """ |
| 212 | + |
| 213 | + def __init__( |
| 214 | + self, |
| 215 | + bin_path: str | None = None, |
| 216 | + annotation_name: str = "dssp_sse", |
| 217 | + is_valid_annotation_name: str | None = None, |
| 218 | + max_n_tokens: int | None = 800, |
| 219 | + ): |
| 220 | + # Initialize executable if not already done |
| 221 | + if bin_path is None: |
| 222 | + DSSPExecutable.get_or_initialize() |
| 223 | + else: |
| 224 | + DSSPExecutable.get_or_initialize(bin_path) |
| 225 | + self.annotation_name = annotation_name |
| 226 | + self.is_valid_annotation_name = is_valid_annotation_name |
| 227 | + self.max_n_tokens = max_n_tokens |
| 228 | + |
| 229 | + def check_input(self, data: dict) -> None: |
| 230 | + check_atom_array_annotation(data, ["chain_type", "chain_iid"]) |
| 231 | + |
| 232 | + def forward(self, data: dict) -> dict: |
| 233 | + atom_array: AtomArray = data["atom_array"] |
| 234 | + |
| 235 | + # Check if structure exceeds max_n_tokens |
| 236 | + if self.max_n_tokens is not None: |
| 237 | + token_starts = get_token_starts(atom_array) |
| 238 | + n_tokens = len(token_starts) |
| 239 | + |
| 240 | + if n_tokens > self.max_n_tokens: |
| 241 | + # Skip DSSP and return data without annotations |
| 242 | + logger.info(f"Skipping DSSP: structure has {n_tokens} tokens, exceeds max_n_tokens={self.max_n_tokens}") |
| 243 | + return data |
| 244 | + |
| 245 | + # Proceed with normal DSSP annotation |
| 246 | + data["atom_array"] = annotate_secondary_structure( |
| 247 | + atom_array, |
| 248 | + bin_path=None, # Use executable manager |
| 249 | + annotation_name=self.annotation_name, |
| 250 | + is_valid_annotation_name=self.is_valid_annotation_name, |
| 251 | + ) |
| 252 | + return data |
0 commit comments