Skip to content

[ExecuTorch][WebGPU] Add update_cache op (llama.update_cache)#20083

Open
JulianCloudNTH wants to merge 6 commits into
gh/JulianCloudNTH/16/basefrom
gh/JulianCloudNTH/16/head
Open

[ExecuTorch][WebGPU] Add update_cache op (llama.update_cache)#20083
JulianCloudNTH wants to merge 6 commits into
gh/JulianCloudNTH/16/basefrom
gh/JulianCloudNTH/16/head

Conversation

@JulianCloudNTH

@JulianCloudNTH JulianCloudNTH commented Jun 6, 2026

Copy link
Copy Markdown
Contributor

Stack from ghstack (oldest at bottom):

Add llama.update_cache.default: an in-place KV-cache write. The shader scatters the new K/V ([1,S,H,D]) into the cache ([1,Cmax,H,D]) at dst_offset = input_pos*n_heads*head_dim, bounds-checked against the cache size. The handler validates shape (batch==1, matching n_heads/head_dim) and sizes the 1D dispatch from the device limit via WebGPUUtils before allocating. Mirrors the Vulkan sdpa_kv_cache_update reference. The export/delegation test is the follow-up diff stacked directly above. Authored with assistance from Claude.
@exported-using-ghexport

Differential Revision: D107547308

Differential Revision: D107547308

[ghstack-poisoned]
@pytorch-bot

pytorch-bot Bot commented Jun 6, 2026

Copy link
Copy Markdown

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/20083

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 Unclassified Failure

As of commit 53c3eb6 with merge base ff2bf9c (image):

UNCLASSIFIED FAILURE - DrCI could not classify the following job because the workflow did not run on the merge base. The failure may be pre-existing on trunk or introduced by this PR:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@github-actions

github-actions Bot commented Jun 6, 2026

Copy link
Copy Markdown

This PR needs a release notes: label

If your change should be included in the release notes (i.e. would users of this library care about this change?), please use a label starting with release notes:. This helps us keep track and include your important work in the next release notes.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "release notes: none"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

[ghstack-poisoned]
@JulianCloudNTH

Copy link
Copy Markdown
Contributor Author

@claude review

@claude

claude Bot commented Jun 8, 2026

Copy link
Copy Markdown

Claude finished @JulianCloudNTH's task in 1m 56s —— View job


Code Review: [ExecuTorch][WebGPU] Add update_cache op

  • Read PR diff and changed files
  • Review UpdateCache.cpp implementation
  • Review WGSL shader
  • Review generated header file
  • Review CMakeLists.txt changes
  • Post review feedback

Overall this is a clean, well-structured addition that follows the existing WebGPU op patterns (BinaryOp, RmsNorm) closely. The shader is simple and correct, the host-side validation is thorough, and the WGSL hash in the generated header is verified. A few observations:

Issues

1. Potential integer overflow in dst_offset calculationUpdateCache.cpp:91

params.dst_offset = static_cast<uint32_t>(
    static_cast<uint64_t>(input_pos) * n_heads * head_dim);

If input_pos is negative (e.g., sentinel or error), the static_cast<uint64_t> of a negative int64_t wraps to a very large value, which then truncates to a uint32_t. While the shader has a bounds check (dst_offset + i >= cache_numel), it would be cleaner to validate input_pos >= 0 on the host before computing the offset, producing a clear error message rather than silently relying on the GPU bounds check to discard all writes.

Fix this →

2. No validation that dst_offset + numel <= cache_numel on the hostUpdateCache.cpp:88-93

The shader performs this bounds check per-thread, but if the offset is beyond the cache, every thread will early-return and no data is written — a silent no-op. A host-side check (or at minimum a warning) would make debugging easier when input_pos is accidentally out-of-range.

3. The out argument (args[3]) is unusedUpdateCache.cpp:38

The op signature is [value, cache, input_pos, out] but args.at(3) is never accessed. This is fine if out is the same buffer as cache (in-place semantics), but it might be worth adding a comment or an assertion (ET_CHECK(args.at(3) == cache_id) or similar) to make the in-place contract explicit, matching how the Vulkan backend handles it (where out is -1 / unused).

Nits (non-blocking)

4. Namespace style inconsistencyUpdateCache.cpp uses the C++17 nested namespace (namespace executorch::backends::webgpu {), matching RmsNorm.cpp, but differs from BinaryOp.cpp which uses separate namespace blocks. This is fine — just noting it's already inconsistent in the existing code.

5. WebGPUUtils.h includeUpdateCache.cpp:10 includes WebGPUUtils.h for clamp_workgroup_size / compute_1d_workgroup_count, which is correct and matches BinaryOp.cpp.

Shader (update_cache.wgsl)

The shader is clean and minimal. The dual bounds checks (lines 17 and 20) are correct:

  • i >= params.numel — standard thread-count guard
  • params.dst_offset + i >= params.cache_numel — prevents out-of-bounds writes

One minor note: the addition params.dst_offset + i could theoretically overflow u32 if both are near 2^32, but in practice cache_numel is a u32 so the cache can never be that large anyway.

Summary

The implementation is solid and ready to ship with minor improvements. The main actionable item is adding host-side validation for input_pos >= 0 (issue #1) — it's a one-liner that prevents a confusing failure mode. The other items are suggestions for clarity.


@claude

claude Bot commented Jun 8, 2026

Copy link
Copy Markdown

Claude finished @JulianCloudNTH's task in 1m 56s —— View job


Code Review: [ExecuTorch][WebGPU] Add update_cache op

  • Read changed files and understand the implementation
  • Review WGSL shader code
  • Review C++ host code (UpdateCache.cpp)
  • Review CMakeLists.txt changes
  • Post review feedback

Overall this is a clean, well-structured addition that follows the existing WebGPU op patterns closely. The shader is minimal and correct for its purpose. A few issues to flag:

Issues

1. No negative input_pos guard — UpdateCache.cpp:88

params.dst_offset = static_cast<uint32_t>(
    static_cast<uint64_t>(input_pos) * n_heads * head_dim);

input_pos is int64_t from graph.get_int(). If it's negative, the static_cast<uint64_t> wraps to a huge value, which then gets truncated to uint32_t. The shader's bounds check (dst_offset + i >= cache_numel) would likely catch it at runtime, but silently producing a garbage offset is confusing to debug. Consider adding an early check:

if (input_pos < 0) {
  throw std::runtime_error("WebGPU update_cache: input_pos must be non-negative");
}

Fix this →

2. Potential uint32_t overflow in dst_offset calculation — UpdateCache.cpp:88-89

For large caches (e.g. input_pos=4096, n_heads=32, head_dim=128), dst_offset = 4096 * 32 * 128 = 16,777,216 which fits fine. But the intermediate uint64_t product isn't range-checked before the narrowing cast to uint32_t. With extreme values (e.g. input_pos > 8192 with large head counts), this could silently truncate. The risk is low for current LLaMA shapes but a bounds assertion would be defensive:

uint64_t dst_offset_u64 = static_cast<uint64_t>(input_pos) * n_heads * head_dim;
if (dst_offset_u64 >= cache_numel) {
  throw std::runtime_error("WebGPU update_cache: dst_offset exceeds cache capacity");
}
params.dst_offset = static_cast<uint32_t>(dst_offset_u64);

This also catches the negative-input_pos case above (since the wrapped uint64_t would exceed cache_numel).

Fix this →

3. Shader u32 addition overflow — update_cache.wgsl:20

if (params.dst_offset + i >= params.cache_numel) {

If dst_offset + i overflows u32, this wraps around and the bounds check becomes ineffective. In WGSL, unsigned addition wraps modulo 2^32. For realistic cache sizes this won't happen (would require >4GB cache), but documenting the assumption or switching to a subtraction-based check (i >= params.cache_numel - params.dst_offset) would be more robust. Low priority.

Nits

4. Unused out arg — UpdateCache.cpp:36

The comment says args: [value, cache, input_pos, out] but out (args[3]) is never accessed. This is correct for an in-place op (cache is both input and output), but the comment implies 4 args are consumed. If out is truly unused, you could omit it from the comment to avoid confusion, or add a brief note that cache is mutated in-place.

5. Namespace style — minor inconsistency

BinaryOp.cpp uses namespace executorch { namespace backends { namespace webgpu { while this file uses namespace executorch::backends::webgpu {. Both are valid C++17, and the nested declaration is arguably cleaner. Just noting the inconsistency — no action needed unless the team has a preference.

Positive Notes

  • The static_assert on UpdateCacheParams size is good practice
  • Shape validation mirrors the Vulkan reference appropriately
  • Workgroup size clamping + dispatch validation via WebGPUUtils is correct
  • Resource lifecycle management (release pattern) is consistent with existing ops
  • The WGSL shader is minimal and does exactly what's needed

Verdict: Looks good overall. Issues 1-2 are worth addressing (defensive checks for input_pos range); issue 3 is low-priority. The rest are nits.


[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. meta-exported

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants