Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backends/webgpu/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ set(WEBGPU_SRCS
runtime/WebGPUGraph.cpp
runtime/WebGPUDelegateHeader.cpp
runtime/WebGPUDevice.cpp
runtime/WebGPUQueryPool.cpp
runtime/ops/OperatorRegistry.cpp
runtime/ops/add/BinaryOp.cpp
runtime/ops/rms_norm/RmsNorm.cpp
Expand Down
11 changes: 11 additions & 0 deletions backends/webgpu/runtime/WebGPUDevice.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include <cstdlib>
#include <memory>
#include <stdexcept>
#include <vector>

namespace executorch {
namespace backends {
Expand Down Expand Up @@ -137,6 +138,16 @@ WebGPUContext create_webgpu_context() {
WGPUStatus_Success) {
device_desc.requiredLimits = &supported_limits;
}

// Bench: enable TimestampQuery if available; fail-open (skip timing if not).
std::vector<WGPUFeatureName> required_features;
if (wgpuAdapterHasFeature(ctx.adapter, WGPUFeatureName_TimestampQuery)) {
required_features.push_back(WGPUFeatureName_TimestampQuery);
device_desc.requiredFeatureCount = required_features.size();
device_desc.requiredFeatures = required_features.data();
ctx.timestamp_supported = true;
}

device_desc.uncapturedErrorCallbackInfo.callback = on_device_error;

WGPUWaitStatus device_wait = webgpu_wait(
Expand Down
8 changes: 8 additions & 0 deletions backends/webgpu/runtime/WebGPUDevice.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@

#include <webgpu/webgpu.h>

#include <executorch/backends/webgpu/runtime/WebGPUQueryPool.h>

#include <memory>

namespace executorch {
namespace backends {
namespace webgpu {
Expand All @@ -19,6 +23,10 @@ struct WebGPUContext {
WGPUAdapter adapter = nullptr;
WGPUDevice device = nullptr;
WGPUQueue queue = nullptr;
// True if the device was created with the TimestampQuery feature (bench).
bool timestamp_supported = false;
// Bench-only: timestamp-query pool, lazily created in execute() (env-gated).
std::unique_ptr<WebGPUQueryPool> querypool;
};

WebGPUContext create_webgpu_context();
Expand Down
56 changes: 55 additions & 1 deletion backends/webgpu/runtime/WebGPUGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include <executorch/backends/webgpu/runtime/WebGPUCompat.h>
#include <executorch/backends/webgpu/runtime/WebGPUDevice.h>

#include <cstdlib>
#include <cstring>
#include <stdexcept>

Expand Down Expand Up @@ -496,18 +497,48 @@ void WebGPUGraph::copy_inputs(
}
}

namespace {
// Bench gate: WEBGPU_TIMESTAMP_QUERY enables per-pass GPU timestamp queries.
bool should_timestamp_query() {
static const bool enabled = std::getenv("WEBGPU_TIMESTAMP_QUERY") != nullptr;
return enabled;
}
} // namespace

void WebGPUGraph::execute() {
const size_t n = dispatches_.size();
const size_t chunk = execute_config_.chunk_size;

if (chunk == 0 || n <= chunk) {
// Bench: timestamp-query pool, null unless env-gated + feature present.
WebGPUQueryPool* qp = nullptr;
if (should_timestamp_query() && n > 0) {
if (auto* ctx = get_default_webgpu_context()) {
if (ctx->timestamp_supported) {
if (!ctx->querypool || ctx->querypool->capacity() < n) {
ctx->querypool = std::make_unique<WebGPUQueryPool>();
ctx->querypool->initialize(device_, static_cast<uint32_t>(n));
}
qp = ctx->querypool.get();
qp->reset(static_cast<uint32_t>(n));
}
}
}

WGPUCommandEncoderDescriptor enc_desc = {};
WGPUCommandEncoder encoder =
wgpuDeviceCreateCommandEncoder(device_, &enc_desc);

// One pass per dispatch: enforces storage RAW ordering across deps.
for (const auto& dispatch : dispatches_) {
for (size_t i = 0; i < n; i++) {
const auto& dispatch = dispatches_[i];
// tw must outlive BeginComputePass (the descriptor points at it).
WGPUPassTimestampWrites tw = {};
WGPUComputePassDescriptor pass_desc = {};
if (qp) {
tw = qp->writes_for(static_cast<uint32_t>(i));
pass_desc.timestampWrites = &tw;
}
WGPUComputePassEncoder pass =
wgpuCommandEncoderBeginComputePass(encoder, &pass_desc);
wgpuComputePassEncoderSetPipeline(pass, dispatch.pipeline);
Expand All @@ -517,22 +548,45 @@ void WebGPUGraph::execute() {
pass, dispatch.workgroup_count_x, 1, 1);
wgpuComputePassEncoderEnd(pass);
wgpuComputePassEncoderRelease(pass);
if (qp) {
qp->record(
static_cast<uint32_t>(i),
dispatch.kernel_name,
{dispatch.workgroup_count_x, 1, 1},
{1, 1, 1});
}
}

for (const auto& copy : output_copies_) {
wgpuCommandEncoderCopyBufferToBuffer(
encoder, copy.src_buffer, 0, copy.staging_buffer, 0, copy.nbytes);
}

if (qp) {
qp->resolve(encoder);
}

WGPUCommandBufferDescriptor cmd_desc = {};
WGPUCommandBuffer cmd = wgpuCommandEncoderFinish(encoder, &cmd_desc);
wgpuQueueSubmit(queue_, 1, &cmd);

wgpuCommandBufferRelease(cmd);
wgpuCommandEncoderRelease(encoder);

if (qp) {
qp->extract_results(instance_);
qp->print_results();
}
return;
}

// GPU timestamp queries assume one submit; chunked execute is multi-submit.
if (should_timestamp_query()) {
throw std::runtime_error(
"WebGPU: WEBGPU_TIMESTAMP_QUERY is incompatible with chunked execute "
"(multi-submit); disable chunking to use GPU timestamp queries");
}

const size_t first_chunk = execute_config_.initial_chunk_size > 0
? execute_config_.initial_chunk_size
: chunk;
Expand Down
1 change: 1 addition & 0 deletions backends/webgpu/runtime/WebGPUGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ struct WebGPUDispatch {
WGPUComputePipeline pipeline = nullptr;
WGPUBindGroup bind_group = nullptr;
uint32_t workgroup_count_x = 1;
std::string kernel_name; // bench label
};

struct OutputCopy {
Expand Down
220 changes: 220 additions & 0 deletions backends/webgpu/runtime/WebGPUQueryPool.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,220 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/backends/webgpu/runtime/WebGPUCompat.h>
#include <executorch/backends/webgpu/runtime/WebGPUQueryPool.h>

#include <cstdio>
#include <map>
#include <stdexcept>
#include <string>

namespace executorch::backends::webgpu {

namespace {

struct MapCallbackData {
WGPUMapAsyncStatus status = WGPUMapAsyncStatus_Error;
};

void map_callback(
WGPUMapAsyncStatus status,
WGPUStringView /*message*/,
void* userdata1,
void* /*userdata2*/) {
auto* data = static_cast<MapCallbackData*>(userdata1);
data->status = status;
}

constexpr uint64_t kTimestampBytes = sizeof(uint64_t);

} // namespace

WebGPUQueryPool::~WebGPUQueryPool() {
if (readback_buf_) {
wgpuBufferRelease(readback_buf_);
}
if (resolve_buf_) {
wgpuBufferRelease(resolve_buf_);
}
if (qset_) {
wgpuQuerySetRelease(qset_);
}
}

void WebGPUQueryPool::initialize(WGPUDevice device, uint32_t max_pairs) {
if (max_pairs == 0) {
return;
}
// Re-init guard; mirrors Vulkan QueryPool (avoids leaking a prior QuerySet).
if (qset_ != nullptr) {
return;
}
capacity_pairs_ = max_pairs;
const uint32_t count = 2 * max_pairs;
const uint64_t bytes = static_cast<uint64_t>(count) * kTimestampBytes;

WGPUQuerySetDescriptor qsd = {};
qsd.type = WGPUQueryType_Timestamp;
qsd.count = count;
qset_ = wgpuDeviceCreateQuerySet(device, &qsd);

WGPUBufferDescriptor rbd = {};
rbd.size = bytes;
rbd.usage = WGPUBufferUsage_QueryResolve | WGPUBufferUsage_CopySrc;
resolve_buf_ = wgpuDeviceCreateBuffer(device, &rbd);

WGPUBufferDescriptor mbd = {};
mbd.size = bytes;
mbd.usage = WGPUBufferUsage_MapRead | WGPUBufferUsage_CopyDst;
readback_buf_ = wgpuDeviceCreateBuffer(device, &mbd);
// WebGPU timestamps are already nanoseconds, so ns_per_tick_ stays 1.0.
}

void WebGPUQueryPool::reset(uint32_t num_dispatches) {
// Fail loud on overrun; mirrors Vulkan QueryPool VK_CHECK_COND guard.
if (num_dispatches > capacity_pairs_) {
throw std::runtime_error(
"WebGPUQueryPool: num_dispatches " + std::to_string(num_dispatches) +
" exceeds capacity " + std::to_string(capacity_pairs_));
}
num_pairs_ = num_dispatches;
durations_.clear();
}

WGPUPassTimestampWrites WebGPUQueryPool::writes_for(uint32_t i) {
WGPUPassTimestampWrites tw = {};
tw.querySet = qset_;
tw.beginningOfPassWriteIndex = 2 * i;
tw.endOfPassWriteIndex = 2 * i + 1;
return tw;
}

void WebGPUQueryPool::record(
uint32_t i,
const std::string& name,
std::array<uint32_t, 3> gwg,
std::array<uint32_t, 3> lwg) {
ShaderDuration d;
d.idx = i;
d.kernel_name = name;
d.global_wg = gwg;
d.local_wg = lwg;
durations_.push_back(d);
}

void WebGPUQueryPool::resolve(WGPUCommandEncoder encoder) {
if (num_pairs_ == 0) {
return;
}
const uint32_t count = 2 * num_pairs_;
wgpuCommandEncoderResolveQuerySet(encoder, qset_, 0, count, resolve_buf_, 0);
wgpuCommandEncoderCopyBufferToBuffer(
encoder,
resolve_buf_,
0,
readback_buf_,
0,
static_cast<uint64_t>(count) * kTimestampBytes);
}

void WebGPUQueryPool::extract_results(WGPUInstance instance) {
if (num_pairs_ == 0) {
return;
}
const uint32_t count = 2 * num_pairs_;
const uint64_t bytes = static_cast<uint64_t>(count) * kTimestampBytes;

MapCallbackData cb;
WGPUBufferMapCallbackInfo cb_info = {};
cb_info.mode = WGPUCallbackMode_WaitAnyOnly;
cb_info.callback = map_callback;
cb_info.userdata1 = &cb;
webgpu_wait(
instance,
wgpuBufferMapAsync(readback_buf_, WGPUMapMode_Read, 0, bytes, cb_info));

if (cb.status != WGPUMapAsyncStatus_Success) {
printf(
"WebGPUQueryPool: readback map failed (status %d)\n", (int)cb.status);
return;
}
const uint64_t* ticks = static_cast<const uint64_t*>(
wgpuBufferGetConstMappedRange(readback_buf_, 0, bytes));
if (ticks != nullptr) {
for (auto& d : durations_) {
const uint64_t t0 = ticks[2 * d.idx];
const uint64_t t1 = ticks[2 * d.idx + 1];
d.start_time_ns = static_cast<uint64_t>(t0 * ns_per_tick_);
d.end_time_ns = static_cast<uint64_t>(t1 * ns_per_tick_);
d.execution_duration_ns =
(t1 >= t0) ? static_cast<uint64_t>((t1 - t0) * ns_per_tick_) : 0;
}
}
wgpuBufferUnmap(readback_buf_);
}

void WebGPUQueryPool::print_results(bool tsv) const {
const char* sep = tsv ? "\t" : " ";
if (tsv) {
printf("idx%skernel%sgwg%sduration_us\n", sep, sep, sep);
} else {
printf("=== WebGPUQueryPool: per-dispatch GPU time ===\n");
}
for (const auto& d : durations_) {
const double us = d.execution_duration_ns / 1000.0;
printf(
"%u%s%s%s(%u,%u,%u)%s%.3f\n",
d.idx,
sep,
d.kernel_name.empty() ? "dispatch" : d.kernel_name.c_str(),
sep,
d.global_wg[0],
d.global_wg[1],
d.global_wg[2],
sep,
us);
}
if (tsv) {
return;
}
std::map<std::string, std::pair<uint64_t, uint32_t>> totals;
for (const auto& d : durations_) {
auto& t = totals[d.kernel_name.empty() ? "dispatch" : d.kernel_name];
t.first += d.execution_duration_ns;
t.second += 1;
}
printf("--- per-kernel mean / total (us) ---\n");
for (const auto& kv : totals) {
const double mean_us = kv.second.first / kv.second.second / 1000.0;
const double total_us = kv.second.first / 1000.0;
printf(
"%s%smean %.3f%stotal %.3f (n=%u)\n",
kv.first.c_str(),
sep,
mean_us,
sep,
total_us,
kv.second.second);
}
}

uint64_t WebGPUQueryPool::get_mean_shader_ns(
const std::string& kernel_name) const {
uint64_t sum = 0;
uint32_t n = 0;
for (const auto& d : durations_) {
if (d.kernel_name == kernel_name) {
sum += d.execution_duration_ns;
n += 1;
}
}
return n == 0 ? 0 : sum / n;
}

} // namespace executorch::backends::webgpu
Loading
Loading