Skip to content

Implement per-token NVFP4 fprop recipe#2931

Draft
zianglih wants to merge 23 commits intoNVIDIA:mainfrom
zianglih:fp4-per-token
Draft

Implement per-token NVFP4 fprop recipe#2931
zianglih wants to merge 23 commits intoNVIDIA:mainfrom
zianglih:fp4-per-token

Conversation

@zianglih
Copy link
Copy Markdown
Contributor

@zianglih zianglih commented Apr 27, 2026

Description

@HumansAnd

Implement per-token NVFP4 recipe with fprop only.
Currently, the per-token scaling is handled by separate pytorch code.
Quantization kernels are bitwise exact with existing TE reference implementation.

The following tests passed on B200:

python3 -m pytest --tb=auto tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py
python3 -m pytest --tb=auto tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py
python3 -m pytest --tb=auto tests/pytorch/test_backward_override.py
python3 -m pytest --tb=auto tests/pytorch/test_sanity.py
python3 -m pytest --tb=auto tests/pytorch/test_recipe.py
python3 -m pytest --tb=auto tests/pytorch/test_torch_compile.py
python3 -m pytest --tb=auto tests/pytorch/test_cpu_offloading.py
PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 NVTE_FUSED_ATTN=0 python3 -m pytest --tb=auto tests/pytorch/test_cuda_graphs.py
NVTE_TEST_NVINSPECT_ENABLED=1 NVTE_TEST_NVINSPECT_CONFIG_FILE=tests/pytorch/debug/test_configs/dummy_feature.yaml NVTE_TEST_NVINSPECT_FEATURE_DIRS=transformer_engine/debug/features PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest --tb=auto tests/pytorch/test_sanity.py

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Add a per_token_activation field in nvfp4 recipe, can be turned on by NVTE_NVFP4_PER_TOKEN_ACTIVATION
  • New per-token nvfp4 quantize kernels in transformer_engine/common/cast/nvfp4/quantize_pertoken_nvfp4.cuh, bitwise exact with existing TE pytorch reference implementation and per-tesor nvfp4 emulated implmentation.
  • Expand dequant kernel transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh to correctly handle this per-token nvfp4
  • In TransformerEngine/transformer_engine/pytorch/cpp_extensions/gemm.py, if per-token nvfp4 is detected, it conducts separate per-token scaling using pytorch code, after cublas gemm
  • Broad test coverage by expanding 7 test files
  • Modify 1d quant reference implementation in tests/cpp/operator/test_cast_nvfp4_transpose.cu to align with pytorch reference numerics

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@zianglih zianglih marked this pull request as draft April 27, 2026 06:24
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Apr 27, 2026

Greptile Summary

This PR implements per-token (per-row) NVFP4 fprop-only quantization for GroupedLinear, adding new CUDA kernels, amax buffer management, and a post-GEMM per-token scaling step in PyTorch. The kernel implementations are clean and the per-tensor fallback for backward is correctly enforced.

  • P1 — out_views[i].numel() crash in grouped GEMM (non-single_output path): When single_output=False the per-token early-return sets out_views = out. If the caller passes [None] * num_gemms, out_views[i].numel() immediately raises AttributeError. The contract that callers must pre-allocate outputs is not enforced or documented.
  • P2 — Bias fused into cuBLAS then manually subtracted: The bias is not stripped from gemm_args before the per-token cuBLAS call. The post-processing (out - bias) * scales + bias is algebraically correct but risks fp32 catastrophic cancellation when |bias| greatly exceeds |Z|. Stripping bias from the cuBLAS call and adding it once after scaling would be cleaner and more stable.

Confidence Score: 3/5

Safe to merge for fprop-only workloads without bias; grouped-GEMM path can crash at runtime when outputs are not pre-allocated.

One confirmed P1 (AttributeError on NoneType in grouped GEMM) plus the already-flagged previous P1s around the early-return return signature. Extensive test coverage reported on B200, but the pre-allocation contract is neither enforced nor documented.

transformer_engine/pytorch/cpp_extensions/gemm.py — grouped per-token path; transformer_engine/common/recipe/init.py — missing backward_override validation.

Important Files Changed

Filename Overview
transformer_engine/pytorch/cpp_extensions/gemm.py Adds per-token NVFP4 detection and GEMM dispatch; grouped-GEMM early-return crashes with AttributeError when outputs are not pre-allocated, and bias forward-then-subtract pattern risks fp32 cancellation.
transformer_engine/common/cast/nvfp4/quantize_per_token_nvfp4.cuh New per-token NVFP4 CUDA kernels for rowwise and columnwise quantization; columnwise num_rows divisibility check only enforced inside the launch helper, not at the public entry point.
transformer_engine/pytorch/csrc/quantizer.cpp Propagates per_token_activation flag through create_tensor, convert_and_update_tensor, and quantize_impl; amax buffer size and shape correctly adjusted for per-token mode.
transformer_engine/pytorch/csrc/extensions/cast.cpp Adds per-token amax buffer support in bulk_allocate_nvfp4_tensors and split_quantize_nvfp4_impl_helper; standalone quantize_nvfp4_per_token function looks correct.
transformer_engine/pytorch/quantization.py Correctly disables per_token_activation for backward quantizers; forward quantizer role assignment via idx%3!=1 is undocumented.
transformer_engine/common/recipe/init.py Adds per_token_activation field to NVFP4BlockScaling recipe with env-var default; no post-init validation enforcing backward_override when per-token mode is on.
transformer_engine/pytorch/tensor/nvfp4_tensor.py Adds per_token_activation field and correctly sizes amax buffers using flat_first_dim rows instead of 1.
transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py Correctly computes total_amax_elements and per-tensor amax offsets for the per-token path.
transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh Dequant kernel correctly parameterised by amax_numel; selects tensor_amax[0] for per-tensor or tensor_amax[y] for per-token.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[Activation Tensor] --> B[split_quantize per-token]
    B --> C[Per-token CUDA kernel\nFP4 data plus FP8 block scales\nper-row amax vector]
    C --> D[NVFP4TensorStorage\namax shape equals num-rows]
    D --> E[general_gemm detection\namax numel greater than 1]
    E --> F[Strip global amax\nset amax to ones]
    F --> G[cuBLAS GEMM\nblock-scaled only fp32 out]
    G --> H[Multiply by per-token scales\nactivation-amax times weight-amax]
    H --> I[Add bias then cast\nFinal output]
