Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions transformer_engine/debug/features/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,6 +465,17 @@ def output_assertions_hook(self, api_name, ret, **kwargs):
if api_name in ["inspect_tensor", "inspect_tensor_postquantize"]:
assert ret is None
if api_name == "modify_tensor":
# Per TEDefaultFeatures.modify_tensor spec: if `out` is provided
# the implementation must write into it and return None. This path
# is exercised by DebugQuantizer.update_quantized() (weight cache
# write-back). Without this branch, any modify_tensor feature
# (FakeQuant, PerTensorScaling, ...) configured on the `weight`
# tensor crashes here when get_weight_workspace updates the cache.
if kwargs.get("out", None) is not None:
assert (
ret is None
), f"modify_tensor with out != None must return None (got {type(ret)})."
return
assert type(ret) in get_all_tensor_types()
if (
type(ret) == torch.Tensor # pylint: disable=unidiomatic-typecheck
Expand Down
241 changes: 208 additions & 33 deletions transformer_engine/debug/features/fake_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

"""FakeQuant Feature support for nvidia-dlframework-inspect"""

from typing import Optional
import math
from typing import Optional, Tuple

import torch

Expand All @@ -16,48 +17,194 @@
import transformer_engine_torch as tex
from transformer_engine.debug.features.api import TEConfigAPIMapper
from transformer_engine.common.recipe import Format
from transformer_engine.pytorch.constants import MXFP8_BLOCK_SCALING_SIZE
from transformer_engine.pytorch.tensor import Quantizer
from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer
from transformer_engine.pytorch.tensor.float8_blockwise_tensor import Float8BlockQuantizer
from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer
from transformer_engine.pytorch.quantization import _default_sf_compute


def fake_quantize(tensor: torch.Tensor, fp8_format: tex.DType, out=None):
"""Input tensor is quantized to fp8 and then dequantized."""
# Block length used by Float8BlockQuantizer (hard-coded to 128 in TE).
_FP8_BLOCKWISE_BLOCK_LEN = 128


def _build_per_tensor_fp8_quantizer(tensor: torch.Tensor, fp8_dtype: tex.DType) -> Quantizer:
"""Per-tensor current scaling FP8 quantizer (E4M3 / E5M2)."""
fp8_max = (
Format.E4M3.value.max_fwd
if fp8_dtype == tex.DType.kFloat8E4M3
else Format.E5M2.value.max_fwd
)
amax = tensor.abs().max().float()
scale = _default_sf_compute(amax, torch.ones(1, device=tensor.device), fp8_max, 0)
return Float8Quantizer(scale, amax, fp8_dtype)


def _build_mxfp8_quantizer(_tensor: torch.Tensor, fp8_dtype: tex.DType) -> Quantizer:
"""MXFP8 (1x32 block scaling) quantizer."""
return MXFP8Quantizer(fp8_dtype=fp8_dtype)


def _build_fp8_blockwise_quantizer(
_tensor: torch.Tensor, fp8_dtype: tex.DType, *, block_scaling_dim: int
) -> Quantizer:
"""Float8 blockwise quantizer (128x128 2D tiles or 1x128 1D rows)."""
return Float8BlockQuantizer(
fp8_dtype=fp8_dtype,
rowwise=True,
columnwise=False,
block_scaling_dim=block_scaling_dim,
)


def _check_blockwise_shape(tensor: torch.Tensor, block_size: int, fp8_format: str) -> None:
"""Validate that tensor shape is compatible with a blockwise quantizer.

For blockwise formats, the last dim must be a multiple of block_size (true
hard requirement of the quantizer kernel). The leading dim is NOT required
to be a multiple of block_size: when it is not, ``_pad_for_blockwise()``
pads it transparently and ``fake_quantize`` slices the padded tail off
after dequantize. This matches the behavior needed for MoE GroupedLinear
where the per-expert M-dim is routing-dependent and rarely 128-aligned.
"""
if tensor.ndim < 2:
raise ValueError(
f"[NVTORCH INSPECT ERROR] FakeQuant quant_format={fp8_format} requires a tensor with "
f"ndim >= 2, got shape {tuple(tensor.shape)}."
)
last = tensor.shape[-1]
if last % block_size != 0:
raise ValueError(
f"[NVTORCH INSPECT ERROR] FakeQuant quant_format={fp8_format} requires "
f"tensor.shape[-1] ({last}) to be divisible by block_size={block_size}. "
f"Got shape {tuple(tensor.shape)}."
)


def _pad_for_blockwise(tensor: torch.Tensor, block_size: int) -> Tuple[torch.Tensor, Optional[int]]:
"""Pad leading dim up to a multiple of ``block_size``.

Returns ``(padded_tensor, original_leading)``. ``original_leading`` is
``None`` when no padding was needed, otherwise it is the original size of
the flattened leading dim, used to slice the dequantized output back to
the caller's shape.

Padding is done with zeros along a flattened 2D view; rows containing pad
zeros end up forming the partial last block, which the blockwise quantizer
handles cleanly (a zero block has scale=1 and contributes no error after
we discard the pad).
"""
if tensor.ndim < 2:
return tensor, None
last = tensor.shape[-1]
leading = math.prod(tensor.shape[:-1])
if leading % block_size == 0:
return tensor, None

pad_rows = block_size - (leading % block_size)
flat = tensor.reshape(leading, last)
pad = flat.new_zeros((pad_rows, last))
padded = torch.cat([flat, pad], dim=0)
return padded, leading


# Format string -> (factory(tensor, fp8_dtype, **factory_kwargs) -> Quantizer,
# fp8_dtype: tex.DType,
# factory_kwargs: dict,
# block_size: Optional[int] for shape validation, None for per-tensor formats)
_FORMAT_DISPATCH = {
# Per-tensor current scaling FP8
"FP8E4M3": (_build_per_tensor_fp8_quantizer, tex.DType.kFloat8E4M3, {}, None),
"FP8E5M2": (_build_per_tensor_fp8_quantizer, tex.DType.kFloat8E5M2, {}, None),
# MXFP8 (1x32 block scaling)
"MXFP8E4M3": (_build_mxfp8_quantizer, tex.DType.kFloat8E4M3, {}, MXFP8_BLOCK_SCALING_SIZE),
"MXFP8E5M2": (_build_mxfp8_quantizer, tex.DType.kFloat8E5M2, {}, MXFP8_BLOCK_SCALING_SIZE),
# Float8 blockwise: 2D 128x128 tiles
"FP8_BLOCKWISE_E4M3": (
Comment on lines +122 to +124
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 The docstring states "MXFP8*: shape[-1] and prod(shape[:-1]) must both be divisible by 32", but because block_size is non-None for MXFP8 entries, _pad_for_blockwise silently zero-pads the leading dim when it is not 32-aligned. The _pad_for_blockwise docstring explicitly describes Float8BlockQuantizer's clean zero-block behaviour, but does not guarantee the same for MXFP8Quantizer. If MXFP8Quantizer does not handle padded rows the same way, the output for the non-padded slice can be subtly wrong without any error.

Suggested change
"MXFP8E5M2": (_build_mxfp8_quantizer, tex.DType.kFloat8E5M2, {}, MXFP8_BLOCK_SCALING_SIZE),
# Float8 blockwise: 2D 128x128 tiles
"FP8_BLOCKWISE_E4M3": (
# MXFP8 (1x32 block scaling) - block_size=None: shape check and padding are
# NOT applied because MXFP8Quantizer does not guarantee clean zero-block
# behaviour for padded rows. Both dims must already be 32-aligned (caller
# responsibility, same as the previous implementation).
"MXFP8E4M3": (_build_mxfp8_quantizer, tex.DType.kFloat8E4M3, {}, None),
"MXFP8E5M2": (_build_mxfp8_quantizer, tex.DType.kFloat8E5M2, {}, None),

_build_fp8_blockwise_quantizer,
tex.DType.kFloat8E4M3,
{"block_scaling_dim": 2},
_FP8_BLOCKWISE_BLOCK_LEN,
),
"FP8_BLOCKWISE_E5M2": (
_build_fp8_blockwise_quantizer,
tex.DType.kFloat8E5M2,
{"block_scaling_dim": 2},
_FP8_BLOCKWISE_BLOCK_LEN,
),
# Float8 blockwise: 1D 1x128 rows
"FP8_BLOCKWISE_1D_E4M3": (
_build_fp8_blockwise_quantizer,
tex.DType.kFloat8E4M3,
{"block_scaling_dim": 1},
_FP8_BLOCKWISE_BLOCK_LEN,
),
"FP8_BLOCKWISE_1D_E5M2": (
_build_fp8_blockwise_quantizer,
tex.DType.kFloat8E5M2,
{"block_scaling_dim": 1},
_FP8_BLOCKWISE_BLOCK_LEN,
),
}


def fake_quantize(tensor: torch.Tensor, fp8_format: str, out=None):
"""Quantize ``tensor`` to the requested FP8 format and immediately dequantize it.

Supports per-tensor FP8 (FP8E4M3 / FP8E5M2), MXFP8 (MXFP8E4M3 / MXFP8E5M2) and
Float8 blockwise scaling (FP8_BLOCKWISE_{,1D_}E4M3 / FP8_BLOCKWISE_{,1D_}E5M2).

For block-scaled formats, if ``prod(shape[:-1])`` is not a multiple of the
block size the leading dim is zero-padded internally and the dequantized
output is sliced back to the original shape. This makes the feature usable
with MoE GroupedLinear where the per-expert M-dim is dynamic.
"""

assert tensor.dtype in (
torch.float,
torch.float16,
torch.bfloat16,
), "[NVTORCH INSPECT ERROR] Unsupported tensor type."
assert tensor.is_cuda, "[NVTORCH INSPECT ERROR] Must be a GPU tensor."
assert fp8_format in {
"FP8E4M3",
"FP8E5M2",
"MXFP8E4M3",
"MXFP8E5M2",
}, (
"[NVTORCH INSPECT ERROR] Only 4 FP8 types: FP8E4M3, FP8E5M2, MXFP8E4M3, MXFP8E5M2 are"
" supported in TE."
)
if fp8_format in ["FP8E4M3", "FP8E5M2"]:
if fp8_format == "FP8E4M3":
fp8_max = Format.E4M3.value.max_fwd
fp8_dtype = tex.DType.kFloat8E4M3
else:
fp8_max = Format.E5M2.value.max_fwd
fp8_dtype = tex.DType.kFloat8E5M2
amax = tensor.abs().max().float()
one = torch.ones(1, device=tensor.device)
scale = _default_sf_compute(amax, one, fp8_max, 0)

quantizer = Float8Quantizer(scale, amax, fp8_dtype)
else:
quantizer = MXFP8Quantizer(fp8_dtype=fp8_format)

if fp8_format not in _FORMAT_DISPATCH:
raise ValueError(
"[NVTORCH INSPECT ERROR] Unsupported FakeQuant quant_format "
f"{fp8_format!r}. Supported formats: {sorted(_FORMAT_DISPATCH)}."
)

factory, fp8_dtype, factory_kwargs, block_size = _FORMAT_DISPATCH[fp8_format]

original_shape = tensor.shape
qinput = tensor
original_leading: Optional[int] = None
if block_size is not None:
_check_blockwise_shape(tensor, block_size, fp8_format)
qinput, original_leading = _pad_for_blockwise(tensor, block_size)

quantizer = factory(qinput, fp8_dtype, **factory_kwargs)
dequantized = quantizer(qinput).dequantize()

if original_leading is not None:
# Slice off the padded rows and restore the caller's logical shape.
dequantized = dequantized[:original_leading].reshape(original_shape)

if out is not None:
out.copy_(quantizer(tensor).dequantize())
# Called from DebugQuantizer.update_quantized() (weight workspace
# cache write-back). `out` may be a QuantizedTensor (e.g.
# Float8BlockwiseQTensor allocated by parent_quantizer.make_empty)
# or a plain torch.Tensor. Use the QuantizedTensor's own quantize_()
# path when available so the fake-quanted bf16 result is re-encoded
# into the cache's native format (this is the correct semantics for
# same-recipe fake-quant: the second cast is near-identity, and for
# cross-recipe fake-quant it captures the additional cast error).
if hasattr(out, "quantize_"):
out.quantize_(dequantized, noop_flag=None)
else:
out.copy_(dequantized)
return None
return quantizer(tensor).dequantize()
return dequantized


@Registry.register_feature(namespace="transformer_engine")
Expand Down Expand Up @@ -94,10 +241,38 @@ class FakeQuant(TEConfigAPIMapper):
- dgrad

quant_format: str
specifies the FP8 format to use:
specifies the FP8 format / scaling strategy to emulate:

Per-tensor current scaling FP8:

- FP8E4M3
- FP8E5M2

MXFP8 (1x32 block scaling):

- MXFP8E4M3
- MXFP8E5M2

Float8 blockwise scaling - 128x128 2D tiles (default `Float8BlockScaling`):

- FP8_BLOCKWISE_E4M3
- FP8_BLOCKWISE_E5M2

Float8 blockwise scaling - 1x128 1D rows:

- FP8_BLOCKWISE_1D_E4M3
- FP8_BLOCKWISE_1D_E5M2

Shape requirements:

- FP8E5M2
- FP8E4M3
- MXFP8*: ``shape[-1]`` and ``prod(shape[:-1])`` must both
be divisible by 32.
- FP8_BLOCKWISE_*: ``shape[-1]`` must be divisible by 128.
``prod(shape[:-1])`` does NOT need to be 128-aligned;
FakeQuant pads it internally and slices the
dequantized output back to the caller's shape.
This makes the feature work with MoE GroupedLinear
where per-expert token counts are routing-dependent.

Example
-------
Expand All @@ -110,7 +285,7 @@ class FakeQuant(TEConfigAPIMapper):
transformer_engine:
FakeQuant:
enabled: True
quant_format: FP8E5M2
quant_format: FP8_BLOCKWISE_E4M3
gemms_struct:
- gemm: fprop
tensors: [activation, weight]
Expand All @@ -120,7 +295,7 @@ class FakeQuant(TEConfigAPIMapper):

def _supported_formats(self):
"""Returns formats that one can fake quantize tensor to."""
return ["FP8E4M3", "FP8E5M2", "MXFP8E4M3", "MXFP8E5M2"]
return list(_FORMAT_DISPATCH)

@api_method
def fp8_gemm_enabled(
Expand Down