Skip to content

[PyTorch debug] FakeQuant: support Float8BlockScaling and fix MoE / w…#3040

Draft
shangxiaokang wants to merge 2 commits into
NVIDIA:mainfrom
shangxiaokang:fake_quant_bwfp8
Draft

[PyTorch debug] FakeQuant: support Float8BlockScaling and fix MoE / w…#3040
shangxiaokang wants to merge 2 commits into
NVIDIA:mainfrom
shangxiaokang:fake_quant_bwfp8

Conversation

@shangxiaokang
Copy link
Copy Markdown

…eight-cache paths

Description

Please include a brief summary of the changes, relevant motivation and context.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@shangxiaokang shangxiaokang marked this pull request as draft May 25, 2026 04:01
@github-actions github-actions Bot added the community-contribution PRs from external contributor outside the core maintainers, representing community-driven work. label May 25, 2026
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 25, 2026

Greptile Summary

This PR extends FakeQuant to support Float8BlockScaling (128x128 2D tiles and 1x128 1D rows), refactors the old inline quantizer construction into a dispatch table, and fixes the weight-cache write-back path in api.py so modify_tensor with out != None does not assert on the None return value.

  • fake_quant.py: Per-tensor, MXFP8, and new FP8_BLOCKWISE_* formats are now unified through _FORMAT_DISPATCH; a zero-padding helper (_pad_for_blockwise) enables MoE GroupedLinear where the per-expert M-dim is not 128-aligned.
  • api.py: output_assertions_hook gains an early-return branch for the out != None case, preventing a spurious tensor-type assertion on the None that a well-behaved modify_tensor implementation must return when writing to out.

Confidence Score: 3/5

The weight-cache write-back path works correctly after the api.py fix, but modify_tensor will crash with an AttributeError whenever out and dtype are both provided.

The modify_tensor body calls .to(dtype) on the None returned by fake_quantize when out is non-None, crashing for any caller that passes both arguments simultaneously. MXFP8 formats now also route through zero-padding, but that path's correctness has only been argued for Float8BlockQuantizer.

transformer_engine/debug/features/fake_quant.py - specifically the modify_tensor dtype handling and the MXFP8 entries in _FORMAT_DISPATCH.

Important Files Changed

Filename Overview
transformer_engine/debug/features/fake_quant.py Adds Float8BlockScaling support and a zero-padding strategy for non-aligned leading dims. One bug: modify_tensor calls .to(dtype) on the None return value of fake_quantize when out is non-None. MXFP8 entries now go through the blockwise padding path whose suitability for MXFP8Quantizer is unverified.
transformer_engine/debug/features/api.py Adds an early-return branch in output_assertions_hook for the modify_tensor / out != None case, correctly skipping the tensor-type assertion and returning early. Logic is sound.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[fake_quantize called] --> B{format in dispatch?}
    B -- No --> C[raise ValueError]
    B -- Yes --> D[unpack factory and block_size]
    D --> E{block_size not None?}
    E -- Yes --> F[_check_blockwise_shape]
    F --> G[_pad_for_blockwise]
    G --> H[build quantizer]
    E -- No --> H
    H --> I[dequantized = quantize then dequantize]
    I --> J{padding applied?}
    J -- Yes --> K[slice and reshape]
    K --> L{out not None?}
    J -- No --> L
    L -- Yes --> M{out has quantize_?}
    M -- Yes --> N[out.quantize_ dequantized]
    M -- No --> O[out.copy_ dequantized]
    N --> P[return None]
    O --> P
    L -- No --> Q[return dequantized]
Loading

Comments Outside Diff (1)

  1. transformer_engine/debug/features/fake_quant.py, line 354-357 (link)

    P1 When out is not None, fake_quantize() returns None (line 208). The very next line then calls .to(dtype) on that None, which raises AttributeError: 'NoneType' object has no attribute 'to' whenever both out and dtype are provided by the caller.

Reviews (1): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile

Comment on lines +122 to +124
"MXFP8E5M2": (_build_mxfp8_quantizer, tex.DType.kFloat8E5M2, {}, MXFP8_BLOCK_SCALING_SIZE),
# Float8 blockwise: 2D 128x128 tiles
"FP8_BLOCKWISE_E4M3": (
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 The docstring states "MXFP8*: shape[-1] and prod(shape[:-1]) must both be divisible by 32", but because block_size is non-None for MXFP8 entries, _pad_for_blockwise silently zero-pads the leading dim when it is not 32-aligned. The _pad_for_blockwise docstring explicitly describes Float8BlockQuantizer's clean zero-block behaviour, but does not guarantee the same for MXFP8Quantizer. If MXFP8Quantizer does not handle padded rows the same way, the output for the non-padded slice can be subtly wrong without any error.

Suggested change
"MXFP8E5M2": (_build_mxfp8_quantizer, tex.DType.kFloat8E5M2, {}, MXFP8_BLOCK_SCALING_SIZE),
# Float8 blockwise: 2D 128x128 tiles
"FP8_BLOCKWISE_E4M3": (
# MXFP8 (1x32 block scaling) - block_size=None: shape check and padding are
# NOT applied because MXFP8Quantizer does not guarantee clean zero-block
# behaviour for padded rows. Both dims must already be 32-aligned (caller
# responsibility, same as the previous implementation).
"MXFP8E4M3": (_build_mxfp8_quantizer, tex.DType.kFloat8E4M3, {}, None),
"MXFP8E5M2": (_build_mxfp8_quantizer, tex.DType.kFloat8E5M2, {}, None),

@pggPL pggPL self-requested a review May 25, 2026 08:31
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

community-contribution PRs from external contributor outside the core maintainers, representing community-driven work.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant