Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import logging

import torch
import torch.nn.functional as F
from torch.nn.parameter import Parameter

import transformer_engine_torch as tex
Expand Down Expand Up @@ -180,6 +181,18 @@
__all__ = ["DotProductAttention"]


def _pad_thd_value_layer(value_layer, head_dim_qk):
"""Pad V for THD FlashAttention when Q/K and V head dimensions differ."""
orig_head_dim_v = value_layer.shape[-1]
return F.pad(value_layer, (0, head_dim_qk - orig_head_dim_v)), orig_head_dim_v


def _trim_thd_output(attn_out, num_attention_heads, padded_head_dim_v, orig_head_dim_v):
"""Trim FlashAttention THD output after padding V to the Q/K head dimension."""
attn_out = attn_out.reshape(attn_out.shape[0], num_attention_heads, padded_head_dim_v)
return attn_out[..., :orig_head_dim_v].reshape(attn_out.shape[0], -1)


class DotProductAttention(TransformerEngineBaseModule):
r"""Allows the model to jointly attend to information from different
representation subspaces as described in the paper:
Expand Down Expand Up @@ -1630,6 +1643,16 @@ def forward(
)

if use_flash_attention:
orig_v_dim = None
if (
q_format == "thd"
and kv_format == "thd"
and not isinstance(value_layer, Float8TensorStorage)
and head_dim_qk != head_dim_v
and value_layer.shape[-1] < head_dim_qk
):
value_layer, orig_v_dim = _pad_thd_value_layer(value_layer, head_dim_qk)

Comment thread
HollowMan6 marked this conversation as resolved.
if core_attention_bias_type == "alibi":
alibi_slopes, _ = dpa_utils.get_alibi(
_alibi_cache,
Expand All @@ -1638,7 +1661,7 @@ def forward(
max_seqlen_kv,
alibi_slopes=alibi_slopes,
)
return self.flash_attention(
attn_out = self.flash_attention(
query_layer,
key_layer,
value_layer,
Expand Down Expand Up @@ -1666,6 +1689,9 @@ def forward(
cu_seqlens_q_padded=cu_seqlens_q_padded,
cu_seqlens_kv_padded=cu_seqlens_kv_padded,
)
if orig_v_dim is not None:
return _trim_thd_output(attn_out, num_attention_heads, head_dim_qk, orig_v_dim)
return attn_out

if use_fused_attention:
fu_core_attention_bias_type = core_attention_bias_type
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -719,10 +719,6 @@ def get_attention_backend(

# Filter: Head dimension
if head_dim_qk != head_dim_v:
if use_flash_attention_2 and FlashAttentionUtils.is_installed:
logger.debug("Disabling FlashAttention 2 as it does not support MLA.")
use_flash_attention_2 = False

qkv_layout_group = qkv_layout.replace("b", "").replace("s", "").replace("t", "")
if use_fused_attention and qkv_layout_group != "hd_hd_hd":
logger.debug(
Comment thread
HollowMan6 marked this conversation as resolved.
Expand Down
Loading