Skip to content

[PyTorch][Core] Fix CUBLAS GGEMM when weight dims are not divisible by 128#2954

Merged
vthumbe1503 merged 7 commits intoNVIDIA:mainfrom
vthumbe1503:fix_cublas_grouped_gemm_gptoss_sizes
May 4, 2026
Merged

[PyTorch][Core] Fix CUBLAS GGEMM when weight dims are not divisible by 128#2954
vthumbe1503 merged 7 commits intoNVIDIA:mainfrom
vthumbe1503:fix_cublas_grouped_gemm_gptoss_sizes

Conversation

@vthumbe1503
Copy link
Copy Markdown
Collaborator

Description

Please include a brief summary of the changes, relevant motivation and context.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

vthumbe1503 and others added 2 commits May 1, 2026 21:00
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 1, 2026

Greptile Summary

This PR fixes incorrect MXFP8 scale-inverse pointer offsets in the cuBLAS grouped GEMM kernel when weight dimensions are not divisible by 128. The old code computed scale offsets as data_byte_offset / 32, which is only correct when dimensions are exact multiples of 32 and there is no tile-padding; the new code uses padded_mxfp8_scale_inv_bytes() that correctly accounts for the 128×4 swizzle-tile padding boundary. A new test shape (2, 256, 2880, 2880) (2880 = 90×32, not divisible by 128) directly exercises the fixed path.

Confidence Score: 4/5

PR is safe to merge; the fix is narrowly scoped to the MXFP8 scale-offset calculation and is well-validated by a new non-128-divisible test shape.

The core logic in padded_mxfp8_scale_inv_bytes and compute_grouped_tensor_mxfp8_scale_inv_offset is consistent with the existing compute_grouped_tensor_offset pattern and with the swizzle tile constants. rowwise is correctly propagated in all three entry-point code paths. No P0 issues found. Score capped at 4 because the fix touches GPU kernel pointer arithmetic that is hard to validate statically.

transformer_engine/common/gemm/cublaslt_grouped_gemm.cu — specifically padded_mxfp8_scale_inv_bytes and its consumers; reviewer should confirm that on-the-fly swizzling produces a buffer whose stride exactly matches the padded formula for the grouped (non-discrete) weight path.

Important Files Changed

Filename Overview
transformer_engine/common/gemm/cublaslt_grouped_gemm.cu Core fix: replaces data_offset/32 with compute_grouped_tensor_mxfp8_scale_inv_offset() for padded MXFP8 scale layout; adds rowwise field to GroupedOperandSelection and threads it through setup_grouped_gemm_kernel; logic is correct across all three grouped-GEMM entry-points.
transformer_engine/common/cast/mxfp8/swizzle.cuh Promotes TILE_DIM_X=4 and TILE_DIM_Y=128 to namespace-level constexpr so they can be reused in the grouped-GEMM kernel; no functional change to gemm_swizzled_scale_idx.
tests/pytorch/test_numerics.py Refactors _make_grouped_tensor_quantized_mxfp8 to use explicit rowwise/columnwise/is_weight args, adds _per_tensor_quantize_mxfp8 helper, and adds the non-128-divisible shape (2880) as a parametrize case.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A["nvte_grouped_gemm / nvte_grouped_gemm_with_discrete_out\n(GroupedTensor A & B)"] --> B["select_grouped_operand(A/B)\nsets sel.rowwise = true/false\nbased on MXFP8 + trans flags"]
    C["nvte_grouped_gemm_with_discrete_inputA\n(discrete A_list, GroupedTensor B)"] --> D["choose_grouped_operand_storage\nA_sel.rowwise = choice.use_rowwise\n(new in this PR)"]
    C --> E["select_grouped_operand(B)\nsets B_sel.rowwise"]
    B --> F["launch_grouped_gemm_setup\na_rowwise = A_sel.rowwise\nb_rowwise = B_sel.rowwise"]
    D --> F
    E --> F
    F --> G["setup_grouped_gemm_kernel<<<...>>>"]
    G --> H{"scaling_mode ==\nNVTE_MXFP8_1D_SCALING?"}
    H -- "Yes (old)" --> I["scale_ptr = base + data_offset/32\n❌ wrong when dims not divisible by 128"]
    H -- "Yes (new)" --> J["scale_offset = compute_grouped_tensor_mxfp8_scale_inv_offset\n(uses padded_mxfp8_scale_inv_bytes with 128x4 tile padding)\n✅ correct for all shapes"]
    H -- "No (tensor scaling)" --> K["scale_ptr = base + tensor_idx\n(unchanged)"]
    J --> L["a_scale_inv_ptrs[idx] / b_scale_inv_ptrs[idx]\npointed at correct per-expert scale block"]
Loading

Reviews (4): Last reviewed commit: "Fix test case dimensions in test_numeric..." | Re-trigger Greptile

Comment thread transformer_engine/common/gemm/cublaslt_grouped_gemm.cu
Comment thread transformer_engine/common/gemm/cublaslt_grouped_gemm.cu Outdated
Comment thread tests/pytorch/test_numerics.py Outdated
Comment thread tests/pytorch/test_numerics.py Outdated
Copy link
Copy Markdown
Collaborator

@zhongbozhu zhongbozhu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@zhongbozhu
Copy link
Copy Markdown
Collaborator

For future reference, this fix PR should be applied to NVFP4 recipe as well.

@vthumbe1503
Copy link
Copy Markdown
Collaborator Author

/te-ci pytorch

@vthumbe1503 vthumbe1503 requested a review from timmoon10 May 1, 2026 22:22
Copy link
Copy Markdown
Collaborator

@timmoon10 timmoon10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall this looks good. My only serious suggestion is to change the name of GroupedOperandSelection.scale_rowwise to rowwise, since that actually changes the intent of the kernel.

Comment thread transformer_engine/common/gemm/cublaslt_grouped_gemm.cu Outdated
Comment thread transformer_engine/common/gemm/cublaslt_grouped_gemm.cu Outdated
Comment thread tests/pytorch/test_numerics.py Outdated
timmoon10
timmoon10 previously approved these changes May 2, 2026
@vthumbe1503
Copy link
Copy Markdown
Collaborator Author

/te-ci

@vthumbe1503 vthumbe1503 changed the title Fix CUBLAS GGEMM when weight dims are not divisible by 128 [PyTorch][Core] Fix CUBLAS GGEMM when weight dims are not divisible by 128 May 2, 2026
Comment on lines +3163 to +3175
quantizer.optimize_for_gemm = not is_weight
grouped_input = torch.cat(tensors, dim=0)
first_dims = torch.tensor([t.shape[0] for t in tensors], dtype=torch.int64, device=device)
if is_weight:
first_dims = None
else:
first_dims = torch.tensor([t.shape[0] for t in tensors], dtype=torch.int64, device=device)
return tex.group_quantize(grouped_input, quantizer, len(tensors), first_dims)


def _per_tensor_quantize_mxfp8(
tensors: List[torch.Tensor],
*,
rowwise: bool,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 _per_tensor_quantize_mxfp8 produces non-swizzled scales that break discrete_in case

MXFP8Quantizer.optimize_for_gemm defaults to False (set in Quantizer.__init__), so every tensor returned by this helper has with_gemm_swizzled_scales=False. When case == "discrete_in", the test passes A_fp8 directly to general_grouped_gemm_for_grouped_tensor, which routes to nvte_grouped_gemm_with_discrete_inputA. That function contains a hard NVTE_CHECK(A_list_info.with_gemm_swizzled_scales, "MXFP8 grouped GEMM: A scales must be swizzled for GEMM."), so the test will throw for every MXFP8 shape when case == "discrete_in".

The old code called grouped_A.split_into_quantized_tensors() on a grouped tensor built with optimize_for_gemm=True, so the split tensors inherited with_gemm_swizzled_scales=True.

Suggested change
quantizer.optimize_for_gemm = not is_weight
grouped_input = torch.cat(tensors, dim=0)
first_dims = torch.tensor([t.shape[0] for t in tensors], dtype=torch.int64, device=device)
if is_weight:
first_dims = None
else:
first_dims = torch.tensor([t.shape[0] for t in tensors], dtype=torch.int64, device=device)
return tex.group_quantize(grouped_input, quantizer, len(tensors), first_dims)
def _per_tensor_quantize_mxfp8(
tensors: List[torch.Tensor],
*,
rowwise: bool,
def _per_tensor_quantize_mxfp8(
tensors: List[torch.Tensor],
*,
rowwise: bool,
columnwise: bool,
) -> List:
"""Quantize each tensor individually with MXFP8.
Used to build reference discrete inputs for grouped GEMM.
"""
quantizer = MXFP8Quantizer(
fp8_dtype=tex.DType.kFloat8E4M3,
rowwise=rowwise,
columnwise=columnwise,
)
quantizer.optimize_for_gemm = True
return [quantizer(t) for t in tensors]

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If inputs are not swizzled already, then TE ensures to do that before GEMM

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So not needed for a functionality test

Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
@vthumbe1503
Copy link
Copy Markdown
Collaborator Author

/te-ci

timmoon10
timmoon10 previously approved these changes May 2, 2026
Total dim should be divisible by 128

Signed-off-by: vthumbe1503 <vthumbe@nvidia.com>
@vthumbe1503
Copy link
Copy Markdown
Collaborator Author

/te-ci pytorch

@vthumbe1503 vthumbe1503 merged commit ad4b3fd into NVIDIA:main May 4, 2026
21 of 24 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants