diff --git a/lor_bug_report.py b/lor_bug_report.py new file mode 100644 index 000000000..2f2b4dd48 --- /dev/null +++ b/lor_bug_report.py @@ -0,0 +1,240 @@ +############################################################################### +# mpi-sppy: MPI-based Stochastic Programming in PYthon +# +# Copyright (c) 2024, Lawrence Livermore National Security, LLC, Alliance for +# Sustainable Energy, LLC, The Regents of the University of California, et al. +# All rights reserved. Please see the files COPYRIGHT.md and LICENSE.md for +# full copyright and license information. +############################################################################### +"""Summarize LOR_bug diagnostic output from PR #717. + +PR #717 instruments mpisppy/spbase.py::SPBase.allreduce_or to print a 4-line +block on every call (from cyl_rk == 0 of the cylinder's mpicomm). This script +parses such a log and writes a short report on the four hypotheses being +tested: + + H1. self.mpicomm has wider membership than the cylinder it should. + H2. Buffer memory underneath local_val was nonzero / non-boolean. + H3. The Allreduce reducer path is malfunctioning. + H4. Duplicate rank participation in self.mpicomm. + +Usage: + python lor_bug_report.py +""" + +import re +import sys +from collections import defaultdict + + +HEADER_RE = re.compile( + r"^\[LOR_bug call=(?P\d+) cls=(?P\S+) " + r"world_rk=(?P\d+) host=(?P\S+) pid=(?P\d+)\] " + r"mpicomm size=(?P\d+) name=(?P.+)$" +) +WORLD_RANKS_RE = re.compile( + r"^\s*world_ranks: min=(?P\d+) max=(?P\d+) " + r"count=(?P\d+) unique=(?P\d+)$" +) +REDUCTIONS_RE = re.compile( + r"^\s*reductions: sum=(?P-?\d+) max=(?P-?\d+) " + r"lor=(?P-?\d+) rank_sum=(?P-?\d+) " + r"expected_rank_sum=(?P-?\d+)$" +) +GATHER_RE = re.compile( + r"^\s*gather: gather_sum=(?P-?\d+) " + r"nonzero_reports=(?P\d+)$" +) + + +def parse(path): + """Return a list of dicts, one per [LOR_bug ...] block.""" + with open(path) as f: + lines = f.readlines() + + entries = [] + i = 0 + n = len(lines) + while i < n: + m = HEADER_RE.match(lines[i].rstrip()) + if not m: + i += 1 + continue + entry = { + "call": int(m["call"]), + "cls": m["cls"], + "world_rk": int(m["world_rk"]), + "host": m["host"], + "pid": int(m["pid"]), + "size": int(m["size"]), + "name": m["name"], + } + i += 1 + # The next three lines should be world_ranks / reductions / gather, + # in that order. Tolerate missing lines defensively. + for pat in (WORLD_RANKS_RE, REDUCTIONS_RE, GATHER_RE): + if i >= n: + break + mm = pat.match(lines[i].rstrip()) + if not mm: + break + for k, v in mm.groupdict().items(): + entry[k] = int(v) + i += 1 + entries.append(entry) + return entries + + +def _examples(rows, fields, n=5): + """Format up to n example rows, showing call ID + the listed fields.""" + out = [] + for e in rows[:n]: + extras = " ".join(f"{f}={e.get(f, '?')}" for f in fields) + out.append( + f" cls={e['cls']} call={e['call']} " + f"world_rk={e['world_rk']} host={e['host']} {extras}" + ) + if len(rows) > n: + out.append(f" (... {len(rows) - n} more truncated ...)") + return "\n".join(out) + + +# MPI implementations often leave new communicators with an empty or +# generic default name. When that happens, grouping by (cls, name) can +# collapse distinct physical comms into one bucket and falsely trip H1. +_DEFAULT_COMM_NAMES = {"''", '""', "'MPI_COMM_WORLD'", "'MPI_COMMUNICATOR'", + "", "''"} + + +def report(entries, path): + print(f"LOR_bug report for: {path}") + print(f"Parsed {len(entries)} [LOR_bug ...] blocks.") + if not entries: + print("\nNo diagnostic blocks found. Was the run on the LOR_bug branch?") + return + + # ---------- per-comm summary ---------- + by_comm = defaultdict(list) + for e in entries: + by_comm[(e["cls"], e["name"])].append(e) + + print("\nPer-comm summary (one printer per comm; cyl_rk == 0 only):") + for (cls, name), es in sorted(by_comm.items()): + sizes = sorted({e["size"] for e in es}) + wrs = sorted({e["world_rk"] for e in es}) + print(f" cls={cls} name={name}") + print(f" calls={len(es)} sizes={sizes} printer_world_rk={wrs}") + + # ---------- H1: wider membership ---------- + # Signal: size varies within a single (cls, name) bucket, OR printer + # world_rk varies across calls for the same logical comm (meaning + # different ranks took the "rank 0" role — only possible if comm + # membership shifted). + print("\nH1 — wider mpicomm membership than expected:") + h1_hits = [] + for (cls, name), es in by_comm.items(): + sizes = {e["size"] for e in es} + printers = {e["world_rk"] for e in es} + if len(sizes) > 1 or len(printers) > 1: + h1_hits.append((cls, name, sorted(sizes), sorted(printers))) + if h1_hits: + print(" WARNING: comm membership is not stable across calls:") + for cls, name, sizes, printers in h1_hits: + print(f" cls={cls} name={name} sizes={sizes} " + f"printer_world_rks={printers}") + else: + print(" OK: every comm has a stable size and stable rank-0 printer.") + defaulted = sorted({n for (_, n) in by_comm if n in _DEFAULT_COMM_NAMES}) + if defaulted: + print(f" NOTE: some comms have default/empty names ({defaulted}); " + "distinct physical comms may collapse into one bucket here " + "and produce spurious H1 hits. Check `printer_world_rk` " + "in the per-comm summary above.") + + # Also: if two different comms share the same printer world rank, that + # rank straddles two cylinders -- possible cross-cylinder contamination. + printer_to_comms = defaultdict(set) + for (cls, name), es in by_comm.items(): + for e in es: + printer_to_comms[e["world_rk"]].add((cls, name)) + shared = {wr: cs for wr, cs in printer_to_comms.items() if len(cs) > 1} + if shared: + print(" NOTE: world ranks acting as printer for multiple comms:") + for wr, cs in sorted(shared.items()): + print(f" world_rk={wr} comms={sorted(cs)}") + + # ---------- H2: buffer aliasing / non-boolean input ---------- + # Signature per PR description: nonzero local_val where it should be 0. + # The unambiguous tell is max > 1 (input was not a Python bool). + print("\nH2 — buffer aliasing / non-boolean input:") + nonbool = [e for e in entries if e.get("max", 0) > 1] + nonzero = [e for e in entries if e.get("gather_sum", 0) > 0] + print(f" Calls with any nonzero local_val: {len(nonzero)} / {len(entries)}" + f" (these may be legitimate True returns)") + if nonbool: + print(f" STRONG SIGNAL: {len(nonbool)} calls had max > 1 " + f"(input was not boolean)") + print(_examples(nonbool, ["max", "gather_sum"])) + else: + print(" OK: every nonzero local_val was 1 (boolean).") + + # ---------- H3: reducer malfunction ---------- + # (a) Allreduce SUM disagrees with the Allgather-summed local_vals. + # (b) rank_sum != expected sum-of-ranks for a comm of this size. + print("\nH3 — Allreduce reducer malfunction:") + sum_mismatch = [e for e in entries + if "sum" in e and "gather_sum" in e + and e["sum"] != e["gather_sum"]] + rank_sum_fail = [e for e in entries + if "rank_sum" in e and "expected_rank_sum" in e + and e["rank_sum"] != e["expected_rank_sum"]] + print(f" sum != gather_sum (reducer disagreeing with gather): " + f"{len(sum_mismatch)}") + if sum_mismatch: + print(_examples(sum_mismatch, ["sum", "gather_sum"])) + print(f" rank_sum sanity failures (SUM broken on this comm): " + f"{len(rank_sum_fail)}") + if rank_sum_fail: + print(_examples(rank_sum_fail, ["rank_sum", "expected_rank_sum"])) + + # ---------- H4: duplicate rank participation ---------- + print("\nH4 — duplicate rank participation in mpicomm:") + dups = [e for e in entries + if "unique" in e and "count" in e and e["unique"] < e["count"]] + print(f" Calls with duplicate world ranks: {len(dups)}") + if dups: + print(_examples(dups, ["count", "unique"])) + + # ---------- Verdict ---------- + print("\nVerdict:") + triggered = [] + if h1_hits: + triggered.append("H1 (wider/unstable membership)") + if nonbool: + triggered.append("H2 (non-boolean input)") + if sum_mismatch or rank_sum_fail: + triggered.append("H3 (reducer)") + if dups: + triggered.append("H4 (duplicate ranks)") + if triggered: + print(" Hypotheses triggered: " + ", ".join(triggered)) + else: + if nonzero: + print(" No invariant violations. Some calls returned nonzero;" + " consistent with legitimate shutdown signals.") + else: + print(" Clean log: no anomalies on any of the four hypotheses.") + + +def main(argv): + if len(argv) != 2: + print(f"Usage: {argv[0]} ", file=sys.stderr) + return 2 + path = argv[1] + entries = parse(path) + report(entries, path) + return 0 + + +if __name__ == "__main__": + sys.exit(main(sys.argv)) diff --git a/mpisppy/spbase.py b/mpisppy/spbase.py index 27873763e..4dc054fc8 100644 --- a/mpisppy/spbase.py +++ b/mpisppy/spbase.py @@ -642,10 +642,124 @@ def spcomm(self, value): def allreduce_or(self, val): - local_val = np.array([val], dtype='int8') - global_val = np.zeros(1, dtype='int8') - self.mpicomm.Allreduce(local_val, global_val, op=MPI.LOR) - if global_val[0] > 0: + # ====== DEBUG: LOR_bug instrumentation ====== + # Yields per call (on cyl_rk == 0 of self.mpicomm) the full picture + # needed to localize an Allreduce(LOR) returning nonzero when every + # rank intends a zero. Probes four axes: + # 1. comm membership (world ranks participating, size, uniqueness) + # 2. data going in (Allgather of every rank's local_val) + # 3. reduction sanity (SUM/MAX/LOR + a rank-sum check whose + # expected value is n*(n-1)/2) + # 4. consistency (compare Allgather sum to Allreduce SUM + # to tell "input was wrong" from + # "Allreduce is wrong") + # Remove before merging to main. See PR description for hypothesis tree. + import os + import socket + import sys + sz = self.mpicomm.Get_size() + cyl_rk = self.mpicomm.Get_rank() + world_rk = MPI.COMM_WORLD.Get_rank() + host = socket.gethostname() + pid = os.getpid() + + local_int = 1 if val else 0 + local_int32 = np.array([local_int], dtype='int32') + local_int8 = np.array([local_int], dtype='int8') + + # (3) Reductions — three ops in parallel, plus a rank-sum sanity check. + sum_out = np.zeros(1, dtype='int32') + self.mpicomm.Allreduce(local_int32, sum_out, op=MPI.SUM) + + max_out = np.zeros(1, dtype='int32') + self.mpicomm.Allreduce(local_int32, max_out, op=MPI.MAX) + + lor_out = np.zeros(1, dtype='int8') + self.mpicomm.Allreduce(local_int8, lor_out, op=MPI.LOR) + + rank_in = np.array([cyl_rk], dtype='int32') + rank_out = np.zeros(1, dtype='int32') + self.mpicomm.Allreduce(rank_in, rank_out, op=MPI.SUM) + expected_rank_sum = sz * (sz - 1) // 2 + + # (1) + (2) Allgather of (world_rk, cyl_rk, local_int) so we see + # exactly which ranks participated and what each one contributed. + report = np.array([world_rk, cyl_rk, local_int], dtype='int32') + all_reports = np.zeros(3 * sz, dtype='int32') + self.mpicomm.Allgather(report, all_reports) + + # Track a per-instance call counter so logs are correlatable. + self._lor_diag_count = getattr(self, "_lor_diag_count", 0) + 1 + call_n = self._lor_diag_count + + if cyl_rk == 0: + rows = all_reports.reshape(sz, 3) + wr = rows[:, 0].tolist() + nonzero_rows = rows[rows[:, 2] != 0] + gather_sum = int(rows[:, 2].sum()) + cls = type(self).__name__ + try: + comm_name = self.mpicomm.Get_name() + except Exception: + comm_name = "" + print( + f"[LOR_bug call={call_n} cls={cls} " + f"world_rk={world_rk} host={host} pid={pid}] " + f"mpicomm size={sz} name={comm_name!r}", + flush=True, + ) + print( + f" world_ranks: min={min(wr)} max={max(wr)} " + f"count={len(wr)} unique={len(set(wr))}", + flush=True, + ) + print( + f" reductions: sum={int(sum_out[0])} max={int(max_out[0])} " + f"lor={int(lor_out[0])} rank_sum={int(rank_out[0])} " + f"expected_rank_sum={expected_rank_sum}", + flush=True, + ) + print( + f" gather: gather_sum={gather_sum} " + f"nonzero_reports={len(nonzero_rows)}", + flush=True, + ) + # "Bad" = invariant-violating, NOT just "nonzero result." A + # legitimate shutdown signal returns sum=lor=1 with + # gather_sum=1 (consistent), which is fine. The real bug + # signature is gather_sum disagreeing with the Allreduce SUM + # (the reducer lying), or the rank-sum sanity check failing + # (SUM broken on this comm), or duplicate world ranks + # (group membership corrupted), or some rank packing >1 + # (non-boolean input — only possible under memory aliasing). + bad = ( + int(rank_out[0]) != expected_rank_sum + or int(sum_out[0]) != gather_sum + or len(set(wr)) != len(wr) + or int(max_out[0]) > 1 + ) + if bad: + limit = min(64, len(nonzero_rows)) + for w, c, v in nonzero_rows[:limit].tolist(): + print( + f" nonzero: world_rk={w} cyl_rk={c} local_val={v}", + flush=True, + ) + if len(nonzero_rows) > limit: + print( + f" (... {len(nonzero_rows) - limit} more nonzero rows truncated ...)", + flush=True, + ) + # Also dump the full world-rank list once so we can see exactly + # who is participating in this comm. + print( + f" ALL world_ranks: {wr}", + flush=True, + ) + sys.stdout.flush() + # ====== END DEBUG ====== + + if lor_out[0] > 0: return True else: return False