Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/envvars.rst
Original file line number Diff line number Diff line change
Expand Up @@ -142,9 +142,9 @@ Attention Backend Selection

.. envvar:: NVTE_FUSED_ATTN_BACKEND

:Type: ``int`` (0, 1, or 2)
:Type: ``int`` (1 or 2)
:Default: Auto-selected
:Description: Force a specific FusedAttention backend. ``0`` = F16_max512_seqlen (cuDNN, ≤512 seq len), ``1`` = F16_arbitrary_seqlen (cuDNN, any seq len), ``2`` = FP8 backend. If not set, the backend is automatically selected based on the input configuration.
:Description: Force a specific FusedAttention backend. ``1`` = F16_arbitrary_seqlen (cuDNN, any seq len), ``2`` = FP8 backend. If not set, the backend is automatically selected based on the input configuration.

.. envvar:: NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT

Expand Down
4 changes: 2 additions & 2 deletions tests/pytorch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,11 +390,11 @@ def test():
_attention_backends["backend_selection_requires_update"] = False
return available_backends, flash_attention_backend, fused_attention_backend

backends = {0: "F16_max512_seqlen", 1: "F16_arbitrary_seqlen", 2: "FP8"}
backends = {1: "F16_arbitrary_seqlen", 2: "FP8"}
if AttentionLogging._is_logging_setup is False:
AttentionLogging.setup_logging()

for i in range(3):
for i in backends:
os.environ["NVTE_FUSED_ATTN_BACKEND"] = str(i)
_attention_backends["backend_selection_requires_update"] = True
available_backends, flash_attention_backend, fused_attention_backend = test()
Expand Down
1 change: 0 additions & 1 deletion transformer_engine/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,6 @@ list(APPEND transformer_engine_cuda_sources
dropout/dropout.cu
fused_attn/context_parallel.cu
fused_attn/kv_cache.cu
fused_attn/fused_attn_f16_max512_seqlen.cu
fused_attn/fused_attn_f16_arbitrary_seqlen.cu
fused_attn/fused_attn_fp8.cu
fused_attn/utils.cu
Expand Down
61 changes: 3 additions & 58 deletions transformer_engine/common/fused_attn/fused_attn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
#include "../util/cuda_runtime.h"
#include "../util/system.h"
#include "fused_attn_f16_arbitrary_seqlen.h"
#include "fused_attn_f16_max512_seqlen.h"
#include "fused_attn_fp8.h"
#include "utils.h"

Expand Down Expand Up @@ -304,28 +303,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
<< std::endl;
}
} else if ((q_dtype == NVTEDType::kNVTEFloat16) || (q_dtype == NVTEDType::kNVTEBFloat16)) {
bool flag_m512 = false;
bool flag_arb = false;
if ((sm_arch_ == 80 || sm_arch_ == 90) && (max_seqlen_q <= 512 && max_seqlen_q % 64 == 0) &&
(max_seqlen_kv <= 512 && max_seqlen_kv % 64 == 0) && (head_dim_qk == 64) &&
(head_dim_v == 64) && (num_attn_heads == num_gqa_groups) &&
((bias_type == NVTE_Bias_Type::NVTE_NO_BIAS) ||
(bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS)) &&
((attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK) ||
(attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK) ||
(attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK &&
max_seqlen_q == max_seqlen_kv) ||
(attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK)) &&
((qkv_layout == NVTE_QKV_Layout::NVTE_SB3HD) ||
(qkv_layout == NVTE_QKV_Layout::NVTE_SBHD_SB2HD) ||
(qkv_layout == NVTE_QKV_Layout::NVTE_BS3HD) ||
(qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BS2HD) ||
(qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD)) &&
((window_size_left == -1) && (window_size_right == -1 || window_size_right == 0)) &&
!requires_64bit_ragged_offset &&
(softmax_type == NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX) && !return_max_logit) {
flag_m512 = true;
}
if (
// TODO(cyang): replace with cudnn-frontend check_support for cleaner logic and better error messaging
// architecture
Expand Down Expand Up @@ -498,31 +476,9 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
dropout == 0.0 && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS))))) {
flag_arb = true;
}
if (((max_seqlen_q > 512) || (max_seqlen_kv > 512)) && (flag_arb == true)) {
if (flag_arb) {
backend = NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen;
}
if ((max_seqlen_q <= 512) && (max_seqlen_kv <= 512)) {
if (flag_arb == true) {
backend = NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen;
} else if ((flag_arb == false) && (flag_m512 == true)) {
backend = NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen;
}
int env_backend = static_cast<int>(backend);
env_backend = transformer_engine::getenv<int>("NVTE_FUSED_ATTN_BACKEND", env_backend);
if (((env_backend == static_cast<int>(NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen)) &&
flag_m512) ||
((env_backend == static_cast<int>(NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen)) &&
flag_arb)) {
backend = static_cast<NVTE_Fused_Attn_Backend>(env_backend);
}
}
if (cudnn_runtime_version < 8901 &&
backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend;
std::cout << "Warning: FP16/BF16 fused attention is supported by cuDNN 8.9.1+."
" Please upgrade your cuDNN version if possible."
<< std::endl;
}
if (cudnn_runtime_version < 8900 &&
backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) {
backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend;
Expand Down Expand Up @@ -667,12 +623,7 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right,
return_max_logit, cuda_graph, false);

if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
fused_attn_max_512_fwd(b, h_q, max_seqlen_q, max_seqlen_kv, d_qk, is_training, attn_scale,
dropout, qkv_layout, bias_type, attn_mask_type, input_Q, input_K,
input_V, input_Bias, output_O, Aux_CTX_Tensors, input_cu_seqlens_q,
input_cu_seqlens_kv, input_rng_state, wkspace, stream, handle);
} else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) {
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) {
fused_attn_arbitrary_seqlen_fwd(
b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, t_q, t_kv, num_pages_k, num_pages_v,
page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, is_training,
Expand Down Expand Up @@ -753,13 +704,7 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right, false,
cuda_graph, deterministic);

if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
fused_attn_max_512_bwd(b, h_q, max_seqlen_q, max_seqlen_kv, d_qk, attn_scale, dropout,
qkv_layout, bias_type, attn_mask_type, input_Q, input_K, input_V,
input_dO, output_S, output_dQ, output_dK, output_dV, output_dBias,
input_cu_seqlens_q, input_cu_seqlens_kv, wkspace, stream, handle);
} else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) {
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) {
size_t i = 0;
Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
Tensor *input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
Expand Down
Loading