diff --git a/csrc/engine/compiler/chunk_prefill_compiler.cpp b/csrc/engine/compiler/chunk_prefill_compiler.cpp new file mode 100644 index 00000000..359ac742 --- /dev/null +++ b/csrc/engine/compiler/chunk_prefill_compiler.cpp @@ -0,0 +1,186 @@ +#include "chunk_prefill_compiler.hpp" +#include "../../global_state/global_state.hpp" +#include "infinicore/context/context.hpp" + + +namespace infinilm::engine { + +ChunkPrefillCompiler::ChunkPrefillCompiler(const std::shared_ptr &model, RankBarrier *barrier) + : GraphCompiler(model, barrier) { + // Enumerate chunk sizes for chunk-prefill + for (size_t cs : {256}) { + chunk_sizes_.push_back(cs); + } + // Enumerate batch sizes for prefill (typically smaller than decode) + for (size_t b = 1; b < 32; b++) { + prefill_batch_sizes_.push_back(b); + } + for (size_t b = 32; b < 64; b += 8) { + prefill_batch_sizes_.push_back(b); + } + for (size_t b = 64; b < 128; b += 16) { + prefill_batch_sizes_.push_back(b); + } + for (size_t b = 128; b < 256; b += 32) { + prefill_batch_sizes_.push_back(b); + } + for (size_t b = 256; b <= 512; b += 64) { + prefill_batch_sizes_.push_back(b); + } +} + +void ChunkPrefillCompiler::compile() { + if (model_->get_cache_config() != nullptr && + dynamic_cast(model_->get_cache_config())) { + + const auto *paged_config = + dynamic_cast(model_->get_cache_config()); + size_t nblocks = paged_config->num_blocks(); + + compiled_map_prefill_.clear(); + + // Max total tokens to avoid OOM during graph recording + constexpr size_t MAX_TOTAL_TOKENS = 4096; + + // Pre-allocate a shared block_tables_holder for the largest (batch_size) we'll use + size_t max_batch = *std::max_element(prefill_batch_sizes_.begin(), prefill_batch_sizes_.end()); + size_t block_per_req = nblocks / max_batch; + block_tables_holder_ = infinicore::Tensor::zeros( + {nblocks}, infinicore::DataType::I32, infinicore::context::getDevice()); + + for (size_t b : prefill_batch_sizes_) { + for (size_t cs : chunk_sizes_) { + size_t total_tokens = b * cs; + if (total_tokens > MAX_TOTAL_TOKENS) { + continue; + } + + size_t bpr = nblocks / b; // block_per_req for this batch size + + InfinilmModel::Input input; + + // input_ids: [1, total_tokens] — all tokens for this batch packed together + input.input_ids = infinicore::Tensor::zeros( + {1, total_tokens}, infinicore::DataType::I64, infinicore::context::getDevice()); + + // position_ids: [total_tokens] + input.position_ids = infinicore::Tensor::zeros( + {total_tokens}, infinicore::DataType::I64, infinicore::context::getDevice()); + + // total_sequence_lengths: [b], set to cs (first-chunk scenario) + input.total_sequence_lengths = infinicore::Tensor::empty( + {b}, infinicore::DataType::I32, infinicore::context::getDevice()); + { + std::vector tsl(b, static_cast(cs)); + infinicore::context::memcpyH2D( + input.total_sequence_lengths.value()->data(), + tsl.data(), b * sizeof(int32_t), false); + } + + // input_offsets: [b+1], stride = cs + input.input_offsets = infinicore::Tensor::empty( + {b + 1}, infinicore::DataType::I32, infinicore::context::getDevice()); + { + std::vector offsets(b + 1); + for (size_t i = 0; i <= b; i++) { + offsets[i] = static_cast(i * cs); + } + infinicore::context::memcpyH2D( + input.input_offsets.value()->data(), + offsets.data(), (b + 1) * sizeof(int32_t), false); + } + + // cu_seqlens: [b+1], same layout as input_offsets for prefill + input.cu_seqlens = infinicore::Tensor::empty( + {b + 1}, infinicore::DataType::I32, infinicore::context::getDevice()); + { + std::vector cu(b + 1); + for (size_t i = 0; i <= b; i++) { + cu[i] = static_cast(i * cs); + } + infinicore::context::memcpyH2D( + input.cu_seqlens.value()->data(), + cu.data(), (b + 1) * sizeof(int32_t), false); + } + + // block_tables: view into the shared holder [b, bpr] + input.block_tables = block_tables_holder_->as_strided( + {b, bpr}, {(ptrdiff_t)bpr, 1}); + + // slot_mapping: [total_tokens] + input.slot_mapping = infinicore::Tensor::zeros( + {total_tokens}, infinicore::DataType::I64, infinicore::context::getDevice()); + + // Attention reads attn_metadata from thread-local forward context. + infinilm::global_state::get_forward_context().attn_metadata = { + input.past_sequence_lengths, + input.total_sequence_lengths, + input.input_offsets, + input.cu_seqlens, + input.block_tables, + input.slot_mapping, + }; + + barrier_->wait(); + infinicore::context::startGraphRecording(); + auto output = model_->forward(input); + auto graph = infinicore::context::stopGraphRecording(); + barrier_->wait(); + + auto shared_output = std::shared_ptr( + new InfinilmModel::Output{infinicore::graph::GraphTensor(output.logits)}); + + compiled_map_prefill_[std::make_tuple(b, cs)] = + CompiledResult{std::move(input), std::make_tuple(graph, shared_output)}; + } + } + } +} + +ChunkPrefillCompiler::Compiled ChunkPrefillCompiler::get_compiled(const InfinilmModel::Input &input) { + if (model_->get_cache_config() == nullptr || + !dynamic_cast(model_->get_cache_config())) { + return {nullptr, nullptr}; + } + + if (!input.block_tables.has_value() || !input.input_ids.has_value()) { + return {nullptr, nullptr}; + } + + size_t batch_size = input.block_tables.value()->size(0); + size_t block_per_req = input.block_tables.value()->size(1); + size_t total_tokens = input.input_ids.value()->size(1); + + // Prefill: total_tokens is a multiple of batch_size, and chunk_size > 1 + if (total_tokens == 0 || total_tokens % batch_size != 0) { + return {nullptr, nullptr}; + } + size_t chunk_size = total_tokens / batch_size; + if (chunk_size <= 1) { + // Single-token case belongs to decode + return {nullptr, nullptr}; + } + + auto result = compiled_map_prefill_.find(std::make_tuple(batch_size, chunk_size)); + if (result == compiled_map_prefill_.end()) { + return {nullptr, nullptr}; + } + + auto &graph_input = result->second.input; + + graph_input.input_ids.value()->copy_from(input.input_ids.value()); + graph_input.position_ids.value()->copy_from(input.position_ids.value()); + graph_input.total_sequence_lengths.value()->copy_from(input.total_sequence_lengths.value()); + graph_input.input_offsets.value()->copy_from(input.input_offsets.value()); + graph_input.cu_seqlens.value()->copy_from(input.cu_seqlens.value()); + graph_input.block_tables.value()->narrow({{1, 0, block_per_req}})->copy_from(input.block_tables.value()); + graph_input.slot_mapping.value()->copy_from(input.slot_mapping.value()); + + auto graph = std::get<0>(result->second.compiled); + auto shared_output = std::shared_ptr( + new InfinilmModel::Output{std::get<1>(result->second.compiled)->logits->resume_from_blob_()}); + + return std::make_tuple(graph, shared_output); +} + +} // namespace infinilm::engine diff --git a/csrc/engine/compiler/chunk_prefill_compiler.hpp b/csrc/engine/compiler/chunk_prefill_compiler.hpp new file mode 100644 index 00000000..bd701158 --- /dev/null +++ b/csrc/engine/compiler/chunk_prefill_compiler.hpp @@ -0,0 +1,42 @@ +#pragma once + +#include "graph_compiler.hpp" + +#include + +namespace infinilm::engine { +class ChunkPrefillCompiler : public GraphCompiler { +public: + ChunkPrefillCompiler(const std::shared_ptr &model, RankBarrier *barrier); + + void compile() override; + + Compiled get_compiled(const InfinilmModel::Input &input) override; + +private: + struct TupleHash { + size_t operator()(const std::tuple &t) const noexcept { + auto h1 = std::hash{}(std::get<0>(t)); + auto h2 = std::hash{}(std::get<1>(t)); + return h1 ^ (h2 + 0x9e3779b97f4a7c15ULL + (h1 << 6) + (h1 >> 2)); + } + }; + + std::vector chunk_sizes_; + std::vector prefill_batch_sizes_; + + infinicore::Tensor block_tables_holder_; + + struct CompiledResult { + InfinilmModel::Input input; + Compiled compiled; + }; + + // Key: (batch_size, chunk_size) + std::unordered_map< + std::tuple, + CompiledResult, + TupleHash> + compiled_map_prefill_; +}; +} // namespace infinilm::engine diff --git a/csrc/engine/compiler/general_compiler.cpp b/csrc/engine/compiler/general_compiler.cpp index 84ee670d..36c6420f 100644 --- a/csrc/engine/compiler/general_compiler.cpp +++ b/csrc/engine/compiler/general_compiler.cpp @@ -1,13 +1,18 @@ #include "general_compiler.hpp" namespace infinilm::engine { -GeneralCompiler::GeneralCompiler(const std::shared_ptr &model, RankBarrier *barrier) : GraphCompiler(model, barrier) { +GeneralCompiler::GeneralCompiler(const std::shared_ptr &model, RankBarrier *barrier, bool enable_chunk_prefill_graph) + : GraphCompiler(model, barrier), enable_chunk_prefill_graph_(enable_chunk_prefill_graph) { static_batching_compiler_ = std::make_unique(model_, barrier); + chunk_prefill_compiler_ = std::make_unique(model_, barrier); paged_compiler_ = std::make_unique(model_, barrier); } void GeneralCompiler::compile() { static_batching_compiler_->compile(); + if (enable_chunk_prefill_graph_) { + chunk_prefill_compiler_->compile(); + } paged_compiler_->compile(); } @@ -19,6 +24,11 @@ GeneralCompiler::Compiled GeneralCompiler::get_compiled(const InfinilmModel::Inp if (std::get<0>(result) != nullptr && std::get<1>(result) != nullptr) { return result; } + // chunk-prefill must be checked before decode (decode would also match if chunk_size==1) + result = chunk_prefill_compiler_.get()->get_compiled(input); + if (std::get<0>(result) != nullptr && std::get<1>(result) != nullptr) { + return result; + } result = paged_compiler_.get()->get_compiled(input); return result; } diff --git a/csrc/engine/compiler/general_compiler.hpp b/csrc/engine/compiler/general_compiler.hpp index e8b84b5d..3edbcea0 100644 --- a/csrc/engine/compiler/general_compiler.hpp +++ b/csrc/engine/compiler/general_compiler.hpp @@ -1,12 +1,13 @@ #pragma once +#include "chunk_prefill_compiler.hpp" #include "paged_compiler.hpp" #include "static_batching_compiler.hpp" namespace infinilm::engine { class GeneralCompiler : public GraphCompiler { public: - GeneralCompiler(const std::shared_ptr &model, RankBarrier *barrier); + GeneralCompiler(const std::shared_ptr &model, RankBarrier *barrier, bool enable_chunk_prefill_graph = false); void compile() override; @@ -15,5 +16,7 @@ class GeneralCompiler : public GraphCompiler { private: std::unique_ptr static_batching_compiler_; std::unique_ptr paged_compiler_; + std::unique_ptr chunk_prefill_compiler_; + bool enable_chunk_prefill_graph_; }; } // namespace infinilm::engine diff --git a/csrc/engine/infer_engine.cpp b/csrc/engine/infer_engine.cpp index 5cee2814..7fc7272a 100644 --- a/csrc/engine/infer_engine.cpp +++ b/csrc/engine/infer_engine.cpp @@ -13,6 +13,7 @@ InferEngine::InferEngine( infinicore::Device::Type device_type, const cache::CacheConfig *cache_config, bool enable_graph_compiling, + bool enable_chunk_prefill_graph, backends::AttentionBackend attention_backend, std::optional kv_cache_dtype) // Changed parameter : communication_group_(distributed_config, device_type), attention_backend_(attention_backend) { @@ -39,6 +40,7 @@ InferEngine::InferEngine( cache_config_ != nullptr ? cache_config_.get() : nullptr, barrier_.get(), enable_graph_compiling, + enable_chunk_prefill_graph, attention_backend_)); } // Compile the model on all workers diff --git a/csrc/engine/infer_engine.hpp b/csrc/engine/infer_engine.hpp index 5a54db66..bbe8a0b6 100644 --- a/csrc/engine/infer_engine.hpp +++ b/csrc/engine/infer_engine.hpp @@ -26,6 +26,7 @@ class InferEngine { infinicore::Device::Type device_type = infinicore::context::getDevice().getType(), const cache::CacheConfig *cache_config = nullptr, bool enable_graph_compiling = false, + bool enable_chunk_prefill_graph = false, backends::AttentionBackend attention_backend = backends::AttentionBackend::Default, std::optional kv_cache_dtype = std::nullopt); diff --git a/csrc/engine/rank_worker.cpp b/csrc/engine/rank_worker.cpp index d53a5ba6..2d59e55b 100644 --- a/csrc/engine/rank_worker.cpp +++ b/csrc/engine/rank_worker.cpp @@ -16,12 +16,14 @@ RankWorker::RankWorker( const cache::CacheConfig *cache_config, RankBarrier *barrier, bool enable_graph_compiling, + bool enable_chunk_prefill_graph, backends::AttentionBackend attention_backend) : infinilm_config_(infinilm_config), model_config_(infinilm_config->model_config), rank_info_(rank_info), attention_backend_(attention_backend), enable_graph_compiling_(enable_graph_compiling), + enable_chunk_prefill_graph_(enable_chunk_prefill_graph), job_cmd_(Command::INIT), has_job_(false), job_done_(false), @@ -270,7 +272,7 @@ void RankWorker::thread_loop() { throw std::runtime_error("Failed to create model"); } if (enable_graph_compiling_) { - compiler_ = std::make_unique(model_, barrier_); + compiler_ = std::make_unique(model_, barrier_, enable_chunk_prefill_graph_); } init_done_ = true; diff --git a/csrc/engine/rank_worker.hpp b/csrc/engine/rank_worker.hpp index 7d04f5b8..906802e7 100644 --- a/csrc/engine/rank_worker.hpp +++ b/csrc/engine/rank_worker.hpp @@ -75,6 +75,7 @@ class RankWorker { const cache::CacheConfig *cache_config, RankBarrier *barrier, bool enable_graph_compiling, + bool enable_chunk_prefill_graph, backends::AttentionBackend attention_backend); // Submit a parameter load job and wait until the load completes on the worker thread. @@ -125,6 +126,7 @@ class RankWorker { // Graph Compiling bool enable_graph_compiling_; + bool enable_chunk_prefill_graph_; std::unique_ptr compiler_; // Command for the pending job (protected by mutex_) diff --git a/csrc/pybind11/engine/engine.hpp b/csrc/pybind11/engine/engine.hpp index ed262876..4717bac2 100644 --- a/csrc/pybind11/engine/engine.hpp +++ b/csrc/pybind11/engine/engine.hpp @@ -37,6 +37,7 @@ inline void bind_infer_engine(py::module &m) { infinicore::Device::Type dev, std::shared_ptr cache_cfg, bool enable_graph_compiling, + bool enable_chunk_prefill_graph, const std::string &attention_backend, std::optional kv_cache_dtype) { return std::make_shared( @@ -45,6 +46,7 @@ inline void bind_infer_engine(py::module &m) { dev, cache_cfg ? cache_cfg.get() : nullptr, enable_graph_compiling, + enable_chunk_prefill_graph, infinilm::backends::parse_attention_backend(attention_backend), kv_cache_dtype); }), @@ -53,6 +55,7 @@ inline void bind_infer_engine(py::module &m) { py::arg("device_type") = infinicore::context::getDevice().getType(), py::arg("cache_config") = py::none(), py::arg("enable_graph_compiling") = false, + py::arg("enable_chunk_prefill_graph") = false, py::arg("attention_backend") = "default", py::arg("kv_cache_dtype") = py::none()) .def("load_param", &InferEngine::load_param, diff --git a/python/infinilm/base_config.py b/python/infinilm/base_config.py index 16b43320..b2e1b927 100644 --- a/python/infinilm/base_config.py +++ b/python/infinilm/base_config.py @@ -60,6 +60,8 @@ def __init__(self): self.attn = self.args.attn self.enable_graph = self.args.enable_graph + self.enable_chunk_prefill_graph = self.args.enable_chunk_prefill_graph + self.chunk_size = self.args.chunk_size self.enable_paged_attn = self.args.enable_paged_attn self.num_blocks = self.args.num_blocks self.block_size = self.args.block_size @@ -123,6 +125,8 @@ def _add_common_args(self): choices=["default", "paged-attn", "flash-attn"], ) self.parser.add_argument("--enable-graph", action="store_true") + self.parser.add_argument("--enable-chunk-prefill-graph", action="store_true", help="enable chunk-prefill graph compiling") + self.parser.add_argument("--chunk-size", type=int, default=0, help="tokens per chunked-prefill slice (0 to disable)") self.parser.add_argument( "--enable-paged-attn", action="store_true", diff --git a/python/infinilm/config/engine_config.py b/python/infinilm/config/engine_config.py index ff33448f..751ca29d 100644 --- a/python/infinilm/config/engine_config.py +++ b/python/infinilm/config/engine_config.py @@ -40,6 +40,8 @@ class EngineConfig: top_p: float = 0.8 top_k: int = 1 enable_graph: bool = False + enable_chunk_prefill_graph: bool = False + chunk_size: int = 0 attn_backend: str = "default" skip_load: bool = False kv_transfer_config: Optional[KVTransferConfig] = None diff --git a/python/infinilm/infer_engine.py b/python/infinilm/infer_engine.py index 14c265bd..2abf52b5 100644 --- a/python/infinilm/infer_engine.py +++ b/python/infinilm/infer_engine.py @@ -59,6 +59,7 @@ def __init__( distributed_config=DistConfig(1), cache_config=None, enable_graph_compiling=False, + enable_chunk_prefill_graph=False, attention_backend="default", kv_cache_dtype=None, ): @@ -75,6 +76,7 @@ def __init__( device._underlying.type, cache_config, enable_graph_compiling, + enable_chunk_prefill_graph, attention_backend, ( parse_dtype(kv_cache_dtype)._underlying diff --git a/python/infinilm/llm/llm.py b/python/infinilm/llm/llm.py index 236ee04f..17828352 100644 --- a/python/infinilm/llm/llm.py +++ b/python/infinilm/llm/llm.py @@ -102,6 +102,8 @@ def __init__(self, config: EngineConfig): def add_request(self, request: InferenceRequest): """Add a request to the scheduler.""" + if self.cache_type == "paged" and self.config.chunk_size > 0: + request.chunk_size = self.config.chunk_size self.scheduler.add_request(request) def step(self) -> tuple[bool, list[tuple]]: @@ -150,7 +152,17 @@ def _update_requests( sampled_tokens: List[int], ) -> List[tuple]: """Update request status after inference step.""" - if is_prefill: + # Detect a chunked-prefill mid-step: single request, prefill phase, + # and this chunk does not yet cover the whole prompt. In that case + # we must NOT consume a sampled token, NOT commit prefill blocks, + # and re-enqueue the request to keep chunking. + chunk_mid_step = ( + is_prefill + and len(requests) > 0 + and all(r.is_chunking() and not r.chunk_is_last() for r in requests) + ) + + if is_prefill and not chunk_mid_step: match self.cache_type: case "paged": pass @@ -158,6 +170,18 @@ def _update_requests( self.scheduler.update_cache() case _: raise ValueError(f"Unsupported cache_type: {self.cache_type}") + + if chunk_mid_step: + for req in requests: + req.chunk_prefill_offset += req.chunk_size + if req.is_aborted(): + logger.info( + f"Request {req.request_id} aborted by client during chunked-prefill" + ) + continue + self.scheduler.requeue_chunking(req) + return [] + pending = [] for req, token_id in zip(requests, sampled_tokens): if req.is_aborted(): @@ -170,6 +194,12 @@ def _update_requests( req.mark_canceled() continue + if is_prefill: + # Clean up chunked-prefill state on the final chunk so the + # next forward pass on this request takes the decode path. + req.chunk_prefill_offset = 0 + req.chunk_size = 0 + req.generated_token_ids.append(token_id) pending_tokens = req.generated_token_ids[req._token_decode_offset :] delta = self.tokenizer.decode(pending_tokens) @@ -300,6 +330,8 @@ def __init__( top_p: float = 0.8, top_k: int = 1, enable_graph: bool = False, + enable_chunk_prefill_graph: bool = False, + chunk_size: int = 0, attn_backend: str = "default", skip_load: bool = False, ): @@ -337,6 +369,8 @@ def __init__( top_p=top_p, top_k=top_k, enable_graph=enable_graph, + enable_chunk_prefill_graph=enable_chunk_prefill_graph, + chunk_size=chunk_size, attn_backend=attn_backend, skip_load=skip_load, ) @@ -478,6 +512,8 @@ def __init__( top_p: float = 0.8, top_k: int = 1, enable_graph: bool = False, + enable_chunk_prefill_graph: bool = False, + chunk_size: int = 0, attn_backend: str = "default", kv_transfer_config: Optional[KVTransferConfig] = None, ): @@ -518,6 +554,8 @@ def __init__( top_p=top_p, top_k=top_k, enable_graph=enable_graph, + enable_chunk_prefill_graph=enable_chunk_prefill_graph, + chunk_size=chunk_size, attn_backend=attn_backend, kv_transfer_config=kv_transfer_config, ) diff --git a/python/infinilm/llm/model_runner/model_runner.py b/python/infinilm/llm/model_runner/model_runner.py index 973f0230..bccf4e29 100644 --- a/python/infinilm/llm/model_runner/model_runner.py +++ b/python/infinilm/llm/model_runner/model_runner.py @@ -72,6 +72,7 @@ def __init__(self, config: EngineConfig): distributed_config=DistConfig(config.tensor_parallel_size), cache_config=cache_config, enable_graph_compiling=config.enable_graph, + enable_chunk_prefill_graph=config.enable_chunk_prefill_graph, attention_backend=config.attn_backend, ) diff --git a/python/infinilm/llm/request.py b/python/infinilm/llm/request.py index a05c9f98..da07f446 100644 --- a/python/infinilm/llm/request.py +++ b/python/infinilm/llm/request.py @@ -166,6 +166,11 @@ def __init__( None # KV transfer parameters from the router ) + # Chunked-prefill state (0 = disabled, otherwise tokens per chunk) + self.chunk_size: int = 0 + # Number of prompt tokens already fed through forward as chunked-prefill + self.chunk_prefill_offset: int = 0 + # For server use self.request_data: Optional[dict] = request_data @@ -204,6 +209,17 @@ def get_num_blocks_required(self, block_size: int) -> int: def get_max_tokens(self) -> Optional[int]: return self.sampling_params.max_tokens + def is_chunking(self) -> bool: + """Return True if this request is in the middle of chunked-prefill.""" + return ( + self.chunk_size > 0 + and (self.prompt_length - self.num_local_cached_tokens) > self.chunk_size + ) + + def chunk_is_last(self) -> bool: + """Return True if the next chunk would finish the prompt.""" + return self.chunk_prefill_offset + self.chunk_size >= self.prompt_length + def is_finished(self) -> bool: return self.status in [ RequestStatus.FINISHED, diff --git a/python/infinilm/llm/scheduler.py b/python/infinilm/llm/scheduler.py index 55a7ba57..659de4b8 100644 --- a/python/infinilm/llm/scheduler.py +++ b/python/infinilm/llm/scheduler.py @@ -1,5 +1,16 @@ """ Scheduler - Request scheduling and batch management with Paged Attention KV Cache. + +Scheduling priority (3-tier with anti-starvation): + 1. New prefill (waiting_queue) — minimize TTFT. + 2. Decode (running_queue, plus remote-KV-finished promotions). + 3. Continue chunked-prefill (chunking_queue) — last priority, but guarded. + +Anti-starvation: + After max_waiting_yields consecutive steps where waiting OR decode beat a + non-empty chunking_queue, the next step is forced onto chunking_queue. + This bounds the worst-case TTFT a long-prompt request can suffer when there + is a steady inflow of new short requests or decode traffic. """ import os @@ -28,13 +39,7 @@ def __init__( class Scheduler: - """Request scheduler with integrated BlockManager for KV cache management. - - Scheduling logic: - 1. Running queue: Check for new blocks needed, update slot_mapping - 2. Waiting queue: Try block reuse (prefix caching), allocate new blocks - 3. Reference counting: Free blocks when requests complete - """ + """Request scheduler with integrated BlockManager for KV cache management.""" def __init__( self, @@ -43,11 +48,15 @@ def __init__( block_size: int = 256, max_num_batched_tokens: int = 1024, connector=None, + max_waiting_yields: int = 4, ): self.waiting_queue = janus.Queue() self.running_queue = janus.Queue() + # Requests in the middle of chunked-prefill. + self.chunking_queue = janus.Queue() self.max_batch_size = max_batch_size + # Remote-KV state (PD disaggregation) self.finished_receiving_kv_req_ids: set[str] = set() self.failed_receiving_kv_req_ids: set[str] = set() self.pending_free_blocks: dict[str, list[int]] = {} @@ -61,19 +70,106 @@ def __init__( self.connector = connector + # Anti-starvation state for chunking_queue. + self._waiting_yields_in_a_row: int = 0 + self.max_waiting_yields: int = max_waiting_yields + def add_request(self, request: InferenceRequest): if request is not None: request.status = RequestStatus.WAITING self.waiting_queue.sync_q.put(request) + # ------------------------------------------------------------------ # + # Main scheduling entrypoint # + # ------------------------------------------------------------------ # def schedule(self) -> Optional[SchedulerOutput]: - """Schedule and return batch of requests to execute.""" - deferred_requests = [] - scheduled_requests = [] - is_prefill = False + # 0) Forced chunking after too many consecutive non-chunking yields. + if self._waiting_yields_in_a_row >= self.max_waiting_yields: + chunking_out = self._try_schedule_chunking() + if chunking_out is not None: + self._waiting_yields_in_a_row = 0 + return chunking_out + # chunking_queue was actually empty — fall through. + + # 1) New prefill. + chunking_was_nonempty = self.chunking_queue.sync_q.qsize() > 0 + waiting_out = self._try_schedule_waiting() + if waiting_out is not None: + if chunking_was_nonempty: + self._waiting_yields_in_a_row += 1 + else: + self._waiting_yields_in_a_row = 0 + return waiting_out + + # 2) Decode (also drains finished remote-KV transfers into decode batch). + chunking_was_nonempty = self.chunking_queue.sync_q.qsize() > 0 + decode_out = self._try_schedule_decode() + if decode_out is not None: + if chunking_was_nonempty: + self._waiting_yields_in_a_row += 1 + else: + self._waiting_yields_in_a_row = 0 + return decode_out + + # 3) Continue chunked-prefill. + chunking_out = self._try_schedule_chunking() + if chunking_out is not None: + self._waiting_yields_in_a_row = 0 + return chunking_out + + # 4) Connector-only output (no requests scheduled, but connector wants to ship metadata). + if self.connector is not None: + scheduler_output = SchedulerOutput(scheduled_requests=[]) + meta = self.connector.build_connector_meta() + scheduler_output.kv_connector_metadata = meta + return scheduler_output + + return None + + # ------------------------------------------------------------------ # + # Per-queue schedulers # + # ------------------------------------------------------------------ # + def _try_schedule_chunking(self) -> Optional[SchedulerOutput]: + """Pull chunking requests into a batch. + + Multi-batch rule: only batch requests where the next chunk is a full + middle chunk (not the last). A last-chunk request (whose next chunk + finishes the prompt) is run alone, because it triggers token sampling + and prefill block commit — semantics differ from middle chunks. + """ + scheduled: List[InferenceRequest] = [] + while len(scheduled) < self.max_batch_size: + try: + req = self.chunking_queue.sync_q.get_nowait() + except queue.Empty: + break + if req.is_finished(): + self.complete_requests([req]) + continue + if req.chunk_is_last(): + if not scheduled: + return SchedulerOutput([req], is_prefill=True) + # Already batched some middle chunks; defer this last-chunk one. + self.chunking_queue.sync_q.put(req) + break + scheduled.append(req) + if scheduled: + return SchedulerOutput(scheduled, is_prefill=True) + return None + + def _try_schedule_waiting(self) -> Optional[SchedulerOutput]: + """Pull new prefill requests from waiting_queue and form a prefill batch. + + Mirrors HEAD's token-budget + connector-aware logic, plus chunked-prefill + kickoff: if a request triggers chunking (prompt_length - num_local_cached + > chunk_size > 0), it must be emitted alone (the chunking C++ graph + signature requires a clean per-step shape, and the first chunk drives + the rest via chunking_queue). + """ + deferred_requests: List[InferenceRequest] = [] + scheduled_requests: List[InferenceRequest] = [] current_num_batched_tokens = 0 - # Process Waiting queue (prefill phase) while ( len(scheduled_requests) < self.max_batch_size and current_num_batched_tokens < self.max_num_batched_tokens @@ -82,7 +178,7 @@ def schedule(self) -> Optional[SchedulerOutput]: req = self.waiting_queue.sync_q.get_nowait() except queue.Empty: break - # Skip requests that were already finished (e.g., timed out/canceled while waiting) + if req.is_finished(): self.complete_requests([req]) continue @@ -111,8 +207,7 @@ def schedule(self) -> Optional[SchedulerOutput]: num_computed_tokens -= 1 num_new_tokens = req.get_prompt_length() - num_computed_tokens - # Early token budget check: skip can_accept_request and allocate_slots - # for requests that would exceed the per-schedule token budget. + # Early token budget check. if not load_kv_async: _num_tokens_this_step = ( req.get_prompt_length() - num_local_computed_tokens @@ -126,15 +221,11 @@ def schedule(self) -> Optional[SchedulerOutput]: deferred_requests.append(req) break - if not self.can_accept_request( - req, - num_local_computed_tokens, - ): + if not self.can_accept_request(req, num_local_computed_tokens): logger.warning( "Insufficient KV cache blocks for request %s, deferring.", req.request_id, ) - if num_local_computed_tokens > 0: self.cache_manager.free_blocks(cached_block_table) deferred_requests.append(req) @@ -183,31 +274,59 @@ def schedule(self) -> Optional[SchedulerOutput]: ) // self.block_size continue - scheduled_requests.append(req) + # Chunked-prefill kickoff detection: if this request needs chunking + # AND we have already started forming a normal-prefill batch, defer + # this long request to next tick. Otherwise emit it as singleton. + remaining = req.get_prompt_length() - req.num_local_cached_tokens + if req.chunk_size > 0 and remaining > req.chunk_size: + if scheduled_requests: + # Send this long one back to waiting; flush the normal batch first. + req.status = RequestStatus.WAITING + self.waiting_queue.sync_q.put(req) + break + # Emit chunking-start as a singleton. + req.chunk_prefill_offset = req.num_local_cached_tokens + req.status = RequestStatus.RUNNING + scheduler_output = SchedulerOutput( + scheduled_requests=[req], + is_prefill=True, + ) + if self.connector is not None: + meta = self.connector.build_connector_meta() + scheduler_output.kv_connector_metadata = meta + # Restore deferred ones. + for d in deferred_requests: + self.waiting_queue.sync_q.put(d) + return scheduler_output + scheduled_requests.append(req) num_tokens_this_step = req.get_prompt_length() - req.num_local_cached_tokens current_num_batched_tokens += num_tokens_this_step - req.status = RequestStatus.RUNNING if deferred_requests: for req in deferred_requests: self.waiting_queue.sync_q.put(req) - # Return prefill batch if any waiting requests were scheduled if scheduled_requests: - is_prefill = True scheduler_output = SchedulerOutput( scheduled_requests=scheduled_requests, - is_prefill=is_prefill, + is_prefill=True, ) if self.connector is not None: meta = self.connector.build_connector_meta() scheduler_output.kv_connector_metadata = meta return scheduler_output - # Promote completed remote KV transfers directly into the decode batch. - # Failed transfers are re-queued for prefill. + return None + + def _try_schedule_decode(self) -> Optional[SchedulerOutput]: + """Drain finished remote-KV completions into decode batch, then pull + from running_queue. + """ + scheduled_requests: List[InferenceRequest] = [] + + # Promote completed remote-KV transfers directly into the decode batch. if self.connector is not None and self.remote_kv_requests: for req_id in list(self.remote_kv_requests.keys()): req = self.remote_kv_requests[req_id] @@ -230,18 +349,15 @@ def schedule(self) -> Optional[SchedulerOutput]: req.status = RequestStatus.RUNNING scheduled_requests.append(req) - # Process Running queue (decode phase) while len(scheduled_requests) < self.max_batch_size: try: req = self.running_queue.sync_q.get_nowait() except queue.Empty: break - # Skip requests that were already finished (e.g., timed out/canceled while running) if req.is_finished(): self.complete_requests([req]) continue - # Decode phase: allocate slot for newly generated token try: req.block_table, new_slot = self.cache_manager.append_slot( req.block_table, req.get_total_length(), req.get_all_token_ids() @@ -250,32 +366,28 @@ def schedule(self) -> Optional[SchedulerOutput]: req.num_blocks = len(req.block_table) req.num_local_cached_tokens = req.get_total_length() - 1 scheduled_requests.append(req) - except RuntimeError as e: raise RuntimeError("No available cache blocks for new token") from e - # Return decode batch if any running requests were scheduled if scheduled_requests: - is_prefill = False scheduler_output = SchedulerOutput( scheduled_requests=scheduled_requests, - is_prefill=is_prefill, + is_prefill=False, ) - # logger.info("Scheduled decode: %d", len(scheduled_requests)) - if self.connector is not None: meta = self.connector.build_connector_meta() scheduler_output.kv_connector_metadata = meta return scheduler_output - if self.connector is not None: - scheduler_output = SchedulerOutput(scheduled_requests=[]) - meta = self.connector.build_connector_meta() - scheduler_output.kv_connector_metadata = meta - return scheduler_output - return None + # ------------------------------------------------------------------ # + # External hooks # + # ------------------------------------------------------------------ # + def requeue_chunking(self, req: InferenceRequest): + """Put a request back into the chunking queue after a chunk has run.""" + self.chunking_queue.sync_q.put(req) + def update_waiting_for_remote_kv(self, request: InferenceRequest): self.remote_kv_requests.pop(request.request_id, None) self.pending_kv_decode_blocks -= ( @@ -306,7 +418,11 @@ def update_waiting_for_remote_kv(self, request: InferenceRequest): self.finished_receiving_kv_req_ids.remove(request.request_id) def complete_requests(self, requests: List[InferenceRequest]): - """Handle completed requests and free their blocks.""" + """Handle completed requests and free their blocks. + + Active (non-finished) requests are re-enqueued into the running_queue — + this is how prefill-finished requests migrate into the decode pipeline. + """ for req in requests: if req.status in [ RequestStatus.FINISHED, @@ -348,7 +464,6 @@ def complete_requests(self, requests: List[InferenceRequest]): f"Request {req.request_id[:8]}... timed out: {req.finish_reason}" ) else: - # Still running, put back in running queue self.running_queue.sync_q.put(req) def can_accept_request( @@ -356,7 +471,6 @@ def can_accept_request( ) -> bool: total_required_blocks = 0 - # Calculate blocks needed for running requests running_queue_size = self.running_queue.sync_q.qsize() for _ in range(running_queue_size): req = self.running_queue.sync_q.get() @@ -369,17 +483,13 @@ def can_accept_request( total_required_blocks += num_blocks_needed self.running_queue.sync_q.put(req) - # Calculate blocks needed for the new request total_length = request.get_prompt_length() - num_local_computed_tokens total_length += request.sampling_params.max_tokens num_blocks_needed = (total_length + self.block_size - 1) // self.block_size total_required_blocks += num_blocks_needed - # Include decode headroom for WAITING_FOR_REMOTE_KVS requests, which - # hold prompt blocks but will also need decode blocks once promoted. total_required_blocks += self.pending_kv_decode_blocks - # Compare with total usable blocks in cache manager return total_required_blocks <= self.cache_manager.get_total_usable_blocks() def update_from_output(self, model_output): @@ -401,17 +511,12 @@ def update_from_output(self, model_output): for req_id in finished_recving_req_ids: if req_id in self.pending_free_blocks: - # Aborted request: transfer complete, now safe to free blocks. self.cache_manager.free_blocks(self.pending_free_blocks.pop(req_id)) elif req_id in self.remote_kv_requests: - # Active request: mark ready for promotion in schedule(). self.finished_receiving_kv_req_ids.add(req_id) - # else: already processed or unknown, discard to avoid stale entries. for req_id in finished_sending_req_ids: self.cache_manager.free_blocks(self.pending_free_blocks.pop(req_id, [])) for req_id in failed_recving_req_ids: - # Only track failures for active (non-aborted) requests; aborted - # requests are handled via pending_free_blocks in finished_recving. if req_id in self.remote_kv_requests: self.failed_receiving_kv_req_ids.add(req_id) diff --git a/python/infinilm/processors/basic_llm_processor.py b/python/infinilm/processors/basic_llm_processor.py index 9145a02d..fdb8e637 100644 --- a/python/infinilm/processors/basic_llm_processor.py +++ b/python/infinilm/processors/basic_llm_processor.py @@ -215,24 +215,44 @@ def _build_model_input_from_batch_scheduler_output( for req in scheduler_output.scheduled_requests: num_cached = req.num_local_cached_tokens if scheduler_output.is_prefill: - # Prefill phase req_tokens = req.get_input_tokens() - tokens_to_compute = req_tokens[num_cached:] - tokens.extend(tokens_to_compute) - - compute_len = len(tokens_to_compute) - seq_len = len(req_tokens) - seq_lens.append(seq_len) - - current_offset += compute_len - seq_offsets.append(current_offset) - - slot_mapping.extend(req.slot_mapping) - cached_lens.append(num_cached) - position_ids.extend(range(num_cached, num_cached + compute_len)) + # Chunked-prefill: only feed [chunk_prefill_offset : +chunk_size). + if req.is_chunking(): + start = req.chunk_prefill_offset + end = min(start + req.chunk_size, len(req_tokens)) + tokens_to_compute = req_tokens[start:end] + compute_len = len(tokens_to_compute) + tokens.extend(tokens_to_compute) + seq_len = end + seq_lens.append(seq_len) + current_offset += compute_len + seq_offsets.append(current_offset) + # req.slot_mapping has length (prompt_length - num_cached) and is + # indexed [0..prompt_length-num_cached). Translate absolute token + # indices to slot_mapping indices. + slot_start = start - num_cached + slot_end = end - num_cached + assert slot_start >= 0 and slot_end <= len(req.slot_mapping), ( + f"chunking slot slice out of range: start={start} " + f"end={end} num_cached={num_cached} " + f"len(slot_mapping)={len(req.slot_mapping)}" + ) + slot_mapping.extend(req.slot_mapping[slot_start:slot_end]) + cached_lens.append(start) + position_ids.extend(range(start, end)) + else: + tokens_to_compute = req_tokens[num_cached:] + tokens.extend(tokens_to_compute) + compute_len = len(tokens_to_compute) + seq_len = len(req_tokens) + seq_lens.append(seq_len) + current_offset += compute_len + seq_offsets.append(current_offset) + slot_mapping.extend(req.slot_mapping) + cached_lens.append(num_cached) + position_ids.extend(range(num_cached, num_cached + compute_len)) else: - # Decode phase seq_len = req.get_total_length() last_token = ( req.generated_token_ids[-1] @@ -241,20 +261,30 @@ def _build_model_input_from_batch_scheduler_output( ) tokens.append(last_token) seq_lens.append(seq_len) - current_offset += 1 seq_offsets.append(current_offset) - slot_mapping.extend(req.slot_mapping) cached_lens.append(num_cached) position_ids.append(seq_len - 1) - # Pad block_table to same length padded_block_table = req.block_table + [-1] * ( max_block_table_len - len(req.block_table) ) block_tables.append(padded_block_table) cu_seqlens.append(cu_seqlens[-1] + seq_len) + + # guarantee non-empty tokens and slot_mapping to avoid downstream errors. If empty, raise with detailed debug info. + if not tokens or not slot_mapping: + states = [ + (r.request_id[:8], r.is_chunking(), + r.chunk_prefill_offset, r.prompt_length, r.num_local_cached_tokens, + len(r.slot_mapping), r.status.name) + for r in scheduler_output.scheduled_requests + ] + raise RuntimeError( + f"build_model_inputs got empty tokens/slot_mapping. " + f"is_prefill={scheduler_output.is_prefill}, states={states}" + ) assert seq_offsets[-1] == len(tokens), ( f"seq_offsets[-1]={seq_offsets[-1]} != len(tokens)={len(tokens)}" diff --git a/python/infinilm/server/inference_server.py b/python/infinilm/server/inference_server.py index 50e98d60..c5403fe3 100644 --- a/python/infinilm/server/inference_server.py +++ b/python/infinilm/server/inference_server.py @@ -110,6 +110,8 @@ def __init__( host: str = "0.0.0.0", port: int = 8000, enable_graph: bool = False, + enable_chunk_prefill_graph: bool = False, + chunk_size: int = 0, attn_backend: str = "default", ignore_eos: bool = False, # PD disaggregation @@ -134,6 +136,10 @@ def __init__( host: Server host address. port: Server port number. enable_graph: Whether to enable graph compiling. + enable_chunk_prefill_graph: Whether to enable chunk-prefill graph compiling. + chunk_size: Tokens per chunked-prefill slice (0 = disabled). When > 0 and paged + cache is used, long prompts are sliced and each slice goes through forward + separately so the C++ ChunkPrefillCompiler precompiled graph can be reused. attn_backend: Attention backend to use ('default', 'flash-attn'). """ self.model_path = model_path @@ -154,6 +160,8 @@ def __init__( self.host = host self.port = port self.enable_graph = enable_graph + self.enable_chunk_prefill_graph = enable_chunk_prefill_graph + self.chunk_size = chunk_size self.attn_backend = attn_backend self.ignore_eos = ignore_eos self.kv_transfer_config = kv_transfer_config @@ -187,12 +195,16 @@ async def lifespan(app: FastAPI): top_p=self.top_p, top_k=self.top_k, enable_graph=self.enable_graph, + enable_chunk_prefill_graph=self.enable_chunk_prefill_graph, + chunk_size=self.chunk_size, attn_backend=self.attn_backend, kv_transfer_config=self.kv_transfer_config, ) self.engine.start() logger.info(f"Engine initialized with model at {self.model_path}") logger.info(f" enable_graph: {self.enable_graph}") + logger.info(f" enable_chunk_prefill_graph: {self.enable_chunk_prefill_graph}") + logger.info(f" chunk_size: {self.chunk_size}") yield self.engine.stop() @@ -589,6 +601,8 @@ def main(): host=cfg.host, port=cfg.port, enable_graph=cfg.enable_graph, + enable_chunk_prefill_graph=cfg.enable_chunk_prefill_graph, + chunk_size=cfg.chunk_size, attn_backend=cfg.attn, ignore_eos=cfg.ignore_eos, kv_transfer_config=kv_transfer_config, diff --git a/scripts/test_chunk_prefill.py b/scripts/test_chunk_prefill.py new file mode 100644 index 00000000..5b668808 --- /dev/null +++ b/scripts/test_chunk_prefill.py @@ -0,0 +1,195 @@ +#!/usr/bin/env python3 +""" +一键对比 chunked prefill 开/关性能 + +该脚本会依次启动 launch_server.py (chunk-size=0/256),运行 test_perf_mix.py 取结果,最后输出对比。 +""" + +import argparse +import os +import re +import signal +import subprocess +import sys +import time + + +SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) +LM_DIR = os.path.dirname(SCRIPT_DIR) +INFERENCE_SERVER = os.path.join(LM_DIR, "python", "infinilm", "server", "inference_server.py") +TEST_SCRIPT = os.path.join(SCRIPT_DIR, "test_perf_mix.py") + + +from openai import OpenAI, APIConnectionError, APIStatusError + +def wait_for_server(popen, host, port, model, timeout=300): + client = OpenAI(base_url=f"http://{host}:{port}", api_key="default") + deadline = time.time() + timeout + while time.time() < deadline: + if popen.poll() is not None: + raise RuntimeError(f"server exited early with code {popen.returncode}") + try: + client.chat.completions.create( + model=model, + messages=[{"role": "user", "content": "hi"}], + max_tokens=1, + ) + return + except (APIConnectionError, APIStatusError): + time.sleep(1) + raise TimeoutError(f"server not ready within {timeout}s") + + +def inference_server(chunk_size, device, port, batch_size, max_new_tokens, enable_paged_attn, enable_graph, model_path): + print(INFERENCE_SERVER) + args = ["CUDA_VISIBLE_DEVICES=8", sys.executable, INFERENCE_SERVER, + f"--chunk-size {chunk_size}", + f"--device {device}", + f"--port {port}", + f"--batch-size {batch_size}", + f"--max-new-tokens {max_new_tokens}", + f"--model {model_path}", + "--enable-chunk-prefill-graph"] + if enable_paged_attn: + args.append("--enable-paged-attn") + if enable_graph: + args.append("--enable-graph") + + popen = subprocess.Popen(" ".join(args), shell=True, preexec_fn=os.setsid, stderr=subprocess.STDOUT) + return popen + + +import socket + +def stop_server(popen, host="127.0.0.1", port=2333, timeout=30): + if popen and popen.poll() is None: + os.killpg(os.getpgid(popen.pid), signal.SIGTERM) + try: + popen.wait(timeout=timeout) + except subprocess.TimeoutExpired: + os.killpg(os.getpgid(popen.pid), signal.SIGKILL) + popen.wait(timeout=5) + + # 等端口真正释放(uvicorn 在 graceful shutdown 期间端口还开着) + deadline = time.time() + timeout + while time.time() < deadline: + try: + with socket.create_connection((host, port), timeout=0.5): + pass # 还有人监听,继续等 + except OSError: + return # 端口已释放 + time.sleep(0.3) + raise RuntimeError(f"port {port} still in use after stop_server") + + +def run_test_perf(): + cmd = f"{sys.executable} -u {TEST_SCRIPT}" + proc = subprocess.Popen( + cmd, shell=True, text=True, bufsize=1, + stdout=subprocess.PIPE, stderr=subprocess.STDOUT, + ) + lines = [] + for line in proc.stdout: + sys.stdout.write(line) + sys.stdout.flush() # 子终端的输出直接转发到父终端,保持实时显示 + lines.append(line) + proc.wait() + return proc.returncode, "".join(lines) # 返回码 + test_perf_mix.py的输出文本 + +def parse_stats(output): + def grab(pat): + m = re.search(pat, output) + return float(m.group(1)) if m else None + + success_m = re.search(r"成功请求数\s*:\s*(\d+)", output) + return { + "avg_ttft_s": grab(r"Average TTFT\s*:\s*([0-9.]+)\s*s"), + "avg_e2e_s": grab(r"Average latency\s*:\s*([0-9.]+)\s*s"), + "avg_ms_per_token": grab(r"Avg time per token\s*:\s*([0-9.]+)\s*ms/token"), + "avg_tps": grab(r"Avg Token generation speed\s*:\s*([0-9.]+)"), + "rps": grab(r"请求速率 \(RPS\)\s*:\s*([0-9.]+)"), + "success": int(success_m.group(1)) if success_m else None, + } + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="比较 chunked prefill 开/关的 TTFT/E2E") + parser.add_argument("--device", type=str, default="iluvatar", help="设备类型") + parser.add_argument("--batch-size", type=int, default=16) + parser.add_argument("--max-new-tokens", type=int, default=16) + parser.add_argument("--enable-paged-attn", type=bool, default=True) + parser.add_argument("--enable-graph", type=bool, default=True) + parser.add_argument("--port", type=int, default=2333) + parser.add_argument("--model-path", type=str, default="/data-aisoft/mechdancer/models/9g_8b_thinking_llama/") + + + args = parser.parse_args() + + results = [] + + for chunk_size in (0, 256): + mode = "ON" if chunk_size > 0 else "OFF" + print("\n" + "="*78) + print(f"开始部署大模型推理服务,chunked prefill {mode} (chunk-size={chunk_size}),请等待服务启动完成...") + + server = inference_server(chunk_size=chunk_size, device=args.device, port=args.port, + batch_size=args.batch_size, max_new_tokens=args.max_new_tokens, + enable_paged_attn=args.enable_paged_attn, enable_graph=args.enable_graph, + model_path=args.model_path) + try: + wait_for_server(server, "127.0.0.1", args.port, model="FM9G-7B", timeout=300) + print("服务启动完成,开始跑 test_perf_mix.py (上一条200OK请求为服务测试成功标志)") + retcode, out = run_test_perf() # out为test子终端输出文本 + if retcode != 0: + print("test_perf_cp.py 执行失败,退出码", retcode) + print(out) + raise SystemExit(1) + + stats = parse_stats(out) #从test输出文本中提取性能指标 + stats.update({"chunk_size": chunk_size, "mode": mode}) + results.append(stats) + print(f"完成chunked prefill {mode}测试 -> {stats}") + + finally: + stop_server(server, host="127.0.0.1", port=args.port) + print("服务已停止") + + print("\n" + "#"*78) + print("最终对比(chunked prefill ON vs OFF)") + print("-"*78) + + def fmt(v, unit="", spec=".3f"): + return "N/A" if v is None else f"{v:{spec}}{unit}" + + def diff(a, b): + return None if (a is None or b is None) else a - b + + def speedup_pct(on, off): + # 越小越快的指标:正数 = ON 比 OFF 快 + if on is None or off is None or off == 0: + return None + return (off - on) / off * 100 + + on_r = next((r for r in results if r["mode"] == "ON"), None) + off_r = next((r for r in results if r["mode"] == "OFF"), None) + + print(f"{'指标':<22}{'ON':>14}{'OFF':>14}{'Δ (ON-OFF)':>16}{'ON 提升':>12}") + print("-"*78) + + def row(label, key, unit, spec=".3f", lower_is_better=True): + a = (on_r or {}).get(key) + b = (off_r or {}).get(key) + pct = speedup_pct(a, b) if lower_is_better else speedup_pct(b, a) + print( + f"{label:<22}" + f"{fmt(a, unit, spec):>14}" + f"{fmt(b, unit, spec):>14}" + f"{fmt(diff(a, b), unit, '+'+spec):>16}" + f"{fmt(pct, '%', '+.2f'):>12}" + ) + + row("Avg TTFT", "avg_ttft_s", " s") + row("Avg E2E latency", "avg_e2e_s", " s") + row("Avg ms/token", "avg_ms_per_token", " ms", ".2f") + row("Avg tokens/s", "avg_tps", "", ".2f", lower_is_better=False) + row("RPS", "rps", "", ".2f", lower_is_better=False) + print("-"*78) diff --git a/scripts/test_perf_cp.py b/scripts/test_perf_cp.py new file mode 100644 index 00000000..4e55bae0 --- /dev/null +++ b/scripts/test_perf_cp.py @@ -0,0 +1,156 @@ +""" +Chunked Prefill TTFT Benchmark + +Test: send a long request, wait a short delay, then send a short request. +Measure the short request's TTFT and E2E. + +With chunked prefill: short request inserts at next chunk boundary → lower TTFT +Without chunked prefill: short request waits for full long prefill → higher TTFT + +Usage: + python3 scripts/test_perf.py [--rounds 5] [--delay 0.1] +""" +import asyncio +import time +from openai import AsyncOpenAI +import argparse + +API_URL = "http://127.0.0.1:2333" +MODEL = "jiuge" +MAX_TOKENS = 30 # decode的tokens数 + +_BASE_PARAGRAPHS = [ + '''人工智能(Artificial Intelligence,简称AI)是计算机科学的一个分支,它企图了解智能的实质,并生产出一种新的能以人类智能相似的方式做出反应的智能机器。人工智能的研究包括机器人、语言识别、图像识别、自然语言处理和专家系统等。人工智能从诞生以来,理论和技术日益成熟,应用领域也不断扩大。可以设想,未来人工智能带来的科技产品,将会是人类智慧的容器。''', + '''1956年夏季,以麦卡赛、明斯基、罗切斯特和申农等为首的一批有远见卓识的年轻科学家在一起聚会,共同研究和探讨用机器模拟智能的一系列有关问题,并首次提出了人工智能这一术语,它标志着人工智能这门新兴学科的正式诞生。此后,IBM公司研制的专用计算机深蓝击败了国际象棋世界冠军卡斯帕罗夫。谷歌公司开发的AlphaGo程序战胜了围棋世界冠军李世石,这被认为是人工智能发展史上的一个重要里程碑。''', + '''量子计算是一种利用量子力学原理进行信息处理的计算方式。与经典计算机使用比特作为信息的基本单位不同,量子计算机使用量子比特。量子比特具有叠加态的特性,即一个量子比特可以同时处于0和1的状态,这使得量子计算机在处理某些特定问题时具有经典计算机无法比拟的优势。量子纠缠是量子计算中另一个关键概念,当两个量子比特发生纠缠时,测量其中一个的状态会立即影响另一个的状态。''', + '''根据联合国政府间气候变化专门委员会第六次评估报告,全球平均温度已经比工业化前水平上升了约1.1摄氏度。报告指出,人类活动是导致全球变暖的主要原因,其中化石燃料的燃烧、工业生产和土地利用变化是温室气体排放的主要来源。极端天气事件的频率和强度都在增加,包括热浪、干旱、暴雨和洪水。北极海冰面积持续缩小,格陵兰和南极冰盖加速融化。''', + '''深度学习是机器学习的一个分支,其核心是利用多层神经网络从大量数据中自动学习特征表示。卷积神经网络在图像识别领域取得了巨大成功,循环神经网络和Transformer架构则在自然语言处理领域展现了强大能力。近年来,大语言模型如GPT、BERT、LLaMA等引领了自然语言处理的技术革新,这些模型通过在海量文本数据上进行预训练,获得了强大的语言理解和生成能力。''', + '''在计算机体系结构领域,冯诺依曼架构仍然是现代计算机的基础。然而随着摩尔定律逐渐放缓,研究人员开始探索新型计算范式,包括神经形态计算、存内计算、光子计算等。GPU和TPU等专用加速器的发展极大推动了深度学习的进步。RISC-V开源指令集架构的兴起为芯片设计带来了新的可能性,而chiplet技术和先进封装则为突破制程限制提供了新的路径。''', + '''可再生能源的成本大幅下降,太阳能和风能已经成为最便宜的新增发电来源。电动汽车市场快速增长,电池技术不断进步。碳捕获和储存技术、绿色氢能等前沿技术也在加速发展。然而,要实现全球碳中和目标,仍需要在能源系统、交通运输、工业生产、建筑等领域进行深刻的变革。智能电网、储能技术、虚拟电厂等概念正在从理论走向实践。''', + '''生物信息学是一门利用计算机技术和数学方法研究生物学问题的交叉学科。基因组学、蛋白质组学、代谢组学等组学技术的发展产生了海量的生物数据。AlphaFold2在蛋白质结构预测方面取得了革命性突破,为药物研发和生命科学研究开辟了新的方向。CRISPR基因编辑技术的发展使得精准修改基因成为可能,为遗传疾病的治疗带来了希望。''', +] + + +def build_long_prompt(idx, target_chars=9000): + parts = [f"(文档编号{idx}) 请仔细阅读以下学术材料并总结:\n\n"] + i = idx + while sum(len(p) for p in parts) < target_chars: + parts.append(_BASE_PARAGRAPHS[i % len(_BASE_PARAGRAPHS)]) + parts.append("\n\n") + i += 1 + parts.append(f"以上是第{idx}份材料,请给出详细分析。") + return "".join(parts) + + +async def measure_one_round(client, round_idx, delay_sec): + """ + 1. Fire a long request (starts prefill immediately) + 2. After delay_sec, fire a short request + 3. Return both TTFT and E2E for both requests + """ + long_prompt = build_long_prompt(round_idx, target_chars=9000) + short_prompt = f"(编号{round_idx}) 1+1等于几?" + + long_result = {} + short_result = {} + + async def do_request(prompt, result_dict, delay=0): + if delay > 0: + await asyncio.sleep(delay) + t0 = time.time() + stream = await client.chat.completions.create( + model=MODEL, + messages=[{"role": "user", "content": prompt}], + stream=True, + max_new_tokens=MAX_TOKENS, + temperature=1.0, + top_p=1.0, + extra_body={"top_k": 1}, + ) + first_token_time = None + total_tokens = 0 + async for chunk in stream: + if chunk.choices[0].delta.content: + if first_token_time is None: + first_token_time = time.time() + total_tokens += 1 + if chunk.choices[0].finish_reason is not None: + break + end_time = time.time() + result_dict["ttft"] = (first_token_time - t0) if first_token_time else None + result_dict["e2e"] = end_time - t0 + result_dict["tokens"] = total_tokens + + await asyncio.gather( + do_request(long_prompt, long_result, delay=0), + do_request(short_prompt, short_result, delay=delay_sec), + ) + return long_result, short_result + + +async def run_benchmark(rounds, delay): + client = AsyncOpenAI(base_url=API_URL, api_key="default") + + # Warmup + print("Warmup...") + await measure_one_round(client, 100, delay) + print("Warmup done.\n") + + long_ttfts = [] + long_e2es = [] + short_ttfts = [] + short_e2es = [] + + for i in range(rounds): + lr, sr = await measure_one_round(client, i, delay) # lr = long request result, sr = short request result + + lt = lr["ttft"] * 1000 if lr["ttft"] else 0 + le = lr["e2e"] * 1000 + st = sr["ttft"] * 1000 if sr["ttft"] else 0 + se = sr["e2e"] * 1000 + + print(f" Round {i}: LONG TTFT={lt:>7.1f}ms E2E={le:>8.1f}ms tokens={lr['tokens']}") + print(f" SHORT TTFT={st:>7.1f}ms E2E={se:>8.1f}ms tokens={sr['tokens']}") + + if lr["ttft"]: + long_ttfts.append(lr["ttft"]) + long_e2es.append(lr["e2e"]) + if sr["ttft"]: + short_ttfts.append(sr["ttft"]) + short_e2es.append(sr["e2e"]) + + sep = "=" * 60 + print(f"\n{sep}") + print(f" Chunked Prefill TTFT Benchmark") + print(f"{sep}") + print(f" Rounds: {rounds}") + print(f" Delay before short request: {delay}s") + print(f" Long prompt: ~9000 chars") + print(f" Max tokens: {MAX_TOKENS}") + + def print_stats(label, ttfts, e2es): + if not ttfts: + return + print(f"\n [{label}]") + print(f" Avg TTFT: {sum(ttfts)/len(ttfts)*1000:>8.1f} ms") + print(f" Min TTFT: {min(ttfts)*1000:>8.1f} ms") + print(f" Max TTFT: {max(ttfts)*1000:>8.1f} ms") + print(f" Avg E2E: {sum(e2es)/len(e2es)*1000:>8.1f} ms") + + print_stats("LONG ", long_ttfts, long_e2es) + print_stats("SHORT", short_ttfts, short_e2es) + + if short_ttfts: + print(f"\n >>> SHORT Avg TTFT = {sum(short_ttfts)/len(short_ttfts)*1000:.1f} ms <<<") + print(f" >>> SHORT Avg E2E = {sum(short_e2es)/len(short_e2es)*1000:.1f} ms <<<") + print(f"{sep}\n") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--rounds", type=int, default=5) + parser.add_argument("--delay", type=float, default=0.1, + help="Seconds to wait before sending short request (default: 0.1)") + args = parser.parse_args() + + asyncio.run(run_benchmark(args.rounds, args.delay)) diff --git a/scripts/test_perf_mix.py b/scripts/test_perf_mix.py new file mode 100644 index 00000000..e5506721 --- /dev/null +++ b/scripts/test_perf_mix.py @@ -0,0 +1,319 @@ +import asyncio +import time +from openai import AsyncOpenAI +import argparse +import random + +PROMPTS = [ + + # ~10000 tokens:极限长上下文,多文件代码重构 + "下面给出 4 个相关文件,请重构以消除重复逻辑并提取公共抽象:\n\n" + + "# file: scheduler_v1.py\n" + "def schedule(reqs):\n return sorted(reqs, key=lambda r: r.arrival)\n" * 100 + + "\n# file: scheduler_v2.py\n" + "def schedule(reqs):\n return sorted(reqs, key=lambda r: -r.priority)\n" * 100 + + "\n# file: scheduler_v3.py\n" + "def schedule(reqs):\n return sorted(reqs, key=lambda r: r.prompt_len)\n" * 100 + + "\n# file: scheduler_v4.py\n" + "def schedule(reqs):\n return sorted(reqs, key=lambda r: r.slo_deadline)\n" * 100, + + "1+1=?", + +] + + + + + +'''# Chunked prefill / TTFT-biased workload. +# +# 设计目标: +# 1. 每个并发窗口都混入 1-2 个多 chunk 长 prompt 和多个极短 prompt; +# 2. 长 prompt 负责制造 prefill 压力,短 prompt 提供更敏感的 TTFT 样本; +# 3. 所有问题都要求短回答,尽量避免 decode 阶段掩盖 TTFT 差异。 +PROMPTS = ( + # 0: 长日志诊断,主要制造多 chunk prefill。 + "你是推理服务性能分析员。请只输出 3 条结论,每条不超过 18 字。\n\n" + + "\n".join( + [ + ( + f"[Trace {i:03d}] model=FM9G-7B chunk_size=256 " + f"prompt_len={4096 + (i % 7) * 768} queued_prefill={18 + i % 13} " + f"decode_batch={2 + i % 5} ttft_ms={780 + (i % 17) * 91} " + "现象:长 prompt prefill 与短请求 decode 竞争同一个调度窗口。" + ) + for i in range(120) + ] + ) + + "\n\n问题:哪些迹象说明短请求被长 prefill 阻塞?", + + "只回答数字:17 * 23 = ?", + + "只回答一个英文短语:TTFT 的全称是什么?", + + # 3: 长代码审查,prompt 很长但输出极短。 + "阅读下面的调度器伪代码。只输出 2 个最可能影响 TTFT 的问题。\n\n```python\n" + + "\n".join( + [ + ( + f"def schedule_step_{i}(waiting, running, budget):\n" + f" long_prefill = [r for r in waiting if r.prompt_len > {2048 + i * 8}]\n" + " short_prefill = [r for r in waiting if r.prompt_len <= 128]\n" + " batch = long_prefill + short_prefill + running\n" + " return batch[:budget]\n" + ) + for i in range(80) + ] + ) + + "```\n问题:这个策略为什么不利于短请求首 token?", + + "只回答星期几:2026-05-28 后三天是周几?", + + "只回答 yes 或 no:chunked prefill 会把长 prompt 拆成多个 prefill 片段吗?", + + # 6: RAG 长上下文,模拟检索拼接。 + "基于检索结果回答最后的问题。只输出一句话结论。\n\n" + + "\n\n".join( + [ + ( + f"[Doc {i:02d}] 在线推理系统中,超长 prompt 的 prefill 会占用连续计算预算。" + "当调度器支持 chunked prefill 时,长请求的 KV cache 写入被拆成固定大小块," + "短请求可以穿插进入同一个批次,从而降低短请求排队到首 token 的时间。" + ) + for i in range(48) + ] + ) + + "\n\n问题:为什么这个负载更容易体现 chunked prefill 对 TTFT 的收益?", + + "把 'batch scheduler' 翻译成中文,只给译文。", + + "只回答一个数字:2 的 10 次方是多少?", + + "只回答一句话:KV cache 的作用是什么?", + + # 10: 长用户会话,模拟客服/RAG 历史。 + "以下是用户会话历史。请只给出一个优先级最高的处理建议。\n\n" + + "\n".join( + [ + ( + f"用户{i:03d}:我的请求 prompt_len={512 + (i % 9) * 1024}," + f"排队后 TTFT 超过 {1.2 + (i % 6) * 0.7:.1f}s," + "短问答和长文档摘要同时进入服务端,怀疑 prefill 阶段产生队头阻塞。" + ) + for i in range(110) + ] + ) + + "\n最后问题:应该优先检查哪个调度指标?", + + "只回答一个词:prefill 后逐 token 生成的阶段叫什么?", + + "只输出 JSON:{\"ttft_sensitive\": true}", + + # 13: 长表格分析,制造稳定长 prompt。 + "阅读下面的压测表格,只输出 3 个异常点编号。\n\n" + + "\n".join( + [ + ( + f"case={i:03d}, qps={4 + i % 8}, concurrency=5, " + f"prompt_tokens={256 * (4 + i % 24)}, output_tokens={8 + i % 5}, " + f"ttft_p50={380 + (i % 11) * 70}ms, ttft_p99={1200 + (i % 19) * 160}ms, " + "note=短请求应当快速进入 decode,但长 prefill 持续占用预算。" + ) + for i in range(100) + ] + ) + + "\n问题:哪些 case 最像 chunked prefill 关闭时的队头阻塞?", + + "只回答数字:4096 / 256 = ?", + + "用 8 个字以内解释:为什么短请求关注 TTFT?", + + # 16: 长合同/规则文本,非代码类长输入。 + "阅读以下服务等级条款,只输出对延迟指标最不利的 2 条。\n\n" + + "\n".join( + [ + ( + f"第{i:03d}条:当单个请求 prompt 长度超过 {1024 + (i % 12) * 512} tokens 时," + "系统可以暂缓短请求首 token 返回,直到当前 prefill 批次完成;" + "若启用 chunked prefill,应允许短请求在下一调度片段中插入。" + ) + for i in range(90) + ] + ) + + "\n问题:哪些条款会直接拉高短请求 TTFT?", + + "只回答一个单词:latency 的中文常用译法是什么?", + + "只回答 A/B:更适合测 TTFT 的输出长度是 A.很短 B.很长", + + "只回答一句话:chunk size 过大会怎样影响短请求?", + + # 20: 长混合材料,压住队列尾部。 + "下面混合了日志、设计说明和用户反馈。请只输出一句总体判断。\n\n" + + "\n".join( + [ + ( + f"[mix-{i:03d}] 设计:prefill_budget={256 * (1 + i % 6)}, " + f"decode_budget={8 + i % 4}; 日志:queue_depth={32 + i % 64}, " + f"long_prompt={4096 + (i % 10) * 1024}; " + "反馈:短问题的首 token 等待时间比总生成时间更影响交互体验。" + ) + for i in range(130) + ] + ) + + "\n问题:当前负载是否适合观察 chunked prefill 对 TTFT 的改善?", + + "只回答数字:5 个并发里 1 个长请求,短请求有几个?", + + "只回答英文缩写:time to first token 简写是什么?", + + "把这句话压缩到 12 个字以内:长 prompt 不应该长期阻塞短请求。", +) +''' +NUM_REQUESTS = len(PROMPTS) +CONCURRENCY = 5 +API_URL = "http://127.0.0.1:2333" +MODEL = "FM9G-7B" + + +async def benchmark_user(client, semaphore, queue, results, user_id, verbose): + while True: + async with semaphore: + task_id = await queue.get() + if task_id is None: + queue.task_done() + break + + question = PROMPTS[task_id] + try: + print(f"🚀 User#{user_id} Sending request #{task_id}") + + start_time = time.time() + stream = await client.chat.completions.create( + model=MODEL, + messages=[{"role": "user", "content": question}], + stream=True, + ) + + first_token_time = None + total_tokens = 0 + answer_chunks = [] + + async for chunk in stream: + if first_token_time is None: + first_token_time = time.time() + delta = chunk.choices[0].delta.content + if delta: + answer_chunks.append(delta) + total_tokens += 1 + if chunk.choices[0].finish_reason is not None: + break + + end_time = time.time() + + ttft = first_token_time - start_time if first_token_time else None + elapsed_time = end_time - start_time if start_time else None + ms_per_token = ( + (elapsed_time / total_tokens * 1000) + if total_tokens > 0 and elapsed_time + else None + ) + tokens_per_second = ( + total_tokens / elapsed_time if elapsed_time > 0 else 0 + ) + + answer = "".join(answer_chunks) + + results.append( + (total_tokens, elapsed_time, tokens_per_second, ttft, ms_per_token) + ) + + if verbose: + print(f"\n📝 Request #{task_id} (User #{user_id})") + if ttft is not None: + print(f" ⏱ 首字延迟 TTFT: {ttft:.3f}s") + if elapsed_time is not None: + print(f" ⏱ 总耗时: {elapsed_time:.3f}s") + + print(f" 🔤 解码 token 总数: {total_tokens}") + if ms_per_token is not None: + print(f" 📏 平均 token 解码时间: {ms_per_token:.2f} ms/token") + else: + print(f" 📏 平均 token 解码时间: N/A (no token generated)") + print(f" ❓ 提问: {question}") + print(f" 💬 回答: {answer}\n") + + queue.task_done() + except Exception as e: + if verbose: + print(f"\n⚠️ Request #{task_id} (User #{user_id}) FAILED:") + print(f" ❌ Error: {e}\n") + queue.task_done() + + +async def run_benchmark(verbose=False): + client = AsyncOpenAI(base_url=API_URL, api_key="default") + semaphore = asyncio.Semaphore(CONCURRENCY) + queue = asyncio.Queue() + results = [] + for i in range(NUM_REQUESTS): + await queue.put(i) + for _ in range(CONCURRENCY): + await queue.put(None) + + users = [ + asyncio.create_task( + benchmark_user(client, semaphore, queue, results, user_id, verbose) + ) + for user_id in range(CONCURRENCY) + ] + + start_time = time.time() + await queue.join() + await asyncio.gather(*users) + end_time = time.time() + + total_elapsed_time = end_time - start_time + tokens_list = [r[0] for r in results if r and r[0] is not None] + latencies = [r[1] for r in results if r and r[1] is not None] + tokens_per_second_list = [r[2] for r in results if r and r[2] is not None] + ttft_list = [r[3] for r in results if r and r[3] is not None] + ms_per_token_list = [r[4] for r in results if r and r[4] is not None] + + successful_requests = len(results) + requests_per_second = ( + successful_requests / total_elapsed_time if total_elapsed_time > 0 else 0 + ) + avg_latency = sum(latencies) / len(latencies) if latencies else 0 + avg_tokens_per_second = ( + sum(tokens_per_second_list) / len(tokens_per_second_list) + if tokens_per_second_list + else 0 + ) + avg_ttft = sum(ttft_list) / len(ttft_list) if ttft_list else 0 + avg_ms_per_token = ( + sum(ms_per_token_list) / len(ms_per_token_list) if ms_per_token_list else None + ) + + width_label = 24 + sep = "-" * 60 + + print(f"\n=== 📊 性能指标汇总 ({MODEL}) ===") + print(sep) + print(f"{'并发数':<{width_label}}: {CONCURRENCY}") + print(f"{'请求总数':<{width_label}}: {NUM_REQUESTS}") + print(f"{'成功请求数':<{width_label}}: {successful_requests}") + print(f"{'总耗时':<{width_label}}: {total_elapsed_time:.2f} s") + print(f"{'总输出token数':<{width_label}}: {sum(tokens_list)}") + print(f"{'请求速率 (RPS)':<{width_label}}: {requests_per_second:.2f} requests/s") + print(sep) + print(f"{'Average latency':<{width_label}}: {avg_latency:.2f} s") + print(f"{'Average TTFT':<{width_label}}: {avg_ttft:.2f} s") + print(f"{'Avg time per token':<{width_label}}: {avg_ms_per_token:.2f} ms/token") + print( + f"{'Avg Token generation speed':<{width_label}}: {avg_tokens_per_second:.2f} tokens/s" + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--verbose", action="store_true") + args = parser.parse_args() + + asyncio.run(run_benchmark(args.verbose))