diff --git a/docker/Dockerfile b/docker/Dockerfile index a0929a019..cb27f9c44 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -69,7 +69,7 @@ ENV CUDA_HOME=/usr/local/cuda \ RUN if [ "${ENABLE_CACHE}" = "1" ]; then \ apt-get update && apt-get install -y libboost-dev && rm -rf /var/lib/apt/lists/*; \ - LIGHTMEM_REF=5900baf92d85ef4dbda6124093506b0af906011a; \ + LIGHTMEM_REF=9f9817b0ec6ae7055dea0542a63f66de2685ed90; \ pip install --no-deps -v "git+https://github.com/ModelTC/LightMem.git@${LIGHTMEM_REF}#egg=light_mem"; \ fi diff --git a/lightllm/server/multi_level_kv_cache/disk_cache_worker.py b/lightllm/server/multi_level_kv_cache/disk_cache_worker.py index b1e5fcf6f..5aa51feaa 100644 --- a/lightllm/server/multi_level_kv_cache/disk_cache_worker.py +++ b/lightllm/server/multi_level_kv_cache/disk_cache_worker.py @@ -21,6 +21,8 @@ ) raise ImportError("LightMem library is required for disk cache functionality") from e +TASK_WAIT_TIMEOUT_S = 60.0 + @dataclass class _PagePayload: @@ -43,11 +45,11 @@ def __init__( assert disk_cache_storage_size > 0 storage_size = int(disk_cache_storage_size * (1024 ** 3)) # num_shard与KVCACHE_MAX_BLOCK_SIZE相关,KVCACHE_MAX_BLOCK_SIZE默认64MB前提下, - # num_shard设置32, 能使disk cache的容量利用率达到90%,继续增大num_shard会导致容量利用率下降 - num_shard = 32 - num_worker = 48 - # 读写同时进行时,分配16线程用来写,32线程用来读 - max_concurrent_write_tasks = 16 + # num_shard设置8, 能使disk cache的容量利用率达到90%,继续增大num_shard会导致容量利用率下降 + num_shard = 8 + num_worker = 24 + # 读写同时进行时,分配8线程用来写,16线程用来读 + max_concurrent_write_tasks = 8 cache_dir = disk_cache_dir if not cache_dir: @@ -78,6 +80,24 @@ def __init__( def _prepare_tensor(self, tensor: torch.Tensor) -> torch.Tensor: return tensor.flatten(1).view(dtype=torch.uint8) + def _wait_task(self, task, cond_name: str) -> bool: + cond = getattr(task, cond_name) + deadline = time.monotonic() + TASK_WAIT_TIMEOUT_S + while not cond(): + if time.monotonic() >= deadline: + logger.error( + "disk cache task '%s' wait timeout after %.1fs, aborting task to avoid hang", + cond_name, + TASK_WAIT_TIMEOUT_S, + ) + try: + self.service.abort(task) + except Exception as e: + logger.error("disk cache abort task failed: %s", e) + return False + time.sleep(0.001) + return True + def run(self) -> None: while True: time.sleep(0.1) @@ -121,9 +141,16 @@ def _persist_pages_to_disk(self, payloads: List[_PagePayload]) -> None: query_result = self.service.query(hashs) if not all(query_result): # 限制写入并发量,给读取操作留资源 + throttle_deadline = time.monotonic() + TASK_WAIT_TIMEOUT_S while ( self.service.active_threads("r") and self.service.active_threads("w") >= self.max_concurrent_write_tasks ): + if time.monotonic() >= throttle_deadline: + logger.error( + "disk cache write throttle wait timeout after %.1fs, proceeding to submit", + TASK_WAIT_TIMEOUT_S, + ) + break time.sleep(0.001) task = self.service.create(hash_128s=hashs, kv_page_indexer=kv_indexer, mode="w") @@ -133,9 +160,7 @@ def _persist_pages_to_disk(self, payloads: List[_PagePayload]) -> None: self.cpu_cache_client.deref_pages(page_list=task.page_already_list) self.cpu_cache_client.lock.release() - # 数据安全即可结束等待,无需写入完成 - while not task.data_safe(): - time.sleep(0.001) + self._wait_task(task, "data_safe") # 释放剩余需要写入的页面 remining_indexes = list(set(page_indexes) - set(task.page_already_list)) @@ -181,6 +206,6 @@ def load_pages(self, hashs: List[int], page_indexes: List[int], start_pos: int = kv_indexer = torch.tensor(page_indexes, dtype=torch.int32, device="cpu") task = self.service.create(hash_128s=hashs, kv_page_indexer=kv_indexer, mode="r", start_pos=start_pos) - while not task.ready(): - time.sleep(0.001) + if not self._wait_task(task, "ready"): + return False return all(state == PyState.Finished for state in task.state()) diff --git a/test/benchmark/service/benchmark_multiturn.py b/test/benchmark/service/benchmark_multiturn.py index 897d12507..2d54eedc1 100644 --- a/test/benchmark/service/benchmark_multiturn.py +++ b/test/benchmark/service/benchmark_multiturn.py @@ -35,6 +35,7 @@ """ import argparse +import hashlib import json import os import random @@ -42,8 +43,8 @@ import time import urllib.parse import urllib.request -from concurrent.futures import ThreadPoolExecutor, as_completed -from typing import Dict, List, Optional, Tuple, Union +from concurrent.futures import FIRST_COMPLETED, ThreadPoolExecutor, as_completed, wait +from typing import Dict, List, Optional, Set, Tuple, Union import numpy as np import requests @@ -67,6 +68,17 @@ def seed_all(seed: int) -> None: np.random.seed(seed) +def derive_seed(base_seed: int, namespace: str, index: int = 0) -> int: + """Derive a deterministic, well-mixed seed from the benchmark seed. + + Adjacent --seed values should lead to unrelated per-session RNG streams, + while still keeping the overall request stream reproducible. + """ + payload = f"{base_seed}:{namespace}:{index}".encode("utf-8") + digest = hashlib.blake2b(payload, digest_size=8).digest() + return int.from_bytes(digest, byteorder="big", signed=False) + + def get_tokenizer(tokenizer_name: str) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: return AutoTokenizer.from_pretrained(tokenizer_name, trust_remote_code=True) @@ -366,7 +378,162 @@ def stream_one_turn( return None -def run_session( +class SessionState: + """Holds the evolving conversation state for a single simulated user. + + A session is either active (currently issuing requests) or silent (paused). + A silent session keeps its accumulated prompt, turn counter and RNG state, + so when it is reactivated it resumes the conversation from where it left + off. All RNG usage for a session happens in the worker thread that runs its + turn; since a session has at most one in-flight turn at a time, its RNG + sequence is deterministic and independent of scheduling/concurrency. + """ + + def __init__( + self, + session_id: int, + tokenizer, + start_input_len: int, + max_input_len: int, + max_turns: int, + base_seed: int, + ) -> None: + self.session_id = session_id + self.max_input_len = max_input_len + self.max_turns = max_turns + session_rng_seed = derive_seed(base_seed, "session_rng", session_id) + session_prompt_seed = derive_seed(base_seed, "session_prompt", session_id) + self.rng = random.Random(session_rng_seed) + self.prompt, self.prompt_len = gen_session_initial_prompt(tokenizer, start_input_len, session_prompt_seed) + self.turn_idx = 0 + self.per_turn: List[Dict] = [] + self.in_flight = False + + def is_completed(self) -> bool: + return self.turn_idx >= self.max_turns or self.prompt_len >= self.max_input_len + + +def request_session_turn( + session: SessionState, + tokenizer, + url: str, + model_name: str, + min_output_len: int, + output_len: int, + request_timeout_s: int, +) -> Tuple[int, Optional[Dict]]: + turn_output_len = session.rng.randint(min_output_len, output_len) + result = stream_one_turn( + tokenizer=tokenizer, + url=url, + model_name=model_name, + prompt=session.prompt, + prompt_token_len=session.prompt_len, + max_new_tokens=turn_output_len, + request_timeout_s=request_timeout_s, + ) + return turn_output_len, result + + +def build_next_prompt( + session: SessionState, + tokenizer, + assistant_token_count: int, + min_turn_input_increment: int, + turn_input_increment: int, +) -> Tuple[str, int]: + turn_input_len = session.rng.randint(min_turn_input_increment, turn_input_increment) + return append_turn_input( + tokenizer, + session.prompt, + session.prompt_len, + assistant_token_count, + turn_input_len, + session.rng, + ) + + +def get_cache_hit_ratio_str( + prompt_tokens_total: int, + cached_tokens_total: int, + cached_reported_turns: int, +) -> str: + if cached_reported_turns > 0 and prompt_tokens_total > 0: + return f"{cached_tokens_total / prompt_tokens_total * 100.0:.2f}%" + return "n/a" + + +def print_progress_line( + concurrency: int, + finished_turns: int, + completed_sessions: int, + session_num: int, + active_sessions: int, + prompt_tokens_total: int, + cached_tokens_total: int, + cached_reported_turns: int, + failed_sessions: int, + wall_start: float, +) -> None: + elapsed_time = max(time.time() - wall_start, 1e-9) + current_qps = finished_turns / elapsed_time + cache_hit_ratio_str = get_cache_hit_ratio_str( + prompt_tokens_total, + cached_tokens_total, + cached_reported_turns, + ) + failed_str = f" failed={failed_sessions}" if failed_sessions else "" + print( + f"\rconc={concurrency} " + f"turns={finished_turns} " + f"sessions={completed_sessions}/{session_num} " + f"active={active_sessions}/{concurrency} " + f"cache_hit={cache_hit_ratio_str} " + f"qps={current_qps:.2f}{failed_str}\033[K", + end="", + flush=True, + ) + + +def run_one_turn( + session: SessionState, + tokenizer, + url: str, + model_name: str, + min_turn_input_increment: int, + turn_input_increment: int, + min_output_len: int, + output_len: int, + request_timeout_s: int, +) -> Tuple[SessionState, Optional[Dict], str, int]: + """Execute a single turn for the given session in a worker thread. + + Returns (session, result, next_prompt, next_prompt_len). `result` is None + when the turn failed. The session state is not mutated here; the scheduler + applies the returned values under its single-threaded control loop. + """ + turn_output_len, result = request_session_turn( + session, + tokenizer, + url, + model_name, + min_output_len, + output_len, + request_timeout_s, + ) + if result is None: + return session, None, session.prompt, session.prompt_len + new_prompt, new_len = build_next_prompt( + session, + tokenizer, + turn_output_len, + min_turn_input_increment, + turn_input_increment, + ) + return session, result, new_prompt, new_len + + +def run_full_session( session_id: int, tokenizer, url: str, @@ -383,54 +550,76 @@ def run_session( progress_state: Dict, progress_lock: threading.Lock, ) -> List[Dict]: - """Run a single multi-turn dialogue session. Returns a list of per-turn - stat dicts (same schema as stream_one_turn output).""" - rng = random.Random(base_seed + session_id) - prompt, prompt_len = gen_session_initial_prompt(tokenizer, start_input_len, base_seed + session_id) + """Run one session to completion in a worker thread. + This preserves the pre-pool execution model for the full-concurrency case, + avoiding the per-turn scheduler overhead in the benchmark client. + """ + session = SessionState(session_id, tokenizer, start_input_len, max_input_len, max_turns, base_seed) per_turn: List[Dict] = [] - turn_idx = 0 + failed = False + try: - while turn_idx < max_turns and prompt_len < max_input_len: - turn_output_len = rng.randint(min_output_len, output_len) - result = stream_one_turn( - tokenizer=tokenizer, - url=url, - model_name=model_name, - prompt=prompt, - prompt_token_len=prompt_len, - max_new_tokens=turn_output_len, - request_timeout_s=request_timeout_s, + while not session.is_completed(): + turn_output_len, result = request_session_turn( + session, + tokenizer, + url, + model_name, + min_output_len, + output_len, + request_timeout_s, ) if result is None: + failed = True break + per_turn.append(result) + session.turn_idx += 1 + session_completed = session.is_completed() + with progress_lock: progress_state["finished_turns"] += 1 - print( - f"\rconc={progress_state['concurrency']} " - f"finished_turns={progress_state['finished_turns']} " - f"active_sessions={progress_state['active_sessions']}\033[K", - end="", - flush=True, + progress_state["prompt_tokens_total"] += result["prompt_tokens"] + progress_state["cached_tokens_total"] += result["cached_tokens"] + if result.get("cached_tokens_reported"): + progress_state["cached_reported_turns"] += 1 + + print_progress_line( + concurrency=progress_state["concurrency"], + finished_turns=progress_state["finished_turns"], + completed_sessions=progress_state["completed_sessions"] + int(session_completed), + session_num=progress_state["session_num"], + active_sessions=progress_state["active_sessions"] - int(session_completed), + prompt_tokens_total=progress_state["prompt_tokens_total"], + cached_tokens_total=progress_state["cached_tokens_total"], + cached_reported_turns=progress_state["cached_reported_turns"], + failed_sessions=progress_state["failed_sessions"], + wall_start=progress_state["wall_start"], ) - turn_input_len = rng.randint(min_turn_input_increment, turn_input_increment) - prompt, prompt_len = append_turn_input( + + if session_completed: + break + + session.prompt, session.prompt_len = build_next_prompt( + session, tokenizer, - prompt, - prompt_len, turn_output_len, - turn_input_len, - rng, + min_turn_input_increment, + turn_input_increment, ) - turn_idx += 1 finally: with progress_lock: progress_state["active_sessions"] -= 1 + if failed: + progress_state["failed_sessions"] += 1 + else: + progress_state["completed_sessions"] += 1 + return per_turn -def run_concurrency_level( +def run_full_concurrency_level( concurrency: int, tokenizer, url: str, @@ -445,19 +634,25 @@ def run_concurrency_level( base_seed: int, request_timeout_s: int, ) -> Dict: - """Run one concurrency level. Returns the aggregated stats dict.""" progress_state = { "concurrency": concurrency, "finished_turns": 0, "active_sessions": concurrency, + "session_num": concurrency, + "prompt_tokens_total": 0, + "cached_tokens_total": 0, + "cached_reported_turns": 0, + "completed_sessions": 0, + "failed_sessions": 0, + "wall_start": time.time(), } progress_lock = threading.Lock() - wall_start = time.time() + wall_start = progress_state["wall_start"] with ThreadPoolExecutor(max_workers=concurrency) as executor: futures = [ executor.submit( - run_session, + run_full_session, sid, tokenizer, url, @@ -477,13 +672,13 @@ def run_concurrency_level( for sid in range(concurrency) ] session_results: List[List[Dict]] = [] - for fut in as_completed(futures): - session_results.append(fut.result()) - wall_end = time.time() - wall_time = max(wall_end - wall_start, 1e-9) - print() # newline after progress bar + for future in as_completed(futures): + session_results.append(future.result()) + + wall_time = max(time.time() - wall_start, 1e-9) + print() - all_turns: List[Dict] = [t for s in session_results for t in s] + all_turns: List[Dict] = [turn for session_turns in session_results for turn in session_turns] return summarize( concurrency=concurrency, turns=all_turns, @@ -493,6 +688,198 @@ def run_concurrency_level( ) +def run_pooled_concurrency_level( + concurrency: int, + tokenizer, + url: str, + model_name: str, + start_input_len: int, + max_input_len: int, + min_turn_input_increment: int, + turn_input_increment: int, + min_output_len: int, + output_len: int, + max_turns: int, + base_seed: int, + request_timeout_s: int, + session_num: int, + swap_interval_turns: int, +) -> Dict: + wall_start = time.time() + all_sessions = [ + SessionState(sid, tokenizer, start_input_len, max_input_len, max_turns, base_seed) for sid in range(session_num) + ] + pool: Dict[int, SessionState] = {session.session_id: session for session in all_sessions} + active_ids: Set[int] = set() + selection_rng = random.Random(derive_seed(base_seed, "active_selection")) + futures: Dict = {} + + finished_turns = 0 + prompt_tokens_total = 0 + cached_tokens_total = 0 + cached_reported_turns = 0 + swaps_done = 0 + completed_sessions = 0 + failed_sessions = 0 + + def do_swap() -> None: + pool_ids = sorted(pool.keys()) + active_ids.clear() + active_ids.update(selection_rng.sample(pool_ids, min(concurrency, len(pool_ids)))) + + def ensure_active_filled() -> None: + target = min(concurrency, len(pool)) + if len(active_ids) >= target: + return + silent_ids = sorted(session_id for session_id in pool if session_id not in active_ids) + need = target - len(active_ids) + chosen = silent_ids if need >= len(silent_ids) else selection_rng.sample(silent_ids, need) + active_ids.update(chosen) + + def submit_active_turns(executor) -> None: + for session_id in sorted(active_ids): + if len(futures) >= concurrency: + break + session = pool.get(session_id) + if session is None or session.in_flight: + continue + session.in_flight = True + future = executor.submit( + run_one_turn, + session, + tokenizer, + url, + model_name, + min_turn_input_increment, + turn_input_increment, + min_output_len, + output_len, + request_timeout_s, + ) + futures[future] = session + + with ThreadPoolExecutor(max_workers=concurrency) as executor: + do_swap() + submit_active_turns(executor) + while futures: + done, _ = wait(list(futures.keys()), return_when=FIRST_COMPLETED) + for future in done: + session = futures.pop(future) + session.in_flight = False + _, result, new_prompt, new_len = future.result() + if result is None: + pool.pop(session.session_id, None) + active_ids.discard(session.session_id) + failed_sessions += 1 + else: + session.per_turn.append(result) + session.prompt = new_prompt + session.prompt_len = new_len + session.turn_idx += 1 + finished_turns += 1 + prompt_tokens_total += result["prompt_tokens"] + cached_tokens_total += result["cached_tokens"] + if result.get("cached_tokens_reported"): + cached_reported_turns += 1 + if session.is_completed(): + pool.pop(session.session_id, None) + active_ids.discard(session.session_id) + completed_sessions += 1 + + if concurrency < session_num and pool and finished_turns // swap_interval_turns > swaps_done: + swaps_done = finished_turns // swap_interval_turns + do_swap() + ensure_active_filled() + + print_progress_line( + concurrency=concurrency, + finished_turns=finished_turns, + completed_sessions=completed_sessions, + session_num=session_num, + active_sessions=len(active_ids), + prompt_tokens_total=prompt_tokens_total, + cached_tokens_total=cached_tokens_total, + cached_reported_turns=cached_reported_turns, + failed_sessions=failed_sessions, + wall_start=wall_start, + ) + + submit_active_turns(executor) + + wall_time = max(time.time() - wall_start, 1e-9) + print() + + all_turns: List[Dict] = [turn for session in all_sessions for turn in session.per_turn] + return summarize( + concurrency=concurrency, + turns=all_turns, + wall_time=wall_time, + num_sessions=session_num, + max_turns=max_turns, + ) + + +def run_concurrency_level( + concurrency: int, + tokenizer, + url: str, + model_name: str, + start_input_len: int, + max_input_len: int, + min_turn_input_increment: int, + turn_input_increment: int, + min_output_len: int, + output_len: int, + max_turns: int, + base_seed: int, + request_timeout_s: int, + session_num: int, + swap_interval_turns: int, +) -> Dict: + """Run one concurrency level. Returns the aggregated stats dict. + + A pool of `session_num` simulated users is built for this level. At any + moment only `concurrency` of them are active (issuing requests); the rest + stay silent while keeping their conversation state. Every + `swap_interval_turns` completed turns the active set is re-sampled from the + pool. Sessions that finish all their turns leave the pool permanently; the + level ends once the pool is empty. + """ + if concurrency == session_num: + return run_full_concurrency_level( + concurrency, + tokenizer, + url, + model_name, + start_input_len, + max_input_len, + min_turn_input_increment, + turn_input_increment, + min_output_len, + output_len, + max_turns, + base_seed, + request_timeout_s, + ) + return run_pooled_concurrency_level( + concurrency, + tokenizer, + url, + model_name, + start_input_len, + max_input_len, + min_turn_input_increment, + turn_input_increment, + min_output_len, + output_len, + max_turns, + base_seed, + request_timeout_s, + session_num, + swap_interval_turns, + ) + + def summarize( concurrency: int, turns: List[Dict], @@ -656,6 +1043,21 @@ def main() -> None: ) parser.add_argument("--seed", type=int, default=0) parser.add_argument("--request_timeout_s", type=int, default=3600) + parser.add_argument( + "--session_num", + type=int, + default=None, + help="Total number of simulated users (pool size) per concurrency level. " + "Must be >= every concurrency level. Defaults to max(concurrency_levels). " + "Only `concurrency` users are active at a time; the rest stay silent but " + "keep their conversation state for later reactivation.", + ) + parser.add_argument( + "--swap_interval_turns", + type=int, + default=100, + help="Every this many completed turns, re-sample the active user set from the pool.", + ) parser.add_argument( "--dump_file", type=str, @@ -674,6 +1076,8 @@ def main() -> None: raise ValueError("--min_turn_input_increment must be >= 0") if args.min_turn_input_increment > args.turn_input_increment: raise ValueError("--min_turn_input_increment must be <= --turn_input_increment") + if args.swap_interval_turns < 1: + raise ValueError("--swap_interval_turns must be >= 1") if args.dump_file and os.path.exists(args.dump_file) and os.path.getsize(args.dump_file) > 0: with open(args.dump_file, "r") as f: @@ -689,12 +1093,18 @@ def main() -> None: ) tokenizer = get_tokenizer(args.tokenizer_path) concurrency_levels = [int(x) for x in args.concurrency_levels.split(",") if x.strip()] + max_concurrency = max(concurrency_levels) if concurrency_levels else 0 + session_num = args.session_num if args.session_num is not None else max_concurrency + if session_num < max_concurrency: + raise ValueError(f"--session_num ({session_num}) must be >= the largest concurrency level ({max_concurrency}).") print(f"URL : {args.url}") print(f"Model : {model_name}") if model_name_note: print(f"Model note : {model_name_note}") print(f"Concurrency levels : {concurrency_levels}") + print(f"session_num : {session_num}") + print(f"swap_interval_turns: {args.swap_interval_turns}") print(f"start_input_len : {args.start_input_len}") print(f"max_input_len : {args.max_input_len}") print(f"min_turn_input_increment: {args.min_turn_input_increment}") @@ -719,6 +1129,8 @@ def main() -> None: max_turns=args.max_turns, base_seed=args.seed, request_timeout_s=args.request_timeout_s, + session_num=session_num, + swap_interval_turns=args.swap_interval_turns, ) print_summary(summary) all_summaries.append(summary) @@ -737,6 +1149,8 @@ def main() -> None: "min_output_len": args.min_output_len, "output_len": args.output_len, "max_turns": args.max_turns, + "session_num": session_num, + "swap_interval_turns": args.swap_interval_turns, "seed": args.seed, }, "results": all_summaries,