[PyTorch][Core] Fix CUBLAS GGEMM when weight dims are not divisible by 128#2954
Conversation
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
for more information, see https://pre-commit.ci
Greptile SummaryThis PR fixes incorrect MXFP8 scale-inverse pointer offsets in the cuBLAS grouped GEMM kernel when weight dimensions are not divisible by 128. The old code computed scale offsets as Confidence Score: 4/5PR is safe to merge; the fix is narrowly scoped to the MXFP8 scale-offset calculation and is well-validated by a new non-128-divisible test shape. The core logic in transformer_engine/common/gemm/cublaslt_grouped_gemm.cu — specifically Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A["nvte_grouped_gemm / nvte_grouped_gemm_with_discrete_out\n(GroupedTensor A & B)"] --> B["select_grouped_operand(A/B)\nsets sel.rowwise = true/false\nbased on MXFP8 + trans flags"]
C["nvte_grouped_gemm_with_discrete_inputA\n(discrete A_list, GroupedTensor B)"] --> D["choose_grouped_operand_storage\nA_sel.rowwise = choice.use_rowwise\n(new in this PR)"]
C --> E["select_grouped_operand(B)\nsets B_sel.rowwise"]
B --> F["launch_grouped_gemm_setup\na_rowwise = A_sel.rowwise\nb_rowwise = B_sel.rowwise"]
D --> F
E --> F
F --> G["setup_grouped_gemm_kernel<<<...>>>"]
G --> H{"scaling_mode ==\nNVTE_MXFP8_1D_SCALING?"}
H -- "Yes (old)" --> I["scale_ptr = base + data_offset/32\n❌ wrong when dims not divisible by 128"]
H -- "Yes (new)" --> J["scale_offset = compute_grouped_tensor_mxfp8_scale_inv_offset\n(uses padded_mxfp8_scale_inv_bytes with 128x4 tile padding)\n✅ correct for all shapes"]
H -- "No (tensor scaling)" --> K["scale_ptr = base + tensor_idx\n(unchanged)"]
J --> L["a_scale_inv_ptrs[idx] / b_scale_inv_ptrs[idx]\npointed at correct per-expert scale block"]
Reviews (4): Last reviewed commit: "Fix test case dimensions in test_numeric..." | Re-trigger Greptile |
|
For future reference, this fix PR should be applied to NVFP4 recipe as well. |
|
/te-ci pytorch |
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
for more information, see https://pre-commit.ci
|
/te-ci |
| quantizer.optimize_for_gemm = not is_weight | ||
| grouped_input = torch.cat(tensors, dim=0) | ||
| first_dims = torch.tensor([t.shape[0] for t in tensors], dtype=torch.int64, device=device) | ||
| if is_weight: | ||
| first_dims = None | ||
| else: | ||
| first_dims = torch.tensor([t.shape[0] for t in tensors], dtype=torch.int64, device=device) | ||
| return tex.group_quantize(grouped_input, quantizer, len(tensors), first_dims) | ||
|
|
||
|
|
||
| def _per_tensor_quantize_mxfp8( | ||
| tensors: List[torch.Tensor], | ||
| *, | ||
| rowwise: bool, |
There was a problem hiding this comment.
_per_tensor_quantize_mxfp8 produces non-swizzled scales that break discrete_in case
MXFP8Quantizer.optimize_for_gemm defaults to False (set in Quantizer.__init__), so every tensor returned by this helper has with_gemm_swizzled_scales=False. When case == "discrete_in", the test passes A_fp8 directly to general_grouped_gemm_for_grouped_tensor, which routes to nvte_grouped_gemm_with_discrete_inputA. That function contains a hard NVTE_CHECK(A_list_info.with_gemm_swizzled_scales, "MXFP8 grouped GEMM: A scales must be swizzled for GEMM."), so the test will throw for every MXFP8 shape when case == "discrete_in".
The old code called grouped_A.split_into_quantized_tensors() on a grouped tensor built with optimize_for_gemm=True, so the split tensors inherited with_gemm_swizzled_scales=True.
| quantizer.optimize_for_gemm = not is_weight | |
| grouped_input = torch.cat(tensors, dim=0) | |
| first_dims = torch.tensor([t.shape[0] for t in tensors], dtype=torch.int64, device=device) | |
| if is_weight: | |
| first_dims = None | |
| else: | |
| first_dims = torch.tensor([t.shape[0] for t in tensors], dtype=torch.int64, device=device) | |
| return tex.group_quantize(grouped_input, quantizer, len(tensors), first_dims) | |
| def _per_tensor_quantize_mxfp8( | |
| tensors: List[torch.Tensor], | |
| *, | |
| rowwise: bool, | |
| def _per_tensor_quantize_mxfp8( | |
| tensors: List[torch.Tensor], | |
| *, | |
| rowwise: bool, | |
| columnwise: bool, | |
| ) -> List: | |
| """Quantize each tensor individually with MXFP8. | |
| Used to build reference discrete inputs for grouped GEMM. | |
| """ | |
| quantizer = MXFP8Quantizer( | |
| fp8_dtype=tex.DType.kFloat8E4M3, | |
| rowwise=rowwise, | |
| columnwise=columnwise, | |
| ) | |
| quantizer.optimize_for_gemm = True | |
| return [quantizer(t) for t in tensors] |
There was a problem hiding this comment.
If inputs are not swizzled already, then TE ensures to do that before GEMM
There was a problem hiding this comment.
So not needed for a functionality test
|
/te-ci |
Total dim should be divisible by 128 Signed-off-by: vthumbe1503 <vthumbe@nvidia.com>
|
/te-ci pytorch |
Description
Please include a brief summary of the changes, relevant motivation and context.
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: