diff --git a/csrc/engine/compiler/chunk_prefill_compiler.cpp b/csrc/engine/compiler/chunk_prefill_compiler.cpp new file mode 100644 index 000000000..55ad56f31 --- /dev/null +++ b/csrc/engine/compiler/chunk_prefill_compiler.cpp @@ -0,0 +1,197 @@ +#include "chunk_prefill_compiler.hpp" +#include "../../global_state/global_state.hpp" +#include "infinicore/context/context.hpp" + + +namespace { +inline void set_zeros(infinicore::Tensor &tensor) { + std::vector zeros(tensor->nbytes(), 0); + infinicore::context::memcpyH2D(tensor->data(), zeros.data(), tensor->nbytes(), false); +} +} // namespace + +namespace infinilm::engine { + +ChunkPrefillCompiler::ChunkPrefillCompiler(const std::shared_ptr &model, RankBarrier *barrier) + : GraphCompiler(model, barrier) { + // Enumerate chunk sizes for chunk-prefill + for (size_t cs : {256}) { + chunk_sizes_.push_back(cs); + } + // Enumerate batch sizes for prefill (typically smaller than decode) + for (size_t b = 1; b < 32; b++) { + prefill_batch_sizes_.push_back(b); + } + for (size_t b = 32; b < 64; b += 8) { + prefill_batch_sizes_.push_back(b); + } + for (size_t b = 64; b < 128; b += 16) { + prefill_batch_sizes_.push_back(b); + } + for (size_t b = 128; b < 256; b += 32) { + prefill_batch_sizes_.push_back(b); + } + for (size_t b = 256; b <= 512; b += 64) { + prefill_batch_sizes_.push_back(b); + } +} + +void ChunkPrefillCompiler::compile() { + if (model_->get_cache_config() != nullptr && + dynamic_cast(model_->get_cache_config())) { + + const auto *paged_config = + dynamic_cast(model_->get_cache_config()); + size_t nblocks = paged_config->num_blocks(); + + compiled_map_prefill_.clear(); + + // Max total tokens to avoid OOM during graph recording + constexpr size_t MAX_TOTAL_TOKENS = 4096; + + // Pre-allocate a shared block_tables_holder for the largest (batch_size) we'll use + size_t max_batch = *std::max_element(prefill_batch_sizes_.begin(), prefill_batch_sizes_.end()); + size_t block_per_req = nblocks / max_batch; + block_tables_holder_ = infinicore::Tensor::empty( + {nblocks}, infinicore::DataType::I32, infinicore::context::getDevice()); + set_zeros(block_tables_holder_); + + for (size_t b : prefill_batch_sizes_) { + for (size_t cs : chunk_sizes_) { + size_t total_tokens = b * cs; + if (total_tokens > MAX_TOTAL_TOKENS) { + continue; + } + + size_t bpr = nblocks / b; // block_per_req for this batch size + + InfinilmModel::Input input; + + // input_ids: [1, total_tokens] — all tokens for this batch packed together + input.input_ids = infinicore::Tensor::empty( + {1, total_tokens}, infinicore::DataType::I64, infinicore::context::getDevice()); + set_zeros(input.input_ids.value()); + + // position_ids: [total_tokens] + input.position_ids = infinicore::Tensor::empty( + {total_tokens}, infinicore::DataType::I64, infinicore::context::getDevice()); + set_zeros(input.position_ids.value()); + + // total_sequence_lengths: [b], set to cs (first-chunk scenario) + input.total_sequence_lengths = infinicore::Tensor::empty( + {b}, infinicore::DataType::I32, infinicore::context::getDevice()); + { + std::vector tsl(b, static_cast(cs)); + infinicore::context::memcpyH2D( + input.total_sequence_lengths.value()->data(), + tsl.data(), b * sizeof(int32_t), false); + } + + // input_offsets: [b+1], stride = cs + input.input_offsets = infinicore::Tensor::empty( + {b + 1}, infinicore::DataType::I32, infinicore::context::getDevice()); + { + std::vector offsets(b + 1); + for (size_t i = 0; i <= b; i++) { + offsets[i] = static_cast(i * cs); + } + infinicore::context::memcpyH2D( + input.input_offsets.value()->data(), + offsets.data(), (b + 1) * sizeof(int32_t), false); + } + + // cu_seqlens: [b+1], same layout as input_offsets for prefill + input.cu_seqlens = infinicore::Tensor::empty( + {b + 1}, infinicore::DataType::I32, infinicore::context::getDevice()); + { + std::vector cu(b + 1); + for (size_t i = 0; i <= b; i++) { + cu[i] = static_cast(i * cs); + } + infinicore::context::memcpyH2D( + input.cu_seqlens.value()->data(), + cu.data(), (b + 1) * sizeof(int32_t), false); + } + + // block_tables: view into the shared holder [b, bpr] + input.block_tables = block_tables_holder_->as_strided( + {b, bpr}, {(ptrdiff_t)bpr, 1}); + + // slot_mapping: [total_tokens] + input.slot_mapping = infinicore::Tensor::empty( + {total_tokens}, infinicore::DataType::I64, infinicore::context::getDevice()); + set_zeros(input.slot_mapping.value()); + + // Attention reads attn_metadata from thread-local forward context. + infinilm::global_state::get_forward_context().attn_metadata = { + input.past_sequence_lengths, + input.total_sequence_lengths, + input.input_offsets, + input.cu_seqlens, + input.block_tables, + input.slot_mapping, + }; + + barrier_->wait(); + infinicore::context::startGraphRecording(); + auto output = model_->forward(input); + auto graph = infinicore::context::stopGraphRecording(); + barrier_->wait(); + + auto shared_output = std::shared_ptr( + new InfinilmModel::Output{infinicore::graph::GraphTensor(output.logits)}); + + compiled_map_prefill_[std::make_tuple(b, cs)] = + CompiledResult{std::move(input), std::make_tuple(graph, shared_output)}; + } + } + } +} + +ChunkPrefillCompiler::Compiled ChunkPrefillCompiler::get_compiled(const InfinilmModel::Input &input) { + if (model_->get_cache_config() == nullptr || + !dynamic_cast(model_->get_cache_config())) { + return {nullptr, nullptr}; + } + + if (!input.block_tables.has_value() || !input.input_ids.has_value()) { + return {nullptr, nullptr}; + } + + size_t batch_size = input.block_tables.value()->size(0); + size_t block_per_req = input.block_tables.value()->size(1); + size_t total_tokens = input.input_ids.value()->size(1); + + // Prefill: total_tokens is a multiple of batch_size, and chunk_size > 1 + if (total_tokens == 0 || total_tokens % batch_size != 0) { + return {nullptr, nullptr}; + } + size_t chunk_size = total_tokens / batch_size; + if (chunk_size <= 1) { + // Single-token case belongs to decode + return {nullptr, nullptr}; + } + + auto result = compiled_map_prefill_.find(std::make_tuple(batch_size, chunk_size)); + if (result == compiled_map_prefill_.end()) { + return {nullptr, nullptr}; + } + + auto &graph_input = result->second.input; + + graph_input.input_ids.value()->copy_from(input.input_ids.value()); + graph_input.position_ids.value()->copy_from(input.position_ids.value()); + graph_input.total_sequence_lengths.value()->copy_from(input.total_sequence_lengths.value()); + graph_input.input_offsets.value()->copy_from(input.input_offsets.value()); + graph_input.cu_seqlens.value()->copy_from(input.cu_seqlens.value()); + graph_input.block_tables.value()->narrow({{1, 0, block_per_req}})->copy_from(input.block_tables.value()); + graph_input.slot_mapping.value()->copy_from(input.slot_mapping.value()); + + auto graph = std::get<0>(result->second.compiled); + auto shared_output = std::shared_ptr( + new InfinilmModel::Output{std::get<1>(result->second.compiled)->logits->resume_from_blob_()}); + + return std::make_tuple(graph, shared_output); +} + +} // namespace infinilm::engine diff --git a/csrc/engine/compiler/chunk_prefill_compiler.hpp b/csrc/engine/compiler/chunk_prefill_compiler.hpp new file mode 100644 index 000000000..bd701158a --- /dev/null +++ b/csrc/engine/compiler/chunk_prefill_compiler.hpp @@ -0,0 +1,42 @@ +#pragma once + +#include "graph_compiler.hpp" + +#include + +namespace infinilm::engine { +class ChunkPrefillCompiler : public GraphCompiler { +public: + ChunkPrefillCompiler(const std::shared_ptr &model, RankBarrier *barrier); + + void compile() override; + + Compiled get_compiled(const InfinilmModel::Input &input) override; + +private: + struct TupleHash { + size_t operator()(const std::tuple &t) const noexcept { + auto h1 = std::hash{}(std::get<0>(t)); + auto h2 = std::hash{}(std::get<1>(t)); + return h1 ^ (h2 + 0x9e3779b97f4a7c15ULL + (h1 << 6) + (h1 >> 2)); + } + }; + + std::vector chunk_sizes_; + std::vector prefill_batch_sizes_; + + infinicore::Tensor block_tables_holder_; + + struct CompiledResult { + InfinilmModel::Input input; + Compiled compiled; + }; + + // Key: (batch_size, chunk_size) + std::unordered_map< + std::tuple, + CompiledResult, + TupleHash> + compiled_map_prefill_; +}; +} // namespace infinilm::engine diff --git a/csrc/engine/compiler/general_compiler.cpp b/csrc/engine/compiler/general_compiler.cpp index 84ee670d4..36c6420f0 100644 --- a/csrc/engine/compiler/general_compiler.cpp +++ b/csrc/engine/compiler/general_compiler.cpp @@ -1,13 +1,18 @@ #include "general_compiler.hpp" namespace infinilm::engine { -GeneralCompiler::GeneralCompiler(const std::shared_ptr &model, RankBarrier *barrier) : GraphCompiler(model, barrier) { +GeneralCompiler::GeneralCompiler(const std::shared_ptr &model, RankBarrier *barrier, bool enable_chunk_prefill_graph) + : GraphCompiler(model, barrier), enable_chunk_prefill_graph_(enable_chunk_prefill_graph) { static_batching_compiler_ = std::make_unique(model_, barrier); + chunk_prefill_compiler_ = std::make_unique(model_, barrier); paged_compiler_ = std::make_unique(model_, barrier); } void GeneralCompiler::compile() { static_batching_compiler_->compile(); + if (enable_chunk_prefill_graph_) { + chunk_prefill_compiler_->compile(); + } paged_compiler_->compile(); } @@ -19,6 +24,11 @@ GeneralCompiler::Compiled GeneralCompiler::get_compiled(const InfinilmModel::Inp if (std::get<0>(result) != nullptr && std::get<1>(result) != nullptr) { return result; } + // chunk-prefill must be checked before decode (decode would also match if chunk_size==1) + result = chunk_prefill_compiler_.get()->get_compiled(input); + if (std::get<0>(result) != nullptr && std::get<1>(result) != nullptr) { + return result; + } result = paged_compiler_.get()->get_compiled(input); return result; } diff --git a/csrc/engine/compiler/general_compiler.hpp b/csrc/engine/compiler/general_compiler.hpp index e8b84b5d9..3edbcea0c 100644 --- a/csrc/engine/compiler/general_compiler.hpp +++ b/csrc/engine/compiler/general_compiler.hpp @@ -1,12 +1,13 @@ #pragma once +#include "chunk_prefill_compiler.hpp" #include "paged_compiler.hpp" #include "static_batching_compiler.hpp" namespace infinilm::engine { class GeneralCompiler : public GraphCompiler { public: - GeneralCompiler(const std::shared_ptr &model, RankBarrier *barrier); + GeneralCompiler(const std::shared_ptr &model, RankBarrier *barrier, bool enable_chunk_prefill_graph = false); void compile() override; @@ -15,5 +16,7 @@ class GeneralCompiler : public GraphCompiler { private: std::unique_ptr static_batching_compiler_; std::unique_ptr paged_compiler_; + std::unique_ptr chunk_prefill_compiler_; + bool enable_chunk_prefill_graph_; }; } // namespace infinilm::engine diff --git a/csrc/engine/infer_engine.cpp b/csrc/engine/infer_engine.cpp index db0dfdd47..5b6ea143e 100644 --- a/csrc/engine/infer_engine.cpp +++ b/csrc/engine/infer_engine.cpp @@ -25,6 +25,7 @@ InferEngine::InferEngine( infinicore::Device::Type device_type, const cache::CacheConfig *cache_config, bool enable_graph_compiling, + bool enable_chunk_prefill_graph, backends::AttentionBackend attention_backend) // Changed parameter : communication_group_(distributed_config, device_type), legacy_model_config_(config), @@ -43,6 +44,7 @@ InferEngine::InferEngine( cache_config_ != nullptr ? cache_config_.get() : nullptr, barrier_.get(), enable_graph_compiling, + enable_chunk_prefill_graph, attention_backend_)); } @@ -56,6 +58,7 @@ InferEngine::InferEngine( infinicore::Device::Type device_type, const cache::CacheConfig *cache_config, bool enable_graph_compiling, + bool enable_chunk_prefill_graph, backends::AttentionBackend attention_backend, std::optional kv_cache_dtype) // Changed parameter : communication_group_(distributed_config, device_type), attention_backend_(attention_backend) { @@ -82,6 +85,7 @@ InferEngine::InferEngine( cache_config_ != nullptr ? cache_config_.get() : nullptr, barrier_.get(), enable_graph_compiling, + enable_chunk_prefill_graph, attention_backend_)); } // Compile the model on all workers diff --git a/csrc/engine/infer_engine.hpp b/csrc/engine/infer_engine.hpp index e36ec3699..153600c48 100644 --- a/csrc/engine/infer_engine.hpp +++ b/csrc/engine/infer_engine.hpp @@ -39,6 +39,7 @@ class InferEngine { infinicore::Device::Type device_type = infinicore::context::getDevice().getType(), const cache::CacheConfig *cache_config = nullptr, bool enable_graph_compiling = false, + bool enable_chunk_prefill_graph = false, backends::AttentionBackend attention_backend = backends::AttentionBackend::Default); InferEngine( @@ -47,6 +48,7 @@ class InferEngine { infinicore::Device::Type device_type = infinicore::context::getDevice().getType(), const cache::CacheConfig *cache_config = nullptr, bool enable_graph_compiling = false, + bool enable_chunk_prefill_graph = false, backends::AttentionBackend attention_backend = backends::AttentionBackend::Default, std::optional kv_cache_dtype = std::nullopt); diff --git a/csrc/engine/rank_worker.cpp b/csrc/engine/rank_worker.cpp index 8a94c441e..e607c569f 100644 --- a/csrc/engine/rank_worker.cpp +++ b/csrc/engine/rank_worker.cpp @@ -27,11 +27,13 @@ RankWorker::RankWorker(const InfinilmModel::Config &model_config, const cache::CacheConfig *cache_config, RankBarrier *barrier, bool enable_graph_compiling, + bool enable_chunk_prefill_graph, backends::AttentionBackend attention_backend) : legacy_model_config_(model_config), rank_info_(rank_info), attention_backend_(attention_backend), enable_graph_compiling_(enable_graph_compiling), + enable_chunk_prefill_graph_(enable_chunk_prefill_graph), job_cmd_(Command::INIT), has_job_(false), job_done_(false), @@ -56,12 +58,14 @@ RankWorker::RankWorker( const cache::CacheConfig *cache_config, RankBarrier *barrier, bool enable_graph_compiling, + bool enable_chunk_prefill_graph, backends::AttentionBackend attention_backend) : infinilm_config_(infinilm_config), model_config_(infinilm_config->model_config), rank_info_(rank_info), attention_backend_(attention_backend), enable_graph_compiling_(enable_graph_compiling), + enable_chunk_prefill_graph_(enable_chunk_prefill_graph), job_cmd_(Command::INIT), has_job_(false), job_done_(false), @@ -303,7 +307,7 @@ void RankWorker::thread_loop() { throw std::runtime_error("Failed to create model"); } if (enable_graph_compiling_) { - compiler_ = std::make_unique(model_, barrier_); + compiler_ = std::make_unique(model_, barrier_, enable_chunk_prefill_graph_); } init_done_ = true; diff --git a/csrc/engine/rank_worker.hpp b/csrc/engine/rank_worker.hpp index f6adcf476..b045adf65 100644 --- a/csrc/engine/rank_worker.hpp +++ b/csrc/engine/rank_worker.hpp @@ -75,6 +75,7 @@ class RankWorker { const cache::CacheConfig *cache_config, RankBarrier *barrier, bool enable_graph_compiling, + bool enable_chunk_prefill_graph, backends::AttentionBackend attention_backend); RankWorker(std::shared_ptr infinilm_config, @@ -82,6 +83,7 @@ class RankWorker { const cache::CacheConfig *cache_config, RankBarrier *barrier, bool enable_graph_compiling, + bool enable_chunk_prefill_graph, backends::AttentionBackend attention_backend); // Submit a parameter load job and wait until the load completes on the worker thread. @@ -131,6 +133,7 @@ class RankWorker { // Graph Compiling bool enable_graph_compiling_; + bool enable_chunk_prefill_graph_; std::unique_ptr compiler_; // Command for the pending job (protected by mutex_) diff --git a/csrc/pybind11/engine/engine.hpp b/csrc/pybind11/engine/engine.hpp index 2741c9cd7..a479f66be 100644 --- a/csrc/pybind11/engine/engine.hpp +++ b/csrc/pybind11/engine/engine.hpp @@ -37,6 +37,7 @@ inline void bind_infer_engine(py::module &m) { infinicore::Device::Type dev, std::shared_ptr cache_cfg, bool enable_graph_compiling, + bool enable_chunk_prefill_graph, const std::string &attention_backend) { return std::make_shared( cfg, @@ -44,6 +45,7 @@ inline void bind_infer_engine(py::module &m) { dev, cache_cfg ? cache_cfg.get() : nullptr, enable_graph_compiling, + enable_chunk_prefill_graph, infinilm::backends::parse_attention_backend(attention_backend)); }), py::arg("config"), @@ -51,6 +53,7 @@ inline void bind_infer_engine(py::module &m) { py::arg("device_type") = infinicore::context::getDevice().getType(), py::arg("cache_config") = py::none(), py::arg("enable_graph_compiling") = false, + py::arg("enable_chunk_prefill_graph") = false, py::arg("attention_backend") = "default") .def("load_param", &InferEngine::load_param, py::arg("name"), py::arg("param"), @@ -81,6 +84,7 @@ inline void bind_infer_engine(py::module &m) { infinicore::Device::Type dev, std::shared_ptr cache_cfg, bool enable_graph_compiling, + bool enable_chunk_prefill_graph, const std::string &attention_backend, std::optional kv_cache_dtype) { return std::make_shared( @@ -89,6 +93,7 @@ inline void bind_infer_engine(py::module &m) { dev, cache_cfg ? cache_cfg.get() : nullptr, enable_graph_compiling, + enable_chunk_prefill_graph, infinilm::backends::parse_attention_backend(attention_backend), kv_cache_dtype); }), @@ -97,6 +102,7 @@ inline void bind_infer_engine(py::module &m) { py::arg("device_type") = infinicore::context::getDevice().getType(), py::arg("cache_config") = py::none(), py::arg("enable_graph_compiling") = false, + py::arg("enable_chunk_prefill_graph") = false, py::arg("attention_backend") = "default", py::arg("kv_cache_dtype") = py::none()) .def("load_param", &InferEngine::load_param, diff --git a/python/infinilm/base_config.py b/python/infinilm/base_config.py index aab5dd459..5ef2a8ffb 100644 --- a/python/infinilm/base_config.py +++ b/python/infinilm/base_config.py @@ -61,6 +61,8 @@ def __init__(self): self.attn = self.args.attn self.enable_graph = self.args.enable_graph + self.enable_chunk_prefill_graph = self.args.enable_chunk_prefill_graph + self.chunk_size = self.args.chunk_size self.enable_paged_attn = self.args.enable_paged_attn self.num_blocks = self.args.num_blocks self.block_size = self.args.block_size @@ -122,6 +124,8 @@ def _add_common_args(self): choices=["default", "paged-attn", "flash-attn"], ) self.parser.add_argument("--enable-graph", action="store_true") + self.parser.add_argument("--enable-chunk-prefill-graph", action="store_true", help="enable chunk-prefill graph compiling") + self.parser.add_argument("--chunk-size", type=int, default=0, help="tokens per chunked-prefill slice (0 to disable)") self.parser.add_argument( "--enable-paged-attn", action="store_true", diff --git a/python/infinilm/infer_engine.py b/python/infinilm/infer_engine.py index 13bb18a19..2477bbc61 100644 --- a/python/infinilm/infer_engine.py +++ b/python/infinilm/infer_engine.py @@ -45,6 +45,7 @@ def __init__( distributed_config=DistConfig(1), cache_config=None, enable_graph_compiling=False, + enable_chunk_prefill_graph=False, attention_backend="default", kv_cache_dtype=None, ): @@ -60,6 +61,7 @@ def __init__( device._underlying.type, cache_config, enable_graph_compiling, + enable_chunk_prefill_graph, attention_backend, ( parse_dtype(kv_cache_dtype)._underlying diff --git a/python/infinilm/llm/cache_manager.py b/python/infinilm/llm/cache_manager.py index 44ca13762..df9f19577 100644 --- a/python/infinilm/llm/cache_manager.py +++ b/python/infinilm/llm/cache_manager.py @@ -119,24 +119,51 @@ def allocate_blocks( ) -> tuple[List[int], List[int], int]: """Allocate cache blocks for new request with prefix caching support. - Args: - token_ids: Input token sequence - block_table: Existing block_table (for decode phase) + Idempotent: if block_table already fully covers token_ids with valid + (still-active) blocks, returns a consistent (block_table, slot_mapping, + num_cached_tokens=0) without re-allocating. - Returns: - Tuple of (block_table, slot_mapping, num_cached_tokens) + Convention: len(slot_mapping) == num_tokens - num_cached_tokens + (one slot per token that needs to be (re)computed). """ if block_table is None: block_table = [] num_tokens = len(token_ids) - num_blocks = (num_tokens + self.block_size - 1) // self.block_size + if num_tokens == 0: + return [], [], 0 + + num_blocks_needed = (num_tokens + self.block_size - 1) // self.block_size + + # -------------------------------------------------------------- # + # Idempotent re-entry path # + # -------------------------------------------------------------- # + # If block_table already covers the prompt AND all those blocks + # are still alive (ref_count > 0), reconstruct slot_mapping from + # block_table and return num_cached_tokens=0 (i.e., the forward + # will recompute everything into the same slots — wasteful but + # always correct, and keeps the slot_mapping length convention). + if block_table and len(block_table) >= num_blocks_needed: + bt = list(block_table[:num_blocks_needed]) + if all(self.blocks[bid].ref_count > 0 for bid in bt): + slot_mapping = [ + bt[i // self.block_size] * self.block_size + (i % self.block_size) + for i in range(num_tokens) + ] + # length = num_tokens = num_tokens - 0 ✓ matches convention + return bt, slot_mapping, 0 + # Otherwise the block_table is stale — drop it and re-allocate. + block_table = [] + + # -------------------------------------------------------------- # + # Below: original code unchanged # + # -------------------------------------------------------------- # slot_mapping = [] num_cached_tokens = 0 prefix_hash = -1 cache_miss = False - for block_idx in range(num_blocks): + for block_idx in range(num_blocks_needed): start_idx = block_idx * self.block_size end_idx = min(start_idx + self.block_size, num_tokens) block_tokens = token_ids[start_idx:end_idx] diff --git a/python/infinilm/llm/llm.py b/python/infinilm/llm/llm.py index cba3af83a..dd8687e52 100644 --- a/python/infinilm/llm/llm.py +++ b/python/infinilm/llm/llm.py @@ -72,6 +72,8 @@ class EngineConfig: top_p: float = 0.8 top_k: int = 1 enable_graph: bool = False + enable_chunk_prefill_graph: bool = False + chunk_size: int = 0 attn_backend: str = "default" skip_load: bool = False @@ -91,6 +93,7 @@ def __init__(self, config: EngineConfig): device=self.device, distributed_config=DistConfig(config.tensor_parallel_size), enable_graph_compiling=config.enable_graph, + enable_chunk_prefill_graph=config.enable_chunk_prefill_graph, attention_backend=config.attn_backend, ) @@ -167,6 +170,8 @@ def _init_device(self): def add_request(self, request: InferenceRequest): """Add a request to the scheduler.""" + if self.cache_type == "paged" and self.config.chunk_size > 0: + request.chunk_size = self.config.chunk_size self.scheduler.add_request(request) def step(self) -> tuple[list[InferenceRequest], list[tuple]]: @@ -210,7 +215,17 @@ def _update_requests( sampled_tokens: List[int], ) -> List[tuple]: """Update request status after inference step.""" - if is_prefill: + # Detect a chunked-prefill mid-step: single request, prefill phase, + # and this chunk does not yet cover the whole prompt. In that case + # we must NOT consume a sampled token, NOT commit prefill blocks, + # and re-enqueue the request to keep chunking. + chunk_mid_step = ( + is_prefill + and len(requests) > 0 + and all(r.is_chunking() and not r.chunk_is_last() for r in requests) + ) + + if is_prefill and not chunk_mid_step: match self.cache_type: case "paged": self.scheduler.cache_manager.reset_req_blocks() @@ -218,6 +233,18 @@ def _update_requests( self.scheduler.update_cache() case _: raise ValueError(f"Unsupported cache_type: {self.cache_type}") + + if chunk_mid_step: + for req in requests: + req.chunk_prefill_offset += req.chunk_size + if req.is_aborted(): + logger.info( + f"Request {req.request_id} aborted by client during chunked-prefill" + ) + continue + self.scheduler.requeue_chunking(req) + return [] + pending = [] for req, token_id in zip(requests, sampled_tokens): if req.is_aborted(): @@ -227,6 +254,10 @@ def _update_requests( continue if req.is_prefill: + # Clean up chunked-prefill state on the final chunk so the + # next forward pass on this request takes the decode path. + req.chunk_prefill_offset = 0 + req.chunk_size = 0 req.is_prefill = False req.generated_token_ids.append(token_id) @@ -361,6 +392,8 @@ def __init__( top_p: float = 0.8, top_k: int = 1, enable_graph: bool = False, + enable_chunk_prefill_graph: bool = False, + chunk_size: int = 0, attn_backend: str = "default", skip_load: bool = False, ): @@ -398,6 +431,8 @@ def __init__( top_p=top_p, top_k=top_k, enable_graph=enable_graph, + enable_chunk_prefill_graph=enable_chunk_prefill_graph, + chunk_size=chunk_size, attn_backend=attn_backend, skip_load=skip_load, ) @@ -539,6 +574,8 @@ def __init__( top_p: float = 0.8, top_k: int = 1, enable_graph: bool = False, + enable_chunk_prefill_graph: bool = False, + chunk_size: int = 0, attn_backend: str = "default", ): """Initialize AsyncLLMEngine. @@ -575,6 +612,8 @@ def __init__( top_p=top_p, top_k=top_k, enable_graph=enable_graph, + enable_chunk_prefill_graph=enable_chunk_prefill_graph, + chunk_size=chunk_size, attn_backend=attn_backend, ) self.engine = LLMEngine(config) @@ -620,8 +659,13 @@ def _step_loop(self): requests, pending = self.engine.step() if not requests: time.sleep(0.01) - elif pending: - self._loop.call_soon_threadsafe(self._batch_put, pending) + else: + if pending: + self._loop.call_soon_threadsafe(self._batch_put, pending) + # Yield GIL so the asyncio main thread can deliver tokens + # to clients between inference steps. Without this, the step + # thread monopolizes the GIL and token streaming stalls. + time.sleep(0.0005) except Exception as e: logger.error(f"Error in step loop: {e}", exc_info=True) self._healthy = False diff --git a/python/infinilm/llm/request.py b/python/infinilm/llm/request.py index 15bcf69f4..ef5c8cd2e 100644 --- a/python/infinilm/llm/request.py +++ b/python/infinilm/llm/request.py @@ -144,6 +144,11 @@ def __init__( self.num_cached_tokens: int = 0 self.num_blocks: int = 0 + # Chunked-prefill state (0 = disabled, otherwise tokens per chunk) + self.chunk_size: int = 0 + # Number of prompt tokens already fed through forward as chunked-prefill + self.chunk_prefill_offset: int = 0 + # For server use self.request_data: Optional[dict] = request_data self.http_request: Optional[Any] = http_request @@ -186,6 +191,18 @@ def get_num_blocks_required(self, block_size: int) -> int: def get_max_tokens(self) -> Optional[int]: return self.sampling_params.max_tokens + def is_chunking(self) -> bool: + """Return True if this request is in the middle of chunked-prefill.""" + return ( + self.chunk_size > 0 + and self.is_prefill + and (self.prompt_length - self.num_cached_tokens) > self.chunk_size + ) + + def chunk_is_last(self) -> bool: + """Return True if the next chunk would finish the prompt.""" + return self.chunk_prefill_offset + self.chunk_size >= self.prompt_length + def is_finished(self) -> bool: return self.status in [ RequestStatus.FINISHED, diff --git a/python/infinilm/llm/scheduler.py b/python/infinilm/llm/scheduler.py index f9c11635a..e235842a3 100644 --- a/python/infinilm/llm/scheduler.py +++ b/python/infinilm/llm/scheduler.py @@ -28,10 +28,17 @@ def __init__( class Scheduler: """Request scheduler with integrated BlockManager for KV cache management. - Scheduling logic: - 1. Running queue: Check for new blocks needed, update slot_mapping - 2. Waiting queue: Try block reuse (prefix caching), allocate new blocks - 3. Reference counting: Free blocks when requests complete + Scheduling priority (option A + B): + 1. Decode (running_queue) — latency-sensitive, never starves anyone. + 2. New prefill (waiting_queue) — preempts in-flight chunking so newly + arrived short requests don't wait for an entire long prefill. + 3. Continue chunked-prefill (chunking_queue) — single-request batch. + + Anti-starvation (option B): + After `max_waiting_yields` consecutive steps where waiting_queue won + over a non-empty chunking_queue, the next step is forced onto the + chunking_queue. This bounds the worst-case delay a long-prompt request + can suffer when there is a steady inflow of new short requests. """ def __init__( @@ -39,44 +46,135 @@ def __init__( max_batch_size: int = 16, num_blocks: int = 512, block_size: int = 256, + max_waiting_yields: int = 4, ): self.waiting_queue = janus.Queue() self.running_queue = janus.Queue() + # Requests in the middle of chunked-prefill — single-request batch only + # (matches the C++ ChunkPrefillCompiler graph signature). + self.chunking_queue = janus.Queue() self.max_batch_size = max_batch_size self.cache_manager = BlockManager(num_blocks=num_blocks, block_size=block_size) self.block_size = block_size + # ---- Anti-starvation state ---- + # How many times waiting_queue has won over a non-empty chunking_queue + # since the last time chunking actually ran. Reset to 0 every time we + # run a chunking step. + self._waiting_yields_in_a_row: int = 0 + # Upper bound on _waiting_yields_in_a_row before chunking is forced. + self.max_waiting_yields: int = max_waiting_yields + def add_request(self, request: InferenceRequest): if request is not None: request.status = RequestStatus.WAITING self.waiting_queue.sync_q.put(request) + # ------------------------------------------------------------------ # + # Main scheduling entrypoint # + # ------------------------------------------------------------------ # def schedule(self) -> Optional[SchedulerOutput]: - """Schedule and return batch of requests to execute.""" - scheduled_requests = [] - is_prefill = False + """Schedule and return batch of requests to execute. + + Priority (prefill first): + 1. New prefill (waiting_queue) — minimize TTFT for new requests. + 2. Decode (running_queue). + 3. Continue chunked-prefill (chunking_queue). + + Anti-starvation (only guards chunking against waiting): + After `max_waiting_yields` consecutive steps where waiting_queue won + over a non-empty chunking_queue, the next step is forced onto the + chunking_queue. + """ + # 0) Forced chunking after too many consecutive waiting yields. + if self._waiting_yields_in_a_row >= self.max_waiting_yields: + chunking_out = self._try_schedule_chunking() + if chunking_out is not None: + self._waiting_yields_in_a_row = 0 + return chunking_out + # chunking_queue was actually empty — fall through to normal path. + + # 1) New prefill + chunking_was_nonempty = self.chunking_queue.sync_q.qsize() > 0 + waiting_out = self._try_schedule_waiting() + if waiting_out is not None: + if chunking_was_nonempty: + self._waiting_yields_in_a_row += 1 + else: + self._waiting_yields_in_a_row = 0 + return waiting_out + + # 2) Decode. + chunking_was_nonempty = self.chunking_queue.sync_q.qsize() > 0 + decode_out = self._try_schedule_decode() + if decode_out is not None: + if chunking_was_nonempty: + self._waiting_yields_in_a_row += 1 + else: + self._waiting_yields_in_a_row = 0 + return decode_out - # Process Waiting queue (prefill phase) - while len(scheduled_requests) < self.max_batch_size: + # 3) Continue an in-flight chunked-prefill request. + chunking_out = self._try_schedule_chunking() + if chunking_out is not None: + self._waiting_yields_in_a_row = 0 + return chunking_out + + return None + + # ------------------------------------------------------------------ # + # Per-queue schedulers # + # ------------------------------------------------------------------ # + def _try_schedule_chunking(self) -> Optional[SchedulerOutput]: + scheduled: List[InferenceRequest] = [] + while len(scheduled) < self.max_batch_size: try: - req = self.waiting_queue.sync_q.get_nowait() + req = self.chunking_queue.sync_q.get_nowait() except queue.Empty: break - # Skip requests that were already finished (e.g., timed out/canceled while waiting) if req.is_finished(): self.complete_requests([req]) continue + # 最后一块(partial 或恰好等于 chunk_size 的最后整块)单独跑。 + # 不能和中间整块混批:最后一块要采样+提交 block,中间块两个都不做。 + if req.chunk_is_last(): + if not scheduled: + return SchedulerOutput([req], is_prefill=True) + # 已经攒了中间块,先把这个 last-chunk 放回队头,等下个 step 单独跑。 + self.chunking_queue.sync_q.put(req) + break + scheduled.append(req) + if scheduled: + return SchedulerOutput(scheduled, is_prefill=True) + return None - if not self.can_accept_request(req): - self.waiting_queue.sync_q.put(req) + def _try_schedule_waiting(self) -> Optional[SchedulerOutput]: + """Pull new prefill requests from waiting_queue and form a prefill batch. + + If any request triggers chunked-prefill (prompt_length > chunk_size > 0), + it's emitted alone as a single-request batch (the chunking graph requires + a uniform chunk_size across the batch, and we don't mix chunking with + regular prefill in the same batch). + """ + scheduled_requests: List[InferenceRequest] = [] + + while len(scheduled_requests) < self.max_batch_size: + try: + req = self.waiting_queue.sync_q.get_nowait() + except queue.Empty: break - # Skip requests that were already finished (e.g., timed out/canceled while waiting) + # Skip requests that were already finished (timed out / canceled while waiting). if req.is_finished(): self.complete_requests([req]) continue + if not self.can_accept_request(req): + # Put it back; we'll retry next tick when cache pressure eases. + self.waiting_queue.sync_q.put(req) + break + req_tokens = req.get_input_tokens() num_required_blocks = req.get_num_blocks_required(self.block_size) @@ -84,35 +182,53 @@ def schedule(self) -> Optional[SchedulerOutput]: if not self.cache_manager.try_free_blocks(num_required_blocks): raise RuntimeError("No available cache blocks for new request") - # Allocate blocks with automatic prefix caching support - req.block_table, req.slot_mapping, req.num_cached_tokens = ( - self.cache_manager.allocate_blocks(req_tokens, req.block_table) - ) - + # Allocate blocks (prefix caching applied automatically). + if not req.block_table: + req.block_table, req.slot_mapping, req.num_cached_tokens = ( + self.cache_manager.allocate_blocks(req_tokens, req.block_table) + ) + req.num_blocks = len(req.block_table) req.status = RequestStatus.RUNNING + + # Start chunked-prefill: emit a single-request batch immediately + # to keep the C++ graph signature stable. The request will be + # requeued into chunking_queue by llm._update_requests after each + # chunk runs. + remaining = req.prompt_length - req.num_cached_tokens + if req.chunk_size > 0 and remaining > req.chunk_size: + req.chunk_prefill_offset = req.num_cached_tokens + if scheduled_requests: + for already in scheduled_requests: + already.status = RequestStatus.WAITING + self.waiting_queue.sync_q.put(already) + return SchedulerOutput([req], is_prefill=True) + scheduled_requests.append(req) - # Return prefill batch if any waiting requests were scheduled if scheduled_requests: - is_prefill = True return SchedulerOutput( scheduled_requests=scheduled_requests, - is_prefill=is_prefill, + is_prefill=True, ) + return None + + def _try_schedule_decode(self) -> Optional[SchedulerOutput]: + """Pull running_queue requests into a decode batch.""" + scheduled_requests: List[InferenceRequest] = [] - # Process Running queue (decode phase) while len(scheduled_requests) < self.max_batch_size: try: req = self.running_queue.sync_q.get_nowait() except queue.Empty: break - # Skip requests that were already finished (e.g., timed out/canceled while running) + + # Skip requests that were already finished (timed out / canceled while running). if req.is_finished(): self.complete_requests([req]) continue - # Decode phase: allocate slot for newly generated token + # Decode phase: allocate slot for newly generated token. try: req.block_table, new_slot = self.cache_manager.append_slot( req.block_table, req.get_total_length(), req.get_all_token_ids() @@ -121,22 +237,30 @@ def schedule(self) -> Optional[SchedulerOutput]: req.num_blocks = len(req.block_table) req.num_cached_tokens = req.get_total_length() - 1 scheduled_requests.append(req) - except RuntimeError as e: raise RuntimeError("No available cache blocks for new token") from e - # Return decode batch if any running requests were scheduled if scheduled_requests: - is_prefill = False return SchedulerOutput( scheduled_requests=scheduled_requests, - is_prefill=is_prefill, + is_prefill=False, ) - return None + # ------------------------------------------------------------------ # + # External hooks (unchanged behavior) # + # ------------------------------------------------------------------ # + def requeue_chunking(self, req: InferenceRequest): + """Put a request back into the chunking queue after a chunk has run.""" + self.chunking_queue.sync_q.put(req) + def complete_requests(self, requests: List[InferenceRequest]): - """Handle completed requests and free their blocks.""" + """Handle completed requests and free their blocks. + + Active (non-finished) requests passed here are re-enqueued into the + running_queue — this is how prefill-finished requests migrate into + the decode pipeline. + """ for req in requests: if req.status in [ RequestStatus.FINISHED, @@ -196,4 +320,4 @@ def get_cache_stats(self) -> dict: "num_free_blocks": self.cache_manager.get_num_free_blocks(), "num_req_blocks": len(self.cache_manager.req_block_ids), "num_used_blocks": len(self.cache_manager.used_block_ids), - } + } \ No newline at end of file diff --git a/python/infinilm/processors/basic_llm_processor.py b/python/infinilm/processors/basic_llm_processor.py index 070a40622..397e9068f 100644 --- a/python/infinilm/processors/basic_llm_processor.py +++ b/python/infinilm/processors/basic_llm_processor.py @@ -183,42 +183,72 @@ def _build_model_input_from_batch_scheduler_output( for req in scheduler_output.scheduled_requests: num_cached = req.num_cached_tokens if scheduler_output.is_prefill: - # Prefill phase req_tokens = req.get_input_tokens() - tokens_to_compute = req_tokens[num_cached:] - tokens.extend(tokens_to_compute) - - compute_len = len(tokens_to_compute) - seq_len = len(req_tokens) - seq_lens.append(seq_len) - - current_offset += compute_len - seq_offsets.append(current_offset) - - slot_mapping.extend(req.slot_mapping) - cached_lens.append(num_cached) - position_ids.extend(range(num_cached, num_cached + compute_len)) + # Chunked-prefill: only feed [chunk_prefill_offset : +chunk_size). + if req.is_chunking(): + start = req.chunk_prefill_offset + end = min(start + req.chunk_size, len(req_tokens)) + tokens_to_compute = req_tokens[start:end] + compute_len = len(tokens_to_compute) + tokens.extend(tokens_to_compute) + seq_len = end + seq_lens.append(seq_len) + current_offset += compute_len + seq_offsets.append(current_offset) + # req.slot_mapping has length (prompt_length - num_cached) and is + # indexed [0..prompt_length-num_cached). Translate absolute token + # indices to slot_mapping indices. + slot_start = start - num_cached + slot_end = end - num_cached + assert slot_start >= 0 and slot_end <= len(req.slot_mapping), ( + f"chunking slot slice out of range: start={start} " + f"end={end} num_cached={num_cached} " + f"len(slot_mapping)={len(req.slot_mapping)}" + ) + slot_mapping.extend(req.slot_mapping[slot_start:slot_end]) + cached_lens.append(start) + position_ids.extend(range(start, end)) + else: + tokens_to_compute = req_tokens[num_cached:] + tokens.extend(tokens_to_compute) + compute_len = len(tokens_to_compute) + seq_len = len(req_tokens) + seq_lens.append(seq_len) + current_offset += compute_len + seq_offsets.append(current_offset) + slot_mapping.extend(req.slot_mapping) + cached_lens.append(num_cached) + position_ids.extend(range(num_cached, num_cached + compute_len)) else: - # Decode phase seq_len = req.get_total_length() last_token = req.generated_token_ids[-1] tokens.append(last_token) seq_lens.append(seq_len) - current_offset += 1 seq_offsets.append(current_offset) - slot_mapping.extend(req.slot_mapping) cached_lens.append(num_cached) position_ids.append(seq_len - 1) - # Pad block_table to same length padded_block_table = req.block_table + [-1] * ( max_block_table_len - len(req.block_table) ) block_tables.append(padded_block_table) cu_seqlens.append(cu_seqlens[-1] + seq_len) + + # guarantee non-empty tokens and slot_mapping to avoid downstream errors. If empty, raise with detailed debug info. + if not tokens or not slot_mapping: + states = [ + (r.request_id[:8], r.is_prefill, r.is_chunking(), + r.chunk_prefill_offset, r.prompt_length, r.num_cached_tokens, + len(r.slot_mapping), r.status.name) + for r in scheduler_output.scheduled_requests + ] + raise RuntimeError( + f"build_model_inputs got empty tokens/slot_mapping. " + f"is_prefill={scheduler_output.is_prefill}, states={states}" + ) return { "input_ids": infinicore.from_list([tokens], dtype=infinicore.int64), diff --git a/python/infinilm/server/inference_server.py b/python/infinilm/server/inference_server.py index 71e9c992f..ac7e94e71 100644 --- a/python/infinilm/server/inference_server.py +++ b/python/infinilm/server/inference_server.py @@ -108,6 +108,8 @@ def __init__( host: str = "0.0.0.0", port: int = 8000, enable_graph: bool = False, + enable_chunk_prefill_graph: bool = False, + chunk_size: int = 0, attn_backend: str = "default", ignore_eos: bool = False, ): @@ -130,6 +132,10 @@ def __init__( host: Server host address. port: Server port number. enable_graph: Whether to enable graph compiling. + enable_chunk_prefill_graph: Whether to enable chunk-prefill graph compiling. + chunk_size: Tokens per chunked-prefill slice (0 = disabled). When > 0 and paged + cache is used, long prompts are sliced and each slice goes through forward + separately so the C++ ChunkPrefillCompiler precompiled graph can be reused. attn_backend: Attention backend to use ('default', 'flash-attn'). """ self.model_path = model_path @@ -150,6 +156,8 @@ def __init__( self.host = host self.port = port self.enable_graph = enable_graph + self.enable_chunk_prefill_graph = enable_chunk_prefill_graph + self.chunk_size = chunk_size self.attn_backend = attn_backend self.ignore_eos = ignore_eos @@ -182,11 +190,15 @@ async def lifespan(app: FastAPI): top_p=self.top_p, top_k=self.top_k, enable_graph=self.enable_graph, + enable_chunk_prefill_graph=self.enable_chunk_prefill_graph, + chunk_size=self.chunk_size, attn_backend=self.attn_backend, ) self.engine.start() logger.info(f"Engine initialized with model at {self.model_path}") logger.info(f" enable_graph: {self.enable_graph}") + logger.info(f" enable_chunk_prefill_graph: {self.enable_chunk_prefill_graph}") + logger.info(f" chunk_size: {self.chunk_size}") yield self.engine.stop() @@ -572,6 +584,8 @@ def main(): host=cfg.host, port=cfg.port, enable_graph=cfg.enable_graph, + enable_chunk_prefill_graph=cfg.enable_chunk_prefill_graph, + chunk_size=cfg.chunk_size, attn_backend=cfg.attn, ignore_eos=cfg.ignore_eos, ) diff --git a/scripts/infer_task.py b/scripts/infer_task.py index 0d1231b77..1851f0a0a 100644 --- a/scripts/infer_task.py +++ b/scripts/infer_task.py @@ -10,6 +10,8 @@ def __init__(self, id, tokens, max_tokens, temperature, topk, topp, end_tokens): self.end_tokens = end_tokens self._kv_cache = None self.pos = 0 + self._discard_output = False + self._remaining_tokens = None def bind_kvcache(self, kv_cache, pos=0): self._kv_cache = kv_cache @@ -24,6 +26,25 @@ def release_kvcache(self): def kvcache(self): return self._kv_cache + def setup_chunked_prefill(self, chunk_size): + if chunk_size <= 0 or len(self.tokens) <= chunk_size: + return + self._remaining_tokens = self.tokens[chunk_size:] + self.tokens = self.tokens[:chunk_size] + self._discard_output = True + + def advance_prefill_chunk(self, chunk_size): + self._kv_cache.update_tokens(self.tokens, self.pos) + self.pos += len(self.tokens) + + if len(self._remaining_tokens) <= chunk_size: + self.tokens = self._remaining_tokens + self._remaining_tokens = None + self._discard_output = False + else: + self.tokens = self._remaining_tokens[:chunk_size] + self._remaining_tokens = self._remaining_tokens[chunk_size:] + def next(self, out_token): self._kv_cache.update_tokens(self.tokens, self.pos) diff --git a/scripts/launch_server.py b/scripts/launch_server.py index d04d4f69d..0639a28b4 100644 --- a/scripts/launch_server.py +++ b/scripts/launch_server.py @@ -64,6 +64,13 @@ def parse_args(): default=None, help="Max token sequence length that model will handle (follows model config if not provided)", ) + parser.add_argument( + "--chunk-size", + type=int, + default=512, + help="Maximum number of tokens per prefill chunk (default: 512). " + "Set to 0 to disable chunked prefill.", + ) parser.add_argument( "--awq", action="store_true", @@ -86,8 +93,10 @@ def parse_args(): USE_AWQ = args.awq USE_GPTQ = args.gptq MAX_BATCH = args.max_batch +CHUNK_SIZE = args.chunk_size print( - f"Using MAX_BATCH={MAX_BATCH}. Try reduce this value if out of memory error occurs." + f"Using MAX_BATCH={MAX_BATCH}, CHUNK_SIZE={CHUNK_SIZE}. " + f"Try reduce these values if out of memory error occurs." ) @@ -163,32 +172,66 @@ async def lifespan(app: FastAPI): # App loop: take requests from the queue, do inference, and put unfinished requests back into the queue. +# Uses priority scheduling: decode/short tasks first, then prefill chunks. def worker_loop(app): + pending_prefill = [] # Low priority: chunked prefill tasks + while True: + # Drain all available tasks from the queue + incoming = [] try: task = app.state.request_queue.sync_q.get(timeout=0.01) + if task is None: + return + incoming.append(task) except queue.Empty: - continue - - if task is None: - return + pass - batch = [task] - while len(batch) < MAX_BATCH: + while True: try: - req = app.state.request_queue.sync_q.get_nowait() - if req is not None: - batch.append(req) + task = app.state.request_queue.sync_q.get_nowait() + if task is None: + return + incoming.append(task) except queue.Empty: break + + # Separate into high priority (decode/new short) and low priority (prefill chunks) + high_priority = [] + for t in incoming: + if t._discard_output: + pending_prefill.append(t) + else: + high_priority.append(t) + + # Build batch: high priority first, then fill with prefill chunks + batch = [] + while high_priority and len(batch) < MAX_BATCH: + batch.append(high_priority.pop(0)) + while pending_prefill and len(batch) < MAX_BATCH: + batch.append(pending_prefill.pop(0)) + + if not batch: + continue + output_tokens = app.state.model.batch_infer_one_round(batch) for task, token in zip(batch, output_tokens): - task.output(token) - if task.finish_reason is None: - app.state.request_queue.sync_q.put(task) + if task._discard_output: + task.advance_prefill_chunk(CHUNK_SIZE) + if task.finish_reason is None: + if task._discard_output: + pending_prefill.append(task) + else: + app.state.request_queue.sync_q.put(task) + else: + app.state.kv_cache_pool.release_sync(task) else: - print(f"[INFO] Task {task.id} finished infer.") - app.state.kv_cache_pool.release_sync(task) + task.output(token) + if task.finish_reason is None: + app.state.request_queue.sync_q.put(task) + else: + print(f"[INFO] Task {task.id} finished infer.") + app.state.kv_cache_pool.release_sync(task) def build_task(id_, request_data, request: Request): @@ -214,6 +257,7 @@ async def chat_stream(id_, request_data, request: Request): try: infer_task = build_task(id_, request_data, request) await request.app.state.kv_cache_pool.acquire(infer_task) + infer_task.setup_chunked_prefill(CHUNK_SIZE) # Initial empty content chunk = json.dumps( @@ -255,6 +299,7 @@ async def chat(id_, request_data, request: Request): try: infer_task = build_task(id_, request_data, request) await request.app.state.kv_cache_pool.acquire(infer_task) + infer_task.setup_chunked_prefill(CHUNK_SIZE) request.app.state.request_queue.sync_q.put(infer_task) output = [] while True: diff --git a/scripts/test_chunk_prefill.py b/scripts/test_chunk_prefill.py new file mode 100644 index 000000000..90d01b9e5 --- /dev/null +++ b/scripts/test_chunk_prefill.py @@ -0,0 +1,194 @@ +#!/usr/bin/env python3 +""" +一键对比 chunked prefill 开/关性能 + +该脚本会依次启动 launch_server.py (chunk-size=0/256),运行 test_perf_mix.py 取结果,最后输出对比。 +""" + +import argparse +import os +import re +import signal +import subprocess +import sys +import time + + +SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) +LM_DIR = os.path.dirname(SCRIPT_DIR) +INFERENCE_SERVER = os.path.join(LM_DIR, "python", "infinilm", "server", "inference_server.py") +TEST_SCRIPT = os.path.join(SCRIPT_DIR, "test_perf_mix.py") + + +from openai import OpenAI, APIConnectionError, APIStatusError + +def wait_for_server(popen, host, port, model, timeout=300): + client = OpenAI(base_url=f"http://{host}:{port}", api_key="default") + deadline = time.time() + timeout + while time.time() < deadline: + if popen.poll() is not None: + raise RuntimeError(f"server exited early with code {popen.returncode}") + try: + client.chat.completions.create( + model=model, + messages=[{"role": "user", "content": "hi"}], + max_tokens=1, + ) + return + except (APIConnectionError, APIStatusError): + time.sleep(1) + raise TimeoutError(f"server not ready within {timeout}s") + + +def inference_server(chunk_size, device, port, batch_size, max_new_tokens, enable_paged_attn, enable_graph, model_path): + print(INFERENCE_SERVER) + args = ["CUDA_VISIBLE_DEVICES=12", sys.executable, INFERENCE_SERVER, + f"--chunk-size {chunk_size}", + f"--device {device}", + f"--port {port}", + f"--batch-size {batch_size}", + f"--max-new-tokens {max_new_tokens}", + f"--model {model_path}"] + if enable_paged_attn: + args.append("--enable-paged-attn") + if enable_graph: + args.append("--enable-graph") + + popen = subprocess.Popen(" ".join(args), shell=True, preexec_fn=os.setsid, stderr=subprocess.STDOUT) + return popen + + +import socket + +def stop_server(popen, host="127.0.0.1", port=2333, timeout=30): + if popen and popen.poll() is None: + os.killpg(os.getpgid(popen.pid), signal.SIGTERM) + try: + popen.wait(timeout=timeout) + except subprocess.TimeoutExpired: + os.killpg(os.getpgid(popen.pid), signal.SIGKILL) + popen.wait(timeout=5) + + # 等端口真正释放(uvicorn 在 graceful shutdown 期间端口还开着) + deadline = time.time() + timeout + while time.time() < deadline: + try: + with socket.create_connection((host, port), timeout=0.5): + pass # 还有人监听,继续等 + except OSError: + return # 端口已释放 + time.sleep(0.3) + raise RuntimeError(f"port {port} still in use after stop_server") + + +def run_test_perf(): + cmd = f"{sys.executable} -u {TEST_SCRIPT}" + proc = subprocess.Popen( + cmd, shell=True, text=True, bufsize=1, + stdout=subprocess.PIPE, stderr=subprocess.STDOUT, + ) + lines = [] + for line in proc.stdout: + sys.stdout.write(line) + sys.stdout.flush() # 子终端的输出直接转发到父终端,保持实时显示 + lines.append(line) + proc.wait() + return proc.returncode, "".join(lines) # 返回码 + test_perf_mix.py的输出文本 + +def parse_stats(output): + def grab(pat): + m = re.search(pat, output) + return float(m.group(1)) if m else None + + success_m = re.search(r"成功请求数\s*:\s*(\d+)", output) + return { + "avg_ttft_s": grab(r"Average TTFT\s*:\s*([0-9.]+)\s*s"), + "avg_e2e_s": grab(r"Average latency\s*:\s*([0-9.]+)\s*s"), + "avg_ms_per_token": grab(r"Avg time per token\s*:\s*([0-9.]+)\s*ms/token"), + "avg_tps": grab(r"Avg Token generation speed\s*:\s*([0-9.]+)"), + "rps": grab(r"请求速率 \(RPS\)\s*:\s*([0-9.]+)"), + "success": int(success_m.group(1)) if success_m else None, + } + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="比较 chunked prefill 开/关的 TTFT/E2E") + parser.add_argument("--device", type=str, default="iluvatar", help="设备类型") + parser.add_argument("--batch-size", type=int, default=16) + parser.add_argument("--max-new-tokens", type=int, default=16) + parser.add_argument("--enable-paged-attn", type=bool, default=True) + parser.add_argument("--enable-graph", type=bool, default=True) + parser.add_argument("--port", type=int, default=2333) + parser.add_argument("--model-path", type=str, default="/data-aisoft/mechdancer/models/9g_8b_thinking_llama/") + + + args = parser.parse_args() + + results = [] + + for chunk_size in (0, 256): + mode = "ON" if chunk_size > 0 else "OFF" + print("\n" + "="*78) + print(f"开始部署大模型推理服务,chunked prefill {mode} (chunk-size={chunk_size}),请等待服务启动完成...") + + server = inference_server(chunk_size=chunk_size, device=args.device, port=args.port, + batch_size=args.batch_size, max_new_tokens=args.max_new_tokens, + enable_paged_attn=args.enable_paged_attn, enable_graph=args.enable_graph, + model_path=args.model_path) + try: + wait_for_server(server, "127.0.0.1", args.port, model="FM9G-7B", timeout=300) + print("服务启动完成,开始跑 test_perf_mix.py (上一条200OK请求为服务测试成功标志)") + retcode, out = run_test_perf() + if retcode != 0: + print("test_perf_cp.py 执行失败,退出码", retcode) + print(out) + raise SystemExit(1) + + stats = parse_stats(out) + stats.update({"chunk_size": chunk_size, "mode": mode}) + results.append(stats) + print(f"完成chunked prefill {mode}测试 -> {stats}") + + finally: + stop_server(server, host="127.0.0.1", port=args.port) + print("服务已停止") + + print("\n" + "#"*78) + print("最终对比(chunked prefill ON vs OFF)") + print("-"*78) + + def fmt(v, unit="", spec=".3f"): + return "N/A" if v is None else f"{v:{spec}}{unit}" + + def diff(a, b): + return None if (a is None or b is None) else a - b + + def speedup_pct(on, off): + # 越小越快的指标:正数 = ON 比 OFF 快 + if on is None or off is None or off == 0: + return None + return (off - on) / off * 100 + + on_r = next((r for r in results if r["mode"] == "ON"), None) + off_r = next((r for r in results if r["mode"] == "OFF"), None) + + print(f"{'指标':<22}{'ON':>14}{'OFF':>14}{'Δ (ON-OFF)':>16}{'ON 提升':>12}") + print("-"*78) + + def row(label, key, unit, spec=".3f", lower_is_better=True): + a = (on_r or {}).get(key) + b = (off_r or {}).get(key) + pct = speedup_pct(a, b) if lower_is_better else speedup_pct(b, a) + print( + f"{label:<22}" + f"{fmt(a, unit, spec):>14}" + f"{fmt(b, unit, spec):>14}" + f"{fmt(diff(a, b), unit, '+'+spec):>16}" + f"{fmt(pct, '%', '+.2f'):>12}" + ) + + row("Avg TTFT", "avg_ttft_s", " s") + row("Avg E2E latency", "avg_e2e_s", " s") + row("Avg ms/token", "avg_ms_per_token", " ms", ".2f") + row("Avg tokens/s", "avg_tps", "", ".2f", lower_is_better=False) + row("RPS", "rps", "", ".2f", lower_is_better=False) + print("-"*78) diff --git a/scripts/test_perf.py b/scripts/test_perf.py index 6a33d8f0d..3e4116f54 100644 --- a/scripts/test_perf.py +++ b/scripts/test_perf.py @@ -29,7 +29,7 @@ NUM_REQUESTS = 64 CONCURRENCY = 20 -API_URL = "http://127.0.0.1:8000" +API_URL = "http://127.0.0.1:2333" MODEL = "FM9G-7B" diff --git a/scripts/test_perf_cp.py b/scripts/test_perf_cp.py new file mode 100644 index 000000000..4e55bae0f --- /dev/null +++ b/scripts/test_perf_cp.py @@ -0,0 +1,156 @@ +""" +Chunked Prefill TTFT Benchmark + +Test: send a long request, wait a short delay, then send a short request. +Measure the short request's TTFT and E2E. + +With chunked prefill: short request inserts at next chunk boundary → lower TTFT +Without chunked prefill: short request waits for full long prefill → higher TTFT + +Usage: + python3 scripts/test_perf.py [--rounds 5] [--delay 0.1] +""" +import asyncio +import time +from openai import AsyncOpenAI +import argparse + +API_URL = "http://127.0.0.1:2333" +MODEL = "jiuge" +MAX_TOKENS = 30 # decode的tokens数 + +_BASE_PARAGRAPHS = [ + '''人工智能(Artificial Intelligence,简称AI)是计算机科学的一个分支,它企图了解智能的实质,并生产出一种新的能以人类智能相似的方式做出反应的智能机器。人工智能的研究包括机器人、语言识别、图像识别、自然语言处理和专家系统等。人工智能从诞生以来,理论和技术日益成熟,应用领域也不断扩大。可以设想,未来人工智能带来的科技产品,将会是人类智慧的容器。''', + '''1956年夏季,以麦卡赛、明斯基、罗切斯特和申农等为首的一批有远见卓识的年轻科学家在一起聚会,共同研究和探讨用机器模拟智能的一系列有关问题,并首次提出了人工智能这一术语,它标志着人工智能这门新兴学科的正式诞生。此后,IBM公司研制的专用计算机深蓝击败了国际象棋世界冠军卡斯帕罗夫。谷歌公司开发的AlphaGo程序战胜了围棋世界冠军李世石,这被认为是人工智能发展史上的一个重要里程碑。''', + '''量子计算是一种利用量子力学原理进行信息处理的计算方式。与经典计算机使用比特作为信息的基本单位不同,量子计算机使用量子比特。量子比特具有叠加态的特性,即一个量子比特可以同时处于0和1的状态,这使得量子计算机在处理某些特定问题时具有经典计算机无法比拟的优势。量子纠缠是量子计算中另一个关键概念,当两个量子比特发生纠缠时,测量其中一个的状态会立即影响另一个的状态。''', + '''根据联合国政府间气候变化专门委员会第六次评估报告,全球平均温度已经比工业化前水平上升了约1.1摄氏度。报告指出,人类活动是导致全球变暖的主要原因,其中化石燃料的燃烧、工业生产和土地利用变化是温室气体排放的主要来源。极端天气事件的频率和强度都在增加,包括热浪、干旱、暴雨和洪水。北极海冰面积持续缩小,格陵兰和南极冰盖加速融化。''', + '''深度学习是机器学习的一个分支,其核心是利用多层神经网络从大量数据中自动学习特征表示。卷积神经网络在图像识别领域取得了巨大成功,循环神经网络和Transformer架构则在自然语言处理领域展现了强大能力。近年来,大语言模型如GPT、BERT、LLaMA等引领了自然语言处理的技术革新,这些模型通过在海量文本数据上进行预训练,获得了强大的语言理解和生成能力。''', + '''在计算机体系结构领域,冯诺依曼架构仍然是现代计算机的基础。然而随着摩尔定律逐渐放缓,研究人员开始探索新型计算范式,包括神经形态计算、存内计算、光子计算等。GPU和TPU等专用加速器的发展极大推动了深度学习的进步。RISC-V开源指令集架构的兴起为芯片设计带来了新的可能性,而chiplet技术和先进封装则为突破制程限制提供了新的路径。''', + '''可再生能源的成本大幅下降,太阳能和风能已经成为最便宜的新增发电来源。电动汽车市场快速增长,电池技术不断进步。碳捕获和储存技术、绿色氢能等前沿技术也在加速发展。然而,要实现全球碳中和目标,仍需要在能源系统、交通运输、工业生产、建筑等领域进行深刻的变革。智能电网、储能技术、虚拟电厂等概念正在从理论走向实践。''', + '''生物信息学是一门利用计算机技术和数学方法研究生物学问题的交叉学科。基因组学、蛋白质组学、代谢组学等组学技术的发展产生了海量的生物数据。AlphaFold2在蛋白质结构预测方面取得了革命性突破,为药物研发和生命科学研究开辟了新的方向。CRISPR基因编辑技术的发展使得精准修改基因成为可能,为遗传疾病的治疗带来了希望。''', +] + + +def build_long_prompt(idx, target_chars=9000): + parts = [f"(文档编号{idx}) 请仔细阅读以下学术材料并总结:\n\n"] + i = idx + while sum(len(p) for p in parts) < target_chars: + parts.append(_BASE_PARAGRAPHS[i % len(_BASE_PARAGRAPHS)]) + parts.append("\n\n") + i += 1 + parts.append(f"以上是第{idx}份材料,请给出详细分析。") + return "".join(parts) + + +async def measure_one_round(client, round_idx, delay_sec): + """ + 1. Fire a long request (starts prefill immediately) + 2. After delay_sec, fire a short request + 3. Return both TTFT and E2E for both requests + """ + long_prompt = build_long_prompt(round_idx, target_chars=9000) + short_prompt = f"(编号{round_idx}) 1+1等于几?" + + long_result = {} + short_result = {} + + async def do_request(prompt, result_dict, delay=0): + if delay > 0: + await asyncio.sleep(delay) + t0 = time.time() + stream = await client.chat.completions.create( + model=MODEL, + messages=[{"role": "user", "content": prompt}], + stream=True, + max_new_tokens=MAX_TOKENS, + temperature=1.0, + top_p=1.0, + extra_body={"top_k": 1}, + ) + first_token_time = None + total_tokens = 0 + async for chunk in stream: + if chunk.choices[0].delta.content: + if first_token_time is None: + first_token_time = time.time() + total_tokens += 1 + if chunk.choices[0].finish_reason is not None: + break + end_time = time.time() + result_dict["ttft"] = (first_token_time - t0) if first_token_time else None + result_dict["e2e"] = end_time - t0 + result_dict["tokens"] = total_tokens + + await asyncio.gather( + do_request(long_prompt, long_result, delay=0), + do_request(short_prompt, short_result, delay=delay_sec), + ) + return long_result, short_result + + +async def run_benchmark(rounds, delay): + client = AsyncOpenAI(base_url=API_URL, api_key="default") + + # Warmup + print("Warmup...") + await measure_one_round(client, 100, delay) + print("Warmup done.\n") + + long_ttfts = [] + long_e2es = [] + short_ttfts = [] + short_e2es = [] + + for i in range(rounds): + lr, sr = await measure_one_round(client, i, delay) # lr = long request result, sr = short request result + + lt = lr["ttft"] * 1000 if lr["ttft"] else 0 + le = lr["e2e"] * 1000 + st = sr["ttft"] * 1000 if sr["ttft"] else 0 + se = sr["e2e"] * 1000 + + print(f" Round {i}: LONG TTFT={lt:>7.1f}ms E2E={le:>8.1f}ms tokens={lr['tokens']}") + print(f" SHORT TTFT={st:>7.1f}ms E2E={se:>8.1f}ms tokens={sr['tokens']}") + + if lr["ttft"]: + long_ttfts.append(lr["ttft"]) + long_e2es.append(lr["e2e"]) + if sr["ttft"]: + short_ttfts.append(sr["ttft"]) + short_e2es.append(sr["e2e"]) + + sep = "=" * 60 + print(f"\n{sep}") + print(f" Chunked Prefill TTFT Benchmark") + print(f"{sep}") + print(f" Rounds: {rounds}") + print(f" Delay before short request: {delay}s") + print(f" Long prompt: ~9000 chars") + print(f" Max tokens: {MAX_TOKENS}") + + def print_stats(label, ttfts, e2es): + if not ttfts: + return + print(f"\n [{label}]") + print(f" Avg TTFT: {sum(ttfts)/len(ttfts)*1000:>8.1f} ms") + print(f" Min TTFT: {min(ttfts)*1000:>8.1f} ms") + print(f" Max TTFT: {max(ttfts)*1000:>8.1f} ms") + print(f" Avg E2E: {sum(e2es)/len(e2es)*1000:>8.1f} ms") + + print_stats("LONG ", long_ttfts, long_e2es) + print_stats("SHORT", short_ttfts, short_e2es) + + if short_ttfts: + print(f"\n >>> SHORT Avg TTFT = {sum(short_ttfts)/len(short_ttfts)*1000:.1f} ms <<<") + print(f" >>> SHORT Avg E2E = {sum(short_e2es)/len(short_e2es)*1000:.1f} ms <<<") + print(f"{sep}\n") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--rounds", type=int, default=5) + parser.add_argument("--delay", type=float, default=0.1, + help="Seconds to wait before sending short request (default: 0.1)") + args = parser.parse_args() + + asyncio.run(run_benchmark(args.rounds, args.delay)) diff --git a/scripts/test_perf_mix.py b/scripts/test_perf_mix.py new file mode 100644 index 000000000..a9dd433e9 --- /dev/null +++ b/scripts/test_perf_mix.py @@ -0,0 +1,173 @@ +import asyncio +import time +from openai import AsyncOpenAI +import argparse +import random + +PROMPTS = [ + + # ~10000 tokens:极限长上下文,多文件代码重构 + "下面给出 4 个相关文件,请重构以消除重复逻辑并提取公共抽象:\n\n" + + "# file: scheduler_v1.py\n" + "def schedule(reqs):\n return sorted(reqs, key=lambda r: r.arrival)\n" * 100 + + "\n# file: scheduler_v2.py\n" + "def schedule(reqs):\n return sorted(reqs, key=lambda r: -r.priority)\n" * 100 + + "\n# file: scheduler_v3.py\n" + "def schedule(reqs):\n return sorted(reqs, key=lambda r: r.prompt_len)\n" * 100 + + "\n# file: scheduler_v4.py\n" + "def schedule(reqs):\n return sorted(reqs, key=lambda r: r.slo_deadline)\n" * 100, + + + + "1+1=?", + +] + +NUM_REQUESTS = len(PROMPTS) +CONCURRENCY = 5 +API_URL = "http://127.0.0.1:2333" +MODEL = "FM9G-7B" + + +async def benchmark_user(client, semaphore, queue, results, user_id, verbose): + while True: + async with semaphore: + task_id = await queue.get() + if task_id is None: + queue.task_done() + break + + question = PROMPTS[task_id] + try: + print(f"🚀 User#{user_id} Sending request #{task_id}") + + start_time = time.time() + stream = await client.chat.completions.create( + model=MODEL, + messages=[{"role": "user", "content": question}], + stream=True, + ) + + first_token_time = None + total_tokens = 0 + answer_chunks = [] + + async for chunk in stream: + if first_token_time is None: + first_token_time = time.time() + delta = chunk.choices[0].delta.content + if delta: + answer_chunks.append(delta) + total_tokens += 1 + if chunk.choices[0].finish_reason is not None: + break + + end_time = time.time() + + ttft = first_token_time - start_time if first_token_time else None + elapsed_time = end_time - start_time if start_time else None + ms_per_token = ( + (elapsed_time / total_tokens * 1000) + if total_tokens > 0 and elapsed_time + else None + ) + tokens_per_second = ( + total_tokens / elapsed_time if elapsed_time > 0 else 0 + ) + + answer = "".join(answer_chunks) + + results.append( + (total_tokens, elapsed_time, tokens_per_second, ttft, ms_per_token) + ) + + if verbose: + print(f"\n📝 Request #{task_id} (User #{user_id})") + if ttft is not None: + print(f" ⏱ 首字延迟 TTFT: {ttft:.3f}s") + if elapsed_time is not None: + print(f" ⏱ 总耗时: {elapsed_time:.3f}s") + + print(f" 🔤 解码 token 总数: {total_tokens}") + if ms_per_token is not None: + print(f" 📏 平均 token 解码时间: {ms_per_token:.2f} ms/token") + else: + print(f" 📏 平均 token 解码时间: N/A (no token generated)") + print(f" ❓ 提问: {question}") + print(f" 💬 回答: {answer}\n") + + queue.task_done() + except Exception as e: + if verbose: + print(f"\n⚠️ Request #{task_id} (User #{user_id}) FAILED:") + print(f" ❌ Error: {e}\n") + queue.task_done() + + +async def run_benchmark(verbose=False): + client = AsyncOpenAI(base_url=API_URL, api_key="default") + semaphore = asyncio.Semaphore(CONCURRENCY) + queue = asyncio.Queue() + results = [] + for i in range(NUM_REQUESTS): + await queue.put(i) + for _ in range(CONCURRENCY): + await queue.put(None) + + users = [ + asyncio.create_task( + benchmark_user(client, semaphore, queue, results, user_id, verbose) + ) + for user_id in range(CONCURRENCY) + ] + + start_time = time.time() + await queue.join() + await asyncio.gather(*users) + end_time = time.time() + + total_elapsed_time = end_time - start_time + tokens_list = [r[0] for r in results if r and r[0] is not None] + latencies = [r[1] for r in results if r and r[1] is not None] + tokens_per_second_list = [r[2] for r in results if r and r[2] is not None] + ttft_list = [r[3] for r in results if r and r[3] is not None] + ms_per_token_list = [r[4] for r in results if r and r[4] is not None] + + successful_requests = len(results) + requests_per_second = ( + successful_requests / total_elapsed_time if total_elapsed_time > 0 else 0 + ) + avg_latency = sum(latencies) / len(latencies) if latencies else 0 + avg_tokens_per_second = ( + sum(tokens_per_second_list) / len(tokens_per_second_list) + if tokens_per_second_list + else 0 + ) + avg_ttft = sum(ttft_list) / len(ttft_list) if ttft_list else 0 + avg_ms_per_token = ( + sum(ms_per_token_list) / len(ms_per_token_list) if ms_per_token_list else None + ) + + width_label = 24 + sep = "-" * 60 + + print(f"\n=== 📊 性能指标汇总 ({MODEL}) ===") + print(sep) + print(f"{'并发数':<{width_label}}: {CONCURRENCY}") + print(f"{'请求总数':<{width_label}}: {NUM_REQUESTS}") + print(f"{'成功请求数':<{width_label}}: {successful_requests}") + print(f"{'总耗时':<{width_label}}: {total_elapsed_time:.2f} s") + print(f"{'总输出token数':<{width_label}}: {sum(tokens_list)}") + print(f"{'请求速率 (RPS)':<{width_label}}: {requests_per_second:.2f} requests/s") + print(sep) + print(f"{'Average latency':<{width_label}}: {avg_latency:.2f} s") + print(f"{'Average TTFT':<{width_label}}: {avg_ttft:.2f} s") + print(f"{'Avg time per token':<{width_label}}: {avg_ms_per_token:.2f} ms/token") + print( + f"{'Avg Token generation speed':<{width_label}}: {avg_tokens_per_second:.2f} tokens/s" + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--verbose", action="store_true") + args = parser.parse_args() + + asyncio.run(run_benchmark(args.verbose)) +