Skip to content

Generalized Tensor Parallelism (GTP) #3005

Open
fanshiqing wants to merge 1 commit into
NVIDIA:mainfrom
fanshiqing:gtp_release
Open

Generalized Tensor Parallelism (GTP) #3005
fanshiqing wants to merge 1 commit into
NVIDIA:mainfrom
fanshiqing:gtp_release

Conversation

@fanshiqing
Copy link
Copy Markdown
Member

@fanshiqing fanshiqing commented May 18, 2026

Deisgn doc: GTP.docx

Description

Core-idea: add Generalized Tensor Parallelism (GTP), which is a flexible fine-grained sharding/just-in time materialization of both activations and parameters with efficient computation-communication overlap.

Mission: improve LLM pretraining efficiency through generalized tensor parallelism, enabling high performance, memory efficiency, ease of use, and strong scalability.

Summary of features

  1. Fine-grained materialization & gradient reduction
  • Weight, gradient, and optimizer states are sharded along the GTP group.
  • Weights are temporarily materialized through prefetching in both the forward and backward passes.
  1. Composability with TP / SP / EP / DDP with efficient overlapping of computation and communication
  • GEMM + TP/EP communication + GTP communication + DDP communication.
  1. GTP + partial Cudagraphs with fine-grained synchronization across graphs
  • GTP reduce-scatter overlapping across graphs.
  1. Low-Precision quantize-then-gather
  • MXFP8 / NVFP4
  • Auto-padding/stripping to satisfy low-precision alignment requirements.
  1. Parallel folding for MoE layer
  • Support configuring the GTP size for dense layers and MoE layers separately.
  1. Distributed checkpointing

How Mcore interacts with TE

① Mcore registers callbacks into TE at import time.

② TE calls back into Mcore runtime during te.Linear(gtp_group=…) init AND during fwd/bwd (weight.all_gather_and_prefetch / wgrad_reduce_scatter).

③ Mcore extensions forward gtp_group= at module init.

④ TE provides FP8 / MXFP8 / NVFP4 tensor types AND the quantize-then-AG / RS collectives (gather_along_first_dim, reduce_scatter_along_first_dim) — imported by Mcore runtime; GTP wraps them with its own schedule, buffer cache, and stream choreography.

image

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • transformer_engine/pytorch/module/base.py (+76 / −2)
    • GTP hook registry: register_gtp_hooks(), maybe_wrap_gtp()
  • transformer_engine/pytorch/module/linear.py (+72 / −2)
    • Linear(gtp_group=…) kwarg
    • fwd: optional all_gather_and_prefetch rebind and skip workspace save;
    • bwd: re-gather + wgrad_reduce_scatter + main_grad write-back guard + sharded
      wgrad_shape.
  • transformer_engine/pytorch/module/layernorm_linear.py (+60 / −5)
    • same pattern mirrored for the fused LN+Linear path
  • transformer_engine/pytorch/module/grouped_linear.py (+115 / −16)
    • GroupedLinear(gtp_group=…) + maybe_wrap_gtp(..., is_grouped=True); dual saved-tensor
      carving (with/without GTP);
    • batched_all_gather_and_prefetch + batched_all_gather_and_prefetch_bwd + batched_wgrad_reduce_scatter
  • transformer_engine/pytorch/distributed.py (+142 / −53)
    • in-place .copy_() for amax/scale_inv/data so storage addresses stay stable across CUDA-graph replay.
    • GTP runtime depends on this for prefetch overlap.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • [] I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 18, 2026

Greptile Summary

This PR introduces Generalized Tensor Parallelism (GTP), a fine-grained sharding mechanism for weights, gradients, and optimizer states with computation–communication overlap for Linear, LayerNormLinear, and GroupedLinear modules. Integration is callback-based: Megatron registers hooks into TE at import time, keeping TE free of a hard Megatron dependency.

  • base.py: Three GTP hook slots (_gtp_slice_fn, _gtp_finalize_fn, _gtp_wrap_fn) registered via register_gtp_hooks; reset_parameters calls the slice hook per weight and a finalize hook after the loop; all paths are no-ops when no GTP integrator is present.
  • linear.py / layernorm_linear.py / grouped_linear.py: Each module gains a gtp_group= kwarg; gtp_size is threaded through FwdArgs/BwdArgs dataclasses; forward AG-prefetch and backward wgrad reduce-scatter are conditionally dispatched around the existing GEMM calls.
  • distributed.py: gather_along_first_dim and its NVFP4/MXFP8 helpers gain output_tensor= and grouped= parameters for buffer reuse and external coalescing-group compatibility; _NVFP4AllGatherAsyncHandle gains a split post_process_nvfp4_gather method and a null-guard on async_handle.

Confidence Score: 3/5

The NVFP4 all-gather post-processing regression can crash any non-GTP NVFP4 workload with non-128-aligned K dimensions; the module files are structurally sound but depend on a separate unreviewable GTP runtime.

The switch from attribute rebinding to .copy_() in _post_process_nvfp4_gather is the critical issue: make_empty pre-allocates _columnwise_scale_inv at the padded shape, but _swap_first_dims returns the unpadded K dimension. Any model with K not a multiple of 128 hits a RuntimeError at the first NVFP4 all-gather — even without GTP enabled.

transformer_engine/pytorch/distributed.py — specifically _post_process_nvfp4_gather and the .copy_() pattern for _columnwise_scale_inv.

Important Files Changed

Filename Overview
transformer_engine/pytorch/distributed.py Adds output_tensor/grouped params to all-gather helpers and async handle for GTP prefetch; shape-mismatch regression in _post_process_nvfp4_gather when .copy_() is used with a padded destination but an unpadded source.
transformer_engine/pytorch/module/base.py Introduces three GTP hook slots with a clean register/call pattern; no-op when GTP is absent; changes to reset_parameters are well-guarded.
transformer_engine/pytorch/module/grouped_linear.py Adds gtp_group kwarg and GTP-aware forward/backward dispatch; weight_names set unconditionally before maybe_wrap_gtp, fixing the prior AttributeError on gtp_group=None.
transformer_engine/pytorch/module/layernorm_linear.py Mirrors Linear GTP wiring; gtp_size plumbed through forward/backward args; wgrad RS deferred correctly after the fused wgrad GEMM.
transformer_engine/pytorch/module/linear.py Adds gtp_size to both dataclasses; AG-prefetch and wgrad RS hooks wired symmetrically for forward and backward.

Reviews (7): Last reviewed commit: "Generalized Tensor Parallelism (GTP) ini..." | Re-trigger Greptile

Comment thread transformer_engine/pytorch/module/generalized_tensor_parallelism.py Outdated
Comment thread transformer_engine/pytorch/module/grouped_linear.py Outdated
Comment thread transformer_engine/pytorch/module/grouped_linear.py Outdated
Comment thread transformer_engine/pytorch/module/generalized_tensor_parallelism.py Outdated
Comment thread transformer_engine/pytorch/module/generalized_tensor_parallelism.py Outdated
Comment thread transformer_engine/pytorch/module/generalized_tensor_parallelism.py Outdated
@fanshiqing
Copy link
Copy Markdown
Member Author

/te-ci L1 pytorch

Comment thread transformer_engine/pytorch/distributed.py
Comment thread transformer_engine/pytorch/csrc/extensions/cast.cpp Outdated
Comment thread transformer_engine/pytorch/module/grouped_linear.py Outdated
Co-authored-by: Jieming Zhang <jiemingz@nvidia.com>
Signed-off-by: Shiqing Fan <shiqingf@nvidia.com>
Comment on lines 1287 to 1295
# Fix the interleaved transposed data from gathering along first dim.
out._columnwise_scale_inv = _swap_first_dims(columnwise_scale_inv_interleaved, world_size)
out._columnwise_data = _swap_first_dims(columnwise_data_interleaved, world_size)
# In-place .copy_() (not `=` rebind) to keep the storage address stable
# for CUDA graph capture — replays see the same pointer they captured.
out._columnwise_scale_inv.copy_(_swap_first_dims(columnwise_scale_inv_interleaved, world_size))
out._columnwise_data.copy_(_swap_first_dims(columnwise_data_interleaved, world_size))

# Optionally pad the scaling inverse if needed.
out._columnwise_scale_inv = pad_columnwise_scale_inv(out._columnwise_scale_inv)
# Optionally pad the scaling inverse if needed (same in-place pattern).
out._columnwise_scale_inv.copy_(pad_columnwise_scale_inv(out._columnwise_scale_inv))

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Shape mismatch in _post_process_nvfp4_gather breaks any K not a multiple of 128

out._columnwise_scale_inv is allocated by NVFP4Quantizer.make_empty with shape (round_up(K, 128), round_up(ceil(M_total/16), 4)) — the fully-padded shape. The intermediate result from _swap_first_dims(columnwise_scale_inv_interleaved, world_size) has the unpadded shape (K_stripped, world_size * unpadded_dim1), because the gather side strips padding before the NCCL collect. When K is not a multiple of 128 (e.g. K=64 → padded to 128), the dimensions diverge and out._columnwise_scale_inv.copy_(...) raises a RuntimeError at the first all-gather call.

The pre-PR code used = rebinding, which handled arbitrary shapes. Replacing it with .copy_() is only safe when the caller pre-allocates buffers with the correct unpadded intermediate shape — which make_empty does not do. The GTP-prefetched output_tensor path has the same problem on the step-1 copy before the pad_columnwise_scale_inv call can correct things.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants