From 71626d781760920d730eff109edac23cf1786795 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Tue, 9 Jun 2026 08:41:22 +0000 Subject: [PATCH 01/17] remove nccl pd --- docs/CN/source/tutorial/api_server_args.rst | 8 +- .../source/tutorial/deepseek_deployment.rst | 16 +- docs/EN/source/tutorial/api_server_args.rst | 8 +- .../source/tutorial/deepseek_deployment.rst | 16 +- lightllm/common/basemodel/basemodel.py | 6 - .../deepseek2_mem_manager.py | 189 ------- .../kv_cache_mem_manager/mem_manager.py | 162 +----- lightllm/server/api_cli.py | 10 +- lightllm/server/api_http.py | 14 +- lightllm/server/api_start.py | 6 +- lightllm/server/core/objs/start_args_type.py | 4 +- lightllm/server/detokenization/manager.py | 2 +- lightllm/server/httpserver/manager.py | 2 +- .../httpserver_for_pd_master/manager.py | 101 +--- lightllm/server/pd_io_struct.py | 119 +---- lightllm/server/router/manager.py | 20 +- .../model_infer/mode_backend/__init__.py | 4 - .../model_infer/mode_backend/base_backend.py | 5 +- .../continues_batch/pd_mode/__init__.py | 0 .../pd_mode/decode_node_impl/__init__.py | 2 - .../pd_mode/decode_node_impl/decode_impl.py | 104 ---- .../decode_node_impl/decode_impl_for_dp.py | 28 - .../decode_node_impl/decode_infer_rpyc.py | 209 -------- .../decode_kv_move_manager.py | 383 -------------- .../decode_node_impl/decode_task_cache.py | 10 - .../decode_node_impl/decode_trans_obj.py | 305 ----------- .../decode_node_impl/decode_trans_process.py | 155 ------ .../pd_mode/decode_node_impl/up_status.py | 127 ----- .../pd_mode/prefill_node_impl/__init__.py | 2 - .../pd_mode/prefill_node_impl/prefill_impl.py | 121 ----- .../prefill_node_impl/prefill_impl_for_dp.py | 30 -- .../prefill_node_impl/prefill_infer_rpyc.py | 47 -- .../prefill_kv_move_manager.py | 241 --------- .../prefill_node_impl/prefill_task_cache.py | 8 - .../prefill_node_impl/prefill_trans_obj.py | 378 -------------- .../prefill_trans_process.py | 162 ------ .../continues_batch/pd_mode/task_queue.py | 48 -- .../continues_batch/pd_mode/utils.py | 20 - .../decode_node_impl/decode_trans_process.py | 9 +- .../pd_nixl/decode_node_impl/up_status.py | 10 +- .../mode_backend/pd_nixl/kv_transporter.py | 37 ++ .../pd_nixl/nccl_kv_transporter.py | 491 ++++++++++++++++++ .../pd_mode => pd_nixl}/p2p_fix.py | 56 +- .../prefill_trans_process.py | 9 +- .../server/router/model_infer/model_rpc.py | 19 +- lightllm/server/router/req_queue/__init__.py | 5 - .../chunked_prefill/impl_for_pd_decode.py | 82 --- skills/test_model/qwen3-8b-pd-nccl/SKILL.md | 187 ------- skills/test_model/qwen3-8b-pd-nixl/SKILL.md | 4 +- test/acc/test_pd_nccl.sh | 58 --- test/start_scripts/README.md | 19 +- test/start_scripts/multi_pd_master.sh | 34 -- .../multi_pd_master/pd_decode.sh | 19 - .../multi_pd_master/pd_prefill.sh | 21 - .../single_pd_master/pd_decode.sh | 20 - .../single_pd_master/pd_nixl_decode.sh | 4 +- .../single_pd_master/pd_nixl_prefill.sh | 4 +- .../single_pd_master/pd_prefill.sh | 21 - 58 files changed, 639 insertions(+), 3542 deletions(-) delete mode 100644 lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/__init__.py delete mode 100644 lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/__init__.py delete mode 100644 lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_impl.py delete mode 100644 lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_impl_for_dp.py delete mode 100644 lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_infer_rpyc.py delete mode 100644 lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_kv_move_manager.py delete mode 100644 lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_task_cache.py delete mode 100644 lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_obj.py delete mode 100644 lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_process.py delete mode 100644 lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/up_status.py delete mode 100644 lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/__init__.py delete mode 100644 lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_impl.py delete mode 100644 lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_impl_for_dp.py delete mode 100644 lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_infer_rpyc.py delete mode 100644 lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_kv_move_manager.py delete mode 100644 lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_task_cache.py delete mode 100644 lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_obj.py delete mode 100644 lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_process.py delete mode 100644 lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/task_queue.py delete mode 100644 lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/utils.py create mode 100644 lightllm/server/router/model_infer/mode_backend/pd_nixl/kv_transporter.py create mode 100644 lightllm/server/router/model_infer/mode_backend/pd_nixl/nccl_kv_transporter.py rename lightllm/server/router/model_infer/mode_backend/{continues_batch/pd_mode => pd_nixl}/p2p_fix.py (64%) delete mode 100644 lightllm/server/router/req_queue/chunked_prefill/impl_for_pd_decode.py delete mode 100644 skills/test_model/qwen3-8b-pd-nccl/SKILL.md delete mode 100644 test/acc/test_pd_nccl.sh delete mode 100644 test/start_scripts/multi_pd_master.sh delete mode 100644 test/start_scripts/multi_pd_master/pd_decode.sh delete mode 100644 test/start_scripts/multi_pd_master/pd_prefill.sh delete mode 100644 test/start_scripts/single_pd_master/pd_decode.sh delete mode 100644 test/start_scripts/single_pd_master/pd_prefill.sh diff --git a/docs/CN/source/tutorial/api_server_args.rst b/docs/CN/source/tutorial/api_server_args.rst index 8e7f9d78e8..a42f30329c 100644 --- a/docs/CN/source/tutorial/api_server_args.rst +++ b/docs/CN/source/tutorial/api_server_args.rst @@ -13,8 +13,8 @@ APIServer 参数详解 设置运行模式,可选值: * ``normal``: 单服务器模式(默认) - * ``prefill``: 预填充模式(用于 pd 分离运行模式) - * ``decode``: 解码模式(用于 pd 分离运行模式) + * ``nixl_prefill``: 预填充模式(用于 pd 分离运行模式) + * ``nixl_decode``: 解码模式(用于 pd 分离运行模式) * ``pd_master``: pd 主节点模式(用于 pd 分离运行模式) * ``config_server``: 配置服务器模式(用于 pd 分离模式,用于注册 pd_master 节点并获取 pd_master 节点列表),专门为大规模、高并发场景设计,当 `pd_master` 遇到显著的 CPU 瓶颈时使用。 @@ -56,13 +56,13 @@ PD 分离模式参数 PD 主节点 IP 地址,默认为 ``0.0.0.0`` - 当 run_mode 设置为 prefill 或 decode 时需要设置此参数 + 当 run_mode 设置为 nixl_prefill 或 nixl_decode 时需要设置此参数 .. option:: --pd_master_port PD 主节点端口,默认为 ``1212`` - 当 run_mode 设置为 prefill 或 decode 时需要设置此参数 + 当 run_mode 设置为 nixl_prefill 或 nixl_decode 时需要设置此参数 .. option:: --pd_decode_rpyc_port diff --git a/docs/CN/source/tutorial/deepseek_deployment.rst b/docs/CN/source/tutorial/deepseek_deployment.rst index de7ecc84c3..8901dd9b98 100644 --- a/docs/CN/source/tutorial/deepseek_deployment.rst +++ b/docs/CN/source/tutorial/deepseek_deployment.rst @@ -174,7 +174,8 @@ PD (Prefill-Decode) 分离模式将预填充和解码阶段分离部署,可以 .. code-block:: bash # PD prefill 模式 for DeepSeek-R1 (DP+EP) on H200 - # 使用方法: sh pd_prefill.sh + # 使用方法: sh pd_nixl_prefill.sh + # 默认使用 NIXL 传输;如需使用 NCCL 数据面,可设置 LIGHTLLM_PD_KV_TRANSPORT_BACKEND=nccl # nvidia-cuda-mps-control -d,运行MPS(可选, 有mps支持性能会好特别多,但是部分显卡和驱动环境开启mps会容易出现错误,建议升级驱动到较高版本,特别是H系列卡) export host=$1 @@ -182,7 +183,7 @@ PD (Prefill-Decode) 分离模式将预填充和解码阶段分离部署,可以 nvidia-cuda-mps-control -d LOADWORKER=18 python -m lightllm.server.api_server \ --model_dir /path/DeepSeek-R1 \ - --run_mode "prefill" \ + --run_mode "nixl_prefill" \ --tp 8 \ --dp 8 \ --host $host \ @@ -200,13 +201,14 @@ PD (Prefill-Decode) 分离模式将预填充和解码阶段分离部署,可以 .. code-block:: bash # PD decode 模式 for DeepSeek-R1 (DP+EP) on H200 - # 使用方法: sh pd_decode.sh + # 使用方法: sh pd_nixl_decode.sh + # 默认使用 NIXL 传输;如需使用 NCCL 数据面,可设置 LIGHTLLM_PD_KV_TRANSPORT_BACKEND=nccl export host=$1 export pd_master_ip=$2 nvidia-cuda-mps-control -d LOADWORKER=18 python -m lightllm.server.api_server \ --model_dir /path/DeepSeek-R1 \ - --run_mode "decode" \ + --run_mode "nixl_decode" \ --tp 8 \ --dp 8 \ --host $host \ @@ -274,7 +276,7 @@ PD (Prefill-Decode) 分离模式将预填充和解码阶段分离部署,可以 nvidia-cuda-mps-control -d LOADWORKER=18 python -m lightllm.server.api_server \ --model_dir /path/DeepSeek-R1 \ - --run_mode "prefill" \ + --run_mode "nixl_prefill" \ --host $host \ --port 8019 \ --tp 8 \ @@ -293,7 +295,7 @@ PD (Prefill-Decode) 分离模式将预填充和解码阶段分离部署,可以 nvidia-cuda-mps-control -d LOADWORKER=18 python -m lightllm.server.api_server \ --model_dir /path/DeepSeek-R1 \ - --run_mode "decode" \ + --run_mode "nixl_decode" \ --host $host \ --port 8121 \ --nccl_port 12322 \ @@ -336,4 +338,4 @@ PD (Prefill-Decode) 分离模式将预填充和解码阶段分离部署,可以 --tokenizer_path /path/DeepSeek-R1/ \ --url http://127.0.0.1:8088/generate_stream -以上所有脚本可以参考 `test/start_scripts/multi_pd_master/` 目录下的脚本。 \ No newline at end of file +以上所有脚本可以参考 `test/start_scripts/multi_pd_master/` 目录下的脚本。 diff --git a/docs/EN/source/tutorial/api_server_args.rst b/docs/EN/source/tutorial/api_server_args.rst index 84785de3b7..4e8083881a 100644 --- a/docs/EN/source/tutorial/api_server_args.rst +++ b/docs/EN/source/tutorial/api_server_args.rst @@ -13,8 +13,8 @@ Basic Configuration Parameters Set the running mode, optional values: * ``normal``: Single server mode (default) - * ``prefill``: Prefill mode (for pd disaggregation running mode) - * ``decode``: Decode mode (for pd disaggregation running mode) + * ``nixl_prefill``: Prefill mode (for pd disaggregation running mode) + * ``nixl_decode``: Decode mode (for pd disaggregation running mode) * ``pd_master``: pd master node mode (for pd disaggregation running mode) * ``config_server``: Configuration server mode (for pd disaggregation mode, used to register pd_master nodes and get pd_master node list), specifically designed for large-scale, high-concurrency scenarios, used when `pd_master` encounters significant CPU bottlenecks. @@ -56,13 +56,13 @@ PD disaggregation Mode Parameters PD master node IP address, default is ``0.0.0.0`` - This parameter needs to be set when run_mode is set to prefill or decode + This parameter needs to be set when run_mode is set to nixl_prefill or nixl_decode .. option:: --pd_master_port PD master node port, default is ``1212`` - This parameter needs to be set when run_mode is set to prefill or decode + This parameter needs to be set when run_mode is set to nixl_prefill or nixl_decode .. option:: --pd_decode_rpyc_port diff --git a/docs/EN/source/tutorial/deepseek_deployment.rst b/docs/EN/source/tutorial/deepseek_deployment.rst index 4c5a121dd6..24b65e4727 100755 --- a/docs/EN/source/tutorial/deepseek_deployment.rst +++ b/docs/EN/source/tutorial/deepseek_deployment.rst @@ -174,7 +174,8 @@ PD (Prefill-Decode) disaggregation mode separates prefill and decode stages for .. code-block:: bash # PD prefill mode for DeepSeek-R1 (DP+EP) on H200 - # Usage: sh pd_prefill.sh + # Usage: sh pd_nixl_prefill.sh + # NIXL is used by default. To use NCCL as the data-plane backend, set LIGHTLLM_PD_KV_TRANSPORT_BACKEND=nccl. # nvidia-cuda-mps-control -d, run MPS (optional, performance will be much better with mps support, but some GPUs may encounter errors when enabling mps, it's recommended to upgrade to a higher driver version, especially for H-series cards) export host=$1 @@ -182,7 +183,7 @@ PD (Prefill-Decode) disaggregation mode separates prefill and decode stages for nvidia-cuda-mps-control -d LOADWORKER=18 python -m lightllm.server.api_server \ --model_dir /path/DeepSeek-R1 \ - --run_mode "prefill" \ + --run_mode "nixl_prefill" \ --tp 8 \ --dp 8 \ --host $host \ @@ -197,13 +198,14 @@ PD (Prefill-Decode) disaggregation mode separates prefill and decode stages for .. code-block:: bash # PD decode mode for DeepSeek-R1 (DP+EP) on H200 - # Usage: sh pd_decode.sh + # Usage: sh pd_nixl_decode.sh + # NIXL is used by default. To use NCCL as the data-plane backend, set LIGHTLLM_PD_KV_TRANSPORT_BACKEND=nccl. export host=$1 export pd_master_ip=$2 nvidia-cuda-mps-control -d LOADWORKER=18 python -m lightllm.server.api_server \ --model_dir /path/DeepSeek-R1 \ - --run_mode "decode" \ + --run_mode "nixl_decode" \ --tp 8 \ --dp 8 \ --host $host \ @@ -271,7 +273,7 @@ Supports multiple PD Master nodes, providing better load balancing and high avai nvidia-cuda-mps-control -d LOADWORKER=18 python -m lightllm.server.api_server \ --model_dir /path/DeepSeek-R1 \ - --run_mode "prefill" \ + --run_mode "nixl_prefill" \ --host $host \ --port 8019 \ --tp 8 \ @@ -290,7 +292,7 @@ Supports multiple PD Master nodes, providing better load balancing and high avai nvidia-cuda-mps-control -d LOADWORKER=18 python -m lightllm.server.api_server \ --model_dir /path/DeepSeek-R1 \ - --run_mode "decode" \ + --run_mode "nixl_decode" \ --host $host \ --port 8121 \ --nccl_port 12322 \ @@ -333,4 +335,4 @@ Supports multiple PD Master nodes, providing better load balancing and high avai --tokenizer_path /path/DeepSeek-R1/ \ --url http://127.0.0.1:8088/generate_stream -All the above scripts can be referenced from the scripts in the `test/start_scripts/multi_pd_master/` directory. \ No newline at end of file +All the above scripts can be referenced from the scripts in the `test/start_scripts/multi_pd_master/` directory. diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index 473dcbafda..94f9d4c1a2 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -110,7 +110,6 @@ def __init__(self, kvargs): # 这可能会占用大量的显存,所以,req_manger 中保存的 mem_manger 是mem manager 初始化后再赋值 self.req_manager.mem_manager = self.mem_manager - self._init_kv_move_buffer() self._check_mem_size() self._init_infer_layer() self._init_some_value() @@ -197,11 +196,6 @@ def _init_mem_manager(self): ) return - def _init_kv_move_buffer(self): - # p d 分离的推理模式下才需要做这一步初始化 - if self.run_mode in ["prefill", "decode"]: - self.mem_manager.alloc_kv_move_buffer(self.mem_manager.size) - def _check_mem_size(self): self.max_total_token_num = self.mem_manager.size diff --git a/lightllm/common/kv_cache_mem_manager/deepseek2_mem_manager.py b/lightllm/common/kv_cache_mem_manager/deepseek2_mem_manager.py index 7a24a59110..9eb02b963c 100644 --- a/lightllm/common/kv_cache_mem_manager/deepseek2_mem_manager.py +++ b/lightllm/common/kv_cache_mem_manager/deepseek2_mem_manager.py @@ -1,13 +1,9 @@ import torch import os import torch.distributed as dist -from lightllm.server.pd_io_struct import KVMoveTask from .mem_manager import MemoryManager from typing import List, Union, Any from lightllm.utils.log_utils import init_logger -from lightllm.common.kv_trans_kernel.kv_trans import kv_trans -from lightllm.common.kv_trans_kernel.kv_trans_v2 import kv_trans_v2_for_d_node, kv_trans_v2_for_p_node -from lightllm.distributed.pynccl import PyNcclCommunicator from lightllm.common.kv_trans_kernel.nixl_kv_trans import mla_page_io from .operator import Deepseek2MemOperator @@ -32,14 +28,6 @@ def get_cell_size(self): def _init_buffers(self, size, dtype, head_num, head_dim, layer_num): self.kv_buffer = torch.empty((layer_num, size + 1, head_num, head_dim), dtype=dtype, device="cuda") - def alloc_kv_move_buffer(self, max_req_total_len): - self.kv_move_buffer = torch.empty( - (1, max_req_total_len + 8, self.head_num, self.head_dim), dtype=self.dtype, device="cuda" - ) - self.kv_move_buf_indexes = torch.arange(0, max_req_total_len + 8, dtype=torch.int64, device="cuda") - self.token_dim_size = self.kv_move_buffer.shape[-1] * self.kv_move_buffer.shape[-2] - return - def alloc_paged_kv_move_buffer(self, page_num, page_size) -> torch.Tensor: self.kv_move_buffer = torch.empty( (page_num, page_size, self.layer_num, self.head_num, self.head_dim), dtype=self.dtype, device="cuda" @@ -96,180 +84,3 @@ def read_page_kv_move_buffer_to_mem( kv_buffer=mem.kv_buffer, mode="read", ) - - def send_to_decode_node( - self, - move_tasks: List[KVMoveTask], - mem_managers: List["Deepseek2MemoryManager"], - dp_size_in_node: int, - nccl_comm: PyNcclCommunicator, - ): - assert dp_size_in_node == 1 - - # 先将数据发送到指定的一张卡上的buffer,再发送。 - move_token_indexes = [] - for task in move_tasks: - if task.move_kv_len != 0: - move_token_indexes.extend(task.prefill_token_indexes[-task.move_kv_len :]) - - cur_device_index = self.kv_buffer.get_device() - cur_mem = mem_managers[cur_device_index] - for layer_index in range(cur_mem.layer_num): - move_buffer = cur_mem._get_kv_move_data(move_token_indexes, layer_index) - nccl_comm.send(move_buffer, dst=1) - return - - def _get_kv_move_data(self, token_indexes: List[int], layer_index: int): - move_size = self.token_dim_size * len(token_indexes) - move_buffer = self.kv_move_buffer.view(-1)[0:move_size].view( - 1, len(token_indexes), self.head_num, self.head_dim - ) - move_buffer[:, :, :, :] = self.kv_buffer[layer_index, token_indexes, :, :] - return move_buffer - - def receive_from_prefill_node( - self, - move_tasks: List[KVMoveTask], - mem_managers: List["MemoryManager"], - dp_size_in_node: int, - nccl_comm: PyNcclCommunicator, - ): - assert dp_size_in_node == 1 - - # 先将数据接受到指定的一张卡上的buffer,再复制到其他的卡上。 - move_token_indexes = [] - for task in move_tasks: - if task.move_kv_len != 0: - move_token_indexes.extend(task.decode_token_indexes[-task.move_kv_len :]) - - cur_device_index = self.kv_buffer.get_device() - token_num = len(move_token_indexes) - move_size = self.token_dim_size * token_num - recive_buffer = self.kv_move_buffer.view(-1)[0:move_size].view(1, token_num, self.head_num, self.head_dim) - for layer_index in range(self.layer_num): - nccl_comm.recv(recive_buffer, src=0) - for i, mem in enumerate(mem_managers): - if i == cur_device_index: - mem._write_kv_move_data(move_token_indexes, recive_buffer, layer_index) - else: - new_recive_buffer = mem.kv_move_buffer.view(-1)[0:move_size].view(recive_buffer.shape) - from torch.cuda import comm - - comm.broadcast(recive_buffer, out=[new_recive_buffer]) - mem._write_kv_move_data(move_token_indexes, new_recive_buffer, layer_index) - return - - def _write_kv_move_data(self, token_indexes: torch.Tensor, buffer_tensor: torch.Tensor, layer_index): - self.kv_buffer[layer_index : layer_index + 1, token_indexes, :, :] = buffer_tensor - return - - def send_to_decode_node_p2p( - self, - move_tasks: List[KVMoveTask], - mem_managers: List["MemoryManager"], - dp_size_in_node: int, - nccl_comm: PyNcclCommunicator, - ): - """ - 使用 p2p triton kernel 进行数据复制和传输的实现方式。 - """ - if not hasattr(self, "mem_ptrs_dict"): - self.mem_ptrs_dict = {} - for layer_index in range(self.layer_num): - mems_ptr = [] - for i in range(0, len(mem_managers), len(mem_managers) // dp_size_in_node): - mems_ptr.append(mem_managers[i].kv_buffer[layer_index, :, :, :].data_ptr()) - mems_ptr = torch.tensor(mems_ptr, dtype=torch.uint64, device="cuda") - self.mem_ptrs_dict[layer_index] = mems_ptr - - move_token_indexes = [] - token_dp_indexes = [] - for task in move_tasks: - if task.move_kv_len != 0: - move_token_indexes.extend(task.prefill_token_indexes[-task.move_kv_len :]) - token_dp_indexes.extend([task.prefill_dp_index for _ in range(task.move_kv_len)]) - - move_token_indexes = torch.tensor(move_token_indexes, dtype=torch.int64, device="cuda") - token_dp_indexes = torch.tensor(token_dp_indexes, dtype=torch.int32, device="cuda") - for layer_index in range(self.layer_num): - move_buffer = self._get_kv_move_data_p2p( - move_token_indexes, token_dp_indexes, layer_index, self.kv_move_buffer, dp_size_in_node - ) - nccl_comm.send(move_buffer, dst=1) - return - - def _get_kv_move_data_p2p( - self, - token_indexes: torch.Tensor, - token_dp_indexes: torch.Tensor, - layer_index: int, - kv_move_buffer: torch.Tensor, - dp_size_in_node: int, - ): - move_token_num = len(token_indexes) - move_size = self.token_dim_size * move_token_num - move_buffer = kv_move_buffer.view(-1)[0:move_size].view(move_token_num, self.head_num, self.head_dim) - kv_trans_v2_for_p_node( - input_mems=self.mem_ptrs_dict[layer_index], - input_idx=token_indexes, - input_dp_idx=token_dp_indexes, - output=move_buffer, - output_idx=self.kv_move_buf_indexes[0:move_token_num], - dp_size_in_node=dp_size_in_node, - ) - return move_buffer - - def receive_from_prefill_node_p2p( - self, - move_tasks: List[KVMoveTask], - mem_managers: List["MemoryManager"], - dp_size_in_node: int, - nccl_comm: PyNcclCommunicator, - ): - if not hasattr(self, "mem_ptrs_dict"): - self.mem_ptrs_dict = {} - for layer_index in range(self.layer_num): - mems_ptr = [] - for i in range(0, len(mem_managers)): - mems_ptr.append(mem_managers[i].kv_buffer[layer_index, :, :, :].data_ptr()) - mems_ptr = torch.tensor(mems_ptr, dtype=torch.uint64, device="cuda") - self.mem_ptrs_dict[layer_index] = mems_ptr - - move_token_indexes = [] - token_dp_indexes = [] - for task in move_tasks: - if task.move_kv_len != 0: - move_token_indexes.extend(task.decode_token_indexes[-task.move_kv_len :]) - token_dp_indexes.extend([task.decode_dp_index for _ in range(task.move_kv_len)]) - - move_token_indexes = torch.tensor(move_token_indexes, dtype=torch.int64, device="cuda") - token_dp_indexes = torch.tensor(token_dp_indexes, dtype=torch.int32, device="cuda") - - token_num = len(move_token_indexes) - move_size = self.token_dim_size * token_num - recive_buffer = self.kv_move_buffer.view(-1)[0:move_size].view(token_num, self.head_num, self.head_dim) - for layer_index in range(self.layer_num): - nccl_comm.recv(recive_buffer, src=0) - self._write_kv_move_data_p2p( - move_token_indexes, token_dp_indexes, recive_buffer, layer_index, dp_size_in_node - ) - return - - def _write_kv_move_data_p2p( - self, - token_indexes: torch.Tensor, - token_dp_indexes: torch.Tensor, - buffer_tensor: torch.Tensor, - layer_index, - dp_size_in_node: int, - ): - move_token_num = len(token_indexes) - kv_trans_v2_for_d_node( - output_mems=self.mem_ptrs_dict[layer_index], - output_idx=token_indexes, - output_dp_idx=token_dp_indexes, - input=buffer_tensor, - input_idx=self.kv_move_buf_indexes[0:move_token_num], - dp_size_in_node=dp_size_in_node, - ) - return diff --git a/lightllm/common/kv_cache_mem_manager/mem_manager.py b/lightllm/common/kv_cache_mem_manager/mem_manager.py index 47364af5f9..69b51b4ab4 100755 --- a/lightllm/common/kv_cache_mem_manager/mem_manager.py +++ b/lightllm/common/kv_cache_mem_manager/mem_manager.py @@ -4,15 +4,12 @@ import torch.distributed as dist import torch.multiprocessing as mp from typing import List, Tuple, Any, Union -from lightllm.server.pd_io_struct import KVMoveTask from lightllm.utils.log_utils import init_logger from lightllm.server.router.dynamic_prompt.shared_arr import SharedInt from .allocator import KvCacheAllocator from lightllm.utils.profile_max_tokens import get_available_gpu_memory, get_total_gpu_memory -from lightllm.common.kv_trans_kernel.kv_trans import kv_trans from lightllm.utils.dist_utils import get_current_rank_in_node, get_node_world_size from lightllm.utils.envs_utils import get_unique_server_name, get_env_start_args -from lightllm.distributed.pynccl import PyNcclCommunicator from lightllm.utils.dist_utils import get_current_device_id from lightllm.utils.config_utils import get_num_key_value_heads from lightllm.common.kv_trans_kernel.nixl_kv_trans import page_io @@ -88,19 +85,6 @@ def _init_buffers(self, size, dtype, head_num, head_dim, layer_num): # 成员变量中,其与 req_manager 中的HOLD_REQUEST_ID具有类似的作用和意义。 self.kv_buffer = torch.empty((layer_num, size + 1, 2 * head_num, head_dim), dtype=dtype, device="cuda") - def alloc_kv_move_buffer(self, max_req_total_len): - """ - pd 分离模式使用的特殊接口 - """ - if isinstance(self, MemoryManager) and type(self) is not MemoryManager: - raise NotImplementedError("subclass need reimpl this method") - self.kv_move_buffer = torch.empty( - (1, max_req_total_len + 8, 2 * self.head_num, self.head_dim), dtype=self.dtype, device="cuda" - ) - self.kv_move_buf_indexes = torch.arange(0, max_req_total_len + 8, dtype=torch.int64, device="cuda") - self.token_dim_size = self.kv_move_buffer.shape[-2] * self.kv_move_buffer.shape[-1] - return - def alloc_paged_kv_move_buffer(self, page_num, page_size) -> torch.Tensor: num_kv_head = get_num_key_value_heads(get_env_start_args().model_dir) self.kv_move_buffer = torch.empty( @@ -175,150 +159,6 @@ def read_page_kv_move_buffer_to_mem( # logger.info(f"dst token tensor {self.kv_buffer[:, mem_indexes[0], 0, 0]}") # logger.info(f"dst page token tensor {cur_page[0, :, 0, 0]}") - def send_to_decode_node( - self, - move_tasks: List[KVMoveTask], - mem_managers: List["MemoryManager"], - dp_size_in_node: int, - nccl_comm: PyNcclCommunicator, - ): - assert dp_size_in_node == 1 - - # 先将数据发送到指定的一张卡上的buffer,再发送。 - - move_token_indexes = [] - for task in move_tasks: - if task.move_kv_len != 0: - move_token_indexes.extend(task.prefill_token_indexes[-task.move_kv_len :]) - - cur_device_index = self.kv_buffer.get_device() - cur_mem = mem_managers[cur_device_index] - for i, mem in enumerate(mem_managers): - for layer_index in range(mem.layer_num): - move_buffer = mem._get_kv_move_data(move_token_indexes, layer_index) - if i == cur_device_index: - nccl_comm.send(move_buffer, dst=1) - else: - move_size = move_buffer.numel() - new_move_buffer = cur_mem.kv_move_buffer.view(-1)[0:move_size].view(move_buffer.shape) - from torch.cuda import comm - - comm.broadcast(move_buffer, out=[new_move_buffer]) - nccl_comm.send(new_move_buffer, dst=1) - return - - def _get_kv_move_data(self, token_indexes: List[int], layer_index: int): - move_size = self.token_dim_size * len(token_indexes) - move_buffer = self.kv_move_buffer.view(-1)[0:move_size].view( - 1, len(token_indexes), 2 * self.head_num, self.head_dim - ) - move_buffer[:, :, :, :] = self.kv_buffer[layer_index, token_indexes, :, :] - return move_buffer - - def receive_from_prefill_node( - self, - move_tasks: List[KVMoveTask], - mem_managers: List["MemoryManager"], - dp_size_in_node: int, - nccl_comm: PyNcclCommunicator, - ): - assert dp_size_in_node == 1 - - # 先将数据接受到指定的一张卡上的buffer,再复制到其他的卡上。 - - move_token_indexes = [] - for task in move_tasks: - if task.move_kv_len != 0: - move_token_indexes.extend(task.decode_token_indexes[-task.move_kv_len :]) - - cur_device_index = self.kv_buffer.get_device() - token_num = len(move_token_indexes) - move_size = self.token_dim_size * token_num - recive_buffer = self.kv_move_buffer.view(-1)[0:move_size].view(1, token_num, 2 * self.head_num, self.head_dim) - for i, mem in enumerate(mem_managers): - for layer_index in range(mem.layer_num): - nccl_comm.recv(recive_buffer, src=0) - if i == cur_device_index: - mem._write_kv_move_data(move_token_indexes, recive_buffer, layer_index) - else: - new_recive_buffer = mem.kv_move_buffer.view(-1)[0:move_size].view(recive_buffer.shape) - from torch.cuda import comm - - comm.broadcast(recive_buffer, out=[new_recive_buffer]) - mem._write_kv_move_data(move_token_indexes, new_recive_buffer, layer_index) - return - - def _write_kv_move_data(self, token_indexes: torch.Tensor, buffer_tensor: torch.Tensor, layer_index): - self.kv_buffer[layer_index : layer_index + 1, token_indexes, :, :] = buffer_tensor - return - - def send_to_decode_node_p2p( - self, - move_tasks: List[KVMoveTask], - mem_managers: List["MemoryManager"], - dp_size_in_node: int, - nccl_comm: PyNcclCommunicator, - ): - """ - 使用 p2p triton kernel 进行数据复制和传输的实现方式。 - """ - assert dp_size_in_node == 1 - - # 先将数据发送到指定的一张卡上的buffer,再发送。 - - move_token_indexes = [] - for task in move_tasks: - if task.move_kv_len != 0: - move_token_indexes.extend(task.prefill_token_indexes[-task.move_kv_len :]) - - move_token_indexes = torch.tensor(move_token_indexes, dtype=torch.int64, device="cuda") - for i, mem in enumerate(mem_managers): - for layer_index in range(mem.layer_num): - move_buffer = mem._get_kv_move_data_p2p(move_token_indexes, layer_index, self.kv_move_buffer) - nccl_comm.send(move_buffer, dst=1) - return - - def _get_kv_move_data_p2p(self, token_indexes: torch.Tensor, layer_index: int, kv_move_buffer: torch.Tensor): - move_token_num = len(token_indexes) - move_size = self.token_dim_size * move_token_num - move_buffer = kv_move_buffer.view(-1)[0:move_size].view(move_token_num, 2 * self.head_num, self.head_dim) - kv_trans( - self.kv_buffer[layer_index, :, :, :], token_indexes, move_buffer, self.kv_move_buf_indexes[0:move_token_num] - ) - return move_buffer - - def receive_from_prefill_node_p2p( - self, - move_tasks: List[KVMoveTask], - mem_managers: List["MemoryManager"], - dp_size_in_node: int, - nccl_comm: PyNcclCommunicator, - ): - assert dp_size_in_node == 1 - - # 先将数据接受到指定的一张卡上的buffer,再复制到其他的卡上。 - - move_token_indexes = [] - for task in move_tasks: - if task.move_kv_len != 0: - move_token_indexes.extend(task.decode_token_indexes[-task.move_kv_len :]) - - move_token_indexes = torch.tensor(move_token_indexes, dtype=torch.int64, device="cuda") - - token_num = len(move_token_indexes) - move_size = self.token_dim_size * token_num - recive_buffer = self.kv_move_buffer.view(-1)[0:move_size].view(token_num, 2 * self.head_num, self.head_dim) - for i, mem in enumerate(mem_managers): - for layer_index in range(mem.layer_num): - nccl_comm.recv(recive_buffer, src=0) - mem._write_kv_move_data_p2p(move_token_indexes, recive_buffer, layer_index) - return - - def _write_kv_move_data_p2p(self, token_indexes: torch.Tensor, buffer_tensor: torch.Tensor, layer_index): - move_token_num = len(token_indexes) - kv_trans(buffer_tensor, self.kv_move_buf_indexes[0:move_token_num], self.kv_buffer[layer_index], token_indexes) - return - def _free_buffers(self): self.kv_buffer = None @@ -359,7 +199,7 @@ def write_to_shm(self, req_manager): 将 mem manager 写入到 shm中,方便pd分离等特性直接从中读取,不依赖进程间队列。 """ if kv_trans_use_p2p(): - from lightllm.server.router.model_infer.mode_backend.continues_batch.pd_mode.p2p_fix import reduce_tensor + from lightllm.server.router.model_infer.mode_backend.pd_nixl.p2p_fix import reduce_tensor mp.reductions.reduce_tensor.__code__ = reduce_tensor.__code__ diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index f86503fd18..fe18423a37 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -9,8 +9,6 @@ def make_argument_parser() -> argparse.ArgumentParser: type=str, choices=[ "normal", - "prefill", - "decode", "nixl_prefill", "nixl_decode", "pd_master", @@ -18,7 +16,7 @@ def make_argument_parser() -> argparse.ArgumentParser: "visual_only", ], default="normal", - help="""set run mode, normal is started for a single server, prefill decode pd_master is for pd split run mode, + help="""set run mode, normal is started for a single server, nixl_prefill/nixl_decode/pd_master is for pd split run mode, config_server is for pd split mode used to register pd_master node, and get pd_master node list, specifically designed for large-scale, high-concurrency scenarios where `pd_master` encounters significant CPU bottlenecks.""", @@ -47,19 +45,19 @@ def make_argument_parser() -> argparse.ArgumentParser: "--pd_master_ip", type=str, default="0.0.0.0", - help="when run_mode set to prefill or decode, you need set this pd_mater_ip", + help="when run_mode set to nixl_prefill or nixl_decode, you need set this pd_mater_ip", ) parser.add_argument( "--pd_master_port", type=int, default=1212, - help="when run_mode set to prefill or decode, you need set this pd_mater_port", + help="when run_mode set to nixl_prefill or nixl_decode, you need set this pd_mater_port", ) parser.add_argument( "--pd_decode_rpyc_port", type=int, default=None, - help="p d mode, decode node used for kv move manager rpyc server port", + help="p d mode, decode node rpyc server port", ) parser.add_argument( "--select_p_d_node_strategy", diff --git a/lightllm/server/api_http.py b/lightllm/server/api_http.py index 668250fa5a..62a4c4d805 100755 --- a/lightllm/server/api_http.py +++ b/lightllm/server/api_http.py @@ -231,7 +231,7 @@ async def token_load(request: Request): @app.post("/generate") async def generate(request: Request) -> Response: - if get_env_start_args().run_mode in ["prefill", "decode", "nixl_prefill", "nixl_decode"]: + if get_env_start_args().run_mode in ["nixl_prefill", "nixl_decode"]: return create_error_response( HTTPStatus.EXPECTATION_FAILED, "service in pd mode dont recv reqs from http interface" ) @@ -253,7 +253,7 @@ async def generate(request: Request) -> Response: @app.post("/generate_stream") async def generate_stream(request: Request) -> Response: - if get_env_start_args().run_mode in ["prefill", "decode", "nixl_prefill", "nixl_decode"]: + if get_env_start_args().run_mode in ["nixl_prefill", "nixl_decode"]: return create_error_response( HTTPStatus.EXPECTATION_FAILED, "service in pd mode dont recv reqs from http interface" ) @@ -275,7 +275,7 @@ async def generate_stream(request: Request) -> Response: @app.post("/get_score") async def get_score(request: Request) -> Response: - if get_env_start_args().run_mode in ["prefill", "decode", "nixl_prefill", "nixl_decode"]: + if get_env_start_args().run_mode in ["nixl_prefill", "nixl_decode"]: return create_error_response( HTTPStatus.EXPECTATION_FAILED, "service in pd mode dont recv reqs from http interface" ) @@ -291,7 +291,7 @@ async def get_score(request: Request) -> Response: @app.post("/") async def compat_generate(request: Request) -> Response: - if get_env_start_args().run_mode in ["prefill", "decode", "nixl_prefill", "nixl_decode"]: + if get_env_start_args().run_mode in ["nixl_prefill", "nixl_decode"]: return create_error_response( HTTPStatus.EXPECTATION_FAILED, "service in pd mode dont recv reqs from http interface" ) @@ -306,7 +306,7 @@ async def compat_generate(request: Request) -> Response: @app.post("/v1/chat/completions", response_model=ChatCompletionResponse) async def chat_completions(request: ChatCompletionRequest, raw_request: Request) -> Response: - if get_env_start_args().run_mode in ["prefill", "decode", "nixl_prefill", "nixl_decode"]: + if get_env_start_args().run_mode in ["nixl_prefill", "nixl_decode"]: return create_error_response( HTTPStatus.EXPECTATION_FAILED, "service in pd mode dont recv reqs from http interface" ) @@ -323,7 +323,7 @@ async def chat_completions(request: ChatCompletionRequest, raw_request: Request) @app.post("/v1/completions", response_model=CompletionResponse) async def completions(request: CompletionRequest, raw_request: Request) -> Response: - if get_env_start_args().run_mode in ["prefill", "decode", "nixl_prefill", "nixl_decode"]: + if get_env_start_args().run_mode in ["nixl_prefill", "nixl_decode"]: return create_error_response( HTTPStatus.EXPECTATION_FAILED, "service in pd mode dont recv reqs from http interface" ) @@ -340,7 +340,7 @@ async def completions(request: CompletionRequest, raw_request: Request) -> Respo @app.post("/v1/messages") async def anthropic_messages(raw_request: Request) -> Response: - if get_env_start_args().run_mode in ["prefill", "decode", "nixl_prefill", "nixl_decode"]: + if get_env_start_args().run_mode in ["nixl_prefill", "nixl_decode"]: return create_error_response( HTTPStatus.EXPECTATION_FAILED, "service in pd mode dont recv reqs from http interface" ) diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py index 654ba0f3e5..825888e9d7 100644 --- a/lightllm/server/api_start.py +++ b/lightllm/server/api_start.py @@ -83,7 +83,7 @@ def normal_or_p_d_start(args): enable_mps() - if args.run_mode not in ["normal", "prefill", "decode", "nixl_prefill", "nixl_decode", "visual_only"]: + if args.run_mode not in ["normal", "nixl_prefill", "nixl_decode", "visual_only"]: return # 通过模型的参数判断是否是多模态模型,包含哪几种模态, 并设置是否启动相应得模块 @@ -99,7 +99,7 @@ def normal_or_p_d_start(args): args.disable_audio = True # pd 分离模式下,不启动多模态的模块 - if args.run_mode in ["decode", "nixl_decode"]: + if args.run_mode == "nixl_decode": args.disable_audio = True args.disable_vision = True @@ -404,7 +404,7 @@ def normal_or_p_d_start(args): args.pd_p_allowed_port_max = 30000 # p d 分离模式下,decode节点的调度间隙是0 - if args.run_mode == "decode": + if args.run_mode == "nixl_decode": args.router_max_wait_tokens = 0 send_and_receive_node_ip(args) # 多机用于收发node ip diff --git a/lightllm/server/core/objs/start_args_type.py b/lightllm/server/core/objs/start_args_type.py index 7ddd0941b8..0b12cbfc54 100644 --- a/lightllm/server/core/objs/start_args_type.py +++ b/lightllm/server/core/objs/start_args_type.py @@ -8,9 +8,7 @@ class StartArgs: run_mode: str = field( default="normal", - metadata={ - "choices": ["normal", "prefill", "decode", "pd_master", "nixl_prefill", "nixl_decode", "visual_only"] - }, + metadata={"choices": ["normal", "pd_master", "nixl_prefill", "nixl_decode", "config_server", "visual_only"]}, ) host: str = field(default="127.0.0.1") port: int = field(default=8000) diff --git a/lightllm/server/detokenization/manager.py b/lightllm/server/detokenization/manager.py index 389171ba8a..8c213914c7 100644 --- a/lightllm/server/detokenization/manager.py +++ b/lightllm/server/detokenization/manager.py @@ -39,7 +39,7 @@ def __init__( self.req_id_to_out: Dict[int, DecodeReq] = {} self.eos_id = args.eos_id self._init_get_token_id_to_token_str() - self.is_pd_decode_mode = self.args.run_mode == "decode" + self.is_pd_decode_mode = False self.shm_req_manager = ShmReqManager() def _init_get_token_id_to_token_str(self): diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py index 90ad87e7d1..b9799cd061 100644 --- a/lightllm/server/httpserver/manager.py +++ b/lightllm/server/httpserver/manager.py @@ -112,7 +112,7 @@ def __init__( self.metric_client = MetricClient(args.metric_port) self.pd_mode: NodeRole = NodeRole(self.args.run_mode) - assert self.pd_mode in [NodeRole.P, NodeRole.D, NodeRole.NORMAL, NodeRole.NP, NodeRole.ND] + assert self.pd_mode in [NodeRole.NORMAL, NodeRole.NP, NodeRole.ND] self.id_gen = ReqIDGenerator() self.first_time_costs = MovingAverage() self.per_token_costs = MovingAverage() diff --git a/lightllm/server/httpserver_for_pd_master/manager.py b/lightllm/server/httpserver_for_pd_master/manager.py index 7b4b8ccaad..fd7969164e 100644 --- a/lightllm/server/httpserver_for_pd_master/manager.py +++ b/lightllm/server/httpserver_for_pd_master/manager.py @@ -9,7 +9,7 @@ asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) from typing import Union, List, Tuple, Dict, Optional from lightllm.server.core.objs import FinishStatus -from ..pd_io_struct import PD_Client_Obj, UpKVStatus, NixlUpKVStatus, ObjType, NodeRole, NIXLDecodeNodeInfo +from ..pd_io_struct import PD_Client_Obj, NixlUpKVStatus, ObjType, NIXLDecodeNodeInfo from lightllm.server.core.objs import SamplingParams, StartArgs from ..multimodal_params import MultimodalParams from ..tokenizer import get_tokenizer @@ -61,7 +61,7 @@ async def remove_pd(self, pd_info_json): self.pd_manager.remove_pd(pd_info_json) return - async def update_req_status(self, upkv_status: Union[UpKVStatus, NixlUpKVStatus]): + async def update_req_status(self, upkv_status: NixlUpKVStatus): try: group_request_id = convert_sub_id_to_group_id(upkv_status.group_request_id) up_status_event = self.req_id_to_out_inf[group_request_id].up_status_event @@ -204,91 +204,6 @@ async def _log_req_header(self, request: Request, group_request_id: int): ) return - async def fetch_stream( - self, - p_node: PD_Client_Obj, - d_node: PD_Client_Obj, - prompt: Union[str, List[int]], - sampling_params: SamplingParams, - multimodal_params: MultimodalParams, - request: Request, - ): - group_request_id = sampling_params.group_request_id - sampling_params.pd_master_node_id.initialize(self.args.pd_node_id) - - req_status = ReqStatus(group_request_id, p_node, d_node) - self.req_id_to_out_inf[group_request_id] = req_status - - up_status_event = req_status.up_status_event - - d_start_args = d_node.start_args - decode_node_dict = { - "node_id": d_start_args["pd_node_id"], - "ip": d_start_args["host"], - "rpyc_port": d_start_args["pd_decode_rpyc_port"], - "max_new_tokens": sampling_params.max_new_tokens - 1, - } - - old_max_new_tokens = sampling_params.max_new_tokens - sampling_params.max_new_tokens = 1 - sampling_params.move_kv_to_decode_node.initialize(decode_node_dict if old_max_new_tokens != 1 else None) - sampling_params.suggested_dp_index = -1 - - await p_node.websocket.send_bytes(pickle.dumps((ObjType.REQ, (prompt, sampling_params, multimodal_params)))) - - while True: - await req_status.wait_to_ready() - if await request.is_disconnected(): - raise ClientDisconnected( - group_request_id=group_request_id, reason="fetch_stream prefill period check network disconnected" - ) - - if await req_status.can_read(self.req_id_to_out_inf): - token_list = await req_status.pop_all_tokens() - for sub_req_id, request_output, metadata, finish_status in token_list: - if old_max_new_tokens != 1: - finish_status = FinishStatus(FinishStatus.NO_FINISH) - else: - finish_status = FinishStatus(FinishStatus.FINISHED_LENGTH) - # 得到 p 节点返回的 prompt_ids 信息 - if metadata.get("prompt_ids", None) is not None: - prompt_ids = metadata.get("prompt_ids") - prompt_ids.append(metadata.get("id")) - yield sub_req_id, request_output, metadata, finish_status - break - - # 如果只需要一个输出 token,prefill 完就直接结束掉吧 - if old_max_new_tokens == 1: - return - - try: - await asyncio.wait_for(up_status_event.wait(), timeout=60) - except asyncio.TimeoutError: - logger.warning(f"group_request_id: {group_request_id} kv move time out err, server is busy now.") - raise ServerBusyError() - - sampling_params.move_kv_to_decode_node.initialize(None) - sampling_params.max_new_tokens = old_max_new_tokens - 1 - upkv_status: UpKVStatus = up_status_event.upkv_status - sampling_params.suggested_dp_index = upkv_status.dp_index - - await d_node.websocket.send_bytes( - pickle.dumps((ObjType.REQ, (prompt_ids, sampling_params, MultimodalParams()))) - ) - - while True: - await req_status.wait_to_ready() - if await request.is_disconnected(): - raise ClientDisconnected( - group_request_id=group_request_id, reason="fetch_stream decode period check network disconnected" - ) - if await req_status.can_read(self.req_id_to_out_inf): - token_list = await req_status.pop_all_tokens() - for sub_req_id, request_output, metadata, finish_status in token_list: - yield sub_req_id, request_output, metadata, finish_status - - return - async def fetch_nixl_stream( self, p_node: PD_Client_Obj, @@ -392,11 +307,7 @@ async def _wait_to_token_package( is_first_token = True sub_req_id_to_mtp_accepted_token_num: Dict[int, int] = {} - client_mode: NodeRole = NodeRole(d_node.mode) - - fetch_stream = self.fetch_nixl_stream if client_mode.is_NP_or_ND() else self.fetch_stream - - async for sub_req_id, out_str, metadata, finish_status in fetch_stream( + async for sub_req_id, out_str, metadata, finish_status in self.fetch_nixl_stream( p_node, d_node, prompt, sampling_params, multimodal_params, request ): if await request.is_disconnected(): @@ -602,14 +513,14 @@ def register_pd(self, pd_info_json, websocket): pd_client.websocket = websocket self.url_to_pd_nodes[pd_client.client_ip_port] = pd_client - if pd_client.mode in ["prefill", "nixl_prefill"]: + if pd_client.mode == "nixl_prefill": self.prefill_nodes = [e for e in self.prefill_nodes if e.client_ip_port != pd_client.client_ip_port] self.prefill_nodes.append(pd_client) - elif pd_client.mode in ["decode", "nixl_decode"]: + elif pd_client.mode == "nixl_decode": self.decode_nodes = [e for e in self.decode_nodes if e.client_ip_port != pd_client.client_ip_port] self.decode_nodes.append(pd_client) else: - assert False, f"mode must in ['prefill', 'decode'], but get {pd_client.mode}" + assert False, f"mode must in ['nixl_prefill', 'nixl_decode'], but get {pd_client.mode}" self.selector.update_nodes(self.prefill_nodes, self.decode_nodes) diff --git a/lightllm/server/pd_io_struct.py b/lightllm/server/pd_io_struct.py index fe0259855d..6aacc01907 100644 --- a/lightllm/server/pd_io_struct.py +++ b/lightllm/server/pd_io_struct.py @@ -12,9 +12,6 @@ # 节点的行为 class NodeRole(enum.Enum): - P = "prefill" - D = "decode" - NP = "nixl_prefill" ND = "nixl_decode" @@ -22,10 +19,10 @@ class NodeRole(enum.Enum): PD_MASTER = "pd_master" def is_D(self): - return self == NodeRole.D or self == NodeRole.ND + return self == NodeRole.ND def is_P(self): - return self == NodeRole.P or self == NodeRole.NP + return self == NodeRole.NP def is_NP(self): return self == NodeRole.NP @@ -63,14 +60,14 @@ class _PD_Client_RunStatus: class PD_Client_Obj: node_id: int client_ip_port: str - mode: str # 只能是 prefill 或者 decode 节点 + mode: str # 只能是 nixl_prefill 或者 nixl_decode 节点 start_args: object # 节点的启动参数信息,用于做匹配性的校验,防止运行过程中出现问题。 websocket: WebSocket = None # 用于通信的 websocket 连接对象 run_status: _PD_Client_RunStatus = field(default_factory=_PD_Client_RunStatus) def __post_init__(self): - if self.mode not in ["prefill", "decode", "nixl_prefill", "nixl_decode"]: - error_info = f"""mode must in ["prefill", "decode", "nixl_prefill", "nixl_decode"], but get {self.mode}""" + if self.mode not in ["nixl_prefill", "nixl_decode"]: + error_info = f"""mode must in ["nixl_prefill", "nixl_decode"], but get {self.mode}""" logger.error(error_info) raise ValueError(error_info) return @@ -88,112 +85,6 @@ def to_log_str(self): return f"PD_MASTER host_ip_port: {self.host_ip_port} node_id: {self.node_id}" -@dataclass -class UpKVStatus: - group_request_id: int - # The identifier of the pd_master node handling the request. - pd_master_node_id: int - # decode node dp_index to handle this request - dp_index: int - - def __post_init__(self): - if not isinstance(self.group_request_id, int): - error_info = "group_request_id only can be int" - logger.error(error_info) - raise ValueError(error_info) - - if not isinstance(self.pd_master_node_id, int): - error_info = "pd_master_node_id only can be int" - logger.error(error_info) - raise ValueError(error_info) - return - - -@dataclass -class DecodeNodeInfo: - node_id: int - ip: str - rpyc_port: str - max_new_tokens: int - - -@dataclass -class PDTransJoinInfo: - decode_id: int - decode_device_id: int - prefill_id: int - prefill_device_id: int - pd_prefill_nccl_ip: str - pd_prefill_nccl_port: int - # 用于标识一次唯一的连接,prefill_id 和 decode_id 相同时,可能因为网络原因重连,为了更好的区分 - # 一次连接,使用一个 uuid 为其标识 - connect_id: str - - -@dataclass -class PDTransLeaveInfo: - decode_id: int - prefill_id: int - # 用于标识一次唯一的连接,prefill_id 和 decode_id 相同时,可能因为网络原因重连,为了更好的区分 - # 一次连接,使用一个 uuid 为其标识 - connect_id: str - - -@dataclass -class KVMoveTask: - group_request_id: int - input_tokens: List[int] # 代表输入的token_id 序列 - prefill_token_indexes: List[int] # 在prefill节点上 mem manager kv buffer中的token index - # 在decode节点上 mem manager kv buffer中的token index, 其代表的是真实占用的额外token,并不与prefill_token_indexes 一样长 - decode_token_indexes: List[int] - move_kv_len: int # 因为 prompt cache 的原因,当prefill节点和decode节点沟通后,传输的kv的数量可能少于 prefill_value 的长度 - prefill_node_id: int - decode_node: DecodeNodeInfo - # 保存prefill 和 decode 节点对应处理的dp_index, 如果是普通tp模式,这个值一定是0, - # 如果是deepseekv2的tp dp 混合模式, 才有真正的意义。 - prefill_dp_index: int - decode_dp_index: int - pd_master_node_id: int - mark_start_time: float = None - # 标记任务使用某个连接id进行传输 - connect_id: str = None - - def __post_init__(self): - if len(self.input_tokens) <= 0: - error_info = "key must len >= 1" - logger.error(error_info) - raise ValueError(error_info) - - def to_prefill_log_info(self): - v_len = None if self.prefill_token_indexes is None else len(self.prefill_token_indexes) - d_i = self.prefill_dp_index - id = self.group_request_id - log = f"id: {id} in_len:{len(self.input_tokens)} v_len: {v_len} move_len: {self.move_kv_len} dp_index:{d_i}" - return log + f" connect_id: {self.connect_id}" - - def to_decode_log_info(self): - v_len = None if self.decode_token_indexes is None else len(self.decode_token_indexes) - d_i = self.decode_dp_index - id = self.group_request_id - log = f"id: {id} in_len:{len(self.input_tokens)} v_len: {v_len} move_len: {self.move_kv_len} dp_index:{d_i}" - return log + f" connect_id: {self.connect_id}" - - def id(self): - return self.group_request_id - - def get_cost_time(self): - if self.mark_start_time is not None: - return time.time() - self.mark_start_time - else: - return 100000000000 - - -@dataclass -class KVMoveTaskGroup: - tasks: List[KVMoveTask] - connect_id: str - - ####### 下边是 NIXL模式下使用的特定对象 ######## diff --git a/lightllm/server/router/manager.py b/lightllm/server/router/manager.py index a41c2f265a..368205ed3e 100644 --- a/lightllm/server/router/manager.py +++ b/lightllm/server/router/manager.py @@ -95,8 +95,8 @@ def __init__(self, args: StartArgs): ) self.metric_client = MetricClient(args.metric_port) - self.is_pd_run_mode = self.args.run_mode in ["prefill", "decode", "nixl_prefill", "nixl_decode"] - self.is_pd_decode_mode = self.args.run_mode in ["decode", "nixl_decode"] + self.is_pd_run_mode = self.args.run_mode in ["nixl_prefill", "nixl_decode"] + self.is_pd_decode_mode = self.args.run_mode == "nixl_decode" # p d 分离模式下,需要调度锁来同步调度端和推理端的一些数据操作 # 主要是为了防止调度失误,造成 OOM 等错误 self.router_lock = mp.Lock() @@ -203,14 +203,6 @@ async def wait_to_model_ready(self): self.req_queue = build_req_queue(self.args, self, self.dp_size_in_node) logger.info(f"use req queue {self.req_queue.__class__.__name__}") - if self.args.run_mode == "prefill": - # 启动 prefill kv move 管理进程 - from lightllm.server.router.model_infer.mode_backend.continues_batch.pd_mode.prefill_node_impl import ( - start_prefill_kv_move_manager_process, - ) - - start_prefill_kv_move_manager_process(self.args, self.info_queue) - if self.args.run_mode == "nixl_prefill": from lightllm.server.router.model_infer.mode_backend.pd_nixl.prefill_node_impl import ( start_prefill_kv_move_manager_process, @@ -218,14 +210,6 @@ async def wait_to_model_ready(self): start_prefill_kv_move_manager_process(self.args, self.info_queue) - if self.args.run_mode == "decode": - # 启动 decode kv move 管理进程 - from lightllm.server.router.model_infer.mode_backend.continues_batch.pd_mode.decode_node_impl import ( - start_decode_kv_move_manager_process, - ) - - start_decode_kv_move_manager_process(self.args, self.info_queue) - if self.args.run_mode == "nixl_decode": from lightllm.server.router.model_infer.mode_backend.pd_nixl.decode_node_impl import ( start_decode_kv_move_manager_process, diff --git a/lightllm/server/router/model_infer/mode_backend/__init__.py b/lightllm/server/router/model_infer/mode_backend/__init__.py index 82f3a8ddf4..1843f31314 100644 --- a/lightllm/server/router/model_infer/mode_backend/__init__.py +++ b/lightllm/server/router/model_infer/mode_backend/__init__.py @@ -10,10 +10,6 @@ from .diverse_backend.impl import DiversehBackend # pd mode backend -from .continues_batch.pd_mode.decode_node_impl.decode_impl import DecodeNode -from .continues_batch.pd_mode.decode_node_impl.decode_impl_for_dp import DPForDecodeNode -from .continues_batch.pd_mode.prefill_node_impl.prefill_impl import ChunckedPrefillForPrefillNode -from .continues_batch.pd_mode.prefill_node_impl.prefill_impl_for_dp import DPChunkedForPrefillNode from .pd_nixl.prefill_node_impl.prefill_impl import NIXLChunckedPrefillForPrefillNode from .pd_nixl.prefill_node_impl.prefill_impl_for_dp import NIXLDPChunkedForPrefillNode from .pd_nixl.decode_node_impl.decode_impl import NIXLDecodeNode diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index 0220dc87fb..f16fec243e 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -133,7 +133,7 @@ def init_model(self, kvargs): dp_rank_in_node=self.dp_rank_in_node, dp_world_size=self.dp_world_size, ) - if self.run_mode in ["prefill", "decode"] + if self.run_mode in ["nixl_prefill", "nixl_decode"] else None ) g_infer_state_lock.dp_world_size = self.dp_world_size @@ -238,8 +238,7 @@ def init_model(self, kvargs): ) if ( - self.args.run_mode in ["nixl_prefill", "nixl_decode", "prefill", "decode"] - or self.args.enable_dp_prompt_cache_fetch + self.args.run_mode in ["nixl_prefill", "nixl_decode"] or self.args.enable_dp_prompt_cache_fetch ): # 如果存在需要跨进程使用mem manger的特性,则将mem manager写入到 shm中,方便 # 读取 diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/__init__.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/__init__.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/__init__.py deleted file mode 100644 index 4b40544fe9..0000000000 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .decode_kv_move_manager import start_decode_kv_move_manager_process -from .decode_trans_process import start_decode_trans_process diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_impl.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_impl.py deleted file mode 100644 index d13987e23f..0000000000 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_impl.py +++ /dev/null @@ -1,104 +0,0 @@ -import os -import torch -import torch.multiprocessing as mp -import torch.distributed as dist -import threading -from lightllm.server.router.model_infer.mode_backend.chunked_prefill.impl import ChunkedPrefillBackend -from typing import List, Tuple -from lightllm.server.router.model_infer.infer_batch import g_infer_context, InferReq, g_infer_state_lock -from lightllm.server.core.objs import FinishStatus -from lightllm.utils.log_utils import init_logger -from rpyc.utils.server import ThreadedServer -from lightllm.common.basemodel.infer_lock import g_router_lock -from .decode_task_cache import g_success_kv_move_task_cache, KVMoveTask -from lightllm.utils.envs_utils import get_unique_server_name -from lightllm.utils.dist_utils import create_new_group_for_current_dp - -logger = init_logger(__name__) - - -class DecodeNode(ChunkedPrefillBackend): - def __init__(self, info_queue: mp.Queue) -> None: - super().__init__() - self.info_queue: mp.Queue = info_queue - self.classed_req_strict_prefill = False - - def init_custom(self): - - self.lock_nccl_group = create_new_group_for_current_dp("gloo") - logger.info(f"lock_nccl_group ranks {dist.get_rank(self.lock_nccl_group)}") - - from .decode_infer_rpyc import PDDecodeInferRpcServer - - socket_path = f"/tmp/{get_unique_server_name()}_decode_node_infer_rpyc_{self.pd_rpyc_ports[self.rank_in_node]}" - if os.path.exists(socket_path): - os.remove(socket_path) - - t = ThreadedServer( - PDDecodeInferRpcServer(self), socket_path=socket_path, protocol_config={"allow_pickle": True} - ) - threading.Thread(target=lambda: t.start(), daemon=True).start() - return - - def _init_reqs(self, reqs: List[Tuple]): - """ - 替换请求初始化操作,替换为 Decode 节点独有的一些特殊初始化流程 - """ - if self.dp_size_in_node != 1: - dp_rank_in_node = self.dp_rank_in_node - reqs = [req for req in reqs if req[3] == dp_rank_in_node] - - g_infer_state_lock.acquire() - - uninit_reqs = g_infer_context.add_reqs(reqs, init_prefix_cache=False) - # 匹配radix cache,并更新一些资源的管理。 - self._post_init_reqs(uninit_reqs=uninit_reqs) - - g_infer_state_lock.release() - req_ids = [e[0] for e in reqs] - - # pd nccl 的 decode 节点模式下不支持 cpu cache - assert not self.args.enable_cpu_cache - return req_ids - - def _post_init_reqs(self, uninit_reqs: List[InferReq]): - """ - 检查请求的 kv len 将可能有问题的请求立即结束掉 - """ - if len(uninit_reqs) == 0: - return - - remove_count = 0 - estimated_peak_token_count = 0 - for req_obj in uninit_reqs: - req_obj: InferReq = req_obj # for easy typing - request_id = req_obj.req_id - if request_id in g_success_kv_move_task_cache: - task, share_node, _ = g_success_kv_move_task_cache.pop(request_id) - task: KVMoveTask = task # for easy typing - self.radix_cache.dec_node_ref_counter(share_node) - req_all_len = len(task.input_tokens) + task.decode_node.max_new_tokens - remove_count += req_all_len - estimated_peak_token_count += req_all_len - req_obj._match_radix_cache() - else: - # 对于不合法的请求,直接模拟将其finished掉 - req_obj.cur_output_len += 1 - req_obj.set_next_gen_token_id(0, 0.0, 1) - req_obj.finish_status.set_status(FinishStatus.FINISHED_STOP) - - if self.is_master_in_dp: - req_obj.shm_req.shm_cur_kv_len = req_obj.cur_kv_len - req_obj.shm_req.shm_cur_output_len = req_obj.cur_output_len - req_obj.shm_req.finish_token_index = req_obj.get_cur_total_len() - 1 - req_obj.shm_req.finish_status.set_status(FinishStatus.FINISHED_STOP) - req_obj.shm_req.candetoken_out_len = req_obj.cur_output_len - - req_id = req_obj.shm_req.request_id - logger.error(f"req_id: {req_id} forced to finished, it not in g_success_kv_move_task_cache") - - if self.is_master_in_dp: - with g_router_lock.obj: - self.shared_token_load.add_frozened_token_count(-remove_count, self.dp_rank_in_node) - self.shared_token_load.add_estimated_peak_token_count(estimated_peak_token_count, self.dp_rank_in_node) - return diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_impl_for_dp.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_impl_for_dp.py deleted file mode 100644 index 8dc9ad1a6d..0000000000 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_impl_for_dp.py +++ /dev/null @@ -1,28 +0,0 @@ -import torch.multiprocessing as mp -from lightllm.server.router.model_infer.infer_batch import g_infer_context, InferReq -from lightllm.utils.log_utils import init_logger -from typing import List, Tuple -from lightllm.server.router.model_infer.mode_backend.dp_backend.impl import DPChunkedPrefillBackend -from .decode_impl import DecodeNode - -logger = init_logger(__name__) - - -class DPForDecodeNode(DPChunkedPrefillBackend): - def __init__(self, info_queue: mp.Queue) -> None: - super().__init__() - self.info_queue: mp.Queue = info_queue - self.classed_req_strict_prefill = False - return - - def init_custom(self): - DecodeNode.init_custom(self) - return - - def _init_reqs(self, reqs: List[Tuple]): - DecodeNode._init_reqs(self, reqs=reqs) - return - - def _post_init_reqs(self, uninit_reqs: List[InferReq]): - DecodeNode._post_init_reqs(self, uninit_reqs=uninit_reqs) - return diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_infer_rpyc.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_infer_rpyc.py deleted file mode 100644 index 6d97714b8c..0000000000 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_infer_rpyc.py +++ /dev/null @@ -1,209 +0,0 @@ -import torch -import torch.distributed as dist -import rpyc -import time -from typing import Dict, List, Tuple, Optional, Union -from rpyc.utils.classic import obtain -from .decode_impl import DecodeNode -from lightllm.common.basemodel.infer_lock import acquire_lock_until_ready, release_acquired_lock, g_router_lock -from .decode_task_cache import g_kv_move_task_cache, g_success_kv_move_task_cache -from lightllm.server.pd_io_struct import KVMoveTask -from lightllm.utils.log_utils import init_logger - -logger = init_logger(__name__) - - -class PDDecodeInferRpcServer(rpyc.Service): - def __init__(self, backend: DecodeNode) -> None: - super().__init__() - self.backend = backend - self.device_id = self.backend.current_device_id - self.dp_rank_in_node = self.backend.dp_rank_in_node - self.is_master_in_dp = self.backend.is_master_in_dp - return - - def on_connect(self, conn): - torch.cuda.set_device(f"cuda:{self.device_id}") - return - - def judge_token_is_ok(self, key_len, max_new_token): - # 多 dp 单卡模式下, 每个 dp 各自处理自己的, 不需要同步 - if self.backend.dp_world_size == 1: - with g_router_lock.obj: - shared_token_load = self.backend.shared_token_load - peak_num = shared_token_load.get_estimated_peak_token_count(self.dp_rank_in_node) - peak_num += shared_token_load.get_frozened_token_count(self.dp_rank_in_node) - peak_num += key_len + max_new_token - - if peak_num < self.backend.get_max_total_token_num(): - object_list = [True] - shared_token_load.add_frozened_token_count(key_len + max_new_token, self.dp_rank_in_node) - else: - object_list = [False] - return object_list[0] - - # 普通单dp模式下, 只有主 rank 处理信息,并将数据同步到其他rank上 - if self.is_master_in_dp: - with g_router_lock.obj: - shared_token_load = self.backend.shared_token_load - peak_num = shared_token_load.get_estimated_peak_token_count(self.dp_rank_in_node) - peak_num += shared_token_load.get_frozened_token_count(self.dp_rank_in_node) - peak_num += key_len + max_new_token - - if peak_num < self.backend.get_max_total_token_num(): - object_list = [True] - shared_token_load.add_frozened_token_count(key_len + max_new_token, self.dp_rank_in_node) - else: - object_list = [False] - dist.broadcast_object_list(object_list, src=0, group=self.backend.lock_nccl_group) - else: - object_list = [None] - dist.broadcast_object_list(object_list, src=0, group=self.backend.lock_nccl_group) - return object_list[0] - - def recover_frozen_token(self, key_len, max_new_token): - if self.is_master_in_dp: - with g_router_lock.obj: - shared_token_load = self.backend.shared_token_load - shared_token_load.add_frozened_token_count(-(key_len + max_new_token), self.dp_rank_in_node) - return - - def _alloc_to_frozen_some_tokens(self, move_task: KVMoveTask): - is_ok = self.judge_token_is_ok(len(move_task.input_tokens), move_task.decode_node.max_new_tokens) - if not is_ok: - if self.is_master_in_dp: - logger.info(f"req_id: {move_task.to_decode_log_info()} alloc token failed") - shared_token_load = self.backend.shared_token_load - dp_rank = self.dp_rank_in_node - frozen_token_num = shared_token_load.get_frozened_token_count(dp_rank) - estimated_peak_token_num = shared_token_load.get_estimated_peak_token_count(dp_rank) - logger.debug( - f"radix refed token num {self.backend.radix_cache.get_refed_tokens_num()}\n" - f"radix hold token num {self.backend.radix_cache.get_tree_total_tokens_num()}\n" - f"mem manager can alloc token num {self.backend.model.mem_manager.allocator.can_use_mem_size}\n" - f"mem manager total size {self.backend.model.mem_manager.allocator.size}\n" - f"frozened token num {frozen_token_num}\n" - f"estimated peak token num {estimated_peak_token_num}\n" - ) - return None - - key = torch.tensor(move_task.input_tokens, dtype=torch.int64, device="cpu") - tree_node, kv_len, fused_token_indexes = self.backend.radix_cache.match_prefix(key, update_refs=True) - # 如果没匹配到,说明长度是0, 将fused_token_indexes做一下转换 - fused_token_indexes = [] if fused_token_indexes is None else fused_token_indexes.tolist() - need_len = len(move_task.input_tokens) - kv_len - if need_len == 0: - alloc_token_indexes = [] - else: - self.backend.radix_cache.free_radix_cache_to_get_enough_token(need_len) - alloc_token_indexes = self.backend.model.mem_manager.alloc(need_len) - if alloc_token_indexes is not None: - alloc_token_indexes = alloc_token_indexes.tolist() - - if alloc_token_indexes is None: - self.backend.radix_cache.dec_node_ref_counter(tree_node) - self.recover_frozen_token(len(move_task.input_tokens), move_task.decode_node.max_new_tokens) - return None - - move_task.decode_token_indexes = alloc_token_indexes - move_task.move_kv_len = need_len - - g_kv_move_task_cache[move_task.group_request_id] = (move_task, tree_node, fused_token_indexes) - return move_task.decode_token_indexes - - # 返回 None 代表服务繁忙已经无法调度新的请求进入了 - def exposed_alloc_to_frozen_some_tokens(self, move_tasks: List[KVMoveTask]) -> List[Optional[List[int]]]: - move_tasks = obtain(move_tasks) - acquire_lock_until_ready(self.backend.lock_nccl_group) - try: - ans_list = [] - for move_task in move_tasks: - ans_list.append(self._alloc_to_frozen_some_tokens(move_task)) - return ans_list - except BaseException as e: - logger.exception(str(e)) - return None - finally: - release_acquired_lock() - - def _put_kv_received_to_radix_cache(self, group_req_id: int): - move_task, tree_node, fused_token_indexes = g_kv_move_task_cache.pop(group_req_id) - radix_cache = self.backend.radix_cache - key = torch.tensor(move_task.input_tokens, dtype=torch.int64, device="cpu") - value = torch.tensor(fused_token_indexes + move_task.decode_token_indexes, dtype=torch.int64, device="cpu") - prefix_len, _ = radix_cache.insert(key, value) - assert len(fused_token_indexes) <= prefix_len - self.backend.model.mem_manager.free(value[len(fused_token_indexes) : prefix_len]) - self.backend.radix_cache.dec_node_ref_counter(tree_node) - - # 申请一段key,把 radix cache 锁住,防止极端情况下被刷掉, decode 端通过减两次引用计数来修正。 - tree_node, kv_len, _ = self.backend.radix_cache.match_prefix(key, update_refs=True) - assert len(key) == kv_len - g_success_kv_move_task_cache[group_req_id] = (move_task, tree_node, time.time()) - return - - def exposed_put_kv_received_to_radix_cache(self, group_req_ids: List[int]): - group_req_ids = obtain(group_req_ids) - acquire_lock_until_ready(self.backend.lock_nccl_group) - for group_req_id in group_req_ids: - self._put_kv_received_to_radix_cache(group_req_id) - release_acquired_lock() - return - - def _fail_to_realese_forzen_tokens(self, group_req_id: int): - move_task, tree_node, fused_token_indexes = g_kv_move_task_cache.pop(group_req_id) - value = torch.tensor(move_task.decode_token_indexes, dtype=torch.int64, device="cpu") - self.backend.model.mem_manager.free(value) - self.backend.radix_cache.dec_node_ref_counter(tree_node) - self.recover_frozen_token(len(move_task.input_tokens), move_task.decode_node.max_new_tokens) - return - - def exposed_fail_to_realese_forzen_tokens(self, group_req_ids: List[int]): - group_req_ids = obtain(group_req_ids) - acquire_lock_until_ready(self.backend.lock_nccl_group) - for group_req_id in group_req_ids: - self._fail_to_realese_forzen_tokens(group_req_id) - release_acquired_lock() - return - - def exposed_unfrozen_time_out_reqs_tokens(self): - acquire_lock_until_ready(self.backend.lock_nccl_group) - if self.backend.dp_world_size == 1: - need_release_reqs = self._get_time_out_reqs() - logger.info(f"kv time out reqs: {need_release_reqs}") - remove_tokens = self._remove_time_out_reqs(need_release_reqs) - if remove_tokens != 0: - with g_router_lock.obj: - self.backend.shared_token_load.add_frozened_token_count(-remove_tokens, self.dp_rank_in_node) - else: - if self.is_master_in_dp: - need_release_reqs = self._get_time_out_reqs() - logger.info(f"kv time out reqs: {need_release_reqs}") - dist.broadcast_object_list([need_release_reqs], src=0, group=self.backend.lock_nccl_group) - else: - receive_objs = [None] - dist.broadcast_object_list(receive_objs, src=0, group=self.backend.lock_nccl_group) - need_release_reqs = receive_objs[0] - remove_tokens = self._remove_time_out_reqs(need_release_reqs) - if self.is_master_in_dp and remove_tokens != 0: - with g_router_lock.obj: - self.backend.shared_token_load.add_frozened_token_count(-remove_tokens, self.dp_rank_in_node) - - release_acquired_lock() - return - - def _get_time_out_reqs(self): - need_release_reqs = [] - for req_id, (_, _, time_mark) in g_success_kv_move_task_cache.items(): - # 6s 这个请求都没有被调度使用,就会主动被删除掉锁定,释放其锁定的token - if time.time() - time_mark > 6: - need_release_reqs.append(req_id) - return need_release_reqs - - def _remove_time_out_reqs(self, need_release_reqs: List[int]) -> int: - remove_tokens = 0 - for req_id in need_release_reqs: - task, tree_node, _ = g_success_kv_move_task_cache.pop(req_id) - self.backend.radix_cache.dec_node_ref_counter(tree_node) - remove_tokens += len(task.input_tokens) + task.decode_node.max_new_tokens - return remove_tokens diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_kv_move_manager.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_kv_move_manager.py deleted file mode 100644 index 4733a141bf..0000000000 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_kv_move_manager.py +++ /dev/null @@ -1,383 +0,0 @@ -import rpyc -import random -import asyncio -import os -import signal -import collections -import time -import psutil -import threading -import inspect -import setproctitle -from rpyc.utils.classic import obtain -from dataclasses import dataclass -from typing import List, Dict, Optional, Tuple, Union -from rpyc import ThreadedServer -from lightllm.utils.log_utils import init_logger -from .decode_infer_rpyc import PDDecodeInferRpcServer -from ..task_queue import TaskQueue -import torch.multiprocessing as mp -from lightllm.server.pd_io_struct import KVMoveTask, UpKVStatus, PDTransJoinInfo, PDTransLeaveInfo -from lightllm.utils.retry_utils import retry -import numpy as np -from rpyc import AsyncResult -from lightllm.utils.device_utils import kv_trans_use_p2p -from lightllm.utils.graceful_utils import graceful_registry -from lightllm.utils.envs_utils import get_unique_server_name - -logger = init_logger(__name__) - -thread_local_data = threading.local() - -KV_MOVE_MAX_NUM = 16 - - -class DecodeKVMoveManager(rpyc.Service): - def __init__(self, args, info_queue: mp.Queue): - super().__init__() - self.args = args - # args.dp // args.nnodes 在跨机tp的场景下,可能为0 - self.dp_size_in_node = max(1, args.dp // args.nnodes) - self.node_world_size = args.tp // args.nnodes - self.dp_world_size = args.tp // args.dp - # 不支持跨机tp的pd 分离策略 - assert self.dp_world_size <= self.node_world_size - - self.info_queue = info_queue - self.infer_rpyc_lock = threading.Lock() - self.infer_rpyc_objs: List[PDDecodeInferRpcServer] = [] - - from .decode_trans_obj import KVTransConnectObj - - self.connect_id_to_trans_obj: Dict[str, KVTransConnectObj] = {} - for port in self.args.pd_node_infer_rpyc_ports: - socket_path = f"/tmp/{get_unique_server_name()}_decode_node_infer_rpyc_{port}" - from rpyc.utils.factory import unix_connect - - con = retry(max_attempts=20, wait_time=2)(unix_connect)(socket_path, config={"allow_pickle": True}) - self.infer_rpyc_objs.append(con.root) - logger.info(f"rpyc connect to port: {port} ok") - - from .up_status import start_up_kv_status_process - - self.up_status_in_queue = mp.Queue() - self.up_status_out_queue = mp.Queue() - start_up_kv_status_process(self.args, self.up_status_in_queue, self.up_status_out_queue) - - # fail release queue - self.fail_to_release_queue = TaskQueue(get_func=lambda datas: datas[0:KV_MOVE_MAX_NUM], fail_func=None) - self.fail_to_release_thread = threading.Thread(target=self.handle_fail_release_task_loop, daemon=True) - self.fail_to_release_thread.start() - - # 在不使用p2p 复制kv 的方案时,需要全局的传输锁进行控制。这个时候kv传输的效率会下降。 - self.kv_trans_lock = threading.Lock() - - from .decode_trans_obj import KVTransProcess - - self.kv_trans_processes: List[KVTransProcess] = [None] * self.node_world_size - for device_id in range(self.node_world_size): - self.kv_trans_processes[device_id] = KVTransProcess() - assert self.kv_trans_processes[device_id].init_all(device_id, self) - - return - - # ================================================================================== - # _dp_alloc_to_frozen_some_tokens - # _put_kv_received_to_radix_cache - # _fail_to_realese_forzen_tokens - # _unfrozen_time_out_reqs_tokens - # 上述接口都是 kv move manager 与推理进程进行交互的接口,主要用于申请锁定kv资源或者释放 - # kv资源的接口 - # ================================================================================== - - async def wait_all_future_finish(self, futures: List[AsyncResult]): - await asyncio.gather(*[asyncio.to_thread(future.wait) for future in futures]) - return - - def _dp_alloc_to_frozen_some_tokens(self, dp_tasks: List[List[KVMoveTask]]) -> List[List[Optional[List[int]]]]: - with self.infer_rpyc_lock: - futures = [] - for dp_index in range(self.dp_size_in_node): - conn_start = dp_index * self.dp_world_size - conn_end = (dp_index + 1) * self.dp_world_size - conns = self.infer_rpyc_objs[conn_start:conn_end] - for conn in conns: - futures.append(rpyc.async_(conn.alloc_to_frozen_some_tokens)(dp_tasks[dp_index])) - - asyncio.run(self.wait_all_future_finish(futures)) - ans_values = [ - obtain(futures[dp_index * self.dp_world_size].value) for dp_index in range(self.dp_size_in_node) - ] - return ans_values - - def _put_kv_received_to_radix_cache(self, tasks: List[KVMoveTask]) -> None: - with self.infer_rpyc_lock: - dp_to_tasks = collections.defaultdict(list) - for task in tasks: - dp_to_tasks[task.decode_dp_index].append(task) - futures: List[AsyncResult] = [] - for decode_dp_index, _tasks in dp_to_tasks.items(): - conn_start = decode_dp_index * self.dp_world_size - conn_end = (decode_dp_index + 1) * self.dp_world_size - conns = self.infer_rpyc_objs[conn_start:conn_end] - for conn in conns: - futures.append( - rpyc.async_(conn.put_kv_received_to_radix_cache)([task.group_request_id for task in _tasks]) - ) - asyncio.run(self.wait_all_future_finish(futures)) - return - - def _fail_to_realese_forzen_tokens(self, tasks: List[KVMoveTask]) -> None: - with self.infer_rpyc_lock: - dp_to_tasks = collections.defaultdict(list) - for task in tasks: - dp_to_tasks[task.decode_dp_index].append(task) - futures: List[AsyncResult] = [] - for decode_dp_index, _tasks in dp_to_tasks.items(): - conn_start = decode_dp_index * self.dp_world_size - conn_end = (decode_dp_index + 1) * self.dp_world_size - conns = self.infer_rpyc_objs[conn_start:conn_end] - for conn in conns: - futures.append( - rpyc.async_(conn.fail_to_realese_forzen_tokens)([task.group_request_id for task in _tasks]) - ) - asyncio.run(self.wait_all_future_finish(futures)) - return - - def _unfrozen_time_out_reqs_tokens(self) -> None: - # 这个接口比较特殊,可以不区分 dp 的具体模式 - with self.infer_rpyc_lock: - futures: List[AsyncResult] = [] - for conn in self.infer_rpyc_objs: - futures.append(rpyc.async_(conn.unfrozen_time_out_reqs_tokens)()) - asyncio.run(self.wait_all_future_finish(futures)) - return - - # ================================================================================== - # put_to_fail_release_task_queue 将因为一些原因失败,需要释放锁定的kv资源的请求放入到 - # 对应的处理队列中,handle_fail_release_task_loop 是一个循环的线程,专门处理这些失败的请求 - # 通过调用与推理进程交互的接口,释放掉申请锁定的 kv 资源。 - # ================================================================================== - - def put_to_fail_release_task_queue(self, task: Union[KVMoveTask, List[KVMoveTask]]): - if isinstance(task, KVMoveTask): - self.fail_to_release_queue.put(task) - elif isinstance(task, list): - self.fail_to_release_queue.put_list(task) - else: - assert False, "error input" - return - - def handle_fail_release_task_loop(self): - while True: - handle_list: List[KVMoveTask] = self.fail_to_release_queue.get_tasks(log_tag="fail_to_release_queue") - if len(handle_list) == 0: - time.sleep(0.01) - else: - self._fail_to_realese_forzen_tokens(handle_list) - return - - # ================================================================================== - # on_connect - # on_disconnect - # exposed_check_alive - # exposed_build_trans_process - # exposed_request_data_transfer - # 上述接口是decode kv move manager 暴露的 rpyc 调用接口,用于 prefill kv move manager - # 进行连接,进行一些元数据资源的交互。 - # ================================================================================== - - def on_connect(self, conn): - # 用于处理连接断开的时候,自动删除资源 - thread_local_data.connect_id = None - pass - - def on_disconnect(self, conn): - # 用于处理连接断开的时候,自动删除资源 - if thread_local_data.connect_id is not None: - self.remove_trans_obj(thread_local_data.connect_id) - logger.info(f"connect id {thread_local_data.connect_id} disconnect") - import gc - - gc.collect() - pass - - def exposed_check_alive(self): - # 用于 prefill node check 通信连接的状态。 - return - - def exposed_build_trans_connect( - self, prefill_node_id, pd_prefill_nccl_ip, pd_prefill_nccl_port, prefill_node_max_kv_trans_num, connect_id - ): - prefill_node_id, pd_prefill_nccl_ip, pd_prefill_nccl_port, prefill_node_max_kv_trans_num = list( - map(obtain, [prefill_node_id, pd_prefill_nccl_ip, pd_prefill_nccl_port, prefill_node_max_kv_trans_num]) - ) - connect_id = obtain(connect_id) - thread_local_data.connect_id = connect_id - - logger.info(f"build trans infos {prefill_node_id} {pd_prefill_nccl_ip} {pd_prefill_nccl_port} {connect_id}") - - from .decode_trans_obj import KVTransConnectObj - - tran_obj = KVTransConnectObj() - tran_obj.create(connect_id, prefill_node_id, pd_prefill_nccl_ip, pd_prefill_nccl_port, self) - self.connect_id_to_trans_obj[connect_id] = tran_obj - return min(prefill_node_max_kv_trans_num, self.args.max_total_token_num) - - # 返回 None 代表繁忙, 放弃该任务的 kv 传送 - def exposed_request_data_transfer(self, tasks: List[KVMoveTask]) -> List[Optional[int]]: - tasks: List[KVMoveTask] = obtain(tasks) - alloc_tokened_tasks = [] - ans_list = [] - try: - for task in tasks: - logger.info(f"exposed_request_data_transfer in {task.to_decode_log_info()}, type {type(task)}") - - trans_obj = self.get_trans_obj(tasks[0]) - assert trans_obj is not None - - id_to_test_range = {} - for task in tasks: - test_dp_indexes = list(range(self.dp_size_in_node)) - random.shuffle(test_dp_indexes) - id_to_test_range[task.group_request_id] = test_dp_indexes - - id_has_result = {} - for test_index in range(self.dp_size_in_node): - dp_tasks = [[] for _ in range(self.dp_size_in_node)] - for task in tasks: - if task.group_request_id not in id_has_result: - test_dp_index = id_to_test_range[task.group_request_id][test_index] - dp_tasks[test_dp_index].append(task) - if not all(len(t) == 0 for t in dp_tasks): - dp_tasks_ans = self._dp_alloc_to_frozen_some_tokens(dp_tasks) - for dp_index in range(self.dp_size_in_node): - for task, decode_token_indexes in zip(dp_tasks[dp_index], dp_tasks_ans[dp_index]): - if decode_token_indexes is not None: - id_has_result[task.group_request_id] = (dp_index, decode_token_indexes) - for task in tasks: - if task.group_request_id in id_has_result: - task.decode_dp_index = id_has_result[task.group_request_id][0] - task.decode_token_indexes = id_has_result[task.group_request_id][1] - task.move_kv_len = len(task.decode_token_indexes) - ans_list.append(task.move_kv_len) - alloc_tokened_tasks.append(task) - else: - logger.info(f"req id {task.id()} request_data_transfer fail, server is busy") - ans_list.append(None) - - except BaseException as e: - self.put_to_fail_release_task_queue(alloc_tokened_tasks) - alloc_tokened_tasks = [] - self.remove_trans_obj(tasks[0].connect_id) - logger.exception(str(e)) - raise e - - if alloc_tokened_tasks: - trans_obj.ready_to_move_queue.put( - alloc_tokened_tasks, error_handle_func=self.put_to_fail_release_task_queue - ) - - return ans_list - - # ================================================================================== - # 定时检测kv 传输成功,但是长时间没有pd master来触发推理的请求, - # 释放这些超时请求占用的kv资源 - # ================================================================================== - - def timer_loop(self): - try: - while True: - self._unfrozen_time_out_reqs_tokens() - time.sleep(3.5) - except (BaseException, RuntimeError) as e: - logger.exception(str(e)) - raise e - - # ================================================================================== - # 定时检测传输进程的健康状态,出现问题拉崩整个系统触发重启 - # ================================================================================== - - def check_trans_process_loop(self): - try: - while True: - for device_id in range(self.node_world_size): - if not self.kv_trans_processes[device_id].is_trans_process_health(): - raise Exception(f"device_id {device_id} kv process is unhealth") - - time.sleep(10.0) - except (BaseException, RuntimeError) as e: - logger.exception(str(e)) - - for device_id in range(self.node_world_size): - self.kv_trans_processes[device_id].killself() - - # 杀掉当前进程的父进程(router), 触发全局崩溃 - os.kill(os.getppid(), signal.SIGKILL) - os.kill(os.getpid(), signal.SIGKILL) - raise e - - # ================================================================================== - # 常用辅助功能函数 - # ================================================================================== - def get_next_device_index(self): - counts = [0 for _ in range(self.node_world_size)] - for obj in self.connect_id_to_trans_obj.values(): - counts[obj.device_index] += 1 - device_index = int(np.argmin(counts)) - return device_index - - def get_trans_obj(self, task: KVMoveTask): - self.__remove_dead_trans_obj() - return self.connect_id_to_trans_obj[task.connect_id] - - def __remove_dead_trans_obj(self): - del_connect_ids = [] - for connect_id, t_obj in self.connect_id_to_trans_obj.items(): - if t_obj.has_error_status(): - del_connect_ids.append(connect_id) - - for connect_id in del_connect_ids: - self.connect_id_to_trans_obj.pop(connect_id, None) - - if del_connect_ids: - import gc - - gc.collect() - return - - def remove_trans_obj(self, connect_id): - if connect_id in self.connect_id_to_trans_obj: - trans_obj = self.connect_id_to_trans_obj.pop(connect_id, None) - if trans_obj is not None: - trans_obj.set_has_error() - return - - -def _init_env(args, info_queue: mp.Queue, event: mp.Event): - import lightllm.utils.rpyc_fix_utils as _ - - # 注册graceful 退出的处理 - graceful_registry(inspect.currentframe().f_code.co_name) - setproctitle.setproctitle(f"lightllm::{get_unique_server_name()}::decode_kv_move_manager") - - manager = DecodeKVMoveManager(args, info_queue) - t = ThreadedServer(manager, port=args.pd_decode_rpyc_port, protocol_config={"allow_pickle": True}) - threading.Thread(target=lambda: t.start(), daemon=True).start() - - kv_trans_process_check = threading.Thread(target=manager.check_trans_process_loop, daemon=True) - kv_trans_process_check.start() - - event.set() - manager.timer_loop() - return - - -def start_decode_kv_move_manager_process(args, info_queue: mp.Queue): - event = mp.Event() - proc = mp.Process(target=_init_env, args=(args, info_queue, event)) - proc.start() - event.wait() - assert proc.is_alive() - logger.info("decode kv move manager process started") - return diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_task_cache.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_task_cache.py deleted file mode 100644 index 48df4b86fb..0000000000 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_task_cache.py +++ /dev/null @@ -1,10 +0,0 @@ -# 这个里面声明了一个全局变量,主要用于推理进程缓存发送给其他进程的Kv move 任务的缓存数据 -# 为了减少一些调用时候的序列化开销。有些调用就只需要传输一个请求id就可以了,不用传输特别的 -# 数据了,提升rpyc 调用的速度, 只用在 decode_impl.py 和 decode_infer_rpyc.py 文件中 -from typing import Dict, List, Tuple -from lightllm.server.pd_io_struct import KVMoveTask -from lightllm.server.router.dynamic_prompt.radix_cache import TreeNode - -g_kv_move_task_cache: Dict[int, Tuple[KVMoveTask, TreeNode, List[int]]] = {} - -g_success_kv_move_task_cache: Dict[int, Tuple[KVMoveTask, TreeNode, float]] = {} # 第三个float代表的是时间,用于判断过期条件。 diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_obj.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_obj.py deleted file mode 100644 index 939f065fb6..0000000000 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_obj.py +++ /dev/null @@ -1,305 +0,0 @@ -import time -import psutil -import threading -from typing import List -from dataclasses import dataclass -from lightllm.utils.log_utils import init_logger -from ..task_queue import TaskQueue -import torch.multiprocessing as mp -from lightllm.server.pd_io_struct import KVMoveTask, UpKVStatus, PDTransJoinInfo, PDTransLeaveInfo, KVMoveTaskGroup -from lightllm.utils.device_utils import kv_trans_use_p2p -from .decode_kv_move_manager import DecodeKVMoveManager -from lightllm.utils.time_utils import TimeChecker -from ..utils import join_if_alive, clear_queue - -logger = init_logger(__name__) - -KV_MOVE_MAX_NUM = 16 - - -@dataclass -class KVTransConnectObj: - connect_id: str = None - prefill_node_id: int = None - kv_trans_process: "KVTransProcess" = None - pd_prefill_nccl_ip: str = None - pd_prefill_nccl_port: int = None - device_index: int = None - manager: "DecodeKVMoveManager" = None - has_error: bool = False - ready_to_move_queue: TaskQueue = None - kv_move_thread: threading.Thread = None - move_finished_queue: TaskQueue = None - put_to_radix_thread: threading.Thread = None - timer_checker: TimeChecker = None - - def create( - self, - connect_id: str, - prefill_node_id: str, - pd_prefill_nccl_ip: str, - pd_prefill_nccl_port: int, - manager: "DecodeKVMoveManager", - ): - self.connect_id = connect_id - self.device_index = manager.get_next_device_index() - self.kv_trans_process = manager.kv_trans_processes[self.device_index] - decode_node_id = manager.args.pd_node_id - self.prefill_node_id = prefill_node_id - self.decode_node_id = decode_node_id - self.pd_prefill_nccl_ip = pd_prefill_nccl_ip - self.pd_prefill_nccl_port = pd_prefill_nccl_port - - self.manager = manager - self.timer_checker = TimeChecker(6) - - with self.kv_trans_process.device_lock: - clear_queue(self.kv_trans_process.task_out_queue) - self.kv_trans_process.task_in_queue.put( - PDTransJoinInfo( - prefill_id=prefill_node_id, - prefill_device_id=-1, - pd_prefill_nccl_ip=pd_prefill_nccl_ip, - pd_prefill_nccl_port=pd_prefill_nccl_port, - decode_id=decode_node_id, - decode_device_id=self.device_index, - connect_id=self.connect_id, - ) - ) - assert self.kv_trans_process.task_out_queue.get(timeout=60) == "nccl_ok" - - self.ready_to_move_queue = TaskQueue( - get_func=lambda datas: datas[0:1], fail_func=self.manager.put_to_fail_release_task_queue - ) - self.kv_move_thread = threading.Thread(target=self.kv_move_loop, daemon=True) - self.kv_move_thread.start() - - self.move_finished_queue = TaskQueue( - get_func=lambda datas: datas[0:KV_MOVE_MAX_NUM], fail_func=self.manager.put_to_fail_release_task_queue - ) - self.put_to_radix_thread = threading.Thread(target=self.put_to_radix_loop, daemon=True) - self.put_to_radix_thread.start() - return - - # ================================================================================== - # 处理接受所有进行 kv 传输的请求,完成后,将请求放入到 move_finished_queue 中 - # ================================================================================== - - def _transfer_kv(self, move_tasks: List[KVMoveTask]): - with self.kv_trans_process.device_lock: - clear_queue(self.kv_trans_process.task_out_queue) - kv_move_group = KVMoveTaskGroup(tasks=move_tasks.copy(), connect_id=self.connect_id) - kv_move_group.connect_id = self.connect_id - self.kv_trans_process.task_in_queue.put(kv_move_group, timeout=10) - assert self.kv_trans_process.task_out_queue.get(timeout=60) == "ok" - logger.info(f"_transfer_kv ok {move_tasks[0].to_decode_log_info()}") - - # 标记 decode 接收到 kv cache 的时间 - for move_task in move_tasks: - move_task.mark_start_time = time.time() - - self.move_finished_queue.put_list(move_tasks) - move_tasks.clear() - - def kv_move_loop(self): - func_name = self.kv_move_loop.__name__ - while not self.has_error: - move_tasks: List[List[KVMoveTask]] = self.ready_to_move_queue.get_tasks(log_tag="ready_to_move_queue") - if len(move_tasks) == 0: - time.sleep(0.01) - continue - - if len(move_tasks) != 1: - logger.error(f"error get need 1, but get {len(move_tasks)}") - assert False - - move_tasks: List[KVMoveTask] = move_tasks[0] - for task in move_tasks: - logger.info(f"{func_name} get task {task.to_decode_log_info()}") - - try: - self.timer_to_check_status(raise_exception=True) - if not kv_trans_use_p2p(): - with self.manager.kv_trans_lock: - self._transfer_kv(move_tasks) - else: - self._transfer_kv(move_tasks) - - except BaseException as e: - logger.exception(str(e)) - self.set_has_error() - self.ready_to_move_queue.clear_tasks() - - finally: - self.manager.put_to_fail_release_task_queue(move_tasks) - - logger.error(f"{func_name} thread quit") - return - - # ================================================================================== - # 将传输完成的请求,放入到 radix cache 中进行管理。 - # ================================================================================== - - def put_to_radix_loop(self): - func_name = self.put_to_radix_loop.__name__ - while not self.has_error: - move_tasks: List[KVMoveTask] = self.move_finished_queue.get_tasks(log_tag="move_finished_queue") - if len(move_tasks) == 0: - time.sleep(0.01) - continue - - for task in move_tasks: - logger.info(f"{func_name} get put radix task {task.to_decode_log_info()}") - - try: - self.timer_to_check_status(raise_exception=True) - # random to check stats - self.manager._put_kv_received_to_radix_cache(move_tasks.copy()) - for task in move_tasks.copy(): - logger.info( - f"{func_name} put kv to radix cache ok, req_id: {task.id()} cost_time {task.get_cost_time()} s" - ) - self.manager.up_status_in_queue.put( - UpKVStatus( - group_request_id=task.group_request_id, - dp_index=task.decode_dp_index, - pd_master_node_id=task.pd_master_node_id, - ) - ) - logger.info(f"{func_name} up kv status req_id: {task.id()} finished") - move_tasks.clear() - - except BaseException as e: - logger.exception(str(e)) - self.set_has_error() - self.move_finished_queue.clear_tasks() - - finally: - self.manager.put_to_fail_release_task_queue(move_tasks) - - logger.error(f"{func_name} thread quit, info: {self.to_log_info()}") - return - - # ================================================================================== - # 错误处理检测操作的一些通用函数 - # ================================================================================== - - def timer_to_check_status(self, raise_exception=True): - if self.timer_checker.has_exceeded(): - try: - assert self.kv_trans_process.is_trans_process_health() - except BaseException as e: - logger.error(f"pid {self.kv_trans_process.process.pid} check failed") - logger.exception(str(e)) - - self.set_has_error() - if raise_exception: - raise e - return - - def has_error_status(self): - try: - assert self.has_error is False - assert self.kv_move_thread.is_alive() - assert self.put_to_radix_thread.is_alive() - except BaseException as e: - logger.exception(str(e)) - self.set_has_error() - return True - - return False - - def set_has_error(self): - self.has_error = True - - if self.ready_to_move_queue is not None: - self.ready_to_move_queue.has_error = True - - if self.move_finished_queue is not None: - self.move_finished_queue.has_error = True - - if self.manager is not None: - self.manager.remove_trans_obj(self.connect_id) - return - - def __del__(self): - logger.error(f"trans obj del start, info: {self.to_log_info()}") - - try: - self.set_has_error() - - join_if_alive(self.kv_move_thread) - join_if_alive(self.put_to_radix_thread) - - if self.connect_id is not None and self.kv_trans_process is not None: - self.kv_trans_process.task_in_queue.put( - PDTransLeaveInfo( - decode_id=self.decode_node_id, prefill_id=self.prefill_node_id, connect_id=self.connect_id - ) - ) - - if self.ready_to_move_queue is not None: - self.ready_to_move_queue.clear_tasks() - if self.move_finished_queue is not None: - self.move_finished_queue.clear_tasks() - - except BaseException as e: - logger.exception(str(e)) - - logger.error(f"trans obj deled, info: {self.to_log_info()}") - - def to_log_info(self): - log = f"connect_id: {self.connect_id} " - log += f"decode_node_id: {self.decode_node_id} " - log += f"prefill_node_id: {self.prefill_node_id} " - log += f"device_index: {self.device_index} " - return log - - -@dataclass -class KVTransProcess: - process: mp.Process = None - # 需要每个卡有一个锁来规划每次只能有一个 connection obj 操作对应显卡上的传输任务。 - device_lock: threading.Lock = None - task_in_queue: mp.Queue = None - task_out_queue: mp.Queue = None - device_id: int = None - - def init_all(self, device_id: int, manager: "DecodeKVMoveManager"): - self.device_lock = threading.Lock() - self.device_id = device_id - self.task_in_queue = mp.Queue() - self.task_out_queue = mp.Queue() - - try: - from .decode_trans_process import start_decode_trans_process - - self.process = start_decode_trans_process( - manager.args, - device_id, - self.task_in_queue, - self.task_out_queue, - ) - assert self.task_out_queue.get(timeout=30) == "proc_start" - assert self.task_out_queue.get(timeout=60) == "get_mem_managers_ok" - - return True - - except Exception as e: - logger.warning(f"Failed start kv trans process for device {device_id}: {e}") - logger.exception(str(e)) - return False - - def is_trans_process_health(self): - try: - process = psutil.Process(self.process.pid) - if not (process.is_running() and process.status() != psutil.STATUS_ZOMBIE): - logger.error(f"kv trans process for device: {self.device_id} dead!!!") - return False - else: - return True - except: - return False - - def killself(self): - self.process.kill() diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_process.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_process.py deleted file mode 100644 index cdca638873..0000000000 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_process.py +++ /dev/null @@ -1,155 +0,0 @@ -import torch -import time -import sys -import inspect -import threading -import setproctitle -import torch.multiprocessing as mp -from torch.distributed import TCPStore -from datetime import timedelta -from typing import List, Dict, Union -from lightllm.utils.log_utils import init_logger -from lightllm.common.kv_cache_mem_manager import MemoryManager -from lightllm.server.pd_io_struct import KVMoveTask, PDTransJoinInfo, PDTransLeaveInfo, KVMoveTaskGroup -from lightllm.utils.device_utils import kv_trans_use_p2p -from lightllm.utils.graceful_utils import graceful_registry -from lightllm.distributed.pynccl import PyNcclCommunicator, StatelessP2PProcessGroup -from lightllm.utils.envs_utils import get_unique_server_name - -logger = init_logger(__name__) - - -def _handle_kvmove_task( - move_tasks: List[KVMoveTask], - task_out_queue: mp.Queue, - mem_managers: List[MemoryManager], - connect_id_to_comm: Dict[str, PyNcclCommunicator], - connect_id: str, - dp_size_in_node: int, -): - total_move_kv_len = sum([task.move_kv_len for task in move_tasks]) - try: - device_index = connect_id_to_comm[connect_id].device.index - start = time.time() - if total_move_kv_len != 0: - cur_mem = mem_managers[device_index] - logger.info(f"trans start: {move_tasks[0].to_decode_log_info()}") - if kv_trans_use_p2p(): - cur_mem.receive_from_prefill_node_p2p( - move_tasks, mem_managers, dp_size_in_node, connect_id_to_comm[connect_id] - ) - else: - cur_mem.receive_from_prefill_node( - move_tasks, mem_managers, dp_size_in_node, connect_id_to_comm[connect_id] - ) - logger.info(f"trans finished: {move_tasks[0].to_decode_log_info()} move len: {total_move_kv_len}") - torch.cuda.synchronize() - logger.info(f"trans cost time: {(time.time() - start)}, {move_tasks[0].to_decode_log_info()}") - task_out_queue.put("ok") - except BaseException as e: - logger.exception(str(e)) - task_out_queue.put("fail") - raise e - - -def _handle_prefill_join( - node_info: PDTransJoinInfo, task_out_queue: mp.Queue, connect_id_to_comm: Dict[str, PyNcclCommunicator] -): - try: - logger.info(f"connect start {node_info}") - store_client = TCPStore( - host_name=node_info.pd_prefill_nccl_ip, - port=node_info.pd_prefill_nccl_port, - is_master=False, - use_libuv=True, - timeout=timedelta(seconds=30), - ) - src_id = node_info.prefill_id - dest_id = node_info.connect_id - logger.info(f"connect src_id {src_id} dest_id {dest_id}") - - result_list = [] - - def async_connect(): - torch.cuda.set_device(node_info.decode_device_id) - group = StatelessP2PProcessGroup.create(src_id=src_id, dest_id=dest_id, is_server=False, store=store_client) - comm = PyNcclCommunicator(group, node_info.decode_device_id) - result_list.append(comm) - return - - connect_task = threading.Thread(target=async_connect, daemon=True) - connect_task.start() - connect_task.join(timeout=36) - if connect_task.is_alive(): - raise Exception(f"{node_info} connect time out") - - connect_id_to_comm[node_info.connect_id] = result_list[0] - logger.info(f"{node_info} kv trans connected") - task_out_queue.put("nccl_ok") - except Exception as e: - task_out_queue.put("nccl_fail") - logger.warning(f"error while connect to prefill node: {e}") - - -def _init_env(args, device_id: int, task_in_queue: mp.Queue, task_out_queue: mp.Queue): - import os - - # os.environ["NCCL_DEBUG"] = "INFO" - os.environ["NCCL_MAX_NCHANNELS"] = "2" - os.environ["NCCL_NSOCKS_PER_CHANNEL"] = "1" - os.environ["NCCL_SOCKET_NTHREADS"] = "1" - torch.backends.cudnn.enabled = False - - dp_size_in_node = max(1, args.dp // args.nnodes) - - setproctitle.setproctitle( - f"lightllm::{get_unique_server_name()}::decode_trans:Device{device_id}_DpSizeInNode{dp_size_in_node}" - ) - - try: - torch.cuda.set_device(device_id) - graceful_registry(inspect.currentframe().f_code.co_name) - task_out_queue.put("proc_start") - - # 从共享内存读取所有rank的mem_manager - node_world_size = args.tp // args.nnodes - mem_managers: List[MemoryManager] = [ - MemoryManager.loads_from_shm(rank_in_node=rank) for rank in range(node_world_size) - ] - - task_out_queue.put("get_mem_managers_ok") - connect_id_to_comm: Dict[str, PyNcclCommunicator] = {} - while True: - task: Union[KVMoveTaskGroup, PDTransJoinInfo, PDTransLeaveInfo] = task_in_queue.get() - if isinstance(task, KVMoveTaskGroup): - _handle_kvmove_task( - task.tasks, task_out_queue, mem_managers, connect_id_to_comm, task.connect_id, dp_size_in_node - ) - elif isinstance(task, PDTransJoinInfo): - _handle_prefill_join(task, task_out_queue, connect_id_to_comm) - elif isinstance(task, PDTransLeaveInfo): - if task.connect_id in connect_id_to_comm: - connect_id_to_comm[task.connect_id].destroy() - logger.info(f"destory {task} nccl communicator.") - else: - logger.info(f"no connect_id {task.connect_id} found in connect_id_to_comm") - - else: - logger.warning(f"unexpected task type: {task}") - - except Exception as e: - logger.error(f"Fatal error happened in kv trans process: {e} in device {device_id}") - raise - - -def start_decode_trans_process( - args, - device_id: int, - task_in_queue: mp.Queue, - task_out_queue: mp.Queue, -): - proc = mp.Process(target=_init_env, args=(args, device_id, task_in_queue, task_out_queue)) - proc.start() - assert proc.is_alive() - logger.info(f"decode trans kv process for device: {device_id} start!") - return proc diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/up_status.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/up_status.py deleted file mode 100644 index 833ffecc89..0000000000 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/up_status.py +++ /dev/null @@ -1,127 +0,0 @@ -import time -import json -import asyncio -import threading -import websockets -import inspect -import setproctitle -import pickle - -from typing import Dict -from dataclasses import asdict -from lightllm.server.pd_io_struct import UpKVStatus -from lightllm.utils.log_utils import init_logger -from lightllm.utils.graceful_utils import graceful_registry -from lightllm.server.pd_io_struct import PD_Master_Obj -import torch.multiprocessing as mp -from lightllm.utils.envs_utils import get_unique_server_name - -logger = init_logger(__name__) - - -class UpStatusManager: - def __init__(self, args, task_in_queue: mp.Queue, task_out_queue: mp.Queue): - self.args = args - self.task_queue: mp.Queue[UpKVStatus] = task_in_queue - self.task_out_queue = task_out_queue - self.daemon_thread = threading.Thread(target=self.thread_loop, daemon=True) - self.daemon_thread.start() - - def thread_loop(self): - asyncio.run(self.task_loop()) - - async def task_loop(self): - - self.id_to_handle_task: Dict[int, asyncio.Task] = {} - self.id_to_handle_queue: Dict[int, asyncio.Queue] = {} - - asyncio.create_task(self.dispatch_task_loop()) - - while True: - try: - from lightllm.server.httpserver.pd_loop import _get_pd_master_objs - - id_to_pd_master_obj = await _get_pd_master_objs(self.args) - logger.info(f"get pd_master_objs {id_to_pd_master_obj}") - - if id_to_pd_master_obj is not None: - for node_id, pd_master_obj in self.id_to_handle_task.items(): - if node_id not in id_to_pd_master_obj: - self.id_to_handle_task[node_id].cancel() - self.id_to_handle_task.pop(node_id, None) - self.id_to_handle_queue.pop(node_id, None) - logger.info(f"up_kv_status_task {pd_master_obj} cancelled") - - for node_id, pd_master_obj in id_to_pd_master_obj.items(): - if node_id not in self.id_to_handle_task: - self.id_to_handle_queue[node_id] = asyncio.Queue() - self.id_to_handle_task[node_id] = asyncio.create_task(self.up_kv_status_task(pd_master_obj)) - - await asyncio.sleep(30) - - except Exception as e: - logger.exception(str(e)) - await asyncio.sleep(10) - - async def dispatch_task_loop(self): - while True: - try: - loop = asyncio.get_event_loop() - upkv_status: UpKVStatus = await loop.run_in_executor(None, self.task_queue.get) - if upkv_status.pd_master_node_id in self.id_to_handle_queue: - await self.id_to_handle_queue[upkv_status.pd_master_node_id].put(upkv_status) - else: - logger.warning(f"upstatus {upkv_status} no connection to pd_master, drop it") - except BaseException as e: - logger.exception(str(e)) - await asyncio.sleep(10) - - async def up_kv_status_task(self, pd_master_obj: PD_Master_Obj): - while True: - try: - uri = f"ws://{pd_master_obj.host_ip_port}/kv_move_status" - async with websockets.connect(uri) as websocket: - import socket - - sock = websocket.transport.get_extra_info("socket") - sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) - - while True: - try: - if pd_master_obj.node_id in self.id_to_handle_queue: - task_queue = self.id_to_handle_queue[pd_master_obj.node_id] - upkv_status: UpKVStatus = await task_queue.get() - await websocket.send(pickle.dumps(upkv_status)) - logger.info(f"up status: {upkv_status}") - else: - await asyncio.sleep(3) - except BaseException as e: - logger.error(str(e)) - raise e - except asyncio.CancelledError: - logger.info(f"up_kv_status_task {pd_master_obj} cancelled") - return - - except Exception as e: - logger.error(f"connetion to pd_master {pd_master_obj} has error: {str(e)}") - logger.exception(str(e)) - await asyncio.sleep(10) - logger.info("reconnection to pd_master") - - -def _init_env(args, task_in_queue: mp.Queue, task_out_queue: mp.Queue): - graceful_registry(inspect.currentframe().f_code.co_name) - setproctitle.setproctitle(f"lightllm::{get_unique_server_name()}::up_kv_status") - up_kv_manager = UpStatusManager(args, task_in_queue, task_out_queue) - logger.info(f"up kv manager {str(up_kv_manager)} start ok") - while True: - time.sleep(666) - return - - -def start_up_kv_status_process(args, task_in_queue: mp.Queue, task_out_queue: mp.Queue): - proc = mp.Process(target=_init_env, args=(args, task_in_queue, task_out_queue)) - proc.start() - assert proc.is_alive() - logger.info("up_kv_status_process start") - return proc diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/__init__.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/__init__.py deleted file mode 100644 index 4100e14eda..0000000000 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .prefill_trans_process import start_prefill_trans_process -from .prefill_kv_move_manager import start_prefill_kv_move_manager_process diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_impl.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_impl.py deleted file mode 100644 index 8e7bddc64e..0000000000 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_impl.py +++ /dev/null @@ -1,121 +0,0 @@ -import os -import time -import threading -import torch -import torch.multiprocessing as mp -import torch.distributed as dist -from typing import List, Tuple -from lightllm.server.router.model_infer.infer_batch import InferReq -from lightllm.server.pd_io_struct import KVMoveTask, DecodeNodeInfo -from lightllm.utils.log_utils import init_logger -from lightllm.common.basemodel.infer_lock import g_router_lock, g_infer_state_lock -from rpyc.utils.server import ThreadedServer -from .prefill_task_cache import g_kv_move_task_cache -from lightllm.utils.envs_utils import get_unique_server_name -from lightllm.utils.dist_utils import create_new_group_for_current_dp -from lightllm.server.router.model_infer.mode_backend.chunked_prefill.impl import ChunkedPrefillBackend - -logger = init_logger(__name__) - - -class ChunckedPrefillForPrefillNode(ChunkedPrefillBackend): - def __init__(self, info_queue: mp.Queue) -> None: - super().__init__() - self.support_overlap = False - self.info_queue: mp.Queue = info_queue - self.classed_req_no_decode = True - - def init_custom(self): - - self.lock_nccl_group = create_new_group_for_current_dp("gloo") - logger.info(f"lock_nccl_group ranks {dist.get_rank(self.lock_nccl_group)}") - - from .prefill_infer_rpyc import PDPrefillInferRpcServer - - socket_path = f"/tmp/{get_unique_server_name()}_prefill_node_infer_rpyc_{self.pd_rpyc_ports[self.rank_in_node]}" - if os.path.exists(socket_path): - os.remove(socket_path) - - t = ThreadedServer( - PDPrefillInferRpcServer(self), socket_path=socket_path, protocol_config={"allow_pickle": True} - ) - threading.Thread(target=lambda: t.start(), daemon=True).start() - return - - def _pre_handle_finished_reqs(self, finished_reqs): - self._prefill_req_frozen_tokens_and_put_to_kvmove_taskqueue(finished_reqs=finished_reqs) - return - - def _prefill_req_frozen_tokens_and_put_to_kvmove_taskqueue(self, finished_reqs: List[InferReq]): - if len(finished_reqs) == 0: - return - - # 提前在radix cache中回收相关的信息,并添加引用进行锁定,方便传输进程传输kv。 - if self.is_master_in_dp: - logger.info("prefill_req_handle_and_frozen_tokens") - - g_infer_state_lock.acquire() - try: - for req in finished_reqs: - - # 区分abort 和 正常结束的请求,正常结束的请求才发起kv传输任务。 - if not req.finish_status.is_finished(): - continue - - req: InferReq = req - key = req.get_input_token_ids()[0 : req.cur_kv_len] - key = torch.tensor(key, dtype=torch.int64, device="cpu") - value = self.model.req_manager.req_to_token_indexs[req.req_idx][: req.cur_kv_len].detach().cpu() - prefix_len, new_shared_kv_node = self.radix_cache.insert(key, value) - old_prefix_len = 0 if req.shared_kv_node is None else req.shared_kv_node.node_prefix_total_len - self.model.mem_manager.free( - self.model.req_manager.req_to_token_indexs[req.req_idx][old_prefix_len:prefix_len] - ) - # 将原有共享节点替换为新共享节点,新共享节点对应的长度为当前的cur_kv_len - - self.radix_cache.dec_node_ref_counter(req.shared_kv_node) - self.radix_cache.add_node_ref_counter(new_shared_kv_node) - req.shared_kv_node = new_shared_kv_node - - _kv_len = req.cur_kv_len - _value = self.radix_cache.get_mem_index_value_by_node(new_shared_kv_node) - assert len(_value) == _kv_len - self.model.req_manager.req_to_token_indexs[req.req_idx][0:_kv_len] = _value - - assert new_shared_kv_node.node_prefix_total_len == req.cur_kv_len - - if req.shm_req.sample_params.move_kv_to_decode_node.exists: - # 注意兼容纯tp 和 tp dp 混合模式的逻辑 - if self.is_master_in_dp: - g_router_lock.acquire() - self.shared_token_load.add_frozened_token_count(len(key), self.dp_rank_in_node) - g_router_lock.release() - - share_node, kv_len, value = self.radix_cache.match_prefix(key, update_refs=True) - assert len(key) == len(value) - # 将下面的请求放入到任务队列中, 注意要使用raidx cache 返回的value - decode_node_info = DecodeNodeInfo(**req.shm_req.sample_params.move_kv_to_decode_node.to_dict()) - task = KVMoveTask( - group_request_id=req.shm_req.group_req_id, - input_tokens=key.tolist(), - prefill_token_indexes=value.tolist(), - decode_token_indexes=None, - prefill_node_id=self.args.pd_node_id, - decode_node=decode_node_info, - move_kv_len=None, - prefill_dp_index=self.dp_rank_in_node, - decode_dp_index=None, - pd_master_node_id=req.shm_req.sample_params.pd_master_node_id.get(), - mark_start_time=time.time(), - ) - g_kv_move_task_cache[task.group_request_id] = (task, share_node) - - # 注意兼容纯 tp 和 tp dp 混合模式的逻辑 - if self.is_master_in_dp: - self.info_queue.put(task) - except BaseException as e: - logger.exception(str(e)) - g_infer_state_lock.release() - if self.is_master_in_dp: - logger.info("prefill_req_handle_and_frozen_tokens end") - return diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_impl_for_dp.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_impl_for_dp.py deleted file mode 100644 index 2897f71412..0000000000 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_impl_for_dp.py +++ /dev/null @@ -1,30 +0,0 @@ -import torch.multiprocessing as mp -from typing import List, Tuple -from lightllm.server.router.model_infer.infer_batch import InferReq -from lightllm.utils.log_utils import init_logger -from .prefill_impl import ChunckedPrefillForPrefillNode -from lightllm.server.router.model_infer.mode_backend.dp_backend.impl import DPChunkedPrefillBackend - -logger = init_logger(__name__) - - -class DPChunkedForPrefillNode(DPChunkedPrefillBackend): - def __init__(self, info_queue: mp.Queue) -> None: - super().__init__() - self.support_overlap = False - self.info_queue: mp.Queue = info_queue - self.classed_req_no_decode = True - - def init_custom(self): - ChunckedPrefillForPrefillNode.init_custom(self) - return - - def _pre_handle_finished_reqs(self, finished_reqs): - self._prefill_req_frozen_tokens_and_put_to_kvmove_taskqueue(finished_reqs=finished_reqs) - return - - def _prefill_req_frozen_tokens_and_put_to_kvmove_taskqueue(self, finished_reqs: List[InferReq]): - ChunckedPrefillForPrefillNode._prefill_req_frozen_tokens_and_put_to_kvmove_taskqueue( - self, finished_reqs=finished_reqs - ) - return diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_infer_rpyc.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_infer_rpyc.py deleted file mode 100644 index 1f2dd52c5a..0000000000 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_infer_rpyc.py +++ /dev/null @@ -1,47 +0,0 @@ -import torch -import torch.distributed as dist -import rpyc -from typing import Dict, List, Tuple -from rpyc.utils.classic import obtain -from .prefill_impl import ChunckedPrefillForPrefillNode -from lightllm.common.basemodel.infer_lock import g_router_lock, acquire_lock_until_ready, release_acquired_lock -from .prefill_task_cache import g_kv_move_task_cache -from lightllm.utils.log_utils import init_logger - -logger = init_logger(__name__) - - -class PDPrefillInferRpcServer(rpyc.Service): - def __init__(self, backend: ChunckedPrefillForPrefillNode) -> None: - super().__init__() - self.backend = backend - self.device_id = self.backend.current_device_id - self.dp_rank_in_node = self.backend.dp_rank_in_node - self.is_master_in_dp = self.backend.is_master_in_dp - return - - def on_connect(self, conn): - torch.cuda.set_device(f"cuda:{self.device_id}") - return - - # pd 分离模式会使用的一些接口,用于做一些全局信息管理 - def exposed_remove_req_refs_from_prompt_cache(self, group_req_ids: List[int]): - group_req_ids = obtain(group_req_ids) - acquire_lock_until_ready(self.backend.lock_nccl_group) - for group_req_id in group_req_ids: - if group_req_id in g_kv_move_task_cache: - task, share_node = g_kv_move_task_cache.pop(group_req_id) - if share_node is not None: - self.backend.radix_cache.dec_node_ref_counter(share_node) - # 减少日志数量 - if self.is_master_in_dp: - logger.info(f"unfrozen tokens for req id: {group_req_id}") - - # 更新调度元数据 - if self.is_master_in_dp: - with g_router_lock.obj: - self.backend.shared_token_load.add_frozened_token_count( - -len(task.input_tokens), self.dp_rank_in_node - ) - release_acquired_lock() - return diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_kv_move_manager.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_kv_move_manager.py deleted file mode 100644 index bd5af98ee6..0000000000 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_kv_move_manager.py +++ /dev/null @@ -1,241 +0,0 @@ -import asyncio -import time -import rpyc -import sys -import os -import gc -import signal -import copy -import numpy as np -import psutil -import threading -import inspect -import collections -import setproctitle -from typing import List, Dict, Union -from lightllm.utils.log_utils import init_logger -from .prefill_infer_rpyc import PDPrefillInferRpcServer -import torch.multiprocessing as mp -from lightllm.server.pd_io_struct import KVMoveTask -from lightllm.utils.retry_utils import retry -from rpyc import AsyncResult -from lightllm.utils.net_utils import get_hostname_ip -from ..task_queue import TaskQueue -from lightllm.utils.graceful_utils import graceful_registry -from lightllm.utils.envs_utils import get_unique_server_name - -KV_MOVE_MAX_NUM = 16 - -logger = init_logger(__name__) - - -class PrefillKVMoveManager: - def __init__(self, args, info_queue: mp.Queue): - self.args = args - # args.dp // args.nnodes 在跨机tp的场景下,可能为0 - self.dp_size_in_node = max(1, args.dp // args.nnodes) - self.node_world_size = args.tp // args.nnodes - self.dp_world_size = args.tp // args.dp - # 不支持跨机tp的pd 分离策略 - assert self.dp_world_size <= self.node_world_size - - self.info_queue = info_queue - self.infer_rpyc_objs: List[PDPrefillInferRpcServer] = [] - - from .prefill_trans_obj import KVTransConnectObj - - self.connect_id_to_trans_obj: Dict[str, KVTransConnectObj] = {} - - for port in self.args.pd_node_infer_rpyc_ports: - socket_path = f"/tmp/{get_unique_server_name()}_prefill_node_infer_rpyc_{port}" - from rpyc.utils.factory import unix_connect - - con = retry(max_attempts=20, wait_time=2)(unix_connect)(socket_path, config={"allow_pickle": True}) - self.infer_rpyc_objs.append(con.root) - logger.info(f"rpyc connect to infer rpyc port: {port} ok") - self.host_ip = get_hostname_ip() - if self.host_ip is None: - self.host_ip = args.host - - self.infer_rpyc_lock = threading.Lock() - - self.kv_trans_lock = threading.Lock() - # 释放token的task队列 - self.release_task_queue = TaskQueue(lambda datas: datas[0:KV_MOVE_MAX_NUM], fail_func=None) - self.release_tasks_thread = threading.Thread(target=self.handle_release_task_loop, daemon=True) - self.release_tasks_thread.start() - - from .prefill_trans_obj import KVTransProcess - - self.kv_trans_processes: List[KVTransProcess] = [None] * self.node_world_size - for device_id in range(self.node_world_size): - self.kv_trans_processes[device_id] = KVTransProcess() - assert self.kv_trans_processes[device_id].init_all(device_id, self) - - return - - # ================================================================================== - # 主任务循环,接收需要进行kv传输的请求进行处理 - # ================================================================================== - - def task_dispatcher_loop(self): - try: - # 获取任务,并分发给相关卡的处理队列 - while True: - move_task: KVMoveTask = self.info_queue.get() - try: - trans_obj = self.__get_trans_obj(move_task) - trans_obj.request_kv_trans_task_queue.put(move_task) - except BaseException as e: - logger.exception(str(e)) - self.put_to_release_task_queue(move_task) - finally: - trans_obj = None - - except (BaseException, RuntimeError) as e: - logger.exception(str(e)) - raise e - - # ================================================================================== - # 请求出错或者完成kv传输后的处理队列和线程loop - # ================================================================================== - - def put_to_release_task_queue(self, task: Union[KVMoveTask, List[KVMoveTask]]): - if isinstance(task, KVMoveTask): - self.release_task_queue.put(task) - elif isinstance(task, list): - self.release_task_queue.put_list(task) - else: - logger.error("error input in put_to_release_task_queue func") - return - - def handle_release_task_loop(self): - while True: - handle_list: List[KVMoveTask] = self.release_task_queue.get_tasks(log_tag="release_task_queue") - if len(handle_list) == 0: - time.sleep(0.01) - else: - self._remove_req_refs_from_prompt_cache(handle_list) - return - - # ================================================================================== - # 定时检测传输进程的健康状态,出现问题拉崩整个系统触发重启 - # ================================================================================== - - def check_trans_process_loop(self): - try: - while True: - for device_id in range(self.node_world_size): - if not self.kv_trans_processes[device_id].is_trans_process_health(): - raise Exception(f"device_id {device_id} kv process is unhealth") - - time.sleep(10.0) - except (BaseException, RuntimeError) as e: - logger.exception(str(e)) - - for device_id in range(self.node_world_size): - self.kv_trans_processes[device_id].killself() - - # 杀掉当前进程的父进程(router), 触发全局崩溃 - os.kill(os.getppid(), signal.SIGKILL) - os.kill(os.getpid(), signal.SIGKILL) - raise e - - # ================================================================================== - # 与推理进程交互接口, _remove_req_refs_from_prompt_cache - # ================================================================================== - - def _remove_req_refs_from_prompt_cache(self, tasks: List[KVMoveTask]): - with self.infer_rpyc_lock: - dp_to_tasks = collections.defaultdict(list) - for task in tasks: - dp_to_tasks[task.prefill_dp_index].append(task) - futures: List[AsyncResult] = [] - for prefill_dp_index, _tasks in dp_to_tasks.items(): - conn_start = prefill_dp_index * self.dp_world_size - conn_end = (prefill_dp_index + 1) * self.dp_world_size - conns = self.infer_rpyc_objs[conn_start:conn_end] - for conn in conns: - futures.append( - rpyc.async_(conn.remove_req_refs_from_prompt_cache)([task.group_request_id for task in _tasks]) - ) - asyncio.run(self.wait_all_future_finish(futures)) - return - - async def wait_all_future_finish(self, futures: List[AsyncResult]): - await asyncio.gather(*[asyncio.to_thread(future.wait) for future in futures]) - return - - # ================================================================================== - # 辅助功能接口 - # ================================================================================== - - def get_next_device_index(self): - counts = [0 for _ in range(self.node_world_size)] - for obj in self.connect_id_to_trans_obj.values(): - counts[obj.device_index] += 1 - device_index = int(np.argmin(counts)) - return device_index - - def remove_trans_obj(self, connect_id): - if connect_id in self.connect_id_to_trans_obj: - trans_obj = self.connect_id_to_trans_obj.pop(connect_id, None) - if trans_obj is not None: - trans_obj.set_has_error() - logger.error(f"remove tran obj decode_node_id {trans_obj.decode_node_id}") - return - - def __get_trans_obj(self, task: KVMoveTask): - self.__remove_dead_trans_obj() - # 如果已经存在连接对象,直接返回 - for obj in self.connect_id_to_trans_obj.values(): - if obj.decode_node_id == task.decode_node.node_id: - return obj - - # 如果不存在连接对象,创建新的连接对象 - gc.collect() - from .prefill_trans_obj import KVTransConnectObj - - trans_obj = KVTransConnectObj() - trans_obj.create(task.decode_node.node_id, task.decode_node.ip, task.decode_node.rpyc_port, self) - self.connect_id_to_trans_obj[trans_obj.connect_id] = trans_obj - return trans_obj - - def __remove_dead_trans_obj(self): - del_connect_ids = [] - for connect_id, t_obj in self.connect_id_to_trans_obj.items(): - if t_obj.has_error_status(): - del_connect_ids.append(connect_id) - - for connect_id in del_connect_ids: - self.connect_id_to_trans_obj.pop(connect_id, None) - - if del_connect_ids: - gc.collect() - return - - -def _init_env(args, info_queue: mp.Queue, event: mp.Event): - import lightllm.utils.rpyc_fix_utils as _ - - # 注册graceful 退出的处理 - graceful_registry(inspect.currentframe().f_code.co_name) - setproctitle.setproctitle(f"lightllm::{get_unique_server_name()}::prefill_kv_move_manager") - - manager = PrefillKVMoveManager(args, info_queue) - kv_trans_process_check = threading.Thread(target=manager.check_trans_process_loop, daemon=True) - kv_trans_process_check.start() - event.set() - # 进入主循环 - manager.task_dispatcher_loop() - return - - -def start_prefill_kv_move_manager_process(args, info_queue: mp.Queue): - event = mp.Event() - proc = mp.Process(target=_init_env, args=(args, info_queue, event)) - proc.start() - event.wait() - assert proc.is_alive() - logger.info("prefill kv move manager process started") - return diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_task_cache.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_task_cache.py deleted file mode 100644 index afa8e87f44..0000000000 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_task_cache.py +++ /dev/null @@ -1,8 +0,0 @@ -# 这个里面声明了一个全局变量,主要用于推理进程缓存发送给其他进程的Kv move 任务的缓存数据 -# 为了减少一些调用时候的序列化开销。有些调用就只需要传输一个请求id就可以了,不用传输特别的 -# 数据了,提升rpyc 调用的速度, 只用在 prefill_impl.py 和 prefill_infer_rpyc.py 文件中 -from typing import Dict, Tuple -from lightllm.server.pd_io_struct import KVMoveTask -from lightllm.server.router.dynamic_prompt.radix_cache import TreeNode - -g_kv_move_task_cache: Dict[int, Tuple[KVMoveTask, TreeNode]] = {} diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_obj.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_obj.py deleted file mode 100644 index 022be45591..0000000000 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_obj.py +++ /dev/null @@ -1,378 +0,0 @@ -import time -import rpyc -import copy -import uuid -import numpy as np -import psutil -import threading -from dataclasses import dataclass -from typing import List, Dict, Union -from lightllm.utils.log_utils import init_logger -import torch.multiprocessing as mp -from lightllm.server.pd_io_struct import KVMoveTask, PDTransJoinInfo, PDTransLeaveInfo, KVMoveTaskGroup -from rpyc.utils.classic import obtain -from ..task_queue import TaskQueue -from lightllm.utils.device_utils import kv_trans_use_p2p -from lightllm.utils.time_utils import TimeChecker -from .prefill_kv_move_manager import PrefillKVMoveManager -from lightllm.utils.net_utils import find_available_port -from ..utils import join_if_alive, clear_queue - -logger = init_logger(__name__) - - -@dataclass -class KVTransConnectObj: - connect_id: str = None - decode_node_id: int = None - rpyc_conn: object = None # rpyc_con 的连接对象 - kv_trans_process: "KVTransProcess" = None - device_index: int = None # 使用的gpu序号 - manager: "PrefillKVMoveManager" = None - has_error: bool = False - request_kv_trans_task_queue: TaskQueue = None - request_thread: threading.Thread = None - ready_kv_trans_task_queue: TaskQueue = None - kv_trans_thread: threading.Thread = None - timer_checker: TimeChecker = None - - # ================================================================================== - # 构建传输通信对象 - # ================================================================================== - - def create( - self, decode_node_id: int, decode_node_ip: str, decode_node_rpyc_port: int, manager: "PrefillKVMoveManager" - ): - device_index = manager.get_next_device_index() # 分配使用的显卡index - self.kv_trans_process = manager.kv_trans_processes[device_index] - prefill_node_id = manager.args.pd_node_id - self.connect_id = str(uuid.uuid4()) - self.decode_node_id = decode_node_id - self.prefill_node_id = prefill_node_id - self.device_index = device_index - self.manager = manager - self.timer_checker = TimeChecker(6) - - con = rpyc.connect( - host=decode_node_ip, - port=decode_node_rpyc_port, - config={"allow_pickle": True, "sync_request_timeout": 60}, - keepalive=True, - ) - - self.rpyc_conn = con - - # 创建 nccl 连接 - with self.kv_trans_process.device_lock: - clear_queue(self.kv_trans_process.task_out_queue) - - self.kv_trans_process.task_in_queue.put( - PDTransJoinInfo( - prefill_id=prefill_node_id, - prefill_device_id=device_index, - pd_prefill_nccl_ip=manager.host_ip, - pd_prefill_nccl_port=self.kv_trans_process.kv_trans_port, - decode_id=decode_node_id, - decode_device_id=-1, - connect_id=self.connect_id, - ) - ) - - # 异步调用, 让decode节点建立与prefill节点进行nccl通信的进程 - max_kv_trans_token_num = obtain( - con.root.build_trans_connect( - prefill_node_id, - manager.host_ip, - self.kv_trans_process.kv_trans_port, - manager.args.max_total_token_num, - self.connect_id, - ) - ) - self.max_kv_trans_token_num = max_kv_trans_token_num - assert self.kv_trans_process.task_out_queue.get(timeout=60) == "nccl_ok" - - self.request_kv_trans_task_queue = TaskQueue( - get_func=self._get_request_tasks, fail_func=self.manager.put_to_release_task_queue - ) - self.request_thread = threading.Thread(target=self.request_kv_trans_loop, daemon=True) - self.request_thread.start() - - self.ready_kv_trans_task_queue = TaskQueue(lambda datas: datas[0:1], self.manager.put_to_release_task_queue) - self.kv_trans_thread = threading.Thread(target=self.kv_trans_handle_loop, daemon=True) - self.kv_trans_thread.start() - - logger.info(f"create KVTransConnectObj success: {self.to_log_info()}") - return - - def _get_request_tasks(self, datas: List[KVMoveTask]): - """ - 根据可以p和d节点间协商得到的 max_kv_trans_token_num 限制,将排队等待 - 传输的请求打包成一个可以传输的list组。 - """ - ans_list = [] - token_num = 0 - for task in datas: - if token_num + len(task.prefill_token_indexes) <= self.max_kv_trans_token_num: - ans_list.append(task) - token_num += len(task.prefill_token_indexes) - else: - break - return ans_list - - # ================================================================================== - # 与 decode 节点进行元数据交互,申请锁定资源准备进行kv的传输 - # ================================================================================== - def request_kv_trans_loop(self): - func_name = self.request_kv_trans_loop.__name__ - - while not self.has_error: - move_tasks: List[KVMoveTask] = self.request_kv_trans_task_queue.get_tasks( - log_tag="request_kv_trans_task_queue" - ) - if len(move_tasks) == 0: - self.timer_check_status(raise_exception=False) - time.sleep(0.01) - continue - try: - self.timer_check_status(raise_exception=True) - for move_task in move_tasks: - move_task.connect_id = self.connect_id - logger.info( - f"{func_name} get task {move_task.to_prefill_log_info()} " - f"queue time {move_task.get_cost_time()} s " - ) - - trans_move_tasks = [copy.copy(move_task) for move_task in move_tasks] - for trans_move_task in trans_move_tasks: - trans_move_task.prefill_token_indexes = None - - mark_start = time.time() - move_kv_lens = self.rpyc_conn.root.request_data_transfer(trans_move_tasks) - move_kv_lens = obtain(move_kv_lens) - request_data_transfer_cost_time = time.time() - mark_start - - logger.info( - f"{func_name} request_data_transfer ok, {move_tasks[0].to_prefill_log_info()}" - f" cost time: {request_data_transfer_cost_time} s" - ) - - ok_trans_list = [] - for i, move_task in enumerate(move_tasks.copy()): - if move_kv_lens[i] is not None: - move_task.move_kv_len = move_kv_lens[i] - ok_trans_list.append(move_task) - move_tasks.remove(move_task) - else: - logger.info(f"prefill node kv move task req_id: {move_task.id()} not send, decode is busy") - - if ok_trans_list: - self.ready_kv_trans_task_queue.put( - ok_trans_list, error_handle_func=self.manager.put_to_release_task_queue - ) - - except BaseException as e: - logger.exception(str(e)) - self.set_has_error() - self.request_kv_trans_task_queue.clear_tasks() - - finally: - # 将没有申请成功的请求放入到释放队列中 - self.manager.put_to_release_task_queue(move_tasks) - - logger.error(f"{func_name}, {self.to_log_info()} thread quit") - return - - # ================================================================================== - # 将准备好 kv 传输的请求进行 kv 传输 - # ================================================================================== - def _transfer_kv(self, move_tasks: List[KVMoveTask]): - with self.kv_trans_process.device_lock: - clear_queue(self.kv_trans_process.task_out_queue) - kv_move_group = KVMoveTaskGroup(tasks=move_tasks.copy(), connect_id=self.connect_id) - self.kv_trans_process.task_in_queue.put(kv_move_group, timeout=10) - assert self.kv_trans_process.task_out_queue.get(timeout=60) == "ok" - self.manager.put_to_release_task_queue(move_tasks) - - logger.info( - f"_transfer_kv data ok, req_id: {move_tasks[0].id()}" - f" cost total time: {move_tasks[0].get_cost_time()} s" - ) - move_tasks.clear() - - def kv_trans_handle_loop(self): - func_name = self.kv_trans_handle_loop.__name__ - while not self.has_error: - move_tasks: List[List[KVMoveTask]] = self.ready_kv_trans_task_queue.get_tasks( - log_tag="ready_kv_trans_task_queue" - ) - if len(move_tasks) == 0: - self.timer_check_status(raise_exception=False) - time.sleep(0.01) - continue - - if len(move_tasks) != 1: - logger.error(f"error get kv trans move_tasks, must be 1, get {len(move_tasks)}") - assert len(move_tasks) == 1 - - move_tasks: List[KVMoveTask] = move_tasks[0] - - try: - self.timer_check_status(raise_exception=True) - for move_task in move_tasks: - logger.info( - f"{func_name} get task {move_task.to_prefill_log_info()} to start kv move" - f"queue time {move_task.get_cost_time()} s " - ) - - if not kv_trans_use_p2p(): - with self.manager.kv_trans_lock: - self._transfer_kv(move_tasks) - else: - self._transfer_kv(move_tasks) - except BaseException as e: - logger.exception(str(e)) - self.set_has_error() - self.ready_kv_trans_task_queue.clear_tasks() - finally: - self.manager.put_to_release_task_queue(move_tasks) - - logger.error(f"trans kv thread, {self.to_log_info()} thread quit") - return - - # ================================================================================== - # 错误处理检测操作的一些通用函数 - # ================================================================================== - - def has_error_status(self): - try: - assert self.has_error is False - assert self.request_thread.is_alive() - assert self.kv_trans_thread.is_alive() - except BaseException as e: - logger.exception(str(e)) - self.set_has_error() - return True - - return False - - def timer_check_status(self, raise_exception=True): - if self.timer_checker.has_exceeded(): - try: - self.rpyc_conn.root.check_alive() - assert self.kv_trans_process.is_trans_process_health() - except BaseException as e: - logger.error(f"pid {self.kv_trans_process.process.pid} check failed") - logger.exception(str(e)) - - self.set_has_error() - if raise_exception: - raise e - - return - - def set_has_error(self): - """ - 将当前传输对象标记为有错误,这样可以防止请求放入到处理队列中 - """ - self.has_error = True - - if self.request_kv_trans_task_queue is not None: - self.request_kv_trans_task_queue.has_error = True - - if self.ready_kv_trans_task_queue is not None: - self.ready_kv_trans_task_queue.has_error = True - - if self.manager is not None: - self.manager.remove_trans_obj(self.connect_id) - return - - def __del__(self): - """ - 函数中有很多判断是否是None的操作,主要是为了避免一些异常流程的del行为不报错。 - """ - logger.error(f"trans obj del start, info: {self.to_log_info()}") - - try: - self.set_has_error() - - join_if_alive(self.request_thread) - join_if_alive(self.kv_trans_thread) - - # 将未处理的请求,清理掉,clear_tasks 会将没处理完的请求 - # 放入到 manager 资源释放队列中 - if self.request_kv_trans_task_queue is not None: - self.request_kv_trans_task_queue.clear_tasks() - if self.ready_kv_trans_task_queue is not None: - self.ready_kv_trans_task_queue.clear_tasks() - - # 传输进程清理掉 nccl 连接 - if self.connect_id is not None: - self.kv_trans_process.task_in_queue.put( - PDTransLeaveInfo( - decode_id=self.decode_node_id, prefill_id=self.prefill_node_id, connect_id=self.connect_id - ) - ) - - except BaseException as e: - logger.exception(str(e)) - - logger.error(f"trans obj deled, info: {self.to_log_info()}") - - def to_log_info(self): - log = f"connect_id: {self.connect_id} " - log += f"decode_node_id: {self.decode_node_id} " - log += f"prefill_node_id: {self.prefill_node_id} " - log += f"device_index: {self.device_index} " - return log - - -@dataclass -class KVTransProcess: - process: mp.Process = None - # 需要每个卡有一个锁来规划每次只能有一个 connection obj 操作对应显卡上的传输任务。 - device_lock: threading.Lock = None - task_in_queue: mp.Queue = None - task_out_queue: mp.Queue = None - device_id: int = None - kv_trans_port: int = None - - def init_all(self, device_id: int, manager: "PrefillKVMoveManager"): - self.device_id = device_id - self.device_lock = threading.Lock() - self.task_in_queue = mp.Queue() - self.task_out_queue = mp.Queue() - self.kv_trans_port = find_available_port(manager.args.pd_p_allowed_port_min, manager.args.pd_p_allowed_port_max) - - try: - from .prefill_trans_process import start_prefill_trans_process - - self.process = start_prefill_trans_process( - manager.args, - manager.host_ip, - self.kv_trans_port, - device_id, - self.task_in_queue, - self.task_out_queue, - ) - assert self.task_out_queue.get(timeout=30) == "proc_start" - assert self.task_out_queue.get(timeout=60) == "get_mem_managers_ok" - - return True - except Exception as e: - logger.warning(f"Failed start kv trans process for device {device_id}: {e}") - logger.exception(str(e)) - return False - - def is_trans_process_health(self): - try: - process = psutil.Process(self.process.pid) - if not (process.is_running() and process.status() != psutil.STATUS_ZOMBIE): - logger.error(f"kv trans process for device: {self.device_id} dead!!!") - return False - else: - return True - except: - return False - - def killself(self): - self.process.kill() diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_process.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_process.py deleted file mode 100644 index a328e3e080..0000000000 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_process.py +++ /dev/null @@ -1,162 +0,0 @@ -import torch -import time -import sys -import inspect -import threading -import setproctitle -import torch.multiprocessing as mp -from torch.distributed import TCPStore -from datetime import timedelta -from typing import List, Dict, Union -from lightllm.utils.log_utils import init_logger -from lightllm.common.kv_cache_mem_manager import MemoryManager -from lightllm.server.pd_io_struct import KVMoveTask, PDTransJoinInfo, PDTransLeaveInfo, KVMoveTaskGroup -from lightllm.utils.device_utils import kv_trans_use_p2p -from lightllm.utils.graceful_utils import graceful_registry -from lightllm.distributed.pynccl import StatelessP2PProcessGroup, PyNcclCommunicator -from lightllm.utils.envs_utils import get_unique_server_name - - -logger = init_logger(__name__) - - -def _handle_kvmove_task( - move_tasks: List[KVMoveTask], - task_out_queue: mp.Queue, - mem_managers: List[MemoryManager], - connect_id_to_comm: Dict[str, PyNcclCommunicator], - connect_id: str, - dp_size_in_node: int, -): - total_move_kv_len = sum([task.move_kv_len for task in move_tasks]) - try: - device_index = connect_id_to_comm[connect_id].device.index - start = time.time() - if total_move_kv_len != 0: - logger.info(f"trans start: {move_tasks[0].to_prefill_log_info()}") - cur_mem = mem_managers[device_index] - if kv_trans_use_p2p(): - cur_mem.send_to_decode_node_p2p( - move_tasks, mem_managers, dp_size_in_node, connect_id_to_comm[connect_id] - ) - else: - cur_mem.send_to_decode_node(move_tasks, mem_managers, dp_size_in_node, connect_id_to_comm[connect_id]) - logger.info(f"trans finished: {move_tasks[0].to_prefill_log_info()} move len: {total_move_kv_len}") - torch.cuda.synchronize() - logger.info( - f"trans cost time: {(time.time() - start)}," - f"move_total_kv_len: {total_move_kv_len}, {move_tasks[0].to_prefill_log_info()}" - ) - task_out_queue.put("ok") - except BaseException as e: - logger.exception(str(e)) - task_out_queue.put("fail") - - -def _handle_decode_join( - node_info: PDTransJoinInfo, - task_out_queue: mp.Queue, - connect_id_to_comm: Dict[str, PyNcclCommunicator], - store: TCPStore, -): - try: - logger.info(f"connect start {node_info}") - src_id = node_info.prefill_id - dest_id = node_info.connect_id - logger.info(f"connect src_id {src_id} dest_id {dest_id}") - result_list = [] - - def async_connect(): - torch.cuda.set_device(node_info.prefill_device_id) - group = StatelessP2PProcessGroup.create(src_id=src_id, dest_id=dest_id, is_server=True, store=store) - comm = PyNcclCommunicator(group, node_info.prefill_device_id) - result_list.append(comm) - return - - connect_task = threading.Thread(target=async_connect, daemon=True) - connect_task.start() - connect_task.join(timeout=36) - if connect_task.is_alive(): - raise Exception(f"{node_info} connect time out") - - connect_id_to_comm[node_info.connect_id] = result_list[0] - logger.info(f"{node_info} kv trans connected!") - task_out_queue.put("nccl_ok") - except Exception as e: - task_out_queue.put("nccl_fail") - logger.warning(f"error while connect to decode node: {e} node_info {node_info}") - - -def _init_env( - args, - store_ip, - store_port, - device_id, - task_in_queue: mp.Queue, - task_out_queue: mp.Queue, -): - import os - - # os.environ["NCCL_DEBUG"] = "INFO" - os.environ["NCCL_MAX_NCHANNELS"] = "2" - os.environ["NCCL_NSOCKS_PER_CHANNEL"] = "1" - os.environ["NCCL_SOCKET_NTHREADS"] = "1" - torch.backends.cudnn.enabled = False - - dp_size_in_node = max(1, args.dp // args.nnodes) - setproctitle.setproctitle( - f"lightllm::{get_unique_server_name()}::prefill_trans:Device{device_id}_DpSizeInNode{dp_size_in_node}" - ) - - try: - torch.cuda.set_device(device_id) - graceful_registry(inspect.currentframe().f_code.co_name) - master_store = TCPStore( - host_name=store_ip, port=store_port, is_master=True, use_libuv=True, timeout=timedelta(seconds=30) - ) - task_out_queue.put("proc_start") - - # 从共享内存读取所有rank的mem_manager - node_world_size = args.tp // args.nnodes - mem_managers: List[MemoryManager] = [ - MemoryManager.loads_from_shm(rank_in_node=rank) for rank in range(node_world_size) - ] - task_out_queue.put("get_mem_managers_ok") - connect_id_to_comm: Dict[str, PyNcclCommunicator] = {} - - while True: - task: Union[KVMoveTaskGroup, PDTransJoinInfo, PDTransLeaveInfo] = task_in_queue.get() - if isinstance(task, KVMoveTaskGroup): - _handle_kvmove_task( - task.tasks, task_out_queue, mem_managers, connect_id_to_comm, task.connect_id, dp_size_in_node - ) - elif isinstance(task, PDTransJoinInfo): - _handle_decode_join(task, task_out_queue, connect_id_to_comm, master_store) - elif isinstance(task, PDTransLeaveInfo): - if task.connect_id in connect_id_to_comm: - connect_id_to_comm[task.connect_id].destroy() - connect_id_to_comm.pop(task.connect_id, None) - logger.info(f"destory {task} nccl communicator.") - else: - logger.error(f"connect id {task.connect_id} dont exist in connect_id_to_comm") - else: - logger.warning(f"unexpected task type: {task}") - - except Exception as e: - logger.error(f"Fatal error happened in kv trans process: {e}") - pass - - -def start_prefill_trans_process( - args, - store_ip, - store_port, - device_id, - task_in_queue: mp.Queue, - task_out_queue: mp.Queue, -): - proc = mp.Process(target=_init_env, args=(args, store_ip, store_port, device_id, task_in_queue, task_out_queue)) - proc.start() - assert proc.is_alive() - logger.info(f"prefill trans kv process for device: {device_id} started!") - return proc diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/task_queue.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/task_queue.py deleted file mode 100644 index 7b856e54a0..0000000000 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/task_queue.py +++ /dev/null @@ -1,48 +0,0 @@ -import threading -from lightllm.utils.log_utils import init_logger - -logger = init_logger(__name__) - - -class TaskQueue: - def __init__(self, get_func, fail_func): - self.lock = threading.Lock() - self.datas = [] - self.get_func = get_func - self.fail_func = fail_func - self.has_error = False - - def size(self): - return len(self.datas) - - def put(self, obj, error_handle_func=None): - if self.has_error: - if error_handle_func is not None: - error_handle_func(obj) - raise Exception("has error") - - with self.lock: - self.datas.append(obj) - - def put_list(self, objs): - if self.has_error: - raise Exception("has error") - - with self.lock: - self.datas.extend(objs) - - def get_tasks(self, log_tag=None): - with self.lock: - ans = self.get_func(self.datas) - self.datas = self.datas[len(ans) :] - if len(self.datas) != 0: - logger.info(f"queue {log_tag} left size: {len(self.datas)}") - return ans - - def clear_tasks(self): - with self.lock: - if len(self.datas) != 0: - for obj in self.datas: - self.fail_func(obj) - self.datas = [] - return diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/utils.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/utils.py deleted file mode 100644 index cd1360fd0a..0000000000 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/utils.py +++ /dev/null @@ -1,20 +0,0 @@ -import threading -import torch.multiprocessing as mp -from queue import Empty - - -def join_if_alive(thread: threading.Thread): - if thread is not None and thread.is_alive(): - try: - thread.join() - except Exception: - pass - return - - -def clear_queue(queue: mp.Queue): - while not queue.empty(): - try: - queue.get_nowait() - except Empty: - break diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/decode_trans_process.py b/lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/decode_trans_process.py index 776f41f24a..0a947f8071 100644 --- a/lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/decode_trans_process.py +++ b/lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/decode_trans_process.py @@ -18,7 +18,7 @@ from lightllm.server.pd_io_struct import NIXLDecodeNodeInfo from lightllm.utils.graceful_utils import graceful_registry from lightllm.server.core.objs import StartArgs -from ..nixl_kv_transporter import NixlKVTransporter +from ..kv_transporter import create_kv_transporter from lightllm.utils.error_utils import log_exception from lightllm.utils.envs_utils import get_unique_server_name @@ -127,8 +127,11 @@ def __init__( page_num=self.args.nixl_pd_kv_page_num, page_size=self.args.nixl_pd_kv_page_size ) self.copy_cuda_stream = torch.cuda.Stream(priority=-1) - self.transporter = NixlKVTransporter( - node_id=self.args.pd_node_id, tp_idx=device_id, kv_move_buffer=kv_move_buffer + self.transporter = create_kv_transporter( + args=self.args, + node_id=self.args.pd_node_id, + tp_idx=device_id, + kv_move_buffer=kv_move_buffer, ) self.recv_task_group_queue = queue.Queue() self.waiting_dict_lock = threading.Lock() diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/up_status.py b/lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/up_status.py index bf70694672..3926ad0eaa 100644 --- a/lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/up_status.py +++ b/lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/up_status.py @@ -7,9 +7,9 @@ import pickle import setproctitle -from typing import Dict, Union +from typing import Dict from dataclasses import asdict -from lightllm.server.pd_io_struct import UpKVStatus, NixlUpKVStatus +from lightllm.server.pd_io_struct import NixlUpKVStatus from lightllm.utils.log_utils import init_logger from lightllm.utils.graceful_utils import graceful_registry from lightllm.server.pd_io_struct import PD_Master_Obj @@ -22,7 +22,7 @@ class UpStatusManager: def __init__(self, args, task_in_queue: mp.SimpleQueue): self.args = args - self.task_queue: mp.SimpleQueue[Union[UpKVStatus, NixlUpKVStatus]] = task_in_queue + self.task_queue: mp.SimpleQueue[NixlUpKVStatus] = task_in_queue self.daemon_thread = threading.Thread(target=self.thread_loop, daemon=True) self.daemon_thread.start() @@ -66,7 +66,7 @@ async def dispatch_task_loop(self): while True: try: loop = asyncio.get_event_loop() - upkv_status: UpKVStatus = await loop.run_in_executor(None, self.task_queue.get) + upkv_status: NixlUpKVStatus = await loop.run_in_executor(None, self.task_queue.get) if upkv_status.pd_master_node_id in self.id_to_handle_queue: await self.id_to_handle_queue[upkv_status.pd_master_node_id].put(upkv_status) else: @@ -89,7 +89,7 @@ async def up_kv_status_task(self, pd_master_obj: PD_Master_Obj): try: if pd_master_obj.node_id in self.id_to_handle_queue: task_queue = self.id_to_handle_queue[pd_master_obj.node_id] - upkv_status: Union[UpKVStatus, NixlUpKVStatus] = await task_queue.get() + upkv_status: NixlUpKVStatus = await task_queue.get() await websocket.send(pickle.dumps(upkv_status)) logger.info(f"up kv status: {upkv_status}") else: diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/kv_transporter.py b/lightllm/server/router/model_infer/mode_backend/pd_nixl/kv_transporter.py new file mode 100644 index 0000000000..ee6c04612b --- /dev/null +++ b/lightllm/server/router/model_infer/mode_backend/pd_nixl/kv_transporter.py @@ -0,0 +1,37 @@ +import os + +from torch import Tensor + +from lightllm.server.core.objs import StartArgs +from lightllm.utils.log_utils import init_logger +from lightllm.utils.net_utils import get_hostname_ip + +logger = init_logger(__name__) + + +def create_kv_transporter(args: StartArgs, node_id: int, tp_idx: int, kv_move_buffer: Tensor): + backend = os.getenv("LIGHTLLM_PD_KV_TRANSPORT_BACKEND", "nixl").lower() + if backend == "nixl": + from .nixl_kv_transporter import NixlKVTransporter + + return NixlKVTransporter(node_id=node_id, tp_idx=tp_idx, kv_move_buffer=kv_move_buffer) + + if backend == "nccl": + from .nccl_kv_transporter import NcclKVTransporter + + logger.info("Use NCCL as pd_nixl KV transporter backend") + port_min = args.pd_p_allowed_port_min + tp_idx * 100 + port_max = min(args.pd_p_allowed_port_max, port_min + 99) + if port_min > args.pd_p_allowed_port_max: + port_min = args.pd_p_allowed_port_min + port_max = args.pd_p_allowed_port_max + return NcclKVTransporter( + node_id=node_id, + tp_idx=tp_idx, + kv_move_buffer=kv_move_buffer, + host_ip=get_hostname_ip() or args.host, + store_port_min=port_min, + store_port_max=port_max, + ) + + raise ValueError(f"unsupported LIGHTLLM_PD_KV_TRANSPORT_BACKEND={backend}") diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/nccl_kv_transporter.py b/lightllm/server/router/model_infer/mode_backend/pd_nixl/nccl_kv_transporter.py new file mode 100644 index 0000000000..b983312dec --- /dev/null +++ b/lightllm/server/router/model_infer/mode_backend/pd_nixl/nccl_kv_transporter.py @@ -0,0 +1,491 @@ +import copy +import pickle +import threading +from dataclasses import dataclass +from datetime import timedelta +from typing import Dict, List, Optional + +import torch +from torch import Tensor +from torch.distributed import TCPStore + +from lightllm.distributed.pynccl import PyNcclCommunicator, StatelessP2PProcessGroup +from lightllm.server.pd_io_struct import NIXLChunckedTransTask, NixlAgentMetadata +from lightllm.utils.log_utils import init_logger +from lightllm.utils.net_utils import get_hostname_ip + +logger = init_logger(__name__) + + +@dataclass +class NcclAgentMetadata: + agent_name: str + host_ip: str + store_port: int + device_id: int + + +@dataclass +class _NcclXferHandle: + thread: Optional[threading.Thread] + status: str = "PROC" + error_info: Optional[str] = None + + +class _PeerSeqTurn: + def __init__(self, transporter: "NcclKVTransporter", peer_name: str, seq: int): + self.transporter = transporter + self.peer_name = peer_name + self.seq = seq + + def __enter__(self): + with self.transporter._peer_seq_cond: + while self.transporter._peer_seq_to_run.get(self.peer_name, 0) != self.seq: + self.transporter._peer_seq_cond.wait() + return self + + def __exit__(self, exc_type, exc_value, traceback): + with self.transporter._peer_seq_cond: + self.transporter._peer_seq_to_run[self.peer_name] = self.seq + 1 + self.transporter._peer_seq_cond.notify_all() + return False + + +class NcclKVTransporter: + """ + NIXL-compatible transporter backed by NCCL point-to-point operations. + + NIXL provides remote notifications and one-sided WRITE. NCCL does not, so this + class uses a small TCPStore control plane for notifications and communicator + bootstrap while preserving the same request/ready/done/error interface used by + pd_nixl trans-process management. + """ + + def __init__( + self, + node_id: int, + tp_idx: int, + kv_move_buffer: Tensor, + host_ip: Optional[str] = None, + store_port: Optional[int] = None, + store_port_min: int = 20000, + store_port_max: int = 30000, + ): + self.node_id = node_id + self.tp_idx = tp_idx + self.kv_move_buffer = kv_move_buffer + self.capture_telemetry = False + self.num_pages, self.page_size, self.num_layers, self.kv_head_num, self.head_dims = kv_move_buffer.shape + + self.host_ip = host_ip or get_hostname_ip() + assert self.host_ip is not None, "can not get host ip for NcclKVTransporter" + + self.store, self.store_port = self._create_local_store( + store_port=store_port, + store_port_min=store_port_min, + store_port_max=store_port_max, + ) + self.remote_agents: Dict[str, NixlAgentMetadata] = {} + self.remote_stores: Dict[str, TCPStore] = {} + self._comms: Dict[str, PyNcclCommunicator] = {} + self._comm_create_lock = threading.Lock() + self._peer_seq_cond = threading.Condition() + self._peer_seq_to_assign: Dict[str, int] = {} + self._peer_seq_to_run: Dict[str, int] = {} + self._recv_notif_counter = 0 + self._deferred_notifs: List[bytes] = [] + self._recv_task_status: Dict[str, _NcclXferHandle] = {} + self._xfer_handle_counter = 0 + self._xfer_handles: Dict[int, _NcclXferHandle] = {} + return + + def _create_local_store( + self, store_port: Optional[int], store_port_min: int, store_port_max: int + ) -> tuple[TCPStore, int]: + if store_port is not None: + ports = [store_port] + else: + ports = list(range(store_port_min, store_port_max + 1)) + + last_error = None + for port in ports: + try: + store = TCPStore( + host_name=self.host_ip, + port=port, + is_master=True, + use_libuv=True, + timeout=timedelta(seconds=30), + ) + return store, port + except BaseException as e: + last_error = e + logger.warning(f"Create NCCL TCPStore on {self.host_ip}:{port} failed: {e}") + + raise RuntimeError( + f"can not allocate NCCL TCPStore port in [{store_port_min}, {store_port_max}]" + ) from last_error + + @property + def agent_name(self) -> str: + return f"{self.node_id}_{self.tp_idx}" + + @property + def agent_metadata(self) -> bytes: + return pickle.dumps( + NcclAgentMetadata( + agent_name=self.agent_name, + host_ip=self.host_ip, + store_port=self.store_port, + device_id=self.tp_idx, + ) + ) + + @property + def local_page_mem_desc(self) -> bytes: + return pickle.dumps( + { + "num_pages": self.num_pages, + "page_size": self.page_size, + "num_layers": self.num_layers, + "kv_head_num": self.kv_head_num, + "head_dims": self.head_dims, + "dtype": str(self.kv_move_buffer.dtype), + } + ) + + def get_new_notifs(self) -> Dict[str, List[bytes]]: + notifs: Dict[str, List[bytes]] = {} + still_deferred = [] + for notify in self._deferred_notifs: + ready_notify = self._get_ready_notify(notify) + if ready_notify is None: + still_deferred.append(notify) + else: + notifs.setdefault(self._get_notify_source_agent_name(ready_notify), []).append(ready_notify) + self._deferred_notifs = still_deferred + + while True: + key = self._notif_key(self.agent_name, self._recv_notif_counter) + if not self.store.check([key]): + break + notify = bytes(self.store.get(key)) + ready_notify = self._get_ready_notify(notify) + if ready_notify is None: + self._deferred_notifs.append(notify) + else: + notifs.setdefault(self._get_notify_source_agent_name(ready_notify), []).append(ready_notify) + self._recv_notif_counter += 1 + return notifs + + def connect_add_remote_agent(self, remote_agent: NixlAgentMetadata): + if remote_agent.agent_name in self.remote_agents: + return + + metadata: NcclAgentMetadata = pickle.loads(remote_agent.agent_metadata) + assert ( + metadata.agent_name == remote_agent.agent_name + ), f"Peer name {metadata.agent_name} does not match remote name {remote_agent.agent_name}" + + self.remote_agents[remote_agent.agent_name] = remote_agent + self.remote_stores[remote_agent.agent_name] = TCPStore( + host_name=metadata.host_ip, + port=metadata.store_port, + is_master=False, + use_libuv=True, + timeout=timedelta(seconds=30), + ) + logger.info(f"Added NCCL remote agent {remote_agent.agent_name} at {metadata.host_ip}:{metadata.store_port}") + return + + def remove_remote_agent(self, peer_name: str): + if peer_name in self.remote_agents: + self.remote_agents.pop(peer_name, None) + self.remote_stores.pop(peer_name, None) + comm = self._comms.pop(peer_name, None) + if comm is not None: + comm.destroy() + with self._peer_seq_cond: + self._peer_seq_to_assign.pop(peer_name, None) + self._peer_seq_to_run.pop(peer_name, None) + else: + logger.warning(f"try to remove remote agent, but peer name {peer_name} agent did not exist") + return + + def send_write_done_task_to_decode_node(self, trans_task: NIXLChunckedTransTask): + new_trans_task = self._copy_notify_task(trans_task) + new_trans_task.nixl_write_stage = "done" + new_trans_task.prefill_agent_name = self.agent_name + new_trans_task.prefill_agent_metadata = self.agent_metadata + new_trans_task.prefill_num_pages = self.num_pages + new_trans_task.prefill_page_reg_desc = self.local_page_mem_desc + self._send_task_notif(trans_task.decode_agent_name, new_trans_task) + return + + def send_write_request_task_to_decode_node(self, trans_task: NIXLChunckedTransTask): + new_trans_task = self._copy_notify_task(trans_task) + new_trans_task.nixl_write_stage = "request" + new_trans_task.prefill_agent_name = self.agent_name + new_trans_task.prefill_agent_metadata = self.agent_metadata + new_trans_task.prefill_num_pages = self.num_pages + new_trans_task.prefill_page_reg_desc = self.local_page_mem_desc + self._send_task_notif(trans_task.decode_agent_name, new_trans_task) + return + + def send_write_ready_task_to_prefill_node(self, trans_task: NIXLChunckedTransTask): + self._start_recv_task(trans_task) + + new_trans_task = self._copy_notify_task(trans_task) + new_trans_task.nixl_write_stage = "ready" + new_trans_task.decode_agent_name = self.agent_name + new_trans_task.decode_agent_metadata = self.agent_metadata + new_trans_task.decode_num_pages = self.num_pages + new_trans_task.decode_page_reg_desc = self.local_page_mem_desc + self._send_task_notif(trans_task.prefill_agent_name, new_trans_task) + return + + def send_error_info_to_prefill_node(self, trans_task: NIXLChunckedTransTask): + if trans_task.prefill_agent_name is None: + return + new_trans_task = self._copy_notify_task(trans_task) + new_trans_task.nixl_write_stage = "error" + new_trans_task.decode_agent_name = self.agent_name + new_trans_task.decode_agent_metadata = self.agent_metadata + new_trans_task.decode_num_pages = self.num_pages + new_trans_task.decode_page_reg_desc = self.local_page_mem_desc + self._send_task_notif(trans_task.prefill_agent_name, new_trans_task) + return + + def send_error_info_to_decode_node(self, trans_task: NIXLChunckedTransTask): + new_trans_task = self._copy_notify_task(trans_task) + new_trans_task.nixl_write_stage = "error" + new_trans_task.prefill_agent_name = self.agent_name + new_trans_task.prefill_agent_metadata = self.agent_metadata + new_trans_task.prefill_num_pages = self.num_pages + new_trans_task.prefill_page_reg_desc = self.local_page_mem_desc + self._send_task_notif(trans_task.decode_agent_name, new_trans_task) + return + + def write_blocks_paged(self, trans_task: NIXLChunckedTransTask) -> int: + assert trans_task.nixl_src_page_index is not None and trans_task.nixl_dst_page_index is not None + decode_agent_name = trans_task.decode_agent_name + if decode_agent_name not in self.remote_agents: + self.connect_add_remote_agent(trans_task.create_decode_agent_obj()) + + self._ensure_comm( + remote_agent_name=decode_agent_name, + is_server=True, + store=self.store, + ) + handle = self._next_xfer_handle() + seq = self._assign_peer_seq(decode_agent_name) + xfer_handle = _NcclXferHandle( + thread=threading.Thread(target=self._send_page_task, args=(handle, trans_task, seq), daemon=True) + ) + self._xfer_handles[handle] = xfer_handle + xfer_handle.thread.start() + return handle + + def check_task_status(self, trans_task: NIXLChunckedTransTask) -> str: + assert trans_task.xfer_handle is not None + handle = self._xfer_handles[trans_task.xfer_handle] + if handle.status == "ERR": + logger.warning(f"Transfer failed with trans task {trans_task.to_str()}: {handle.error_info}") + return handle.status + + def release_xfer_handle(self, handle): + xfer_handle = self._xfer_handles.pop(handle, None) + if xfer_handle is not None: + xfer_handle.thread.join(timeout=1) + return + + def shutdown(self): + for handle in list(self._xfer_handles.keys()): + self.release_xfer_handle(handle) + for comm in list(self._comms.values()): + comm.destroy() + self._comms.clear() + self.remote_agents.clear() + self.remote_stores.clear() + return + + def _start_recv_task(self, trans_task: NIXLChunckedTransTask): + if trans_task.prefill_agent_name not in self.remote_agents: + self.connect_add_remote_agent(trans_task.create_prefill_agent_obj()) + self._recv_task_status[trans_task.get_key()] = _NcclXferHandle(thread=None) + seq = self._assign_peer_seq(trans_task.prefill_agent_name) + threading.Thread(target=self._recv_page_task, args=(copy.copy(trans_task), seq), daemon=True).start() + return + + def _send_page_task(self, handle: int, trans_task: NIXLChunckedTransTask, seq: int): + xfer_handle = self._xfer_handles[handle] + try: + remote_agent = self.remote_agents[trans_task.decode_agent_name] + remote_metadata: NcclAgentMetadata = pickle.loads(remote_agent.agent_metadata) + page_tensor = self.kv_move_buffer[trans_task.nixl_src_page_index] + comm = self._get_cached_comm(trans_task.decode_agent_name) + with self._peer_seq_turn(trans_task.decode_agent_name, seq): + comm.send(page_tensor, dst=1) + torch.cuda.current_stream().synchronize() + xfer_handle.status = "DONE" + logger.info( + f"NCCL send page done request_id={trans_task.request_id} " + f"src_page={trans_task.nixl_src_page_index} dst_agent={remote_metadata.agent_name}" + ) + except BaseException as e: + xfer_handle.status = "ERR" + xfer_handle.error_info = str(e) + logger.exception(str(e)) + self._drop_comm(trans_task.decode_agent_name) + return + + def _recv_page_task(self, trans_task: NIXLChunckedTransTask, seq: int): + try: + page_tensor = self.kv_move_buffer[trans_task.nixl_dst_page_index] + remote_agent = self.remote_agents[trans_task.prefill_agent_name] + remote_store = self.remote_stores[remote_agent.agent_name] + comm = self._ensure_comm( + remote_agent_name=trans_task.prefill_agent_name, + is_server=False, + store=remote_store, + ) + with self._peer_seq_turn(trans_task.prefill_agent_name, seq): + comm.recv(page_tensor, src=0) + torch.cuda.current_stream().synchronize() + self._recv_task_status[trans_task.get_key()].status = "DONE" + logger.info( + f"NCCL recv page done request_id={trans_task.request_id} " + f"dst_page={trans_task.nixl_dst_page_index}" + ) + except BaseException as e: + trans_task.error_info = str(e) + recv_status = self._recv_task_status.get(trans_task.get_key(), None) + if recv_status is not None: + recv_status.status = "ERR" + recv_status.error_info = str(e) + logger.exception(str(e)) + self._drop_comm(trans_task.prefill_agent_name) + self.send_error_info_to_prefill_node(trans_task) + return + + def _get_ready_notify(self, notify: bytes) -> Optional[bytes]: + try: + notify_obj = pickle.loads(notify) + except BaseException: + return notify + + if not isinstance(notify_obj, NIXLChunckedTransTask): + return notify + + if notify_obj.nixl_write_stage != "done": + return notify + + recv_status = self._recv_task_status.get(notify_obj.get_key(), None) + if recv_status is None or recv_status.status == "PROC": + return None + + self._recv_task_status.pop(notify_obj.get_key(), None) + if recv_status.status == "ERR": + notify_obj.error_info = recv_status.error_info or "nccl recv failed" + return pickle.dumps(notify_obj) + + return notify + + def _get_cached_comm(self, remote_agent_name: str) -> PyNcclCommunicator: + comm = self._comms.get(remote_agent_name) + if comm is None: + raise RuntimeError(f"NCCL communicator with peer {remote_agent_name} is not initialized") + return comm + + def _ensure_comm( + self, + remote_agent_name: str, + is_server: bool, + store: TCPStore, + ) -> PyNcclCommunicator: + comm = self._comms.get(remote_agent_name) + if comm is not None: + return comm + + with self._comm_create_lock: + comm = self._comms.get(remote_agent_name) + if comm is not None: + return comm + + if is_server: + src_id = self.agent_name + dest_id = remote_agent_name + else: + src_id = remote_agent_name + dest_id = self.agent_name + + group = StatelessP2PProcessGroup.create( + src_id=src_id, + dest_id=dest_id, + is_server=is_server, + store=store, + ) + comm = PyNcclCommunicator(group, self.tp_idx) + self._comms[remote_agent_name] = comm + logger.info(f"Created NCCL communicator with peer {remote_agent_name}") + return comm + + def _drop_comm(self, remote_agent_name: str): + with self._comm_create_lock: + comm = self._comms.pop(remote_agent_name, None) + if comm is not None: + comm.destroy() + logger.warning(f"Dropped NCCL communicator with peer {remote_agent_name}") + return + + def _assign_peer_seq(self, peer_name: str) -> int: + with self._peer_seq_cond: + seq = self._peer_seq_to_assign.get(peer_name, 0) + self._peer_seq_to_assign[peer_name] = seq + 1 + self._peer_seq_to_run.setdefault(peer_name, 0) + return seq + + def _peer_seq_turn(self, peer_name: str, seq: int): + return _PeerSeqTurn(self, peer_name, seq) + + def _send_task_notif(self, remote_agent_name: str, trans_task: NIXLChunckedTransTask): + if remote_agent_name not in self.remote_agents: + if remote_agent_name == trans_task.decode_agent_name: + self.connect_add_remote_agent(trans_task.create_decode_agent_obj()) + else: + self.connect_add_remote_agent(trans_task.create_prefill_agent_obj()) + + remote_store = self.remote_stores[remote_agent_name] + counter = remote_store.add(f"notif/{remote_agent_name}/counter", 1) - 1 + remote_store.set(self._notif_key(remote_agent_name, counter), pickle.dumps(trans_task)) + return + + def _copy_notify_task(self, trans_task: NIXLChunckedTransTask) -> NIXLChunckedTransTask: + new_trans_task: NIXLChunckedTransTask = copy.copy(trans_task) + new_trans_task.mem_indexes = None + new_trans_task.xfer_handle = None + return new_trans_task + + def _next_xfer_handle(self): + self._xfer_handle_counter += 1 + return self._xfer_handle_counter + + @staticmethod + def _notif_key(agent_name: str, counter: int) -> str: + return f"notif/{agent_name}/{counter}" + + @staticmethod + def _get_notify_source_agent_name(notify: bytes) -> str: + try: + notify_obj = pickle.loads(notify) + except BaseException: + return "unknown" + + if not isinstance(notify_obj, NIXLChunckedTransTask): + return "unknown" + + if notify_obj.nixl_write_stage == "request": + return notify_obj.prefill_agent_name or "unknown" + if notify_obj.nixl_write_stage in ["ready", "done"]: + return notify_obj.decode_agent_name or "unknown" + return notify_obj.prefill_agent_name or notify_obj.decode_agent_name or "unknown" diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/p2p_fix.py b/lightllm/server/router/model_infer/mode_backend/pd_nixl/p2p_fix.py similarity index 64% rename from lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/p2p_fix.py rename to lightllm/server/router/model_infer/mode_backend/pd_nixl/p2p_fix.py index 62609c4c91..8ca5f4e808 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/p2p_fix.py +++ b/lightllm/server/router/model_infer/mode_backend/pd_nixl/p2p_fix.py @@ -1,16 +1,15 @@ # mypy: allow-untyped-defs -import multiprocessing -import os -import threading -from multiprocessing.reduction import ForkingPickler -from multiprocessing.util import register_after_fork -from typing import Union - import torch import torch.utils.hooks from torch._namedtensor_internals import check_serializing_named_tensor -from torch.multiprocessing.reductions import storage_from_cache, shared_cache, StorageWeakRef -from torch.multiprocessing.reductions import reduce_nested_tensor, reduce_sparse_tensor, rebuild_tensor +from torch.multiprocessing.reductions import ( + StorageWeakRef, + reduce_nested_tensor, + reduce_sparse_tensor, + rebuild_tensor, + shared_cache, + storage_from_cache, +) def p2p_fix_rebuild_cuda_tensor( @@ -30,13 +29,7 @@ def p2p_fix_rebuild_cuda_tensor( event_handle, event_sync_required, ): - # 因为接收进程在将 tensor 对应的 handle重新转化为指针的时候 - # 在其c++源码中会将当前显卡切换到storage_device再做操作,这样 - # 得到的指针可能不是接收进程当前上下文设备可以访问的,所以在这里 - # hack 修改了使用的 storage_device,这样后续tritonkernel同时 - # 访问几张显卡上的数据,进行p2p操作就不会出问题了。 storage_device = torch.cuda.current_device() - # If storage_handle is None, storage points to nullptr. if storage_handle is None or storage_size_bytes == 0: storage = storage_cls(0, dtype=dtype, device=storage_device, _internal=True) else: @@ -55,12 +48,10 @@ def p2p_fix_rebuild_cuda_tensor( ) shared_cache[(storage_handle, storage_offset_bytes)] = StorageWeakRef(storage) else: - # We already ref counting this Storage, but producer needs new ref-counters to be released. storage_cls._release_ipc_counter(ref_counter_handle, ref_counter_offset, device=storage_device) _storage = storage if isinstance(storage, torch.UntypedStorage) else storage._untyped_storage - - t = torch._utils._rebuild_tensor( + tensor = torch._utils._rebuild_tensor( torch.storage.TypedStorage(wrap_storage=_storage, dtype=dtype, _internal=True), tensor_offset, tensor_size, @@ -68,22 +59,20 @@ def p2p_fix_rebuild_cuda_tensor( ) if tensor_cls == torch.nn.parameter.Parameter: - # It is crucial for integer tensors to receive - # the requires_grad=False as an argument in the constructor - t = torch.nn.parameter.Parameter(t, requires_grad=requires_grad) + tensor = torch.nn.parameter.Parameter(tensor, requires_grad=requires_grad) else: - t.requires_grad = requires_grad + tensor.requires_grad = requires_grad - return t + return tensor def reduce_tensor(tensor): if tensor.requires_grad and not tensor.is_leaf: raise RuntimeError( "Cowardly refusing to serialize non-leaf tensor which requires_grad, " - "since autograd does not support crossing process boundaries. " + "since autograd does not support crossing process boundaries. " "If you just want to transfer the data, call detach() on the tensor " - "before serializing (e.g., putting it on the queue)." + "before serializing." ) check_serializing_named_tensor(tensor) @@ -106,6 +95,8 @@ def reduce_tensor(tensor): storage = tensor._typed_storage() if storage._untyped_storage.device.type == "cuda": + from lightllm.server.router.model_infer.mode_backend.pd_nixl.p2p_fix import p2p_fix_rebuild_cuda_tensor + ( device, handle, @@ -118,25 +109,19 @@ def reduce_tensor(tensor): ) = storage._share_cuda_() tensor_offset = tensor.storage_offset() shared_cache[handle] = StorageWeakRef(storage) - # _backward_hooks purposely omitted here, see - # Note [Don't serialize hooks] - from lightllm.server.router.model_infer.mode_backend.continues_batch.pd_mode.p2p_fix import ( - p2p_fix_rebuild_cuda_tensor, - ) - return ( p2p_fix_rebuild_cuda_tensor, ( type(tensor), tensor.size(), tensor.stride(), - tensor_offset, # tensor offset in its storage + tensor_offset, type(storage), tensor.dtype, device, - handle, # identifier which CUDA allocation is the storage in. - storage_size_bytes, # size(in bytes) of the storage - storage_offset_bytes, # offset(in bytes) of the storage in the CUDA allocation + handle, + storage_size_bytes, + storage_offset_bytes, tensor.requires_grad, ref_counter_handle, ref_counter_offset, @@ -145,7 +130,6 @@ def reduce_tensor(tensor): ), ) - # _backward_hooks purposely omitted here, see Note [Don't serialize hooks] metadata = ( tensor.storage_offset(), tensor.size(), diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/prefill_node_impl/prefill_trans_process.py b/lightllm/server/router/model_infer/mode_backend/pd_nixl/prefill_node_impl/prefill_trans_process.py index cd124445ea..856fba1188 100644 --- a/lightllm/server/router/model_infer/mode_backend/pd_nixl/prefill_node_impl/prefill_trans_process.py +++ b/lightllm/server/router/model_infer/mode_backend/pd_nixl/prefill_node_impl/prefill_trans_process.py @@ -12,7 +12,7 @@ from lightllm.server.pd_io_struct import NIXLChunckedTransTask from lightllm.utils.graceful_utils import graceful_registry from lightllm.server.core.objs import StartArgs -from ..nixl_kv_transporter import NixlKVTransporter +from ..kv_transporter import create_kv_transporter from lightllm.utils.error_utils import log_exception from lightllm.utils.envs_utils import get_unique_server_name @@ -101,8 +101,11 @@ def __init__( page_num=self.args.nixl_pd_kv_page_num, page_size=self.args.nixl_pd_kv_page_size ) self.copy_cuda_stream = torch.cuda.Stream(priority=-1) - self.transporter = NixlKVTransporter( - node_id=self.args.pd_node_id, tp_idx=device_id, kv_move_buffer=kv_move_buffer + self.transporter = create_kv_transporter( + args=self.args, + node_id=self.args.pd_node_id, + tp_idx=device_id, + kv_move_buffer=kv_move_buffer, ) self.waiting_dict_lock = threading.Lock() self.waiting_dict: Dict[str, NIXLChunckedTransTask] = {} diff --git a/lightllm/server/router/model_infer/model_rpc.py b/lightllm/server/router/model_infer/model_rpc.py index 408b173371..0590c454ff 100644 --- a/lightllm/server/router/model_infer/model_rpc.py +++ b/lightllm/server/router/model_infer/model_rpc.py @@ -22,10 +22,6 @@ XgrammarBackend, DPChunkedPrefillBackend, DiversehBackend, - DecodeNode, - DPForDecodeNode, - ChunckedPrefillForPrefillNode, - DPChunkedForPrefillNode, NIXLChunckedPrefillForPrefillNode, NIXLDPChunkedForPrefillNode, NIXLDecodeNode, @@ -67,28 +63,15 @@ def exposed_init_model(self, kvargs): is_outlines_constraint_mode = self.args.output_constraint_mode == "outlines" is_xgrammar_constraint_mode = self.args.output_constraint_mode == "xgrammar" assert not (is_outlines_constraint_mode and is_xgrammar_constraint_mode), "only one constraint mode can be true" - is_prefill_node = self.args.run_mode == "prefill" - is_decode_node = self.args.run_mode == "decode" is_nixl_prefill_node = self.args.run_mode == "nixl_prefill" is_nixl_decode_node = self.args.run_mode == "nixl_decode" - if is_prefill_node: - if self.args.dp > 1: - self.backend = DPChunkedForPrefillNode(self.info_queue) - else: - self.backend = ChunckedPrefillForPrefillNode(self.info_queue) - elif is_nixl_prefill_node: + if is_nixl_prefill_node: if self.args.dp > 1: self.backend = NIXLDPChunkedForPrefillNode(self.info_queue) else: self.backend = NIXLChunckedPrefillForPrefillNode(self.info_queue) - elif is_decode_node: - if self.args.dp > 1: - self.backend = DPForDecodeNode(self.info_queue) - else: - self.backend = DecodeNode(self.info_queue) - elif is_nixl_decode_node: if self.args.dp > 1: self.backend = NIXLDPForDecodeNode(self.info_queue) diff --git a/lightllm/server/router/req_queue/__init__.py b/lightllm/server/router/req_queue/__init__.py index 067332d945..5f0cb866ed 100644 --- a/lightllm/server/router/req_queue/__init__.py +++ b/lightllm/server/router/req_queue/__init__.py @@ -1,4 +1,3 @@ -from .chunked_prefill.impl_for_pd_decode import QueueForPDDecode from .chunked_prefill.impl import ChunkedPrefillQueue from .chunked_prefill.beam_impl import ChunkedBeamContinuesBatchQueue from .chunked_prefill.impl_for_nixl_pd import NIXLPDQueue @@ -14,10 +13,6 @@ def _get_req_queue_class(args, router, dp_size_in_node: int): return ChunkedPrefillQueue if args.first_token_constraint_mode: return ChunkedPrefillQueue - if args.run_mode in ["decode"]: - return QueueForPDDecode - if args.run_mode in ["prefill"]: - return ChunkedPrefillQueue if args.run_mode in ["nixl_prefill", "nixl_decode"]: return NIXLPDQueue diff --git a/lightllm/server/router/req_queue/chunked_prefill/impl_for_pd_decode.py b/lightllm/server/router/req_queue/chunked_prefill/impl_for_pd_decode.py deleted file mode 100644 index 4c2ebf7c00..0000000000 --- a/lightllm/server/router/req_queue/chunked_prefill/impl_for_pd_decode.py +++ /dev/null @@ -1,82 +0,0 @@ -import time -import uuid -import numpy as np -from typing import List -from lightllm.utils.infer_utils import calculate_time -from ...batch import Batch, Req -from lightllm.server.router.req_queue.base_queue import BaseQueue -from lightllm.common.basemodel.infer_lock import g_router_lock -from lightllm.utils.log_utils import init_logger - -logger = init_logger(__name__) - - -class QueueForPDDecode(BaseQueue): - def __init__(self, args, router, dp_index, dp_size_in_node) -> None: - super().__init__(args, router, dp_index, dp_size_in_node) - - def _init_cache_list(self, current_batch: Batch, is_busy): - if current_batch is not None: - self.cache_len_list = [ - req.get_tuple_tokens(is_busy, self.router.router_statics.ema_req_out_len) - for req in current_batch.reqs - if req.sample_params.suggested_dp_index == self.dp_index - ] - else: - self.cache_len_list = [] - return - - # @calculate_time(show=True, min_cost_ms=10) - def generate_new_batch(self, current_batch: Batch): - if len(self.waiting_req_list) == 0: - return None - - # 如果当前已经被调度的请求数量超过了上限,直接不调度新的请求了。 - exist_req_num = self.get_batch_dp_req_size(current_batch) - req_is_full = exist_req_num >= self.running_max_req_size - if req_is_full: - return None - - can_run_list = [] - abort_req_list = [] - aborted_count = 0 - for req in self.waiting_req_list: - if req.is_aborted: - # 由于管理的复杂性,只有没有被调度运行过的请求可以因为abort直接在队列中忽略掉. - # 暂停的请求需要恢复后,由 router manager 部分来过滤。暂时保持这种处理方法, 否则会导致管理token和管理req对象的泄漏 - aborted_count += 1 - abort_req_list.append(req) - continue - if exist_req_num + len(can_run_list) + 1 <= self.batch_max_tokens: - can_run_list.append(req) - else: - break - new_batch = None - if len(can_run_list) != 0: - new_batch = Batch(uuid.uuid4().int, can_run_list, dp_size_in_node=self.dp_size_in_node) - for req in abort_req_list: - req: Req = req - logger.debug(f"router abort req id {req.request_id} shm_index: {req.index_in_shm_mem}") - self.free_aborted_req_cpu_cache_pages(req) - self.router.shm_req_manager.put_back_req_obj(req) - self.waiting_req_list = self.waiting_req_list[len(can_run_list) + aborted_count :] - return new_batch - - def _calcu_batch_token_load_batch_not_none(self, current_batch: Batch): - is_busy = self.is_busy() - self._init_cache_list(current_batch, is_busy) - if len(self.cache_len_list) != 0: - self.cache_len_list.sort(key=lambda x: -x[1]) - left_out_len_array = np.array([e[1] for e in self.cache_len_list]) - has_run_len_array = np.array([e[0] for e in self.cache_len_list]) - cum_run_len_array = np.cumsum(has_run_len_array) - size_array = np.arange(1, len(self.cache_len_list) + 1, 1) - need_max_token_num = (left_out_len_array * size_array + cum_run_len_array).max() - else: - need_max_token_num = 0 - with g_router_lock.obj: - return ( - need_max_token_num, - (need_max_token_num + self.router.shared_token_load.get_frozened_token_count(self.dp_index)) - / self.max_total_tokens, - ) diff --git a/skills/test_model/qwen3-8b-pd-nccl/SKILL.md b/skills/test_model/qwen3-8b-pd-nccl/SKILL.md deleted file mode 100644 index fdf559a27e..0000000000 --- a/skills/test_model/qwen3-8b-pd-nccl/SKILL.md +++ /dev/null @@ -1,187 +0,0 @@ ---- -name: test-model-qwen3-8b-pd-nccl -description: >- - LightLLM Qwen3-8b PD disaggregation gsm8k: pd_master on 8089, prefill on 8001, decode on - 8002, tp 2 each. Assign four GPUs by running nvidia-smi and deciding prefill/decode pairs - (no fixed card IDs; no complex shell automation). lm_eval hits pd_master URL. HOST vs - PD_MASTER_IP when co-located. Requires LOG_DIR, MODEL_DIR, proxy cleared, no_proxy, summary.txt. - Use for PD NCCL-style separation tests. ---- - -# Qwen3-8B **PD 分离**(`pd_master` + `prefill` + `decode`)本地 GSM8K 评测 - -**测试标识**:同一 **`--model_dir`**(Qwen3-8B)下拆 **三条** `api_server` 进程——**调度/入口(pd_master)**、**prefill 节点**、**decode 节点**;评测 **`lm_eval`** 只访问 **`pd_master` 的 HTTP 端口(8089)**,由调度转发到 PD 链路。与单机单进程 Base–TP、MTP 等流程区分。 - -**端口约定**:**`pd_master`:`8089`**;**prefill:`8001`**;**decode:`8002`**。启动与就绪探测须覆盖这三处(以及日志中的 PD 注册/报错信息)。 - -**绑定 IP(`HOST` / `PD_MASTER_IP`)**:各进程的 **`--host`** 表示 **本服务监听绑定的 IP**(与其它集群「逻辑 hostname」概念区分时,此处一律按 **绑定地址** 理解)。当 **`pd_master`、`prefill`、`decode` 部署在同一台机器上时**,三者使用的绑定 IP **相同**:此时可只做一次赋值 **`export HOST="${PD_MASTER_IP}"`**(或先将本机对外/LAN IP 赋给 **`PD_MASTER_IP`**,再 **`export HOST="${PD_MASTER_IP}"`**),保证 **`pd_master` 的 `--host`** 与 **prefill/decode 的 `--host`** 一致;**`lm_eval` 的 `base_url` 仍指向 `pd_master`**,故 **`PD_MASTER_IP`** 也同时作为评测 URL 中的主机名。 - -整轮产物落在**同一日志目录**,写入 **`summary.txt`** 与各进程日志(见「日志目录」);**不要**写聚合启动脚本,按「启动说明」逐条手动启动并在后台落盘。 - -## 日志目录(含 `summary.txt`) - -- 每次评测先选定或新建**一个日志目录**(例如带时间戳或任务名),与其它测试轮次分开。 -- **三个 `api_server` 的标准输出/错误**分别写入该目录,建议命名:**`pd_master.log`**、**`prefill.log`**、**`decode.log`**(或分子目录 `pd_master/`、`prefill/`、`decode/`)。 -- **`summary.txt` 固定放在该日志目录下**,汇总:三台进程的启动参数摘要、端口与就绪情况、`lm_eval` 关键结果、失败原因与最终结论。 -- **`eval_gsm8k.log`**:`lm_eval` 终端输出;**`summary.txt`** 仍为**总览结论**。 - -## 启动说明 - -本节包含:启动前检查 → 可变项说明 → 显卡分配 → **按顺序**三条完整 server 命令 → 评测命令。 - -### 启动前检查 - -开跑前先确认资源与环境可用;**不满足则先清理占用端口的进程或释放 GPU**,再按顺序启动。 - -1. **显卡**:prefill / decode 各需 **2 张物理 GPU**(**`--tp 2`**),共 **4 张互不重复**的卡。**不要写死卡号**:先 **`nvidia-smi`**(见下文「显卡分配」),由执行者根据占用与集群情况选定 **prefill 两张、decode 两张**,再 **`export PREFILL_CUDA_DEVICES`**、**`DECODE_CUDA_DEVICES`** 后启动。 -2. **端口**:**`8089`、`8001`、`8002`** 均须未被监听(`ss -tlnp`、`lsof -i :端口` 等);若被占用,结束占用进程后再启动。 -3. **网络 / IP**:**`HOST`** 为 **prefill / decode 的服务绑定 IP**;**`PD_MASTER_IP`** 为 **`pd_master` 的 `--host`**,且与 **`lm_eval` 访问地址**一致。**单机三进程同机时**:**`HOST` 与 `PD_MASTER_IP` 取同一值**(见上文「绑定 IP」);多机分发时再按各节点真实监听地址分别设置。 -4. **代理**:启动 **任一 server 前**将 **`http_proxy` / `https_proxy` 置空**(见各命令块前 `export`);避免代理干扰本地 PD 通信。**评测阶段**使用 **`no_proxy`** 排除本机(见评测命令);若需先用代理下载 `lm_eval` 缓存,见「执行约定」。 - -### 启动服务的命令模板(可变项) - -| 可变项 | 含义 | -|--------|------| -| `LOG_DIR` | 本轮日志根目录,建议**绝对路径**;`export LOG_DIR=…`。 | -| `MODEL_DIR` | 模型目录,对应三条命令中的 **`--model_dir`**;`lm_eval` 的 **`tokenizer` 须与此路径一致**。 | -| `PD_MASTER_IP` | **`pd_master` 进程 `--host`** 所使用的 **绑定 IP**;同时也是 **`lm_eval` 里 `base_url` 的主机部分**(评测客户端访问 pd_master 的地址)。 | -| `HOST` | **`prefill` / `decode` 进程 `--host`** 所使用的 **绑定 IP**(本服务监听地址)。**与 `pd_master` 同机时**:与 **`PD_MASTER_IP` 相同**,可 **`export HOST="${PD_MASTER_IP}"`**。 | -| `PREFILL_CUDA_DEVICES` | **prefill** 的 **`CUDA_VISIBLE_DEVICES`**,形如 `a,b`(两张物理卡索引);由 **`nvidia-smi`** 判断后 **`export`**。 | -| `DECODE_CUDA_DEVICES` | **decode** 的 **`CUDA_VISIBLE_DEVICES`**,形如 `c,d`;与 prefill **四卡互不重复**。 | -| `pd_master.log` 等 | 文件名示例,可改名。 | - -开跑前导出(引号内替换为本机实际值): - -```bash -export LOG_DIR='〈日志根目录〉' -export MODEL_DIR='〈Qwen3-8B 模型目录〉' -export PD_MASTER_IP='〈本机绑定 IP:pd_master --host,且供 lm_eval 访问〉' -# 单机:prefill/decode 与 pd_master 同机时,绑定同一 IP -export HOST="${PD_MASTER_IP}" -# 多机:若 prefill/decode 监听地址不同,再单独 export HOST='〈该机上绑定 IP〉' -``` - -首次试跑可用的**默认 `MODEL_DIR`** 见「执行约定」。 - -### 显卡分配(`nvidia-smi` + 人工/Agent 决策,不用复杂脚本) - -约束:**prefill**、**decode** 各 **2 张物理 GPU**(**`--tp 2`**),共 **4 张互不重复**;**不要**默认写死 `0,1` / `2,3`。 - -1. **查看占用**:在启动 **prefill** 之前(**`pd_master` 已起来之后**即可),执行 **`nvidia-smi`**,需要时可带列表输出便于比对: - `nvidia-smi --query-gpu=index,name,memory.used,memory.free --format=csv` - 或先看总览再决定。 -2. **选定四卡**:由**执行本 skill 的 Agent(或操作者)**根据上述输出、机器上其它任务占用、是否需避开某几张卡等因素,**自行决定**哪 **2** 张给 **prefill**、哪 **2** 张给 **decode**(两组索引不得重叠)。 -3. **写入环境变量**:在同一 shell 中 **`export`**(示例数值仅作格式说明): - -```bash -export PREFILL_CUDA_DEVICES='〈物理索引1〉,〈物理索引2〉' -export DECODE_CUDA_DEVICES='〈物理索引3〉,〈物理索引4〉' -``` - -4. **记录**:把最终 **`PREFILL_CUDA_DEVICES`**、**`DECODE_CUDA_DEVICES`** 及当时 **`nvidia-smi` 要点**记入 **`summary.txt`**。 - -**禁止**:不必编写 **awk / mapfile / 长段 bash** 自动选卡脚本;以 **`nvidia-smi` 事实 + 明确决策**为准。 - -### 1)启动 `pd_master`(须最先就绪监听) - -每条命令前清空代理;以下为 **可直接执行** 的后台形式(含 **`nohup`** 与重定向)。若调试可去掉 `nohup` 与 `>> … &`。 - -```bash -export http_proxy= -export https_proxy= - -nohup python -m lightllm.server.api_server \ - --model_dir "${MODEL_DIR}" \ - --run_mode pd_master \ - --host "${PD_MASTER_IP}" \ - --port 8089 \ - >> "${LOG_DIR}/pd_master.log" 2>&1 & -``` - -### 2)启动 `prefill` 节点 - -**须在 pd_master 已监听且日志无致命错误后再启动**(见「执行约定」)。启动本命令前须已完成 **`nvidia-smi` 决策并 `export PREFILL_CUDA_DEVICES=…`**(见「显卡分配」)。 - -```bash -export http_proxy= -export https_proxy= - -LOADWORKER=18 CUDA_VISIBLE_DEVICES="${PREFILL_CUDA_DEVICES}" \ -nohup python -m lightllm.server.api_server \ - --model_dir "${MODEL_DIR}" \ - --run_mode prefill \ - --tp 2 \ - --dp 1 \ - --host "${HOST}" \ - --port 8001 \ - --disable_cudagraph \ - --pd_master_ip "${PD_MASTER_IP}" \ - --pd_master_port 8089 \ - >> "${LOG_DIR}/prefill.log" 2>&1 & -``` - -### 3)启动 `decode` 节点 - -启动前须已完成 **`export DECODE_CUDA_DEVICES=…`**(见「显卡分配」)。 - -```bash -export http_proxy= -export https_proxy= - -LOADWORKER=18 CUDA_VISIBLE_DEVICES="${DECODE_CUDA_DEVICES}" \ -nohup python -m lightllm.server.api_server \ - --model_dir "${MODEL_DIR}" \ - --run_mode decode \ - --tp 2 \ - --dp 1 \ - --host "${HOST}" \ - --port 8002 \ - --pd_master_ip "${PD_MASTER_IP}" \ - --pd_master_port 8089 \ - >> "${LOG_DIR}/decode.log" 2>&1 & -``` - -### 评测命令(prefill / decode 已与 pd_master 建立 PD 链路后执行) - -**`base_url` 指向 `pd_master`**:`http://${PD_MASTER_IP}:8089/v1/completions`。以下为带日志落盘的**完整命令**(`--model_args` 使用双引号以展开变量;**`tokenizer` 与 `MODEL_DIR` 一致**): - -```bash -export http_proxy= -export https_proxy= - -HF_ALLOW_CODE_EVAL=1 HF_DATASETS_OFFLINE=0 \ -no_proxy=localhost,127.0.0.1,0.0.0.0,::1,${PD_MASTER_IP} \ -lm_eval --model local-completions \ - --model_args "{\"model\":\"qwen/qwen3-8b\", \"base_url\":\"http://${PD_MASTER_IP}:8089/v1/completions\", \"max_length\": 16384, \"tokenized_requests\": false, \"tokenizer\":\"${MODEL_DIR}\"}" \ - --tasks gsm8k --batch_size 500 --confirm_run_unsafe_code \ - >> "${LOG_DIR}/eval_gsm8k.log" 2>&1 -``` - -- **`no_proxy`**:须包含本机与 **`PD_MASTER_IP`**(及脚本中的 `0.0.0.0`、`::1` 等),避免评测流量误走 HTTP 代理。 -- 若环境需要,可同时设置 **`NO_PROXY`** 与 **`no_proxy`** 一致。 -- **`tokenized_requests`: `false`** 与脚本一致。 -- 调试可不重定向:去掉末尾 `>> "${LOG_DIR}/eval_gsm8k.log" 2>&1`。 - -## 执行约定(不要额外写“专用启动脚本”) - -**模型目录**:**首轮试跑**可先: - -```bash -export MODEL_DIR=/mtc/models/qwen3-8b -``` - -无法启动或路径类报错时,**请用户提供**本机实际 **`MODEL_DIR`**;保持 **`tokenizer` 与 `--model_dir` 同路径**。 - -**`lm_eval` 与代理 / 缓存**:若评测依赖首次下载缓存,可先**保留代理**单独跑一次 `lm_eval` 完成缓存下载,再**清空代理**并按上文 **`no_proxy`** 跑正式评测(与脚本注释一致)。 - -1. **启动顺序**:先 **`pd_master`** → 再 **`nvidia-smi` 决策并 `export PREFILL_CUDA_DEVICES` / `DECODE_CUDA_DEVICES`** → 再 **prefill** → 再 **decode**;不要颠倒。每一步将输出重定向到 **`LOG_DIR`** 下对应日志。 -2. **不要用 health 接口** 作为唯一依据;改为:**端口 listen**(8089 / 8001 / 8002)并结合日志判断是否已与 pd_master 建立 PD 链路或是否报错。 -3. **等待 / 轮询**:若端口未就绪或链路未建立,约 **每 20 秒** 查看 **`pd_master.log`、`prefill.log`、`decode.log`**,区分仍在启动还是已报错;异常写入 **`summary.txt`** 并停止后续步骤。 -4. **维护 `summary.txt`**:记录三条启动命令摘要(或等价)、**本次 `PREFILL_CUDA_DEVICES` / `DECODE_CUDA_DEVICES` 及选卡依据(`nvidia-smi` 要点)**、各端口检测结果、`lm_eval` 关键输出;结束后写**最终汇总**。 -5. **测试结束后**:**关闭本次启动的所有相关进程**(`pd_master`、prefill、decode),释放端口与 GPU。 -6. **错误记录**:启动或评测失败时,将错误摘要记入 **`summary.txt`**,并在对话中说明关键信息。 - -## 输出文件 - -- **`summary.txt`**:位于**本轮日志目录**,作为本次 PD 分离评测的**最终总结**。 -- **服务与评测日志**:**`pd_master.log`、`prefill.log`、`decode.log`、`eval_gsm8k.log`** 均落在**同一日志目录**。 diff --git a/skills/test_model/qwen3-8b-pd-nixl/SKILL.md b/skills/test_model/qwen3-8b-pd-nixl/SKILL.md index 08ee7258df..5ee864a910 100644 --- a/skills/test_model/qwen3-8b-pd-nixl/SKILL.md +++ b/skills/test_model/qwen3-8b-pd-nixl/SKILL.md @@ -14,7 +14,7 @@ description: >- # Qwen3-8B **PD 分离(NIXL)**(`pd_master` + `nixl_prefill` + `nixl_decode`)本地 GSM8K 评测 -**测试标识**:同一 **`--model_dir`**(Qwen3-8B)下拆 **三条** `api_server` 进程——**调度/入口(`pd_master`)**、**`nixl_prefill` 节点**、**`nixl_decode` 节点**;评测 **`lm_eval`** 只访问 **`pd_master` 的 HTTP 端口(8089)**。与 **NCCL 版 PD**(`prefill` / `decode`)区分之处在于 **`--run_mode`** 与 **prefill/decode 前须配置 UCX/RDMA 环境变量**。 +**测试标识**:同一 **`--model_dir`**(Qwen3-8B)下拆 **三条** `api_server` 进程——**调度/入口(`pd_master`)**、**`nixl_prefill` 节点**、**`nixl_decode` 节点**;评测 **`lm_eval`** 只访问 **`pd_master` 的 HTTP 端口(8089)**。默认使用 NIXL 传输;需要验证 NCCL 数据面时,设置 **`LIGHTLLM_PD_KV_TRANSPORT_BACKEND=nccl`**,上层仍保持相同的 `nixl_prefill` / `nixl_decode` 管理路径。 **端口约定**:**`pd_master`:`8089`**;**prefill:`8001`**;**decode:`8002`**。启动与就绪探测须覆盖这三处(以及日志中的 PD 注册/报错信息)。 @@ -70,7 +70,7 @@ export UCX_TLS=rc,cuda,gdr_copy ### 显卡分配(`nvidia-smi` + 人工/Agent 决策,不用复杂脚本) -与 **qwen3-8b-pd-nccl** skill 相同:**prefill**、**decode** 各 **2** 张 GPU,共 **4** 张互不重复。 +**nixl_prefill**、**nixl_decode** 各 **2** 张 GPU,共 **4** 张互不重复。需要验证 NCCL 数据面时,额外设置 **`LIGHTLLM_PD_KV_TRANSPORT_BACKEND=nccl`**。 1. 执行 **`nvidia-smi`**(可选用 `--query-gpu=index,name,memory.used,memory.free --format=csv`)。 2. 由执行者选定哪 2 张给 prefill、哪 2 张给 decode(不重叠)。 diff --git a/test/acc/test_pd_nccl.sh b/test/acc/test_pd_nccl.sh deleted file mode 100644 index 3f7dc267c8..0000000000 --- a/test/acc/test_pd_nccl.sh +++ /dev/null @@ -1,58 +0,0 @@ -$pd_master_ip 为本机的ip地址, 测试的时候,自己修改为对应的ip地址 - -# 启动pd_master节点 -# 测试前关闭代理 -export http_proxy= -export https_proxy= -python -m lightllm.server.api_server --model_dir /mtc/models/qwen3-8b --run_mode "pd_master" --host $pd_master_ip --port 8089 - -# 启动prefill 节点 -$host 为本机的ip地址, 测试的时候,自己修改为对应的ip地址 -$pd_master_ip 为pd_master的ip地址, 测试的时候,自己修改为对应的ip地址,在测试的时候为本机ip地址 -# 测试前关闭代理 -export http_proxy= -export https_proxy= -LOADWORKER=18 CUDA_VISIBLE_DEVICES=0,1 python -m lightllm.server.api_server \ ---model_dir /mtc/models/qwen3-8b \ ---run_mode "prefill" \ ---tp 2 \ ---dp 1 \ ---host $host \ ---port 8001 \ ---disable_cudagraph \ ---pd_master_ip $pd_master_ip \ ---pd_master_port 8089 - -# 启动 decode 节点 -# 测试前关闭代理 -export http_proxy= -export https_proxy= -$host 为本机的ip地址, 测试的时候,自己修改为对应的ip地址 -$pd_master_ip 为pd_master的ip地址, 测试的时候,自己修改为对应的ip地址,在测试的时候为本机ip地址 -LOADWORKER=18 CUDA_VISIBLE_DEVICES=2,3 python -m lightllm.server.api_server \ ---model_dir /mtc/models/qwen3-8b \ ---run_mode "decode" \ ---tp 2 \ ---dp 1 \ ---host $host \ ---port 8002 \ ---pd_master_ip $pd_master_ip \ ---pd_master_port 8089 - -# 等待 prefill 和 decode 节点启动完成,并连上 pd master以后,执行测试脚本 -# 测试前关闭代理 -export http_proxy= -export https_proxy= -$pd_master_ip 为pd_master的ip地址, 测试的时候,自己修改为对应的ip地址 -export no_proxy="localhost,127.0.0.1,0.0.0.0,::1" HF_ALLOW_CODE_EVAL=1 HF_DATASETS_OFFLINE=0 lm_eval \ ---model local-completions --model_args \ -'{"model":"qwen/qwen3-8b", "base_url":"http://$pd_master_ip:8089/v1/completions", "max_length": 16384, "tokenized_requests": false}' \ ---tasks gsm8k --batch_size 500 --confirm_run_unsafe_code - -# 1. 按顺序在不同的cmd中启动上面的程序,然后再执行评测脚本,将结果写入out.txt 中,注意需要标记启动的参数和结果信息。 -# 2. 执行评测命令的时候,需要用no_proxy 将本地local ip 排除。 -# 3. 不要写额外的脚本来启动服务,就是单独一个一个的按照上面的描述启动服务,然后再执行评测脚本,然后注意等待服务启动完成,可以20s检测一次其控制台输出,看是否启动完成,还是启动报错。 -# 4. 最后需要总结下测试的结果,并将结果输出到对话中。 -# 5. 如果启动过程中出现错误,需要记录错误信息,并输出到对话中。 -# 6. 测试完成后,关闭所有启动的进程。 -# 7. lm_eval 的评测命令有时候需要利用代理去下载一些缓存,所以可以先不关闭代码,跑一次lm_eval对应的命令,等cache下载好了以后,再关闭代理,跑第二次评测命令。 \ No newline at end of file diff --git a/test/start_scripts/README.md b/test/start_scripts/README.md index 8ed44a2753..a84acf96dc 100644 --- a/test/start_scripts/README.md +++ b/test/start_scripts/README.md @@ -20,15 +20,13 @@ This directory contains various startup scripts for deploying DeepSeek models wi #### Single PD Master Mode - `single_pd_master/pd_master.sh` - PD Master service -- `single_pd_master/pd_prefill.sh` - Prefill service -- `single_pd_master/pd_decode.sh` - Decode service +- `single_pd_master/pd_nixl_prefill.sh` - Prefill service +- `single_pd_master/pd_nixl_decode.sh` - Decode service #### Multi PD Master Mode - `multi_pd_master/config_server.sh` - Configuration server - `multi_pd_master/pd_master_1.sh` - PD Master 1 - `multi_pd_master/pd_master_2.sh` - PD Master 2 -- `multi_pd_master/pd_prefill.sh` - Prefill service -- `multi_pd_master/pd_decode.sh` - Decode service ## Usage Instructions @@ -73,10 +71,10 @@ sh multi_node_ep_node1.sh sh single_pd_master/pd_master.sh # Step 2: Start Prefill service -sh single_pd_master/pd_prefill.sh +sh single_pd_master/pd_nixl_prefill.sh # Step 3: Start Decode service -sh single_pd_master/pd_decode.sh +sh single_pd_master/pd_nixl_decode.sh ``` ### 6. Multi PD Master Mode @@ -89,9 +87,8 @@ sh multi_pd_master/config_server.sh sh multi_pd_master/pd_master_1.sh sh multi_pd_master/pd_master_2.sh -# Step 3: Start Prefill and Decode services -sh multi_pd_master/pd_prefill.sh -sh multi_pd_master/pd_decode.sh +# Step 3: Start Prefill and Decode services with the nixl_prefill/nixl_decode run modes. +# Multi-PD startup scripts for these nodes are not provided in this directory. ``` ## Configuration Guide @@ -99,7 +96,7 @@ sh multi_pd_master/pd_decode.sh ### Environment Variables - `LOADWORKER`: Model loading thread count, recommended 8-18 -- `DISABLE_KV_TRANS_USE_P2P`: Disable P2P communication optimization to transfer kv data +- `LIGHTLLM_PD_KV_TRANSPORT_BACKEND`: KV transporter backend for PD disaggregation, `nixl` by default; set to `nccl` to use the NCCL data plane. - `CUDA_VISIBLE_DEVICES`: Specify GPU devices to use ### Important Parameters @@ -198,4 +195,4 @@ python benchmark_client.py \ 2. Adjust parameters according to actual hardware configuration 3. Ensure network environment meets multi-node deployment requirements 4. Recommend thorough testing before production deployment -5. Regularly monitor service status and performance metrics \ No newline at end of file +5. Regularly monitor service status and performance metrics diff --git a/test/start_scripts/multi_pd_master.sh b/test/start_scripts/multi_pd_master.sh deleted file mode 100644 index 7b83923929..0000000000 --- a/test/start_scripts/multi_pd_master.sh +++ /dev/null @@ -1,34 +0,0 @@ -# 多 pd_master 节点部署示例 -python -m lightllm.server.api_server --run_mode "config_server" --config_server_host 10.120.114.74 --config_server_port 60088 - -python -m lightllm.server.api_server --model_dir /mtc/models/DeepSeek-V2-Lite-Chat --run_mode "pd_master" --host 10.120.114.74 --port 60011 --config_server_host 10.120.114.74 --config_server_port 60088 - -python -m lightllm.server.api_server --model_dir /mtc/models/DeepSeek-V2-Lite-Chat --run_mode "pd_master" --host 10.120.114.74 --port 60012 --config_server_host 10.120.114.74 --config_server_port 60088 - -nvidia-cuda-mps-control -d -CUDA_VISIBLE_DEVICES=0 LOADWORKER=1 python -m lightllm.server.api_server --model_dir /mtc/models/DeepSeek-V2-Lite-Chat \ ---run_mode "prefill" \ ---host 10.120.178.74 \ ---port 8019 \ ---tp 1 \ ---nccl_port 2732 \ ---max_total_token_num 40000 \ ---tokenizer_mode fast \ ---max_req_total_len 16000 \ ---running_max_req_size 128 \ ---disable_cudagraph \ ---config_server_host 10.120.114.74 \ ---config_server_port 60088 - -CUDA_VISIBLE_DEVICES=1 LOADWORKER=10 python -m lightllm.server.api_server --model_dir /mtc/models/DeepSeek-V2-Lite-Chat \ ---run_mode "decode" \ ---host 10.120.178.74 \ ---port 8121 \ ---nccl_port 12322 \ ---tp 1 \ ---max_total_token_num 40000 \ ---graph_max_len_in_batch 2048 \ ---graph_max_batch_size 16 \ ---tokenizer_mode fast \ ---config_server_host 10.120.114.74 \ ---config_server_port 60088 \ No newline at end of file diff --git a/test/start_scripts/multi_pd_master/pd_decode.sh b/test/start_scripts/multi_pd_master/pd_decode.sh deleted file mode 100644 index cb55ec338f..0000000000 --- a/test/start_scripts/multi_pd_master/pd_decode.sh +++ /dev/null @@ -1,19 +0,0 @@ -# decode -# host: the host of the decode server -# config_server_host: the host of the config server -# sh decode.sh -export host=$1 -export config_server_host=$2 -nvidia-cuda-mps-control -d -MOE_MODE=EP LOADWORKER=18 python -m lightllm.server.api_server \ ---model_dir /path/DeepSeek-R1 \ ---run_mode "decode" \ ---host $host \ ---port 8121 \ ---nccl_port 12322 \ ---tp 8 \ ---dp 8 \ ---config_server_host $config_server_host \ ---config_server_port 60088 -# if you want to enable microbatch overlap, you can uncomment the following lines -#--enable_decode_microbatch_overlap \ No newline at end of file diff --git a/test/start_scripts/multi_pd_master/pd_prefill.sh b/test/start_scripts/multi_pd_master/pd_prefill.sh deleted file mode 100644 index 45f6c0c011..0000000000 --- a/test/start_scripts/multi_pd_master/pd_prefill.sh +++ /dev/null @@ -1,21 +0,0 @@ -# prefill -# host: the host of the prefill server -# config_server_host: the host of the config server -# sh pd_prefill.sh -export host=$1 -export config_server_host=$2 -nvidia-cuda-mps-control -d -LOADWORKER=18 python -m lightllm.server.api_server \ ---model_dir /path/DeepSeek-R1 \ ---run_mode "prefill" \ ---host $host \ ---port 8019 \ ---tp 8 \ ---dp 8 \ ---nccl_port 2732 \ ---disable_cudagraph \ ---config_server_host $config_server_host \ ---config_server_port 60088 \ ---enable_ep_moe -# if you want to enable microbatch overlap, you can uncomment the following lines -#--enable_prefill_microbatch_overlap \ No newline at end of file diff --git a/test/start_scripts/single_pd_master/pd_decode.sh b/test/start_scripts/single_pd_master/pd_decode.sh deleted file mode 100644 index dac7a6dac6..0000000000 --- a/test/start_scripts/single_pd_master/pd_decode.sh +++ /dev/null @@ -1,20 +0,0 @@ -# PD decode mode for deepseek R1 (DP+EP) on H200 -# host: the host of the current node -# pd_master_ip: the ip of the pd master -# sh pd_decode.sh -export host=$1 -export pd_master_ip=$2 -nvidia-cuda-mps-control -d -LOADWORKER=18 python -m lightllm.server.api_server \ ---model_dir /path/DeepSeek-R1 \ ---run_mode "decode" \ ---tp 8 \ ---dp 8 \ ---host $host \ ---port 8121 \ ---nccl_port 12322 \ ---enable_ep_moe \ ---pd_master_ip $pd_master_ip \ ---pd_master_port 60011 -# if you want to enable microbatch overlap, you can uncomment the following lines -#--enable_decode_microbatch_overlap \ No newline at end of file diff --git a/test/start_scripts/single_pd_master/pd_nixl_decode.sh b/test/start_scripts/single_pd_master/pd_nixl_decode.sh index 4b3fd0bc4e..f4f279d89e 100644 --- a/test/start_scripts/single_pd_master/pd_nixl_decode.sh +++ b/test/start_scripts/single_pd_master/pd_nixl_decode.sh @@ -1,7 +1,7 @@ # PD decode mode for deepseek R1 (DP+EP) on H200 # host: the host of the current node # pd_master_ip: the ip of the pd master -# sh pd_decode.sh +# sh pd_nixl_decode.sh export host=$1 export pd_master_ip=$2 @@ -22,4 +22,4 @@ LOADWORKER=18 python -m lightllm.server.api_server \ --pd_master_ip $pd_master_ip \ --pd_master_port 60011 # if you want to enable microbatch overlap, you can uncomment the following lines -#--enable_decode_microbatch_overlap \ No newline at end of file +#--enable_decode_microbatch_overlap diff --git a/test/start_scripts/single_pd_master/pd_nixl_prefill.sh b/test/start_scripts/single_pd_master/pd_nixl_prefill.sh index f415919f90..3a10a32f2a 100644 --- a/test/start_scripts/single_pd_master/pd_nixl_prefill.sh +++ b/test/start_scripts/single_pd_master/pd_nixl_prefill.sh @@ -1,7 +1,7 @@ # PD prefill mode for deepseek R1 (DP+EP) on H200 # host: the host of the current node # pd_master_ip: the ip of the pd master -# sh pd_prefill.sh +# sh pd_nixl_prefill.sh ### nixl pd mode used export UCX_NET_DEVICES=$(ibv_devinfo | grep 'hca_id:' | grep -v -E 'mlx5_8|mlx5_9' | awk '{print $2":1"}' | paste -sd, -) @@ -24,4 +24,4 @@ LOADWORKER=18 python -m lightllm.server.api_server \ --pd_master_ip $pd_master_ip \ --pd_master_port 60011 # if you want to enable microbatch overlap, you can uncomment the following lines -#--enable_prefill_microbatch_overlap \ No newline at end of file +#--enable_prefill_microbatch_overlap diff --git a/test/start_scripts/single_pd_master/pd_prefill.sh b/test/start_scripts/single_pd_master/pd_prefill.sh deleted file mode 100644 index 6bde9ef32c..0000000000 --- a/test/start_scripts/single_pd_master/pd_prefill.sh +++ /dev/null @@ -1,21 +0,0 @@ -# PD prefill mode for deepseek R1 (DP+EP) on H200 -# host: the host of the current node -# pd_master_ip: the ip of the pd master -# sh pd_prefill.sh -export host=$1 -export pd_master_ip=$2 -nvidia-cuda-mps-control -d -LOADWORKER=18 python -m lightllm.server.api_server \ ---model_dir /path/DeepSeek-R1 \ ---run_mode "prefill" \ ---tp 8 \ ---dp 8 \ ---host $host \ ---port 8019 \ ---nccl_port 2732 \ ---disable_cudagraph \ ---pd_master_ip $pd_master_ip \ ---pd_master_port 60011 \ ---enable_ep_moe -# if you want to enable microbatch overlap, you can uncomment the following lines -#--enable_prefill_microbatch_overlap \ No newline at end of file From 6aa40a2b6438f247c572ebefe768af00f746d83b Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Tue, 9 Jun 2026 09:01:18 +0000 Subject: [PATCH 02/17] gifix --- .../router/model_infer/mode_backend/continues_batch/__init__.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 lightllm/server/router/model_infer/mode_backend/continues_batch/__init__.py diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/__init__.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 From 37e91217982048073b62659e16b2a0370ff58c09 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Tue, 9 Jun 2026 09:23:53 +0000 Subject: [PATCH 03/17] fix --- lightllm/common/basemodel/infer_lock.py | 138 ------------------ lightllm/server/router/manager.py | 7 - .../server/router/model_infer/infer_batch.py | 9 -- .../model_infer/mode_backend/base_backend.py | 33 ----- .../mode_backend/chunked_prefill/impl.py | 5 - .../mode_backend/diverse_backend/impl.py | 6 - .../mode_backend/dp_backend/impl.py | 13 -- .../generic_padded_pre_process.py | 5 - .../mode_backend/generic_pre_process.py | 5 - .../pd_nixl/decode_node_impl/decode_impl.py | 8 +- .../server/router/model_infer/model_rpc.py | 8 - .../server/router/req_queue/base_queue.py | 8 +- .../router/req_queue/chunked_prefill/impl.py | 47 +++--- .../chunked_prefill/impl_for_nixl_pd.py | 1 - .../server/router/req_queue/dp_base_queue.py | 8 +- 15 files changed, 29 insertions(+), 272 deletions(-) delete mode 100644 lightllm/common/basemodel/infer_lock.py diff --git a/lightllm/common/basemodel/infer_lock.py b/lightllm/common/basemodel/infer_lock.py deleted file mode 100644 index 9da027e662..0000000000 --- a/lightllm/common/basemodel/infer_lock.py +++ /dev/null @@ -1,138 +0,0 @@ -# 这不是一个很好的设计但是不是很好找到更好更简单对架构入侵更小的实现方法。 -# 这个地方声明的锁和计数,主要是用来解决在 PD 分离模式下,kv_move_manager 进程中会出现 -# 通过rpyc调用操作 radix cache 和 mem_manager 中的数据的问题,这可能导致严重的数据同步 -# 问题,主要原因是各个tp的推理进程运行到的位置节点并没有严格的保证,导致radix cache 和 -# mem manager 中的数据出现各个进程间不一致的问题。 -# 下面的实现中,通过一个锁和计数对象, 配合使用的方式,来解决这个问题。 -from dataclasses import dataclass -import numpy as np -import threading -from lightllm.server.router.dynamic_prompt.shared_arr import SharedArray -import torch.distributed as dist -import time -import torch.multiprocessing as mp -from lightllm.utils.log_utils import init_logger - -logger = init_logger(__name__) - - -class InferStateLock: - def __init__(self, name, rank_in_dp: int, dp_rank_in_node: int, dp_world_size: int): - self.infer_lock = threading.Lock() - self.dp_rank_in_node = dp_rank_in_node - # sync_world_size 应该是 min(dp_world_size, node_world_size) - self.dp_world_size = dp_world_size - self.rank_in_dp = rank_in_dp - # 默认开 128 tp 的空间, 现在应该没什么卡能开这么大的tp 吧 - self.lock_tp_infos = SharedArray( - f"{name}_dp_rank_{str(self.dp_rank_in_node)}_lock_tp_infos", shape=(self.dp_world_size + 1,), dtype=np.int64 - ) - self.lock_tp_infos.arr[:] = 0 - - def add_cur_mark(self): - self.lock_tp_infos.arr[self.rank_in_dp] += 1 - - def get_cur_mark(self): - return self.lock_tp_infos.arr[self.rank_in_dp] - - def get_max_mark_in_group(self): - return np.max(self.lock_tp_infos.arr[0 : self.dp_world_size]) - - def judge_cur_mark_equal_max_mark_in_group(self): - return self.get_cur_mark() == self.get_max_mark_in_group() - - def judge_mark_in_group_all_same(self): - marks = self.lock_tp_infos.arr[0 : self.dp_world_size] - return bool(np.all(marks == marks[0])) - - def acquire_lock_and_update_cur_mark(self): - self.infer_lock.acquire() - self.add_cur_mark() - - def release_lock(self): - self.infer_lock.release() - - def set_group_wait_mark(self): - if self.rank_in_dp == 0: - self.lock_tp_infos.arr[-1] = 1 - - def unset_group_wait_mark(self): - if self.rank_in_dp == 0: - self.lock_tp_infos.arr[-1] = 0 - - def get_group_wait_mark(self): - return self.lock_tp_infos.arr[-1] - - -@dataclass -class G_Infer_Lock: - obj: InferStateLock = None - dp_world_size: int = None - - def acquire(self): - if self.obj is not None: - # 当遇到有同步请求的时候,同时自己的mark已经是最大的mark的时候,就在这里休眠, - # 不去竞争锁, 因为 wait_mark == 1 的时候, 说明acquire_lock_until_ready被调用, - # 有推理进程在申请同步点操作 - while self.obj.get_group_wait_mark() == 1 and self.obj.judge_cur_mark_equal_max_mark_in_group(): - time.sleep(0) - - self.obj.acquire_lock_and_update_cur_mark() - - def release(self): - if self.obj is not None: - self.obj.release_lock() - - -# 后续由 backend 对象来对obj进行初始化赋值,方便进行全局调用 -g_infer_state_lock = G_Infer_Lock() - - -# 下面两个函数需要配对使用 -def acquire_lock_until_ready(nccl_group): - # 单卡一tp不用过度加锁 - if g_infer_state_lock.dp_world_size == 1: - g_infer_state_lock.obj.infer_lock.acquire() - return - - g_infer_state_lock.obj.set_group_wait_mark() - while True: - g_infer_state_lock.obj.infer_lock.acquire() - dist.barrier(nccl_group) - judge_ans = g_infer_state_lock.obj.judge_mark_in_group_all_same() - dist.barrier(nccl_group) - - if judge_ans is not True: - # 释放锁进行重试 - g_infer_state_lock.obj.infer_lock.release() - time.sleep(0.001) - logger.info("wait get locks sleep 1ms") - else: - break - - g_infer_state_lock.obj.unset_group_wait_mark() - return - - -def release_acquired_lock(): - g_infer_state_lock.obj.infer_lock.release() - - -@dataclass -class G_Router_Lock: - """ - 保护pd分离模式下, 一些调度相关信息数据的操作。 - """ - - obj = None # 进程锁对象 - - def acquire(self): - if self.obj is not None: - self.obj.acquire() - - def release(self): - if self.obj is not None: - self.obj.release() - - -g_router_lock = G_Router_Lock() diff --git a/lightllm/server/router/manager.py b/lightllm/server/router/manager.py index 368205ed3e..1e869093d0 100644 --- a/lightllm/server/router/manager.py +++ b/lightllm/server/router/manager.py @@ -28,7 +28,6 @@ from lightllm.utils.log_utils import init_logger, log_time_ready from lightllm.server.router.token_load import TokenLoad from lightllm.server.metrics.manager import MetricClient -from lightllm.common.basemodel.infer_lock import g_router_lock from lightllm.common.kv_cache_mem_manager import ReadOnlyStaticsMemoryManager from lightllm.utils.graceful_utils import graceful_registry from lightllm.utils.process_check import start_parent_check_thread @@ -97,11 +96,6 @@ def __init__(self, args: StartArgs): self.metric_client = MetricClient(args.metric_port) self.is_pd_run_mode = self.args.run_mode in ["nixl_prefill", "nixl_decode"] self.is_pd_decode_mode = self.args.run_mode == "nixl_decode" - # p d 分离模式下,需要调度锁来同步调度端和推理端的一些数据操作 - # 主要是为了防止调度失误,造成 OOM 等错误 - self.router_lock = mp.Lock() - g_router_lock.obj = self.router_lock - self.shm_reqs_io_buffer = ShmObjsIOBuffer() self.cpu_cache_client = ( @@ -135,7 +129,6 @@ async def wait_to_model_ready(self): rank_in_node=rank_in_node, node_world_size=node_world_size, info_queue=self.info_queue, - router_lock=self.router_lock, ) ) tasks.append(task) diff --git a/lightllm/server/router/model_infer/infer_batch.py b/lightllm/server/router/model_infer/infer_batch.py index 6c4b19e65c..419d6491ac 100644 --- a/lightllm/server/router/model_infer/infer_batch.py +++ b/lightllm/server/router/model_infer/infer_batch.py @@ -18,7 +18,6 @@ ) from lightllm.utils.log_utils import init_logger from lightllm.server.req_id_generator import convert_sub_id_to_group_id -from lightllm.common.basemodel.infer_lock import g_infer_state_lock from lightllm.server.multimodal_params import MultimodalParams from lightllm.utils.custom_kernel_utis import custom_cat from lightllm.utils.envs_utils import get_env_start_args @@ -288,15 +287,12 @@ def _filter(self, finished_request_ids: List[int]): def filter_reqs(self, finished_reqs: List["InferReq"]): if finished_reqs: - g_infer_state_lock.acquire() self._filter([req.req_id for req in finished_reqs]) - g_infer_state_lock.release() return @torch.no_grad() def pause_reqs(self, pause_reqs: List["InferReq"], is_master_in_dp: bool): if pause_reqs: - g_infer_state_lock.acquire() free_token_index = [] for req in pause_reqs: @@ -314,13 +310,10 @@ def pause_reqs(self, pause_reqs: List["InferReq"], is_master_in_dp: bool): if len(free_token_index) != 0: free_token_index = custom_cat(free_token_index) self.req_manager.free_token(free_token_index) - - g_infer_state_lock.release() return self def recover_paused_reqs(self, paused_reqs: List["InferReq"], is_master_in_dp: bool, can_alloc_token_num: int): if paused_reqs: - g_infer_state_lock.acquire() for req in paused_reqs: prefill_need_token_num = req.get_cur_total_len() @@ -338,8 +331,6 @@ def recover_paused_reqs(self, paused_reqs: List["InferReq"], is_master_in_dp: bo req.shm_req.is_paused = False logger.debug(f"infer recover paused req id {req.req_id}") can_alloc_token_num -= prefill_need_token_num - - g_infer_state_lock.release() return def get_can_alloc_token_num(self): diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index f16fec243e..3953f8d38d 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -11,7 +11,6 @@ from lightllm.models import get_model from lightllm.server.router.model_infer.infer_batch import InferReq, InferReqUpdatePack from lightllm.server.router.token_load import TokenLoad -from lightllm.common.basemodel.infer_lock import g_infer_state_lock, InferStateLock from lightllm.common.basemodel.basemodel import TpPartBaseModel from lightllm.common.req_manager import ReqManagerForMamba from lightllm.common.linear_att_cache_manager import LinearAttCacheManager @@ -124,24 +123,6 @@ def init_model(self, kvargs): self.shared_token_load = TokenLoad(f"{get_unique_server_name()}_shared_token_load", self.dp_size_in_node) - # 为 p d 分离模式添加的全局锁管理,用于做一些同步操作。 一定需要在 - # init_process_group 之后调用 - g_infer_state_lock.obj = ( - InferStateLock( - name=get_unique_server_name(), - rank_in_dp=self.rank_in_dp, - dp_rank_in_node=self.dp_rank_in_node, - dp_world_size=self.dp_world_size, - ) - if self.run_mode in ["nixl_prefill", "nixl_decode"] - else None - ) - g_infer_state_lock.dp_world_size = self.dp_world_size - self.infer_state_lock = g_infer_state_lock - # 防止InferStateLock 中的全局共享信息被重复异常初始化,导致同步异常的问题。 - # 所以做一次barrier等待 - dist.barrier() - if self.args.enable_multimodal: g_infer_context.init_cpu_embed_cache_client() @@ -500,10 +481,7 @@ def _init_reqs(self, reqs: List[Tuple]): if self.dp_size_in_node != 1: dp_rank_in_node = self.dp_rank_in_node reqs = [req for req in reqs if req[3] == dp_rank_in_node] - - g_infer_state_lock.acquire() g_infer_context.add_reqs(reqs) - g_infer_state_lock.release() req_ids = [e[0] for e in reqs] if self.args.enable_cpu_cache: @@ -513,9 +491,7 @@ def _init_reqs(self, reqs: List[Tuple]): def _load_cpu_cache_to_reqs(self, req_ids): req_objs: List[InferReq] = [g_infer_context.requests_mapping[req_id] for req_id in req_ids] - g_infer_state_lock.acquire() self.multi_level_cache_module.load_cpu_cache_to_reqs(reqs=req_objs) - g_infer_state_lock.release() return def _filter_not_ready_reqs(self, req_ids: List[int]) -> List[InferReq]: @@ -533,13 +509,11 @@ def _timer_merge_radix_tree(self): and (self._radix_tree_merge_counter % self._radix_tree_merge_update_delta == 0) and self.radix_cache is not None ): - g_infer_state_lock.acquire() start = time.time() self.radix_cache.merge_unreferenced_nodes() self.logger.info( f"radix tree merge_unreferenced_nodes cost time {time.time() - start} s in rank {self.global_rank}" ) - g_infer_state_lock.release() return # 一些可以复用的通用功能函数 @@ -597,9 +571,6 @@ def _get_classed_reqs( wait_pause_count = 0 prefill_tokens = 0 - # 因为会使用到 radix cache 和 mem_manager 的计数信息 - # 所以需要加锁保护。 - g_infer_state_lock.acquire() can_alloc_token_num = g_infer_context.get_can_alloc_token_num() for req_obj in ready_reqs: @@ -660,8 +631,6 @@ def _get_classed_reqs( req_obj.wait_pause = True wait_pause_count += 1 - g_infer_state_lock.release() - self._pre_handle_finished_reqs(finished_reqs=finished_reqs) # 如果使能了 cpu cache 功能,对于已经完成的请求,进行 gpu kv 卸载到 cpu cache的操作。 if self.args.enable_cpu_cache: @@ -764,9 +733,7 @@ def _post_handle( # 一些可以复用的通用功能函数 def _filter_reqs(self, reqs: List[InferReq]): if reqs: - g_infer_state_lock.acquire() g_infer_context.filter_reqs(reqs) - g_infer_state_lock.release() return # 一些可以复用的通用功能函数 diff --git a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py index 60045fab6c..e06f06c4e3 100644 --- a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py @@ -15,7 +15,6 @@ from lightllm.server.router.model_infer.mode_backend.generic_post_process import sample from lightllm.server.router.model_infer.infer_batch import g_infer_context from lightllm.server.router.model_infer.pin_mem_manager import g_pin_mem_manager -from lightllm.common.basemodel.infer_lock import g_infer_state_lock from lightllm.common.basemodel.batch_objs import ModelOutput, ModelInput from lightllm.common.basemodel.triton_kernel.gather_token_id import scatter_token from lightllm.common.basemodel.triton_kernel.mtp_utils import ( @@ -316,9 +315,7 @@ def decode_mtp( ) if len(need_free_mem_indexes) > 0: - g_infer_state_lock.acquire() g_infer_context.req_manager.mem_manager.free(need_free_mem_indexes) - g_infer_state_lock.release() # 第四阶段 event_pack.notify_pre_post_handle() @@ -384,11 +381,9 @@ def _draft_decode_eagle( ): batch_size = main_model_input.batch_size num_reqs = batch_size // (self.mtp_step + 1) - g_infer_state_lock.acquire() if g_infer_context.radix_cache is not None: g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(num_reqs * self.mtp_step) eagle_mem_indexes_cpu = g_infer_context.req_manager.mem_manager.alloc(num_reqs * self.mtp_step) - g_infer_state_lock.release() eagle_mem_indexes = eagle_mem_indexes_cpu.cuda(non_blocking=True) # share some inference info with the main model diff --git a/lightllm/server/router/model_infer/mode_backend/diverse_backend/impl.py b/lightllm/server/router/model_infer/mode_backend/diverse_backend/impl.py index 5a179cb620..34e174bc59 100644 --- a/lightllm/server/router/model_infer/mode_backend/diverse_backend/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/diverse_backend/impl.py @@ -14,7 +14,6 @@ from lightllm.common.basemodel.triton_kernel.gather_token_id import scatter_token from lightllm.server.router.model_infer.pin_mem_manager import g_pin_mem_manager from ..chunked_prefill.impl import ChunkedPrefillBackend -from lightllm.common.basemodel.infer_lock import g_infer_state_lock from lightllm.utils.envs_utils import get_env_start_args @@ -167,7 +166,6 @@ def _diverse_pre_post_handle(self, run_reqs: List[InferReq], is_chuncked_mode: b return update_func_objs def _master_req_to_radix_cache(self, master_req: InferReq): - g_infer_state_lock.acquire() key = master_req.get_input_token_ids()[0 : master_req.cur_kv_len] key = torch.tensor(key, dtype=torch.int64, device="cpu") value = self.model.req_manager.req_to_token_indexs[master_req.req_idx][: master_req.cur_kv_len].detach().cpu() @@ -189,11 +187,9 @@ def _master_req_to_radix_cache(self, master_req: InferReq): share_node, kv_len, value = self.radix_cache.match_prefix(key, update_refs=False) assert share_node == new_shared_kv_node and kv_len == master_req.cur_kv_len self.model.req_manager.req_to_token_indexs[master_req.req_idx][0 : master_req.cur_kv_len] = value - g_infer_state_lock.release() return def _copy_master_req_to_slave_req(self, slave_req: InferReq): - g_infer_state_lock.acquire() master_req = slave_req.related_master_req assert master_req is not None @@ -213,6 +209,4 @@ def _copy_master_req_to_slave_req(self, slave_req: InferReq): slave_req.shm_req.shm_cur_kv_len = slave_req.cur_kv_len assert kv_len <= slave_req.shm_req.input_len - - g_infer_state_lock.release() return diff --git a/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py b/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py index f1896b201a..878d4fade0 100644 --- a/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py @@ -14,7 +14,6 @@ padded_overlap_prepare_decode_inputs, ) from lightllm.server.router.model_infer.mode_backend.overlap_events import OverlapEventPack -from lightllm.common.basemodel.infer_lock import g_infer_state_lock from lightllm.server.router.model_infer.mode_backend.mtp_pre_process import ( prepare_mtp_prefill_inputs, ) @@ -70,8 +69,6 @@ def _init_reqs(self, reqs: List[Tuple]): current_dp_reqs = [req for req in reqs if req[3] == dp_rank_in_node] other_dp_reqs = [req for req in reqs if req[3] != dp_rank_in_node] - g_infer_state_lock.acquire() - infer_reqs = g_infer_context.add_reqs(reqs, init_prefix_cache=True) req_dp_ranks = [req[3] for req in reqs] self.dp_kv_shared_module.fill_reqs_info(reqs=infer_reqs) @@ -79,7 +76,6 @@ def _init_reqs(self, reqs: List[Tuple]): self.dp_kv_shared_module.kv_trans(trans_tasks=trans_taskes) g_infer_context._filter(finished_request_ids=[req[0] for req in other_dp_reqs]) - g_infer_state_lock.release() req_ids = [e[0] for e in current_dp_reqs] @@ -520,9 +516,7 @@ def decode_mtp(self, event_pack: OverlapEventPack, decode_reqs: List[InferReq]): extra_post_req_handle_func=self.extra_post_req_handle_func, ) if len(need_free_mem_indexes) > 0: - g_infer_state_lock.acquire() g_infer_context.req_manager.mem_manager.free(need_free_mem_indexes) - g_infer_state_lock.release() # 第四阶段 event_pack.notify_pre_post_handle() @@ -594,12 +588,9 @@ def _draft_decode_eagle( real_req_num = req_num // (self.mtp_step + 1) padded_req_num = model_input.batch_size // (self.mtp_step + 1) - real_req_num eagle_mem_indexes_cpu = None - - g_infer_state_lock.acquire() if g_infer_context.radix_cache is not None: g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(real_req_num * self.mtp_step) eagle_mem_indexes_cpu = g_infer_context.req_manager.mem_manager.alloc(real_req_num * self.mtp_step) - g_infer_state_lock.release() eagle_mem_indexes = eagle_mem_indexes_cpu.cuda(non_blocking=True) # process the draft model output @@ -842,9 +833,7 @@ def decode_overlap_mtp(self, event_pack: OverlapEventPack, decode_reqs: List[Inf extra_post_req_handle_func=self.extra_post_req_handle_func, ) if len(need_free_mem_indexes) > 0: - g_infer_state_lock.acquire() g_infer_context.req_manager.mem_manager.free(need_free_mem_indexes) - g_infer_state_lock.release() event_pack.notify_pre_post_handle() else: event_pack.notify_post_handle_and_wait_pre_post_handle() @@ -957,11 +946,9 @@ def _draft_decode_eagle_overlap( real_req_num = real_req_num0 + real_req_num1 padded_req_num0 = model_input0.batch_size // (self.mtp_step + 1) - real_req_num0 padded_req_num1 = model_input1.batch_size // (self.mtp_step + 1) - real_req_num1 - g_infer_state_lock.acquire() if g_infer_context.radix_cache is not None: g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(real_req_num * self.mtp_step) eagle_mem_indexes_cpu = g_infer_context.req_manager.mem_manager.alloc(real_req_num * self.mtp_step) - g_infer_state_lock.release() eagle_mem_indexes = eagle_mem_indexes_cpu.cuda(non_blocking=True) eagle_mem_indexes0 = eagle_mem_indexes[0 : real_req_num0 * self.mtp_step] eagle_mem_indexes1 = eagle_mem_indexes[real_req_num0 * self.mtp_step : real_req_num * self.mtp_step] diff --git a/lightllm/server/router/model_infer/mode_backend/generic_padded_pre_process.py b/lightllm/server/router/model_infer/mode_backend/generic_padded_pre_process.py index d3796b6392..aa257e4cef 100644 --- a/lightllm/server/router/model_infer/mode_backend/generic_padded_pre_process.py +++ b/lightllm/server/router/model_infer/mode_backend/generic_padded_pre_process.py @@ -7,7 +7,6 @@ from lightllm.server.router.model_infer.infer_batch import g_infer_context, InferReq from lightllm.utils.infer_utils import calculate_time from lightllm.utils.envs_utils import get_env_start_args -from lightllm.common.basemodel.infer_lock import g_infer_state_lock from lightllm.common.basemodel.batch_objs import ModelInput, ModelOutput @@ -95,11 +94,9 @@ def padded_prepare_prefill_inputs( b_prefill_start_loc = b_q_seq_len.cumsum(dim=0, dtype=torch.int32) - b_q_seq_len # dynamic prompt cache 准备 token - g_infer_state_lock.acquire() if g_infer_context.radix_cache is not None: g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(input_ids.shape[0] - padded_req_num) mem_indexes = g_infer_context.req_manager.mem_manager.alloc(input_ids.shape[0] - padded_req_num) - g_infer_state_lock.release() if padded_req_num > 0: mem_indexes = F.pad( @@ -202,11 +199,9 @@ def padded_prepare_decode_inputs( # dynamic prompt cache 准备 token padded_mem_indexes_num = padded_req_num * (args_mtp_step + 1) - g_infer_state_lock.acquire() if g_infer_context.radix_cache is not None: g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(b_seq_len.shape[0] - padded_mem_indexes_num) mem_indexes = g_infer_context.req_manager.mem_manager.alloc(b_seq_len.shape[0] - padded_mem_indexes_num) - g_infer_state_lock.release() if padded_mem_indexes_num > 0: mem_indexes = F.pad( diff --git a/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py b/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py index ae1af19565..eb67c87715 100644 --- a/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py +++ b/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py @@ -2,7 +2,6 @@ import numpy as np from typing import List, Tuple from lightllm.server.router.model_infer.infer_batch import InferReq, g_infer_context -from lightllm.common.basemodel.infer_lock import g_infer_state_lock from lightllm.common.basemodel.batch_objs import ModelInput from lightllm.utils.envs_utils import ( enable_diverse_mode_gqa_decode_fast_kernel, @@ -69,11 +68,9 @@ def prepare_prefill_inputs(req_objs: List[InferReq], is_chuncked_mode: bool) -> b_prefill_start_loc = b_q_seq_len.cumsum(dim=0, dtype=torch.int32) - b_q_seq_len # dynamic prompt cache 准备 token - g_infer_state_lock.acquire() if g_infer_context.radix_cache is not None: g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(input_ids.shape[0]) mem_indexes = g_infer_context.req_manager.mem_manager.alloc(input_ids.shape[0]) - g_infer_state_lock.release() model_input = ModelInput( batch_size=b_seq_len.shape[0], @@ -141,11 +138,9 @@ def prepare_decode_inputs(req_objs: List[InferReq]) -> Tuple[ModelInput, List[In b_mark_shared_group = None # dynamic prompt cache 准备 token - g_infer_state_lock.acquire() if g_infer_context.radix_cache is not None: g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(b_seq_len.shape[0]) mem_indexes = g_infer_context.req_manager.mem_manager.alloc(b_seq_len.shape[0]) - g_infer_state_lock.release() model_input = ModelInput( batch_size=b_seq_len.shape[0], diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/decode_impl.py b/lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/decode_impl.py index bba3d11965..a00af420c6 100644 --- a/lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/decode_impl.py +++ b/lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/decode_impl.py @@ -3,7 +3,7 @@ from lightllm.server.pd_io_struct import NIXLChunckedTransTask, NIXLChunckedTransTaskGroup, NIXLAbortReq from lightllm.server.router.model_infer.mode_backend.chunked_prefill.impl import ChunkedPrefillBackend from typing import List, Tuple -from lightllm.server.router.model_infer.infer_batch import g_infer_context, InferReq, g_infer_state_lock +from lightllm.server.router.model_infer.infer_batch import g_infer_context, InferReq from lightllm.server.core.objs import FinishStatus from lightllm.utils.log_utils import init_logger from lightllm.utils.device_utils import kv_trans_use_p2p @@ -31,12 +31,9 @@ def _init_reqs(self, reqs: List[Tuple]): dp_rank_in_node = self.dp_rank_in_node reqs = [req for req in reqs if req[3] == dp_rank_in_node] - g_infer_state_lock.acquire() - uninit_reqs = g_infer_context.add_reqs(reqs, init_prefix_cache=True) # 匹配radix cache,并更新一些资源的管理。 self._post_init_reqs(uninit_reqs=uninit_reqs) - g_infer_state_lock.release() # pd nixl 的 decode 节点模式下当前不支持 cpu cache, 未来可能会支持。 assert not self.args.enable_cpu_cache @@ -64,7 +61,6 @@ def _filter_not_ready_reqs(self, req_ids: List[int]) -> List[InferReq]: 主要用于在 nixl pd 分离模式下, 由子类继承重载, prefill 和 decode 节点过滤 kv 传输错误,或者 kv 传输没有完成的请求。 """ - g_infer_state_lock.acquire() ans_list: List[InferReq] = [] for request_id in req_ids: req_obj: InferReq = g_infer_context.requests_mapping[request_id] @@ -108,8 +104,6 @@ def _filter_not_ready_reqs(self, req_ids: List[int]) -> List[InferReq]: req_obj.shm_req.shm_cur_kv_len = req_obj.cur_kv_len ans_list.append(req_obj) - - g_infer_state_lock.release() return ans_list def _decode_node_gen_trans_tasks(self, req_obj: InferReq): diff --git a/lightllm/server/router/model_infer/model_rpc.py b/lightllm/server/router/model_infer/model_rpc.py index 0590c454ff..afa0fb4c7f 100644 --- a/lightllm/server/router/model_infer/model_rpc.py +++ b/lightllm/server/router/model_infer/model_rpc.py @@ -152,7 +152,6 @@ def _init_env( rank_in_node, node_world_size, info_queue, - router_lock, socket_path, success_event, ): @@ -163,11 +162,6 @@ def _init_env( setproctitle.setproctitle(f"lightllm::{get_unique_server_name()}::model_infer:RANK{rank}") start_parent_check_thread() - # 将调度锁注册到全局的共享变量中 - from lightllm.common.basemodel.infer_lock import g_router_lock - - g_router_lock.obj = router_lock - model_rpc_server = ModelRpcServer(args, rank, rank_in_node, node_world_size, info_queue) # Start rpyc server with Unix socket t = ThreadedServer(model_rpc_server, socket_path=socket_path, protocol_config={"allow_pickle": True}) @@ -183,7 +177,6 @@ async def start_model_process( rank_in_node, node_world_size, info_queue: mp.Queue, - router_lock, ): import lightllm.utils.rpyc_fix_utils as _ @@ -200,7 +193,6 @@ async def start_model_process( rank_in_node, node_world_size, info_queue, - router_lock, socket_path, success_event, ), diff --git a/lightllm/server/router/req_queue/base_queue.py b/lightllm/server/router/req_queue/base_queue.py index 73113a59b8..96e3486cde 100644 --- a/lightllm/server/router/req_queue/base_queue.py +++ b/lightllm/server/router/req_queue/base_queue.py @@ -2,7 +2,6 @@ from lightllm.utils.infer_utils import calculate_time from ..batch import Batch, Req from lightllm.server.core.objs import FinishStatus -from lightllm.common.basemodel.infer_lock import g_router_lock from lightllm.utils.config_utils import get_fixed_kv_len from lightllm.server.core.objs import StartArgs @@ -82,8 +81,7 @@ def update_token_load(self, current_batch: Batch, force_update=False): if self.router.shared_token_load.need_update_dynamic_max_load() or force_update: estimated_peak_token_count, dynamic_max_load = self.calcu_batch_token_load(current_batch) token_ratio1 = self.router.get_used_tokens(self.dp_index) / self.router.max_total_token_num - with g_router_lock.obj: - self.router.shared_token_load.set_current_load(token_ratio1, self.dp_index) - self.router.shared_token_load.set_estimated_peak_token_count(estimated_peak_token_count, self.dp_index) - self.router.shared_token_load.set_dynamic_max_load(dynamic_max_load, self.dp_index) + self.router.shared_token_load.set_current_load(token_ratio1, self.dp_index) + self.router.shared_token_load.set_estimated_peak_token_count(estimated_peak_token_count, self.dp_index) + self.router.shared_token_load.set_dynamic_max_load(dynamic_max_load, self.dp_index) return diff --git a/lightllm/server/router/req_queue/chunked_prefill/impl.py b/lightllm/server/router/req_queue/chunked_prefill/impl.py index 884b5930b0..1201c30285 100644 --- a/lightllm/server/router/req_queue/chunked_prefill/impl.py +++ b/lightllm/server/router/req_queue/chunked_prefill/impl.py @@ -2,7 +2,6 @@ import numpy as np from ...batch import Batch, Req from lightllm.server.router.req_queue.base_queue import BaseQueue -from lightllm.common.basemodel.infer_lock import g_router_lock from lightllm.utils.log_utils import init_logger logger = init_logger(__name__) @@ -37,27 +36,26 @@ def _can_add_new_req(self, req: Req, is_busy, new_batch_first_router_need_tokens size_array = np.arange(1, len(self.cache_len_list) + 1, 1) need_max_token_num = (left_out_len_array * size_array + cum_run_len_array).max() - with g_router_lock.obj: - ok_token_num = ( - need_max_token_num + self.router.shared_token_load.get_frozened_token_count(self.dp_index) - < self.max_total_tokens - ) + ok_token_num = ( + need_max_token_num + self.router.shared_token_load.get_frozened_token_count(self.dp_index) + < self.max_total_tokens + ) - ok_req_num = len(self.cache_len_list) <= self.running_max_req_size + ok_req_num = len(self.cache_len_list) <= self.running_max_req_size - new_batch_first_router_need_tokens += req.get_first_router_need_tokens() - ok_prefill = new_batch_first_router_need_tokens <= self.batch_max_tokens + new_batch_first_router_need_tokens += req.get_first_router_need_tokens() + ok_prefill = new_batch_first_router_need_tokens <= self.batch_max_tokens - if ok_token_num and ok_req_num and ok_prefill: - self.router.shared_token_load.set_estimated_peak_token_count(need_max_token_num, self.dp_index) - self.router.shared_token_load.set_dynamic_max_load( - (need_max_token_num + self.router.shared_token_load.get_frozened_token_count(self.dp_index)) - / self.max_total_tokens, - self.dp_index, - ) - return True, new_batch_first_router_need_tokens - else: - return False, new_batch_first_router_need_tokens + if ok_token_num and ok_req_num and ok_prefill: + self.router.shared_token_load.set_estimated_peak_token_count(need_max_token_num, self.dp_index) + self.router.shared_token_load.set_dynamic_max_load( + (need_max_token_num + self.router.shared_token_load.get_frozened_token_count(self.dp_index)) + / self.max_total_tokens, + self.dp_index, + ) + return True, new_batch_first_router_need_tokens + else: + return False, new_batch_first_router_need_tokens # @calculate_time(show=True, min_cost_ms=10) def generate_new_batch(self, current_batch: Batch): @@ -121,9 +119,8 @@ def _calcu_batch_token_load_batch_not_none(self, current_batch: Batch): else: need_max_token_num = 0 - with g_router_lock.obj: - return ( - need_max_token_num, - (need_max_token_num + self.router.shared_token_load.get_frozened_token_count(self.dp_index)) - / self.max_total_tokens, - ) + return ( + need_max_token_num, + (need_max_token_num + self.router.shared_token_load.get_frozened_token_count(self.dp_index)) + / self.max_total_tokens, + ) diff --git a/lightllm/server/router/req_queue/chunked_prefill/impl_for_nixl_pd.py b/lightllm/server/router/req_queue/chunked_prefill/impl_for_nixl_pd.py index 3b831c92a6..482568ebfb 100644 --- a/lightllm/server/router/req_queue/chunked_prefill/impl_for_nixl_pd.py +++ b/lightllm/server/router/req_queue/chunked_prefill/impl_for_nixl_pd.py @@ -3,7 +3,6 @@ from typing import Tuple from ...batch import Batch, Req from lightllm.server.router.req_queue.base_queue import BaseQueue -from lightllm.common.basemodel.infer_lock import g_router_lock from lightllm.utils.log_utils import init_logger logger = init_logger(__name__) diff --git a/lightllm/server/router/req_queue/dp_base_queue.py b/lightllm/server/router/req_queue/dp_base_queue.py index a73823b8b7..866e1b9f42 100644 --- a/lightllm/server/router/req_queue/dp_base_queue.py +++ b/lightllm/server/router/req_queue/dp_base_queue.py @@ -3,7 +3,6 @@ from ..batch import Batch, Req from lightllm.server.router.req_queue.base_queue import BaseQueue from lightllm.server.router.req_queue.dp_balancer import get_dp_balancer -from lightllm.common.basemodel.infer_lock import g_router_lock from lightllm.utils.log_utils import init_logger logger = init_logger(__name__) @@ -70,8 +69,7 @@ def update_token_load(self, current_batch: Batch, force_update=False): current_batch ) token_ratio1 = self.router.get_used_tokens(dp_index) / self.router.max_total_token_num - with g_router_lock.obj: - self.router.shared_token_load.set_current_load(token_ratio1, dp_index) - self.router.shared_token_load.set_estimated_peak_token_count(estimated_peak_token_count, dp_index) - self.router.shared_token_load.set_dynamic_max_load(dynamic_max_load, dp_index) + self.router.shared_token_load.set_current_load(token_ratio1, dp_index) + self.router.shared_token_load.set_estimated_peak_token_count(estimated_peak_token_count, dp_index) + self.router.shared_token_load.set_dynamic_max_load(dynamic_max_load, dp_index) return From 1f7f265c7b5df657b5687b20d5822c0e02aa38c4 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Tue, 9 Jun 2026 09:34:20 +0000 Subject: [PATCH 04/17] fix --- lightllm/server/router/manager.py | 7 +------ lightllm/server/router/req_queue/base_queue.py | 6 ++---- .../router/req_queue/chunked_prefill/beam_impl.py | 11 +++-------- .../router/req_queue/chunked_prefill/impl.py | 11 +++-------- lightllm/server/router/token_load.py | 15 +-------------- 5 files changed, 10 insertions(+), 40 deletions(-) diff --git a/lightllm/server/router/manager.py b/lightllm/server/router/manager.py index 1e869093d0..8db0eaf77f 100644 --- a/lightllm/server/router/manager.py +++ b/lightllm/server/router/manager.py @@ -72,7 +72,6 @@ def __init__(self, args: StartArgs): self.shared_token_load = TokenLoad(f"{get_unique_server_name()}_shared_token_load", self.dp_size_in_node) for dp_index in range(self.dp_size_in_node): self.shared_token_load.set_estimated_peak_token_count(0, dp_index) - self.shared_token_load.set_frozened_token_count(0, dp_index) self.shared_token_load.set_current_load(0.0, dp_index) self.shared_token_load.set_logical_max_load(0.0, dp_index) self.shared_token_load.set_dynamic_max_load(0.0, dp_index) @@ -232,13 +231,11 @@ async def loop_for_fwd( - self.read_only_statics_mem_manager.get_unrefed_token_num(dp_index) ) / self.max_total_token_num d_i = dp_index - frozen_token_num = self.shared_token_load.get_frozened_token_count(d_i) estimated_peak_token_count = self.shared_token_load.get_estimated_peak_token_count(d_i) paused_req_num = self._get_paused_req_num_in_dp_index(dp_index=d_i) logger.debug( f"dp_i {d_i} current batch size: {len(self.running_batch.reqs)} \n" f"dp_i {d_i} paused req num: {paused_req_num} \n" - f"dp_i {d_i} frozen token num: {frozen_token_num} \n" f"dp_i {d_i} estimated_peak_token_count: {estimated_peak_token_count} \n" f"dp_i {d_i} token used ratio: {token_ratio1} not contain prompt cache tree unrefed token\n" f"dp_i {d_i} token used ratio: {token_ratio2} contain prompt cache tree unrefed token" @@ -264,11 +261,9 @@ async def loop_for_fwd( self.metric_client.gauge_set("lightllm_queue_size", 0.0) self.metric_client.gauge_set("lightllm_batch_current_max_tokens", 0.0) # 60s print once - if log_time_ready("frozen_info", 60): + if log_time_ready("token_load_info", 60): for dp_i in range(self.dp_size_in_node): - frozen_token_num = self.shared_token_load.get_frozened_token_count(dp_i) estimated_peak_token_count = self.shared_token_load.get_estimated_peak_token_count(dp_i) - logger.debug(f"dp_i {dp_i} frozen token num: {frozen_token_num} \n") logger.debug(f"dp_i {dp_i} estimated_peak_token_count: {estimated_peak_token_count} \n") await asyncio.sleep(self._get_schedule_time_interval()) diff --git a/lightllm/server/router/req_queue/base_queue.py b/lightllm/server/router/req_queue/base_queue.py index 96e3486cde..0d1ffe6967 100644 --- a/lightllm/server/router/req_queue/base_queue.py +++ b/lightllm/server/router/req_queue/base_queue.py @@ -45,9 +45,7 @@ def is_busy(self): # 计算当前所有的token使用量, 如果使用了dynamic prompt cache, 使用的token量中不包含,cache tree 中未被引用的数据。 cur_all_used_tokens = self.router.get_used_tokens(self.dp_index) # 判断当前服务是否处于token使用率过高的状态,过高的情况下,调度要偏向保守 - cur_token_ratio = ( - cur_all_used_tokens + self.router.shared_token_load.get_frozened_token_count(self.dp_index) - ) / self.max_total_tokens + cur_token_ratio = cur_all_used_tokens / self.max_total_tokens is_busy = cur_token_ratio >= self.router_token_ratio return is_busy @@ -70,7 +68,7 @@ def generate_new_batch(self, current_batch: Batch): def calcu_batch_token_load(self, current_batch: Batch): if current_batch is None: - return 0, self.router.shared_token_load.get_frozened_token_count(self.dp_index) / self.max_total_tokens + return 0, 0.0 else: return self._calcu_batch_token_load_batch_not_none(current_batch) diff --git a/lightllm/server/router/req_queue/chunked_prefill/beam_impl.py b/lightllm/server/router/req_queue/chunked_prefill/beam_impl.py index 23f94de704..63084d9d3b 100644 --- a/lightllm/server/router/req_queue/chunked_prefill/beam_impl.py +++ b/lightllm/server/router/req_queue/chunked_prefill/beam_impl.py @@ -49,10 +49,7 @@ def _can_add_new_group_reqs(self, cur_handle_group_reqs: List[Req], is_busy, new # prefill token 计算, 因为对beam的prefill计算过程是共享的,所以只计算一个请求对应的token数量 new_batch_first_router_need_tokens += req.get_first_router_need_tokens() - ok_token_num = ( - need_max_token_num + self.router.shared_token_load.get_frozened_token_count(self.dp_index) - < self.max_total_tokens - ) + ok_token_num = need_max_token_num < self.max_total_tokens ok_req_num = len(self.cache_len_list) <= self.running_max_req_size @@ -62,8 +59,7 @@ def _can_add_new_group_reqs(self, cur_handle_group_reqs: List[Req], is_busy, new if ok_token_num and ok_req_num and ok_prefill: self.router.shared_token_load.set_estimated_peak_token_count(need_max_token_num, self.dp_index) self.router.shared_token_load.set_dynamic_max_load( - (need_max_token_num + self.router.shared_token_load.get_frozened_token_count(self.dp_index)) - / self.max_total_tokens, + need_max_token_num / self.max_total_tokens, self.dp_index, ) return True, new_batch_first_router_need_tokens @@ -167,6 +163,5 @@ def _calcu_batch_token_load_batch_not_none(self, current_batch: Batch): need_max_token_num = max(need_max_token_num, cumsum_len + index * cur_ouput_len) return ( need_max_token_num, - (need_max_token_num + self.router.shared_token_load.get_frozened_token_count(self.dp_index)) - / self.max_total_tokens, + need_max_token_num / self.max_total_tokens, ) diff --git a/lightllm/server/router/req_queue/chunked_prefill/impl.py b/lightllm/server/router/req_queue/chunked_prefill/impl.py index 1201c30285..e82cc7e181 100644 --- a/lightllm/server/router/req_queue/chunked_prefill/impl.py +++ b/lightllm/server/router/req_queue/chunked_prefill/impl.py @@ -36,10 +36,7 @@ def _can_add_new_req(self, req: Req, is_busy, new_batch_first_router_need_tokens size_array = np.arange(1, len(self.cache_len_list) + 1, 1) need_max_token_num = (left_out_len_array * size_array + cum_run_len_array).max() - ok_token_num = ( - need_max_token_num + self.router.shared_token_load.get_frozened_token_count(self.dp_index) - < self.max_total_tokens - ) + ok_token_num = need_max_token_num < self.max_total_tokens ok_req_num = len(self.cache_len_list) <= self.running_max_req_size @@ -49,8 +46,7 @@ def _can_add_new_req(self, req: Req, is_busy, new_batch_first_router_need_tokens if ok_token_num and ok_req_num and ok_prefill: self.router.shared_token_load.set_estimated_peak_token_count(need_max_token_num, self.dp_index) self.router.shared_token_load.set_dynamic_max_load( - (need_max_token_num + self.router.shared_token_load.get_frozened_token_count(self.dp_index)) - / self.max_total_tokens, + need_max_token_num / self.max_total_tokens, self.dp_index, ) return True, new_batch_first_router_need_tokens @@ -121,6 +117,5 @@ def _calcu_batch_token_load_batch_not_none(self, current_batch: Batch): return ( need_max_token_num, - (need_max_token_num + self.router.shared_token_load.get_frozened_token_count(self.dp_index)) - / self.max_total_tokens, + need_max_token_num / self.max_total_tokens, ) diff --git a/lightllm/server/router/token_load.py b/lightllm/server/router/token_load.py index e4ce4b8352..45fa34d5da 100644 --- a/lightllm/server/router/token_load.py +++ b/lightllm/server/router/token_load.py @@ -20,7 +20,7 @@ def __init__(self, name, dp_size_in_node) -> None: f"{name}_ext_infos", shape=( self.dp_size_in_node, - 2, + 1, ), dtype=np.int64, ) @@ -40,19 +40,6 @@ def add_estimated_peak_token_count(self, value: int, index: int): def get_estimated_peak_token_count(self, index: int) -> int: return self.shared_token_infos.arr[index, 0] - # 记录系统被临时固定的不能被使用的token数,主要在于 pd 分离的模式下 - # 推理系统需要在 kv 传输时临时固定一些 token, 防止调度系统估计失误,导致调度问题 - def set_frozened_token_count(self, obj: int, index: int): - self.shared_token_infos.arr[index, 1] = obj - return - - def get_frozened_token_count(self, index: int) -> int: - return self.shared_token_infos.arr[index, 1] - - def add_frozened_token_count(self, value: int, index: int): - self.shared_token_infos.arr[index, 1] += value - return - # current_load 当前使用token量,估计的负载 def set_current_load(self, value, index: int): self.shared_token_load.arr[index, 0] = value From ab811d564ae0d11d5bc2d1ab527c78964459d3df Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Wed, 10 Jun 2026 00:44:58 +0000 Subject: [PATCH 05/17] fix --- lightllm/utils/envs_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lightllm/utils/envs_utils.py b/lightllm/utils/envs_utils.py index 2bdd4005fa..26e94dfb8b 100644 --- a/lightllm/utils/envs_utils.py +++ b/lightllm/utils/envs_utils.py @@ -11,7 +11,7 @@ def set_unique_server_name(args): - node_uuid = uuid.uuid1().hex[0:8] + node_uuid = uuid.uuid4().hex[0:16] if args.run_mode == "pd_master": os.environ["LIGHTLLM_UNIQUE_SERVICE_NAME_ID"] = str(node_uuid) + "_pd_master" From 3b6767c7feb178a1bfc7989fff98d04da1712050 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Wed, 10 Jun 2026 00:57:08 +0000 Subject: [PATCH 06/17] fix --- lightllm/server/httpserver/manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py index b9799cd061..fd97886844 100644 --- a/lightllm/server/httpserver/manager.py +++ b/lightllm/server/httpserver/manager.py @@ -468,7 +468,7 @@ async def generate( yield sub_req_id, request_output, metadata, finish_status except (ClientDisconnected, Exception) as e: - logger.error(f"group_request_id: {group_request_id} has exception {str(e)}") + logger.warning(f"group_request_id: {group_request_id} has exception {str(e)}") if isinstance(e, ClientDisconnected): logger.warning(f"group_request_id: {group_request_id} {e.reason}") From 29a17eb024f684d602be87de9d3718b7ef87a51b Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Wed, 10 Jun 2026 02:27:47 +0000 Subject: [PATCH 07/17] fix --- .../pd_nixl/decode_node_impl/decode_trans_process.py | 2 ++ .../router/model_infer/mode_backend/pd_nixl/kv_transporter.py | 4 ++-- .../pd_nixl/prefill_node_impl/prefill_trans_process.py | 1 + 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/decode_trans_process.py b/lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/decode_trans_process.py index 0a947f8071..cb1b92c184 100644 --- a/lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/decode_trans_process.py +++ b/lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/decode_trans_process.py @@ -46,6 +46,8 @@ def _init_env( task_out_queue: mp.Queue, up_status_in_queue: Optional[mp.SimpleQueue], ): + import lightllm.utils.rpyc_fix_utils as _ + import os # ------------------------------------------------------------------------- diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/kv_transporter.py b/lightllm/server/router/model_infer/mode_backend/pd_nixl/kv_transporter.py index ee6c04612b..c413515eda 100644 --- a/lightllm/server/router/model_infer/mode_backend/pd_nixl/kv_transporter.py +++ b/lightllm/server/router/model_infer/mode_backend/pd_nixl/kv_transporter.py @@ -30,8 +30,8 @@ def create_kv_transporter(args: StartArgs, node_id: int, tp_idx: int, kv_move_bu tp_idx=tp_idx, kv_move_buffer=kv_move_buffer, host_ip=get_hostname_ip() or args.host, - store_port_min=port_min, - store_port_max=port_max, + control_port_min=port_min, + control_port_max=port_max, ) raise ValueError(f"unsupported LIGHTLLM_PD_KV_TRANSPORT_BACKEND={backend}") diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/prefill_node_impl/prefill_trans_process.py b/lightllm/server/router/model_infer/mode_backend/pd_nixl/prefill_node_impl/prefill_trans_process.py index 856fba1188..fa7367958c 100644 --- a/lightllm/server/router/model_infer/mode_backend/pd_nixl/prefill_node_impl/prefill_trans_process.py +++ b/lightllm/server/router/model_infer/mode_backend/pd_nixl/prefill_node_impl/prefill_trans_process.py @@ -40,6 +40,7 @@ def _init_env( task_in_queue: mp.Queue, task_out_queue: mp.Queue, ): + import lightllm.utils.rpyc_fix_utils as _ import os From d6b7cd2e07f5ae32a23d9219ea5277784a9c6dab Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Wed, 10 Jun 2026 06:56:28 +0000 Subject: [PATCH 08/17] fix --- .../pd_nixl/nccl_kv_transporter.py | 627 ++++++++++-------- 1 file changed, 347 insertions(+), 280 deletions(-) diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/nccl_kv_transporter.py b/lightllm/server/router/model_infer/mode_backend/pd_nixl/nccl_kv_transporter.py index b983312dec..d048cbf9b7 100644 --- a/lightllm/server/router/model_infer/mode_backend/pd_nixl/nccl_kv_transporter.py +++ b/lightllm/server/router/model_infer/mode_backend/pd_nixl/nccl_kv_transporter.py @@ -1,16 +1,20 @@ import copy +import errno +import queue import pickle import threading from dataclasses import dataclass -from datetime import timedelta from typing import Dict, List, Optional +import rpyc import torch from torch import Tensor -from torch.distributed import TCPStore +from rpyc.utils.classic import obtain +from rpyc.utils.server import ThreadedServer from lightllm.distributed.pynccl import PyNcclCommunicator, StatelessP2PProcessGroup from lightllm.server.pd_io_struct import NIXLChunckedTransTask, NixlAgentMetadata +from lightllm.utils.envs_utils import get_env_start_args from lightllm.utils.log_utils import init_logger from lightllm.utils.net_utils import get_hostname_ip @@ -21,42 +25,16 @@ class NcclAgentMetadata: agent_name: str host_ip: str - store_port: int + control_port: int device_id: int -@dataclass -class _NcclXferHandle: - thread: Optional[threading.Thread] - status: str = "PROC" - error_info: Optional[str] = None - - -class _PeerSeqTurn: - def __init__(self, transporter: "NcclKVTransporter", peer_name: str, seq: int): - self.transporter = transporter - self.peer_name = peer_name - self.seq = seq - - def __enter__(self): - with self.transporter._peer_seq_cond: - while self.transporter._peer_seq_to_run.get(self.peer_name, 0) != self.seq: - self.transporter._peer_seq_cond.wait() - return self - - def __exit__(self, exc_type, exc_value, traceback): - with self.transporter._peer_seq_cond: - self.transporter._peer_seq_to_run[self.peer_name] = self.seq + 1 - self.transporter._peer_seq_cond.notify_all() - return False - - class NcclKVTransporter: """ NIXL-compatible transporter backed by NCCL point-to-point operations. NIXL provides remote notifications and one-sided WRITE. NCCL does not, so this - class uses a small TCPStore control plane for notifications and communicator + class uses a small TCP control channel for notifications and communicator bootstrap while preserving the same request/ready/done/error interface used by pd_nixl trans-process management. """ @@ -67,65 +45,31 @@ def __init__( tp_idx: int, kv_move_buffer: Tensor, host_ip: Optional[str] = None, - store_port: Optional[int] = None, - store_port_min: int = 20000, - store_port_max: int = 30000, + control_port_min: int = 20000, + control_port_max: int = 30000, ): self.node_id = node_id self.tp_idx = tp_idx self.kv_move_buffer = kv_move_buffer + args = get_env_start_args() + assert args.run_mode in ["nixl_prefill", "nixl_decode"], args.run_mode + self.is_prefill_node = args.run_mode == "nixl_prefill" self.capture_telemetry = False self.num_pages, self.page_size, self.num_layers, self.kv_head_num, self.head_dims = kv_move_buffer.shape self.host_ip = host_ip or get_hostname_ip() assert self.host_ip is not None, "can not get host ip for NcclKVTransporter" - self.store, self.store_port = self._create_local_store( - store_port=store_port, - store_port_min=store_port_min, - store_port_max=store_port_max, + self.control_channel = _NcclControlChannel( + host_ip=self.host_ip, + port_min=control_port_min, + port_max=control_port_max, ) self.remote_agents: Dict[str, NixlAgentMetadata] = {} - self.remote_stores: Dict[str, TCPStore] = {} - self._comms: Dict[str, PyNcclCommunicator] = {} - self._comm_create_lock = threading.Lock() - self._peer_seq_cond = threading.Condition() - self._peer_seq_to_assign: Dict[str, int] = {} - self._peer_seq_to_run: Dict[str, int] = {} - self._recv_notif_counter = 0 - self._deferred_notifs: List[bytes] = [] - self._recv_task_status: Dict[str, _NcclXferHandle] = {} - self._xfer_handle_counter = 0 - self._xfer_handles: Dict[int, _NcclXferHandle] = {} + self._peers: Dict[str, "_NcclPeer"] = {} + self._peer_lock = threading.Lock() return - def _create_local_store( - self, store_port: Optional[int], store_port_min: int, store_port_max: int - ) -> tuple[TCPStore, int]: - if store_port is not None: - ports = [store_port] - else: - ports = list(range(store_port_min, store_port_max + 1)) - - last_error = None - for port in ports: - try: - store = TCPStore( - host_name=self.host_ip, - port=port, - is_master=True, - use_libuv=True, - timeout=timedelta(seconds=30), - ) - return store, port - except BaseException as e: - last_error = e - logger.warning(f"Create NCCL TCPStore on {self.host_ip}:{port} failed: {e}") - - raise RuntimeError( - f"can not allocate NCCL TCPStore port in [{store_port_min}, {store_port_max}]" - ) from last_error - @property def agent_name(self) -> str: return f"{self.node_id}_{self.tp_idx}" @@ -136,7 +80,7 @@ def agent_metadata(self) -> bytes: NcclAgentMetadata( agent_name=self.agent_name, host_ip=self.host_ip, - store_port=self.store_port, + control_port=self.control_channel.port, device_id=self.tp_idx, ) ) @@ -156,26 +100,8 @@ def local_page_mem_desc(self) -> bytes: def get_new_notifs(self) -> Dict[str, List[bytes]]: notifs: Dict[str, List[bytes]] = {} - still_deferred = [] - for notify in self._deferred_notifs: - ready_notify = self._get_ready_notify(notify) - if ready_notify is None: - still_deferred.append(notify) - else: - notifs.setdefault(self._get_notify_source_agent_name(ready_notify), []).append(ready_notify) - self._deferred_notifs = still_deferred - - while True: - key = self._notif_key(self.agent_name, self._recv_notif_counter) - if not self.store.check([key]): - break - notify = bytes(self.store.get(key)) - ready_notify = self._get_ready_notify(notify) - if ready_notify is None: - self._deferred_notifs.append(notify) - else: - notifs.setdefault(self._get_notify_source_agent_name(ready_notify), []).append(ready_notify) - self._recv_notif_counter += 1 + for notify in self.control_channel.get_notifs(): + notifs.setdefault(self._get_notify_source_agent_name(notify), []).append(notify) return notifs def connect_add_remote_agent(self, remote_agent: NixlAgentMetadata): @@ -188,26 +114,18 @@ def connect_add_remote_agent(self, remote_agent: NixlAgentMetadata): ), f"Peer name {metadata.agent_name} does not match remote name {remote_agent.agent_name}" self.remote_agents[remote_agent.agent_name] = remote_agent - self.remote_stores[remote_agent.agent_name] = TCPStore( - host_name=metadata.host_ip, - port=metadata.store_port, - is_master=False, - use_libuv=True, - timeout=timedelta(seconds=30), + logger.info( + f"Added NCCL remote agent {remote_agent.agent_name} at {metadata.host_ip}:{metadata.control_port}" ) - logger.info(f"Added NCCL remote agent {remote_agent.agent_name} at {metadata.host_ip}:{metadata.store_port}") return def remove_remote_agent(self, peer_name: str): if peer_name in self.remote_agents: self.remote_agents.pop(peer_name, None) - self.remote_stores.pop(peer_name, None) - comm = self._comms.pop(peer_name, None) - if comm is not None: - comm.destroy() - with self._peer_seq_cond: - self._peer_seq_to_assign.pop(peer_name, None) - self._peer_seq_to_run.pop(peer_name, None) + with self._peer_lock: + peer = self._peers.pop(peer_name, None) + if peer is not None: + peer.close() else: logger.warning(f"try to remove remote agent, but peer name {peer_name} agent did not exist") return @@ -233,7 +151,10 @@ def send_write_request_task_to_decode_node(self, trans_task: NIXLChunckedTransTa return def send_write_ready_task_to_prefill_node(self, trans_task: NIXLChunckedTransTask): - self._start_recv_task(trans_task) + if trans_task.prefill_agent_name not in self.remote_agents: + self.connect_add_remote_agent(trans_task.create_prefill_agent_obj()) + + self._get_peer(trans_task.prefill_agent_name).start_recv(trans_task) new_trans_task = self._copy_notify_task(trans_task) new_trans_task.nixl_write_stage = "ready" @@ -266,226 +187,372 @@ def send_error_info_to_decode_node(self, trans_task: NIXLChunckedTransTask): self._send_task_notif(trans_task.decode_agent_name, new_trans_task) return - def write_blocks_paged(self, trans_task: NIXLChunckedTransTask) -> int: + def write_blocks_paged(self, trans_task: NIXLChunckedTransTask) -> "_NcclXferHandle": assert trans_task.nixl_src_page_index is not None and trans_task.nixl_dst_page_index is not None decode_agent_name = trans_task.decode_agent_name if decode_agent_name not in self.remote_agents: self.connect_add_remote_agent(trans_task.create_decode_agent_obj()) - self._ensure_comm( - remote_agent_name=decode_agent_name, - is_server=True, - store=self.store, - ) - handle = self._next_xfer_handle() - seq = self._assign_peer_seq(decode_agent_name) - xfer_handle = _NcclXferHandle( - thread=threading.Thread(target=self._send_page_task, args=(handle, trans_task, seq), daemon=True) - ) - self._xfer_handles[handle] = xfer_handle - xfer_handle.thread.start() - return handle + return self._get_peer(decode_agent_name).send_page(trans_task) def check_task_status(self, trans_task: NIXLChunckedTransTask) -> str: assert trans_task.xfer_handle is not None - handle = self._xfer_handles[trans_task.xfer_handle] - if handle.status == "ERR": - logger.warning(f"Transfer failed with trans task {trans_task.to_str()}: {handle.error_info}") - return handle.status + return trans_task.xfer_handle.check_status() def release_xfer_handle(self, handle): - xfer_handle = self._xfer_handles.pop(handle, None) - if xfer_handle is not None: - xfer_handle.thread.join(timeout=1) return def shutdown(self): - for handle in list(self._xfer_handles.keys()): - self.release_xfer_handle(handle) - for comm in list(self._comms.values()): - comm.destroy() - self._comms.clear() + with self._peer_lock: + peers = list(self._peers.values()) + self._peers.clear() + for peer in peers: + peer.close() self.remote_agents.clear() - self.remote_stores.clear() + self.control_channel.close() return - def _start_recv_task(self, trans_task: NIXLChunckedTransTask): - if trans_task.prefill_agent_name not in self.remote_agents: - self.connect_add_remote_agent(trans_task.create_prefill_agent_obj()) - self._recv_task_status[trans_task.get_key()] = _NcclXferHandle(thread=None) - seq = self._assign_peer_seq(trans_task.prefill_agent_name) - threading.Thread(target=self._recv_page_task, args=(copy.copy(trans_task), seq), daemon=True).start() + def _get_peer(self, peer_name: str) -> "_NcclPeer": + with self._peer_lock: + peer = self._peers.get(peer_name) + if peer is None: + peer = _NcclPeer(self, peer_name) + self._peers[peer_name] = peer + return peer + + def _send_task_notif(self, remote_agent_name: str, trans_task: NIXLChunckedTransTask): + if remote_agent_name not in self.remote_agents: + if remote_agent_name == trans_task.decode_agent_name: + self.connect_add_remote_agent(trans_task.create_decode_agent_obj()) + else: + self.connect_add_remote_agent(trans_task.create_prefill_agent_obj()) + + remote_metadata = self._get_remote_metadata(remote_agent_name) + self.control_channel.send_notif( + remote_agent_name, + remote_metadata.host_ip, + remote_metadata.control_port, + pickle.dumps(trans_task), + ) return - def _send_page_task(self, handle: int, trans_task: NIXLChunckedTransTask, seq: int): - xfer_handle = self._xfer_handles[handle] + def _get_remote_metadata(self, remote_agent_name: str) -> NcclAgentMetadata: + remote_agent = self.remote_agents[remote_agent_name] + return pickle.loads(remote_agent.agent_metadata) + + def _copy_notify_task(self, trans_task: NIXLChunckedTransTask) -> NIXLChunckedTransTask: + new_trans_task: NIXLChunckedTransTask = copy.copy(trans_task) + new_trans_task.mem_indexes = None + new_trans_task.xfer_handle = None + return new_trans_task + + def _get_notify_source_agent_name(self, notify: bytes) -> str: + notify_obj = pickle.loads(notify) + assert isinstance(notify_obj, NIXLChunckedTransTask), type(notify_obj) + + if notify_obj.error_info is not None: + if self.is_prefill_node: + assert notify_obj.decode_agent_name is not None + return notify_obj.decode_agent_name + else: + assert notify_obj.prefill_agent_name is not None + return notify_obj.prefill_agent_name + + if notify_obj.nixl_write_stage == "request": + assert notify_obj.prefill_agent_name is not None + return notify_obj.prefill_agent_name + + if notify_obj.nixl_write_stage in ["ready", "done"]: + assert notify_obj.decode_agent_name is not None + return notify_obj.decode_agent_name + + raise AssertionError(f"unexpected notify stage: {notify_obj.nixl_write_stage}") + + +@dataclass +class _NcclXferHandle: + peer_name: str + event: torch.cuda.Event + status: str = "PROC" + error_info: Optional[str] = None + + def check_status(self) -> str: + if self.status != "PROC": + return self.status + try: - remote_agent = self.remote_agents[trans_task.decode_agent_name] - remote_metadata: NcclAgentMetadata = pickle.loads(remote_agent.agent_metadata) - page_tensor = self.kv_move_buffer[trans_task.nixl_src_page_index] - comm = self._get_cached_comm(trans_task.decode_agent_name) - with self._peer_seq_turn(trans_task.decode_agent_name, seq): - comm.send(page_tensor, dst=1) - torch.cuda.current_stream().synchronize() - xfer_handle.status = "DONE" - logger.info( - f"NCCL send page done request_id={trans_task.request_id} " - f"src_page={trans_task.nixl_src_page_index} dst_agent={remote_metadata.agent_name}" - ) + if self.event.query(): + self.status = "DONE" except BaseException as e: - xfer_handle.status = "ERR" - xfer_handle.error_info = str(e) - logger.exception(str(e)) - self._drop_comm(trans_task.decode_agent_name) + self.status = "ERR" + self.error_info = str(e) + return self.status + + +class _NcclPeer: + def __init__(self, transporter: NcclKVTransporter, peer_name: str): + self.transporter = transporter + self.peer_name = peer_name + self.comm: Optional[PyNcclCommunicator] = None + self.stream: Optional[torch.cuda.Stream] = None + self.recv_queue: Optional["queue.Queue[Optional[NIXLChunckedTransTask]]"] = None + self._lock = threading.Lock() + + def send_page(self, trans_task: NIXLChunckedTransTask) -> _NcclXferHandle: + assert trans_task.nixl_src_page_index is not None and trans_task.nixl_dst_page_index is not None + page_tensor = self.transporter.kv_move_buffer[trans_task.nixl_src_page_index] + comm = self._ensure_comm(is_server=True) + stream = self._get_stream() + + comm.send(page_tensor, dst=1, stream=stream) + event = torch.cuda.Event() + event.record(stream) + + logger.info( + f"NCCL send page posted request_id={trans_task.request_id} " + f"src_page={trans_task.nixl_src_page_index} dst_agent={self.peer_name}" + ) + return _NcclXferHandle(peer_name=self.peer_name, event=event) + + def start_recv(self, trans_task: NIXLChunckedTransTask): + self._get_recv_queue().put(copy.copy(trans_task)) + return + + def close(self): + with self._lock: + recv_queue = self.recv_queue + self.recv_queue = None + comm = self.comm + self.comm = None + self.stream = None + + if recv_queue is not None: + recv_queue.put(None) + if comm is not None: + comm.destroy() return - def _recv_page_task(self, trans_task: NIXLChunckedTransTask, seq: int): + def _get_stream(self) -> torch.cuda.Stream: + with self._lock: + if self.stream is None: + torch.cuda.set_device(self.transporter.tp_idx) + self.stream = torch.cuda.Stream() + return self.stream + + def _get_recv_queue(self) -> "queue.Queue[Optional[NIXLChunckedTransTask]]": + with self._lock: + if self.recv_queue is not None: + return self.recv_queue + + self.recv_queue = queue.Queue() + threading.Thread(target=self._recv_page_loop, args=(self.recv_queue,), daemon=True).start() + return self.recv_queue + + def _recv_page_loop(self, recv_queue: "queue.Queue[Optional[NIXLChunckedTransTask]]"): + torch.cuda.set_device(self.transporter.tp_idx) + while True: + trans_task = recv_queue.get() + if trans_task is None: + return + self._recv_page(trans_task) + + def _recv_page(self, trans_task: NIXLChunckedTransTask): try: - page_tensor = self.kv_move_buffer[trans_task.nixl_dst_page_index] - remote_agent = self.remote_agents[trans_task.prefill_agent_name] - remote_store = self.remote_stores[remote_agent.agent_name] - comm = self._ensure_comm( - remote_agent_name=trans_task.prefill_agent_name, - is_server=False, - store=remote_store, - ) - with self._peer_seq_turn(trans_task.prefill_agent_name, seq): - comm.recv(page_tensor, src=0) - torch.cuda.current_stream().synchronize() - self._recv_task_status[trans_task.get_key()].status = "DONE" + page_tensor = self.transporter.kv_move_buffer[trans_task.nixl_dst_page_index] + comm = self._ensure_comm(is_server=False) + stream = self._get_stream() + comm.recv(page_tensor, src=0, stream=stream) logger.info( f"NCCL recv page done request_id={trans_task.request_id} " f"dst_page={trans_task.nixl_dst_page_index}" ) except BaseException as e: trans_task.error_info = str(e) - recv_status = self._recv_task_status.get(trans_task.get_key(), None) - if recv_status is not None: - recv_status.status = "ERR" - recv_status.error_info = str(e) logger.exception(str(e)) - self._drop_comm(trans_task.prefill_agent_name) - self.send_error_info_to_prefill_node(trans_task) + self._drop_comm() + self.transporter.send_error_info_to_prefill_node(trans_task) return - def _get_ready_notify(self, notify: bytes) -> Optional[bytes]: - try: - notify_obj = pickle.loads(notify) - except BaseException: - return notify + def _ensure_comm(self, is_server: bool) -> PyNcclCommunicator: + with self._lock: + if self.comm is not None: + return self.comm - if not isinstance(notify_obj, NIXLChunckedTransTask): - return notify + if is_server: + src_id = self.transporter.agent_name + dest_id = self.peer_name + else: + src_id = self.peer_name + dest_id = self.transporter.agent_name - if notify_obj.nixl_write_stage != "done": - return notify + group = StatelessP2PProcessGroup.create( + src_id=src_id, + dest_id=dest_id, + is_server=is_server, + store=_NcclControlStore(self.transporter, self.peer_name), + ) + self.comm = PyNcclCommunicator(group, self.transporter.tp_idx) + logger.info(f"Created NCCL communicator with peer {self.peer_name}") + return self.comm - recv_status = self._recv_task_status.get(notify_obj.get_key(), None) - if recv_status is None or recv_status.status == "PROC": - return None + def _drop_comm(self): + with self._lock: + comm = self.comm + self.comm = None - self._recv_task_status.pop(notify_obj.get_key(), None) - if recv_status.status == "ERR": - notify_obj.error_info = recv_status.error_info or "nccl recv failed" - return pickle.dumps(notify_obj) + if comm is not None: + comm.destroy() + logger.warning(f"Dropped NCCL communicator with peer {self.peer_name}") + return - return notify - def _get_cached_comm(self, remote_agent_name: str) -> PyNcclCommunicator: - comm = self._comms.get(remote_agent_name) - if comm is None: - raise RuntimeError(f"NCCL communicator with peer {remote_agent_name} is not initialized") - return comm +class _NcclControlService(rpyc.Service): + def __init__(self, channel: "_NcclControlChannel"): + super().__init__() + self.channel = channel - def _ensure_comm( - self, - remote_agent_name: str, - is_server: bool, - store: TCPStore, - ) -> PyNcclCommunicator: - comm = self._comms.get(remote_agent_name) - if comm is not None: - return comm + def exposed_push_notif(self, payload: bytes): + payload = obtain(payload) + self.channel.notif_queue.put(payload) + return - with self._comm_create_lock: - comm = self._comms.get(remote_agent_name) - if comm is not None: - return comm + def exposed_set_value(self, key: str, value: bytes): + key = obtain(key) + value = obtain(value) + self.channel.add_store_value(key, value) + return - if is_server: - src_id = self.agent_name - dest_id = remote_agent_name - else: - src_id = remote_agent_name - dest_id = self.agent_name - group = StatelessP2PProcessGroup.create( - src_id=src_id, - dest_id=dest_id, - is_server=is_server, - store=store, - ) - comm = PyNcclCommunicator(group, self.tp_idx) - self._comms[remote_agent_name] = comm - logger.info(f"Created NCCL communicator with peer {remote_agent_name}") - return comm - - def _drop_comm(self, remote_agent_name: str): - with self._comm_create_lock: - comm = self._comms.pop(remote_agent_name, None) - if comm is not None: - comm.destroy() - logger.warning(f"Dropped NCCL communicator with peer {remote_agent_name}") +class _NcclControlChannel: + def __init__( + self, + host_ip: str, + port_min: int, + port_max: int, + ): + self.notif_queue: "queue.Queue[bytes]" = queue.Queue() + self._store_values: Dict[str, bytes] = {} + self._store_cond = threading.Condition() + self._conn_lock = threading.Lock() + self._conns: Dict[tuple[str, str, int], rpyc.Connection] = {} + self._server, self.port = self._start_server(host_ip, port_min, port_max) + + def _start_server(self, host_ip: str, port_min: int, port_max: int) -> tuple[ThreadedServer, int]: + last_error = None + for cur_port in range(port_min, port_max + 1): + try: + server = ThreadedServer( + _NcclControlService(self), + hostname=host_ip, + port=cur_port, + protocol_config={ + "allow_pickle": True, + "allow_all_attrs": True, + "allow_getattr": True, + "allow_setattr": True, + }, + ) + threading.Thread(target=server.start, daemon=True).start() + logger.info(f"NCCL RPyC control channel listen on {host_ip}:{cur_port}") + return server, cur_port + except OSError as e: + last_error = e + if e.errno == errno.EADDRINUSE: + logger.info(f"NCCL RPyC control port {host_ip}:{cur_port} is in use, try next port") + else: + logger.warning(f"Create NCCL RPyC control channel on {host_ip}:{cur_port} failed: {e}") + raise RuntimeError(f"can not allocate NCCL control port in [{port_min}, {port_max}]") from last_error + + def close(self): + with self._conn_lock: + for conn in self._conns.values(): + try: + conn.close() + except Exception: + pass + self._conns.clear() + self._server.close() return - def _assign_peer_seq(self, peer_name: str) -> int: - with self._peer_seq_cond: - seq = self._peer_seq_to_assign.get(peer_name, 0) - self._peer_seq_to_assign[peer_name] = seq + 1 - self._peer_seq_to_run.setdefault(peer_name, 0) - return seq + def add_store_value(self, key: str, value: bytes): + with self._store_cond: + self._store_values[key] = value + self._store_cond.notify_all() + return - def _peer_seq_turn(self, peer_name: str, seq: int): - return _PeerSeqTurn(self, peer_name, seq) + def wait_store_value(self, key: str, timeout: float = 30.0) -> bytes: + with self._store_cond: + ok = self._store_cond.wait_for(lambda: key in self._store_values, timeout=timeout) + if not ok: + raise TimeoutError(f"wait timeout after {int(timeout * 1000)}ms, key: {key}") + return self._store_values.pop(key) - def _send_task_notif(self, remote_agent_name: str, trans_task: NIXLChunckedTransTask): - if remote_agent_name not in self.remote_agents: - if remote_agent_name == trans_task.decode_agent_name: - self.connect_add_remote_agent(trans_task.create_decode_agent_obj()) - else: - self.connect_add_remote_agent(trans_task.create_prefill_agent_obj()) + def get_notifs(self) -> List[bytes]: + notifs = [] + while True: + try: + notifs.append(self.notif_queue.get_nowait()) + except queue.Empty: + break + return notifs - remote_store = self.remote_stores[remote_agent_name] - counter = remote_store.add(f"notif/{remote_agent_name}/counter", 1) - 1 - remote_store.set(self._notif_key(remote_agent_name, counter), pickle.dumps(trans_task)) + def send_notif(self, peer_name: str, host_ip: str, port: int, payload: bytes): + self._call(peer_name, host_ip, port, "push_notif", payload) return - def _copy_notify_task(self, trans_task: NIXLChunckedTransTask) -> NIXLChunckedTransTask: - new_trans_task: NIXLChunckedTransTask = copy.copy(trans_task) - new_trans_task.mem_indexes = None - new_trans_task.xfer_handle = None - return new_trans_task + def send_store_value(self, peer_name: str, host_ip: str, port: int, key: str, value: bytes): + self._call(peer_name, host_ip, port, "set_value", key, value) + return + + def _call(self, peer_name: str, host_ip: str, port: int, method: str, *args): + conn_key = (peer_name, host_ip, port) + with self._conn_lock: + conn = self._conns.get(conn_key) + if conn is None: + conn = rpyc.connect( + host_ip, + port, + config={ + "allow_pickle": True, + "allow_all_attrs": True, + "allow_getattr": True, + "allow_setattr": True, + }, + ) + self._conns[conn_key] = conn + try: + getattr(conn.root, method)(*args) + except Exception as e: + self._conns.pop(conn_key, None) + try: + conn.close() + except Exception: + pass + raise RuntimeError(f"NCCL control RPC {method} to {peer_name} failed") from e + return - def _next_xfer_handle(self): - self._xfer_handle_counter += 1 - return self._xfer_handle_counter - @staticmethod - def _notif_key(agent_name: str, counter: int) -> str: - return f"notif/{agent_name}/{counter}" +class _NcclControlStore: + def __init__(self, transporter: "NcclKVTransporter", remote_agent_name: str): + self.transporter = transporter + self.remote_agent_name = remote_agent_name + + def set(self, key: str, value: bytes): + remote_metadata = self.transporter._get_remote_metadata(self.remote_agent_name) + self.transporter.control_channel.send_store_value( + self.remote_agent_name, + remote_metadata.host_ip, + remote_metadata.control_port, + self._send_key(key), + bytes(value), + ) + return - @staticmethod - def _get_notify_source_agent_name(notify: bytes) -> str: - try: - notify_obj = pickle.loads(notify) - except BaseException: - return "unknown" + def get(self, key: str) -> bytes: + return self.transporter.control_channel.wait_store_value(self._recv_key(key)) - if not isinstance(notify_obj, NIXLChunckedTransTask): - return "unknown" + def _send_key(self, key: str) -> str: + return f"{self.transporter.agent_name}->{self.remote_agent_name}:{key}" - if notify_obj.nixl_write_stage == "request": - return notify_obj.prefill_agent_name or "unknown" - if notify_obj.nixl_write_stage in ["ready", "done"]: - return notify_obj.decode_agent_name or "unknown" - return notify_obj.prefill_agent_name or notify_obj.decode_agent_name or "unknown" + def _recv_key(self, key: str) -> str: + return f"{self.remote_agent_name}->{self.transporter.agent_name}:{key}" From 2d5d56a198aca58f153d3c318006a01efa8de0a2 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Wed, 10 Jun 2026 08:53:59 +0000 Subject: [PATCH 09/17] fix --- docs/CN/source/tutorial/api_server_args.rst | 8 +- .../source/tutorial/deepseek_deployment.rst | 12 +-- docs/EN/source/tutorial/api_server_args.rst | 8 +- .../source/tutorial/deepseek_deployment.rst | 12 +-- .../kv_cache_mem_manager/mem_manager.py | 2 +- lightllm/server/api_cli.py | 18 ++-- lightllm/server/api_http.py | 14 ++-- lightllm/server/api_start.py | 6 +- .../{nixl_params.py => pd_kv_trans_params.py} | 10 +-- lightllm/server/core/objs/sampling_params.py | 10 +-- lightllm/server/core/objs/start_args_type.py | 6 +- lightllm/server/httpserver/manager.py | 44 +++++----- lightllm/server/httpserver/pd_loop.py | 34 ++++---- .../httpserver_for_pd_master/manager.py | 50 +++++------ lightllm/server/pd_io_struct.py | 72 +++++++--------- lightllm/server/router/manager.py | 12 +-- .../server/router/model_infer/infer_batch.py | 36 ++++---- .../model_infer/mode_backend/__init__.py | 8 +- .../model_infer/mode_backend/base_backend.py | 48 +++++------ .../mode_backend/chunked_prefill/impl.py | 4 +- .../mode_backend/diverse_backend/impl.py | 2 +- .../mode_backend/dp_backend/impl.py | 8 +- .../mode_backend/{pd_nixl => pd}/__init__.py | 0 .../{pd_nixl => pd}/base_kv_move_manager.py | 16 ++-- .../decode_node_impl/__init__.py | 0 .../decode_node_impl/decode_impl.py | 56 ++++++------- .../decode_node_impl/decode_impl_for_dp.py | 20 ++--- .../decode_kv_move_manager.py | 12 +-- .../decode_node_impl/decode_trans_process.py | 62 +++++++------- .../decode_node_impl/up_status.py | 10 +-- .../{pd_nixl => pd}/kv_transporter.py | 2 +- .../{pd_nixl => pd}/nccl_kv_transporter.py | 82 +++++++++---------- .../{pd_nixl => pd}/nixl_kv_transporter.py | 50 +++++------ .../mode_backend/{pd_nixl => pd}/p2p_fix.py | 2 +- .../prefill_node_impl/__init__.py | 0 .../prefill_node_impl/prefill_impl.py | 62 +++++++------- .../prefill_node_impl/prefill_impl_for_dp.py | 18 ++-- .../prefill_kv_move_manager.py | 6 +- .../prefill_trans_process.py | 46 +++++------ .../{pd_nixl => pd}/trans_process_obj.py | 0 .../server/router/model_infer/model_rpc.py | 24 +++--- lightllm/server/router/req_queue/__init__.py | 2 +- skills/test_model/qwen3-8b-pd-nixl/SKILL.md | 28 +++---- .../qwen3-8b-pd-nixl/check_nvidia_peermem.sh | 2 +- .../test_model/qwen3.5-0.8b-pd-nixl/SKILL.md | 48 +++++------ .../check_nvidia_peermem.sh | 2 +- test/acc/{test_pd_nixl.sh => test_pd.sh} | 4 +- test/start_scripts/README.md | 10 +-- .../{pd_nixl_decode.sh => pd_decode.sh} | 4 +- .../{pd_nixl_prefill.sh => pd_prefill.sh} | 4 +- 50 files changed, 492 insertions(+), 504 deletions(-) rename lightllm/server/core/objs/{nixl_params.py => pd_kv_trans_params.py} (54%) rename lightllm/server/router/model_infer/mode_backend/{pd_nixl => pd}/__init__.py (100%) rename lightllm/server/router/model_infer/mode_backend/{pd_nixl => pd}/base_kv_move_manager.py (89%) rename lightllm/server/router/model_infer/mode_backend/{pd_nixl => pd}/decode_node_impl/__init__.py (100%) rename lightllm/server/router/model_infer/mode_backend/{pd_nixl => pd}/decode_node_impl/decode_impl.py (82%) rename lightllm/server/router/model_infer/mode_backend/{pd_nixl => pd}/decode_node_impl/decode_impl_for_dp.py (67%) rename lightllm/server/router/model_infer/mode_backend/{pd_nixl => pd}/decode_node_impl/decode_kv_move_manager.py (88%) rename lightllm/server/router/model_infer/mode_backend/{pd_nixl => pd}/decode_node_impl/decode_trans_process.py (91%) rename lightllm/server/router/model_infer/mode_backend/{pd_nixl => pd}/decode_node_impl/up_status.py (93%) rename lightllm/server/router/model_infer/mode_backend/{pd_nixl => pd}/kv_transporter.py (95%) rename lightllm/server/router/model_infer/mode_backend/{pd_nixl => pd}/nccl_kv_transporter.py (87%) rename lightllm/server/router/model_infer/mode_backend/{pd_nixl => pd}/nixl_kv_transporter.py (87%) rename lightllm/server/router/model_infer/mode_backend/{pd_nixl => pd}/p2p_fix.py (97%) rename lightllm/server/router/model_infer/mode_backend/{pd_nixl => pd}/prefill_node_impl/__init__.py (100%) rename lightllm/server/router/model_infer/mode_backend/{pd_nixl => pd}/prefill_node_impl/prefill_impl.py (70%) rename lightllm/server/router/model_infer/mode_backend/{pd_nixl => pd}/prefill_node_impl/prefill_impl_for_dp.py (66%) rename lightllm/server/router/model_infer/mode_backend/{pd_nixl => pd}/prefill_node_impl/prefill_kv_move_manager.py (93%) rename lightllm/server/router/model_infer/mode_backend/{pd_nixl => pd}/prefill_node_impl/prefill_trans_process.py (90%) rename lightllm/server/router/model_infer/mode_backend/{pd_nixl => pd}/trans_process_obj.py (100%) rename test/acc/{test_pd_nixl.sh => test_pd.sh} (98%) rename test/start_scripts/single_pd_master/{pd_nixl_decode.sh => pd_decode.sh} (90%) rename test/start_scripts/single_pd_master/{pd_nixl_prefill.sh => pd_prefill.sh} (91%) diff --git a/docs/CN/source/tutorial/api_server_args.rst b/docs/CN/source/tutorial/api_server_args.rst index a42f30329c..8e7f9d78e8 100644 --- a/docs/CN/source/tutorial/api_server_args.rst +++ b/docs/CN/source/tutorial/api_server_args.rst @@ -13,8 +13,8 @@ APIServer 参数详解 设置运行模式,可选值: * ``normal``: 单服务器模式(默认) - * ``nixl_prefill``: 预填充模式(用于 pd 分离运行模式) - * ``nixl_decode``: 解码模式(用于 pd 分离运行模式) + * ``prefill``: 预填充模式(用于 pd 分离运行模式) + * ``decode``: 解码模式(用于 pd 分离运行模式) * ``pd_master``: pd 主节点模式(用于 pd 分离运行模式) * ``config_server``: 配置服务器模式(用于 pd 分离模式,用于注册 pd_master 节点并获取 pd_master 节点列表),专门为大规模、高并发场景设计,当 `pd_master` 遇到显著的 CPU 瓶颈时使用。 @@ -56,13 +56,13 @@ PD 分离模式参数 PD 主节点 IP 地址,默认为 ``0.0.0.0`` - 当 run_mode 设置为 nixl_prefill 或 nixl_decode 时需要设置此参数 + 当 run_mode 设置为 prefill 或 decode 时需要设置此参数 .. option:: --pd_master_port PD 主节点端口,默认为 ``1212`` - 当 run_mode 设置为 nixl_prefill 或 nixl_decode 时需要设置此参数 + 当 run_mode 设置为 prefill 或 decode 时需要设置此参数 .. option:: --pd_decode_rpyc_port diff --git a/docs/CN/source/tutorial/deepseek_deployment.rst b/docs/CN/source/tutorial/deepseek_deployment.rst index 8901dd9b98..00cf12da0c 100644 --- a/docs/CN/source/tutorial/deepseek_deployment.rst +++ b/docs/CN/source/tutorial/deepseek_deployment.rst @@ -174,7 +174,7 @@ PD (Prefill-Decode) 分离模式将预填充和解码阶段分离部署,可以 .. code-block:: bash # PD prefill 模式 for DeepSeek-R1 (DP+EP) on H200 - # 使用方法: sh pd_nixl_prefill.sh + # 使用方法: sh pd_prefill.sh # 默认使用 NIXL 传输;如需使用 NCCL 数据面,可设置 LIGHTLLM_PD_KV_TRANSPORT_BACKEND=nccl # nvidia-cuda-mps-control -d,运行MPS(可选, 有mps支持性能会好特别多,但是部分显卡和驱动环境开启mps会容易出现错误,建议升级驱动到较高版本,特别是H系列卡) @@ -183,7 +183,7 @@ PD (Prefill-Decode) 分离模式将预填充和解码阶段分离部署,可以 nvidia-cuda-mps-control -d LOADWORKER=18 python -m lightllm.server.api_server \ --model_dir /path/DeepSeek-R1 \ - --run_mode "nixl_prefill" \ + --run_mode "prefill" \ --tp 8 \ --dp 8 \ --host $host \ @@ -201,14 +201,14 @@ PD (Prefill-Decode) 分离模式将预填充和解码阶段分离部署,可以 .. code-block:: bash # PD decode 模式 for DeepSeek-R1 (DP+EP) on H200 - # 使用方法: sh pd_nixl_decode.sh + # 使用方法: sh pd_decode.sh # 默认使用 NIXL 传输;如需使用 NCCL 数据面,可设置 LIGHTLLM_PD_KV_TRANSPORT_BACKEND=nccl export host=$1 export pd_master_ip=$2 nvidia-cuda-mps-control -d LOADWORKER=18 python -m lightllm.server.api_server \ --model_dir /path/DeepSeek-R1 \ - --run_mode "nixl_decode" \ + --run_mode "decode" \ --tp 8 \ --dp 8 \ --host $host \ @@ -276,7 +276,7 @@ PD (Prefill-Decode) 分离模式将预填充和解码阶段分离部署,可以 nvidia-cuda-mps-control -d LOADWORKER=18 python -m lightllm.server.api_server \ --model_dir /path/DeepSeek-R1 \ - --run_mode "nixl_prefill" \ + --run_mode "prefill" \ --host $host \ --port 8019 \ --tp 8 \ @@ -295,7 +295,7 @@ PD (Prefill-Decode) 分离模式将预填充和解码阶段分离部署,可以 nvidia-cuda-mps-control -d LOADWORKER=18 python -m lightllm.server.api_server \ --model_dir /path/DeepSeek-R1 \ - --run_mode "nixl_decode" \ + --run_mode "decode" \ --host $host \ --port 8121 \ --nccl_port 12322 \ diff --git a/docs/EN/source/tutorial/api_server_args.rst b/docs/EN/source/tutorial/api_server_args.rst index 4e8083881a..84785de3b7 100644 --- a/docs/EN/source/tutorial/api_server_args.rst +++ b/docs/EN/source/tutorial/api_server_args.rst @@ -13,8 +13,8 @@ Basic Configuration Parameters Set the running mode, optional values: * ``normal``: Single server mode (default) - * ``nixl_prefill``: Prefill mode (for pd disaggregation running mode) - * ``nixl_decode``: Decode mode (for pd disaggregation running mode) + * ``prefill``: Prefill mode (for pd disaggregation running mode) + * ``decode``: Decode mode (for pd disaggregation running mode) * ``pd_master``: pd master node mode (for pd disaggregation running mode) * ``config_server``: Configuration server mode (for pd disaggregation mode, used to register pd_master nodes and get pd_master node list), specifically designed for large-scale, high-concurrency scenarios, used when `pd_master` encounters significant CPU bottlenecks. @@ -56,13 +56,13 @@ PD disaggregation Mode Parameters PD master node IP address, default is ``0.0.0.0`` - This parameter needs to be set when run_mode is set to nixl_prefill or nixl_decode + This parameter needs to be set when run_mode is set to prefill or decode .. option:: --pd_master_port PD master node port, default is ``1212`` - This parameter needs to be set when run_mode is set to nixl_prefill or nixl_decode + This parameter needs to be set when run_mode is set to prefill or decode .. option:: --pd_decode_rpyc_port diff --git a/docs/EN/source/tutorial/deepseek_deployment.rst b/docs/EN/source/tutorial/deepseek_deployment.rst index 24b65e4727..3a968b8948 100755 --- a/docs/EN/source/tutorial/deepseek_deployment.rst +++ b/docs/EN/source/tutorial/deepseek_deployment.rst @@ -174,7 +174,7 @@ PD (Prefill-Decode) disaggregation mode separates prefill and decode stages for .. code-block:: bash # PD prefill mode for DeepSeek-R1 (DP+EP) on H200 - # Usage: sh pd_nixl_prefill.sh + # Usage: sh pd_prefill.sh # NIXL is used by default. To use NCCL as the data-plane backend, set LIGHTLLM_PD_KV_TRANSPORT_BACKEND=nccl. # nvidia-cuda-mps-control -d, run MPS (optional, performance will be much better with mps support, but some GPUs may encounter errors when enabling mps, it's recommended to upgrade to a higher driver version, especially for H-series cards) @@ -183,7 +183,7 @@ PD (Prefill-Decode) disaggregation mode separates prefill and decode stages for nvidia-cuda-mps-control -d LOADWORKER=18 python -m lightllm.server.api_server \ --model_dir /path/DeepSeek-R1 \ - --run_mode "nixl_prefill" \ + --run_mode "prefill" \ --tp 8 \ --dp 8 \ --host $host \ @@ -198,14 +198,14 @@ PD (Prefill-Decode) disaggregation mode separates prefill and decode stages for .. code-block:: bash # PD decode mode for DeepSeek-R1 (DP+EP) on H200 - # Usage: sh pd_nixl_decode.sh + # Usage: sh pd_decode.sh # NIXL is used by default. To use NCCL as the data-plane backend, set LIGHTLLM_PD_KV_TRANSPORT_BACKEND=nccl. export host=$1 export pd_master_ip=$2 nvidia-cuda-mps-control -d LOADWORKER=18 python -m lightllm.server.api_server \ --model_dir /path/DeepSeek-R1 \ - --run_mode "nixl_decode" \ + --run_mode "decode" \ --tp 8 \ --dp 8 \ --host $host \ @@ -273,7 +273,7 @@ Supports multiple PD Master nodes, providing better load balancing and high avai nvidia-cuda-mps-control -d LOADWORKER=18 python -m lightllm.server.api_server \ --model_dir /path/DeepSeek-R1 \ - --run_mode "nixl_prefill" \ + --run_mode "prefill" \ --host $host \ --port 8019 \ --tp 8 \ @@ -292,7 +292,7 @@ Supports multiple PD Master nodes, providing better load balancing and high avai nvidia-cuda-mps-control -d LOADWORKER=18 python -m lightllm.server.api_server \ --model_dir /path/DeepSeek-R1 \ - --run_mode "nixl_decode" \ + --run_mode "decode" \ --host $host \ --port 8121 \ --nccl_port 12322 \ diff --git a/lightllm/common/kv_cache_mem_manager/mem_manager.py b/lightllm/common/kv_cache_mem_manager/mem_manager.py index 69b51b4ab4..658d3e899c 100755 --- a/lightllm/common/kv_cache_mem_manager/mem_manager.py +++ b/lightllm/common/kv_cache_mem_manager/mem_manager.py @@ -199,7 +199,7 @@ def write_to_shm(self, req_manager): 将 mem manager 写入到 shm中,方便pd分离等特性直接从中读取,不依赖进程间队列。 """ if kv_trans_use_p2p(): - from lightllm.server.router.model_infer.mode_backend.pd_nixl.p2p_fix import reduce_tensor + from lightllm.server.router.model_infer.mode_backend.pd.p2p_fix import reduce_tensor mp.reductions.reduce_tensor.__code__ = reduce_tensor.__code__ diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index fe18423a37..15f1fa27ea 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -9,14 +9,14 @@ def make_argument_parser() -> argparse.ArgumentParser: type=str, choices=[ "normal", - "nixl_prefill", - "nixl_decode", + "prefill", + "decode", "pd_master", "config_server", "visual_only", ], default="normal", - help="""set run mode, normal is started for a single server, nixl_prefill/nixl_decode/pd_master is for pd split run mode, + help="""set run mode, normal is started for a single server, prefill/decode/pd_master is for pd split run mode, config_server is for pd split mode used to register pd_master node, and get pd_master node list, specifically designed for large-scale, high-concurrency scenarios where `pd_master` encounters significant CPU bottlenecks.""", @@ -45,13 +45,13 @@ def make_argument_parser() -> argparse.ArgumentParser: "--pd_master_ip", type=str, default="0.0.0.0", - help="when run_mode set to nixl_prefill or nixl_decode, you need set this pd_mater_ip", + help="when run_mode set to prefill or decode, you need set this pd_mater_ip", ) parser.add_argument( "--pd_master_port", type=int, default=1212, - help="when run_mode set to nixl_prefill or nixl_decode, you need set this pd_mater_port", + help="when run_mode set to prefill or decode, you need set this pd_mater_port", ) parser.add_argument( "--pd_decode_rpyc_port", @@ -87,17 +87,17 @@ def make_argument_parser() -> argparse.ArgumentParser: proxy module use config server to find remote vit infer nodes to infer img""", ) parser.add_argument( - "--nixl_pd_kv_page_num", + "--pd_kv_page_num", type=int, default=16, - help="nixl pd mode, kv move page_num", + help="pd mode, kv move page_num", ) parser.add_argument( - "--nixl_pd_kv_page_size", + "--pd_kv_page_size", type=int, default=1024, - help="nixl pd mode, kv page size.", + help="pd mode, kv page size.", ) parser.add_argument( diff --git a/lightllm/server/api_http.py b/lightllm/server/api_http.py index 62a4c4d805..c6809fd2ad 100755 --- a/lightllm/server/api_http.py +++ b/lightllm/server/api_http.py @@ -231,7 +231,7 @@ async def token_load(request: Request): @app.post("/generate") async def generate(request: Request) -> Response: - if get_env_start_args().run_mode in ["nixl_prefill", "nixl_decode"]: + if get_env_start_args().run_mode in ["prefill", "decode"]: return create_error_response( HTTPStatus.EXPECTATION_FAILED, "service in pd mode dont recv reqs from http interface" ) @@ -253,7 +253,7 @@ async def generate(request: Request) -> Response: @app.post("/generate_stream") async def generate_stream(request: Request) -> Response: - if get_env_start_args().run_mode in ["nixl_prefill", "nixl_decode"]: + if get_env_start_args().run_mode in ["prefill", "decode"]: return create_error_response( HTTPStatus.EXPECTATION_FAILED, "service in pd mode dont recv reqs from http interface" ) @@ -275,7 +275,7 @@ async def generate_stream(request: Request) -> Response: @app.post("/get_score") async def get_score(request: Request) -> Response: - if get_env_start_args().run_mode in ["nixl_prefill", "nixl_decode"]: + if get_env_start_args().run_mode in ["prefill", "decode"]: return create_error_response( HTTPStatus.EXPECTATION_FAILED, "service in pd mode dont recv reqs from http interface" ) @@ -291,7 +291,7 @@ async def get_score(request: Request) -> Response: @app.post("/") async def compat_generate(request: Request) -> Response: - if get_env_start_args().run_mode in ["nixl_prefill", "nixl_decode"]: + if get_env_start_args().run_mode in ["prefill", "decode"]: return create_error_response( HTTPStatus.EXPECTATION_FAILED, "service in pd mode dont recv reqs from http interface" ) @@ -306,7 +306,7 @@ async def compat_generate(request: Request) -> Response: @app.post("/v1/chat/completions", response_model=ChatCompletionResponse) async def chat_completions(request: ChatCompletionRequest, raw_request: Request) -> Response: - if get_env_start_args().run_mode in ["nixl_prefill", "nixl_decode"]: + if get_env_start_args().run_mode in ["prefill", "decode"]: return create_error_response( HTTPStatus.EXPECTATION_FAILED, "service in pd mode dont recv reqs from http interface" ) @@ -323,7 +323,7 @@ async def chat_completions(request: ChatCompletionRequest, raw_request: Request) @app.post("/v1/completions", response_model=CompletionResponse) async def completions(request: CompletionRequest, raw_request: Request) -> Response: - if get_env_start_args().run_mode in ["nixl_prefill", "nixl_decode"]: + if get_env_start_args().run_mode in ["prefill", "decode"]: return create_error_response( HTTPStatus.EXPECTATION_FAILED, "service in pd mode dont recv reqs from http interface" ) @@ -340,7 +340,7 @@ async def completions(request: CompletionRequest, raw_request: Request) -> Respo @app.post("/v1/messages") async def anthropic_messages(raw_request: Request) -> Response: - if get_env_start_args().run_mode in ["nixl_prefill", "nixl_decode"]: + if get_env_start_args().run_mode in ["prefill", "decode"]: return create_error_response( HTTPStatus.EXPECTATION_FAILED, "service in pd mode dont recv reqs from http interface" ) diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py index 825888e9d7..64cb276d34 100644 --- a/lightllm/server/api_start.py +++ b/lightllm/server/api_start.py @@ -83,7 +83,7 @@ def normal_or_p_d_start(args): enable_mps() - if args.run_mode not in ["normal", "nixl_prefill", "nixl_decode", "visual_only"]: + if args.run_mode not in ["normal", "prefill", "decode", "visual_only"]: return # 通过模型的参数判断是否是多模态模型,包含哪几种模态, 并设置是否启动相应得模块 @@ -99,7 +99,7 @@ def normal_or_p_d_start(args): args.disable_audio = True # pd 分离模式下,不启动多模态的模块 - if args.run_mode == "nixl_decode": + if args.run_mode == "decode": args.disable_audio = True args.disable_vision = True @@ -404,7 +404,7 @@ def normal_or_p_d_start(args): args.pd_p_allowed_port_max = 30000 # p d 分离模式下,decode节点的调度间隙是0 - if args.run_mode == "nixl_decode": + if args.run_mode == "decode": args.router_max_wait_tokens = 0 send_and_receive_node_ip(args) # 多机用于收发node ip diff --git a/lightllm/server/core/objs/nixl_params.py b/lightllm/server/core/objs/pd_kv_trans_params.py similarity index 54% rename from lightllm/server/core/objs/nixl_params.py rename to lightllm/server/core/objs/pd_kv_trans_params.py index 8b64554f84..68d9de3aa9 100644 --- a/lightllm/server/core/objs/nixl_params.py +++ b/lightllm/server/core/objs/pd_kv_trans_params.py @@ -2,13 +2,13 @@ import ctypes from typing import Optional -LIGHTLLM_NIXL_PARAM_OBJ_MAX_BYTES = int(os.getenv("LIGHTLLM_NIXL_PARAM_OBJ_MAX_BYTES", 8 * 1024)) +LIGHTLLM_PD_KV_TRANS_PARAM_OBJ_MAX_BYTES = int(os.getenv("LIGHTLLM_PD_KV_TRANS_PARAM_OBJ_MAX_BYTES", 8 * 1024)) -class NIXLParamObj(ctypes.Structure): +class PDKVTransParamObj(ctypes.Structure): _pack_ = 4 _fields_ = [ - ("data", ctypes.c_ubyte * LIGHTLLM_NIXL_PARAM_OBJ_MAX_BYTES), + ("data", ctypes.c_ubyte * LIGHTLLM_PD_KV_TRANS_PARAM_OBJ_MAX_BYTES), ("data_len", ctypes.c_int), ] @@ -21,8 +21,8 @@ def set(self, obj_bytes: Optional[bytes]): return assert ( - len(obj_bytes) <= LIGHTLLM_NIXL_PARAM_OBJ_MAX_BYTES - ), f"NIXL_PARAM_OBJ bytes len {len(obj_bytes)} exceeds length of {LIGHTLLM_NIXL_PARAM_OBJ_MAX_BYTES} bytes." + len(obj_bytes) <= LIGHTLLM_PD_KV_TRANS_PARAM_OBJ_MAX_BYTES + ), f"PD_KV_TRANS_PARAM_OBJ bytes len {len(obj_bytes)} exceeds length of {LIGHTLLM_PD_KV_TRANS_PARAM_OBJ_MAX_BYTES} bytes." ctypes.memmove(self.data, obj_bytes, len(obj_bytes)) self.data_len = len(obj_bytes) return diff --git a/lightllm/server/core/objs/sampling_params.py b/lightllm/server/core/objs/sampling_params.py index c94f3c6957..f503a92fb3 100644 --- a/lightllm/server/core/objs/sampling_params.py +++ b/lightllm/server/core/objs/sampling_params.py @@ -3,7 +3,7 @@ from typing import Optional, List, Tuple, Union from transformers import GenerationConfig from lightllm.server.req_id_generator import MAX_BEST_OF -from .nixl_params import NIXLParamObj +from .pd_kv_trans_params import PDKVTransParamObj _SAMPLING_EPS = 1e-5 DEFAULT_INPUT_PENALTY = os.getenv("INPUT_PENALTY", "False").upper() in ["ON", "TRUE", "1"] @@ -333,8 +333,8 @@ class SamplingParams(ctypes.Structure): ("move_kv_to_decode_node", DecodeNode), # move kv to deocde node, only used in pd mode # in pd split mode, use to keep the id of pd master ("pd_master_node_id", NodeUUId), - # nixl params object, only used in nixl pd mode, used to build nixl connection in p and d - ("nixl_params", NIXLParamObj), + # pd params object, only used in pd mode, used to build kv transport connection in prefill and decode + ("pd_kv_trans_params", PDKVTransParamObj), ("skip_special_tokens", ctypes.c_bool), # whether to skip special tokens when decoding ("add_special_tokens", ctypes.c_bool), # whether to add special tokens when encoding ( @@ -386,8 +386,8 @@ def init(self, tokenizer, **kwargs): self.move_kv_to_decode_node = DecodeNode() self.move_kv_to_decode_node.initialize(kwargs.get("move_kv_to_decode_node", None)) - self.nixl_params = NIXLParamObj() - self.nixl_params.set(kwargs.get("nixl_params", None)) + self.pd_kv_trans_params = PDKVTransParamObj() + self.pd_kv_trans_params.set(kwargs.get("pd_kv_trans_params", None)) self.pd_master_node_id = NodeUUId() self.pd_master_node_id.initialize(kwargs.get("pd_master_node_id", 0)) diff --git a/lightllm/server/core/objs/start_args_type.py b/lightllm/server/core/objs/start_args_type.py index 0b12cbfc54..5ed804d6a0 100644 --- a/lightllm/server/core/objs/start_args_type.py +++ b/lightllm/server/core/objs/start_args_type.py @@ -8,7 +8,7 @@ class StartArgs: run_mode: str = field( default="normal", - metadata={"choices": ["normal", "pd_master", "nixl_prefill", "nixl_decode", "config_server", "visual_only"]}, + metadata={"choices": ["normal", "pd_master", "prefill", "decode", "config_server", "visual_only"]}, ) host: str = field(default="127.0.0.1") port: int = field(default=8000) @@ -171,8 +171,8 @@ class StartArgs: mtp_draft_model_dir: Optional[str] = field(default=None) mtp_step: int = field(default=0) kv_quant_calibration_config_path: Optional[str] = field(default=None) - nixl_pd_kv_page_num: int = field(default=16) - nixl_pd_kv_page_size: int = field(default=1024) + pd_kv_page_num: int = field(default=16) + pd_kv_page_size: int = field(default=1024) pd_node_id: int = field(default=-1) enable_cpu_cache: bool = field(default=False) cpu_cache_storage_size: float = field(default=2) diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py index fd97886844..5049cd96f7 100644 --- a/lightllm/server/httpserver/manager.py +++ b/lightllm/server/httpserver/manager.py @@ -17,7 +17,7 @@ from websockets import ClientConnection from fastapi import Request from ..tokenizer import get_tokenizer -from ..pd_io_struct import NodeRole, ObjType, NIXLDecodeNodeInfo +from ..pd_io_struct import NodeRole, ObjType, PDDecodeNodeInfo from ..embed_cache.utils import get_shm_name_data, create_shm from ..multimodal_params import AudioItem, MultimodalParams, ImageItem from ..req_id_generator import ReqIDGenerator @@ -34,7 +34,7 @@ from lightllm.utils.statics_utils import MovingAverage from lightllm.utils.config_utils import get_vocab_size from lightllm.utils.envs_utils import get_unique_server_name -from lightllm.utils.error_utils import ClientDisconnected, NixlPrefillNodeStopGenToken +from lightllm.utils.error_utils import ClientDisconnected, PDPrefillNodeStopGenToken from rpyc.utils.classic import obtain logger = init_logger(__name__) @@ -112,7 +112,7 @@ def __init__( self.metric_client = MetricClient(args.metric_port) self.pd_mode: NodeRole = NodeRole(self.args.run_mode) - assert self.pd_mode in [NodeRole.NORMAL, NodeRole.NP, NodeRole.ND] + assert self.pd_mode in [NodeRole.NORMAL, NodeRole.P, NodeRole.D] self.id_gen = ReqIDGenerator() self.first_time_costs = MovingAverage() self.per_token_costs = MovingAverage() @@ -196,7 +196,7 @@ def _assert_image_token_count(self, token_num: int): return async def _alloc_multimodal_resources(self, multimodal_params: MultimodalParams, sampling_params: SamplingParams): - # 只有 P 和 NORMAL 节点需要真的管理多模态资源 + # 只有 prefill 和 NORMAL 节点需要真的管理多模态资源 if self.pd_mode.is_P_or_NORMAL(): items, md5sums, tokens_nums, datas = [], [], [], [] for img in multimodal_params.images: @@ -226,7 +226,7 @@ async def _alloc_multimodal_resources(self, multimodal_params: MultimodalParams, return async def _release_multimodal_resources(self, multimodal_params: MultimodalParams): - # 只有 P 和 NORMAL 节点需要真的管理多模态资源 + # 只有 prefill 和 NORMAL 节点需要真的管理多模态资源 if self.pd_mode.is_P_or_NORMAL(): if multimodal_params is not None: ids_to_release = [] @@ -312,10 +312,10 @@ async def generate( sampling_params: SamplingParams, multimodal_params: MultimodalParams, request: Request, - # 该参数只会在 nixl pd mode 中使用,用于上报一些信息给 pd_master - nixl_pd_upload_websocket: ClientConnection = None, + # 该参数只会在 pd mode 中使用,用于上报一些信息给 pd_master + pd_upload_websocket: ClientConnection = None, # 用于等待 pd_master 下发的交换信息 - nixl_pd_event: asyncio.Event = None, + pd_event: asyncio.Event = None, ) -> AsyncGenerator[Tuple[int, str, dict, FinishStatus], None]: start_time = time.time() @@ -370,28 +370,26 @@ async def generate( "check_and_repair_length_done", ) - if nixl_pd_upload_websocket is not None and self.pd_mode.is_NP(): - # 在 nixl pd 模式下的 p 节点, 为了更好的兼容多模态的推理流程,np 节点需要先上报其 encode 好的 prompt ids 信息,然后 - # 再等待 pd_master 传输下来的对应的进行 decode 节点的decode信息,然后再执行后续的流程 - logger.info( - f"nixl prefill node upload group_req_id {group_request_id} prompt ids len : {len(prompt_ids)}" - ) - await nixl_pd_upload_websocket.send( - pickle.dumps((ObjType.NIXL_UPLOAD_NP_PROMPT_IDS, group_request_id, prompt_ids)) + if pd_upload_websocket is not None and self.pd_mode.is_P(): + # 在 pd 模式下的 prefill 节点,为了兼容多模态推理流程,需要先上报 encode 好的 prompt ids, + # 再等待 pd_master 下发对应请求的 decode 节点信息,然后执行后续流程。 + logger.info(f"pd prefill node upload group_req_id {group_request_id} prompt ids len : {len(prompt_ids)}") + await pd_upload_websocket.send( + pickle.dumps((ObjType.PD_UPLOAD_PREFILL_PROMPT_IDS, group_request_id, prompt_ids)) ) try: - await asyncio.wait_for(nixl_pd_event.wait(), timeout=180) + await asyncio.wait_for(pd_event.wait(), timeout=180) except asyncio.TimeoutError: - logger.error(f"nixl np node wait nixl_pd_event 180s time out, group_req_id {group_request_id}") - raise Exception(f"group_req_id {group_request_id} wait nixl_pd_event time out") + logger.error(f"pd prefill node wait pd_event 180s time out, group_req_id {group_request_id}") + raise Exception(f"group_req_id {group_request_id} wait pd_event time out") - decode_node_info: NIXLDecodeNodeInfo = nixl_pd_event.decode_node_info - sampling_params.nixl_params.set(pickle.dumps(decode_node_info)) + decode_node_info: PDDecodeNodeInfo = pd_event.decode_node_info + sampling_params.pd_kv_trans_params.set(pickle.dumps(decode_node_info)) if decode_node_info.ready_kv_len == len(prompt_ids) - 1: # 如果 decode 节点的 ready_kv_len 和 prefill encode 的 len(prompt ids) -1 相等,说明不需要进行 prefill - # 直接 raise NixlPrefillNodeStopGenToken - raise NixlPrefillNodeStopGenToken(group_request_id=group_request_id) + # 直接 raise PDPrefillNodeStopGenToken + raise PDPrefillNodeStopGenToken(group_request_id=group_request_id) # 申请资源并存储 alloced_req_indexes = [] diff --git a/lightllm/server/httpserver/pd_loop.py b/lightllm/server/httpserver/pd_loop.py index e341da2a85..dcf0c89fed 100644 --- a/lightllm/server/httpserver/pd_loop.py +++ b/lightllm/server/httpserver/pd_loop.py @@ -20,7 +20,7 @@ from ..pd_io_struct import PD_Master_Obj from lightllm.server.core.objs import StartArgs from lightllm.server.core.objs import SamplingParams -from lightllm.utils.error_utils import NixlPrefillNodeStopGenToken +from lightllm.utils.error_utils import PDPrefillNodeStopGenToken logger = init_logger(__name__) @@ -115,8 +115,8 @@ async def _pd_handle_task(manager: HttpServerManager, pd_master_obj: PD_Master_O if obj[0] == ObjType.REQ: prompt, sampling_params, multimodal_params = obj[1] group_req_id = sampling_params.group_request_id - nixl_pd_event = asyncio.Event() - group_req_id_to_event[group_req_id] = nixl_pd_event + pd_event = asyncio.Event() + group_req_id_to_event[group_req_id] = pd_event asyncio.create_task( _pd_process_generate( manager=manager, @@ -124,8 +124,8 @@ async def _pd_handle_task(manager: HttpServerManager, pd_master_obj: PD_Master_O sampling_params=sampling_params, multimodal_params=multimodal_params, forwarding_queue=forwarding_queue, - nixl_pd_upload_websocket=websocket, - nixl_pd_event=nixl_pd_event, + pd_upload_websocket=websocket, + pd_event=pd_event, ) ) elif obj[0] == ObjType.ABORT: @@ -141,14 +141,14 @@ async def delayed_abort_task(group_req_id, retry_count): asyncio.create_task(delayed_abort_task(group_req_id=group_req_id, retry_count=4)) - elif obj[0] == ObjType.NIXL_REQ_DECODE_NODE_INFO: + elif obj[0] == ObjType.PD_REQ_DECODE_NODE_INFO: _, group_req_id, decode_node_info = obj - nixl_pd_event = group_req_id_to_event.pop(group_req_id, None) - if nixl_pd_event is None: - logger.error(f"error in find nixl_pd_event, info: {obj}") + pd_event = group_req_id_to_event.pop(group_req_id, None) + if pd_event is None: + logger.error(f"error in find pd_event, info: {obj}") continue - nixl_pd_event.decode_node_info = decode_node_info - nixl_pd_event.set() + pd_event.decode_node_info = decode_node_info + pd_event.set() else: logger.error(f"recevie error obj {str(obj)}") @@ -209,8 +209,8 @@ async def _pd_process_generate( sampling_params: SamplingParams, multimodal_params: Dict, forwarding_queue: AsyncQueue, - nixl_pd_upload_websocket: ClientConnection, - nixl_pd_event: asyncio.Event, + pd_upload_websocket: ClientConnection, + pd_event: asyncio.Event, ): try: async for sub_req_id, request_output, metadata, finish_status in manager.generate( @@ -218,13 +218,13 @@ async def _pd_process_generate( sampling_params=sampling_params, multimodal_params=multimodal_params, request=None, - nixl_pd_upload_websocket=nixl_pd_upload_websocket, - nixl_pd_event=nixl_pd_event, + pd_upload_websocket=pd_upload_websocket, + pd_event=pd_event, ): metadata["node_mode"] = manager.args.run_mode await forwarding_queue.put((sub_req_id, request_output, metadata, finish_status)) - except NixlPrefillNodeStopGenToken as e: - logger.info(f"nixl prefill node stop gen token for group_request_id {e.group_request_id}") + except PDPrefillNodeStopGenToken as e: + logger.info(f"pd prefill node stop gen token for group_request_id {e.group_request_id}") except BaseException as e: logger.error(str(e)) diff --git a/lightllm/server/httpserver_for_pd_master/manager.py b/lightllm/server/httpserver_for_pd_master/manager.py index fd7969164e..104da9f26e 100644 --- a/lightllm/server/httpserver_for_pd_master/manager.py +++ b/lightllm/server/httpserver_for_pd_master/manager.py @@ -9,7 +9,7 @@ asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) from typing import Union, List, Tuple, Dict, Optional from lightllm.server.core.objs import FinishStatus -from ..pd_io_struct import PD_Client_Obj, NixlUpKVStatus, ObjType, NIXLDecodeNodeInfo +from ..pd_io_struct import PD_Client_Obj, PDUpKVStatus, ObjType, PDDecodeNodeInfo from lightllm.server.core.objs import SamplingParams, StartArgs from ..multimodal_params import MultimodalParams from ..tokenizer import get_tokenizer @@ -61,7 +61,7 @@ async def remove_pd(self, pd_info_json): self.pd_manager.remove_pd(pd_info_json) return - async def update_req_status(self, upkv_status: NixlUpKVStatus): + async def update_req_status(self, upkv_status: PDUpKVStatus): try: group_request_id = convert_sub_id_to_group_id(upkv_status.group_request_id) up_status_event = self.req_id_to_out_inf[group_request_id].up_status_event @@ -204,7 +204,7 @@ async def _log_req_header(self, request: Request, group_request_id: int): ) return - async def fetch_nixl_stream( + async def fetch_pd_stream( self, p_node: PD_Client_Obj, d_node: PD_Client_Obj, @@ -220,25 +220,25 @@ async def fetch_nixl_stream( self.req_id_to_out_inf[group_request_id] = req_status up_status_event = req_status.up_status_event - nixl_np_up_prompt_ids_event = req_status.nixl_np_up_prompt_ids_event + prefill_prompt_ids_event = req_status.prefill_prompt_ids_event old_max_new_tokens = sampling_params.max_new_tokens sampling_params.max_new_tokens = 1 await p_node.websocket.send_bytes(pickle.dumps((ObjType.REQ, (prompt, sampling_params, multimodal_params)))) try: - await asyncio.wait_for(nixl_np_up_prompt_ids_event.wait(), timeout=60) + await asyncio.wait_for(prefill_prompt_ids_event.wait(), timeout=60) except asyncio.TimeoutError: - logger.warning(f"group_request_id: {group_request_id} wait np up prompt ids time out") + logger.warning(f"group_request_id: {group_request_id} wait prefill prompt ids time out") raise ServerBusyError() if await request.is_disconnected(): raise ClientDisconnected( - group_request_id=group_request_id, reason="fetch_nixl_stream prefill period check network disconnected" + group_request_id=group_request_id, reason="fetch_pd_stream prefill period check network disconnected" ) - prompt_ids = nixl_np_up_prompt_ids_event.prompt_ids - logger.info(f"group_request_id: {group_request_id} get np up prompt ids len {len(prompt_ids)}") + prompt_ids = prefill_prompt_ids_event.prompt_ids + logger.info(f"group_request_id: {group_request_id} get prefill prompt ids len {len(prompt_ids)}") sampling_params.max_new_tokens = old_max_new_tokens await d_node.websocket.send_bytes( @@ -252,11 +252,11 @@ async def fetch_nixl_stream( raise ServerBusyError() # 将 decode 节点上报的当前请求使用的decode节点的信息下发给 p 节点,这样 p 节点才知道将 kv 传输给那个 d 节点。 - upkv_status: NixlUpKVStatus = up_status_event.upkv_status - nixl_params: bytes = upkv_status.nixl_params - decode_node_info: NIXLDecodeNodeInfo = pickle.loads(nixl_params) + upkv_status: PDUpKVStatus = up_status_event.upkv_status + pd_kv_trans_params: bytes = upkv_status.pd_kv_trans_params + decode_node_info: PDDecodeNodeInfo = pickle.loads(pd_kv_trans_params) await p_node.websocket.send_bytes( - pickle.dumps((ObjType.NIXL_REQ_DECODE_NODE_INFO, group_request_id, decode_node_info)) + pickle.dumps((ObjType.PD_REQ_DECODE_NODE_INFO, group_request_id, decode_node_info)) ) first_token_gen = False @@ -265,18 +265,18 @@ async def fetch_nixl_stream( if await request.is_disconnected(): raise ClientDisconnected( group_request_id=group_request_id, - reason="fetch_nixl_stream decode period check network disconnected", + reason="fetch_pd_stream decode period check network disconnected", ) if await req_status.can_read(self.req_id_to_out_inf): token_list = await req_status.pop_all_tokens() for sub_req_id, request_output, metadata, finish_status in token_list: output_index = metadata.get("count_output_tokens") - # 因为 nixl 的 prefill 和 decode 节点都有可能上报首token,所以需要做一下过滤。 + # 因为 pd 的 prefill 和 decode 节点都有可能上报首token,所以需要做一下过滤。 if output_index == 1: if first_token_gen is False: first_token_gen = True node_run_mode = metadata.pop("node_mode", None) - if node_run_mode == "nixl_prefill": + if node_run_mode == "prefill": if old_max_new_tokens != 1 and finish_status.is_finished_length(): finish_status = FinishStatus(FinishStatus.NO_FINISH) yield sub_req_id, request_output, metadata, finish_status @@ -307,7 +307,7 @@ async def _wait_to_token_package( is_first_token = True sub_req_id_to_mtp_accepted_token_num: Dict[int, int] = {} - async for sub_req_id, out_str, metadata, finish_status in self.fetch_nixl_stream( + async for sub_req_id, out_str, metadata, finish_status in self.fetch_pd_stream( p_node, d_node, prompt, sampling_params, multimodal_params, request ): if await request.is_disconnected(): @@ -431,16 +431,16 @@ async def handle_loop(self): req_status.event.set() except: pass - elif obj[0] == ObjType.NIXL_UPLOAD_NP_PROMPT_IDS: + elif obj[0] == ObjType.PD_UPLOAD_PREFILL_PROMPT_IDS: _, group_req_id, prompt_ids = obj try: req_status: ReqStatus = self.req_id_to_out_inf[group_req_id] async with req_status.lock: - req_status.nixl_np_up_prompt_ids_event.prompt_ids = prompt_ids - req_status.nixl_np_up_prompt_ids_event.set() + req_status.prefill_prompt_ids_event.prompt_ids = prompt_ids + req_status.prefill_prompt_ids_event.set() except: logger.error( - f"NIXL_UPLOAD_NP_PROMPT_IDS fail find req status for group_req_id: {group_req_id}" + f"PD_UPLOAD_PREFILL_PROMPT_IDS fail find req status for group_req_id: {group_req_id}" ) else: logger.error(f"recevie error obj {obj}") @@ -463,7 +463,7 @@ def __init__(self, req_id, p_node, d_node) -> None: self.lock = asyncio.Lock() self.event = asyncio.Event() self.up_status_event = asyncio.Event() - self.nixl_np_up_prompt_ids_event = asyncio.Event() + self.prefill_prompt_ids_event = asyncio.Event() self.out_token_info_list: List[Tuple[int, str, dict, FinishStatus]] = [] self.p_node: PD_Client_Obj = p_node self.d_node: PD_Client_Obj = d_node @@ -513,14 +513,14 @@ def register_pd(self, pd_info_json, websocket): pd_client.websocket = websocket self.url_to_pd_nodes[pd_client.client_ip_port] = pd_client - if pd_client.mode == "nixl_prefill": + if pd_client.mode == "prefill": self.prefill_nodes = [e for e in self.prefill_nodes if e.client_ip_port != pd_client.client_ip_port] self.prefill_nodes.append(pd_client) - elif pd_client.mode == "nixl_decode": + elif pd_client.mode == "decode": self.decode_nodes = [e for e in self.decode_nodes if e.client_ip_port != pd_client.client_ip_port] self.decode_nodes.append(pd_client) else: - assert False, f"mode must in ['nixl_prefill', 'nixl_decode'], but get {pd_client.mode}" + assert False, f"mode must in ['prefill', 'decode'], but get {pd_client.mode}" self.selector.update_nodes(self.prefill_nodes, self.decode_nodes) diff --git a/lightllm/server/pd_io_struct.py b/lightllm/server/pd_io_struct.py index 6aacc01907..1fa564cb99 100644 --- a/lightllm/server/pd_io_struct.py +++ b/lightllm/server/pd_io_struct.py @@ -12,23 +12,16 @@ # 节点的行为 class NodeRole(enum.Enum): - NP = "nixl_prefill" - ND = "nixl_decode" - + P = "prefill" + D = "decode" NORMAL = "normal" PD_MASTER = "pd_master" def is_D(self): - return self == NodeRole.ND + return self == NodeRole.D def is_P(self): - return self == NodeRole.NP - - def is_NP(self): - return self == NodeRole.NP - - def is_ND(self): - return self == NodeRole.ND + return self == NodeRole.P def is_normal(self): return self == NodeRole.NORMAL @@ -39,16 +32,13 @@ def is_P_or_NORMAL(self): def is_P_or_D(self): return self.is_P() or self.is_D() - def is_NP_or_ND(self): - return self == NodeRole.NP or self == NodeRole.ND - class ObjType(enum.Enum): ABORT = 1 REQ = 2 TOKEN_PACKS = 3 - NIXL_UPLOAD_NP_PROMPT_IDS = 4 # nixl p 节点上报生成的 prompt ids 信息。 - NIXL_REQ_DECODE_NODE_INFO = 5 # nixl pd master 节点下发给 nixl p 节点的对应请求对应的 d 节点的信息。 + PD_UPLOAD_PREFILL_PROMPT_IDS = 4 # prefill 节点上报生成的 prompt ids 信息。 + PD_REQ_DECODE_NODE_INFO = 5 # pd master 节点下发给 prefill 节点的请求对应的 decode 节点信息。 @dataclass @@ -60,14 +50,14 @@ class _PD_Client_RunStatus: class PD_Client_Obj: node_id: int client_ip_port: str - mode: str # 只能是 nixl_prefill 或者 nixl_decode 节点 + mode: str # 只能是 prefill 或者 decode 节点 start_args: object # 节点的启动参数信息,用于做匹配性的校验,防止运行过程中出现问题。 websocket: WebSocket = None # 用于通信的 websocket 连接对象 run_status: _PD_Client_RunStatus = field(default_factory=_PD_Client_RunStatus) def __post_init__(self): - if self.mode not in ["nixl_prefill", "nixl_decode"]: - error_info = f"""mode must in ["nixl_prefill", "nixl_decode"], but get {self.mode}""" + if self.mode not in ["prefill", "decode"]: + error_info = f"""mode must in ["prefill", "decode"], but get {self.mode}""" logger.error(error_info) raise ValueError(error_info) return @@ -85,14 +75,14 @@ def to_log_str(self): return f"PD_MASTER host_ip_port: {self.host_ip_port} node_id: {self.node_id}" -####### 下边是 NIXL模式下使用的特定对象 ######## +####### 下边是 pd kv 传输使用的对象 ######## @dataclass -class NixlUpKVStatus: +class PDUpKVStatus: group_request_id: int pd_master_node_id: int - nixl_params: bytes # nixl 建立连接所使用的元数据对象 + pd_kv_trans_params: bytes # pd kv 传输建立连接所使用的元数据对象 def __post_init__(self): @@ -110,11 +100,11 @@ def __post_init__(self): def __str__(self): req_id = self.group_request_id pd_m_id = self.pd_master_node_id - return f"group_request_id: {req_id} pd_master_node_id: {pd_m_id} nixl_params_len: {len(self.nixl_params)}" + return f"group_request_id: {req_id} pd_master_node_id: {pd_m_id} pd_kv_trans_params_len: {len(self.pd_kv_trans_params)}" @dataclass -class NIXLDecodeNodeInfo: +class PDDecodeNodeInfo: decode_node_id: int pd_master_node_id: int @@ -128,7 +118,7 @@ class NIXLDecodeNodeInfo: @dataclass -class NixlAgentMetadata: +class PDAgentMetadata: agent_name: str agent_metadata: bytes num_pages: int @@ -137,7 +127,7 @@ class NixlAgentMetadata: @dataclass -class NIXLChunckedTransTask: +class PDChunckedTransTask: request_id: int start_kv_index: int end_kv_index: int @@ -164,11 +154,11 @@ class NIXLChunckedTransTask: first_gen_token_id: Optional[int] first_gen_token_logprob: Optional[float] - nixl_write_stage: Optional[str] = None + write_stage: Optional[str] = None # transfer params - nixl_src_page_index: Optional[int] = None - nixl_dst_page_index: Optional[int] = None + src_page_index: Optional[int] = None + dst_page_index: Optional[int] = None # xfer_handle xfer_handle: Optional[int] = None @@ -193,7 +183,7 @@ def __post_init__(self): assert self.start_kv_index == self.end_kv_index assert len(self.mem_indexes) == 0 else: - raise ValueError(f"unknown NIXL trans page kind {self.page_kind}") + raise ValueError(f"unknown PD trans page kind {self.page_kind}") self.create_time = time.time() return @@ -218,7 +208,7 @@ def get_key(self) -> str: return f"{self.request_id}_{self.page_kind}_{self.start_kv_index}_{self.end_kv_index}" def to_str(self): - obj: NIXLChunckedTransTask = copy.copy(self) + obj: PDChunckedTransTask = copy.copy(self) obj.mem_indexes = None if obj.decode_agent_metadata is not None: obj.decode_agent_metadata = b"xxx" @@ -238,8 +228,8 @@ def transfer_kv_num(self): def need_transfer_page(self): return self.page_kind != "kv" or self.transfer_kv_num() != 0 - def createRetObj(self) -> "NIXLChunckedTransTaskRet": - ret = NIXLChunckedTransTaskRet( + def createRetObj(self) -> "PDChunckedTransTaskRet": + ret = PDChunckedTransTaskRet( request_id=self.request_id, start_kv_index=self.start_kv_index, end_kv_index=self.end_kv_index, @@ -250,16 +240,16 @@ def createRetObj(self) -> "NIXLChunckedTransTaskRet": ) return ret - def create_prefill_agent_obj(self) -> NixlAgentMetadata: - return NixlAgentMetadata( + def create_prefill_agent_obj(self) -> PDAgentMetadata: + return PDAgentMetadata( agent_name=self.prefill_agent_name, agent_metadata=self.prefill_agent_metadata, num_pages=self.prefill_num_pages, page_reg_desc=self.prefill_page_reg_desc, ) - def create_decode_agent_obj(self) -> NixlAgentMetadata: - return NixlAgentMetadata( + def create_decode_agent_obj(self) -> PDAgentMetadata: + return PDAgentMetadata( agent_name=self.decode_agent_name, agent_metadata=self.decode_agent_metadata, num_pages=self.decode_num_pages, @@ -268,7 +258,7 @@ def create_decode_agent_obj(self) -> NixlAgentMetadata: @dataclass -class NIXLChunckedTransTaskRet: +class PDChunckedTransTaskRet: request_id: int start_kv_index: int end_kv_index: int @@ -282,11 +272,11 @@ def get_key(self) -> str: @dataclass -class NIXLChunckedTransTaskGroup: - task_list: List[NIXLChunckedTransTask] = field(default_factory=list) +class PDChunckedTransTaskGroup: + task_list: List[PDChunckedTransTask] = field(default_factory=list) @dataclass -class NIXLAbortReq: +class PDAbortReq: request_id: int device_id: int diff --git a/lightllm/server/router/manager.py b/lightllm/server/router/manager.py index 8db0eaf77f..a5419adb9e 100644 --- a/lightllm/server/router/manager.py +++ b/lightllm/server/router/manager.py @@ -93,8 +93,8 @@ def __init__(self, args: StartArgs): ) self.metric_client = MetricClient(args.metric_port) - self.is_pd_run_mode = self.args.run_mode in ["nixl_prefill", "nixl_decode"] - self.is_pd_decode_mode = self.args.run_mode == "nixl_decode" + self.is_pd_run_mode = self.args.run_mode in ["prefill", "decode"] + self.is_pd_decode_mode = self.args.run_mode == "decode" self.shm_reqs_io_buffer = ShmObjsIOBuffer() self.cpu_cache_client = ( @@ -195,15 +195,15 @@ async def wait_to_model_ready(self): self.req_queue = build_req_queue(self.args, self, self.dp_size_in_node) logger.info(f"use req queue {self.req_queue.__class__.__name__}") - if self.args.run_mode == "nixl_prefill": - from lightllm.server.router.model_infer.mode_backend.pd_nixl.prefill_node_impl import ( + if self.args.run_mode == "prefill": + from lightllm.server.router.model_infer.mode_backend.pd.prefill_node_impl import ( start_prefill_kv_move_manager_process, ) start_prefill_kv_move_manager_process(self.args, self.info_queue) - if self.args.run_mode == "nixl_decode": - from lightllm.server.router.model_infer.mode_backend.pd_nixl.decode_node_impl import ( + if self.args.run_mode == "decode": + from lightllm.server.router.model_infer.mode_backend.pd.decode_node_impl import ( start_decode_kv_move_manager_process, ) diff --git a/lightllm/server/router/model_infer/infer_batch.py b/lightllm/server/router/model_infer/infer_batch.py index 419d6491ac..cc925061f7 100644 --- a/lightllm/server/router/model_infer/infer_batch.py +++ b/lightllm/server/router/model_infer/infer_batch.py @@ -21,7 +21,7 @@ from lightllm.server.multimodal_params import MultimodalParams from lightllm.utils.custom_kernel_utis import custom_cat from lightllm.utils.envs_utils import get_env_start_args -from lightllm.server.pd_io_struct import NIXLDecodeNodeInfo +from lightllm.server.pd_io_struct import PDDecodeNodeInfo from lightllm.server.embed_cache.embed_cache_client import CpuEmbedCacheClient logger = init_logger(__name__) @@ -458,11 +458,11 @@ def __init__( logger.error("invalid_token_ids contain tokenid >= vobsize, we remove these token ids") self.invalid_token_ids = [e for e in self.invalid_token_ids if e < vocab_size] - # nixl decode node information - if self.shm_param.nixl_params.data_len > 0: - self.nixl_decode_node: NIXLDecodeNodeInfo = pickle.loads(self.shm_param.nixl_params.get()) + # pd decode node information + if self.shm_param.pd_kv_trans_params.data_len > 0: + self.pd_decode_node: PDDecodeNodeInfo = pickle.loads(self.shm_param.pd_kv_trans_params.get()) else: - self.nixl_decode_node: NIXLDecodeNodeInfo = None + self.pd_decode_node: PDDecodeNodeInfo = None # only pd mode used. self.pd_master_node_id: int = self.shm_param.pd_master_node_id.get() @@ -522,12 +522,12 @@ def __init__( self.slave_reqs: List[InferReq] = [] self.related_master_req: InferReq = None - # nixl pd 分离模式使用的变量, 普通模式下这些变量没有具体用途 - self.nixl_trans_kv_start_index: int = 0 - self.nixl_pd_task_num: int = 0 - self.nixl_pd_task_sunccess_num: int = 0 - self.nixl_pd_task_failed_num: int = 0 - self.nixl_trans_device_id: int = -1 + # pd 分离模式使用的变量, 普通模式下这些变量没有具体用途 + self.pd_trans_kv_start_index: int = 0 + self.pd_task_num: int = 0 + self.pd_task_success_num: int = 0 + self.pd_task_failed_num: int = 0 + self.pd_trans_device_id: int = -1 # 类似 qwen3.5 这种混合linear att 模型使用的状态,记录申请来用于保存对应的线性att缓存的 buffer id # 当 prefill 阶段结束后, 对应长度的 linear att state 会写入到申请 buffer id 对应的块中, 方便插入到 radix cache中 @@ -576,9 +576,9 @@ def _init_all_state(self): self.shm_req.link_logprobs_shm_array() self.sampling_param: InferSamplingParams = InferSamplingParams(self.shm_req, self.vocab_size) - # 更新 nixl pd 分离模式下, prefill 节点需要开始传输的起始位置 - if self.sampling_param.nixl_decode_node is not None: - self.nixl_trans_kv_start_index = self.sampling_param.nixl_decode_node.ready_kv_len + # 更新 pd 分离模式下, prefill 节点需要开始传输的起始位置 + if self.sampling_param.pd_decode_node is not None: + self.pd_trans_kv_start_index = self.sampling_param.pd_decode_node.ready_kv_len self.cur_kv_len = 0 self.cur_output_len = 0 @@ -892,12 +892,12 @@ def handle( eos_ids: List[int], extra_post_req_handle_func: Optional[Callable[[InferReq, int, float], None]], is_master_in_dp: bool, - nixl_prefill_chuncked_handle_func: Optional[Callable[[InferReq, int, float, int], None]] = None, + pd_prefill_chunked_handle_func: Optional[Callable[[InferReq, int, float, int], None]] = None, ): - # nixl_prefill_chuncked_handle_func 主要是为了处理 nixl prefill 模式下 + # pd_prefill_chunked_handle_func 主要是为了处理 pd prefill 模式下 # 分块 prefill 后,形成对应的pd 分块传输处理。 - if nixl_prefill_chuncked_handle_func is not None: - nixl_prefill_chuncked_handle_func(self.req_obj, next_token_id, next_token_logprob, self.output_len) + if pd_prefill_chunked_handle_func is not None: + pd_prefill_chunked_handle_func(self.req_obj, next_token_id, next_token_logprob, self.output_len) if self.output_len <= 0: return diff --git a/lightllm/server/router/model_infer/mode_backend/__init__.py b/lightllm/server/router/model_infer/mode_backend/__init__.py index 1843f31314..1a4bf6c020 100644 --- a/lightllm/server/router/model_infer/mode_backend/__init__.py +++ b/lightllm/server/router/model_infer/mode_backend/__init__.py @@ -10,7 +10,7 @@ from .diverse_backend.impl import DiversehBackend # pd mode backend -from .pd_nixl.prefill_node_impl.prefill_impl import NIXLChunckedPrefillForPrefillNode -from .pd_nixl.prefill_node_impl.prefill_impl_for_dp import NIXLDPChunkedForPrefillNode -from .pd_nixl.decode_node_impl.decode_impl import NIXLDecodeNode -from .pd_nixl.decode_node_impl.decode_impl_for_dp import NIXLDPForDecodeNode +from .pd.prefill_node_impl.prefill_impl import PDChunkedPrefillForPrefillNode +from .pd.prefill_node_impl.prefill_impl_for_dp import PDDPChunkedForPrefillNode +from .pd.decode_node_impl.decode_impl import PDDecodeNode +from .pd.decode_node_impl.decode_impl_for_dp import PDDPForDecodeNode diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index 3953f8d38d..a1466d226b 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -47,7 +47,7 @@ from lightllm.models.glm4_moe_lite_mtp.model import Glm4MoeLiteMTPModel from lightllm.server.router.model_infer.mode_backend.generic_post_process import sample from lightllm.common.basemodel.triton_kernel.gather_token_id import scatter_token -from lightllm.server.pd_io_struct import NIXLChunckedTransTaskRet +from lightllm.server.pd_io_struct import PDChunckedTransTaskRet from .multi_level_kv_cache import MultiLevelKvCacheModule @@ -73,8 +73,8 @@ def __init__(self) -> None: self.classed_req_no_decode = False self.classed_req_strict_prefill = True - # nixl pd mode callback func - self.nixl_prefill_chuncked_handle_func: Optional[Callable[[InferReq, int, float, int], None]] = None + # pd mode callback func + self.pd_prefill_chunked_handle_func: Optional[Callable[[InferReq, int, float, int], None]] = None # counter self._radix_tree_merge_counter: int = 0 @@ -104,8 +104,8 @@ def init_model(self, kvargs): self.eos_id: List[int] = kvargs.get("eos_id", [2]) self.disable_cudagraph = self.args.disable_cudagraph self.is_multinode_tp = self.args.nnodes > 1 and self.args.dp == 1 - self.is_nixl_pd_mode = self.run_mode in ["nixl_prefill", "nixl_decode"] - self.is_nixl_decode_mode = self.run_mode == "nixl_decode" + self.is_pd_mode = self.run_mode in ["prefill", "decode"] + self.is_pd_decode_mode = self.run_mode == "decode" self.logger = init_logger(__name__) @@ -219,7 +219,7 @@ def init_model(self, kvargs): ) if ( - self.args.run_mode in ["nixl_prefill", "nixl_decode"] or self.args.enable_dp_prompt_cache_fetch + self.args.run_mode in ["prefill", "decode"] or self.args.enable_dp_prompt_cache_fetch ): # 如果存在需要跨进程使用mem manger的特性,则将mem manager写入到 shm中,方便 # 读取 @@ -232,8 +232,8 @@ def init_model(self, kvargs): self.init_dp_kv_shared() self.shm_reqs_io_buffer = ShmObjsIOBuffer() - # 只会在 nixl pd 模式下才会使用,用于上传分块传输任务是否成功。 - self.shm_nixl_trans_io_buffer = ShmObjsIOBuffer(tail_str="nixl") + # 只会在 pd pd 模式下才会使用,用于上传分块传输任务是否成功。 + self.shm_pd_trans_io_buffer = ShmObjsIOBuffer(tail_str="pd") # 开启 mtp 模式,需要完成mtp model的初始化 if self.args.mtp_mode: @@ -380,10 +380,10 @@ def _try_read_new_reqs_normal(self): if new_buffer_is_ready: self._read_reqs_buffer_and_init_reqs() - # nixl pd mode 从 shm_nixl_trans_io_buffer 读取分块传输的完成进度。 - if self.is_nixl_pd_mode: + # pd mode 从 shm_pd_trans_io_buffer 读取分块传输的完成进度。 + if self.is_pd_mode: if self.is_master_in_node: - if self.shm_nixl_trans_io_buffer.is_ready(): + if self.shm_pd_trans_io_buffer.is_ready(): self.node_broadcast_tensor.fill_(1) else: self.node_broadcast_tensor.fill_(0) @@ -392,7 +392,7 @@ def _try_read_new_reqs_normal(self): broadcast(self.node_broadcast_tensor, src=src_rank_id, group=self.node_nccl_group, async_op=False) new_buffer_is_ready = self.node_broadcast_tensor.detach().item() if new_buffer_is_ready: - self._read_nixl_trans_io_buffer_and_update_req_status() + self._read_pd_trans_io_buffer_and_update_req_status() return def _try_read_new_reqs_multinode_tp(self): @@ -415,7 +415,7 @@ def _try_read_new_reqs_multinode_tp(self): if new_buffer_is_ready: self._read_reqs_buffer_and_init_reqs() - assert self.is_nixl_pd_mode is False + assert self.is_pd_mode is False return def _read_reqs_buffer_and_init_reqs(self): @@ -436,20 +436,20 @@ def _read_reqs_buffer_and_init_reqs(self): self._init_reqs(reqs=init_reqs) return - def _read_nixl_trans_io_buffer_and_update_req_status(self): - cmds: List[NIXLChunckedTransTaskRet] = self.shm_nixl_trans_io_buffer.read_obj() - self.shm_nixl_trans_io_buffer.sub_state() + def _read_pd_trans_io_buffer_and_update_req_status(self): + cmds: List[PDChunckedTransTaskRet] = self.shm_pd_trans_io_buffer.read_obj() + self.shm_pd_trans_io_buffer.sub_state() if cmds: for obj in cmds: if obj.request_id in g_infer_context.requests_mapping: req: InferReq = g_infer_context.requests_mapping[obj.request_id] if obj.has_error: - req.nixl_pd_task_failed_num += 1 + req.pd_task_failed_num += 1 else: - req.nixl_pd_task_sunccess_num += 1 - # nixl decode 节点需要预填充 prefill 节点发送过来的产生的首token信息,以使 + req.pd_task_success_num += 1 + # pd decode 节点需要预填充 prefill 节点发送过来的产生的首token信息,以使 # 推理过程可以继续。 - if self.is_nixl_decode_mode: + if self.is_pd_decode_mode: if obj.first_gen_token_id is not None: assert req.cur_output_len == 0 req.cur_output_len += 1 @@ -465,7 +465,7 @@ def _read_nixl_trans_io_buffer_and_update_req_status(self): eos_ids=self.eos_id, extra_post_req_handle_func=None, is_master_in_dp=self.is_master_in_dp, - nixl_prefill_chuncked_handle_func=None, + pd_prefill_chunked_handle_func=None, ) return @@ -497,7 +497,7 @@ def _load_cpu_cache_to_reqs(self, req_ids): def _filter_not_ready_reqs(self, req_ids: List[int]) -> List[InferReq]: """ 将错误请求从 req_ids 中过滤出来, 然后让 _get_classed_reqs 进行处理。 该函数 - 主要用于在 nixl pd 分离模式下, 由子类继承重载, prefill 和 decode 节点过滤 kv 传输错误,或者 kv + 主要用于在 pd 分离模式下, 由子类继承重载, prefill 和 decode 节点过滤 kv 传输错误,或者 kv 传输没有完成的请求。 """ return [g_infer_context.requests_mapping[request_id] for request_id in req_ids] @@ -705,7 +705,7 @@ def _post_handle( next_token_logprobs: List[float], run_reqs_update_packs: List[InferReqUpdatePack], extra_post_req_handle_func: Optional[Callable[[InferReq, int, float], None]] = None, - nixl_prefill_chuncked_handle_func: Optional[Callable[[InferReq, int, float, int], None]] = None, + pd_prefill_chunked_handle_func: Optional[Callable[[InferReq, int, float, int], None]] = None, ): """ extra_post_req_handle_func 用于提供在一个请求确定输出的时候,给出额外的后处理操作,主要是用于 @@ -722,7 +722,7 @@ def _post_handle( eos_ids=self.eos_id, extra_post_req_handle_func=extra_post_req_handle_func, is_master_in_dp=self.is_master_in_dp, - nixl_prefill_chuncked_handle_func=nixl_prefill_chuncked_handle_func, + pd_prefill_chunked_handle_func=pd_prefill_chunked_handle_func, ) g_infer_context.req_manager.req_sampling_params_manager.update_reqs_token_counter( diff --git a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py index e06f06c4e3..792a10a788 100644 --- a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py @@ -137,7 +137,7 @@ def prefill_normal( next_token_logprobs=next_token_logprobs_cpu, run_reqs_update_packs=update_packs, extra_post_req_handle_func=self.extra_post_req_handle_func, - nixl_prefill_chuncked_handle_func=self.nixl_prefill_chuncked_handle_func, + pd_prefill_chunked_handle_func=self.pd_prefill_chunked_handle_func, ) # 第四阶段 event_pack.notify_pre_post_handle() @@ -223,7 +223,7 @@ def prefill_mtp( next_token_logprobs=next_token_logprobs_cpu, run_reqs_update_packs=update_packs, extra_post_req_handle_func=self.extra_post_req_handle_func, - nixl_prefill_chuncked_handle_func=self.nixl_prefill_chuncked_handle_func, + pd_prefill_chunked_handle_func=self.pd_prefill_chunked_handle_func, ) # 第四阶段 diff --git a/lightllm/server/router/model_infer/mode_backend/diverse_backend/impl.py b/lightllm/server/router/model_infer/mode_backend/diverse_backend/impl.py index 34e174bc59..f1681eda52 100644 --- a/lightllm/server/router/model_infer/mode_backend/diverse_backend/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/diverse_backend/impl.py @@ -139,7 +139,7 @@ def _diverse_pre_post_handle(self, run_reqs: List[InferReq], is_chuncked_mode: b pack = InferReqUpdatePack(req_obj=req_obj, output_len=0) update_func_objs.append(pack) pre_master_req_pack = pack - # TODO 如果 diverse mode 需要支持 nixl pd 分离,则应该每个分块prefill后都进行相关的复制, + # TODO 如果 diverse mode 需要支持 pd 分离,则应该每个分块prefill后都进行相关的复制, # 暂时不支持 diverse mode 和 pd 模式的混合 continue diff --git a/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py b/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py index 878d4fade0..e6b9d1c18d 100644 --- a/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py @@ -182,7 +182,7 @@ def prefill_normal( next_token_logprobs=next_token_logprobs_cpu, run_reqs_update_packs=update_packs, extra_post_req_handle_func=self.extra_post_req_handle_func, - nixl_prefill_chuncked_handle_func=self.nixl_prefill_chuncked_handle_func, + pd_prefill_chunked_handle_func=self.pd_prefill_chunked_handle_func, ) # 第四阶段 event_pack.notify_pre_post_handle() @@ -295,7 +295,7 @@ def prefill_overlap(self, event_pack: OverlapEventPack, prefill_reqs: List[Infer next_token_logprobs=next_token_logprobs_cpu, run_reqs_update_packs=update_packs, extra_post_req_handle_func=self.extra_post_req_handle_func, - nixl_prefill_chuncked_handle_func=self.nixl_prefill_chuncked_handle_func, + pd_prefill_chunked_handle_func=self.pd_prefill_chunked_handle_func, ) # 第四阶段 event_pack.notify_pre_post_handle() @@ -420,7 +420,7 @@ def prefill_mtp(self, event_pack: OverlapEventPack, prefill_reqs: List[InferReq] next_token_logprobs=next_token_logprobs_cpu, run_reqs_update_packs=update_packs, extra_post_req_handle_func=self.extra_post_req_handle_func, - nixl_prefill_chuncked_handle_func=self.nixl_prefill_chuncked_handle_func, + pd_prefill_chunked_handle_func=self.pd_prefill_chunked_handle_func, ) # 第四阶段 @@ -719,7 +719,7 @@ def prefill_overlap_mtp(self, event_pack: OverlapEventPack, prefill_reqs: List[I next_token_logprobs=next_token_logprobs_cpu, run_reqs_update_packs=update_packs, extra_post_req_handle_func=self.extra_post_req_handle_func, - nixl_prefill_chuncked_handle_func=self.nixl_prefill_chuncked_handle_func, + pd_prefill_chunked_handle_func=self.pd_prefill_chunked_handle_func, ) event_pack.notify_pre_post_handle() else: diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/__init__.py b/lightllm/server/router/model_infer/mode_backend/pd/__init__.py similarity index 100% rename from lightllm/server/router/model_infer/mode_backend/pd_nixl/__init__.py rename to lightllm/server/router/model_infer/mode_backend/pd/__init__.py diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/base_kv_move_manager.py b/lightllm/server/router/model_infer/mode_backend/pd/base_kv_move_manager.py similarity index 89% rename from lightllm/server/router/model_infer/mode_backend/pd_nixl/base_kv_move_manager.py rename to lightllm/server/router/model_infer/mode_backend/pd/base_kv_move_manager.py index eb11728029..125edede25 100644 --- a/lightllm/server/router/model_infer/mode_backend/pd_nixl/base_kv_move_manager.py +++ b/lightllm/server/router/model_infer/mode_backend/pd/base_kv_move_manager.py @@ -6,7 +6,7 @@ import torch.multiprocessing as mp from typing import List, Dict, Union, Callable, Optional from lightllm.utils.log_utils import init_logger -from lightllm.server.pd_io_struct import NIXLChunckedTransTaskRet +from lightllm.server.pd_io_struct import PDChunckedTransTaskRet from lightllm.server.core.objs import StartArgs from lightllm.server.core.objs.shm_objs_io_buffer import ShmObjsIOBuffer from .trans_process_obj import KVTransProcess @@ -47,7 +47,7 @@ def __init__( threading.Thread(target=self.task_ret_handle_loop, args=(trans_process,), daemon=True).start() # 通过 io buffer 将命令写入到推理进程中 - self.shm_nixl_trans_io_buffer = ShmObjsIOBuffer(tail_str="nixl") + self.shm_pd_trans_io_buffer = ShmObjsIOBuffer(tail_str="pd") for func in [self.task_dispatcher_loop, self.task_ret_upload_loop, self.check_trans_process_loop]: threading.Thread(target=func, daemon=True).start() @@ -66,15 +66,15 @@ def task_dispatcher_loop(self): @log_exception def task_ret_upload_loop(self): while True: - ret_obj: NIXLChunckedTransTaskRet = self.ret_obj_queue.get() - ret_objs: List[NIXLChunckedTransTaskRet] = [ret_obj] + ret_obj: PDChunckedTransTaskRet = self.ret_obj_queue.get() + ret_objs: List[PDChunckedTransTaskRet] = [ret_obj] ret_objs.extend(self._collect_return_objects()) while True: - if self.shm_nixl_trans_io_buffer.is_empty(): + if self.shm_pd_trans_io_buffer.is_empty(): # to do, 这里写入的数量,可能会超过共享管道的大小。 - self.shm_nixl_trans_io_buffer.write_obj(ret_objs) - self.shm_nixl_trans_io_buffer.set_ready() + self.shm_pd_trans_io_buffer.write_obj(ret_objs) + self.shm_pd_trans_io_buffer.set_ready() break else: time.sleep(0.01) @@ -97,7 +97,7 @@ def _collect_return_objects(self): @log_exception def task_ret_handle_loop(self, trans_process: KVTransProcess): while True: - ret_obj: NIXLChunckedTransTaskRet = trans_process.task_out_queue.get() + ret_obj: PDChunckedTransTaskRet = trans_process.task_out_queue.get() self.ret_obj_queue.put(ret_obj) # ================================================================================== diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/__init__.py b/lightllm/server/router/model_infer/mode_backend/pd/decode_node_impl/__init__.py similarity index 100% rename from lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/__init__.py rename to lightllm/server/router/model_infer/mode_backend/pd/decode_node_impl/__init__.py diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/decode_impl.py b/lightllm/server/router/model_infer/mode_backend/pd/decode_node_impl/decode_impl.py similarity index 82% rename from lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/decode_impl.py rename to lightllm/server/router/model_infer/mode_backend/pd/decode_node_impl/decode_impl.py index a00af420c6..242e1089e7 100644 --- a/lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/decode_impl.py +++ b/lightllm/server/router/model_infer/mode_backend/pd/decode_node_impl/decode_impl.py @@ -1,6 +1,6 @@ import random import torch.multiprocessing as mp -from lightllm.server.pd_io_struct import NIXLChunckedTransTask, NIXLChunckedTransTaskGroup, NIXLAbortReq +from lightllm.server.pd_io_struct import PDChunckedTransTask, PDChunckedTransTaskGroup, PDAbortReq from lightllm.server.router.model_infer.mode_backend.chunked_prefill.impl import ChunkedPrefillBackend from typing import List, Tuple from lightllm.server.router.model_infer.infer_batch import g_infer_context, InferReq @@ -11,7 +11,7 @@ logger = init_logger(__name__) -class NIXLDecodeNode(ChunkedPrefillBackend): +class PDDecodeNode(ChunkedPrefillBackend): def __init__(self, info_queue: mp.Queue) -> None: super().__init__() self.info_queue: mp.Queue = info_queue @@ -65,13 +65,13 @@ def _filter_not_ready_reqs(self, req_ids: List[int]) -> List[InferReq]: for request_id in req_ids: req_obj: InferReq = g_infer_context.requests_mapping[request_id] - if self.is_master_in_dp and req_obj.infer_aborted and req_obj.nixl_pd_task_num != 0: - self.info_queue.put(NIXLAbortReq(request_id=req_obj.req_id, device_id=req_obj.nixl_trans_device_id)) + if self.is_master_in_dp and req_obj.infer_aborted and req_obj.pd_task_num != 0: + self.info_queue.put(PDAbortReq(request_id=req_obj.req_id, device_id=req_obj.pd_trans_device_id)) - if req_obj.nixl_pd_task_num != (req_obj.nixl_pd_task_failed_num + req_obj.nixl_pd_task_sunccess_num): + if req_obj.pd_task_num != (req_obj.pd_task_failed_num + req_obj.pd_task_success_num): continue - if req_obj.nixl_pd_task_failed_num > 0: + if req_obj.pd_task_failed_num > 0: # 强制停止 if not req_obj.finish_status.is_finished(): req_obj.cur_output_len += 1 @@ -110,12 +110,12 @@ def _decode_node_gen_trans_tasks(self, req_obj: InferReq): """ decode node 生成所有的传输任务对象。 """ - group = NIXLChunckedTransTaskGroup() + group = PDChunckedTransTaskGroup() input_len = req_obj.shm_req.input_len # 当 decode 节点不能匹配足够的kv的时候,才进行真实的 kv 传输。 if input_len - req_obj.cur_kv_len > 1: - page_size = self.args.nixl_pd_kv_page_size - req_obj.nixl_trans_kv_start_index = req_obj.cur_kv_len + page_size = self.args.pd_kv_page_size + req_obj.pd_trans_kv_start_index = req_obj.cur_kv_len need_mem_size = input_len - req_obj.cur_kv_len if need_mem_size > 0: @@ -127,13 +127,13 @@ def _decode_node_gen_trans_tasks(self, req_obj: InferReq): req_obj.req_idx, req_obj.cur_kv_len : (req_obj.cur_kv_len + need_mem_size) ] = mem_indexes - while req_obj.nixl_trans_kv_start_index < input_len: - cur_page_size = min(page_size, input_len - req_obj.nixl_trans_kv_start_index) + while req_obj.pd_trans_kv_start_index < input_len: + cur_page_size = min(page_size, input_len - req_obj.pd_trans_kv_start_index) # 生成页面传输任务, 放入kv move manager 的处理队列中 - start_index = req_obj.nixl_trans_kv_start_index - end_index = req_obj.nixl_trans_kv_start_index + cur_page_size + start_index = req_obj.pd_trans_kv_start_index + end_index = req_obj.pd_trans_kv_start_index + cur_page_size page_mem_indexes = mem_indexes[start_index - req_obj.cur_kv_len : end_index - req_obj.cur_kv_len] - self._create_nixl_trans_task( + self._create_pd_trans_task( req_obj=req_obj, mem_indexes=page_mem_indexes.tolist(), kv_start_index=start_index, @@ -141,13 +141,13 @@ def _decode_node_gen_trans_tasks(self, req_obj: InferReq): group=group, ) # update - req_obj.nixl_trans_kv_start_index += cur_page_size + req_obj.pd_trans_kv_start_index += cur_page_size req_obj.cur_kv_len += len(mem_indexes) # 如果当前是linear att 混合模型,则需要创建一个linear att 状态的传输任务 if g_infer_context.is_linear_att_mixed_model: - self._create_nixl_trans_task( + self._create_pd_trans_task( req_obj=req_obj, mem_indexes=[], kv_start_index=input_len, @@ -161,7 +161,7 @@ def _decode_node_gen_trans_tasks(self, req_obj: InferReq): if not group.task_list: # 需要上报一个包含 0 长度的trans task,触发 kv move manager 给 pd master 上报 # upkv_status 状态,使推理流程完整。 - self._create_nixl_trans_task( + self._create_pd_trans_task( req_obj=req_obj, mem_indexes=[], kv_start_index=req_obj.cur_kv_len, @@ -173,31 +173,31 @@ def _decode_node_gen_trans_tasks(self, req_obj: InferReq): self.info_queue.put(group) return - def _create_nixl_trans_task( + def _create_pd_trans_task( self, req_obj: InferReq, mem_indexes: List[int], kv_start_index: int, kv_end_index: int, - group: NIXLChunckedTransTaskGroup, + group: PDChunckedTransTaskGroup, page_kind: str = "kv", ): # 确定传输设备 - if req_obj.nixl_trans_device_id == -1: - if not hasattr(self, "nixl_iter_device_id"): - self.nixl_iter_device_id = 0 - req_obj.nixl_trans_device_id = self.nixl_iter_device_id + if req_obj.pd_trans_device_id == -1: + if not hasattr(self, "pd_iter_device_id"): + self.pd_iter_device_id = 0 + req_obj.pd_trans_device_id = self.pd_iter_device_id # only self.is_master_in_dp will be used. - self.nixl_iter_device_id = (self.nixl_iter_device_id + 1) % self.node_world_size + self.pd_iter_device_id = (self.pd_iter_device_id + 1) % self.node_world_size if page_kind == "kv": req_idx = None elif page_kind == "linear_att_state": req_idx = req_obj.req_idx else: - raise ValueError(f"unknown NIXL trans page kind {page_kind}") + raise ValueError(f"unknown PD trans page kind {page_kind}") - trans_task = NIXLChunckedTransTask( + trans_task = PDChunckedTransTask( request_id=req_obj.req_id, start_kv_index=kv_start_index, end_kv_index=kv_end_index, @@ -206,7 +206,7 @@ def _create_nixl_trans_task( prefill_dp_index=None, decode_dp_index=self.dp_rank_in_node, src_device_id=None, - dst_device_id=req_obj.nixl_trans_device_id, + dst_device_id=req_obj.pd_trans_device_id, mem_indexes=mem_indexes, prefill_agent_name=None, prefill_agent_metadata=None, @@ -222,5 +222,5 @@ def _create_nixl_trans_task( req_idx=req_idx, ) group.task_list.append(trans_task) - req_obj.nixl_pd_task_num += 1 + req_obj.pd_task_num += 1 return diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/decode_impl_for_dp.py b/lightllm/server/router/model_infer/mode_backend/pd/decode_node_impl/decode_impl_for_dp.py similarity index 67% rename from lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/decode_impl_for_dp.py rename to lightllm/server/router/model_infer/mode_backend/pd/decode_node_impl/decode_impl_for_dp.py index dc46d795a9..87af300003 100644 --- a/lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/decode_impl_for_dp.py +++ b/lightllm/server/router/model_infer/mode_backend/pd/decode_node_impl/decode_impl_for_dp.py @@ -3,12 +3,12 @@ from lightllm.utils.log_utils import init_logger from typing import List, Tuple from lightllm.server.router.model_infer.mode_backend.dp_backend.impl import DPChunkedPrefillBackend -from .decode_impl import NIXLDecodeNode, NIXLChunckedTransTaskGroup +from .decode_impl import PDDecodeNode, PDChunckedTransTaskGroup logger = init_logger(__name__) -class NIXLDPForDecodeNode(DPChunkedPrefillBackend): +class PDDPForDecodeNode(DPChunkedPrefillBackend): def __init__(self, info_queue: mp.Queue) -> None: super().__init__() self.info_queue: mp.Queue = info_queue @@ -16,30 +16,30 @@ def __init__(self, info_queue: mp.Queue) -> None: return def init_custom(self): - return NIXLDecodeNode.init_custom(self) + return PDDecodeNode.init_custom(self) def _init_reqs(self, reqs: List[Tuple]): - return NIXLDecodeNode._init_reqs(self, reqs=reqs) + return PDDecodeNode._init_reqs(self, reqs=reqs) def _post_init_reqs(self, uninit_reqs: List[InferReq]): - return NIXLDecodeNode._post_init_reqs(self, uninit_reqs=uninit_reqs) + return PDDecodeNode._post_init_reqs(self, uninit_reqs=uninit_reqs) def _filter_not_ready_reqs(self, req_ids: List[int]) -> List[InferReq]: - return NIXLDecodeNode._filter_not_ready_reqs(self, req_ids=req_ids) + return PDDecodeNode._filter_not_ready_reqs(self, req_ids=req_ids) def _decode_node_gen_trans_tasks(self, req_obj: InferReq): - return NIXLDecodeNode._decode_node_gen_trans_tasks(self, req_obj=req_obj) + return PDDecodeNode._decode_node_gen_trans_tasks(self, req_obj=req_obj) - def _create_nixl_trans_task( + def _create_pd_trans_task( self, req_obj: InferReq, mem_indexes: List[int], kv_start_index: int, kv_end_index: int, - group: NIXLChunckedTransTaskGroup, + group: PDChunckedTransTaskGroup, page_kind: str = "kv", ): - return NIXLDecodeNode._create_nixl_trans_task( + return PDDecodeNode._create_pd_trans_task( self, req_obj=req_obj, mem_indexes=mem_indexes, diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/decode_kv_move_manager.py b/lightllm/server/router/model_infer/mode_backend/pd/decode_node_impl/decode_kv_move_manager.py similarity index 88% rename from lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/decode_kv_move_manager.py rename to lightllm/server/router/model_infer/mode_backend/pd/decode_node_impl/decode_kv_move_manager.py index f1b2244024..41bdcd361c 100644 --- a/lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/decode_kv_move_manager.py +++ b/lightllm/server/router/model_infer/mode_backend/pd/decode_node_impl/decode_kv_move_manager.py @@ -5,7 +5,7 @@ import time from typing import List, Dict, Optional, Tuple, Union, Callable from lightllm.utils.log_utils import init_logger -from lightllm.server.pd_io_struct import NIXLChunckedTransTaskGroup, NIXLAbortReq +from lightllm.server.pd_io_struct import PDChunckedTransTaskGroup, PDAbortReq from lightllm.server.core.objs import StartArgs from lightllm.utils.graceful_utils import graceful_registry from ..trans_process_obj import KVTransProcess @@ -31,7 +31,7 @@ def _init_env(args, info_queue: mp.Queue, event: mp.Event): # 注册graceful 退出的处理 graceful_registry(inspect.currentframe().f_code.co_name) - setproctitle.setproctitle(f"lightllm::{get_unique_server_name()}::nixl_decode_kv_move_manager") + setproctitle.setproctitle(f"lightllm::{get_unique_server_name()}::decode_kv_move_manager") from .up_status import start_up_kv_status_process @@ -76,11 +76,11 @@ def __init__( def task_dispatcher_loop(self): # 获取任务,并分发给相关卡的处理队列 while True: - task_group: Union[NIXLChunckedTransTaskGroup, NIXLAbortReq] = self.info_queue.get() + task_group: Union[PDChunckedTransTaskGroup, PDAbortReq] = self.info_queue.get() - if isinstance(task_group, NIXLChunckedTransTaskGroup): + if isinstance(task_group, PDChunckedTransTaskGroup): device_id = task_group.task_list[0].dst_device_id - elif isinstance(task_group, NIXLAbortReq): + elif isinstance(task_group, PDAbortReq): device_id = task_group.device_id else: assert False, f"error obj {task_group}" @@ -88,7 +88,7 @@ def task_dispatcher_loop(self): try: trans_process: KVTransProcess = self.kv_trans_processes[device_id] trans_process.task_in_queue.put(task_group) - if isinstance(task_group, NIXLChunckedTransTaskGroup): + if isinstance(task_group, PDChunckedTransTaskGroup): logger.info( f"kv move manager dispatch task group {task_group.task_list[0].to_str()} to device {device_id}" ) diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/decode_trans_process.py b/lightllm/server/router/model_infer/mode_backend/pd/decode_node_impl/decode_trans_process.py similarity index 91% rename from lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/decode_trans_process.py rename to lightllm/server/router/model_infer/mode_backend/pd/decode_node_impl/decode_trans_process.py index cb1b92c184..036c6f162b 100644 --- a/lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/decode_trans_process.py +++ b/lightllm/server/router/model_infer/mode_backend/pd/decode_node_impl/decode_trans_process.py @@ -10,12 +10,12 @@ from lightllm.utils.log_utils import init_logger from lightllm.common.kv_cache_mem_manager import MemoryManager from lightllm.server.pd_io_struct import ( - NIXLChunckedTransTask, - NIXLChunckedTransTaskGroup, - NixlUpKVStatus, - NIXLAbortReq, + PDChunckedTransTask, + PDChunckedTransTaskGroup, + PDUpKVStatus, + PDAbortReq, ) -from lightllm.server.pd_io_struct import NIXLDecodeNodeInfo +from lightllm.server.pd_io_struct import PDDecodeNodeInfo from lightllm.utils.graceful_utils import graceful_registry from lightllm.server.core.objs import StartArgs from ..kv_transporter import create_kv_transporter @@ -53,7 +53,7 @@ def _init_env( # ------------------------------------------------------------------------- # 问题背景(PD NIXL + 同卡多进程): # decode 物理 GPU 上至少有两个独立 CUDA 进程:model_infer(解码推理)与 - # nixl_decode_trans(把 prefill 侧 KV page 拷入 decode KV cache)。 + # decode_trans(把 prefill 侧 KV page 拷入 decode KV cache)。 # lm_eval batch=64 时会在短时间内并发大量 read_page;拷贝在 copy_cuda_stream # 上排队,而推理在另一进程的 stream 上执行,彼此无法 cudaStreamWaitEvent # 协调。日志里的 read_page_gpu_time(event 差值)会把「等 GPU 时间片 / @@ -72,7 +72,7 @@ def _init_env( os.environ["CUDA_MPS_CLIENT_PRIORITY"] = "0" torch.backends.cudnn.enabled = False - setproctitle.setproctitle(f"lightllm::{get_unique_server_name()}::nixl_decode_trans:Device{device_id}") + setproctitle.setproctitle(f"lightllm::{get_unique_server_name()}::decode_trans:Device{device_id}") try: torch.cuda.set_device(device_id) @@ -126,7 +126,7 @@ def __init__( self.up_status_in_queue = up_status_in_queue cur_mem_manager: MemoryManager = self.mem_managers[device_id] kv_move_buffer = cur_mem_manager.alloc_paged_kv_move_buffer( - page_num=self.args.nixl_pd_kv_page_num, page_size=self.args.nixl_pd_kv_page_size + page_num=self.args.pd_kv_page_num, page_size=self.args.pd_kv_page_size ) self.copy_cuda_stream = torch.cuda.Stream(priority=-1) self.transporter = create_kv_transporter( @@ -137,14 +137,14 @@ def __init__( ) self.recv_task_group_queue = queue.Queue() self.waiting_dict_lock = threading.Lock() - self.waiting_dict: Dict[str, NIXLChunckedTransTask] = {} + self.waiting_dict: Dict[str, PDChunckedTransTask] = {} self.request_page_task_queue = queue.Queue() self.ready_page_task_queue = queue.Queue() self.success_queue = queue.Queue() self.failed_queue = queue.Queue() self.page_index_queue = queue.Queue() - for page_index in range(self.args.nixl_pd_kv_page_num): + for page_index in range(self.args.pd_kv_page_num): self.page_index_queue.put(page_index) # warmup 预先加载一次kv 数据到 mem manager,避免第一次拷贝时出现卡顿。 @@ -179,10 +179,10 @@ def _warmup(self): @log_exception def recv_task_loop(self): while True: - obj: Union[NIXLChunckedTransTaskGroup, NIXLAbortReq] = self.task_in_queue.get() - if isinstance(obj, NIXLChunckedTransTaskGroup): + obj: Union[PDChunckedTransTaskGroup, PDAbortReq] = self.task_in_queue.get() + if isinstance(obj, PDChunckedTransTaskGroup): self.recv_task_group_queue.put(obj) - elif isinstance(obj, NIXLAbortReq): + elif isinstance(obj, PDAbortReq): self._abort(request_id=obj.request_id) else: assert False, f"recv error obj {obj}" @@ -191,7 +191,7 @@ def _abort(self, request_id: int, error_info: str = "aborted req"): aborted_tasks = [] with self.waiting_dict_lock: for key, trans_task in list(self.waiting_dict.items()): - if trans_task.request_id == request_id and trans_task.nixl_dst_page_index is None: + if trans_task.request_id == request_id and trans_task.dst_page_index is None: # 对于 已经分配了page index 的任务,不能直接失败,需要两边走完正常流程再失败,不然可能 # 出现复杂的异步协同问题。 aborted_tasks.append(self.waiting_dict.pop(key)) @@ -204,7 +204,7 @@ def _abort(self, request_id: int, error_info: str = "aborted req"): @log_exception def dispatch_task_loop(self): while True: - trans_task_group: NIXLChunckedTransTaskGroup = self.recv_task_group_queue.get() + trans_task_group: PDChunckedTransTaskGroup = self.recv_task_group_queue.get() with self.waiting_dict_lock: for task in trans_task_group.task_list: @@ -217,7 +217,7 @@ def dispatch_task_loop(self): # up status task = trans_task_group.task_list[0] - decode_node_info = NIXLDecodeNodeInfo( + decode_node_info = PDDecodeNodeInfo( decode_node_id=self.args.pd_node_id, pd_master_node_id=task.pd_master_node_id, agent_name=self.transporter.agent_name, @@ -228,10 +228,10 @@ def dispatch_task_loop(self): ready_kv_len=task.start_kv_index, ) - up_status = NixlUpKVStatus( + up_status = PDUpKVStatus( group_request_id=task.request_id, pd_master_node_id=task.pd_master_node_id, - nixl_params=pickle.dumps(decode_node_info), + pd_kv_trans_params=pickle.dumps(decode_node_info), ) self.up_status_in_queue.put(up_status) @@ -258,7 +258,7 @@ def accept_peer_task_loop( except: notify_obj = None - if not isinstance(notify_obj, NIXLChunckedTransTask): + if not isinstance(notify_obj, PDChunckedTransTask): continue # 请求有错误 @@ -281,7 +281,7 @@ def accept_peer_task_loop( # 到了请求页面的阶段 remote_trans_task = notify_obj - if remote_trans_task.nixl_write_stage == "request": + if remote_trans_task.write_stage == "request": with self.waiting_dict_lock: local_trans_task = self.waiting_dict.pop(remote_trans_task.get_key(), None) if local_trans_task is not None: @@ -310,7 +310,7 @@ def accept_peer_task_loop( continue # prefill 写完数据到了 done 阶段 - if remote_trans_task.nixl_write_stage == "done": + if remote_trans_task.write_stage == "done": with self.waiting_dict_lock: local_trans_task = self.waiting_dict.pop(remote_trans_task.get_key(), None) if local_trans_task is not None: @@ -354,8 +354,8 @@ def request_page_loop(self): torch.cuda.set_device(self.device_id) while True: dst_page_index = self.page_index_queue.get() - trans_task: NIXLChunckedTransTask = self.request_page_task_queue.get() - trans_task.nixl_dst_page_index = dst_page_index + trans_task: PDChunckedTransTask = self.request_page_task_queue.get() + trans_task.dst_page_index = dst_page_index trans_task.start_trans_time = time.time() key = trans_task.get_key() try: @@ -378,7 +378,7 @@ def request_page_loop(self): def read_page_to_mems_loop(self): torch.cuda.set_device(self.device_id) while True: - trans_task: NIXLChunckedTransTask = self.ready_page_task_queue.get() + trans_task: PDChunckedTransTask = self.ready_page_task_queue.get() copy_start_event = torch.cuda.Event(enable_timing=True) copy_end_event = torch.cuda.Event(enable_timing=True) with torch.cuda.stream(stream=self.copy_cuda_stream): @@ -386,7 +386,7 @@ def read_page_to_mems_loop(self): cur_mem = self.mem_managers[self.device_id] cur_mem.read_page_kv_move_buffer_to_mem( trans_task.mem_indexes, - page_index=trans_task.nixl_dst_page_index, + page_index=trans_task.dst_page_index, dp_index=trans_task.decode_dp_index, mem_managers=self.mem_managers, dp_world_size=self.dp_world_size, @@ -401,7 +401,7 @@ def success_loop(self): torch.cuda.set_device(self.device_id) while True: copy_end_event, copy_start_event, trans_task = self.success_queue.get() - trans_task: NIXLChunckedTransTask = trans_task + trans_task: PDChunckedTransTask = trans_task copy_end_event: Optional[torch.cuda.Event] = copy_end_event copy_start_event: Optional[torch.cuda.Event] = copy_start_event read_page_gpu_time_ms = -1.0 @@ -409,8 +409,8 @@ def success_loop(self): copy_end_event.synchronize() read_page_gpu_time_ms = copy_start_event.elapsed_time(copy_end_event) - if trans_task.nixl_dst_page_index is not None: - self.page_index_queue.put(trans_task.nixl_dst_page_index) + if trans_task.dst_page_index is not None: + self.page_index_queue.put(trans_task.dst_page_index) if trans_task.xfer_handle is not None: self.transporter.release_xfer_handle(trans_task.xfer_handle) @@ -430,11 +430,11 @@ def success_loop(self): def fail_loop(self): torch.cuda.set_device(self.device_id) while True: - trans_task: NIXLChunckedTransTask = self.failed_queue.get() + trans_task: PDChunckedTransTask = self.failed_queue.get() # 回收页面 - if trans_task.nixl_dst_page_index is not None: - self.page_index_queue.put(trans_task.nixl_dst_page_index) + if trans_task.dst_page_index is not None: + self.page_index_queue.put(trans_task.dst_page_index) if trans_task.xfer_handle is not None: self.transporter.release_xfer_handle(trans_task.xfer_handle) diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/up_status.py b/lightllm/server/router/model_infer/mode_backend/pd/decode_node_impl/up_status.py similarity index 93% rename from lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/up_status.py rename to lightllm/server/router/model_infer/mode_backend/pd/decode_node_impl/up_status.py index 3926ad0eaa..bc1d00f384 100644 --- a/lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/up_status.py +++ b/lightllm/server/router/model_infer/mode_backend/pd/decode_node_impl/up_status.py @@ -9,7 +9,7 @@ from typing import Dict from dataclasses import asdict -from lightllm.server.pd_io_struct import NixlUpKVStatus +from lightllm.server.pd_io_struct import PDUpKVStatus from lightllm.utils.log_utils import init_logger from lightllm.utils.graceful_utils import graceful_registry from lightllm.server.pd_io_struct import PD_Master_Obj @@ -22,7 +22,7 @@ class UpStatusManager: def __init__(self, args, task_in_queue: mp.SimpleQueue): self.args = args - self.task_queue: mp.SimpleQueue[NixlUpKVStatus] = task_in_queue + self.task_queue: mp.SimpleQueue[PDUpKVStatus] = task_in_queue self.daemon_thread = threading.Thread(target=self.thread_loop, daemon=True) self.daemon_thread.start() @@ -66,7 +66,7 @@ async def dispatch_task_loop(self): while True: try: loop = asyncio.get_event_loop() - upkv_status: NixlUpKVStatus = await loop.run_in_executor(None, self.task_queue.get) + upkv_status: PDUpKVStatus = await loop.run_in_executor(None, self.task_queue.get) if upkv_status.pd_master_node_id in self.id_to_handle_queue: await self.id_to_handle_queue[upkv_status.pd_master_node_id].put(upkv_status) else: @@ -89,7 +89,7 @@ async def up_kv_status_task(self, pd_master_obj: PD_Master_Obj): try: if pd_master_obj.node_id in self.id_to_handle_queue: task_queue = self.id_to_handle_queue[pd_master_obj.node_id] - upkv_status: NixlUpKVStatus = await task_queue.get() + upkv_status: PDUpKVStatus = await task_queue.get() await websocket.send(pickle.dumps(upkv_status)) logger.info(f"up kv status: {upkv_status}") else: @@ -110,7 +110,7 @@ async def up_kv_status_task(self, pd_master_obj: PD_Master_Obj): def _init_env(args, task_in_queue: mp.SimpleQueue): graceful_registry(inspect.currentframe().f_code.co_name) - setproctitle.setproctitle(f"lightllm::{get_unique_server_name()}::nixl_up_kv_status") + setproctitle.setproctitle(f"lightllm::{get_unique_server_name()}::pd_up_kv_status") up_kv_manager = UpStatusManager(args, task_in_queue) logger.info(f"up kv manager {str(up_kv_manager)} start ok") while True: diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/kv_transporter.py b/lightllm/server/router/model_infer/mode_backend/pd/kv_transporter.py similarity index 95% rename from lightllm/server/router/model_infer/mode_backend/pd_nixl/kv_transporter.py rename to lightllm/server/router/model_infer/mode_backend/pd/kv_transporter.py index c413515eda..236a737448 100644 --- a/lightllm/server/router/model_infer/mode_backend/pd_nixl/kv_transporter.py +++ b/lightllm/server/router/model_infer/mode_backend/pd/kv_transporter.py @@ -19,7 +19,7 @@ def create_kv_transporter(args: StartArgs, node_id: int, tp_idx: int, kv_move_bu if backend == "nccl": from .nccl_kv_transporter import NcclKVTransporter - logger.info("Use NCCL as pd_nixl KV transporter backend") + logger.info("Use NCCL as pd KV transporter backend") port_min = args.pd_p_allowed_port_min + tp_idx * 100 port_max = min(args.pd_p_allowed_port_max, port_min + 99) if port_min > args.pd_p_allowed_port_max: diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/nccl_kv_transporter.py b/lightllm/server/router/model_infer/mode_backend/pd/nccl_kv_transporter.py similarity index 87% rename from lightllm/server/router/model_infer/mode_backend/pd_nixl/nccl_kv_transporter.py rename to lightllm/server/router/model_infer/mode_backend/pd/nccl_kv_transporter.py index d048cbf9b7..1bea648a83 100644 --- a/lightllm/server/router/model_infer/mode_backend/pd_nixl/nccl_kv_transporter.py +++ b/lightllm/server/router/model_infer/mode_backend/pd/nccl_kv_transporter.py @@ -13,7 +13,7 @@ from rpyc.utils.server import ThreadedServer from lightllm.distributed.pynccl import PyNcclCommunicator, StatelessP2PProcessGroup -from lightllm.server.pd_io_struct import NIXLChunckedTransTask, NixlAgentMetadata +from lightllm.server.pd_io_struct import PDChunckedTransTask, PDAgentMetadata from lightllm.utils.envs_utils import get_env_start_args from lightllm.utils.log_utils import init_logger from lightllm.utils.net_utils import get_hostname_ip @@ -31,12 +31,12 @@ class NcclAgentMetadata: class NcclKVTransporter: """ - NIXL-compatible transporter backed by NCCL point-to-point operations. + PD KV transporter backed by NCCL point-to-point operations. - NIXL provides remote notifications and one-sided WRITE. NCCL does not, so this - class uses a small TCP control channel for notifications and communicator - bootstrap while preserving the same request/ready/done/error interface used by - pd_nixl trans-process management. + NCCL does not provide remote notifications or one-sided WRITE, so this class + uses a small RPyC control channel for notifications and communicator bootstrap + while preserving the same request/ready/done/error interface used by pd + trans-process management. """ def __init__( @@ -52,8 +52,8 @@ def __init__( self.tp_idx = tp_idx self.kv_move_buffer = kv_move_buffer args = get_env_start_args() - assert args.run_mode in ["nixl_prefill", "nixl_decode"], args.run_mode - self.is_prefill_node = args.run_mode == "nixl_prefill" + assert args.run_mode in ["prefill", "decode"], args.run_mode + self.is_prefill_node = args.run_mode == "prefill" self.capture_telemetry = False self.num_pages, self.page_size, self.num_layers, self.kv_head_num, self.head_dims = kv_move_buffer.shape @@ -65,7 +65,7 @@ def __init__( port_min=control_port_min, port_max=control_port_max, ) - self.remote_agents: Dict[str, NixlAgentMetadata] = {} + self.remote_agents: Dict[str, PDAgentMetadata] = {} self._peers: Dict[str, "_NcclPeer"] = {} self._peer_lock = threading.Lock() return @@ -104,7 +104,7 @@ def get_new_notifs(self) -> Dict[str, List[bytes]]: notifs.setdefault(self._get_notify_source_agent_name(notify), []).append(notify) return notifs - def connect_add_remote_agent(self, remote_agent: NixlAgentMetadata): + def connect_add_remote_agent(self, remote_agent: PDAgentMetadata): if remote_agent.agent_name in self.remote_agents: return @@ -130,9 +130,9 @@ def remove_remote_agent(self, peer_name: str): logger.warning(f"try to remove remote agent, but peer name {peer_name} agent did not exist") return - def send_write_done_task_to_decode_node(self, trans_task: NIXLChunckedTransTask): + def send_write_done_task_to_decode_node(self, trans_task: PDChunckedTransTask): new_trans_task = self._copy_notify_task(trans_task) - new_trans_task.nixl_write_stage = "done" + new_trans_task.write_stage = "done" new_trans_task.prefill_agent_name = self.agent_name new_trans_task.prefill_agent_metadata = self.agent_metadata new_trans_task.prefill_num_pages = self.num_pages @@ -140,9 +140,9 @@ def send_write_done_task_to_decode_node(self, trans_task: NIXLChunckedTransTask) self._send_task_notif(trans_task.decode_agent_name, new_trans_task) return - def send_write_request_task_to_decode_node(self, trans_task: NIXLChunckedTransTask): + def send_write_request_task_to_decode_node(self, trans_task: PDChunckedTransTask): new_trans_task = self._copy_notify_task(trans_task) - new_trans_task.nixl_write_stage = "request" + new_trans_task.write_stage = "request" new_trans_task.prefill_agent_name = self.agent_name new_trans_task.prefill_agent_metadata = self.agent_metadata new_trans_task.prefill_num_pages = self.num_pages @@ -150,14 +150,14 @@ def send_write_request_task_to_decode_node(self, trans_task: NIXLChunckedTransTa self._send_task_notif(trans_task.decode_agent_name, new_trans_task) return - def send_write_ready_task_to_prefill_node(self, trans_task: NIXLChunckedTransTask): + def send_write_ready_task_to_prefill_node(self, trans_task: PDChunckedTransTask): if trans_task.prefill_agent_name not in self.remote_agents: self.connect_add_remote_agent(trans_task.create_prefill_agent_obj()) self._get_peer(trans_task.prefill_agent_name).start_recv(trans_task) new_trans_task = self._copy_notify_task(trans_task) - new_trans_task.nixl_write_stage = "ready" + new_trans_task.write_stage = "ready" new_trans_task.decode_agent_name = self.agent_name new_trans_task.decode_agent_metadata = self.agent_metadata new_trans_task.decode_num_pages = self.num_pages @@ -165,11 +165,11 @@ def send_write_ready_task_to_prefill_node(self, trans_task: NIXLChunckedTransTas self._send_task_notif(trans_task.prefill_agent_name, new_trans_task) return - def send_error_info_to_prefill_node(self, trans_task: NIXLChunckedTransTask): + def send_error_info_to_prefill_node(self, trans_task: PDChunckedTransTask): if trans_task.prefill_agent_name is None: return new_trans_task = self._copy_notify_task(trans_task) - new_trans_task.nixl_write_stage = "error" + new_trans_task.write_stage = "error" new_trans_task.decode_agent_name = self.agent_name new_trans_task.decode_agent_metadata = self.agent_metadata new_trans_task.decode_num_pages = self.num_pages @@ -177,9 +177,9 @@ def send_error_info_to_prefill_node(self, trans_task: NIXLChunckedTransTask): self._send_task_notif(trans_task.prefill_agent_name, new_trans_task) return - def send_error_info_to_decode_node(self, trans_task: NIXLChunckedTransTask): + def send_error_info_to_decode_node(self, trans_task: PDChunckedTransTask): new_trans_task = self._copy_notify_task(trans_task) - new_trans_task.nixl_write_stage = "error" + new_trans_task.write_stage = "error" new_trans_task.prefill_agent_name = self.agent_name new_trans_task.prefill_agent_metadata = self.agent_metadata new_trans_task.prefill_num_pages = self.num_pages @@ -187,15 +187,15 @@ def send_error_info_to_decode_node(self, trans_task: NIXLChunckedTransTask): self._send_task_notif(trans_task.decode_agent_name, new_trans_task) return - def write_blocks_paged(self, trans_task: NIXLChunckedTransTask) -> "_NcclXferHandle": - assert trans_task.nixl_src_page_index is not None and trans_task.nixl_dst_page_index is not None + def write_blocks_paged(self, trans_task: PDChunckedTransTask) -> "_NcclXferHandle": + assert trans_task.src_page_index is not None and trans_task.dst_page_index is not None decode_agent_name = trans_task.decode_agent_name if decode_agent_name not in self.remote_agents: self.connect_add_remote_agent(trans_task.create_decode_agent_obj()) return self._get_peer(decode_agent_name).send_page(trans_task) - def check_task_status(self, trans_task: NIXLChunckedTransTask) -> str: + def check_task_status(self, trans_task: PDChunckedTransTask) -> str: assert trans_task.xfer_handle is not None return trans_task.xfer_handle.check_status() @@ -220,7 +220,7 @@ def _get_peer(self, peer_name: str) -> "_NcclPeer": self._peers[peer_name] = peer return peer - def _send_task_notif(self, remote_agent_name: str, trans_task: NIXLChunckedTransTask): + def _send_task_notif(self, remote_agent_name: str, trans_task: PDChunckedTransTask): if remote_agent_name not in self.remote_agents: if remote_agent_name == trans_task.decode_agent_name: self.connect_add_remote_agent(trans_task.create_decode_agent_obj()) @@ -240,15 +240,15 @@ def _get_remote_metadata(self, remote_agent_name: str) -> NcclAgentMetadata: remote_agent = self.remote_agents[remote_agent_name] return pickle.loads(remote_agent.agent_metadata) - def _copy_notify_task(self, trans_task: NIXLChunckedTransTask) -> NIXLChunckedTransTask: - new_trans_task: NIXLChunckedTransTask = copy.copy(trans_task) + def _copy_notify_task(self, trans_task: PDChunckedTransTask) -> PDChunckedTransTask: + new_trans_task: PDChunckedTransTask = copy.copy(trans_task) new_trans_task.mem_indexes = None new_trans_task.xfer_handle = None return new_trans_task def _get_notify_source_agent_name(self, notify: bytes) -> str: notify_obj = pickle.loads(notify) - assert isinstance(notify_obj, NIXLChunckedTransTask), type(notify_obj) + assert isinstance(notify_obj, PDChunckedTransTask), type(notify_obj) if notify_obj.error_info is not None: if self.is_prefill_node: @@ -258,15 +258,15 @@ def _get_notify_source_agent_name(self, notify: bytes) -> str: assert notify_obj.prefill_agent_name is not None return notify_obj.prefill_agent_name - if notify_obj.nixl_write_stage == "request": + if notify_obj.write_stage == "request": assert notify_obj.prefill_agent_name is not None return notify_obj.prefill_agent_name - if notify_obj.nixl_write_stage in ["ready", "done"]: + if notify_obj.write_stage in ["ready", "done"]: assert notify_obj.decode_agent_name is not None return notify_obj.decode_agent_name - raise AssertionError(f"unexpected notify stage: {notify_obj.nixl_write_stage}") + raise AssertionError(f"unexpected notify stage: {notify_obj.write_stage}") @dataclass @@ -295,12 +295,12 @@ def __init__(self, transporter: NcclKVTransporter, peer_name: str): self.peer_name = peer_name self.comm: Optional[PyNcclCommunicator] = None self.stream: Optional[torch.cuda.Stream] = None - self.recv_queue: Optional["queue.Queue[Optional[NIXLChunckedTransTask]]"] = None + self.recv_queue: Optional["queue.Queue[Optional[PDChunckedTransTask]]"] = None self._lock = threading.Lock() - def send_page(self, trans_task: NIXLChunckedTransTask) -> _NcclXferHandle: - assert trans_task.nixl_src_page_index is not None and trans_task.nixl_dst_page_index is not None - page_tensor = self.transporter.kv_move_buffer[trans_task.nixl_src_page_index] + def send_page(self, trans_task: PDChunckedTransTask) -> _NcclXferHandle: + assert trans_task.src_page_index is not None and trans_task.dst_page_index is not None + page_tensor = self.transporter.kv_move_buffer[trans_task.src_page_index] comm = self._ensure_comm(is_server=True) stream = self._get_stream() @@ -310,11 +310,11 @@ def send_page(self, trans_task: NIXLChunckedTransTask) -> _NcclXferHandle: logger.info( f"NCCL send page posted request_id={trans_task.request_id} " - f"src_page={trans_task.nixl_src_page_index} dst_agent={self.peer_name}" + f"src_page={trans_task.src_page_index} dst_agent={self.peer_name}" ) return _NcclXferHandle(peer_name=self.peer_name, event=event) - def start_recv(self, trans_task: NIXLChunckedTransTask): + def start_recv(self, trans_task: PDChunckedTransTask): self._get_recv_queue().put(copy.copy(trans_task)) return @@ -339,7 +339,7 @@ def _get_stream(self) -> torch.cuda.Stream: self.stream = torch.cuda.Stream() return self.stream - def _get_recv_queue(self) -> "queue.Queue[Optional[NIXLChunckedTransTask]]": + def _get_recv_queue(self) -> "queue.Queue[Optional[PDChunckedTransTask]]": with self._lock: if self.recv_queue is not None: return self.recv_queue @@ -348,7 +348,7 @@ def _get_recv_queue(self) -> "queue.Queue[Optional[NIXLChunckedTransTask]]": threading.Thread(target=self._recv_page_loop, args=(self.recv_queue,), daemon=True).start() return self.recv_queue - def _recv_page_loop(self, recv_queue: "queue.Queue[Optional[NIXLChunckedTransTask]]"): + def _recv_page_loop(self, recv_queue: "queue.Queue[Optional[PDChunckedTransTask]]"): torch.cuda.set_device(self.transporter.tp_idx) while True: trans_task = recv_queue.get() @@ -356,15 +356,15 @@ def _recv_page_loop(self, recv_queue: "queue.Queue[Optional[NIXLChunckedTransTas return self._recv_page(trans_task) - def _recv_page(self, trans_task: NIXLChunckedTransTask): + def _recv_page(self, trans_task: PDChunckedTransTask): try: - page_tensor = self.transporter.kv_move_buffer[trans_task.nixl_dst_page_index] + page_tensor = self.transporter.kv_move_buffer[trans_task.dst_page_index] comm = self._ensure_comm(is_server=False) stream = self._get_stream() comm.recv(page_tensor, src=0, stream=stream) logger.info( f"NCCL recv page done request_id={trans_task.request_id} " - f"dst_page={trans_task.nixl_dst_page_index}" + f"dst_page={trans_task.dst_page_index}" ) except BaseException as e: trans_task.error_info = str(e) diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/nixl_kv_transporter.py b/lightllm/server/router/model_infer/mode_backend/pd/nixl_kv_transporter.py similarity index 87% rename from lightllm/server/router/model_infer/mode_backend/pd_nixl/nixl_kv_transporter.py rename to lightllm/server/router/model_infer/mode_backend/pd/nixl_kv_transporter.py index 696d98afab..bd5e11f05d 100644 --- a/lightllm/server/router/model_infer/mode_backend/pd_nixl/nixl_kv_transporter.py +++ b/lightllm/server/router/model_infer/mode_backend/pd/nixl_kv_transporter.py @@ -5,7 +5,7 @@ from dataclasses import dataclass from typing import Dict from torch import Tensor -from lightllm.server.pd_io_struct import NIXLChunckedTransTask, NixlAgentMetadata +from lightllm.server.pd_io_struct import PDChunckedTransTask, PDAgentMetadata from lightllm.utils.log_utils import init_logger @@ -39,7 +39,7 @@ def __init__(self, node_id: int, tp_idx: int, kv_move_buffer: Tensor): logger.info("NIXL telemetry enabled") self.nixl_agent = NixlWrapper(self.agent_name, conf) self._register_kv_move_buffer(kv_move_buffer=kv_move_buffer) - self.remote_agents: Dict[str, NixlAgentMetadata] = {} + self.remote_agents: Dict[str, PDAgentMetadata] = {} return @property @@ -72,7 +72,7 @@ def _create_paged_xfer_handles(self, reg_desc: "nixlBind.nixlRegDList", page_num descs = self.nixl_agent.get_xfer_descs(pages_data, "VRAM") return self.nixl_agent.prep_xfer_dlist(agent_name, descs, "VRAM") - def connect_add_remote_agent(self, remote_agent: NixlAgentMetadata): + def connect_add_remote_agent(self, remote_agent: PDAgentMetadata): if remote_agent.agent_name in self.remote_agents: return @@ -102,7 +102,7 @@ def connect_add_remote_agent(self, remote_agent: NixlAgentMetadata): def remove_remote_agent(self, peer_name: str): if peer_name in self.remote_agents: try: - remote_agent: NixlAgentMetadata = self.remote_agents.pop(peer_name, None) + remote_agent: PDAgentMetadata = self.remote_agents.pop(peer_name, None) assert remote_agent.agent_name == peer_name self.nixl_agent.remove_remote_agent(remote_agent.agent_name) if remote_agent.page_xfer_handles is not None: @@ -113,15 +113,15 @@ def remove_remote_agent(self, peer_name: str): else: logger.warning(f"try to remove remote agent, but peer name {peer_name} agent did not exist") - def send_write_done_task_to_decode_node(self, trans_task: NIXLChunckedTransTask): + def send_write_done_task_to_decode_node(self, trans_task: PDChunckedTransTask): decode_agent_name = trans_task.decode_agent_name if decode_agent_name not in self.remote_agents: logger.warning(f"decode_agent_name {decode_agent_name} not exist") _remote_agent = trans_task.create_decode_agent_obj() self.connect_add_remote_agent(_remote_agent) - new_trans_task: NIXLChunckedTransTask = copy.copy(trans_task) - new_trans_task.nixl_write_stage = "done" + new_trans_task: PDChunckedTransTask = copy.copy(trans_task) + new_trans_task.write_stage = "done" new_trans_task.mem_indexes = None new_trans_task.xfer_handle = None new_trans_task.decode_agent_metadata = None @@ -136,15 +136,15 @@ def send_write_done_task_to_decode_node(self, trans_task: NIXLChunckedTransTask) ) return - def send_write_request_task_to_decode_node(self, trans_task: NIXLChunckedTransTask): + def send_write_request_task_to_decode_node(self, trans_task: PDChunckedTransTask): decode_agent_name = trans_task.decode_agent_name if decode_agent_name not in self.remote_agents: logger.warning(f"decode_agent_name {decode_agent_name} not exist") _remote_agent = trans_task.create_decode_agent_obj() self.connect_add_remote_agent(_remote_agent) - new_trans_task: NIXLChunckedTransTask = copy.copy(trans_task) - new_trans_task.nixl_write_stage = "request" + new_trans_task: PDChunckedTransTask = copy.copy(trans_task) + new_trans_task.write_stage = "request" new_trans_task.mem_indexes = None new_trans_task.xfer_handle = None new_trans_task.prefill_agent_name = self.agent_name @@ -157,15 +157,15 @@ def send_write_request_task_to_decode_node(self, trans_task: NIXLChunckedTransTa ) return - def send_write_ready_task_to_prefill_node(self, trans_task: NIXLChunckedTransTask): + def send_write_ready_task_to_prefill_node(self, trans_task: PDChunckedTransTask): prefill_agent_name = trans_task.prefill_agent_name if prefill_agent_name not in self.remote_agents: logger.warning(f"prefill_agent_name {prefill_agent_name} not exist") _remote_agent = trans_task.create_prefill_agent_obj() self.connect_add_remote_agent(_remote_agent) - new_trans_task: NIXLChunckedTransTask = copy.copy(trans_task) - new_trans_task.nixl_write_stage = "ready" + new_trans_task: PDChunckedTransTask = copy.copy(trans_task) + new_trans_task.write_stage = "ready" new_trans_task.mem_indexes = None new_trans_task.xfer_handle = None new_trans_task.decode_agent_name = self.agent_name @@ -178,7 +178,7 @@ def send_write_ready_task_to_prefill_node(self, trans_task: NIXLChunckedTransTas ) return - def send_error_info_to_prefill_node(self, trans_task: NIXLChunckedTransTask): + def send_error_info_to_prefill_node(self, trans_task: PDChunckedTransTask): # decode node 主动发送错误信息给 prefill node, 但是只有到达一定阶段的任务才有对端的信息 # 才能发送 if trans_task.prefill_agent_name is None: @@ -191,8 +191,8 @@ def send_error_info_to_prefill_node(self, trans_task: NIXLChunckedTransTask): _remote_agent = trans_task.create_prefill_agent_obj() self.connect_add_remote_agent(_remote_agent) assert trans_task.error_info is not None - new_trans_task: NIXLChunckedTransTask = copy.copy(trans_task) - new_trans_task.nixl_write_stage = "error" + new_trans_task: PDChunckedTransTask = copy.copy(trans_task) + new_trans_task.write_stage = "error" new_trans_task.mem_indexes = None new_trans_task.xfer_handle = None new_trans_task.decode_agent_name = self.agent_name @@ -209,7 +209,7 @@ def send_error_info_to_prefill_node(self, trans_task: NIXLChunckedTransTask): self.remove_remote_agent(peer_name=prefill_agent_name) return - def send_error_info_to_decode_node(self, trans_task: NIXLChunckedTransTask): + def send_error_info_to_decode_node(self, trans_task: PDChunckedTransTask): try: decode_agent_name = trans_task.decode_agent_name if decode_agent_name not in self.remote_agents: @@ -217,8 +217,8 @@ def send_error_info_to_decode_node(self, trans_task: NIXLChunckedTransTask): _remote_agent = trans_task.create_decode_agent_obj() self.connect_add_remote_agent(_remote_agent) assert trans_task.error_info is not None - new_trans_task: NIXLChunckedTransTask = copy.copy(trans_task) - new_trans_task.nixl_write_stage = "error" + new_trans_task: PDChunckedTransTask = copy.copy(trans_task) + new_trans_task.write_stage = "error" new_trans_task.mem_indexes = None new_trans_task.xfer_handle = None new_trans_task.prefill_agent_name = self.agent_name @@ -237,7 +237,7 @@ def send_error_info_to_decode_node(self, trans_task: NIXLChunckedTransTask): def write_blocks_paged( self, - trans_task: NIXLChunckedTransTask, + trans_task: PDChunckedTransTask, ) -> int: """ prefill node call this function to write kv blocks into decode node pages @@ -248,16 +248,16 @@ def write_blocks_paged( _remote_agent = trans_task.create_decode_agent_obj() self.connect_add_remote_agent(_remote_agent) - assert trans_task.nixl_src_page_index is not None and trans_task.nixl_dst_page_index is not None - remote_agent: NixlAgentMetadata = self.remote_agents[decode_agent_name] + assert trans_task.src_page_index is not None and trans_task.dst_page_index is not None + remote_agent: PDAgentMetadata = self.remote_agents[decode_agent_name] src_handle = self.page_local_xfer_handles dst_handle = remote_agent.page_xfer_handles handle = self.nixl_agent.make_prepped_xfer( "WRITE", src_handle, - [trans_task.nixl_src_page_index], + [trans_task.src_page_index], dst_handle, - [trans_task.nixl_dst_page_index], + [trans_task.dst_page_index], b"", ) if not handle: @@ -267,7 +267,7 @@ def write_blocks_paged( return handle - def check_task_status(self, trans_task: NIXLChunckedTransTask) -> str: + def check_task_status(self, trans_task: PDChunckedTransTask) -> str: assert trans_task.xfer_handle is not None handle = trans_task.xfer_handle xfer_state = self.nixl_agent.check_xfer_state(handle) diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/p2p_fix.py b/lightllm/server/router/model_infer/mode_backend/pd/p2p_fix.py similarity index 97% rename from lightllm/server/router/model_infer/mode_backend/pd_nixl/p2p_fix.py rename to lightllm/server/router/model_infer/mode_backend/pd/p2p_fix.py index 8ca5f4e808..0307df582b 100644 --- a/lightllm/server/router/model_infer/mode_backend/pd_nixl/p2p_fix.py +++ b/lightllm/server/router/model_infer/mode_backend/pd/p2p_fix.py @@ -95,7 +95,7 @@ def reduce_tensor(tensor): storage = tensor._typed_storage() if storage._untyped_storage.device.type == "cuda": - from lightllm.server.router.model_infer.mode_backend.pd_nixl.p2p_fix import p2p_fix_rebuild_cuda_tensor + from lightllm.server.router.model_infer.mode_backend.pd.p2p_fix import p2p_fix_rebuild_cuda_tensor ( device, diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/prefill_node_impl/__init__.py b/lightllm/server/router/model_infer/mode_backend/pd/prefill_node_impl/__init__.py similarity index 100% rename from lightllm/server/router/model_infer/mode_backend/pd_nixl/prefill_node_impl/__init__.py rename to lightllm/server/router/model_infer/mode_backend/pd/prefill_node_impl/__init__.py diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/prefill_node_impl/prefill_impl.py b/lightllm/server/router/model_infer/mode_backend/pd/prefill_node_impl/prefill_impl.py similarity index 70% rename from lightllm/server/router/model_infer/mode_backend/pd_nixl/prefill_node_impl/prefill_impl.py rename to lightllm/server/router/model_infer/mode_backend/pd/prefill_node_impl/prefill_impl.py index 87d72df75c..0f297a1fe5 100644 --- a/lightllm/server/router/model_infer/mode_backend/pd_nixl/prefill_node_impl/prefill_impl.py +++ b/lightllm/server/router/model_infer/mode_backend/pd/prefill_node_impl/prefill_impl.py @@ -2,7 +2,7 @@ import random from typing import List, Tuple from lightllm.server.router.model_infer.infer_batch import InferReq -from lightllm.server.pd_io_struct import NIXLChunckedTransTask +from lightllm.server.pd_io_struct import PDChunckedTransTask from lightllm.utils.log_utils import init_logger from lightllm.utils.device_utils import kv_trans_use_p2p from lightllm.server.router.model_infer.infer_batch import g_infer_context @@ -11,13 +11,13 @@ logger = init_logger(__name__) -class NIXLChunckedPrefillForPrefillNode(ChunkedPrefillBackend): +class PDChunkedPrefillForPrefillNode(ChunkedPrefillBackend): def __init__(self, info_queue: mp.Queue) -> None: super().__init__() self.support_overlap = False self.info_queue: mp.Queue = info_queue self.classed_req_no_decode = True - self.nixl_prefill_chuncked_handle_func = self._prefill_chuncked_handle_func + self.pd_prefill_chunked_handle_func = self._prefill_chuncked_handle_func def init_custom(self): assert kv_trans_use_p2p() @@ -35,12 +35,12 @@ def _filter_not_ready_reqs(self, req_ids: List[int]) -> List[InferReq]: prefill_finished = req_obj.shm_req.input_len <= req_obj.cur_kv_len if prefill_finished: # 等待所有传输任务都已经完成。 - if req_obj.nixl_pd_task_num == (req_obj.nixl_pd_task_failed_num + req_obj.nixl_pd_task_sunccess_num): + if req_obj.pd_task_num == (req_obj.pd_task_failed_num + req_obj.pd_task_success_num): ans_list.append(req_obj) else: if req_obj.infer_aborted: - if req_obj.nixl_pd_task_num == ( - req_obj.nixl_pd_task_failed_num + req_obj.nixl_pd_task_sunccess_num + if req_obj.pd_task_num == ( + req_obj.pd_task_failed_num + req_obj.pd_task_success_num ): ans_list.append(req_obj) else: @@ -58,19 +58,19 @@ def _prefill_chuncked_handle_func( assert req_obj.cur_kv_len <= req_obj.shm_req.input_len input_len = req_obj.shm_req.input_len - page_size = self.args.nixl_pd_kv_page_size + page_size = self.args.pd_kv_page_size prefill_finished = req_obj.cur_kv_len == input_len - trans_task_list: List[NIXLChunckedTransTask] = [] - while req_obj.nixl_trans_kv_start_index < req_obj.cur_kv_len: - cur_page_size = min(page_size, req_obj.cur_kv_len - req_obj.nixl_trans_kv_start_index) + trans_task_list: List[PDChunckedTransTask] = [] + while req_obj.pd_trans_kv_start_index < req_obj.cur_kv_len: + cur_page_size = min(page_size, req_obj.cur_kv_len - req_obj.pd_trans_kv_start_index) # 生成页面传输任务, 放入kv move manager 的处理队列中 if cur_page_size == page_size or prefill_finished: - trans_task = self._create_nixl_trans_task( + trans_task = self._create_pd_trans_task( req_obj=req_obj, - kv_start_index=req_obj.nixl_trans_kv_start_index, - kv_end_index=req_obj.nixl_trans_kv_start_index + cur_page_size, + kv_start_index=req_obj.pd_trans_kv_start_index, + kv_end_index=req_obj.pd_trans_kv_start_index + cur_page_size, ) - req_obj.nixl_trans_kv_start_index += cur_page_size + req_obj.pd_trans_kv_start_index += cur_page_size trans_task_list.append(trans_task) else: break @@ -78,7 +78,7 @@ def _prefill_chuncked_handle_func( if prefill_finished and len(trans_task_list) != 0 and output_len == 1: if g_infer_context.is_linear_att_mixed_model: trans_task_list.append( - self._create_nixl_trans_task( + self._create_pd_trans_task( req_obj=req_obj, kv_start_index=input_len, kv_end_index=input_len, @@ -93,21 +93,21 @@ def _prefill_chuncked_handle_func( self.info_queue.put(trans_task) return - def _create_nixl_trans_task( + def _create_pd_trans_task( self, req_obj: InferReq, kv_start_index: int, kv_end_index: int, page_kind: str = "kv", - ) -> NIXLChunckedTransTask: + ) -> PDChunckedTransTask: # 确定传输设备 - if req_obj.nixl_trans_device_id == -1: - if not hasattr(self, "nixl_iter_device_id"): - self.nixl_iter_device_id = 0 - req_obj.nixl_trans_device_id = self.nixl_iter_device_id - self.nixl_iter_device_id = (self.nixl_iter_device_id + 1) % self.node_world_size + if req_obj.pd_trans_device_id == -1: + if not hasattr(self, "pd_iter_device_id"): + self.pd_iter_device_id = 0 + req_obj.pd_trans_device_id = self.pd_iter_device_id + self.pd_iter_device_id = (self.pd_iter_device_id + 1) % self.node_world_size - nixl_decode_node_info = req_obj.sampling_param.nixl_decode_node + pd_decode_node_info = req_obj.sampling_param.pd_decode_node if page_kind == "kv": mem_indexes = ( self.model.req_manager.req_to_token_indexs[req_obj.req_idx, kv_start_index:kv_end_index] @@ -120,8 +120,8 @@ def _create_nixl_trans_task( mem_indexes = [] req_idx = req_obj.req_idx else: - raise ValueError(f"unknown NIXL trans page kind {page_kind}") - trans_task = NIXLChunckedTransTask( + raise ValueError(f"unknown PD trans page kind {page_kind}") + trans_task = PDChunckedTransTask( request_id=req_obj.req_id, start_kv_index=kv_start_index, end_kv_index=kv_end_index, @@ -129,21 +129,21 @@ def _create_nixl_trans_task( pd_master_node_id=req_obj.sampling_param.pd_master_node_id, prefill_dp_index=self.dp_rank_in_node, decode_dp_index=None, - src_device_id=req_obj.nixl_trans_device_id, + src_device_id=req_obj.pd_trans_device_id, dst_device_id=None, mem_indexes=mem_indexes, prefill_agent_name=None, prefill_agent_metadata=None, prefill_num_pages=None, prefill_page_reg_desc=None, - decode_agent_name=nixl_decode_node_info.agent_name, - decode_agent_metadata=nixl_decode_node_info.agent_metadata, - decode_num_pages=nixl_decode_node_info.num_pages, - decode_page_reg_desc=nixl_decode_node_info.page_reg_desc, + decode_agent_name=pd_decode_node_info.agent_name, + decode_agent_metadata=pd_decode_node_info.agent_metadata, + decode_num_pages=pd_decode_node_info.num_pages, + decode_page_reg_desc=pd_decode_node_info.page_reg_desc, first_gen_token_id=None, first_gen_token_logprob=None, page_kind=page_kind, req_idx=req_idx, ) - req_obj.nixl_pd_task_num += 1 + req_obj.pd_task_num += 1 return trans_task diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/prefill_node_impl/prefill_impl_for_dp.py b/lightllm/server/router/model_infer/mode_backend/pd/prefill_node_impl/prefill_impl_for_dp.py similarity index 66% rename from lightllm/server/router/model_infer/mode_backend/pd_nixl/prefill_node_impl/prefill_impl_for_dp.py rename to lightllm/server/router/model_infer/mode_backend/pd/prefill_node_impl/prefill_impl_for_dp.py index daa041afea..5f4465a730 100644 --- a/lightllm/server/router/model_infer/mode_backend/pd_nixl/prefill_node_impl/prefill_impl_for_dp.py +++ b/lightllm/server/router/model_infer/mode_backend/pd/prefill_node_impl/prefill_impl_for_dp.py @@ -2,38 +2,38 @@ from typing import List, Tuple from lightllm.server.router.model_infer.infer_batch import InferReq from lightllm.utils.log_utils import init_logger -from .prefill_impl import NIXLChunckedPrefillForPrefillNode, NIXLChunckedTransTask +from .prefill_impl import PDChunkedPrefillForPrefillNode, PDChunckedTransTask from lightllm.server.router.model_infer.mode_backend.dp_backend.impl import DPChunkedPrefillBackend logger = init_logger(__name__) -class NIXLDPChunkedForPrefillNode(DPChunkedPrefillBackend): +class PDDPChunkedForPrefillNode(DPChunkedPrefillBackend): def __init__(self, info_queue: mp.Queue) -> None: super().__init__() self.support_overlap = False self.info_queue: mp.Queue = info_queue self.classed_req_no_decode = True - self.nixl_prefill_chuncked_handle_func = self._prefill_chuncked_handle_func + self.pd_prefill_chunked_handle_func = self._prefill_chuncked_handle_func def init_custom(self): - NIXLChunckedPrefillForPrefillNode.init_custom(self) + PDChunkedPrefillForPrefillNode.init_custom(self) return def _filter_not_ready_reqs(self, req_ids: List[int]) -> List[InferReq]: - return NIXLChunckedPrefillForPrefillNode._filter_not_ready_reqs(self, req_ids) + return PDChunkedPrefillForPrefillNode._filter_not_ready_reqs(self, req_ids) def _prefill_chuncked_handle_func( self, req_obj: InferReq, next_token_id: int, next_token_prob: float, output_len: int ): - return NIXLChunckedPrefillForPrefillNode._prefill_chuncked_handle_func( + return PDChunkedPrefillForPrefillNode._prefill_chuncked_handle_func( self, req_obj=req_obj, next_token_id=next_token_id, next_token_prob=next_token_prob, output_len=output_len ) - def _create_nixl_trans_task( + def _create_pd_trans_task( self, req_obj: InferReq, kv_start_index: int, kv_end_index: int, page_kind: str = "kv" - ) -> NIXLChunckedTransTask: - return NIXLChunckedPrefillForPrefillNode._create_nixl_trans_task( + ) -> PDChunckedTransTask: + return PDChunkedPrefillForPrefillNode._create_pd_trans_task( self, req_obj=req_obj, kv_start_index=kv_start_index, diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/prefill_node_impl/prefill_kv_move_manager.py b/lightllm/server/router/model_infer/mode_backend/pd/prefill_node_impl/prefill_kv_move_manager.py similarity index 93% rename from lightllm/server/router/model_infer/mode_backend/pd_nixl/prefill_node_impl/prefill_kv_move_manager.py rename to lightllm/server/router/model_infer/mode_backend/pd/prefill_node_impl/prefill_kv_move_manager.py index fb95091158..7bda07d54d 100644 --- a/lightllm/server/router/model_infer/mode_backend/pd_nixl/prefill_node_impl/prefill_kv_move_manager.py +++ b/lightllm/server/router/model_infer/mode_backend/pd/prefill_node_impl/prefill_kv_move_manager.py @@ -4,7 +4,7 @@ import time from typing import List, Dict, Union, Callable from lightllm.utils.log_utils import init_logger -from lightllm.server.pd_io_struct import NIXLChunckedTransTask +from lightllm.server.pd_io_struct import PDChunckedTransTask from lightllm.utils.graceful_utils import graceful_registry from lightllm.server.core.objs import StartArgs from ..trans_process_obj import KVTransProcess @@ -30,7 +30,7 @@ def _init_env(args, info_queue: mp.Queue, event: mp.Event): # 注册graceful 退出的处理 graceful_registry(inspect.currentframe().f_code.co_name) - setproctitle.setproctitle(f"lightllm::{get_unique_server_name()}::nixl_prefill_kv_move_manager") + setproctitle.setproctitle(f"lightllm::{get_unique_server_name()}::prefill_kv_move_manager") from .prefill_trans_process import start_prefill_trans_process @@ -56,7 +56,7 @@ def __init__(self, args: StartArgs, info_queue: mp.Queue, start_trans_process_fu def task_dispatcher_loop(self): # 获取任务,并分发给相关卡的处理队列 while True: - task: NIXLChunckedTransTask = self.info_queue.get() + task: PDChunckedTransTask = self.info_queue.get() device_id = task.src_device_id try: diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/prefill_node_impl/prefill_trans_process.py b/lightllm/server/router/model_infer/mode_backend/pd/prefill_node_impl/prefill_trans_process.py similarity index 90% rename from lightllm/server/router/model_infer/mode_backend/pd_nixl/prefill_node_impl/prefill_trans_process.py rename to lightllm/server/router/model_infer/mode_backend/pd/prefill_node_impl/prefill_trans_process.py index fa7367958c..e286a10f96 100644 --- a/lightllm/server/router/model_infer/mode_backend/pd_nixl/prefill_node_impl/prefill_trans_process.py +++ b/lightllm/server/router/model_infer/mode_backend/pd/prefill_node_impl/prefill_trans_process.py @@ -9,7 +9,7 @@ from typing import List, Dict, Optional from lightllm.utils.log_utils import init_logger from lightllm.common.kv_cache_mem_manager import MemoryManager -from lightllm.server.pd_io_struct import NIXLChunckedTransTask +from lightllm.server.pd_io_struct import PDChunckedTransTask from lightllm.utils.graceful_utils import graceful_registry from lightllm.server.core.objs import StartArgs from ..kv_transporter import create_kv_transporter @@ -48,7 +48,7 @@ def _init_env( os.environ["CUDA_MPS_CLIENT_PRIORITY"] = "0" torch.backends.cudnn.enabled = False - setproctitle.setproctitle(f"lightllm::{get_unique_server_name()}::nixl_prefill_trans:Device{device_id}") + setproctitle.setproctitle(f"lightllm::{get_unique_server_name()}::prefill_trans:Device{device_id}") try: torch.cuda.set_device(device_id) @@ -99,7 +99,7 @@ def __init__( cur_mem_manager: MemoryManager = self.mem_managers[device_id] kv_move_buffer = cur_mem_manager.alloc_paged_kv_move_buffer( - page_num=self.args.nixl_pd_kv_page_num, page_size=self.args.nixl_pd_kv_page_size + page_num=self.args.pd_kv_page_num, page_size=self.args.pd_kv_page_size ) self.copy_cuda_stream = torch.cuda.Stream(priority=-1) self.transporter = create_kv_transporter( @@ -109,7 +109,7 @@ def __init__( kv_move_buffer=kv_move_buffer, ) self.waiting_dict_lock = threading.Lock() - self.waiting_dict: Dict[str, NIXLChunckedTransTask] = {} + self.waiting_dict: Dict[str, PDChunckedTransTask] = {} self.local_copy_kv_queue = queue.Queue() self.ready_transfer_queue = queue.Queue() @@ -118,7 +118,7 @@ def __init__( self.failed_queue = queue.Queue() self.page_index_queue = queue.Queue() - for page_index in range(self.args.nixl_pd_kv_page_num): + for page_index in range(self.args.pd_kv_page_num): self.page_index_queue.put(page_index) # warmup 预先执行一次 kv 写入 page buffer,避免第一次拷贝时出现卡顿。 @@ -169,8 +169,8 @@ def recv_task_loop(self): while True: page_index = self.page_index_queue.get() - trans_task: NIXLChunckedTransTask = self.task_in_queue.get() - trans_task.nixl_src_page_index = page_index + trans_task: PDChunckedTransTask = self.task_in_queue.get() + trans_task.src_page_index = page_index # 初次校验 time out if trans_task.time_out(): @@ -183,14 +183,14 @@ def recv_task_loop(self): def local_copy_kv_loop(self): torch.cuda.set_device(self.device_id) while True: - trans_task: NIXLChunckedTransTask = self.local_copy_kv_queue.get() + trans_task: PDChunckedTransTask = self.local_copy_kv_queue.get() # 将kv 数据拷贝到 page 上,然后传输给 decode node,让其进行读取。 with torch.cuda.stream(stream=self.copy_cuda_stream): cur_mem = self.mem_managers[self.device_id] cur_mem.write_mem_to_page_kv_move_buffer( trans_task.mem_indexes, - page_index=trans_task.nixl_src_page_index, + page_index=trans_task.src_page_index, dp_index=trans_task.prefill_dp_index, mem_managers=self.mem_managers, dp_world_size=self.dp_world_size, @@ -208,7 +208,7 @@ def ready_transfer_loop(self): torch.cuda.set_device(self.device_id) while True: sync_event, trans_task = self.ready_transfer_queue.get() - trans_task: NIXLChunckedTransTask = trans_task + trans_task: PDChunckedTransTask = trans_task sync_event: torch.cuda.Event = sync_event sync_event.synchronize() key = trans_task.get_key() @@ -246,7 +246,7 @@ def accept_decode_write_task_loop(self): except BaseException: notify_obj = None - if not isinstance(notify_obj, NIXLChunckedTransTask): + if not isinstance(notify_obj, PDChunckedTransTask): continue if notify_obj.error_info is not None: @@ -254,17 +254,17 @@ def accept_decode_write_task_loop(self): self._abort(request_id=notify_obj.request_id, error_info=notify_obj.error_info) continue - if notify_obj.nixl_write_stage == "ready": + if notify_obj.write_stage == "ready": key = notify_obj.get_key() with self.waiting_dict_lock: trans_task = self.waiting_dict.pop(key, None) if trans_task is not None: - trans_task.nixl_dst_page_index = notify_obj.nixl_dst_page_index + trans_task.dst_page_index = notify_obj.dst_page_index self.write_peer_kv_queue.put(trans_task) logger.info( f"recv WRITE ready from decode request_id={trans_task.request_id} " f"kv=[{trans_task.start_kv_index},{trans_task.end_kv_index}) " - f"srcpage={trans_task.nixl_src_page_index} dstpage={trans_task.nixl_dst_page_index}" + f"srcpage={trans_task.src_page_index} dstpage={trans_task.dst_page_index}" ) else: logger.warning( @@ -301,7 +301,7 @@ def write_peer_kv_loop(self): torch.cuda.set_device(self.device_id) while True: trans_task = self.write_peer_kv_queue.get() - trans_task: NIXLChunckedTransTask = trans_task + trans_task: PDChunckedTransTask = trans_task try: xfer_handle = self.transporter.write_blocks_paged(trans_task=trans_task) @@ -348,7 +348,7 @@ def update_task_status_loop( logger.info( f"write trans task request_id={trans_task.request_id} " f"kv=[{trans_task.start_kv_index},{trans_task.end_kv_index}) " - f"src_page={trans_task.nixl_src_page_index} dst_page={trans_task.nixl_dst_page_index} " + f"src_page={trans_task.src_page_index} dst_page={trans_task.dst_page_index} " f"xfer time: {total_us:.3f} us, " f"post time: {post_us:.3f} us, backend time: {backend_us:.3f} us, " f"nixl_backend: {nixl_backend}, total_bytes: {telem.totalBytes}" @@ -358,7 +358,7 @@ def update_task_status_loop( f"send WRITE done nixl notify " f"request_id={trans_task.request_id} " f"kv=[{trans_task.start_kv_index},{trans_task.end_kv_index}) " - f"src_page={trans_task.nixl_src_page_index} dst_page={trans_task.nixl_dst_page_index}" + f"src_page={trans_task.src_page_index} dst_page={trans_task.dst_page_index}" ) self.success_queue.put(trans_task) elif ret == "ERR": @@ -376,10 +376,10 @@ def update_task_status_loop( def success_loop(self): torch.cuda.set_device(self.device_id) while True: - trans_task: NIXLChunckedTransTask = self.success_queue.get() + trans_task: PDChunckedTransTask = self.success_queue.get() # 写回后,回收页面 - if trans_task.nixl_src_page_index is not None: - self.page_index_queue.put(trans_task.nixl_src_page_index) + if trans_task.src_page_index is not None: + self.page_index_queue.put(trans_task.src_page_index) if trans_task.xfer_handle is not None: self.transporter.release_xfer_handle(trans_task.xfer_handle) @@ -397,11 +397,11 @@ def success_loop(self): def fail_loop(self): torch.cuda.set_device(self.device_id) while True: - trans_task: NIXLChunckedTransTask = self.failed_queue.get() + trans_task: PDChunckedTransTask = self.failed_queue.get() # 回收页面 - if trans_task.nixl_src_page_index is not None: - self.page_index_queue.put(trans_task.nixl_src_page_index) + if trans_task.src_page_index is not None: + self.page_index_queue.put(trans_task.src_page_index) if trans_task.xfer_handle is not None: self.transporter.release_xfer_handle(trans_task.xfer_handle) diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/trans_process_obj.py b/lightllm/server/router/model_infer/mode_backend/pd/trans_process_obj.py similarity index 100% rename from lightllm/server/router/model_infer/mode_backend/pd_nixl/trans_process_obj.py rename to lightllm/server/router/model_infer/mode_backend/pd/trans_process_obj.py diff --git a/lightllm/server/router/model_infer/model_rpc.py b/lightllm/server/router/model_infer/model_rpc.py index afa0fb4c7f..864a7405b7 100644 --- a/lightllm/server/router/model_infer/model_rpc.py +++ b/lightllm/server/router/model_infer/model_rpc.py @@ -22,10 +22,10 @@ XgrammarBackend, DPChunkedPrefillBackend, DiversehBackend, - NIXLChunckedPrefillForPrefillNode, - NIXLDPChunkedForPrefillNode, - NIXLDecodeNode, - NIXLDPForDecodeNode, + PDChunkedPrefillForPrefillNode, + PDDPChunkedForPrefillNode, + PDDecodeNode, + PDDPForDecodeNode, ) from lightllm.server.router.model_infer.mode_backend.redundancy_expert_manager import RedundancyExpertManager from lightllm.server.core.objs.start_args_type import StartArgs @@ -63,20 +63,20 @@ def exposed_init_model(self, kvargs): is_outlines_constraint_mode = self.args.output_constraint_mode == "outlines" is_xgrammar_constraint_mode = self.args.output_constraint_mode == "xgrammar" assert not (is_outlines_constraint_mode and is_xgrammar_constraint_mode), "only one constraint mode can be true" - is_nixl_prefill_node = self.args.run_mode == "nixl_prefill" - is_nixl_decode_node = self.args.run_mode == "nixl_decode" + is_prefill_node = self.args.run_mode == "prefill" + is_decode_node = self.args.run_mode == "decode" - if is_nixl_prefill_node: + if is_prefill_node: if self.args.dp > 1: - self.backend = NIXLDPChunkedForPrefillNode(self.info_queue) + self.backend = PDDPChunkedForPrefillNode(self.info_queue) else: - self.backend = NIXLChunckedPrefillForPrefillNode(self.info_queue) + self.backend = PDChunkedPrefillForPrefillNode(self.info_queue) - elif is_nixl_decode_node: + elif is_decode_node: if self.args.dp > 1: - self.backend = NIXLDPForDecodeNode(self.info_queue) + self.backend = PDDPForDecodeNode(self.info_queue) else: - self.backend = NIXLDecodeNode(self.info_queue) + self.backend = PDDecodeNode(self.info_queue) elif self.args.dp > 1: self.backend = DPChunkedPrefillBackend() diff --git a/lightllm/server/router/req_queue/__init__.py b/lightllm/server/router/req_queue/__init__.py index 5f0cb866ed..f977f6d567 100644 --- a/lightllm/server/router/req_queue/__init__.py +++ b/lightllm/server/router/req_queue/__init__.py @@ -13,7 +13,7 @@ def _get_req_queue_class(args, router, dp_size_in_node: int): return ChunkedPrefillQueue if args.first_token_constraint_mode: return ChunkedPrefillQueue - if args.run_mode in ["nixl_prefill", "nixl_decode"]: + if args.run_mode in ["prefill", "decode"]: return NIXLPDQueue if args.disable_chunked_prefill: diff --git a/skills/test_model/qwen3-8b-pd-nixl/SKILL.md b/skills/test_model/qwen3-8b-pd-nixl/SKILL.md index 5ee864a910..2012daadc5 100644 --- a/skills/test_model/qwen3-8b-pd-nixl/SKILL.md +++ b/skills/test_model/qwen3-8b-pd-nixl/SKILL.md @@ -1,8 +1,8 @@ --- name: test-model-qwen3-8b-pd-nixl description: >- - LightLLM Qwen3-8b PD disaggregation over NIXL gsm8k: pd_master on 8089, nixl_prefill on 8001, - nixl_decode on 8002, tp 2 each. Assign four GPUs via nvidia-smi then export + LightLLM Qwen3-8b PD disaggregation over NIXL gsm8k: pd_master on 8089, prefill on 8001, + decode on 8002, tp 2 each. Assign four GPUs via nvidia-smi then export PREFILL_CUDA_DEVICES / DECODE_CUDA_DEVICES (no fixed card IDs; no complex shell automation). UCX_NET_DEVICES and TLS for RDMA per cluster. lm_eval hits pd_master URL. HOST vs PD_MASTER_IP when co-located. Before lm_eval, must POST one completion via curl to @@ -12,13 +12,13 @@ description: >- Use for PD NIXL-style separation tests. --- -# Qwen3-8B **PD 分离(NIXL)**(`pd_master` + `nixl_prefill` + `nixl_decode`)本地 GSM8K 评测 +# Qwen3-8B **PD 分离(NIXL)**(`pd_master` + `prefill` + `decode`)本地 GSM8K 评测 -**测试标识**:同一 **`--model_dir`**(Qwen3-8B)下拆 **三条** `api_server` 进程——**调度/入口(`pd_master`)**、**`nixl_prefill` 节点**、**`nixl_decode` 节点**;评测 **`lm_eval`** 只访问 **`pd_master` 的 HTTP 端口(8089)**。默认使用 NIXL 传输;需要验证 NCCL 数据面时,设置 **`LIGHTLLM_PD_KV_TRANSPORT_BACKEND=nccl`**,上层仍保持相同的 `nixl_prefill` / `nixl_decode` 管理路径。 +**测试标识**:同一 **`--model_dir`**(Qwen3-8B)下拆 **三条** `api_server` 进程——**调度/入口(`pd_master`)**、**`prefill` 节点**、**`decode` 节点**;评测 **`lm_eval`** 只访问 **`pd_master` 的 HTTP 端口(8089)**。默认使用 NIXL 传输;需要验证 NCCL 数据面时,设置 **`LIGHTLLM_PD_KV_TRANSPORT_BACKEND=nccl`**,上层仍保持相同的 `prefill` / `decode` 管理路径。 **端口约定**:**`pd_master`:`8089`**;**prefill:`8001`**;**decode:`8002`**。启动与就绪探测须覆盖这三处(以及日志中的 PD 注册/报错信息)。 -**绑定 IP(`HOST` / `PD_MASTER_IP`)**:各进程的 **`--host`** 表示 **本服务监听绑定的 IP**。当 **`pd_master`、`nixl_prefill`、`nixl_decode` 部署在同一台机器上时**,三者使用的绑定 IP **相同**:可 **`export HOST="${PD_MASTER_IP}"`**;**`lm_eval` 的 `base_url` 仍指向 `pd_master`**。 +**绑定 IP(`HOST` / `PD_MASTER_IP`)**:各进程的 **`--host`** 表示 **本服务监听绑定的 IP**。当 **`pd_master`、`prefill`、`decode` 部署在同一台机器上时**,三者使用的绑定 IP **相同**:可 **`export HOST="${PD_MASTER_IP}"`**;**`lm_eval` 的 `base_url` 仍指向 `pd_master`**。 整轮产物落在**同一日志目录**,写入 **`summary.txt`** 与各进程日志;**不要**写聚合启动脚本,按「启动说明」逐条手动启动并在后台落盘。 @@ -50,7 +50,7 @@ description: >- | `LOG_DIR` | 本轮日志根目录;`export LOG_DIR=…`。 | | `MODEL_DIR` | **`--model_dir`**;`lm_eval` 的 **`tokenizer` 须与此路径一致**。 | | `PD_MASTER_IP` | **`pd_master` 的 `--host`**;**`lm_eval` 的 `base_url` 主机**。 | -| `HOST` | **`nixl_prefill` / `nixl_decode` 的 `--host`**。同机时 **`export HOST="${PD_MASTER_IP}"`**。 | +| `HOST` | **`prefill` / `decode` 的 `--host`**。同机时 **`export HOST="${PD_MASTER_IP}"`**。 | | `PREFILL_CUDA_DEVICES` | prefill 的 **`CUDA_VISIBLE_DEVICES`**(两张物理索引);**`nvidia-smi` 后 export**。 | | `DECODE_CUDA_DEVICES` | decode 的 **`CUDA_VISIBLE_DEVICES`**;与 prefill **四卡互不重复**。 | | `UCX_NET_DEVICES` | UCX 使用的 HCA 列表,形如 `mlx5_0:1,mlx5_1:1`;**按本机 `ibv_devinfo` 与规划填写**。 | @@ -70,7 +70,7 @@ export UCX_TLS=rc,cuda,gdr_copy ### 显卡分配(`nvidia-smi` + 人工/Agent 决策,不用复杂脚本) -**nixl_prefill**、**nixl_decode** 各 **2** 张 GPU,共 **4** 张互不重复。需要验证 NCCL 数据面时,额外设置 **`LIGHTLLM_PD_KV_TRANSPORT_BACKEND=nccl`**。 +**prefill**、**decode** 各 **2** 张 GPU,共 **4** 张互不重复。需要验证 NCCL 数据面时,额外设置 **`LIGHTLLM_PD_KV_TRANSPORT_BACKEND=nccl`**。 1. 执行 **`nvidia-smi`**(可选用 `--query-gpu=index,name,memory.used,memory.free --format=csv`)。 2. 由执行者选定哪 2 张给 prefill、哪 2 张给 decode(不重叠)。 @@ -119,7 +119,7 @@ nohup python -m lightllm.server.api_server \ >> "${LOG_DIR}/pd_master.log" 2>&1 & ``` -### 2)启动 `nixl_prefill` 节点 +### 2)启动 `prefill` 节点 **须在 `pd_master` 就绪后**再启动。启动前已完成 **`nvidia-smi` 决策**并 **`export PREFILL_CUDA_DEVICES`**,且已设置 **UCX**。 @@ -130,7 +130,7 @@ export https_proxy= LOADWORKER=18 CUDA_VISIBLE_DEVICES="${PREFILL_CUDA_DEVICES}" \ nohup python -m lightllm.server.api_server \ --model_dir "${MODEL_DIR}" \ - --run_mode nixl_prefill \ + --run_mode prefill \ --tp 2 \ --dp 1 \ --host "${HOST}" \ @@ -143,7 +143,7 @@ nohup python -m lightllm.server.api_server \ (若需显式传入 UCX,可在同一 shell 中于本块之前 **`export UCX_NET_DEVICES`** 等;**`nohup` 会继承当前 shell 的环境变量**。) -### 3)启动 `nixl_decode` 节点 +### 3)启动 `decode` 节点 启动前 **`export DECODE_CUDA_DEVICES`**,并确保 **UCX** 已设置。 @@ -155,7 +155,7 @@ export no_proxy=localhost,127.0.0.1,0.0.0.0,::1,${PD_MASTER_IP} LOADWORKER=18 CUDA_VISIBLE_DEVICES="${DECODE_CUDA_DEVICES}" \ nohup python -m lightllm.server.api_server \ --model_dir "${MODEL_DIR}" \ - --run_mode nixl_decode \ + --run_mode decode \ --tp 2 \ --dp 1 \ --host "${HOST}" \ @@ -169,7 +169,7 @@ nohup python -m lightllm.server.api_server \ NIXL PD 链路在首次真实推理前易出现冷启动与传输路径问题。**在跑 `lm_eval` 正式评测之前**,必须先对 **`pd_master`** 的 **`/v1/completions`** 发 **至少一次** HTTP 请求,确认返回 **2xx** 且响应体含正常 completion(再走长评测)。 -1. **时机**:**`nixl_prefill` 与 `nixl_decode` 均已启动**,且日志显示已与 **`pd_master`** 建立 PD 链路后再执行(可与端口 listen、日志轮询结合判断)。 +1. **时机**:**`prefill` 与 `decode` 均已启动**,且日志显示已与 **`pd_master`** 建立 PD 链路后再执行(可与端口 listen、日志轮询结合判断)。 2. **代理**:执行 **`curl` 前**同样 **`export http_proxy=` / `export https_proxy=`**;若评测机对 **`PD_MASTER_IP`** 走代理会失败,可对本次 shell 设置 **`no_proxy`**(与下文 `lm_eval` 一致,须包含 **`${PD_MASTER_IP}`**)。 3. **记录**:将 **`curl` 使用的命令、HTTP 状态码、若失败则错误摘要** 写入 **`summary.txt`**;成功后再启动 **`lm_eval`**。 @@ -212,11 +212,11 @@ lm_eval --model local-completions \ **模型目录**:首轮可 **`export MODEL_DIR=/mtc/models/qwen3-8b`**;路径报错时由用户提供本机 **`MODEL_DIR`**。 -1. **启动顺序**:**`bash skills/test_model/qwen3-8b-pd-nixl/check_nvidia_peermem.sh >> "${LOG_DIR}/summary.txt"`** → **`pd_master`** → **`nvidia-smi` + export 四卡** → **设置 UCX** → **`nixl_prefill`** → **`nixl_decode`** → **`curl` warmup(须成功)** → **`lm_eval`**。 +1. **启动顺序**:**`bash skills/test_model/qwen3-8b-pd-nixl/check_nvidia_peermem.sh >> "${LOG_DIR}/summary.txt"`** → **`pd_master`** → **`nvidia-smi` + export 四卡** → **设置 UCX** → **`prefill`** → **`decode`** → **`curl` warmup(须成功)** → **`lm_eval`**。 2. **不要用 health 接口**作为唯一依据;结合 **端口 listen** 与 **`pd_master.log` / `prefill.log` / `decode.log`**。 3. **约每 20 秒**查看日志直至就绪或报错;异常写入 **`summary.txt`**。 4. **`summary.txt`**:记录启动摘要、**`PREFILL_CUDA_DEVICES` / `DECODE_CUDA_DEVICES`** 与选卡依据、**`UCX_NET_DEVICES` 等**、**`curl` warmup 结果(或 `curl_warmup.log` 路径)**、评测关键输出、最终结论。 -5. **结束后**关闭 **`pd_master`、`nixl_prefill`、`nixl_decode`** 相关进程。 +5. **结束后**关闭 **`pd_master`、`prefill`、`decode`** 相关进程。 6. 当用户说明是压测的时候,将lmeval 的 --batch_size 修改为 500 7. 发现 connetion to pd_master has error 错误的时候,可以先容忍一会,这种网络状态错误有时是可以自行恢复的。 diff --git a/skills/test_model/qwen3-8b-pd-nixl/check_nvidia_peermem.sh b/skills/test_model/qwen3-8b-pd-nixl/check_nvidia_peermem.sh index e68a038417..21bc2f35e6 100755 --- a/skills/test_model/qwen3-8b-pd-nixl/check_nvidia_peermem.sh +++ b/skills/test_model/qwen3-8b-pd-nixl/check_nvidia_peermem.sh @@ -35,7 +35,7 @@ if [[ "$FAIL" -ne 0 ]]; then Enable GPUDirect RDMA: sudo modprobe nvidia_peermem lsmod | grep nvidia_peermem - # cross-node: run on every host; then restart nixl_prefill / nixl_decode + # cross-node: run on every host; then restart prefill / decode EOF exit 1 fi diff --git a/skills/test_model/qwen3.5-0.8b-pd-nixl/SKILL.md b/skills/test_model/qwen3.5-0.8b-pd-nixl/SKILL.md index 983f76e551..244a7a06a6 100644 --- a/skills/test_model/qwen3.5-0.8b-pd-nixl/SKILL.md +++ b/skills/test_model/qwen3.5-0.8b-pd-nixl/SKILL.md @@ -2,9 +2,9 @@ name: test-model-qwen3.5-0.8b-pd-nixl description: >- LightLLM Qwen3.5-0.8B PD disaggregation over NIXL gsm8k: pd_master on 8089, - nixl_prefill on 8001, nixl_decode on 8002. Supports TP1 and TP2 runs by setting + prefill on 8001, decode on 8002. Supports TP1 and TP2 runs by setting TP / PREFILL_CUDA_DEVICES / DECODE_CUDA_DEVICES. Qwen3.5 has linear-attention - state transfer; use --nixl_pd_kv_page_size 2048 and a large enough page_num + state transfer; use --pd_kv_page_size 2048 and a large enough page_num such as 256. lm_eval hits pd_master URL. Requires UCX/RDMA env, nvidia_peermem check, curl warmup before lm_eval, registration wait in pd_master.log, and summary.txt. Includes optional repeated-prompt decode cache probe for linear-att @@ -14,7 +14,7 @@ description: >- # Qwen3.5-0.8B **PD 分离(NIXL)** 本地 GSM8K 评测 **测试标识**:同一 **`MODEL_DIR`(Qwen3.5-0.8B)** 下拆三条 `api_server` 进程: -**`pd_master`**、**`nixl_prefill`**、**`nixl_decode`**。评测和 warmup 只访问 +**`pd_master`**、**`prefill`**、**`decode`**。评测和 warmup 只访问 **`pd_master` 的 HTTP 端口 `8089`**。 Qwen3.5 与 Qwen3-8B 的关键差异: @@ -22,8 +22,8 @@ Qwen3.5 与 Qwen3-8B 的关键差异: | 项 | Qwen3.5-0.8B NIXL PD 要点 | |---|---| | linear-att 状态 | PD 传输除了 KV page,还会传 `linear_att_state` 特殊页 | -| NIXL page size | 建议固定 **`--nixl_pd_kv_page_size 2048`**;`1024` 可能不足以容纳 linear-att 状态 | -| page num | 建议 **`--nixl_pd_kv_page_num 256`** 起步,避免 page 池过小干扰评测 | +| NIXL page size | 建议固定 **`--pd_kv_page_size 2048`**;`1024` 可能不足以容纳 linear-att 状态 | +| page num | 建议 **`--pd_kv_page_num 256`** 起步,避免 page 池过小干扰评测 | | cache 判断 | repeated prompt 可能只在 prefill 侧命中,decode 侧不一定 decode-only 命中 | ## 日志目录 @@ -40,7 +40,7 @@ Qwen3.5 与 Qwen3-8B 的关键差异: 建议命名: ```bash -export LOG_DIR="/mtc/wzj/lightllm_dev2/LightLLM/test/benchmark/static_inference/log/qwen35_pd_nixl_$(date +%Y%m%d_%H%M%S)" +export LOG_DIR="/mtc/wzj/lightllm_dev2/LightLLM/test/benchmark/static_inference/log/qwen35_pd_$(date +%Y%m%d_%H%M%S)" mkdir -p "${LOG_DIR}" ``` @@ -64,8 +64,8 @@ export MODEL_NAME='qwen/Qwen3.5-0.8B' export TP=2 export PREFILL_CUDA_DEVICES='0,1' export DECODE_CUDA_DEVICES='2,3' -export NIXL_PD_KV_PAGE_SIZE=2048 -export NIXL_PD_KV_PAGE_NUM=256 +export PD_KV_PAGE_SIZE=2048 +export PD_KV_PAGE_NUM=256 export PD_MASTER_IP="$(hostname -I | awk '{print $1}')" export HOST="${PD_MASTER_IP}" ``` @@ -78,8 +78,8 @@ export MODEL_NAME='qwen/Qwen3.5-0.8B' export TP=1 export PREFILL_CUDA_DEVICES='4' export DECODE_CUDA_DEVICES='5' -export NIXL_PD_KV_PAGE_SIZE=2048 -export NIXL_PD_KV_PAGE_NUM=256 +export PD_KV_PAGE_SIZE=2048 +export PD_KV_PAGE_NUM=256 export PD_MASTER_IP="$(hostname -I | awk '{print $1}')" export HOST="${PD_MASTER_IP}" ``` @@ -108,8 +108,8 @@ export no_proxy=localhost,127.0.0.1,0.0.0.0,::1,${PD_MASTER_IP} echo "TP=${TP}" echo "PREFILL_CUDA_DEVICES=${PREFILL_CUDA_DEVICES}" echo "DECODE_CUDA_DEVICES=${DECODE_CUDA_DEVICES}" - echo "NIXL_PD_KV_PAGE_SIZE=${NIXL_PD_KV_PAGE_SIZE}" - echo "NIXL_PD_KV_PAGE_NUM=${NIXL_PD_KV_PAGE_NUM}" + echo "PD_KV_PAGE_SIZE=${PD_KV_PAGE_SIZE}" + echo "PD_KV_PAGE_NUM=${PD_KV_PAGE_NUM}" echo "PD_MASTER_IP=${PD_MASTER_IP}" echo "HOST=${HOST}" echo "UCX_NET_DEVICES=${UCX_NET_DEVICES-}" @@ -132,13 +132,13 @@ nohup python -m lightllm.server.api_server \ 等待 `8089` listen 后再启动节点。 -### 2. 启动 `nixl_prefill` +### 2. 启动 `prefill` ```bash LOADWORKER=18 CUDA_VISIBLE_DEVICES="${PREFILL_CUDA_DEVICES}" \ nohup python -m lightllm.server.api_server \ --model_dir "${MODEL_DIR}" \ - --run_mode nixl_prefill \ + --run_mode prefill \ --tp "${TP}" \ --dp 1 \ --host "${HOST}" \ @@ -146,26 +146,26 @@ nohup python -m lightllm.server.api_server \ --disable_cudagraph \ --pd_master_ip "${PD_MASTER_IP}" \ --pd_master_port 8089 \ - --nixl_pd_kv_page_size "${NIXL_PD_KV_PAGE_SIZE}" \ - --nixl_pd_kv_page_num "${NIXL_PD_KV_PAGE_NUM}" \ + --pd_kv_page_size "${PD_KV_PAGE_SIZE}" \ + --pd_kv_page_num "${PD_KV_PAGE_NUM}" \ >> "${LOG_DIR}/prefill.log" 2>&1 & ``` -### 3. 启动 `nixl_decode` +### 3. 启动 `decode` ```bash LOADWORKER=18 CUDA_VISIBLE_DEVICES="${DECODE_CUDA_DEVICES}" \ nohup python -m lightllm.server.api_server \ --model_dir "${MODEL_DIR}" \ - --run_mode nixl_decode \ + --run_mode decode \ --tp "${TP}" \ --dp 1 \ --host "${HOST}" \ --port 8002 \ --pd_master_ip "${PD_MASTER_IP}" \ --pd_master_port 8089 \ - --nixl_pd_kv_page_size "${NIXL_PD_KV_PAGE_SIZE}" \ - --nixl_pd_kv_page_num "${NIXL_PD_KV_PAGE_NUM}" \ + --pd_kv_page_size "${PD_KV_PAGE_SIZE}" \ + --pd_kv_page_num "${PD_KV_PAGE_NUM}" \ >> "${LOG_DIR}/decode.log" 2>&1 & ``` @@ -174,14 +174,14 @@ nohup python -m lightllm.server.api_server \ 不要只看端口。必须等待 `pd_master.log` 同时出现: ```text -mode: nixl_prefill ... registed -mode: nixl_decode ... registed +mode: prefill ... registed +mode: decode ... registed ``` 可用命令: ```bash -rg 'mode: nixl_prefill .* registed|mode: nixl_decode .* registed|ERROR|Traceback|Exception' "${LOG_DIR}/pd_master.log" "${LOG_DIR}/prefill.log" "${LOG_DIR}/decode.log" +rg 'mode: prefill .* registed|mode: decode .* registed|ERROR|Traceback|Exception' "${LOG_DIR}/pd_master.log" "${LOG_DIR}/prefill.log" "${LOG_DIR}/decode.log" ``` ## Warmup @@ -293,7 +293,7 @@ decode-only 全命中的期望信号: | `NIXL_ERR_BACKEND` / `uct_iface_open(rc_verbs/mlx5_8:1) failed: Address not valid` | 显式设置可用 `UCX_NET_DEVICES`,例如避开 `mlx5_8/9` | | `digest sent was rejected` | 多为快速重启后的共享内存 / multiprocessing authkey 残留;清理端口和残留 `lightllm::...` worker 后重启 | | `can not find waiting WRITE task` | 检查 NIXL notify key、abort 日志、以及 `pd_io_struct.py` 中 key 是否包含进程本地 `req_idx` | -| 1024 page size 失败 | Qwen3.5 linear-att state 页可能放不下;使用 `--nixl_pd_kv_page_size 2048` | +| 1024 page size 失败 | Qwen3.5 linear-att state 页可能放不下;使用 `--pd_kv_page_size 2048` | | 第二次同 prompt 仍走 WRITE | 可能是 decode 侧没有建立可复用 cache,或 linear-att 尾块状态无法全命中 | ## 收尾 diff --git a/skills/test_model/qwen3.5-0.8b-pd-nixl/check_nvidia_peermem.sh b/skills/test_model/qwen3.5-0.8b-pd-nixl/check_nvidia_peermem.sh index 6c0fbbc118..86dca002d0 100755 --- a/skills/test_model/qwen3.5-0.8b-pd-nixl/check_nvidia_peermem.sh +++ b/skills/test_model/qwen3.5-0.8b-pd-nixl/check_nvidia_peermem.sh @@ -35,7 +35,7 @@ if [[ "$FAIL" -ne 0 ]]; then Enable GPUDirect RDMA: sudo modprobe nvidia_peermem lsmod | grep nvidia_peermem - # cross-node: run on every host; then restart nixl_prefill / nixl_decode + # cross-node: run on every host; then restart prefill / decode EOF exit 1 fi diff --git a/test/acc/test_pd_nixl.sh b/test/acc/test_pd.sh similarity index 98% rename from test/acc/test_pd_nixl.sh rename to test/acc/test_pd.sh index 8bbd7007e5..ee94e73e91 100644 --- a/test/acc/test_pd_nixl.sh +++ b/test/acc/test_pd.sh @@ -18,7 +18,7 @@ export UCX_LOG_LEVEL=info export UCX_TLS=rc,cuda,gdr_copy LOADWORKER=18 CUDA_VISIBLE_DEVICES=0,1 python -m lightllm.server.api_server \ --model_dir /mtc/models/qwen3-8b \ ---run_mode "nixl_prefill" \ +--run_mode "prefill" \ --tp 2 \ --dp 1 \ --host $host \ @@ -39,7 +39,7 @@ $host 为本机的ip地址, 测试的时候,自己修改为对应的ip地址 $pd_master_ip 为pd_master的ip地址, 测试的时候,自己修改为对应的ip地址,在测试的时候为本机ip地址 LOADWORKER=18 CUDA_VISIBLE_DEVICES=2,3 python -m lightllm.server.api_server \ --model_dir /mtc/models/qwen3-8b \ ---run_mode "nixl_decode" \ +--run_mode "decode" \ --tp 2 \ --dp 1 \ --host $host \ diff --git a/test/start_scripts/README.md b/test/start_scripts/README.md index a84acf96dc..ffaad87f7f 100644 --- a/test/start_scripts/README.md +++ b/test/start_scripts/README.md @@ -20,8 +20,8 @@ This directory contains various startup scripts for deploying DeepSeek models wi #### Single PD Master Mode - `single_pd_master/pd_master.sh` - PD Master service -- `single_pd_master/pd_nixl_prefill.sh` - Prefill service -- `single_pd_master/pd_nixl_decode.sh` - Decode service +- `single_pd_master/pd_prefill.sh` - Prefill service +- `single_pd_master/pd_decode.sh` - Decode service #### Multi PD Master Mode - `multi_pd_master/config_server.sh` - Configuration server @@ -71,10 +71,10 @@ sh multi_node_ep_node1.sh sh single_pd_master/pd_master.sh # Step 2: Start Prefill service -sh single_pd_master/pd_nixl_prefill.sh +sh single_pd_master/pd_prefill.sh # Step 3: Start Decode service -sh single_pd_master/pd_nixl_decode.sh +sh single_pd_master/pd_decode.sh ``` ### 6. Multi PD Master Mode @@ -87,7 +87,7 @@ sh multi_pd_master/config_server.sh sh multi_pd_master/pd_master_1.sh sh multi_pd_master/pd_master_2.sh -# Step 3: Start Prefill and Decode services with the nixl_prefill/nixl_decode run modes. +# Step 3: Start Prefill and Decode services with the prefill/decode run modes. # Multi-PD startup scripts for these nodes are not provided in this directory. ``` diff --git a/test/start_scripts/single_pd_master/pd_nixl_decode.sh b/test/start_scripts/single_pd_master/pd_decode.sh similarity index 90% rename from test/start_scripts/single_pd_master/pd_nixl_decode.sh rename to test/start_scripts/single_pd_master/pd_decode.sh index f4f279d89e..eb45622f3e 100644 --- a/test/start_scripts/single_pd_master/pd_nixl_decode.sh +++ b/test/start_scripts/single_pd_master/pd_decode.sh @@ -1,7 +1,7 @@ # PD decode mode for deepseek R1 (DP+EP) on H200 # host: the host of the current node # pd_master_ip: the ip of the pd master -# sh pd_nixl_decode.sh +# sh pd_decode.sh export host=$1 export pd_master_ip=$2 @@ -12,7 +12,7 @@ export UCX_TLS=rc,cuda,gdr_copy nvidia-cuda-mps-control -d LOADWORKER=18 python -m lightllm.server.api_server \ --model_dir /path/DeepSeek-R1 \ ---run_mode "nixl_decode" \ +--run_mode "decode" \ --tp 8 \ --dp 8 \ --host $host \ diff --git a/test/start_scripts/single_pd_master/pd_nixl_prefill.sh b/test/start_scripts/single_pd_master/pd_prefill.sh similarity index 91% rename from test/start_scripts/single_pd_master/pd_nixl_prefill.sh rename to test/start_scripts/single_pd_master/pd_prefill.sh index 3a10a32f2a..4cd4e1a705 100644 --- a/test/start_scripts/single_pd_master/pd_nixl_prefill.sh +++ b/test/start_scripts/single_pd_master/pd_prefill.sh @@ -1,7 +1,7 @@ # PD prefill mode for deepseek R1 (DP+EP) on H200 # host: the host of the current node # pd_master_ip: the ip of the pd master -# sh pd_nixl_prefill.sh +# sh pd_prefill.sh ### nixl pd mode used export UCX_NET_DEVICES=$(ibv_devinfo | grep 'hca_id:' | grep -v -E 'mlx5_8|mlx5_9' | awk '{print $2":1"}' | paste -sd, -) @@ -13,7 +13,7 @@ export pd_master_ip=$2 nvidia-cuda-mps-control -d LOADWORKER=18 python -m lightllm.server.api_server \ --model_dir /path/DeepSeek-R1 \ ---run_mode "nixl_prefill" \ +--run_mode "prefill" \ --tp 8 \ --dp 8 \ --host $host \ From f48a91c30034550ed861ec93256eb8877c43c567 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Wed, 10 Jun 2026 08:57:56 +0000 Subject: [PATCH 10/17] fix --- lightllm/server/router/req_queue/__init__.py | 4 ++-- .../{impl_for_nixl_pd.py => impl_for_pd.py} | 2 +- lightllm/utils/error_utils.py | 6 +++--- skills/test_model/qwen3-8b-pd-nixl/SKILL.md | 14 +++++++------- test/start_scripts/single_pd_master/pd_prefill.sh | 2 +- 5 files changed, 14 insertions(+), 14 deletions(-) rename lightllm/server/router/req_queue/chunked_prefill/{impl_for_nixl_pd.py => impl_for_pd.py} (99%) diff --git a/lightllm/server/router/req_queue/__init__.py b/lightllm/server/router/req_queue/__init__.py index f977f6d567..eb991bb4b9 100644 --- a/lightllm/server/router/req_queue/__init__.py +++ b/lightllm/server/router/req_queue/__init__.py @@ -1,6 +1,6 @@ from .chunked_prefill.impl import ChunkedPrefillQueue from .chunked_prefill.beam_impl import ChunkedBeamContinuesBatchQueue -from .chunked_prefill.impl_for_nixl_pd import NIXLPDQueue +from .chunked_prefill.impl_for_pd import PDQueue from .dp_base_queue import DpQueue @@ -14,7 +14,7 @@ def _get_req_queue_class(args, router, dp_size_in_node: int): if args.first_token_constraint_mode: return ChunkedPrefillQueue if args.run_mode in ["prefill", "decode"]: - return NIXLPDQueue + return PDQueue if args.disable_chunked_prefill: # 虽然也使用chuncked prefill queue 但是由于 args.chunked_prefill_size = args.max_req_total_len diff --git a/lightllm/server/router/req_queue/chunked_prefill/impl_for_nixl_pd.py b/lightllm/server/router/req_queue/chunked_prefill/impl_for_pd.py similarity index 99% rename from lightllm/server/router/req_queue/chunked_prefill/impl_for_nixl_pd.py rename to lightllm/server/router/req_queue/chunked_prefill/impl_for_pd.py index 482568ebfb..5ec09f5760 100644 --- a/lightllm/server/router/req_queue/chunked_prefill/impl_for_nixl_pd.py +++ b/lightllm/server/router/req_queue/chunked_prefill/impl_for_pd.py @@ -8,7 +8,7 @@ logger = init_logger(__name__) -class NIXLPDQueue(BaseQueue): +class PDQueue(BaseQueue): def __init__(self, args, router, dp_index, dp_size_in_node) -> None: super().__init__(args, router, dp_index, dp_size_in_node) diff --git a/lightllm/utils/error_utils.py b/lightllm/utils/error_utils.py index 0e2db7f4e8..77ad90c618 100644 --- a/lightllm/utils/error_utils.py +++ b/lightllm/utils/error_utils.py @@ -38,10 +38,10 @@ def __init__(self, group_request_id: Optional[int] = None, reason: str = "client self.reason = reason -class NixlPrefillNodeStopGenToken(Exception): - def __init__(self, group_request_id, message="Nixl prefill node stop gen token"): +class PDPrefillNodeStopGenToken(Exception): + def __init__(self, group_request_id, message="PD prefill node stop gen token"): """ - Initialize the NixlPrefillNodeStopGenToken + Initialize the PDPrefillNodeStopGenToken Args: message (str): Error message to display diff --git a/skills/test_model/qwen3-8b-pd-nixl/SKILL.md b/skills/test_model/qwen3-8b-pd-nixl/SKILL.md index 2012daadc5..a1775d09d8 100644 --- a/skills/test_model/qwen3-8b-pd-nixl/SKILL.md +++ b/skills/test_model/qwen3-8b-pd-nixl/SKILL.md @@ -1,18 +1,18 @@ --- name: test-model-qwen3-8b-pd-nixl description: >- - LightLLM Qwen3-8b PD disaggregation over NIXL gsm8k: pd_master on 8089, prefill on 8001, + LightLLM Qwen3-8b PD disaggregation gsm8k: pd_master on 8089, prefill on 8001, decode on 8002, tp 2 each. Assign four GPUs via nvidia-smi then export PREFILL_CUDA_DEVICES / DECODE_CUDA_DEVICES (no fixed card IDs; no complex shell automation). UCX_NET_DEVICES and TLS for RDMA per cluster. lm_eval hits pd_master URL. HOST vs PD_MASTER_IP when co-located. Before lm_eval, must POST one completion via curl to pd_master for warmup verification. Requires LOG_DIR, MODEL_DIR, proxy cleared, no_proxy, - summary.txt. Same-GPU model_infer + nixl_*_trans need NVIDIA MPS for best KV copy perf; + summary.txt. Same-GPU model_infer + pd_*_trans need NVIDIA MPS for best KV copy perf; record MPS on/off in summary. Run check_nvidia_peermem.sh in this skill dir; record in summary.txt. - Use for PD NIXL-style separation tests. + Use for PD separation tests with either the default NIXL transport or NCCL transport. --- -# Qwen3-8B **PD 分离(NIXL)**(`pd_master` + `prefill` + `decode`)本地 GSM8K 评测 +# Qwen3-8B **PD 分离**(`pd_master` + `prefill` + `decode`)本地 GSM8K 评测 **测试标识**:同一 **`--model_dir`**(Qwen3-8B)下拆 **三条** `api_server` 进程——**调度/入口(`pd_master`)**、**`prefill` 节点**、**`decode` 节点**;评测 **`lm_eval`** 只访问 **`pd_master` 的 HTTP 端口(8089)**。默认使用 NIXL 传输;需要验证 NCCL 数据面时,设置 **`LIGHTLLM_PD_KV_TRANSPORT_BACKEND=nccl`**,上层仍保持相同的 `prefill` / `decode` 管理路径。 @@ -41,7 +41,7 @@ description: >- 4. **代理**:启动 **任一 server 前**将 **`http_proxy` / `https_proxy` 置空**;评测使用 **`no_proxy`**(见评测命令)。 5. **RDMA / UCX**:prefill 与 decode 进程在启动 Python 前须设置 **`UCX_NET_DEVICES`**(及可选 **`UCX_LOG_LEVEL`**、**`UCX_TLS`**),取值依赖本机 **`ibv_devinfo`** 与机房拓扑(见「UCX / RDMA」);**不要**默认照抄他机上的设备名或排除列表。 6. **`nvidia_peermem`**:`bash skills/test_model/qwen3-8b-pd-nixl/check_nvidia_peermem.sh >> "${LOG_DIR}/summary.txt"`;失败按脚本提示 `modprobe` 后重启服务(跨机各节点都要做)。 -7. **CUDA MPS(强烈建议,见下节)**:**要达到 NIXL PD 最优 KV 拷贝与 batch 评测性能,须在启动 `api_server` 之前在本机启用 NVIDIA MPS**。未开 MPS 时功能通常仍可用,但易出现 **`read_page_gpu_time` 数十秒级毛刺**、**`lm_eval` 单 batch 近百秒**;**`summary.txt` 须写明 MPS 是否已开启及验证方式**。 +7. **CUDA MPS(强烈建议,见下节)**:**要达到 PD KV 拷贝与 batch 评测最佳性能,须在启动 `api_server` 之前在本机启用 NVIDIA MPS**。未开 MPS 时功能通常仍可用,但易出现 **`read_page_gpu_time` 数十秒级毛刺**、**`lm_eval` 单 batch 近百秒**;**`summary.txt` 须写明 MPS 是否已开启及验证方式**。 ### 启动服务的命令模板(可变项) @@ -79,7 +79,7 @@ export UCX_TLS=rc,cuda,gdr_copy **禁止**:为选卡编写 **awk / mapfile / 长段 bash** 自动化;以 **`nvidia-smi` 事实 + 明确决策**为准。 -### UCX / RDMA(NIXL 传输) +### UCX / RDMA(默认 NIXL 传输) - **`UCX_NET_DEVICES`**:须覆盖本进程要用的 **RDMA 设备**;是否排除某些 HCA(例如数据面网卡)由**本机拓扑**决定,在 **`summary.txt`** 中写明依据。 - **`UCX_TLS`**:常见 **`rc,cuda,gdr_copy`**;若环境不支持再按报错调整。 @@ -167,7 +167,7 @@ nohup python -m lightllm.server.api_server \ ### 测试前 curl warmup(**须执行**,再走 `lm_eval`) -NIXL PD 链路在首次真实推理前易出现冷启动与传输路径问题。**在跑 `lm_eval` 正式评测之前**,必须先对 **`pd_master`** 的 **`/v1/completions`** 发 **至少一次** HTTP 请求,确认返回 **2xx** 且响应体含正常 completion(再走长评测)。 +PD 链路在首次真实推理前易出现冷启动与传输路径问题。**在跑 `lm_eval` 正式评测之前**,必须先对 **`pd_master`** 的 **`/v1/completions`** 发 **至少一次** HTTP 请求,确认返回 **2xx** 且响应体含正常 completion(再走长评测)。 1. **时机**:**`prefill` 与 `decode` 均已启动**,且日志显示已与 **`pd_master`** 建立 PD 链路后再执行(可与端口 listen、日志轮询结合判断)。 2. **代理**:执行 **`curl` 前**同样 **`export http_proxy=` / `export https_proxy=`**;若评测机对 **`PD_MASTER_IP`** 走代理会失败,可对本次 shell 设置 **`no_proxy`**(与下文 `lm_eval` 一致,须包含 **`${PD_MASTER_IP}`**)。 diff --git a/test/start_scripts/single_pd_master/pd_prefill.sh b/test/start_scripts/single_pd_master/pd_prefill.sh index 4cd4e1a705..08f5300bc6 100644 --- a/test/start_scripts/single_pd_master/pd_prefill.sh +++ b/test/start_scripts/single_pd_master/pd_prefill.sh @@ -3,7 +3,7 @@ # pd_master_ip: the ip of the pd master # sh pd_prefill.sh -### nixl pd mode used +### PD mode using the default KV transport export UCX_NET_DEVICES=$(ibv_devinfo | grep 'hca_id:' | grep -v -E 'mlx5_8|mlx5_9' | awk '{print $2":1"}' | paste -sd, -) export UCX_LOG_LEVEL=info export UCX_TLS=rc,cuda,gdr_copy From c606e4328cbb2013036aacf39cd00198117edd2a Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Wed, 10 Jun 2026 09:12:54 +0000 Subject: [PATCH 11/17] fix --- lightllm/server/core/objs/pd_kv_trans_params.py | 5 ++++- lightllm/server/pd_io_struct.py | 5 ++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/lightllm/server/core/objs/pd_kv_trans_params.py b/lightllm/server/core/objs/pd_kv_trans_params.py index 68d9de3aa9..ae56c870f7 100644 --- a/lightllm/server/core/objs/pd_kv_trans_params.py +++ b/lightllm/server/core/objs/pd_kv_trans_params.py @@ -22,7 +22,10 @@ def set(self, obj_bytes: Optional[bytes]): assert ( len(obj_bytes) <= LIGHTLLM_PD_KV_TRANS_PARAM_OBJ_MAX_BYTES - ), f"PD_KV_TRANS_PARAM_OBJ bytes len {len(obj_bytes)} exceeds length of {LIGHTLLM_PD_KV_TRANS_PARAM_OBJ_MAX_BYTES} bytes." + ), ( + f"PD_KV_TRANS_PARAM_OBJ bytes len {len(obj_bytes)} exceeds length of " + f"{LIGHTLLM_PD_KV_TRANS_PARAM_OBJ_MAX_BYTES} bytes." + ) ctypes.memmove(self.data, obj_bytes, len(obj_bytes)) self.data_len = len(obj_bytes) return diff --git a/lightllm/server/pd_io_struct.py b/lightllm/server/pd_io_struct.py index 1fa564cb99..1d68f81a9e 100644 --- a/lightllm/server/pd_io_struct.py +++ b/lightllm/server/pd_io_struct.py @@ -100,7 +100,10 @@ def __post_init__(self): def __str__(self): req_id = self.group_request_id pd_m_id = self.pd_master_node_id - return f"group_request_id: {req_id} pd_master_node_id: {pd_m_id} pd_kv_trans_params_len: {len(self.pd_kv_trans_params)}" + return ( + f"group_request_id: {req_id} pd_master_node_id: {pd_m_id} " + f"pd_kv_trans_params_len: {len(self.pd_kv_trans_params)}" + ) @dataclass From ca0496b10326a35ce46dabb5c50425283a5f2763 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Wed, 10 Jun 2026 09:15:00 +0000 Subject: [PATCH 12/17] fix --- lightllm/server/core/objs/pd_kv_trans_params.py | 4 +--- lightllm/server/httpserver/manager.py | 4 +++- .../server/router/model_infer/mode_backend/base_backend.py | 4 +--- .../model_infer/mode_backend/pd/nccl_kv_transporter.py | 7 ++----- .../mode_backend/pd/prefill_node_impl/prefill_impl.py | 4 +--- 5 files changed, 8 insertions(+), 15 deletions(-) diff --git a/lightllm/server/core/objs/pd_kv_trans_params.py b/lightllm/server/core/objs/pd_kv_trans_params.py index ae56c870f7..290a4ca287 100644 --- a/lightllm/server/core/objs/pd_kv_trans_params.py +++ b/lightllm/server/core/objs/pd_kv_trans_params.py @@ -20,9 +20,7 @@ def set(self, obj_bytes: Optional[bytes]): self.data_len = 0 return - assert ( - len(obj_bytes) <= LIGHTLLM_PD_KV_TRANS_PARAM_OBJ_MAX_BYTES - ), ( + assert len(obj_bytes) <= LIGHTLLM_PD_KV_TRANS_PARAM_OBJ_MAX_BYTES, ( f"PD_KV_TRANS_PARAM_OBJ bytes len {len(obj_bytes)} exceeds length of " f"{LIGHTLLM_PD_KV_TRANS_PARAM_OBJ_MAX_BYTES} bytes." ) diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py index 5049cd96f7..8fdd277f57 100644 --- a/lightllm/server/httpserver/manager.py +++ b/lightllm/server/httpserver/manager.py @@ -373,7 +373,9 @@ async def generate( if pd_upload_websocket is not None and self.pd_mode.is_P(): # 在 pd 模式下的 prefill 节点,为了兼容多模态推理流程,需要先上报 encode 好的 prompt ids, # 再等待 pd_master 下发对应请求的 decode 节点信息,然后执行后续流程。 - logger.info(f"pd prefill node upload group_req_id {group_request_id} prompt ids len : {len(prompt_ids)}") + logger.info( + f"pd prefill node upload group_req_id {group_request_id} prompt ids len : {len(prompt_ids)}" + ) await pd_upload_websocket.send( pickle.dumps((ObjType.PD_UPLOAD_PREFILL_PROMPT_IDS, group_request_id, prompt_ids)) ) diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index a1466d226b..4323a62d1c 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -218,9 +218,7 @@ def init_model(self, kvargs): [rank for rank in range(self.global_world_size)], backend="nccl" ) - if ( - self.args.run_mode in ["prefill", "decode"] or self.args.enable_dp_prompt_cache_fetch - ): + if self.args.run_mode in ["prefill", "decode"] or self.args.enable_dp_prompt_cache_fetch: # 如果存在需要跨进程使用mem manger的特性,则将mem manager写入到 shm中,方便 # 读取 self.model.mem_manager.write_to_shm(req_manager=self.model.req_manager) diff --git a/lightllm/server/router/model_infer/mode_backend/pd/nccl_kv_transporter.py b/lightllm/server/router/model_infer/mode_backend/pd/nccl_kv_transporter.py index 1bea648a83..2ed0335ca5 100644 --- a/lightllm/server/router/model_infer/mode_backend/pd/nccl_kv_transporter.py +++ b/lightllm/server/router/model_infer/mode_backend/pd/nccl_kv_transporter.py @@ -114,9 +114,7 @@ def connect_add_remote_agent(self, remote_agent: PDAgentMetadata): ), f"Peer name {metadata.agent_name} does not match remote name {remote_agent.agent_name}" self.remote_agents[remote_agent.agent_name] = remote_agent - logger.info( - f"Added NCCL remote agent {remote_agent.agent_name} at {metadata.host_ip}:{metadata.control_port}" - ) + logger.info(f"Added NCCL remote agent {remote_agent.agent_name} at {metadata.host_ip}:{metadata.control_port}") return def remove_remote_agent(self, peer_name: str): @@ -363,8 +361,7 @@ def _recv_page(self, trans_task: PDChunckedTransTask): stream = self._get_stream() comm.recv(page_tensor, src=0, stream=stream) logger.info( - f"NCCL recv page done request_id={trans_task.request_id} " - f"dst_page={trans_task.dst_page_index}" + f"NCCL recv page done request_id={trans_task.request_id} " f"dst_page={trans_task.dst_page_index}" ) except BaseException as e: trans_task.error_info = str(e) diff --git a/lightllm/server/router/model_infer/mode_backend/pd/prefill_node_impl/prefill_impl.py b/lightllm/server/router/model_infer/mode_backend/pd/prefill_node_impl/prefill_impl.py index 0f297a1fe5..2a501f509b 100644 --- a/lightllm/server/router/model_infer/mode_backend/pd/prefill_node_impl/prefill_impl.py +++ b/lightllm/server/router/model_infer/mode_backend/pd/prefill_node_impl/prefill_impl.py @@ -39,9 +39,7 @@ def _filter_not_ready_reqs(self, req_ids: List[int]) -> List[InferReq]: ans_list.append(req_obj) else: if req_obj.infer_aborted: - if req_obj.pd_task_num == ( - req_obj.pd_task_failed_num + req_obj.pd_task_success_num - ): + if req_obj.pd_task_num == (req_obj.pd_task_failed_num + req_obj.pd_task_success_num): ans_list.append(req_obj) else: continue From 7e900acb9b202d432e6de69d2c6dfefcc38b6e09 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Wed, 10 Jun 2026 09:51:56 +0000 Subject: [PATCH 13/17] fix --- .../server/core/objs/py_sampling_params.py | 7 ---- lightllm/server/core/objs/sampling_params.py | 42 ------------------- .../server/router/model_infer/infer_batch.py | 6 --- 3 files changed, 55 deletions(-) diff --git a/lightllm/server/core/objs/py_sampling_params.py b/lightllm/server/core/objs/py_sampling_params.py index 5d3a511d21..cbc63c898d 100644 --- a/lightllm/server/core/objs/py_sampling_params.py +++ b/lightllm/server/core/objs/py_sampling_params.py @@ -58,8 +58,6 @@ def __init__( invalid_token_ids: Optional[List[int]] = None, # p d mode used params group_request_id: Optional[int] = None, - # move kv to deocde node, only used in pd mode - move_kv_to_decode_node: Optional[dict] = None, # suggest dp index, deepseekv2 dp mode, use to suggest used dp_index suggested_dp_index: Optional[int] = None, seed: Optional[int] = -1, @@ -93,7 +91,6 @@ def __init__( self.allowed_token_ids = allowed_token_ids self.invalid_token_ids = invalid_token_ids self.group_request_id = group_request_id - self.move_kv_to_decode_node = move_kv_to_decode_node self.suggested_dp_index = suggested_dp_index self.seed = seed if self.do_sample is False: @@ -192,9 +189,6 @@ def verify(self): if not (self.group_request_id is None or isinstance(self.group_request_id, int)): raise ValueError(f"group_request_id must be None or int ,but get {self.group_request_id}") - if not (self.move_kv_to_decode_node is None or isinstance(self.move_kv_to_decode_node, dict)): - raise ValueError(f"move_kv_to_decode_node must be None or dict, but get {self.move_kv_to_decode_node}") - if not (self.suggested_dp_index is None or isinstance(self.suggested_dp_index, int)): raise ValueError(f"suggested_dp_index must be None or int, but get {self.suggested_dp_index}") @@ -273,7 +267,6 @@ def to_dict(self): ret["guided_json"] = self.guided_json ret["allowed_token_ids"] = self.allowed_token_ids ret["invalid_token_ids"] = self.invalid_token_ids - ret["move_kv_to_decode_node"] = self.move_kv_to_decode_node ret["seed"] = self.seed return ret diff --git a/lightllm/server/core/objs/sampling_params.py b/lightllm/server/core/objs/sampling_params.py index f503a92fb3..c39559f5f6 100644 --- a/lightllm/server/core/objs/sampling_params.py +++ b/lightllm/server/core/objs/sampling_params.py @@ -260,44 +260,6 @@ def get(self) -> int: return (self.node_id_high << 64) | self.node_id_low -class DecodeNode(ctypes.Structure): - _pack_ = 4 - _fields_ = [ - ("exists", ctypes.c_bool), - ("node_id", NodeUUId), - ("ip", ctypes.c_int32 * 4), - ("rpyc_port", ctypes.c_int), - ("max_new_tokens", ctypes.c_int), - ] - - def initialize(self, data_dict): - if data_dict is None: - self.exists = False - return - - self.exists = True - - pd_node_id = data_dict["node_id"] - self.node_id = NodeUUId() - self.node_id.initialize(pd_node_id) - - ip_parts = [int(part) for part in data_dict["ip"].split(".")] - self.ip = (ctypes.c_int32 * 4)(*ip_parts) - - self.rpyc_port = data_dict["rpyc_port"] - self.max_new_tokens = data_dict["max_new_tokens"] - - def to_dict(self): - if not self.exists: - return None - return { - "node_id": self.node_id.get(), - "ip": ".".join(str(self.ip[i]) for i in range(4)), - "rpyc_port": self.rpyc_port, - "max_new_tokens": self.max_new_tokens, - } - - class SamplingParams(ctypes.Structure): _pack_ = 4 _fields_ = [ @@ -330,7 +292,6 @@ class SamplingParams(ctypes.Structure): ("exponential_decay_length_penalty", ExponentialDecayLengthPenalty), ("group_request_id", ctypes.c_int64), # p d mode used params ("suggested_dp_index", ctypes.c_int), # suggest dp index, deepseekv2 dp mode, use to suggest used dp_index - ("move_kv_to_decode_node", DecodeNode), # move kv to deocde node, only used in pd mode # in pd split mode, use to keep the id of pd master ("pd_master_node_id", NodeUUId), # pd params object, only used in pd mode, used to build kv transport connection in prefill and decode @@ -384,8 +345,6 @@ def init(self, tokenizer, **kwargs): self.exponential_decay_length_penalty = ExponentialDecayLengthPenalty() self.exponential_decay_length_penalty.initialize(kwargs.get("exponential_decay_length_penalty", (1, 1.0))) - self.move_kv_to_decode_node = DecodeNode() - self.move_kv_to_decode_node.initialize(kwargs.get("move_kv_to_decode_node", None)) self.pd_kv_trans_params = PDKVTransParamObj() self.pd_kv_trans_params.set(kwargs.get("pd_kv_trans_params", None)) self.pd_master_node_id = NodeUUId() @@ -522,7 +481,6 @@ def to_dict(self): "allowed_token_ids": self.allowed_token_ids.to_list(), "invalid_token_ids": self.invalid_token_ids.to_list(), "group_request_id": self.group_request_id, - "move_kv_to_decode_node": self.move_kv_to_decode_node.to_dict(), "skip_special_tokens": self.skip_special_tokens, "add_special_tokens": self.add_special_tokens, "add_spaces_between_special_tokens": self.add_spaces_between_special_tokens, diff --git a/lightllm/server/router/model_infer/infer_batch.py b/lightllm/server/router/model_infer/infer_batch.py index cc925061f7..bae5ea1e3c 100644 --- a/lightllm/server/router/model_infer/infer_batch.py +++ b/lightllm/server/router/model_infer/infer_batch.py @@ -441,12 +441,6 @@ def __init__( # if provided, invalid_token_ids are masked to -inf during sampling (see generic_post_process.sample) self.invalid_token_ids = self.shm_param.invalid_token_ids.to_list() - # p d mode use params - if self.shm_param.move_kv_to_decode_node.exists: - self.move_kv_to_decode_node = self.shm_param.move_kv_to_decode_node.to_dict() - else: - self.move_kv_to_decode_node = None - # this check is not very good to placed here. to do... if self.allowed_token_ids is not None: if not all(e < vocab_size for e in self.allowed_token_ids): From b85292389319a8913ae8d96d068bd4931a3ba2ba Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Wed, 10 Jun 2026 10:10:37 +0000 Subject: [PATCH 14/17] fix --- lightllm/server/api_cli.py | 6 ------ lightllm/server/api_start.py | 15 +++------------ lightllm/server/core/objs/start_args_type.py | 1 - .../model_infer/mode_backend/pd/kv_transporter.py | 13 ++++++++----- 4 files changed, 11 insertions(+), 24 deletions(-) diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index 15f1fa27ea..7e40421140 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -53,12 +53,6 @@ def make_argument_parser() -> argparse.ArgumentParser: default=1212, help="when run_mode set to prefill or decode, you need set this pd_mater_port", ) - parser.add_argument( - "--pd_decode_rpyc_port", - type=int, - default=None, - help="p d mode, decode node rpyc server port", - ) parser.add_argument( "--select_p_d_node_strategy", type=str, diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py index 64cb276d34..aaaefb930d 100644 --- a/lightllm/server/api_start.py +++ b/lightllm/server/api_start.py @@ -338,8 +338,6 @@ def normal_or_p_d_start(args): already_uesd_ports = [args.port] if args.nccl_port is not None: already_uesd_ports.append(args.nccl_port) - if args.pd_decode_rpyc_port is not None: - already_uesd_ports.append(args.pd_decode_rpyc_port) if args.visual_nccl_ports is not None: already_uesd_ports.extend(args.visual_nccl_ports[: args.visual_dp]) if not args.disable_audio and args.audio_nccl_ports is not None: @@ -352,7 +350,7 @@ def normal_or_p_d_start(args): node_world_size = args.tp // args.nnodes can_use_ports = alloc_can_use_network_port( - num=10 + node_world_size + args.visual_dp * args.visual_tp + args.visual_dp + args.audio_dp, + num=9 + node_world_size + args.visual_dp * args.visual_tp + args.visual_dp + args.audio_dp, used_ports=already_uesd_ports, ) logger.info(f"alloced ports: {can_use_ports}") @@ -366,9 +364,8 @@ def normal_or_p_d_start(args): cache_port, metric_port, multi_level_kv_cache_port, - pd_decode_rpyc_port, - ) = can_use_ports[0:10] - can_use_ports = can_use_ports[10:] + ) = can_use_ports[0:9] + can_use_ports = can_use_ports[9:] if args.visual_nccl_ports is None: args.visual_nccl_ports = can_use_ports[: args.visual_dp] @@ -385,8 +382,6 @@ def normal_or_p_d_start(args): # 将申请好的端口放入args参数中 if args.nccl_port is None: args.nccl_port = nccl_port - if args.pd_decode_rpyc_port is None: - args.pd_decode_rpyc_port = pd_decode_rpyc_port args.router_port = router_port args.detokenization_port = detokenization_port args.http_server_port = http_server_port @@ -399,10 +394,6 @@ def normal_or_p_d_start(args): args.pd_node_infer_rpyc_ports = can_use_ports[0:node_world_size] # p d 分离模式下用于标识节点的id args.pd_node_id = uuid.uuid4().int - # p 节点用来建立torch kv 传输分布组的可用端口范围 - args.pd_p_allowed_port_min = 20000 - args.pd_p_allowed_port_max = 30000 - # p d 分离模式下,decode节点的调度间隙是0 if args.run_mode == "decode": args.router_max_wait_tokens = 0 diff --git a/lightllm/server/core/objs/start_args_type.py b/lightllm/server/core/objs/start_args_type.py index 5ed804d6a0..e7f35780a4 100644 --- a/lightllm/server/core/objs/start_args_type.py +++ b/lightllm/server/core/objs/start_args_type.py @@ -23,7 +23,6 @@ class StartArgs: config_server_visual_redis_port: int = field(default=None) afs_image_embed_dir: str = field(default=None) afs_embed_capacity: int = field(default=250000) - pd_decode_rpyc_port: int = field(default=None) select_p_d_node_strategy: str = field(default=None) model_name: str = field(default="default_model_name") model_dir: Optional[str] = field(default=None) diff --git a/lightllm/server/router/model_infer/mode_backend/pd/kv_transporter.py b/lightllm/server/router/model_infer/mode_backend/pd/kv_transporter.py index 236a737448..c4e7951043 100644 --- a/lightllm/server/router/model_infer/mode_backend/pd/kv_transporter.py +++ b/lightllm/server/router/model_infer/mode_backend/pd/kv_transporter.py @@ -8,6 +8,9 @@ logger = init_logger(__name__) +_NCCL_CONTROL_PORT_MIN = 20000 +_NCCL_CONTROL_PORT_MAX = 30000 + def create_kv_transporter(args: StartArgs, node_id: int, tp_idx: int, kv_move_buffer: Tensor): backend = os.getenv("LIGHTLLM_PD_KV_TRANSPORT_BACKEND", "nixl").lower() @@ -20,11 +23,11 @@ def create_kv_transporter(args: StartArgs, node_id: int, tp_idx: int, kv_move_bu from .nccl_kv_transporter import NcclKVTransporter logger.info("Use NCCL as pd KV transporter backend") - port_min = args.pd_p_allowed_port_min + tp_idx * 100 - port_max = min(args.pd_p_allowed_port_max, port_min + 99) - if port_min > args.pd_p_allowed_port_max: - port_min = args.pd_p_allowed_port_min - port_max = args.pd_p_allowed_port_max + port_min = _NCCL_CONTROL_PORT_MIN + tp_idx * 100 + port_max = min(_NCCL_CONTROL_PORT_MAX, port_min + 99) + if port_min > _NCCL_CONTROL_PORT_MAX: + port_min = _NCCL_CONTROL_PORT_MIN + port_max = _NCCL_CONTROL_PORT_MAX return NcclKVTransporter( node_id=node_id, tp_idx=tp_idx, From 854c3b55a14e6a69e0131b1fba38037b34d72b41 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Wed, 10 Jun 2026 10:19:41 +0000 Subject: [PATCH 15/17] fix --- test/start_scripts/multi_pd_master.sh | 34 +++++++++++++++++++ .../multi_pd_master/pd_decode.sh | 19 +++++++++++ .../multi_pd_master/pd_prefill.sh | 21 ++++++++++++ 3 files changed, 74 insertions(+) create mode 100644 test/start_scripts/multi_pd_master.sh create mode 100644 test/start_scripts/multi_pd_master/pd_decode.sh create mode 100644 test/start_scripts/multi_pd_master/pd_prefill.sh diff --git a/test/start_scripts/multi_pd_master.sh b/test/start_scripts/multi_pd_master.sh new file mode 100644 index 0000000000..7b83923929 --- /dev/null +++ b/test/start_scripts/multi_pd_master.sh @@ -0,0 +1,34 @@ +# 多 pd_master 节点部署示例 +python -m lightllm.server.api_server --run_mode "config_server" --config_server_host 10.120.114.74 --config_server_port 60088 + +python -m lightllm.server.api_server --model_dir /mtc/models/DeepSeek-V2-Lite-Chat --run_mode "pd_master" --host 10.120.114.74 --port 60011 --config_server_host 10.120.114.74 --config_server_port 60088 + +python -m lightllm.server.api_server --model_dir /mtc/models/DeepSeek-V2-Lite-Chat --run_mode "pd_master" --host 10.120.114.74 --port 60012 --config_server_host 10.120.114.74 --config_server_port 60088 + +nvidia-cuda-mps-control -d +CUDA_VISIBLE_DEVICES=0 LOADWORKER=1 python -m lightllm.server.api_server --model_dir /mtc/models/DeepSeek-V2-Lite-Chat \ +--run_mode "prefill" \ +--host 10.120.178.74 \ +--port 8019 \ +--tp 1 \ +--nccl_port 2732 \ +--max_total_token_num 40000 \ +--tokenizer_mode fast \ +--max_req_total_len 16000 \ +--running_max_req_size 128 \ +--disable_cudagraph \ +--config_server_host 10.120.114.74 \ +--config_server_port 60088 + +CUDA_VISIBLE_DEVICES=1 LOADWORKER=10 python -m lightllm.server.api_server --model_dir /mtc/models/DeepSeek-V2-Lite-Chat \ +--run_mode "decode" \ +--host 10.120.178.74 \ +--port 8121 \ +--nccl_port 12322 \ +--tp 1 \ +--max_total_token_num 40000 \ +--graph_max_len_in_batch 2048 \ +--graph_max_batch_size 16 \ +--tokenizer_mode fast \ +--config_server_host 10.120.114.74 \ +--config_server_port 60088 \ No newline at end of file diff --git a/test/start_scripts/multi_pd_master/pd_decode.sh b/test/start_scripts/multi_pd_master/pd_decode.sh new file mode 100644 index 0000000000..cb55ec338f --- /dev/null +++ b/test/start_scripts/multi_pd_master/pd_decode.sh @@ -0,0 +1,19 @@ +# decode +# host: the host of the decode server +# config_server_host: the host of the config server +# sh decode.sh +export host=$1 +export config_server_host=$2 +nvidia-cuda-mps-control -d +MOE_MODE=EP LOADWORKER=18 python -m lightllm.server.api_server \ +--model_dir /path/DeepSeek-R1 \ +--run_mode "decode" \ +--host $host \ +--port 8121 \ +--nccl_port 12322 \ +--tp 8 \ +--dp 8 \ +--config_server_host $config_server_host \ +--config_server_port 60088 +# if you want to enable microbatch overlap, you can uncomment the following lines +#--enable_decode_microbatch_overlap \ No newline at end of file diff --git a/test/start_scripts/multi_pd_master/pd_prefill.sh b/test/start_scripts/multi_pd_master/pd_prefill.sh new file mode 100644 index 0000000000..45f6c0c011 --- /dev/null +++ b/test/start_scripts/multi_pd_master/pd_prefill.sh @@ -0,0 +1,21 @@ +# prefill +# host: the host of the prefill server +# config_server_host: the host of the config server +# sh pd_prefill.sh +export host=$1 +export config_server_host=$2 +nvidia-cuda-mps-control -d +LOADWORKER=18 python -m lightllm.server.api_server \ +--model_dir /path/DeepSeek-R1 \ +--run_mode "prefill" \ +--host $host \ +--port 8019 \ +--tp 8 \ +--dp 8 \ +--nccl_port 2732 \ +--disable_cudagraph \ +--config_server_host $config_server_host \ +--config_server_port 60088 \ +--enable_ep_moe +# if you want to enable microbatch overlap, you can uncomment the following lines +#--enable_prefill_microbatch_overlap \ No newline at end of file From 128c8c8443def6dd12f03689aceb3326ab551ff9 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Wed, 10 Jun 2026 10:46:51 +0000 Subject: [PATCH 16/17] fix --- .../model_infer/mode_backend/pd/p2p_fix.py | 56 ++++++++++++------- 1 file changed, 36 insertions(+), 20 deletions(-) diff --git a/lightllm/server/router/model_infer/mode_backend/pd/p2p_fix.py b/lightllm/server/router/model_infer/mode_backend/pd/p2p_fix.py index 0307df582b..5a737c2fc6 100644 --- a/lightllm/server/router/model_infer/mode_backend/pd/p2p_fix.py +++ b/lightllm/server/router/model_infer/mode_backend/pd/p2p_fix.py @@ -1,15 +1,16 @@ # mypy: allow-untyped-defs +import multiprocessing +import os +import threading +from multiprocessing.reduction import ForkingPickler +from multiprocessing.util import register_after_fork +from typing import Union + import torch import torch.utils.hooks from torch._namedtensor_internals import check_serializing_named_tensor -from torch.multiprocessing.reductions import ( - StorageWeakRef, - reduce_nested_tensor, - reduce_sparse_tensor, - rebuild_tensor, - shared_cache, - storage_from_cache, -) +from torch.multiprocessing.reductions import storage_from_cache, shared_cache, StorageWeakRef +from torch.multiprocessing.reductions import reduce_nested_tensor, reduce_sparse_tensor, rebuild_tensor def p2p_fix_rebuild_cuda_tensor( @@ -29,7 +30,13 @@ def p2p_fix_rebuild_cuda_tensor( event_handle, event_sync_required, ): + # 因为接收进程在将 tensor 对应的 handle重新转化为指针的时候 + # 在其c++源码中会将当前显卡切换到storage_device再做操作,这样 + # 得到的指针可能不是接收进程当前上下文设备可以访问的,所以在这里 + # hack 修改了使用的 storage_device,这样后续tritonkernel同时 + # 访问几张显卡上的数据,进行p2p操作就不会出问题了。 storage_device = torch.cuda.current_device() + # If storage_handle is None, storage points to nullptr. if storage_handle is None or storage_size_bytes == 0: storage = storage_cls(0, dtype=dtype, device=storage_device, _internal=True) else: @@ -48,10 +55,12 @@ def p2p_fix_rebuild_cuda_tensor( ) shared_cache[(storage_handle, storage_offset_bytes)] = StorageWeakRef(storage) else: + # We already ref counting this Storage, but producer needs new ref-counters to be released. storage_cls._release_ipc_counter(ref_counter_handle, ref_counter_offset, device=storage_device) _storage = storage if isinstance(storage, torch.UntypedStorage) else storage._untyped_storage - tensor = torch._utils._rebuild_tensor( + + t = torch._utils._rebuild_tensor( torch.storage.TypedStorage(wrap_storage=_storage, dtype=dtype, _internal=True), tensor_offset, tensor_size, @@ -59,20 +68,22 @@ def p2p_fix_rebuild_cuda_tensor( ) if tensor_cls == torch.nn.parameter.Parameter: - tensor = torch.nn.parameter.Parameter(tensor, requires_grad=requires_grad) + # It is crucial for integer tensors to receive + # the requires_grad=False as an argument in the constructor + t = torch.nn.parameter.Parameter(t, requires_grad=requires_grad) else: - tensor.requires_grad = requires_grad + t.requires_grad = requires_grad - return tensor + return t def reduce_tensor(tensor): if tensor.requires_grad and not tensor.is_leaf: raise RuntimeError( "Cowardly refusing to serialize non-leaf tensor which requires_grad, " - "since autograd does not support crossing process boundaries. " + "since autograd does not support crossing process boundaries. " "If you just want to transfer the data, call detach() on the tensor " - "before serializing." + "before serializing (e.g., putting it on the queue)." ) check_serializing_named_tensor(tensor) @@ -95,8 +106,6 @@ def reduce_tensor(tensor): storage = tensor._typed_storage() if storage._untyped_storage.device.type == "cuda": - from lightllm.server.router.model_infer.mode_backend.pd.p2p_fix import p2p_fix_rebuild_cuda_tensor - ( device, handle, @@ -109,19 +118,25 @@ def reduce_tensor(tensor): ) = storage._share_cuda_() tensor_offset = tensor.storage_offset() shared_cache[handle] = StorageWeakRef(storage) + # _backward_hooks purposely omitted here, see + # Note [Don't serialize hooks] + from lightllm.server.router.model_infer.mode_backend.pd.p2p_fix import ( + p2p_fix_rebuild_cuda_tensor, + ) + return ( p2p_fix_rebuild_cuda_tensor, ( type(tensor), tensor.size(), tensor.stride(), - tensor_offset, + tensor_offset, # tensor offset in its storage type(storage), tensor.dtype, device, - handle, - storage_size_bytes, - storage_offset_bytes, + handle, # identifier which CUDA allocation is the storage in. + storage_size_bytes, # size(in bytes) of the storage + storage_offset_bytes, # offset(in bytes) of the storage in the CUDA allocation tensor.requires_grad, ref_counter_handle, ref_counter_offset, @@ -130,6 +145,7 @@ def reduce_tensor(tensor): ), ) + # _backward_hooks purposely omitted here, see Note [Don't serialize hooks] metadata = ( tensor.storage_offset(), tensor.size(), From 64644281ef4330252b69794007b92aa79557cc0b Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Wed, 10 Jun 2026 12:14:15 +0000 Subject: [PATCH 17/17] fix --- skills/test_model/qwen3.5-0.8b-pd-nixl/SKILL.md | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/skills/test_model/qwen3.5-0.8b-pd-nixl/SKILL.md b/skills/test_model/qwen3.5-0.8b-pd-nixl/SKILL.md index 244a7a06a6..dca4fa9a2a 100644 --- a/skills/test_model/qwen3.5-0.8b-pd-nixl/SKILL.md +++ b/skills/test_model/qwen3.5-0.8b-pd-nixl/SKILL.md @@ -4,8 +4,8 @@ description: >- LightLLM Qwen3.5-0.8B PD disaggregation over NIXL gsm8k: pd_master on 8089, prefill on 8001, decode on 8002. Supports TP1 and TP2 runs by setting TP / PREFILL_CUDA_DEVICES / DECODE_CUDA_DEVICES. Qwen3.5 has linear-attention - state transfer; use --pd_kv_page_size 2048 and a large enough page_num - such as 256. lm_eval hits pd_master URL. Requires UCX/RDMA env, nvidia_peermem + state transfer; use --pd_kv_page_size 2048 and --pd_kv_page_num 16. + lm_eval hits pd_master URL. Requires UCX/RDMA env, nvidia_peermem check, curl warmup before lm_eval, registration wait in pd_master.log, and summary.txt. Includes optional repeated-prompt decode cache probe for linear-att page-boundary behavior. @@ -23,7 +23,7 @@ Qwen3.5 与 Qwen3-8B 的关键差异: |---|---| | linear-att 状态 | PD 传输除了 KV page,还会传 `linear_att_state` 特殊页 | | NIXL page size | 建议固定 **`--pd_kv_page_size 2048`**;`1024` 可能不足以容纳 linear-att 状态 | -| page num | 建议 **`--pd_kv_page_num 256`** 起步,避免 page 池过小干扰评测 | +| page num | 建议 **`--pd_kv_page_num 16`** 起步,避免 page 池过大导致显存压力 | | cache 判断 | repeated prompt 可能只在 prefill 侧命中,decode 侧不一定 decode-only 命中 | ## 日志目录 @@ -65,7 +65,7 @@ export TP=2 export PREFILL_CUDA_DEVICES='0,1' export DECODE_CUDA_DEVICES='2,3' export PD_KV_PAGE_SIZE=2048 -export PD_KV_PAGE_NUM=256 +export PD_KV_PAGE_NUM=16 export PD_MASTER_IP="$(hostname -I | awk '{print $1}')" export HOST="${PD_MASTER_IP}" ``` @@ -79,7 +79,7 @@ export TP=1 export PREFILL_CUDA_DEVICES='4' export DECODE_CUDA_DEVICES='5' export PD_KV_PAGE_SIZE=2048 -export PD_KV_PAGE_NUM=256 +export PD_KV_PAGE_NUM=16 export PD_MASTER_IP="$(hostname -I | awk '{print $1}')" export HOST="${PD_MASTER_IP}" ```