Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
224 changes: 204 additions & 20 deletions src/retromol/fingerprint/fingerprint.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,23 @@
import hashlib
import struct
import logging
from typing import Any, Literal, Callable, Iterable, Sequence
from typing import Any, Literal, Callable, Iterable, Sequence, Iterator, TypeVar

import numpy as np
from numpy.typing import NDArray

_BIOCRACKER = False
try:
from biocracker.query.modules import (
LinearReadout as BioCrackerLinearReadout,
NRPSModule,
PKSModule,
PKSExtenderUnit,
)
_BIOCRACKER = True
except ImportError:
pass

from retromol.model.result import Result
from retromol.model.rules import MatchingRule
from retromol.model.reaction_graph import MolNode
Expand All @@ -23,6 +35,59 @@


_MISS = object()
DEFAULT_KMER_WEIGHTS: dict[int, int] = {1: 1, 2: 1}
DEFAULT_KMER_SIZES: list[int] = [1, 2]


T = TypeVar("T")
Direction = Literal["forward", "backward", "both"]


def iter_kmers_sequence(items: Sequence[T], k: int, direction: Direction = "both") -> Iterator[tuple[T, ...]]:
"""
Iterate over k-mers of size k from a sequence of items.

:param items: sequence of items
:param k: size of the k-mers
:param direction: direction of k-mers to generate: "forward", "backward", or "both"
:return: iterator over k-mers as tuples
"""
if k < 1:
raise ValueError("k must be at least 1")

n = len(items)
if k > n:
return # no kmers possible

# Special case: 1-mers have no meaningful direction
if k == 1:
for x in items:
yield (x,)
return

def forward() -> Iterator[tuple[T, ...]]:
"""
Yield forward k-mers.
"""
for i in range(0, n - k + 1):
yield tuple(items[i : i + k])

def backward() -> Iterator[tuple[T, ...]]:
"""
Yield backward k-mers.
"""
for km in forward():
yield km[::-1]

if direction == "forward":
yield from forward()
elif direction == "backward":
yield from backward()
elif direction == "both":
yield from forward()
yield from backward()
else:
raise ValueError(f"invalid direction: {direction}")


def encode_family_token(fam: str) -> str:
Expand Down Expand Up @@ -165,7 +230,7 @@ def _nh(_: int) -> int:
# Set bits to 1 (binary)
fp[idxs] = 1

return fp
return fp


