Skip to content

Fix device sync in generic benchmarking functions for TPU/Pallas#1773

Merged
norx1991 merged 2 commits intomainfrom
yifeixu/fix_sync_pallas
Mar 24, 2026
Merged

Fix device sync in generic benchmarking functions for TPU/Pallas#1773
norx1991 merged 2 commits intomainfrom
yifeixu/fix_sync_pallas

Conversation

@norx1991
Copy link
Copy Markdown
Contributor

@norx1991 norx1991 commented Mar 21, 2026

Summary

  • torch.accelerator.synchronize() does not reliably block on torch_tpu, causing wall-clock benchmarks to return before device computation finishes
  • Add _synchronize(result) helper that uses torch_tpu._internal.sync.synchronize(tensor, wait=True) for TPU tensors, falling back to torch.accelerator.synchronize() otherwise
  • Update all _generic benchmarking functions (compute_repeat_generic, interleaved_bench_generic, do_bench_generic) to capture fn() return values and sync on them

Note

  • Wall-time (latency) measures the time for a single kernel call including device execution. It syncs after every call:
    sync → start timer → fn() → sync → stop timer
    This is what helion's _generic benchmarking functions measure, and what this PR fixes.

  • Throughput measures the sustained dispatch rate by firing many calls without syncing in between, then syncing once at the end:
    start timer → fn() × N → sync → stop timer → divide by N

Profiling script: https://gist.github.com/norx1991/0066f29f4d98078d89404c26b78083af

Before (torch.accelerator.synchronize, N=104,857,600, grid=100, block=1048576):

Implementation Wall (ms) Throughput (ms) Device (ms, xprof) xprof occ
torch.exp 0.117 0.049 0.262 90/1010
helion exp_fwd (g=100) 0.145 0.121 0.771 167/1010
jax.numpy.exp 0.446 0.270 0.263 1010/1010
pallas exp jax (g=100) 0.436 0.270 0.263 1010/1010
pallas exp torch_tpu (g=100) 0.078 0.096 0.771 168/1010

After (torch_tpu._internal.sync.synchronize):

Implementation Wall (ms) Throughput (ms) Device (ms, xprof) xprof occ
torch.exp 0.781 0.058 0.260 1010/1010
helion exp_fwd (g=100) 1.077 0.144 0.771 1010/1010
jax.numpy.exp 0.422 0.270 0.265 1010/1010
pallas exp jax (g=100) 0.418 0.270 0.263 1010/1010
pallas exp torch_tpu (g=100) 1.066 0.109 0.771 1010/1010

Wall times for torch paths now match device times. xprof captures all 1010 iterations across the board (was 90-168 with broken sync). JAX paths unchanged (use block_until_ready).

It is worth noticing that torch_tpu._internal.sync.synchronize(tensor, wait=True) is a per-tensor sync, not a device barrier. It seems that it only waits for that specific PjRt buffer's future to resolve, but earlier independent computations may still be in-flight. In contrast, JAX's block_until_ready() on the last array effectively acts as a device barrier — it does not return until all prior dispatched work has completed. This is reflected in the throughput measurement.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Mar 21, 2026
torch.accelerator.synchronize() does not reliably block on torch_tpu,
causing wall-clock benchmarks to return before device computation
finishes. Add _synchronize(result) helper that uses torch_tpu's
tensor-level sync for TPU tensors, falling back to
torch.accelerator.synchronize() otherwise. Update all _generic
benchmarking functions to capture fn() return values and sync on them.
@norx1991 norx1991 force-pushed the yifeixu/fix_sync_pallas branch from 90d5fda to a6cadb8 Compare March 21, 2026 16:58
Use multiline import format so the line fits within 88-char limit,
preventing ruff from reformatting and breaking the pyrefly suppression comment.
@oulgen
Copy link
Copy Markdown
Contributor

oulgen commented Mar 22, 2026

Can you explain why torch.accelerator.synchronize() doesnt work? we should instead fix that in pytorch

@norx1991
Copy link
Copy Markdown
Contributor Author

Can you explain why torch.accelerator.synchronize() doesnt work? we should instead fix that in pytorch

Good question. The call goes here: https://github.com/pytorch/pytorch/blob/main/torch/csrc/DeviceAccelerator.cpp#L71-L82. The reason is torch_tpu does not have _lazy_init() method so the call silently returns at Line 76. It takes no effect. Otherwise, it will hit this error: https://github.com/pytorch/pytorch/blob/3f14378080012bafc834a927560b6a1555b5bf80/c10/core/impl/DeviceGuardImplInterface.h#L245.

With that said, I only see method that waits on specific tensors in torch_tpu: https://github.com/google-ml-infra/torch_tpu/blob/main/torch_tpu/_internal/sync/sync.py#L31, so it may require some bigger change on the torch_tpu side to implement a method to sync all ops on device.

I also had a brief discussion with @v0i0 . Because the observation here shows that the autotuner is basically optimizing a wrong objective function for TPU, we want to fix this ASAP to unblock ourselves.

@norx1991 norx1991 requested review from oulgen and v0i0 March 23, 2026 05:31
@norx1991 norx1991 marked this pull request as ready for review March 23, 2026 05:35
@norx1991 norx1991 merged commit 3ae128b into main Mar 24, 2026
19 of 21 checks passed
@norx1991 norx1991 deleted the yifeixu/fix_sync_pallas branch March 24, 2026 04:36
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants