diff --git a/CHANGELOG.md b/CHANGELOG.md index 9ac8304..2c9ac83 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,26 @@ This project follows a lightweight "keep a log" style. +## 0.1.3 - Refactoring and robustness + +- **Added**: + - `undirected_edges_unique(...)` for undirected containers and use it internally for component generator extraction. + - Read-only accessors `PeriodicComponent.snf` and `PeriodicComponent.tree_parent_map()` and updated `canonical_lift(...)` to avoid touching private component caches. + +- **Changed**: + - `LiftPatch.to_networkx(as_undirected=True, ...)` now stores direction and orig-edge snapshots under a single reserved edge attribute `__pbcgraph__` to avoid collisions with user attributes. + - Minor: code that previously read `_pbc_*` or `orig_edges` from NetworkX exports must now read `__pbcgraph__`. + +- **Performance**: + - `lift_patch(...)` avoids redundant incoming-edge traversal for undirected containers (no semantic change). + - Refactored `edges()`, `neighbors()`, and `in_neighbors()` to use streaming deterministic iteration (avoid building a full edge list just to sort it). + +- **Refactors**: + - Refactored `LiftPatch.to_networkx(...)` into small helpers (no behavior change). + - Split `pbcgraph.alg.lift` into `_lift_patch` and `_canonical_lift` implementation modules, keeping the public API unchanged. + - General refactors: introduce an internal constant `PBC_META_KEY` for NetworkX export metadata, simplify internal key filtering in undirected containers, and apply small style cleanups. + + ## 0.1.2 - Finite lifts and canonical lifts - **Finite lift patches** diff --git a/docs/examples/lift_patch.ipynb b/docs/examples/lift_patch.ipynb index 5602262..4b6a349 100644 --- a/docs/examples/lift_patch.ipynb +++ b/docs/examples/lift_patch.ipynb @@ -39,6 +39,7 @@ " PeriodicMultiGraph,\n", " PeriodicDiGraph,\n", " PeriodicMultiDiGraph,\n", + " PBC_META_KEY,\n", ")\n" ] }, @@ -152,7 +153,7 @@ "You can still request an undirected view from the patch export:\n", "\n", "- `undirected_mode='multigraph'`: one undirected multiedge per directed edge\n", - "- `undirected_mode='orig_edges'`: one undirected edge with `orig_edges=[...]`\n" + "- `undirected_mode='orig_edges'`: one undirected edge with `__pbcgraph__={'orig_edges': [...]}`\n" ] }, { @@ -190,7 +191,7 @@ "for u, v, data in nxU.edges(data=True):\n", " if {u, v} != {('A', (0,)), ('B', (0,))}:\n", " continue\n", - " print(u, '--', v, 'label=', data.get('label'), 'tail=', data.get('_pbc_tail'), 'head=', data.get('_pbc_head'))\n" + " print(u, '--', v, 'label=', data.get('label'), 'tail=', data.get(PBC_META_KEY, {}).get('tail'), 'head=', data.get(PBC_META_KEY, {}).get('head'))\n" ] }, { @@ -205,7 +206,7 @@ "print(type(nxC))\n", "data = nxC.edges[('A', (0,)), ('B', (0,))]\n", "print('orig_edges records:')\n", - "pprint(data['orig_edges'])\n" + "pprint(data[PBC_META_KEY]['orig_edges'])\n" ] }, { @@ -286,4 +287,4 @@ }, "nbformat": 4, "nbformat_minor": 5 -} +} \ No newline at end of file diff --git a/docs/general/concepts.md b/docs/general/concepts.md index 5bacbf1..1d5da44 100644 --- a/docs/general/concepts.md +++ b/docs/general/concepts.md @@ -133,9 +133,19 @@ Important details: - **Undirected views of directed patches** are available via `LiftPatch.to_networkx(as_undirected=True, undirected_mode=...)`: - `undirected_mode='multigraph'`: one undirected multiedge per directed - edge; direction metadata is stored in `_pbc_tail`/`_pbc_head`. + edge; direction metadata is stored under the `__pbcgraph__` edge + attribute. - `undirected_mode='orig_edges'`: collapsed simple graph; each undirected - adjacency stores `orig_edges=[...]` snapshots. + adjacency stores `orig_edges=[...]` snapshots under the `__pbcgraph__` + edge attribute. + + + +**NetworkX export metadata key.** Direction/origin information is stored under a single +reserved edge attribute named `__pbcgraph__`. In code, prefer using the constant +`PBC_META_KEY` (exported from `pbcgraph`) instead of hardcoding the string. +The library reserves this key for its own metadata; attempting to set it as a user +edge attribute will raise an error. These export options avoid silent loss of information when you want an undirected representation for an inherently directed relation. diff --git a/src/pbcgraph/__about__.py b/src/pbcgraph/__about__.py index f254bcc..ecfe157 100644 --- a/src/pbcgraph/__about__.py +++ b/src/pbcgraph/__about__.py @@ -1,4 +1,4 @@ -__version__ = '0.1.2' +__version__ = '0.1.3' __all__ = [ '__version__', diff --git a/src/pbcgraph/__init__.py b/src/pbcgraph/__init__.py index 984043d..1dee704 100644 --- a/src/pbcgraph/__init__.py +++ b/src/pbcgraph/__init__.py @@ -73,6 +73,7 @@ ) from pbcgraph.component import PeriodicComponent +from pbcgraph.core.constants import PBC_META_KEY from pbcgraph.core.exceptions import PBCGraphError, StaleComponentError from pbcgraph.core.types import ( TVec, @@ -88,6 +89,7 @@ __all__ = [ '__version__', + 'PBC_META_KEY', # containers 'PeriodicDiGraph', 'PeriodicGraph', diff --git a/src/pbcgraph/alg/_canonical_lift.py b/src/pbcgraph/alg/_canonical_lift.py new file mode 100644 index 0000000..d462c87 --- /dev/null +++ b/src/pbcgraph/alg/_canonical_lift.py @@ -0,0 +1,504 @@ +"""Canonical lifts (strand representatives). + +This module provides `canonical_lift(...)` and the `CanonicalLift` +container: a deterministic choice of one instance per quotient node +for a selected strand (coset) in the infinite lift. +""" + +from __future__ import annotations + +from collections import deque +from dataclasses import dataclass +from typing import ( + Any, + Callable, + Dict, + FrozenSet, + Hashable, + List, + Literal, + Optional, + Sequence, + Tuple, + Union, +) + +from pbcgraph.core.exceptions import CanonicalLiftError +from pbcgraph.core.ordering import fallback_key, stable_sorted +from pbcgraph.core.types import ( + NodeId, + NodeInst, + TVec, + add_tvec, + sub_tvec, + zero_tvec, + validate_tvec, +) + + +TreeEdgeRec = Tuple[NodeId, NodeId, TVec, int] + + +@dataclass(frozen=True) +class CanonicalLift: + """A deterministic finite representation of a single strand. + + Attributes: + nodes: Node instances `(u, shift)` in canonical order. Contains + exactly one instance for every quotient node in the component. + strand_key: Target strand (coset) key in `Z^d / L`. + anchor_site: Quotient node chosen to be placed in `anchor_shift`. + anchor_shift: Anchor cell translation vector. + placement: Placement mode used to construct the lift. + score: Placement score (smaller is better; 0 is best). + tree_edges: Optional spanning-tree edge records for debugging. + """ + + nodes: Tuple[NodeInst, ...] + strand_key: Hashable + anchor_site: NodeId + anchor_shift: TVec + placement: str + score: Union[int, float] + tree_edges: Optional[Tuple[TreeEdgeRec, ...]] = None + + +def _sorted_nodes_by_key( + nodes: Sequence[NodeId], + node_order: Optional[Callable[[NodeId], Any]], +) -> Tuple[NodeId, ...]: + seq = list(nodes) + if not seq: + return () + + if node_order is None: + return tuple(stable_sorted(seq)) + + def k(u: NodeId) -> Any: + return node_order(u) + + try: + return tuple( + sorted(seq, key=lambda u: (k(u), fallback_key(u))) + ) + except TypeError: + return tuple( + sorted(seq, key=lambda u: (fallback_key(k(u)), fallback_key(u))) + ) + + +def _sorted_node_insts( + insts: Sequence[NodeInst], + node_order: Optional[Callable[[NodeId], Any]], +) -> Tuple[NodeInst, ...]: + seq = list(insts) + if not seq: + return () + + if node_order is None: + try: + return tuple(sorted(seq, key=lambda x: (x[0], x[1]))) + except TypeError: + return tuple(sorted(seq, key=lambda x: (fallback_key(x[0]), x[1]))) + + def k(u: NodeId) -> Any: + return node_order(u) + + try: + return tuple(sorted( + seq, key=lambda x: ( + k(x[0]), x[1], fallback_key(x[0]) + ) + )) + except TypeError: + return tuple(sorted( + seq, key=lambda x: ( + fallback_key(k(x[0])), x[1], fallback_key(x[0]) + ) + )) + + +def _compute_lift_score( + snf: Any, + rel_shifts: Dict[NodeId, TVec], + nodes: Sequence[NodeId], + score: Literal['l1', 'l2'], +) -> int: + """Compute placement score for a lift. + + Args: + snf: SNF decomposition of the component translation subgroup. + rel_shifts: Per-node relative shifts with respect to the anchor site. + nodes: Quotient node ids in the component. + score: Score metric: 'l1' or 'l2'. + + Returns: + The deterministic integer score (smaller is better). + + Raises: + CanonicalLiftError: If the SNF decomposition is invalid. + """ + r = int(snf.rank) + total = 0 + for u in nodes: + y = snf.apply_U(rel_shifts[u]) + node_mag = 0 + for i in range(r): + di = int(snf.diag[i]) + if di == 0: + raise CanonicalLiftError('invalid SNF diagonal entry') + qi = int(y[i] // di) + if score == 'l1': + node_mag += abs(qi) + else: + node_mag += qi * qi + total += node_mag + return int(total) + + +def _compute_rel_abs_shifts( + pot: Dict[NodeId, TVec], + *, + anchor_site: NodeId, + anchor_shift: TVec, +) -> Tuple[Dict[NodeId, TVec], Dict[NodeId, TVec]]: + """Compute relative and absolute shifts for a given anchor site.""" + pot_anchor = pot[anchor_site] + rel: Dict[NodeId, TVec] = {} + abs_s: Dict[NodeId, TVec] = {} + for u, pu in pot.items(): + r = sub_tvec(pu, pot_anchor) + rel[u] = r + abs_s[u] = add_tvec(anchor_shift, r) + return rel, abs_s + + +def _build_internal_adj( + component: Any, + abs_shift: Dict[NodeId, TVec], +) -> Dict[NodeId, FrozenSet[NodeId]]: + """Build induced internal undirected adjacency on selected instances. + + An undirected adjacency between quotient nodes `u` and `v` exists if at + least one directed periodic edge between them is consistent with the + selected absolute shifts. + + Args: + component: PeriodicComponent. + abs_shift: Mapping `u -> shift` for exactly the component nodes. + + Returns: + Dict mapping node id to a frozen set of adjacent node ids. + """ + adj: Dict[NodeId, set[NodeId]] = {u: set() for u in component.nodes} + for u in component.nodes: + su = abs_shift[u] + for v, t, _k in component.graph.neighbors(u, keys=True, data=False): + if v not in component.nodes: + continue + if abs_shift[v] == add_tvec(su, t): + adj[u].add(v) + adj[v].add(u) + return {u: frozenset(nbs) for u, nbs in adj.items()} + + +def _is_connected_undirected( + adj: Dict[NodeId, FrozenSet[NodeId]], + nodes_ordered: Sequence[NodeId], + *, + skip: Optional[NodeId] = None, +) -> bool: + """Return True if the induced graph is connected + (optionally skipping a node).""" + nodes = [u for u in nodes_ordered if u != skip] + if not nodes: + return True + + start = nodes[0] + seen: set[NodeId] = {start} + q: deque[NodeId] = deque([start]) + + while q: + u = q.popleft() + for v in stable_sorted(list(adj.get(u, frozenset()))): + if v == skip: + continue + if v in seen: + continue + seen.add(v) + q.append(v) + return len(seen) == len(nodes) + + +def _boundary_deltas_for_node( + component: Any, + abs_shift: Dict[NodeId, TVec], + u: NodeId, +) -> Tuple[TVec, ...]: + """Enumerate per-node deltas induced by boundary periodic edges.""" + su = abs_shift[u] + deltas: set[TVec] = set() + + for v, t, _k in component.graph.neighbors(u, keys=True, data=False): + if v not in component.nodes: + continue + desired = add_tvec(su, t) + if abs_shift[v] == desired: + continue + # Want: abs_shift[v] == (su + delta) + t + delta = sub_tvec(sub_tvec(abs_shift[v], su), t) + deltas.add(delta) + + for v, t_in, _k in component.graph.in_neighbors(u, keys=True, data=False): + if v not in component.nodes: + continue + desired_u = add_tvec(abs_shift[v], t_in) + if desired_u == su: + continue + # Want: (su + delta) == abs_shift[v] + t_in + delta = sub_tvec(desired_u, su) + deltas.add(delta) + + if not deltas: + return () + + try: + return tuple(sorted(deltas)) + except TypeError: + return tuple(sorted(deltas, key=fallback_key)) + + +def canonical_lift( + component: Any, + *, + strand_key: Optional[Hashable] = None, + seed: Optional[NodeInst] = None, + anchor_shift: Optional[TVec] = None, + placement: Literal['tree', 'best_anchor', 'greedy_cut'] = 'tree', + score: Literal['l1', 'l2'] = 'l1', + return_tree: bool = False, + node_order: Optional[Callable[[NodeId], Any]] = None, + edge_order: Optional[Callable[[Tuple[Any, ...]], Any]] = None, +) -> CanonicalLift: + """Construct a deterministic finite representation of one strand. + + v0.1.2 step4 implements `placement='tree'`, `placement='best_anchor'`, and + `placement='greedy_cut'`. + + Args: + component: A :class:`~pbcgraph.component.PeriodicComponent`. + strand_key: Optional explicit strand key. + seed: Optional seed instance `(u, shift)`. + anchor_shift: Optional anchor cell shift. + placement: Placement mode (`'tree'` in step2). + score: Score metric: `'l1'` or `'l2'`. + return_tree: If True, include spanning-tree edge records. + node_order: Optional ordering key for quotient node ids. + edge_order: Optional ordering key for periodic edges (reserved). + + Returns: + A :class:`~pbcgraph.alg.lift.CanonicalLift`. + + Raises: + CanonicalLiftError: On invalid inputs or if the requested strand does + not intersect the anchor cell. + """ + del edge_order # Reserved for later placement modes. + + if placement not in ('tree', 'best_anchor', 'greedy_cut'): + raise CanonicalLiftError( + "canonical_lift placement must be one of 'tree', " + "'best_anchor', 'greedy_cut'" + ) + + dim = int(component.graph.dim) + + if seed is not None: + u_seed, s_seed = seed + validate_tvec(s_seed, dim) + else: + u_seed = None # noqa: F841 + s_seed = None + + if anchor_shift is None: + if s_seed is not None: + anchor_shift = s_seed + else: + anchor_shift = zero_tvec(dim) + else: + validate_tvec(anchor_shift, dim) + + if strand_key is None: + if seed is not None: + try: + K = component.inst_key(seed) + except KeyError as e: + raise CanonicalLiftError( + 'seed does not belong to component' + ) from e + else: + nodes_sorted = _sorted_nodes_by_key( + list(component.nodes), node_order + ) + if not nodes_sorted: + raise CanonicalLiftError('component has no nodes') + default_seed = (nodes_sorted[0], zero_tvec(dim)) + K = component.inst_key(default_seed) + else: + K = strand_key + + eligible: List[NodeId] = [] + for u in component.nodes: + if component.inst_key((u, anchor_shift)) == K: + eligible.append(u) + + if not eligible: + raise CanonicalLiftError( + 'requested strand_key does not intersect the anchor cell' + ) + + pot = {u: component.potential(u) for u in component.nodes} + + snf = component.snf + + if score not in ('l1', 'l2'): + raise CanonicalLiftError("score must be 'l1' or 'l2'") + + nodes_list = list(component.nodes) + eligible_sorted = _sorted_nodes_by_key(eligible, node_order) + + if placement == 'tree': + anchor_site = eligible_sorted[0] + rel_shift, abs_shift = _compute_rel_abs_shifts( + pot, + anchor_site=anchor_site, + anchor_shift=anchor_shift, + ) + total_score = _compute_lift_score(snf, rel_shift, nodes_list, score) + else: + best_anchor_site: Optional[NodeId] = None + best_rel: Optional[Dict[NodeId, TVec]] = None + best_abs: Optional[Dict[NodeId, TVec]] = None + best_score: Optional[int] = None + + for a in eligible_sorted: + rel_a, abs_a = _compute_rel_abs_shifts( + pot, + anchor_site=a, + anchor_shift=anchor_shift, + ) + s = _compute_lift_score(snf, rel_a, nodes_list, score) + if best_score is None or s < best_score: + best_score = int(s) + best_anchor_site = a + best_rel = rel_a + best_abs = abs_a + + if best_anchor_site is None or best_rel is None or best_abs is None: + raise CanonicalLiftError('failed to select anchor site') + + anchor_site = best_anchor_site + rel_shift = best_rel + abs_shift = best_abs + total_score = int(best_score) + + if placement == 'greedy_cut': + # Start from the best-anchor placement and perform local, per-node + # moves by elements of the translation subgroup L that improve score + # while keeping the induced internal graph connected. + nodes_sorted = _sorted_nodes_by_key(list(component.nodes), node_order) + abs_cur: Dict[NodeId, TVec] = dict(abs_shift) + cur_score = int(total_score) + + while True: + moved = False + adj = _build_internal_adj(component, abs_cur) + if not _is_connected_undirected(adj, nodes_sorted): + raise CanonicalLiftError( + 'internal induced graph is disconnected' + ) + + for u in nodes_sorted: + if u == anchor_site: + continue + deltas = _boundary_deltas_for_node(component, abs_cur, u) + if not deltas: + continue + + # Pre-filter: u must not be an articulation point of the + # current internal graph. + if not _is_connected_undirected(adj, nodes_sorted, skip=u): + continue + + best_move: Optional[Tuple[int, TVec]] = None + old_s = abs_cur[u] + + for delta in deltas: + new_s = add_tvec(old_s, delta) + if component.inst_key((u, new_s)) != K: + continue + + abs_cur[u] = new_s + new_adj = _build_internal_adj(component, abs_cur) + ok = True + if not new_adj.get(u, frozenset()): + ok = False + elif not _is_connected_undirected(new_adj, nodes_sorted): + ok = False + + if ok: + rel_tmp = { + x: sub_tvec(abs_cur[x], abs_cur[anchor_site]) + for x in component.nodes + } + s = _compute_lift_score( + snf, rel_tmp, nodes_list, score + ) + if s < cur_score: + if best_move is None: + best_move = (int(s), delta) + else: + best_s, best_delta = best_move + if int(s) < best_s or ( + int(s) == best_s and delta < best_delta + ): + best_move = (int(s), delta) + + abs_cur[u] = old_s + + if best_move is not None: + best_s, best_delta = best_move + abs_cur[u] = add_tvec(abs_cur[u], best_delta) + cur_score = int(best_s) + moved = True + break + + if not moved: + break + + abs_shift = abs_cur + total_score = int(cur_score) + + insts = [(u, abs_shift[u]) for u in component.nodes] + insts_sorted = _sorted_node_insts(insts, node_order) + + tree_edges: Optional[Tuple[TreeEdgeRec, ...]] = None + if return_tree: + recs: List[TreeEdgeRec] = [] + parent_map = component.tree_parent_map() + children = _sorted_nodes_by_key(list(parent_map.keys()), node_order) + for child in children: + parent, _t, k = parent_map[child] + tvec = sub_tvec(abs_shift[child], abs_shift[parent]) + recs.append((parent, child, tvec, int(k))) + tree_edges = tuple(recs) + + return CanonicalLift( + nodes=insts_sorted, + strand_key=K, + anchor_site=anchor_site, + anchor_shift=anchor_shift, + placement=placement, + score=int(total_score), + tree_edges=tree_edges, + ) diff --git a/src/pbcgraph/alg/_lift_patch.py b/src/pbcgraph/alg/_lift_patch.py new file mode 100644 index 0000000..a61afdc --- /dev/null +++ b/src/pbcgraph/alg/_lift_patch.py @@ -0,0 +1,537 @@ +"""Finite lift patches of periodic graphs. + +This module provides `lift_patch(...)` and the `LiftPatch` container, +a finite, non-periodic view extracted from the infinite lift of a +periodic quotient graph. +""" + + +from __future__ import annotations + +from collections import deque +from dataclasses import dataclass +from typing import ( + Any, + Callable, + Dict, + Iterator, + List, + Literal, + Optional, + Sequence, + Tuple, + Union, +) + +import networkx as nx + +from pbcgraph.core.constants import PBC_META_KEY +from pbcgraph.core.exceptions import LiftPatchError +from pbcgraph.core.ordering import fallback_key, stable_sorted +from pbcgraph.core.protocols import PeriodicDiGraphLike +from pbcgraph.core.types import ( + NodeInst, + TVec, + validate_tvec, +) + + +PatchEdgeRec = Tuple[NodeInst, NodeInst, Dict[str, Any]] +PatchMultiEdgeRec = Tuple[NodeInst, NodeInst, int, Dict[str, Any]] + + +def _validate_box( + box: Sequence[Sequence[int]], + dim: int, +) -> Tuple[Tuple[int, int], ...]: + if len(box) != dim: + raise LiftPatchError('box dimension mismatch') + out: List[Tuple[int, int]] = [] + for rng in box: + if len(rng) != 2: + raise LiftPatchError('box must be a sequence of (lo, hi) pairs') + lo = int(rng[0]) + hi = int(rng[1]) + if hi < lo: + raise LiftPatchError('box has invalid range (hi < lo)') + out.append((lo, hi)) + return tuple(out) + + +def _intersect_boxes( + a: Optional[Tuple[Tuple[int, int], ...]], + b: Optional[Tuple[Tuple[int, int], ...]], + dim: int, +) -> Optional[Tuple[Tuple[int, int], ...]]: + if a is None: + return b + if b is None: + return a + if len(a) != dim or len(b) != dim: + raise LiftPatchError('box dimension mismatch') + out: List[Tuple[int, int]] = [] + for (lo1, hi1), (lo2, hi2) in zip(a, b): + lo = max(lo1, lo2) + hi = min(hi1, hi2) + if hi < lo: + # Empty intersection: still return a valid box. + out.append((lo, lo)) + else: + out.append((lo, hi)) + return tuple(out) + + +def _in_box(shift: TVec, box: Optional[Tuple[Tuple[int, int], ...]]) -> bool: + if box is None: + return True + for x, (lo, hi) in zip(shift, box): + if x < lo or x >= hi: + return False + return True + + +def _try_sort_patch_edges( + records: List[Tuple[Any, Any, int, Any]], +) -> None: + """Sort patch edge candidates deterministically. + + Records are (u_inst, v_inst, key, payload). + """ + try: + records.sort(key=lambda r: (r[0], r[1], r[2])) + except TypeError: + records.sort( + key=lambda r: (fallback_key(r[0]), fallback_key(r[1]), r[2]) + ) + + +@dataclass(frozen=True) +class LiftPatch: + """A finite patch extracted from the infinite lift. + + Attributes: + nodes: Node instances `(u, shift)` in canonical order. + edges: Edges between included node instances. + + - For simple containers: `(u_inst, v_inst, attrs)`. + - For multigraph containers: `(u_inst, v_inst, key, attrs)`. + + For directed patches, `(u_inst, v_inst)` is ordered. + For undirected patches, endpoints are in canonical order. + seed: Seed node instance. + radius: BFS radius in the lifted graph (weak connectivity), if used. + box: Effective absolute box constraint after intersection, if used. + """ + + nodes: Tuple[NodeInst, ...] + edges: Tuple[Union[PatchEdgeRec, PatchMultiEdgeRec], ...] + seed: NodeInst + radius: Optional[int] + box: Optional[Tuple[Tuple[int, int], ...]] + _is_multigraph: bool = False + _is_directed: bool = False + + @property + def is_multigraph(self) -> bool: + """Whether the patch edges include keys.""" + return bool(self._is_multigraph) + + @property + def is_directed(self) -> bool: + """Whether the patch edges are directed.""" + return bool(self._is_directed) + + def to_networkx( + self, + *, + as_undirected: Optional[bool] = None, + undirected_mode: Literal['multigraph', 'orig_edges'] = 'multigraph', + ) -> Union[nx.Graph, nx.DiGraph, nx.MultiGraph, nx.MultiDiGraph]: + """Export the patch as a NetworkX graph. + + Notes: + - By default, directed patches export as directed NetworkX graphs, + and undirected patches export as undirected. + - For directed patches, `as_undirected=True` provides an undirected + view: + - `undirected_mode='multigraph'` returns a MultiGraph where + each directed edge becomes a distinct undirected multiedge, + with direction metadata stored under the `__pbcgraph__` + edge attribute. + - `undirected_mode='orig_edges'` returns a simple Graph where + each undirected adjacency stores `orig_edges=[...]` + snapshots under the `__pbcgraph__` edge attribute. + """ + if as_undirected is None: + as_undirected = not self.is_directed + + if not self.is_directed and as_undirected is False: + raise ValueError('cannot export an undirected patch as directed') + + if not as_undirected: + return _lift_patch_to_networkx_directed(self) + + if not self.is_directed: + return _lift_patch_to_networkx_undirected(self) + + if undirected_mode == 'multigraph': + return _lift_patch_to_networkx_directed_multigraph(self) + + if undirected_mode == 'orig_edges': + return _lift_patch_to_networkx_directed_orig_edges(self) + + raise ValueError('invalid undirected_mode') + + +def _lift_patch_to_networkx_directed( + patch: LiftPatch, +) -> Union[nx.DiGraph, nx.MultiDiGraph]: + if patch.is_multigraph: + Gd: Union[nx.DiGraph, nx.MultiDiGraph] = nx.MultiDiGraph() + else: + Gd = nx.DiGraph() + + for node in patch.nodes: + Gd.add_node(node) + + if patch.is_multigraph: + for u, v, key, attrs in patch.edges: # type: ignore[misc] + Gd.add_edge(u, v, key=int(key), **dict(attrs)) + else: + for u, v, attrs in patch.edges: # type: ignore[misc] + Gd.add_edge(u, v, **dict(attrs)) + return Gd + + +def _lift_patch_to_networkx_undirected( + patch: LiftPatch, +) -> Union[nx.Graph, nx.MultiGraph]: + if patch.is_multigraph: + Gu: Union[nx.Graph, nx.MultiGraph] = nx.MultiGraph() + else: + Gu = nx.Graph() + + for node in patch.nodes: + Gu.add_node(node) + + if patch.is_multigraph: + for u, v, key, attrs in patch.edges: # type: ignore[misc] + Gu.add_edge(u, v, key=int(key), **dict(attrs)) + else: + for u, v, attrs in patch.edges: # type: ignore[misc] + Gu.add_edge(u, v, **dict(attrs)) + return Gu + + +def _lift_patch_to_networkx_directed_multigraph( + patch: LiftPatch, +) -> nx.MultiGraph: + Gu = nx.MultiGraph() + for node in patch.nodes: + Gu.add_node(node) + + if patch.is_multigraph: + for u, v, key, attrs in patch.edges: # type: ignore[misc] + data = dict(attrs) + data[PBC_META_KEY] = { + 'tail': u, + 'head': v, + 'key': int(key), + } + Gu.add_edge(u, v, **data) + else: + for u, v, attrs in patch.edges: # type: ignore[misc] + data = dict(attrs) + data[PBC_META_KEY] = { + 'tail': u, + 'head': v, + 'key': None, + } + Gu.add_edge(u, v, **data) + return Gu + + +def _lift_patch_to_networkx_directed_orig_edges( + patch: LiftPatch, +) -> nx.Graph: + Gu = nx.Graph() + for node in patch.nodes: + Gu.add_node(node) + + def _canon_pair(a: NodeInst, b: NodeInst) -> Tuple[NodeInst, NodeInst]: + uu, vv = stable_sorted([a, b]) + return uu, vv + + buckets: Dict[Tuple[NodeInst, NodeInst], List[Dict[str, Any]]] = {} + if patch.is_multigraph: + for u, v, key, attrs in patch.edges: # type: ignore[misc] + a, b = _canon_pair(u, v) + rec = { + 'tail': u, + 'head': v, + 'key': int(key), + 'attrs': dict(attrs), + } + buckets.setdefault((a, b), []).append(rec) + else: + for u, v, attrs in patch.edges: # type: ignore[misc] + a, b = _canon_pair(u, v) + rec = { + 'tail': u, + 'head': v, + 'key': None, + 'attrs': dict(attrs), + } + buckets.setdefault((a, b), []).append(rec) + + for (a, b), recs in buckets.items(): + try: + recs.sort(key=lambda r: (r['tail'], r['head'], r['key'])) + except TypeError: + recs.sort( + key=lambda r: ( + fallback_key(r['tail']), + fallback_key(r['head']), + -1 if r['key'] is None else int(r['key']), + ) + ) + Gu.add_edge(a, b, **{PBC_META_KEY: {'orig_edges': recs}}) + return Gu + + +def lift_patch( + G: PeriodicDiGraphLike, + seed: NodeInst, + *, + radius: Optional[int] = None, + box: Optional[Tuple[Tuple[int, int], ...]] = None, + box_rel: Optional[Tuple[Tuple[int, int], ...]] = None, + include_edges: bool = True, + max_nodes: Optional[int] = None, + node_order: Optional[Callable[[NodeInst], Any]] = None, + edge_order: Optional[Callable[[Tuple[Any, ...]], Any]] = None, +) -> LiftPatch: + """Extract a finite patch of the lifted graph around a seed. + + The traversal uses weak connectivity in the infinite lift: from an instance + it considers both outgoing and incoming quotient edges. + + Notes: + The returned patch is directed if `G.is_undirected == False`, and + undirected otherwise. Use `LiftPatch.to_networkx(as_undirected=True, + ...)` to obtain undirected views of directed patches. + + + Args: + G: A periodic graph container. + seed: Seed instance `(u, shift)`. + radius: Optional BFS radius in the lifted graph. + box: Optional absolute half-open bounds per coordinate. + box_rel: Optional bounds relative to `seed.shift`. + include_edges: Whether to include edges between included nodes. + max_nodes: If provided, raise if the patch would include more than + `max_nodes` nodes. + node_order: Optional key function for ordering node instances. + edge_order: Optional key function for ordering edge records. + + Returns: + A :class:`~pbcgraph.alg.lift.LiftPatch`. + + Raises: + LiftPatchError: On invalid inputs or if `max_nodes` is exceeded. + """ + dim = int(G.dim) + u0, s0 = seed + validate_tvec(s0, dim) + if radius is None and box is None and box_rel is None: + raise LiftPatchError( + 'at least one of radius, box, or box_rel is required' + ) + if radius is not None: + radius = int(radius) + if radius < 0: + raise LiftPatchError('radius must be non-negative') + + abs_box: Optional[Tuple[Tuple[int, int], ...]] = None + if box is not None: + abs_box = _validate_box(box, dim) + + abs_box_rel: Optional[Tuple[Tuple[int, int], ...]] = None + if box_rel is not None: + rel = _validate_box(box_rel, dim) + out: List[Tuple[int, int]] = [] + for (lo, hi), x0 in zip(rel, s0): + out.append((int(x0) + lo, int(x0) + hi)) + abs_box_rel = tuple(out) + + eff_box = _intersect_boxes(abs_box, abs_box_rel, dim) + if not _in_box(s0, eff_box): + raise LiftPatchError('seed instance is outside the effective box') + + if max_nodes is not None: + max_nodes = int(max_nodes) + if max_nodes <= 0: + raise LiftPatchError('max_nodes must be positive') + + # ----------------- + # Traversal + # ----------------- + visited: Dict[NodeInst, int] = {seed: 0} + q: deque[NodeInst] = deque([seed]) + + def iter_weak_neighbors(inst: NodeInst) -> Iterator[NodeInst]: + for v, s2 in G.neighbors_inst(inst, keys=False, data=False): + yield v, s2 + if not G.is_undirected: + for v, s2 in G.in_neighbors_inst(inst, keys=False, data=False): + yield v, s2 + + while q: + cur = q.popleft() + dcur = visited[cur] + if radius is not None and dcur >= radius: + continue + + for nb in iter_weak_neighbors(cur): + _v, s2 = nb + validate_tvec(s2, dim) + if not _in_box(s2, eff_box): + continue + if nb in visited: + continue + visited[nb] = dcur + 1 + q.append(nb) + if max_nodes is not None and len(visited) > max_nodes: + raise LiftPatchError('max_nodes exceeded during traversal') + + # Canonical node order. + nodes_list = list(visited.keys()) + if node_order is None: + nodes = tuple(stable_sorted(nodes_list)) + else: + nodes = tuple(sorted(nodes_list, key=node_order)) + + patch_is_directed = not bool(G.is_undirected) + + # ----------------- + # Edge inclusion (no explicit tvec) + # ----------------- + edges_out: List[Union[PatchEdgeRec, PatchMultiEdgeRec]] = [] + if include_edges: + included_set = set(visited) + + if patch_is_directed: + records: List[ + Tuple[NodeInst, NodeInst, int, Any, Dict[str, Any]] + ] = [] + for inst in nodes: + for v, s2, k, attrs in G.neighbors_inst( + inst, keys=True, data=True + ): + nb = (v, s2) + if nb not in included_set: + continue + sel_key = (inst, nb, int(k)) + sc = ( + edge_order(sel_key) + if edge_order is not None + else sel_key + ) + records.append((inst, nb, int(k), sc, dict(attrs))) + + try: + records.sort(key=lambda r: (r[3], r[0], r[1], r[2])) + except TypeError: + records.sort( + key=lambda r: ( + fallback_key(r[3]), + fallback_key(r[0]), + fallback_key(r[1]), + r[2], + ) + ) + + if G.is_multigraph: + for u_inst, v_inst, kk, _sc, attrs in records: + edges_out.append((u_inst, v_inst, int(kk), dict(attrs))) + else: + for u_inst, v_inst, _kk, _sc, attrs in records: + edges_out.append((u_inst, v_inst, dict(attrs))) + + else: + candidates: List[ + Tuple[NodeInst, NodeInst, int, Dict[str, Any]] + ] = [] + for inst in nodes: + for v, s2, k, attrs in G.neighbors_inst( + inst, keys=True, data=True + ): + nb = (v, s2) + if nb not in included_set: + continue + candidates.append((inst, nb, int(k), dict(attrs))) + + # Canonicalize endpoints to undirected pairs. + canon: List[Tuple[NodeInst, NodeInst, int, Dict[str, Any]]] = [] + for a, b, k, attrs in candidates: + u_inst, v_inst = stable_sorted([a, b]) + canon.append((u_inst, v_inst, k, attrs)) + + # Deduplicate reciprocal realizations deterministically. + best: Dict[ + Tuple[NodeInst, NodeInst, Optional[int]], + Tuple[Any, Dict[str, Any]], + ] = {} + for u_inst, v_inst, k, attrs in canon: + if G.is_multigraph: + eid: Tuple[ + NodeInst, NodeInst, Optional[int] + ] = (u_inst, v_inst, k) + sel_key = (u_inst, v_inst, k) + else: + eid = (u_inst, v_inst, None) + sel_key = (u_inst, v_inst, k) + + score = ( + edge_order(sel_key) + if edge_order is not None + else sel_key + ) + + if eid not in best: + best[eid] = (score, attrs) + continue + prev_score, _prev_attrs = best[eid] + try: + better = score < prev_score + except TypeError: + better = fallback_key(score) < fallback_key(prev_score) + if better: + best[eid] = (score, attrs) + + if G.is_multigraph: + out_multi: List[Tuple[Any, Any, int, Any]] = [] + for (u_inst, v_inst, kk), (sc, attrs) in best.items(): + assert kk is not None + out_multi.append((u_inst, v_inst, int(kk), (sc, attrs))) + _try_sort_patch_edges(out_multi) + for u_inst, v_inst, kk, payload in out_multi: + _sc, attrs = payload + edges_out.append((u_inst, v_inst, int(kk), dict(attrs))) + else: + out_simple: List[Tuple[Any, Any, int, Any]] = [] + for (u_inst, v_inst, _), (sc, attrs) in best.items(): + out_simple.append((u_inst, v_inst, 0, (sc, attrs))) + _try_sort_patch_edges(out_simple) + for u_inst, v_inst, _kk, payload in out_simple: + _sc, attrs = payload + edges_out.append((u_inst, v_inst, dict(attrs))) + return LiftPatch( + nodes=nodes, + edges=tuple(edges_out), + seed=seed, + radius=radius, + box=eff_box, + _is_multigraph=bool(G.is_multigraph), + _is_directed=patch_is_directed, + ) diff --git a/src/pbcgraph/alg/lift.py b/src/pbcgraph/alg/lift.py index f6d0402..51fb91a 100644 --- a/src/pbcgraph/alg/lift.py +++ b/src/pbcgraph/alg/lift.py @@ -1,999 +1,17 @@ """Finite lifts of periodic graphs. -This module implements finite, non-periodic views derived from a periodic -quotient graph. +Public API re-exports: -v0.1.2 adds two high-level operations: +- `lift_patch`, `LiftPatch` +- `canonical_lift`, `CanonicalLift` -1) ``lift_patch``: extract a finite patch of the infinite lift - around a seed instance (directed for directed sources; undirected for - undirected sources). - -2) ``canonical_lift`` (added in later steps of the v0.1.2 plan). +Implementation lives in the private modules +`pbcgraph.alg._lift_patch` and `pbcgraph.alg._canonical_lift`. """ from __future__ import annotations -from collections import deque -from dataclasses import dataclass -from typing import ( - Any, - Callable, - Dict, - FrozenSet, - Hashable, - Iterator, - List, - Literal, - Optional, - Sequence, - Tuple, - Union, -) - -import networkx as nx - -from pbcgraph.core.exceptions import CanonicalLiftError, LiftPatchError -from pbcgraph.core.ordering import fallback_key, stable_sorted -from pbcgraph.core.protocols import PeriodicDiGraphLike -from pbcgraph.core.types import ( - NodeId, - NodeInst, - TVec, - add_tvec, - sub_tvec, - zero_tvec, - validate_tvec, -) - - -PatchEdgeRec = Tuple[NodeInst, NodeInst, Dict[str, Any]] -PatchMultiEdgeRec = Tuple[NodeInst, NodeInst, int, Dict[str, Any]] - - -def _validate_box( - box: Sequence[Sequence[int]], - dim: int, -) -> Tuple[Tuple[int, int], ...]: - if len(box) != dim: - raise LiftPatchError('box dimension mismatch') - out: List[Tuple[int, int]] = [] - for rng in box: - if len(rng) != 2: - raise LiftPatchError('box must be a sequence of (lo, hi) pairs') - lo = int(rng[0]) - hi = int(rng[1]) - if hi < lo: - raise LiftPatchError('box has invalid range (hi < lo)') - out.append((lo, hi)) - return tuple(out) - - -def _intersect_boxes( - a: Optional[Tuple[Tuple[int, int], ...]], - b: Optional[Tuple[Tuple[int, int], ...]], - dim: int, -) -> Optional[Tuple[Tuple[int, int], ...]]: - if a is None: - return b - if b is None: - return a - if len(a) != dim or len(b) != dim: - raise LiftPatchError('box dimension mismatch') - out: List[Tuple[int, int]] = [] - for (lo1, hi1), (lo2, hi2) in zip(a, b): - lo = max(lo1, lo2) - hi = min(hi1, hi2) - if hi < lo: - # Empty intersection: still return a valid box. - out.append((lo, lo)) - else: - out.append((lo, hi)) - return tuple(out) - - -def _in_box(shift: TVec, box: Optional[Tuple[Tuple[int, int], ...]]) -> bool: - if box is None: - return True - for x, (lo, hi) in zip(shift, box): - if x < lo or x >= hi: - return False - return True - - -def _try_sort_patch_edges( - records: List[Tuple[Any, Any, int, Any]], -) -> None: - """Sort patch edge candidates deterministically. - - Records are (u_inst, v_inst, key, payload). - """ - try: - records.sort(key=lambda r: (r[0], r[1], r[2])) - except TypeError: - records.sort( - key=lambda r: (fallback_key(r[0]), fallback_key(r[1]), r[2]) - ) - - -@dataclass(frozen=True) -class LiftPatch: - """A finite patch extracted from the infinite lift. - - Attributes: - nodes: Node instances `(u, shift)` in canonical order. - edges: Edges between included node instances. - - - For simple containers: `(u_inst, v_inst, attrs)`. - - For multigraph containers: `(u_inst, v_inst, key, attrs)`. - - For directed patches, `(u_inst, v_inst)` is ordered. - For undirected patches, endpoints are in canonical order. - seed: Seed node instance. - radius: BFS radius in the lifted graph (weak connectivity), if used. - box: Effective absolute box constraint after intersection, if used. - """ - - nodes: Tuple[NodeInst, ...] - edges: Tuple[Union[PatchEdgeRec, PatchMultiEdgeRec], ...] - seed: NodeInst - radius: Optional[int] - box: Optional[Tuple[Tuple[int, int], ...]] - _is_multigraph: bool = False - _is_directed: bool = False - - @property - def is_multigraph(self) -> bool: - """Whether the patch edges include keys.""" - return bool(self._is_multigraph) - - @property - def is_directed(self) -> bool: - """Whether the patch edges are directed.""" - return bool(self._is_directed) - - def to_networkx( - self, - *, - as_undirected: Optional[bool] = None, - undirected_mode: Literal['multigraph', 'orig_edges'] = 'multigraph', - ) -> Union[nx.Graph, nx.DiGraph, nx.MultiGraph, nx.MultiDiGraph]: - """Export the patch as a NetworkX graph. - - Notes: - - By default, directed patches export as directed NetworkX graphs, - and undirected patches export as undirected. - - For directed patches, `as_undirected=True` provides an undirected - view: - - `undirected_mode='multigraph'` returns a MultiGraph where - each directed edge becomes a distinct undirected multiedge, - with direction metadata in edge attributes. - - `undirected_mode='orig_edges'` returns a simple Graph where - each undirected adjacency stores `orig_edges=[...]` - snapshots. - """ - if as_undirected is None: - as_undirected = not self.is_directed - - if not self.is_directed and as_undirected is False: - raise ValueError('cannot export an undirected patch as directed') - - # Directed export (default for directed patches). - if not as_undirected: - if self.is_multigraph: - Gd: Union[nx.DiGraph, nx.MultiDiGraph] = nx.MultiDiGraph() - else: - Gd = nx.DiGraph() - - for node in self.nodes: - Gd.add_node(node) - - if self.is_multigraph: - for u, v, key, attrs in self.edges: # type: ignore[misc] - Gd.add_edge(u, v, key=int(key), **dict(attrs)) - else: - for u, v, attrs in self.edges: # type: ignore[misc] - Gd.add_edge(u, v, **dict(attrs)) - return Gd - - # Undirected export for undirected patches. - if not self.is_directed: - if self.is_multigraph: - Gu: Union[nx.Graph, nx.MultiGraph] = nx.MultiGraph() - else: - Gu = nx.Graph() - - for node in self.nodes: - Gu.add_node(node) - - if self.is_multigraph: - for u, v, key, attrs in self.edges: # type: ignore[misc] - Gu.add_edge(u, v, key=int(key), **dict(attrs)) - else: - for u, v, attrs in self.edges: # type: ignore[misc] - Gu.add_edge(u, v, **dict(attrs)) - return Gu - - # Directed patch -> undirected view. - if undirected_mode == 'multigraph': - Gu2 = nx.MultiGraph() - for node in self.nodes: - Gu2.add_node(node) - - if self.is_multigraph: - for u, v, key, attrs in self.edges: # type: ignore[misc] - data = dict(attrs) - data['_pbc_tail'] = u - data['_pbc_head'] = v - data['_pbc_key'] = int(key) - Gu2.add_edge(u, v, **data) - else: - for u, v, attrs in self.edges: # type: ignore[misc] - data = dict(attrs) - data['_pbc_tail'] = u - data['_pbc_head'] = v - data['_pbc_key'] = None - Gu2.add_edge(u, v, **data) - return Gu2 - - if undirected_mode != 'orig_edges': - raise ValueError('invalid undirected_mode') - - Gu3 = nx.Graph() - for node in self.nodes: - Gu3.add_node(node) - - def _canon_pair(a: NodeInst, b: NodeInst) -> Tuple[NodeInst, NodeInst]: - uu, vv = stable_sorted([a, b]) - return uu, vv - - buckets: Dict[Tuple[NodeInst, NodeInst], List[Dict[str, Any]]] = {} - if self.is_multigraph: - for u, v, key, attrs in self.edges: # type: ignore[misc] - a, b = _canon_pair(u, v) - rec = { - 'tail': u, - 'head': v, - 'key': int(key), - 'attrs': dict(attrs), - } - buckets.setdefault((a, b), []).append(rec) - else: - for u, v, attrs in self.edges: # type: ignore[misc] - a, b = _canon_pair(u, v) - rec = { - 'tail': u, - 'head': v, - 'key': None, - 'attrs': dict(attrs), - } - buckets.setdefault((a, b), []).append(rec) - - for (a, b), recs in buckets.items(): - try: - recs.sort(key=lambda r: (r['tail'], r['head'], r['key'])) - except TypeError: - recs.sort( - key=lambda r: ( - fallback_key(r['tail']), - fallback_key(r['head']), - -1 if r['key'] is None else int(r['key']), - ) - ) - Gu3.add_edge(a, b, orig_edges=recs) - return Gu3 - - -def lift_patch( - G: PeriodicDiGraphLike, - seed: NodeInst, - *, - radius: Optional[int] = None, - box: Optional[Tuple[Tuple[int, int], ...]] = None, - box_rel: Optional[Tuple[Tuple[int, int], ...]] = None, - include_edges: bool = True, - max_nodes: Optional[int] = None, - node_order: Optional[Callable[[NodeInst], Any]] = None, - edge_order: Optional[Callable[[Tuple[Any, ...]], Any]] = None, -) -> LiftPatch: - """Extract a finite patch of the lifted graph around a seed. - - The traversal uses weak connectivity in the infinite lift: from an instance - it considers both outgoing and incoming quotient edges. - - Notes: - The returned patch is directed if `G.is_undirected == False`, and - undirected otherwise. Use `LiftPatch.to_networkx(as_undirected=True, - ...)` to obtain undirected views of directed patches. - - - Args: - G: A periodic graph container. - seed: Seed instance `(u, shift)`. - radius: Optional BFS radius in the lifted graph. - box: Optional absolute half-open bounds per coordinate. - box_rel: Optional bounds relative to `seed.shift`. - include_edges: Whether to include edges between included nodes. - max_nodes: If provided, raise if the patch would include more than - `max_nodes` nodes. - node_order: Optional key function for ordering node instances. - edge_order: Optional key function for ordering edge records. - - Returns: - A :class:`~pbcgraph.alg.lift.LiftPatch`. - - Raises: - LiftPatchError: On invalid inputs or if `max_nodes` is exceeded. - """ - dim = int(G.dim) - u0, s0 = seed - validate_tvec(s0, dim) - if radius is None and box is None and box_rel is None: - raise LiftPatchError( - 'at least one of radius, box, or box_rel is required' - ) - if radius is not None: - radius = int(radius) - if radius < 0: - raise LiftPatchError('radius must be non-negative') - - abs_box: Optional[Tuple[Tuple[int, int], ...]] = None - if box is not None: - abs_box = _validate_box(box, dim) - - abs_box_rel: Optional[Tuple[Tuple[int, int], ...]] = None - if box_rel is not None: - rel = _validate_box(box_rel, dim) - out: List[Tuple[int, int]] = [] - for (lo, hi), x0 in zip(rel, s0): - out.append((int(x0) + lo, int(x0) + hi)) - abs_box_rel = tuple(out) - - eff_box = _intersect_boxes(abs_box, abs_box_rel, dim) - if not _in_box(s0, eff_box): - raise LiftPatchError('seed instance is outside the effective box') - - if max_nodes is not None: - max_nodes = int(max_nodes) - if max_nodes <= 0: - raise LiftPatchError('max_nodes must be positive') - - # ----------------- - # Traversal - # ----------------- - visited: Dict[NodeInst, int] = {seed: 0} - q: deque[NodeInst] = deque([seed]) - - def iter_weak_neighbors(inst: NodeInst) -> Iterator[NodeInst]: - for v, s2 in G.neighbors_inst(inst, keys=False, data=False): - yield v, s2 - for v, s2 in G.in_neighbors_inst(inst, keys=False, data=False): - yield v, s2 - - while q: - cur = q.popleft() - dcur = visited[cur] - if radius is not None and dcur >= radius: - continue - - for nb in iter_weak_neighbors(cur): - _v, s2 = nb - validate_tvec(s2, dim) - if not _in_box(s2, eff_box): - continue - if nb in visited: - continue - visited[nb] = dcur + 1 - q.append(nb) - if max_nodes is not None and len(visited) > max_nodes: - raise LiftPatchError('max_nodes exceeded during traversal') - - # Canonical node order. - nodes_list = list(visited.keys()) - if node_order is None: - nodes = tuple(stable_sorted(nodes_list)) - else: - nodes = tuple(sorted(nodes_list, key=node_order)) - - patch_is_directed = not bool(G.is_undirected) - - # ----------------- - # Edge inclusion (no explicit tvec) - # ----------------- - edges_out: List[Union[PatchEdgeRec, PatchMultiEdgeRec]] = [] - if include_edges: - included_set = set(visited) - - if patch_is_directed: - records: List[ - Tuple[NodeInst, NodeInst, int, Any, Dict[str, Any]] - ] = [] - for inst in nodes: - for v, s2, k, attrs in G.neighbors_inst( - inst, keys=True, data=True - ): - nb = (v, s2) - if nb not in included_set: - continue - sel_key = (inst, nb, int(k)) - sc = ( - edge_order(sel_key) - if edge_order is not None - else sel_key - ) - records.append((inst, nb, int(k), sc, dict(attrs))) - - try: - records.sort(key=lambda r: (r[3], r[0], r[1], r[2])) - except TypeError: - records.sort( - key=lambda r: ( - fallback_key(r[3]), - fallback_key(r[0]), - fallback_key(r[1]), - r[2], - ) - ) - - if G.is_multigraph: - for u_inst, v_inst, kk, _sc, attrs in records: - edges_out.append((u_inst, v_inst, int(kk), dict(attrs))) - else: - for u_inst, v_inst, _kk, _sc, attrs in records: - edges_out.append((u_inst, v_inst, dict(attrs))) - - else: - candidates: List[ - Tuple[NodeInst, NodeInst, int, Dict[str, Any]] - ] = [] - for inst in nodes: - for v, s2, k, attrs in G.neighbors_inst( - inst, keys=True, data=True - ): - nb = (v, s2) - if nb not in included_set: - continue - candidates.append((inst, nb, int(k), dict(attrs))) - for v, s2, k, attrs in G.in_neighbors_inst( - inst, keys=True, data=True - ): - nb = (v, s2) - if nb not in included_set: - continue - candidates.append((inst, nb, int(k), dict(attrs))) - - # Canonicalize endpoints to undirected pairs. - canon: List[Tuple[NodeInst, NodeInst, int, Dict[str, Any]]] = [] - for a, b, k, attrs in candidates: - u_inst, v_inst = stable_sorted([a, b]) - canon.append((u_inst, v_inst, k, attrs)) - - # Deduplicate reciprocal realizations deterministically. - best: Dict[ - Tuple[NodeInst, NodeInst, Optional[int]], - Tuple[Any, Dict[str, Any]], - ] = {} - for u_inst, v_inst, k, attrs in canon: - if G.is_multigraph: - eid: Tuple[ - NodeInst, NodeInst, Optional[int] - ] = (u_inst, v_inst, k) - sel_key = (u_inst, v_inst, k) - else: - eid = (u_inst, v_inst, None) - sel_key = (u_inst, v_inst, k) - - score = ( - edge_order(sel_key) - if edge_order is not None - else sel_key - ) - - if eid not in best: - best[eid] = (score, attrs) - continue - prev_score, _prev_attrs = best[eid] - try: - better = score < prev_score - except TypeError: - better = fallback_key(score) < fallback_key(prev_score) - if better: - best[eid] = (score, attrs) - - if G.is_multigraph: - out_multi: List[Tuple[Any, Any, int, Any]] = [] - for (u_inst, v_inst, kk), (sc, attrs) in best.items(): - assert kk is not None - out_multi.append((u_inst, v_inst, int(kk), (sc, attrs))) - _try_sort_patch_edges(out_multi) - for u_inst, v_inst, kk, payload in out_multi: - _sc, attrs = payload - edges_out.append((u_inst, v_inst, int(kk), dict(attrs))) - else: - out_simple: List[Tuple[Any, Any, int, Any]] = [] - for (u_inst, v_inst, _), (sc, attrs) in best.items(): - out_simple.append((u_inst, v_inst, 0, (sc, attrs))) - _try_sort_patch_edges(out_simple) - for u_inst, v_inst, _kk, payload in out_simple: - _sc, attrs = payload - edges_out.append((u_inst, v_inst, dict(attrs))) - return LiftPatch( - nodes=nodes, - edges=tuple(edges_out), - seed=seed, - radius=radius, - box=eff_box, - _is_multigraph=bool(G.is_multigraph), - _is_directed=patch_is_directed, - ) - - -TreeEdgeRec = Tuple[NodeId, NodeId, TVec, int] - - -@dataclass(frozen=True) -class CanonicalLift: - """A deterministic finite representation of a single strand. - - Attributes: - nodes: Node instances `(u, shift)` in canonical order. Contains - exactly one instance for every quotient node in the component. - strand_key: Target strand (coset) key in `Z^d / L`. - anchor_site: Quotient node chosen to be placed in `anchor_shift`. - anchor_shift: Anchor cell translation vector. - placement: Placement mode used to construct the lift. - score: Placement score (smaller is better; 0 is best). - tree_edges: Optional spanning-tree edge records for debugging. - """ - - nodes: Tuple[NodeInst, ...] - strand_key: Hashable - anchor_site: NodeId - anchor_shift: TVec - placement: str - score: Union[int, float] - tree_edges: Optional[Tuple[TreeEdgeRec, ...]] = None - - -def _sorted_nodes_by_key( - nodes: Sequence[NodeId], - node_order: Optional[Callable[[NodeId], Any]], -) -> Tuple[NodeId, ...]: - seq = list(nodes) - if not seq: - return () - - if node_order is None: - return tuple(stable_sorted(seq)) - - def k(u: NodeId) -> Any: - return node_order(u) - - try: - return tuple( - sorted(seq, key=lambda u: (k(u), fallback_key(u))) - ) - except TypeError: - return tuple( - sorted(seq, key=lambda u: (fallback_key(k(u)), fallback_key(u))) - ) - - -def _sorted_node_insts( - insts: Sequence[NodeInst], - node_order: Optional[Callable[[NodeId], Any]], -) -> Tuple[NodeInst, ...]: - seq = list(insts) - if not seq: - return () - - if node_order is None: - try: - return tuple(sorted(seq, key=lambda x: (x[0], x[1]))) - except TypeError: - return tuple(sorted(seq, key=lambda x: (fallback_key(x[0]), x[1]))) - - def k(u: NodeId) -> Any: - return node_order(u) - - try: - return tuple(sorted( - seq, key=lambda x: ( - k(x[0]), x[1], fallback_key(x[0]) - ) - )) - except TypeError: - return tuple(sorted( - seq, key=lambda x: ( - fallback_key(k(x[0])), x[1], fallback_key(x[0]) - ) - )) - - -def _compute_lift_score( - snf: Any, - rel_shifts: Dict[NodeId, TVec], - nodes: Sequence[NodeId], - score: Literal['l1', 'l2'], -) -> int: - """Compute placement score for a lift. - - Args: - snf: SNF decomposition of the component translation subgroup. - rel_shifts: Per-node relative shifts with respect to the anchor site. - nodes: Quotient node ids in the component. - score: Score metric: 'l1' or 'l2'. - - Returns: - The deterministic integer score (smaller is better). - - Raises: - CanonicalLiftError: If the SNF decomposition is invalid. - """ - r = int(snf.rank) - total = 0 - for u in nodes: - y = snf.apply_U(rel_shifts[u]) - node_mag = 0 - for i in range(r): - di = int(snf.diag[i]) - if di == 0: - raise CanonicalLiftError('invalid SNF diagonal entry') - qi = int(y[i] // di) - if score == 'l1': - node_mag += abs(qi) - else: - node_mag += qi * qi - total += node_mag - return int(total) - - -def _compute_rel_abs_shifts( - pot: Dict[NodeId, TVec], - *, - anchor_site: NodeId, - anchor_shift: TVec, -) -> Tuple[Dict[NodeId, TVec], Dict[NodeId, TVec]]: - """Compute relative and absolute shifts for a given anchor site.""" - pot_anchor = pot[anchor_site] - rel: Dict[NodeId, TVec] = {} - abs_s: Dict[NodeId, TVec] = {} - for u, pu in pot.items(): - r = sub_tvec(pu, pot_anchor) - rel[u] = r - abs_s[u] = add_tvec(anchor_shift, r) - return rel, abs_s - - -def _build_internal_adj( - component: Any, - abs_shift: Dict[NodeId, TVec], -) -> Dict[NodeId, FrozenSet[NodeId]]: - """Build induced internal undirected adjacency on selected instances. - - An undirected adjacency between quotient nodes `u` and `v` exists if at - least one directed periodic edge between them is consistent with the - selected absolute shifts. - - Args: - component: PeriodicComponent. - abs_shift: Mapping `u -> shift` for exactly the component nodes. - - Returns: - Dict mapping node id to a frozen set of adjacent node ids. - """ - adj: Dict[NodeId, set[NodeId]] = {u: set() for u in component.nodes} - for u in component.nodes: - su = abs_shift[u] - for v, t, _k in component.graph.neighbors(u, keys=True, data=False): - if v not in component.nodes: - continue - if abs_shift[v] == add_tvec(su, t): - adj[u].add(v) - adj[v].add(u) - return {u: frozenset(nbs) for u, nbs in adj.items()} - - -def _is_connected_undirected( - adj: Dict[NodeId, FrozenSet[NodeId]], - nodes_ordered: Sequence[NodeId], - *, - skip: Optional[NodeId] = None, -) -> bool: - """Return True if the induced graph is connected - (optionally skipping a node).""" - nodes = [u for u in nodes_ordered if u != skip] - if not nodes: - return True - - start = nodes[0] - seen: set[NodeId] = {start} - q: deque[NodeId] = deque([start]) - - while q: - u = q.popleft() - for v in stable_sorted(list(adj.get(u, frozenset()))): - if v == skip: - continue - if v in seen: - continue - seen.add(v) - q.append(v) - return len(seen) == len(nodes) - - -def _boundary_deltas_for_node( - component: Any, - abs_shift: Dict[NodeId, TVec], - u: NodeId, -) -> Tuple[TVec, ...]: - """Enumerate per-node deltas induced by boundary periodic edges.""" - su = abs_shift[u] - deltas: set[TVec] = set() - - for v, t, _k in component.graph.neighbors(u, keys=True, data=False): - if v not in component.nodes: - continue - desired = add_tvec(su, t) - if abs_shift[v] == desired: - continue - # Want: abs_shift[v] == (su + delta) + t - delta = sub_tvec(sub_tvec(abs_shift[v], su), t) - deltas.add(delta) - - for v, t_in, _k in component.graph.in_neighbors(u, keys=True, data=False): - if v not in component.nodes: - continue - desired_u = add_tvec(abs_shift[v], t_in) - if desired_u == su: - continue - # Want: (su + delta) == abs_shift[v] + t_in - delta = sub_tvec(desired_u, su) - deltas.add(delta) - - if not deltas: - return () - - try: - return tuple(sorted(deltas)) - except TypeError: - return tuple(sorted(deltas, key=fallback_key)) - - -def canonical_lift( - component: Any, - *, - strand_key: Optional[Hashable] = None, - seed: Optional[NodeInst] = None, - anchor_shift: Optional[TVec] = None, - placement: Literal['tree', 'best_anchor', 'greedy_cut'] = 'tree', - score: Literal['l1', 'l2'] = 'l1', - return_tree: bool = False, - node_order: Optional[Callable[[NodeId], Any]] = None, - edge_order: Optional[Callable[[Tuple[Any, ...]], Any]] = None, -) -> CanonicalLift: - """Construct a deterministic finite representation of one strand. - - v0.1.2 step4 implements `placement='tree'`, `placement='best_anchor'`, and - `placement='greedy_cut'`. - - Args: - component: A :class:`~pbcgraph.component.PeriodicComponent`. - strand_key: Optional explicit strand key. - seed: Optional seed instance `(u, shift)`. - anchor_shift: Optional anchor cell shift. - placement: Placement mode (`'tree'` in step2). - score: Score metric: `'l1'` or `'l2'`. - return_tree: If True, include spanning-tree edge records. - node_order: Optional ordering key for quotient node ids. - edge_order: Optional ordering key for periodic edges (reserved). - - Returns: - A :class:`~pbcgraph.alg.lift.CanonicalLift`. - - Raises: - CanonicalLiftError: On invalid inputs or if the requested strand does - not intersect the anchor cell. - """ - del edge_order # Reserved for later placement modes. - - if placement not in ('tree', 'best_anchor', 'greedy_cut'): - raise CanonicalLiftError( - "canonical_lift placement must be one of 'tree', " - "'best_anchor', 'greedy_cut'" - ) - - dim = int(component.graph.dim) - - if seed is not None: - u_seed, s_seed = seed - validate_tvec(s_seed, dim) - else: - u_seed = None # noqa: F841 - s_seed = None - - if anchor_shift is None: - if s_seed is not None: - anchor_shift = s_seed - else: - anchor_shift = zero_tvec(dim) - else: - validate_tvec(anchor_shift, dim) - - if strand_key is None: - if seed is not None: - try: - K = component.inst_key(seed) - except KeyError as e: - raise CanonicalLiftError( - 'seed does not belong to component' - ) from e - else: - nodes_sorted = _sorted_nodes_by_key( - list(component.nodes), node_order - ) - if not nodes_sorted: - raise CanonicalLiftError('component has no nodes') - default_seed = (nodes_sorted[0], zero_tvec(dim)) - K = component.inst_key(default_seed) - else: - K = strand_key - - eligible: List[NodeId] = [] - for u in component.nodes: - if component.inst_key((u, anchor_shift)) == K: - eligible.append(u) - - if not eligible: - raise CanonicalLiftError( - 'requested strand_key does not intersect the anchor cell' - ) - - pot = {u: component.potential(u) for u in component.nodes} - - snf = component._snf - if snf is None: - raise CanonicalLiftError('component has no SNF decomposition') - - if score not in ('l1', 'l2'): - raise CanonicalLiftError("score must be 'l1' or 'l2'") - - nodes_list = list(component.nodes) - eligible_sorted = _sorted_nodes_by_key(eligible, node_order) - - if placement == 'tree': - anchor_site = eligible_sorted[0] - rel_shift, abs_shift = _compute_rel_abs_shifts( - pot, - anchor_site=anchor_site, - anchor_shift=anchor_shift, - ) - total_score = _compute_lift_score(snf, rel_shift, nodes_list, score) - else: - best_anchor_site: Optional[NodeId] = None - best_rel: Optional[Dict[NodeId, TVec]] = None - best_abs: Optional[Dict[NodeId, TVec]] = None - best_score: Optional[int] = None - - for a in eligible_sorted: - rel_a, abs_a = _compute_rel_abs_shifts( - pot, - anchor_site=a, - anchor_shift=anchor_shift, - ) - s = _compute_lift_score(snf, rel_a, nodes_list, score) - if best_score is None or s < best_score: - best_score = int(s) - best_anchor_site = a - best_rel = rel_a - best_abs = abs_a - - if best_anchor_site is None or best_rel is None or best_abs is None: - raise CanonicalLiftError('failed to select anchor site') - - anchor_site = best_anchor_site - rel_shift = best_rel - abs_shift = best_abs - total_score = int(best_score) - - if placement == 'greedy_cut': - # Start from the best-anchor placement and perform local, per-node - # moves by elements of the translation subgroup L that improve score - # while keeping the induced internal graph connected. - nodes_sorted = _sorted_nodes_by_key(list(component.nodes), node_order) - abs_cur: Dict[NodeId, TVec] = dict(abs_shift) - cur_score = int(total_score) - - while True: - moved = False - adj = _build_internal_adj(component, abs_cur) - if not _is_connected_undirected(adj, nodes_sorted): - raise CanonicalLiftError( - 'internal induced graph is disconnected' - ) - - for u in nodes_sorted: - if u == anchor_site: - continue - deltas = _boundary_deltas_for_node(component, abs_cur, u) - if not deltas: - continue - - # Pre-filter: u must not be an articulation point of the - # current internal graph. - if not _is_connected_undirected(adj, nodes_sorted, skip=u): - continue - - best_move: Optional[Tuple[int, TVec]] = None - old_s = abs_cur[u] - - for delta in deltas: - new_s = add_tvec(old_s, delta) - if component.inst_key((u, new_s)) != K: - continue - - abs_cur[u] = new_s - new_adj = _build_internal_adj(component, abs_cur) - ok = True - if not new_adj.get(u, frozenset()): - ok = False - elif not _is_connected_undirected(new_adj, nodes_sorted): - ok = False - - if ok: - rel_tmp = { - x: sub_tvec(abs_cur[x], abs_cur[anchor_site]) - for x in component.nodes - } - s = _compute_lift_score( - snf, rel_tmp, nodes_list, score - ) - if s < cur_score: - if best_move is None: - best_move = (int(s), delta) - else: - best_s, best_delta = best_move - if int(s) < best_s or ( - int(s) == best_s and delta < best_delta - ): - best_move = (int(s), delta) - - abs_cur[u] = old_s - - if best_move is not None: - best_s, best_delta = best_move - abs_cur[u] = add_tvec(abs_cur[u], best_delta) - cur_score = int(best_s) - moved = True - break - - if not moved: - break - - abs_shift = abs_cur - total_score = int(cur_score) - - insts = [(u, abs_shift[u]) for u in component.nodes] - insts_sorted = _sorted_node_insts(insts, node_order) - - tree_edges: Optional[Tuple[TreeEdgeRec, ...]] = None - if return_tree: - recs: List[TreeEdgeRec] = [] - children = _sorted_nodes_by_key( - list(component._tree_parent.keys()), node_order - ) - for child in children: - parent, _t, k = component._tree_parent[child] - tvec = sub_tvec(abs_shift[child], abs_shift[parent]) - recs.append((parent, child, tvec, int(k))) - tree_edges = tuple(recs) +from pbcgraph.alg._canonical_lift import CanonicalLift, canonical_lift +from pbcgraph.alg._lift_patch import LiftPatch, lift_patch - return CanonicalLift( - nodes=insts_sorted, - strand_key=K, - anchor_site=anchor_site, - anchor_shift=anchor_shift, - placement=placement, - score=int(total_score), - tree_edges=tree_edges, - ) +__all__ = ['lift_patch', 'LiftPatch', 'canonical_lift', 'CanonicalLift'] diff --git a/src/pbcgraph/component.py b/src/pbcgraph/component.py index 8bfbd04..66f6db7 100644 --- a/src/pbcgraph/component.py +++ b/src/pbcgraph/component.py @@ -13,6 +13,7 @@ from collections import deque from dataclasses import dataclass, field +from types import MappingProxyType from typing import ( Callable, Any, @@ -20,12 +21,12 @@ FrozenSet, Hashable, List, + Mapping, Optional, Tuple, ) from pbcgraph.core.exceptions import StaleComponentError -from pbcgraph.core.ordering import fallback_key from pbcgraph.core.types import ( NodeId, NodeInst, @@ -128,6 +129,31 @@ def _require_fresh(self) -> None: 'PeriodicComponent is stale: graph structure has changed' ) + @property + def snf(self) -> SNFDecomposition: + """Smith normal form decomposition for the translation subgroup. + + Raises: + StaleComponentError: If the parent graph has changed structurally. + """ + self._require_fresh() + dec = self._snf + assert dec is not None + return dec + + def tree_parent_map(self) -> Mapping[NodeId, Tuple[NodeId, TVec, int]]: + """Read-only spanning-tree parent mapping. + + The mapping records, for each non-root node `child`, a tuple + `(parent, tvec, key)` describing the tree edge used to assign the + node potential. + + Raises: + StaleComponentError: If the parent graph has changed structurally. + """ + self._require_fresh() + return MappingProxyType(self._tree_parent) + # ----------------- # Potential # ----------------- @@ -376,36 +402,12 @@ def _compute_generators(self, pot: Dict[NodeId, TVec]) -> List[TVec]: gens.append(g) return gens - def _node_leq(a: NodeId, b: NodeId) -> bool: - try: - return a <= b # type: ignore[operator] - except TypeError: - return fallback_key(a) <= fallback_key(b) - - seen = set() - for u, v, t, k in self.graph.edges(keys=True, data=False, tvec=True): + for u, v, t, k in self.graph.undirected_edges_unique( + keys=True, data=False, tvec=True + ): if u not in self.nodes or v not in self.nodes: continue - - tv = tuple(int(x) for x in t) - if u == v: - tv_abs = min(tv, neg_tvec(tv)) - ident = (u, u, tv_abs, int(k)) - if ident in seen: - continue - seen.add(ident) - g = tv_abs - else: - if _node_leq(u, v): - a, b, tv_use = u, v, tv - else: - a, b, tv_use = v, u, neg_tvec(tv) - ident = (a, b, tv_use, int(k)) - if ident in seen: - continue - seen.add(ident) - g = sub_tvec(add_tvec(pot[a], tv_use), pot[b]) - + g = sub_tvec(add_tvec(pot[u], t), pot[v]) if _tvec_is_zero(g): continue gens.append(g) diff --git a/src/pbcgraph/core/__init__.py b/src/pbcgraph/core/__init__.py index 02369cc..6237af8 100644 --- a/src/pbcgraph/core/__init__.py +++ b/src/pbcgraph/core/__init__.py @@ -6,4 +6,8 @@ Most users should import from the top-level `pbcgraph` namespace. """ -__all__ = [] +from pbcgraph.core.constants import PBC_META_KEY + +__all__ = [ + 'PBC_META_KEY', +] diff --git a/src/pbcgraph/core/constants.py b/src/pbcgraph/core/constants.py new file mode 100644 index 0000000..3c6632f --- /dev/null +++ b/src/pbcgraph/core/constants.py @@ -0,0 +1,14 @@ +"""Shared constants. + +This module centralizes constant values used across pbcgraph. + +Some constants are part of the public API when they define interoperability +contracts with external data structures. +""" + +from __future__ import annotations + + +# Reserved key used to store pbcgraph export metadata inside external +# data structures (e.g. NetworkX edge attribute dictionaries). +PBC_META_KEY = '__pbcgraph__' diff --git a/src/pbcgraph/graph/__init__.py b/src/pbcgraph/graph/__init__.py new file mode 100644 index 0000000..4b85fa8 --- /dev/null +++ b/src/pbcgraph/graph/__init__.py @@ -0,0 +1,31 @@ +"""Periodic graph containers. + +pbcgraph represents a periodic graph by a finite quotient graph, where each +directed quotient edge carries an integer translation vector in ``Z^d``. + +Internally, quotient edges are stored in a NetworkX +:class:`networkx.MultiDiGraph`. + +Two container families are provided: + +- :class:`~pbcgraph.graph.PeriodicDiGraph` / + :class:`~pbcgraph.graph.PeriodicGraph`: + at most one edge per ``(u, v, tvec)``. +- :class:`~pbcgraph.graph.PeriodicMultiDiGraph` / + :class:`~pbcgraph.graph.PeriodicMultiGraph`: + allow multiple edges per ``(u, v, tvec)`` (distinguished by edge keys). +""" + +from __future__ import annotations + +from pbcgraph.graph.directed import PeriodicDiGraph +from pbcgraph.graph.multidirected import PeriodicMultiDiGraph +from pbcgraph.graph.undirected import PeriodicGraph +from pbcgraph.graph.multigraph import PeriodicMultiGraph + +__all__ = [ + 'PeriodicDiGraph', + 'PeriodicGraph', + 'PeriodicMultiDiGraph', + 'PeriodicMultiGraph', +] diff --git a/src/pbcgraph/graph.py b/src/pbcgraph/graph/directed.py similarity index 56% rename from src/pbcgraph/graph.py rename to src/pbcgraph/graph/directed.py index cec2abe..7eda837 100644 --- a/src/pbcgraph/graph.py +++ b/src/pbcgraph/graph/directed.py @@ -1,34 +1,8 @@ -"""Periodic graph containers. - -pbcgraph represents a periodic graph by a finite quotient graph, where each -directed quotient edge carries an integer translation vector in ``Z^d``. - -Internally, quotient edges are stored in a NetworkX -:class:`networkx.MultiDiGraph`. -However, pbcgraph exposes *two* containers families: - -- `PeriodicDiGraph` / `PeriodicGraph`: at most one edge per ``(u, v, tvec)``. -- `PeriodicMultiDiGraph` / `PeriodicMultiGraph`: allow multiple edges per - ``(u, v, tvec)`` (distinguished by edge keys). - -Exports: - PeriodicDiGraph: Directed periodic graph on ``Z^d`` (unique per - ``(u, v, tvec)``). - PeriodicGraph: Undirected periodic graph implemented as a pair of directed - realizations per undirected edge (unique per undirected - ``{u, v, tvec}`` up to reversal). - PeriodicMultiDiGraph: Directed periodic multigraph on ``Z^d``. - PeriodicMultiGraph: Undirected periodic multigraph. - -Attributes: - _TVEC_ATTR: Internal edge-data key for translation vectors. - _USER_ATTRS: Internal edge-data key for the live user-attributes mapping. -""" +"""Directed periodic graph container.""" from __future__ import annotations -from dataclasses import dataclass - +from types import MappingProxyType from typing import ( TYPE_CHECKING, Any, @@ -41,24 +15,19 @@ Tuple, ) -from types import MappingProxyType - import networkx as nx if TYPE_CHECKING: - from pbcgraph.component import PeriodicComponent from pbcgraph.alg.lift import LiftPatch + from pbcgraph.component import PeriodicComponent from pbcgraph.alg.components import components as _components - from pbcgraph.core.ordering import ( + fallback_key, stable_sorted, stable_tvec, - # stable_unique_sorted, try_sort_edges, - try_sort_neighbor_edges, ) - from pbcgraph.core.types import ( EdgeKey, NodeId, @@ -69,63 +38,14 @@ sub_tvec, validate_tvec, ) - - -_TVEC_ATTR = '_tvec' -_USER_ATTRS = '_attrs' - - -def _ro(mapping: Dict[str, Any]) -> MappingProxyType: - """Return a read-only live view of a mapping.""" - return MappingProxyType(mapping) - - -def _validate_edge_key(key: EdgeKey) -> None: - """Validate an edge key. - - Edge keys must be ints, but ``bool`` is rejected (even though it is a - subclass of ``int``). - """ - if isinstance(key, bool) or not isinstance(key, int): - raise TypeError('edge key must be an int (bool is not allowed)') - - -@dataclass(frozen=True) -class _UKey: - """Private directed-edge key for undirected containers. - - `PeriodicGraph` and `PeriodicMultiGraph` represent each undirected - edge as two directed realizations. When `u == v` (self-loop in the - quotient) these two realizations would collide in NetworkX if they - shared the same `(u, v, key)` triple. - - `_UKey` splits the user-visible *base* key into two internal keys - distinguished by `dir` in {+1, -1}. - - The public API always exposes the base key (an int). - """ - - base: int - dir: int - - def __post_init__(self) -> None: - if self.dir not in (-1, 1): - raise ValueError('dir must be +1 or -1') - - -def _base_key(k: object) -> int: - """Return the public base edge key for an internal key. - - Args: - k: An internal key, either an int (directed containers) or a - `_UKey` (undirected containers). - - Returns: - The user-visible base key as a Python int. - """ - if isinstance(k, _UKey): - return int(k.base) - return int(k) +from pbcgraph.graph.shared import ( + _TVEC_ATTR, + _USER_ATTRS, + base_key as _base_key, + check_reserved_edge_attrs as _check_reserved_edge_attrs, + ro as _ro, + validate_edge_key as _validate_edge_key, +) class PeriodicDiGraph: @@ -362,6 +282,7 @@ def _add_edge_impl( tvec_norm = stable_tvec(tvec) user_attrs: Dict[str, Any] = dict(attrs) + _check_reserved_edge_attrs(user_attrs) self._g.add_edge( u, v, key=key, **{_TVEC_ATTR: tvec_norm, _USER_ATTRS: user_attrs} ) @@ -455,6 +376,7 @@ def set_edge_attrs( if data is None: raise KeyError((u, v, key)) if attrs: + _check_reserved_edge_attrs(attrs) data[_USER_ATTRS].update(attrs) self.data_version += 1 @@ -490,15 +412,120 @@ def edges( - `(u, v, tvec, attrs)` - `(u, v, tvec, key, attrs)` """ + # Streaming deterministic iteration: + # iterate u, then v, then edges on (u, v) ordered by (tvec, key). + for u in stable_sorted(list(self._g.nodes)): + adj = self._g.adj[u] + for v in stable_sorted(list(adj.keys())): + kd = adj[v] + items: List[Tuple[Tuple[int, ...], int, Any]] = [] + for ik, ed in kd.items(): + items.append( + ( + stable_tvec(ed[_TVEC_ATTR]), + _base_key(ik), + ed[_USER_ATTRS], + ) + ) + items.sort(key=lambda r: (r[0], r[1])) + for tv, k, attrs in items: + if data: + attrs_ro = _ro(attrs) + if tvec: + if keys: + if data: + yield u, v, tv, k, attrs_ro + else: + yield u, v, tv, k + else: + if data: + yield u, v, tv, attrs_ro + else: + yield u, v, tv + else: + if keys: + if data: + yield u, v, k, attrs_ro + else: + yield u, v, k + else: + if data: + yield u, v, attrs_ro + else: + yield u, v + + def undirected_edges_unique( + self, keys: bool = False, data: bool = False, tvec: bool = False + ) -> Iterable: + """Iterate unique undirected edges in deterministic order. + + This iterator is only defined for undirected containers + (`PeriodicGraph` / `PeriodicMultiGraph`). It returns each undirected + quotient edge exactly once, in a canonical orientation. + + Canonicalization rules: + - For `u != v`, the returned endpoints satisfy `u <= v` under the + same ordering policy used elsewhere in pbcgraph. + - For quotient self-loops with nonzero translation, the returned + translation vector is canonicalized to `min(tvec, -tvec)`. + + Args: + keys: If True, include the public base edge key. + data: If True, include the read-only user attribute mapping. + tvec: If True, include the translation vector. + + Returns: + An iterable of: + - `(u, v)` + - `(u, v, attrs)` + - `(u, v, key)` + - `(u, v, tvec)` + - `(u, v, tvec, key)` + - `(u, v, key, attrs)` + - `(u, v, tvec, attrs)` + - `(u, v, tvec, key, attrs)` + + Raises: + TypeError: If called on a directed container. + """ + if not self.is_undirected: + raise TypeError( + 'undirected_edges_unique is only available for ' + 'undirected containers' + ) + records: List[Tuple[Any, Any, Tuple[int, ...], int, Any]] = [] + seen: set[Tuple[Any, Any, Tuple[int, ...], int]] = set() + for u, v, k, edata in self._g.edges(keys=True, data=True): - records.append( - ( - u, v, - stable_tvec(edata[_TVEC_ATTR]), - _base_key(k), edata[_USER_ATTRS] - ) - ) + base = int(_base_key(k)) + tv = stable_tvec(edata[_TVEC_ATTR]) + + if u == v: + tv_neg = stable_tvec(neg_tvec(tv)) + tv_abs = tv if tv <= tv_neg else tv_neg + ident = (u, u, tv_abs, base) + if ident in seen: + continue + seen.add(ident) + records.append((u, u, tv_abs, base, edata[_USER_ATTRS])) + continue + + try: + leq = u <= v # type: ignore[operator] + except TypeError: + leq = fallback_key(u) <= fallback_key(v) + + if leq: + a, b, tv_use = u, v, tv + else: + a, b, tv_use = v, u, stable_tvec(neg_tvec(tv)) + + ident = (a, b, tv_use, base) + if ident in seen: + continue + seen.add(ident) + records.append((a, b, tv_use, base, edata[_USER_ATTRS])) try_sort_edges(records) @@ -546,34 +573,32 @@ def neighbors( if not self._g.has_node(u): raise KeyError(u) - records: List[Tuple[Any, Tuple[int, ...], int, Any]] = [] adj = self._g.adj[u] - for v in adj: + for v in stable_sorted(list(adj.keys())): kd = adj[v] - for k in kd: - ed = kd[k] - records.append( + items: List[Tuple[Tuple[int, ...], int, Any]] = [] + for ik, ed in kd.items(): + items.append( ( - v, stable_tvec(ed[_TVEC_ATTR]), - _base_key(k), ed[_USER_ATTRS] + stable_tvec(ed[_TVEC_ATTR]), + _base_key(ik), + ed[_USER_ATTRS], ) ) - - try_sort_neighbor_edges(records) - - for v, tv, k, attrs in records: - if data: - attrs_ro = _ro(attrs) - if keys: - if data: - yield v, tv, k, attrs_ro - else: - yield v, tv, k - else: + items.sort(key=lambda r: (r[0], r[1])) + for tv, k, attrs in items: if data: - yield v, tv, attrs_ro + attrs_ro = _ro(attrs) + if keys: + if data: + yield v, tv, k, attrs_ro + else: + yield v, tv, k else: - yield v, tv + if data: + yield v, tv, attrs_ro + else: + yield v, tv def in_neighbors( self, u: NodeId, keys: bool = False, data: bool = False @@ -593,34 +618,32 @@ def in_neighbors( if not self._g.has_node(u): raise KeyError(u) - records: List[Tuple[Any, Tuple[int, ...], int, Any]] = [] pred_adj = self._g.pred[u] - for v in pred_adj: + for v in stable_sorted(list(pred_adj.keys())): kd = pred_adj[v] - for k in kd: - ed = kd[k] - records.append( + items: List[Tuple[Tuple[int, ...], int, Any]] = [] + for ik, ed in kd.items(): + items.append( ( - v, stable_tvec(ed[_TVEC_ATTR]), - _base_key(k), ed[_USER_ATTRS] + stable_tvec(ed[_TVEC_ATTR]), + _base_key(ik), + ed[_USER_ATTRS], ) ) - - try_sort_neighbor_edges(records) - - for v, tv, k, attrs in records: - if data: - attrs_ro = _ro(attrs) - if keys: - if data: - yield v, tv, k, attrs_ro - else: - yield v, tv, k - else: + items.sort(key=lambda r: (r[0], r[1])) + for tv, k, attrs in items: if data: - yield v, tv, attrs_ro + attrs_ro = _ro(attrs) + if keys: + if data: + yield v, tv, k, attrs_ro + else: + yield v, tv, k else: - yield v, tv + if data: + yield v, tv, attrs_ro + else: + yield v, tv def neighbors_inst( self, node_inst: NodeInst, keys: bool = False, data: bool = False @@ -792,360 +815,3 @@ def lift_patch( node_order=node_order, edge_order=edge_order, ) - - -class PeriodicMultiDiGraph(PeriodicDiGraph): - """Directed periodic multigraph on ``Z^d``. - - Unlike `PeriodicDiGraph`, this container allows multiple edges for the same - directed triple ``(u, v, tvec)``. Such parallel edges are distinguished by - their edge keys. - """ - - @property - def is_multigraph(self) -> bool: - """Whether this container allows multiple edges per `(u, v, tvec)`.""" - return True - - def add_edge( - self, - u: NodeId, - v: NodeId, - tvec: TVec, - key: Optional[EdgeKey] = None, - **attrs: Any, - ) -> EdgeKey: - """Add a directed periodic edge (parallel edges allowed).""" - return self._add_edge_impl(u, v, tvec, key=key, attrs=dict(attrs)) - - -class PeriodicGraph(PeriodicDiGraph): - """Undirected periodic graph. - - Internally, an undirected periodic edge is represented by two directed - realizations: - - - ``u -> v`` with translation ``tvec`` - - ``v -> u`` with translation ``-tvec`` - - Both realizations share the same underlying user-attributes dict. - The public API returns read-only live views of that mapping. - - **Important:** a crystallographically common pattern is a quotient - self-loop with non-zero translation (``u == v`` and ``tvec != 0``), - representing a bond to a periodic image of the same motif. - - NetworkX identifies multiedges by ``(u, v, key)``. For self-loops, - the two directed realizations would collide if they shared the same key. - - To avoid this, `PeriodicGraph` stores directed realizations using a private - internal key type `_UKey(base, dir)`, where `base` is the user-visible - integer key and `dir` is `+1` / `-1`. The public API always exposes - the base key. - - In addition to the undirected-invariant pairing, this container enforces - an invariant analogous to `PeriodicDiGraph`: - - *For any undirected triple ``{u, v, tvec}`` (up to reversal), at most one - edge exists.* - - To allow multiple contacts for the same motif pair and translation, use - `PeriodicMultiGraph`. - - Notes: - `PeriodicGraph` is a subclass of `PeriodicDiGraph`, but restricts some - operations (for example, directed connectivity modes in algorithms). - """ - - @property - def is_undirected(self) -> bool: - """Whether this container should be treated as undirected - by algorithms.""" - return True - - def edges( - self, keys: bool = False, data: bool = False, tvec: bool = False - ) -> Iterable: - """Iterate directed realizations in deterministic order. - - This iterator yields *directed realizations* of undirected edges. - - Note: - For self-loop periodic edges (``u == v`` and ``tvec != 0``), the - two directed realizations share the same ``(u, v, key)`` triple and - differ only by the translation vector. If `keys=True` but - `tvec=False`, this may yield duplicate ``(u, u, key)`` records. Use - `tvec=True` to disambiguate. - - See `PeriodicDiGraph.edges` for the record formats. - """ - return super().edges(keys=keys, data=data, tvec=tvec) - - def _internal_keys_for_base( - self, u: NodeId, v: NodeId, key: EdgeKey - ) -> List[object]: - """Return internal keys on (u -> v) whose public base key - equals `key`.""" - kd = self._g.get_edge_data(u, v) or {} - out: List[object] = [] - for ik in kd: - if isinstance(ik, _UKey): - if int(ik.base) == int(key): - out.append(ik) - else: - if int(ik) == int(key): - out.append(ik) - return out - - def _choose_internal_key( - self, u: NodeId, v: NodeId, key: EdgeKey - ) -> object: - """Choose a deterministic internal key for accessing - shared attrs/tvec.""" - keys = self._internal_keys_for_base(u, v, key) - if not keys: - raise KeyError((u, v, key)) - # Prefer the "forward" realization when present. - for ik in keys: - if isinstance(ik, _UKey) and ik.dir == 1: - return ik - # Otherwise choose a deterministic order. - return sorted( - keys, - key=lambda x: ( - 0 if isinstance(x, _UKey) else 1, - getattr(x, 'dir', 0), - repr(x), - ), - )[0] - - def _has_undirected_base(self, u: NodeId, v: NodeId, key: EdgeKey) -> bool: - """Return True if an undirected edge with base key exists - between u and v.""" - if u != v: - return ( - len(self._internal_keys_for_base(u, v, key)) == 1 and - len(self._internal_keys_for_base(v, u, key)) == 1 - ) - # Self-loop: both realizations live on (u -> u). - return len(self._internal_keys_for_base(u, u, key)) == 2 - - def add_edge( - self, - u: NodeId, - v: NodeId, - tvec: TVec, - key: Optional[EdgeKey] = None, - **attrs: Any, - ) -> EdgeKey: - validate_tvec(tvec, self._dim) - - existing = self._key_for_tvec(u, v, tvec) - existing_rev = self._key_for_tvec(v, u, neg_tvec(tvec)) - if existing is not None or existing_rev is not None: - raise ValueError( - 'undirected edge already exists for {u, v, tvec}: ' - f'({u!r}, {v!r}, {tuple(tvec)!r}); ' - f'key={existing if existing is not None else existing_rev!r}' - ) - - return self._add_undirected_impl( - u, v, tvec, key=key, attrs=dict(attrs) - ) - - def _add_undirected_impl( - self, - u: NodeId, - v: NodeId, - tvec: TVec, - *, - key: Optional[EdgeKey], - attrs: Dict[str, Any], - ) -> EdgeKey: - """Implementation for adding an undirected edge (no tvec checks).""" - validate_tvec(tvec, self._dim) - if not self._g.has_node(u): - self.add_node(u) - if not self._g.has_node(v): - self.add_node(v) - - if key is None: - key = self._alloc_key_undirected(u, v) - else: - _validate_edge_key(key) - - # Disallow overwriting an existing base key in either direction. - if ( - self._internal_keys_for_base(u, v, key) - or self._internal_keys_for_base(v, u, key) - ): - raise KeyError((u, v, key)) - - tvec_norm = stable_tvec(tvec) - user_attrs: Dict[str, Any] = dict(attrs) - - k_fwd = _UKey(int(key), 1) - k_rev = _UKey(int(key), -1) - - self._g.add_edge( - u, - v, - key=k_fwd, - **{_TVEC_ATTR: tvec_norm, _USER_ATTRS: user_attrs}, - ) - self._g.add_edge( - v, - u, - key=k_rev, - **{ - _TVEC_ATTR: stable_tvec(neg_tvec(tvec)), - _USER_ATTRS: user_attrs, - }, - ) - self.structural_version += 1 - return int(key) - - def has_edge( - self, u: NodeId, v: NodeId, key: Optional[EdgeKey] = None - ) -> bool: - if key is None: - if u != v: - return self._g.has_edge(u, v) and self._g.has_edge(v, u) - kd = self._g.get_edge_data(u, u) or {} - return len(kd) >= 2 - return self._has_undirected_base(u, v, key) - - def edge_tvec(self, u: NodeId, v: NodeId, key: EdgeKey) -> TVec: - ik = self._choose_internal_key(u, v, key) - data = self._g.get_edge_data(u, v, ik) - if data is None: - raise KeyError((u, v, key)) - return stable_tvec(data[_TVEC_ATTR]) - - def get_edge_data( - self, u: NodeId, v: NodeId, key: EdgeKey, default: Any = None - ) -> Any: - try: - ik = self._choose_internal_key(u, v, key) - except KeyError: - return default - data = self._g.get_edge_data(u, v, ik) - if data is None: - return default - return _ro(data[_USER_ATTRS]) - - def set_edge_attrs( - self, u: NodeId, v: NodeId, key: EdgeKey, **attrs: Any - ) -> None: - if not self._has_undirected_base(u, v, key): - raise KeyError((u, v, key)) - if not attrs: - return - ik = self._choose_internal_key(u, v, key) - data = self._g.get_edge_data(u, v, ik) - if data is None: - raise KeyError((u, v, key)) - data[_USER_ATTRS].update(attrs) - self.data_version += 1 - - def remove_edge(self, u: NodeId, v: NodeId, key: EdgeKey) -> None: - triples = set() - for a, b in ((u, v), (v, u)): - kd = self._g.get_edge_data(a, b) or {} - for ik in kd: - if _base_key(ik) == int(key): - triples.add((a, b, ik)) - if len(triples) != 2: - raise KeyError((u, v, key)) - for a, b, ik in triples: - self._g.remove_edge(a, b, key=ik) - self.structural_version += 1 - - def check_invariants(self, *, strict: bool = False) -> Dict[str, Any]: - """Check undirected pairing invariants. - - Returns a structured report and optionally raises on errors. - - Invariants checked: - - For every directed realization there is a paired reverse one. - - Translation vectors satisfy t(v,u,rev) = -t(u,v,fwd). - - The user-attributes dict is the *same object* for paired - realizations. - - Args: - strict: If True, raise ValueError on the first violation. - - Returns: - A dict with keys: `ok`, `errors`, `n_edges`. - """ - errors: List[str] = [] - - for u, v, ik, ed in self._g.edges(keys=True, data=True): - base = _base_key(ik) - if isinstance(ik, _UKey): - rev_key: object = _UKey(base, -ik.dir) - else: - rev_key = ik - rev = self._g.get_edge_data(v, u, rev_key) - if rev is None: - msg = ( - 'missing reverse edge for ' - f'({u!r}, {v!r}, base={base!r}, ik={ik!r})' - ) - if strict: - raise ValueError(msg) - errors.append(msg) - continue - tv = stable_tvec(ed[_TVEC_ATTR]) - tv_rev = stable_tvec(rev[_TVEC_ATTR]) - if tv_rev != stable_tvec(neg_tvec(tv)): - msg = ( - 'translation mismatch for paired edges: ' - f'({u!r}->{v!r}, base={base!r}) has {tv!r}, ' - f'({v!r}->{u!r}, base={base!r}) has {tv_rev!r}' - ) - if strict: - raise ValueError(msg) - errors.append(msg) - if ed[_USER_ATTRS] is not rev[_USER_ATTRS]: - msg = ( - 'attribute mapping is not shared for paired edges: ' - f'({u!r},{v!r}, base={base!r})' - ) - if strict: - raise ValueError(msg) - errors.append(msg) - - return { - 'ok': len(errors) == 0, - 'errors': errors, - 'n_edges': int(self._g.number_of_edges()), - } - - -class PeriodicMultiGraph(PeriodicGraph): - """Undirected periodic multigraph. - - Unlike `PeriodicGraph`, this container allows multiple undirected edges for - the same motif pair and translation (i.e. multiple edges for the same - undirected ``{u, v, tvec}`` up to reversal). Parallel edges are - distinguished by their edge keys. - """ - - @property - def is_multigraph(self) -> bool: - """Whether this container allows multiple edges per `(u, v, tvec)`.""" - return True - - def add_edge( - self, - u: NodeId, - v: NodeId, - tvec: TVec, - key: Optional[EdgeKey] = None, - **attrs: Any, - ) -> EdgeKey: - """Add an undirected periodic edge (parallel edges allowed).""" - return self._add_undirected_impl( - u, v, tvec, key=key, attrs=dict(attrs) - ) diff --git a/src/pbcgraph/graph/multidirected.py b/src/pbcgraph/graph/multidirected.py new file mode 100644 index 0000000..3965bef --- /dev/null +++ b/src/pbcgraph/graph/multidirected.py @@ -0,0 +1,33 @@ +"""Directed periodic multigraph container.""" + +from __future__ import annotations + +from typing import Any, Optional + +from pbcgraph.core.types import EdgeKey, NodeId, TVec +from pbcgraph.graph.directed import PeriodicDiGraph + + +class PeriodicMultiDiGraph(PeriodicDiGraph): + """Directed periodic multigraph on ``Z^d``. + + Unlike `PeriodicDiGraph`, this container allows multiple edges for the same + directed triple ``(u, v, tvec)``. Such parallel edges are distinguished by + their edge keys. + """ + + @property + def is_multigraph(self) -> bool: + """Whether this container allows multiple edges per `(u, v, tvec)`.""" + return True + + def add_edge( + self, + u: NodeId, + v: NodeId, + tvec: TVec, + key: Optional[EdgeKey] = None, + **attrs: Any, + ) -> EdgeKey: + """Add a directed periodic edge (parallel edges allowed).""" + return self._add_edge_impl(u, v, tvec, key=key, attrs=dict(attrs)) diff --git a/src/pbcgraph/graph/multigraph.py b/src/pbcgraph/graph/multigraph.py new file mode 100644 index 0000000..093e5a8 --- /dev/null +++ b/src/pbcgraph/graph/multigraph.py @@ -0,0 +1,36 @@ +"""Undirected periodic multigraph container.""" + +from __future__ import annotations + +from typing import Any, Optional + +from pbcgraph.core.types import EdgeKey, NodeId, TVec +from pbcgraph.graph.undirected import PeriodicGraph + + +class PeriodicMultiGraph(PeriodicGraph): + """Undirected periodic multigraph. + + Unlike `PeriodicGraph`, this container allows multiple undirected edges for + the same motif pair and translation (i.e. multiple edges for the same + undirected ``{u, v, tvec}`` up to reversal). Parallel edges are + distinguished by their edge keys. + """ + + @property + def is_multigraph(self) -> bool: + """Whether this container allows multiple edges per `(u, v, tvec)`.""" + return True + + def add_edge( + self, + u: NodeId, + v: NodeId, + tvec: TVec, + key: Optional[EdgeKey] = None, + **attrs: Any, + ) -> EdgeKey: + """Add an undirected periodic edge (parallel edges allowed).""" + return self._add_undirected_impl( + u, v, tvec, key=key, attrs=dict(attrs) + ) diff --git a/src/pbcgraph/graph/shared.py b/src/pbcgraph/graph/shared.py new file mode 100644 index 0000000..eae5537 --- /dev/null +++ b/src/pbcgraph/graph/shared.py @@ -0,0 +1,83 @@ +"""Shared helpers for periodic graph containers. + +This module contains internal implementation details that are shared across +container classes in :mod:`pbcgraph.graph`. + +Notes: + - Translation vectors are stored under the internal edge-data key + :data:`_TVEC_ATTR`. + - User attributes are stored in a dedicated mapping under + :data:`_USER_ATTRS`. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from types import MappingProxyType +from typing import Any, Dict + +from pbcgraph.core.constants import PBC_META_KEY +from pbcgraph.core.types import EdgeKey + + +_TVEC_ATTR = '_tvec' +_USER_ATTRS = '_attrs' + + +def ro(mapping: Dict[str, Any]) -> MappingProxyType: + """Return a read-only live view of a mapping.""" + return MappingProxyType(mapping) + + +def validate_edge_key(key: EdgeKey) -> None: + """Validate an edge key. + + Edge keys must be ints, but ``bool`` is rejected (even though it is a + subclass of ``int``). + """ + if isinstance(key, bool) or not isinstance(key, int): + raise TypeError('edge key must be an int (bool is not allowed)') + + +def check_reserved_edge_attrs(attrs: Dict[str, Any]) -> None: + """Reject reserved edge-attribute keys. + + The key :data:`pbcgraph.PBC_META_KEY` is reserved for pbcgraph export + metadata (for example, in NetworkX edge attribute dicts produced by + :meth:`pbcgraph.alg.lift.LiftPatch.to_networkx`). + """ + if PBC_META_KEY in attrs: + raise ValueError( + f'edge attribute key {PBC_META_KEY!r} is reserved for pbcgraph ' + 'metadata' + ) + + +@dataclass(frozen=True) +class UKey: + """Private directed-edge key for undirected containers. + + `PeriodicGraph` and `PeriodicMultiGraph` represent each undirected edge as + two directed realizations. When `u == v` (a quotient self-loop), these two + realizations would collide in NetworkX if they shared the same + ``(u, v, key)`` triple. + + `UKey` splits the user-visible base key into two internal keys + distinguished by `dir` in {+1, -1}. + + The public API always exposes the base key (an int). + """ + + base: int + dir: int + + def __post_init__(self) -> None: + if self.dir not in (-1, 1): + raise ValueError('dir must be +1 or -1') + + +def base_key(k: object) -> int: + """Return the public base edge key for an internal key.""" + if isinstance(k, UKey): + return int(k.base) + return int(k) diff --git a/src/pbcgraph/graph/undirected.py b/src/pbcgraph/graph/undirected.py new file mode 100644 index 0000000..da055c7 --- /dev/null +++ b/src/pbcgraph/graph/undirected.py @@ -0,0 +1,321 @@ +"""Undirected periodic graph container.""" + +from __future__ import annotations + +from typing import Any, Dict, Iterable, List, Optional + +from pbcgraph.core.ordering import stable_tvec +from pbcgraph.core.types import EdgeKey, NodeId, TVec, neg_tvec, validate_tvec +from pbcgraph.graph.directed import PeriodicDiGraph +from pbcgraph.graph.shared import ( + _TVEC_ATTR, + _USER_ATTRS, + UKey as _UKey, + base_key as _base_key, + check_reserved_edge_attrs as _check_reserved_edge_attrs, + ro as _ro, + validate_edge_key as _validate_edge_key, +) + + +class PeriodicGraph(PeriodicDiGraph): + """Undirected periodic graph. + + Internally, an undirected periodic edge is represented by two directed + realizations: + + - ``u -> v`` with translation ``tvec`` + - ``v -> u`` with translation ``-tvec`` + + Both realizations share the same underlying user-attributes dict. + The public API returns read-only live views of that mapping. + + **Important:** a crystallographically common pattern is a quotient + self-loop with non-zero translation (``u == v`` and ``tvec != 0``), + representing a bond to a periodic image of the same motif. + + NetworkX identifies multiedges by ``(u, v, key)``. For self-loops, + the two directed realizations would collide if they shared the same key. + + To avoid this, `PeriodicGraph` stores directed realizations using a private + internal key type `_UKey(base, dir)`, where `base` is the user-visible + integer key and `dir` is `+1` / `-1`. The public API always exposes + the base key. + + In addition to the undirected-invariant pairing, this container enforces + an invariant analogous to `PeriodicDiGraph`: + + *For any undirected triple ``{u, v, tvec}`` (up to reversal), at most one + edge exists.* + + To allow multiple contacts for the same motif pair and translation, use + `PeriodicMultiGraph`. + + Notes: + `PeriodicGraph` is a subclass of `PeriodicDiGraph`, but restricts some + operations (for example, directed connectivity modes in algorithms). + """ + + @property + def is_undirected(self) -> bool: + """Whether this container should be treated as undirected + by algorithms.""" + return True + + def edges( + self, keys: bool = False, data: bool = False, tvec: bool = False + ) -> Iterable: + """Iterate directed realizations in deterministic order. + + This iterator yields *directed realizations* of undirected edges. + + Note: + For self-loop periodic edges (``u == v`` and ``tvec != 0``), the + two directed realizations share the same ``(u, v, key)`` triple and + differ only by the translation vector. If `keys=True` but + `tvec=False`, this may yield duplicate ``(u, u, key)`` records. Use + `tvec=True` to disambiguate. + + See `PeriodicDiGraph.edges` for the record formats. + """ + return super().edges(keys=keys, data=data, tvec=tvec) + + def _internal_keys_for_base( + self, u: NodeId, v: NodeId, key: EdgeKey + ) -> List[object]: + """Return internal keys on (u -> v) whose public base key + equals `key`.""" + kd = self._g.get_edge_data(u, v) or {} + base = int(key) + out: List[object] = [] + for ik in kd: + if _base_key(ik) == base: + out.append(ik) + return out + + def _choose_internal_key( + self, u: NodeId, v: NodeId, key: EdgeKey + ) -> object: + """Choose a deterministic internal key for accessing + shared attrs/tvec.""" + keys = self._internal_keys_for_base(u, v, key) + if not keys: + raise KeyError((u, v, key)) + # Prefer the "forward" realization when present. + for ik in keys: + if isinstance(ik, _UKey) and ik.dir == 1: + return ik + # Otherwise choose a deterministic order. + return sorted( + keys, + key=lambda x: ( + 0 if isinstance(x, _UKey) else 1, + getattr(x, 'dir', 0), + repr(x), + ), + )[0] + + def _has_undirected_base(self, u: NodeId, v: NodeId, key: EdgeKey) -> bool: + """Return True if an undirected edge with base key exists + between u and v.""" + if u != v: + return ( + len(self._internal_keys_for_base(u, v, key)) == 1 and + len(self._internal_keys_for_base(v, u, key)) == 1 + ) + # Self-loop: both realizations live on (u -> u). + return len(self._internal_keys_for_base(u, u, key)) == 2 + + def add_edge( + self, + u: NodeId, + v: NodeId, + tvec: TVec, + key: Optional[EdgeKey] = None, + **attrs: Any, + ) -> EdgeKey: + validate_tvec(tvec, self._dim) + + existing = self._key_for_tvec(u, v, tvec) + existing_rev = self._key_for_tvec(v, u, neg_tvec(tvec)) + if existing is not None or existing_rev is not None: + raise ValueError( + 'undirected edge already exists for {u, v, tvec}: ' + f'({u!r}, {v!r}, {tuple(tvec)!r}); ' + f'key={existing if existing is not None else existing_rev!r}' + ) + + return self._add_undirected_impl( + u, v, tvec, key=key, attrs=dict(attrs) + ) + + def _add_undirected_impl( + self, + u: NodeId, + v: NodeId, + tvec: TVec, + *, + key: Optional[EdgeKey], + attrs: Dict[str, Any], + ) -> EdgeKey: + """Implementation for adding an undirected edge (no tvec checks).""" + validate_tvec(tvec, self._dim) + if not self._g.has_node(u): + self.add_node(u) + if not self._g.has_node(v): + self.add_node(v) + + if key is None: + key = self._alloc_key_undirected(u, v) + else: + _validate_edge_key(key) + + # Disallow overwriting an existing base key in either direction. + if ( + self._internal_keys_for_base(u, v, key) + or self._internal_keys_for_base(v, u, key) + ): + raise KeyError((u, v, key)) + + tvec_norm = stable_tvec(tvec) + user_attrs: Dict[str, Any] = dict(attrs) + _check_reserved_edge_attrs(user_attrs) + + k_fwd = _UKey(int(key), 1) + k_rev = _UKey(int(key), -1) + + self._g.add_edge( + u, + v, + key=k_fwd, + **{_TVEC_ATTR: tvec_norm, _USER_ATTRS: user_attrs}, + ) + self._g.add_edge( + v, + u, + key=k_rev, + **{ + _TVEC_ATTR: stable_tvec(neg_tvec(tvec)), + _USER_ATTRS: user_attrs, + }, + ) + self.structural_version += 1 + return int(key) + + def has_edge( + self, u: NodeId, v: NodeId, key: Optional[EdgeKey] = None + ) -> bool: + if key is None: + if u != v: + return self._g.has_edge(u, v) and self._g.has_edge(v, u) + kd = self._g.get_edge_data(u, u) or {} + return len(kd) >= 2 + return self._has_undirected_base(u, v, key) + + def edge_tvec(self, u: NodeId, v: NodeId, key: EdgeKey) -> TVec: + ik = self._choose_internal_key(u, v, key) + data = self._g.get_edge_data(u, v, ik) + if data is None: + raise KeyError((u, v, key)) + return stable_tvec(data[_TVEC_ATTR]) + + def get_edge_data( + self, u: NodeId, v: NodeId, key: EdgeKey, default: Any = None + ) -> Any: + try: + ik = self._choose_internal_key(u, v, key) + except KeyError: + return default + data = self._g.get_edge_data(u, v, ik) + if data is None: + return default + return _ro(data[_USER_ATTRS]) + + def set_edge_attrs( + self, u: NodeId, v: NodeId, key: EdgeKey, **attrs: Any + ) -> None: + if not self._has_undirected_base(u, v, key): + raise KeyError((u, v, key)) + if not attrs: + return + ik = self._choose_internal_key(u, v, key) + data = self._g.get_edge_data(u, v, ik) + if data is None: + raise KeyError((u, v, key)) + _check_reserved_edge_attrs(attrs) + data[_USER_ATTRS].update(attrs) + self.data_version += 1 + + def remove_edge(self, u: NodeId, v: NodeId, key: EdgeKey) -> None: + triples = set() + for a, b in ((u, v), (v, u)): + kd = self._g.get_edge_data(a, b) or {} + for ik in kd: + if _base_key(ik) == int(key): + triples.add((a, b, ik)) + if len(triples) != 2: + raise KeyError((u, v, key)) + for a, b, ik in triples: + self._g.remove_edge(a, b, key=ik) + self.structural_version += 1 + + def check_invariants(self, *, strict: bool = False) -> Dict[str, Any]: + """Check undirected pairing invariants. + + Returns a structured report and optionally raises on errors. + + Invariants checked: + - For every directed realization there is a paired reverse one. + - Translation vectors satisfy t(v,u,rev) = -t(u,v,fwd). + - The user-attributes dict is the *same object* for paired + realizations. + + Args: + strict: If True, raise ValueError on the first violation. + + Returns: + A dict with keys: `ok`, `errors`, `n_edges`. + """ + errors: List[str] = [] + + for u, v, ik, ed in self._g.edges(keys=True, data=True): + base = _base_key(ik) + if isinstance(ik, _UKey): + rev_key: object = _UKey(base, -ik.dir) + else: + rev_key = ik + rev = self._g.get_edge_data(v, u, rev_key) + if rev is None: + msg = ( + 'missing reverse edge for ' + f'({u!r}, {v!r}, base={base!r}, ik={ik!r})' + ) + if strict: + raise ValueError(msg) + errors.append(msg) + continue + tv = stable_tvec(ed[_TVEC_ATTR]) + tv_rev = stable_tvec(rev[_TVEC_ATTR]) + if tv_rev != stable_tvec(neg_tvec(tv)): + msg = ( + 'translation mismatch for paired edges: ' + f'({u!r}->{v!r}, base={base!r}) has {tv!r}, ' + f'({v!r}->{u!r}, base={base!r}) has {tv_rev!r}' + ) + if strict: + raise ValueError(msg) + errors.append(msg) + if ed[_USER_ATTRS] is not rev[_USER_ATTRS]: + msg = ( + 'attribute mapping is not shared for paired edges: ' + f'({u!r},{v!r}, base={base!r})' + ) + if strict: + raise ValueError(msg) + errors.append(msg) + + return { + 'ok': len(errors) == 0, + 'errors': errors, + 'n_edges': int(self._g.number_of_edges()), + } diff --git a/tests/test_components_rank.py b/tests/alg/components/test_components_rank.py similarity index 100% rename from tests/test_components_rank.py rename to tests/alg/components/test_components_rank.py diff --git a/tests/test_canonical_lift_best_anchor.py b/tests/alg/lift/test_canonical_lift_best_anchor.py similarity index 100% rename from tests/test_canonical_lift_best_anchor.py rename to tests/alg/lift/test_canonical_lift_best_anchor.py diff --git a/tests/test_canonical_lift_greedy_cut.py b/tests/alg/lift/test_canonical_lift_greedy_cut.py similarity index 100% rename from tests/test_canonical_lift_greedy_cut.py rename to tests/alg/lift/test_canonical_lift_greedy_cut.py diff --git a/tests/test_canonical_lift_tree.py b/tests/alg/lift/test_canonical_lift_tree.py similarity index 100% rename from tests/test_canonical_lift_tree.py rename to tests/alg/lift/test_canonical_lift_tree.py diff --git a/tests/test_lift_patch.py b/tests/alg/lift/test_lift_patch.py similarity index 87% rename from tests/test_lift_patch.py rename to tests/alg/lift/test_lift_patch.py index 593e867..7bb5efe 100644 --- a/tests/test_lift_patch.py +++ b/tests/alg/lift/test_lift_patch.py @@ -1,7 +1,12 @@ import networkx as nx import pytest -from pbcgraph import PeriodicDiGraph, PeriodicGraph, PeriodicMultiGraph +from pbcgraph import ( + PBC_META_KEY, + PeriodicDiGraph, + PeriodicGraph, + PeriodicMultiGraph, +) from pbcgraph.core.exceptions import LiftPatchError @@ -69,16 +74,18 @@ def test_lift_patch_directed_preserves_both_directions_and_exports(): if {u, v} != {('A', (0,)), ('B', (0,))}: continue labels.append(data['label']) - assert data['_pbc_tail'] in {('A', (0,)), ('B', (0,))} - assert data['_pbc_head'] in {('A', (0,)), ('B', (0,))} + meta = data[PBC_META_KEY] + assert meta['tail'] in {('A', (0,)), ('B', (0,))} + assert meta['head'] in {('A', (0,)), ('B', (0,))} assert sorted(labels) == ['x', 'y'] nxC = patch.to_networkx(as_undirected=True, undirected_mode='orig_edges') assert isinstance(nxC, nx.Graph) data = nxC.edges[('A', (0,)), ('B', (0,))] - assert 'orig_edges' in data - assert len(data['orig_edges']) == 2 - labels2 = sorted([rec['attrs']['label'] for rec in data['orig_edges']]) + assert PBC_META_KEY in data + recs = data[PBC_META_KEY]['orig_edges'] + assert len(recs) == 2 + labels2 = sorted([rec['attrs']['label'] for rec in recs]) assert labels2 == ['x', 'y'] diff --git a/tests/test_paths.py b/tests/alg/paths/test_paths.py similarity index 100% rename from tests/test_paths.py rename to tests/alg/paths/test_paths.py diff --git a/tests/test_component_inst_key.py b/tests/component/test_component_inst_key.py similarity index 100% rename from tests/test_component_inst_key.py rename to tests/component/test_component_inst_key.py diff --git a/tests/test_component_staleness.py b/tests/component/test_component_staleness.py similarity index 100% rename from tests/test_component_staleness.py rename to tests/component/test_component_staleness.py diff --git a/tests/test_component_torsion.py b/tests/component/test_component_torsion.py similarity index 100% rename from tests/test_component_torsion.py rename to tests/component/test_component_torsion.py diff --git a/tests/test_types.py b/tests/core/test_types.py similarity index 100% rename from tests/test_types.py rename to tests/core/test_types.py diff --git a/tests/test_edges_tvec_flag.py b/tests/graph/test_edges_tvec_flag.py similarity index 100% rename from tests/test_edges_tvec_flag.py rename to tests/graph/test_edges_tvec_flag.py diff --git a/tests/test_graph_basic.py b/tests/graph/test_graph_basic.py similarity index 100% rename from tests/test_graph_basic.py rename to tests/graph/test_graph_basic.py diff --git a/tests/test_multiedges.py b/tests/graph/test_multiedges.py similarity index 100% rename from tests/test_multiedges.py rename to tests/graph/test_multiedges.py diff --git a/tests/test_neighbors_inst.py b/tests/graph/test_neighbors_inst.py similarity index 100% rename from tests/test_neighbors_inst.py rename to tests/graph/test_neighbors_inst.py diff --git a/tests/graph/test_reserved_edge_attrs.py b/tests/graph/test_reserved_edge_attrs.py new file mode 100644 index 0000000..301ba88 --- /dev/null +++ b/tests/graph/test_reserved_edge_attrs.py @@ -0,0 +1,35 @@ +import pytest + +from pbcgraph import ( + PBC_META_KEY, + PeriodicDiGraph, + PeriodicGraph, + PeriodicMultiDiGraph, + PeriodicMultiGraph, +) + + +def test_reserved_meta_key_rejected_in_add_edge_directed(): + G = PeriodicDiGraph(dim=1) + with pytest.raises(ValueError): + G.add_edge('A', 'B', (0,), **{PBC_META_KEY: {'x': 1}}) + + +def test_reserved_meta_key_rejected_in_add_edge_undirected(): + G = PeriodicGraph(dim=1) + with pytest.raises(ValueError): + G.add_edge('A', 'A', (1,), **{PBC_META_KEY: {'x': 1}}) + + +def test_reserved_meta_key_rejected_in_set_edge_attrs_directed(): + G = PeriodicMultiDiGraph(dim=1) + k = G.add_edge('A', 'B', (0,), kind='bond') + with pytest.raises(ValueError): + G.set_edge_attrs('A', 'B', k, **{PBC_META_KEY: {'x': 1}}) + + +def test_reserved_meta_key_rejected_in_set_edge_attrs_undirected(): + G = PeriodicMultiGraph(dim=1) + k = G.add_edge('A', 'A', (1,), kind='bond') + with pytest.raises(ValueError): + G.set_edge_attrs('A', 'A', k, **{PBC_META_KEY: {'x': 1}}) diff --git a/tests/graph/test_undirected_edges_unique.py b/tests/graph/test_undirected_edges_unique.py new file mode 100644 index 0000000..cebe010 --- /dev/null +++ b/tests/graph/test_undirected_edges_unique.py @@ -0,0 +1,37 @@ +from types import MappingProxyType + +import pytest + +from pbcgraph import PeriodicDiGraph, PeriodicGraph + + +def test_undirected_edges_unique_self_loop() -> None: + G = PeriodicGraph(dim=1) + G.add_edge('A', 'A', (1,), label='bond') + + recs = list(G.undirected_edges_unique(keys=True, data=True, tvec=True)) + assert len(recs) == 1 + + u, v, t, k, attrs = recs[0] + assert u == 'A' + assert v == 'A' + assert t == (-1,) + assert k == 0 + assert isinstance(attrs, MappingProxyType) + assert attrs['label'] == 'bond' + + +def test_undirected_edges_unique_non_loop() -> None: + G = PeriodicGraph(dim=1) + G.add_edge('A', 'B', (0,), label='ab') + + recs = list(G.undirected_edges_unique(keys=True, data=False, tvec=True)) + assert recs == [('A', 'B', (0,), 0)] + + +def test_undirected_edges_unique_directed_raises() -> None: + G = PeriodicDiGraph(dim=1) + G.add_edge('A', 'B', (0,)) + + with pytest.raises(TypeError): + list(G.undirected_edges_unique()) diff --git a/tests/test_undirected_self_loop.py b/tests/graph/test_undirected_self_loop.py similarity index 100% rename from tests/test_undirected_self_loop.py rename to tests/graph/test_undirected_self_loop.py diff --git a/tests/test_imports.py b/tests/meta/test_imports.py similarity index 100% rename from tests/test_imports.py rename to tests/meta/test_imports.py diff --git a/tests/test_imports_acyclic.py b/tests/meta/test_imports_acyclic.py similarity index 100% rename from tests/test_imports_acyclic.py rename to tests/meta/test_imports_acyclic.py