Skip to content

[PyTorch] Pad V when Q/V head dims differ (MLA) for THD#2629

Open
HollowMan6 wants to merge 1 commit into
NVIDIA:mainfrom
HollowMan6:mla_thd
Open

[PyTorch] Pad V when Q/V head dims differ (MLA) for THD#2629
HollowMan6 wants to merge 1 commit into
NVIDIA:mainfrom
HollowMan6:mla_thd

Conversation

@HollowMan6
Copy link
Copy Markdown
Member

@HollowMan6 HollowMan6 commented Jan 27, 2026

Description

For MLA, we shall pad V when Q/V head dims differ for THD

Similar to NVIDIA/Megatron-LM#3003

Fixes NVIDIA/Megatron-LM#1698

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • pad V when Q/V head dims differ for THD

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Copilot AI review requested due to automatic review settings January 27, 2026 23:31
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Jan 27, 2026

Greptile Summary

This PR fixes MLA (DeepSeek V3-style) attention for the THD format by padding the V tensor to the Q/K head dimension before calling flash attention, then trimming the output back to the original V head dimension. The guard that previously disabled FA2 for all mismatched-head-dim cases is removed in favor of the new runtime padding approach.

  • V padding/trimming for THD MLA: _pad_thd_value_layer and _trim_thd_output are added as helpers; padding is gated on q_format == \"thd\", kv_format == \"thd\", non-FP8 V, and head_dim_v < head_dim_qk.
  • FA2 guard removal: The unconditional use_flash_attention_2 = False guard for head_dim_qk != head_dim_v is removed entirely, but the padding workaround only covers the THD path, leaving non-THD layouts with mismatched head dims unprotected.
  • Backend-agnostic padding: The padding is applied regardless of whether FA2, FA3, or FA4 is the selected backend; FA3 and FA4 already handle (192, 128) natively, so padding to (192, 192) breaks those backends.

Confidence Score: 3/5

Two real defects remain: the FA2 guard removal exposes non-THD MLA paths to a backend that cannot handle them, and the backend-agnostic padding breaks FA3/FA4 which already natively support the DeepSeek (192, 128) shape.

The FA2 guard was removed for all qkv_format values, but the padding workaround only activates for THD — any non-THD MLA call that previously fell back gracefully now reaches FA2 without protection. Separately, on systems with FA3 or FA4 installed, the padding transforms a (192, 128) MLA configuration into (192, 192), invalidating the shape contracts those backends were selected against; FA4 on SM100/110 in particular has no valid kernel for (192, 192) and would produce a runtime error.

Both changed files need attention: utils.py for the scope of the guard removal, and dot_product_attention.py for the missing backend check around the padding logic.

Important Files Changed

Filename Overview
transformer_engine/pytorch/attention/dot_product_attention/utils.py Removed FA2 MLA guard unconditionally; non-THD layouts with mismatched head dims can now select FA2 without any padding applied, causing runtime failures.
transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py Adds _pad_thd_value_layer and _trim_thd_output helpers for THD+MLA V-padding, but padding is applied for all FA backends (including FA3/FA4 that natively support DeepSeek MLA), breaking those backends.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[DotProductAttention.forward] --> B{use_flash_attention?}
    B -- Yes --> C{q_format==thd AND kv_format==thd AND NOT FP8 AND head_dim_v < head_dim_qk?}
    C -- Yes --> D[_pad_thd_value_layer: pad V to head_dim_qk]
    C -- No --> E[Use V as-is]
    D --> F[self.flash_attention with padded V]
    E --> F
    F --> G{orig_v_dim is not None?}
    G -- Yes --> H[_trim_thd_output: reshape and slice to orig_v_dim]
    G -- No --> I[return attn_out]
    H --> I

    subgraph utils.py: get_attention_backend
    J{head_dim_qk != head_dim_v?} -- Yes --> K[FA2: guard REMOVED - now unrestricted]
    J -- Yes --> L[FA3: _is_fa3_supported checks 192/128 natively]
    J -- Yes --> M[FA4: _validate_head_dims checks 192/128 natively]
    end

    style K fill:#f88,stroke:#f00
    style D fill:#ffa,stroke:#fa0
Loading

Reviews (7): Last reviewed commit: "[PyTorch] Pad V when Q/V head dims diffe..." | Re-trigger Greptile

Copy link
Copy Markdown
Contributor

@greptile-apps greptile-apps Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 files reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

Comment thread tests/pytorch/attention/test_attention.py Outdated
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds support for Multi-head Latent Attention (MLA) with mismatched Q/V head dimensions in the THD (Total-Hidden-Dimension) format. When the value tensor has a smaller head dimension than the query/key tensors, the code pads the value tensor to match the Q/K head dimension, runs the attention operation, and then trims the output back to the original V dimension.

Changes:

  • Added padding logic for V tensor when head dimensions differ in THD format
  • Implemented trimming function to restore correct output dimensions after attention
  • Added test case for THD attention with mismatched Q/V head dimensions

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.

File Description
transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py Implements padding of V tensor before attention and trimming of output after attention for THD format with mismatched Q/V head dimensions
tests/pytorch/attention/test_attention.py Adds test case to verify THD attention works with different Q/V head dimensions

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread tests/pytorch/attention/test_attention.py Outdated
Copy link
Copy Markdown
Contributor

@greptile-apps greptile-apps Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 file reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Copy link
Copy Markdown
Contributor

@greptile-apps greptile-apps Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

@ptrendx ptrendx added the community-contribution PRs from external contributor outside the core maintainers, representing community-driven work. label Apr 9, 2026
@cyanguwa cyanguwa requested a review from vcherepanov-nv April 22, 2026 22:21
@cyanguwa
Copy link
Copy Markdown
Collaborator

This change should only be required by the FlashAttention backend. The other two backends FusedAttention and UnfusedDPA do support MLA (head_dim_qk != head_dim_v). I'd propose a few changes:

@vcherepanov-nv, could you help push this PR through the finish line? Thanks!

@HollowMan6
Copy link
Copy Markdown
Member Author

Thank you @cyanguwa, I just cleaned up the PR and also follow your requirements. Please let me know what you think @vcherepanov-nv.

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread transformer_engine/pytorch/attention/dot_product_attention/utils.py
Signed-off-by: Hollow Man <hollowman@opensuse.org>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

2.16.0 community-contribution PRs from external contributor outside the core maintainers, representing community-driven work. org-contribution

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[BUG]DotProductAttention:Disabling FlashAttention as it does not support MLA

4 participants