diff --git a/README.md b/README.md index ce3ebb0..30cd82f 100644 --- a/README.md +++ b/README.md @@ -57,5 +57,27 @@ print(box) >>> Atoms(symbols='C10H44O12', pbc=True, cell=[8.4, 8.4, 8.4]) ``` +## Optional Dependencies + +### xyzgraph + +For more robust bond order and formal charge determination, install the +[xyzgraph](https://github.com/aligfellow/xyzgraph) backend: + +```bash +pip install molify[xyzgraph] +``` + +Then pass `engine="xyzgraph"` to `ase2networkx` or `ase2rdkit`: + +```py +from molify import smiles2atoms, ase2rdkit + +atoms = smiles2atoms("C=O") +atoms.info.pop("connectivity") # remove known connectivity to trigger engine + +mol = ase2rdkit(atoms, engine="xyzgraph") +``` + Many additional features are described in the [documentation](https://zincware.github.io/rdkit2ase/). diff --git a/pyproject.toml b/pyproject.toml index 47aaf07..44aebf4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,6 +34,10 @@ docs = [ [project.optional-dependencies] +xyzgraph = [ + "xyzgraph>=1.6.1", +] + vesin = [ "vesin>=0.3.7", ] diff --git a/src/molify/ase2x.py b/src/molify/ase2x.py index b31c670..1e3f067 100644 --- a/src/molify/ase2x.py +++ b/src/molify/ase2x.py @@ -1,3 +1,5 @@ +from typing import Literal + import ase import networkx as nx import numpy as np @@ -9,6 +11,11 @@ except ImportError: vesin = None +try: + import xyzgraph as _xyzgraph +except ImportError: + _xyzgraph = None + def _create_graph_from_connectivity( atoms: ase.Atoms, connectivity, charges @@ -94,10 +101,60 @@ def _add_node_properties( graph.nodes[i]["charge"] = 1.0 +def _xyzgraph_to_molify_graph(xg_graph: nx.Graph, atoms: ase.Atoms) -> nx.Graph: + """Convert an xyzgraph-produced NetworkX graph to molify's schema.""" + from ase.data import atomic_numbers + + graph = nx.Graph() + graph.graph["pbc"] = atoms.pbc + graph.graph["cell"] = atoms.cell + + for node_id, data in xg_graph.nodes(data=True): + graph.add_node( + node_id, + atomic_number=atomic_numbers[data["symbol"]], + position=np.array(data["position"]), + original_index=node_id, + charge=float(data.get("formal_charge", 0)), + ) + + for u, v, data in xg_graph.edges(data=True): + graph.add_edge(u, v, bond_order=data["bond_order"]) + + return graph + + +def _ase2networkx_xyzgraph( + atoms: ase.Atoms, + charge: int | None = None, + **engine_kwargs, +) -> nx.Graph: + """Build molecular graph using xyzgraph's cheminformatics pipeline.""" + from ase.data import chemical_symbols + + from molify.utils import unwrap_structures + + unwrapped = unwrap_structures(atoms, engine="rdkit") + + xyzgraph_atoms = [ + (chemical_symbols[atom.number], tuple(atom.position)) for atom in unwrapped + ] + + if charge is None: + charge = int(sum(unwrapped.get_initial_charges())) + + xg_graph = _xyzgraph.build_graph(xyzgraph_atoms, charge=charge, **engine_kwargs) + + return _xyzgraph_to_molify_graph(xg_graph, atoms) + + def ase2networkx( atoms: ase.Atoms, pbc: bool = True, scale: float = 1.2, + engine: Literal["auto", "rdkit", "xyzgraph"] = "auto", + charge: int | None = None, + **engine_kwargs, ) -> nx.Graph: """Convert an ASE Atoms object to a NetworkX graph. @@ -116,6 +173,16 @@ def ase2networkx( scale : float, optional Scaling factor for the covalent radii when determining bond cutoffs (default is 1.2). + engine : str, optional + Backend engine for bond determination. One of ``"auto"``, + ``"rdkit"``, or ``"xyzgraph"`` (default is ``"auto"``). + ``"auto"`` uses the distance-based/rdkit pipeline. + ``"xyzgraph"`` uses xyzgraph for bond order and charge + determination (requires ``pip install molify[xyzgraph]``). + charge : int or None, optional + Total molecular charge forwarded to xyzgraph (default is None). + **engine_kwargs + Additional keyword arguments forwarded to the engine backend. Returns ------- @@ -140,9 +207,11 @@ def ase2networkx( Connectivity is determined by: 1. Using explicit connectivity if present in atoms.info - 2. Otherwise using distance-based cutoffs (edges will have bond_order=None) + 2. With ``engine="xyzgraph"``, using xyzgraph's cheminformatics pipeline + (provides bond orders and formal charges) + 3. Otherwise using distance-based cutoffs (edges will have bond_order=None) - To get bond orders, pass the graph to networkx2rdkit(). + To get bond orders without xyzgraph, pass the graph to networkx2rdkit(). Examples -------- @@ -156,8 +225,10 @@ def ase2networkx( """ if len(atoms) == 0: return nx.Graph() + charges = atoms.get_initial_charges() + # Use explicit connectivity when present (regardless of engine) if "connectivity" in atoms.info: connectivity = atoms.info["connectivity"] # ensure connectivity is list[tuple[int, int, float|None]] and @@ -168,6 +239,20 @@ def ase2networkx( ] return _create_graph_from_connectivity(atoms, connectivity, charges) + # Resolve engine (only reached when no explicit connectivity) + use_xyzgraph = False + if engine == "xyzgraph": + if _xyzgraph is None: + raise ImportError( + "xyzgraph is required for engine='xyzgraph'. " + "Install it with: pip install molify[xyzgraph]" + ) + use_xyzgraph = True + # engine == "auto" or "rdkit" -> use_xyzgraph stays False + + if use_xyzgraph: + return _ase2networkx_xyzgraph(atoms, charge=charge, **engine_kwargs) + connectivity_matrix, non_bonding_atomic_numbers = _compute_connectivity_matrix( atoms, scale, pbc ) @@ -184,7 +269,13 @@ def ase2networkx( return graph -def ase2rdkit(atoms: ase.Atoms, suggestions: list[str] | None = None) -> Chem.Mol: +def ase2rdkit( + atoms: ase.Atoms, + suggestions: list[str] | None = None, + engine: Literal["auto", "rdkit", "xyzgraph"] = "auto", + charge: int | None = None, + **engine_kwargs, +) -> Chem.Mol: """Convert an ASE Atoms object to an RDKit molecule. Convenience function that chains: @@ -197,6 +288,13 @@ def ase2rdkit(atoms: ase.Atoms, suggestions: list[str] | None = None) -> Chem.Mo suggestions : list[str], optional SMILES/SMARTS patterns for bond order determination. Passed directly to networkx2rdkit(). + engine : Literal["auto", "rdkit", "xyzgraph"], optional + Backend for bond detection and bond order assignment (default "auto"). + Passed through to ase2networkx(). + charge : int or None, optional + Total system charge, forwarded to xyzgraph (default is None). + **engine_kwargs + Additional keyword arguments forwarded to the engine backend. Returns ------- @@ -216,5 +314,5 @@ def ase2rdkit(atoms: ase.Atoms, suggestions: list[str] | None = None) -> Chem.Mo from molify import ase2networkx, networkx2rdkit - graph = ase2networkx(atoms) + graph = ase2networkx(atoms, engine=engine, charge=charge, **engine_kwargs) return networkx2rdkit(graph, suggestions=suggestions) diff --git a/tests/test_xyzgraph_engine.py b/tests/test_xyzgraph_engine.py new file mode 100644 index 0000000..4a83683 --- /dev/null +++ b/tests/test_xyzgraph_engine.py @@ -0,0 +1,215 @@ +import pytest + +xyzgraph = pytest.importorskip("xyzgraph") + +import molify + + +def test_ase2networkx_xyzgraph_engine_water(): + """xyzgraph engine produces correct graph for water.""" + atoms = molify.smiles2atoms("O") + atoms.info.pop("connectivity") + + graph = molify.ase2networkx(atoms, engine="xyzgraph") + + assert graph.number_of_nodes() == 3 + assert graph.number_of_edges() == 2 + + for node_id, data in graph.nodes(data=True): + assert "atomic_number" in data + assert "position" in data + assert "charge" in data + assert isinstance(data["atomic_number"], int) + assert isinstance(data["charge"], float) + + for u, v, data in graph.edges(data=True): + assert "bond_order" in data + assert data["bond_order"] is not None + assert data["bond_order"] == 1.0 + + +def test_ase2networkx_xyzgraph_engine_ethanol(): + """xyzgraph engine produces correct graph for ethanol.""" + atoms = molify.smiles2atoms("CCO") + atoms.info.pop("connectivity") + + graph = molify.ase2networkx(atoms, engine="xyzgraph") + + assert graph.number_of_nodes() == 9 + assert graph.number_of_edges() == 8 + + for u, v, data in graph.edges(data=True): + assert data["bond_order"] is not None + + +def test_ase2networkx_xyzgraph_engine_formaldehyde(): + """xyzgraph engine detects double bond in formaldehyde.""" + atoms = molify.smiles2atoms("C=O") + atoms.info.pop("connectivity") + + graph = molify.ase2networkx(atoms, engine="xyzgraph") + + assert graph.number_of_nodes() == 4 + + co_bond_order = None + for u, v, data in graph.edges(data=True): + nums = {graph.nodes[u]["atomic_number"], graph.nodes[v]["atomic_number"]} + if nums == {6, 8}: + co_bond_order = data["bond_order"] + assert co_bond_order == 2.0 + + +def test_ase2networkx_xyzgraph_engine_preserves_pbc_cell(): + """xyzgraph engine preserves pbc and cell graph attributes.""" + atoms = molify.smiles2atoms("O") + atoms.info.pop("connectivity") + + graph = molify.ase2networkx(atoms, engine="xyzgraph") + + assert "pbc" in graph.graph + assert "cell" in graph.graph + + +def test_ase2networkx_xyzgraph_charge_parameter(): + """Explicit charge parameter is forwarded to xyzgraph.""" + atoms = molify.smiles2atoms("[OH-]") + atoms.info.pop("connectivity") + + graph = molify.ase2networkx(atoms, engine="xyzgraph", charge=-1) + + assert graph.number_of_nodes() == 2 + assert graph.number_of_edges() == 1 + + total_charge = sum(data["charge"] for _, data in graph.nodes(data=True)) + assert total_charge == pytest.approx(-1.0, abs=0.1) + + +def test_ase2networkx_xyzgraph_engine_kwargs(): + """engine_kwargs are forwarded to xyzgraph.build_graph.""" + atoms = molify.smiles2atoms("O") + atoms.info.pop("connectivity") + + graph = molify.ase2networkx(atoms, engine="xyzgraph", quick=True) + assert graph.number_of_nodes() == 3 + + +def test_ase2networkx_xyzgraph_engine_empty_atoms(): + """xyzgraph engine handles empty atoms gracefully.""" + import ase + + atoms = ase.Atoms() + graph = molify.ase2networkx(atoms, engine="xyzgraph") + assert graph.number_of_nodes() == 0 + assert graph.number_of_edges() == 0 + + +def test_ase2networkx_rdkit_engine_unchanged(): + """engine='rdkit' preserves current behavior exactly.""" + atoms = molify.smiles2atoms("O") + atoms.info.pop("connectivity") + + graph = molify.ase2networkx(atoms, engine="rdkit") + + assert graph.number_of_nodes() == 3 + assert graph.number_of_edges() == 2 + + for u, v, data in graph.edges(data=True): + assert data["bond_order"] is None + + +def test_ase2rdkit_xyzgraph_engine_water(): + """ase2rdkit with xyzgraph engine produces correct RDKit molecule for water.""" + from rdkit import Chem + + atoms = molify.smiles2atoms("O") + atoms.info.pop("connectivity") + + mol = molify.ase2rdkit(atoms, engine="xyzgraph") + + assert mol.GetNumAtoms() == 3 + assert Chem.MolToSmiles(mol, canonical=True) == Chem.MolToSmiles( + Chem.AddHs(Chem.MolFromSmiles("O")), canonical=True + ) + + +def test_ase2rdkit_xyzgraph_engine_ethanol(): + """ase2rdkit with xyzgraph engine produces correct molecule for ethanol.""" + from rdkit import Chem + + atoms = molify.smiles2atoms("CCO") + atoms.info.pop("connectivity") + + mol = molify.ase2rdkit(atoms, engine="xyzgraph") + + assert mol.GetNumAtoms() == 9 + assert Chem.MolToSmiles(mol, canonical=True) == Chem.MolToSmiles( + Chem.AddHs(Chem.MolFromSmiles("CCO")), canonical=True + ) + + +def test_ase2rdkit_xyzgraph_charge_forwarded(): + """ase2rdkit forwards charge parameter to xyzgraph engine.""" + + atoms = molify.smiles2atoms("[OH-]") + atoms.info.pop("connectivity") + + mol = molify.ase2rdkit(atoms, engine="xyzgraph", charge=-1) + assert mol.GetNumAtoms() == 2 + + +def test_ase2rdkit_xyzgraph_engine_formaldehyde(): + """ase2rdkit with xyzgraph engine correctly identifies double bonds.""" + from rdkit import Chem + + atoms = molify.smiles2atoms("C=O") + atoms.info.pop("connectivity") + + mol = molify.ase2rdkit(atoms, engine="xyzgraph") + + assert Chem.MolToSmiles(mol, canonical=True) == Chem.MolToSmiles( + Chem.AddHs(Chem.MolFromSmiles("C=O")), canonical=True + ) + + +def test_ase2networkx_xyzgraph_importerror(): + """engine='xyzgraph' raises ImportError when xyzgraph is not installed.""" + from unittest.mock import patch + + atoms = molify.smiles2atoms("O") + atoms.info.pop("connectivity") + + with patch("molify.ase2x._xyzgraph", None): + with pytest.raises(ImportError, match="xyzgraph is required"): + molify.ase2networkx(atoms, engine="xyzgraph") + + +def test_ase2networkx_auto_engine_no_xyzgraph(): + """engine='auto' falls back to rdkit behavior when xyzgraph is not installed.""" + from unittest.mock import patch + + atoms = molify.smiles2atoms("O") + atoms.info.pop("connectivity") + + with patch("molify.ase2x._xyzgraph", None): + graph = molify.ase2networkx(atoms, engine="auto") + + # Should work fine with rdkit fallback + assert graph.number_of_nodes() == 3 + assert graph.number_of_edges() == 2 + # rdkit path has bond_order=None + for u, v, data in graph.edges(data=True): + assert data["bond_order"] is None + + +def test_ase2networkx_connectivity_takes_precedence_over_engine(): + """When connectivity is present in atoms.info, engine parameter is ignored.""" + atoms = molify.smiles2atoms("O") # Has connectivity in info + + # Even with engine="xyzgraph", connectivity should be used + graph = molify.ase2networkx(atoms, engine="xyzgraph") + + assert graph.number_of_nodes() == 3 + assert graph.number_of_edges() == 2 + # Bond orders come from connectivity (not None) + for u, v, data in graph.edges(data=True): + assert data["bond_order"] is not None diff --git a/uv.lock b/uv.lock index 6c54330..f940641 100644 --- a/uv.lock +++ b/uv.lock @@ -1231,7 +1231,7 @@ wheels = [ [[package]] name = "molify" -version = "0.2.1" +version = "0.2.2" source = { editable = "." } dependencies = [ { name = "ase" }, @@ -1246,6 +1246,9 @@ dependencies = [ vesin = [ { name = "vesin" }, ] +xyzgraph = [ + { name = "xyzgraph" }, +] [package.dev-dependencies] dev = [ @@ -1272,8 +1275,9 @@ requires-dist = [ { name = "packmol", specifier = ">=21.1.2" }, { name = "rdkit", specifier = ">=2024" }, { name = "vesin", marker = "extra == 'vesin'", specifier = ">=0.3.7" }, + { name = "xyzgraph", marker = "extra == 'xyzgraph'", specifier = ">=1.6.1" }, ] -provides-extras = ["vesin"] +provides-extras = ["xyzgraph", "vesin"] [package.metadata.requires-dev] dev = [ @@ -2674,3 +2678,19 @@ sdist = { url = "https://files.pythonhosted.org/packages/0b/02/ae6ceac1baeda5308 wheels = [ { url = "https://files.pythonhosted.org/packages/f4/24/2a3e3df732393fed8b3ebf2ec078f05546de641fe1b667ee316ec1dcf3b7/webencodings-0.5.1-py2.py3-none-any.whl", hash = "sha256:a0af1213f3c2226497a97e2b3aa01a7e4bee4f403f95be16fc9acd2947514a78", size = 11774, upload-time = "2017-04-05T20:21:32.581Z" }, ] + +[[package]] +name = "xyzgraph" +version = "1.6.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "networkx", version = "3.4.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, + { name = "networkx", version = "3.5", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, + { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, + { name = "numpy", version = "2.3.5", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, + { name = "rdkit" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/21/ec/f82fe482987523cb861561f27ce9cee3ae99c0cb8c4fec07440453f80d4b/xyzgraph-1.6.1.tar.gz", hash = "sha256:e005eacf55a73f5208b99c7665ba317709b9274374b511062881f761a3eaeef1", size = 142683, upload-time = "2026-02-23T10:16:28.55Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/9c/ab/5867ddf13cc6c31286324cfdc4e5899ef6bba6b1eda6e551a1c9c936e7a4/xyzgraph-1.6.1-py3-none-any.whl", hash = "sha256:8d3b1b96eb9410698c3609baa88f9016e37d2586fe8db89e3f35e914ecdf3049", size = 106653, upload-time = "2026-02-23T10:16:26.887Z" }, +]