[PyTorch] Enable head dim 256 for FA4#2932
Conversation
bdcc02e to
3b3f7d0
Compare
Greptile SummaryThis PR enables head_dim=256 support for FlashAttention 4 on SM100/SM103 GPUs by delegating head-dimension validation to FA4's own
Confidence Score: 4/5The core logic in utils.py is sound, but the grouped import in backends.py will crash the entire TE module load for any user who has FA4 installed at a version older than 4.0.0b11. The import in backends.py bundles transformer_engine/pytorch/attention/dot_product_attention/backends.py — the grouped FA4 import is the critical path that warrants a second look before merging. Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[get_attention_backend called] --> B{use_flash_attention_4 and v4_is_installed and v4_validate_head_dims is not None?}
B -- No --> Z[Skip FA4 head-dim check]
B -- Yes --> C[Compute _fa4_alignment]
C --> D[Call v4_validate_head_dims]
D -- AssertionError --> E[Disable FA4]
D -- OK --> F{SM100 AND hd256 AND seqlen_q != seqlen_kv?}
F -- Yes --> G[Disable FA4 cross-attn hd256]
F -- No --> H{Training AND MLA AND SM100?}
H -- Yes --> I[gcd misalignment check]
I -- Misaligned --> J[Disable FA4 MLA bwd]
I -- OK --> K[FA4 enabled]
H -- No --> K
Reviews (8): Last reviewed commit: "Merge branch 'main' into xiny/headdim256..." | Re-trigger Greptile |
Signed-off-by: Xin Yao <xiny@nvidia.com>
Signed-off-by: Xin Yao <xiny@nvidia.com>
|
/te-ci pytorch L3 |
|
@vcherepanov-nv @KshitijLakhani Please review. |
| # dV TMEM load atoms. When (tile_hdimv // 2) % dK_reduce_ncol != 0, dV reads are | ||
| # misaligned. The dedicated (256, 256) kernel uses its own tmem layout so it's | ||
| # not affected. See: flash_attn/cute/flash_bwd_sm100.py, line ~262 and ~3890. | ||
| if ( |
There was a problem hiding this comment.
Should this still be checked when FlashAttentionUtils.v4_validate_head_dims == None?
There was a problem hiding this comment.
I double checked that this is a bug of FA4. Kernels produce wrong results on these shapes but they're allowed by v4_validate_head_dims, so we have to filter them out manually.
Raise an issue to FA4. Dao-AILab/flash-attention#2552
|
LGTM |
Signed-off-by: Xin Yao <xiny@nvidia.com>
|
/te-ci pytorch L3 |
Signed-off-by: Xin Yao <xiny@nvidia.com>
|
/te-ci pytorch L3 |
|
/te-ci pytorch L3 |
|
B200 test failed with 1 element mismatch. It should be irrelevant to this PR because I saw similar errors in other pipelines. |
|
|
Signed-off-by: Xin Yao <xiny@nvidia.com>
|
/te-ci pytorch L3 |
Description
Need FA4 version
4.0.0b11.Type of change
Changes
Please list the changes introduced in this PR:
Checklist: