[PyTorch] Make modules.GroupedLinear graph-safe#3038
Conversation
Greptile SummaryThis PR introduces a cuBLASLt grouped GEMM path for
Confidence Score: 4/5Safe to merge for the default (legacy) path; the new grouped-tensor path is opt-in and guarded by an env var and SM100+ capability check. The single_grouped_bias + delay_wgrad_compute crash flagged in the prior review is now fixed and covered by a dedicated regression test. The env-var-leak issues in the test suite are cleaned up with monkeypatch. The only remaining concern is that _forward_grouped_tensor and _backward_grouped_tensor both use only input_quantizers[0] / grad_output_quantizers[0] when calling tex.group_quantize, silently ignoring per-group quantizer configurations if they were ever to diverge; this does not affect the current GroupedLinear module (which creates identical quantizers per group) but could become a latent bug for future callers. transformer_engine/pytorch/module/grouped_linear.py — specifically the _forward_grouped_tensor and _backward_grouped_tensor quantizer indexing. Important Files Changed
Reviews (4): Last reviewed commit: "make modules.GroupedLinear graph-safe" | Re-trigger Greptile |
Signed-off-by: Xin Yao <xiny@nvidia.com>
d176247 to
698383e
Compare
|
/te-ci pytorch |
Description
Enable grouped quantization and cuBLASLt grouped gemm for
modules.GroupedLinearto benefit cases where cuteDSL fused grouped gemm is not available.Move grouped gemm and grouped linear related tests to a standalone file.
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: