[All] Remove max512 backend#2949
Conversation
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
|
/te-ci |
for more information, see https://pre-commit.ci
Greptile SummaryThis PR removes the legacy
Confidence Score: 3/5Not safe to merge as-is: FP8 + T3HD context-parallel paths will crash at runtime. Two P1 bugs in context_parallel.py — removing the T3HD guard for FP8 aux-tensor unpacking breaks a clearly exercised code path with a definite ValueError at runtime. The rest of the max512 removal is clean and well-scoped across C++, JAX, and PyTorch layers. transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py — lines 982-983 and 3202-3203 for the forward unpack, and lines 1169/1185 for the backward aux_tensors construction in the FP8+T3HD path. Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[nvte_get_fused_attn_backend] --> B{q/k dtype}
B -->|FP8| C[FP8 backend logic unchanged]
B -->|FP16/BF16| D{flag_arb?}
D -->|true| E[NVTE_F16_arbitrary_seqlen]
D -->|false| F[NVTE_No_Backend]
E --> G{cudnn >= 8900?}
G -->|yes| H[Return F16_arbitrary_seqlen]
G -->|no| I[Return No_Backend]
subgraph REMOVED [Removed max512 path]
R1[flag_m512 check seqlen<=512 head_dim=64 no GQA]
R2[NVTE_F16_max512_seqlen]
R3[env var override NVTE_FUSED_ATTN_BACKEND]
end
subgraph CP_BUG [Bug in context_parallel.py]
CP1[fp8 + T3HD forward: aux_ctx_tensors has 3 elements M ZInv rng_state]
CP2[new code unpacks 2 only -> ValueError]
end
|
| if fp8: | ||
| if qkv_layout != "t3hd": | ||
| softmax_lse_per_step, rng_states = aux_ctx_tensors | ||
| else: | ||
| softmax_lse_per_step, _, rng_states = aux_ctx_tensors | ||
| softmax_lse_per_step, rng_states = aux_ctx_tensors |
There was a problem hiding this comment.
Unpack crash for FP8 + T3HD layout
When fp8=True and qkv_layout == "t3hd", aux_ctx_tensors contains 3 elements — [M, ZInv, rng_state] — as confirmed by pytorch/csrc/extensions/attention.cpp lines 285-292 (NVTE_T3HD allocates an extra ZInv tensor). Unpacking into only 2 variables raises ValueError: too many values to unpack. The original guard that handled this case was removed without replacing the T3HD-specific path.
The same crash exists in AttnFuncWithCPAndKVAllGather.forward at line 3202-3203.
| if fp8: | ||
| if qkv_layout != "t3hd": | ||
| softmax_lse_per_step[i], rng_states[i] = aux_ctx_tensors | ||
| else: | ||
| softmax_lse_per_step[i], _, rng_states[i] = aux_ctx_tensors | ||
| softmax_lse_per_step[i], rng_states[i] = aux_ctx_tensors |
There was a problem hiding this comment.
Description
Please include a brief summary of the changes, relevant motivation and context.
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: