Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cookbook/local/open_generate.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
"\n",
"!pip install py3Dmol\n",
"import py3Dmol\n",
"\n",
"from esm.models.esm3 import ESM3\n",
"from esm.sdk.api import ESMProtein, GenerationConfig\n",
"from esm.utils.structure.protein_chain import ProteinChain"
Expand Down
5 changes: 2 additions & 3 deletions cookbook/local/raw_forwards.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import torch
import torch.nn.functional as F

from esm.pretrained import (
ESM3_function_decoder_v0,
ESM3_sm_open_v0,
Expand All @@ -12,9 +13,7 @@
from esm.tokenization.function_tokenizer import (
InterProQuantizedTokenizer as EsmFunctionTokenizer,
)
from esm.tokenization.sequence_tokenizer import (
EsmSequenceTokenizer,
)
from esm.tokenization.sequence_tokenizer import EsmSequenceTokenizer
from esm.utils.structure.protein_chain import ProteinChain
from esm.utils.types import FunctionAnnotation

Expand Down
1 change: 1 addition & 0 deletions cookbook/snippets/fold_invfold.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import cast

import numpy as np

from esm.sdk.api import (
ESM3InferenceClient,
ESMProtein,
Expand Down
4 changes: 3 additions & 1 deletion cookbook/tutorials/1_esmprotein.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
"outputs": [],
"source": [
"from biotite.database import rcsb\n",
"\n",
"from esm.sdk.api import ESMProtein\n",
"from esm.utils.structure.protein_chain import ProteinChain\n",
"from esm.utils.types import FunctionAnnotation\n",
Expand Down Expand Up @@ -496,9 +497,10 @@
"# Functions for visualizing InterPro function annotations\n",
"\n",
"from dna_features_viewer import GraphicFeature, GraphicRecord\n",
"from esm.utils.function.interpro import InterPro, InterProEntryType\n",
"from matplotlib import colormaps\n",
"\n",
"from esm.utils.function.interpro import InterPro, InterProEntryType\n",
"\n",
"\n",
"def visualize_function_annotations(\n",
" annotations: list[FunctionAnnotation],\n",
Expand Down
6 changes: 3 additions & 3 deletions cookbook/tutorials/2_embed.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -49,18 +49,18 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Grab a token from [Forge](https://forge.evolutionaryscale.ai/) and add it below. Note that your token is like a password for your account and you should take care to protect it. For this reason it is recommended to frequently create a new token and delete old, unused ones. It is also recommended to paste the token directly into an environment variable or use a utility like `getpass` as shown below so tokens are not accidentally shared or checked into code repositories."
"Grab a token from [the Forge console](https://forge.evolutionaryscale.ai/console) and add it below. Note that your token is like a password for your account and you should take care to protect it. For this reason it is recommended to frequently create a new token and delete old, unused ones. It is also recommended to paste the token directly into an environment variable or use a utility like `getpass` as shown below so tokens are not accidentally shared or checked into code repositories."
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"from getpass import getpass\n",
"\n",
"token = getpass(\"Token from Forge: \")"
"token = getpass(\"Token from Forge console: \")"
]
},
{
Expand Down
7 changes: 4 additions & 3 deletions cookbook/tutorials/3_gfp_design.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
"import matplotlib.pyplot as pl\n",
"import py3Dmol\n",
"import torch\n",
"\n",
"from esm.sdk import client\n",
"from esm.sdk.api import ESMProtein, GenerationConfig\n",
"from esm.utils.structure.protein_chain import ProteinChain"
Expand All @@ -79,18 +80,18 @@
"\n",
"The largest ESM3 (98 billion parameters) was trained with 1.07e24 FLOPs on 2.78 billion proteins and 771 billion unique tokens. To create esmGFP we used the 7 billion parameter variant of ESM3. We'll use this model via the [EvolutionaryScale Forge](https://forge.evolutionaryscale.ai) API.\n",
"\n",
"Grab a token from [Forge](https://forge.evolutionaryscale.ai/) and add it below. Note that your token is like a password for your account and you should take care to protect it. For this reason it is recommended to frequently create a new token and delete old, unused ones. It is also recommended to paste the token directly into an environment variable or use a utility like `getpass` as shown below so tokens are not accidentally shared or checked into code repositories.\n"
"Grab a token from [the Forge console](https://forge.evolutionaryscale.ai/console) and add it below. Note that your token is like a password for your account and you should take care to protect it. For this reason it is recommended to frequently create a new token and delete old, unused ones. It is also recommended to paste the token directly into an environment variable or use a utility like `getpass` as shown below so tokens are not accidentally shared or checked into code repositories.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 3,
"metadata": {
"id": "zNrU9Q2SYonX"
},
"outputs": [],
"source": [
"token = getpass(\"Token from Forge: \")"
"token = getpass(\"Token from Forge console: \")"
]
},
{
Expand Down
5 changes: 3 additions & 2 deletions cookbook/tutorials/4_forge_generate.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
"\n",
"!pip install py3Dmol\n",
"import py3Dmol\n",
"\n",
"from esm.sdk import client\n",
"from esm.sdk.api import ESMProtein, GenerationConfig\n",
"from esm.utils.structure.protein_chain import ProteinChain"
Expand All @@ -52,7 +53,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Grab a token from [Forge](https://forge.evolutionaryscale.ai/) and add it below. Note that your token is like a password for your account and you should take care to protect it. For this reason it is recommended to frequently create a new token and delete old, unused ones. It is also recommended to paste the token directly into an environment variable or use a utility like `getpass` as shown below so tokens are not accidentally shared or checked into code repositories."
"Grab a token from [the Forge console](https://forge.evolutionaryscale.ai/console) and add it below. Note that your token is like a password for your account and you should take care to protect it. For this reason it is recommended to frequently create a new token and delete old, unused ones. It is also recommended to paste the token directly into an environment variable or use a utility like `getpass` as shown below so tokens are not accidentally shared or checked into code repositories."
]
},
{
Expand All @@ -63,7 +64,7 @@
"source": [
"from getpass import getpass\n",
"\n",
"token = getpass(\"Token from Forge: \")\n",
"token = getpass(\"Token from Forge console: \")\n",
"model = client(model=\"esm3-open\", url=\"https://forge.evolutionaryscale.ai\", token=token)"
]
},
Expand Down
3 changes: 2 additions & 1 deletion cookbook/tutorials/5_guided_generation.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
"source": [
"import biotite.structure as bs\n",
"import py3Dmol\n",
"\n",
"from esm.sdk.api import ESMProtein, GenerationConfig\n",
"from esm.sdk.experimental import ESM3GuidedDecoding, GuidedDecodingScoringFunction"
]
Expand Down Expand Up @@ -119,7 +120,7 @@
"\n",
"from esm.sdk import client\n",
"\n",
"token = getpass(\"Token from Forge: \")\n",
"token = getpass(\"Token from Forge console: \")\n",
"model = client(\n",
" model=\"esm3-medium-2024-08\", url=\"https://forge.evolutionaryscale.ai\", token=token\n",
")"
Expand Down
3 changes: 1 addition & 2 deletions esm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1 @@
__version__ = "3.2.2.post1"

__version__ = "3.2.2"
5 changes: 1 addition & 4 deletions esm/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,7 @@
import torch.nn.functional as F
from torch import nn

from esm.layers.rotary import (
RotaryEmbedding,
TritonRotaryEmbedding,
)
from esm.layers.rotary import RotaryEmbedding, TritonRotaryEmbedding

try:
from flash_attn import flash_attn_varlen_qkvpacked_func
Expand Down
9 changes: 2 additions & 7 deletions esm/layers/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,8 @@
import torch.nn as nn
import torch.nn.functional as F

from esm.layers.attention import (
FlashMultiHeadAttention,
MultiHeadAttention,
)
from esm.layers.geom_attention import (
GeometricReasoningOriginalImpl,
)
from esm.layers.attention import FlashMultiHeadAttention, MultiHeadAttention
from esm.layers.geom_attention import GeometricReasoningOriginalImpl
from esm.utils.structure.affine3d import Affine3D


Expand Down
5 changes: 1 addition & 4 deletions esm/layers/structure_proj.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,7 @@
import torch.nn as nn

from esm.utils.constants.physics import BB_COORDINATES
from esm.utils.structure.affine3d import (
Affine3D,
RotationMatrix,
)
from esm.utils.structure.affine3d import Affine3D, RotationMatrix


class Dim6RotStructureHead(nn.Module):
Expand Down
14 changes: 3 additions & 11 deletions esm/models/esm3.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,7 @@
from esm.layers.regression_head import RegressionHead
from esm.layers.transformer_stack import TransformerStack
from esm.models.function_decoder import FunctionTokenDecoder
from esm.models.vqvae import (
StructureTokenDecoder,
StructureTokenEncoder,
)
from esm.models.vqvae import StructureTokenDecoder, StructureTokenEncoder
from esm.sdk.api import (
ESM3InferenceClient,
ESMProtein,
Expand All @@ -32,10 +29,7 @@
from esm.tokenization import TokenizerCollectionProtocol
from esm.utils import encoding
from esm.utils.constants import esm3 as C
from esm.utils.constants.models import (
ESM3_OPEN_SMALL,
normalize_model_name,
)
from esm.utils.constants.models import ESM3_OPEN_SMALL, normalize_model_name
from esm.utils.decoding import decode_protein_tensor
from esm.utils.generation import (
_batch_forward,
Expand All @@ -50,9 +44,7 @@
get_default_sampling_config,
validate_sampling_config,
)
from esm.utils.structure.affine3d import (
build_affine3d_from_coordinates,
)
from esm.utils.structure.affine3d import build_affine3d_from_coordinates


@dataclass
Expand Down
4 changes: 1 addition & 3 deletions esm/models/function_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,7 @@

from esm.layers.regression_head import RegressionHead
from esm.layers.transformer_stack import TransformerStack
from esm.tokenization.function_tokenizer import (
InterProQuantizedTokenizer,
)
from esm.tokenization.function_tokenizer import InterProQuantizedTokenizer
from esm.utils.constants import esm3 as C
from esm.utils.misc import merge_annotations, merge_ranges
from esm.utils.types import FunctionAnnotation
Expand Down
5 changes: 1 addition & 4 deletions esm/models/vqvae.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,7 @@
from esm.layers.transformer_stack import TransformerStack
from esm.utils.constants import esm3 as C
from esm.utils.misc import knn_graph
from esm.utils.structure.affine3d import (
Affine3D,
build_affine3d_from_coordinates,
)
from esm.utils.structure.affine3d import Affine3D, build_affine3d_from_coordinates
from esm.utils.structure.predicted_aligned_error import (
compute_predicted_aligned_error,
compute_tm,
Expand Down
10 changes: 2 additions & 8 deletions esm/pretrained.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,8 @@
from esm.models.esm3 import ESM3
from esm.models.esmc import ESMC
from esm.models.function_decoder import FunctionTokenDecoder
from esm.models.vqvae import (
StructureTokenDecoder,
StructureTokenEncoder,
)
from esm.tokenization import (
get_esm3_model_tokenizers,
get_esmc_model_tokenizers,
)
from esm.models.vqvae import StructureTokenDecoder, StructureTokenEncoder
from esm.tokenization import get_esm3_model_tokenizers, get_esmc_model_tokenizers
from esm.utils.constants.esm3 import data_root
from esm.utils.constants.models import (
ESM3_FUNCTION_DECODER_V0,
Expand Down
54 changes: 12 additions & 42 deletions esm/sdk/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,27 +2,19 @@

from abc import ABC
from copy import deepcopy
from typing import List, Sequence
from typing import Sequence

import attr
import torch
from attr import asdict, define

import esm.utils.constants.api as C
from esm.tokenization import (
TokenizerCollectionProtocol,
get_esm3_model_tokenizers,
)
from esm.tokenization import TokenizerCollectionProtocol, get_esm3_model_tokenizers
from esm.utils import encoding
from esm.utils.constants.models import ESM3_OPEN_SMALL
from esm.utils.misc import (
get_chainbreak_boundaries_from_sequence,
)
from esm.utils.misc import get_chainbreak_boundaries_from_sequence
from esm.utils.structure.protein_chain import ProteinChain
from esm.utils.structure.protein_complex import (
SINGLE_LETTER_CHAIN_IDS,
ProteinComplex,
)
from esm.utils.structure.protein_complex import SINGLE_LETTER_CHAIN_IDS, ProteinComplex
from esm.utils.types import FunctionAnnotation, PathOrBuffer


Expand All @@ -43,7 +35,6 @@ class ESMProtein(ProteinType):
plddt: torch.Tensor | None = None
ptm: torch.Tensor | None = None


# When calling EvolutionaryScale API, use this flag to disclose any
# sequences that may potentially have concerns.
# Such sequences may not go through standard safety filter for approved users.
Expand Down Expand Up @@ -157,43 +148,20 @@ def to_protein_complex(
gt_chains = list(copy_annotations_from_ground_truth.chain_iter())
else:
gt_chains = None

# Expand pLDDT to match sequence length if needed, inserting NaN at chain breaks
# This handles the case where the server doesn't include chain breaks in pLDDT
# We should fix this in the server side.
if self.plddt is not None and len(self.plddt) != len(self.sequence):
# Only expand if there's a mismatch (likely due to chain breaks)
if "|" in self.sequence:
# Create expanded pLDDT with NaN at chain break positions
expanded_plddt = torch.full((len(self.sequence),), float("nan"))
plddt_idx = 0
for i, aa in enumerate(self.sequence):
if aa != "|":
if plddt_idx < len(self.plddt):
expanded_plddt[i] = self.plddt[plddt_idx]
plddt_idx += 1
plddt = expanded_plddt
else:
# Mismatch but no chain breaks - shouldn't happen but preserve original
plddt = self.plddt
else:
plddt = self.plddt

pred_chains = []
for i, (start, end) in enumerate(chain_boundaries):
if i >= len(SINGLE_LETTER_CHAIN_IDS):
raise ValueError(
f"Too many chains to convert to ProteinComplex. The maximum number of chains is {len(SINGLE_LETTER_CHAIN_IDS)}"
)

pred_chain = ProteinChain.from_atom37(
atom37_positions=coords[start:end],
sequence=self.sequence[start:end],
chain_id=gt_chains[i].chain_id
if gt_chains is not None
else SINGLE_LETTER_CHAIN_IDS[i],
entity_id=gt_chains[i].entity_id if gt_chains is not None else None,
confidence=plddt[start:end] if plddt is not None else None,
confidence=self.plddt[start:end] if self.plddt is not None else None,
)
pred_chains.append(pred_chain)
return ProteinComplex.from_chains(pred_chains)
Expand Down Expand Up @@ -330,14 +298,19 @@ def use_generative_unmasking_strategy(self):
self.temperature_annealing = True


@define
class MSA:
# Paired MSA sequences.
# One would typically compute these using, for example, ColabFold.
sequences: list[str]


@define
class InverseFoldingConfig:
invalid_ids: Sequence[int] = []
temperature: float = 1.0




## Low Level Endpoint Types
@define
class SamplingTrackConfig:
Expand Down Expand Up @@ -402,9 +375,6 @@ class LogitsConfig:
ith_hidden_layer: int = -1





@define
class LogitsOutput:
logits: ForwardTrackData | None = None
Expand Down
Loading
Loading