Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 6 additions & 17 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ concurrency:

jobs:
test-precommit:
runs-on: ubuntu-22.04-16core
runs-on: ubuntu-24.04

steps:
- name: Checkout code
Expand All @@ -32,20 +32,11 @@ jobs:
cache-write: ${{ github.event_name == 'push' && github.ref_name == 'main' }}

- name: Check formatting and typing
run: |
set -e
# pyright seems to do something weird at initialization that causes it to error out
# We can ignore the first invocation here.
pyright esm/__init__.py || true
pre-commit install
env NODE_OPTIONS="--max-old-space-size=16384" pre-commit run --all-files --show-diff-on-failure
[ -z "$(git status --porcelain)" ] && true || (echo "❌❌❌ pre-commit hook failed! A few files changed ❌❌❌]" && git status --porcelain && false)
git reset --hard HEAD # test without the pre-commit changes
shell: pixi run bash -e {0}
run: pixi run lint-all


test-esm:
runs-on: ubuntu-22.04-16core
runs-on: ubuntu-24.04

steps:
- name: Checkout code
Expand All @@ -59,20 +50,18 @@ jobs:
cache-write: ${{ github.event_name == 'push' && github.ref_name == 'main' }}

- name: Run tests
run: |
set -o pipefail
pytest -v --junitxml=pytest.xml tests/ | tee pytest-coverage.txt
shell: pixi run bash -e {0}
run: pixi run cov-test

- name: Run Docker tests
env:
DOCKER_TAG: ${{ github.sha }}
FORGE_URL: https://forge.evolutionaryscale.ai/
ESM3_FORGE_TOKEN: ${{ secrets.ESM3_FORGE_TOKEN }}
run: |
set -e
cd tests
make build-oss-ci
make start-docker-oss URL=${{ env.FORGE_URL }} DOCKER_TAG=${{ env.DOCKER_TAG }} ESM3_FORGE_TOKEN=${{ secrets.ESM3_FORGE_TOKEN }}
make start-docker-oss URL=${{ env.FORGE_URL }} DOCKER_TAG=${{ env.DOCKER_TAG }} ESM3_FORGE_TOKEN=${{ env.ESM3_FORGE_TOKEN }}
shell: pixi run bash -e {0}

- name: cleanup docker containers if they're hanging
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ esm.egg-info
# pixi environments
.pixi
*.egg-info
*.pyc
33 changes: 33 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# See https://pre-commit.com for more information
# See https://pre-commit.com/hooks.html for more hooks
exclude: (fasta|pdb|cif|mds|json)$
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v3.2.0
hooks:
- id: trailing-whitespace
- id: end-of-file-fixer
- id: check-yaml
- id: check-added-large-files
exclude: pixi.lock
- id: check-merge-conflict
- repo: https://github.com/seddonym/import-linter
rev: v1.12.1
hooks:
- id: import-linter
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.7.3
hooks:
- id: ruff # linter
args: [ --fix ]
- id: ruff-format # formatter
types_or: [python, jupyter]
- repo: https://github.com/RobertCraigie/pyright-python
rev: v1.1.399
hooks:
- id: pyright
name: pyright
entry: pyright
language: system
types: [python]
pass_filenames: true # For speed, we only check the files that are changed
26 changes: 26 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@


