diff --git a/cookbook/local/open_generate.ipynb b/cookbook/local/open_generate.ipynb index c8cec4c1..32a72c38 100644 --- a/cookbook/local/open_generate.ipynb +++ b/cookbook/local/open_generate.ipynb @@ -38,6 +38,7 @@ "\n", "!pip install py3Dmol\n", "import py3Dmol\n", + "\n", "from esm.models.esm3 import ESM3\n", "from esm.sdk.api import ESMProtein, GenerationConfig\n", "from esm.utils.structure.protein_chain import ProteinChain" diff --git a/cookbook/local/raw_forwards.py b/cookbook/local/raw_forwards.py index 5701fa2a..baad28ee 100644 --- a/cookbook/local/raw_forwards.py +++ b/cookbook/local/raw_forwards.py @@ -2,6 +2,7 @@ import torch import torch.nn.functional as F + from esm.pretrained import ( ESM3_function_decoder_v0, ESM3_sm_open_v0, @@ -12,9 +13,7 @@ from esm.tokenization.function_tokenizer import ( InterProQuantizedTokenizer as EsmFunctionTokenizer, ) -from esm.tokenization.sequence_tokenizer import ( - EsmSequenceTokenizer, -) +from esm.tokenization.sequence_tokenizer import EsmSequenceTokenizer from esm.utils.structure.protein_chain import ProteinChain from esm.utils.types import FunctionAnnotation diff --git a/cookbook/snippets/fold_invfold.py b/cookbook/snippets/fold_invfold.py index cb24db6a..e6be5485 100644 --- a/cookbook/snippets/fold_invfold.py +++ b/cookbook/snippets/fold_invfold.py @@ -2,6 +2,7 @@ from typing import cast import numpy as np + from esm.sdk.api import ( ESM3InferenceClient, ESMProtein, diff --git a/cookbook/tutorials/1_esmprotein.ipynb b/cookbook/tutorials/1_esmprotein.ipynb index e143ff13..13017733 100644 --- a/cookbook/tutorials/1_esmprotein.ipynb +++ b/cookbook/tutorials/1_esmprotein.ipynb @@ -72,6 +72,7 @@ "outputs": [], "source": [ "from biotite.database import rcsb\n", + "\n", "from esm.sdk.api import ESMProtein\n", "from esm.utils.structure.protein_chain import ProteinChain\n", "from esm.utils.types import FunctionAnnotation\n", @@ -496,9 +497,10 @@ "# Functions for visualizing InterPro function annotations\n", "\n", "from dna_features_viewer import GraphicFeature, GraphicRecord\n", - "from esm.utils.function.interpro import InterPro, InterProEntryType\n", "from matplotlib import colormaps\n", "\n", + "from esm.utils.function.interpro import InterPro, InterProEntryType\n", + "\n", "\n", "def visualize_function_annotations(\n", " annotations: list[FunctionAnnotation],\n", diff --git a/cookbook/tutorials/2_embed.ipynb b/cookbook/tutorials/2_embed.ipynb index 61fdb397..459fa90e 100644 --- a/cookbook/tutorials/2_embed.ipynb +++ b/cookbook/tutorials/2_embed.ipynb @@ -49,18 +49,18 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Grab a token from [Forge](https://forge.evolutionaryscale.ai/) and add it below. Note that your token is like a password for your account and you should take care to protect it. For this reason it is recommended to frequently create a new token and delete old, unused ones. It is also recommended to paste the token directly into an environment variable or use a utility like `getpass` as shown below so tokens are not accidentally shared or checked into code repositories." + "Grab a token from [the Forge console](https://forge.evolutionaryscale.ai/console) and add it below. Note that your token is like a password for your account and you should take care to protect it. For this reason it is recommended to frequently create a new token and delete old, unused ones. It is also recommended to paste the token directly into an environment variable or use a utility like `getpass` as shown below so tokens are not accidentally shared or checked into code repositories." ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "from getpass import getpass\n", "\n", - "token = getpass(\"Token from Forge: \")" + "token = getpass(\"Token from Forge console: \")" ] }, { diff --git a/cookbook/tutorials/3_gfp_design.ipynb b/cookbook/tutorials/3_gfp_design.ipynb index 95b42418..bde09ad5 100644 --- a/cookbook/tutorials/3_gfp_design.ipynb +++ b/cookbook/tutorials/3_gfp_design.ipynb @@ -64,6 +64,7 @@ "import matplotlib.pyplot as pl\n", "import py3Dmol\n", "import torch\n", + "\n", "from esm.sdk import client\n", "from esm.sdk.api import ESMProtein, GenerationConfig\n", "from esm.utils.structure.protein_chain import ProteinChain" @@ -79,18 +80,18 @@ "\n", "The largest ESM3 (98 billion parameters) was trained with 1.07e24 FLOPs on 2.78 billion proteins and 771 billion unique tokens. To create esmGFP we used the 7 billion parameter variant of ESM3. We'll use this model via the [EvolutionaryScale Forge](https://forge.evolutionaryscale.ai) API.\n", "\n", - "Grab a token from [Forge](https://forge.evolutionaryscale.ai/) and add it below. Note that your token is like a password for your account and you should take care to protect it. For this reason it is recommended to frequently create a new token and delete old, unused ones. It is also recommended to paste the token directly into an environment variable or use a utility like `getpass` as shown below so tokens are not accidentally shared or checked into code repositories.\n" + "Grab a token from [the Forge console](https://forge.evolutionaryscale.ai/console) and add it below. Note that your token is like a password for your account and you should take care to protect it. For this reason it is recommended to frequently create a new token and delete old, unused ones. It is also recommended to paste the token directly into an environment variable or use a utility like `getpass` as shown below so tokens are not accidentally shared or checked into code repositories.\n" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "metadata": { "id": "zNrU9Q2SYonX" }, "outputs": [], "source": [ - "token = getpass(\"Token from Forge: \")" + "token = getpass(\"Token from Forge console: \")" ] }, { diff --git a/cookbook/tutorials/4_forge_generate.ipynb b/cookbook/tutorials/4_forge_generate.ipynb index 5fb6e676..7570fda9 100644 --- a/cookbook/tutorials/4_forge_generate.ipynb +++ b/cookbook/tutorials/4_forge_generate.ipynb @@ -36,6 +36,7 @@ "\n", "!pip install py3Dmol\n", "import py3Dmol\n", + "\n", "from esm.sdk import client\n", "from esm.sdk.api import ESMProtein, GenerationConfig\n", "from esm.utils.structure.protein_chain import ProteinChain" @@ -52,7 +53,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Grab a token from [Forge](https://forge.evolutionaryscale.ai/) and add it below. Note that your token is like a password for your account and you should take care to protect it. For this reason it is recommended to frequently create a new token and delete old, unused ones. It is also recommended to paste the token directly into an environment variable or use a utility like `getpass` as shown below so tokens are not accidentally shared or checked into code repositories." + "Grab a token from [the Forge console](https://forge.evolutionaryscale.ai/console) and add it below. Note that your token is like a password for your account and you should take care to protect it. For this reason it is recommended to frequently create a new token and delete old, unused ones. It is also recommended to paste the token directly into an environment variable or use a utility like `getpass` as shown below so tokens are not accidentally shared or checked into code repositories." ] }, { @@ -63,7 +64,7 @@ "source": [ "from getpass import getpass\n", "\n", - "token = getpass(\"Token from Forge: \")\n", + "token = getpass(\"Token from Forge console: \")\n", "model = client(model=\"esm3-open\", url=\"https://forge.evolutionaryscale.ai\", token=token)" ] }, diff --git a/cookbook/tutorials/5_guided_generation.ipynb b/cookbook/tutorials/5_guided_generation.ipynb index b2d8c7dc..f04d6be3 100644 --- a/cookbook/tutorials/5_guided_generation.ipynb +++ b/cookbook/tutorials/5_guided_generation.ipynb @@ -49,6 +49,7 @@ "source": [ "import biotite.structure as bs\n", "import py3Dmol\n", + "\n", "from esm.sdk.api import ESMProtein, GenerationConfig\n", "from esm.sdk.experimental import ESM3GuidedDecoding, GuidedDecodingScoringFunction" ] @@ -119,7 +120,7 @@ "\n", "from esm.sdk import client\n", "\n", - "token = getpass(\"Token from Forge: \")\n", + "token = getpass(\"Token from Forge console: \")\n", "model = client(\n", " model=\"esm3-medium-2024-08\", url=\"https://forge.evolutionaryscale.ai\", token=token\n", ")" diff --git a/esm/__init__.py b/esm/__init__.py index aec10a41..1e3bed4c 100644 --- a/esm/__init__.py +++ b/esm/__init__.py @@ -1,2 +1 @@ -__version__ = "3.2.2.post1" - +__version__ = "3.2.2" diff --git a/esm/layers/attention.py b/esm/layers/attention.py index ce57632f..564ef90c 100644 --- a/esm/layers/attention.py +++ b/esm/layers/attention.py @@ -5,10 +5,7 @@ import torch.nn.functional as F from torch import nn -from esm.layers.rotary import ( - RotaryEmbedding, - TritonRotaryEmbedding, -) +from esm.layers.rotary import RotaryEmbedding, TritonRotaryEmbedding try: from flash_attn import flash_attn_varlen_qkvpacked_func diff --git a/esm/layers/blocks.py b/esm/layers/blocks.py index 593b277f..76ebbe06 100644 --- a/esm/layers/blocks.py +++ b/esm/layers/blocks.py @@ -2,13 +2,8 @@ import torch.nn as nn import torch.nn.functional as F -from esm.layers.attention import ( - FlashMultiHeadAttention, - MultiHeadAttention, -) -from esm.layers.geom_attention import ( - GeometricReasoningOriginalImpl, -) +from esm.layers.attention import FlashMultiHeadAttention, MultiHeadAttention +from esm.layers.geom_attention import GeometricReasoningOriginalImpl from esm.utils.structure.affine3d import Affine3D diff --git a/esm/layers/structure_proj.py b/esm/layers/structure_proj.py index 783ddeb4..faad0fe9 100644 --- a/esm/layers/structure_proj.py +++ b/esm/layers/structure_proj.py @@ -2,10 +2,7 @@ import torch.nn as nn from esm.utils.constants.physics import BB_COORDINATES -from esm.utils.structure.affine3d import ( - Affine3D, - RotationMatrix, -) +from esm.utils.structure.affine3d import Affine3D, RotationMatrix class Dim6RotStructureHead(nn.Module): diff --git a/esm/models/esm3.py b/esm/models/esm3.py index cbe02ddd..218a8e90 100644 --- a/esm/models/esm3.py +++ b/esm/models/esm3.py @@ -13,10 +13,7 @@ from esm.layers.regression_head import RegressionHead from esm.layers.transformer_stack import TransformerStack from esm.models.function_decoder import FunctionTokenDecoder -from esm.models.vqvae import ( - StructureTokenDecoder, - StructureTokenEncoder, -) +from esm.models.vqvae import StructureTokenDecoder, StructureTokenEncoder from esm.sdk.api import ( ESM3InferenceClient, ESMProtein, @@ -32,10 +29,7 @@ from esm.tokenization import TokenizerCollectionProtocol from esm.utils import encoding from esm.utils.constants import esm3 as C -from esm.utils.constants.models import ( - ESM3_OPEN_SMALL, - normalize_model_name, -) +from esm.utils.constants.models import ESM3_OPEN_SMALL, normalize_model_name from esm.utils.decoding import decode_protein_tensor from esm.utils.generation import ( _batch_forward, @@ -50,9 +44,7 @@ get_default_sampling_config, validate_sampling_config, ) -from esm.utils.structure.affine3d import ( - build_affine3d_from_coordinates, -) +from esm.utils.structure.affine3d import build_affine3d_from_coordinates @dataclass diff --git a/esm/models/function_decoder.py b/esm/models/function_decoder.py index c4f32992..e5f1fb28 100644 --- a/esm/models/function_decoder.py +++ b/esm/models/function_decoder.py @@ -12,9 +12,7 @@ from esm.layers.regression_head import RegressionHead from esm.layers.transformer_stack import TransformerStack -from esm.tokenization.function_tokenizer import ( - InterProQuantizedTokenizer, -) +from esm.tokenization.function_tokenizer import InterProQuantizedTokenizer from esm.utils.constants import esm3 as C from esm.utils.misc import merge_annotations, merge_ranges from esm.utils.types import FunctionAnnotation diff --git a/esm/models/vqvae.py b/esm/models/vqvae.py index 0f5226a4..37bc3945 100644 --- a/esm/models/vqvae.py +++ b/esm/models/vqvae.py @@ -7,10 +7,7 @@ from esm.layers.transformer_stack import TransformerStack from esm.utils.constants import esm3 as C from esm.utils.misc import knn_graph -from esm.utils.structure.affine3d import ( - Affine3D, - build_affine3d_from_coordinates, -) +from esm.utils.structure.affine3d import Affine3D, build_affine3d_from_coordinates from esm.utils.structure.predicted_aligned_error import ( compute_predicted_aligned_error, compute_tm, diff --git a/esm/pretrained.py b/esm/pretrained.py index b9121511..e452e1d2 100644 --- a/esm/pretrained.py +++ b/esm/pretrained.py @@ -6,14 +6,8 @@ from esm.models.esm3 import ESM3 from esm.models.esmc import ESMC from esm.models.function_decoder import FunctionTokenDecoder -from esm.models.vqvae import ( - StructureTokenDecoder, - StructureTokenEncoder, -) -from esm.tokenization import ( - get_esm3_model_tokenizers, - get_esmc_model_tokenizers, -) +from esm.models.vqvae import StructureTokenDecoder, StructureTokenEncoder +from esm.tokenization import get_esm3_model_tokenizers, get_esmc_model_tokenizers from esm.utils.constants.esm3 import data_root from esm.utils.constants.models import ( ESM3_FUNCTION_DECODER_V0, diff --git a/esm/sdk/api.py b/esm/sdk/api.py index 6b152556..0212ddcd 100644 --- a/esm/sdk/api.py +++ b/esm/sdk/api.py @@ -2,27 +2,19 @@ from abc import ABC from copy import deepcopy -from typing import List, Sequence +from typing import Sequence import attr import torch from attr import asdict, define import esm.utils.constants.api as C -from esm.tokenization import ( - TokenizerCollectionProtocol, - get_esm3_model_tokenizers, -) +from esm.tokenization import TokenizerCollectionProtocol, get_esm3_model_tokenizers from esm.utils import encoding from esm.utils.constants.models import ESM3_OPEN_SMALL -from esm.utils.misc import ( - get_chainbreak_boundaries_from_sequence, -) +from esm.utils.misc import get_chainbreak_boundaries_from_sequence from esm.utils.structure.protein_chain import ProteinChain -from esm.utils.structure.protein_complex import ( - SINGLE_LETTER_CHAIN_IDS, - ProteinComplex, -) +from esm.utils.structure.protein_complex import SINGLE_LETTER_CHAIN_IDS, ProteinComplex from esm.utils.types import FunctionAnnotation, PathOrBuffer @@ -43,7 +35,6 @@ class ESMProtein(ProteinType): plddt: torch.Tensor | None = None ptm: torch.Tensor | None = None - # When calling EvolutionaryScale API, use this flag to disclose any # sequences that may potentially have concerns. # Such sequences may not go through standard safety filter for approved users. @@ -157,35 +148,12 @@ def to_protein_complex( gt_chains = list(copy_annotations_from_ground_truth.chain_iter()) else: gt_chains = None - - # Expand pLDDT to match sequence length if needed, inserting NaN at chain breaks - # This handles the case where the server doesn't include chain breaks in pLDDT - # We should fix this in the server side. - if self.plddt is not None and len(self.plddt) != len(self.sequence): - # Only expand if there's a mismatch (likely due to chain breaks) - if "|" in self.sequence: - # Create expanded pLDDT with NaN at chain break positions - expanded_plddt = torch.full((len(self.sequence),), float("nan")) - plddt_idx = 0 - for i, aa in enumerate(self.sequence): - if aa != "|": - if plddt_idx < len(self.plddt): - expanded_plddt[i] = self.plddt[plddt_idx] - plddt_idx += 1 - plddt = expanded_plddt - else: - # Mismatch but no chain breaks - shouldn't happen but preserve original - plddt = self.plddt - else: - plddt = self.plddt - pred_chains = [] for i, (start, end) in enumerate(chain_boundaries): if i >= len(SINGLE_LETTER_CHAIN_IDS): raise ValueError( f"Too many chains to convert to ProteinComplex. The maximum number of chains is {len(SINGLE_LETTER_CHAIN_IDS)}" ) - pred_chain = ProteinChain.from_atom37( atom37_positions=coords[start:end], sequence=self.sequence[start:end], @@ -193,7 +161,7 @@ def to_protein_complex( if gt_chains is not None else SINGLE_LETTER_CHAIN_IDS[i], entity_id=gt_chains[i].entity_id if gt_chains is not None else None, - confidence=plddt[start:end] if plddt is not None else None, + confidence=self.plddt[start:end] if self.plddt is not None else None, ) pred_chains.append(pred_chain) return ProteinComplex.from_chains(pred_chains) @@ -330,14 +298,19 @@ def use_generative_unmasking_strategy(self): self.temperature_annealing = True +@define +class MSA: + # Paired MSA sequences. + # One would typically compute these using, for example, ColabFold. + sequences: list[str] + + @define class InverseFoldingConfig: invalid_ids: Sequence[int] = [] temperature: float = 1.0 - - ## Low Level Endpoint Types @define class SamplingTrackConfig: @@ -402,9 +375,6 @@ class LogitsConfig: ith_hidden_layer: int = -1 - - - @define class LogitsOutput: logits: ForwardTrackData | None = None diff --git a/esm/sdk/base_forge_client.py b/esm/sdk/base_forge_client.py index 3c60a25f..ff05b541 100644 --- a/esm/sdk/base_forge_client.py +++ b/esm/sdk/base_forge_client.py @@ -1,13 +1,9 @@ -import asyncio -import time -from abc import ABC, abstractmethod from typing import Any from urllib.parse import urljoin import httpx from esm.sdk.api import ESMProteinError -from esm.sdk.retry import retry_decorator from esm.utils.decoding import assemble_message @@ -84,10 +80,6 @@ def prepare_request( headers = {**self.headers, **headers} if return_bytes: headers["return-bytes"] = "true" - # __INTERNAL_BEGIN___ - if disable_cache: - headers["X-Disable-Cache"] = "true" - # __INTERNAL_END___ return request, headers def prepare_data(self, response, endpoint: str) -> dict[str, Any]: @@ -120,11 +112,7 @@ async def _async_post( ): try: request, headers = self.prepare_request( - request, - potential_sequence_of_concern, - return_bytes, - disable_cache, - headers, + request, potential_sequence_of_concern, return_bytes, headers ) response = await self.async_client.post( url=urljoin(self.url, f"/api/v1/{endpoint}"), @@ -154,10 +142,7 @@ def _post( ): try: request, headers = self.prepare_request( - request, - potential_sequence_of_concern, - return_bytes, - headers, + request, potential_sequence_of_concern, return_bytes, headers ) response = self.client.post( url=urljoin(self.url, f"/api/v1/{endpoint}"), @@ -175,5 +160,3 @@ def _post( error_code=500, error_msg=f"Failed to submit request to {endpoint}. Error: {e}", ) - - diff --git a/esm/sdk/forge.py b/esm/sdk/forge.py index ac105c03..e9bb2e3d 100644 --- a/esm/sdk/forge.py +++ b/esm/sdk/forge.py @@ -1,14 +1,13 @@ -from __future__ import annotations - import asyncio import base64 import pickle from concurrent.futures import ThreadPoolExecutor -from typing import Any, Literal, Sequence, cast +from typing import Any, Sequence import torch from esm.sdk.api import ( + MSA, ESM3InferenceClient, ESMCInferenceClient, ESMProtein, @@ -20,30 +19,14 @@ InverseFoldingConfig, LogitsConfig, LogitsOutput, - ProteinChain, ProteinType, SamplingConfig, SamplingTrackConfig, ) -from esm.sdk.base_forge_client import ( - _BaseForgeInferenceClient, -) +from esm.sdk.base_forge_client import _BaseForgeInferenceClient from esm.sdk.retry import retry_decorator from esm.utils.constants.api import MIMETYPE_ES_PICKLE -from esm.utils.misc import ( - deserialize_tensors, - maybe_list, - maybe_tensor, -) -from esm.utils.msa import MSA -from esm.utils.structure.input_builder import ( - StructurePredictionInput, - serialize_structure_prediction_input, -) -from esm.utils.structure.molecular_complex import ( - MolecularComplex, - MolecularComplexResult, -) +from esm.utils.misc import deserialize_tensors, maybe_list, maybe_tensor from esm.utils.types import FunctionAnnotation @@ -53,8 +36,10 @@ def _list_to_function_annotations(l) -> list[FunctionAnnotation] | None: return [FunctionAnnotation(*t) for t in l] -def _maybe_logits(data: dict[str, Any], track: str): - return maybe_tensor(data.get("logits", {}).get(track, None)) +def _maybe_logits(data: dict[str, Any], track: str, return_bytes: bool = False): + ret = data.get("logits", {}).get(track, None) + # TODO(s22chan): just return this when removing return_bytes + return ret if ret is None or not return_bytes else maybe_tensor(ret) def _maybe_b64_decode(obj, return_bytes: bool): @@ -108,13 +93,9 @@ def __init__( ) @staticmethod - def _process_fold_request( - sequence: str, - model_name: str | None, - ): + def _process_fold_request(sequence: str, model_name: str | None): request: dict[str, Any] = {"sequence": sequence} - request["model"] = model_name return request @@ -149,7 +130,6 @@ def process_inverse_fold_request( return request - async def _async_fetch_msa(self, sequence: str) -> MSA: print("Fetching MSA ... this may take a few minutes") # Accept both "|" and ":" as the chainbreak token. @@ -157,7 +137,7 @@ async def _async_fetch_msa(self, sequence: str) -> MSA: data = await self._async_post( "msa", request={}, params={"sequence": sequence, "use_env": False} ) - return MSA.from_sequences(sequences=data["msa"]) + return MSA(sequences=data["msa"]) def _fetch_msa(self, sequence: str) -> MSA: print("Fetching MSA ... this may take a few minutes") @@ -166,7 +146,7 @@ def _fetch_msa(self, sequence: str) -> MSA: data = self._post( "msa", request={}, params={"sequence": sequence, "use_env": False} ) - return MSA.from_sequences(sequences=data["msa"]) + return MSA(sequences=data["msa"]) @retry_decorator async def async_fold( @@ -188,15 +168,11 @@ async def async_fold( del potential_sequence_of_concern request = self._process_fold_request( - sequence, - model_name if model_name is not None else self.model, + sequence, model_name if model_name is not None else self.model ) try: - data = await self._async_post( - "fold", - request, - ) + data = await self._async_post("fold", request) except ESMProteinError as e: return e @@ -223,98 +199,16 @@ def fold( del potential_sequence_of_concern request = self._process_fold_request( - sequence, - model_name if model_name is not None else self.model, + sequence, model_name if model_name is not None else self.model ) try: - data = self._post( - "fold", - request, - ) + data = self._post("fold", request) except ESMProteinError as e: return e return self._process_fold_response(data, sequence) - @retry_decorator - async def async_fold_all_atom( - self, - all_atom_input: StructurePredictionInput, - model_name: str | None = None, - ) -> MolecularComplexResult | list[MolecularComplexResult] | ESMProteinError: - """Fold a molecular complex containing proteins, nucleic acids, and/or ligands. - - Args: - all_atom_input: StructurePredictionInput containing sequences for different molecule types - model_name: Override the client level model name if needed - """ - request = self._process_fold_all_atom_request( - all_atom_input, - model_name if model_name is not None else self.model, - ) - - try: - data = await self._async_post( - "fold_all_atom", - request, - ) - except ESMProteinError as e: - return e - - return self._process_fold_all_atom_response(data) - - @retry_decorator - def fold_all_atom( - self, - all_atom_input: StructurePredictionInput, - model_name: str | None = None, - ) -> MolecularComplexResult | list[MolecularComplexResult] | ESMProteinError: - """Predict coordinates for a molecular complex containing proteins, dna, rna, and/or ligands. - - Args: - all_atom_input: StructurePredictionInput containing sequences for different molecule types - model_name: Override the client level model name if needed - """ - request = self._process_fold_all_atom_request( - all_atom_input, - model_name if model_name is not None else self.model, - ) - - try: - data = self._post( - "fold_all_atom", - request, - ) - except ESMProteinError as e: - return e - - return self._process_fold_all_atom_response(data) - - @staticmethod - def _process_fold_all_atom_request( - all_atom_input: StructurePredictionInput, - model_name: str | None = None, - ) -> dict[str, Any]: - request: dict[str, Any] = { - "all_atom_input": serialize_structure_prediction_input(all_atom_input), - "model": model_name, - } - - - return request - - @staticmethod - def _process_fold_all_atom_response(data: dict[str, Any]) -> MolecularComplexResult: - complex_data = data.get("complex") - molecular_complex = MolecularComplex.from_state_dict(complex_data) - return MolecularComplexResult( - complex=molecular_complex, - plddt=maybe_tensor(data.get("plddt"), convert_none_to_nan=True), - ptm=data.get("ptm", None), - distogram=maybe_tensor(data.get("distogram"), convert_none_to_nan=True), - ) - @retry_decorator async def async_inverse_fold( self, @@ -386,7 +280,6 @@ def inverse_fold( return ESMProtein(sequence=data["sequence"]) - class ESM3ForgeInferenceClient(ESM3InferenceClient, _BaseForgeInferenceClient): def __init__( self, @@ -709,15 +602,19 @@ def _process_logits_response( return LogitsOutput( logits=ForwardTrackData( - sequence=_maybe_logits(data, "sequence"), - structure=_maybe_logits(data, "structure"), - secondary_structure=_maybe_logits(data, "secondary_structure"), - sasa=_maybe_logits(data, "sasa"), - function=_maybe_logits(data, "function"), + sequence=_maybe_logits(data, "sequence", return_bytes), + structure=_maybe_logits(data, "structure", return_bytes), + secondary_structure=_maybe_logits( + data, "secondary_structure", return_bytes + ), + sasa=_maybe_logits(data, "sasa", return_bytes), + function=_maybe_logits(data, "function", return_bytes), ), embeddings=maybe_tensor(data["embeddings"]), mean_embedding=data["mean_embedding"], - residue_annotation_logits=_maybe_logits(data, "residue_annotation"), + residue_annotation_logits=_maybe_logits( + data, "residue_annotation", return_bytes + ), hidden_states=maybe_tensor(data["hidden_states"]), mean_hidden_state=maybe_tensor(data["mean_hidden_state"]), ) @@ -1068,7 +965,6 @@ def _process_logits_request( "sequence": config.sequence, "return_embeddings": config.return_embeddings, "return_mean_embedding": config.return_mean_embedding, - "return_mean_hidden_states": config.return_mean_hidden_states, "return_hidden_states": config.return_hidden_states, "ith_hidden_layer": config.ith_hidden_layer, } @@ -1085,11 +981,12 @@ def _process_logits_response( data["hidden_states"] = _maybe_b64_decode(data["hidden_states"], return_bytes) output = LogitsOutput( - logits=ForwardTrackData(sequence=_maybe_logits(data, "sequence")), + logits=ForwardTrackData( + sequence=_maybe_logits(data, "sequence", return_bytes) + ), embeddings=maybe_tensor(data["embeddings"]), mean_embedding=data["mean_embedding"], hidden_states=maybe_tensor(data["hidden_states"]), - mean_hidden_state=maybe_tensor(data["mean_hidden_state"]), ) return output @@ -1212,5 +1109,3 @@ def raw_model(self): raise NotImplementedError( f"Can not get underlying remote model {self.model} from a Forge client." ) - - diff --git a/esm/sdk/retry.py b/esm/sdk/retry.py index 302d6cf0..16c354b6 100644 --- a/esm/sdk/retry.py +++ b/esm/sdk/retry.py @@ -2,9 +2,10 @@ from contextvars import ContextVar from functools import wraps +import httpx from tenacity import ( retry, - retry_if_exception, + retry_if_exception_type, retry_if_result, stop_after_attempt, wait_incrementing, @@ -29,12 +30,8 @@ def retry_if_specific_error(exception): def log_retry_attempt(retry_state): - try: - outcome = retry_state.outcome.result() - except Exception: - outcome = retry_state.outcome.exception() print( - f"Retrying... Attempt {retry_state.attempt_number} after {retry_state.next_action.sleep}s due to: {outcome}" + f"Retrying... Attempt {retry_state.attempt_number} after {retry_state.next_action.sleep}s due to: {retry_state.outcome.result()}" ) @@ -44,18 +41,13 @@ def retry_decorator(func): instance's retry settings. """ - def return_last_value(retry_state): - """Return the result of the last call attempt.""" - return retry_state.outcome.result() - @wraps(func) async def async_wrapper(instance, *args, **kwargs): if skip_retries_var.get(): return await func(instance, *args, **kwargs) retry_decorator = retry( - retry_error_callback=return_last_value, retry=retry_if_result(retry_if_specific_error) - | retry_if_exception(retry_if_specific_error), + | retry_if_exception_type(httpx.ConnectTimeout), # ADDED wait=wait_incrementing( increment=1, start=instance.min_retry_wait, max=instance.max_retry_wait ), @@ -70,9 +62,8 @@ def wrapper(instance, *args, **kwargs): if skip_retries_var.get(): return func(instance, *args, **kwargs) retry_decorator = retry( - retry_error_callback=return_last_value, retry=retry_if_result(retry_if_specific_error) - | retry_if_exception(retry_if_specific_error), + | retry_if_exception_type(httpx.ConnectTimeout), # ADDED wait=wait_incrementing( increment=1, start=instance.min_retry_wait, max=instance.max_retry_wait ), diff --git a/esm/tokenization/__init__.py b/esm/tokenization/__init__.py index ea609225..6db76554 100644 --- a/esm/tokenization/__init__.py +++ b/esm/tokenization/__init__.py @@ -1,10 +1,7 @@ from dataclasses import dataclass from typing import Protocol -from esm.utils.constants.models import ( - ESM3_OPEN_SMALL, - normalize_model_name, -) +from esm.utils.constants.models import ESM3_OPEN_SMALL, normalize_model_name from .function_tokenizer import InterProQuantizedTokenizer from .residue_tokenizer import ResidueAnnotationsTokenizer diff --git a/esm/utils/decoding.py b/esm/utils/decoding.py index b5588527..1fe256b6 100644 --- a/esm/utils/decoding.py +++ b/esm/utils/decoding.py @@ -10,24 +10,12 @@ from esm.models.vqvae import StructureTokenDecoder from esm.sdk.api import ESMProtein, ESMProteinTensor from esm.tokenization import TokenizerCollectionProtocol -from esm.tokenization.function_tokenizer import ( - InterProQuantizedTokenizer, -) -from esm.tokenization.residue_tokenizer import ( - ResidueAnnotationsTokenizer, -) -from esm.tokenization.sasa_tokenizer import ( - SASADiscretizingTokenizer, -) -from esm.tokenization.sequence_tokenizer import ( - EsmSequenceTokenizer, -) -from esm.tokenization.ss_tokenizer import ( - SecondaryStructureTokenizer, -) -from esm.tokenization.structure_tokenizer import ( - StructureTokenizer, -) +from esm.tokenization.function_tokenizer import InterProQuantizedTokenizer +from esm.tokenization.residue_tokenizer import ResidueAnnotationsTokenizer +from esm.tokenization.sasa_tokenizer import SASADiscretizingTokenizer +from esm.tokenization.sequence_tokenizer import EsmSequenceTokenizer +from esm.tokenization.ss_tokenizer import SecondaryStructureTokenizer +from esm.tokenization.structure_tokenizer import StructureTokenizer from esm.tokenization.tokenizer_base import EsmTokenizerBase from esm.utils.constants import api as api_constants from esm.utils.constants import esm3 as C diff --git a/esm/utils/encoding.py b/esm/utils/encoding.py index 83c9d033..8461709d 100644 --- a/esm/utils/encoding.py +++ b/esm/utils/encoding.py @@ -7,26 +7,13 @@ from esm.tokenization.function_tokenizer import ( InterProQuantizedTokenizer as EsmFunctionTokenizer, ) - -from esm.tokenization.residue_tokenizer import ( - ResidueAnnotationsTokenizer, -) -from esm.tokenization.sasa_tokenizer import ( - SASADiscretizingTokenizer, -) -from esm.tokenization.sequence_tokenizer import ( - EsmSequenceTokenizer, -) -from esm.tokenization.ss_tokenizer import ( - SecondaryStructureTokenizer, -) -from esm.tokenization.structure_tokenizer import ( - StructureTokenizer, -) +from esm.tokenization.residue_tokenizer import ResidueAnnotationsTokenizer +from esm.tokenization.sasa_tokenizer import SASADiscretizingTokenizer +from esm.tokenization.sequence_tokenizer import EsmSequenceTokenizer +from esm.tokenization.ss_tokenizer import SecondaryStructureTokenizer +from esm.tokenization.structure_tokenizer import StructureTokenizer from esm.utils.constants import esm3 as C -from esm.utils.function.encode_decode import ( - encode_function_annotations, -) +from esm.utils.function.encode_decode import encode_function_annotations from esm.utils.structure.protein_chain import ProteinChain from esm.utils.types import FunctionAnnotation @@ -165,8 +152,6 @@ def tokenize_function_annotations( return function_tokens, residue_annotation_tokens - - # Tokenized Defaults def get_default_sequence_tokens( sequence_length: int, sequence_tokenizer: EsmSequenceTokenizer @@ -242,5 +227,3 @@ def get_default_residue_annotation_tokens( residue_annotation_tokens[0] = residue_annotation_tokenizer.bos_token_id residue_annotation_tokens[-1] = residue_annotation_tokenizer.eos_token_id return residue_annotation_tokens - - diff --git a/esm/utils/forge_context_manager.py b/esm/utils/forge_context_manager.py index fac0c3bd..b1c2bdf3 100644 --- a/esm/utils/forge_context_manager.py +++ b/esm/utils/forge_context_manager.py @@ -7,10 +7,7 @@ from tqdm import tqdm from esm.sdk.api import ESMProteinError -from esm.sdk.retry import ( - retry_if_specific_error, - skip_retries_var, -) +from esm.sdk.retry import retry_if_specific_error, skip_retries_var TQDM_BAR_FORMAT = ( "{desc:<12}{percentage:3.0f}%|{bar:24}| {n_fmt}/{total_fmt} " diff --git a/esm/utils/function/encode_decode.py b/esm/utils/function/encode_decode.py index 29534e34..a4029858 100644 --- a/esm/utils/function/encode_decode.py +++ b/esm/utils/function/encode_decode.py @@ -3,16 +3,9 @@ import torch -from esm.models.function_decoder import ( - FunctionTokenDecoder, - merge_annotations, -) -from esm.tokenization.function_tokenizer import ( - InterProQuantizedTokenizer, -) -from esm.tokenization.residue_tokenizer import ( - ResidueAnnotationsTokenizer, -) +from esm.models.function_decoder import FunctionTokenDecoder, merge_annotations +from esm.tokenization.function_tokenizer import InterProQuantizedTokenizer +from esm.tokenization.residue_tokenizer import ResidueAnnotationsTokenizer from esm.utils.constants import esm3 as C from esm.utils.types import FunctionAnnotation diff --git a/esm/utils/generation.py b/esm/utils/generation.py index cbc4306b..b2d4ede4 100644 --- a/esm/utils/generation.py +++ b/esm/utils/generation.py @@ -19,13 +19,8 @@ SamplingConfig, SamplingTrackConfig, ) -from esm.tokenization import ( - EsmTokenizerBase, - TokenizerCollectionProtocol, -) -from esm.tokenization.function_tokenizer import ( - InterProQuantizedTokenizer, -) +from esm.tokenization import EsmTokenizerBase, TokenizerCollectionProtocol +from esm.tokenization.function_tokenizer import InterProQuantizedTokenizer from esm.utils.constants import esm3 as C from esm.utils.misc import stack_variable_length_tensors from esm.utils.noise_schedules import NOISE_SCHEDULE_REGISTRY @@ -48,9 +43,7 @@ def _trim_sequence_tensor_dataclass(o: Any, sequence_len: int): sliced = {} for k, v in attr.asdict(o, recurse=False).items(): - if k in ["mean_hidden_state", "mean_embedding"]: - sliced[k] = v - elif v is None: + if v is None: sliced[k] = None elif isinstance(v, torch.Tensor): # Trim padding. diff --git a/esm/utils/misc.py b/esm/utils/misc.py index 409ba1bf..f0d7a602 100644 --- a/esm/utils/misc.py +++ b/esm/utils/misc.py @@ -1,19 +1,7 @@ -from __future__ import annotations - import os from collections import defaultdict -from dataclasses import is_dataclass from io import BytesIO -from typing import ( - Any, - ContextManager, - Generator, - Iterable, - Protocol, - Sequence, - TypeVar, - runtime_checkable, -) +from typing import Any, ContextManager, Sequence, TypeVar from warnings import warn import huggingface_hub @@ -30,12 +18,6 @@ TSequence = TypeVar("TSequence", bound=Sequence) -@runtime_checkable -class Concatable(Protocol): - @classmethod - def concat(cls, objs: list[Concatable]) -> Concatable: ... - - def slice_python_object_as_numpy( obj: TSequence, idx: int | list[int] | slice | np.ndarray ) -> TSequence: @@ -70,37 +52,6 @@ def slice_python_object_as_numpy( return sliced_obj # type: ignore -def slice_any_object( - obj: TSequence, idx: int | list[int] | slice | np.ndarray -) -> TSequence: - """ - Slice a arbitrary object (like a list, string, or tuple) as if it was a numpy object. Similar to `slice_python_object_as_numpy`, but detects if it's a numpy array or Tensor and uses the existing slice method if so. - - If the object is a dataclass, it will simply apply the index to the object, under the assumption that the object has correcty implemented numpy indexing. - - Example: - >>> obj = "ABCDE" - >>> slice_any_object(obj, [1, 3, 4]) - "BDE" - - >>> obj = np.array([1, 2, 3, 4, 5]) - >>> slice_any_object(obj, np.arange(5) < 3) - np.array([1, 2, 3]) - - >>> obj = ProteinChain.from_rcsb("1a3a", "A") - >>> slice_any_object(obj, np.arange(len(obj)) < 10) - # ProteinChain w/ length 10 - - """ - if isinstance(obj, (np.ndarray, torch.Tensor)): - return obj[idx] # type: ignore - elif is_dataclass(obj): - # if passing a dataclass, assume it implements a custom slice - return obj[idx] # type: ignore - else: - return slice_python_object_as_numpy(obj, idx) - - def rbf(values, v_min, v_max, n_bins=16): """ Returns RBF encodings in a new dimension at the end. @@ -347,8 +298,6 @@ def replace_inf(data): def maybe_tensor(x, convert_none_to_nan: bool = False) -> torch.Tensor | None: if x is None: return None - if isinstance(x, torch.Tensor): - return x if isinstance(x, list) and all(isinstance(t, torch.Tensor) for t in x): return torch.stack(x) if convert_none_to_nan: @@ -408,90 +357,3 @@ def deserialize_tensors(b: bytes) -> Any: buf = BytesIO(zstd.ZSTD_uncompress(b)) d = torch.load(buf, map_location="cpu", weights_only=False) return d - - -def join_lists( - lists: Sequence[Sequence[Any]], separator: Sequence[Any] | None = None -) -> list[Any]: - """Joins multiple lists with separator element. Like str.join but for lists. - - Example: [[1, 2], [3], [4]], separator=[0] -> [1, 2, 0, 3, 0, 4] - - Args: - lists: Lists of elements to chain - separator: separators to intsert between chained output. - Returns: - Joined lists. - """ - if not lists: - return [] - joined = [] - joined.extend(lists[0]) - for l in lists[1:]: - if separator: - joined.extend(separator) - joined.extend(l) - return joined - - -def iterate_with_intermediate( - lists: Iterable, intermediate -) -> Generator[Any, None, None]: - """ - Iterate over the iterable, yielding the intermediate value between - every element of the intermediate. Useful for joining objects with - separator tokens. - """ - it = iter(lists) - yield next(it) - for l in it: - yield intermediate - yield l - - -def concat_objects(objs: Sequence[Any], separator: Any | None = None): - """ - Concat objects with each other using a separator token. - - Supports: - - Concatable (objects that implement `concat` classmethod) - - strings - - lists - - numpy arrays - - torch Tensors - - Example: - >>> foo = "abc" - >>> bar = "def" - >>> concat_objects([foo, bar], "|") - "abc|def" - """ - match objs[0]: - case Concatable(): - return objs[0].__class__.concat(objs) # type: ignore - case str(): - assert isinstance( - separator, str - ), "Trying to join strings but separator is not a string" - return separator.join(objs) - case list(): - if separator is not None: - return join_lists(objs, [separator]) - else: - return join_lists(objs) - case np.ndarray(): - if separator is not None: - return np.concatenate( - list(iterate_with_intermediate(objs, np.array([separator]))) - ) - else: - return np.concatenate(objs) - case torch.Tensor(): - if separator is not None: - return torch.cat( - list(iterate_with_intermediate(objs, torch.tensor([separator]))) - ) - else: - return torch.cat(objs) # type: ignore - case _: - raise TypeError(type(objs[0])) diff --git a/esm/utils/msa/__init__.py b/esm/utils/msa/__init__.py deleted file mode 100644 index 3804a365..00000000 --- a/esm/utils/msa/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -from esm.utils.msa.msa import ( - MSA, - FastMSA, - remove_insertions_from_sequence, -) - -__all__ = ["MSA", "FastMSA", "remove_insertions_from_sequence"] diff --git a/esm/utils/msa/filter_sequences.py b/esm/utils/msa/filter_sequences.py deleted file mode 100644 index d44549d5..00000000 --- a/esm/utils/msa/filter_sequences.py +++ /dev/null @@ -1,79 +0,0 @@ -import tempfile -from pathlib import Path - -import numpy as np -from scipy.spatial.distance import cdist - -from esm.utils.system import run_subprocess_with_errorcheck - - -def greedy_select_indices(array, num_seqs: int, mode: str = "max") -> list[int]: - """Greedily select sequences that either maximize or minimize hamming distance. - - Algorithm proposed in the MSA Transformer paper. Starting from the query sequence, - iteratively add sequences to the list with the maximum (minimum) average Hamming - distance to the existing set of sequences. - - Args: - array (np.ndarray): Character array representing the sequences in the MSA - num_seqs (int): Number of sequences to select. - mode (str): Whether to maximize or minimize diversity. DO NOT pick 'min' unless - you're doing it to prove a point for a paper. - - Returns: - list[int]: List of indices to select from the array - """ - assert mode in ("max", "min") - depth = array.shape[0] - if depth <= num_seqs: - return list(range(depth)) - array = array.view(np.uint8) - - optfunc = np.argmax if mode == "max" else np.argmin - all_indices = np.arange(depth) - indices = [0] - pairwise_distances = np.zeros((0, depth)) - for _ in range(num_seqs - 1): - dist = cdist(array[indices[-1:]], array, "hamming") - pairwise_distances = np.concatenate([pairwise_distances, dist]) - shifted_distance = np.delete(pairwise_distances, indices, axis=1).mean(0) - shifted_index = optfunc(shifted_distance) - index = np.delete(all_indices, indices)[shifted_index] - indices.append(index) - indices = sorted(indices) - return indices - - -def hhfilter( - sequences: list[str], - seqid: int = 90, - diff: int = 0, - cov: int = 0, - qid: int = 0, - qsc: float = -20.0, - binary: str = "hhfilter", -) -> list[int]: - with tempfile.TemporaryDirectory(dir="/dev/shm") as tempdirname: - tempdir = Path(tempdirname) - fasta_file = tempdir / "input.fasta" - fasta_file.write_text( - "\n".join(f">{i}\n{seq}" for i, seq in enumerate(sequences)) - ) - output_file = tempdir / "output.fasta" - command = " ".join( - [ - f"{binary}", - f"-i {fasta_file}", - "-M a3m", - f"-o {output_file}", - f"-id {seqid}", - f"-diff {diff}", - f"-cov {cov}", - f"-qid {qid}", - f"-qsc {qsc}", - ] - ).split(" ") - run_subprocess_with_errorcheck(command, capture_output=True) - with output_file.open() as f: - indices = [int(line[1:].strip()) for line in f if line.startswith(">")] - return indices diff --git a/esm/utils/msa/msa.py b/esm/utils/msa/msa.py deleted file mode 100644 index 8722b9d6..00000000 --- a/esm/utils/msa/msa.py +++ /dev/null @@ -1,507 +0,0 @@ -from __future__ import annotations - -import dataclasses -import string -from dataclasses import dataclass -from functools import cached_property -from itertools import islice -from typing import Sequence - -import numpy as np -from Bio import SeqIO -from scipy.spatial.distance import cdist - -from esm.utils.misc import slice_any_object -from esm.utils.msa.filter_sequences import ( - greedy_select_indices, - hhfilter, -) -from esm.utils.parsing import ( - FastaEntry, - read_sequences, - write_sequences, -) -from esm.utils.sequential_dataclass import SequentialDataclass -from esm.utils.system import PathOrBuffer - -REMOVE_LOWERCASE_TRANSLATION = str.maketrans(dict.fromkeys(string.ascii_lowercase)) - - -def remove_insertions_from_sequence(seq: str) -> str: - return seq.translate(REMOVE_LOWERCASE_TRANSLATION) - - -@dataclass(frozen=True) -class MSA(SequentialDataclass): - """Object-oriented interface to an MSA. - - Args: - sequences (list[str]): List of protein sequences - headers (list[str]): List of headers describing the sequences - - """ - - entries: list[FastaEntry] - - @cached_property - def sequences(self) -> list[str]: - return [entry.sequence for entry in self.entries] - - @cached_property - def headers(self) -> list[str]: - return [entry.header for entry in self.entries] - - def __repr__(self): - return ( - f"MSA({self.entries[0].header}: Depth={self.depth}, Length={self.seqlen})" - ) - - def to_fast_msa(self) -> FastMSA: - return FastMSA(self.array, self.headers) - - @classmethod - def from_a3m( - cls, - path: PathOrBuffer, - remove_insertions: bool = True, - max_sequences: int | None = None, - ) -> MSA: - entries = [] - for header, seq in islice(read_sequences(path), max_sequences): - if remove_insertions: - seq = remove_insertions_from_sequence(seq) - if entries: - assert ( - len(seq) == len(entries[0].sequence) - ), f"Sequence length mismatch. Expected: {len(entries[0].sequence)}, Received: {len(seq)}" - entries.append(FastaEntry(header, seq)) - return cls(entries) - - def to_a3m(self, path: PathOrBuffer) -> None: - write_sequences(self.entries, path) - - @classmethod - def from_stockholm( - cls, - path: PathOrBuffer, - remove_insertions: bool = True, - max_sequences: int | None = None, - ) -> MSA: - entries = [] - for record in islice(SeqIO.parse(path, "stockholm"), max_sequences): - header = f"{record.id} {record.description}" - seq = str(record.seq) - if entries: - assert ( - len(seq) == len(entries[0].sequence) - ), f"Sequence length mismatch. Expected: {len(entries[0].sequence)}, Received: {len(seq)}" - entries.append(FastaEntry(header, seq)) - msa = cls(entries) - if remove_insertions: - keep_inds = [i for i, aa in enumerate(msa.query) if aa != "-"] - msa = msa.select_positions(keep_inds) - return msa - - def to_bytes(self) -> bytes: - version = 1 - version_bytes = version.to_bytes(1, "little") - seqlen_bytes = self.seqlen.to_bytes(4, "little") - depth_bytes = self.depth.to_bytes(4, "little") - array_bytes = self.array.tobytes() - header_bytes = "\n".join(entry.header for entry in self.entries).encode() - all_bytes = ( - version_bytes + seqlen_bytes + depth_bytes + array_bytes + header_bytes - ) - return all_bytes - - @classmethod - def from_bytes(cls, data: bytes) -> MSA: - version_bytes, seqlen_bytes, depth_bytes, data = ( - data[:1], - data[1:5], - data[5:9], - data[9:], - ) - version = int.from_bytes(version_bytes, "little") - if version != 1: - raise ValueError(f"Unsupported version: {version}") - seqlen = int.from_bytes(seqlen_bytes, "little") - depth = int.from_bytes(depth_bytes, "little") - array_bytes, header_bytes = data[: seqlen * depth], data[seqlen * depth :] - array = np.frombuffer(array_bytes, dtype="|S1") - array = array.reshape(depth, seqlen) - headers = header_bytes.decode().split("\n") - # Sometimes the separation is two newlines, which results in an empty header. - headers = [header for header in headers if header] - entries = [ - FastaEntry(header, b"".join(row).decode()) - for header, row in zip(headers, array) - ] - return cls(entries) - - # TODO(jmaccarl): set remove_insertions to True by default here to match other utils - @classmethod - def from_sequences( - cls, sequences: list[str], remove_insertions: bool = False - ) -> MSA: - if remove_insertions: - entries = [ - FastaEntry("", remove_insertions_from_sequence(seq)) - for seq in sequences - ] - else: - entries = [FastaEntry("", seq) for seq in sequences] - return cls(entries) - - def to_sequence_bytes(self) -> bytes: - """Stores ONLY SEQUENCES in array format as bytes. Header information will be lost.""" - seqlen_bytes = self.seqlen.to_bytes(4, "little") - array_bytes = self.array.tobytes() - all_bytes = seqlen_bytes + array_bytes - return all_bytes - - @classmethod - def from_sequence_bytes(cls, data: bytes) -> MSA: - seqlen_bytes, array_bytes = data[:4], data[4:] - seqlen = int.from_bytes(seqlen_bytes, "little") - array = np.frombuffer(array_bytes, dtype="|S1") - array = array.reshape(-1, seqlen) - entries = [FastaEntry("", b"".join(row).decode()) for row in array] - return cls(entries) - - @property - def depth(self) -> int: - return len(self.entries) - - @property - def seqlen(self) -> int: - return len(self.entries[0].sequence) - - @cached_property - def array(self) -> np.ndarray: - return np.array([list(seq) for seq in self.sequences], dtype="|S1") - - @property - def query(self) -> str: - return self.entries[0].sequence - - def select_sequences(self, indices: Sequence[int] | np.ndarray) -> MSA: - """Subselect rows of the MSA.""" - entries = [self.entries[idx] for idx in indices] - return dataclasses.replace(self, entries=entries) - - def select_positions(self, indices: Sequence[int] | np.ndarray) -> MSA: - """Subselect columns of the MSA.""" - entries = [ - FastaEntry(header, "".join(seq[idx] for idx in indices)) - for header, seq in self.entries - ] - return dataclasses.replace(self, entries=entries) - - def __getitem__(self, indices: int | list[int] | slice | np.ndarray): - if isinstance(indices, int): - indices = [indices] - - entries = [ - FastaEntry(header, slice_any_object(seq, indices)) - for header, seq in self.entries - ] - return dataclasses.replace(self, entries=entries) - - def __len__(self): - return self.seqlen - - def greedy_select(self, num_seqs: int, mode: str = "max") -> MSA: - """Greedily select sequences that either maximize or minimize hamming distance. - - Algorithm proposed in the MSA Transformer paper. Starting from the query sequence, - iteratively add sequences to the list with the maximum (minimum) average Hamming - distance to the existing set of sequences. - - Args: - num_seqs (int): Number of sequences to select. - mode (str): Whether to maximize or minimize diversity. DO NOT pick 'min' unless - you're doing it to prove a point for a paper. - - Returns: - MSA object w/ subselected sequences. - """ - assert mode in ("max", "min") - if self.depth <= num_seqs: - return self - - indices = greedy_select_indices(self.array, num_seqs, mode) - return self.select_sequences(indices) - - def hhfilter( - self, - seqid: int = 90, - diff: int = 0, - cov: int = 0, - qid: int = 0, - qsc: float = -20.0, - binary: str = "hhfilter", - ) -> MSA: - """Apply hhfilter to the sequences in the MSA and return a filtered MSA.""" - - indices = hhfilter( - self.sequences, - seqid=seqid, - diff=diff, - cov=cov, - qid=qid, - qsc=qsc, - binary=binary, - ) - return self.select_sequences(indices) - - def select_random_sequences(self, num_seqs: int) -> MSA: - """Uses random sampling to subselect sequences from the MSA. Always - keeps the query sequence. - """ - if num_seqs >= self.depth: - return self - - # Subselect random, always keeping the query sequence. - indices = np.sort( - np.append( - 0, np.random.choice(self.depth - 1, num_seqs - 1, replace=False) + 1 - ) - ) - msa = self.select_sequences(indices) # type: ignore - return msa - - def select_diverse_sequences(self, num_seqs: int) -> MSA: - """Applies hhfilter to select ~num_seqs sequences, then uses random sampling - to subselect if necessary. - """ - if num_seqs >= self.depth: - return self - - msa = self.hhfilter(diff=num_seqs) - if num_seqs < msa.depth: - msa = msa.select_random_sequences(num_seqs) - return msa - - def pad_to_depth(self, depth: int) -> MSA: - if depth < self.depth: - raise ValueError(f"Cannot pad to depth {depth} when depth is {self.depth}") - elif depth == self.depth: - return self - - num_to_add = depth - self.depth - extra_entries = [FastaEntry("", "-" * self.seqlen) for _ in range(num_to_add)] - return dataclasses.replace(self, entries=self.entries + extra_entries) - - @classmethod - def stack( - cls, msas: Sequence[MSA], remove_query_from_later_msas: bool = True - ) -> MSA: - """Stack a series of MSAs. Optionally remove the query from msas after the first.""" - all_entries = [] - for i, msa in enumerate(msas): - entries = msa.entries - if i > 0 and remove_query_from_later_msas: - entries = entries[1:] - all_entries.extend(entries) - return cls(entries=all_entries) - - @cached_property - def seqid(self) -> np.ndarray: - array = self.array.view(np.uint8) - seqid = 1 - cdist(array[0][None], array, "hamming") - return seqid[0] - - @classmethod - def concat( - cls, - msas: Sequence[MSA], - join_token: str | None = "|", - allow_depth_mismatch: bool = False, - ) -> MSA: - """Concatenate a series of MSAs horizontally, along the sequence dimension.""" - if not msas: - raise ValueError("Cannot concatenate an empty list of MSAs") - msa_depths = [msa.depth for msa in msas] - if len(set(msa_depths)) != 1: - if not allow_depth_mismatch: - raise ValueError("Depth mismatch in concatenating MSAs") - else: - max_depth = max(msa_depths) - msas = [msa.pad_to_depth(max_depth) for msa in msas] - headers = [ - "|".join([str(h) for h in headers]) - for headers in zip(*(msa.headers for msa in msas)) - ] - - if join_token is None: - join_token = "" - - seqs = [join_token.join(vals) for vals in zip(*(msa.sequences for msa in msas))] - entries = [FastaEntry(header, seq) for header, seq in zip(headers, seqs)] - return cls(entries) - - -@dataclass(frozen=True) -class FastMSA(SequentialDataclass): - """Object-oriented interface to an MSA stored as a numpy uint8 array.""" - - array: np.ndarray - headers: list[str] | None = None - - def __post_init__(self): - if self.headers is not None: - assert ( - len(self.headers) == self.depth - ), "Number of headers must match depth." - - @classmethod - def from_bytes(cls, data: bytes) -> FastMSA: - version_bytes, seqlen_bytes, depth_bytes, data = ( - data[:1], - data[1:5], - data[5:9], - data[9:], - ) - version = int.from_bytes(version_bytes, "little") - if version != 1: - raise ValueError(f"Unsupported version: {version}") - seqlen = int.from_bytes(seqlen_bytes, "little") - depth = int.from_bytes(depth_bytes, "little") - array_bytes, header_bytes = data[: seqlen * depth], data[seqlen * depth :] - array = np.frombuffer(array_bytes, dtype="|S1") - array = array.reshape(depth, seqlen) - headers = header_bytes.decode().split("\n") - # Sometimes the separation is two newlines, which results in an empty header. - headers = [header for header in headers if header] - return cls(array, headers) - - @classmethod - def from_sequence_bytes(cls, data: bytes) -> FastMSA: - seqlen_bytes, array_bytes = data[:4], data[4:] - seqlen = int.from_bytes(seqlen_bytes, "little") - array = np.frombuffer(array_bytes, dtype="|S1") - array = array.reshape(-1, seqlen) - return cls(array) - - @property - def depth(self) -> int: - return self.array.shape[0] - - @property - def seqlen(self) -> int: - return self.array.shape[1] - - def __len__(self): - return self.seqlen - - def __getitem__(self, indices: int | list[int] | slice | np.ndarray): - if isinstance(indices, int): - indices = [indices] - - return dataclasses.replace(self, array=self.array[:, indices]) - - def select_sequences(self, indices: Sequence[int] | np.ndarray) -> FastMSA: - """Subselect rows of the MSA.""" - array = self.array[indices] - headers = ( - [self.headers[idx] for idx in indices] if self.headers is not None else None - ) - return dataclasses.replace(self, array=array, headers=headers) - - def select_random_sequences(self, num_seqs: int) -> FastMSA: - """Uses random sampling to subselect sequences from the MSA. Always - keeps the query sequence. - """ - if num_seqs >= self.depth: - return self - - # Subselect random, always keeping the query sequence. - indices = np.sort( - np.append( - 0, np.random.choice(self.depth - 1, num_seqs - 1, replace=False) + 1 - ) - ) - msa = self.select_sequences(indices) # type: ignore - return msa - - def pad_to_depth(self, depth: int) -> FastMSA: - if depth < self.depth: - raise ValueError(f"Cannot pad to depth {depth} when depth is {self.depth}") - elif depth == self.depth: - return self - - num_to_add = depth - self.depth - array = np.pad( - self.array, - [(0, num_to_add), (0, 0)], - constant_values=ord("-") if self.array.dtype == np.uint8 else b"-", - ) - headers = self.headers - if headers is not None: - headers = headers + [""] * num_to_add - return dataclasses.replace(self, array=array, headers=headers) - - @classmethod - def concat( - cls, - msas: Sequence[FastMSA], - join_token: str | None = None, - allow_depth_mismatch: bool = False, - ) -> FastMSA: - """Concatenate a series of MSAs horizontally, along the sequence dimension.""" - if not msas: - raise ValueError("Cannot concatenate an empty list of MSAs") - if join_token is not None and join_token != "": - raise NotImplementedError("join_token is not supported for FastMSA") - - msa_depths = [msa.depth for msa in msas] - if len(set(msa_depths)) != 1: - if not allow_depth_mismatch: - raise ValueError("Depth mismatch in concatenating MSAs") - else: - max_depth = max(msa_depths) - msas = [msa.pad_to_depth(max_depth) for msa in msas] - headers = [ - "|".join([str(h) for h in headers]) - for headers in zip( - *( - msa.headers if msa.headers is not None else [""] * msa.depth - for msa in msas - ) - ) - ] - - array = np.concatenate([msa.array for msa in msas], axis=1) - return cls(array, headers) - - def to_msa(self) -> MSA: - headers = ( - self.headers - if self.headers is not None - else [f"seq{i}" for i in range(self.depth)] - ) - entries = [ - FastaEntry(header, b"".join(row).decode()) - for header, row in zip(headers, self.array) - ] - return MSA(entries) - - @classmethod - def stack( - cls, msas: Sequence[FastMSA], remove_query_from_later_msas: bool = True - ) -> FastMSA: - """Stack a series of MSAs. Optionally remove the query from msas after the first.""" - arrays = [] - all_headers = [] - for i, msa in enumerate(msas): - array = msa.array - headers = msa.headers - if i > 0 and remove_query_from_later_msas: - array = array[1:] - if headers is not None: - headers = headers[1:] - arrays.append(array) - if headers is not None: - all_headers.extend(headers) - return cls(np.concatenate(arrays, axis=0), all_headers) diff --git a/esm/utils/parsing.py b/esm/utils/parsing.py deleted file mode 100644 index c47938ab..00000000 --- a/esm/utils/parsing.py +++ /dev/null @@ -1,83 +0,0 @@ -import io -from pathlib import Path -from typing import Generator, Iterable, NamedTuple - -PathOrBuffer = str | Path | io.TextIOBase -FastaEntry = NamedTuple("FastaEntry", [("header", str), ("sequence", str)]) - - -def parse_fasta(fasta_string: str) -> Generator[FastaEntry, None, None]: - """ - Parses a fasta file and yields FastaEntry objects - - Args: - fasta_string: The fasta file as a string - Returns: - A generator of FastaEntry objects - """ - header = None - seq = [] - num_sequences = 0 - for line in fasta_string.splitlines(): - if not line or line[0] == "#": - continue - if line.startswith(">"): - if header is not None: - yield FastaEntry(header, "".join(seq)) - seq = [] - header = line[1:].strip() - else: - seq.append(line) - if header is not None: - num_sequences += 1 - yield FastaEntry(header, "".join(seq)) - - if num_sequences == 0: - raise ValueError("Found no sequences in input") - - -def read_sequences(path: PathOrBuffer) -> Generator[FastaEntry, None, None]: - # Uses duck typing to try and call the right method - # Doesn't use explicit isinstance check to support - # inputs that are not explicitly str/Path/TextIOBase but - # may support similar functionality - data = None # type: ignore - try: - if str(path).endswith(".gz"): - import gzip - - data = gzip.open(path, "rt") # type: ignore - else: - try: - data = open(path) # type: ignore - except TypeError: - data: io.TextIOBase = path # type: ignore - - yield from parse_fasta(data.read()) - finally: - if data is not None: - data.close() - - -def read_first_sequence(path: PathOrBuffer) -> FastaEntry: - return next(iter(read_sequences(path))) - - -def write_sequences(sequences: Iterable[tuple[str, str]], path: PathOrBuffer) -> None: - needs_closing = False - handle = None - try: - try: - handle = open(path, "w") # type: ignore - needs_closing = True - except TypeError: - handle = path - has_prev = False - for header, seq in sequences: - if has_prev: - handle.write("\n") # type: ignore - handle.write(f">{header}\n{seq}") # type: ignore - has_prev = True - finally: - if needs_closing: - handle.close() # type: ignore diff --git a/esm/utils/sampling.py b/esm/utils/sampling.py index 68c5c868..fdf8658d 100644 --- a/esm/utils/sampling.py +++ b/esm/utils/sampling.py @@ -5,18 +5,9 @@ import torch import torch.nn.functional as F -from esm.sdk.api import ( - ESMProteinTensor, - SamplingConfig, - SamplingTrackConfig, -) -from esm.tokenization import ( - TokenizerCollectionProtocol, - get_invalid_tokenizer_ids, -) -from esm.tokenization.function_tokenizer import ( - InterProQuantizedTokenizer, -) +from esm.sdk.api import ESMProteinTensor, SamplingConfig, SamplingTrackConfig +from esm.tokenization import TokenizerCollectionProtocol, get_invalid_tokenizer_ids +from esm.tokenization.function_tokenizer import InterProQuantizedTokenizer from esm.utils.constants.esm3 import ( MAX_RESIDUE_ANNOTATIONS, SASA_DISCRETIZATION_BOUNDARIES, diff --git a/esm/utils/sequential_dataclass.py b/esm/utils/sequential_dataclass.py deleted file mode 100644 index 8e935798..00000000 --- a/esm/utils/sequential_dataclass.py +++ /dev/null @@ -1,157 +0,0 @@ -from abc import ABC, abstractmethod -from dataclasses import dataclass, fields, replace -from typing import TypeVar - -import numpy as np - -from esm.utils.misc import concat_objects, slice_any_object - -T = TypeVar("T") - - -@dataclass(frozen=True) -class SequentialDataclass(ABC): - """ - This is a builder on a dataclass that allows for automatic slicing and concatenation. - - When representing multimodal data, we often have multiple datatypes which have sequence dimensions that are the same (e.g. the length of the protein). - - When applying a transformation like a crop, we want to apply this to all tensors at the same time (e.g. crop the sequence, structure, and function). - - We also have some fields that are not sequential (like an id, or data source), which we don't want to crop. - - The SequentialDataclass abstracts this cropping away, allowing you to define dataclasses that implement `__len__`, `__getitem__` and `concat` automatically. - - This is done through the `metadata` field, which can take 3 values: - `sequence` (bool): True or False, tells the dataclass whether this field is a sequential type. Default: False. - `sequence_dim` (int): Which dimension is the sequential dimension (e.g. for a list of inverse folded sequences, we want to index each sequence in the list, not the list itself). Default: 0. - `join_token` (Any): What token to use to join when concatenating elements. Default: None. - - - Example: - - @dataclass(frozen=True) - class Foo(SequentialDataclass): - id: str - sequence: str = field(metadata={"sequence": True, "join_token": "|"}) - tensor: torch.Tensor = field(metadata={"sequence": True, "join_token": torch.nan}) - - def __len__(self): - # Must implement the __len__ method - return len(self.sequence) - - >>> foo = Foo(id="foo", sequence="ABCDE", tensor=torch.randn(5)) - Foo(id='foo', sequence='ABCDE', tensor=tensor([ 0.0252, -0.3335, -0.5143, 0.0251, -1.0717])) - - >>> foo[1:4] - Foo(id='foo', sequence='BCD', tensor=tensor([-0.3335, -0.5143, 0.0251])) - - >>> foo[np.arange(5) < 3] - Foo(id='foo', sequence='ABC', tensor=tensor([ 0.0252, -0.3335, -0.5143])) - - >>> Foo.concat([foo[:2], foo[3:]]) - Foo(id='foo', sequence='AB|DE', tensor=tensor([ 0.0252, -0.3335, nan, 0.0251, -1.0717])) - - # Trying to create a type where the sequence lengths do not match raises an error - >>> foo = Foo(id="foo", sequence="ABCDE", tensor=torch.randn(6)) - ValueError: Mismatch in sequence length for field: tensor. Expected 5, received 6 - - """ - - def __post_init__(self): - self._check_sequence_lengths_match() - - @abstractmethod - def __len__(self): - raise NotImplementedError - - def __getitem__(self, idx: int | list[int] | slice | np.ndarray): - updated_fields = {} - if isinstance(idx, int): - # make it so that things remain sequential - idx = [idx] - - for fld in fields(self): - if fld.metadata.get("sequence", False): - # this is a sequence, should be the same length as all other sequences - sequence_dim = fld.metadata.get("sequence_dim", 0) - value = getattr(self, fld.name) - if value is None: - continue - match sequence_dim: - case 0: - # sequence is first dimension - value = getattr(self, fld.name) - value = slice_any_object(value, idx) - updated_fields[fld.name] = value - case 1: - new_value = [slice_any_object(item, idx) for item in value] - updated_fields[fld.name] = value.__class__(new_value) - case _: - raise NotImplementedError( - "Arbitrary slicing for different sequence length fields is not implemented" - ) - - return replace(self, **updated_fields) - - def _check_sequence_lengths_match(self): - """Checks if sequence lengths of all "sequence" fields match.""" - for fld in fields(self): - if fld.metadata.get("sequence", False) and fld.name != "complex": - # this is a sequence, should be the same length as all other sequences - sequence_dim = fld.metadata.get("sequence_dim", 0) - value = getattr(self, fld.name) - if value is None: - continue - match sequence_dim: - case 0: - # sequence is first dimension - value = getattr(self, fld.name) - if len(value) != len(self): - raise ValueError( - f"Mismatch in sequence length for field: {fld.name}. Expected {len(self)}, received {len(value)}" - ) - case 1: - for item in value: - if len(item) != len(self): - raise ValueError( - f"Mismatch in sequence length for field: {fld.name}. Expected {len(self)}, received {len(item)}" - ) - case _: - raise NotImplementedError( - "Arbitrary matching for different sequence length fields is not implemented" - ) - - @classmethod - def concat(cls, items: list[T], **kwargs) -> T: - updated_fields = {} - for fld in fields(cls): - if fld.metadata.get("sequence", False): - # this is a sequence, should be the same length as all other sequences - sequence_dim = fld.metadata.get("sequence_dim", 0) - join_value = fld.metadata.get("join_token", None) - if getattr(items[0], fld.name) is None: - continue - values = [getattr(item, fld.name) for item in items] - match sequence_dim: - case 0: - # sequence is first dimension - value = concat_objects(values, join_value) - updated_fields[fld.name] = value - case 1: - new_value = [ - concat_objects(item, join_value) for item in zip(*values) - ] - updated_fields[fld.name] = getattr( - items[0], fld.name - ).__class__(new_value) - case _: - raise NotImplementedError( - "Arbitrary joining for different sequence length fields is not implemented" - ) - updated_fields.update(kwargs) - - return replace( - items[0], # type: ignore - **updated_fields, - ) diff --git a/esm/utils/structure/aligner.py b/esm/utils/structure/aligner.py index dd6702aa..f25d9987 100644 --- a/esm/utils/structure/aligner.py +++ b/esm/utils/structure/aligner.py @@ -6,9 +6,7 @@ import numpy as np import torch -from esm.utils.structure.protein_structure import ( - compute_affine_and_rmsd, -) +from esm.utils.structure.protein_structure import compute_affine_and_rmsd class Alignable(Protocol): diff --git a/esm/utils/structure/atom_indexer.py b/esm/utils/structure/atom_indexer.py index 2f588b98..d62f05c9 100644 --- a/esm/utils/structure/atom_indexer.py +++ b/esm/utils/structure/atom_indexer.py @@ -1,8 +1,6 @@ import numpy as np -from esm.utils.structure.protein_structure import ( - index_by_atom_name, -) +from esm.utils.structure.protein_structure import index_by_atom_name class AtomIndexer: diff --git a/esm/utils/structure/input_builder.py b/esm/utils/structure/input_builder.py deleted file mode 100644 index 158c887c..00000000 --- a/esm/utils/structure/input_builder.py +++ /dev/null @@ -1,96 +0,0 @@ -from dataclasses import dataclass -from typing import Any, Sequence - -import numpy as np - - - -@dataclass -class Modification: - position: int # zero-indexed - ccd: str - - -@dataclass -class ProteinInput: - id: str | list[str] - sequence: str - modifications: list[Modification] | None = None - - -@dataclass -class RNAInput: - id: str | list[str] - sequence: str - modifications: list[Modification] | None = None - - -@dataclass -class DNAInput: - id: str | list[str] - sequence: str - modifications: list[Modification] | None = None - - -@dataclass -class LigandInput: - id: str | list[str] - smiles: str - ccd: list[str] | None = None - - -@dataclass -class DistogramConditioning: - chain_id: str - distogram: np.ndarray - - -@dataclass -class PocketConditioning: - binder_chain_id: str - contacts: list[tuple[str, int]] - - -@dataclass -class StructurePredictionInput: - sequences: Sequence[ProteinInput | RNAInput | DNAInput | LigandInput] - pocket: PocketConditioning | None = None - distogram_conditioning: list[DistogramConditioning] | None = None - - -def serialize_structure_prediction_input(all_atom_input: StructurePredictionInput): - def create_chain_data(seq_input, chain_type: str) -> dict[str, Any]: - chain_data: dict[str, Any] = { - "sequence": seq_input.sequence, - "id": seq_input.id, - "type": chain_type, - } - if hasattr(seq_input, "modifications") and seq_input.modifications: - mods = [ - {"position": mod.position, "ccd": mod.ccd} - for mod in seq_input.modifications - ] - chain_data["modifications"] = mods - return chain_data - - sequences = [] - for seq_input in all_atom_input.sequences: - if isinstance(seq_input, ProteinInput): - sequences.append(create_chain_data(seq_input, "protein")) - elif isinstance(seq_input, RNAInput): - sequences.append(create_chain_data(seq_input, "rna")) - elif isinstance(seq_input, DNAInput): - sequences.append(create_chain_data(seq_input, "dna")) - elif isinstance(seq_input, LigandInput): - sequences.append( - { - "smiles": seq_input.smiles, - "id": seq_input.id, - "ccd": seq_input.ccd, - "type": "ligand", - } - ) - else: - raise ValueError(f"Unsupported sequence input type: {type(seq_input)}") - - return {"sequences": sequences} diff --git a/esm/utils/structure/metrics.py b/esm/utils/structure/metrics.py index b2e590db..d76ed766 100644 --- a/esm/utils/structure/metrics.py +++ b/esm/utils/structure/metrics.py @@ -264,7 +264,7 @@ def compute_lddt_ca( if all_atom_pred_pos.dim() != 3: all_atom_pred_pos = all_atom_pred_pos[..., ca_pos, :] all_atom_positions = all_atom_positions[..., ca_pos, :] - all_atom_mask = all_atom_mask[..., ca_pos] + all_atom_mask = all_atom_mask[..., ca_pos : (ca_pos + 1)] # keep dim return compute_lddt( all_atom_pred_pos, diff --git a/esm/utils/structure/molecular_complex.py b/esm/utils/structure/molecular_complex.py deleted file mode 100644 index 40bfae4e..00000000 --- a/esm/utils/structure/molecular_complex.py +++ /dev/null @@ -1,944 +0,0 @@ -from __future__ import annotations - -import io -import os -import re -from dataclasses import asdict, dataclass -from pathlib import Path -from subprocess import check_output -from tempfile import TemporaryDirectory -from typing import TYPE_CHECKING, Any, List - -import biotite.structure.io.pdbx as pdbx -import brotli -import msgpack -import numpy as np -import torch - -from esm.utils import residue_constants -from esm.utils.structure.metrics import ( - compute_lddt, - compute_rmsd, -) -from esm.utils.structure.protein_complex import ( - ProteinComplex, - ProteinComplexMetadata, -) - - -@dataclass -class MolecularComplexResult: - """Result of molecular complex folding""" - - complex: MolecularComplex - plddt: torch.Tensor | None = None - ptm: float | None = None - iptm: float | None = None - pae: torch.Tensor | None = None - distogram: torch.Tensor | None = None - pair_chains_iptm: torch.Tensor | None = None - output_embedding_sequence: torch.Tensor | None = None - output_embedding_pair_pooled: torch.Tensor | None = None - - -@dataclass -class MolecularComplexMetadata: - """Metadata for MolecularComplex objects.""" - - entity_lookup: dict[int, str] - chain_lookup: dict[int, str] - assembly_composition: dict[str, list[str]] | None = None - - -@dataclass -class Molecule: - """Represents a single molecule/token within a MolecularComplex.""" - - token: str - token_idx: int - atom_positions: np.ndarray # [N_atoms, 3] - atom_elements: np.ndarray # [N_atoms] element strings - residue_type: int - molecule_type: int # PROTEIN=0, RNA=1, DNA=2, LIGAND=3 - confidence: float - - -@dataclass(frozen=True) -class MolecularComplex: - """ - Dataclass representing a molecular complex with support for proteins, nucleic acids, and ligands. - - Uses a flat atom representation with token-based sequence indexing, supporting all atom types - beyond the traditional atom37 protein representation. - """ - - id: str - sequence: List[str] # Token sequence like ['MET', 'LYS', 'A', 'G', 'ATP'] - - # Flat atom arrays - simplified representation - atom_positions: np.ndarray # [N_atoms, 3] 3D coordinates - atom_elements: np.ndarray # [N_atoms] element strings - - # Token-to-atom mapping for efficient access - token_to_atoms: np.ndarray # [N_tokens, 2] start/end indices into atoms array - - # Confidence data - plddt: np.ndarray # Per-token confidence scores [N_tokens] - - # Metadata - metadata: MolecularComplexMetadata - - def __post_init__(self): - """Validate array dimensions.""" - n_tokens = len(self.sequence) - assert ( - self.token_to_atoms.shape[0] == n_tokens - ), f"token_to_atoms shape {self.token_to_atoms.shape} != {n_tokens} tokens" - assert ( - self.plddt.shape[0] == n_tokens - ), f"plddt shape {self.plddt.shape} != {n_tokens} tokens" - - def __len__(self) -> int: - """Return number of tokens.""" - return len(self.sequence) - - def __getitem__(self, idx: int) -> Molecule: - """Access individual molecules/tokens by index.""" - if idx >= len(self.sequence) or idx < 0: - raise IndexError( - f"Token index {idx} out of range for {len(self.sequence)} tokens" - ) - - token = self.sequence[idx] - start_atom, end_atom = self.token_to_atoms[idx] - - # Extract atom data for this token - token_atom_positions = self.atom_positions[start_atom:end_atom] - token_atom_elements = self.atom_elements[start_atom:end_atom] - - # Default values for residue/molecule type (would be extended based on actual implementation) - residue_type = 0 # Default to standard residue - molecule_type = 0 # Default to protein - - return Molecule( - token=token, - token_idx=idx, - atom_positions=token_atom_positions, - atom_elements=token_atom_elements, - residue_type=residue_type, - molecule_type=molecule_type, - confidence=self.plddt[idx], - ) - - @property - def atom_coordinates(self) -> np.ndarray: - """Get flat array of all atom coordinates [N_atoms, 3].""" - return self.atom_positions - - # Conversion methods - @classmethod - def from_protein_complex(cls, pc: ProteinComplex) -> "MolecularComplex": - """Convert a ProteinComplex to MolecularComplex. - - Args: - pc: ProteinComplex object with atom37 representation - - Returns: - MolecularComplex with flat atom arrays and token-based indexing - """ - from esm.utils import residue_constants - - # Extract sequence without chain breaks - sequence_no_breaks = pc.sequence.replace("|", "") - sequence_tokens = [ - residue_constants.restype_1to3.get(aa, "UNK") for aa in sequence_no_breaks - ] - - # Convert atom37 to flat arrays - flat_positions = [] - flat_elements = [] - token_to_atoms = [] - - atom_idx = 0 - residue_idx = 0 - - for i, aa in enumerate(pc.sequence): - if aa == "|": - # Skip chain break tokens - continue - - # Get atom37 positions and mask for this residue - res_positions = pc.atom37_positions[residue_idx] # [37, 3] - res_mask = pc.atom37_mask[residue_idx] # [37] - - # Track start position for this token - token_start = atom_idx - - # Process each atom type in atom37 representation - for atom_type_idx, atom_name in enumerate(residue_constants.atom_types): - if res_mask[atom_type_idx]: # Atom is present - # Add position - flat_positions.append(res_positions[atom_type_idx]) - - # Determine element from atom name - element = ( - atom_name[0] if atom_name else "C" - ) # First character is element - flat_elements.append(element) - - atom_idx += 1 - - # Record token-to-atom mapping [start_idx, end_idx) - token_to_atoms.append([token_start, atom_idx]) - residue_idx += 1 - - # Convert to numpy arrays - atom_positions = np.array(flat_positions, dtype=np.float32) - atom_elements = np.array(flat_elements, dtype=object) - token_to_atoms_array = np.array(token_to_atoms, dtype=np.int32) - - # Extract confidence scores (skip chain breaks) - confidence_scores = [] - residue_idx = 0 - for aa in pc.sequence: - if aa != "|": - confidence_scores.append(pc.confidence[residue_idx]) - residue_idx += 1 - - confidence_array = np.array(confidence_scores, dtype=np.float32) - - # Create metadata - convert entity IDs to strings for MolecularComplexMetadata - entity_lookup_str = {k: str(v) for k, v in pc.metadata.entity_lookup.items()} - metadata = MolecularComplexMetadata( - entity_lookup=entity_lookup_str, - chain_lookup=pc.metadata.chain_lookup, - assembly_composition=pc.metadata.assembly_composition, - ) - - return cls( - id=pc.id, - sequence=sequence_tokens, - atom_positions=atom_positions, - atom_elements=atom_elements, - token_to_atoms=token_to_atoms_array, - plddt=confidence_array, - metadata=metadata, - ) - - def to_protein_complex(self) -> ProteinComplex: - """Convert MolecularComplex back to ProteinComplex format. - - Extracts only protein tokens and converts from flat atom representation - back to atom37 format used by ProteinComplex. - - Returns: - ProteinComplex with protein residues only, excluding ligands/nucleic acids - """ - from esm.utils import residue_constants - - # No need for element mapping - already using element characters - - # Filter for protein tokens only (skip ligands, nucleic acids) - protein_tokens = [] - protein_indices = [] - - for i, token in enumerate(self.sequence): - # Check if token is a standard 3-letter amino acid code - if token in residue_constants.restype_3to1: - protein_tokens.append(token) - protein_indices.append(i) - - if not protein_tokens: - raise ValueError("No protein tokens found in MolecularComplex") - - n_residues = len(protein_tokens) - - # Initialize atom37 arrays - atom37_positions = np.full((n_residues, 37, 3), np.nan, dtype=np.float32) - atom37_mask = np.zeros((n_residues, 37), dtype=bool) - - # Convert tokens back to single-letter sequence - single_letter_sequence = "".join( - [residue_constants.restype_3to1[token] for token in protein_tokens] - ) - - # Extract confidence scores for protein residues only - protein_confidence = self.plddt[protein_indices] - - # Convert flat atoms back to atom37 representation - for res_idx, token_idx in enumerate(protein_indices): - token = self.sequence[token_idx] - start_atom, end_atom = self.token_to_atoms[token_idx] - - # Get atom data for this residue - res_atom_positions = self.atom_positions[start_atom:end_atom] - - # Reconstruct atom37 representation by exactly reversing the forward conversion logic - # In from_protein_complex, atoms are added in atom_types order if present in mask - # So we need to reconstruct the mask and positions in the same order - atom_count = 0 - for atom_type_idx, atom_name in enumerate(residue_constants.atom_types): - # Check if this atom type exists for this residue and was present - residue_atoms = residue_constants.residue_atoms.get(token, []) - if atom_name in residue_atoms: - # This atom type exists for this residue, so it should have been included - if atom_count < len(res_atom_positions): - atom37_positions[res_idx, atom_type_idx] = res_atom_positions[ - atom_count - ] - atom37_mask[res_idx, atom_type_idx] = True - atom_count += 1 - - # Create other required arrays for ProteinComplex - # For simplicity, assume all protein residues belong to the same entity/chain - entity_id = np.zeros(n_residues, dtype=np.int64) - chain_id = np.zeros(n_residues, dtype=np.int64) - sym_id = np.zeros(n_residues, dtype=np.int64) - residue_index = np.arange(1, n_residues + 1, dtype=np.int64) - insertion_code = np.array([""] * n_residues, dtype=object) - - # Create simplified protein complex metadata - # Map the first entity/chain from molecular complex metadata - protein_metadata = ProteinComplexMetadata( - entity_lookup={0: 1}, # Single entity (int for ProteinComplexMetadata) - chain_lookup={0: "A"}, # Single chain - assembly_composition=self.metadata.assembly_composition, - ) - - return ProteinComplex( - id=self.id, - sequence=single_letter_sequence, - entity_id=entity_id, - chain_id=chain_id, - sym_id=sym_id, - residue_index=residue_index, - insertion_code=insertion_code, - atom37_positions=atom37_positions, - atom37_mask=atom37_mask, - confidence=protein_confidence, - metadata=protein_metadata, - ) - - @classmethod - def from_mmcif(cls, inp: str, id: str | None = None) -> "MolecularComplex": - """Read MolecularComplex from mmcif file or string. - - Args: - inp: Path to mmCIF file or mmCIF content as string - id: Optional identifier to assign to the complex - - Returns: - MolecularComplex with all molecules (proteins, ligands, nucleic acids) - """ - from io import StringIO - - # Check if input is a file path or mmCIF string content - if os.path.exists(inp): - # Input is a file path - mmcif_file = pdbx.CIFFile.read(inp) - else: - # Input is mmCIF string content - mmcif_file = pdbx.CIFFile.read(StringIO(inp)) - - # Get structure - handle missing model information gracefully - try: - structure = pdbx.get_structure(mmcif_file, model=1) - except (KeyError, ValueError): - # Fallback for mmCIF files without model information - try: - structure = pdbx.get_structure(mmcif_file) - except Exception: - # Last resort: use the first available model or all atoms - structure = pdbx.get_structure(mmcif_file, model=None) - # Type hint for pyright - structure is an AtomArray which is iterable - if TYPE_CHECKING: - structure: Any = structure - - # Get entity information from mmCIF - entity_info = {} - try: - # Access the first block in CIFFile - block = mmcif_file[0] - if "entity" in block: - entity_category = block["entity"] - if "id" in entity_category and "type" in entity_category: - entity_ids = entity_category["id"] - entity_types = entity_category["type"] - # Convert CIFColumn to list for iteration - if hasattr(entity_ids, "__iter__") and hasattr( - entity_types, "__iter__" - ): - # Type annotation to help pyright understand these are iterable - entity_ids_list = list(entity_ids) # type: ignore - entity_types_list = list(entity_types) # type: ignore - for eid, etype in zip(entity_ids_list, entity_types_list): - entity_info[eid] = etype - except Exception: - pass - - # Initialize arrays for flat atom representation - sequence_tokens = [] - flat_positions = [] - flat_elements = [] - token_to_atoms = [] - confidence_scores = [] - - atom_idx = 0 - - # Group atoms by chain and residue - chain_residue_groups = {} - for atom in structure: - chain_id = atom.chain_id - res_id = atom.res_id - res_name = atom.res_name - - if chain_id not in chain_residue_groups: - chain_residue_groups[chain_id] = {} - if res_id not in chain_residue_groups[chain_id]: - chain_residue_groups[chain_id][res_id] = { - "atoms": [], - "res_name": res_name, - "is_hetero": atom.hetero, - } - chain_residue_groups[chain_id][res_id]["atoms"].append(atom) - - # Process each chain and residue - for chain_id in sorted(chain_residue_groups.keys()): - residues = chain_residue_groups[chain_id] - - for res_id in sorted(residues.keys()): - residue_data = residues[res_id] - res_name = residue_data["res_name"] - atoms = residue_data["atoms"] - is_hetero = residue_data["is_hetero"] - - # Skip water molecules - if res_name == "HOH": - continue - - # Determine token name - if not is_hetero and res_name in residue_constants.restype_3to1: - # Standard amino acid - token_name = res_name - elif res_name in ["A", "T", "G", "C", "U", "DA", "DT", "DG", "DC"]: - # Nucleotide - token_name = res_name - else: - # Ligand or other molecule - token_name = res_name - - sequence_tokens.append(token_name) - token_start = atom_idx - - # Add all atoms from this residue - for atom in atoms: - flat_positions.append(atom.coord) - - # Get element character - element = atom.element - flat_elements.append(element) - - atom_idx += 1 - - # Record token-to-atom mapping - token_to_atoms.append([token_start, atom_idx]) - - # Add confidence score (B-factor if available, otherwise 1.0) - bfactor = getattr(atoms[0], "b_factor", 50.0) if atoms else 50.0 - confidence_scores.append(min(bfactor / 100.0, 1.0)) - - # Convert to numpy arrays - if not flat_positions: - # Create minimal arrays if no atoms found - atom_positions = np.zeros((0, 3), dtype=np.float32) - atom_elements = np.zeros(0, dtype=object) - token_to_atoms_array = np.zeros((len(sequence_tokens), 2), dtype=np.int32) - else: - atom_positions = np.array(flat_positions, dtype=np.float32) - atom_elements = np.array(flat_elements, dtype=object) - token_to_atoms_array = np.array(token_to_atoms, dtype=np.int32) - - confidence_array = np.array(confidence_scores, dtype=np.float32) - - # Create metadata - metadata = MolecularComplexMetadata( - entity_lookup=entity_info, - chain_lookup={ - i: chain_id for i, chain_id in enumerate(chain_residue_groups.keys()) - }, - assembly_composition=None, - ) - - # Set complex ID - if input was a path, use the stem; otherwise use default - if os.path.exists(inp): - complex_id = id or Path(inp).stem - else: - complex_id = id or "complex_from_string" - - return cls( - id=complex_id, - sequence=sequence_tokens, - atom_positions=atom_positions, - atom_elements=atom_elements, - token_to_atoms=token_to_atoms_array, - plddt=confidence_array, - metadata=metadata, - ) - - def to_mmcif(self) -> str: - """Write MolecularComplex to mmcif string. - - Returns: - String representation of the complex in mmCIF format - """ - # No need for element mapping - already using element characters - - lines = [] - - # Header - lines.append(f"data_{self.id}") - lines.append("#") - lines.append(f"_entry.id {self.id}") - lines.append("#") - - # Structure metadata - lines.append("_struct.entry_id {}".format(self.id)) - lines.append("_struct.title 'Protein Structure'") - lines.append("#") - - # Entity information - entity_id = 1 - chain_counter = 0 - lines.append("loop_") - lines.append("_entity.id") - lines.append("_entity.type") - lines.append("_entity.pdbx_description") - - # Determine entities based on sequence - protein_tokens = [] - other_tokens = [] - - for i, token in enumerate(self.sequence): - if token in residue_constants.restype_3to1: - protein_tokens.append((i, token)) - else: - other_tokens.append((i, token)) - - if protein_tokens: - lines.append(f"{entity_id} polymer 'Protein chain'") - entity_id += 1 - - for token in set(token for _, token in other_tokens): - lines.append(f"{entity_id} non-polymer 'Ligand {token}'") - entity_id += 1 - - lines.append("#") - - # Chain assignments - lines.append("loop_") - lines.append("_struct_asym.id") - lines.append("_struct_asym.entity_id") - - chain_id = "A" - if protein_tokens: - lines.append(f"{chain_id} 1") - chain_counter += 1 - chain_id = chr(ord(chain_id) + 1) - - entity_id = 2 - for token in set(token for _, token in other_tokens): - lines.append(f"{chain_id} {entity_id}") - entity_id += 1 - chain_counter += 1 - if chain_counter < 26: - chain_id = chr(ord(chain_id) + 1) - - lines.append("#") - - # Atom site information - lines.append("loop_") - lines.append("_atom_site.group_PDB") - lines.append("_atom_site.id") - lines.append("_atom_site.type_symbol") - lines.append("_atom_site.label_atom_id") - lines.append("_atom_site.label_alt_id") - lines.append("_atom_site.label_comp_id") - lines.append("_atom_site.label_asym_id") - lines.append("_atom_site.label_entity_id") - lines.append("_atom_site.label_seq_id") - lines.append("_atom_site.pdbx_PDB_ins_code") - lines.append("_atom_site.Cartn_x") - lines.append("_atom_site.Cartn_y") - lines.append("_atom_site.Cartn_z") - lines.append("_atom_site.occupancy") - lines.append("_atom_site.B_iso_or_equiv") - lines.append("_atom_site.pdbx_PDB_model_num") - lines.append("_atom_site.auth_seq_id") - lines.append("_atom_site.auth_comp_id") - lines.append("_atom_site.auth_asym_id") - lines.append("_atom_site.auth_atom_id") - - atom_id = 1 - seq_id = 1 - chain_id = "A" - entity_id = 1 - - for token_idx, token in enumerate(self.sequence): - start_atom, end_atom = self.token_to_atoms[token_idx] - - # Determine if this is a protein residue or ligand - is_protein = token in residue_constants.restype_3to1 - group_pdb = "ATOM" if is_protein else "HETATM" - current_entity_id = 1 if is_protein else 2 # Simplified entity assignment - current_chain_id = "A" if is_protein else "B" # Simplified chain assignment - - # Create atom names for this token - atom_names = [] - if is_protein: - # Use standard protein atom names - res_atoms = residue_constants.residue_atoms.get( - token, ["N", "CA", "C", "O"] - ) - atom_names = res_atoms[: end_atom - start_atom] - else: - # Generate generic atom names for ligands - for i in range(end_atom - start_atom): - atom_names.append(f"C{i+1}") - - # Pad atom names if needed - while len(atom_names) < (end_atom - start_atom): - atom_names.append(f"X{len(atom_names)+1}") - - # Write atoms for this token - for atom_idx_in_token, global_atom_idx in enumerate( - range(start_atom, end_atom) - ): - pos = self.atom_positions[global_atom_idx] - element_char = self.atom_elements[global_atom_idx] - element_symbol = element_char if isinstance(element_char, str) else "C" - - atom_name = ( - atom_names[atom_idx_in_token] - if atom_idx_in_token < len(atom_names) - else f"X{atom_idx_in_token+1}" - ) - - # Format atom site line - bfactor = ( - self.plddt[token_idx] * 100.0 - if len(self.plddt) > token_idx - else 50.0 - ) - - line = ( - f"{group_pdb:<6} {atom_id:>5} {element_symbol:<2} {atom_name:<4} . " - f"{token:<3} {current_chain_id} {current_entity_id} {seq_id:>3} ? " - f"{pos[0]:>8.3f} {pos[1]:>8.3f} {pos[2]:>8.3f} 1.00 {bfactor:>6.2f} 1 " - f"{seq_id:>3} {token:<3} {current_chain_id} {atom_name:<4}" - ) - lines.append(line) - atom_id += 1 - - seq_id += 1 - - lines.append("#") - return "\n".join(lines) - - def dockq(self, native: "MolecularComplex") -> Any: - """Compute DockQ score against native structure. - - Args: - native: Native MolecularComplex to compute DockQ against - - Returns: - DockQ result containing score and alignment information - """ - # Imports moved to top of file - - # Convert both complexes to ProteinComplex format for DockQ computation - # This extracts only the protein portion and converts to PDB format - try: - self_pc = self.to_protein_complex() - native_pc = native.to_protein_complex() - except ValueError as e: - raise ValueError( - f"Cannot convert MolecularComplex to ProteinComplex for DockQ: {e}" - ) - - # Normalize chain IDs for PDB compatibility - self_pc = self_pc.normalize_chain_ids_for_pdb() - native_pc = native_pc.normalize_chain_ids_for_pdb() - - # Use the existing ProteinComplex.dockq() method - try: - dockq_result = self_pc.dockq(native_pc) - return dockq_result - except Exception: - # Fallback to manual DockQ computation if ProteinComplex.dockq() fails - return self._compute_dockq_manual(native) - - def _compute_dockq_manual(self, native: "MolecularComplex") -> Any: - """Manual DockQ computation fallback.""" - # Imports moved to top of file - - # Convert both complexes to ProteinComplex format - try: - self_pc = self.to_protein_complex() - native_pc = native.to_protein_complex() - except ValueError as e: - raise ValueError( - f"Cannot convert MolecularComplex to ProteinComplex for DockQ: {e}" - ) - - # Normalize chain IDs for PDB compatibility - self_pc = self_pc.normalize_chain_ids_for_pdb() - native_pc = native_pc.normalize_chain_ids_for_pdb() - - # Write temporary PDB files and run DockQ - with TemporaryDirectory() as tdir: - dir_path = Path(tdir) - self_pdb = dir_path / "self.pdb" - native_pdb = dir_path / "native.pdb" - - # Write PDB files - self_pc.to_pdb(self_pdb) - native_pc.to_pdb(native_pdb) - - # Run DockQ - try: - output = check_output(["DockQ", str(self_pdb), str(native_pdb)]) - output_text = output.decode() - - # Parse DockQ output - lines = output_text.split("\n") - - # Find the total DockQ score - dockq_score = None - for line in lines: - if "Total DockQ" in line: - match = re.search(r"Total DockQ.*: ([\d.]+)", line) - if match: - dockq_score = float(match.group(1)) - break - - if dockq_score is None: - # Try to find individual DockQ scores - for line in lines: - if line.startswith("DockQ") and ":" in line: - try: - dockq_score = float(line.split(":")[1].strip()) - break - except (ValueError, IndexError): - continue - - if dockq_score is None: - raise ValueError("Could not parse DockQ score from output") - - # Return a simple result structure - return { - "total_dockq": dockq_score, - "raw_output": output_text, - "aligned": self, # Return self as aligned structure - } - - except FileNotFoundError: - raise RuntimeError( - "DockQ is not installed. Please install DockQ to use this method." - ) - except Exception as e: - raise RuntimeError(f"DockQ computation failed: {e}") - - def rmsd(self, target: "MolecularComplex", **kwargs) -> float: - """Compute RMSD against target structure. - - Args: - target: Target MolecularComplex to compute RMSD against - **kwargs: Additional arguments passed to compute_rmsd - - Returns: - float: RMSD value between the two structures - """ - # Imports moved to top of file - - # Ensure both complexes have the same number of tokens - if len(self) != len(target): - raise ValueError( - f"Complexes must have the same number of tokens: {len(self)} vs {len(target)}" - ) - - # Extract center positions for each token (using centroid of atoms) - mobile_coords = [] - target_coords = [] - atom_mask = [] - - for i in range(len(self)): - # Get atom positions for this token - mobile_start, mobile_end = self.token_to_atoms[i] - target_start, target_end = target.token_to_atoms[i] - - # Extract atom positions - mobile_atoms = self.atom_positions[mobile_start:mobile_end] - target_atoms = target.atom_positions[target_start:target_end] - - # Check if both tokens have atoms - if len(mobile_atoms) == 0 or len(target_atoms) == 0: - # Skip tokens with no atoms - continue - - # For simplicity, use the centroid of atoms as the representative position - mobile_center = mobile_atoms.mean(axis=0) - target_center = target_atoms.mean(axis=0) - - mobile_coords.append(mobile_center) - target_coords.append(target_center) - atom_mask.append(True) - - if len(mobile_coords) == 0: - raise ValueError("No valid atoms found for RMSD computation") - - # Convert to tensors - mobile_tensor = torch.from_numpy(np.stack(mobile_coords, axis=0)).unsqueeze( - 0 - ) # [1, N, 3] - target_tensor = torch.from_numpy(np.stack(target_coords, axis=0)).unsqueeze( - 0 - ) # [1, N, 3] - mask_tensor = torch.tensor(atom_mask, dtype=torch.bool).unsqueeze(0) # [1, N] - - # Compute RMSD using existing infrastructure - rmsd_value = compute_rmsd( - mobile=mobile_tensor, - target=target_tensor, - atom_exists_mask=mask_tensor, - reduction="batch", - **kwargs, - ) - - return float(rmsd_value) - - def lddt_ca(self, target: "MolecularComplex", **kwargs) -> float: - """Compute LDDT score against target structure. - - Args: - target: Target MolecularComplex to compute LDDT against - **kwargs: Additional arguments passed to compute_lddt - - Returns: - float: LDDT value between the two structures - """ - # Imports moved to top of file - - # Ensure both complexes have the same number of tokens - if len(self) != len(target): - raise ValueError( - f"Complexes must have the same number of tokens: {len(self)} vs {len(target)}" - ) - - # Extract center positions for each token (using centroid of atoms) - mobile_coords = [] - target_coords = [] - atom_mask = [] - - for i in range(len(self)): - # Get atom positions for this token - mobile_start, mobile_end = self.token_to_atoms[i] - target_start, target_end = target.token_to_atoms[i] - - # Extract atom positions - mobile_atoms = self.atom_positions[mobile_start:mobile_end] - target_atoms = target.atom_positions[target_start:target_end] - - # Check if both tokens have atoms - if len(mobile_atoms) == 0 or len(target_atoms) == 0: - # Skip tokens with no atoms - mobile_coords.append(np.full(3, np.nan)) - target_coords.append(np.full(3, np.nan)) - atom_mask.append(False) - continue - - # For simplicity, use the centroid of atoms as the representative position - mobile_center = mobile_atoms.mean(axis=0) - target_center = target_atoms.mean(axis=0) - - mobile_coords.append(mobile_center) - target_coords.append(target_center) - atom_mask.append(True) - - if not any(atom_mask): - raise ValueError("No valid atoms found for LDDT computation") - - # Convert to tensors - mobile_tensor = torch.from_numpy(np.stack(mobile_coords, axis=0)).unsqueeze( - 0 - ) # [1, N, 3] - target_tensor = torch.from_numpy(np.stack(target_coords, axis=0)).unsqueeze( - 0 - ) # [1, N, 3] - mask_tensor = torch.tensor(atom_mask, dtype=torch.bool).unsqueeze(0) # [1, N] - - # Compute LDDT using existing infrastructure - lddt_value = compute_lddt( - all_atom_pred_pos=mobile_tensor, - all_atom_positions=target_tensor, - all_atom_mask=mask_tensor, - per_residue=False, # Return overall LDDT score - **kwargs, - ) - - return float(lddt_value) - - def state_dict(self): - """This state dict is optimized for storage, so it turns things to fp16 whenever - possible and converts numpy arrays to lists for JSON serialization. - """ - dct = {k: v for k, v in vars(self).items()} - for k, v in dct.items(): - if isinstance(v, np.ndarray): - match v.dtype: - case np.int64: - dct[k] = v.astype(np.int32).tolist() - case np.float64 | np.float32: - dct[k] = v.astype(np.float16).tolist() - case _: - dct[k] = v.tolist() - elif isinstance(v, MolecularComplexMetadata): - dct[k] = asdict(v) - - return dct - - def to_blob(self) -> bytes: - return brotli.compress(msgpack.dumps(self.state_dict()), quality=5) - - @classmethod - def from_state_dict(cls, dct): - for k, v in dct.items(): - if isinstance(v, list) and k in [ - "atom_positions", - "atom_elements", - "token_to_atoms", - "plddt", - ]: - dct[k] = np.array(v) - - for k, v in dct.items(): - if isinstance(v, np.ndarray): - if k in ["atom_positions", "plddt"]: - dct[k] = v.astype(np.float32) - elif k in ["token_to_atoms"]: - dct[k] = v.astype(np.int32) - - dct["metadata"] = MolecularComplexMetadata(**dct["metadata"]) - return cls(**dct) - - @classmethod - def from_blob(cls, input: Path | str | io.BytesIO | bytes): - match input: - case Path() | str(): - bytes = Path(input).read_bytes() - case io.BytesIO(): - bytes = input.getvalue() - case _: - bytes = input - return cls.from_state_dict( - msgpack.loads(brotli.decompress(bytes), strict_map_key=False) - ) diff --git a/esm/utils/structure/protein_chain.py b/esm/utils/structure/protein_chain.py index b4db7081..4889886e 100644 --- a/esm/utils/structure/protein_chain.py +++ b/esm/utils/structure/protein_chain.py @@ -25,21 +25,13 @@ from esm.utils.structure.affine3d import Affine3D from esm.utils.structure.aligner import Aligner from esm.utils.structure.atom_indexer import AtomIndexer -from esm.utils.structure.metrics import ( - compute_gdt_ts, - compute_lddt_ca, -) -from esm.utils.structure.mmcif_parsing import ( - MmcifWrapper, - Residue, -) +from esm.utils.structure.metrics import compute_gdt_ts, compute_lddt_ca +from esm.utils.structure.mmcif_parsing import MmcifWrapper, Residue from esm.utils.structure.normalize_coordinates import ( apply_frame_to_coords, get_protein_normalization_frame, ) -from esm.utils.structure.protein_structure import ( - index_by_atom_name, -) +from esm.utils.structure.protein_structure import index_by_atom_name from esm.utils.types import PathOrBuffer msgpack_numpy.patch() @@ -401,7 +393,6 @@ def from_blob(cls, input: Path | str | io.BytesIO | bytes): bytes = input return cls.from_state_dict(msgpack.loads(brotli.decompress(bytes))) - def sasa(self, by_residue: bool = True): arr = self.atom_array_no_insertions sasa_per_atom = bs.sasa(arr) # type: ignore @@ -707,7 +698,6 @@ def gdt_ts( ) return float(gdt_ts) if gdt_ts.numel() == 1 else gdt_ts.numpy().flatten() - @classmethod def chain_iterable_from_mmcif( cls, diff --git a/esm/utils/structure/protein_complex.py b/esm/utils/structure/protein_complex.py index 2bd20398..7e2c38de 100644 --- a/esm/utils/structure/protein_complex.py +++ b/esm/utils/structure/protein_complex.py @@ -32,14 +32,8 @@ from esm.utils.structure.affine3d import Affine3D from esm.utils.structure.aligner import Aligner from esm.utils.structure.atom_indexer import AtomIndexer -from esm.utils.structure.metrics import ( - compute_gdt_ts, - compute_lddt_ca, -) -from esm.utils.structure.mmcif_parsing import ( - MmcifWrapper, - NoProteinError, -) +from esm.utils.structure.metrics import compute_gdt_ts, compute_lddt_ca +from esm.utils.structure.mmcif_parsing import MmcifWrapper, NoProteinError from esm.utils.structure.protein_chain import ( ProteinChain, chain_to_ndarray, diff --git a/esm/utils/system.py b/esm/utils/system.py deleted file mode 100644 index c2800e57..00000000 --- a/esm/utils/system.py +++ /dev/null @@ -1,45 +0,0 @@ -import io -import subprocess -import typing as T -from pathlib import Path - -PathLike = T.Union[str, Path] -PathOrBuffer = T.Union[PathLike, io.StringIO] - - -def run_subprocess_with_errorcheck( - *popenargs, - capture_output: bool = False, - quiet: bool = False, - env: dict[str, str] | None = None, - shell: bool = False, - executable: str | None = None, - **kws, -) -> subprocess.CompletedProcess: - """A command similar to subprocess.run, however the errormessage will - contain the stderr when using this function. This makes it significantly - easier to diagnose issues. - """ - try: - if capture_output: - stdout = subprocess.PIPE - elif quiet: - stdout = subprocess.DEVNULL - else: - stdout = None - - p = subprocess.run( - *popenargs, - stderr=subprocess.PIPE, - stdout=stdout, - check=True, - env=env, - shell=shell, - executable=executable, - **kws, - ) - except subprocess.CalledProcessError as e: - raise RuntimeError( - f"Command failed with errorcode {e.returncode}." f"\n\n{e.stderr.decode()}" - ) - return p diff --git a/esm/widgets/components/function_annotator.py b/esm/widgets/components/function_annotator.py index 714238f9..f567f94d 100644 --- a/esm/widgets/components/function_annotator.py +++ b/esm/widgets/components/function_annotator.py @@ -4,9 +4,7 @@ from ipywidgets import widgets from esm.sdk.api import FunctionAnnotation -from esm.tokenization.function_tokenizer import ( - InterProQuantizedTokenizer, -) +from esm.tokenization.function_tokenizer import InterProQuantizedTokenizer TRIE: pygtrie.CharTrie | None = None diff --git a/esm/widgets/components/results_visualizer.py b/esm/widgets/components/results_visualizer.py index 99692e51..261c0a5d 100644 --- a/esm/widgets/components/results_visualizer.py +++ b/esm/widgets/components/results_visualizer.py @@ -7,15 +7,11 @@ import matplotlib.pyplot as plt from esm.sdk.api import ESMProtein -from esm.widgets.utils.drawing.draw_category_array import ( - draw_data_array, -) +from esm.widgets.utils.drawing.draw_category_array import draw_data_array from esm.widgets.utils.drawing.draw_function_annotations import ( draw_function_annotations, ) -from esm.widgets.utils.drawing.draw_protein_structure import ( - draw_protein_structure, -) +from esm.widgets.utils.drawing.draw_protein_structure import draw_protein_structure from esm.widgets.utils.serialization import ( create_download_button_from_buffer, protein_to_pdb_buffer, diff --git a/esm/widgets/components/sasa_prompt_selector.py b/esm/widgets/components/sasa_prompt_selector.py index ecc5c1b4..9c026500 100644 --- a/esm/widgets/components/sasa_prompt_selector.py +++ b/esm/widgets/components/sasa_prompt_selector.py @@ -3,16 +3,9 @@ import ipywidgets as widgets from esm.utils.structure.protein_chain import ProteinChain -from esm.widgets.utils.drawing.colors import ( - hex_to_rgba_tuple, - rgba_tuple_to_hex, -) -from esm.widgets.utils.drawing.draw_category_array import ( - draw_data_array, -) -from esm.widgets.utils.parsing import ( - convert_range_string_to_list_of_ranges, -) +from esm.widgets.utils.drawing.colors import hex_to_rgba_tuple, rgba_tuple_to_hex +from esm.widgets.utils.drawing.draw_category_array import draw_data_array +from esm.widgets.utils.parsing import convert_range_string_to_list_of_ranges from esm.widgets.utils.prompting import PromptManager diff --git a/esm/widgets/components/secondary_structure_prompt_selector.py b/esm/widgets/components/secondary_structure_prompt_selector.py index 020180fe..b8007d69 100644 --- a/esm/widgets/components/secondary_structure_prompt_selector.py +++ b/esm/widgets/components/secondary_structure_prompt_selector.py @@ -4,16 +4,9 @@ import pydssp from esm.utils.structure.protein_chain import ProteinChain -from esm.widgets.utils.drawing.colors import ( - hex_to_rgba_tuple, - rgba_tuple_to_hex, -) -from esm.widgets.utils.drawing.draw_category_array import ( - draw_data_array, -) -from esm.widgets.utils.parsing import ( - convert_range_string_to_list_of_ranges, -) +from esm.widgets.utils.drawing.colors import hex_to_rgba_tuple, rgba_tuple_to_hex +from esm.widgets.utils.drawing.draw_category_array import draw_data_array +from esm.widgets.utils.parsing import convert_range_string_to_list_of_ranges from esm.widgets.utils.prompting import PromptManager diff --git a/esm/widgets/components/sequence_prompt_selector.py b/esm/widgets/components/sequence_prompt_selector.py index d538e670..c5e3526f 100644 --- a/esm/widgets/components/sequence_prompt_selector.py +++ b/esm/widgets/components/sequence_prompt_selector.py @@ -6,9 +6,7 @@ hex_to_rgba_tuple, rgba_tuple_to_rgba_html_string, ) -from esm.widgets.utils.parsing import ( - convert_range_string_to_list_of_ranges, -) +from esm.widgets.utils.parsing import convert_range_string_to_list_of_ranges from esm.widgets.utils.prompting import PromptManager diff --git a/esm/widgets/components/structure_prompt_selector.py b/esm/widgets/components/structure_prompt_selector.py index f4b497c0..13a9df78 100644 --- a/esm/widgets/components/structure_prompt_selector.py +++ b/esm/widgets/components/structure_prompt_selector.py @@ -10,12 +10,8 @@ from esm.utils.structure.protein_chain import ProteinChain from esm.widgets.utils import indexing -from esm.widgets.utils.drawing.draw_protein_structure import ( - draw_protein_structure, -) -from esm.widgets.utils.parsing import ( - convert_range_string_to_list_of_ranges, -) +from esm.widgets.utils.drawing.draw_protein_structure import draw_protein_structure +from esm.widgets.utils.parsing import convert_range_string_to_list_of_ranges from esm.widgets.utils.printing import wrapped_print from esm.widgets.utils.prompting import PromptManager diff --git a/esm/widgets/utils/drawing/draw_function_annotations.py b/esm/widgets/utils/drawing/draw_function_annotations.py index c71e5434..59e9f7cf 100644 --- a/esm/widgets/utils/drawing/draw_function_annotations.py +++ b/esm/widgets/utils/drawing/draw_function_annotations.py @@ -9,10 +9,7 @@ from PIL import Image from esm.sdk.api import FunctionAnnotation -from esm.utils.function.interpro import ( - InterPro, - InterProEntryType, -) +from esm.utils.function.interpro import InterPro, InterProEntryType @contextmanager diff --git a/esm/widgets/utils/prompting.py b/esm/widgets/utils/prompting.py index 1ce6d9c6..1e89bb64 100644 --- a/esm/widgets/utils/prompting.py +++ b/esm/widgets/utils/prompting.py @@ -9,9 +9,7 @@ from esm.utils import encoding from esm.widgets.utils import indexing from esm.widgets.utils.drawing.colors import rgba_tuple_to_hex -from esm.widgets.utils.drawing.draw_category_array import ( - draw_data_array, -) +from esm.widgets.utils.drawing.draw_category_array import draw_data_array from esm.widgets.utils.printing import wrapped_print diff --git a/esm/widgets/views/esm3_generation_launcher.py b/esm/widgets/views/esm3_generation_launcher.py index e94c60ff..f8bf8f3b 100644 --- a/esm/widgets/views/esm3_generation_launcher.py +++ b/esm/widgets/views/esm3_generation_launcher.py @@ -13,13 +13,9 @@ GenerationConfig, ) from esm.utils.constants import models -from esm.widgets.components.results_visualizer import ( - create_results_visualizer, -) +from esm.widgets.components.results_visualizer import create_results_visualizer from esm.widgets.utils.printing import wrapped_print -from esm.widgets.utils.serialization import ( - create_download_results_button, -) +from esm.widgets.utils.serialization import create_download_results_button def create_esm3_generation_launcher( diff --git a/esm/widgets/views/esm3_prompt_selector.py b/esm/widgets/views/esm3_prompt_selector.py index 035db28e..f7e60686 100644 --- a/esm/widgets/views/esm3_prompt_selector.py +++ b/esm/widgets/views/esm3_prompt_selector.py @@ -1,8 +1,6 @@ from ipywidgets import widgets -from esm.widgets.components.sasa_prompt_selector import ( - create_sasa_prompt_selector, -) +from esm.widgets.components.sasa_prompt_selector import create_sasa_prompt_selector from esm.widgets.components.secondary_structure_prompt_selector import ( create_secondary_structure_prompt_selector, ) diff --git a/esm/widgets/views/generation.py b/esm/widgets/views/generation.py index 19015f60..fdec2094 100644 --- a/esm/widgets/views/generation.py +++ b/esm/widgets/views/generation.py @@ -4,20 +4,12 @@ from esm.sdk.api import ESM3InferenceClient, ESMProtein from esm.utils.constants import esm3 as C -from esm.widgets.components.function_annotator import ( - create_function_annotator, -) +from esm.widgets.components.function_annotator import create_function_annotator from esm.widgets.utils.prompting import PromptManagerCollection from esm.widgets.utils.protein_import import ProteinImporter -from esm.widgets.views.esm3_generation_launcher import ( - create_esm3_generation_launcher, -) -from esm.widgets.views.esm3_prompt_preview import ( - create_esm3_prompt_preview, -) -from esm.widgets.views.esm3_prompt_selector import ( - create_esm3_prompt_selector, -) +from esm.widgets.views.esm3_generation_launcher import create_esm3_generation_launcher +from esm.widgets.views.esm3_prompt_preview import create_esm3_prompt_preview +from esm.widgets.views.esm3_prompt_selector import create_esm3_prompt_selector def create_generation_ui( diff --git a/esm/widgets/views/inverse_folding.py b/esm/widgets/views/inverse_folding.py index 8becb8eb..50d9a128 100644 --- a/esm/widgets/views/inverse_folding.py +++ b/esm/widgets/views/inverse_folding.py @@ -6,9 +6,7 @@ ESMProteinError, GenerationConfig, ) -from esm.widgets.components.results_visualizer import ( - create_results_visualizer, -) +from esm.widgets.components.results_visualizer import create_results_visualizer from esm.widgets.utils.printing import wrapped_print from esm.widgets.utils.protein_import import ProteinImporter diff --git a/esm/widgets/views/login.py b/esm/widgets/views/login.py index 2d8be5a3..5c7b6706 100644 --- a/esm/widgets/views/login.py +++ b/esm/widgets/views/login.py @@ -4,10 +4,7 @@ from ipywidgets import widgets -from esm.widgets.utils.clients import ( - get_forge_client, - get_local_client, -) +from esm.widgets.utils.clients import get_forge_client, get_local_client from esm.widgets.utils.types import ClientInitContainer diff --git a/esm/widgets/views/prediction.py b/esm/widgets/views/prediction.py index de6666d2..94ff49dc 100644 --- a/esm/widgets/views/prediction.py +++ b/esm/widgets/views/prediction.py @@ -6,9 +6,7 @@ ESMProteinError, GenerationConfig, ) -from esm.widgets.components.results_visualizer import ( - create_results_visualizer, -) +from esm.widgets.components.results_visualizer import create_results_visualizer from esm.widgets.utils.printing import wrapped_print from esm.widgets.utils.protein_import import ProteinImporter diff --git a/pyproject.toml b/pyproject.toml index 0ca682fd..2f8008f8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "esm" -version = "3.2.2.post1" +version = "3.2.2" description = "EvolutionaryScale open model repository" readme = "README.md" requires-python = ">=3.12,<3.13" @@ -45,6 +45,7 @@ dependencies = [ "pygtrie", "dna_features_viewer", ] + # Pytest [tool.pytest.ini_options] addopts = """ diff --git a/tests/oss_pytests/requirements.txt b/tests/oss_pytests/requirements.txt index 143baa7a..d120ce35 100644 --- a/tests/oss_pytests/requirements.txt +++ b/tests/oss_pytests/requirements.txt @@ -1,2 +1,3 @@ -esm >=3.2.1post1,<4.0.0 +esm pytest +httpx # TODO(williamxi): Remove this after the esm repo is fixed diff --git a/tests/oss_pytests/test_oss_client.py b/tests/oss_pytests/test_oss_client.py index e2dbafb5..6fa4a30d 100644 --- a/tests/oss_pytests/test_oss_client.py +++ b/tests/oss_pytests/test_oss_client.py @@ -1,7 +1,7 @@ import os import pytest -import torch + from esm.sdk import client # pyright: ignore from esm.sdk.api import ( # pyright: ignore ESMProtein, @@ -37,7 +37,6 @@ def test_oss_esm3_client(): logits_config = LogitsConfig(sequence=True, return_embeddings=True) result = esm3_client.logits(input=encoded_protein, config=logits_config) assert isinstance(result, LogitsOutput) - assert isinstance(result.logits.sequence, torch.Tensor) sampling_config = SamplingConfig(sequence=SamplingTrackConfig(temperature=0.1)) result = esm3_client.forward_and_sample( @@ -54,7 +53,7 @@ def test_oss_esm3_client(): def test_oss_esmc_client(): assert URL is not None - sequence = "MALWMRLLPLLALLALAVPDPAAA" + sequence = "MALWMRLLPLLALLALAVUUPDPAAA" model = "esmc-300m-2024-12" esmc_client = client(model=model, url=URL, token=API_TOKEN) @@ -70,14 +69,13 @@ def test_oss_esmc_client(): ) result = esmc_client.logits(input=encoded_protein, config=logits_config) assert isinstance(result, LogitsOutput) - assert isinstance(result.logits.sequence, torch.Tensor) @pytest.mark.sdk def test_oss_sequence_structure_forge_inference_client(): assert URL is not None - sequence = "MALWMRLLPLLALLALAVPDPAAA" + sequence = "MALWMRLLPLLALLALAVUUPDPAAA" model = "esm3-small-2024-03" client = SequenceStructureForgeInferenceClient( model=model, url=URL, token=API_TOKEN