From aaed83a018ae3c8a7153ef7921fadd49f2c15003 Mon Sep 17 00:00:00 2001 From: Neil Thomas Date: Wed, 17 Sep 2025 21:14:52 +0000 Subject: [PATCH] sync 3.2.2.post1 --- cookbook/local/open_generate.ipynb | 1 - cookbook/local/raw_forwards.py | 5 +- cookbook/snippets/fold_invfold.py | 1 - cookbook/tutorials/1_esmprotein.ipynb | 4 +- cookbook/tutorials/2_embed.ipynb | 6 +- cookbook/tutorials/3_gfp_design.ipynb | 7 +- cookbook/tutorials/4_forge_generate.ipynb | 5 +- cookbook/tutorials/5_guided_generation.ipynb | 3 +- esm/__init__.py | 3 +- esm/layers/attention.py | 5 +- esm/layers/blocks.py | 9 +- esm/layers/structure_proj.py | 5 +- esm/models/esm3.py | 14 +- esm/models/function_decoder.py | 4 +- esm/models/vqvae.py | 5 +- esm/pretrained.py | 10 +- esm/sdk/api.py | 54 +- esm/sdk/base_forge_client.py | 21 +- esm/sdk/forge.py | 161 ++- esm/sdk/retry.py | 19 +- esm/tokenization/__init__.py | 5 +- esm/utils/decoding.py | 24 +- esm/utils/encoding.py | 29 +- esm/utils/forge_context_manager.py | 5 +- esm/utils/function/encode_decode.py | 13 +- esm/utils/generation.py | 13 +- esm/utils/misc.py | 140 ++- esm/utils/msa/__init__.py | 7 + esm/utils/msa/filter_sequences.py | 79 ++ esm/utils/msa/msa.py | 507 ++++++++++ esm/utils/parsing.py | 83 ++ esm/utils/sampling.py | 15 +- esm/utils/sequential_dataclass.py | 157 +++ esm/utils/structure/aligner.py | 4 +- esm/utils/structure/atom_indexer.py | 4 +- esm/utils/structure/input_builder.py | 96 ++ esm/utils/structure/metrics.py | 2 +- esm/utils/structure/molecular_complex.py | 944 ++++++++++++++++++ esm/utils/structure/protein_chain.py | 16 +- esm/utils/structure/protein_complex.py | 10 +- esm/utils/system.py | 45 + esm/widgets/components/function_annotator.py | 4 +- esm/widgets/components/results_visualizer.py | 8 +- .../components/sasa_prompt_selector.py | 13 +- .../secondary_structure_prompt_selector.py | 13 +- .../components/sequence_prompt_selector.py | 4 +- .../components/structure_prompt_selector.py | 8 +- .../drawing/draw_function_annotations.py | 5 +- esm/widgets/utils/prompting.py | 4 +- esm/widgets/views/esm3_generation_launcher.py | 8 +- esm/widgets/views/esm3_prompt_selector.py | 4 +- esm/widgets/views/generation.py | 16 +- esm/widgets/views/inverse_folding.py | 4 +- esm/widgets/views/login.py | 5 +- esm/widgets/views/prediction.py | 4 +- pyproject.toml | 3 +- tests/oss_pytests/requirements.txt | 3 +- tests/oss_pytests/test_oss_client.py | 8 +- 58 files changed, 2513 insertions(+), 141 deletions(-) create mode 100644 esm/utils/msa/__init__.py create mode 100644 esm/utils/msa/filter_sequences.py create mode 100644 esm/utils/msa/msa.py create mode 100644 esm/utils/parsing.py create mode 100644 esm/utils/sequential_dataclass.py create mode 100644 esm/utils/structure/input_builder.py create mode 100644 esm/utils/structure/molecular_complex.py create mode 100644 esm/utils/system.py diff --git a/cookbook/local/open_generate.ipynb b/cookbook/local/open_generate.ipynb index 32a72c38..c8cec4c1 100644 --- a/cookbook/local/open_generate.ipynb +++ b/cookbook/local/open_generate.ipynb @@ -38,7 +38,6 @@ "\n", "!pip install py3Dmol\n", "import py3Dmol\n", - "\n", "from esm.models.esm3 import ESM3\n", "from esm.sdk.api import ESMProtein, GenerationConfig\n", "from esm.utils.structure.protein_chain import ProteinChain" diff --git a/cookbook/local/raw_forwards.py b/cookbook/local/raw_forwards.py index baad28ee..5701fa2a 100644 --- a/cookbook/local/raw_forwards.py +++ b/cookbook/local/raw_forwards.py @@ -2,7 +2,6 @@ import torch import torch.nn.functional as F - from esm.pretrained import ( ESM3_function_decoder_v0, ESM3_sm_open_v0, @@ -13,7 +12,9 @@ from esm.tokenization.function_tokenizer import ( InterProQuantizedTokenizer as EsmFunctionTokenizer, ) -from esm.tokenization.sequence_tokenizer import EsmSequenceTokenizer +from esm.tokenization.sequence_tokenizer import ( + EsmSequenceTokenizer, +) from esm.utils.structure.protein_chain import ProteinChain from esm.utils.types import FunctionAnnotation diff --git a/cookbook/snippets/fold_invfold.py b/cookbook/snippets/fold_invfold.py index e6be5485..cb24db6a 100644 --- a/cookbook/snippets/fold_invfold.py +++ b/cookbook/snippets/fold_invfold.py @@ -2,7 +2,6 @@ from typing import cast import numpy as np - from esm.sdk.api import ( ESM3InferenceClient, ESMProtein, diff --git a/cookbook/tutorials/1_esmprotein.ipynb b/cookbook/tutorials/1_esmprotein.ipynb index 13017733..e143ff13 100644 --- a/cookbook/tutorials/1_esmprotein.ipynb +++ b/cookbook/tutorials/1_esmprotein.ipynb @@ -72,7 +72,6 @@ "outputs": [], "source": [ "from biotite.database import rcsb\n", - "\n", "from esm.sdk.api import ESMProtein\n", "from esm.utils.structure.protein_chain import ProteinChain\n", "from esm.utils.types import FunctionAnnotation\n", @@ -497,9 +496,8 @@ "# Functions for visualizing InterPro function annotations\n", "\n", "from dna_features_viewer import GraphicFeature, GraphicRecord\n", - "from matplotlib import colormaps\n", - "\n", "from esm.utils.function.interpro import InterPro, InterProEntryType\n", + "from matplotlib import colormaps\n", "\n", "\n", "def visualize_function_annotations(\n", diff --git a/cookbook/tutorials/2_embed.ipynb b/cookbook/tutorials/2_embed.ipynb index 459fa90e..61fdb397 100644 --- a/cookbook/tutorials/2_embed.ipynb +++ b/cookbook/tutorials/2_embed.ipynb @@ -49,18 +49,18 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Grab a token from [the Forge console](https://forge.evolutionaryscale.ai/console) and add it below. Note that your token is like a password for your account and you should take care to protect it. For this reason it is recommended to frequently create a new token and delete old, unused ones. It is also recommended to paste the token directly into an environment variable or use a utility like `getpass` as shown below so tokens are not accidentally shared or checked into code repositories." + "Grab a token from [Forge](https://forge.evolutionaryscale.ai/) and add it below. Note that your token is like a password for your account and you should take care to protect it. For this reason it is recommended to frequently create a new token and delete old, unused ones. It is also recommended to paste the token directly into an environment variable or use a utility like `getpass` as shown below so tokens are not accidentally shared or checked into code repositories." ] }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from getpass import getpass\n", "\n", - "token = getpass(\"Token from Forge console: \")" + "token = getpass(\"Token from Forge: \")" ] }, { diff --git a/cookbook/tutorials/3_gfp_design.ipynb b/cookbook/tutorials/3_gfp_design.ipynb index bde09ad5..95b42418 100644 --- a/cookbook/tutorials/3_gfp_design.ipynb +++ b/cookbook/tutorials/3_gfp_design.ipynb @@ -64,7 +64,6 @@ "import matplotlib.pyplot as pl\n", "import py3Dmol\n", "import torch\n", - "\n", "from esm.sdk import client\n", "from esm.sdk.api import ESMProtein, GenerationConfig\n", "from esm.utils.structure.protein_chain import ProteinChain" @@ -80,18 +79,18 @@ "\n", "The largest ESM3 (98 billion parameters) was trained with 1.07e24 FLOPs on 2.78 billion proteins and 771 billion unique tokens. To create esmGFP we used the 7 billion parameter variant of ESM3. We'll use this model via the [EvolutionaryScale Forge](https://forge.evolutionaryscale.ai) API.\n", "\n", - "Grab a token from [the Forge console](https://forge.evolutionaryscale.ai/console) and add it below. Note that your token is like a password for your account and you should take care to protect it. For this reason it is recommended to frequently create a new token and delete old, unused ones. It is also recommended to paste the token directly into an environment variable or use a utility like `getpass` as shown below so tokens are not accidentally shared or checked into code repositories.\n" + "Grab a token from [Forge](https://forge.evolutionaryscale.ai/) and add it below. Note that your token is like a password for your account and you should take care to protect it. For this reason it is recommended to frequently create a new token and delete old, unused ones. It is also recommended to paste the token directly into an environment variable or use a utility like `getpass` as shown below so tokens are not accidentally shared or checked into code repositories.\n" ] }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "metadata": { "id": "zNrU9Q2SYonX" }, "outputs": [], "source": [ - "token = getpass(\"Token from Forge console: \")" + "token = getpass(\"Token from Forge: \")" ] }, { diff --git a/cookbook/tutorials/4_forge_generate.ipynb b/cookbook/tutorials/4_forge_generate.ipynb index 7570fda9..5fb6e676 100644 --- a/cookbook/tutorials/4_forge_generate.ipynb +++ b/cookbook/tutorials/4_forge_generate.ipynb @@ -36,7 +36,6 @@ "\n", "!pip install py3Dmol\n", "import py3Dmol\n", - "\n", "from esm.sdk import client\n", "from esm.sdk.api import ESMProtein, GenerationConfig\n", "from esm.utils.structure.protein_chain import ProteinChain" @@ -53,7 +52,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Grab a token from [the Forge console](https://forge.evolutionaryscale.ai/console) and add it below. Note that your token is like a password for your account and you should take care to protect it. For this reason it is recommended to frequently create a new token and delete old, unused ones. It is also recommended to paste the token directly into an environment variable or use a utility like `getpass` as shown below so tokens are not accidentally shared or checked into code repositories." + "Grab a token from [Forge](https://forge.evolutionaryscale.ai/) and add it below. Note that your token is like a password for your account and you should take care to protect it. For this reason it is recommended to frequently create a new token and delete old, unused ones. It is also recommended to paste the token directly into an environment variable or use a utility like `getpass` as shown below so tokens are not accidentally shared or checked into code repositories." ] }, { @@ -64,7 +63,7 @@ "source": [ "from getpass import getpass\n", "\n", - "token = getpass(\"Token from Forge console: \")\n", + "token = getpass(\"Token from Forge: \")\n", "model = client(model=\"esm3-open\", url=\"https://forge.evolutionaryscale.ai\", token=token)" ] }, diff --git a/cookbook/tutorials/5_guided_generation.ipynb b/cookbook/tutorials/5_guided_generation.ipynb index f04d6be3..b2d8c7dc 100644 --- a/cookbook/tutorials/5_guided_generation.ipynb +++ b/cookbook/tutorials/5_guided_generation.ipynb @@ -49,7 +49,6 @@ "source": [ "import biotite.structure as bs\n", "import py3Dmol\n", - "\n", "from esm.sdk.api import ESMProtein, GenerationConfig\n", "from esm.sdk.experimental import ESM3GuidedDecoding, GuidedDecodingScoringFunction" ] @@ -120,7 +119,7 @@ "\n", "from esm.sdk import client\n", "\n", - "token = getpass(\"Token from Forge console: \")\n", + "token = getpass(\"Token from Forge: \")\n", "model = client(\n", " model=\"esm3-medium-2024-08\", url=\"https://forge.evolutionaryscale.ai\", token=token\n", ")" diff --git a/esm/__init__.py b/esm/__init__.py index 1e3bed4c..aec10a41 100644 --- a/esm/__init__.py +++ b/esm/__init__.py @@ -1 +1,2 @@ -__version__ = "3.2.2" +__version__ = "3.2.2.post1" + diff --git a/esm/layers/attention.py b/esm/layers/attention.py index 564ef90c..ce57632f 100644 --- a/esm/layers/attention.py +++ b/esm/layers/attention.py @@ -5,7 +5,10 @@ import torch.nn.functional as F from torch import nn -from esm.layers.rotary import RotaryEmbedding, TritonRotaryEmbedding +from esm.layers.rotary import ( + RotaryEmbedding, + TritonRotaryEmbedding, +) try: from flash_attn import flash_attn_varlen_qkvpacked_func diff --git a/esm/layers/blocks.py b/esm/layers/blocks.py index 76ebbe06..593b277f 100644 --- a/esm/layers/blocks.py +++ b/esm/layers/blocks.py @@ -2,8 +2,13 @@ import torch.nn as nn import torch.nn.functional as F -from esm.layers.attention import FlashMultiHeadAttention, MultiHeadAttention -from esm.layers.geom_attention import GeometricReasoningOriginalImpl +from esm.layers.attention import ( + FlashMultiHeadAttention, + MultiHeadAttention, +) +from esm.layers.geom_attention import ( + GeometricReasoningOriginalImpl, +) from esm.utils.structure.affine3d import Affine3D diff --git a/esm/layers/structure_proj.py b/esm/layers/structure_proj.py index faad0fe9..783ddeb4 100644 --- a/esm/layers/structure_proj.py +++ b/esm/layers/structure_proj.py @@ -2,7 +2,10 @@ import torch.nn as nn from esm.utils.constants.physics import BB_COORDINATES -from esm.utils.structure.affine3d import Affine3D, RotationMatrix +from esm.utils.structure.affine3d import ( + Affine3D, + RotationMatrix, +) class Dim6RotStructureHead(nn.Module): diff --git a/esm/models/esm3.py b/esm/models/esm3.py index 218a8e90..cbe02ddd 100644 --- a/esm/models/esm3.py +++ b/esm/models/esm3.py @@ -13,7 +13,10 @@ from esm.layers.regression_head import RegressionHead from esm.layers.transformer_stack import TransformerStack from esm.models.function_decoder import FunctionTokenDecoder -from esm.models.vqvae import StructureTokenDecoder, StructureTokenEncoder +from esm.models.vqvae import ( + StructureTokenDecoder, + StructureTokenEncoder, +) from esm.sdk.api import ( ESM3InferenceClient, ESMProtein, @@ -29,7 +32,10 @@ from esm.tokenization import TokenizerCollectionProtocol from esm.utils import encoding from esm.utils.constants import esm3 as C -from esm.utils.constants.models import ESM3_OPEN_SMALL, normalize_model_name +from esm.utils.constants.models import ( + ESM3_OPEN_SMALL, + normalize_model_name, +) from esm.utils.decoding import decode_protein_tensor from esm.utils.generation import ( _batch_forward, @@ -44,7 +50,9 @@ get_default_sampling_config, validate_sampling_config, ) -from esm.utils.structure.affine3d import build_affine3d_from_coordinates +from esm.utils.structure.affine3d import ( + build_affine3d_from_coordinates, +) @dataclass diff --git a/esm/models/function_decoder.py b/esm/models/function_decoder.py index e5f1fb28..c4f32992 100644 --- a/esm/models/function_decoder.py +++ b/esm/models/function_decoder.py @@ -12,7 +12,9 @@ from esm.layers.regression_head import RegressionHead from esm.layers.transformer_stack import TransformerStack -from esm.tokenization.function_tokenizer import InterProQuantizedTokenizer +from esm.tokenization.function_tokenizer import ( + InterProQuantizedTokenizer, +) from esm.utils.constants import esm3 as C from esm.utils.misc import merge_annotations, merge_ranges from esm.utils.types import FunctionAnnotation diff --git a/esm/models/vqvae.py b/esm/models/vqvae.py index 37bc3945..0f5226a4 100644 --- a/esm/models/vqvae.py +++ b/esm/models/vqvae.py @@ -7,7 +7,10 @@ from esm.layers.transformer_stack import TransformerStack from esm.utils.constants import esm3 as C from esm.utils.misc import knn_graph -from esm.utils.structure.affine3d import Affine3D, build_affine3d_from_coordinates +from esm.utils.structure.affine3d import ( + Affine3D, + build_affine3d_from_coordinates, +) from esm.utils.structure.predicted_aligned_error import ( compute_predicted_aligned_error, compute_tm, diff --git a/esm/pretrained.py b/esm/pretrained.py index e452e1d2..b9121511 100644 --- a/esm/pretrained.py +++ b/esm/pretrained.py @@ -6,8 +6,14 @@ from esm.models.esm3 import ESM3 from esm.models.esmc import ESMC from esm.models.function_decoder import FunctionTokenDecoder -from esm.models.vqvae import StructureTokenDecoder, StructureTokenEncoder -from esm.tokenization import get_esm3_model_tokenizers, get_esmc_model_tokenizers +from esm.models.vqvae import ( + StructureTokenDecoder, + StructureTokenEncoder, +) +from esm.tokenization import ( + get_esm3_model_tokenizers, + get_esmc_model_tokenizers, +) from esm.utils.constants.esm3 import data_root from esm.utils.constants.models import ( ESM3_FUNCTION_DECODER_V0, diff --git a/esm/sdk/api.py b/esm/sdk/api.py index 0212ddcd..6b152556 100644 --- a/esm/sdk/api.py +++ b/esm/sdk/api.py @@ -2,19 +2,27 @@ from abc import ABC from copy import deepcopy -from typing import Sequence +from typing import List, Sequence import attr import torch from attr import asdict, define import esm.utils.constants.api as C -from esm.tokenization import TokenizerCollectionProtocol, get_esm3_model_tokenizers +from esm.tokenization import ( + TokenizerCollectionProtocol, + get_esm3_model_tokenizers, +) from esm.utils import encoding from esm.utils.constants.models import ESM3_OPEN_SMALL -from esm.utils.misc import get_chainbreak_boundaries_from_sequence +from esm.utils.misc import ( + get_chainbreak_boundaries_from_sequence, +) from esm.utils.structure.protein_chain import ProteinChain -from esm.utils.structure.protein_complex import SINGLE_LETTER_CHAIN_IDS, ProteinComplex +from esm.utils.structure.protein_complex import ( + SINGLE_LETTER_CHAIN_IDS, + ProteinComplex, +) from esm.utils.types import FunctionAnnotation, PathOrBuffer @@ -35,6 +43,7 @@ class ESMProtein(ProteinType): plddt: torch.Tensor | None = None ptm: torch.Tensor | None = None + # When calling EvolutionaryScale API, use this flag to disclose any # sequences that may potentially have concerns. # Such sequences may not go through standard safety filter for approved users. @@ -148,12 +157,35 @@ def to_protein_complex( gt_chains = list(copy_annotations_from_ground_truth.chain_iter()) else: gt_chains = None + + # Expand pLDDT to match sequence length if needed, inserting NaN at chain breaks + # This handles the case where the server doesn't include chain breaks in pLDDT + # We should fix this in the server side. + if self.plddt is not None and len(self.plddt) != len(self.sequence): + # Only expand if there's a mismatch (likely due to chain breaks) + if "|" in self.sequence: + # Create expanded pLDDT with NaN at chain break positions + expanded_plddt = torch.full((len(self.sequence),), float("nan")) + plddt_idx = 0 + for i, aa in enumerate(self.sequence): + if aa != "|": + if plddt_idx < len(self.plddt): + expanded_plddt[i] = self.plddt[plddt_idx] + plddt_idx += 1 + plddt = expanded_plddt + else: + # Mismatch but no chain breaks - shouldn't happen but preserve original + plddt = self.plddt + else: + plddt = self.plddt + pred_chains = [] for i, (start, end) in enumerate(chain_boundaries): if i >= len(SINGLE_LETTER_CHAIN_IDS): raise ValueError( f"Too many chains to convert to ProteinComplex. The maximum number of chains is {len(SINGLE_LETTER_CHAIN_IDS)}" ) + pred_chain = ProteinChain.from_atom37( atom37_positions=coords[start:end], sequence=self.sequence[start:end], @@ -161,7 +193,7 @@ def to_protein_complex( if gt_chains is not None else SINGLE_LETTER_CHAIN_IDS[i], entity_id=gt_chains[i].entity_id if gt_chains is not None else None, - confidence=self.plddt[start:end] if self.plddt is not None else None, + confidence=plddt[start:end] if plddt is not None else None, ) pred_chains.append(pred_chain) return ProteinComplex.from_chains(pred_chains) @@ -298,19 +330,14 @@ def use_generative_unmasking_strategy(self): self.temperature_annealing = True -@define -class MSA: - # Paired MSA sequences. - # One would typically compute these using, for example, ColabFold. - sequences: list[str] - - @define class InverseFoldingConfig: invalid_ids: Sequence[int] = [] temperature: float = 1.0 + + ## Low Level Endpoint Types @define class SamplingTrackConfig: @@ -375,6 +402,9 @@ class LogitsConfig: ith_hidden_layer: int = -1 + + + @define class LogitsOutput: logits: ForwardTrackData | None = None diff --git a/esm/sdk/base_forge_client.py b/esm/sdk/base_forge_client.py index ff05b541..3c60a25f 100644 --- a/esm/sdk/base_forge_client.py +++ b/esm/sdk/base_forge_client.py @@ -1,9 +1,13 @@ +import asyncio +import time +from abc import ABC, abstractmethod from typing import Any from urllib.parse import urljoin import httpx from esm.sdk.api import ESMProteinError +from esm.sdk.retry import retry_decorator from esm.utils.decoding import assemble_message @@ -80,6 +84,10 @@ def prepare_request( headers = {**self.headers, **headers} if return_bytes: headers["return-bytes"] = "true" + # __INTERNAL_BEGIN___ + if disable_cache: + headers["X-Disable-Cache"] = "true" + # __INTERNAL_END___ return request, headers def prepare_data(self, response, endpoint: str) -> dict[str, Any]: @@ -112,7 +120,11 @@ async def _async_post( ): try: request, headers = self.prepare_request( - request, potential_sequence_of_concern, return_bytes, headers + request, + potential_sequence_of_concern, + return_bytes, + disable_cache, + headers, ) response = await self.async_client.post( url=urljoin(self.url, f"/api/v1/{endpoint}"), @@ -142,7 +154,10 @@ def _post( ): try: request, headers = self.prepare_request( - request, potential_sequence_of_concern, return_bytes, headers + request, + potential_sequence_of_concern, + return_bytes, + headers, ) response = self.client.post( url=urljoin(self.url, f"/api/v1/{endpoint}"), @@ -160,3 +175,5 @@ def _post( error_code=500, error_msg=f"Failed to submit request to {endpoint}. Error: {e}", ) + + diff --git a/esm/sdk/forge.py b/esm/sdk/forge.py index e9bb2e3d..ac105c03 100644 --- a/esm/sdk/forge.py +++ b/esm/sdk/forge.py @@ -1,13 +1,14 @@ +from __future__ import annotations + import asyncio import base64 import pickle from concurrent.futures import ThreadPoolExecutor -from typing import Any, Sequence +from typing import Any, Literal, Sequence, cast import torch from esm.sdk.api import ( - MSA, ESM3InferenceClient, ESMCInferenceClient, ESMProtein, @@ -19,14 +20,30 @@ InverseFoldingConfig, LogitsConfig, LogitsOutput, + ProteinChain, ProteinType, SamplingConfig, SamplingTrackConfig, ) -from esm.sdk.base_forge_client import _BaseForgeInferenceClient +from esm.sdk.base_forge_client import ( + _BaseForgeInferenceClient, +) from esm.sdk.retry import retry_decorator from esm.utils.constants.api import MIMETYPE_ES_PICKLE -from esm.utils.misc import deserialize_tensors, maybe_list, maybe_tensor +from esm.utils.misc import ( + deserialize_tensors, + maybe_list, + maybe_tensor, +) +from esm.utils.msa import MSA +from esm.utils.structure.input_builder import ( + StructurePredictionInput, + serialize_structure_prediction_input, +) +from esm.utils.structure.molecular_complex import ( + MolecularComplex, + MolecularComplexResult, +) from esm.utils.types import FunctionAnnotation @@ -36,10 +53,8 @@ def _list_to_function_annotations(l) -> list[FunctionAnnotation] | None: return [FunctionAnnotation(*t) for t in l] -def _maybe_logits(data: dict[str, Any], track: str, return_bytes: bool = False): - ret = data.get("logits", {}).get(track, None) - # TODO(s22chan): just return this when removing return_bytes - return ret if ret is None or not return_bytes else maybe_tensor(ret) +def _maybe_logits(data: dict[str, Any], track: str): + return maybe_tensor(data.get("logits", {}).get(track, None)) def _maybe_b64_decode(obj, return_bytes: bool): @@ -93,9 +108,13 @@ def __init__( ) @staticmethod - def _process_fold_request(sequence: str, model_name: str | None): + def _process_fold_request( + sequence: str, + model_name: str | None, + ): request: dict[str, Any] = {"sequence": sequence} + request["model"] = model_name return request @@ -130,6 +149,7 @@ def process_inverse_fold_request( return request + async def _async_fetch_msa(self, sequence: str) -> MSA: print("Fetching MSA ... this may take a few minutes") # Accept both "|" and ":" as the chainbreak token. @@ -137,7 +157,7 @@ async def _async_fetch_msa(self, sequence: str) -> MSA: data = await self._async_post( "msa", request={}, params={"sequence": sequence, "use_env": False} ) - return MSA(sequences=data["msa"]) + return MSA.from_sequences(sequences=data["msa"]) def _fetch_msa(self, sequence: str) -> MSA: print("Fetching MSA ... this may take a few minutes") @@ -146,7 +166,7 @@ def _fetch_msa(self, sequence: str) -> MSA: data = self._post( "msa", request={}, params={"sequence": sequence, "use_env": False} ) - return MSA(sequences=data["msa"]) + return MSA.from_sequences(sequences=data["msa"]) @retry_decorator async def async_fold( @@ -168,11 +188,15 @@ async def async_fold( del potential_sequence_of_concern request = self._process_fold_request( - sequence, model_name if model_name is not None else self.model + sequence, + model_name if model_name is not None else self.model, ) try: - data = await self._async_post("fold", request) + data = await self._async_post( + "fold", + request, + ) except ESMProteinError as e: return e @@ -199,16 +223,98 @@ def fold( del potential_sequence_of_concern request = self._process_fold_request( - sequence, model_name if model_name is not None else self.model + sequence, + model_name if model_name is not None else self.model, ) try: - data = self._post("fold", request) + data = self._post( + "fold", + request, + ) except ESMProteinError as e: return e return self._process_fold_response(data, sequence) + @retry_decorator + async def async_fold_all_atom( + self, + all_atom_input: StructurePredictionInput, + model_name: str | None = None, + ) -> MolecularComplexResult | list[MolecularComplexResult] | ESMProteinError: + """Fold a molecular complex containing proteins, nucleic acids, and/or ligands. + + Args: + all_atom_input: StructurePredictionInput containing sequences for different molecule types + model_name: Override the client level model name if needed + """ + request = self._process_fold_all_atom_request( + all_atom_input, + model_name if model_name is not None else self.model, + ) + + try: + data = await self._async_post( + "fold_all_atom", + request, + ) + except ESMProteinError as e: + return e + + return self._process_fold_all_atom_response(data) + + @retry_decorator + def fold_all_atom( + self, + all_atom_input: StructurePredictionInput, + model_name: str | None = None, + ) -> MolecularComplexResult | list[MolecularComplexResult] | ESMProteinError: + """Predict coordinates for a molecular complex containing proteins, dna, rna, and/or ligands. + + Args: + all_atom_input: StructurePredictionInput containing sequences for different molecule types + model_name: Override the client level model name if needed + """ + request = self._process_fold_all_atom_request( + all_atom_input, + model_name if model_name is not None else self.model, + ) + + try: + data = self._post( + "fold_all_atom", + request, + ) + except ESMProteinError as e: + return e + + return self._process_fold_all_atom_response(data) + + @staticmethod + def _process_fold_all_atom_request( + all_atom_input: StructurePredictionInput, + model_name: str | None = None, + ) -> dict[str, Any]: + request: dict[str, Any] = { + "all_atom_input": serialize_structure_prediction_input(all_atom_input), + "model": model_name, + } + + + return request + + @staticmethod + def _process_fold_all_atom_response(data: dict[str, Any]) -> MolecularComplexResult: + complex_data = data.get("complex") + molecular_complex = MolecularComplex.from_state_dict(complex_data) + return MolecularComplexResult( + complex=molecular_complex, + plddt=maybe_tensor(data.get("plddt"), convert_none_to_nan=True), + ptm=data.get("ptm", None), + distogram=maybe_tensor(data.get("distogram"), convert_none_to_nan=True), + ) + @retry_decorator async def async_inverse_fold( self, @@ -280,6 +386,7 @@ def inverse_fold( return ESMProtein(sequence=data["sequence"]) + class ESM3ForgeInferenceClient(ESM3InferenceClient, _BaseForgeInferenceClient): def __init__( self, @@ -602,19 +709,15 @@ def _process_logits_response( return LogitsOutput( logits=ForwardTrackData( - sequence=_maybe_logits(data, "sequence", return_bytes), - structure=_maybe_logits(data, "structure", return_bytes), - secondary_structure=_maybe_logits( - data, "secondary_structure", return_bytes - ), - sasa=_maybe_logits(data, "sasa", return_bytes), - function=_maybe_logits(data, "function", return_bytes), + sequence=_maybe_logits(data, "sequence"), + structure=_maybe_logits(data, "structure"), + secondary_structure=_maybe_logits(data, "secondary_structure"), + sasa=_maybe_logits(data, "sasa"), + function=_maybe_logits(data, "function"), ), embeddings=maybe_tensor(data["embeddings"]), mean_embedding=data["mean_embedding"], - residue_annotation_logits=_maybe_logits( - data, "residue_annotation", return_bytes - ), + residue_annotation_logits=_maybe_logits(data, "residue_annotation"), hidden_states=maybe_tensor(data["hidden_states"]), mean_hidden_state=maybe_tensor(data["mean_hidden_state"]), ) @@ -965,6 +1068,7 @@ def _process_logits_request( "sequence": config.sequence, "return_embeddings": config.return_embeddings, "return_mean_embedding": config.return_mean_embedding, + "return_mean_hidden_states": config.return_mean_hidden_states, "return_hidden_states": config.return_hidden_states, "ith_hidden_layer": config.ith_hidden_layer, } @@ -981,12 +1085,11 @@ def _process_logits_response( data["hidden_states"] = _maybe_b64_decode(data["hidden_states"], return_bytes) output = LogitsOutput( - logits=ForwardTrackData( - sequence=_maybe_logits(data, "sequence", return_bytes) - ), + logits=ForwardTrackData(sequence=_maybe_logits(data, "sequence")), embeddings=maybe_tensor(data["embeddings"]), mean_embedding=data["mean_embedding"], hidden_states=maybe_tensor(data["hidden_states"]), + mean_hidden_state=maybe_tensor(data["mean_hidden_state"]), ) return output @@ -1109,3 +1212,5 @@ def raw_model(self): raise NotImplementedError( f"Can not get underlying remote model {self.model} from a Forge client." ) + + diff --git a/esm/sdk/retry.py b/esm/sdk/retry.py index 16c354b6..302d6cf0 100644 --- a/esm/sdk/retry.py +++ b/esm/sdk/retry.py @@ -2,10 +2,9 @@ from contextvars import ContextVar from functools import wraps -import httpx from tenacity import ( retry, - retry_if_exception_type, + retry_if_exception, retry_if_result, stop_after_attempt, wait_incrementing, @@ -30,8 +29,12 @@ def retry_if_specific_error(exception): def log_retry_attempt(retry_state): + try: + outcome = retry_state.outcome.result() + except Exception: + outcome = retry_state.outcome.exception() print( - f"Retrying... Attempt {retry_state.attempt_number} after {retry_state.next_action.sleep}s due to: {retry_state.outcome.result()}" + f"Retrying... Attempt {retry_state.attempt_number} after {retry_state.next_action.sleep}s due to: {outcome}" ) @@ -41,13 +44,18 @@ def retry_decorator(func): instance's retry settings. """ + def return_last_value(retry_state): + """Return the result of the last call attempt.""" + return retry_state.outcome.result() + @wraps(func) async def async_wrapper(instance, *args, **kwargs): if skip_retries_var.get(): return await func(instance, *args, **kwargs) retry_decorator = retry( + retry_error_callback=return_last_value, retry=retry_if_result(retry_if_specific_error) - | retry_if_exception_type(httpx.ConnectTimeout), # ADDED + | retry_if_exception(retry_if_specific_error), wait=wait_incrementing( increment=1, start=instance.min_retry_wait, max=instance.max_retry_wait ), @@ -62,8 +70,9 @@ def wrapper(instance, *args, **kwargs): if skip_retries_var.get(): return func(instance, *args, **kwargs) retry_decorator = retry( + retry_error_callback=return_last_value, retry=retry_if_result(retry_if_specific_error) - | retry_if_exception_type(httpx.ConnectTimeout), # ADDED + | retry_if_exception(retry_if_specific_error), wait=wait_incrementing( increment=1, start=instance.min_retry_wait, max=instance.max_retry_wait ), diff --git a/esm/tokenization/__init__.py b/esm/tokenization/__init__.py index 6db76554..ea609225 100644 --- a/esm/tokenization/__init__.py +++ b/esm/tokenization/__init__.py @@ -1,7 +1,10 @@ from dataclasses import dataclass from typing import Protocol -from esm.utils.constants.models import ESM3_OPEN_SMALL, normalize_model_name +from esm.utils.constants.models import ( + ESM3_OPEN_SMALL, + normalize_model_name, +) from .function_tokenizer import InterProQuantizedTokenizer from .residue_tokenizer import ResidueAnnotationsTokenizer diff --git a/esm/utils/decoding.py b/esm/utils/decoding.py index 1fe256b6..b5588527 100644 --- a/esm/utils/decoding.py +++ b/esm/utils/decoding.py @@ -10,12 +10,24 @@ from esm.models.vqvae import StructureTokenDecoder from esm.sdk.api import ESMProtein, ESMProteinTensor from esm.tokenization import TokenizerCollectionProtocol -from esm.tokenization.function_tokenizer import InterProQuantizedTokenizer -from esm.tokenization.residue_tokenizer import ResidueAnnotationsTokenizer -from esm.tokenization.sasa_tokenizer import SASADiscretizingTokenizer -from esm.tokenization.sequence_tokenizer import EsmSequenceTokenizer -from esm.tokenization.ss_tokenizer import SecondaryStructureTokenizer -from esm.tokenization.structure_tokenizer import StructureTokenizer +from esm.tokenization.function_tokenizer import ( + InterProQuantizedTokenizer, +) +from esm.tokenization.residue_tokenizer import ( + ResidueAnnotationsTokenizer, +) +from esm.tokenization.sasa_tokenizer import ( + SASADiscretizingTokenizer, +) +from esm.tokenization.sequence_tokenizer import ( + EsmSequenceTokenizer, +) +from esm.tokenization.ss_tokenizer import ( + SecondaryStructureTokenizer, +) +from esm.tokenization.structure_tokenizer import ( + StructureTokenizer, +) from esm.tokenization.tokenizer_base import EsmTokenizerBase from esm.utils.constants import api as api_constants from esm.utils.constants import esm3 as C diff --git a/esm/utils/encoding.py b/esm/utils/encoding.py index 8461709d..83c9d033 100644 --- a/esm/utils/encoding.py +++ b/esm/utils/encoding.py @@ -7,13 +7,26 @@ from esm.tokenization.function_tokenizer import ( InterProQuantizedTokenizer as EsmFunctionTokenizer, ) -from esm.tokenization.residue_tokenizer import ResidueAnnotationsTokenizer -from esm.tokenization.sasa_tokenizer import SASADiscretizingTokenizer -from esm.tokenization.sequence_tokenizer import EsmSequenceTokenizer -from esm.tokenization.ss_tokenizer import SecondaryStructureTokenizer -from esm.tokenization.structure_tokenizer import StructureTokenizer + +from esm.tokenization.residue_tokenizer import ( + ResidueAnnotationsTokenizer, +) +from esm.tokenization.sasa_tokenizer import ( + SASADiscretizingTokenizer, +) +from esm.tokenization.sequence_tokenizer import ( + EsmSequenceTokenizer, +) +from esm.tokenization.ss_tokenizer import ( + SecondaryStructureTokenizer, +) +from esm.tokenization.structure_tokenizer import ( + StructureTokenizer, +) from esm.utils.constants import esm3 as C -from esm.utils.function.encode_decode import encode_function_annotations +from esm.utils.function.encode_decode import ( + encode_function_annotations, +) from esm.utils.structure.protein_chain import ProteinChain from esm.utils.types import FunctionAnnotation @@ -152,6 +165,8 @@ def tokenize_function_annotations( return function_tokens, residue_annotation_tokens + + # Tokenized Defaults def get_default_sequence_tokens( sequence_length: int, sequence_tokenizer: EsmSequenceTokenizer @@ -227,3 +242,5 @@ def get_default_residue_annotation_tokens( residue_annotation_tokens[0] = residue_annotation_tokenizer.bos_token_id residue_annotation_tokens[-1] = residue_annotation_tokenizer.eos_token_id return residue_annotation_tokens + + diff --git a/esm/utils/forge_context_manager.py b/esm/utils/forge_context_manager.py index b1c2bdf3..fac0c3bd 100644 --- a/esm/utils/forge_context_manager.py +++ b/esm/utils/forge_context_manager.py @@ -7,7 +7,10 @@ from tqdm import tqdm from esm.sdk.api import ESMProteinError -from esm.sdk.retry import retry_if_specific_error, skip_retries_var +from esm.sdk.retry import ( + retry_if_specific_error, + skip_retries_var, +) TQDM_BAR_FORMAT = ( "{desc:<12}{percentage:3.0f}%|{bar:24}| {n_fmt}/{total_fmt} " diff --git a/esm/utils/function/encode_decode.py b/esm/utils/function/encode_decode.py index a4029858..29534e34 100644 --- a/esm/utils/function/encode_decode.py +++ b/esm/utils/function/encode_decode.py @@ -3,9 +3,16 @@ import torch -from esm.models.function_decoder import FunctionTokenDecoder, merge_annotations -from esm.tokenization.function_tokenizer import InterProQuantizedTokenizer -from esm.tokenization.residue_tokenizer import ResidueAnnotationsTokenizer +from esm.models.function_decoder import ( + FunctionTokenDecoder, + merge_annotations, +) +from esm.tokenization.function_tokenizer import ( + InterProQuantizedTokenizer, +) +from esm.tokenization.residue_tokenizer import ( + ResidueAnnotationsTokenizer, +) from esm.utils.constants import esm3 as C from esm.utils.types import FunctionAnnotation diff --git a/esm/utils/generation.py b/esm/utils/generation.py index b2d4ede4..cbc4306b 100644 --- a/esm/utils/generation.py +++ b/esm/utils/generation.py @@ -19,8 +19,13 @@ SamplingConfig, SamplingTrackConfig, ) -from esm.tokenization import EsmTokenizerBase, TokenizerCollectionProtocol -from esm.tokenization.function_tokenizer import InterProQuantizedTokenizer +from esm.tokenization import ( + EsmTokenizerBase, + TokenizerCollectionProtocol, +) +from esm.tokenization.function_tokenizer import ( + InterProQuantizedTokenizer, +) from esm.utils.constants import esm3 as C from esm.utils.misc import stack_variable_length_tensors from esm.utils.noise_schedules import NOISE_SCHEDULE_REGISTRY @@ -43,7 +48,9 @@ def _trim_sequence_tensor_dataclass(o: Any, sequence_len: int): sliced = {} for k, v in attr.asdict(o, recurse=False).items(): - if v is None: + if k in ["mean_hidden_state", "mean_embedding"]: + sliced[k] = v + elif v is None: sliced[k] = None elif isinstance(v, torch.Tensor): # Trim padding. diff --git a/esm/utils/misc.py b/esm/utils/misc.py index f0d7a602..409ba1bf 100644 --- a/esm/utils/misc.py +++ b/esm/utils/misc.py @@ -1,7 +1,19 @@ +from __future__ import annotations + import os from collections import defaultdict +from dataclasses import is_dataclass from io import BytesIO -from typing import Any, ContextManager, Sequence, TypeVar +from typing import ( + Any, + ContextManager, + Generator, + Iterable, + Protocol, + Sequence, + TypeVar, + runtime_checkable, +) from warnings import warn import huggingface_hub @@ -18,6 +30,12 @@ TSequence = TypeVar("TSequence", bound=Sequence) +@runtime_checkable +class Concatable(Protocol): + @classmethod + def concat(cls, objs: list[Concatable]) -> Concatable: ... + + def slice_python_object_as_numpy( obj: TSequence, idx: int | list[int] | slice | np.ndarray ) -> TSequence: @@ -52,6 +70,37 @@ def slice_python_object_as_numpy( return sliced_obj # type: ignore +def slice_any_object( + obj: TSequence, idx: int | list[int] | slice | np.ndarray +) -> TSequence: + """ + Slice a arbitrary object (like a list, string, or tuple) as if it was a numpy object. Similar to `slice_python_object_as_numpy`, but detects if it's a numpy array or Tensor and uses the existing slice method if so. + + If the object is a dataclass, it will simply apply the index to the object, under the assumption that the object has correcty implemented numpy indexing. + + Example: + >>> obj = "ABCDE" + >>> slice_any_object(obj, [1, 3, 4]) + "BDE" + + >>> obj = np.array([1, 2, 3, 4, 5]) + >>> slice_any_object(obj, np.arange(5) < 3) + np.array([1, 2, 3]) + + >>> obj = ProteinChain.from_rcsb("1a3a", "A") + >>> slice_any_object(obj, np.arange(len(obj)) < 10) + # ProteinChain w/ length 10 + + """ + if isinstance(obj, (np.ndarray, torch.Tensor)): + return obj[idx] # type: ignore + elif is_dataclass(obj): + # if passing a dataclass, assume it implements a custom slice + return obj[idx] # type: ignore + else: + return slice_python_object_as_numpy(obj, idx) + + def rbf(values, v_min, v_max, n_bins=16): """ Returns RBF encodings in a new dimension at the end. @@ -298,6 +347,8 @@ def replace_inf(data): def maybe_tensor(x, convert_none_to_nan: bool = False) -> torch.Tensor | None: if x is None: return None + if isinstance(x, torch.Tensor): + return x if isinstance(x, list) and all(isinstance(t, torch.Tensor) for t in x): return torch.stack(x) if convert_none_to_nan: @@ -357,3 +408,90 @@ def deserialize_tensors(b: bytes) -> Any: buf = BytesIO(zstd.ZSTD_uncompress(b)) d = torch.load(buf, map_location="cpu", weights_only=False) return d + + +def join_lists( + lists: Sequence[Sequence[Any]], separator: Sequence[Any] | None = None +) -> list[Any]: + """Joins multiple lists with separator element. Like str.join but for lists. + + Example: [[1, 2], [3], [4]], separator=[0] -> [1, 2, 0, 3, 0, 4] + + Args: + lists: Lists of elements to chain + separator: separators to intsert between chained output. + Returns: + Joined lists. + """ + if not lists: + return [] + joined = [] + joined.extend(lists[0]) + for l in lists[1:]: + if separator: + joined.extend(separator) + joined.extend(l) + return joined + + +def iterate_with_intermediate( + lists: Iterable, intermediate +) -> Generator[Any, None, None]: + """ + Iterate over the iterable, yielding the intermediate value between + every element of the intermediate. Useful for joining objects with + separator tokens. + """ + it = iter(lists) + yield next(it) + for l in it: + yield intermediate + yield l + + +def concat_objects(objs: Sequence[Any], separator: Any | None = None): + """ + Concat objects with each other using a separator token. + + Supports: + - Concatable (objects that implement `concat` classmethod) + - strings + - lists + - numpy arrays + - torch Tensors + + Example: + >>> foo = "abc" + >>> bar = "def" + >>> concat_objects([foo, bar], "|") + "abc|def" + """ + match objs[0]: + case Concatable(): + return objs[0].__class__.concat(objs) # type: ignore + case str(): + assert isinstance( + separator, str + ), "Trying to join strings but separator is not a string" + return separator.join(objs) + case list(): + if separator is not None: + return join_lists(objs, [separator]) + else: + return join_lists(objs) + case np.ndarray(): + if separator is not None: + return np.concatenate( + list(iterate_with_intermediate(objs, np.array([separator]))) + ) + else: + return np.concatenate(objs) + case torch.Tensor(): + if separator is not None: + return torch.cat( + list(iterate_with_intermediate(objs, torch.tensor([separator]))) + ) + else: + return torch.cat(objs) # type: ignore + case _: + raise TypeError(type(objs[0])) diff --git a/esm/utils/msa/__init__.py b/esm/utils/msa/__init__.py new file mode 100644 index 00000000..3804a365 --- /dev/null +++ b/esm/utils/msa/__init__.py @@ -0,0 +1,7 @@ +from esm.utils.msa.msa import ( + MSA, + FastMSA, + remove_insertions_from_sequence, +) + +__all__ = ["MSA", "FastMSA", "remove_insertions_from_sequence"] diff --git a/esm/utils/msa/filter_sequences.py b/esm/utils/msa/filter_sequences.py new file mode 100644 index 00000000..d44549d5 --- /dev/null +++ b/esm/utils/msa/filter_sequences.py @@ -0,0 +1,79 @@ +import tempfile +from pathlib import Path + +import numpy as np +from scipy.spatial.distance import cdist + +from esm.utils.system import run_subprocess_with_errorcheck + + +def greedy_select_indices(array, num_seqs: int, mode: str = "max") -> list[int]: + """Greedily select sequences that either maximize or minimize hamming distance. + + Algorithm proposed in the MSA Transformer paper. Starting from the query sequence, + iteratively add sequences to the list with the maximum (minimum) average Hamming + distance to the existing set of sequences. + + Args: + array (np.ndarray): Character array representing the sequences in the MSA + num_seqs (int): Number of sequences to select. + mode (str): Whether to maximize or minimize diversity. DO NOT pick 'min' unless + you're doing it to prove a point for a paper. + + Returns: + list[int]: List of indices to select from the array + """ + assert mode in ("max", "min") + depth = array.shape[0] + if depth <= num_seqs: + return list(range(depth)) + array = array.view(np.uint8) + + optfunc = np.argmax if mode == "max" else np.argmin + all_indices = np.arange(depth) + indices = [0] + pairwise_distances = np.zeros((0, depth)) + for _ in range(num_seqs - 1): + dist = cdist(array[indices[-1:]], array, "hamming") + pairwise_distances = np.concatenate([pairwise_distances, dist]) + shifted_distance = np.delete(pairwise_distances, indices, axis=1).mean(0) + shifted_index = optfunc(shifted_distance) + index = np.delete(all_indices, indices)[shifted_index] + indices.append(index) + indices = sorted(indices) + return indices + + +def hhfilter( + sequences: list[str], + seqid: int = 90, + diff: int = 0, + cov: int = 0, + qid: int = 0, + qsc: float = -20.0, + binary: str = "hhfilter", +) -> list[int]: + with tempfile.TemporaryDirectory(dir="/dev/shm") as tempdirname: + tempdir = Path(tempdirname) + fasta_file = tempdir / "input.fasta" + fasta_file.write_text( + "\n".join(f">{i}\n{seq}" for i, seq in enumerate(sequences)) + ) + output_file = tempdir / "output.fasta" + command = " ".join( + [ + f"{binary}", + f"-i {fasta_file}", + "-M a3m", + f"-o {output_file}", + f"-id {seqid}", + f"-diff {diff}", + f"-cov {cov}", + f"-qid {qid}", + f"-qsc {qsc}", + ] + ).split(" ") + run_subprocess_with_errorcheck(command, capture_output=True) + with output_file.open() as f: + indices = [int(line[1:].strip()) for line in f if line.startswith(">")] + return indices diff --git a/esm/utils/msa/msa.py b/esm/utils/msa/msa.py new file mode 100644 index 00000000..8722b9d6 --- /dev/null +++ b/esm/utils/msa/msa.py @@ -0,0 +1,507 @@ +from __future__ import annotations + +import dataclasses +import string +from dataclasses import dataclass +from functools import cached_property +from itertools import islice +from typing import Sequence + +import numpy as np +from Bio import SeqIO +from scipy.spatial.distance import cdist + +from esm.utils.misc import slice_any_object +from esm.utils.msa.filter_sequences import ( + greedy_select_indices, + hhfilter, +) +from esm.utils.parsing import ( + FastaEntry, + read_sequences, + write_sequences, +) +from esm.utils.sequential_dataclass import SequentialDataclass +from esm.utils.system import PathOrBuffer + +REMOVE_LOWERCASE_TRANSLATION = str.maketrans(dict.fromkeys(string.ascii_lowercase)) + + +def remove_insertions_from_sequence(seq: str) -> str: + return seq.translate(REMOVE_LOWERCASE_TRANSLATION) + + +@dataclass(frozen=True) +class MSA(SequentialDataclass): + """Object-oriented interface to an MSA. + + Args: + sequences (list[str]): List of protein sequences + headers (list[str]): List of headers describing the sequences + + """ + + entries: list[FastaEntry] + + @cached_property + def sequences(self) -> list[str]: + return [entry.sequence for entry in self.entries] + + @cached_property + def headers(self) -> list[str]: + return [entry.header for entry in self.entries] + + def __repr__(self): + return ( + f"MSA({self.entries[0].header}: Depth={self.depth}, Length={self.seqlen})" + ) + + def to_fast_msa(self) -> FastMSA: + return FastMSA(self.array, self.headers) + + @classmethod + def from_a3m( + cls, + path: PathOrBuffer, + remove_insertions: bool = True, + max_sequences: int | None = None, + ) -> MSA: + entries = [] + for header, seq in islice(read_sequences(path), max_sequences): + if remove_insertions: + seq = remove_insertions_from_sequence(seq) + if entries: + assert ( + len(seq) == len(entries[0].sequence) + ), f"Sequence length mismatch. Expected: {len(entries[0].sequence)}, Received: {len(seq)}" + entries.append(FastaEntry(header, seq)) + return cls(entries) + + def to_a3m(self, path: PathOrBuffer) -> None: + write_sequences(self.entries, path) + + @classmethod + def from_stockholm( + cls, + path: PathOrBuffer, + remove_insertions: bool = True, + max_sequences: int | None = None, + ) -> MSA: + entries = [] + for record in islice(SeqIO.parse(path, "stockholm"), max_sequences): + header = f"{record.id} {record.description}" + seq = str(record.seq) + if entries: + assert ( + len(seq) == len(entries[0].sequence) + ), f"Sequence length mismatch. Expected: {len(entries[0].sequence)}, Received: {len(seq)}" + entries.append(FastaEntry(header, seq)) + msa = cls(entries) + if remove_insertions: + keep_inds = [i for i, aa in enumerate(msa.query) if aa != "-"] + msa = msa.select_positions(keep_inds) + return msa + + def to_bytes(self) -> bytes: + version = 1 + version_bytes = version.to_bytes(1, "little") + seqlen_bytes = self.seqlen.to_bytes(4, "little") + depth_bytes = self.depth.to_bytes(4, "little") + array_bytes = self.array.tobytes() + header_bytes = "\n".join(entry.header for entry in self.entries).encode() + all_bytes = ( + version_bytes + seqlen_bytes + depth_bytes + array_bytes + header_bytes + ) + return all_bytes + + @classmethod + def from_bytes(cls, data: bytes) -> MSA: + version_bytes, seqlen_bytes, depth_bytes, data = ( + data[:1], + data[1:5], + data[5:9], + data[9:], + ) + version = int.from_bytes(version_bytes, "little") + if version != 1: + raise ValueError(f"Unsupported version: {version}") + seqlen = int.from_bytes(seqlen_bytes, "little") + depth = int.from_bytes(depth_bytes, "little") + array_bytes, header_bytes = data[: seqlen * depth], data[seqlen * depth :] + array = np.frombuffer(array_bytes, dtype="|S1") + array = array.reshape(depth, seqlen) + headers = header_bytes.decode().split("\n") + # Sometimes the separation is two newlines, which results in an empty header. + headers = [header for header in headers if header] + entries = [ + FastaEntry(header, b"".join(row).decode()) + for header, row in zip(headers, array) + ] + return cls(entries) + + # TODO(jmaccarl): set remove_insertions to True by default here to match other utils + @classmethod + def from_sequences( + cls, sequences: list[str], remove_insertions: bool = False + ) -> MSA: + if remove_insertions: + entries = [ + FastaEntry("", remove_insertions_from_sequence(seq)) + for seq in sequences + ] + else: + entries = [FastaEntry("", seq) for seq in sequences] + return cls(entries) + + def to_sequence_bytes(self) -> bytes: + """Stores ONLY SEQUENCES in array format as bytes. Header information will be lost.""" + seqlen_bytes = self.seqlen.to_bytes(4, "little") + array_bytes = self.array.tobytes() + all_bytes = seqlen_bytes + array_bytes + return all_bytes + + @classmethod + def from_sequence_bytes(cls, data: bytes) -> MSA: + seqlen_bytes, array_bytes = data[:4], data[4:] + seqlen = int.from_bytes(seqlen_bytes, "little") + array = np.frombuffer(array_bytes, dtype="|S1") + array = array.reshape(-1, seqlen) + entries = [FastaEntry("", b"".join(row).decode()) for row in array] + return cls(entries) + + @property + def depth(self) -> int: + return len(self.entries) + + @property + def seqlen(self) -> int: + return len(self.entries[0].sequence) + + @cached_property + def array(self) -> np.ndarray: + return np.array([list(seq) for seq in self.sequences], dtype="|S1") + + @property + def query(self) -> str: + return self.entries[0].sequence + + def select_sequences(self, indices: Sequence[int] | np.ndarray) -> MSA: + """Subselect rows of the MSA.""" + entries = [self.entries[idx] for idx in indices] + return dataclasses.replace(self, entries=entries) + + def select_positions(self, indices: Sequence[int] | np.ndarray) -> MSA: + """Subselect columns of the MSA.""" + entries = [ + FastaEntry(header, "".join(seq[idx] for idx in indices)) + for header, seq in self.entries + ] + return dataclasses.replace(self, entries=entries) + + def __getitem__(self, indices: int | list[int] | slice | np.ndarray): + if isinstance(indices, int): + indices = [indices] + + entries = [ + FastaEntry(header, slice_any_object(seq, indices)) + for header, seq in self.entries + ] + return dataclasses.replace(self, entries=entries) + + def __len__(self): + return self.seqlen + + def greedy_select(self, num_seqs: int, mode: str = "max") -> MSA: + """Greedily select sequences that either maximize or minimize hamming distance. + + Algorithm proposed in the MSA Transformer paper. Starting from the query sequence, + iteratively add sequences to the list with the maximum (minimum) average Hamming + distance to the existing set of sequences. + + Args: + num_seqs (int): Number of sequences to select. + mode (str): Whether to maximize or minimize diversity. DO NOT pick 'min' unless + you're doing it to prove a point for a paper. + + Returns: + MSA object w/ subselected sequences. + """ + assert mode in ("max", "min") + if self.depth <= num_seqs: + return self + + indices = greedy_select_indices(self.array, num_seqs, mode) + return self.select_sequences(indices) + + def hhfilter( + self, + seqid: int = 90, + diff: int = 0, + cov: int = 0, + qid: int = 0, + qsc: float = -20.0, + binary: str = "hhfilter", + ) -> MSA: + """Apply hhfilter to the sequences in the MSA and return a filtered MSA.""" + + indices = hhfilter( + self.sequences, + seqid=seqid, + diff=diff, + cov=cov, + qid=qid, + qsc=qsc, + binary=binary, + ) + return self.select_sequences(indices) + + def select_random_sequences(self, num_seqs: int) -> MSA: + """Uses random sampling to subselect sequences from the MSA. Always + keeps the query sequence. + """ + if num_seqs >= self.depth: + return self + + # Subselect random, always keeping the query sequence. + indices = np.sort( + np.append( + 0, np.random.choice(self.depth - 1, num_seqs - 1, replace=False) + 1 + ) + ) + msa = self.select_sequences(indices) # type: ignore + return msa + + def select_diverse_sequences(self, num_seqs: int) -> MSA: + """Applies hhfilter to select ~num_seqs sequences, then uses random sampling + to subselect if necessary. + """ + if num_seqs >= self.depth: + return self + + msa = self.hhfilter(diff=num_seqs) + if num_seqs < msa.depth: + msa = msa.select_random_sequences(num_seqs) + return msa + + def pad_to_depth(self, depth: int) -> MSA: + if depth < self.depth: + raise ValueError(f"Cannot pad to depth {depth} when depth is {self.depth}") + elif depth == self.depth: + return self + + num_to_add = depth - self.depth + extra_entries = [FastaEntry("", "-" * self.seqlen) for _ in range(num_to_add)] + return dataclasses.replace(self, entries=self.entries + extra_entries) + + @classmethod + def stack( + cls, msas: Sequence[MSA], remove_query_from_later_msas: bool = True + ) -> MSA: + """Stack a series of MSAs. Optionally remove the query from msas after the first.""" + all_entries = [] + for i, msa in enumerate(msas): + entries = msa.entries + if i > 0 and remove_query_from_later_msas: + entries = entries[1:] + all_entries.extend(entries) + return cls(entries=all_entries) + + @cached_property + def seqid(self) -> np.ndarray: + array = self.array.view(np.uint8) + seqid = 1 - cdist(array[0][None], array, "hamming") + return seqid[0] + + @classmethod + def concat( + cls, + msas: Sequence[MSA], + join_token: str | None = "|", + allow_depth_mismatch: bool = False, + ) -> MSA: + """Concatenate a series of MSAs horizontally, along the sequence dimension.""" + if not msas: + raise ValueError("Cannot concatenate an empty list of MSAs") + msa_depths = [msa.depth for msa in msas] + if len(set(msa_depths)) != 1: + if not allow_depth_mismatch: + raise ValueError("Depth mismatch in concatenating MSAs") + else: + max_depth = max(msa_depths) + msas = [msa.pad_to_depth(max_depth) for msa in msas] + headers = [ + "|".join([str(h) for h in headers]) + for headers in zip(*(msa.headers for msa in msas)) + ] + + if join_token is None: + join_token = "" + + seqs = [join_token.join(vals) for vals in zip(*(msa.sequences for msa in msas))] + entries = [FastaEntry(header, seq) for header, seq in zip(headers, seqs)] + return cls(entries) + + +@dataclass(frozen=True) +class FastMSA(SequentialDataclass): + """Object-oriented interface to an MSA stored as a numpy uint8 array.""" + + array: np.ndarray + headers: list[str] | None = None + + def __post_init__(self): + if self.headers is not None: + assert ( + len(self.headers) == self.depth + ), "Number of headers must match depth." + + @classmethod + def from_bytes(cls, data: bytes) -> FastMSA: + version_bytes, seqlen_bytes, depth_bytes, data = ( + data[:1], + data[1:5], + data[5:9], + data[9:], + ) + version = int.from_bytes(version_bytes, "little") + if version != 1: + raise ValueError(f"Unsupported version: {version}") + seqlen = int.from_bytes(seqlen_bytes, "little") + depth = int.from_bytes(depth_bytes, "little") + array_bytes, header_bytes = data[: seqlen * depth], data[seqlen * depth :] + array = np.frombuffer(array_bytes, dtype="|S1") + array = array.reshape(depth, seqlen) + headers = header_bytes.decode().split("\n") + # Sometimes the separation is two newlines, which results in an empty header. + headers = [header for header in headers if header] + return cls(array, headers) + + @classmethod + def from_sequence_bytes(cls, data: bytes) -> FastMSA: + seqlen_bytes, array_bytes = data[:4], data[4:] + seqlen = int.from_bytes(seqlen_bytes, "little") + array = np.frombuffer(array_bytes, dtype="|S1") + array = array.reshape(-1, seqlen) + return cls(array) + + @property + def depth(self) -> int: + return self.array.shape[0] + + @property + def seqlen(self) -> int: + return self.array.shape[1] + + def __len__(self): + return self.seqlen + + def __getitem__(self, indices: int | list[int] | slice | np.ndarray): + if isinstance(indices, int): + indices = [indices] + + return dataclasses.replace(self, array=self.array[:, indices]) + + def select_sequences(self, indices: Sequence[int] | np.ndarray) -> FastMSA: + """Subselect rows of the MSA.""" + array = self.array[indices] + headers = ( + [self.headers[idx] for idx in indices] if self.headers is not None else None + ) + return dataclasses.replace(self, array=array, headers=headers) + + def select_random_sequences(self, num_seqs: int) -> FastMSA: + """Uses random sampling to subselect sequences from the MSA. Always + keeps the query sequence. + """ + if num_seqs >= self.depth: + return self + + # Subselect random, always keeping the query sequence. + indices = np.sort( + np.append( + 0, np.random.choice(self.depth - 1, num_seqs - 1, replace=False) + 1 + ) + ) + msa = self.select_sequences(indices) # type: ignore + return msa + + def pad_to_depth(self, depth: int) -> FastMSA: + if depth < self.depth: + raise ValueError(f"Cannot pad to depth {depth} when depth is {self.depth}") + elif depth == self.depth: + return self + + num_to_add = depth - self.depth + array = np.pad( + self.array, + [(0, num_to_add), (0, 0)], + constant_values=ord("-") if self.array.dtype == np.uint8 else b"-", + ) + headers = self.headers + if headers is not None: + headers = headers + [""] * num_to_add + return dataclasses.replace(self, array=array, headers=headers) + + @classmethod + def concat( + cls, + msas: Sequence[FastMSA], + join_token: str | None = None, + allow_depth_mismatch: bool = False, + ) -> FastMSA: + """Concatenate a series of MSAs horizontally, along the sequence dimension.""" + if not msas: + raise ValueError("Cannot concatenate an empty list of MSAs") + if join_token is not None and join_token != "": + raise NotImplementedError("join_token is not supported for FastMSA") + + msa_depths = [msa.depth for msa in msas] + if len(set(msa_depths)) != 1: + if not allow_depth_mismatch: + raise ValueError("Depth mismatch in concatenating MSAs") + else: + max_depth = max(msa_depths) + msas = [msa.pad_to_depth(max_depth) for msa in msas] + headers = [ + "|".join([str(h) for h in headers]) + for headers in zip( + *( + msa.headers if msa.headers is not None else [""] * msa.depth + for msa in msas + ) + ) + ] + + array = np.concatenate([msa.array for msa in msas], axis=1) + return cls(array, headers) + + def to_msa(self) -> MSA: + headers = ( + self.headers + if self.headers is not None + else [f"seq{i}" for i in range(self.depth)] + ) + entries = [ + FastaEntry(header, b"".join(row).decode()) + for header, row in zip(headers, self.array) + ] + return MSA(entries) + + @classmethod + def stack( + cls, msas: Sequence[FastMSA], remove_query_from_later_msas: bool = True + ) -> FastMSA: + """Stack a series of MSAs. Optionally remove the query from msas after the first.""" + arrays = [] + all_headers = [] + for i, msa in enumerate(msas): + array = msa.array + headers = msa.headers + if i > 0 and remove_query_from_later_msas: + array = array[1:] + if headers is not None: + headers = headers[1:] + arrays.append(array) + if headers is not None: + all_headers.extend(headers) + return cls(np.concatenate(arrays, axis=0), all_headers) diff --git a/esm/utils/parsing.py b/esm/utils/parsing.py new file mode 100644 index 00000000..c47938ab --- /dev/null +++ b/esm/utils/parsing.py @@ -0,0 +1,83 @@ +import io +from pathlib import Path +from typing import Generator, Iterable, NamedTuple + +PathOrBuffer = str | Path | io.TextIOBase +FastaEntry = NamedTuple("FastaEntry", [("header", str), ("sequence", str)]) + + +def parse_fasta(fasta_string: str) -> Generator[FastaEntry, None, None]: + """ + Parses a fasta file and yields FastaEntry objects + + Args: + fasta_string: The fasta file as a string + Returns: + A generator of FastaEntry objects + """ + header = None + seq = [] + num_sequences = 0 + for line in fasta_string.splitlines(): + if not line or line[0] == "#": + continue + if line.startswith(">"): + if header is not None: + yield FastaEntry(header, "".join(seq)) + seq = [] + header = line[1:].strip() + else: + seq.append(line) + if header is not None: + num_sequences += 1 + yield FastaEntry(header, "".join(seq)) + + if num_sequences == 0: + raise ValueError("Found no sequences in input") + + +def read_sequences(path: PathOrBuffer) -> Generator[FastaEntry, None, None]: + # Uses duck typing to try and call the right method + # Doesn't use explicit isinstance check to support + # inputs that are not explicitly str/Path/TextIOBase but + # may support similar functionality + data = None # type: ignore + try: + if str(path).endswith(".gz"): + import gzip + + data = gzip.open(path, "rt") # type: ignore + else: + try: + data = open(path) # type: ignore + except TypeError: + data: io.TextIOBase = path # type: ignore + + yield from parse_fasta(data.read()) + finally: + if data is not None: + data.close() + + +def read_first_sequence(path: PathOrBuffer) -> FastaEntry: + return next(iter(read_sequences(path))) + + +def write_sequences(sequences: Iterable[tuple[str, str]], path: PathOrBuffer) -> None: + needs_closing = False + handle = None + try: + try: + handle = open(path, "w") # type: ignore + needs_closing = True + except TypeError: + handle = path + has_prev = False + for header, seq in sequences: + if has_prev: + handle.write("\n") # type: ignore + handle.write(f">{header}\n{seq}") # type: ignore + has_prev = True + finally: + if needs_closing: + handle.close() # type: ignore diff --git a/esm/utils/sampling.py b/esm/utils/sampling.py index fdf8658d..68c5c868 100644 --- a/esm/utils/sampling.py +++ b/esm/utils/sampling.py @@ -5,9 +5,18 @@ import torch import torch.nn.functional as F -from esm.sdk.api import ESMProteinTensor, SamplingConfig, SamplingTrackConfig -from esm.tokenization import TokenizerCollectionProtocol, get_invalid_tokenizer_ids -from esm.tokenization.function_tokenizer import InterProQuantizedTokenizer +from esm.sdk.api import ( + ESMProteinTensor, + SamplingConfig, + SamplingTrackConfig, +) +from esm.tokenization import ( + TokenizerCollectionProtocol, + get_invalid_tokenizer_ids, +) +from esm.tokenization.function_tokenizer import ( + InterProQuantizedTokenizer, +) from esm.utils.constants.esm3 import ( MAX_RESIDUE_ANNOTATIONS, SASA_DISCRETIZATION_BOUNDARIES, diff --git a/esm/utils/sequential_dataclass.py b/esm/utils/sequential_dataclass.py new file mode 100644 index 00000000..8e935798 --- /dev/null +++ b/esm/utils/sequential_dataclass.py @@ -0,0 +1,157 @@ +from abc import ABC, abstractmethod +from dataclasses import dataclass, fields, replace +from typing import TypeVar + +import numpy as np + +from esm.utils.misc import concat_objects, slice_any_object + +T = TypeVar("T") + + +@dataclass(frozen=True) +class SequentialDataclass(ABC): + """ + This is a builder on a dataclass that allows for automatic slicing and concatenation. + + When representing multimodal data, we often have multiple datatypes which have sequence dimensions that are the same (e.g. the length of the protein). + + When applying a transformation like a crop, we want to apply this to all tensors at the same time (e.g. crop the sequence, structure, and function). + + We also have some fields that are not sequential (like an id, or data source), which we don't want to crop. + + The SequentialDataclass abstracts this cropping away, allowing you to define dataclasses that implement `__len__`, `__getitem__` and `concat` automatically. + + This is done through the `metadata` field, which can take 3 values: + `sequence` (bool): True or False, tells the dataclass whether this field is a sequential type. Default: False. + `sequence_dim` (int): Which dimension is the sequential dimension (e.g. for a list of inverse folded sequences, we want to index each sequence in the list, not the list itself). Default: 0. + `join_token` (Any): What token to use to join when concatenating elements. Default: None. + + + Example: + + @dataclass(frozen=True) + class Foo(SequentialDataclass): + id: str + sequence: str = field(metadata={"sequence": True, "join_token": "|"}) + tensor: torch.Tensor = field(metadata={"sequence": True, "join_token": torch.nan}) + + def __len__(self): + # Must implement the __len__ method + return len(self.sequence) + + >>> foo = Foo(id="foo", sequence="ABCDE", tensor=torch.randn(5)) + Foo(id='foo', sequence='ABCDE', tensor=tensor([ 0.0252, -0.3335, -0.5143, 0.0251, -1.0717])) + + >>> foo[1:4] + Foo(id='foo', sequence='BCD', tensor=tensor([-0.3335, -0.5143, 0.0251])) + + >>> foo[np.arange(5) < 3] + Foo(id='foo', sequence='ABC', tensor=tensor([ 0.0252, -0.3335, -0.5143])) + + >>> Foo.concat([foo[:2], foo[3:]]) + Foo(id='foo', sequence='AB|DE', tensor=tensor([ 0.0252, -0.3335, nan, 0.0251, -1.0717])) + + # Trying to create a type where the sequence lengths do not match raises an error + >>> foo = Foo(id="foo", sequence="ABCDE", tensor=torch.randn(6)) + ValueError: Mismatch in sequence length for field: tensor. Expected 5, received 6 + + """ + + def __post_init__(self): + self._check_sequence_lengths_match() + + @abstractmethod + def __len__(self): + raise NotImplementedError + + def __getitem__(self, idx: int | list[int] | slice | np.ndarray): + updated_fields = {} + if isinstance(idx, int): + # make it so that things remain sequential + idx = [idx] + + for fld in fields(self): + if fld.metadata.get("sequence", False): + # this is a sequence, should be the same length as all other sequences + sequence_dim = fld.metadata.get("sequence_dim", 0) + value = getattr(self, fld.name) + if value is None: + continue + match sequence_dim: + case 0: + # sequence is first dimension + value = getattr(self, fld.name) + value = slice_any_object(value, idx) + updated_fields[fld.name] = value + case 1: + new_value = [slice_any_object(item, idx) for item in value] + updated_fields[fld.name] = value.__class__(new_value) + case _: + raise NotImplementedError( + "Arbitrary slicing for different sequence length fields is not implemented" + ) + + return replace(self, **updated_fields) + + def _check_sequence_lengths_match(self): + """Checks if sequence lengths of all "sequence" fields match.""" + for fld in fields(self): + if fld.metadata.get("sequence", False) and fld.name != "complex": + # this is a sequence, should be the same length as all other sequences + sequence_dim = fld.metadata.get("sequence_dim", 0) + value = getattr(self, fld.name) + if value is None: + continue + match sequence_dim: + case 0: + # sequence is first dimension + value = getattr(self, fld.name) + if len(value) != len(self): + raise ValueError( + f"Mismatch in sequence length for field: {fld.name}. Expected {len(self)}, received {len(value)}" + ) + case 1: + for item in value: + if len(item) != len(self): + raise ValueError( + f"Mismatch in sequence length for field: {fld.name}. Expected {len(self)}, received {len(item)}" + ) + case _: + raise NotImplementedError( + "Arbitrary matching for different sequence length fields is not implemented" + ) + + @classmethod + def concat(cls, items: list[T], **kwargs) -> T: + updated_fields = {} + for fld in fields(cls): + if fld.metadata.get("sequence", False): + # this is a sequence, should be the same length as all other sequences + sequence_dim = fld.metadata.get("sequence_dim", 0) + join_value = fld.metadata.get("join_token", None) + if getattr(items[0], fld.name) is None: + continue + values = [getattr(item, fld.name) for item in items] + match sequence_dim: + case 0: + # sequence is first dimension + value = concat_objects(values, join_value) + updated_fields[fld.name] = value + case 1: + new_value = [ + concat_objects(item, join_value) for item in zip(*values) + ] + updated_fields[fld.name] = getattr( + items[0], fld.name + ).__class__(new_value) + case _: + raise NotImplementedError( + "Arbitrary joining for different sequence length fields is not implemented" + ) + updated_fields.update(kwargs) + + return replace( + items[0], # type: ignore + **updated_fields, + ) diff --git a/esm/utils/structure/aligner.py b/esm/utils/structure/aligner.py index f25d9987..dd6702aa 100644 --- a/esm/utils/structure/aligner.py +++ b/esm/utils/structure/aligner.py @@ -6,7 +6,9 @@ import numpy as np import torch -from esm.utils.structure.protein_structure import compute_affine_and_rmsd +from esm.utils.structure.protein_structure import ( + compute_affine_and_rmsd, +) class Alignable(Protocol): diff --git a/esm/utils/structure/atom_indexer.py b/esm/utils/structure/atom_indexer.py index d62f05c9..2f588b98 100644 --- a/esm/utils/structure/atom_indexer.py +++ b/esm/utils/structure/atom_indexer.py @@ -1,6 +1,8 @@ import numpy as np -from esm.utils.structure.protein_structure import index_by_atom_name +from esm.utils.structure.protein_structure import ( + index_by_atom_name, +) class AtomIndexer: diff --git a/esm/utils/structure/input_builder.py b/esm/utils/structure/input_builder.py new file mode 100644 index 00000000..158c887c --- /dev/null +++ b/esm/utils/structure/input_builder.py @@ -0,0 +1,96 @@ +from dataclasses import dataclass +from typing import Any, Sequence + +import numpy as np + + + +@dataclass +class Modification: + position: int # zero-indexed + ccd: str + + +@dataclass +class ProteinInput: + id: str | list[str] + sequence: str + modifications: list[Modification] | None = None + + +@dataclass +class RNAInput: + id: str | list[str] + sequence: str + modifications: list[Modification] | None = None + + +@dataclass +class DNAInput: + id: str | list[str] + sequence: str + modifications: list[Modification] | None = None + + +@dataclass +class LigandInput: + id: str | list[str] + smiles: str + ccd: list[str] | None = None + + +@dataclass +class DistogramConditioning: + chain_id: str + distogram: np.ndarray + + +@dataclass +class PocketConditioning: + binder_chain_id: str + contacts: list[tuple[str, int]] + + +@dataclass +class StructurePredictionInput: + sequences: Sequence[ProteinInput | RNAInput | DNAInput | LigandInput] + pocket: PocketConditioning | None = None + distogram_conditioning: list[DistogramConditioning] | None = None + + +def serialize_structure_prediction_input(all_atom_input: StructurePredictionInput): + def create_chain_data(seq_input, chain_type: str) -> dict[str, Any]: + chain_data: dict[str, Any] = { + "sequence": seq_input.sequence, + "id": seq_input.id, + "type": chain_type, + } + if hasattr(seq_input, "modifications") and seq_input.modifications: + mods = [ + {"position": mod.position, "ccd": mod.ccd} + for mod in seq_input.modifications + ] + chain_data["modifications"] = mods + return chain_data + + sequences = [] + for seq_input in all_atom_input.sequences: + if isinstance(seq_input, ProteinInput): + sequences.append(create_chain_data(seq_input, "protein")) + elif isinstance(seq_input, RNAInput): + sequences.append(create_chain_data(seq_input, "rna")) + elif isinstance(seq_input, DNAInput): + sequences.append(create_chain_data(seq_input, "dna")) + elif isinstance(seq_input, LigandInput): + sequences.append( + { + "smiles": seq_input.smiles, + "id": seq_input.id, + "ccd": seq_input.ccd, + "type": "ligand", + } + ) + else: + raise ValueError(f"Unsupported sequence input type: {type(seq_input)}") + + return {"sequences": sequences} diff --git a/esm/utils/structure/metrics.py b/esm/utils/structure/metrics.py index d76ed766..b2e590db 100644 --- a/esm/utils/structure/metrics.py +++ b/esm/utils/structure/metrics.py @@ -264,7 +264,7 @@ def compute_lddt_ca( if all_atom_pred_pos.dim() != 3: all_atom_pred_pos = all_atom_pred_pos[..., ca_pos, :] all_atom_positions = all_atom_positions[..., ca_pos, :] - all_atom_mask = all_atom_mask[..., ca_pos : (ca_pos + 1)] # keep dim + all_atom_mask = all_atom_mask[..., ca_pos] return compute_lddt( all_atom_pred_pos, diff --git a/esm/utils/structure/molecular_complex.py b/esm/utils/structure/molecular_complex.py new file mode 100644 index 00000000..40bfae4e --- /dev/null +++ b/esm/utils/structure/molecular_complex.py @@ -0,0 +1,944 @@ +from __future__ import annotations + +import io +import os +import re +from dataclasses import asdict, dataclass +from pathlib import Path +from subprocess import check_output +from tempfile import TemporaryDirectory +from typing import TYPE_CHECKING, Any, List + +import biotite.structure.io.pdbx as pdbx +import brotli +import msgpack +import numpy as np +import torch + +from esm.utils import residue_constants +from esm.utils.structure.metrics import ( + compute_lddt, + compute_rmsd, +) +from esm.utils.structure.protein_complex import ( + ProteinComplex, + ProteinComplexMetadata, +) + + +@dataclass +class MolecularComplexResult: + """Result of molecular complex folding""" + + complex: MolecularComplex + plddt: torch.Tensor | None = None + ptm: float | None = None + iptm: float | None = None + pae: torch.Tensor | None = None + distogram: torch.Tensor | None = None + pair_chains_iptm: torch.Tensor | None = None + output_embedding_sequence: torch.Tensor | None = None + output_embedding_pair_pooled: torch.Tensor | None = None + + +@dataclass +class MolecularComplexMetadata: + """Metadata for MolecularComplex objects.""" + + entity_lookup: dict[int, str] + chain_lookup: dict[int, str] + assembly_composition: dict[str, list[str]] | None = None + + +@dataclass +class Molecule: + """Represents a single molecule/token within a MolecularComplex.""" + + token: str + token_idx: int + atom_positions: np.ndarray # [N_atoms, 3] + atom_elements: np.ndarray # [N_atoms] element strings + residue_type: int + molecule_type: int # PROTEIN=0, RNA=1, DNA=2, LIGAND=3 + confidence: float + + +@dataclass(frozen=True) +class MolecularComplex: + """ + Dataclass representing a molecular complex with support for proteins, nucleic acids, and ligands. + + Uses a flat atom representation with token-based sequence indexing, supporting all atom types + beyond the traditional atom37 protein representation. + """ + + id: str + sequence: List[str] # Token sequence like ['MET', 'LYS', 'A', 'G', 'ATP'] + + # Flat atom arrays - simplified representation + atom_positions: np.ndarray # [N_atoms, 3] 3D coordinates + atom_elements: np.ndarray # [N_atoms] element strings + + # Token-to-atom mapping for efficient access + token_to_atoms: np.ndarray # [N_tokens, 2] start/end indices into atoms array + + # Confidence data + plddt: np.ndarray # Per-token confidence scores [N_tokens] + + # Metadata + metadata: MolecularComplexMetadata + + def __post_init__(self): + """Validate array dimensions.""" + n_tokens = len(self.sequence) + assert ( + self.token_to_atoms.shape[0] == n_tokens + ), f"token_to_atoms shape {self.token_to_atoms.shape} != {n_tokens} tokens" + assert ( + self.plddt.shape[0] == n_tokens + ), f"plddt shape {self.plddt.shape} != {n_tokens} tokens" + + def __len__(self) -> int: + """Return number of tokens.""" + return len(self.sequence) + + def __getitem__(self, idx: int) -> Molecule: + """Access individual molecules/tokens by index.""" + if idx >= len(self.sequence) or idx < 0: + raise IndexError( + f"Token index {idx} out of range for {len(self.sequence)} tokens" + ) + + token = self.sequence[idx] + start_atom, end_atom = self.token_to_atoms[idx] + + # Extract atom data for this token + token_atom_positions = self.atom_positions[start_atom:end_atom] + token_atom_elements = self.atom_elements[start_atom:end_atom] + + # Default values for residue/molecule type (would be extended based on actual implementation) + residue_type = 0 # Default to standard residue + molecule_type = 0 # Default to protein + + return Molecule( + token=token, + token_idx=idx, + atom_positions=token_atom_positions, + atom_elements=token_atom_elements, + residue_type=residue_type, + molecule_type=molecule_type, + confidence=self.plddt[idx], + ) + + @property + def atom_coordinates(self) -> np.ndarray: + """Get flat array of all atom coordinates [N_atoms, 3].""" + return self.atom_positions + + # Conversion methods + @classmethod + def from_protein_complex(cls, pc: ProteinComplex) -> "MolecularComplex": + """Convert a ProteinComplex to MolecularComplex. + + Args: + pc: ProteinComplex object with atom37 representation + + Returns: + MolecularComplex with flat atom arrays and token-based indexing + """ + from esm.utils import residue_constants + + # Extract sequence without chain breaks + sequence_no_breaks = pc.sequence.replace("|", "") + sequence_tokens = [ + residue_constants.restype_1to3.get(aa, "UNK") for aa in sequence_no_breaks + ] + + # Convert atom37 to flat arrays + flat_positions = [] + flat_elements = [] + token_to_atoms = [] + + atom_idx = 0 + residue_idx = 0 + + for i, aa in enumerate(pc.sequence): + if aa == "|": + # Skip chain break tokens + continue + + # Get atom37 positions and mask for this residue + res_positions = pc.atom37_positions[residue_idx] # [37, 3] + res_mask = pc.atom37_mask[residue_idx] # [37] + + # Track start position for this token + token_start = atom_idx + + # Process each atom type in atom37 representation + for atom_type_idx, atom_name in enumerate(residue_constants.atom_types): + if res_mask[atom_type_idx]: # Atom is present + # Add position + flat_positions.append(res_positions[atom_type_idx]) + + # Determine element from atom name + element = ( + atom_name[0] if atom_name else "C" + ) # First character is element + flat_elements.append(element) + + atom_idx += 1 + + # Record token-to-atom mapping [start_idx, end_idx) + token_to_atoms.append([token_start, atom_idx]) + residue_idx += 1 + + # Convert to numpy arrays + atom_positions = np.array(flat_positions, dtype=np.float32) + atom_elements = np.array(flat_elements, dtype=object) + token_to_atoms_array = np.array(token_to_atoms, dtype=np.int32) + + # Extract confidence scores (skip chain breaks) + confidence_scores = [] + residue_idx = 0 + for aa in pc.sequence: + if aa != "|": + confidence_scores.append(pc.confidence[residue_idx]) + residue_idx += 1 + + confidence_array = np.array(confidence_scores, dtype=np.float32) + + # Create metadata - convert entity IDs to strings for MolecularComplexMetadata + entity_lookup_str = {k: str(v) for k, v in pc.metadata.entity_lookup.items()} + metadata = MolecularComplexMetadata( + entity_lookup=entity_lookup_str, + chain_lookup=pc.metadata.chain_lookup, + assembly_composition=pc.metadata.assembly_composition, + ) + + return cls( + id=pc.id, + sequence=sequence_tokens, + atom_positions=atom_positions, + atom_elements=atom_elements, + token_to_atoms=token_to_atoms_array, + plddt=confidence_array, + metadata=metadata, + ) + + def to_protein_complex(self) -> ProteinComplex: + """Convert MolecularComplex back to ProteinComplex format. + + Extracts only protein tokens and converts from flat atom representation + back to atom37 format used by ProteinComplex. + + Returns: + ProteinComplex with protein residues only, excluding ligands/nucleic acids + """ + from esm.utils import residue_constants + + # No need for element mapping - already using element characters + + # Filter for protein tokens only (skip ligands, nucleic acids) + protein_tokens = [] + protein_indices = [] + + for i, token in enumerate(self.sequence): + # Check if token is a standard 3-letter amino acid code + if token in residue_constants.restype_3to1: + protein_tokens.append(token) + protein_indices.append(i) + + if not protein_tokens: + raise ValueError("No protein tokens found in MolecularComplex") + + n_residues = len(protein_tokens) + + # Initialize atom37 arrays + atom37_positions = np.full((n_residues, 37, 3), np.nan, dtype=np.float32) + atom37_mask = np.zeros((n_residues, 37), dtype=bool) + + # Convert tokens back to single-letter sequence + single_letter_sequence = "".join( + [residue_constants.restype_3to1[token] for token in protein_tokens] + ) + + # Extract confidence scores for protein residues only + protein_confidence = self.plddt[protein_indices] + + # Convert flat atoms back to atom37 representation + for res_idx, token_idx in enumerate(protein_indices): + token = self.sequence[token_idx] + start_atom, end_atom = self.token_to_atoms[token_idx] + + # Get atom data for this residue + res_atom_positions = self.atom_positions[start_atom:end_atom] + + # Reconstruct atom37 representation by exactly reversing the forward conversion logic + # In from_protein_complex, atoms are added in atom_types order if present in mask + # So we need to reconstruct the mask and positions in the same order + atom_count = 0 + for atom_type_idx, atom_name in enumerate(residue_constants.atom_types): + # Check if this atom type exists for this residue and was present + residue_atoms = residue_constants.residue_atoms.get(token, []) + if atom_name in residue_atoms: + # This atom type exists for this residue, so it should have been included + if atom_count < len(res_atom_positions): + atom37_positions[res_idx, atom_type_idx] = res_atom_positions[ + atom_count + ] + atom37_mask[res_idx, atom_type_idx] = True + atom_count += 1 + + # Create other required arrays for ProteinComplex + # For simplicity, assume all protein residues belong to the same entity/chain + entity_id = np.zeros(n_residues, dtype=np.int64) + chain_id = np.zeros(n_residues, dtype=np.int64) + sym_id = np.zeros(n_residues, dtype=np.int64) + residue_index = np.arange(1, n_residues + 1, dtype=np.int64) + insertion_code = np.array([""] * n_residues, dtype=object) + + # Create simplified protein complex metadata + # Map the first entity/chain from molecular complex metadata + protein_metadata = ProteinComplexMetadata( + entity_lookup={0: 1}, # Single entity (int for ProteinComplexMetadata) + chain_lookup={0: "A"}, # Single chain + assembly_composition=self.metadata.assembly_composition, + ) + + return ProteinComplex( + id=self.id, + sequence=single_letter_sequence, + entity_id=entity_id, + chain_id=chain_id, + sym_id=sym_id, + residue_index=residue_index, + insertion_code=insertion_code, + atom37_positions=atom37_positions, + atom37_mask=atom37_mask, + confidence=protein_confidence, + metadata=protein_metadata, + ) + + @classmethod + def from_mmcif(cls, inp: str, id: str | None = None) -> "MolecularComplex": + """Read MolecularComplex from mmcif file or string. + + Args: + inp: Path to mmCIF file or mmCIF content as string + id: Optional identifier to assign to the complex + + Returns: + MolecularComplex with all molecules (proteins, ligands, nucleic acids) + """ + from io import StringIO + + # Check if input is a file path or mmCIF string content + if os.path.exists(inp): + # Input is a file path + mmcif_file = pdbx.CIFFile.read(inp) + else: + # Input is mmCIF string content + mmcif_file = pdbx.CIFFile.read(StringIO(inp)) + + # Get structure - handle missing model information gracefully + try: + structure = pdbx.get_structure(mmcif_file, model=1) + except (KeyError, ValueError): + # Fallback for mmCIF files without model information + try: + structure = pdbx.get_structure(mmcif_file) + except Exception: + # Last resort: use the first available model or all atoms + structure = pdbx.get_structure(mmcif_file, model=None) + # Type hint for pyright - structure is an AtomArray which is iterable + if TYPE_CHECKING: + structure: Any = structure + + # Get entity information from mmCIF + entity_info = {} + try: + # Access the first block in CIFFile + block = mmcif_file[0] + if "entity" in block: + entity_category = block["entity"] + if "id" in entity_category and "type" in entity_category: + entity_ids = entity_category["id"] + entity_types = entity_category["type"] + # Convert CIFColumn to list for iteration + if hasattr(entity_ids, "__iter__") and hasattr( + entity_types, "__iter__" + ): + # Type annotation to help pyright understand these are iterable + entity_ids_list = list(entity_ids) # type: ignore + entity_types_list = list(entity_types) # type: ignore + for eid, etype in zip(entity_ids_list, entity_types_list): + entity_info[eid] = etype + except Exception: + pass + + # Initialize arrays for flat atom representation + sequence_tokens = [] + flat_positions = [] + flat_elements = [] + token_to_atoms = [] + confidence_scores = [] + + atom_idx = 0 + + # Group atoms by chain and residue + chain_residue_groups = {} + for atom in structure: + chain_id = atom.chain_id + res_id = atom.res_id + res_name = atom.res_name + + if chain_id not in chain_residue_groups: + chain_residue_groups[chain_id] = {} + if res_id not in chain_residue_groups[chain_id]: + chain_residue_groups[chain_id][res_id] = { + "atoms": [], + "res_name": res_name, + "is_hetero": atom.hetero, + } + chain_residue_groups[chain_id][res_id]["atoms"].append(atom) + + # Process each chain and residue + for chain_id in sorted(chain_residue_groups.keys()): + residues = chain_residue_groups[chain_id] + + for res_id in sorted(residues.keys()): + residue_data = residues[res_id] + res_name = residue_data["res_name"] + atoms = residue_data["atoms"] + is_hetero = residue_data["is_hetero"] + + # Skip water molecules + if res_name == "HOH": + continue + + # Determine token name + if not is_hetero and res_name in residue_constants.restype_3to1: + # Standard amino acid + token_name = res_name + elif res_name in ["A", "T", "G", "C", "U", "DA", "DT", "DG", "DC"]: + # Nucleotide + token_name = res_name + else: + # Ligand or other molecule + token_name = res_name + + sequence_tokens.append(token_name) + token_start = atom_idx + + # Add all atoms from this residue + for atom in atoms: + flat_positions.append(atom.coord) + + # Get element character + element = atom.element + flat_elements.append(element) + + atom_idx += 1 + + # Record token-to-atom mapping + token_to_atoms.append([token_start, atom_idx]) + + # Add confidence score (B-factor if available, otherwise 1.0) + bfactor = getattr(atoms[0], "b_factor", 50.0) if atoms else 50.0 + confidence_scores.append(min(bfactor / 100.0, 1.0)) + + # Convert to numpy arrays + if not flat_positions: + # Create minimal arrays if no atoms found + atom_positions = np.zeros((0, 3), dtype=np.float32) + atom_elements = np.zeros(0, dtype=object) + token_to_atoms_array = np.zeros((len(sequence_tokens), 2), dtype=np.int32) + else: + atom_positions = np.array(flat_positions, dtype=np.float32) + atom_elements = np.array(flat_elements, dtype=object) + token_to_atoms_array = np.array(token_to_atoms, dtype=np.int32) + + confidence_array = np.array(confidence_scores, dtype=np.float32) + + # Create metadata + metadata = MolecularComplexMetadata( + entity_lookup=entity_info, + chain_lookup={ + i: chain_id for i, chain_id in enumerate(chain_residue_groups.keys()) + }, + assembly_composition=None, + ) + + # Set complex ID - if input was a path, use the stem; otherwise use default + if os.path.exists(inp): + complex_id = id or Path(inp).stem + else: + complex_id = id or "complex_from_string" + + return cls( + id=complex_id, + sequence=sequence_tokens, + atom_positions=atom_positions, + atom_elements=atom_elements, + token_to_atoms=token_to_atoms_array, + plddt=confidence_array, + metadata=metadata, + ) + + def to_mmcif(self) -> str: + """Write MolecularComplex to mmcif string. + + Returns: + String representation of the complex in mmCIF format + """ + # No need for element mapping - already using element characters + + lines = [] + + # Header + lines.append(f"data_{self.id}") + lines.append("#") + lines.append(f"_entry.id {self.id}") + lines.append("#") + + # Structure metadata + lines.append("_struct.entry_id {}".format(self.id)) + lines.append("_struct.title 'Protein Structure'") + lines.append("#") + + # Entity information + entity_id = 1 + chain_counter = 0 + lines.append("loop_") + lines.append("_entity.id") + lines.append("_entity.type") + lines.append("_entity.pdbx_description") + + # Determine entities based on sequence + protein_tokens = [] + other_tokens = [] + + for i, token in enumerate(self.sequence): + if token in residue_constants.restype_3to1: + protein_tokens.append((i, token)) + else: + other_tokens.append((i, token)) + + if protein_tokens: + lines.append(f"{entity_id} polymer 'Protein chain'") + entity_id += 1 + + for token in set(token for _, token in other_tokens): + lines.append(f"{entity_id} non-polymer 'Ligand {token}'") + entity_id += 1 + + lines.append("#") + + # Chain assignments + lines.append("loop_") + lines.append("_struct_asym.id") + lines.append("_struct_asym.entity_id") + + chain_id = "A" + if protein_tokens: + lines.append(f"{chain_id} 1") + chain_counter += 1 + chain_id = chr(ord(chain_id) + 1) + + entity_id = 2 + for token in set(token for _, token in other_tokens): + lines.append(f"{chain_id} {entity_id}") + entity_id += 1 + chain_counter += 1 + if chain_counter < 26: + chain_id = chr(ord(chain_id) + 1) + + lines.append("#") + + # Atom site information + lines.append("loop_") + lines.append("_atom_site.group_PDB") + lines.append("_atom_site.id") + lines.append("_atom_site.type_symbol") + lines.append("_atom_site.label_atom_id") + lines.append("_atom_site.label_alt_id") + lines.append("_atom_site.label_comp_id") + lines.append("_atom_site.label_asym_id") + lines.append("_atom_site.label_entity_id") + lines.append("_atom_site.label_seq_id") + lines.append("_atom_site.pdbx_PDB_ins_code") + lines.append("_atom_site.Cartn_x") + lines.append("_atom_site.Cartn_y") + lines.append("_atom_site.Cartn_z") + lines.append("_atom_site.occupancy") + lines.append("_atom_site.B_iso_or_equiv") + lines.append("_atom_site.pdbx_PDB_model_num") + lines.append("_atom_site.auth_seq_id") + lines.append("_atom_site.auth_comp_id") + lines.append("_atom_site.auth_asym_id") + lines.append("_atom_site.auth_atom_id") + + atom_id = 1 + seq_id = 1 + chain_id = "A" + entity_id = 1 + + for token_idx, token in enumerate(self.sequence): + start_atom, end_atom = self.token_to_atoms[token_idx] + + # Determine if this is a protein residue or ligand + is_protein = token in residue_constants.restype_3to1 + group_pdb = "ATOM" if is_protein else "HETATM" + current_entity_id = 1 if is_protein else 2 # Simplified entity assignment + current_chain_id = "A" if is_protein else "B" # Simplified chain assignment + + # Create atom names for this token + atom_names = [] + if is_protein: + # Use standard protein atom names + res_atoms = residue_constants.residue_atoms.get( + token, ["N", "CA", "C", "O"] + ) + atom_names = res_atoms[: end_atom - start_atom] + else: + # Generate generic atom names for ligands + for i in range(end_atom - start_atom): + atom_names.append(f"C{i+1}") + + # Pad atom names if needed + while len(atom_names) < (end_atom - start_atom): + atom_names.append(f"X{len(atom_names)+1}") + + # Write atoms for this token + for atom_idx_in_token, global_atom_idx in enumerate( + range(start_atom, end_atom) + ): + pos = self.atom_positions[global_atom_idx] + element_char = self.atom_elements[global_atom_idx] + element_symbol = element_char if isinstance(element_char, str) else "C" + + atom_name = ( + atom_names[atom_idx_in_token] + if atom_idx_in_token < len(atom_names) + else f"X{atom_idx_in_token+1}" + ) + + # Format atom site line + bfactor = ( + self.plddt[token_idx] * 100.0 + if len(self.plddt) > token_idx + else 50.0 + ) + + line = ( + f"{group_pdb:<6} {atom_id:>5} {element_symbol:<2} {atom_name:<4} . " + f"{token:<3} {current_chain_id} {current_entity_id} {seq_id:>3} ? " + f"{pos[0]:>8.3f} {pos[1]:>8.3f} {pos[2]:>8.3f} 1.00 {bfactor:>6.2f} 1 " + f"{seq_id:>3} {token:<3} {current_chain_id} {atom_name:<4}" + ) + lines.append(line) + atom_id += 1 + + seq_id += 1 + + lines.append("#") + return "\n".join(lines) + + def dockq(self, native: "MolecularComplex") -> Any: + """Compute DockQ score against native structure. + + Args: + native: Native MolecularComplex to compute DockQ against + + Returns: + DockQ result containing score and alignment information + """ + # Imports moved to top of file + + # Convert both complexes to ProteinComplex format for DockQ computation + # This extracts only the protein portion and converts to PDB format + try: + self_pc = self.to_protein_complex() + native_pc = native.to_protein_complex() + except ValueError as e: + raise ValueError( + f"Cannot convert MolecularComplex to ProteinComplex for DockQ: {e}" + ) + + # Normalize chain IDs for PDB compatibility + self_pc = self_pc.normalize_chain_ids_for_pdb() + native_pc = native_pc.normalize_chain_ids_for_pdb() + + # Use the existing ProteinComplex.dockq() method + try: + dockq_result = self_pc.dockq(native_pc) + return dockq_result + except Exception: + # Fallback to manual DockQ computation if ProteinComplex.dockq() fails + return self._compute_dockq_manual(native) + + def _compute_dockq_manual(self, native: "MolecularComplex") -> Any: + """Manual DockQ computation fallback.""" + # Imports moved to top of file + + # Convert both complexes to ProteinComplex format + try: + self_pc = self.to_protein_complex() + native_pc = native.to_protein_complex() + except ValueError as e: + raise ValueError( + f"Cannot convert MolecularComplex to ProteinComplex for DockQ: {e}" + ) + + # Normalize chain IDs for PDB compatibility + self_pc = self_pc.normalize_chain_ids_for_pdb() + native_pc = native_pc.normalize_chain_ids_for_pdb() + + # Write temporary PDB files and run DockQ + with TemporaryDirectory() as tdir: + dir_path = Path(tdir) + self_pdb = dir_path / "self.pdb" + native_pdb = dir_path / "native.pdb" + + # Write PDB files + self_pc.to_pdb(self_pdb) + native_pc.to_pdb(native_pdb) + + # Run DockQ + try: + output = check_output(["DockQ", str(self_pdb), str(native_pdb)]) + output_text = output.decode() + + # Parse DockQ output + lines = output_text.split("\n") + + # Find the total DockQ score + dockq_score = None + for line in lines: + if "Total DockQ" in line: + match = re.search(r"Total DockQ.*: ([\d.]+)", line) + if match: + dockq_score = float(match.group(1)) + break + + if dockq_score is None: + # Try to find individual DockQ scores + for line in lines: + if line.startswith("DockQ") and ":" in line: + try: + dockq_score = float(line.split(":")[1].strip()) + break + except (ValueError, IndexError): + continue + + if dockq_score is None: + raise ValueError("Could not parse DockQ score from output") + + # Return a simple result structure + return { + "total_dockq": dockq_score, + "raw_output": output_text, + "aligned": self, # Return self as aligned structure + } + + except FileNotFoundError: + raise RuntimeError( + "DockQ is not installed. Please install DockQ to use this method." + ) + except Exception as e: + raise RuntimeError(f"DockQ computation failed: {e}") + + def rmsd(self, target: "MolecularComplex", **kwargs) -> float: + """Compute RMSD against target structure. + + Args: + target: Target MolecularComplex to compute RMSD against + **kwargs: Additional arguments passed to compute_rmsd + + Returns: + float: RMSD value between the two structures + """ + # Imports moved to top of file + + # Ensure both complexes have the same number of tokens + if len(self) != len(target): + raise ValueError( + f"Complexes must have the same number of tokens: {len(self)} vs {len(target)}" + ) + + # Extract center positions for each token (using centroid of atoms) + mobile_coords = [] + target_coords = [] + atom_mask = [] + + for i in range(len(self)): + # Get atom positions for this token + mobile_start, mobile_end = self.token_to_atoms[i] + target_start, target_end = target.token_to_atoms[i] + + # Extract atom positions + mobile_atoms = self.atom_positions[mobile_start:mobile_end] + target_atoms = target.atom_positions[target_start:target_end] + + # Check if both tokens have atoms + if len(mobile_atoms) == 0 or len(target_atoms) == 0: + # Skip tokens with no atoms + continue + + # For simplicity, use the centroid of atoms as the representative position + mobile_center = mobile_atoms.mean(axis=0) + target_center = target_atoms.mean(axis=0) + + mobile_coords.append(mobile_center) + target_coords.append(target_center) + atom_mask.append(True) + + if len(mobile_coords) == 0: + raise ValueError("No valid atoms found for RMSD computation") + + # Convert to tensors + mobile_tensor = torch.from_numpy(np.stack(mobile_coords, axis=0)).unsqueeze( + 0 + ) # [1, N, 3] + target_tensor = torch.from_numpy(np.stack(target_coords, axis=0)).unsqueeze( + 0 + ) # [1, N, 3] + mask_tensor = torch.tensor(atom_mask, dtype=torch.bool).unsqueeze(0) # [1, N] + + # Compute RMSD using existing infrastructure + rmsd_value = compute_rmsd( + mobile=mobile_tensor, + target=target_tensor, + atom_exists_mask=mask_tensor, + reduction="batch", + **kwargs, + ) + + return float(rmsd_value) + + def lddt_ca(self, target: "MolecularComplex", **kwargs) -> float: + """Compute LDDT score against target structure. + + Args: + target: Target MolecularComplex to compute LDDT against + **kwargs: Additional arguments passed to compute_lddt + + Returns: + float: LDDT value between the two structures + """ + # Imports moved to top of file + + # Ensure both complexes have the same number of tokens + if len(self) != len(target): + raise ValueError( + f"Complexes must have the same number of tokens: {len(self)} vs {len(target)}" + ) + + # Extract center positions for each token (using centroid of atoms) + mobile_coords = [] + target_coords = [] + atom_mask = [] + + for i in range(len(self)): + # Get atom positions for this token + mobile_start, mobile_end = self.token_to_atoms[i] + target_start, target_end = target.token_to_atoms[i] + + # Extract atom positions + mobile_atoms = self.atom_positions[mobile_start:mobile_end] + target_atoms = target.atom_positions[target_start:target_end] + + # Check if both tokens have atoms + if len(mobile_atoms) == 0 or len(target_atoms) == 0: + # Skip tokens with no atoms + mobile_coords.append(np.full(3, np.nan)) + target_coords.append(np.full(3, np.nan)) + atom_mask.append(False) + continue + + # For simplicity, use the centroid of atoms as the representative position + mobile_center = mobile_atoms.mean(axis=0) + target_center = target_atoms.mean(axis=0) + + mobile_coords.append(mobile_center) + target_coords.append(target_center) + atom_mask.append(True) + + if not any(atom_mask): + raise ValueError("No valid atoms found for LDDT computation") + + # Convert to tensors + mobile_tensor = torch.from_numpy(np.stack(mobile_coords, axis=0)).unsqueeze( + 0 + ) # [1, N, 3] + target_tensor = torch.from_numpy(np.stack(target_coords, axis=0)).unsqueeze( + 0 + ) # [1, N, 3] + mask_tensor = torch.tensor(atom_mask, dtype=torch.bool).unsqueeze(0) # [1, N] + + # Compute LDDT using existing infrastructure + lddt_value = compute_lddt( + all_atom_pred_pos=mobile_tensor, + all_atom_positions=target_tensor, + all_atom_mask=mask_tensor, + per_residue=False, # Return overall LDDT score + **kwargs, + ) + + return float(lddt_value) + + def state_dict(self): + """This state dict is optimized for storage, so it turns things to fp16 whenever + possible and converts numpy arrays to lists for JSON serialization. + """ + dct = {k: v for k, v in vars(self).items()} + for k, v in dct.items(): + if isinstance(v, np.ndarray): + match v.dtype: + case np.int64: + dct[k] = v.astype(np.int32).tolist() + case np.float64 | np.float32: + dct[k] = v.astype(np.float16).tolist() + case _: + dct[k] = v.tolist() + elif isinstance(v, MolecularComplexMetadata): + dct[k] = asdict(v) + + return dct + + def to_blob(self) -> bytes: + return brotli.compress(msgpack.dumps(self.state_dict()), quality=5) + + @classmethod + def from_state_dict(cls, dct): + for k, v in dct.items(): + if isinstance(v, list) and k in [ + "atom_positions", + "atom_elements", + "token_to_atoms", + "plddt", + ]: + dct[k] = np.array(v) + + for k, v in dct.items(): + if isinstance(v, np.ndarray): + if k in ["atom_positions", "plddt"]: + dct[k] = v.astype(np.float32) + elif k in ["token_to_atoms"]: + dct[k] = v.astype(np.int32) + + dct["metadata"] = MolecularComplexMetadata(**dct["metadata"]) + return cls(**dct) + + @classmethod + def from_blob(cls, input: Path | str | io.BytesIO | bytes): + match input: + case Path() | str(): + bytes = Path(input).read_bytes() + case io.BytesIO(): + bytes = input.getvalue() + case _: + bytes = input + return cls.from_state_dict( + msgpack.loads(brotli.decompress(bytes), strict_map_key=False) + ) diff --git a/esm/utils/structure/protein_chain.py b/esm/utils/structure/protein_chain.py index 4889886e..b4db7081 100644 --- a/esm/utils/structure/protein_chain.py +++ b/esm/utils/structure/protein_chain.py @@ -25,13 +25,21 @@ from esm.utils.structure.affine3d import Affine3D from esm.utils.structure.aligner import Aligner from esm.utils.structure.atom_indexer import AtomIndexer -from esm.utils.structure.metrics import compute_gdt_ts, compute_lddt_ca -from esm.utils.structure.mmcif_parsing import MmcifWrapper, Residue +from esm.utils.structure.metrics import ( + compute_gdt_ts, + compute_lddt_ca, +) +from esm.utils.structure.mmcif_parsing import ( + MmcifWrapper, + Residue, +) from esm.utils.structure.normalize_coordinates import ( apply_frame_to_coords, get_protein_normalization_frame, ) -from esm.utils.structure.protein_structure import index_by_atom_name +from esm.utils.structure.protein_structure import ( + index_by_atom_name, +) from esm.utils.types import PathOrBuffer msgpack_numpy.patch() @@ -393,6 +401,7 @@ def from_blob(cls, input: Path | str | io.BytesIO | bytes): bytes = input return cls.from_state_dict(msgpack.loads(brotli.decompress(bytes))) + def sasa(self, by_residue: bool = True): arr = self.atom_array_no_insertions sasa_per_atom = bs.sasa(arr) # type: ignore @@ -698,6 +707,7 @@ def gdt_ts( ) return float(gdt_ts) if gdt_ts.numel() == 1 else gdt_ts.numpy().flatten() + @classmethod def chain_iterable_from_mmcif( cls, diff --git a/esm/utils/structure/protein_complex.py b/esm/utils/structure/protein_complex.py index 7e2c38de..2bd20398 100644 --- a/esm/utils/structure/protein_complex.py +++ b/esm/utils/structure/protein_complex.py @@ -32,8 +32,14 @@ from esm.utils.structure.affine3d import Affine3D from esm.utils.structure.aligner import Aligner from esm.utils.structure.atom_indexer import AtomIndexer -from esm.utils.structure.metrics import compute_gdt_ts, compute_lddt_ca -from esm.utils.structure.mmcif_parsing import MmcifWrapper, NoProteinError +from esm.utils.structure.metrics import ( + compute_gdt_ts, + compute_lddt_ca, +) +from esm.utils.structure.mmcif_parsing import ( + MmcifWrapper, + NoProteinError, +) from esm.utils.structure.protein_chain import ( ProteinChain, chain_to_ndarray, diff --git a/esm/utils/system.py b/esm/utils/system.py new file mode 100644 index 00000000..c2800e57 --- /dev/null +++ b/esm/utils/system.py @@ -0,0 +1,45 @@ +import io +import subprocess +import typing as T +from pathlib import Path + +PathLike = T.Union[str, Path] +PathOrBuffer = T.Union[PathLike, io.StringIO] + + +def run_subprocess_with_errorcheck( + *popenargs, + capture_output: bool = False, + quiet: bool = False, + env: dict[str, str] | None = None, + shell: bool = False, + executable: str | None = None, + **kws, +) -> subprocess.CompletedProcess: + """A command similar to subprocess.run, however the errormessage will + contain the stderr when using this function. This makes it significantly + easier to diagnose issues. + """ + try: + if capture_output: + stdout = subprocess.PIPE + elif quiet: + stdout = subprocess.DEVNULL + else: + stdout = None + + p = subprocess.run( + *popenargs, + stderr=subprocess.PIPE, + stdout=stdout, + check=True, + env=env, + shell=shell, + executable=executable, + **kws, + ) + except subprocess.CalledProcessError as e: + raise RuntimeError( + f"Command failed with errorcode {e.returncode}." f"\n\n{e.stderr.decode()}" + ) + return p diff --git a/esm/widgets/components/function_annotator.py b/esm/widgets/components/function_annotator.py index f567f94d..714238f9 100644 --- a/esm/widgets/components/function_annotator.py +++ b/esm/widgets/components/function_annotator.py @@ -4,7 +4,9 @@ from ipywidgets import widgets from esm.sdk.api import FunctionAnnotation -from esm.tokenization.function_tokenizer import InterProQuantizedTokenizer +from esm.tokenization.function_tokenizer import ( + InterProQuantizedTokenizer, +) TRIE: pygtrie.CharTrie | None = None diff --git a/esm/widgets/components/results_visualizer.py b/esm/widgets/components/results_visualizer.py index 261c0a5d..99692e51 100644 --- a/esm/widgets/components/results_visualizer.py +++ b/esm/widgets/components/results_visualizer.py @@ -7,11 +7,15 @@ import matplotlib.pyplot as plt from esm.sdk.api import ESMProtein -from esm.widgets.utils.drawing.draw_category_array import draw_data_array +from esm.widgets.utils.drawing.draw_category_array import ( + draw_data_array, +) from esm.widgets.utils.drawing.draw_function_annotations import ( draw_function_annotations, ) -from esm.widgets.utils.drawing.draw_protein_structure import draw_protein_structure +from esm.widgets.utils.drawing.draw_protein_structure import ( + draw_protein_structure, +) from esm.widgets.utils.serialization import ( create_download_button_from_buffer, protein_to_pdb_buffer, diff --git a/esm/widgets/components/sasa_prompt_selector.py b/esm/widgets/components/sasa_prompt_selector.py index 9c026500..ecc5c1b4 100644 --- a/esm/widgets/components/sasa_prompt_selector.py +++ b/esm/widgets/components/sasa_prompt_selector.py @@ -3,9 +3,16 @@ import ipywidgets as widgets from esm.utils.structure.protein_chain import ProteinChain -from esm.widgets.utils.drawing.colors import hex_to_rgba_tuple, rgba_tuple_to_hex -from esm.widgets.utils.drawing.draw_category_array import draw_data_array -from esm.widgets.utils.parsing import convert_range_string_to_list_of_ranges +from esm.widgets.utils.drawing.colors import ( + hex_to_rgba_tuple, + rgba_tuple_to_hex, +) +from esm.widgets.utils.drawing.draw_category_array import ( + draw_data_array, +) +from esm.widgets.utils.parsing import ( + convert_range_string_to_list_of_ranges, +) from esm.widgets.utils.prompting import PromptManager diff --git a/esm/widgets/components/secondary_structure_prompt_selector.py b/esm/widgets/components/secondary_structure_prompt_selector.py index b8007d69..020180fe 100644 --- a/esm/widgets/components/secondary_structure_prompt_selector.py +++ b/esm/widgets/components/secondary_structure_prompt_selector.py @@ -4,9 +4,16 @@ import pydssp from esm.utils.structure.protein_chain import ProteinChain -from esm.widgets.utils.drawing.colors import hex_to_rgba_tuple, rgba_tuple_to_hex -from esm.widgets.utils.drawing.draw_category_array import draw_data_array -from esm.widgets.utils.parsing import convert_range_string_to_list_of_ranges +from esm.widgets.utils.drawing.colors import ( + hex_to_rgba_tuple, + rgba_tuple_to_hex, +) +from esm.widgets.utils.drawing.draw_category_array import ( + draw_data_array, +) +from esm.widgets.utils.parsing import ( + convert_range_string_to_list_of_ranges, +) from esm.widgets.utils.prompting import PromptManager diff --git a/esm/widgets/components/sequence_prompt_selector.py b/esm/widgets/components/sequence_prompt_selector.py index c5e3526f..d538e670 100644 --- a/esm/widgets/components/sequence_prompt_selector.py +++ b/esm/widgets/components/sequence_prompt_selector.py @@ -6,7 +6,9 @@ hex_to_rgba_tuple, rgba_tuple_to_rgba_html_string, ) -from esm.widgets.utils.parsing import convert_range_string_to_list_of_ranges +from esm.widgets.utils.parsing import ( + convert_range_string_to_list_of_ranges, +) from esm.widgets.utils.prompting import PromptManager diff --git a/esm/widgets/components/structure_prompt_selector.py b/esm/widgets/components/structure_prompt_selector.py index 13a9df78..f4b497c0 100644 --- a/esm/widgets/components/structure_prompt_selector.py +++ b/esm/widgets/components/structure_prompt_selector.py @@ -10,8 +10,12 @@ from esm.utils.structure.protein_chain import ProteinChain from esm.widgets.utils import indexing -from esm.widgets.utils.drawing.draw_protein_structure import draw_protein_structure -from esm.widgets.utils.parsing import convert_range_string_to_list_of_ranges +from esm.widgets.utils.drawing.draw_protein_structure import ( + draw_protein_structure, +) +from esm.widgets.utils.parsing import ( + convert_range_string_to_list_of_ranges, +) from esm.widgets.utils.printing import wrapped_print from esm.widgets.utils.prompting import PromptManager diff --git a/esm/widgets/utils/drawing/draw_function_annotations.py b/esm/widgets/utils/drawing/draw_function_annotations.py index 59e9f7cf..c71e5434 100644 --- a/esm/widgets/utils/drawing/draw_function_annotations.py +++ b/esm/widgets/utils/drawing/draw_function_annotations.py @@ -9,7 +9,10 @@ from PIL import Image from esm.sdk.api import FunctionAnnotation -from esm.utils.function.interpro import InterPro, InterProEntryType +from esm.utils.function.interpro import ( + InterPro, + InterProEntryType, +) @contextmanager diff --git a/esm/widgets/utils/prompting.py b/esm/widgets/utils/prompting.py index 1e89bb64..1ce6d9c6 100644 --- a/esm/widgets/utils/prompting.py +++ b/esm/widgets/utils/prompting.py @@ -9,7 +9,9 @@ from esm.utils import encoding from esm.widgets.utils import indexing from esm.widgets.utils.drawing.colors import rgba_tuple_to_hex -from esm.widgets.utils.drawing.draw_category_array import draw_data_array +from esm.widgets.utils.drawing.draw_category_array import ( + draw_data_array, +) from esm.widgets.utils.printing import wrapped_print diff --git a/esm/widgets/views/esm3_generation_launcher.py b/esm/widgets/views/esm3_generation_launcher.py index f8bf8f3b..e94c60ff 100644 --- a/esm/widgets/views/esm3_generation_launcher.py +++ b/esm/widgets/views/esm3_generation_launcher.py @@ -13,9 +13,13 @@ GenerationConfig, ) from esm.utils.constants import models -from esm.widgets.components.results_visualizer import create_results_visualizer +from esm.widgets.components.results_visualizer import ( + create_results_visualizer, +) from esm.widgets.utils.printing import wrapped_print -from esm.widgets.utils.serialization import create_download_results_button +from esm.widgets.utils.serialization import ( + create_download_results_button, +) def create_esm3_generation_launcher( diff --git a/esm/widgets/views/esm3_prompt_selector.py b/esm/widgets/views/esm3_prompt_selector.py index f7e60686..035db28e 100644 --- a/esm/widgets/views/esm3_prompt_selector.py +++ b/esm/widgets/views/esm3_prompt_selector.py @@ -1,6 +1,8 @@ from ipywidgets import widgets -from esm.widgets.components.sasa_prompt_selector import create_sasa_prompt_selector +from esm.widgets.components.sasa_prompt_selector import ( + create_sasa_prompt_selector, +) from esm.widgets.components.secondary_structure_prompt_selector import ( create_secondary_structure_prompt_selector, ) diff --git a/esm/widgets/views/generation.py b/esm/widgets/views/generation.py index fdec2094..19015f60 100644 --- a/esm/widgets/views/generation.py +++ b/esm/widgets/views/generation.py @@ -4,12 +4,20 @@ from esm.sdk.api import ESM3InferenceClient, ESMProtein from esm.utils.constants import esm3 as C -from esm.widgets.components.function_annotator import create_function_annotator +from esm.widgets.components.function_annotator import ( + create_function_annotator, +) from esm.widgets.utils.prompting import PromptManagerCollection from esm.widgets.utils.protein_import import ProteinImporter -from esm.widgets.views.esm3_generation_launcher import create_esm3_generation_launcher -from esm.widgets.views.esm3_prompt_preview import create_esm3_prompt_preview -from esm.widgets.views.esm3_prompt_selector import create_esm3_prompt_selector +from esm.widgets.views.esm3_generation_launcher import ( + create_esm3_generation_launcher, +) +from esm.widgets.views.esm3_prompt_preview import ( + create_esm3_prompt_preview, +) +from esm.widgets.views.esm3_prompt_selector import ( + create_esm3_prompt_selector, +) def create_generation_ui( diff --git a/esm/widgets/views/inverse_folding.py b/esm/widgets/views/inverse_folding.py index 50d9a128..8becb8eb 100644 --- a/esm/widgets/views/inverse_folding.py +++ b/esm/widgets/views/inverse_folding.py @@ -6,7 +6,9 @@ ESMProteinError, GenerationConfig, ) -from esm.widgets.components.results_visualizer import create_results_visualizer +from esm.widgets.components.results_visualizer import ( + create_results_visualizer, +) from esm.widgets.utils.printing import wrapped_print from esm.widgets.utils.protein_import import ProteinImporter diff --git a/esm/widgets/views/login.py b/esm/widgets/views/login.py index 5c7b6706..2d8be5a3 100644 --- a/esm/widgets/views/login.py +++ b/esm/widgets/views/login.py @@ -4,7 +4,10 @@ from ipywidgets import widgets -from esm.widgets.utils.clients import get_forge_client, get_local_client +from esm.widgets.utils.clients import ( + get_forge_client, + get_local_client, +) from esm.widgets.utils.types import ClientInitContainer diff --git a/esm/widgets/views/prediction.py b/esm/widgets/views/prediction.py index 94ff49dc..de6666d2 100644 --- a/esm/widgets/views/prediction.py +++ b/esm/widgets/views/prediction.py @@ -6,7 +6,9 @@ ESMProteinError, GenerationConfig, ) -from esm.widgets.components.results_visualizer import create_results_visualizer +from esm.widgets.components.results_visualizer import ( + create_results_visualizer, +) from esm.widgets.utils.printing import wrapped_print from esm.widgets.utils.protein_import import ProteinImporter diff --git a/pyproject.toml b/pyproject.toml index 2f8008f8..0ca682fd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "esm" -version = "3.2.2" +version = "3.2.2.post1" description = "EvolutionaryScale open model repository" readme = "README.md" requires-python = ">=3.12,<3.13" @@ -45,7 +45,6 @@ dependencies = [ "pygtrie", "dna_features_viewer", ] - # Pytest [tool.pytest.ini_options] addopts = """ diff --git a/tests/oss_pytests/requirements.txt b/tests/oss_pytests/requirements.txt index d120ce35..143baa7a 100644 --- a/tests/oss_pytests/requirements.txt +++ b/tests/oss_pytests/requirements.txt @@ -1,3 +1,2 @@ -esm +esm >=3.2.1post1,<4.0.0 pytest -httpx # TODO(williamxi): Remove this after the esm repo is fixed diff --git a/tests/oss_pytests/test_oss_client.py b/tests/oss_pytests/test_oss_client.py index 6fa4a30d..e2dbafb5 100644 --- a/tests/oss_pytests/test_oss_client.py +++ b/tests/oss_pytests/test_oss_client.py @@ -1,7 +1,7 @@ import os import pytest - +import torch from esm.sdk import client # pyright: ignore from esm.sdk.api import ( # pyright: ignore ESMProtein, @@ -37,6 +37,7 @@ def test_oss_esm3_client(): logits_config = LogitsConfig(sequence=True, return_embeddings=True) result = esm3_client.logits(input=encoded_protein, config=logits_config) assert isinstance(result, LogitsOutput) + assert isinstance(result.logits.sequence, torch.Tensor) sampling_config = SamplingConfig(sequence=SamplingTrackConfig(temperature=0.1)) result = esm3_client.forward_and_sample( @@ -53,7 +54,7 @@ def test_oss_esm3_client(): def test_oss_esmc_client(): assert URL is not None - sequence = "MALWMRLLPLLALLALAVUUPDPAAA" + sequence = "MALWMRLLPLLALLALAVPDPAAA" model = "esmc-300m-2024-12" esmc_client = client(model=model, url=URL, token=API_TOKEN) @@ -69,13 +70,14 @@ def test_oss_esmc_client(): ) result = esmc_client.logits(input=encoded_protein, config=logits_config) assert isinstance(result, LogitsOutput) + assert isinstance(result.logits.sequence, torch.Tensor) @pytest.mark.sdk def test_oss_sequence_structure_forge_inference_client(): assert URL is not None - sequence = "MALWMRLLPLLALLALAVUUPDPAAA" + sequence = "MALWMRLLPLLALLALAVPDPAAA" model = "esm3-small-2024-03" client = SequenceStructureForgeInferenceClient( model=model, url=URL, token=API_TOKEN