diff --git a/esm/__init__.py b/esm/__init__.py index 3dc1ccb5..0954090c 100644 --- a/esm/__init__.py +++ b/esm/__init__.py @@ -1 +1 @@ -__version__ = "3.2.4.a0" +__version__ = "3.2.4.a1" diff --git a/esm/models/esm3.py b/esm/models/esm3.py index 218a8e90..0d3ead1d 100644 --- a/esm/models/esm3.py +++ b/esm/models/esm3.py @@ -489,6 +489,7 @@ def encode(self, input: ESMProtein) -> ESMProteinTensor: function=function_tokens, residue_annotations=residue_annotation_tokens, coordinates=coordinates, + potential_sequence_of_concern=input.potential_sequence_of_concern, ).to(next(self.parameters()).device) def decode(self, input: ESMProteinTensor) -> ESMProtein: diff --git a/esm/models/esmc.py b/esm/models/esmc.py index 0807a218..d3a5bfc9 100644 --- a/esm/models/esmc.py +++ b/esm/models/esmc.py @@ -178,9 +178,10 @@ def encode(self, input: ESMProtein) -> ESMProteinTensor: if input.sequence is not None: sequence_tokens = self._tokenize([input.sequence])[0] - return ESMProteinTensor(sequence=sequence_tokens).to( - next(self.parameters()).device - ) + return ESMProteinTensor( + sequence=sequence_tokens, + potential_sequence_of_concern=input.potential_sequence_of_concern, + ).to(next(self.parameters()).device) def decode(self, input: ESMProteinTensor) -> ESMProtein: input = attr.evolve(input) # Make a copy diff --git a/esm/sdk/forge.py b/esm/sdk/forge.py index 9c65c5f1..98973326 100644 --- a/esm/sdk/forge.py +++ b/esm/sdk/forge.py @@ -523,6 +523,7 @@ def _process_encode_response(data: dict[str, Any]) -> ESMProteinTensor: sasa=maybe_tensor(data["outputs"]["sasa"]), function=maybe_tensor(data["outputs"]["function"]), residue_annotations=maybe_tensor(data["outputs"]["residue_annotation"]), + potential_sequence_of_concern=data["potential_sequence_of_concern"], ) @staticmethod @@ -1004,7 +1005,10 @@ async def async_encode( except ESMProteinError as e: return e - return ESMProteinTensor(sequence=maybe_tensor(data["outputs"]["sequence"])) + return ESMProteinTensor( + sequence=maybe_tensor(data["outputs"]["sequence"]), + potential_sequence_of_concern=data["potential_sequence_of_concern"], + ) @retry_decorator def encode(self, input: ESMProtein) -> ESMProteinTensor | ESMProteinError: @@ -1018,7 +1022,10 @@ def encode(self, input: ESMProtein) -> ESMProteinTensor | ESMProteinError: except ESMProteinError as e: return e - return ESMProteinTensor(sequence=maybe_tensor(data["outputs"]["sequence"])) + return ESMProteinTensor( + sequence=maybe_tensor(data["outputs"]["sequence"]), + potential_sequence_of_concern=data["potential_sequence_of_concern"], + ) @retry_decorator async def async_decode( diff --git a/esm/utils/sampling.py b/esm/utils/sampling.py index fdf8658d..98cf513e 100644 --- a/esm/utils/sampling.py +++ b/esm/utils/sampling.py @@ -284,6 +284,9 @@ def sample_sasa_logits( sasa_value[max_prob_idx == 18] = float("inf") sasa_value[~sampling_mask] = float("inf") + # Set BOS and EOS tokens to 0 + sasa_value[..., 0] = 0.0 + sasa_value[..., -1] = 0.0 return sasa_value diff --git a/esm/utils/structure/input_builder.py b/esm/utils/structure/input_builder.py index e432b532..b2e664fb 100644 --- a/esm/utils/structure/input_builder.py +++ b/esm/utils/structure/input_builder.py @@ -105,7 +105,6 @@ def create_chain_data(seq_input, chain_type: str) -> dict[str, Any]: result: dict[str, Any] = {"sequences": sequences} - # Add covalent bonds if present if all_atom_input.covalent_bonds is not None: result["covalent_bonds"] = [ { @@ -119,4 +118,16 @@ def create_chain_data(seq_input, chain_type: str) -> dict[str, Any]: for bond in all_atom_input.covalent_bonds ] + if all_atom_input.pocket is not None: + result["pocket"] = { + "binder_chain_id": all_atom_input.pocket.binder_chain_id, + "contacts": all_atom_input.pocket.contacts, + } + + if all_atom_input.distogram_conditioning is not None: + result["distogram_conditioning"] = [ + {"chain_id": disto.chain_id, "distogram": disto.distogram.tolist()} + for disto in all_atom_input.distogram_conditioning + ] + return result diff --git a/esm/utils/structure/protein_chain.py b/esm/utils/structure/protein_chain.py index 2b4c8165..07b1e6e2 100644 --- a/esm/utils/structure/protein_chain.py +++ b/esm/utils/structure/protein_chain.py @@ -38,6 +38,20 @@ CHAIN_ID_CONST = "A" +def _str_key_to_int_key(dct: dict, ignore_keys: list[str] | None = None) -> dict: + new_dict = {} + for k, v in dct.items(): + v_new = v + if k not in ignore_keys and isinstance(v, dict): + v_new = _str_key_to_int_key(v, ignore_keys=ignore_keys) + # Note assembly_composition is *supposed* to have string keys. + if isinstance(k, str) and k.isdigit(): + new_dict[int(k)] = v_new + else: + new_dict[k] = v_new + return new_dict + + def _num_non_null_residues(seqres_to_structure_chain: Mapping[int, Residue]) -> int: return sum( residue.residue_number is not None @@ -366,6 +380,9 @@ def from_open_source(cls, pc: ProteinChain): @classmethod def from_state_dict(cls, dct): + # Note: assembly_composition is *supposed* to have string keys. + dct = _str_key_to_int_key(dct, ignore_keys=["assembly_composition"]) + for k, v in dct.items(): if isinstance(v, list): dct[k] = np.array(v) diff --git a/esm/utils/structure/protein_complex.py b/esm/utils/structure/protein_complex.py index 0ef2465a..fa71d2ad 100644 --- a/esm/utils/structure/protein_complex.py +++ b/esm/utils/structure/protein_complex.py @@ -36,6 +36,7 @@ from esm.utils.structure.mmcif_parsing import MmcifWrapper, NoProteinError from esm.utils.structure.protein_chain import ( ProteinChain, + _str_key_to_int_key, chain_to_ndarray, index_by_atom_name, infer_CB, @@ -410,6 +411,9 @@ def to_blob(self, backbone_only=False) -> bytes: @classmethod def from_state_dict(cls, dct): + # Note: assembly_composition is *supposed* to have string keys. + dct = _str_key_to_int_key(dct, ignore_keys=["assembly_composition"]) + for k, v in dct.items(): if isinstance(v, list): dct[k] = np.array(v) diff --git a/pixi.lock b/pixi.lock index 2756766b..3aa514df 100644 --- a/pixi.lock +++ b/pixi.lock @@ -1726,8 +1726,8 @@ packages: requires_python: '>=3.8' - pypi: ./ name: esm - version: 3.2.4a0 - sha256: aaf223456a7bbe86ef1885a0d14d5f6c1864bb8dc84b23d786948f95c6e53eae + version: 3.2.4a1 + sha256: 9a3a042ef03cda7a67cb638f08f7536d8c564329f9d9f9024e7428c7bcccc2d7 requires_dist: - torch>=2.2.0 - torchvision diff --git a/pyproject.toml b/pyproject.toml index 8f791392..8d15c1da 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "esm" -version = "3.2.4.a0" +version = "3.2.4.a1" description = "EvolutionaryScale open model repository" readme = "README.md" requires-python = ">=3.12,<3.13"