Skip to content

[PyTorch] Enable head dim 256 for FA4#2932

Open
yaox12 wants to merge 9 commits into
NVIDIA:mainfrom
yaox12:xiny/headdim256_fa
Open

[PyTorch] Enable head dim 256 for FA4#2932
yaox12 wants to merge 9 commits into
NVIDIA:mainfrom
yaox12:xiny/headdim256_fa

Conversation

@yaox12
Copy link
Copy Markdown
Member

@yaox12 yaox12 commented Apr 27, 2026

Description

Need FA4 version 4.0.0b11.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@yaox12 yaox12 marked this pull request as draft April 27, 2026 09:31
@yaox12 yaox12 force-pushed the xiny/headdim256_fa branch from bdcc02e to 3b3f7d0 Compare April 27, 2026 09:31
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Apr 27, 2026

Greptile Summary

This PR enables head_dim=256 support for FlashAttention 4 on SM100/SM103 GPUs by delegating head-dimension validation to FA4's own _validate_head_dims function instead of maintaining a parallel static guard in TE, and bumps the required FA4 version to 4.0.0b11.

  • backends.py: _validate_head_dims is imported alongside flash_attn_func/flash_attn_varlen_func in a single grouped import; if absent in an older FA4 install, an uncaught ImportError crashes the entire module load (previously flagged).
  • utils.py: Replaces the static per-arch head-dim check with a live call to FA4's validator; adds an SM100 cross-attention fallback for hd256 shapes; the MLA misalignment workaround is preserved as independent if checks; v4_installation_steps is correctly updated to 4.0.0b11.
  • test_attention.py: Adds test_dpa_fa4_hdim256 with an explicit SM100/SM103 skipif guard, and removes stale cuDNN version checks from all FA4 tests.

Confidence Score: 4/5

The core logic in utils.py is sound, but the grouped import in backends.py will crash the entire TE module load for any user who has FA4 installed at a version older than 4.0.0b11.

The import in backends.py bundles _validate_head_dims into the same grouped block as the two core FA4 functions. Any FA4 install older than 4.0.0b11 that lacks this symbol triggers an unhandled ImportError at module load time, making TE unusable for those users.

transformer_engine/pytorch/attention/dot_product_attention/backends.py — the grouped FA4 import is the critical path that warrants a second look before merging.

Important Files Changed

Filename Overview
transformer_engine/pytorch/attention/dot_product_attention/backends.py Adds _validate_head_dims to the same grouped import as flash_attn_func/flash_attn_varlen_func; an ImportError on older FA4 (pre-4.0.0b11) crashes the entire module load rather than gracefully falling back.
transformer_engine/pytorch/attention/dot_product_attention/utils.py Replaces static head-dim guard with a live call to FA4's _validate_head_dims; adds SM100 cross-attention fallback for hd256; MLA workaround restructured as independent if checks; v4_installation_steps updated to 4.0.0b11.
tests/pytorch/attention/test_attention.py Adds dedicated test_dpa_fa4_hdim256 with explicit SM100/SM103 skip guard; removes stale cuDNN version checks from all FA4 tests.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[get_attention_backend called] --> B{use_flash_attention_4 and v4_is_installed and v4_validate_head_dims is not None?}
    B -- No --> Z[Skip FA4 head-dim check]
    B -- Yes --> C[Compute _fa4_alignment]
    C --> D[Call v4_validate_head_dims]
    D -- AssertionError --> E[Disable FA4]
    D -- OK --> F{SM100 AND hd256 AND seqlen_q != seqlen_kv?}
    F -- Yes --> G[Disable FA4 cross-attn hd256]
    F -- No --> H{Training AND MLA AND SM100?}
    H -- Yes --> I[gcd misalignment check]
    I -- Misaligned --> J[Disable FA4 MLA bwd]
    I -- OK --> K[FA4 enabled]
    H -- No --> K
Loading

Reviews (8): Last reviewed commit: "Merge branch 'main' into xiny/headdim256..." | Re-trigger Greptile

Comment thread tests/pytorch/attention/test_attention.py Outdated
Signed-off-by: Xin Yao <xiny@nvidia.com>
@yaox12 yaox12 force-pushed the xiny/headdim256_fa branch from 3b3f7d0 to 9a93156 Compare May 6, 2026 02:44
Signed-off-by: Xin Yao <xiny@nvidia.com>
@yaox12 yaox12 force-pushed the xiny/headdim256_fa branch from ae74e44 to 8aa5242 Compare May 6, 2026 02:55
@yaox12
Copy link
Copy Markdown
Member Author

yaox12 commented May 6, 2026

/te-ci pytorch L3

@yaox12 yaox12 marked this pull request as ready for review May 6, 2026 02:59
@yaox12
Copy link
Copy Markdown
Member Author

yaox12 commented May 6, 2026

@vcherepanov-nv @KshitijLakhani Please review.

@KshitijLakhani KshitijLakhani requested a review from mk-61 May 8, 2026 06:34
Comment thread tests/pytorch/attention/test_attention.py Outdated
Comment thread tests/pytorch/attention/test_attention.py
# dV TMEM load atoms. When (tile_hdimv // 2) % dK_reduce_ncol != 0, dV reads are
# misaligned. The dedicated (256, 256) kernel uses its own tmem layout so it's
# not affected. See: flash_attn/cute/flash_bwd_sm100.py, line ~262 and ~3890.
if (
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this still be checked when FlashAttentionUtils.v4_validate_head_dims == None?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I double checked that this is a bug of FA4. Kernels produce wrong results on these shapes but they're allowed by v4_validate_head_dims, so we have to filter them out manually.
Raise an issue to FA4. Dao-AILab/flash-attention#2552

@vcherepanov-nv
Copy link
Copy Markdown
Collaborator

LGTM

yaox12 added 2 commits May 10, 2026 22:28
@yaox12
Copy link
Copy Markdown
Member Author

yaox12 commented May 11, 2026

/te-ci pytorch L3

yaox12 added 2 commits May 12, 2026 10:30
Signed-off-by: Xin Yao <xiny@nvidia.com>
@yaox12
Copy link
Copy Markdown
Member Author

yaox12 commented May 12, 2026

/te-ci pytorch L3

@yaox12 yaox12 requested a review from cyanguwa as a code owner May 13, 2026 10:42
@yaox12
Copy link
Copy Markdown
Member Author

yaox12 commented May 15, 2026

/te-ci pytorch L3

Copy link
Copy Markdown
Member

@sudhakarsingh27 sudhakarsingh27 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, pending CI

@sudhakarsingh27 sudhakarsingh27 self-requested a review May 15, 2026 20:52
@yaox12
Copy link
Copy Markdown
Member Author

yaox12 commented May 18, 2026

B200 test failed with 1 element mismatch. It should be irrelevant to this PR because I saw similar errors in other pipelines.

@sudhakarsingh27
Copy link
Copy Markdown
Member

sudhakarsingh27 commented May 18, 2026

Need to manually run L1 tests, triggering now
Doesn't look like it's needed

Signed-off-by: Xin Yao <xiny@nvidia.com>
@yaox12
Copy link
Copy Markdown
Member Author

yaox12 commented May 24, 2026

/te-ci pytorch L3

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants