Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 52 additions & 95 deletions cookbook/tutorials/2_embed.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion esm/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "3.2.3"
__version__ = "3.2.4.a0"
7 changes: 6 additions & 1 deletion esm/sdk/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def from_protein_chain(
sasa=protein_chain.sasa().tolist(),
function_annotations=None,
coordinates=torch.tensor(protein_chain.atom37_positions),
plddt=torch.tensor(protein_chain.confidence),
)
else:
return ESMProtein(
Expand All @@ -85,6 +86,7 @@ def from_protein_chain(
sasa=None,
function_annotations=None,
coordinates=torch.tensor(protein_chain.atom37_positions),
plddt=torch.tensor(protein_chain.confidence),
)

@classmethod
Expand All @@ -104,6 +106,7 @@ def from_protein_complex(
coordinates=torch.tensor(
protein_complex.atom37_positions, dtype=torch.float32
),
plddt=torch.tensor(protein_complex.confidence),
)

def to_pdb(self, pdb_path: PathOrBuffer) -> None:
Expand Down Expand Up @@ -325,7 +328,9 @@ def use_generative_unmasking_strategy(self):
@define
class InverseFoldingConfig:
invalid_ids: Sequence[int] = []
temperature: float = 1.0
temperature: float = 0.1
seed: int | None = None
decode_in_residue_index_order: bool = False


## Low Level Endpoint Types
Expand Down
2 changes: 2 additions & 0 deletions esm/sdk/forge.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,8 @@ def process_inverse_fold_request(
inverse_folding_config = {
"invalid_ids": config.invalid_ids,
"temperature": config.temperature,
"seed": config.seed,
"decode_in_residue_index_order": config.decode_in_residue_index_order,
}
request = {
"coordinates": maybe_list(coordinates, convert_nan_to_none=True),
Expand Down
3 changes: 2 additions & 1 deletion esm/utils/structure/molecular_complex.py
Original file line number Diff line number Diff line change
Expand Up @@ -707,8 +707,9 @@ def to_mmcif(self) -> str:
atom_array.chain_id = np.array(atom_chain_ids, dtype="U4")
atom_array.res_name = np.array(atom_res_names, dtype="U4")
atom_array.hetero = atom_hetero
atom_array.b_factor = atom_bfactors
atom_array.atom_name = np.array(atom_names, dtype="U4")
atom_array.add_annotation("b_factor", dtype=float)
atom_array.b_factor = atom_bfactors

# Use existing elements or infer them from atom names
if self.atom_elements is not None and len(self.atom_elements) == n_atoms:
Expand Down
4 changes: 3 additions & 1 deletion esm/utils/structure/protein_chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -1121,7 +1121,9 @@ def normalize_coordinates(self) -> ProteinChain:

def infer_oxygen(self) -> ProteinChain:
"""Oxygen position is fixed given N, CA, C atoms. Infer it if not provided."""
O_missing_indices = np.argwhere(np.isnan(self.atoms["O"]).any(axis=1)).squeeze()
O_missing_indices = np.argwhere(
~np.isfinite(self.atoms["O"]).all(axis=1)
).squeeze()

O_vector = torch.tensor([0.6240, -1.0613, 0.0103], dtype=torch.float32)
N, CA, C = torch.from_numpy(self.atoms[["N", "CA", "C"]]).float().unbind(dim=1)
Expand Down
4 changes: 3 additions & 1 deletion esm/utils/structure/protein_complex.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,7 +562,9 @@ def join_arrays(arrays: Sequence[np.ndarray], sep: np.ndarray):

def infer_oxygen(self) -> ProteinComplex:
"""Oxygen position is fixed given N, CA, C atoms. Infer it if not provided."""
O_missing_indices = np.argwhere(np.isnan(self.atoms["O"]).any(axis=1)).squeeze()
O_missing_indices = np.argwhere(
~np.isfinite(self.atoms["O"]).all(axis=1)
).squeeze()

O_vector = torch.tensor([0.6240, -1.0613, 0.0103], dtype=torch.float32)
N, CA, C = torch.from_numpy(self.atoms[["N", "CA", "C"]]).float().unbind(dim=1)
Expand Down
103 changes: 58 additions & 45 deletions pixi.lock

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "esm"
version = "3.2.3"
version = "3.2.4.a0"
description = "EvolutionaryScale open model repository"
readme = "README.md"
requires-python = ">=3.12,<3.13"
Expand All @@ -24,7 +24,7 @@ dependencies = [
"torch>=2.2.0",
"torchvision",
"torchtext",
"transformers<4.48.2",
"transformers==4.52.4",
"ipython",
"einops",
"biotite>=1.0.0",
Expand Down
6 changes: 5 additions & 1 deletion tests/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@ DOCKER_TAG ?= dev
DOCKER_IMAGE_OSS=oss_pytests:${DOCKER_TAG}

build-oss-ci:
docker build -f oss_pytests/Dockerfile oss_pytests -t $(DOCKER_IMAGE_OSS)
docker build \
--output=type=docker \
-f oss_pytests/Dockerfile \
-t $(DOCKER_IMAGE_OSS) \
oss_pytests

start-docker-oss:
docker run \
Expand Down
Loading