Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
74 commits
Select commit Hold shift + click to select a range
6cf58b4
fix
hiworldwzj Apr 9, 2026
f1251a3
add gitignore
flyinglandlord Mar 12, 2026
d9b1fdd
finish usable mtp kernel
flyinglandlord Mar 17, 2026
315366a
end-to-end finish
flyinglandlord Mar 19, 2026
73ea125
fix cudagraph support
flyinglandlord Mar 19, 2026
dc91e59
save runnable version of dynamic mtp
flyinglandlord Mar 26, 2026
aefe67e
save runnable version of dynamic mtp
flyinglandlord Mar 26, 2026
3c28fb0
fix
hiworldwzj Apr 9, 2026
2dc933e
save fixed dynamic mtp
flyinglandlord Mar 27, 2026
0b08de8
save
flyinglandlord Mar 30, 2026
952ec15
save
flyinglandlord Mar 30, 2026
3750118
add experiment script
flyinglandlord Apr 1, 2026
4bf4287
update mtp kernel support BLOCK_BATCH < max_verify_group_size
flyinglandlord Apr 1, 2026
05e0dfd
fix implementation issues
flyinglandlord Apr 3, 2026
c2b7569
save
flyinglandlord Apr 4, 2026
80219af
save
flyinglandlord Apr 8, 2026
41180d3
fix
hiworldwzj Apr 9, 2026
8afd7a8
fix
hiworldwzj Apr 9, 2026
2b277fa
fix
hiworldwzj Apr 9, 2026
93c2ada
fix
hiworldwzj Apr 9, 2026
6395447
fix
hiworldwzj Apr 9, 2026
fc20624
fix
hiworldwzj Apr 9, 2026
1b08d15
fix
hiworldwzj Apr 9, 2026
775adbd
fix
hiworldwzj Apr 9, 2026
d4830ff
fix
hiworldwzj Apr 9, 2026
da944f8
fix
hiworldwzj Apr 9, 2026
22c5996
fix
hiworldwzj Apr 9, 2026
6fbe8d8
fix
hiworldwzj Apr 9, 2026
c4a9f74
fix lightllm/server/router/model_infer/mode_backend/generic_pre_proce…
flyinglandlord Apr 9, 2026
5eb4889
update generic_pre_process.py
flyinglandlord Apr 9, 2026
379f256
fix
flyinglandlord Apr 9, 2026
e943a43
fix
flyinglandlord Apr 9, 2026
e121b9d
refactor qwen3_eagle3
flyinglandlord Apr 9, 2026
1723230
add stage1
hiworldwzj Apr 9, 2026
6890bc0
fix
hiworldwzj Apr 9, 2026
1e1fb98
fix
hiworldwzj Apr 9, 2026
29535b2
fix
hiworldwzj Apr 9, 2026
5b925a6
fix
hiworldwzj Apr 9, 2026
538200d
fix
hiworldwzj Apr 9, 2026
2831c70
fix
hiworldwzj Apr 9, 2026
158c7a3
fix
hiworldwzj Apr 9, 2026
4c08120
fix
hiworldwzj Apr 10, 2026
a837cbb
fix
hiworldwzj Apr 10, 2026
17ed333
fix
hiworldwzj Apr 10, 2026
7927e15
fix
hiworldwzj Apr 10, 2026
53a7077
fix base_backend.py
flyinglandlord Apr 10, 2026
322f713
fix
hiworldwzj Apr 10, 2026
c3e46c9
fix
hiworldwzj Apr 10, 2026
2f0c250
fix
hiworldwzj Apr 10, 2026
8ab2ab4
fix
hiworldwzj Apr 10, 2026
76ca4ce
fix
hiworldwzj Apr 10, 2026
1d0f18e
fix
hiworldwzj Apr 10, 2026
6e69701
fix
hiworldwzj Apr 10, 2026
1565698
fix
hiworldwzj Apr 10, 2026
5e857c8
fix
hiworldwzj Apr 10, 2026
3272073
fix
hiworldwzj Apr 11, 2026
1126013
fix
hiworldwzj Apr 11, 2026
f591158
fix
flyinglandlord Apr 13, 2026
74189c2
add vllm test script
flyinglandlord Apr 20, 2026
df54935
remove 200000 token limit in test script
flyinglandlord May 7, 2026
261f5b4
Mtp optimization ema overlap (#1310)
flyinglandlord May 27, 2026
03ac700
add sample_dynamic_mtp_req_mask
hiworldwzj May 27, 2026
7bc26a7
fix
hiworldwzj May 27, 2026
4b93c8c
fix
hiworldwzj May 27, 2026
d7bdef0
fix
hiworldwzj May 27, 2026
a731153
Merge branch 'mtp_optimization' of https://github.com/ModelTC/LightLL…
flyinglandlord May 27, 2026
97a0123
fix
flyinglandlord May 29, 2026
62226d5
fix
hiworldwzj May 29, 2026
0783c1e
fix
hiworldwzj May 29, 2026
4777936
fix
hiworldwzj May 29, 2026
9dcfc05
fix
hiworldwzj May 29, 2026
9229d28
add dynamic fa3 (#1334)
shihaobai Jun 9, 2026
2305238
update mtp planner class
hiworldwzj Jun 12, 2026
706125d
update mtp planner
hiworldwzj Jun 12, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,9 +1,16 @@
__pycache__/
.pyc
.codex
build
dist
*.egg-info
.idea
.vscode
tmp/
requirements-musa.txt

hf_datasets_cache/
wandb/
datasets/
trace/
experiment_results/
12,002 changes: 12,002 additions & 0 deletions datasets/gsm8k.json

Large diffs are not rendered by default.

37 changes: 31 additions & 6 deletions lightllm/common/basemodel/attention/fa3/fp.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,11 @@
from typing import Optional, TYPE_CHECKING
from lightllm.utils.dist_utils import get_current_device_id
from lightllm.utils.sgl_utils import flash_attn_with_kvcache
from lightllm.utils.envs_utils import get_env_start_args
from lightllm.common.basemodel.triton_kernel.fa3_utils import page_table_copy
from lightllm.utils.envs_utils import enable_dynamic_mtp_verify, get_env_start_args
from lightllm.common.basemodel.triton_kernel.fa3_utils import (
build_dynamic_mtp_fa3_decode_params,
page_table_copy,
)
from lightllm.common.basemodel.triton_kernel.gen_prefill_params import gen_cumsum_pad0_tensor


Expand Down Expand Up @@ -125,8 +128,21 @@ class Fa3DecodeAttState(BaseDecodeAttState):
def init_state(self):
self.backend: Fa3AttBackend = self.backend
args_mtp_step = get_env_start_args().mtp_step
is_dynamic_mtp = args_mtp_step > 0 and enable_dynamic_mtp_verify()

if args_mtp_step > 0:
if is_dynamic_mtp:
att_batch_size = self.infer_state.batch_size
(b_q_seq_len, b_kv_seq_len, b_att_req_idx, self.b_att_seq_len,) = build_dynamic_mtp_fa3_decode_params(
b_req_idx=self.infer_state.b_req_idx,
b_seq_len=self.infer_state.b_seq_len,
b_mark_shared_group=self.infer_state.b_mark_shared_group,
att_batch_size=att_batch_size,
hold_req_id=self.backend.model.req_manager.HOLD_REQUEST_ID,
)
b1_cu_q_seq_len, b1_cu_kv_seq_len = gen_cumsum_pad0_tensor(b_q_seq_len, b_kv_seq_len)
self.cu_seqlens_q = b1_cu_q_seq_len.int()
self.cu_seqlens_k = b1_cu_kv_seq_len.int()
elif args_mtp_step > 0:
# 修正 mtp 在 fa3 下的输入。
mtp_size = args_mtp_step + 1
b_q_seq_len = torch.full(
Expand All @@ -139,12 +155,14 @@ def init_state(self):
b1_cu_q_seq_len, b1_cu_kv_seq_len = gen_cumsum_pad0_tensor(b_q_seq_len, b_kv_seq_len)
self.cu_seqlens_q = b1_cu_q_seq_len.int()
self.cu_seqlens_k = b1_cu_kv_seq_len.int()
att_batch_size = self.infer_state.batch_size // (args_mtp_step + 1)
else:
self.cu_seqlens_q = self.infer_state.b1_cu_q_seq_len.int()
self.cu_seqlens_k = self.infer_state.b1_cu_kv_seq_len.int()
att_batch_size = self.infer_state.batch_size

att_batch_size = self.infer_state.batch_size // (args_mtp_step + 1)
assert self.infer_state.batch_size % (args_mtp_step + 1) == 0
if not is_dynamic_mtp:
assert self.infer_state.batch_size % (args_mtp_step + 1) == 0

model = self.backend.model
# 可以使用 cuda graph的时候从 buffer中申请
Expand All @@ -163,7 +181,14 @@ def init_state(self):
device=self.infer_state.input_ids.device,
)

if args_mtp_step > 0:
if is_dynamic_mtp:
page_table_copy(
page_table=self.page_table[:, : self.infer_state.max_kv_seq_len],
req_to_token_indexs=model.req_manager.req_to_token_indexs,
b_req_idx=b_att_req_idx,
)
self.decode_max_q_seq_len = args_mtp_step + 1
elif args_mtp_step > 0:
page_table_copy(
page_table=self.page_table[:, : self.infer_state.max_kv_seq_len],
req_to_token_indexs=model.req_manager.req_to_token_indexs,
Expand Down
51 changes: 49 additions & 2 deletions lightllm/common/basemodel/attention/triton/fp.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import dataclasses
import torch

from lightllm.utils.envs_utils import enable_dynamic_mtp_verify, get_env_start_args, enable_triton_mtp_kernel
from ..base_att import BaseAttBackend, BasePrefillAttState, BaseDecodeAttState, AttControl
from typing import Optional

Expand Down Expand Up @@ -80,8 +82,17 @@ def _nomarl_prefill_att(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,

@dataclasses.dataclass
class TritonDecodeAttState(BaseDecodeAttState):
# MTP related state variables
b_mark_shared_group: torch.Tensor = None

def init_state(self):
pass
args_mtp_step = get_env_start_args().mtp_step

if args_mtp_step > 0:
# MTP mode initialization
self.b_mark_shared_group = self.infer_state.b_mark_shared_group
else:
self.b_mark_shared_group = None

def copy_for_decode_cuda_graph(self, new_state: "TritonDecodeAttState"):
super().copy_for_decode_cuda_graph(new_state)
Expand All @@ -99,9 +110,17 @@ def decode_att(
assert att_control.tp_alibi is not None
return self._alibi_decode_att(q=q, k=k, v=v, att_control=att_control, alloc_func=alloc_func)
else:

args_mtp_step = get_env_start_args().mtp_step

q_head_num = q.shape[1]
k_head_num = k.shape[1]
if q_head_num == k_head_num:

if args_mtp_step > 0 and (enable_dynamic_mtp_verify() or enable_triton_mtp_kernel()):
# MTP mode: use mtp diverse attention
assert q_head_num >= k_head_num, "MTP diverse attention requires q_head_num >= k_head_num"
return self._dynamic_mtp_decode_gqa_att(q=q, k=k, v=v, alloc_func=alloc_func)
elif q_head_num == k_head_num:
return self._normal_decode_flash_decoding_att(q=q, k=k, v=v, alloc_func=alloc_func)
elif q_head_num > k_head_num:
return self._normal_decode_gqa_flash_decoding_att(q=q, k=k, v=v, alloc_func=alloc_func)
Expand Down Expand Up @@ -182,6 +201,34 @@ def _normal_decode_gqa_flash_decoding_att(

return out

def _dynamic_mtp_decode_gqa_att(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
alloc_func=torch.empty,
):
from ...triton_kernel.att.decode_att.gqa.mtp_diverse import (
token_decode_attention_mtp_diverse_single_token,
)

b_seq_len = self.infer_state.b_seq_len
# 在动态 MTP 验证模式下,使用 infer_state.b_mark_shared_group(从 model_input 传递)
# 在静态 MTP 模式下,使用 self.b_mark_shared_group(在 init_state 中初始化)
b_mark_shared_group = self.infer_state.b_mark_shared_group
out = token_decode_attention_mtp_diverse_single_token(
q=q,
k=k,
v=v,
Req_to_tokens=self.infer_state.req_manager.req_to_token_indexs,
B_req_idx=self.infer_state.b_req_idx,
b_seq_len=b_seq_len,
b_mark_shared_group=b_mark_shared_group,
alloc_tensor_func=alloc_func,
)

return out

def _normal_decode_gqa_flash_decoding_att_vsm(
self,
q: torch.Tensor,
Expand Down
100 changes: 80 additions & 20 deletions lightllm/common/basemodel/basemodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,23 @@
from lightllm.common.basemodel.triton_kernel.gather_token_id import gather_token
from lightllm.utils.log_utils import init_logger
from lightllm.utils.dist_utils import get_dp_world_size
from lightllm.utils.envs_utils import get_env_start_args, get_llm_data_type, get_added_mtp_kv_layer_num
from lightllm.utils.envs_utils import (
enable_triton_mtp_kernel,
get_env_start_args,
get_llm_data_type,
get_added_mtp_kv_layer_num,
)
from lightllm.distributed.communication_op import dist_group_manager
from lightllm.common.basemodel.batch_objs import ModelInput, ModelOutput
from lightllm.common.basemodel.batch_objs import ModelInput, ModelOutput, OutHiddenState
from lightllm.common.triton_utils.autotuner import AutotuneLevel
from lightllm.utils.custom_kernel_utis import pad2dim_tensor_to_new_batch
from lightllm.utils.envs_utils import set_model_init_status, enable_diverse_mode_gqa_decode_fast_kernel
from lightllm.utils.envs_utils import (
set_model_init_status,
enable_diverse_mode_gqa_decode_fast_kernel,
enable_dynamic_mtp_verify,
)
from lightllm.common.triton_utils.autotuner import Autotuner
from lightllm.utils.infer_utils import post_empty_cache
from lightllm.utils.infer_utils import calculate_time, post_empty_cache
from .attention import get_prefill_att_backend_class, get_decode_att_backend_class
from .attention import BaseAttBackend

Expand Down Expand Up @@ -93,16 +102,11 @@ def __init__(self, kvargs):
self.mem_fraction = kvargs.get("mem_fraction", 0.9)
self.tp_world_size_ = get_dp_world_size()
self.enable_tpsp_mix_mode = get_env_start_args().enable_tpsp_mix_mode

self.is_mtp_mode = self.args.mtp_mode in [
"vanilla_with_att",
"eagle_with_att",
"vanilla_no_att",
"eagle_no_att",
]
self.prefill_graph: PrefillCudaGraph = None

self._init_config()
self._init_speculative_algo(kvargs)

self._verify_must()
self._verify_params()
self._init_quant()
Expand Down Expand Up @@ -137,6 +141,38 @@ def __init__(self, kvargs):
set_model_init_status(True)
return

def _init_speculative_algo(self, kvargs):
self.is_mtp_draft_model = kvargs.get("is_mtp_draft_model", False)
self.is_mtp_mode = self.args.mtp_mode in [
"vanilla_with_att",
"eagle_with_att",
"vanilla_no_att",
"eagle_no_att",
"eagle3",
]
if not self.is_mtp_mode:
self.output_hidden_layers = []
return

if not self.is_mtp_draft_model:
# 主 main model 需要输出 hidden state 用于 draft 模型进行 mtp 预测。
if self.args.mtp_mode == "eagle3":
# assert not self.args.enable_prefill_cudagraph, "eagle3 mode does not support prefill cudagraph"
assert (
not self.args.enable_decode_microbatch_overlap
), "eagle3 mode does not support decode microbatch overlap"
assert (
not self.args.enable_prefill_microbatch_overlap
), "eagle3 mode does not support prefill microbatch overlap"
self.output_hidden_layers = [1, self.config["n_layer"] // 2 - 1, self.config["n_layer"] - 4]
else:
self.output_hidden_layers = [self.config["n_layer"] - 1]
else:
# draft model 需要输出 hidden state 用于 多步 mtp 预测
self.output_hidden_layers = [self.config["n_layer"] - 1]

return

def _wait_other_modules_ready(self):
for event in self.wait_events:
event.wait()
Expand All @@ -151,6 +187,12 @@ def _init_config(self):
repair_config(self.config, same_names=["num_hidden_layers", "n_layer"])
if self.finetune_config:
self.config["vocab_size"] = self.finetune_config.vocab_size

# eagle3 mode 下,需要修改 vocab_size 为 draft_vocab_size, 其他场景
# 这个代码并不会生效。
if "draft_vocab_size" in self.config.keys():
self.config["target_vocab_size"] = self.config["vocab_size"]
self.config["vocab_size"] = self.config["draft_vocab_size"]
return

@final
Expand Down Expand Up @@ -314,6 +356,8 @@ def _create_inferstate(self, model_input: ModelInput, microbatch_index: int = 0)
if enable_diverse_mode_gqa_decode_fast_kernel():
infer_state.b_shared_seq_len = model_input.b_shared_seq_len
infer_state.b_mark_shared_group = model_input.b_mark_shared_group
elif enable_dynamic_mtp_verify() or enable_triton_mtp_kernel():
infer_state.b_mark_shared_group = model_input.b_mark_shared_group

infer_state.multimodal_params = model_input.multimodal_params

Expand Down Expand Up @@ -377,6 +421,11 @@ def _create_padded_decode_model_input(self, model_input: ModelInput, new_batch_s
new_model_input.b_mark_shared_group = F.pad(
new_model_input.b_mark_shared_group, (0, padded_batch_size), mode="constant", value=1
)
elif enable_dynamic_mtp_verify() or enable_triton_mtp_kernel():
assert new_model_input.b_mark_shared_group is not None
new_model_input.b_mark_shared_group = F.pad(
new_model_input.b_mark_shared_group, (0, padded_batch_size), mode="constant", value=0
)

# 特殊模型,特殊模式的特殊变量的特殊 padding
if new_model_input.mtp_draft_input_hiddens is not None:
Expand Down Expand Up @@ -561,12 +610,23 @@ def _context_forward(self, infer_state: InferStateInfo):
input_tensors = [input_embs]

def prefill_func(input_tensors, infer_state):
mtp_out_hidden_state = OutHiddenState(selected_layers=self.output_hidden_layers)
_input_embs = input_tensors[0]
for i in range(self.layers_num):
layer = self.layers_infer[i]
layer_method = (layer.context_forward, layer.tpsp_context_forward)[run_mode_index]
_input_embs = layer_method(_input_embs, infer_state, self.trans_layers_weight[i])
return [_input_embs]
mtp_out_hidden_state.add_hidden(
layer_index=i,
layer_num=self.layers_num,
hidden=_input_embs,
)

capture_hiddens = mtp_out_hidden_state.get_captured_hiddens()
if capture_hiddens is not None:
return [_input_embs, capture_hiddens]
else:
return [_input_embs]

handle_token_num = input_ids.shape[0]

Expand Down Expand Up @@ -596,8 +656,8 @@ def prefill_func(input_tensors, infer_state):
model_output = ModelOutput(logits=predict_logits)

# 特殊模型特殊模式的额外输出
if self.is_mtp_mode:
model_output.mtp_main_output_hiddens = input_embs
if self.is_mtp_mode and len(output_tensors) > 1:
model_output.mtp_main_output_hiddens = output_tensors[1]

# 在开启使用deepep的时候,需要调用clear_deepep_buffer做资源清理,没有启用的时候
# 该调用没有实际意义
Expand All @@ -611,22 +671,21 @@ def _token_forward(self, infer_state: InferStateInfo):
cuda_input_ids = input_ids
pre_method = (self.pre_infer.token_forward, self.pre_infer.tpsp_token_forward)[run_mode_index]
input_embs = pre_method(cuda_input_ids, infer_state, self.pre_post_weight)
mtp_out_hidden_state = OutHiddenState(selected_layers=self.output_hidden_layers)
for i in range(self.layers_num):
layer = self.layers_infer[i]
layer_method = (layer.token_forward, layer.tpsp_token_forward)[run_mode_index]
input_embs: torch.Tensor = layer_method(input_embs, infer_state, self.trans_layers_weight[i])
mtp_out_hidden_state.add_hidden(layer_index=i, layer_num=self.layers_num, hidden=input_embs)

capture_hiddens = mtp_out_hidden_state.get_captured_hiddens()
post_method = (self.post_infer.token_forward, self.post_infer.tpsp_token_forward)[run_mode_index]
predict_logits: torch.Tensor = post_method(input_embs, infer_state, self.pre_post_weight)

if self.is_mtp_mode:
graph_out_hiddens = input_embs.contiguous()

model_output = ModelOutput(logits=predict_logits.contiguous())

# 特殊模型特殊模式的额外输出
if self.is_mtp_mode:
model_output.mtp_main_output_hiddens = graph_out_hiddens
if self.is_mtp_mode and capture_hiddens is not None:
model_output.mtp_main_output_hiddens = capture_hiddens.contiguous()

# 在 cuda graph 模式下,输出需要转为 no ref tensor, 加强mem pool 的复用,降低显存的使用。
if infer_state.is_cuda_graph:
Expand Down Expand Up @@ -1027,6 +1086,7 @@ def _gen_special_model_input(self, token_num: int):
or "Qwen3MOEMTPModel" in str(self.__class__)
or "MistralMTPModel" in str(self.__class__)
or "Glm4MoeLiteMTPModel" in str(self.__class__)
or "Qwen3EagleModel" in str(self.__class__)
)
if is_mtp_draft_model:
special_model_input["mtp_draft_input_hiddens"] = torch.randn(
Expand Down
Loading