Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion esm/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "3.2.4.a0"
__version__ = "3.2.4.a1"
1 change: 1 addition & 0 deletions esm/models/esm3.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,6 +489,7 @@ def encode(self, input: ESMProtein) -> ESMProteinTensor:
function=function_tokens,
residue_annotations=residue_annotation_tokens,
coordinates=coordinates,
potential_sequence_of_concern=input.potential_sequence_of_concern,
).to(next(self.parameters()).device)

def decode(self, input: ESMProteinTensor) -> ESMProtein:
Expand Down
7 changes: 4 additions & 3 deletions esm/models/esmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,9 +178,10 @@ def encode(self, input: ESMProtein) -> ESMProteinTensor:

if input.sequence is not None:
sequence_tokens = self._tokenize([input.sequence])[0]
return ESMProteinTensor(sequence=sequence_tokens).to(
next(self.parameters()).device
)
return ESMProteinTensor(
sequence=sequence_tokens,
potential_sequence_of_concern=input.potential_sequence_of_concern,
).to(next(self.parameters()).device)

def decode(self, input: ESMProteinTensor) -> ESMProtein:
input = attr.evolve(input) # Make a copy
Expand Down
11 changes: 9 additions & 2 deletions esm/sdk/forge.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,6 +523,7 @@ def _process_encode_response(data: dict[str, Any]) -> ESMProteinTensor:
sasa=maybe_tensor(data["outputs"]["sasa"]),
function=maybe_tensor(data["outputs"]["function"]),
residue_annotations=maybe_tensor(data["outputs"]["residue_annotation"]),
potential_sequence_of_concern=data["potential_sequence_of_concern"],
)

@staticmethod
Expand Down Expand Up @@ -1004,7 +1005,10 @@ async def async_encode(
except ESMProteinError as e:
return e

return ESMProteinTensor(sequence=maybe_tensor(data["outputs"]["sequence"]))
return ESMProteinTensor(
sequence=maybe_tensor(data["outputs"]["sequence"]),
potential_sequence_of_concern=data["potential_sequence_of_concern"],
)

@retry_decorator
def encode(self, input: ESMProtein) -> ESMProteinTensor | ESMProteinError:
Expand All @@ -1018,7 +1022,10 @@ def encode(self, input: ESMProtein) -> ESMProteinTensor | ESMProteinError:
except ESMProteinError as e:
return e

return ESMProteinTensor(sequence=maybe_tensor(data["outputs"]["sequence"]))
return ESMProteinTensor(
sequence=maybe_tensor(data["outputs"]["sequence"]),
potential_sequence_of_concern=data["potential_sequence_of_concern"],
)

@retry_decorator
async def async_decode(
Expand Down
3 changes: 3 additions & 0 deletions esm/utils/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,9 @@ def sample_sasa_logits(
sasa_value[max_prob_idx == 18] = float("inf")
sasa_value[~sampling_mask] = float("inf")

# Set BOS and EOS tokens to 0
sasa_value[..., 0] = 0.0
sasa_value[..., -1] = 0.0
return sasa_value


Expand Down
13 changes: 12 additions & 1 deletion esm/utils/structure/input_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,6 @@ def create_chain_data(seq_input, chain_type: str) -> dict[str, Any]:

result: dict[str, Any] = {"sequences": sequences}

# Add covalent bonds if present
if all_atom_input.covalent_bonds is not None:
result["covalent_bonds"] = [
{
Expand All @@ -119,4 +118,16 @@ def create_chain_data(seq_input, chain_type: str) -> dict[str, Any]:
for bond in all_atom_input.covalent_bonds
]

if all_atom_input.pocket is not None:
result["pocket"] = {
"binder_chain_id": all_atom_input.pocket.binder_chain_id,
"contacts": all_atom_input.pocket.contacts,
}

if all_atom_input.distogram_conditioning is not None:
result["distogram_conditioning"] = [
{"chain_id": disto.chain_id, "distogram": disto.distogram.tolist()}
for disto in all_atom_input.distogram_conditioning
]

return result
17 changes: 17 additions & 0 deletions esm/utils/structure/protein_chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,20 @@
CHAIN_ID_CONST = "A"


def _str_key_to_int_key(dct: dict, ignore_keys: list[str] | None = None) -> dict:
new_dict = {}
for k, v in dct.items():
v_new = v
if k not in ignore_keys and isinstance(v, dict):
v_new = _str_key_to_int_key(v, ignore_keys=ignore_keys)
# Note assembly_composition is *supposed* to have string keys.
if isinstance(k, str) and k.isdigit():
new_dict[int(k)] = v_new
else:
new_dict[k] = v_new
return new_dict


def _num_non_null_residues(seqres_to_structure_chain: Mapping[int, Residue]) -> int:
return sum(
residue.residue_number is not None
Expand Down Expand Up @@ -366,6 +380,9 @@ def from_open_source(cls, pc: ProteinChain):

@classmethod
def from_state_dict(cls, dct):
# Note: assembly_composition is *supposed* to have string keys.
dct = _str_key_to_int_key(dct, ignore_keys=["assembly_composition"])

for k, v in dct.items():
if isinstance(v, list):
dct[k] = np.array(v)
Expand Down
4 changes: 4 additions & 0 deletions esm/utils/structure/protein_complex.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from esm.utils.structure.mmcif_parsing import MmcifWrapper, NoProteinError
from esm.utils.structure.protein_chain import (
ProteinChain,
_str_key_to_int_key,
chain_to_ndarray,
index_by_atom_name,
infer_CB,
Expand Down Expand Up @@ -410,6 +411,9 @@ def to_blob(self, backbone_only=False) -> bytes:

@classmethod
def from_state_dict(cls, dct):
# Note: assembly_composition is *supposed* to have string keys.
dct = _str_key_to_int_key(dct, ignore_keys=["assembly_composition"])

for k, v in dct.items():
if isinstance(v, list):
dct[k] = np.array(v)
Expand Down
4 changes: 2 additions & 2 deletions pixi.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "esm"
version = "3.2.4.a0"
version = "3.2.4.a1"
description = "EvolutionaryScale open model repository"
readme = "README.md"
requires-python = ">=3.12,<3.13"
Expand Down
Loading