diff --git a/esm/utils/structure/protein_chain.py b/esm/utils/structure/protein_chain.py index ab5ea6d6..535b0068 100644 --- a/esm/utils/structure/protein_chain.py +++ b/esm/utils/structure/protein_chain.py @@ -7,6 +7,8 @@ from typing import Sequence, TypeVar, Union import biotite.structure as bs +import openmm +import pdbfixer import brotli import msgpack import msgpack_numpy @@ -643,8 +645,30 @@ def from_rcsb( cls, pdb_id: str, chain_id: str = "detect", + fix_pdb: bool = False, ): - f: io.StringIO = rcsb.fetch(pdb_id, "pdb") # type: ignore + f: io.StringIO = rcsb.fetch(pdb_id, "pdb") # type: ignore (_io.StringIO) + + if fix_pdb: + fixer = pdbfixer.PDBFixer(pdbfile=f) + + # PDBFixer operations + fixer.findNonstandardResidues() + fixer.replaceNonstandardResidues() + fixer.findMissingResidues() + fixer.findMissingAtoms() + fixer.addMissingAtoms(seed=0) + fixer.addMissingHydrogens() + + # Create a StringIO object + f = io.StringIO() + + # Write the PDBFixer object to the StringIO object + openmm.app.PDBFile.writeFile(fixer.topology, fixer.positions, f, keepIds=True) + + # Reset StringIO pointer to the beginning + f.seek(0) + return cls.from_pdb(f, chain_id=chain_id, id=pdb_id) @classmethod diff --git a/pyproject.toml b/pyproject.toml index b9b4c67b..d7078e0e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,6 +34,8 @@ dependencies = [ "brotli", "attrs", "pandas", + "openmm", + "pdbfixer", ] [tool.setuptools]