From ba0f4b27591605f1699701f8e33265daea1a5e14 Mon Sep 17 00:00:00 2001 From: Julian Ng-Thow-Hing Date: Fri, 5 Jun 2026 09:20:12 -0700 Subject: [PATCH] Update [ghstack-poisoned] --- backends/webgpu/runtime/WebGPUGraph.cpp | 16 ++++++++++++++++ backends/webgpu/runtime/WebGPUGraph.h | 6 ++++++ 2 files changed, 22 insertions(+) diff --git a/backends/webgpu/runtime/WebGPUGraph.cpp b/backends/webgpu/runtime/WebGPUGraph.cpp index 19620e679b1..a11b188f428 100644 --- a/backends/webgpu/runtime/WebGPUGraph.cpp +++ b/backends/webgpu/runtime/WebGPUGraph.cpp @@ -48,6 +48,17 @@ size_t vk_datatype_size(vkgraph::VkDataType dtype) { WebGPUGraph::WebGPUGraph() = default; +WGPUBuffer WebGPUGraph::create_scratch_buffer(size_t nbytes) { + WGPUBufferDescriptor buf_desc = {}; + buf_desc.size = nbytes > 0 ? nbytes : 4; + buf_desc.usage = WGPUBufferUsage_Storage | WGPUBufferUsage_CopyDst | + WGPUBufferUsage_CopySrc; + buf_desc.mappedAtCreation = false; + WGPUBuffer buffer = wgpuDeviceCreateBuffer(device_, &buf_desc); + scratch_buffers_.push_back(buffer); + return buffer; +} + WebGPUGraph::~WebGPUGraph() { for (size_t i = 0; i < tensors_.size(); i++) { if (tensors_[i].buffer && @@ -60,6 +71,11 @@ WebGPUGraph::~WebGPUGraph() { wgpuBufferRelease(buf); } } + for (auto& buf : scratch_buffers_) { + if (buf) { + wgpuBufferRelease(buf); + } + } for (auto& buf : output_staging_buffers_) { if (buf) { wgpuBufferRelease(buf); diff --git a/backends/webgpu/runtime/WebGPUGraph.h b/backends/webgpu/runtime/WebGPUGraph.h index ac88a42ff60..aa3dadc13ab 100644 --- a/backends/webgpu/runtime/WebGPUGraph.h +++ b/backends/webgpu/runtime/WebGPUGraph.h @@ -119,6 +119,9 @@ class WebGPUGraph { uniform_buffer_bytes_ += bytes; } + // Graph-owned scratch storage buffer for fused-op intermediates (e.g. SDPA). + WGPUBuffer create_scratch_buffer(size_t nbytes); + WGPUShaderModule get_or_create_shader( const std::string& key, const char* wgsl_source); @@ -173,6 +176,9 @@ class WebGPUGraph { std::vector shared_buffers_; std::vector shared_buffer_sizes_; + // Long-lived scratch storage buffers for fused ops (e.g. SDPA temporaries). + std::vector scratch_buffers_; + // Staging buffers for reading back outputs (MapRead | CopyDst). std::vector output_staging_buffers_;