Skip to content

[ExecuTorch][WebGPU] Enable FlashDecoding by default for decode SDPA (runtime shape gate)#20544

Open
JulianCloudNTH wants to merge 4 commits into
gh/JulianCloudNTH/64/basefrom
gh/JulianCloudNTH/64/head
Open

[ExecuTorch][WebGPU] Enable FlashDecoding by default for decode SDPA (runtime shape gate)#20544
JulianCloudNTH wants to merge 4 commits into
gh/JulianCloudNTH/64/basefrom
gh/JulianCloudNTH/64/head

Conversation

@JulianCloudNTH

@JulianCloudNTH JulianCloudNTH commented Jun 26, 2026

Copy link
Copy Markdown
Contributor

Stack from ghstack (oldest at bottom):

Makes split-KV FlashDecoding the default decode-attention path (it was shipped dormant behind a default-OFF compile flag). FD is the fastest WebGPU SDPA decode arm (+178% vs naive, M4 Pro, isolated op); this turns it on for production and selects it at runtime by a shape-capability predicate.

{F1991715077}

Problem: the FD kernel is correct and measured (+178%) but compile-gated OFF, so no production build used it. A device-limit gate (web-llm-style maxStorageBufferBindingSize) was considered but is dead code here: FD's resource needs (workgroup size 64, 512 B shared memory, 5 storage bindings) are all below WebGPU's baseline minimum limits, and FD binds the same K/V caches as the materialized fallback — so no spec-compliant device can run materialized decode but fail FD. The only selection criterion with real effect is shape.

Solution: enable FD by default and select it at runtime on shape, not device.

  • Before: EXECUTORCH_BUILD_WEBGPU_SDPA_FD default OFF; FD code unlinked; every decode used the materialized QK/softmax/AV path.
  • After: flag default ON (kept as a build-time kill-switch); decode (S == 1, static input_pos) with head dim <= kSdpaFdMaxHeadDim uses FD; other shapes (including head dim > 128) fall through to the materialized path.

Implementation:

  • Sdpa.cpp: extend the FD selection predicate with D <= kSdpaFdMaxHeadDim so unsupported head dims fall through instead of throwing.
  • SdpaFdDecode.h: expose kSdpaFdMaxHeadDim (FD's lane-owns-D reach) as the single source of truth; SdpaFdDecode.cpp ties it to WG_SIZE * MAX_D_PER_LANE with a static_assert.
  • CMakeLists.txt (fbcode + xplat): flip the option default to ON; OFF remains a kill-switch that drops all FlashDecoding code.
  • test_webgpu_native_ci.sh: drop the now-redundant explicit =ON flag so CI builds and tests the default.
  • Mirrors Vulkan backends/vulkan/runtime/graph/ops/impl/SDPA.cpp shape-based kernel selection (is_single_token); no device-adaptive gate, matching the Vulkan delegate.

Constraints: decode-only (S == 1), static input_pos (dynamic-pos decode still uses the materialized path); fp32, buffer-only; the FD kernels are unchanged by this diff.

Co-authored with Claude Code.

Differential Revision: D109520722

[ghstack-poisoned]
@pytorch-bot

pytorch-bot Bot commented Jun 26, 2026

Copy link
Copy Markdown

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/20544

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 1d697cb with merge base 51729bb (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jun 26, 2026
@github-actions

Copy link
Copy Markdown

This PR needs a release notes: label

If your change should be included in the release notes (i.e. would users of this library care about this change?), please use a label starting with release notes:. This helps us keep track and include your important work in the next release notes.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "release notes: none"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

@JulianCloudNTH

Copy link
Copy Markdown
Contributor Author

@claude Check for overengineering, areas for modularization, and functional correctness

@claude

claude Bot commented Jun 26, 2026

Copy link
Copy Markdown

Claude finished @JulianCloudNTH's task in 9m 2s —— View job


PR Review: Enable FlashDecoding by default for decode SDPA

I read the gate change in Sdpa.cpp, the new SdpaFdDecode.{h,cpp}, both WGSL kernels, and compared against the materialized path and shared utils.

  • Read changed files & diff
  • Check for overengineering
  • Identify modularization opportunities
  • Verify functional correctness

Overall this is a clean, well-scoped change. The gate is correct, the scratch-buffer move is a nice win, and the kernels are numerically equivalent to the materialized path for the decode case. Findings below, ordered by impact.


Functional correctness ✅ (with notes)

The gate and the no-mask assumption are correct. For decode (S == 1), s = 0 and context_len = input_pos + 1, so the materialized causal mask c > s + input_pos masks nothing in [0, context_len). The FD split kernel iterates the full [0, context_len) range with no mask — equivalent. Scale, GQA head mapping (h / g in both sdpa_compute_attn_weights.wgsl:69 and sdpa_fd_split.wgsl:56), Q/K/V/out layouts, and update-cache ordering (FD is recorded after the K/V update dispatches, so it reads the freshly written token) all line up. 👍

Empty splits in the capped case are handled correctly — worth confirming you intended this. num_splits is ceil(context_len/64) capped at 128, and split_len = ceil(context_len/num_splits). When the cap binds (e.g. context_len = 8193num_splits = 128, split_len = 65), the tail split indices start past context_len (127*65 = 8255 > 8193), so c0 >= c1 and that workgroup's loop never runs. It writes m = NEG_INF, l = 0. The reduce pass then computes sh_w[i] = exp(NEG_INF - M) = 0 (sdpa_fd_reduce.wgsl:55), contributing nothing to L or the output accumulator. So it's correct — but it does mean large contexts silently dispatch dead workgroups. Fine functionally; just flagging it's load-bearing on exp(-1e30) == 0.

Defensive guard message is stale. SdpaFdDecode.cpp:172-175 throws "head dim must be <= 128" with a hardcoded literal, while the rest of the change makes kSdpaFdMaxHeadDim the single source of truth. If WG_SIZE/MAX_D_PER_LANE ever change, this message lies. Minor, but cheap to fix:

"WebGPU sdpa FlashDecoding: head dim must be <= " +
    std::to_string(kSdpaFdMaxHeadDim);

Modularization — main finding

build_dispatch + make_uniform_buffer + BufferBinding are now duplicated. SdpaFdDecode.cpp:56-154 is a near-verbatim copy of Sdpa.cpp:135-249 (~90 lines). On top of that, utils::make_uniform already exists in WebGPUUtils.h:59, so uniform-buffer creation now has three copies. The two build_dispatch versions differ only in:

  • FD parameterizes read-write bindings via n_rw; Sdpa.cpp hardcodes "only binding 0 is rw" (Sdpa.cpp:186).
  • FD has no wg_size override constant (its kernels use a literal @workgroup_size(64)).
  • FD passes a kernel_name to the dispatch label.

These differences are all expressible in one signature: n_rw (defaulting to 1), wg_size = 0 to skip the override constant, and an optional label. Hoisting a shared build_dispatch/BufferBinding into a common header (e.g. alongside WebGPUUtils.h) and having both call sites use it would remove the largest block of duplication and keep bind-group semantics in one place.

I understand the kill-switch design wants FD code to compile out cleanly when WEBGPU_SDPA_FD=OFF, but a shared helper in the always-compiled core doesn't conflict with that — only the FD .cpp and its #ifdef call site get dropped.

Fix this →


Overengineering — minor

  • if (num_splits < 1u) num_splits = 1u; (SdpaFdDecode.cpp:184) is dead. context_len >= 1 is already guaranteed (Sdpa.cpp:393 throws on context_len <= 0), so ceil(context_len/64) is always >= 1 on an unsigned. Drop it or keep as a one-line assert if you prefer the contract documented.
  • The defensive D guards (SdpaFdDecode.cpp:172,176) re-check what the Sdpa.cpp gate already enforces. This is reasonable belt-and-suspenders for a public entry point — I'd keep it, just fix the message above.
  • FdReduceParams _pad0/_pad1 to 16B is correct (WebGPU uniform alignment), not overengineering.

Build / packaging — verify outside OSS

The PR description says the option default is flipped in "fbcode + xplat" CMake and that test_webgpu_native_ci.sh drops the explicit =ON, but the actual diff touches only backends/webgpu/CMakeLists.txt — no CI script and no buck targets.bzl change are present. If there's an internal buck target that lists sdpa/Sdpa.cpp sources, sdpa_fd_decode/SdpaFdDecode.cpp must be added there too (and the WEBGPU_SDPA_FD define wired up), or the internal build will fail to link sdpa_fd_decode_dispatch. I can't see internal targets from here — worth double-checking before landing. The OSS CMakeLists.txt change itself is correct.


Nice work — the runtime shape gate is the right call and the dead-code analysis in the description (FD's limits all below WebGPU baseline) is convincing. The one thing I'd act on before landing is the build_dispatch duplication; the rest are minor.
· branch gh/JulianCloudNTH/64/head

[ghstack-poisoned]
JulianCloudNTH added a commit that referenced this pull request Jun 26, 2026
…(runtime shape gate)

Pull Request resolved: #20544

**Makes split-KV FlashDecoding the default decode-attention path** (it was shipped dormant behind a default-OFF compile flag). FD is the fastest WebGPU SDPA decode arm (**+178% vs naive**, M4 Pro, isolated op); this turns it on for production and selects it at runtime by a shape-capability predicate.

{F1991715077}

**Problem:** the FD kernel is correct and measured (+178%) but compile-gated OFF, so no production build used it. A device-limit gate (web-llm-style `maxStorageBufferBindingSize`) was considered but is dead code here: FD's resource needs (workgroup size 64, 512 B shared memory, 5 storage bindings) are all below WebGPU's baseline minimum limits, and FD binds the same K/V caches as the materialized fallback — so no spec-compliant device can run materialized decode but fail FD. The only selection criterion with real effect is shape.

**Solution:** enable FD by default and select it at runtime on shape, not device.

- **Before:** `EXECUTORCH_BUILD_WEBGPU_SDPA_FD` default OFF; FD code unlinked; every decode used the materialized QK/softmax/AV path.
- **After:** flag default ON (kept as a build-time kill-switch); decode (`S == 1`, static input_pos) with head dim `<= kSdpaFdMaxHeadDim` uses FD; other shapes (including head dim > 128) fall through to the materialized path.

**Implementation:**

- `Sdpa.cpp`: extend the FD selection predicate with `D <= kSdpaFdMaxHeadDim` so unsupported head dims fall through instead of throwing.
- `SdpaFdDecode.h`: expose `kSdpaFdMaxHeadDim` (FD's lane-owns-D reach) as the single source of truth; `SdpaFdDecode.cpp` ties it to `WG_SIZE * MAX_D_PER_LANE` with a `static_assert`.
- `CMakeLists.txt` (fbcode + xplat): flip the option default to ON; OFF remains a kill-switch that drops all FlashDecoding code.
- `test_webgpu_native_ci.sh`: drop the now-redundant explicit `=ON` flag so CI builds and tests the default.
- Mirrors Vulkan `backends/vulkan/runtime/graph/ops/impl/SDPA.cpp` shape-based kernel selection (`is_single_token`); no device-adaptive gate, matching the Vulkan delegate.

**Constraints:** decode-only (`S == 1`), static input_pos (dynamic-pos decode still uses the materialized path); fp32, buffer-only; the FD kernels are unchanged by this diff.

Co-authored with Claude Code.
ghstack-source-id: 397435149
@exported-using-ghexport

Differential Revision: [D109520722](https://our.internmc.facebook.com/intern/diff/D109520722/)
[ghstack-poisoned]
JulianCloudNTH added a commit that referenced this pull request Jun 26, 2026
…(runtime shape gate)

Pull Request resolved: #20544

**Makes split-KV FlashDecoding the default decode-attention path** (it was shipped dormant behind a default-OFF compile flag). FD is the fastest WebGPU SDPA decode arm (**+178% vs naive**, M4 Pro, isolated op); this turns it on for production and selects it at runtime by a shape-capability predicate.

{F1991715077}

**Problem:** the FD kernel is correct and measured (+178%) but compile-gated OFF, so no production build used it. A device-limit gate (web-llm-style `maxStorageBufferBindingSize`) was considered but is dead code here: FD's resource needs (workgroup size 64, 512 B shared memory, 5 storage bindings) are all below WebGPU's baseline minimum limits, and FD binds the same K/V caches as the materialized fallback — so no spec-compliant device can run materialized decode but fail FD. The only selection criterion with real effect is shape.

**Solution:** enable FD by default and select it at runtime on shape, not device.

- **Before:** `EXECUTORCH_BUILD_WEBGPU_SDPA_FD` default OFF; FD code unlinked; every decode used the materialized QK/softmax/AV path.
- **After:** flag default ON (kept as a build-time kill-switch); decode (`S == 1`, static input_pos) with head dim `<= kSdpaFdMaxHeadDim` uses FD; other shapes (including head dim > 128) fall through to the materialized path.

**Implementation:**

- `Sdpa.cpp`: extend the FD selection predicate with `D <= kSdpaFdMaxHeadDim` so unsupported head dims fall through instead of throwing.
- `SdpaFdDecode.h`: expose `kSdpaFdMaxHeadDim` (FD's lane-owns-D reach) as the single source of truth; `SdpaFdDecode.cpp` ties it to `WG_SIZE * MAX_D_PER_LANE` with a `static_assert`.
- `CMakeLists.txt` (fbcode + xplat): flip the option default to ON; OFF remains a kill-switch that drops all FlashDecoding code.
- `test_webgpu_native_ci.sh`: drop the now-redundant explicit `=ON` flag so CI builds and tests the default.
- Mirrors Vulkan `backends/vulkan/runtime/graph/ops/impl/SDPA.cpp` shape-based kernel selection (`is_single_token`); no device-adaptive gate, matching the Vulkan delegate.

**Constraints:** decode-only (`S == 1`), static input_pos (dynamic-pos decode still uses the materialized path); fp32, buffer-only; the FD kernels are unchanged by this diff.

Co-authored with Claude Code.
ghstack-source-id: 397454762
@exported-using-ghexport

Differential Revision: [D109520722](https://our.internmc.facebook.com/intern/diff/D109520722/)

@SS-JIA SS-JIA left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Review automatically exported from Phabricator review in Meta.

[ghstack-poisoned]
JulianCloudNTH added a commit that referenced this pull request Jun 27, 2026
…(runtime shape gate)

Pull Request resolved: #20544

**Makes split-KV FlashDecoding the default decode-attention path** (it was shipped dormant behind a default-OFF compile flag). FD is the fastest WebGPU SDPA decode arm (**+178% vs naive**, M4 Pro, isolated op); this turns it on for production and selects it at runtime by a shape-capability predicate.

{F1991715077}

**Problem:** the FD kernel is correct and measured (+178%) but compile-gated OFF, so no production build used it. A device-limit gate (web-llm-style `maxStorageBufferBindingSize`) was considered but is dead code here: FD's resource needs (workgroup size 64, 512 B shared memory, 5 storage bindings) are all below WebGPU's baseline minimum limits, and FD binds the same K/V caches as the materialized fallback — so no spec-compliant device can run materialized decode but fail FD. The only selection criterion with real effect is shape.

**Solution:** enable FD by default and select it at runtime on shape, not device.

- **Before:** `EXECUTORCH_BUILD_WEBGPU_SDPA_FD` default OFF; FD code unlinked; every decode used the materialized QK/softmax/AV path.
- **After:** flag default ON (kept as a build-time kill-switch); decode (`S == 1`, static input_pos) with head dim `<= kSdpaFdMaxHeadDim` uses FD; other shapes (including head dim > 128) fall through to the materialized path.

**Implementation:**

- `Sdpa.cpp`: extend the FD selection predicate with `D <= kSdpaFdMaxHeadDim` so unsupported head dims fall through instead of throwing.
- `SdpaFdDecode.h`: expose `kSdpaFdMaxHeadDim` (FD's lane-owns-D reach) as the single source of truth; `SdpaFdDecode.cpp` ties it to `WG_SIZE * MAX_D_PER_LANE` with a `static_assert`.
- `CMakeLists.txt` (fbcode + xplat): flip the option default to ON; OFF remains a kill-switch that drops all FlashDecoding code.
- `test_webgpu_native_ci.sh`: drop the now-redundant explicit `=ON` flag so CI builds and tests the default.
- Mirrors Vulkan `backends/vulkan/runtime/graph/ops/impl/SDPA.cpp` shape-based kernel selection (`is_single_token`); no device-adaptive gate, matching the Vulkan delegate.

**Constraints:** decode-only (`S == 1`), static input_pos (dynamic-pos decode still uses the materialized path); fp32, buffer-only; the FD kernels are unchanged by this diff.

Co-authored with Claude Code.
ghstack-source-id: 397631266
@exported-using-ghexport

Differential Revision: [D109520722](https://our.internmc.facebook.com/intern/diff/D109520722/)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. meta-exported

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants