[PyTorch] Pad V when Q/V head dims differ (MLA) for THD#2629
Conversation
Greptile SummaryThis PR fixes MLA (DeepSeek V3-style) attention for the THD format by padding the V tensor to the Q/K head dimension before calling flash attention, then trimming the output back to the original V head dimension. The guard that previously disabled FA2 for all mismatched-head-dim cases is removed in favor of the new runtime padding approach.
Confidence Score: 3/5Two real defects remain: the FA2 guard removal exposes non-THD MLA paths to a backend that cannot handle them, and the backend-agnostic padding breaks FA3/FA4 which already natively support the DeepSeek (192, 128) shape. The FA2 guard was removed for all qkv_format values, but the padding workaround only activates for THD — any non-THD MLA call that previously fell back gracefully now reaches FA2 without protection. Separately, on systems with FA3 or FA4 installed, the padding transforms a (192, 128) MLA configuration into (192, 192), invalidating the shape contracts those backends were selected against; FA4 on SM100/110 in particular has no valid kernel for (192, 192) and would produce a runtime error. Both changed files need attention: Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[DotProductAttention.forward] --> B{use_flash_attention?}
B -- Yes --> C{q_format==thd AND kv_format==thd AND NOT FP8 AND head_dim_v < head_dim_qk?}
C -- Yes --> D[_pad_thd_value_layer: pad V to head_dim_qk]
C -- No --> E[Use V as-is]
D --> F[self.flash_attention with padded V]
E --> F
F --> G{orig_v_dim is not None?}
G -- Yes --> H[_trim_thd_output: reshape and slice to orig_v_dim]
G -- No --> I[return attn_out]
H --> I
subgraph utils.py: get_attention_backend
J{head_dim_qk != head_dim_v?} -- Yes --> K[FA2: guard REMOVED - now unrestricted]
J -- Yes --> L[FA3: _is_fa3_supported checks 192/128 natively]
J -- Yes --> M[FA4: _validate_head_dims checks 192/128 natively]
end
style K fill:#f88,stroke:#f00
style D fill:#ffa,stroke:#fa0
Reviews (7): Last reviewed commit: "[PyTorch] Pad V when Q/V head dims diffe..." | Re-trigger Greptile |
There was a problem hiding this comment.
Pull request overview
This PR adds support for Multi-head Latent Attention (MLA) with mismatched Q/V head dimensions in the THD (Total-Hidden-Dimension) format. When the value tensor has a smaller head dimension than the query/key tensors, the code pads the value tensor to match the Q/K head dimension, runs the attention operation, and then trims the output back to the original V dimension.
Changes:
- Added padding logic for V tensor when head dimensions differ in THD format
- Implemented trimming function to restore correct output dimensions after attention
- Added test case for THD attention with mismatched Q/V head dimensions
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.
| File | Description |
|---|---|
| transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py | Implements padding of V tensor before attention and trimming of output after attention for THD format with mismatched Q/V head dimensions |
| tests/pytorch/attention/test_attention.py | Adds test case to verify THD attention works with different Q/V head dimensions |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
|
This change should only be required by the FlashAttention backend. The other two backends FusedAttention and UnfusedDPA do support MLA (head_dim_qk != head_dim_v). I'd propose a few changes:
@vcherepanov-nv, could you help push this PR through the finish line? Thanks! |
|
Thank you @cyanguwa, I just cleaned up the PR and also follow your requirements. Please let me know what you think @vcherepanov-nv. |
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Signed-off-by: Hollow Man <hollowman@opensuse.org>
Description
For MLA, we shall pad V when Q/V head dims differ for THD
Similar to NVIDIA/Megatron-LM#3003
Fixes NVIDIA/Megatron-LM#1698
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: