Skip to content
5 changes: 3 additions & 2 deletions lightllm/common/basemodel/attention/create_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Attention backend selection utilities."""
from lightllm.utils.envs_utils import get_env_start_args
from lightllm.common.basemodel.attention.paged_fa3.fp import PagedFa3AttBackend
from lightllm.utils.envs_utils import get_env_start_args, get_page_size
from lightllm.utils.log_utils import init_logger
from lightllm.utils.backend_validator import validate
from typing import Dict
Expand All @@ -23,7 +24,7 @@
data_type_to_backend = {
"None": {
"triton": TritonAttBackend,
"fa3": Fa3AttBackend,
"fa3": PagedFa3AttBackend if get_page_size() > 1 else Fa3AttBackend,
"flashinfer": FlashInferAttBackend,
},
"int4kv": {
Expand Down
2 changes: 1 addition & 1 deletion lightllm/common/basemodel/attention/flashinfer/fp.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def _nomarl_prefill_att(
self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, alloc_func=torch.empty
) -> torch.Tensor:
self.backend: FlashInferAttBackend = self.backend # for typing
o_tensor = alloc_func(q.shape, q.dtype, device="cuda")
o_tensor = alloc_func(q.shape, q.dtype, device=q.device)
self.prefill_wrapper.run(
q,
(k.unsqueeze(1), v.unsqueeze(1)),
Expand Down
4 changes: 2 additions & 2 deletions lightllm/common/basemodel/attention/flashinfer/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def prefill_att(
def _fp8_prefill_att(
self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, alloc_func=torch.empty
) -> torch.Tensor:
o_tensor = alloc_func(q.shape, q.dtype, device="cuda")
o_tensor = alloc_func(q.shape, q.dtype, device=q.device)
k = k.unsqueeze(1).view(torch.float8_e4m3fn)
v = v.unsqueeze(1).view(torch.float8_e4m3fn)
layer_index = self.backend._find_layer_index(k=k, v=v, att_state=self)
Expand Down Expand Up @@ -97,7 +97,7 @@ def _fp8_decode_att(
v: torch.Tensor,
alloc_func=torch.empty,
):
o_tensor = alloc_func(q.shape, q.dtype, device="cuda")
o_tensor = alloc_func(q.shape, q.dtype, device=q.device)

k = k.unsqueeze(1).view(torch.float8_e4m3fn)
v = v.unsqueeze(1).view(torch.float8_e4m3fn)
Expand Down
4 changes: 2 additions & 2 deletions lightllm/common/basemodel/attention/flashinfer/mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def _mla_prefill_att(
) -> torch.Tensor:
self.backend: MlaFlashInferAttBackend = self.backend # for typing
k_nope, k_rope = k
o_tensor = alloc_func((q.shape[0], q.shape[1], v.shape[-1]), q.dtype, device="cuda")
o_tensor = alloc_func((q.shape[0], q.shape[1], v.shape[-1]), q.dtype, device=q.device)
q_head_num = q.shape[1]
k = torch.cat([k_nope, torch.repeat_interleave(k_rope, q_head_num, dim=-2)], dim=-1)
self.prefill_wrapper.run(q, k, v, out=o_tensor)
Expand All @@ -125,7 +125,7 @@ def init_state(self):

self.kv_starts = self.infer_state.b1_cu_kv_seq_len

self.q_indptr = torch.arange(batch_size + 1, dtype=torch.int32, device="cuda")
self.q_indptr = torch.arange(batch_size + 1, dtype=torch.int32, device=device)
if batch_size <= model.graph_max_batch_size and self.infer_state.max_kv_seq_len <= model.graph_max_len_in_batch:
self.kv_indices = self.backend.kv_indices_buffer[self.infer_state.microbatch_index][
: batch_size * self.backend.max_seq_length
Expand Down
13 changes: 11 additions & 2 deletions lightllm/common/basemodel/attention/nsa/flashmla_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,10 +135,19 @@ def init_state(self):
hold_req_idx=self.infer_state.req_manager.HOLD_REQUEST_ID,
)
self.nsa_cache_seqlens = torch.minimum(
torch.full(size=(self.infer_state.batch_size,), fill_value=2048, dtype=torch.int32, device="cuda"),
torch.full(
size=(self.infer_state.batch_size,),
fill_value=2048,
dtype=torch.int32,
device=self.infer_state.input_ids.device,
),
self.infer_state.b_seq_len,
)
padded_seq_lens = torch.zeros(size=(self.nsa_cache_seqlens.shape[0] + 1,), dtype=torch.int32, device="cuda")
padded_seq_lens = torch.zeros(
size=(self.nsa_cache_seqlens.shape[0] + 1,),
dtype=torch.int32,
device=self.infer_state.input_ids.device,
)
# 进行 cumsum 操作
padded_seq_lens[1:].copy_(self.nsa_cache_seqlens, non_blocking=True)
self.nsa_cu_seqlens_k_new = padded_seq_lens.cumsum(dim=0, dtype=torch.int32)
Expand Down
Loading