diff --git a/src/tsim/circuit.py b/src/tsim/circuit.py index 2ebaa29..785eee9 100644 --- a/src/tsim/circuit.py +++ b/src/tsim/circuit.py @@ -575,6 +575,11 @@ def compile_sampler(self, *, seed: int | None = None): def compile_detector_sampler(self, *, seed: int | None = None): """Compile circuit into a detector sampler. + Connected components whose single output is deterministically given by + one f-variable are handled via a fast direct path (no compilation or + autoregressive sampling). Remaining components go through the full + compilation pipeline. + Args: seed: Random seed for the sampler. If None, a random seed will be generated. diff --git a/src/tsim/compile/pipeline.py b/src/tsim/compile/pipeline.py index 6522aad..b06fe3f 100644 --- a/src/tsim/compile/pipeline.py +++ b/src/tsim/compile/pipeline.py @@ -2,6 +2,7 @@ from __future__ import annotations +from fractions import Fraction from typing import Literal import jax.numpy as jnp @@ -16,6 +17,63 @@ DecompositionMode = Literal["sequential", "joint"] +def _classify_direct( + component: ConnectedComponent, +) -> tuple[int, bool] | None: + """Check if a component is directly determined by a single f-variable. + + A component qualifies when its graph consists of exactly two vertices — one + boundary output and one Z-spider — connected by a Hadamard edge, where the + Z-spider carries a single ``f`` parameter and a constant phase of either 0 + (no flip) or π (flip). + + Returns: + ``(output_index, f_index, flip)`` if the fast path applies, otherwise + ``None``. + + """ + graph = component.graph + outputs = list(graph.outputs()) + if len(outputs) != 1: + return None + + vertices = list(graph.vertices()) + if len(vertices) != 2: + return None + + v_out = outputs[0] + neighbors = list(graph.neighbors(v_out)) + if len(neighbors) != 1: + return None + + v_det = neighbors[0] + if graph.type(v_det) != zx.utils.VertexType.Z: + return None + if graph.edge_type(graph.edge(v_out, v_det)) != zx.utils.EdgeType.HADAMARD: + return None + + params = graph.get_params(v_det) + if len(params) != 1: + return None + f_param = next(iter(params)) + if not f_param.startswith("f"): + return None + + all_graph_params = get_params(graph) + if all_graph_params != {f_param}: + return None + + phase = graph.phase(v_det) + if phase == 0: + flip = False + elif phase == Fraction(1, 1): + flip = True + else: + return None + + return int(f_param[1:]), flip + + def compile_program( prepared: SamplingGraph, *, @@ -48,23 +106,56 @@ def compile_program( f_indices_global = _get_f_indices(prepared.graph) num_outputs = prepared.num_outputs + direct_f_indices: list[int] = [] + direct_flips: list[bool] = [] + direct_output_order: list[int] = [] compiled_components: list[CompiledComponent] = [] - output_order: list[int] = [] + compiled_output_order: list[int] = [] sorted_components = sorted(components, key=lambda c: len(c.output_indices)) for component in sorted_components: - compiled = _compile_component( - component=component, - f_indices_global=f_indices_global, - mode=mode, + result = _classify_direct(component) + if result is not None: + f_idx, flip = result + direct_f_indices.append(f_idx) + direct_flips.append(flip) + direct_output_order.append(component.output_indices[0]) + else: + compiled = _compile_component( + component=component, + f_indices_global=f_indices_global, + mode=mode, + ) + compiled_components.append(compiled) + compiled_output_order.extend(component.output_indices) + + # Sort direct entries by output index so that the concatenation layout + # in sample_program matches the original output order as closely as + # possible. When transform_error_basis also prioritises outputs, this + # often yields an identity permutation and avoids reindexing at sample time. + if direct_output_order: + order = sorted( + range(len(direct_output_order)), key=direct_output_order.__getitem__ ) - compiled_components.append(compiled) - output_order.extend(component.output_indices) + direct_f_indices = [direct_f_indices[i] for i in order] + direct_flips = [direct_flips[i] for i in order] + direct_output_order = [direct_output_order[i] for i in order] + + # output_order must match the concatenation layout in sample_program: + # [direct bits, compiled_0 outputs, compiled_1 outputs, ...] + output_order = jnp.array( + direct_output_order + compiled_output_order, dtype=jnp.int32 + ) + reindex = jnp.argsort(output_order) + is_identity = bool(jnp.all(reindex == jnp.arange(len(output_order)))) return CompiledProgram( components=tuple(compiled_components), - output_order=jnp.array(output_order, dtype=jnp.int32), + direct_f_indices=jnp.array(direct_f_indices, dtype=jnp.int32), + direct_flips=jnp.array(direct_flips, dtype=jnp.bool_), + output_order=output_order, + output_reindex=None if is_identity else reindex, num_outputs=num_outputs, num_f_params=len(f_indices_global), num_detectors=prepared.num_detectors, diff --git a/src/tsim/core/graph.py b/src/tsim/core/graph.py index 8cb638a..2ef7209 100644 --- a/src/tsim/core/graph.py +++ b/src/tsim/core/graph.py @@ -274,9 +274,27 @@ def transform_error_basis( then f0 = e1 XOR e3. """ - parametrized_vertices = [ - v for v in g.vertices() if v in g._phaseVars and g._phaseVars[v] + # Prioritize output-connected detector vertices so that f0, f1, ... + # are assigned in output order. This maximises the chance that the + # direct-component fast path produces an identity permutation, avoiding + # a column reindex at sample time. + output_detectors = [] + for v_out in g.outputs(): + neighbors = list(g.neighbors(v_out)) + if ( + len(neighbors) == 1 + and neighbors[0] in g._phaseVars + and g._phaseVars[neighbors[0]] + ): + output_detectors.append(neighbors[0]) + + output_det_set = set(output_detectors) + rest = [ + v + for v in g.vertices() + if v not in output_det_set and v in g._phaseVars and g._phaseVars[v] ] + parametrized_vertices = output_detectors + rest if not parametrized_vertices: g.scalar = Scalar() diff --git a/src/tsim/core/types.py b/src/tsim/core/types.py index c8f1cd8..3fe1869 100644 --- a/src/tsim/core/types.py +++ b/src/tsim/core/types.py @@ -86,8 +86,13 @@ class CompiledProgram: Attributes: components: The compiled components, sorted by number of outputs. - output_order: Array for reordering component outputs to final order. - final_samples = combined[:, np.argsort(output_order)] + direct_f_indices: Precomputed f-parameter indices for direct components. + direct_flips: Precomputed flip flags for direct components. + output_order: Maps concatenated position to original output index. + The first ``len(direct_f_indices)`` entries correspond to direct + components; the remainder to compiled components. + output_reindex: Precomputed ``argsort(output_order)`` permutation, + or ``None`` when the outputs are already in order. num_outputs: Total number of outputs across all components. num_f_params: Total number of f-parameters. num_detectors: Number of detector outputs (for detector sampling). @@ -95,7 +100,10 @@ class CompiledProgram: """ components: tuple[CompiledComponent, ...] + direct_f_indices: Array + direct_flips: Array output_order: Array + output_reindex: Array | None num_outputs: int num_f_params: int num_detectors: int diff --git a/src/tsim/noise/channels.py b/src/tsim/noise/channels.py index 4b8b134..4bee034 100644 --- a/src/tsim/noise/channels.py +++ b/src/tsim/noise/channels.py @@ -1,10 +1,10 @@ """Pauli noise channels and error sampling infrastructure.""" +from __future__ import annotations + from collections import defaultdict from dataclasses import dataclass -import jax -import jax.numpy as jnp import numpy as np @@ -26,11 +26,6 @@ def num_bits(self) -> int: """Number of bits in the channel (k where probs has shape 2^k).""" return int(np.log2(len(self.probs))) - @property - def logits(self) -> jax.Array: - """Convert to logits for JAX sampling.""" - return jnp.log(jnp.array(self.probs)) - def error_probs(p: float) -> np.ndarray: """Single-bit error channel. Returns shape (2,).""" @@ -382,49 +377,13 @@ def simplify_channels( return channels -def _sample_channels( - key: jax.Array, - channels: list[Channel], - matrix: jax.Array, - num_samples: int, -) -> jax.Array: - """Sample from multiple channels and combine results via XOR. - - Args: - key: JAX random key - channels: List of channels to sample from - matrix: Signature matrix of shape (num_signatures, num_outputs) - num_samples: Number of samples to draw - - Returns: - Samples array of shape (num_samples, num_outputs). - - """ - num_outputs = matrix.shape[1] - res = jnp.zeros((num_samples, num_outputs), dtype=jnp.uint8) - - keys = jax.random.split(key, len(channels)) - - for channel, subkey in zip(channels, keys, strict=True): - num_bits = channel.num_bits - - samples = jax.random.categorical(subkey, channel.logits, shape=(num_samples,)) - # Extract individual bits from sampled indices - bits = ((samples[:, None] >> jnp.arange(num_bits)) & 1).astype(jnp.uint8) - - # XOR contribution from each bit into the result - for i, col_id in enumerate(channel.unique_col_ids): - res = res ^ (bits[:, i : i + 1] * matrix[col_id : col_id + 1, :]) - - return res - - class ChannelSampler: """Samples from multiple error channels and transforms to a reduced basis. This class combines multiple error channels (each producing error bits e0, e1, ...) and applies a linear transformation over GF(2) to convert samples from the original - "e" basis to a reduced "f" basis. + "e" basis to a reduced "f" basis using geometric-skip sampling optimized for + low-noise regimes. f_i = error_transform_ij * e_j mod 2 @@ -481,25 +440,93 @@ def __init__( e_offset += num_bits self.channels = simplify_channels(channels, null_col_id=null_col_id) - self.signature_matrix = jnp.array(signature_matrix, dtype=jnp.uint8) + self.signature_matrix = signature_matrix.astype(np.uint8) - self._key = jax.random.key( - seed if seed is not None else np.random.randint(0, 2**30) + self._rng = np.random.default_rng( + seed if seed is not None else np.random.default_rng().integers(0, 2**30) + ) + self._sparse_data = self._precompute_sparse( + self.channels, self.signature_matrix ) - def sample(self, num_samples: int = 1) -> jax.Array: + @staticmethod + def _precompute_sparse( + channels: list[Channel], signature_matrix: np.ndarray + ) -> list[tuple[float, np.ndarray, np.ndarray]]: + """Precompute per-channel data for geometric-skip sampling. + + For each channel with non-trivial fire probability, computes: + - p_fire: probability of any non-identity outcome + - cond_cdf: conditional CDF over non-identity outcomes + - xor_patterns: precomputed XOR output patterns per outcome + + Args: + channels: List of noise channels to precompute data for. + signature_matrix: Binary matrix of shape (num_e, num_f) mapping + error-variable columns to output f-variables. + + Returns: + List of (p_fire, cond_cdf, xor_patterns) tuples, one per channel + with non-trivial fire probability. ``p_fire`` is a float, + ``cond_cdf`` is a float64 array of shape (n_outcomes - 1,), and + ``xor_patterns`` is a uint8 array of shape (n_outcomes - 1, num_f). + + """ + data: list[tuple[float, np.ndarray, np.ndarray]] = [] + for ch in channels: + probs = ch.probs.astype(np.float64) + p_fire = 1.0 - float(probs[0]) + n_outcomes = len(probs) + + if p_fire <= 1e-15 or n_outcomes <= 1: + continue + + cond_cdf = np.cumsum(probs[1:] / p_fire, dtype=np.float64) + cond_cdf /= cond_cdf[-1] + + col_ids = np.asarray(ch.unique_col_ids) + num_bits = len(col_ids) + outcomes = np.arange(1, n_outcomes) + bits_mask = ((outcomes[:, None] >> np.arange(num_bits)) & 1).astype( + np.uint8 + ) + xor_patterns = (bits_mask @ signature_matrix[col_ids] % 2).astype(np.uint8) + + data.append((p_fire, cond_cdf, xor_patterns)) + return data + + def sample(self, num_samples: int = 1) -> np.ndarray: """Sample from all error channels and transform to new error basis. + Uses geometric-skip sampling, optimized for low-noise regimes where + P(non-identity) << 1 per channel. + Args: num_samples: Number of samples to draw. Returns: - Array of shape (num_samples, num_f) with boolean values indicating + NumPy array of shape (num_samples, num_f) with uint8 values indicating which f-variables are set for each sample. """ - self._key, subkey = jax.random.split(self._key) - samples = _sample_channels( - subkey, self.channels, self.signature_matrix, num_samples - ) - return samples + num_outputs = self.signature_matrix.shape[1] + result = np.zeros((num_samples, num_outputs), dtype=np.uint8) + + for p_fire, cond_cdf, xor_pats in self._sparse_data: + expected = num_samples * p_fire + sigma = np.sqrt(expected * (1.0 - p_fire)) + # At 7 sigma, we undersample in about 1 out of 1e12 cases + n_draws = int(expected + 7.0 * sigma) + 100 + + positions = np.cumsum(self._rng.geometric(p_fire, size=n_draws)) - 1 + positions = positions[positions < num_samples] + + if len(positions) == 0: + continue + + outcome_idx = np.searchsorted( + cond_cdf, self._rng.uniform(size=len(positions)) + ) + result[positions] ^= xor_pats[outcome_idx] + + return result diff --git a/src/tsim/sampler.py b/src/tsim/sampler.py index 2f6c9d0..cdbb4bc 100644 --- a/src/tsim/sampler.py +++ b/src/tsim/sampler.py @@ -121,12 +121,21 @@ def sample_program( """ results: list[jax.Array] = [] + if len(program.direct_f_indices) > 0: + direct_bits = ( + f_params[:, program.direct_f_indices].astype(jnp.bool_) + ^ program.direct_flips + ) + results.append(direct_bits) + for component in program.components: samples, key = sample_component(component, f_params, key) results.append(samples) combined = jnp.concatenate(results, axis=1) - return combined[:, jnp.argsort(program.output_order)] + if program.output_reindex is not None: + combined = combined[:, program.output_reindex] + return combined class _CompiledSamplerBase: @@ -157,8 +166,7 @@ def __init__( prepared = prepare_graph(circuit, sample_detectors=sample_detectors) self._program = compile_program(prepared, mode=mode) - self._key, subkey = jax.random.split(self._key) - channel_seed = int(jax.random.randint(subkey, (), 0, 2**30)) + channel_seed = int(np.random.default_rng(seed).integers(0, 2**30)) self._channel_sampler = ChannelSampler( channel_probs=prepared.channel_probs, error_transform=prepared.error_transform, @@ -168,22 +176,56 @@ def __init__( self.circuit = circuit self._num_detectors = prepared.num_detectors + # Pre-cache numpy arrays for the direct fast path so we don't + # convert from JAX on every sample call. + prog = self._program + n_direct = len(prog.direct_f_indices) + self._direct_f_np = np.asarray(prog.direct_f_indices) + self._direct_fl_np = np.asarray(prog.direct_flips) + self._direct_reindex_np = ( + np.asarray(prog.output_reindex) if prog.output_reindex is not None else None + ) + self._direct_any_flips = bool(np.any(self._direct_fl_np)) + self._direct_contiguous = n_direct > 0 and np.array_equal( + self._direct_f_np, np.arange(n_direct) + ) + def _sample_batches(self, shots: int, batch_size: int | None = None) -> np.ndarray: """Sample in batches and concatenate results.""" + if not self._program.components: + return self._sample_direct(shots) + if batch_size is None: batch_size = shots batches: list[jax.Array] = [] for _ in range(ceil(shots / batch_size)): - f_params = self._channel_sampler.sample(batch_size) + f_params_np = self._channel_sampler.sample(batch_size) + f_params = jnp.asarray(f_params_np) self._key, subkey = jax.random.split(self._key) samples = sample_program(self._program, f_params, subkey) batches.append(samples) return np.concatenate(batches)[:shots] + def _sample_direct(self, shots: int) -> np.ndarray: + """Fast path when all components are direct (pure numpy, no JAX).""" + f_params = self._channel_sampler.sample(shots) + n = len(self._direct_f_np) + if self._direct_contiguous: + result = f_params[:, :n] if n < f_params.shape[1] else f_params + else: + result = f_params[:, self._direct_f_np] + if self._direct_any_flips: + result = result ^ self._direct_fl_np + if self._direct_reindex_np is not None: + result = result[:, self._direct_reindex_np] + return result + def __repr__(self) -> str: """Return a string representation with compilation statistics.""" + n_direct = len(self._program.direct_f_indices) + c_graphs = [] c_params = [] c_a_terms = [] @@ -222,11 +264,13 @@ def _format_bytes(n: int) -> str: error_channel_bits = sum( channel.num_bits for channel in self._channel_sampler.channels ) + max_outputs = int(np.max(num_outputs)) if num_outputs else 0 return ( - f"{type(self).__name__}({np.sum(c_graphs)} graphs, " + f"{type(self).__name__}({n_direct} direct, " + f"{np.sum(c_graphs)} graphs, " f"{error_channel_bits} error channel bits, " - f"{np.max(num_outputs)} outputs for largest cc, " + f"{max_outputs} outputs for largest cc, " f"≤ {np.max(c_params) if c_params else 0} parameters, {np.sum(c_a_terms)} A terms, " f"{np.sum(c_b_terms)} B terms, " f"{np.sum(c_c_terms)} C terms, {np.sum(c_d_terms)} D terms, " @@ -406,10 +450,19 @@ def probability_of(self, state: np.ndarray, *, batch_size: int) -> np.ndarray: Array of probabilities P(state | error_sample) for each error sample. """ - f_samples = self._channel_sampler.sample(batch_size) + f_samples = jnp.asarray(self._channel_sampler.sample(batch_size)) p_norm = jnp.ones(batch_size) p_joint = jnp.ones(batch_size) + if len(self._program.direct_f_indices) > 0: + bits = ( + f_samples[:, self._program.direct_f_indices].astype(jnp.bool_) + ^ self._program.direct_flips + ) + n_direct = len(self._program.direct_f_indices) + targets = state[self._program.output_order[:n_direct]] + p_joint = p_joint * (bits == targets).all(axis=1) + for component in self._program.components: assert len(component.compiled_scalar_graphs) == 2 diff --git a/test/integration/test_sampler_circuits.py b/test/integration/test_sampler_circuits.py index 6408be8..21a9ecd 100644 --- a/test/integration/test_sampler_circuits.py +++ b/test/integration/test_sampler_circuits.py @@ -21,7 +21,7 @@ def test_sample_bell_state(): m = sampler.sample(100) assert np.array_equal(m[:, 0], m[:, 1]) - assert np.count_nonzero(m[:, 0]) == 53 + assert np.count_nonzero(m[:, 0]) == 48 def test_detector_sampler_bell_state_with_measurement_error(): @@ -52,7 +52,7 @@ def test_t_gate(): ) sampler = c.compile_sampler(seed=0) m = sampler.sample(100) - assert np.count_nonzero(m) == 16 + assert np.count_nonzero(m) == 9 def test_s_gate(): @@ -66,7 +66,7 @@ def test_s_gate(): ) sampler = c.compile_sampler(seed=0) m = sampler.sample(100) - assert np.count_nonzero(m) == 53 + assert np.count_nonzero(m) == 48 def test_t_dag_gate(): @@ -114,13 +114,13 @@ def test_r_gate(): ) sampler = c.compile_sampler(seed=0) m = sampler.sample(10) - assert np.count_nonzero(m[:, 0]) == 4 - assert np.count_nonzero(m[:, 1]) == 6 + assert np.count_nonzero(m[:, 0]) == 7 + assert np.count_nonzero(m[:, 1]) == 4 assert np.count_nonzero(m[:, 2]) == 0 det_sampler = c.compile_detector_sampler(seed=0) d = det_sampler.sample(10) - assert np.count_nonzero(d) == 4 + assert np.count_nonzero(d) == 7 @pytest.mark.parametrize( diff --git a/test/unit/noise/test_channels.py b/test/unit/noise/test_channels.py index 8b5a8fb..d922a93 100644 --- a/test/unit/noise/test_channels.py +++ b/test/unit/noise/test_channels.py @@ -1,12 +1,9 @@ -import jax -import jax.numpy as jnp import numpy as np from numpy.testing import assert_allclose from tsim.noise.channels import ( Channel, ChannelSampler, - _sample_channels, absorb_subset_channels, correlated_error_probs, error_probs, @@ -140,8 +137,20 @@ def test_chain_with_certain_first_error(self): assert_allclose(probs[4], 0.0) # Third error (blocked) +def _sample_channels(channels, matrix, n_samples, seed=42): + """Sample from channels using the ChannelSampler infrastructure.""" + sampler = object.__new__(ChannelSampler) + sampler.channels = channels + sampler.signature_matrix = np.asarray(matrix, dtype=np.uint8) + sampler._rng = np.random.default_rng(seed) + sampler._sparse_data = ChannelSampler._precompute_sparse( + channels, sampler.signature_matrix + ) + return sampler.sample(n_samples) + + def assert_sampling_matches( - matrix: jnp.ndarray, + matrix: np.ndarray, channels_before: list[Channel], channels_after: list[Channel], n_samples: int = 500_000, @@ -152,13 +161,11 @@ def assert_sampling_matches( Compares the mean of each output bit (f-variable) between the two channel sets. """ - key1 = jax.random.key(seed) - bits1 = _sample_channels(key1, channels_before, matrix, n_samples) - freq1 = np.mean(np.asarray(bits1), axis=0) + bits1 = _sample_channels(channels_before, matrix, n_samples, seed=seed) + freq1 = np.mean(bits1, axis=0) - key2 = jax.random.key(seed + 1) - bits2 = _sample_channels(key2, channels_after, matrix, n_samples) - freq2 = np.mean(np.asarray(bits2), axis=0) + bits2 = _sample_channels(channels_after, matrix, n_samples, seed=seed + 1) + freq2 = np.mean(bits2, axis=0) assert_allclose( freq1, @@ -251,13 +258,13 @@ def test_no_merge_different_signatures(self): def test_sampling_matches_after_merge(self): """Sampling statistics should match before and after merging.""" - mat = jnp.array( + mat = np.array( [ [1, 0, 0], [0, 1, 0], [1, 1, 0], ], - dtype=jnp.uint8, + dtype=np.uint8, ) c1 = Channel(probs=error_probs(0.1), unique_col_ids=(0,)) @@ -357,7 +364,7 @@ def test_preserves_sampling_statistics(self): normalized = normalize_channels([c]) - mat = jnp.eye(2, dtype=jnp.uint8) + mat = np.eye(2, dtype=np.uint8) assert_sampling_matches(mat, [c], normalized) @@ -406,13 +413,13 @@ def test_no_absorb_partial_overlap(self): def test_sampling_matches_after_absorb(self): """Sampling statistics should match before and after absorption.""" - mat = jnp.array( + mat = np.array( [ [1, 0, 0], [0, 1, 0], [1, 1, 0], ], - dtype=jnp.uint8, + dtype=np.uint8, ) # c1 has signature (0,), c2 has signatures (0, 1) @@ -435,14 +442,14 @@ class TestSimplifyChannels: def test_simplify_mixed_channels(self): """Test simplification with a mix of channel types.""" - mat = jnp.array( + mat = np.array( [ [1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [1, 1, 1, 0], ], - dtype=jnp.uint8, + dtype=np.uint8, ) # Create channels: @@ -467,7 +474,7 @@ def test_simplify_mixed_channels(self): def test_simplify_many_1bit_channels(self): """Test simplification of many 1-bit channels with same signature.""" - mat = jnp.array([[1], [1]], dtype=jnp.uint8) + mat = np.array([[1], [1]], dtype=np.uint8) # 10 channels all with the same signature channels = [ @@ -481,13 +488,13 @@ def test_simplify_many_1bit_channels(self): def test_simplify_preserves_independent_channels(self): """Channels with disjoint signatures should remain separate.""" - mat = jnp.array( + mat = np.array( [ [1, 0, 0], [0, 1, 0], [0, 0, 1], ], - dtype=jnp.uint8, + dtype=np.uint8, ) c1 = Channel(probs=error_probs(0.1), unique_col_ids=(0,)) @@ -502,31 +509,27 @@ def test_simplify_preserves_independent_channels(self): class TestSampleChannels: - """Tests for _sample_channels function.""" + """Tests for channel sampling.""" def test_single_channel(self): """Test that sampling produces correct frequencies for a single channel.""" - mat = jnp.array([[1]], dtype=jnp.uint8) + mat = np.array([[1]], dtype=np.uint8) c = Channel(probs=error_probs(0.3), unique_col_ids=(0,)) - key = jax.random.key(42) - bits = _sample_channels(key, [c], mat, 100_000) - freq = np.mean(np.asarray(bits[:, 0])) + bits = _sample_channels([c], mat, 100_000) + freq = np.mean(bits[:, 0]) assert_allclose(freq, 0.3, rtol=0.05) def test_xor_two_channels(self): """Test that sampling correctly XORs two independent channels.""" - # Matrix shape: (num_signatures, num_f_vars) - # Both signatures (0 and 1) affect f0 - mat = jnp.array([[1], [1]], dtype=jnp.uint8) + mat = np.array([[1], [1]], dtype=np.uint8) c1 = Channel(probs=error_probs(0.2), unique_col_ids=(0,)) c2 = Channel(probs=error_probs(0.3), unique_col_ids=(1,)) - key = jax.random.key(42) - bits = _sample_channels(key, [c1, c2], mat, 100_000) - freq = np.mean(np.asarray(bits[:, 0])) + bits = _sample_channels([c1, c2], mat, 100_000) + freq = np.mean(bits[:, 0]) # P(f0=1) = P(e0 XOR e1 = 1) = 0.2*0.7 + 0.3*0.8 = 0.14 + 0.24 = 0.38 expected = 0.2 * 0.7 + 0.3 * 0.8 diff --git a/test/unit/test_sampler.py b/test/unit/test_sampler.py index 892663d..f86cea0 100644 --- a/test/unit/test_sampler.py +++ b/test/unit/test_sampler.py @@ -38,10 +38,10 @@ def test_seed(): ) for _ in range(2): sampler = c.compile_sampler(seed=0) + assert np.count_nonzero(sampler.sample(100)) == 48 assert np.count_nonzero(sampler.sample(100)) == 53 assert np.count_nonzero(sampler.sample(100)) == 52 assert np.count_nonzero(sampler.sample(100)) == 50 - assert np.count_nonzero(sampler.sample(100)) == 48 def test_sampler_repr():