diff --git a/esm/__init__.py b/esm/__init__.py index 98a35b2..3348d7f 100644 --- a/esm/__init__.py +++ b/esm/__init__.py @@ -1 +1 @@ -__version__ = "3.2.2.post2" +__version__ = "3.2.3" diff --git a/esm/sdk/base_forge_client.py b/esm/sdk/base_forge_client.py index ff05b54..2ed3763 100644 --- a/esm/sdk/base_forge_client.py +++ b/esm/sdk/base_forge_client.py @@ -128,7 +128,7 @@ async def _async_post( except Exception as e: raise ESMProteinError( error_code=500, - error_msg=f"Failed to submit request to {endpoint}. Error: {e}", + error_msg=f"Failed to submit request to {endpoint}. Error: {str(e)}", ) def _post( @@ -158,5 +158,5 @@ def _post( except Exception as e: raise ESMProteinError( error_code=500, - error_msg=f"Failed to submit request to {endpoint}. Error: {e}", + error_msg=f"Failed to submit request to {endpoint}. Error: {str(e)}", ) diff --git a/esm/utils/msa/filter_sequences.py b/esm/utils/msa/filter_sequences.py index d44549d..da860d2 100644 --- a/esm/utils/msa/filter_sequences.py +++ b/esm/utils/msa/filter_sequences.py @@ -1,3 +1,4 @@ +import os import tempfile from pathlib import Path @@ -53,7 +54,9 @@ def hhfilter( qsc: float = -20.0, binary: str = "hhfilter", ) -> list[int]: - with tempfile.TemporaryDirectory(dir="/dev/shm") as tempdirname: + with tempfile.TemporaryDirectory( + dir="/dev/shm" if os.path.exists("/dev/shm") else None + ) as tempdirname: tempdir = Path(tempdirname) fasta_file = tempdir / "input.fasta" fasta_file.write_text( diff --git a/esm/utils/structure/input_builder.py b/esm/utils/structure/input_builder.py index 026912f..e432b53 100644 --- a/esm/utils/structure/input_builder.py +++ b/esm/utils/structure/input_builder.py @@ -50,11 +50,22 @@ class PocketConditioning: contacts: list[tuple[str, int]] +@dataclass +class CovalentBond: + chain_id1: str + res_idx1: int + atom_idx1: int + chain_id2: str + res_idx2: int + atom_idx2: int + + @dataclass class StructurePredictionInput: sequences: Sequence[ProteinInput | RNAInput | DNAInput | LigandInput] pocket: PocketConditioning | None = None distogram_conditioning: list[DistogramConditioning] | None = None + covalent_bonds: list[CovalentBond] | None = None def serialize_structure_prediction_input(all_atom_input: StructurePredictionInput): @@ -92,4 +103,20 @@ def create_chain_data(seq_input, chain_type: str) -> dict[str, Any]: else: raise ValueError(f"Unsupported sequence input type: {type(seq_input)}") - return {"sequences": sequences} + result: dict[str, Any] = {"sequences": sequences} + + # Add covalent bonds if present + if all_atom_input.covalent_bonds is not None: + result["covalent_bonds"] = [ + { + "chain_id1": bond.chain_id1, + "res_idx1": bond.res_idx1, + "atom_idx1": bond.atom_idx1, + "chain_id2": bond.chain_id2, + "res_idx2": bond.res_idx2, + "atom_idx2": bond.atom_idx2, + } + for bond in all_atom_input.covalent_bonds + ] + + return result diff --git a/esm/utils/structure/molecular_complex.py b/esm/utils/structure/molecular_complex.py index 6b6da1c..3b2ffe7 100644 --- a/esm/utils/structure/molecular_complex.py +++ b/esm/utils/structure/molecular_complex.py @@ -35,6 +35,8 @@ class MolecularComplexResult: pair_chains_iptm: torch.Tensor | None = None output_embedding_sequence: torch.Tensor | None = None output_embedding_pair_pooled: torch.Tensor | None = None + residue_index: torch.Tensor | None = None + entity_id: torch.Tensor | None = None @dataclass diff --git a/pixi.lock b/pixi.lock index 897f11a..33ad120 100644 --- a/pixi.lock +++ b/pixi.lock @@ -1726,8 +1726,8 @@ packages: requires_python: '>=3.8' - pypi: ./ name: esm - version: 3.2.2.post2 - sha256: 3f59a2977c85d35b4b1353902fa90e35d02acbabe6ffb506727bd406ec987ad1 + version: 3.2.3 + sha256: 7f3df1026fb23f4812615d3c4968f643f04d9cbf7735000615b011620ac83007 requires_dist: - torch>=2.2.0 - torchvision diff --git a/pyproject.toml b/pyproject.toml index 0dcc398..923d306 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "esm" -version = "3.2.2.post2" +version = "3.2.3" description = "EvolutionaryScale open model repository" readme = "README.md" requires-python = ">=3.12,<3.13"