diff --git a/CHANGELOG.md b/CHANGELOG.md index b9e01cc..9ac8304 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,28 @@ This project follows a lightweight "keep a log" style. +## 0.1.2 - Finite lifts and canonical lifts + +- **Finite lift patches** + - Added `lift_patch(...)`: extract a finite patch of the infinite lift around a seed instance, using either a BFS radius and/or absolute/relative cell-index bounding boxes. + - Patch edges store **snapshot** attribute dicts. + - For undirected containers, paired directed realizations are deduplicated deterministically. + +- **Patch export and directed semantics** + - For directed periodic containers, `lift_patch(...)` now produces a **directed** patch by default (exported as `nx.DiGraph` / `nx.MultiDiGraph`). + - `LiftPatch.to_networkx(as_undirected=True, undirected_mode=...)` provides undirected views of directed patches: + - `undirected_mode='multigraph'`: one undirected multiedge per directed edge (direction preserved in `_pbc_tail`/`_pbc_head`). + - `undirected_mode='orig_edges'`: collapsed simple graph with `orig_edges=[...]` snapshots for each adjacency. + +- **Canonical lifts (strand representatives)** + - Added `canonical_lift(...)` to select one instance per quotient node for a chosen strand (coset in `Z^d/L`). + - Implemented placements: `tree`, `best_anchor`, and `greedy_cut`. + - Stored deterministic spanning-tree parent edges on `PeriodicComponent` to optionally return `tree_edges`. + +- **Errors** + - Added `CanonicalLiftError` and `LiftPatchError` for well-scoped failure modes. + + ## 0.1.1 - Refactoring - **Deterministic iteration** diff --git a/README.md b/README.md index b4b1b20..8cef3ea 100644 --- a/README.md +++ b/README.md @@ -17,6 +17,8 @@ What you get in v0.1: - `PeriodicGraph` / `PeriodicDiGraph`: unique edge per `(u, v, tvec)`. - `PeriodicMultiGraph` / `PeriodicMultiDiGraph`: parallel edges allowed for the same `(u, v, tvec)`. - `PeriodicComponent`: lattice invariants (rank, SNF torsion) and exact instance connectivity via `same_fragment(...)`. +- `lift_patch(...)`: extract a finite (non-periodic) patch of the infinite lift around a seed instance. +- `canonical_lift(...)`: select one lifted instance per quotient node for a chosen strand (coset in `Z^d/L`). ## Status @@ -25,15 +27,13 @@ The API may still evolve, but the library is already useful for research code an ## Install -Requires Python 3.10+. - -Once the project is published on PyPI: +Requires Python 3.10+. Latest stable version is usually published on PyPI: ```bash python -m pip install pbcgraph ``` -Until then (or for the latest `dev` branch), install from GitHub: +To install the latest version (or for the latest `dev` branch), install from GitHub: ```bash python -m pip install git+https://github.com/IvanChernyshov/pbcgraph.git @@ -70,6 +70,20 @@ neighbors = list(G.neighbors_inst(('A', (0, 0)))) comp = G.components()[0] assert comp.same_fragment(('A', (0, 0)), ('A', (1, 0))) assert not comp.same_fragment(('A', (0, 0)), ('A', (0, 1))) + +# Extract a finite patch of the infinite lift around a seed instance. +patch = G.lift_patch(('A', (0, 0)), radius=2) +nx_patch = patch.to_networkx() # nx.Graph / nx.MultiGraph for undirected sources + +# For directed sources, patches are directed by default: +# nx_patch = patch.to_networkx() # nx.DiGraph / nx.MultiDiGraph +# and you can obtain undirected views via: +# nx_u = patch.to_networkx(as_undirected=True, undirected_mode='multigraph') +# nx_c = patch.to_networkx(as_undirected=True, undirected_mode='orig_edges') + +# Canonical lift: pick one instance per quotient node for a strand. +lift = comp.canonical_lift(placement='tree') +assert len(lift.instances) == len(comp.nodes) ``` ## Documentation diff --git a/docs/api/algorithms.md b/docs/api/algorithms.md index 15cfc3a..c639d41 100644 --- a/docs/api/algorithms.md +++ b/docs/api/algorithms.md @@ -12,6 +12,24 @@ options: show_source: false +## Lifts + +::: pbcgraph.alg.lift.lift_patch + options: + show_source: false + +::: pbcgraph.alg.lift.LiftPatch + options: + show_source: false + +::: pbcgraph.alg.lift.canonical_lift + options: + show_source: false + +::: pbcgraph.alg.lift.CanonicalLift + options: + show_source: false + ## Lattice utilities ::: pbcgraph.alg.lattice.snf_decomposition diff --git a/docs/examples/canonical_lift.ipynb b/docs/examples/canonical_lift.ipynb new file mode 100644 index 0000000..8c385ce --- /dev/null +++ b/docs/examples/canonical_lift.ipynb @@ -0,0 +1,216 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "ef2e2a4a", + "metadata": {}, + "source": [ + "# Canonical lift\n", + "\n", + "`canonical_lift(...)` selects **exactly one lifted instance** for every quotient\n", + "node in a `PeriodicComponent`, producing a deterministic finite representation\n", + "of a single *strand* (a connected component of the infinite lift).\n", + "\n", + "In v0.1.2 you can choose between three placement modes:\n", + "\n", + "- `placement='tree'`: place the deterministic spanning tree with a chosen anchor\n", + "- `placement='best_anchor'`: try all valid anchors and pick the best score\n", + "- `placement='greedy_cut'`: locally improve the score while preserving connectivity\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e379cb70", + "metadata": {}, + "outputs": [], + "source": [ + "from pprint import pprint\n", + "\n", + "from pbcgraph import PeriodicDiGraph\n" + ] + }, + { + "cell_type": "markdown", + "id": "f2d4646c", + "metadata": {}, + "source": [ + "## Helper: inspect a canonical lift\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "eae34cdf", + "metadata": {}, + "outputs": [], + "source": [ + "def summarize_canon(component, out):\n", + " print('placement:', out.placement)\n", + " print('score:', out.score)\n", + " print('strand_key:', out.strand_key)\n", + " print('anchor_site:', out.anchor_site)\n", + " print('anchor_shift:', out.anchor_shift)\n", + " print('\\nnodes (u, shift):')\n", + " pprint(list(out.nodes))\n", + " print('\\nall nodes are in the target strand:', all(\n", + " component.inst_key((u, s)) == out.strand_key for u, s in out.nodes\n", + " ))\n", + " if out.tree_edges is not None:\n", + " print('\\ntree edges (parent, child, tvec, key):')\n", + " pprint(list(out.tree_edges))\n" + ] + }, + { + "cell_type": "markdown", + "id": "a025daf5", + "metadata": {}, + "source": [ + "## 1) Tree placement and `tree_edges`\n", + "\n", + "This is a small 1D quotient with a periodic cycle.\n", + "We request `return_tree=True` to see the spanning-tree edges used to compute\n", + "potentials.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6e943f16", + "metadata": {}, + "outputs": [], + "source": [ + "G = PeriodicDiGraph(dim=1)\n", + "G.add_edge('A', 'B', (0,))\n", + "G.add_edge('B', 'C', (0,))\n", + "G.add_edge('C', 'A', (1,))\n", + "\n", + "c = G.components()[0]\n", + "out_tree = c.canonical_lift(seed=('B', (0,)), anchor_shift=(0,), return_tree=True)\n", + "summarize_canon(c, out_tree)\n" + ] + }, + { + "cell_type": "markdown", + "id": "980ee941", + "metadata": {}, + "source": [ + "## 2) `best_anchor`: same strand, better score\n", + "\n", + "Here we intentionally make deterministic potentials very unbalanced.\n", + "`best_anchor` tries all anchors that exist in the requested strand inside the\n", + "anchor cell and chooses the one that minimizes the score.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e2a33ce2", + "metadata": {}, + "outputs": [], + "source": [ + "H = PeriodicDiGraph(dim=1)\n", + "H.add_edge('A', 'B', (2,))\n", + "H.add_edge('B', 'C', (98,))\n", + "H.add_edge('C', 'A', (-99,)) # cycle generator = 1 -> L = Z\n", + "\n", + "c2 = H.components()[0]\n", + "out_tree2 = c2.canonical_lift(anchor_shift=(0,), placement='tree', score='l1')\n", + "out_best2 = c2.canonical_lift(anchor_shift=(0,), placement='best_anchor', score='l1')\n", + "\n", + "summarize_canon(c2, out_tree2)\n", + "print('\\n---')\n", + "summarize_canon(c2, out_best2)\n", + "print('\\nbest_anchor improves score:', out_best2.score < out_tree2.score)\n" + ] + }, + { + "cell_type": "markdown", + "id": "e4257447", + "metadata": {}, + "source": [ + "## 3) `greedy_cut`: local improvement beyond `best_anchor`\n", + "\n", + "This example has two distinct quotient edges between `C` and `A`.\n", + "The deterministic spanning tree picks one of them, but `greedy_cut` can locally\n", + "switch to the alternative periodic relation and reduce the score while keeping\n", + "the induced internal graph connected.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ab80728a", + "metadata": {}, + "outputs": [], + "source": [ + "K = PeriodicDiGraph(dim=1)\n", + "K.add_edge('A', 'B', (2,))\n", + "K.add_edge('B', 'C', (98,))\n", + "K.add_edge('C', 'A', (-100,))\n", + "K.add_edge('C', 'A', (-99,))\n", + "\n", + "c3 = K.components()[0]\n", + "out_best3 = c3.canonical_lift(anchor_shift=(0,), placement='best_anchor', score='l1')\n", + "out_greedy3 = c3.canonical_lift(anchor_shift=(0,), placement='greedy_cut', score='l1')\n", + "\n", + "summarize_canon(c3, out_best3)\n", + "print('\\n---')\n", + "summarize_canon(c3, out_greedy3)\n", + "print('\\ngreedy_cut improves score:', out_greedy3.score < out_best3.score)\n" + ] + }, + { + "cell_type": "markdown", + "id": "05db5fe5", + "metadata": {}, + "source": [ + "## 4) Strand keys and the \"strand absent in the anchor cell\" error\n", + "\n", + "If the translation subgroup is a proper sublattice of `Z^d`, the infinite lift\n", + "splits into multiple disconnected strands (torsion / interpenetration).\n", + "\n", + "In this case, a requested `strand_key` might have **no representatives in the\n", + "anchor cell**. Then `canonical_lift` raises `CanonicalLiftError`.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5dde9333", + "metadata": {}, + "outputs": [], + "source": [ + "from pbcgraph.core.exceptions import CanonicalLiftError\n", + "\n", + "T = PeriodicDiGraph(dim=1)\n", + "T.add_edge('A', 'A', (2,)) # L = 2Z -> torsion 2 (even/odd strands)\n", + "\n", + "c4 = T.components()[0]\n", + "print('torsion invariants:', c4.torsion_invariants)\n", + "\n", + "k0 = c4.inst_key(('A', (0,)))\n", + "k1 = c4.inst_key(('A', (1,)))\n", + "print('strand key at A@(0):', k0)\n", + "print('strand key at A@(1):', k1)\n", + "\n", + "try:\n", + " c4.canonical_lift(strand_key=k1, anchor_shift=(0,))\n", + "except CanonicalLiftError as e:\n", + " print('expected error:', e)\n", + "\n", + "# Fix: choose an anchor cell that actually contains the strand.\n", + "out_fix = c4.canonical_lift(strand_key=k1, anchor_shift=(1,))\n", + "summarize_canon(c4, out_fix)\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/examples/lift_patch.ipynb b/docs/examples/lift_patch.ipynb new file mode 100644 index 0000000..5602262 --- /dev/null +++ b/docs/examples/lift_patch.ipynb @@ -0,0 +1,289 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "c501b7dd", + "metadata": {}, + "source": [ + "# Lift patch\n", + "\n", + "This notebook demonstrates **finite patches** extracted from the infinite lift using\n", + "`lift_patch(seed, ...)`.\n", + "\n", + "You will see how the output behaves for different container types:\n", + "\n", + "- `PeriodicGraph` / `PeriodicMultiGraph` (undirected containers)\n", + "- `PeriodicDiGraph` / `PeriodicMultiDiGraph` (directed containers)\n", + "\n", + "Key idea:\n", + "\n", + "- Traversal uses **weak connectivity** in the lift (successors and predecessors).\n", + "- The returned patch is **directed** when the source container is directed.\n", + "- For directed patches, you can still obtain an undirected view via\n", + " `patch.to_networkx(as_undirected=True, ...)`.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ac399bbd", + "metadata": {}, + "outputs": [], + "source": [ + "from pprint import pprint\n", + "\n", + "import networkx as nx\n", + "\n", + "from pbcgraph import (\n", + " PeriodicGraph,\n", + " PeriodicMultiGraph,\n", + " PeriodicDiGraph,\n", + " PeriodicMultiDiGraph,\n", + ")\n" + ] + }, + { + "cell_type": "markdown", + "id": "af4bcded", + "metadata": {}, + "source": [ + "## Helper: summarize a patch\n", + "\n", + "A `LiftPatch` stores node instances and edge records. The most convenient way to\n", + "work with it is often to export to NetworkX via `to_networkx()`.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3fe08034", + "metadata": {}, + "outputs": [], + "source": [ + "def summarize_patch(patch, *, max_edges=12):\n", + " print('patch nodes:', len(patch.nodes))\n", + " print('patch edges:', len(patch.edges))\n", + " print('is_directed:', patch.is_directed)\n", + " print('is_multigraph:', patch.is_multigraph)\n", + " print('seed:', patch.seed)\n", + " print('radius:', patch.radius)\n", + " print('box:', patch.box)\n", + " print('\\nfirst nodes:')\n", + " pprint(list(patch.nodes)[:10])\n", + " print('\\nfirst edges:')\n", + " pprint(list(patch.edges)[:max_edges])\n" + ] + }, + { + "cell_type": "markdown", + "id": "ce7b47ca", + "metadata": {}, + "source": [ + "## 1) Undirected periodic graph (`PeriodicGraph`)\n", + "\n", + "Here the source container is undirected (internally stored as two directed\n", + "realizations per bond). The patch exports as `nx.Graph`.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "45fe0870", + "metadata": {}, + "outputs": [], + "source": [ + "G = PeriodicGraph(dim=2)\n", + "G.add_edge('A', 'B', (0, 0))\n", + "G.add_edge('B', 'C', (0, 0))\n", + "G.add_edge('C', 'A', (1, 0)) # periodic cycle generator along x\n", + "\n", + "patch = G.lift_patch(('A', (0, 0)), radius=2)\n", + "summarize_patch(patch)\n", + "\n", + "nxG = patch.to_networkx()\n", + "print('\\nexport type:', type(nxG))\n", + "print('nx nodes:', nxG.number_of_nodes())\n", + "print('nx edges:', nxG.number_of_edges())\n" + ] + }, + { + "cell_type": "markdown", + "id": "01c29ce7", + "metadata": {}, + "source": [ + "## 2) Undirected multigraph (`PeriodicMultiGraph`)\n", + "\n", + "A multigraph can store multiple periodic edges between the same quotient nodes.\n", + "The patch exports as `nx.MultiGraph` and preserves edge keys.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "844383e6", + "metadata": {}, + "outputs": [], + "source": [ + "H = PeriodicMultiGraph(dim=1)\n", + "H.add_edge('A', 'B', (0,), label='bond-1')\n", + "H.add_edge('A', 'B', (1,), label='bond-2')\n", + "\n", + "patch2 = H.lift_patch(('A', (0,)), radius=1)\n", + "summarize_patch(patch2)\n", + "\n", + "nxH = patch2.to_networkx()\n", + "print('\\nexport type:', type(nxH))\n", + "print('edges with data:')\n", + "for u, v, k, data in nxH.edges(keys=True, data=True):\n", + " print(u, ' -- ', v, ' key=', k, ' data=', data)\n" + ] + }, + { + "cell_type": "markdown", + "id": "43918857", + "metadata": {}, + "source": [ + "## 3) Directed periodic graph (`PeriodicDiGraph`)\n", + "\n", + "In step 5, `lift_patch` became **direction-preserving** for directed containers.\n", + "This avoids the old drawback where `u -> v` and `v -> u` could collapse in an\n", + "undirected patch.\n", + "\n", + "You can still request an undirected view from the patch export:\n", + "\n", + "- `undirected_mode='multigraph'`: one undirected multiedge per directed edge\n", + "- `undirected_mode='orig_edges'`: one undirected edge with `orig_edges=[...]`\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "52c2e564", + "metadata": {}, + "outputs": [], + "source": [ + "D = PeriodicDiGraph(dim=1)\n", + "D.add_edge('A', 'B', (0,), label='x')\n", + "D.add_edge('B', 'A', (0,), label='y')\n", + "\n", + "patch3 = D.lift_patch(('A', (0,)), radius=1)\n", + "summarize_patch(patch3)\n", + "\n", + "nxD = patch3.to_networkx()\n", + "print('\\nexport type:', type(nxD))\n", + "print('directed edges:')\n", + "for u, v, data in nxD.edges(data=True):\n", + " print(u, ' -> ', v, ' data=', data)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1a08b591", + "metadata": {}, + "outputs": [], + "source": [ + "# Undirected view: multigraph\n", + "nxU = patch3.to_networkx(as_undirected=True, undirected_mode='multigraph')\n", + "print(type(nxU))\n", + "print('undirected multiedges between A and B:', nxU.number_of_edges(('A', (0,)), ('B', (0,))))\n", + "for u, v, data in nxU.edges(data=True):\n", + " if {u, v} != {('A', (0,)), ('B', (0,))}:\n", + " continue\n", + " print(u, '--', v, 'label=', data.get('label'), 'tail=', data.get('_pbc_tail'), 'head=', data.get('_pbc_head'))\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0785bdf2", + "metadata": {}, + "outputs": [], + "source": [ + "# Undirected view: collapsed Graph with orig_edges bags\n", + "nxC = patch3.to_networkx(as_undirected=True, undirected_mode='orig_edges')\n", + "print(type(nxC))\n", + "data = nxC.edges[('A', (0,)), ('B', (0,))]\n", + "print('orig_edges records:')\n", + "pprint(data['orig_edges'])\n" + ] + }, + { + "cell_type": "markdown", + "id": "a80e6dac", + "metadata": {}, + "source": [ + "## 4) Directed multigraph (`PeriodicMultiDiGraph`)\n", + "\n", + "Parallel directed edges are preserved in the patch export as `nx.MultiDiGraph`.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e138e498", + "metadata": {}, + "outputs": [], + "source": [ + "M = PeriodicMultiDiGraph(dim=1)\n", + "M.add_edge('A', 'B', (0,), label='e1')\n", + "M.add_edge('A', 'B', (0,), label='e2')\n", + "M.add_edge('B', 'A', (0,), label='back')\n", + "\n", + "patch4 = M.lift_patch(('A', (0,)), radius=1)\n", + "summarize_patch(patch4)\n", + "\n", + "nxM = patch4.to_networkx()\n", + "print('\\nexport type:', type(nxM))\n", + "print('directed multiedges A->B:', nxM.number_of_edges(('A', (0,)), ('B', (0,))))\n", + "for u, v, k, data in nxM.edges(keys=True, data=True):\n", + " if u == ('A', (0,)) and v == ('B', (0,)):\n", + " print('A->B key=', k, 'label=', data.get('label'))\n" + ] + }, + { + "cell_type": "markdown", + "id": "27b372d8", + "metadata": {}, + "source": [ + "## 5) Using a bounding box (`box` and `box_rel`)\n", + "\n", + "Besides a BFS `radius`, you can restrict the patch by an absolute cell box.\n", + "The box is a tuple of `(min, max)` intervals for each lattice coordinate.\n", + "\n", + "`box_rel` is convenient when you want a symmetric window around the seed shift.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e7af7712", + "metadata": {}, + "outputs": [], + "source": [ + "P = PeriodicGraph(dim=1)\n", + "P.add_edge('A', 'A', (1,), label='step')\n", + "\n", + "# radius-based patch\n", + "patch_r = P.lift_patch(('A', (0,)), radius=3)\n", + "print('radius=3 nodes:', patch_r.nodes)\n", + "\n", + "# box-based patch: only shifts in [-1, 1]\n", + "patch_b = P.lift_patch(('A', (0,)), box=((-1, 1),))\n", + "print('box=[-1,1] nodes:', patch_b.nodes)\n", + "\n", + "# box_rel: relative window around seed shift\n", + "patch_br = P.lift_patch(('A', (5,)), box_rel=((-1, 1),))\n", + "print('seed shift 5, box_rel=[-1,1] nodes:', patch_br.nodes)\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/examples/snf_interpenetration.ipynb b/docs/examples/snf.ipynb similarity index 98% rename from docs/examples/snf_interpenetration.ipynb rename to docs/examples/snf.ipynb index 56cb1ef..631ee51 100644 --- a/docs/examples/snf_interpenetration.ipynb +++ b/docs/examples/snf.ipynb @@ -5,7 +5,7 @@ "id": "6917600d", "metadata": {}, "source": [ - "# Smith Normal Form (SNF), torsion, and interpenetration\n", + "# Smith Normal Form (SNF)\n", "\n", "`PeriodicComponent` computes the translation subgroup `L ⊂ Z^d` induced by quotient cycles, then uses a Smith\n", "Normal Form (SNF) decomposition to expose:\n", diff --git a/docs/general/concepts.md b/docs/general/concepts.md index c563bee..5bacbf1 100644 --- a/docs/general/concepts.md +++ b/docs/general/concepts.md @@ -108,3 +108,34 @@ The numerical representation depends on a unimodular change of coordinates (see - `connectivity='weak'`: successors ∪ predecessors (default behavior for undirected use-cases) This is a quotient path; it does *not* compute an instance-aware shortest path in the infinite lift. + +## Finite patches of the lift + +Sometimes you want a **finite non-periodic graph** that represents a local +fragment of the infinite lift (for visualization, local feature computation, +or feeding to non-periodic algorithms). + +`lift_patch(...)` builds such a patch around a seed instance `(u, shift)` using +either a BFS radius and/or a cell-index bounding box. + +Important details: + +- **Traversal is weakly connected** in the lift: from an instance it considers + both outgoing and incoming periodic edges (successors ∪ predecessors). + This makes patch extraction useful even for directed quotient graphs. + +- **Patch direction follows the container**: + - from undirected containers (`PeriodicGraph`, `PeriodicMultiGraph`), the + patch is undirected; + - from directed containers (`PeriodicDiGraph`, `PeriodicMultiDiGraph`), the + patch is directed. + +- **Undirected views of directed patches** are available via + `LiftPatch.to_networkx(as_undirected=True, undirected_mode=...)`: + - `undirected_mode='multigraph'`: one undirected multiedge per directed + edge; direction metadata is stored in `_pbc_tail`/`_pbc_head`. + - `undirected_mode='orig_edges'`: collapsed simple graph; each undirected + adjacency stores `orig_edges=[...]` snapshots. + +These export options avoid silent loss of information when you want an +undirected representation for an inherently directed relation. diff --git a/docs/general/graph_types.md b/docs/general/graph_types.md index 0764d54..006954f 100644 --- a/docs/general/graph_types.md +++ b/docs/general/graph_types.md @@ -74,6 +74,11 @@ Examples where direction is meaningful: If you only want to *label* an interaction (e.g., donor/acceptor role), but the relation should still be treated as symmetric for connectivity, an undirected container plus attributes is usually a better fit. +!!! note + `lift_patch(...)` follows the container direction: patches extracted from + `PeriodicDiGraph` / `PeriodicMultiDiGraph` are directed by default (while + traversal still uses weak connectivity). + ## `PeriodicMultiDiGraph` Use `PeriodicMultiDiGraph` when edges are directed **and** multiple distinct edges may exist for the same diff --git a/docs/index.md b/docs/index.md index 224770c..d534fb5 100644 --- a/docs/index.md +++ b/docs/index.md @@ -20,6 +20,8 @@ The key idea is simple and useful: you store a **finite quotient graph** (intern - deterministic `inst_key(...)` keys for lifted instances within a component. - `same_fragment(...)`: exact “are these two lifted instances in the same connected fragment?” checks. - `shortest_path_quotient(...)`: fast BFS in the quotient with `connectivity='directed'|'weak'`. +- `lift_patch(...)`: extract a finite patch of the infinite lift around a seed instance. +- `canonical_lift(...)`: pick a canonical set of lifted instances (one per quotient node) for a chosen strand. ## Design philosophy diff --git a/mkdocs.yml b/mkdocs.yml index 5231e6f..631c2ee 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -59,7 +59,9 @@ nav: - Roadmap: general/roadmap.md - Examples: - Quickstart: examples/quickstart.ipynb - - SNF and interpenetration: examples/snf_interpenetration.ipynb + - Lift patch: examples/lift_patch.ipynb + - Canonical lift: examples/canonical_lift.ipynb + - Smith Normal Form (SNF): examples/snf.ipynb - API: - Reference: api/index.md - Graphs: api/graphs.md diff --git a/src/pbcgraph/__about__.py b/src/pbcgraph/__about__.py index 27573f3..f254bcc 100644 --- a/src/pbcgraph/__about__.py +++ b/src/pbcgraph/__about__.py @@ -1,4 +1,4 @@ -__version__ = '0.1.1' +__version__ = '0.1.2' __all__ = [ '__version__', diff --git a/src/pbcgraph/alg/__init__.py b/src/pbcgraph/alg/__init__.py index a3d5755..ee24f67 100644 --- a/src/pbcgraph/alg/__init__.py +++ b/src/pbcgraph/alg/__init__.py @@ -1,6 +1,12 @@ """Algorithms for pbcgraph (v0.1).""" from pbcgraph.alg.components import components, connected_components +from pbcgraph.alg.lift import ( + CanonicalLift, + LiftPatch, + canonical_lift, + lift_patch, +) from pbcgraph.alg.paths import Connectivity, shortest_path_quotient from pbcgraph.lattice import ( SNFDecomposition, @@ -11,6 +17,10 @@ __all__ = [ 'components', 'connected_components', + 'lift_patch', + 'LiftPatch', + 'canonical_lift', + 'CanonicalLift', 'shortest_path_quotient', 'Connectivity', 'SNFDecomposition', diff --git a/src/pbcgraph/alg/lift.py b/src/pbcgraph/alg/lift.py new file mode 100644 index 0000000..f6d0402 --- /dev/null +++ b/src/pbcgraph/alg/lift.py @@ -0,0 +1,999 @@ +"""Finite lifts of periodic graphs. + +This module implements finite, non-periodic views derived from a periodic +quotient graph. + +v0.1.2 adds two high-level operations: + +1) ``lift_patch``: extract a finite patch of the infinite lift + around a seed instance (directed for directed sources; undirected for + undirected sources). + +2) ``canonical_lift`` (added in later steps of the v0.1.2 plan). +""" + +from __future__ import annotations + +from collections import deque +from dataclasses import dataclass +from typing import ( + Any, + Callable, + Dict, + FrozenSet, + Hashable, + Iterator, + List, + Literal, + Optional, + Sequence, + Tuple, + Union, +) + +import networkx as nx + +from pbcgraph.core.exceptions import CanonicalLiftError, LiftPatchError +from pbcgraph.core.ordering import fallback_key, stable_sorted +from pbcgraph.core.protocols import PeriodicDiGraphLike +from pbcgraph.core.types import ( + NodeId, + NodeInst, + TVec, + add_tvec, + sub_tvec, + zero_tvec, + validate_tvec, +) + + +PatchEdgeRec = Tuple[NodeInst, NodeInst, Dict[str, Any]] +PatchMultiEdgeRec = Tuple[NodeInst, NodeInst, int, Dict[str, Any]] + + +def _validate_box( + box: Sequence[Sequence[int]], + dim: int, +) -> Tuple[Tuple[int, int], ...]: + if len(box) != dim: + raise LiftPatchError('box dimension mismatch') + out: List[Tuple[int, int]] = [] + for rng in box: + if len(rng) != 2: + raise LiftPatchError('box must be a sequence of (lo, hi) pairs') + lo = int(rng[0]) + hi = int(rng[1]) + if hi < lo: + raise LiftPatchError('box has invalid range (hi < lo)') + out.append((lo, hi)) + return tuple(out) + + +def _intersect_boxes( + a: Optional[Tuple[Tuple[int, int], ...]], + b: Optional[Tuple[Tuple[int, int], ...]], + dim: int, +) -> Optional[Tuple[Tuple[int, int], ...]]: + if a is None: + return b + if b is None: + return a + if len(a) != dim or len(b) != dim: + raise LiftPatchError('box dimension mismatch') + out: List[Tuple[int, int]] = [] + for (lo1, hi1), (lo2, hi2) in zip(a, b): + lo = max(lo1, lo2) + hi = min(hi1, hi2) + if hi < lo: + # Empty intersection: still return a valid box. + out.append((lo, lo)) + else: + out.append((lo, hi)) + return tuple(out) + + +def _in_box(shift: TVec, box: Optional[Tuple[Tuple[int, int], ...]]) -> bool: + if box is None: + return True + for x, (lo, hi) in zip(shift, box): + if x < lo or x >= hi: + return False + return True + + +def _try_sort_patch_edges( + records: List[Tuple[Any, Any, int, Any]], +) -> None: + """Sort patch edge candidates deterministically. + + Records are (u_inst, v_inst, key, payload). + """ + try: + records.sort(key=lambda r: (r[0], r[1], r[2])) + except TypeError: + records.sort( + key=lambda r: (fallback_key(r[0]), fallback_key(r[1]), r[2]) + ) + + +@dataclass(frozen=True) +class LiftPatch: + """A finite patch extracted from the infinite lift. + + Attributes: + nodes: Node instances `(u, shift)` in canonical order. + edges: Edges between included node instances. + + - For simple containers: `(u_inst, v_inst, attrs)`. + - For multigraph containers: `(u_inst, v_inst, key, attrs)`. + + For directed patches, `(u_inst, v_inst)` is ordered. + For undirected patches, endpoints are in canonical order. + seed: Seed node instance. + radius: BFS radius in the lifted graph (weak connectivity), if used. + box: Effective absolute box constraint after intersection, if used. + """ + + nodes: Tuple[NodeInst, ...] + edges: Tuple[Union[PatchEdgeRec, PatchMultiEdgeRec], ...] + seed: NodeInst + radius: Optional[int] + box: Optional[Tuple[Tuple[int, int], ...]] + _is_multigraph: bool = False + _is_directed: bool = False + + @property + def is_multigraph(self) -> bool: + """Whether the patch edges include keys.""" + return bool(self._is_multigraph) + + @property + def is_directed(self) -> bool: + """Whether the patch edges are directed.""" + return bool(self._is_directed) + + def to_networkx( + self, + *, + as_undirected: Optional[bool] = None, + undirected_mode: Literal['multigraph', 'orig_edges'] = 'multigraph', + ) -> Union[nx.Graph, nx.DiGraph, nx.MultiGraph, nx.MultiDiGraph]: + """Export the patch as a NetworkX graph. + + Notes: + - By default, directed patches export as directed NetworkX graphs, + and undirected patches export as undirected. + - For directed patches, `as_undirected=True` provides an undirected + view: + - `undirected_mode='multigraph'` returns a MultiGraph where + each directed edge becomes a distinct undirected multiedge, + with direction metadata in edge attributes. + - `undirected_mode='orig_edges'` returns a simple Graph where + each undirected adjacency stores `orig_edges=[...]` + snapshots. + """ + if as_undirected is None: + as_undirected = not self.is_directed + + if not self.is_directed and as_undirected is False: + raise ValueError('cannot export an undirected patch as directed') + + # Directed export (default for directed patches). + if not as_undirected: + if self.is_multigraph: + Gd: Union[nx.DiGraph, nx.MultiDiGraph] = nx.MultiDiGraph() + else: + Gd = nx.DiGraph() + + for node in self.nodes: + Gd.add_node(node) + + if self.is_multigraph: + for u, v, key, attrs in self.edges: # type: ignore[misc] + Gd.add_edge(u, v, key=int(key), **dict(attrs)) + else: + for u, v, attrs in self.edges: # type: ignore[misc] + Gd.add_edge(u, v, **dict(attrs)) + return Gd + + # Undirected export for undirected patches. + if not self.is_directed: + if self.is_multigraph: + Gu: Union[nx.Graph, nx.MultiGraph] = nx.MultiGraph() + else: + Gu = nx.Graph() + + for node in self.nodes: + Gu.add_node(node) + + if self.is_multigraph: + for u, v, key, attrs in self.edges: # type: ignore[misc] + Gu.add_edge(u, v, key=int(key), **dict(attrs)) + else: + for u, v, attrs in self.edges: # type: ignore[misc] + Gu.add_edge(u, v, **dict(attrs)) + return Gu + + # Directed patch -> undirected view. + if undirected_mode == 'multigraph': + Gu2 = nx.MultiGraph() + for node in self.nodes: + Gu2.add_node(node) + + if self.is_multigraph: + for u, v, key, attrs in self.edges: # type: ignore[misc] + data = dict(attrs) + data['_pbc_tail'] = u + data['_pbc_head'] = v + data['_pbc_key'] = int(key) + Gu2.add_edge(u, v, **data) + else: + for u, v, attrs in self.edges: # type: ignore[misc] + data = dict(attrs) + data['_pbc_tail'] = u + data['_pbc_head'] = v + data['_pbc_key'] = None + Gu2.add_edge(u, v, **data) + return Gu2 + + if undirected_mode != 'orig_edges': + raise ValueError('invalid undirected_mode') + + Gu3 = nx.Graph() + for node in self.nodes: + Gu3.add_node(node) + + def _canon_pair(a: NodeInst, b: NodeInst) -> Tuple[NodeInst, NodeInst]: + uu, vv = stable_sorted([a, b]) + return uu, vv + + buckets: Dict[Tuple[NodeInst, NodeInst], List[Dict[str, Any]]] = {} + if self.is_multigraph: + for u, v, key, attrs in self.edges: # type: ignore[misc] + a, b = _canon_pair(u, v) + rec = { + 'tail': u, + 'head': v, + 'key': int(key), + 'attrs': dict(attrs), + } + buckets.setdefault((a, b), []).append(rec) + else: + for u, v, attrs in self.edges: # type: ignore[misc] + a, b = _canon_pair(u, v) + rec = { + 'tail': u, + 'head': v, + 'key': None, + 'attrs': dict(attrs), + } + buckets.setdefault((a, b), []).append(rec) + + for (a, b), recs in buckets.items(): + try: + recs.sort(key=lambda r: (r['tail'], r['head'], r['key'])) + except TypeError: + recs.sort( + key=lambda r: ( + fallback_key(r['tail']), + fallback_key(r['head']), + -1 if r['key'] is None else int(r['key']), + ) + ) + Gu3.add_edge(a, b, orig_edges=recs) + return Gu3 + + +def lift_patch( + G: PeriodicDiGraphLike, + seed: NodeInst, + *, + radius: Optional[int] = None, + box: Optional[Tuple[Tuple[int, int], ...]] = None, + box_rel: Optional[Tuple[Tuple[int, int], ...]] = None, + include_edges: bool = True, + max_nodes: Optional[int] = None, + node_order: Optional[Callable[[NodeInst], Any]] = None, + edge_order: Optional[Callable[[Tuple[Any, ...]], Any]] = None, +) -> LiftPatch: + """Extract a finite patch of the lifted graph around a seed. + + The traversal uses weak connectivity in the infinite lift: from an instance + it considers both outgoing and incoming quotient edges. + + Notes: + The returned patch is directed if `G.is_undirected == False`, and + undirected otherwise. Use `LiftPatch.to_networkx(as_undirected=True, + ...)` to obtain undirected views of directed patches. + + + Args: + G: A periodic graph container. + seed: Seed instance `(u, shift)`. + radius: Optional BFS radius in the lifted graph. + box: Optional absolute half-open bounds per coordinate. + box_rel: Optional bounds relative to `seed.shift`. + include_edges: Whether to include edges between included nodes. + max_nodes: If provided, raise if the patch would include more than + `max_nodes` nodes. + node_order: Optional key function for ordering node instances. + edge_order: Optional key function for ordering edge records. + + Returns: + A :class:`~pbcgraph.alg.lift.LiftPatch`. + + Raises: + LiftPatchError: On invalid inputs or if `max_nodes` is exceeded. + """ + dim = int(G.dim) + u0, s0 = seed + validate_tvec(s0, dim) + if radius is None and box is None and box_rel is None: + raise LiftPatchError( + 'at least one of radius, box, or box_rel is required' + ) + if radius is not None: + radius = int(radius) + if radius < 0: + raise LiftPatchError('radius must be non-negative') + + abs_box: Optional[Tuple[Tuple[int, int], ...]] = None + if box is not None: + abs_box = _validate_box(box, dim) + + abs_box_rel: Optional[Tuple[Tuple[int, int], ...]] = None + if box_rel is not None: + rel = _validate_box(box_rel, dim) + out: List[Tuple[int, int]] = [] + for (lo, hi), x0 in zip(rel, s0): + out.append((int(x0) + lo, int(x0) + hi)) + abs_box_rel = tuple(out) + + eff_box = _intersect_boxes(abs_box, abs_box_rel, dim) + if not _in_box(s0, eff_box): + raise LiftPatchError('seed instance is outside the effective box') + + if max_nodes is not None: + max_nodes = int(max_nodes) + if max_nodes <= 0: + raise LiftPatchError('max_nodes must be positive') + + # ----------------- + # Traversal + # ----------------- + visited: Dict[NodeInst, int] = {seed: 0} + q: deque[NodeInst] = deque([seed]) + + def iter_weak_neighbors(inst: NodeInst) -> Iterator[NodeInst]: + for v, s2 in G.neighbors_inst(inst, keys=False, data=False): + yield v, s2 + for v, s2 in G.in_neighbors_inst(inst, keys=False, data=False): + yield v, s2 + + while q: + cur = q.popleft() + dcur = visited[cur] + if radius is not None and dcur >= radius: + continue + + for nb in iter_weak_neighbors(cur): + _v, s2 = nb + validate_tvec(s2, dim) + if not _in_box(s2, eff_box): + continue + if nb in visited: + continue + visited[nb] = dcur + 1 + q.append(nb) + if max_nodes is not None and len(visited) > max_nodes: + raise LiftPatchError('max_nodes exceeded during traversal') + + # Canonical node order. + nodes_list = list(visited.keys()) + if node_order is None: + nodes = tuple(stable_sorted(nodes_list)) + else: + nodes = tuple(sorted(nodes_list, key=node_order)) + + patch_is_directed = not bool(G.is_undirected) + + # ----------------- + # Edge inclusion (no explicit tvec) + # ----------------- + edges_out: List[Union[PatchEdgeRec, PatchMultiEdgeRec]] = [] + if include_edges: + included_set = set(visited) + + if patch_is_directed: + records: List[ + Tuple[NodeInst, NodeInst, int, Any, Dict[str, Any]] + ] = [] + for inst in nodes: + for v, s2, k, attrs in G.neighbors_inst( + inst, keys=True, data=True + ): + nb = (v, s2) + if nb not in included_set: + continue + sel_key = (inst, nb, int(k)) + sc = ( + edge_order(sel_key) + if edge_order is not None + else sel_key + ) + records.append((inst, nb, int(k), sc, dict(attrs))) + + try: + records.sort(key=lambda r: (r[3], r[0], r[1], r[2])) + except TypeError: + records.sort( + key=lambda r: ( + fallback_key(r[3]), + fallback_key(r[0]), + fallback_key(r[1]), + r[2], + ) + ) + + if G.is_multigraph: + for u_inst, v_inst, kk, _sc, attrs in records: + edges_out.append((u_inst, v_inst, int(kk), dict(attrs))) + else: + for u_inst, v_inst, _kk, _sc, attrs in records: + edges_out.append((u_inst, v_inst, dict(attrs))) + + else: + candidates: List[ + Tuple[NodeInst, NodeInst, int, Dict[str, Any]] + ] = [] + for inst in nodes: + for v, s2, k, attrs in G.neighbors_inst( + inst, keys=True, data=True + ): + nb = (v, s2) + if nb not in included_set: + continue + candidates.append((inst, nb, int(k), dict(attrs))) + for v, s2, k, attrs in G.in_neighbors_inst( + inst, keys=True, data=True + ): + nb = (v, s2) + if nb not in included_set: + continue + candidates.append((inst, nb, int(k), dict(attrs))) + + # Canonicalize endpoints to undirected pairs. + canon: List[Tuple[NodeInst, NodeInst, int, Dict[str, Any]]] = [] + for a, b, k, attrs in candidates: + u_inst, v_inst = stable_sorted([a, b]) + canon.append((u_inst, v_inst, k, attrs)) + + # Deduplicate reciprocal realizations deterministically. + best: Dict[ + Tuple[NodeInst, NodeInst, Optional[int]], + Tuple[Any, Dict[str, Any]], + ] = {} + for u_inst, v_inst, k, attrs in canon: + if G.is_multigraph: + eid: Tuple[ + NodeInst, NodeInst, Optional[int] + ] = (u_inst, v_inst, k) + sel_key = (u_inst, v_inst, k) + else: + eid = (u_inst, v_inst, None) + sel_key = (u_inst, v_inst, k) + + score = ( + edge_order(sel_key) + if edge_order is not None + else sel_key + ) + + if eid not in best: + best[eid] = (score, attrs) + continue + prev_score, _prev_attrs = best[eid] + try: + better = score < prev_score + except TypeError: + better = fallback_key(score) < fallback_key(prev_score) + if better: + best[eid] = (score, attrs) + + if G.is_multigraph: + out_multi: List[Tuple[Any, Any, int, Any]] = [] + for (u_inst, v_inst, kk), (sc, attrs) in best.items(): + assert kk is not None + out_multi.append((u_inst, v_inst, int(kk), (sc, attrs))) + _try_sort_patch_edges(out_multi) + for u_inst, v_inst, kk, payload in out_multi: + _sc, attrs = payload + edges_out.append((u_inst, v_inst, int(kk), dict(attrs))) + else: + out_simple: List[Tuple[Any, Any, int, Any]] = [] + for (u_inst, v_inst, _), (sc, attrs) in best.items(): + out_simple.append((u_inst, v_inst, 0, (sc, attrs))) + _try_sort_patch_edges(out_simple) + for u_inst, v_inst, _kk, payload in out_simple: + _sc, attrs = payload + edges_out.append((u_inst, v_inst, dict(attrs))) + return LiftPatch( + nodes=nodes, + edges=tuple(edges_out), + seed=seed, + radius=radius, + box=eff_box, + _is_multigraph=bool(G.is_multigraph), + _is_directed=patch_is_directed, + ) + + +TreeEdgeRec = Tuple[NodeId, NodeId, TVec, int] + + +@dataclass(frozen=True) +class CanonicalLift: + """A deterministic finite representation of a single strand. + + Attributes: + nodes: Node instances `(u, shift)` in canonical order. Contains + exactly one instance for every quotient node in the component. + strand_key: Target strand (coset) key in `Z^d / L`. + anchor_site: Quotient node chosen to be placed in `anchor_shift`. + anchor_shift: Anchor cell translation vector. + placement: Placement mode used to construct the lift. + score: Placement score (smaller is better; 0 is best). + tree_edges: Optional spanning-tree edge records for debugging. + """ + + nodes: Tuple[NodeInst, ...] + strand_key: Hashable + anchor_site: NodeId + anchor_shift: TVec + placement: str + score: Union[int, float] + tree_edges: Optional[Tuple[TreeEdgeRec, ...]] = None + + +def _sorted_nodes_by_key( + nodes: Sequence[NodeId], + node_order: Optional[Callable[[NodeId], Any]], +) -> Tuple[NodeId, ...]: + seq = list(nodes) + if not seq: + return () + + if node_order is None: + return tuple(stable_sorted(seq)) + + def k(u: NodeId) -> Any: + return node_order(u) + + try: + return tuple( + sorted(seq, key=lambda u: (k(u), fallback_key(u))) + ) + except TypeError: + return tuple( + sorted(seq, key=lambda u: (fallback_key(k(u)), fallback_key(u))) + ) + + +def _sorted_node_insts( + insts: Sequence[NodeInst], + node_order: Optional[Callable[[NodeId], Any]], +) -> Tuple[NodeInst, ...]: + seq = list(insts) + if not seq: + return () + + if node_order is None: + try: + return tuple(sorted(seq, key=lambda x: (x[0], x[1]))) + except TypeError: + return tuple(sorted(seq, key=lambda x: (fallback_key(x[0]), x[1]))) + + def k(u: NodeId) -> Any: + return node_order(u) + + try: + return tuple(sorted( + seq, key=lambda x: ( + k(x[0]), x[1], fallback_key(x[0]) + ) + )) + except TypeError: + return tuple(sorted( + seq, key=lambda x: ( + fallback_key(k(x[0])), x[1], fallback_key(x[0]) + ) + )) + + +def _compute_lift_score( + snf: Any, + rel_shifts: Dict[NodeId, TVec], + nodes: Sequence[NodeId], + score: Literal['l1', 'l2'], +) -> int: + """Compute placement score for a lift. + + Args: + snf: SNF decomposition of the component translation subgroup. + rel_shifts: Per-node relative shifts with respect to the anchor site. + nodes: Quotient node ids in the component. + score: Score metric: 'l1' or 'l2'. + + Returns: + The deterministic integer score (smaller is better). + + Raises: + CanonicalLiftError: If the SNF decomposition is invalid. + """ + r = int(snf.rank) + total = 0 + for u in nodes: + y = snf.apply_U(rel_shifts[u]) + node_mag = 0 + for i in range(r): + di = int(snf.diag[i]) + if di == 0: + raise CanonicalLiftError('invalid SNF diagonal entry') + qi = int(y[i] // di) + if score == 'l1': + node_mag += abs(qi) + else: + node_mag += qi * qi + total += node_mag + return int(total) + + +def _compute_rel_abs_shifts( + pot: Dict[NodeId, TVec], + *, + anchor_site: NodeId, + anchor_shift: TVec, +) -> Tuple[Dict[NodeId, TVec], Dict[NodeId, TVec]]: + """Compute relative and absolute shifts for a given anchor site.""" + pot_anchor = pot[anchor_site] + rel: Dict[NodeId, TVec] = {} + abs_s: Dict[NodeId, TVec] = {} + for u, pu in pot.items(): + r = sub_tvec(pu, pot_anchor) + rel[u] = r + abs_s[u] = add_tvec(anchor_shift, r) + return rel, abs_s + + +def _build_internal_adj( + component: Any, + abs_shift: Dict[NodeId, TVec], +) -> Dict[NodeId, FrozenSet[NodeId]]: + """Build induced internal undirected adjacency on selected instances. + + An undirected adjacency between quotient nodes `u` and `v` exists if at + least one directed periodic edge between them is consistent with the + selected absolute shifts. + + Args: + component: PeriodicComponent. + abs_shift: Mapping `u -> shift` for exactly the component nodes. + + Returns: + Dict mapping node id to a frozen set of adjacent node ids. + """ + adj: Dict[NodeId, set[NodeId]] = {u: set() for u in component.nodes} + for u in component.nodes: + su = abs_shift[u] + for v, t, _k in component.graph.neighbors(u, keys=True, data=False): + if v not in component.nodes: + continue + if abs_shift[v] == add_tvec(su, t): + adj[u].add(v) + adj[v].add(u) + return {u: frozenset(nbs) for u, nbs in adj.items()} + + +def _is_connected_undirected( + adj: Dict[NodeId, FrozenSet[NodeId]], + nodes_ordered: Sequence[NodeId], + *, + skip: Optional[NodeId] = None, +) -> bool: + """Return True if the induced graph is connected + (optionally skipping a node).""" + nodes = [u for u in nodes_ordered if u != skip] + if not nodes: + return True + + start = nodes[0] + seen: set[NodeId] = {start} + q: deque[NodeId] = deque([start]) + + while q: + u = q.popleft() + for v in stable_sorted(list(adj.get(u, frozenset()))): + if v == skip: + continue + if v in seen: + continue + seen.add(v) + q.append(v) + return len(seen) == len(nodes) + + +def _boundary_deltas_for_node( + component: Any, + abs_shift: Dict[NodeId, TVec], + u: NodeId, +) -> Tuple[TVec, ...]: + """Enumerate per-node deltas induced by boundary periodic edges.""" + su = abs_shift[u] + deltas: set[TVec] = set() + + for v, t, _k in component.graph.neighbors(u, keys=True, data=False): + if v not in component.nodes: + continue + desired = add_tvec(su, t) + if abs_shift[v] == desired: + continue + # Want: abs_shift[v] == (su + delta) + t + delta = sub_tvec(sub_tvec(abs_shift[v], su), t) + deltas.add(delta) + + for v, t_in, _k in component.graph.in_neighbors(u, keys=True, data=False): + if v not in component.nodes: + continue + desired_u = add_tvec(abs_shift[v], t_in) + if desired_u == su: + continue + # Want: (su + delta) == abs_shift[v] + t_in + delta = sub_tvec(desired_u, su) + deltas.add(delta) + + if not deltas: + return () + + try: + return tuple(sorted(deltas)) + except TypeError: + return tuple(sorted(deltas, key=fallback_key)) + + +def canonical_lift( + component: Any, + *, + strand_key: Optional[Hashable] = None, + seed: Optional[NodeInst] = None, + anchor_shift: Optional[TVec] = None, + placement: Literal['tree', 'best_anchor', 'greedy_cut'] = 'tree', + score: Literal['l1', 'l2'] = 'l1', + return_tree: bool = False, + node_order: Optional[Callable[[NodeId], Any]] = None, + edge_order: Optional[Callable[[Tuple[Any, ...]], Any]] = None, +) -> CanonicalLift: + """Construct a deterministic finite representation of one strand. + + v0.1.2 step4 implements `placement='tree'`, `placement='best_anchor'`, and + `placement='greedy_cut'`. + + Args: + component: A :class:`~pbcgraph.component.PeriodicComponent`. + strand_key: Optional explicit strand key. + seed: Optional seed instance `(u, shift)`. + anchor_shift: Optional anchor cell shift. + placement: Placement mode (`'tree'` in step2). + score: Score metric: `'l1'` or `'l2'`. + return_tree: If True, include spanning-tree edge records. + node_order: Optional ordering key for quotient node ids. + edge_order: Optional ordering key for periodic edges (reserved). + + Returns: + A :class:`~pbcgraph.alg.lift.CanonicalLift`. + + Raises: + CanonicalLiftError: On invalid inputs or if the requested strand does + not intersect the anchor cell. + """ + del edge_order # Reserved for later placement modes. + + if placement not in ('tree', 'best_anchor', 'greedy_cut'): + raise CanonicalLiftError( + "canonical_lift placement must be one of 'tree', " + "'best_anchor', 'greedy_cut'" + ) + + dim = int(component.graph.dim) + + if seed is not None: + u_seed, s_seed = seed + validate_tvec(s_seed, dim) + else: + u_seed = None # noqa: F841 + s_seed = None + + if anchor_shift is None: + if s_seed is not None: + anchor_shift = s_seed + else: + anchor_shift = zero_tvec(dim) + else: + validate_tvec(anchor_shift, dim) + + if strand_key is None: + if seed is not None: + try: + K = component.inst_key(seed) + except KeyError as e: + raise CanonicalLiftError( + 'seed does not belong to component' + ) from e + else: + nodes_sorted = _sorted_nodes_by_key( + list(component.nodes), node_order + ) + if not nodes_sorted: + raise CanonicalLiftError('component has no nodes') + default_seed = (nodes_sorted[0], zero_tvec(dim)) + K = component.inst_key(default_seed) + else: + K = strand_key + + eligible: List[NodeId] = [] + for u in component.nodes: + if component.inst_key((u, anchor_shift)) == K: + eligible.append(u) + + if not eligible: + raise CanonicalLiftError( + 'requested strand_key does not intersect the anchor cell' + ) + + pot = {u: component.potential(u) for u in component.nodes} + + snf = component._snf + if snf is None: + raise CanonicalLiftError('component has no SNF decomposition') + + if score not in ('l1', 'l2'): + raise CanonicalLiftError("score must be 'l1' or 'l2'") + + nodes_list = list(component.nodes) + eligible_sorted = _sorted_nodes_by_key(eligible, node_order) + + if placement == 'tree': + anchor_site = eligible_sorted[0] + rel_shift, abs_shift = _compute_rel_abs_shifts( + pot, + anchor_site=anchor_site, + anchor_shift=anchor_shift, + ) + total_score = _compute_lift_score(snf, rel_shift, nodes_list, score) + else: + best_anchor_site: Optional[NodeId] = None + best_rel: Optional[Dict[NodeId, TVec]] = None + best_abs: Optional[Dict[NodeId, TVec]] = None + best_score: Optional[int] = None + + for a in eligible_sorted: + rel_a, abs_a = _compute_rel_abs_shifts( + pot, + anchor_site=a, + anchor_shift=anchor_shift, + ) + s = _compute_lift_score(snf, rel_a, nodes_list, score) + if best_score is None or s < best_score: + best_score = int(s) + best_anchor_site = a + best_rel = rel_a + best_abs = abs_a + + if best_anchor_site is None or best_rel is None or best_abs is None: + raise CanonicalLiftError('failed to select anchor site') + + anchor_site = best_anchor_site + rel_shift = best_rel + abs_shift = best_abs + total_score = int(best_score) + + if placement == 'greedy_cut': + # Start from the best-anchor placement and perform local, per-node + # moves by elements of the translation subgroup L that improve score + # while keeping the induced internal graph connected. + nodes_sorted = _sorted_nodes_by_key(list(component.nodes), node_order) + abs_cur: Dict[NodeId, TVec] = dict(abs_shift) + cur_score = int(total_score) + + while True: + moved = False + adj = _build_internal_adj(component, abs_cur) + if not _is_connected_undirected(adj, nodes_sorted): + raise CanonicalLiftError( + 'internal induced graph is disconnected' + ) + + for u in nodes_sorted: + if u == anchor_site: + continue + deltas = _boundary_deltas_for_node(component, abs_cur, u) + if not deltas: + continue + + # Pre-filter: u must not be an articulation point of the + # current internal graph. + if not _is_connected_undirected(adj, nodes_sorted, skip=u): + continue + + best_move: Optional[Tuple[int, TVec]] = None + old_s = abs_cur[u] + + for delta in deltas: + new_s = add_tvec(old_s, delta) + if component.inst_key((u, new_s)) != K: + continue + + abs_cur[u] = new_s + new_adj = _build_internal_adj(component, abs_cur) + ok = True + if not new_adj.get(u, frozenset()): + ok = False + elif not _is_connected_undirected(new_adj, nodes_sorted): + ok = False + + if ok: + rel_tmp = { + x: sub_tvec(abs_cur[x], abs_cur[anchor_site]) + for x in component.nodes + } + s = _compute_lift_score( + snf, rel_tmp, nodes_list, score + ) + if s < cur_score: + if best_move is None: + best_move = (int(s), delta) + else: + best_s, best_delta = best_move + if int(s) < best_s or ( + int(s) == best_s and delta < best_delta + ): + best_move = (int(s), delta) + + abs_cur[u] = old_s + + if best_move is not None: + best_s, best_delta = best_move + abs_cur[u] = add_tvec(abs_cur[u], best_delta) + cur_score = int(best_s) + moved = True + break + + if not moved: + break + + abs_shift = abs_cur + total_score = int(cur_score) + + insts = [(u, abs_shift[u]) for u in component.nodes] + insts_sorted = _sorted_node_insts(insts, node_order) + + tree_edges: Optional[Tuple[TreeEdgeRec, ...]] = None + if return_tree: + recs: List[TreeEdgeRec] = [] + children = _sorted_nodes_by_key( + list(component._tree_parent.keys()), node_order + ) + for child in children: + parent, _t, k = component._tree_parent[child] + tvec = sub_tvec(abs_shift[child], abs_shift[parent]) + recs.append((parent, child, tvec, int(k))) + tree_edges = tuple(recs) + + return CanonicalLift( + nodes=insts_sorted, + strand_key=K, + anchor_site=anchor_site, + anchor_shift=anchor_shift, + placement=placement, + score=int(total_score), + tree_edges=tree_edges, + ) diff --git a/src/pbcgraph/component.py b/src/pbcgraph/component.py index ca95478..8bfbd04 100644 --- a/src/pbcgraph/component.py +++ b/src/pbcgraph/component.py @@ -14,6 +14,8 @@ from collections import deque from dataclasses import dataclass, field from typing import ( + Callable, + Any, Dict, FrozenSet, Hashable, @@ -35,6 +37,7 @@ ) from pbcgraph.core.protocols import PeriodicDiGraphLike from pbcgraph.lattice.snf import SNFDecomposition, snf_decomposition +from pbcgraph.alg.lift import CanonicalLift def _tvec_is_zero(t: TVec) -> bool: @@ -89,17 +92,21 @@ class PeriodicComponent: # Private caches. _potentials: Dict[NodeId, TVec] = field(default_factory=dict, repr=False) + _tree_parent: Dict[NodeId, Tuple[NodeId, TVec, int]] = field( + default_factory=dict, repr=False + ) _snf: Optional[SNFDecomposition] = field(default=None, repr=False) def __post_init__(self) -> None: # Compute potentials and lattice invariants eagerly for determinism. # The dataclass is frozen, so we must use `object.__setattr__` to # populate computed fields and caches during initialization. - pot = self._compute_potentials() + pot, parent = self._compute_potentials() gens = self._compute_generators(pot) dec = snf_decomposition(gens, self.graph.dim) object.__setattr__(self, '_potentials', pot) + object.__setattr__(self, '_tree_parent', parent) object.__setattr__(self, '_snf', dec) object.__setattr__(self, 'rank', dec.rank) object.__setattr__(self, 'translation_generators', tuple(gens)) @@ -259,12 +266,64 @@ def transversal_basis(self) -> Dict[str, List[TVec]]: 'torsion_moduli': torsion_moduli, } + # ----------------- + # Canonical lifts + # ----------------- + def canonical_lift( + self, + *, + strand_key: Hashable | None = None, + seed: NodeInst | None = None, + anchor_shift: TVec | None = None, + placement: str = 'tree', + score: str = 'l1', + return_tree: bool = False, + node_order: Callable[[NodeId], Any] | None = None, + edge_order: Callable[[tuple], Any] | None = None, + ) -> 'CanonicalLift': + """Return a deterministic finite representation of a single strand. + + This is a thin wrapper over :func:`pbcgraph.alg.lift.canonical_lift`. + + Args: + strand_key: Optional explicit strand (coset) key. + seed: Optional seed instance used to determine `strand_key` and/or + default `anchor_shift`. + anchor_shift: Target anchor cell shift. + placement: Placement mode. v0.1.2 step4 implements `'tree'`, + `'best_anchor'`, and `'greedy_cut'`. + score: Score metric, `'l1'` or `'l2'`. + return_tree: If True, include spanning-tree edge records. + node_order: Optional ordering key for quotient node ids. + edge_order: Optional ordering key for periodic edges (reserved for + later placement modes). + + Returns: + A :class:`~pbcgraph.alg.lift.CanonicalLift`. + """ + from pbcgraph.alg.lift import canonical_lift as _canonical_lift + + return _canonical_lift( + self, + strand_key=strand_key, + seed=seed, + anchor_shift=anchor_shift, + placement=placement, + score=score, + return_tree=return_tree, + node_order=node_order, + edge_order=edge_order, + ) + # ----------------- # Internal computations # ----------------- - def _compute_potentials(self) -> Dict[NodeId, TVec]: + def _compute_potentials( + self, + ) -> Tuple[Dict[NodeId, TVec], Dict[NodeId, Tuple[NodeId, TVec, int]]]: dim = self.graph.dim pot: Dict[NodeId, TVec] = {self.root: zero_tvec(dim)} + parent: Dict[NodeId, Tuple[NodeId, TVec, int]] = {} q = deque([self.root]) while q: @@ -278,10 +337,11 @@ def _compute_potentials(self) -> Dict[NodeId, TVec]: if v in pot: continue pot[v] = add_tvec(pu, tvec) + parent[v] = (u, tvec, int(k)) q.append(v) # Incoming edges next (weak traversal). - for v, t_in, _k in self.graph.in_neighbors( + for v, t_in, k in self.graph.in_neighbors( u, keys=True, data=False ): if v not in self.nodes: @@ -289,7 +349,9 @@ def _compute_potentials(self) -> Dict[NodeId, TVec]: if v in pot: continue pot[v] = sub_tvec(pu, t_in) + parent[v] = (u, neg_tvec(t_in), int(k)) q.append(v) + if len(pot) != len(self.nodes): # This should never happen if component extraction is correct. missing = [u for u in self.nodes if u not in pot] @@ -297,7 +359,7 @@ def _compute_potentials(self) -> Dict[NodeId, TVec]: 'component potential assignment incomplete, ' f'missing: {missing}' ) - return pot + return pot, parent def _compute_generators(self, pot: Dict[NodeId, TVec]) -> List[TVec]: gens: List[TVec] = [] diff --git a/src/pbcgraph/core/exceptions.py b/src/pbcgraph/core/exceptions.py index 0ab26a2..c347a53 100644 --- a/src/pbcgraph/core/exceptions.py +++ b/src/pbcgraph/core/exceptions.py @@ -7,3 +7,11 @@ class PBCGraphError(Exception): class StaleComponentError(PBCGraphError): """Raised when a PeriodicComponent is used after its graph has changed.""" + + +class LiftPatchError(PBCGraphError): + """Raised when finite patch extraction fails.""" + + +class CanonicalLiftError(PBCGraphError): + """Raised when canonical lift construction fails.""" diff --git a/src/pbcgraph/core/protocols.py b/src/pbcgraph/core/protocols.py index 99ebac3..4121f2b 100644 --- a/src/pbcgraph/core/protocols.py +++ b/src/pbcgraph/core/protocols.py @@ -22,6 +22,7 @@ class PeriodicDiGraphLike(Protocol): structural_version: int data_version: int is_undirected: bool + is_multigraph: bool # Nodes def has_node(self, u: NodeId) -> bool: @@ -52,5 +53,18 @@ def edges( ) -> Iterable: ... + # Lifted neighborhoods + def neighbors_inst( + self, node_inst: tuple[NodeId, TVec], keys: bool = False, + data: bool = False + ) -> Iterable: + ... + + def in_neighbors_inst( + self, node_inst: tuple[NodeId, TVec], keys: bool = False, + data: bool = False + ) -> Iterable: + ... + def edge_tvec(self, u: NodeId, v: NodeId, key: EdgeKey) -> TVec: ... diff --git a/src/pbcgraph/graph.py b/src/pbcgraph/graph.py index e6055f7..cec2abe 100644 --- a/src/pbcgraph/graph.py +++ b/src/pbcgraph/graph.py @@ -32,6 +32,7 @@ from typing import ( TYPE_CHECKING, Any, + Callable, Dict, Iterable, Iterator, @@ -46,6 +47,7 @@ if TYPE_CHECKING: from pbcgraph.component import PeriodicComponent + from pbcgraph.alg.lift import LiftPatch from pbcgraph.alg.components import components as _components @@ -170,6 +172,11 @@ def is_undirected(self) -> bool: by algorithms.""" return False + @property + def is_multigraph(self) -> bool: + """Whether this container allows multiple edges per `(u, v, tvec)`.""" + return False + def __len__(self) -> int: return self._g.number_of_nodes() @@ -748,6 +755,44 @@ def components(self) -> List['PeriodicComponent']: """Return connected components as `PeriodicComponent` objects.""" return _components(self) + # ----------------- + # Finite lifts + # ----------------- + def lift_patch( + self, + seed: NodeInst, + *, + radius: Optional[int] = None, + box: Optional[Tuple[Tuple[int, int], ...]] = None, + box_rel: Optional[Tuple[Tuple[int, int], ...]] = None, + include_edges: bool = True, + max_nodes: Optional[int] = None, + node_order: Optional[Callable[[NodeInst], Any]] = None, + edge_order: Optional[Callable[[Tuple[Any, ...]], Any]] = None, + ) -> 'LiftPatch': + """Extract a finite patch of the lifted graph. + + This is a thin wrapper over :func:`pbcgraph.alg.lift.lift_patch`. + + Notes: + For directed containers this patch is directed by default (exported + as `nx.DiGraph` / `nx.MultiDiGraph`). Use + `patch.to_networkx(as_undirected=True, ...)` for undirected views. + """ + from pbcgraph.alg.lift import lift_patch as _lift_patch + + return _lift_patch( + self, + seed, + radius=radius, + box=box, + box_rel=box_rel, + include_edges=include_edges, + max_nodes=max_nodes, + node_order=node_order, + edge_order=edge_order, + ) + class PeriodicMultiDiGraph(PeriodicDiGraph): """Directed periodic multigraph on ``Z^d``. @@ -757,6 +802,11 @@ class PeriodicMultiDiGraph(PeriodicDiGraph): their edge keys. """ + @property + def is_multigraph(self) -> bool: + """Whether this container allows multiple edges per `(u, v, tvec)`.""" + return True + def add_edge( self, u: NodeId, @@ -1082,6 +1132,11 @@ class PeriodicMultiGraph(PeriodicGraph): distinguished by their edge keys. """ + @property + def is_multigraph(self) -> bool: + """Whether this container allows multiple edges per `(u, v, tvec)`.""" + return True + def add_edge( self, u: NodeId, diff --git a/tests/test_canonical_lift_best_anchor.py b/tests/test_canonical_lift_best_anchor.py new file mode 100644 index 0000000..6f3a23b --- /dev/null +++ b/tests/test_canonical_lift_best_anchor.py @@ -0,0 +1,34 @@ +from pbcgraph import PeriodicDiGraph +from pbcgraph.alg.components import components + + +def test_canonical_lift_best_anchor_selects_min_score_anchor(): + # Build a 1D directed quotient where potentials are highly unbalanced: + # pot(A)=0, pot(B)=2, pot(C)=100 (deterministic root is A). + # Add an extra edge to make the translation subgroup L = Z so scoring is + # sensitive to absolute displacements. + G = PeriodicDiGraph(dim=1) + G.add_edge('A', 'B', (2,)) + G.add_edge('B', 'C', (98,)) + G.add_edge('C', 'A', (-99,)) # cycle generator = 1 -> L = Z + + c = components(G)[0] + K0 = c.inst_key(('A', (0,))) + assert c.inst_key(('B', (0,))) == K0 + assert c.inst_key(('C', (0,))) == K0 + + out_tree = c.canonical_lift( + anchor_shift=(0,), placement='tree', score='l1' + ) + out_best = c.canonical_lift( + anchor_shift=(0,), placement='best_anchor', score='l1' + ) + + assert out_tree.anchor_site == 'A' + assert out_best.anchor_site == 'B' + assert out_best.placement == 'best_anchor' + assert out_best.score < out_tree.score + + # Still returns exactly one instance per quotient node. + assert {u for u, _s in out_best.nodes} == {'A', 'B', 'C'} + assert len(out_best.nodes) == 3 diff --git a/tests/test_canonical_lift_greedy_cut.py b/tests/test_canonical_lift_greedy_cut.py new file mode 100644 index 0000000..6e928dd --- /dev/null +++ b/tests/test_canonical_lift_greedy_cut.py @@ -0,0 +1,72 @@ +from collections import deque + +from pbcgraph import PeriodicDiGraph +from pbcgraph.alg.components import components +from pbcgraph.core.types import add_tvec + + +def _internal_adj(component, shift_map): + adj = {u: set() for u in component.nodes} + for u in component.nodes: + su = shift_map[u] + for v, t, _k in component.graph.neighbors(u, keys=True, data=False): + if v not in component.nodes: + continue + if shift_map[v] == add_tvec(su, t): + adj[u].add(v) + adj[v].add(u) + return adj + + +def _is_connected(adj, nodes): + nodes = list(nodes) + if not nodes: + return True + start = sorted(nodes, key=str)[0] + seen = {start} + q = deque([start]) + while q: + u = q.popleft() + for v in sorted(adj[u], key=str): + if v in seen: + continue + seen.add(v) + q.append(v) + return len(seen) == len(nodes) + + +def test_canonical_lift_greedy_cut_improves_score_and_preserves_connectivity(): + # A 1D quotient where best_anchor is good but local redistribution can + # further reduce the score while preserving internal connectivity. + G = PeriodicDiGraph(dim=1) + G.add_edge('A', 'B', (2,)) + G.add_edge('B', 'C', (98,)) + # Two distinct quotient edges between C and A. The spanning-tree + # potentials use the first one in deterministic order, while + # `greedy_cut` can locally switch to the other to reduce score. + G.add_edge('C', 'A', (-100,)) + G.add_edge('C', 'A', (-99,)) + + c = components(G)[0] + + out_best = c.canonical_lift( + anchor_shift=(0,), placement='best_anchor', score='l1' + ) + out_greedy = c.canonical_lift( + anchor_shift=(0,), placement='greedy_cut', score='l1' + ) + + assert out_greedy.placement == 'greedy_cut' + assert out_greedy.anchor_site == out_best.anchor_site + assert out_greedy.anchor_shift == out_best.anchor_shift + assert out_greedy.score <= out_best.score + assert out_greedy.score < out_best.score + + shift_map = {u: s for u, s in out_greedy.nodes} + assert set(shift_map) == set(c.nodes) + + for u, s in out_greedy.nodes: + assert c.inst_key((u, s)) == out_greedy.strand_key + + adj = _internal_adj(c, shift_map) + assert _is_connected(adj, c.nodes) diff --git a/tests/test_canonical_lift_tree.py b/tests/test_canonical_lift_tree.py new file mode 100644 index 0000000..b53a5c0 --- /dev/null +++ b/tests/test_canonical_lift_tree.py @@ -0,0 +1,50 @@ +import pytest + +from pbcgraph import PeriodicDiGraph +from pbcgraph.alg.components import components +from pbcgraph.core.exceptions import CanonicalLiftError +from pbcgraph.core.types import sub_tvec + + +def test_canonical_lift_tree_basic_properties_and_tree_edges(): + G = PeriodicDiGraph(dim=1) + G.add_edge('A', 'B', (0,)) + G.add_edge('B', 'C', (0,)) + G.add_edge('C', 'A', (1,)) + + c = components(G)[0] + out = c.canonical_lift( + seed=('B', (0,)), anchor_shift=(0,), return_tree=True + ) + + assert {u for u, _s in out.nodes} == {'A', 'B', 'C'} + assert len(out.nodes) == 3 + assert out.anchor_site == 'A' + assert out.anchor_shift == (0,) + + for u, s in out.nodes: + assert c.inst_key((u, s)) == out.strand_key + + assert out.tree_edges is not None + assert len(out.tree_edges) == 2 + + shift_map = {u: s for u, s in out.nodes} + children = set() + for parent, child, tvec, key in out.tree_edges: + assert parent in shift_map + assert child in shift_map + children.add(child) + assert tvec == sub_tvec(shift_map[child], shift_map[parent]) + assert isinstance(key, int) + assert children == {'B', 'C'} + + +def test_canonical_lift_tree_raises_when_strand_absent_in_anchor_cell(): + G = PeriodicDiGraph(dim=1) + G.add_edge('A', 'A', (2,)) + + c = components(G)[0] + assert c.inst_key(('A', (0,))) == (0,) + + with pytest.raises(CanonicalLiftError): + c.canonical_lift(strand_key=(1,), anchor_shift=(0,)) diff --git a/tests/test_lift_patch.py b/tests/test_lift_patch.py new file mode 100644 index 0000000..593e867 --- /dev/null +++ b/tests/test_lift_patch.py @@ -0,0 +1,103 @@ +import networkx as nx +import pytest + +from pbcgraph import PeriodicDiGraph, PeriodicGraph, PeriodicMultiGraph +from pbcgraph.core.exceptions import LiftPatchError + + +def test_lift_patch_radius_uses_weak_connectivity(): + G = PeriodicDiGraph(dim=1) + G.add_edge('A', 'B', (0,)) + + patch = G.lift_patch(('B', (0,)), radius=1) + + assert ('B', (0,)) in patch.nodes + assert ('A', (0,)) in patch.nodes + assert len(patch.edges) == 1 + assert patch.is_directed + nxG = patch.to_networkx() + assert isinstance(nxG, nx.DiGraph) + assert (('A', (0,)), ('B', (0,))) in nxG.edges + + +def test_lift_patch_box_rel_bounds_and_termination(): + G = PeriodicGraph(dim=1) + G.add_edge('A', 'A', (1,)) + + patch = G.lift_patch(('A', (0,)), box_rel=((0, 3),)) + assert patch.nodes == (('A', (0,)), ('A', (1,)), ('A', (2,))) + assert len(patch.edges) == 2 + + +def test_lift_patch_multigraph_preserves_key_and_dedupes_reciprocals(): + G = PeriodicMultiGraph(dim=1) + G.add_edge('A', 'A', (1,), key=7, kind='bond') + + patch = G.lift_patch(('A', (0,)), box=((0, 2),)) + assert patch.nodes == (('A', (0,)), ('A', (1,))) + assert len(patch.edges) == 1 + + u, v, key, attrs = patch.edges[0] + assert {u, v} == {('A', (0,)), ('A', (1,))} + assert key == 7 + assert attrs['kind'] == 'bond' + + nxG = patch.to_networkx() + assert isinstance(nxG, nx.MultiGraph) + + +def test_lift_patch_directed_preserves_both_directions_and_exports(): + G = PeriodicDiGraph(dim=1) + G.add_edge('A', 'B', (0,), label='x') + G.add_edge('B', 'A', (0,), label='y') + + patch = G.lift_patch(('A', (0,)), radius=1) + assert patch.is_directed + assert len(patch.edges) == 2 + + nxD = patch.to_networkx() + assert isinstance(nxD, nx.DiGraph) + assert nxD.edges[('A', (0,)), ('B', (0,))]['label'] == 'x' + assert nxD.edges[('B', (0,)), ('A', (0,))]['label'] == 'y' + + nxU = patch.to_networkx(as_undirected=True, undirected_mode='multigraph') + assert isinstance(nxU, nx.MultiGraph) + assert nxU.number_of_edges(('A', (0,)), ('B', (0,))) == 2 + + labels = [] + for u, v, data in nxU.edges(data=True): + if {u, v} != {('A', (0,)), ('B', (0,))}: + continue + labels.append(data['label']) + assert data['_pbc_tail'] in {('A', (0,)), ('B', (0,))} + assert data['_pbc_head'] in {('A', (0,)), ('B', (0,))} + assert sorted(labels) == ['x', 'y'] + + nxC = patch.to_networkx(as_undirected=True, undirected_mode='orig_edges') + assert isinstance(nxC, nx.Graph) + data = nxC.edges[('A', (0,)), ('B', (0,))] + assert 'orig_edges' in data + assert len(data['orig_edges']) == 2 + labels2 = sorted([rec['attrs']['label'] for rec in data['orig_edges']]) + assert labels2 == ['x', 'y'] + + +def test_lift_patch_to_networkx_snapshots_edge_attrs(): + G = PeriodicGraph(dim=1) + k = G.add_edge('A', 'A', (1,), kind='bond') + + patch = G.lift_patch(('A', (0,)), box=((0, 2),)) + nxG = patch.to_networkx() + + (u, v) = list(nxG.edges())[0] + nxG.edges[u, v]['kind'] = 'modified' + + # Underlying pbcgraph edge attrs are unaffected. + assert G.get_edge_data('A', 'A', k)['kind'] == 'bond' + + +def test_lift_patch_requires_a_finiteness_constraint(): + G = PeriodicGraph(dim=1) + G.add_edge('A', 'A', (1,)) + with pytest.raises(LiftPatchError): + G.lift_patch(('A', (0,)))