Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/tsim/circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -575,6 +575,11 @@ def compile_sampler(self, *, seed: int | None = None):
def compile_detector_sampler(self, *, seed: int | None = None):
"""Compile circuit into a detector sampler.

Connected components whose single output is deterministically given by
one f-variable are handled via a fast direct path (no compilation or
autoregressive sampling). Remaining components go through the full
compilation pipeline.

Args:
seed: Random seed for the sampler. If None, a random seed will be generated.

Expand Down
107 changes: 99 additions & 8 deletions src/tsim/compile/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

from fractions import Fraction
from typing import Literal

import jax.numpy as jnp
Expand All @@ -16,6 +17,63 @@
DecompositionMode = Literal["sequential", "joint"]


def _classify_direct(
component: ConnectedComponent,
) -> tuple[int, bool] | None:
"""Check if a component is directly determined by a single f-variable.

A component qualifies when its graph consists of exactly two vertices — one
boundary output and one Z-spider — connected by a Hadamard edge, where the
Z-spider carries a single ``f`` parameter and a constant phase of either 0
(no flip) or π (flip).

Returns:
``(output_index, f_index, flip)`` if the fast path applies, otherwise
``None``.

"""
graph = component.graph
outputs = list(graph.outputs())
if len(outputs) != 1:
return None

vertices = list(graph.vertices())
if len(vertices) != 2:
return None

v_out = outputs[0]
neighbors = list(graph.neighbors(v_out))
if len(neighbors) != 1:
return None

v_det = neighbors[0]
if graph.type(v_det) != zx.utils.VertexType.Z:
return None
if graph.edge_type(graph.edge(v_out, v_det)) != zx.utils.EdgeType.HADAMARD:
return None

params = graph.get_params(v_det)
if len(params) != 1:
return None
f_param = next(iter(params))
if not f_param.startswith("f"):
return None

all_graph_params = get_params(graph)
if all_graph_params != {f_param}:
return None

phase = graph.phase(v_det)
if phase == 0:
flip = False
elif phase == Fraction(1, 1):
flip = True
else:
return None

return int(f_param[1:]), flip


def compile_program(
prepared: SamplingGraph,
*,
Expand Down Expand Up @@ -48,23 +106,56 @@ def compile_program(
f_indices_global = _get_f_indices(prepared.graph)
num_outputs = prepared.num_outputs

direct_f_indices: list[int] = []
direct_flips: list[bool] = []
direct_output_order: list[int] = []
compiled_components: list[CompiledComponent] = []
output_order: list[int] = []
compiled_output_order: list[int] = []

sorted_components = sorted(components, key=lambda c: len(c.output_indices))

for component in sorted_components:
compiled = _compile_component(
component=component,
f_indices_global=f_indices_global,
mode=mode,
result = _classify_direct(component)
if result is not None:
f_idx, flip = result
direct_f_indices.append(f_idx)
direct_flips.append(flip)
direct_output_order.append(component.output_indices[0])
else:
compiled = _compile_component(
component=component,
f_indices_global=f_indices_global,
mode=mode,
)
compiled_components.append(compiled)
compiled_output_order.extend(component.output_indices)

# Sort direct entries by output index so that the concatenation layout
# in sample_program matches the original output order as closely as
# possible. When transform_error_basis also prioritises outputs, this
# often yields an identity permutation and avoids reindexing at sample time.
if direct_output_order:
order = sorted(
range(len(direct_output_order)), key=direct_output_order.__getitem__
)
compiled_components.append(compiled)
output_order.extend(component.output_indices)
direct_f_indices = [direct_f_indices[i] for i in order]
direct_flips = [direct_flips[i] for i in order]
direct_output_order = [direct_output_order[i] for i in order]

# output_order must match the concatenation layout in sample_program:
# [direct bits, compiled_0 outputs, compiled_1 outputs, ...]
output_order = jnp.array(
direct_output_order + compiled_output_order, dtype=jnp.int32
)
reindex = jnp.argsort(output_order)
is_identity = bool(jnp.all(reindex == jnp.arange(len(output_order))))

return CompiledProgram(
components=tuple(compiled_components),
output_order=jnp.array(output_order, dtype=jnp.int32),
direct_f_indices=jnp.array(direct_f_indices, dtype=jnp.int32),
direct_flips=jnp.array(direct_flips, dtype=jnp.bool_),
output_order=output_order,
output_reindex=None if is_identity else reindex,
num_outputs=num_outputs,
num_f_params=len(f_indices_global),
num_detectors=prepared.num_detectors,
Expand Down
22 changes: 20 additions & 2 deletions src/tsim/core/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,9 +274,27 @@ def transform_error_basis(
then f0 = e1 XOR e3.

"""
parametrized_vertices = [
v for v in g.vertices() if v in g._phaseVars and g._phaseVars[v]
# Prioritize output-connected detector vertices so that f0, f1, ...
# are assigned in output order. This maximises the chance that the
# direct-component fast path produces an identity permutation, avoiding
# a column reindex at sample time.
output_detectors = []
for v_out in g.outputs():
neighbors = list(g.neighbors(v_out))
if (
len(neighbors) == 1
and neighbors[0] in g._phaseVars
and g._phaseVars[neighbors[0]]
):
output_detectors.append(neighbors[0])

output_det_set = set(output_detectors)
rest = [
v
for v in g.vertices()
if v not in output_det_set and v in g._phaseVars and g._phaseVars[v]
]
parametrized_vertices = output_detectors + rest

if not parametrized_vertices:
g.scalar = Scalar()
Expand Down
12 changes: 10 additions & 2 deletions src/tsim/core/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,16 +86,24 @@ class CompiledProgram:

Attributes:
components: The compiled components, sorted by number of outputs.
output_order: Array for reordering component outputs to final order.
final_samples = combined[:, np.argsort(output_order)]
direct_f_indices: Precomputed f-parameter indices for direct components.
direct_flips: Precomputed flip flags for direct components.
output_order: Maps concatenated position to original output index.
The first ``len(direct_f_indices)`` entries correspond to direct
components; the remainder to compiled components.
output_reindex: Precomputed ``argsort(output_order)`` permutation,
or ``None`` when the outputs are already in order.
num_outputs: Total number of outputs across all components.
num_f_params: Total number of f-parameters.
num_detectors: Number of detector outputs (for detector sampling).

"""

components: tuple[CompiledComponent, ...]
direct_f_indices: Array
direct_flips: Array
output_order: Array
output_reindex: Array | None
num_outputs: int
num_f_params: int
num_detectors: int
137 changes: 82 additions & 55 deletions src/tsim/noise/channels.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
"""Pauli noise channels and error sampling infrastructure."""

from __future__ import annotations

from collections import defaultdict
from dataclasses import dataclass

import jax
import jax.numpy as jnp
import numpy as np


Expand All @@ -26,11 +26,6 @@ def num_bits(self) -> int:
"""Number of bits in the channel (k where probs has shape 2^k)."""
return int(np.log2(len(self.probs)))

@property
def logits(self) -> jax.Array:
"""Convert to logits for JAX sampling."""
return jnp.log(jnp.array(self.probs))


def error_probs(p: float) -> np.ndarray:
"""Single-bit error channel. Returns shape (2,)."""
Expand Down Expand Up @@ -382,49 +377,13 @@ def simplify_channels(
return channels


def _sample_channels(
key: jax.Array,
channels: list[Channel],
matrix: jax.Array,
num_samples: int,
) -> jax.Array:
"""Sample from multiple channels and combine results via XOR.

Args:
key: JAX random key
channels: List of channels to sample from
matrix: Signature matrix of shape (num_signatures, num_outputs)
num_samples: Number of samples to draw

Returns:
Samples array of shape (num_samples, num_outputs).

"""
num_outputs = matrix.shape[1]
res = jnp.zeros((num_samples, num_outputs), dtype=jnp.uint8)

keys = jax.random.split(key, len(channels))

for channel, subkey in zip(channels, keys, strict=True):
num_bits = channel.num_bits

samples = jax.random.categorical(subkey, channel.logits, shape=(num_samples,))
# Extract individual bits from sampled indices
bits = ((samples[:, None] >> jnp.arange(num_bits)) & 1).astype(jnp.uint8)

# XOR contribution from each bit into the result
for i, col_id in enumerate(channel.unique_col_ids):
res = res ^ (bits[:, i : i + 1] * matrix[col_id : col_id + 1, :])

return res


class ChannelSampler:
"""Samples from multiple error channels and transforms to a reduced basis.

This class combines multiple error channels (each producing error bits e0, e1, ...)
and applies a linear transformation over GF(2) to convert samples from the original
"e" basis to a reduced "f" basis.
"e" basis to a reduced "f" basis using geometric-skip sampling optimized for
low-noise regimes.

f_i = error_transform_ij * e_j mod 2

Expand Down Expand Up @@ -481,25 +440,93 @@ def __init__(
e_offset += num_bits

self.channels = simplify_channels(channels, null_col_id=null_col_id)
self.signature_matrix = jnp.array(signature_matrix, dtype=jnp.uint8)
self.signature_matrix = signature_matrix.astype(np.uint8)

self._key = jax.random.key(
seed if seed is not None else np.random.randint(0, 2**30)
self._rng = np.random.default_rng(
seed if seed is not None else np.random.default_rng().integers(0, 2**30)
)
self._sparse_data = self._precompute_sparse(
self.channels, self.signature_matrix
)

def sample(self, num_samples: int = 1) -> jax.Array:
@staticmethod
def _precompute_sparse(
channels: list[Channel], signature_matrix: np.ndarray
) -> list[tuple[float, np.ndarray, np.ndarray]]:
"""Precompute per-channel data for geometric-skip sampling.

For each channel with non-trivial fire probability, computes:
- p_fire: probability of any non-identity outcome
- cond_cdf: conditional CDF over non-identity outcomes
- xor_patterns: precomputed XOR output patterns per outcome

Args:
channels: List of noise channels to precompute data for.
signature_matrix: Binary matrix of shape (num_e, num_f) mapping
error-variable columns to output f-variables.

Returns:
List of (p_fire, cond_cdf, xor_patterns) tuples, one per channel
with non-trivial fire probability. ``p_fire`` is a float,
``cond_cdf`` is a float64 array of shape (n_outcomes - 1,), and
``xor_patterns`` is a uint8 array of shape (n_outcomes - 1, num_f).

"""
data: list[tuple[float, np.ndarray, np.ndarray]] = []
for ch in channels:
probs = ch.probs.astype(np.float64)
p_fire = 1.0 - float(probs[0])
n_outcomes = len(probs)

if p_fire <= 1e-15 or n_outcomes <= 1:
continue

cond_cdf = np.cumsum(probs[1:] / p_fire, dtype=np.float64)
cond_cdf /= cond_cdf[-1]

col_ids = np.asarray(ch.unique_col_ids)
num_bits = len(col_ids)
outcomes = np.arange(1, n_outcomes)
bits_mask = ((outcomes[:, None] >> np.arange(num_bits)) & 1).astype(
np.uint8
)
xor_patterns = (bits_mask @ signature_matrix[col_ids] % 2).astype(np.uint8)

data.append((p_fire, cond_cdf, xor_patterns))
return data

def sample(self, num_samples: int = 1) -> np.ndarray:
"""Sample from all error channels and transform to new error basis.

Uses geometric-skip sampling, optimized for low-noise regimes where
P(non-identity) << 1 per channel.

Args:
num_samples: Number of samples to draw.

Returns:
Array of shape (num_samples, num_f) with boolean values indicating
NumPy array of shape (num_samples, num_f) with uint8 values indicating
which f-variables are set for each sample.

"""
self._key, subkey = jax.random.split(self._key)
samples = _sample_channels(
subkey, self.channels, self.signature_matrix, num_samples
)
return samples
num_outputs = self.signature_matrix.shape[1]
result = np.zeros((num_samples, num_outputs), dtype=np.uint8)

for p_fire, cond_cdf, xor_pats in self._sparse_data:
expected = num_samples * p_fire
sigma = np.sqrt(expected * (1.0 - p_fire))
# At 7 sigma, we undersample in about 1 out of 1e12 cases
n_draws = int(expected + 7.0 * sigma) + 100

positions = np.cumsum(self._rng.geometric(p_fire, size=n_draws)) - 1
positions = positions[positions < num_samples]

if len(positions) == 0:
continue

outcome_idx = np.searchsorted(
cond_cdf, self._rng.uniform(size=len(positions))
)
result[positions] ^= xor_pats[outcome_idx]

return result
Loading
Loading