From 4ad50d42fdee5b5831a24188fd83acb622b5e63b Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Thu, 30 Apr 2026 16:05:03 -0700 Subject: [PATCH 1/2] remove max512 subbackend Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- docs/envvars.rst | 4 +- tests/pytorch/utils.py | 4 +- transformer_engine/common/CMakeLists.txt | 1 - .../common/fused_attn/fused_attn.cpp | 61 +- .../fused_attn_f16_max512_seqlen.cu | 1343 ----------------- .../fused_attn/fused_attn_f16_max512_seqlen.h | 41 - .../include/transformer_engine/fused_attn.h | 4 +- .../common/util/pybind_helper.h | 1 - .../jax/cpp_extensions/attention.py | 5 +- .../jax/csrc/extensions/attention.cpp | 10 - .../jax/csrc/extensions/pybind.cpp | 1 - .../dot_product_attention/backends.py | 25 +- .../dot_product_attention/context_parallel.py | 49 +- .../attention/dot_product_attention/utils.py | 30 - .../pytorch/cpp_extensions/fused_attn.py | 38 +- .../pytorch/csrc/extensions/attention.cpp | 1 - 16 files changed, 32 insertions(+), 1586 deletions(-) delete mode 100644 transformer_engine/common/fused_attn/fused_attn_f16_max512_seqlen.cu delete mode 100644 transformer_engine/common/fused_attn/fused_attn_f16_max512_seqlen.h diff --git a/docs/envvars.rst b/docs/envvars.rst index 1e040b4c3e..29ca498148 100644 --- a/docs/envvars.rst +++ b/docs/envvars.rst @@ -142,9 +142,9 @@ Attention Backend Selection .. envvar:: NVTE_FUSED_ATTN_BACKEND - :Type: ``int`` (0, 1, or 2) + :Type: ``int`` (1 or 2) :Default: Auto-selected - :Description: Force a specific FusedAttention backend. ``0`` = F16_max512_seqlen (cuDNN, ≤512 seq len), ``1`` = F16_arbitrary_seqlen (cuDNN, any seq len), ``2`` = FP8 backend. If not set, the backend is automatically selected based on the input configuration. + :Description: Force a specific FusedAttention backend. ``1`` = F16_arbitrary_seqlen (cuDNN, any seq len), ``2`` = FP8 backend. If not set, the backend is automatically selected based on the input configuration. .. envvar:: NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT diff --git a/tests/pytorch/utils.py b/tests/pytorch/utils.py index 8f8852edc2..8439f6cc1a 100644 --- a/tests/pytorch/utils.py +++ b/tests/pytorch/utils.py @@ -390,11 +390,11 @@ def test(): _attention_backends["backend_selection_requires_update"] = False return available_backends, flash_attention_backend, fused_attention_backend - backends = {0: "F16_max512_seqlen", 1: "F16_arbitrary_seqlen", 2: "FP8"} + backends = {1: "F16_arbitrary_seqlen", 2: "FP8"} if AttentionLogging._is_logging_setup is False: AttentionLogging.setup_logging() - for i in range(3): + for i in backends: os.environ["NVTE_FUSED_ATTN_BACKEND"] = str(i) _attention_backends["backend_selection_requires_update"] = True available_backends, flash_attention_backend, fused_attention_backend = test() diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index 781fe48814..b9bcca6b17 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -181,7 +181,6 @@ list(APPEND transformer_engine_cuda_sources dropout/dropout.cu fused_attn/context_parallel.cu fused_attn/kv_cache.cu - fused_attn/fused_attn_f16_max512_seqlen.cu fused_attn/fused_attn_f16_arbitrary_seqlen.cu fused_attn/fused_attn_fp8.cu fused_attn/utils.cu diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index 141767b803..b0d37c07e2 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -11,7 +11,6 @@ #include "../util/cuda_runtime.h" #include "../util/system.h" #include "fused_attn_f16_arbitrary_seqlen.h" -#include "fused_attn_f16_max512_seqlen.h" #include "fused_attn_fp8.h" #include "utils.h" @@ -304,28 +303,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( << std::endl; } } else if ((q_dtype == NVTEDType::kNVTEFloat16) || (q_dtype == NVTEDType::kNVTEBFloat16)) { - bool flag_m512 = false; bool flag_arb = false; - if ((sm_arch_ == 80 || sm_arch_ == 90) && (max_seqlen_q <= 512 && max_seqlen_q % 64 == 0) && - (max_seqlen_kv <= 512 && max_seqlen_kv % 64 == 0) && (head_dim_qk == 64) && - (head_dim_v == 64) && (num_attn_heads == num_gqa_groups) && - ((bias_type == NVTE_Bias_Type::NVTE_NO_BIAS) || - (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS)) && - ((attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK) || - (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK) || - (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK && - max_seqlen_q == max_seqlen_kv) || - (attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK)) && - ((qkv_layout == NVTE_QKV_Layout::NVTE_SB3HD) || - (qkv_layout == NVTE_QKV_Layout::NVTE_SBHD_SB2HD) || - (qkv_layout == NVTE_QKV_Layout::NVTE_BS3HD) || - (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BS2HD) || - (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD)) && - ((window_size_left == -1) && (window_size_right == -1 || window_size_right == 0)) && - !requires_64bit_ragged_offset && - (softmax_type == NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX) && !return_max_logit) { - flag_m512 = true; - } if ( // TODO(cyang): replace with cudnn-frontend check_support for cleaner logic and better error messaging // architecture @@ -498,31 +476,9 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( dropout == 0.0 && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS))))) { flag_arb = true; } - if (((max_seqlen_q > 512) || (max_seqlen_kv > 512)) && (flag_arb == true)) { + if (flag_arb) { backend = NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen; } - if ((max_seqlen_q <= 512) && (max_seqlen_kv <= 512)) { - if (flag_arb == true) { - backend = NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen; - } else if ((flag_arb == false) && (flag_m512 == true)) { - backend = NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen; - } - int env_backend = static_cast(backend); - env_backend = transformer_engine::getenv("NVTE_FUSED_ATTN_BACKEND", env_backend); - if (((env_backend == static_cast(NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen)) && - flag_m512) || - ((env_backend == static_cast(NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen)) && - flag_arb)) { - backend = static_cast(env_backend); - } - } - if (cudnn_runtime_version < 8901 && - backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { - backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend; - std::cout << "Warning: FP16/BF16 fused attention is supported by cuDNN 8.9.1+." - " Please upgrade your cuDNN version if possible." - << std::endl; - } if (cudnn_runtime_version < 8900 && backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend; @@ -667,12 +623,7 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right, return_max_logit, cuda_graph, false); - if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { - fused_attn_max_512_fwd(b, h_q, max_seqlen_q, max_seqlen_kv, d_qk, is_training, attn_scale, - dropout, qkv_layout, bias_type, attn_mask_type, input_Q, input_K, - input_V, input_Bias, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, - input_cu_seqlens_kv, input_rng_state, wkspace, stream, handle); - } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { + if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { fused_attn_arbitrary_seqlen_fwd( b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, t_q, t_kv, num_pages_k, num_pages_v, page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, is_training, @@ -753,13 +704,7 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right, false, cuda_graph, deterministic); - if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - fused_attn_max_512_bwd(b, h_q, max_seqlen_q, max_seqlen_kv, d_qk, attn_scale, dropout, - qkv_layout, bias_type, attn_mask_type, input_Q, input_K, input_V, - input_dO, output_S, output_dQ, output_dK, output_dV, output_dBias, - input_cu_seqlens_q, input_cu_seqlens_kv, wkspace, stream, handle); - } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { + if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { size_t i = 0; Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); Tensor *input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_max512_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_max512_seqlen.cu deleted file mode 100644 index d5151a51f1..0000000000 --- a/transformer_engine/common/fused_attn/fused_attn_f16_max512_seqlen.cu +++ /dev/null @@ -1,1343 +0,0 @@ -/************************************************************************* - * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * - * See LICENSE for license information. - ************************************************************************/ - -#include -#include -#include - -#include -#include - -#include "../common.h" -#include "../cudnn_utils.h" -#include "fused_attn_f16_max512_seqlen.h" -#include "utils.h" - -#define Q_ID 1 -#define K_ID 2 -#define V_ID 3 -#define O_ID 4 -#define S_ID 5 -#define B_ID 6 -#define DROPOUT_CONST_ID 7 -#define S_CONST_ID 8 -#define Q_SEQLEN_ID 9 -#define K_SEQLEN_ID 10 -#define dQ_ID 11 -#define dK_ID 12 -#define dV_ID 13 -#define dO_ID 14 -#define MASK_VAL_ID 15 -#define dS_ID 16 -#define dBias_ID 17 -#define DROPOUT_SEED_ID 18 -#define DROPOUT_OFFSET_ID 19 - -#define VIRTUAL_ID 20 - -namespace transformer_engine { -namespace fused_attn { - -static void createScale(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d, - NVTE_QKV_Layout layout, cudnnDataType_t tensorType, - std::vector &ops) { - // scale - int64_t scale_dim[4] = {1, 1, 1, 1}; - int64_t scale_stride[4] = {1, 1, 1, 1}; - - int64_t k_dim[4] = {b, h, d, s_kv}; - int64_t k_stride[4]; - generateMatrixStrides(b, h, s_q, s_kv, d, k_stride, layout, - NVTE_QKV_Matrix::NVTE_K_Matrix_Transpose); - - auto scaleTensor = - tensor_create(tensorType, S_CONST_ID, scale_dim, scale_stride, false, true); // is by value - auto kTensor = tensor_create(tensorType, K_ID, k_dim, k_stride, false, false); - auto afterScaleKTensor = - tensor_create(tensorType, VIRTUAL_ID, k_dim, k_stride, true, false); // is virtual - - // Define the scale descriptor - auto scaleDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_MUL); - - // Create a Scale Node. - auto scale_op = binary_pw_op_create(kTensor, scaleTensor, afterScaleKTensor, scaleDesc); - - ops.push_back(std::move(scale_op)); -} - -static cudnn_frontend::Tensor createBMM1(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d, - NVTE_QKV_Layout layout, cudnnDataType_t tensorType, - bool zero_s, std::vector &ops) { - // Creates the necessary tensor descriptors - int64_t q_dim[4] = {b, h, s_q, d}; - int64_t q_stride[4]; - generateMatrixStrides(b, h, s_q, s_kv, d, q_stride, layout, NVTE_QKV_Matrix::NVTE_Q_Matrix); - - int64_t k_dim[4] = {b, h, d, s_kv}; - int64_t k_stride[4]; - generateMatrixStrides(b, h, s_q, s_kv, d, k_stride, layout, - NVTE_QKV_Matrix::NVTE_K_Matrix_Transpose); - - int64_t p_dim[4] = {b, h, s_q, s_kv}; - int64_t p_stride[4]; - generateMatrixStrides(b, h, s_q, s_kv, d, p_stride, layout, NVTE_QKV_Matrix::NVTE_S_Matrix); - - int64_t seqlen_dim[4] = {b, 1, 1, 1}; - int64_t seqlen_stride[4] = {1, 1, 1, 1}; - - auto qTensor = tensor_create(tensorType, Q_ID, q_dim, q_stride, false, false); - auto afterScaleKTensor = - tensor_create(tensorType, VIRTUAL_ID, k_dim, k_stride, true, false); // is virtual - // first GEMM output - auto pTensor = tensor_create(CUDNN_DATA_FLOAT, VIRTUAL_ID + 1, p_dim, p_stride, true, - false); // is virtual - - auto seqlenQTensor = - tensor_create(CUDNN_DATA_INT32, Q_SEQLEN_ID, seqlen_dim, seqlen_stride, false, false); - auto seqlenKTensor = - tensor_create(CUDNN_DATA_INT32, K_SEQLEN_ID, seqlen_dim, seqlen_stride, false, false); - - // Define the matmul 1 desc - // set padding value optionally to 0 for writing zeros to S tensor (if not set, old behaviour) - auto matmul_1_Desc = cudnn_frontend::MatMulDescBuilder().setComputeType(CUDNN_DATA_FLOAT).build(); - - if (zero_s) { - matmul_1_Desc = cudnn_frontend::MatMulDescBuilder() - .setComputeType(CUDNN_DATA_FLOAT) - .setPaddingValue(0.0f) - .build(); - } - - // Create a matmul 1 Node - auto matmul_op1 = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR) - .setaMatDesc(qTensor) - .setbMatDesc(afterScaleKTensor) - .setcMatDesc(pTensor) - .setmOverrideDesc(seqlenQTensor) - .setnOverrideDesc(seqlenKTensor) - .setmatmulDesc(matmul_1_Desc) - .build(); - - ops.push_back(std::move(matmul_op1)); - - return pTensor; -} - -static cudnn_frontend::Tensor createBias(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d, - NVTE_QKV_Layout layout, cudnnDataType_t tensorType, - std::vector &ops, - cudnn_frontend::Tensor const &prevBlockOutputTensor) { - NVTE_CHECK(ops.size() != 0, "Bias op constructed incorrectly as the first one."); - - int64_t b_dim[4] = {1, h, s_q, s_kv}; - int64_t b_stride[4] = {h * s_q * s_kv, s_q * s_kv, s_kv, 1}; - - int64_t afterBias_dim[4] = {b, h, s_q, s_kv}; - int64_t afterBias_stride[4]; - generateMatrixStrides(b, h, s_q, s_kv, d, afterBias_stride, layout, - NVTE_QKV_Matrix::NVTE_S_Matrix); - - // bias - auto bTensor = tensor_create(tensorType, B_ID, b_dim, b_stride, false, false); - // output - auto afterBiasTensor = tensor_create(CUDNN_DATA_FLOAT, VIRTUAL_ID + 50, afterBias_dim, - afterBias_stride, true, false); // is virtual - - // Define the bias descriptor - auto biasDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_ADD); - - // Create a Bias Node. - auto bias_op = binary_pw_op_create(prevBlockOutputTensor, bTensor, afterBiasTensor, biasDesc); - - ops.push_back(std::move(bias_op)); - - return afterBiasTensor; -} - -static cudnn_frontend::Tensor createMask(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d, - NVTE_QKV_Layout layout, NVTE_Mask_Type mask_type, - cudnnDataType_t tensorType, - std::vector &ops, - cudnn_frontend::Tensor const &prevBlockOutputTensor, - bool is_bprop) { - NVTE_CHECK(ops.size() != 0, "Padding mask constructed incorrectly as the first one."); - - // subtraction output - int64_t afterBMM1_dim[4] = {b, h, s_q, s_kv}; - int64_t afterBMM1_stride[4] = {h * s_q * s_kv, s_q * s_kv, s_kv, 1}; - - int64_t seqlen_dim[4] = {b, 1, 1, 1}; - int64_t seqlen_stride[4] = {1, 1, 1, 1}; - - int64_t maskVal_dim[4] = {1, 1, 1, 1}; - int64_t maskVal_stride[4] = {1, 1, 1, 1}; - - // mask value to put in the masked pixels - auto maskValTensor = tensor_create(CUDNN_DATA_FLOAT, MASK_VAL_ID, maskVal_dim, maskVal_stride, - false, true); // is by value - - auto seqlenQTensor = - tensor_create(CUDNN_DATA_INT32, Q_SEQLEN_ID, seqlen_dim, seqlen_stride, false, false); - auto seqlenKTensor = - tensor_create(CUDNN_DATA_INT32, K_SEQLEN_ID, seqlen_dim, seqlen_stride, false, false); - // gen index row output - auto rowIndexTensor = tensor_create(CUDNN_DATA_FLOAT, VIRTUAL_ID + 100, afterBMM1_dim, - afterBMM1_stride, true, false); // is virtual - // gen index column output - auto columnIndexTensor = tensor_create(CUDNN_DATA_FLOAT, VIRTUAL_ID + 101, afterBMM1_dim, - afterBMM1_stride, true, false); // is virtual - // less than row output - auto lessThanRowTensor = - tensor_create(CUDNN_DATA_BOOLEAN, VIRTUAL_ID + 102, afterBMM1_dim, afterBMM1_stride, true, - false); // is virtual - // less than column output - auto lessThanColTensor = tensor_create(CUDNN_DATA_BOOLEAN, VIRTUAL_ID + 103, afterBMM1_dim, - afterBMM1_stride, true, false); // is virtual - // padding mask (lessthanRow && lessthanCol) - auto paddingMaskTensor = tensor_create(CUDNN_DATA_BOOLEAN, VIRTUAL_ID + 104, afterBMM1_dim, - afterBMM1_stride, true, false); // is virtual - // row >= col check for causal mask - auto rowGreaterColTensor = tensor_create(CUDNN_DATA_BOOLEAN, VIRTUAL_ID + 105, afterBMM1_dim, - afterBMM1_stride, true, false); // is virtual - // create causal mask (padding && row >= col) - auto causalMaskTensor = tensor_create(CUDNN_DATA_BOOLEAN, VIRTUAL_ID + 106, afterBMM1_dim, - afterBMM1_stride, true, false); // is virtual - - // output after masking - int64_t maskOutputTensor_id = VIRTUAL_ID + 107; - int64_t maskOutputTensor_virtual = true; - cudnnDataType_t maskOutputTensor_dataType = CUDNN_DATA_FLOAT; - auto maskOutputTensor_reorderType = cudnn_frontend::TensorReordering_t::NONE; - - if (is_bprop) { - maskOutputTensor_id = dS_ID; - maskOutputTensor_virtual = false; - maskOutputTensor_dataType = tensorType; - maskOutputTensor_reorderType = cudnn_frontend::TensorReordering_t::F16x16; - } - - auto maskOutputTensor = - cudnn_frontend::TensorBuilder() - .setDim(4, afterBMM1_dim) - .setStride(4, afterBMM1_stride) - .setAlignment(16) // 16B alignment is needed to run a tensor core engine - .setByValue(false) - .setDataType(maskOutputTensor_dataType) - .setVirtual(maskOutputTensor_virtual) - .setId(maskOutputTensor_id) - .setReorderType(maskOutputTensor_reorderType) - .build(); - - // Define the gen index for row descriptor - auto genIndexRowDesc = cudnn_frontend::PointWiseDescBuilder() - .setMode(CUDNN_POINTWISE_GEN_INDEX) - .setAxis(2) - .setComputeType(CUDNN_DATA_FLOAT) - .build(); - - // Create a gen index Node. - auto genIndexRow_op = unary_pw_op_create(prevBlockOutputTensor, rowIndexTensor, genIndexRowDesc); - - // Define the gen index for row descriptor - auto genIndexColumnDesc = cudnn_frontend::PointWiseDescBuilder() - .setMode(CUDNN_POINTWISE_GEN_INDEX) - .setAxis(3) - .setComputeType(CUDNN_DATA_FLOAT) - .build(); - - // Create a gen index Node. - auto genIndexColumn_op = - unary_pw_op_create(prevBlockOutputTensor, columnIndexTensor, genIndexColumnDesc); - - // Define the less than comparison for row descriptor - auto lessThanRowDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_CMP_LT); - - // Create a less than comparison for row Node. - auto lessThanRow_op = - binary_pw_op_create(rowIndexTensor, seqlenQTensor, lessThanRowTensor, lessThanRowDesc); - - // Define the less than comparison for column descriptor - auto lessThanColDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_CMP_LT); - - // Create a less than comparison for col Node. - auto lessThanCol_op = - binary_pw_op_create(columnIndexTensor, seqlenKTensor, lessThanColTensor, lessThanColDesc); - - // Define the less than comparison for column descriptor - auto paddingMaskAndDesc = pw_desc_create(CUDNN_DATA_BOOLEAN, CUDNN_POINTWISE_LOGICAL_AND); - - // Create a and node for combining lessThanRow and lessThanCol - auto paddingMaskAnd_op = binary_pw_op_create(lessThanRowTensor, lessThanColTensor, - paddingMaskTensor, paddingMaskAndDesc); - - // Define the greater than equal to comparison descriptor - auto rowGreaterColDesc = pw_desc_create(CUDNN_DATA_BOOLEAN, CUDNN_POINTWISE_CMP_GE); - - // Create a greater than equal to Node. - auto rowGreaterCol_op = binary_pw_op_create(rowIndexTensor, columnIndexTensor, - rowGreaterColTensor, rowGreaterColDesc); - - // Define the and to create causal mask descriptor - auto causalMaskAndDesc = pw_desc_create(CUDNN_DATA_BOOLEAN, CUDNN_POINTWISE_LOGICAL_AND); - - // Create a causal Mask Node. - auto causalMaskAnd_op = binary_pw_op_create(paddingMaskTensor, rowGreaterColTensor, - causalMaskTensor, causalMaskAndDesc); - - /////////////////// Apply the mask ////////////////////////// - - auto maskTensor = (mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || - mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK) - ? std::move(causalMaskTensor) - : std::move(paddingMaskTensor); - - // Define the binary select to perform masking descriptor - auto maskDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_BINARY_SELECT); - - // Create a binary select Node. - auto mask_op = ternary_pw_op_create(prevBlockOutputTensor, maskValTensor, maskTensor, - maskOutputTensor, maskDesc); - - ops.push_back(std::move(genIndexRow_op)); - ops.push_back(std::move(genIndexColumn_op)); - ops.push_back(std::move(lessThanRow_op)); - ops.push_back(std::move(lessThanCol_op)); - ops.push_back(std::move(paddingMaskAnd_op)); - if (mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || - mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK) { - ops.push_back(std::move(rowGreaterCol_op)); - ops.push_back(std::move(causalMaskAnd_op)); - } - ops.push_back(std::move(mask_op)); - - return maskOutputTensor; -} - -static cudnn_frontend::Tensor createSoftmaxForward( - int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d, NVTE_QKV_Layout layout, - bool enable_dropout, bool softmax_output_virtual, cudnnDataType_t tensorType, - std::vector &ops, - cudnn_frontend::Tensor const &prevBlockOutputTensor) { - int64_t afterBMM1_dim[4] = {b, h, s_q, s_kv}; - int64_t afterBMM1_stride[4] = {h * s_q * s_kv, s_q * s_kv, s_kv, 1}; - - int64_t afterReduction_dim[4] = {b, h, s_q, 1}; - int64_t afterReduction_stride[4] = {h * s_q, s_q, 1, 1}; - - cudnnDataType_t softmaxOutputType = enable_dropout ? CUDNN_DATA_FLOAT : tensorType; - uint64_t softmaxOutputName = softmax_output_virtual ? VIRTUAL_ID + 154 : S_ID; - - // max (x) - auto afterMaxReductionTensor = - tensor_create(CUDNN_DATA_FLOAT, VIRTUAL_ID + 150, afterReduction_dim, afterReduction_stride, - true, false); // is virtual - // x - max(x) - auto afterSubtractionTensor = tensor_create(CUDNN_DATA_FLOAT, VIRTUAL_ID + 151, afterBMM1_dim, - afterBMM1_stride, true, false); // is virtual - // e^(x - max(x)) - auto afterExponentTensor = tensor_create(CUDNN_DATA_FLOAT, VIRTUAL_ID + 152, afterBMM1_dim, - afterBMM1_stride, true, false); // is virtual; - // sum (e^(x - max(x))) - auto afterAddReductionTensor = - tensor_create(CUDNN_DATA_FLOAT, VIRTUAL_ID + 153, afterReduction_dim, afterReduction_stride, - true, false); // is virtual - // divide (e/ sum(e)) - - auto reorder_type = cudnn_frontend::TensorReordering_t::F16x16; - - auto afterDivisionTensor = - cudnn_frontend::TensorBuilder() - .setDim(4, afterBMM1_dim) - .setStride(4, afterBMM1_stride) - .setId(softmaxOutputName) - .setAlignment(16) // 16B alignment is needed to run a tensor core engine - .setDataType(softmaxOutputType) - .setVirtual(softmax_output_virtual) - .setByValue(false) - .setReorderType(reorder_type) - .build(); - - // Define the reduction descriptor - auto reductionMaxDesc = cudnn_frontend::ReductionDescBuilder() - .setComputeType(CUDNN_DATA_FLOAT) - .setReductionOp(CUDNN_REDUCE_TENSOR_MAX) - .build(); - - // Create a reduction max Node. - auto reductionMax_op = - cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_REDUCTION_DESCRIPTOR) - .setxDesc(prevBlockOutputTensor) - .setyDesc(afterMaxReductionTensor) - .setreductionDesc(reductionMaxDesc) - .build(); - - // Define the subtract descriptor - auto subtractDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_SUB); - - // Create a subtract Node. - auto subtract_op = binary_pw_op_create(prevBlockOutputTensor, afterMaxReductionTensor, - afterSubtractionTensor, subtractDesc); - - // Define the exponent descriptor - auto exponentDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_EXP); - - // Create a exponent Node. - auto exponent_op = unary_pw_op_create(afterSubtractionTensor, afterExponentTensor, exponentDesc); - - // Define the reduction descriptor - auto reductionAddDesc = cudnn_frontend::ReductionDescBuilder() - .setComputeType(CUDNN_DATA_FLOAT) - .setReductionOp(CUDNN_REDUCE_TENSOR_ADD) - .build(); - - // Create a reduction add Node. - auto reductionAdd_op = - cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_REDUCTION_DESCRIPTOR) - .setxDesc(afterExponentTensor) - .setyDesc(afterAddReductionTensor) - .setreductionDesc(reductionAddDesc) - .build(); - - // Define the division descriptor - auto divisionDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_DIV); - - // Create a subtract Node. - auto division_op = binary_pw_op_create(afterExponentTensor, afterAddReductionTensor, - afterDivisionTensor, divisionDesc); - - ops.push_back(std::move(reductionMax_op)); - ops.push_back(std::move(subtract_op)); - ops.push_back(std::move(exponent_op)); - ops.push_back(std::move(reductionAdd_op)); - ops.push_back(std::move(division_op)); - - return afterDivisionTensor; -} - -static cudnn_frontend::Tensor createDropout(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, - int64_t d, double probability, - cudnnDataType_t tensorType, - std::vector &ops, - cudnn_frontend::Tensor const &prevBlockOutputTensor) { - NVTE_CHECK(ops.size() != 0, "Dropout DAG constructed incorrectly as the first one"); - - int64_t afterBMM1_dim[4] = {b, h, s_q, s_kv}; - int64_t afterBMM1_stride[4] = {h * s_q * s_kv, s_q * s_kv, s_kv, 1}; - - int64_t scale_dim[4] = {1, 1, 1, 1}; - int64_t scale_stride[4] = {1, 1, 1, 1}; - - // mask for the dropout - auto dropoutMaskTensor = tensor_create(CUDNN_DATA_FLOAT, VIRTUAL_ID + 200, afterBMM1_dim, - afterBMM1_stride, true, false); // is virtual - - auto reorder_type = cudnn_frontend::TensorReordering_t::F16x16; - - // after dropout tensor - auto afterDropoutTensor = - cudnn_frontend::TensorBuilder() - .setDim(4, afterBMM1_dim) - .setStride(4, afterBMM1_stride) - .setId(S_ID) - .setAlignment(16) // 16B alignment is needed to run a tensor core engine - .setDataType(tensorType) - .setVirtual(false) - .setByValue(false) - .setReorderType(reorder_type) - .build(); - // scale after dropout - auto scaleDropoutTensor = - tensor_create(tensorType, DROPOUT_CONST_ID, scale_dim, scale_stride, false, - true); // is by value - // after Scale - auto afterScaleTensor = tensor_create(tensorType, VIRTUAL_ID + 201, afterBMM1_dim, - afterBMM1_stride, true, false); // is virtual - - // Define the reduction descriptor - auto rngDesc = cudnn_frontend::RngDescBuilder() - .setRngDistribution(CUDNN_RNG_DISTRIBUTION_BERNOULLI) - .setBernoulliDistProbability(1.0 - probability) - .build(); - - auto dropoutSeed = - tensor_create(CUDNN_DATA_INT64, DROPOUT_SEED_ID, scale_dim, scale_stride, false, false); - auto dropoutOffset = - tensor_create(CUDNN_DATA_INT64, DROPOUT_OFFSET_ID, scale_dim, scale_stride, false, false); - - // Create a rng Node. - auto rng_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_RNG_DESCRIPTOR) - .setyDesc(dropoutMaskTensor) - .setSeedDesc(dropoutSeed) - .setOffsetDesc(dropoutOffset) - .setRngDesc(rngDesc) - .build(); - - // Define the multiply mask descriptor - auto maskMulDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_MUL); - - // Create a multiply mask Node. - auto maskMul_op = binary_pw_op_create(prevBlockOutputTensor, dropoutMaskTensor, - afterDropoutTensor, maskMulDesc); - - // Define the multiply scale descriptor - auto scaleMulDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_MUL); - - // Create a multiply mask Node. - auto scaleMul_op = - binary_pw_op_create(afterDropoutTensor, scaleDropoutTensor, afterScaleTensor, scaleMulDesc); - - ops.push_back(std::move(rng_op)); - ops.push_back(std::move(maskMul_op)); - ops.push_back(std::move(scaleMul_op)); - - return afterScaleTensor; -} - -static void createBMM2(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d, - NVTE_QKV_Layout layout, cudnnDataType_t tensorType, - std::vector &ops, - cudnn_frontend::Tensor const &prevBlockOutputTensor) { - NVTE_CHECK(ops.size() != 0, "BMM2 op constructed incorrectly as the first one"); - - int64_t seqlen_dim[4] = {b, 1, 1, 1}; - int64_t seqlen_stride[4] = {1, 1, 1, 1}; - - int64_t v_dim[4] = {b, h, s_kv, d}; - int64_t v_stride[4]; - generateMatrixStrides(b, h, s_q, s_kv, d, v_stride, layout, NVTE_QKV_Matrix::NVTE_V_Matrix); - - int64_t o_dim[4] = {b, h, s_q, d}; - int64_t o_stride[4]; - generateMatrixStrides(b, h, s_q, s_kv, d, o_stride, layout, NVTE_QKV_Matrix::NVTE_O_Matrix); - - auto seqlenQTensor = - tensor_create(CUDNN_DATA_INT32, Q_SEQLEN_ID, seqlen_dim, seqlen_stride, false, false); - auto seqlenKTensor = - tensor_create(CUDNN_DATA_INT32, K_SEQLEN_ID, seqlen_dim, seqlen_stride, false, false); - auto vTensor = tensor_create(tensorType, V_ID, v_dim, v_stride, false, false); - // second GEMM output - auto oTensor = tensor_create(tensorType, O_ID, o_dim, o_stride, false, false); - - // Define the matmul 2 desc - // set padding value optionally to 0 for writing zeros to O tensor (if not set, old behaviour) - auto matmul_2_Desc = cudnn_frontend::MatMulDescBuilder() - .setComputeType(CUDNN_DATA_FLOAT) - .setPaddingValue(0.0f) - .build(); - - // Create a matmul 2 Node - auto matmul_op2 = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR) - .setaMatDesc(prevBlockOutputTensor) - .setbMatDesc(vTensor) - .setcMatDesc(oTensor) - .setmOverrideDesc(seqlenQTensor) - .setkOverrideDesc(seqlenKTensor) - .setmatmulDesc(matmul_2_Desc) - .build(); - - ops.push_back(std::move(matmul_op2)); -} - -static cudnn_frontend::Tensor createSoftmaxBackward(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, - int64_t d, NVTE_QKV_Layout layout, - cudnnDataType_t tensorType, - std::vector &ops, - cudnn_frontend::Tensor const &yTensor, - cudnn_frontend::Tensor const &dyTensor) { - NVTE_CHECK(ops.size() != 0, "Softmax backward constructed incorrectly as the first one"); - - int64_t p_dim[4] = {b, h, s_q, s_kv}; - int64_t p_stride[4]; - generateMatrixStrides(b, h, s_q, s_kv, d, p_stride, layout, NVTE_QKV_Matrix::NVTE_S_Matrix); - - int64_t p_reduction_dim[4] = {b, h, s_q, 1}; - int64_t p_reduction_stride[4]; - - p_reduction_stride[3] = 1; - p_reduction_stride[2] = 1; - p_reduction_stride[1] = s_q; - p_reduction_stride[0] = s_q * h; - - int64_t const_dim[4] = {1, 1, 1, 1}; - int64_t const_stride[4] = {1, 1, 1, 1}; - - // creating all tensors - auto softmaxScaleTensor = - tensor_create(CUDNN_DATA_FLOAT, S_CONST_ID, const_dim, const_stride, false, true); - auto dyMulYTensor = - tensor_create(CUDNN_DATA_FLOAT, VIRTUAL_ID + 250, p_dim, p_stride, true, false); - auto dxAfterReductionTensor = tensor_create(CUDNN_DATA_FLOAT, VIRTUAL_ID + 251, p_reduction_dim, - p_reduction_stride, true, false); - auto dxAfterSubtractionTensor = - tensor_create(CUDNN_DATA_FLOAT, VIRTUAL_ID + 252, p_dim, p_stride, true, false); - auto dxUnscaleTensor = - tensor_create(CUDNN_DATA_FLOAT, VIRTUAL_ID + 253, p_dim, p_stride, true, false); - auto dxTensor = tensor_create(CUDNN_DATA_FLOAT, VIRTUAL_ID + 254, p_dim, p_stride, true, false); - - // creating all ops - // mul (y * dy) - auto mul_1_desc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_MUL); - auto mul_1_op = binary_pw_op_create(yTensor, dyTensor, dyMulYTensor, mul_1_desc); - - // reduction add sum (y * dy) - auto reductionAddDesc = cudnn_frontend::ReductionDescBuilder() - .setComputeType(CUDNN_DATA_FLOAT) - .setReductionOp(CUDNN_REDUCE_TENSOR_ADD) - .build(); - - auto reductionAdd_op = - cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_REDUCTION_DESCRIPTOR) - .setxDesc(dyMulYTensor) - .setyDesc(dxAfterReductionTensor) - .setreductionDesc(reductionAddDesc) - .build(); - - // subtraction (dy - sum(y * dy)) - auto sub_0_desc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_SUB); - auto sub_0_op = - binary_pw_op_create(dyTensor, dxAfterReductionTensor, dxAfterSubtractionTensor, sub_0_desc); - - // mul (y * (dy - sum(y * dy))) - auto mul_2_desc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_MUL); - auto mul_2_op = - binary_pw_op_create(yTensor, dxAfterSubtractionTensor, dxUnscaleTensor, mul_2_desc); - - // mul (scale * dx) - auto mul_3_desc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_MUL); - auto mul_3_op = binary_pw_op_create(dxUnscaleTensor, softmaxScaleTensor, dxTensor, mul_3_desc); - - ops.push_back(std::move(mul_1_op)); - ops.push_back(std::move(reductionAdd_op)); - ops.push_back(std::move(sub_0_op)); - ops.push_back(std::move(mul_2_op)); - ops.push_back(std::move(mul_3_op)); - - return dxTensor; -} - -void fused_attn_max_512_fwd_impl( - int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d, bool is_training, - float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, void *devPtrQ, void *devPtrK, void *devPtrV, - void *devPtrS, void *devPtrO, void *devPtrBias, void *devPtrCuSeqlenQ, void *devPtrCuSeqlenKV, - void *devPtrDropoutSeed, void *devPtrDropoutOffset, void *workspace, size_t *workspace_size, - cudnnDataType_t tensorType, cudaStream_t stream, cudnnHandle_t handle) { - try { - FADescriptor descriptor{b, h, - s_q, s_kv, - d, scaling_factor, - is_training, dropout_probability, - layout, bias_type, - mask_type, tensorType, - false}; - - using CacheType = std::map; - static thread_local CacheType fmha_fprop_cache; - - // softmax auxiliary is only used in the training mode - bool enable_dropout = is_training && (dropout_probability != 0.0f); - - // two conditions that make softmax auxiliary in virtual - // 1. inference mode (not is_training) - // 2. dropout enabled: the auxiliary becomes the dropout output - bool softmax_output_virtual = !is_training || enable_dropout; - - // Get plan from cache if cache is available, otherwise create one - auto get_plan = [&](CacheType &cache, const FADescriptor &descriptor) { - // if hit, return - auto it = cache.find(descriptor); - if (it != cache.end()) { - auto plan = it->second; - return plan; - } - - // otherwise, build the op_graph and the plan. Then update cache - std::vector all_ops; - std::vector ops; - - createScale(b, h, s_q, s_kv, d, layout, tensorType, ops); - - // if bias, we need to memset the S buffer to correctly computate dbias - // WAR: causal_mask without bias needs memset the S buffer - // inference mode doesn't need the S auxiliary - auto zero_s = (bias_type != NVTE_Bias_Type::NVTE_NO_BIAS) || - (mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || - (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK)) && - is_training; - std::shared_ptr maskInput; - auto bmm1_output = createBMM1(b, h, s_q, s_kv, d, layout, tensorType, zero_s, ops); - - NVTE_CHECK(bias_type != NVTE_Bias_Type::NVTE_PRE_SCALE_BIAS, - "NVTE_Bias_Type::NVTE_PRE_SCALE_BIAS has not been implemented."); - - if (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS) { - auto bias_output = createBias(b, h, s_q, s_kv, d, layout, tensorType, ops, bmm1_output); - maskInput = std::make_shared(std::move(bias_output)); - } - if (bias_type == NVTE_Bias_Type::NVTE_NO_BIAS) { - maskInput = std::make_shared(std::move(bmm1_output)); - } - - auto mask_output = createMask(b, h, s_q, s_kv, d, layout, mask_type, tensorType, ops, - *maskInput.get(), false); - - NVTE_CHECK(dropout_probability != 1.0f, "Dropout probability cannot be 1.0."); - - auto softmax_output = - createSoftmaxForward(b, h, s_q, s_kv, d, layout, enable_dropout, softmax_output_virtual, - tensorType, ops, mask_output); - - if (enable_dropout) { - auto dropout_output = - createDropout(b, h, s_q, s_kv, d, dropout_probability, tensorType, ops, softmax_output); - createBMM2(b, h, s_q, s_kv, d, layout, tensorType, ops, dropout_output); - } else { - createBMM2(b, h, s_q, s_kv, d, layout, tensorType, ops, softmax_output); - } - - for (unsigned int i = 0; i < ops.size(); i++) { - all_ops.push_back(&ops[i]); - } - - // Create an Operation Graph - auto opGraph = cudnn_frontend::OperationGraphBuilder() - .setHandle(handle) - .setOperationGraph(all_ops.size(), all_ops.data()) - .build(); - - cudnn_frontend::EngineConfigList filtered_configs; - auto statuses = cudnn_frontend::get_heuristics_list<1>( - {"heuristics_instant"}, opGraph, allowAllConfig, filtered_configs, true); - - if (filtered_configs.size() == 0) { - cudnn_frontend::set_error_and_throw_exception( - nullptr, CUDNN_STATUS_NOT_SUPPORTED, - "run_mha_fprop: No config returned by the heuristics"); - } - auto plan = cudnn_frontend::ExecutionPlanBuilder() - .setHandle(handle) - .setEngineConfig(filtered_configs[0], opGraph.getTag()) - .build(); - cache.insert({descriptor, plan}); - return plan; - }; - - auto plan = get_plan(fmha_fprop_cache, descriptor); - - auto plan_workspace_size = plan.getWorkspaceSize(); - - // Exit to request upper level API to allocate memory if needed - if (workspace == nullptr) { - size_t actual_seqlen_workspace_size = 2 * b * sizeof(int32_t); - *workspace_size = plan_workspace_size + actual_seqlen_workspace_size; - return; - } - - // cuDNN stream check needs to be moved here to support dummy kernel calls with - // null streams for sizing the cuDNN workspace. - NVTE_CHECK_CUDNN(cudnnSetStream(handle, stream)); - - // Prepare actual seqlen - constexpr size_t nthreads_per_block = 128; - const size_t grid = (b + nthreads_per_block - 1) / nthreads_per_block; - void *devActualSeqlenQ = static_cast(workspace) + plan_workspace_size; - void *devActualSeqlenK = static_cast(devActualSeqlenQ) + b * sizeof(int32_t); - cu_seqlens_to_actual_seqlens<<>>( - b, b, static_cast(devPtrCuSeqlenQ), - static_cast(devPtrCuSeqlenKV), static_cast(devActualSeqlenQ), - static_cast(devActualSeqlenK)); - NVTE_CHECK_CUDA(cudaGetLastError()); - - // change this if you have access to float_min - float negInfinity = -1.0E+10; - float scale_dropout = 1 / (1 - dropout_probability); - - std::set> data_ptrs; - // add all the data pointers to be used in the variant pack - data_ptrs.insert(std::pair(Q_ID, devPtrQ)); - data_ptrs.insert(std::pair(K_ID, devPtrK)); - data_ptrs.insert(std::pair(V_ID, devPtrV)); - data_ptrs.insert(std::pair(Q_SEQLEN_ID, devActualSeqlenQ)); - data_ptrs.insert(std::pair(K_SEQLEN_ID, devActualSeqlenK)); - data_ptrs.insert(std::pair(MASK_VAL_ID, &negInfinity)); - - __half half_cast_scaling_factor{scaling_factor}; - __nv_bfloat16 bfloat_cast_scaling_factor{scaling_factor}; - - if (tensorType == CUDNN_DATA_FLOAT) { - data_ptrs.insert(std::pair(S_CONST_ID, &scaling_factor)); - } else if (tensorType == CUDNN_DATA_HALF) { - data_ptrs.insert(std::pair(S_CONST_ID, &half_cast_scaling_factor)); - } else if (tensorType == CUDNN_DATA_BFLOAT16) { - data_ptrs.insert(std::pair(S_CONST_ID, &bfloat_cast_scaling_factor)); - } else { - NVTE_ERROR("Unsupported tensor type."); - } - - data_ptrs.insert(std::pair(O_ID, devPtrO)); - - if (bias_type != NVTE_Bias_Type::NVTE_NO_BIAS) { - data_ptrs.insert(std::pair(B_ID, devPtrBias)); - } - - // if enable_dropout, S is the result after dropout - // if not enable dropout, S is the result after softmax - if (enable_dropout || !softmax_output_virtual) { - data_ptrs.insert(std::pair(S_ID, devPtrS)); - } - - __half half_cast_scale_dropout{scale_dropout}; - __nv_bfloat16 bfloat16_cast_scale_dropout{scale_dropout}; - - if (enable_dropout) { - // TODO(rewang): make a util func - if (tensorType == CUDNN_DATA_FLOAT) { - data_ptrs.insert(std::pair(DROPOUT_CONST_ID, &scale_dropout)); - } else if (tensorType == CUDNN_DATA_HALF) { - data_ptrs.insert(std::pair(DROPOUT_CONST_ID, &half_cast_scale_dropout)); - } else if (tensorType == CUDNN_DATA_BFLOAT16) { - data_ptrs.insert( - std::pair(DROPOUT_CONST_ID, &bfloat16_cast_scale_dropout)); - } else { - NVTE_ERROR("Unsupported tensor type."); - } - data_ptrs.insert(std::pair(DROPOUT_SEED_ID, devPtrDropoutSeed)); - data_ptrs.insert(std::pair(DROPOUT_OFFSET_ID, devPtrDropoutOffset)); - } - - auto variantPack = cudnn_frontend::VariantPackBuilder() - .setWorkspacePointer(workspace) - .setDataPointers(data_ptrs) - .build(); - - NVTE_CHECK_CUDNN(cudnnBackendExecute(handle, plan.get_raw_desc(), variantPack.get_raw_desc())); - } catch (cudnn_frontend::cudnnException &e) { - NVTE_ERROR(e.what()); - } -} - -void fused_attn_max_512_bwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d, - float scaling_factor, float dropout_probability, - NVTE_QKV_Layout layout, NVTE_Mask_Type mask_type, - NVTE_Bias_Type bias_type, void *devPtrQ, void *devPtrK, - void *devPtrV, void *devPtrS, void *devPtrdQ, void *devPtrdK, - void *devPtrdV, void *devPtrdO, void *devPtrdS, void *devPtrdBias, - void *devPtrCuSeqlenQ, void *devPtrCuSeqlenKV, void *workspace, - size_t *workspace_size, cudnnDataType_t tensorType, - cudaStream_t stream, cudnnHandle_t handle) { - try { - FADescriptor descriptor{ - b, h, s_q, s_kv, d, scaling_factor, true, dropout_probability, - layout, bias_type, mask_type, tensorType, false}; - - using CacheType = std::map; - static thread_local CacheType fmha_bprop_cache; - - auto get_plan = [&](CacheType &cache, const FADescriptor &descriptor) { - auto it = cache.find(descriptor); - if (it != cache.end()) { - return it->second; - } - - std::vector all_ops; - std::vector ops; - - // Creates the necessary tensor descriptors - int64_t q_dim[4] = {b, h, s_q, d}; - int64_t q_stride[4]; - generateMatrixStrides(b, h, s_q, s_kv, d, q_stride, layout, NVTE_QKV_Matrix::NVTE_Q_Matrix); - - int64_t k_dim[4] = {b, h, s_kv, d}; - int64_t k_stride[4]; - generateMatrixStrides( - b, h, s_q, s_kv, d, k_stride, layout, - NVTE_QKV_Matrix::NVTE_K_Matrix); // type is correct as K is not transposed - - int64_t v_dim[4] = {b, h, d, s_kv}; - int64_t v_stride[4]; - generateMatrixStrides( - b, h, s_q, s_kv, d, v_stride, layout, - NVTE_QKV_Matrix::NVTE_V_Matrix_Transpose); // type is correct as V is transposed - - int64_t p_dim[4] = {b, h, s_q, s_kv}; - int64_t p_stride[4]; - generateMatrixStrides(b, h, s_q, s_kv, d, p_stride, layout, NVTE_QKV_Matrix::NVTE_S_Matrix); - - int64_t p_transpose_dim[4] = {b, h, s_kv, s_q}; - int64_t p_transpose_stride[4]; - p_transpose_stride[0] = p_stride[0]; - p_transpose_stride[1] = p_stride[1]; - p_transpose_stride[2] = p_stride[3]; - p_transpose_stride[3] = p_stride[2]; - - int64_t o_dim[4] = {b, h, s_q, d}; - int64_t o_stride[4]; - generateMatrixStrides(b, h, s_q, s_kv, d, o_stride, layout, NVTE_QKV_Matrix::NVTE_O_Matrix); - - int64_t seqlen_dim[4] = {b, 1, 1, 1}; - int64_t seqlen_stride[4] = {1, 1, 1, 1}; - - int64_t scale_dim[4] = {1, 1, 1, 1}; - int64_t scale_stride[4] = {1, 1, 1, 1}; - - // inputs to fprop - auto qTensor = tensor_create(tensorType, Q_ID, q_dim, q_stride, false, false); - auto kTensor = tensor_create(tensorType, K_ID, k_dim, k_stride, false, false); - auto vTensor = tensor_create(tensorType, V_ID, v_dim, v_stride, false, false); - auto seqlenQTensor = - tensor_create(CUDNN_DATA_INT32, Q_SEQLEN_ID, seqlen_dim, seqlen_stride, false, false); - auto seqlenKTensor = - tensor_create(CUDNN_DATA_INT32, K_SEQLEN_ID, seqlen_dim, seqlen_stride, false, false); - - // gradient of the output - auto doTensor = tensor_create(tensorType, dO_ID, o_dim, o_stride, false, false); - - auto reorder_type = cudnn_frontend::TensorReordering_t::F16x16; - - // activation from fprop - auto pTensor = cudnn_frontend::TensorBuilder() - .setDim(4, p_dim) - .setStride(4, p_stride) - .setId(S_ID) - .setAlignment(16) // 16B alignment is needed to run a tensor core engine - .setDataType(tensorType) - .setVirtual(false) - .setByValue(false) - .setReorderType(reorder_type) - .build(); - - // outputs from bprop - auto dqTensor = tensor_create(tensorType, dQ_ID, q_dim, q_stride, false, false); - auto dkTensor = tensor_create(tensorType, dK_ID, k_dim, k_stride, false, false); - auto dvTensor = tensor_create(tensorType, dV_ID, k_dim, k_stride, false, - false); // not transposed therefore k_dim and k_stride - - //////////////////////////////////////////////////////// - // start creating the ops and the intermediate tensors - auto pReshapeTensor = tensor_create(tensorType, VIRTUAL_ID + 300, p_transpose_dim, - p_transpose_stride, true, false); - - // reshape to perform transpose and make pReshape - auto reshape_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_RESHAPE_DESCRIPTOR) - .setxDesc(pTensor) - .setyDesc(pReshapeTensor) - .build(); - - ops.push_back(std::move(reshape_op)); - - // scale dropout - auto dropoutScaleTensor = tensor_create(CUDNN_DATA_FLOAT, DROPOUT_CONST_ID, scale_dim, - scale_stride, false, true); // is by value - auto pAfterScaleTensor = tensor_create(tensorType, VIRTUAL_ID + 301, p_transpose_dim, - p_transpose_stride, true, false); - - auto scaleMulDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_MUL); - auto scaleMul_op = - binary_pw_op_create(pReshapeTensor, dropoutScaleTensor, pAfterScaleTensor, scaleMulDesc); - ops.push_back(std::move(scaleMul_op)); - - // perform absolute operation to remove the mask bit - auto pTransposeAfterAbsTensor = tensor_create(tensorType, VIRTUAL_ID + 302, p_transpose_dim, - p_transpose_stride, true, false); - - auto absDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_ABS); - auto abs_op = unary_pw_op_create(pAfterScaleTensor, pTransposeAfterAbsTensor, absDesc); - ops.push_back(std::move(abs_op)); - - // matmul to calculate dvTensor - // set padding value optionally to 0 for writing zeros to dV tensor (if not set, old - // behaviour) - auto matmul_0_Desc = cudnn_frontend::MatMulDescBuilder() - .setComputeType(CUDNN_DATA_FLOAT) - .setPaddingValue(0.0f) - .build(); - - auto matmul_op0 = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR) - .setaMatDesc(pTransposeAfterAbsTensor) - .setbMatDesc(doTensor) - .setcMatDesc(dvTensor) - .setmOverrideDesc(seqlenKTensor) - .setkOverrideDesc(seqlenQTensor) - .setmatmulDesc(matmul_0_Desc) - .build(); - - ops.push_back(std::move(matmul_op0)); - - // matmul to calculate dpTensor - auto dpTensor = - tensor_create(CUDNN_DATA_FLOAT, VIRTUAL_ID + 303, p_dim, p_stride, true, false); - - auto matmul_1_Desc = - cudnn_frontend::MatMulDescBuilder().setComputeType(CUDNN_DATA_FLOAT).build(); - - auto matmul_op1 = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR) - .setaMatDesc(doTensor) - .setbMatDesc(vTensor) - .setcMatDesc(dpTensor) - .setmOverrideDesc(seqlenQTensor) - .setnOverrideDesc(seqlenKTensor) - .setmatmulDesc(matmul_1_Desc) - .build(); - - ops.push_back(std::move(matmul_op1)); - - // mask the values which were dropped in dropout - auto pAbsTensor = tensor_create(tensorType, VIRTUAL_ID + 304, p_dim, p_stride, true, false); - - auto p_absDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_ABS); - auto p_abs_op = unary_pw_op_create(pTensor, pAbsTensor, p_absDesc); - ops.push_back(std::move(p_abs_op)); - - // create the dropout mask - auto zeroTensor = tensor_create(CUDNN_DATA_FLOAT, MASK_VAL_ID, scale_dim, scale_stride, false, - true); // is by value - auto dropoutMaskTensor = - tensor_create(CUDNN_DATA_BOOLEAN, VIRTUAL_ID + 305, p_dim, p_stride, true, false); - - auto greater_than_0_desc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_CMP_GT); - auto greater_than_0_op = - binary_pw_op_create(pTensor, zeroTensor, dropoutMaskTensor, greater_than_0_desc); - ops.push_back(std::move(greater_than_0_op)); - - // scale for the dropout - auto dpAfterScaleTensor = - tensor_create(CUDNN_DATA_FLOAT, VIRTUAL_ID + 306, p_dim, p_stride, true, false); - - auto mul_0_desc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_MUL); - auto mul_0_op = - binary_pw_op_create(dpTensor, dropoutScaleTensor, dpAfterScaleTensor, mul_0_desc); - ops.push_back(std::move(mul_0_op)); - - // drop the values based on the dropout mask - auto dpAfterDropoutTensor = - tensor_create(CUDNN_DATA_FLOAT, VIRTUAL_ID + 307, p_dim, p_stride, true, false); - - auto selection_0_desc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_BINARY_SELECT); - auto selection_0_op = ternary_pw_op_create(dpAfterScaleTensor, zeroTensor, dropoutMaskTensor, - dpAfterDropoutTensor, selection_0_desc); - ops.push_back(std::move(selection_0_op)); - - // softmax backward - auto dsTensor = createSoftmaxBackward(b, h, s_q, s_kv, d, layout, tensorType, ops, pAbsTensor, - dpAfterDropoutTensor); - - // mask - auto dsAfterMaskTensor = - createMask(b, h, s_q, s_kv, d, layout, mask_type, tensorType, ops, dsTensor, true); - - // dbias tensor - int64_t dbias_dim[4] = {1, h, s_q, s_kv}; - int64_t dbias_stride[4] = {h * s_q * s_kv, s_q * s_kv, s_kv, 1}; - auto dBiasTensor = tensor_create(tensorType, dBias_ID, dbias_dim, dbias_stride, false, false); - - if (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS) { - auto softmaxScaleTensor = - tensor_create(CUDNN_DATA_FLOAT, S_CONST_ID, scale_dim, scale_stride, false, true); - auto softmaxScaleReciprocalTensor = - tensor_create(CUDNN_DATA_FLOAT, VIRTUAL_ID + 401, scale_dim, scale_stride, true, false); - auto dbiasBeforeScaleTensor = - tensor_create(CUDNN_DATA_FLOAT, VIRTUAL_ID + 402, dbias_dim, dbias_stride, true, false); - - // Define the reduction descriptor - auto reductionAddDesc = cudnn_frontend::ReductionDescBuilder() - .setComputeType(CUDNN_DATA_FLOAT) - .setReductionOp(CUDNN_REDUCE_TENSOR_ADD) - .build(); - - // Create a reduction add node to compute the dbias - auto reductionAdd_op = - cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_REDUCTION_DESCRIPTOR) - .setxDesc(dsAfterMaskTensor) - .setyDesc(dbiasBeforeScaleTensor) - .setreductionDesc(reductionAddDesc) - .build(); - ops.push_back(std::move(reductionAdd_op)); - - // take the reciprocal of the scale - auto reciprocal_scale_desc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_RECIPROCAL); - auto reciprocal_scale_op = unary_pw_op_create( - softmaxScaleTensor, softmaxScaleReciprocalTensor, reciprocal_scale_desc); - ops.push_back(std::move(reciprocal_scale_op)); - - // apply the scale - auto dBias_scale_desc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_MUL); - auto dBias_scale_op = binary_pw_op_create( - dbiasBeforeScaleTensor, softmaxScaleReciprocalTensor, dBiasTensor, dBias_scale_desc); - ops.push_back(std::move(dBias_scale_op)); - } - - // matmul to calculate dqTensor - // set padding value optionally to 0 for writing zeros to dqTensor (if not set, old - // behaviour) - auto matmul_2_Desc = cudnn_frontend::MatMulDescBuilder() - .setComputeType(CUDNN_DATA_FLOAT) - .setPaddingValue(0.0f) - .build(); - - auto matmul_op2 = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR) - .setaMatDesc(dsAfterMaskTensor) - .setbMatDesc(kTensor) - .setcMatDesc(dqTensor) - .setmOverrideDesc(seqlenQTensor) - .setkOverrideDesc(seqlenKTensor) - .setmatmulDesc(matmul_2_Desc) - .build(); - - ops.push_back(std::move(matmul_op2)); - - // reshape for transpose of ds - auto dsAfterMaskReshapeTensor = tensor_create(tensorType, VIRTUAL_ID + 308, p_transpose_dim, - p_transpose_stride, true, false); - - auto reshape_2_op = - cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_RESHAPE_DESCRIPTOR) - .setxDesc(dsAfterMaskTensor) - .setyDesc(dsAfterMaskReshapeTensor) - .build(); - - ops.push_back(std::move(reshape_2_op)); - - // matmul to calculate dkTensor - // set padding value optionally to 0 for writing zeros to dktensor (if not set, old - // behaviour) - auto matmul_3_Desc = cudnn_frontend::MatMulDescBuilder() - .setComputeType(CUDNN_DATA_FLOAT) - .setPaddingValue(0.0f) - .build(); - - auto matmul_op3 = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR) - .setaMatDesc(dsAfterMaskReshapeTensor) - .setbMatDesc(qTensor) - .setcMatDesc(dkTensor) - .setmOverrideDesc(seqlenKTensor) - .setkOverrideDesc(seqlenQTensor) - .setmatmulDesc(matmul_3_Desc) - .build(); - - ops.push_back(std::move(matmul_op3)); - - ///////////////////////////////////////////////////////////////// - - for (unsigned int i = 0; i < ops.size(); i++) { - all_ops.push_back(&ops[i]); - } - - // Create an Operation Graph - auto opGraph = cudnn_frontend::OperationGraphBuilder() - .setHandle(handle) - .setOperationGraph(all_ops.size(), all_ops.data()) - .build(); - - cudnn_frontend::EngineConfigList filtered_configs; - auto statuses = cudnn_frontend::get_heuristics_list<1>( - {"heuristics_instant"}, opGraph, allowAllConfig, filtered_configs, true); - - if (filtered_configs.size() == 0) { - cudnn_frontend::set_error_and_throw_exception( - nullptr, CUDNN_STATUS_NOT_SUPPORTED, - "run_mha_bprop: No config returned by the heuristics"); - } - - auto plan = cudnn_frontend::ExecutionPlanBuilder() - .setHandle(handle) - .setEngineConfig(filtered_configs[0], opGraph.getTag()) - .build(); - cache.insert({descriptor, plan}); - return plan; - }; - - auto plan = get_plan(fmha_bprop_cache, descriptor); - - auto plan_workspace_size = plan.getWorkspaceSize(); - - // Exit to request upper level API to allocate memory if needed - if (workspace == nullptr) { - size_t actual_seqlen_workspace_size = 2 * b * sizeof(int32_t); - *workspace_size = plan_workspace_size + actual_seqlen_workspace_size; - return; - } - - // cuDNN stream check needs to be moved here to support dummy kernel calls with - // null streams for sizing the cuDNN workspace. - NVTE_CHECK_CUDNN(cudnnSetStream(handle, stream)); - - constexpr size_t nthreads_per_block = 128; - const size_t grid = (b + nthreads_per_block - 1) / nthreads_per_block; - void *devActualSeqlenQ = static_cast(workspace) + plan_workspace_size; - void *devActualSeqlenK = static_cast(devActualSeqlenQ) + b * sizeof(int32_t); - cu_seqlens_to_actual_seqlens<<>>( - b, b, static_cast(devPtrCuSeqlenQ), - static_cast(devPtrCuSeqlenKV), static_cast(devActualSeqlenQ), - static_cast(devActualSeqlenK)); - NVTE_CHECK_CUDA(cudaGetLastError()); - - std::set> data_ptrs; - // add all the data pointers to be used in the variant pack - data_ptrs.insert(std::pair(dQ_ID, devPtrdQ)); - data_ptrs.insert(std::pair(dK_ID, devPtrdK)); - data_ptrs.insert(std::pair(dV_ID, devPtrdV)); - - data_ptrs.insert(std::pair(Q_ID, devPtrQ)); - data_ptrs.insert(std::pair(K_ID, devPtrK)); - data_ptrs.insert(std::pair(V_ID, devPtrV)); - data_ptrs.insert(std::pair(S_ID, devPtrS)); - data_ptrs.insert(std::pair(dO_ID, devPtrdO)); - data_ptrs.insert(std::pair(dS_ID, devPtrdS)); - data_ptrs.insert(std::pair(Q_SEQLEN_ID, devActualSeqlenQ)); - data_ptrs.insert(std::pair(K_SEQLEN_ID, devActualSeqlenK)); - - if (bias_type != NVTE_Bias_Type::NVTE_NO_BIAS) { - data_ptrs.insert(std::pair(dBias_ID, devPtrdBias)); - } - - float zeroVal = 0.0f; - float dropoutScale = 1.0f / (1.0f - dropout_probability); - - data_ptrs.insert(std::pair(DROPOUT_CONST_ID, &dropoutScale)); - data_ptrs.insert(std::pair(S_CONST_ID, &scaling_factor)); - data_ptrs.insert(std::pair(MASK_VAL_ID, &zeroVal)); - - auto variantPack = cudnn_frontend::VariantPackBuilder() - .setWorkspacePointer(workspace) - .setDataPointers(data_ptrs) - .build(); - - NVTE_CHECK_CUDNN(cudnnBackendExecute(handle, plan.get_raw_desc(), variantPack.get_raw_desc())); - } catch (cudnn_frontend::cudnnException &e) { - NVTE_ERROR(e.what()); - } -} - -} // namespace fused_attn - -using namespace transformer_engine::fused_attn; -void fused_attn_max_512_fwd(size_t batch, size_t num_head, size_t q_max_seqlen, - size_t kv_max_seqlen, size_t head_dim, bool is_training, - float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, - const Tensor *input_Bias, Tensor *output_O, - NVTETensorPack *Aux_CTX_Tensors, const Tensor *q_cu_seqlens, - const Tensor *kv_cu_seqlens, const Tensor *rng_state, Tensor *workspace, - cudaStream_t stream, cudnnHandle_t handle) { - using namespace transformer_engine; - - void *devPtrQ = input_Q->data.dptr; - void *devPtrK = input_K->data.dptr; - void *devPtrV = input_V->data.dptr; - - void *devPtrBias = input_Bias->data.dptr; - - void *devPtrO = output_O->data.dptr; - - void *devPtrS = nullptr; - - const DType q_type = input_Q->data.dtype; - const DType kv_type = input_K->data.dtype; - NVTE_CHECK(q_type == kv_type, "data type of Q must be equal to data type of KV."); - - if (Aux_CTX_Tensors->size == 0) { - Aux_CTX_Tensors->size = 1; - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - output_S->data.dptr = nullptr; - output_S->data.shape = {batch, num_head, q_max_seqlen, kv_max_seqlen}; - output_S->data.dtype = q_type; - } else if (Aux_CTX_Tensors->size == 1) { - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - devPtrS = output_S->data.dptr; - } else { - NVTE_ERROR("Unexpected Aux_CTX_Tensors->size."); - } - - void *devQCuSeqlen = q_cu_seqlens->data.dptr; - void *devKVCuSeqlen = kv_cu_seqlens->data.dptr; - - const DType rng_state_type = rng_state->data.dtype; - NVTE_CHECK(rng_state_type == DType::kInt64); - void *devPtrDropoutSeed = rng_state->data.dptr; - void *devPtrDropoutOffset = - static_cast(static_cast(rng_state->data.dptr) + 1); - - size_t workspace_size = 0; - - fused_attn_max_512_fwd_impl( - batch, num_head, q_max_seqlen, kv_max_seqlen, head_dim, is_training, attn_scale, p_dropout, - qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrS, devPtrO, devPtrBias, - devQCuSeqlen, devKVCuSeqlen, devPtrDropoutSeed, devPtrDropoutOffset, workspace->data.dptr, - &workspace_size, get_cudnn_dtype(q_type), stream, handle); - - if (workspace_size > 0) { - if (workspace->data.dptr == nullptr) { - workspace->data.shape = {workspace_size}; - workspace->data.dtype = DType::kByte; - return; - } - } else if (workspace_size == 0) { - workspace->data.shape = {1}; - workspace->data.dtype = DType::kByte; - return; - } else { - NVTE_ERROR("Unexpected workspace_size."); - } -} - -void fused_attn_max_512_bwd(size_t batch, size_t num_head, size_t q_max_seqlen, - size_t kv_max_seqlen, size_t head_dim, float attn_scale, - float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, const Tensor *input_Q, const Tensor *input_K, - const Tensor *input_V, const Tensor *input_dO, Tensor *output_S, - Tensor *output_dQ, Tensor *output_dK, Tensor *output_dV, - Tensor *output_dBias, const Tensor *q_cu_seqlens, - const Tensor *kv_cu_seqlens, Tensor *workspace, cudaStream_t stream, - cudnnHandle_t handle) { - using namespace transformer_engine; - - void *devPtrQ = input_Q->data.dptr; - void *devPtrK = input_K->data.dptr; - void *devPtrV = input_V->data.dptr; - - void *devPtrdO = input_dO->data.dptr; - - void *devPtrdQ = output_dQ->data.dptr; - void *devPtrdK = output_dK->data.dptr; - void *devPtrdV = output_dV->data.dptr; - - void *devPtrdBias = output_dBias->data.dptr; - - void *devPtrS = output_S->data.dptr; - - // devPtrdS reuses the memory of devPtrS - void *devPtrdS = devPtrS; - - void *devPtrQCuSeqlens = q_cu_seqlens->data.dptr; - void *devPtrKVCuSeqlens = kv_cu_seqlens->data.dptr; - - const auto q_type = input_Q->data.dtype; - const auto kv_type = input_K->data.dtype; - NVTE_CHECK(q_type == kv_type, "data type of Q must be equal to data type of KV."); - size_t workspace_size = 0; - - fused_attn_max_512_bwd_impl( - batch, num_head, q_max_seqlen, kv_max_seqlen, head_dim, attn_scale, p_dropout, qkv_layout, - mask_type, bias_type, devPtrQ, devPtrK, devPtrV, devPtrS, devPtrdQ, devPtrdK, devPtrdV, - devPtrdO, devPtrdS, devPtrdBias, devPtrQCuSeqlens, devPtrKVCuSeqlens, workspace->data.dptr, - &workspace_size, get_cudnn_dtype(q_type), stream, handle); - - if (workspace_size > 0) { - if (workspace->data.dptr == nullptr) { - workspace->data.shape = {workspace_size}; - workspace->data.dtype = DType::kByte; - return; - } - } else if (workspace_size == 0) { - workspace->data.shape = {1}; - workspace->data.dtype = DType::kByte; - return; - } else { - NVTE_ERROR("Unexpected workspace_size."); - } -} -} // namespace transformer_engine diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_max512_seqlen.h b/transformer_engine/common/fused_attn/fused_attn_f16_max512_seqlen.h deleted file mode 100644 index 1e59d4dc8f..0000000000 --- a/transformer_engine/common/fused_attn/fused_attn_f16_max512_seqlen.h +++ /dev/null @@ -1,41 +0,0 @@ -/************************************************************************* - * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * - * See LICENSE for license information. - ************************************************************************/ - -/*! \file fused_attn_fp16_bf16_max_seqlen_512.h - * \brief Functions for fused attention for half precision with seqlen <= 512 - */ - -#ifndef TRANSFORMER_ENGINE_COMMON_FUSED_ATTN_FUSED_ATTN_MAX_512_H_ -#define TRANSFORMER_ENGINE_COMMON_FUSED_ATTN_FUSED_ATTN_MAX_512_H_ - -#include - -#include "common/common.h" -#include "transformer_engine/fused_attn.h" - -namespace transformer_engine { -void fused_attn_max_512_fwd(size_t batch, size_t num_head, size_t q_max_seqlen, - size_t kv_max_seqlen, size_t head_dim, bool is_training, - float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, - const Tensor *input_Bias, Tensor *output_O, - NVTETensorPack *Aux_CTX_Tensors, const Tensor *q_cu_seqlens, - const Tensor *kv_cu_seqlens, const Tensor *rng_state, Tensor *workspace, - cudaStream_t stream, cudnnHandle_t handle); - -void fused_attn_max_512_bwd(size_t batch, size_t num_head, size_t q_max_seqlen, - size_t kv_max_seqlen, size_t head_dim, float attn_scale, - float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, const Tensor *input_Q, const Tensor *input_K, - const Tensor *input_V, const Tensor *input_dO, Tensor *output_S, - Tensor *output_dQ, Tensor *output_dK, Tensor *output_dV, - Tensor *output_dBias, const Tensor *q_cu_seqlens, - const Tensor *kv_cu_seqlens, Tensor *workspace, cudaStream_t stream, - cudnnHandle_t handle); -} // namespace transformer_engine - -#endif // TRANSFORMER_ENGINE_COMMON_FUSED_ATTN_FUSED_ATTN_MAX_512_H_ diff --git a/transformer_engine/common/include/transformer_engine/fused_attn.h b/transformer_engine/common/include/transformer_engine/fused_attn.h index 912dc32d35..ab69d243ac 100644 --- a/transformer_engine/common/include/transformer_engine/fused_attn.h +++ b/transformer_engine/common/include/transformer_engine/fused_attn.h @@ -156,11 +156,9 @@ enum NVTE_Softmax_Type { enum NVTE_Fused_Attn_Backend { /*! No supported backend */ NVTE_No_Backend = -1, - /*! cuDNN-based FP16/BF16 fused attention for <= 512 sequence length */ - NVTE_F16_max512_seqlen = 0, /*! cuDNN-based FP16/BF16 fused attention for any sequence length */ NVTE_F16_arbitrary_seqlen = 1, - /*! cuDNN-based FP8 fused attention for <= 512 sequence length */ + /*! cuDNN-based FP8 fused attention */ NVTE_FP8 = 2, }; diff --git a/transformer_engine/common/util/pybind_helper.h b/transformer_engine/common/util/pybind_helper.h index fdfa47da8f..ef7687e3e9 100644 --- a/transformer_engine/common/util/pybind_helper.h +++ b/transformer_engine/common/util/pybind_helper.h @@ -79,7 +79,6 @@ .value("NVTE_Paged_KV_THD_SBHD_SBHD", NVTE_QKV_Layout::NVTE_Paged_KV_THD_SBHD_SBHD) \ .value("NVTE_BHSD_BHSD_BHSD", NVTE_QKV_Layout::NVTE_BHSD_BHSD_BHSD); \ pybind11::enum_(m, "NVTE_Fused_Attn_Backend", pybind11::module_local()) \ - .value("NVTE_F16_max512_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) \ .value("NVTE_F16_arbitrary_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) \ .value("NVTE_FP8", NVTE_Fused_Attn_Backend::NVTE_FP8) \ .value("NVTE_No_Backend", NVTE_Fused_Attn_Backend::NVTE_No_Backend); \ diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index 40d02f40e1..489bfde997 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -353,10 +353,7 @@ def abstract( config.window_size, ).get_fused_attn_backend() - if backend == NVTE_Fused_Attn_Backend.NVTE_F16_max512_seqlen: - softmax_shape = (*batch_shape, attn_heads, q_max_seqlen, kv_max_seqlen) - softmax_dtype = q_dtype - elif backend == NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen: + if backend == NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen: # cuDNN 9.6 reduces the required softmax shape if get_cudnn_version() >= (9, 6, 0): if config.qkv_layout.is_thd(): diff --git a/transformer_engine/jax/csrc/extensions/attention.cpp b/transformer_engine/jax/csrc/extensions/attention.cpp index 76f2d92891..ae489b5730 100644 --- a/transformer_engine/jax/csrc/extensions/attention.cpp +++ b/transformer_engine/jax/csrc/extensions/attention.cpp @@ -28,7 +28,6 @@ NVTE_Fused_Attn_Backend GetFusedAttnBackend( /* NOTE: PrepareFusedAttnForwardAuxTensors unifies the auxiliary tensor pack logic from the fused attention forward kernels in: - - common/fused_attn/fused_attn_f16_max512_seqlen.cu lines 594-634 and 773-812 - common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu lines 1270-1281 and 1348-1359 */ void PrepareFusedAttnForwardAuxTensors(NVTETensorPack *tensor_pack, const size_t input_batch, @@ -40,7 +39,6 @@ void PrepareFusedAttnForwardAuxTensors(NVTETensorPack *tensor_pack, const size_t void *bias_buf = nullptr, void *softmax_offset_buf = nullptr) { // all backends need softmax but expect different shapes/dtypes - // start with the max512 sequence length softmax shape/dtype and correct later tensor_pack->size = 1; NVTETensor &softmax_aux = tensor_pack->tensors[0]; NVTEBasicTensor softmax_aux_data; @@ -128,14 +126,6 @@ void PrepareFusedAttnBackwardAuxTensors(NVTETensorPack *tensor_pack, const size_ dummy_backend, softmax_buf, rng_state_buf, bias_buf, softmax_offset_buf); - // correct softmax shape for max512 sequence length kernel - if (backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { - NVTEBasicTensor softmax_aux_data = - nvte_get_tensor_param(tensor_pack->tensors[0], kNVTERowwiseData); - softmax_aux_data.shape.data[3] = kv_max_seqlen; // {B,H,Qs,1} -> {B,H,Qs,Ks} - softmax_aux_data.dtype = static_cast(dtype); - nvte_set_tensor_param(&(tensor_pack->tensors[0]), kNVTERowwiseData, &softmax_aux_data); - } } pybind11::tuple GetFusedAttnForwardWorkspaceSizes( diff --git a/transformer_engine/jax/csrc/extensions/pybind.cpp b/transformer_engine/jax/csrc/extensions/pybind.cpp index b002643942..70d0403b3e 100644 --- a/transformer_engine/jax/csrc/extensions/pybind.cpp +++ b/transformer_engine/jax/csrc/extensions/pybind.cpp @@ -189,7 +189,6 @@ PYBIND11_MODULE(transformer_engine_jax, m) { pybind11::enum_(m, "NVTE_Fused_Attn_Backend", pybind11::module_local()) .value("NVTE_No_Backend", NVTE_Fused_Attn_Backend::NVTE_No_Backend) - .value("NVTE_F16_max512_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) .value("NVTE_F16_arbitrary_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) .value("NVTE_FP8", NVTE_Fused_Attn_Backend::NVTE_FP8); diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index 4104820a1c..dd452e3e89 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -1859,31 +1859,10 @@ def backward(ctx, d_out, *_args): class FusedAttention(torch.nn.Module): - """Dot product attention, with multiple backends: + """Dot product attention, using cuDNN fused attention backend: - 1. FusedAttnBackend["F16_max512_seqlen"] - cuDNN based fused attention for FP16/BF16 and <=512 sequence length. - 2. FusedAttnBackend["F16_arbitrary_seqlen"] + FusedAttnBackend["F16_arbitrary_seqlen"] cuDNN based fused attention for FP16/BF16 and any sequence length. - - Support matrix: - - | backend | 1 | 2 | - | flash based | no | yes | - | cuDNN based | yes | yes | - | qkv dtype | fp16/bf16 | fp16/bf16 | - | attn_type | self/cross | self/cross | - | qkv_layout | | | - | - (q,k,v) | sb3hd, bs3hd | sb3hd, bs3hd, sbh3d, bsh3d | - | | sbhd_sb2hd, bshd_bs2hd | sbhd_sb2hd, bshd_bs2hd | - | | bshd_bshd_bshd | sbhd_sbh2d, bshd_bsh2d | - | | | sbhd_sbhd_sbhd, bshd_bshd_bshd | - | mask_type | causal/padding/no_mask | causal/padding/no_mask | - | bias_type | post_scale_bias/no_bias | post_scale_bias/alibi/no_bias | - | dropout | yes | yes | - | max_seqlen | <=512, multiple of 64 | any, multiple of 64 | - | head_dim | 64 | <=128, multiple of 8 | - | output dtype | fp16/bf16 | fp16/bf16 | """ def __init__( diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index 7b10593acf..32eb1b597a 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -980,10 +980,7 @@ def cp_p2p_fwd_fused_attn( ) if fp8: - if qkv_layout != "t3hd": - softmax_lse_per_step, rng_states = aux_ctx_tensors - else: - softmax_lse_per_step, _, rng_states = aux_ctx_tensors + softmax_lse_per_step, rng_states = aux_ctx_tensors else: softmax_lse_per_step, rng_states, *rest = aux_ctx_tensors attn_bias = rest[0] if len(rest) > 0 else None @@ -1169,17 +1166,7 @@ def cp_p2p_bwd_fused_attn( section, ): """Per-tile backward call of CP P2P with FusedAttention backend""" - if fp8: - if qkv_layout == "t3hd": - aux_tensors = [ - softmax_lse, - softmax_lse, - rng_states[cp_size - step - 1], - ] - else: - aux_tensors = [softmax_lse, rng_states[cp_size - step - 1]] - else: - aux_tensors = [softmax_lse, rng_states[cp_size - step - 1]] + aux_tensors = [softmax_lse, rng_states[cp_size - step - 1]] max_seqlen_q_ = max_seqlen_q max_seqlen_kv_ = max_seqlen_kv @@ -1195,17 +1182,7 @@ def cp_p2p_bwd_fused_attn( attn_mask_type_ = "padding" if "padding" in attn_mask_type else "no_mask" elif section == "upper-triangle": q_part, out_part, dout_part = [x.contiguous() for x in [q_part, out_part, dout_part]] - if fp8: - if qkv_layout == "t3hd": - aux_tensors = [ - softmax_lse_, - softmax_lse_, - rng_states[cp_size - step - 1], - ] - else: - aux_tensors = [softmax_lse_, rng_states[cp_size - step - 1]] - else: - aux_tensors = [softmax_lse_, rng_states[cp_size - step - 1]] + aux_tensors = [softmax_lse_, rng_states[cp_size - step - 1]] max_seqlen_q_ = max_seqlen_q // 2 cu_seqlens_q_padded_ = None if cu_seqlens_q_padded is None else cu_seqlens_q_padded // 2 @@ -3223,10 +3200,7 @@ def forward( **fp8_meta_kwargs, ) if fp8: - if qkv_layout != "t3hd": - softmax_lse_per_step[i], rng_states[i] = aux_ctx_tensors - else: - softmax_lse_per_step[i], _, rng_states[i] = aux_ctx_tensors + softmax_lse_per_step[i], rng_states[i] = aux_ctx_tensors else: softmax_lse_per_step[i], rng_states[i], *_ = aux_ctx_tensors if return_max_logit: @@ -3588,17 +3562,10 @@ def backward(ctx, dout, *_args): out_part = out.select(seq_dim_o, i).contiguous() dout_part = dout.select(seq_dim_o, i).contiguous() if ctx.use_fused_attention: - if ctx.fp8 and ctx.qkv_layout == "t3hd": - aux_ctx_tensors = [ - softmax_lse_per_step[i], - softmax_lse_per_step[i], - rng_states[i], - ] - else: - aux_ctx_tensors = [ - softmax_lse_per_step[i], - rng_states[i], - ] + aux_ctx_tensors = [ + softmax_lse_per_step[i], + rng_states[i], + ] fused_attn_backend = tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen fp8_meta_kwargs = {} new_qkv_layout = ctx.qkv_layout diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index ed87423534..7df5daabe5 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -1217,10 +1217,6 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt "Disabling FusedAttention as dbias calculation is not supported for 111s" ) use_fused_attention = False - elif not fu_core_attention_bias_requires_grad: - # max512 backend will only support [1, h, s, s] - os.environ["NVTE_FUSED_ATTN_BACKEND"] = "1" - # Filter: cuDNN support fused_attention_backend = None if use_fused_attention: @@ -1254,32 +1250,6 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt logger.debug("Disabling FusedAttention as no backend supports the provided input") use_fused_attention = False fused_attention_backend = None - if ( - use_fused_attention - and window_size is not None - and (window_size[0] != -1 or window_size[1] not in [-1, 0]) - and fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"] - ): - logger.debug( - "Disabling FusedAttention as only sub-backend %s does not support " - "slidng window attention", - int(fused_attention_backend), - ) - use_fused_attention = False - fused_attention_backend = None - if ( - use_fused_attention - and fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"] - and fu_core_attention_bias_type == "post_scale_bias" - and fu_core_attention_bias_shape != "1hss" - ): - logger.debug( - "Disabling FusedAttention as cuDNN sub-backend 0 only supports post_scale_bias in" - " [1, H, S, S] shape" - ) - use_fused_attention = False - fused_attention_backend = None - # Filter: Determinism # backend | deterministic # --------------------------------------------- diff --git a/transformer_engine/pytorch/cpp_extensions/fused_attn.py b/transformer_engine/pytorch/cpp_extensions/fused_attn.py index 01e139da46..d8f3011445 100644 --- a/transformer_engine/pytorch/cpp_extensions/fused_attn.py +++ b/transformer_engine/pytorch/cpp_extensions/fused_attn.py @@ -98,13 +98,12 @@ } FusedAttnBackend = { - "F16_max512_seqlen": NVTE_Fused_Attn_Backend.NVTE_F16_max512_seqlen, "F16_arbitrary_seqlen": NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen, "FP8": NVTE_Fused_Attn_Backend.NVTE_FP8, "No_Backend": NVTE_Fused_Attn_Backend.NVTE_No_Backend, } -BACKEND_F16m512_FP8_THREADS_PER_CTA = 128 +BACKEND_FP8_THREADS_PER_CTA = 128 BACKEND_F16arb_ELTS_PER_THREADS = 16 META_QKV = FP8FwdTensorIdx.GEMM1_OUTPUT @@ -249,22 +248,18 @@ def fused_attn_fwd( if is_training is False, aux_ctx_tensors = None softmax-related tensors: - 1. if fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"] - softmax: torch.Tensor - Softmax(Q*K.T) - shape [batch_size, num_heads, max_seqlen_q, max_seqlen_kv], dtype float32 - 2. if fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"] + 1. if fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"] softmaxStats: torch.Tensor log(sum(e^(x - max(x)))), where x=Q*K.T shape [batch_size, num_heads, max_seqlen_q, 1], dtype float32 - 3. if fused_attention_backend == FusedAttnBackend["FP8"] + 2. if fused_attention_backend == FusedAttnBackend["FP8"] M: torch.Tensor max(Q*K.T) shape [batch_size, num_heads, max_seqlen_q, 1], dtype float32 ZInv: torch.Tensor, only allocated for T3HD path 1/sum(e^(x - max(x))), where x=Q*K.T shape [batch_size, num_heads, max_seqlen_q, 1], dtype float32 - rng_state: torch.Tensor, optional, if backend is not F16_max512_seqlen + rng_state: torch.Tensor state of the random number generator; [seed, offset], dtype uint64 max_logit : if return_max_logit = True, shape [h] and same data type as O; otherwise None @@ -299,19 +294,13 @@ def fused_attn_fwd( f" q.dtype={q.dtype}, backend={fused_attention_backend}." ) - # BF16/FP16 fused attention API from fmha_v1 apex - if fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]: - rng_elts_per_thread = ( - max_seqlen_q * max_seqlen_kv + BACKEND_F16m512_FP8_THREADS_PER_CTA - 1 - ) // BACKEND_F16m512_FP8_THREADS_PER_CTA - # BF16/FP16 fused attention API from fmha_v2 - elif fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]: + if fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]: rng_elts_per_thread = BACKEND_F16arb_ELTS_PER_THREADS # FP8 fused attention API from fmha_v2 elif fused_attention_backend == FusedAttnBackend["FP8"]: rng_elts_per_thread = ( - max_seqlen_q * max_seqlen_q + BACKEND_F16m512_FP8_THREADS_PER_CTA - 1 - ) // BACKEND_F16m512_FP8_THREADS_PER_CTA + max_seqlen_q * max_seqlen_q + BACKEND_FP8_THREADS_PER_CTA - 1 + ) // BACKEND_FP8_THREADS_PER_CTA else: raise ValueError(f"Unsupported backend {fused_attention_backend}") @@ -566,13 +555,12 @@ def fused_attn_bwd( f" q.dtype={q.dtype}, backend={fused_attention_backend}." ) - if fused_attention_backend != FusedAttnBackend["F16_max512_seqlen"]: - if len(aux_ctx_tensors) < 1: - raise ValueError( - "aux_ctx_tensors must contain rng_state as its last element," - f" but got len(aux_ctx_tensors)={len(aux_ctx_tensors)}" - f" for backend={fused_attention_backend}." - ) + if len(aux_ctx_tensors) < 1: + raise ValueError( + "aux_ctx_tensors must contain rng_state as its last element," + f" but got len(aux_ctx_tensors)={len(aux_ctx_tensors)}" + f" for backend={fused_attention_backend}." + ) output_tensors = tex.fused_attn_bwd( max_seqlen_q, diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cpp b/transformer_engine/pytorch/csrc/extensions/attention.cpp index e6781bd58a..8d7a24dcec 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cpp +++ b/transformer_engine/pytorch/csrc/extensions/attention.cpp @@ -271,7 +271,6 @@ std::vector fused_attn_fwd( nvte_set_tensor_param(&nvte_aux_tensor_pack.tensors[i], kNVTERowwiseData, &temp_data); }; // allocate memory for nvte_aux_tensor_pack.tensors - // f16_max512 : S [b, h, sq, skv] // f16_arbitrary: // return_max_logit=false: S [b, h, sq, 1], rng_state [2], (optional) Bias [1, h, sq, skv], (optional) SoftmaxOffset [1, h, 1, 1] // return_max_logit=true: S [b, h, sq, 1], Max [b, h, sq, 1], rng_state [2], (optional) Bias [1, h, sq, skv], (optional) SoftmaxOffset [1, h, 1, 1] From e6cfec21cbdde7425fb339ba369b2783355c2eae Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 30 Apr 2026 23:07:45 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/jax/csrc/extensions/attention.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/transformer_engine/jax/csrc/extensions/attention.cpp b/transformer_engine/jax/csrc/extensions/attention.cpp index ae489b5730..ed136d7b9e 100644 --- a/transformer_engine/jax/csrc/extensions/attention.cpp +++ b/transformer_engine/jax/csrc/extensions/attention.cpp @@ -125,7 +125,6 @@ void PrepareFusedAttnBackwardAuxTensors(NVTETensorPack *tensor_pack, const size_ q_max_seqlen, kv_max_seqlen, dtype, dummy_bias_type, dummy_backend, softmax_buf, rng_state_buf, bias_buf, softmax_offset_buf); - } pybind11::tuple GetFusedAttnForwardWorkspaceSizes(