Skip to content

[PyTorch] Add pad_between_seqs support for non-CP and CP (A2A and P2P) with FA3 + THD (varlen)#2596

Open
sudhakarsingh27 wants to merge 16 commits intoNVIDIA:mainfrom
sudhakarsingh27:flash_attn_pad_bw_seqs
Open

[PyTorch] Add pad_between_seqs support for non-CP and CP (A2A and P2P) with FA3 + THD (varlen)#2596
sudhakarsingh27 wants to merge 16 commits intoNVIDIA:mainfrom
sudhakarsingh27:flash_attn_pad_bw_seqs

Conversation

@sudhakarsingh27
Copy link
Copy Markdown
Collaborator

@sudhakarsingh27 sudhakarsingh27 commented Jan 14, 2026

Description

TLDR

Enable pad_between_seqs=True for FlashAttention 3 with THD format — both for context parallelism (A2A and P2P comm types) and non-CP paths. Previously pad_between_seqs was only supported with FusedAttention.

Problem

When using THD format with variable-length sequences, sequences are padded for divisibility across CP ranks. With pad_between_seqs=True, the attention kernel needs to know actual (unpadded) token counts so it doesn't compute attention over padding tokens. FusedAttention already handled this via cu_seqlens_q_padded, but FlashAttention (both FA2 and FA3) had pad_between_seqs hardcoded to False in the CP path, and FA2 was entirely disabled for pad_between_seqs + thd. FA3 can natively handle this via its seqused_q/seqused_k mechanism.

Solution

Use FA3's seqused_q/seqused_k tensors to communicate actual token counts per batch element. Pass cu_seqlens_q_padded for tensor memory layout while deriving seqused_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1] from the real cu_seqlens. This applies to both the CP path (A2A and P2P) and the non-CP path.

Fixes #2399

Type of change

  • New feature (non-breaking change which adds functionality)

Changes

Please list the changes introduced in this PR:

context_parallel.py

  • get_fa_args(): Add seqused_q/seqused_k parameters, pass through to FA3 forward and backward positional arg lists (replacing hardcoded Nones).
  • cp_p2p_fwd_flash_attn() / cp_p2p_bwd_flash_attn(): Accept pad_between_seqs, cu_seqlens_q_padded, cu_seqlens_kv_padded. When enabled, derive seqused tensors and override cu_seqlens to padded versions (with half-padding for lower-triangle/upper-triangle sections).
  • AttnFuncWithCPAndKVP2P: Thread pad_between_seqs and padded cu_seqlens through all forward/backward cp_p2p_fwd/bwd_flash_attn call sites. Save ctx.pad_between_seqs for backward.
  • AttnFuncWithCPAndQKVOA2A.forward(): Add pad_between_seqs parameter. When enabled with FA3+THD, derive seqused and swap cu_seqlens for padded versions before calling get_fa_args().
  • AttnFuncWithCPAndQKVOA2A.backward(): Same seqused/cu_seqlens override. Use zeros_like (not empty_like) for gradient init when pad_between_seqs since FA3 skips padding positions. Add extra None in return tuple for the new pad_between_seqs gradient slot.
  • attn_forward_func_with_cp(): Pass pad_between_seqs in A2A args list.

backends.py

  • FlashAttention.forward(): Accept cu_seqlens_q_padded/cu_seqlens_kv_padded. Detect pad_between_seqs by comparing padded vs actual cu_seqlens. Pass padded cu_seqlens to CP path. For non-CP FA3 path, derive and pass seqused_q/seqused_k.

dot_product_attention.py

  • Pass cu_seqlens_q_padded/cu_seqlens_kv_padded through to FlashAttention.

utils.py

  • Only disable FA2 (not FA3) when pad_between_seqs + thd. FA3 handles this natively via seqused.

test_attention_with_cp.py

  • Add @pytest.mark.parametrize("pad_between_seqs", [False, True]) to flash attention CP tests.
  • Skip pad_between_seqs=True for non-THD formats, when FA3 is not installed, and for a2a+p2p comm type (not yet supported).

run_attention_with_cp.py

  • Thread pad_between_seqs through generate_input_shapes() and run_dpa_with_cp().
  • When pad_between_seqs, set cu_seqlens_q to actual lengths (not just for FusedAttention).
  • Handle FA3 backward NaN at padding positions: nan_to_num(nan=0.0).
  • Zero padding positions explicitly before comparison (FA3 doesn't guarantee zeros at padding slots).
  • Add tensor names to NaN/Inf assertion messages for debuggability.

test_attention.py

  • Group FlashAttention with FusedAttention for padded input/output handling in _run_dot_product_attention() (previously FlashAttention used original unpadded inputs).
  • Pass cu_seqlens_q_padded/cu_seqlens_kv_padded and pad_between_seqs to DPA call for FlashAttention backend.
  • Add pad_between_seqs=True to parametrize with skip for non-THD formats.

New Tests

CP tests (test_attention_with_cp.py)

Added @pytest.mark.parametrize("pad_between_seqs", [False, True]) to test_cp_with_flash_attention. Skip conditions: non-THD formats, FA3 not installed, a2a+p2p comm type.

5 new tests that run (all pad_between_seqs=True, thd, bf16):

Test CP comm Model config
True-p2p-thd-cp_1_0-bf16 P2P causal, 1 head
True-p2p-thd-cp_2_1-bf16 P2P causal, 2 heads
True-a2a-thd-cp_1_0-bf16 A2A causal, 1 head
True-a2a-thd-cp_1_2-bf16 A2A causal, sliding window
True-a2a-thd-cp_2_1-bf16 A2A causal, 2 heads

Non-CP tests (test_attention.py)

Added True to @pytest.mark.parametrize("pad_between_seqs", [False, True]) on test_dot_product_attention, with skip for non-THD. Also changed _run_dot_product_attention so FlashAttention uses padded inputs/cu_seqlens and receives pad_between_seqs=True.

48 new test IDs collected, but all are skipped because the main parametrize uses qkv_layout=None (defaults to sbhd, not thd). The non-CP pad_between_seqs + FA3 code path is exercised indirectly when other test functions call test_dot_product_attention with qkv_layout="thd_thd_thd" (e.g., test_dpa_softmax_thd).

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@sudhakarsingh27 sudhakarsingh27 self-assigned this Jan 14, 2026
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Jan 14, 2026

Greptile Summary

This PR extends pad_between_seqs=True support to FlashAttention 3 with THD (varlen) format, covering both the non-CP path (FlashAttention.forward) and the CP paths (A2A and P2P). The core mechanism uses FA3's seqused_q/seqused_k kwargs to communicate actual token counts while cu_seqlens_q_padded describes the padded memory layout. Previous review comments flagging missing use_flash_attention_4 = False and wrong cu_seqlens in the non-CP FA3 path are both addressed in this revision.

Confidence Score: 5/5

Safe to merge; previous P1 issues (FA4 not disabled, wrong cu_seqlens in non-CP FA3) are correctly addressed, and the implementation is consistent across all CP paths.

All previously flagged P1 issues are resolved: utils.py now sets use_flash_attention_4 = False and use_unfused_attention = False; backends.py non-CP FA3 path passes cu_seqlens_q_padded for memory layout and seqused_q/k as kwargs; P2P and A2A paths derive seqused from per-step actual cu_seqlens and swap to padded cu_seqlens for memory layout. The only remaining issue is a P2 edge case around passing seqused_q/k to flash_attn_with_kvcache_v3 when inference_params is set, which is an untested combination.

backends.py lines 1138-1157 (seqused kwarg forwarding to kvcache path when pad_between_seqs + inference_params both set).

Important Files Changed

Filename Overview
transformer_engine/pytorch/attention/dot_product_attention/utils.py Correctly disables FA2, FA4, and UnfusedAttention for pad_between_seqs=True (FA4 disable was missing in previous version); also fixes FA2 paged/non-paged page_size check and FA3 deterministic head_dim condition.
transformer_engine/pytorch/attention/dot_product_attention/backends.py Non-CP FA3 path now correctly passes cu_seqlens_q_padded for memory layout and seqused_q/k for actual token counts when pad_between_seqs=True; CP path threads pad_between_seqs and padded cu_seqlens through to context_parallel helpers.
transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py get_fa_args and cp_p2p_fwd/bwd_flash_attn plumbed with seqused_q/k; A2A and P2P forward/backward correctly derive seqused from per-step actual cu_seqlens and override cu_seqlens to padded versions; zeros_like used for gradient init when pad_between_seqs.
transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py Passes pad_between_seqs and padded cu_seqlens to FlashAttention.forward; minimal targeted change.
tests/pytorch/attention/run_attention_with_cp.py Threads fa_pad_between_seqs through generate_input_shapes and run_dpa_with_cp; handles FA3 garbage at padding positions by explicitly zeroing before comparison; verifies CP backward tensors (dq_/dk_/dv_) have clean padding.
tests/pytorch/attention/test_attention_with_cp.py Adds pad_between_seqs parametrize to test_cp_with_flash_attention with appropriate skips for non-THD, missing FA3, and unsupported a2a+p2p comm type.
tests/pytorch/attention/test_attention.py Adds pad_between_seqs=True to DPA test parametrize; FlashAttention now uses padded inputs and receives cu_seqlens_q_padded/cu_seqlens_kv_padded consistent with the new backends.py interface.
qa/L1_pytorch_distributed_unittest/test.sh Parallelizes CP test runs on ≥8-GPU machines using disjoint GPU sets; falls back to sequential on smaller hosts.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A["DotProductAttention.forward\npad_between_seqs, cu_seqlens_q_padded"] --> B{CP enabled?}
    B -->|Yes| C["attn_forward_func_with_cp\npad_between_seqs threaded through"]
    B -->|No| D["FlashAttention.forward\npad_between_seqs, padded cu_seqlens"]

    C --> E{A2A or P2P?}
    E -->|A2A| F["AttnFuncWithCPAndQKVOA2A\nseqused = cu_seqlens[1:] - cu_seqlens[:-1]\ncu_seqlens ← cu_seqlens_padded"]
    E -->|P2P| G["AttnFuncWithCPAndKVP2P\ncp_p2p_fwd_flash_attn per tile"]

    G --> H{tile section?}
    H -->|diagonal/all| I["seqused from per_step cu_seqlens\ncu_seqlens ← cu_seqlens_padded"]
    H -->|lower-triangle| J["seqused from per_step cu_seqlens\ncu_seqlens_kv ← cu_seqlens_kv_padded // 2"]
    H -->|upper-triangle| K["seqused from per_step cu_seqlens\ncu_seqlens_q ← cu_seqlens_q_padded // 2"]

    D --> L{FA version?}
    L -->|FA3 non-CP| M["cu_seqlens_padded → positional args\nseqused_q/k → kwargs\nfa3 flash_attn_varlen_func_v3"]
    L -->|FA4 / FA2| N["Blocked by utils.py:\nuse_flash_attention_2=False\nuse_flash_attention_4=False"]

    F --> O["get_fa_args with seqused_q/k\nflash_attn_varlen_func_v3"]
    I --> O
    J --> O
    K --> O
Loading

Reviews (32): Last reviewed commit: "Merge branch 'flash_attn_pad_bw_seqs' of..." | Re-trigger Greptile

Copy link
Copy Markdown
Contributor

@greptile-apps greptile-apps Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

4 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Comment on lines +974 to +983
# if `pad_between_seqs` is True, provide flash_attn_3 with `seqused_q` and `seqused_k`
# in addition to `cu_seqlens_q_padded` and `cu_seqlens_kv_padded` to avoid affecting the
# padding positions.
if pad_between_seqs:
fa_3_optional_forward_kwargs["seqused_q"] = (
cu_seqlens_q[1:] - cu_seqlens_q[:-1]
)
fa_3_optional_forward_kwargs["seqused_k"] = (
cu_seqlens_kv[1:] - cu_seqlens_kv[:-1]
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: verify that flash_attn_3 with seqused_q/seqused_k truly avoids writing to padding positions - the related issue #2391 mentions "we need to manually set the output of the padded positions to zero" (similar to how FusedAttention zeroes output in C++ for THD format). if flash_attn_3 doesn't zero these internally, output may have garbage values in padded positions. have you verified that flash_attn_3 correctly handles padding internally with these parameters?

@sudhakarsingh27 sudhakarsingh27 force-pushed the flash_attn_pad_bw_seqs branch from ea51821 to e338049 Compare March 10, 2026 23:37
@sudhakarsingh27
Copy link
Copy Markdown
Collaborator Author

/te-ci pytorch L2

@sudhakarsingh27 sudhakarsingh27 changed the title Flash attn pad bw seqs [PyTorch] Add pad_between_seqs support for A2A and P2P CP with FA3 + THD Mar 11, 2026
Comment thread transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py Outdated
Comment thread transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py Outdated

pad_between_seqs = False
if qkv_format == "thd" and cu_seqlens_q_padded is not None:
pad_between_seqs = not torch.equal(cu_seqlens_q_padded, cu_seqlens_q)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can pad_between_seqs be decided ahead of time, passed by the user or something? This wouldn't be CUDA Graph-compatible right?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This pattern exists in dpa.py as well. But yes, it's definitely redundant here

Comment thread tests/pytorch/attention/run_attention_with_cp.py Outdated
@sudhakarsingh27
Copy link
Copy Markdown
Collaborator Author

/te-ci pytorch L1

Comment thread tests/pytorch/attention/run_attention_with_cp.py Outdated
Comment thread tests/pytorch/attention/run_attention_with_cp.py Outdated
Comment thread tests/pytorch/attention/run_attention_with_cp.py Outdated
@sudhakarsingh27
Copy link
Copy Markdown
Collaborator Author

/te-ci pytorch L3

@sudhakarsingh27 sudhakarsingh27 force-pushed the flash_attn_pad_bw_seqs branch from b0a3c64 to 057f406 Compare April 9, 2026 05:18
@sudhakarsingh27
Copy link
Copy Markdown
Collaborator Author

/te-ci pytorch L3

1 similar comment
@sudhakarsingh27
Copy link
Copy Markdown
Collaborator Author

/te-ci pytorch L3

@sudhakarsingh27 sudhakarsingh27 force-pushed the flash_attn_pad_bw_seqs branch from 00bdc92 to 0f48ebc Compare April 10, 2026 15:04
Comment thread qa/L3_pytorch_FA_versions_test/test.sh
Comment thread tests/pytorch/attention/run_attention_with_cp.py Outdated
Comment thread tests/pytorch/attention/test_attention.py Outdated
if not FlashAttentionUtils.v3_is_installed:
pytest.skip("pad_between_seqs with CP requires Flash Attention v3!")
if cp_comm_type == "a2a+p2p":
pytest.skip("pad_between_seqs is not yet supported with A2A+P2P CP comm type!")
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about AG?

if pad_between_seqs:
dq, dk, dv = [torch.zeros_like(x) for x in [q_part, k_part, v_part]]
else:
dq, dk, dv = [torch.empty_like(x) for x in [q_part, k_part, v_part]]
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to confirm, we can't do this for fwd, right? Because fwd output is not allocated by us.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a limitation in Flash Attention code - forward never mutates out (so pre-zeroing is overwritten), backward treats dq/dk/dv as in-place mutable (so pre-zeroing sticks). Also this zeroing out works only for CP code where we can provide the args.

None of the zeroing works for non-CP path because we only have the forward call in TE.

 FA3 / Hopper (hopper/flash_attn_interface.py)
- Forward: mutates_args=() _ namespace flash_attn_3::_flash_attn_forward 
- Backward: mutates_args=("dq", "dk", "dv") _ namespace flash_attn_3::_flash_attn_backward

@ptrendx ptrendx added this to the 2.15 milestone Apr 23, 2026
Comment thread qa/L3_pytorch_FA_versions_test/test.sh Outdated
Comment thread qa/L3_pytorch_FA_versions_test/test.sh Outdated
Comment thread tests/pytorch/attention/run_attention_with_cp.py Outdated
Comment thread qa/L3_pytorch_FA_versions_test/test.sh
@sudhakarsingh27
Copy link
Copy Markdown
Collaborator Author

/te-ci pytorch L3

Add support for padding between sequences (pad_between_seqs) in the
FlashAttention 3 backend when used with context parallelism (CP).

Key changes:
- backends.py: Pass fa_pad_between_seqs through to FA3 forward/backward
- context_parallel.py: Handle pad_between_seqs in A2A and P2P CP paths,
  zero FA3 padding garbage in CP forward, fix a2a backward alignment
- dot_product_attention.py: Auto-detect pad_between_seqs from cu_seqlens
- utils.py: Gate FA3 deterministic backward for hdim>=256, fix
  flash_attn_supported override for cross-attention and large head_dim,
  disable UnfusedDotProductAttention for pad_between_seqs, add SM100+
  FA3 skip

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Add test parametrization for pad_between_seqs in flash attention tests.
Update run_attention_with_cp.py to support the new parameter and fix
batch boundary alignment in the non-CP FA3 path. Run tests in parallel
when multiple GPUs are available.

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Add deterministic CP test runs to L3 FA versions test. Support TE_PATH
positional arg and fix GPU threshold for parallel test execution.

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
…raint

The previous check disabled FA3 for deterministic mode whenever
head_dim_qk > 128, which was overly conservative — FA3 forward supports
deterministic execution at any head dim. The actual constraint from
flash_api.cpp is that the backward pass does not support deterministic
mode when max(head_size, head_size_v) >= 256.

Narrow the gate to only disable FA3 during training (backward) and
raise the threshold to >= 256, checking both head_dim_qk and head_dim_v
to handle MLA configs with asymmetric head dimensions.

Ref: https://github.com/Dao-AILab/flash-attention/blob/ac6f2eb5/hopper/flash_api.cpp#L1370

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
@sudhakarsingh27 sudhakarsingh27 force-pushed the flash_attn_pad_bw_seqs branch from 9c01601 to 4745f98 Compare April 24, 2026 23:02
The pad_between_seqs gate in get_attention_backend only disabled
FlashAttention 2, letting FA4 leak through to the test-time
fused-vs-flash comparison. On B200 runners that install flash-attn-4,
this caused test_dpa_qkv_layout_thd to compare FusedAttention against
an FA4 output whose padded positions contain garbage, producing 48
numerics failures in L3_pytorch_FA_versions_test--B200_1GPU.

The log message already claimed FA4 would be disabled — this change
makes the code match the message: set use_flash_attention_4 = False
alongside use_flash_attention_2 when pad_between_seqs is True. FA3
continues to support pad_between_seqs via seqused_q/seqused_k.

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
@sudhakarsingh27 sudhakarsingh27 changed the title [PyTorch] Add pad_between_seqs support for A2A and P2P CP with FA3 + THD [PyTorch] Add pad_between_seqs support for non-CP and CP (A2A and P2P) with FA3 + THD (varlen) Apr 24, 2026
@sudhakarsingh27
Copy link
Copy Markdown
Collaborator Author

/te-ci pytorch L3

cyanguwa
cyanguwa previously approved these changes Apr 25, 2026
FA4 install brings in nvidia-cutlass-dsl, whose `import cutlass`
adds cutlass/base_dsl/ to sys.path. That directory contains a utils/
package that shadows tests/pytorch/utils.py, breaking collection of
test_attention_with_cp.py with:
  ImportError: cannot import name 'ModelConfig' from 'utils'

Prepend $TE_PATH/tests/pytorch to PYTHONPATH so the local utils.py
is always resolved first, regardless of what FA4 dependencies install.

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
@sudhakarsingh27
Copy link
Copy Markdown
Collaborator Author

/te-ci pytorch L3

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
@sudhakarsingh27
Copy link
Copy Markdown
Collaborator Author

/te-ci pytorch L3

@sudhakarsingh27
Copy link
Copy Markdown
Collaborator Author

/te-ci pytorch L3

@sudhakarsingh27
Copy link
Copy Markdown
Collaborator Author

/te-ci pytorch L3

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Support FlashAttention with pad_between_seqs=True

3 participants