Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
240 changes: 240 additions & 0 deletions lor_bug_report.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,240 @@
###############################################################################
# mpi-sppy: MPI-based Stochastic Programming in PYthon
#
# Copyright (c) 2024, Lawrence Livermore National Security, LLC, Alliance for
# Sustainable Energy, LLC, The Regents of the University of California, et al.
# All rights reserved. Please see the files COPYRIGHT.md and LICENSE.md for
# full copyright and license information.
###############################################################################
"""Summarize LOR_bug diagnostic output from PR #717.

PR #717 instruments mpisppy/spbase.py::SPBase.allreduce_or to print a 4-line
block on every call (from cyl_rk == 0 of the cylinder's mpicomm). This script
parses such a log and writes a short report on the four hypotheses being
tested:

H1. self.mpicomm has wider membership than the cylinder it should.
H2. Buffer memory underneath local_val was nonzero / non-boolean.
H3. The Allreduce reducer path is malfunctioning.
H4. Duplicate rank participation in self.mpicomm.

Usage:
python lor_bug_report.py <stdout_file>
"""

import re
import sys
from collections import defaultdict


HEADER_RE = re.compile(
r"^\[LOR_bug call=(?P<call>\d+) cls=(?P<cls>\S+) "
r"world_rk=(?P<world_rk>\d+) host=(?P<host>\S+) pid=(?P<pid>\d+)\] "
r"mpicomm size=(?P<size>\d+) name=(?P<name>.+)$"
)
WORLD_RANKS_RE = re.compile(
r"^\s*world_ranks: min=(?P<wr_min>\d+) max=(?P<wr_max>\d+) "
r"count=(?P<count>\d+) unique=(?P<unique>\d+)$"
)
REDUCTIONS_RE = re.compile(
r"^\s*reductions: sum=(?P<sum>-?\d+) max=(?P<max>-?\d+) "
r"lor=(?P<lor>-?\d+) rank_sum=(?P<rank_sum>-?\d+) "
r"expected_rank_sum=(?P<expected_rank_sum>-?\d+)$"
)
GATHER_RE = re.compile(
r"^\s*gather: gather_sum=(?P<gather_sum>-?\d+) "
r"nonzero_reports=(?P<nonzero_reports>\d+)$"
)


def parse(path):
"""Return a list of dicts, one per [LOR_bug ...] block."""
with open(path) as f:
lines = f.readlines()

entries = []
i = 0
n = len(lines)
while i < n:
m = HEADER_RE.match(lines[i].rstrip())
if not m:
i += 1
continue
entry = {
"call": int(m["call"]),
"cls": m["cls"],
"world_rk": int(m["world_rk"]),
"host": m["host"],
"pid": int(m["pid"]),
"size": int(m["size"]),
"name": m["name"],
}
i += 1
# The next three lines should be world_ranks / reductions / gather,
# in that order. Tolerate missing lines defensively.
for pat in (WORLD_RANKS_RE, REDUCTIONS_RE, GATHER_RE):
if i >= n:
break
mm = pat.match(lines[i].rstrip())
if not mm:
break
for k, v in mm.groupdict().items():
entry[k] = int(v)
i += 1
entries.append(entry)
return entries


def _examples(rows, fields, n=5):
"""Format up to n example rows, showing call ID + the listed fields."""
out = []
for e in rows[:n]:
extras = " ".join(f"{f}={e.get(f, '?')}" for f in fields)
out.append(
f" cls={e['cls']} call={e['call']} "
f"world_rk={e['world_rk']} host={e['host']} {extras}"
)
if len(rows) > n:
out.append(f" (... {len(rows) - n} more truncated ...)")
return "\n".join(out)


# MPI implementations often leave new communicators with an empty or
# generic default name. When that happens, grouping by (cls, name) can
# collapse distinct physical comms into one bucket and falsely trip H1.
_DEFAULT_COMM_NAMES = {"''", '""', "'MPI_COMM_WORLD'", "'MPI_COMMUNICATOR'",
"<unknown>", "'<unknown>'"}


def report(entries, path):
print(f"LOR_bug report for: {path}")
print(f"Parsed {len(entries)} [LOR_bug ...] blocks.")
if not entries:
print("\nNo diagnostic blocks found. Was the run on the LOR_bug branch?")
return

# ---------- per-comm summary ----------
by_comm = defaultdict(list)
for e in entries:
by_comm[(e["cls"], e["name"])].append(e)

print("\nPer-comm summary (one printer per comm; cyl_rk == 0 only):")
for (cls, name), es in sorted(by_comm.items()):
sizes = sorted({e["size"] for e in es})
wrs = sorted({e["world_rk"] for e in es})
print(f" cls={cls} name={name}")
print(f" calls={len(es)} sizes={sizes} printer_world_rk={wrs}")

