Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,9 @@ doc = [
"jupytext>=1.17.2",
]

[tool.uv.sources]
pyzx-param = { git = "https://github.com/rafaelha/pyzx", rev = "f4d440c299b71f576ccc369886ba319fb51dd2e2" }

Comment on lines +97 to +99
Copy link

Copilot AI Mar 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The project code now imports DecompositionStrategy and passes strategy=... into pyzx_param.simulate.replace_* APIs. However, [project.dependencies] still allows pyzx-param>=0.9.2 from PyPI, while only uv users will pick up the git revision via [tool.uv.sources]. To avoid broken installs for pip/packaging users, either (a) bump the minimum released pyzx-param version that contains these APIs, or (b) switch the main dependency to a direct URL reference (PEP 508) matching this git rev.

Copilot uses AI. Check for mistakes.
[tool.hatch.metadata]
allow-direct-references = true

Expand Down
23 changes: 19 additions & 4 deletions src/tsim/circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import pyzx_param as zx
import stim
from pyzx_param.graph.base import BaseGraph
from pyzx_param.simulate import DecompositionStrategy

from tsim.core.graph import build_sampling_graph
from tsim.core.parse import parse_parametric_tag, parse_stim_circuit
Expand Down Expand Up @@ -659,10 +660,17 @@ def get_sampling_graph(self, sample_detectors: bool = False) -> BaseGraph:
built = parse_stim_circuit(self._stim_circ)
return build_sampling_graph(built, sample_detectors=sample_detectors)

def compile_sampler(self, *, seed: int | None = None):
def compile_sampler(
self,
*,
strategy: DecompositionStrategy = "cat5",
seed: int | None = None,
):
"""Compile circuit into a measurement sampler.

Args:
strategy: Stabilizer rank decomposition strategy.
Must be one of "cat5", "bss", "cutting".
seed: Random seed for the sampler. If None, a random seed will be generated.

Returns:
Expand All @@ -671,12 +679,19 @@ def compile_sampler(self, *, seed: int | None = None):
"""
from tsim.sampler import CompiledMeasurementSampler

return CompiledMeasurementSampler(self, seed=seed)
return CompiledMeasurementSampler(self, seed=seed, strategy=strategy)

def compile_detector_sampler(self, *, seed: int | None = None):
def compile_detector_sampler(
self,
*,
strategy: DecompositionStrategy = "cat5",
seed: int | None = None,
):
"""Compile circuit into a detector sampler.

Args:
strategy: Stabilizer rank decomposition strategy.
Must be one of "cat5", "bss", "cutting".
seed: Random seed for the sampler. If None, a random seed will be generated.

Returns:
Expand All @@ -685,7 +700,7 @@ def compile_detector_sampler(self, *, seed: int | None = None):
"""
from tsim.sampler import CompiledDetectorSampler

return CompiledDetectorSampler(self, seed=seed)
return CompiledDetectorSampler(self, seed=seed, strategy=strategy)

def cast_to_stim(self) -> stim.Circuit:
"""Return self with type cast to stim.Circuit. This is useful for passing the circuit to functions that expect a stim.Circuit."""
Expand Down
9 changes: 8 additions & 1 deletion src/tsim/compile/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import jax.numpy as jnp
import pyzx_param as zx
from pyzx_param.graph.base import BaseGraph
from pyzx_param.simulate import DecompositionStrategy

from tsim.compile.compile import CompiledScalarGraphs, compile_scalar_graphs
from tsim.compile.stabrank import find_stab
Expand All @@ -20,6 +21,7 @@ def compile_program(
prepared: SamplingGraph,
*,
mode: DecompositionMode,
strategy: DecompositionStrategy = "cat5",
) -> CompiledProgram:
"""Compile a prepared graph into an executable sampling program.

Expand All @@ -37,6 +39,8 @@ def compile_program(
mode: Decomposition mode:
- "sequential": For sampling - creates [0, 1, 2, ..., n] circuits
- "joint": For probability estimation - creates [0, n] circuits
strategy: Stabilizer rank decomposition strategy.
Must be one of "cat5", "bss", "cutting".

Returns:
A CompiledProgram ready for sampling.
Expand All @@ -58,6 +62,7 @@ def compile_program(
component=component,
f_indices_global=f_indices_global,
mode=mode,
strategy=strategy,
)
compiled_components.append(compiled)
output_order.extend(component.output_indices)
Expand Down Expand Up @@ -89,13 +94,15 @@ def _compile_component(
component: ConnectedComponent,
f_indices_global: list[int],
mode: DecompositionMode,
strategy: DecompositionStrategy = "cat5",
) -> CompiledComponent:
"""Compile a single connected component.

Args:
component: The connected component to compile.
f_indices_global: Global list of all f-parameter indices (numerically sorted).
mode: Decomposition mode (sequential or joint).
strategy: Stabilizer rank decomposition strategy.

Returns:
A CompiledComponent ready for sampling.
Expand Down Expand Up @@ -142,7 +149,7 @@ def _compile_component(
param_names += [f"m{output_indices[j]}" for j in range(num_m_plugged)]

# Perform stabilizer rank decomposition and compile
g_list = find_stab(g_copy)
g_list = find_stab(g_copy, strategy=strategy)

if len(g_list) == 1:
# This is a Clifford graph, we can clear the global phase terms
Expand Down
22 changes: 15 additions & 7 deletions src/tsim/compile/stabrank.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import pyzx_param as zx
from pyzx_param.graph.base import BaseGraph
from pyzx_param.simulate import DecompositionStrategy


def _decompose(
Expand All @@ -30,25 +31,31 @@ def _decompose(
return results


def find_stab_magic(graphs: Iterable[BaseGraph]) -> list[BaseGraph]:
def find_stab_magic(
graphs: Iterable[BaseGraph], strategy: DecompositionStrategy
) -> list[BaseGraph]:
"""Recursively decompose ZX-graphs into stabilizer components via magic-state removal."""
return _decompose(
list(graphs),
count_fn=zx.simplify.tcount,
replace_fn=lambda g: zx.simulate.replace_magic_states(g, pick_random=False),
replace_fn=lambda g: zx.simulate.replace_magic_states(
g, pick_random=False, strategy=strategy
),
)


def find_stab_u3(graphs: Iterable[BaseGraph]) -> list[BaseGraph]:
def find_stab_u3(
graphs: Iterable[BaseGraph], strategy: DecompositionStrategy
) -> list[BaseGraph]:
"""Recursively decompose ZX-graphs by removing U3 phases."""
return _decompose(
list(graphs),
count_fn=zx.simplify.u3_count,
replace_fn=zx.simulate.replace_u3_states,
replace_fn=lambda g: zx.simulate.replace_u3_states(g, strategy=strategy),
)


def find_stab(graph: BaseGraph) -> list[BaseGraph]:
def find_stab(graph: BaseGraph, strategy: DecompositionStrategy) -> list[BaseGraph]:
"""Decompose a ZX-graph into a sum of stabilizer components.

This is the main entry point for stabilizer rank decomposition. It first removes
Expand All @@ -57,11 +64,12 @@ def find_stab(graph: BaseGraph) -> list[BaseGraph]:

Args:
graph: The ZX graph to decompose.
strategy: Decomposition strategy. Must be one of "cat5", "bss", "cutting".

Returns:
A list of scalar graphs whose sum equals the original graph.

"""
zx.full_reduce(graph, paramSafe=True)
graphs = find_stab_u3([graph])
return find_stab_magic(graphs)
graphs = find_stab_u3([graph], strategy=strategy)
return find_stab_magic(graphs, strategy=strategy)
51 changes: 45 additions & 6 deletions src/tsim/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import jax
import jax.numpy as jnp
import numpy as np
from pyzx_param.simulate import DecompositionStrategy

from tsim.compile.evaluate import evaluate
from tsim.compile.pipeline import compile_program
Expand Down Expand Up @@ -138,6 +139,7 @@ def __init__(
*,
sample_detectors: bool,
mode: Literal["sequential", "joint"],
strategy: DecompositionStrategy = "cat5",
seed: int | None = None,
):
"""Initialize the sampler by compiling the circuit.
Expand All @@ -146,6 +148,8 @@ def __init__(
circuit: The quantum circuit to compile.
sample_detectors: If True, sample detectors/observables instead of measurements.
mode: Compilation mode - "sequential" for autoregressive, "joint" for probabilities.
strategy: Stabilizer rank decomposition strategy.
Must be one of "cat5", "bss", "cutting".
seed: Random seed. If None, a random seed is generated.

"""
Expand All @@ -155,7 +159,7 @@ def __init__(
self._key = jax.random.key(seed)

prepared = prepare_graph(circuit, sample_detectors=sample_detectors)
self._program = compile_program(prepared, mode=mode)
self._program = compile_program(prepared, mode=mode, strategy=strategy)

channel_seed = int(np.random.default_rng(seed).integers(0, 2**30))
self._channel_sampler = ChannelSampler(
Expand Down Expand Up @@ -242,15 +246,29 @@ class CompiledMeasurementSampler(_CompiledSamplerBase):
- compiled_scalar_graphs[i]: cumulative probability up to bit i
"""

def __init__(self, circuit: Circuit, *, seed: int | None = None):
def __init__(
self,
circuit: Circuit,
*,
strategy: DecompositionStrategy = "cat5",
seed: int | None = None,
):
"""Create a measurement sampler.

Args:
circuit: The quantum circuit to compile.
strategy: Stabilizer rank decomposition strategy.
Must be one of "cat5", "bss", "cutting".
seed: Random seed for JAX. If None, a random seed is generated.

"""
super().__init__(circuit, sample_detectors=False, mode="sequential", seed=seed)
super().__init__(
circuit,
sample_detectors=False,
mode="sequential",
seed=seed,
strategy=strategy,
)

def sample(self, shots: int, *, batch_size: int = 1024) -> np.ndarray:
"""Sample measurement outcomes from the circuit.
Expand Down Expand Up @@ -278,15 +296,29 @@ def _maybe_bit_pack(array: np.ndarray, *, bit_packed: bool) -> np.ndarray:
class CompiledDetectorSampler(_CompiledSamplerBase):
"""Samples detector and observable outcomes from a quantum circuit."""

def __init__(self, circuit: Circuit, *, seed: int | None = None):
def __init__(
self,
circuit: Circuit,
*,
strategy: DecompositionStrategy = "cat5",
seed: int | None = None,
):
"""Create a detector sampler.

Args:
circuit: The quantum circuit to compile.
strategy: Stabilizer rank decomposition strategy.
Must be one of "cat5", "bss", "cutting".
seed: Random seed for JAX. If None, a random seed is generated.

"""
super().__init__(circuit, sample_detectors=True, mode="sequential", seed=seed)
super().__init__(
circuit,
sample_detectors=True,
mode="sequential",
seed=seed,
strategy=strategy,
)

@overload
def sample(
Expand Down Expand Up @@ -381,18 +413,25 @@ def __init__(
circuit: Circuit,
*,
sample_detectors: bool = False,
strategy: DecompositionStrategy = "cat5",
seed: int | None = None,
):
"""Create a probability estimator.

Args:
circuit: The quantum circuit to compile.
sample_detectors: If True, compute detector/observable probabilities.
strategy: Stabilizer rank decomposition strategy.
Must be one of "cat5", "bss", "cutting".
seed: Random seed for JAX. If None, a random seed is generated.

"""
super().__init__(
circuit, sample_detectors=sample_detectors, mode="joint", seed=seed
circuit,
sample_detectors=sample_detectors,
mode="joint",
seed=seed,
strategy=strategy,
)

def probability_of(self, state: np.ndarray, *, batch_size: int) -> np.ndarray:
Expand Down
Loading
Loading