Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 95 additions & 47 deletions transformer_engine/pytorch/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from __future__ import annotations

from collections.abc import Iterable
from contextlib import contextmanager, AbstractContextManager, ContextDecorator
from contextlib import contextmanager, AbstractContextManager, ContextDecorator, nullcontext
from functools import lru_cache
from dataclasses import dataclass
import math
Expand Down Expand Up @@ -918,7 +918,10 @@ def fork(self, name: str = "model-parallel-rng"):


def reduce_scatter_along_first_dim(
inp: torch.Tensor, tp_group: dist_group_type, async_op: bool = False
inp: torch.Tensor,
tp_group: dist_group_type,
async_op: bool = False,
output: torch.Tensor = None,
) -> Tuple[torch.Tensor, Optional[torch.distributed.Work]]:
"""Reduce-scatter the input tensor across model parallel group."""
world_size = get_distributed_world_size(tp_group)
Expand All @@ -936,7 +939,8 @@ def reduce_scatter_along_first_dim(

dim_size[0] = dim_size[0] // world_size

output = torch.empty(dim_size, dtype=inp.dtype, device=torch.cuda.current_device())
if output is None:
output = torch.empty(dim_size, dtype=inp.dtype, device=torch.cuda.current_device())
handle = torch.distributed.reduce_scatter_tensor(
output, inp.contiguous(), group=tp_group, async_op=async_op
)
Expand Down Expand Up @@ -1281,11 +1285,13 @@ def _post_process_nvfp4_gather(
handle = None

# Fix the interleaved transposed data from gathering along first dim.
out._columnwise_scale_inv = _swap_first_dims(columnwise_scale_inv_interleaved, world_size)
out._columnwise_data = _swap_first_dims(columnwise_data_interleaved, world_size)
# In-place .copy_() (not `=` rebind) to keep the storage address stable
# for CUDA graph capture — replays see the same pointer they captured.
out._columnwise_scale_inv.copy_(_swap_first_dims(columnwise_scale_inv_interleaved, world_size))
out._columnwise_data.copy_(_swap_first_dims(columnwise_data_interleaved, world_size))

# Optionally pad the scaling inverse if needed.
out._columnwise_scale_inv = pad_columnwise_scale_inv(out._columnwise_scale_inv)
# Optionally pad the scaling inverse if needed (same in-place pattern).
out._columnwise_scale_inv.copy_(pad_columnwise_scale_inv(out._columnwise_scale_inv))

Comment on lines 1287 to 1295
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Shape mismatch in _post_process_nvfp4_gather breaks any K not a multiple of 128

out._columnwise_scale_inv is allocated by NVFP4Quantizer.make_empty with shape (round_up(K, 128), round_up(ceil(M_total/16), 4)) — the fully-padded shape. The intermediate result from _swap_first_dims(columnwise_scale_inv_interleaved, world_size) has the unpadded shape (K_stripped, world_size * unpadded_dim1), because the gather side strips padding before the NCCL collect. When K is not a multiple of 128 (e.g. K=64 → padded to 128), the dimensions diverge and out._columnwise_scale_inv.copy_(...) raises a RuntimeError at the first all-gather call.

The pre-PR code used = rebinding, which handled arbitrary shapes. Replacing it with .copy_() is only safe when the caller pre-allocates buffers with the correct unpadded intermediate shape — which make_empty does not do. The GTP-prefetched output_tensor path has the same problem on the step-1 copy before the pad_columnwise_scale_inv call can correct things.


@dataclass
Expand All @@ -1299,17 +1305,25 @@ class _NVFP4AllGatherAsyncHandle:
async_handle: torch.distributed.Work
_synchronized: bool = False

def wait(self) -> None:
"""Wait for the async operation to complete and post-process the tensor."""
if self._synchronized:
return
self.async_handle.wait()
def post_process_nvfp4_gather(self) -> None:
"""Fix interleaved transposed data + pad scale_inv after the async AG completes.

Idempotent: gated by ``_synchronized`` in :meth:`wait`.
"""
_post_process_nvfp4_gather(
self.output,
self.columnwise_data_interleaved,
self.columnwise_scale_inv_interleaved,
self.world_size,
)

def wait(self) -> None:
"""Wait for the async operation to complete and post-process the tensor."""
if self._synchronized:
return
if self.async_handle is not None:
self.async_handle.wait()
self.post_process_nvfp4_gather()
self._synchronized = True
Comment thread
fanshiqing marked this conversation as resolved.


Expand All @@ -1320,6 +1334,8 @@ def _all_gather_nvfp4(
async_op: bool = False,
quantizer: NVFP4Quantizer,
out_shape: Optional[list[int]] = None,
output_tensor=None,
grouped=False,
) -> tuple[NVFP4TensorStorage, Optional[torch.distributed.Work]]:
"""All-gather NVFP4 tensor along first dimension."""

Expand Down Expand Up @@ -1383,6 +1399,12 @@ def _all_gather_nvfp4(
out = quantizer(out)
return out, None

# Construct NVFP4 output tensor
if output_tensor is not None:
out = output_tensor
else:
out = quantizer.make_empty(out_shape, dtype=dtype, device=device)

# Cast input tensor to NVFP4 with required data
if not isinstance(inp, NVFP4TensorStorage):
inp = quantizer(inp)
Expand All @@ -1395,17 +1417,19 @@ def _all_gather_nvfp4(
)
inp = quantizer(inp.dequantize(dtype=dtype))

# Construct NVFP4 output tensor
out = quantizer.make_empty(out_shape, dtype=dtype, device=device)

# Coalesce NCCL collectives for gathering data and scale inverses.
with torch.distributed._coalescing_manager(
group=process_group,
device=device,
async_ops=async_op,
) as gather_coalescing_manager:
if not grouped:
# Coalesce NCCL collectives for gathering data and scale inverses.
gather_coalescing_manager = torch.distributed._coalescing_manager(
group=process_group,
device=device,
async_ops=async_op,
)
else:
gather_coalescing_manager = nullcontext()

with gather_coalescing_manager as coalesced_handle:
# Gather NVFP4 data for row-wise usage
out_columnwise_data = None
if quantizer.rowwise_usage:

# Remove padding from NVFP4 scale-inverses
Expand Down Expand Up @@ -1433,8 +1457,9 @@ def _all_gather_nvfp4(
group=process_group,
)

# Transfer amax to output.
out._amax_rowwise = inp._amax_rowwise
# Transfer amax to output via in-place .copy_() so the storage
# address stays stable for CUDA graph capture.
out._amax_rowwise.copy_(inp._amax_rowwise)

# Gather the transposed NVFP4 data along first dimension. Fix format later.
if quantizer.columnwise_usage:
Expand Down Expand Up @@ -1483,17 +1508,24 @@ def _all_gather_nvfp4(
)

# Transfer amax to output.
out._amax_columnwise = inp._amax_columnwise
out._amax_columnwise.copy_(inp._amax_columnwise)

handle = gather_coalescing_manager if async_op else None
handle = coalesced_handle if async_op else None

# Fixes interleaved data for transposed tensor/scale inv and pads scale inv if needed.
if async_op and quantizer.columnwise_usage:
handle = _NVFP4AllGatherAsyncHandle(
out, out_columnwise_data, out_scale_inv, world_size, handle
)
elif quantizer.columnwise_usage:
_post_process_nvfp4_gather(out, out_columnwise_data, out_scale_inv, world_size, handle)
if quantizer.columnwise_usage:
if async_op or grouped:
# Defer post-processing: either the async op hasn't completed yet, or an
# external coalescing manager owns the NCCL ops and hasn't flushed them.
inner_handle = handle if async_op else None
handle = _NVFP4AllGatherAsyncHandle(
out, out_columnwise_data, out_scale_inv, world_size, inner_handle
)
else:
_post_process_nvfp4_gather(out, out_columnwise_data, out_scale_inv, world_size, handle)
else:
if handle is not None:
handle.output = out

return out, handle

Expand All @@ -1505,6 +1537,8 @@ def _all_gather_mxfp8(
async_op: bool = False,
quantizer: MXFP8Quantizer,
out_shape: Optional[list[int]] = None,
output_tensor: torch.Tensor = None,
grouped: bool = False,
) -> tuple[MXFP8TensorStorage, Optional[torch.distributed.Work]]:
"""All-gather MXFP8 tensor along first dimension."""

Expand Down Expand Up @@ -1570,15 +1604,22 @@ def _all_gather_mxfp8(
inp = quantizer(inp.dequantize(dtype=dtype))

# Construct MXFP8 output tensor
out = quantizer.make_empty(out_shape, dtype=dtype, device=device)
if output_tensor is not None:
out = output_tensor
else:
out = quantizer.make_empty(out_shape, dtype=dtype, device=device)

# Coalesce NCCL collectives
with torch.distributed._coalescing_manager(
group=process_group,
device=device,
async_ops=async_op,
) as coalescing_manager:
if not grouped:
# Coalesce NCCL collectives for gathering data and scale inverses.
gather_coalescing_manager = torch.distributed._coalescing_manager(
group=process_group,
device=device,
async_ops=async_op,
)
else:
gather_coalescing_manager = nullcontext()

with gather_coalescing_manager as coalesced_handle:
# Gather MXFP8 data for row-wise usage
if quantizer.rowwise_usage:

Expand Down Expand Up @@ -1625,7 +1666,7 @@ def _all_gather_mxfp8(
group=process_group,
)

handle = coalescing_manager if async_op else None
handle = coalesced_handle if async_op else None
return out, handle


Expand All @@ -1634,6 +1675,8 @@ def gather_along_first_dim(
process_group: dist_group_type,
async_op: bool = False,
quantizer: Optional[Quantizer] = None,
output_tensor: torch.Tensor = None,
grouped: bool = False,
) -> tuple[torch.Tensor, Optional[torch.distributed.Work]]:
"""
All-gather tensors and concatenate along first dimension.
Expand Down Expand Up @@ -1724,6 +1767,8 @@ def gather_along_first_dim(
async_op=async_op,
quantizer=quantizer,
out_shape=out_shape,
output_tensor=output_tensor,
grouped=grouped,
)

# NVFP4 case
Expand All @@ -1738,6 +1783,8 @@ def gather_along_first_dim(
async_op=async_op,
quantizer=quantizer,
out_shape=out_shape,
output_tensor=output_tensor,
grouped=grouped,
)

# High-precision communication for quantized tensors
Expand Down Expand Up @@ -1767,19 +1814,20 @@ def gather_along_first_dim(
inp = inp.dequantize()

# Communication for plain PyTorch tensors
out = torch.empty(
out_shape,
dtype=inp.dtype,
device=inp.device,
memory_format=torch.contiguous_format,
)
if output_tensor is None:
output_tensor = torch.empty(
out_shape,
dtype=inp.dtype,
device=inp.device,
memory_format=torch.contiguous_format,
)
handle = torch.distributed.all_gather_into_tensor(
out,
output_tensor,
inp.contiguous(),
group=process_group,
async_op=async_op,
)
return out, handle
return output_tensor, handle


# Global cache to store symmetric memory tensors
Expand Down
Loading
Loading