From 1cfb08b424b23115a43364b82abccdd532cc3022 Mon Sep 17 00:00:00 2001 From: Neil Thomas Date: Fri, 19 Sep 2025 19:27:16 +0000 Subject: [PATCH 1/4] 3.2.2.post2 --- cookbook/tutorials/2_embed.ipynb | 6 +- cookbook/tutorials/3_gfp_design.ipynb | 6 +- cookbook/tutorials/4_forge_generate.ipynb | 4 +- cookbook/tutorials/5_guided_generation.ipynb | 2 +- esm/__init__.py | 2 +- esm/sdk/api.py | 32 +- esm/sdk/forge.py | 108 ++- esm/sdk/retry.py | 19 +- esm/utils/generation.py | 4 +- esm/utils/misc.py | 140 ++- esm/utils/msa/__init__.py | 3 + esm/utils/msa/filter_sequences.py | 79 ++ esm/utils/msa/msa.py | 500 ++++++++++ esm/utils/parsing.py | 83 ++ esm/utils/sequential_dataclass.py | 157 ++++ esm/utils/structure/input_builder.py | 95 ++ esm/utils/structure/metrics.py | 2 +- esm/utils/structure/molecular_complex.py | 938 +++++++++++++++++++ esm/utils/structure/protein_complex.py | 12 +- esm/utils/system.py | 45 + pixi.lock | 4 +- pyproject.toml | 3 +- tests/oss_pytests/requirements.txt | 3 +- tests/oss_pytests/test_oss_client.py | 3 + 24 files changed, 2196 insertions(+), 54 deletions(-) create mode 100644 esm/utils/msa/__init__.py create mode 100644 esm/utils/msa/filter_sequences.py create mode 100644 esm/utils/msa/msa.py create mode 100644 esm/utils/parsing.py create mode 100644 esm/utils/sequential_dataclass.py create mode 100644 esm/utils/structure/input_builder.py create mode 100644 esm/utils/structure/molecular_complex.py create mode 100644 esm/utils/system.py diff --git a/cookbook/tutorials/2_embed.ipynb b/cookbook/tutorials/2_embed.ipynb index 459fa90e..61fdb397 100644 --- a/cookbook/tutorials/2_embed.ipynb +++ b/cookbook/tutorials/2_embed.ipynb @@ -49,18 +49,18 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Grab a token from [the Forge console](https://forge.evolutionaryscale.ai/console) and add it below. Note that your token is like a password for your account and you should take care to protect it. For this reason it is recommended to frequently create a new token and delete old, unused ones. It is also recommended to paste the token directly into an environment variable or use a utility like `getpass` as shown below so tokens are not accidentally shared or checked into code repositories." + "Grab a token from [Forge](https://forge.evolutionaryscale.ai/) and add it below. Note that your token is like a password for your account and you should take care to protect it. For this reason it is recommended to frequently create a new token and delete old, unused ones. It is also recommended to paste the token directly into an environment variable or use a utility like `getpass` as shown below so tokens are not accidentally shared or checked into code repositories." ] }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from getpass import getpass\n", "\n", - "token = getpass(\"Token from Forge console: \")" + "token = getpass(\"Token from Forge: \")" ] }, { diff --git a/cookbook/tutorials/3_gfp_design.ipynb b/cookbook/tutorials/3_gfp_design.ipynb index bde09ad5..f2bb85b9 100644 --- a/cookbook/tutorials/3_gfp_design.ipynb +++ b/cookbook/tutorials/3_gfp_design.ipynb @@ -80,18 +80,18 @@ "\n", "The largest ESM3 (98 billion parameters) was trained with 1.07e24 FLOPs on 2.78 billion proteins and 771 billion unique tokens. To create esmGFP we used the 7 billion parameter variant of ESM3. We'll use this model via the [EvolutionaryScale Forge](https://forge.evolutionaryscale.ai) API.\n", "\n", - "Grab a token from [the Forge console](https://forge.evolutionaryscale.ai/console) and add it below. Note that your token is like a password for your account and you should take care to protect it. For this reason it is recommended to frequently create a new token and delete old, unused ones. It is also recommended to paste the token directly into an environment variable or use a utility like `getpass` as shown below so tokens are not accidentally shared or checked into code repositories.\n" + "Grab a token from [Forge](https://forge.evolutionaryscale.ai/) and add it below. Note that your token is like a password for your account and you should take care to protect it. For this reason it is recommended to frequently create a new token and delete old, unused ones. It is also recommended to paste the token directly into an environment variable or use a utility like `getpass` as shown below so tokens are not accidentally shared or checked into code repositories.\n" ] }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "metadata": { "id": "zNrU9Q2SYonX" }, "outputs": [], "source": [ - "token = getpass(\"Token from Forge console: \")" + "token = getpass(\"Token from Forge: \")" ] }, { diff --git a/cookbook/tutorials/4_forge_generate.ipynb b/cookbook/tutorials/4_forge_generate.ipynb index 7570fda9..9962e54d 100644 --- a/cookbook/tutorials/4_forge_generate.ipynb +++ b/cookbook/tutorials/4_forge_generate.ipynb @@ -53,7 +53,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Grab a token from [the Forge console](https://forge.evolutionaryscale.ai/console) and add it below. Note that your token is like a password for your account and you should take care to protect it. For this reason it is recommended to frequently create a new token and delete old, unused ones. It is also recommended to paste the token directly into an environment variable or use a utility like `getpass` as shown below so tokens are not accidentally shared or checked into code repositories." + "Grab a token from [Forge](https://forge.evolutionaryscale.ai/) and add it below. Note that your token is like a password for your account and you should take care to protect it. For this reason it is recommended to frequently create a new token and delete old, unused ones. It is also recommended to paste the token directly into an environment variable or use a utility like `getpass` as shown below so tokens are not accidentally shared or checked into code repositories." ] }, { @@ -64,7 +64,7 @@ "source": [ "from getpass import getpass\n", "\n", - "token = getpass(\"Token from Forge console: \")\n", + "token = getpass(\"Token from Forge: \")\n", "model = client(model=\"esm3-open\", url=\"https://forge.evolutionaryscale.ai\", token=token)" ] }, diff --git a/cookbook/tutorials/5_guided_generation.ipynb b/cookbook/tutorials/5_guided_generation.ipynb index f04d6be3..d35a8c94 100644 --- a/cookbook/tutorials/5_guided_generation.ipynb +++ b/cookbook/tutorials/5_guided_generation.ipynb @@ -120,7 +120,7 @@ "\n", "from esm.sdk import client\n", "\n", - "token = getpass(\"Token from Forge console: \")\n", + "token = getpass(\"Token from Forge: \")\n", "model = client(\n", " model=\"esm3-medium-2024-08\", url=\"https://forge.evolutionaryscale.ai\", token=token\n", ")" diff --git a/esm/__init__.py b/esm/__init__.py index 1e3bed4c..98a35b2d 100644 --- a/esm/__init__.py +++ b/esm/__init__.py @@ -1 +1 @@ -__version__ = "3.2.2" +__version__ = "3.2.2.post2" diff --git a/esm/sdk/api.py b/esm/sdk/api.py index 0212ddcd..bd93f0cd 100644 --- a/esm/sdk/api.py +++ b/esm/sdk/api.py @@ -148,12 +148,35 @@ def to_protein_complex( gt_chains = list(copy_annotations_from_ground_truth.chain_iter()) else: gt_chains = None + + # Expand pLDDT to match sequence length if needed, inserting NaN at chain breaks + # This handles the case where the server doesn't include chain breaks in pLDDT + # We should fix this in the server side. + if self.plddt is not None and len(self.plddt) != len(self.sequence): + # Only expand if there's a mismatch (likely due to chain breaks) + if "|" in self.sequence: + # Create expanded pLDDT with NaN at chain break positions + expanded_plddt = torch.full((len(self.sequence),), float("nan")) + plddt_idx = 0 + for i, aa in enumerate(self.sequence): + if aa != "|": + if plddt_idx < len(self.plddt): + expanded_plddt[i] = self.plddt[plddt_idx] + plddt_idx += 1 + plddt = expanded_plddt + else: + # Mismatch but no chain breaks - shouldn't happen but preserve original + plddt = self.plddt + else: + plddt = self.plddt + pred_chains = [] for i, (start, end) in enumerate(chain_boundaries): if i >= len(SINGLE_LETTER_CHAIN_IDS): raise ValueError( f"Too many chains to convert to ProteinComplex. The maximum number of chains is {len(SINGLE_LETTER_CHAIN_IDS)}" ) + pred_chain = ProteinChain.from_atom37( atom37_positions=coords[start:end], sequence=self.sequence[start:end], @@ -161,7 +184,7 @@ def to_protein_complex( if gt_chains is not None else SINGLE_LETTER_CHAIN_IDS[i], entity_id=gt_chains[i].entity_id if gt_chains is not None else None, - confidence=self.plddt[start:end] if self.plddt is not None else None, + confidence=plddt[start:end] if plddt is not None else None, ) pred_chains.append(pred_chain) return ProteinComplex.from_chains(pred_chains) @@ -298,13 +321,6 @@ def use_generative_unmasking_strategy(self): self.temperature_annealing = True -@define -class MSA: - # Paired MSA sequences. - # One would typically compute these using, for example, ColabFold. - sequences: list[str] - - @define class InverseFoldingConfig: invalid_ids: Sequence[int] = [] diff --git a/esm/sdk/forge.py b/esm/sdk/forge.py index e9bb2e3d..67c94b07 100644 --- a/esm/sdk/forge.py +++ b/esm/sdk/forge.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import asyncio import base64 import pickle @@ -7,7 +9,6 @@ import torch from esm.sdk.api import ( - MSA, ESM3InferenceClient, ESMCInferenceClient, ESMProtein, @@ -27,6 +28,15 @@ from esm.sdk.retry import retry_decorator from esm.utils.constants.api import MIMETYPE_ES_PICKLE from esm.utils.misc import deserialize_tensors, maybe_list, maybe_tensor +from esm.utils.msa import MSA +from esm.utils.structure.input_builder import ( + StructurePredictionInput, + serialize_structure_prediction_input, +) +from esm.utils.structure.molecular_complex import ( + MolecularComplex, + MolecularComplexResult, +) from esm.utils.types import FunctionAnnotation @@ -36,10 +46,8 @@ def _list_to_function_annotations(l) -> list[FunctionAnnotation] | None: return [FunctionAnnotation(*t) for t in l] -def _maybe_logits(data: dict[str, Any], track: str, return_bytes: bool = False): - ret = data.get("logits", {}).get(track, None) - # TODO(s22chan): just return this when removing return_bytes - return ret if ret is None or not return_bytes else maybe_tensor(ret) +def _maybe_logits(data: dict[str, Any], track: str): + return maybe_tensor(data.get("logits", {}).get(track, None)) def _maybe_b64_decode(obj, return_bytes: bool): @@ -137,7 +145,7 @@ async def _async_fetch_msa(self, sequence: str) -> MSA: data = await self._async_post( "msa", request={}, params={"sequence": sequence, "use_env": False} ) - return MSA(sequences=data["msa"]) + return MSA.from_sequences(sequences=data["msa"]) def _fetch_msa(self, sequence: str) -> MSA: print("Fetching MSA ... this may take a few minutes") @@ -146,7 +154,7 @@ def _fetch_msa(self, sequence: str) -> MSA: data = self._post( "msa", request={}, params={"sequence": sequence, "use_env": False} ) - return MSA(sequences=data["msa"]) + return MSA.from_sequences(sequences=data["msa"]) @retry_decorator async def async_fold( @@ -209,6 +217,70 @@ def fold( return self._process_fold_response(data, sequence) + @retry_decorator + async def async_fold_all_atom( + self, all_atom_input: StructurePredictionInput, model_name: str | None = None + ) -> MolecularComplexResult | list[MolecularComplexResult] | ESMProteinError: + """Fold a molecular complex containing proteins, nucleic acids, and/or ligands. + + Args: + all_atom_input: StructurePredictionInput containing sequences for different molecule types + model_name: Override the client level model name if needed + """ + request = self._process_fold_all_atom_request( + all_atom_input, model_name if model_name is not None else self.model + ) + + try: + data = await self._async_post("fold_all_atom", request) + except ESMProteinError as e: + return e + + return self._process_fold_all_atom_response(data) + + @retry_decorator + def fold_all_atom( + self, all_atom_input: StructurePredictionInput, model_name: str | None = None + ) -> MolecularComplexResult | list[MolecularComplexResult] | ESMProteinError: + """Predict coordinates for a molecular complex containing proteins, dna, rna, and/or ligands. + + Args: + all_atom_input: StructurePredictionInput containing sequences for different molecule types + model_name: Override the client level model name if needed + """ + request = self._process_fold_all_atom_request( + all_atom_input, model_name if model_name is not None else self.model + ) + + try: + data = self._post("fold_all_atom", request) + except ESMProteinError as e: + return e + + return self._process_fold_all_atom_response(data) + + @staticmethod + def _process_fold_all_atom_request( + all_atom_input: StructurePredictionInput, model_name: str | None = None + ) -> dict[str, Any]: + request: dict[str, Any] = { + "all_atom_input": serialize_structure_prediction_input(all_atom_input), + "model": model_name, + } + + return request + + @staticmethod + def _process_fold_all_atom_response(data: dict[str, Any]) -> MolecularComplexResult: + complex_data = data.get("complex") + molecular_complex = MolecularComplex.from_state_dict(complex_data) + return MolecularComplexResult( + complex=molecular_complex, + plddt=maybe_tensor(data.get("plddt"), convert_none_to_nan=True), + ptm=data.get("ptm", None), + distogram=maybe_tensor(data.get("distogram"), convert_none_to_nan=True), + ) + @retry_decorator async def async_inverse_fold( self, @@ -602,19 +674,15 @@ def _process_logits_response( return LogitsOutput( logits=ForwardTrackData( - sequence=_maybe_logits(data, "sequence", return_bytes), - structure=_maybe_logits(data, "structure", return_bytes), - secondary_structure=_maybe_logits( - data, "secondary_structure", return_bytes - ), - sasa=_maybe_logits(data, "sasa", return_bytes), - function=_maybe_logits(data, "function", return_bytes), + sequence=_maybe_logits(data, "sequence"), + structure=_maybe_logits(data, "structure"), + secondary_structure=_maybe_logits(data, "secondary_structure"), + sasa=_maybe_logits(data, "sasa"), + function=_maybe_logits(data, "function"), ), embeddings=maybe_tensor(data["embeddings"]), mean_embedding=data["mean_embedding"], - residue_annotation_logits=_maybe_logits( - data, "residue_annotation", return_bytes - ), + residue_annotation_logits=_maybe_logits(data, "residue_annotation"), hidden_states=maybe_tensor(data["hidden_states"]), mean_hidden_state=maybe_tensor(data["mean_hidden_state"]), ) @@ -965,6 +1033,7 @@ def _process_logits_request( "sequence": config.sequence, "return_embeddings": config.return_embeddings, "return_mean_embedding": config.return_mean_embedding, + "return_mean_hidden_states": config.return_mean_hidden_states, "return_hidden_states": config.return_hidden_states, "ith_hidden_layer": config.ith_hidden_layer, } @@ -981,12 +1050,11 @@ def _process_logits_response( data["hidden_states"] = _maybe_b64_decode(data["hidden_states"], return_bytes) output = LogitsOutput( - logits=ForwardTrackData( - sequence=_maybe_logits(data, "sequence", return_bytes) - ), + logits=ForwardTrackData(sequence=_maybe_logits(data, "sequence")), embeddings=maybe_tensor(data["embeddings"]), mean_embedding=data["mean_embedding"], hidden_states=maybe_tensor(data["hidden_states"]), + mean_hidden_state=maybe_tensor(data["mean_hidden_state"]), ) return output diff --git a/esm/sdk/retry.py b/esm/sdk/retry.py index 16c354b6..302d6cf0 100644 --- a/esm/sdk/retry.py +++ b/esm/sdk/retry.py @@ -2,10 +2,9 @@ from contextvars import ContextVar from functools import wraps -import httpx from tenacity import ( retry, - retry_if_exception_type, + retry_if_exception, retry_if_result, stop_after_attempt, wait_incrementing, @@ -30,8 +29,12 @@ def retry_if_specific_error(exception): def log_retry_attempt(retry_state): + try: + outcome = retry_state.outcome.result() + except Exception: + outcome = retry_state.outcome.exception() print( - f"Retrying... Attempt {retry_state.attempt_number} after {retry_state.next_action.sleep}s due to: {retry_state.outcome.result()}" + f"Retrying... Attempt {retry_state.attempt_number} after {retry_state.next_action.sleep}s due to: {outcome}" ) @@ -41,13 +44,18 @@ def retry_decorator(func): instance's retry settings. """ + def return_last_value(retry_state): + """Return the result of the last call attempt.""" + return retry_state.outcome.result() + @wraps(func) async def async_wrapper(instance, *args, **kwargs): if skip_retries_var.get(): return await func(instance, *args, **kwargs) retry_decorator = retry( + retry_error_callback=return_last_value, retry=retry_if_result(retry_if_specific_error) - | retry_if_exception_type(httpx.ConnectTimeout), # ADDED + | retry_if_exception(retry_if_specific_error), wait=wait_incrementing( increment=1, start=instance.min_retry_wait, max=instance.max_retry_wait ), @@ -62,8 +70,9 @@ def wrapper(instance, *args, **kwargs): if skip_retries_var.get(): return func(instance, *args, **kwargs) retry_decorator = retry( + retry_error_callback=return_last_value, retry=retry_if_result(retry_if_specific_error) - | retry_if_exception_type(httpx.ConnectTimeout), # ADDED + | retry_if_exception(retry_if_specific_error), wait=wait_incrementing( increment=1, start=instance.min_retry_wait, max=instance.max_retry_wait ), diff --git a/esm/utils/generation.py b/esm/utils/generation.py index b2d4ede4..4d35e7ee 100644 --- a/esm/utils/generation.py +++ b/esm/utils/generation.py @@ -43,7 +43,9 @@ def _trim_sequence_tensor_dataclass(o: Any, sequence_len: int): sliced = {} for k, v in attr.asdict(o, recurse=False).items(): - if v is None: + if k in ["mean_hidden_state", "mean_embedding"]: + sliced[k] = v + elif v is None: sliced[k] = None elif isinstance(v, torch.Tensor): # Trim padding. diff --git a/esm/utils/misc.py b/esm/utils/misc.py index f0d7a602..409ba1bf 100644 --- a/esm/utils/misc.py +++ b/esm/utils/misc.py @@ -1,7 +1,19 @@ +from __future__ import annotations + import os from collections import defaultdict +from dataclasses import is_dataclass from io import BytesIO -from typing import Any, ContextManager, Sequence, TypeVar +from typing import ( + Any, + ContextManager, + Generator, + Iterable, + Protocol, + Sequence, + TypeVar, + runtime_checkable, +) from warnings import warn import huggingface_hub @@ -18,6 +30,12 @@ TSequence = TypeVar("TSequence", bound=Sequence) +@runtime_checkable +class Concatable(Protocol): + @classmethod + def concat(cls, objs: list[Concatable]) -> Concatable: ... + + def slice_python_object_as_numpy( obj: TSequence, idx: int | list[int] | slice | np.ndarray ) -> TSequence: @@ -52,6 +70,37 @@ def slice_python_object_as_numpy( return sliced_obj # type: ignore +def slice_any_object( + obj: TSequence, idx: int | list[int] | slice | np.ndarray +) -> TSequence: + """ + Slice a arbitrary object (like a list, string, or tuple) as if it was a numpy object. Similar to `slice_python_object_as_numpy`, but detects if it's a numpy array or Tensor and uses the existing slice method if so. + + If the object is a dataclass, it will simply apply the index to the object, under the assumption that the object has correcty implemented numpy indexing. + + Example: + >>> obj = "ABCDE" + >>> slice_any_object(obj, [1, 3, 4]) + "BDE" + + >>> obj = np.array([1, 2, 3, 4, 5]) + >>> slice_any_object(obj, np.arange(5) < 3) + np.array([1, 2, 3]) + + >>> obj = ProteinChain.from_rcsb("1a3a", "A") + >>> slice_any_object(obj, np.arange(len(obj)) < 10) + # ProteinChain w/ length 10 + + """ + if isinstance(obj, (np.ndarray, torch.Tensor)): + return obj[idx] # type: ignore + elif is_dataclass(obj): + # if passing a dataclass, assume it implements a custom slice + return obj[idx] # type: ignore + else: + return slice_python_object_as_numpy(obj, idx) + + def rbf(values, v_min, v_max, n_bins=16): """ Returns RBF encodings in a new dimension at the end. @@ -298,6 +347,8 @@ def replace_inf(data): def maybe_tensor(x, convert_none_to_nan: bool = False) -> torch.Tensor | None: if x is None: return None + if isinstance(x, torch.Tensor): + return x if isinstance(x, list) and all(isinstance(t, torch.Tensor) for t in x): return torch.stack(x) if convert_none_to_nan: @@ -357,3 +408,90 @@ def deserialize_tensors(b: bytes) -> Any: buf = BytesIO(zstd.ZSTD_uncompress(b)) d = torch.load(buf, map_location="cpu", weights_only=False) return d + + +def join_lists( + lists: Sequence[Sequence[Any]], separator: Sequence[Any] | None = None +) -> list[Any]: + """Joins multiple lists with separator element. Like str.join but for lists. + + Example: [[1, 2], [3], [4]], separator=[0] -> [1, 2, 0, 3, 0, 4] + + Args: + lists: Lists of elements to chain + separator: separators to intsert between chained output. + Returns: + Joined lists. + """ + if not lists: + return [] + joined = [] + joined.extend(lists[0]) + for l in lists[1:]: + if separator: + joined.extend(separator) + joined.extend(l) + return joined + + +def iterate_with_intermediate( + lists: Iterable, intermediate +) -> Generator[Any, None, None]: + """ + Iterate over the iterable, yielding the intermediate value between + every element of the intermediate. Useful for joining objects with + separator tokens. + """ + it = iter(lists) + yield next(it) + for l in it: + yield intermediate + yield l + + +def concat_objects(objs: Sequence[Any], separator: Any | None = None): + """ + Concat objects with each other using a separator token. + + Supports: + - Concatable (objects that implement `concat` classmethod) + - strings + - lists + - numpy arrays + - torch Tensors + + Example: + >>> foo = "abc" + >>> bar = "def" + >>> concat_objects([foo, bar], "|") + "abc|def" + """ + match objs[0]: + case Concatable(): + return objs[0].__class__.concat(objs) # type: ignore + case str(): + assert isinstance( + separator, str + ), "Trying to join strings but separator is not a string" + return separator.join(objs) + case list(): + if separator is not None: + return join_lists(objs, [separator]) + else: + return join_lists(objs) + case np.ndarray(): + if separator is not None: + return np.concatenate( + list(iterate_with_intermediate(objs, np.array([separator]))) + ) + else: + return np.concatenate(objs) + case torch.Tensor(): + if separator is not None: + return torch.cat( + list(iterate_with_intermediate(objs, torch.tensor([separator]))) + ) + else: + return torch.cat(objs) # type: ignore + case _: + raise TypeError(type(objs[0])) diff --git a/esm/utils/msa/__init__.py b/esm/utils/msa/__init__.py new file mode 100644 index 00000000..5dd9965b --- /dev/null +++ b/esm/utils/msa/__init__.py @@ -0,0 +1,3 @@ +from esm.utils.msa.msa import MSA, FastMSA, remove_insertions_from_sequence + +__all__ = ["MSA", "FastMSA", "remove_insertions_from_sequence"] diff --git a/esm/utils/msa/filter_sequences.py b/esm/utils/msa/filter_sequences.py new file mode 100644 index 00000000..d44549d5 --- /dev/null +++ b/esm/utils/msa/filter_sequences.py @@ -0,0 +1,79 @@ +import tempfile +from pathlib import Path + +import numpy as np +from scipy.spatial.distance import cdist + +from esm.utils.system import run_subprocess_with_errorcheck + + +def greedy_select_indices(array, num_seqs: int, mode: str = "max") -> list[int]: + """Greedily select sequences that either maximize or minimize hamming distance. + + Algorithm proposed in the MSA Transformer paper. Starting from the query sequence, + iteratively add sequences to the list with the maximum (minimum) average Hamming + distance to the existing set of sequences. + + Args: + array (np.ndarray): Character array representing the sequences in the MSA + num_seqs (int): Number of sequences to select. + mode (str): Whether to maximize or minimize diversity. DO NOT pick 'min' unless + you're doing it to prove a point for a paper. + + Returns: + list[int]: List of indices to select from the array + """ + assert mode in ("max", "min") + depth = array.shape[0] + if depth <= num_seqs: + return list(range(depth)) + array = array.view(np.uint8) + + optfunc = np.argmax if mode == "max" else np.argmin + all_indices = np.arange(depth) + indices = [0] + pairwise_distances = np.zeros((0, depth)) + for _ in range(num_seqs - 1): + dist = cdist(array[indices[-1:]], array, "hamming") + pairwise_distances = np.concatenate([pairwise_distances, dist]) + shifted_distance = np.delete(pairwise_distances, indices, axis=1).mean(0) + shifted_index = optfunc(shifted_distance) + index = np.delete(all_indices, indices)[shifted_index] + indices.append(index) + indices = sorted(indices) + return indices + + +def hhfilter( + sequences: list[str], + seqid: int = 90, + diff: int = 0, + cov: int = 0, + qid: int = 0, + qsc: float = -20.0, + binary: str = "hhfilter", +) -> list[int]: + with tempfile.TemporaryDirectory(dir="/dev/shm") as tempdirname: + tempdir = Path(tempdirname) + fasta_file = tempdir / "input.fasta" + fasta_file.write_text( + "\n".join(f">{i}\n{seq}" for i, seq in enumerate(sequences)) + ) + output_file = tempdir / "output.fasta" + command = " ".join( + [ + f"{binary}", + f"-i {fasta_file}", + "-M a3m", + f"-o {output_file}", + f"-id {seqid}", + f"-diff {diff}", + f"-cov {cov}", + f"-qid {qid}", + f"-qsc {qsc}", + ] + ).split(" ") + run_subprocess_with_errorcheck(command, capture_output=True) + with output_file.open() as f: + indices = [int(line[1:].strip()) for line in f if line.startswith(">")] + return indices diff --git a/esm/utils/msa/msa.py b/esm/utils/msa/msa.py new file mode 100644 index 00000000..838e5b4b --- /dev/null +++ b/esm/utils/msa/msa.py @@ -0,0 +1,500 @@ +from __future__ import annotations + +import dataclasses +import string +from dataclasses import dataclass +from functools import cached_property +from itertools import islice +from typing import Sequence + +import numpy as np +from Bio import SeqIO +from scipy.spatial.distance import cdist + +from esm.utils.misc import slice_any_object +from esm.utils.msa.filter_sequences import greedy_select_indices, hhfilter +from esm.utils.parsing import FastaEntry, read_sequences, write_sequences +from esm.utils.sequential_dataclass import SequentialDataclass +from esm.utils.system import PathOrBuffer + +REMOVE_LOWERCASE_TRANSLATION = str.maketrans(dict.fromkeys(string.ascii_lowercase)) + + +def remove_insertions_from_sequence(seq: str) -> str: + return seq.translate(REMOVE_LOWERCASE_TRANSLATION) + + +@dataclass(frozen=True) +class MSA(SequentialDataclass): + """Object-oriented interface to an MSA. + + Args: + sequences (list[str]): List of protein sequences + headers (list[str]): List of headers describing the sequences + + """ + + entries: list[FastaEntry] + + @cached_property + def sequences(self) -> list[str]: + return [entry.sequence for entry in self.entries] + + @cached_property + def headers(self) -> list[str]: + return [entry.header for entry in self.entries] + + def __repr__(self): + return ( + f"MSA({self.entries[0].header}: Depth={self.depth}, Length={self.seqlen})" + ) + + def to_fast_msa(self) -> FastMSA: + return FastMSA(self.array, self.headers) + + @classmethod + def from_a3m( + cls, + path: PathOrBuffer, + remove_insertions: bool = True, + max_sequences: int | None = None, + ) -> MSA: + entries = [] + for header, seq in islice(read_sequences(path), max_sequences): + if remove_insertions: + seq = remove_insertions_from_sequence(seq) + if entries: + assert ( + len(seq) == len(entries[0].sequence) + ), f"Sequence length mismatch. Expected: {len(entries[0].sequence)}, Received: {len(seq)}" + entries.append(FastaEntry(header, seq)) + return cls(entries) + + def to_a3m(self, path: PathOrBuffer) -> None: + write_sequences(self.entries, path) + + @classmethod + def from_stockholm( + cls, + path: PathOrBuffer, + remove_insertions: bool = True, + max_sequences: int | None = None, + ) -> MSA: + entries = [] + for record in islice(SeqIO.parse(path, "stockholm"), max_sequences): + header = f"{record.id} {record.description}" + seq = str(record.seq) + if entries: + assert ( + len(seq) == len(entries[0].sequence) + ), f"Sequence length mismatch. Expected: {len(entries[0].sequence)}, Received: {len(seq)}" + entries.append(FastaEntry(header, seq)) + msa = cls(entries) + if remove_insertions: + keep_inds = [i for i, aa in enumerate(msa.query) if aa != "-"] + msa = msa.select_positions(keep_inds) + return msa + + def to_bytes(self) -> bytes: + version = 1 + version_bytes = version.to_bytes(1, "little") + seqlen_bytes = self.seqlen.to_bytes(4, "little") + depth_bytes = self.depth.to_bytes(4, "little") + array_bytes = self.array.tobytes() + header_bytes = "\n".join(entry.header for entry in self.entries).encode() + all_bytes = ( + version_bytes + seqlen_bytes + depth_bytes + array_bytes + header_bytes + ) + return all_bytes + + @classmethod + def from_bytes(cls, data: bytes) -> MSA: + version_bytes, seqlen_bytes, depth_bytes, data = ( + data[:1], + data[1:5], + data[5:9], + data[9:], + ) + version = int.from_bytes(version_bytes, "little") + if version != 1: + raise ValueError(f"Unsupported version: {version}") + seqlen = int.from_bytes(seqlen_bytes, "little") + depth = int.from_bytes(depth_bytes, "little") + array_bytes, header_bytes = data[: seqlen * depth], data[seqlen * depth :] + array = np.frombuffer(array_bytes, dtype="|S1") + array = array.reshape(depth, seqlen) + headers = header_bytes.decode().split("\n") + # Sometimes the separation is two newlines, which results in an empty header. + headers = [header for header in headers if header] + entries = [ + FastaEntry(header, b"".join(row).decode()) + for header, row in zip(headers, array) + ] + return cls(entries) + + # TODO(jmaccarl): set remove_insertions to True by default here to match other utils + @classmethod + def from_sequences( + cls, sequences: list[str], remove_insertions: bool = False + ) -> MSA: + if remove_insertions: + entries = [ + FastaEntry("", remove_insertions_from_sequence(seq)) + for seq in sequences + ] + else: + entries = [FastaEntry("", seq) for seq in sequences] + return cls(entries) + + def to_sequence_bytes(self) -> bytes: + """Stores ONLY SEQUENCES in array format as bytes. Header information will be lost.""" + seqlen_bytes = self.seqlen.to_bytes(4, "little") + array_bytes = self.array.tobytes() + all_bytes = seqlen_bytes + array_bytes + return all_bytes + + @classmethod + def from_sequence_bytes(cls, data: bytes) -> MSA: + seqlen_bytes, array_bytes = data[:4], data[4:] + seqlen = int.from_bytes(seqlen_bytes, "little") + array = np.frombuffer(array_bytes, dtype="|S1") + array = array.reshape(-1, seqlen) + entries = [FastaEntry("", b"".join(row).decode()) for row in array] + return cls(entries) + + @property + def depth(self) -> int: + return len(self.entries) + + @property + def seqlen(self) -> int: + return len(self.entries[0].sequence) + + @cached_property + def array(self) -> np.ndarray: + return np.array([list(seq) for seq in self.sequences], dtype="|S1") + + @property + def query(self) -> str: + return self.entries[0].sequence + + def select_sequences(self, indices: Sequence[int] | np.ndarray) -> MSA: + """Subselect rows of the MSA.""" + entries = [self.entries[idx] for idx in indices] + return dataclasses.replace(self, entries=entries) + + def select_positions(self, indices: Sequence[int] | np.ndarray) -> MSA: + """Subselect columns of the MSA.""" + entries = [ + FastaEntry(header, "".join(seq[idx] for idx in indices)) + for header, seq in self.entries + ] + return dataclasses.replace(self, entries=entries) + + def __getitem__(self, indices: int | list[int] | slice | np.ndarray): + if isinstance(indices, int): + indices = [indices] + + entries = [ + FastaEntry(header, slice_any_object(seq, indices)) + for header, seq in self.entries + ] + return dataclasses.replace(self, entries=entries) + + def __len__(self): + return self.seqlen + + def greedy_select(self, num_seqs: int, mode: str = "max") -> MSA: + """Greedily select sequences that either maximize or minimize hamming distance. + + Algorithm proposed in the MSA Transformer paper. Starting from the query sequence, + iteratively add sequences to the list with the maximum (minimum) average Hamming + distance to the existing set of sequences. + + Args: + num_seqs (int): Number of sequences to select. + mode (str): Whether to maximize or minimize diversity. DO NOT pick 'min' unless + you're doing it to prove a point for a paper. + + Returns: + MSA object w/ subselected sequences. + """ + assert mode in ("max", "min") + if self.depth <= num_seqs: + return self + + indices = greedy_select_indices(self.array, num_seqs, mode) + return self.select_sequences(indices) + + def hhfilter( + self, + seqid: int = 90, + diff: int = 0, + cov: int = 0, + qid: int = 0, + qsc: float = -20.0, + binary: str = "hhfilter", + ) -> MSA: + """Apply hhfilter to the sequences in the MSA and return a filtered MSA.""" + + indices = hhfilter( + self.sequences, + seqid=seqid, + diff=diff, + cov=cov, + qid=qid, + qsc=qsc, + binary=binary, + ) + return self.select_sequences(indices) + + def select_random_sequences(self, num_seqs: int) -> MSA: + """Uses random sampling to subselect sequences from the MSA. Always + keeps the query sequence. + """ + if num_seqs >= self.depth: + return self + + # Subselect random, always keeping the query sequence. + indices = np.sort( + np.append( + 0, np.random.choice(self.depth - 1, num_seqs - 1, replace=False) + 1 + ) + ) + msa = self.select_sequences(indices) # type: ignore + return msa + + def select_diverse_sequences(self, num_seqs: int) -> MSA: + """Applies hhfilter to select ~num_seqs sequences, then uses random sampling + to subselect if necessary. + """ + if num_seqs >= self.depth: + return self + + msa = self.hhfilter(diff=num_seqs) + if num_seqs < msa.depth: + msa = msa.select_random_sequences(num_seqs) + return msa + + def pad_to_depth(self, depth: int) -> MSA: + if depth < self.depth: + raise ValueError(f"Cannot pad to depth {depth} when depth is {self.depth}") + elif depth == self.depth: + return self + + num_to_add = depth - self.depth + extra_entries = [FastaEntry("", "-" * self.seqlen) for _ in range(num_to_add)] + return dataclasses.replace(self, entries=self.entries + extra_entries) + + @classmethod + def stack( + cls, msas: Sequence[MSA], remove_query_from_later_msas: bool = True + ) -> MSA: + """Stack a series of MSAs. Optionally remove the query from msas after the first.""" + all_entries = [] + for i, msa in enumerate(msas): + entries = msa.entries + if i > 0 and remove_query_from_later_msas: + entries = entries[1:] + all_entries.extend(entries) + return cls(entries=all_entries) + + @cached_property + def seqid(self) -> np.ndarray: + array = self.array.view(np.uint8) + seqid = 1 - cdist(array[0][None], array, "hamming") + return seqid[0] + + @classmethod + def concat( + cls, + msas: Sequence[MSA], + join_token: str | None = "|", + allow_depth_mismatch: bool = False, + ) -> MSA: + """Concatenate a series of MSAs horizontally, along the sequence dimension.""" + if not msas: + raise ValueError("Cannot concatenate an empty list of MSAs") + msa_depths = [msa.depth for msa in msas] + if len(set(msa_depths)) != 1: + if not allow_depth_mismatch: + raise ValueError("Depth mismatch in concatenating MSAs") + else: + max_depth = max(msa_depths) + msas = [msa.pad_to_depth(max_depth) for msa in msas] + headers = [ + "|".join([str(h) for h in headers]) + for headers in zip(*(msa.headers for msa in msas)) + ] + + if join_token is None: + join_token = "" + + seqs = [join_token.join(vals) for vals in zip(*(msa.sequences for msa in msas))] + entries = [FastaEntry(header, seq) for header, seq in zip(headers, seqs)] + return cls(entries) + + +@dataclass(frozen=True) +class FastMSA(SequentialDataclass): + """Object-oriented interface to an MSA stored as a numpy uint8 array.""" + + array: np.ndarray + headers: list[str] | None = None + + def __post_init__(self): + if self.headers is not None: + assert ( + len(self.headers) == self.depth + ), "Number of headers must match depth." + + @classmethod + def from_bytes(cls, data: bytes) -> FastMSA: + version_bytes, seqlen_bytes, depth_bytes, data = ( + data[:1], + data[1:5], + data[5:9], + data[9:], + ) + version = int.from_bytes(version_bytes, "little") + if version != 1: + raise ValueError(f"Unsupported version: {version}") + seqlen = int.from_bytes(seqlen_bytes, "little") + depth = int.from_bytes(depth_bytes, "little") + array_bytes, header_bytes = data[: seqlen * depth], data[seqlen * depth :] + array = np.frombuffer(array_bytes, dtype="|S1") + array = array.reshape(depth, seqlen) + headers = header_bytes.decode().split("\n") + # Sometimes the separation is two newlines, which results in an empty header. + headers = [header for header in headers if header] + return cls(array, headers) + + @classmethod + def from_sequence_bytes(cls, data: bytes) -> FastMSA: + seqlen_bytes, array_bytes = data[:4], data[4:] + seqlen = int.from_bytes(seqlen_bytes, "little") + array = np.frombuffer(array_bytes, dtype="|S1") + array = array.reshape(-1, seqlen) + return cls(array) + + @property + def depth(self) -> int: + return self.array.shape[0] + + @property + def seqlen(self) -> int: + return self.array.shape[1] + + def __len__(self): + return self.seqlen + + def __getitem__(self, indices: int | list[int] | slice | np.ndarray): + if isinstance(indices, int): + indices = [indices] + + return dataclasses.replace(self, array=self.array[:, indices]) + + def select_sequences(self, indices: Sequence[int] | np.ndarray) -> FastMSA: + """Subselect rows of the MSA.""" + array = self.array[indices] + headers = ( + [self.headers[idx] for idx in indices] if self.headers is not None else None + ) + return dataclasses.replace(self, array=array, headers=headers) + + def select_random_sequences(self, num_seqs: int) -> FastMSA: + """Uses random sampling to subselect sequences from the MSA. Always + keeps the query sequence. + """ + if num_seqs >= self.depth: + return self + + # Subselect random, always keeping the query sequence. + indices = np.sort( + np.append( + 0, np.random.choice(self.depth - 1, num_seqs - 1, replace=False) + 1 + ) + ) + msa = self.select_sequences(indices) # type: ignore + return msa + + def pad_to_depth(self, depth: int) -> FastMSA: + if depth < self.depth: + raise ValueError(f"Cannot pad to depth {depth} when depth is {self.depth}") + elif depth == self.depth: + return self + + num_to_add = depth - self.depth + array = np.pad( + self.array, + [(0, num_to_add), (0, 0)], + constant_values=ord("-") if self.array.dtype == np.uint8 else b"-", + ) + headers = self.headers + if headers is not None: + headers = headers + [""] * num_to_add + return dataclasses.replace(self, array=array, headers=headers) + + @classmethod + def concat( + cls, + msas: Sequence[FastMSA], + join_token: str | None = None, + allow_depth_mismatch: bool = False, + ) -> FastMSA: + """Concatenate a series of MSAs horizontally, along the sequence dimension.""" + if not msas: + raise ValueError("Cannot concatenate an empty list of MSAs") + if join_token is not None and join_token != "": + raise NotImplementedError("join_token is not supported for FastMSA") + + msa_depths = [msa.depth for msa in msas] + if len(set(msa_depths)) != 1: + if not allow_depth_mismatch: + raise ValueError("Depth mismatch in concatenating MSAs") + else: + max_depth = max(msa_depths) + msas = [msa.pad_to_depth(max_depth) for msa in msas] + headers = [ + "|".join([str(h) for h in headers]) + for headers in zip( + *( + msa.headers if msa.headers is not None else [""] * msa.depth + for msa in msas + ) + ) + ] + + array = np.concatenate([msa.array for msa in msas], axis=1) + return cls(array, headers) + + def to_msa(self) -> MSA: + headers = ( + self.headers + if self.headers is not None + else [f"seq{i}" for i in range(self.depth)] + ) + entries = [ + FastaEntry(header, b"".join(row).decode()) + for header, row in zip(headers, self.array) + ] + return MSA(entries) + + @classmethod + def stack( + cls, msas: Sequence[FastMSA], remove_query_from_later_msas: bool = True + ) -> FastMSA: + """Stack a series of MSAs. Optionally remove the query from msas after the first.""" + arrays = [] + all_headers = [] + for i, msa in enumerate(msas): + array = msa.array + headers = msa.headers + if i > 0 and remove_query_from_later_msas: + array = array[1:] + if headers is not None: + headers = headers[1:] + arrays.append(array) + if headers is not None: + all_headers.extend(headers) + return cls(np.concatenate(arrays, axis=0), all_headers) diff --git a/esm/utils/parsing.py b/esm/utils/parsing.py new file mode 100644 index 00000000..c47938ab --- /dev/null +++ b/esm/utils/parsing.py @@ -0,0 +1,83 @@ +import io +from pathlib import Path +from typing import Generator, Iterable, NamedTuple + +PathOrBuffer = str | Path | io.TextIOBase +FastaEntry = NamedTuple("FastaEntry", [("header", str), ("sequence", str)]) + + +def parse_fasta(fasta_string: str) -> Generator[FastaEntry, None, None]: + """ + Parses a fasta file and yields FastaEntry objects + + Args: + fasta_string: The fasta file as a string + Returns: + A generator of FastaEntry objects + """ + header = None + seq = [] + num_sequences = 0 + for line in fasta_string.splitlines(): + if not line or line[0] == "#": + continue + if line.startswith(">"): + if header is not None: + yield FastaEntry(header, "".join(seq)) + seq = [] + header = line[1:].strip() + else: + seq.append(line) + if header is not None: + num_sequences += 1 + yield FastaEntry(header, "".join(seq)) + + if num_sequences == 0: + raise ValueError("Found no sequences in input") + + +def read_sequences(path: PathOrBuffer) -> Generator[FastaEntry, None, None]: + # Uses duck typing to try and call the right method + # Doesn't use explicit isinstance check to support + # inputs that are not explicitly str/Path/TextIOBase but + # may support similar functionality + data = None # type: ignore + try: + if str(path).endswith(".gz"): + import gzip + + data = gzip.open(path, "rt") # type: ignore + else: + try: + data = open(path) # type: ignore + except TypeError: + data: io.TextIOBase = path # type: ignore + + yield from parse_fasta(data.read()) + finally: + if data is not None: + data.close() + + +def read_first_sequence(path: PathOrBuffer) -> FastaEntry: + return next(iter(read_sequences(path))) + + +def write_sequences(sequences: Iterable[tuple[str, str]], path: PathOrBuffer) -> None: + needs_closing = False + handle = None + try: + try: + handle = open(path, "w") # type: ignore + needs_closing = True + except TypeError: + handle = path + has_prev = False + for header, seq in sequences: + if has_prev: + handle.write("\n") # type: ignore + handle.write(f">{header}\n{seq}") # type: ignore + has_prev = True + finally: + if needs_closing: + handle.close() # type: ignore diff --git a/esm/utils/sequential_dataclass.py b/esm/utils/sequential_dataclass.py new file mode 100644 index 00000000..8e935798 --- /dev/null +++ b/esm/utils/sequential_dataclass.py @@ -0,0 +1,157 @@ +from abc import ABC, abstractmethod +from dataclasses import dataclass, fields, replace +from typing import TypeVar + +import numpy as np + +from esm.utils.misc import concat_objects, slice_any_object + +T = TypeVar("T") + + +@dataclass(frozen=True) +class SequentialDataclass(ABC): + """ + This is a builder on a dataclass that allows for automatic slicing and concatenation. + + When representing multimodal data, we often have multiple datatypes which have sequence dimensions that are the same (e.g. the length of the protein). + + When applying a transformation like a crop, we want to apply this to all tensors at the same time (e.g. crop the sequence, structure, and function). + + We also have some fields that are not sequential (like an id, or data source), which we don't want to crop. + + The SequentialDataclass abstracts this cropping away, allowing you to define dataclasses that implement `__len__`, `__getitem__` and `concat` automatically. + + This is done through the `metadata` field, which can take 3 values: + `sequence` (bool): True or False, tells the dataclass whether this field is a sequential type. Default: False. + `sequence_dim` (int): Which dimension is the sequential dimension (e.g. for a list of inverse folded sequences, we want to index each sequence in the list, not the list itself). Default: 0. + `join_token` (Any): What token to use to join when concatenating elements. Default: None. + + + Example: + + @dataclass(frozen=True) + class Foo(SequentialDataclass): + id: str + sequence: str = field(metadata={"sequence": True, "join_token": "|"}) + tensor: torch.Tensor = field(metadata={"sequence": True, "join_token": torch.nan}) + + def __len__(self): + # Must implement the __len__ method + return len(self.sequence) + + >>> foo = Foo(id="foo", sequence="ABCDE", tensor=torch.randn(5)) + Foo(id='foo', sequence='ABCDE', tensor=tensor([ 0.0252, -0.3335, -0.5143, 0.0251, -1.0717])) + + >>> foo[1:4] + Foo(id='foo', sequence='BCD', tensor=tensor([-0.3335, -0.5143, 0.0251])) + + >>> foo[np.arange(5) < 3] + Foo(id='foo', sequence='ABC', tensor=tensor([ 0.0252, -0.3335, -0.5143])) + + >>> Foo.concat([foo[:2], foo[3:]]) + Foo(id='foo', sequence='AB|DE', tensor=tensor([ 0.0252, -0.3335, nan, 0.0251, -1.0717])) + + # Trying to create a type where the sequence lengths do not match raises an error + >>> foo = Foo(id="foo", sequence="ABCDE", tensor=torch.randn(6)) + ValueError: Mismatch in sequence length for field: tensor. Expected 5, received 6 + + """ + + def __post_init__(self): + self._check_sequence_lengths_match() + + @abstractmethod + def __len__(self): + raise NotImplementedError + + def __getitem__(self, idx: int | list[int] | slice | np.ndarray): + updated_fields = {} + if isinstance(idx, int): + # make it so that things remain sequential + idx = [idx] + + for fld in fields(self): + if fld.metadata.get("sequence", False): + # this is a sequence, should be the same length as all other sequences + sequence_dim = fld.metadata.get("sequence_dim", 0) + value = getattr(self, fld.name) + if value is None: + continue + match sequence_dim: + case 0: + # sequence is first dimension + value = getattr(self, fld.name) + value = slice_any_object(value, idx) + updated_fields[fld.name] = value + case 1: + new_value = [slice_any_object(item, idx) for item in value] + updated_fields[fld.name] = value.__class__(new_value) + case _: + raise NotImplementedError( + "Arbitrary slicing for different sequence length fields is not implemented" + ) + + return replace(self, **updated_fields) + + def _check_sequence_lengths_match(self): + """Checks if sequence lengths of all "sequence" fields match.""" + for fld in fields(self): + if fld.metadata.get("sequence", False) and fld.name != "complex": + # this is a sequence, should be the same length as all other sequences + sequence_dim = fld.metadata.get("sequence_dim", 0) + value = getattr(self, fld.name) + if value is None: + continue + match sequence_dim: + case 0: + # sequence is first dimension + value = getattr(self, fld.name) + if len(value) != len(self): + raise ValueError( + f"Mismatch in sequence length for field: {fld.name}. Expected {len(self)}, received {len(value)}" + ) + case 1: + for item in value: + if len(item) != len(self): + raise ValueError( + f"Mismatch in sequence length for field: {fld.name}. Expected {len(self)}, received {len(item)}" + ) + case _: + raise NotImplementedError( + "Arbitrary matching for different sequence length fields is not implemented" + ) + + @classmethod + def concat(cls, items: list[T], **kwargs) -> T: + updated_fields = {} + for fld in fields(cls): + if fld.metadata.get("sequence", False): + # this is a sequence, should be the same length as all other sequences + sequence_dim = fld.metadata.get("sequence_dim", 0) + join_value = fld.metadata.get("join_token", None) + if getattr(items[0], fld.name) is None: + continue + values = [getattr(item, fld.name) for item in items] + match sequence_dim: + case 0: + # sequence is first dimension + value = concat_objects(values, join_value) + updated_fields[fld.name] = value + case 1: + new_value = [ + concat_objects(item, join_value) for item in zip(*values) + ] + updated_fields[fld.name] = getattr( + items[0], fld.name + ).__class__(new_value) + case _: + raise NotImplementedError( + "Arbitrary joining for different sequence length fields is not implemented" + ) + updated_fields.update(kwargs) + + return replace( + items[0], # type: ignore + **updated_fields, + ) diff --git a/esm/utils/structure/input_builder.py b/esm/utils/structure/input_builder.py new file mode 100644 index 00000000..026912fc --- /dev/null +++ b/esm/utils/structure/input_builder.py @@ -0,0 +1,95 @@ +from dataclasses import dataclass +from typing import Any, Sequence + +import numpy as np + + +@dataclass +class Modification: + position: int # zero-indexed + ccd: str + + +@dataclass +class ProteinInput: + id: str | list[str] + sequence: str + modifications: list[Modification] | None = None + + +@dataclass +class RNAInput: + id: str | list[str] + sequence: str + modifications: list[Modification] | None = None + + +@dataclass +class DNAInput: + id: str | list[str] + sequence: str + modifications: list[Modification] | None = None + + +@dataclass +class LigandInput: + id: str | list[str] + smiles: str + ccd: list[str] | None = None + + +@dataclass +class DistogramConditioning: + chain_id: str + distogram: np.ndarray + + +@dataclass +class PocketConditioning: + binder_chain_id: str + contacts: list[tuple[str, int]] + + +@dataclass +class StructurePredictionInput: + sequences: Sequence[ProteinInput | RNAInput | DNAInput | LigandInput] + pocket: PocketConditioning | None = None + distogram_conditioning: list[DistogramConditioning] | None = None + + +def serialize_structure_prediction_input(all_atom_input: StructurePredictionInput): + def create_chain_data(seq_input, chain_type: str) -> dict[str, Any]: + chain_data: dict[str, Any] = { + "sequence": seq_input.sequence, + "id": seq_input.id, + "type": chain_type, + } + if hasattr(seq_input, "modifications") and seq_input.modifications: + mods = [ + {"position": mod.position, "ccd": mod.ccd} + for mod in seq_input.modifications + ] + chain_data["modifications"] = mods + return chain_data + + sequences = [] + for seq_input in all_atom_input.sequences: + if isinstance(seq_input, ProteinInput): + sequences.append(create_chain_data(seq_input, "protein")) + elif isinstance(seq_input, RNAInput): + sequences.append(create_chain_data(seq_input, "rna")) + elif isinstance(seq_input, DNAInput): + sequences.append(create_chain_data(seq_input, "dna")) + elif isinstance(seq_input, LigandInput): + sequences.append( + { + "smiles": seq_input.smiles, + "id": seq_input.id, + "ccd": seq_input.ccd, + "type": "ligand", + } + ) + else: + raise ValueError(f"Unsupported sequence input type: {type(seq_input)}") + + return {"sequences": sequences} diff --git a/esm/utils/structure/metrics.py b/esm/utils/structure/metrics.py index d76ed766..b2e590db 100644 --- a/esm/utils/structure/metrics.py +++ b/esm/utils/structure/metrics.py @@ -264,7 +264,7 @@ def compute_lddt_ca( if all_atom_pred_pos.dim() != 3: all_atom_pred_pos = all_atom_pred_pos[..., ca_pos, :] all_atom_positions = all_atom_positions[..., ca_pos, :] - all_atom_mask = all_atom_mask[..., ca_pos : (ca_pos + 1)] # keep dim + all_atom_mask = all_atom_mask[..., ca_pos] return compute_lddt( all_atom_pred_pos, diff --git a/esm/utils/structure/molecular_complex.py b/esm/utils/structure/molecular_complex.py new file mode 100644 index 00000000..f53ab9c4 --- /dev/null +++ b/esm/utils/structure/molecular_complex.py @@ -0,0 +1,938 @@ +from __future__ import annotations + +import io +import os +import re +from dataclasses import asdict, dataclass +from pathlib import Path +from subprocess import check_output +from tempfile import TemporaryDirectory +from typing import TYPE_CHECKING, Any, List + +import biotite.structure.io.pdbx as pdbx +import brotli +import msgpack +import numpy as np +import torch + +from esm.utils import residue_constants +from esm.utils.structure.metrics import compute_lddt, compute_rmsd +from esm.utils.structure.protein_complex import ProteinComplex, ProteinComplexMetadata + + +@dataclass +class MolecularComplexResult: + """Result of molecular complex folding""" + + complex: MolecularComplex + plddt: torch.Tensor | None = None + ptm: float | None = None + iptm: float | None = None + pae: torch.Tensor | None = None + distogram: torch.Tensor | None = None + pair_chains_iptm: torch.Tensor | None = None + output_embedding_sequence: torch.Tensor | None = None + output_embedding_pair_pooled: torch.Tensor | None = None + + +@dataclass +class MolecularComplexMetadata: + """Metadata for MolecularComplex objects.""" + + entity_lookup: dict[int, str] + chain_lookup: dict[int, str] + assembly_composition: dict[str, list[str]] | None = None + + +@dataclass +class Molecule: + """Represents a single molecule/token within a MolecularComplex.""" + + token: str + token_idx: int + atom_positions: np.ndarray # [N_atoms, 3] + atom_elements: np.ndarray # [N_atoms] element strings + residue_type: int + molecule_type: int # PROTEIN=0, RNA=1, DNA=2, LIGAND=3 + confidence: float + + +@dataclass(frozen=True) +class MolecularComplex: + """ + Dataclass representing a molecular complex with support for proteins, nucleic acids, and ligands. + + Uses a flat atom representation with token-based sequence indexing, supporting all atom types + beyond the traditional atom37 protein representation. + """ + + id: str + sequence: List[str] # Token sequence like ['MET', 'LYS', 'A', 'G', 'ATP'] + + # Flat atom arrays - simplified representation + atom_positions: np.ndarray # [N_atoms, 3] 3D coordinates + atom_elements: np.ndarray # [N_atoms] element strings + + # Token-to-atom mapping for efficient access + token_to_atoms: np.ndarray # [N_tokens, 2] start/end indices into atoms array + + # Confidence data + plddt: np.ndarray # Per-token confidence scores [N_tokens] + + # Metadata + metadata: MolecularComplexMetadata + + def __post_init__(self): + """Validate array dimensions.""" + n_tokens = len(self.sequence) + assert ( + self.token_to_atoms.shape[0] == n_tokens + ), f"token_to_atoms shape {self.token_to_atoms.shape} != {n_tokens} tokens" + assert ( + self.plddt.shape[0] == n_tokens + ), f"plddt shape {self.plddt.shape} != {n_tokens} tokens" + + def __len__(self) -> int: + """Return number of tokens.""" + return len(self.sequence) + + def __getitem__(self, idx: int) -> Molecule: + """Access individual molecules/tokens by index.""" + if idx >= len(self.sequence) or idx < 0: + raise IndexError( + f"Token index {idx} out of range for {len(self.sequence)} tokens" + ) + + token = self.sequence[idx] + start_atom, end_atom = self.token_to_atoms[idx] + + # Extract atom data for this token + token_atom_positions = self.atom_positions[start_atom:end_atom] + token_atom_elements = self.atom_elements[start_atom:end_atom] + + # Default values for residue/molecule type (would be extended based on actual implementation) + residue_type = 0 # Default to standard residue + molecule_type = 0 # Default to protein + + return Molecule( + token=token, + token_idx=idx, + atom_positions=token_atom_positions, + atom_elements=token_atom_elements, + residue_type=residue_type, + molecule_type=molecule_type, + confidence=self.plddt[idx], + ) + + @property + def atom_coordinates(self) -> np.ndarray: + """Get flat array of all atom coordinates [N_atoms, 3].""" + return self.atom_positions + + # Conversion methods + @classmethod + def from_protein_complex(cls, pc: ProteinComplex) -> "MolecularComplex": + """Convert a ProteinComplex to MolecularComplex. + + Args: + pc: ProteinComplex object with atom37 representation + + Returns: + MolecularComplex with flat atom arrays and token-based indexing + """ + from esm.utils import residue_constants + + # Extract sequence without chain breaks + sequence_no_breaks = pc.sequence.replace("|", "") + sequence_tokens = [ + residue_constants.restype_1to3.get(aa, "UNK") for aa in sequence_no_breaks + ] + + # Convert atom37 to flat arrays + flat_positions = [] + flat_elements = [] + token_to_atoms = [] + + atom_idx = 0 + residue_idx = 0 + + for i, aa in enumerate(pc.sequence): + if aa == "|": + # Skip chain break tokens + continue + + # Get atom37 positions and mask for this residue + res_positions = pc.atom37_positions[residue_idx] # [37, 3] + res_mask = pc.atom37_mask[residue_idx] # [37] + + # Track start position for this token + token_start = atom_idx + + # Process each atom type in atom37 representation + for atom_type_idx, atom_name in enumerate(residue_constants.atom_types): + if res_mask[atom_type_idx]: # Atom is present + # Add position + flat_positions.append(res_positions[atom_type_idx]) + + # Determine element from atom name + element = ( + atom_name[0] if atom_name else "C" + ) # First character is element + flat_elements.append(element) + + atom_idx += 1 + + # Record token-to-atom mapping [start_idx, end_idx) + token_to_atoms.append([token_start, atom_idx]) + residue_idx += 1 + + # Convert to numpy arrays + atom_positions = np.array(flat_positions, dtype=np.float32) + atom_elements = np.array(flat_elements, dtype=object) + token_to_atoms_array = np.array(token_to_atoms, dtype=np.int32) + + # Extract confidence scores (skip chain breaks) + confidence_scores = [] + residue_idx = 0 + for aa in pc.sequence: + if aa != "|": + confidence_scores.append(pc.confidence[residue_idx]) + residue_idx += 1 + + confidence_array = np.array(confidence_scores, dtype=np.float32) + + # Create metadata - convert entity IDs to strings for MolecularComplexMetadata + entity_lookup_str = {k: str(v) for k, v in pc.metadata.entity_lookup.items()} + metadata = MolecularComplexMetadata( + entity_lookup=entity_lookup_str, + chain_lookup=pc.metadata.chain_lookup, + assembly_composition=pc.metadata.assembly_composition, + ) + + return cls( + id=pc.id, + sequence=sequence_tokens, + atom_positions=atom_positions, + atom_elements=atom_elements, + token_to_atoms=token_to_atoms_array, + plddt=confidence_array, + metadata=metadata, + ) + + def to_protein_complex(self) -> ProteinComplex: + """Convert MolecularComplex back to ProteinComplex format. + + Extracts only protein tokens and converts from flat atom representation + back to atom37 format used by ProteinComplex. + + Returns: + ProteinComplex with protein residues only, excluding ligands/nucleic acids + """ + from esm.utils import residue_constants + + # No need for element mapping - already using element characters + + # Filter for protein tokens only (skip ligands, nucleic acids) + protein_tokens = [] + protein_indices = [] + + for i, token in enumerate(self.sequence): + # Check if token is a standard 3-letter amino acid code + if token in residue_constants.restype_3to1: + protein_tokens.append(token) + protein_indices.append(i) + + if not protein_tokens: + raise ValueError("No protein tokens found in MolecularComplex") + + n_residues = len(protein_tokens) + + # Initialize atom37 arrays + atom37_positions = np.full((n_residues, 37, 3), np.nan, dtype=np.float32) + atom37_mask = np.zeros((n_residues, 37), dtype=bool) + + # Convert tokens back to single-letter sequence + single_letter_sequence = "".join( + [residue_constants.restype_3to1[token] for token in protein_tokens] + ) + + # Extract confidence scores for protein residues only + protein_confidence = self.plddt[protein_indices] + + # Convert flat atoms back to atom37 representation + for res_idx, token_idx in enumerate(protein_indices): + token = self.sequence[token_idx] + start_atom, end_atom = self.token_to_atoms[token_idx] + + # Get atom data for this residue + res_atom_positions = self.atom_positions[start_atom:end_atom] + + # Reconstruct atom37 representation by exactly reversing the forward conversion logic + # In from_protein_complex, atoms are added in atom_types order if present in mask + # So we need to reconstruct the mask and positions in the same order + atom_count = 0 + for atom_type_idx, atom_name in enumerate(residue_constants.atom_types): + # Check if this atom type exists for this residue and was present + residue_atoms = residue_constants.residue_atoms.get(token, []) + if atom_name in residue_atoms: + # This atom type exists for this residue, so it should have been included + if atom_count < len(res_atom_positions): + atom37_positions[res_idx, atom_type_idx] = res_atom_positions[ + atom_count + ] + atom37_mask[res_idx, atom_type_idx] = True + atom_count += 1 + + # Create other required arrays for ProteinComplex + # For simplicity, assume all protein residues belong to the same entity/chain + entity_id = np.zeros(n_residues, dtype=np.int64) + chain_id = np.zeros(n_residues, dtype=np.int64) + sym_id = np.zeros(n_residues, dtype=np.int64) + residue_index = np.arange(1, n_residues + 1, dtype=np.int64) + insertion_code = np.array([""] * n_residues, dtype=object) + + # Create simplified protein complex metadata + # Map the first entity/chain from molecular complex metadata + protein_metadata = ProteinComplexMetadata( + entity_lookup={0: 1}, # Single entity (int for ProteinComplexMetadata) + chain_lookup={0: "A"}, # Single chain + assembly_composition=self.metadata.assembly_composition, + ) + + return ProteinComplex( + id=self.id, + sequence=single_letter_sequence, + entity_id=entity_id, + chain_id=chain_id, + sym_id=sym_id, + residue_index=residue_index, + insertion_code=insertion_code, + atom37_positions=atom37_positions, + atom37_mask=atom37_mask, + confidence=protein_confidence, + metadata=protein_metadata, + ) + + @classmethod + def from_mmcif(cls, inp: str, id: str | None = None) -> "MolecularComplex": + """Read MolecularComplex from mmcif file or string. + + Args: + inp: Path to mmCIF file or mmCIF content as string + id: Optional identifier to assign to the complex + + Returns: + MolecularComplex with all molecules (proteins, ligands, nucleic acids) + """ + from io import StringIO + + # Check if input is a file path or mmCIF string content + if os.path.exists(inp): + # Input is a file path + mmcif_file = pdbx.CIFFile.read(inp) + else: + # Input is mmCIF string content + mmcif_file = pdbx.CIFFile.read(StringIO(inp)) + + # Get structure - handle missing model information gracefully + try: + structure = pdbx.get_structure(mmcif_file, model=1) + except (KeyError, ValueError): + # Fallback for mmCIF files without model information + try: + structure = pdbx.get_structure(mmcif_file) + except Exception: + # Last resort: use the first available model or all atoms + structure = pdbx.get_structure(mmcif_file, model=None) + # Type hint for pyright - structure is an AtomArray which is iterable + if TYPE_CHECKING: + structure: Any = structure + + # Get entity information from mmCIF + entity_info = {} + try: + # Access the first block in CIFFile + block = mmcif_file[0] + if "entity" in block: + entity_category = block["entity"] + if "id" in entity_category and "type" in entity_category: + entity_ids = entity_category["id"] + entity_types = entity_category["type"] + # Convert CIFColumn to list for iteration + if hasattr(entity_ids, "__iter__") and hasattr( + entity_types, "__iter__" + ): + # Type annotation to help pyright understand these are iterable + entity_ids_list = list(entity_ids) # type: ignore + entity_types_list = list(entity_types) # type: ignore + for eid, etype in zip(entity_ids_list, entity_types_list): + entity_info[eid] = etype + except Exception: + pass + + # Initialize arrays for flat atom representation + sequence_tokens = [] + flat_positions = [] + flat_elements = [] + token_to_atoms = [] + confidence_scores = [] + + atom_idx = 0 + + # Group atoms by chain and residue + chain_residue_groups = {} + for atom in structure: + chain_id = atom.chain_id + res_id = atom.res_id + res_name = atom.res_name + + if chain_id not in chain_residue_groups: + chain_residue_groups[chain_id] = {} + if res_id not in chain_residue_groups[chain_id]: + chain_residue_groups[chain_id][res_id] = { + "atoms": [], + "res_name": res_name, + "is_hetero": atom.hetero, + } + chain_residue_groups[chain_id][res_id]["atoms"].append(atom) + + # Process each chain and residue + for chain_id in sorted(chain_residue_groups.keys()): + residues = chain_residue_groups[chain_id] + + for res_id in sorted(residues.keys()): + residue_data = residues[res_id] + res_name = residue_data["res_name"] + atoms = residue_data["atoms"] + is_hetero = residue_data["is_hetero"] + + # Skip water molecules + if res_name == "HOH": + continue + + # Determine token name + if not is_hetero and res_name in residue_constants.restype_3to1: + # Standard amino acid + token_name = res_name + elif res_name in ["A", "T", "G", "C", "U", "DA", "DT", "DG", "DC"]: + # Nucleotide + token_name = res_name + else: + # Ligand or other molecule + token_name = res_name + + sequence_tokens.append(token_name) + token_start = atom_idx + + # Add all atoms from this residue + for atom in atoms: + flat_positions.append(atom.coord) + + # Get element character + element = atom.element + flat_elements.append(element) + + atom_idx += 1 + + # Record token-to-atom mapping + token_to_atoms.append([token_start, atom_idx]) + + # Add confidence score (B-factor if available, otherwise 1.0) + bfactor = getattr(atoms[0], "b_factor", 50.0) if atoms else 50.0 + confidence_scores.append(min(bfactor / 100.0, 1.0)) + + # Convert to numpy arrays + if not flat_positions: + # Create minimal arrays if no atoms found + atom_positions = np.zeros((0, 3), dtype=np.float32) + atom_elements = np.zeros(0, dtype=object) + token_to_atoms_array = np.zeros((len(sequence_tokens), 2), dtype=np.int32) + else: + atom_positions = np.array(flat_positions, dtype=np.float32) + atom_elements = np.array(flat_elements, dtype=object) + token_to_atoms_array = np.array(token_to_atoms, dtype=np.int32) + + confidence_array = np.array(confidence_scores, dtype=np.float32) + + # Create metadata + metadata = MolecularComplexMetadata( + entity_lookup=entity_info, + chain_lookup={ + i: chain_id for i, chain_id in enumerate(chain_residue_groups.keys()) + }, + assembly_composition=None, + ) + + # Set complex ID - if input was a path, use the stem; otherwise use default + if os.path.exists(inp): + complex_id = id or Path(inp).stem + else: + complex_id = id or "complex_from_string" + + return cls( + id=complex_id, + sequence=sequence_tokens, + atom_positions=atom_positions, + atom_elements=atom_elements, + token_to_atoms=token_to_atoms_array, + plddt=confidence_array, + metadata=metadata, + ) + + def to_mmcif(self) -> str: + """Write MolecularComplex to mmcif string. + + Returns: + String representation of the complex in mmCIF format + """ + # No need for element mapping - already using element characters + + lines = [] + + # Header + lines.append(f"data_{self.id}") + lines.append("#") + lines.append(f"_entry.id {self.id}") + lines.append("#") + + # Structure metadata + lines.append("_struct.entry_id {}".format(self.id)) + lines.append("_struct.title 'Protein Structure'") + lines.append("#") + + # Entity information + entity_id = 1 + chain_counter = 0 + lines.append("loop_") + lines.append("_entity.id") + lines.append("_entity.type") + lines.append("_entity.pdbx_description") + + # Determine entities based on sequence + protein_tokens = [] + other_tokens = [] + + for i, token in enumerate(self.sequence): + if token in residue_constants.restype_3to1: + protein_tokens.append((i, token)) + else: + other_tokens.append((i, token)) + + if protein_tokens: + lines.append(f"{entity_id} polymer 'Protein chain'") + entity_id += 1 + + for token in set(token for _, token in other_tokens): + lines.append(f"{entity_id} non-polymer 'Ligand {token}'") + entity_id += 1 + + lines.append("#") + + # Chain assignments + lines.append("loop_") + lines.append("_struct_asym.id") + lines.append("_struct_asym.entity_id") + + chain_id = "A" + if protein_tokens: + lines.append(f"{chain_id} 1") + chain_counter += 1 + chain_id = chr(ord(chain_id) + 1) + + entity_id = 2 + for token in set(token for _, token in other_tokens): + lines.append(f"{chain_id} {entity_id}") + entity_id += 1 + chain_counter += 1 + if chain_counter < 26: + chain_id = chr(ord(chain_id) + 1) + + lines.append("#") + + # Atom site information + lines.append("loop_") + lines.append("_atom_site.group_PDB") + lines.append("_atom_site.id") + lines.append("_atom_site.type_symbol") + lines.append("_atom_site.label_atom_id") + lines.append("_atom_site.label_alt_id") + lines.append("_atom_site.label_comp_id") + lines.append("_atom_site.label_asym_id") + lines.append("_atom_site.label_entity_id") + lines.append("_atom_site.label_seq_id") + lines.append("_atom_site.pdbx_PDB_ins_code") + lines.append("_atom_site.Cartn_x") + lines.append("_atom_site.Cartn_y") + lines.append("_atom_site.Cartn_z") + lines.append("_atom_site.occupancy") + lines.append("_atom_site.B_iso_or_equiv") + lines.append("_atom_site.pdbx_PDB_model_num") + lines.append("_atom_site.auth_seq_id") + lines.append("_atom_site.auth_comp_id") + lines.append("_atom_site.auth_asym_id") + lines.append("_atom_site.auth_atom_id") + + atom_id = 1 + seq_id = 1 + chain_id = "A" + entity_id = 1 + + for token_idx, token in enumerate(self.sequence): + start_atom, end_atom = self.token_to_atoms[token_idx] + + # Determine if this is a protein residue or ligand + is_protein = token in residue_constants.restype_3to1 + group_pdb = "ATOM" if is_protein else "HETATM" + current_entity_id = 1 if is_protein else 2 # Simplified entity assignment + current_chain_id = "A" if is_protein else "B" # Simplified chain assignment + + # Create atom names for this token + atom_names = [] + if is_protein: + # Use standard protein atom names + res_atoms = residue_constants.residue_atoms.get( + token, ["N", "CA", "C", "O"] + ) + atom_names = res_atoms[: end_atom - start_atom] + else: + # Generate generic atom names for ligands + for i in range(end_atom - start_atom): + atom_names.append(f"C{i+1}") + + # Pad atom names if needed + while len(atom_names) < (end_atom - start_atom): + atom_names.append(f"X{len(atom_names)+1}") + + # Write atoms for this token + for atom_idx_in_token, global_atom_idx in enumerate( + range(start_atom, end_atom) + ): + pos = self.atom_positions[global_atom_idx] + element_char = self.atom_elements[global_atom_idx] + element_symbol = element_char if isinstance(element_char, str) else "C" + + atom_name = ( + atom_names[atom_idx_in_token] + if atom_idx_in_token < len(atom_names) + else f"X{atom_idx_in_token+1}" + ) + + # Format atom site line + bfactor = ( + self.plddt[token_idx] * 100.0 + if len(self.plddt) > token_idx + else 50.0 + ) + + line = ( + f"{group_pdb:<6} {atom_id:>5} {element_symbol:<2} {atom_name:<4} . " + f"{token:<3} {current_chain_id} {current_entity_id} {seq_id:>3} ? " + f"{pos[0]:>8.3f} {pos[1]:>8.3f} {pos[2]:>8.3f} 1.00 {bfactor:>6.2f} 1 " + f"{seq_id:>3} {token:<3} {current_chain_id} {atom_name:<4}" + ) + lines.append(line) + atom_id += 1 + + seq_id += 1 + + lines.append("#") + return "\n".join(lines) + + def dockq(self, native: "MolecularComplex") -> Any: + """Compute DockQ score against native structure. + + Args: + native: Native MolecularComplex to compute DockQ against + + Returns: + DockQ result containing score and alignment information + """ + # Imports moved to top of file + + # Convert both complexes to ProteinComplex format for DockQ computation + # This extracts only the protein portion and converts to PDB format + try: + self_pc = self.to_protein_complex() + native_pc = native.to_protein_complex() + except ValueError as e: + raise ValueError( + f"Cannot convert MolecularComplex to ProteinComplex for DockQ: {e}" + ) + + # Normalize chain IDs for PDB compatibility + self_pc = self_pc.normalize_chain_ids_for_pdb() + native_pc = native_pc.normalize_chain_ids_for_pdb() + + # Use the existing ProteinComplex.dockq() method + try: + dockq_result = self_pc.dockq(native_pc) + return dockq_result + except Exception: + # Fallback to manual DockQ computation if ProteinComplex.dockq() fails + return self._compute_dockq_manual(native) + + def _compute_dockq_manual(self, native: "MolecularComplex") -> Any: + """Manual DockQ computation fallback.""" + # Imports moved to top of file + + # Convert both complexes to ProteinComplex format + try: + self_pc = self.to_protein_complex() + native_pc = native.to_protein_complex() + except ValueError as e: + raise ValueError( + f"Cannot convert MolecularComplex to ProteinComplex for DockQ: {e}" + ) + + # Normalize chain IDs for PDB compatibility + self_pc = self_pc.normalize_chain_ids_for_pdb() + native_pc = native_pc.normalize_chain_ids_for_pdb() + + # Write temporary PDB files and run DockQ + with TemporaryDirectory() as tdir: + dir_path = Path(tdir) + self_pdb = dir_path / "self.pdb" + native_pdb = dir_path / "native.pdb" + + # Write PDB files + self_pc.to_pdb(self_pdb) + native_pc.to_pdb(native_pdb) + + # Run DockQ + try: + output = check_output(["DockQ", str(self_pdb), str(native_pdb)]) + output_text = output.decode() + + # Parse DockQ output + lines = output_text.split("\n") + + # Find the total DockQ score + dockq_score = None + for line in lines: + if "Total DockQ" in line: + match = re.search(r"Total DockQ.*: ([\d.]+)", line) + if match: + dockq_score = float(match.group(1)) + break + + if dockq_score is None: + # Try to find individual DockQ scores + for line in lines: + if line.startswith("DockQ") and ":" in line: + try: + dockq_score = float(line.split(":")[1].strip()) + break + except (ValueError, IndexError): + continue + + if dockq_score is None: + raise ValueError("Could not parse DockQ score from output") + + # Return a simple result structure + return { + "total_dockq": dockq_score, + "raw_output": output_text, + "aligned": self, # Return self as aligned structure + } + + except FileNotFoundError: + raise RuntimeError( + "DockQ is not installed. Please install DockQ to use this method." + ) + except Exception as e: + raise RuntimeError(f"DockQ computation failed: {e}") + + def rmsd(self, target: "MolecularComplex", **kwargs) -> float: + """Compute RMSD against target structure. + + Args: + target: Target MolecularComplex to compute RMSD against + **kwargs: Additional arguments passed to compute_rmsd + + Returns: + float: RMSD value between the two structures + """ + # Imports moved to top of file + + # Ensure both complexes have the same number of tokens + if len(self) != len(target): + raise ValueError( + f"Complexes must have the same number of tokens: {len(self)} vs {len(target)}" + ) + + # Extract center positions for each token (using centroid of atoms) + mobile_coords = [] + target_coords = [] + atom_mask = [] + + for i in range(len(self)): + # Get atom positions for this token + mobile_start, mobile_end = self.token_to_atoms[i] + target_start, target_end = target.token_to_atoms[i] + + # Extract atom positions + mobile_atoms = self.atom_positions[mobile_start:mobile_end] + target_atoms = target.atom_positions[target_start:target_end] + + # Check if both tokens have atoms + if len(mobile_atoms) == 0 or len(target_atoms) == 0: + # Skip tokens with no atoms + continue + + # For simplicity, use the centroid of atoms as the representative position + mobile_center = mobile_atoms.mean(axis=0) + target_center = target_atoms.mean(axis=0) + + mobile_coords.append(mobile_center) + target_coords.append(target_center) + atom_mask.append(True) + + if len(mobile_coords) == 0: + raise ValueError("No valid atoms found for RMSD computation") + + # Convert to tensors + mobile_tensor = torch.from_numpy(np.stack(mobile_coords, axis=0)).unsqueeze( + 0 + ) # [1, N, 3] + target_tensor = torch.from_numpy(np.stack(target_coords, axis=0)).unsqueeze( + 0 + ) # [1, N, 3] + mask_tensor = torch.tensor(atom_mask, dtype=torch.bool).unsqueeze(0) # [1, N] + + # Compute RMSD using existing infrastructure + rmsd_value = compute_rmsd( + mobile=mobile_tensor, + target=target_tensor, + atom_exists_mask=mask_tensor, + reduction="batch", + **kwargs, + ) + + return float(rmsd_value) + + def lddt_ca(self, target: "MolecularComplex", **kwargs) -> float: + """Compute LDDT score against target structure. + + Args: + target: Target MolecularComplex to compute LDDT against + **kwargs: Additional arguments passed to compute_lddt + + Returns: + float: LDDT value between the two structures + """ + # Imports moved to top of file + + # Ensure both complexes have the same number of tokens + if len(self) != len(target): + raise ValueError( + f"Complexes must have the same number of tokens: {len(self)} vs {len(target)}" + ) + + # Extract center positions for each token (using centroid of atoms) + mobile_coords = [] + target_coords = [] + atom_mask = [] + + for i in range(len(self)): + # Get atom positions for this token + mobile_start, mobile_end = self.token_to_atoms[i] + target_start, target_end = target.token_to_atoms[i] + + # Extract atom positions + mobile_atoms = self.atom_positions[mobile_start:mobile_end] + target_atoms = target.atom_positions[target_start:target_end] + + # Check if both tokens have atoms + if len(mobile_atoms) == 0 or len(target_atoms) == 0: + # Skip tokens with no atoms + mobile_coords.append(np.full(3, np.nan)) + target_coords.append(np.full(3, np.nan)) + atom_mask.append(False) + continue + + # For simplicity, use the centroid of atoms as the representative position + mobile_center = mobile_atoms.mean(axis=0) + target_center = target_atoms.mean(axis=0) + + mobile_coords.append(mobile_center) + target_coords.append(target_center) + atom_mask.append(True) + + if not any(atom_mask): + raise ValueError("No valid atoms found for LDDT computation") + + # Convert to tensors + mobile_tensor = torch.from_numpy(np.stack(mobile_coords, axis=0)).unsqueeze( + 0 + ) # [1, N, 3] + target_tensor = torch.from_numpy(np.stack(target_coords, axis=0)).unsqueeze( + 0 + ) # [1, N, 3] + mask_tensor = torch.tensor(atom_mask, dtype=torch.bool).unsqueeze(0) # [1, N] + + # Compute LDDT using existing infrastructure + lddt_value = compute_lddt( + all_atom_pred_pos=mobile_tensor, + all_atom_positions=target_tensor, + all_atom_mask=mask_tensor, + per_residue=False, # Return overall LDDT score + **kwargs, + ) + + return float(lddt_value) + + def state_dict(self): + """This state dict is optimized for storage, so it turns things to fp16 whenever + possible and converts numpy arrays to lists for JSON serialization. + """ + dct = {k: v for k, v in vars(self).items()} + for k, v in dct.items(): + if isinstance(v, np.ndarray): + match v.dtype: + case np.int64: + dct[k] = v.astype(np.int32).tolist() + case np.float64 | np.float32: + dct[k] = v.astype(np.float16).tolist() + case _: + dct[k] = v.tolist() + elif isinstance(v, MolecularComplexMetadata): + dct[k] = asdict(v) + + return dct + + def to_blob(self) -> bytes: + return brotli.compress(msgpack.dumps(self.state_dict()), quality=5) + + @classmethod + def from_state_dict(cls, dct): + for k, v in dct.items(): + if isinstance(v, list) and k in [ + "atom_positions", + "atom_elements", + "token_to_atoms", + "plddt", + ]: + dct[k] = np.array(v) + + for k, v in dct.items(): + if isinstance(v, np.ndarray): + if k in ["atom_positions", "plddt"]: + dct[k] = v.astype(np.float32) + elif k in ["token_to_atoms"]: + dct[k] = v.astype(np.int32) + + dct["metadata"] = MolecularComplexMetadata(**dct["metadata"]) + return cls(**dct) + + @classmethod + def from_blob(cls, input: Path | str | io.BytesIO | bytes): + match input: + case Path() | str(): + bytes = Path(input).read_bytes() + case io.BytesIO(): + bytes = input.getvalue() + case _: + bytes = input + return cls.from_state_dict( + msgpack.loads(brotli.decompress(bytes), strict_map_key=False) + ) diff --git a/esm/utils/structure/protein_complex.py b/esm/utils/structure/protein_complex.py index 7e2c38de..306c8832 100644 --- a/esm/utils/structure/protein_complex.py +++ b/esm/utils/structure/protein_complex.py @@ -377,11 +377,14 @@ def switch_assembly(self, id: str): assert self.metadata.mmcif is not None return get_assembly_fast(self.metadata.mmcif, assembly_id=id) - def state_dict(self, backbone_only=False): + def state_dict(self, backbone_only=False, json_serializable=False): """This state dict is optimized for storage, so it turns things to fp16 whenever possible. Note that we also only support int32 residue indices, I'm hoping we don't need more than 2**32 residues...""" dct = {k: v for k, v in vars(self).items()} + if backbone_only: + dct["atom37_mask"][:, 3:] = False + dct["atom37_positions"] = dct["atom37_positions"][dct["atom37_mask"]] for k, v in dct.items(): if isinstance(v, np.ndarray): match v.dtype: @@ -391,9 +394,10 @@ def state_dict(self, backbone_only=False): dct[k] = v.astype(np.float16) case _: pass + if json_serializable: + dct[k] = v.tolist() elif isinstance(v, ProteinComplexMetadata): dct[k] = asdict(v) - dct["atom37_positions"] = dct["atom37_positions"][dct["atom37_mask"]] dct["metadata"]["mmcif"] = None # These can be populated with non-serializable objects and are not needed for reconstruction dct.pop("atoms", None) @@ -406,6 +410,10 @@ def to_blob(self, backbone_only=False) -> bytes: @classmethod def from_state_dict(cls, dct): + for k, v in dct.items(): + if isinstance(v, list): + dct[k] = np.array(v) + atom37 = np.full((*dct["atom37_mask"].shape, 3), np.nan) atom37[dct["atom37_mask"]] = dct["atom37_positions"] dct["atom37_positions"] = atom37 diff --git a/esm/utils/system.py b/esm/utils/system.py new file mode 100644 index 00000000..c2800e57 --- /dev/null +++ b/esm/utils/system.py @@ -0,0 +1,45 @@ +import io +import subprocess +import typing as T +from pathlib import Path + +PathLike = T.Union[str, Path] +PathOrBuffer = T.Union[PathLike, io.StringIO] + + +def run_subprocess_with_errorcheck( + *popenargs, + capture_output: bool = False, + quiet: bool = False, + env: dict[str, str] | None = None, + shell: bool = False, + executable: str | None = None, + **kws, +) -> subprocess.CompletedProcess: + """A command similar to subprocess.run, however the errormessage will + contain the stderr when using this function. This makes it significantly + easier to diagnose issues. + """ + try: + if capture_output: + stdout = subprocess.PIPE + elif quiet: + stdout = subprocess.DEVNULL + else: + stdout = None + + p = subprocess.run( + *popenargs, + stderr=subprocess.PIPE, + stdout=stdout, + check=True, + env=env, + shell=shell, + executable=executable, + **kws, + ) + except subprocess.CalledProcessError as e: + raise RuntimeError( + f"Command failed with errorcode {e.returncode}." f"\n\n{e.stderr.decode()}" + ) + return p diff --git a/pixi.lock b/pixi.lock index be3b375f..897f11a5 100644 --- a/pixi.lock +++ b/pixi.lock @@ -1726,8 +1726,8 @@ packages: requires_python: '>=3.8' - pypi: ./ name: esm - version: 3.2.2 - sha256: c14e2546bda5f0910c14acfabb7ea334e7171905c6799b43178f0420a92d6f3e + version: 3.2.2.post2 + sha256: 3f59a2977c85d35b4b1353902fa90e35d02acbabe6ffb506727bd406ec987ad1 requires_dist: - torch>=2.2.0 - torchvision diff --git a/pyproject.toml b/pyproject.toml index 2f8008f8..0dcc398d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "esm" -version = "3.2.2" +version = "3.2.2.post2" description = "EvolutionaryScale open model repository" readme = "README.md" requires-python = ">=3.12,<3.13" @@ -45,7 +45,6 @@ dependencies = [ "pygtrie", "dna_features_viewer", ] - # Pytest [tool.pytest.ini_options] addopts = """ diff --git a/tests/oss_pytests/requirements.txt b/tests/oss_pytests/requirements.txt index d120ce35..143baa7a 100644 --- a/tests/oss_pytests/requirements.txt +++ b/tests/oss_pytests/requirements.txt @@ -1,3 +1,2 @@ -esm +esm >=3.2.1post1,<4.0.0 pytest -httpx # TODO(williamxi): Remove this after the esm repo is fixed diff --git a/tests/oss_pytests/test_oss_client.py b/tests/oss_pytests/test_oss_client.py index 9dd5b188..96940135 100644 --- a/tests/oss_pytests/test_oss_client.py +++ b/tests/oss_pytests/test_oss_client.py @@ -1,6 +1,7 @@ import os import pytest +import torch from esm.sdk import client # pyright: ignore from esm.sdk.api import ( # pyright: ignore @@ -37,6 +38,7 @@ def test_oss_esm3_client(): logits_config = LogitsConfig(sequence=True, return_embeddings=True) result = esm3_client.logits(input=encoded_protein, config=logits_config) assert isinstance(result, LogitsOutput) + assert isinstance(result.logits.sequence, torch.Tensor) sampling_config = SamplingConfig(sequence=SamplingTrackConfig(temperature=0.1)) result = esm3_client.forward_and_sample( @@ -69,6 +71,7 @@ def test_oss_esmc_client(): ) result = esmc_client.logits(input=encoded_protein, config=logits_config) assert isinstance(result, LogitsOutput) + assert isinstance(result.logits.sequence, torch.Tensor) @pytest.mark.sdk From 0382151104e2a3c5e8c14ea14e5a391c9bdfb936 Mon Sep 17 00:00:00 2001 From: Neil Thomas Date: Fri, 19 Sep 2025 20:59:11 +0000 Subject: [PATCH 2/4] Add none check for pyright --- tests/oss_pytests/test_oss_client.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/oss_pytests/test_oss_client.py b/tests/oss_pytests/test_oss_client.py index 96940135..4e7638f6 100644 --- a/tests/oss_pytests/test_oss_client.py +++ b/tests/oss_pytests/test_oss_client.py @@ -38,6 +38,7 @@ def test_oss_esm3_client(): logits_config = LogitsConfig(sequence=True, return_embeddings=True) result = esm3_client.logits(input=encoded_protein, config=logits_config) assert isinstance(result, LogitsOutput) + assert result.logits is not None assert isinstance(result.logits.sequence, torch.Tensor) sampling_config = SamplingConfig(sequence=SamplingTrackConfig(temperature=0.1)) @@ -71,6 +72,7 @@ def test_oss_esmc_client(): ) result = esmc_client.logits(input=encoded_protein, config=logits_config) assert isinstance(result, LogitsOutput) + assert result.logits is not None assert isinstance(result.logits.sequence, torch.Tensor) From 34c6638b587c705e0d86fc851335e51bb3e0c578 Mon Sep 17 00:00:00 2001 From: Neil Thomas Date: Fri, 19 Sep 2025 21:46:14 +0000 Subject: [PATCH 3/4] Sync --- cookbook/local/open_generate.ipynb | 1 - cookbook/local/raw_forwards.py | 5 +- cookbook/snippets/fold_invfold.py | 1 - cookbook/tutorials/1_esmprotein.ipynb | 4 +- cookbook/tutorials/3_gfp_design.ipynb | 1 - cookbook/tutorials/4_forge_generate.ipynb | 1 - cookbook/tutorials/5_guided_generation.ipynb | 1 - esm/__init__.py | 1 + esm/layers/attention.py | 5 +- esm/layers/blocks.py | 9 ++- esm/layers/structure_proj.py | 5 +- esm/models/esm3.py | 14 +++- esm/models/function_decoder.py | 4 +- esm/models/vqvae.py | 5 +- esm/pretrained.py | 10 ++- esm/sdk/api.py | 22 ++++-- esm/sdk/base_forge_client.py | 16 ++++- esm/sdk/forge.py | 67 ++++++++++++++----- esm/tokenization/__init__.py | 5 +- esm/utils/decoding.py | 24 +++++-- esm/utils/encoding.py | 29 ++++++-- esm/utils/forge_context_manager.py | 5 +- esm/utils/function/encode_decode.py | 13 +++- esm/utils/generation.py | 9 ++- esm/utils/misc.py | 6 +- esm/utils/msa/__init__.py | 6 +- esm/utils/msa/msa.py | 11 ++- esm/utils/sampling.py | 15 ++++- esm/utils/structure/aligner.py | 4 +- esm/utils/structure/atom_indexer.py | 4 +- esm/utils/structure/input_builder.py | 1 + esm/utils/structure/molecular_complex.py | 10 ++- esm/utils/structure/protein_chain.py | 16 ++++- esm/utils/structure/protein_complex.py | 10 ++- esm/widgets/components/function_annotator.py | 4 +- esm/widgets/components/results_visualizer.py | 8 ++- .../components/sasa_prompt_selector.py | 13 +++- .../secondary_structure_prompt_selector.py | 13 +++- .../components/sequence_prompt_selector.py | 4 +- .../components/structure_prompt_selector.py | 8 ++- .../drawing/draw_function_annotations.py | 5 +- esm/widgets/utils/prompting.py | 4 +- esm/widgets/views/esm3_generation_launcher.py | 8 ++- esm/widgets/views/esm3_prompt_selector.py | 4 +- esm/widgets/views/generation.py | 16 +++-- esm/widgets/views/inverse_folding.py | 4 +- esm/widgets/views/login.py | 5 +- esm/widgets/views/prediction.py | 4 +- tests/oss_pytests/test_oss_client.py | 1 - 49 files changed, 339 insertions(+), 102 deletions(-) diff --git a/cookbook/local/open_generate.ipynb b/cookbook/local/open_generate.ipynb index 32a72c38..c8cec4c1 100644 --- a/cookbook/local/open_generate.ipynb +++ b/cookbook/local/open_generate.ipynb @@ -38,7 +38,6 @@ "\n", "!pip install py3Dmol\n", "import py3Dmol\n", - "\n", "from esm.models.esm3 import ESM3\n", "from esm.sdk.api import ESMProtein, GenerationConfig\n", "from esm.utils.structure.protein_chain import ProteinChain" diff --git a/cookbook/local/raw_forwards.py b/cookbook/local/raw_forwards.py index baad28ee..5701fa2a 100644 --- a/cookbook/local/raw_forwards.py +++ b/cookbook/local/raw_forwards.py @@ -2,7 +2,6 @@ import torch import torch.nn.functional as F - from esm.pretrained import ( ESM3_function_decoder_v0, ESM3_sm_open_v0, @@ -13,7 +12,9 @@ from esm.tokenization.function_tokenizer import ( InterProQuantizedTokenizer as EsmFunctionTokenizer, ) -from esm.tokenization.sequence_tokenizer import EsmSequenceTokenizer +from esm.tokenization.sequence_tokenizer import ( + EsmSequenceTokenizer, +) from esm.utils.structure.protein_chain import ProteinChain from esm.utils.types import FunctionAnnotation diff --git a/cookbook/snippets/fold_invfold.py b/cookbook/snippets/fold_invfold.py index e6be5485..cb24db6a 100644 --- a/cookbook/snippets/fold_invfold.py +++ b/cookbook/snippets/fold_invfold.py @@ -2,7 +2,6 @@ from typing import cast import numpy as np - from esm.sdk.api import ( ESM3InferenceClient, ESMProtein, diff --git a/cookbook/tutorials/1_esmprotein.ipynb b/cookbook/tutorials/1_esmprotein.ipynb index 13017733..e143ff13 100644 --- a/cookbook/tutorials/1_esmprotein.ipynb +++ b/cookbook/tutorials/1_esmprotein.ipynb @@ -72,7 +72,6 @@ "outputs": [], "source": [ "from biotite.database import rcsb\n", - "\n", "from esm.sdk.api import ESMProtein\n", "from esm.utils.structure.protein_chain import ProteinChain\n", "from esm.utils.types import FunctionAnnotation\n", @@ -497,9 +496,8 @@ "# Functions for visualizing InterPro function annotations\n", "\n", "from dna_features_viewer import GraphicFeature, GraphicRecord\n", - "from matplotlib import colormaps\n", - "\n", "from esm.utils.function.interpro import InterPro, InterProEntryType\n", + "from matplotlib import colormaps\n", "\n", "\n", "def visualize_function_annotations(\n", diff --git a/cookbook/tutorials/3_gfp_design.ipynb b/cookbook/tutorials/3_gfp_design.ipynb index f2bb85b9..95b42418 100644 --- a/cookbook/tutorials/3_gfp_design.ipynb +++ b/cookbook/tutorials/3_gfp_design.ipynb @@ -64,7 +64,6 @@ "import matplotlib.pyplot as pl\n", "import py3Dmol\n", "import torch\n", - "\n", "from esm.sdk import client\n", "from esm.sdk.api import ESMProtein, GenerationConfig\n", "from esm.utils.structure.protein_chain import ProteinChain" diff --git a/cookbook/tutorials/4_forge_generate.ipynb b/cookbook/tutorials/4_forge_generate.ipynb index 9962e54d..5fb6e676 100644 --- a/cookbook/tutorials/4_forge_generate.ipynb +++ b/cookbook/tutorials/4_forge_generate.ipynb @@ -36,7 +36,6 @@ "\n", "!pip install py3Dmol\n", "import py3Dmol\n", - "\n", "from esm.sdk import client\n", "from esm.sdk.api import ESMProtein, GenerationConfig\n", "from esm.utils.structure.protein_chain import ProteinChain" diff --git a/cookbook/tutorials/5_guided_generation.ipynb b/cookbook/tutorials/5_guided_generation.ipynb index d35a8c94..b2d8c7dc 100644 --- a/cookbook/tutorials/5_guided_generation.ipynb +++ b/cookbook/tutorials/5_guided_generation.ipynb @@ -49,7 +49,6 @@ "source": [ "import biotite.structure as bs\n", "import py3Dmol\n", - "\n", "from esm.sdk.api import ESMProtein, GenerationConfig\n", "from esm.sdk.experimental import ESM3GuidedDecoding, GuidedDecodingScoringFunction" ] diff --git a/esm/__init__.py b/esm/__init__.py index 98a35b2d..c37e4c07 100644 --- a/esm/__init__.py +++ b/esm/__init__.py @@ -1 +1,2 @@ __version__ = "3.2.2.post2" + diff --git a/esm/layers/attention.py b/esm/layers/attention.py index 964d4282..b0f7c2b5 100644 --- a/esm/layers/attention.py +++ b/esm/layers/attention.py @@ -5,7 +5,10 @@ import torch.nn.functional as F from torch import nn -from esm.layers.rotary import RotaryEmbedding, TritonRotaryEmbedding +from esm.layers.rotary import ( + RotaryEmbedding, + TritonRotaryEmbedding, +) try: from flash_attn import flash_attn_varlen_qkvpacked_func # type: ignore diff --git a/esm/layers/blocks.py b/esm/layers/blocks.py index 76ebbe06..593b277f 100644 --- a/esm/layers/blocks.py +++ b/esm/layers/blocks.py @@ -2,8 +2,13 @@ import torch.nn as nn import torch.nn.functional as F -from esm.layers.attention import FlashMultiHeadAttention, MultiHeadAttention -from esm.layers.geom_attention import GeometricReasoningOriginalImpl +from esm.layers.attention import ( + FlashMultiHeadAttention, + MultiHeadAttention, +) +from esm.layers.geom_attention import ( + GeometricReasoningOriginalImpl, +) from esm.utils.structure.affine3d import Affine3D diff --git a/esm/layers/structure_proj.py b/esm/layers/structure_proj.py index faad0fe9..783ddeb4 100644 --- a/esm/layers/structure_proj.py +++ b/esm/layers/structure_proj.py @@ -2,7 +2,10 @@ import torch.nn as nn from esm.utils.constants.physics import BB_COORDINATES -from esm.utils.structure.affine3d import Affine3D, RotationMatrix +from esm.utils.structure.affine3d import ( + Affine3D, + RotationMatrix, +) class Dim6RotStructureHead(nn.Module): diff --git a/esm/models/esm3.py b/esm/models/esm3.py index 218a8e90..cbe02ddd 100644 --- a/esm/models/esm3.py +++ b/esm/models/esm3.py @@ -13,7 +13,10 @@ from esm.layers.regression_head import RegressionHead from esm.layers.transformer_stack import TransformerStack from esm.models.function_decoder import FunctionTokenDecoder -from esm.models.vqvae import StructureTokenDecoder, StructureTokenEncoder +from esm.models.vqvae import ( + StructureTokenDecoder, + StructureTokenEncoder, +) from esm.sdk.api import ( ESM3InferenceClient, ESMProtein, @@ -29,7 +32,10 @@ from esm.tokenization import TokenizerCollectionProtocol from esm.utils import encoding from esm.utils.constants import esm3 as C -from esm.utils.constants.models import ESM3_OPEN_SMALL, normalize_model_name +from esm.utils.constants.models import ( + ESM3_OPEN_SMALL, + normalize_model_name, +) from esm.utils.decoding import decode_protein_tensor from esm.utils.generation import ( _batch_forward, @@ -44,7 +50,9 @@ get_default_sampling_config, validate_sampling_config, ) -from esm.utils.structure.affine3d import build_affine3d_from_coordinates +from esm.utils.structure.affine3d import ( + build_affine3d_from_coordinates, +) @dataclass diff --git a/esm/models/function_decoder.py b/esm/models/function_decoder.py index e5f1fb28..c4f32992 100644 --- a/esm/models/function_decoder.py +++ b/esm/models/function_decoder.py @@ -12,7 +12,9 @@ from esm.layers.regression_head import RegressionHead from esm.layers.transformer_stack import TransformerStack -from esm.tokenization.function_tokenizer import InterProQuantizedTokenizer +from esm.tokenization.function_tokenizer import ( + InterProQuantizedTokenizer, +) from esm.utils.constants import esm3 as C from esm.utils.misc import merge_annotations, merge_ranges from esm.utils.types import FunctionAnnotation diff --git a/esm/models/vqvae.py b/esm/models/vqvae.py index 37bc3945..0f5226a4 100644 --- a/esm/models/vqvae.py +++ b/esm/models/vqvae.py @@ -7,7 +7,10 @@ from esm.layers.transformer_stack import TransformerStack from esm.utils.constants import esm3 as C from esm.utils.misc import knn_graph -from esm.utils.structure.affine3d import Affine3D, build_affine3d_from_coordinates +from esm.utils.structure.affine3d import ( + Affine3D, + build_affine3d_from_coordinates, +) from esm.utils.structure.predicted_aligned_error import ( compute_predicted_aligned_error, compute_tm, diff --git a/esm/pretrained.py b/esm/pretrained.py index e452e1d2..b9121511 100644 --- a/esm/pretrained.py +++ b/esm/pretrained.py @@ -6,8 +6,14 @@ from esm.models.esm3 import ESM3 from esm.models.esmc import ESMC from esm.models.function_decoder import FunctionTokenDecoder -from esm.models.vqvae import StructureTokenDecoder, StructureTokenEncoder -from esm.tokenization import get_esm3_model_tokenizers, get_esmc_model_tokenizers +from esm.models.vqvae import ( + StructureTokenDecoder, + StructureTokenEncoder, +) +from esm.tokenization import ( + get_esm3_model_tokenizers, + get_esmc_model_tokenizers, +) from esm.utils.constants.esm3 import data_root from esm.utils.constants.models import ( ESM3_FUNCTION_DECODER_V0, diff --git a/esm/sdk/api.py b/esm/sdk/api.py index bd93f0cd..6b152556 100644 --- a/esm/sdk/api.py +++ b/esm/sdk/api.py @@ -2,19 +2,27 @@ from abc import ABC from copy import deepcopy -from typing import Sequence +from typing import List, Sequence import attr import torch from attr import asdict, define import esm.utils.constants.api as C -from esm.tokenization import TokenizerCollectionProtocol, get_esm3_model_tokenizers +from esm.tokenization import ( + TokenizerCollectionProtocol, + get_esm3_model_tokenizers, +) from esm.utils import encoding from esm.utils.constants.models import ESM3_OPEN_SMALL -from esm.utils.misc import get_chainbreak_boundaries_from_sequence +from esm.utils.misc import ( + get_chainbreak_boundaries_from_sequence, +) from esm.utils.structure.protein_chain import ProteinChain -from esm.utils.structure.protein_complex import SINGLE_LETTER_CHAIN_IDS, ProteinComplex +from esm.utils.structure.protein_complex import ( + SINGLE_LETTER_CHAIN_IDS, + ProteinComplex, +) from esm.utils.types import FunctionAnnotation, PathOrBuffer @@ -35,6 +43,7 @@ class ESMProtein(ProteinType): plddt: torch.Tensor | None = None ptm: torch.Tensor | None = None + # When calling EvolutionaryScale API, use this flag to disclose any # sequences that may potentially have concerns. # Such sequences may not go through standard safety filter for approved users. @@ -327,6 +336,8 @@ class InverseFoldingConfig: temperature: float = 1.0 + + ## Low Level Endpoint Types @define class SamplingTrackConfig: @@ -391,6 +402,9 @@ class LogitsConfig: ith_hidden_layer: int = -1 + + + @define class LogitsOutput: logits: ForwardTrackData | None = None diff --git a/esm/sdk/base_forge_client.py b/esm/sdk/base_forge_client.py index ff05b541..9d27e647 100644 --- a/esm/sdk/base_forge_client.py +++ b/esm/sdk/base_forge_client.py @@ -1,9 +1,13 @@ +import asyncio +import time +from abc import ABC, abstractmethod from typing import Any from urllib.parse import urljoin import httpx from esm.sdk.api import ESMProteinError +from esm.sdk.retry import retry_decorator from esm.utils.decoding import assemble_message @@ -112,7 +116,10 @@ async def _async_post( ): try: request, headers = self.prepare_request( - request, potential_sequence_of_concern, return_bytes, headers + request, + potential_sequence_of_concern, + return_bytes, + headers, ) response = await self.async_client.post( url=urljoin(self.url, f"/api/v1/{endpoint}"), @@ -142,7 +149,10 @@ def _post( ): try: request, headers = self.prepare_request( - request, potential_sequence_of_concern, return_bytes, headers + request, + potential_sequence_of_concern, + return_bytes, + headers, ) response = self.client.post( url=urljoin(self.url, f"/api/v1/{endpoint}"), @@ -160,3 +170,5 @@ def _post( error_code=500, error_msg=f"Failed to submit request to {endpoint}. Error: {e}", ) + + diff --git a/esm/sdk/forge.py b/esm/sdk/forge.py index 67c94b07..ac105c03 100644 --- a/esm/sdk/forge.py +++ b/esm/sdk/forge.py @@ -4,7 +4,7 @@ import base64 import pickle from concurrent.futures import ThreadPoolExecutor -from typing import Any, Sequence +from typing import Any, Literal, Sequence, cast import torch @@ -20,14 +20,21 @@ InverseFoldingConfig, LogitsConfig, LogitsOutput, + ProteinChain, ProteinType, SamplingConfig, SamplingTrackConfig, ) -from esm.sdk.base_forge_client import _BaseForgeInferenceClient +from esm.sdk.base_forge_client import ( + _BaseForgeInferenceClient, +) from esm.sdk.retry import retry_decorator from esm.utils.constants.api import MIMETYPE_ES_PICKLE -from esm.utils.misc import deserialize_tensors, maybe_list, maybe_tensor +from esm.utils.misc import ( + deserialize_tensors, + maybe_list, + maybe_tensor, +) from esm.utils.msa import MSA from esm.utils.structure.input_builder import ( StructurePredictionInput, @@ -101,9 +108,13 @@ def __init__( ) @staticmethod - def _process_fold_request(sequence: str, model_name: str | None): + def _process_fold_request( + sequence: str, + model_name: str | None, + ): request: dict[str, Any] = {"sequence": sequence} + request["model"] = model_name return request @@ -138,6 +149,7 @@ def process_inverse_fold_request( return request + async def _async_fetch_msa(self, sequence: str) -> MSA: print("Fetching MSA ... this may take a few minutes") # Accept both "|" and ":" as the chainbreak token. @@ -176,11 +188,15 @@ async def async_fold( del potential_sequence_of_concern request = self._process_fold_request( - sequence, model_name if model_name is not None else self.model + sequence, + model_name if model_name is not None else self.model, ) try: - data = await self._async_post("fold", request) + data = await self._async_post( + "fold", + request, + ) except ESMProteinError as e: return e @@ -207,11 +223,15 @@ def fold( del potential_sequence_of_concern request = self._process_fold_request( - sequence, model_name if model_name is not None else self.model + sequence, + model_name if model_name is not None else self.model, ) try: - data = self._post("fold", request) + data = self._post( + "fold", + request, + ) except ESMProteinError as e: return e @@ -219,7 +239,9 @@ def fold( @retry_decorator async def async_fold_all_atom( - self, all_atom_input: StructurePredictionInput, model_name: str | None = None + self, + all_atom_input: StructurePredictionInput, + model_name: str | None = None, ) -> MolecularComplexResult | list[MolecularComplexResult] | ESMProteinError: """Fold a molecular complex containing proteins, nucleic acids, and/or ligands. @@ -228,11 +250,15 @@ async def async_fold_all_atom( model_name: Override the client level model name if needed """ request = self._process_fold_all_atom_request( - all_atom_input, model_name if model_name is not None else self.model + all_atom_input, + model_name if model_name is not None else self.model, ) try: - data = await self._async_post("fold_all_atom", request) + data = await self._async_post( + "fold_all_atom", + request, + ) except ESMProteinError as e: return e @@ -240,7 +266,9 @@ async def async_fold_all_atom( @retry_decorator def fold_all_atom( - self, all_atom_input: StructurePredictionInput, model_name: str | None = None + self, + all_atom_input: StructurePredictionInput, + model_name: str | None = None, ) -> MolecularComplexResult | list[MolecularComplexResult] | ESMProteinError: """Predict coordinates for a molecular complex containing proteins, dna, rna, and/or ligands. @@ -249,11 +277,15 @@ def fold_all_atom( model_name: Override the client level model name if needed """ request = self._process_fold_all_atom_request( - all_atom_input, model_name if model_name is not None else self.model + all_atom_input, + model_name if model_name is not None else self.model, ) try: - data = self._post("fold_all_atom", request) + data = self._post( + "fold_all_atom", + request, + ) except ESMProteinError as e: return e @@ -261,13 +293,15 @@ def fold_all_atom( @staticmethod def _process_fold_all_atom_request( - all_atom_input: StructurePredictionInput, model_name: str | None = None + all_atom_input: StructurePredictionInput, + model_name: str | None = None, ) -> dict[str, Any]: request: dict[str, Any] = { "all_atom_input": serialize_structure_prediction_input(all_atom_input), "model": model_name, } + return request @staticmethod @@ -352,6 +386,7 @@ def inverse_fold( return ESMProtein(sequence=data["sequence"]) + class ESM3ForgeInferenceClient(ESM3InferenceClient, _BaseForgeInferenceClient): def __init__( self, @@ -1177,3 +1212,5 @@ def raw_model(self): raise NotImplementedError( f"Can not get underlying remote model {self.model} from a Forge client." ) + + diff --git a/esm/tokenization/__init__.py b/esm/tokenization/__init__.py index 6db76554..ea609225 100644 --- a/esm/tokenization/__init__.py +++ b/esm/tokenization/__init__.py @@ -1,7 +1,10 @@ from dataclasses import dataclass from typing import Protocol -from esm.utils.constants.models import ESM3_OPEN_SMALL, normalize_model_name +from esm.utils.constants.models import ( + ESM3_OPEN_SMALL, + normalize_model_name, +) from .function_tokenizer import InterProQuantizedTokenizer from .residue_tokenizer import ResidueAnnotationsTokenizer diff --git a/esm/utils/decoding.py b/esm/utils/decoding.py index 1fe256b6..b5588527 100644 --- a/esm/utils/decoding.py +++ b/esm/utils/decoding.py @@ -10,12 +10,24 @@ from esm.models.vqvae import StructureTokenDecoder from esm.sdk.api import ESMProtein, ESMProteinTensor from esm.tokenization import TokenizerCollectionProtocol -from esm.tokenization.function_tokenizer import InterProQuantizedTokenizer -from esm.tokenization.residue_tokenizer import ResidueAnnotationsTokenizer -from esm.tokenization.sasa_tokenizer import SASADiscretizingTokenizer -from esm.tokenization.sequence_tokenizer import EsmSequenceTokenizer -from esm.tokenization.ss_tokenizer import SecondaryStructureTokenizer -from esm.tokenization.structure_tokenizer import StructureTokenizer +from esm.tokenization.function_tokenizer import ( + InterProQuantizedTokenizer, +) +from esm.tokenization.residue_tokenizer import ( + ResidueAnnotationsTokenizer, +) +from esm.tokenization.sasa_tokenizer import ( + SASADiscretizingTokenizer, +) +from esm.tokenization.sequence_tokenizer import ( + EsmSequenceTokenizer, +) +from esm.tokenization.ss_tokenizer import ( + SecondaryStructureTokenizer, +) +from esm.tokenization.structure_tokenizer import ( + StructureTokenizer, +) from esm.tokenization.tokenizer_base import EsmTokenizerBase from esm.utils.constants import api as api_constants from esm.utils.constants import esm3 as C diff --git a/esm/utils/encoding.py b/esm/utils/encoding.py index 8461709d..83c9d033 100644 --- a/esm/utils/encoding.py +++ b/esm/utils/encoding.py @@ -7,13 +7,26 @@ from esm.tokenization.function_tokenizer import ( InterProQuantizedTokenizer as EsmFunctionTokenizer, ) -from esm.tokenization.residue_tokenizer import ResidueAnnotationsTokenizer -from esm.tokenization.sasa_tokenizer import SASADiscretizingTokenizer -from esm.tokenization.sequence_tokenizer import EsmSequenceTokenizer -from esm.tokenization.ss_tokenizer import SecondaryStructureTokenizer -from esm.tokenization.structure_tokenizer import StructureTokenizer + +from esm.tokenization.residue_tokenizer import ( + ResidueAnnotationsTokenizer, +) +from esm.tokenization.sasa_tokenizer import ( + SASADiscretizingTokenizer, +) +from esm.tokenization.sequence_tokenizer import ( + EsmSequenceTokenizer, +) +from esm.tokenization.ss_tokenizer import ( + SecondaryStructureTokenizer, +) +from esm.tokenization.structure_tokenizer import ( + StructureTokenizer, +) from esm.utils.constants import esm3 as C -from esm.utils.function.encode_decode import encode_function_annotations +from esm.utils.function.encode_decode import ( + encode_function_annotations, +) from esm.utils.structure.protein_chain import ProteinChain from esm.utils.types import FunctionAnnotation @@ -152,6 +165,8 @@ def tokenize_function_annotations( return function_tokens, residue_annotation_tokens + + # Tokenized Defaults def get_default_sequence_tokens( sequence_length: int, sequence_tokenizer: EsmSequenceTokenizer @@ -227,3 +242,5 @@ def get_default_residue_annotation_tokens( residue_annotation_tokens[0] = residue_annotation_tokenizer.bos_token_id residue_annotation_tokens[-1] = residue_annotation_tokenizer.eos_token_id return residue_annotation_tokens + + diff --git a/esm/utils/forge_context_manager.py b/esm/utils/forge_context_manager.py index b1c2bdf3..fac0c3bd 100644 --- a/esm/utils/forge_context_manager.py +++ b/esm/utils/forge_context_manager.py @@ -7,7 +7,10 @@ from tqdm import tqdm from esm.sdk.api import ESMProteinError -from esm.sdk.retry import retry_if_specific_error, skip_retries_var +from esm.sdk.retry import ( + retry_if_specific_error, + skip_retries_var, +) TQDM_BAR_FORMAT = ( "{desc:<12}{percentage:3.0f}%|{bar:24}| {n_fmt}/{total_fmt} " diff --git a/esm/utils/function/encode_decode.py b/esm/utils/function/encode_decode.py index a4029858..29534e34 100644 --- a/esm/utils/function/encode_decode.py +++ b/esm/utils/function/encode_decode.py @@ -3,9 +3,16 @@ import torch -from esm.models.function_decoder import FunctionTokenDecoder, merge_annotations -from esm.tokenization.function_tokenizer import InterProQuantizedTokenizer -from esm.tokenization.residue_tokenizer import ResidueAnnotationsTokenizer +from esm.models.function_decoder import ( + FunctionTokenDecoder, + merge_annotations, +) +from esm.tokenization.function_tokenizer import ( + InterProQuantizedTokenizer, +) +from esm.tokenization.residue_tokenizer import ( + ResidueAnnotationsTokenizer, +) from esm.utils.constants import esm3 as C from esm.utils.types import FunctionAnnotation diff --git a/esm/utils/generation.py b/esm/utils/generation.py index 4d35e7ee..cbc4306b 100644 --- a/esm/utils/generation.py +++ b/esm/utils/generation.py @@ -19,8 +19,13 @@ SamplingConfig, SamplingTrackConfig, ) -from esm.tokenization import EsmTokenizerBase, TokenizerCollectionProtocol -from esm.tokenization.function_tokenizer import InterProQuantizedTokenizer +from esm.tokenization import ( + EsmTokenizerBase, + TokenizerCollectionProtocol, +) +from esm.tokenization.function_tokenizer import ( + InterProQuantizedTokenizer, +) from esm.utils.constants import esm3 as C from esm.utils.misc import stack_variable_length_tensors from esm.utils.noise_schedules import NOISE_SCHEDULE_REGISTRY diff --git a/esm/utils/misc.py b/esm/utils/misc.py index 409ba1bf..a9190041 100644 --- a/esm/utils/misc.py +++ b/esm/utils/misc.py @@ -2,6 +2,7 @@ import os from collections import defaultdict +from contextlib import nullcontext from dataclasses import is_dataclass from io import BytesIO from typing import ( @@ -261,7 +262,7 @@ def unbinpack( return stack_variable_length_tensors(unpacked_tensors, pad_value) -def fp32_autocast_context(device_type: str) -> ContextManager[torch.amp.autocast]: # type: ignore +def fp32_autocast_context(device_type: str) -> ContextManager[Any]: # type: ignore """ Returns an autocast context manager that disables downcasting by AMP. @@ -273,6 +274,9 @@ def fp32_autocast_context(device_type: str) -> ContextManager[torch.amp.autocast """ if device_type == "cpu": return torch.amp.autocast(device_type, enabled=False) # type: ignore + elif device_type == "mps": + # For MPS, just return a no-op context manager (nullcontext) since MPS does not support autocast. + return nullcontext() elif device_type == "cuda": return torch.amp.autocast(device_type, dtype=torch.float32) # type: ignore else: diff --git a/esm/utils/msa/__init__.py b/esm/utils/msa/__init__.py index 5dd9965b..3804a365 100644 --- a/esm/utils/msa/__init__.py +++ b/esm/utils/msa/__init__.py @@ -1,3 +1,7 @@ -from esm.utils.msa.msa import MSA, FastMSA, remove_insertions_from_sequence +from esm.utils.msa.msa import ( + MSA, + FastMSA, + remove_insertions_from_sequence, +) __all__ = ["MSA", "FastMSA", "remove_insertions_from_sequence"] diff --git a/esm/utils/msa/msa.py b/esm/utils/msa/msa.py index 838e5b4b..8722b9d6 100644 --- a/esm/utils/msa/msa.py +++ b/esm/utils/msa/msa.py @@ -12,8 +12,15 @@ from scipy.spatial.distance import cdist from esm.utils.misc import slice_any_object -from esm.utils.msa.filter_sequences import greedy_select_indices, hhfilter -from esm.utils.parsing import FastaEntry, read_sequences, write_sequences +from esm.utils.msa.filter_sequences import ( + greedy_select_indices, + hhfilter, +) +from esm.utils.parsing import ( + FastaEntry, + read_sequences, + write_sequences, +) from esm.utils.sequential_dataclass import SequentialDataclass from esm.utils.system import PathOrBuffer diff --git a/esm/utils/sampling.py b/esm/utils/sampling.py index fdf8658d..68c5c868 100644 --- a/esm/utils/sampling.py +++ b/esm/utils/sampling.py @@ -5,9 +5,18 @@ import torch import torch.nn.functional as F -from esm.sdk.api import ESMProteinTensor, SamplingConfig, SamplingTrackConfig -from esm.tokenization import TokenizerCollectionProtocol, get_invalid_tokenizer_ids -from esm.tokenization.function_tokenizer import InterProQuantizedTokenizer +from esm.sdk.api import ( + ESMProteinTensor, + SamplingConfig, + SamplingTrackConfig, +) +from esm.tokenization import ( + TokenizerCollectionProtocol, + get_invalid_tokenizer_ids, +) +from esm.tokenization.function_tokenizer import ( + InterProQuantizedTokenizer, +) from esm.utils.constants.esm3 import ( MAX_RESIDUE_ANNOTATIONS, SASA_DISCRETIZATION_BOUNDARIES, diff --git a/esm/utils/structure/aligner.py b/esm/utils/structure/aligner.py index f25d9987..dd6702aa 100644 --- a/esm/utils/structure/aligner.py +++ b/esm/utils/structure/aligner.py @@ -6,7 +6,9 @@ import numpy as np import torch -from esm.utils.structure.protein_structure import compute_affine_and_rmsd +from esm.utils.structure.protein_structure import ( + compute_affine_and_rmsd, +) class Alignable(Protocol): diff --git a/esm/utils/structure/atom_indexer.py b/esm/utils/structure/atom_indexer.py index d62f05c9..2f588b98 100644 --- a/esm/utils/structure/atom_indexer.py +++ b/esm/utils/structure/atom_indexer.py @@ -1,6 +1,8 @@ import numpy as np -from esm.utils.structure.protein_structure import index_by_atom_name +from esm.utils.structure.protein_structure import ( + index_by_atom_name, +) class AtomIndexer: diff --git a/esm/utils/structure/input_builder.py b/esm/utils/structure/input_builder.py index 026912fc..158c887c 100644 --- a/esm/utils/structure/input_builder.py +++ b/esm/utils/structure/input_builder.py @@ -4,6 +4,7 @@ import numpy as np + @dataclass class Modification: position: int # zero-indexed diff --git a/esm/utils/structure/molecular_complex.py b/esm/utils/structure/molecular_complex.py index f53ab9c4..40bfae4e 100644 --- a/esm/utils/structure/molecular_complex.py +++ b/esm/utils/structure/molecular_complex.py @@ -16,8 +16,14 @@ import torch from esm.utils import residue_constants -from esm.utils.structure.metrics import compute_lddt, compute_rmsd -from esm.utils.structure.protein_complex import ProteinComplex, ProteinComplexMetadata +from esm.utils.structure.metrics import ( + compute_lddt, + compute_rmsd, +) +from esm.utils.structure.protein_complex import ( + ProteinComplex, + ProteinComplexMetadata, +) @dataclass diff --git a/esm/utils/structure/protein_chain.py b/esm/utils/structure/protein_chain.py index 4889886e..b4db7081 100644 --- a/esm/utils/structure/protein_chain.py +++ b/esm/utils/structure/protein_chain.py @@ -25,13 +25,21 @@ from esm.utils.structure.affine3d import Affine3D from esm.utils.structure.aligner import Aligner from esm.utils.structure.atom_indexer import AtomIndexer -from esm.utils.structure.metrics import compute_gdt_ts, compute_lddt_ca -from esm.utils.structure.mmcif_parsing import MmcifWrapper, Residue +from esm.utils.structure.metrics import ( + compute_gdt_ts, + compute_lddt_ca, +) +from esm.utils.structure.mmcif_parsing import ( + MmcifWrapper, + Residue, +) from esm.utils.structure.normalize_coordinates import ( apply_frame_to_coords, get_protein_normalization_frame, ) -from esm.utils.structure.protein_structure import index_by_atom_name +from esm.utils.structure.protein_structure import ( + index_by_atom_name, +) from esm.utils.types import PathOrBuffer msgpack_numpy.patch() @@ -393,6 +401,7 @@ def from_blob(cls, input: Path | str | io.BytesIO | bytes): bytes = input return cls.from_state_dict(msgpack.loads(brotli.decompress(bytes))) + def sasa(self, by_residue: bool = True): arr = self.atom_array_no_insertions sasa_per_atom = bs.sasa(arr) # type: ignore @@ -698,6 +707,7 @@ def gdt_ts( ) return float(gdt_ts) if gdt_ts.numel() == 1 else gdt_ts.numpy().flatten() + @classmethod def chain_iterable_from_mmcif( cls, diff --git a/esm/utils/structure/protein_complex.py b/esm/utils/structure/protein_complex.py index 306c8832..6c96a584 100644 --- a/esm/utils/structure/protein_complex.py +++ b/esm/utils/structure/protein_complex.py @@ -32,8 +32,14 @@ from esm.utils.structure.affine3d import Affine3D from esm.utils.structure.aligner import Aligner from esm.utils.structure.atom_indexer import AtomIndexer -from esm.utils.structure.metrics import compute_gdt_ts, compute_lddt_ca -from esm.utils.structure.mmcif_parsing import MmcifWrapper, NoProteinError +from esm.utils.structure.metrics import ( + compute_gdt_ts, + compute_lddt_ca, +) +from esm.utils.structure.mmcif_parsing import ( + MmcifWrapper, + NoProteinError, +) from esm.utils.structure.protein_chain import ( ProteinChain, chain_to_ndarray, diff --git a/esm/widgets/components/function_annotator.py b/esm/widgets/components/function_annotator.py index f567f94d..714238f9 100644 --- a/esm/widgets/components/function_annotator.py +++ b/esm/widgets/components/function_annotator.py @@ -4,7 +4,9 @@ from ipywidgets import widgets from esm.sdk.api import FunctionAnnotation -from esm.tokenization.function_tokenizer import InterProQuantizedTokenizer +from esm.tokenization.function_tokenizer import ( + InterProQuantizedTokenizer, +) TRIE: pygtrie.CharTrie | None = None diff --git a/esm/widgets/components/results_visualizer.py b/esm/widgets/components/results_visualizer.py index 261c0a5d..99692e51 100644 --- a/esm/widgets/components/results_visualizer.py +++ b/esm/widgets/components/results_visualizer.py @@ -7,11 +7,15 @@ import matplotlib.pyplot as plt from esm.sdk.api import ESMProtein -from esm.widgets.utils.drawing.draw_category_array import draw_data_array +from esm.widgets.utils.drawing.draw_category_array import ( + draw_data_array, +) from esm.widgets.utils.drawing.draw_function_annotations import ( draw_function_annotations, ) -from esm.widgets.utils.drawing.draw_protein_structure import draw_protein_structure +from esm.widgets.utils.drawing.draw_protein_structure import ( + draw_protein_structure, +) from esm.widgets.utils.serialization import ( create_download_button_from_buffer, protein_to_pdb_buffer, diff --git a/esm/widgets/components/sasa_prompt_selector.py b/esm/widgets/components/sasa_prompt_selector.py index 9c026500..ecc5c1b4 100644 --- a/esm/widgets/components/sasa_prompt_selector.py +++ b/esm/widgets/components/sasa_prompt_selector.py @@ -3,9 +3,16 @@ import ipywidgets as widgets from esm.utils.structure.protein_chain import ProteinChain -from esm.widgets.utils.drawing.colors import hex_to_rgba_tuple, rgba_tuple_to_hex -from esm.widgets.utils.drawing.draw_category_array import draw_data_array -from esm.widgets.utils.parsing import convert_range_string_to_list_of_ranges +from esm.widgets.utils.drawing.colors import ( + hex_to_rgba_tuple, + rgba_tuple_to_hex, +) +from esm.widgets.utils.drawing.draw_category_array import ( + draw_data_array, +) +from esm.widgets.utils.parsing import ( + convert_range_string_to_list_of_ranges, +) from esm.widgets.utils.prompting import PromptManager diff --git a/esm/widgets/components/secondary_structure_prompt_selector.py b/esm/widgets/components/secondary_structure_prompt_selector.py index b8007d69..020180fe 100644 --- a/esm/widgets/components/secondary_structure_prompt_selector.py +++ b/esm/widgets/components/secondary_structure_prompt_selector.py @@ -4,9 +4,16 @@ import pydssp from esm.utils.structure.protein_chain import ProteinChain -from esm.widgets.utils.drawing.colors import hex_to_rgba_tuple, rgba_tuple_to_hex -from esm.widgets.utils.drawing.draw_category_array import draw_data_array -from esm.widgets.utils.parsing import convert_range_string_to_list_of_ranges +from esm.widgets.utils.drawing.colors import ( + hex_to_rgba_tuple, + rgba_tuple_to_hex, +) +from esm.widgets.utils.drawing.draw_category_array import ( + draw_data_array, +) +from esm.widgets.utils.parsing import ( + convert_range_string_to_list_of_ranges, +) from esm.widgets.utils.prompting import PromptManager diff --git a/esm/widgets/components/sequence_prompt_selector.py b/esm/widgets/components/sequence_prompt_selector.py index c5e3526f..d538e670 100644 --- a/esm/widgets/components/sequence_prompt_selector.py +++ b/esm/widgets/components/sequence_prompt_selector.py @@ -6,7 +6,9 @@ hex_to_rgba_tuple, rgba_tuple_to_rgba_html_string, ) -from esm.widgets.utils.parsing import convert_range_string_to_list_of_ranges +from esm.widgets.utils.parsing import ( + convert_range_string_to_list_of_ranges, +) from esm.widgets.utils.prompting import PromptManager diff --git a/esm/widgets/components/structure_prompt_selector.py b/esm/widgets/components/structure_prompt_selector.py index 13a9df78..f4b497c0 100644 --- a/esm/widgets/components/structure_prompt_selector.py +++ b/esm/widgets/components/structure_prompt_selector.py @@ -10,8 +10,12 @@ from esm.utils.structure.protein_chain import ProteinChain from esm.widgets.utils import indexing -from esm.widgets.utils.drawing.draw_protein_structure import draw_protein_structure -from esm.widgets.utils.parsing import convert_range_string_to_list_of_ranges +from esm.widgets.utils.drawing.draw_protein_structure import ( + draw_protein_structure, +) +from esm.widgets.utils.parsing import ( + convert_range_string_to_list_of_ranges, +) from esm.widgets.utils.printing import wrapped_print from esm.widgets.utils.prompting import PromptManager diff --git a/esm/widgets/utils/drawing/draw_function_annotations.py b/esm/widgets/utils/drawing/draw_function_annotations.py index 59e9f7cf..c71e5434 100644 --- a/esm/widgets/utils/drawing/draw_function_annotations.py +++ b/esm/widgets/utils/drawing/draw_function_annotations.py @@ -9,7 +9,10 @@ from PIL import Image from esm.sdk.api import FunctionAnnotation -from esm.utils.function.interpro import InterPro, InterProEntryType +from esm.utils.function.interpro import ( + InterPro, + InterProEntryType, +) @contextmanager diff --git a/esm/widgets/utils/prompting.py b/esm/widgets/utils/prompting.py index 1e89bb64..1ce6d9c6 100644 --- a/esm/widgets/utils/prompting.py +++ b/esm/widgets/utils/prompting.py @@ -9,7 +9,9 @@ from esm.utils import encoding from esm.widgets.utils import indexing from esm.widgets.utils.drawing.colors import rgba_tuple_to_hex -from esm.widgets.utils.drawing.draw_category_array import draw_data_array +from esm.widgets.utils.drawing.draw_category_array import ( + draw_data_array, +) from esm.widgets.utils.printing import wrapped_print diff --git a/esm/widgets/views/esm3_generation_launcher.py b/esm/widgets/views/esm3_generation_launcher.py index f8bf8f3b..e94c60ff 100644 --- a/esm/widgets/views/esm3_generation_launcher.py +++ b/esm/widgets/views/esm3_generation_launcher.py @@ -13,9 +13,13 @@ GenerationConfig, ) from esm.utils.constants import models -from esm.widgets.components.results_visualizer import create_results_visualizer +from esm.widgets.components.results_visualizer import ( + create_results_visualizer, +) from esm.widgets.utils.printing import wrapped_print -from esm.widgets.utils.serialization import create_download_results_button +from esm.widgets.utils.serialization import ( + create_download_results_button, +) def create_esm3_generation_launcher( diff --git a/esm/widgets/views/esm3_prompt_selector.py b/esm/widgets/views/esm3_prompt_selector.py index f7e60686..035db28e 100644 --- a/esm/widgets/views/esm3_prompt_selector.py +++ b/esm/widgets/views/esm3_prompt_selector.py @@ -1,6 +1,8 @@ from ipywidgets import widgets -from esm.widgets.components.sasa_prompt_selector import create_sasa_prompt_selector +from esm.widgets.components.sasa_prompt_selector import ( + create_sasa_prompt_selector, +) from esm.widgets.components.secondary_structure_prompt_selector import ( create_secondary_structure_prompt_selector, ) diff --git a/esm/widgets/views/generation.py b/esm/widgets/views/generation.py index fdec2094..19015f60 100644 --- a/esm/widgets/views/generation.py +++ b/esm/widgets/views/generation.py @@ -4,12 +4,20 @@ from esm.sdk.api import ESM3InferenceClient, ESMProtein from esm.utils.constants import esm3 as C -from esm.widgets.components.function_annotator import create_function_annotator +from esm.widgets.components.function_annotator import ( + create_function_annotator, +) from esm.widgets.utils.prompting import PromptManagerCollection from esm.widgets.utils.protein_import import ProteinImporter -from esm.widgets.views.esm3_generation_launcher import create_esm3_generation_launcher -from esm.widgets.views.esm3_prompt_preview import create_esm3_prompt_preview -from esm.widgets.views.esm3_prompt_selector import create_esm3_prompt_selector +from esm.widgets.views.esm3_generation_launcher import ( + create_esm3_generation_launcher, +) +from esm.widgets.views.esm3_prompt_preview import ( + create_esm3_prompt_preview, +) +from esm.widgets.views.esm3_prompt_selector import ( + create_esm3_prompt_selector, +) def create_generation_ui( diff --git a/esm/widgets/views/inverse_folding.py b/esm/widgets/views/inverse_folding.py index 50d9a128..8becb8eb 100644 --- a/esm/widgets/views/inverse_folding.py +++ b/esm/widgets/views/inverse_folding.py @@ -6,7 +6,9 @@ ESMProteinError, GenerationConfig, ) -from esm.widgets.components.results_visualizer import create_results_visualizer +from esm.widgets.components.results_visualizer import ( + create_results_visualizer, +) from esm.widgets.utils.printing import wrapped_print from esm.widgets.utils.protein_import import ProteinImporter diff --git a/esm/widgets/views/login.py b/esm/widgets/views/login.py index 5c7b6706..2d8be5a3 100644 --- a/esm/widgets/views/login.py +++ b/esm/widgets/views/login.py @@ -4,7 +4,10 @@ from ipywidgets import widgets -from esm.widgets.utils.clients import get_forge_client, get_local_client +from esm.widgets.utils.clients import ( + get_forge_client, + get_local_client, +) from esm.widgets.utils.types import ClientInitContainer diff --git a/esm/widgets/views/prediction.py b/esm/widgets/views/prediction.py index 94ff49dc..de6666d2 100644 --- a/esm/widgets/views/prediction.py +++ b/esm/widgets/views/prediction.py @@ -6,7 +6,9 @@ ESMProteinError, GenerationConfig, ) -from esm.widgets.components.results_visualizer import create_results_visualizer +from esm.widgets.components.results_visualizer import ( + create_results_visualizer, +) from esm.widgets.utils.printing import wrapped_print from esm.widgets.utils.protein_import import ProteinImporter diff --git a/tests/oss_pytests/test_oss_client.py b/tests/oss_pytests/test_oss_client.py index 4e7638f6..d597a220 100644 --- a/tests/oss_pytests/test_oss_client.py +++ b/tests/oss_pytests/test_oss_client.py @@ -2,7 +2,6 @@ import pytest import torch - from esm.sdk import client # pyright: ignore from esm.sdk.api import ( # pyright: ignore ESMProtein, From 38f50fb9223346ec049f24a9fe084305cb27472b Mon Sep 17 00:00:00 2001 From: Neil Thomas Date: Fri, 19 Sep 2025 21:51:10 +0000 Subject: [PATCH 4/4] ruff checks --- cookbook/local/open_generate.ipynb | 1 + cookbook/local/raw_forwards.py | 5 +- cookbook/snippets/fold_invfold.py | 1 + cookbook/tutorials/1_esmprotein.ipynb | 4 +- cookbook/tutorials/3_gfp_design.ipynb | 1 + cookbook/tutorials/4_forge_generate.ipynb | 1 + cookbook/tutorials/5_guided_generation.ipynb | 1 + esm/__init__.py | 1 - esm/layers/attention.py | 5 +- esm/layers/blocks.py | 9 +-- esm/layers/structure_proj.py | 5 +- esm/models/esm3.py | 14 +--- esm/models/function_decoder.py | 4 +- esm/models/vqvae.py | 5 +- esm/pretrained.py | 10 +-- esm/sdk/api.py | 22 ++---- esm/sdk/base_forge_client.py | 16 +---- esm/sdk/forge.py | 67 +++++-------------- esm/tokenization/__init__.py | 5 +- esm/utils/decoding.py | 24 ++----- esm/utils/encoding.py | 29 ++------ esm/utils/forge_context_manager.py | 5 +- esm/utils/function/encode_decode.py | 13 +--- esm/utils/generation.py | 9 +-- esm/utils/msa/__init__.py | 6 +- esm/utils/msa/msa.py | 11 +-- esm/utils/sampling.py | 15 +---- esm/utils/structure/aligner.py | 4 +- esm/utils/structure/atom_indexer.py | 4 +- esm/utils/structure/input_builder.py | 1 - esm/utils/structure/molecular_complex.py | 10 +-- esm/utils/structure/protein_chain.py | 16 +---- esm/utils/structure/protein_complex.py | 10 +-- esm/widgets/components/function_annotator.py | 4 +- esm/widgets/components/results_visualizer.py | 8 +-- .../components/sasa_prompt_selector.py | 13 +--- .../secondary_structure_prompt_selector.py | 13 +--- .../components/sequence_prompt_selector.py | 4 +- .../components/structure_prompt_selector.py | 8 +-- .../drawing/draw_function_annotations.py | 5 +- esm/widgets/utils/prompting.py | 4 +- esm/widgets/views/esm3_generation_launcher.py | 8 +-- esm/widgets/views/esm3_prompt_selector.py | 4 +- esm/widgets/views/generation.py | 16 ++--- esm/widgets/views/inverse_folding.py | 4 +- esm/widgets/views/login.py | 5 +- esm/widgets/views/prediction.py | 4 +- tests/oss_pytests/test_oss_client.py | 1 + 48 files changed, 101 insertions(+), 334 deletions(-) diff --git a/cookbook/local/open_generate.ipynb b/cookbook/local/open_generate.ipynb index c8cec4c1..32a72c38 100644 --- a/cookbook/local/open_generate.ipynb +++ b/cookbook/local/open_generate.ipynb @@ -38,6 +38,7 @@ "\n", "!pip install py3Dmol\n", "import py3Dmol\n", + "\n", "from esm.models.esm3 import ESM3\n", "from esm.sdk.api import ESMProtein, GenerationConfig\n", "from esm.utils.structure.protein_chain import ProteinChain" diff --git a/cookbook/local/raw_forwards.py b/cookbook/local/raw_forwards.py index 5701fa2a..baad28ee 100644 --- a/cookbook/local/raw_forwards.py +++ b/cookbook/local/raw_forwards.py @@ -2,6 +2,7 @@ import torch import torch.nn.functional as F + from esm.pretrained import ( ESM3_function_decoder_v0, ESM3_sm_open_v0, @@ -12,9 +13,7 @@ from esm.tokenization.function_tokenizer import ( InterProQuantizedTokenizer as EsmFunctionTokenizer, ) -from esm.tokenization.sequence_tokenizer import ( - EsmSequenceTokenizer, -) +from esm.tokenization.sequence_tokenizer import EsmSequenceTokenizer from esm.utils.structure.protein_chain import ProteinChain from esm.utils.types import FunctionAnnotation diff --git a/cookbook/snippets/fold_invfold.py b/cookbook/snippets/fold_invfold.py index cb24db6a..e6be5485 100644 --- a/cookbook/snippets/fold_invfold.py +++ b/cookbook/snippets/fold_invfold.py @@ -2,6 +2,7 @@ from typing import cast import numpy as np + from esm.sdk.api import ( ESM3InferenceClient, ESMProtein, diff --git a/cookbook/tutorials/1_esmprotein.ipynb b/cookbook/tutorials/1_esmprotein.ipynb index e143ff13..13017733 100644 --- a/cookbook/tutorials/1_esmprotein.ipynb +++ b/cookbook/tutorials/1_esmprotein.ipynb @@ -72,6 +72,7 @@ "outputs": [], "source": [ "from biotite.database import rcsb\n", + "\n", "from esm.sdk.api import ESMProtein\n", "from esm.utils.structure.protein_chain import ProteinChain\n", "from esm.utils.types import FunctionAnnotation\n", @@ -496,9 +497,10 @@ "# Functions for visualizing InterPro function annotations\n", "\n", "from dna_features_viewer import GraphicFeature, GraphicRecord\n", - "from esm.utils.function.interpro import InterPro, InterProEntryType\n", "from matplotlib import colormaps\n", "\n", + "from esm.utils.function.interpro import InterPro, InterProEntryType\n", + "\n", "\n", "def visualize_function_annotations(\n", " annotations: list[FunctionAnnotation],\n", diff --git a/cookbook/tutorials/3_gfp_design.ipynb b/cookbook/tutorials/3_gfp_design.ipynb index 95b42418..f2bb85b9 100644 --- a/cookbook/tutorials/3_gfp_design.ipynb +++ b/cookbook/tutorials/3_gfp_design.ipynb @@ -64,6 +64,7 @@ "import matplotlib.pyplot as pl\n", "import py3Dmol\n", "import torch\n", + "\n", "from esm.sdk import client\n", "from esm.sdk.api import ESMProtein, GenerationConfig\n", "from esm.utils.structure.protein_chain import ProteinChain" diff --git a/cookbook/tutorials/4_forge_generate.ipynb b/cookbook/tutorials/4_forge_generate.ipynb index 5fb6e676..9962e54d 100644 --- a/cookbook/tutorials/4_forge_generate.ipynb +++ b/cookbook/tutorials/4_forge_generate.ipynb @@ -36,6 +36,7 @@ "\n", "!pip install py3Dmol\n", "import py3Dmol\n", + "\n", "from esm.sdk import client\n", "from esm.sdk.api import ESMProtein, GenerationConfig\n", "from esm.utils.structure.protein_chain import ProteinChain" diff --git a/cookbook/tutorials/5_guided_generation.ipynb b/cookbook/tutorials/5_guided_generation.ipynb index b2d8c7dc..d35a8c94 100644 --- a/cookbook/tutorials/5_guided_generation.ipynb +++ b/cookbook/tutorials/5_guided_generation.ipynb @@ -49,6 +49,7 @@ "source": [ "import biotite.structure as bs\n", "import py3Dmol\n", + "\n", "from esm.sdk.api import ESMProtein, GenerationConfig\n", "from esm.sdk.experimental import ESM3GuidedDecoding, GuidedDecodingScoringFunction" ] diff --git a/esm/__init__.py b/esm/__init__.py index c37e4c07..98a35b2d 100644 --- a/esm/__init__.py +++ b/esm/__init__.py @@ -1,2 +1 @@ __version__ = "3.2.2.post2" - diff --git a/esm/layers/attention.py b/esm/layers/attention.py index b0f7c2b5..964d4282 100644 --- a/esm/layers/attention.py +++ b/esm/layers/attention.py @@ -5,10 +5,7 @@ import torch.nn.functional as F from torch import nn -from esm.layers.rotary import ( - RotaryEmbedding, - TritonRotaryEmbedding, -) +from esm.layers.rotary import RotaryEmbedding, TritonRotaryEmbedding try: from flash_attn import flash_attn_varlen_qkvpacked_func # type: ignore diff --git a/esm/layers/blocks.py b/esm/layers/blocks.py index 593b277f..76ebbe06 100644 --- a/esm/layers/blocks.py +++ b/esm/layers/blocks.py @@ -2,13 +2,8 @@ import torch.nn as nn import torch.nn.functional as F -from esm.layers.attention import ( - FlashMultiHeadAttention, - MultiHeadAttention, -) -from esm.layers.geom_attention import ( - GeometricReasoningOriginalImpl, -) +from esm.layers.attention import FlashMultiHeadAttention, MultiHeadAttention +from esm.layers.geom_attention import GeometricReasoningOriginalImpl from esm.utils.structure.affine3d import Affine3D diff --git a/esm/layers/structure_proj.py b/esm/layers/structure_proj.py index 783ddeb4..faad0fe9 100644 --- a/esm/layers/structure_proj.py +++ b/esm/layers/structure_proj.py @@ -2,10 +2,7 @@ import torch.nn as nn from esm.utils.constants.physics import BB_COORDINATES -from esm.utils.structure.affine3d import ( - Affine3D, - RotationMatrix, -) +from esm.utils.structure.affine3d import Affine3D, RotationMatrix class Dim6RotStructureHead(nn.Module): diff --git a/esm/models/esm3.py b/esm/models/esm3.py index cbe02ddd..218a8e90 100644 --- a/esm/models/esm3.py +++ b/esm/models/esm3.py @@ -13,10 +13,7 @@ from esm.layers.regression_head import RegressionHead from esm.layers.transformer_stack import TransformerStack from esm.models.function_decoder import FunctionTokenDecoder -from esm.models.vqvae import ( - StructureTokenDecoder, - StructureTokenEncoder, -) +from esm.models.vqvae import StructureTokenDecoder, StructureTokenEncoder from esm.sdk.api import ( ESM3InferenceClient, ESMProtein, @@ -32,10 +29,7 @@ from esm.tokenization import TokenizerCollectionProtocol from esm.utils import encoding from esm.utils.constants import esm3 as C -from esm.utils.constants.models import ( - ESM3_OPEN_SMALL, - normalize_model_name, -) +from esm.utils.constants.models import ESM3_OPEN_SMALL, normalize_model_name from esm.utils.decoding import decode_protein_tensor from esm.utils.generation import ( _batch_forward, @@ -50,9 +44,7 @@ get_default_sampling_config, validate_sampling_config, ) -from esm.utils.structure.affine3d import ( - build_affine3d_from_coordinates, -) +from esm.utils.structure.affine3d import build_affine3d_from_coordinates @dataclass diff --git a/esm/models/function_decoder.py b/esm/models/function_decoder.py index c4f32992..e5f1fb28 100644 --- a/esm/models/function_decoder.py +++ b/esm/models/function_decoder.py @@ -12,9 +12,7 @@ from esm.layers.regression_head import RegressionHead from esm.layers.transformer_stack import TransformerStack -from esm.tokenization.function_tokenizer import ( - InterProQuantizedTokenizer, -) +from esm.tokenization.function_tokenizer import InterProQuantizedTokenizer from esm.utils.constants import esm3 as C from esm.utils.misc import merge_annotations, merge_ranges from esm.utils.types import FunctionAnnotation diff --git a/esm/models/vqvae.py b/esm/models/vqvae.py index 0f5226a4..37bc3945 100644 --- a/esm/models/vqvae.py +++ b/esm/models/vqvae.py @@ -7,10 +7,7 @@ from esm.layers.transformer_stack import TransformerStack from esm.utils.constants import esm3 as C from esm.utils.misc import knn_graph -from esm.utils.structure.affine3d import ( - Affine3D, - build_affine3d_from_coordinates, -) +from esm.utils.structure.affine3d import Affine3D, build_affine3d_from_coordinates from esm.utils.structure.predicted_aligned_error import ( compute_predicted_aligned_error, compute_tm, diff --git a/esm/pretrained.py b/esm/pretrained.py index b9121511..e452e1d2 100644 --- a/esm/pretrained.py +++ b/esm/pretrained.py @@ -6,14 +6,8 @@ from esm.models.esm3 import ESM3 from esm.models.esmc import ESMC from esm.models.function_decoder import FunctionTokenDecoder -from esm.models.vqvae import ( - StructureTokenDecoder, - StructureTokenEncoder, -) -from esm.tokenization import ( - get_esm3_model_tokenizers, - get_esmc_model_tokenizers, -) +from esm.models.vqvae import StructureTokenDecoder, StructureTokenEncoder +from esm.tokenization import get_esm3_model_tokenizers, get_esmc_model_tokenizers from esm.utils.constants.esm3 import data_root from esm.utils.constants.models import ( ESM3_FUNCTION_DECODER_V0, diff --git a/esm/sdk/api.py b/esm/sdk/api.py index 6b152556..bd93f0cd 100644 --- a/esm/sdk/api.py +++ b/esm/sdk/api.py @@ -2,27 +2,19 @@ from abc import ABC from copy import deepcopy -from typing import List, Sequence +from typing import Sequence import attr import torch from attr import asdict, define import esm.utils.constants.api as C -from esm.tokenization import ( - TokenizerCollectionProtocol, - get_esm3_model_tokenizers, -) +from esm.tokenization import TokenizerCollectionProtocol, get_esm3_model_tokenizers from esm.utils import encoding from esm.utils.constants.models import ESM3_OPEN_SMALL -from esm.utils.misc import ( - get_chainbreak_boundaries_from_sequence, -) +from esm.utils.misc import get_chainbreak_boundaries_from_sequence from esm.utils.structure.protein_chain import ProteinChain -from esm.utils.structure.protein_complex import ( - SINGLE_LETTER_CHAIN_IDS, - ProteinComplex, -) +from esm.utils.structure.protein_complex import SINGLE_LETTER_CHAIN_IDS, ProteinComplex from esm.utils.types import FunctionAnnotation, PathOrBuffer @@ -43,7 +35,6 @@ class ESMProtein(ProteinType): plddt: torch.Tensor | None = None ptm: torch.Tensor | None = None - # When calling EvolutionaryScale API, use this flag to disclose any # sequences that may potentially have concerns. # Such sequences may not go through standard safety filter for approved users. @@ -336,8 +327,6 @@ class InverseFoldingConfig: temperature: float = 1.0 - - ## Low Level Endpoint Types @define class SamplingTrackConfig: @@ -402,9 +391,6 @@ class LogitsConfig: ith_hidden_layer: int = -1 - - - @define class LogitsOutput: logits: ForwardTrackData | None = None diff --git a/esm/sdk/base_forge_client.py b/esm/sdk/base_forge_client.py index 9d27e647..ff05b541 100644 --- a/esm/sdk/base_forge_client.py +++ b/esm/sdk/base_forge_client.py @@ -1,13 +1,9 @@ -import asyncio -import time -from abc import ABC, abstractmethod from typing import Any from urllib.parse import urljoin import httpx from esm.sdk.api import ESMProteinError -from esm.sdk.retry import retry_decorator from esm.utils.decoding import assemble_message @@ -116,10 +112,7 @@ async def _async_post( ): try: request, headers = self.prepare_request( - request, - potential_sequence_of_concern, - return_bytes, - headers, + request, potential_sequence_of_concern, return_bytes, headers ) response = await self.async_client.post( url=urljoin(self.url, f"/api/v1/{endpoint}"), @@ -149,10 +142,7 @@ def _post( ): try: request, headers = self.prepare_request( - request, - potential_sequence_of_concern, - return_bytes, - headers, + request, potential_sequence_of_concern, return_bytes, headers ) response = self.client.post( url=urljoin(self.url, f"/api/v1/{endpoint}"), @@ -170,5 +160,3 @@ def _post( error_code=500, error_msg=f"Failed to submit request to {endpoint}. Error: {e}", ) - - diff --git a/esm/sdk/forge.py b/esm/sdk/forge.py index ac105c03..67c94b07 100644 --- a/esm/sdk/forge.py +++ b/esm/sdk/forge.py @@ -4,7 +4,7 @@ import base64 import pickle from concurrent.futures import ThreadPoolExecutor -from typing import Any, Literal, Sequence, cast +from typing import Any, Sequence import torch @@ -20,21 +20,14 @@ InverseFoldingConfig, LogitsConfig, LogitsOutput, - ProteinChain, ProteinType, SamplingConfig, SamplingTrackConfig, ) -from esm.sdk.base_forge_client import ( - _BaseForgeInferenceClient, -) +from esm.sdk.base_forge_client import _BaseForgeInferenceClient from esm.sdk.retry import retry_decorator from esm.utils.constants.api import MIMETYPE_ES_PICKLE -from esm.utils.misc import ( - deserialize_tensors, - maybe_list, - maybe_tensor, -) +from esm.utils.misc import deserialize_tensors, maybe_list, maybe_tensor from esm.utils.msa import MSA from esm.utils.structure.input_builder import ( StructurePredictionInput, @@ -108,13 +101,9 @@ def __init__( ) @staticmethod - def _process_fold_request( - sequence: str, - model_name: str | None, - ): + def _process_fold_request(sequence: str, model_name: str | None): request: dict[str, Any] = {"sequence": sequence} - request["model"] = model_name return request @@ -149,7 +138,6 @@ def process_inverse_fold_request( return request - async def _async_fetch_msa(self, sequence: str) -> MSA: print("Fetching MSA ... this may take a few minutes") # Accept both "|" and ":" as the chainbreak token. @@ -188,15 +176,11 @@ async def async_fold( del potential_sequence_of_concern request = self._process_fold_request( - sequence, - model_name if model_name is not None else self.model, + sequence, model_name if model_name is not None else self.model ) try: - data = await self._async_post( - "fold", - request, - ) + data = await self._async_post("fold", request) except ESMProteinError as e: return e @@ -223,15 +207,11 @@ def fold( del potential_sequence_of_concern request = self._process_fold_request( - sequence, - model_name if model_name is not None else self.model, + sequence, model_name if model_name is not None else self.model ) try: - data = self._post( - "fold", - request, - ) + data = self._post("fold", request) except ESMProteinError as e: return e @@ -239,9 +219,7 @@ def fold( @retry_decorator async def async_fold_all_atom( - self, - all_atom_input: StructurePredictionInput, - model_name: str | None = None, + self, all_atom_input: StructurePredictionInput, model_name: str | None = None ) -> MolecularComplexResult | list[MolecularComplexResult] | ESMProteinError: """Fold a molecular complex containing proteins, nucleic acids, and/or ligands. @@ -250,15 +228,11 @@ async def async_fold_all_atom( model_name: Override the client level model name if needed """ request = self._process_fold_all_atom_request( - all_atom_input, - model_name if model_name is not None else self.model, + all_atom_input, model_name if model_name is not None else self.model ) try: - data = await self._async_post( - "fold_all_atom", - request, - ) + data = await self._async_post("fold_all_atom", request) except ESMProteinError as e: return e @@ -266,9 +240,7 @@ async def async_fold_all_atom( @retry_decorator def fold_all_atom( - self, - all_atom_input: StructurePredictionInput, - model_name: str | None = None, + self, all_atom_input: StructurePredictionInput, model_name: str | None = None ) -> MolecularComplexResult | list[MolecularComplexResult] | ESMProteinError: """Predict coordinates for a molecular complex containing proteins, dna, rna, and/or ligands. @@ -277,15 +249,11 @@ def fold_all_atom( model_name: Override the client level model name if needed """ request = self._process_fold_all_atom_request( - all_atom_input, - model_name if model_name is not None else self.model, + all_atom_input, model_name if model_name is not None else self.model ) try: - data = self._post( - "fold_all_atom", - request, - ) + data = self._post("fold_all_atom", request) except ESMProteinError as e: return e @@ -293,15 +261,13 @@ def fold_all_atom( @staticmethod def _process_fold_all_atom_request( - all_atom_input: StructurePredictionInput, - model_name: str | None = None, + all_atom_input: StructurePredictionInput, model_name: str | None = None ) -> dict[str, Any]: request: dict[str, Any] = { "all_atom_input": serialize_structure_prediction_input(all_atom_input), "model": model_name, } - return request @staticmethod @@ -386,7 +352,6 @@ def inverse_fold( return ESMProtein(sequence=data["sequence"]) - class ESM3ForgeInferenceClient(ESM3InferenceClient, _BaseForgeInferenceClient): def __init__( self, @@ -1212,5 +1177,3 @@ def raw_model(self): raise NotImplementedError( f"Can not get underlying remote model {self.model} from a Forge client." ) - - diff --git a/esm/tokenization/__init__.py b/esm/tokenization/__init__.py index ea609225..6db76554 100644 --- a/esm/tokenization/__init__.py +++ b/esm/tokenization/__init__.py @@ -1,10 +1,7 @@ from dataclasses import dataclass from typing import Protocol -from esm.utils.constants.models import ( - ESM3_OPEN_SMALL, - normalize_model_name, -) +from esm.utils.constants.models import ESM3_OPEN_SMALL, normalize_model_name from .function_tokenizer import InterProQuantizedTokenizer from .residue_tokenizer import ResidueAnnotationsTokenizer diff --git a/esm/utils/decoding.py b/esm/utils/decoding.py index b5588527..1fe256b6 100644 --- a/esm/utils/decoding.py +++ b/esm/utils/decoding.py @@ -10,24 +10,12 @@ from esm.models.vqvae import StructureTokenDecoder from esm.sdk.api import ESMProtein, ESMProteinTensor from esm.tokenization import TokenizerCollectionProtocol -from esm.tokenization.function_tokenizer import ( - InterProQuantizedTokenizer, -) -from esm.tokenization.residue_tokenizer import ( - ResidueAnnotationsTokenizer, -) -from esm.tokenization.sasa_tokenizer import ( - SASADiscretizingTokenizer, -) -from esm.tokenization.sequence_tokenizer import ( - EsmSequenceTokenizer, -) -from esm.tokenization.ss_tokenizer import ( - SecondaryStructureTokenizer, -) -from esm.tokenization.structure_tokenizer import ( - StructureTokenizer, -) +from esm.tokenization.function_tokenizer import InterProQuantizedTokenizer +from esm.tokenization.residue_tokenizer import ResidueAnnotationsTokenizer +from esm.tokenization.sasa_tokenizer import SASADiscretizingTokenizer +from esm.tokenization.sequence_tokenizer import EsmSequenceTokenizer +from esm.tokenization.ss_tokenizer import SecondaryStructureTokenizer +from esm.tokenization.structure_tokenizer import StructureTokenizer from esm.tokenization.tokenizer_base import EsmTokenizerBase from esm.utils.constants import api as api_constants from esm.utils.constants import esm3 as C diff --git a/esm/utils/encoding.py b/esm/utils/encoding.py index 83c9d033..8461709d 100644 --- a/esm/utils/encoding.py +++ b/esm/utils/encoding.py @@ -7,26 +7,13 @@ from esm.tokenization.function_tokenizer import ( InterProQuantizedTokenizer as EsmFunctionTokenizer, ) - -from esm.tokenization.residue_tokenizer import ( - ResidueAnnotationsTokenizer, -) -from esm.tokenization.sasa_tokenizer import ( - SASADiscretizingTokenizer, -) -from esm.tokenization.sequence_tokenizer import ( - EsmSequenceTokenizer, -) -from esm.tokenization.ss_tokenizer import ( - SecondaryStructureTokenizer, -) -from esm.tokenization.structure_tokenizer import ( - StructureTokenizer, -) +from esm.tokenization.residue_tokenizer import ResidueAnnotationsTokenizer +from esm.tokenization.sasa_tokenizer import SASADiscretizingTokenizer +from esm.tokenization.sequence_tokenizer import EsmSequenceTokenizer +from esm.tokenization.ss_tokenizer import SecondaryStructureTokenizer +from esm.tokenization.structure_tokenizer import StructureTokenizer from esm.utils.constants import esm3 as C -from esm.utils.function.encode_decode import ( - encode_function_annotations, -) +from esm.utils.function.encode_decode import encode_function_annotations from esm.utils.structure.protein_chain import ProteinChain from esm.utils.types import FunctionAnnotation @@ -165,8 +152,6 @@ def tokenize_function_annotations( return function_tokens, residue_annotation_tokens - - # Tokenized Defaults def get_default_sequence_tokens( sequence_length: int, sequence_tokenizer: EsmSequenceTokenizer @@ -242,5 +227,3 @@ def get_default_residue_annotation_tokens( residue_annotation_tokens[0] = residue_annotation_tokenizer.bos_token_id residue_annotation_tokens[-1] = residue_annotation_tokenizer.eos_token_id return residue_annotation_tokens - - diff --git a/esm/utils/forge_context_manager.py b/esm/utils/forge_context_manager.py index fac0c3bd..b1c2bdf3 100644 --- a/esm/utils/forge_context_manager.py +++ b/esm/utils/forge_context_manager.py @@ -7,10 +7,7 @@ from tqdm import tqdm from esm.sdk.api import ESMProteinError -from esm.sdk.retry import ( - retry_if_specific_error, - skip_retries_var, -) +from esm.sdk.retry import retry_if_specific_error, skip_retries_var TQDM_BAR_FORMAT = ( "{desc:<12}{percentage:3.0f}%|{bar:24}| {n_fmt}/{total_fmt} " diff --git a/esm/utils/function/encode_decode.py b/esm/utils/function/encode_decode.py index 29534e34..a4029858 100644 --- a/esm/utils/function/encode_decode.py +++ b/esm/utils/function/encode_decode.py @@ -3,16 +3,9 @@ import torch -from esm.models.function_decoder import ( - FunctionTokenDecoder, - merge_annotations, -) -from esm.tokenization.function_tokenizer import ( - InterProQuantizedTokenizer, -) -from esm.tokenization.residue_tokenizer import ( - ResidueAnnotationsTokenizer, -) +from esm.models.function_decoder import FunctionTokenDecoder, merge_annotations +from esm.tokenization.function_tokenizer import InterProQuantizedTokenizer +from esm.tokenization.residue_tokenizer import ResidueAnnotationsTokenizer from esm.utils.constants import esm3 as C from esm.utils.types import FunctionAnnotation diff --git a/esm/utils/generation.py b/esm/utils/generation.py index cbc4306b..4d35e7ee 100644 --- a/esm/utils/generation.py +++ b/esm/utils/generation.py @@ -19,13 +19,8 @@ SamplingConfig, SamplingTrackConfig, ) -from esm.tokenization import ( - EsmTokenizerBase, - TokenizerCollectionProtocol, -) -from esm.tokenization.function_tokenizer import ( - InterProQuantizedTokenizer, -) +from esm.tokenization import EsmTokenizerBase, TokenizerCollectionProtocol +from esm.tokenization.function_tokenizer import InterProQuantizedTokenizer from esm.utils.constants import esm3 as C from esm.utils.misc import stack_variable_length_tensors from esm.utils.noise_schedules import NOISE_SCHEDULE_REGISTRY diff --git a/esm/utils/msa/__init__.py b/esm/utils/msa/__init__.py index 3804a365..5dd9965b 100644 --- a/esm/utils/msa/__init__.py +++ b/esm/utils/msa/__init__.py @@ -1,7 +1,3 @@ -from esm.utils.msa.msa import ( - MSA, - FastMSA, - remove_insertions_from_sequence, -) +from esm.utils.msa.msa import MSA, FastMSA, remove_insertions_from_sequence __all__ = ["MSA", "FastMSA", "remove_insertions_from_sequence"] diff --git a/esm/utils/msa/msa.py b/esm/utils/msa/msa.py index 8722b9d6..838e5b4b 100644 --- a/esm/utils/msa/msa.py +++ b/esm/utils/msa/msa.py @@ -12,15 +12,8 @@ from scipy.spatial.distance import cdist from esm.utils.misc import slice_any_object -from esm.utils.msa.filter_sequences import ( - greedy_select_indices, - hhfilter, -) -from esm.utils.parsing import ( - FastaEntry, - read_sequences, - write_sequences, -) +from esm.utils.msa.filter_sequences import greedy_select_indices, hhfilter +from esm.utils.parsing import FastaEntry, read_sequences, write_sequences from esm.utils.sequential_dataclass import SequentialDataclass from esm.utils.system import PathOrBuffer diff --git a/esm/utils/sampling.py b/esm/utils/sampling.py index 68c5c868..fdf8658d 100644 --- a/esm/utils/sampling.py +++ b/esm/utils/sampling.py @@ -5,18 +5,9 @@ import torch import torch.nn.functional as F -from esm.sdk.api import ( - ESMProteinTensor, - SamplingConfig, - SamplingTrackConfig, -) -from esm.tokenization import ( - TokenizerCollectionProtocol, - get_invalid_tokenizer_ids, -) -from esm.tokenization.function_tokenizer import ( - InterProQuantizedTokenizer, -) +from esm.sdk.api import ESMProteinTensor, SamplingConfig, SamplingTrackConfig +from esm.tokenization import TokenizerCollectionProtocol, get_invalid_tokenizer_ids +from esm.tokenization.function_tokenizer import InterProQuantizedTokenizer from esm.utils.constants.esm3 import ( MAX_RESIDUE_ANNOTATIONS, SASA_DISCRETIZATION_BOUNDARIES, diff --git a/esm/utils/structure/aligner.py b/esm/utils/structure/aligner.py index dd6702aa..f25d9987 100644 --- a/esm/utils/structure/aligner.py +++ b/esm/utils/structure/aligner.py @@ -6,9 +6,7 @@ import numpy as np import torch -from esm.utils.structure.protein_structure import ( - compute_affine_and_rmsd, -) +from esm.utils.structure.protein_structure import compute_affine_and_rmsd class Alignable(Protocol): diff --git a/esm/utils/structure/atom_indexer.py b/esm/utils/structure/atom_indexer.py index 2f588b98..d62f05c9 100644 --- a/esm/utils/structure/atom_indexer.py +++ b/esm/utils/structure/atom_indexer.py @@ -1,8 +1,6 @@ import numpy as np -from esm.utils.structure.protein_structure import ( - index_by_atom_name, -) +from esm.utils.structure.protein_structure import index_by_atom_name class AtomIndexer: diff --git a/esm/utils/structure/input_builder.py b/esm/utils/structure/input_builder.py index 158c887c..026912fc 100644 --- a/esm/utils/structure/input_builder.py +++ b/esm/utils/structure/input_builder.py @@ -4,7 +4,6 @@ import numpy as np - @dataclass class Modification: position: int # zero-indexed diff --git a/esm/utils/structure/molecular_complex.py b/esm/utils/structure/molecular_complex.py index 40bfae4e..f53ab9c4 100644 --- a/esm/utils/structure/molecular_complex.py +++ b/esm/utils/structure/molecular_complex.py @@ -16,14 +16,8 @@ import torch from esm.utils import residue_constants -from esm.utils.structure.metrics import ( - compute_lddt, - compute_rmsd, -) -from esm.utils.structure.protein_complex import ( - ProteinComplex, - ProteinComplexMetadata, -) +from esm.utils.structure.metrics import compute_lddt, compute_rmsd +from esm.utils.structure.protein_complex import ProteinComplex, ProteinComplexMetadata @dataclass diff --git a/esm/utils/structure/protein_chain.py b/esm/utils/structure/protein_chain.py index b4db7081..4889886e 100644 --- a/esm/utils/structure/protein_chain.py +++ b/esm/utils/structure/protein_chain.py @@ -25,21 +25,13 @@ from esm.utils.structure.affine3d import Affine3D from esm.utils.structure.aligner import Aligner from esm.utils.structure.atom_indexer import AtomIndexer -from esm.utils.structure.metrics import ( - compute_gdt_ts, - compute_lddt_ca, -) -from esm.utils.structure.mmcif_parsing import ( - MmcifWrapper, - Residue, -) +from esm.utils.structure.metrics import compute_gdt_ts, compute_lddt_ca +from esm.utils.structure.mmcif_parsing import MmcifWrapper, Residue from esm.utils.structure.normalize_coordinates import ( apply_frame_to_coords, get_protein_normalization_frame, ) -from esm.utils.structure.protein_structure import ( - index_by_atom_name, -) +from esm.utils.structure.protein_structure import index_by_atom_name from esm.utils.types import PathOrBuffer msgpack_numpy.patch() @@ -401,7 +393,6 @@ def from_blob(cls, input: Path | str | io.BytesIO | bytes): bytes = input return cls.from_state_dict(msgpack.loads(brotli.decompress(bytes))) - def sasa(self, by_residue: bool = True): arr = self.atom_array_no_insertions sasa_per_atom = bs.sasa(arr) # type: ignore @@ -707,7 +698,6 @@ def gdt_ts( ) return float(gdt_ts) if gdt_ts.numel() == 1 else gdt_ts.numpy().flatten() - @classmethod def chain_iterable_from_mmcif( cls, diff --git a/esm/utils/structure/protein_complex.py b/esm/utils/structure/protein_complex.py index 6c96a584..306c8832 100644 --- a/esm/utils/structure/protein_complex.py +++ b/esm/utils/structure/protein_complex.py @@ -32,14 +32,8 @@ from esm.utils.structure.affine3d import Affine3D from esm.utils.structure.aligner import Aligner from esm.utils.structure.atom_indexer import AtomIndexer -from esm.utils.structure.metrics import ( - compute_gdt_ts, - compute_lddt_ca, -) -from esm.utils.structure.mmcif_parsing import ( - MmcifWrapper, - NoProteinError, -) +from esm.utils.structure.metrics import compute_gdt_ts, compute_lddt_ca +from esm.utils.structure.mmcif_parsing import MmcifWrapper, NoProteinError from esm.utils.structure.protein_chain import ( ProteinChain, chain_to_ndarray, diff --git a/esm/widgets/components/function_annotator.py b/esm/widgets/components/function_annotator.py index 714238f9..f567f94d 100644 --- a/esm/widgets/components/function_annotator.py +++ b/esm/widgets/components/function_annotator.py @@ -4,9 +4,7 @@ from ipywidgets import widgets from esm.sdk.api import FunctionAnnotation -from esm.tokenization.function_tokenizer import ( - InterProQuantizedTokenizer, -) +from esm.tokenization.function_tokenizer import InterProQuantizedTokenizer TRIE: pygtrie.CharTrie | None = None diff --git a/esm/widgets/components/results_visualizer.py b/esm/widgets/components/results_visualizer.py index 99692e51..261c0a5d 100644 --- a/esm/widgets/components/results_visualizer.py +++ b/esm/widgets/components/results_visualizer.py @@ -7,15 +7,11 @@ import matplotlib.pyplot as plt from esm.sdk.api import ESMProtein -from esm.widgets.utils.drawing.draw_category_array import ( - draw_data_array, -) +from esm.widgets.utils.drawing.draw_category_array import draw_data_array from esm.widgets.utils.drawing.draw_function_annotations import ( draw_function_annotations, ) -from esm.widgets.utils.drawing.draw_protein_structure import ( - draw_protein_structure, -) +from esm.widgets.utils.drawing.draw_protein_structure import draw_protein_structure from esm.widgets.utils.serialization import ( create_download_button_from_buffer, protein_to_pdb_buffer, diff --git a/esm/widgets/components/sasa_prompt_selector.py b/esm/widgets/components/sasa_prompt_selector.py index ecc5c1b4..9c026500 100644 --- a/esm/widgets/components/sasa_prompt_selector.py +++ b/esm/widgets/components/sasa_prompt_selector.py @@ -3,16 +3,9 @@ import ipywidgets as widgets from esm.utils.structure.protein_chain import ProteinChain -from esm.widgets.utils.drawing.colors import ( - hex_to_rgba_tuple, - rgba_tuple_to_hex, -) -from esm.widgets.utils.drawing.draw_category_array import ( - draw_data_array, -) -from esm.widgets.utils.parsing import ( - convert_range_string_to_list_of_ranges, -) +from esm.widgets.utils.drawing.colors import hex_to_rgba_tuple, rgba_tuple_to_hex +from esm.widgets.utils.drawing.draw_category_array import draw_data_array +from esm.widgets.utils.parsing import convert_range_string_to_list_of_ranges from esm.widgets.utils.prompting import PromptManager diff --git a/esm/widgets/components/secondary_structure_prompt_selector.py b/esm/widgets/components/secondary_structure_prompt_selector.py index 020180fe..b8007d69 100644 --- a/esm/widgets/components/secondary_structure_prompt_selector.py +++ b/esm/widgets/components/secondary_structure_prompt_selector.py @@ -4,16 +4,9 @@ import pydssp from esm.utils.structure.protein_chain import ProteinChain -from esm.widgets.utils.drawing.colors import ( - hex_to_rgba_tuple, - rgba_tuple_to_hex, -) -from esm.widgets.utils.drawing.draw_category_array import ( - draw_data_array, -) -from esm.widgets.utils.parsing import ( - convert_range_string_to_list_of_ranges, -) +from esm.widgets.utils.drawing.colors import hex_to_rgba_tuple, rgba_tuple_to_hex +from esm.widgets.utils.drawing.draw_category_array import draw_data_array +from esm.widgets.utils.parsing import convert_range_string_to_list_of_ranges from esm.widgets.utils.prompting import PromptManager diff --git a/esm/widgets/components/sequence_prompt_selector.py b/esm/widgets/components/sequence_prompt_selector.py index d538e670..c5e3526f 100644 --- a/esm/widgets/components/sequence_prompt_selector.py +++ b/esm/widgets/components/sequence_prompt_selector.py @@ -6,9 +6,7 @@ hex_to_rgba_tuple, rgba_tuple_to_rgba_html_string, ) -from esm.widgets.utils.parsing import ( - convert_range_string_to_list_of_ranges, -) +from esm.widgets.utils.parsing import convert_range_string_to_list_of_ranges from esm.widgets.utils.prompting import PromptManager diff --git a/esm/widgets/components/structure_prompt_selector.py b/esm/widgets/components/structure_prompt_selector.py index f4b497c0..13a9df78 100644 --- a/esm/widgets/components/structure_prompt_selector.py +++ b/esm/widgets/components/structure_prompt_selector.py @@ -10,12 +10,8 @@ from esm.utils.structure.protein_chain import ProteinChain from esm.widgets.utils import indexing -from esm.widgets.utils.drawing.draw_protein_structure import ( - draw_protein_structure, -) -from esm.widgets.utils.parsing import ( - convert_range_string_to_list_of_ranges, -) +from esm.widgets.utils.drawing.draw_protein_structure import draw_protein_structure +from esm.widgets.utils.parsing import convert_range_string_to_list_of_ranges from esm.widgets.utils.printing import wrapped_print from esm.widgets.utils.prompting import PromptManager diff --git a/esm/widgets/utils/drawing/draw_function_annotations.py b/esm/widgets/utils/drawing/draw_function_annotations.py index c71e5434..59e9f7cf 100644 --- a/esm/widgets/utils/drawing/draw_function_annotations.py +++ b/esm/widgets/utils/drawing/draw_function_annotations.py @@ -9,10 +9,7 @@ from PIL import Image from esm.sdk.api import FunctionAnnotation -from esm.utils.function.interpro import ( - InterPro, - InterProEntryType, -) +from esm.utils.function.interpro import InterPro, InterProEntryType @contextmanager diff --git a/esm/widgets/utils/prompting.py b/esm/widgets/utils/prompting.py index 1ce6d9c6..1e89bb64 100644 --- a/esm/widgets/utils/prompting.py +++ b/esm/widgets/utils/prompting.py @@ -9,9 +9,7 @@ from esm.utils import encoding from esm.widgets.utils import indexing from esm.widgets.utils.drawing.colors import rgba_tuple_to_hex -from esm.widgets.utils.drawing.draw_category_array import ( - draw_data_array, -) +from esm.widgets.utils.drawing.draw_category_array import draw_data_array from esm.widgets.utils.printing import wrapped_print diff --git a/esm/widgets/views/esm3_generation_launcher.py b/esm/widgets/views/esm3_generation_launcher.py index e94c60ff..f8bf8f3b 100644 --- a/esm/widgets/views/esm3_generation_launcher.py +++ b/esm/widgets/views/esm3_generation_launcher.py @@ -13,13 +13,9 @@ GenerationConfig, ) from esm.utils.constants import models -from esm.widgets.components.results_visualizer import ( - create_results_visualizer, -) +from esm.widgets.components.results_visualizer import create_results_visualizer from esm.widgets.utils.printing import wrapped_print -from esm.widgets.utils.serialization import ( - create_download_results_button, -) +from esm.widgets.utils.serialization import create_download_results_button def create_esm3_generation_launcher( diff --git a/esm/widgets/views/esm3_prompt_selector.py b/esm/widgets/views/esm3_prompt_selector.py index 035db28e..f7e60686 100644 --- a/esm/widgets/views/esm3_prompt_selector.py +++ b/esm/widgets/views/esm3_prompt_selector.py @@ -1,8 +1,6 @@ from ipywidgets import widgets -from esm.widgets.components.sasa_prompt_selector import ( - create_sasa_prompt_selector, -) +from esm.widgets.components.sasa_prompt_selector import create_sasa_prompt_selector from esm.widgets.components.secondary_structure_prompt_selector import ( create_secondary_structure_prompt_selector, ) diff --git a/esm/widgets/views/generation.py b/esm/widgets/views/generation.py index 19015f60..fdec2094 100644 --- a/esm/widgets/views/generation.py +++ b/esm/widgets/views/generation.py @@ -4,20 +4,12 @@ from esm.sdk.api import ESM3InferenceClient, ESMProtein from esm.utils.constants import esm3 as C -from esm.widgets.components.function_annotator import ( - create_function_annotator, -) +from esm.widgets.components.function_annotator import create_function_annotator from esm.widgets.utils.prompting import PromptManagerCollection from esm.widgets.utils.protein_import import ProteinImporter -from esm.widgets.views.esm3_generation_launcher import ( - create_esm3_generation_launcher, -) -from esm.widgets.views.esm3_prompt_preview import ( - create_esm3_prompt_preview, -) -from esm.widgets.views.esm3_prompt_selector import ( - create_esm3_prompt_selector, -) +from esm.widgets.views.esm3_generation_launcher import create_esm3_generation_launcher +from esm.widgets.views.esm3_prompt_preview import create_esm3_prompt_preview +from esm.widgets.views.esm3_prompt_selector import create_esm3_prompt_selector def create_generation_ui( diff --git a/esm/widgets/views/inverse_folding.py b/esm/widgets/views/inverse_folding.py index 8becb8eb..50d9a128 100644 --- a/esm/widgets/views/inverse_folding.py +++ b/esm/widgets/views/inverse_folding.py @@ -6,9 +6,7 @@ ESMProteinError, GenerationConfig, ) -from esm.widgets.components.results_visualizer import ( - create_results_visualizer, -) +from esm.widgets.components.results_visualizer import create_results_visualizer from esm.widgets.utils.printing import wrapped_print from esm.widgets.utils.protein_import import ProteinImporter diff --git a/esm/widgets/views/login.py b/esm/widgets/views/login.py index 2d8be5a3..5c7b6706 100644 --- a/esm/widgets/views/login.py +++ b/esm/widgets/views/login.py @@ -4,10 +4,7 @@ from ipywidgets import widgets -from esm.widgets.utils.clients import ( - get_forge_client, - get_local_client, -) +from esm.widgets.utils.clients import get_forge_client, get_local_client from esm.widgets.utils.types import ClientInitContainer diff --git a/esm/widgets/views/prediction.py b/esm/widgets/views/prediction.py index de6666d2..94ff49dc 100644 --- a/esm/widgets/views/prediction.py +++ b/esm/widgets/views/prediction.py @@ -6,9 +6,7 @@ ESMProteinError, GenerationConfig, ) -from esm.widgets.components.results_visualizer import ( - create_results_visualizer, -) +from esm.widgets.components.results_visualizer import create_results_visualizer from esm.widgets.utils.printing import wrapped_print from esm.widgets.utils.protein_import import ProteinImporter diff --git a/tests/oss_pytests/test_oss_client.py b/tests/oss_pytests/test_oss_client.py index d597a220..4e7638f6 100644 --- a/tests/oss_pytests/test_oss_client.py +++ b/tests/oss_pytests/test_oss_client.py @@ -2,6 +2,7 @@ import pytest import torch + from esm.sdk import client # pyright: ignore from esm.sdk.api import ( # pyright: ignore ESMProtein,