Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "retromol"
version = "1.1.0"
version = "2.0.0"
description = "RetroMol is retrosynthetic analysis tool for modular natural products"
readme = "README.md"
requires-python = ">=3.10"
Expand Down
26 changes: 15 additions & 11 deletions src/retromol/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,8 @@
from retromol.version import __version__
from retromol.utils.logging import setup_logging, add_file_handler
from retromol.model.rules import RuleSet
from retromol.model.submission import Submission
from retromol.model.result import Result
from retromol.model.readout import LinearReadout
from retromol.model.submission import Submission
from retromol.pipelines.parsing import run_retromol_with_timeout
from retromol.io.streaming import run_retromol_stream, stream_sdf_records, stream_table_rows, stream_json_records
from retromol.chem.mol import encode_mol
Expand Down Expand Up @@ -129,28 +128,33 @@ def main() -> None:
result: Result = run_retromol_with_timeout(submission, ruleset)
log.info(f"result: {result}")

# Write out result to file and then read back in again for visualization (test I/O)
result_dict = result.to_dict()
with open(os.path.join(args.outdir, "result.json"), "w") as f:
json.dump(result_dict, f, indent=4)

with open(os.path.join(args.outdir, "result.json"), "r") as f:
result_data = json.load(f)
result2 = Result.from_dict(result_data)

# Report on coverage as percentage of tags identified
coverage = result.calculate_coverage()
coverage = result2.calculate_coverage()
log.info(f"coverage: {coverage:.2%}")

# Get linear readout; print summary
linear_readout = LinearReadout.from_result(result)
# Get linear readout; draw assembly graph
linear_readout = result2.linear_readout
out_assembly_graph_fig = os.path.join(args.outdir, "assembly_graph.png")
linear_readout.assembly_graph.draw(show_unassigned=True, savepath=out_assembly_graph_fig)
log.info(f"linear readout: {linear_readout}")

# Visualize reaction graph
root = encode_mol(result.submission.mol)
root = encode_mol(result2.submission.mol)
visualize_reaction_graph(
result.reaction_graph,
result2.reaction_graph,
html_path=os.path.join(args.outdir, "reaction_graph.html"),
root_enc=root
)

result_dict = result.to_dict()
with open(os.path.join(args.outdir, "result.json"), "w") as f:
json.dump(result_dict, f, indent=4)

result_counts["successes"] += 1

# Batch mode
Expand Down
10 changes: 2 additions & 8 deletions src/retromol/fingerprint/fingerprint.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,8 +175,6 @@ def __init__(
tanimoto_threshold: float = 0.6,
morgan_radius: int = 2,
morgan_num_bits: int = 2048,
family_token_weight: int = 1,
ancestor_token_weight: int = 1,
) -> None:
"""
Initialize FingerprintGenerator.
Expand All @@ -186,8 +184,6 @@ def __init__(
:param tanimoto_threshold: Tanimoto similarity threshold for collapsing monomers
:param morgan_radius: radius for Morgan fingerprinting when collapsing monomers
:param morgan_num_bits: number of bits for Morgan fingerprinting when collapsing monomers
:param family_token_weight: weight for family tokens in the fingerprint
:param ancestor_token_weight: weight for ancestor tokens in the fingerprint
"""
matching_rules = list(matching_rules)

