From ef61358c9124bb807ec892ad876023f2e820c491 Mon Sep 17 00:00:00 2001 From: helloyongyang Date: Thu, 11 Jun 2026 15:06:21 +0800 Subject: [PATCH 1/2] Support platform server --- lightx2v/server/services/distributed_utils.py | 24 +- scripts/hidream_o1_image/post.sh | 34 +++ scripts/hidream_o1_image/post_all.sh | 21 ++ .../post_async_t2i_and_wait.py | 280 ++++++++++++++++++ scripts/hidream_o1_image/start_server.sh | 35 +++ 5 files changed, 386 insertions(+), 8 deletions(-) create mode 100644 scripts/hidream_o1_image/post.sh create mode 100644 scripts/hidream_o1_image/post_all.sh create mode 100644 scripts/hidream_o1_image/post_async_t2i_and_wait.py create mode 100644 scripts/hidream_o1_image/start_server.sh diff --git a/lightx2v/server/services/distributed_utils.py b/lightx2v/server/services/distributed_utils.py index d96c3537e..0dfebd127 100644 --- a/lightx2v/server/services/distributed_utils.py +++ b/lightx2v/server/services/distributed_utils.py @@ -7,6 +7,9 @@ import torch.distributed as dist from loguru import logger +import lightx2v_platform # noqa: F401 +from lightx2v_platform.registry_factory import PLATFORM_DEVICE_REGISTER + class DistributedManager: def __init__(self): @@ -19,27 +22,32 @@ def __init__(self): CHUNK_SIZE = 1024 * 1024 + def _get_platform_device(self): + platform = os.getenv("PLATFORM", "cuda") + platform_device = PLATFORM_DEVICE_REGISTER.get(platform, None) + if platform_device is None: + available_platforms = list(PLATFORM_DEVICE_REGISTER.keys()) + raise RuntimeError(f"Unsupported PLATFORM: {platform}. Available PLATFORM: {available_platforms}") + return platform_device + def init_process_group(self) -> bool: try: self.rank = int(os.environ.get("LOCAL_RANK", 0)) self.world_size = int(os.environ.get("WORLD_SIZE", 1)) + platform_device = self._get_platform_device() if self.world_size > 1: - backend = "nccl" if torch.cuda.is_available() else "gloo" - dist.init_process_group(backend=backend, init_method="env://") + platform_device.init_parallel_env() + backend = dist.get_backend() logger.info(f"Setup backend: {backend}") task_timeout = timedelta(days=30) self.task_pg = dist.new_group(backend="gloo", timeout=task_timeout) logger.info("Created gloo process group for task distribution with 30-day timeout") - if torch.cuda.is_available(): - torch.cuda.set_device(self.rank) - self.device = f"cuda:{self.rank}" - else: - self.device = "cpu" + self.device = f"{platform_device.get_device()}:{self.rank}" else: - self.device = "cuda:0" if torch.cuda.is_available() else "cpu" + self.device = f"{platform_device.get_device()}:0" self.is_initialized = True logger.info(f"Rank {self.rank}/{self.world_size - 1} distributed environment initialized successfully") diff --git a/scripts/hidream_o1_image/post.sh b/scripts/hidream_o1_image/post.sh new file mode 100644 index 000000000..3ee991dff --- /dev/null +++ b/scripts/hidream_o1_image/post.sh @@ -0,0 +1,34 @@ +#!/bin/bash + +lightx2v_path=/root/yongyang3/LightX2V +port=8000 +server_url=http://127.0.0.1:${port} + +prompt="medium shot, eye-level, front view. A woman is seated in an ornate bedroom, illuminated by candlelight, with a calm and composed expression. The subject is a young woman with fair skin, light brown hair styled in an updo with loose tendrils framing her face, and blue eyes. She wears a cream-colored satin robe with delicate floral embroidery and lace trim along the neckline. Her ears are adorned with pearl drop earrings. She is seated on a bed with a dark, intricately carved wooden headboard. To her left, a wooden nightstand holds three lit white candles and a candelabra with multiple lit candles in the background. The bed is covered with patterned pillows and a dark, textured blanket. The walls are paneled with dark wood and feature a large, ornate tapestry with muted earth tones. The lighting creates soft highlights on her face and robe, with warm shadows cast across the room." +negative_prompt="" +infer_steps=28 +seed=32 +aspect_ratio=1:1 +target_height=2048 +target_width=2048 + +# Keep this relative so /v1/tasks/{task_id}/result can download it from the server output dir. +server_save_result_path=hidream_o1_image_t2i_dev_2604_request.png +output=${lightx2v_path}/save_results/hidream_o1_image_t2i_dev_2604_request.png +timeout_seconds=1200 +poll_interval=2.0 + +export PYTHONPATH="${lightx2v_path}:${PYTHONPATH:-}" + +python "${lightx2v_path}/scripts/server/post_async_t2i_and_wait.py" \ +--url "${server_url}" \ +--prompt "${prompt}" \ +--negative_prompt "${negative_prompt}" \ +--infer_steps "${infer_steps}" \ +--seed "${seed}" \ +--aspect_ratio "${aspect_ratio}" \ +--target_shape "${target_height}" "${target_width}" \ +--save_result_path "${server_save_result_path}" \ +--timeout_seconds "${timeout_seconds}" \ +--poll_interval "${poll_interval}" \ +--output "${output}" diff --git a/scripts/hidream_o1_image/post_all.sh b/scripts/hidream_o1_image/post_all.sh new file mode 100644 index 000000000..16affc769 --- /dev/null +++ b/scripts/hidream_o1_image/post_all.sh @@ -0,0 +1,21 @@ +#!/bin/bash + +lightx2v_path=/root/yongyang3/LightX2V +test_json=${TEST_JSON:-/root/test.json} +port=${PORT:-8000} +server_url=${SERVER_URL:-http://127.0.0.1:${port}} + +export PYTHONPATH="${lightx2v_path}:${PYTHONPATH:-}" + +python "${lightx2v_path}/scripts/hidream_o1_image/post_async_t2i_and_wait.py" \ +--url "${server_url}" \ +--prompt_json "${test_json}" \ +--negative_prompt "" \ +--infer_steps 28 \ +--seed 42 \ +--aspect_ratio 1:1 \ +--target_shape 2048 2048 \ +--timeout_seconds 1200 \ +--poll_interval 2.0 \ +--output_dir "${lightx2v_path}/save_results/hidream_o1_image_test_json" \ +--output_prefix hidream_o1_image_test_json diff --git a/scripts/hidream_o1_image/post_async_t2i_and_wait.py b/scripts/hidream_o1_image/post_async_t2i_and_wait.py new file mode 100644 index 000000000..e23238c35 --- /dev/null +++ b/scripts/hidream_o1_image/post_async_t2i_and_wait.py @@ -0,0 +1,280 @@ +import argparse +import json +import os +import time +from datetime import datetime +from pathlib import Path +from typing import List, Optional +from urllib.parse import urlparse + +import requests + + +def submit_t2i_task( + base_url: str, + prompt: str, + negative_prompt: str, + infer_steps: int, + seed: int, + aspect_ratio: str, + target_shape: Optional[List[int]], + save_result_path: str, + use_prompt_enhancer: bool, +) -> str: + payload = { + "prompt": prompt, + "negative_prompt": negative_prompt, + "infer_steps": infer_steps, + "seed": seed, + "aspect_ratio": aspect_ratio, + "save_result_path": save_result_path, + "use_prompt_enhancer": use_prompt_enhancer, + } + if target_shape: + payload["target_shape"] = target_shape + + submit_url = f"{base_url.rstrip('/')}/v1/tasks/image/" + response = requests.post(submit_url, json=payload, timeout=30) + if response.status_code != 200: + raise RuntimeError(f"Submit task failed ({response.status_code}): {response.text}") + + data = response.json() + task_id = data.get("task_id") + if not task_id: + raise RuntimeError(f"Submit task succeeded but no task_id found: {data}") + return task_id + + +def wait_task_done(base_url: str, task_id: str, timeout_seconds: int, poll_interval: float) -> dict: + status_url = f"{base_url.rstrip('/')}/v1/tasks/{task_id}/status" + deadline = time.time() + timeout_seconds + + while time.time() < deadline: + response = requests.get(status_url, timeout=15) + if response.status_code != 200: + raise RuntimeError(f"Get task status failed ({response.status_code}): {response.text}") + + status = response.json() + task_status = status.get("status") + print(f"[poll] task_id={task_id}, status={task_status}") + + if task_status == "completed": + return status + if task_status in ("failed", "cancelled"): + raise RuntimeError(f"Task ended with status={task_status}, detail={status.get('error')}") + + time.sleep(poll_interval) + + raise TimeoutError(f"Task {task_id} timeout after {timeout_seconds}s") + + +def download_result(base_url: str, task_id: str, output: str) -> Path: + result_url = f"{base_url.rstrip('/')}/v1/tasks/{task_id}/result" + response = requests.get(result_url, timeout=120) + if response.status_code != 200: + raise RuntimeError(f"Download result failed ({response.status_code}): {response.text}") + + output_path = Path(output) + output_path.parent.mkdir(parents=True, exist_ok=True) + output_path.write_bytes(response.content) + return output_path + + +def format_duration(seconds: float) -> str: + total_ms = int(round(seconds * 1000)) + total_seconds, milliseconds = divmod(total_ms, 1000) + hours, remainder = divmod(total_seconds, 3600) + minutes, seconds = divmod(remainder, 60) + return f"{hours:02d}:{minutes:02d}:{seconds:02d}.{milliseconds:03d}" + + +def ensure_local_no_proxy(base_url: str) -> None: + hostname = urlparse(base_url).hostname + if hostname not in {"127.0.0.1", "localhost", "0.0.0.0", "::1"}: + return + + local_hosts = ["127.0.0.1", "localhost", "0.0.0.0", "::1"] + for env_name in ("NO_PROXY", "no_proxy"): + existing = [item.strip() for item in os.environ.get(env_name, "").split(",") if item.strip()] + merged = local_hosts + [item for item in existing if item not in local_hosts] + os.environ[env_name] = ",".join(merged) + + +def load_prompts(prompt_json: str) -> List[str]: + path = Path(prompt_json) + data = json.loads(path.read_text(encoding="utf-8")) + if not isinstance(data, list): + raise ValueError(f"{path} must contain a JSON array") + + prompts = [] + for index, item in enumerate(data): + if isinstance(item, str): + prompt = item + elif isinstance(item, dict) and isinstance(item.get("prompt"), str): + prompt = item["prompt"] + else: + raise ValueError(f"{path}[{index}] must be a string or an object with a string prompt") + prompts.append(prompt) + return prompts + + +def write_summary_line(summary_file: Optional[Path], line: str) -> None: + if summary_file is None: + return + with summary_file.open("a", encoding="utf-8") as f: + f.write(line) + f.write("\n") + + +def run_one_task(args, prompt: str, save_result_path: str, output: str) -> dict: + case_start = time.perf_counter() + task_id = submit_t2i_task( + base_url=args.url, + prompt=prompt, + negative_prompt=args.negative_prompt, + infer_steps=args.infer_steps, + seed=args.seed, + aspect_ratio=args.aspect_ratio, + target_shape=args.target_shape, + save_result_path=save_result_path, + use_prompt_enhancer=args.use_prompt_enhancer, + ) + print(f"Task submitted successfully, task_id={task_id}") + + final_status = wait_task_done( + base_url=args.url, + task_id=task_id, + timeout_seconds=args.timeout_seconds, + poll_interval=args.poll_interval, + ) + print(f"Task completed: {final_status}") + + output_path = download_result(args.url, task_id, output) + elapsed = time.perf_counter() - case_start + print(f"Result saved to: {output_path}") + print(f"Task elapsed: {format_duration(elapsed)} ({elapsed:.3f}s)") + + return { + "task_id": task_id, + "status": "success", + "elapsed": elapsed, + "output": str(output_path), + "save_result_path": final_status.get("save_result_path", save_result_path), + } + + +def run_batch(args) -> int: + prompts = load_prompts(args.prompt_json) + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + run_stamp = datetime.now().strftime("%Y%m%d_%H%M%S") + summary_file = Path(args.summary_file) if args.summary_file else output_dir / f"{args.output_prefix}_summary_{run_stamp}.log" + summary_file.parent.mkdir(parents=True, exist_ok=True) + + started_at = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + total = len(prompts) + completed = 0 + failed = 0 + batch_start = time.perf_counter() + + summary_file.write_text( + "\n".join( + [ + f"Run started at: {started_at}", + f"Prompt JSON: {args.prompt_json}", + f"Server URL: {args.url}", + f"Output directory: {output_dir}", + f"Total prompts: {total}", + "", + ] + ), + encoding="utf-8", + ) + + print(f"Posting {total} prompts from {args.prompt_json} to {args.url}") + print(f"Output directory: {output_dir}") + print(f"Summary file: {summary_file}") + + for index, prompt in enumerate(prompts, 1): + number = f"{index:03d}" + server_save_result_path = f"{args.output_prefix}_{number}.png" + output = output_dir / server_save_result_path + case_started_at = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + case_start = time.perf_counter() + + print(f"[{index}/{total}] submitting {server_save_result_path}") + write_summary_line(summary_file, f"Case {number} started at: {case_started_at}") + + try: + result = run_one_task(args, prompt, server_save_result_path, str(output)) + except Exception as e: + failed += 1 + elapsed = time.perf_counter() - case_start + print(f"[{index}/{total}] failed: {server_save_result_path}: {e}") + write_summary_line(summary_file, f"Case {number} status: failed, elapsed: {format_duration(elapsed)} ({elapsed:.3f}s), error: {e}") + if args.stop_on_error: + break + else: + completed += 1 + write_summary_line( + summary_file, + f"Case {number} status: success, task_id: {result['task_id']}, elapsed: {format_duration(result['elapsed'])} ({result['elapsed']:.3f}s), output: {result['output']}", + ) + print(f"[{index}/{total}] saved {output}") + + total_elapsed = time.perf_counter() - batch_start + ended_at = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + write_summary_line(summary_file, "") + write_summary_line(summary_file, f"Run ended at: {ended_at}") + write_summary_line(summary_file, f"Elapsed seconds: {total_elapsed:.3f}") + write_summary_line(summary_file, f"Elapsed time: {format_duration(total_elapsed)}") + write_summary_line(summary_file, f"Completed prompts: {completed}") + write_summary_line(summary_file, f"Failed prompts: {failed}") + + print(f"Finished: completed={completed}/{total}, failed={failed}/{total}") + print(f"Total elapsed: {format_duration(total_elapsed)} ({total_elapsed:.3f}s)") + print(f"Summary written to: {summary_file}") + return 1 if failed else 0 + + +def main(): + parser = argparse.ArgumentParser(description="Submit T2I task to /v1/tasks/image/ and wait for final result.") + parser.add_argument("--url", type=str, default="http://127.0.0.1:8000", help="Server base url") + parser.add_argument("--prompt", type=str, default=None, help="Prompt text") + parser.add_argument("--prompt_json", "--json", dest="prompt_json", type=str, default=None, help="JSON file containing prompts") + parser.add_argument("--negative_prompt", type=str, default="", help="Negative prompt text") + parser.add_argument("--infer_steps", type=int, default=30, help="Inference steps") + parser.add_argument("--seed", type=int, default=42, help="Random seed") + parser.add_argument("--aspect_ratio", type=str, default="16:9", help="Aspect ratio for image task") + parser.add_argument( + "--target_shape", + type=int, + nargs="+", + default=None, + help="Target output shape, e.g. --target_shape 1536 2752", + ) + parser.add_argument("--save_result_path", type=str, default="", help="Server-side save_result_path") + parser.add_argument("--use_prompt_enhancer", action="store_true", help="Enable prompt enhancer") + parser.add_argument("--timeout_seconds", type=int, default=600, help="Polling timeout in seconds") + parser.add_argument("--poll_interval", type=float, default=2.0, help="Polling interval in seconds") + parser.add_argument("--output", type=str, default="save_results/t2i_result.png", help="Local output image path") + parser.add_argument("--output_dir", type=str, default="save_results/hidream_o1_image_test_json", help="Batch output directory") + parser.add_argument("--output_prefix", type=str, default="hidream_o1_image_test_json", help="Batch output filename prefix") + parser.add_argument("--summary_file", type=str, default=None, help="Batch timing summary file") + parser.add_argument("--stop_on_error", action="store_true", help="Stop batch mode after the first failed prompt") + + args = parser.parse_args() + ensure_local_no_proxy(args.url) + + if args.prompt_json: + raise SystemExit(run_batch(args)) + + if not args.prompt: + parser.error("--prompt is required unless --prompt_json is set") + + run_one_task(args, args.prompt, args.save_result_path, args.output) + + +if __name__ == "__main__": + main() diff --git a/scripts/hidream_o1_image/start_server.sh b/scripts/hidream_o1_image/start_server.sh new file mode 100644 index 000000000..d5b032696 --- /dev/null +++ b/scripts/hidream_o1_image/start_server.sh @@ -0,0 +1,35 @@ +#!/bin/bash + +lightx2v_path=/root/yongyang3/LightX2V +model_path=/root/wushuo/models/HiDream-ai/HiDream-O1-Image-Dev-2604 +config_json=/root/yongyang3/LightX2V/configs/hidream_o1_image/mlu/hidream_o1_image_t2i_dev_2604_dist.json + +host=0.0.0.0 +port=8000 +metric_port=8001 +max_queue_size=10 +nproc_per_node=4 +master_port=29500 + +export PLATFORM=cambricon_mlu +export MLU_VISIBLE_DEVICES=0,1,2,3 +export PYTORCH_MLU_ALLOC_CONF=expandable_segments:True +export LD_LIBRARY_PATH=/usr/local/neuware/lib64:${LD_LIBRARY_PATH} + +# set environment variables +source "${lightx2v_path}/scripts/base/base.sh" + +echo "Starting HiDream-O1-Image Dev-2604 T2I distributed service on ${host}:${port}" +echo "CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES}, NPROC_PER_NODE=${nproc_per_node}, MASTER_PORT=${master_port}" + +torchrun --nproc_per_node="${nproc_per_node}" --master_port="${master_port}" -m lightx2v.server \ +--model_cls hidream_o1_image \ +--task t2i \ +--model_path "${model_path}" \ +--config_json "${config_json}" \ +--host "${host}" \ +--port "${port}" \ +--metric_port "${metric_port}" \ +--max_queue_size "${max_queue_size}" + +echo "Service stopped" From fe71fefb255ec1fa2a91d6a42ab71f46620f69d3 Mon Sep 17 00:00:00 2001 From: helloyongyang Date: Thu, 11 Jun 2026 17:00:17 +0800 Subject: [PATCH 2/2] Update sever. (Add threading.Condition for create_task) --- .../hidream_o1_image_runner.py | 4 +- lightx2v/server/api/openai_images.py | 58 ++- lightx2v/server/api/server.py | 8 +- lightx2v/server/task_manager.py | 21 +- scripts/hidream_o1_image/post.sh | 34 -- scripts/hidream_o1_image/post_all.sh | 21 -- .../post_async_t2i_and_wait.py | 280 -------------- scripts/hidream_o1_image/post_t2i_openai.sh | 20 + .../test_openai_images_client.py | 348 ++++++++++++++++++ 9 files changed, 442 insertions(+), 352 deletions(-) delete mode 100644 scripts/hidream_o1_image/post.sh delete mode 100644 scripts/hidream_o1_image/post_all.sh delete mode 100644 scripts/hidream_o1_image/post_async_t2i_and_wait.py create mode 100755 scripts/hidream_o1_image/post_t2i_openai.sh create mode 100644 scripts/hidream_o1_image/test_openai_images_client.py diff --git a/lightx2v/models/runners/hidream_o1_image/hidream_o1_image_runner.py b/lightx2v/models/runners/hidream_o1_image/hidream_o1_image_runner.py index 260012b1b..06ef32fa8 100644 --- a/lightx2v/models/runners/hidream_o1_image/hidream_o1_image_runner.py +++ b/lightx2v/models/runners/hidream_o1_image/hidream_o1_image_runner.py @@ -262,7 +262,7 @@ def run_pipeline(self, input_info): save_result_path = self.inputs.get("save_result_path") if self.input_info.return_result_tensor: self.end_run() - return {"image": image} + return {"images": [image]} if save_result_path: os.makedirs(os.path.dirname(os.path.abspath(save_result_path)), exist_ok=True) image.save(save_result_path) @@ -271,7 +271,7 @@ def run_pipeline(self, input_info): if GET_RECORDER_MODE(): monitor_cli.lightx2v_worker_request_success.inc() self.end_run() - return {"image": None} + return {"images": None} def end_run(self): if hasattr(self, "scheduler") and self.scheduler is not None: diff --git a/lightx2v/server/api/openai_images.py b/lightx2v/server/api/openai_images.py index a4bd2b6f8..e11f2864d 100644 --- a/lightx2v/server/api/openai_images.py +++ b/lightx2v/server/api/openai_images.py @@ -3,6 +3,7 @@ import re import time import uuid +from datetime import datetime from pathlib import Path from typing import Literal, Optional @@ -17,6 +18,7 @@ router = APIRouter() _SIZE_PATTERN = re.compile(r"^\s*(\d+)\s*x\s*(\d+)\s*$", re.IGNORECASE) +OPENAI_IMAGE_RESULT_POLL_INTERVAL_SECONDS = 0.2 class OpenAIImageGenerationRequest(BaseModel): @@ -52,7 +54,9 @@ def _shape_from_size(size: str) -> tuple[int, int]: async def _wait_task_result_png(task_id: str, timeout_seconds: int, poll_interval_seconds: float) -> bytes: start_time = time.monotonic() + status_checks = 0 while True: + status_checks += 1 task_status = task_manager.get_task_status(task_id) if not task_status: raise HTTPException(status_code=500, detail=f"Task status not found: {task_id}") @@ -61,6 +65,16 @@ async def _wait_task_result_png(task_id: str, timeout_seconds: int, poll_interva if status == TaskStatus.COMPLETED.value: result_png = task_manager.get_task_result_png(task_id) if result_png: + wait_elapsed_ms = (time.monotonic() - start_time) * 1000 + completion_observe_lag_ms = 0.0 + end_time = task_status.get("end_time") + if end_time: + completion_observe_lag_ms = (datetime.now() - end_time).total_seconds() * 1000 + logger.info( + f"Task {task_id} OpenAI image wait_task_result cost total={wait_elapsed_ms:.2f} ms " + f"completion_observe_lag={completion_observe_lag_ms:.2f} ms " + f"poll_interval={poll_interval_seconds:.2f} s status_checks={status_checks}" + ) return result_png raise HTTPException(status_code=500, detail=f"Task completed but no in-memory image found: {task_id}") @@ -89,28 +103,39 @@ async def _watch_client_disconnect(request: Request, task_id: str, poll_interval async def _run_sync_image_task(request: Request, message: ImageTaskRequest) -> bytes: task_id = None timeout_seconds = 600 - poll_interval_seconds = 0.5 + poll_interval_seconds = OPENAI_IMAGE_RESULT_POLL_INTERVAL_SECONDS try: message.prefer_memory_result = True + create_task_start = time.perf_counter() task_id = task_manager.create_task(message) + create_task_elapsed_ms = (time.perf_counter() - create_task_start) * 1000 message.task_id = task_id + logger.info(f"Task {task_id} OpenAI image create_task cost {create_task_elapsed_ms:.2f} ms prompt_chars={len(message.prompt)} target_shape={message.target_shape}") wait_task = asyncio.create_task(_wait_task_result_png(task_id, timeout_seconds, poll_interval_seconds)) disconnect_task = asyncio.create_task(_watch_client_disconnect(request, task_id)) done, pending = await asyncio.wait({wait_task, disconnect_task}, return_when=asyncio.FIRST_COMPLETED) - for pending_task in pending: - pending_task.cancel() - await asyncio.gather(*pending, return_exceptions=True) + + if wait_task in done: + result_png = wait_task.result() + for pending_task in pending: + pending_task.cancel() + if pending: + _, still_pending = await asyncio.wait(pending, timeout=0) + if still_pending: + logger.debug(f"Task {task_id} disconnect watcher cancellation is still pending") + logger.info(f"Task {task_id} OpenAI image task result ready, building response") + return result_png if disconnect_task in done and disconnect_task.result(): if not wait_task.done(): wait_task.cancel() - await asyncio.gather(wait_task, return_exceptions=True) + await asyncio.wait({wait_task}, timeout=0) raise HTTPException(status_code=499, detail=f"Client disconnected, task {task_id} cancelled") - return wait_task.result() + raise HTTPException(status_code=500, detail=f"Task {task_id} ended without image result") except RuntimeError as e: raise HTTPException(status_code=503, detail=str(e)) except HTTPException: @@ -138,10 +163,25 @@ def _build_url_response(request: Request, task_id: str, image_bytes: bytes) -> s def _build_openai_response(request: Request, task_id: str, image_bytes: bytes, response_format: Literal["url", "b64_json"]): + total_start = time.perf_counter() if response_format == "b64_json": - return OpenAIImageResponse(created=int(time.time()), data=[{"b64_json": base64.b64encode(image_bytes).decode("utf-8")}]) - - return OpenAIImageResponse(created=int(time.time()), data=[{"url": _build_url_response(request, task_id, image_bytes)}]) + base64_start = time.perf_counter() + b64_json = base64.b64encode(image_bytes).decode("utf-8") + base64_elapsed_ms = (time.perf_counter() - base64_start) * 1000 + response = OpenAIImageResponse(created=int(time.time()), data=[{"b64_json": b64_json}]) + total_elapsed_ms = (time.perf_counter() - total_start) * 1000 + logger.info( + f"Task {task_id} OpenAI image response build cost total={total_elapsed_ms:.2f} ms base64={base64_elapsed_ms:.2f} ms format=b64_json png_bytes={len(image_bytes)} b64_chars={len(b64_json)}" + ) + return response + + url_start = time.perf_counter() + url = _build_url_response(request, task_id, image_bytes) + url_elapsed_ms = (time.perf_counter() - url_start) * 1000 + response = OpenAIImageResponse(created=int(time.time()), data=[{"url": url}]) + total_elapsed_ms = (time.perf_counter() - total_start) * 1000 + logger.info(f"Task {task_id} OpenAI image response build cost total={total_elapsed_ms:.2f} ms url_write={url_elapsed_ms:.2f} ms format=url png_bytes={len(image_bytes)}") + return response def _build_image_task_request( diff --git a/lightx2v/server/api/server.py b/lightx2v/server/api/server.py index 7027b5a1a..af438b855 100644 --- a/lightx2v/server/api/server.py +++ b/lightx2v/server/api/server.py @@ -1,6 +1,7 @@ import asyncio import threading import time +from datetime import datetime from pathlib import Path from typing import Any, Optional @@ -14,6 +15,8 @@ from .deps import ServiceContainer, get_services from .router import create_api_router +TASK_PROCESSING_IDLE_WAIT_TIMEOUT_SECONDS = 0.2 + class ApiServer: def __init__(self, max_queue_size: int = 10, app: Optional[FastAPI] = None): @@ -85,10 +88,9 @@ def _task_processing_loop(self): loop = asyncio.get_event_loop() while not self.stop_processing.is_set(): - task_id = task_manager.get_next_pending_task() + task_id = task_manager.wait_for_next_pending_task(timeout=TASK_PROCESSING_IDLE_WAIT_TIMEOUT_SECONDS) if task_id is None: - time.sleep(1) continue task_info = task_manager.get_task(task_id) @@ -112,6 +114,8 @@ async def _process_single_task(self, task_info: Any): return try: + pending_elapsed_ms = (datetime.now() - task_info.start_time).total_seconds() * 1000 + logger.info(f"Task {task_id} scheduler pending wait {pending_elapsed_ms:.2f} ms") task_manager.start_task(task_id) if task_info.stop_event.is_set(): diff --git a/lightx2v/server/task_manager.py b/lightx2v/server/task_manager.py index 5c53ea096..b7db45029 100644 --- a/lightx2v/server/task_manager.py +++ b/lightx2v/server/task_manager.py @@ -40,6 +40,7 @@ def __init__(self, max_queue_size: int = 100): self._tasks: OrderedDict[str, TaskInfo] = OrderedDict() self._lock = threading.RLock() + self._task_available = threading.Condition(self._lock) self._processing_lock = threading.Lock() self._current_processing_task: Optional[str] = None @@ -50,7 +51,7 @@ def __init__(self, max_queue_size: int = 100): self._emit_queue_metrics_unlocked() def create_task(self, message: Any) -> str: - with self._lock: + with self._task_available: if hasattr(message, "task_id") and message.task_id in self._tasks: raise RuntimeError(f"Task ID {message.task_id} already exists") @@ -66,6 +67,7 @@ def create_task(self, message: Any) -> str: self._cleanup_old_tasks() self._emit_queue_metrics_unlocked() + self._task_available.notify() return task_id @@ -202,9 +204,20 @@ def release_processing_lock(self, task_id: str): def get_next_pending_task(self) -> Optional[str]: with self._lock: - for task_id, task in self._tasks.items(): - if task.status == TaskStatus.PENDING: - return task_id + return self._get_next_pending_task_unlocked() + + def wait_for_next_pending_task(self, timeout: Optional[float] = None) -> Optional[str]: + with self._task_available: + task_id = self._get_next_pending_task_unlocked() + if task_id: + return task_id + self._task_available.wait(timeout=timeout) + return self._get_next_pending_task_unlocked() + + def _get_next_pending_task_unlocked(self) -> Optional[str]: + for task_id, task in self._tasks.items(): + if task.status == TaskStatus.PENDING: + return task_id return None def get_service_status(self) -> Dict[str, Any]: diff --git a/scripts/hidream_o1_image/post.sh b/scripts/hidream_o1_image/post.sh deleted file mode 100644 index 3ee991dff..000000000 --- a/scripts/hidream_o1_image/post.sh +++ /dev/null @@ -1,34 +0,0 @@ -#!/bin/bash - -lightx2v_path=/root/yongyang3/LightX2V -port=8000 -server_url=http://127.0.0.1:${port} - -prompt="medium shot, eye-level, front view. A woman is seated in an ornate bedroom, illuminated by candlelight, with a calm and composed expression. The subject is a young woman with fair skin, light brown hair styled in an updo with loose tendrils framing her face, and blue eyes. She wears a cream-colored satin robe with delicate floral embroidery and lace trim along the neckline. Her ears are adorned with pearl drop earrings. She is seated on a bed with a dark, intricately carved wooden headboard. To her left, a wooden nightstand holds three lit white candles and a candelabra with multiple lit candles in the background. The bed is covered with patterned pillows and a dark, textured blanket. The walls are paneled with dark wood and feature a large, ornate tapestry with muted earth tones. The lighting creates soft highlights on her face and robe, with warm shadows cast across the room." -negative_prompt="" -infer_steps=28 -seed=32 -aspect_ratio=1:1 -target_height=2048 -target_width=2048 - -# Keep this relative so /v1/tasks/{task_id}/result can download it from the server output dir. -server_save_result_path=hidream_o1_image_t2i_dev_2604_request.png -output=${lightx2v_path}/save_results/hidream_o1_image_t2i_dev_2604_request.png -timeout_seconds=1200 -poll_interval=2.0 - -export PYTHONPATH="${lightx2v_path}:${PYTHONPATH:-}" - -python "${lightx2v_path}/scripts/server/post_async_t2i_and_wait.py" \ ---url "${server_url}" \ ---prompt "${prompt}" \ ---negative_prompt "${negative_prompt}" \ ---infer_steps "${infer_steps}" \ ---seed "${seed}" \ ---aspect_ratio "${aspect_ratio}" \ ---target_shape "${target_height}" "${target_width}" \ ---save_result_path "${server_save_result_path}" \ ---timeout_seconds "${timeout_seconds}" \ ---poll_interval "${poll_interval}" \ ---output "${output}" diff --git a/scripts/hidream_o1_image/post_all.sh b/scripts/hidream_o1_image/post_all.sh deleted file mode 100644 index 16affc769..000000000 --- a/scripts/hidream_o1_image/post_all.sh +++ /dev/null @@ -1,21 +0,0 @@ -#!/bin/bash - -lightx2v_path=/root/yongyang3/LightX2V -test_json=${TEST_JSON:-/root/test.json} -port=${PORT:-8000} -server_url=${SERVER_URL:-http://127.0.0.1:${port}} - -export PYTHONPATH="${lightx2v_path}:${PYTHONPATH:-}" - -python "${lightx2v_path}/scripts/hidream_o1_image/post_async_t2i_and_wait.py" \ ---url "${server_url}" \ ---prompt_json "${test_json}" \ ---negative_prompt "" \ ---infer_steps 28 \ ---seed 42 \ ---aspect_ratio 1:1 \ ---target_shape 2048 2048 \ ---timeout_seconds 1200 \ ---poll_interval 2.0 \ ---output_dir "${lightx2v_path}/save_results/hidream_o1_image_test_json" \ ---output_prefix hidream_o1_image_test_json diff --git a/scripts/hidream_o1_image/post_async_t2i_and_wait.py b/scripts/hidream_o1_image/post_async_t2i_and_wait.py deleted file mode 100644 index e23238c35..000000000 --- a/scripts/hidream_o1_image/post_async_t2i_and_wait.py +++ /dev/null @@ -1,280 +0,0 @@ -import argparse -import json -import os -import time -from datetime import datetime -from pathlib import Path -from typing import List, Optional -from urllib.parse import urlparse - -import requests - - -def submit_t2i_task( - base_url: str, - prompt: str, - negative_prompt: str, - infer_steps: int, - seed: int, - aspect_ratio: str, - target_shape: Optional[List[int]], - save_result_path: str, - use_prompt_enhancer: bool, -) -> str: - payload = { - "prompt": prompt, - "negative_prompt": negative_prompt, - "infer_steps": infer_steps, - "seed": seed, - "aspect_ratio": aspect_ratio, - "save_result_path": save_result_path, - "use_prompt_enhancer": use_prompt_enhancer, - } - if target_shape: - payload["target_shape"] = target_shape - - submit_url = f"{base_url.rstrip('/')}/v1/tasks/image/" - response = requests.post(submit_url, json=payload, timeout=30) - if response.status_code != 200: - raise RuntimeError(f"Submit task failed ({response.status_code}): {response.text}") - - data = response.json() - task_id = data.get("task_id") - if not task_id: - raise RuntimeError(f"Submit task succeeded but no task_id found: {data}") - return task_id - - -def wait_task_done(base_url: str, task_id: str, timeout_seconds: int, poll_interval: float) -> dict: - status_url = f"{base_url.rstrip('/')}/v1/tasks/{task_id}/status" - deadline = time.time() + timeout_seconds - - while time.time() < deadline: - response = requests.get(status_url, timeout=15) - if response.status_code != 200: - raise RuntimeError(f"Get task status failed ({response.status_code}): {response.text}") - - status = response.json() - task_status = status.get("status") - print(f"[poll] task_id={task_id}, status={task_status}") - - if task_status == "completed": - return status - if task_status in ("failed", "cancelled"): - raise RuntimeError(f"Task ended with status={task_status}, detail={status.get('error')}") - - time.sleep(poll_interval) - - raise TimeoutError(f"Task {task_id} timeout after {timeout_seconds}s") - - -def download_result(base_url: str, task_id: str, output: str) -> Path: - result_url = f"{base_url.rstrip('/')}/v1/tasks/{task_id}/result" - response = requests.get(result_url, timeout=120) - if response.status_code != 200: - raise RuntimeError(f"Download result failed ({response.status_code}): {response.text}") - - output_path = Path(output) - output_path.parent.mkdir(parents=True, exist_ok=True) - output_path.write_bytes(response.content) - return output_path - - -def format_duration(seconds: float) -> str: - total_ms = int(round(seconds * 1000)) - total_seconds, milliseconds = divmod(total_ms, 1000) - hours, remainder = divmod(total_seconds, 3600) - minutes, seconds = divmod(remainder, 60) - return f"{hours:02d}:{minutes:02d}:{seconds:02d}.{milliseconds:03d}" - - -def ensure_local_no_proxy(base_url: str) -> None: - hostname = urlparse(base_url).hostname - if hostname not in {"127.0.0.1", "localhost", "0.0.0.0", "::1"}: - return - - local_hosts = ["127.0.0.1", "localhost", "0.0.0.0", "::1"] - for env_name in ("NO_PROXY", "no_proxy"): - existing = [item.strip() for item in os.environ.get(env_name, "").split(",") if item.strip()] - merged = local_hosts + [item for item in existing if item not in local_hosts] - os.environ[env_name] = ",".join(merged) - - -def load_prompts(prompt_json: str) -> List[str]: - path = Path(prompt_json) - data = json.loads(path.read_text(encoding="utf-8")) - if not isinstance(data, list): - raise ValueError(f"{path} must contain a JSON array") - - prompts = [] - for index, item in enumerate(data): - if isinstance(item, str): - prompt = item - elif isinstance(item, dict) and isinstance(item.get("prompt"), str): - prompt = item["prompt"] - else: - raise ValueError(f"{path}[{index}] must be a string or an object with a string prompt") - prompts.append(prompt) - return prompts - - -def write_summary_line(summary_file: Optional[Path], line: str) -> None: - if summary_file is None: - return - with summary_file.open("a", encoding="utf-8") as f: - f.write(line) - f.write("\n") - - -def run_one_task(args, prompt: str, save_result_path: str, output: str) -> dict: - case_start = time.perf_counter() - task_id = submit_t2i_task( - base_url=args.url, - prompt=prompt, - negative_prompt=args.negative_prompt, - infer_steps=args.infer_steps, - seed=args.seed, - aspect_ratio=args.aspect_ratio, - target_shape=args.target_shape, - save_result_path=save_result_path, - use_prompt_enhancer=args.use_prompt_enhancer, - ) - print(f"Task submitted successfully, task_id={task_id}") - - final_status = wait_task_done( - base_url=args.url, - task_id=task_id, - timeout_seconds=args.timeout_seconds, - poll_interval=args.poll_interval, - ) - print(f"Task completed: {final_status}") - - output_path = download_result(args.url, task_id, output) - elapsed = time.perf_counter() - case_start - print(f"Result saved to: {output_path}") - print(f"Task elapsed: {format_duration(elapsed)} ({elapsed:.3f}s)") - - return { - "task_id": task_id, - "status": "success", - "elapsed": elapsed, - "output": str(output_path), - "save_result_path": final_status.get("save_result_path", save_result_path), - } - - -def run_batch(args) -> int: - prompts = load_prompts(args.prompt_json) - output_dir = Path(args.output_dir) - output_dir.mkdir(parents=True, exist_ok=True) - - run_stamp = datetime.now().strftime("%Y%m%d_%H%M%S") - summary_file = Path(args.summary_file) if args.summary_file else output_dir / f"{args.output_prefix}_summary_{run_stamp}.log" - summary_file.parent.mkdir(parents=True, exist_ok=True) - - started_at = datetime.now().strftime("%Y-%m-%d %H:%M:%S") - total = len(prompts) - completed = 0 - failed = 0 - batch_start = time.perf_counter() - - summary_file.write_text( - "\n".join( - [ - f"Run started at: {started_at}", - f"Prompt JSON: {args.prompt_json}", - f"Server URL: {args.url}", - f"Output directory: {output_dir}", - f"Total prompts: {total}", - "", - ] - ), - encoding="utf-8", - ) - - print(f"Posting {total} prompts from {args.prompt_json} to {args.url}") - print(f"Output directory: {output_dir}") - print(f"Summary file: {summary_file}") - - for index, prompt in enumerate(prompts, 1): - number = f"{index:03d}" - server_save_result_path = f"{args.output_prefix}_{number}.png" - output = output_dir / server_save_result_path - case_started_at = datetime.now().strftime("%Y-%m-%d %H:%M:%S") - case_start = time.perf_counter() - - print(f"[{index}/{total}] submitting {server_save_result_path}") - write_summary_line(summary_file, f"Case {number} started at: {case_started_at}") - - try: - result = run_one_task(args, prompt, server_save_result_path, str(output)) - except Exception as e: - failed += 1 - elapsed = time.perf_counter() - case_start - print(f"[{index}/{total}] failed: {server_save_result_path}: {e}") - write_summary_line(summary_file, f"Case {number} status: failed, elapsed: {format_duration(elapsed)} ({elapsed:.3f}s), error: {e}") - if args.stop_on_error: - break - else: - completed += 1 - write_summary_line( - summary_file, - f"Case {number} status: success, task_id: {result['task_id']}, elapsed: {format_duration(result['elapsed'])} ({result['elapsed']:.3f}s), output: {result['output']}", - ) - print(f"[{index}/{total}] saved {output}") - - total_elapsed = time.perf_counter() - batch_start - ended_at = datetime.now().strftime("%Y-%m-%d %H:%M:%S") - write_summary_line(summary_file, "") - write_summary_line(summary_file, f"Run ended at: {ended_at}") - write_summary_line(summary_file, f"Elapsed seconds: {total_elapsed:.3f}") - write_summary_line(summary_file, f"Elapsed time: {format_duration(total_elapsed)}") - write_summary_line(summary_file, f"Completed prompts: {completed}") - write_summary_line(summary_file, f"Failed prompts: {failed}") - - print(f"Finished: completed={completed}/{total}, failed={failed}/{total}") - print(f"Total elapsed: {format_duration(total_elapsed)} ({total_elapsed:.3f}s)") - print(f"Summary written to: {summary_file}") - return 1 if failed else 0 - - -def main(): - parser = argparse.ArgumentParser(description="Submit T2I task to /v1/tasks/image/ and wait for final result.") - parser.add_argument("--url", type=str, default="http://127.0.0.1:8000", help="Server base url") - parser.add_argument("--prompt", type=str, default=None, help="Prompt text") - parser.add_argument("--prompt_json", "--json", dest="prompt_json", type=str, default=None, help="JSON file containing prompts") - parser.add_argument("--negative_prompt", type=str, default="", help="Negative prompt text") - parser.add_argument("--infer_steps", type=int, default=30, help="Inference steps") - parser.add_argument("--seed", type=int, default=42, help="Random seed") - parser.add_argument("--aspect_ratio", type=str, default="16:9", help="Aspect ratio for image task") - parser.add_argument( - "--target_shape", - type=int, - nargs="+", - default=None, - help="Target output shape, e.g. --target_shape 1536 2752", - ) - parser.add_argument("--save_result_path", type=str, default="", help="Server-side save_result_path") - parser.add_argument("--use_prompt_enhancer", action="store_true", help="Enable prompt enhancer") - parser.add_argument("--timeout_seconds", type=int, default=600, help="Polling timeout in seconds") - parser.add_argument("--poll_interval", type=float, default=2.0, help="Polling interval in seconds") - parser.add_argument("--output", type=str, default="save_results/t2i_result.png", help="Local output image path") - parser.add_argument("--output_dir", type=str, default="save_results/hidream_o1_image_test_json", help="Batch output directory") - parser.add_argument("--output_prefix", type=str, default="hidream_o1_image_test_json", help="Batch output filename prefix") - parser.add_argument("--summary_file", type=str, default=None, help="Batch timing summary file") - parser.add_argument("--stop_on_error", action="store_true", help="Stop batch mode after the first failed prompt") - - args = parser.parse_args() - ensure_local_no_proxy(args.url) - - if args.prompt_json: - raise SystemExit(run_batch(args)) - - if not args.prompt: - parser.error("--prompt is required unless --prompt_json is set") - - run_one_task(args, args.prompt, args.save_result_path, args.output) - - -if __name__ == "__main__": - main() diff --git a/scripts/hidream_o1_image/post_t2i_openai.sh b/scripts/hidream_o1_image/post_t2i_openai.sh new file mode 100755 index 000000000..3c5d3f340 --- /dev/null +++ b/scripts/hidream_o1_image/post_t2i_openai.sh @@ -0,0 +1,20 @@ +#!/bin/bash + +lightx2v_path=/root/yongyang3/LightX2V +test_json=/root/test.json +base_url=http://127.0.0.1:8000/v1 +output_dir=${lightx2v_path}/save_results/hidream_o1_image_openai_test + +export PYTHONPATH="${lightx2v_path}" + +python "${lightx2v_path}/scripts/hidream_o1_image/test_openai_images_client.py" \ +--base_url "${base_url}" \ +--api_key "dummy-key" \ +--model "gpt-image-1" \ +--mode generate \ +--prompt_json "${test_json}" \ +--seed 42 \ +--size "2048x2048" \ +--response_format "b64_json" \ +--output_dir "${output_dir}" \ +--output_prefix "hidream_o1_image_openai" diff --git a/scripts/hidream_o1_image/test_openai_images_client.py b/scripts/hidream_o1_image/test_openai_images_client.py new file mode 100644 index 000000000..b6faa96a4 --- /dev/null +++ b/scripts/hidream_o1_image/test_openai_images_client.py @@ -0,0 +1,348 @@ +import argparse +import base64 +import json +import os +import time +from dataclasses import dataclass +from datetime import datetime +from pathlib import Path +from typing import Any +from urllib.parse import urlparse + +import requests + +try: + from openai import OpenAI # pyright: ignore[reportMissingImports] +except ImportError: + OpenAI = None # type: ignore[assignment] + + +@dataclass +class SaveImageResult: + path: Path + source: str + bytes_written: int + total_seconds: float + decode_seconds: float = 0.0 + download_seconds: float = 0.0 + write_seconds: float = 0.0 + + +@dataclass +class ImageRequestResult: + path: Path + http_sdk_parse_seconds: float + save: SaveImageResult + + +def _extract_data_item(response: Any) -> dict[str, Any]: + if not hasattr(response, "data") or not response.data: + raise RuntimeError(f"Invalid OpenAI images response: {response}") + item = response.data[0] + if hasattr(item, "model_dump"): + return item.model_dump() # openai pydantic object + if isinstance(item, dict): + return item + raise RuntimeError(f"Unsupported data item type: {type(item)!r}") + + +def _save_image_from_item(item: dict[str, Any], output_path: Path) -> SaveImageResult: + output_path.parent.mkdir(parents=True, exist_ok=True) + total_start = time.perf_counter() + + if "b64_json" in item and item["b64_json"]: + decode_start = time.perf_counter() + image_bytes = base64.b64decode(item["b64_json"]) + decode_seconds = time.perf_counter() - decode_start + + write_start = time.perf_counter() + output_path.write_bytes(image_bytes) + write_seconds = time.perf_counter() - write_start + + return SaveImageResult( + path=output_path, + source="b64_json", + bytes_written=len(image_bytes), + total_seconds=time.perf_counter() - total_start, + decode_seconds=decode_seconds, + write_seconds=write_seconds, + ) + + if "url" in item and item["url"]: + download_start = time.perf_counter() + resp = requests.get(item["url"], timeout=120) + resp.raise_for_status() + download_seconds = time.perf_counter() - download_start + + write_start = time.perf_counter() + output_path.write_bytes(resp.content) + write_seconds = time.perf_counter() - write_start + + return SaveImageResult( + path=output_path, + source="url", + bytes_written=len(resp.content), + total_seconds=time.perf_counter() - total_start, + download_seconds=download_seconds, + write_seconds=write_seconds, + ) + + raise RuntimeError(f"Response item has neither b64_json nor url: {item}") + + +def _summarize_response_item(item: dict[str, Any]) -> dict[str, Any]: + summary = dict(item) + if "b64_json" in summary and summary["b64_json"]: + summary["b64_json"] = f"" + return summary + + +def _format_duration(seconds: float) -> str: + total_ms = int(round(seconds * 1000)) + total_seconds, milliseconds = divmod(total_ms, 1000) + hours, remainder = divmod(total_seconds, 3600) + minutes, seconds = divmod(remainder, 60) + return f"{hours:02d}:{minutes:02d}:{seconds:02d}.{milliseconds:03d}" + + +def _format_request_timing(result: ImageRequestResult) -> str: + save = result.save + parts = [ + f"http_sdk_parse={result.http_sdk_parse_seconds:.3f}s", + f"save_total={save.total_seconds:.3f}s", + ] + if save.decode_seconds: + parts.append(f"base64_decode={save.decode_seconds:.3f}s") + if save.download_seconds: + parts.append(f"download={save.download_seconds:.3f}s") + parts.extend( + [ + f"disk_write={save.write_seconds:.3f}s", + f"bytes={save.bytes_written}", + f"source={save.source}", + ] + ) + return ", ".join(parts) + + +def _ensure_local_no_proxy(base_url: str) -> None: + hostname = urlparse(base_url).hostname + if hostname not in {"127.0.0.1", "localhost", "0.0.0.0", "::1"}: + return + + local_hosts = ["127.0.0.1", "localhost", "0.0.0.0", "::1"] + for env_name in ("NO_PROXY", "no_proxy"): + existing = [item.strip() for item in os.environ.get(env_name, "").split(",") if item.strip()] + merged = local_hosts + [item for item in existing if item not in local_hosts] + os.environ[env_name] = ",".join(merged) + + +def _load_prompts(prompt_json: str) -> list[str]: + path = Path(prompt_json) + data = json.loads(path.read_text(encoding="utf-8")) + if not isinstance(data, list): + raise ValueError(f"{path} must contain a JSON array") + + prompts = [] + for index, item in enumerate(data): + if isinstance(item, str): + prompt = item + elif isinstance(item, dict) and isinstance(item.get("prompt"), str): + prompt = item["prompt"] + else: + raise ValueError(f"{path}[{index}] must be a string or an object with a string prompt") + prompts.append(prompt) + return prompts + + +def _write_summary_line(summary_file: Path, line: str) -> None: + with summary_file.open("a", encoding="utf-8") as f: + f.write(line) + f.write("\n") + + +def _extra_body_from_args(args: argparse.Namespace) -> dict[str, Any] | None: + if args.seed is None: + return None + return {"seed": args.seed} + + +def run_generate(client: Any, args: argparse.Namespace, prompt: str | None = None, output_path: Path | None = None) -> ImageRequestResult: + prompt = args.prompt if prompt is None else prompt + output_path = Path(args.output_dir) / "generate.png" if output_path is None else output_path + + request_start = time.perf_counter() + response = client.images.generate( + model=args.model, + prompt=prompt, + size=args.size, + response_format=args.response_format, + extra_body=_extra_body_from_args(args), + ) + http_sdk_parse_seconds = time.perf_counter() - request_start + + item = _extract_data_item(response) + print(f"[generate] response item: {_summarize_response_item(item)}") + save_result = _save_image_from_item(item, output_path) + result = ImageRequestResult(path=save_result.path, http_sdk_parse_seconds=http_sdk_parse_seconds, save=save_result) + print(f"[generate] timing: {_format_request_timing(result)}") + return result + + +def run_edit(client: Any, args: argparse.Namespace) -> ImageRequestResult: + if not args.image: + raise ValueError("--image is required for edit mode") + + image_path = Path(args.image) + if not image_path.exists(): + raise FileNotFoundError(f"Image file not found: {image_path}") + + with image_path.open("rb") as image_file: + kwargs = { + "model": args.model, + "image": image_file, + "prompt": args.edit_prompt or args.prompt, + "size": args.size, + "response_format": args.response_format, + "extra_body": _extra_body_from_args(args), + } + request_start = time.perf_counter() + if args.mask: + mask_path = Path(args.mask) + if not mask_path.exists(): + raise FileNotFoundError(f"Mask file not found: {mask_path}") + with mask_path.open("rb") as mask_file: + response = client.images.edit(mask=mask_file, **kwargs) + else: + response = client.images.edit(**kwargs) + http_sdk_parse_seconds = time.perf_counter() - request_start + + item = _extract_data_item(response) + print(f"[edit] response item: {_summarize_response_item(item)}") + save_result = _save_image_from_item(item, Path(args.output_dir) / "edit.png") + result = ImageRequestResult(path=save_result.path, http_sdk_parse_seconds=http_sdk_parse_seconds, save=save_result) + print(f"[edit] timing: {_format_request_timing(result)}") + return result + + +def run_generate_batch(client: Any, args: argparse.Namespace) -> int: + prompts = _load_prompts(args.prompt_json) + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + run_stamp = datetime.now().strftime("%Y%m%d_%H%M%S") + summary_file = Path(args.summary_file) if args.summary_file else output_dir / f"{args.output_prefix}_summary_{run_stamp}.log" + summary_file.parent.mkdir(parents=True, exist_ok=True) + + total = len(prompts) + completed = 0 + failed = 0 + batch_start = time.perf_counter() + started_at = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + + summary_file.write_text( + "\n".join( + [ + f"Run started at: {started_at}", + f"Prompt JSON: {args.prompt_json}", + f"Base URL: {args.base_url}", + f"Model: {args.model}", + f"Seed: {args.seed}", + f"Size: {args.size}", + f"Response format: {args.response_format}", + f"Output directory: {output_dir}", + f"Total prompts: {total}", + "", + ] + ), + encoding="utf-8", + ) + + print(f"Posting {total} OpenAI-format image prompts from {args.prompt_json} to {args.base_url}") + print(f"Output directory: {output_dir}") + print(f"Summary file: {summary_file}") + + for index, prompt in enumerate(prompts, 1): + number = f"{index:03d}" + output_path = output_dir / f"{args.output_prefix}_{number}.png" + case_start = time.perf_counter() + case_started_at = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + + print(f"[{index}/{total}] submitting {output_path.name}") + _write_summary_line(summary_file, f"Case {number} started at: {case_started_at}") + + try: + result = run_generate(client, args, prompt=prompt, output_path=output_path) + except Exception as e: + failed += 1 + elapsed = time.perf_counter() - case_start + print(f"[{index}/{total}] failed: {output_path.name}: {e}") + _write_summary_line(summary_file, f"Case {number} status: failed, elapsed: {_format_duration(elapsed)} ({elapsed:.3f}s), error: {e}") + if args.stop_on_error: + break + else: + completed += 1 + elapsed = time.perf_counter() - case_start + print(f"[{index}/{total}] saved {result.path}") + print(f"[{index}/{total}] elapsed: {_format_duration(elapsed)} ({elapsed:.3f}s)") + _write_summary_line( + summary_file, + f"Case {number} status: success, elapsed: {_format_duration(elapsed)} ({elapsed:.3f}s), {_format_request_timing(result)}, output: {result.path}", + ) + + total_elapsed = time.perf_counter() - batch_start + ended_at = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + _write_summary_line(summary_file, "") + _write_summary_line(summary_file, f"Run ended at: {ended_at}") + _write_summary_line(summary_file, f"Elapsed seconds: {total_elapsed:.3f}") + _write_summary_line(summary_file, f"Elapsed time: {_format_duration(total_elapsed)}") + _write_summary_line(summary_file, f"Completed prompts: {completed}") + _write_summary_line(summary_file, f"Failed prompts: {failed}") + + print(f"Finished: completed={completed}/{total}, failed={failed}/{total}") + print(f"Total elapsed: {_format_duration(total_elapsed)} ({total_elapsed:.3f}s)") + print(f"Summary written to: {summary_file}") + return 1 if failed else 0 + + +def main() -> None: + parser = argparse.ArgumentParser(description="Test OpenAI-compatible image APIs on LightX2V server.") + parser.add_argument("--base_url", type=str, default="http://127.0.0.1:8000/v1", help="OpenAI-compatible base URL") + parser.add_argument("--api_key", type=str, default="dummy-key", help="OpenAI API key placeholder") + parser.add_argument("--model", type=str, default="gpt-image-1", help="Model name (for compatibility only)") + parser.add_argument("--mode", choices=["generate", "edit", "all"], default="generate", help="Test mode") + parser.add_argument("--prompt", type=str, default="a futuristic city at sunset", help="Prompt for generation") + parser.add_argument("--prompt_json", "--json", dest="prompt_json", type=str, default="", help="JSON file containing prompts for batch generation") + parser.add_argument("--edit_prompt", type=str, default="", help="Prompt for edit (defaults to --prompt)") + parser.add_argument("--seed", type=int, default=None, help="Optional generation seed") + parser.add_argument("--size", type=str, default="1024x1024", help="Image size, e.g. 1024x1024") + parser.add_argument("--response_format", choices=["url", "b64_json"], default="b64_json", help="OpenAI response format") + parser.add_argument("--image", type=str, default="", help="Input image path for edit mode") + parser.add_argument("--mask", type=str, default="", help="Optional mask image path for edit mode") + parser.add_argument("--output_dir", type=str, default="outputs/openai_images_test", help="Directory to save outputs") + parser.add_argument("--output_prefix", type=str, default="openai_generate", help="Batch output filename prefix") + parser.add_argument("--summary_file", type=str, default="", help="Batch timing summary file") + parser.add_argument("--stop_on_error", action="store_true", help="Stop batch mode after the first failed prompt") + args = parser.parse_args() + + if OpenAI is None: + raise RuntimeError("Missing dependency: openai. Please install it with `pip install openai`.") + + _ensure_local_no_proxy(args.base_url) + client = OpenAI(api_key=args.api_key, base_url=args.base_url) + + if args.prompt_json: + raise SystemExit(run_generate_batch(client, args)) + + output_results: list[ImageRequestResult] = [] + if args.mode in ("generate", "all"): + output_results.append(run_generate(client, args)) + if args.mode in ("edit", "all"): + output_results.append(run_edit(client, args)) + + for result in output_results: + print(f"[saved] {result.path}") + + +if __name__ == "__main__": + main()