Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 68 additions & 2 deletions iron/common/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,38 @@ def _to_numpy(x):
expected_np = _to_numpy(reference).reshape((-1,))
output = _to_numpy(output).reshape((-1,))

# Detect NaNs early and dump arrays for post-mortem when they occur.
try:
nan_out_count = int(np.isnan(output).sum())
nan_ref_count = int(np.isnan(expected_np).sum())
except Exception:
nan_out_count = 0
nan_ref_count = 0
if nan_out_count or nan_ref_count:
import os
import time

dbg_dir = os.path.join("build", "debug")
try:
os.makedirs(dbg_dir, exist_ok=True)
except Exception:
dbg_dir = "."
ts = int(time.time() * 1000)
if nan_out_count:
out_path = os.path.join(dbg_dir, f"{buf_name}_output_nan_{ts}.npy")
try:
np.save(out_path, output)
print(f"{buf_name}: AIE execution produced {nan_out_count} NaN values; dumped output to {out_path}")
except Exception as e:
print(f"{buf_name}: Failed to dump output with NaNs: {e}")
if nan_ref_count:
ref_path = os.path.join(dbg_dir, f"{buf_name}_reference_nan_{ts}.npy")
try:
np.save(ref_path, expected_np)
print(f"{buf_name}: Reference contains {nan_ref_count} NaN values; dumped reference to {ref_path}")
except Exception as e:
print(f"{buf_name}: Failed to dump reference with NaNs: {e}")

if len(output) < len(expected_np):
# Allow larger buffers - binning may have allocated more space than needed
print(
Expand Down Expand Up @@ -205,8 +237,42 @@ def run_test(
# Run timed iterations and measure NPU execution time
total_npu_ns = 0
for _ in range(timed_iters):
result = op_func(*args)
total_npu_ns += result.npu_time
try:
result = op_func(*args)
total_npu_ns += result.npu_time
except Exception as e:
import os
import time
import traceback

dbg_dir = os.path.join("build", "debug")
try:
os.makedirs(dbg_dir, exist_ok=True)
except Exception:
dbg_dir = "."
ts = int(time.time() * 1000)
# Dump provided input buffers (torch tensors)
for in_name, in_tensor in input_buffers.items():
try:
arr = _to_numpy(in_tensor)
p = os.path.join(dbg_dir, f"{in_name}_input_{ts}.npy")
np.save(p, arr)
print(f"Dumped input {in_name} to {p}")
except Exception as ee:
print(f"Failed to dump input {in_name}: {ee}")
# Dump any allocated output XRTTensors
for out_name, out_buf in output_map.items():
try:
torch_out = out_buf.to_torch()
arr = _to_numpy(torch_out)
p = os.path.join(dbg_dir, f"{out_name}_output_before_failure_{ts}.npy")
np.save(p, arr)
print(f"Dumped output {out_name} (pre-verify) to {p}")
except Exception as ee:
print(f"Failed to dump output {out_name}: {ee}")
print("Exception during NPU execution:")
traceback.print_exc()
raise
latency_us = (total_npu_ns / timed_iters) / 1e3

# Verify outputs
Expand Down