- [Installation ](#installation-)
- [Available Models](#available-models-)
- [ESM 3](#esm-3-)
- [Quickstart for ESM3 Open](#esm3-quickstart-)
- [ESM3 98B via Forge API](#esm3-forge)
Expand All @@ -33,6 +34,31 @@ To get started with ESM, install the python library using pip:
pip install esm
```

## Available Models <a name="available-models"></a>

### ESM 3 Family

| Model | Model Size | Release Date | Note |
|-------|------------|--------------|------|
| **Flagship Models** | | | Most users will be interested in using one of these models. |
| esm3-large-2024-03 | 98B | 2024-03 | |
| esm3-medium-2024-08 | 7B | 2024-08 | |
| esm3-small-2024-08 | 1.4B | 2024-08 | |
| **Published Models** | | | These models were used to generate all of the results in the ESM3 paper and are provided to facilitate reproducibility. |
| esm3-large-2024-03 | 98B | 2024-03 | |
| esm3-medium-2024-03 | 7B | 2024-03 | |
| esm3-small-2024-03 | 1.4B | 2024-03 | |
| **Experimental Models** | | | These models are provided for early use by researchers and are still under development. |
| esm3-medium-multimer-2024-09 | 7B | 2024-09 | |

### ESM C Models

| Model | Model Size | Number of Layers | Release Date |
|-------|------------|------------------|--------------|
| esmc-6b-2024-12 | 6B | 80 | 2024-12 |
| esmc-600m-2024-12 | 600M | 36 | 2024-12 |
| esmc-300m-2024-12 | 300M | 30 | 2024-12 |

## ESM 3 <a name="esm3"></a>

[ESM3](https://www.evolutionaryscale.ai/papers/esm3-simulating-500-million-years-of-evolution-with-a-language-model) is a frontier generative model for biology, able to jointly reason across three fundamental biological properties of proteins: sequence, structure, and function. These three data modalities are represented as tracks of discrete tokens at the input and output of ESM3. You can present the model with a combination of partial inputs across the tracks, and ESM3 will provide output predictions for all the tracks.
Expand Down
1 change: 1 addition & 0 deletions cookbook/local/open_generate.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
"\n",
"!pip install py3Dmol\n",
"import py3Dmol\n",
"\n",
"from esm.models.esm3 import ESM3\n",
"from esm.sdk.api import ESMProtein, GenerationConfig\n",
"from esm.utils.structure.protein_chain import ProteinChain"
Expand Down
4 changes: 1 addition & 3 deletions cookbook/local/raw_forwards.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,7 @@
from esm.tokenization.function_tokenizer import (
InterProQuantizedTokenizer as EsmFunctionTokenizer,
)
from esm.tokenization.sequence_tokenizer import (
EsmSequenceTokenizer,
)
from esm.tokenization.sequence_tokenizer import EsmSequenceTokenizer
from esm.utils.structure.protein_chain import ProteinChain
from esm.utils.types import FunctionAnnotation

Expand Down
4 changes: 3 additions & 1 deletion cookbook/tutorials/1_esmprotein.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
"outputs": [],
"source": [
"from biotite.database import rcsb\n",
"\n",
"from esm.sdk.api import ESMProtein\n",
"from esm.utils.structure.protein_chain import ProteinChain\n",
"from esm.utils.types import FunctionAnnotation\n",
Expand Down Expand Up @@ -496,9 +497,10 @@
"# Functions for visualizing InterPro function annotations\n",
"\n",
"from dna_features_viewer import GraphicFeature, GraphicRecord\n",
"from esm.utils.function.interpro import InterPro, InterProEntryType\n",
"from matplotlib import colormaps\n",
"\n",
"from esm.utils.function.interpro import InterPro, InterProEntryType\n",
"\n",
"\n",
"def visualize_function_annotations(\n",
" annotations: list[FunctionAnnotation],\n",
Expand Down
1 change: 1 addition & 0 deletions cookbook/tutorials/3_gfp_design.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
"import matplotlib.pyplot as pl\n",
"import py3Dmol\n",
"import torch\n",
"\n",
"from esm.sdk import client\n",
"from esm.sdk.api import ESMProtein, GenerationConfig\n",
"from esm.utils.structure.protein_chain import ProteinChain"
Expand Down
1 change: 1 addition & 0 deletions cookbook/tutorials/4_forge_generate.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
"\n",
"!pip install py3Dmol\n",
"import py3Dmol\n",
"\n",
"from esm.sdk import client\n",
"from esm.sdk.api import ESMProtein, GenerationConfig\n",
"from esm.utils.structure.protein_chain import ProteinChain"
Expand Down
73 changes: 64 additions & 9 deletions cookbook/tutorials/5_guided_generation.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@
"3. Minimize a biophysical energy function\n",
"4. Use experimental screening data to guide designs with a regression model\n",
"\n",
"As long as your scoring function takes a protein as input and outputs a single score, you can use it to guide designs. To accomplish this, we use an implementation of derivative-free guidance inspired by Soft Value-Based Decoding described in [Li, et al 2024](https://arxiv.org/abs/2408.08252).\n",
"As long as your scoring function takes a protein as input and outputs a single score, you can use it to guide designs. To accomplish this, we use an implementation of derivative-free guidance inspired by Soft Value-Based Decoding described in [Li, et al 2024](https://arxiv.org/abs/2408.08252) and constrained optimization using the Modified Differential Method of Multipliers from [Platt & Barr 1987](https://proceedings.neurips.cc/paper_files/paper/1987/file/a1126573153ad7e9f44ba80e99316482-Paper.pdf)\n",
"\n",
"In this notebook we will walk through a few examples to illustrate how to use guided generation. \n",
"\n",
"1. Guide towards high pTM for improved generation quality\n",
"2. Generate a protein with no cysteine (C) residues\n",
"3. Maximize protein globularity by minimizing the radius of gyration\n",
"3. Maximize protein globularity by minimizing the radius of gyration, while keeping pTM high\n",
"\n"
]
},
Expand Down Expand Up @@ -49,6 +49,7 @@
"source": [
"import biotite.structure as bs\n",
"import py3Dmol\n",
"\n",
"from esm.sdk.api import ESMProtein, GenerationConfig\n",
"from esm.sdk.experimental import ESM3GuidedDecoding, GuidedDecodingScoringFunction"
]
Expand Down Expand Up @@ -269,6 +270,11 @@
"metadata": {},
"outputs": [],
"source": [
"# Start from a fully masked protein\n",
"PROTEIN_LENGTH = 256\n",
"starting_protein = ESMProtein(sequence=\"_\" * PROTEIN_LENGTH)\n",
"\n",
"# Call guided_generate\n",
"no_cysteine_protein = no_cysteine_guided_decoding.guided_generate(\n",
" protein=starting_protein,\n",
" num_decoding_steps=len(starting_protein) // 8,\n",
Expand Down Expand Up @@ -302,7 +308,20 @@
"source": [
"## Maximize Globularity\n",
"\n",
"We use the radius of gyration as a proxy to maximize globularity, we also encourage generations to have high pTM"
"We use the radius of gyration as a proxy to maximize globularity, and we will also encourage generations to have high pTM by using constraints"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from esm.sdk.experimental import (\n",
" ConstraintType,\n",
" ESM3GuidedDecodingWithConstraints,\n",
" GenerationConstraint,\n",
")"
]
},
{
Expand All @@ -313,12 +332,11 @@
"source": [
"class RadiousOfGyrationScoringFunction(GuidedDecodingScoringFunction):\n",
" def __call__(self, protein: ESMProtein) -> float:\n",
" # Use the negative radius of gyration as the score to maximize\n",
" score = -1 * self.radius_of_gyration(protein)\n",
"\n",
" assert protein.ptm is not None, \"Protein must have pTM scores to be scored\"\n",
" if protein.ptm < 0.5:\n",
" # Penalize proteins with low pTM scores\n",
" score = score * 2\n",
" # Re-scale the score to be in a similar magnitude as pTM\n",
" score = score / 100.0\n",
"\n",
" return score\n",
"\n",
Expand All @@ -335,8 +353,19 @@
"metadata": {},
"outputs": [],
"source": [
"radius_guided_decoding = ESM3GuidedDecoding(\n",
" client=model, scoring_function=RadiousOfGyrationScoringFunction()\n",
"# Constrain generation to have pTM > 0.75\n",
"ptm_constraint = GenerationConstraint(\n",
" scoring_function=PTMScoringFunction(),\n",
" constraint_type=ConstraintType.GREATER_EQUAL,\n",
" value=0.75,\n",
")\n",
"\n",
"radius_guided_decoding = ESM3GuidedDecodingWithConstraints(\n",
" client=model,\n",
" scoring_function=RadiousOfGyrationScoringFunction(),\n",
" constraints=[ptm_constraint], # Add list of constraints\n",
" damping=1.0, # Damping factor for the MMDM algorithm\n",
" learning_rate=10.0, # Learning rate for the MMDM algorithm\n",
")"
]
},
Expand All @@ -346,6 +375,11 @@
"metadata": {},
"outputs": [],
"source": [
"# Start from a fully masked protein\n",
"PROTEIN_LENGTH = 256\n",
"starting_protein = ESMProtein(sequence=\"_\" * PROTEIN_LENGTH)\n",
"\n",
"# Call guided_generate\n",
"radius_guided_protein = radius_guided_decoding.guided_generate(\n",
" protein=starting_protein,\n",
" num_decoding_steps=len(starting_protein) // 8,\n",
Expand All @@ -359,11 +393,32 @@
"metadata": {},
"outputs": [],
"source": [
"# Visualize the trajectory of the constrained generation\n",
"radius_guided_decoding.visualize_latest_trajectory()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Visualize the generated protein\n",
"view = py3Dmol.view(width=800, height=400)\n",
"view.addModel(radius_guided_protein.to_pdb_string(), \"pdb\")\n",
"view.setStyle({\"cartoon\": {\"color\": \"spectrum\"}})\n",
"view.zoomTo()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Check pTM\n",
"radius_guided_protein.ptm"
]
}
],
"metadata": {
Expand Down
1 change: 0 additions & 1 deletion esm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1 @@
__version__ = "3.2.1"

13 changes: 5 additions & 8 deletions esm/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,12 @@
import torch.nn.functional as F
from torch import nn

from esm.layers.rotary import (
RotaryEmbedding,
TritonRotaryEmbedding,
)
from esm.layers.rotary import RotaryEmbedding, TritonRotaryEmbedding

try:
from flash_attn import flash_attn_varlen_qkvpacked_func # type:ignore
except ImportError:
flash_attn_varlen_func = None
from flash_attn import flash_attn_varlen_qkvpacked_func
except (ImportError, RuntimeError):
flash_attn_varlen_qkvpacked_func = None


class MultiHeadAttention(nn.Module):
Expand Down Expand Up @@ -117,7 +114,7 @@ def forward(self, x, seq_id):
)
qkv_N3HD = self.rotary(qkv_N3HD, cu_seqlens, max_seqlen)

context_NHD = flash_attn_varlen_qkvpacked_func(
context_NHD = flash_attn_varlen_qkvpacked_func( # type: ignore
qkv_N3HD, cu_seqlens, max_seqlen, softmax_scale=self.d_head**-0.5
)
context_ND = einops.rearrange(context_NHD, "n h d -> n (h d)")
Expand Down
9 changes: 2 additions & 7 deletions esm/layers/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,8 @@
import torch.nn as nn
import torch.nn.functional as F

from esm.layers.attention import (
FlashMultiHeadAttention,
MultiHeadAttention,
)
from esm.layers.geom_attention import (
GeometricReasoningOriginalImpl,
)
from esm.layers.attention import FlashMultiHeadAttention, MultiHeadAttention
from esm.layers.geom_attention import GeometricReasoningOriginalImpl
from esm.utils.structure.affine3d import Affine3D


Expand Down
5 changes: 1 addition & 4 deletions esm/layers/structure_proj.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,7 @@
import torch.nn as nn

from esm.utils.constants.physics import BB_COORDINATES
from esm.utils.structure.affine3d import (
Affine3D,
RotationMatrix,
)
from esm.utils.structure.affine3d import Affine3D, RotationMatrix


class Dim6RotStructureHead(nn.Module):
Expand Down
Loading
Loading