class FingerprintGenerator:
Expand Down Expand Up @@ -274,11 +339,11 @@ def fingerprint_from_result(
"""
# Default kmer_sizes
if kmer_sizes is None:
kmer_sizes = [1, 2]
kmer_sizes = DEFAULT_KMER_SIZES

# Default kmer_weights
if kmer_weights is None:
kmer_weights = {1: 1, 2: 1}
kmer_weights = DEFAULT_KMER_WEIGHTS

# Retrieve AssemblyGraph from Result
a = result.linear_readout.assembly_graph
Expand All @@ -289,27 +354,36 @@ def fingerprint_from_result(
for kmer_size in kmer_sizes:
for kmer in a.iter_kmers(k=kmer_size):

per_node_ancestors: list[list[str | None]] = []
max_depth = 0

for node in kmer:
anc = self.ancestor_list_for_node(node)
per_node_ancestors.append(anc)
max_depth = max(max_depth, len(anc))

# Emit ancestor-aligned kmers
# Ancestral tokens for items in kmer
per_item_ancestors: list[list[str | None]] = []

for item in kmer:
# Structural token is the lowest level ancestor
ancestors: list[str | None] = []

# First get the structural token (lowest level ancestor)
if item.is_identified:
smiles = item.smiles
ancestors.append(g.token if (g := self.assign_to_group(smiles)) is not None else None)
else:
ancestors.append(None)

# Then get the rest of the ancestors
# We reverse to have the highest level ancestor last
ancestors.extend(reversed(self.ancestor_list_for_node(item)))

per_item_ancestors.append(ancestors)

assert len(per_item_ancestors) == len(kmer), "length mismatch in ancestor tokens"

# Get tokenized kmer from every level of ancestor
max_depth = max(len(anc) for anc in per_item_ancestors)
for level in range(max_depth):
tokenized_kmers.append(tuple(
anc[level] if level < len(anc) else None
for anc in per_node_ancestors
for anc in per_item_ancestors
))

# Emite structural kmer separately (structure only)
tokenized_kmers.append(tuple(
(g.token if (g := self.assign_to_group(node.smiles)) is not None else None)
for node in kmer
))

# Gather additional 1-mer virtual family tokens (defined in matching rules); only once per found monomer
for node in a.monomer_nodes():
ident = node.identity if node.is_identified else None
Expand All @@ -328,3 +402,113 @@ def fingerprint_from_result(
)

return fp

def fingerprint_from_biocracker_readout(
self,
readout: BioCrackerLinearReadout,
num_bits: int = 2048,
kmer_sizes: list[int] | None = None,
kmer_weights: dict[int, int] | None = None,
counted: bool = False,
) -> NDArray[np.int8]:
"""
Generate a fingerprint from a BioCracker LinearReadout.

:param readout: BioCracker LinearReadout object
:param num_bits: number of bits in the fingerprint
:param kmer_sizes: list of k-mer sizes to consider
:param kmer_weights: weights for each k-mer size. Determines how many bits each k-mer sets.
:param counted: if True, count the number of times each k-mer appears.
:return: fingerprint as a numpy array
:raises ImportError: if biocracker is not installed
:raises ValueError: if unsupported module type is encountered
:raises AssertionError: if length mismatches occur in token lists
"""
if not _BIOCRACKER:
raise ImportError("biocracker is not installed; cannot generate fingerprint from biocracker readout")

# Default kmer_sizes
if kmer_sizes is None:
kmer_sizes = DEFAULT_KMER_SIZES

# Default kmer_weights
if kmer_weights is None:
kmer_weights = DEFAULT_KMER_WEIGHTS

# Calculate kmers from BioCracker's linear readout
tokenized_kmers: list[tuple[str | None, ...]] = []

modules = readout.biosynthetic_order()
for kmer_size in kmer_sizes:
for kmer in iter_kmers_sequence(modules, k=kmer_size):

# Ancestral tokens for items in kmer
per_module_ancestors: list[list[str | None]] = []

for module in kmer:
# Structural token is the lowest level ancestor
ancestors: list[str | None] = []

if isinstance(module, NRPSModule):
# Extract SMILES of predicted substrate
if module.substrate is not None:
if module.substrate.name == "graminine":
# Graminine SMILES was incorrect in BioCracker versions <2.0.1; fix here for backwards compatibility
smiles = r"O=NN(O)CCC[C@H](N)(C(=O)O)"
else:
smiles = module.substrate.smiles
else:
smiles = None

if smiles is not None:
ancestors.append(g.token if (g := self.assign_to_group(smiles)) is not None else None)
else:
# No predicted substrate
ancestors.append(None)

# We don't add ancestral tokens for NRPSModule

elif isinstance(module, PKSModule):
# PKSModule has no structural token
ancestors.append(None)

# Extract ancestral tokens
match module.substrate.extender_unit:
case PKSExtenderUnit.PKS_A: ancestors.extend(["A", "PKS"])
case PKSExtenderUnit.PKS_B: ancestors.extend(["B", "PKS"])
case PKSExtenderUnit.PKS_C: ancestors.extend(["C", "PKS"])
case PKSExtenderUnit.PKS_D: ancestors.extend(["D", "PKS"])

else:
# Unsupported module type
log.warning(f"Unsupported module type: {type(module)}")
ancestors.append(None)

per_module_ancestors.append(ancestors)

assert len(per_module_ancestors) == len(kmer), "length mismatch in ancestor tokens"

# Get tokenized kmer from every level of ancestor
max_depth = max(len(anc) for anc in per_module_ancestors)
for level in range(max_depth):
tokenized_kmers.append(tuple(
anc[level] if level < len(anc) else None
for anc in per_module_ancestors
))

# Add modifiers as family tokens
for modifier in readout.modifiers:
tokenized_kmers.append((encode_family_token(modifier),))

# Hash kmers
fp = kmers_to_fingerprint(
tokenized_kmers,
num_bits=num_bits,
num_hashes_per_kmer=lambda k: kmer_weights.get(k, 1),
seed=42,
none_policy="keep",
counted=counted,
)

return fp