diff --git a/cookbook/tutorials/2_embed.ipynb b/cookbook/tutorials/2_embed.ipynb index 459fa90e..61fdb397 100644 --- a/cookbook/tutorials/2_embed.ipynb +++ b/cookbook/tutorials/2_embed.ipynb @@ -49,18 +49,18 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Grab a token from [the Forge console](https://forge.evolutionaryscale.ai/console) and add it below. Note that your token is like a password for your account and you should take care to protect it. For this reason it is recommended to frequently create a new token and delete old, unused ones. It is also recommended to paste the token directly into an environment variable or use a utility like `getpass` as shown below so tokens are not accidentally shared or checked into code repositories." + "Grab a token from [Forge](https://forge.evolutionaryscale.ai/) and add it below. Note that your token is like a password for your account and you should take care to protect it. For this reason it is recommended to frequently create a new token and delete old, unused ones. It is also recommended to paste the token directly into an environment variable or use a utility like `getpass` as shown below so tokens are not accidentally shared or checked into code repositories." ] }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from getpass import getpass\n", "\n", - "token = getpass(\"Token from Forge console: \")" + "token = getpass(\"Token from Forge: \")" ] }, { diff --git a/cookbook/tutorials/3_gfp_design.ipynb b/cookbook/tutorials/3_gfp_design.ipynb index bde09ad5..f2bb85b9 100644 --- a/cookbook/tutorials/3_gfp_design.ipynb +++ b/cookbook/tutorials/3_gfp_design.ipynb @@ -80,18 +80,18 @@ "\n", "The largest ESM3 (98 billion parameters) was trained with 1.07e24 FLOPs on 2.78 billion proteins and 771 billion unique tokens. To create esmGFP we used the 7 billion parameter variant of ESM3. We'll use this model via the [EvolutionaryScale Forge](https://forge.evolutionaryscale.ai) API.\n", "\n", - "Grab a token from [the Forge console](https://forge.evolutionaryscale.ai/console) and add it below. Note that your token is like a password for your account and you should take care to protect it. For this reason it is recommended to frequently create a new token and delete old, unused ones. It is also recommended to paste the token directly into an environment variable or use a utility like `getpass` as shown below so tokens are not accidentally shared or checked into code repositories.\n" + "Grab a token from [Forge](https://forge.evolutionaryscale.ai/) and add it below. Note that your token is like a password for your account and you should take care to protect it. For this reason it is recommended to frequently create a new token and delete old, unused ones. It is also recommended to paste the token directly into an environment variable or use a utility like `getpass` as shown below so tokens are not accidentally shared or checked into code repositories.\n" ] }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "metadata": { "id": "zNrU9Q2SYonX" }, "outputs": [], "source": [ - "token = getpass(\"Token from Forge console: \")" + "token = getpass(\"Token from Forge: \")" ] }, { diff --git a/cookbook/tutorials/4_forge_generate.ipynb b/cookbook/tutorials/4_forge_generate.ipynb index 7570fda9..9962e54d 100644 --- a/cookbook/tutorials/4_forge_generate.ipynb +++ b/cookbook/tutorials/4_forge_generate.ipynb @@ -53,7 +53,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Grab a token from [the Forge console](https://forge.evolutionaryscale.ai/console) and add it below. Note that your token is like a password for your account and you should take care to protect it. For this reason it is recommended to frequently create a new token and delete old, unused ones. It is also recommended to paste the token directly into an environment variable or use a utility like `getpass` as shown below so tokens are not accidentally shared or checked into code repositories." + "Grab a token from [Forge](https://forge.evolutionaryscale.ai/) and add it below. Note that your token is like a password for your account and you should take care to protect it. For this reason it is recommended to frequently create a new token and delete old, unused ones. It is also recommended to paste the token directly into an environment variable or use a utility like `getpass` as shown below so tokens are not accidentally shared or checked into code repositories." ] }, { @@ -64,7 +64,7 @@ "source": [ "from getpass import getpass\n", "\n", - "token = getpass(\"Token from Forge console: \")\n", + "token = getpass(\"Token from Forge: \")\n", "model = client(model=\"esm3-open\", url=\"https://forge.evolutionaryscale.ai\", token=token)" ] }, diff --git a/cookbook/tutorials/5_guided_generation.ipynb b/cookbook/tutorials/5_guided_generation.ipynb index f04d6be3..d35a8c94 100644 --- a/cookbook/tutorials/5_guided_generation.ipynb +++ b/cookbook/tutorials/5_guided_generation.ipynb @@ -120,7 +120,7 @@ "\n", "from esm.sdk import client\n", "\n", - "token = getpass(\"Token from Forge console: \")\n", + "token = getpass(\"Token from Forge: \")\n", "model = client(\n", " model=\"esm3-medium-2024-08\", url=\"https://forge.evolutionaryscale.ai\", token=token\n", ")" diff --git a/esm/__init__.py b/esm/__init__.py index 1e3bed4c..98a35b2d 100644 --- a/esm/__init__.py +++ b/esm/__init__.py @@ -1 +1 @@ -__version__ = "3.2.2" +__version__ = "3.2.2.post2" diff --git a/esm/sdk/api.py b/esm/sdk/api.py index 0212ddcd..bd93f0cd 100644 --- a/esm/sdk/api.py +++ b/esm/sdk/api.py @@ -148,12 +148,35 @@ def to_protein_complex( gt_chains = list(copy_annotations_from_ground_truth.chain_iter()) else: gt_chains = None + + # Expand pLDDT to match sequence length if needed, inserting NaN at chain breaks + # This handles the case where the server doesn't include chain breaks in pLDDT + # We should fix this in the server side. + if self.plddt is not None and len(self.plddt) != len(self.sequence): + # Only expand if there's a mismatch (likely due to chain breaks) + if "|" in self.sequence: + # Create expanded pLDDT with NaN at chain break positions + expanded_plddt = torch.full((len(self.sequence),), float("nan")) + plddt_idx = 0 + for i, aa in enumerate(self.sequence): + if aa != "|": + if plddt_idx < len(self.plddt): + expanded_plddt[i] = self.plddt[plddt_idx] + plddt_idx += 1 + plddt = expanded_plddt + else: + # Mismatch but no chain breaks - shouldn't happen but preserve original + plddt = self.plddt + else: + plddt = self.plddt + pred_chains = [] for i, (start, end) in enumerate(chain_boundaries): if i >= len(SINGLE_LETTER_CHAIN_IDS): raise ValueError( f"Too many chains to convert to ProteinComplex. The maximum number of chains is {len(SINGLE_LETTER_CHAIN_IDS)}" ) + pred_chain = ProteinChain.from_atom37( atom37_positions=coords[start:end], sequence=self.sequence[start:end], @@ -161,7 +184,7 @@ def to_protein_complex( if gt_chains is not None else SINGLE_LETTER_CHAIN_IDS[i], entity_id=gt_chains[i].entity_id if gt_chains is not None else None, - confidence=self.plddt[start:end] if self.plddt is not None else None, + confidence=plddt[start:end] if plddt is not None else None, ) pred_chains.append(pred_chain) return ProteinComplex.from_chains(pred_chains) @@ -298,13 +321,6 @@ def use_generative_unmasking_strategy(self): self.temperature_annealing = True -@define -class MSA: - # Paired MSA sequences. - # One would typically compute these using, for example, ColabFold. - sequences: list[str] - - @define class InverseFoldingConfig: invalid_ids: Sequence[int] = [] diff --git a/esm/sdk/forge.py b/esm/sdk/forge.py index e9bb2e3d..67c94b07 100644 --- a/esm/sdk/forge.py +++ b/esm/sdk/forge.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import asyncio import base64 import pickle @@ -7,7 +9,6 @@ import torch from esm.sdk.api import ( - MSA, ESM3InferenceClient, ESMCInferenceClient, ESMProtein, @@ -27,6 +28,15 @@ from esm.sdk.retry import retry_decorator from esm.utils.constants.api import MIMETYPE_ES_PICKLE from esm.utils.misc import deserialize_tensors, maybe_list, maybe_tensor +from esm.utils.msa import MSA +from esm.utils.structure.input_builder import ( + StructurePredictionInput, + serialize_structure_prediction_input, +) +from esm.utils.structure.molecular_complex import ( + MolecularComplex, + MolecularComplexResult, +) from esm.utils.types import FunctionAnnotation @@ -36,10 +46,8 @@ def _list_to_function_annotations(l) -> list[FunctionAnnotation] | None: return [FunctionAnnotation(*t) for t in l] -def _maybe_logits(data: dict[str, Any], track: str, return_bytes: bool = False): - ret = data.get("logits", {}).get(track, None) - # TODO(s22chan): just return this when removing return_bytes - return ret if ret is None or not return_bytes else maybe_tensor(ret) +def _maybe_logits(data: dict[str, Any], track: str): + return maybe_tensor(data.get("logits", {}).get(track, None)) def _maybe_b64_decode(obj, return_bytes: bool): @@ -137,7 +145,7 @@ async def _async_fetch_msa(self, sequence: str) -> MSA: data = await self._async_post( "msa", request={}, params={"sequence": sequence, "use_env": False} ) - return MSA(sequences=data["msa"]) + return MSA.from_sequences(sequences=data["msa"]) def _fetch_msa(self, sequence: str) -> MSA: print("Fetching MSA ... this may take a few minutes") @@ -146,7 +154,7 @@ def _fetch_msa(self, sequence: str) -> MSA: data = self._post( "msa", request={}, params={"sequence": sequence, "use_env": False} ) - return MSA(sequences=data["msa"]) + return MSA.from_sequences(sequences=data["msa"]) @retry_decorator async def async_fold( @@ -209,6 +217,70 @@ def fold( return self._process_fold_response(data, sequence) + @retry_decorator + async def async_fold_all_atom( + self, all_atom_input: StructurePredictionInput, model_name: str | None = None + ) -> MolecularComplexResult | list[MolecularComplexResult] | ESMProteinError: + """Fold a molecular complex containing proteins, nucleic acids, and/or ligands. + + Args: + all_atom_input: StructurePredictionInput containing sequences for different molecule types + model_name: Override the client level model name if needed + """ + request = self._process_fold_all_atom_request( + all_atom_input, model_name if model_name is not None else self.model + ) + + try: + data = await self._async_post("fold_all_atom", request) + except ESMProteinError as e: + return e + + return self._process_fold_all_atom_response(data) + + @retry_decorator + def fold_all_atom( + self, all_atom_input: StructurePredictionInput, model_name: str | None = None + ) -> MolecularComplexResult | list[MolecularComplexResult] | ESMProteinError: + """Predict coordinates for a molecular complex containing proteins, dna, rna, and/or ligands. + + Args: + all_atom_input: StructurePredictionInput containing sequences for different molecule types + model_name: Override the client level model name if needed + """ + request = self._process_fold_all_atom_request( + all_atom_input, model_name if model_name is not None else self.model + ) + + try: + data = self._post("fold_all_atom", request) + except ESMProteinError as e: + return e + + return self._process_fold_all_atom_response(data) + + @staticmethod + def _process_fold_all_atom_request( + all_atom_input: StructurePredictionInput, model_name: str | None = None + ) -> dict[str, Any]: + request: dict[str, Any] = { + "all_atom_input": serialize_structure_prediction_input(all_atom_input), + "model": model_name, + } + + return request + + @staticmethod + def _process_fold_all_atom_response(data: dict[str, Any]) -> MolecularComplexResult: + complex_data = data.get("complex") + molecular_complex = MolecularComplex.from_state_dict(complex_data) + return MolecularComplexResult( + complex=molecular_complex, + plddt=maybe_tensor(data.get("plddt"), convert_none_to_nan=True), + ptm=data.get("ptm", None), + distogram=maybe_tensor(data.get("distogram"), convert_none_to_nan=True), + ) + @retry_decorator async def async_inverse_fold( self, @@ -602,19 +674,15 @@ def _process_logits_response( return LogitsOutput( logits=ForwardTrackData( - sequence=_maybe_logits(data, "sequence", return_bytes), - structure=_maybe_logits(data, "structure", return_bytes), - secondary_structure=_maybe_logits( - data, "secondary_structure", return_bytes - ), - sasa=_maybe_logits(data, "sasa", return_bytes), - function=_maybe_logits(data, "function", return_bytes), + sequence=_maybe_logits(data, "sequence"), + structure=_maybe_logits(data, "structure"), + secondary_structure=_maybe_logits(data, "secondary_structure"), + sasa=_maybe_logits(data, "sasa"), + function=_maybe_logits(data, "function"), ), embeddings=maybe_tensor(data["embeddings"]), mean_embedding=data["mean_embedding"], - residue_annotation_logits=_maybe_logits( - data, "residue_annotation", return_bytes - ), + residue_annotation_logits=_maybe_logits(data, "residue_annotation"), hidden_states=maybe_tensor(data["hidden_states"]), mean_hidden_state=maybe_tensor(data["mean_hidden_state"]), ) @@ -965,6 +1033,7 @@ def _process_logits_request( "sequence": config.sequence, "return_embeddings": config.return_embeddings, "return_mean_embedding": config.return_mean_embedding, + "return_mean_hidden_states": config.return_mean_hidden_states, "return_hidden_states": config.return_hidden_states, "ith_hidden_layer": config.ith_hidden_layer, } @@ -981,12 +1050,11 @@ def _process_logits_response( data["hidden_states"] = _maybe_b64_decode(data["hidden_states"], return_bytes) output = LogitsOutput( - logits=ForwardTrackData( - sequence=_maybe_logits(data, "sequence", return_bytes) - ), + logits=ForwardTrackData(sequence=_maybe_logits(data, "sequence")), embeddings=maybe_tensor(data["embeddings"]), mean_embedding=data["mean_embedding"], hidden_states=maybe_tensor(data["hidden_states"]), + mean_hidden_state=maybe_tensor(data["mean_hidden_state"]), ) return output diff --git a/esm/sdk/retry.py b/esm/sdk/retry.py index 16c354b6..302d6cf0 100644 --- a/esm/sdk/retry.py +++ b/esm/sdk/retry.py @@ -2,10 +2,9 @@ from contextvars import ContextVar from functools import wraps -import httpx from tenacity import ( retry, - retry_if_exception_type, + retry_if_exception, retry_if_result, stop_after_attempt, wait_incrementing, @@ -30,8 +29,12 @@ def retry_if_specific_error(exception): def log_retry_attempt(retry_state): + try: + outcome = retry_state.outcome.result() + except Exception: + outcome = retry_state.outcome.exception() print( - f"Retrying... Attempt {retry_state.attempt_number} after {retry_state.next_action.sleep}s due to: {retry_state.outcome.result()}" + f"Retrying... Attempt {retry_state.attempt_number} after {retry_state.next_action.sleep}s due to: {outcome}" ) @@ -41,13 +44,18 @@ def retry_decorator(func): instance's retry settings. """ + def return_last_value(retry_state): + """Return the result of the last call attempt.""" + return retry_state.outcome.result() + @wraps(func) async def async_wrapper(instance, *args, **kwargs): if skip_retries_var.get(): return await func(instance, *args, **kwargs) retry_decorator = retry( + retry_error_callback=return_last_value, retry=retry_if_result(retry_if_specific_error) - | retry_if_exception_type(httpx.ConnectTimeout), # ADDED + | retry_if_exception(retry_if_specific_error), wait=wait_incrementing( increment=1, start=instance.min_retry_wait, max=instance.max_retry_wait ), @@ -62,8 +70,9 @@ def wrapper(instance, *args, **kwargs): if skip_retries_var.get(): return func(instance, *args, **kwargs) retry_decorator = retry( + retry_error_callback=return_last_value, retry=retry_if_result(retry_if_specific_error) - | retry_if_exception_type(httpx.ConnectTimeout), # ADDED + | retry_if_exception(retry_if_specific_error), wait=wait_incrementing( increment=1, start=instance.min_retry_wait, max=instance.max_retry_wait ), diff --git a/esm/utils/generation.py b/esm/utils/generation.py index b2d4ede4..4d35e7ee 100644 --- a/esm/utils/generation.py +++ b/esm/utils/generation.py @@ -43,7 +43,9 @@ def _trim_sequence_tensor_dataclass(o: Any, sequence_len: int): sliced = {} for k, v in attr.asdict(o, recurse=False).items(): - if v is None: + if k in ["mean_hidden_state", "mean_embedding"]: + sliced[k] = v + elif v is None: sliced[k] = None elif isinstance(v, torch.Tensor): # Trim padding. diff --git a/esm/utils/misc.py b/esm/utils/misc.py index 8ff7e24e..a9190041 100644 --- a/esm/utils/misc.py +++ b/esm/utils/misc.py @@ -1,8 +1,20 @@ +from __future__ import annotations + import os from collections import defaultdict from contextlib import nullcontext +from dataclasses import is_dataclass from io import BytesIO -from typing import Any, ContextManager, Sequence, TypeVar +from typing import ( + Any, + ContextManager, + Generator, + Iterable, + Protocol, + Sequence, + TypeVar, + runtime_checkable, +) from warnings import warn import huggingface_hub @@ -19,6 +31,12 @@ TSequence = TypeVar("TSequence", bound=Sequence) +@runtime_checkable +class Concatable(Protocol): + @classmethod + def concat(cls, objs: list[Concatable]) -> Concatable: ... + + def slice_python_object_as_numpy( obj: TSequence, idx: int | list[int] | slice | np.ndarray ) -> TSequence: @@ -53,6 +71,37 @@ def slice_python_object_as_numpy( return sliced_obj # type: ignore +def slice_any_object( + obj: TSequence, idx: int | list[int] | slice | np.ndarray +) -> TSequence: + """ + Slice a arbitrary object (like a list, string, or tuple) as if it was a numpy object. Similar to `slice_python_object_as_numpy`, but detects if it's a numpy array or Tensor and uses the existing slice method if so. + + If the object is a dataclass, it will simply apply the index to the object, under the assumption that the object has correcty implemented numpy indexing. + + Example: + >>> obj = "ABCDE" + >>> slice_any_object(obj, [1, 3, 4]) + "BDE" + + >>> obj = np.array([1, 2, 3, 4, 5]) + >>> slice_any_object(obj, np.arange(5) < 3) + np.array([1, 2, 3]) + + >>> obj = ProteinChain.from_rcsb("1a3a", "A") + >>> slice_any_object(obj, np.arange(len(obj)) < 10) + # ProteinChain w/ length 10 + + """ + if isinstance(obj, (np.ndarray, torch.Tensor)): + return obj[idx] # type: ignore + elif is_dataclass(obj): + # if passing a dataclass, assume it implements a custom slice + return obj[idx] # type: ignore + else: + return slice_python_object_as_numpy(obj, idx) + + def rbf(values, v_min, v_max, n_bins=16): """ Returns RBF encodings in a new dimension at the end. @@ -213,7 +262,7 @@ def unbinpack( return stack_variable_length_tensors(unpacked_tensors, pad_value) -def fp32_autocast_context(device_type: str) -> ContextManager[torch.amp.autocast]: # type: ignore +def fp32_autocast_context(device_type: str) -> ContextManager[Any]: # type: ignore """ Returns an autocast context manager that disables downcasting by AMP. @@ -302,6 +351,8 @@ def replace_inf(data): def maybe_tensor(x, convert_none_to_nan: bool = False) -> torch.Tensor | None: if x is None: return None + if isinstance(x, torch.Tensor): + return x if isinstance(x, list) and all(isinstance(t, torch.Tensor) for t in x): return torch.stack(x) if convert_none_to_nan: @@ -361,3 +412,90 @@ def deserialize_tensors(b: bytes) -> Any: buf = BytesIO(zstd.ZSTD_uncompress(b)) d = torch.load(buf, map_location="cpu", weights_only=False) return d + + +def join_lists( + lists: Sequence[Sequence[Any]], separator: Sequence[Any] | None = None +) -> list[Any]: + """Joins multiple lists with separator element. Like str.join but for lists. + + Example: [[1, 2], [3], [4]], separator=[0] -> [1, 2, 0, 3, 0, 4] + + Args: + lists: Lists of elements to chain + separator: separators to intsert between chained output. + Returns: + Joined lists. + """ + if not lists: + return [] + joined = [] + joined.extend(lists[0]) + for l in lists[1:]: + if separator: + joined.extend(separator) + joined.extend(l) + return joined + + +def iterate_with_intermediate( + lists: Iterable, intermediate +) -> Generator[Any, None, None]: + """ + Iterate over the iterable, yielding the intermediate value between + every element of the intermediate. Useful for joining objects with + separator tokens. + """ + it = iter(lists) + yield next(it) + for l in it: + yield intermediate + yield l + + +def concat_objects(objs: Sequence[Any], separator: Any | None = None): + """ + Concat objects with each other using a separator token. + + Supports: + - Concatable (objects that implement `concat` classmethod) + - strings + - lists + - numpy arrays + - torch Tensors + + Example: + >>> foo = "abc" + >>> bar = "def" + >>> concat_objects([foo, bar], "|") + "abc|def" + """ + match objs[0]: + case Concatable(): + return objs[0].__class__.concat(objs) # type: ignore + case str(): + assert isinstance( + separator, str + ), "Trying to join strings but separator is not a string" + return separator.join(objs) + case list(): + if separator is not None: + return join_lists(objs, [separator]) + else: + return join_lists(objs) + case np.ndarray(): + if separator is not None: + return np.concatenate( + list(iterate_with_intermediate(objs, np.array([separator]))) + ) + else: + return np.concatenate(objs) + case torch.Tensor(): + if separator is not None: + return torch.cat( + list(iterate_with_intermediate(objs, torch.tensor([separator]))) + ) + else: + return torch.cat(objs) # type: ignore + case _: + raise TypeError(type(objs[0])) diff --git a/esm/utils/msa/__init__.py b/esm/utils/msa/__init__.py new file mode 100644 index 00000000..5dd9965b --- /dev/null +++ b/esm/utils/msa/__init__.py @@ -0,0 +1,3 @@ +from esm.utils.msa.msa import MSA, FastMSA, remove_insertions_from_sequence + +__all__ = ["MSA", "FastMSA", "remove_insertions_from_sequence"] diff --git a/esm/utils/msa/filter_sequences.py b/esm/utils/msa/filter_sequences.py new file mode 100644 index 00000000..d44549d5 --- /dev/null +++ b/esm/utils/msa/filter_sequences.py @@ -0,0 +1,79 @@ +import tempfile +from pathlib import Path + +import numpy as np +from scipy.spatial.distance import cdist + +from esm.utils.system import run_subprocess_with_errorcheck + + +def greedy_select_indices(array, num_seqs: int, mode: str = "max") -> list[int]: + """Greedily select sequences that either maximize or minimize hamming distance. + + Algorithm proposed in the MSA Transformer paper. Starting from the query sequence, + iteratively add sequences to the list with the maximum (minimum) average Hamming + distance to the existing set of sequences. + + Args: + array (np.ndarray): Character array representing the sequences in the MSA + num_seqs (int): Number of sequences to select. + mode (str): Whether to maximize or minimize diversity. DO NOT pick 'min' unless + you're doing it to prove a point for a paper. + + Returns: + list[int]: List of indices to select from the array + """ + assert mode in ("max", "min") + depth = array.shape[0] + if depth <= num_seqs: + return list(range(depth)) + array = array.view(np.uint8) + + optfunc = np.argmax if mode == "max" else np.argmin + all_indices = np.arange(depth) + indices = [0] + pairwise_distances = np.zeros((0, depth)) + for _ in range(num_seqs - 1): + dist = cdist(array[indices[-1:]], array, "hamming") + pairwise_distances = np.concatenate([pairwise_distances, dist]) + shifted_distance = np.delete(pairwise_distances, indices, axis=1).mean(0) + shifted_index = optfunc(shifted_distance) + index = np.delete(all_indices, indices)[shifted_index] + indices.append(index) + indices = sorted(indices) + return indices + + +def hhfilter( + sequences: list[str], + seqid: int = 90, + diff: int = 0, + cov: int = 0, + qid: int = 0, + qsc: float = -20.0, + binary: str = "hhfilter", +) -> list[int]: + with tempfile.TemporaryDirectory(dir="/dev/shm") as tempdirname: + tempdir = Path(tempdirname) + fasta_file = tempdir / "input.fasta" + fasta_file.write_text( + "\n".join(f">{i}\n{seq}" for i, seq in enumerate(sequences)) + ) + output_file = tempdir / "output.fasta" + command = " ".join( + [ + f"{binary}", + f"-i {fasta_file}", + "-M a3m", + f"-o {output_file}", + f"-id {seqid}", + f"-diff {diff}", + f"-cov {cov}", + f"-qid {qid}", + f"-qsc {qsc}", + ] + ).split(" ") + run_subprocess_with_errorcheck(command, capture_output=True) + with output_file.open() as f: + indices = [int(line[1:].strip()) for line in f if line.startswith(">")] + return indices diff --git a/esm/utils/msa/msa.py b/esm/utils/msa/msa.py new file mode 100644 index 00000000..838e5b4b --- /dev/null +++ b/esm/utils/msa/msa.py @@ -0,0 +1,500 @@ +from __future__ import annotations + +import dataclasses +import string +from dataclasses import dataclass +from functools import cached_property +from itertools import islice +from typing import Sequence + +import numpy as np +from Bio import SeqIO +from scipy.spatial.distance import cdist + +from esm.utils.misc import slice_any_object +from esm.utils.msa.filter_sequences import greedy_select_indices, hhfilter +from esm.utils.parsing import FastaEntry, read_sequences, write_sequences +from esm.utils.sequential_dataclass import SequentialDataclass +from esm.utils.system import PathOrBuffer + +REMOVE_LOWERCASE_TRANSLATION = str.maketrans(dict.fromkeys(string.ascii_lowercase)) + + +def remove_insertions_from_sequence(seq: str) -> str: + return seq.translate(REMOVE_LOWERCASE_TRANSLATION) + + +@dataclass(frozen=True) +class MSA(SequentialDataclass): + """Object-oriented interface to an MSA. + + Args: + sequences (list[str]): List of protein sequences + headers (list[str]): List of headers describing the sequences + + """ + + entries: list[FastaEntry] + + @cached_property + def sequences(self) -> list[str]: + return [entry.sequence for entry in self.entries] + + @cached_property + def headers(self) -> list[str]: + return [entry.header for entry in self.entries] + + def __repr__(self): + return ( + f"MSA({self.entries[0].header}: Depth={self.depth}, Length={self.seqlen})" + ) + + def to_fast_msa(self) -> FastMSA: + return FastMSA(self.array, self.headers) + + @classmethod + def from_a3m( + cls, + path: PathOrBuffer, + remove_insertions: bool = True, + max_sequences: int | None = None, + ) -> MSA: + entries = [] + for header, seq in islice(read_sequences(path), max_sequences): + if remove_insertions: + seq = remove_insertions_from_sequence(seq) + if entries: + assert ( + len(seq) == len(entries[0].sequence) + ), f"Sequence length mismatch. Expected: {len(entries[0].sequence)}, Received: {len(seq)}" + entries.append(FastaEntry(header, seq)) + return cls(entries) + + def to_a3m(self, path: PathOrBuffer) -> None: + write_sequences(self.entries, path) + + @classmethod + def from_stockholm( + cls, + path: PathOrBuffer, + remove_insertions: bool = True, + max_sequences: int | None = None, + ) -> MSA: + entries = [] + for record in islice(SeqIO.parse(path, "stockholm"), max_sequences): + header = f"{record.id} {record.description}" + seq = str(record.seq) + if entries: + assert ( + len(seq) == len(entries[0].sequence) + ), f"Sequence length mismatch. Expected: {len(entries[0].sequence)}, Received: {len(seq)}" + entries.append(FastaEntry(header, seq)) + msa = cls(entries) + if remove_insertions: + keep_inds = [i for i, aa in enumerate(msa.query) if aa != "-"] + msa = msa.select_positions(keep_inds) + return msa + + def to_bytes(self) -> bytes: + version = 1 + version_bytes = version.to_bytes(1, "little") + seqlen_bytes = self.seqlen.to_bytes(4, "little") + depth_bytes = self.depth.to_bytes(4, "little") + array_bytes = self.array.tobytes() + header_bytes = "\n".join(entry.header for entry in self.entries).encode() + all_bytes = ( + version_bytes + seqlen_bytes + depth_bytes + array_bytes + header_bytes + ) + return all_bytes + + @classmethod + def from_bytes(cls, data: bytes) -> MSA: + version_bytes, seqlen_bytes, depth_bytes, data = ( + data[:1], + data[1:5], + data[5:9], + data[9:], + ) + version = int.from_bytes(version_bytes, "little") + if version != 1: + raise ValueError(f"Unsupported version: {version}") + seqlen = int.from_bytes(seqlen_bytes, "little") + depth = int.from_bytes(depth_bytes, "little") + array_bytes, header_bytes = data[: seqlen * depth], data[seqlen * depth :] + array = np.frombuffer(array_bytes, dtype="|S1") + array = array.reshape(depth, seqlen) + headers = header_bytes.decode().split("\n") + # Sometimes the separation is two newlines, which results in an empty header. + headers = [header for header in headers if header] + entries = [ + FastaEntry(header, b"".join(row).decode()) + for header, row in zip(headers, array) + ] + return cls(entries) + + # TODO(jmaccarl): set remove_insertions to True by default here to match other utils + @classmethod + def from_sequences( + cls, sequences: list[str], remove_insertions: bool = False + ) -> MSA: + if remove_insertions: + entries = [ + FastaEntry("", remove_insertions_from_sequence(seq)) + for seq in sequences + ] + else: + entries = [FastaEntry("", seq) for seq in sequences] + return cls(entries) + + def to_sequence_bytes(self) -> bytes: + """Stores ONLY SEQUENCES in array format as bytes. Header information will be lost.""" + seqlen_bytes = self.seqlen.to_bytes(4, "little") + array_bytes = self.array.tobytes() + all_bytes = seqlen_bytes + array_bytes + return all_bytes + + @classmethod + def from_sequence_bytes(cls, data: bytes) -> MSA: + seqlen_bytes, array_bytes = data[:4], data[4:] + seqlen = int.from_bytes(seqlen_bytes, "little") + array = np.frombuffer(array_bytes, dtype="|S1") + array = array.reshape(-1, seqlen) + entries = [FastaEntry("", b"".join(row).decode()) for row in array] + return cls(entries) + + @property + def depth(self) -> int: + return len(self.entries) + + @property + def seqlen(self) -> int: + return len(self.entries[0].sequence) + + @cached_property + def array(self) -> np.ndarray: + return np.array([list(seq) for seq in self.sequences], dtype="|S1") + + @property + def query(self) -> str: + return self.entries[0].sequence + + def select_sequences(self, indices: Sequence[int] | np.ndarray) -> MSA: + """Subselect rows of the MSA.""" + entries = [self.entries[idx] for idx in indices] + return dataclasses.replace(self, entries=entries) + + def select_positions(self, indices: Sequence[int] | np.ndarray) -> MSA: + """Subselect columns of the MSA.""" + entries = [ + FastaEntry(header, "".join(seq[idx] for idx in indices)) + for header, seq in self.entries + ] + return dataclasses.replace(self, entries=entries) + + def __getitem__(self, indices: int | list[int] | slice | np.ndarray): + if isinstance(indices, int): + indices = [indices] + + entries = [ + FastaEntry(header, slice_any_object(seq, indices)) + for header, seq in self.entries + ] + return dataclasses.replace(self, entries=entries) + + def __len__(self): + return self.seqlen + + def greedy_select(self, num_seqs: int, mode: str = "max") -> MSA: + """Greedily select sequences that either maximize or minimize hamming distance. + + Algorithm proposed in the MSA Transformer paper. Starting from the query sequence, + iteratively add sequences to the list with the maximum (minimum) average Hamming + distance to the existing set of sequences. + + Args: + num_seqs (int): Number of sequences to select. + mode (str): Whether to maximize or minimize diversity. DO NOT pick 'min' unless + you're doing it to prove a point for a paper. + + Returns: + MSA object w/ subselected sequences. + """ + assert mode in ("max", "min") + if self.depth <= num_seqs: + return self + + indices = greedy_select_indices(self.array, num_seqs, mode) + return self.select_sequences(indices) + + def hhfilter( + self, + seqid: int = 90, + diff: int = 0, + cov: int = 0, + qid: int = 0, + qsc: float = -20.0, + binary: str = "hhfilter", + ) -> MSA: + """Apply hhfilter to the sequences in the MSA and return a filtered MSA.""" + + indices = hhfilter( + self.sequences, + seqid=seqid, + diff=diff, + cov=cov, + qid=qid, + qsc=qsc, + binary=binary, + ) + return self.select_sequences(indices) + + def select_random_sequences(self, num_seqs: int) -> MSA: + """Uses random sampling to subselect sequences from the MSA. Always + keeps the query sequence. + """ + if num_seqs >= self.depth: + return self + + # Subselect random, always keeping the query sequence. + indices = np.sort( + np.append( + 0, np.random.choice(self.depth - 1, num_seqs - 1, replace=False) + 1 + ) + ) + msa = self.select_sequences(indices) # type: ignore + return msa + + def select_diverse_sequences(self, num_seqs: int) -> MSA: + """Applies hhfilter to select ~num_seqs sequences, then uses random sampling + to subselect if necessary. + """ + if num_seqs >= self.depth: + return self + + msa = self.hhfilter(diff=num_seqs) + if num_seqs < msa.depth: + msa = msa.select_random_sequences(num_seqs) + return msa + + def pad_to_depth(self, depth: int) -> MSA: + if depth < self.depth: + raise ValueError(f"Cannot pad to depth {depth} when depth is {self.depth}") + elif depth == self.depth: + return self + + num_to_add = depth - self.depth + extra_entries = [FastaEntry("", "-" * self.seqlen) for _ in range(num_to_add)] + return dataclasses.replace(self, entries=self.entries + extra_entries) + + @classmethod + def stack( + cls, msas: Sequence[MSA], remove_query_from_later_msas: bool = True + ) -> MSA: + """Stack a series of MSAs. Optionally remove the query from msas after the first.""" + all_entries = [] + for i, msa in enumerate(msas): + entries = msa.entries + if i > 0 and remove_query_from_later_msas: + entries = entries[1:] + all_entries.extend(entries) + return cls(entries=all_entries) + + @cached_property + def seqid(self) -> np.ndarray: + array = self.array.view(np.uint8) + seqid = 1 - cdist(array[0][None], array, "hamming") + return seqid[0] + + @classmethod + def concat( + cls, + msas: Sequence[MSA], + join_token: str | None = "|", + allow_depth_mismatch: bool = False, + ) -> MSA: + """Concatenate a series of MSAs horizontally, along the sequence dimension.""" + if not msas: + raise ValueError("Cannot concatenate an empty list of MSAs") + msa_depths = [msa.depth for msa in msas] + if len(set(msa_depths)) != 1: + if not allow_depth_mismatch: + raise ValueError("Depth mismatch in concatenating MSAs") + else: + max_depth = max(msa_depths) + msas = [msa.pad_to_depth(max_depth) for msa in msas] + headers = [ + "|".join([str(h) for h in headers]) + for headers in zip(*(msa.headers for msa in msas)) + ] + + if join_token is None: + join_token = "" + + seqs = [join_token.join(vals) for vals in zip(*(msa.sequences for msa in msas))] + entries = [FastaEntry(header, seq) for header, seq in zip(headers, seqs)] + return cls(entries) + + +@dataclass(frozen=True) +class FastMSA(SequentialDataclass): + """Object-oriented interface to an MSA stored as a numpy uint8 array.""" + + array: np.ndarray + headers: list[str] | None = None + + def __post_init__(self): + if self.headers is not None: + assert ( + len(self.headers) == self.depth + ), "Number of headers must match depth." + + @classmethod + def from_bytes(cls, data: bytes) -> FastMSA: + version_bytes, seqlen_bytes, depth_bytes, data = ( + data[:1], + data[1:5], + data[5:9], + data[9:], + ) + version = int.from_bytes(version_bytes, "little") + if version != 1: + raise ValueError(f"Unsupported version: {version}") + seqlen = int.from_bytes(seqlen_bytes, "little") + depth = int.from_bytes(depth_bytes, "little") + array_bytes, header_bytes = data[: seqlen * depth], data[seqlen * depth :] + array = np.frombuffer(array_bytes, dtype="|S1") + array = array.reshape(depth, seqlen) + headers = header_bytes.decode().split("\n") + # Sometimes the separation is two newlines, which results in an empty header. + headers = [header for header in headers if header] + return cls(array, headers) + + @classmethod + def from_sequence_bytes(cls, data: bytes) -> FastMSA: + seqlen_bytes, array_bytes = data[:4], data[4:] + seqlen = int.from_bytes(seqlen_bytes, "little") + array = np.frombuffer(array_bytes, dtype="|S1") + array = array.reshape(-1, seqlen) + return cls(array) + + @property + def depth(self) -> int: + return self.array.shape[0] + + @property + def seqlen(self) -> int: + return self.array.shape[1] + + def __len__(self): + return self.seqlen + + def __getitem__(self, indices: int | list[int] | slice | np.ndarray): + if isinstance(indices, int): + indices = [indices] + + return dataclasses.replace(self, array=self.array[:, indices]) + + def select_sequences(self, indices: Sequence[int] | np.ndarray) -> FastMSA: + """Subselect rows of the MSA.""" + array = self.array[indices] + headers = ( + [self.headers[idx] for idx in indices] if self.headers is not None else None + ) + return dataclasses.replace(self, array=array, headers=headers) + + def select_random_sequences(self, num_seqs: int) -> FastMSA: + """Uses random sampling to subselect sequences from the MSA. Always + keeps the query sequence. + """ + if num_seqs >= self.depth: + return self + + # Subselect random, always keeping the query sequence. + indices = np.sort( + np.append( + 0, np.random.choice(self.depth - 1, num_seqs - 1, replace=False) + 1 + ) + ) + msa = self.select_sequences(indices) # type: ignore + return msa + + def pad_to_depth(self, depth: int) -> FastMSA: + if depth < self.depth: + raise ValueError(f"Cannot pad to depth {depth} when depth is {self.depth}") + elif depth == self.depth: + return self + + num_to_add = depth - self.depth + array = np.pad( + self.array, + [(0, num_to_add), (0, 0)], + constant_values=ord("-") if self.array.dtype == np.uint8 else b"-", + ) + headers = self.headers + if headers is not None: + headers = headers + [""] * num_to_add + return dataclasses.replace(self, array=array, headers=headers) + + @classmethod + def concat( + cls, + msas: Sequence[FastMSA], + join_token: str | None = None, + allow_depth_mismatch: bool = False, + ) -> FastMSA: + """Concatenate a series of MSAs horizontally, along the sequence dimension.""" + if not msas: + raise ValueError("Cannot concatenate an empty list of MSAs") + if join_token is not None and join_token != "": + raise NotImplementedError("join_token is not supported for FastMSA") + + msa_depths = [msa.depth for msa in msas] + if len(set(msa_depths)) != 1: + if not allow_depth_mismatch: + raise ValueError("Depth mismatch in concatenating MSAs") + else: + max_depth = max(msa_depths) + msas = [msa.pad_to_depth(max_depth) for msa in msas] + headers = [ + "|".join([str(h) for h in headers]) + for headers in zip( + *( + msa.headers if msa.headers is not None else [""] * msa.depth + for msa in msas + ) + ) + ] + + array = np.concatenate([msa.array for msa in msas], axis=1) + return cls(array, headers) + + def to_msa(self) -> MSA: + headers = ( + self.headers + if self.headers is not None + else [f"seq{i}" for i in range(self.depth)] + ) + entries = [ + FastaEntry(header, b"".join(row).decode()) + for header, row in zip(headers, self.array) + ] + return MSA(entries) + + @classmethod + def stack( + cls, msas: Sequence[FastMSA], remove_query_from_later_msas: bool = True + ) -> FastMSA: + """Stack a series of MSAs. Optionally remove the query from msas after the first.""" + arrays = [] + all_headers = [] + for i, msa in enumerate(msas): + array = msa.array + headers = msa.headers + if i > 0 and remove_query_from_later_msas: + array = array[1:] + if headers is not None: + headers = headers[1:] + arrays.append(array) + if headers is not None: + all_headers.extend(headers) + return cls(np.concatenate(arrays, axis=0), all_headers) diff --git a/esm/utils/parsing.py b/esm/utils/parsing.py new file mode 100644 index 00000000..c47938ab --- /dev/null +++ b/esm/utils/parsing.py @@ -0,0 +1,83 @@ +import io +from pathlib import Path +from typing import Generator, Iterable, NamedTuple + +PathOrBuffer = str | Path | io.TextIOBase +FastaEntry = NamedTuple("FastaEntry", [("header", str), ("sequence", str)]) + + +def parse_fasta(fasta_string: str) -> Generator[FastaEntry, None, None]: + """ + Parses a fasta file and yields FastaEntry objects + + Args: + fasta_string: The fasta file as a string + Returns: + A generator of FastaEntry objects + """ + header = None + seq = [] + num_sequences = 0 + for line in fasta_string.splitlines(): + if not line or line[0] == "#": + continue + if line.startswith(">"): + if header is not None: + yield FastaEntry(header, "".join(seq)) + seq = [] + header = line[1:].strip() + else: + seq.append(line) + if header is not None: + num_sequences += 1 + yield FastaEntry(header, "".join(seq)) + + if num_sequences == 0: + raise ValueError("Found no sequences in input") + + +def read_sequences(path: PathOrBuffer) -> Generator[FastaEntry, None, None]: + # Uses duck typing to try and call the right method + # Doesn't use explicit isinstance check to support + # inputs that are not explicitly str/Path/TextIOBase but + # may support similar functionality + data = None # type: ignore + try: + if str(path).endswith(".gz"): + import gzip + + data = gzip.open(path, "rt") # type: ignore + else: + try: + data = open(path) # type: ignore + except TypeError: + data: io.TextIOBase = path # type: ignore + + yield from parse_fasta(data.read()) + finally: + if data is not None: + data.close() + + +def read_first_sequence(path: PathOrBuffer) -> FastaEntry: + return next(iter(read_sequences(path))) + + +def write_sequences(sequences: Iterable[tuple[str, str]], path: PathOrBuffer) -> None: + needs_closing = False + handle = None + try: + try: + handle = open(path, "w") # type: ignore + needs_closing = True + except TypeError: + handle = path + has_prev = False + for header, seq in sequences: + if has_prev: + handle.write("\n") # type: ignore + handle.write(f">{header}\n{seq}") # type: ignore + has_prev = True + finally: + if needs_closing: + handle.close() # type: ignore diff --git a/esm/utils/sequential_dataclass.py b/esm/utils/sequential_dataclass.py new file mode 100644 index 00000000..8e935798 --- /dev/null +++ b/esm/utils/sequential_dataclass.py @@ -0,0 +1,157 @@ +from abc import ABC, abstractmethod +from dataclasses import dataclass, fields, replace +from typing import TypeVar + +import numpy as np + +from esm.utils.misc import concat_objects, slice_any_object + +T = TypeVar("T") + + +@dataclass(frozen=True) +class SequentialDataclass(ABC): + """ + This is a builder on a dataclass that allows for automatic slicing and concatenation. + + When representing multimodal data, we often have multiple datatypes which have sequence dimensions that are the same (e.g. the length of the protein). + + When applying a transformation like a crop, we want to apply this to all tensors at the same time (e.g. crop the sequence, structure, and function). + + We also have some fields that are not sequential (like an id, or data source), which we don't want to crop. + + The SequentialDataclass abstracts this cropping away, allowing you to define dataclasses that implement `__len__`, `__getitem__` and `concat` automatically. + + This is done through the `metadata` field, which can take 3 values: + `sequence` (bool): True or False, tells the dataclass whether this field is a sequential type. Default: False. + `sequence_dim` (int): Which dimension is the sequential dimension (e.g. for a list of inverse folded sequences, we want to index each sequence in the list, not the list itself). Default: 0. + `join_token` (Any): What token to use to join when concatenating elements. Default: None. + + + Example: + + @dataclass(frozen=True) + class Foo(SequentialDataclass): + id: str + sequence: str = field(metadata={"sequence": True, "join_token": "|"}) + tensor: torch.Tensor = field(metadata={"sequence": True, "join_token": torch.nan}) + + def __len__(self): + # Must implement the __len__ method + return len(self.sequence) + + >>> foo = Foo(id="foo", sequence="ABCDE", tensor=torch.randn(5)) + Foo(id='foo', sequence='ABCDE', tensor=tensor([ 0.0252, -0.3335, -0.5143, 0.0251, -1.0717])) + + >>> foo[1:4] + Foo(id='foo', sequence='BCD', tensor=tensor([-0.3335, -0.5143, 0.0251])) + + >>> foo[np.arange(5) < 3] + Foo(id='foo', sequence='ABC', tensor=tensor([ 0.0252, -0.3335, -0.5143])) + + >>> Foo.concat([foo[:2], foo[3:]]) + Foo(id='foo', sequence='AB|DE', tensor=tensor([ 0.0252, -0.3335, nan, 0.0251, -1.0717])) + + # Trying to create a type where the sequence lengths do not match raises an error + >>> foo = Foo(id="foo", sequence="ABCDE", tensor=torch.randn(6)) + ValueError: Mismatch in sequence length for field: tensor. Expected 5, received 6 + + """ + + def __post_init__(self): + self._check_sequence_lengths_match() + + @abstractmethod + def __len__(self): + raise NotImplementedError + + def __getitem__(self, idx: int | list[int] | slice | np.ndarray): + updated_fields = {} + if isinstance(idx, int): + # make it so that things remain sequential + idx = [idx] + + for fld in fields(self): + if fld.metadata.get("sequence", False): + # this is a sequence, should be the same length as all other sequences + sequence_dim = fld.metadata.get("sequence_dim", 0) + value = getattr(self, fld.name) + if value is None: + continue + match sequence_dim: + case 0: + # sequence is first dimension + value = getattr(self, fld.name) + value = slice_any_object(value, idx) + updated_fields[fld.name] = value + case 1: + new_value = [slice_any_object(item, idx) for item in value] + updated_fields[fld.name] = value.__class__(new_value) + case _: + raise NotImplementedError( + "Arbitrary slicing for different sequence length fields is not implemented" + ) + + return replace(self, **updated_fields) + + def _check_sequence_lengths_match(self): + """Checks if sequence lengths of all "sequence" fields match.""" + for fld in fields(self): + if fld.metadata.get("sequence", False) and fld.name != "complex": + # this is a sequence, should be the same length as all other sequences + sequence_dim = fld.metadata.get("sequence_dim", 0) + value = getattr(self, fld.name) + if value is None: + continue + match sequence_dim: + case 0: + # sequence is first dimension + value = getattr(self, fld.name) + if len(value) != len(self): + raise ValueError( + f"Mismatch in sequence length for field: {fld.name}. Expected {len(self)}, received {len(value)}" + ) + case 1: + for item in value: + if len(item) != len(self): + raise ValueError( + f"Mismatch in sequence length for field: {fld.name}. Expected {len(self)}, received {len(item)}" + ) + case _: + raise NotImplementedError( + "Arbitrary matching for different sequence length fields is not implemented" + ) + + @classmethod + def concat(cls, items: list[T], **kwargs) -> T: + updated_fields = {} + for fld in fields(cls): + if fld.metadata.get("sequence", False): + # this is a sequence, should be the same length as all other sequences + sequence_dim = fld.metadata.get("sequence_dim", 0) + join_value = fld.metadata.get("join_token", None) + if getattr(items[0], fld.name) is None: + continue + values = [getattr(item, fld.name) for item in items] + match sequence_dim: + case 0: + # sequence is first dimension + value = concat_objects(values, join_value) + updated_fields[fld.name] = value + case 1: + new_value = [ + concat_objects(item, join_value) for item in zip(*values) + ] + updated_fields[fld.name] = getattr( + items[0], fld.name + ).__class__(new_value) + case _: + raise NotImplementedError( + "Arbitrary joining for different sequence length fields is not implemented" + ) + updated_fields.update(kwargs) + + return replace( + items[0], # type: ignore + **updated_fields, + ) diff --git a/esm/utils/structure/input_builder.py b/esm/utils/structure/input_builder.py new file mode 100644 index 00000000..026912fc --- /dev/null +++ b/esm/utils/structure/input_builder.py @@ -0,0 +1,95 @@ +from dataclasses import dataclass +from typing import Any, Sequence + +import numpy as np + + +@dataclass +class Modification: + position: int # zero-indexed + ccd: str + + +@dataclass +class ProteinInput: + id: str | list[str] + sequence: str + modifications: list[Modification] | None = None + + +@dataclass +class RNAInput: + id: str | list[str] + sequence: str + modifications: list[Modification] | None = None + + +@dataclass +class DNAInput: + id: str | list[str] + sequence: str + modifications: list[Modification] | None = None + + +@dataclass +class LigandInput: + id: str | list[str] + smiles: str + ccd: list[str] | None = None + + +@dataclass +class DistogramConditioning: + chain_id: str + distogram: np.ndarray + + +@dataclass +class PocketConditioning: + binder_chain_id: str + contacts: list[tuple[str, int]] + + +@dataclass +class StructurePredictionInput: + sequences: Sequence[ProteinInput | RNAInput | DNAInput | LigandInput] + pocket: PocketConditioning | None = None + distogram_conditioning: list[DistogramConditioning] | None = None + + +def serialize_structure_prediction_input(all_atom_input: StructurePredictionInput): + def create_chain_data(seq_input, chain_type: str) -> dict[str, Any]: + chain_data: dict[str, Any] = { + "sequence": seq_input.sequence, + "id": seq_input.id, + "type": chain_type, + } + if hasattr(seq_input, "modifications") and seq_input.modifications: + mods = [ + {"position": mod.position, "ccd": mod.ccd} + for mod in seq_input.modifications + ] + chain_data["modifications"] = mods + return chain_data + + sequences = [] + for seq_input in all_atom_input.sequences: + if isinstance(seq_input, ProteinInput): + sequences.append(create_chain_data(seq_input, "protein")) + elif isinstance(seq_input, RNAInput): + sequences.append(create_chain_data(seq_input, "rna")) + elif isinstance(seq_input, DNAInput): + sequences.append(create_chain_data(seq_input, "dna")) + elif isinstance(seq_input, LigandInput): + sequences.append( + { + "smiles": seq_input.smiles, + "id": seq_input.id, + "ccd": seq_input.ccd, + "type": "ligand", + } + ) + else: + raise ValueError(f"Unsupported sequence input type: {type(seq_input)}") + + return {"sequences": sequences} diff --git a/esm/utils/structure/metrics.py b/esm/utils/structure/metrics.py index d76ed766..b2e590db 100644 --- a/esm/utils/structure/metrics.py +++ b/esm/utils/structure/metrics.py @@ -264,7 +264,7 @@ def compute_lddt_ca( if all_atom_pred_pos.dim() != 3: all_atom_pred_pos = all_atom_pred_pos[..., ca_pos, :] all_atom_positions = all_atom_positions[..., ca_pos, :] - all_atom_mask = all_atom_mask[..., ca_pos : (ca_pos + 1)] # keep dim + all_atom_mask = all_atom_mask[..., ca_pos] return compute_lddt( all_atom_pred_pos, diff --git a/esm/utils/structure/molecular_complex.py b/esm/utils/structure/molecular_complex.py new file mode 100644 index 00000000..f53ab9c4 --- /dev/null +++ b/esm/utils/structure/molecular_complex.py @@ -0,0 +1,938 @@ +from __future__ import annotations + +import io +import os +import re +from dataclasses import asdict, dataclass +from pathlib import Path +from subprocess import check_output +from tempfile import TemporaryDirectory +from typing import TYPE_CHECKING, Any, List + +import biotite.structure.io.pdbx as pdbx +import brotli +import msgpack +import numpy as np +import torch + +from esm.utils import residue_constants +from esm.utils.structure.metrics import compute_lddt, compute_rmsd +from esm.utils.structure.protein_complex import ProteinComplex, ProteinComplexMetadata + + +@dataclass +class MolecularComplexResult: + """Result of molecular complex folding""" + + complex: MolecularComplex + plddt: torch.Tensor | None = None + ptm: float | None = None + iptm: float | None = None + pae: torch.Tensor | None = None + distogram: torch.Tensor | None = None + pair_chains_iptm: torch.Tensor | None = None + output_embedding_sequence: torch.Tensor | None = None + output_embedding_pair_pooled: torch.Tensor | None = None + + +@dataclass +class MolecularComplexMetadata: + """Metadata for MolecularComplex objects.""" + + entity_lookup: dict[int, str] + chain_lookup: dict[int, str] + assembly_composition: dict[str, list[str]] | None = None + + +@dataclass +class Molecule: + """Represents a single molecule/token within a MolecularComplex.""" + + token: str + token_idx: int + atom_positions: np.ndarray # [N_atoms, 3] + atom_elements: np.ndarray # [N_atoms] element strings + residue_type: int + molecule_type: int # PROTEIN=0, RNA=1, DNA=2, LIGAND=3 + confidence: float + + +@dataclass(frozen=True) +class MolecularComplex: + """ + Dataclass representing a molecular complex with support for proteins, nucleic acids, and ligands. + + Uses a flat atom representation with token-based sequence indexing, supporting all atom types + beyond the traditional atom37 protein representation. + """ + + id: str + sequence: List[str] # Token sequence like ['MET', 'LYS', 'A', 'G', 'ATP'] + + # Flat atom arrays - simplified representation + atom_positions: np.ndarray # [N_atoms, 3] 3D coordinates + atom_elements: np.ndarray # [N_atoms] element strings + + # Token-to-atom mapping for efficient access + token_to_atoms: np.ndarray # [N_tokens, 2] start/end indices into atoms array + + # Confidence data + plddt: np.ndarray # Per-token confidence scores [N_tokens] + + # Metadata + metadata: MolecularComplexMetadata + + def __post_init__(self): + """Validate array dimensions.""" + n_tokens = len(self.sequence) + assert ( + self.token_to_atoms.shape[0] == n_tokens + ), f"token_to_atoms shape {self.token_to_atoms.shape} != {n_tokens} tokens" + assert ( + self.plddt.shape[0] == n_tokens + ), f"plddt shape {self.plddt.shape} != {n_tokens} tokens" + + def __len__(self) -> int: + """Return number of tokens.""" + return len(self.sequence) + + def __getitem__(self, idx: int) -> Molecule: + """Access individual molecules/tokens by index.""" + if idx >= len(self.sequence) or idx < 0: + raise IndexError( + f"Token index {idx} out of range for {len(self.sequence)} tokens" + ) + + token = self.sequence[idx] + start_atom, end_atom = self.token_to_atoms[idx] + + # Extract atom data for this token + token_atom_positions = self.atom_positions[start_atom:end_atom] + token_atom_elements = self.atom_elements[start_atom:end_atom] + + # Default values for residue/molecule type (would be extended based on actual implementation) + residue_type = 0 # Default to standard residue + molecule_type = 0 # Default to protein + + return Molecule( + token=token, + token_idx=idx, + atom_positions=token_atom_positions, + atom_elements=token_atom_elements, + residue_type=residue_type, + molecule_type=molecule_type, + confidence=self.plddt[idx], + ) + + @property + def atom_coordinates(self) -> np.ndarray: + """Get flat array of all atom coordinates [N_atoms, 3].""" + return self.atom_positions + + # Conversion methods + @classmethod + def from_protein_complex(cls, pc: ProteinComplex) -> "MolecularComplex": + """Convert a ProteinComplex to MolecularComplex. + + Args: + pc: ProteinComplex object with atom37 representation + + Returns: + MolecularComplex with flat atom arrays and token-based indexing + """ + from esm.utils import residue_constants + + # Extract sequence without chain breaks + sequence_no_breaks = pc.sequence.replace("|", "") + sequence_tokens = [ + residue_constants.restype_1to3.get(aa, "UNK") for aa in sequence_no_breaks + ] + + # Convert atom37 to flat arrays + flat_positions = [] + flat_elements = [] + token_to_atoms = [] + + atom_idx = 0 + residue_idx = 0 + + for i, aa in enumerate(pc.sequence): + if aa == "|": + # Skip chain break tokens + continue + + # Get atom37 positions and mask for this residue + res_positions = pc.atom37_positions[residue_idx] # [37, 3] + res_mask = pc.atom37_mask[residue_idx] # [37] + + # Track start position for this token + token_start = atom_idx + + # Process each atom type in atom37 representation + for atom_type_idx, atom_name in enumerate(residue_constants.atom_types): + if res_mask[atom_type_idx]: # Atom is present + # Add position + flat_positions.append(res_positions[atom_type_idx]) + + # Determine element from atom name + element = ( + atom_name[0] if atom_name else "C" + ) # First character is element + flat_elements.append(element) + + atom_idx += 1 + + # Record token-to-atom mapping [start_idx, end_idx) + token_to_atoms.append([token_start, atom_idx]) + residue_idx += 1 + + # Convert to numpy arrays + atom_positions = np.array(flat_positions, dtype=np.float32) + atom_elements = np.array(flat_elements, dtype=object) + token_to_atoms_array = np.array(token_to_atoms, dtype=np.int32) + + # Extract confidence scores (skip chain breaks) + confidence_scores = [] + residue_idx = 0 + for aa in pc.sequence: + if aa != "|": + confidence_scores.append(pc.confidence[residue_idx]) + residue_idx += 1 + + confidence_array = np.array(confidence_scores, dtype=np.float32) + + # Create metadata - convert entity IDs to strings for MolecularComplexMetadata + entity_lookup_str = {k: str(v) for k, v in pc.metadata.entity_lookup.items()} + metadata = MolecularComplexMetadata( + entity_lookup=entity_lookup_str, + chain_lookup=pc.metadata.chain_lookup, + assembly_composition=pc.metadata.assembly_composition, + ) + + return cls( + id=pc.id, + sequence=sequence_tokens, + atom_positions=atom_positions, + atom_elements=atom_elements, + token_to_atoms=token_to_atoms_array, + plddt=confidence_array, + metadata=metadata, + ) + + def to_protein_complex(self) -> ProteinComplex: + """Convert MolecularComplex back to ProteinComplex format. + + Extracts only protein tokens and converts from flat atom representation + back to atom37 format used by ProteinComplex. + + Returns: + ProteinComplex with protein residues only, excluding ligands/nucleic acids + """ + from esm.utils import residue_constants + + # No need for element mapping - already using element characters + + # Filter for protein tokens only (skip ligands, nucleic acids) + protein_tokens = [] + protein_indices = [] + + for i, token in enumerate(self.sequence): + # Check if token is a standard 3-letter amino acid code + if token in residue_constants.restype_3to1: + protein_tokens.append(token) + protein_indices.append(i) + + if not protein_tokens: + raise ValueError("No protein tokens found in MolecularComplex") + + n_residues = len(protein_tokens) + + # Initialize atom37 arrays + atom37_positions = np.full((n_residues, 37, 3), np.nan, dtype=np.float32) + atom37_mask = np.zeros((n_residues, 37), dtype=bool) + + # Convert tokens back to single-letter sequence + single_letter_sequence = "".join( + [residue_constants.restype_3to1[token] for token in protein_tokens] + ) + + # Extract confidence scores for protein residues only + protein_confidence = self.plddt[protein_indices] + + # Convert flat atoms back to atom37 representation + for res_idx, token_idx in enumerate(protein_indices): + token = self.sequence[token_idx] + start_atom, end_atom = self.token_to_atoms[token_idx] + + # Get atom data for this residue + res_atom_positions = self.atom_positions[start_atom:end_atom] + + # Reconstruct atom37 representation by exactly reversing the forward conversion logic + # In from_protein_complex, atoms are added in atom_types order if present in mask + # So we need to reconstruct the mask and positions in the same order + atom_count = 0 + for atom_type_idx, atom_name in enumerate(residue_constants.atom_types): + # Check if this atom type exists for this residue and was present + residue_atoms = residue_constants.residue_atoms.get(token, []) + if atom_name in residue_atoms: + # This atom type exists for this residue, so it should have been included + if atom_count < len(res_atom_positions): + atom37_positions[res_idx, atom_type_idx] = res_atom_positions[ + atom_count + ] + atom37_mask[res_idx, atom_type_idx] = True + atom_count += 1 + + # Create other required arrays for ProteinComplex + # For simplicity, assume all protein residues belong to the same entity/chain + entity_id = np.zeros(n_residues, dtype=np.int64) + chain_id = np.zeros(n_residues, dtype=np.int64) + sym_id = np.zeros(n_residues, dtype=np.int64) + residue_index = np.arange(1, n_residues + 1, dtype=np.int64) + insertion_code = np.array([""] * n_residues, dtype=object) + + # Create simplified protein complex metadata + # Map the first entity/chain from molecular complex metadata + protein_metadata = ProteinComplexMetadata( + entity_lookup={0: 1}, # Single entity (int for ProteinComplexMetadata) + chain_lookup={0: "A"}, # Single chain + assembly_composition=self.metadata.assembly_composition, + ) + + return ProteinComplex( + id=self.id, + sequence=single_letter_sequence, + entity_id=entity_id, + chain_id=chain_id, + sym_id=sym_id, + residue_index=residue_index, + insertion_code=insertion_code, + atom37_positions=atom37_positions, + atom37_mask=atom37_mask, + confidence=protein_confidence, + metadata=protein_metadata, + ) + + @classmethod + def from_mmcif(cls, inp: str, id: str | None = None) -> "MolecularComplex": + """Read MolecularComplex from mmcif file or string. + + Args: + inp: Path to mmCIF file or mmCIF content as string + id: Optional identifier to assign to the complex + + Returns: + MolecularComplex with all molecules (proteins, ligands, nucleic acids) + """ + from io import StringIO + + # Check if input is a file path or mmCIF string content + if os.path.exists(inp): + # Input is a file path + mmcif_file = pdbx.CIFFile.read(inp) + else: + # Input is mmCIF string content + mmcif_file = pdbx.CIFFile.read(StringIO(inp)) + + # Get structure - handle missing model information gracefully + try: + structure = pdbx.get_structure(mmcif_file, model=1) + except (KeyError, ValueError): + # Fallback for mmCIF files without model information + try: + structure = pdbx.get_structure(mmcif_file) + except Exception: + # Last resort: use the first available model or all atoms + structure = pdbx.get_structure(mmcif_file, model=None) + # Type hint for pyright - structure is an AtomArray which is iterable + if TYPE_CHECKING: + structure: Any = structure + + # Get entity information from mmCIF + entity_info = {} + try: + # Access the first block in CIFFile + block = mmcif_file[0] + if "entity" in block: + entity_category = block["entity"] + if "id" in entity_category and "type" in entity_category: + entity_ids = entity_category["id"] + entity_types = entity_category["type"] + # Convert CIFColumn to list for iteration + if hasattr(entity_ids, "__iter__") and hasattr( + entity_types, "__iter__" + ): + # Type annotation to help pyright understand these are iterable + entity_ids_list = list(entity_ids) # type: ignore + entity_types_list = list(entity_types) # type: ignore + for eid, etype in zip(entity_ids_list, entity_types_list): + entity_info[eid] = etype + except Exception: + pass + + # Initialize arrays for flat atom representation + sequence_tokens = [] + flat_positions = [] + flat_elements = [] + token_to_atoms = [] + confidence_scores = [] + + atom_idx = 0 + + # Group atoms by chain and residue + chain_residue_groups = {} + for atom in structure: + chain_id = atom.chain_id + res_id = atom.res_id + res_name = atom.res_name + + if chain_id not in chain_residue_groups: + chain_residue_groups[chain_id] = {} + if res_id not in chain_residue_groups[chain_id]: + chain_residue_groups[chain_id][res_id] = { + "atoms": [], + "res_name": res_name, + "is_hetero": atom.hetero, + } + chain_residue_groups[chain_id][res_id]["atoms"].append(atom) + + # Process each chain and residue + for chain_id in sorted(chain_residue_groups.keys()): + residues = chain_residue_groups[chain_id] + + for res_id in sorted(residues.keys()): + residue_data = residues[res_id] + res_name = residue_data["res_name"] + atoms = residue_data["atoms"] + is_hetero = residue_data["is_hetero"] + + # Skip water molecules + if res_name == "HOH": + continue + + # Determine token name + if not is_hetero and res_name in residue_constants.restype_3to1: + # Standard amino acid + token_name = res_name + elif res_name in ["A", "T", "G", "C", "U", "DA", "DT", "DG", "DC"]: + # Nucleotide + token_name = res_name + else: + # Ligand or other molecule + token_name = res_name + + sequence_tokens.append(token_name) + token_start = atom_idx + + # Add all atoms from this residue + for atom in atoms: + flat_positions.append(atom.coord) + + # Get element character + element = atom.element + flat_elements.append(element) + + atom_idx += 1 + + # Record token-to-atom mapping + token_to_atoms.append([token_start, atom_idx]) + + # Add confidence score (B-factor if available, otherwise 1.0) + bfactor = getattr(atoms[0], "b_factor", 50.0) if atoms else 50.0 + confidence_scores.append(min(bfactor / 100.0, 1.0)) + + # Convert to numpy arrays + if not flat_positions: + # Create minimal arrays if no atoms found + atom_positions = np.zeros((0, 3), dtype=np.float32) + atom_elements = np.zeros(0, dtype=object) + token_to_atoms_array = np.zeros((len(sequence_tokens), 2), dtype=np.int32) + else: + atom_positions = np.array(flat_positions, dtype=np.float32) + atom_elements = np.array(flat_elements, dtype=object) + token_to_atoms_array = np.array(token_to_atoms, dtype=np.int32) + + confidence_array = np.array(confidence_scores, dtype=np.float32) + + # Create metadata + metadata = MolecularComplexMetadata( + entity_lookup=entity_info, + chain_lookup={ + i: chain_id for i, chain_id in enumerate(chain_residue_groups.keys()) + }, + assembly_composition=None, + ) + + # Set complex ID - if input was a path, use the stem; otherwise use default + if os.path.exists(inp): + complex_id = id or Path(inp).stem + else: + complex_id = id or "complex_from_string" + + return cls( + id=complex_id, + sequence=sequence_tokens, + atom_positions=atom_positions, + atom_elements=atom_elements, + token_to_atoms=token_to_atoms_array, + plddt=confidence_array, + metadata=metadata, + ) + + def to_mmcif(self) -> str: + """Write MolecularComplex to mmcif string. + + Returns: + String representation of the complex in mmCIF format + """ + # No need for element mapping - already using element characters + + lines = [] + + # Header + lines.append(f"data_{self.id}") + lines.append("#") + lines.append(f"_entry.id {self.id}") + lines.append("#") + + # Structure metadata + lines.append("_struct.entry_id {}".format(self.id)) + lines.append("_struct.title 'Protein Structure'") + lines.append("#") + + # Entity information + entity_id = 1 + chain_counter = 0 + lines.append("loop_") + lines.append("_entity.id") + lines.append("_entity.type") + lines.append("_entity.pdbx_description") + + # Determine entities based on sequence + protein_tokens = [] + other_tokens = [] + + for i, token in enumerate(self.sequence): + if token in residue_constants.restype_3to1: + protein_tokens.append((i, token)) + else: + other_tokens.append((i, token)) + + if protein_tokens: + lines.append(f"{entity_id} polymer 'Protein chain'") + entity_id += 1 + + for token in set(token for _, token in other_tokens): + lines.append(f"{entity_id} non-polymer 'Ligand {token}'") + entity_id += 1 + + lines.append("#") + + # Chain assignments + lines.append("loop_") + lines.append("_struct_asym.id") + lines.append("_struct_asym.entity_id") + + chain_id = "A" + if protein_tokens: + lines.append(f"{chain_id} 1") + chain_counter += 1 + chain_id = chr(ord(chain_id) + 1) + + entity_id = 2 + for token in set(token for _, token in other_tokens): + lines.append(f"{chain_id} {entity_id}") + entity_id += 1 + chain_counter += 1 + if chain_counter < 26: + chain_id = chr(ord(chain_id) + 1) + + lines.append("#") + + # Atom site information + lines.append("loop_") + lines.append("_atom_site.group_PDB") + lines.append("_atom_site.id") + lines.append("_atom_site.type_symbol") + lines.append("_atom_site.label_atom_id") + lines.append("_atom_site.label_alt_id") + lines.append("_atom_site.label_comp_id") + lines.append("_atom_site.label_asym_id") + lines.append("_atom_site.label_entity_id") + lines.append("_atom_site.label_seq_id") + lines.append("_atom_site.pdbx_PDB_ins_code") + lines.append("_atom_site.Cartn_x") + lines.append("_atom_site.Cartn_y") + lines.append("_atom_site.Cartn_z") + lines.append("_atom_site.occupancy") + lines.append("_atom_site.B_iso_or_equiv") + lines.append("_atom_site.pdbx_PDB_model_num") + lines.append("_atom_site.auth_seq_id") + lines.append("_atom_site.auth_comp_id") + lines.append("_atom_site.auth_asym_id") + lines.append("_atom_site.auth_atom_id") + + atom_id = 1 + seq_id = 1 + chain_id = "A" + entity_id = 1 + + for token_idx, token in enumerate(self.sequence): + start_atom, end_atom = self.token_to_atoms[token_idx] + + # Determine if this is a protein residue or ligand + is_protein = token in residue_constants.restype_3to1 + group_pdb = "ATOM" if is_protein else "HETATM" + current_entity_id = 1 if is_protein else 2 # Simplified entity assignment + current_chain_id = "A" if is_protein else "B" # Simplified chain assignment + + # Create atom names for this token + atom_names = [] + if is_protein: + # Use standard protein atom names + res_atoms = residue_constants.residue_atoms.get( + token, ["N", "CA", "C", "O"] + ) + atom_names = res_atoms[: end_atom - start_atom] + else: + # Generate generic atom names for ligands + for i in range(end_atom - start_atom): + atom_names.append(f"C{i+1}") + + # Pad atom names if needed + while len(atom_names) < (end_atom - start_atom): + atom_names.append(f"X{len(atom_names)+1}") + + # Write atoms for this token + for atom_idx_in_token, global_atom_idx in enumerate( + range(start_atom, end_atom) + ): + pos = self.atom_positions[global_atom_idx] + element_char = self.atom_elements[global_atom_idx] + element_symbol = element_char if isinstance(element_char, str) else "C" + + atom_name = ( + atom_names[atom_idx_in_token] + if atom_idx_in_token < len(atom_names) + else f"X{atom_idx_in_token+1}" + ) + + # Format atom site line + bfactor = ( + self.plddt[token_idx] * 100.0 + if len(self.plddt) > token_idx + else 50.0 + ) + + line = ( + f"{group_pdb:<6} {atom_id:>5} {element_symbol:<2} {atom_name:<4} . " + f"{token:<3} {current_chain_id} {current_entity_id} {seq_id:>3} ? " + f"{pos[0]:>8.3f} {pos[1]:>8.3f} {pos[2]:>8.3f} 1.00 {bfactor:>6.2f} 1 " + f"{seq_id:>3} {token:<3} {current_chain_id} {atom_name:<4}" + ) + lines.append(line) + atom_id += 1 + + seq_id += 1 + + lines.append("#") + return "\n".join(lines) + + def dockq(self, native: "MolecularComplex") -> Any: + """Compute DockQ score against native structure. + + Args: + native: Native MolecularComplex to compute DockQ against + + Returns: + DockQ result containing score and alignment information + """ + # Imports moved to top of file + + # Convert both complexes to ProteinComplex format for DockQ computation + # This extracts only the protein portion and converts to PDB format + try: + self_pc = self.to_protein_complex() + native_pc = native.to_protein_complex() + except ValueError as e: + raise ValueError( + f"Cannot convert MolecularComplex to ProteinComplex for DockQ: {e}" + ) + + # Normalize chain IDs for PDB compatibility + self_pc = self_pc.normalize_chain_ids_for_pdb() + native_pc = native_pc.normalize_chain_ids_for_pdb() + + # Use the existing ProteinComplex.dockq() method + try: + dockq_result = self_pc.dockq(native_pc) + return dockq_result + except Exception: + # Fallback to manual DockQ computation if ProteinComplex.dockq() fails + return self._compute_dockq_manual(native) + + def _compute_dockq_manual(self, native: "MolecularComplex") -> Any: + """Manual DockQ computation fallback.""" + # Imports moved to top of file + + # Convert both complexes to ProteinComplex format + try: + self_pc = self.to_protein_complex() + native_pc = native.to_protein_complex() + except ValueError as e: + raise ValueError( + f"Cannot convert MolecularComplex to ProteinComplex for DockQ: {e}" + ) + + # Normalize chain IDs for PDB compatibility + self_pc = self_pc.normalize_chain_ids_for_pdb() + native_pc = native_pc.normalize_chain_ids_for_pdb() + + # Write temporary PDB files and run DockQ + with TemporaryDirectory() as tdir: + dir_path = Path(tdir) + self_pdb = dir_path / "self.pdb" + native_pdb = dir_path / "native.pdb" + + # Write PDB files + self_pc.to_pdb(self_pdb) + native_pc.to_pdb(native_pdb) + + # Run DockQ + try: + output = check_output(["DockQ", str(self_pdb), str(native_pdb)]) + output_text = output.decode() + + # Parse DockQ output + lines = output_text.split("\n") + + # Find the total DockQ score + dockq_score = None + for line in lines: + if "Total DockQ" in line: + match = re.search(r"Total DockQ.*: ([\d.]+)", line) + if match: + dockq_score = float(match.group(1)) + break + + if dockq_score is None: + # Try to find individual DockQ scores + for line in lines: + if line.startswith("DockQ") and ":" in line: + try: + dockq_score = float(line.split(":")[1].strip()) + break + except (ValueError, IndexError): + continue + + if dockq_score is None: + raise ValueError("Could not parse DockQ score from output") + + # Return a simple result structure + return { + "total_dockq": dockq_score, + "raw_output": output_text, + "aligned": self, # Return self as aligned structure + } + + except FileNotFoundError: + raise RuntimeError( + "DockQ is not installed. Please install DockQ to use this method." + ) + except Exception as e: + raise RuntimeError(f"DockQ computation failed: {e}") + + def rmsd(self, target: "MolecularComplex", **kwargs) -> float: + """Compute RMSD against target structure. + + Args: + target: Target MolecularComplex to compute RMSD against + **kwargs: Additional arguments passed to compute_rmsd + + Returns: + float: RMSD value between the two structures + """ + # Imports moved to top of file + + # Ensure both complexes have the same number of tokens + if len(self) != len(target): + raise ValueError( + f"Complexes must have the same number of tokens: {len(self)} vs {len(target)}" + ) + + # Extract center positions for each token (using centroid of atoms) + mobile_coords = [] + target_coords = [] + atom_mask = [] + + for i in range(len(self)): + # Get atom positions for this token + mobile_start, mobile_end = self.token_to_atoms[i] + target_start, target_end = target.token_to_atoms[i] + + # Extract atom positions + mobile_atoms = self.atom_positions[mobile_start:mobile_end] + target_atoms = target.atom_positions[target_start:target_end] + + # Check if both tokens have atoms + if len(mobile_atoms) == 0 or len(target_atoms) == 0: + # Skip tokens with no atoms + continue + + # For simplicity, use the centroid of atoms as the representative position + mobile_center = mobile_atoms.mean(axis=0) + target_center = target_atoms.mean(axis=0) + + mobile_coords.append(mobile_center) + target_coords.append(target_center) + atom_mask.append(True) + + if len(mobile_coords) == 0: + raise ValueError("No valid atoms found for RMSD computation") + + # Convert to tensors + mobile_tensor = torch.from_numpy(np.stack(mobile_coords, axis=0)).unsqueeze( + 0 + ) # [1, N, 3] + target_tensor = torch.from_numpy(np.stack(target_coords, axis=0)).unsqueeze( + 0 + ) # [1, N, 3] + mask_tensor = torch.tensor(atom_mask, dtype=torch.bool).unsqueeze(0) # [1, N] + + # Compute RMSD using existing infrastructure + rmsd_value = compute_rmsd( + mobile=mobile_tensor, + target=target_tensor, + atom_exists_mask=mask_tensor, + reduction="batch", + **kwargs, + ) + + return float(rmsd_value) + + def lddt_ca(self, target: "MolecularComplex", **kwargs) -> float: + """Compute LDDT score against target structure. + + Args: + target: Target MolecularComplex to compute LDDT against + **kwargs: Additional arguments passed to compute_lddt + + Returns: + float: LDDT value between the two structures + """ + # Imports moved to top of file + + # Ensure both complexes have the same number of tokens + if len(self) != len(target): + raise ValueError( + f"Complexes must have the same number of tokens: {len(self)} vs {len(target)}" + ) + + # Extract center positions for each token (using centroid of atoms) + mobile_coords = [] + target_coords = [] + atom_mask = [] + + for i in range(len(self)): + # Get atom positions for this token + mobile_start, mobile_end = self.token_to_atoms[i] + target_start, target_end = target.token_to_atoms[i] + + # Extract atom positions + mobile_atoms = self.atom_positions[mobile_start:mobile_end] + target_atoms = target.atom_positions[target_start:target_end] + + # Check if both tokens have atoms + if len(mobile_atoms) == 0 or len(target_atoms) == 0: + # Skip tokens with no atoms + mobile_coords.append(np.full(3, np.nan)) + target_coords.append(np.full(3, np.nan)) + atom_mask.append(False) + continue + + # For simplicity, use the centroid of atoms as the representative position + mobile_center = mobile_atoms.mean(axis=0) + target_center = target_atoms.mean(axis=0) + + mobile_coords.append(mobile_center) + target_coords.append(target_center) + atom_mask.append(True) + + if not any(atom_mask): + raise ValueError("No valid atoms found for LDDT computation") + + # Convert to tensors + mobile_tensor = torch.from_numpy(np.stack(mobile_coords, axis=0)).unsqueeze( + 0 + ) # [1, N, 3] + target_tensor = torch.from_numpy(np.stack(target_coords, axis=0)).unsqueeze( + 0 + ) # [1, N, 3] + mask_tensor = torch.tensor(atom_mask, dtype=torch.bool).unsqueeze(0) # [1, N] + + # Compute LDDT using existing infrastructure + lddt_value = compute_lddt( + all_atom_pred_pos=mobile_tensor, + all_atom_positions=target_tensor, + all_atom_mask=mask_tensor, + per_residue=False, # Return overall LDDT score + **kwargs, + ) + + return float(lddt_value) + + def state_dict(self): + """This state dict is optimized for storage, so it turns things to fp16 whenever + possible and converts numpy arrays to lists for JSON serialization. + """ + dct = {k: v for k, v in vars(self).items()} + for k, v in dct.items(): + if isinstance(v, np.ndarray): + match v.dtype: + case np.int64: + dct[k] = v.astype(np.int32).tolist() + case np.float64 | np.float32: + dct[k] = v.astype(np.float16).tolist() + case _: + dct[k] = v.tolist() + elif isinstance(v, MolecularComplexMetadata): + dct[k] = asdict(v) + + return dct + + def to_blob(self) -> bytes: + return brotli.compress(msgpack.dumps(self.state_dict()), quality=5) + + @classmethod + def from_state_dict(cls, dct): + for k, v in dct.items(): + if isinstance(v, list) and k in [ + "atom_positions", + "atom_elements", + "token_to_atoms", + "plddt", + ]: + dct[k] = np.array(v) + + for k, v in dct.items(): + if isinstance(v, np.ndarray): + if k in ["atom_positions", "plddt"]: + dct[k] = v.astype(np.float32) + elif k in ["token_to_atoms"]: + dct[k] = v.astype(np.int32) + + dct["metadata"] = MolecularComplexMetadata(**dct["metadata"]) + return cls(**dct) + + @classmethod + def from_blob(cls, input: Path | str | io.BytesIO | bytes): + match input: + case Path() | str(): + bytes = Path(input).read_bytes() + case io.BytesIO(): + bytes = input.getvalue() + case _: + bytes = input + return cls.from_state_dict( + msgpack.loads(brotli.decompress(bytes), strict_map_key=False) + ) diff --git a/esm/utils/structure/protein_complex.py b/esm/utils/structure/protein_complex.py index 7e2c38de..306c8832 100644 --- a/esm/utils/structure/protein_complex.py +++ b/esm/utils/structure/protein_complex.py @@ -377,11 +377,14 @@ def switch_assembly(self, id: str): assert self.metadata.mmcif is not None return get_assembly_fast(self.metadata.mmcif, assembly_id=id) - def state_dict(self, backbone_only=False): + def state_dict(self, backbone_only=False, json_serializable=False): """This state dict is optimized for storage, so it turns things to fp16 whenever possible. Note that we also only support int32 residue indices, I'm hoping we don't need more than 2**32 residues...""" dct = {k: v for k, v in vars(self).items()} + if backbone_only: + dct["atom37_mask"][:, 3:] = False + dct["atom37_positions"] = dct["atom37_positions"][dct["atom37_mask"]] for k, v in dct.items(): if isinstance(v, np.ndarray): match v.dtype: @@ -391,9 +394,10 @@ def state_dict(self, backbone_only=False): dct[k] = v.astype(np.float16) case _: pass + if json_serializable: + dct[k] = v.tolist() elif isinstance(v, ProteinComplexMetadata): dct[k] = asdict(v) - dct["atom37_positions"] = dct["atom37_positions"][dct["atom37_mask"]] dct["metadata"]["mmcif"] = None # These can be populated with non-serializable objects and are not needed for reconstruction dct.pop("atoms", None) @@ -406,6 +410,10 @@ def to_blob(self, backbone_only=False) -> bytes: @classmethod def from_state_dict(cls, dct): + for k, v in dct.items(): + if isinstance(v, list): + dct[k] = np.array(v) + atom37 = np.full((*dct["atom37_mask"].shape, 3), np.nan) atom37[dct["atom37_mask"]] = dct["atom37_positions"] dct["atom37_positions"] = atom37 diff --git a/esm/utils/system.py b/esm/utils/system.py new file mode 100644 index 00000000..c2800e57 --- /dev/null +++ b/esm/utils/system.py @@ -0,0 +1,45 @@ +import io +import subprocess +import typing as T +from pathlib import Path + +PathLike = T.Union[str, Path] +PathOrBuffer = T.Union[PathLike, io.StringIO] + + +def run_subprocess_with_errorcheck( + *popenargs, + capture_output: bool = False, + quiet: bool = False, + env: dict[str, str] | None = None, + shell: bool = False, + executable: str | None = None, + **kws, +) -> subprocess.CompletedProcess: + """A command similar to subprocess.run, however the errormessage will + contain the stderr when using this function. This makes it significantly + easier to diagnose issues. + """ + try: + if capture_output: + stdout = subprocess.PIPE + elif quiet: + stdout = subprocess.DEVNULL + else: + stdout = None + + p = subprocess.run( + *popenargs, + stderr=subprocess.PIPE, + stdout=stdout, + check=True, + env=env, + shell=shell, + executable=executable, + **kws, + ) + except subprocess.CalledProcessError as e: + raise RuntimeError( + f"Command failed with errorcode {e.returncode}." f"\n\n{e.stderr.decode()}" + ) + return p diff --git a/pixi.lock b/pixi.lock index be3b375f..897f11a5 100644 --- a/pixi.lock +++ b/pixi.lock @@ -1726,8 +1726,8 @@ packages: requires_python: '>=3.8' - pypi: ./ name: esm - version: 3.2.2 - sha256: c14e2546bda5f0910c14acfabb7ea334e7171905c6799b43178f0420a92d6f3e + version: 3.2.2.post2 + sha256: 3f59a2977c85d35b4b1353902fa90e35d02acbabe6ffb506727bd406ec987ad1 requires_dist: - torch>=2.2.0 - torchvision diff --git a/pyproject.toml b/pyproject.toml index 2f8008f8..0dcc398d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "esm" -version = "3.2.2" +version = "3.2.2.post2" description = "EvolutionaryScale open model repository" readme = "README.md" requires-python = ">=3.12,<3.13" @@ -45,7 +45,6 @@ dependencies = [ "pygtrie", "dna_features_viewer", ] - # Pytest [tool.pytest.ini_options] addopts = """ diff --git a/tests/oss_pytests/requirements.txt b/tests/oss_pytests/requirements.txt index d120ce35..143baa7a 100644 --- a/tests/oss_pytests/requirements.txt +++ b/tests/oss_pytests/requirements.txt @@ -1,3 +1,2 @@ -esm +esm >=3.2.1post1,<4.0.0 pytest -httpx # TODO(williamxi): Remove this after the esm repo is fixed diff --git a/tests/oss_pytests/test_oss_client.py b/tests/oss_pytests/test_oss_client.py index 9dd5b188..4e7638f6 100644 --- a/tests/oss_pytests/test_oss_client.py +++ b/tests/oss_pytests/test_oss_client.py @@ -1,6 +1,7 @@ import os import pytest +import torch from esm.sdk import client # pyright: ignore from esm.sdk.api import ( # pyright: ignore @@ -37,6 +38,8 @@ def test_oss_esm3_client(): logits_config = LogitsConfig(sequence=True, return_embeddings=True) result = esm3_client.logits(input=encoded_protein, config=logits_config) assert isinstance(result, LogitsOutput) + assert result.logits is not None + assert isinstance(result.logits.sequence, torch.Tensor) sampling_config = SamplingConfig(sequence=SamplingTrackConfig(temperature=0.1)) result = esm3_client.forward_and_sample( @@ -69,6 +72,8 @@ def test_oss_esmc_client(): ) result = esmc_client.logits(input=encoded_protein, config=logits_config) assert isinstance(result, LogitsOutput) + assert result.logits is not None + assert isinstance(result.logits.sequence, torch.Tensor) @pytest.mark.sdk