From 9aeafe64fedc27bc9955551ade97e685dcf89828 Mon Sep 17 00:00:00 2001 From: Dave Woodruff Date: Wed, 20 May 2026 16:05:04 -0700 Subject: [PATCH 1/5] DEBUG: instrument SPBase.allreduce_or for LOR_bug A spurious shutdown is firing on every xhatter rank despite no rank having written 1.0 to the SHUTDOWN buffer; replacing Allreduce(LOR) with Allreduce(SUM) returns a stable-by-pattern nonzero (~69), and the input local_val has been verified zero on the xhatter ranks themselves. Four hypotheses remain: (i) self.mpicomm has wider membership than the xhatter cylinder (ii) buffer memory underneath local_val is not 0 when MPI reads it (iii) the Allreduce reducer path is malfunctioning (iv) duplicate rank participation in self.mpicomm This patch packs four diagnostic axes into a single Allreduce call: 1. an Allgather of (world_rk, cyl_rk, local_int) - shows exactly which world ranks participate and what each one contributed 2. parallel SUM, MAX, LOR reductions - MAX distinguishes "many small contributions" from "few large ones," LOR confirms observed call-site behavior 3. a rank-sum sanity reduction (each rank contributes its mpicomm rank), expected to equal n*(n-1)/2; mismatch flags a corrupt SUM reducer 4. a comparison between the Allgather-summed values and the Allreduce(SUM) result; divergence isolates the bug to the reducer path Output is printed on cyl_rk == 0 with the call counter, class name, host, pid, comm name, and world-rank min/max/count/unique; on any anomaly it also lists nonzero rows and the full participant list. One Allreduce call now does 4 reductions + 1 Allgather; cost is dominated by run-launch overhead in the target environment, so the extra collectives are acceptable. Revert before merging to main. Reading the output (greps): # 1. Did anything print at all? grep '^\[LOR_bug' out.log | head # 2. What does each cylinder think its comm size is? # (one printer per cylinder; size should be that cylinder's rank count) # If you see size=150 here, hypothesis (i) is the bug. grep 'mpicomm size=' out.log | sort -u # 3. World-rank membership of each cylinder. # count==unique is required; unique --- mpisppy/spbase.py | 116 ++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 112 insertions(+), 4 deletions(-) diff --git a/mpisppy/spbase.py b/mpisppy/spbase.py index 27873763e..e1befbac8 100644 --- a/mpisppy/spbase.py +++ b/mpisppy/spbase.py @@ -642,10 +642,118 @@ def spcomm(self, value): def allreduce_or(self, val): - local_val = np.array([val], dtype='int8') - global_val = np.zeros(1, dtype='int8') - self.mpicomm.Allreduce(local_val, global_val, op=MPI.LOR) - if global_val[0] > 0: + # ====== DEBUG: LOR_bug instrumentation ====== + # Yields per call (on cyl_rk == 0 of self.mpicomm) the full picture + # needed to localize an Allreduce(LOR) returning nonzero when every + # rank intends a zero. Probes four axes: + # 1. comm membership (world ranks participating, size, uniqueness) + # 2. data going in (Allgather of every rank's local_val) + # 3. reduction sanity (SUM/MAX/LOR + a rank-sum check whose + # expected value is n*(n-1)/2) + # 4. consistency (compare Allgather sum to Allreduce SUM + # to tell "input was wrong" from + # "Allreduce is wrong") + # Remove before merging to main. See PR description for hypothesis tree. + import os + import socket + import sys + sz = self.mpicomm.Get_size() + cyl_rk = self.mpicomm.Get_rank() + world_rk = MPI.COMM_WORLD.Get_rank() + host = socket.gethostname() + pid = os.getpid() + + local_int = 1 if val else 0 + local_int32 = np.array([local_int], dtype='int32') + local_int8 = np.array([local_int], dtype='int8') + + # (3) Reductions — three ops in parallel, plus a rank-sum sanity check. + sum_out = np.zeros(1, dtype='int32') + self.mpicomm.Allreduce(local_int32, sum_out, op=MPI.SUM) + + max_out = np.zeros(1, dtype='int32') + self.mpicomm.Allreduce(local_int32, max_out, op=MPI.MAX) + + lor_out = np.zeros(1, dtype='int8') + self.mpicomm.Allreduce(local_int8, lor_out, op=MPI.LOR) + + rank_in = np.array([cyl_rk], dtype='int32') + rank_out = np.zeros(1, dtype='int32') + self.mpicomm.Allreduce(rank_in, rank_out, op=MPI.SUM) + expected_rank_sum = sz * (sz - 1) // 2 + + # (1) + (2) Allgather of (world_rk, cyl_rk, local_int) so we see + # exactly which ranks participated and what each one contributed. + report = np.array([world_rk, cyl_rk, local_int], dtype='int32') + all_reports = np.zeros(3 * sz, dtype='int32') + self.mpicomm.Allgather(report, all_reports) + + # Track a per-instance call counter so logs are correlatable. + self._lor_diag_count = getattr(self, "_lor_diag_count", 0) + 1 + call_n = self._lor_diag_count + + if cyl_rk == 0: + rows = all_reports.reshape(sz, 3) + wr = rows[:, 0].tolist() + nonzero_rows = rows[rows[:, 2] != 0] + gather_sum = int(rows[:, 2].sum()) + cls = type(self).__name__ + try: + comm_name = self.mpicomm.Get_name() + except Exception: + comm_name = "" + print( + f"[LOR_bug call={call_n} cls={cls} " + f"world_rk={world_rk} host={host} pid={pid}] " + f"mpicomm size={sz} name={comm_name!r}", + flush=True, + ) + print( + f" world_ranks: min={min(wr)} max={max(wr)} " + f"count={len(wr)} unique={len(set(wr))}", + flush=True, + ) + print( + f" reductions: sum={int(sum_out[0])} max={int(max_out[0])} " + f"lor={int(lor_out[0])} rank_sum={int(rank_out[0])} " + f"expected_rank_sum={expected_rank_sum}", + flush=True, + ) + print( + f" gather: gather_sum={gather_sum} " + f"nonzero_reports={len(nonzero_rows)}", + flush=True, + ) + # If anything looks bad, dump the per-rank rows. + bad = ( + int(sum_out[0]) != 0 + or int(lor_out[0]) != 0 + or int(rank_out[0]) != expected_rank_sum + or int(sum_out[0]) != gather_sum + or len(set(wr)) != len(wr) + ) + if bad: + limit = min(64, len(nonzero_rows)) + for w, c, v in nonzero_rows[:limit].tolist(): + print( + f" nonzero: world_rk={w} cyl_rk={c} local_val={v}", + flush=True, + ) + if len(nonzero_rows) > limit: + print( + f" (... {len(nonzero_rows) - limit} more nonzero rows truncated ...)", + flush=True, + ) + # Also dump the full world-rank list once so we can see exactly + # who is participating in this comm. + print( + f" ALL world_ranks: {wr}", + flush=True, + ) + sys.stdout.flush() + # ====== END DEBUG ====== + + if lor_out[0] > 0: return True else: return False From 3d38b250286e0ba264df126e815a255377020146 Mon Sep 17 00:00:00 2001 From: Dave Woodruff Date: Wed, 20 May 2026 18:14:44 -0700 Subject: [PATCH 2/5] DEBUG(LOR_bug): tighten anomaly trigger to invariant violations Initial trigger fired on any nonzero reduce result, which caught legitimate shutdown signals (sum=lor=1, gather_sum=1, all consistent) and would flood logs on every cylinder finalization. Real-bug signature is invariant-violating, not just nonzero: - rank_sum != expected_rank_sum -> SUM broken on this comm - sum != gather_sum -> reducer disagrees with gather - unique != count -> duplicate world ranks - max > 1 -> non-boolean input on some rank Verified: 0 false positives across ~440k allreduce_or calls in a sizes 3-scen 3-rank xhatshuffle+lagrangian run (sizes_cylinders.py --num-scens 3 --xhatshuffle --lagrangian). --- mpisppy/spbase.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/mpisppy/spbase.py b/mpisppy/spbase.py index e1befbac8..4dc054fc8 100644 --- a/mpisppy/spbase.py +++ b/mpisppy/spbase.py @@ -724,13 +724,19 @@ def allreduce_or(self, val): f"nonzero_reports={len(nonzero_rows)}", flush=True, ) - # If anything looks bad, dump the per-rank rows. + # "Bad" = invariant-violating, NOT just "nonzero result." A + # legitimate shutdown signal returns sum=lor=1 with + # gather_sum=1 (consistent), which is fine. The real bug + # signature is gather_sum disagreeing with the Allreduce SUM + # (the reducer lying), or the rank-sum sanity check failing + # (SUM broken on this comm), or duplicate world ranks + # (group membership corrupted), or some rank packing >1 + # (non-boolean input — only possible under memory aliasing). bad = ( - int(sum_out[0]) != 0 - or int(lor_out[0]) != 0 - or int(rank_out[0]) != expected_rank_sum + int(rank_out[0]) != expected_rank_sum or int(sum_out[0]) != gather_sum or len(set(wr)) != len(wr) + or int(max_out[0]) > 1 ) if bad: limit = min(64, len(nonzero_rows)) From b54e8880ded50c5466b4b9e346326170169071f2 Mon Sep 17 00:00:00 2001 From: Dave Woodruff Date: Fri, 22 May 2026 14:21:45 -0700 Subject: [PATCH 3/5] DEBUG(LOR_bug): add lor_bug_report.py to summarize diagnostic output Parses a log containing [LOR_bug ...] blocks emitted by the instrumented SPBase.allreduce_or and writes a short per-hypothesis report (H1 wider membership, H2 buffer aliasing, H3 reducer malfunction, H4 duplicate ranks). Intended to live and die with this debug branch. Co-Authored-By: Claude Opus 4.7 (1M context) --- lor_bug_report.py | 225 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 225 insertions(+) create mode 100644 lor_bug_report.py diff --git a/lor_bug_report.py b/lor_bug_report.py new file mode 100644 index 000000000..9cb7b8ad0 --- /dev/null +++ b/lor_bug_report.py @@ -0,0 +1,225 @@ +############################################################################### +# mpi-sppy: MPI-based Stochastic Programming in PYthon +# +# Copyright (c) 2024, Lawrence Livermore National Security, LLC, Alliance for +# Sustainable Energy, LLC, The Regents of the University of California, et al. +# All rights reserved. Please see the files COPYRIGHT.md and LICENSE.md for +# full copyright and license information. +############################################################################### +"""Summarize LOR_bug diagnostic output from PR #717. + +PR #717 instruments mpisppy/spbase.py::SPBase.allreduce_or to print a 4-line +block on every call (from cyl_rk == 0 of the cylinder's mpicomm). This script +parses such a log and writes a short report on the four hypotheses being +tested: + + H1. self.mpicomm has wider membership than the cylinder it should. + H2. Buffer memory underneath local_val was nonzero / non-boolean. + H3. The Allreduce reducer path is malfunctioning. + H4. Duplicate rank participation in self.mpicomm. + +Usage: + python lor_bug_report.py +""" + +import re +import sys +from collections import defaultdict + + +HEADER_RE = re.compile( + r"^\[LOR_bug call=(?P\d+) cls=(?P\S+) " + r"world_rk=(?P\d+) host=(?P\S+) pid=(?P\d+)\] " + r"mpicomm size=(?P\d+) name=(?P.+)$" +) +WORLD_RANKS_RE = re.compile( + r"^\s*world_ranks: min=(?P\d+) max=(?P\d+) " + r"count=(?P\d+) unique=(?P\d+)$" +) +REDUCTIONS_RE = re.compile( + r"^\s*reductions: sum=(?P-?\d+) max=(?P-?\d+) " + r"lor=(?P-?\d+) rank_sum=(?P-?\d+) " + r"expected_rank_sum=(?P-?\d+)$" +) +GATHER_RE = re.compile( + r"^\s*gather: gather_sum=(?P-?\d+) " + r"nonzero_reports=(?P\d+)$" +) + + +def parse(path): + """Return a list of dicts, one per [LOR_bug ...] block.""" + with open(path) as f: + lines = f.readlines() + + entries = [] + i = 0 + n = len(lines) + while i < n: + m = HEADER_RE.match(lines[i].rstrip()) + if not m: + i += 1 + continue + entry = { + "call": int(m["call"]), + "cls": m["cls"], + "world_rk": int(m["world_rk"]), + "host": m["host"], + "pid": int(m["pid"]), + "size": int(m["size"]), + "name": m["name"], + } + i += 1 + # The next three lines should be world_ranks / reductions / gather, + # in that order. Tolerate missing lines defensively. + for pat in (WORLD_RANKS_RE, REDUCTIONS_RE, GATHER_RE): + if i >= n: + break + mm = pat.match(lines[i].rstrip()) + if not mm: + break + for k, v in mm.groupdict().items(): + entry[k] = int(v) + i += 1 + entries.append(entry) + return entries + + +def _examples(rows, n=5): + out = [] + for e in rows[:n]: + out.append( + f" cls={e['cls']} call={e['call']} " + f"world_rk={e['world_rk']} host={e['host']}" + ) + if len(rows) > n: + out.append(f" (... {len(rows) - n} more truncated ...)") + return "\n".join(out) + + +def report(entries, path): + print(f"LOR_bug report for: {path}") + print(f"Parsed {len(entries)} [LOR_bug ...] blocks.") + if not entries: + print("\nNo diagnostic blocks found. Was the run on the LOR_bug branch?") + return + + # ---------- per-comm summary ---------- + by_comm = defaultdict(list) + for e in entries: + by_comm[(e["cls"], e["name"])].append(e) + + print("\nPer-comm summary (one printer per comm; cyl_rk == 0 only):") + for (cls, name), es in sorted(by_comm.items()): + sizes = sorted({e["size"] for e in es}) + wrs = sorted({e["world_rk"] for e in es}) + print(f" cls={cls} name={name}") + print(f" calls={len(es)} sizes={sizes} printer_world_rk={wrs}") + + # ---------- H1: wider membership ---------- + # Signal: size varies within a single (cls, name) bucket, OR printer + # world_rk varies across calls for the same logical comm (meaning + # different ranks took the "rank 0" role — only possible if comm + # membership shifted). + print("\nH1 — wider mpicomm membership than expected:") + h1_hits = [] + for (cls, name), es in by_comm.items(): + sizes = {e["size"] for e in es} + printers = {e["world_rk"] for e in es} + if len(sizes) > 1 or len(printers) > 1: + h1_hits.append((cls, name, sorted(sizes), sorted(printers))) + if h1_hits: + print(" WARNING: comm membership is not stable across calls:") + for cls, name, sizes, printers in h1_hits: + print(f" cls={cls} name={name} sizes={sizes} " + f"printer_world_rks={printers}") + else: + print(" OK: every comm has a stable size and stable rank-0 printer.") + + # Also: if two different comms share the same printer world rank, that + # rank straddles two cylinders -- possible cross-cylinder contamination. + printer_to_comms = defaultdict(set) + for (cls, name), es in by_comm.items(): + for e in es: + printer_to_comms[e["world_rk"]].add((cls, name)) + shared = {wr: cs for wr, cs in printer_to_comms.items() if len(cs) > 1} + if shared: + print(" NOTE: world ranks acting as printer for multiple comms:") + for wr, cs in sorted(shared.items()): + print(f" world_rk={wr} comms={sorted(cs)}") + + # ---------- H2: buffer aliasing / non-boolean input ---------- + # Signature per PR description: nonzero local_val where it should be 0. + # The unambiguous tell is max > 1 (input was not a Python bool). + print("\nH2 — buffer aliasing / non-boolean input:") + nonbool = [e for e in entries if e.get("max", 0) > 1] + nonzero = [e for e in entries if e.get("gather_sum", 0) > 0] + print(f" Calls with any nonzero local_val: {len(nonzero)} / {len(entries)}" + f" (these may be legitimate True returns)") + if nonbool: + print(f" STRONG SIGNAL: {len(nonbool)} calls had max > 1 " + f"(input was not boolean)") + print(_examples(nonbool)) + else: + print(" OK: every nonzero local_val was 1 (boolean).") + + # ---------- H3: reducer malfunction ---------- + # (a) Allreduce SUM disagrees with the Allgather-summed local_vals. + # (b) rank_sum != expected sum-of-ranks for a comm of this size. + print("\nH3 — Allreduce reducer malfunction:") + sum_mismatch = [e for e in entries + if "sum" in e and "gather_sum" in e + and e["sum"] != e["gather_sum"]] + rank_sum_fail = [e for e in entries + if "rank_sum" in e and "expected_rank_sum" in e + and e["rank_sum"] != e["expected_rank_sum"]] + print(f" sum != gather_sum (reducer disagreeing with gather): " + f"{len(sum_mismatch)}") + if sum_mismatch: + print(_examples(sum_mismatch)) + print(f" rank_sum sanity failures (SUM broken on this comm): " + f"{len(rank_sum_fail)}") + if rank_sum_fail: + print(_examples(rank_sum_fail)) + + # ---------- H4: duplicate rank participation ---------- + print("\nH4 — duplicate rank participation in mpicomm:") + dups = [e for e in entries + if "unique" in e and "count" in e and e["unique"] < e["count"]] + print(f" Calls with duplicate world ranks: {len(dups)}") + if dups: + print(_examples(dups)) + + # ---------- Verdict ---------- + print("\nVerdict:") + triggered = [] + if h1_hits: + triggered.append("H1 (wider/unstable membership)") + if nonbool: + triggered.append("H2 (non-boolean input)") + if sum_mismatch or rank_sum_fail: + triggered.append("H3 (reducer)") + if dups: + triggered.append("H4 (duplicate ranks)") + if triggered: + print(" Hypotheses triggered: " + ", ".join(triggered)) + else: + if nonzero: + print(" No invariant violations. Some calls returned nonzero;" + " consistent with legitimate shutdown signals.") + else: + print(" Clean log: no anomalies on any of the four hypotheses.") + + +def main(argv): + if len(argv) != 2: + print(f"Usage: {argv[0]} ", file=sys.stderr) + return 2 + path = argv[1] + entries = parse(path) + report(entries, path) + return 0 + + +if __name__ == "__main__": + sys.exit(main(sys.argv)) From 6578aefe64528c36a19d732e6297265e6fb7afec Mon Sep 17 00:00:00 2001 From: Dave Woodruff Date: Fri, 22 May 2026 14:25:24 -0700 Subject: [PATCH 4/5] DEBUG(LOR_bug): rename log_file -> stdout_file in usage text Co-Authored-By: Claude Opus 4.7 (1M context) --- lor_bug_report.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lor_bug_report.py b/lor_bug_report.py index 9cb7b8ad0..983c379ad 100644 --- a/lor_bug_report.py +++ b/lor_bug_report.py @@ -19,7 +19,7 @@ H4. Duplicate rank participation in self.mpicomm. Usage: - python lor_bug_report.py + python lor_bug_report.py """ import re @@ -213,7 +213,7 @@ def report(entries, path): def main(argv): if len(argv) != 2: - print(f"Usage: {argv[0]} ", file=sys.stderr) + print(f"Usage: {argv[0]} ", file=sys.stderr) return 2 path = argv[1] entries = parse(path) From 6c0b9c48e78c817972da1236473e6b22f5637f1a Mon Sep 17 00:00:00 2001 From: Dave Woodruff Date: Fri, 22 May 2026 14:28:37 -0700 Subject: [PATCH 5/5] =?UTF-8?q?DEBUG(LOR=5Fbug):=20lor=5Fbug=5Freport.py?= =?UTF-8?q?=20=E2=80=94=20show=20offending=20values,=20warn=20on=20default?= =?UTF-8?q?=20comm=20names?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Example lines now include the values that tripped each detector (max/gather_sum for H2, sum/gather_sum for H3 reducer, rank_sum/ expected for H3 sanity, count/unique for H4) so the triager doesn't have to grep back into the raw log. - H1 emits a NOTE when any comm has a default/empty Get_name() value, since (cls, name) grouping can collapse distinct physical comms and spuriously trip the "varying size" check. Co-Authored-By: Claude Opus 4.7 (1M context) --- lor_bug_report.py | 27 +++++++++++++++++++++------ 1 file changed, 21 insertions(+), 6 deletions(-) diff --git a/lor_bug_report.py b/lor_bug_report.py index 983c379ad..2f2b4dd48 100644 --- a/lor_bug_report.py +++ b/lor_bug_report.py @@ -85,18 +85,27 @@ def parse(path): return entries -def _examples(rows, n=5): +def _examples(rows, fields, n=5): + """Format up to n example rows, showing call ID + the listed fields.""" out = [] for e in rows[:n]: + extras = " ".join(f"{f}={e.get(f, '?')}" for f in fields) out.append( f" cls={e['cls']} call={e['call']} " - f"world_rk={e['world_rk']} host={e['host']}" + f"world_rk={e['world_rk']} host={e['host']} {extras}" ) if len(rows) > n: out.append(f" (... {len(rows) - n} more truncated ...)") return "\n".join(out) +# MPI implementations often leave new communicators with an empty or +# generic default name. When that happens, grouping by (cls, name) can +# collapse distinct physical comms into one bucket and falsely trip H1. +_DEFAULT_COMM_NAMES = {"''", '""', "'MPI_COMM_WORLD'", "'MPI_COMMUNICATOR'", + "", "''"} + + def report(entries, path): print(f"LOR_bug report for: {path}") print(f"Parsed {len(entries)} [LOR_bug ...] blocks.") @@ -135,6 +144,12 @@ def report(entries, path): f"printer_world_rks={printers}") else: print(" OK: every comm has a stable size and stable rank-0 printer.") + defaulted = sorted({n for (_, n) in by_comm if n in _DEFAULT_COMM_NAMES}) + if defaulted: + print(f" NOTE: some comms have default/empty names ({defaulted}); " + "distinct physical comms may collapse into one bucket here " + "and produce spurious H1 hits. Check `printer_world_rk` " + "in the per-comm summary above.") # Also: if two different comms share the same printer world rank, that # rank straddles two cylinders -- possible cross-cylinder contamination. @@ -159,7 +174,7 @@ def report(entries, path): if nonbool: print(f" STRONG SIGNAL: {len(nonbool)} calls had max > 1 " f"(input was not boolean)") - print(_examples(nonbool)) + print(_examples(nonbool, ["max", "gather_sum"])) else: print(" OK: every nonzero local_val was 1 (boolean).") @@ -176,11 +191,11 @@ def report(entries, path): print(f" sum != gather_sum (reducer disagreeing with gather): " f"{len(sum_mismatch)}") if sum_mismatch: - print(_examples(sum_mismatch)) + print(_examples(sum_mismatch, ["sum", "gather_sum"])) print(f" rank_sum sanity failures (SUM broken on this comm): " f"{len(rank_sum_fail)}") if rank_sum_fail: - print(_examples(rank_sum_fail)) + print(_examples(rank_sum_fail, ["rank_sum", "expected_rank_sum"])) # ---------- H4: duplicate rank participation ---------- print("\nH4 — duplicate rank participation in mpicomm:") @@ -188,7 +203,7 @@ def report(entries, path): if "unique" in e and "count" in e and e["unique"] < e["count"]] print(f" Calls with duplicate world ranks: {len(dups)}") if dups: - print(_examples(dups)) + print(_examples(dups, ["count", "unique"])) # ---------- Verdict ---------- print("\nVerdict:")