# ---------- H1: wider membership ----------
# Signal: size varies within a single (cls, name) bucket, OR printer
# world_rk varies across calls for the same logical comm (meaning
# different ranks took the "rank 0" role — only possible if comm
# membership shifted).
print("\nH1 — wider mpicomm membership than expected:")
h1_hits = []
for (cls, name), es in by_comm.items():
sizes = {e["size"] for e in es}
printers = {e["world_rk"] for e in es}
if len(sizes) > 1 or len(printers) > 1:
h1_hits.append((cls, name, sorted(sizes), sorted(printers)))
if h1_hits:
print(" WARNING: comm membership is not stable across calls:")
for cls, name, sizes, printers in h1_hits:
print(f" cls={cls} name={name} sizes={sizes} "
f"printer_world_rks={printers}")
else:
print(" OK: every comm has a stable size and stable rank-0 printer.")
defaulted = sorted({n for (_, n) in by_comm if n in _DEFAULT_COMM_NAMES})
if defaulted:
print(f" NOTE: some comms have default/empty names ({defaulted}); "
"distinct physical comms may collapse into one bucket here "
"and produce spurious H1 hits. Check `printer_world_rk` "
"in the per-comm summary above.")

# Also: if two different comms share the same printer world rank, that
# rank straddles two cylinders -- possible cross-cylinder contamination.
printer_to_comms = defaultdict(set)
for (cls, name), es in by_comm.items():
for e in es:
printer_to_comms[e["world_rk"]].add((cls, name))
shared = {wr: cs for wr, cs in printer_to_comms.items() if len(cs) > 1}
if shared:
print(" NOTE: world ranks acting as printer for multiple comms:")
for wr, cs in sorted(shared.items()):
print(f" world_rk={wr} comms={sorted(cs)}")

# ---------- H2: buffer aliasing / non-boolean input ----------
# Signature per PR description: nonzero local_val where it should be 0.
# The unambiguous tell is max > 1 (input was not a Python bool).
print("\nH2 — buffer aliasing / non-boolean input:")
nonbool = [e for e in entries if e.get("max", 0) > 1]
nonzero = [e for e in entries if e.get("gather_sum", 0) > 0]
print(f" Calls with any nonzero local_val: {len(nonzero)} / {len(entries)}"
f" (these may be legitimate True returns)")
if nonbool:
print(f" STRONG SIGNAL: {len(nonbool)} calls had max > 1 "
f"(input was not boolean)")
print(_examples(nonbool, ["max", "gather_sum"]))
else:
print(" OK: every nonzero local_val was 1 (boolean).")

# ---------- H3: reducer malfunction ----------
# (a) Allreduce SUM disagrees with the Allgather-summed local_vals.
# (b) rank_sum != expected sum-of-ranks for a comm of this size.
print("\nH3 — Allreduce reducer malfunction:")
sum_mismatch = [e for e in entries
if "sum" in e and "gather_sum" in e
and e["sum"] != e["gather_sum"]]
rank_sum_fail = [e for e in entries
if "rank_sum" in e and "expected_rank_sum" in e
and e["rank_sum"] != e["expected_rank_sum"]]
print(f" sum != gather_sum (reducer disagreeing with gather): "
f"{len(sum_mismatch)}")
if sum_mismatch:
print(_examples(sum_mismatch, ["sum", "gather_sum"]))
print(f" rank_sum sanity failures (SUM broken on this comm): "
f"{len(rank_sum_fail)}")
if rank_sum_fail:
print(_examples(rank_sum_fail, ["rank_sum", "expected_rank_sum"]))

# ---------- H4: duplicate rank participation ----------
print("\nH4 — duplicate rank participation in mpicomm:")
dups = [e for e in entries
if "unique" in e and "count" in e and e["unique"] < e["count"]]
print(f" Calls with duplicate world ranks: {len(dups)}")
if dups:
print(_examples(dups, ["count", "unique"]))

# ---------- Verdict ----------
print("\nVerdict:")
triggered = []
if h1_hits:
triggered.append("H1 (wider/unstable membership)")
if nonbool:
triggered.append("H2 (non-boolean input)")
if sum_mismatch or rank_sum_fail:
triggered.append("H3 (reducer)")
if dups:
triggered.append("H4 (duplicate ranks)")
if triggered:
print(" Hypotheses triggered: " + ", ".join(triggered))
else:
if nonzero:
print(" No invariant violations. Some calls returned nonzero;"
" consistent with legitimate shutdown signals.")
else:
print(" Clean log: no anomalies on any of the four hypotheses.")


def main(argv):
if len(argv) != 2:
print(f"Usage: {argv[0]} <stdout_file>", file=sys.stderr)
return 2
path = argv[1]
entries = parse(path)
report(entries, path)
return 0


if __name__ == "__main__":
sys.exit(main(sys.argv))
122 changes: 118 additions & 4 deletions mpisppy/spbase.py
Original file line number Diff line number Diff line change
Expand Up @@ -642,10 +642,124 @@ def spcomm(self, value):


