diff --git a/pyproject.toml b/pyproject.toml index b8a880b..c7c59d2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -94,6 +94,9 @@ doc = [ "jupytext>=1.17.2", ] +[tool.uv.sources] +pyzx-param = { git = "https://github.com/rafaelha/pyzx", rev = "f4d440c299b71f576ccc369886ba319fb51dd2e2" } + [tool.hatch.metadata] allow-direct-references = true diff --git a/src/tsim/circuit.py b/src/tsim/circuit.py index 90519f4..5404c55 100644 --- a/src/tsim/circuit.py +++ b/src/tsim/circuit.py @@ -7,6 +7,7 @@ import pyzx_param as zx import stim from pyzx_param.graph.base import BaseGraph +from pyzx_param.simulate import DecompositionStrategy from tsim.core.graph import build_sampling_graph from tsim.core.parse import parse_parametric_tag, parse_stim_circuit @@ -659,10 +660,17 @@ def get_sampling_graph(self, sample_detectors: bool = False) -> BaseGraph: built = parse_stim_circuit(self._stim_circ) return build_sampling_graph(built, sample_detectors=sample_detectors) - def compile_sampler(self, *, seed: int | None = None): + def compile_sampler( + self, + *, + strategy: DecompositionStrategy = "cat5", + seed: int | None = None, + ): """Compile circuit into a measurement sampler. Args: + strategy: Stabilizer rank decomposition strategy. + Must be one of "cat5", "bss", "cutting". seed: Random seed for the sampler. If None, a random seed will be generated. Returns: @@ -671,12 +679,19 @@ def compile_sampler(self, *, seed: int | None = None): """ from tsim.sampler import CompiledMeasurementSampler - return CompiledMeasurementSampler(self, seed=seed) + return CompiledMeasurementSampler(self, seed=seed, strategy=strategy) - def compile_detector_sampler(self, *, seed: int | None = None): + def compile_detector_sampler( + self, + *, + strategy: DecompositionStrategy = "cat5", + seed: int | None = None, + ): """Compile circuit into a detector sampler. Args: + strategy: Stabilizer rank decomposition strategy. + Must be one of "cat5", "bss", "cutting". seed: Random seed for the sampler. If None, a random seed will be generated. Returns: @@ -685,7 +700,7 @@ def compile_detector_sampler(self, *, seed: int | None = None): """ from tsim.sampler import CompiledDetectorSampler - return CompiledDetectorSampler(self, seed=seed) + return CompiledDetectorSampler(self, seed=seed, strategy=strategy) def cast_to_stim(self) -> stim.Circuit: """Return self with type cast to stim.Circuit. This is useful for passing the circuit to functions that expect a stim.Circuit.""" diff --git a/src/tsim/compile/pipeline.py b/src/tsim/compile/pipeline.py index 6522aad..d462bf4 100644 --- a/src/tsim/compile/pipeline.py +++ b/src/tsim/compile/pipeline.py @@ -7,6 +7,7 @@ import jax.numpy as jnp import pyzx_param as zx from pyzx_param.graph.base import BaseGraph +from pyzx_param.simulate import DecompositionStrategy from tsim.compile.compile import CompiledScalarGraphs, compile_scalar_graphs from tsim.compile.stabrank import find_stab @@ -20,6 +21,7 @@ def compile_program( prepared: SamplingGraph, *, mode: DecompositionMode, + strategy: DecompositionStrategy = "cat5", ) -> CompiledProgram: """Compile a prepared graph into an executable sampling program. @@ -37,6 +39,8 @@ def compile_program( mode: Decomposition mode: - "sequential": For sampling - creates [0, 1, 2, ..., n] circuits - "joint": For probability estimation - creates [0, n] circuits + strategy: Stabilizer rank decomposition strategy. + Must be one of "cat5", "bss", "cutting". Returns: A CompiledProgram ready for sampling. @@ -58,6 +62,7 @@ def compile_program( component=component, f_indices_global=f_indices_global, mode=mode, + strategy=strategy, ) compiled_components.append(compiled) output_order.extend(component.output_indices) @@ -89,6 +94,7 @@ def _compile_component( component: ConnectedComponent, f_indices_global: list[int], mode: DecompositionMode, + strategy: DecompositionStrategy = "cat5", ) -> CompiledComponent: """Compile a single connected component. @@ -96,6 +102,7 @@ def _compile_component( component: The connected component to compile. f_indices_global: Global list of all f-parameter indices (numerically sorted). mode: Decomposition mode (sequential or joint). + strategy: Stabilizer rank decomposition strategy. Returns: A CompiledComponent ready for sampling. @@ -142,7 +149,7 @@ def _compile_component( param_names += [f"m{output_indices[j]}" for j in range(num_m_plugged)] # Perform stabilizer rank decomposition and compile - g_list = find_stab(g_copy) + g_list = find_stab(g_copy, strategy=strategy) if len(g_list) == 1: # This is a Clifford graph, we can clear the global phase terms diff --git a/src/tsim/compile/stabrank.py b/src/tsim/compile/stabrank.py index febfd13..0173e1f 100644 --- a/src/tsim/compile/stabrank.py +++ b/src/tsim/compile/stabrank.py @@ -4,6 +4,7 @@ import pyzx_param as zx from pyzx_param.graph.base import BaseGraph +from pyzx_param.simulate import DecompositionStrategy def _decompose( @@ -30,25 +31,31 @@ def _decompose( return results -def find_stab_magic(graphs: Iterable[BaseGraph]) -> list[BaseGraph]: +def find_stab_magic( + graphs: Iterable[BaseGraph], strategy: DecompositionStrategy +) -> list[BaseGraph]: """Recursively decompose ZX-graphs into stabilizer components via magic-state removal.""" return _decompose( list(graphs), count_fn=zx.simplify.tcount, - replace_fn=lambda g: zx.simulate.replace_magic_states(g, pick_random=False), + replace_fn=lambda g: zx.simulate.replace_magic_states( + g, pick_random=False, strategy=strategy + ), ) -def find_stab_u3(graphs: Iterable[BaseGraph]) -> list[BaseGraph]: +def find_stab_u3( + graphs: Iterable[BaseGraph], strategy: DecompositionStrategy +) -> list[BaseGraph]: """Recursively decompose ZX-graphs by removing U3 phases.""" return _decompose( list(graphs), count_fn=zx.simplify.u3_count, - replace_fn=zx.simulate.replace_u3_states, + replace_fn=lambda g: zx.simulate.replace_u3_states(g, strategy=strategy), ) -def find_stab(graph: BaseGraph) -> list[BaseGraph]: +def find_stab(graph: BaseGraph, strategy: DecompositionStrategy) -> list[BaseGraph]: """Decompose a ZX-graph into a sum of stabilizer components. This is the main entry point for stabilizer rank decomposition. It first removes @@ -57,11 +64,12 @@ def find_stab(graph: BaseGraph) -> list[BaseGraph]: Args: graph: The ZX graph to decompose. + strategy: Decomposition strategy. Must be one of "cat5", "bss", "cutting". Returns: A list of scalar graphs whose sum equals the original graph. """ zx.full_reduce(graph, paramSafe=True) - graphs = find_stab_u3([graph]) - return find_stab_magic(graphs) + graphs = find_stab_u3([graph], strategy=strategy) + return find_stab_magic(graphs, strategy=strategy) diff --git a/src/tsim/sampler.py b/src/tsim/sampler.py index 22579e3..eefc2f7 100644 --- a/src/tsim/sampler.py +++ b/src/tsim/sampler.py @@ -8,6 +8,7 @@ import jax import jax.numpy as jnp import numpy as np +from pyzx_param.simulate import DecompositionStrategy from tsim.compile.evaluate import evaluate from tsim.compile.pipeline import compile_program @@ -138,6 +139,7 @@ def __init__( *, sample_detectors: bool, mode: Literal["sequential", "joint"], + strategy: DecompositionStrategy = "cat5", seed: int | None = None, ): """Initialize the sampler by compiling the circuit. @@ -146,6 +148,8 @@ def __init__( circuit: The quantum circuit to compile. sample_detectors: If True, sample detectors/observables instead of measurements. mode: Compilation mode - "sequential" for autoregressive, "joint" for probabilities. + strategy: Stabilizer rank decomposition strategy. + Must be one of "cat5", "bss", "cutting". seed: Random seed. If None, a random seed is generated. """ @@ -155,7 +159,7 @@ def __init__( self._key = jax.random.key(seed) prepared = prepare_graph(circuit, sample_detectors=sample_detectors) - self._program = compile_program(prepared, mode=mode) + self._program = compile_program(prepared, mode=mode, strategy=strategy) channel_seed = int(np.random.default_rng(seed).integers(0, 2**30)) self._channel_sampler = ChannelSampler( @@ -242,15 +246,29 @@ class CompiledMeasurementSampler(_CompiledSamplerBase): - compiled_scalar_graphs[i]: cumulative probability up to bit i """ - def __init__(self, circuit: Circuit, *, seed: int | None = None): + def __init__( + self, + circuit: Circuit, + *, + strategy: DecompositionStrategy = "cat5", + seed: int | None = None, + ): """Create a measurement sampler. Args: circuit: The quantum circuit to compile. + strategy: Stabilizer rank decomposition strategy. + Must be one of "cat5", "bss", "cutting". seed: Random seed for JAX. If None, a random seed is generated. """ - super().__init__(circuit, sample_detectors=False, mode="sequential", seed=seed) + super().__init__( + circuit, + sample_detectors=False, + mode="sequential", + seed=seed, + strategy=strategy, + ) def sample(self, shots: int, *, batch_size: int = 1024) -> np.ndarray: """Sample measurement outcomes from the circuit. @@ -278,15 +296,29 @@ def _maybe_bit_pack(array: np.ndarray, *, bit_packed: bool) -> np.ndarray: class CompiledDetectorSampler(_CompiledSamplerBase): """Samples detector and observable outcomes from a quantum circuit.""" - def __init__(self, circuit: Circuit, *, seed: int | None = None): + def __init__( + self, + circuit: Circuit, + *, + strategy: DecompositionStrategy = "cat5", + seed: int | None = None, + ): """Create a detector sampler. Args: circuit: The quantum circuit to compile. + strategy: Stabilizer rank decomposition strategy. + Must be one of "cat5", "bss", "cutting". seed: Random seed for JAX. If None, a random seed is generated. """ - super().__init__(circuit, sample_detectors=True, mode="sequential", seed=seed) + super().__init__( + circuit, + sample_detectors=True, + mode="sequential", + seed=seed, + strategy=strategy, + ) @overload def sample( @@ -381,6 +413,7 @@ def __init__( circuit: Circuit, *, sample_detectors: bool = False, + strategy: DecompositionStrategy = "cat5", seed: int | None = None, ): """Create a probability estimator. @@ -388,11 +421,17 @@ def __init__( Args: circuit: The quantum circuit to compile. sample_detectors: If True, compute detector/observable probabilities. + strategy: Stabilizer rank decomposition strategy. + Must be one of "cat5", "bss", "cutting". seed: Random seed for JAX. If None, a random seed is generated. """ super().__init__( - circuit, sample_detectors=sample_detectors, mode="joint", seed=seed + circuit, + sample_detectors=sample_detectors, + mode="joint", + seed=seed, + strategy=strategy, ) def probability_of(self, state: np.ndarray, *, batch_size: int) -> np.ndarray: diff --git a/test/integration/test_sampler.py b/test/integration/test_sampler.py index 66d7b46..9089318 100644 --- a/test/integration/test_sampler.py +++ b/test/integration/test_sampler.py @@ -5,6 +5,7 @@ import pytest import pyzx_param as zx import stim +from pyzx_param.simulate import DecompositionStrategy from tqdm import tqdm from tsim.circuit import Circuit @@ -58,7 +59,10 @@ def assert_samples_match(samples1: np.ndarray, samples2: np.ndarray): "color_code:memory_xyz", ], ) -def test_quantum_memory_codes_without_noise(code_task: str): +@pytest.mark.parametrize("strategy", ["cat5", "bss", "cutting"]) +def test_quantum_memory_codes_without_noise( + code_task: str, strategy: DecompositionStrategy +): circ = stim.Circuit.generated( code_task, @@ -70,13 +74,14 @@ def test_quantum_memory_codes_without_noise(code_task: str): after_reset_flip_probability=0.0, ) c = Circuit.from_stim_program(circ) - sampler = c.compile_detector_sampler(seed=0) + sampler = c.compile_detector_sampler(strategy=strategy, seed=0) samples = sampler.sample(10) assert not np.any(samples) @pytest.mark.parametrize("seed", [1, 2, 42]) -def test_sampler(seed): +@pytest.mark.parametrize("strategy", ["cat5", "bss", "cutting"]) +def test_sampler(seed, strategy): num_qubits = 3 stim_circuit = gen_stim_circuit( num_qubits, 100, include_measurements=True, seed=seed @@ -88,7 +93,7 @@ def test_sampler(seed): # Sample from both simulators stim_sampler = VecSampler(stim_circuit) - sampler = circuit.compile_sampler(seed=seed) + sampler = circuit.compile_sampler(strategy=strategy, seed=seed) stim_samples, _, _ = stim_sampler.sample(n_samples) tsim_samples = sampler.sample(n_samples, batch_size=batch_size) @@ -326,11 +331,14 @@ def simulate_with_vec_sampler(stim_circuit: stim.Circuit) -> np.ndarray: return np.abs(state_vector) ** 2 -def simulate_with_tsim(stim_circuit: stim.Circuit) -> np.ndarray: +def simulate_with_tsim( + stim_circuit: stim.Circuit, strategy: DecompositionStrategy = "cat5" +) -> np.ndarray: """Compute state probabilities using tsim's CompiledStateProbs. Args: stim_circuit: The stim circuit (with tags) to simulate. Should not include measurements. + strategy: Stabilizer rank decomposition strategy. Returns: Array of probabilities for each computational basis state. @@ -341,7 +349,7 @@ def simulate_with_tsim(stim_circuit: stim.Circuit) -> np.ndarray: "M " + " ".join([str(i) for i in range(stim_circuit.num_qubits)]) ) circuit = Circuit.from_stim_program(stim_circuit_with_m) - prob_sampler = CompiledStateProbs(circuit) + prob_sampler = CompiledStateProbs(circuit, strategy=strategy) probabilities = [] for i in range(2**num_qubits): @@ -376,14 +384,15 @@ def simulate_with_pyzx_tensor(stim_circuit: stim.Circuit) -> np.ndarray: @pytest.mark.parametrize("num_qubits", [3, 4]) @pytest.mark.parametrize("seed", [1, 2]) -def test_compare_to_statevector_simulator_and_pyzx_tensor(num_qubits, seed): +@pytest.mark.parametrize("strategy", ["cat5", "bss", "cutting"]) +def test_compare_to_statevector_simulator_and_pyzx_tensor(num_qubits, seed, strategy): stim_circuit = gen_stim_circuit( num_qubits, 100, include_measurements=False, seed=seed, ) - tsim_state_vector = simulate_with_tsim(stim_circuit) + tsim_state_vector = simulate_with_tsim(stim_circuit, strategy=strategy) pyzx_state_vector = simulate_with_pyzx_tensor(stim_circuit) stim_state_vector = simulate_with_vec_sampler(stim_circuit) @@ -393,8 +402,9 @@ def test_compare_to_statevector_simulator_and_pyzx_tensor(num_qubits, seed): @pytest.mark.parametrize("num_qubits", [3, 4]) @pytest.mark.parametrize("seed", [2, 42]) +@pytest.mark.parametrize("strategy", ["cat5", "bss", "cutting"]) def test_compare_to_statevector_simulator_and_pyzx_tensor_with_arbitrary_rotations( - num_qubits, seed + num_qubits, seed, strategy ): stim_circuit = gen_stim_circuit( num_qubits, @@ -409,7 +419,7 @@ def test_compare_to_statevector_simulator_and_pyzx_tensor_with_arbitrary_rotatio c = Circuit.from_stim_program(stim_circuit) assert zx.simplify.u3_count(c.get_graph()) > 0 - tsim_state_vector = simulate_with_tsim(stim_circuit) + tsim_state_vector = simulate_with_tsim(stim_circuit, strategy=strategy) pyzx_state_vector = simulate_with_pyzx_tensor(stim_circuit) stim_state_vector = simulate_with_vec_sampler(stim_circuit) diff --git a/uv.lock b/uv.lock index 7bc51c6..3e022a8 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 3 +revision = 2 requires-python = ">=3.10" resolution-markers = [ "python_full_version >= '3.13'", @@ -212,7 +212,7 @@ requires-dist = [ { name = "jax", extras = ["cuda13"], marker = "extra == 'cuda13'", specifier = ">=0.6.0" }, { name = "lxml", specifier = ">=5.0.0" }, { name = "numpy", specifier = ">=1.25.0" }, - { name = "pyzx-param", specifier = ">=0.9.2" }, + { name = "pyzx-param", git = "https://github.com/rafaelha/pyzx?rev=f4d440c299b71f576ccc369886ba319fb51dd2e2" }, { name = "stim", specifier = ">=1.0.0" }, ] provides-extras = ["cuda12", "cuda13"] @@ -965,15 +965,15 @@ wheels = [ [[package]] name = "griffe-kirin" -version = "0.2.0" -source = { registry = "https://pypi.org/simple" } +version = "0.1.0" +source = { registry = "https://quera.jfrog.io/artifactory/api/pypi/kirin/simple" } dependencies = [ { name = "griffe" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/a5/e0/9db75de974546903ab46a41179a95243abc84ceb2ba0b4101b5016f5301e/griffe_kirin-0.2.0.tar.gz", hash = "sha256:87d4280fbb1c77e1eb3fe7b627810b219295ec74651cf216748330b379a4ded5", size = 57664, upload-time = "2025-02-21T20:41:16.862Z" } +sdist = { url = "https://quera.jfrog.io/artifactory/api/pypi/kirin/griffe-kirin/0.1.0/griffe_kirin-0.1.0.tar.gz", hash = "sha256:b5ed85a8c07da0219787ac8d57de0a4f2a6df3109fc3aab47d78a12f96dda04d" } wheels = [ - { url = "https://files.pythonhosted.org/packages/38/b2/04bf90dc9e9e80f969b562731cf1619bd2eb0d0b726ba7befd38fda38131/griffe_kirin-0.2.0-py3-none-any.whl", hash = "sha256:81d999d6e1480cbe46e5c7b82933d2eee96863840dfbac283df620d85c13c2fc", size = 4063, upload-time = "2025-02-21T20:41:14.734Z" }, + { url = "https://quera.jfrog.io/artifactory/api/pypi/kirin/griffe-kirin/0.1.0/griffe_kirin-0.1.0-py3-none-any.whl", hash = "sha256:cdbc7a08c4fa5229f8f5f9137fd4de7dde35214166d3d448cbac8778079bd10a" }, ] [[package]] @@ -3613,7 +3613,7 @@ wheels = [ [[package]] name = "pyzx-param" version = "0.9.2" -source = { registry = "https://pypi.org/simple" } +source = { git = "https://github.com/rafaelha/pyzx?rev=f4d440c299b71f576ccc369886ba319fb51dd2e2#f4d440c299b71f576ccc369886ba319fb51dd2e2" } dependencies = [ { name = "lark" }, { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, @@ -3623,10 +3623,6 @@ dependencies = [ { name = "tqdm" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/22/d3/160c226916130b4995de2e8e95949ed065fe26636c7d359fecabbf66884d/pyzx_param-0.9.2.tar.gz", hash = "sha256:8a4c7136f12c974f032934fa2fa0f66489b9b95eebd6850a3fa41e9adc6d7a98", size = 357647, upload-time = "2026-01-29T01:31:06.499Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/9b/40/ca0b3e02add85fc8e3e7ad88c99c454f86a3f55e3cc3b08b02175be112ab/pyzx_param-0.9.2-py3-none-any.whl", hash = "sha256:32ccd56c172e88102e21e9660d5ba05deff00ea862ba35f080b73d9a130b45ee", size = 393805, upload-time = "2026-01-29T01:31:05.158Z" }, -] [[package]] name = "referencing"