Generalized Tensor Parallelism (GTP) #3005
Conversation
Greptile SummaryThis PR introduces Generalized Tensor Parallelism (GTP), a fine-grained sharding mechanism for weights, gradients, and optimizer states with computation–communication overlap for
Confidence Score: 3/5The NVFP4 all-gather post-processing regression can crash any non-GTP NVFP4 workload with non-128-aligned K dimensions; the module files are structurally sound but depend on a separate unreviewable GTP runtime. The switch from attribute rebinding to transformer_engine/pytorch/distributed.py — specifically Important Files Changed
Reviews (7): Last reviewed commit: "Generalized Tensor Parallelism (GTP) ini..." | Re-trigger Greptile |
|
/te-ci L1 pytorch |
3e70bdf to
ed9ce68
Compare
Co-authored-by: Jieming Zhang <jiemingz@nvidia.com> Signed-off-by: Shiqing Fan <shiqingf@nvidia.com>
| # Fix the interleaved transposed data from gathering along first dim. | ||
| out._columnwise_scale_inv = _swap_first_dims(columnwise_scale_inv_interleaved, world_size) | ||
| out._columnwise_data = _swap_first_dims(columnwise_data_interleaved, world_size) | ||
| # In-place .copy_() (not `=` rebind) to keep the storage address stable | ||
| # for CUDA graph capture — replays see the same pointer they captured. | ||
| out._columnwise_scale_inv.copy_(_swap_first_dims(columnwise_scale_inv_interleaved, world_size)) | ||
| out._columnwise_data.copy_(_swap_first_dims(columnwise_data_interleaved, world_size)) | ||
|
|
||
| # Optionally pad the scaling inverse if needed. | ||
| out._columnwise_scale_inv = pad_columnwise_scale_inv(out._columnwise_scale_inv) | ||
| # Optionally pad the scaling inverse if needed (same in-place pattern). | ||
| out._columnwise_scale_inv.copy_(pad_columnwise_scale_inv(out._columnwise_scale_inv)) | ||
|
|
There was a problem hiding this comment.
Shape mismatch in
_post_process_nvfp4_gather breaks any K not a multiple of 128
out._columnwise_scale_inv is allocated by NVFP4Quantizer.make_empty with shape (round_up(K, 128), round_up(ceil(M_total/16), 4)) — the fully-padded shape. The intermediate result from _swap_first_dims(columnwise_scale_inv_interleaved, world_size) has the unpadded shape (K_stripped, world_size * unpadded_dim1), because the gather side strips padding before the NCCL collect. When K is not a multiple of 128 (e.g. K=64 → padded to 128), the dimensions diverge and out._columnwise_scale_inv.copy_(...) raises a RuntimeError at the first all-gather call.
The pre-PR code used = rebinding, which handled arbitrary shapes. Replacing it with .copy_() is only safe when the caller pre-allocates buffers with the correct unpadded intermediate shape — which make_empty does not do. The GTP-prefetched output_tensor path has the same problem on the step-1 copy before the pad_columnwise_scale_inv call can correct things.
Deisgn doc: GTP.docx
Description
Core-idea: add Generalized Tensor Parallelism (GTP), which is a flexible fine-grained sharding/just-in time materialization of both activations and parameters with efficient computation-communication overlap.
Mission: improve LLM pretraining efficiency through generalized tensor parallelism, enabling high performance, memory efficiency, ease of use, and strong scalability.
Summary of features
How Mcore interacts with TE
① Mcore registers callbacks into TE at import time.
② TE calls back into Mcore runtime during te.Linear(gtp_group=…) init AND during fwd/bwd (weight.all_gather_and_prefetch / wgrad_reduce_scatter).
③ Mcore extensions forward gtp_group= at module init.
④ TE provides FP8 / MXFP8 / NVFP4 tensor types AND the quantize-then-AG / RS collectives (gather_along_first_dim, reduce_scatter_along_first_dim) — imported by Mcore runtime; GTP wraps them with its own schedule, buffer cache, and stream choreography.
Type of change
Changes
Please list the changes introduced in this PR:
wgrad_shape.
carving (with/without GTP);
Checklist: