Skip to content

[PyTorch] Make modules.GroupedLinear graph-safe#3038

Open
yaox12 wants to merge 1 commit into
NVIDIA:mainfrom
yaox12:xiny/enable-grouped-quantize-cublaslt
Open

[PyTorch] Make modules.GroupedLinear graph-safe#3038
yaox12 wants to merge 1 commit into
NVIDIA:mainfrom
yaox12:xiny/enable-grouped-quantize-cublaslt

Conversation

@yaox12
Copy link
Copy Markdown
Member

@yaox12 yaox12 commented May 22, 2026

Description

  • Enable grouped quantization and cuBLASLt grouped gemm for modules.GroupedLinear to benefit cases where cuteDSL fused grouped gemm is not available.

    1. Reduce CPU overhead by reducing number of kernels.
    2. Be CUDA-Graph-safe.
    3. Improve kernel performance.
  • Move grouped gemm and grouped linear related tests to a standalone file.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 22, 2026

Greptile Summary

This PR introduces a cuBLASLt grouped GEMM path for GroupedLinear backed by GroupedTensor metadata, targeting SM100+ (Blackwell) hardware. The new path is opt-in via NVTE_GROUPED_LINEAR_USE_FUSED_GROUPED_GEMM=1, reduces CPU kernel-launch overhead, and is CUDA-graph-safe by accepting a CUDA tensor for m_splits instead of a Python list.

  • Adds _forward_grouped_tensor / _backward_grouped_tensor statics (~300 LOC) inside _GroupedLinear, selected at runtime via _is_grouped_tensor_path_supported; the legacy path is completely unchanged when the new path is disabled.
  • Migrates all grouped-linear and grouped-GEMM tests from test_numerics.py into a new tests/pytorch/test_grouped_linear.py, extending them with MXFP8 grouped-tensor GEMM tests, a CUDA-graph-capture smoke test, and a regression for the single_grouped_bias + delay_wgrad_compute crash identified in prior review rounds.

Confidence Score: 4/5

Safe to merge for the default (legacy) path; the new grouped-tensor path is opt-in and guarded by an env var and SM100+ capability check.

The single_grouped_bias + delay_wgrad_compute crash flagged in the prior review is now fixed and covered by a dedicated regression test. The env-var-leak issues in the test suite are cleaned up with monkeypatch. The only remaining concern is that _forward_grouped_tensor and _backward_grouped_tensor both use only input_quantizers[0] / grad_output_quantizers[0] when calling tex.group_quantize, silently ignoring per-group quantizer configurations if they were ever to diverge; this does not affect the current GroupedLinear module (which creates identical quantizers per group) but could become a latent bug for future callers.

transformer_engine/pytorch/module/grouped_linear.py — specifically the _forward_grouped_tensor and _backward_grouped_tensor quantizer indexing.

Important Files Changed

Filename Overview
transformer_engine/pytorch/module/grouped_linear.py Adds the grouped-tensor GEMM path (~300 LOC) for SM100+: new _forward_grouped_tensor / _backward_grouped_tensor statics, quantizer-dispatch helpers, and a ctx.use_grouped_tensor_path flag that routes the legacy backward unchanged. The previously-flagged single_grouped_bias + delay_wgrad_compute crash is resolved via the has_grad_biases guard. Only the first quantizer from each list is used when calling tex.group_quantize, which could silently misapply config if per-group quantizers ever diverge.
tests/pytorch/test_grouped_linear.py New standalone test file (1739 lines) migrated from test_numerics.py, extended with grouped-tensor GEMM tests. The _reset_fp8_state autouse fixture uses monkeypatch.setenv to safely manage env vars; the env-var-leak issues flagged in previous review threads are now fixed. Includes CUDA-graph-safe capture test and single_grouped_bias + delay_wgrad regression.
tests/pytorch/test_numerics.py Removes all grouped-linear / grouped-GEMM test code (~600 lines) migrated to test_grouped_linear.py; drops unused imports. Remaining tests are unaffected.
benchmarks/linear/benchmark_grouped_linear.py Conditionally converts m_splits to a CUDA tensor when NVTE_GROUPED_LINEAR_USE_FUSED_GROUPED_GEMM=1, consistent with the module's own env-var default of 0.
qa/L0_pytorch_unittest/test.sh Adds test_grouped_linear.py to the CI run with the same env-var flags as test_numerics.py.

Reviews (4): Last reviewed commit: "make modules.GroupedLinear graph-safe" | Re-trigger Greptile

Comment thread tests/pytorch/test_grouped_linear.py
Comment thread tests/pytorch/test_grouped_linear.py
Comment thread transformer_engine/pytorch/module/grouped_linear.py
Comment thread benchmarks/linear/benchmark_grouped_linear.py
Comment thread tests/pytorch/test_grouped_linear.py Outdated
Signed-off-by: Xin Yao <xiny@nvidia.com>
@yaox12 yaox12 force-pushed the xiny/enable-grouped-quantize-cublaslt branch from d176247 to 698383e Compare May 25, 2026 03:56
@yaox12
Copy link
Copy Markdown
Member Author

yaox12 commented May 25, 2026

/te-ci pytorch

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant