Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backends/webgpu/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ set(WEBGPU_SRCS
runtime/ops/unsqueeze/Unsqueeze.cpp
runtime/ops/slice/Slice.cpp
runtime/ops/permute/Permute.cpp
runtime/ops/sdpa_fd_decode/SdpaFdDecode.cpp
)

add_library(webgpu_backend ${WEBGPU_SRCS})
Expand Down
13 changes: 13 additions & 0 deletions backends/webgpu/runtime/WebGPUGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,19 @@ WGPUBuffer WebGPUGraph::create_scratch_buffer(size_t nbytes) {
return buffer;
}

WGPUBuffer WebGPUGraph::make_uniform_buffer(const void* data, size_t size) {
WGPUBufferDescriptor desc = {};
desc.size = size;
desc.usage = WGPUBufferUsage_Uniform | WGPUBufferUsage_CopyDst;
desc.mappedAtCreation = true;
WGPUBuffer buffer = wgpuDeviceCreateBuffer(device_, &desc);
void* mapped = wgpuBufferGetMappedRange(buffer, 0, size);
std::memcpy(mapped, data, size);
wgpuBufferUnmap(buffer);
uniform_buffer_bytes_ += size;
return buffer;
}

void WebGPUGraph::update_symints_from_inputs(
const std::vector<InputData>& inputs) {
for (const auto& src : symint_sources_) {
Expand Down
4 changes: 4 additions & 0 deletions backends/webgpu/runtime/WebGPUGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,10 @@ class WebGPUGraph {
// Graph-owned scratch storage buffer for fused-op intermediates (e.g. SDPA).
WGPUBuffer create_scratch_buffer(size_t nbytes);

// Create a mapped-at-creation uniform buffer from `size` bytes and track it
// in the memory stats. Shared helper for ops needing a uniform Params buffer.
WGPUBuffer make_uniform_buffer(const void* data, size_t size);

WGPUShaderModule get_or_create_shader(
const std::string& key,
const char* wgsl_source);
Expand Down
41 changes: 17 additions & 24 deletions backends/webgpu/runtime/ops/sdpa/Sdpa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,13 @@
#include <executorch/backends/webgpu/runtime/ops/sdpa/sdpa_compute_attn_weights_wgsl.h>
#include <executorch/backends/webgpu/runtime/ops/sdpa/sdpa_compute_out_wgsl.h>
#include <executorch/backends/webgpu/runtime/ops/sdpa/sdpa_softmax_wgsl.h>
#include <executorch/backends/webgpu/runtime/ops/sdpa_fd_decode/SdpaFdDecode.h>
#include <executorch/backends/webgpu/runtime/ops/update_cache/update_cache_wgsl.h>

#include <webgpu/webgpu.h>

#include <cmath>
#include <cstdint>
#include <cstring>
#include <stdexcept>
#include <string>

Expand Down Expand Up @@ -128,22 +128,6 @@ static ComputeOutParams make_compute_out_params(
return p;
}

// Create a uniform buffer initialized with the given bytes.
WGPUBuffer
make_uniform_buffer(WebGPUGraph& graph, const void* data, size_t size) {
WGPUDevice device = graph.device();
WGPUBufferDescriptor desc = {};
desc.size = size;
desc.usage = WGPUBufferUsage_Uniform | WGPUBufferUsage_CopyDst;
desc.mappedAtCreation = true;
WGPUBuffer buffer = wgpuDeviceCreateBuffer(device, &desc);
void* mapped = wgpuBufferGetMappedRange(buffer, 0, size);
std::memcpy(mapped, data, size);
wgpuBufferUnmap(buffer);
graph.add_uniform_buffer_bytes(size);
return buffer;
}

// A buffer + its byte size, for binding.
struct BufferBinding {
WGPUBuffer buffer;
Expand Down Expand Up @@ -262,7 +246,7 @@ static WGPUBuffer record_update_cache_dispatch(
device, static_cast<uint32_t>(kv_numel), uc_wg, label);
UpdateCacheParams uc =
make_update_cache_params(kv_numel, kv_dst_offset, cache_numel);
WGPUBuffer ubuf = make_uniform_buffer(graph, &uc, sizeof(uc));
WGPUBuffer ubuf = graph.make_uniform_buffer(&uc, sizeof(uc));
BufferBinding bindings[2] = {
{cache.buffer, cache.nbytes}, {src.buffer, src.nbytes}};
build_dispatch(
Expand Down Expand Up @@ -429,9 +413,6 @@ void sdpa_with_kv_cache_impl(WebGPUGraph& graph, const std::vector<int>& args) {
static_cast<uint64_t>(S) *
static_cast<uint64_t>(dynamic_pos ? Cmax : context_len);
const uint64_t aw_bytes = aw_cap_floats * sizeof(float);
// Prefill scratch scales as Hq·S·Cmax; can be large for long-context prefill.
WGPUBuffer attn_weights = graph.create_scratch_buffer(aw_bytes);
WGPUBuffer attn_weights_softmax = graph.create_scratch_buffer(aw_bytes);

// Dynamic input_pos: the resize hook rewrites these per step.
WGPUBuffer uc_k_buf = nullptr, uc_v_buf = nullptr, qk_buf = nullptr,
Expand Down Expand Up @@ -475,6 +456,18 @@ void sdpa_with_kv_cache_impl(WebGPUGraph& graph, const std::vector<int>& args) {
dynamic_pos,
"update_cache(V)");

// FlashDecoding decode (S==1, static pos). Shapes FD can't handle (head dim
// > kSdpaFdMaxHeadDim) fall through to the materialized path below.
if (S == 1 && !dynamic_pos && D <= kSdpaFdMaxHeadDim) {
sdpa_fd_decode_dispatch(
graph, q, k_cache, v_cache, out, Hq, Hkv, D, context_len, g, scale);
return;
}

// QK/softmax scratch — allocated only on the non-FD path (Hq*S*Cmax prefill).
WGPUBuffer attn_weights = graph.create_scratch_buffer(aw_bytes);
WGPUBuffer attn_weights_softmax = graph.create_scratch_buffer(aw_bytes);

// --- Dispatch 3: QK -> attn_weights. One thread per TM x TN tile.
{
if (aw_floats > UINT32_MAX) {
Expand All @@ -487,7 +480,7 @@ void sdpa_with_kv_cache_impl(WebGPUGraph& graph, const std::vector<int>& args) {
device, static_cast<uint32_t>(qk_tiles), qk_wg, "QK");
AttnWeightsParams p = make_attn_weights_params(
S, Hq, Hkv, D, context_len, input_pos, g, scale);
WGPUBuffer ubuf = make_uniform_buffer(graph, &p, sizeof(p));
WGPUBuffer ubuf = graph.make_uniform_buffer(&p, sizeof(p));
BufferBinding bindings[3] = {
{attn_weights, aw_bytes},
{q.buffer, q.nbytes},
Expand All @@ -513,7 +506,7 @@ void sdpa_with_kv_cache_impl(WebGPUGraph& graph, const std::vector<int>& args) {
const uint32_t wgc = utils::compute_1d_workgroup_count(
device, static_cast<uint32_t>(Hq * S), 1, "softmax");
SoftmaxParams p = make_softmax_params(Hq, S, context_len);
WGPUBuffer ubuf = make_uniform_buffer(graph, &p, sizeof(p));
WGPUBuffer ubuf = graph.make_uniform_buffer(&p, sizeof(p));
BufferBinding bindings[2] = {
{attn_weights_softmax, aw_bytes}, {attn_weights, aw_bytes}};
build_dispatch(
Expand All @@ -537,7 +530,7 @@ void sdpa_with_kv_cache_impl(WebGPUGraph& graph, const std::vector<int>& args) {
const uint32_t wgc = utils::compute_1d_workgroup_count(
device, static_cast<uint32_t>(av_tiles), av_wg, "AV");
ComputeOutParams p = make_compute_out_params(S, Hq, Hkv, D, context_len, g);
WGPUBuffer ubuf = make_uniform_buffer(graph, &p, sizeof(p));
WGPUBuffer ubuf = graph.make_uniform_buffer(&p, sizeof(p));
BufferBinding bindings[3] = {
{out.buffer, out.nbytes},
{attn_weights_softmax, aw_bytes},
Expand Down
Loading
Loading