Fix device sync in generic benchmarking functions for TPU/Pallas#1773
Fix device sync in generic benchmarking functions for TPU/Pallas#1773
Conversation
torch.accelerator.synchronize() does not reliably block on torch_tpu, causing wall-clock benchmarks to return before device computation finishes. Add _synchronize(result) helper that uses torch_tpu's tensor-level sync for TPU tensors, falling back to torch.accelerator.synchronize() otherwise. Update all _generic benchmarking functions to capture fn() return values and sync on them.
90d5fda to
a6cadb8
Compare
Use multiline import format so the line fits within 88-char limit, preventing ruff from reformatting and breaking the pyrefly suppression comment.
|
Can you explain why |
Good question. The call goes here: https://github.com/pytorch/pytorch/blob/main/torch/csrc/DeviceAccelerator.cpp#L71-L82. The reason is torch_tpu does not have With that said, I only see method that waits on specific tensors in torch_tpu: https://github.com/google-ml-infra/torch_tpu/blob/main/torch_tpu/_internal/sync/sync.py#L31, so it may require some bigger change on the torch_tpu side to implement a method to sync all ops on device. I also had a brief discussion with @v0i0 . Because the observation here shows that the autotuner is basically optimizing a wrong objective function for TPU, we want to fix this ASAP to unblock ourselves. |
Summary
torch.accelerator.synchronize()does not reliably block ontorch_tpu, causing wall-clock benchmarks to return before device computation finishes_synchronize(result)helper that usestorch_tpu._internal.sync.synchronize(tensor, wait=True)for TPU tensors, falling back totorch.accelerator.synchronize()otherwise_genericbenchmarking functions (compute_repeat_generic,interleaved_bench_generic,do_bench_generic) to capturefn()return values and sync on themNote
Wall-time (latency) measures the time for a single kernel call including device execution. It syncs after every call:
sync → start timer → fn() → sync → stop timer
This is what helion's _generic benchmarking functions measure, and what this PR fixes.
Throughput measures the sustained dispatch rate by firing many calls without syncing in between, then syncing once at the end:
start timer → fn() × N → sync → stop timer → divide by N
Profiling script: https://gist.github.com/norx1991/0066f29f4d98078d89404c26b78083af
Before (
torch.accelerator.synchronize, N=104,857,600, grid=100, block=1048576):After (
torch_tpu._internal.sync.synchronize):Wall times for torch paths now match device times. xprof captures all 1010 iterations across the board (was 90-168 with broken sync). JAX paths unchanged (use
block_until_ready).It is worth noticing that torch_tpu._internal.sync.synchronize(tensor, wait=True) is a per-tensor sync, not a device barrier. It seems that it only waits for that specific PjRt buffer's future to resolve, but earlier independent computations may still be in-flight. In contrast, JAX's block_until_ready() on the last array effectively acts as a device barrier — it does not return until all prior dispatched work has completed. This is reflected in the throughput measurement.