Loading

Reviews (6): Last reviewed commit: "Improve accuracy by unfolding weight per..." | Re-trigger Greptile

Comment thread transformer_engine/pytorch/cpp_extensions/gemm.py Outdated
// Compute "correct" per-block encoding scaling factor
const float S_enc_b_fp8 = S_dec_b_fp32 == 0.f ? 0.f : S_enc / S_dec_b_fp32;
const float S_enc_b_fp8 = S_dec_b_fp32 == 0.f ? 0.f :
fminf(1.0f / (S_dec_b_fp32 * (1.0f / S_enc)), Numeric_Traits<float>::maxNorm);
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have to change here to stay aligned with pytorch reference.

@zianglih zianglih marked this pull request as ready for review April 27, 2026 09:14
@zianglih zianglih marked this pull request as draft May 2, 2026 18:22
zianglih and others added 14 commits May 2, 2026 11:27
Signed-off-by: Ziang Li <ziangli@umich.edu>
Co-authored-by: Yigong Qin <qqqyyy1233@outlook.com>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
@ziang-and ziang-and force-pushed the fp4-per-token branch 2 times, most recently from 6998f64 to 5b2f606 Compare May 2, 2026 19:10
zianglih added 5 commits May 2, 2026 16:33
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
@zianglih
Copy link
Copy Markdown
Contributor Author

zianglih commented May 2, 2026

The following extended tests all passed:

python3 -m pytest --tb=auto tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py
python3 -m pytest --tb=auto tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py
python3 -m pytest --tb=auto tests/pytorch/test_backward_override.py
python3 -m pytest --tb=auto tests/pytorch/test_sanity.py
python3 -m pytest --tb=auto tests/pytorch/test_recipe.py
python3 -m pytest --tb=auto tests/pytorch/test_torch_compile.py
python3 -m pytest --tb=auto tests/pytorch/test_cpu_offloading.py
PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 NVTE_FUSED_ATTN=0 python3 -m pytest --tb=auto tests/pytorch/test_cuda_graphs.py
NVTE_TEST_NVINSPECT_ENABLED=1 NVTE_TEST_NVINSPECT_CONFIG_FILE=tests/pytorch/debug/test_configs/dummy_feature.yaml NVTE_TEST_NVINSPECT_FEATURE_DIRS=transformer_engine/debug/features PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest --tb=auto tests/pytorch/test_sanity.py

cd /root/TransformerEngine/tests/cpp
cmake --build build -j200
TEST_BIN="$(find build -type f -name test_operator -perm -u+x | head -n 1)"
"$TEST_BIN" --gtest_filter='*FusedCastTransposeNVFP4*:*DequantizeNVFP4*'
EOF

@zianglih zianglih marked this pull request as ready for review May 2, 2026 23:54
Comment thread transformer_engine/pytorch/cpp_extensions/gemm.py
zianglih added 2 commits May 2, 2026 17:23
…clean up

Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Comment on lines +350 to +354
else:
out_views = out
for i in range(num_gemms):
if out_views[i].numel() == 0:
continue
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 out_views iteration crashes with AttributeError when single_output=False and outputs are not pre-allocated

When single_output=False, out_views = out (line 351). If the caller passes [None] * num_gemms, every out_views[i] is NoneType. The loop immediately calls out_views[i].numel(), raising AttributeError: 'NoneType' object has no attribute 'numel'.

All existing callers appear to pre-allocate, but this contract is not enforced or documented. At minimum, add a guard:

if not all(isinstance(v, torch.Tensor) for v in out_views):
    raise RuntimeError(
        "Per-token NVFP4 grouped GEMM requires pre-allocated output tensors."
    )

@ptrendx ptrendx added the community-contribution PRs from external contributor outside the core maintainers, representing community-driven work. label May 4, 2026
@ptrendx
Copy link
Copy Markdown
Member

ptrendx commented May 4, 2026

Hi @zianglih, could you clarify why you needed the new quantization kernels? The existing NVFP4 quantization kernels should already work if you only use the rowwise mode there, no?

@zianglih zianglih marked this pull request as draft May 5, 2026 05:56
Signed-off-by: Ziang Li <ziangli@umich.edu>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

community-contribution PRs from external contributor outside the core maintainers, representing community-driven work.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants