[PyTorch] Add pad_between_seqs support for non-CP and CP (A2A and P2P) with FA3 + THD (varlen)#2596
[PyTorch] Add pad_between_seqs support for non-CP and CP (A2A and P2P) with FA3 + THD (varlen)#2596sudhakarsingh27 wants to merge 16 commits intoNVIDIA:mainfrom
pad_between_seqs support for non-CP and CP (A2A and P2P) with FA3 + THD (varlen)#2596Conversation
Greptile SummaryThis PR extends Confidence Score: 5/5Safe to merge; previous P1 issues (FA4 not disabled, wrong cu_seqlens in non-CP FA3) are correctly addressed, and the implementation is consistent across all CP paths. All previously flagged P1 issues are resolved: utils.py now sets use_flash_attention_4 = False and use_unfused_attention = False; backends.py non-CP FA3 path passes cu_seqlens_q_padded for memory layout and seqused_q/k as kwargs; P2P and A2A paths derive seqused from per-step actual cu_seqlens and swap to padded cu_seqlens for memory layout. The only remaining issue is a P2 edge case around passing seqused_q/k to flash_attn_with_kvcache_v3 when inference_params is set, which is an untested combination. backends.py lines 1138-1157 (seqused kwarg forwarding to kvcache path when pad_between_seqs + inference_params both set). Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A["DotProductAttention.forward\npad_between_seqs, cu_seqlens_q_padded"] --> B{CP enabled?}
B -->|Yes| C["attn_forward_func_with_cp\npad_between_seqs threaded through"]
B -->|No| D["FlashAttention.forward\npad_between_seqs, padded cu_seqlens"]
C --> E{A2A or P2P?}
E -->|A2A| F["AttnFuncWithCPAndQKVOA2A\nseqused = cu_seqlens[1:] - cu_seqlens[:-1]\ncu_seqlens ← cu_seqlens_padded"]
E -->|P2P| G["AttnFuncWithCPAndKVP2P\ncp_p2p_fwd_flash_attn per tile"]
G --> H{tile section?}
H -->|diagonal/all| I["seqused from per_step cu_seqlens\ncu_seqlens ← cu_seqlens_padded"]
H -->|lower-triangle| J["seqused from per_step cu_seqlens\ncu_seqlens_kv ← cu_seqlens_kv_padded // 2"]
H -->|upper-triangle| K["seqused from per_step cu_seqlens\ncu_seqlens_q ← cu_seqlens_q_padded // 2"]
D --> L{FA version?}
L -->|FA3 non-CP| M["cu_seqlens_padded → positional args\nseqused_q/k → kwargs\nfa3 flash_attn_varlen_func_v3"]
L -->|FA4 / FA2| N["Blocked by utils.py:\nuse_flash_attention_2=False\nuse_flash_attention_4=False"]
F --> O["get_fa_args with seqused_q/k\nflash_attn_varlen_func_v3"]
I --> O
J --> O
K --> O
Reviews (32): Last reviewed commit: "Merge branch 'flash_attn_pad_bw_seqs' of..." | Re-trigger Greptile |
| # if `pad_between_seqs` is True, provide flash_attn_3 with `seqused_q` and `seqused_k` | ||
| # in addition to `cu_seqlens_q_padded` and `cu_seqlens_kv_padded` to avoid affecting the | ||
| # padding positions. | ||
| if pad_between_seqs: | ||
| fa_3_optional_forward_kwargs["seqused_q"] = ( | ||
| cu_seqlens_q[1:] - cu_seqlens_q[:-1] | ||
| ) | ||
| fa_3_optional_forward_kwargs["seqused_k"] = ( | ||
| cu_seqlens_kv[1:] - cu_seqlens_kv[:-1] | ||
| ) |
There was a problem hiding this comment.
style: verify that flash_attn_3 with seqused_q/seqused_k truly avoids writing to padding positions - the related issue #2391 mentions "we need to manually set the output of the padded positions to zero" (similar to how FusedAttention zeroes output in C++ for THD format). if flash_attn_3 doesn't zero these internally, output may have garbage values in padded positions. have you verified that flash_attn_3 correctly handles padding internally with these parameters?
ea51821 to
e338049
Compare
|
/te-ci pytorch L2 |
|
|
||
| pad_between_seqs = False | ||
| if qkv_format == "thd" and cu_seqlens_q_padded is not None: | ||
| pad_between_seqs = not torch.equal(cu_seqlens_q_padded, cu_seqlens_q) |
There was a problem hiding this comment.
Can pad_between_seqs be decided ahead of time, passed by the user or something? This wouldn't be CUDA Graph-compatible right?
There was a problem hiding this comment.
This pattern exists in dpa.py as well. But yes, it's definitely redundant here
|
/te-ci pytorch L1 |
|
/te-ci pytorch L3 |
b0a3c64 to
057f406
Compare
|
/te-ci pytorch L3 |
1 similar comment
|
/te-ci pytorch L3 |
00bdc92 to
0f48ebc
Compare
| if not FlashAttentionUtils.v3_is_installed: | ||
| pytest.skip("pad_between_seqs with CP requires Flash Attention v3!") | ||
| if cp_comm_type == "a2a+p2p": | ||
| pytest.skip("pad_between_seqs is not yet supported with A2A+P2P CP comm type!") |
| if pad_between_seqs: | ||
| dq, dk, dv = [torch.zeros_like(x) for x in [q_part, k_part, v_part]] | ||
| else: | ||
| dq, dk, dv = [torch.empty_like(x) for x in [q_part, k_part, v_part]] |
There was a problem hiding this comment.
Just to confirm, we can't do this for fwd, right? Because fwd output is not allocated by us.
There was a problem hiding this comment.
It's a limitation in Flash Attention code - forward never mutates out (so pre-zeroing is overwritten), backward treats dq/dk/dv as in-place mutable (so pre-zeroing sticks). Also this zeroing out works only for CP code where we can provide the args.
None of the zeroing works for non-CP path because we only have the forward call in TE.
FA3 / Hopper (hopper/flash_attn_interface.py)
- Forward: mutates_args=() _ namespace flash_attn_3::_flash_attn_forward
- Backward: mutates_args=("dq", "dk", "dv") _ namespace flash_attn_3::_flash_attn_backward
|
/te-ci pytorch L3 |
Add support for padding between sequences (pad_between_seqs) in the FlashAttention 3 backend when used with context parallelism (CP). Key changes: - backends.py: Pass fa_pad_between_seqs through to FA3 forward/backward - context_parallel.py: Handle pad_between_seqs in A2A and P2P CP paths, zero FA3 padding garbage in CP forward, fix a2a backward alignment - dot_product_attention.py: Auto-detect pad_between_seqs from cu_seqlens - utils.py: Gate FA3 deterministic backward for hdim>=256, fix flash_attn_supported override for cross-attention and large head_dim, disable UnfusedDotProductAttention for pad_between_seqs, add SM100+ FA3 skip Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Add test parametrization for pad_between_seqs in flash attention tests. Update run_attention_with_cp.py to support the new parameter and fix batch boundary alignment in the non-CP FA3 path. Run tests in parallel when multiple GPUs are available. Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Add deterministic CP test runs to L3 FA versions test. Support TE_PATH positional arg and fix GPU threshold for parallel test execution. Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
…raint The previous check disabled FA3 for deterministic mode whenever head_dim_qk > 128, which was overly conservative — FA3 forward supports deterministic execution at any head dim. The actual constraint from flash_api.cpp is that the backward pass does not support deterministic mode when max(head_size, head_size_v) >= 256. Narrow the gate to only disable FA3 during training (backward) and raise the threshold to >= 256, checking both head_dim_qk and head_dim_v to handle MLA configs with asymmetric head dimensions. Ref: https://github.com/Dao-AILab/flash-attention/blob/ac6f2eb5/hopper/flash_api.cpp#L1370 Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
9c01601 to
4745f98
Compare
The pad_between_seqs gate in get_attention_backend only disabled FlashAttention 2, letting FA4 leak through to the test-time fused-vs-flash comparison. On B200 runners that install flash-attn-4, this caused test_dpa_qkv_layout_thd to compare FusedAttention against an FA4 output whose padded positions contain garbage, producing 48 numerics failures in L3_pytorch_FA_versions_test--B200_1GPU. The log message already claimed FA4 would be disabled — this change makes the code match the message: set use_flash_attention_4 = False alongside use_flash_attention_2 when pad_between_seqs is True. FA3 continues to support pad_between_seqs via seqused_q/seqused_k. Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
pad_between_seqs support for non-CP and CP (A2A and P2P) with FA3 + THD (varlen)
…_attn_pad_bw_seqs
|
/te-ci pytorch L3 |
FA4 install brings in nvidia-cutlass-dsl, whose `import cutlass` adds cutlass/base_dsl/ to sys.path. That directory contains a utils/ package that shadows tests/pytorch/utils.py, breaking collection of test_attention_with_cp.py with: ImportError: cannot import name 'ModelConfig' from 'utils' Prepend $TE_PATH/tests/pytorch to PYTHONPATH so the local utils.py is always resolved first, regardless of what FA4 dependencies install. Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
…_attn_pad_bw_seqs
…s its a known cudnn issue Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
…_attn_pad_bw_seqs
for more information, see https://pre-commit.ci
…ransformerEngine into flash_attn_pad_bw_seqs
|
/te-ci pytorch L3 |
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
|
/te-ci pytorch L3 |
|
/te-ci pytorch L3 |
…_attn_pad_bw_seqs
…ransformerEngine into flash_attn_pad_bw_seqs
|
/te-ci pytorch L3 |
Description
TLDR
Enable
pad_between_seqs=Truefor FlashAttention 3 with THD format — both for context parallelism (A2A and P2P comm types) and non-CP paths. Previouslypad_between_seqswas only supported with FusedAttention.Problem
When using THD format with variable-length sequences, sequences are padded for divisibility across CP ranks. With
pad_between_seqs=True, the attention kernel needs to know actual (unpadded) token counts so it doesn't compute attention over padding tokens. FusedAttention already handled this viacu_seqlens_q_padded, but FlashAttention (both FA2 and FA3) hadpad_between_seqshardcoded toFalsein the CP path, and FA2 was entirely disabled forpad_between_seqs + thd. FA3 can natively handle this via itsseqused_q/seqused_kmechanism.Solution
Use FA3's
seqused_q/seqused_ktensors to communicate actual token counts per batch element. Passcu_seqlens_q_paddedfor tensor memory layout while derivingseqused_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1]from the realcu_seqlens. This applies to both the CP path (A2A and P2P) and the non-CP path.Fixes #2399
Type of change
Changes
Please list the changes introduced in this PR:
context_parallel.py
get_fa_args(): Addseqused_q/seqused_kparameters, pass through to FA3 forward and backward positional arg lists (replacing hardcodedNones).cp_p2p_fwd_flash_attn()/cp_p2p_bwd_flash_attn(): Acceptpad_between_seqs,cu_seqlens_q_padded,cu_seqlens_kv_padded. When enabled, derivesequsedtensors and overridecu_seqlensto padded versions (with half-padding for lower-triangle/upper-triangle sections).AttnFuncWithCPAndKVP2P: Threadpad_between_seqsand padded cu_seqlens through all forward/backwardcp_p2p_fwd/bwd_flash_attncall sites. Savectx.pad_between_seqsfor backward.AttnFuncWithCPAndQKVOA2A.forward(): Addpad_between_seqsparameter. When enabled with FA3+THD, derivesequsedand swapcu_seqlensfor padded versions before callingget_fa_args().AttnFuncWithCPAndQKVOA2A.backward(): Same seqused/cu_seqlens override. Usezeros_like(notempty_like) for gradient init whenpad_between_seqssince FA3 skips padding positions. Add extraNonein return tuple for the newpad_between_seqsgradient slot.attn_forward_func_with_cp(): Passpad_between_seqsin A2A args list.backends.py
FlashAttention.forward(): Acceptcu_seqlens_q_padded/cu_seqlens_kv_padded. Detectpad_between_seqsby comparing padded vs actual cu_seqlens. Pass padded cu_seqlens to CP path. For non-CP FA3 path, derive and passseqused_q/seqused_k.dot_product_attention.py
cu_seqlens_q_padded/cu_seqlens_kv_paddedthrough toFlashAttention.utils.py
pad_between_seqs + thd. FA3 handles this natively viaseqused.test_attention_with_cp.py
@pytest.mark.parametrize("pad_between_seqs", [False, True])to flash attention CP tests.pad_between_seqs=Truefor non-THD formats, when FA3 is not installed, and fora2a+p2pcomm type (not yet supported).run_attention_with_cp.py
pad_between_seqsthroughgenerate_input_shapes()andrun_dpa_with_cp().pad_between_seqs, setcu_seqlens_qto actual lengths (not just for FusedAttention).nan_to_num(nan=0.0).test_attention.py
_run_dot_product_attention()(previously FlashAttention used original unpadded inputs).cu_seqlens_q_padded/cu_seqlens_kv_paddedandpad_between_seqsto DPA call for FlashAttention backend.pad_between_seqs=Trueto parametrize with skip for non-THD formats.New Tests
CP tests (
test_attention_with_cp.py)Added
@pytest.mark.parametrize("pad_between_seqs", [False, True])totest_cp_with_flash_attention. Skip conditions: non-THD formats, FA3 not installed,a2a+p2pcomm type.5 new tests that run (all
pad_between_seqs=True, thd, bf16):True-p2p-thd-cp_1_0-bf16True-p2p-thd-cp_2_1-bf16True-a2a-thd-cp_1_0-bf16True-a2a-thd-cp_1_2-bf16True-a2a-thd-cp_2_1-bf16Non-CP tests (
test_attention.py)Added
Trueto@pytest.mark.parametrize("pad_between_seqs", [False, True])ontest_dot_product_attention, with skip for non-THD. Also changed_run_dot_product_attentionso FlashAttention uses padded inputs/cu_seqlens and receivespad_between_seqs=True.48 new test IDs collected, but all are skipped because the main parametrize uses
qkv_layout=None(defaults to sbhd, not thd). The non-CPpad_between_seqs+ FA3 code path is exercised indirectly when other test functions calltest_dot_product_attentionwithqkv_layout="thd_thd_thd"(e.g.,test_dpa_softmax_thd).Checklist: