Skip to content

Commit e64976c

Browse files
authored
Merge pull request RosettaCommons#71 from baker-laboratory/feat/dssp-secondary-structure-annotation
feat: Add DSSP secondary structure annotation
2 parents 2a9505c + f4ef080 commit e64976c

4 files changed

Lines changed: 468 additions & 2 deletions

File tree

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
"""DSSP executable wrapper for secondary structure annotation."""
2+
3+
import logging
4+
import os
5+
from os import PathLike
6+
7+
from atomworks.ml.executables import Executable, ExecutableError
8+
9+
logger = logging.getLogger(__name__)
10+
11+
12+
class DSSPExecutable(Executable):
13+
"""Executable wrapper for the DSSP program.
14+
15+
DSSP (Define Secondary Structure of Proteins) is used to annotate secondary
16+
structure elements in protein structures based on hydrogen bonding patterns.
17+
18+
Examples:
19+
>>> dssp = DSSPExecutable.get_or_initialize()
20+
>>> version = dssp.get_version()
21+
>>> bin_path = dssp.get_bin_path()
22+
"""
23+
24+
name = "mkdssp"
25+
required_verification_text = ("DSSP", "output-format")
26+
version_cmd = "--version"
27+
verification_cmd = "--help"
28+
29+
@classmethod
30+
def initialize(cls, bin_path: PathLike | None = None, *args, **kwargs) -> "DSSPExecutable":
31+
"""Initialize DSSP executable.
32+
33+
Args:
34+
bin_path: Path to DSSP executable. If ``None``, attempts to find using ``DSSP`` env variable.
35+
36+
Returns:
37+
Initialized DSSPExecutable.
38+
39+
Raises:
40+
ExecutableError: If executable not found or invalid.
41+
"""
42+
if bin_path is None:
43+
bin_path = cls._infer_bin_path_from_env_var()
44+
return super().initialize(bin_path, *args, **kwargs)
45+
46+
@staticmethod
47+
def _infer_bin_path_from_env_var() -> PathLike:
48+
"""Get the path to the DSSP executable from environment variables."""
49+
dssp_path = os.environ.get("DSSP")
50+
if dssp_path is not None and os.path.isfile(dssp_path) and os.access(dssp_path, os.X_OK):
51+
return dssp_path
52+
53+
raise ExecutableError(
54+
"No `bin_path` provided and `DSSP` environment variable not set.\n"
55+
"Please set the `DSSP` environment variable to the path of the DSSP executable "
56+
"or provide a `bin_path` to the `DSSPExecutable` constructor: "
57+
"`DSSPExecutable.initialize(bin_path='/path/to/mkdssp')`."
58+
)
59+
60+
@classmethod
61+
def _setup(cls, bin_path: PathLike, *args, **kwargs) -> None:
62+
"""Setup method for DSSP (no special setup required)."""
63+
pass
Lines changed: 252 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,252 @@
1+
"""Secondary structure annotation using DSSP."""
2+
3+
import logging
4+
from enum import IntEnum, StrEnum
5+
6+
import biotite.application.dssp as dssp
7+
import numpy as np
8+
from biotite.structure import AtomArray
9+
10+
from atomworks.enums import ChainType
11+
from atomworks.ml.executables.dssp import DSSPExecutable
12+
from atomworks.ml.transforms._checks import (
13+
check_atom_array_annotation,
14+
)
15+
from atomworks.ml.transforms.base import Transform
16+
from atomworks.ml.utils.token import get_token_starts, spread_token_wise
17+
18+
logger = logging.getLogger("atomworks.ml")
19+
20+
21+
class SSEnum(IntEnum):
22+
"""Secondary structure enum for protein residues.
23+
24+
Groups DSSP codes into five categories for efficient storage and manipulation.
25+
Used for both DSSP output annotations and secondary structure conditioning.
26+
27+
Values:
28+
NONE: Not set/not conditioned (value: -1)
29+
ALPHA_HELIX: Alpha helix (H, G, I in DSSP) (value: 0)
30+
BETA_SHEET: Beta sheet (E, B in DSSP) (value: 1)
31+
OTHER_PROTEIN: Coil/loop/turn (T, S, C, P in DSSP) (value: 2)
32+
NON_PROTEIN: Non-protein chain or DSSP failed (value: 3)
33+
"""
34+
35+
NONE = -1
36+
ALPHA_HELIX = 0
37+
BETA_SHEET = 1
38+
OTHER_PROTEIN = 2
39+
NON_PROTEIN = 3
40+
41+
@classmethod
42+
def names(cls) -> list[str]:
43+
"""Return human-readable names for each group."""
44+
return ["none", "alpha_helix", "beta_sheet", "other_protein", "non_protein"]
45+
46+
@classmethod
47+
def to_string(cls, value: int) -> str:
48+
"""Convert integer value to human-readable string."""
49+
# Handle NONE specially since it's -1
50+
if value == -1:
51+
return "none"
52+
return cls.names()[value + 1] # Offset by 1 since NONE is at index 0
53+
54+
55+
class DSSPCode(StrEnum):
56+
"""DSSP secondary structure codes as defined by the DSSP program."""
57+
58+
ALPHA_HELIX = "H" # alpha-helix
59+
ISOLATED_BETA_BRIDGE = "B" # residue in isolated beta-bridge
60+
EXTENDED_STRAND = "E" # extended strand, participates in beta ladder
61+
THREE_TEN_HELIX = "G" # 3-10 helix
62+
PI_HELIX = "I" # pi-helix
63+
POLYPROLINE_HELIX = "P" # kappa-helix (poly-proline II helix)
64+
HYDROGEN_BONDED_TURN = "T" # hydrogen-bonded turn
65+
BEND = "S" # bend
66+
OTHER = "C" # loop, coil, or irregular
67+
NON_PROTEIN = "!" # non-protein
68+
69+
@classmethod
70+
def valid_codes(cls) -> set[str]:
71+
"""Return set of valid DSSP codes."""
72+
return {e.value for e in cls}
73+
74+
@classmethod
75+
def to_group_index(cls, code: str) -> int:
76+
"""Map DSSP code to SSEnum index."""
77+
if code in {cls.ALPHA_HELIX.value, cls.THREE_TEN_HELIX.value, cls.PI_HELIX.value}:
78+
return SSEnum.ALPHA_HELIX
79+
if code in {cls.EXTENDED_STRAND.value, cls.ISOLATED_BETA_BRIDGE.value}:
80+
return SSEnum.BETA_SHEET
81+
if code == cls.NON_PROTEIN.value:
82+
return SSEnum.NON_PROTEIN
83+
return SSEnum.OTHER_PROTEIN
84+
85+
86+
def _get_chain_sse_and_valid(chain_atom_array: AtomArray, bin_path: str) -> tuple[np.ndarray, bool]:
87+
"""Run DSSP on a chain's protein atoms, return group indices and whether DSSP ran successfully.
88+
89+
Args:
90+
chain_atom_array: AtomArray containing atoms from a single chain.
91+
bin_path: Path to DSSP executable.
92+
93+
Returns:
94+
Tuple of (group_indices, is_valid) where group_indices are integers from
95+
SSEnum and is_valid indicates if DSSP ran successfully.
96+
"""
97+
try:
98+
dssp_codes = dssp.DsspApp.annotate_sse(chain_atom_array, bin_path=bin_path)
99+
# Convert DSSP codes to group indices
100+
group_indices = np.array([DSSPCode.to_group_index(code) for code in dssp_codes], dtype=np.int8)
101+
return group_indices, True
102+
except Exception as e:
103+
chain_id = getattr(chain_atom_array, "chain_id", ["?"])[0]
104+
logger.error(
105+
f"Error running DSSP for entity {chain_id}: {e}; "
106+
f"using NON_PROTEIN code for this entity's residues, and setting is_valid annotation to False"
107+
)
108+
return (
109+
np.full(len(chain_atom_array), SSEnum.NON_PROTEIN, dtype=np.int8),
110+
False,
111+
)
112+
113+
114+
def annotate_secondary_structure(
115+
atom_array: AtomArray,
116+
bin_path: str | None = None,
117+
annotation_name: str = "dssp_sse",
118+
is_valid_annotation_name: str | None = None,
119+
) -> AtomArray:
120+
"""Annotate secondary structure for each residue using DSSP.
121+
122+
Only protein tokens are assigned secondary structure groups; all others are
123+
set to NON_PROTEIN.
124+
125+
Also adds a boolean annotation indicating whether the SSE is valid (not default
126+
NON_PROTEIN due to error).
127+
128+
Args:
129+
atom_array: AtomArray to annotate.
130+
bin_path: Path to DSSP executable. If ``None``, uses executable from ``DSSPExecutable``.
131+
annotation_name: Name for the SSE annotation. Defaults to ``"dssp_sse"``.
132+
is_valid_annotation_name: Name for the validity annotation. If ``None``,
133+
uses ``"{annotation_name}_is_valid"``. Defaults to ``None``.
134+
135+
Returns:
136+
AtomArray with secondary structure annotations added.
137+
"""
138+
# Get bin_path from executable manager if not provided
139+
if bin_path is None:
140+
dssp_exec = DSSPExecutable.get_or_initialize()
141+
bin_path = dssp_exec.get_bin_path()
142+
143+
# Atom-level masks
144+
is_protein_atom_lvl = np.isin(atom_array.chain_type, ChainType.get_proteins())
145+
is_atomized_atom_lvl = (
146+
atom_array.atomize
147+
if "atomize" in atom_array.get_annotation_categories()
148+
else np.zeros(atom_array.array_length(), dtype=bool)
149+
)
150+
151+
# Token-level masks
152+
token_starts = get_token_starts(atom_array)
153+
atom_array_token_lvl = atom_array[token_starts]
154+
155+
# Default all tokens to NON_PROTEIN and all is_valid to False
156+
sse = np.full(len(atom_array_token_lvl), SSEnum.NON_PROTEIN, dtype=np.int8)
157+
is_valid = np.zeros(len(atom_array_token_lvl), dtype=bool)
158+
159+
if np.any(is_protein_atom_lvl):
160+
# Loop over chain instances
161+
for chain_iid in np.unique(atom_array.chain_iid):
162+
chain_iid_mask = atom_array.chain_iid == chain_iid
163+
chain_iid_protein_mask = chain_iid_mask & is_protein_atom_lvl & ~is_atomized_atom_lvl
164+
165+
if not np.any(chain_iid_protein_mask):
166+
# Early exit if this chain has no protein atoms
167+
continue
168+
169+
# Get chain atoms and compute SSE
170+
chain_atom_array = atom_array[chain_iid_protein_mask]
171+
sse_chain, is_valid_chain = _get_chain_sse_and_valid(chain_atom_array, bin_path)
172+
173+
# Assign to all tokens in this chain instance
174+
token_mask = chain_iid_protein_mask[token_starts]
175+
if len(sse_chain) == token_mask.sum():
176+
sse[token_mask] = sse_chain
177+
is_valid[token_mask] = is_valid_chain
178+
else:
179+
# Catch-all for situations that arise (usually due to cropping)
180+
logger.warning(
181+
f"Mismatch in SSE length for chain {chain_iid}: {len(sse_chain)} != {token_mask.sum()}. "
182+
f"We will use NON_PROTEIN for this chain, and set is_valid to False."
183+
)
184+
185+
# Spread token-wise to all atoms and set annotations
186+
sse_spread = spread_token_wise(atom_array, sse)
187+
is_valid_spread = spread_token_wise(atom_array, is_valid)
188+
189+
# Use provided is_valid annotation name or default
190+
if is_valid_annotation_name is None:
191+
is_valid_annotation_name = f"{annotation_name}_is_valid"
192+
193+
atom_array.set_annotation(annotation_name, sse_spread)
194+
atom_array.set_annotation(is_valid_annotation_name, is_valid_spread)
195+
196+
return atom_array
197+
198+
199+
class AnnotateSecondaryStructure(Transform):
200+
"""Annotate secondary structure for each residue using DSSP.
201+
202+
Adds integer annotations from :py:class:`SSEnum` indicating the secondary structure type.
203+
204+
Args:
205+
bin_path: Path to DSSP executable. If ``None``, uses ``DSSP`` environment variable. Defaults to ``None``.
206+
annotation_name: Name for the SSE annotation. Defaults to ``"dssp_sse"``.
207+
is_valid_annotation_name: Name for the validity annotation. If ``None``,
208+
uses ``"{annotation_name}_is_valid"``. Defaults to ``None``.
209+
max_n_tokens: Maximum number of tokens to run DSSP on. If structure exceeds this,
210+
DSSP is skipped and no annotations are added. Defaults to 800, which encompasses most proteins.
211+
"""
212+
213+
def __init__(
214+
self,
215+
bin_path: str | None = None,
216+
annotation_name: str = "dssp_sse",
217+
is_valid_annotation_name: str | None = None,
218+
max_n_tokens: int | None = 800,
219+
):
220+
# Initialize executable if not already done
221+
if bin_path is None:
222+
DSSPExecutable.get_or_initialize()
223+
else:
224+
DSSPExecutable.get_or_initialize(bin_path)
225+
self.annotation_name = annotation_name
226+
self.is_valid_annotation_name = is_valid_annotation_name
227+
self.max_n_tokens = max_n_tokens
228+
229+
def check_input(self, data: dict) -> None:
230+
check_atom_array_annotation(data, ["chain_type", "chain_iid"])
231+
232+
def forward(self, data: dict) -> dict:
233+
atom_array: AtomArray = data["atom_array"]
234+
235+
# Check if structure exceeds max_n_tokens
236+
if self.max_n_tokens is not None:
237+
token_starts = get_token_starts(atom_array)
238+
n_tokens = len(token_starts)
239+
240+
if n_tokens > self.max_n_tokens:
241+
# Skip DSSP and return data without annotations
242+
logger.info(f"Skipping DSSP: structure has {n_tokens} tokens, exceeds max_n_tokens={self.max_n_tokens}")
243+
return data
244+
245+
# Proceed with normal DSSP annotation
246+
data["atom_array"] = annotate_secondary_structure(
247+
atom_array,
248+
bin_path=None, # Use executable manager
249+
annotation_name=self.annotation_name,
250+
is_valid_annotation_name=self.is_valid_annotation_name,
251+
)
252+
return data

src/atomworks/ml/transforms/template.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -179,9 +179,9 @@ def to_atom_array(self, template_idx: int) -> AtomArray:
179179
# Create atom array
180180
atom_array = atom_array_from_encoding(
181181
atom14_coords,
182-
atom14_mask,
183182
seq_tokenized,
184-
encoding=LEGACY_RF2_ATOM14_ENCODING,
183+
LEGACY_RF2_ATOM14_ENCODING,
184+
encoded_mask=atom14_mask,
185185
)
186186
n_atom = len(atom_array)
187187

0 commit comments

Comments
 (0)