def allreduce_or(self, val):
local_val = np.array([val], dtype='int8')
global_val = np.zeros(1, dtype='int8')
self.mpicomm.Allreduce(local_val, global_val, op=MPI.LOR)
if global_val[0] > 0:
# ====== DEBUG: LOR_bug instrumentation ======
# Yields per call (on cyl_rk == 0 of self.mpicomm) the full picture
# needed to localize an Allreduce(LOR) returning nonzero when every
# rank intends a zero. Probes four axes:
# 1. comm membership (world ranks participating, size, uniqueness)
# 2. data going in (Allgather of every rank's local_val)
# 3. reduction sanity (SUM/MAX/LOR + a rank-sum check whose
# expected value is n*(n-1)/2)
# 4. consistency (compare Allgather sum to Allreduce SUM
# to tell "input was wrong" from
# "Allreduce is wrong")
# Remove before merging to main. See PR description for hypothesis tree.
import os
import socket
import sys
sz = self.mpicomm.Get_size()
cyl_rk = self.mpicomm.Get_rank()
world_rk = MPI.COMM_WORLD.Get_rank()
host = socket.gethostname()
pid = os.getpid()

local_int = 1 if val else 0
local_int32 = np.array([local_int], dtype='int32')
local_int8 = np.array([local_int], dtype='int8')

# (3) Reductions — three ops in parallel, plus a rank-sum sanity check.
sum_out = np.zeros(1, dtype='int32')
self.mpicomm.Allreduce(local_int32, sum_out, op=MPI.SUM)

max_out = np.zeros(1, dtype='int32')
self.mpicomm.Allreduce(local_int32, max_out, op=MPI.MAX)

lor_out = np.zeros(1, dtype='int8')
self.mpicomm.Allreduce(local_int8, lor_out, op=MPI.LOR)

rank_in = np.array([cyl_rk], dtype='int32')
rank_out = np.zeros(1, dtype='int32')
self.mpicomm.Allreduce(rank_in, rank_out, op=MPI.SUM)
expected_rank_sum = sz * (sz - 1) // 2

# (1) + (2) Allgather of (world_rk, cyl_rk, local_int) so we see
# exactly which ranks participated and what each one contributed.
report = np.array([world_rk, cyl_rk, local_int], dtype='int32')
all_reports = np.zeros(3 * sz, dtype='int32')
self.mpicomm.Allgather(report, all_reports)

# Track a per-instance call counter so logs are correlatable.
self._lor_diag_count = getattr(self, "_lor_diag_count", 0) + 1
call_n = self._lor_diag_count

if cyl_rk == 0:
rows = all_reports.reshape(sz, 3)
wr = rows[:, 0].tolist()
nonzero_rows = rows[rows[:, 2] != 0]
gather_sum = int(rows[:, 2].sum())
cls = type(self).__name__
try:
comm_name = self.mpicomm.Get_name()
except Exception:
comm_name = "<unknown>"
print(
f"[LOR_bug call={call_n} cls={cls} "
f"world_rk={world_rk} host={host} pid={pid}] "
f"mpicomm size={sz} name={comm_name!r}",
flush=True,
)
print(
f" world_ranks: min={min(wr)} max={max(wr)} "
f"count={len(wr)} unique={len(set(wr))}",
flush=True,
)
print(
f" reductions: sum={int(sum_out[0])} max={int(max_out[0])} "
f"lor={int(lor_out[0])} rank_sum={int(rank_out[0])} "
f"expected_rank_sum={expected_rank_sum}",
flush=True,
)
print(
f" gather: gather_sum={gather_sum} "
f"nonzero_reports={len(nonzero_rows)}",
flush=True,
)
# "Bad" = invariant-violating, NOT just "nonzero result." A
# legitimate shutdown signal returns sum=lor=1 with
# gather_sum=1 (consistent), which is fine. The real bug
# signature is gather_sum disagreeing with the Allreduce SUM
# (the reducer lying), or the rank-sum sanity check failing
# (SUM broken on this comm), or duplicate world ranks
# (group membership corrupted), or some rank packing >1
# (non-boolean input — only possible under memory aliasing).
bad = (
int(rank_out[0]) != expected_rank_sum
or int(sum_out[0]) != gather_sum
or len(set(wr)) != len(wr)
or int(max_out[0]) > 1
)
if bad:
limit = min(64, len(nonzero_rows))
for w, c, v in nonzero_rows[:limit].tolist():
print(
f" nonzero: world_rk={w} cyl_rk={c} local_val={v}",
flush=True,
)
if len(nonzero_rows) > limit:
print(
f" (... {len(nonzero_rows) - limit} more nonzero rows truncated ...)",
flush=True,
)
# Also dump the full world-rank list once so we can see exactly
# who is participating in this comm.
print(
f" ALL world_ranks: {wr}",
flush=True,
)
sys.stdout.flush()
# ====== END DEBUG ======

if lor_out[0] > 0:
return True
else:
return False
Expand Down
Loading