Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ py_wheel(
name="tesseract_decoder_wheel",
distribution = "tesseract_decoder",
deps=[
"//src:tesseract_decoder",
"//src:_core",
"//src/py:generated_stubs",
"//src/py/_tesseract_py_util:_tesseract_py_util",
"//src/py:tesseract_decoder",
":package_data",
],
version = "$(VERSION)",
Expand Down
61 changes: 61 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,67 @@ Here are some tips for improving performance:
* *DEM usage frequency output*: if `--dem-out` is specified, outputs estimated error frequencies.
* *Statistics output*: includes number of shots, errors, low confidence shots, and processing time.

---

## Multi-Pass Graph Shattering

For loopy 3D syndrome hypergraphs (such as circuit-level color codes under circuit noise), monolithic MWPM/beam search scales exponentially slow. Tesseract implements **Multi-Pass Graph Shattering** to sever physical error correlation edges, breaking the monolithic graph into independent, planar-like CSS stabilizer components.

Priors (LLRs) are dynamically updated and propagated between passes using conditional probabilities to preserve physical logical accuracy while delivering up to **$1,000\times$ decoding speedups**.

### ⚠️ Strict Annotation Requirements (Limitations)
To decode using graph shattering, Tesseract **must** be able to classify detectors into basis components. The input Stim circuit or Detector Error Model (DEM) **MUST** be annotated using one of the following conventions:
1. **Basis Tags**: Detector instructions must contain standard basis metadata tags (e.g. `detector(0, 0) D0 {"basis": "X"}` or `detector(0, 0) D1 {"basis": "Z"}`).
2. **Coordinate Conventions (Chromobius Style)**: Detector coordinates must contain at least 4 dimensions, where the 4th coordinate represents `color + 3 * basis` (Component 0: `0 <= coords[3] <= 2`, Component 1: `3 <= coords[3] <= 5`).

If an unannotated circuit/DEM is supplied with `--multipass` enabled, Tesseract will fail fast and throw a clear `std::invalid_argument` exception.

### CLI Options
* `--multipass`: Enable multi-pass graph shattering (default = false).
* `--num-passes`, `--num_passes`: Number of prior propagation passes (default = 2).
* `1`: Uncorrelated independent CSS decoding (planar speedup, no reweighting).
* `2`: Standard causally reweighted prior propagation decoding.
* *Note: values > 2 are experimental and were never systematically benchmarked.*
* `--multipass-strategy`, `--multipass_strategy`: Causal or static pass scheduling (default = causal).
* `causal` (Recommended): Dynamically schedules stabilizer components sequentially based on physical prior causal flow (Component 0 decodes first, updates prior edge weights, Component 1 decodes using those LLRs).
* `static`: Decodes all components in parallel without dynamic pass-to-pass LLR updates.

### CLI Examples

**1. Running Multi-Pass on Basis-Tag Annotated Surface Codes (using Long-Beam Settings):**
```bash
./bazel-bin/src/tesseract \
--circuit testdata/annotated_surface_codes/style=surface_code,d=5,basis=X,num_rounds=10,max_qubits_per_module=49,total_qubits=64,k=1,noise=SI1000,p=0.00100.stim \
--sample-num-shots 1000 \
--multipass \
--num-passes 2 \
--multipass-strategy causal \
--pqlimit 1000000 \
--beam 20 \
--beam-climbing \
--no-revisit-dets \
--num-det-orders 21 \
--print-stats
```

**2. Running Multi-Pass on Coordinate-Annotated Color Codes (using Long-Beam Settings):**
```bash
./bazel-bin/src/tesseract \
--circuit testdata/colorcodes/r=5,d=5,p=0.003,noise=si1000,c=midout_color_code_X,q=23,gates=cz.stim \
--sample-num-shots 1000 \
--multipass \
--num-passes 2 \
--multipass-strategy causal \
--pqlimit 1000000 \
--beam 20 \
--beam-climbing \
--no-revisit-dets \
--num-det-orders 21 \
--print-stats
```

---

## Python Interface

[Full Python wrapper documentation](src/py/README.md)
Expand Down
1 change: 1 addition & 0 deletions src/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,7 @@ cc_binary(
linkopts = OPT_LINKOPTS,
deps = [
":libtesseract",
":libmulti_pass_tesseract_decoder",
"@argparse",
"@nlohmann_json//:json",
"@stim//:stim_lib",
Expand Down
180 changes: 90 additions & 90 deletions src/error_correlations.cc
Original file line number Diff line number Diff line change
@@ -1,123 +1,123 @@
#include "error_correlations.h"
#include <sstream>

#include <iostream>
#include <sstream>

namespace tesseract {

std::string ImpliedProbability::str() const {
std::stringstream ss;
ss << "ImpliedProbability(affected={";
for (size_t i = 0; i < affected_hyperedge.size(); ++i) {
ss << affected_hyperedge[i] << (i == affected_hyperedge.size() - 1 ? "" : ",");
}
ss << "}, prob=" << probability << ")";
return ss.str();
std::stringstream ss;
ss << "ImpliedProbability(affected={";
for (size_t i = 0; i < affected_hyperedge.size(); ++i) {
ss << affected_hyperedge[i] << (i == affected_hyperedge.size() - 1 ? "" : ",");
}
ss << "}, prob=" << probability << ")";
return ss.str();
}

bool ImpliedProbability::operator==(const ImpliedProbability& other) const {
return affected_hyperedge == other.affected_hyperedge &&
std::abs(probability - other.probability) < 1e-12;
return affected_hyperedge == other.affected_hyperedge &&
std::abs(probability - other.probability) < 1e-12;
}

bool ImpliedProbability::operator<(const ImpliedProbability& other) const {
if (affected_hyperedge != other.affected_hyperedge) {
return affected_hyperedge < other.affected_hyperedge;
}
return probability < other.probability;
if (affected_hyperedge != other.affected_hyperedge) {
return affected_hyperedge < other.affected_hyperedge;
}
return probability < other.probability;
}

JointProbsMap get_hyperedge_joint_probabilities(const stim::DetectorErrorModel& dem) {
JointProbsMap joint_probs;
auto flattened = dem.flattened();

for (const auto& inst : flattened.instructions) {
if (inst.type != stim::DemInstructionType::DEM_ERROR) continue;

double p = inst.arg_data[0];

std::vector<Hyperedge> components;
size_t group_start = 0;
for (size_t k = 0; k <= inst.target_data.size(); ++k) {
if (k == inst.target_data.size() || inst.target_data[k].is_separator()) {
Hyperedge hyperedge;
for (size_t j = group_start; j < k; ++j) {
const auto& target = inst.target_data[j];
if (target.is_relative_detector_id()) {
hyperedge.push_back(target.val());
}
}
if (!hyperedge.empty()) {
std::sort(hyperedge.begin(), hyperedge.end());
components.push_back(hyperedge);
}
group_start = k + 1;
}
}
JointProbsMap get_hyperedge_joint_probabilities(const stim::DetectorErrorModel& dem,
const std::vector<int>& global_det_to_comp_id) {
JointProbsMap joint_probs;
auto flattened = dem.flattened();

// 1. Marginal probabilities (diagonal)
for (const auto& h : components) {
if (joint_probs[h].find(h) == joint_probs[h].end()) {
joint_probs[h][h] = 0.0;
}
// P(A) = P(A) XOR p
joint_probs[h][h] = joint_probs[h][h] * (1 - p) + p * (1 - joint_probs[h][h]);
}
for (const auto& inst : flattened.instructions) {
if (inst.type != stim::DemInstructionType::DEM_ERROR) continue;

double p = inst.arg_data[0];

std::map<int, Hyperedge> comp_targets;
for (const auto& target : inst.target_data) {
if (target.is_relative_detector_id()) {
int d = target.val();
int cid =
(d >= 0 && (size_t)d < global_det_to_comp_id.size()) ? global_det_to_comp_id[d] : -1;
if (cid != -1) comp_targets[cid].push_back(d);
}
}

std::vector<Hyperedge> components;
for (auto& [cid, h] : comp_targets) {
std::sort(h.begin(), h.end());
components.push_back(h);
}

// 1. Marginal probabilities (diagonal)
for (const auto& h : components) {
if (joint_probs[h].find(h) == joint_probs[h].end()) {
joint_probs[h][h] = 0.0;
}
// P(A) = P(A) XOR p
joint_probs[h][h] = joint_probs[h][h] * (1 - p) + p * (1 - joint_probs[h][h]);
}

// 2. Joint probabilities (off-diagonal)
// For a bridging error p connecting A and B, P(A and B) += p (approx)
// Actually, the joint probability is accurately tracked via the same XOR logic
// if we assume independence of other error mechanisms.
if (components.size() > 1) {
for (size_t i = 0; i < components.size(); ++i) {
for (size_t j = 0; j < components.size(); ++j) {
if (i == j) continue;
const auto& hi = components[i];
const auto& hj = components[j];
if (joint_probs[hi].find(hj) == joint_probs[hi].end()) {
joint_probs[hi][hj] = 0.0;
}
// For small p, joint probability P(A and B) is roughly the sum of p's of bridging errors
joint_probs[hi][hj] = joint_probs[hi][hj] * (1 - p) + p * (1 - joint_probs[hi][hj]);
}
}
// 2. Joint probabilities (off-diagonal)
// For a bridging error p connecting A and B, P(A and B) += p (approx)
// Actually, the joint probability is accurately tracked via the same XOR logic
// if we assume independence of other error mechanisms.
if (components.size() > 1) {
for (size_t i = 0; i < components.size(); ++i) {
for (size_t j = 0; j < components.size(); ++j) {
if (i == j) continue;
const auto& hi = components[i];
const auto& hj = components[j];
if (joint_probs[hi].find(hj) == joint_probs[hi].end()) {
joint_probs[hi][hj] = 0.0;
}
// For small p, joint probability P(A and B) is roughly the sum of p's of bridging errors
joint_probs[hi][hj] = joint_probs[hi][hj] * (1 - p) + p * (1 - joint_probs[hi][hj]);
}
}
}
}

return joint_probs;
return joint_probs;
}

ImpliedProbsMap get_implied_hyperedge_probabilities(const JointProbsMap& joint_probs) {
ImpliedProbsMap implied_probs;
ImpliedProbsMap implied_probs;

for (const auto& [causal, affected_map] : joint_probs) {
double p_causal = 0.0;
auto it_self = affected_map.find(causal);
if (it_self != affected_map.end()) {
p_causal = it_self->second;
}
for (const auto& [causal, affected_map] : joint_probs) {
double p_causal = 0.0;
auto it_self = affected_map.find(causal);
if (it_self != affected_map.end()) {
p_causal = it_self->second;
}

if (p_causal <= 0 || p_causal >= 1.0) continue;
if (p_causal <= 0 || p_causal >= 1.0) continue;

for (const auto& [affected, p_joint] : affected_map) {
if (causal == affected) continue;
for (const auto& [affected, p_joint] : affected_map) {
if (causal == affected) continue;

// Conditional Probability P(affected | causal) = P(affected and causal) / P(causal)
double p_conditional = p_joint / p_causal;

// Cap to 1.0 (numerical precision)
if (p_conditional > 1.0) p_conditional = 1.0;
if (p_conditional < 0.0) p_conditional = 0.0;
// Conditional Probability P(affected | causal) = P(affected and causal) / P(causal)
double p_conditional = p_joint / p_causal;

implied_probs[causal].push_back({affected, p_conditional});
}
// Cap to 1.0 (numerical precision)
if (p_conditional > 1.0) p_conditional = 1.0;
if (p_conditional < 0.0) p_conditional = 0.0;

implied_probs[causal].push_back({affected, p_conditional});
}
}

return implied_probs;
return implied_probs;
}

ImpliedProbsMap process_dem_correlations(const stim::DetectorErrorModel& dem) {
auto joint = get_hyperedge_joint_probabilities(dem);
return get_implied_hyperedge_probabilities(joint);
ImpliedProbsMap process_dem_correlations(const stim::DetectorErrorModel& dem,
const std::vector<int>& global_det_to_comp_id) {
auto joint = get_hyperedge_joint_probabilities(dem, global_det_to_comp_id);
return get_implied_hyperedge_probabilities(joint);
}

} // namespace tesseract
} // namespace tesseract
29 changes: 16 additions & 13 deletions src/error_correlations.h
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
#ifndef ERROR_CORRELATIONS_H
#define ERROR_CORRELATIONS_H

#include <algorithm>
#include <cmath>
#include <map>
#include <numeric>
#include <vector>
#include <string>
#include <algorithm>
#include <vector>

#include "stim.h"

Expand All @@ -16,27 +16,29 @@ namespace tesseract {
* Represents a probability adjustment for an affected hyperedge given a causal hyperedge.
*/
struct ImpliedProbability {
std::vector<int> affected_hyperedge;
double probability; // Represents the conditional probability P(affected | causal)
std::vector<int> affected_hyperedge;
double probability; // Represents the conditional probability P(affected | causal)

std::string str() const;
bool operator==(const ImpliedProbability& other) const;
bool operator<(const ImpliedProbability& other) const;
std::string str() const;
bool operator==(const ImpliedProbability& other) const;
bool operator<(const ImpliedProbability& other) const;
};

// Type alias for hyperedge (sorted detector indices)
using Hyperedge = std::vector<int>;
// Type alias for joint probabilities map: causal_hyperedge -> {affected_hyperedge -> joint_prob}
using JointProbsMap = std::map<Hyperedge, std::map<Hyperedge, double>>;
// Type alias for implied probabilities map: causal_hyperedge -> list of conditional probability updates
// Type alias for implied probabilities map: causal_hyperedge -> list of conditional probability
// updates
using ImpliedProbsMap = std::map<Hyperedge, std::vector<ImpliedProbability>>;

/**
* Calculates marginal and joint probabilities for hyperedges in a DEM.
* Note: Assumes the input DEM has NOT been decomposed yet, as we need bridging errors
* Note: Assumes the input DEM has NOT been decomposed yet, as we need bridging errors
* to find joint probabilities.
*/
JointProbsMap get_hyperedge_joint_probabilities(const stim::DetectorErrorModel& dem);
JointProbsMap get_hyperedge_joint_probabilities(const stim::DetectorErrorModel& dem,
const std::vector<int>& global_det_to_comp_id);

/**
* Calculates conditional probabilities from joint probabilities.
Expand All @@ -46,8 +48,9 @@ ImpliedProbsMap get_implied_hyperedge_probabilities(const JointProbsMap& joint_p
/**
* Complete workflow for analyzing correlations within a stim::DetectorErrorModel.
*/
ImpliedProbsMap process_dem_correlations(const stim::DetectorErrorModel& dem);
ImpliedProbsMap process_dem_correlations(const stim::DetectorErrorModel& dem,
const std::vector<int>& global_det_to_comp_id);

} // namespace tesseract
} // namespace tesseract

#endif // ERROR_CORRELATIONS_H
#endif // ERROR_CORRELATIONS_H
Loading
Loading