From 5dcc278bd3136ceebfc585e1396685bac456fbb8 Mon Sep 17 00:00:00 2001 From: Julian Ng-Thow-Hing Date: Fri, 26 Jun 2026 08:30:27 -0700 Subject: [PATCH] Update [ghstack-poisoned] --- backends/webgpu/CMakeLists.txt | 14 + backends/webgpu/runtime/ops/sdpa/Sdpa.cpp | 20 +- .../ops/sdpa_fd_decode/SdpaFdDecode.cpp | 260 ++++++++++++++++++ .../runtime/ops/sdpa_fd_decode/SdpaFdDecode.h | 38 +++ .../ops/sdpa_fd_decode/sdpa_fd_reduce.wgsl | 58 ++++ .../ops/sdpa_fd_decode/sdpa_fd_reduce_wgsl.h | 82 ++++++ .../ops/sdpa_fd_decode/sdpa_fd_split.wgsl | 131 +++++++++ .../ops/sdpa_fd_decode/sdpa_fd_split_wgsl.h | 155 +++++++++++ 8 files changed, 755 insertions(+), 3 deletions(-) create mode 100644 backends/webgpu/runtime/ops/sdpa_fd_decode/SdpaFdDecode.cpp create mode 100644 backends/webgpu/runtime/ops/sdpa_fd_decode/SdpaFdDecode.h create mode 100644 backends/webgpu/runtime/ops/sdpa_fd_decode/sdpa_fd_reduce.wgsl create mode 100644 backends/webgpu/runtime/ops/sdpa_fd_decode/sdpa_fd_reduce_wgsl.h create mode 100644 backends/webgpu/runtime/ops/sdpa_fd_decode/sdpa_fd_split.wgsl create mode 100644 backends/webgpu/runtime/ops/sdpa_fd_decode/sdpa_fd_split_wgsl.h diff --git a/backends/webgpu/CMakeLists.txt b/backends/webgpu/CMakeLists.txt index f7cd85f9758..1759035c7e2 100644 --- a/backends/webgpu/CMakeLists.txt +++ b/backends/webgpu/CMakeLists.txt @@ -94,6 +94,20 @@ if(EXECUTORCH_BUILD_WEBGPU_PROFILING) ) endif() +# Split-KV FlashDecoding decode path (sdpa_fd_decode). Default ON: selected at +# runtime for decode (S==1) shapes it supports (head dim <= kSdpaFdMaxHeadDim); +# other shapes use the materialized SDPA path. Set OFF as a kill-switch to drop +# all FlashDecoding code from the build. +option(EXECUTORCH_BUILD_WEBGPU_SDPA_FD + "Enable split-KV FlashDecoding SDPA decode path" ON +) +if(EXECUTORCH_BUILD_WEBGPU_SDPA_FD) + target_sources( + webgpu_backend PRIVATE runtime/ops/sdpa_fd_decode/SdpaFdDecode.cpp + ) + target_compile_definitions(webgpu_backend PRIVATE WEBGPU_SDPA_FD) +endif() + # Link with --whole-archive for static registration of backend + ops executorch_target_link_options_shared_lib(webgpu_backend) diff --git a/backends/webgpu/runtime/ops/sdpa/Sdpa.cpp b/backends/webgpu/runtime/ops/sdpa/Sdpa.cpp index b1aa689a09d..42314162256 100644 --- a/backends/webgpu/runtime/ops/sdpa/Sdpa.cpp +++ b/backends/webgpu/runtime/ops/sdpa/Sdpa.cpp @@ -12,6 +12,9 @@ #include #include #include +#if defined(WEBGPU_SDPA_FD) +#include +#endif #include #include @@ -427,9 +430,6 @@ void sdpa_with_kv_cache_impl(WebGPUGraph& graph, const std::vector& args) { static_cast(S) * static_cast(dynamic_pos ? Cmax : context_len); const uint64_t aw_bytes = aw_cap_floats * sizeof(float); - // Prefill scratch scales as Hq·S·Cmax; can be large for long-context prefill. - WGPUBuffer attn_weights = graph.create_scratch_buffer(aw_bytes); - WGPUBuffer attn_weights_softmax = graph.create_scratch_buffer(aw_bytes); // Dynamic input_pos: the resize hook rewrites these per step. WGPUBuffer uc_k_buf = nullptr, uc_v_buf = nullptr, qk_buf = nullptr, @@ -473,6 +473,20 @@ void sdpa_with_kv_cache_impl(WebGPUGraph& graph, const std::vector& args) { dynamic_pos, "update_cache(V)"); +#ifdef WEBGPU_SDPA_FD + // FlashDecoding decode (S==1, static pos). Shapes FD can't handle (head dim + // > kSdpaFdMaxHeadDim) fall through to the materialized path below. + if (S == 1 && !dynamic_pos && D <= kSdpaFdMaxHeadDim) { + sdpa_fd_decode_dispatch( + graph, q, k_cache, v_cache, out, Hq, Hkv, D, context_len, g, scale); + return; + } +#endif + + // QK/softmax scratch — allocated only on the non-FD path (Hq*S*Cmax prefill). + WGPUBuffer attn_weights = graph.create_scratch_buffer(aw_bytes); + WGPUBuffer attn_weights_softmax = graph.create_scratch_buffer(aw_bytes); + // --- Dispatch 3: QK -> attn_weights. One thread per TM x TN tile. { if (aw_floats > UINT32_MAX) { diff --git a/backends/webgpu/runtime/ops/sdpa_fd_decode/SdpaFdDecode.cpp b/backends/webgpu/runtime/ops/sdpa_fd_decode/SdpaFdDecode.cpp new file mode 100644 index 00000000000..7d8db43f68f --- /dev/null +++ b/backends/webgpu/runtime/ops/sdpa_fd_decode/SdpaFdDecode.cpp @@ -0,0 +1,260 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +// Split-KV FlashDecoding decode dispatch (split + reduce passes). + +#include +#include +#include +#include +#include + +#include + +#include +#include +#include + +namespace executorch::backends::webgpu { + +namespace { + +// MUST match the .wgsl: MAX_SPLITS and WG_SIZE*MAX_D_PER_LANE. +constexpr uint32_t kSdpaFdSplitTile = 64; // KV positions per split +constexpr uint32_t kSdpaFdMaxSplits = 128; // == MAX_SPLITS in both .wgsl files +// Public head-dim limit (kSdpaFdMaxHeadDim) must equal the kernel's lane-owns-D +// reach; tie them so a WG_SIZE change can't silently desync the Sdpa.cpp gate. +static_assert( + kSdpaFdMaxHeadDim == kSdpaFdSplitWorkgroupSizeX * 2u, + "kSdpaFdMaxHeadDim must match WG_SIZE * MAX_D_PER_LANE"); + +struct FdSplitParams { + uint32_t Hq; + uint32_t Hkv; + uint32_t D; + uint32_t context_len; + uint32_t g; + uint32_t num_splits; + uint32_t split_len; + float scale; +}; +static_assert(sizeof(FdSplitParams) == 32, "FdSplitParams must be 32B"); + +struct FdReduceParams { + uint32_t D; + uint32_t num_splits; + uint32_t _pad0; + uint32_t _pad1; +}; +static_assert(sizeof(FdReduceParams) == 16, "FdReduceParams must be 16B"); + +struct BufferBinding { + WGPUBuffer buffer; + uint64_t size; +}; + +WGPUBuffer +make_uniform_buffer(WebGPUGraph& graph, const void* data, size_t size) { + WGPUDevice device = graph.device(); + WGPUBufferDescriptor desc = {}; + desc.size = size; + desc.usage = WGPUBufferUsage_Uniform | WGPUBufferUsage_CopyDst; + desc.mappedAtCreation = true; + WGPUBuffer buffer = wgpuDeviceCreateBuffer(device, &desc); + void* mapped = wgpuBufferGetMappedRange(buffer, 0, size); + std::memcpy(mapped, data, size); + wgpuBufferUnmap(buffer); + graph.add_uniform_buffer_bytes(size); + return buffer; +} + +// Mirrors Sdpa.cpp build_dispatch; n_rw leading bindings are read_write. +void build_dispatch( + WebGPUGraph& graph, + const char* wgsl_source, + const BufferBinding* storage_bindings, + uint32_t n_storage, + uint32_t n_rw, + WGPUBuffer uniform_buffer, + uint64_t uniform_size, + uint32_t workgroup_count_x, + const char* kernel_name) { + WGPUDevice device = graph.device(); + + WGPUShaderSourceWGSL wgsl_desc = {}; + wgsl_desc.chain.sType = WGPUSType_ShaderSourceWGSL; + wgsl_desc.code = {wgsl_source, WGPU_STRLEN}; + WGPUShaderModuleDescriptor shader_desc = {}; + shader_desc.nextInChain = &wgsl_desc.chain; + WGPUShaderModule shader = wgpuDeviceCreateShaderModule(device, &shader_desc); + + constexpr uint32_t kMaxEntries = 8; + if (n_storage + 1u > kMaxEntries) { + throw std::runtime_error( + "WebGPU sdpa FlashDecoding: bind group entry count exceeds kMaxEntries"); + } + WGPUBindGroupLayoutEntry bgl_entries[kMaxEntries] = {}; + const uint32_t uniform_binding = n_storage; + for (uint32_t i = 0; i < n_storage; i++) { + bgl_entries[i].binding = i; + bgl_entries[i].visibility = WGPUShaderStage_Compute; + bgl_entries[i].buffer.type = (i < n_rw) + ? WGPUBufferBindingType_Storage + : WGPUBufferBindingType_ReadOnlyStorage; + } + bgl_entries[uniform_binding].binding = uniform_binding; + bgl_entries[uniform_binding].visibility = WGPUShaderStage_Compute; + bgl_entries[uniform_binding].buffer.type = WGPUBufferBindingType_Uniform; + + WGPUBindGroupLayoutDescriptor bgl_desc = {}; + bgl_desc.entryCount = n_storage + 1; + bgl_desc.entries = bgl_entries; + WGPUBindGroupLayout bgl = wgpuDeviceCreateBindGroupLayout(device, &bgl_desc); + + WGPUPipelineLayoutDescriptor pl_desc = {}; + pl_desc.bindGroupLayoutCount = 1; + pl_desc.bindGroupLayouts = &bgl; + WGPUPipelineLayout pipeline_layout = + wgpuDeviceCreatePipelineLayout(device, &pl_desc); + + WGPUComputePipelineDescriptor pipeline_desc = {}; + pipeline_desc.layout = pipeline_layout; + pipeline_desc.compute.module = shader; + pipeline_desc.compute.entryPoint = {"main", WGPU_STRLEN}; + WGPUComputePipeline pipeline = + wgpuDeviceCreateComputePipeline(device, &pipeline_desc); + + WGPUBindGroupEntry bg_entries[kMaxEntries] = {}; + for (uint32_t i = 0; i < n_storage; i++) { + bg_entries[i].binding = i; + bg_entries[i].buffer = storage_bindings[i].buffer; + bg_entries[i].size = storage_bindings[i].size; + } + bg_entries[uniform_binding].binding = uniform_binding; + bg_entries[uniform_binding].buffer = uniform_buffer; + bg_entries[uniform_binding].size = uniform_size; + + WGPUBindGroupDescriptor bg_desc = {}; + bg_desc.layout = bgl; + bg_desc.entryCount = n_storage + 1; + bg_desc.entries = bg_entries; + WGPUBindGroup bind_group = wgpuDeviceCreateBindGroup(device, &bg_desc); + + graph.add_dispatch({pipeline, bind_group, workgroup_count_x, kernel_name}); + + wgpuShaderModuleRelease(shader); + wgpuBindGroupLayoutRelease(bgl); + wgpuPipelineLayoutRelease(pipeline_layout); + wgpuBufferRelease(uniform_buffer); +} + +} // namespace + +void sdpa_fd_decode_dispatch( + WebGPUGraph& graph, + const WebGPUTensor& q, + const WebGPUTensor& k_cache, + const WebGPUTensor& v_cache, + const WebGPUTensor& out, + int64_t Hq, + int64_t Hkv, + int64_t D, + int64_t context_len, + int64_t g, + float scale) { + // Defensive contract guard: the Sdpa.cpp gate only routes D <= this here, but + // keep the check (lane-owns-D reach) so a future caller can't silently overrun. + if (D > kSdpaFdMaxHeadDim) { + throw std::runtime_error( + "WebGPU sdpa FlashDecoding: head dim must be <= 128"); + } + if (D % 4 != 0) { + throw std::runtime_error( + "WebGPU sdpa FlashDecoding: head dim must be a multiple of 4"); + } + + // Split factor: one split per kSdpaFdSplitTile KV rows, capped. + uint32_t num_splits = static_cast( + (context_len + kSdpaFdSplitTile - 1) / kSdpaFdSplitTile); + if (num_splits < 1u) { + num_splits = 1u; + } + if (num_splits > kSdpaFdMaxSplits) { + num_splits = kSdpaFdMaxSplits; + } + const uint32_t split_len = + static_cast((context_len + num_splits - 1) / num_splits); + + // Scratch: per-(head,split) partials at kSdpaFdMaxSplits stride. + const uint64_t po_floats = static_cast(Hq) * + static_cast(kSdpaFdMaxSplits) * static_cast(D); + const uint64_t pml_floats = static_cast(Hq) * + static_cast(kSdpaFdMaxSplits) * 2ull; + WGPUBuffer part_o = graph.create_scratch_buffer(po_floats * sizeof(float)); + WGPUBuffer part_ml = graph.create_scratch_buffer(pml_floats * sizeof(float)); + + // Pass 1: split (Hq*num_splits WGs) -> writes part_o, part_ml. + FdSplitParams sp = {}; + sp.Hq = static_cast(Hq); + sp.Hkv = static_cast(Hkv); + sp.D = static_cast(D); + sp.context_len = static_cast(context_len); + sp.g = static_cast(g); + sp.num_splits = num_splits; + sp.split_len = split_len; + sp.scale = scale; + WGPUBuffer ub_split = make_uniform_buffer(graph, &sp, sizeof(sp)); + BufferBinding split_bindings[5] = { + {part_o, po_floats * sizeof(float)}, + {part_ml, pml_floats * sizeof(float)}, + {q.buffer, q.nbytes}, + {k_cache.buffer, k_cache.nbytes}, + {v_cache.buffer, v_cache.nbytes}}; + const uint32_t wgc_split = utils::compute_1d_workgroup_count( + graph.device(), + static_cast(Hq) * num_splits * kSdpaFdSplitWorkgroupSizeX, + kSdpaFdSplitWorkgroupSizeX, + "fd_split"); + build_dispatch( + graph, + kSdpaFdSplitWGSL, + split_bindings, + 5, + 2, + ub_split, + sizeof(sp), + wgc_split, + "fd_split"); + + // Pass 2: reduce (Hq WGs) -> reads part_o, part_ml; writes out. + FdReduceParams rp = {}; + rp.D = static_cast(D); + rp.num_splits = num_splits; + WGPUBuffer ub_reduce = make_uniform_buffer(graph, &rp, sizeof(rp)); + BufferBinding reduce_bindings[3] = { + {out.buffer, out.nbytes}, + {part_o, po_floats * sizeof(float)}, + {part_ml, pml_floats * sizeof(float)}}; + const uint32_t wgc_reduce = utils::compute_1d_workgroup_count( + graph.device(), + static_cast(Hq) * kSdpaFdReduceWorkgroupSizeX, + kSdpaFdReduceWorkgroupSizeX, + "fd_reduce"); + build_dispatch( + graph, + kSdpaFdReduceWGSL, + reduce_bindings, + 3, + 1, + ub_reduce, + sizeof(rp), + wgc_reduce, + "fd_reduce"); +} + +} // namespace executorch::backends::webgpu diff --git a/backends/webgpu/runtime/ops/sdpa_fd_decode/SdpaFdDecode.h b/backends/webgpu/runtime/ops/sdpa_fd_decode/SdpaFdDecode.h new file mode 100644 index 00000000000..1b161cd8ec0 --- /dev/null +++ b/backends/webgpu/runtime/ops/sdpa_fd_decode/SdpaFdDecode.h @@ -0,0 +1,38 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +#include + +namespace executorch::backends::webgpu { + +// FlashDecoding's lane-owns-D layout covers head dims up to WG_SIZE(64) * +// MAX_D_PER_LANE(2). Decode shapes above this fall through to the materialized +// SDPA path (the FD selection predicate in Sdpa.cpp checks this). +constexpr int64_t kSdpaFdMaxHeadDim = 128; + +// Split-KV FlashDecoding decode dispatch (S==1): a split pass over +// Hq*num_splits workgroups + a reduce pass over Hq workgroups. Called from the +// Sdpa.cpp WEBGPU_SDPA_FD branch. +void sdpa_fd_decode_dispatch( + WebGPUGraph& graph, + const WebGPUTensor& q, + const WebGPUTensor& k_cache, + const WebGPUTensor& v_cache, + const WebGPUTensor& out, + int64_t Hq, + int64_t Hkv, + int64_t D, + int64_t context_len, + int64_t g, + float scale); + +} // namespace executorch::backends::webgpu diff --git a/backends/webgpu/runtime/ops/sdpa_fd_decode/sdpa_fd_reduce.wgsl b/backends/webgpu/runtime/ops/sdpa_fd_decode/sdpa_fd_reduce.wgsl new file mode 100644 index 00000000000..60f86d217ff --- /dev/null +++ b/backends/webgpu/runtime/ops/sdpa_fd_decode/sdpa_fd_reduce.wgsl @@ -0,0 +1,58 @@ +@group(0) @binding(0) var t_out: array; +@group(0) @binding(1) var t_part_o: array; +@group(0) @binding(2) var t_part_ml: array; + +struct Params { + D: u32, + num_splits: u32, + _pad0: u32, + _pad1: u32, +} +@group(0) @binding(3) var params: Params; + +const WG_SIZE: u32 = 64u; +const MAX_SPLITS: u32 = 128u; +const MAX_D_PER_LANE: u32 = 2u; +const NEG_INF: f32 = -1.0e30; + +// w_i = exp(m_i - M) per split, computed once and reused for the L-sum and every output dim. +var sh_w: array; + +// FlashDecoding pass 2: online-softmax merge of the per-split partials, then normalize. +@compute @workgroup_size(64, 1, 1) +fn main( + @builtin(workgroup_id) wid: vec3, + @builtin(local_invocation_id) lid: vec3) { + let h = wid.x; + let t = lid.x; + let D = params.D; + let ns = params.num_splits; + let head_base = h * MAX_SPLITS; + + var M: f32 = NEG_INF; + for (var i: u32 = 0u; i < ns; i = i + 1u) { + M = max(M, t_part_ml[(head_base + i) * 2u + 0u]); + } + // Compute w_i = exp(m_i - M) once per split into shared memory (was recomputed per output dim). + for (var i: u32 = t; i < ns; i = i + WG_SIZE) { + sh_w[i] = exp(t_part_ml[(head_base + i) * 2u + 0u] - M); + } + workgroupBarrier(); + + var L: f32 = 0.0; + for (var i: u32 = 0u; i < ns; i = i + 1u) { + L = L + sh_w[i] * t_part_ml[(head_base + i) * 2u + 1u]; + } + let inv = select(0.0, 1.0 / L, L > 0.0); + + for (var nd: u32 = 0u; nd < MAX_D_PER_LANE; nd = nd + 1u) { + let d = t + nd * WG_SIZE; + if (d < D) { + var acc: f32 = 0.0; + for (var i: u32 = 0u; i < ns; i = i + 1u) { + acc = acc + sh_w[i] * t_part_o[(head_base + i) * D + d]; + } + t_out[h * D + d] = acc * inv; + } + } +} diff --git a/backends/webgpu/runtime/ops/sdpa_fd_decode/sdpa_fd_reduce_wgsl.h b/backends/webgpu/runtime/ops/sdpa_fd_decode/sdpa_fd_reduce_wgsl.h new file mode 100644 index 00000000000..6e5a515c065 --- /dev/null +++ b/backends/webgpu/runtime/ops/sdpa_fd_decode/sdpa_fd_reduce_wgsl.h @@ -0,0 +1,82 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +namespace executorch::backends::webgpu { + +// @generated from sdpa_fd_reduce.wgsl - DO NOT EDIT. +// wgsl-sha256: 3debe9b52adece82b31a067a876105121afb4e09cb31d95596ea699d1a179d18 +inline constexpr const char* kSdpaFdReduceWGSL = R"( +@group(0) @binding(0) var t_out: array; +@group(0) @binding(1) var t_part_o: array; +@group(0) @binding(2) var t_part_ml: array; + +struct Params { + D: u32, + num_splits: u32, + _pad0: u32, + _pad1: u32, +} +@group(0) @binding(3) var params: Params; + +const WG_SIZE: u32 = 64u; +const MAX_SPLITS: u32 = 128u; +const MAX_D_PER_LANE: u32 = 2u; +const NEG_INF: f32 = -1.0e30; + +// w_i = exp(m_i - M) per split, computed once and reused for the L-sum and every output dim. +var sh_w: array; + +// FlashDecoding pass 2: online-softmax merge of the per-split partials, then normalize. +@compute @workgroup_size(64, 1, 1) +fn main( + @builtin(workgroup_id) wid: vec3, + @builtin(local_invocation_id) lid: vec3) { + let h = wid.x; + let t = lid.x; + let D = params.D; + let ns = params.num_splits; + let head_base = h * MAX_SPLITS; + + var M: f32 = NEG_INF; + for (var i: u32 = 0u; i < ns; i = i + 1u) { + M = max(M, t_part_ml[(head_base + i) * 2u + 0u]); + } + // Compute w_i = exp(m_i - M) once per split into shared memory (was recomputed per output dim). + for (var i: u32 = t; i < ns; i = i + WG_SIZE) { + sh_w[i] = exp(t_part_ml[(head_base + i) * 2u + 0u] - M); + } + workgroupBarrier(); + + var L: f32 = 0.0; + for (var i: u32 = 0u; i < ns; i = i + 1u) { + L = L + sh_w[i] * t_part_ml[(head_base + i) * 2u + 1u]; + } + let inv = select(0.0, 1.0 / L, L > 0.0); + + for (var nd: u32 = 0u; nd < MAX_D_PER_LANE; nd = nd + 1u) { + let d = t + nd * WG_SIZE; + if (d < D) { + var acc: f32 = 0.0; + for (var i: u32 = 0u; i < ns; i = i + 1u) { + acc = acc + sh_w[i] * t_part_o[(head_base + i) * D + d]; + } + t_out[h * D + d] = acc * inv; + } + } +} +)"; + +inline constexpr uint32_t kSdpaFdReduceWorkgroupSizeX = 64; +inline constexpr uint32_t kSdpaFdReduceWorkgroupSizeY = 1; +inline constexpr uint32_t kSdpaFdReduceWorkgroupSizeZ = 1; + +} // namespace executorch::backends::webgpu diff --git a/backends/webgpu/runtime/ops/sdpa_fd_decode/sdpa_fd_split.wgsl b/backends/webgpu/runtime/ops/sdpa_fd_decode/sdpa_fd_split.wgsl new file mode 100644 index 00000000000..6e716141ee5 --- /dev/null +++ b/backends/webgpu/runtime/ops/sdpa_fd_decode/sdpa_fd_split.wgsl @@ -0,0 +1,131 @@ +@group(0) @binding(0) var t_part_o: array; +@group(0) @binding(1) var t_part_ml: array; +@group(0) @binding(2) var t_q: array; +@group(0) @binding(3) var t_k_cache: array; +@group(0) @binding(4) var t_v_cache: array; + +struct Params { + Hq: u32, + Hkv: u32, + D: u32, + context_len: u32, + g: u32, + num_splits: u32, + split_len: u32, + scale: f32, +} +@group(0) @binding(5) var params: Params; + +const WG_SIZE: u32 = 64u; +const MAX_SPLITS: u32 = 128u; +const MAX_D_PER_LANE: u32 = 2u; +const NEG_INF: f32 = -1.0e30; + +// sh_s: block scores then softmax weights; sh_red: max/sum reduction scratch. +var sh_s: array; +var sh_red: array; + +// FlashDecoding pass 1: per-(head,split) unnormalized softmax partial. +@compute @workgroup_size(64, 1, 1) +fn main( + @builtin(workgroup_id) wid: vec3, + @builtin(local_invocation_id) lid: vec3) { + let h = wid.x / params.num_splits; + let split_i = wid.x % params.num_splits; + let t = lid.x; + let D = params.D; + let D4 = D / 4u; // D is a multiple of 4 (guarded host-side); vec4 QK dot + let ctx = params.context_len; + let kv = h / params.g; + let q_base = h * D; + let kv_row_stride = params.Hkv * D; + + let c0 = split_i * params.split_len; + var c1 = c0 + params.split_len; + if (c1 > ctx) { c1 = ctx; } + + var m: f32 = NEG_INF; + var l: f32 = 0.0; + var o_acc: array; + for (var nd: u32 = 0u; nd < MAX_D_PER_LANE; nd = nd + 1u) { o_acc[nd] = 0.0; } + + // Stream the split in blocks of WG_SIZE KV positions. + var block: u32 = c0; + loop { + if (block >= c1) { break; } + var n: u32 = c1 - block; + if (n > WG_SIZE) { n = WG_SIZE; } + + // Phase 1: lane t computes the full QK dot for position block+t (vec4), one + // K row read once. Out-of-block lanes hold NEG_INF (safe for the max). + var s: f32 = NEG_INF; + if (t < n) { + let kvbase = (block + t) * kv_row_stride + kv * D; + var acc4 = vec4(0.0, 0.0, 0.0, 0.0); + for (var i4: u32 = 0u; i4 < D4; i4 = i4 + 1u) { + let qi = q_base + i4 * 4u; + let ki = kvbase + i4 * 4u; + let qv = vec4(t_q[qi], t_q[qi + 1u], t_q[qi + 2u], t_q[qi + 3u]); + let kvv = vec4( + t_k_cache[ki], t_k_cache[ki + 1u], + t_k_cache[ki + 2u], t_k_cache[ki + 3u]); + acc4 = acc4 + qv * kvv; + } + s = (acc4.x + acc4.y + acc4.z + acc4.w) * params.scale; + } + sh_s[t] = s; + + // Phase 2a: block max via tree reduction (sh_red written from register s). + sh_red[t] = s; + workgroupBarrier(); + for (var stride: u32 = WG_SIZE / 2u; stride > 0u; stride = stride >> 1u) { + if (t < stride) { sh_red[t] = max(sh_red[t], sh_red[t + stride]); } + workgroupBarrier(); + } + let m_new = max(m, sh_red[0]); + let rescale = exp(m - m_new); + + // Phase 2b: each lane exponentiates ITS position once -> p (reuse sh_s), + // and reduce the block sum of p. + var p_t: f32 = 0.0; + if (t < n) { p_t = exp(sh_s[t] - m_new); } + workgroupBarrier(); // all reads of sh_s (the scores) done before overwrite + sh_s[t] = p_t; + sh_red[t] = p_t; + workgroupBarrier(); + for (var stride: u32 = WG_SIZE / 2u; stride > 0u; stride = stride >> 1u) { + if (t < stride) { sh_red[t] = sh_red[t] + sh_red[t + stride]; } + workgroupBarrier(); + } + l = rescale * l + sh_red[0]; + + // Phase 2c: each lane accumulates V for its own output dims over the block, + // reading the shared per-position weights (no exp in this loop). + for (var nd: u32 = 0u; nd < MAX_D_PER_LANE; nd = nd + 1u) { + let d = t + nd * WG_SIZE; + if (d < D) { + var acc: f32 = rescale * o_acc[nd]; + for (var j: u32 = 0u; j < n; j = j + 1u) { + let vbase = (block + j) * kv_row_stride + kv * D; + acc = acc + sh_s[j] * t_v_cache[vbase + d]; + } + o_acc[nd] = acc; + } + } + m = m_new; + workgroupBarrier(); // before the next block overwrites sh_s / sh_red + block = block + WG_SIZE; + } + + let part = h * MAX_SPLITS + split_i; + for (var nd: u32 = 0u; nd < MAX_D_PER_LANE; nd = nd + 1u) { + let d = t + nd * WG_SIZE; + if (d < D) { + t_part_o[part * D + d] = o_acc[nd]; + } + } + if (t == 0u) { + t_part_ml[part * 2u + 0u] = m; + t_part_ml[part * 2u + 1u] = l; + } +} diff --git a/backends/webgpu/runtime/ops/sdpa_fd_decode/sdpa_fd_split_wgsl.h b/backends/webgpu/runtime/ops/sdpa_fd_decode/sdpa_fd_split_wgsl.h new file mode 100644 index 00000000000..958c16d26dc --- /dev/null +++ b/backends/webgpu/runtime/ops/sdpa_fd_decode/sdpa_fd_split_wgsl.h @@ -0,0 +1,155 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +namespace executorch::backends::webgpu { + +// @generated from sdpa_fd_split.wgsl - DO NOT EDIT. +// wgsl-sha256: f8a392ab021e4f1453abc2dc615254985d8045bfa187fc6a1d8fab722276ec17 +inline constexpr const char* kSdpaFdSplitWGSL = R"( +@group(0) @binding(0) var t_part_o: array; +@group(0) @binding(1) var t_part_ml: array; +@group(0) @binding(2) var t_q: array; +@group(0) @binding(3) var t_k_cache: array; +@group(0) @binding(4) var t_v_cache: array; + +struct Params { + Hq: u32, + Hkv: u32, + D: u32, + context_len: u32, + g: u32, + num_splits: u32, + split_len: u32, + scale: f32, +} +@group(0) @binding(5) var params: Params; + +const WG_SIZE: u32 = 64u; +const MAX_SPLITS: u32 = 128u; +const MAX_D_PER_LANE: u32 = 2u; +const NEG_INF: f32 = -1.0e30; + +// sh_s: block scores then softmax weights; sh_red: max/sum reduction scratch. +var sh_s: array; +var sh_red: array; + +// FlashDecoding pass 1: per-(head,split) unnormalized softmax partial. +@compute @workgroup_size(64, 1, 1) +fn main( + @builtin(workgroup_id) wid: vec3, + @builtin(local_invocation_id) lid: vec3) { + let h = wid.x / params.num_splits; + let split_i = wid.x % params.num_splits; + let t = lid.x; + let D = params.D; + let D4 = D / 4u; // D is a multiple of 4 (guarded host-side); vec4 QK dot + let ctx = params.context_len; + let kv = h / params.g; + let q_base = h * D; + let kv_row_stride = params.Hkv * D; + + let c0 = split_i * params.split_len; + var c1 = c0 + params.split_len; + if (c1 > ctx) { c1 = ctx; } + + var m: f32 = NEG_INF; + var l: f32 = 0.0; + var o_acc: array; + for (var nd: u32 = 0u; nd < MAX_D_PER_LANE; nd = nd + 1u) { o_acc[nd] = 0.0; } + + // Stream the split in blocks of WG_SIZE KV positions. + var block: u32 = c0; + loop { + if (block >= c1) { break; } + var n: u32 = c1 - block; + if (n > WG_SIZE) { n = WG_SIZE; } + + // Phase 1: lane t computes the full QK dot for position block+t (vec4), one + // K row read once. Out-of-block lanes hold NEG_INF (safe for the max). + var s: f32 = NEG_INF; + if (t < n) { + let kvbase = (block + t) * kv_row_stride + kv * D; + var acc4 = vec4(0.0, 0.0, 0.0, 0.0); + for (var i4: u32 = 0u; i4 < D4; i4 = i4 + 1u) { + let qi = q_base + i4 * 4u; + let ki = kvbase + i4 * 4u; + let qv = vec4(t_q[qi], t_q[qi + 1u], t_q[qi + 2u], t_q[qi + 3u]); + let kvv = vec4( + t_k_cache[ki], t_k_cache[ki + 1u], + t_k_cache[ki + 2u], t_k_cache[ki + 3u]); + acc4 = acc4 + qv * kvv; + } + s = (acc4.x + acc4.y + acc4.z + acc4.w) * params.scale; + } + sh_s[t] = s; + + // Phase 2a: block max via tree reduction (sh_red written from register s). + sh_red[t] = s; + workgroupBarrier(); + for (var stride: u32 = WG_SIZE / 2u; stride > 0u; stride = stride >> 1u) { + if (t < stride) { sh_red[t] = max(sh_red[t], sh_red[t + stride]); } + workgroupBarrier(); + } + let m_new = max(m, sh_red[0]); + let rescale = exp(m - m_new); + + // Phase 2b: each lane exponentiates ITS position once -> p (reuse sh_s), + // and reduce the block sum of p. + var p_t: f32 = 0.0; + if (t < n) { p_t = exp(sh_s[t] - m_new); } + workgroupBarrier(); // all reads of sh_s (the scores) done before overwrite + sh_s[t] = p_t; + sh_red[t] = p_t; + workgroupBarrier(); + for (var stride: u32 = WG_SIZE / 2u; stride > 0u; stride = stride >> 1u) { + if (t < stride) { sh_red[t] = sh_red[t] + sh_red[t + stride]; } + workgroupBarrier(); + } + l = rescale * l + sh_red[0]; + + // Phase 2c: each lane accumulates V for its own output dims over the block, + // reading the shared per-position weights (no exp in this loop). + for (var nd: u32 = 0u; nd < MAX_D_PER_LANE; nd = nd + 1u) { + let d = t + nd * WG_SIZE; + if (d < D) { + var acc: f32 = rescale * o_acc[nd]; + for (var j: u32 = 0u; j < n; j = j + 1u) { + let vbase = (block + j) * kv_row_stride + kv * D; + acc = acc + sh_s[j] * t_v_cache[vbase + d]; + } + o_acc[nd] = acc; + } + } + m = m_new; + workgroupBarrier(); // before the next block overwrites sh_s / sh_red + block = block + WG_SIZE; + } + + let part = h * MAX_SPLITS + split_i; + for (var nd: u32 = 0u; nd < MAX_D_PER_LANE; nd = nd + 1u) { + let d = t + nd * WG_SIZE; + if (d < D) { + t_part_o[part * D + d] = o_acc[nd]; + } + } + if (t == 0u) { + t_part_ml[part * 2u + 0u] = m; + t_part_ml[part * 2u + 1u] = l; + } +} +)"; + +inline constexpr uint32_t kSdpaFdSplitWorkgroupSizeX = 64; +inline constexpr uint32_t kSdpaFdSplitWorkgroupSizeY = 1; +inline constexpr uint32_t kSdpaFdSplitWorkgroupSizeZ = 1; + +} // namespace executorch::backends::webgpu