Expand Down Expand Up @@ -279,10 +275,8 @@ def fingerprint_from_result(
if kmer_weights is None:
kmer_weights = {1: 1, 2: 1}

# Create assembly graph of monomers; first collect nodes to include
root = result.submission.mol
collected = result.reaction_graph.get_leaf_nodes(identified_only=False)
a = AssemblyGraph.build(root_mol=root, monomers=collected, include_unassigned=True)
# Retrieve AssemblyGraph from Result
a = result.linear_readout.assembly_graph

# Calculate kmers from AssemblyGraph
tokenized_kmers: list[tuple[str | None, ...]] = []
Expand Down
108 changes: 107 additions & 1 deletion src/retromol/model/assembly_graph.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
"""Module contains utilities for defining and working with assembly graphs."""

from dataclasses import dataclass
from dataclasses import dataclass, asdict
from typing import Any, Iterable, Iterable, Iterator, Generator

from rdkit.Chem.rdchem import Mol
import matplotlib.pyplot as plt
import networkx as nx

from retromol.model.reaction_graph import MolNode
from retromol.model.identity import MolIdentity
from retromol.chem.tagging import get_tags_mol


Expand Down Expand Up @@ -37,6 +38,24 @@ class RootBondLink:
bond_type: str # stringified version of RDKit BondType
bond_order: float | int | None # include if available

def to_dict(self) -> dict[str, Any]:
"""
Convert the RootBondLink to a dictionary.

:return: dictionary representation of the RootBondLink
"""
return asdict(self)

@classmethod
def from_dict(cls, data: dict[str, Any]) -> "RootBondLink":
"""
Create a RootBondLink from a dictionary.

:param data: dictionary representation of the RootBondLink
:return: RootBondLink instance
"""
return cls(**data)


def build_assembly_graph(
root_mol: Mol,
Expand Down Expand Up @@ -497,6 +516,93 @@ def validate(self) -> None:

if not isinstance(data["n_bonds"], int):
raise ValueError(f"AssemblyGraph edge {u!r}-{v!r} n_bonds must be int")

def to_dict(self) -> dict[str, Any]:
"""
Convert the AssemblyGraph to a dictionary.

:return: dictionary representation of the AssemblyGraph
"""
nodes_out: list[dict[str, Any]] = []
for node_id, data in self.g.nodes(data=True):
tags = data.get("tags", set())
tags_json = sorted(tags) # stable + JSON-friendly

mn = data.get("molnode", None)
mn_json = None if mn is None else mn.to_dict()

ident = data.get("identity", None)
ident_json = None if ident is None else ident.to_dict()

nodes_out.append(
{
"id": node_id,
"tags": tags_json,
"identity": ident_json,
"molnode": mn_json,
}
)

edges_out: list[dict[str, Any]] = []
for u, v, data in self.g.edges(data=True):
bonds = data.get("bonds", [])
edges_out.append(
{
"u": u,
"v": v,
"bonds": [b.to_dict() for b in bonds],
"n_bonds": int(data.get("n_bonds", len(bonds))),
}
)

return {
"unassigned": self.unassigned,
"nodes": nodes_out,
"edges": edges_out,
}

@classmethod
def from_dict(cls, data: dict[str, Any], validate: bool = True) -> "AssemblyGraph":
"""
Create an AssemblyGraph from a dictionary.

:param data: dictionary representation of the AssemblyGraph
:param validate: whether to validate the graph after creation (default: True)
:return: AssemblyGraph instance
"""
unassigned = data.get("unassigned", "unassigned")
g = nx.Graph()

# Nodes
for nd in data.get("nodes", []):
node_id = nd["id"]
tags = set(nd.get("tags", []))

mn_payload = nd.get("molnode", None)
molnode = None if mn_payload is None else MolNode.from_dict(mn_payload)

ident_payload = nd.get("identity", None)
identity = None if ident_payload is None else MolIdentity.from_dict(ident_payload)

g.add_node(node_id, molnode=molnode, tags=tags, identity=identity)

# Edges
for ed in data.get("edges", []):
u = ed["u"]
v = ed["v"]

bonds_raw = ed.get("bonds", [])
bonds = [RootBondLink.from_dict(b) for b in bonds_raw]

n_bonds = int(ed.get("n_bonds", len(bonds)))
g.add_edge(u, v, bonds=bonds, n_bonds=n_bonds)

ag = cls(g=g, unassigned=unassigned, validate_upon_initialization=False)

if validate:
ag.validate()

return ag

@classmethod
def build(
Expand Down
4 changes: 2 additions & 2 deletions src/retromol/model/reaction_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,8 +298,8 @@ def from_dict(cls, data: dict[str, Any]) -> "ReactionGraph":
:return: ReactionGraph object
"""
reaction_graph = cls(
nodes={int(enc): MolNode.from_dict(node_data) for enc, node_data in data["nodes"].items()},
nodes={enc: MolNode.from_dict(node_data) for enc, node_data in data["nodes"].items()},
edges=[RxnEdge.from_dict(edge_data) for edge_data in data["edges"]],
out_edges={int(enc): indices for enc, indices in data["out_edges"].items()},
out_edges={enc: indices for enc, indices in data["out_edges"].items()},
)
return reaction_graph
41 changes: 30 additions & 11 deletions src/retromol/model/readout.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,8 @@
from dataclasses import dataclass
from typing import Literal

from retromol.model.reaction_graph import MolNode
from retromol.model.reaction_graph import MolNode, ReactionGraph
from retromol.model.assembly_graph import AssemblyGraph
from retromol.model.result import Result
from retromol.model.rules import MatchingRule
from retromol.chem.mol import encode_mol
from retromol.chem.tagging import get_tags_mol
Expand All @@ -32,18 +31,18 @@ def __str__(self) -> str:
return f"LinearReadout(assembly_graph_nodes={self.assembly_graph.g.number_of_nodes()}; assembly_graph_edges={self.assembly_graph.g.number_of_edges()}; num_paths={len(self.paths)})"

@classmethod
def from_result(
def from_reaction_graph(
cls,
result: Result,
root_enc: str | None = None,
root_enc: str,
reaction_graph: ReactionGraph,
exclude_identities: list[MatchingRule] | None = None,
include_identities: list[MatchingRule] | None = None,
) -> "LinearReadout":
"""
Create a LinearReadout from a Result object.

:param result: RetroMol parsing result
:param root_enc: optional root molecule encoding; if None, use submission molecule
:param root_enc: encoding of the root molecule
:param reaction_graph: ReactionGraph object
:param exclude_identities: list of matching rules to exclude identities (not used here)
:param include_identities: list of matching rules to include identities (not used here)
:return: LinearReadout instance
Expand All @@ -57,10 +56,7 @@ def from_result(
if include_identities is not None:
include_identities = set([r.id for r in include_identities])

g = result.reaction_graph
if root_enc is None:
root_enc = encode_mol(result.submission.mol)

g = reaction_graph
if root_enc not in g.nodes:
raise ValueError(f"root_enc {root_enc} not found in reaction graph nodes")

Expand All @@ -83,3 +79,26 @@ def from_result(
paths.append(path)

return cls(assembly_graph=a, paths=paths)

def to_dict(self) -> dict:
"""
Serialize the LinearReadout to a dictionary.

:return: dict representation of LinearReadout
"""
return {
"assembly_graph": self.assembly_graph.to_dict(),
"paths": [[node.to_dict() for node in path] for path in self.paths],
}

@classmethod
def from_dict(cls, data: dict) -> "LinearReadout":
"""
Deserialize a LinearReadout from a dictionary.

:param data: dict representation of LinearReadout
:return: LinearReadout instance
"""
assembly_graph = AssemblyGraph.from_dict(data["assembly_graph"])
paths = [[MolNode.from_dict(node_data) for node_data in path_data] for path_data in data["paths"]]
return cls(assembly_graph=assembly_graph, paths=paths)
11 changes: 10 additions & 1 deletion src/retromol/model/result.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,27 @@
"""Module defining the Result data class."""

from dataclasses import dataclass, asdict
from dataclasses import dataclass
from typing import Any

from retromol.model.submission import Submission
from retromol.model.reaction_graph import ReactionGraph
from retromol.model.readout import LinearReadout
from retromol.chem.tagging import get_tags_mol


@dataclass(frozen=True)
class Result:
"""
Represents a RetroMol parsing result.

:var submission: Submission: the original submission associated with this result
:var reaction_graph: ReactionGraph: the reaction graph generated from retrosynthetic analysis
:var linear_readout: LinearReadout: the linear readout representation of the reaction graph
"""

submission: Submission
reaction_graph: ReactionGraph
linear_readout: LinearReadout

def __str__(self) -> str:
"""
Expand Down Expand Up @@ -55,6 +61,7 @@ def to_dict(self) -> dict[str, Any]:
return {
"submission": self.submission.to_dict(),
"reaction_graph": self.reaction_graph.to_dict(),
"linear_readout": self.linear_readout.to_dict(),
}

@classmethod
Expand All @@ -67,8 +74,10 @@ def from_dict(cls, data: dict[str, Any]) -> "Result":
"""
submission = Submission.from_dict(data["submission"])
reaction_graph = ReactionGraph.from_dict(data["reaction_graph"])
linear_readout = LinearReadout.from_dict(data["linear_readout"])

return cls(
submission=submission,
reaction_graph=reaction_graph,
linear_readout=linear_readout,
)
13 changes: 10 additions & 3 deletions src/retromol/model/submission.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,19 @@ class Submission:
:var smiles: str: SMILES representation of the submitted molecule
:var name: str | None: optional name of the submitted molecule
:var props: dict[str, Any] | None: optional additional properties associated with the submission
:var keep_stereo: bool: whether to keep stereochemistry during standardization
:var neutralize: bool: whether to neutralize the molecule during standardization
:var canonicalize_tautomer: bool: whether to canonicalize the tautomer during
"""

smiles: str
name: str | None = None
props: dict[str, Any] | None = None

keep_stereo: bool = True
neutralize: bool = True
canonicalize_tautomer: bool = False

mol: Mol = field(init=False, repr=False)
inchikey: str = field(init=False, repr=False)

Expand All @@ -36,9 +43,9 @@ def __post_init__(self) -> None:
# Generate standardized molecule
mol = standardize_from_smiles(
smiles,
keep_stereo=True,
neutralize=True,
tautomer_canon=True,
keep_stereo=self.keep_stereo,
neutralize=self.neutralize,
tautomer_canon=self.canonicalize_tautomer,
)

# Generate InChIKey
Expand Down
Loading