Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions cookbook/tutorials/2_embed.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -49,18 +49,18 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Grab a token from [the Forge console](https://forge.evolutionaryscale.ai/console) and add it below. Note that your token is like a password for your account and you should take care to protect it. For this reason it is recommended to frequently create a new token and delete old, unused ones. It is also recommended to paste the token directly into an environment variable or use a utility like `getpass` as shown below so tokens are not accidentally shared or checked into code repositories."
"Grab a token from [Forge](https://forge.evolutionaryscale.ai/) and add it below. Note that your token is like a password for your account and you should take care to protect it. For this reason it is recommended to frequently create a new token and delete old, unused ones. It is also recommended to paste the token directly into an environment variable or use a utility like `getpass` as shown below so tokens are not accidentally shared or checked into code repositories."
]
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from getpass import getpass\n",
"\n",
"token = getpass(\"Token from Forge console: \")"
"token = getpass(\"Token from Forge: \")"
]
},
{
Expand Down
6 changes: 3 additions & 3 deletions cookbook/tutorials/3_gfp_design.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -80,18 +80,18 @@
"\n",
"The largest ESM3 (98 billion parameters) was trained with 1.07e24 FLOPs on 2.78 billion proteins and 771 billion unique tokens. To create esmGFP we used the 7 billion parameter variant of ESM3. We'll use this model via the [EvolutionaryScale Forge](https://forge.evolutionaryscale.ai) API.\n",
"\n",
"Grab a token from [the Forge console](https://forge.evolutionaryscale.ai/console) and add it below. Note that your token is like a password for your account and you should take care to protect it. For this reason it is recommended to frequently create a new token and delete old, unused ones. It is also recommended to paste the token directly into an environment variable or use a utility like `getpass` as shown below so tokens are not accidentally shared or checked into code repositories.\n"
"Grab a token from [Forge](https://forge.evolutionaryscale.ai/) and add it below. Note that your token is like a password for your account and you should take care to protect it. For this reason it is recommended to frequently create a new token and delete old, unused ones. It is also recommended to paste the token directly into an environment variable or use a utility like `getpass` as shown below so tokens are not accidentally shared or checked into code repositories.\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": null,
"metadata": {
"id": "zNrU9Q2SYonX"
},
"outputs": [],
"source": [
"token = getpass(\"Token from Forge console: \")"
"token = getpass(\"Token from Forge: \")"
]
},
{
Expand Down
4 changes: 2 additions & 2 deletions cookbook/tutorials/4_forge_generate.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Grab a token from [the Forge console](https://forge.evolutionaryscale.ai/console) and add it below. Note that your token is like a password for your account and you should take care to protect it. For this reason it is recommended to frequently create a new token and delete old, unused ones. It is also recommended to paste the token directly into an environment variable or use a utility like `getpass` as shown below so tokens are not accidentally shared or checked into code repositories."
"Grab a token from [Forge](https://forge.evolutionaryscale.ai/) and add it below. Note that your token is like a password for your account and you should take care to protect it. For this reason it is recommended to frequently create a new token and delete old, unused ones. It is also recommended to paste the token directly into an environment variable or use a utility like `getpass` as shown below so tokens are not accidentally shared or checked into code repositories."
]
},
{
Expand All @@ -64,7 +64,7 @@
"source": [
"from getpass import getpass\n",
"\n",
"token = getpass(\"Token from Forge console: \")\n",
"token = getpass(\"Token from Forge: \")\n",
"model = client(model=\"esm3-open\", url=\"https://forge.evolutionaryscale.ai\", token=token)"
]
},
Expand Down
2 changes: 1 addition & 1 deletion cookbook/tutorials/5_guided_generation.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@
"\n",
"from esm.sdk import client\n",
"\n",
"token = getpass(\"Token from Forge console: \")\n",
"token = getpass(\"Token from Forge: \")\n",
"model = client(\n",
" model=\"esm3-medium-2024-08\", url=\"https://forge.evolutionaryscale.ai\", token=token\n",
")"
Expand Down
2 changes: 1 addition & 1 deletion esm/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "3.2.2"
__version__ = "3.2.2.post2"
32 changes: 24 additions & 8 deletions esm/sdk/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,20 +148,43 @@ def to_protein_complex(
gt_chains = list(copy_annotations_from_ground_truth.chain_iter())
else:
gt_chains = None

# Expand pLDDT to match sequence length if needed, inserting NaN at chain breaks
# This handles the case where the server doesn't include chain breaks in pLDDT
# We should fix this in the server side.
if self.plddt is not None and len(self.plddt) != len(self.sequence):
# Only expand if there's a mismatch (likely due to chain breaks)
if "|" in self.sequence:
# Create expanded pLDDT with NaN at chain break positions
expanded_plddt = torch.full((len(self.sequence),), float("nan"))
plddt_idx = 0
for i, aa in enumerate(self.sequence):
if aa != "|":
if plddt_idx < len(self.plddt):
expanded_plddt[i] = self.plddt[plddt_idx]
plddt_idx += 1
plddt = expanded_plddt
else:
# Mismatch but no chain breaks - shouldn't happen but preserve original
plddt = self.plddt
else:
plddt = self.plddt

pred_chains = []
for i, (start, end) in enumerate(chain_boundaries):
if i >= len(SINGLE_LETTER_CHAIN_IDS):
raise ValueError(
f"Too many chains to convert to ProteinComplex. The maximum number of chains is {len(SINGLE_LETTER_CHAIN_IDS)}"
)

pred_chain = ProteinChain.from_atom37(
atom37_positions=coords[start:end],
sequence=self.sequence[start:end],
chain_id=gt_chains[i].chain_id
if gt_chains is not None
else SINGLE_LETTER_CHAIN_IDS[i],
entity_id=gt_chains[i].entity_id if gt_chains is not None else None,
confidence=self.plddt[start:end] if self.plddt is not None else None,
confidence=plddt[start:end] if plddt is not None else None,
)
pred_chains.append(pred_chain)
return ProteinComplex.from_chains(pred_chains)
Expand Down Expand Up @@ -298,13 +321,6 @@ def use_generative_unmasking_strategy(self):
self.temperature_annealing = True


@define
class MSA:
# Paired MSA sequences.
# One would typically compute these using, for example, ColabFold.
sequences: list[str]


@define
class InverseFoldingConfig:
invalid_ids: Sequence[int] = []
Expand Down
108 changes: 88 additions & 20 deletions esm/sdk/forge.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import asyncio
import base64
import pickle
Expand All @@ -7,7 +9,6 @@
import torch

from esm.sdk.api import (
MSA,
ESM3InferenceClient,
ESMCInferenceClient,
ESMProtein,
Expand All @@ -27,6 +28,15 @@
from esm.sdk.retry import retry_decorator
from esm.utils.constants.api import MIMETYPE_ES_PICKLE
from esm.utils.misc import deserialize_tensors, maybe_list, maybe_tensor
from esm.utils.msa import MSA
from esm.utils.structure.input_builder import (
StructurePredictionInput,
serialize_structure_prediction_input,
)
from esm.utils.structure.molecular_complex import (
MolecularComplex,
MolecularComplexResult,
)
from esm.utils.types import FunctionAnnotation


Expand All @@ -36,10 +46,8 @@ def _list_to_function_annotations(l) -> list[FunctionAnnotation] | None:
return [FunctionAnnotation(*t) for t in l]


def _maybe_logits(data: dict[str, Any], track: str, return_bytes: bool = False):
ret = data.get("logits", {}).get(track, None)
# TODO(s22chan): just return this when removing return_bytes
return ret if ret is None or not return_bytes else maybe_tensor(ret)
def _maybe_logits(data: dict[str, Any], track: str):
return maybe_tensor(data.get("logits", {}).get(track, None))


def _maybe_b64_decode(obj, return_bytes: bool):
Expand Down Expand Up @@ -137,7 +145,7 @@ async def _async_fetch_msa(self, sequence: str) -> MSA:
data = await self._async_post(
"msa", request={}, params={"sequence": sequence, "use_env": False}
)
return MSA(sequences=data["msa"])
return MSA.from_sequences(sequences=data["msa"])

def _fetch_msa(self, sequence: str) -> MSA:
print("Fetching MSA ... this may take a few minutes")
Expand All @@ -146,7 +154,7 @@ def _fetch_msa(self, sequence: str) -> MSA:
data = self._post(
"msa", request={}, params={"sequence": sequence, "use_env": False}
)
return MSA(sequences=data["msa"])
return MSA.from_sequences(sequences=data["msa"])

@retry_decorator
async def async_fold(
Expand Down Expand Up @@ -209,6 +217,70 @@ def fold(

return self._process_fold_response(data, sequence)

@retry_decorator
async def async_fold_all_atom(
self, all_atom_input: StructurePredictionInput, model_name: str | None = None
) -> MolecularComplexResult | list[MolecularComplexResult] | ESMProteinError:
"""Fold a molecular complex containing proteins, nucleic acids, and/or ligands.

Args:
all_atom_input: StructurePredictionInput containing sequences for different molecule types
model_name: Override the client level model name if needed
"""
request = self._process_fold_all_atom_request(
all_atom_input, model_name if model_name is not None else self.model
)

try:
data = await self._async_post("fold_all_atom", request)
except ESMProteinError as e:
return e

return self._process_fold_all_atom_response(data)

@retry_decorator
def fold_all_atom(
self, all_atom_input: StructurePredictionInput, model_name: str | None = None
) -> MolecularComplexResult | list[MolecularComplexResult] | ESMProteinError:
"""Predict coordinates for a molecular complex containing proteins, dna, rna, and/or ligands.

Args:
all_atom_input: StructurePredictionInput containing sequences for different molecule types
model_name: Override the client level model name if needed
"""
request = self._process_fold_all_atom_request(
all_atom_input, model_name if model_name is not None else self.model
)

try:
data = self._post("fold_all_atom", request)
except ESMProteinError as e:
return e

return self._process_fold_all_atom_response(data)

@staticmethod
def _process_fold_all_atom_request(
all_atom_input: StructurePredictionInput, model_name: str | None = None
) -> dict[str, Any]:
request: dict[str, Any] = {
"all_atom_input": serialize_structure_prediction_input(all_atom_input),
"model": model_name,
}

return request

@staticmethod
def _process_fold_all_atom_response(data: dict[str, Any]) -> MolecularComplexResult:
complex_data = data.get("complex")
molecular_complex = MolecularComplex.from_state_dict(complex_data)
return MolecularComplexResult(
complex=molecular_complex,
plddt=maybe_tensor(data.get("plddt"), convert_none_to_nan=True),
ptm=data.get("ptm", None),
distogram=maybe_tensor(data.get("distogram"), convert_none_to_nan=True),
)

@retry_decorator
async def async_inverse_fold(
self,
Expand Down Expand Up @@ -602,19 +674,15 @@ def _process_logits_response(

return LogitsOutput(
logits=ForwardTrackData(
sequence=_maybe_logits(data, "sequence", return_bytes),
structure=_maybe_logits(data, "structure", return_bytes),
secondary_structure=_maybe_logits(
data, "secondary_structure", return_bytes
),
sasa=_maybe_logits(data, "sasa", return_bytes),
function=_maybe_logits(data, "function", return_bytes),
sequence=_maybe_logits(data, "sequence"),
structure=_maybe_logits(data, "structure"),
secondary_structure=_maybe_logits(data, "secondary_structure"),
sasa=_maybe_logits(data, "sasa"),
function=_maybe_logits(data, "function"),
),
embeddings=maybe_tensor(data["embeddings"]),
mean_embedding=data["mean_embedding"],
residue_annotation_logits=_maybe_logits(
data, "residue_annotation", return_bytes
),
residue_annotation_logits=_maybe_logits(data, "residue_annotation"),
hidden_states=maybe_tensor(data["hidden_states"]),
mean_hidden_state=maybe_tensor(data["mean_hidden_state"]),
)
Expand Down Expand Up @@ -965,6 +1033,7 @@ def _process_logits_request(
"sequence": config.sequence,
"return_embeddings": config.return_embeddings,
"return_mean_embedding": config.return_mean_embedding,
"return_mean_hidden_states": config.return_mean_hidden_states,
"return_hidden_states": config.return_hidden_states,
"ith_hidden_layer": config.ith_hidden_layer,
}
Expand All @@ -981,12 +1050,11 @@ def _process_logits_response(
data["hidden_states"] = _maybe_b64_decode(data["hidden_states"], return_bytes)

output = LogitsOutput(
logits=ForwardTrackData(
sequence=_maybe_logits(data, "sequence", return_bytes)
),
logits=ForwardTrackData(sequence=_maybe_logits(data, "sequence")),
embeddings=maybe_tensor(data["embeddings"]),
mean_embedding=data["mean_embedding"],
hidden_states=maybe_tensor(data["hidden_states"]),
mean_hidden_state=maybe_tensor(data["mean_hidden_state"]),
)
return output

Expand Down
19 changes: 14 additions & 5 deletions esm/sdk/retry.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,9 @@
from contextvars import ContextVar
from functools import wraps

import httpx
from tenacity import (
retry,
retry_if_exception_type,
retry_if_exception,
retry_if_result,
stop_after_attempt,
wait_incrementing,
Expand All @@ -30,8 +29,12 @@ def retry_if_specific_error(exception):


def log_retry_attempt(retry_state):
try:
outcome = retry_state.outcome.result()
except Exception:
outcome = retry_state.outcome.exception()
print(
f"Retrying... Attempt {retry_state.attempt_number} after {retry_state.next_action.sleep}s due to: {retry_state.outcome.result()}"
f"Retrying... Attempt {retry_state.attempt_number} after {retry_state.next_action.sleep}s due to: {outcome}"
)


Expand All @@ -41,13 +44,18 @@ def retry_decorator(func):
instance's retry settings.
"""

def return_last_value(retry_state):
"""Return the result of the last call attempt."""
return retry_state.outcome.result()

@wraps(func)
async def async_wrapper(instance, *args, **kwargs):
if skip_retries_var.get():
return await func(instance, *args, **kwargs)
retry_decorator = retry(
retry_error_callback=return_last_value,
retry=retry_if_result(retry_if_specific_error)
| retry_if_exception_type(httpx.ConnectTimeout), # ADDED
| retry_if_exception(retry_if_specific_error),
wait=wait_incrementing(
increment=1, start=instance.min_retry_wait, max=instance.max_retry_wait
),
Expand All @@ -62,8 +70,9 @@ def wrapper(instance, *args, **kwargs):
if skip_retries_var.get():
return func(instance, *args, **kwargs)
retry_decorator = retry(
retry_error_callback=return_last_value,
retry=retry_if_result(retry_if_specific_error)
| retry_if_exception_type(httpx.ConnectTimeout), # ADDED
| retry_if_exception(retry_if_specific_error),
wait=wait_incrementing(
increment=1, start=instance.min_retry_wait, max=instance.max_retry_wait
),
Expand Down
4 changes: 3 additions & 1 deletion esm/utils/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,9 @@ def _trim_sequence_tensor_dataclass(o: Any, sequence_len: int):

sliced = {}
for k, v in attr.asdict(o, recurse=False).items():
if v is None:
if k in ["mean_hidden_state", "mean_embedding"]:
sliced[k] = v
elif v is None:
sliced[k] = None
elif isinstance(v, torch.Tensor):
# Trim padding.
Expand Down
Loading
Loading