Skip to content

[All] Remove max512 backend#2949

Open
cyanguwa wants to merge 2 commits intoNVIDIA:mainfrom
cyanguwa:remove_max512_subbackend
Open

[All] Remove max512 backend#2949
cyanguwa wants to merge 2 commits intoNVIDIA:mainfrom
cyanguwa:remove_max512_subbackend

Conversation

@cyanguwa
Copy link
Copy Markdown
Collaborator

Description

Please include a brief summary of the changes, relevant motivation and context.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
@cyanguwa
Copy link
Copy Markdown
Collaborator Author

/te-ci

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Apr 30, 2026

Greptile Summary

This PR removes the legacy NVTE_F16_max512_seqlen (backend 0) cuDNN fused-attention kernel, which only supported FP16/BF16 with sequence lengths ≤ 512, head_dim = 64, and no GQA. All remaining F16 traffic now routes through NVTE_F16_arbitrary_seqlen (backend 1).

  • P1 — FP8 + T3HD context-parallel crash: The T3HD-specific unpacking guard was removed in context_parallel.py, but pytorch/csrc/extensions/attention.cpp still allocates a 3-element aux_ctx_tensors list ([M, ZInv, rng_state]) for FP8 + NVTE_T3HD layout. Both cp_p2p_fwd_fused_attn (line 982) and AttnFuncWithCPAndKVAllGather.forward (line 3202) now unpack exactly 2 elements, causing an immediate ValueError: too many values to unpack for any FP8 + T3HD context-parallel workload.
  • P1 — same crash, KV-allgather backward path (line 3565–3568): aux_ctx_tensors is built with 2 elements for FP8 + T3HD, where the backward kernel expects 3.

Confidence Score: 3/5

Not safe to merge as-is: FP8 + T3HD context-parallel paths will crash at runtime.

Two P1 bugs in context_parallel.py — removing the T3HD guard for FP8 aux-tensor unpacking breaks a clearly exercised code path with a definite ValueError at runtime. The rest of the max512 removal is clean and well-scoped across C++, JAX, and PyTorch layers.

transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py — lines 982-983 and 3202-3203 for the forward unpack, and lines 1169/1185 for the backward aux_tensors construction in the FP8+T3HD path.

Important Files Changed

Filename Overview
transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py Removed T3HD-specific FP8 aux-tensor handling, introducing a ValueError crash for FP8+T3HD layouts in both the P2P and KV-allgather context-parallel paths.
transformer_engine/common/fused_attn/fused_attn.cpp Removed max512 backend selection logic and its dispatch in fwd/bwd kernels; backend selection for F16 now always resolves to NVTE_F16_arbitrary_seqlen when flag_arb is true.
transformer_engine/pytorch/cpp_extensions/fused_attn.py Removed F16_max512_seqlen from FusedAttnBackend dict and simplified rng_elts_per_thread logic; constant renamed from BACKEND_F16m512_FP8_THREADS_PER_CTA to BACKEND_FP8_THREADS_PER_CTA with same value.
transformer_engine/pytorch/attention/dot_product_attention/utils.py Removed max512-specific backend filters (sliding window, post_scale_bias shape, and forced NVTE_FUSED_ATTN_BACKEND=1 override).
transformer_engine/jax/cpp_extensions/attention.py Removed max512-specific softmax shape allocation branch; JAX forward now only handles F16_arbitrary_seqlen and FP8 backends.
docs/envvars.rst Updated NVTE_FUSED_ATTN_BACKEND docs to remove backend 0; however the env var override for backend 1 is no longer enforced in C++.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[nvte_get_fused_attn_backend] --> B{q/k dtype}
    B -->|FP8| C[FP8 backend logic unchanged]
    B -->|FP16/BF16| D{flag_arb?}
    D -->|true| E[NVTE_F16_arbitrary_seqlen]
    D -->|false| F[NVTE_No_Backend]
    E --> G{cudnn >= 8900?}
    G -->|yes| H[Return F16_arbitrary_seqlen]
    G -->|no| I[Return No_Backend]

    subgraph REMOVED [Removed max512 path]
        R1[flag_m512 check seqlen<=512 head_dim=64 no GQA]
        R2[NVTE_F16_max512_seqlen]
        R3[env var override NVTE_FUSED_ATTN_BACKEND]
    end

    subgraph CP_BUG [Bug in context_parallel.py]
        CP1[fp8 + T3HD forward: aux_ctx_tensors has 3 elements M ZInv rng_state]
        CP2[new code unpacks 2 only -> ValueError]
    end
Loading

Comments Outside Diff (1)

  1. docs/envvars.rst, line 7-8 (link)

    P2 NVTE_FUSED_ATTN_BACKEND=1 is now silently ignored for F16

    The env-var override for the F16 case was only checked inside the removed seqlen <= 512 block in fused_attn.cpp. Now that the block is gone, setting NVTE_FUSED_ATTN_BACKEND=1 for an F16 workload has no effect — the backend is determined purely by flag_arb. The docs still describe value 1 as a valid override; this is misleading. Consider adding a note that the env var is now effectively a no-op for F16 (only one F16 backend exists), or re-add env-var gating for the arbitrary-seqlen backend selection.

Reviews (1): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile

Comment on lines 982 to +983
if fp8:
if qkv_layout != "t3hd":
softmax_lse_per_step, rng_states = aux_ctx_tensors
else:
softmax_lse_per_step, _, rng_states = aux_ctx_tensors
softmax_lse_per_step, rng_states = aux_ctx_tensors
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Unpack crash for FP8 + T3HD layout

When fp8=True and qkv_layout == "t3hd", aux_ctx_tensors contains 3 elements[M, ZInv, rng_state] — as confirmed by pytorch/csrc/extensions/attention.cpp lines 285-292 (NVTE_T3HD allocates an extra ZInv tensor). Unpacking into only 2 variables raises ValueError: too many values to unpack. The original guard that handled this case was removed without replacing the T3HD-specific path.

The same crash exists in AttnFuncWithCPAndKVAllGather.forward at line 3202-3203.

Comment on lines 3202 to +3203
if fp8:
if qkv_layout != "t3hd":
softmax_lse_per_step[i], rng_states[i] = aux_ctx_tensors
else:
softmax_lse_per_step[i], _, rng_states[i] = aux_ctx_tensors
softmax_lse_per_step[i], rng_states[i] = aux_ctx_tensors
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Same FP8 + T3HD unpack crash (KV-allgather path)

Identical issue as cp_p2p_fwd_fused_attn: for fp8=True and qkv_layout == "t3hd", aux_ctx_tensors is a 3-element list [M, ZInv, rng_state], but this line tries to unpack exactly 2, raising ValueError: too many values to unpack.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant